From 6e0f7be083711f3d6c9288b5e88c9334a182f0f4 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Mon, 19 Oct 2020 13:31:50 +0100 Subject: [PATCH 1/4] adds an acti-norm-dropout block Signed-off-by: Wenqi Li --- docs/source/networks.rst | 5 ++ monai/networks/blocks/__init__.py | 1 + monai/networks/blocks/acti_norm.py | 119 ++++++++++++++++++++++++++ monai/networks/blocks/convolutions.py | 96 +++++++++------------ monai/networks/layers/factories.py | 23 ++++- monai/networks/nets/basic_unet.py | 5 +- monai/utils/module.py | 19 +++- tests/test_adn.py | 84 ++++++++++++++++++ tests/test_basic_unet.py | 9 +- tests/test_convolutions.py | 4 +- 10 files changed, 301 insertions(+), 64 deletions(-) create mode 100644 monai/networks/blocks/acti_norm.py create mode 100644 tests/test_adn.py diff --git a/docs/source/networks.rst b/docs/source/networks.rst index 8e7c667026..e48ff7c814 100644 --- a/docs/source/networks.rst +++ b/docs/source/networks.rst @@ -10,6 +10,11 @@ Blocks .. automodule:: monai.networks.blocks .. currentmodule:: monai.networks.blocks +`ADN` +~~~~~ +.. autoclass:: ADN + :members: + `Convolution` ~~~~~~~~~~~~~ .. autoclass:: Convolution diff --git a/monai/networks/blocks/__init__.py b/monai/networks/blocks/__init__.py index 65ca519929..a80726f1a8 100644 --- a/monai/networks/blocks/__init__.py +++ b/monai/networks/blocks/__init__.py @@ -9,6 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from .acti_norm import ADN from .aspp import SimpleASPP from .convolutions import Convolution, ResidualUnit from .downsample import MaxAvgPool diff --git a/monai/networks/blocks/acti_norm.py b/monai/networks/blocks/acti_norm.py new file mode 100644 index 0000000000..5a74ecbeeb --- /dev/null +++ b/monai/networks/blocks/acti_norm.py @@ -0,0 +1,119 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Tuple, Union + +import torch.nn as nn + +from monai.networks.layers.factories import Act, Dropout, Norm, split_args +from monai.utils import has_option + + +class ADN(nn.Sequential): + """ + Constructs a sequential module of optional activation, dropout, and normalization layers + (with an arbitrary order):: + + -- (Norm) -- (Dropout) -- (Acti) -- + + Args: + ordering: a string representing the ordering of activation, dropout, and normalization. Defaults to "NDA". + in_channels: `C` from an expected input of size (N, C, H[, W, D]). + act: activation type and arguments. Defaults to PReLU. + norm: feature normalization type and arguments. Defaults to instance norm. + norm_dim: determine the spatial dimensions of the normalization layer. + defaults to `dropout_dim` if unspecified. + dropout: dropout ratio. Defaults to no dropout. + dropout_dim: determine the spatial dimensions of dropout. + defaults to `norm_dim` if unspecified. + + - When dropout_dim = 1, randomly zeroes some of the elements for each channel. + - When dropout_dim = 2, Randomly zeroes out entire channels (a channel is a 2D feature map). + - When dropout_dim = 3, Randomly zeroes out entire channels (a channel is a 3D feature map). + + Examples:: + # activation, group norm, dropout + >>> norm_params = ("GROUP", {"num_groups": 1, "affine": False}) + >>> ADN(norm=norm_params, in_channels=1, dropout_dim=1, dropout=0.8, ordering="AND") + ADN( + (A): ReLU() + (N): GroupNorm(1, 1, eps=1e-05, affine=False) + (D): Dropout(p=0.8, inplace=False) + ) + + # LeakyReLU, dropout + >>> act_params = ("leakyrelu", {"negative_slope": 0.1, "inplace": True}) + >>> ADN(act=act_params, in_channels=1, dropout_dim=1, dropout=0.8, ordering="AD") + ADN( + (A): LeakyReLU(negative_slope=0.1, inplace=True) + (D): Dropout(p=0.8, inplace=False) + ) + + + See also: + + :py:class:`monai.networks.layers.Dropout` + :py:class:`monai.networks.layers.Act` + :py:class:`monai.networks.layers.Norm` + :py:class:`monai.networks.layers.split_args` + + """ + + def __init__( + self, + ordering: str = "NDA", + in_channels: Optional[int] = None, + act: Optional[Union[Tuple, str]] = "RELU", + norm: Optional[Union[Tuple, str]] = None, + norm_dim: Optional[int] = None, + dropout: Optional[Union[Tuple, str, float]] = None, + dropout_dim: Optional[int] = None, + ) -> None: + super().__init__() + + op_dict = {"A": None, "D": None, "N": None} + # define the normalisation type and the arguments to the constructor + if norm is not None: + if norm_dim is None and dropout_dim is None: + raise ValueError("norm_dim or dropout_dim needs to be specified.") + norm_name, norm_args = split_args(norm) + norm_type = Norm[norm_name, norm_dim or dropout_dim] + kw_args = dict(norm_args) + if has_option(norm_type, "num_features") and "num_features" not in kw_args: + kw_args["num_features"] = in_channels + if has_option(norm_type, "num_channels") and "num_channels" not in kw_args: + kw_args["num_channels"] = in_channels + op_dict["N"] = norm_type(**kw_args) + + # define the activation type and the arguments to the constructor + if act is not None: + act_name, act_args = split_args(act) + act_type = Act[act_name] + op_dict["A"] = act_type(**act_args) + + if dropout is not None: + # if dropout was specified simply as a p value, use default name and make a keyword map with the value + if isinstance(dropout, (int, float)): + drop_name = Dropout.DROPOUT + drop_args = {"p": dropout} + else: + drop_name, drop_args = split_args(dropout) + + if norm_dim is None and dropout_dim is None: + raise ValueError("norm_dim or dropout_dim needs to be specified.") + drop_type = Dropout[drop_name, dropout_dim or norm_dim] + op_dict["D"] = drop_type(**drop_args) + + for item in ordering.upper(): + if item not in op_dict: + raise ValueError(f"ordering must be a string of {op_dict}, got {item} in it.") + if op_dict[item] is not None: + self.add_module(item, op_dict[item]) # type: ignore diff --git a/monai/networks/blocks/convolutions.py b/monai/networks/blocks/convolutions.py index 7c0ca3808b..eafe028a06 100644 --- a/monai/networks/blocks/convolutions.py +++ b/monai/networks/blocks/convolutions.py @@ -15,15 +15,16 @@ import torch import torch.nn as nn +from monai.networks.blocks import ADN from monai.networks.layers.convutils import same_padding, stride_minus_kernel_padding -from monai.networks.layers.factories import Act, Conv, Dropout, Norm, split_args +from monai.networks.layers.factories import Conv class Convolution(nn.Sequential): """ Constructs a convolution with normalization, optional dropout, and optional activation layers:: - -- (Conv|ConvTrans) -- Norm -- (Dropout) -- (Acti) -- + -- (Conv|ConvTrans) -- (Norm -- Dropout -- Acti) -- if ``conv_only`` set to ``True``:: @@ -35,14 +36,18 @@ class Convolution(nn.Sequential): out_channels: number of output channels. strides: convolution stride. Defaults to 1. kernel_size: convolution kernel size. Defaults to 3. + adn_ordering: a string representing the ordering of activation, normalization, and dropout. + Defaults to "NDA". act: activation type and arguments. Defaults to PReLU. norm: feature normalization type and arguments. Defaults to instance norm. dropout: dropout ratio. Defaults to no dropout. dropout_dim: determine the dimensions of dropout. Defaults to 1. - When dropout_dim = 1, randomly zeroes some of the elements for each channel. - When dropout_dim = 2, Randomly zeroes out entire channels (a channel is a 2D feature map). - When dropout_dim = 3, Randomly zeroes out entire channels (a channel is a 3D feature map). - The value of dropout_dim should be no no larger than the value of dimensions. + + - When dropout_dim = 1, randomly zeroes some of the elements for each channel. + - When dropout_dim = 2, Randomly zeroes out entire channels (a channel is a 2D feature map). + - When dropout_dim = 3, Randomly zeroes out entire channels (a channel is a 3D feature map). + + The value of dropout_dim should be no no larger than the value of `dimensions`. dilation: dilation rate. Defaults to 1. groups: controls the connections between inputs and outputs. Defaults to 1. bias: whether to have a bias term. Defaults to True. @@ -56,10 +61,7 @@ class Convolution(nn.Sequential): See also: :py:class:`monai.networks.layers.Conv` - :py:class:`monai.networks.layers.Dropout` - :py:class:`monai.networks.layers.Act` - :py:class:`monai.networks.layers.Norm` - :py:class:`monai.networks.layers.split_args` + :py:class:`monai.networks.blocks.ADN` """ @@ -70,10 +72,11 @@ def __init__( out_channels: int, strides: Union[Sequence[int], int] = 1, kernel_size: Union[Sequence[int], int] = 3, - act: Optional[Union[Tuple, str]] = Act.PRELU, - norm: Union[Tuple, str] = Norm.INSTANCE, + adn_ordering: str = "NDA", + act: Optional[Union[Tuple, str]] = "PRELU", + norm: Optional[Union[Tuple, str]] = "INSTANCE", dropout: Optional[Union[Tuple, str, float]] = None, - dropout_dim: int = 1, + dropout_dim: Optional[int] = 1, dilation: Union[Sequence[int], int] = 1, groups: int = 1, bias: bool = True, @@ -90,33 +93,6 @@ def __init__( if padding is None: padding = same_padding(kernel_size, dilation) conv_type = Conv[Conv.CONVTRANS if is_transposed else Conv.CONV, dimensions] - # define the normalisation type and the arguments to the constructor - if norm is not None: - norm_name, norm_args = split_args(norm) - norm_type = Norm[norm_name, dimensions] - else: - norm_type = norm_args = None - - # define the activation type and the arguments to the constructor - if act is not None: - act_name, act_args = split_args(act) - act_type = Act[act_name] - else: - act_type = act_args = None - - if dropout: - # if dropout was specified simply as a p value, use default name and make a keyword map with the value - if isinstance(dropout, (int, float)): - drop_name = Dropout.DROPOUT - drop_args = {"p": dropout} - else: - drop_name, drop_args = split_args(dropout) - - if dropout_dim > dimensions: - raise ValueError( - f"dropout_dim should be no larger than dimensions, got dropout_dim={dropout_dim} and dimensions={dimensions}." - ) - drop_type = Dropout[drop_name, dropout_dim] if is_transposed: if output_padding is None: @@ -147,14 +123,18 @@ def __init__( self.add_module("conv", conv) if not conv_only: - if norm is not None: - self.add_module("norm", norm_type(out_channels, **norm_args)) - - if dropout: - self.add_module("dropout", drop_type(**drop_args)) - - if act is not None: - self.add_module("act", act_type(**act_args)) + self.add_module( + "adn", + ADN( + ordering=adn_ordering, + in_channels=out_channels, + act=act, + norm=norm, + norm_dim=dimensions, + dropout=dropout, + dropout_dim=dropout_dim, + ), + ) class ResidualUnit(nn.Module): @@ -168,14 +148,18 @@ class ResidualUnit(nn.Module): strides: convolution stride. Defaults to 1. kernel_size: convolution kernel size. Defaults to 3. subunits: number of convolutions. Defaults to 2. + adn_ordering: a string representing the ordering of activation, normalization, and dropout. + Defaults to "NDA". act: activation type and arguments. Defaults to PReLU. norm: feature normalization type and arguments. Defaults to instance norm. dropout: dropout ratio. Defaults to no dropout. dropout_dim: determine the dimensions of dropout. Defaults to 1. - When dropout_dim = 1, randomly zeroes some of the elements for each channel. - When dropout_dim = 2, Randomly zero out entire channels (a channel is a 2D feature map). - When dropout_dim = 3, Randomly zero out entire channels (a channel is a 3D feature map). - The value of dropout_dim should be no no larger than the value of dimensions. + + - When dropout_dim = 1, randomly zeroes some of the elements for each channel. + - When dropout_dim = 2, Randomly zero out entire channels (a channel is a 2D feature map). + - When dropout_dim = 3, Randomly zero out entire channels (a channel is a 3D feature map). + + The value of dropout_dim should be no no larger than the value of `dimensions`. dilation: dilation rate. Defaults to 1. bias: whether to have a bias term. Defaults to True. last_conv_only: for the last subunit, whether to use the convolutional layer only. @@ -197,10 +181,11 @@ def __init__( strides: Union[Sequence[int], int] = 1, kernel_size: Union[Sequence[int], int] = 3, subunits: int = 2, - act: Optional[Union[Tuple, str]] = Act.PRELU, - norm: Union[Tuple, str] = Norm.INSTANCE, + adn_ordering: str = "NDA", + act: Optional[Union[Tuple, str]] = "PRELU", + norm: Optional[Union[Tuple, str]] = "INSTANCE", dropout: Optional[Union[Tuple, str, float]] = None, - dropout_dim: int = 1, + dropout_dim: Optional[int] = 1, dilation: Union[Sequence[int], int] = 1, bias: bool = True, last_conv_only: bool = False, @@ -226,6 +211,7 @@ def __init__( out_channels, strides=sstrides, kernel_size=kernel_size, + adn_ordering=adn_ordering, act=act, norm=norm, dropout=dropout, diff --git a/monai/networks/layers/factories.py b/monai/networks/layers/factories.py index 1e4c7febae..63014b8894 100644 --- a/monai/networks/layers/factories.py +++ b/monai/networks/layers/factories.py @@ -60,7 +60,7 @@ def use_factory(fact_args): layer = use_factory( (fact.TEST, kwargs) ) """ -from typing import Any, Callable, Dict, Tuple, Type, Union +from typing import Any, Callable, Dict, Optional, Tuple, Type, Union import torch.nn as nn @@ -216,7 +216,26 @@ def batch_factory(dim: int) -> Type[Union[nn.BatchNorm1d, nn.BatchNorm2d, nn.Bat return types[dim - 1] -Norm.add_factory_callable("group", lambda: nn.modules.GroupNorm) +@Norm.factory_function("group") +def group_factory(_dim: Optional[int] = None) -> Type[nn.GroupNorm]: + return nn.GroupNorm + + +@Norm.factory_function("layer") +def layer_factory(_dim: Optional[int] = None) -> Type[nn.LayerNorm]: + return nn.LayerNorm + + +@Norm.factory_function("localresponse") +def local_response_factory(_dim: Optional[int] = None) -> Type[nn.LocalResponseNorm]: + return nn.LocalResponseNorm + + +@Norm.factory_function("syncbatch") +def sync_batch_factory(_dim: Optional[int] = None) -> Type[nn.SyncBatchNorm]: + return nn.SyncBatchNorm + + Act.add_factory_callable("elu", lambda: nn.modules.ELU) Act.add_factory_callable("relu", lambda: nn.modules.ReLU) Act.add_factory_callable("leakyrelu", lambda: nn.modules.LeakyReLU) diff --git a/monai/networks/nets/basic_unet.py b/monai/networks/nets/basic_unet.py index 06905d2218..746378e0a1 100644 --- a/monai/networks/nets/basic_unet.py +++ b/monai/networks/nets/basic_unet.py @@ -123,13 +123,12 @@ def forward(self, x: torch.Tensor, x_e: torch.Tensor): x_0 = self.upsample(x) # handling spatial shapes due to the 2x maxpooling with odd edge lengths. - dimensions = x.ndim - 2 # type: ignore + dimensions = len(x.shape) - 2 sp = [0] * (dimensions * 2) for i in range(dimensions): if x_e.shape[-i - 1] != x_0.shape[-i - 1]: sp[i * 2 + 1] = 1 - if sum(sp) != 0: - x_0 = torch.nn.functional.pad(x_0, sp, "replicate") + x_0 = torch.nn.functional.pad(x_0, sp, "replicate") x = self.convs(torch.cat([x_e, x_0], dim=1)) # input channels: (cat_chns + up_chns) return x diff --git a/monai/utils/module.py b/monai/utils/module.py index 459486ad32..4b8c1e91e7 100644 --- a/monai/utils/module.py +++ b/monai/utils/module.py @@ -9,11 +9,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +import inspect import sys from importlib import import_module from pkgutil import walk_packages from re import match -from typing import Any, Callable, List, Tuple +from typing import Any, Callable, List, Sequence, Tuple, Union + +from .misc import ensure_tuple OPTIONAL_IMPORT_MSG_FMT = "{}" @@ -25,6 +28,7 @@ "optional_import", "load_submodules", "get_full_type_name", + "has_option", ] @@ -214,3 +218,16 @@ def __call__(self, *_args, **_kwargs): raise self._exception return _LazyRaise(), False + + +def has_option(obj, keywords: Union[str, Sequence[str]]) -> bool: + """ + Return a boolean indicating whether the given callable `obj` has the `keywords` in its signature. + """ + if not callable(obj): + return False + sig = inspect.signature(obj) + for key in ensure_tuple(keywords): + if key not in sig.parameters: + return False + return True diff --git a/tests/test_adn.py b/tests/test_adn.py new file mode 100644 index 0000000000..71ac286b03 --- /dev/null +++ b/tests/test_adn.py @@ -0,0 +1,84 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from parameterized import parameterized + +from monai.networks.blocks import ADN +from tests.utils import TorchImageTestCase2D, TorchImageTestCase3D + +TEST_CASES_2D = [ + [{"act": None}], + [{"norm_dim": 2}], + [{"norm_dim": 2, "act": "relu", "dropout": 0.8, "ordering": "DA"}], + [{"dropout_dim": 1, "dropout": 0.8, "ordering": "DA"}], + [{"norm": "BATCH", "norm_dim": 2, "in_channels": 1, "dropout_dim": 1, "dropout": 0.8, "ordering": "NDA"}], + [{"norm": "BATCH", "norm_dim": 2, "in_channels": 1, "dropout_dim": 1, "dropout": 0.8, "ordering": "AND"}], + [{"norm": "INSTANCE", "norm_dim": 2, "dropout_dim": 1, "dropout": 0.8, "ordering": "AND"}], + [ + { + "norm": ("GROUP", {"num_groups": 1, "affine": False}), + "in_channels": 1, + "norm_dim": 2, + "dropout_dim": 1, + "dropout": 0.8, + "ordering": "AND", + } + ], + [{"norm": ("localresponse", {"size": 4}), "norm_dim": 2, "dropout_dim": 1, "dropout": 0.8, "ordering": "AND"}], +] + +TEST_CASES_3D = [ + [{"norm_dim": 3}], + [{"act": "prelu", "dropout_dim": 1, "dropout": 0.8, "ordering": "DA"}], + [{"dropout_dim": 1, "dropout": 0.8, "ordering": "DA"}], + [{"norm": "BATCH", "norm_dim": 3, "in_channels": 1, "dropout_dim": 1, "dropout": 0.8, "ordering": "NDA"}], + [{"norm": "BATCH", "norm_dim": 3, "in_channels": 1, "dropout_dim": 1, "dropout": 0.8, "ordering": "AND"}], + [{"norm": "INSTANCE", "norm_dim": 3, "dropout_dim": 1, "dropout": 0.8, "ordering": "AND"}], + [ + { + "norm": ("layer", {"normalized_shape": (64, 80)}), + "norm_dim": 3, + "dropout_dim": 1, + "dropout": 0.8, + "ordering": "AND", + } + ], +] + + +class TestADN2D(TorchImageTestCase2D): + @parameterized.expand(TEST_CASES_2D) + def test_adn_2d(self, args): + adn = ADN(**args) + print(adn) + out = adn(self.imt) + expected_shape = (1, self.input_channels, self.im_shape[0], self.im_shape[1]) + self.assertEqual(out.shape, expected_shape) + + def test_no_input(self): + with self.assertRaises(ValueError): + ADN(norm="instance") + + +class TestADN3D(TorchImageTestCase3D): + @parameterized.expand(TEST_CASES_3D) + def test_adn_3d(self, args): + adn = ADN(**args) + print(adn) + out = adn(self.imt) + expected_shape = (1, self.input_channels, self.im_shape[1], self.im_shape[0], self.im_shape[2]) + self.assertEqual(out.shape, expected_shape) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_basic_unet.py b/tests/test_basic_unet.py index 33abac86bb..757edb01fb 100644 --- a/tests/test_basic_unet.py +++ b/tests/test_basic_unet.py @@ -15,6 +15,7 @@ from parameterized import parameterized from monai.networks.nets import BasicUNet +from tests.utils import test_script_save CASES_2D = [] for mode in ["pixelshuffle", "nontrainable", "deconv"]: @@ -71,7 +72,7 @@ ] -class TestBaseUNET(unittest.TestCase): +class TestBasicUNET(unittest.TestCase): @parameterized.expand(CASES_2D + CASES_3D) def test_shape(self, input_param, input_shape, expected_shape): device = "cuda" if torch.cuda.is_available() else "cpu" @@ -81,6 +82,12 @@ def test_shape(self, input_param, input_shape, expected_shape): result = net(torch.randn(input_shape).to(device)) self.assertEqual(result.shape, expected_shape) + def test_script(self): + net = BasicUNet(dimensions=2, in_channels=1, out_channels=3) + test_data = torch.randn(16, 1, 32, 32) + out_orig, out_reloaded = test_script_save(net, test_data) + assert torch.allclose(out_orig, out_reloaded) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_convolutions.py b/tests/test_convolutions.py index 7755fab5df..bb6ea45e62 100644 --- a/tests/test_convolutions.py +++ b/tests/test_convolutions.py @@ -68,13 +68,13 @@ def test_transpose2(self): class TestConvolution3D(TorchImageTestCase3D): def test_conv1(self): - conv = Convolution(3, self.input_channels, self.output_channels) + conv = Convolution(3, self.input_channels, self.output_channels, dropout=0.1, adn_ordering="DAN") out = conv(self.imt) expected_shape = (1, self.output_channels, self.im_shape[1], self.im_shape[0], self.im_shape[2]) self.assertEqual(out.shape, expected_shape) def test_conv1_no_acti(self): - conv = Convolution(3, self.input_channels, self.output_channels, act=None) + conv = Convolution(3, self.input_channels, self.output_channels, act=None, adn_ordering="AND") out = conv(self.imt) expected_shape = (1, self.output_channels, self.im_shape[1], self.im_shape[0], self.im_shape[2]) self.assertEqual(out.shape, expected_shape) From 0183355490f1ce3d41329e5770c7f5bae0ffe3b0 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Wed, 21 Oct 2020 09:50:53 +0100 Subject: [PATCH 2/4] fixes unit tests Signed-off-by: Wenqi Li --- monai/networks/blocks/acti_norm.py | 2 +- monai/networks/nets/unet.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/networks/blocks/acti_norm.py b/monai/networks/blocks/acti_norm.py index 5a74ecbeeb..92629d2322 100644 --- a/monai/networks/blocks/acti_norm.py +++ b/monai/networks/blocks/acti_norm.py @@ -103,7 +103,7 @@ def __init__( # if dropout was specified simply as a p value, use default name and make a keyword map with the value if isinstance(dropout, (int, float)): drop_name = Dropout.DROPOUT - drop_args = {"p": dropout} + drop_args = {"p": float(dropout)} else: drop_name, drop_args = split_args(dropout) diff --git a/monai/networks/nets/unet.py b/monai/networks/nets/unet.py index ad0e6d375f..d9326176eb 100644 --- a/monai/networks/nets/unet.py +++ b/monai/networks/nets/unet.py @@ -35,7 +35,7 @@ def __init__( num_res_units: int = 0, act=Act.PRELU, norm=Norm.INSTANCE, - dropout=0, + dropout=0.0, ) -> None: """ Enhanced version of UNet which has residual units implemented with the ResidualUnit class. From 83ac0095f17a5f6a7c43be794eb7e3d97d8af68c Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Wed, 21 Oct 2020 11:10:49 +0100 Subject: [PATCH 3/4] update docstrings Signed-off-by: Wenqi Li --- docs/source/metrics.rst | 2 +- monai/networks/blocks/acti_norm.py | 2 +- monai/networks/nets/basic_unet.py | 8 ++++++++ 3 files changed, 10 insertions(+), 2 deletions(-) diff --git a/docs/source/metrics.rst b/docs/source/metrics.rst index a1af1b4623..a54b070dd2 100644 --- a/docs/source/metrics.rst +++ b/docs/source/metrics.rst @@ -30,5 +30,5 @@ Metrics .. autofunction:: compute_average_surface_distance `Occlusion sensitivity` --------------------------- +----------------------- .. autofunction:: compute_occlusion_sensitivity \ No newline at end of file diff --git a/monai/networks/blocks/acti_norm.py b/monai/networks/blocks/acti_norm.py index 92629d2322..585726edf2 100644 --- a/monai/networks/blocks/acti_norm.py +++ b/monai/networks/blocks/acti_norm.py @@ -40,6 +40,7 @@ class ADN(nn.Sequential): - When dropout_dim = 3, Randomly zeroes out entire channels (a channel is a 3D feature map). Examples:: + # activation, group norm, dropout >>> norm_params = ("GROUP", {"num_groups": 1, "affine": False}) >>> ADN(norm=norm_params, in_channels=1, dropout_dim=1, dropout=0.8, ordering="AND") @@ -57,7 +58,6 @@ class ADN(nn.Sequential): (D): Dropout(p=0.8, inplace=False) ) - See also: :py:class:`monai.networks.layers.Dropout` diff --git a/monai/networks/nets/basic_unet.py b/monai/networks/nets/basic_unet.py index 746378e0a1..ebac0273a9 100644 --- a/monai/networks/nets/basic_unet.py +++ b/monai/networks/nets/basic_unet.py @@ -176,9 +176,17 @@ def __init__( # for spatial 2D >>> net = BasicUNet(dimensions=2, features=(64, 128, 256, 512, 1024, 128)) + # for spatial 2D, with group norm + >>> net = BasicUNet(dimensions=2, features=(64, 128, 256, 512, 1024, 128), norm=("group", {"num_groups": 4})) + # for spatial 3D >>> net = BasicUNet(dimensions=3, features=(32, 32, 64, 128, 256, 32)) + See Also + + - :py:class:`monai.networks.nets.DynUNet` + - :py:class:`monai.networks.nets.UNet` + """ super().__init__() From bb6f3afb3f63e48172113d681721497195cc2664 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Wed, 21 Oct 2020 11:55:33 +0100 Subject: [PATCH 4/4] 1D test cases Signed-off-by: Wenqi Li --- tests/test_basic_unet.py | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/tests/test_basic_unet.py b/tests/test_basic_unet.py index 757edb01fb..b4d3f9fce6 100644 --- a/tests/test_basic_unet.py +++ b/tests/test_basic_unet.py @@ -17,6 +17,23 @@ from monai.networks.nets import BasicUNet from tests.utils import test_script_save +CASES_1D = [] +for mode in ["pixelshuffle", "nontrainable", "deconv", None]: + kwargs = { + "dimensions": 1, + "in_channels": 5, + "out_channels": 8, + } + if mode is not None: + kwargs["upsample"] = mode # type: ignore + CASES_1D.append( + [ + kwargs, + (10, 5, 17), + (10, 8, 17), + ] + ) + CASES_2D = [] for mode in ["pixelshuffle", "nontrainable", "deconv"]: for d1 in range(17, 64, 14): @@ -73,9 +90,10 @@ class TestBasicUNET(unittest.TestCase): - @parameterized.expand(CASES_2D + CASES_3D) + @parameterized.expand(CASES_1D + CASES_2D + CASES_3D) def test_shape(self, input_param, input_shape, expected_shape): device = "cuda" if torch.cuda.is_available() else "cpu" + print(input_param) net = BasicUNet(**input_param).to(device) net.eval() with torch.no_grad():