diff --git a/monai/networks/nets/basic_unet.py b/monai/networks/nets/basic_unet.py index b26fdcb622..c91335ccba 100644 --- a/monai/networks/nets/basic_unet.py +++ b/monai/networks/nets/basic_unet.py @@ -149,12 +149,12 @@ def __init__( self.convs = TwoConv(spatial_dims, cat_chns + up_chns, out_chns, act, norm, bias, dropout) self.is_pad = is_pad - def forward(self, x: torch.Tensor, x_e: torch.Tensor | None): + def forward(self, x: torch.Tensor, x_e: torch.Tensor): """ Args: x: features to be upsampled. - x_e: features from the encoder. + x_e: optional features from the encoder, if None, this branch is not in use. """ x_0 = self.upsample(x) diff --git a/monai/networks/nets/flexible_unet.py b/monai/networks/nets/flexible_unet.py index a880cafdc3..ac2124b5f9 100644 --- a/monai/networks/nets/flexible_unet.py +++ b/monai/networks/nets/flexible_unet.py @@ -232,6 +232,7 @@ def __init__( dropout: float | tuple = 0.0, decoder_bias: bool = False, upsample: str = "nontrainable", + pre_conv: str = "default", interp_mode: str = "nearest", is_pad: bool = True, ) -> None: @@ -262,6 +263,8 @@ def __init__( decoder_bias: whether to have a bias term in decoder's convolution blocks. upsample: upsampling mode, available options are``"deconv"``, ``"pixelshuffle"``, ``"nontrainable"``. + pre_conv:a conv block applied before upsampling. Only used in the "nontrainable" or + "pixelshuffle" mode, default to `default`. interp_mode: {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``} Only used in the "nontrainable" mode. is_pad: whether to pad upsampling features to fit features from encoder. Default to True. @@ -309,7 +312,7 @@ def __init__( bias=decoder_bias, upsample=upsample, interp_mode=interp_mode, - pre_conv="default", + pre_conv=pre_conv, align_corners=None, is_pad=is_pad, )