From a170243770ae52e17592cba34e215c3293d0d63f Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Fri, 16 May 2025 15:48:56 +0800 Subject: [PATCH 1/5] fix #8453 Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- tests/bundle/test_bundle_download.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/bundle/test_bundle_download.py b/tests/bundle/test_bundle_download.py index dc73ea6d99..e7dce43c6a 100644 --- a/tests/bundle/test_bundle_download.py +++ b/tests/bundle/test_bundle_download.py @@ -280,7 +280,6 @@ def test_load_weights(self, bundle_files, bundle_name, repo, device, model_file) source="github", progress=False, device=device, - return_state_dict=True, ) # prepare network with open(os.path.join(bundle_root, bundle_files[2])) as f: @@ -349,7 +348,6 @@ def test_load_weights_with_net_override(self, bundle_name, device, net_override) source="monaihosting", progress=False, device=device, - return_state_dict=False, ) # prepare data and test From ba8fc01de7cb2fc25ef21362eb5b6d51ac1cf297 Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Fri, 16 May 2025 17:05:44 +0800 Subject: [PATCH 2/5] fix #8453 Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- monai/bundle/scripts.py | 2 -- tests/bundle/test_bundle_download.py | 7 +++---- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index 2873997290..2046a6242a 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -702,8 +702,6 @@ def load( 3. If `load_ts_module` is `True`, return a triple that include a TorchScript module, the corresponding metadata dict, and extra files dict. please check `monai.data.load_net_with_metadata` for more details. - 4. If `return_state_dict` is True, return model weights, only used for compatibility - when `model` and `net_name` are all `None`. """ bundle_dir_ = _process_bundle_dir(bundle_dir) diff --git a/tests/bundle/test_bundle_download.py b/tests/bundle/test_bundle_download.py index e7dce43c6a..5777d518f4 100644 --- a/tests/bundle/test_bundle_download.py +++ b/tests/bundle/test_bundle_download.py @@ -268,11 +268,10 @@ class TestLoad(unittest.TestCase): @skip_if_quick def test_load_weights(self, bundle_files, bundle_name, repo, device, model_file): with skip_if_downloading_fails(): - # download bundle, and load weights from the downloaded path with tempfile.TemporaryDirectory() as tempdir: bundle_root = os.path.join(tempdir, bundle_name) # load weights - weights = load( + model_1 = load( name=bundle_name, model_file=model_file, bundle_dir=tempdir, @@ -288,7 +287,7 @@ def test_load_weights(self, bundle_files, bundle_name, repo, device, model_file) del net_args["_target_"] model = getattr(nets, model_name)(**net_args) model.to(device) - model.load_state_dict(weights) + model.load_state_dict(model_1) model.eval() # prepare data and test @@ -334,6 +333,7 @@ def test_load_weights(self, bundle_files, bundle_name, repo, device, model_file) output_3 = model_3.forward(input_tensor) assert_allclose(output_3, expected_output, atol=1e-4, rtol=1e-4, type_test=False) + @parameterized.expand([TEST_CASE_8]) @skip_if_quick @skipUnless(has_huggingface_hub, "Requires `huggingface_hub`.") @@ -369,7 +369,6 @@ def test_load_weights_with_net_override(self, bundle_name, device, net_override) source="monaihosting", progress=False, device=device, - return_state_dict=False, net_override=net_override, ) From c498a3b1647fc81682a151c23b0d34674e23a47d Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Fri, 16 May 2025 17:07:46 +0800 Subject: [PATCH 3/5] fix format Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- tests/bundle/test_bundle_download.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/tests/bundle/test_bundle_download.py b/tests/bundle/test_bundle_download.py index 5777d518f4..0a237ad53b 100644 --- a/tests/bundle/test_bundle_download.py +++ b/tests/bundle/test_bundle_download.py @@ -333,7 +333,6 @@ def test_load_weights(self, bundle_files, bundle_name, repo, device, model_file) output_3 = model_3.forward(input_tensor) assert_allclose(output_3, expected_output, atol=1e-4, rtol=1e-4, type_test=False) - @parameterized.expand([TEST_CASE_8]) @skip_if_quick @skipUnless(has_huggingface_hub, "Requires `huggingface_hub`.") @@ -342,13 +341,7 @@ def test_load_weights_with_net_override(self, bundle_name, device, net_override) # download bundle, and load weights from the downloaded path with tempfile.TemporaryDirectory() as tempdir: # load weights - model = load( - name=bundle_name, - bundle_dir=tempdir, - source="monaihosting", - progress=False, - device=device, - ) + model = load(name=bundle_name, bundle_dir=tempdir, source="monaihosting", progress=False, device=device) # prepare data and test input_tensor = torch.rand(1, 1, 96, 96, 96).to(device) From 407d244a68df22a34f34e753be1fe31c1dac0643 Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Fri, 16 May 2025 17:10:37 +0800 Subject: [PATCH 4/5] fix #8453 Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- tests/bundle/test_bundle_download.py | 3 --- tests/ngc_bundle_download.py | 1 - 2 files changed, 4 deletions(-) diff --git a/tests/bundle/test_bundle_download.py b/tests/bundle/test_bundle_download.py index 0a237ad53b..a073d1ba76 100644 --- a/tests/bundle/test_bundle_download.py +++ b/tests/bundle/test_bundle_download.py @@ -311,13 +311,11 @@ def test_load_weights(self, bundle_files, bundle_name, repo, device, model_file) progress=False, device=device, source="github", - return_state_dict=False, ) model_2.eval() output_2 = model_2.forward(input_tensor) assert_allclose(output_2, expected_output, atol=1e-4, rtol=1e-4, type_test=False) - # test compatibility with return_state_dict=True. model_3 = load( name=bundle_name, model_file=model_file, @@ -326,7 +324,6 @@ def test_load_weights(self, bundle_files, bundle_name, repo, device, model_file) device=device, net_name=model_name, source="github", - return_state_dict=False, **net_args, ) model_3.eval() diff --git a/tests/ngc_bundle_download.py b/tests/ngc_bundle_download.py index 7953f6201b..10ce16a0e6 100644 --- a/tests/ngc_bundle_download.py +++ b/tests/ngc_bundle_download.py @@ -88,7 +88,6 @@ def test_ngc_download_bundle(self, bundle_name, version, remove_prefix, download version=version, bundle_dir=tempdir, remove_prefix=remove_prefix, - return_state_dict=False, ) assert_allclose( model.state_dict()[TESTCASE_WEIGHTS["key"]], From 1774ed302b96b2107963e3352ffc937e0700ce20 Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Fri, 16 May 2025 22:20:31 +0800 Subject: [PATCH 5/5] fix format Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- tests/ngc_bundle_download.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/tests/ngc_bundle_download.py b/tests/ngc_bundle_download.py index 10ce16a0e6..612397b5a1 100644 --- a/tests/ngc_bundle_download.py +++ b/tests/ngc_bundle_download.py @@ -83,11 +83,7 @@ def test_ngc_download_bundle(self, bundle_name, version, remove_prefix, download self.assertTrue(check_hash(filepath=full_file_path, val=hash_val)) model = load( - name=bundle_name, - source="ngc", - version=version, - bundle_dir=tempdir, - remove_prefix=remove_prefix, + name=bundle_name, source="ngc", version=version, bundle_dir=tempdir, remove_prefix=remove_prefix ) assert_allclose( model.state_dict()[TESTCASE_WEIGHTS["key"]],