diff --git a/tests/test_swin_unetr.py b/tests/test_swin_unetr.py index 9f6b1e7c0a..e34e5a3c8e 100644 --- a/tests/test_swin_unetr.py +++ b/tests/test_swin_unetr.py @@ -24,13 +24,21 @@ from monai.networks.nets.swin_unetr import PatchMerging, PatchMergingV2, SwinUNETR, filter_swinunetr from monai.networks.utils import copy_model_state from monai.utils import optional_import -from tests.utils import assert_allclose, skip_if_downloading_fails, skip_if_no_cuda, skip_if_quick, testing_data_config +from tests.utils import ( + assert_allclose, + pytorch_after, + skip_if_downloading_fails, + skip_if_no_cuda, + skip_if_quick, + testing_data_config, +) einops, has_einops = optional_import("einops") TEST_CASE_SWIN_UNETR = [] case_idx = 0 test_merging_mode = ["mergingv2", "merging", PatchMerging, PatchMergingV2] +checkpoint_vals = [True, False] if pytorch_after(1, 11) else [False] for attn_drop_rate in [0.4]: for in_channels in [1]: for depth in [[2, 1, 1, 1], [1, 2, 1, 1]]: @@ -38,23 +46,25 @@ for img_size in ((64, 32, 192), (96, 32)): for feature_size in [12]: for norm_name in ["instance"]: - test_case = [ - { - "spatial_dims": len(img_size), - "in_channels": in_channels, - "out_channels": out_channels, - "img_size": img_size, - "feature_size": feature_size, - "depths": depth, - "norm_name": norm_name, - "attn_drop_rate": attn_drop_rate, - "downsample": test_merging_mode[case_idx % 4], - }, - (2, in_channels, *img_size), - (2, out_channels, *img_size), - ] - case_idx += 1 - TEST_CASE_SWIN_UNETR.append(test_case) + for use_checkpoint in checkpoint_vals: + test_case = [ + { + "spatial_dims": len(img_size), + "in_channels": in_channels, + "out_channels": out_channels, + "img_size": img_size, + "feature_size": feature_size, + "depths": depth, + "norm_name": norm_name, + "attn_drop_rate": attn_drop_rate, + "downsample": test_merging_mode[case_idx % 4], + "use_checkpoint": use_checkpoint, + }, + (2, in_channels, *img_size), + (2, out_channels, *img_size), + ] + case_idx += 1 + TEST_CASE_SWIN_UNETR.append(test_case) TEST_CASE_FILTER = [ [