From 2fbafac3884f63b111b4edfb35583347198199f8 Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Wed, 6 Sep 2023 17:52:37 +0800 Subject: [PATCH 01/11] avoid breaking in create `BundleWorkflow` Signed-off-by: KumoLiu --- monai/bundle/workflows.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/monai/bundle/workflows.py b/monai/bundle/workflows.py index 8d53f2e88c..5f34578f7b 100644 --- a/monai/bundle/workflows.py +++ b/monai/bundle/workflows.py @@ -43,6 +43,10 @@ class BundleWorkflow(ABC): or "infer", "inference", "eval", "evaluation" for a inference workflow, other unsupported string will raise a ValueError. default to `None` for common workflow. + workflow: specifies the workflow type: "train" or "training" for a training workflow, + or "infer", "inference", "eval", "evaluation" for a inference workflow, + other unsupported string will raise a ValueError. + default to `None` for common workflow. """ @@ -56,7 +60,8 @@ class BundleWorkflow(ABC): new_name="workflow_type", msg_suffix="please use `workflow_type` instead.", ) - def __init__(self, workflow_type: str | None = None): + def __init__(self, workflow_type: str | None = None, workflow: str | None = None): + workflow_type = workflow if workflow is not None else workflow_type if workflow_type is None: self.properties = copy(MetaProperties) self.workflow_type = None @@ -198,6 +203,10 @@ class ConfigWorkflow(BundleWorkflow): or "infer", "inference", "eval", "evaluation" for a inference workflow, other unsupported string will raise a ValueError. default to `None` for common workflow. + workflow: specifies the workflow type: "train" or "training" for a training workflow, + or "infer", "inference", "eval", "evaluation" for a inference workflow, + other unsupported string will raise a ValueError. + default to `None` for common workflow. override: id-value pairs to override or add the corresponding config content. e.g. ``--net#input_chns 42``, ``--net %/data/other.json#net_arg`` @@ -221,8 +230,10 @@ def __init__( final_id: str = "finalize", tracking: str | dict | None = None, workflow_type: str | None = None, + workflow: str | None = None, **override: Any, ) -> None: + workflow_type = workflow if workflow is not None else workflow_type super().__init__(workflow_type=workflow_type) if config_file is not None: _config_files = ensure_tuple(config_file) From 5b81e88a2e78baebb264132cd8d7bc1865dd4194 Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Wed, 6 Sep 2023 22:41:34 +0800 Subject: [PATCH 02/11] avoid breaking changes in `load` Signed-off-by: KumoLiu --- monai/bundle/scripts.py | 54 ++++++++++++++++++++++++++--------- tests/test_bundle_download.py | 22 ++++++++++++-- 2 files changed, 61 insertions(+), 15 deletions(-) diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index cdea2b4218..f3625e87ee 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -28,6 +28,7 @@ from monai.apps.mmars.mmars import _get_all_ngc_models from monai.apps.utils import _basename, download_url, extractall, get_logger +from monai.bundle.config_item import ConfigComponent from monai.bundle.config_parser import ConfigParser from monai.bundle.utils import DEFAULT_INFERENCE, DEFAULT_METADATA from monai.bundle.workflows import BundleWorkflow, ConfigWorkflow @@ -247,7 +248,7 @@ def _process_bundle_dir(bundle_dir: PathLike | None = None) -> Path: return Path(bundle_dir) -@deprecated_arg_default("source", "github", "monaihosting", since="1.3", replaced="1.4") +@deprecated_arg_default("source", "github", "monaihosting", since="1.3", replaced="1.5") def download( name: str | None = None, version: str | None = None, @@ -375,8 +376,9 @@ def download( ) -@deprecated_arg("net_name", since="1.3", removed="1.4", msg_suffix="please use ``model`` instead.") -@deprecated_arg("net_kwargs", since="1.3", removed="1.3", msg_suffix="please use ``model`` instead.") +@deprecated_arg("net_name", since="1.3", removed="1.5", msg_suffix="please use ``model`` instead.") +@deprecated_arg("net_kwargs", since="1.3", removed="1.5", msg_suffix="please use ``model`` instead.") +@deprecated_arg_default("return_state_dict", "True", "False", since="1.3", replaced="1.5") def load( name: str, model: torch.nn.Module | None = None, @@ -395,8 +397,10 @@ def load( workflow_name: str | BundleWorkflow | None = None, args_file: str | None = None, copy_model_args: dict | None = None, + return_state_dict: bool = True, + net_override: dict = {}, net_name: str | None = None, - **net_override: Any, + **net_kwargs: Any ) -> object | tuple[torch.nn.Module, dict, dict] | Any: """ Load model weights or TorchScript module of a bundle. @@ -441,7 +445,12 @@ def load( workflow_name: specified bundle workflow name, should be a string or class, default to "ConfigWorkflow". args_file: a JSON or YAML file to provide default values for all the args in "download" function. copy_model_args: other arguments for the `monai.networks.copy_model_state` function. + return_state_dict: whether to return state dict, if True, return state_dict, else a corresponding network + from `_workflow.network_def` will be instantiated and load the achieved weights. net_override: id-value pairs to override the parameters in the network of the bundle. + net_name: if not `None`, a corresponding network will be instantiated and load the achieved weights. + This argument only works when loading weights. + net_kwargs: other arguments that are used to instantiate the network class defined by `net_name`. Returns: 1. If `load_ts_module` is `False` and `model` is `None`, @@ -449,9 +458,10 @@ def load( else return an instantiated network that loaded the weights. 2. If `load_ts_module` is `False` and `model` is not `None`, return an instantiated network that loaded the weights. - 3. If `load_ts_module` is `True`, return a triple that include a TorchScript module, + 1. If `load_ts_module` is `True`, return a triple that include a TorchScript module, the corresponding metadata dict, and extra files dict. please check `monai.data.load_net_with_metadata` for more details. + 2. If `model` is `None`, """ bundle_dir_ = _process_bundle_dir(bundle_dir) @@ -466,7 +476,7 @@ def load( if remove_prefix: name = _remove_ngc_prefix(name, prefix=remove_prefix) full_path = os.path.join(bundle_dir_, name, model_file) - if not os.path.exists(full_path) or model is None: + if not os.path.exists(full_path): download( name=name, version=version, @@ -477,6 +487,14 @@ def load( progress=progress, args_file=args_file, ) + + # loading with `torch.jit.load` + if load_ts_module is True: + return load_net_with_metadata(full_path, map_location=torch.device(device), more_extra_files=config_files) + # loading with `torch.load` + model_dict = torch.load(full_path, map_location=torch.device(device)) + + if model is None and not return_state_dict: train_config_file = bundle_dir_ / name / "configs" / f"{workflow_type}.json" if train_config_file.is_file(): _net_override = {f"network_def#{key}": value for key, value in net_override.items()} @@ -490,21 +508,31 @@ def load( else: _workflow = None - # loading with `torch.jit.load` - if load_ts_module is True: - return load_net_with_metadata(full_path, map_location=torch.device(device), more_extra_files=config_files) - # loading with `torch.load` - model_dict = torch.load(full_path, map_location=torch.device(device)) if not isinstance(model_dict, Mapping): warnings.warn(f"the state dictionary from {full_path} should be a dictionary but got {type(model_dict)}.") model_dict = get_state_dict(model_dict) - if model is None and _workflow is None: + if (model is None and _workflow is None) or return_state_dict: return model_dict - model = _workflow.network_def if model is None else model + + if model is None: + if return_state_dict: + return model_dict + elif _workflow is not None: + if getattr(_workflow, "network_def") is None: + warnings.warn("No available network definition in the bundle, return state dict instead.") + return model_dict + else: + model = _workflow.network_def + elif net_name is not None: + net_kwargs["_target_"] = net_name + configer = ConfigComponent(config=net_kwargs) + model = configer.instantiate() + model.to(device) copy_model_state(dst=model, src=model_dict if key_in_ckpt is None else model_dict[key_in_ckpt], **copy_model_args) + return model diff --git a/tests/test_bundle_download.py b/tests/test_bundle_download.py index 2457af3229..499f923325 100644 --- a/tests/test_bundle_download.py +++ b/tests/test_bundle_download.py @@ -146,6 +146,7 @@ def test_load_weights(self, bundle_files, bundle_name, repo, device, model_file) source="github", progress=False, device=device, + return_state_dict=False, ) # prepare network @@ -176,11 +177,27 @@ def test_load_weights(self, bundle_files, bundle_name, repo, device, model_file) device=device, net_name=model_name, source="github", + return_state_dict=False, ) model_2.eval() output_2 = model_2.forward(input_tensor) assert_allclose(output_2, expected_output, atol=1e-4, rtol=1e-4, type_test=False) + # test forward compatibility with return_state_dict=True. + model_3 = load( + name=bundle_name, + model_file=model_file, + bundle_dir=tempdir, + progress=False, + device=device, + net_name=model_name, + source="github", + **net_args, + ) + model_3.eval() + output_3 = model_3.forward(input_tensor) + assert_allclose(output_3, expected_output, atol=1e-4, rtol=1e-4, type_test=False) + @parameterized.expand([TEST_CASE_7]) @skip_if_quick def test_load_weights_with_net_override(self, bundle_name, device, net_override): @@ -188,7 +205,7 @@ def test_load_weights_with_net_override(self, bundle_name, device, net_override) # download bundle, and load weights from the downloaded path with tempfile.TemporaryDirectory() as tempdir: # load weights - model = load(name=bundle_name, bundle_dir=tempdir, source="monaihosting", progress=False, device=device) + model = load(name=bundle_name, bundle_dir=tempdir, source="monaihosting", progress=False, device=device, return_state_dict=False) # prepare data and test input_tensor = torch.rand(1, 1, 96, 96, 96).to(device) @@ -209,7 +226,8 @@ def test_load_weights_with_net_override(self, bundle_name, device, net_override) source="monaihosting", progress=False, device=device, - **net_override, + return_state_dict=False, + net_override=net_override, ) # prepare data and test From 02d1ab259ecbe5841cebfc2e8d64544a5a3bd0b8 Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Thu, 7 Sep 2023 00:39:20 +0800 Subject: [PATCH 03/11] add unittests Signed-off-by: KumoLiu --- monai/bundle/scripts.py | 19 ++++++++++--------- tests/test_bundle_download.py | 6 +++--- 2 files changed, 13 insertions(+), 12 deletions(-) diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index f3625e87ee..0f2c846a4a 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -458,12 +458,16 @@ def load( else return an instantiated network that loaded the weights. 2. If `load_ts_module` is `False` and `model` is not `None`, return an instantiated network that loaded the weights. - 1. If `load_ts_module` is `True`, return a triple that include a TorchScript module, + 3. If `load_ts_module` is `True`, return a triple that include a TorchScript module, the corresponding metadata dict, and extra files dict. please check `monai.data.load_net_with_metadata` for more details. - 2. If `model` is `None`, + 4. If `return_state_dict` is True, return model weights, only used for compatibility + when `model` and `net_name` are all `None`. """ + if return_state_dict and (model is not None or net_name is not None): + warnings.warn("Incompatible values: model and net_name are all specified, return state dict instead.") + bundle_dir_ = _process_bundle_dir(bundle_dir) copy_model_args = {} if copy_model_args is None else copy_model_args @@ -494,7 +498,8 @@ def load( # loading with `torch.load` model_dict = torch.load(full_path, map_location=torch.device(device)) - if model is None and not return_state_dict: + _workflow = None + if (model is None and not return_state_dict) or net_name is None: train_config_file = bundle_dir_ / name / "configs" / f"{workflow_type}.json" if train_config_file.is_file(): _net_override = {f"network_def#{key}": value for key, value in net_override.items()} @@ -505,20 +510,16 @@ def load( workflow_type=workflow_type, **_net_override, ) - else: - _workflow = None if not isinstance(model_dict, Mapping): warnings.warn(f"the state dictionary from {full_path} should be a dictionary but got {type(model_dict)}.") model_dict = get_state_dict(model_dict) - if (model is None and _workflow is None) or return_state_dict: + if return_state_dict: return model_dict if model is None: - if return_state_dict: - return model_dict - elif _workflow is not None: + if _workflow is not None: if getattr(_workflow, "network_def") is None: warnings.warn("No available network definition in the bundle, return state dict instead.") return model_dict diff --git a/tests/test_bundle_download.py b/tests/test_bundle_download.py index 499f923325..2fe1a8a19d 100644 --- a/tests/test_bundle_download.py +++ b/tests/test_bundle_download.py @@ -146,7 +146,7 @@ def test_load_weights(self, bundle_files, bundle_name, repo, device, model_file) source="github", progress=False, device=device, - return_state_dict=False, + return_state_dict=True, ) # prepare network @@ -175,7 +175,6 @@ def test_load_weights(self, bundle_files, bundle_name, repo, device, model_file) bundle_dir=tempdir, progress=False, device=device, - net_name=model_name, source="github", return_state_dict=False, ) @@ -183,7 +182,7 @@ def test_load_weights(self, bundle_files, bundle_name, repo, device, model_file) output_2 = model_2.forward(input_tensor) assert_allclose(output_2, expected_output, atol=1e-4, rtol=1e-4, type_test=False) - # test forward compatibility with return_state_dict=True. + # test compatibility with return_state_dict=True. model_3 = load( name=bundle_name, model_file=model_file, @@ -192,6 +191,7 @@ def test_load_weights(self, bundle_files, bundle_name, repo, device, model_file) device=device, net_name=model_name, source="github", + return_state_dict=False, **net_args, ) model_3.eval() From 75dd8e96c938b969d897133e657fe05f82ef73d9 Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Thu, 7 Sep 2023 00:42:34 +0800 Subject: [PATCH 04/11] fix flake8 Signed-off-by: KumoLiu --- monai/bundle/scripts.py | 2 +- tests/ngc_bundle_download.py | 2 +- tests/test_bundle_download.py | 9 ++++++++- 3 files changed, 10 insertions(+), 3 deletions(-) diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index 0f2c846a4a..4d8e171010 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -400,7 +400,7 @@ def load( return_state_dict: bool = True, net_override: dict = {}, net_name: str | None = None, - **net_kwargs: Any + **net_kwargs: Any, ) -> object | tuple[torch.nn.Module, dict, dict] | Any: """ Load model weights or TorchScript module of a bundle. diff --git a/tests/ngc_bundle_download.py b/tests/ngc_bundle_download.py index f699914f6a..60fdee80d2 100644 --- a/tests/ngc_bundle_download.py +++ b/tests/ngc_bundle_download.py @@ -83,7 +83,7 @@ def test_ngc_download_bundle(self, bundle_name, version, remove_prefix, download self.assertTrue(check_hash(filepath=full_file_path, val=hash_val)) model = load( - name=bundle_name, source="ngc", version=version, bundle_dir=tempdir, remove_prefix=remove_prefix + name=bundle_name, source="ngc", version=version, bundle_dir=tempdir, remove_prefix=remove_prefix, return_state_dict=False ) assert_allclose( model.state_dict()[TESTCASE_WEIGHTS["key"]], diff --git a/tests/test_bundle_download.py b/tests/test_bundle_download.py index 2fe1a8a19d..d43cf3b9c0 100644 --- a/tests/test_bundle_download.py +++ b/tests/test_bundle_download.py @@ -205,7 +205,14 @@ def test_load_weights_with_net_override(self, bundle_name, device, net_override) # download bundle, and load weights from the downloaded path with tempfile.TemporaryDirectory() as tempdir: # load weights - model = load(name=bundle_name, bundle_dir=tempdir, source="monaihosting", progress=False, device=device, return_state_dict=False) + model = load( + name=bundle_name, + bundle_dir=tempdir, + source="monaihosting", + progress=False, + device=device, + return_state_dict=False, + ) # prepare data and test input_tensor = torch.rand(1, 1, 96, 96, 96).to(device) From 403fce52dd53abaddfbb233c5099796b3192ba02 Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Thu, 7 Sep 2023 00:45:23 +0800 Subject: [PATCH 05/11] fix mypy Signed-off-by: KumoLiu --- monai/bundle/scripts.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index 4d8e171010..80d6c91a19 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -378,7 +378,7 @@ def download( @deprecated_arg("net_name", since="1.3", removed="1.5", msg_suffix="please use ``model`` instead.") @deprecated_arg("net_kwargs", since="1.3", removed="1.5", msg_suffix="please use ``model`` instead.") -@deprecated_arg_default("return_state_dict", "True", "False", since="1.3", replaced="1.5") +@deprecated_arg("return_state_dict", since="1.3", removed="1.5") def load( name: str, model: torch.nn.Module | None = None, @@ -528,11 +528,11 @@ def load( elif net_name is not None: net_kwargs["_target_"] = net_name configer = ConfigComponent(config=net_kwargs) - model = configer.instantiate() + model = configer.instantiate() # type: ignore - model.to(device) + model.to(device) # type: ignore - copy_model_state(dst=model, src=model_dict if key_in_ckpt is None else model_dict[key_in_ckpt], **copy_model_args) + copy_model_state(dst=model, src=model_dict if key_in_ckpt is None else model_dict[key_in_ckpt], **copy_model_args) # type: ignore return model From d14ab24c083b221dec2e9468f3242b009c827ee1 Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Thu, 7 Sep 2023 00:47:38 +0800 Subject: [PATCH 06/11] fix flake8 Signed-off-by: KumoLiu --- tests/ngc_bundle_download.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tests/ngc_bundle_download.py b/tests/ngc_bundle_download.py index 60fdee80d2..ba35f2b80c 100644 --- a/tests/ngc_bundle_download.py +++ b/tests/ngc_bundle_download.py @@ -83,7 +83,12 @@ def test_ngc_download_bundle(self, bundle_name, version, remove_prefix, download self.assertTrue(check_hash(filepath=full_file_path, val=hash_val)) model = load( - name=bundle_name, source="ngc", version=version, bundle_dir=tempdir, remove_prefix=remove_prefix, return_state_dict=False + name=bundle_name, + source="ngc", + version=version, + bundle_dir=tempdir, + remove_prefix=remove_prefix, + return_state_dict=False, ) assert_allclose( model.state_dict()[TESTCASE_WEIGHTS["key"]], From 416c008b7b592f59a548818172b106a33e42195b Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Thu, 7 Sep 2023 00:58:49 +0800 Subject: [PATCH 07/11] minor fix Signed-off-by: KumoLiu --- monai/bundle/scripts.py | 35 ++++++++++++++++++----------------- 1 file changed, 18 insertions(+), 17 deletions(-) diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index 80d6c91a19..2ed29b64e1 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -498,37 +498,38 @@ def load( # loading with `torch.load` model_dict = torch.load(full_path, map_location=torch.device(device)) + if not isinstance(model_dict, Mapping): + warnings.warn(f"the state dictionary from {full_path} should be a dictionary but got {type(model_dict)}.") + model_dict = get_state_dict(model_dict) + + if return_state_dict: + return model_dict + _workflow = None - if (model is None and not return_state_dict) or net_name is None: - train_config_file = bundle_dir_ / name / "configs" / f"{workflow_type}.json" - if train_config_file.is_file(): + if model is None and net_name is None: + bundle_config_file = bundle_dir_ / name / "configs" / f"{workflow_type}.json" + if bundle_config_file.is_file(): _net_override = {f"network_def#{key}": value for key, value in net_override.items()} _workflow = create_workflow( workflow_name=workflow_name, args_file=args_file, - config_file=str(train_config_file), + config_file=str(bundle_config_file), workflow_type=workflow_type, **_net_override, ) - - if not isinstance(model_dict, Mapping): - warnings.warn(f"the state dictionary from {full_path} should be a dictionary but got {type(model_dict)}.") - model_dict = get_state_dict(model_dict) - - if return_state_dict: - return model_dict - - if model is None: + else: + warnings.warn(f"Cannot find the config file: {bundle_config_file}, return state dict instead.") + return model_dict if _workflow is not None: if getattr(_workflow, "network_def") is None: warnings.warn("No available network definition in the bundle, return state dict instead.") return model_dict else: model = _workflow.network_def - elif net_name is not None: - net_kwargs["_target_"] = net_name - configer = ConfigComponent(config=net_kwargs) - model = configer.instantiate() # type: ignore + elif net_name is not None: + net_kwargs["_target_"] = net_name + configer = ConfigComponent(config=net_kwargs) + model = configer.instantiate() # type: ignore model.to(device) # type: ignore From 929186253f610c944c720868b6448c9e94b41f20 Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Thu, 7 Sep 2023 01:00:09 +0800 Subject: [PATCH 08/11] fix flake8 Signed-off-by: KumoLiu --- monai/bundle/scripts.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index 2ed29b64e1..d1c8b0ce3e 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -533,7 +533,11 @@ def load( model.to(device) # type: ignore - copy_model_state(dst=model, src=model_dict if key_in_ckpt is None else model_dict[key_in_ckpt], **copy_model_args) # type: ignore + copy_model_state( # type: ignore + dst=model, + src=model_dict if key_in_ckpt is None else model_dict[key_in_ckpt], + **copy_model_args + ) return model From df961144e5756051c81181368d8177d4c808dfa7 Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Thu, 7 Sep 2023 01:01:56 +0800 Subject: [PATCH 09/11] fix mypy Signed-off-by: KumoLiu --- monai/bundle/scripts.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index d1c8b0ce3e..a689cb107e 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -533,10 +533,8 @@ def load( model.to(device) # type: ignore - copy_model_state( # type: ignore - dst=model, - src=model_dict if key_in_ckpt is None else model_dict[key_in_ckpt], - **copy_model_args + copy_model_state( + dst=model, src=model_dict if key_in_ckpt is None else model_dict[key_in_ckpt], **copy_model_args # type: ignore ) return model From d7d2737a34ab35bc1bbb81b9cf6b198c5e94c82e Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Thu, 7 Sep 2023 10:14:23 +0800 Subject: [PATCH 10/11] Update monai/bundle/scripts.py Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- monai/bundle/scripts.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index a689cb107e..7b7daf6c0a 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -521,7 +521,7 @@ def load( warnings.warn(f"Cannot find the config file: {bundle_config_file}, return state dict instead.") return model_dict if _workflow is not None: - if getattr(_workflow, "network_def") is None: + if not hasattr(_workflow, "network_def"): warnings.warn("No available network definition in the bundle, return state dict instead.") return model_dict else: From 8a631692f1237527fbd4f1ee78cd08b0d4bf6f99 Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Thu, 7 Sep 2023 14:37:52 +0800 Subject: [PATCH 11/11] fix flake8 Signed-off-by: KumoLiu --- monai/bundle/scripts.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index 7b7daf6c0a..fc8dafbc77 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -398,7 +398,7 @@ def load( args_file: str | None = None, copy_model_args: dict | None = None, return_state_dict: bool = True, - net_override: dict = {}, + net_override: dict | None = None, net_name: str | None = None, **net_kwargs: Any, ) -> object | tuple[torch.nn.Module, dict, dict] | Any: @@ -447,7 +447,7 @@ def load( copy_model_args: other arguments for the `monai.networks.copy_model_state` function. return_state_dict: whether to return state dict, if True, return state_dict, else a corresponding network from `_workflow.network_def` will be instantiated and load the achieved weights. - net_override: id-value pairs to override the parameters in the network of the bundle. + net_override: id-value pairs to override the parameters in the network of the bundle, default to `None`. net_name: if not `None`, a corresponding network will be instantiated and load the achieved weights. This argument only works when loading weights. net_kwargs: other arguments that are used to instantiate the network class defined by `net_name`. @@ -469,6 +469,7 @@ def load( warnings.warn("Incompatible values: model and net_name are all specified, return state dict instead.") bundle_dir_ = _process_bundle_dir(bundle_dir) + net_override = {} if net_override is None else net_override copy_model_args = {} if copy_model_args is None else copy_model_args if device is None: