diff --git a/monai/networks/nets/attentionunet.py b/monai/networks/nets/attentionunet.py index 5689cf1071..fdf31d9701 100644 --- a/monai/networks/nets/attentionunet.py +++ b/monai/networks/nets/attentionunet.py @@ -29,7 +29,7 @@ def __init__( spatial_dims: int, in_channels: int, out_channels: int, - kernel_size: int = 3, + kernel_size: Sequence[int] | int = 3, strides: int = 1, dropout=0.0, ): @@ -219,7 +219,13 @@ def __init__( self.kernel_size = kernel_size self.dropout = dropout - head = ConvBlock(spatial_dims=spatial_dims, in_channels=in_channels, out_channels=channels[0], dropout=dropout) + head = ConvBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=channels[0], + dropout=dropout, + kernel_size=self.kernel_size, + ) reduce_channels = Convolution( spatial_dims=spatial_dims, in_channels=channels[0], @@ -245,6 +251,7 @@ def _create_block(channels: Sequence[int], strides: Sequence[int]) -> nn.Module: out_channels=channels[1], strides=strides[0], dropout=self.dropout, + kernel_size=self.kernel_size, ), subblock, ), @@ -271,6 +278,7 @@ def _get_bottom_layer(self, in_channels: int, out_channels: int, strides: int) - out_channels=out_channels, strides=strides, dropout=self.dropout, + kernel_size=self.kernel_size, ), up_kernel_size=self.up_kernel_size, strides=strides, diff --git a/tests/test_attentionunet.py b/tests/test_attentionunet.py index 83f6cabc5e..6a577f763f 100644 --- a/tests/test_attentionunet.py +++ b/tests/test_attentionunet.py @@ -14,11 +14,17 @@ import unittest import torch +import torch.nn as nn import monai.networks.nets.attentionunet as att from tests.utils import skip_if_no_cuda, skip_if_quick +def get_net_parameters(net: nn.Module) -> int: + """Returns the total number of parameters in a Module.""" + return sum(param.numel() for param in net.parameters()) + + class TestAttentionUnet(unittest.TestCase): def test_attention_block(self): @@ -50,6 +56,20 @@ def test_attentionunet(self): self.assertEqual(output.shape[0], input.shape[0]) self.assertEqual(output.shape[1], 2) + def test_attentionunet_kernel_size(self): + args_dict = { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 2, + "channels": (3, 4, 5), + "up_kernel_size": 5, + "strides": (1, 2), + } + model_a = att.AttentionUnet(**args_dict, kernel_size=5) + model_b = att.AttentionUnet(**args_dict, kernel_size=7) + self.assertEqual(get_net_parameters(model_a), 3534) + self.assertEqual(get_net_parameters(model_b), 5574) + @skip_if_no_cuda def test_attentionunet_gpu(self): for dims in [2, 3]: