diff --git a/docs/source/networks.rst b/docs/source/networks.rst index 36d62752d4..31bc6de6f8 100644 --- a/docs/source/networks.rst +++ b/docs/source/networks.rst @@ -470,6 +470,11 @@ Nets .. autoclass:: ViT :members: +`ViTAutoEnc` +~~~~~~~~~~~~ +.. autoclass:: ViTAutoEnc + :members: + `FullyConnectedNet` ~~~~~~~~~~~~~~~~~~~ .. autoclass:: FullyConnectedNet diff --git a/monai/networks/nets/__init__.py b/monai/networks/nets/__init__.py index 3b8d1dd6ec..a07297be13 100644 --- a/monai/networks/nets/__init__.py +++ b/monai/networks/nets/__init__.py @@ -83,4 +83,5 @@ from .unetr import UNETR from .varautoencoder import VarAutoEncoder from .vit import ViT +from .vitautoenc import ViTAutoEnc from .vnet import VNet diff --git a/monai/networks/nets/vitautoenc.py b/monai/networks/nets/vitautoenc.py new file mode 100644 index 0000000000..097534d230 --- /dev/null +++ b/monai/networks/nets/vitautoenc.py @@ -0,0 +1,115 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import math +from typing import Sequence, Union + +import torch +import torch.nn as nn + +from monai.networks.blocks.patchembedding import PatchEmbeddingBlock +from monai.networks.blocks.transformerblock import TransformerBlock + +__all__ = ["ViTAutoEnc"] + + +class ViTAutoEnc(nn.Module): + """ + Vision Transformer (ViT), based on: "Dosovitskiy et al., + An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale " + + Modified to also give same dimension outputs as the input size of the image + """ + + def __init__( + self, + in_channels: int, + img_size: Union[Sequence[int], int], + patch_size: Union[Sequence[int], int], + hidden_size: int = 768, + mlp_dim: int = 3072, + num_layers: int = 12, + num_heads: int = 12, + pos_embed: str = "conv", + dropout_rate: float = 0.0, + spatial_dims: int = 3, + ) -> None: + """ + Args: + in_channels: dimension of input channels or the number of channels for input + img_size: dimension of input image. + patch_size: dimension of patch size. + hidden_size: dimension of hidden layer. + mlp_dim: dimension of feedforward layer. + num_layers: number of transformer blocks. + num_heads: number of attention heads. + pos_embed: position embedding layer type. + dropout_rate: faction of the input units to drop. + spatial_dims: number of spatial dimensions. + + Examples:: + + # for single channel input with image size of (96,96,96), conv position embedding and segmentation backbone + # It will provide an output of same size as that of the input + >>> net = ViTAutoEnc(in_channels=1, patch_size=(16,16,16), img_size=(96,96,96), pos_embed='conv') + + # for 3-channel with image size of (128,128,128), output will be same size as of input + >>> net = ViTAutoEnc(in_channels=3, patch_size=(16,16,16), img_size=(128,128,128), pos_embed='conv') + + """ + + super().__init__() + + if not (0 <= dropout_rate <= 1): + raise ValueError("dropout_rate should be between 0 and 1.") + + if hidden_size % num_heads != 0: + raise ValueError("hidden_size should be divisible by num_heads.") + + if spatial_dims == 2: + raise ValueError("Not implemented for 2 dimensions, please try 3") + + self.patch_embedding = PatchEmbeddingBlock( + in_channels=in_channels, + img_size=img_size, + patch_size=patch_size, + hidden_size=hidden_size, + num_heads=num_heads, + pos_embed=pos_embed, + dropout_rate=dropout_rate, + spatial_dims=spatial_dims, + ) + self.blocks = nn.ModuleList( + [TransformerBlock(hidden_size, mlp_dim, num_heads, dropout_rate) for i in range(num_layers)] + ) + self.norm = nn.LayerNorm(hidden_size) + + new_patch_size = (4, 4, 4) + self.conv3d_transpose = nn.ConvTranspose3d(hidden_size, 16, kernel_size=new_patch_size, stride=new_patch_size) + self.conv3d_transpose_1 = nn.ConvTranspose3d( + in_channels=16, out_channels=1, kernel_size=new_patch_size, stride=new_patch_size + ) + + def forward(self, x): + x = self.patch_embedding(x) + hidden_states_out = [] + for blk in self.blocks: + x = blk(x) + hidden_states_out.append(x) + x = self.norm(x) + x = x.transpose(1, 2) + cuberoot = round(math.pow(x.size()[2], 1 / 3)) + x_shape = x.size() + x = torch.reshape(x, [x_shape[0], x_shape[1], cuberoot, cuberoot, cuberoot]) + x = self.conv3d_transpose(x) + x = self.conv3d_transpose_1(x) + return x, hidden_states_out diff --git a/tests/min_tests.py b/tests/min_tests.py index ccf1d58ebd..ed48f6986f 100644 --- a/tests/min_tests.py +++ b/tests/min_tests.py @@ -138,6 +138,7 @@ def run_testsuit(): "test_unetr", "test_unetr_block", "test_vit", + "test_vitautoenc", "test_write_metrics_reports", "test_zoom", "test_zoom_affine", diff --git a/tests/test_vitautoenc.py b/tests/test_vitautoenc.py new file mode 100644 index 0000000000..13cb0d8325 --- /dev/null +++ b/tests/test_vitautoenc.py @@ -0,0 +1,120 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest + +import torch +from parameterized import parameterized + +from monai.networks import eval_mode +from monai.networks.nets.vitautoenc import ViTAutoEnc + +TEST_CASE_Vitautoenc = [] +for in_channels in [1, 4]: + for img_size in [64, 96, 128]: + for patch_size in [16]: + for pos_embed in ["conv", "perceptron"]: + for nd in [3]: + test_case = [ + { + "in_channels": in_channels, + "img_size": (img_size,) * nd, + "patch_size": (patch_size,) * nd, + "hidden_size": 768, + "mlp_dim": 3072, + "num_layers": 4, + "num_heads": 12, + "pos_embed": pos_embed, + "dropout_rate": 0.6, + }, + (2, in_channels, *([img_size] * nd)), + (2, 1, *([img_size] * nd)), + ] + + TEST_CASE_Vitautoenc.append(test_case) + + +class TestPatchEmbeddingBlock(unittest.TestCase): + @parameterized.expand(TEST_CASE_Vitautoenc) + def test_shape(self, input_param, input_shape, expected_shape): + net = ViTAutoEnc(**input_param) + with eval_mode(net): + result, _ = net(torch.randn(input_shape)) + self.assertEqual(result.shape, expected_shape) + + def test_ill_arg(self): + with self.assertRaises(ValueError): + ViTAutoEnc( + in_channels=1, + img_size=(128, 128, 128), + patch_size=(16, 16, 16), + hidden_size=128, + mlp_dim=3072, + num_layers=12, + num_heads=12, + pos_embed="conv", + dropout_rate=5.0, + ) + + with self.assertRaises(ValueError): + ViTAutoEnc( + in_channels=1, + img_size=(32, 32, 32), + patch_size=(64, 64, 64), + hidden_size=512, + mlp_dim=3072, + num_layers=12, + num_heads=8, + pos_embed="perceptron", + dropout_rate=0.3, + ) + + with self.assertRaises(ValueError): + ViTAutoEnc( + in_channels=1, + img_size=(96, 96, 96), + patch_size=(8, 8, 8), + hidden_size=512, + mlp_dim=3072, + num_layers=12, + num_heads=14, + pos_embed="conv", + dropout_rate=0.3, + ) + + with self.assertRaises(ValueError): + ViTAutoEnc( + in_channels=1, + img_size=(97, 97, 97), + patch_size=(4, 4, 4), + hidden_size=768, + mlp_dim=3072, + num_layers=12, + num_heads=8, + pos_embed="perceptron", + dropout_rate=0.3, + ) + + with self.assertRaises(ValueError): + ViTAutoEnc( + in_channels=4, + img_size=(96, 96, 96), + patch_size=(16, 16, 16), + hidden_size=768, + mlp_dim=3072, + num_layers=12, + num_heads=12, + pos_embed="perc", + dropout_rate=0.3, + ) + + +if __name__ == "__main__": + unittest.main()