Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/metrics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,5 +30,5 @@ Metrics
.. autofunction:: compute_average_surface_distance

`Occlusion sensitivity`
--------------------------
-----------------------
.. autofunction:: compute_occlusion_sensitivity
5 changes: 5 additions & 0 deletions docs/source/networks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@ Blocks
.. automodule:: monai.networks.blocks
.. currentmodule:: monai.networks.blocks

`ADN`
~~~~~
.. autoclass:: ADN
:members:

`Convolution`
~~~~~~~~~~~~~
.. autoclass:: Convolution
Expand Down
1 change: 1 addition & 0 deletions monai/networks/blocks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .acti_norm import ADN
from .aspp import SimpleASPP
from .convolutions import Convolution, ResidualUnit
from .downsample import MaxAvgPool
Expand Down
119 changes: 119 additions & 0 deletions monai/networks/blocks/acti_norm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
# Copyright 2020 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional, Tuple, Union

import torch.nn as nn

from monai.networks.layers.factories import Act, Dropout, Norm, split_args
from monai.utils import has_option


class ADN(nn.Sequential):
"""
Constructs a sequential module of optional activation, dropout, and normalization layers
(with an arbitrary order)::

-- (Norm) -- (Dropout) -- (Acti) --

Args:
ordering: a string representing the ordering of activation, dropout, and normalization. Defaults to "NDA".
in_channels: `C` from an expected input of size (N, C, H[, W, D]).
act: activation type and arguments. Defaults to PReLU.
norm: feature normalization type and arguments. Defaults to instance norm.
norm_dim: determine the spatial dimensions of the normalization layer.
defaults to `dropout_dim` if unspecified.
dropout: dropout ratio. Defaults to no dropout.
dropout_dim: determine the spatial dimensions of dropout.
defaults to `norm_dim` if unspecified.

- When dropout_dim = 1, randomly zeroes some of the elements for each channel.
- When dropout_dim = 2, Randomly zeroes out entire channels (a channel is a 2D feature map).
- When dropout_dim = 3, Randomly zeroes out entire channels (a channel is a 3D feature map).

Examples::

# activation, group norm, dropout
>>> norm_params = ("GROUP", {"num_groups": 1, "affine": False})
>>> ADN(norm=norm_params, in_channels=1, dropout_dim=1, dropout=0.8, ordering="AND")
ADN(
(A): ReLU()
(N): GroupNorm(1, 1, eps=1e-05, affine=False)
(D): Dropout(p=0.8, inplace=False)
)

# LeakyReLU, dropout
>>> act_params = ("leakyrelu", {"negative_slope": 0.1, "inplace": True})
>>> ADN(act=act_params, in_channels=1, dropout_dim=1, dropout=0.8, ordering="AD")
ADN(
(A): LeakyReLU(negative_slope=0.1, inplace=True)
(D): Dropout(p=0.8, inplace=False)
)

See also:

:py:class:`monai.networks.layers.Dropout`
:py:class:`monai.networks.layers.Act`
:py:class:`monai.networks.layers.Norm`
:py:class:`monai.networks.layers.split_args`

"""

def __init__(
self,
ordering: str = "NDA",
in_channels: Optional[int] = None,
act: Optional[Union[Tuple, str]] = "RELU",
norm: Optional[Union[Tuple, str]] = None,
norm_dim: Optional[int] = None,
dropout: Optional[Union[Tuple, str, float]] = None,
dropout_dim: Optional[int] = None,
) -> None:
super().__init__()

op_dict = {"A": None, "D": None, "N": None}
# define the normalisation type and the arguments to the constructor
if norm is not None:
if norm_dim is None and dropout_dim is None:
raise ValueError("norm_dim or dropout_dim needs to be specified.")
norm_name, norm_args = split_args(norm)
norm_type = Norm[norm_name, norm_dim or dropout_dim]
kw_args = dict(norm_args)
if has_option(norm_type, "num_features") and "num_features" not in kw_args:
kw_args["num_features"] = in_channels
if has_option(norm_type, "num_channels") and "num_channels" not in kw_args:
kw_args["num_channels"] = in_channels
op_dict["N"] = norm_type(**kw_args)

