From 103d8972e75c0676dae09ee040d6a48ecef13c54 Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Tue, 18 Jul 2023 13:32:05 +0800 Subject: [PATCH 1/5] add `is_sqrt` Signed-off-by: KumoLiu --- monai/utils/__init__.py | 1 + monai/utils/misc.py | 8 ++++++++ 2 files changed, 9 insertions(+) diff --git a/monai/utils/__init__.py b/monai/utils/__init__.py index 5fa62ed36b..b5c0ea2a83 100644 --- a/monai/utils/__init__.py +++ b/monai/utils/__init__.py @@ -94,6 +94,7 @@ str2list, to_tuple_of_dictionaries, zip_with, + is_sqrt, ) from .module import ( InvalidPyTorchVersionError, diff --git a/monai/utils/misc.py b/monai/utils/misc.py index bad3f2b9a9..3dc7f3e054 100644 --- a/monai/utils/misc.py +++ b/monai/utils/misc.py @@ -14,6 +14,7 @@ import inspect import itertools import os +import math import pprint import random import shutil @@ -853,3 +854,10 @@ def run_cmd(cmd_list: list[str], **kwargs: Any) -> subprocess.CompletedProcess: output = str(e.stdout.decode(errors="replace")) errors = str(e.stderr.decode(errors="replace")) raise RuntimeError(f"subprocess call error {e.returncode}: {errors}, {output}.") from e + + +def is_sqrt(size): + size = ensure_tuple(size) + sqrt_size = [int(math.sqrt(_size)) for _size in size] + ret = [_i * _j for _i, _j in zip(sqrt_size, sqrt_size)] + return ensure_tuple(ret) == size From 82be756682e8a6c6d5df8924f621ed15cd583dc3 Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Tue, 18 Jul 2023 13:32:33 +0800 Subject: [PATCH 2/5] fix hard-code kernel size in `ViTAutoEnc` Signed-off-by: KumoLiu --- monai/networks/nets/vitautoenc.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/monai/networks/nets/vitautoenc.py b/monai/networks/nets/vitautoenc.py index 6ae8135438..12d7d4e376 100644 --- a/monai/networks/nets/vitautoenc.py +++ b/monai/networks/nets/vitautoenc.py @@ -11,6 +11,7 @@ from __future__ import annotations +import math from collections.abc import Sequence import torch @@ -19,7 +20,7 @@ from monai.networks.blocks.patchembedding import PatchEmbeddingBlock from monai.networks.blocks.transformerblock import TransformerBlock from monai.networks.layers import Conv -from monai.utils import ensure_tuple_rep +from monai.utils import ensure_tuple_rep, is_sqrt __all__ = ["ViTAutoEnc"] @@ -78,9 +79,14 @@ def __init__( """ super().__init__() - + if not is_sqrt(patch_size): + raise ValueError(f"patch_size should be square number, got {patch_size}.") self.patch_size = ensure_tuple_rep(patch_size, spatial_dims) + self.img_size = ensure_tuple_rep(img_size, spatial_dims) self.spatial_dims = spatial_dims + for m, p in zip(self.img_size, self.patch_size): + if m % p != 0: + raise ValueError(f"patch_size={patch_size} should be divisible by img_size={img_size}.") self.patch_embedding = PatchEmbeddingBlock( in_channels=in_channels, @@ -100,12 +106,12 @@ def __init__( ) self.norm = nn.LayerNorm(hidden_size) - new_patch_size = [4] * self.spatial_dims conv_trans = Conv[Conv.CONVTRANS, self.spatial_dims] # self.conv3d_transpose* is to be compatible with existing 3d model weights. - self.conv3d_transpose = conv_trans(hidden_size, deconv_chns, kernel_size=new_patch_size, stride=new_patch_size) + up_kernel_size = [int(math.sqrt(i)) for i in self.patch_size] + self.conv3d_transpose = conv_trans(hidden_size, deconv_chns, kernel_size=up_kernel_size, stride=up_kernel_size) self.conv3d_transpose_1 = conv_trans( - in_channels=deconv_chns, out_channels=out_channels, kernel_size=new_patch_size, stride=new_patch_size + in_channels=deconv_chns, out_channels=out_channels, kernel_size=up_kernel_size, stride=up_kernel_size ) def forward(self, x): From 5193f345728d5d630675aa3c24ac7b39d21d14dd Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Tue, 18 Jul 2023 13:39:07 +0800 Subject: [PATCH 3/5] add unittest Signed-off-by: KumoLiu --- tests/test_vitautoenc.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/tests/test_vitautoenc.py b/tests/test_vitautoenc.py index 7a3956e964..5e95d3c7fb 100644 --- a/tests/test_vitautoenc.py +++ b/tests/test_vitautoenc.py @@ -49,7 +49,7 @@ { "in_channels": 1, "img_size": (512, 512, 32), - "patch_size": (16, 16, 16), + "patch_size": (64, 64, 16), "hidden_size": 768, "mlp_dim": 3072, "num_layers": 4, @@ -147,6 +147,19 @@ def test_ill_arg(self): dropout_rate=0.3, ) + with self.assertRaises(ValueError): + ViTAutoEnc( + in_channels=4, + img_size=(96, 96, 96), + patch_size=(9, 9, 9), + hidden_size=768, + mlp_dim=3072, + num_layers=12, + num_heads=12, + pos_embed="perc", + dropout_rate=0.3, + ) + if __name__ == "__main__": unittest.main() From e310625012460d7e8703e476518d3f99540dbf1b Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Tue, 18 Jul 2023 13:54:56 +0800 Subject: [PATCH 4/5] fix flake8 Signed-off-by: KumoLiu --- monai/utils/__init__.py | 2 +- monai/utils/misc.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/utils/__init__.py b/monai/utils/__init__.py index b5c0ea2a83..58bde3f5e8 100644 --- a/monai/utils/__init__.py +++ b/monai/utils/__init__.py @@ -80,6 +80,7 @@ is_module_ver_at_least, is_scalar, is_scalar_tensor, + is_sqrt, issequenceiterable, list_to_dict, path_to_uri, @@ -94,7 +95,6 @@ str2list, to_tuple_of_dictionaries, zip_with, - is_sqrt, ) from .module import ( InvalidPyTorchVersionError, diff --git a/monai/utils/misc.py b/monai/utils/misc.py index 3dc7f3e054..20ef256708 100644 --- a/monai/utils/misc.py +++ b/monai/utils/misc.py @@ -13,8 +13,8 @@ import inspect import itertools -import os import math +import os import pprint import random import shutil From d60f955cf5c504f93135dad0bd9b63951b1af0f9 Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Tue, 18 Jul 2023 14:02:23 +0800 Subject: [PATCH 5/5] add docstring Signed-off-by: KumoLiu --- monai/utils/misc.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/monai/utils/misc.py b/monai/utils/misc.py index 20ef256708..18f05b6e9f 100644 --- a/monai/utils/misc.py +++ b/monai/utils/misc.py @@ -856,8 +856,11 @@ def run_cmd(cmd_list: list[str], **kwargs: Any) -> subprocess.CompletedProcess: raise RuntimeError(f"subprocess call error {e.returncode}: {errors}, {output}.") from e -def is_sqrt(size): - size = ensure_tuple(size) - sqrt_size = [int(math.sqrt(_size)) for _size in size] - ret = [_i * _j for _i, _j in zip(sqrt_size, sqrt_size)] - return ensure_tuple(ret) == size +def is_sqrt(num: Sequence[int] | int) -> bool: + """ + Determine if the input is a square number or a squence of square numbers. + """ + num = ensure_tuple(num) + sqrt_num = [int(math.sqrt(_num)) for _num in num] + ret = [_i * _j for _i, _j in zip(sqrt_num, sqrt_num)] + return ensure_tuple(ret) == num