Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions docs/source/networks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,26 @@ Blocks
.. autoclass:: PatchEmbeddingBlock
:members:

`FactorizedIncreaseBlock`
~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: FactorizedIncreaseBlock
:members:

`FactorizedReduceBlock`
~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: FactorizedReduceBlock
:members:

`P3DActiConvNormBlock`
~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: P3DActiConvNormBlock
:members:

`ActiConvNormBlock`
~~~~~~~~~~~~~~~~~~~
.. autoclass:: ActiConvNormBlock
:members:

`Warp`
~~~~~~
.. autoclass:: Warp
Expand Down
1 change: 1 addition & 0 deletions monai/networks/blocks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from .aspp import SimpleASPP
from .convolutions import Convolution, ResidualUnit
from .crf import CRF
from .dints_block import ActiConvNormBlock, FactorizedIncreaseBlock, FactorizedReduceBlock, P3DActiConvNormBlock
from .downsample import MaxAvgPool
from .dynunet_block import UnetBasicBlock, UnetOutBlock, UnetResBlock, UnetUpBlock, get_output_padding, get_padding
from .fcn import FCN, GCN, MCFCN, Refine
Expand Down
265 changes: 265 additions & 0 deletions monai/networks/blocks/dints_block.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,265 @@
# Copyright 2020 - 2021 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from typing import Tuple, Union

import torch

from monai.networks.layers.factories import Conv
from monai.networks.layers.utils import get_act_layer, get_norm_layer

__all__ = ["FactorizedIncreaseBlock", "FactorizedReduceBlock", "P3DActiConvNormBlock", "ActiConvNormBlock"]


class FactorizedIncreaseBlock(torch.nn.Sequential):
"""
Up-sampling the features by two using linear interpolation and convolutions.
"""

def __init__(
self,
in_channel: int,
out_channel: int,
spatial_dims: int = 3,
act_name: Union[Tuple, str] = "RELU",
norm_name: Union[Tuple, str] = "INSTANCE",
):
"""
Args:
in_channel: number of input channels
out_channel: number of output channels
spatial_dims: number of spatial dimensions
act_name: activation layer type and arguments.
norm_name: feature normalization type and arguments.
"""
super().__init__()
self._in_channel = in_channel
self._out_channel = out_channel
self._spatial_dims = spatial_dims

conv_type = Conv[Conv.CONV, self._spatial_dims]

self.add_module("up", torch.nn.Upsample(scale_factor=2, mode="trilinear", align_corners=True))
self.add_module("acti", get_act_layer(name=act_name))
self.add_module(
"conv",
conv_type(
in_channels=self._in_channel,
out_channels=self._out_channel,
kernel_size=1,
stride=1,
padding=0,
groups=1,
bias=False,
dilation=1,
),
)
self.add_module(
"norm", get_norm_layer(name=norm_name, spatial_dims=self._spatial_dims, channels=self._out_channel)
)


class FactorizedReduceBlock(torch.nn.Module):
"""
Down-sampling the feature by 2 using stride.
The length along each spatial dimension must be a multiple of 2.
"""

def __init__(
self,
in_channel: int,
out_channel: int,
spatial_dims: int = 3,
act_name: Union[Tuple, str] = "RELU",
norm_name: Union[Tuple, str] = "INSTANCE",
):
"""
Args:
in_channel: number of input channels
out_channel: number of output channels.
spatial_dims: number of spatial dimensions.
act_name: activation layer type and arguments.
norm_name: feature normalization type and arguments.
"""
super().__init__()
self._in_channel = in_channel
self._out_channel = out_channel
self._spatial_dims = spatial_dims

conv_type = Conv[Conv.CONV, self._spatial_dims]

self.act = get_act_layer(name=act_name)
self.conv_1 = conv_type(
in_channels=self._in_channel,
out_channels=self._out_channel // 2,
kernel_size=1,
stride=2,
padding=0,
groups=1,
bias=False,
dilation=1,
)
self.conv_2 = conv_type(
in_channels=self._in_channel,
out_channels=self._out_channel // 2,
kernel_size=1,
stride=2,
padding=0,
groups=1,
bias=False,
dilation=1,
)
self.norm = get_norm_layer(name=norm_name, spatial_dims=self._spatial_dims, channels=self._out_channel)

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
The length along each spatial dimension must be a multiple of 2.
"""
x = self.act(x)
out = torch.cat([self.conv_1(x), self.conv_2(x[:, :, 1:, 1:, 1:])], dim=1)
out = self.norm(out)
return out


class P3DActiConvNormBlock(torch.nn.Sequential):
"""
-- (act) -- (conv) -- (norm) --
"""

def __init__(
self,
in_channel: int,
out_channel: int,
kernel_size: int,
padding: int,
mode: int = 0,
act_name: Union[Tuple, str] = "RELU",
norm_name: Union[Tuple, str] = "INSTANCE",
):
"""
Args:
in_channel: number of input channels.
out_channel: number of output channels.
kernel_size: kernel size to be expanded to 3D.
padding: padding size to be expanded to 3D.
mode: mode for the anisotropic kernels:

- 0: ``(k, k, 1)``, ``(1, 1, k)``,
- 1: ``(k, 1, k)``, ``(1, k, 1)``,
- 2: ``(1, k, k)``. ``(k, 1, 1)``.

