From 765cfc391a0128302d14fa722fb011cc8e671357 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Mon, 27 Feb 2023 14:35:06 +0000 Subject: [PATCH 1/5] update padding mode Signed-off-by: Wenqi Li --- monai/transforms/croppad/functional.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/monai/transforms/croppad/functional.py b/monai/transforms/croppad/functional.py index 96d6cd8121..179167101f 100644 --- a/monai/transforms/croppad/functional.py +++ b/monai/transforms/croppad/functional.py @@ -63,13 +63,14 @@ def pad_nd(img: torch.Tensor, to_pad: list[tuple[int, int]], mode: str, **kwargs kwargs: other arguments for the `np.pad` or `torch.pad` function. note that `np.pad` treats channel dimension as the first dimension. """ - if mode in {"linear_ramp", "maximum", "mean", "median", "minimum", "symmetric", "empty"}: + if mode in {"linear_ramp", "maximum", "mean", "median", "minimum", "symmetric", "empty", "wrap"}: return _np_pad(img, pad_width=to_pad, mode=mode, **kwargs) mode = convert_pad_mode(dst=img, mode=mode).value try: _pad = ( _pt_pad - if mode in {"reflect", "replicate"} and img.dtype not in {torch.int16, torch.int64, torch.bool, torch.uint8} + if mode in {"reflect", "replicate", "constant", "circular"} + and img.dtype not in {torch.int64, torch.bool} else _np_pad ) return _pad(img, pad_width=to_pad, mode=mode, **kwargs) From cf1c93a5774939467b602915a8765782db1c1b38 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Mon, 27 Feb 2023 10:18:45 -0500 Subject: [PATCH 2/5] update and tests Signed-off-by: Wenqi Li --- monai/transforms/croppad/functional.py | 33 +++++++++++++++++++------- 1 file changed, 24 insertions(+), 9 deletions(-) diff --git a/monai/transforms/croppad/functional.py b/monai/transforms/croppad/functional.py index 179167101f..ae959882ad 100644 --- a/monai/transforms/croppad/functional.py +++ b/monai/transforms/croppad/functional.py @@ -15,6 +15,7 @@ from __future__ import annotations +import warnings import numpy as np import torch from torch.nn.functional import pad as pad_pt @@ -29,20 +30,34 @@ def _np_pad(img: torch.Tensor, pad_width: list[tuple[int, int]], mode: str, **kwargs) -> torch.Tensor: - img_np = img.detach().cpu().numpy() if isinstance(img, torch.Tensor) else img + if isinstance(img, torch.Tensor): + if img.is_cuda: + warnings.warn(f"Padding: moving img {img.shape} on cuda to cpu for dtype={img.dtype} mode={mode}.") + img_np = img.detach().cpu().numpy() + else: + img_np = img mode = convert_pad_mode(dst=img_np, mode=mode).value if mode == "constant" and "value" in kwargs: - kwargs["constant_values"] = kwargs.pop("value") - out = torch.as_tensor(np.pad(img, pad_width, mode=mode, **kwargs)) # type: ignore + _kwargs = kwargs.copy() + _kwargs["constant_values"] = _kwargs.pop("value") + else: + _kwargs = kwargs + out = torch.as_tensor(np.pad(img, pad_width, mode=mode, **_kwargs)) # type: ignore if isinstance(img, MetaTensor): out = convert_to_dst_type(out, dst=img)[0] return out def _pt_pad(img: torch.Tensor, pad_width: list[tuple[int, int]], mode: str, **kwargs) -> torch.Tensor: + mode = convert_pad_mode(dst=img, mode=mode).value + if mode == "constant" and "constant_values" in kwargs: + _kwargs = kwargs.copy() + _kwargs["value"] = _kwargs.pop("constant_values") + else: + _kwargs = kwargs pt_pad_width = [val for sublist in pad_width[1:] for val in sublist[::-1]][::-1] # torch.pad expects `[B, C, H, W, [D]]` shape - return pad_pt(img.unsqueeze(0), pt_pad_width, mode=mode, **kwargs).squeeze(0) + return pad_pt(img.unsqueeze(0), pt_pad_width, mode=mode, **_kwargs).squeeze(0) def pad_nd(img: torch.Tensor, to_pad: list[tuple[int, int]], mode: str, **kwargs): @@ -68,15 +83,15 @@ def pad_nd(img: torch.Tensor, to_pad: list[tuple[int, int]], mode: str, **kwargs mode = convert_pad_mode(dst=img, mode=mode).value try: _pad = ( - _pt_pad - if mode in {"reflect", "replicate", "constant", "circular"} - and img.dtype not in {torch.int64, torch.bool} - else _np_pad + _np_pad + if mode in {"reflect", "replicate"} + and img.dtype in {torch.int16, torch.int64, torch.bool, torch.uint8} + else _pt_pad ) return _pad(img, pad_width=to_pad, mode=mode, **kwargs) except (ValueError, TypeError, RuntimeError) as err: if isinstance(err, NotImplementedError) or any( - k in str(err) for k in ("supported", "unexpected keyword", "implemented") + k in str(err) for k in ("supported", "unexpected keyword", "implemented", "value") ): return _np_pad(img, pad_width=to_pad, mode=mode, **kwargs) raise ValueError(f"{img.shape} {to_pad} {mode} {kwargs} {img.dtype} {img.device}") from err From 7f45545dd6418db27744a46faf33324a20cc60a6 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Mon, 27 Feb 2023 15:20:35 +0000 Subject: [PATCH 3/5] update Signed-off-by: Wenqi Li --- monai/transforms/croppad/functional.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/transforms/croppad/functional.py b/monai/transforms/croppad/functional.py index ae959882ad..648473833b 100644 --- a/monai/transforms/croppad/functional.py +++ b/monai/transforms/croppad/functional.py @@ -16,6 +16,7 @@ from __future__ import annotations import warnings + import numpy as np import torch from torch.nn.functional import pad as pad_pt @@ -84,8 +85,7 @@ def pad_nd(img: torch.Tensor, to_pad: list[tuple[int, int]], mode: str, **kwargs try: _pad = ( _np_pad - if mode in {"reflect", "replicate"} - and img.dtype in {torch.int16, torch.int64, torch.bool, torch.uint8} + if mode in {"reflect", "replicate"} and img.dtype in {torch.int16, torch.int64, torch.bool, torch.uint8} else _pt_pad ) return _pad(img, pad_width=to_pad, mode=mode, **kwargs) From e6e16de69318329b8f4fd21e01592d94e179aa4f Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Mon, 27 Feb 2023 16:32:38 +0000 Subject: [PATCH 4/5] update Signed-off-by: Wenqi Li --- monai/transforms/croppad/functional.py | 2 +- monai/transforms/utils.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/monai/transforms/croppad/functional.py b/monai/transforms/croppad/functional.py index 648473833b..2d07f9f42c 100644 --- a/monai/transforms/croppad/functional.py +++ b/monai/transforms/croppad/functional.py @@ -79,7 +79,7 @@ def pad_nd(img: torch.Tensor, to_pad: list[tuple[int, int]], mode: str, **kwargs kwargs: other arguments for the `np.pad` or `torch.pad` function. note that `np.pad` treats channel dimension as the first dimension. """ - if mode in {"linear_ramp", "maximum", "mean", "median", "minimum", "symmetric", "empty", "wrap"}: + if mode in {"linear_ramp", "maximum", "mean", "median", "minimum", "symmetric", "empty"}: return _np_pad(img, pad_width=to_pad, mode=mode, **kwargs) mode = convert_pad_mode(dst=img, mode=mode).value try: diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index a8db6818bc..d3c8eb606f 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -1628,13 +1628,13 @@ def convert_pad_mode(dst: NdarrayOrTensor, mode: str | None): if isinstance(dst, torch.Tensor): if mode == "wrap": mode = "circular" - if mode == "edge": + elif mode == "edge": mode = "replicate" return look_up_option(mode, PytorchPadMode) if isinstance(dst, np.ndarray): if mode == "circular": mode = "wrap" - if mode == "replicate": + elif mode == "replicate": mode = "edge" return look_up_option(mode, NumpyPadMode) raise ValueError(f"unsupported data type: {type(dst)}.") From 5e44fada7b4cb29ae917ffc2401f732f7c7de9dd Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Tue, 28 Feb 2023 11:43:14 +0000 Subject: [PATCH 5/5] update based on comments Signed-off-by: Wenqi Li --- monai/transforms/croppad/functional.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/monai/transforms/croppad/functional.py b/monai/transforms/croppad/functional.py index ce6575721f..cefd42ca58 100644 --- a/monai/transforms/croppad/functional.py +++ b/monai/transforms/croppad/functional.py @@ -33,17 +33,14 @@ def _np_pad(img: torch.Tensor, pad_width: list[tuple[int, int]], mode: str, **kwargs) -> torch.Tensor: if isinstance(img, torch.Tensor): if img.is_cuda: - warnings.warn(f"Padding: moving img {img.shape} on cuda to cpu for dtype={img.dtype} mode={mode}.") + warnings.warn(f"Padding: moving img {img.shape} from cuda to cpu for dtype={img.dtype} mode={mode}.") img_np = img.detach().cpu().numpy() else: img_np = img mode = convert_pad_mode(dst=img_np, mode=mode).value if mode == "constant" and "value" in kwargs: - _kwargs = kwargs.copy() - _kwargs["constant_values"] = _kwargs.pop("value") - else: - _kwargs = kwargs - out = torch.as_tensor(np.pad(img, pad_width, mode=mode, **_kwargs)) # type: ignore + kwargs["constant_values"] = kwargs.pop("value") + out = torch.as_tensor(np.pad(img, pad_width, mode=mode, **kwargs)) # type: ignore if isinstance(img, MetaTensor): out = convert_to_dst_type(out, dst=img)[0] return out