From 546c509ace6f352e3ebd202691490755e41b2005 Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Fri, 13 Aug 2021 22:49:37 +0800 Subject: [PATCH 1/4] add unetdecoder Signed-off-by: Yiheng Wang --- monai/networks/blocks/upsample.py | 10 +- monai/networks/nets/basic_unet.py | 146 ++++++++++++++++++++++++++---- monai/networks/nets/dynunet.py | 2 +- 3 files changed, 136 insertions(+), 22 deletions(-) diff --git a/monai/networks/blocks/upsample.py b/monai/networks/blocks/upsample.py index f3c680f050..c3f05d9963 100644 --- a/monai/networks/blocks/upsample.py +++ b/monai/networks/blocks/upsample.py @@ -60,21 +60,21 @@ def __init__( thus if size is defined, `scale_factor` will not be used. Defaults to None. mode: {``"deconv"``, ``"nontrainable"``, ``"pixelshuffle"``}. Defaults to ``"deconv"``. - pre_conv: a conv block applied before upsampling. Defaults to None. + pre_conv: a conv block applied before upsampling. Defaults to "default". When ``conv_block`` is ``"default"``, one reserved conv layer will be utilized when Only used in the "nontrainable" or "pixelshuffle" mode. interp_mode: {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``} - Only used when ``mode`` is ``UpsampleMode.NONTRAINABLE``. + Only used in the "nontrainable" mode. If ends with ``"linear"`` will use ``spatial dims`` to determine the correct interpolation. This corresponds to linear, bilinear, trilinear for 1D, 2D, and 3D respectively. The interpolation mode. Defaults to ``"linear"``. See also: https://pytorch.org/docs/stable/nn.html#upsample align_corners: set the align_corners parameter of `torch.nn.Upsample`. Defaults to True. - Only used in the nontrainable mode. + Only used in the "nontrainable" mode. bias: whether to have a bias term in the default preconv and deconv layers. Defaults to True. apply_pad_pool: if True the upsampled tensor is padded then average pooling is applied with a kernel the size of `scale_factor` with a stride of 1. See also: :py:class:`monai.networks.blocks.SubpixelUpsample`. - Only used in the pixelshuffle mode. + Only used in the "pixelshuffle" mode. """ super().__init__() scale_factor_ = ensure_tuple_rep(scale_factor, dimensions) @@ -104,6 +104,8 @@ def __init__( ) elif pre_conv is not None and pre_conv != "default": self.add_module("preconv", pre_conv) # type: ignore + elif pre_conv is None and (out_channels != in_channels): + raise ValueError(f"in the nontrainable mode, if not setting pre_conv, out_channels should equal to in_channels.") interp_mode = InterpolateMode(interp_mode) linear_mode = [InterpolateMode.LINEAR, InterpolateMode.BILINEAR, InterpolateMode.TRILINEAR] diff --git a/monai/networks/nets/basic_unet.py b/monai/networks/nets/basic_unet.py index 08f2c92272..9660d5cb9e 100644 --- a/monai/networks/nets/basic_unet.py +++ b/monai/networks/nets/basic_unet.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Sequence, Union +from typing import Optional, Sequence, Tuple, Union import torch import torch.nn as nn @@ -92,6 +92,9 @@ def __init__( norm: Union[str, tuple], dropout: Union[float, tuple] = 0.0, upsample: str = "deconv", + pre_conv: Optional[str] = "default", + interp_mode: str = "linear", + align_corners: Optional[bool] = True, halves: bool = True, ): """ @@ -105,12 +108,30 @@ def __init__( dropout: dropout ratio. Defaults to no dropout. upsample: upsampling mode, available options are ``"deconv"``, ``"pixelshuffle"``, ``"nontrainable"``. + pre_conv: a conv block applied before upsampling. + Only used in the "nontrainable" or "pixelshuffle" mode. + interp_mode: {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``} + Only used in the "nontrainable" mode. + align_corners: set the align_corners parameter for upsample. Defaults to True. + Only used in the "nontrainable" mode. halves: whether to halve the number of channels during upsampling. + This parameter does not work on ``nontrainable`` mode if ``pre_conv`` is `None`. """ super().__init__() - - up_chns = in_chns // 2 if halves else in_chns - self.upsample = UpSample(dim, in_chns, up_chns, 2, mode=upsample) + if upsample == "nontrainable" and pre_conv is None: + up_chns = in_chns + else: + up_chns = in_chns // 2 if halves else in_chns + self.upsample = UpSample( + dim, + in_chns, + up_chns, + 2, + mode=upsample, + pre_conv=pre_conv, + interp_mode=interp_mode, + align_corners=align_corners, + ) self.convs = TwoConv(dim, cat_chns + up_chns, out_chns, act, norm, dropout) def forward(self, x: torch.Tensor, x_e: torch.Tensor): @@ -134,6 +155,85 @@ def forward(self, x: torch.Tensor, x_e: torch.Tensor): return x +class BasicUNetDecoder(nn.Module): + def __init__( + self, + dim: int, + encoder_channels: Sequence[int], + decoder_channels: Sequence[int], + act: Union[str, tuple], + norm: Union[str, tuple], + dropout: Union[float, tuple], + upsample: str, + pre_conv: Optional[str], + interp_mode: str, + align_corners: Optional[bool], + ): + """ + Decoder of BasicUNet. + This class refers to `segmentation_models.pytorch + `_. + + Args: + dim: number of spatial dimensions. + encoder_channels: number of output channels for all feature maps in encoder. + `len(encoder_channels)` should be no less than 2. + decoder_channels: number of output channels for all feature maps in decoder. + `len(decoder_channels)` should equal to `len(encoder_channels) - 1`. + act: activation type and arguments. + norm: feature normalization type and arguments. + dropout: dropout ratio. + upsample: upsampling mode, available options are + ``"deconv"``, ``"pixelshuffle"``, ``"nontrainable"``. + pre_conv: a conv block applied before upsampling. + Only used in the "nontrainable" or "pixelshuffle" mode. + interp_mode: {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``} + Only used in the "nontrainable" mode. + align_corners: set the align_corners parameter for upsample. Defaults to True. + Only used in the "nontrainable" mode. + + """ + super().__init__() + if len(encoder_channels) < 2: + raise ValueError("the length of `encoder_channels` should be no less than 2") + if len(decoder_channels) != len(encoder_channels) - 1: + raise ValueError("`len(decoder_channels)` should equal to `len(encoder_channels) - 1`") + + in_channels = [encoder_channels[-1]] + list(decoder_channels[:-1]) + skip_channels = list(encoder_channels[:-1][::-1]) + halves = [True] * (len(skip_channels) - 1) + halves.append(False) + blocks = [] + for in_chn, skip_chn, out_chn, halve in zip(in_channels, skip_channels, decoder_channels, halves): + blocks.append( + UpCat( + dim=dim, + in_chns=in_chn, + cat_chns=skip_chn, + out_chns=out_chn, + act=act, + norm=norm, + dropout=dropout, + upsample=upsample, + pre_conv=pre_conv, + interp_mode=interp_mode, + align_corners=align_corners, + halves=halve, + ) + ) + self.blocks = nn.ModuleList(blocks) + + def forward(self, *feature_maps: Sequence[torch.Tensor]): + + feature_maps = feature_maps[0][::-1] + skips = feature_maps[1:] + x = feature_maps[0] + for i, block in enumerate(self.blocks): + x = block(x, skips[i]) + + return x + + class BasicUNet(nn.Module): def __init__( self, @@ -145,6 +245,9 @@ def __init__( norm: Union[str, tuple] = ("instance", {"affine": True}), dropout: Union[float, tuple] = 0.0, upsample: str = "deconv", + pre_conv: Optional[str] = "default", + interp_mode: str = "linear", + align_corners: Optional[bool] = True, ): """ A UNet implementation with 1D/2D/3D supports. @@ -161,24 +264,25 @@ def __init__( out_channels: number of output channels. Defaults to 2. features: six integers as numbers of features. Defaults to ``(32, 32, 64, 128, 256, 32)``, - - the first five values correspond to the five-level encoder feature sizes. - the last value corresponds to the feature size after the last upsampling. - act: activation type and arguments. Defaults to LeakyReLU. norm: feature normalization type and arguments. Defaults to instance norm. dropout: dropout ratio. Defaults to no dropout. upsample: upsampling mode, available options are ``"deconv"``, ``"pixelshuffle"``, ``"nontrainable"``. + pre_conv: a conv block applied before upsampling. + Only used in the "nontrainable" or "pixelshuffle" mode. + interp_mode: {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``} + Only used in the "nontrainable" mode. + align_corners: set the align_corners parameter for upsample. Defaults to True. + Only used in the "nontrainable" mode. Examples:: - # for spatial 2D >>> net = BasicUNet(dimensions=2, features=(64, 128, 256, 512, 1024, 128)) - # for spatial 2D, with group norm >>> net = BasicUNet(dimensions=2, features=(64, 128, 256, 512, 1024, 128), norm=("group", {"num_groups": 4})) - # for spatial 3D >>> net = BasicUNet(dimensions=3, features=(32, 32, 64, 128, 256, 32)) @@ -192,6 +296,9 @@ def __init__( fea = ensure_tuple_rep(features, 6) print(f"BasicUNet features: {fea}.") + encoder_channels = list(features[:-1]) + decoder_channels = list(features[1:-2][::-1]) + decoder_channels.append(fea[-1]) self.conv_0 = TwoConv(dimensions, in_channels, features[0], act, norm, dropout) self.down_1 = Down(dimensions, fea[0], fea[1], act, norm, dropout) @@ -199,10 +306,18 @@ def __init__( self.down_3 = Down(dimensions, fea[2], fea[3], act, norm, dropout) self.down_4 = Down(dimensions, fea[3], fea[4], act, norm, dropout) - self.upcat_4 = UpCat(dimensions, fea[4], fea[3], fea[3], act, norm, dropout, upsample) - self.upcat_3 = UpCat(dimensions, fea[3], fea[2], fea[2], act, norm, dropout, upsample) - self.upcat_2 = UpCat(dimensions, fea[2], fea[1], fea[1], act, norm, dropout, upsample) - self.upcat_1 = UpCat(dimensions, fea[1], fea[0], fea[5], act, norm, dropout, upsample, halves=False) + self.decoder = BasicUNetDecoder( + dim=dimensions, + encoder_channels=encoder_channels, + decoder_channels=decoder_channels, + act=act, + norm=norm, + dropout=dropout, + upsample=upsample, + pre_conv=pre_conv, + interp_mode=interp_mode, + align_corners=align_corners, + ) self.final_conv = Conv["conv", dimensions](fea[5], out_channels, kernel_size=1) @@ -225,10 +340,7 @@ def forward(self, x: torch.Tensor): x3 = self.down_3(x2) x4 = self.down_4(x3) - u4 = self.upcat_4(x4, x3) - u3 = self.upcat_3(u4, x2) - u2 = self.upcat_2(u3, x1) - u1 = self.upcat_1(u2, x0) + u1 = self.decoder([x0, x1, x2, x3, x4]) logits = self.final_conv(u1) return logits diff --git a/monai/networks/nets/dynunet.py b/monai/networks/nets/dynunet.py index b0ea249c6a..3922249b78 100644 --- a/monai/networks/nets/dynunet.py +++ b/monai/networks/nets/dynunet.py @@ -26,7 +26,7 @@ class DynUNetSkipLayer(nn.Module): Defines a layer in the UNet topology which combines the downsample and upsample pathways with the skip connection. The member `next_layer` may refer to instances of this class or the final bottleneck layer at the bottom the UNet structure. The purpose of using a recursive class like this is to get around the Torchscript restrictions on - looping over lists of layers and accumulating lists of output tensors which much be indexed. The `heads` list is + looping over lists of layers and accumulating lists of output tensors which must be indexed. The `heads` list is shared amongst all the instances of this class and is used to store the output from the supervision heads during forward passes of the network. """ From 6249bdba37bf24d3c14109ea2e64e5b91f1d36a3 Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Sat, 14 Aug 2021 15:43:31 +0800 Subject: [PATCH 2/4] remove decoder Signed-off-by: Yiheng Wang --- monai/networks/nets/basic_unet.py | 126 ++++-------------------------- 1 file changed, 15 insertions(+), 111 deletions(-) diff --git a/monai/networks/nets/basic_unet.py b/monai/networks/nets/basic_unet.py index 9660d5cb9e..57660f15bd 100644 --- a/monai/networks/nets/basic_unet.py +++ b/monai/networks/nets/basic_unet.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Sequence, Tuple, Union +from typing import Optional, Sequence, Union import torch import torch.nn as nn @@ -92,7 +92,7 @@ def __init__( norm: Union[str, tuple], dropout: Union[float, tuple] = 0.0, upsample: str = "deconv", - pre_conv: Optional[str] = "default", + pre_conv: Optional[Union[nn.Module, str]] = "default", interp_mode: str = "linear", align_corners: Optional[bool] = True, halves: bool = True, @@ -155,85 +155,6 @@ def forward(self, x: torch.Tensor, x_e: torch.Tensor): return x -class BasicUNetDecoder(nn.Module): - def __init__( - self, - dim: int, - encoder_channels: Sequence[int], - decoder_channels: Sequence[int], - act: Union[str, tuple], - norm: Union[str, tuple], - dropout: Union[float, tuple], - upsample: str, - pre_conv: Optional[str], - interp_mode: str, - align_corners: Optional[bool], - ): - """ - Decoder of BasicUNet. - This class refers to `segmentation_models.pytorch - `_. - - Args: - dim: number of spatial dimensions. - encoder_channels: number of output channels for all feature maps in encoder. - `len(encoder_channels)` should be no less than 2. - decoder_channels: number of output channels for all feature maps in decoder. - `len(decoder_channels)` should equal to `len(encoder_channels) - 1`. - act: activation type and arguments. - norm: feature normalization type and arguments. - dropout: dropout ratio. - upsample: upsampling mode, available options are - ``"deconv"``, ``"pixelshuffle"``, ``"nontrainable"``. - pre_conv: a conv block applied before upsampling. - Only used in the "nontrainable" or "pixelshuffle" mode. - interp_mode: {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``} - Only used in the "nontrainable" mode. - align_corners: set the align_corners parameter for upsample. Defaults to True. - Only used in the "nontrainable" mode. - - """ - super().__init__() - if len(encoder_channels) < 2: - raise ValueError("the length of `encoder_channels` should be no less than 2") - if len(decoder_channels) != len(encoder_channels) - 1: - raise ValueError("`len(decoder_channels)` should equal to `len(encoder_channels) - 1`") - - in_channels = [encoder_channels[-1]] + list(decoder_channels[:-1]) - skip_channels = list(encoder_channels[:-1][::-1]) - halves = [True] * (len(skip_channels) - 1) - halves.append(False) - blocks = [] - for in_chn, skip_chn, out_chn, halve in zip(in_channels, skip_channels, decoder_channels, halves): - blocks.append( - UpCat( - dim=dim, - in_chns=in_chn, - cat_chns=skip_chn, - out_chns=out_chn, - act=act, - norm=norm, - dropout=dropout, - upsample=upsample, - pre_conv=pre_conv, - interp_mode=interp_mode, - align_corners=align_corners, - halves=halve, - ) - ) - self.blocks = nn.ModuleList(blocks) - - def forward(self, *feature_maps: Sequence[torch.Tensor]): - - feature_maps = feature_maps[0][::-1] - skips = feature_maps[1:] - x = feature_maps[0] - for i, block in enumerate(self.blocks): - x = block(x, skips[i]) - - return x - - class BasicUNet(nn.Module): def __init__( self, @@ -245,9 +166,6 @@ def __init__( norm: Union[str, tuple] = ("instance", {"affine": True}), dropout: Union[float, tuple] = 0.0, upsample: str = "deconv", - pre_conv: Optional[str] = "default", - interp_mode: str = "linear", - align_corners: Optional[bool] = True, ): """ A UNet implementation with 1D/2D/3D supports. @@ -264,25 +182,24 @@ def __init__( out_channels: number of output channels. Defaults to 2. features: six integers as numbers of features. Defaults to ``(32, 32, 64, 128, 256, 32)``, + - the first five values correspond to the five-level encoder feature sizes. - the last value corresponds to the feature size after the last upsampling. + act: activation type and arguments. Defaults to LeakyReLU. norm: feature normalization type and arguments. Defaults to instance norm. dropout: dropout ratio. Defaults to no dropout. upsample: upsampling mode, available options are ``"deconv"``, ``"pixelshuffle"``, ``"nontrainable"``. - pre_conv: a conv block applied before upsampling. - Only used in the "nontrainable" or "pixelshuffle" mode. - interp_mode: {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``} - Only used in the "nontrainable" mode. - align_corners: set the align_corners parameter for upsample. Defaults to True. - Only used in the "nontrainable" mode. Examples:: + # for spatial 2D >>> net = BasicUNet(dimensions=2, features=(64, 128, 256, 512, 1024, 128)) + # for spatial 2D, with group norm >>> net = BasicUNet(dimensions=2, features=(64, 128, 256, 512, 1024, 128), norm=("group", {"num_groups": 4})) + # for spatial 3D >>> net = BasicUNet(dimensions=3, features=(32, 32, 64, 128, 256, 32)) @@ -296,9 +213,6 @@ def __init__( fea = ensure_tuple_rep(features, 6) print(f"BasicUNet features: {fea}.") - encoder_channels = list(features[:-1]) - decoder_channels = list(features[1:-2][::-1]) - decoder_channels.append(fea[-1]) self.conv_0 = TwoConv(dimensions, in_channels, features[0], act, norm, dropout) self.down_1 = Down(dimensions, fea[0], fea[1], act, norm, dropout) @@ -306,18 +220,10 @@ def __init__( self.down_3 = Down(dimensions, fea[2], fea[3], act, norm, dropout) self.down_4 = Down(dimensions, fea[3], fea[4], act, norm, dropout) - self.decoder = BasicUNetDecoder( - dim=dimensions, - encoder_channels=encoder_channels, - decoder_channels=decoder_channels, - act=act, - norm=norm, - dropout=dropout, - upsample=upsample, - pre_conv=pre_conv, - interp_mode=interp_mode, - align_corners=align_corners, - ) + self.upcat_4 = UpCat(dimensions, fea[4], fea[3], fea[3], act, norm, dropout, upsample) + self.upcat_3 = UpCat(dimensions, fea[3], fea[2], fea[2], act, norm, dropout, upsample) + self.upcat_2 = UpCat(dimensions, fea[2], fea[1], fea[1], act, norm, dropout, upsample) + self.upcat_1 = UpCat(dimensions, fea[1], fea[0], fea[5], act, norm, dropout, upsample, halves=False) self.final_conv = Conv["conv", dimensions](fea[5], out_channels, kernel_size=1) @@ -335,12 +241,10 @@ def forward(self, x: torch.Tensor): """ x0 = self.conv_0(x) - x1 = self.down_1(x0) - x2 = self.down_2(x1) - x3 = self.down_3(x2) - x4 = self.down_4(x3) - - u1 = self.decoder([x0, x1, x2, x3, x4]) + u4 = self.upcat_4(x4, x3) + u3 = self.upcat_3(u4, x2) + u2 = self.upcat_2(u3, x1) + u1 = self.upcat_1(u2, x0) logits = self.final_conv(u1) return logits From 84f0bd6a9348e4536703547032bddabc0a585de7 Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Sat, 14 Aug 2021 15:45:47 +0800 Subject: [PATCH 3/4] recover mis deleted lines Signed-off-by: Yiheng Wang --- monai/networks/nets/basic_unet.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/monai/networks/nets/basic_unet.py b/monai/networks/nets/basic_unet.py index 57660f15bd..18bb1d3a6c 100644 --- a/monai/networks/nets/basic_unet.py +++ b/monai/networks/nets/basic_unet.py @@ -199,7 +199,7 @@ def __init__( # for spatial 2D, with group norm >>> net = BasicUNet(dimensions=2, features=(64, 128, 256, 512, 1024, 128), norm=("group", {"num_groups": 4})) - + # for spatial 3D >>> net = BasicUNet(dimensions=3, features=(32, 32, 64, 128, 256, 32)) @@ -241,6 +241,11 @@ def forward(self, x: torch.Tensor): """ x0 = self.conv_0(x) + x1 = self.down_1(x0) + x2 = self.down_2(x1) + x3 = self.down_3(x2) + x4 = self.down_4(x3) + u4 = self.upcat_4(x4, x3) u3 = self.upcat_3(u4, x2) u2 = self.upcat_2(u3, x1) From 84c11ff3265edb0b740fe14938a7d38aa6c8b41b Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Sat, 14 Aug 2021 15:55:55 +0800 Subject: [PATCH 4/4] fix black error Signed-off-by: Yiheng Wang --- monai/networks/blocks/upsample.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/monai/networks/blocks/upsample.py b/monai/networks/blocks/upsample.py index c3f05d9963..5320611ce6 100644 --- a/monai/networks/blocks/upsample.py +++ b/monai/networks/blocks/upsample.py @@ -105,7 +105,9 @@ def __init__( elif pre_conv is not None and pre_conv != "default": self.add_module("preconv", pre_conv) # type: ignore elif pre_conv is None and (out_channels != in_channels): - raise ValueError(f"in the nontrainable mode, if not setting pre_conv, out_channels should equal to in_channels.") + raise ValueError( + "in the nontrainable mode, if not setting pre_conv, out_channels should equal to in_channels." + ) interp_mode = InterpolateMode(interp_mode) linear_mode = [InterpolateMode.LINEAR, InterpolateMode.BILINEAR, InterpolateMode.TRILINEAR]