From 363e574febf5979e3ac10bdcb31cc486a6ed20e3 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Thu, 5 Aug 2021 17:02:45 +0100 Subject: [PATCH 1/6] 2d/3d patchembedding Signed-off-by: Wenqi Li --- monai/networks/blocks/patchembedding.py | 73 +++++++++++++------------ monai/networks/nets/vit.py | 8 ++- tests/test_patchembedding.py | 30 +++++----- 3 files changed, 60 insertions(+), 51 deletions(-) diff --git a/monai/networks/blocks/patchembedding.py b/monai/networks/blocks/patchembedding.py index 1f312e9126..547734f9e5 100644 --- a/monai/networks/blocks/patchembedding.py +++ b/monai/networks/blocks/patchembedding.py @@ -11,31 +11,42 @@ import math -from typing import Tuple, Union +from typing import Tuple +import numpy as np import torch import torch.nn as nn -from monai.utils import optional_import +from monai.networks.layers import Conv +from monai.utils import ensure_tuple_rep, optional_import +from monai.utils.module import look_up_option Rearrange, _ = optional_import("einops.layers.torch", name="Rearrange") +SUPPORTED_EMBEDDING_TYPES = {"conv", "perceptron"} class PatchEmbeddingBlock(nn.Module): """ A patch embedding block, based on: "Dosovitskiy et al., An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale " + + Example:: + + >>> from monai.networks.blocks import PatchEmbeddingBlock + >>> PatchEmbeddingBlock(in_channels=4, img_size=32, patch_size=8, hidden_size=32, num_heads=4, pos_embed="conv") + """ def __init__( self, in_channels: int, - img_size: Tuple[int, int, int], - patch_size: Tuple[int, int, int], + img_size: Tuple[int, ...], + patch_size: Tuple[int, ...], hidden_size: int, num_heads: int, pos_embed: str, dropout_rate: float = 0.0, + spatial_dims: int = 3, ) -> None: """ Args: @@ -46,47 +57,44 @@ def __init__( num_heads: number of attention heads. pos_embed: position embedding layer type. dropout_rate: faction of the input units to drop. + spatial_dims: number of spatial dimensions + """ - super().__init__() + super(PatchEmbeddingBlock, self).__init__() if not (0 <= dropout_rate <= 1): - raise AssertionError("dropout_rate should be between 0 and 1.") + raise ValueError("dropout_rate should be between 0 and 1.") if hidden_size % num_heads != 0: - raise AssertionError("hidden size should be divisible by num_heads.") + raise ValueError("hidden size should be divisible by num_heads.") + + self.pos_embed = look_up_option(pos_embed, SUPPORTED_EMBEDDING_TYPES) + img_size = ensure_tuple_rep(img_size, spatial_dims) + patch_size = ensure_tuple_rep(patch_size, spatial_dims) for m, p in zip(img_size, patch_size): if m < p: - raise AssertionError("patch_size should be smaller than img_size.") + raise ValueError("patch_size should be smaller than img_size.") + if self.pos_embed == "perceptron" and m % p != 0: + raise ValueError("patch_size should be divisible by img_size for perceptron.") + self.n_patches = np.prod([im_d // p_d for im_d, p_d in zip(img_size, patch_size)]) + self.patch_dim = in_channels * np.prod(patch_size) - if pos_embed not in ["conv", "perceptron"]: - raise KeyError(f"Position embedding layer of type {pos_embed} is not supported.") - - if pos_embed == "perceptron": - if img_size[0] % patch_size[0] != 0: - raise AssertionError("img_size should be divisible by patch_size for perceptron patch embedding.") - - self.n_patches = ( - (img_size[0] // patch_size[0]) * (img_size[1] // patch_size[1]) * (img_size[2] // patch_size[2]) - ) - self.patch_dim = in_channels * patch_size[0] * patch_size[1] * patch_size[2] - - self.pos_embed = pos_embed - self.patch_embeddings: Union[nn.Conv3d, nn.Sequential] + self.patch_embeddings: nn.Module if self.pos_embed == "conv": - self.patch_embeddings = nn.Conv3d( + self.patch_embeddings = Conv[Conv.CONV, spatial_dims]( in_channels=in_channels, out_channels=hidden_size, kernel_size=patch_size, stride=patch_size ) elif self.pos_embed == "perceptron": + # for 3d: "b c (h p1) (w p2) (d p3)-> b (h w d) (p1 p2 p3 c)" + chars = (("h", "p1"), ("w", "p2"), ("d", "p3"))[:spatial_dims] + from_chars = "b c " + " ".join(f"({k} {v})" for k, v in chars) + to_chars = f"b ({' '.join([c[0] for c in chars])}) ({' '.join([c[1] for c in chars])} c)" + axes_len = {f"p{i+1}": p for i, p in enumerate(patch_size)} self.patch_embeddings = nn.Sequential( - Rearrange( - "b c (h p1) (w p2) (d p3)-> b (h w d) (p1 p2 p3 c)", - p1=patch_size[0], - p2=patch_size[1], - p3=patch_size[2], - ), + Rearrange(f"{from_chars} -> {to_chars}", **axes_len), nn.Linear(self.patch_dim, hidden_size), ) self.position_embeddings = nn.Parameter(torch.zeros(1, self.n_patches, hidden_size)) @@ -121,12 +129,9 @@ def norm_cdf(x): return tensor def forward(self, x): + x = self.patch_embeddings(x) if self.pos_embed == "conv": - x = self.patch_embeddings(x) - x = x.flatten(2) - x = x.transpose(-1, -2) - elif self.pos_embed == "perceptron": - x = self.patch_embeddings(x) + x = x.flatten(2).transpose(-1, -2) embeddings = x + self.position_embeddings embeddings = self.dropout(embeddings) return embeddings diff --git a/monai/networks/nets/vit.py b/monai/networks/nets/vit.py index 3e90a36757..b8bc54aa4c 100644 --- a/monai/networks/nets/vit.py +++ b/monai/networks/nets/vit.py @@ -75,7 +75,13 @@ def __init__( self.classification = classification self.patch_embedding = PatchEmbeddingBlock( - in_channels, img_size, patch_size, hidden_size, num_heads, pos_embed, dropout_rate + in_channels=in_channels, + img_size=img_size, + patch_size=patch_size, + hidden_size=hidden_size, + num_heads=num_heads, + pos_embed=pos_embed, + dropout_rate=dropout_rate, ) self.blocks = nn.ModuleList( [TransformerBlock(hidden_size, mlp_dim, num_heads, dropout_rate) for i in range(num_layers)] diff --git a/tests/test_patchembedding.py b/tests/test_patchembedding.py index 5283153880..6c9ac78a99 100644 --- a/tests/test_patchembedding.py +++ b/tests/test_patchembedding.py @@ -12,7 +12,6 @@ import unittest from unittest import skipUnless -import numpy as np import torch from parameterized import parameterized @@ -23,31 +22,30 @@ einops, has_einops = optional_import("einops") TEST_CASE_PATCHEMBEDDINGBLOCK = [] -for dropout_rate in np.linspace(0, 1, 2): +for dropout_rate in (0.5,): for in_channels in [1, 4]: for hidden_size in [360, 768]: for img_size in [96, 128]: for patch_size in [8, 16]: for num_heads in [8, 12]: for pos_embed in ["conv", "perceptron"]: - for classification in ["False", "True"]: - if classification: - out = (2, (img_size // patch_size) ** 3 + 1, hidden_size) - else: - out = (2, (img_size // patch_size) ** 3, hidden_size) + # for classification in (False, True): # TODO: add classification tests + for nd in (2, 3): test_case = [ { "in_channels": in_channels, - "img_size": (img_size, img_size, img_size), - "patch_size": (patch_size, patch_size, patch_size), + "img_size": (img_size,) * nd, + "patch_size": (patch_size,) * nd, "hidden_size": hidden_size, "num_heads": num_heads, "pos_embed": pos_embed, "dropout_rate": dropout_rate, }, - (2, in_channels, img_size, *([img_size] * 2)), - (2, (img_size // patch_size) ** 3, hidden_size), + (2, in_channels, *([img_size] * nd)), + (2, (img_size // patch_size) ** nd, hidden_size), ] + if nd == 2: + test_case[0]["spatial_dims"] = 2 # type: ignore TEST_CASE_PATCHEMBEDDINGBLOCK.append(test_case) @@ -61,7 +59,7 @@ def test_shape(self, input_param, input_shape, expected_shape): self.assertEqual(result.shape, expected_shape) def test_ill_arg(self): - with self.assertRaises(AssertionError): + with self.assertRaises(ValueError): PatchEmbeddingBlock( in_channels=1, img_size=(128, 128, 128), @@ -72,7 +70,7 @@ def test_ill_arg(self): dropout_rate=5.0, ) - with self.assertRaises(AssertionError): + with self.assertRaises(ValueError): PatchEmbeddingBlock( in_channels=1, img_size=(32, 32, 32), @@ -83,7 +81,7 @@ def test_ill_arg(self): dropout_rate=0.3, ) - with self.assertRaises(AssertionError): + with self.assertRaises(ValueError): PatchEmbeddingBlock( in_channels=1, img_size=(96, 96, 96), @@ -94,7 +92,7 @@ def test_ill_arg(self): dropout_rate=0.3, ) - with self.assertRaises(AssertionError): + with self.assertRaises(ValueError): PatchEmbeddingBlock( in_channels=1, img_size=(97, 97, 97), @@ -105,7 +103,7 @@ def test_ill_arg(self): dropout_rate=0.3, ) - with self.assertRaises(KeyError): + with self.assertRaises(ValueError): PatchEmbeddingBlock( in_channels=4, img_size=(96, 96, 96), From 3d72d0872575f7721e8315733d305b934049b0c1 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Thu, 5 Aug 2021 17:57:29 +0100 Subject: [PATCH 2/6] minor updates for selfattention Signed-off-by: Wenqi Li --- monai/networks/blocks/mlp.py | 2 +- monai/networks/blocks/selfattention.py | 16 ++++++---------- monai/networks/blocks/transformerblock.py | 4 ++-- tests/test_selfattention.py | 4 ++-- 4 files changed, 11 insertions(+), 15 deletions(-) diff --git a/monai/networks/blocks/mlp.py b/monai/networks/blocks/mlp.py index b108188605..11b5e6fc15 100644 --- a/monai/networks/blocks/mlp.py +++ b/monai/networks/blocks/mlp.py @@ -35,7 +35,7 @@ def __init__( super().__init__() if not (0 <= dropout_rate <= 1): - raise AssertionError("dropout_rate should be between 0 and 1.") + raise ValueError("dropout_rate should be between 0 and 1.") self.linear1 = nn.Linear(hidden_size, mlp_dim) self.linear2 = nn.Linear(mlp_dim, hidden_size) diff --git a/monai/networks/blocks/selfattention.py b/monai/networks/blocks/selfattention.py index bd5bbfa072..9dc45cccc8 100644 --- a/monai/networks/blocks/selfattention.py +++ b/monai/networks/blocks/selfattention.py @@ -14,7 +14,7 @@ from monai.utils import optional_import -einops, has_einops = optional_import("einops") +einops, _ = optional_import("einops") class SABlock(nn.Module): @@ -37,13 +37,13 @@ def __init__( """ - super().__init__() + super(SABlock, self).__init__() if not (0 <= dropout_rate <= 1): - raise AssertionError("dropout_rate should be between 0 and 1.") + raise ValueError("dropout_rate should be between 0 and 1.") if hidden_size % num_heads != 0: - raise AssertionError("hidden size should be divisible by num_heads.") + raise ValueError("hidden size should be divisible by num_heads.") self.num_heads = num_heads self.out_proj = nn.Linear(hidden_size, hidden_size) @@ -52,17 +52,13 @@ def __init__( self.drop_weights = nn.Dropout(dropout_rate) self.head_dim = hidden_size // num_heads self.scale = self.head_dim ** -0.5 - if has_einops: - self.rearrange = einops.rearrange - else: - raise ValueError('"Requires einops.') def forward(self, x): - q, k, v = self.rearrange(self.qkv(x), "b h (qkv l d) -> qkv b l h d", qkv=3, l=self.num_heads) + q, k, v = einops.rearrange(self.qkv(x), "b h (qkv l d) -> qkv b l h d", qkv=3, l=self.num_heads) att_mat = (torch.einsum("blxd,blyd->blxy", q, k) * self.scale).softmax(dim=-1) att_mat = self.drop_weights(att_mat) x = torch.einsum("bhxy,bhyd->bhxd", att_mat, v) - x = self.rearrange(x, "b h l d -> b l (h d)") + x = einops.rearrange(x, "b h l d -> b l (h d)") x = self.out_proj(x) x = self.drop_output(x) return x diff --git a/monai/networks/blocks/transformerblock.py b/monai/networks/blocks/transformerblock.py index 3dd80f58ad..c7a948ed76 100644 --- a/monai/networks/blocks/transformerblock.py +++ b/monai/networks/blocks/transformerblock.py @@ -40,10 +40,10 @@ def __init__( super().__init__() if not (0 <= dropout_rate <= 1): - raise AssertionError("dropout_rate should be between 0 and 1.") + raise ValueError("dropout_rate should be between 0 and 1.") if hidden_size % num_heads != 0: - raise AssertionError("hidden size should be divisible by num_heads.") + raise ValueError("hidden_size should be divisible by num_heads.") self.mlp = MLPBlock(hidden_size, mlp_dim, dropout_rate) self.norm1 = nn.LayerNorm(hidden_size) diff --git a/tests/test_selfattention.py b/tests/test_selfattention.py index 2430b82c9b..3d561aac2f 100644 --- a/tests/test_selfattention.py +++ b/tests/test_selfattention.py @@ -49,10 +49,10 @@ def test_shape(self, input_param, input_shape, expected_shape): self.assertEqual(result.shape, expected_shape) def test_ill_arg(self): - with self.assertRaises(AssertionError): + with self.assertRaises(ValueError): SABlock(hidden_size=128, num_heads=12, dropout_rate=6.0) - with self.assertRaises(AssertionError): + with self.assertRaises(ValueError): SABlock(hidden_size=620, num_heads=8, dropout_rate=0.4) From a41bdfd8276bac6ed08bdd402836c4e31380b0a6 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Thu, 5 Aug 2021 18:18:49 +0100 Subject: [PATCH 3/6] 2d vit Signed-off-by: Wenqi Li --- monai/networks/blocks/patchembedding.py | 2 +- monai/networks/nets/vit.py | 18 +++++++-------- tests/test_mlp.py | 2 +- tests/test_vit.py | 29 ++++++++++++------------- 4 files changed, 25 insertions(+), 26 deletions(-) diff --git a/monai/networks/blocks/patchembedding.py b/monai/networks/blocks/patchembedding.py index 547734f9e5..6dc467f13a 100644 --- a/monai/networks/blocks/patchembedding.py +++ b/monai/networks/blocks/patchembedding.py @@ -57,7 +57,7 @@ def __init__( num_heads: number of attention heads. pos_embed: position embedding layer type. dropout_rate: faction of the input units to drop. - spatial_dims: number of spatial dimensions + spatial_dims: number of spatial dimensions. """ diff --git a/monai/networks/nets/vit.py b/monai/networks/nets/vit.py index b8bc54aa4c..6c733a4331 100644 --- a/monai/networks/nets/vit.py +++ b/monai/networks/nets/vit.py @@ -27,8 +27,8 @@ class ViT(nn.Module): def __init__( self, in_channels: int, - img_size: Tuple[int, int, int], - patch_size: Tuple[int, int, int], + img_size: Tuple[int, ...], + patch_size: Tuple[int, ...], hidden_size: int = 768, mlp_dim: int = 3072, num_layers: int = 12, @@ -37,6 +37,7 @@ def __init__( classification: bool = False, num_classes: int = 2, dropout_rate: float = 0.0, + spatial_dims: int = 3, ) -> None: """ Args: @@ -51,6 +52,7 @@ def __init__( classification: bool argument to determine if classification is used. num_classes: number of classes if classification is used. dropout_rate: faction of the input units to drop. + spatial_dims: number of spatial dimensions. Examples:: @@ -58,20 +60,17 @@ def __init__( >>> net = ViT(in_channels=1, img_size=(96,96,96), pos_embed='conv') # for 3-channel with patch size of (128,128,128), 24 layers and classification backbone - >>> net = ViT(in_channels=3, img_size=(128,128,128), pos_embed='conv', classification= True) + >>> net = ViT(in_channels=3, img_size=(128,128,128), pos_embed='conv', classification=True) """ - super().__init__() + super(ViT, self).__init__() if not (0 <= dropout_rate <= 1): - raise AssertionError("dropout_rate should be between 0 and 1.") + raise ValueError("dropout_rate should be between 0 and 1.") if hidden_size % num_heads != 0: - raise AssertionError("hidden size should be divisible by num_heads.") - - if pos_embed not in ["conv", "perceptron"]: - raise KeyError(f"Position embedding layer of type {pos_embed} is not supported.") + raise ValueError("hidden_size should be divisible by num_heads.") self.classification = classification self.patch_embedding = PatchEmbeddingBlock( @@ -82,6 +81,7 @@ def __init__( num_heads=num_heads, pos_embed=pos_embed, dropout_rate=dropout_rate, + spatial_dims=spatial_dims, ) self.blocks = nn.ModuleList( [TransformerBlock(hidden_size, mlp_dim, num_heads, dropout_rate) for i in range(num_layers)] diff --git a/tests/test_mlp.py b/tests/test_mlp.py index efc8db74c2..7a93f81ec3 100644 --- a/tests/test_mlp.py +++ b/tests/test_mlp.py @@ -44,7 +44,7 @@ def test_shape(self, input_param, input_shape, expected_shape): self.assertEqual(result.shape, expected_shape) def test_ill_arg(self): - with self.assertRaises(AssertionError): + with self.assertRaises(ValueError): MLPBlock(hidden_size=128, mlp_dim=512, dropout_rate=5.0) diff --git a/tests/test_vit.py b/tests/test_vit.py index 0d0d58093b..fd32b303c7 100644 --- a/tests/test_vit.py +++ b/tests/test_vit.py @@ -28,28 +28,27 @@ for num_layers in [4]: for num_classes in [2]: for pos_embed in ["conv"]: - for classification in ["False"]: - if classification: - out = (2, num_classes) - else: - out = (2, (img_size // patch_size) ** 3, hidden_size) # type: ignore + # for classification in [False, True]: # TODO: test classification + for nd in (2, 3): test_case = [ { "in_channels": in_channels, - "img_size": (img_size, img_size, img_size), - "patch_size": (patch_size, patch_size, patch_size), + "img_size": (img_size,) * nd, + "patch_size": (patch_size,) * nd, "hidden_size": hidden_size, "mlp_dim": mlp_dim, "num_layers": num_layers, "num_heads": num_heads, "pos_embed": pos_embed, - "classification": classification, + "classification": False, "num_classes": num_classes, "dropout_rate": dropout_rate, }, - (2, in_channels, img_size, *([img_size] * 2)), - out, + (2, in_channels, *([img_size] * nd)), + (2, (img_size // patch_size) ** nd, hidden_size), ] + if nd == 2: + test_case[0]["spatial_dims"] = 2 TEST_CASE_Vit.append(test_case) @@ -62,7 +61,7 @@ def test_shape(self, input_param, input_shape, expected_shape): self.assertEqual(result.shape, expected_shape) def test_ill_arg(self): - with self.assertRaises(AssertionError): + with self.assertRaises(ValueError): ViT( in_channels=1, img_size=(128, 128, 128), @@ -76,7 +75,7 @@ def test_ill_arg(self): dropout_rate=5.0, ) - with self.assertRaises(AssertionError): + with self.assertRaises(ValueError): ViT( in_channels=1, img_size=(32, 32, 32), @@ -90,7 +89,7 @@ def test_ill_arg(self): dropout_rate=0.3, ) - with self.assertRaises(AssertionError): + with self.assertRaises(ValueError): ViT( in_channels=1, img_size=(96, 96, 96), @@ -104,7 +103,7 @@ def test_ill_arg(self): dropout_rate=0.3, ) - with self.assertRaises(AssertionError): + with self.assertRaises(ValueError): ViT( in_channels=1, img_size=(97, 97, 97), @@ -118,7 +117,7 @@ def test_ill_arg(self): dropout_rate=0.3, ) - with self.assertRaises(KeyError): + with self.assertRaises(ValueError): ViT( in_channels=4, img_size=(96, 96, 96), From e248698afd1879190e4564d8125c1495a8171ab3 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Thu, 5 Aug 2021 18:24:56 +0100 Subject: [PATCH 4/6] fixes type hint Signed-off-by: Wenqi Li --- tests/test_vit.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_vit.py b/tests/test_vit.py index fd32b303c7..0dce73b0cb 100644 --- a/tests/test_vit.py +++ b/tests/test_vit.py @@ -48,7 +48,7 @@ (2, (img_size // patch_size) ** nd, hidden_size), ] if nd == 2: - test_case[0]["spatial_dims"] = 2 + test_case[0]["spatial_dims"] = 2 # type: ignore TEST_CASE_Vit.append(test_case) From f822f67e4c8bc1f63412570e22868917b259e1da Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Thu, 5 Aug 2021 19:01:53 +0100 Subject: [PATCH 5/6] update unetr Signed-off-by: Wenqi Li --- monai/networks/blocks/patchembedding.py | 6 +-- monai/networks/blocks/unetr_block.py | 11 ++--- monai/networks/nets/unetr.py | 60 ++++++++++++------------- monai/networks/nets/vit.py | 6 +-- tests/test_unetr.py | 51 ++++++++++----------- tests/test_unetr_block.py | 48 ++++++++++---------- 6 files changed, 88 insertions(+), 94 deletions(-) diff --git a/monai/networks/blocks/patchembedding.py b/monai/networks/blocks/patchembedding.py index 6dc467f13a..c1fcfa9af7 100644 --- a/monai/networks/blocks/patchembedding.py +++ b/monai/networks/blocks/patchembedding.py @@ -11,7 +11,7 @@ import math -from typing import Tuple +from typing import Sequence, Union import numpy as np import torch @@ -40,8 +40,8 @@ class PatchEmbeddingBlock(nn.Module): def __init__( self, in_channels: int, - img_size: Tuple[int, ...], - patch_size: Tuple[int, ...], + img_size: Union[Sequence[int], int], + patch_size: Union[Sequence[int], int], hidden_size: int, num_heads: int, pos_embed: str, diff --git a/monai/networks/blocks/unetr_block.py b/monai/networks/blocks/unetr_block.py index 20c39f6240..a0852d05e0 100644 --- a/monai/networks/blocks/unetr_block.py +++ b/monai/networks/blocks/unetr_block.py @@ -28,9 +28,8 @@ def __init__( self, spatial_dims: int, in_channels: int, - out_channels: int, # type: ignore + out_channels: int, kernel_size: Union[Sequence[int], int], - stride: Union[Sequence[int], int], upsample_kernel_size: Union[Sequence[int], int], norm_name: Union[Tuple, str], res_block: bool = False, @@ -41,7 +40,6 @@ def __init__( in_channels: number of input channels. out_channels: number of output channels. kernel_size: convolution kernel size. - stride: convolution stride. upsample_kernel_size: convolution kernel size for transposed convolution layers. norm_name: feature normalization type and arguments. res_block: bool argument to determine if residual block is used. @@ -148,7 +146,7 @@ def __init__( is_transposed=True, ), UnetResBlock( - spatial_dims=3, + spatial_dims=spatial_dims, in_channels=out_channels, out_channels=out_channels, kernel_size=kernel_size, @@ -173,7 +171,7 @@ def __init__( is_transposed=True, ), UnetBasicBlock( - spatial_dims=3, + spatial_dims=spatial_dims, in_channels=out_channels, out_channels=out_channels, kernel_size=kernel_size, @@ -257,5 +255,4 @@ def __init__( ) def forward(self, inp): - out = self.layer(inp) - return out + return self.layer(inp) diff --git a/monai/networks/nets/unetr.py b/monai/networks/nets/unetr.py index 1ac9c9ee49..ed49847515 100644 --- a/monai/networks/nets/unetr.py +++ b/monai/networks/nets/unetr.py @@ -9,13 +9,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Tuple, Union +from typing import Sequence, Tuple, Union import torch.nn as nn from monai.networks.blocks.dynunet_block import UnetOutBlock from monai.networks.blocks.unetr_block import UnetrBasicBlock, UnetrPrUpBlock, UnetrUpBlock from monai.networks.nets.vit import ViT +from monai.utils import ensure_tuple_rep class UNETR(nn.Module): @@ -28,7 +29,7 @@ def __init__( self, in_channels: int, out_channels: int, - img_size: Tuple[int, int, int], + img_size: Union[Sequence[int], int], feature_size: int = 16, hidden_size: int = 768, mlp_dim: int = 3072, @@ -38,6 +39,7 @@ def __init__( conv_block: bool = False, res_block: bool = True, dropout_rate: float = 0.0, + spatial_dims: int = 3, ) -> None: """ Args: @@ -53,35 +55,33 @@ def __init__( conv_block: bool argument to determine if convolutional block is used. res_block: bool argument to determine if residual block is used. dropout_rate: faction of the input units to drop. + spatial_dims: number of spatial dims. Examples:: # for single channel input 4-channel output with patch size of (96,96,96), feature size of 32 and batch norm >>> net = UNETR(in_channels=1, out_channels=4, img_size=(96,96,96), feature_size=32, norm_name='batch') + # for single channel input 4-channel output with patch size of (96,96), feature size of 32 and batch norm + >>> net = UNETR(in_channels=1, out_channels=4, img_size=96, feature_size=32, norm_name='batch', spatial_dims=2) + # for 4-channel input 3-channel output with patch size of (128,128,128), conv position embedding and instance norm >>> net = UNETR(in_channels=4, out_channels=3, img_size=(128,128,128), pos_embed='conv', norm_name='instance') """ - super().__init__() + super(UNETR, self).__init__() if not (0 <= dropout_rate <= 1): - raise AssertionError("dropout_rate should be between 0 and 1.") + raise ValueError("dropout_rate should be between 0 and 1.") if hidden_size % num_heads != 0: - raise AssertionError("hidden size should be divisible by num_heads.") - - if pos_embed not in ["conv", "perceptron"]: - raise KeyError(f"Position embedding layer of type {pos_embed} is not supported.") + raise ValueError("hidden_size should be divisible by num_heads.") self.num_layers = 12 - self.patch_size = (16, 16, 16) - self.feat_size = ( - img_size[0] // self.patch_size[0], - img_size[1] // self.patch_size[1], - img_size[2] // self.patch_size[2], - ) + img_size = ensure_tuple_rep(img_size, spatial_dims) + self.patch_size = ensure_tuple_rep(16, spatial_dims) + self.feat_size = tuple(img_d // p_d for img_d, p_d in zip(img_size, self.patch_size)) self.hidden_size = hidden_size self.classification = False self.vit = ViT( @@ -95,9 +95,10 @@ def __init__( pos_embed=pos_embed, classification=self.classification, dropout_rate=dropout_rate, + spatial_dims=spatial_dims, ) self.encoder1 = UnetrBasicBlock( - spatial_dims=3, + spatial_dims=spatial_dims, in_channels=in_channels, out_channels=feature_size, kernel_size=3, @@ -106,7 +107,7 @@ def __init__( res_block=res_block, ) self.encoder2 = UnetrPrUpBlock( - spatial_dims=3, + spatial_dims=spatial_dims, in_channels=hidden_size, out_channels=feature_size * 2, num_layer=2, @@ -118,7 +119,7 @@ def __init__( res_block=res_block, ) self.encoder3 = UnetrPrUpBlock( - spatial_dims=3, + spatial_dims=spatial_dims, in_channels=hidden_size, out_channels=feature_size * 4, num_layer=1, @@ -130,7 +131,7 @@ def __init__( res_block=res_block, ) self.encoder4 = UnetrPrUpBlock( - spatial_dims=3, + spatial_dims=spatial_dims, in_channels=hidden_size, out_channels=feature_size * 8, num_layer=0, @@ -142,50 +143,48 @@ def __init__( res_block=res_block, ) self.decoder5 = UnetrUpBlock( - spatial_dims=3, + spatial_dims=spatial_dims, in_channels=hidden_size, out_channels=feature_size * 8, - stride=1, kernel_size=3, upsample_kernel_size=2, norm_name=norm_name, res_block=res_block, ) self.decoder4 = UnetrUpBlock( - spatial_dims=3, + spatial_dims=spatial_dims, in_channels=feature_size * 8, out_channels=feature_size * 4, - stride=1, kernel_size=3, upsample_kernel_size=2, norm_name=norm_name, res_block=res_block, ) self.decoder3 = UnetrUpBlock( - spatial_dims=3, + spatial_dims=spatial_dims, in_channels=feature_size * 4, out_channels=feature_size * 2, - stride=1, kernel_size=3, upsample_kernel_size=2, norm_name=norm_name, res_block=res_block, ) self.decoder2 = UnetrUpBlock( - spatial_dims=3, + spatial_dims=spatial_dims, in_channels=feature_size * 2, out_channels=feature_size, - stride=1, kernel_size=3, upsample_kernel_size=2, norm_name=norm_name, res_block=res_block, ) - self.out = UnetOutBlock(spatial_dims=3, in_channels=feature_size, out_channels=out_channels) # type: ignore + self.out = UnetOutBlock(spatial_dims=spatial_dims, in_channels=feature_size, out_channels=out_channels) def proj_feat(self, x, hidden_size, feat_size): - x = x.view(x.size(0), feat_size[0], feat_size[1], feat_size[2], hidden_size) - x = x.permute(0, 4, 1, 2, 3).contiguous() + new_view = (x.size(0), *feat_size, hidden_size) + x = x.view(new_view) + new_axes = (0, len(x.shape) - 1) + tuple(d + 1 for d in range(len(feat_size))) + x = x.permute(new_axes).contiguous() return x def forward(self, x_in): @@ -202,5 +201,4 @@ def forward(self, x_in): dec2 = self.decoder4(dec3, enc3) dec1 = self.decoder3(dec2, enc2) out = self.decoder2(dec1, enc1) - logits = self.out(out) - return logits + return self.out(out) diff --git a/monai/networks/nets/vit.py b/monai/networks/nets/vit.py index 6c733a4331..0fd55cac62 100644 --- a/monai/networks/nets/vit.py +++ b/monai/networks/nets/vit.py @@ -10,7 +10,7 @@ # limitations under the License. -from typing import Tuple +from typing import Sequence, Union import torch.nn as nn @@ -27,8 +27,8 @@ class ViT(nn.Module): def __init__( self, in_channels: int, - img_size: Tuple[int, ...], - patch_size: Tuple[int, ...], + img_size: Union[Sequence[int], int], + patch_size: Union[Sequence[int], int], hidden_size: int = 768, mlp_dim: int = 3072, num_layers: int = 12, diff --git a/tests/test_unetr.py b/tests/test_unetr.py index cd50cb487c..d19ed2ca59 100644 --- a/tests/test_unetr.py +++ b/tests/test_unetr.py @@ -28,27 +28,28 @@ for mlp_dim in [3072]: for norm_name in ["instance"]: for pos_embed in ["perceptron"]: - for conv_block in [True]: - for res_block in [False]: - test_case = [ - { - "in_channels": in_channels, - "out_channels": out_channels, - "img_size": (img_size, img_size, img_size), - "hidden_size": hidden_size, - "feature_size": feature_size, - "norm_name": norm_name, - "mlp_dim": mlp_dim, - "num_heads": num_heads, - "pos_embed": pos_embed, - "dropout_rate": dropout_rate, - "conv_block": conv_block, - "res_block": res_block, - }, - (2, in_channels, img_size, *([img_size] * 2)), - (2, out_channels, img_size, *([img_size] * 2)), - ] - TEST_CASE_UNETR.append(test_case) + for nd in (2, 3): + test_case = [ + { + "in_channels": in_channels, + "out_channels": out_channels, + "img_size": (img_size,) * nd, + "hidden_size": hidden_size, + "feature_size": feature_size, + "norm_name": norm_name, + "mlp_dim": mlp_dim, + "num_heads": num_heads, + "pos_embed": pos_embed, + "dropout_rate": dropout_rate, + "conv_block": True, + "res_block": False, + }, + (2, in_channels, *([img_size] * nd)), + (2, out_channels, *([img_size] * nd)), + ] + if nd == 2: + test_case[0]["spatial_dims"] = 2 # type: ignore + TEST_CASE_UNETR.append(test_case) class TestPatchEmbeddingBlock(unittest.TestCase): @@ -60,7 +61,7 @@ def test_shape(self, input_param, input_shape, expected_shape): self.assertEqual(result.shape, expected_shape) def test_ill_arg(self): - with self.assertRaises(AssertionError): + with self.assertRaises(ValueError): UNETR( in_channels=1, out_channels=3, @@ -74,7 +75,7 @@ def test_ill_arg(self): dropout_rate=5.0, ) - with self.assertRaises(AssertionError): + with self.assertRaises(ValueError): UNETR( in_channels=1, out_channels=4, @@ -88,7 +89,7 @@ def test_ill_arg(self): dropout_rate=0.5, ) - with self.assertRaises(AssertionError): + with self.assertRaises(ValueError): UNETR( in_channels=1, out_channels=3, @@ -102,7 +103,7 @@ def test_ill_arg(self): dropout_rate=0.4, ) - with self.assertRaises(KeyError): + with self.assertRaises(ValueError): UNETR( in_channels=1, out_channels=4, diff --git a/tests/test_unetr_block.py b/tests/test_unetr_block.py index 0b22838fae..7546918a2c 100644 --- a/tests/test_unetr_block.py +++ b/tests/test_unetr_block.py @@ -20,7 +20,7 @@ from tests.utils import test_script_save TEST_CASE_UNETR_BASIC_BLOCK = [] -for spatial_dims in range(2, 4): +for spatial_dims in range(1, 4): for kernel_size in [1, 3]: for stride in [2]: for norm_name in [("GROUP", {"num_groups": 16}), ("batch", {"track_running_stats": False}), "instance"]: @@ -45,34 +45,32 @@ TEST_UP_BLOCK = [] in_channels, out_channels = 4, 2 -for spatial_dims in range(2, 4): +for spatial_dims in range(1, 4): for kernel_size in [1, 3]: - for stride in [1, 2]: - for res_block in [False, True]: - for norm_name in ["instance"]: - for in_size in [15, 16]: - out_size = in_size * stride - test_case = [ - { - "spatial_dims": spatial_dims, - "in_channels": in_channels, - "out_channels": out_channels, - "kernel_size": kernel_size, - "norm_name": norm_name, - "stride": stride, - "res_block": res_block, - "upsample_kernel_size": stride, - }, - (1, in_channels, *([in_size] * spatial_dims)), - (1, out_channels, *([out_size] * spatial_dims)), - (1, out_channels, *([in_size * stride] * spatial_dims)), - ] - TEST_UP_BLOCK.append(test_case) + for res_block in [False, True]: + for norm_name in ["instance"]: + for in_size in [15, 16]: + out_size = in_size * stride + test_case = [ + { + "spatial_dims": spatial_dims, + "in_channels": in_channels, + "out_channels": out_channels, + "kernel_size": kernel_size, + "norm_name": norm_name, + "res_block": res_block, + "upsample_kernel_size": stride, + }, + (1, in_channels, *([in_size] * spatial_dims)), + (1, out_channels, *([out_size] * spatial_dims)), + (1, out_channels, *([in_size * stride] * spatial_dims)), + ] + TEST_UP_BLOCK.append(test_case) TEST_PRUP_BLOCK = [] in_channels, out_channels = 4, 2 -for spatial_dims in range(2, 4): +for spatial_dims in range(1, 4): for kernel_size in [1, 3]: for upsample_kernel_size in [2, 3]: for stride in [1, 2]: @@ -81,7 +79,7 @@ for in_size in [15, 16]: for num_layer in [0, 2]: in_size_tmp = in_size - for _num in range(num_layer + 1): + for _ in range(num_layer + 1): out_size = in_size_tmp * upsample_kernel_size in_size_tmp = out_size test_case = [ From 10d6117e0b562570dcb6eb4937322770a9bc9545 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Thu, 5 Aug 2021 19:30:36 +0100 Subject: [PATCH 6/6] fixes unit test Signed-off-by: Wenqi Li --- tests/test_transformerblock.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_transformerblock.py b/tests/test_transformerblock.py index 24d16c77aa..616e3e7ec9 100644 --- a/tests/test_transformerblock.py +++ b/tests/test_transformerblock.py @@ -46,10 +46,10 @@ def test_shape(self, input_param, input_shape, expected_shape): self.assertEqual(result.shape, expected_shape) def test_ill_arg(self): - with self.assertRaises(AssertionError): + with self.assertRaises(ValueError): TransformerBlock(hidden_size=128, num_heads=12, mlp_dim=2048, dropout_rate=4.0) - with self.assertRaises(AssertionError): + with self.assertRaises(ValueError): TransformerBlock(hidden_size=622, num_heads=8, mlp_dim=3072, dropout_rate=0.4)