Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 30 additions & 15 deletions monai/data/grid_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,18 @@

from __future__ import annotations

from collections.abc import Callable, Hashable, Iterable, Mapping, Sequence
from collections.abc import Callable, Generator, Hashable, Iterable, Mapping, Sequence
from copy import deepcopy

import numpy as np

from monai.config import KeysCollection
from monai.config.type_definitions import NdarrayTensor
from monai.data.dataset import Dataset
from monai.data.iterable_dataset import IterableDataset
from monai.data.utils import iter_patch
from monai.transforms import apply_transform
from monai.utils import NumpyPadMode, ensure_tuple, first, look_up_option
from monai.utils import NumpyPadMode, ensure_tuple, first

__all__ = ["PatchDataset", "GridPatchDataset", "PatchIter", "PatchIterd"]

Expand All @@ -34,17 +35,25 @@ class PatchIter:
"""

def __init__(
self, patch_size: Sequence[int], start_pos: Sequence[int] = (), mode: str = NumpyPadMode.WRAP, **pad_opts: dict
self,
patch_size: Sequence[int],
start_pos: Sequence[int] = (),
mode: str | None = NumpyPadMode.WRAP,
**pad_opts: dict,
):
"""

Args:
patch_size: size of patches to generate slices for, 0/None selects whole dimension
start_pos: starting position in the array, default is 0 for each dimension
mode: {``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, ``"mean"``,
``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``}
One of the listed string values or a user supplied function. Defaults to ``"wrap"``.
See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html
mode: available modes: (Numpy) {``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``,
``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``}
(PyTorch) {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}.
One of the listed string values or a user supplied function.
If None, no wrapping is performed. Defaults to ``"wrap"``.
See also: https://numpy.org/doc/stable/reference/generated/numpy.pad.html
https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html
requires pytorch >= 1.10 for best compatibility.
pad_opts: other arguments for the `np.pad` function.
note that `np.pad` treats channel dimension as the first dimension.

