diff --git a/monai/networks/nets/resnet.py b/monai/networks/nets/resnet.py index e923c1bb7d..fca975d40e 100644 --- a/monai/networks/nets/resnet.py +++ b/monai/networks/nets/resnet.py @@ -175,6 +175,7 @@ class ResNet(nn.Module): widen_factor: widen output for each layer. num_classes: number of output (classifications). feed_forward: whether to add the FC layer for the output, default to `True`. + bias_downsample: whether to use bias term in the downsampling block when `shortcut_type` is 'B', default to `True`. """ @@ -192,6 +193,7 @@ def __init__( widen_factor: float = 1.0, num_classes: int = 400, feed_forward: bool = True, + bias_downsample: bool = True, # for backwards compatibility (also see PR #5477) ) -> None: super().__init__() @@ -216,6 +218,7 @@ def __init__( self.in_planes = block_inplanes[0] self.no_max_pool = no_max_pool + self.bias_downsample = bias_downsample conv1_kernel_size = ensure_tuple_rep(conv1_t_size, spatial_dims) conv1_stride = ensure_tuple_rep(conv1_t_stride, spatial_dims) @@ -277,7 +280,13 @@ def _make_layer( ) else: downsample = nn.Sequential( - conv_type(self.in_planes, planes * block.expansion, kernel_size=1, stride=stride), + conv_type( + self.in_planes, + planes * block.expansion, + kernel_size=1, + stride=stride, + bias=self.bias_downsample, + ), norm_type(planes * block.expansion), ) @@ -323,7 +332,7 @@ def _resnet( progress: bool, **kwargs: Any, ) -> ResNet: - model: ResNet = ResNet(block, layers, block_inplanes, **kwargs) + model: ResNet = ResNet(block, layers, block_inplanes, bias_downsample=not pretrained, **kwargs) if pretrained: # Author of paper zipped the state_dict on googledrive, # so would need to download, unzip and read (2.8gb file for a ~150mb state dict). diff --git a/tests/test_resnet.py b/tests/test_resnet.py index ae05f36210..b09b97a450 100644 --- a/tests/test_resnet.py +++ b/tests/test_resnet.py @@ -140,11 +140,27 @@ (1, 3), ] +TEST_CASE_7 = [ # 1D, batch 1, 2 input channels, bias_downsample + { + "block": "bottleneck", + "layers": [3, 4, 6, 3], + "block_inplanes": [64, 128, 256, 512], + "spatial_dims": 1, + "n_input_channels": 2, + "num_classes": 3, + "conv1_t_size": [3], + "conv1_t_stride": 1, + "bias_downsample": False, # set to False if pretrained=True (PR #5477) + }, + (1, 2, 32), + (1, 3), +] + TEST_CASES = [] for case in [TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_2_A, TEST_CASE_3_A]: for model in [resnet10, resnet18, resnet34, resnet50, resnet101, resnet152, resnet200]: TEST_CASES.append([model, *case]) -for case in [TEST_CASE_5, TEST_CASE_5_A, TEST_CASE_6]: +for case in [TEST_CASE_5, TEST_CASE_5_A, TEST_CASE_6, TEST_CASE_7]: TEST_CASES.append([ResNet, *case]) TEST_SCRIPT_CASES = [