Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 27 additions & 8 deletions monai/networks/nets/attentionunet.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,12 +143,27 @@ def forward(self, g: torch.Tensor, x: torch.Tensor) -> torch.Tensor:


class AttentionLayer(nn.Module):
def __init__(self, spatial_dims: int, in_channels: int, out_channels: int, submodule: nn.Module, dropout=0.0):
def __init__(
self,
spatial_dims: int,
in_channels: int,
out_channels: int,
submodule: nn.Module,
up_kernel_size=3,
strides=2,
dropout=0.0,
):
super().__init__()
self.attention = AttentionBlock(
spatial_dims=spatial_dims, f_g=in_channels, f_l=in_channels, f_int=in_channels // 2
)
self.upconv = UpConv(spatial_dims=spatial_dims, in_channels=out_channels, out_channels=in_channels, strides=2)
self.upconv = UpConv(
spatial_dims=spatial_dims,
in_channels=out_channels,
out_channels=in_channels,
strides=strides,
kernel_size=up_kernel_size,
)
self.merge = Convolution(
spatial_dims=spatial_dims, in_channels=2 * in_channels, out_channels=in_channels, dropout=dropout
)
Expand All @@ -174,7 +189,7 @@ class AttentionUnet(nn.Module):
channels (Sequence[int]): sequence of channels. Top block first. The length of `channels` should be no less than 2.
strides (Sequence[int]): stride to use for convolutions.
kernel_size: convolution kernel size.
upsample_kernel_size: convolution kernel size for transposed convolution layers.
up_kernel_size: convolution kernel size for transposed convolution layers.
dropout: dropout ratio. Defaults to no dropout.
"""

Expand Down Expand Up @@ -210,9 +225,9 @@ def __init__(
)
self.up_kernel_size = up_kernel_size

def _create_block(channels: Sequence[int], strides: Sequence[int], level: int = 0) -> nn.Module:
def _create_block(channels: Sequence[int], strides: Sequence[int]) -> nn.Module:
if len(channels) > 2:
subblock = _create_block(channels[1:], strides[1:], level=level + 1)
subblock = _create_block(channels[1:], strides[1:])
return AttentionLayer(
spatial_dims=spatial_dims,
in_channels=channels[0],
Expand All @@ -227,17 +242,19 @@ def _create_block(channels: Sequence[int], strides: Sequence[int], level: int =
),
subblock,
),
up_kernel_size=self.up_kernel_size,
strides=strides[0],
dropout=dropout,
)
else:
# the next layer is the bottom so stop recursion,
# create the bottom layer as the sublock for this layer
return self._get_bottom_layer(channels[0], channels[1], strides[0], level=level + 1)
# create the bottom layer as the subblock for this layer
return self._get_bottom_layer(channels[0], channels[1], strides[0])

encdec = _create_block(self.channels, self.strides)
self.model = nn.Sequential(head, encdec, reduce_channels)

def _get_bottom_layer(self, in_channels: int, out_channels: int, strides: int, level: int) -> nn.Module:
def _get_bottom_layer(self, in_channels: int, out_channels: int, strides: int) -> nn.Module:
return AttentionLayer(
spatial_dims=self.dimensions,
in_channels=in_channels,
Expand All @@ -249,6 +266,8 @@ def _get_bottom_layer(self, in_channels: int, out_channels: int, strides: int, l
strides=strides,
dropout=self.dropout,
),
up_kernel_size=self.up_kernel_size,
strides=strides,
dropout=self.dropout,
)

Expand Down
2 changes: 1 addition & 1 deletion tests/test_attentionunet.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def test_attentionunet(self):
shape = (3, 1) + (92,) * dims
input = torch.rand(*shape)
model = att.AttentionUnet(
spatial_dims=dims, in_channels=1, out_channels=2, channels=(3, 4, 5), strides=(2, 2)
spatial_dims=dims, in_channels=1, out_channels=2, channels=(3, 4, 5), up_kernel_size=5, strides=(1, 2)
)
output = model(input)
self.assertEqual(output.shape[2:], input.shape[2:])
Expand Down