From c3c485baed6e245b4eccd0b90a21d02b70f23e1b Mon Sep 17 00:00:00 2001 From: Peter Kaplinsky Date: Wed, 1 May 2024 12:18:35 -0400 Subject: [PATCH 1/6] Propagate kernel size through attention unet convblocks Signed-off-by: Peter Kaplinsky --- monai/networks/nets/attentionunet.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/monai/networks/nets/attentionunet.py b/monai/networks/nets/attentionunet.py index 5689cf1071..4de0fe0281 100644 --- a/monai/networks/nets/attentionunet.py +++ b/monai/networks/nets/attentionunet.py @@ -219,7 +219,13 @@ def __init__( self.kernel_size = kernel_size self.dropout = dropout - head = ConvBlock(spatial_dims=spatial_dims, in_channels=in_channels, out_channels=channels[0], dropout=dropout) + head = ConvBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=channels[0], + dropout=dropout, + kernel_size=self.kernel_size, + ) reduce_channels = Convolution( spatial_dims=spatial_dims, in_channels=channels[0], @@ -245,6 +251,7 @@ def _create_block(channels: Sequence[int], strides: Sequence[int]) -> nn.Module: out_channels=channels[1], strides=strides[0], dropout=self.dropout, + kernel_size=self.kernel_size, ), subblock, ), @@ -271,6 +278,7 @@ def _get_bottom_layer(self, in_channels: int, out_channels: int, strides: int) - out_channels=out_channels, strides=strides, dropout=self.dropout, + kernel_size=self.kernel_size, ), up_kernel_size=self.up_kernel_size, strides=strides, From 4bbf2acb20287bee2e8ef297210ece84bb8a198f Mon Sep 17 00:00:00 2001 From: Peter Kaplinsky Date: Wed, 1 May 2024 12:41:10 -0400 Subject: [PATCH 2/6] attentionunet kernel_size type fix Signed-off-by: Peter Kaplinsky --- monai/networks/nets/attentionunet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/networks/nets/attentionunet.py b/monai/networks/nets/attentionunet.py index 4de0fe0281..fdf31d9701 100644 --- a/monai/networks/nets/attentionunet.py +++ b/monai/networks/nets/attentionunet.py @@ -29,7 +29,7 @@ def __init__( spatial_dims: int, in_channels: int, out_channels: int, - kernel_size: int = 3, + kernel_size: Sequence[int] | int = 3, strides: int = 1, dropout=0.0, ): From 8d3d36c74c1a7ba156e8e23428ff7e8a38931374 Mon Sep 17 00:00:00 2001 From: Peter Kaplinsky Date: Wed, 1 May 2024 16:20:53 -0400 Subject: [PATCH 3/6] Trigger rerun Signed-off-by: Peter Kaplinsky From 424ec2e54eecc3b632701e1abdd9b95f3867cdfb Mon Sep 17 00:00:00 2001 From: Peter Kaplinsky Date: Thu, 2 May 2024 14:23:52 -0400 Subject: [PATCH 4/6] add tests for attentionunet kernel size Signed-off-by: Peter Kaplinsky --- tests/test_attentionunet.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/tests/test_attentionunet.py b/tests/test_attentionunet.py index 83f6cabc5e..395362d7d1 100644 --- a/tests/test_attentionunet.py +++ b/tests/test_attentionunet.py @@ -14,11 +14,17 @@ import unittest import torch +import torch.nn as nn import monai.networks.nets.attentionunet as att from tests.utils import skip_if_no_cuda, skip_if_quick +def get_net_parameters(net: nn.Module) -> int: + """Returns the total number of parameters in a Module.""" + return sum(param.numel() for param in net.parameters()) + + class TestAttentionUnet(unittest.TestCase): def test_attention_block(self): @@ -50,6 +56,27 @@ def test_attentionunet(self): self.assertEqual(output.shape[0], input.shape[0]) self.assertEqual(output.shape[1], 2) + def test_attentionunet_kernel_size(self): + model_a = att.AttentionUnet( + spatial_dims=2, + in_channels=1, + out_channels=2, + channels=(3, 4, 5), + up_kernel_size=5, + strides=(1, 2), + kernel_size=5, + ) + model_b = att.AttentionUnet( + spatial_dims=2, + in_channels=1, + out_channels=2, + channels=(3, 4, 5), + up_kernel_size=5, + strides=(1, 2), + kernel_size=7, + ) + self.assertNotEqual(get_net_parameters(model_a), get_net_parameters(model_b)) + @skip_if_no_cuda def test_attentionunet_gpu(self): for dims in [2, 3]: From c7611f6b27f5a3b00ad4c5b1aa6eab66917df7ff Mon Sep 17 00:00:00 2001 From: Peter Kaplinsky Date: Thu, 2 May 2024 14:38:27 -0400 Subject: [PATCH 5/6] reformat attentionunet tests Signed-off-by: Peter Kaplinsky --- tests/test_attentionunet.py | 28 ++++++++++------------------ 1 file changed, 10 insertions(+), 18 deletions(-) diff --git a/tests/test_attentionunet.py b/tests/test_attentionunet.py index 395362d7d1..c07c340d6c 100644 --- a/tests/test_attentionunet.py +++ b/tests/test_attentionunet.py @@ -57,24 +57,16 @@ def test_attentionunet(self): self.assertEqual(output.shape[1], 2) def test_attentionunet_kernel_size(self): - model_a = att.AttentionUnet( - spatial_dims=2, - in_channels=1, - out_channels=2, - channels=(3, 4, 5), - up_kernel_size=5, - strides=(1, 2), - kernel_size=5, - ) - model_b = att.AttentionUnet( - spatial_dims=2, - in_channels=1, - out_channels=2, - channels=(3, 4, 5), - up_kernel_size=5, - strides=(1, 2), - kernel_size=7, - ) + args_dict = { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 2, + "channels": (3, 4, 5), + "up_kernel_size": 5, + "strides": (1, 2), + } + model_a = att.AttentionUnet(**args_dict, kernel_size=5) + model_b = att.AttentionUnet(**args_dict, kernel_size=7) self.assertNotEqual(get_net_parameters(model_a), get_net_parameters(model_b)) @skip_if_no_cuda From 054169827933652e369ddfbf87a5432012c8d258 Mon Sep 17 00:00:00 2001 From: Peter Kaplinsky Date: Fri, 3 May 2024 11:14:02 -0400 Subject: [PATCH 6/6] test param count in att-unet Signed-off-by: Peter Kaplinsky --- tests/test_attentionunet.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_attentionunet.py b/tests/test_attentionunet.py index c07c340d6c..6a577f763f 100644 --- a/tests/test_attentionunet.py +++ b/tests/test_attentionunet.py @@ -67,7 +67,8 @@ def test_attentionunet_kernel_size(self): } model_a = att.AttentionUnet(**args_dict, kernel_size=5) model_b = att.AttentionUnet(**args_dict, kernel_size=7) - self.assertNotEqual(get_net_parameters(model_a), get_net_parameters(model_b)) + self.assertEqual(get_net_parameters(model_a), 3534) + self.assertEqual(get_net_parameters(model_b), 5574) @skip_if_no_cuda def test_attentionunet_gpu(self):