From a0c8fcd39aac1de6cea7be8dd1899493d3aa766b Mon Sep 17 00:00:00 2001 From: binliu Date: Tue, 21 Mar 2023 08:57:00 +0000 Subject: [PATCH 1/5] add pre_conv parameter to FlexibleUNet and update the type hint in basic unet Signed-off-by: binliu --- monai/networks/nets/basic_unet.py | 3 ++- monai/networks/nets/flexible_unet.py | 5 ++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/monai/networks/nets/basic_unet.py b/monai/networks/nets/basic_unet.py index b26fdcb622..4656e6b707 100644 --- a/monai/networks/nets/basic_unet.py +++ b/monai/networks/nets/basic_unet.py @@ -19,6 +19,7 @@ from monai.networks.blocks import Convolution, UpSample from monai.networks.layers.factories import Conv, Pool from monai.utils import ensure_tuple_rep +from typing import Optional __all__ = ["BasicUnet", "Basicunet", "basicunet", "BasicUNet"] @@ -149,7 +150,7 @@ def __init__( self.convs = TwoConv(spatial_dims, cat_chns + up_chns, out_chns, act, norm, bias, dropout) self.is_pad = is_pad - def forward(self, x: torch.Tensor, x_e: torch.Tensor | None): + def forward(self, x: torch.Tensor, x_e: Optional[torch.Tensor]): """ Args: diff --git a/monai/networks/nets/flexible_unet.py b/monai/networks/nets/flexible_unet.py index a880cafdc3..ac2124b5f9 100644 --- a/monai/networks/nets/flexible_unet.py +++ b/monai/networks/nets/flexible_unet.py @@ -232,6 +232,7 @@ def __init__( dropout: float | tuple = 0.0, decoder_bias: bool = False, upsample: str = "nontrainable", + pre_conv: str = "default", interp_mode: str = "nearest", is_pad: bool = True, ) -> None: @@ -262,6 +263,8 @@ def __init__( decoder_bias: whether to have a bias term in decoder's convolution blocks. upsample: upsampling mode, available options are``"deconv"``, ``"pixelshuffle"``, ``"nontrainable"``. + pre_conv:a conv block applied before upsampling. Only used in the "nontrainable" or + "pixelshuffle" mode, default to `default`. interp_mode: {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``} Only used in the "nontrainable" mode. is_pad: whether to pad upsampling features to fit features from encoder. Default to True. @@ -309,7 +312,7 @@ def __init__( bias=decoder_bias, upsample=upsample, interp_mode=interp_mode, - pre_conv="default", + pre_conv=pre_conv, align_corners=None, is_pad=is_pad, ) From f9ee3753a4d410e48817ea5038d47f32ad0d2119 Mon Sep 17 00:00:00 2001 From: binliu Date: Tue, 21 Mar 2023 08:59:20 +0000 Subject: [PATCH 2/5] update the import order Signed-off-by: binliu --- monai/networks/nets/basic_unet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/networks/nets/basic_unet.py b/monai/networks/nets/basic_unet.py index 4656e6b707..301edcbebd 100644 --- a/monai/networks/nets/basic_unet.py +++ b/monai/networks/nets/basic_unet.py @@ -12,6 +12,7 @@ from __future__ import annotations from collections.abc import Sequence +from typing import Optional import torch import torch.nn as nn @@ -19,7 +20,6 @@ from monai.networks.blocks import Convolution, UpSample from monai.networks.layers.factories import Conv, Pool from monai.utils import ensure_tuple_rep -from typing import Optional __all__ = ["BasicUnet", "Basicunet", "basicunet", "BasicUNet"] From 982bad580948d7b36ae824ae08fbfc00d91baa75 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 21 Mar 2023 09:05:45 +0000 Subject: [PATCH 3/5] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/networks/nets/basic_unet.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/monai/networks/nets/basic_unet.py b/monai/networks/nets/basic_unet.py index 301edcbebd..b26fdcb622 100644 --- a/monai/networks/nets/basic_unet.py +++ b/monai/networks/nets/basic_unet.py @@ -12,7 +12,6 @@ from __future__ import annotations from collections.abc import Sequence -from typing import Optional import torch import torch.nn as nn @@ -150,7 +149,7 @@ def __init__( self.convs = TwoConv(spatial_dims, cat_chns + up_chns, out_chns, act, norm, bias, dropout) self.is_pad = is_pad - def forward(self, x: torch.Tensor, x_e: Optional[torch.Tensor]): + def forward(self, x: torch.Tensor, x_e: torch.Tensor | None): """ Args: From f4b9838a5462094c3811d52566ada5626b9fdfb6 Mon Sep 17 00:00:00 2001 From: binliu Date: Tue, 21 Mar 2023 09:15:01 +0000 Subject: [PATCH 4/5] change the type hint to Union to avoid being fixed by pre commit Signed-off-by: binliu --- monai/networks/nets/basic_unet.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/monai/networks/nets/basic_unet.py b/monai/networks/nets/basic_unet.py index b26fdcb622..a43b828ad5 100644 --- a/monai/networks/nets/basic_unet.py +++ b/monai/networks/nets/basic_unet.py @@ -12,6 +12,7 @@ from __future__ import annotations from collections.abc import Sequence +from typing import Union import torch import torch.nn as nn @@ -149,7 +150,7 @@ def __init__( self.convs = TwoConv(spatial_dims, cat_chns + up_chns, out_chns, act, norm, bias, dropout) self.is_pad = is_pad - def forward(self, x: torch.Tensor, x_e: torch.Tensor | None): + def forward(self, x: torch.Tensor, x_e: Union[torch.Tensor, None]): """ Args: From 8b78144e8c11f9c2871ad92afad17cc2371523a7 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Tue, 21 Mar 2023 09:41:43 +0000 Subject: [PATCH 5/5] workaround for typing tensorrt Signed-off-by: Wenqi Li --- monai/networks/nets/basic_unet.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/monai/networks/nets/basic_unet.py b/monai/networks/nets/basic_unet.py index a43b828ad5..c91335ccba 100644 --- a/monai/networks/nets/basic_unet.py +++ b/monai/networks/nets/basic_unet.py @@ -12,7 +12,6 @@ from __future__ import annotations from collections.abc import Sequence -from typing import Union import torch import torch.nn as nn @@ -150,12 +149,12 @@ def __init__( self.convs = TwoConv(spatial_dims, cat_chns + up_chns, out_chns, act, norm, bias, dropout) self.is_pad = is_pad - def forward(self, x: torch.Tensor, x_e: Union[torch.Tensor, None]): + def forward(self, x: torch.Tensor, x_e: torch.Tensor): """ Args: x: features to be upsampled. - x_e: features from the encoder. + x_e: optional features from the encoder, if None, this branch is not in use. """ x_0 = self.upsample(x)