diff --git a/monai/networks/nets/vitautoenc.py b/monai/networks/nets/vitautoenc.py index 6ae8135438..12d7d4e376 100644 --- a/monai/networks/nets/vitautoenc.py +++ b/monai/networks/nets/vitautoenc.py @@ -11,6 +11,7 @@ from __future__ import annotations +import math from collections.abc import Sequence import torch @@ -19,7 +20,7 @@ from monai.networks.blocks.patchembedding import PatchEmbeddingBlock from monai.networks.blocks.transformerblock import TransformerBlock from monai.networks.layers import Conv -from monai.utils import ensure_tuple_rep +from monai.utils import ensure_tuple_rep, is_sqrt __all__ = ["ViTAutoEnc"] @@ -78,9 +79,14 @@ def __init__( """ super().__init__() - + if not is_sqrt(patch_size): + raise ValueError(f"patch_size should be square number, got {patch_size}.") self.patch_size = ensure_tuple_rep(patch_size, spatial_dims) + self.img_size = ensure_tuple_rep(img_size, spatial_dims) self.spatial_dims = spatial_dims + for m, p in zip(self.img_size, self.patch_size): + if m % p != 0: + raise ValueError(f"patch_size={patch_size} should be divisible by img_size={img_size}.") self.patch_embedding = PatchEmbeddingBlock( in_channels=in_channels, @@ -100,12 +106,12 @@ def __init__( ) self.norm = nn.LayerNorm(hidden_size) - new_patch_size = [4] * self.spatial_dims conv_trans = Conv[Conv.CONVTRANS, self.spatial_dims] # self.conv3d_transpose* is to be compatible with existing 3d model weights. - self.conv3d_transpose = conv_trans(hidden_size, deconv_chns, kernel_size=new_patch_size, stride=new_patch_size) + up_kernel_size = [int(math.sqrt(i)) for i in self.patch_size] + self.conv3d_transpose = conv_trans(hidden_size, deconv_chns, kernel_size=up_kernel_size, stride=up_kernel_size) self.conv3d_transpose_1 = conv_trans( - in_channels=deconv_chns, out_channels=out_channels, kernel_size=new_patch_size, stride=new_patch_size + in_channels=deconv_chns, out_channels=out_channels, kernel_size=up_kernel_size, stride=up_kernel_size ) def forward(self, x): diff --git a/monai/utils/__init__.py b/monai/utils/__init__.py index 5fa62ed36b..58bde3f5e8 100644 --- a/monai/utils/__init__.py +++ b/monai/utils/__init__.py @@ -80,6 +80,7 @@ is_module_ver_at_least, is_scalar, is_scalar_tensor, + is_sqrt, issequenceiterable, list_to_dict, path_to_uri, diff --git a/monai/utils/misc.py b/monai/utils/misc.py index bad3f2b9a9..18f05b6e9f 100644 --- a/monai/utils/misc.py +++ b/monai/utils/misc.py @@ -13,6 +13,7 @@ import inspect import itertools +import math import os import pprint import random @@ -853,3 +854,13 @@ def run_cmd(cmd_list: list[str], **kwargs: Any) -> subprocess.CompletedProcess: output = str(e.stdout.decode(errors="replace")) errors = str(e.stderr.decode(errors="replace")) raise RuntimeError(f"subprocess call error {e.returncode}: {errors}, {output}.") from e + + +def is_sqrt(num: Sequence[int] | int) -> bool: + """ + Determine if the input is a square number or a squence of square numbers. + """ + num = ensure_tuple(num) + sqrt_num = [int(math.sqrt(_num)) for _num in num] + ret = [_i * _j for _i, _j in zip(sqrt_num, sqrt_num)] + return ensure_tuple(ret) == num diff --git a/tests/test_vitautoenc.py b/tests/test_vitautoenc.py index 7a3956e964..5e95d3c7fb 100644 --- a/tests/test_vitautoenc.py +++ b/tests/test_vitautoenc.py @@ -49,7 +49,7 @@ { "in_channels": 1, "img_size": (512, 512, 32), - "patch_size": (16, 16, 16), + "patch_size": (64, 64, 16), "hidden_size": 768, "mlp_dim": 3072, "num_layers": 4, @@ -147,6 +147,19 @@ def test_ill_arg(self): dropout_rate=0.3, ) + with self.assertRaises(ValueError): + ViTAutoEnc( + in_channels=4, + img_size=(96, 96, 96), + patch_size=(9, 9, 9), + hidden_size=768, + mlp_dim=3072, + num_layers=12, + num_heads=12, + pos_embed="perc", + dropout_rate=0.3, + ) + if __name__ == "__main__": unittest.main()