Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,4 @@ sphinxcontrib-qthelp
sphinxcontrib-serializinghtml
sphinx-autodoc-typehints==1.11.1
pandas
einops
4 changes: 2 additions & 2 deletions docs/source/installation.md
Original file line number Diff line number Diff line change
Expand Up @@ -174,9 +174,9 @@ Since MONAI v0.2.0, the extras syntax such as `pip install 'monai[nibabel]'` is

- The options are
```
[nibabel, skimage, pillow, tensorboard, gdown, ignite, torchvision, itk, tqdm, lmdb, psutil, cucim, openslide, pandas]
[nibabel, skimage, pillow, tensorboard, gdown, ignite, torchvision, itk, tqdm, lmdb, psutil, cucim, openslide, pandas, einops]
```
which correspond to `nibabel`, `scikit-image`, `pillow`, `tensorboard`,
`gdown`, `pytorch-ignite`, `torchvision`, `itk`, `tqdm`, `lmdb`, `psutil`, `cucim` `openslide-python` and `pandas`, respectively.
`gdown`, `pytorch-ignite`, `torchvision`, `itk`, `tqdm`, `lmdb`, `psutil`, `cucim`, `openslide-python`, `pandas` and `einops`, respectively.

- `pip install 'monai[all]'` installs all the optional dependencies.
39 changes: 39 additions & 0 deletions docs/source/networks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -79,11 +79,30 @@ Blocks
.. autoclass:: ResBlock
:members:

`SABlock Block`
~~~~~~~~~~~~~~~~~
.. autoclass:: SABlock
:members:

`Squeeze-and-Excitation`
~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: ChannelSELayer
:members:

`Transformer Block`
~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: TransformerBlock
:members:

`UNETR Block`
~~~~~~~~~~~~~~~~~~~~
.. autoclass:: UnetrBasicBlock
:members:
.. autoclass:: UnetrUpBlock
:members:
.. autoclass:: UnetrPrUpBlock
:members:

`Residual Squeeze-and-Excitation`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: ResidualSELayer
Expand Down Expand Up @@ -159,6 +178,16 @@ Blocks
.. autoclass:: LocalNetFeatureExtractorBlock
:members:

`MLP Block`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: MLPBlock
:members:

`Patch Embedding Block`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: PatchEmbeddingBlock
:members:

`Warp`
~~~~~~
.. autoclass:: Warp
Expand Down Expand Up @@ -389,6 +418,11 @@ Nets
.. autoclass:: Unet
.. autoclass:: unet

`UNETR`
~~~~~~~~~~~~~~~~~
.. autoclass:: UNETR
:members:

`BasicUNet`
~~~~~~~~~~~
.. autoclass:: BasicUNet
Expand Down Expand Up @@ -426,6 +460,11 @@ Nets
.. autoclass:: VarAutoEncoder
:members:

`ViT`
~~~~~~
.. autoclass:: ViT
:members:

`FullyConnectedNet`
~~~~~~~~~~~~~~~~~~~
.. autoclass:: FullyConnectedNet
Expand Down
1 change: 1 addition & 0 deletions monai/config/deviceconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ def get_optional_config_values():
output["lmdb"] = get_package_version("lmdb")
output["psutil"] = psutil_version
output["pandas"] = get_package_version("pandas")
output["einops"] = get_package_version("einops")

return output

