diff --git a/monai/networks/nets/swin_unetr.py b/monai/networks/nets/swin_unetr.py index 5c08a3ca8a..5bff07581b 100644 --- a/monai/networks/nets/swin_unetr.py +++ b/monai/networks/nets/swin_unetr.py @@ -20,11 +20,13 @@ import torch.nn.functional as F import torch.utils.checkpoint as checkpoint from torch.nn import LayerNorm +from typing_extensions import Final from monai.networks.blocks import MLPBlock as Mlp from monai.networks.blocks import PatchEmbed, UnetOutBlock, UnetrBasicBlock, UnetrUpBlock from monai.networks.layers import DropPath, trunc_normal_ from monai.utils import ensure_tuple_rep, look_up_option, optional_import +from monai.utils.deprecate_utils import deprecated_arg rearrange, _ = optional_import("einops", name="rearrange") @@ -49,6 +51,15 @@ class SwinUNETR(nn.Module): " """ + patch_size: Final[int] = 2 + + @deprecated_arg( + name="img_size", + since="1.3", + removed="1.5", + msg_suffix="The img_size argument is not required anymore and " + "checks on the input size are run during forward().", + ) def __init__( self, img_size: Sequence[int] | int, @@ -69,7 +80,10 @@ def __init__( ) -> None: """ Args: - img_size: dimension of input image. + img_size: spatial dimension of input image. + This argument is only used for checking that the input image size is divisible by the patch size. + The tensor passed to forward() can have a dynamic shape as long as its spatial dimensions are divisible by 2**5. + It will be removed in an upcoming version. in_channels: dimension of input channels. out_channels: dimension of output channels. feature_size: dimension of network feature size. @@ -103,16 +117,13 @@ def __init__( super().__init__() img_size = ensure_tuple_rep(img_size, spatial_dims) - patch_size = ensure_tuple_rep(2, spatial_dims) + patch_sizes = ensure_tuple_rep(self.patch_size, spatial_dims) window_size = ensure_tuple_rep(7, spatial_dims) if spatial_dims not in (2, 3): raise ValueError("spatial dimension should be 2 or 3.") - for m, p in zip(img_size, patch_size): - for i in range(5): - if m % np.power(p, i + 1) != 0: - raise ValueError("input image size (img_size) should be divisible by stage-wise image resolution.") + self._check_input_size(img_size) if not (0 <= drop_rate <= 1): raise ValueError("dropout rate should be between 0 and 1.") @@ -132,7 +143,7 @@ def __init__( in_chans=in_channels, embed_dim=feature_size, window_size=window_size, - patch_size=patch_size, + patch_size=patch_sizes, depths=depths, num_heads=num_heads, mlp_ratio=4.0, @@ -297,7 +308,20 @@ def load_from(self, weights): weights["state_dict"]["module.layers4.0.downsample.norm.bias"] ) + @torch.jit.unused + def _check_input_size(self, spatial_shape): + img_size = np.array(spatial_shape) + remainder = (img_size % np.power(self.patch_size, 5)) > 0 + if remainder.any(): + wrong_dims = (np.where(remainder)[0] + 2).tolist() + raise ValueError( + f"spatial dimensions {wrong_dims} of input image (spatial shape: {spatial_shape})" + f" must be divisible by {self.patch_size}**5." + ) + def forward(self, x_in): + if not torch.jit.is_scripting(): + self._check_input_size(x_in.shape[2:]) hidden_states_out = self.swinViT(x_in, self.normalize) enc0 = self.encoder1(x_in) enc1 = self.encoder2(hidden_states_out[0]) @@ -669,12 +693,12 @@ def load_from(self, weights, n_block, layer): def forward(self, x, mask_matrix): shortcut = x if self.use_checkpoint: - x = checkpoint.checkpoint(self.forward_part1, x, mask_matrix) + x = checkpoint.checkpoint(self.forward_part1, x, mask_matrix, use_reentrant=False) else: x = self.forward_part1(x, mask_matrix) x = shortcut + self.drop_path(x) if self.use_checkpoint: - x = x + checkpoint.checkpoint(self.forward_part2, x) + x = x + checkpoint.checkpoint(self.forward_part2, x, use_reentrant=False) else: x = x + self.forward_part2(x) return x diff --git a/runtests.sh b/runtests.sh index f71bc4c9ff..40cc09144f 100755 --- a/runtests.sh +++ b/runtests.sh @@ -261,7 +261,7 @@ do doBlackFormat=true doIsortFormat=true doFlake8Format=true - doPylintFormat=true + # doPylintFormat=true # https://github.com/Project-MONAI/MONAI/issues/7094 doRuffFormat=true doCopyRight=true ;;