diff --git a/monai/bundle/workflows.py b/monai/bundle/workflows.py index 3b349e1103..5979619088 100644 --- a/monai/bundle/workflows.py +++ b/monai/bundle/workflows.py @@ -12,6 +12,7 @@ from __future__ import annotations import os +import sys import time import warnings from abc import ABC, abstractmethod @@ -170,6 +171,7 @@ class ConfigWorkflow(BundleWorkflow): """ Specification for the config-based bundle workflow. Standardized the `initialize`, `run`, `finalize` behavior in a config-based training, evaluation, or inference. + Before `run`, we add bundle root directory to Python search directories automatically. For more information: https://docs.monai.io/en/latest/mb_specification.html. Args: @@ -224,23 +226,23 @@ def __init__( super().__init__(workflow_type=workflow_type) if config_file is not None: _config_files = ensure_tuple(config_file) - config_root_path = Path(_config_files[0]).parent + self.config_root_path = Path(_config_files[0]).parent for _config_file in _config_files: _config_file = Path(_config_file) - if _config_file.parent != config_root_path: + if _config_file.parent != self.config_root_path: warnings.warn( - f"Not all config files are in {config_root_path}. If logging_file and meta_file are" - f"not specified, {config_root_path} will be used as the default config root directory." + f"Not all config files are in {self.config_root_path}. If logging_file and meta_file are" + f"not specified, {self.config_root_path} will be used as the default config root directory." ) if not _config_file.is_file(): raise FileNotFoundError(f"Cannot find the config file: {_config_file}.") else: - config_root_path = Path("configs") + self.config_root_path = Path("configs") - logging_file = str(config_root_path / "logging.conf") if logging_file is None else logging_file + logging_file = str(self.config_root_path / "logging.conf") if logging_file is None else logging_file if logging_file is not None: if not os.path.exists(logging_file): - if logging_file == str(config_root_path / "logging.conf"): + if logging_file == str(self.config_root_path / "logging.conf"): warnings.warn(f"Default logging file in {logging_file} does not exist, skipping logging.") else: raise FileNotFoundError(f"Cannot find the logging config file: {logging_file}.") @@ -250,7 +252,7 @@ def __init__( self.parser = ConfigParser() self.parser.read_config(f=config_file) - meta_file = str(config_root_path / "metadata.json") if meta_file is None else meta_file + meta_file = str(self.config_root_path / "metadata.json") if meta_file is None else meta_file if isinstance(meta_file, str) and not os.path.exists(meta_file): raise FileNotFoundError(f"Cannot find the metadata config file: {meta_file}.") else: @@ -283,8 +285,13 @@ def initialize(self) -> Any: def run(self) -> Any: """ Run the bundle workflow, it can be a training, evaluation or inference. + Before run, we add bundle root directory to Python search directories automatically. """ + _bundle_root_path = ( + self.config_root_path.parent if self.config_root_path.name == "configs" else self.config_root_path + ) + sys.path.insert(1, str(_bundle_root_path)) if self.run_id not in self.parser: raise ValueError(f"run ID '{self.run_id}' doesn't exist in the config file.") return self._run_expr(id=self.run_id) diff --git a/tests/test_integration_bundle_run.py b/tests/test_integration_bundle_run.py index 42abc1a5e0..e6f3d6b8c6 100644 --- a/tests/test_integration_bundle_run.py +++ b/tests/test_integration_bundle_run.py @@ -14,6 +14,7 @@ import json import os import shutil +import subprocess import sys import tempfile import unittest @@ -44,6 +45,14 @@ def run(self): return self.val +class _Runnable43: + def __init__(self, func): + self.func = func + + def run(self): + self.func() + + class TestBundleRun(unittest.TestCase): def setUp(self): self.data_dir = tempfile.mkdtemp() @@ -77,6 +86,69 @@ def test_tiny(self): with self.assertRaises(RuntimeError): # test wrong run_id="run" command_line_tests(cmd + ["run", "run", "--config_file", config_file]) + with self.assertRaises(RuntimeError): + # test missing meta file + command_line_tests(cmd + ["run", "training", "--config_file", config_file]) + + def test_scripts_fold(self): + # test scripts directory has been added to Python search directories automatically + config_file = os.path.join(self.data_dir, "tiny_config.json") + meta_file = os.path.join(self.data_dir, "tiny_meta.json") + scripts_dir = os.path.join(self.data_dir, "scripts") + script_file = os.path.join(scripts_dir, "test_scripts_fold.py") + init_file = os.path.join(scripts_dir, "__init__.py") + + with open(config_file, "w") as f: + json.dump( + { + "imports": ["$import scripts"], + "trainer": { + "_target_": "tests.test_integration_bundle_run._Runnable43", + "func": "$scripts.tiny_test", + }, + # keep this test case to cover the "runner_id" arg + "training": "$@trainer.run()", + }, + f, + ) + with open(meta_file, "w") as f: + json.dump( + {"version": "0.1.0", "monai_version": "1.1.0", "pytorch_version": "1.13.1", "numpy_version": "1.22.2"}, + f, + ) + + os.mkdir(scripts_dir) + script_file_lines = ["def tiny_test():\n", " print('successfully added scripts fold!') \n"] + init_file_line = "from .test_scripts_fold import tiny_test\n" + with open(script_file, "w") as f: + f.writelines(script_file_lines) + f.close() + with open(init_file, "w") as f: + f.write(init_file_line) + f.close() + + cmd = ["coverage", "run", "-m", "monai.bundle"] + # test both CLI entry "run" and "run_workflow" + expected_condition = "successfully added scripts fold!" + command_run = cmd + ["run", "training", "--config_file", config_file, "--meta_file", meta_file] + completed_process = subprocess.run(command_run, check=True, capture_output=True, text=True) + output = repr(completed_process.stdout).replace("\\n", "\n").replace("\\t", "\t") # Get the captured output + print(output) + + self.assertTrue(expected_condition in output) + command_run_workflow = cmd + [ + "run_workflow", + "--run_id", + "training", + "--config_file", + config_file, + "--meta_file", + meta_file, + ] + completed_process = subprocess.run(command_run_workflow, check=True, capture_output=True, text=True) + output = repr(completed_process.stdout).replace("\\n", "\n").replace("\\t", "\t") # Get the captured output + print(output) + self.assertTrue(expected_condition in output) with self.assertRaises(RuntimeError): # test missing meta file