Expand Down
5 changes: 5 additions & 0 deletions monai/networks/blocks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,11 @@
from .dynunet_block import UnetBasicBlock, UnetOutBlock, UnetResBlock, UnetUpBlock, get_output_padding, get_padding
from .fcn import FCN, GCN, MCFCN, Refine
from .localnet_block import LocalNetDownSampleBlock, LocalNetFeatureExtractorBlock, LocalNetUpSampleBlock
from .mlp import MLPBlock
from .patchembedding import PatchEmbeddingBlock
from .regunet_block import RegistrationDownSampleBlock, RegistrationExtractionBlock, RegistrationResidualConvBlock
from .segresnet_block import ResBlock
from .selfattention import SABlock
from .squeeze_and_excitation import (
ChannelSELayer,
ResidualSELayer,
Expand All @@ -28,5 +31,7 @@
SEResNetBottleneck,
SEResNeXtBottleneck,
)
from .transformerblock import TransformerBlock
from .unetr_block import UnetrBasicBlock, UnetrPrUpBlock, UnetrUpBlock
from .upsample import SubpixelUpsample, Subpixelupsample, SubpixelUpSample, Upsample, UpSample
from .warp import DVF2DDF, Warp
51 changes: 51 additions & 0 deletions monai/networks/blocks/mlp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Copyright 2020 - 2021 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch.nn as nn


class MLPBlock(nn.Module):
"""
A multi-layer perceptron block, based on: "Dosovitskiy et al.,
An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>"
"""

def __init__(
self,
hidden_size: int,
mlp_dim: int,
dropout_rate: float = 0.0,
) -> None:
"""
Args:
hidden_size: dimension of hidden layer.
mlp_dim: dimension of feedforward layer.
dropout_rate: faction of the input units to drop.

"""

super().__init__()

if not (0 <= dropout_rate <= 1):
raise AssertionError("dropout_rate should be between 0 and 1.")

self.linear1 = nn.Linear(hidden_size, mlp_dim)
self.linear2 = nn.Linear(mlp_dim, hidden_size)
self.fn = nn.GELU()
self.drop1 = nn.Dropout(dropout_rate)
self.drop2 = nn.Dropout(dropout_rate)

def forward(self, x):
x = self.fn(self.linear1(x))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

might be more robust if there's a torch.flatten as the first step?

x = self.drop1(x)
x = self.linear2(x)
x = self.drop2(x)
return x
138 changes: 138 additions & 0 deletions monai/networks/blocks/patchembedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
# Copyright 2020 - 2021 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import math
from typing import Tuple, Union

import torch
import torch.nn as nn

from monai.utils import optional_import

einops, has_einops = optional_import("einops")


class PatchEmbeddingBlock(nn.Module):
"""
A patch embedding block, based on: "Dosovitskiy et al.,
An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>"
"""

def __init__(
self,
in_channels: int,
img_size: Union[int, Tuple[int, int, int]],
patch_size: Union[int, Tuple[int, int, int]],
hidden_size: int,
num_heads: int,
pos_embed: Union[Tuple, str], # type: ignore
classification: bool,
dropout_rate: float = 0.0,
) -> None:
"""
Args:
in_channels: dimension of input channels.
img_size: dimension of input image.
patch_size: dimension of patch size.
hidden_size: dimension of hidden layer.
num_heads: number of attention heads.
pos_embed: position embedding layer type.
classification: bool argument to determine if classification is used.
dropout_rate: faction of the input units to drop.

"""

super().__init__()

if not (0 <= dropout_rate <= 1):
raise AssertionError("dropout_rate should be between 0 and 1.")

if hidden_size % num_heads != 0:
raise AssertionError("hidden size should be divisible by num_heads.")

if img_size < patch_size: # type: ignore
raise AssertionError("patch_size should be smaller than img_size.")

if pos_embed not in ["conv", "perceptron"]:
raise KeyError(f"Position embedding layer of type {pos_embed} is not supported.")

if pos_embed == "perceptron":
if img_size[0] % patch_size[0] != 0: # type: ignore
raise AssertionError("img_size should be divisible by patch_size for perceptron patch embedding.")

if has_einops: # type: ignore
from einops.layers.torch import Rearrange # type: ignore

self.Rearrange = Rearrange # type: ignore
else:
raise ValueError('"Requires einops.')

