From 12d5e8ff7bb398c24abeecc7bfb431573a2df6d7 Mon Sep 17 00:00:00 2001 From: vnath Date: Tue, 26 Oct 2021 16:08:55 -0500 Subject: [PATCH 01/15] Adding ViT Autoencoder --- monai/networks/nets/__init__.py | 1 + monai/networks/nets/vit_ae.py | 134 ++++++++++++++++++++++++++++++++ tests/test_vit_ae.py | 128 ++++++++++++++++++++++++++++++ 3 files changed, 263 insertions(+) create mode 100644 monai/networks/nets/vit_ae.py create mode 100644 tests/test_vit_ae.py diff --git a/monai/networks/nets/__init__.py b/monai/networks/nets/__init__.py index 3b8d1dd6ec..a4069a0771 100644 --- a/monai/networks/nets/__init__.py +++ b/monai/networks/nets/__init__.py @@ -83,4 +83,5 @@ from .unetr import UNETR from .varautoencoder import VarAutoEncoder from .vit import ViT +from .vit_ae import ViT_AE from .vnet import VNet diff --git a/monai/networks/nets/vit_ae.py b/monai/networks/nets/vit_ae.py new file mode 100644 index 0000000000..cf9b186718 --- /dev/null +++ b/monai/networks/nets/vit_ae.py @@ -0,0 +1,134 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Sequence, Union + +import math +import torch +import torch.nn as nn + +from monai.networks.blocks.patchembedding import PatchEmbeddingBlock +from monai.networks.blocks.transformerblock import TransformerBlock + +__all__ = ["ViT_AE"] + + +class ViT_AE(nn.Module): + """ + Vision Transformer (ViT), based on: "Dosovitskiy et al., + An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale " + + Modified to also give same dimension outputs as the input size of the image + """ + + def __init__( + self, + in_channels: int, + img_size: Union[Sequence[int], int], + patch_size: Union[Sequence[int], int], + hidden_size: int = 768, + mlp_dim: int = 3072, + num_layers: int = 12, + num_heads: int = 12, + pos_embed: str = "conv", + classification: bool = False, + same_as_input_size: bool = True, + num_classes: int = 2, + dropout_rate: float = 0.0, + spatial_dims: int = 3, + ) -> None: + """ + Args: + in_channels: dimension of input channels. + img_size: dimension of input image. + patch_size: dimension of patch size. + hidden_size: dimension of hidden layer. + mlp_dim: dimension of feedforward layer. + num_layers: number of transformer blocks. + num_heads: number of attention heads. + pos_embed: position embedding layer type. + classification: bool argument to determine if classification is used. + num_classes: number of classes if classification is used. + dropout_rate: faction of the input units to drop. + spatial_dims: number of spatial dimensions. + + Examples:: + + # for single channel input with image size of (96,96,96), conv position embedding and segmentation backbone + >>> net = ViT(in_channels=1, img_size=(96,96,96), pos_embed='conv') + + # for 3-channel with image size of (128,128,128), 24 layers and classification backbone + >>> net = ViT(in_channels=3, img_size=(128,128,128), pos_embed='conv', classification=True) + + # for 3-channel with image size of (224,224), 12 layers and classification backbone + >>> net = ViT(in_channels=3, img_size=(224,224), pos_embed='conv', classification=True, spatial_dims=2) + + """ + + super(ViT_AE, self).__init__() + + if not (0 <= dropout_rate <= 1): + raise ValueError("dropout_rate should be between 0 and 1.") + + if hidden_size % num_heads != 0: + raise ValueError("hidden_size should be divisible by num_heads.") + + self.same_as_input_size = same_as_input_size + self.classification = classification + self.patch_embedding = PatchEmbeddingBlock( + in_channels=in_channels, + img_size=img_size, + patch_size=patch_size, + hidden_size=hidden_size, + num_heads=num_heads, + pos_embed=pos_embed, + dropout_rate=dropout_rate, + spatial_dims=spatial_dims, + ) + self.blocks = nn.ModuleList( + [TransformerBlock(hidden_size, mlp_dim, num_heads, dropout_rate) for i in range(num_layers)] + ) + self.norm = nn.LayerNorm(hidden_size) + if self.classification: + self.cls_token = nn.Parameter(torch.zeros(1, 1, hidden_size)) + self.classification_head = nn.Sequential(nn.Linear(hidden_size, num_classes), nn.Tanh()) + + if self.same_as_input_size: + new_patch_size = [4, 4, 4] + self.conv3d_transpose = nn.ConvTranspose3d(hidden_size, 16, kernel_size=new_patch_size, stride=new_patch_size) + self.conv3d_transpose_1 = nn.ConvTranspose3d(in_channels=16, out_channels=1, kernel_size=new_patch_size, stride=new_patch_size) + + + def forward(self, x): + x = self.patch_embedding(x) + if self.classification: + cls_token = self.cls_token.expand(x.shape[0], -1, -1) + x = torch.cat((cls_token, x), dim=1) + hidden_states_out = [] + for blk in self.blocks: + x = blk(x) + hidden_states_out.append(x) + x = self.norm(x) + if self.classification: + x = self.classification_head(x[:, 0]) + if self.same_as_input_size: + x = x.transpose(1, 2) + cubeRoot = round(math.pow(x.size()[2], 1 / 3)) + #print('X shape is {}'.format(x.shape)) + x_flatten = x.unflatten(2, (cubeRoot, cubeRoot, cubeRoot)) + #print('X flatten shape is {}'.format(x_flatten.shape)) + x = self.conv3d_transpose(x.unflatten(2, (cubeRoot, cubeRoot, cubeRoot))) + #print('X after Conv shape is {}'.format(x.shape)) + x = self.conv3d_transpose_1(x) + #print('X after Conv 1 shape is {}'.format(x.shape)) + #print('Get ready to see the reds ...') + return x, hidden_states_out diff --git a/tests/test_vit_ae.py b/tests/test_vit_ae.py new file mode 100644 index 0000000000..49caba34ef --- /dev/null +++ b/tests/test_vit_ae.py @@ -0,0 +1,128 @@ +import unittest + +import torch +from parameterized import parameterized + +from monai.networks import eval_mode +from monai.networks.nets.vit_ae import ViT_AE + +TEST_CASE_Vit = [] +for dropout_rate in [0.6]: + for in_channels in [4]: + for hidden_size in [768]: + for img_size in [64, 96, 128]: + for patch_size in [16]: + for num_heads in [12]: + for mlp_dim in [3072]: + for num_layers in [4]: + for num_classes in [8]: + for pos_embed in ["conv"]: + for classification in [False]: + for same_as_input_size in [True]: + for nd in (2, 3): + test_case = [ + { + "in_channels": in_channels, + "img_size": (img_size,) * nd, + "patch_size": (patch_size,) * nd, + "hidden_size": hidden_size, + "mlp_dim": mlp_dim, + "num_layers": num_layers, + "num_heads": num_heads, + "pos_embed": pos_embed, + "classification": classification, + "num_classes": num_classes, + "dropout_rate": dropout_rate, + "same_as_input_size": same_as_input_size + }, + (2, in_channels, *([img_size] * nd)), + (2, (img_size // patch_size) ** nd, hidden_size), + ] + if nd == 2: + test_case[0]["spatial_dims"] = 2 # type: ignore + if test_case[0]["classification"]: # type: ignore + test_case[2] = (2, test_case[0]["num_classes"]) # type: ignore + TEST_CASE_Vit.append(test_case) + + +class TestPatchEmbeddingBlock(unittest.TestCase): + @parameterized.expand(TEST_CASE_Vit) + def test_shape(self, input_param, input_shape, expected_shape): + net = ViT_AE(**input_param) + with eval_mode(net): + result, _ = net(torch.randn(input_shape)) + self.assertEqual(result.shape, expected_shape) + + def test_ill_arg(self): + with self.assertRaises(ValueError): + ViT_AE( + in_channels=1, + img_size=(128, 128, 128), + patch_size=(16, 16, 16), + hidden_size=128, + mlp_dim=3072, + num_layers=12, + num_heads=12, + pos_embed="conv", + classification=False, + dropout_rate=5.0, + ) + + with self.assertRaises(ValueError): + ViT_AE( + in_channels=1, + img_size=(32, 32, 32), + patch_size=(64, 64, 64), + hidden_size=512, + mlp_dim=3072, + num_layers=12, + num_heads=8, + pos_embed="perceptron", + classification=False, + dropout_rate=0.3, + ) + + with self.assertRaises(ValueError): + ViT_AE( + in_channels=1, + img_size=(96, 96, 96), + patch_size=(8, 8, 8), + hidden_size=512, + mlp_dim=3072, + num_layers=12, + num_heads=14, + pos_embed="conv", + classification=False, + dropout_rate=0.3, + ) + + with self.assertRaises(ValueError): + ViT_AE( + in_channels=1, + img_size=(97, 97, 97), + patch_size=(4, 4, 4), + hidden_size=768, + mlp_dim=3072, + num_layers=12, + num_heads=8, + pos_embed="perceptron", + classification=True, + dropout_rate=0.3, + ) + + with self.assertRaises(ValueError): + ViT_AE( + in_channels=4, + img_size=(96, 96, 96), + patch_size=(16, 16, 16), + hidden_size=768, + mlp_dim=3072, + num_layers=12, + num_heads=12, + pos_embed="perc", + classification=False, + dropout_rate=0.3, + ) + +if __name__ == "__main__": + unittest.main() From b7439c95be07511d0f5c78b5e8f33b33815fcad4 Mon Sep 17 00:00:00 2001 From: vnath Date: Sun, 31 Oct 2021 17:17:30 -0500 Subject: [PATCH 02/15] Fixing Debug print statement Signed-off-by: vnath --- monai/networks/nets/vit_ae.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/monai/networks/nets/vit_ae.py b/monai/networks/nets/vit_ae.py index cf9b186718..8e59259569 100644 --- a/monai/networks/nets/vit_ae.py +++ b/monai/networks/nets/vit_ae.py @@ -123,12 +123,6 @@ def forward(self, x): if self.same_as_input_size: x = x.transpose(1, 2) cubeRoot = round(math.pow(x.size()[2], 1 / 3)) - #print('X shape is {}'.format(x.shape)) - x_flatten = x.unflatten(2, (cubeRoot, cubeRoot, cubeRoot)) - #print('X flatten shape is {}'.format(x_flatten.shape)) x = self.conv3d_transpose(x.unflatten(2, (cubeRoot, cubeRoot, cubeRoot))) - #print('X after Conv shape is {}'.format(x.shape)) x = self.conv3d_transpose_1(x) - #print('X after Conv 1 shape is {}'.format(x.shape)) - #print('Get ready to see the reds ...') return x, hidden_states_out From e7bb3af70507a61919d9a50fe56899bd50244316 Mon Sep 17 00:00:00 2001 From: vnath Date: Sun, 31 Oct 2021 17:25:30 -0500 Subject: [PATCH 03/15] Truncating Classification Related Code snippets Signed-off-by: vnath --- monai/networks/nets/vit_ae.py | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/monai/networks/nets/vit_ae.py b/monai/networks/nets/vit_ae.py index 8e59259569..8d1126def3 100644 --- a/monai/networks/nets/vit_ae.py +++ b/monai/networks/nets/vit_ae.py @@ -40,9 +40,7 @@ def __init__( num_layers: int = 12, num_heads: int = 12, pos_embed: str = "conv", - classification: bool = False, same_as_input_size: bool = True, - num_classes: int = 2, dropout_rate: float = 0.0, spatial_dims: int = 3, ) -> None: @@ -83,7 +81,6 @@ def __init__( raise ValueError("hidden_size should be divisible by num_heads.") self.same_as_input_size = same_as_input_size - self.classification = classification self.patch_embedding = PatchEmbeddingBlock( in_channels=in_channels, img_size=img_size, @@ -98,9 +95,6 @@ def __init__( [TransformerBlock(hidden_size, mlp_dim, num_heads, dropout_rate) for i in range(num_layers)] ) self.norm = nn.LayerNorm(hidden_size) - if self.classification: - self.cls_token = nn.Parameter(torch.zeros(1, 1, hidden_size)) - self.classification_head = nn.Sequential(nn.Linear(hidden_size, num_classes), nn.Tanh()) if self.same_as_input_size: new_patch_size = [4, 4, 4] @@ -110,16 +104,11 @@ def __init__( def forward(self, x): x = self.patch_embedding(x) - if self.classification: - cls_token = self.cls_token.expand(x.shape[0], -1, -1) - x = torch.cat((cls_token, x), dim=1) hidden_states_out = [] for blk in self.blocks: x = blk(x) hidden_states_out.append(x) x = self.norm(x) - if self.classification: - x = self.classification_head(x[:, 0]) if self.same_as_input_size: x = x.transpose(1, 2) cubeRoot = round(math.pow(x.size()[2], 1 / 3)) From df33387a5cda35437629a47455260571f5843d9e Mon Sep 17 00:00:00 2001 From: vnath Date: Sun, 31 Oct 2021 17:31:48 -0500 Subject: [PATCH 04/15] Added explanation for input arguments Signed-off-by: vnath --- monai/networks/nets/vit_ae.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/monai/networks/nets/vit_ae.py b/monai/networks/nets/vit_ae.py index 8d1126def3..d20d8a8b81 100644 --- a/monai/networks/nets/vit_ae.py +++ b/monai/networks/nets/vit_ae.py @@ -46,7 +46,7 @@ def __init__( ) -> None: """ Args: - in_channels: dimension of input channels. + in_channels: dimension of input channels or the number of channels for input img_size: dimension of input image. patch_size: dimension of patch size. hidden_size: dimension of hidden layer. @@ -54,8 +54,6 @@ def __init__( num_layers: number of transformer blocks. num_heads: number of attention heads. pos_embed: position embedding layer type. - classification: bool argument to determine if classification is used. - num_classes: number of classes if classification is used. dropout_rate: faction of the input units to drop. spatial_dims: number of spatial dimensions. @@ -101,7 +99,6 @@ def __init__( self.conv3d_transpose = nn.ConvTranspose3d(hidden_size, 16, kernel_size=new_patch_size, stride=new_patch_size) self.conv3d_transpose_1 = nn.ConvTranspose3d(in_channels=16, out_channels=1, kernel_size=new_patch_size, stride=new_patch_size) - def forward(self, x): x = self.patch_embedding(x) hidden_states_out = [] From 4eea01e82bb0dc8b92baa2479642173b2c79ab57 Mon Sep 17 00:00:00 2001 From: vnath Date: Sun, 31 Oct 2021 17:32:53 -0500 Subject: [PATCH 05/15] Added explanation for input arguments Signed-off-by: vnath --- monai/networks/nets/vit_ae.py | 1 + 1 file changed, 1 insertion(+) diff --git a/monai/networks/nets/vit_ae.py b/monai/networks/nets/vit_ae.py index d20d8a8b81..b1047934e3 100644 --- a/monai/networks/nets/vit_ae.py +++ b/monai/networks/nets/vit_ae.py @@ -56,6 +56,7 @@ def __init__( pos_embed: position embedding layer type. dropout_rate: faction of the input units to drop. spatial_dims: number of spatial dimensions. + same_as_input_size: If set to True, ViT_AE will return output of same dimension as of input Examples:: From c2e5e14b2221040fc6678c3874d375aa21152bc2 Mon Sep 17 00:00:00 2001 From: vnath Date: Sun, 31 Oct 2021 17:36:00 -0500 Subject: [PATCH 06/15] Changed Class Name for vit autoencoder Signed-off-by: vnath --- monai/networks/nets/__init__.py | 2 +- monai/networks/nets/{vit_ae.py => vitautoenc.py} | 4 ++-- tests/test_vit_ae.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) rename monai/networks/nets/{vit_ae.py => vitautoenc.py} (98%) diff --git a/monai/networks/nets/__init__.py b/monai/networks/nets/__init__.py index a4069a0771..a07297be13 100644 --- a/monai/networks/nets/__init__.py +++ b/monai/networks/nets/__init__.py @@ -83,5 +83,5 @@ from .unetr import UNETR from .varautoencoder import VarAutoEncoder from .vit import ViT -from .vit_ae import ViT_AE +from .vitautoenc import ViTAutoEnc from .vnet import VNet diff --git a/monai/networks/nets/vit_ae.py b/monai/networks/nets/vitautoenc.py similarity index 98% rename from monai/networks/nets/vit_ae.py rename to monai/networks/nets/vitautoenc.py index b1047934e3..07bdd3d923 100644 --- a/monai/networks/nets/vit_ae.py +++ b/monai/networks/nets/vitautoenc.py @@ -19,10 +19,10 @@ from monai.networks.blocks.patchembedding import PatchEmbeddingBlock from monai.networks.blocks.transformerblock import TransformerBlock -__all__ = ["ViT_AE"] +__all__ = ["ViTAutoEnc"] -class ViT_AE(nn.Module): +class ViTAutoEnc(nn.Module): """ Vision Transformer (ViT), based on: "Dosovitskiy et al., An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale " diff --git a/tests/test_vit_ae.py b/tests/test_vit_ae.py index 49caba34ef..cfafb60c6e 100644 --- a/tests/test_vit_ae.py +++ b/tests/test_vit_ae.py @@ -4,7 +4,7 @@ from parameterized import parameterized from monai.networks import eval_mode -from monai.networks.nets.vit_ae import ViT_AE +from monai.networks.nets.vitautoenc import ViT_AE TEST_CASE_Vit = [] for dropout_rate in [0.6]: From 9bcc9c927a22634953f9afe2831e0899b525fe29 Mon Sep 17 00:00:00 2001 From: vnath Date: Sun, 31 Oct 2021 17:39:13 -0500 Subject: [PATCH 07/15] Changed Class Name for vit autoencoder Signed-off-by: vnath --- monai/networks/nets/vitautoenc.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/monai/networks/nets/vitautoenc.py b/monai/networks/nets/vitautoenc.py index 07bdd3d923..a2d29dc5e2 100644 --- a/monai/networks/nets/vitautoenc.py +++ b/monai/networks/nets/vitautoenc.py @@ -61,17 +61,15 @@ def __init__( Examples:: # for single channel input with image size of (96,96,96), conv position embedding and segmentation backbone - >>> net = ViT(in_channels=1, img_size=(96,96,96), pos_embed='conv') + # It will provide an output of same size as that of the input + >>> net = ViTAutoEnc(in_channels=1, img_size=(96,96,96), pos_embed='conv', same_as_input_size=True) - # for 3-channel with image size of (128,128,128), 24 layers and classification backbone - >>> net = ViT(in_channels=3, img_size=(128,128,128), pos_embed='conv', classification=True) - - # for 3-channel with image size of (224,224), 12 layers and classification backbone - >>> net = ViT(in_channels=3, img_size=(224,224), pos_embed='conv', classification=True, spatial_dims=2) + # for 3-channel with image size of (128,128,128), output will be same size as of input + >>> net = ViTAutoEnc(in_channels=3, img_size=(128,128,128), pos_embed='conv', same_as_input_size=True) """ - super(ViT_AE, self).__init__() + super(ViTAutoEnc, self).__init__() if not (0 <= dropout_rate <= 1): raise ValueError("dropout_rate should be between 0 and 1.") From b64160046d6cd9dd73f9cca77371b7d7ba6f506b Mon Sep 17 00:00:00 2001 From: vnath Date: Sun, 31 Oct 2021 17:43:11 -0500 Subject: [PATCH 08/15] Addressed Wenqi comments for test cases Signed-off-by: vnath --- tests/min_tests.py | 1 + tests/{test_vit_ae.py => test_vitautoenc.py} | 14 +++++++------- 2 files changed, 8 insertions(+), 7 deletions(-) rename tests/{test_vit_ae.py => test_vitautoenc.py} (96%) diff --git a/tests/min_tests.py b/tests/min_tests.py index 22f15f7279..da0e5d8b66 100644 --- a/tests/min_tests.py +++ b/tests/min_tests.py @@ -138,6 +138,7 @@ def run_testsuit(): "test_unetr", "test_unetr_block", "test_vit", + "test_vitautoenc", "test_write_metrics_reports", "test_zoom", "test_zoom_affine", diff --git a/tests/test_vit_ae.py b/tests/test_vitautoenc.py similarity index 96% rename from tests/test_vit_ae.py rename to tests/test_vitautoenc.py index cfafb60c6e..cc47f94484 100644 --- a/tests/test_vit_ae.py +++ b/tests/test_vitautoenc.py @@ -4,7 +4,7 @@ from parameterized import parameterized from monai.networks import eval_mode -from monai.networks.nets.vitautoenc import ViT_AE +from monai.networks.nets.vitautoenc import ViTAutoEnc TEST_CASE_Vit = [] for dropout_rate in [0.6]: @@ -48,14 +48,14 @@ class TestPatchEmbeddingBlock(unittest.TestCase): @parameterized.expand(TEST_CASE_Vit) def test_shape(self, input_param, input_shape, expected_shape): - net = ViT_AE(**input_param) + net = ViTAutoEnc(**input_param) with eval_mode(net): result, _ = net(torch.randn(input_shape)) self.assertEqual(result.shape, expected_shape) def test_ill_arg(self): with self.assertRaises(ValueError): - ViT_AE( + ViTAutoEnc( in_channels=1, img_size=(128, 128, 128), patch_size=(16, 16, 16), @@ -69,7 +69,7 @@ def test_ill_arg(self): ) with self.assertRaises(ValueError): - ViT_AE( + ViTAutoEnc( in_channels=1, img_size=(32, 32, 32), patch_size=(64, 64, 64), @@ -83,7 +83,7 @@ def test_ill_arg(self): ) with self.assertRaises(ValueError): - ViT_AE( + ViTAutoEnc( in_channels=1, img_size=(96, 96, 96), patch_size=(8, 8, 8), @@ -97,7 +97,7 @@ def test_ill_arg(self): ) with self.assertRaises(ValueError): - ViT_AE( + ViTAutoEnc( in_channels=1, img_size=(97, 97, 97), patch_size=(4, 4, 4), @@ -111,7 +111,7 @@ def test_ill_arg(self): ) with self.assertRaises(ValueError): - ViT_AE( + ViTAutoEnc( in_channels=4, img_size=(96, 96, 96), patch_size=(16, 16, 16), From 04c35de291a098bceede6400a7498a65fbe60bed Mon Sep 17 00:00:00 2001 From: vnath Date: Sun, 31 Oct 2021 17:48:57 -0500 Subject: [PATCH 09/15] Removed the flag for same as input size for cleaner code Signed-off-by: vnath --- monai/networks/nets/vitautoenc.py | 23 +++++++++-------------- tests/test_vitautoenc.py | 10 ++++++++++ 2 files changed, 19 insertions(+), 14 deletions(-) diff --git a/monai/networks/nets/vitautoenc.py b/monai/networks/nets/vitautoenc.py index a2d29dc5e2..2c9a7d7474 100644 --- a/monai/networks/nets/vitautoenc.py +++ b/monai/networks/nets/vitautoenc.py @@ -40,7 +40,6 @@ def __init__( num_layers: int = 12, num_heads: int = 12, pos_embed: str = "conv", - same_as_input_size: bool = True, dropout_rate: float = 0.0, spatial_dims: int = 3, ) -> None: @@ -56,16 +55,15 @@ def __init__( pos_embed: position embedding layer type. dropout_rate: faction of the input units to drop. spatial_dims: number of spatial dimensions. - same_as_input_size: If set to True, ViT_AE will return output of same dimension as of input Examples:: # for single channel input with image size of (96,96,96), conv position embedding and segmentation backbone # It will provide an output of same size as that of the input - >>> net = ViTAutoEnc(in_channels=1, img_size=(96,96,96), pos_embed='conv', same_as_input_size=True) + >>> net = ViTAutoEnc(in_channels=1, img_size=(96,96,96), pos_embed='conv') # for 3-channel with image size of (128,128,128), output will be same size as of input - >>> net = ViTAutoEnc(in_channels=3, img_size=(128,128,128), pos_embed='conv', same_as_input_size=True) + >>> net = ViTAutoEnc(in_channels=3, img_size=(128,128,128), pos_embed='conv') """ @@ -77,7 +75,6 @@ def __init__( if hidden_size % num_heads != 0: raise ValueError("hidden_size should be divisible by num_heads.") - self.same_as_input_size = same_as_input_size self.patch_embedding = PatchEmbeddingBlock( in_channels=in_channels, img_size=img_size, @@ -93,10 +90,9 @@ def __init__( ) self.norm = nn.LayerNorm(hidden_size) - if self.same_as_input_size: - new_patch_size = [4, 4, 4] - self.conv3d_transpose = nn.ConvTranspose3d(hidden_size, 16, kernel_size=new_patch_size, stride=new_patch_size) - self.conv3d_transpose_1 = nn.ConvTranspose3d(in_channels=16, out_channels=1, kernel_size=new_patch_size, stride=new_patch_size) + new_patch_size = [4, 4, 4] + self.conv3d_transpose = nn.ConvTranspose3d(hidden_size, 16, kernel_size=new_patch_size, stride=new_patch_size) + self.conv3d_transpose_1 = nn.ConvTranspose3d(in_channels=16, out_channels=1, kernel_size=new_patch_size, stride=new_patch_size) def forward(self, x): x = self.patch_embedding(x) @@ -105,9 +101,8 @@ def forward(self, x): x = blk(x) hidden_states_out.append(x) x = self.norm(x) - if self.same_as_input_size: - x = x.transpose(1, 2) - cubeRoot = round(math.pow(x.size()[2], 1 / 3)) - x = self.conv3d_transpose(x.unflatten(2, (cubeRoot, cubeRoot, cubeRoot))) - x = self.conv3d_transpose_1(x) + x = x.transpose(1, 2) + cubeRoot = round(math.pow(x.size()[2], 1 / 3)) + x = self.conv3d_transpose(x.unflatten(2, (cubeRoot, cubeRoot, cubeRoot))) + x = self.conv3d_transpose_1(x) return x, hidden_states_out diff --git a/tests/test_vitautoenc.py b/tests/test_vitautoenc.py index cc47f94484..ccab841935 100644 --- a/tests/test_vitautoenc.py +++ b/tests/test_vitautoenc.py @@ -1,3 +1,13 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import unittest import torch From 84902410e7bf18545007a4d4b0eadfa84c62d616 Mon Sep 17 00:00:00 2001 From: vnath Date: Mon, 1 Nov 2021 10:08:11 -0500 Subject: [PATCH 10/15] Test cases for loop removed, much more polishing done of the code Signed-off-by: vnath --- monai/networks/nets/vitautoenc.py | 4 +- tests/test_vitautoenc.py | 63 +++++++++++-------------------- 2 files changed, 24 insertions(+), 43 deletions(-) diff --git a/monai/networks/nets/vitautoenc.py b/monai/networks/nets/vitautoenc.py index 2c9a7d7474..3eb68b8156 100644 --- a/monai/networks/nets/vitautoenc.py +++ b/monai/networks/nets/vitautoenc.py @@ -60,10 +60,10 @@ def __init__( # for single channel input with image size of (96,96,96), conv position embedding and segmentation backbone # It will provide an output of same size as that of the input - >>> net = ViTAutoEnc(in_channels=1, img_size=(96,96,96), pos_embed='conv') + >>> net = ViTAutoEnc(in_channels=1, patch_size=(16,16,16), img_size=(96,96,96), pos_embed='conv') # for 3-channel with image size of (128,128,128), output will be same size as of input - >>> net = ViTAutoEnc(in_channels=3, img_size=(128,128,128), pos_embed='conv') + >>> net = ViTAutoEnc(in_channels=3, patch_size=(16,16,16), img_size=(128,128,128), pos_embed='conv') """ diff --git a/tests/test_vitautoenc.py b/tests/test_vitautoenc.py index ccab841935..6e37d301e9 100644 --- a/tests/test_vitautoenc.py +++ b/tests/test_vitautoenc.py @@ -17,42 +17,28 @@ from monai.networks.nets.vitautoenc import ViTAutoEnc TEST_CASE_Vit = [] -for dropout_rate in [0.6]: - for in_channels in [4]: - for hidden_size in [768]: - for img_size in [64, 96, 128]: - for patch_size in [16]: - for num_heads in [12]: - for mlp_dim in [3072]: - for num_layers in [4]: - for num_classes in [8]: - for pos_embed in ["conv"]: - for classification in [False]: - for same_as_input_size in [True]: - for nd in (2, 3): - test_case = [ - { - "in_channels": in_channels, - "img_size": (img_size,) * nd, - "patch_size": (patch_size,) * nd, - "hidden_size": hidden_size, - "mlp_dim": mlp_dim, - "num_layers": num_layers, - "num_heads": num_heads, - "pos_embed": pos_embed, - "classification": classification, - "num_classes": num_classes, - "dropout_rate": dropout_rate, - "same_as_input_size": same_as_input_size - }, - (2, in_channels, *([img_size] * nd)), - (2, (img_size // patch_size) ** nd, hidden_size), - ] - if nd == 2: - test_case[0]["spatial_dims"] = 2 # type: ignore - if test_case[0]["classification"]: # type: ignore - test_case[2] = (2, test_case[0]["num_classes"]) # type: ignore - TEST_CASE_Vit.append(test_case) +for in_channels in [1, 4]: + for img_size in [64, 96, 128]: + for patch_size in [16]: + for pos_embed in ["conv", "perceptron"]: + for nd in [3]: + test_case = [ + { + "in_channels": in_channels, + "img_size": (img_size,) * nd, + "patch_size": (patch_size,) * nd, + "hidden_size": 768, + "mlp_dim": 3072, + "num_layers": 4, + "num_heads": 12, + "pos_embed": pos_embed, + "dropout_rate": 0.6, + }, + (2, in_channels, *([img_size] * nd)), + (2, 1, *([img_size] * nd)), + ] + + TEST_CASE_Vit.append(test_case) class TestPatchEmbeddingBlock(unittest.TestCase): @@ -74,7 +60,6 @@ def test_ill_arg(self): num_layers=12, num_heads=12, pos_embed="conv", - classification=False, dropout_rate=5.0, ) @@ -88,7 +73,6 @@ def test_ill_arg(self): num_layers=12, num_heads=8, pos_embed="perceptron", - classification=False, dropout_rate=0.3, ) @@ -102,7 +86,6 @@ def test_ill_arg(self): num_layers=12, num_heads=14, pos_embed="conv", - classification=False, dropout_rate=0.3, ) @@ -116,7 +99,6 @@ def test_ill_arg(self): num_layers=12, num_heads=8, pos_embed="perceptron", - classification=True, dropout_rate=0.3, ) @@ -130,7 +112,6 @@ def test_ill_arg(self): num_layers=12, num_heads=12, pos_embed="perc", - classification=False, dropout_rate=0.3, ) From ee0d3dfbd9e010f1e00855462630c3b5418fd2ad Mon Sep 17 00:00:00 2001 From: vnath Date: Mon, 1 Nov 2021 11:02:33 -0500 Subject: [PATCH 11/15] Doc changes added Signed-off-by: vnath --- docs/source/networks.rst | 5 +++++ monai/networks/nets/vitautoenc.py | 13 +++++++------ tests/test_vitautoenc.py | 1 + 3 files changed, 13 insertions(+), 6 deletions(-) diff --git a/docs/source/networks.rst b/docs/source/networks.rst index 36d62752d4..cf27c43ce8 100644 --- a/docs/source/networks.rst +++ b/docs/source/networks.rst @@ -470,6 +470,11 @@ Nets .. autoclass:: ViT :members: +`ViTAutoEnc` +~~~~~ +.. autoclass:: ViTAutoEnc + :members: + `FullyConnectedNet` ~~~~~~~~~~~~~~~~~~~ .. autoclass:: FullyConnectedNet diff --git a/monai/networks/nets/vitautoenc.py b/monai/networks/nets/vitautoenc.py index 3eb68b8156..ca632e0234 100644 --- a/monai/networks/nets/vitautoenc.py +++ b/monai/networks/nets/vitautoenc.py @@ -10,10 +10,9 @@ # limitations under the License. +import math from typing import Sequence, Union -import math -import torch import torch.nn as nn from monai.networks.blocks.patchembedding import PatchEmbeddingBlock @@ -90,9 +89,11 @@ def __init__( ) self.norm = nn.LayerNorm(hidden_size) - new_patch_size = [4, 4, 4] + new_patch_size = (4, 4, 4) self.conv3d_transpose = nn.ConvTranspose3d(hidden_size, 16, kernel_size=new_patch_size, stride=new_patch_size) - self.conv3d_transpose_1 = nn.ConvTranspose3d(in_channels=16, out_channels=1, kernel_size=new_patch_size, stride=new_patch_size) + self.conv3d_transpose_1 = nn.ConvTranspose3d( + in_channels=16, out_channels=1, kernel_size=new_patch_size, stride=new_patch_size + ) def forward(self, x): x = self.patch_embedding(x) @@ -102,7 +103,7 @@ def forward(self, x): hidden_states_out.append(x) x = self.norm(x) x = x.transpose(1, 2) - cubeRoot = round(math.pow(x.size()[2], 1 / 3)) - x = self.conv3d_transpose(x.unflatten(2, (cubeRoot, cubeRoot, cubeRoot))) + cuberoot = round(math.pow(x.size()[2], 1 / 3)) + x = self.conv3d_transpose(x.unflatten(2, (cuberoot, cuberoot, cuberoot))) x = self.conv3d_transpose_1(x) return x, hidden_states_out diff --git a/tests/test_vitautoenc.py b/tests/test_vitautoenc.py index 6e37d301e9..72cf65dfb3 100644 --- a/tests/test_vitautoenc.py +++ b/tests/test_vitautoenc.py @@ -115,5 +115,6 @@ def test_ill_arg(self): dropout_rate=0.3, ) + if __name__ == "__main__": unittest.main() From 62b2851a14d7cda07959b2441bcbccec0c32bbc7 Mon Sep 17 00:00:00 2001 From: vnath Date: Mon, 1 Nov 2021 17:44:22 -0500 Subject: [PATCH 12/15] Fixed rst formatting a raising error and also added raising error for 2D inputs Signed-off-by: vnath --- docs/source/networks.rst | 2 +- monai/networks/nets/vitautoenc.py | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/docs/source/networks.rst b/docs/source/networks.rst index cf27c43ce8..31bc6de6f8 100644 --- a/docs/source/networks.rst +++ b/docs/source/networks.rst @@ -471,7 +471,7 @@ Nets :members: `ViTAutoEnc` -~~~~~ +~~~~~~~~~~~~ .. autoclass:: ViTAutoEnc :members: diff --git a/monai/networks/nets/vitautoenc.py b/monai/networks/nets/vitautoenc.py index ca632e0234..9bdfcd6ab0 100644 --- a/monai/networks/nets/vitautoenc.py +++ b/monai/networks/nets/vitautoenc.py @@ -74,6 +74,9 @@ def __init__( if hidden_size % num_heads != 0: raise ValueError("hidden_size should be divisible by num_heads.") + if spatial_dims == 2: + raise ValueError("Not implemented for 2 dimensions, please try 3") + self.patch_embedding = PatchEmbeddingBlock( in_channels=in_channels, img_size=img_size, From 2040ffb98c56dc1bf4d5428622721bd8d234756a Mon Sep 17 00:00:00 2001 From: vnath Date: Tue, 2 Nov 2021 11:57:16 -0500 Subject: [PATCH 13/15] Modified the ViTAutoEnc to adapt for version 1.6 of pytorch for backward comptability Signed-off-by: vnath --- monai/networks/nets/vitautoenc.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/monai/networks/nets/vitautoenc.py b/monai/networks/nets/vitautoenc.py index 9bdfcd6ab0..48c7fbeb25 100644 --- a/monai/networks/nets/vitautoenc.py +++ b/monai/networks/nets/vitautoenc.py @@ -13,6 +13,7 @@ import math from typing import Sequence, Union +import torch import torch.nn as nn from monai.networks.blocks.patchembedding import PatchEmbeddingBlock @@ -107,6 +108,8 @@ def forward(self, x): x = self.norm(x) x = x.transpose(1, 2) cuberoot = round(math.pow(x.size()[2], 1 / 3)) - x = self.conv3d_transpose(x.unflatten(2, (cuberoot, cuberoot, cuberoot))) + x_shape = x.size() + x = torch.reshape(x, [x_shape[0], x_shape[1], cuberoot, cuberoot, cuberoot]) + x = self.conv3d_transpose(x) x = self.conv3d_transpose_1(x) return x, hidden_states_out From 6454b3f222c07f3bae07549ee0a69be57da0235b Mon Sep 17 00:00:00 2001 From: vnath Date: Tue, 2 Nov 2021 13:58:30 -0500 Subject: [PATCH 14/15] Variable name changed for test case file Signed-off-by: vnath --- tests/test_vitautoenc.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_vitautoenc.py b/tests/test_vitautoenc.py index 72cf65dfb3..13cb0d8325 100644 --- a/tests/test_vitautoenc.py +++ b/tests/test_vitautoenc.py @@ -16,7 +16,7 @@ from monai.networks import eval_mode from monai.networks.nets.vitautoenc import ViTAutoEnc -TEST_CASE_Vit = [] +TEST_CASE_Vitautoenc = [] for in_channels in [1, 4]: for img_size in [64, 96, 128]: for patch_size in [16]: @@ -38,11 +38,11 @@ (2, 1, *([img_size] * nd)), ] - TEST_CASE_Vit.append(test_case) + TEST_CASE_Vitautoenc.append(test_case) class TestPatchEmbeddingBlock(unittest.TestCase): - @parameterized.expand(TEST_CASE_Vit) + @parameterized.expand(TEST_CASE_Vitautoenc) def test_shape(self, input_param, input_shape, expected_shape): net = ViTAutoEnc(**input_param) with eval_mode(net): From 8f35e2a41c4ad73c06fa913f8ee4c5ba9194c109 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 2 Nov 2021 18:59:11 +0000 Subject: [PATCH 15/15] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/networks/nets/vitautoenc.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/networks/nets/vitautoenc.py b/monai/networks/nets/vitautoenc.py index 48c7fbeb25..097534d230 100644 --- a/monai/networks/nets/vitautoenc.py +++ b/monai/networks/nets/vitautoenc.py @@ -67,7 +67,7 @@ def __init__( """ - super(ViTAutoEnc, self).__init__() + super().__init__() if not (0 <= dropout_rate <= 1): raise ValueError("dropout_rate should be between 0 and 1.")