diff --git a/docs/source/networks.rst b/docs/source/networks.rst index 2224b42a74..fc16e8c86e 100644 --- a/docs/source/networks.rst +++ b/docs/source/networks.rst @@ -164,6 +164,11 @@ Layers .. currentmodule:: monai.networks.layers +`ChannelPad` +~~~~~~~~~~~~ +.. autoclass:: ChannelPad + :members: + `SkipConnection` ~~~~~~~~~~~~~~~~ .. autoclass:: SkipConnection diff --git a/monai/networks/layers/simplelayers.py b/monai/networks/layers/simplelayers.py index 4860b2862c..48012dfb1c 100644 --- a/monai/networks/layers/simplelayers.py +++ b/monai/networks/layers/simplelayers.py @@ -18,13 +18,82 @@ from torch.autograd import Function from monai.networks.layers.convutils import gaussian_1d, same_padding -from monai.utils import PT_BEFORE_1_7, InvalidPyTorchVersionError, SkipMode, ensure_tuple_rep, optional_import +from monai.networks.layers.factories import Conv +from monai.utils import ( + PT_BEFORE_1_7, + ChannelMatching, + InvalidPyTorchVersionError, + SkipMode, + ensure_tuple_rep, + optional_import, +) _C, _ = optional_import("monai._C") if not PT_BEFORE_1_7: fft, _ = optional_import("torch.fft") -__all__ = ["SkipConnection", "Flatten", "GaussianFilter", "LLTM", "Reshape", "separable_filtering", "HilbertTransform"] +__all__ = [ + "SkipConnection", + "Flatten", + "GaussianFilter", + "LLTM", + "Reshape", + "separable_filtering", + "HilbertTransform", + "ChannelPad", +] + + +class ChannelPad(nn.Module): + """ + Expand the input tensor's channel dimension from length `in_channels` to `out_channels`, + by padding or a projection. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + mode: Union[ChannelMatching, str] = ChannelMatching.PAD, + ): + """ + + Args: + spatial_dims: number of spatial dimensions of the input image. + in_channels: number of input channels. + out_channels: number of output channels. + mode: {``"pad"``, ``"project"``} + Specifies handling residual branch and conv branch channel mismatches. Defaults to ``"pad"``. + + - ``"pad"``: with zero padding. + - ``"project"``: with a trainable conv with kernel size one. + """ + super().__init__() + self.project = None + self.pad = None + if in_channels == out_channels: + return + mode = ChannelMatching(mode) + if mode == ChannelMatching.PROJECT: + conv_type = Conv[Conv.CONV, spatial_dims] + self.project = conv_type(in_channels, out_channels, kernel_size=1) + return + if mode == ChannelMatching.PAD: + if in_channels > out_channels: + raise ValueError('Incompatible values: channel_matching="pad" and in_channels > out_channels.') + pad_1 = (out_channels - in_channels) // 2 + pad_2 = out_channels - in_channels - pad_1 + pad = [0, 0] * spatial_dims + [pad_1, pad_2] + [0, 0] + self.pad = tuple(pad) + return + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.project is not None: + return torch.as_tensor(self.project(x)) # as_tensor used to get around mypy typing bug + if self.pad is not None: + return F.pad(x, self.pad) + return x class SkipConnection(nn.Module): diff --git a/monai/networks/nets/highresnet.py b/monai/networks/nets/highresnet.py index c2adfd237a..918b5b5349 100644 --- a/monai/networks/nets/highresnet.py +++ b/monai/networks/nets/highresnet.py @@ -9,21 +9,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, Optional, Sequence, Union +from typing import Dict, Optional, Sequence, Tuple, Union import torch import torch.nn as nn -import torch.nn.functional as F -from monai.networks.layers.convutils import same_padding -from monai.networks.layers.factories import Conv, Dropout, Norm -from monai.utils import Activation, ChannelMatching, Normalisation +from monai.networks.blocks import ADN, Convolution +from monai.networks.layers.simplelayers import ChannelPad +from monai.utils import ChannelMatching -SUPPORTED_NORM = { - Normalisation.BATCH: lambda spatial_dims: Norm[Norm.BATCH, spatial_dims], - Normalisation.INSTANCE: lambda spatial_dims: Norm[Norm.INSTANCE, spatial_dims], -} -SUPPORTED_ACTI = {Activation.RELU: nn.ReLU, Activation.PRELU: nn.PReLU, Activation.RELU6: nn.ReLU6} DEFAULT_LAYER_PARAMS_3D = ( # initial conv layer {"name": "conv_0", "n_features": 16, "kernel_size": 3}, @@ -37,64 +31,6 @@ ) -class ConvNormActi(nn.Module): - def __init__( - self, - spatial_dims: int, - in_channels: int, - out_channels: int, - kernel_size: int, - norm_type: Optional[Union[Normalisation, str]] = None, - acti_type: Optional[Union[Activation, str]] = None, - dropout_prob: Optional[float] = None, - ) -> None: - """ - Args: - spatial_dims: number of spatial dimensions of the input image. - in_channels: number of input channels. - out_channels: number of output channels. - kernel_size: size of the convolving kernel. - norm_type: {``"batch"``, ``"instance"``} - Feature normalisation with batchnorm or instancenorm. Defaults to ``"batch"``. - acti_type: {``"relu"``, ``"prelu"``, ``"relu6"``} - Non-linear activation using ReLU or PReLU. Defaults to ``"relu"``. - dropout_prob: probability of the feature map to be zeroed - (only applies to the penultimate conv layer). - """ - - super(ConvNormActi, self).__init__() - - layers = nn.ModuleList() - - conv_type = Conv[Conv.CONV, spatial_dims] - padding_size = same_padding(kernel_size) - conv = conv_type(in_channels, out_channels, kernel_size, padding=padding_size) - layers.append(conv) - - if norm_type is not None: - norm_type = Normalisation(norm_type) - layers.append(SUPPORTED_NORM[norm_type](spatial_dims)(out_channels)) - if acti_type is not None: - acti_type = Activation(acti_type) - layers.append(SUPPORTED_ACTI[acti_type](inplace=True)) - if dropout_prob is not None: - dropout_type = Dropout[Dropout.DROPOUT, spatial_dims] - layers.append(dropout_type(p=dropout_prob)) - self.layers = nn.Sequential(*layers) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return torch.as_tensor(self.layers(x)) - - -class ChannelPad(nn.Module): - def __init__(self, pad): - super().__init__() - self.pad = tuple(pad) - - def forward(self, x): - return F.pad(x, self.pad) - - class HighResBlock(nn.Module): def __init__( self, @@ -103,8 +39,8 @@ def __init__( out_channels: int, kernels: Sequence[int] = (3, 3), dilation: Union[Sequence[int], int] = 1, - norm_type: Union[Normalisation, str] = Normalisation.INSTANCE, - acti_type: Union[Activation, str] = Activation.RELU, + norm_type: Union[Tuple, str] = ("batch", {"affine": True}), + acti_type: Union[Tuple, str] = ("relu", {"inplace": True}), channel_matching: Union[ChannelMatching, str] = ChannelMatching.PAD, ) -> None: """ @@ -114,51 +50,39 @@ def __init__( out_channels: number of output channels. kernels: each integer k in `kernels` corresponds to a convolution layer with kernel size k. dilation: spacing between kernel elements. - norm_type: {``"batch"``, ``"instance"``} - Feature normalisation with batchnorm or instancenorm. Defaults to ``"instance"``. + norm_type: feature normalization type and arguments. + Defaults to ``("batch", {"affine": True})``. acti_type: {``"relu"``, ``"prelu"``, ``"relu6"``} Non-linear activation using ReLU or PReLU. Defaults to ``"relu"``. channel_matching: {``"pad"``, ``"project"``} Specifies handling residual branch and conv branch channel mismatches. Defaults to ``"pad"``. - ``"pad"``: with zero padding. - - ``"project"``: with a trainable conv with kernel size. + - ``"project"``: with a trainable conv with kernel size one. Raises: ValueError: When ``channel_matching=pad`` and ``in_channels > out_channels``. Incompatible values. """ super(HighResBlock, self).__init__() - conv_type = Conv[Conv.CONV, spatial_dims] - norm_type = Normalisation(norm_type) - acti_type = Activation(acti_type) - - self.project = None - self.pad = None - - if in_channels != out_channels: - channel_matching = ChannelMatching(channel_matching) - - if channel_matching == ChannelMatching.PROJECT: - self.project = conv_type(in_channels, out_channels, kernel_size=1) - - if channel_matching == ChannelMatching.PAD: - if in_channels > out_channels: - raise ValueError('Incompatible values: channel_matching="pad" and in_channels > out_channels.') - pad_1 = (out_channels - in_channels) // 2 - pad_2 = out_channels - in_channels - pad_1 - pad = [0, 0] * spatial_dims + [pad_1, pad_2] + [0, 0] - self.pad = ChannelPad(pad) + self.chn_pad = ChannelPad( + spatial_dims=spatial_dims, in_channels=in_channels, out_channels=out_channels, mode=channel_matching + ) layers = nn.ModuleList() _in_chns, _out_chns = in_channels, out_channels for kernel_size in kernels: - layers.append(SUPPORTED_NORM[norm_type](spatial_dims)(_in_chns)) - layers.append(SUPPORTED_ACTI[acti_type](inplace=True)) layers.append( - conv_type( - _in_chns, _out_chns, kernel_size, padding=same_padding(kernel_size, dilation), dilation=dilation + ADN(ordering="NA", in_channels=_in_chns, act=acti_type, norm=norm_type, norm_dim=spatial_dims) + ) + layers.append( + Convolution( + dimensions=spatial_dims, + in_channels=_in_chns, + out_channels=_out_chns, + kernel_size=kernel_size, + dilation=dilation, ) ) _in_chns = _out_chns @@ -167,14 +91,7 @@ def __init__( def forward(self, x: torch.Tensor) -> torch.Tensor: x_conv: torch.Tensor = self.layers(x) - - if self.project is not None: - return x_conv + torch.as_tensor(self.project(x)) # as_tensor used to get around mypy typing bug - - if self.pad is not None: - return x_conv + torch.as_tensor(self.pad(x)) - - return x_conv + x + return x_conv + torch.as_tensor(self.chn_pad(x)) class HighResNet(nn.Module): @@ -191,13 +108,18 @@ class HighResNet(nn.Module): spatial_dims: number of spatial dimensions of the input image. in_channels: number of input channels. out_channels: number of output channels. - norm_type: {``"batch"``, ``"instance"``} - Feature normalisation with batchnorm or instancenorm. Defaults to ``"batch"``. - acti_type: {``"relu"``, ``"prelu"``, ``"relu6"``} - Non-linear activation using ReLU or PReLU. Defaults to ``"relu"``. + norm_type: feature normalization type and arguments. + Defaults to ``("batch", {"affine": True})``. + acti_type: activation type and arguments. + Defaults to ``("relu", {"inplace": True})``. dropout_prob: probability of the feature map to be zeroed (only applies to the penultimate conv layer). layer_params: specifying key parameters of each layer/block. + channel_matching: {``"pad"``, ``"project"``} + Specifies handling residual branch and conv branch channel mismatches. Defaults to ``"pad"``. + + - ``"pad"``: with zero padding. + - ``"project"``: with a trainable conv with kernel size one. """ def __init__( @@ -205,10 +127,11 @@ def __init__( spatial_dims: int = 3, in_channels: int = 1, out_channels: int = 1, - norm_type: Union[Normalisation, str] = Normalisation.BATCH, - acti_type: Union[Activation, str] = Activation.RELU, - dropout_prob: Optional[float] = None, + norm_type: Union[str, tuple] = ("batch", {"affine": True}), + acti_type: Union[str, tuple] = ("relu", {"inplace": True}), + dropout_prob: Optional[Union[Tuple, str, float]] = 0.0, layer_params: Sequence[Dict] = DEFAULT_LAYER_PARAMS_3D, + channel_matching: Union[ChannelMatching, str] = ChannelMatching.PAD, ) -> None: super(HighResNet, self).__init__() @@ -218,14 +141,14 @@ def __init__( params = layer_params[0] _in_chns, _out_chns = in_channels, params["n_features"] blocks.append( - ConvNormActi( - spatial_dims, - _in_chns, - _out_chns, + Convolution( + dimensions=spatial_dims, + in_channels=_in_chns, + out_channels=_out_chns, kernel_size=params["kernel_size"], - norm_type=norm_type, - acti_type=acti_type, - dropout_prob=None, + adn_ordering="NA", + act=acti_type, + norm=norm_type, ) ) @@ -236,13 +159,14 @@ def __init__( for _ in range(params["repeat"]): blocks.append( HighResBlock( - spatial_dims, - _in_chns, - _out_chns, - params["kernels"], + spatial_dims=spatial_dims, + in_channels=_in_chns, + out_channels=_out_chns, + kernels=params["kernels"], dilation=_dilation, norm_type=norm_type, acti_type=acti_type, + channel_matching=channel_matching, ) ) _in_chns = _out_chns @@ -251,28 +175,30 @@ def __init__( params = layer_params[-2] _in_chns, _out_chns = _out_chns, params["n_features"] blocks.append( - ConvNormActi( - spatial_dims, - _in_chns, - _out_chns, + Convolution( + dimensions=spatial_dims, + in_channels=_in_chns, + out_channels=_out_chns, kernel_size=params["kernel_size"], - norm_type=norm_type, - acti_type=acti_type, - dropout_prob=dropout_prob, + adn_ordering="NAD", + act=acti_type, + norm=norm_type, + dropout=dropout_prob, ) ) params = layer_params[-1] _in_chns = _out_chns blocks.append( - ConvNormActi( - spatial_dims, - _in_chns, - out_channels, + Convolution( + dimensions=spatial_dims, + in_channels=_in_chns, + out_channels=out_channels, kernel_size=params["kernel_size"], - norm_type=norm_type, - acti_type=None, - dropout_prob=None, + adn_ordering="NAD", + act=acti_type, + norm=norm_type, + dropout=dropout_prob, ) ) diff --git a/tests/test_channel_pad.py b/tests/test_channel_pad.py new file mode 100644 index 0000000000..00d0eab65a --- /dev/null +++ b/tests/test_channel_pad.py @@ -0,0 +1,48 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +from parameterized import parameterized + +from monai.networks.layers import ChannelPad + +TEST_CASES_3D = [] +for type_1 in ("pad", "project"): + input_shape = (16, 10, 32, 24, 48) + out_chns = 13 + result_shape = list(input_shape) + result_shape[1] = out_chns + test_case = [ + {"spatial_dims": 3, "in_channels": 10, "out_channels": out_chns, "mode": type_1}, + input_shape, + result_shape, + ] + TEST_CASES_3D.append(test_case) + + +class TestChannelPad(unittest.TestCase): + @parameterized.expand(TEST_CASES_3D) + def test_shape(self, input_param, input_shape, expected_shape): + net = ChannelPad(**input_param) + net.eval() + with torch.no_grad(): + result = net(torch.randn(input_shape)) + self.assertEqual(list(result.shape), list(expected_shape)) + + def test_wrong_mode(self): + with self.assertRaises(ValueError): + ChannelPad(3, 10, 20, mode="test")(torch.randn(10, 10)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_highresnet.py b/tests/test_highresnet.py index 0afa4b1a9b..10f4f41fea 100644 --- a/tests/test_highresnet.py +++ b/tests/test_highresnet.py @@ -53,7 +53,7 @@ def test_shape(self, input_param, input_shape, expected_shape): result = net.forward(torch.randn(input_shape).to(device)) self.assertEqual(result.shape, expected_shape) - @TimedCall(seconds=200, force_quit=True) + @TimedCall(seconds=400, force_quit=True) def test_script(self): input_param, input_shape, expected_shape = TEST_CASE_1 net = HighResNet(**input_param)