From 0134abf5dc59569233ef8f8003fa3211d0d853e6 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Fri, 19 Nov 2021 17:44:22 +0000 Subject: [PATCH 1/9] adds dints blocks Signed-off-by: Wenqi Li --- docs/source/networks.rst | 20 +++ monai/networks/blocks/__init__.py | 1 + monai/networks/blocks/dints_block.py | 212 +++++++++++++++++++++++++++ tests/test_factorized_increase.py | 34 +++++ tests/test_factorized_reduce.py | 34 +++++ tests/test_p3d_block.py | 41 ++++++ 6 files changed, 342 insertions(+) create mode 100644 monai/networks/blocks/dints_block.py create mode 100644 tests/test_factorized_increase.py create mode 100644 tests/test_factorized_reduce.py create mode 100644 tests/test_p3d_block.py diff --git a/docs/source/networks.rst b/docs/source/networks.rst index 39889266d9..0aad63b373 100644 --- a/docs/source/networks.rst +++ b/docs/source/networks.rst @@ -188,6 +188,26 @@ Blocks .. autoclass:: PatchEmbeddingBlock :members: +`FactorizedIncreaseBlock` +~~~~~~~~~~~~~~~~~~~~~~~~~ +.. autoclass:: FactorizedIncreaseBlock + :members: + +`FactorizedReduceBlock` +~~~~~~~~~~~~~~~~~~~~~~~ +.. autoclass:: FactorizedReduceBlock + :members: + +`P3DReLUConvNormBlock` +~~~~~~~~~~~~~~~~~~~~~~ +.. autoclass:: P3DReLUConvNormBlock + :members: + +`ReLUConvNormBlock` +~~~~~~~~~~~~~~~~~~~ +.. autoclass:: ReLUConvNormBlock + :members: + `Warp` ~~~~~~ .. autoclass:: Warp diff --git a/monai/networks/blocks/__init__.py b/monai/networks/blocks/__init__.py index db723f622d..cb3e64304c 100644 --- a/monai/networks/blocks/__init__.py +++ b/monai/networks/blocks/__init__.py @@ -14,6 +14,7 @@ from .aspp import SimpleASPP from .convolutions import Convolution, ResidualUnit from .crf import CRF +from .dints_block import FactorizedIncreaseBlock, FactorizedReduceBlock, P3DReLUConvNormBlock, ReLUConvNormBlock from .downsample import MaxAvgPool from .dynunet_block import UnetBasicBlock, UnetOutBlock, UnetResBlock, UnetUpBlock, get_output_padding, get_padding from .fcn import FCN, GCN, MCFCN, Refine diff --git a/monai/networks/blocks/dints_block.py b/monai/networks/blocks/dints_block.py new file mode 100644 index 0000000000..6c4681161a --- /dev/null +++ b/monai/networks/blocks/dints_block.py @@ -0,0 +1,212 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Optional, Sequence, Tuple, Union + +import torch +import torch.nn as nn + +from monai.networks.layers.factories import Conv +from monai.networks.layers.utils import get_act_layer, get_norm_layer + +__all__ = ["FactorizedIncreaseBlock", "FactorizedReduceBlock", "P3DReLUConvNormBlock", "ReLUConvNormBlock"] + + +class FactorizedIncreaseBlock(nn.Module): + """ + Up-sampling the features by 2 using linear interpolation and convolutions. + """ + + def __init__( + self, + in_channel: int, + out_channel: int, + spatial_dims: int = 3, + act_name: Union[Tuple, str] = "RELU", + norm_name: Union[Tuple, str] = "INSTANCE", + ): + super().__init__() + self._in_channel = in_channel + self._out_channel = out_channel + self._spatial_dims = spatial_dims + + conv_type = Conv[Conv.CONV, self._spatial_dims] + + self.op = nn.Sequential( + nn.Upsample(scale_factor=2, mode="trilinear", align_corners=True), + get_act_layer(name=act_name), + conv_type( + in_channels=self._in_channel, + out_channels=self._out_channel, + kernel_size=1, + stride=1, + padding=0, + groups=1, + bias=False, + dilation=1, + ), + get_norm_layer(name=norm_name, spatial_dims=self._spatial_dims, channels=self._out_channel), + ) + + def forward(self, x): + return self.op(x) + + +class FactorizedReduceBlock(nn.Module): + """ + Down-sampling the feature by 2 using stride. + The length along each spatial dimension must be a multiple of 2. + """ + + def __init__( + self, + in_channel: int, + out_channel: int, + spatial_dims: int = 3, + act_name: Union[Tuple, str] = "RELU", + norm_name: Union[Tuple, str] = "INSTANCE", + ): + super().__init__() + self._in_channel = in_channel + self._out_channel = out_channel + self._spatial_dims = spatial_dims + + conv_type = Conv[Conv.CONV, self._spatial_dims] + + self.act = get_act_layer(name=act_name) + self.conv_1 = conv_type( + in_channels=self._in_channel, + out_channels=self._out_channel // 2, + kernel_size=1, + stride=2, + padding=0, + groups=1, + bias=False, + dilation=1, + ) + self.conv_2 = conv_type( + in_channels=self._in_channel, + out_channels=self._out_channel // 2, + kernel_size=1, + stride=2, + padding=0, + groups=1, + bias=False, + dilation=1, + ) + self.norm = get_norm_layer(name=norm_name, spatial_dims=self._spatial_dims, channels=self._out_channel) + + def forward(self, x): + x = self.act(x) + out = torch.cat([self.conv_1(x), self.conv_2(x[:, :, 1:, 1:, 1:])], dim=1) + out = self.norm(out) + return out + + +class P3DReLUConvNormBlock(nn.Module): + def __init__( + self, + in_channel: int, + out_channel: int, + kernel_size: int, + padding: int, + p3dmode: int = 0, + act_name: Union[Tuple, str] = "RELU", + norm_name: Union[Tuple, str] = "INSTANCE", + ): + super().__init__() + self._in_channel = in_channel + self._out_channel = out_channel + self._p3dmode = p3dmode + + conv_type = Conv[Conv.CONV, 3] + + if self._p3dmode == 0: # 3 x 3 x 1 + kernel_size0 = (kernel_size, kernel_size, 1) + kernel_size1 = (1, 1, kernel_size) + padding0 = (padding, padding, 0) + padding1 = (0, 0, padding) + elif self._p3dmode == 1: # 3 x 1 x 3 + kernel_size0 = (kernel_size, 1, kernel_size) + kernel_size1 = (1, kernel_size, 1) + padding0 = (padding, 0, padding) + padding1 = (0, padding, 0) + elif self._p3dmode == 2: # 1 x 3 x 3 + kernel_size0 = (1, kernel_size, kernel_size) + kernel_size1 = (kernel_size, 1, 1) + padding0 = (0, padding, padding) + padding1 = (padding, 0, 0) + + self.op = nn.Sequential( + get_act_layer(name=act_name), + conv_type( + in_channels=self._in_channel, + out_channels=self._out_channel, + kernel_size=kernel_size0, + stride=1, + padding=padding0, + groups=1, + bias=False, + dilation=1, + ), + conv_type( + in_channels=self._in_channel, + out_channels=self._out_channel, + kernel_size=kernel_size1, + stride=1, + padding=padding1, + groups=1, + bias=False, + dilation=1, + ), + get_norm_layer(name=norm_name, spatial_dims=3, channels=self._out_channel), + ) + + def forward(self, x): + return self.op(x) + + +class ReLUConvNormBlock(nn.Module): + def __init__( + self, + in_channel: int, + out_channel: int, + kernel_size: int = 3, + padding: int = 1, + spatial_dims: int = 3, + act_name: Union[Tuple, str] = "RELU", + norm_name: Union[Tuple, str] = "INSTANCE", + ): + super().__init__() + self._in_channel = in_channel + self._out_channel = out_channel + self._spatial_dims = spatial_dims + + conv_type = Conv[Conv.CONV, self._spatial_dims] + + self.op = nn.Sequential( + get_act_layer(name=act_name), + conv_type( + in_channels=self._in_channel, + out_channels=self._out_channel, + kernel_size=kernel_size, + stride=1, + padding=padding, + groups=1, + bias=False, + dilation=1, + ), + get_norm_layer(name=norm_name, spatial_dims=self._spatial_dims, channels=self._out_channel), + ) + + def forward(self, x): + return self.op(x) diff --git a/tests/test_factorized_increase.py b/tests/test_factorized_increase.py new file mode 100644 index 0000000000..6ef3bb465d --- /dev/null +++ b/tests/test_factorized_increase.py @@ -0,0 +1,34 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +from parameterized import parameterized + +from monai.networks.blocks.dints_block import FactorizedIncreaseBlock + +TEST_CASES_3D = [ + [{"in_channel": 32, "out_channel": 16}, (7, 32, 24, 16, 8), (7, 16, 48, 32, 16)], + [{"in_channel": 1, "out_channel": 2}, (1, 1, 1, 1, 1), (1, 2, 2, 2, 2)], +] + + +class TestFactInc(unittest.TestCase): + @parameterized.expand(TEST_CASES_3D) + def test_factorized_increase_3d(self, input_param, input_shape, expected_shape): + net = FactorizedIncreaseBlock(**input_param) + result = net(torch.randn(input_shape)) + self.assertEqual(result.shape, expected_shape) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_factorized_reduce.py b/tests/test_factorized_reduce.py new file mode 100644 index 0000000000..3184d1c6ca --- /dev/null +++ b/tests/test_factorized_reduce.py @@ -0,0 +1,34 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +from parameterized import parameterized + +from monai.networks.blocks.dints_block import FactorizedReduceBlock + +TEST_CASES_3D = [ + [{"in_channel": 32, "out_channel": 16}, (7, 32, 24, 16, 8), (7, 16, 12, 8, 4)], + [{"in_channel": 16, "out_channel": 32}, (7, 16, 23, 15, 7), (7, 32, 11, 7, 3)], +] + + +class TestFactRed(unittest.TestCase): + @parameterized.expand(TEST_CASES_3D) + def test_factorized_increase_3d(self, input_param, input_shape, expected_shape): + net = FactorizedReduceBlock(**input_param) + result = net(torch.randn(input_shape)) + self.assertEqual(result.shape, expected_shape) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_p3d_block.py b/tests/test_p3d_block.py new file mode 100644 index 0000000000..b97954ce61 --- /dev/null +++ b/tests/test_p3d_block.py @@ -0,0 +1,41 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +from parameterized import parameterized + +from monai.networks.blocks.dints_block import P3DActiConvNormBlock + +TEST_CASES_3D = [ + [{"c_in": 32, "c_out": 16, "kernel_size": 3, "padding": 0, "p3dmode": 0}, (7, 32, 16, 32, 8), (7, 16, 14, 30, 6)], + [ + {"c_in": 32, "c_out": 16, "kernel_size": 3, "padding": 1, "p3dmode": 0}, # check padding + (7, 32, 16, 32, 8), + (7, 16, 16, 32, 8), + ], + [{"c_in": 32, "c_out": 16, "kernel_size": 3, "padding": 0, "p3dmode": 1}, (7, 32, 16, 32, 8), (7, 16, 14, 30, 6)], + [{"c_in": 32, "c_out": 16, "kernel_size": 3, "padding": 0, "p3dmode": 2}, (7, 32, 16, 32, 8), (7, 16, 14, 30, 6)], + [{"c_in": 32, "c_out": 16, "kernel_size": 4, "padding": 0, "p3dmode": 0}, (7, 32, 16, 32, 8), (7, 16, 13, 29, 5)], +] + + +class TestP3D(unittest.TestCase): + @parameterized.expand(TEST_CASES_3D) + def test_factorized_increase_3d(self, input_param, input_shape, expected_shape): + net = P3DActiConvNormBlock(**input_param) + result = net(torch.randn(input_shape)) + self.assertEqual(result.shape, expected_shape) + + +if __name__ == "__main__": + unittest.main() From 79276fc9d9931edea3a2ddde161c80b92899d020 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Fri, 19 Nov 2021 18:17:32 +0000 Subject: [PATCH 2/9] docstring updates Signed-off-by: Wenqi Li --- docs/source/networks.rst | 8 ++-- monai/networks/blocks/__init__.py | 2 +- monai/networks/blocks/dints_block.py | 64 +++++++++++++++++++++++++--- tests/test_p3d_block.py | 10 ++--- 4 files changed, 68 insertions(+), 16 deletions(-) diff --git a/docs/source/networks.rst b/docs/source/networks.rst index 0aad63b373..452df4580d 100644 --- a/docs/source/networks.rst +++ b/docs/source/networks.rst @@ -198,14 +198,14 @@ Blocks .. autoclass:: FactorizedReduceBlock :members: -`P3DReLUConvNormBlock` +`P3DActiConvNormBlock` ~~~~~~~~~~~~~~~~~~~~~~ -.. autoclass:: P3DReLUConvNormBlock +.. autoclass:: P3DActiConvNormBlock :members: -`ReLUConvNormBlock` +`ActiConvNormBlock` ~~~~~~~~~~~~~~~~~~~ -.. autoclass:: ReLUConvNormBlock +.. autoclass:: ActiConvNormBlock :members: `Warp` diff --git a/monai/networks/blocks/__init__.py b/monai/networks/blocks/__init__.py index cb3e64304c..68c5be82d7 100644 --- a/monai/networks/blocks/__init__.py +++ b/monai/networks/blocks/__init__.py @@ -14,7 +14,7 @@ from .aspp import SimpleASPP from .convolutions import Convolution, ResidualUnit from .crf import CRF -from .dints_block import FactorizedIncreaseBlock, FactorizedReduceBlock, P3DReLUConvNormBlock, ReLUConvNormBlock +from .dints_block import FactorizedIncreaseBlock, FactorizedReduceBlock, P3DActiConvNormBlock, ActiConvNormBlock from .downsample import MaxAvgPool from .dynunet_block import UnetBasicBlock, UnetOutBlock, UnetResBlock, UnetUpBlock, get_output_padding, get_padding from .fcn import FCN, GCN, MCFCN, Refine diff --git a/monai/networks/blocks/dints_block.py b/monai/networks/blocks/dints_block.py index 6c4681161a..c73ced3d41 100644 --- a/monai/networks/blocks/dints_block.py +++ b/monai/networks/blocks/dints_block.py @@ -18,12 +18,12 @@ from monai.networks.layers.factories import Conv from monai.networks.layers.utils import get_act_layer, get_norm_layer -__all__ = ["FactorizedIncreaseBlock", "FactorizedReduceBlock", "P3DReLUConvNormBlock", "ReLUConvNormBlock"] +__all__ = ["FactorizedIncreaseBlock", "FactorizedReduceBlock", "P3DActiConvNormBlock", "ActiConvNormBlock"] class FactorizedIncreaseBlock(nn.Module): """ - Up-sampling the features by 2 using linear interpolation and convolutions. + Up-sampling the features by two using linear interpolation and convolutions. """ def __init__( @@ -34,6 +34,14 @@ def __init__( act_name: Union[Tuple, str] = "RELU", norm_name: Union[Tuple, str] = "INSTANCE", ): + """ + Args: + in_channel: number of input channels + out_channel: number of output channels + spatial_dims: number of spatial dimensions + act_name: activation layer type and arguments. + norm_name: feature normalization type and arguments. + """ super().__init__() self._in_channel = in_channel self._out_channel = out_channel @@ -75,6 +83,14 @@ def __init__( act_name: Union[Tuple, str] = "RELU", norm_name: Union[Tuple, str] = "INSTANCE", ): + """ + Args: + in_channel: number of input channels + out_channel: number of output channels. + spatial_dims: number of spatial dimensions. + act_name: activation layer type and arguments. + norm_name: feature normalization type and arguments. + """ super().__init__() self._in_channel = in_channel self._out_channel = out_channel @@ -106,27 +122,48 @@ def __init__( self.norm = get_norm_layer(name=norm_name, spatial_dims=self._spatial_dims, channels=self._out_channel) def forward(self, x): + """ + The length along each spatial dimension must be a multiple of 2. + """ x = self.act(x) out = torch.cat([self.conv_1(x), self.conv_2(x[:, :, 1:, 1:, 1:])], dim=1) out = self.norm(out) return out -class P3DReLUConvNormBlock(nn.Module): +class P3DActiConvNormBlock(nn.Module): + """ + -- (act) -- (conv) -- (norm) -- + """ def __init__( self, in_channel: int, out_channel: int, kernel_size: int, padding: int, - p3dmode: int = 0, + mode: int = 0, act_name: Union[Tuple, str] = "RELU", norm_name: Union[Tuple, str] = "INSTANCE", ): + """ + Args: + in_channel: number of input channels. + out_channel: number of output channels. + kernel_size: kernel size to be expanded to 3D. + padding: padding size to be expanded to 3D. + mode: mode for the anisotropic kernels: + + - 0: ``(k, k, 1)``, ``(1, 1, k)``, + - 1: ``(k, 1, k)``, ``(1, k, 1)``, + - 2: ``(1, k, k)``. ``(k, 1, 1)``. + + act_name:activation layer type and arguments. + norm_name: feature normalization type and arguments. + """ super().__init__() self._in_channel = in_channel self._out_channel = out_channel - self._p3dmode = p3dmode + self._p3dmode = int(mode) conv_type = Conv[Conv.CONV, 3] @@ -145,6 +182,8 @@ def __init__( kernel_size1 = (kernel_size, 1, 1) padding0 = (0, padding, padding) padding1 = (padding, 0, 0) + else: + raise ValueError("`mode` must be 0, 1, or 2.") self.op = nn.Sequential( get_act_layer(name=act_name), @@ -175,7 +214,10 @@ def forward(self, x): return self.op(x) -class ReLUConvNormBlock(nn.Module): +class ActiConvNormBlock(nn.Sequential): + """ + -- (Acti) -- (Conv) -- (Norm) -- + """ def __init__( self, in_channel: int, @@ -186,6 +228,16 @@ def __init__( act_name: Union[Tuple, str] = "RELU", norm_name: Union[Tuple, str] = "INSTANCE", ): + """ + Args: + in_channel: number of input channels. + out_channel: number of output channels. + kernel_size: kernel size of the convolution. + padding: padding size of the convolution. + spatial_dims: number of spatial dimensions. + act_name: activation layer type and arguments. + norm_name: feature normalization type and arguments. + """ super().__init__() self._in_channel = in_channel self._out_channel = out_channel diff --git a/tests/test_p3d_block.py b/tests/test_p3d_block.py index b97954ce61..15f0366d82 100644 --- a/tests/test_p3d_block.py +++ b/tests/test_p3d_block.py @@ -17,15 +17,15 @@ from monai.networks.blocks.dints_block import P3DActiConvNormBlock TEST_CASES_3D = [ - [{"c_in": 32, "c_out": 16, "kernel_size": 3, "padding": 0, "p3dmode": 0}, (7, 32, 16, 32, 8), (7, 16, 14, 30, 6)], + [{"c_in": 32, "c_out": 16, "kernel_size": 3, "padding": 0, "mode": 0}, (7, 32, 16, 32, 8), (7, 16, 14, 30, 6)], [ - {"c_in": 32, "c_out": 16, "kernel_size": 3, "padding": 1, "p3dmode": 0}, # check padding + {"c_in": 32, "c_out": 16, "kernel_size": 3, "padding": 1, "mode": 0}, # check padding (7, 32, 16, 32, 8), (7, 16, 16, 32, 8), ], - [{"c_in": 32, "c_out": 16, "kernel_size": 3, "padding": 0, "p3dmode": 1}, (7, 32, 16, 32, 8), (7, 16, 14, 30, 6)], - [{"c_in": 32, "c_out": 16, "kernel_size": 3, "padding": 0, "p3dmode": 2}, (7, 32, 16, 32, 8), (7, 16, 14, 30, 6)], - [{"c_in": 32, "c_out": 16, "kernel_size": 4, "padding": 0, "p3dmode": 0}, (7, 32, 16, 32, 8), (7, 16, 13, 29, 5)], + [{"c_in": 32, "c_out": 16, "kernel_size": 3, "padding": 0, "mode": 1}, (7, 32, 16, 32, 8), (7, 16, 14, 30, 6)], + [{"c_in": 32, "c_out": 16, "kernel_size": 3, "padding": 0, "mode": 2}, (7, 32, 16, 32, 8), (7, 16, 14, 30, 6)], + [{"c_in": 32, "c_out": 16, "kernel_size": 4, "padding": 0, "mode": 0}, (7, 32, 16, 32, 8), (7, 16, 13, 29, 5)], ] From 96dbb9c6c3088553d798258e49cdb58aa33af93a Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Fri, 19 Nov 2021 18:21:19 +0000 Subject: [PATCH 3/9] update to use sequential Signed-off-by: Wenqi Li --- monai/networks/blocks/__init__.py | 2 +- monai/networks/blocks/dints_block.py | 59 ++++++++++++++-------------- 2 files changed, 31 insertions(+), 30 deletions(-) diff --git a/monai/networks/blocks/__init__.py b/monai/networks/blocks/__init__.py index 68c5be82d7..01a5bfca2a 100644 --- a/monai/networks/blocks/__init__.py +++ b/monai/networks/blocks/__init__.py @@ -14,7 +14,7 @@ from .aspp import SimpleASPP from .convolutions import Convolution, ResidualUnit from .crf import CRF -from .dints_block import FactorizedIncreaseBlock, FactorizedReduceBlock, P3DActiConvNormBlock, ActiConvNormBlock +from .dints_block import ActiConvNormBlock, FactorizedIncreaseBlock, FactorizedReduceBlock, P3DActiConvNormBlock from .downsample import MaxAvgPool from .dynunet_block import UnetBasicBlock, UnetOutBlock, UnetResBlock, UnetUpBlock, get_output_padding, get_padding from .fcn import FCN, GCN, MCFCN, Refine diff --git a/monai/networks/blocks/dints_block.py b/monai/networks/blocks/dints_block.py index c73ced3d41..54cbd56e76 100644 --- a/monai/networks/blocks/dints_block.py +++ b/monai/networks/blocks/dints_block.py @@ -10,10 +10,9 @@ # limitations under the License. -from typing import Optional, Sequence, Tuple, Union +from typing import Tuple, Union import torch -import torch.nn as nn from monai.networks.layers.factories import Conv from monai.networks.layers.utils import get_act_layer, get_norm_layer @@ -21,7 +20,7 @@ __all__ = ["FactorizedIncreaseBlock", "FactorizedReduceBlock", "P3DActiConvNormBlock", "ActiConvNormBlock"] -class FactorizedIncreaseBlock(nn.Module): +class FactorizedIncreaseBlock(torch.nn.Sequential): """ Up-sampling the features by two using linear interpolation and convolutions. """ @@ -49,9 +48,10 @@ def __init__( conv_type = Conv[Conv.CONV, self._spatial_dims] - self.op = nn.Sequential( - nn.Upsample(scale_factor=2, mode="trilinear", align_corners=True), - get_act_layer(name=act_name), + self.add_module("up", torch.nn.Upsample(scale_factor=2, mode="trilinear", align_corners=True)) + self.add_module("acti", get_act_layer(name=act_name)) + self.add_module( + "conv", conv_type( in_channels=self._in_channel, out_channels=self._out_channel, @@ -62,14 +62,13 @@ def __init__( bias=False, dilation=1, ), - get_norm_layer(name=norm_name, spatial_dims=self._spatial_dims, channels=self._out_channel), ) - - def forward(self, x): - return self.op(x) + self.add_module( + "norm", get_norm_layer(name=norm_name, spatial_dims=self._spatial_dims, channels=self._out_channel) + ) -class FactorizedReduceBlock(nn.Module): +class FactorizedReduceBlock(torch.nn.Module): """ Down-sampling the feature by 2 using stride. The length along each spatial dimension must be a multiple of 2. @@ -131,10 +130,11 @@ def forward(self, x): return out -class P3DActiConvNormBlock(nn.Module): +class P3DActiConvNormBlock(torch.nn.Sequential): """ -- (act) -- (conv) -- (norm) -- """ + def __init__( self, in_channel: int, @@ -167,17 +167,17 @@ def __init__( conv_type = Conv[Conv.CONV, 3] - if self._p3dmode == 0: # 3 x 3 x 1 + if self._p3dmode == 0: # (k, k, 1), (1, 1, k) kernel_size0 = (kernel_size, kernel_size, 1) kernel_size1 = (1, 1, kernel_size) padding0 = (padding, padding, 0) padding1 = (0, 0, padding) - elif self._p3dmode == 1: # 3 x 1 x 3 + elif self._p3dmode == 1: # (k, 1, k), (1, k, 1) kernel_size0 = (kernel_size, 1, kernel_size) kernel_size1 = (1, kernel_size, 1) padding0 = (padding, 0, padding) padding1 = (0, padding, 0) - elif self._p3dmode == 2: # 1 x 3 x 3 + elif self._p3dmode == 2: # (1, k, k), (k, 1, 1) kernel_size0 = (1, kernel_size, kernel_size) kernel_size1 = (kernel_size, 1, 1) padding0 = (0, padding, padding) @@ -185,8 +185,9 @@ def __init__( else: raise ValueError("`mode` must be 0, 1, or 2.") - self.op = nn.Sequential( - get_act_layer(name=act_name), + self.add_module("acti", get_act_layer(name=act_name)) + self.add_module( + "conv", conv_type( in_channels=self._in_channel, out_channels=self._out_channel, @@ -197,6 +198,9 @@ def __init__( bias=False, dilation=1, ), + ) + self.add_module( + "conv_1", conv_type( in_channels=self._in_channel, out_channels=self._out_channel, @@ -207,17 +211,15 @@ def __init__( bias=False, dilation=1, ), - get_norm_layer(name=norm_name, spatial_dims=3, channels=self._out_channel), ) - - def forward(self, x): - return self.op(x) + self.add_module("norm", get_norm_layer(name=norm_name, spatial_dims=3, channels=self._out_channel)) -class ActiConvNormBlock(nn.Sequential): +class ActiConvNormBlock(torch.nn.Sequential): """ -- (Acti) -- (Conv) -- (Norm) -- """ + def __init__( self, in_channel: int, @@ -244,9 +246,9 @@ def __init__( self._spatial_dims = spatial_dims conv_type = Conv[Conv.CONV, self._spatial_dims] - - self.op = nn.Sequential( - get_act_layer(name=act_name), + self.add_module("acti", get_act_layer(name=act_name)) + self.add_module( + "conv", conv_type( in_channels=self._in_channel, out_channels=self._out_channel, @@ -257,8 +259,7 @@ def __init__( bias=False, dilation=1, ), - get_norm_layer(name=norm_name, spatial_dims=self._spatial_dims, channels=self._out_channel), ) - - def forward(self, x): - return self.op(x) + self.add_module( + "norm", get_norm_layer(name=norm_name, spatial_dims=self._spatial_dims, channels=self._out_channel) + ) From 15c7deb4e936acc6a31d6fff04ee316f4aa7b2d7 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Fri, 19 Nov 2021 18:36:30 +0000 Subject: [PATCH 4/9] update tests Signed-off-by: Wenqi Li --- monai/networks/blocks/dints_block.py | 2 +- tests/test_p3d_block.py | 26 +++++++++++++++++++++----- 2 files changed, 22 insertions(+), 6 deletions(-) diff --git a/monai/networks/blocks/dints_block.py b/monai/networks/blocks/dints_block.py index 54cbd56e76..6d0195a17c 100644 --- a/monai/networks/blocks/dints_block.py +++ b/monai/networks/blocks/dints_block.py @@ -202,7 +202,7 @@ def __init__( self.add_module( "conv_1", conv_type( - in_channels=self._in_channel, + in_channels=self._out_channel, out_channels=self._out_channel, kernel_size=kernel_size1, stride=1, diff --git a/tests/test_p3d_block.py b/tests/test_p3d_block.py index 15f0366d82..1a424e60d9 100644 --- a/tests/test_p3d_block.py +++ b/tests/test_p3d_block.py @@ -17,15 +17,31 @@ from monai.networks.blocks.dints_block import P3DActiConvNormBlock TEST_CASES_3D = [ - [{"c_in": 32, "c_out": 16, "kernel_size": 3, "padding": 0, "mode": 0}, (7, 32, 16, 32, 8), (7, 16, 14, 30, 6)], [ - {"c_in": 32, "c_out": 16, "kernel_size": 3, "padding": 1, "mode": 0}, # check padding + {"in_channel": 32, "out_channel": 16, "kernel_size": 3, "padding": 0, "mode": 0}, + (7, 32, 16, 32, 8), + (7, 16, 14, 30, 6), + ], + [ + {"in_channel": 32, "out_channel": 16, "kernel_size": 3, "padding": 1, "mode": 0}, # check padding (7, 32, 16, 32, 8), (7, 16, 16, 32, 8), ], - [{"c_in": 32, "c_out": 16, "kernel_size": 3, "padding": 0, "mode": 1}, (7, 32, 16, 32, 8), (7, 16, 14, 30, 6)], - [{"c_in": 32, "c_out": 16, "kernel_size": 3, "padding": 0, "mode": 2}, (7, 32, 16, 32, 8), (7, 16, 14, 30, 6)], - [{"c_in": 32, "c_out": 16, "kernel_size": 4, "padding": 0, "mode": 0}, (7, 32, 16, 32, 8), (7, 16, 13, 29, 5)], + [ + {"in_channel": 32, "out_channel": 16, "kernel_size": 3, "padding": 0, "mode": 1}, + (7, 32, 16, 32, 8), + (7, 16, 14, 30, 6), + ], + [ + {"in_channel": 32, "out_channel": 16, "kernel_size": 3, "padding": 0, "mode": 2}, + (7, 32, 16, 32, 8), + (7, 16, 14, 30, 6), + ], + [ + {"in_channel": 32, "out_channel": 16, "kernel_size": 4, "padding": 0, "mode": 0}, + (7, 32, 16, 32, 8), + (7, 16, 13, 29, 5), + ], ] From a1859b6453398a3fd04620f30a84a0ee024116b6 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Fri, 19 Nov 2021 18:37:53 +0000 Subject: [PATCH 5/9] fixes tests Signed-off-by: Wenqi Li --- tests/test_factorized_reduce.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_factorized_reduce.py b/tests/test_factorized_reduce.py index 3184d1c6ca..48a91a7c14 100644 --- a/tests/test_factorized_reduce.py +++ b/tests/test_factorized_reduce.py @@ -18,7 +18,7 @@ TEST_CASES_3D = [ [{"in_channel": 32, "out_channel": 16}, (7, 32, 24, 16, 8), (7, 16, 12, 8, 4)], - [{"in_channel": 16, "out_channel": 32}, (7, 16, 23, 15, 7), (7, 32, 11, 7, 3)], + [{"in_channel": 16, "out_channel": 32}, (7, 16, 22, 14, 6), (7, 32, 11, 7, 3)], ] From 500da75f9a5be2ab3e062a9bb5a1be1836085477 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Fri, 19 Nov 2021 18:45:28 +0000 Subject: [PATCH 6/9] update tests Signed-off-by: Wenqi Li --- monai/networks/blocks/dints_block.py | 6 ++--- tests/test_acn_block.py | 38 ++++++++++++++++++++++++++++ tests/test_factorized_reduce.py | 2 +- tests/test_p3d_block.py | 2 +- 4 files changed, 43 insertions(+), 5 deletions(-) create mode 100644 tests/test_acn_block.py diff --git a/monai/networks/blocks/dints_block.py b/monai/networks/blocks/dints_block.py index 6d0195a17c..1869e09268 100644 --- a/monai/networks/blocks/dints_block.py +++ b/monai/networks/blocks/dints_block.py @@ -120,7 +120,7 @@ def __init__( ) self.norm = get_norm_layer(name=norm_name, spatial_dims=self._spatial_dims, channels=self._out_channel) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: """ The length along each spatial dimension must be a multiple of 2. """ @@ -190,7 +190,7 @@ def __init__( "conv", conv_type( in_channels=self._in_channel, - out_channels=self._out_channel, + out_channels=self._in_channel, kernel_size=kernel_size0, stride=1, padding=padding0, @@ -202,7 +202,7 @@ def __init__( self.add_module( "conv_1", conv_type( - in_channels=self._out_channel, + in_channels=self._in_channel, out_channels=self._out_channel, kernel_size=kernel_size1, stride=1, diff --git a/tests/test_acn_block.py b/tests/test_acn_block.py new file mode 100644 index 0000000000..5e2467a565 --- /dev/null +++ b/tests/test_acn_block.py @@ -0,0 +1,38 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +from parameterized import parameterized + +from monai.networks.blocks.dints_block import ActiConvNormBlock + +TEST_CASES = [ + [{"in_channel": 32, "out_channel": 16, "kernel_size": 3, "padding": 1}, (7, 32, 16, 31, 7), (7, 16, 16, 31, 7)], + [ + {"in_channel": 32, "out_channel": 16, "kernel_size": 3, "padding": 1, "spatial_dims": 2}, + (7, 32, 13, 32), + (7, 16, 13, 32), + ], +] + + +class TestACNBlock(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_acn_block(self, input_param, input_shape, expected_shape): + net = ActiConvNormBlock(**input_param) + result = net(torch.randn(input_shape)) + self.assertEqual(result.shape, expected_shape) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_factorized_reduce.py b/tests/test_factorized_reduce.py index 48a91a7c14..9ed1e8f6ca 100644 --- a/tests/test_factorized_reduce.py +++ b/tests/test_factorized_reduce.py @@ -24,7 +24,7 @@ class TestFactRed(unittest.TestCase): @parameterized.expand(TEST_CASES_3D) - def test_factorized_increase_3d(self, input_param, input_shape, expected_shape): + def test_factorized_reduce_3d(self, input_param, input_shape, expected_shape): net = FactorizedReduceBlock(**input_param) result = net(torch.randn(input_shape)) self.assertEqual(result.shape, expected_shape) diff --git a/tests/test_p3d_block.py b/tests/test_p3d_block.py index 1a424e60d9..9953a4a0fe 100644 --- a/tests/test_p3d_block.py +++ b/tests/test_p3d_block.py @@ -47,7 +47,7 @@ class TestP3D(unittest.TestCase): @parameterized.expand(TEST_CASES_3D) - def test_factorized_increase_3d(self, input_param, input_shape, expected_shape): + def test_3d(self, input_param, input_shape, expected_shape): net = P3DActiConvNormBlock(**input_param) result = net(torch.randn(input_shape)) self.assertEqual(result.shape, expected_shape) From dc4409ee5a00935283e95995b3204e775017cb9e Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Fri, 19 Nov 2021 22:33:02 +0000 Subject: [PATCH 7/9] update based on comments Signed-off-by: Wenqi Li --- tests/test_p3d_block.py | 22 ++++++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/tests/test_p3d_block.py b/tests/test_p3d_block.py index 9953a4a0fe..b9237cba01 100644 --- a/tests/test_p3d_block.py +++ b/tests/test_p3d_block.py @@ -33,12 +33,26 @@ (7, 16, 14, 30, 6), ], [ - {"in_channel": 32, "out_channel": 16, "kernel_size": 3, "padding": 0, "mode": 2}, + { + "in_channel": 32, + "out_channel": 16, + "kernel_size": 3, + "padding": 0, + "mode": 2, + "act_name": ("leakyrelu", {"inplace": True, "negative_slope": 0.2}), + }, (7, 32, 16, 32, 8), (7, 16, 14, 30, 6), ], [ - {"in_channel": 32, "out_channel": 16, "kernel_size": 4, "padding": 0, "mode": 0}, + { + "in_channel": 32, + "out_channel": 16, + "kernel_size": 4, + "padding": 0, + "mode": 0, + "norm_name": ("INSTANCE", {"affine": True}), + }, (7, 32, 16, 32, 8), (7, 16, 13, 29, 5), ], @@ -52,6 +66,10 @@ def test_3d(self, input_param, input_shape, expected_shape): result = net(torch.randn(input_shape)) self.assertEqual(result.shape, expected_shape) + def test_ill(self): + with self.assertRaises(ValueError): + P3DActiConvNormBlock(in_channel=32, out_channel=16, kernel_size=3, padding=0, mode=3) + if __name__ == "__main__": unittest.main() From bb4a5973ee00bafc7ae7b3d49fb1c79e6dccf064 Mon Sep 17 00:00:00 2001 From: dongy Date: Tue, 16 Nov 2021 08:30:20 -0800 Subject: [PATCH 8/9] update scripts Signed-off-by: dongy --- monai/networks/nets/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/monai/networks/nets/__init__.py b/monai/networks/nets/__init__.py index f7348b7b44..f59d3c35d8 100644 --- a/monai/networks/nets/__init__.py +++ b/monai/networks/nets/__init__.py @@ -29,6 +29,7 @@ densenet201, densenet264, ) +from .dints import DiNTS from .dynunet import DynUNet, DynUnet, Dynunet from .efficientnet import ( BlockArgs, From 9cca0f8e2c07ebfeaef59372276be20387ab2c3b Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Fri, 19 Nov 2021 22:44:30 +0000 Subject: [PATCH 9/9] temp remote dints import Signed-off-by: Wenqi Li --- monai/networks/nets/__init__.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/monai/networks/nets/__init__.py b/monai/networks/nets/__init__.py index f59d3c35d8..3dcb0782ed 100644 --- a/monai/networks/nets/__init__.py +++ b/monai/networks/nets/__init__.py @@ -29,7 +29,8 @@ densenet201, densenet264, ) -from .dints import DiNTS + +# from .dints import DiNTS from .dynunet import DynUNet, DynUnet, Dynunet from .efficientnet import ( BlockArgs,