From b1b886e91e561bcda420650c1bebb587d2b5417d Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Wed, 4 Jan 2023 20:50:52 +0000 Subject: [PATCH 1/2] Flexible interp modes Signed-off-by: Wenqi Li --- monai/networks/blocks/localnet_block.py | 21 +++++++++++++---- monai/networks/blocks/regunet_block.py | 10 +++++++- monai/networks/nets/regunet.py | 31 ++++++++++++++++++++----- tests/test_localnet.py | 2 ++ tests/test_localnet_block.py | 12 +++++++++- tests/test_regunet_block.py | 1 + 6 files changed, 64 insertions(+), 13 deletions(-) diff --git a/monai/networks/blocks/localnet_block.py b/monai/networks/blocks/localnet_block.py index 41b76c7d4c..9cdd0cecc2 100644 --- a/monai/networks/blocks/localnet_block.py +++ b/monai/networks/blocks/localnet_block.py @@ -166,7 +166,7 @@ def forward(self, x) -> Tuple[torch.Tensor, torch.Tensor]: class LocalNetUpSampleBlock(nn.Module): """ - A up-sample module that can be used for LocalNet, based on: + An up-sample module that can be used for LocalNet, based on: `Weakly-supervised convolutional neural networks for multimodal image registration `_. `Label-driven weakly-supervised learning for multimodal deformable image registration @@ -176,12 +176,21 @@ class LocalNetUpSampleBlock(nn.Module): DeepReg (https://github.com/DeepRegNet/DeepReg) """ - def __init__(self, spatial_dims: int, in_channels: int, out_channels: int) -> None: + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + mode: str = "nearest", + align_corners: Optional[bool] = None, + ) -> None: """ Args: spatial_dims: number of spatial dimensions. in_channels: number of input channels. out_channels: number of output channels. + mode: interpolation mode of the additive upsampling, default to 'nearest'. + align_corners: whether to align corners for the additive upsampling, default to None. Raises: ValueError: when ``in_channels != 2 * out_channels`` """ @@ -199,9 +208,11 @@ def __init__(self, spatial_dims: int, in_channels: int, out_channels: int) -> No f"got in_channels={in_channels}, out_channels={out_channels}" ) self.out_channels = out_channels + self.mode = mode + self.align_corners = align_corners - def addictive_upsampling(self, x, mid) -> torch.Tensor: - x = F.interpolate(x, mid.shape[2:]) + def additive_upsampling(self, x, mid) -> torch.Tensor: + x = F.interpolate(x, mid.shape[2:], mode=self.mode, align_corners=self.align_corners) # [(batch, out_channels, ...), (batch, out_channels, ...)] x = x.split(split_size=int(self.out_channels), dim=1) # (batch, out_channels, ...) @@ -226,7 +237,7 @@ def forward(self, x, mid) -> torch.Tensor: "expecting mid spatial dimensions be exactly the double of x spatial dimensions, " f"got x of shape {x.shape}, mid of shape {mid.shape}" ) - h0 = self.deconv_block(x) + self.addictive_upsampling(x, mid) + h0 = self.deconv_block(x) + self.additive_upsampling(x, mid) r1 = h0 + mid r2 = self.conv_block(h0) out: torch.Tensor = self.residual_block(r2, r1) diff --git a/monai/networks/blocks/regunet_block.py b/monai/networks/blocks/regunet_block.py index 78e2598b4b..306b57a827 100644 --- a/monai/networks/blocks/regunet_block.py +++ b/monai/networks/blocks/regunet_block.py @@ -200,6 +200,8 @@ def __init__( out_channels: int, kernel_initializer: Optional[str] = "kaiming_uniform", activation: Optional[str] = None, + mode: str = "nearest", + align_corners: Optional[bool] = None, ): """ @@ -211,6 +213,8 @@ def __init__( out_channels: number of output channels kernel_initializer: kernel initializer activation: kernel activation function + mode: feature map interpolation mode, default to "nearest". + align_corners: whether to align corners for feature map interpolation. """ super().__init__() self.extract_levels = extract_levels @@ -228,6 +232,8 @@ def __init__( for d in extract_levels ] ) + self.mode = mode + self.align_corners = align_corners def forward(self, x: List[torch.Tensor], image_size: List[int]) -> torch.Tensor: """ @@ -240,7 +246,9 @@ def forward(self, x: List[torch.Tensor], image_size: List[int]) -> torch.Tensor: Tensor of shape (batch, `out_channels`, size1, size2, size3), where (size1, size2, size3) = ``image_size`` """ feature_list = [ - F.interpolate(layer(x[self.max_level - level]), size=image_size) + F.interpolate( + layer(x[self.max_level - level]), size=image_size, mode=self.mode, align_corners=self.align_corners + ) for layer, level in zip(self.layers, self.extract_levels) ] out: torch.Tensor = torch.mean(torch.stack(feature_list, dim=0), dim=0) diff --git a/monai/networks/nets/regunet.py b/monai/networks/nets/regunet.py index ee78342459..01fd9b905f 100644 --- a/monai/networks/nets/regunet.py +++ b/monai/networks/nets/regunet.py @@ -337,14 +337,23 @@ def build_output_block(self): class AdditiveUpSampleBlock(nn.Module): - def __init__(self, spatial_dims: int, in_channels: int, out_channels: int): + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + mode: str = "nearest", + align_corners: Optional[bool] = None, + ): super().__init__() self.deconv = get_deconv_block(spatial_dims=spatial_dims, in_channels=in_channels, out_channels=out_channels) + self.mode = mode + self.align_corners = align_corners def forward(self, x: torch.Tensor) -> torch.Tensor: output_size = [size * 2 for size in x.shape[2:]] deconved = self.deconv(x) - resized = F.interpolate(x, output_size) + resized = F.interpolate(x, output_size, mode=self.mode, align_corners=self.align_corners) resized = torch.sum(torch.stack(resized.split(split_size=resized.shape[1] // 2, dim=1), dim=-1), dim=-1) out: torch.Tensor = deconved + resized return out @@ -372,8 +381,10 @@ def __init__( out_activation: Optional[str] = None, out_channels: int = 3, pooling: bool = True, - use_addictive_sampling: bool = True, + use_additive_sampling: bool = True, concat_skip: bool = False, + mode: str = "nearest", + align_corners: Optional[bool] = None, ): """ Args: @@ -385,10 +396,14 @@ def __init__( out_channels: number of channels for the output extract_levels: list, which levels from net to extract. The maximum level must equal to ``depth`` pooling: for down-sampling, use non-parameterized pooling if true, otherwise use conv3d - use_addictive_sampling: whether use additive up-sampling layer for decoding. + use_additive_sampling: whether use additive up-sampling layer for decoding. concat_skip: when up-sampling, concatenate skipped tensor if true, otherwise use addition + mode: mode for interpolation when use_additive_sampling, default is "nearest". + align_corners: align_corners for interpolation when use_additive_sampling, default is None. """ - self.use_additive_upsampling = use_addictive_sampling + self.use_additive_upsampling = use_additive_sampling + self.mode = mode + self.align_corners = align_corners super().__init__( spatial_dims=spatial_dims, in_channels=in_channels, @@ -412,7 +427,11 @@ def build_bottom_block(self, in_channels: int, out_channels: int): def build_up_sampling_block(self, in_channels: int, out_channels: int) -> nn.Module: if self.use_additive_upsampling: return AdditiveUpSampleBlock( - spatial_dims=self.spatial_dims, in_channels=in_channels, out_channels=out_channels + spatial_dims=self.spatial_dims, + in_channels=in_channels, + out_channels=out_channels, + mode=self.mode, + align_corners=self.align_corners, ) return get_deconv_block(spatial_dims=self.spatial_dims, in_channels=in_channels, out_channels=out_channels) diff --git a/tests/test_localnet.py b/tests/test_localnet.py index 9ad50b9be8..bdf8a44198 100644 --- a/tests/test_localnet.py +++ b/tests/test_localnet.py @@ -32,6 +32,8 @@ "extract_levels": (0, 1), "pooling": False, "concat_skip": True, + "mode": "bilinear", + "align_corners": True, }, (1, 2, 16, 16), (1, 2, 16, 16), diff --git a/tests/test_localnet_block.py b/tests/test_localnet_block.py index d85509344e..57932860ed 100644 --- a/tests/test_localnet_block.py +++ b/tests/test_localnet_block.py @@ -25,7 +25,17 @@ [{"spatial_dims": spatial_dims, "in_channels": 2, "out_channels": 4, "kernel_size": 3}] for spatial_dims in [2, 3] ] -TEST_CASE_UP_SAMPLE = [[{"spatial_dims": spatial_dims, "in_channels": 4, "out_channels": 2}] for spatial_dims in [2, 3]] +TEST_CASE_UP_SAMPLE = [ + [ + { + "spatial_dims": spatial_dims, + "in_channels": 4, + "out_channels": 2, + "mode": "bilinear" if spatial_dims == 2 else "trilinear", + } + ] + for spatial_dims in [2, 3] +] TEST_CASE_EXTRACT = [ [{"spatial_dims": spatial_dims, "in_channels": 2, "out_channels": 3, "act": act, "initializer": initializer}] diff --git a/tests/test_regunet_block.py b/tests/test_regunet_block.py index 3be02ea377..81190fd038 100644 --- a/tests/test_regunet_block.py +++ b/tests/test_regunet_block.py @@ -53,6 +53,7 @@ "out_channels": 1, "kernel_initializer": "zeros", "activation": "sigmoid", + "mode": "trilinear", }, [(1, 3, 2, 2, 2), (1, 2, 4, 4, 4), (1, 1, 8, 8, 8)], (3, 3, 3), From 23b1ef421aa132d553f754d8ce06dac5003609a0 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Thu, 5 Jan 2023 09:08:38 +0000 Subject: [PATCH 2/2] remove win+conda test Signed-off-by: Wenqi Li --- .github/workflows/conda.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/conda.yml b/.github/workflows/conda.yml index d175c9bdaf..a03ba0cf2c 100644 --- a/.github/workflows/conda.yml +++ b/.github/workflows/conda.yml @@ -18,7 +18,7 @@ jobs: strategy: fail-fast: false matrix: - os: [windows-latest, ubuntu-latest] + os: [ubuntu-latest] python-version: ["3.9"] runs-on: ${{ matrix.os }} env: