diff --git a/src/madengine/tools/discover_models.py b/src/madengine/tools/discover_models.py index 64110471..d6776740 100644 --- a/src/madengine/tools/discover_models.py +++ b/src/madengine/tools/discover_models.py @@ -95,8 +95,8 @@ def discover_models(self) -> None: # Update model name using backslash-separated path model_dict["name"] = dirname + '/' + model_dict["name"] # Update relative path for dockerfile and scripts - model_dict["dockerfile"] = os.path.join("scripts", dirname, model_dict["dockerfile"]) - model_dict["scripts"] = os.path.join("scripts", dirname, model_dict["scripts"]) + model_dict["dockerfile"] = os.path.normpath(os.path.join("scripts", dirname, model_dict["dockerfile"])) + model_dict["scripts"] = os.path.normpath(os.path.join("scripts", dirname, model_dict["scripts"])) self.models.append(model_dict) self.model_list.append(model_dict["name"]) @@ -144,8 +144,9 @@ def select_models(self) -> None: # of the tags are extra args to be passed into the model script. if len(tag_list) > 1: extra_args = [tag_ for tag_ in tag_list[1:]] - extra_args = " ".join(extra_args) - extra_args = " " + extra_args + extra_args = [tag_.strip().replace("=", " ") for tag_ in extra_args] + extra_args = " --".join(extra_args) + extra_args = " --" + extra_args else: extra_args = "" @@ -160,8 +161,8 @@ def select_models(self) -> None: custom_model.update_model() # Update relative path for dockerfile and scripts dirname = custom_model.name.split("/")[0] - custom_model.dockerfile = os.path.join("scripts", dirname, custom_model.dockerfile) - custom_model.scripts = os.path.join("scripts", dirname, custom_model.scripts) + custom_model.dockerfile = os.path.normpath(os.path.join("scripts", dirname, custom_model.dockerfile)) + custom_model.scripts = os.path.normpath(os.path.join("scripts", dirname, custom_model.scripts)) model_dict = custom_model.to_dict() model_dict["args"] = model_dict["args"] + extra_args tag_models.append(model_dict) diff --git a/tests/test_discover.py b/tests/test_discover.py index 0319977b..d0643985 100644 --- a/tests/test_discover.py +++ b/tests/test_discover.py @@ -69,7 +69,7 @@ def test_additional_args(self, global_data, clean_test_temp_files): with open(os.path.join(BASE_DIR, "perf.csv"), "r") as csv_file: csv_reader = csv.DictReader(csv_file) for row in csv_reader: - if row["model"] == "dummy2/model2" and row["status"] == "SUCCESS" and "batch-size=32" in row["args"]: + if row["model"] == "dummy2/model2" and row["status"] == "SUCCESS" and "--batch-size 32" in row["args"]: success = True if not success: pytest.fail("dummy2/model2:batch-size=32 did not run successfully.")