Expand All @@ -58,10 +67,10 @@ def __init__(
"""
self.patch_size = (None,) + tuple(patch_size) # expand to have the channel dim
self.start_pos = ensure_tuple(start_pos)
self.mode: NumpyPadMode = look_up_option(mode, NumpyPadMode)
self.mode = mode
self.pad_opts = pad_opts

def __call__(self, array: np.ndarray):
def __call__(self, array: NdarrayTensor) -> Generator[tuple[NdarrayTensor, np.ndarray], None, None]:
"""
Args:
array: the image to generate patches from.
Expand Down Expand Up @@ -89,10 +98,14 @@ class PatchIterd:
keys: keys of the corresponding items to iterate patches.
patch_size: size of patches to generate slices for, 0/None selects whole dimension
start_pos: starting position in the array, default is 0 for each dimension
mode: {``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, ``"mean"``,
``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``}
One of the listed string values or a user supplied function. Defaults to ``"wrap"``.
See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html
mode: available modes: (Numpy) {``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``,
``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``}
(PyTorch) {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}.
One of the listed string values or a user supplied function.
If None, no wrapping is performed. Defaults to ``"wrap"``.
See also: https://numpy.org/doc/stable/reference/generated/numpy.pad.html
https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html
requires pytorch >= 1.10 for best compatibility.
pad_opts: other arguments for the `np.pad` function.
note that `np.pad` treats channel dimension as the first dimension.

Expand All @@ -107,13 +120,15 @@ def __init__(
keys: KeysCollection,
patch_size: Sequence[int],
start_pos: Sequence[int] = (),
mode: str = NumpyPadMode.WRAP,
mode: str | None = NumpyPadMode.WRAP,
**pad_opts,
):
self.keys = ensure_tuple(keys)
self.patch_iter = PatchIter(patch_size=patch_size, start_pos=start_pos, mode=mode, **pad_opts)

def __call__(self, data: Mapping[Hashable, np.ndarray]):
def __call__(
self, data: Mapping[Hashable, NdarrayTensor]
) -> Generator[tuple[Mapping[Hashable, NdarrayTensor], np.ndarray], None, None]:
d = dict(data)
original_spatial_shape = d[first(self.keys)].shape[1:]

Expand Down
24 changes: 17 additions & 7 deletions monai/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,14 +247,14 @@ def iter_patch_position(


def iter_patch(
arr: np.ndarray,
arr: NdarrayOrTensor,
patch_size: Sequence[int] | int = 0,
start_pos: Sequence[int] = (),
overlap: Sequence[float] | float = 0.0,
copy_back: bool = True,
mode: str | None = NumpyPadMode.WRAP,
**pad_opts: dict,
):
) -> Generator[tuple[NdarrayOrTensor, np.ndarray], None, None]:
"""
Yield successive patches from `arr` of size `patch_size`. The iteration can start from position `start_pos` in `arr`
but drawing from a padded array extended by the `patch_size` in each dimension (so these coordinates can be negative
Expand All @@ -268,9 +268,16 @@ def iter_patch(
overlap: the amount of overlap of neighboring patches in each dimension (a value between 0.0 and 1.0).
If only one float number is given, it will be applied to all dimensions. Defaults to 0.0.
copy_back: if True data from the yielded patches is copied back to `arr` once the generator completes
mode: One of the listed string values in ``monai.utils.NumpyPadMode`` or ``monai.utils.PytorchPadMode``,
or a user supplied function. If None, no wrapping is performed. Defaults to ``"wrap"``.
pad_opts: padding options, see `numpy.pad`
mode: available modes: (Numpy) {``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``,
``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``}
(PyTorch) {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}.
One of the listed string values or a user supplied function.
If None, no wrapping is performed. Defaults to ``"wrap"``.
See also: https://numpy.org/doc/stable/reference/generated/numpy.pad.html
https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html
requires pytorch >= 1.10 for best compatibility.
pad_opts: other arguments for the `np.pad` or `torch.pad` function.
note that `np.pad` treats channel dimension as the first dimension.

Yields:
Patches of array data from `arr` which are views into a padded array which can be modified, if `copy_back` is
Expand All @@ -285,6 +292,9 @@ def iter_patch(
Nth_dim_start, Nth_dim_end]]

"""

from monai.transforms.croppad.functional import pad_nd # needs to be here to avoid circular import

# ensure patchSize and startPos are the right length
patch_size_ = get_valid_patch_size(arr.shape, patch_size)
start_pos = ensure_tuple_size(start_pos, arr.ndim)
Expand All @@ -296,7 +306,7 @@ def iter_patch(
_overlap = [op if v else 0.0 for op, v in zip(ensure_tuple_rep(overlap, arr.ndim), is_v)] # overlap if v else 0.0
# pad image by maximum values needed to ensure patches are taken from inside an image
if padded:
arrpad = np.pad(arr, tuple((p, p) for p in _pad_size), look_up_option(mode, NumpyPadMode).value, **pad_opts)
arrpad = pad_nd(arr, to_pad=[(p, p) for p in _pad_size], mode=mode, **pad_opts) # type: ignore
# choose a start position in the padded image
start_pos_padded = tuple(s + p for s, p in zip(start_pos, _pad_size))

Expand All @@ -319,7 +329,7 @@ def iter_patch(
# copy back data from the padded image if required
if copy_back:
slices = tuple(slice(p, p + s) for p, s in zip(_pad_size, arr.shape))
arr[...] = arrpad[slices]
arr[...] = arrpad[slices] # type: ignore


def get_valid_patch_size(image_size: Sequence[int], patch_size: Sequence[int] | int | np.ndarray) -> tuple[int, ...]:
Expand Down
2 changes: 1 addition & 1 deletion monai/transforms/croppad/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def __call__( # type: ignore[override]
kwargs_.update(kwargs)

img_t = convert_to_tensor(data=img, track_meta=get_track_meta())
return pad_func(img_t, to_pad_, mode_, self.get_transform_info(), kwargs_)
return pad_func(img_t, to_pad_, self.get_transform_info(), mode_, **kwargs_)

def inverse(self, data: MetaTensor) -> MetaTensor:
transform = self.pop_transform(data)
Expand Down
60 changes: 38 additions & 22 deletions monai/transforms/croppad/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import torch
from torch.nn.functional import pad as pad_pt

from monai.config.type_definitions import NdarrayTensor
from monai.data.meta_obj import get_track_meta
from monai.data.meta_tensor import MetaTensor
from monai.data.utils import to_affine_nd
Expand Down Expand Up @@ -49,7 +50,7 @@ def _convert_pt_pad_mode(padding_mode):
return PytorchPadMode.REPLICATE # "nearest", "border", and others


def _np_pad(img: torch.Tensor, pad_width: list[tuple[int, int]], mode: str, **kwargs) -> torch.Tensor:
def _np_pad(img: NdarrayTensor, pad_width: list[tuple[int, int]], mode: str, **kwargs) -> NdarrayTensor:
if isinstance(img, torch.Tensor):
if img.is_cuda:
warnings.warn(f"Padding: moving img {img.shape} from cuda to cpu for dtype={img.dtype} mode={mode}.")
Expand All @@ -59,28 +60,32 @@ def _np_pad(img: torch.Tensor, pad_width: list[tuple[int, int]], mode: str, **kw
mode = convert_pad_mode(dst=img_np, mode=mode).value
if mode == "constant" and "value" in kwargs:
kwargs["constant_values"] = kwargs.pop("value")
out = torch.as_tensor(np.pad(img, pad_width, mode=mode, **kwargs)) # type: ignore
if isinstance(img, MetaTensor):
out = convert_to_dst_type(out, dst=img)[0]
return out
img_np = np.pad(img_np, pad_width, mode=mode, **kwargs) # type: ignore
return convert_to_dst_type(img_np, dst=img)[0]


def _pt_pad(img: torch.Tensor, pad_width: list[tuple[int, int]], mode: str, **kwargs) -> torch.Tensor:
mode = convert_pad_mode(dst=img, mode=mode).value
def _pt_pad(img: NdarrayTensor, pad_width: list[tuple[int, int]], mode: str, **kwargs) -> NdarrayTensor:
img_pt = torch.as_tensor(img)
mode = convert_pad_mode(dst=img_pt, mode=mode).value
if mode == "constant" and "constant_values" in kwargs:
_kwargs = kwargs.copy()
_kwargs["value"] = _kwargs.pop("constant_values")
else:
_kwargs = kwargs
pt_pad_width = [val for sublist in pad_width[1:] for val in sublist[::-1]][::-1]
# torch.pad expects `[B, C, H, W, [D]]` shape
return pad_pt(img.unsqueeze(0), pt_pad_width, mode=mode, **_kwargs).squeeze(0)
img_pt = pad_pt(img_pt.unsqueeze(0), pt_pad_width, mode=mode, **_kwargs).squeeze(0)
return convert_to_dst_type(img_pt, dst=img)[0]


def pad_nd(img: torch.Tensor, to_pad: list[tuple[int, int]], mode: str, **kwargs):
def pad_nd(
img: NdarrayTensor, to_pad: list[tuple[int, int]], mode: str = PytorchPadMode.CONSTANT, **kwargs
) -> NdarrayTensor:
"""
PyTorch/Numpy pad ``img`` with integers ``to_pad`` amounts. Depending on the ``mode`` and input dtype,
a suitable backend will be used automatically.
Pad `img` for a given an amount of padding in each dimension.

`torch.nn.functional.pad` is used unless the mode or kwargs are not available in torch,
in which case `np.pad` will be used.

Args:
img: data to be transformed, assuming `img` is channel-first and padding doesn't apply to the channel dim.
Expand All @@ -90,27 +95,31 @@ def pad_nd(img: torch.Tensor, to_pad: list[tuple[int, int]], mode: str, **kwargs
``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``}
(PyTorch) {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}.
One of the listed string values or a user supplied function. Defaults to ``"constant"``.
See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html
See also: https://numpy.org/doc/stable/reference/generated/numpy.pad.html
https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html
kwargs: other arguments for the `np.pad` or `torch.pad` function.
note that `np.pad` treats channel dimension as the first dimension.
"""
if mode in {"linear_ramp", "maximum", "mean", "median", "minimum", "symmetric", "empty"}:
return _np_pad(img, pad_width=to_pad, mode=mode, **kwargs)
mode = convert_pad_mode(dst=img, mode=mode).value
try:
_pad = (
_np_pad
if mode in {"reflect", "replicate"} and img.dtype in {torch.int16, torch.int64, torch.bool, torch.uint8}
else _pt_pad
)
_pad = _np_pad
if mode in {"constant", "reflect", "edge", "replicate", "wrap", "circular"} and img.dtype not in {
torch.int16,
torch.int64,
torch.bool,
torch.uint8,
}:
_pad = _pt_pad
return _pad(img, pad_width=to_pad, mode=mode, **kwargs)
except (ValueError, TypeError, RuntimeError) as err:
if isinstance(err, NotImplementedError) or any(
k in str(err) for k in ("supported", "unexpected keyword", "implemented", "value")
):
return _np_pad(img, pad_width=to_pad, mode=mode, **kwargs)
raise ValueError(f"{img.shape} {to_pad} {mode} {kwargs} {img.dtype} {img.device}") from err
raise ValueError(
f"{img.shape} {to_pad} {mode} {kwargs} {img.dtype} {img.device if isinstance(img, torch.Tensor) else None}"
) from err


def crop_or_pad_nd(img: torch.Tensor, translation_mat, spatial_size: tuple[int, ...], mode: str, **kwargs):
Expand Down Expand Up @@ -148,23 +157,30 @@ def crop_or_pad_nd(img: torch.Tensor, translation_mat, spatial_size: tuple[int,


def pad_func(
img: torch.Tensor, to_pad: tuple[tuple[int, int]], mode: str, transform_info: dict, kwargs
img: torch.Tensor,
to_pad: tuple[tuple[int, int]],
transform_info: dict,
mode: str = PytorchPadMode.CONSTANT,
**kwargs,
) -> torch.Tensor:
"""
Functional implementation of padding a MetaTensor. This function operates eagerly or lazily according
to ``transform_info[TraceKeys.LAZY_EVALUATION]`` (default ``False``).

`torch.nn.functional.pad` is used unless the mode or kwargs are not available in torch,
in which case `np.pad` will be used.

Args:
img: data to be transformed, assuming `img` is channel-first and padding doesn't apply to the channel dim.
to_pad: the amount to be padded in each dimension [(low_H, high_H), (low_W, high_W), ...].
note that it including channel dimension.
transform_info: a dictionary with the relevant information pertaining to an applied transform.
mode: available modes: (Numpy) {``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``,
``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``}
(PyTorch) {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}.
One of the listed string values or a user supplied function. Defaults to ``"constant"``.
See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html
See also: https://numpy.org/doc/stable/reference/generated/numpy.pad.html
https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html
transform_info: a dictionary with the relevant information pertaining to an applied transform.
kwargs: other arguments for the `np.pad` or `torch.pad` function.
note that `np.pad` treats channel dimension as the first dimension.
"""
Expand Down
Loading