diff --git a/monai/bundle/workflows.py b/monai/bundle/workflows.py index da3aa30141..804b3b06f0 100644 --- a/monai/bundle/workflows.py +++ b/monai/bundle/workflows.py @@ -46,6 +46,9 @@ class BundleWorkflow(ABC): or "infer", "inference", "eval", "evaluation" for a inference workflow, other unsupported string will raise a ValueError. default to `None` for common workflow. + meta_file: filepath of the metadata file, if this is a list of file paths, their contents will be merged in order. + logging_file: config file for `logging` module in the program. for more details: + https://docs.python.org/3/library/logging.config.html#logging.config.fileConfig. """ @@ -59,11 +62,40 @@ class BundleWorkflow(ABC): new_name="workflow_type", msg_suffix="please use `workflow_type` instead.", ) - def __init__(self, workflow_type: str | None = None, workflow: str | None = None): + def __init__( + self, + workflow_type: str | None = None, + workflow: str | None = None, + meta_file: str | Sequence[str] | None = None, + logging_file: str | None = None, + ): + if logging_file is not None: + if not os.path.isfile(logging_file): + raise FileNotFoundError(f"Cannot find the logging config file: {logging_file}.") + logger.info(f"Setting logging properties based on config: {logging_file}.") + fileConfig(logging_file, disable_existing_loggers=False) + + if meta_file is not None: + if isinstance(meta_file, str) and not os.path.isfile(meta_file): + logger.error( + f"Cannot find the metadata config file: {meta_file}. " + "Please see: https://docs.monai.io/en/stable/mb_specification.html" + ) + meta_file = None + if isinstance(meta_file, list): + for f in meta_file: + if not os.path.isfile(f): + logger.error( + f"Cannot find the metadata config file: {f}. " + "Please see: https://docs.monai.io/en/stable/mb_specification.html" + ) + meta_file = None + workflow_type = workflow if workflow is not None else workflow_type if workflow_type is None: self.properties = copy(MetaProperties) self.workflow_type = None + self.meta_file = meta_file return if workflow_type.lower() in self.supported_train_type: self.properties = {**TrainProperties, **MetaProperties} @@ -74,6 +106,8 @@ def __init__(self, workflow_type: str | None = None, workflow: str | None = None else: raise ValueError(f"Unsupported workflow type: '{workflow_type}'.") + self.meta_file = meta_file + @abstractmethod def initialize(self, *args: Any, **kwargs: Any) -> Any: """ @@ -142,6 +176,13 @@ def get_workflow_type(self): """ return self.workflow_type + def get_meta_file(self): + """ + Get the meta file. + + """ + return self.meta_file + def add_property(self, name: str, required: str, desc: str | None = None) -> None: """ Besides the default predefined properties, some 3rd party applications may need the bundle @@ -233,25 +274,26 @@ def __init__( **override: Any, ) -> None: workflow_type = workflow if workflow is not None else workflow_type - super().__init__(workflow_type=workflow_type) if config_file is not None: _config_files = ensure_tuple(config_file) - self.config_root_path = Path(_config_files[0]).parent + config_root_path = Path(_config_files[0]).parent for _config_file in _config_files: _config_file = Path(_config_file) - if _config_file.parent != self.config_root_path: + if _config_file.parent != config_root_path: logger.warn( - f"Not all config files are in {self.config_root_path}. If logging_file and meta_file are" - f"not specified, {self.config_root_path} will be used as the default config root directory." + f"Not all config files are in {config_root_path}. If logging_file and meta_file are" + f"not specified, {config_root_path} will be used as the default config root directory." ) if not _config_file.is_file(): raise FileNotFoundError(f"Cannot find the config file: {_config_file}.") else: - self.config_root_path = Path("configs") - + config_root_path = Path("configs") + meta_file = str(config_root_path / "metadata.json") if meta_file is None else meta_file + super().__init__(workflow_type=workflow_type, meta_file=meta_file) + self.config_root_path = config_root_path logging_file = str(self.config_root_path / "logging.conf") if logging_file is None else logging_file if logging_file is not None: - if not os.path.exists(logging_file): + if not os.path.isfile(logging_file): if logging_file == str(self.config_root_path / "logging.conf"): logger.warn(f"Default logging file in {logging_file} does not exist, skipping logging.") else: @@ -262,14 +304,8 @@ def __init__( self.parser = ConfigParser() self.parser.read_config(f=config_file) - meta_file = str(self.config_root_path / "metadata.json") if meta_file is None else meta_file - if isinstance(meta_file, str) and not os.path.exists(meta_file): - logger.error( - f"Cannot find the metadata config file: {meta_file}. " - "Please see: https://docs.monai.io/en/stable/mb_specification.html" - ) - else: - self.parser.read_meta(f=meta_file) + if self.meta_file is not None: + self.parser.read_meta(f=self.meta_file) # the rest key-values in the _args are to override config content self.parser.update(pairs=override) diff --git a/tests/nonconfig_workflow.py b/tests/nonconfig_workflow.py index 7b5328bf72..b2c44c12c6 100644 --- a/tests/nonconfig_workflow.py +++ b/tests/nonconfig_workflow.py @@ -36,8 +36,8 @@ class NonConfigWorkflow(BundleWorkflow): """ - def __init__(self, filename, output_dir): - super().__init__(workflow_type="inference") + def __init__(self, filename, output_dir, meta_file=None, logging_file=None): + super().__init__(workflow_type="inference", meta_file=meta_file, logging_file=logging_file) self.filename = filename self.output_dir = output_dir self._bundle_root = "will override" diff --git a/tests/test_bundle_workflow.py b/tests/test_bundle_workflow.py index f7da37acef..0b0d51cbfb 100644 --- a/tests/test_bundle_workflow.py +++ b/tests/test_bundle_workflow.py @@ -35,6 +35,8 @@ TEST_CASE_3 = [os.path.join(os.path.dirname(__file__), "testing_data", "config_fl_train.json")] +TEST_CASE_NON_CONFIG_WRONG_LOG = [None, "logging.conf", "Cannot find the logging config file: logging.conf."] + class TestBundleWorkflow(unittest.TestCase): @@ -144,8 +146,14 @@ def test_train_config(self, config_file): def test_non_config(self): # test user defined python style workflow inferer = NonConfigWorkflow(self.filename, self.data_dir) + self.assertEqual(inferer.meta_file, None) self._test_inferer(inferer) + @parameterized.expand([TEST_CASE_NON_CONFIG_WRONG_LOG]) + def test_non_config_wrong_log_cases(self, meta_file, logging_file, expected_error): + with self.assertRaisesRegex(FileNotFoundError, expected_error): + NonConfigWorkflow(self.filename, self.data_dir, meta_file, logging_file) + if __name__ == "__main__": unittest.main()