diff --git a/machine/jobs/build_clearml_helper.py b/machine/jobs/build_clearml_helper.py index e605e498..029bb8bc 100644 --- a/machine/jobs/build_clearml_helper.py +++ b/machine/jobs/build_clearml_helper.py @@ -110,7 +110,7 @@ def local_progress(progress_status: ProgressStatus) -> None: return local_progress -def update_settings(settings: Settings, args: dict): +def update_settings(settings: Settings, args: dict, task: Optional[Task], logger: logging.Logger): settings.update(args) settings.model_type = cast(str, settings.model_type).lower() if "build_options" in settings: @@ -121,4 +121,11 @@ def update_settings(settings: Settings, args: dict): except TypeError as e: raise TypeError(f"Build options could not be parsed: {e}") from e settings.update({settings.model_type: build_options}) + if "align_pretranslations" in build_options: + settings.update({"align_pretranslations": build_options["align_pretranslations"]}) + if task is not None and "tags" in build_options: + tags = build_options["tags"] + if isinstance(tags, str) or (isinstance(tags, list) and all(isinstance(tag, str) for tag in tags)): + task.add_tags(tags) settings.data_dir = os.path.expanduser(cast(str, settings.data_dir)) + logger.info(f"Config: {settings.as_dict()}") diff --git a/machine/jobs/build_nmt_engine.py b/machine/jobs/build_nmt_engine.py index 8ba836cc..75f13b42 100644 --- a/machine/jobs/build_nmt_engine.py +++ b/machine/jobs/build_nmt_engine.py @@ -1,14 +1,12 @@ import argparse -import json import logging -import os -from typing import Callable, Optional, cast +from typing import Callable, Optional from clearml import Task from ..utils.canceled_error import CanceledError from ..utils.progress_status import ProgressStatus -from .build_clearml_helper import report_clearml_progress +from .build_clearml_helper import report_clearml_progress, update_settings from .config import SETTINGS from .nmt_engine_build_job import NmtEngineBuildJob from .nmt_model_factory import NmtModelFactory @@ -47,26 +45,11 @@ def clearml_progress(status: ProgressStatus) -> None: try: logger.info("NMT Engine Build Job started") - - SETTINGS.update(args) - model_type = cast(str, SETTINGS.model_type).lower() - if "build_options" in SETTINGS: - try: - build_options = json.loads(cast(str, SETTINGS.build_options)) - except ValueError as e: - raise ValueError("Build options could not be parsed: Invalid JSON") from e - except TypeError as e: - raise TypeError(f"Build options could not be parsed: {e}") from e - SETTINGS.update({model_type: build_options}) - if "align_pretranslations" in build_options: - SETTINGS.update({"align_pretranslations": build_options["align_pretranslations"]}) - SETTINGS.data_dir = os.path.expanduser(cast(str, SETTINGS.data_dir)) - - logger.info(f"Config: {SETTINGS.as_dict()}") + update_settings(SETTINGS, args, task, logger) translation_file_service = TranslationFileService(SharedFileServiceType.CLEARML, SETTINGS) nmt_model_factory: NmtModelFactory - if model_type == "huggingface": + if SETTINGS.model_type == "huggingface": from .huggingface.hugging_face_nmt_model_factory import HuggingFaceNmtModelFactory nmt_model_factory = HuggingFaceNmtModelFactory(SETTINGS) diff --git a/machine/jobs/build_smt_engine.py b/machine/jobs/build_smt_engine.py index 45cdd81f..8a9883d1 100644 --- a/machine/jobs/build_smt_engine.py +++ b/machine/jobs/build_smt_engine.py @@ -46,9 +46,7 @@ def run(args: dict) -> None: try: logger.info("SMT Engine Build Job started") - update_settings(SETTINGS, args) - - logger.info(f"Config: {SETTINGS.as_dict()}") + update_settings(SETTINGS, args, task, logger) shared_file_service = TranslationFileService(SharedFileServiceType.CLEARML, SETTINGS) smt_model_factory: SmtModelFactory diff --git a/machine/jobs/build_word_alignment_model.py b/machine/jobs/build_word_alignment_model.py index eaeaefaa..1181e665 100644 --- a/machine/jobs/build_word_alignment_model.py +++ b/machine/jobs/build_word_alignment_model.py @@ -47,9 +47,7 @@ def run(args: dict): try: logger.info("Word Alignment Build Job started") - update_settings(SETTINGS, args) - - logger.info(f"Config: {SETTINGS.as_dict()}") + update_settings(SETTINGS, args, task, logger) word_alignment_file_service = WordAlignmentFileService(SharedFileServiceType.CLEARML, SETTINGS) word_alignment_model_factory: WordAlignmentModelFactory