From 38c4e90c9eab16f8eac5caf460cf0b2d2738aca6 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Mon, 13 Sep 2021 12:38:27 +0100 Subject: [PATCH 1/5] con1_padding -> conv1_padding Signed-off-by: Wenqi Li --- monai/networks/nets/resnet.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/networks/nets/resnet.py b/monai/networks/nets/resnet.py index a5e6b7ab81..ecb5b850d1 100644 --- a/monai/networks/nets/resnet.py +++ b/monai/networks/nets/resnet.py @@ -198,7 +198,7 @@ def __init__( ] block_avgpool = get_avgpool() - conv1_kernel, conv1_stride, con1_padding = get_conv1(conv1_t_size, conv1_t_stride) + conv1_kernel, conv1_stride, conv1_padding = get_conv1(conv1_t_size, conv1_t_stride) block_inplanes = [int(x * widen_factor) for x in block_inplanes] self.in_planes = block_inplanes[0] @@ -209,7 +209,7 @@ def __init__( self.in_planes, kernel_size=conv1_kernel[spatial_dims], stride=conv1_stride[spatial_dims], - padding=con1_padding[spatial_dims], + padding=conv1_padding[spatial_dims], bias=False, ) self.bn1 = norm_type(self.in_planes) From 1e16e6adb13c86be44832d5d804ccf0750964a52 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Mon, 13 Sep 2021 13:12:13 +0100 Subject: [PATCH 2/5] simpler init. Signed-off-by: Wenqi Li --- monai/networks/nets/resnet.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/monai/networks/nets/resnet.py b/monai/networks/nets/resnet.py index ecb5b850d1..72ddcf02ce 100644 --- a/monai/networks/nets/resnet.py +++ b/monai/networks/nets/resnet.py @@ -269,12 +269,16 @@ def _make_layer( norm_type(planes * block.expansion), ) - layers = [] - layers.append( + layers = [ block( - in_planes=self.in_planes, planes=planes, spatial_dims=spatial_dims, stride=stride, downsample=downsample + in_planes=self.in_planes, + planes=planes, + spatial_dims=spatial_dims, + stride=stride, + downsample=downsample, ) - ) + ] + self.in_planes = planes * block.expansion for _i in range(1, blocks): layers.append(block(self.in_planes, planes, spatial_dims=spatial_dims)) From c24b8cc40d2aa77243d4c8c145e0b4c4defd14ec Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Mon, 13 Sep 2021 13:42:53 +0100 Subject: [PATCH 3/5] fixes 2715 Signed-off-by: Wenqi Li --- monai/networks/nets/resnet.py | 22 ++++++++++++---------- tests/test_resnet.py | 14 +++++++++++++- 2 files changed, 25 insertions(+), 11 deletions(-) diff --git a/monai/networks/nets/resnet.py b/monai/networks/nets/resnet.py index 72ddcf02ce..d6e0f6f528 100644 --- a/monai/networks/nets/resnet.py +++ b/monai/networks/nets/resnet.py @@ -17,6 +17,8 @@ import torch.nn.functional as F from monai.networks.layers.factories import Conv, Norm, Pool +from monai.networks.layers.utils import get_pool_layer +from monai.utils.module import look_up_option __all__ = ["ResNet", "resnet10", "resnet18", "resnet34", "resnet50", "resnet101", "resnet152", "resnet200"] @@ -162,7 +164,9 @@ class ResNet(nn.Module): conv1_t_size: size of first convolution layer, determines kernel and padding. conv1_t_stride: stride of first convolution layer. no_max_pool: bool argument to determine if to use maxpool layer. - shortcut_type: which downsample block to use. + shortcut_type: which downsample block to use. Options are 'A', 'B', default to 'B'. + - 'A': using `self._downsample_basic_block`. + - 'B': kernel_size 1 conv + norm. widen_factor: widen output for each layer. num_classes: number of output (classifications) """ @@ -234,14 +238,9 @@ def __init__( nn.init.constant_(torch.as_tensor(m.bias), 0) def _downsample_basic_block(self, x: torch.Tensor, planes: int, stride: int, spatial_dims: int = 3) -> torch.Tensor: - assert spatial_dims == 3 - out: torch.Tensor = F.avg_pool3d(x, kernel_size=1, stride=stride) - zero_pads = torch.zeros(out.size(0), planes - out.size(1), out.size(2), out.size(3), out.size(4)) - if isinstance(out.data, torch.FloatTensor): - zero_pads = zero_pads.cuda() - + out: torch.Tensor = get_pool_layer(("avg", {"kernel_size": 1, "stride": stride}), spatial_dims=spatial_dims)(x) + zero_pads = torch.zeros(out.size(0), planes - out.size(1), *out.shape[2:], dtype=out.dtype, device=out.device) out = torch.cat([out.data, zero_pads], dim=1) - return out def _make_layer( @@ -259,9 +258,12 @@ def _make_layer( downsample: Union[nn.Module, partial, None] = None if stride != 1 or self.in_planes != planes * block.expansion: - if shortcut_type == "A": + if look_up_option(shortcut_type, {"A", "B"}) == "A": downsample = partial( - self._downsample_basic_block, planes=planes * block.expansion, kernel_size=1, stride=stride + self._downsample_basic_block, + planes=planes * block.expansion, + stride=stride, + spatial_dims=spatial_dims, ) else: downsample = nn.Sequential( diff --git a/tests/test_resnet.py b/tests/test_resnet.py index c4ba5c2e16..a3b77138f3 100644 --- a/tests/test_resnet.py +++ b/tests/test_resnet.py @@ -42,14 +42,26 @@ (2, 3), ] +TEST_CASE_2_A = [ # 2D, batch 2, 1 input channel, shortcut type A + {"pretrained": False, "spatial_dims": 2, "n_input_channels": 1, "num_classes": 3, "shortcut_type": "A"}, + (2, 1, 32, 64), + (2, 3), +] + TEST_CASE_3 = [ # 1D, batch 1, 2 input channels {"pretrained": False, "spatial_dims": 1, "n_input_channels": 2, "num_classes": 3}, (1, 2, 32), (1, 3), ] +TEST_CASE_3_A = [ # 1D, batch 1, 2 input channels + {"pretrained": False, "spatial_dims": 1, "n_input_channels": 2, "num_classes": 3, "shortcut_type": "A"}, + (1, 2, 32), + (1, 3), +] + TEST_CASES = [] -for case in [TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]: +for case in [TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_2_A]: for model in [resnet10, resnet18, resnet34, resnet50, resnet101, resnet152, resnet200]: TEST_CASES.append([model, *case]) From faa53e8f599d7abef21fb35393cd8c65f5f739b3 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Mon, 13 Sep 2021 13:46:45 +0100 Subject: [PATCH 4/5] adds 3d tests Signed-off-by: Wenqi Li --- tests/test_resnet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_resnet.py b/tests/test_resnet.py index a3b77138f3..16cd6f4865 100644 --- a/tests/test_resnet.py +++ b/tests/test_resnet.py @@ -61,7 +61,7 @@ ] TEST_CASES = [] -for case in [TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_2_A]: +for case in [TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_2_A, TEST_CASE_3_A]: for model in [resnet10, resnet18, resnet34, resnet50, resnet101, resnet152, resnet200]: TEST_CASES.append([model, *case]) From 3c6358aaa0d7d722adba44d09b11be81223f1e01 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Mon, 13 Sep 2021 14:10:18 +0100 Subject: [PATCH 5/5] fixes flake8 error Signed-off-by: Wenqi Li --- monai/networks/nets/resnet.py | 1 - 1 file changed, 1 deletion(-) diff --git a/monai/networks/nets/resnet.py b/monai/networks/nets/resnet.py index d6e0f6f528..3b86dc3d62 100644 --- a/monai/networks/nets/resnet.py +++ b/monai/networks/nets/resnet.py @@ -14,7 +14,6 @@ import torch import torch.nn as nn -import torch.nn.functional as F from monai.networks.layers.factories import Conv, Norm, Pool from monai.networks.layers.utils import get_pool_layer