From e294f515a0cba8d1f96429255ad0998757c3cd22 Mon Sep 17 00:00:00 2001 From: ahatamizadeh Date: Sun, 29 Aug 2021 11:48:37 -0700 Subject: [PATCH 1/4] add classification support for ViT Model Signed-off-by: ahatamizadeh --- monai/networks/nets/unetr.py | 10 ++++---- monai/networks/nets/vit.py | 16 +++++++++---- tests/test_vit.py | 44 ++++++++++++++++++------------------ 3 files changed, 39 insertions(+), 31 deletions(-) diff --git a/monai/networks/nets/unetr.py b/monai/networks/nets/unetr.py index ed49847515..9990cb6643 100644 --- a/monai/networks/nets/unetr.py +++ b/monai/networks/nets/unetr.py @@ -34,9 +34,9 @@ def __init__( hidden_size: int = 768, mlp_dim: int = 3072, num_heads: int = 12, - pos_embed: str = "perceptron", + pos_embed: str = "conv", norm_name: Union[Tuple, str] = "instance", - conv_block: bool = False, + conv_block: bool = True, res_block: bool = True, dropout_rate: float = 0.0, spatial_dims: int = 3, @@ -59,13 +59,13 @@ def __init__( Examples:: - # for single channel input 4-channel output with patch size of (96,96,96), feature size of 32 and batch norm + # for single channel input 4-channel output with image size of (96,96,96), feature size of 32 and batch norm >>> net = UNETR(in_channels=1, out_channels=4, img_size=(96,96,96), feature_size=32, norm_name='batch') - # for single channel input 4-channel output with patch size of (96,96), feature size of 32 and batch norm + # for single channel input 4-channel output with image size of (96,96), feature size of 32 and batch norm >>> net = UNETR(in_channels=1, out_channels=4, img_size=96, feature_size=32, norm_name='batch', spatial_dims=2) - # for 4-channel input 3-channel output with patch size of (128,128,128), conv position embedding and instance norm + # for 4-channel input 3-channel output with image size of (128,128,128), conv position embedding and instance norm >>> net = UNETR(in_channels=4, out_channels=3, img_size=(128,128,128), pos_embed='conv', norm_name='instance') """ diff --git a/monai/networks/nets/vit.py b/monai/networks/nets/vit.py index 0fd55cac62..3a5d94cc37 100644 --- a/monai/networks/nets/vit.py +++ b/monai/networks/nets/vit.py @@ -12,6 +12,7 @@ from typing import Sequence, Union +import torch import torch.nn as nn from monai.networks.blocks.patchembedding import PatchEmbeddingBlock @@ -33,7 +34,7 @@ def __init__( mlp_dim: int = 3072, num_layers: int = 12, num_heads: int = 12, - pos_embed: str = "perceptron", + pos_embed: str = "conv", classification: bool = False, num_classes: int = 2, dropout_rate: float = 0.0, @@ -56,12 +57,15 @@ def __init__( Examples:: - # for single channel input with patch size of (96,96,96), conv position embedding and segmentation backbone + # for single channel input with image size of (96,96,96), conv position embedding and segmentation backbone >>> net = ViT(in_channels=1, img_size=(96,96,96), pos_embed='conv') - # for 3-channel with patch size of (128,128,128), 24 layers and classification backbone + # for 3-channel with image size of (128,128,128), 24 layers and classification backbone >>> net = ViT(in_channels=3, img_size=(128,128,128), pos_embed='conv', classification=True) + # for 3-channel with image size of (224,224), 12 layers and classification backbone + >>> net = ViT(in_channels=3, img_size=(224,224), pos_embed='conv', classification=True, spatial_dims=2) + """ super(ViT, self).__init__() @@ -88,10 +92,14 @@ def __init__( ) self.norm = nn.LayerNorm(hidden_size) if self.classification: - self.classification_head = nn.Linear(hidden_size, num_classes) + self.cls_token = nn.Parameter(torch.zeros(1, 1, hidden_size)) + self.classification_head = nn.Sequential(nn.Linear(hidden_size, num_classes), nn.Tanh()) def forward(self, x): x = self.patch_embedding(x) + if self.classification: + cls_token = self.cls_token.expand(x.shape[0], -1, -1) + x = torch.cat((cls_token, x), dim=1) hidden_states_out = [] for blk in self.blocks: x = blk(x) diff --git a/tests/test_vit.py b/tests/test_vit.py index 0dce73b0cb..fdb5da0dc3 100644 --- a/tests/test_vit.py +++ b/tests/test_vit.py @@ -28,28 +28,28 @@ for num_layers in [4]: for num_classes in [2]: for pos_embed in ["conv"]: - # for classification in [False, True]: # TODO: test classification - for nd in (2, 3): - test_case = [ - { - "in_channels": in_channels, - "img_size": (img_size,) * nd, - "patch_size": (patch_size,) * nd, - "hidden_size": hidden_size, - "mlp_dim": mlp_dim, - "num_layers": num_layers, - "num_heads": num_heads, - "pos_embed": pos_embed, - "classification": False, - "num_classes": num_classes, - "dropout_rate": dropout_rate, - }, - (2, in_channels, *([img_size] * nd)), - (2, (img_size // patch_size) ** nd, hidden_size), - ] - if nd == 2: - test_case[0]["spatial_dims"] = 2 # type: ignore - TEST_CASE_Vit.append(test_case) + for classification in [False, True]: + for nd in (2, 3): + test_case = [ + { + "in_channels": in_channels, + "img_size": (img_size,) * nd, + "patch_size": (patch_size,) * nd, + "hidden_size": hidden_size, + "mlp_dim": mlp_dim, + "num_layers": num_layers, + "num_heads": num_heads, + "pos_embed": pos_embed, + "classification": False, + "num_classes": num_classes, + "dropout_rate": dropout_rate, + }, + (2, in_channels, *([img_size] * nd)), + (2, (img_size // patch_size) ** nd, hidden_size), + ] + if nd == 2: + test_case[0]["spatial_dims"] = 2 # type: ignore + TEST_CASE_Vit.append(test_case) class TestPatchEmbeddingBlock(unittest.TestCase): From c4dfe3bdacf742c313604020905eca7a9a175343 Mon Sep 17 00:00:00 2001 From: ahatamizadeh Date: Sun, 29 Aug 2021 12:18:48 -0700 Subject: [PATCH 2/4] add classification support for ViT Model Signed-off-by: ahatamizadeh --- tests/test_vit.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/tests/test_vit.py b/tests/test_vit.py index fdb5da0dc3..45a942196d 100644 --- a/tests/test_vit.py +++ b/tests/test_vit.py @@ -26,7 +26,7 @@ for num_heads in [12]: for mlp_dim in [3072]: for num_layers in [4]: - for num_classes in [2]: + for num_classes in [8]: for pos_embed in ["conv"]: for classification in [False, True]: for nd in (2, 3): @@ -40,7 +40,7 @@ "num_layers": num_layers, "num_heads": num_heads, "pos_embed": pos_embed, - "classification": False, + "classification": classification, "num_classes": num_classes, "dropout_rate": dropout_rate, }, @@ -49,6 +49,8 @@ ] if nd == 2: test_case[0]["spatial_dims"] = 2 # type: ignore + if test_case[0]["classification"]: + test_case[2] = (2, test_case[0]["num_classes"]) # type: ignore TEST_CASE_Vit.append(test_case) @@ -113,7 +115,7 @@ def test_ill_arg(self): num_layers=12, num_heads=8, pos_embed="perceptron", - classification=False, + classification=True, dropout_rate=0.3, ) From 4e4d13330a4d5ddd567a46fe5247db2291b10ba2 Mon Sep 17 00:00:00 2001 From: ahatamizadeh Date: Sun, 29 Aug 2021 12:51:46 -0700 Subject: [PATCH 3/4] add classification support for ViT Model Signed-off-by: ahatamizadeh --- tests/test_vit.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_vit.py b/tests/test_vit.py index 45a942196d..cdf0888222 100644 --- a/tests/test_vit.py +++ b/tests/test_vit.py @@ -49,7 +49,7 @@ ] if nd == 2: test_case[0]["spatial_dims"] = 2 # type: ignore - if test_case[0]["classification"]: + if test_case[0]["classification"]: # type: ignore test_case[2] = (2, test_case[0]["num_classes"]) # type: ignore TEST_CASE_Vit.append(test_case) From cd147a061c520737a42ba20f19e20edf89ed19aa Mon Sep 17 00:00:00 2001 From: monai-bot Date: Sun, 29 Aug 2021 23:09:54 +0000 Subject: [PATCH 4/4] [MONAI] python code formatting Signed-off-by: monai-bot --- monai/apps/deepedit/transforms.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/apps/deepedit/transforms.py b/monai/apps/deepedit/transforms.py index 2ab2722076..845e7bd1d0 100644 --- a/monai/apps/deepedit/transforms.py +++ b/monai/apps/deepedit/transforms.py @@ -6,10 +6,10 @@ from monai.config import KeysCollection from monai.transforms.transform import MapTransform, Randomizable, Transform +from monai.utils import optional_import logger = logging.getLogger(__name__) -from monai.utils import optional_import distance_transform_cdt, _ = optional_import("scipy.ndimage.morphology", name="distance_transform_cdt")