Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 22 additions & 17 deletions monai/networks/nets/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@

import torch
import torch.nn as nn
import torch.nn.functional as F

from monai.networks.layers.factories import Conv, Norm, Pool
from monai.networks.layers.utils import get_pool_layer
from monai.utils.module import look_up_option

__all__ = ["ResNet", "resnet10", "resnet18", "resnet34", "resnet50", "resnet101", "resnet152", "resnet200"]

Expand Down Expand Up @@ -162,7 +163,9 @@ class ResNet(nn.Module):
conv1_t_size: size of first convolution layer, determines kernel and padding.
conv1_t_stride: stride of first convolution layer.
no_max_pool: bool argument to determine if to use maxpool layer.
shortcut_type: which downsample block to use.
shortcut_type: which downsample block to use. Options are 'A', 'B', default to 'B'.
- 'A': using `self._downsample_basic_block`.
- 'B': kernel_size 1 conv + norm.
widen_factor: widen output for each layer.
num_classes: number of output (classifications)
"""
Expand Down Expand Up @@ -198,7 +201,7 @@ def __init__(
]

block_avgpool = get_avgpool()
conv1_kernel, conv1_stride, con1_padding = get_conv1(conv1_t_size, conv1_t_stride)
conv1_kernel, conv1_stride, conv1_padding = get_conv1(conv1_t_size, conv1_t_stride)
block_inplanes = [int(x * widen_factor) for x in block_inplanes]

self.in_planes = block_inplanes[0]
Expand All @@ -209,7 +212,7 @@ def __init__(
self.in_planes,
kernel_size=conv1_kernel[spatial_dims],
stride=conv1_stride[spatial_dims],
padding=con1_padding[spatial_dims],
padding=conv1_padding[spatial_dims],
bias=False,
)
self.bn1 = norm_type(self.in_planes)
Expand All @@ -234,14 +237,9 @@ def __init__(
nn.init.constant_(torch.as_tensor(m.bias), 0)

def _downsample_basic_block(self, x: torch.Tensor, planes: int, stride: int, spatial_dims: int = 3) -> torch.Tensor:
assert spatial_dims == 3
out: torch.Tensor = F.avg_pool3d(x, kernel_size=1, stride=stride)
zero_pads = torch.zeros(out.size(0), planes - out.size(1), out.size(2), out.size(3), out.size(4))
if isinstance(out.data, torch.FloatTensor):
zero_pads = zero_pads.cuda()

out: torch.Tensor = get_pool_layer(("avg", {"kernel_size": 1, "stride": stride}), spatial_dims=spatial_dims)(x)
zero_pads = torch.zeros(out.size(0), planes - out.size(1), *out.shape[2:], dtype=out.dtype, device=out.device)
out = torch.cat([out.data, zero_pads], dim=1)

return out

def _make_layer(
Expand All @@ -259,22 +257,29 @@ def _make_layer(

downsample: Union[nn.Module, partial, None] = None
if stride != 1 or self.in_planes != planes * block.expansion:
if shortcut_type == "A":
if look_up_option(shortcut_type, {"A", "B"}) == "A":
downsample = partial(
self._downsample_basic_block, planes=planes * block.expansion, kernel_size=1, stride=stride
self._downsample_basic_block,
planes=planes * block.expansion,
stride=stride,
spatial_dims=spatial_dims,
)
else:
downsample = nn.Sequential(
conv_type(self.in_planes, planes * block.expansion, kernel_size=1, stride=stride),
norm_type(planes * block.expansion),
)

layers = []
layers.append(
layers = [
block(
in_planes=self.in_planes, planes=planes, spatial_dims=spatial_dims, stride=stride, downsample=downsample
in_planes=self.in_planes,
planes=planes,
spatial_dims=spatial_dims,
stride=stride,
downsample=downsample,
)
)
]

self.in_planes = planes * block.expansion
for _i in range(1, blocks):
layers.append(block(self.in_planes, planes, spatial_dims=spatial_dims))
Expand Down
14 changes: 13 additions & 1 deletion tests/test_resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,26 @@
(2, 3),
]

TEST_CASE_2_A = [ # 2D, batch 2, 1 input channel, shortcut type A
{"pretrained": False, "spatial_dims": 2, "n_input_channels": 1, "num_classes": 3, "shortcut_type": "A"},
(2, 1, 32, 64),
(2, 3),
]

TEST_CASE_3 = [ # 1D, batch 1, 2 input channels
{"pretrained": False, "spatial_dims": 1, "n_input_channels": 2, "num_classes": 3},
(1, 2, 32),
(1, 3),
]

TEST_CASE_3_A = [ # 1D, batch 1, 2 input channels
{"pretrained": False, "spatial_dims": 1, "n_input_channels": 2, "num_classes": 3, "shortcut_type": "A"},
(1, 2, 32),
(1, 3),
]

TEST_CASES = []
for case in [TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]:
for case in [TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_2_A, TEST_CASE_3_A]:
for model in [resnet10, resnet18, resnet34, resnet50, resnet101, resnet152, resnet200]:
TEST_CASES.append([model, *case])

Expand Down