From e8967351bb3b79330e0c1f7fedc8a8d446e77e96 Mon Sep 17 00:00:00 2001 From: NoTody Date: Thu, 14 Sep 2023 06:28:54 -0400 Subject: [PATCH 01/17] add sincos positional embedding for patch embedding class Signed-off-by: NoTody --- monai/networks/blocks/patchembedding.py | 38 +++++++--- monai/networks/blocks/pos_embed_utils.py | 89 ++++++++++++++++++++++++ monai/networks/nets/unetr.py | 8 +-- monai/networks/nets/vit.py | 13 ++-- monai/networks/nets/vitautoenc.py | 10 +-- requirements-dev.txt | 1 + tests/test_patchembedding.py | 52 +++++++------- 7 files changed, 163 insertions(+), 48 deletions(-) create mode 100644 monai/networks/blocks/pos_embed_utils.py diff --git a/monai/networks/blocks/patchembedding.py b/monai/networks/blocks/patchembedding.py index 57c0c5ee02..4410eca9b0 100644 --- a/monai/networks/blocks/patchembedding.py +++ b/monai/networks/blocks/patchembedding.py @@ -19,6 +19,7 @@ import torch.nn.functional as F from torch.nn import LayerNorm +from monai.networks.blocks.pos_embed_utils import build_sincos_position_embedding from monai.networks.layers import Conv, trunc_normal_ from monai.utils import ensure_tuple_rep, optional_import from monai.utils.module import look_up_option @@ -35,7 +36,8 @@ class PatchEmbeddingBlock(nn.Module): Example:: >>> from monai.networks.blocks import PatchEmbeddingBlock - >>> PatchEmbeddingBlock(in_channels=4, img_size=32, patch_size=8, hidden_size=32, num_heads=4, pos_embed="conv") + >>> PatchEmbeddingBlock(in_channels=4, img_size=32, patch_size=8, hidden_size=32, num_heads=4, + >>> patch_embed="conv", pos_embed="sincos") """ @@ -46,7 +48,8 @@ def __init__( patch_size: Sequence[int] | int, hidden_size: int, num_heads: int, - pos_embed: str, + patch_embed: str, + pos_embed: str = "norm", dropout_rate: float = 0.0, spatial_dims: int = 3, ) -> None: @@ -57,11 +60,10 @@ def __init__( patch_size: dimension of patch size. hidden_size: dimension of hidden layer. num_heads: number of attention heads. + patch_embed: patch embedding layer type. pos_embed: position embedding layer type. dropout_rate: faction of the input units to drop. spatial_dims: number of spatial dimensions. - - """ super().__init__() @@ -72,24 +74,29 @@ def __init__( if hidden_size % num_heads != 0: raise ValueError(f"hidden size {hidden_size} should be divisible by num_heads {num_heads}.") - self.pos_embed = look_up_option(pos_embed, SUPPORTED_EMBEDDING_TYPES) + self.patch_embed = look_up_option(patch_embed, SUPPORTED_EMBEDDING_TYPES) img_size = ensure_tuple_rep(img_size, spatial_dims) patch_size = ensure_tuple_rep(patch_size, spatial_dims) for m, p in zip(img_size, patch_size): if m < p: raise ValueError("patch_size should be smaller than img_size.") - if self.pos_embed == "perceptron" and m % p != 0: + if self.patch_embed == "perceptron" and m % p != 0: raise ValueError("patch_size should be divisible by img_size for perceptron.") self.n_patches = np.prod([im_d // p_d for im_d, p_d in zip(img_size, patch_size)]) self.patch_dim = int(in_channels * np.prod(patch_size)) + grid_size = [] + for in_size, pa_size in zip(img_size, patch_size): + assert in_size % pa_size == 0, "input size and patch size are not proper" + grid_size.append(in_size // pa_size) + self.patch_embeddings: nn.Module - if self.pos_embed == "conv": + if self.patch_embed == "conv": self.patch_embeddings = Conv[Conv.CONV, spatial_dims]( in_channels=in_channels, out_channels=hidden_size, kernel_size=patch_size, stride=patch_size ) - elif self.pos_embed == "perceptron": + elif self.patch_embed == "perceptron": # for 3d: "b c (h p1) (w p2) (d p3)-> b (h w d) (p1 p2 p3 c)" chars = (("h", "p1"), ("w", "p2"), ("d", "p3"))[:spatial_dims] from_chars = "b c " + " ".join(f"({k} {v})" for k, v in chars) @@ -100,7 +107,18 @@ def __init__( ) self.position_embeddings = nn.Parameter(torch.zeros(1, self.n_patches, hidden_size)) self.dropout = nn.Dropout(dropout_rate) - trunc_normal_(self.position_embeddings, mean=0.0, std=0.02, a=-2.0, b=2.0) + + if pos_embed == "none": + pass + elif pos_embed == "learnable": + trunc_normal_(self.position_embeddings, mean=0.0, std=0.02, a=-2.0, b=2.0) + elif pos_embed == "sincos": + with torch.no_grad(): + pos_embed = build_sincos_position_embedding(grid_size, hidden_size, spatial_dims) + self.position_embeddings.data.copy_(pos_embed.float()) + else: + raise ValueError(f"pos_embed type {pos_embed} not supported.") + self.apply(self._init_weights) def _init_weights(self, m): @@ -114,7 +132,7 @@ def _init_weights(self, m): def forward(self, x): x = self.patch_embeddings(x) - if self.pos_embed == "conv": + if self.patch_embed == "conv": x = x.flatten(2).transpose(-1, -2) embeddings = x + self.position_embeddings embeddings = self.dropout(embeddings) diff --git a/monai/networks/blocks/pos_embed_utils.py b/monai/networks/blocks/pos_embed_utils.py new file mode 100644 index 0000000000..e3acbb0647 --- /dev/null +++ b/monai/networks/blocks/pos_embed_utils.py @@ -0,0 +1,89 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import Optional + +import torch +import torch.nn as nn +from timm.models.layers import to_2tuple, to_3tuple + +__all__ = ["build_sincos_position_embedding"] + +def build_sincos_position_embedding( + grid_size: Optional[int], embed_dim: int, spatial_dims: int = 3, temperature: float = 10000.0 +) -> torch.nn.Parameter: + """ + Builds a sin-cos position embedding based on the given grid size, embed dimension, spatial dimensions, and temperature. + Reference: https://github.com/cvlab-stonybrook/SelfMedMAE/blob/68d191dfcc1c7d0145db93a6a570362de29e3b30/lib/models/mae3d.py + + Args: + grid_size (int or Tuple[int]): The size of the grid in each spatial dimension. + embed_dim (int): The dimension of the embedding. + spatial_dims (int): The number of spatial dimensions (2 for 2D, 3 for 3D). + temperature (float): The temperature for the sin-cos position embedding. + + Returns: + pos_embed (nn.Parameter): The sin-cos position embedding as a learnable parameter. + """ + + if spatial_dims == 2: + grid_size = to_2tuple(grid_size) + h, w = grid_size + grid_h = torch.arange(h, dtype=torch.float32) + grid_w = torch.arange(w, dtype=torch.float32) + + grid_h, grid_w = torch.meshgrid(grid_h, grid_w, indexing='ij') + + assert embed_dim % 4 == 0, "Embed dimension must be divisible by 4 for 2D sin-cos position embedding" + + pos_dim = embed_dim // 4 + omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim + omega = 1.0 / (temperature**omega) + out_h = torch.einsum("m,d->md", [grid_h.flatten(), omega]) + out_w = torch.einsum("m,d->md", [grid_w.flatten(), omega]) + pos_emb = torch.cat([torch.sin(out_w), torch.cos(out_w), torch.sin(out_h), torch.cos(out_h)], dim=1)[None, :, :] + elif spatial_dims == 3: + grid_size = to_3tuple(grid_size) + h, w, d = grid_size + grid_h = torch.arange(h, dtype=torch.float32) + grid_w = torch.arange(w, dtype=torch.float32) + grid_d = torch.arange(d, dtype=torch.float32) + + grid_h, grid_w, grid_d = torch.meshgrid(grid_h, grid_w, grid_d, indexing='ij') + + assert embed_dim % 6 == 0, "Embed dimension must be divisible by 6 for 3D sin-cos position embedding" + + pos_dim = embed_dim // 6 + omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim + omega = 1.0 / (temperature**omega) + out_h = torch.einsum("m,d->md", [grid_h.flatten(), omega]) + out_w = torch.einsum("m,d->md", [grid_w.flatten(), omega]) + out_d = torch.einsum("m,d->md", [grid_d.flatten(), omega]) + pos_emb = torch.cat( + [ + torch.sin(out_w), + torch.cos(out_w), + torch.sin(out_h), + torch.cos(out_h), + torch.sin(out_d), + torch.cos(out_d), + ], + dim=1, + )[None, :, :] + else: + raise NotImplementedError("Spatial Dimension Size {spatial_dims} Not Implemented!") + + pos_embed = nn.Parameter(pos_emb) + pos_embed.requires_grad = False + + return pos_embed diff --git a/monai/networks/nets/unetr.py b/monai/networks/nets/unetr.py index 4cdcd73c4d..219ae1f888 100644 --- a/monai/networks/nets/unetr.py +++ b/monai/networks/nets/unetr.py @@ -36,7 +36,7 @@ def __init__( hidden_size: int = 768, mlp_dim: int = 3072, num_heads: int = 12, - pos_embed: str = "conv", + patch_embed: str = "conv", norm_name: tuple | str = "instance", conv_block: bool = True, res_block: bool = True, @@ -54,7 +54,7 @@ def __init__( hidden_size: dimension of hidden layer. Defaults to 768. mlp_dim: dimension of feedforward layer. Defaults to 3072. num_heads: number of attention heads. Defaults to 12. - pos_embed: position embedding layer type. Defaults to "conv". + patch_embed: patch embedding layer type. Defaults to "conv". norm_name: feature normalization type and arguments. Defaults to "instance". conv_block: if convolutional block is used. Defaults to True. res_block: if residual block is used. Defaults to True. @@ -72,7 +72,7 @@ def __init__( >>> net = UNETR(in_channels=1, out_channels=4, img_size=96, feature_size=32, norm_name='batch', spatial_dims=2) # for 4-channel input 3-channel output with image size of (128,128,128), conv position embedding and instance norm - >>> net = UNETR(in_channels=4, out_channels=3, img_size=(128,128,128), pos_embed='conv', norm_name='instance') + >>> net = UNETR(in_channels=4, out_channels=3, img_size=(128,128,128), patch_embed='conv', norm_name='instance') """ @@ -98,7 +98,7 @@ def __init__( mlp_dim=mlp_dim, num_layers=self.num_layers, num_heads=num_heads, - pos_embed=pos_embed, + patch_embed=patch_embed, classification=self.classification, dropout_rate=dropout_rate, spatial_dims=spatial_dims, diff --git a/monai/networks/nets/vit.py b/monai/networks/nets/vit.py index 8cd42b54b1..1f51a15237 100644 --- a/monai/networks/nets/vit.py +++ b/monai/networks/nets/vit.py @@ -39,7 +39,8 @@ def __init__( mlp_dim: int = 3072, num_layers: int = 12, num_heads: int = 12, - pos_embed: str = "conv", + patch_embed: str = "conv", + pos_embed: str = "learnable", classification: bool = False, num_classes: int = 2, dropout_rate: float = 0.0, @@ -57,7 +58,8 @@ def __init__( mlp_dim (int, optional): dimension of feedforward layer. Defaults to 3072. num_layers (int, optional): number of transformer blocks. Defaults to 12. num_heads (int, optional): number of attention heads. Defaults to 12. - pos_embed (str, optional): position embedding layer type. Defaults to "conv". + patch_embed (str, optional): patch embedding layer type. Defaults to "conv". + pos_embed (str, optional): position embedding type. Defaults to "learnable". classification (bool, optional): bool argument to determine if classification is used. Defaults to False. num_classes (int, optional): number of classes if classification is used. Defaults to 2. dropout_rate (float, optional): faction of the input units to drop. Defaults to 0.0. @@ -71,13 +73,13 @@ def __init__( Examples:: # for single channel input with image size of (96,96,96), conv position embedding and segmentation backbone - >>> net = ViT(in_channels=1, img_size=(96,96,96), pos_embed='conv') + >>> net = ViT(in_channels=1, img_size=(96,96,96), patch_embed='conv', pos_embed='sincos') # for 3-channel with image size of (128,128,128), 24 layers and classification backbone - >>> net = ViT(in_channels=3, img_size=(128,128,128), pos_embed='conv', classification=True) + >>> net = ViT(in_channels=3, img_size=(128,128,128), patch_embed='conv', pos_embed='sincos', classification=True) # for 3-channel with image size of (224,224), 12 layers and classification backbone - >>> net = ViT(in_channels=3, img_size=(224,224), pos_embed='conv', classification=True, spatial_dims=2) + >>> net = ViT(in_channels=3, img_size=(224,224), patch_embed='conv', pos_embed='sincos', classification=True, spatial_dims=2) """ @@ -96,6 +98,7 @@ def __init__( patch_size=patch_size, hidden_size=hidden_size, num_heads=num_heads, + patch_embed=patch_embed, pos_embed=pos_embed, dropout_rate=dropout_rate, spatial_dims=spatial_dims, diff --git a/monai/networks/nets/vitautoenc.py b/monai/networks/nets/vitautoenc.py index 12d7d4e376..c1808072b0 100644 --- a/monai/networks/nets/vitautoenc.py +++ b/monai/networks/nets/vitautoenc.py @@ -44,7 +44,7 @@ def __init__( mlp_dim: int = 3072, num_layers: int = 12, num_heads: int = 12, - pos_embed: str = "conv", + patch_embed: str = "conv", dropout_rate: float = 0.0, spatial_dims: int = 3, qkv_bias: bool = False, @@ -61,7 +61,7 @@ def __init__( mlp_dim: dimension of feedforward layer. Defaults to 3072. num_layers: number of transformer blocks. Defaults to 12. num_heads: number of attention heads. Defaults to 12. - pos_embed: position embedding layer type. Defaults to "conv". + patch_embed: position embedding layer type. Defaults to "conv". dropout_rate: faction of the input units to drop. Defaults to 0.0. spatial_dims: number of spatial dimensions. Defaults to 3. qkv_bias: apply bias to the qkv linear layer in self attention block. Defaults to False. @@ -71,10 +71,10 @@ def __init__( # for single channel input with image size of (96,96,96), conv position embedding and segmentation backbone # It will provide an output of same size as that of the input - >>> net = ViTAutoEnc(in_channels=1, patch_size=(16,16,16), img_size=(96,96,96), pos_embed='conv') + >>> net = ViTAutoEnc(in_channels=1, patch_size=(16,16,16), img_size=(96,96,96), patch_embed='conv') # for 3-channel with image size of (128,128,128), output will be same size as of input - >>> net = ViTAutoEnc(in_channels=3, patch_size=(16,16,16), img_size=(128,128,128), pos_embed='conv') + >>> net = ViTAutoEnc(in_channels=3, patch_size=(16,16,16), img_size=(128,128,128), patch_embed='conv') """ @@ -94,7 +94,7 @@ def __init__( patch_size=patch_size, hidden_size=hidden_size, num_heads=num_heads, - pos_embed=pos_embed, + patch_embed=patch_embed, dropout_rate=dropout_rate, spatial_dims=self.spatial_dims, ) diff --git a/requirements-dev.txt b/requirements-dev.txt index 9620ea253d..0761a53326 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -56,3 +56,4 @@ filelock!=3.12.0 # https://github.com/microsoft/nni/issues/5523 zarr lpips==0.1.4 nvidia-ml-py +timm>=0.9.7 \ No newline at end of file diff --git a/tests/test_patchembedding.py b/tests/test_patchembedding.py index ae7fd14401..91e0a86553 100644 --- a/tests/test_patchembedding.py +++ b/tests/test_patchembedding.py @@ -31,25 +31,27 @@ for img_size in [32, 64]: for patch_size in [8, 16]: for num_heads in [8, 12]: - for pos_embed in ["conv", "perceptron"]: - # for classification in (False, True): # TODO: add classification tests - for nd in (2, 3): - test_case = [ - { - "in_channels": in_channels, - "img_size": (img_size,) * nd, - "patch_size": (patch_size,) * nd, - "hidden_size": hidden_size, - "num_heads": num_heads, - "pos_embed": pos_embed, - "dropout_rate": dropout_rate, - }, - (2, in_channels, *([img_size] * nd)), - (2, (img_size // patch_size) ** nd, hidden_size), - ] - if nd == 2: - test_case[0]["spatial_dims"] = 2 # type: ignore - TEST_CASE_PATCHEMBEDDINGBLOCK.append(test_case) + for patch_embed in ["conv", "perceptron"]: + for pos_embed in ["none", "learnable", "sincos"]: + # for classification in (False, True): # TODO: add classification tests + for nd in (2, 3): + test_case = [ + { + "in_channels": in_channels, + "img_size": (img_size,) * nd, + "patch_size": (patch_size,) * nd, + "hidden_size": hidden_size, + "num_heads": num_heads, + "patch_embed": patch_embed, + "pos_embed": pos_embed, + "dropout_rate": dropout_rate, + }, + (2, in_channels, *([img_size] * nd)), + (2, (img_size // patch_size) ** nd, hidden_size), + ] + if nd == 2: + test_case[0]["spatial_dims"] = 2 # type: ignore + TEST_CASE_PATCHEMBEDDINGBLOCK.append(test_case) TEST_CASE_PATCHEMBED = [] for patch_size in [2]: @@ -96,7 +98,8 @@ def test_ill_arg(self): patch_size=(16, 16, 16), hidden_size=128, num_heads=12, - pos_embed="conv", + patch_embed="conv", + pos_embed="sincos", dropout_rate=5.0, ) @@ -107,7 +110,8 @@ def test_ill_arg(self): patch_size=(64, 64, 64), hidden_size=512, num_heads=8, - pos_embed="perceptron", + patch_embed="perceptron", + pos_embed="sincos", dropout_rate=0.3, ) @@ -118,7 +122,7 @@ def test_ill_arg(self): patch_size=(8, 8, 8), hidden_size=512, num_heads=14, - pos_embed="conv", + patch_embed="conv", dropout_rate=0.3, ) @@ -129,7 +133,7 @@ def test_ill_arg(self): patch_size=(4, 4, 4), hidden_size=768, num_heads=8, - pos_embed="perceptron", + patch_embed="perceptron", dropout_rate=0.3, ) @@ -140,7 +144,7 @@ def test_ill_arg(self): patch_size=(16, 16, 16), hidden_size=768, num_heads=12, - pos_embed="perc", + patch_embed="perc", dropout_rate=0.3, ) From 9a8790e89c1cc905b32f0d823d9521cc3948b1a7 Mon Sep 17 00:00:00 2001 From: NoTody Date: Thu, 14 Sep 2023 06:32:35 -0400 Subject: [PATCH 02/17] correct default positional embedding in patch embedding Signed-off-by: NoTody --- monai/networks/blocks/patchembedding.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/monai/networks/blocks/patchembedding.py b/monai/networks/blocks/patchembedding.py index 4410eca9b0..4d8dd248a6 100644 --- a/monai/networks/blocks/patchembedding.py +++ b/monai/networks/blocks/patchembedding.py @@ -25,7 +25,8 @@ from monai.utils.module import look_up_option Rearrange, _ = optional_import("einops.layers.torch", name="Rearrange") -SUPPORTED_EMBEDDING_TYPES = {"conv", "perceptron"} +SUPPORTED_PATCH_EMBEDDING_TYPES = {"conv", "perceptron"} +SUPPORTED_POS_EMBEDDING_TYPES = {"none", "learnable", "sincos"} class PatchEmbeddingBlock(nn.Module): @@ -49,7 +50,7 @@ def __init__( hidden_size: int, num_heads: int, patch_embed: str, - pos_embed: str = "norm", + pos_embed: str = "learnable", dropout_rate: float = 0.0, spatial_dims: int = 3, ) -> None: @@ -74,7 +75,8 @@ def __init__( if hidden_size % num_heads != 0: raise ValueError(f"hidden size {hidden_size} should be divisible by num_heads {num_heads}.") - self.patch_embed = look_up_option(patch_embed, SUPPORTED_EMBEDDING_TYPES) + self.patch_embed = look_up_option(patch_embed, SUPPORTED_PATCH_EMBEDDING_TYPES) + self.pos_embed = look_up_option(pos_embed, SUPPORTED_POS_EMBEDDING_TYPES) img_size = ensure_tuple_rep(img_size, spatial_dims) patch_size = ensure_tuple_rep(patch_size, spatial_dims) From 7c4c5806898ab33c0b0a3ab3bb38210448805a55 Mon Sep 17 00:00:00 2001 From: NoTody Date: Thu, 14 Sep 2023 07:12:40 -0400 Subject: [PATCH 03/17] style errors fix Signed-off-by: NoTody --- merged.zarr/.zarray | 26 +++++++++++++++++++++++ merged.zarr/0.0.0.0 | Bin 0 -> 215 bytes monai/networks/blocks/pos_embed_utils.py | 5 +++-- monai/networks/nets/vit.py | 3 ++- test.zarr/.zarray | 26 +++++++++++++++++++++++ test.zarr/0.0.0.0 | Bin 0 -> 215 bytes 6 files changed, 57 insertions(+), 3 deletions(-) create mode 100644 merged.zarr/.zarray create mode 100644 merged.zarr/0.0.0.0 create mode 100644 test.zarr/.zarray create mode 100644 test.zarr/0.0.0.0 diff --git a/merged.zarr/.zarray b/merged.zarr/.zarray new file mode 100644 index 0000000000..f656b710e2 --- /dev/null +++ b/merged.zarr/.zarray @@ -0,0 +1,26 @@ +{ + "chunks": [ + 2, + 3, + 4, + 4 + ], + "compressor": { + "blocksize": 0, + "clevel": 5, + "cname": "lz4", + "id": "blosc", + "shuffle": 1 + }, + "dtype": "kJGGB0#(!h~*g=8PvP**NPlzTCQbHpm!63r@Ox#Cj{`>MqJxD*;bh3NeV z*kC+Cb!}VV3 torch.nn.Parameter: @@ -42,7 +43,7 @@ def build_sincos_position_embedding( grid_h = torch.arange(h, dtype=torch.float32) grid_w = torch.arange(w, dtype=torch.float32) - grid_h, grid_w = torch.meshgrid(grid_h, grid_w, indexing='ij') + grid_h, grid_w = torch.meshgrid(grid_h, grid_w, indexing="ij") assert embed_dim % 4 == 0, "Embed dimension must be divisible by 4 for 2D sin-cos position embedding" @@ -59,7 +60,7 @@ def build_sincos_position_embedding( grid_w = torch.arange(w, dtype=torch.float32) grid_d = torch.arange(d, dtype=torch.float32) - grid_h, grid_w, grid_d = torch.meshgrid(grid_h, grid_w, grid_d, indexing='ij') + grid_h, grid_w, grid_d = torch.meshgrid(grid_h, grid_w, grid_d, indexing="ij") assert embed_dim % 6 == 0, "Embed dimension must be divisible by 6 for 3D sin-cos position embedding" diff --git a/monai/networks/nets/vit.py b/monai/networks/nets/vit.py index 1f51a15237..4d8c041e85 100644 --- a/monai/networks/nets/vit.py +++ b/monai/networks/nets/vit.py @@ -79,7 +79,8 @@ def __init__( >>> net = ViT(in_channels=3, img_size=(128,128,128), patch_embed='conv', pos_embed='sincos', classification=True) # for 3-channel with image size of (224,224), 12 layers and classification backbone - >>> net = ViT(in_channels=3, img_size=(224,224), patch_embed='conv', pos_embed='sincos', classification=True, spatial_dims=2) + >>> net = ViT(in_channels=3, img_size=(224,224), patch_embed='conv', pos_embed='sincos', classification=True, + >>> spatial_dims=2) """ diff --git a/test.zarr/.zarray b/test.zarr/.zarray new file mode 100644 index 0000000000..f656b710e2 --- /dev/null +++ b/test.zarr/.zarray @@ -0,0 +1,26 @@ +{ + "chunks": [ + 2, + 3, + 4, + 4 + ], + "compressor": { + "blocksize": 0, + "clevel": 5, + "cname": "lz4", + "id": "blosc", + "shuffle": 1 + }, + "dtype": "kJGGB0#(!h~*g=8PvP**NPlzTCQbHpm!63r@Ox#Cj{`>MqJxD*;bh3NeV z*kC+Cb!}VV3 Date: Thu, 14 Sep 2023 07:13:20 -0400 Subject: [PATCH 04/17] remove test temp files Signed-off-by: NoTody --- merged.zarr/.zarray | 26 -------------------------- merged.zarr/0.0.0.0 | Bin 215 -> 0 bytes test.zarr/.zarray | 26 -------------------------- test.zarr/0.0.0.0 | Bin 215 -> 0 bytes 4 files changed, 52 deletions(-) delete mode 100644 merged.zarr/.zarray delete mode 100644 merged.zarr/0.0.0.0 delete mode 100644 test.zarr/.zarray delete mode 100644 test.zarr/0.0.0.0 diff --git a/merged.zarr/.zarray b/merged.zarr/.zarray deleted file mode 100644 index f656b710e2..0000000000 --- a/merged.zarr/.zarray +++ /dev/null @@ -1,26 +0,0 @@ -{ - "chunks": [ - 2, - 3, - 4, - 4 - ], - "compressor": { - "blocksize": 0, - "clevel": 5, - "cname": "lz4", - "id": "blosc", - "shuffle": 1 - }, - "dtype": "kJGGB0#(!h~*g=8PvP**NPlzTCQbHpm!63r@Ox#Cj{`>MqJxD*;bh3NeV z*kC+Cb!}VV3kJGGB0#(!h~*g=8PvP**NPlzTCQbHpm!63r@Ox#Cj{`>MqJxD*;bh3NeV z*kC+Cb!}VV3 Date: Thu, 14 Sep 2023 11:16:43 +0000 Subject: [PATCH 05/17] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- requirements-dev.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements-dev.txt b/requirements-dev.txt index 0761a53326..1c902b867e 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -56,4 +56,4 @@ filelock!=3.12.0 # https://github.com/microsoft/nni/issues/5523 zarr lpips==0.1.4 nvidia-ml-py -timm>=0.9.7 \ No newline at end of file +timm>=0.9.7 From dd15afe6c5b34797032c15cebd9aebd487c429dd Mon Sep 17 00:00:00 2001 From: NoTody Date: Fri, 15 Sep 2023 05:10:15 -0400 Subject: [PATCH 06/17] Exclude timm with self-written functions for 2-tuple and 3-tuple. Add Deprecation for pos_embed. Signed-off-by: NoTody --- monai/networks/blocks/patchembedding.py | 34 +++++++++++++++--------- monai/networks/blocks/pos_embed_utils.py | 13 ++++++++- monai/networks/nets/unetr.py | 21 ++++++++++----- monai/networks/nets/vit.py | 31 +++++++++++++-------- monai/networks/nets/vitautoenc.py | 23 ++++++++++------ requirements-dev.txt | 3 +-- tests/test_patchembedding.py | 22 +++++++-------- 7 files changed, 94 insertions(+), 53 deletions(-) diff --git a/monai/networks/blocks/patchembedding.py b/monai/networks/blocks/patchembedding.py index 4d8dd248a6..29d513cc8f 100644 --- a/monai/networks/blocks/patchembedding.py +++ b/monai/networks/blocks/patchembedding.py @@ -21,7 +21,7 @@ from monai.networks.blocks.pos_embed_utils import build_sincos_position_embedding from monai.networks.layers import Conv, trunc_normal_ -from monai.utils import ensure_tuple_rep, optional_import +from monai.utils import deprecated_arg, ensure_tuple_rep, optional_import from monai.utils.module import look_up_option Rearrange, _ = optional_import("einops.layers.torch", name="Rearrange") @@ -38,10 +38,15 @@ class PatchEmbeddingBlock(nn.Module): >>> from monai.networks.blocks import PatchEmbeddingBlock >>> PatchEmbeddingBlock(in_channels=4, img_size=32, patch_size=8, hidden_size=32, num_heads=4, - >>> patch_embed="conv", pos_embed="sincos") + >>> proj_type="conv", pos_embed_type="sincos") """ - + @deprecated_arg( + name="pos_embed", + since="1.4", + new_name="proj_type", + msg_suffix="please use `proj_type` instead.", + ) def __init__( self, in_channels: int, @@ -49,8 +54,9 @@ def __init__( patch_size: Sequence[int] | int, hidden_size: int, num_heads: int, - patch_embed: str, - pos_embed: str = "learnable", + pos_embed: str = "conv", + proj_type: str = "conv", + pos_embed_type: str = "learnable", dropout_rate: float = 0.0, spatial_dims: int = 3, ) -> None: @@ -61,10 +67,12 @@ def __init__( patch_size: dimension of patch size. hidden_size: dimension of hidden layer. num_heads: number of attention heads. - patch_embed: patch embedding layer type. - pos_embed: position embedding layer type. + proj_type: patch embedding layer type. + pos_embed_type: position embedding layer type. dropout_rate: faction of the input units to drop. spatial_dims: number of spatial dimensions. + .. deprecated:: 1.4 + ``pos_embed`` is deprecated in favor of ``proj_type``. """ super().__init__() @@ -75,8 +83,8 @@ def __init__( if hidden_size % num_heads != 0: raise ValueError(f"hidden size {hidden_size} should be divisible by num_heads {num_heads}.") - self.patch_embed = look_up_option(patch_embed, SUPPORTED_PATCH_EMBEDDING_TYPES) - self.pos_embed = look_up_option(pos_embed, SUPPORTED_POS_EMBEDDING_TYPES) + self.patch_embed = look_up_option(proj_type, SUPPORTED_PATCH_EMBEDDING_TYPES) + self.pos_embed = look_up_option(pos_embed_type, SUPPORTED_POS_EMBEDDING_TYPES) img_size = ensure_tuple_rep(img_size, spatial_dims) patch_size = ensure_tuple_rep(patch_size, spatial_dims) @@ -110,16 +118,16 @@ def __init__( self.position_embeddings = nn.Parameter(torch.zeros(1, self.n_patches, hidden_size)) self.dropout = nn.Dropout(dropout_rate) - if pos_embed == "none": + if pos_embed_type == "none": pass - elif pos_embed == "learnable": + elif pos_embed_type == "learnable": trunc_normal_(self.position_embeddings, mean=0.0, std=0.02, a=-2.0, b=2.0) - elif pos_embed == "sincos": + elif pos_embed_type == "sincos": with torch.no_grad(): pos_embed = build_sincos_position_embedding(grid_size, hidden_size, spatial_dims) self.position_embeddings.data.copy_(pos_embed.float()) else: - raise ValueError(f"pos_embed type {pos_embed} not supported.") + raise ValueError(f"pos_embed_type {pos_embed_type} not supported.") self.apply(self._init_weights) diff --git a/monai/networks/blocks/pos_embed_utils.py b/monai/networks/blocks/pos_embed_utils.py index b08eb184e5..813b477ae1 100644 --- a/monai/networks/blocks/pos_embed_utils.py +++ b/monai/networks/blocks/pos_embed_utils.py @@ -15,10 +15,19 @@ import torch import torch.nn as nn -from timm.models.layers import to_2tuple, to_3tuple + +from itertools import repeat +import collections.abc __all__ = ["build_sincos_position_embedding"] +# From PyTorch internals +def _ntuple(n): + def parse(x): + if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): + return tuple(x) + return tuple(repeat(x, n)) + return parse def build_sincos_position_embedding( grid_size: Optional[int], embed_dim: int, spatial_dims: int = 3, temperature: float = 10000.0 @@ -38,6 +47,7 @@ def build_sincos_position_embedding( """ if spatial_dims == 2: + to_2tuple = _ntuple(2) grid_size = to_2tuple(grid_size) h, w = grid_size grid_h = torch.arange(h, dtype=torch.float32) @@ -54,6 +64,7 @@ def build_sincos_position_embedding( out_w = torch.einsum("m,d->md", [grid_w.flatten(), omega]) pos_emb = torch.cat([torch.sin(out_w), torch.cos(out_w), torch.sin(out_h), torch.cos(out_h)], dim=1)[None, :, :] elif spatial_dims == 3: + to_3tuple = _ntuple(3) grid_size = to_3tuple(grid_size) h, w, d = grid_size grid_h = torch.arange(h, dtype=torch.float32) diff --git a/monai/networks/nets/unetr.py b/monai/networks/nets/unetr.py index 219ae1f888..c588fa7c5a 100644 --- a/monai/networks/nets/unetr.py +++ b/monai/networks/nets/unetr.py @@ -18,7 +18,7 @@ from monai.networks.blocks.dynunet_block import UnetOutBlock from monai.networks.blocks.unetr_block import UnetrBasicBlock, UnetrPrUpBlock, UnetrUpBlock from monai.networks.nets.vit import ViT -from monai.utils import ensure_tuple_rep +from monai.utils import deprecated_arg, ensure_tuple_rep class UNETR(nn.Module): @@ -26,7 +26,12 @@ class UNETR(nn.Module): UNETR based on: "Hatamizadeh et al., UNETR: Transformers for 3D Medical Image Segmentation " """ - + @deprecated_arg( + name="pos_embed", + since="1.4", + new_name="proj_type", + msg_suffix="please use `proj_type` instead.", + ) def __init__( self, in_channels: int, @@ -36,7 +41,8 @@ def __init__( hidden_size: int = 768, mlp_dim: int = 3072, num_heads: int = 12, - patch_embed: str = "conv", + pos_embed: str = "conv", + proj_type: str = "conv", norm_name: tuple | str = "instance", conv_block: bool = True, res_block: bool = True, @@ -54,7 +60,7 @@ def __init__( hidden_size: dimension of hidden layer. Defaults to 768. mlp_dim: dimension of feedforward layer. Defaults to 3072. num_heads: number of attention heads. Defaults to 12. - patch_embed: patch embedding layer type. Defaults to "conv". + proj_type: patch embedding layer type. Defaults to "conv". norm_name: feature normalization type and arguments. Defaults to "instance". conv_block: if convolutional block is used. Defaults to True. res_block: if residual block is used. Defaults to True. @@ -62,7 +68,8 @@ def __init__( spatial_dims: number of spatial dims. Defaults to 3. qkv_bias: apply the bias term for the qkv linear layer in self attention block. Defaults to False. save_attn: to make accessible the attention in self attention block. Defaults to False. - + .. deprecated:: 1.4 + ``pos_embed`` is deprecated in favor of ``proj_type``. Examples:: # for single channel input 4-channel output with image size of (96,96,96), feature size of 32 and batch norm @@ -72,7 +79,7 @@ def __init__( >>> net = UNETR(in_channels=1, out_channels=4, img_size=96, feature_size=32, norm_name='batch', spatial_dims=2) # for 4-channel input 3-channel output with image size of (128,128,128), conv position embedding and instance norm - >>> net = UNETR(in_channels=4, out_channels=3, img_size=(128,128,128), patch_embed='conv', norm_name='instance') + >>> net = UNETR(in_channels=4, out_channels=3, img_size=(128,128,128), proj_type='conv', norm_name='instance') """ @@ -98,7 +105,7 @@ def __init__( mlp_dim=mlp_dim, num_layers=self.num_layers, num_heads=num_heads, - patch_embed=patch_embed, + proj_type=proj_type, classification=self.classification, dropout_rate=dropout_rate, spatial_dims=spatial_dims, diff --git a/monai/networks/nets/vit.py b/monai/networks/nets/vit.py index 4d8c041e85..7bf4f46512 100644 --- a/monai/networks/nets/vit.py +++ b/monai/networks/nets/vit.py @@ -19,6 +19,8 @@ from monai.networks.blocks.patchembedding import PatchEmbeddingBlock from monai.networks.blocks.transformerblock import TransformerBlock +from monai.utils import deprecated_arg + __all__ = ["ViT"] @@ -29,7 +31,12 @@ class ViT(nn.Module): ViT supports Torchscript but only works for Pytorch after 1.8. """ - + @deprecated_arg( + name="pos_embed", + since="1.4", + new_name="proj_type", + msg_suffix="please use `proj_type` instead.", + ) def __init__( self, in_channels: int, @@ -39,8 +46,9 @@ def __init__( mlp_dim: int = 3072, num_layers: int = 12, num_heads: int = 12, - patch_embed: str = "conv", - pos_embed: str = "learnable", + pos_embed: str = "conv", + proj_type: str = "conv", + pos_embed_type: str = "learnable", classification: bool = False, num_classes: int = 2, dropout_rate: float = 0.0, @@ -58,8 +66,8 @@ def __init__( mlp_dim (int, optional): dimension of feedforward layer. Defaults to 3072. num_layers (int, optional): number of transformer blocks. Defaults to 12. num_heads (int, optional): number of attention heads. Defaults to 12. - patch_embed (str, optional): patch embedding layer type. Defaults to "conv". - pos_embed (str, optional): position embedding type. Defaults to "learnable". + proj_type (str, optional): patch embedding layer type. Defaults to "conv". + pos_embed_type (str, optional): position embedding type. Defaults to "learnable". classification (bool, optional): bool argument to determine if classification is used. Defaults to False. num_classes (int, optional): number of classes if classification is used. Defaults to 2. dropout_rate (float, optional): faction of the input units to drop. Defaults to 0.0. @@ -69,17 +77,18 @@ def __init__( Set to other values to remove this function. qkv_bias (bool, optional): apply bias to the qkv linear layer in self attention block. Defaults to False. save_attn (bool, optional): to make accessible the attention in self attention block. Defaults to False. - + .. deprecated:: 1.4 + ``pos_embed`` is deprecated in favor of ``proj_type``. Examples:: # for single channel input with image size of (96,96,96), conv position embedding and segmentation backbone - >>> net = ViT(in_channels=1, img_size=(96,96,96), patch_embed='conv', pos_embed='sincos') + >>> net = ViT(in_channels=1, img_size=(96,96,96), proj_type='conv', pos_embed_type='sincos') # for 3-channel with image size of (128,128,128), 24 layers and classification backbone - >>> net = ViT(in_channels=3, img_size=(128,128,128), patch_embed='conv', pos_embed='sincos', classification=True) + >>> net = ViT(in_channels=3, img_size=(128,128,128), proj_type='conv', pos_embed_type='sincos', classification=True) # for 3-channel with image size of (224,224), 12 layers and classification backbone - >>> net = ViT(in_channels=3, img_size=(224,224), patch_embed='conv', pos_embed='sincos', classification=True, + >>> net = ViT(in_channels=3, img_size=(224,224), proj_type='conv', pos_embed_type='sincos', classification=True, >>> spatial_dims=2) """ @@ -99,8 +108,8 @@ def __init__( patch_size=patch_size, hidden_size=hidden_size, num_heads=num_heads, - patch_embed=patch_embed, - pos_embed=pos_embed, + proj_type=proj_type, + pos_embed_type=pos_embed_type, dropout_rate=dropout_rate, spatial_dims=spatial_dims, ) diff --git a/monai/networks/nets/vitautoenc.py b/monai/networks/nets/vitautoenc.py index c1808072b0..eef0a25418 100644 --- a/monai/networks/nets/vitautoenc.py +++ b/monai/networks/nets/vitautoenc.py @@ -20,7 +20,7 @@ from monai.networks.blocks.patchembedding import PatchEmbeddingBlock from monai.networks.blocks.transformerblock import TransformerBlock from monai.networks.layers import Conv -from monai.utils import ensure_tuple_rep, is_sqrt +from monai.utils import deprecated_arg, ensure_tuple_rep, is_sqrt __all__ = ["ViTAutoEnc"] @@ -32,7 +32,12 @@ class ViTAutoEnc(nn.Module): Modified to also give same dimension outputs as the input size of the image """ - + @deprecated_arg( + name="pos_embed", + since="1.4", + new_name="proj_type", + msg_suffix="please use `proj_type` instead.", + ) def __init__( self, in_channels: int, @@ -44,7 +49,8 @@ def __init__( mlp_dim: int = 3072, num_layers: int = 12, num_heads: int = 12, - patch_embed: str = "conv", + pos_embed: str = "conv", + proj_type: str = "conv", dropout_rate: float = 0.0, spatial_dims: int = 3, qkv_bias: bool = False, @@ -61,20 +67,21 @@ def __init__( mlp_dim: dimension of feedforward layer. Defaults to 3072. num_layers: number of transformer blocks. Defaults to 12. num_heads: number of attention heads. Defaults to 12. - patch_embed: position embedding layer type. Defaults to "conv". + proj_type: position embedding layer type. Defaults to "conv". dropout_rate: faction of the input units to drop. Defaults to 0.0. spatial_dims: number of spatial dimensions. Defaults to 3. qkv_bias: apply bias to the qkv linear layer in self attention block. Defaults to False. save_attn: to make accessible the attention in self attention block. Defaults to False. Defaults to False. - + .. deprecated:: 1.4 + ``pos_embed`` is deprecated in favor of ``proj_type``. Examples:: # for single channel input with image size of (96,96,96), conv position embedding and segmentation backbone # It will provide an output of same size as that of the input - >>> net = ViTAutoEnc(in_channels=1, patch_size=(16,16,16), img_size=(96,96,96), patch_embed='conv') + >>> net = ViTAutoEnc(in_channels=1, patch_size=(16,16,16), img_size=(96,96,96), proj_type='conv') # for 3-channel with image size of (128,128,128), output will be same size as of input - >>> net = ViTAutoEnc(in_channels=3, patch_size=(16,16,16), img_size=(128,128,128), patch_embed='conv') + >>> net = ViTAutoEnc(in_channels=3, patch_size=(16,16,16), img_size=(128,128,128), proj_type='conv') """ @@ -94,7 +101,7 @@ def __init__( patch_size=patch_size, hidden_size=hidden_size, num_heads=num_heads, - patch_embed=patch_embed, + proj_type=proj_type, dropout_rate=dropout_rate, spatial_dims=self.spatial_dims, ) diff --git a/requirements-dev.txt b/requirements-dev.txt index 0761a53326..c242a61cff 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -55,5 +55,4 @@ typeguard<3 # https://github.com/microsoft/nni/issues/5457 filelock!=3.12.0 # https://github.com/microsoft/nni/issues/5523 zarr lpips==0.1.4 -nvidia-ml-py -timm>=0.9.7 \ No newline at end of file +nvidia-ml-py \ No newline at end of file diff --git a/tests/test_patchembedding.py b/tests/test_patchembedding.py index 91e0a86553..ba4e5c0b99 100644 --- a/tests/test_patchembedding.py +++ b/tests/test_patchembedding.py @@ -31,8 +31,8 @@ for img_size in [32, 64]: for patch_size in [8, 16]: for num_heads in [8, 12]: - for patch_embed in ["conv", "perceptron"]: - for pos_embed in ["none", "learnable", "sincos"]: + for proj_type in ["conv", "perceptron"]: + for pos_embed_type in ["none", "learnable", "sincos"]: # for classification in (False, True): # TODO: add classification tests for nd in (2, 3): test_case = [ @@ -42,8 +42,8 @@ "patch_size": (patch_size,) * nd, "hidden_size": hidden_size, "num_heads": num_heads, - "patch_embed": patch_embed, - "pos_embed": pos_embed, + "pos_embed": proj_type, + "pos_embed_type": pos_embed_type, "dropout_rate": dropout_rate, }, (2, in_channels, *([img_size] * nd)), @@ -98,8 +98,8 @@ def test_ill_arg(self): patch_size=(16, 16, 16), hidden_size=128, num_heads=12, - patch_embed="conv", - pos_embed="sincos", + pos_embed="conv", + pos_embed_type="sincos", dropout_rate=5.0, ) @@ -110,8 +110,8 @@ def test_ill_arg(self): patch_size=(64, 64, 64), hidden_size=512, num_heads=8, - patch_embed="perceptron", - pos_embed="sincos", + pos_embed="perceptron", + pos_embed_type="sincos", dropout_rate=0.3, ) @@ -122,7 +122,7 @@ def test_ill_arg(self): patch_size=(8, 8, 8), hidden_size=512, num_heads=14, - patch_embed="conv", + pos_embed="conv", dropout_rate=0.3, ) @@ -133,7 +133,7 @@ def test_ill_arg(self): patch_size=(4, 4, 4), hidden_size=768, num_heads=8, - patch_embed="perceptron", + pos_embed="perceptron", dropout_rate=0.3, ) @@ -144,7 +144,7 @@ def test_ill_arg(self): patch_size=(16, 16, 16), hidden_size=768, num_heads=12, - patch_embed="perc", + pos_embed="perc", dropout_rate=0.3, ) From 216ccc064b587a38021bd5412c169aad6365d9bb Mon Sep 17 00:00:00 2001 From: NoTody Date: Fri, 15 Sep 2023 05:15:55 -0400 Subject: [PATCH 07/17] remove timm in requirements-dev dependencies Signed-off-by: NoTody --- requirements-dev.txt | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/requirements-dev.txt b/requirements-dev.txt index 1c902b867e..c242a61cff 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -55,5 +55,4 @@ typeguard<3 # https://github.com/microsoft/nni/issues/5457 filelock!=3.12.0 # https://github.com/microsoft/nni/issues/5523 zarr lpips==0.1.4 -nvidia-ml-py -timm>=0.9.7 +nvidia-ml-py \ No newline at end of file From 342fe83ec0e86bd03f9228f1d82f83cf5aee6a4f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 15 Sep 2023 09:17:55 +0000 Subject: [PATCH 08/17] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- requirements-dev.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements-dev.txt b/requirements-dev.txt index c242a61cff..9620ea253d 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -55,4 +55,4 @@ typeguard<3 # https://github.com/microsoft/nni/issues/5457 filelock!=3.12.0 # https://github.com/microsoft/nni/issues/5523 zarr lpips==0.1.4 -nvidia-ml-py \ No newline at end of file +nvidia-ml-py From 35c12ed6cac127a145023cd02b64214e62cdc54c Mon Sep 17 00:00:00 2001 From: NoTody Date: Fri, 15 Sep 2023 06:01:41 -0400 Subject: [PATCH 09/17] lint formatting. change conflict variables in patchembedding. Signed-off-by: NoTody --- monai/networks/blocks/patchembedding.py | 32 +++++++++++------------- monai/networks/blocks/pos_embed_utils.py | 8 +++--- monai/networks/nets/unetr.py | 8 ++---- monai/networks/nets/vit.py | 9 ++----- monai/networks/nets/vitautoenc.py | 8 ++---- 5 files changed, 25 insertions(+), 40 deletions(-) diff --git a/monai/networks/blocks/patchembedding.py b/monai/networks/blocks/patchembedding.py index 29d513cc8f..0d8e7212d3 100644 --- a/monai/networks/blocks/patchembedding.py +++ b/monai/networks/blocks/patchembedding.py @@ -41,12 +41,8 @@ class PatchEmbeddingBlock(nn.Module): >>> proj_type="conv", pos_embed_type="sincos") """ - @deprecated_arg( - name="pos_embed", - since="1.4", - new_name="proj_type", - msg_suffix="please use `proj_type` instead.", - ) + + @deprecated_arg(name="pos_embed", since="1.4", new_name="proj_type", msg_suffix="please use `proj_type` instead.") def __init__( self, in_channels: int, @@ -83,15 +79,15 @@ def __init__( if hidden_size % num_heads != 0: raise ValueError(f"hidden size {hidden_size} should be divisible by num_heads {num_heads}.") - self.patch_embed = look_up_option(proj_type, SUPPORTED_PATCH_EMBEDDING_TYPES) - self.pos_embed = look_up_option(pos_embed_type, SUPPORTED_POS_EMBEDDING_TYPES) + self.proj_type = look_up_option(proj_type, SUPPORTED_PATCH_EMBEDDING_TYPES) + self.pos_embed_type = look_up_option(pos_embed_type, SUPPORTED_POS_EMBEDDING_TYPES) img_size = ensure_tuple_rep(img_size, spatial_dims) patch_size = ensure_tuple_rep(patch_size, spatial_dims) for m, p in zip(img_size, patch_size): if m < p: raise ValueError("patch_size should be smaller than img_size.") - if self.patch_embed == "perceptron" and m % p != 0: + if self.proj_type == "perceptron" and m % p != 0: raise ValueError("patch_size should be divisible by img_size for perceptron.") self.n_patches = np.prod([im_d // p_d for im_d, p_d in zip(img_size, patch_size)]) self.patch_dim = int(in_channels * np.prod(patch_size)) @@ -102,11 +98,11 @@ def __init__( grid_size.append(in_size // pa_size) self.patch_embeddings: nn.Module - if self.patch_embed == "conv": + if self.proj_type == "conv": self.patch_embeddings = Conv[Conv.CONV, spatial_dims]( in_channels=in_channels, out_channels=hidden_size, kernel_size=patch_size, stride=patch_size ) - elif self.patch_embed == "perceptron": + elif self.proj_type == "perceptron": # for 3d: "b c (h p1) (w p2) (d p3)-> b (h w d) (p1 p2 p3 c)" chars = (("h", "p1"), ("w", "p2"), ("d", "p3"))[:spatial_dims] from_chars = "b c " + " ".join(f"({k} {v})" for k, v in chars) @@ -118,16 +114,16 @@ def __init__( self.position_embeddings = nn.Parameter(torch.zeros(1, self.n_patches, hidden_size)) self.dropout = nn.Dropout(dropout_rate) - if pos_embed_type == "none": + if self.pos_embed_type == "none": pass - elif pos_embed_type == "learnable": + elif self.pos_embed_type == "learnable": trunc_normal_(self.position_embeddings, mean=0.0, std=0.02, a=-2.0, b=2.0) - elif pos_embed_type == "sincos": + elif self.pos_embed_type == "sincos": with torch.no_grad(): - pos_embed = build_sincos_position_embedding(grid_size, hidden_size, spatial_dims) - self.position_embeddings.data.copy_(pos_embed.float()) + pos_embeddings = build_sincos_position_embedding(grid_size, hidden_size, spatial_dims) + self.position_embeddings.data.copy_(pos_embeddings.float()) else: - raise ValueError(f"pos_embed_type {pos_embed_type} not supported.") + raise ValueError(f"pos_embed_type {self.pos_embed_type} not supported.") self.apply(self._init_weights) @@ -142,7 +138,7 @@ def _init_weights(self, m): def forward(self, x): x = self.patch_embeddings(x) - if self.patch_embed == "conv": + if self.proj_type == "conv": x = x.flatten(2).transpose(-1, -2) embeddings = x + self.position_embeddings embeddings = self.dropout(embeddings) diff --git a/monai/networks/blocks/pos_embed_utils.py b/monai/networks/blocks/pos_embed_utils.py index 813b477ae1..111ddbafb8 100644 --- a/monai/networks/blocks/pos_embed_utils.py +++ b/monai/networks/blocks/pos_embed_utils.py @@ -11,24 +11,26 @@ from __future__ import annotations +import collections.abc +from itertools import repeat from typing import Optional import torch import torch.nn as nn -from itertools import repeat -import collections.abc - __all__ = ["build_sincos_position_embedding"] + # From PyTorch internals def _ntuple(n): def parse(x): if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): return tuple(x) return tuple(repeat(x, n)) + return parse + def build_sincos_position_embedding( grid_size: Optional[int], embed_dim: int, spatial_dims: int = 3, temperature: float = 10000.0 ) -> torch.nn.Parameter: diff --git a/monai/networks/nets/unetr.py b/monai/networks/nets/unetr.py index c588fa7c5a..46dd474606 100644 --- a/monai/networks/nets/unetr.py +++ b/monai/networks/nets/unetr.py @@ -26,12 +26,8 @@ class UNETR(nn.Module): UNETR based on: "Hatamizadeh et al., UNETR: Transformers for 3D Medical Image Segmentation " """ - @deprecated_arg( - name="pos_embed", - since="1.4", - new_name="proj_type", - msg_suffix="please use `proj_type` instead.", - ) + + @deprecated_arg(name="pos_embed", since="1.4", new_name="proj_type", msg_suffix="please use `proj_type` instead.") def __init__( self, in_channels: int, diff --git a/monai/networks/nets/vit.py b/monai/networks/nets/vit.py index 7bf4f46512..7639dd2f58 100644 --- a/monai/networks/nets/vit.py +++ b/monai/networks/nets/vit.py @@ -18,7 +18,6 @@ from monai.networks.blocks.patchembedding import PatchEmbeddingBlock from monai.networks.blocks.transformerblock import TransformerBlock - from monai.utils import deprecated_arg __all__ = ["ViT"] @@ -31,12 +30,8 @@ class ViT(nn.Module): ViT supports Torchscript but only works for Pytorch after 1.8. """ - @deprecated_arg( - name="pos_embed", - since="1.4", - new_name="proj_type", - msg_suffix="please use `proj_type` instead.", - ) + + @deprecated_arg(name="pos_embed", since="1.4", new_name="proj_type", msg_suffix="please use `proj_type` instead.") def __init__( self, in_channels: int, diff --git a/monai/networks/nets/vitautoenc.py b/monai/networks/nets/vitautoenc.py index eef0a25418..222160e7fd 100644 --- a/monai/networks/nets/vitautoenc.py +++ b/monai/networks/nets/vitautoenc.py @@ -32,12 +32,8 @@ class ViTAutoEnc(nn.Module): Modified to also give same dimension outputs as the input size of the image """ - @deprecated_arg( - name="pos_embed", - since="1.4", - new_name="proj_type", - msg_suffix="please use `proj_type` instead.", - ) + + @deprecated_arg(name="pos_embed", since="1.4", new_name="proj_type", msg_suffix="please use `proj_type` instead.") def __init__( self, in_channels: int, From 259313cc2ce430eb41628e3e99a87d5ae14b5db2 Mon Sep 17 00:00:00 2001 From: NoTody Date: Fri, 15 Sep 2023 06:24:05 -0400 Subject: [PATCH 10/17] add deprecation doc indent Signed-off-by: NoTody --- monai/networks/nets/unetr.py | 2 ++ monai/networks/nets/vit.py | 2 ++ monai/networks/nets/vitautoenc.py | 2 ++ 3 files changed, 6 insertions(+) diff --git a/monai/networks/nets/unetr.py b/monai/networks/nets/unetr.py index 46dd474606..c242552823 100644 --- a/monai/networks/nets/unetr.py +++ b/monai/networks/nets/unetr.py @@ -64,8 +64,10 @@ def __init__( spatial_dims: number of spatial dims. Defaults to 3. qkv_bias: apply the bias term for the qkv linear layer in self attention block. Defaults to False. save_attn: to make accessible the attention in self attention block. Defaults to False. + .. deprecated:: 1.4 ``pos_embed`` is deprecated in favor of ``proj_type``. + Examples:: # for single channel input 4-channel output with image size of (96,96,96), feature size of 32 and batch norm diff --git a/monai/networks/nets/vit.py b/monai/networks/nets/vit.py index 7639dd2f58..2153673676 100644 --- a/monai/networks/nets/vit.py +++ b/monai/networks/nets/vit.py @@ -72,8 +72,10 @@ def __init__( Set to other values to remove this function. qkv_bias (bool, optional): apply bias to the qkv linear layer in self attention block. Defaults to False. save_attn (bool, optional): to make accessible the attention in self attention block. Defaults to False. + .. deprecated:: 1.4 ``pos_embed`` is deprecated in favor of ``proj_type``. + Examples:: # for single channel input with image size of (96,96,96), conv position embedding and segmentation backbone diff --git a/monai/networks/nets/vitautoenc.py b/monai/networks/nets/vitautoenc.py index 222160e7fd..ece5cb9e14 100644 --- a/monai/networks/nets/vitautoenc.py +++ b/monai/networks/nets/vitautoenc.py @@ -68,8 +68,10 @@ def __init__( spatial_dims: number of spatial dimensions. Defaults to 3. qkv_bias: apply bias to the qkv linear layer in self attention block. Defaults to False. save_attn: to make accessible the attention in self attention block. Defaults to False. Defaults to False. + .. deprecated:: 1.4 ``pos_embed`` is deprecated in favor of ``proj_type``. + Examples:: # for single channel input with image size of (96,96,96), conv position embedding and segmentation backbone From 64d911fbfae5acda816a2c298acc39aeb6baf123 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 15 Sep 2023 10:24:31 +0000 Subject: [PATCH 11/17] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/networks/nets/unetr.py | 4 ++-- monai/networks/nets/vit.py | 4 ++-- monai/networks/nets/vitautoenc.py | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/monai/networks/nets/unetr.py b/monai/networks/nets/unetr.py index c242552823..6ee5e1a9d4 100644 --- a/monai/networks/nets/unetr.py +++ b/monai/networks/nets/unetr.py @@ -64,10 +64,10 @@ def __init__( spatial_dims: number of spatial dims. Defaults to 3. qkv_bias: apply the bias term for the qkv linear layer in self attention block. Defaults to False. save_attn: to make accessible the attention in self attention block. Defaults to False. - + .. deprecated:: 1.4 ``pos_embed`` is deprecated in favor of ``proj_type``. - + Examples:: # for single channel input 4-channel output with image size of (96,96,96), feature size of 32 and batch norm diff --git a/monai/networks/nets/vit.py b/monai/networks/nets/vit.py index 2153673676..6968cdcef4 100644 --- a/monai/networks/nets/vit.py +++ b/monai/networks/nets/vit.py @@ -72,10 +72,10 @@ def __init__( Set to other values to remove this function. qkv_bias (bool, optional): apply bias to the qkv linear layer in self attention block. Defaults to False. save_attn (bool, optional): to make accessible the attention in self attention block. Defaults to False. - + .. deprecated:: 1.4 ``pos_embed`` is deprecated in favor of ``proj_type``. - + Examples:: # for single channel input with image size of (96,96,96), conv position embedding and segmentation backbone diff --git a/monai/networks/nets/vitautoenc.py b/monai/networks/nets/vitautoenc.py index ece5cb9e14..9c9f13c363 100644 --- a/monai/networks/nets/vitautoenc.py +++ b/monai/networks/nets/vitautoenc.py @@ -68,10 +68,10 @@ def __init__( spatial_dims: number of spatial dimensions. Defaults to 3. qkv_bias: apply bias to the qkv linear layer in self attention block. Defaults to False. save_attn: to make accessible the attention in self attention block. Defaults to False. Defaults to False. - + .. deprecated:: 1.4 ``pos_embed`` is deprecated in favor of ``proj_type``. - + Examples:: # for single channel input with image size of (96,96,96), conv position embedding and segmentation backbone From 1f8dcef90d455996ecec0b026ca0af64bcd5ee58 Mon Sep 17 00:00:00 2001 From: NoTody Date: Fri, 15 Sep 2023 12:46:45 -0400 Subject: [PATCH 12/17] Fix typing definition in pos_embed_utils.py. Remove image/patch shape check. Signed-off-by: NoTody --- monai/networks/blocks/patchembedding.py | 9 ++++----- monai/networks/blocks/pos_embed_utils.py | 14 +++++++------- 2 files changed, 11 insertions(+), 12 deletions(-) diff --git a/monai/networks/blocks/patchembedding.py b/monai/networks/blocks/patchembedding.py index 0d8e7212d3..43f0090614 100644 --- a/monai/networks/blocks/patchembedding.py +++ b/monai/networks/blocks/patchembedding.py @@ -92,11 +92,6 @@ def __init__( self.n_patches = np.prod([im_d // p_d for im_d, p_d in zip(img_size, patch_size)]) self.patch_dim = int(in_channels * np.prod(patch_size)) - grid_size = [] - for in_size, pa_size in zip(img_size, patch_size): - assert in_size % pa_size == 0, "input size and patch size are not proper" - grid_size.append(in_size // pa_size) - self.patch_embeddings: nn.Module if self.proj_type == "conv": self.patch_embeddings = Conv[Conv.CONV, spatial_dims]( @@ -119,6 +114,10 @@ def __init__( elif self.pos_embed_type == "learnable": trunc_normal_(self.position_embeddings, mean=0.0, std=0.02, a=-2.0, b=2.0) elif self.pos_embed_type == "sincos": + grid_size = [] + for in_size, pa_size in zip(img_size, patch_size): + grid_size.append(in_size // pa_size) + with torch.no_grad(): pos_embeddings = build_sincos_position_embedding(grid_size, hidden_size, spatial_dims) self.position_embeddings.data.copy_(pos_embeddings.float()) diff --git a/monai/networks/blocks/pos_embed_utils.py b/monai/networks/blocks/pos_embed_utils.py index 111ddbafb8..e1f47cd7e9 100644 --- a/monai/networks/blocks/pos_embed_utils.py +++ b/monai/networks/blocks/pos_embed_utils.py @@ -13,7 +13,7 @@ import collections.abc from itertools import repeat -from typing import Optional +from typing import List, Union import torch import torch.nn as nn @@ -32,14 +32,14 @@ def parse(x): def build_sincos_position_embedding( - grid_size: Optional[int], embed_dim: int, spatial_dims: int = 3, temperature: float = 10000.0 + grid_size: Union[int, List[int]], embed_dim: int, spatial_dims: int = 3, temperature: float = 10000.0 ) -> torch.nn.Parameter: """ Builds a sin-cos position embedding based on the given grid size, embed dimension, spatial dimensions, and temperature. Reference: https://github.com/cvlab-stonybrook/SelfMedMAE/blob/68d191dfcc1c7d0145db93a6a570362de29e3b30/lib/models/mae3d.py Args: - grid_size (int or Tuple[int]): The size of the grid in each spatial dimension. + grid_size (List[int]): The size of the grid in each spatial dimension. embed_dim (int): The dimension of the embedding. spatial_dims (int): The number of spatial dimensions (2 for 2D, 3 for 3D). temperature (float): The temperature for the sin-cos position embedding. @@ -50,8 +50,8 @@ def build_sincos_position_embedding( if spatial_dims == 2: to_2tuple = _ntuple(2) - grid_size = to_2tuple(grid_size) - h, w = grid_size + grid_size_t = to_2tuple(grid_size) + h, w = grid_size_t grid_h = torch.arange(h, dtype=torch.float32) grid_w = torch.arange(w, dtype=torch.float32) @@ -67,8 +67,8 @@ def build_sincos_position_embedding( pos_emb = torch.cat([torch.sin(out_w), torch.cos(out_w), torch.sin(out_h), torch.cos(out_h)], dim=1)[None, :, :] elif spatial_dims == 3: to_3tuple = _ntuple(3) - grid_size = to_3tuple(grid_size) - h, w, d = grid_size + grid_size_t = to_3tuple(grid_size) + h, w, d = grid_size_t grid_h = torch.arange(h, dtype=torch.float32) grid_w = torch.arange(w, dtype=torch.float32) grid_d = torch.arange(d, dtype=torch.float32) From d148fc1768f314001c3db151e9adc53dbdec8530 Mon Sep 17 00:00:00 2001 From: NoTody Date: Fri, 15 Sep 2023 12:54:51 -0400 Subject: [PATCH 13/17] lint formatting Signed-off-by: NoTody --- monai/networks/blocks/patchembedding.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/networks/blocks/patchembedding.py b/monai/networks/blocks/patchembedding.py index 43f0090614..903482ec0e 100644 --- a/monai/networks/blocks/patchembedding.py +++ b/monai/networks/blocks/patchembedding.py @@ -117,7 +117,7 @@ def __init__( grid_size = [] for in_size, pa_size in zip(img_size, patch_size): grid_size.append(in_size // pa_size) - + with torch.no_grad(): pos_embeddings = build_sincos_position_embedding(grid_size, hidden_size, spatial_dims) self.position_embeddings.data.copy_(pos_embeddings.float()) From 44fc76b5ac73803549e90257f15704ac394e1e8a Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Sat, 16 Sep 2023 12:52:06 +0100 Subject: [PATCH 14/17] update deprecating ver. Signed-off-by: Wenqi Li --- monai/networks/blocks/patchembedding.py | 2 +- tests/test_patchembedding.py | 10 ++++++++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/monai/networks/blocks/patchembedding.py b/monai/networks/blocks/patchembedding.py index 903482ec0e..cb259b4ef5 100644 --- a/monai/networks/blocks/patchembedding.py +++ b/monai/networks/blocks/patchembedding.py @@ -42,7 +42,7 @@ class PatchEmbeddingBlock(nn.Module): """ - @deprecated_arg(name="pos_embed", since="1.4", new_name="proj_type", msg_suffix="please use `proj_type` instead.") + @deprecated_arg(name="pos_embed", since="1.2", new_name="proj_type", msg_suffix="please use `proj_type` instead.") def __init__( self, in_channels: int, diff --git a/tests/test_patchembedding.py b/tests/test_patchembedding.py index ba4e5c0b99..18f89d8ea1 100644 --- a/tests/test_patchembedding.py +++ b/tests/test_patchembedding.py @@ -136,6 +136,16 @@ def test_ill_arg(self): pos_embed="perceptron", dropout_rate=0.3, ) + with self.assertRaises(ValueError): + PatchEmbeddingBlock( + in_channels=1, + img_size=(97, 97, 97), + patch_size=(4, 4, 4), + hidden_size=768, + num_heads=8, + proj_type="perceptron", + dropout_rate=0.3, + ) with self.assertRaises(ValueError): PatchEmbeddingBlock( From 7ca8f87c530e91d4f729fd93095635bb3ba5a0c7 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Sat, 16 Sep 2023 15:06:21 +0100 Subject: [PATCH 15/17] update deprecating ver. Signed-off-by: Wenqi Li --- monai/networks/nets/unetr.py | 2 +- monai/networks/nets/vit.py | 2 +- monai/networks/nets/vitautoenc.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/monai/networks/nets/unetr.py b/monai/networks/nets/unetr.py index 6ee5e1a9d4..bfcd6e7d47 100644 --- a/monai/networks/nets/unetr.py +++ b/monai/networks/nets/unetr.py @@ -27,7 +27,7 @@ class UNETR(nn.Module): UNETR: Transformers for 3D Medical Image Segmentation " """ - @deprecated_arg(name="pos_embed", since="1.4", new_name="proj_type", msg_suffix="please use `proj_type` instead.") + @deprecated_arg(name="pos_embed", since="1.2", new_name="proj_type", msg_suffix="please use `proj_type` instead.") def __init__( self, in_channels: int, diff --git a/monai/networks/nets/vit.py b/monai/networks/nets/vit.py index 6968cdcef4..85f428afc1 100644 --- a/monai/networks/nets/vit.py +++ b/monai/networks/nets/vit.py @@ -31,7 +31,7 @@ class ViT(nn.Module): ViT supports Torchscript but only works for Pytorch after 1.8. """ - @deprecated_arg(name="pos_embed", since="1.4", new_name="proj_type", msg_suffix="please use `proj_type` instead.") + @deprecated_arg(name="pos_embed", since="1.2", new_name="proj_type", msg_suffix="please use `proj_type` instead.") def __init__( self, in_channels: int, diff --git a/monai/networks/nets/vitautoenc.py b/monai/networks/nets/vitautoenc.py index 9c9f13c363..6063b192a4 100644 --- a/monai/networks/nets/vitautoenc.py +++ b/monai/networks/nets/vitautoenc.py @@ -33,7 +33,7 @@ class ViTAutoEnc(nn.Module): Modified to also give same dimension outputs as the input size of the image """ - @deprecated_arg(name="pos_embed", since="1.4", new_name="proj_type", msg_suffix="please use `proj_type` instead.") + @deprecated_arg(name="pos_embed", since="1.2", new_name="proj_type", msg_suffix="please use `proj_type` instead.") def __init__( self, in_channels: int, From 3ac01ad1bca38a09ca832f54de46fe755e34a660 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Sat, 16 Sep 2023 15:08:52 +0100 Subject: [PATCH 16/17] fixes typo Signed-off-by: Wenqi Li --- monai/networks/blocks/mlp.py | 2 +- monai/networks/blocks/patchembedding.py | 2 +- monai/networks/blocks/selfattention.py | 2 +- monai/networks/blocks/transformerblock.py | 2 +- monai/networks/nets/transchex.py | 2 +- monai/networks/nets/vit.py | 2 +- monai/networks/nets/vitautoenc.py | 2 +- 7 files changed, 7 insertions(+), 7 deletions(-) diff --git a/monai/networks/blocks/mlp.py b/monai/networks/blocks/mlp.py index e3ab94b32a..d3510b64d3 100644 --- a/monai/networks/blocks/mlp.py +++ b/monai/networks/blocks/mlp.py @@ -32,7 +32,7 @@ def __init__( Args: hidden_size: dimension of hidden layer. mlp_dim: dimension of feedforward layer. If 0, `hidden_size` will be used. - dropout_rate: faction of the input units to drop. + dropout_rate: fraction of the input units to drop. act: activation type and arguments. Defaults to GELU. Also supports "GEGLU" and others. dropout_mode: dropout mode, can be "vit" or "swin". "vit" mode uses two dropout instances as implemented in diff --git a/monai/networks/blocks/patchembedding.py b/monai/networks/blocks/patchembedding.py index cb259b4ef5..f6d390692e 100644 --- a/monai/networks/blocks/patchembedding.py +++ b/monai/networks/blocks/patchembedding.py @@ -65,7 +65,7 @@ def __init__( num_heads: number of attention heads. proj_type: patch embedding layer type. pos_embed_type: position embedding layer type. - dropout_rate: faction of the input units to drop. + dropout_rate: fraction of the input units to drop. spatial_dims: number of spatial dimensions. .. deprecated:: 1.4 ``pos_embed`` is deprecated in favor of ``proj_type``. diff --git a/monai/networks/blocks/selfattention.py b/monai/networks/blocks/selfattention.py index 71fb549db8..7c81c1704f 100644 --- a/monai/networks/blocks/selfattention.py +++ b/monai/networks/blocks/selfattention.py @@ -37,7 +37,7 @@ def __init__( Args: hidden_size (int): dimension of hidden layer. num_heads (int): number of attention heads. - dropout_rate (float, optional): faction of the input units to drop. Defaults to 0.0. + dropout_rate (float, optional): fraction of the input units to drop. Defaults to 0.0. qkv_bias (bool, optional): bias term for the qkv linear layer. Defaults to False. save_attn (bool, optional): to make accessible the attention matrix. Defaults to False. diff --git a/monai/networks/blocks/transformerblock.py b/monai/networks/blocks/transformerblock.py index f7d4e0e130..ddf959dad2 100644 --- a/monai/networks/blocks/transformerblock.py +++ b/monai/networks/blocks/transformerblock.py @@ -37,7 +37,7 @@ def __init__( hidden_size (int): dimension of hidden layer. mlp_dim (int): dimension of feedforward layer. num_heads (int): number of attention heads. - dropout_rate (float, optional): faction of the input units to drop. Defaults to 0.0. + dropout_rate (float, optional): fraction of the input units to drop. Defaults to 0.0. qkv_bias (bool, optional): apply bias term for the qkv linear layer. Defaults to False. save_attn (bool, optional): to make accessible the attention matrix. Defaults to False. diff --git a/monai/networks/nets/transchex.py b/monai/networks/nets/transchex.py index 31e27ffbf2..c73415b63e 100644 --- a/monai/networks/nets/transchex.py +++ b/monai/networks/nets/transchex.py @@ -314,7 +314,7 @@ def __init__( num_language_layers: number of language transformer layers. num_vision_layers: number of vision transformer layers. num_mixed_layers: number of mixed transformer layers. - drop_out: faction of the input units to drop. + drop_out: fraction of the input units to drop. The other parameters are part of the `bert_config` to `MultiModal.from_pretrained`. diff --git a/monai/networks/nets/vit.py b/monai/networks/nets/vit.py index 85f428afc1..f033d7ff4a 100644 --- a/monai/networks/nets/vit.py +++ b/monai/networks/nets/vit.py @@ -65,7 +65,7 @@ def __init__( pos_embed_type (str, optional): position embedding type. Defaults to "learnable". classification (bool, optional): bool argument to determine if classification is used. Defaults to False. num_classes (int, optional): number of classes if classification is used. Defaults to 2. - dropout_rate (float, optional): faction of the input units to drop. Defaults to 0.0. + dropout_rate (float, optional): fraction of the input units to drop. Defaults to 0.0. spatial_dims (int, optional): number of spatial dimensions. Defaults to 3. post_activation (str, optional): add a final acivation function to the classification head when `classification` is True. Default to "Tanh" for `nn.Tanh()`. diff --git a/monai/networks/nets/vitautoenc.py b/monai/networks/nets/vitautoenc.py index 6063b192a4..59aae2d54a 100644 --- a/monai/networks/nets/vitautoenc.py +++ b/monai/networks/nets/vitautoenc.py @@ -64,7 +64,7 @@ def __init__( num_layers: number of transformer blocks. Defaults to 12. num_heads: number of attention heads. Defaults to 12. proj_type: position embedding layer type. Defaults to "conv". - dropout_rate: faction of the input units to drop. Defaults to 0.0. + dropout_rate: fraction of the input units to drop. Defaults to 0.0. spatial_dims: number of spatial dimensions. Defaults to 3. qkv_bias: apply bias to the qkv linear layer in self attention block. Defaults to False. save_attn: to make accessible the attention in self attention block. Defaults to False. Defaults to False. From 3be960481e1c3445614a434d3f049391aced089c Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Sat, 16 Sep 2023 15:54:05 +0100 Subject: [PATCH 17/17] skip TypeError: meshgrid() got an unexpected keyword argument 'indexing' Signed-off-by: Wenqi Li --- tests/test_patchembedding.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/test_patchembedding.py b/tests/test_patchembedding.py index 18f89d8ea1..77ade984eb 100644 --- a/tests/test_patchembedding.py +++ b/tests/test_patchembedding.py @@ -21,6 +21,7 @@ from monai.networks import eval_mode from monai.networks.blocks.patchembedding import PatchEmbed, PatchEmbeddingBlock from monai.utils import optional_import +from tests.utils import SkipIfBeforePyTorchVersion einops, has_einops = optional_import("einops") @@ -74,6 +75,7 @@ TEST_CASE_PATCHEMBED.append(test_case) +@SkipIfBeforePyTorchVersion((1, 11, 1)) class TestPatchEmbeddingBlock(unittest.TestCase): def setUp(self): self.threads = torch.get_num_threads()