diff --git a/monai/networks/blocks/upsample.py b/monai/networks/blocks/upsample.py index f3c680f050..5320611ce6 100644 --- a/monai/networks/blocks/upsample.py +++ b/monai/networks/blocks/upsample.py @@ -60,21 +60,21 @@ def __init__( thus if size is defined, `scale_factor` will not be used. Defaults to None. mode: {``"deconv"``, ``"nontrainable"``, ``"pixelshuffle"``}. Defaults to ``"deconv"``. - pre_conv: a conv block applied before upsampling. Defaults to None. + pre_conv: a conv block applied before upsampling. Defaults to "default". When ``conv_block`` is ``"default"``, one reserved conv layer will be utilized when Only used in the "nontrainable" or "pixelshuffle" mode. interp_mode: {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``} - Only used when ``mode`` is ``UpsampleMode.NONTRAINABLE``. + Only used in the "nontrainable" mode. If ends with ``"linear"`` will use ``spatial dims`` to determine the correct interpolation. This corresponds to linear, bilinear, trilinear for 1D, 2D, and 3D respectively. The interpolation mode. Defaults to ``"linear"``. See also: https://pytorch.org/docs/stable/nn.html#upsample align_corners: set the align_corners parameter of `torch.nn.Upsample`. Defaults to True. - Only used in the nontrainable mode. + Only used in the "nontrainable" mode. bias: whether to have a bias term in the default preconv and deconv layers. Defaults to True. apply_pad_pool: if True the upsampled tensor is padded then average pooling is applied with a kernel the size of `scale_factor` with a stride of 1. See also: :py:class:`monai.networks.blocks.SubpixelUpsample`. - Only used in the pixelshuffle mode. + Only used in the "pixelshuffle" mode. """ super().__init__() scale_factor_ = ensure_tuple_rep(scale_factor, dimensions) @@ -104,6 +104,10 @@ def __init__( ) elif pre_conv is not None and pre_conv != "default": self.add_module("preconv", pre_conv) # type: ignore + elif pre_conv is None and (out_channels != in_channels): + raise ValueError( + "in the nontrainable mode, if not setting pre_conv, out_channels should equal to in_channels." + ) interp_mode = InterpolateMode(interp_mode) linear_mode = [InterpolateMode.LINEAR, InterpolateMode.BILINEAR, InterpolateMode.TRILINEAR] diff --git a/monai/networks/nets/basic_unet.py b/monai/networks/nets/basic_unet.py index 08f2c92272..18bb1d3a6c 100644 --- a/monai/networks/nets/basic_unet.py +++ b/monai/networks/nets/basic_unet.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Sequence, Union +from typing import Optional, Sequence, Union import torch import torch.nn as nn @@ -92,6 +92,9 @@ def __init__( norm: Union[str, tuple], dropout: Union[float, tuple] = 0.0, upsample: str = "deconv", + pre_conv: Optional[Union[nn.Module, str]] = "default", + interp_mode: str = "linear", + align_corners: Optional[bool] = True, halves: bool = True, ): """ @@ -105,12 +108,30 @@ def __init__( dropout: dropout ratio. Defaults to no dropout. upsample: upsampling mode, available options are ``"deconv"``, ``"pixelshuffle"``, ``"nontrainable"``. + pre_conv: a conv block applied before upsampling. + Only used in the "nontrainable" or "pixelshuffle" mode. + interp_mode: {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``} + Only used in the "nontrainable" mode. + align_corners: set the align_corners parameter for upsample. Defaults to True. + Only used in the "nontrainable" mode. halves: whether to halve the number of channels during upsampling. + This parameter does not work on ``nontrainable`` mode if ``pre_conv`` is `None`. """ super().__init__() - - up_chns = in_chns // 2 if halves else in_chns - self.upsample = UpSample(dim, in_chns, up_chns, 2, mode=upsample) + if upsample == "nontrainable" and pre_conv is None: + up_chns = in_chns + else: + up_chns = in_chns // 2 if halves else in_chns + self.upsample = UpSample( + dim, + in_chns, + up_chns, + 2, + mode=upsample, + pre_conv=pre_conv, + interp_mode=interp_mode, + align_corners=align_corners, + ) self.convs = TwoConv(dim, cat_chns + up_chns, out_chns, act, norm, dropout) def forward(self, x: torch.Tensor, x_e: torch.Tensor): diff --git a/monai/networks/nets/dynunet.py b/monai/networks/nets/dynunet.py index b0ea249c6a..3922249b78 100644 --- a/monai/networks/nets/dynunet.py +++ b/monai/networks/nets/dynunet.py @@ -26,7 +26,7 @@ class DynUNetSkipLayer(nn.Module): Defines a layer in the UNet topology which combines the downsample and upsample pathways with the skip connection. The member `next_layer` may refer to instances of this class or the final bottleneck layer at the bottom the UNet structure. The purpose of using a recursive class like this is to get around the Torchscript restrictions on - looping over lists of layers and accumulating lists of output tensors which much be indexed. The `heads` list is + looping over lists of layers and accumulating lists of output tensors which must be indexed. The `heads` list is shared amongst all the instances of this class and is used to store the output from the supervision heads during forward passes of the network. """