Skip to content
49 changes: 49 additions & 0 deletions monai/networks/nets/swin_unetr.py
Original file line number Diff line number Diff line change
Expand Up @@ -1055,3 +1055,52 @@ def forward(self, x, normalize=True):
x4 = self.layers4[0](x3.contiguous())
x4_out = self.proj_out(x4, normalize)
return [x0_out, x1_out, x2_out, x3_out, x4_out]


def filter_swinunetr(key, value):
"""
A filter function used to filter the pretrained weights from [1], then the weights can be loaded into MONAI SwinUNETR Model.
This function is typically used with `monai.networks.copy_model_state`
[1] "Valanarasu JM et al., Disruptive Autoencoders: Leveraging Low-level features for 3D Medical Image Pre-training
<https://arxiv.org/abs/2307.16896>"

Args:
key: the key in the source state dict used for the update.
value: the value in the source state dict used for the update.

Examples::

import torch
from monai.apps import download_url
from monai.networks.utils import copy_model_state
from monai.networks.nets.swin_unetr import SwinUNETR, filter_swinunetr

model = SwinUNETR(img_size=(96, 96, 96), in_channels=1, out_channels=3, feature_size=48)
resource = (
"https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/ssl_pretrained_weights.pth"
)
ssl_weights_path = "./ssl_pretrained_weights.pth"
download_url(resource, ssl_weights_path)
ssl_weights = torch.load(ssl_weights_path)["model"]

dst_dict, loaded, not_loaded = copy_model_state(model, ssl_weights, filter_func=filter_swinunetr)

"""
if key in [
"encoder.mask_token",
"encoder.norm.weight",
"encoder.norm.bias",
"out.conv.conv.weight",
"out.conv.conv.bias",
]:
return None

if key[:8] == "encoder.":
if key[8:19] == "patch_embed":
new_key = "swinViT." + key[8:]
else:
new_key = "swinViT." + key[8:18] + key[20:]

return new_key, value
else:
return None
11 changes: 10 additions & 1 deletion monai/networks/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,6 +478,7 @@ def copy_model_state(
mapping=None,
exclude_vars=None,
inplace=True,
filter_func=None,
):
"""
Compute a module state_dict, of which the keys are the same as `dst`. The values of `dst` are overwritten
Expand All @@ -490,7 +491,7 @@ def copy_model_state(

Args:
dst: a pytorch module or state dict to be updated.
src: a pytorch module or state dist used to get the values used for the update.
src: a pytorch module or state dict used to get the values used for the update.
dst_prefix: `dst` key prefix, so that `dst[dst_prefix + src_key]`
will be assigned to the value of `src[src_key]`.
mapping: a `{"src_key": "dst_key"}` dict, indicating that `dst[dst_prefix + dst_key]`
Expand All @@ -499,6 +500,8 @@ def copy_model_state(
so that their values are not overwritten by `src`.
inplace: whether to set the `dst` module with the updated `state_dict` via `load_state_dict`.
This option is only available when `dst` is a `torch.nn.Module`.
filter_func: a filter function used to filter the weights to be loaded.
See 'filter_swinunetr' in "monai.networks.nets.swin_unetr.py".

Examples:
.. code-block:: python
Expand Down Expand Up @@ -536,6 +539,12 @@ def copy_model_state(
warnings.warn(f"Param. shape changed from {dst_dict[dst_key].shape} to {src_dict[s].shape}.")
dst_dict[dst_key] = src_dict[s]
updated_keys.append(dst_key)
if filter_func is not None:
for key, value in src_dict.items():
new_pair = filter_func(key, value)
if new_pair is not None and new_pair[0] not in to_skip:
dst_dict[new_pair[0]] = new_pair[1]
updated_keys.append(new_pair[0])

updated_keys = sorted(set(updated_keys))
unchanged_keys = sorted(set(all_keys).difference(updated_keys))
Expand Down
33 changes: 32 additions & 1 deletion tests/test_swin_unetr.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,20 @@

from __future__ import annotations

import os
import tempfile
import unittest
from unittest import skipUnless

import torch
from parameterized import parameterized

from monai.apps import download_url
from monai.networks import eval_mode
from monai.networks.nets.swin_unetr import PatchMerging, PatchMergingV2, SwinUNETR
from monai.networks.nets.swin_unetr import PatchMerging, PatchMergingV2, SwinUNETR, filter_swinunetr
from monai.networks.utils import copy_model_state
from monai.utils import optional_import
from tests.utils import assert_allclose, skip_if_downloading_fails, skip_if_quick, testing_data_config

einops, has_einops = optional_import("einops")

Expand Down Expand Up @@ -51,6 +56,14 @@
case_idx += 1
TEST_CASE_SWIN_UNETR.append(test_case)

TEST_CASE_FILTER = [
[
{"img_size": (96, 96, 96), "in_channels": 1, "out_channels": 14, "feature_size": 48, "use_checkpoint": True},
"swinViT.layers1.0.blocks.0.norm1.weight",
torch.tensor([0.9473, 0.9343, 0.8566, 0.8487, 0.8065, 0.7779, 0.6333, 0.5555]),
]
]


class TestSWINUNETR(unittest.TestCase):
@parameterized.expand(TEST_CASE_SWIN_UNETR)
Expand Down Expand Up @@ -93,6 +106,24 @@ def test_patch_merging(self):
t = PatchMerging(dim)(torch.zeros((1, 21, 20, 20, dim)))
self.assertEqual(t.shape, torch.Size([1, 11, 10, 10, 20]))

@parameterized.expand(TEST_CASE_FILTER)
@skip_if_quick
def test_filter_swinunetr(self, input_param, key, value):
with skip_if_downloading_fails():
with tempfile.TemporaryDirectory() as tempdir:
file_name = "ssl_pretrained_weights.pth"
data_spec = testing_data_config("models", f"{file_name.split('.', 1)[0]}")
weight_path = os.path.join(tempdir, file_name)
download_url(
data_spec["url"], weight_path, hash_val=data_spec["hash_val"], hash_type=data_spec["hash_type"]
)

ssl_weight = torch.load(weight_path)["model"]
net = SwinUNETR(**input_param)
dst_dict, loaded, not_loaded = copy_model_state(net, ssl_weight, filter_func=filter_swinunetr)
assert_allclose(dst_dict[key][:8], value, atol=1e-4, rtol=1e-4, type_test=False)
self.assertTrue(len(loaded) == 157 and len(not_loaded) == 2)


if __name__ == "__main__":
unittest.main()
5 changes: 5 additions & 0 deletions tests/testing_data/data_config.json
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,11 @@
"url": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/se_resnext50_32x4d-a260b3a4.pth",
"hash_type": "sha256",
"hash_val": "a260b3a40f82dfe37c58d26a612bcf7bef0d27c6fed096226b0e4e9fb364168e"
},
"ssl_pretrained_weights": {
"url": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/ssl_pretrained_weights.pth",
"hash_type": "sha256",
"hash_val": "c3564f40a6a051d3753a6d8fae5cc8eaf21ce8d82a9a3baf80748d15664055e8"
}
},
"configs": {
Expand Down