# define the activation type and the arguments to the constructor
if act is not None:
act_name, act_args = split_args(act)
act_type = Act[act_name]
op_dict["A"] = act_type(**act_args)

if dropout is not None:
# if dropout was specified simply as a p value, use default name and make a keyword map with the value
if isinstance(dropout, (int, float)):
drop_name = Dropout.DROPOUT
drop_args = {"p": float(dropout)}
else:
drop_name, drop_args = split_args(dropout)

if norm_dim is None and dropout_dim is None:
raise ValueError("norm_dim or dropout_dim needs to be specified.")
drop_type = Dropout[drop_name, dropout_dim or norm_dim]
op_dict["D"] = drop_type(**drop_args)

for item in ordering.upper():
if item not in op_dict:
raise ValueError(f"ordering must be a string of {op_dict}, got {item} in it.")
if op_dict[item] is not None:
self.add_module(item, op_dict[item]) # type: ignore
96 changes: 41 additions & 55 deletions monai/networks/blocks/convolutions.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,16 @@
import torch
import torch.nn as nn

from monai.networks.blocks import ADN
from monai.networks.layers.convutils import same_padding, stride_minus_kernel_padding
from monai.networks.layers.factories import Act, Conv, Dropout, Norm, split_args
from monai.networks.layers.factories import Conv


class Convolution(nn.Sequential):
"""
Constructs a convolution with normalization, optional dropout, and optional activation layers::

-- (Conv|ConvTrans) -- Norm -- (Dropout) -- (Acti) --
-- (Conv|ConvTrans) -- (Norm -- Dropout -- Acti) --

if ``conv_only`` set to ``True``::

Expand All @@ -35,14 +36,18 @@ class Convolution(nn.Sequential):
out_channels: number of output channels.
strides: convolution stride. Defaults to 1.
kernel_size: convolution kernel size. Defaults to 3.
adn_ordering: a string representing the ordering of activation, normalization, and dropout.
Defaults to "NDA".
act: activation type and arguments. Defaults to PReLU.
norm: feature normalization type and arguments. Defaults to instance norm.
dropout: dropout ratio. Defaults to no dropout.
dropout_dim: determine the dimensions of dropout. Defaults to 1.
When dropout_dim = 1, randomly zeroes some of the elements for each channel.
When dropout_dim = 2, Randomly zeroes out entire channels (a channel is a 2D feature map).
When dropout_dim = 3, Randomly zeroes out entire channels (a channel is a 3D feature map).
The value of dropout_dim should be no no larger than the value of dimensions.

- When dropout_dim = 1, randomly zeroes some of the elements for each channel.
- When dropout_dim = 2, Randomly zeroes out entire channels (a channel is a 2D feature map).
- When dropout_dim = 3, Randomly zeroes out entire channels (a channel is a 3D feature map).

The value of dropout_dim should be no no larger than the value of `dimensions`.
dilation: dilation rate. Defaults to 1.
groups: controls the connections between inputs and outputs. Defaults to 1.
bias: whether to have a bias term. Defaults to True.
Expand All @@ -56,10 +61,7 @@ class Convolution(nn.Sequential):
See also:

:py:class:`monai.networks.layers.Conv`
:py:class:`monai.networks.layers.Dropout`
:py:class:`monai.networks.layers.Act`
:py:class:`monai.networks.layers.Norm`
:py:class:`monai.networks.layers.split_args`
:py:class:`monai.networks.blocks.ADN`

"""