self.n_patches = (
(img_size[0] // patch_size[0]) * (img_size[1] // patch_size[1]) * (img_size[2] // patch_size[2]) # type: ignore
)
self.patch_dim = in_channels * patch_size[0] * patch_size[1] * patch_size[2] # type: ignore
self.pos_embed = pos_embed
if self.pos_embed == "conv":
self.patch_embeddings = nn.Conv3d(
in_channels=in_channels, out_channels=hidden_size, kernel_size=patch_size, stride=patch_size # type: ignore
)
elif self.pos_embed == "perceptron":
self.patch_embeddings = nn.Sequential( # type: ignore
self.Rearrange(
"b c (h p1) (w p2) (d p3)-> b (h w d) (p1 p2 p3 c)",
p1=patch_size[0], # type: ignore
p2=patch_size[1], # type: ignore
p3=patch_size[2], # type: ignore
),
nn.Linear(self.patch_dim, hidden_size),
)
self.position_embeddings = nn.Parameter(torch.zeros(1, self.n_patches, hidden_size))
self.cls_token = nn.Parameter(torch.zeros(1, 1, hidden_size))
self.dropout = nn.Dropout(dropout_rate)
self.trunc_normal_(self.position_embeddings, mean=0.0, std=0.02, a=-2.0, b=2.0)
self.apply(self._init_weights)

def _init_weights(self, m):
if isinstance(m, nn.Linear):
self.trunc_normal_(m.weight, mean=0.0, std=0.02, a=-2.0, b=2.0)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)

def trunc_normal_(self, tensor, mean, std, a, b):
# From PyTorch official master until it's in a few official releases - RW
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
def norm_cdf(x):
return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0

with torch.no_grad():
l = norm_cdf((a - mean) / std)
u = norm_cdf((b - mean) / std)
tensor.uniform_(2 * l - 1, 2 * u - 1)
tensor.erfinv_()
tensor.mul_(std * math.sqrt(2.0))
tensor.add_(mean)
tensor.clamp_(min=a, max=b)
return tensor

def forward(self, x):
if self.pos_embed == "conv":
x = self.patch_embeddings(x)
x = x.flatten(2)
x = x.transpose(-1, -2)
elif self.pos_embed == "perceptron":
x = self.patch_embeddings(x)
embeddings = x + self.position_embeddings
embeddings = self.dropout(embeddings)
return embeddings
68 changes: 68 additions & 0 deletions monai/networks/blocks/selfattention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# Copyright 2020 - 2021 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
import torch.nn as nn

from monai.utils import optional_import

einops, has_einops = optional_import("einops")


class SABlock(nn.Module):
"""
A self-attention block, based on: "Dosovitskiy et al.,
An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>"
"""

def __init__(
self,
hidden_size: int,
num_heads: int,
dropout_rate: float = 0.0,
) -> None:
"""
Args:
hidden_size: dimension of hidden layer.
num_heads: number of attention heads.
dropout_rate: faction of the input units to drop.

"""

super().__init__()

if not (0 <= dropout_rate <= 1):
raise AssertionError("dropout_rate should be between 0 and 1.")

if hidden_size % num_heads != 0:
raise AssertionError("hidden size should be divisible by num_heads.")

self.num_heads = num_heads
self.out_proj = nn.Linear(hidden_size, hidden_size)
self.qkv = nn.Linear(hidden_size, hidden_size * 3, bias=False)
self.drop_output = nn.Dropout(dropout_rate)
self.drop_weights = nn.Dropout(dropout_rate)
self.head_dim = hidden_size // num_heads
self.scale = self.head_dim ** -0.5
if has_einops:
self.rearrange = einops.rearrange
else:
raise ValueError('"Requires einops.')

def forward(self, x):
q, k, v = self.rearrange(self.qkv(x), "b h (qkv l d) -> qkv b l h d", qkv=3, l=self.num_heads)
att_mat = (torch.einsum("blxd,blyd->blxy", q, k) * self.scale).softmax(dim=-1)
att_mat = self.drop_weights(att_mat)
x = torch.einsum("bhxy,bhyd->bhxd", att_mat, v)
x = self.rearrange(x, "b h l d -> b l (h d)")
x = self.out_proj(x)
x = self.drop_output(x)
return x
Loading