From eba5f2653a598d5034c5d2ee3c7868fed7d7478f Mon Sep 17 00:00:00 2001 From: kate-sann5100 Date: Sun, 17 Jan 2021 17:50:38 +0000 Subject: [PATCH 1/2] 1442 add initialization Signed-off-by: kate-sann5100 --- monai/networks/blocks/localnet_block.py | 9 +++++++++ monai/networks/nets/localnet.py | 9 ++++++--- 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/monai/networks/blocks/localnet_block.py b/monai/networks/blocks/localnet_block.py index ee7fac0690..9828cbb81d 100644 --- a/monai/networks/blocks/localnet_block.py +++ b/monai/networks/blocks/localnet_block.py @@ -285,6 +285,7 @@ def __init__( in_channels: int, out_channels: int, act: Optional[Union[Tuple, str]] = "RELU", + initializer: str = "kaiming_uniform", ) -> None: """ Args: @@ -298,6 +299,14 @@ def __init__( self.conv_block = get_conv_block( spatial_dims=spatial_dims, in_channels=in_channels, out_channels=out_channels, act=act, norm=None ) + if initializer == "kaiming_uniform": + nn.init.kaiming_normal_(self.conv_block.conv.weight) + elif initializer == "zeros": + nn.init.zeros_(self.conv_block.conv.weight) + else: + raise ValueError( + f"initializer {initializer} is not supported, " "currently supporting kaiming_uniform and zeros" + ) def forward(self, x) -> torch.Tensor: """ diff --git a/monai/networks/nets/localnet.py b/monai/networks/nets/localnet.py index 1bb3dcbc21..ea8abca185 100644 --- a/monai/networks/nets/localnet.py +++ b/monai/networks/nets/localnet.py @@ -32,15 +32,17 @@ def __init__( num_channel_initial: int, extract_levels: List[int], out_activation: Optional[Union[Tuple, str]], + out_initializer: str = "kaiming_uniform", ) -> None: """ Args: spatial_dims: number of spatial dimensions. in_channels: number of input channels. out_channels: number of output channels. - num_channel_initial: number of initial channels, - extract_levels: number of extraction levels, - out_activation: activation to use at end layer, + num_channel_initial: number of initial channels. + extract_levels: number of extraction levels. + out_activation: activation to use at end layer. + out_initializer: initializer for extraction layers. """ super(LocalNet, self).__init__() self.extract_levels = extract_levels @@ -85,6 +87,7 @@ def __init__( in_channels=num_channels[level], out_channels=out_channels, act=out_activation, + initializer=out_initializer, ) for level in self.extract_levels ] From 0ccaee969d8ee572aca019fe4a397f9538d3888b Mon Sep 17 00:00:00 2001 From: kate-sann5100 Date: Tue, 19 Jan 2021 14:43:12 +0000 Subject: [PATCH 2/2] 1442 fix typing and add test cases Signed-off-by: kate-sann5100 --- monai/networks/blocks/localnet_block.py | 23 ++++++++------ tests/test_localnet.py | 41 ++++++++++--------------- tests/test_localnet_block.py | 15 ++++----- 3 files changed, 36 insertions(+), 43 deletions(-) diff --git a/monai/networks/blocks/localnet_block.py b/monai/networks/blocks/localnet_block.py index 9828cbb81d..4166c08774 100644 --- a/monai/networks/blocks/localnet_block.py +++ b/monai/networks/blocks/localnet_block.py @@ -1,4 +1,4 @@ -from typing import Optional, Sequence, Tuple, Union +from typing import Optional, Sequence, Tuple, Type, Union import torch from torch import nn @@ -6,7 +6,7 @@ from monai.networks.blocks import Convolution from monai.networks.layers import same_padding -from monai.networks.layers.factories import Norm, Pool +from monai.networks.layers.factories import Conv, Norm, Pool def get_conv_block( @@ -299,14 +299,17 @@ def __init__( self.conv_block = get_conv_block( spatial_dims=spatial_dims, in_channels=in_channels, out_channels=out_channels, act=act, norm=None ) - if initializer == "kaiming_uniform": - nn.init.kaiming_normal_(self.conv_block.conv.weight) - elif initializer == "zeros": - nn.init.zeros_(self.conv_block.conv.weight) - else: - raise ValueError( - f"initializer {initializer} is not supported, " "currently supporting kaiming_uniform and zeros" - ) + conv_type: Type[Union[nn.Conv1d, nn.Conv2d, nn.Conv3d]] = Conv[Conv.CONV, spatial_dims] + for m in self.conv_block.modules(): + if isinstance(m, conv_type): + if initializer == "kaiming_uniform": + nn.init.kaiming_normal_(torch.as_tensor(m.weight)) + elif initializer == "zeros": + nn.init.zeros_(torch.as_tensor(m.weight)) + else: + raise ValueError( + f"initializer {initializer} is not supported, " "currently supporting kaiming_uniform and zeros" + ) def forward(self, x) -> torch.Tensor: """ diff --git a/tests/test_localnet.py b/tests/test_localnet.py index d4f812e811..97a10d0c83 100644 --- a/tests/test_localnet.py +++ b/tests/test_localnet.py @@ -10,15 +10,6 @@ device = "cuda" if torch.cuda.is_available() else "cpu" -param_variations_2d = { - "spatial_dims": 2, - "in_channels": 2, - "out_channels": 2, - "num_channel_initial": 16, - "extract_levels": [0, 1, 2], - "out_activation": ["sigmoid", None], -} - TEST_CASE_LOCALNET_2D = [ [ { @@ -41,23 +32,25 @@ for num_channel_initial in [4, 16, 32]: for extract_levels in [[0, 1, 2], [0, 1, 2, 3], [0, 1, 2, 3, 4]]: for out_activation in ["sigmoid", None]: - TEST_CASE_LOCALNET_3D.append( - [ - { - "spatial_dims": 3, - "in_channels": in_channels, - "out_channels": out_channels, - "num_channel_initial": num_channel_initial, - "extract_levels": extract_levels, - "out_activation": out_activation, - }, - (1, in_channels, 16, 16, 16), - (1, out_channels, 16, 16, 16), - ] - ) + for out_initializer in ["kaiming_uniform", "zeros"]: + TEST_CASE_LOCALNET_3D.append( + [ + { + "spatial_dims": 3, + "in_channels": in_channels, + "out_channels": out_channels, + "num_channel_initial": num_channel_initial, + "extract_levels": extract_levels, + "out_activation": out_activation, + "out_initializer": out_initializer, + }, + (1, in_channels, 16, 16, 16), + (1, out_channels, 16, 16, 16), + ] + ) -class TestDynUNet(unittest.TestCase): +class TestLocalNet(unittest.TestCase): @parameterized.expand(TEST_CASE_LOCALNET_2D + TEST_CASE_LOCALNET_3D) def test_shape(self, input_param, input_shape, expected_shape): net = LocalNet(**input_param).to(device) diff --git a/tests/test_localnet_block.py b/tests/test_localnet_block.py index af5ef19222..e6171aeae9 100644 --- a/tests/test_localnet_block.py +++ b/tests/test_localnet_block.py @@ -17,15 +17,8 @@ TEST_CASE_UP_SAMPLE = [[{"spatial_dims": spatial_dims, "in_channels": 4, "out_channels": 2}] for spatial_dims in [2, 3]] TEST_CASE_EXTRACT = [ - [ - { - "spatial_dims": spatial_dims, - "in_channels": 2, - "out_channels": 3, - "act": act, - } - ] - for spatial_dims, act in zip([2, 3], ["sigmoid", None]) + [{"spatial_dims": spatial_dims, "in_channels": 2, "out_channels": 3, "act": act, "initializer": initializer}] + for spatial_dims, act, initializer in zip([2, 3], ["sigmoid", None], ["kaiming_uniform", "zeros"]) ] in_size = 4 @@ -93,6 +86,10 @@ def test_shape(self, input_param): result = net(torch.randn(input_shape)) self.assertEqual(result.shape, expected_shape) + def test_ill_arg(self): + with self.assertRaises(ValueError): + LocalNetFeatureExtractorBlock(spatial_dims=2, in_channels=2, out_channels=2, initializer="none") + if __name__ == "__main__": unittest.main()