diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index 2873997290..2046a6242a 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -702,8 +702,6 @@ def load( 3. If `load_ts_module` is `True`, return a triple that include a TorchScript module, the corresponding metadata dict, and extra files dict. please check `monai.data.load_net_with_metadata` for more details. - 4. If `return_state_dict` is True, return model weights, only used for compatibility - when `model` and `net_name` are all `None`. """ bundle_dir_ = _process_bundle_dir(bundle_dir) diff --git a/tests/bundle/test_bundle_download.py b/tests/bundle/test_bundle_download.py index dc73ea6d99..a073d1ba76 100644 --- a/tests/bundle/test_bundle_download.py +++ b/tests/bundle/test_bundle_download.py @@ -268,11 +268,10 @@ class TestLoad(unittest.TestCase): @skip_if_quick def test_load_weights(self, bundle_files, bundle_name, repo, device, model_file): with skip_if_downloading_fails(): - # download bundle, and load weights from the downloaded path with tempfile.TemporaryDirectory() as tempdir: bundle_root = os.path.join(tempdir, bundle_name) # load weights - weights = load( + model_1 = load( name=bundle_name, model_file=model_file, bundle_dir=tempdir, @@ -280,7 +279,6 @@ def test_load_weights(self, bundle_files, bundle_name, repo, device, model_file) source="github", progress=False, device=device, - return_state_dict=True, ) # prepare network with open(os.path.join(bundle_root, bundle_files[2])) as f: @@ -289,7 +287,7 @@ def test_load_weights(self, bundle_files, bundle_name, repo, device, model_file) del net_args["_target_"] model = getattr(nets, model_name)(**net_args) model.to(device) - model.load_state_dict(weights) + model.load_state_dict(model_1) model.eval() # prepare data and test @@ -313,13 +311,11 @@ def test_load_weights(self, bundle_files, bundle_name, repo, device, model_file) progress=False, device=device, source="github", - return_state_dict=False, ) model_2.eval() output_2 = model_2.forward(input_tensor) assert_allclose(output_2, expected_output, atol=1e-4, rtol=1e-4, type_test=False) - # test compatibility with return_state_dict=True. model_3 = load( name=bundle_name, model_file=model_file, @@ -328,7 +324,6 @@ def test_load_weights(self, bundle_files, bundle_name, repo, device, model_file) device=device, net_name=model_name, source="github", - return_state_dict=False, **net_args, ) model_3.eval() @@ -343,14 +338,7 @@ def test_load_weights_with_net_override(self, bundle_name, device, net_override) # download bundle, and load weights from the downloaded path with tempfile.TemporaryDirectory() as tempdir: # load weights - model = load( - name=bundle_name, - bundle_dir=tempdir, - source="monaihosting", - progress=False, - device=device, - return_state_dict=False, - ) + model = load(name=bundle_name, bundle_dir=tempdir, source="monaihosting", progress=False, device=device) # prepare data and test input_tensor = torch.rand(1, 1, 96, 96, 96).to(device) @@ -371,7 +359,6 @@ def test_load_weights_with_net_override(self, bundle_name, device, net_override) source="monaihosting", progress=False, device=device, - return_state_dict=False, net_override=net_override, ) diff --git a/tests/ngc_bundle_download.py b/tests/ngc_bundle_download.py index 7953f6201b..612397b5a1 100644 --- a/tests/ngc_bundle_download.py +++ b/tests/ngc_bundle_download.py @@ -83,12 +83,7 @@ def test_ngc_download_bundle(self, bundle_name, version, remove_prefix, download self.assertTrue(check_hash(filepath=full_file_path, val=hash_val)) model = load( - name=bundle_name, - source="ngc", - version=version, - bundle_dir=tempdir, - remove_prefix=remove_prefix, - return_state_dict=False, + name=bundle_name, source="ngc", version=version, bundle_dir=tempdir, remove_prefix=remove_prefix ) assert_allclose( model.state_dict()[TESTCASE_WEIGHTS["key"]],