Expand All @@ -70,10 +72,11 @@ def __init__(
out_channels: int,
strides: Union[Sequence[int], int] = 1,
kernel_size: Union[Sequence[int], int] = 3,
act: Optional[Union[Tuple, str]] = Act.PRELU,
norm: Union[Tuple, str] = Norm.INSTANCE,
adn_ordering: str = "NDA",
act: Optional[Union[Tuple, str]] = "PRELU",
norm: Optional[Union[Tuple, str]] = "INSTANCE",
dropout: Optional[Union[Tuple, str, float]] = None,
dropout_dim: int = 1,
dropout_dim: Optional[int] = 1,
dilation: Union[Sequence[int], int] = 1,
groups: int = 1,
bias: bool = True,
Expand All @@ -90,33 +93,6 @@ def __init__(
if padding is None:
padding = same_padding(kernel_size, dilation)
conv_type = Conv[Conv.CONVTRANS if is_transposed else Conv.CONV, dimensions]
# define the normalisation type and the arguments to the constructor
if norm is not None:
norm_name, norm_args = split_args(norm)
norm_type = Norm[norm_name, dimensions]
else:
norm_type = norm_args = None

# define the activation type and the arguments to the constructor
if act is not None:
act_name, act_args = split_args(act)
act_type = Act[act_name]
else:
act_type = act_args = None

if dropout:
# if dropout was specified simply as a p value, use default name and make a keyword map with the value
if isinstance(dropout, (int, float)):
drop_name = Dropout.DROPOUT
drop_args = {"p": dropout}
else:
drop_name, drop_args = split_args(dropout)

if dropout_dim > dimensions:
raise ValueError(
f"dropout_dim should be no larger than dimensions, got dropout_dim={dropout_dim} and dimensions={dimensions}."
)
drop_type = Dropout[drop_name, dropout_dim]

if is_transposed:
if output_padding is None:
Expand Down Expand Up @@ -147,14 +123,18 @@ def __init__(
self.add_module("conv", conv)

if not conv_only:
if norm is not None:
self.add_module("norm", norm_type(out_channels, **norm_args))

if dropout:
self.add_module("dropout", drop_type(**drop_args))

if act is not None:
self.add_module("act", act_type(**act_args))
self.add_module(
"adn",
ADN(
ordering=adn_ordering,
in_channels=out_channels,
act=act,
norm=norm,
norm_dim=dimensions,
dropout=dropout,
dropout_dim=dropout_dim,
),
)


class ResidualUnit(nn.Module):
Expand All @@ -168,14 +148,18 @@ class ResidualUnit(nn.Module):
strides: convolution stride. Defaults to 1.
kernel_size: convolution kernel size. Defaults to 3.
subunits: number of convolutions. Defaults to 2.
adn_ordering: a string representing the ordering of activation, normalization, and dropout.
Defaults to "NDA".
act: activation type and arguments. Defaults to PReLU.
norm: feature normalization type and arguments. Defaults to instance norm.
dropout: dropout ratio. Defaults to no dropout.
dropout_dim: determine the dimensions of dropout. Defaults to 1.
When dropout_dim = 1, randomly zeroes some of the elements for each channel.
When dropout_dim = 2, Randomly zero out entire channels (a channel is a 2D feature map).
When dropout_dim = 3, Randomly zero out entire channels (a channel is a 3D feature map).
The value of dropout_dim should be no no larger than the value of dimensions.

- When dropout_dim = 1, randomly zeroes some of the elements for each channel.
- When dropout_dim = 2, Randomly zero out entire channels (a channel is a 2D feature map).
- When dropout_dim = 3, Randomly zero out entire channels (a channel is a 3D feature map).

The value of dropout_dim should be no no larger than the value of `dimensions`.
dilation: dilation rate. Defaults to 1.
bias: whether to have a bias term. Defaults to True.
last_conv_only: for the last subunit, whether to use the convolutional layer only.
Expand All @@ -197,10 +181,11 @@ def __init__(
strides: Union[Sequence[int], int] = 1,
kernel_size: Union[Sequence[int], int] = 3,
subunits: int = 2,
act: Optional[Union[Tuple, str]] = Act.PRELU,
norm: Union[Tuple, str] = Norm.INSTANCE,
adn_ordering: str = "NDA",
act: Optional[Union[Tuple, str]] = "PRELU",
norm: Optional[Union[Tuple, str]] = "INSTANCE",
dropout: Optional[Union[Tuple, str, float]] = None,
dropout_dim: int = 1,
dropout_dim: Optional[int] = 1,
dilation: Union[Sequence[int], int] = 1,
bias: bool = True,
last_conv_only: bool = False,
Expand All @@ -226,6 +211,7 @@ def __init__(
out_channels,
strides=sstrides,
kernel_size=kernel_size,
adn_ordering=adn_ordering,
act=act,
norm=norm,
dropout=dropout,
Expand Down
23 changes: 21 additions & 2 deletions monai/networks/layers/factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def use_factory(fact_args):
layer = use_factory( (fact.TEST, kwargs) )
"""

from typing import Any, Callable, Dict, Tuple, Type, Union
from typing import Any, Callable, Dict, Optional, Tuple, Type, Union

import torch.nn as nn

Expand Down Expand Up @@ -216,7 +216,26 @@ def batch_factory(dim: int) -> Type[Union[nn.BatchNorm1d, nn.BatchNorm2d, nn.Bat
return types[dim - 1]


Norm.add_factory_callable("group", lambda: nn.modules.GroupNorm)
@Norm.factory_function("group")
def group_factory(_dim: Optional[int] = None) -> Type[nn.GroupNorm]:
return nn.GroupNorm


@Norm.factory_function("layer")
def layer_factory(_dim: Optional[int] = None) -> Type[nn.LayerNorm]:
return nn.LayerNorm


@Norm.factory_function("localresponse")
def local_response_factory(_dim: Optional[int] = None) -> Type[nn.LocalResponseNorm]:
return nn.LocalResponseNorm


@Norm.factory_function("syncbatch")
def sync_batch_factory(_dim: Optional[int] = None) -> Type[nn.SyncBatchNorm]:
return nn.SyncBatchNorm


Act.add_factory_callable("elu", lambda: nn.modules.ELU)
Act.add_factory_callable("relu", lambda: nn.modules.ReLU)
Act.add_factory_callable("leakyrelu", lambda: nn.modules.LeakyReLU)
Expand Down
13 changes: 10 additions & 3 deletions monai/networks/nets/basic_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,13 +123,12 @@ def forward(self, x: torch.Tensor, x_e: torch.Tensor):
x_0 = self.upsample(x)

# handling spatial shapes due to the 2x maxpooling with odd edge lengths.
dimensions = x.ndim - 2 # type: ignore
dimensions = len(x.shape) - 2
sp = [0] * (dimensions * 2)
for i in range(dimensions):
if x_e.shape[-i - 1] != x_0.shape[-i - 1]:
sp[i * 2 + 1] = 1
if sum(sp) != 0:
x_0 = torch.nn.functional.pad(x_0, sp, "replicate")
x_0 = torch.nn.functional.pad(x_0, sp, "replicate")

x = self.convs(torch.cat([x_e, x_0], dim=1)) # input channels: (cat_chns + up_chns)
return x
Expand Down Expand Up @@ -177,9 +176,17 @@ def __init__(
# for spatial 2D
>>> net = BasicUNet(dimensions=2, features=(64, 128, 256, 512, 1024, 128))

# for spatial 2D, with group norm
>>> net = BasicUNet(dimensions=2, features=(64, 128, 256, 512, 1024, 128), norm=("group", {"num_groups": 4}))

# for spatial 3D
>>> net = BasicUNet(dimensions=3, features=(32, 32, 64, 128, 256, 32))

See Also

- :py:class:`monai.networks.nets.DynUNet`
- :py:class:`monai.networks.nets.UNet`

"""
super().__init__()

Expand Down
Loading