diff --git a/docs/requirements.txt b/docs/requirements.txt index f412a19307..db6ea92bd8 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -19,3 +19,4 @@ sphinxcontrib-qthelp sphinxcontrib-serializinghtml sphinx-autodoc-typehints==1.11.1 pandas +einops diff --git a/docs/source/installation.md b/docs/source/installation.md index efa5cf08a8..d8dddff205 100644 --- a/docs/source/installation.md +++ b/docs/source/installation.md @@ -174,9 +174,9 @@ Since MONAI v0.2.0, the extras syntax such as `pip install 'monai[nibabel]'` is - The options are ``` -[nibabel, skimage, pillow, tensorboard, gdown, ignite, torchvision, itk, tqdm, lmdb, psutil, cucim, openslide, pandas] +[nibabel, skimage, pillow, tensorboard, gdown, ignite, torchvision, itk, tqdm, lmdb, psutil, cucim, openslide, pandas, einops] ``` which correspond to `nibabel`, `scikit-image`, `pillow`, `tensorboard`, -`gdown`, `pytorch-ignite`, `torchvision`, `itk`, `tqdm`, `lmdb`, `psutil`, `cucim` `openslide-python` and `pandas`, respectively. +`gdown`, `pytorch-ignite`, `torchvision`, `itk`, `tqdm`, `lmdb`, `psutil`, `cucim`, `openslide-python`, `pandas` and `einops`, respectively. - `pip install 'monai[all]'` installs all the optional dependencies. diff --git a/docs/source/networks.rst b/docs/source/networks.rst index 4177696aa4..f70fc4f024 100644 --- a/docs/source/networks.rst +++ b/docs/source/networks.rst @@ -79,11 +79,30 @@ Blocks .. autoclass:: ResBlock :members: +`SABlock Block` +~~~~~~~~~~~~~~~~~ +.. autoclass:: SABlock + :members: + `Squeeze-and-Excitation` ~~~~~~~~~~~~~~~~~~~~~~~~ .. autoclass:: ChannelSELayer :members: +`Transformer Block` +~~~~~~~~~~~~~~~~~~~~~~~~ +.. autoclass:: TransformerBlock + :members: + +`UNETR Block` +~~~~~~~~~~~~~~~~~~~~ +.. autoclass:: UnetrBasicBlock + :members: +.. autoclass:: UnetrUpBlock + :members: +.. autoclass:: UnetrPrUpBlock + :members: + `Residual Squeeze-and-Excitation` ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. autoclass:: ResidualSELayer @@ -159,6 +178,16 @@ Blocks .. autoclass:: LocalNetFeatureExtractorBlock :members: +`MLP Block` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. autoclass:: MLPBlock + :members: + +`Patch Embedding Block` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. autoclass:: PatchEmbeddingBlock + :members: + `Warp` ~~~~~~ .. autoclass:: Warp @@ -389,6 +418,11 @@ Nets .. autoclass:: Unet .. autoclass:: unet +`UNETR` +~~~~~~~~~~~~~~~~~ +.. autoclass:: UNETR + :members: + `BasicUNet` ~~~~~~~~~~~ .. autoclass:: BasicUNet @@ -426,6 +460,11 @@ Nets .. autoclass:: VarAutoEncoder :members: +`ViT` +~~~~~~ +.. autoclass:: ViT + :members: + `FullyConnectedNet` ~~~~~~~~~~~~~~~~~~~ .. autoclass:: FullyConnectedNet diff --git a/monai/config/deviceconfig.py b/monai/config/deviceconfig.py index c7bd9e00e8..73180cab14 100644 --- a/monai/config/deviceconfig.py +++ b/monai/config/deviceconfig.py @@ -81,6 +81,7 @@ def get_optional_config_values(): output["lmdb"] = get_package_version("lmdb") output["psutil"] = psutil_version output["pandas"] = get_package_version("pandas") + output["einops"] = get_package_version("einops") return output diff --git a/monai/networks/blocks/__init__.py b/monai/networks/blocks/__init__.py index ed6ac12430..db723f622d 100644 --- a/monai/networks/blocks/__init__.py +++ b/monai/networks/blocks/__init__.py @@ -18,8 +18,11 @@ from .dynunet_block import UnetBasicBlock, UnetOutBlock, UnetResBlock, UnetUpBlock, get_output_padding, get_padding from .fcn import FCN, GCN, MCFCN, Refine from .localnet_block import LocalNetDownSampleBlock, LocalNetFeatureExtractorBlock, LocalNetUpSampleBlock +from .mlp import MLPBlock +from .patchembedding import PatchEmbeddingBlock from .regunet_block import RegistrationDownSampleBlock, RegistrationExtractionBlock, RegistrationResidualConvBlock from .segresnet_block import ResBlock +from .selfattention import SABlock from .squeeze_and_excitation import ( ChannelSELayer, ResidualSELayer, @@ -28,5 +31,7 @@ SEResNetBottleneck, SEResNeXtBottleneck, ) +from .transformerblock import TransformerBlock +from .unetr_block import UnetrBasicBlock, UnetrPrUpBlock, UnetrUpBlock from .upsample import SubpixelUpsample, Subpixelupsample, SubpixelUpSample, Upsample, UpSample from .warp import DVF2DDF, Warp diff --git a/monai/networks/blocks/mlp.py b/monai/networks/blocks/mlp.py new file mode 100644 index 0000000000..b108188605 --- /dev/null +++ b/monai/networks/blocks/mlp.py @@ -0,0 +1,51 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch.nn as nn + + +class MLPBlock(nn.Module): + """ + A multi-layer perceptron block, based on: "Dosovitskiy et al., + An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale " + """ + + def __init__( + self, + hidden_size: int, + mlp_dim: int, + dropout_rate: float = 0.0, + ) -> None: + """ + Args: + hidden_size: dimension of hidden layer. + mlp_dim: dimension of feedforward layer. + dropout_rate: faction of the input units to drop. + + """ + + super().__init__() + + if not (0 <= dropout_rate <= 1): + raise AssertionError("dropout_rate should be between 0 and 1.") + + self.linear1 = nn.Linear(hidden_size, mlp_dim) + self.linear2 = nn.Linear(mlp_dim, hidden_size) + self.fn = nn.GELU() + self.drop1 = nn.Dropout(dropout_rate) + self.drop2 = nn.Dropout(dropout_rate) + + def forward(self, x): + x = self.fn(self.linear1(x)) + x = self.drop1(x) + x = self.linear2(x) + x = self.drop2(x) + return x diff --git a/monai/networks/blocks/patchembedding.py b/monai/networks/blocks/patchembedding.py new file mode 100644 index 0000000000..6b80fdcc40 --- /dev/null +++ b/monai/networks/blocks/patchembedding.py @@ -0,0 +1,138 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import math +from typing import Tuple, Union + +import torch +import torch.nn as nn + +from monai.utils import optional_import + +einops, has_einops = optional_import("einops") + + +class PatchEmbeddingBlock(nn.Module): + """ + A patch embedding block, based on: "Dosovitskiy et al., + An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale " + """ + + def __init__( + self, + in_channels: int, + img_size: Union[int, Tuple[int, int, int]], + patch_size: Union[int, Tuple[int, int, int]], + hidden_size: int, + num_heads: int, + pos_embed: Union[Tuple, str], # type: ignore + classification: bool, + dropout_rate: float = 0.0, + ) -> None: + """ + Args: + in_channels: dimension of input channels. + img_size: dimension of input image. + patch_size: dimension of patch size. + hidden_size: dimension of hidden layer. + num_heads: number of attention heads. + pos_embed: position embedding layer type. + classification: bool argument to determine if classification is used. + dropout_rate: faction of the input units to drop. + + """ + + super().__init__() + + if not (0 <= dropout_rate <= 1): + raise AssertionError("dropout_rate should be between 0 and 1.") + + if hidden_size % num_heads != 0: + raise AssertionError("hidden size should be divisible by num_heads.") + + if img_size < patch_size: # type: ignore + raise AssertionError("patch_size should be smaller than img_size.") + + if pos_embed not in ["conv", "perceptron"]: + raise KeyError(f"Position embedding layer of type {pos_embed} is not supported.") + + if pos_embed == "perceptron": + if img_size[0] % patch_size[0] != 0: # type: ignore + raise AssertionError("img_size should be divisible by patch_size for perceptron patch embedding.") + + if has_einops: # type: ignore + from einops.layers.torch import Rearrange # type: ignore + + self.Rearrange = Rearrange # type: ignore + else: + raise ValueError('"Requires einops.') + + self.n_patches = ( + (img_size[0] // patch_size[0]) * (img_size[1] // patch_size[1]) * (img_size[2] // patch_size[2]) # type: ignore + ) + self.patch_dim = in_channels * patch_size[0] * patch_size[1] * patch_size[2] # type: ignore + self.pos_embed = pos_embed + if self.pos_embed == "conv": + self.patch_embeddings = nn.Conv3d( + in_channels=in_channels, out_channels=hidden_size, kernel_size=patch_size, stride=patch_size # type: ignore + ) + elif self.pos_embed == "perceptron": + self.patch_embeddings = nn.Sequential( # type: ignore + self.Rearrange( + "b c (h p1) (w p2) (d p3)-> b (h w d) (p1 p2 p3 c)", + p1=patch_size[0], # type: ignore + p2=patch_size[1], # type: ignore + p3=patch_size[2], # type: ignore + ), + nn.Linear(self.patch_dim, hidden_size), + ) + self.position_embeddings = nn.Parameter(torch.zeros(1, self.n_patches, hidden_size)) + self.cls_token = nn.Parameter(torch.zeros(1, 1, hidden_size)) + self.dropout = nn.Dropout(dropout_rate) + self.trunc_normal_(self.position_embeddings, mean=0.0, std=0.02, a=-2.0, b=2.0) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + self.trunc_normal_(m.weight, mean=0.0, std=0.02, a=-2.0, b=2.0) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def trunc_normal_(self, tensor, mean, std, a, b): + # From PyTorch official master until it's in a few official releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 + + with torch.no_grad(): + l = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) + tensor.uniform_(2 * l - 1, 2 * u - 1) + tensor.erfinv_() + tensor.mul_(std * math.sqrt(2.0)) + tensor.add_(mean) + tensor.clamp_(min=a, max=b) + return tensor + + def forward(self, x): + if self.pos_embed == "conv": + x = self.patch_embeddings(x) + x = x.flatten(2) + x = x.transpose(-1, -2) + elif self.pos_embed == "perceptron": + x = self.patch_embeddings(x) + embeddings = x + self.position_embeddings + embeddings = self.dropout(embeddings) + return embeddings diff --git a/monai/networks/blocks/selfattention.py b/monai/networks/blocks/selfattention.py new file mode 100644 index 0000000000..bd5bbfa072 --- /dev/null +++ b/monai/networks/blocks/selfattention.py @@ -0,0 +1,68 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn + +from monai.utils import optional_import + +einops, has_einops = optional_import("einops") + + +class SABlock(nn.Module): + """ + A self-attention block, based on: "Dosovitskiy et al., + An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale " + """ + + def __init__( + self, + hidden_size: int, + num_heads: int, + dropout_rate: float = 0.0, + ) -> None: + """ + Args: + hidden_size: dimension of hidden layer. + num_heads: number of attention heads. + dropout_rate: faction of the input units to drop. + + """ + + super().__init__() + + if not (0 <= dropout_rate <= 1): + raise AssertionError("dropout_rate should be between 0 and 1.") + + if hidden_size % num_heads != 0: + raise AssertionError("hidden size should be divisible by num_heads.") + + self.num_heads = num_heads + self.out_proj = nn.Linear(hidden_size, hidden_size) + self.qkv = nn.Linear(hidden_size, hidden_size * 3, bias=False) + self.drop_output = nn.Dropout(dropout_rate) + self.drop_weights = nn.Dropout(dropout_rate) + self.head_dim = hidden_size // num_heads + self.scale = self.head_dim ** -0.5 + if has_einops: + self.rearrange = einops.rearrange + else: + raise ValueError('"Requires einops.') + + def forward(self, x): + q, k, v = self.rearrange(self.qkv(x), "b h (qkv l d) -> qkv b l h d", qkv=3, l=self.num_heads) + att_mat = (torch.einsum("blxd,blyd->blxy", q, k) * self.scale).softmax(dim=-1) + att_mat = self.drop_weights(att_mat) + x = torch.einsum("bhxy,bhyd->bhxd", att_mat, v) + x = self.rearrange(x, "b h l d -> b l (h d)") + x = self.out_proj(x) + x = self.drop_output(x) + return x diff --git a/monai/networks/blocks/transformerblock.py b/monai/networks/blocks/transformerblock.py new file mode 100644 index 0000000000..3dd80f58ad --- /dev/null +++ b/monai/networks/blocks/transformerblock.py @@ -0,0 +1,56 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch.nn as nn + +from monai.networks.blocks.mlp import MLPBlock +from monai.networks.blocks.selfattention import SABlock + + +class TransformerBlock(nn.Module): + """ + A transformer block, based on: "Dosovitskiy et al., + An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale " + """ + + def __init__( + self, + hidden_size: int, + mlp_dim: int, + num_heads: int, + dropout_rate: float = 0.0, + ) -> None: + """ + Args: + hidden_size: dimension of hidden layer. + mlp_dim: dimension of feedforward layer. + num_heads: number of attention heads. + dropout_rate: faction of the input units to drop. + + """ + + super().__init__() + + if not (0 <= dropout_rate <= 1): + raise AssertionError("dropout_rate should be between 0 and 1.") + + if hidden_size % num_heads != 0: + raise AssertionError("hidden size should be divisible by num_heads.") + + self.mlp = MLPBlock(hidden_size, mlp_dim, dropout_rate) + self.norm1 = nn.LayerNorm(hidden_size) + self.attn = SABlock(hidden_size, num_heads, dropout_rate) + self.norm2 = nn.LayerNorm(hidden_size) + + def forward(self, x): + x = x + self.attn(self.norm1(x)) + x = x + self.mlp(self.norm2(x)) + return x diff --git a/monai/networks/blocks/unetr_block.py b/monai/networks/blocks/unetr_block.py new file mode 100644 index 0000000000..20c39f6240 --- /dev/null +++ b/monai/networks/blocks/unetr_block.py @@ -0,0 +1,261 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Sequence, Tuple, Union + +import torch +import torch.nn as nn + +from monai.networks.blocks.dynunet_block import UnetBasicBlock, UnetResBlock, get_conv_layer + + +class UnetrUpBlock(nn.Module): + """ + An upsampling module that can be used for UNETR: "Hatamizadeh et al., + UNETR: Transformers for 3D Medical Image Segmentation " + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, # type: ignore + kernel_size: Union[Sequence[int], int], + stride: Union[Sequence[int], int], + upsample_kernel_size: Union[Sequence[int], int], + norm_name: Union[Tuple, str], + res_block: bool = False, + ) -> None: + """ + Args: + spatial_dims: number of spatial dimensions. + in_channels: number of input channels. + out_channels: number of output channels. + kernel_size: convolution kernel size. + stride: convolution stride. + upsample_kernel_size: convolution kernel size for transposed convolution layers. + norm_name: feature normalization type and arguments. + res_block: bool argument to determine if residual block is used. + + """ + + super(UnetrUpBlock, self).__init__() + upsample_stride = upsample_kernel_size + self.transp_conv = get_conv_layer( + spatial_dims, + in_channels, + out_channels, + kernel_size=upsample_kernel_size, + stride=upsample_stride, + conv_only=True, + is_transposed=True, + ) + + if res_block: + self.conv_block = UnetResBlock( + spatial_dims, + out_channels + out_channels, + out_channels, + kernel_size=kernel_size, + stride=1, + norm_name=norm_name, + ) + else: + self.conv_block = UnetBasicBlock( # type: ignore + spatial_dims, + out_channels + out_channels, + out_channels, + kernel_size=kernel_size, + stride=1, + norm_name=norm_name, + ) + + def forward(self, inp, skip): + # number of channels for skip should equals to out_channels + out = self.transp_conv(inp) + out = torch.cat((out, skip), dim=1) + out = self.conv_block(out) + return out + + +class UnetrPrUpBlock(nn.Module): + """ + A projection upsampling module that can be used for UNETR: "Hatamizadeh et al., + UNETR: Transformers for 3D Medical Image Segmentation " + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + num_layer: int, + kernel_size: Union[Sequence[int], int], + stride: Union[Sequence[int], int], + upsample_kernel_size: Union[Sequence[int], int], + norm_name: Union[Tuple, str], + conv_block: bool = False, + res_block: bool = False, + ) -> None: + """ + Args: + spatial_dims: number of spatial dimensions. + in_channels: number of input channels. + out_channels: number of output channels. + num_layer: number of upsampling blocks. + kernel_size: convolution kernel size. + stride: convolution stride. + upsample_kernel_size: convolution kernel size for transposed convolution layers. + norm_name: feature normalization type and arguments. + conv_block: bool argument to determine if convolutional block is used. + res_block: bool argument to determine if residual block is used. + + """ + + super().__init__() + + upsample_stride = upsample_kernel_size + self.transp_conv_init = get_conv_layer( + spatial_dims, + in_channels, + out_channels, + kernel_size=upsample_kernel_size, + stride=upsample_stride, + conv_only=True, + is_transposed=True, + ) + if conv_block: + if res_block: + self.blocks = nn.ModuleList( + [ + nn.Sequential( + get_conv_layer( + spatial_dims, + out_channels, + out_channels, + kernel_size=upsample_kernel_size, + stride=upsample_stride, + conv_only=True, + is_transposed=True, + ), + UnetResBlock( + spatial_dims=3, + in_channels=out_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + norm_name=norm_name, + ), + ) + for i in range(num_layer) + ] + ) + else: + self.blocks = nn.ModuleList( + [ + nn.Sequential( + get_conv_layer( + spatial_dims, + out_channels, + out_channels, + kernel_size=upsample_kernel_size, + stride=upsample_stride, + conv_only=True, + is_transposed=True, + ), + UnetBasicBlock( + spatial_dims=3, + in_channels=out_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + norm_name=norm_name, + ), + ) + for i in range(num_layer) + ] + ) + else: + self.blocks = nn.ModuleList( + [ + get_conv_layer( + spatial_dims, + out_channels, + out_channels, + kernel_size=upsample_kernel_size, + stride=upsample_stride, + conv_only=True, + is_transposed=True, + ) + for i in range(num_layer) + ] + ) + + def forward(self, x): + x = self.transp_conv_init(x) + for blk in self.blocks: + x = blk(x) + return x + + +class UnetrBasicBlock(nn.Module): + """ + A CNN module that can be used for UNETR, based on: "Hatamizadeh et al., + UNETR: Transformers for 3D Medical Image Segmentation " + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + kernel_size: Union[Sequence[int], int], + stride: Union[Sequence[int], int], + norm_name: Union[Tuple, str], + res_block: bool = False, + ) -> None: + """ + Args: + spatial_dims: number of spatial dimensions. + in_channels: number of input channels. + out_channels: number of output channels. + kernel_size: convolution kernel size. + stride: convolution stride. + norm_name: feature normalization type and arguments. + res_block: bool argument to determine if residual block is used. + + """ + + super().__init__() + + if res_block: + self.layer = UnetResBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + norm_name=norm_name, + ) + else: + self.layer = UnetBasicBlock( # type: ignore + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + norm_name=norm_name, + ) + + def forward(self, inp): + out = self.layer(inp) + return out diff --git a/monai/networks/nets/__init__.py b/monai/networks/nets/__init__.py index 82a68aeea8..9cf6c5e07f 100644 --- a/monai/networks/nets/__init__.py +++ b/monai/networks/nets/__init__.py @@ -74,5 +74,7 @@ ) from .torchvision_fc import TorchVisionFCModel, TorchVisionFullyConvModel from .unet import UNet, Unet, unet +from .unetr import UNETR from .varautoencoder import VarAutoEncoder +from .vit import ViT from .vnet import VNet diff --git a/monai/networks/nets/unetr.py b/monai/networks/nets/unetr.py new file mode 100644 index 0000000000..4e8b68b43e --- /dev/null +++ b/monai/networks/nets/unetr.py @@ -0,0 +1,198 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Tuple, Union + +import torch.nn as nn + +from monai.networks.blocks.dynunet_block import UnetOutBlock +from monai.networks.blocks.unetr_block import UnetrBasicBlock, UnetrPrUpBlock, UnetrUpBlock +from monai.networks.nets.vit import ViT + + +class UNETR(nn.Module): + """ + UNETR based on: "Hatamizadeh et al., + UNETR: Transformers for 3D Medical Image Segmentation " + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + img_size: Tuple, # type: ignore + feature_size: int, + hidden_size: int, + mlp_dim: int, + num_heads: int, + pos_embed: Union[Tuple, str], + norm_name: Union[Tuple, str], + conv_block: bool = False, + res_block: bool = False, + dropout_rate: float = 0.0, + ) -> None: + """ + Args: + in_channels: dimension of input channels. + out_channels: dimension of output channels. + img_size: dimension of input image. + feature_size: dimension of network feature size. + hidden_size: dimension of hidden layer. + mlp_dim: dimension of feedforward layer. + num_heads: number of attention heads. + pos_embed: position embedding layer type. + norm_name: feature normalization type and arguments. + conv_block: bool argument to determine if convolutional block is used. + res_block: bool argument to determine if residual block is used. + dropout_rate: faction of the input units to drop. + + """ + + super().__init__() + + if not (0 <= dropout_rate <= 1): + raise AssertionError("dropout_rate should be between 0 and 1.") + + if hidden_size % num_heads != 0: + raise AssertionError("hidden size should be divisible by num_heads.") + + if pos_embed not in ["conv", "perceptron"]: + raise KeyError(f"Position embedding layer of type {pos_embed} is not supported.") + + self.num_layers = 12 + self.patch_size = (16, 16, 16) + self.feat_size = ( + img_size[0] // self.patch_size[0], # type: ignore + img_size[1] // self.patch_size[1], # type: ignore + img_size[2] // self.patch_size[2], # type: ignore + ) + self.hidden_size = hidden_size + self.classification = False + self.vit = ViT( + in_channels=in_channels, + img_size=img_size, + patch_size=self.patch_size, + hidden_size=hidden_size, + mlp_dim=mlp_dim, + num_layers=self.num_layers, + num_heads=num_heads, + pos_embed=pos_embed, + classification=self.classification, + dropout_rate=dropout_rate, + ) + self.encoder1 = UnetrBasicBlock( + spatial_dims=3, + in_channels=in_channels, + out_channels=feature_size, + kernel_size=3, + stride=1, + norm_name=norm_name, + res_block=res_block, + ) + self.encoder2 = UnetrPrUpBlock( + spatial_dims=3, + in_channels=hidden_size, + out_channels=feature_size * 2, + num_layer=2, + kernel_size=3, + stride=1, + upsample_kernel_size=2, + norm_name=norm_name, + conv_block=conv_block, + res_block=res_block, + ) + self.encoder3 = UnetrPrUpBlock( + spatial_dims=3, + in_channels=hidden_size, + out_channels=feature_size * 4, + num_layer=1, + kernel_size=3, + stride=1, + upsample_kernel_size=2, + norm_name=norm_name, + conv_block=conv_block, + res_block=res_block, + ) + self.encoder4 = UnetrPrUpBlock( + spatial_dims=3, + in_channels=hidden_size, + out_channels=feature_size * 8, + num_layer=0, + kernel_size=3, + stride=1, + upsample_kernel_size=2, + norm_name=norm_name, + conv_block=conv_block, + res_block=res_block, + ) + self.decoder5 = UnetrUpBlock( + spatial_dims=3, + in_channels=hidden_size, + out_channels=feature_size * 8, + stride=1, + kernel_size=3, + upsample_kernel_size=2, + norm_name=norm_name, + res_block=res_block, + ) + self.decoder4 = UnetrUpBlock( + spatial_dims=3, + in_channels=feature_size * 8, + out_channels=feature_size * 4, + stride=1, + kernel_size=3, + upsample_kernel_size=2, + norm_name=norm_name, + res_block=res_block, + ) + self.decoder3 = UnetrUpBlock( + spatial_dims=3, + in_channels=feature_size * 4, + out_channels=feature_size * 2, + stride=1, + kernel_size=3, + upsample_kernel_size=2, + norm_name=norm_name, + res_block=res_block, + ) + self.decoder2 = UnetrUpBlock( + spatial_dims=3, + in_channels=feature_size * 2, + out_channels=feature_size, + stride=1, + kernel_size=3, + upsample_kernel_size=2, + norm_name=norm_name, + res_block=res_block, + ) + self.out = UnetOutBlock(spatial_dims=3, in_channels=feature_size, out_channels=out_channels) # type: ignore + + def proj_feat(self, x, hidden_size, feat_size): + x = x.view(x.size(0), feat_size[0], feat_size[1], feat_size[2], hidden_size) + x = x.permute(0, 4, 1, 2, 3).contiguous() + return x + + def forward(self, x_in): + x, hidden_states_out = self.vit(x_in) + enc1 = self.encoder1(x_in) + x2 = hidden_states_out[3] + enc2 = self.encoder2(self.proj_feat(x2, self.hidden_size, self.feat_size)) + x3 = hidden_states_out[6] + enc3 = self.encoder3(self.proj_feat(x3, self.hidden_size, self.feat_size)) + x4 = hidden_states_out[9] + enc4 = self.encoder4(self.proj_feat(x4, self.hidden_size, self.feat_size)) + dec4 = self.proj_feat(x, self.hidden_size, self.feat_size) + dec3 = self.decoder5(dec4, enc4) + dec2 = self.decoder4(dec3, enc3) + dec1 = self.decoder3(dec2, enc2) + out = self.decoder2(dec1, enc1) + logits = self.out(out) + return logits diff --git a/monai/networks/nets/vit.py b/monai/networks/nets/vit.py new file mode 100644 index 0000000000..1a43414aca --- /dev/null +++ b/monai/networks/nets/vit.py @@ -0,0 +1,91 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Tuple, Union + +import torch.nn as nn + +from monai.networks.blocks.patchembedding import PatchEmbeddingBlock +from monai.networks.blocks.transformerblock import TransformerBlock + + +class ViT(nn.Module): + """ + Vision Transformer (ViT), based on: "Dosovitskiy et al., + An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale " + """ + + def __init__( + self, + in_channels: int, + img_size: Tuple, # type: ignore + patch_size: Tuple, # type: ignore + hidden_size: int, + mlp_dim: int, + num_layers: int, + num_heads: int, + pos_embed: Union[Tuple, str], + classification: bool, + num_classes: int = 2, + dropout_rate: float = 0.0, + ) -> None: + """ + Args: + in_channels: dimension of input channels. + img_size: dimension of input image. + patch_size: dimension of patch size. + hidden_size: dimension of hidden layer. + mlp_dim: dimension of feedforward layer. + num_layers: number of transformer blocks. + num_heads: number of attention heads. + pos_embed: position embedding layer type. + classification: bool argument to determine if classification is used. + num_classes: number of classes if classification is used. + dropout_rate: faction of the input units to drop. + + """ + + super().__init__() + + if not (0 <= dropout_rate <= 1): + raise AssertionError("dropout_rate should be between 0 and 1.") + + if hidden_size % num_heads != 0: + raise AssertionError("hidden size should be divisible by num_heads.") + + if img_size < patch_size: + raise AssertionError("patch_size should be smaller than img_size.") + + if pos_embed not in ["conv", "perceptron"]: + raise KeyError(f"Position embedding layer of type {pos_embed} is not supported.") + + self.classification = classification + self.patch_embedding = PatchEmbeddingBlock( + in_channels, img_size, patch_size, hidden_size, num_heads, pos_embed, classification, dropout_rate # type: ignore + ) + self.blocks = nn.ModuleList( + [TransformerBlock(hidden_size, mlp_dim, num_heads, dropout_rate) for i in range(num_layers)] # type: ignore + ) + self.norm = nn.LayerNorm(hidden_size) + if self.classification: + self.classification_head = nn.Linear(hidden_size, num_classes) + + def forward(self, x): + x = self.patch_embedding(x) + hidden_states_out = [] + for blk in self.blocks: + x = blk(x) + hidden_states_out.append(x) + x = self.norm(x) + if self.classification: + x = self.classification_head(x[:, 0]) + return x, hidden_states_out diff --git a/requirements-dev.txt b/requirements-dev.txt index 51a50074ff..ea20b89b32 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -34,4 +34,5 @@ sphinx-rtd-theme==0.5.2 cucim~=0.19.0; platform_system == "Linux" openslide-python==1.1.2 pandas -requests \ No newline at end of file +requests +einops diff --git a/setup.cfg b/setup.cfg index a78794fd94..082119370f 100644 --- a/setup.cfg +++ b/setup.cfg @@ -43,6 +43,7 @@ all = cucim~=0.19.0 openslide-python==1.1.2 pandas + einops nibabel = nibabel skimage = @@ -71,7 +72,8 @@ openslide = openslide-python==1.1.2 pandas = pandas - +einops = + einops [flake8] select = B,C,E,F,N,P,T4,W,B9 max_line_length = 120 diff --git a/tests/min_tests.py b/tests/min_tests.py index 046f9b4a40..a3f140b856 100644 --- a/tests/min_tests.py +++ b/tests/min_tests.py @@ -127,6 +127,13 @@ def run_testsuit(): "test_write_metrics_reports", "test_csv_dataset", "test_csv_iterable_dataset", + "test_mlp", + "test_patchembedding", + "test_selfattention", + "test_transformerblock", + "test_unetr", + "test_unetr_block", + "test_vit", ] assert sorted(exclude_cases) == sorted(set(exclude_cases)), f"Duplicated items in {exclude_cases}" diff --git a/tests/test_mlp.py b/tests/test_mlp.py new file mode 100644 index 0000000000..efc8db74c2 --- /dev/null +++ b/tests/test_mlp.py @@ -0,0 +1,52 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import torch +from parameterized import parameterized + +from monai.networks import eval_mode +from monai.networks.blocks.mlp import MLPBlock + +TEST_CASE_MLP = [] +for dropout_rate in np.linspace(0, 1, 4): + for hidden_size in [128, 256, 512, 768]: + for mlp_dim in [512, 1028, 2048, 3072]: + + test_case = [ + { + "hidden_size": hidden_size, + "mlp_dim": mlp_dim, + "dropout_rate": dropout_rate, + }, + (2, 512, hidden_size), + (2, 512, hidden_size), + ] + TEST_CASE_MLP.append(test_case) + + +class TestMLPBlock(unittest.TestCase): + @parameterized.expand(TEST_CASE_MLP) + def test_shape(self, input_param, input_shape, expected_shape): + net = MLPBlock(**input_param) + with eval_mode(net): + result = net(torch.randn(input_shape)) + self.assertEqual(result.shape, expected_shape) + + def test_ill_arg(self): + with self.assertRaises(AssertionError): + MLPBlock(hidden_size=128, mlp_dim=512, dropout_rate=5.0) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_patchembedding.py b/tests/test_patchembedding.py new file mode 100644 index 0000000000..c19c970c2b --- /dev/null +++ b/tests/test_patchembedding.py @@ -0,0 +1,127 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from unittest import skipUnless + +import numpy as np +import torch +from parameterized import parameterized + +from monai.networks import eval_mode +from monai.networks.blocks.patchembedding import PatchEmbeddingBlock +from monai.utils import optional_import + +einops, has_einops = optional_import("einops") + +TEST_CASE_PATCHEMBEDDINGBLOCK = [] +for dropout_rate in np.linspace(0, 1, 2): + for in_channels in [1, 4]: + for hidden_size in [360, 768]: + for img_size in [96, 128]: + for patch_size in [8, 16]: + for num_heads in [8, 12]: + for pos_embed in ["conv", "perceptron"]: + for classification in ["False", "True"]: + if classification: + out = (2, (img_size // patch_size) ** 3 + 1, hidden_size) + else: + out = (2, (img_size // patch_size) ** 3, hidden_size) + test_case = [ + { + "in_channels": in_channels, + "img_size": (img_size, img_size, img_size), + "patch_size": (patch_size, patch_size, patch_size), + "hidden_size": hidden_size, + "num_heads": num_heads, + "pos_embed": pos_embed, + "classification": classification, + "dropout_rate": dropout_rate, + }, + (2, in_channels, img_size, *([img_size] * 2)), + (2, (img_size // patch_size) ** 3, hidden_size), + ] + TEST_CASE_PATCHEMBEDDINGBLOCK.append(test_case) + + +class TestPatchEmbeddingBlock(unittest.TestCase): + @parameterized.expand(TEST_CASE_PATCHEMBEDDINGBLOCK) + @skipUnless(has_einops, "Requires einops") + def test_shape(self, input_param, input_shape, expected_shape): + net = PatchEmbeddingBlock(**input_param) + with eval_mode(net): + result = net(torch.randn(input_shape)) + self.assertEqual(result.shape, expected_shape) + + def test_ill_arg(self): + with self.assertRaises(AssertionError): + PatchEmbeddingBlock( + in_channels=1, + img_size=(128, 128, 128), + patch_size=(16, 16, 16), + hidden_size=128, + num_heads=12, + pos_embed="conv", + classification=False, + dropout_rate=5.0, + ) + + with self.assertRaises(AssertionError): + PatchEmbeddingBlock( + in_channels=1, + img_size=(32, 32, 32), + patch_size=(64, 64, 64), + hidden_size=512, + num_heads=8, + pos_embed="perceptron", + classification=False, + dropout_rate=0.3, + ) + + with self.assertRaises(AssertionError): + PatchEmbeddingBlock( + in_channels=1, + img_size=(96, 96, 96), + patch_size=(8, 8, 8), + hidden_size=512, + num_heads=14, + pos_embed="conv", + classification=False, + dropout_rate=0.3, + ) + + with self.assertRaises(AssertionError): + PatchEmbeddingBlock( + in_channels=1, + img_size=(97, 97, 97), + patch_size=(4, 4, 4), + hidden_size=768, + num_heads=8, + pos_embed="perceptron", + classification=False, + dropout_rate=0.3, + ) + + with self.assertRaises(KeyError): + PatchEmbeddingBlock( + in_channels=4, + img_size=(96, 96, 96), + patch_size=(16, 16, 16), + hidden_size=768, + num_heads=12, + pos_embed="perc", + classification=False, + dropout_rate=0.3, + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_selfattention.py b/tests/test_selfattention.py new file mode 100644 index 0000000000..2430b82c9b --- /dev/null +++ b/tests/test_selfattention.py @@ -0,0 +1,60 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from unittest import skipUnless + +import numpy as np +import torch +from parameterized import parameterized + +from monai.networks import eval_mode +from monai.networks.blocks.selfattention import SABlock +from monai.utils import optional_import + +einops, has_einops = optional_import("einops") + +TEST_CASE_SABLOCK = [] +for dropout_rate in np.linspace(0, 1, 4): + for hidden_size in [360, 480, 600, 768]: + for num_heads in [4, 6, 8, 12]: + + test_case = [ + { + "hidden_size": hidden_size, + "num_heads": num_heads, + "dropout_rate": dropout_rate, + }, + (2, 512, hidden_size), + (2, 512, hidden_size), + ] + TEST_CASE_SABLOCK.append(test_case) + + +class TestResBlock(unittest.TestCase): + @parameterized.expand(TEST_CASE_SABLOCK) + @skipUnless(has_einops, "Requires einops") + def test_shape(self, input_param, input_shape, expected_shape): + net = SABlock(**input_param) + with eval_mode(net): + result = net(torch.randn(input_shape)) + self.assertEqual(result.shape, expected_shape) + + def test_ill_arg(self): + with self.assertRaises(AssertionError): + SABlock(hidden_size=128, num_heads=12, dropout_rate=6.0) + + with self.assertRaises(AssertionError): + SABlock(hidden_size=620, num_heads=8, dropout_rate=0.4) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_transformerblock.py b/tests/test_transformerblock.py new file mode 100644 index 0000000000..24d16c77aa --- /dev/null +++ b/tests/test_transformerblock.py @@ -0,0 +1,57 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import torch +from parameterized import parameterized + +from monai.networks import eval_mode +from monai.networks.blocks.transformerblock import TransformerBlock + +TEST_CASE_TRANSFORMERBLOCK = [] +for dropout_rate in np.linspace(0, 1, 4): + for hidden_size in [360, 480, 600, 768]: + for num_heads in [4, 8, 12]: + for mlp_dim in [1024, 3072]: + + test_case = [ + { + "hidden_size": hidden_size, + "num_heads": num_heads, + "mlp_dim": mlp_dim, + "dropout_rate": dropout_rate, + }, + (2, 512, hidden_size), + (2, 512, hidden_size), + ] + TEST_CASE_TRANSFORMERBLOCK.append(test_case) + + +class TestTransformerBlock(unittest.TestCase): + @parameterized.expand(TEST_CASE_TRANSFORMERBLOCK) + def test_shape(self, input_param, input_shape, expected_shape): + net = TransformerBlock(**input_param) + with eval_mode(net): + result = net(torch.randn(input_shape)) + self.assertEqual(result.shape, expected_shape) + + def test_ill_arg(self): + with self.assertRaises(AssertionError): + TransformerBlock(hidden_size=128, num_heads=12, mlp_dim=2048, dropout_rate=4.0) + + with self.assertRaises(AssertionError): + TransformerBlock(hidden_size=622, num_heads=8, mlp_dim=3072, dropout_rate=0.4) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_unetr.py b/tests/test_unetr.py new file mode 100644 index 0000000000..cd50cb487c --- /dev/null +++ b/tests/test_unetr.py @@ -0,0 +1,121 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +from parameterized import parameterized + +from monai.networks import eval_mode +from monai.networks.nets.unetr import UNETR + +TEST_CASE_UNETR = [] +for dropout_rate in [0.4]: + for in_channels in [1]: + for out_channels in [2]: + for hidden_size in [768]: + for img_size in [96, 128]: + for feature_size in [16]: + for num_heads in [8]: + for mlp_dim in [3072]: + for norm_name in ["instance"]: + for pos_embed in ["perceptron"]: + for conv_block in [True]: + for res_block in [False]: + test_case = [ + { + "in_channels": in_channels, + "out_channels": out_channels, + "img_size": (img_size, img_size, img_size), + "hidden_size": hidden_size, + "feature_size": feature_size, + "norm_name": norm_name, + "mlp_dim": mlp_dim, + "num_heads": num_heads, + "pos_embed": pos_embed, + "dropout_rate": dropout_rate, + "conv_block": conv_block, + "res_block": res_block, + }, + (2, in_channels, img_size, *([img_size] * 2)), + (2, out_channels, img_size, *([img_size] * 2)), + ] + TEST_CASE_UNETR.append(test_case) + + +class TestPatchEmbeddingBlock(unittest.TestCase): + @parameterized.expand(TEST_CASE_UNETR) + def test_shape(self, input_param, input_shape, expected_shape): + net = UNETR(**input_param) + with eval_mode(net): + result = net(torch.randn(input_shape)) + self.assertEqual(result.shape, expected_shape) + + def test_ill_arg(self): + with self.assertRaises(AssertionError): + UNETR( + in_channels=1, + out_channels=3, + img_size=(128, 128, 128), + feature_size=16, + hidden_size=128, + mlp_dim=3072, + num_heads=12, + pos_embed="conv", + norm_name="instance", + dropout_rate=5.0, + ) + + with self.assertRaises(AssertionError): + UNETR( + in_channels=1, + out_channels=4, + img_size=(32, 32, 32), + feature_size=32, + hidden_size=512, + mlp_dim=3072, + num_heads=12, + pos_embed="conv", + norm_name="instance", + dropout_rate=0.5, + ) + + with self.assertRaises(AssertionError): + UNETR( + in_channels=1, + out_channels=3, + img_size=(96, 96, 96), + feature_size=16, + hidden_size=512, + mlp_dim=3072, + num_heads=14, + pos_embed="conv", + norm_name="batch", + dropout_rate=0.4, + ) + + with self.assertRaises(KeyError): + UNETR( + in_channels=1, + out_channels=4, + img_size=(96, 96, 96), + feature_size=8, + hidden_size=768, + mlp_dim=3072, + num_heads=12, + pos_embed="perc", + norm_name="instance", + dropout_rate=0.2, + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_unetr_block.py b/tests/test_unetr_block.py new file mode 100644 index 0000000000..ba988d07e6 --- /dev/null +++ b/tests/test_unetr_block.py @@ -0,0 +1,159 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +from parameterized import parameterized + +from monai.networks import eval_mode +from monai.networks.blocks.dynunet_block import get_padding +from monai.networks.blocks.unetr_block import UnetrBasicBlock, UnetrPrUpBlock, UnetrUpBlock +from tests.utils import test_script_save + +TEST_CASE_UNETR_BASIC_BLOCK = [] +for spatial_dims in range(2, 4): + for kernel_size in [1, 3]: + for stride in [2]: + for norm_name in [("GROUP", {"num_groups": 16}), ("batch", {"track_running_stats": False}), "instance"]: + for in_size in [15, 16]: + padding = get_padding(kernel_size, stride) + if not isinstance(padding, int): + padding = padding[0] + out_size = int((in_size + 2 * padding - kernel_size) / stride) + 1 + test_case = [ + { + "spatial_dims": spatial_dims, + "in_channels": 16, + "out_channels": 16, + "kernel_size": kernel_size, + "norm_name": norm_name, + "stride": stride, + }, + (1, 16, *([in_size] * spatial_dims)), + (1, 16, *([out_size] * spatial_dims)), + ] + TEST_CASE_UNETR_BASIC_BLOCK.append(test_case) + +TEST_UP_BLOCK = [] +in_channels, out_channels = 4, 2 +for spatial_dims in range(2, 4): + for kernel_size in [1, 3]: + for stride in [1, 2]: + for res_block in [False, True]: + for norm_name in ["instance"]: + for in_size in [15, 16]: + out_size = in_size * stride + test_case = [ + { + "spatial_dims": spatial_dims, + "in_channels": in_channels, + "out_channels": out_channels, + "kernel_size": kernel_size, + "norm_name": norm_name, + "stride": stride, + "res_block": res_block, + "upsample_kernel_size": stride, + }, + (1, in_channels, *([in_size] * spatial_dims)), + (1, out_channels, *([out_size] * spatial_dims)), + (1, out_channels, *([in_size * stride] * spatial_dims)), + ] + TEST_UP_BLOCK.append(test_case) + + +TEST_PRUP_BLOCK = [] +in_channels, out_channels = 4, 2 +for spatial_dims in range(2, 4): + for kernel_size in [1, 3]: + for upsample_kernel_size in [2, 3]: + for stride in [1, 2]: + for res_block in [False, True]: + for norm_name in ["instance"]: + for in_size in [15, 16]: + for num_layer in [0, 2]: + in_size_tmp = in_size + for _num in range(num_layer + 1): + out_size = in_size_tmp * upsample_kernel_size + in_size_tmp = out_size + test_case = [ + { + "spatial_dims": spatial_dims, + "in_channels": in_channels, + "out_channels": out_channels, + "num_layer": num_layer, + "kernel_size": kernel_size, + "norm_name": norm_name, + "stride": stride, + "res_block": res_block, + "upsample_kernel_size": upsample_kernel_size, + }, + (1, in_channels, *([in_size] * spatial_dims)), + (1, out_channels, *([out_size] * spatial_dims)), + ] + TEST_PRUP_BLOCK.append(test_case) + + +class TestResBasicBlock(unittest.TestCase): + @parameterized.expand(TEST_CASE_UNETR_BASIC_BLOCK) + def test_shape(self, input_param, input_shape, expected_shape): + for net in [UnetrBasicBlock(**input_param)]: + with eval_mode(net): + result = net(torch.randn(input_shape)) + self.assertEqual(result.shape, expected_shape) + + def test_ill_arg(self): + with self.assertRaises(KeyError): + UnetrBasicBlock(3, 4, 2, kernel_size=3, stride=1, norm_name="norm") + with self.assertRaises(AssertionError): + UnetrBasicBlock(3, 4, 2, kernel_size=1, stride=4, norm_name="batch") + + def test_script(self): + input_param, input_shape, _ = TEST_CASE_UNETR_BASIC_BLOCK[0] + net = UnetrBasicBlock(**input_param) + with eval_mode(net): + test_data = torch.randn(input_shape) + test_script_save(net, test_data) + + +class TestUpBlock(unittest.TestCase): + @parameterized.expand(TEST_UP_BLOCK) + def test_shape(self, input_param, input_shape, expected_shape, skip_shape): + net = UnetrUpBlock(**input_param) + with eval_mode(net): + result = net(torch.randn(input_shape), torch.randn(skip_shape)) + self.assertEqual(result.shape, expected_shape) + + def test_script(self): + input_param, input_shape, _, skip_shape = TEST_UP_BLOCK[0] + net = UnetrUpBlock(**input_param) + test_data = torch.randn(input_shape) + skip_data = torch.randn(skip_shape) + test_script_save(net, test_data, skip_data) + + +class TestPrUpBlock(unittest.TestCase): + @parameterized.expand(TEST_PRUP_BLOCK) + def test_shape(self, input_param, input_shape, expected_shape): + net = UnetrPrUpBlock(**input_param) + with eval_mode(net): + result = net(torch.randn(input_shape)) + self.assertEqual(result.shape, expected_shape) + + def test_script(self): + input_param, input_shape, _ = TEST_PRUP_BLOCK[0] + net = UnetrPrUpBlock(**input_param) + test_data = torch.randn(input_shape) + test_script_save(net, test_data) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_vit.py b/tests/test_vit.py new file mode 100644 index 0000000000..0d0d58093b --- /dev/null +++ b/tests/test_vit.py @@ -0,0 +1,137 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +from parameterized import parameterized + +from monai.networks import eval_mode +from monai.networks.nets.vit import ViT + +TEST_CASE_Vit = [] +for dropout_rate in [0.6]: + for in_channels in [4]: + for hidden_size in [768]: + for img_size in [96, 128]: + for patch_size in [16]: + for num_heads in [12]: + for mlp_dim in [3072]: + for num_layers in [4]: + for num_classes in [2]: + for pos_embed in ["conv"]: + for classification in ["False"]: + if classification: + out = (2, num_classes) + else: + out = (2, (img_size // patch_size) ** 3, hidden_size) # type: ignore + test_case = [ + { + "in_channels": in_channels, + "img_size": (img_size, img_size, img_size), + "patch_size": (patch_size, patch_size, patch_size), + "hidden_size": hidden_size, + "mlp_dim": mlp_dim, + "num_layers": num_layers, + "num_heads": num_heads, + "pos_embed": pos_embed, + "classification": classification, + "num_classes": num_classes, + "dropout_rate": dropout_rate, + }, + (2, in_channels, img_size, *([img_size] * 2)), + out, + ] + TEST_CASE_Vit.append(test_case) + + +class TestPatchEmbeddingBlock(unittest.TestCase): + @parameterized.expand(TEST_CASE_Vit) + def test_shape(self, input_param, input_shape, expected_shape): + net = ViT(**input_param) + with eval_mode(net): + result, _ = net(torch.randn(input_shape)) + self.assertEqual(result.shape, expected_shape) + + def test_ill_arg(self): + with self.assertRaises(AssertionError): + ViT( + in_channels=1, + img_size=(128, 128, 128), + patch_size=(16, 16, 16), + hidden_size=128, + mlp_dim=3072, + num_layers=12, + num_heads=12, + pos_embed="conv", + classification=False, + dropout_rate=5.0, + ) + + with self.assertRaises(AssertionError): + ViT( + in_channels=1, + img_size=(32, 32, 32), + patch_size=(64, 64, 64), + hidden_size=512, + mlp_dim=3072, + num_layers=12, + num_heads=8, + pos_embed="perceptron", + classification=False, + dropout_rate=0.3, + ) + + with self.assertRaises(AssertionError): + ViT( + in_channels=1, + img_size=(96, 96, 96), + patch_size=(8, 8, 8), + hidden_size=512, + mlp_dim=3072, + num_layers=12, + num_heads=14, + pos_embed="conv", + classification=False, + dropout_rate=0.3, + ) + + with self.assertRaises(AssertionError): + ViT( + in_channels=1, + img_size=(97, 97, 97), + patch_size=(4, 4, 4), + hidden_size=768, + mlp_dim=3072, + num_layers=12, + num_heads=8, + pos_embed="perceptron", + classification=False, + dropout_rate=0.3, + ) + + with self.assertRaises(KeyError): + ViT( + in_channels=4, + img_size=(96, 96, 96), + patch_size=(16, 16, 16), + hidden_size=768, + mlp_dim=3072, + num_layers=12, + num_heads=12, + pos_embed="perc", + classification=False, + dropout_rate=0.3, + ) + + +if __name__ == "__main__": + unittest.main()