diff --git a/monai/transforms/croppad/functional.py b/monai/transforms/croppad/functional.py index fc086eb9a8..cefd42ca58 100644 --- a/monai/transforms/croppad/functional.py +++ b/monai/transforms/croppad/functional.py @@ -15,6 +15,8 @@ from __future__ import annotations +import warnings + import numpy as np import torch from torch.nn.functional import pad as pad_pt @@ -29,7 +31,12 @@ def _np_pad(img: torch.Tensor, pad_width: list[tuple[int, int]], mode: str, **kwargs) -> torch.Tensor: - img_np = img.detach().cpu().numpy() if isinstance(img, torch.Tensor) else img + if isinstance(img, torch.Tensor): + if img.is_cuda: + warnings.warn(f"Padding: moving img {img.shape} from cuda to cpu for dtype={img.dtype} mode={mode}.") + img_np = img.detach().cpu().numpy() + else: + img_np = img mode = convert_pad_mode(dst=img_np, mode=mode).value if mode == "constant" and "value" in kwargs: kwargs["constant_values"] = kwargs.pop("value") @@ -40,9 +47,15 @@ def _np_pad(img: torch.Tensor, pad_width: list[tuple[int, int]], mode: str, **kw def _pt_pad(img: torch.Tensor, pad_width: list[tuple[int, int]], mode: str, **kwargs) -> torch.Tensor: + mode = convert_pad_mode(dst=img, mode=mode).value + if mode == "constant" and "constant_values" in kwargs: + _kwargs = kwargs.copy() + _kwargs["value"] = _kwargs.pop("constant_values") + else: + _kwargs = kwargs pt_pad_width = [val for sublist in pad_width[1:] for val in sublist[::-1]][::-1] # torch.pad expects `[B, C, H, W, [D]]` shape - return pad_pt(img.unsqueeze(0), pt_pad_width, mode=mode, **kwargs).squeeze(0) + return pad_pt(img.unsqueeze(0), pt_pad_width, mode=mode, **_kwargs).squeeze(0) def pad_nd(img: torch.Tensor, to_pad: list[tuple[int, int]], mode: str, **kwargs): @@ -68,14 +81,14 @@ def pad_nd(img: torch.Tensor, to_pad: list[tuple[int, int]], mode: str, **kwargs mode = convert_pad_mode(dst=img, mode=mode).value try: _pad = ( - _pt_pad - if mode in {"reflect", "replicate"} and img.dtype not in {torch.int16, torch.int64, torch.bool, torch.uint8} - else _np_pad + _np_pad + if mode in {"reflect", "replicate"} and img.dtype in {torch.int16, torch.int64, torch.bool, torch.uint8} + else _pt_pad ) return _pad(img, pad_width=to_pad, mode=mode, **kwargs) except (ValueError, TypeError, RuntimeError) as err: if isinstance(err, NotImplementedError) or any( - k in str(err) for k in ("supported", "unexpected keyword", "implemented") + k in str(err) for k in ("supported", "unexpected keyword", "implemented", "value") ): return _np_pad(img, pad_width=to_pad, mode=mode, **kwargs) raise ValueError(f"{img.shape} {to_pad} {mode} {kwargs} {img.dtype} {img.device}") from err diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index a8db6818bc..d3c8eb606f 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -1628,13 +1628,13 @@ def convert_pad_mode(dst: NdarrayOrTensor, mode: str | None): if isinstance(dst, torch.Tensor): if mode == "wrap": mode = "circular" - if mode == "edge": + elif mode == "edge": mode = "replicate" return look_up_option(mode, PytorchPadMode) if isinstance(dst, np.ndarray): if mode == "circular": mode = "wrap" - if mode == "replicate": + elif mode == "replicate": mode = "edge" return look_up_option(mode, NumpyPadMode) raise ValueError(f"unsupported data type: {type(dst)}.")