diff --git a/monai/networks/nets/attentionunet.py b/monai/networks/nets/attentionunet.py index 177a54e105..a57b57425e 100644 --- a/monai/networks/nets/attentionunet.py +++ b/monai/networks/nets/attentionunet.py @@ -143,12 +143,27 @@ def forward(self, g: torch.Tensor, x: torch.Tensor) -> torch.Tensor: class AttentionLayer(nn.Module): - def __init__(self, spatial_dims: int, in_channels: int, out_channels: int, submodule: nn.Module, dropout=0.0): + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + submodule: nn.Module, + up_kernel_size=3, + strides=2, + dropout=0.0, + ): super().__init__() self.attention = AttentionBlock( spatial_dims=spatial_dims, f_g=in_channels, f_l=in_channels, f_int=in_channels // 2 ) - self.upconv = UpConv(spatial_dims=spatial_dims, in_channels=out_channels, out_channels=in_channels, strides=2) + self.upconv = UpConv( + spatial_dims=spatial_dims, + in_channels=out_channels, + out_channels=in_channels, + strides=strides, + kernel_size=up_kernel_size, + ) self.merge = Convolution( spatial_dims=spatial_dims, in_channels=2 * in_channels, out_channels=in_channels, dropout=dropout ) @@ -174,7 +189,7 @@ class AttentionUnet(nn.Module): channels (Sequence[int]): sequence of channels. Top block first. The length of `channels` should be no less than 2. strides (Sequence[int]): stride to use for convolutions. kernel_size: convolution kernel size. - upsample_kernel_size: convolution kernel size for transposed convolution layers. + up_kernel_size: convolution kernel size for transposed convolution layers. dropout: dropout ratio. Defaults to no dropout. """ @@ -210,9 +225,9 @@ def __init__( ) self.up_kernel_size = up_kernel_size - def _create_block(channels: Sequence[int], strides: Sequence[int], level: int = 0) -> nn.Module: + def _create_block(channels: Sequence[int], strides: Sequence[int]) -> nn.Module: if len(channels) > 2: - subblock = _create_block(channels[1:], strides[1:], level=level + 1) + subblock = _create_block(channels[1:], strides[1:]) return AttentionLayer( spatial_dims=spatial_dims, in_channels=channels[0], @@ -227,17 +242,19 @@ def _create_block(channels: Sequence[int], strides: Sequence[int], level: int = ), subblock, ), + up_kernel_size=self.up_kernel_size, + strides=strides[0], dropout=dropout, ) else: # the next layer is the bottom so stop recursion, - # create the bottom layer as the sublock for this layer - return self._get_bottom_layer(channels[0], channels[1], strides[0], level=level + 1) + # create the bottom layer as the subblock for this layer + return self._get_bottom_layer(channels[0], channels[1], strides[0]) encdec = _create_block(self.channels, self.strides) self.model = nn.Sequential(head, encdec, reduce_channels) - def _get_bottom_layer(self, in_channels: int, out_channels: int, strides: int, level: int) -> nn.Module: + def _get_bottom_layer(self, in_channels: int, out_channels: int, strides: int) -> nn.Module: return AttentionLayer( spatial_dims=self.dimensions, in_channels=in_channels, @@ -249,6 +266,8 @@ def _get_bottom_layer(self, in_channels: int, out_channels: int, strides: int, l strides=strides, dropout=self.dropout, ), + up_kernel_size=self.up_kernel_size, + strides=strides, dropout=self.dropout, ) diff --git a/tests/test_attentionunet.py b/tests/test_attentionunet.py index b2f53f9c16..e1df7b8acd 100644 --- a/tests/test_attentionunet.py +++ b/tests/test_attentionunet.py @@ -39,7 +39,7 @@ def test_attentionunet(self): shape = (3, 1) + (92,) * dims input = torch.rand(*shape) model = att.AttentionUnet( - spatial_dims=dims, in_channels=1, out_channels=2, channels=(3, 4, 5), strides=(2, 2) + spatial_dims=dims, in_channels=1, out_channels=2, channels=(3, 4, 5), up_kernel_size=5, strides=(1, 2) ) output = model(input) self.assertEqual(output.shape[2:], input.shape[2:])