diff --git a/monai/data/grid_dataset.py b/monai/data/grid_dataset.py index fc8175f630..43c72b5a78 100644 --- a/monai/data/grid_dataset.py +++ b/monai/data/grid_dataset.py @@ -11,17 +11,18 @@ from __future__ import annotations -from collections.abc import Callable, Hashable, Iterable, Mapping, Sequence +from collections.abc import Callable, Generator, Hashable, Iterable, Mapping, Sequence from copy import deepcopy import numpy as np from monai.config import KeysCollection +from monai.config.type_definitions import NdarrayTensor from monai.data.dataset import Dataset from monai.data.iterable_dataset import IterableDataset from monai.data.utils import iter_patch from monai.transforms import apply_transform -from monai.utils import NumpyPadMode, ensure_tuple, first, look_up_option +from monai.utils import NumpyPadMode, ensure_tuple, first __all__ = ["PatchDataset", "GridPatchDataset", "PatchIter", "PatchIterd"] @@ -34,17 +35,25 @@ class PatchIter: """ def __init__( - self, patch_size: Sequence[int], start_pos: Sequence[int] = (), mode: str = NumpyPadMode.WRAP, **pad_opts: dict + self, + patch_size: Sequence[int], + start_pos: Sequence[int] = (), + mode: str | None = NumpyPadMode.WRAP, + **pad_opts: dict, ): """ Args: patch_size: size of patches to generate slices for, 0/None selects whole dimension start_pos: starting position in the array, default is 0 for each dimension - mode: {``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, ``"mean"``, - ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} - One of the listed string values or a user supplied function. Defaults to ``"wrap"``. - See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html + mode: available modes: (Numpy) {``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, + ``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} + (PyTorch) {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}. + One of the listed string values or a user supplied function. + If None, no wrapping is performed. Defaults to ``"wrap"``. + See also: https://numpy.org/doc/stable/reference/generated/numpy.pad.html + https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html + requires pytorch >= 1.10 for best compatibility. pad_opts: other arguments for the `np.pad` function. note that `np.pad` treats channel dimension as the first dimension. @@ -58,10 +67,10 @@ def __init__( """ self.patch_size = (None,) + tuple(patch_size) # expand to have the channel dim self.start_pos = ensure_tuple(start_pos) - self.mode: NumpyPadMode = look_up_option(mode, NumpyPadMode) + self.mode = mode self.pad_opts = pad_opts - def __call__(self, array: np.ndarray): + def __call__(self, array: NdarrayTensor) -> Generator[tuple[NdarrayTensor, np.ndarray], None, None]: """ Args: array: the image to generate patches from. @@ -89,10 +98,14 @@ class PatchIterd: keys: keys of the corresponding items to iterate patches. patch_size: size of patches to generate slices for, 0/None selects whole dimension start_pos: starting position in the array, default is 0 for each dimension - mode: {``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, ``"mean"``, - ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} - One of the listed string values or a user supplied function. Defaults to ``"wrap"``. - See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html + mode: available modes: (Numpy) {``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, + ``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} + (PyTorch) {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}. + One of the listed string values or a user supplied function. + If None, no wrapping is performed. Defaults to ``"wrap"``. + See also: https://numpy.org/doc/stable/reference/generated/numpy.pad.html + https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html + requires pytorch >= 1.10 for best compatibility. pad_opts: other arguments for the `np.pad` function. note that `np.pad` treats channel dimension as the first dimension. @@ -107,13 +120,15 @@ def __init__( keys: KeysCollection, patch_size: Sequence[int], start_pos: Sequence[int] = (), - mode: str = NumpyPadMode.WRAP, + mode: str | None = NumpyPadMode.WRAP, **pad_opts, ): self.keys = ensure_tuple(keys) self.patch_iter = PatchIter(patch_size=patch_size, start_pos=start_pos, mode=mode, **pad_opts) - def __call__(self, data: Mapping[Hashable, np.ndarray]): + def __call__( + self, data: Mapping[Hashable, NdarrayTensor] + ) -> Generator[tuple[Mapping[Hashable, NdarrayTensor], np.ndarray], None, None]: d = dict(data) original_spatial_shape = d[first(self.keys)].shape[1:] diff --git a/monai/data/utils.py b/monai/data/utils.py index 5461fda937..2c035afb3f 100644 --- a/monai/data/utils.py +++ b/monai/data/utils.py @@ -247,14 +247,14 @@ def iter_patch_position( def iter_patch( - arr: np.ndarray, + arr: NdarrayOrTensor, patch_size: Sequence[int] | int = 0, start_pos: Sequence[int] = (), overlap: Sequence[float] | float = 0.0, copy_back: bool = True, mode: str | None = NumpyPadMode.WRAP, **pad_opts: dict, -): +) -> Generator[tuple[NdarrayOrTensor, np.ndarray], None, None]: """ Yield successive patches from `arr` of size `patch_size`. The iteration can start from position `start_pos` in `arr` but drawing from a padded array extended by the `patch_size` in each dimension (so these coordinates can be negative @@ -268,9 +268,16 @@ def iter_patch( overlap: the amount of overlap of neighboring patches in each dimension (a value between 0.0 and 1.0). If only one float number is given, it will be applied to all dimensions. Defaults to 0.0. copy_back: if True data from the yielded patches is copied back to `arr` once the generator completes - mode: One of the listed string values in ``monai.utils.NumpyPadMode`` or ``monai.utils.PytorchPadMode``, - or a user supplied function. If None, no wrapping is performed. Defaults to ``"wrap"``. - pad_opts: padding options, see `numpy.pad` + mode: available modes: (Numpy) {``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, + ``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} + (PyTorch) {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}. + One of the listed string values or a user supplied function. + If None, no wrapping is performed. Defaults to ``"wrap"``. + See also: https://numpy.org/doc/stable/reference/generated/numpy.pad.html + https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html + requires pytorch >= 1.10 for best compatibility. + pad_opts: other arguments for the `np.pad` or `torch.pad` function. + note that `np.pad` treats channel dimension as the first dimension. Yields: Patches of array data from `arr` which are views into a padded array which can be modified, if `copy_back` is @@ -285,6 +292,9 @@ def iter_patch( Nth_dim_start, Nth_dim_end]] """ + + from monai.transforms.croppad.functional import pad_nd # needs to be here to avoid circular import + # ensure patchSize and startPos are the right length patch_size_ = get_valid_patch_size(arr.shape, patch_size) start_pos = ensure_tuple_size(start_pos, arr.ndim) @@ -296,7 +306,7 @@ def iter_patch( _overlap = [op if v else 0.0 for op, v in zip(ensure_tuple_rep(overlap, arr.ndim), is_v)] # overlap if v else 0.0 # pad image by maximum values needed to ensure patches are taken from inside an image if padded: - arrpad = np.pad(arr, tuple((p, p) for p in _pad_size), look_up_option(mode, NumpyPadMode).value, **pad_opts) + arrpad = pad_nd(arr, to_pad=[(p, p) for p in _pad_size], mode=mode, **pad_opts) # type: ignore # choose a start position in the padded image start_pos_padded = tuple(s + p for s, p in zip(start_pos, _pad_size)) @@ -319,7 +329,7 @@ def iter_patch( # copy back data from the padded image if required if copy_back: slices = tuple(slice(p, p + s) for p, s in zip(_pad_size, arr.shape)) - arr[...] = arrpad[slices] + arr[...] = arrpad[slices] # type: ignore def get_valid_patch_size(image_size: Sequence[int], patch_size: Sequence[int] | int | np.ndarray) -> tuple[int, ...]: diff --git a/monai/transforms/croppad/array.py b/monai/transforms/croppad/array.py index aa13d54c51..94689d2fcf 100644 --- a/monai/transforms/croppad/array.py +++ b/monai/transforms/croppad/array.py @@ -150,7 +150,7 @@ def __call__( # type: ignore[override] kwargs_.update(kwargs) img_t = convert_to_tensor(data=img, track_meta=get_track_meta()) - return pad_func(img_t, to_pad_, mode_, self.get_transform_info(), kwargs_) + return pad_func(img_t, to_pad_, self.get_transform_info(), mode_, **kwargs_) def inverse(self, data: MetaTensor) -> MetaTensor: transform = self.pop_transform(data) diff --git a/monai/transforms/croppad/functional.py b/monai/transforms/croppad/functional.py index fa95958bd5..e694edb737 100644 --- a/monai/transforms/croppad/functional.py +++ b/monai/transforms/croppad/functional.py @@ -21,6 +21,7 @@ import torch from torch.nn.functional import pad as pad_pt +from monai.config.type_definitions import NdarrayTensor from monai.data.meta_obj import get_track_meta from monai.data.meta_tensor import MetaTensor from monai.data.utils import to_affine_nd @@ -49,7 +50,7 @@ def _convert_pt_pad_mode(padding_mode): return PytorchPadMode.REPLICATE # "nearest", "border", and others -def _np_pad(img: torch.Tensor, pad_width: list[tuple[int, int]], mode: str, **kwargs) -> torch.Tensor: +def _np_pad(img: NdarrayTensor, pad_width: list[tuple[int, int]], mode: str, **kwargs) -> NdarrayTensor: if isinstance(img, torch.Tensor): if img.is_cuda: warnings.warn(f"Padding: moving img {img.shape} from cuda to cpu for dtype={img.dtype} mode={mode}.") @@ -59,14 +60,13 @@ def _np_pad(img: torch.Tensor, pad_width: list[tuple[int, int]], mode: str, **kw mode = convert_pad_mode(dst=img_np, mode=mode).value if mode == "constant" and "value" in kwargs: kwargs["constant_values"] = kwargs.pop("value") - out = torch.as_tensor(np.pad(img, pad_width, mode=mode, **kwargs)) # type: ignore - if isinstance(img, MetaTensor): - out = convert_to_dst_type(out, dst=img)[0] - return out + img_np = np.pad(img_np, pad_width, mode=mode, **kwargs) # type: ignore + return convert_to_dst_type(img_np, dst=img)[0] -def _pt_pad(img: torch.Tensor, pad_width: list[tuple[int, int]], mode: str, **kwargs) -> torch.Tensor: - mode = convert_pad_mode(dst=img, mode=mode).value +def _pt_pad(img: NdarrayTensor, pad_width: list[tuple[int, int]], mode: str, **kwargs) -> NdarrayTensor: + img_pt = torch.as_tensor(img) + mode = convert_pad_mode(dst=img_pt, mode=mode).value if mode == "constant" and "constant_values" in kwargs: _kwargs = kwargs.copy() _kwargs["value"] = _kwargs.pop("constant_values") @@ -74,13 +74,18 @@ def _pt_pad(img: torch.Tensor, pad_width: list[tuple[int, int]], mode: str, **kw _kwargs = kwargs pt_pad_width = [val for sublist in pad_width[1:] for val in sublist[::-1]][::-1] # torch.pad expects `[B, C, H, W, [D]]` shape - return pad_pt(img.unsqueeze(0), pt_pad_width, mode=mode, **_kwargs).squeeze(0) + img_pt = pad_pt(img_pt.unsqueeze(0), pt_pad_width, mode=mode, **_kwargs).squeeze(0) + return convert_to_dst_type(img_pt, dst=img)[0] -def pad_nd(img: torch.Tensor, to_pad: list[tuple[int, int]], mode: str, **kwargs): +def pad_nd( + img: NdarrayTensor, to_pad: list[tuple[int, int]], mode: str = PytorchPadMode.CONSTANT, **kwargs +) -> NdarrayTensor: """ - PyTorch/Numpy pad ``img`` with integers ``to_pad`` amounts. Depending on the ``mode`` and input dtype, - a suitable backend will be used automatically. + Pad `img` for a given an amount of padding in each dimension. + + `torch.nn.functional.pad` is used unless the mode or kwargs are not available in torch, + in which case `np.pad` will be used. Args: img: data to be transformed, assuming `img` is channel-first and padding doesn't apply to the channel dim. @@ -90,27 +95,31 @@ def pad_nd(img: torch.Tensor, to_pad: list[tuple[int, int]], mode: str, **kwargs ``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} (PyTorch) {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}. One of the listed string values or a user supplied function. Defaults to ``"constant"``. - See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html + See also: https://numpy.org/doc/stable/reference/generated/numpy.pad.html https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html kwargs: other arguments for the `np.pad` or `torch.pad` function. note that `np.pad` treats channel dimension as the first dimension. """ if mode in {"linear_ramp", "maximum", "mean", "median", "minimum", "symmetric", "empty"}: return _np_pad(img, pad_width=to_pad, mode=mode, **kwargs) - mode = convert_pad_mode(dst=img, mode=mode).value try: - _pad = ( - _np_pad - if mode in {"reflect", "replicate"} and img.dtype in {torch.int16, torch.int64, torch.bool, torch.uint8} - else _pt_pad - ) + _pad = _np_pad + if mode in {"constant", "reflect", "edge", "replicate", "wrap", "circular"} and img.dtype not in { + torch.int16, + torch.int64, + torch.bool, + torch.uint8, + }: + _pad = _pt_pad return _pad(img, pad_width=to_pad, mode=mode, **kwargs) except (ValueError, TypeError, RuntimeError) as err: if isinstance(err, NotImplementedError) or any( k in str(err) for k in ("supported", "unexpected keyword", "implemented", "value") ): return _np_pad(img, pad_width=to_pad, mode=mode, **kwargs) - raise ValueError(f"{img.shape} {to_pad} {mode} {kwargs} {img.dtype} {img.device}") from err + raise ValueError( + f"{img.shape} {to_pad} {mode} {kwargs} {img.dtype} {img.device if isinstance(img, torch.Tensor) else None}" + ) from err def crop_or_pad_nd(img: torch.Tensor, translation_mat, spatial_size: tuple[int, ...], mode: str, **kwargs): @@ -148,23 +157,30 @@ def crop_or_pad_nd(img: torch.Tensor, translation_mat, spatial_size: tuple[int, def pad_func( - img: torch.Tensor, to_pad: tuple[tuple[int, int]], mode: str, transform_info: dict, kwargs + img: torch.Tensor, + to_pad: tuple[tuple[int, int]], + transform_info: dict, + mode: str = PytorchPadMode.CONSTANT, + **kwargs, ) -> torch.Tensor: """ Functional implementation of padding a MetaTensor. This function operates eagerly or lazily according to ``transform_info[TraceKeys.LAZY_EVALUATION]`` (default ``False``). + `torch.nn.functional.pad` is used unless the mode or kwargs are not available in torch, + in which case `np.pad` will be used. + Args: img: data to be transformed, assuming `img` is channel-first and padding doesn't apply to the channel dim. to_pad: the amount to be padded in each dimension [(low_H, high_H), (low_W, high_W), ...]. note that it including channel dimension. + transform_info: a dictionary with the relevant information pertaining to an applied transform. mode: available modes: (Numpy) {``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, ``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} (PyTorch) {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}. One of the listed string values or a user supplied function. Defaults to ``"constant"``. - See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html + See also: https://numpy.org/doc/stable/reference/generated/numpy.pad.html https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html - transform_info: a dictionary with the relevant information pertaining to an applied transform. kwargs: other arguments for the `np.pad` or `torch.pad` function. note that `np.pad` treats channel dimension as the first dimension. """ diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index c8dc12193a..15cebc93c5 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -47,7 +47,6 @@ from monai.transforms.traits import MultiSampleTrait from monai.transforms.transform import LazyTransform, Randomizable, RandomizableTransform, Transform from monai.transforms.utils import ( - convert_pad_mode, create_control_grid, create_grid, create_rotate, @@ -57,7 +56,7 @@ map_spatial_axes, scale_affine, ) -from monai.transforms.utils_pytorch_numpy_unification import linalg_inv, moveaxis +from monai.transforms.utils_pytorch_numpy_unification import argsort, argwhere, linalg_inv, moveaxis from monai.utils import ( GridSampleMode, GridSamplePadMode, @@ -77,7 +76,7 @@ optional_import, ) from monai.utils.deprecate_utils import deprecated_arg -from monai.utils.enums import GridPatchSort, PatchKeys, PytorchPadMode, TraceKeys, TransformBackends +from monai.utils.enums import GridPatchSort, PatchKeys, TraceKeys, TransformBackends from monai.utils.misc import ImageMetaKey as Key from monai.utils.module import look_up_option from monai.utils.type_conversion import convert_data_type, get_equivalent_dtype, get_torch_dtype_from_string @@ -3004,12 +3003,16 @@ class GridPatch(Transform, MultiSampleTrait): threshold: a value to keep only the patches whose sum of intensities are less than the threshold. Defaults to no filtering. pad_mode: the mode for padding the input image by `patch_size` to include patches that cross boundaries. - Defaults to None, which means no padding will be applied. - Available modes:{``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, ``"mean"``, - ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``}. + Available modes: (Numpy) {``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, + ``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} + (PyTorch) {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}. + One of the listed string values or a user supplied function. + Defaults to `None`, which means no padding will be applied. See also: https://numpy.org/doc/stable/reference/generated/numpy.pad.html + https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html + requires pytorch >= 1.10 for best compatibility. pad_kwargs: other arguments for the `np.pad` or `torch.pad` function. - + note that `np.pad` treats channel dimension as the first dimension. Returns: MetaTensor: the extracted patches as a single tensor (with patch dimension as the first dimension), with following metadata: @@ -3036,34 +3039,35 @@ def __init__( ): self.patch_size = ensure_tuple(patch_size) self.offset = ensure_tuple(offset) if offset else (0,) * len(self.patch_size) - self.pad_mode: NumpyPadMode | None = convert_pad_mode(dst=np.zeros(1), mode=pad_mode) if pad_mode else None + self.pad_mode = pad_mode self.pad_kwargs = pad_kwargs self.overlap = overlap self.num_patches = num_patches self.sort_fn = sort_fn.lower() if sort_fn else None self.threshold = threshold - def filter_threshold(self, image_np: np.ndarray, locations: np.ndarray): + def filter_threshold(self, image_np: NdarrayOrTensor, locations: np.ndarray) -> tuple[NdarrayOrTensor, np.ndarray]: """ Filter the patches and their locations according to a threshold. Args: - image_np: a numpy.ndarray representing a stack of patches. + image_np: a numpy.ndarray or torch.Tensor representing a stack of patches. locations: a numpy.ndarray representing the stack of location of each patch. Returns: - tuple[numpy.ndarray, numpy.ndarray]: tuple of filtered patches and locations. + tuple[NdarrayOrTensor, numpy.ndarray]: tuple of filtered patches and locations. """ n_dims = len(image_np.shape) - idx = np.argwhere(image_np.sum(axis=tuple(range(1, n_dims))) < self.threshold).reshape(-1) - return image_np[idx], locations[idx] + idx = argwhere(image_np.sum(tuple(range(1, n_dims))) < self.threshold).reshape(-1) + idx_np = convert_data_type(idx, np.ndarray)[0] + return image_np[idx], locations[idx_np] # type: ignore - def filter_count(self, image_np: np.ndarray, locations: np.ndarray): + def filter_count(self, image_np: NdarrayOrTensor, locations: np.ndarray) -> tuple[NdarrayOrTensor, np.ndarray]: """ Sort the patches based on the sum of their intensity, and just keep `self.num_patches` of them. Args: - image_np: a numpy.ndarray representing a stack of patches. + image_np: a numpy.ndarray or torch.Tensor representing a stack of patches. locations: a numpy.ndarray representing the stack of location of each patch. """ if self.sort_fn is None: @@ -3072,14 +3076,15 @@ def filter_count(self, image_np: np.ndarray, locations: np.ndarray): elif self.num_patches is not None: n_dims = len(image_np.shape) if self.sort_fn == GridPatchSort.MIN: - idx = np.argsort(image_np.sum(axis=tuple(range(1, n_dims)))) + idx = argsort(image_np.sum(tuple(range(1, n_dims)))) elif self.sort_fn == GridPatchSort.MAX: - idx = np.argsort(-image_np.sum(axis=tuple(range(1, n_dims)))) + idx = argsort(-image_np.sum(tuple(range(1, n_dims)))) else: raise ValueError(f'`sort_fn` should be either "min", "max" or None! {self.sort_fn} provided!') idx = idx[: self.num_patches] - image_np = image_np[idx] - locations = locations[idx] + idx_np = convert_data_type(idx, np.ndarray)[0] + image_np = image_np[idx] # type: ignore + locations = locations[idx_np] return image_np, locations def __call__(self, array: NdarrayOrTensor) -> MetaTensor: @@ -3094,9 +3099,8 @@ def __call__(self, array: NdarrayOrTensor) -> MetaTensor: with defined `PatchKeys.LOCATION` and `PatchKeys.COUNT` metadata. """ # create the patch iterator which sweeps the image row-by-row - array_np, *_ = convert_data_type(array, np.ndarray) patch_iterator = iter_patch( - array_np, + array, patch_size=(None,) + self.patch_size, # expand to have the channel dim start_pos=(0,) + self.offset, # expand to have the channel dim overlap=self.overlap, @@ -3105,8 +3109,8 @@ def __call__(self, array: NdarrayOrTensor) -> MetaTensor: **self.pad_kwargs, ) patches = list(zip(*patch_iterator)) - patched_image = np.array(patches[0]) - locations = np.array(patches[1])[:, 1:, 0] # only keep the starting location + patched_image = np.stack(patches[0]) if isinstance(array, np.ndarray) else torch.stack(patches[0]) + locations = np.stack(patches[1])[:, 1:, 0] # only keep the starting location # Apply threshold filtering if self.threshold is not None: @@ -3120,11 +3124,22 @@ def __call__(self, array: NdarrayOrTensor) -> MetaTensor: if self.threshold is None: padding = self.num_patches - len(patched_image) if padding > 0: - patched_image = np.pad( - patched_image, - [[0, padding], [0, 0]] + [[0, 0]] * len(self.patch_size), - constant_values=self.pad_kwargs.get("constant_values", 0), - ) + # pad constant patches to the end of the first dim + constant_values = self.pad_kwargs.get("constant_values", 0) + padding_shape = (padding, *list(patched_image.shape)[1:]) + constant_padding: NdarrayOrTensor + if isinstance(patched_image, np.ndarray): + constant_padding = np.full(padding_shape, constant_values, dtype=patched_image.dtype) + patched_image = np.concatenate([patched_image, constant_padding], axis=0) + else: + constant_padding = torch.full( + padding_shape, + constant_values, + dtype=patched_image.dtype, + layout=patched_image.layout, + device=patched_image.device, + ) + patched_image = torch.cat([patched_image, constant_padding], dim=0) locations = np.pad(locations, [[0, padding], [0, 0]], constant_values=0) # Convert to MetaTensor @@ -3162,11 +3177,16 @@ class RandGridPatch(GridPatch, RandomizableTransform, MultiSampleTrait): threshold: a value to keep only the patches whose sum of intensities are less than the threshold. Defaults to no filtering. pad_mode: the mode for padding the input image by `patch_size` to include patches that cross boundaries. - Defaults to None, which means no padding will be applied. - Available modes:{``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, ``"mean"``, - ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``}. + Available modes: (Numpy) {``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, + ``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} + (PyTorch) {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}. + One of the listed string values or a user supplied function. + Defaults to `None`, which means no padding will be applied. See also: https://numpy.org/doc/stable/reference/generated/numpy.pad.html + https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html + requires pytorch >= 1.10 for best compatibility. pad_kwargs: other arguments for the `np.pad` or `torch.pad` function. + note that `np.pad` treats channel dimension as the first dimension. Returns: MetaTensor: the extracted patches as a single tensor (with patch dimension as the first dimension), @@ -3190,7 +3210,7 @@ def __init__( overlap: Sequence[float] | float = 0.0, sort_fn: str | None = None, threshold: float | None = None, - pad_mode: str = PytorchPadMode.CONSTANT, + pad_mode: str | None = None, **pad_kwargs, ): super().__init__( diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index 36e86da903..2f34f57ca2 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -67,7 +67,7 @@ ensure_tuple_rep, fall_back_tuple, ) -from monai.utils.enums import PytorchPadMode, TraceKeys +from monai.utils.enums import TraceKeys from monai.utils.module import optional_import nib, _ = optional_import("nibabel") @@ -1953,12 +1953,17 @@ class GridPatchd(MapTransform, MultiSampleTrait): threshold: a value to keep only the patches whose sum of intensities are less than the threshold. Defaults to no filtering. pad_mode: the mode for padding the input image by `patch_size` to include patches that cross boundaries. - Defaults to None, which means no padding will be applied. - Available modes:{``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, ``"mean"``, - ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``}. + Available modes: (Numpy) {``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, + ``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} + (PyTorch) {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}. + One of the listed string values or a user supplied function. + Defaults to `None`, which means no padding will be applied. See also: https://numpy.org/doc/stable/reference/generated/numpy.pad.html + https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html + requires pytorch >= 1.10 for best compatibility. allow_missing_keys: don't raise exception if key is missing. pad_kwargs: other arguments for the `np.pad` or `torch.pad` function. + note that `np.pad` treats channel dimension as the first dimension. Returns: dictionary, contains the all the original key/value with the values for `keys` @@ -1981,7 +1986,7 @@ def __init__( overlap: float = 0.0, sort_fn: str | None = None, threshold: float | None = None, - pad_mode: str = PytorchPadMode.CONSTANT, + pad_mode: str | None = None, allow_missing_keys: bool = False, **pad_kwargs, ): @@ -2028,12 +2033,17 @@ class RandGridPatchd(RandomizableTransform, MapTransform, MultiSampleTrait): threshold: a value to keep only the patches whose sum of intensities are less than the threshold. Defaults to no filtering. pad_mode: the mode for padding the input image by `patch_size` to include patches that cross boundaries. - Defaults to None, which means no padding will be applied. - Available modes:{``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, ``"mean"``, - ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``}. + Available modes: (Numpy) {``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, + ``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} + (PyTorch) {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}. + One of the listed string values or a user supplied function. + Defaults to `None`, which means no padding will be applied. See also: https://numpy.org/doc/stable/reference/generated/numpy.pad.html + https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html + requires pytorch >= 1.10 for best compatibility. allow_missing_keys: don't raise exception if key is missing. pad_kwargs: other arguments for the `np.pad` or `torch.pad` function. + note that `np.pad` treats channel dimension as the first dimension. Returns: dictionary, contains the all the original key/value with the values for `keys` @@ -2058,7 +2068,7 @@ def __init__( overlap: float = 0.0, sort_fn: str | None = None, threshold: float | None = None, - pad_mode: str = PytorchPadMode.CONSTANT, + pad_mode: str | None = None, allow_missing_keys: bool = False, **pad_kwargs, ): diff --git a/monai/transforms/utils_pytorch_numpy_unification.py b/monai/transforms/utils_pytorch_numpy_unification.py index cad15df181..0774d50314 100644 --- a/monai/transforms/utils_pytorch_numpy_unification.py +++ b/monai/transforms/utils_pytorch_numpy_unification.py @@ -28,6 +28,8 @@ "clip", "percentile", "where", + "argwhere", + "argsort", "nonzero", "floor_divide", "unravel_index", @@ -140,6 +142,36 @@ def where(condition: NdarrayOrTensor, x=None, y=None) -> NdarrayOrTensor: return result +def argwhere(a: NdarrayTensor) -> NdarrayTensor: + """`np.argwhere` with equivalent implementation for torch. + + Args: + a: input data. + + Returns: + Indices of elements that are non-zero. Indices are grouped by element. + This array will have shape (N, a.ndim) where N is the number of non-zero items. + """ + if isinstance(a, np.ndarray): + return np.argwhere(a) # type: ignore + return torch.argwhere(a) # type: ignore + + +def argsort(a: NdarrayTensor, axis: int | None = -1) -> NdarrayTensor: + """`np.argsort` with equivalent implementation for torch. + + Args: + a: the array/tensor to sort. + axis: axis along which to sort. + + Returns: + Array/Tensor of indices that sort a along the specified axis. + """ + if isinstance(a, np.ndarray): + return np.argsort(a, axis=axis) # type: ignore + return torch.argsort(a, dim=axis) # type: ignore + + def nonzero(x: NdarrayOrTensor) -> NdarrayOrTensor: """`np.nonzero` with equivalent implementation for torch. diff --git a/tests/test_grid_dataset.py b/tests/test_grid_dataset.py index 937dda344b..ba33547260 100644 --- a/tests/test_grid_dataset.py +++ b/tests/test_grid_dataset.py @@ -20,7 +20,7 @@ from monai.data import DataLoader, GridPatchDataset, PatchIter, PatchIterd, iter_patch from monai.transforms import RandShiftIntensity, RandShiftIntensityd from monai.utils import set_determinism -from tests.utils import assert_allclose, get_arange_img +from tests.utils import TEST_NDARRAYS, assert_allclose, get_arange_img def identity_generator(x): @@ -29,6 +29,34 @@ def identity_generator(x): yield item, idx +TEST_CASES_ITER_PATCH = [] +for p in TEST_NDARRAYS: + TEST_CASES_ITER_PATCH.append([p, True]) + TEST_CASES_ITER_PATCH.append([p, False]) + +A = np.arange(16).repeat(3).reshape(4, 4, 3).transpose(2, 0, 1) +A11 = A[:, :2, :2] +A12 = A[:, :2, 2:] +A21 = A[:, 2:, :2] +A22 = A[:, 2:, 2:] +COORD11 = [[0, 3], [0, 2], [0, 2]] +COORD12 = [[0, 3], [0, 2], [2, 4]] +COORD21 = [[0, 3], [2, 4], [0, 2]] +COORD22 = [[0, 3], [2, 4], [2, 4]] + +TEST_CASE_0 = [{"patch_size": (2, 2)}, A, [A11, A12, A21, A22], np.array([COORD11, COORD12, COORD21, COORD22])] +TEST_CASE_1 = [{"patch_size": (2, 2), "start_pos": (0, 2, 2)}, A, [A22], np.array([COORD22])] +TEST_CASE_2 = [{"patch_size": (2, 2), "start_pos": (0, 0, 2)}, A, [A12, A22], np.array([COORD12, COORD22])] +TEST_CASE_3 = [{"patch_size": (2, 2), "start_pos": (0, 2, 0)}, A, [A21, A22], np.array([COORD21, COORD22])] + +TEST_CASES_PATCH_ITER = [] +for p in TEST_NDARRAYS: + TEST_CASES_PATCH_ITER.append([p, *TEST_CASE_0]) + TEST_CASES_PATCH_ITER.append([p, *TEST_CASE_1]) + TEST_CASES_PATCH_ITER.append([p, *TEST_CASE_2]) + TEST_CASES_PATCH_ITER.append([p, *TEST_CASE_3]) + + class TestGridPatchDataset(unittest.TestCase): def setUp(self): set_determinism(seed=1234) @@ -36,14 +64,33 @@ def setUp(self): def tearDown(self): set_determinism(None) - @parameterized.expand([[True], [False]]) - def test_iter_patch(self, cb): + @parameterized.expand(TEST_CASES_ITER_PATCH) + def test_iter_patch(self, in_type, cb): shape = (10, 30, 30) - input_img = get_arange_img(shape) + input_img = in_type(get_arange_img(shape)) for p, _ in iter_patch(input_img, patch_size=(None, 10, 30, None), copy_back=cb): p += 1.0 - assert_allclose(p, get_arange_img(shape) + 1.0) - assert_allclose(input_img, get_arange_img(shape) + (1.0 if cb else 0.0)) + assert_allclose(p, in_type(get_arange_img(shape)) + 1.0, type_test=True, device_test=True) + assert_allclose( + input_img, in_type(get_arange_img(shape)) + (1.0 if cb else 0.0), type_test=True, device_test=True + ) + + @parameterized.expand(TEST_CASES_PATCH_ITER) + def test_patch_iter(self, in_type, input_parameters, image, expected, coords): + input_image = in_type(image) + patch_iterator = PatchIter(**input_parameters)(input_image) + for (result_image, result_loc), expected_patch, coord in zip(patch_iterator, expected, coords): + assert_allclose(result_image, in_type(expected_patch), type_test=True, device_test=True) + assert_allclose(result_loc, coord, type_test=True, device_test=True) + + @parameterized.expand(TEST_CASES_PATCH_ITER) + def test_patch_iterd(self, in_type, input_parameters, image, expected, coords): + image_key = "image" + input_dict = {image_key: in_type(image)} + patch_iterator = PatchIterd(keys=image_key, **input_parameters)(input_dict) + for (result_image_dict, result_loc), expected_patch, coord in zip(patch_iterator, expected, coords): + assert_allclose(result_image_dict[image_key], in_type(expected_patch), type_test=True, device_test=True) + assert_allclose(result_loc, coord, type_test=True, device_test=True) def test_shape(self): # test Iterable input data diff --git a/tests/test_grid_patch.py b/tests/test_grid_patch.py index 766b37cf31..342c8cbecb 100644 --- a/tests/test_grid_patch.py +++ b/tests/test_grid_patch.py @@ -14,6 +14,7 @@ import unittest import numpy as np +import torch from parameterized import parameterized from monai.data import MetaTensor, set_track_meta @@ -97,13 +98,19 @@ class TestGridPatch(unittest.TestCase): @parameterized.expand(TEST_CASES) + @SkipIfBeforePyTorchVersion((1, 11, 1)) def test_grid_patch(self, in_type, input_parameters, image, expected): input_image = in_type(image) splitter = GridPatch(**input_parameters) output = splitter(input_image) self.assertEqual(len(output), len(expected)) for output_patch, expected_patch in zip(output, expected): - assert_allclose(output_patch, expected_patch, type_test=False) + assert_allclose( + output_patch, + in_type(expected_patch), + type_test=False, + device_test=True if isinstance(in_type(expected_patch), torch.Tensor) else False, + ) @parameterized.expand([TEST_CASE_META_0, TEST_CASE_META_1]) @SkipIfBeforePyTorchVersion((1, 9, 1)) diff --git a/tests/test_grid_patchd.py b/tests/test_grid_patchd.py index 46928150cd..3e22a55238 100644 --- a/tests/test_grid_patchd.py +++ b/tests/test_grid_patchd.py @@ -14,10 +14,11 @@ import unittest import numpy as np +import torch from parameterized import parameterized from monai.transforms.spatial.dictionary import GridPatchd -from tests.utils import TEST_NDARRAYS, assert_allclose +from tests.utils import TEST_NDARRAYS, SkipIfBeforePyTorchVersion, assert_allclose A = np.arange(16).repeat(3).reshape(4, 4, 3).transpose(2, 0, 1) A11 = A[:, :2, :2] @@ -77,6 +78,7 @@ class TestGridPatchd(unittest.TestCase): @parameterized.expand(TEST_SINGLE) + @SkipIfBeforePyTorchVersion((1, 11, 1)) def test_grid_patchd(self, in_type, input_parameters, image_dict, expected): image_key = "image" input_dict = {} @@ -88,7 +90,12 @@ def test_grid_patchd(self, in_type, input_parameters, image_dict, expected): output = splitter(input_dict) self.assertEqual(len(output[image_key]), len(expected)) for output_patch, expected_patch in zip(output[image_key], expected): - assert_allclose(output_patch, expected_patch, type_test=False) + assert_allclose( + output_patch, + in_type(expected_patch), + type_test=False, + device_test=True if isinstance(in_type(expected_patch), torch.Tensor) else False, + ) if __name__ == "__main__": diff --git a/tests/test_rand_grid_patch.py b/tests/test_rand_grid_patch.py index cb66276a8c..fa1ba145a0 100644 --- a/tests/test_rand_grid_patch.py +++ b/tests/test_rand_grid_patch.py @@ -14,6 +14,7 @@ import unittest import numpy as np +import torch from parameterized import parameterized from monai.data import MetaTensor, set_track_meta @@ -21,8 +22,6 @@ from monai.utils import set_determinism from tests.utils import TEST_NDARRAYS, SkipIfBeforePyTorchVersion, assert_allclose -set_determinism(1234) - A = np.arange(16).repeat(3).reshape(4, 4, 3).transpose(2, 0, 1) A11 = A[:, :2, :2] A12 = A[:, :2, 2:] @@ -53,6 +52,7 @@ "max_offset": -1, "sort_fn": "min", "num_patches": 1, + "pad_mode": "constant", "constant_values": 255, }, A, @@ -60,14 +60,14 @@ ] TEST_CASE_10 = [{"patch_size": (2, 2), "min_offset": 0, "max_offset": 0, "threshold": 50.0}, A, [A11]] -TEST_CASE_MEAT_0 = [ +TEST_CASE_META_0 = [ {"patch_size": (2, 2)}, A, [A11, A12, A21, A22], [{"location": [0, 0]}, {"location": [0, 2]}, {"location": [2, 0]}, {"location": [2, 2]}], ] -TEST_CASE_MEAT_1 = [ +TEST_CASE_META_1 = [ {"patch_size": (2, 2)}, MetaTensor(x=A, meta={"path": "path/to/file"}), [A11, A12, A21, A22], @@ -95,7 +95,14 @@ class TestRandGridPatch(unittest.TestCase): + def setUp(self): + set_determinism(seed=1234) + + def tearDown(self): + set_determinism(None) + @parameterized.expand(TEST_SINGLE) + @SkipIfBeforePyTorchVersion((1, 11, 1)) def test_rand_grid_patch(self, in_type, input_parameters, image, expected): input_image = in_type(image) splitter = RandGridPatch(**input_parameters) @@ -103,9 +110,14 @@ def test_rand_grid_patch(self, in_type, input_parameters, image, expected): output = splitter(input_image) self.assertEqual(len(output), len(expected)) for output_patch, expected_patch in zip(output, expected): - assert_allclose(output_patch, expected_patch, type_test=False) + assert_allclose( + output_patch, + in_type(expected_patch), + type_test=False, + device_test=True if isinstance(in_type(expected_patch), torch.Tensor) else False, + ) - @parameterized.expand([TEST_CASE_MEAT_0, TEST_CASE_MEAT_1]) + @parameterized.expand([TEST_CASE_META_0, TEST_CASE_META_1]) @SkipIfBeforePyTorchVersion((1, 9, 1)) def test_rand_grid_patch_meta(self, input_parameters, image, expected, expected_meta): set_track_meta(True) diff --git a/tests/test_rand_grid_patchd.py b/tests/test_rand_grid_patchd.py index 15f4d5447f..c6d6b82729 100644 --- a/tests/test_rand_grid_patchd.py +++ b/tests/test_rand_grid_patchd.py @@ -14,13 +14,12 @@ import unittest import numpy as np +import torch from parameterized import parameterized from monai.transforms.spatial.dictionary import RandGridPatchd from monai.utils import set_determinism -from tests.utils import TEST_NDARRAYS, assert_allclose - -set_determinism(1234) +from tests.utils import TEST_NDARRAYS, SkipIfBeforePyTorchVersion, assert_allclose A = np.arange(16).repeat(3).reshape(4, 4, 3).transpose(2, 0, 1) A11 = A[:, :2, :2] @@ -52,6 +51,7 @@ "max_offset": -1, "sort_fn": "min", "num_patches": 1, + "pad_mode": "constant", "constant_values": 255, }, {"image": A}, @@ -75,7 +75,14 @@ class TestRandGridPatchd(unittest.TestCase): + def setUp(self): + set_determinism(seed=1234) + + def tearDown(self): + set_determinism(None) + @parameterized.expand(TEST_SINGLE) + @SkipIfBeforePyTorchVersion((1, 11, 1)) def test_rand_grid_patchd(self, in_type, input_parameters, image_dict, expected): image_key = "image" input_dict = {} @@ -88,7 +95,12 @@ def test_rand_grid_patchd(self, in_type, input_parameters, image_dict, expected) output = splitter(input_dict) self.assertEqual(len(output[image_key]), len(expected)) for output_patch, expected_patch in zip(output[image_key], expected): - assert_allclose(output_patch, expected_patch, type_test=False) + assert_allclose( + output_patch, + in_type(expected_patch), + type_test=False, + device_test=True if isinstance(in_type(expected_patch), torch.Tensor) else False, + ) if __name__ == "__main__":