act_name:activation layer type and arguments.
norm_name: feature normalization type and arguments.
"""
super().__init__()
self._in_channel = in_channel
self._out_channel = out_channel
self._p3dmode = int(mode)

conv_type = Conv[Conv.CONV, 3]

if self._p3dmode == 0: # (k, k, 1), (1, 1, k)
kernel_size0 = (kernel_size, kernel_size, 1)
kernel_size1 = (1, 1, kernel_size)
padding0 = (padding, padding, 0)
padding1 = (0, 0, padding)
elif self._p3dmode == 1: # (k, 1, k), (1, k, 1)
kernel_size0 = (kernel_size, 1, kernel_size)
kernel_size1 = (1, kernel_size, 1)
padding0 = (padding, 0, padding)
padding1 = (0, padding, 0)
elif self._p3dmode == 2: # (1, k, k), (k, 1, 1)
kernel_size0 = (1, kernel_size, kernel_size)
kernel_size1 = (kernel_size, 1, 1)
padding0 = (0, padding, padding)
padding1 = (padding, 0, 0)
else:
raise ValueError("`mode` must be 0, 1, or 2.")

self.add_module("acti", get_act_layer(name=act_name))
self.add_module(
"conv",
conv_type(
in_channels=self._in_channel,
out_channels=self._in_channel,
kernel_size=kernel_size0,
stride=1,
padding=padding0,
groups=1,
bias=False,
dilation=1,
),
)
self.add_module(
"conv_1",
conv_type(
in_channels=self._in_channel,
out_channels=self._out_channel,
kernel_size=kernel_size1,
stride=1,
padding=padding1,
groups=1,
bias=False,
dilation=1,
),
)
self.add_module("norm", get_norm_layer(name=norm_name, spatial_dims=3, channels=self._out_channel))


class ActiConvNormBlock(torch.nn.Sequential):
"""
-- (Acti) -- (Conv) -- (Norm) --
"""

def __init__(
self,
in_channel: int,
out_channel: int,
kernel_size: int = 3,
padding: int = 1,
spatial_dims: int = 3,
act_name: Union[Tuple, str] = "RELU",
norm_name: Union[Tuple, str] = "INSTANCE",
):
"""
Args:
in_channel: number of input channels.
out_channel: number of output channels.
kernel_size: kernel size of the convolution.
padding: padding size of the convolution.
spatial_dims: number of spatial dimensions.
act_name: activation layer type and arguments.
norm_name: feature normalization type and arguments.
"""
super().__init__()
self._in_channel = in_channel
self._out_channel = out_channel
self._spatial_dims = spatial_dims

conv_type = Conv[Conv.CONV, self._spatial_dims]
self.add_module("acti", get_act_layer(name=act_name))
self.add_module(
"conv",
conv_type(
in_channels=self._in_channel,
out_channels=self._out_channel,
kernel_size=kernel_size,
stride=1,
padding=padding,
groups=1,
bias=False,
dilation=1,
),
)
self.add_module(
"norm", get_norm_layer(name=norm_name, spatial_dims=self._spatial_dims, channels=self._out_channel)
)
2 changes: 2 additions & 0 deletions monai/networks/nets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
densenet201,
densenet264,
)

# from .dints import DiNTS
from .dynunet import DynUNet, DynUnet, Dynunet
from .efficientnet import (
BlockArgs,
Expand Down
38 changes: 38 additions & 0 deletions tests/test_acn_block.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Copyright 2020 - 2021 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import torch
from parameterized import parameterized

from monai.networks.blocks.dints_block import ActiConvNormBlock

TEST_CASES = [
[{"in_channel": 32, "out_channel": 16, "kernel_size": 3, "padding": 1}, (7, 32, 16, 31, 7), (7, 16, 16, 31, 7)],
[
{"in_channel": 32, "out_channel": 16, "kernel_size": 3, "padding": 1, "spatial_dims": 2},
(7, 32, 13, 32),
(7, 16, 13, 32),
],
]


class TestACNBlock(unittest.TestCase):
@parameterized.expand(TEST_CASES)
def test_acn_block(self, input_param, input_shape, expected_shape):
net = ActiConvNormBlock(**input_param)
result = net(torch.randn(input_shape))
self.assertEqual(result.shape, expected_shape)


if __name__ == "__main__":
unittest.main()
34 changes: 34 additions & 0 deletions tests/test_factorized_increase.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Copyright 2020 - 2021 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import torch
from parameterized import parameterized

from monai.networks.blocks.dints_block import FactorizedIncreaseBlock

TEST_CASES_3D = [
[{"in_channel": 32, "out_channel": 16}, (7, 32, 24, 16, 8), (7, 16, 48, 32, 16)],
[{"in_channel": 1, "out_channel": 2}, (1, 1, 1, 1, 1), (1, 2, 2, 2, 2)],
]


class TestFactInc(unittest.TestCase):
@parameterized.expand(TEST_CASES_3D)
def test_factorized_increase_3d(self, input_param, input_shape, expected_shape):
net = FactorizedIncreaseBlock(**input_param)
result = net(torch.randn(input_shape))
self.assertEqual(result.shape, expected_shape)


if __name__ == "__main__":
unittest.main()
Loading