diff --git a/monai/networks/blocks/aspp.py b/monai/networks/blocks/aspp.py index 41ed39c359..f8bf8a5ba6 100644 --- a/monai/networks/blocks/aspp.py +++ b/monai/networks/blocks/aspp.py @@ -39,6 +39,7 @@ def __init__( dilations: Sequence[int] = (1, 2, 4, 6), norm_type: Optional[Union[Tuple, str]] = "BATCH", acti_type: Optional[Union[Tuple, str]] = "LEAKYRELU", + bias: bool = False, ) -> None: """ Args: @@ -54,6 +55,9 @@ def __init__( Defaults to batch norm. acti_type: final kernel-size-one convolution activation type. Defaults to leaky ReLU. + bias: whether to have a bias term in convolution blocks. Defaults to False. + According to `Performance Tuning Guide `_, + if a conv layer is directly followed by a batch norm layer, bias should be False. Raises: ValueError: When ``kernel_sizes`` length differs from ``dilations``. @@ -88,6 +92,7 @@ def __init__( kernel_size=1, act=acti_type, norm=norm_type, + bias=bias, ) def forward(self, x: torch.Tensor) -> torch.Tensor: diff --git a/monai/networks/nets/autoencoder.py b/monai/networks/nets/autoencoder.py index d0089198d5..d0a54b8148 100644 --- a/monai/networks/nets/autoencoder.py +++ b/monai/networks/nets/autoencoder.py @@ -37,6 +37,7 @@ def __init__( act: Optional[Union[Tuple, str]] = Act.PRELU, norm: Union[Tuple, str] = Norm.INSTANCE, dropout: Optional[Union[Tuple, str, float]] = None, + bias: bool = True, ) -> None: super().__init__() @@ -51,6 +52,7 @@ def __init__( self.act = act self.norm = norm self.dropout = dropout + self.bias = bias self.num_inter_units = num_inter_units self.inter_channels = inter_channels if inter_channels is not None else [] self.inter_dilations = list(inter_dilations or [1] * len(self.inter_channels)) @@ -103,6 +105,7 @@ def _get_intermediate_module(self, in_channels: int, num_inter_units: int) -> Tu norm=self.norm, dropout=self.dropout, dilation=di, + bias=self.bias, ) else: unit = Convolution( @@ -115,6 +118,7 @@ def _get_intermediate_module(self, in_channels: int, num_inter_units: int) -> Tu norm=self.norm, dropout=self.dropout, dilation=di, + bias=self.bias, ) intermediate.add_module("inter_%i" % i, unit) @@ -148,6 +152,7 @@ def _get_encode_layer(self, in_channels: int, out_channels: int, strides: int, i act=self.act, norm=self.norm, dropout=self.dropout, + bias=self.bias, last_conv_only=is_last, ) return Convolution( @@ -159,6 +164,7 @@ def _get_encode_layer(self, in_channels: int, out_channels: int, strides: int, i act=self.act, norm=self.norm, dropout=self.dropout, + bias=self.bias, conv_only=is_last, ) @@ -175,6 +181,7 @@ def _get_decode_layer(self, in_channels: int, out_channels: int, strides: int, i act=self.act, norm=self.norm, dropout=self.dropout, + bias=self.bias, conv_only=is_last and self.num_res_units == 0, is_transposed=True, ) @@ -192,6 +199,7 @@ def _get_decode_layer(self, in_channels: int, out_channels: int, strides: int, i act=self.act, norm=self.norm, dropout=self.dropout, + bias=self.bias, last_conv_only=is_last, ) diff --git a/monai/networks/nets/basic_unet.py b/monai/networks/nets/basic_unet.py index 18bb1d3a6c..63205f45ee 100644 --- a/monai/networks/nets/basic_unet.py +++ b/monai/networks/nets/basic_unet.py @@ -31,6 +31,7 @@ def __init__( out_chns: int, act: Union[str, tuple], norm: Union[str, tuple], + bias: bool, dropout: Union[float, tuple] = 0.0, ): """ @@ -40,12 +41,14 @@ def __init__( out_chns: number of output channels. act: activation type and arguments. norm: feature normalization type and arguments. + bias: whether to have a bias term in convolution blocks. dropout: dropout ratio. Defaults to no dropout. + """ super().__init__() - conv_0 = Convolution(dim, in_chns, out_chns, act=act, norm=norm, dropout=dropout, padding=1) - conv_1 = Convolution(dim, out_chns, out_chns, act=act, norm=norm, dropout=dropout, padding=1) + conv_0 = Convolution(dim, in_chns, out_chns, act=act, norm=norm, dropout=dropout, bias=bias, padding=1) + conv_1 = Convolution(dim, out_chns, out_chns, act=act, norm=norm, dropout=dropout, bias=bias, padding=1) self.add_module("conv_0", conv_0) self.add_module("conv_1", conv_1) @@ -60,6 +63,7 @@ def __init__( out_chns: int, act: Union[str, tuple], norm: Union[str, tuple], + bias: bool, dropout: Union[float, tuple] = 0.0, ): """ @@ -69,12 +73,14 @@ def __init__( out_chns: number of output channels. act: activation type and arguments. norm: feature normalization type and arguments. + bias: whether to have a bias term in convolution blocks. dropout: dropout ratio. Defaults to no dropout. + """ super().__init__() max_pooling = Pool["MAX", dim](kernel_size=2) - convs = TwoConv(dim, in_chns, out_chns, act, norm, dropout) + convs = TwoConv(dim, in_chns, out_chns, act, norm, bias, dropout) self.add_module("max_pooling", max_pooling) self.add_module("convs", convs) @@ -90,6 +96,7 @@ def __init__( out_chns: int, act: Union[str, tuple], norm: Union[str, tuple], + bias: bool, dropout: Union[float, tuple] = 0.0, upsample: str = "deconv", pre_conv: Optional[Union[nn.Module, str]] = "default", @@ -105,6 +112,7 @@ def __init__( out_chns: number of output channels. act: activation type and arguments. norm: feature normalization type and arguments. + bias: whether to have a bias term in convolution blocks. dropout: dropout ratio. Defaults to no dropout. upsample: upsampling mode, available options are ``"deconv"``, ``"pixelshuffle"``, ``"nontrainable"``. @@ -132,9 +140,9 @@ def __init__( interp_mode=interp_mode, align_corners=align_corners, ) - self.convs = TwoConv(dim, cat_chns + up_chns, out_chns, act, norm, dropout) + self.convs = TwoConv(dim, cat_chns + up_chns, out_chns, act, norm, bias, dropout) - def forward(self, x: torch.Tensor, x_e: torch.Tensor): + def forward(self, x: torch.Tensor, x_e: Optional[torch.Tensor]): """ Args: @@ -143,15 +151,18 @@ def forward(self, x: torch.Tensor, x_e: torch.Tensor): """ x_0 = self.upsample(x) - # handling spatial shapes due to the 2x maxpooling with odd edge lengths. - dimensions = len(x.shape) - 2 - sp = [0] * (dimensions * 2) - for i in range(dimensions): - if x_e.shape[-i - 1] != x_0.shape[-i - 1]: - sp[i * 2 + 1] = 1 - x_0 = torch.nn.functional.pad(x_0, sp, "replicate") + if x_e is not None: + # handling spatial shapes due to the 2x maxpooling with odd edge lengths. + dimensions = len(x.shape) - 2 + sp = [0] * (dimensions * 2) + for i in range(dimensions): + if x_e.shape[-i - 1] != x_0.shape[-i - 1]: + sp[i * 2 + 1] = 1 + x_0 = torch.nn.functional.pad(x_0, sp, "replicate") + x = self.convs(torch.cat([x_e, x_0], dim=1)) # input channels: (cat_chns + up_chns) + else: + x = self.convs(x_0) - x = self.convs(torch.cat([x_e, x_0], dim=1)) # input channels: (cat_chns + up_chns) return x @@ -164,6 +175,7 @@ def __init__( features: Sequence[int] = (32, 32, 64, 128, 256, 32), act: Union[str, tuple] = ("LeakyReLU", {"negative_slope": 0.1, "inplace": True}), norm: Union[str, tuple] = ("instance", {"affine": True}), + bias: bool = True, dropout: Union[float, tuple] = 0.0, upsample: str = "deconv", ): @@ -188,6 +200,9 @@ def __init__( act: activation type and arguments. Defaults to LeakyReLU. norm: feature normalization type and arguments. Defaults to instance norm. + bias: whether to have a bias term in convolution blocks. Defaults to True. + According to `Performance Tuning Guide `_, + if a conv layer is directly followed by a batch norm layer, bias should be False. dropout: dropout ratio. Defaults to no dropout. upsample: upsampling mode, available options are ``"deconv"``, ``"pixelshuffle"``, ``"nontrainable"``. @@ -214,16 +229,16 @@ def __init__( fea = ensure_tuple_rep(features, 6) print(f"BasicUNet features: {fea}.") - self.conv_0 = TwoConv(dimensions, in_channels, features[0], act, norm, dropout) - self.down_1 = Down(dimensions, fea[0], fea[1], act, norm, dropout) - self.down_2 = Down(dimensions, fea[1], fea[2], act, norm, dropout) - self.down_3 = Down(dimensions, fea[2], fea[3], act, norm, dropout) - self.down_4 = Down(dimensions, fea[3], fea[4], act, norm, dropout) + self.conv_0 = TwoConv(dimensions, in_channels, features[0], act, norm, bias, dropout) + self.down_1 = Down(dimensions, fea[0], fea[1], act, norm, bias, dropout) + self.down_2 = Down(dimensions, fea[1], fea[2], act, norm, bias, dropout) + self.down_3 = Down(dimensions, fea[2], fea[3], act, norm, bias, dropout) + self.down_4 = Down(dimensions, fea[3], fea[4], act, norm, bias, dropout) - self.upcat_4 = UpCat(dimensions, fea[4], fea[3], fea[3], act, norm, dropout, upsample) - self.upcat_3 = UpCat(dimensions, fea[3], fea[2], fea[2], act, norm, dropout, upsample) - self.upcat_2 = UpCat(dimensions, fea[2], fea[1], fea[1], act, norm, dropout, upsample) - self.upcat_1 = UpCat(dimensions, fea[1], fea[0], fea[5], act, norm, dropout, upsample, halves=False) + self.upcat_4 = UpCat(dimensions, fea[4], fea[3], fea[3], act, norm, bias, dropout, upsample) + self.upcat_3 = UpCat(dimensions, fea[3], fea[2], fea[2], act, norm, bias, dropout, upsample) + self.upcat_2 = UpCat(dimensions, fea[2], fea[1], fea[1], act, norm, bias, dropout, upsample) + self.upcat_1 = UpCat(dimensions, fea[1], fea[0], fea[5], act, norm, bias, dropout, upsample, halves=False) self.final_conv = Conv["conv", dimensions](fea[5], out_channels, kernel_size=1) diff --git a/monai/networks/nets/highresnet.py b/monai/networks/nets/highresnet.py index a67a5088ce..12908a9119 100644 --- a/monai/networks/nets/highresnet.py +++ b/monai/networks/nets/highresnet.py @@ -43,6 +43,7 @@ def __init__( dilation: Union[Sequence[int], int] = 1, norm_type: Union[Tuple, str] = ("batch", {"affine": True}), acti_type: Union[Tuple, str] = ("relu", {"inplace": True}), + bias: bool = False, channel_matching: Union[ChannelMatching, str] = ChannelMatching.PAD, ) -> None: """ @@ -56,6 +57,9 @@ def __init__( Defaults to ``("batch", {"affine": True})``. acti_type: {``"relu"``, ``"prelu"``, ``"relu6"``} Non-linear activation using ReLU or PReLU. Defaults to ``"relu"``. + bias: whether to have a bias term in convolution blocks. Defaults to False. + According to `Performance Tuning Guide `_, + if a conv layer is directly followed by a batch norm layer, bias should be False. channel_matching: {``"pad"``, ``"project"``} Specifies handling residual branch and conv branch channel mismatches. Defaults to ``"pad"``. @@ -85,6 +89,7 @@ def __init__( out_channels=_out_chns, kernel_size=kernel_size, dilation=dilation, + bias=bias, ) ) _in_chns = _out_chns @@ -116,6 +121,9 @@ class HighResNet(nn.Module): Defaults to ``("relu", {"inplace": True})``. dropout_prob: probability of the feature map to be zeroed (only applies to the penultimate conv layer). + bias: whether to have a bias term in convolution blocks. Defaults to False. + According to `Performance Tuning Guide `_, + if a conv layer is directly followed by a batch norm layer, bias should be False. layer_params: specifying key parameters of each layer/block. channel_matching: {``"pad"``, ``"project"``} Specifies handling residual branch and conv branch channel mismatches. Defaults to ``"pad"``. @@ -132,6 +140,7 @@ def __init__( norm_type: Union[str, tuple] = ("batch", {"affine": True}), acti_type: Union[str, tuple] = ("relu", {"inplace": True}), dropout_prob: Optional[Union[Tuple, str, float]] = 0.0, + bias: bool = False, layer_params: Sequence[Dict] = DEFAULT_LAYER_PARAMS_3D, channel_matching: Union[ChannelMatching, str] = ChannelMatching.PAD, ) -> None: @@ -151,6 +160,7 @@ def __init__( adn_ordering="NA", act=acti_type, norm=norm_type, + bias=bias, ) ) @@ -168,6 +178,7 @@ def __init__( dilation=_dilation, norm_type=norm_type, acti_type=acti_type, + bias=bias, channel_matching=channel_matching, ) ) @@ -185,6 +196,7 @@ def __init__( adn_ordering="NAD", act=acti_type, norm=norm_type, + bias=bias, dropout=dropout_prob, ) ) @@ -200,6 +212,7 @@ def __init__( adn_ordering="NAD", act=acti_type, norm=norm_type, + bias=bias, dropout=dropout_prob, ) ) diff --git a/monai/networks/nets/unet.py b/monai/networks/nets/unet.py index 158b154042..70cc816fe9 100644 --- a/monai/networks/nets/unet.py +++ b/monai/networks/nets/unet.py @@ -38,7 +38,8 @@ def __init__( num_res_units: int = 0, act: Union[Tuple, str] = Act.PRELU, norm: Union[Tuple, str] = Norm.INSTANCE, - dropout=0.0, + dropout: float = 0.0, + bias: bool = True, ) -> None: """ Enhanced version of UNet which has residual units implemented with the ResidualUnit class. @@ -60,6 +61,9 @@ def __init__( act: activation type and arguments. Defaults to PReLU. norm: feature normalization type and arguments. Defaults to instance norm. dropout: dropout ratio. Defaults to no dropout. + bias: whether to have a bias term in convolution blocks. Defaults to True. + According to `Performance Tuning Guide `_, + if a conv layer is directly followed by a batch norm layer, bias should be False. Note: The acceptable spatial size of input data depends on the parameters of the network, to set appropriate spatial size, please check the tutorial for more details: @@ -97,6 +101,7 @@ def __init__( self.act = act self.norm = norm self.dropout = dropout + self.bias = bias def _create_block( inc: int, outc: int, channels: Sequence[int], strides: Sequence[int], is_top: bool @@ -151,6 +156,7 @@ def _get_down_layer(self, in_channels: int, out_channels: int, strides: int, is_ act=self.act, norm=self.norm, dropout=self.dropout, + bias=self.bias, ) return Convolution( self.dimensions, @@ -161,6 +167,7 @@ def _get_down_layer(self, in_channels: int, out_channels: int, strides: int, is_ act=self.act, norm=self.norm, dropout=self.dropout, + bias=self.bias, ) def _get_bottom_layer(self, in_channels: int, out_channels: int) -> nn.Module: @@ -190,6 +197,7 @@ def _get_up_layer(self, in_channels: int, out_channels: int, strides: int, is_to act=self.act, norm=self.norm, dropout=self.dropout, + bias=self.bias, conv_only=is_top and self.num_res_units == 0, is_transposed=True, ) @@ -205,6 +213,7 @@ def _get_up_layer(self, in_channels: int, out_channels: int, strides: int, is_to act=self.act, norm=self.norm, dropout=self.dropout, + bias=self.bias, last_conv_only=is_top, ) conv = nn.Sequential(conv, ru) diff --git a/monai/networks/nets/varautoencoder.py b/monai/networks/nets/varautoencoder.py index 72caa3a2cb..7f54890992 100644 --- a/monai/networks/nets/varautoencoder.py +++ b/monai/networks/nets/varautoencoder.py @@ -43,6 +43,7 @@ def __init__( act: Optional[Union[Tuple, str]] = Act.PRELU, norm: Union[Tuple, str] = Norm.INSTANCE, dropout: Optional[Union[Tuple, str, float]] = None, + bias: bool = True, ) -> None: self.in_channels, *self.in_shape = in_shape @@ -65,6 +66,7 @@ def __init__( act, norm, dropout, + bias, ) padding = same_padding(self.kernel_size) diff --git a/monai/networks/nets/vnet.py b/monai/networks/nets/vnet.py index dc71cb104b..72f3290a89 100644 --- a/monai/networks/nets/vnet.py +++ b/monai/networks/nets/vnet.py @@ -29,7 +29,7 @@ def get_acti_layer(act: Union[Tuple[str, Dict], str], nchan: int = 0): class LUConv(nn.Module): - def __init__(self, spatial_dims: int, nchan: int, act: Union[Tuple[str, Dict], str]): + def __init__(self, spatial_dims: int, nchan: int, act: Union[Tuple[str, Dict], str], bias: bool = False): super(LUConv, self).__init__() self.act_function = get_acti_layer(act, nchan) @@ -40,6 +40,7 @@ def __init__(self, spatial_dims: int, nchan: int, act: Union[Tuple[str, Dict], s kernel_size=5, act=None, norm=Norm.BATCH, + bias=bias, ) def forward(self, x): @@ -48,15 +49,22 @@ def forward(self, x): return out -def _make_nconv(spatial_dims: int, nchan: int, depth: int, act: Union[Tuple[str, Dict], str]): +def _make_nconv(spatial_dims: int, nchan: int, depth: int, act: Union[Tuple[str, Dict], str], bias: bool = False): layers = [] for _ in range(depth): - layers.append(LUConv(spatial_dims, nchan, act)) + layers.append(LUConv(spatial_dims, nchan, act, bias)) return nn.Sequential(*layers) class InputTransition(nn.Module): - def __init__(self, spatial_dims: int, in_channels: int, out_channels: int, act: Union[Tuple[str, Dict], str]): + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + act: Union[Tuple[str, Dict], str], + bias: bool = False, + ): super(InputTransition, self).__init__() if 16 % in_channels != 0: @@ -72,6 +80,7 @@ def __init__(self, spatial_dims: int, in_channels: int, out_channels: int, act: kernel_size=5, act=None, norm=Norm.BATCH, + bias=bias, ) def forward(self, x): @@ -91,6 +100,7 @@ def __init__( act: Union[Tuple[str, Dict], str], dropout_prob: Optional[float] = None, dropout_dim: int = 3, + bias: bool = False, ): super(DownTransition, self).__init__() @@ -99,11 +109,11 @@ def __init__( dropout_type: Type[Union[nn.Dropout, nn.Dropout2d, nn.Dropout3d]] = Dropout[Dropout.DROPOUT, dropout_dim] out_channels = 2 * in_channels - self.down_conv = conv_type(in_channels, out_channels, kernel_size=2, stride=2) + self.down_conv = conv_type(in_channels, out_channels, kernel_size=2, stride=2, bias=bias) self.bn1 = norm_type(out_channels) self.act_function1 = get_acti_layer(act, out_channels) self.act_function2 = get_acti_layer(act, out_channels) - self.ops = _make_nconv(spatial_dims, out_channels, nconvs, act) + self.ops = _make_nconv(spatial_dims, out_channels, nconvs, act, bias) self.dropout = dropout_type(dropout_prob) if dropout_prob is not None else None def forward(self, x): @@ -156,7 +166,14 @@ def forward(self, x, skipx): class OutputTransition(nn.Module): - def __init__(self, spatial_dims: int, in_channels: int, out_channels: int, act: Union[Tuple[str, Dict], str]): + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + act: Union[Tuple[str, Dict], str], + bias: bool = False, + ): super(OutputTransition, self).__init__() conv_type: Type[Union[nn.Conv2d, nn.Conv3d]] = Conv[Conv.CONV, spatial_dims] @@ -169,6 +186,7 @@ def __init__(self, spatial_dims: int, in_channels: int, out_channels: int, act: kernel_size=5, act=None, norm=Norm.BATCH, + bias=bias, ) self.conv2 = conv_type(out_channels, out_channels, kernel_size=1) @@ -201,6 +219,10 @@ class VNet(nn.Module): - ``dropout_dim = 1``, randomly zeroes some of the elements for each channel. - ``dropout_dim = 2``, Randomly zeroes out entire channels (a channel is a 2D feature map). - ``dropout_dim = 3``, Randomly zeroes out entire channels (a channel is a 3D feature map). + bias: whether to have a bias term in convolution blocks. Defaults to False. + According to `Performance Tuning Guide `_, + if a conv layer is directly followed by a batch norm layer, bias should be False. + """ def __init__( @@ -211,22 +233,23 @@ def __init__( act: Union[Tuple[str, Dict], str] = ("elu", {"inplace": True}), dropout_prob: float = 0.5, dropout_dim: int = 3, + bias: bool = False, ): super().__init__() if spatial_dims not in (2, 3): raise AssertionError("spatial_dims can only be 2 or 3.") - self.in_tr = InputTransition(spatial_dims, in_channels, 16, act) - self.down_tr32 = DownTransition(spatial_dims, 16, 1, act) - self.down_tr64 = DownTransition(spatial_dims, 32, 2, act) - self.down_tr128 = DownTransition(spatial_dims, 64, 3, act, dropout_prob=dropout_prob) - self.down_tr256 = DownTransition(spatial_dims, 128, 2, act, dropout_prob=dropout_prob) + self.in_tr = InputTransition(spatial_dims, in_channels, 16, act, bias=bias) + self.down_tr32 = DownTransition(spatial_dims, 16, 1, act, bias=bias) + self.down_tr64 = DownTransition(spatial_dims, 32, 2, act, bias=bias) + self.down_tr128 = DownTransition(spatial_dims, 64, 3, act, dropout_prob=dropout_prob, bias=bias) + self.down_tr256 = DownTransition(spatial_dims, 128, 2, act, dropout_prob=dropout_prob, bias=bias) self.up_tr256 = UpTransition(spatial_dims, 256, 256, 2, act, dropout_prob=dropout_prob) self.up_tr128 = UpTransition(spatial_dims, 256, 128, 2, act, dropout_prob=dropout_prob) self.up_tr64 = UpTransition(spatial_dims, 128, 64, 1, act) self.up_tr32 = UpTransition(spatial_dims, 64, 32, 1, act) - self.out_tr = OutputTransition(spatial_dims, 32, out_channels, act) + self.out_tr = OutputTransition(spatial_dims, 32, out_channels, act, bias=bias) def forward(self, x): out16 = self.in_tr(x) diff --git a/tests/test_vnet.py b/tests/test_vnet.py index c64b566c42..4eba5396b2 100644 --- a/tests/test_vnet.py +++ b/tests/test_vnet.py @@ -73,3 +73,7 @@ def test_script(self): net = VNet(spatial_dims=3, in_channels=1, out_channels=3, dropout_dim=3) test_data = torch.randn(1, 1, 32, 32, 32) test_script_save(net, test_data) + + +if __name__ == "__main__": + unittest.main()