From 31d4bee73d0beb1aa69144f71a91fd803e3f4361 Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Wed, 29 Nov 2023 12:04:59 +0800 Subject: [PATCH 1/6] fix #7265 Signed-off-by: KumoLiu --- tests/test_swin_unetr.py | 39 +++++++++++++++++++++------------------ 1 file changed, 21 insertions(+), 18 deletions(-) diff --git a/tests/test_swin_unetr.py b/tests/test_swin_unetr.py index 9f6b1e7c0a..6921718769 100644 --- a/tests/test_swin_unetr.py +++ b/tests/test_swin_unetr.py @@ -24,7 +24,7 @@ from monai.networks.nets.swin_unetr import PatchMerging, PatchMergingV2, SwinUNETR, filter_swinunetr from monai.networks.utils import copy_model_state from monai.utils import optional_import -from tests.utils import assert_allclose, skip_if_downloading_fails, skip_if_no_cuda, skip_if_quick, testing_data_config +from tests.utils import assert_allclose, skip_if_downloading_fails, skip_if_no_cuda, skip_if_quick, testing_data_config, SkipIfBeforePyTorchVersion einops, has_einops = optional_import("einops") @@ -38,23 +38,25 @@ for img_size in ((64, 32, 192), (96, 32)): for feature_size in [12]: for norm_name in ["instance"]: - test_case = [ - { - "spatial_dims": len(img_size), - "in_channels": in_channels, - "out_channels": out_channels, - "img_size": img_size, - "feature_size": feature_size, - "depths": depth, - "norm_name": norm_name, - "attn_drop_rate": attn_drop_rate, - "downsample": test_merging_mode[case_idx % 4], - }, - (2, in_channels, *img_size), - (2, out_channels, *img_size), - ] - case_idx += 1 - TEST_CASE_SWIN_UNETR.append(test_case) + for use_checkpoint in [True, False]: + test_case = [ + { + "spatial_dims": len(img_size), + "in_channels": in_channels, + "out_channels": out_channels, + "img_size": img_size, + "feature_size": feature_size, + "depths": depth, + "norm_name": norm_name, + "attn_drop_rate": attn_drop_rate, + "downsample": test_merging_mode[case_idx % 4], + "use_checkpoint": use_checkpoint, + }, + (2, in_channels, *img_size), + (2, out_channels, *img_size), + ] + case_idx += 1 + TEST_CASE_SWIN_UNETR.append(test_case) TEST_CASE_FILTER = [ [ @@ -65,6 +67,7 @@ ] +@SkipIfBeforePyTorchVersion((1, 10)) class TestSWINUNETR(unittest.TestCase): @parameterized.expand(TEST_CASE_SWIN_UNETR) @skipUnless(has_einops, "Requires einops") From b012aa9d41935afc6ba91e0cc523513caf3e2e0b Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Wed, 29 Nov 2023 14:16:03 +0800 Subject: [PATCH 2/6] fix flake8 Signed-off-by: KumoLiu --- tests/test_swin_unetr.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/tests/test_swin_unetr.py b/tests/test_swin_unetr.py index 6921718769..10dfa87b0b 100644 --- a/tests/test_swin_unetr.py +++ b/tests/test_swin_unetr.py @@ -24,7 +24,14 @@ from monai.networks.nets.swin_unetr import PatchMerging, PatchMergingV2, SwinUNETR, filter_swinunetr from monai.networks.utils import copy_model_state from monai.utils import optional_import -from tests.utils import assert_allclose, skip_if_downloading_fails, skip_if_no_cuda, skip_if_quick, testing_data_config, SkipIfBeforePyTorchVersion +from tests.utils import ( + SkipIfBeforePyTorchVersion, + assert_allclose, + skip_if_downloading_fails, + skip_if_no_cuda, + skip_if_quick, + testing_data_config, +) einops, has_einops = optional_import("einops") From f613581a94482800e4d02f56b0a7018e0f384c05 Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Wed, 29 Nov 2023 22:59:00 +0800 Subject: [PATCH 3/6] fix ci Signed-off-by: KumoLiu --- tests/test_swin_unetr.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_swin_unetr.py b/tests/test_swin_unetr.py index 10dfa87b0b..d676197437 100644 --- a/tests/test_swin_unetr.py +++ b/tests/test_swin_unetr.py @@ -74,7 +74,7 @@ ] -@SkipIfBeforePyTorchVersion((1, 10)) +@SkipIfBeforePyTorchVersion((1, 11)) class TestSWINUNETR(unittest.TestCase): @parameterized.expand(TEST_CASE_SWIN_UNETR) @skipUnless(has_einops, "Requires einops") From 163bdd6e5049681fdb13e9bcda9de91e738f106a Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Thu, 30 Nov 2023 00:10:43 +0800 Subject: [PATCH 4/6] address comments Signed-off-by: KumoLiu --- tests/test_swin_unetr.py | 45 +++++++++++++++++++++------------------- 1 file changed, 24 insertions(+), 21 deletions(-) diff --git a/tests/test_swin_unetr.py b/tests/test_swin_unetr.py index d676197437..df9d1acc80 100644 --- a/tests/test_swin_unetr.py +++ b/tests/test_swin_unetr.py @@ -45,36 +45,35 @@ for img_size in ((64, 32, 192), (96, 32)): for feature_size in [12]: for norm_name in ["instance"]: - for use_checkpoint in [True, False]: - test_case = [ - { - "spatial_dims": len(img_size), - "in_channels": in_channels, - "out_channels": out_channels, - "img_size": img_size, - "feature_size": feature_size, - "depths": depth, - "norm_name": norm_name, - "attn_drop_rate": attn_drop_rate, - "downsample": test_merging_mode[case_idx % 4], - "use_checkpoint": use_checkpoint, - }, - (2, in_channels, *img_size), - (2, out_channels, *img_size), - ] - case_idx += 1 - TEST_CASE_SWIN_UNETR.append(test_case) + test_case = [ + { + "spatial_dims": len(img_size), + "in_channels": in_channels, + "out_channels": out_channels, + "img_size": img_size, + "feature_size": feature_size, + "depths": depth, + "norm_name": norm_name, + "attn_drop_rate": attn_drop_rate, + "downsample": test_merging_mode[case_idx % 4], + }, + (2, in_channels, *img_size), + (2, out_channels, *img_size), + ] + case_idx += 1 + TEST_CASE_SWIN_UNETR.append(test_case) TEST_CASE_FILTER = [ [ {"img_size": (96, 96, 96), "in_channels": 1, "out_channels": 14, "feature_size": 48, "use_checkpoint": True}, + (1, 1, 96, 96, 96), + (1, 14, 96, 96, 96), "swinViT.layers1.0.blocks.0.norm1.weight", torch.tensor([0.9473, 0.9343, 0.8566, 0.8487, 0.8065, 0.7779, 0.6333, 0.5555]), ] ] -@SkipIfBeforePyTorchVersion((1, 11)) class TestSWINUNETR(unittest.TestCase): @parameterized.expand(TEST_CASE_SWIN_UNETR) @skipUnless(has_einops, "Requires einops") @@ -119,7 +118,8 @@ def test_patch_merging(self): @parameterized.expand(TEST_CASE_FILTER) @skip_if_quick @skip_if_no_cuda - def test_filter_swinunetr(self, input_param, key, value): + @SkipIfBeforePyTorchVersion((1, 11)) + def test_filter_swinunetr(self, input_param, input_shape, expected_shape, key, value): with skip_if_downloading_fails(): with tempfile.TemporaryDirectory() as tempdir: file_name = "ssl_pretrained_weights.pth" @@ -131,6 +131,9 @@ def test_filter_swinunetr(self, input_param, key, value): ssl_weight = torch.load(weight_path)["model"] net = SwinUNETR(**input_param) + with eval_mode(net): + result = net(torch.randn(input_shape)) + self.assertEqual(result.shape, expected_shape) dst_dict, loaded, not_loaded = copy_model_state(net, ssl_weight, filter_func=filter_swinunetr) assert_allclose(dst_dict[key][:8], value, atol=1e-4, rtol=1e-4, type_test=False) self.assertTrue(len(loaded) == 157 and len(not_loaded) == 2) From 6ccbe7185c06c8c6ebc126cbc55206e046240c20 Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Thu, 30 Nov 2023 11:24:09 +0800 Subject: [PATCH 5/6] address comments Signed-off-by: KumoLiu --- tests/test_swin_unetr.py | 46 +++++++++++++++++++--------------------- 1 file changed, 22 insertions(+), 24 deletions(-) diff --git a/tests/test_swin_unetr.py b/tests/test_swin_unetr.py index df9d1acc80..ec352bd788 100644 --- a/tests/test_swin_unetr.py +++ b/tests/test_swin_unetr.py @@ -31,6 +31,7 @@ skip_if_no_cuda, skip_if_quick, testing_data_config, + pytorch_after, ) einops, has_einops = optional_import("einops") @@ -38,6 +39,7 @@ TEST_CASE_SWIN_UNETR = [] case_idx = 0 test_merging_mode = ["mergingv2", "merging", PatchMerging, PatchMergingV2] +checkpoint_vals = [True, False] if pytorch_after(1,11) else [False] for attn_drop_rate in [0.4]: for in_channels in [1]: for depth in [[2, 1, 1, 1], [1, 2, 1, 1]]: @@ -45,29 +47,29 @@ for img_size in ((64, 32, 192), (96, 32)): for feature_size in [12]: for norm_name in ["instance"]: - test_case = [ - { - "spatial_dims": len(img_size), - "in_channels": in_channels, - "out_channels": out_channels, - "img_size": img_size, - "feature_size": feature_size, - "depths": depth, - "norm_name": norm_name, - "attn_drop_rate": attn_drop_rate, - "downsample": test_merging_mode[case_idx % 4], - }, - (2, in_channels, *img_size), - (2, out_channels, *img_size), - ] - case_idx += 1 - TEST_CASE_SWIN_UNETR.append(test_case) + for use_checkpoint in checkpoint_vals: + test_case = [ + { + "spatial_dims": len(img_size), + "in_channels": in_channels, + "out_channels": out_channels, + "img_size": img_size, + "feature_size": feature_size, + "depths": depth, + "norm_name": norm_name, + "attn_drop_rate": attn_drop_rate, + "downsample": test_merging_mode[case_idx % 4], + "use_checkpoint": use_checkpoint, + }, + (2, in_channels, *img_size), + (2, out_channels, *img_size), + ] + case_idx += 1 + TEST_CASE_SWIN_UNETR.append(test_case) TEST_CASE_FILTER = [ [ {"img_size": (96, 96, 96), "in_channels": 1, "out_channels": 14, "feature_size": 48, "use_checkpoint": True}, - (1, 1, 96, 96, 96), - (1, 14, 96, 96, 96), "swinViT.layers1.0.blocks.0.norm1.weight", torch.tensor([0.9473, 0.9343, 0.8566, 0.8487, 0.8065, 0.7779, 0.6333, 0.5555]), ] @@ -118,8 +120,7 @@ def test_patch_merging(self): @parameterized.expand(TEST_CASE_FILTER) @skip_if_quick @skip_if_no_cuda - @SkipIfBeforePyTorchVersion((1, 11)) - def test_filter_swinunetr(self, input_param, input_shape, expected_shape, key, value): + def test_filter_swinunetr(self, input_param, key, value): with skip_if_downloading_fails(): with tempfile.TemporaryDirectory() as tempdir: file_name = "ssl_pretrained_weights.pth" @@ -131,9 +132,6 @@ def test_filter_swinunetr(self, input_param, input_shape, expected_shape, key, v ssl_weight = torch.load(weight_path)["model"] net = SwinUNETR(**input_param) - with eval_mode(net): - result = net(torch.randn(input_shape)) - self.assertEqual(result.shape, expected_shape) dst_dict, loaded, not_loaded = copy_model_state(net, ssl_weight, filter_func=filter_swinunetr) assert_allclose(dst_dict[key][:8], value, atol=1e-4, rtol=1e-4, type_test=False) self.assertTrue(len(loaded) == 157 and len(not_loaded) == 2) From 9d338a3f130e4c7cdc094349e03b23e4058881c3 Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Thu, 30 Nov 2023 11:26:58 +0800 Subject: [PATCH 6/6] fix flake8 Signed-off-by: KumoLiu --- tests/test_swin_unetr.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tests/test_swin_unetr.py b/tests/test_swin_unetr.py index ec352bd788..e34e5a3c8e 100644 --- a/tests/test_swin_unetr.py +++ b/tests/test_swin_unetr.py @@ -25,13 +25,12 @@ from monai.networks.utils import copy_model_state from monai.utils import optional_import from tests.utils import ( - SkipIfBeforePyTorchVersion, assert_allclose, + pytorch_after, skip_if_downloading_fails, skip_if_no_cuda, skip_if_quick, testing_data_config, - pytorch_after, ) einops, has_einops = optional_import("einops") @@ -39,7 +38,7 @@ TEST_CASE_SWIN_UNETR = [] case_idx = 0 test_merging_mode = ["mergingv2", "merging", PatchMerging, PatchMergingV2] -checkpoint_vals = [True, False] if pytorch_after(1,11) else [False] +checkpoint_vals = [True, False] if pytorch_after(1, 11) else [False] for attn_drop_rate in [0.4]: for in_channels in [1]: for depth in [[2, 1, 1, 1], [1, 2, 1, 1]]: