From ed0ea9b11e06242b3c73dc395bf61c3420bdccdb Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Thu, 31 Aug 2023 14:29:38 +0800 Subject: [PATCH 1/7] support regular in mapping Signed-off-by: KumoLiu --- monai/networks/utils.py | 12 +++++++----- tests/test_copy_model_state.py | 2 +- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/monai/networks/utils.py b/monai/networks/utils.py index 83429a2837..18b7b21cb9 100644 --- a/monai/networks/utils.py +++ b/monai/networks/utils.py @@ -531,11 +531,13 @@ def copy_model_state( updated_keys.append(dst_key) for s in mapping if mapping else {}: dst_key = f"{dst_prefix}{mapping[s]}" - if dst_key in dst_dict and dst_key not in to_skip: - if dst_dict[dst_key].shape != src_dict[s].shape: - warnings.warn(f"Param. shape changed from {dst_dict[dst_key].shape} to {src_dict[s].shape}.") - dst_dict[dst_key] = src_dict[s] - updated_keys.append(dst_key) + src_keys = sorted(set(s_key for s_key in src_dict if s_key not in to_skip and re.compile(s).search(s_key))) + dst_keys = sorted(set(d_key for d_key in dst_dict if d_key not in to_skip and re.compile(dst_key).search(d_key))) + for _src_key, _dst_key in zip(src_keys, dst_keys): + if dst_dict[_dst_key].shape != src_dict[_src_key].shape: + warnings.warn(f"Param. shape changed from {dst_dict[_dst_key].shape} to {src_dict[_src_key].shape}.") + dst_dict[_dst_key] = src_dict[_src_key] + updated_keys.append(_dst_key) updated_keys = sorted(set(updated_keys)) unchanged_keys = sorted(set(all_keys).difference(updated_keys)) diff --git a/tests/test_copy_model_state.py b/tests/test_copy_model_state.py index 2e7513b234..f70f99c99f 100644 --- a/tests/test_copy_model_state.py +++ b/tests/test_copy_model_state.py @@ -134,7 +134,7 @@ def test_set_map_across(self, device_0, device_1): model_two.to(device_1) # test weight map model_dict, ch, unch = copy_model_state( - model_one, model_two, mapping={"layer_1.weight": "layer.weight", "layer_1.bias": "layer_1.weight"} + model_one, model_two, mapping={"layer_1.weight": "^layer.weight", "layer_1.bias": "layer_1.weight"} ) model_one.load_state_dict(model_dict) x = np.random.randn(4, 10) From 95ed94178cd1c671fbddc676c1b6994001e2aa42 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 31 Aug 2023 06:47:14 +0000 Subject: [PATCH 2/7] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/networks/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/networks/utils.py b/monai/networks/utils.py index 18b7b21cb9..13b5fd0701 100644 --- a/monai/networks/utils.py +++ b/monai/networks/utils.py @@ -531,8 +531,8 @@ def copy_model_state( updated_keys.append(dst_key) for s in mapping if mapping else {}: dst_key = f"{dst_prefix}{mapping[s]}" - src_keys = sorted(set(s_key for s_key in src_dict if s_key not in to_skip and re.compile(s).search(s_key))) - dst_keys = sorted(set(d_key for d_key in dst_dict if d_key not in to_skip and re.compile(dst_key).search(d_key))) + src_keys = sorted({s_key for s_key in src_dict if s_key not in to_skip and re.compile(s).search(s_key)}) + dst_keys = sorted({d_key for d_key in dst_dict if d_key not in to_skip and re.compile(dst_key).search(d_key)}) for _src_key, _dst_key in zip(src_keys, dst_keys): if dst_dict[_dst_key].shape != src_dict[_src_key].shape: warnings.warn(f"Param. shape changed from {dst_dict[_dst_key].shape} to {src_dict[_src_key].shape}.") From a98b92fda5405fb3243e7e0db7a6c17ac0ff2230 Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Mon, 11 Sep 2023 18:41:31 +0800 Subject: [PATCH 3/7] Revert "support regular in mapping" This reverts commit ed0ea9b11e06242b3c73dc395bf61c3420bdccdb. Signed-off-by: KumoLiu --- monai/networks/utils.py | 12 +++++------- tests/test_copy_model_state.py | 2 +- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/monai/networks/utils.py b/monai/networks/utils.py index 13b5fd0701..83429a2837 100644 --- a/monai/networks/utils.py +++ b/monai/networks/utils.py @@ -531,13 +531,11 @@ def copy_model_state( updated_keys.append(dst_key) for s in mapping if mapping else {}: dst_key = f"{dst_prefix}{mapping[s]}" - src_keys = sorted({s_key for s_key in src_dict if s_key not in to_skip and re.compile(s).search(s_key)}) - dst_keys = sorted({d_key for d_key in dst_dict if d_key not in to_skip and re.compile(dst_key).search(d_key)}) - for _src_key, _dst_key in zip(src_keys, dst_keys): - if dst_dict[_dst_key].shape != src_dict[_src_key].shape: - warnings.warn(f"Param. shape changed from {dst_dict[_dst_key].shape} to {src_dict[_src_key].shape}.") - dst_dict[_dst_key] = src_dict[_src_key] - updated_keys.append(_dst_key) + if dst_key in dst_dict and dst_key not in to_skip: + if dst_dict[dst_key].shape != src_dict[s].shape: + warnings.warn(f"Param. shape changed from {dst_dict[dst_key].shape} to {src_dict[s].shape}.") + dst_dict[dst_key] = src_dict[s] + updated_keys.append(dst_key) updated_keys = sorted(set(updated_keys)) unchanged_keys = sorted(set(all_keys).difference(updated_keys)) diff --git a/tests/test_copy_model_state.py b/tests/test_copy_model_state.py index f70f99c99f..2e7513b234 100644 --- a/tests/test_copy_model_state.py +++ b/tests/test_copy_model_state.py @@ -134,7 +134,7 @@ def test_set_map_across(self, device_0, device_1): model_two.to(device_1) # test weight map model_dict, ch, unch = copy_model_state( - model_one, model_two, mapping={"layer_1.weight": "^layer.weight", "layer_1.bias": "layer_1.weight"} + model_one, model_two, mapping={"layer_1.weight": "layer.weight", "layer_1.bias": "layer_1.weight"} ) model_one.load_state_dict(model_dict) x = np.random.randn(4, 10) From 3138b6ad47e6e4afe866cf8d556352b91f00c75f Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Mon, 11 Sep 2023 18:55:29 +0800 Subject: [PATCH 4/7] add `filter_func` Signed-off-by: KumoLiu --- monai/networks/utils.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/monai/networks/utils.py b/monai/networks/utils.py index 83429a2837..0b93b0db27 100644 --- a/monai/networks/utils.py +++ b/monai/networks/utils.py @@ -478,6 +478,7 @@ def copy_model_state( mapping=None, exclude_vars=None, inplace=True, + filter_func=None, ): """ Compute a module state_dict, of which the keys are the same as `dst`. The values of `dst` are overwritten @@ -536,6 +537,12 @@ def copy_model_state( warnings.warn(f"Param. shape changed from {dst_dict[dst_key].shape} to {src_dict[s].shape}.") dst_dict[dst_key] = src_dict[s] updated_keys.append(dst_key) + if filter_func is not None: + for key, value in src_dict.items(): + new_pair = filter_func(key, value) + if new_pair is not None: + dst_dict[new_pair[0]] = new_pair[1] + updated_keys.append(new_pair[0]) updated_keys = sorted(set(updated_keys)) unchanged_keys = sorted(set(all_keys).difference(updated_keys)) From 0feaec44e5c5fe8b3649a06eec83138a6feecb76 Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Tue, 12 Sep 2023 13:54:03 +0800 Subject: [PATCH 5/7] add unittests Signed-off-by: KumoLiu --- monai/networks/nets/swin_unetr.py | 31 +++++++++++++++++++++ monai/networks/utils.py | 6 +++-- tests/test_swin_unetr.py | 42 ++++++++++++++++++++++++++++- tests/testing_data/data_config.json | 5 ++++ 4 files changed, 81 insertions(+), 3 deletions(-) diff --git a/monai/networks/nets/swin_unetr.py b/monai/networks/nets/swin_unetr.py index 6cd9719611..73aa152b0c 100644 --- a/monai/networks/nets/swin_unetr.py +++ b/monai/networks/nets/swin_unetr.py @@ -1055,3 +1055,34 @@ def forward(self, x, normalize=True): x4 = self.layers4[0](x3.contiguous()) x4_out = self.proj_out(x4, normalize) return [x0_out, x1_out, x2_out, x3_out, x4_out] + + +def filter_swinunetr(key, value): + """ + A filter function used to filter the pretrained weights from [1], then the weights can be loaded into MONAI SwinUNETR Model. + [1] "Valanarasu JM et al., Disruptive Autoencoders: Leveraging Low-level features for 3D Medical Image Pre-training + " + + Args: + key: the key in the source state dict used for the update. + value: the value in the source state dict used for the update. + + """ + if key in [ + "encoder.mask_token", + "encoder.norm.weight", + "encoder.norm.bias", + "out.conv.conv.weight", + "out.conv.conv.bias", + ]: + return None + + if key[:8] == "encoder.": + if key[8:19] == "patch_embed": + new_key = "swinViT." + key[8:] + else: + new_key = "swinViT." + key[8:18] + key[20:] + + return new_key, value + else: + return None diff --git a/monai/networks/utils.py b/monai/networks/utils.py index 0b93b0db27..32704c6095 100644 --- a/monai/networks/utils.py +++ b/monai/networks/utils.py @@ -491,7 +491,7 @@ def copy_model_state( Args: dst: a pytorch module or state dict to be updated. - src: a pytorch module or state dist used to get the values used for the update. + src: a pytorch module or state dict used to get the values used for the update. dst_prefix: `dst` key prefix, so that `dst[dst_prefix + src_key]` will be assigned to the value of `src[src_key]`. mapping: a `{"src_key": "dst_key"}` dict, indicating that `dst[dst_prefix + dst_key]` @@ -500,6 +500,8 @@ def copy_model_state( so that their values are not overwritten by `src`. inplace: whether to set the `dst` module with the updated `state_dict` via `load_state_dict`. This option is only available when `dst` is a `torch.nn.Module`. + filter_func: a filter function used to filter the weights to be loaded. + See 'filter_swinunetr' in "monai.networks.nets.swin_unetr.py". Examples: .. code-block:: python @@ -540,7 +542,7 @@ def copy_model_state( if filter_func is not None: for key, value in src_dict.items(): new_pair = filter_func(key, value) - if new_pair is not None: + if new_pair is not None and new_pair[0] not in to_skip: dst_dict[new_pair[0]] = new_pair[1] updated_keys.append(new_pair[0]) diff --git a/tests/test_swin_unetr.py b/tests/test_swin_unetr.py index 636fcc9e31..803d59c321 100644 --- a/tests/test_swin_unetr.py +++ b/tests/test_swin_unetr.py @@ -11,15 +11,20 @@ from __future__ import annotations +import os +import tempfile import unittest from unittest import skipUnless import torch from parameterized import parameterized +from monai.apps import download_url from monai.networks import eval_mode -from monai.networks.nets.swin_unetr import PatchMerging, PatchMergingV2, SwinUNETR +from monai.networks.nets.swin_unetr import PatchMerging, PatchMergingV2, SwinUNETR, filter_swinunetr +from monai.networks.utils import copy_model_state from monai.utils import optional_import +from tests.utils import assert_allclose, skip_if_downloading_fails, skip_if_quick, testing_data_config einops, has_einops = optional_import("einops") @@ -51,6 +56,23 @@ case_idx += 1 TEST_CASE_SWIN_UNETR.append(test_case) +TEST_CASE_FILTER = [ + [ + { + "img_size": (96, 96, 96), + "in_channels": 1, + "out_channels": 14, + "feature_size": 48, + "drop_rate": 0.0, + "attn_drop_rate": 0.0, + "dropout_path_rate": 0.0, + "use_checkpoint": True, + }, + "swinViT.layers1.0.blocks.0.norm1.weight", + torch.tensor([0.9473, 0.9343, 0.8566, 0.8487, 0.8065, 0.7779, 0.6333, 0.5555]), + ] +] + class TestSWINUNETR(unittest.TestCase): @parameterized.expand(TEST_CASE_SWIN_UNETR) @@ -93,6 +115,24 @@ def test_patch_merging(self): t = PatchMerging(dim)(torch.zeros((1, 21, 20, 20, dim))) self.assertEqual(t.shape, torch.Size([1, 11, 10, 10, 20])) + @parameterized.expand(TEST_CASE_FILTER) + @skip_if_quick + def test_filter_swinunetr(self, input_param, key, value): + with skip_if_downloading_fails(): + with tempfile.TemporaryDirectory() as tempdir: + file_name = "ssl_pretrained_weights.pth" + data_spec = testing_data_config("models", f"{file_name.split('.', 1)[0]}") + weight_path = os.path.join(tempdir, file_name) + download_url( + data_spec["url"], weight_path, hash_val=data_spec["hash_val"], hash_type=data_spec["hash_type"] + ) + + ssl_weight = torch.load(weight_path)["model"] + net = SwinUNETR(**input_param) + dst_dict, loaded, not_loaded = copy_model_state(net, ssl_weight, filter_func=filter_swinunetr) + assert_allclose(dst_dict[key][:8], value, atol=1e-4, rtol=1e-4, type_test=False) + self.assertTrue(len(loaded) == 157 and len(not_loaded) == 2) + if __name__ == "__main__": unittest.main() diff --git a/tests/testing_data/data_config.json b/tests/testing_data/data_config.json index 4bdac6abba..c0666119d9 100644 --- a/tests/testing_data/data_config.json +++ b/tests/testing_data/data_config.json @@ -128,6 +128,11 @@ "url": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/se_resnext50_32x4d-a260b3a4.pth", "hash_type": "sha256", "hash_val": "a260b3a40f82dfe37c58d26a612bcf7bef0d27c6fed096226b0e4e9fb364168e" + }, + "ssl_pretrained_weights": { + "url": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/ssl_pretrained_weights.pth", + "hash_type": "sha256", + "hash_val": "c3564f40a6a051d3753a6d8fae5cc8eaf21ce8d82a9a3baf80748d15664055e8" } }, "configs": { From 67e0edbf250eadeff4f36198a7b15f5e8f586796 Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Tue, 12 Sep 2023 16:08:15 +0800 Subject: [PATCH 6/7] address comments Signed-off-by: KumoLiu --- monai/networks/nets/swin_unetr.py | 1 + 1 file changed, 1 insertion(+) diff --git a/monai/networks/nets/swin_unetr.py b/monai/networks/nets/swin_unetr.py index 73aa152b0c..196c188351 100644 --- a/monai/networks/nets/swin_unetr.py +++ b/monai/networks/nets/swin_unetr.py @@ -1060,6 +1060,7 @@ def forward(self, x, normalize=True): def filter_swinunetr(key, value): """ A filter function used to filter the pretrained weights from [1], then the weights can be loaded into MONAI SwinUNETR Model. + This function is typically used with `monai.networks.copy_model_state` [1] "Valanarasu JM et al., Disruptive Autoencoders: Leveraging Low-level features for 3D Medical Image Pre-training " From bbf75296a358fe1c26ed66a60b9b5d12c94a6e08 Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Tue, 12 Sep 2023 23:21:02 +0800 Subject: [PATCH 7/7] address comments Signed-off-by: KumoLiu --- monai/networks/nets/swin_unetr.py | 17 +++++++++++++++++ tests/test_swin_unetr.py | 11 +---------- 2 files changed, 18 insertions(+), 10 deletions(-) diff --git a/monai/networks/nets/swin_unetr.py b/monai/networks/nets/swin_unetr.py index 196c188351..5c08a3ca8a 100644 --- a/monai/networks/nets/swin_unetr.py +++ b/monai/networks/nets/swin_unetr.py @@ -1068,6 +1068,23 @@ def filter_swinunetr(key, value): key: the key in the source state dict used for the update. value: the value in the source state dict used for the update. + Examples:: + + import torch + from monai.apps import download_url + from monai.networks.utils import copy_model_state + from monai.networks.nets.swin_unetr import SwinUNETR, filter_swinunetr + + model = SwinUNETR(img_size=(96, 96, 96), in_channels=1, out_channels=3, feature_size=48) + resource = ( + "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/ssl_pretrained_weights.pth" + ) + ssl_weights_path = "./ssl_pretrained_weights.pth" + download_url(resource, ssl_weights_path) + ssl_weights = torch.load(ssl_weights_path)["model"] + + dst_dict, loaded, not_loaded = copy_model_state(model, ssl_weights, filter_func=filter_swinunetr) + """ if key in [ "encoder.mask_token", diff --git a/tests/test_swin_unetr.py b/tests/test_swin_unetr.py index 803d59c321..9197308aa5 100644 --- a/tests/test_swin_unetr.py +++ b/tests/test_swin_unetr.py @@ -58,16 +58,7 @@ TEST_CASE_FILTER = [ [ - { - "img_size": (96, 96, 96), - "in_channels": 1, - "out_channels": 14, - "feature_size": 48, - "drop_rate": 0.0, - "attn_drop_rate": 0.0, - "dropout_path_rate": 0.0, - "use_checkpoint": True, - }, + {"img_size": (96, 96, 96), "in_channels": 1, "out_channels": 14, "feature_size": 48, "use_checkpoint": True}, "swinViT.layers1.0.blocks.0.norm1.weight", torch.tensor([0.9473, 0.9343, 0.8566, 0.8487, 0.8065, 0.7779, 0.6333, 0.5555]), ]