From faa8dc6b31b67a8091931fd7ae7df66f62db97d7 Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Tue, 8 Aug 2023 14:41:48 +0800 Subject: [PATCH 01/56] add `AutoBundle` and `from_bundle` Signed-off-by: KumoLiu --- monai/bundle/scripts.py | 47 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 47 insertions(+) diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index 322a2efc5b..189ec57eb7 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -1501,3 +1501,50 @@ def init_bundle( copyfile(str(ckpt_file), str(models_dir / "model.pt")) elif network is not None: save_state(network, str(models_dir / "model.pt")) + + +class AutoBundle(): + def __init__(self, bundle_name_or_path: str, bundle_dir: str = None, workflow: str = "train", **kwargs) -> None: + _bundle_name_or_path = Path(ensure_tuple(bundle_name_or_path)[0]) + if _bundle_name_or_path.is_file(): + config_file = bundle_name_or_path + logging_file = f"{_bundle_name_or_path.parent}/logging.conf" + self.meta_file = f"{_bundle_name_or_path.parent}/metadata.json" + else: + download_args = { + "version": kwargs.pop("version", None), + "source": kwargs.pop("source", download_source), + } + download(bundle_name_or_path, bundle_dir=bundle_dir, **download_args) + bundle_dir = _process_bundle_dir(bundle_dir) + config_file = f"{bundle_dir}/{bundle_name_or_path}/configs/{workflow}.json" + logging_file = f"{bundle_dir}/{bundle_name_or_path}/configs/logging.conf" + self.meta_file = f"{bundle_dir}/{bundle_name_or_path}/configs/metadata.json" + + self.workflow = ConfigWorkflow( + config_file=config_file, + meta_file=self.meta_file, + logging_file=logging_file, + workflow=workflow, + **kwargs + ) + self.workflow.initialize() + + def from_bundle(self, property: str | None = None, meta: str | None = None): + if property is not None and meta is not None: + raise ValueError("Incompatible values: both property and meta are specified.") + if property is not None: + if property in self.workflow.properties: + return getattr(self.workflow, property) + raise ValueError(f"Missing property {property} in the bundle.") + if meta is not None: + metadata = ConfigParser.load_config_files(files=self.meta_file) + if meta.lower() in metadata.keys(): + return {meta: metadata[meta.lower()]} + raise ValueError(f"Missing meta {meta} informtation in metadata.json.") + + def train(self): + pass + + def predict(self): + pass From be7aa45a2e1e46e9be02ee9048b264e792faa67b Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Tue, 8 Aug 2023 14:55:20 +0800 Subject: [PATCH 02/56] minor fix Signed-off-by: KumoLiu --- monai/bundle/scripts.py | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index 4fdcff300c..4013977036 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -1541,7 +1541,7 @@ def init_bundle( save_state(network, str(models_dir / "model.pt")) -class AutoBundle(): +class AutoBundle: def __init__(self, bundle_name_or_path: str, bundle_dir: str = None, workflow: str = "train", **kwargs) -> None: _bundle_name_or_path = Path(ensure_tuple(bundle_name_or_path)[0]) if _bundle_name_or_path.is_file(): @@ -1549,10 +1549,7 @@ def __init__(self, bundle_name_or_path: str, bundle_dir: str = None, workflow: s logging_file = f"{_bundle_name_or_path.parent}/logging.conf" self.meta_file = f"{_bundle_name_or_path.parent}/metadata.json" else: - download_args = { - "version": kwargs.pop("version", None), - "source": kwargs.pop("source", download_source), - } + download_args = {"version": kwargs.pop("version", None), "source": kwargs.pop("source", download_source)} download(bundle_name_or_path, bundle_dir=bundle_dir, **download_args) bundle_dir = _process_bundle_dir(bundle_dir) config_file = f"{bundle_dir}/{bundle_name_or_path}/configs/{workflow}.json" @@ -1560,11 +1557,7 @@ def __init__(self, bundle_name_or_path: str, bundle_dir: str = None, workflow: s self.meta_file = f"{bundle_dir}/{bundle_name_or_path}/configs/metadata.json" self.workflow = ConfigWorkflow( - config_file=config_file, - meta_file=self.meta_file, - logging_file=logging_file, - workflow=workflow, - **kwargs + config_file=config_file, meta_file=self.meta_file, logging_file=logging_file, workflow=workflow, **kwargs ) self.workflow.initialize() From a7bc4e86b8b399229a0d7bafcceb8eff08e72952 Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Thu, 10 Aug 2023 22:16:15 +0800 Subject: [PATCH 03/56] rename to `BundleManager` Signed-off-by: KumoLiu --- monai/bundle/__init__.py | 1 + monai/bundle/scripts.py | 16 ++++++++-------- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/monai/bundle/__init__.py b/monai/bundle/__init__.py index e8ea9d62b0..2605973495 100644 --- a/monai/bundle/__init__.py +++ b/monai/bundle/__init__.py @@ -16,6 +16,7 @@ from .properties import InferProperties, TrainProperties from .reference_resolver import ReferenceResolver from .scripts import ( + BundleManager, ckpt_export, download, get_all_bundles_list, diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index 4013977036..c7ec837c11 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -1541,27 +1541,27 @@ def init_bundle( save_state(network, str(models_dir / "model.pt")) -class AutoBundle: +class BundleManager: def __init__(self, bundle_name_or_path: str, bundle_dir: str = None, workflow: str = "train", **kwargs) -> None: _bundle_name_or_path = Path(ensure_tuple(bundle_name_or_path)[0]) if _bundle_name_or_path.is_file(): config_file = bundle_name_or_path - logging_file = f"{_bundle_name_or_path.parent}/logging.conf" - self.meta_file = f"{_bundle_name_or_path.parent}/metadata.json" + logging_file = _bundle_name_or_path.parent / "logging.conf" + self.meta_file = _bundle_name_or_path.parent / "metadata.json" else: download_args = {"version": kwargs.pop("version", None), "source": kwargs.pop("source", download_source)} download(bundle_name_or_path, bundle_dir=bundle_dir, **download_args) bundle_dir = _process_bundle_dir(bundle_dir) - config_file = f"{bundle_dir}/{bundle_name_or_path}/configs/{workflow}.json" - logging_file = f"{bundle_dir}/{bundle_name_or_path}/configs/logging.conf" - self.meta_file = f"{bundle_dir}/{bundle_name_or_path}/configs/metadata.json" + config_file = bundle_dir / bundle_name_or_path / "configs" / f"{workflow}.json" + logging_file = bundle_dir / bundle_name_or_path / "configs" / "logging.conf" + self.meta_file = bundle_dir / bundle_name_or_path / "configs" / "metadata.json" self.workflow = ConfigWorkflow( - config_file=config_file, meta_file=self.meta_file, logging_file=logging_file, workflow=workflow, **kwargs + config_file=str(config_file), meta_file=str(self.meta_file), logging_file=str(logging_file), workflow=workflow, **kwargs ) self.workflow.initialize() - def from_bundle(self, property: str | None = None, meta: str | None = None): + def get(self, property: str | None = None, meta: str | None = None) -> Any: if property is not None and meta is not None: raise ValueError("Incompatible values: both property and meta are specified.") if property is not None: From b8bc55a988a2ca0d31b80b48b17732de0fa24e7c Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Thu, 10 Aug 2023 22:59:35 +0800 Subject: [PATCH 04/56] support multi configs Signed-off-by: KumoLiu --- monai/bundle/scripts.py | 24 +++++++++++++++++------- 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index c7ec837c11..1bfc3ea520 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -1542,22 +1542,32 @@ def init_bundle( class BundleManager: - def __init__(self, bundle_name_or_path: str, bundle_dir: str = None, workflow: str = "train", **kwargs) -> None: + def __init__(self, bundle_name_or_path: str | Sequence[str], bundle_dir: str = None, configs: str | Sequence[str] = "train", **kwargs) -> None: + configs = ensure_tuple(configs) + if "train" in configs: + workflow = "train" + else: + workflow = "infer" + _bundle_name_or_path = Path(ensure_tuple(bundle_name_or_path)[0]) if _bundle_name_or_path.is_file(): config_file = bundle_name_or_path - logging_file = _bundle_name_or_path.parent / "logging.conf" - self.meta_file = _bundle_name_or_path.parent / "metadata.json" + config_root_path = _bundle_name_or_path.parent else: download_args = {"version": kwargs.pop("version", None), "source": kwargs.pop("source", download_source)} download(bundle_name_or_path, bundle_dir=bundle_dir, **download_args) bundle_dir = _process_bundle_dir(bundle_dir) - config_file = bundle_dir / bundle_name_or_path / "configs" / f"{workflow}.json" - logging_file = bundle_dir / bundle_name_or_path / "configs" / "logging.conf" - self.meta_file = bundle_dir / bundle_name_or_path / "configs" / "metadata.json" + config_root_path = bundle_dir / bundle_name_or_path / "configs" + if len(configs) > 0: + config_file = [str(config_root_path / f"{_config}.json") for _config in configs] + else: + config_file = str(config_root_path / f"{configs[0]}.json") + + logging_file = config_root_path / "logging.conf" + self.meta_file = config_root_path / "metadata.json" self.workflow = ConfigWorkflow( - config_file=str(config_file), meta_file=str(self.meta_file), logging_file=str(logging_file), workflow=workflow, **kwargs + config_file=config_file, meta_file=str(self.meta_file), logging_file=str(logging_file), workflow=workflow, **kwargs ) self.workflow.initialize() From eb6a9d48dd91aac5733aba8ba6a939e311ff680c Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Thu, 10 Aug 2023 23:42:08 +0800 Subject: [PATCH 05/56] add docstring Signed-off-by: KumoLiu --- monai/bundle/scripts.py | 63 +++++++++++++++++++++++++++++++++++++++-- 1 file changed, 60 insertions(+), 3 deletions(-) diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index 1bfc3ea520..588f1fcf99 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -1542,7 +1542,48 @@ def init_bundle( class BundleManager: - def __init__(self, bundle_name_or_path: str | Sequence[str], bundle_dir: str = None, configs: str | Sequence[str] = "train", **kwargs) -> None: + """ + The `BundleManager` class facilitates the automatic downloading and instantiation of bundles. + It allows users to retrieve bundle properties and meta information. + + Typical usage examples: + + .. code-block:: python + + from monai.bundle import BundleManager + + # Create a BundleManager instance for the 'spleen_ct_segmentation' bundle + bundle = BundleManager("spleen_ct_segmentation") + + # Get properties defined in `TrainProperties` or `InferProperties` + train_preprocessing = bundle.get("train_preprocessing") + print(train_preprocessing) + + # Also support to retrieve meta information from the "metadata.json" file + version = bundle.get(meta = "version") + print(version) + + + Args: + bundle_name_or_path: the name or file path of the bundle. If a list of file paths is provided, + their contents will be merged. + bundle_dir: the target directory to store downloaded bundle. + Defaults to the 'bundle' subfolder under `torch.hub.get_dir()`. + + target directory to store the downloaded data. + Default is `bundle` subfolder under `torch.hub.get_dir()`. + configs: The name of the config file(s), supporting multiple names. + Defaults to "train". + kwargs: Additional arguments for download or workflow class instantiation. + """ + + def __init__( + self, + bundle_name_or_path: str | Sequence[str], + bundle_dir: PathLike | None = None, + configs: str | Sequence[str] = "train", + **kwargs, + ) -> None: configs = ensure_tuple(configs) if "train" in configs: workflow = "train" @@ -1554,7 +1595,11 @@ def __init__(self, bundle_name_or_path: str | Sequence[str], bundle_dir: str = N config_file = bundle_name_or_path config_root_path = _bundle_name_or_path.parent else: - download_args = {"version": kwargs.pop("version", None), "source": kwargs.pop("source", download_source)} + download_args = { + "version": kwargs.pop("version", None), + "args_file": kwargs.pop("args_file", None), + "source": kwargs.pop("source", download_source), + } download(bundle_name_or_path, bundle_dir=bundle_dir, **download_args) bundle_dir = _process_bundle_dir(bundle_dir) config_root_path = bundle_dir / bundle_name_or_path / "configs" @@ -1567,11 +1612,23 @@ def __init__(self, bundle_name_or_path: str | Sequence[str], bundle_dir: str = N self.meta_file = config_root_path / "metadata.json" self.workflow = ConfigWorkflow( - config_file=config_file, meta_file=str(self.meta_file), logging_file=str(logging_file), workflow=workflow, **kwargs + config_file=config_file, + meta_file=str(self.meta_file), + logging_file=str(logging_file), + workflow=workflow, + **kwargs, ) self.workflow.initialize() def get(self, property: str | None = None, meta: str | None = None) -> Any: + """ + Get information from the bundle. + + Args: + property: the target property, defined in `TrainProperties` or `InferProperties`. + meta: meta information retrieved from the "metadata.json" file, such as version, changelog, etc. + + """ if property is not None and meta is not None: raise ValueError("Incompatible values: both property and meta are specified.") if property is not None: From fc677b5da0907e526be2dae2363cd9764d4011a1 Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Fri, 11 Aug 2023 15:05:27 +0800 Subject: [PATCH 06/56] change `bundle_name_or_path` to `bundle_name` and `config_path` Signed-off-by: KumoLiu --- monai/bundle/scripts.py | 33 +++++++++++++++++++-------------- 1 file changed, 19 insertions(+), 14 deletions(-) diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index 588f1fcf99..836a82c922 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -1560,16 +1560,15 @@ class BundleManager: print(train_preprocessing) # Also support to retrieve meta information from the "metadata.json" file - version = bundle.get(meta = "version") + version = bundle.get(meta="version") print(version) - Args: - bundle_name_or_path: the name or file path of the bundle. If a list of file paths is provided, - their contents will be merged. + bundle_name: the name of the bundle. Defaults to None. + config_path: file path of the bundle. If a list of file paths is provided, + their contents will be merged. Defaults to None. bundle_dir: the target directory to store downloaded bundle. Defaults to the 'bundle' subfolder under `torch.hub.get_dir()`. - target directory to store the downloaded data. Default is `bundle` subfolder under `torch.hub.get_dir()`. configs: The name of the config file(s), supporting multiple names. @@ -1579,30 +1578,36 @@ class BundleManager: def __init__( self, - bundle_name_or_path: str | Sequence[str], + bundle_name: str | None = None, + config_path: str | Sequence[str] | None = None, bundle_dir: PathLike | None = None, configs: str | Sequence[str] = "train", - **kwargs, + **kwargs: Any, ) -> None: + if bundle_name is None and config_path is None: + raise ValueError("Must specify bundle_name or config_path.") configs = ensure_tuple(configs) if "train" in configs: workflow = "train" else: workflow = "infer" - _bundle_name_or_path = Path(ensure_tuple(bundle_name_or_path)[0]) - if _bundle_name_or_path.is_file(): - config_file = bundle_name_or_path - config_root_path = _bundle_name_or_path.parent + if config_path is not None: + _config_path = Path(ensure_tuple(config_path)[0]) + if _config_path.is_file(): + config_file = config_path + config_root_path = _config_path.parent + else: + raise FileNotFoundError(f"Cannot find the config file: {config_path}.") else: download_args = { "version": kwargs.pop("version", None), "args_file": kwargs.pop("args_file", None), "source": kwargs.pop("source", download_source), } - download(bundle_name_or_path, bundle_dir=bundle_dir, **download_args) + download(bundle_name, bundle_dir=bundle_dir, **download_args) bundle_dir = _process_bundle_dir(bundle_dir) - config_root_path = bundle_dir / bundle_name_or_path / "configs" + config_root_path = bundle_dir / bundle_name / "configs" # type: ignore if len(configs) > 0: config_file = [str(config_root_path / f"{_config}.json") for _config in configs] else: @@ -1632,7 +1637,7 @@ def get(self, property: str | None = None, meta: str | None = None) -> Any: if property is not None and meta is not None: raise ValueError("Incompatible values: both property and meta are specified.") if property is not None: - if property in self.workflow.properties: + if property in self.workflow.properties: # type: ignore return getattr(self.workflow, property) raise ValueError(f"Missing property {property} in the bundle.") if meta is not None: From 8fb5076713f5c4a7be6124c037f563420b8fb9c2 Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Fri, 11 Aug 2023 15:52:06 +0800 Subject: [PATCH 07/56] support flexible postfix Signed-off-by: KumoLiu --- monai/bundle/scripts.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index 836a82c922..211df9ccbc 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -1541,6 +1541,15 @@ def init_bundle( save_state(network, str(models_dir / "model.pt")) +def _find_bundle_file(root_dir: Path, file_name: str, suffix: Sequence[str] = ("json", "yaml", "yml")) -> str | None: + # find bundle file with possible suffix + for _suffix in suffix: + full_name = f"{file_name}.{_suffix}" + if full_name in os.listdir(root_dir): + return full_name + return None + + class BundleManager: """ The `BundleManager` class facilitates the automatic downloading and instantiation of bundles. @@ -1609,9 +1618,9 @@ def __init__( bundle_dir = _process_bundle_dir(bundle_dir) config_root_path = bundle_dir / bundle_name / "configs" # type: ignore if len(configs) > 0: - config_file = [str(config_root_path / f"{_config}.json") for _config in configs] + config_file = [str(config_root_path / _find_bundle_file(config_root_path, _config)) for _config in configs] # type: ignore else: - config_file = str(config_root_path / f"{configs[0]}.json") + config_file = str(config_root_path / _find_bundle_file(config_root_path, configs[0])) # type: ignore logging_file = config_root_path / "logging.conf" self.meta_file = config_root_path / "metadata.json" From e4d061fd9246aa44882edee314292c03fe147727 Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Fri, 11 Aug 2023 16:00:24 +0800 Subject: [PATCH 08/56] minor fix Signed-off-by: KumoLiu --- monai/bundle/scripts.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index 211df9ccbc..225daf9b65 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -1618,7 +1618,9 @@ def __init__( bundle_dir = _process_bundle_dir(bundle_dir) config_root_path = bundle_dir / bundle_name / "configs" # type: ignore if len(configs) > 0: - config_file = [str(config_root_path / _find_bundle_file(config_root_path, _config)) for _config in configs] # type: ignore + config_file = [ # type: ignore + str(config_root_path / _find_bundle_file(config_root_path, _config)) for _config in configs + ] else: config_file = str(config_root_path / _find_bundle_file(config_root_path, configs[0])) # type: ignore From d49f3d0f519b64d5b2aba7da28da6aaba1129c64 Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Fri, 11 Aug 2023 16:55:34 +0800 Subject: [PATCH 09/56] minor fix Signed-off-by: KumoLiu --- monai/bundle/scripts.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index 225daf9b65..cb45135757 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -1618,8 +1618,8 @@ def __init__( bundle_dir = _process_bundle_dir(bundle_dir) config_root_path = bundle_dir / bundle_name / "configs" # type: ignore if len(configs) > 0: - config_file = [ # type: ignore - str(config_root_path / _find_bundle_file(config_root_path, _config)) for _config in configs + config_file = [ + str(config_root_path / _find_bundle_file(config_root_path, _config)) for _config in configs # type: ignore ] else: config_file = str(config_root_path / _find_bundle_file(config_root_path, configs[0])) # type: ignore From dc8858b53ca5d33b423f7a6d74644362d1a085f5 Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Fri, 11 Aug 2023 19:20:22 +0800 Subject: [PATCH 10/56] update based on comments Signed-off-by: KumoLiu --- monai/bundle/scripts.py | 40 +++++++++++++++++++--------------------- 1 file changed, 19 insertions(+), 21 deletions(-) diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index cb45135757..6051c9c42f 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -1541,11 +1541,11 @@ def init_bundle( save_state(network, str(models_dir / "model.pt")) -def _find_bundle_file(root_dir: Path, file_name: str, suffix: Sequence[str] = ("json", "yaml", "yml")) -> str | None: +def _find_config_file(root_dir: Path, file_name: str, suffix: Sequence[str] = ("json", "yaml", "yml")) -> Path | None: # find bundle file with possible suffix for _suffix in suffix: - full_name = f"{file_name}.{_suffix}" - if full_name in os.listdir(root_dir): + full_name = root_dir / f"{file_name}.{_suffix}" + if full_name.is_file(): return full_name return None @@ -1619,43 +1619,41 @@ def __init__( config_root_path = bundle_dir / bundle_name / "configs" # type: ignore if len(configs) > 0: config_file = [ - str(config_root_path / _find_bundle_file(config_root_path, _config)) for _config in configs # type: ignore + str(_find_config_file(config_root_path, _config)) for _config in configs # type: ignore ] else: - config_file = str(config_root_path / _find_bundle_file(config_root_path, configs[0])) # type: ignore + config_file = str(_find_config_file(config_root_path, configs[0])) # type: ignore logging_file = config_root_path / "logging.conf" self.meta_file = config_root_path / "metadata.json" self.workflow = ConfigWorkflow( config_file=config_file, - meta_file=str(self.meta_file), - logging_file=str(logging_file), + meta_file=str(self.meta_file) if self.meta_file.is_file() else None, + logging_file=str(logging_file) if logging_file.is_file() else None, workflow=workflow, **kwargs, ) self.workflow.initialize() - def get(self, property: str | None = None, meta: str | None = None) -> Any: + def get(self, id: str = "", default: Any | None = None) -> Any: """ - Get information from the bundle. + Get information from the bundle by id. Args: - property: the target property, defined in `TrainProperties` or `InferProperties`. - meta: meta information retrieved from the "metadata.json" file, such as version, changelog, etc. + id: id to specify the expected position. + It could be the target property, defined in `TrainProperties` or `InferProperties`. + Or meta information retrieved from the "metadata.json" file, such as version, changelog, etc. """ - if property is not None and meta is not None: - raise ValueError("Incompatible values: both property and meta are specified.") - if property is not None: - if property in self.workflow.properties: # type: ignore - return getattr(self.workflow, property) - raise ValueError(f"Missing property {property} in the bundle.") - if meta is not None: + if id in self.workflow.properties: # type: ignore + return getattr(self.workflow, id) + elif self.meta_file.is_file(): metadata = ConfigParser.load_config_files(files=self.meta_file) - if meta.lower() in metadata.keys(): - return {meta: metadata[meta.lower()]} - raise ValueError(f"Missing meta {meta} informtation in metadata.json.") + if id.lower() in metadata.keys(): + return {id: metadata[id.lower()]} + warnings.warn("Specified ``id`` is invalid, return default value.") + return default def train(self): pass From cc5c2c7502376cd892ff6168714cb7ec57e10313 Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Fri, 11 Aug 2023 19:21:28 +0800 Subject: [PATCH 11/56] fix flake8 Signed-off-by: KumoLiu --- monai/bundle/scripts.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index 6051c9c42f..80e7261d5c 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -1618,9 +1618,7 @@ def __init__( bundle_dir = _process_bundle_dir(bundle_dir) config_root_path = bundle_dir / bundle_name / "configs" # type: ignore if len(configs) > 0: - config_file = [ - str(_find_config_file(config_root_path, _config)) for _config in configs # type: ignore - ] + config_file = [str(_find_config_file(config_root_path, _config)) for _config in configs] # type: ignore else: config_file = str(_find_config_file(config_root_path, configs[0])) # type: ignore @@ -1641,7 +1639,7 @@ def get(self, id: str = "", default: Any | None = None) -> Any: Get information from the bundle by id. Args: - id: id to specify the expected position. + id: id to specify the expected position. It could be the target property, defined in `TrainProperties` or `InferProperties`. Or meta information retrieved from the "metadata.json" file, such as version, changelog, etc. From 52c9f525352f52fcb46996dac4a6da3d733a0949 Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Fri, 11 Aug 2023 19:25:23 +0800 Subject: [PATCH 12/56] update docstring Signed-off-by: KumoLiu --- monai/bundle/scripts.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index 80e7261d5c..9430ddd91b 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -1569,7 +1569,7 @@ class BundleManager: print(train_preprocessing) # Also support to retrieve meta information from the "metadata.json" file - version = bundle.get(meta="version") + version = bundle.get("version") print(version) Args: @@ -1642,6 +1642,7 @@ def get(self, id: str = "", default: Any | None = None) -> Any: id: id to specify the expected position. It could be the target property, defined in `TrainProperties` or `InferProperties`. Or meta information retrieved from the "metadata.json" file, such as version, changelog, etc. + default: default value to return if the specified ``id`` is invalid. """ if id in self.workflow.properties: # type: ignore @@ -1650,7 +1651,7 @@ def get(self, id: str = "", default: Any | None = None) -> Any: metadata = ConfigParser.load_config_files(files=self.meta_file) if id.lower() in metadata.keys(): return {id: metadata[id.lower()]} - warnings.warn("Specified ``id`` is invalid, return default value.") + warnings.warn("Specified ``id`` is invalid or missing 'metadata.json', return default value.") return default def train(self): From 5eec914f3f98f20414d10ae9f3090b1296567061 Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Sun, 13 Aug 2023 17:34:30 +0800 Subject: [PATCH 13/56] update based on comments Signed-off-by: KumoLiu --- monai/bundle/scripts.py | 30 ++++++++++++++++++++---------- 1 file changed, 20 insertions(+), 10 deletions(-) diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index 9430ddd91b..7e0765f532 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -63,7 +63,7 @@ # set BUNDLE_DOWNLOAD_SRC="ngc" to use NGC source in default for bundle download # set BUNDLE_DOWNLOAD_SRC="monaihosting" to use monaihosting source in default for bundle download -download_source = os.environ.get("BUNDLE_DOWNLOAD_SRC", "github") +DEFAULT_DOWNLOAD_SOURCE = os.environ.get("BUNDLE_DOWNLOAD_SRC", "github") PPRINT_CONFIG_N = 5 @@ -253,7 +253,7 @@ def download( name: str | None = None, version: str | None = None, bundle_dir: PathLike | None = None, - source: str = download_source, + source: str = DEFAULT_DOWNLOAD_SOURCE, repo: str | None = None, url: str | None = None, remove_prefix: str | None = "monai_", @@ -382,7 +382,7 @@ def load( model_file: str | None = None, load_ts_module: bool = False, bundle_dir: PathLike | None = None, - source: str = download_source, + source: str = DEFAULT_DOWNLOAD_SOURCE, repo: str | None = None, remove_prefix: str | None = "monai_", progress: bool = True, @@ -1582,7 +1582,13 @@ class BundleManager: Default is `bundle` subfolder under `torch.hub.get_dir()`. configs: The name of the config file(s), supporting multiple names. Defaults to "train". - kwargs: Additional arguments for download or workflow class instantiation. + version: version name of the target bundle to download, like: "0.1.0". If `None`, will download + the latest version. + source: storage location name. This argument is used when `url` is `None`. + In default, the value is achieved from the environment variable BUNDLE_DOWNLOAD_SRC, and + it should be "ngc", "monaihosting" or "github". + args_file: a JSON or YAML file to provide default values for all the args in download function. + kwargs: Additional arguments for workflow class instantiation. """ def __init__( @@ -1591,6 +1597,9 @@ def __init__( config_path: str | Sequence[str] | None = None, bundle_dir: PathLike | None = None, configs: str | Sequence[str] = "train", + version: str | None = None, + source: str = DEFAULT_DOWNLOAD_SOURCE, + args_file: str | None = None, **kwargs: Any, ) -> None: if bundle_name is None and config_path is None: @@ -1609,13 +1618,14 @@ def __init__( else: raise FileNotFoundError(f"Cannot find the config file: {config_path}.") else: - download_args = { - "version": kwargs.pop("version", None), - "args_file": kwargs.pop("args_file", None), - "source": kwargs.pop("source", download_source), - } - download(bundle_name, bundle_dir=bundle_dir, **download_args) bundle_dir = _process_bundle_dir(bundle_dir) + download( + bundle_name, + bundle_dir=bundle_dir, + version=version, + source=source, + args_file=args_file, + ) config_root_path = bundle_dir / bundle_name / "configs" # type: ignore if len(configs) > 0: config_file = [str(_find_config_file(config_root_path, _config)) for _config in configs] # type: ignore From 42bac3d200d6b7ead229969cb1817e6fd398b41a Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Sun, 13 Aug 2023 17:43:02 +0800 Subject: [PATCH 14/56] fix flake8 Signed-off-by: KumoLiu --- monai/bundle/scripts.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index 7e0765f532..d1e27cf25d 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -1619,13 +1619,7 @@ def __init__( raise FileNotFoundError(f"Cannot find the config file: {config_path}.") else: bundle_dir = _process_bundle_dir(bundle_dir) - download( - bundle_name, - bundle_dir=bundle_dir, - version=version, - source=source, - args_file=args_file, - ) + download(bundle_name, bundle_dir=bundle_dir, version=version, source=source, args_file=args_file) config_root_path = bundle_dir / bundle_name / "configs" # type: ignore if len(configs) > 0: config_file = [str(_find_config_file(config_root_path, _config)) for _config in configs] # type: ignore From 090399705757c3a9ddbe2033468d5a8f11faafed Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Mon, 14 Aug 2023 14:35:33 +0800 Subject: [PATCH 15/56] rename `bundle_name` to `name` Signed-off-by: KumoLiu --- monai/bundle/scripts.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index d1e27cf25d..48cec9cdae 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -1573,7 +1573,7 @@ class BundleManager: print(version) Args: - bundle_name: the name of the bundle. Defaults to None. + name: the name of the bundle. Defaults to None. config_path: file path of the bundle. If a list of file paths is provided, their contents will be merged. Defaults to None. bundle_dir: the target directory to store downloaded bundle. @@ -1593,7 +1593,7 @@ class BundleManager: def __init__( self, - bundle_name: str | None = None, + name: str | None = None, config_path: str | Sequence[str] | None = None, bundle_dir: PathLike | None = None, configs: str | Sequence[str] = "train", @@ -1602,8 +1602,8 @@ def __init__( args_file: str | None = None, **kwargs: Any, ) -> None: - if bundle_name is None and config_path is None: - raise ValueError("Must specify bundle_name or config_path.") + if name is None and config_path is None: + raise ValueError("Must specify name or config_path.") configs = ensure_tuple(configs) if "train" in configs: workflow = "train" @@ -1619,8 +1619,8 @@ def __init__( raise FileNotFoundError(f"Cannot find the config file: {config_path}.") else: bundle_dir = _process_bundle_dir(bundle_dir) - download(bundle_name, bundle_dir=bundle_dir, version=version, source=source, args_file=args_file) - config_root_path = bundle_dir / bundle_name / "configs" # type: ignore + download(name, bundle_dir=bundle_dir, version=version, source=source, args_file=args_file) + config_root_path = bundle_dir / name / "configs" # type: ignore if len(configs) > 0: config_file = [str(_find_config_file(config_root_path, _config)) for _config in configs] # type: ignore else: From 6576c15d2b8a960aa35d097fd6720551ad7bff61 Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Mon, 14 Aug 2023 14:51:17 +0800 Subject: [PATCH 16/56] update docstring Signed-off-by: KumoLiu --- monai/bundle/scripts.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index 48cec9cdae..995ce25776 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -1578,8 +1578,6 @@ class BundleManager: their contents will be merged. Defaults to None. bundle_dir: the target directory to store downloaded bundle. Defaults to the 'bundle' subfolder under `torch.hub.get_dir()`. - target directory to store the downloaded data. - Default is `bundle` subfolder under `torch.hub.get_dir()`. configs: The name of the config file(s), supporting multiple names. Defaults to "train". version: version name of the target bundle to download, like: "0.1.0". If `None`, will download From 0409cac54584e1ccebadaf1dd697c4ed45af1508 Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Tue, 15 Aug 2023 14:58:19 +0800 Subject: [PATCH 17/56] add `MetaProterties` Signed-off-by: KumoLiu --- monai/bundle/properties.py | 30 +++++++++++++++- monai/bundle/scripts.py | 74 +++++++++++++++++++++++++++++++++++++- 2 files changed, 102 insertions(+), 2 deletions(-) diff --git a/monai/bundle/properties.py b/monai/bundle/properties.py index 16ecf77268..17844b4ffc 100644 --- a/monai/bundle/properties.py +++ b/monai/bundle/properties.py @@ -13,7 +13,7 @@ to interact with the bundle workflow. Some properties are required and some are optional, optional properties mean: if some component of the bundle workflow refer to the property, the property must be defined, otherwise, the property can be None. -Every item in this `TrainProperties` or `InferProperties` dictionary is a property, +Every item in this `TrainProperties` or `InferProperties` or `MetaProterties` dictionary is a property, the key is the property name and the values include: 1. description. 2. whether it's a required property. @@ -48,6 +48,11 @@ BundleProperty.REQUIRED: True, BundlePropertyConfig.ID: f"train{ID_SEP_KEY}trainer", }, + "network_def": { + BundleProperty.DESC: "network module for the training.", + BundleProperty.REQUIRED: False, + BundlePropertyConfig.ID: "network_def", + }, "max_epochs": { BundleProperty.DESC: "max number of epochs to execute the training.", BundleProperty.REQUIRED: True, @@ -216,3 +221,26 @@ BundlePropertyConfig.REF_ID: f"evaluator{ID_SEP_KEY}key_val_metric", }, } + +MetaProterties = { + "version": { + BundleProperty.DESC: "bundle version", + BundleProperty.REQUIRED: False, + BundlePropertyConfig.ID: "version", + }, + "monai_version": { + BundleProperty.DESC: "required monai version used for bundle", + BundleProperty.REQUIRED: False, + BundlePropertyConfig.ID: "monai_version", + }, + "pytorch_version": { + BundleProperty.DESC: "required pytorch version used for bundle", + BundleProperty.REQUIRED: False, + BundlePropertyConfig.ID: "pytorch_version", + }, + "numpy_version": { + BundleProperty.DESC: "required numpy version used for bundle", + BundleProperty.REQUIRED: False, + BundlePropertyConfig.ID: "numpy_version", + }, +} diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index 995ce25776..43b786d392 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -417,7 +417,7 @@ def load( If used, it should be in the form of "repo_owner/repo_name/release_tag". remove_prefix: This argument is used when `source` is "ngc". Currently, all ngc bundles have the ``monai_`` prefix, which is not existing in their model zoo contrasts. In order to - maintain the consistency between these two sources, remove prefix is necessary. + maintain the consistency between these three sources, remove prefix is necessary. Therefore, if specified, downloaded folder name will remove the prefix. progress: whether to display a progress bar when downloading. device: target device of returned weights or module, if `None`, prefer to "cuda" if existing. @@ -1550,6 +1550,78 @@ def _find_config_file(root_dir: Path, file_name: str, suffix: Sequence[str] = (" return None +def load_bundle_state_dict( + name: str | None = None, + bundle_dir: PathLike | None = None, + model: str | None = None, + device: str | None = None, + dst_prefix: str = "", + key_in_ckpt: str | None = None, + mapping: dict = None, + exclude_vars: str = None, + inplace: bool = True, + config_path: str | Sequence[str] | None = None, + configs: str | Sequence[str] = "train", + version: str | None = None, + source: str = DEFAULT_DOWNLOAD_SOURCE, + args_file: str | None = None, + model_file: str | None = None, + load_ts_module: bool = False, + config_files: Sequence[str] = (), + **override: Any +): + """ + + Args: + + Examples + + """ + bundle_dir = _process_bundle_dir(bundle_dir) + if device is None: + device = "cuda:0" if is_available() else "cpu" + if model_file is None: + model_file = os.path.join("models", "model.ts" if load_ts_module is True else "model.pt") + full_path = os.path.join(bundle_dir, name, model_file) + if not os.path.exists(full_path) or model is None: + _override = {f"network_def#{key}": value for key, value in override.items()} + bundle = BundleManager( + name=name, + config_path=config_path, + bundle_dir=bundle_dir, + configs=configs, + version=version, + source=source, + args_file=args_file, + **_override + ) + + # loading with `torch.jit.load` + if load_ts_module is True: + return load_net_with_metadata(full_path, map_location=torch.device(device), more_extra_files=config_files) + + # loading with `torch.load` + model_dict = torch.load(full_path, map_location=torch.device(device)) + if not isinstance(model_dict, Mapping): + warnings.warn(f"the state dictionary from {full_path} should be a dictionary but got {type(model_dict)}.") + model_dict = get_state_dict(model_dict) + + if model is not None: + model.to(device) + else: + model = bundle.get("network_def").to(device) + + copy_model_state( + dst=model, + src=model_dict if key_in_ckpt is None else model_dict[key_in_ckpt], + dst_prefix=dst_prefix, + mapping=mapping, + exclude_vars=exclude_vars, + inplace=inplace, + ) + return model + + class BundleManager: """ The `BundleManager` class facilitates the automatic downloading and instantiation of bundles. From 7057fe267a06a49f83ed2b6df46c37ac14b0fecf Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Wed, 23 Aug 2023 16:29:27 +0800 Subject: [PATCH 18/56] support getting meta property Signed-off-by: KumoLiu --- monai/bundle/properties.py | 23 +++++++++++++++++++---- monai/bundle/workflows.py | 6 ++++-- 2 files changed, 23 insertions(+), 6 deletions(-) diff --git a/monai/bundle/properties.py b/monai/bundle/properties.py index 17844b4ffc..0b74c26085 100644 --- a/monai/bundle/properties.py +++ b/monai/bundle/properties.py @@ -226,21 +226,36 @@ "version": { BundleProperty.DESC: "bundle version", BundleProperty.REQUIRED: False, - BundlePropertyConfig.ID: "version", + BundlePropertyConfig.ID: f"_meta_{ID_SEP_KEY}version", }, "monai_version": { BundleProperty.DESC: "required monai version used for bundle", BundleProperty.REQUIRED: False, - BundlePropertyConfig.ID: "monai_version", + BundlePropertyConfig.ID: f"_meta_{ID_SEP_KEY}monai_version", }, "pytorch_version": { BundleProperty.DESC: "required pytorch version used for bundle", BundleProperty.REQUIRED: False, - BundlePropertyConfig.ID: "pytorch_version", + BundlePropertyConfig.ID: f"_meta_{ID_SEP_KEY}pytorch_version", }, "numpy_version": { BundleProperty.DESC: "required numpy version used for bundle", BundleProperty.REQUIRED: False, - BundlePropertyConfig.ID: "numpy_version", + BundlePropertyConfig.ID: f"_meta_{ID_SEP_KEY}numpy_version", + }, + "description": { + BundleProperty.DESC: "description for bundle", + BundleProperty.REQUIRED: False, + BundlePropertyConfig.ID: f"_meta_{ID_SEP_KEY}description", + }, + "spatial_shape": { + BundleProperty.DESC: "spatial shape for the inputs", + BundleProperty.REQUIRED: False, + BundlePropertyConfig.ID: f"_meta_{ID_SEP_KEY}network_data_format{ID_SEP_KEY}inputs{ID_SEP_KEY}image{ID_SEP_KEY}spatial_shape", + }, + "channel_def": { + BundleProperty.DESC: "channel definition for the prediction", + BundleProperty.REQUIRED: False, + BundlePropertyConfig.ID: f"_meta_{ID_SEP_KEY}network_data_format{ID_SEP_KEY}outputs{ID_SEP_KEY}pred{ID_SEP_KEY}channel_def", }, } diff --git a/monai/bundle/workflows.py b/monai/bundle/workflows.py index 6bd966592e..da19aed0bf 100644 --- a/monai/bundle/workflows.py +++ b/monai/bundle/workflows.py @@ -22,7 +22,7 @@ from monai.apps.utils import get_logger from monai.bundle.config_parser import ConfigParser -from monai.bundle.properties import InferProperties, TrainProperties +from monai.bundle.properties import InferProperties, TrainProperties, MetaProterties from monai.bundle.utils import DEFAULT_EXP_MGMT_SETTINGS, EXPR_KEY, ID_REF_KEY, ID_SEP_KEY from monai.utils import BundleProperty, BundlePropertyConfig @@ -50,14 +50,16 @@ class BundleWorkflow(ABC): def __init__(self, workflow: str | None = None): if workflow is None: - self.properties = None + self.properties = copy(MetaProterties) self.workflow = None return if workflow.lower() in self.supported_train_type: self.properties = copy(TrainProperties) + self.properties.update(copy(MetaProterties)) self.workflow = "train" elif workflow.lower() in self.supported_infer_type: self.properties = copy(InferProperties) + self.properties.update(copy(MetaProterties)) self.workflow = "infer" else: raise ValueError(f"Unsupported workflow type: '{workflow}'.") From 7ab0a7499453585d32e3e35058d9f2188eb646a1 Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Wed, 23 Aug 2023 16:31:45 +0800 Subject: [PATCH 19/56] remove `BundleManager` Signed-off-by: KumoLiu --- monai/bundle/__init__.py | 1 - monai/bundle/scripts.py | 185 --------------------------------------- 2 files changed, 186 deletions(-) diff --git a/monai/bundle/__init__.py b/monai/bundle/__init__.py index 2605973495..e8ea9d62b0 100644 --- a/monai/bundle/__init__.py +++ b/monai/bundle/__init__.py @@ -16,7 +16,6 @@ from .properties import InferProperties, TrainProperties from .reference_resolver import ReferenceResolver from .scripts import ( - BundleManager, ckpt_export, download, get_all_bundles_list, diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index 43b786d392..449d303d20 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -1548,188 +1548,3 @@ def _find_config_file(root_dir: Path, file_name: str, suffix: Sequence[str] = (" if full_name.is_file(): return full_name return None - - -def load_bundle_state_dict( - name: str | None = None, - bundle_dir: PathLike | None = None, - model: str | None = None, - device: str | None = None, - dst_prefix: str = "", - key_in_ckpt: str | None = None, - mapping: dict = None, - exclude_vars: str = None, - inplace: bool = True, - config_path: str | Sequence[str] | None = None, - configs: str | Sequence[str] = "train", - version: str | None = None, - source: str = DEFAULT_DOWNLOAD_SOURCE, - args_file: str | None = None, - model_file: str | None = None, - load_ts_module: bool = False, - config_files: Sequence[str] = (), - **override: Any -): - """ - - Args: - - Examples - - """ - bundle_dir = _process_bundle_dir(bundle_dir) - if device is None: - device = "cuda:0" if is_available() else "cpu" - if model_file is None: - model_file = os.path.join("models", "model.ts" if load_ts_module is True else "model.pt") - full_path = os.path.join(bundle_dir, name, model_file) - if not os.path.exists(full_path) or model is None: - _override = {f"network_def#{key}": value for key, value in override.items()} - bundle = BundleManager( - name=name, - config_path=config_path, - bundle_dir=bundle_dir, - configs=configs, - version=version, - source=source, - args_file=args_file, - **_override - ) - - # loading with `torch.jit.load` - if load_ts_module is True: - return load_net_with_metadata(full_path, map_location=torch.device(device), more_extra_files=config_files) - - # loading with `torch.load` - model_dict = torch.load(full_path, map_location=torch.device(device)) - if not isinstance(model_dict, Mapping): - warnings.warn(f"the state dictionary from {full_path} should be a dictionary but got {type(model_dict)}.") - model_dict = get_state_dict(model_dict) - - if model is not None: - model.to(device) - else: - model = bundle.get("network_def").to(device) - - copy_model_state( - dst=model, - src=model_dict if key_in_ckpt is None else model_dict[key_in_ckpt], - dst_prefix=dst_prefix, - mapping=mapping, - exclude_vars=exclude_vars, - inplace=inplace, - ) - return model - - -class BundleManager: - """ - The `BundleManager` class facilitates the automatic downloading and instantiation of bundles. - It allows users to retrieve bundle properties and meta information. - - Typical usage examples: - - .. code-block:: python - - from monai.bundle import BundleManager - - # Create a BundleManager instance for the 'spleen_ct_segmentation' bundle - bundle = BundleManager("spleen_ct_segmentation") - - # Get properties defined in `TrainProperties` or `InferProperties` - train_preprocessing = bundle.get("train_preprocessing") - print(train_preprocessing) - - # Also support to retrieve meta information from the "metadata.json" file - version = bundle.get("version") - print(version) - - Args: - name: the name of the bundle. Defaults to None. - config_path: file path of the bundle. If a list of file paths is provided, - their contents will be merged. Defaults to None. - bundle_dir: the target directory to store downloaded bundle. - Defaults to the 'bundle' subfolder under `torch.hub.get_dir()`. - configs: The name of the config file(s), supporting multiple names. - Defaults to "train". - version: version name of the target bundle to download, like: "0.1.0". If `None`, will download - the latest version. - source: storage location name. This argument is used when `url` is `None`. - In default, the value is achieved from the environment variable BUNDLE_DOWNLOAD_SRC, and - it should be "ngc", "monaihosting" or "github". - args_file: a JSON or YAML file to provide default values for all the args in download function. - kwargs: Additional arguments for workflow class instantiation. - """ - - def __init__( - self, - name: str | None = None, - config_path: str | Sequence[str] | None = None, - bundle_dir: PathLike | None = None, - configs: str | Sequence[str] = "train", - version: str | None = None, - source: str = DEFAULT_DOWNLOAD_SOURCE, - args_file: str | None = None, - **kwargs: Any, - ) -> None: - if name is None and config_path is None: - raise ValueError("Must specify name or config_path.") - configs = ensure_tuple(configs) - if "train" in configs: - workflow = "train" - else: - workflow = "infer" - - if config_path is not None: - _config_path = Path(ensure_tuple(config_path)[0]) - if _config_path.is_file(): - config_file = config_path - config_root_path = _config_path.parent - else: - raise FileNotFoundError(f"Cannot find the config file: {config_path}.") - else: - bundle_dir = _process_bundle_dir(bundle_dir) - download(name, bundle_dir=bundle_dir, version=version, source=source, args_file=args_file) - config_root_path = bundle_dir / name / "configs" # type: ignore - if len(configs) > 0: - config_file = [str(_find_config_file(config_root_path, _config)) for _config in configs] # type: ignore - else: - config_file = str(_find_config_file(config_root_path, configs[0])) # type: ignore - - logging_file = config_root_path / "logging.conf" - self.meta_file = config_root_path / "metadata.json" - - self.workflow = ConfigWorkflow( - config_file=config_file, - meta_file=str(self.meta_file) if self.meta_file.is_file() else None, - logging_file=str(logging_file) if logging_file.is_file() else None, - workflow=workflow, - **kwargs, - ) - self.workflow.initialize() - - def get(self, id: str = "", default: Any | None = None) -> Any: - """ - Get information from the bundle by id. - - Args: - id: id to specify the expected position. - It could be the target property, defined in `TrainProperties` or `InferProperties`. - Or meta information retrieved from the "metadata.json" file, such as version, changelog, etc. - default: default value to return if the specified ``id`` is invalid. - - """ - if id in self.workflow.properties: # type: ignore - return getattr(self.workflow, id) - elif self.meta_file.is_file(): - metadata = ConfigParser.load_config_files(files=self.meta_file) - if id.lower() in metadata.keys(): - return {id: metadata[id.lower()]} - warnings.warn("Specified ``id`` is invalid or missing 'metadata.json', return default value.") - return default - - def train(self): - pass - - def predict(self): - pass From 79bf87de7b03e6204bfd2e8f987d584e875599e4 Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Wed, 23 Aug 2023 17:05:11 +0800 Subject: [PATCH 20/56] simplify the specification for meta_file and logging_file in `ConfigWorkflow` Signed-off-by: KumoLiu --- monai/bundle/workflows.py | 66 ++++++++++++++++++++++----------------- 1 file changed, 37 insertions(+), 29 deletions(-) diff --git a/monai/bundle/workflows.py b/monai/bundle/workflows.py index da19aed0bf..014fcf3c6e 100644 --- a/monai/bundle/workflows.py +++ b/monai/bundle/workflows.py @@ -24,7 +24,7 @@ from monai.bundle.config_parser import ConfigParser from monai.bundle.properties import InferProperties, TrainProperties, MetaProterties from monai.bundle.utils import DEFAULT_EXP_MGMT_SETTINGS, EXPR_KEY, ID_REF_KEY, ID_SEP_KEY -from monai.utils import BundleProperty, BundlePropertyConfig +from monai.utils import BundleProperty, BundlePropertyConfig, deprecated_arg_default, ensure_tuple __all__ = ["BundleWorkflow", "ConfigWorkflow"] @@ -168,18 +168,18 @@ class ConfigWorkflow(BundleWorkflow): For more information: https://docs.monai.io/en/latest/mb_specification.html. Args: - run_id: ID name of the expected config expression to run, default to "run". - to run the config, the target config must contain this ID. + config_file: filepath of the config file, if it is a list of file paths, the content of them will be merged. + meta_file: filepath of the metadata file, if it is a list of file paths, the content of them will be merged. + If None, default to "configs/metadata.json", which is commonly used for bundles in MONAI model zoo. + logging_file: config file for `logging` module in the program. for more details: + https://docs.python.org/3/library/logging.config.html#logging.config.fileConfig. + If None, default to "configs/logging.conf", which is commonly used for bundles in MONAI model zoo. init_id: ID name of the expected config expression to initialize before running, default to "initialize". allow a config to have no `initialize` logic and the ID. + run_id: ID name of the expected config expression to run, default to "run". + to run the config, the target config must contain this ID. final_id: ID name of the expected config expression to finalize after running, default to "finalize". allow a config to have no `finalize` logic and the ID. - meta_file: filepath of the metadata file, if it is a list of file paths, the content of them will be merged. - Default to "configs/metadata.json", which is commonly used for bundles in MONAI model zoo. - config_file: filepath of the config file, if it is a list of file paths, the content of them will be merged. - logging_file: config file for `logging` module in the program. for more details: - https://docs.python.org/3/library/logging.config.html#logging.config.fileConfig. - Default to "configs/logging.conf", which is commonly used for bundles in MONAI model zoo. tracking: if not None, enable the experiment tracking at runtime with optionally configurable and extensible. if "mlflow", will add `MLFlowHandler` to the parsed bundle with default tracking settings, if other string, treat it as file path to load the tracking settings. @@ -190,17 +190,18 @@ class ConfigWorkflow(BundleWorkflow): workflow: specifies the workflow type: "train" or "training" for a training workflow, or "infer", "inference", "eval", "evaluation" for a inference workflow, other unsupported string will raise a ValueError. - default to `None` for common workflow. + default to `train` for training workflow. override: id-value pairs to override or add the corresponding config content. e.g. ``--net#input_chns 42``, ``--net %/data/other.json#net_arg`` """ + @deprecated_arg_default("workflow", None, "train", since="1.3", replaced="1.4") def __init__( self, config_file: str | Sequence[str], - meta_file: str | Sequence[str] | None = "configs/metadata.json", - logging_file: str | None = "configs/logging.conf", + meta_file: str | Sequence[str] | None = None, + logging_file: str | None = None, init_id: str = "initialize", run_id: str = "run", final_id: str = "finalize", @@ -209,26 +210,33 @@ def __init__( **override: Any, ) -> None: super().__init__(workflow=workflow) - if logging_file is not None: - if not os.path.exists(logging_file): - if logging_file == "configs/logging.conf": - warnings.warn("Default logging file in 'configs/logging.conf' does not exist, skipping logging.") - else: - raise FileNotFoundError(f"Cannot find the logging config file: {logging_file}.") - else: - logger.info(f"Setting logging properties based on config: {logging_file}.") - fileConfig(logging_file, disable_existing_loggers=False) + _config_path = Path(ensure_tuple(config_file)[0]) + if _config_path.is_file(): + config_file = config_file + config_root_path = _config_path.parent + else: + raise FileNotFoundError(f"Cannot find the config file: {config_file}.") + + if logging_file is None: + logging_file = config_root_path / "logging.conf" + if not logging_file.is_file(): + warnings.warn(f"Default logging file in {logging_file} does not exist, skipping logging.") + if not os.path.exists(logging_file): + raise FileNotFoundError(f"Cannot find the logging config file: {logging_file}.") + else: + logger.info(f"Setting logging properties based on config: {logging_file}.") + fileConfig(logging_file, disable_existing_loggers=False) self.parser = ConfigParser() self.parser.read_config(f=config_file) - if meta_file is not None: - if isinstance(meta_file, str) and not os.path.exists(meta_file): - if meta_file == "configs/metadata.json": - warnings.warn("Default metadata file in 'configs/metadata.json' does not exist, skipping loading.") - else: - raise FileNotFoundError(f"Cannot find the metadata config file: {meta_file}.") - else: - self.parser.read_meta(f=meta_file) + if meta_file is None: + meta_file = config_root_path / "metadata.json" + if not meta_file.is_file(): + warnings.warn(f"Default metadata file in {meta_file} does not exist, skipping logging.") + if not os.path.exists(meta_file): + raise FileNotFoundError(f"Cannot find the metadata config file: {meta_file}.") + else: + self.parser.read_meta(f=meta_file) # the rest key-values in the _args are to override config content self.parser.update(pairs=override) From 1a2abeaa6043d720daccaf23d3d819a171e6f549 Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Wed, 23 Aug 2023 17:08:09 +0800 Subject: [PATCH 21/56] fix flake8 Signed-off-by: KumoLiu --- monai/bundle/properties.py | 3 ++- monai/bundle/workflows.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/monai/bundle/properties.py b/monai/bundle/properties.py index 0b74c26085..850955d8fa 100644 --- a/monai/bundle/properties.py +++ b/monai/bundle/properties.py @@ -251,7 +251,8 @@ "spatial_shape": { BundleProperty.DESC: "spatial shape for the inputs", BundleProperty.REQUIRED: False, - BundlePropertyConfig.ID: f"_meta_{ID_SEP_KEY}network_data_format{ID_SEP_KEY}inputs{ID_SEP_KEY}image{ID_SEP_KEY}spatial_shape", + BundlePropertyConfig.ID: f"_meta_{ID_SEP_KEY}network_data_format{ID_SEP_KEY}inputs{ID_SEP_KEY}image" + f"{ID_SEP_KEY}spatial_shape", }, "channel_def": { BundleProperty.DESC: "channel definition for the prediction", diff --git a/monai/bundle/workflows.py b/monai/bundle/workflows.py index 014fcf3c6e..853e0a0c82 100644 --- a/monai/bundle/workflows.py +++ b/monai/bundle/workflows.py @@ -22,7 +22,7 @@ from monai.apps.utils import get_logger from monai.bundle.config_parser import ConfigParser -from monai.bundle.properties import InferProperties, TrainProperties, MetaProterties +from monai.bundle.properties import InferProperties, MetaProterties, TrainProperties from monai.bundle.utils import DEFAULT_EXP_MGMT_SETTINGS, EXPR_KEY, ID_REF_KEY, ID_SEP_KEY from monai.utils import BundleProperty, BundlePropertyConfig, deprecated_arg_default, ensure_tuple From a8ae128aca61fed63d9d36b7744f5176f466ec1a Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Wed, 23 Aug 2023 17:22:22 +0800 Subject: [PATCH 22/56] fix mypy Signed-off-by: KumoLiu --- monai/bundle/workflows.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/monai/bundle/workflows.py b/monai/bundle/workflows.py index 853e0a0c82..d0a8f3061b 100644 --- a/monai/bundle/workflows.py +++ b/monai/bundle/workflows.py @@ -24,6 +24,7 @@ from monai.bundle.config_parser import ConfigParser from monai.bundle.properties import InferProperties, MetaProterties, TrainProperties from monai.bundle.utils import DEFAULT_EXP_MGMT_SETTINGS, EXPR_KEY, ID_REF_KEY, ID_SEP_KEY +from monai.config import PathLike from monai.utils import BundleProperty, BundlePropertyConfig, deprecated_arg_default, ensure_tuple __all__ = ["BundleWorkflow", "ConfigWorkflow"] @@ -200,8 +201,8 @@ class ConfigWorkflow(BundleWorkflow): def __init__( self, config_file: str | Sequence[str], - meta_file: str | Sequence[str] | None = None, - logging_file: str | None = None, + meta_file: PathLike | Sequence[PathLike] | None = None, + logging_file: PathLike | None = None, init_id: str = "initialize", run_id: str = "run", final_id: str = "finalize", @@ -221,9 +222,9 @@ def __init__( logging_file = config_root_path / "logging.conf" if not logging_file.is_file(): warnings.warn(f"Default logging file in {logging_file} does not exist, skipping logging.") - if not os.path.exists(logging_file): - raise FileNotFoundError(f"Cannot find the logging config file: {logging_file}.") else: + if not Path(logging_file).is_file(): + raise FileNotFoundError(f"Cannot find the logging config file: {logging_file}.") logger.info(f"Setting logging properties based on config: {logging_file}.") fileConfig(logging_file, disable_existing_loggers=False) @@ -232,10 +233,10 @@ def __init__( if meta_file is None: meta_file = config_root_path / "metadata.json" if not meta_file.is_file(): - warnings.warn(f"Default metadata file in {meta_file} does not exist, skipping logging.") - if not os.path.exists(meta_file): - raise FileNotFoundError(f"Cannot find the metadata config file: {meta_file}.") + warnings.warn(f"Default metadata file in {meta_file} does not exist, skipping loading.") else: + if not Path(ensure_tuple(meta_file)[0]).is_file(): + raise FileNotFoundError(f"Cannot find the metadata config file: {meta_file}.") self.parser.read_meta(f=meta_file) # the rest key-values in the _args are to override config content From c0c4e01f220cec00725e9d1174131a1aa06ae426 Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Wed, 23 Aug 2023 19:04:13 +0800 Subject: [PATCH 23/56] simplify the specification for meta_file and logging_file in ConfigWorkflow Signed-off-by: KumoLiu --- monai/bundle/workflows.py | 36 +++++++++++++++++++----------------- 1 file changed, 19 insertions(+), 17 deletions(-) diff --git a/monai/bundle/workflows.py b/monai/bundle/workflows.py index d0a8f3061b..44a9a74f06 100644 --- a/monai/bundle/workflows.py +++ b/monai/bundle/workflows.py @@ -218,26 +218,28 @@ def __init__( else: raise FileNotFoundError(f"Cannot find the config file: {config_file}.") - if logging_file is None: - logging_file = config_root_path / "logging.conf" - if not logging_file.is_file(): - warnings.warn(f"Default logging file in {logging_file} does not exist, skipping logging.") - else: - if not Path(logging_file).is_file(): - raise FileNotFoundError(f"Cannot find the logging config file: {logging_file}.") - logger.info(f"Setting logging properties based on config: {logging_file}.") - fileConfig(logging_file, disable_existing_loggers=False) + logging_file = config_root_path / "logging.conf" if logging_file is None else logging_file + if logging_file is not None: + if not os.path.exists(logging_file): + if logging_file == config_root_path / "logging.conf": + warnings.warn(f"Default logging file in {logging_file} does not exist, skipping logging.") + else: + raise FileNotFoundError(f"Cannot find the logging config file: {logging_file}.") + else: + logger.info(f"Setting logging properties based on config: {logging_file}.") + fileConfig(logging_file, disable_existing_loggers=False) self.parser = ConfigParser() self.parser.read_config(f=config_file) - if meta_file is None: - meta_file = config_root_path / "metadata.json" - if not meta_file.is_file(): - warnings.warn(f"Default metadata file in {meta_file} does not exist, skipping loading.") - else: - if not Path(ensure_tuple(meta_file)[0]).is_file(): - raise FileNotFoundError(f"Cannot find the metadata config file: {meta_file}.") - self.parser.read_meta(f=meta_file) + meta_file = config_root_path / "metadata.json" if meta_file is None else meta_file + if meta_file is not None: + if isinstance(meta_file, str) and not os.path.exists(meta_file): + if meta_file == config_root_path / "metadata.json": + warnings.warn(f"Default metadata file in {meta_file} does not exist, skipping loading.") + else: + raise FileNotFoundError(f"Cannot find the metadata config file: {meta_file}.") + else: + self.parser.read_meta(f=meta_file) # the rest key-values in the _args are to override config content self.parser.update(pairs=override) From fcc8b1b2ae85806f3c699d535f48bb2ef5659547 Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Wed, 23 Aug 2023 19:26:35 +0800 Subject: [PATCH 24/56] add `create_workflow` and update `load` Signed-off-by: KumoLiu --- monai/bundle/scripts.py | 95 +++++++++++++++++++++++------------------ 1 file changed, 53 insertions(+), 42 deletions(-) diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index 449d303d20..b89cfef90e 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -378,6 +378,7 @@ def download( def load( name: str, + model: str | None = None, version: str | None = None, model_file: str | None = None, load_ts_module: bool = False, @@ -389,8 +390,13 @@ def load( device: str | None = None, key_in_ckpt: str | None = None, config_files: Sequence[str] = (), - net_name: str | None = None, - **net_kwargs: Any, + dst_prefix: str = "", + mapping: dict = None, + exclude_vars: str = None, + inplace: bool = True, + workflow_name: str | BundleWorkflow | None = None, + args_file: str | None = None, + **override: Any ) -> object | tuple[torch.nn.Module, dict, dict] | Any: """ Load model weights or TorchScript module of a bundle. @@ -425,9 +431,6 @@ def load( weights. if not nested checkpoint, no need to set. config_files: extra filenames would be loaded. The argument only works when loading a TorchScript module, see `_extra_files` in `torch.jit.load` for more details. - net_name: if not `None`, a corresponding network will be instantiated and load the achieved weights. - This argument only works when loading weights. - net_kwargs: other arguments that are used to instantiate the network class defined by `net_name`. Returns: 1. If `load_ts_module` is `False` and `net_name` is `None`, return model weights. @@ -439,15 +442,12 @@ def load( """ bundle_dir_ = _process_bundle_dir(bundle_dir) - + if device is None: + device = "cuda:0" if is_available() else "cpu" if model_file is None: model_file = os.path.join("models", "model.ts" if load_ts_module is True else "model.pt") - if source == "ngc": - name = _add_ngc_prefix(name) - if remove_prefix: - name = _remove_ngc_prefix(name, prefix=remove_prefix) full_path = os.path.join(bundle_dir_, name, model_file) - if not os.path.exists(full_path): + if not os.path.exists(full_path) or model is None: download( name=name, version=version, @@ -457,9 +457,10 @@ def load( remove_prefix=remove_prefix, progress=progress, ) + train_config_file = str(bundle_dir_ / name / "configs" / "train.json") + _override = {f"network_def#{key}": value for key, value in override.items()} + workflow = create_workflow(workflow_name=workflow_name, args_file=args_file, config_file=train_config_file, workflow="train", **_override) - if device is None: - device = "cuda:0" if is_available() else "cpu" # loading with `torch.jit.load` if load_ts_module is True: return load_net_with_metadata(full_path, map_location=torch.device(device), more_extra_files=config_files) @@ -469,13 +470,17 @@ def load( warnings.warn(f"the state dictionary from {full_path} should be a dictionary but got {type(model_dict)}.") model_dict = get_state_dict(model_dict) - if net_name is None: - return model_dict - net_kwargs["_target_"] = net_name - configer = ConfigComponent(config=net_kwargs) - model = configer.instantiate() - model.to(device) # type: ignore - copy_model_state(dst=model, src=model_dict if key_in_ckpt is None else model_dict[key_in_ckpt]) # type: ignore + model = workflow.network_def if model is None else model + model.to(device) + + copy_model_state( + dst=model, + src=model_dict if key_in_ckpt is None else model_dict[key_in_ckpt], + dst_prefix=dst_prefix, + mapping=mapping, + exclude_vars=exclude_vars, + inplace=inplace, + ) return model @@ -734,7 +739,7 @@ def run( workflow.finalize() -def run_workflow(workflow: str | BundleWorkflow | None = None, args_file: str | None = None, **kwargs: Any) -> None: +def run_workflow(workflow_name: str | BundleWorkflow | None = None, args_file: str | None = None, **kwargs: Any) -> None: """ Specify `bundle workflow` to run monai bundle components and workflows. The workflow should be subclass of `BundleWorkflow` and be available to import. @@ -748,35 +753,17 @@ def run_workflow(workflow: str | BundleWorkflow | None = None, args_file: str | python -m monai.bundle run_workflow --meta_file --config_file # Set the workflow to other customized BundleWorkflow subclass: - python -m monai.bundle run_workflow --workflow CustomizedWorkflow ... + python -m monai.bundle run_workflow --workflow_name CustomizedWorkflow ... Args: - workflow: specified bundle workflow name, should be a string or class, default to "ConfigWorkflow". + workflow_name: specified bundle workflow name, should be a string or class, default to "ConfigWorkflow". args_file: a JSON or YAML file to provide default values for this API. so that the command line inputs can be simplified. kwargs: arguments to instantiate the workflow class. """ - _args = _update_args(args=args_file, workflow=workflow, **kwargs) - _log_input_summary(tag="run", args=_args) - (workflow_name,) = _pop_args(_args, workflow=ConfigWorkflow) # the default workflow name is "ConfigWorkflow" - if isinstance(workflow_name, str): - workflow_class, has_built_in = optional_import("monai.bundle", name=str(workflow_name)) # search built-in - if not has_built_in: - workflow_class = locate(str(workflow_name)) # search dotted path - if workflow_class is None: - raise ValueError(f"cannot locate specified workflow class: {workflow_name}.") - elif issubclass(workflow_name, BundleWorkflow): - workflow_class = workflow_name - else: - raise ValueError( - "Argument `workflow` must be a bundle workflow class name" - f"or subclass of BundleWorkflow, got: {workflow_name}." - ) - - workflow_ = workflow_class(**_args) - workflow_.initialize() + workflow_ = create_workflow(workflow_name=workflow_name, args_file=args_file, **kwargs) workflow_.run() workflow_.finalize() @@ -1548,3 +1535,27 @@ def _find_config_file(root_dir: Path, file_name: str, suffix: Sequence[str] = (" if full_name.is_file(): return full_name return None + + +def create_workflow(workflow_name: str | BundleWorkflow | None = None, args_file: str | None = None, **kwargs: Any) -> None: + _args = _update_args(args=args_file, workflow_name=workflow_name, **kwargs) + _log_input_summary(tag="run", args=_args) + (workflow_name,) = _pop_args(_args, workflow_name=ConfigWorkflow) # the default workflow name is "ConfigWorkflow" + if isinstance(workflow_name, str): + workflow_class, has_built_in = optional_import("monai.bundle", name=str(workflow_name)) # search built-in + if not has_built_in: + workflow_class = locate(str(workflow_name)) # search dotted path + if workflow_class is None: + raise ValueError(f"cannot locate specified workflow class: {workflow_name}.") + elif issubclass(workflow_name, BundleWorkflow): + workflow_class = workflow_name + else: + raise ValueError( + "Argument `workflow` must be a bundle workflow class name" + f"or subclass of BundleWorkflow, got: {workflow_name}." + ) + + workflow_ = workflow_class(**_args) + workflow_.initialize() + + return workflow_ From 30ebb163f35381dbf7e32c94a3645332916943b8 Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Wed, 23 Aug 2023 19:26:58 +0800 Subject: [PATCH 25/56] remove `_find_config_file` Signed-off-by: KumoLiu --- monai/bundle/scripts.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index b89cfef90e..b356efa4d5 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -1528,15 +1528,6 @@ def init_bundle( save_state(network, str(models_dir / "model.pt")) -def _find_config_file(root_dir: Path, file_name: str, suffix: Sequence[str] = ("json", "yaml", "yml")) -> Path | None: - # find bundle file with possible suffix - for _suffix in suffix: - full_name = root_dir / f"{file_name}.{_suffix}" - if full_name.is_file(): - return full_name - return None - - def create_workflow(workflow_name: str | BundleWorkflow | None = None, args_file: str | None = None, **kwargs: Any) -> None: _args = _update_args(args=args_file, workflow_name=workflow_name, **kwargs) _log_input_summary(tag="run", args=_args) From f609f981ba4046321da74867c91bc83234beca27 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 23 Aug 2023 11:27:36 +0000 Subject: [PATCH 26/56] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/bundle/scripts.py | 1 - 1 file changed, 1 deletion(-) diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index b356efa4d5..25aac991cd 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -28,7 +28,6 @@ from monai.apps.mmars.mmars import _get_all_ngc_models from monai.apps.utils import _basename, download_url, extractall, get_logger -from monai.bundle.config_item import ConfigComponent from monai.bundle.config_parser import ConfigParser from monai.bundle.utils import DEFAULT_INFERENCE, DEFAULT_METADATA from monai.bundle.workflows import BundleWorkflow, ConfigWorkflow From ee525f8410607aac3b68edc120909448640f3cac Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Wed, 23 Aug 2023 19:49:30 +0800 Subject: [PATCH 27/56] update docstring Signed-off-by: KumoLiu --- monai/bundle/__init__.py | 1 + monai/bundle/scripts.py | 31 +++++++++++++++++++++++++++++++ 2 files changed, 32 insertions(+) diff --git a/monai/bundle/__init__.py b/monai/bundle/__init__.py index e8ea9d62b0..858805e648 100644 --- a/monai/bundle/__init__.py +++ b/monai/bundle/__init__.py @@ -29,6 +29,7 @@ trt_export, verify_metadata, verify_net_in_out, + create_workflow ) from .utils import ( DEFAULT_EXP_MGMT_SETTINGS, diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index b356efa4d5..51bb5db806 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -431,6 +431,14 @@ def load( weights. if not nested checkpoint, no need to set. config_files: extra filenames would be loaded. The argument only works when loading a TorchScript module, see `_extra_files` in `torch.jit.load` for more details. + dst_prefix: used in `copy_model_state`, `dst` key prefix, so that `dst[dst_prefix + src_key]` + will be assigned to the value of `src[src_key]`. + mapping: used in `copy_model_state`, a `{"src_key": "dst_key"}` dict, indicating that `dst[dst_prefix + dst_key]` + to be assigned to the value of `src[src_key]`. + exclude_vars: used in `copy_model_state`, a regular expression to match the `dst` variable names, + so that their values are not overwritten by `src`. + inplace: used in `copy_model_state`, whether to set the `dst` module with the updated `state_dict` via `load_state_dict`. + This option is only available when `dst` is a `torch.nn.Module`. Returns: 1. If `load_ts_module` is `False` and `net_name` is `None`, return model weights. @@ -1529,6 +1537,29 @@ def init_bundle( def create_workflow(workflow_name: str | BundleWorkflow | None = None, args_file: str | None = None, **kwargs: Any) -> None: + """ + Specify `bundle workflow` to create monai bundle workflows. + The workflow should be subclass of `BundleWorkflow` and be available to import. + It can be MONAI existing bundle workflows or user customized workflows. + + Typical usage examples: + + .. code-block:: python + + # Specify config_file path to create workflow: + workflow = create_workflow(config_file="/workspace/spleen_ct_segmentation/configs/train.json", workflow="train") + + # Set the workflow to other customized BundleWorkflow subclass to create workflow: + workflow = create_workflow(workflow_name=CustomizedWorkflow) + + Args: + workflow_name: specified bundle workflow name, should be a string or class, default to "ConfigWorkflow". + args_file: a JSON or YAML file to provide default values for this API. + so that the command line inputs can be simplified. + kwargs: arguments to instantiate the workflow class. + + """ + _args = _update_args(args=args_file, workflow_name=workflow_name, **kwargs) _log_input_summary(tag="run", args=_args) (workflow_name,) = _pop_args(_args, workflow_name=ConfigWorkflow) # the default workflow name is "ConfigWorkflow" From 83cd3d3b697c05d725e841cdd8982ea35c035501 Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Thu, 24 Aug 2023 11:31:03 +0800 Subject: [PATCH 28/56] minor fix Signed-off-by: KumoLiu --- tests/test_integration_bundle_run.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_integration_bundle_run.py b/tests/test_integration_bundle_run.py index 74ac93bc27..08ae555ffd 100644 --- a/tests/test_integration_bundle_run.py +++ b/tests/test_integration_bundle_run.py @@ -147,7 +147,7 @@ def test_customized_workflow(self): filename = os.path.join(self.data_dir, "image.nii") nib.save(nib.Nifti1Image(test_image, np.eye(4)), filename) - cmd = "-m fire monai.bundle.scripts run_workflow --workflow tests.nonconfig_workflow.NonConfigWorkflow" + cmd = "-m fire monai.bundle.scripts run_workflow --workflow_name tests.nonconfig_workflow.NonConfigWorkflow" cmd += f" --filename {filename} --output_dir {self.data_dir}" command_line_tests(["coverage", "run"] + cmd.split(" ")) loader = LoadImage(image_only=True) From 559711d1536719cd2c9995728ea8d3a698370ae3 Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Thu, 24 Aug 2023 14:41:30 +0800 Subject: [PATCH 29/56] add unittests for load Signed-off-by: KumoLiu --- monai/bundle/workflows.py | 2 +- tests/test_bundle_download.py | 54 +++++++++++++++++++++++++++++++++-- 2 files changed, 53 insertions(+), 3 deletions(-) diff --git a/monai/bundle/workflows.py b/monai/bundle/workflows.py index 44a9a74f06..1e350d1e3a 100644 --- a/monai/bundle/workflows.py +++ b/monai/bundle/workflows.py @@ -191,7 +191,7 @@ class ConfigWorkflow(BundleWorkflow): workflow: specifies the workflow type: "train" or "training" for a training workflow, or "infer", "inference", "eval", "evaluation" for a inference workflow, other unsupported string will raise a ValueError. - default to `train` for training workflow. + default to `None` for common workflow. override: id-value pairs to override or add the corresponding config content. e.g. ``--net#input_chns 42``, ``--net %/data/other.json#net_arg`` diff --git a/tests/test_bundle_download.py b/tests/test_bundle_download.py index 36e935bf08..7c89b09a07 100644 --- a/tests/test_bundle_download.py +++ b/tests/test_bundle_download.py @@ -17,11 +17,13 @@ import unittest import torch +import numpy as np from parameterized import parameterized import monai.networks.nets as nets from monai.apps import check_hash -from monai.bundle import ConfigParser, load +from monai.utils import set_determinism +from monai.bundle import ConfigParser, load, create_workflow from tests.utils import ( SkipIfBeforePyTorchVersion, assert_allclose, @@ -64,6 +66,12 @@ "https://api.ngc.nvidia.com/v2/models/nvidia/monaihosting/brats_mri_segmentation/versions/0.3.9/files/brats_mri_segmentation_v0.3.9.zip", ] +TEST_CASE_7 = [ + "spleen_ct_segmentation", + "cuda" if torch.cuda.is_available() else "cpu", + {"spatial_dims": 3, "out_channels": 5} +] + class TestDownload(unittest.TestCase): @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) @@ -159,20 +167,62 @@ def test_load_weights(self, bundle_files, bundle_name, repo, device, model_file) # load instantiated model directly and test, since the bundle has been downloaded, # there is no need to input `repo` + _model_2 = nets.__dict__[model_name](**net_args) model_2 = load( name=bundle_name, + model=_model_2, model_file=model_file, bundle_dir=tempdir, progress=False, device=device, net_name=model_name, source="github", - **net_args, ) model_2.eval() output_2 = model_2.forward(input_tensor) assert_allclose(output_2, expected_output, atol=1e-4, rtol=1e-4, type_test=False) + @parameterized.expand([TEST_CASE_7]) + @skip_if_quick + def test_load_weights_with_net_override(self, bundle_name, device, net_override): + with skip_if_downloading_fails(): + # download bundle, and load weights from the downloaded path + with tempfile.TemporaryDirectory() as tempdir: + # load weights + model = load( + name=bundle_name, + bundle_dir=tempdir, + source="monaihosting", + progress=False, + device=device, + ) + + # prepare data and test + input_tensor = torch.rand(1, 1, 96, 96, 96).to(device) + output = model(input_tensor) + model_path = f"{tempdir}/spleen_ct_segmentation/models/model.pt" + workflow = create_workflow(config_file=f"{tempdir}/spleen_ct_segmentation/configs/train.json", workflow="train") + expected_model = workflow.network_def.to(device) + expected_model.load_state_dict(torch.load(model_path)) + expected_output = expected_model(input_tensor) + assert_allclose(output, expected_output, atol=1e-4, rtol=1e-4, type_test=False) + + # using net_override to override kwargs in network directly + model_2 = load( + name=bundle_name, + bundle_dir=tempdir, + source="monaihosting", + progress=False, + device=device, + **net_override + ) + + # prepare data and test + input_tensor = torch.rand(1, 1, 96, 96, 96).to(device) + output = model_2(input_tensor) + expected_shape = (1, 5, 96, 96, 96) + np.testing.assert_equal(output.shape, expected_shape) + @parameterized.expand([TEST_CASE_5]) @skip_if_quick @SkipIfBeforePyTorchVersion((1, 7, 1)) From 7038b90ccd1295911d5a0809caf6a0bb99a65618 Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Thu, 24 Aug 2023 14:45:50 +0800 Subject: [PATCH 30/56] update the docstring for the `load` Signed-off-by: KumoLiu --- monai/bundle/scripts.py | 47 +++++++++++++++++++++++++++++++---------- 1 file changed, 36 insertions(+), 11 deletions(-) diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index 9b73208173..4335c16320 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -377,8 +377,9 @@ def download( def load( name: str, - model: str | None = None, + model: torch.nn.Module | None = None, version: str | None = None, + workflow: str = "train", model_file: str | None = None, load_ts_module: bool = False, bundle_dir: PathLike | None = None, @@ -395,7 +396,7 @@ def load( inplace: bool = True, workflow_name: str | BundleWorkflow | None = None, args_file: str | None = None, - **override: Any + **net_override: Any ) -> object | tuple[torch.nn.Module, dict, dict] | Any: """ Load model weights or TorchScript module of a bundle. @@ -407,8 +408,15 @@ def load( https://github.com/Project-MONAI/model-zoo/releases/tag/hosting_storage_v1. "monai_brats_mri_segmentation" in ngc: https://catalog.ngc.nvidia.com/models?filters=&orderBy=scoreDESC&query=monai. + "mednist_gan" in monaihosting: + https://api.ngc.nvidia.com/v2/models/nvidia/monaihosting/mednist_gan/versions/0.2.0/files/mednist_gan_v0.2.0.zip + model: a pytorch module to be updated. Default to None, using the "network_def" in the bundle. version: version name of the target bundle to download, like: "0.1.0". If `None`, will download the latest version. + workflow: specifies the workflow type: "train" or "training" for a training workflow, + or "infer", "inference", "eval", "evaluation" for a inference workflow, + other unsupported string will raise a ValueError. + default to `train` for training workflow. model_file: the relative path of the model weights or TorchScript module within bundle. If `None`, "models/model.pt" or "models/model.ts" will be used. load_ts_module: a flag to specify if loading the TorchScript module. @@ -438,10 +446,15 @@ def load( so that their values are not overwritten by `src`. inplace: used in `copy_model_state`, whether to set the `dst` module with the updated `state_dict` via `load_state_dict`. This option is only available when `dst` is a `torch.nn.Module`. + workflow_name: specified bundle workflow name, should be a string or class, default to "ConfigWorkflow". + args_file: a JSON or YAML file to provide default values for all the args in "download" function. + net_override: id-value pairs to override the parameters in the network of the bundle. Returns: - 1. If `load_ts_module` is `False` and `net_name` is `None`, return model weights. - 2. If `load_ts_module` is `False` and `net_name` is not `None`, + 1. If `load_ts_module` is `False` and `model` is `None`, + return model weights if can't find "network_def" in the bundle, + else return an instantiated network that loaded the weights. + 2. If `load_ts_module` is `False` and `model` is not `None`, return an instantiated network that loaded the weights. 3. If `load_ts_module` is `True`, return a triple that include a TorchScript module, the corresponding metadata dict, and extra files dict. @@ -463,10 +476,14 @@ def load( repo=repo, remove_prefix=remove_prefix, progress=progress, + args_file=args_file, ) - train_config_file = str(bundle_dir_ / name / "configs" / "train.json") - _override = {f"network_def#{key}": value for key, value in override.items()} - workflow = create_workflow(workflow_name=workflow_name, args_file=args_file, config_file=train_config_file, workflow="train", **_override) + train_config_file = bundle_dir_ / name / "configs" / f"{workflow}.json" + if train_config_file.is_file(): + _net_override = {f"network_def#{key}": value for key, value in net_override.items()} + _workflow = create_workflow(workflow_name=workflow_name, args_file=args_file, config_file=str(train_config_file), workflow=workflow, **_net_override) + else: + _workflow = None # loading with `torch.jit.load` if load_ts_module is True: @@ -477,7 +494,10 @@ def load( warnings.warn(f"the state dictionary from {full_path} should be a dictionary but got {type(model_dict)}.") model_dict = get_state_dict(model_dict) - model = workflow.network_def if model is None else model + if model is None and _workflow is None: + return model_dict + + model = _workflow.network_def if model is None else model model.to(device) copy_model_state( @@ -1535,7 +1555,12 @@ def init_bundle( save_state(network, str(models_dir / "model.pt")) -def create_workflow(workflow_name: str | BundleWorkflow | None = None, args_file: str | None = None, **kwargs: Any) -> None: +def create_workflow( + workflow_name: str | BundleWorkflow | None = None, + config_file: str | Sequence[str] | None = None, + args_file: str | None = None, + **kwargs: Any +) -> None: """ Specify `bundle workflow` to create monai bundle workflows. The workflow should be subclass of `BundleWorkflow` and be available to import. @@ -1553,13 +1578,13 @@ def create_workflow(workflow_name: str | BundleWorkflow | None = None, args_file Args: workflow_name: specified bundle workflow name, should be a string or class, default to "ConfigWorkflow". + config_file: filepath of the config file, if it is a list of file paths, the content of them will be merged. args_file: a JSON or YAML file to provide default values for this API. so that the command line inputs can be simplified. kwargs: arguments to instantiate the workflow class. """ - - _args = _update_args(args=args_file, workflow_name=workflow_name, **kwargs) + _args = _update_args(args=args_file, workflow_name=workflow_name, config_file=config_file, **kwargs) _log_input_summary(tag="run", args=_args) (workflow_name,) = _pop_args(_args, workflow_name=ConfigWorkflow) # the default workflow name is "ConfigWorkflow" if isinstance(workflow_name, str): From a5992b9a8c6fe41921ed66c1c24775a5e01819d5 Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Thu, 24 Aug 2023 14:47:36 +0800 Subject: [PATCH 31/56] fix flake8 Signed-off-by: KumoLiu --- monai/bundle/__init__.py | 4 ++-- monai/bundle/scripts.py | 16 ++++++++++++---- tests/test_bundle_download.py | 20 ++++++++------------ 3 files changed, 22 insertions(+), 18 deletions(-) diff --git a/monai/bundle/__init__.py b/monai/bundle/__init__.py index 858805e648..7eddeaa641 100644 --- a/monai/bundle/__init__.py +++ b/monai/bundle/__init__.py @@ -13,10 +13,11 @@ from .config_item import ComponentLocator, ConfigComponent, ConfigExpression, ConfigItem, Instantiable from .config_parser import ConfigParser -from .properties import InferProperties, TrainProperties +from .properties import InferProperties, MetaProterties, TrainProperties from .reference_resolver import ReferenceResolver from .scripts import ( ckpt_export, + create_workflow, download, get_all_bundles_list, get_bundle_info, @@ -29,7 +30,6 @@ trt_export, verify_metadata, verify_net_in_out, - create_workflow ) from .utils import ( DEFAULT_EXP_MGMT_SETTINGS, diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index 4335c16320..8c1ee20a2e 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -396,7 +396,7 @@ def load( inplace: bool = True, workflow_name: str | BundleWorkflow | None = None, args_file: str | None = None, - **net_override: Any + **net_override: Any, ) -> object | tuple[torch.nn.Module, dict, dict] | Any: """ Load model weights or TorchScript module of a bundle. @@ -481,7 +481,13 @@ def load( train_config_file = bundle_dir_ / name / "configs" / f"{workflow}.json" if train_config_file.is_file(): _net_override = {f"network_def#{key}": value for key, value in net_override.items()} - _workflow = create_workflow(workflow_name=workflow_name, args_file=args_file, config_file=str(train_config_file), workflow=workflow, **_net_override) + _workflow = create_workflow( + workflow_name=workflow_name, + args_file=args_file, + config_file=str(train_config_file), + workflow=workflow, + **_net_override, + ) else: _workflow = None @@ -766,7 +772,9 @@ def run( workflow.finalize() -def run_workflow(workflow_name: str | BundleWorkflow | None = None, args_file: str | None = None, **kwargs: Any) -> None: +def run_workflow( + workflow_name: str | BundleWorkflow | None = None, args_file: str | None = None, **kwargs: Any +) -> None: """ Specify `bundle workflow` to run monai bundle components and workflows. The workflow should be subclass of `BundleWorkflow` and be available to import. @@ -1559,7 +1567,7 @@ def create_workflow( workflow_name: str | BundleWorkflow | None = None, config_file: str | Sequence[str] | None = None, args_file: str | None = None, - **kwargs: Any + **kwargs: Any, ) -> None: """ Specify `bundle workflow` to create monai bundle workflows. diff --git a/tests/test_bundle_download.py b/tests/test_bundle_download.py index 7c89b09a07..a4a1fcc10d 100644 --- a/tests/test_bundle_download.py +++ b/tests/test_bundle_download.py @@ -16,14 +16,14 @@ import tempfile import unittest -import torch import numpy as np +import torch from parameterized import parameterized import monai.networks.nets as nets from monai.apps import check_hash +from monai.bundle import ConfigParser, create_workflow, load from monai.utils import set_determinism -from monai.bundle import ConfigParser, load, create_workflow from tests.utils import ( SkipIfBeforePyTorchVersion, assert_allclose, @@ -69,7 +69,7 @@ TEST_CASE_7 = [ "spleen_ct_segmentation", "cuda" if torch.cuda.is_available() else "cpu", - {"spatial_dims": 3, "out_channels": 5} + {"spatial_dims": 3, "out_channels": 5}, ] @@ -189,19 +189,15 @@ def test_load_weights_with_net_override(self, bundle_name, device, net_override) # download bundle, and load weights from the downloaded path with tempfile.TemporaryDirectory() as tempdir: # load weights - model = load( - name=bundle_name, - bundle_dir=tempdir, - source="monaihosting", - progress=False, - device=device, - ) + model = load(name=bundle_name, bundle_dir=tempdir, source="monaihosting", progress=False, device=device) # prepare data and test input_tensor = torch.rand(1, 1, 96, 96, 96).to(device) output = model(input_tensor) model_path = f"{tempdir}/spleen_ct_segmentation/models/model.pt" - workflow = create_workflow(config_file=f"{tempdir}/spleen_ct_segmentation/configs/train.json", workflow="train") + workflow = create_workflow( + config_file=f"{tempdir}/spleen_ct_segmentation/configs/train.json", workflow="train" + ) expected_model = workflow.network_def.to(device) expected_model.load_state_dict(torch.load(model_path)) expected_output = expected_model(input_tensor) @@ -214,7 +210,7 @@ def test_load_weights_with_net_override(self, bundle_name, device, net_override) source="monaihosting", progress=False, device=device, - **net_override + **net_override, ) # prepare data and test From 0531989f0564cfbe42cfa2aaa0c18ec4d30512aa Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 24 Aug 2023 06:48:22 +0000 Subject: [PATCH 32/56] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/test_bundle_download.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_bundle_download.py b/tests/test_bundle_download.py index a4a1fcc10d..a18c518c80 100644 --- a/tests/test_bundle_download.py +++ b/tests/test_bundle_download.py @@ -23,7 +23,6 @@ import monai.networks.nets as nets from monai.apps import check_hash from monai.bundle import ConfigParser, create_workflow, load -from monai.utils import set_determinism from tests.utils import ( SkipIfBeforePyTorchVersion, assert_allclose, From a5cfd71346e062b1ed2104516ecf05d7d60e9b35 Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Thu, 24 Aug 2023 14:57:38 +0800 Subject: [PATCH 33/56] fix mypy Signed-off-by: KumoLiu --- monai/bundle/scripts.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index 8c1ee20a2e..a8b3082b48 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -391,8 +391,8 @@ def load( key_in_ckpt: str | None = None, config_files: Sequence[str] = (), dst_prefix: str = "", - mapping: dict = None, - exclude_vars: str = None, + mapping: dict | None = None, + exclude_vars: str | None = None, inplace: bool = True, workflow_name: str | BundleWorkflow | None = None, args_file: str | None = None, @@ -502,8 +502,7 @@ def load( if model is None and _workflow is None: return model_dict - - model = _workflow.network_def if model is None else model + model = _workflow.network_def if model is None else model # type: ignore model.to(device) copy_model_state( @@ -1568,7 +1567,7 @@ def create_workflow( config_file: str | Sequence[str] | None = None, args_file: str | None = None, **kwargs: Any, -) -> None: +) -> Any: """ Specify `bundle workflow` to create monai bundle workflows. The workflow should be subclass of `BundleWorkflow` and be available to import. @@ -1601,7 +1600,7 @@ def create_workflow( workflow_class = locate(str(workflow_name)) # search dotted path if workflow_class is None: raise ValueError(f"cannot locate specified workflow class: {workflow_name}.") - elif issubclass(workflow_name, BundleWorkflow): + elif issubclass(type(workflow_name), BundleWorkflow): workflow_class = workflow_name else: raise ValueError( From 82a3c23d379bea45063611728b6374dc5f612ca0 Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Thu, 24 Aug 2023 16:27:45 +0800 Subject: [PATCH 34/56] fix ci Signed-off-by: KumoLiu --- monai/bundle/scripts.py | 10 +++++----- monai/bundle/workflows.py | 12 ++++++------ 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index a8b3082b48..a9e5a50a01 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -712,12 +712,12 @@ def run( final_id: ID name of the expected config expression to finalize after running, default to "finalize". it's optional for both configs and this `run` function. meta_file: filepath of the metadata file, if it is a list of file paths, the content of them will be merged. - Default to "configs/metadata.json", which is commonly used for bundles in MONAI model zoo. + Default to None. config_file: filepath of the config file, if `None`, must be provided in `args_file`. if it is a list of file paths, the content of them will be merged. logging_file: config file for `logging` module in the program. for more details: https://docs.python.org/3/library/logging.config.html#logging.config.fileConfig. - Default to "configs/logging.conf", which is commonly used for bundles in MONAI model zoo. + Default to None. tracking: if not None, enable the experiment tracking at runtime with optionally configurable and extensible. if "mlflow", will add `MLFlowHandler` to the parsed bundle with default tracking settings, if other string, treat it as file path to load the tracking settings. @@ -749,11 +749,11 @@ def run( config_file_, meta_file_, init_id_, run_id_, final_id_, logging_file_, tracking_ = _pop_args( _args, config_file=None, - meta_file="configs/metadata.json", + meta_file=None, init_id="initialize", run_id="run", final_id="finalize", - logging_file="configs/logging.conf", + logging_file=None, tracking=None, ) workflow = ConfigWorkflow( @@ -1600,7 +1600,7 @@ def create_workflow( workflow_class = locate(str(workflow_name)) # search dotted path if workflow_class is None: raise ValueError(f"cannot locate specified workflow class: {workflow_name}.") - elif issubclass(type(workflow_name), BundleWorkflow): + elif issubclass(workflow_name, BundleWorkflow): # type: ignore workflow_class = workflow_name else: raise ValueError( diff --git a/monai/bundle/workflows.py b/monai/bundle/workflows.py index 1e350d1e3a..f0c107e35a 100644 --- a/monai/bundle/workflows.py +++ b/monai/bundle/workflows.py @@ -201,8 +201,8 @@ class ConfigWorkflow(BundleWorkflow): def __init__( self, config_file: str | Sequence[str], - meta_file: PathLike | Sequence[PathLike] | None = None, - logging_file: PathLike | None = None, + meta_file: str | Sequence[str] | None = None, + logging_file: str | None = None, init_id: str = "initialize", run_id: str = "run", final_id: str = "finalize", @@ -218,10 +218,10 @@ def __init__( else: raise FileNotFoundError(f"Cannot find the config file: {config_file}.") - logging_file = config_root_path / "logging.conf" if logging_file is None else logging_file + logging_file = str(config_root_path / "logging.conf") if logging_file is None else logging_file if logging_file is not None: if not os.path.exists(logging_file): - if logging_file == config_root_path / "logging.conf": + if logging_file == str(config_root_path / "logging.conf"): warnings.warn(f"Default logging file in {logging_file} does not exist, skipping logging.") else: raise FileNotFoundError(f"Cannot find the logging config file: {logging_file}.") @@ -231,10 +231,10 @@ def __init__( self.parser = ConfigParser() self.parser.read_config(f=config_file) - meta_file = config_root_path / "metadata.json" if meta_file is None else meta_file + meta_file = str(config_root_path / "metadata.json") if meta_file is None else meta_file if meta_file is not None: if isinstance(meta_file, str) and not os.path.exists(meta_file): - if meta_file == config_root_path / "metadata.json": + if meta_file == str(config_root_path / "metadata.json"): warnings.warn(f"Default metadata file in {meta_file} does not exist, skipping loading.") else: raise FileNotFoundError(f"Cannot find the metadata config file: {meta_file}.") From 0ad2364084b4769e92234dba9c41554571c5cad7 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 24 Aug 2023 08:28:24 +0000 Subject: [PATCH 35/56] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/bundle/workflows.py | 1 - 1 file changed, 1 deletion(-) diff --git a/monai/bundle/workflows.py b/monai/bundle/workflows.py index f0c107e35a..12653c9631 100644 --- a/monai/bundle/workflows.py +++ b/monai/bundle/workflows.py @@ -24,7 +24,6 @@ from monai.bundle.config_parser import ConfigParser from monai.bundle.properties import InferProperties, MetaProterties, TrainProperties from monai.bundle.utils import DEFAULT_EXP_MGMT_SETTINGS, EXPR_KEY, ID_REF_KEY, ID_SEP_KEY -from monai.config import PathLike from monai.utils import BundleProperty, BundlePropertyConfig, deprecated_arg_default, ensure_tuple __all__ = ["BundleWorkflow", "ConfigWorkflow"] From 6938718b757e0366b8fc6344a1e7dbafcf588bbd Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Thu, 24 Aug 2023 17:26:26 +0800 Subject: [PATCH 36/56] fix ci Signed-off-by: KumoLiu --- monai/bundle/workflows.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/monai/bundle/workflows.py b/monai/bundle/workflows.py index 12653c9631..b21ad8fdd9 100644 --- a/monai/bundle/workflows.py +++ b/monai/bundle/workflows.py @@ -210,12 +210,15 @@ def __init__( **override: Any, ) -> None: super().__init__(workflow=workflow) - _config_path = Path(ensure_tuple(config_file)[0]) - if _config_path.is_file(): - config_file = config_file - config_root_path = _config_path.parent + if config_file is not None: + _config_path = Path(ensure_tuple(config_file)[0]) + if _config_path.is_file(): + config_file = config_file + config_root_path = _config_path.parent + else: + raise FileNotFoundError(f"Cannot find the config file: {config_file}.") else: - raise FileNotFoundError(f"Cannot find the config file: {config_file}.") + config_root_path = Path("configs") logging_file = str(config_root_path / "logging.conf") if logging_file is None else logging_file if logging_file is not None: From 5da4ea6c8e05f123c5d53913037950a05daa8985 Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Fri, 25 Aug 2023 09:47:29 +0800 Subject: [PATCH 37/56] Update monai/bundle/__init__.py Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- monai/bundle/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/bundle/__init__.py b/monai/bundle/__init__.py index 7eddeaa641..dd556e9eb3 100644 --- a/monai/bundle/__init__.py +++ b/monai/bundle/__init__.py @@ -13,7 +13,7 @@ from .config_item import ComponentLocator, ConfigComponent, ConfigExpression, ConfigItem, Instantiable from .config_parser import ConfigParser -from .properties import InferProperties, MetaProterties, TrainProperties +from .properties import InferProperties, MetaProperties, TrainProperties from .reference_resolver import ReferenceResolver from .scripts import ( ckpt_export, From 85275d3d8807b21d83023b080094b0f8858a3c5b Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Fri, 25 Aug 2023 10:08:57 +0800 Subject: [PATCH 38/56] fix typo Signed-off-by: KumoLiu --- monai/bundle/properties.py | 4 ++-- monai/bundle/workflows.py | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/monai/bundle/properties.py b/monai/bundle/properties.py index 850955d8fa..7f250db311 100644 --- a/monai/bundle/properties.py +++ b/monai/bundle/properties.py @@ -13,7 +13,7 @@ to interact with the bundle workflow. Some properties are required and some are optional, optional properties mean: if some component of the bundle workflow refer to the property, the property must be defined, otherwise, the property can be None. -Every item in this `TrainProperties` or `InferProperties` or `MetaProterties` dictionary is a property, +Every item in this `TrainProperties` or `InferProperties` or `MetaProperties` dictionary is a property, the key is the property name and the values include: 1. description. 2. whether it's a required property. @@ -222,7 +222,7 @@ }, } -MetaProterties = { +MetaProperties = { "version": { BundleProperty.DESC: "bundle version", BundleProperty.REQUIRED: False, diff --git a/monai/bundle/workflows.py b/monai/bundle/workflows.py index b21ad8fdd9..b172bfcc22 100644 --- a/monai/bundle/workflows.py +++ b/monai/bundle/workflows.py @@ -22,7 +22,7 @@ from monai.apps.utils import get_logger from monai.bundle.config_parser import ConfigParser -from monai.bundle.properties import InferProperties, MetaProterties, TrainProperties +from monai.bundle.properties import InferProperties, MetaProperties, TrainProperties from monai.bundle.utils import DEFAULT_EXP_MGMT_SETTINGS, EXPR_KEY, ID_REF_KEY, ID_SEP_KEY from monai.utils import BundleProperty, BundlePropertyConfig, deprecated_arg_default, ensure_tuple @@ -50,16 +50,16 @@ class BundleWorkflow(ABC): def __init__(self, workflow: str | None = None): if workflow is None: - self.properties = copy(MetaProterties) + self.properties = copy(MetaProperties) self.workflow = None return if workflow.lower() in self.supported_train_type: self.properties = copy(TrainProperties) - self.properties.update(copy(MetaProterties)) + self.properties.update(copy(MetaProperties)) self.workflow = "train" elif workflow.lower() in self.supported_infer_type: self.properties = copy(InferProperties) - self.properties.update(copy(MetaProterties)) + self.properties.update(copy(MetaProperties)) self.workflow = "infer" else: raise ValueError(f"Unsupported workflow type: '{workflow}'.") From 63d3a845b92a643773e08a05482b939bbb88e20d Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Fri, 25 Aug 2023 11:16:39 +0800 Subject: [PATCH 39/56] update according to comments Signed-off-by: KumoLiu --- monai/bundle/properties.py | 2 +- monai/bundle/scripts.py | 19 +++---------------- monai/bundle/workflows.py | 6 ++---- tests/test_bundle_download.py | 4 ++-- 4 files changed, 8 insertions(+), 23 deletions(-) diff --git a/monai/bundle/properties.py b/monai/bundle/properties.py index 7f250db311..81e32b89fd 100644 --- a/monai/bundle/properties.py +++ b/monai/bundle/properties.py @@ -225,7 +225,7 @@ MetaProperties = { "version": { BundleProperty.DESC: "bundle version", - BundleProperty.REQUIRED: False, + BundleProperty.REQUIRED: True, BundlePropertyConfig.ID: f"_meta_{ID_SEP_KEY}version", }, "monai_version": { diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index a9e5a50a01..2462e9aab5 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -390,12 +390,9 @@ def load( device: str | None = None, key_in_ckpt: str | None = None, config_files: Sequence[str] = (), - dst_prefix: str = "", - mapping: dict | None = None, - exclude_vars: str | None = None, - inplace: bool = True, workflow_name: str | BundleWorkflow | None = None, args_file: str | None = None, + copy_model_args: dict = {}, **net_override: Any, ) -> object | tuple[torch.nn.Module, dict, dict] | Any: """ @@ -438,16 +435,9 @@ def load( weights. if not nested checkpoint, no need to set. config_files: extra filenames would be loaded. The argument only works when loading a TorchScript module, see `_extra_files` in `torch.jit.load` for more details. - dst_prefix: used in `copy_model_state`, `dst` key prefix, so that `dst[dst_prefix + src_key]` - will be assigned to the value of `src[src_key]`. - mapping: used in `copy_model_state`, a `{"src_key": "dst_key"}` dict, indicating that `dst[dst_prefix + dst_key]` - to be assigned to the value of `src[src_key]`. - exclude_vars: used in `copy_model_state`, a regular expression to match the `dst` variable names, - so that their values are not overwritten by `src`. - inplace: used in `copy_model_state`, whether to set the `dst` module with the updated `state_dict` via `load_state_dict`. - This option is only available when `dst` is a `torch.nn.Module`. workflow_name: specified bundle workflow name, should be a string or class, default to "ConfigWorkflow". args_file: a JSON or YAML file to provide default values for all the args in "download" function. + copy_model_args: other arguments for the `monai.networks.copy_model_state` function. net_override: id-value pairs to override the parameters in the network of the bundle. Returns: @@ -508,10 +498,7 @@ def load( copy_model_state( dst=model, src=model_dict if key_in_ckpt is None else model_dict[key_in_ckpt], - dst_prefix=dst_prefix, - mapping=mapping, - exclude_vars=exclude_vars, - inplace=inplace, + **copy_model_args ) return model diff --git a/monai/bundle/workflows.py b/monai/bundle/workflows.py index b172bfcc22..1373341b80 100644 --- a/monai/bundle/workflows.py +++ b/monai/bundle/workflows.py @@ -54,12 +54,10 @@ def __init__(self, workflow: str | None = None): self.workflow = None return if workflow.lower() in self.supported_train_type: - self.properties = copy(TrainProperties) - self.properties.update(copy(MetaProperties)) + self.properties = {**TrainProperties, **MetaProperties} self.workflow = "train" elif workflow.lower() in self.supported_infer_type: - self.properties = copy(InferProperties) - self.properties.update(copy(MetaProperties)) + self.properties = {**InferProperties, **MetaProperties} self.workflow = "infer" else: raise ValueError(f"Unsupported workflow type: '{workflow}'.") diff --git a/tests/test_bundle_download.py b/tests/test_bundle_download.py index a18c518c80..8299acc6fd 100644 --- a/tests/test_bundle_download.py +++ b/tests/test_bundle_download.py @@ -153,7 +153,7 @@ def test_load_weights(self, bundle_files, bundle_name, repo, device, model_file) net_args = json.load(f)["network_def"] model_name = net_args["_target_"] del net_args["_target_"] - model = nets.__dict__[model_name](**net_args) + model = getattr(nets, model_name)(**net_args) model.to(device) model.load_state_dict(weights) model.eval() @@ -166,7 +166,7 @@ def test_load_weights(self, bundle_files, bundle_name, repo, device, model_file) # load instantiated model directly and test, since the bundle has been downloaded, # there is no need to input `repo` - _model_2 = nets.__dict__[model_name](**net_args) + _model_2 = getattr(nets, model_name)(**net_args) model_2 = load( name=bundle_name, model=_model_2, From 864dadb38e824e1faebd5d8abf660dc39145d68c Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Fri, 25 Aug 2023 11:18:00 +0800 Subject: [PATCH 40/56] fix flake8 Signed-off-by: KumoLiu --- monai/bundle/scripts.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index 2462e9aab5..af8f3f9209 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -495,11 +495,7 @@ def load( model = _workflow.network_def if model is None else model # type: ignore model.to(device) - copy_model_state( - dst=model, - src=model_dict if key_in_ckpt is None else model_dict[key_in_ckpt], - **copy_model_args - ) + copy_model_state(dst=model, src=model_dict if key_in_ckpt is None else model_dict[key_in_ckpt], **copy_model_args) return model From fa3b940972b25efcf120cda24e8d68b882b48d41 Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Fri, 25 Aug 2023 11:26:10 +0800 Subject: [PATCH 41/56] Update monai/bundle/workflows.py Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- monai/bundle/workflows.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/bundle/workflows.py b/monai/bundle/workflows.py index 1373341b80..6b59ca3fd2 100644 --- a/monai/bundle/workflows.py +++ b/monai/bundle/workflows.py @@ -166,8 +166,8 @@ class ConfigWorkflow(BundleWorkflow): For more information: https://docs.monai.io/en/latest/mb_specification.html. Args: - config_file: filepath of the config file, if it is a list of file paths, the content of them will be merged. - meta_file: filepath of the metadata file, if it is a list of file paths, the content of them will be merged. + config_file: filepath of the config file, if this is a list of file paths, their contents will be merged in order. + meta_file: filepath of the metadata file, if this is a list of file paths, their contents will be merged in order. If None, default to "configs/metadata.json", which is commonly used for bundles in MONAI model zoo. logging_file: config file for `logging` module in the program. for more details: https://docs.python.org/3/library/logging.config.html#logging.config.fileConfig. From 54c65af21f7168d6720db99cce87a4725d9f131b Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Fri, 25 Aug 2023 11:31:30 +0800 Subject: [PATCH 42/56] minor fix Signed-off-by: KumoLiu --- monai/bundle/workflows.py | 1 - 1 file changed, 1 deletion(-) diff --git a/monai/bundle/workflows.py b/monai/bundle/workflows.py index 1373341b80..e1d7524c4d 100644 --- a/monai/bundle/workflows.py +++ b/monai/bundle/workflows.py @@ -211,7 +211,6 @@ def __init__( if config_file is not None: _config_path = Path(ensure_tuple(config_file)[0]) if _config_path.is_file(): - config_file = config_file config_root_path = _config_path.parent else: raise FileNotFoundError(f"Cannot find the config file: {config_file}.") From c4ac4678a1e68f4c1c6009d4c509f19b1186a8b7 Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Mon, 28 Aug 2023 11:58:26 +0800 Subject: [PATCH 43/56] fix version ci error Signed-off-by: KumoLiu --- tests/nonconfig_workflow.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/tests/nonconfig_workflow.py b/tests/nonconfig_workflow.py index 34f22aa565..a0a2e42e2c 100644 --- a/tests/nonconfig_workflow.py +++ b/tests/nonconfig_workflow.py @@ -50,9 +50,13 @@ def __init__(self, filename, output_dir): self._preprocessing = None self._postprocessing = None self._evaluator = None + self._version = None def initialize(self): set_determinism(0) + if self._version is None: + self._version = "0.1.0" + if self._preprocessing is None: self._preprocessing = Compose( [LoadImaged(keys="image"), EnsureChannelFirstd(keys="image"), ScaleIntensityd(keys="image")] @@ -118,6 +122,8 @@ def _get_property(self, name, property): return self._preprocessing if name == "postprocessing": return self._postprocessing + if name == "version": + return self._version if property[BundleProperty.REQUIRED]: raise ValueError(f"unsupported property '{name}' is required in the bundle properties.") @@ -142,5 +148,7 @@ def _set_property(self, name, property, value): self._preprocessing = value elif name == "postprocessing": self._postprocessing = value + elif name == "version": + self._version = value elif property[BundleProperty.REQUIRED]: raise ValueError(f"unsupported property '{name}' is required in the bundle properties.") From b267d7654365ff2fb4077d23d289584afc1eace0 Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Mon, 28 Aug 2023 13:04:24 +0800 Subject: [PATCH 44/56] fix flake8 Signed-off-by: KumoLiu --- monai/bundle/scripts.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index af8f3f9209..49bce451e7 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -392,7 +392,7 @@ def load( config_files: Sequence[str] = (), workflow_name: str | BundleWorkflow | None = None, args_file: str | None = None, - copy_model_args: dict = {}, + copy_model_args: dict | None = None, **net_override: Any, ) -> object | tuple[torch.nn.Module, dict, dict] | Any: """ @@ -452,6 +452,8 @@ def load( """ bundle_dir_ = _process_bundle_dir(bundle_dir) + copy_model_args = {} if copy_model_args is None else copy_model_args + if device is None: device = "cuda:0" if is_available() else "cpu" if model_file is None: From 05e51e5f15e1cb07175b0ffe93f1fcd5101c6145 Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Mon, 28 Aug 2023 18:47:01 +0800 Subject: [PATCH 45/56] set "meta_file" as required in `ConfigWorkflow` Signed-off-by: KumoLiu --- monai/bundle/properties.py | 6 +++--- monai/bundle/workflows.py | 12 ++++-------- tests/nonconfig_workflow.py | 12 ++++++++++++ tests/test_integration_bundle_run.py | 19 +++++++++++++++++-- 4 files changed, 36 insertions(+), 13 deletions(-) diff --git a/monai/bundle/properties.py b/monai/bundle/properties.py index 81e32b89fd..a75e862a84 100644 --- a/monai/bundle/properties.py +++ b/monai/bundle/properties.py @@ -230,17 +230,17 @@ }, "monai_version": { BundleProperty.DESC: "required monai version used for bundle", - BundleProperty.REQUIRED: False, + BundleProperty.REQUIRED: True, BundlePropertyConfig.ID: f"_meta_{ID_SEP_KEY}monai_version", }, "pytorch_version": { BundleProperty.DESC: "required pytorch version used for bundle", - BundleProperty.REQUIRED: False, + BundleProperty.REQUIRED: True, BundlePropertyConfig.ID: f"_meta_{ID_SEP_KEY}pytorch_version", }, "numpy_version": { BundleProperty.DESC: "required numpy version used for bundle", - BundleProperty.REQUIRED: False, + BundleProperty.REQUIRED: True, BundlePropertyConfig.ID: f"_meta_{ID_SEP_KEY}numpy_version", }, "description": { diff --git a/monai/bundle/workflows.py b/monai/bundle/workflows.py index 472b5ebfc5..efb48b47b2 100644 --- a/monai/bundle/workflows.py +++ b/monai/bundle/workflows.py @@ -231,14 +231,10 @@ def __init__( self.parser = ConfigParser() self.parser.read_config(f=config_file) meta_file = str(config_root_path / "metadata.json") if meta_file is None else meta_file - if meta_file is not None: - if isinstance(meta_file, str) and not os.path.exists(meta_file): - if meta_file == str(config_root_path / "metadata.json"): - warnings.warn(f"Default metadata file in {meta_file} does not exist, skipping loading.") - else: - raise FileNotFoundError(f"Cannot find the metadata config file: {meta_file}.") - else: - self.parser.read_meta(f=meta_file) + if isinstance(meta_file, str) and not os.path.exists(meta_file): + raise FileNotFoundError(f"Cannot find the metadata config file: {meta_file}.") + else: + self.parser.read_meta(f=meta_file) # the rest key-values in the _args are to override config content self.parser.update(pairs=override) diff --git a/tests/nonconfig_workflow.py b/tests/nonconfig_workflow.py index a0a2e42e2c..6e1d5b58e7 100644 --- a/tests/nonconfig_workflow.py +++ b/tests/nonconfig_workflow.py @@ -51,12 +51,24 @@ def __init__(self, filename, output_dir): self._postprocessing = None self._evaluator = None self._version = None + self._monai_version = None + self._pytorch_version = None + self._numpy_version = None def initialize(self): set_determinism(0) if self._version is None: self._version = "0.1.0" + if self._monai_version is None: + self._monai_version = "1.1.0" + + if self._pytorch_version is None: + self._pytorch_version = "1.13.1" + + if self._numpy_version is None: + self._numpy_version = "1.22.2" + if self._preprocessing is None: self._preprocessing = Compose( [LoadImaged(keys="image"), EnsureChannelFirstd(keys="image"), ScaleIntensityd(keys="image")] diff --git a/tests/test_integration_bundle_run.py b/tests/test_integration_bundle_run.py index 08ae555ffd..91bafe6eb5 100644 --- a/tests/test_integration_bundle_run.py +++ b/tests/test_integration_bundle_run.py @@ -53,6 +53,7 @@ def tearDown(self): def test_tiny(self): config_file = os.path.join(self.data_dir, "tiny_config.json") + meta_file = os.path.join(self.data_dir, "tiny_meta.json") with open(config_file, "w") as f: json.dump( { @@ -62,14 +63,28 @@ def test_tiny(self): }, f, ) + with open(meta_file, "w") as f: + json.dump( + { + "version": "0.1.0", + "monai_version": "1.1.0", + "pytorch_version": "1.13.1", + "numpy_version": "1.22.2", + }, + f, + ) cmd = ["coverage", "run", "-m", "monai.bundle"] # test both CLI entry "run" and "run_workflow" - command_line_tests(cmd + ["run", "training", "--config_file", config_file]) - command_line_tests(cmd + ["run_workflow", "--run_id", "training", "--config_file", config_file]) + command_line_tests(cmd + ["run", "training", "--config_file", config_file, "--meta_file", meta_file]) + command_line_tests(cmd + ["run_workflow", "--run_id", "training", "--config_file", config_file, "--meta_file", meta_file]) with self.assertRaises(RuntimeError): # test wrong run_id="run" command_line_tests(cmd + ["run", "run", "--config_file", config_file]) + with self.assertRaises(RuntimeError): + # test missing meta file + command_line_tests(cmd + ["run", "training", "--config_file", config_file]) + @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) def test_shape(self, config_file, expected_shape): test_image = np.random.rand(*expected_shape) From bc3a134d2a44e368a4c5bf0610f7f663b893fd5a Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Mon, 28 Aug 2023 18:48:59 +0800 Subject: [PATCH 46/56] fix flake8 Signed-off-by: KumoLiu --- tests/test_integration_bundle_run.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/tests/test_integration_bundle_run.py b/tests/test_integration_bundle_run.py index 91bafe6eb5..42abc1a5e0 100644 --- a/tests/test_integration_bundle_run.py +++ b/tests/test_integration_bundle_run.py @@ -65,18 +65,15 @@ def test_tiny(self): ) with open(meta_file, "w") as f: json.dump( - { - "version": "0.1.0", - "monai_version": "1.1.0", - "pytorch_version": "1.13.1", - "numpy_version": "1.22.2", - }, + {"version": "0.1.0", "monai_version": "1.1.0", "pytorch_version": "1.13.1", "numpy_version": "1.22.2"}, f, ) cmd = ["coverage", "run", "-m", "monai.bundle"] # test both CLI entry "run" and "run_workflow" command_line_tests(cmd + ["run", "training", "--config_file", config_file, "--meta_file", meta_file]) - command_line_tests(cmd + ["run_workflow", "--run_id", "training", "--config_file", config_file, "--meta_file", meta_file]) + command_line_tests( + cmd + ["run_workflow", "--run_id", "training", "--config_file", config_file, "--meta_file", meta_file] + ) with self.assertRaises(RuntimeError): # test wrong run_id="run" command_line_tests(cmd + ["run", "run", "--config_file", config_file]) From dcd76dd73d56cb58fb62a9439e03894c1ca2e3b3 Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Mon, 28 Aug 2023 21:35:37 +0800 Subject: [PATCH 47/56] minor fix Signed-off-by: KumoLiu --- tests/nonconfig_workflow.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/tests/nonconfig_workflow.py b/tests/nonconfig_workflow.py index 6e1d5b58e7..13ed3a4ab3 100644 --- a/tests/nonconfig_workflow.py +++ b/tests/nonconfig_workflow.py @@ -136,6 +136,12 @@ def _get_property(self, name, property): return self._postprocessing if name == "version": return self._version + if name == "monai_version": + return self._monai_version + if name == "pytorch_version": + return self._pytorch_version + if name == "numpy_version": + return self._numpy_version if property[BundleProperty.REQUIRED]: raise ValueError(f"unsupported property '{name}' is required in the bundle properties.") @@ -162,5 +168,11 @@ def _set_property(self, name, property, value): self._postprocessing = value elif name == "version": self._version = value + elif name == "monai_version": + self._monai_version = value + elif name == "pytorch_version": + self._pytorch_version = value + elif name == "numpy_version": + self._numpy_version = value elif property[BundleProperty.REQUIRED]: raise ValueError(f"unsupported property '{name}' is required in the bundle properties.") From 8e74d0454db5807fe494e2954078529c44447d28 Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Mon, 28 Aug 2023 22:33:42 +0800 Subject: [PATCH 48/56] fix unittests Signed-off-by: KumoLiu --- tests/test_bundle_utils.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/tests/test_bundle_utils.py b/tests/test_bundle_utils.py index d92f6e517f..3c28ba2f2f 100644 --- a/tests/test_bundle_utils.py +++ b/tests/test_bundle_utils.py @@ -100,7 +100,19 @@ def test_load_config_zip(self): self.assertEqual(p["test_dict"]["b"], "c") def test_run(self): - command_line_tests(["python", "-m", "monai.bundle", "run", "test", "--test", "$print('hello world')"]) + command_line_tests( + [ + "python", + "-m", + "monai.bundle", + "run", + "test", + "--test", + "$print('hello world')", + "--meta_file", + self.metadata_name, + ] + ) def test_load_config_ts(self): # create a Torchscript zip of the bundle From fe762587b0210e88c33cead38fd1b0077ae82593 Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Tue, 29 Aug 2023 15:47:59 +0800 Subject: [PATCH 49/56] address comments Signed-off-by: KumoLiu --- monai/bundle/scripts.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index 49bce451e7..17f11064ee 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -741,7 +741,7 @@ def run( logging_file=None, tracking=None, ) - workflow = ConfigWorkflow( + workflow = create_workflow( config_file=config_file_, meta_file=meta_file_, logging_file=logging_file_, @@ -751,7 +751,6 @@ def run( tracking=tracking_, **_args, ) - workflow.initialize() workflow.run() workflow.finalize() @@ -1589,7 +1588,7 @@ def create_workflow( workflow_class = workflow_name else: raise ValueError( - "Argument `workflow` must be a bundle workflow class name" + "Argument `workflow_name` must be a bundle workflow class name" f"or subclass of BundleWorkflow, got: {workflow_name}." ) From e5c8bfd928c68dea5a60f1bca74041ab5bba4486 Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Tue, 29 Aug 2023 16:40:52 +0800 Subject: [PATCH 50/56] address comments Signed-off-by: KumoLiu --- monai/bundle/scripts.py | 43 +++++++++++--------------------------- tests/test_bundle_utils.py | 2 ++ 2 files changed, 14 insertions(+), 31 deletions(-) diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index 17f11064ee..1af2bd8188 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -717,40 +717,17 @@ def run( """ - _args = _update_args( - args=args_file, - run_id=run_id, - init_id=init_id, - final_id=final_id, - meta_file=meta_file, + workflow = create_workflow( config_file=config_file, + args_file=args_file, + meta_file=meta_file, logging_file=logging_file, + init_id=init_id, + run_id=run_id, + final_id=final_id, tracking=tracking, **override, ) - if "config_file" not in _args: - warnings.warn("`config_file` not provided for 'monai.bundle run'.") - _log_input_summary(tag="run", args=_args) - config_file_, meta_file_, init_id_, run_id_, final_id_, logging_file_, tracking_ = _pop_args( - _args, - config_file=None, - meta_file=None, - init_id="initialize", - run_id="run", - final_id="finalize", - logging_file=None, - tracking=None, - ) - workflow = create_workflow( - config_file=config_file_, - meta_file=meta_file_, - logging_file=logging_file_, - init_id=init_id_, - run_id=run_id_, - final_id=final_id_, - tracking=tracking_, - **_args, - ) workflow.run() workflow.finalize() @@ -1577,7 +1554,7 @@ def create_workflow( """ _args = _update_args(args=args_file, workflow_name=workflow_name, config_file=config_file, **kwargs) _log_input_summary(tag="run", args=_args) - (workflow_name,) = _pop_args(_args, workflow_name=ConfigWorkflow) # the default workflow name is "ConfigWorkflow" + (workflow_name, config_file) = _pop_args(_args, workflow_name=ConfigWorkflow, config_file=None) # the default workflow name is "ConfigWorkflow" if isinstance(workflow_name, str): workflow_class, has_built_in = optional_import("monai.bundle", name=str(workflow_name)) # search built-in if not has_built_in: @@ -1592,7 +1569,11 @@ def create_workflow( f"or subclass of BundleWorkflow, got: {workflow_name}." ) - workflow_ = workflow_class(**_args) + if config_file is not None: + workflow_ = workflow_class(config_file=config_file, **_args) + else: + workflow_ = workflow_class(**_args) + workflow_.initialize() return workflow_ diff --git a/tests/test_bundle_utils.py b/tests/test_bundle_utils.py index 3c28ba2f2f..391a56bc3c 100644 --- a/tests/test_bundle_utils.py +++ b/tests/test_bundle_utils.py @@ -109,6 +109,8 @@ def test_run(self): "test", "--test", "$print('hello world')", + "--config_file", + self.test_name, "--meta_file", self.metadata_name, ] From 2789c5f29745fbd13813be1d8134d8d5b080f7dc Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Tue, 29 Aug 2023 16:57:21 +0800 Subject: [PATCH 51/56] fix flake8 Signed-off-by: KumoLiu --- monai/bundle/scripts.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index 1af2bd8188..1f082e5f2e 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -1554,7 +1554,9 @@ def create_workflow( """ _args = _update_args(args=args_file, workflow_name=workflow_name, config_file=config_file, **kwargs) _log_input_summary(tag="run", args=_args) - (workflow_name, config_file) = _pop_args(_args, workflow_name=ConfigWorkflow, config_file=None) # the default workflow name is "ConfigWorkflow" + (workflow_name, config_file) = _pop_args( + _args, workflow_name=ConfigWorkflow, config_file=None + ) # the default workflow name is "ConfigWorkflow" if isinstance(workflow_name, str): workflow_class, has_built_in = optional_import("monai.bundle", name=str(workflow_name)) # search built-in if not has_built_in: From 561a68e362026de630bc247f542eb8bccf852c1a Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Wed, 30 Aug 2023 10:56:54 +0800 Subject: [PATCH 52/56] add deprecated_arg for `net_name` Signed-off-by: KumoLiu --- monai/bundle/scripts.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index 1f082e5f2e..e6d3e3823e 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -375,6 +375,8 @@ def download( ) +@deprecated_arg("net_name", since="1.3", removed="1.4", msg_suffix="please use ``model`` instead.") +@deprecated_arg("net_kwargs", since="1.3", removed="1.3", msg_suffix="please use ``model`` instead.") def load( name: str, model: torch.nn.Module | None = None, @@ -393,6 +395,7 @@ def load( workflow_name: str | BundleWorkflow | None = None, args_file: str | None = None, copy_model_args: dict | None = None, + net_name: str | None = None, **net_override: Any, ) -> object | tuple[torch.nn.Module, dict, dict] | Any: """ From 0039860de8b02e0e62f6aa935d7bdb88dd5c545e Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Wed, 30 Aug 2023 11:14:40 +0800 Subject: [PATCH 53/56] update `worklfow` to `workflow_type` Signed-off-by: KumoLiu --- monai/bundle/scripts.py | 10 +++++----- monai/bundle/workflows.py | 32 ++++++++++++++++--------------- monai/fl/client/monai_algo.py | 4 ++-- tests/nonconfig_workflow.py | 2 +- tests/test_bundle_download.py | 2 +- tests/test_bundle_workflow.py | 4 ++-- tests/test_fl_monai_algo.py | 14 +++++++------- tests/test_fl_monai_algo_dist.py | 8 ++++---- tests/test_fl_monai_algo_stats.py | 4 ++-- tests/test_handler_mlflow.py | 2 +- 10 files changed, 42 insertions(+), 40 deletions(-) diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index e6d3e3823e..be6c0caba6 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -381,7 +381,7 @@ def load( name: str, model: torch.nn.Module | None = None, version: str | None = None, - workflow: str = "train", + workflow_type: str = "train", model_file: str | None = None, load_ts_module: bool = False, bundle_dir: PathLike | None = None, @@ -413,7 +413,7 @@ def load( model: a pytorch module to be updated. Default to None, using the "network_def" in the bundle. version: version name of the target bundle to download, like: "0.1.0". If `None`, will download the latest version. - workflow: specifies the workflow type: "train" or "training" for a training workflow, + workflow_type: specifies the workflow type: "train" or "training" for a training workflow, or "infer", "inference", "eval", "evaluation" for a inference workflow, other unsupported string will raise a ValueError. default to `train` for training workflow. @@ -473,14 +473,14 @@ def load( progress=progress, args_file=args_file, ) - train_config_file = bundle_dir_ / name / "configs" / f"{workflow}.json" + train_config_file = bundle_dir_ / name / "configs" / f"{workflow_type}.json" if train_config_file.is_file(): _net_override = {f"network_def#{key}": value for key, value in net_override.items()} _workflow = create_workflow( workflow_name=workflow_name, args_file=args_file, config_file=str(train_config_file), - workflow=workflow, + workflow_type=workflow_type, **_net_override, ) else: @@ -1542,7 +1542,7 @@ def create_workflow( .. code-block:: python # Specify config_file path to create workflow: - workflow = create_workflow(config_file="/workspace/spleen_ct_segmentation/configs/train.json", workflow="train") + workflow = create_workflow(config_file="/workspace/spleen_ct_segmentation/configs/train.json", workflow_type="train") # Set the workflow to other customized BundleWorkflow subclass to create workflow: workflow = create_workflow(workflow_name=CustomizedWorkflow) diff --git a/monai/bundle/workflows.py b/monai/bundle/workflows.py index efb48b47b2..5f23c2ea58 100644 --- a/monai/bundle/workflows.py +++ b/monai/bundle/workflows.py @@ -24,7 +24,7 @@ from monai.bundle.config_parser import ConfigParser from monai.bundle.properties import InferProperties, MetaProperties, TrainProperties from monai.bundle.utils import DEFAULT_EXP_MGMT_SETTINGS, EXPR_KEY, ID_REF_KEY, ID_SEP_KEY -from monai.utils import BundleProperty, BundlePropertyConfig, deprecated_arg_default, ensure_tuple +from monai.utils import BundleProperty, BundlePropertyConfig, deprecated_arg_default, ensure_tuple, deprecated_arg __all__ = ["BundleWorkflow", "ConfigWorkflow"] @@ -38,7 +38,7 @@ class BundleWorkflow(ABC): And also provides the interface to get / set public properties to interact with a bundle workflow. Args: - workflow: specifies the workflow type: "train" or "training" for a training workflow, + workflow_type: specifies the workflow type: "train" or "training" for a training workflow, or "infer", "inference", "eval", "evaluation" for a inference workflow, other unsupported string will raise a ValueError. default to `None` for common workflow. @@ -48,19 +48,20 @@ class BundleWorkflow(ABC): supported_train_type: tuple = ("train", "training") supported_infer_type: tuple = ("infer", "inference", "eval", "evaluation") - def __init__(self, workflow: str | None = None): - if workflow is None: + @deprecated_arg("workflow", since="1.3", removed="1.5", new_name="workflow_type", msg_suffix="please use `workflow_type` instead.") + def __init__(self, workflow_type: str | None = None): + if workflow_type is None: self.properties = copy(MetaProperties) - self.workflow = None + self.workflow_type = None return - if workflow.lower() in self.supported_train_type: + if workflow_type.lower() in self.supported_train_type: self.properties = {**TrainProperties, **MetaProperties} - self.workflow = "train" - elif workflow.lower() in self.supported_infer_type: + self.workflow_type = "train" + elif workflow_type.lower() in self.supported_infer_type: self.properties = {**InferProperties, **MetaProperties} - self.workflow = "infer" + self.workflow_type = "infer" else: - raise ValueError(f"Unsupported workflow type: '{workflow}'.") + raise ValueError(f"Unsupported workflow type: '{workflow_type}'.") @abstractmethod def initialize(self, *args: Any, **kwargs: Any) -> Any: @@ -128,7 +129,7 @@ def get_workflow_type(self): Get the workflow type, it can be `None`, "train", or "infer". """ - return self.workflow + return self.workflow_type def add_property(self, name: str, required: str, desc: str | None = None) -> None: """ @@ -185,7 +186,7 @@ class ConfigWorkflow(BundleWorkflow): will patch the target config content with `tracking handlers` and the top-level items of `configs`. for detailed usage examples, please check the tutorial: https://github.com/Project-MONAI/tutorials/blob/main/experiment_management/bundle_integrate_mlflow.ipynb. - workflow: specifies the workflow type: "train" or "training" for a training workflow, + workflow_type: specifies the workflow type: "train" or "training" for a training workflow, or "infer", "inference", "eval", "evaluation" for a inference workflow, other unsupported string will raise a ValueError. default to `None` for common workflow. @@ -194,7 +195,8 @@ class ConfigWorkflow(BundleWorkflow): """ - @deprecated_arg_default("workflow", None, "train", since="1.3", replaced="1.4") + @deprecated_arg("workflow", since="1.3", removed="1.5", new_name="workflow_type", msg_suffix="please use `workflow_type` instead.") + @deprecated_arg_default("workflow_type", None, "train", since="1.3", replaced="1.4") def __init__( self, config_file: str | Sequence[str], @@ -204,10 +206,10 @@ def __init__( run_id: str = "run", final_id: str = "finalize", tracking: str | dict | None = None, - workflow: str | None = None, + workflow_type: str | None = None, **override: Any, ) -> None: - super().__init__(workflow=workflow) + super().__init__(workflow_type=workflow_type) if config_file is not None: _config_path = Path(ensure_tuple(config_file)[0]) if _config_path.is_file(): diff --git a/monai/fl/client/monai_algo.py b/monai/fl/client/monai_algo.py index 4838b784e7..fc79d7c420 100644 --- a/monai/fl/client/monai_algo.py +++ b/monai/fl/client/monai_algo.py @@ -149,7 +149,7 @@ def initialize(self, extra=None): if self.workflow is None: config_train_files = self._add_config_files(self.config_train_filename) self.workflow = ConfigWorkflow( - config_file=config_train_files, meta_file=None, logging_file=None, workflow="train" + config_file=config_train_files, meta_file=None, logging_file=None, workflow_type="train" ) self.workflow.initialize() self.workflow.bundle_root = self.bundle_root @@ -431,7 +431,7 @@ def initialize(self, extra=None): if "run_name" not in self.train_kwargs: self.train_kwargs["run_name"] = f"{self.client_name}_{timestamp}" self.train_workflow = ConfigWorkflow( - config_file=config_train_files, meta_file=None, logging_file=None, workflow="train", **self.train_kwargs + config_file=config_train_files, meta_file=None, logging_file=None, workflow_type="train", **self.train_kwargs ) if self.train_workflow is not None: self.train_workflow.initialize() diff --git a/tests/nonconfig_workflow.py b/tests/nonconfig_workflow.py index 13ed3a4ab3..7b5328bf72 100644 --- a/tests/nonconfig_workflow.py +++ b/tests/nonconfig_workflow.py @@ -37,7 +37,7 @@ class NonConfigWorkflow(BundleWorkflow): """ def __init__(self, filename, output_dir): - super().__init__(workflow="inference") + super().__init__(workflow_type="inference") self.filename = filename self.output_dir = output_dir self._bundle_root = "will override" diff --git a/tests/test_bundle_download.py b/tests/test_bundle_download.py index 8299acc6fd..2457af3229 100644 --- a/tests/test_bundle_download.py +++ b/tests/test_bundle_download.py @@ -195,7 +195,7 @@ def test_load_weights_with_net_override(self, bundle_name, device, net_override) output = model(input_tensor) model_path = f"{tempdir}/spleen_ct_segmentation/models/model.pt" workflow = create_workflow( - config_file=f"{tempdir}/spleen_ct_segmentation/configs/train.json", workflow="train" + config_file=f"{tempdir}/spleen_ct_segmentation/configs/train.json", workflow_type="train" ) expected_model = workflow.network_def.to(device) expected_model.load_state_dict(torch.load(model_path)) diff --git a/tests/test_bundle_workflow.py b/tests/test_bundle_workflow.py index 247ed5ecd4..4291eedf3f 100644 --- a/tests/test_bundle_workflow.py +++ b/tests/test_bundle_workflow.py @@ -95,7 +95,7 @@ def test_inference_config(self, config_file): } # test standard MONAI model-zoo config workflow inferer = ConfigWorkflow( - workflow="infer", + workflow_type="infer", config_file=config_file, logging_file=os.path.join(os.path.dirname(__file__), "testing_data", "logging.conf"), **override, @@ -106,7 +106,7 @@ def test_inference_config(self, config_file): def test_train_config(self, config_file): # test standard MONAI model-zoo config workflow trainer = ConfigWorkflow( - workflow="train", + workflow_type="train", config_file=config_file, logging_file=os.path.join(os.path.dirname(__file__), "testing_data", "logging.conf"), init_id="initialize", diff --git a/tests/test_fl_monai_algo.py b/tests/test_fl_monai_algo.py index 026f7ca8b8..e649e7374f 100644 --- a/tests/test_fl_monai_algo.py +++ b/tests/test_fl_monai_algo.py @@ -36,7 +36,7 @@ { "bundle_root": _data_dir, "train_workflow": ConfigWorkflow( - config_file=os.path.join(_data_dir, "config_fl_train.json"), workflow="train", logging_file=_logging_file + config_file=os.path.join(_data_dir, "config_fl_train.json"), workflow_type="train", logging_file=_logging_file ), "config_evaluate_filename": None, "config_filters_filename": os.path.join(_data_dir, "config_fl_filters.json"), @@ -54,7 +54,7 @@ { "bundle_root": _data_dir, "train_workflow": ConfigWorkflow( - config_file=os.path.join(_data_dir, "config_fl_train.json"), workflow="train", logging_file=_logging_file + config_file=os.path.join(_data_dir, "config_fl_train.json"), workflow_type="train", logging_file=_logging_file ), "config_evaluate_filename": None, "config_filters_filename": os.path.join(_data_dir, "config_fl_filters.json"), @@ -66,7 +66,7 @@ "bundle_root": _data_dir, "train_workflow": ConfigWorkflow( config_file=os.path.join(_data_dir, "config_fl_train.json"), - workflow="train", + workflow_type="train", logging_file=_logging_file, tracking={ "handlers_id": DEFAULT_HANDLERS_ID, @@ -95,7 +95,7 @@ os.path.join(_data_dir, "config_fl_train.json"), os.path.join(_data_dir, "config_fl_evaluate.json"), ], - workflow="train", + workflow_type="train", logging_file=_logging_file, tracking="mlflow", tracking_uri=path_to_uri(_data_dir) + "/mlflow_1", @@ -130,7 +130,7 @@ os.path.join(_data_dir, "config_fl_train.json"), os.path.join(_data_dir, "config_fl_evaluate.json"), ], - workflow="train", + workflow_type="train", logging_file=_logging_file, ), "config_filters_filename": os.path.join(_data_dir, "config_fl_filters.json"), @@ -141,7 +141,7 @@ { "bundle_root": _data_dir, "train_workflow": ConfigWorkflow( - config_file=os.path.join(_data_dir, "config_fl_train.json"), workflow="train", logging_file=_logging_file + config_file=os.path.join(_data_dir, "config_fl_train.json"), workflow_type="train", logging_file=_logging_file ), "config_evaluate_filename": None, "send_weight_diff": False, @@ -161,7 +161,7 @@ { "bundle_root": _data_dir, "train_workflow": ConfigWorkflow( - config_file=os.path.join(_data_dir, "config_fl_train.json"), workflow="train", logging_file=_logging_file + config_file=os.path.join(_data_dir, "config_fl_train.json"), workflow_type="train", logging_file=_logging_file ), "config_evaluate_filename": None, "send_weight_diff": True, diff --git a/tests/test_fl_monai_algo_dist.py b/tests/test_fl_monai_algo_dist.py index f6dc626ad9..1302ab6618 100644 --- a/tests/test_fl_monai_algo_dist.py +++ b/tests/test_fl_monai_algo_dist.py @@ -41,15 +41,15 @@ def test_train(self): pathjoin(_data_dir, "config_fl_evaluate.json"), pathjoin(_data_dir, "multi_gpu_evaluate.json"), ] - train_workflow = ConfigWorkflow(config_file=train_configs, workflow="train", logging_file=_logging_file) + train_workflow = ConfigWorkflow(config_file=train_configs, workflow_type="train", logging_file=_logging_file) # simulate the case that this application has specific requirements for a bundle workflow train_workflow.add_property(name="loader", required=True, config_id="train#training_transforms#0", desc="NA") # initialize algo algo = MonaiAlgo( bundle_root=_data_dir, - train_workflow=ConfigWorkflow(config_file=train_configs, workflow="train", logging_file=_logging_file), - eval_workflow=ConfigWorkflow(config_file=eval_configs, workflow="train", logging_file=_logging_file), + train_workflow=ConfigWorkflow(config_file=train_configs, workflow_type="train", logging_file=_logging_file), + eval_workflow=ConfigWorkflow(config_file=eval_configs, workflow_type="train", logging_file=_logging_file), config_filters_filename=pathjoin(_root_dir, "testing_data", "config_fl_filters.json"), ) algo.initialize(extra={ExtraItems.CLIENT_NAME: "test_fl"}) @@ -90,7 +90,7 @@ def test_evaluate(self): algo = MonaiAlgo( bundle_root=_data_dir, config_train_filename=None, - eval_workflow=ConfigWorkflow(config_file=config_file, workflow="train", logging_file=_logging_file), + eval_workflow=ConfigWorkflow(config_file=config_file, workflow_type="train", logging_file=_logging_file), config_filters_filename=pathjoin(_data_dir, "config_fl_filters.json"), ) algo.initialize(extra={ExtraItems.CLIENT_NAME: "test_fl"}) diff --git a/tests/test_fl_monai_algo_stats.py b/tests/test_fl_monai_algo_stats.py index e46b6b899a..307b3f539c 100644 --- a/tests/test_fl_monai_algo_stats.py +++ b/tests/test_fl_monai_algo_stats.py @@ -30,7 +30,7 @@ { "bundle_root": _data_dir, "workflow": ConfigWorkflow( - workflow="train", + workflow_type="train", config_file=os.path.join(_data_dir, "config_fl_stats_1.json"), logging_file=_logging_file, meta_file=None, @@ -49,7 +49,7 @@ { "bundle_root": _data_dir, "workflow": ConfigWorkflow( - workflow="train", + workflow_type="train", config_file=[ os.path.join(_data_dir, "config_fl_stats_1.json"), os.path.join(_data_dir, "config_fl_stats_2.json"), diff --git a/tests/test_handler_mlflow.py b/tests/test_handler_mlflow.py index d5578c01bc..92cf17eadb 100644 --- a/tests/test_handler_mlflow.py +++ b/tests/test_handler_mlflow.py @@ -255,7 +255,7 @@ def test_dataset_tracking(self): meta_file = os.path.join(bundle_root, "configs/metadata.json") logging_file = os.path.join(bundle_root, "configs/logging.conf") workflow = ConfigWorkflow( - workflow="infer", + workflow_type="infer", config_file=config_file, meta_file=meta_file, logging_file=logging_file, From 134d1ea3b613d5f80f6979dee0f4edcf073976ad Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Wed, 30 Aug 2023 11:22:02 +0800 Subject: [PATCH 54/56] fix flake8 Signed-off-by: KumoLiu --- monai/bundle/workflows.py | 18 +++++++++++++++--- monai/fl/client/monai_algo.py | 6 +++++- tests/test_fl_monai_algo.py | 16 ++++++++++++---- 3 files changed, 32 insertions(+), 8 deletions(-) diff --git a/monai/bundle/workflows.py b/monai/bundle/workflows.py index 5f23c2ea58..49d65a634e 100644 --- a/monai/bundle/workflows.py +++ b/monai/bundle/workflows.py @@ -24,7 +24,7 @@ from monai.bundle.config_parser import ConfigParser from monai.bundle.properties import InferProperties, MetaProperties, TrainProperties from monai.bundle.utils import DEFAULT_EXP_MGMT_SETTINGS, EXPR_KEY, ID_REF_KEY, ID_SEP_KEY -from monai.utils import BundleProperty, BundlePropertyConfig, deprecated_arg_default, ensure_tuple, deprecated_arg +from monai.utils import BundleProperty, BundlePropertyConfig, deprecated_arg, deprecated_arg_default, ensure_tuple __all__ = ["BundleWorkflow", "ConfigWorkflow"] @@ -48,7 +48,13 @@ class BundleWorkflow(ABC): supported_train_type: tuple = ("train", "training") supported_infer_type: tuple = ("infer", "inference", "eval", "evaluation") - @deprecated_arg("workflow", since="1.3", removed="1.5", new_name="workflow_type", msg_suffix="please use `workflow_type` instead.") + @deprecated_arg( + "workflow", + since="1.3", + removed="1.5", + new_name="workflow_type", + msg_suffix="please use `workflow_type` instead.", + ) def __init__(self, workflow_type: str | None = None): if workflow_type is None: self.properties = copy(MetaProperties) @@ -195,7 +201,13 @@ class ConfigWorkflow(BundleWorkflow): """ - @deprecated_arg("workflow", since="1.3", removed="1.5", new_name="workflow_type", msg_suffix="please use `workflow_type` instead.") + @deprecated_arg( + "workflow", + since="1.3", + removed="1.5", + new_name="workflow_type", + msg_suffix="please use `workflow_type` instead.", + ) @deprecated_arg_default("workflow_type", None, "train", since="1.3", replaced="1.4") def __init__( self, diff --git a/monai/fl/client/monai_algo.py b/monai/fl/client/monai_algo.py index fc79d7c420..933be96e34 100644 --- a/monai/fl/client/monai_algo.py +++ b/monai/fl/client/monai_algo.py @@ -431,7 +431,11 @@ def initialize(self, extra=None): if "run_name" not in self.train_kwargs: self.train_kwargs["run_name"] = f"{self.client_name}_{timestamp}" self.train_workflow = ConfigWorkflow( - config_file=config_train_files, meta_file=None, logging_file=None, workflow_type="train", **self.train_kwargs + config_file=config_train_files, + meta_file=None, + logging_file=None, + workflow_type="train", + **self.train_kwargs, ) if self.train_workflow is not None: self.train_workflow.initialize() diff --git a/tests/test_fl_monai_algo.py b/tests/test_fl_monai_algo.py index e649e7374f..ca781ff166 100644 --- a/tests/test_fl_monai_algo.py +++ b/tests/test_fl_monai_algo.py @@ -36,7 +36,9 @@ { "bundle_root": _data_dir, "train_workflow": ConfigWorkflow( - config_file=os.path.join(_data_dir, "config_fl_train.json"), workflow_type="train", logging_file=_logging_file + config_file=os.path.join(_data_dir, "config_fl_train.json"), + workflow_type="train", + logging_file=_logging_file, ), "config_evaluate_filename": None, "config_filters_filename": os.path.join(_data_dir, "config_fl_filters.json"), @@ -54,7 +56,9 @@ { "bundle_root": _data_dir, "train_workflow": ConfigWorkflow( - config_file=os.path.join(_data_dir, "config_fl_train.json"), workflow_type="train", logging_file=_logging_file + config_file=os.path.join(_data_dir, "config_fl_train.json"), + workflow_type="train", + logging_file=_logging_file, ), "config_evaluate_filename": None, "config_filters_filename": os.path.join(_data_dir, "config_fl_filters.json"), @@ -141,7 +145,9 @@ { "bundle_root": _data_dir, "train_workflow": ConfigWorkflow( - config_file=os.path.join(_data_dir, "config_fl_train.json"), workflow_type="train", logging_file=_logging_file + config_file=os.path.join(_data_dir, "config_fl_train.json"), + workflow_type="train", + logging_file=_logging_file, ), "config_evaluate_filename": None, "send_weight_diff": False, @@ -161,7 +167,9 @@ { "bundle_root": _data_dir, "train_workflow": ConfigWorkflow( - config_file=os.path.join(_data_dir, "config_fl_train.json"), workflow_type="train", logging_file=_logging_file + config_file=os.path.join(_data_dir, "config_fl_train.json"), + workflow_type="train", + logging_file=_logging_file, ), "config_evaluate_filename": None, "send_weight_diff": True, From e93309e46df58a0ab7098be978f9fb94b99a07f7 Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Wed, 30 Aug 2023 13:24:20 +0800 Subject: [PATCH 55/56] fix ci Signed-off-by: KumoLiu --- monai/fl/client/monai_algo.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/monai/fl/client/monai_algo.py b/monai/fl/client/monai_algo.py index 933be96e34..626bc9651d 100644 --- a/monai/fl/client/monai_algo.py +++ b/monai/fl/client/monai_algo.py @@ -317,12 +317,12 @@ class MonaiAlgo(ClientAlgo, MonaiAlgoStats): config_train_filename: bundle training config path relative to bundle_root. can be a list of files. defaults to "configs/train.json". only useful when `train_workflow` is None. train_kwargs: other args of the `ConfigWorkflow` of train, except for `config_file`, `meta_file`, - `logging_file`, `workflow`. only useful when `train_workflow` is None. + `logging_file`, `workflow_type`. only useful when `train_workflow` is None. config_evaluate_filename: bundle evaluation config path relative to bundle_root. can be a list of files. if "default", ["configs/train.json", "configs/evaluate.json"] will be used. this arg is only useful when `eval_workflow` is None. eval_kwargs: other args of the `ConfigWorkflow` of evaluation, except for `config_file`, `meta_file`, - `logging_file`, `workflow`. only useful when `eval_workflow` is None. + `logging_file`, `workflow_type`. only useful when `eval_workflow` is None. config_filters_filename: filter configuration file. Can be a list of files; defaults to `None`. disable_ckpt_loading: do not use any CheckpointLoader if defined in train/evaluate configs; defaults to `True`. best_model_filepath: location of best model checkpoint; defaults "models/model.pt" relative to `bundle_root`. @@ -459,7 +459,7 @@ def initialize(self, extra=None): config_file=config_eval_files, meta_file=None, logging_file=None, - workflow=self.eval_workflow_name, + workflow_type=self.eval_workflow_name, **self.eval_kwargs, ) if self.eval_workflow is not None: From edb8768b9fb1ad8f0d39e4419f99373d3eaaf4b1 Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Wed, 30 Aug 2023 13:34:19 +0800 Subject: [PATCH 56/56] address comments Signed-off-by: KumoLiu --- monai/bundle/workflows.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/monai/bundle/workflows.py b/monai/bundle/workflows.py index 49d65a634e..3b349e1103 100644 --- a/monai/bundle/workflows.py +++ b/monai/bundle/workflows.py @@ -223,11 +223,17 @@ def __init__( ) -> None: super().__init__(workflow_type=workflow_type) if config_file is not None: - _config_path = Path(ensure_tuple(config_file)[0]) - if _config_path.is_file(): - config_root_path = _config_path.parent - else: - raise FileNotFoundError(f"Cannot find the config file: {config_file}.") + _config_files = ensure_tuple(config_file) + config_root_path = Path(_config_files[0]).parent + for _config_file in _config_files: + _config_file = Path(_config_file) + if _config_file.parent != config_root_path: + warnings.warn( + f"Not all config files are in {config_root_path}. If logging_file and meta_file are" + f"not specified, {config_root_path} will be used as the default config root directory." + ) + if not _config_file.is_file(): + raise FileNotFoundError(f"Cannot find the config file: {_config_file}.") else: config_root_path = Path("configs")