From 4d178e23c245034d6ea85b223784cfb7b9c26da0 Mon Sep 17 00:00:00 2001 From: Qingpeng Li Date: Mon, 27 Mar 2023 21:56:34 +0800 Subject: [PATCH 1/9] support GPU tensor for * `GridPatch`, `GridPatchd`, `RandGridPatch` and `RandGridPatchd` * `GridPatchDataset`, `PatchIter`, `PatchIterd` and `iter_patch` --- monai/data/grid_dataset.py | 43 ++++++---- monai/data/utils.py | 22 +++-- monai/transforms/croppad/array.py | 2 +- monai/transforms/croppad/functional.py | 46 ++++++----- monai/transforms/spatial/array.py | 80 +++++++++++-------- monai/transforms/spatial/dictionary.py | 28 ++++--- .../utils_pytorch_numpy_unification.py | 38 +++++++++ tests/test_grid_dataset.py | 56 +++++++++++-- tests/test_grid_patch.py | 4 +- tests/test_grid_patchd.py | 4 +- tests/test_rand_grid_patch.py | 19 +++-- tests/test_rand_grid_patchd.py | 13 ++- 12 files changed, 253 insertions(+), 102 deletions(-) diff --git a/monai/data/grid_dataset.py b/monai/data/grid_dataset.py index fc8175f630..4cb316b7ca 100644 --- a/monai/data/grid_dataset.py +++ b/monai/data/grid_dataset.py @@ -11,17 +11,18 @@ from __future__ import annotations -from collections.abc import Callable, Hashable, Iterable, Mapping, Sequence +from collections.abc import Generator, Callable, Hashable, Iterable, Mapping, Sequence from copy import deepcopy import numpy as np from monai.config import KeysCollection +from monai.config.type_definitions import NdarrayTensor from monai.data.dataset import Dataset from monai.data.iterable_dataset import IterableDataset from monai.data.utils import iter_patch from monai.transforms import apply_transform -from monai.utils import NumpyPadMode, ensure_tuple, first, look_up_option +from monai.utils import NumpyPadMode, ensure_tuple, first __all__ = ["PatchDataset", "GridPatchDataset", "PatchIter", "PatchIterd"] @@ -34,17 +35,25 @@ class PatchIter: """ def __init__( - self, patch_size: Sequence[int], start_pos: Sequence[int] = (), mode: str = NumpyPadMode.WRAP, **pad_opts: dict + self, + patch_size: Sequence[int], + start_pos: Sequence[int] = (), + mode: str | None = NumpyPadMode.WRAP, + **pad_opts: dict, ): """ Args: patch_size: size of patches to generate slices for, 0/None selects whole dimension start_pos: starting position in the array, default is 0 for each dimension - mode: {``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, ``"mean"``, - ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} - One of the listed string values or a user supplied function. Defaults to ``"wrap"``. - See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html + mode: available modes: (Numpy) {``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, + ``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} + (PyTorch) {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}. + One of the listed string values or a user supplied function. + If None, no wrapping is performed. Defaults to ``"wrap"``. + See also: https://numpy.org/doc/stable/reference/generated/numpy.pad.html + https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html + requires pytorch >= 1.10 for best compatibility. pad_opts: other arguments for the `np.pad` function. note that `np.pad` treats channel dimension as the first dimension. @@ -58,10 +67,10 @@ def __init__( """ self.patch_size = (None,) + tuple(patch_size) # expand to have the channel dim self.start_pos = ensure_tuple(start_pos) - self.mode: NumpyPadMode = look_up_option(mode, NumpyPadMode) + self.mode = mode self.pad_opts = pad_opts - def __call__(self, array: np.ndarray): + def __call__(self, array: NdarrayTensor) -> Generator[tuple[NdarrayTensor, np.ndarray], None, None]: """ Args: array: the image to generate patches from. @@ -89,10 +98,14 @@ class PatchIterd: keys: keys of the corresponding items to iterate patches. patch_size: size of patches to generate slices for, 0/None selects whole dimension start_pos: starting position in the array, default is 0 for each dimension - mode: {``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, ``"mean"``, - ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} - One of the listed string values or a user supplied function. Defaults to ``"wrap"``. - See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html + mode: available modes: (Numpy) {``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, + ``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} + (PyTorch) {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}. + One of the listed string values or a user supplied function. + If None, no wrapping is performed. Defaults to ``"wrap"``. + See also: https://numpy.org/doc/stable/reference/generated/numpy.pad.html + https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html + requires pytorch >= 1.10 for best compatibility. pad_opts: other arguments for the `np.pad` function. note that `np.pad` treats channel dimension as the first dimension. @@ -107,13 +120,13 @@ def __init__( keys: KeysCollection, patch_size: Sequence[int], start_pos: Sequence[int] = (), - mode: str = NumpyPadMode.WRAP, + mode: str | None = NumpyPadMode.WRAP, **pad_opts, ): self.keys = ensure_tuple(keys) self.patch_iter = PatchIter(patch_size=patch_size, start_pos=start_pos, mode=mode, **pad_opts) - def __call__(self, data: Mapping[Hashable, np.ndarray]): + def __call__(self, data: Mapping[Hashable, NdarrayTensor]) -> Generator[tuple[Mapping[Hashable, NdarrayTensor], np.ndarray], None, None]: d = dict(data) original_spatial_shape = d[first(self.keys)].shape[1:] diff --git a/monai/data/utils.py b/monai/data/utils.py index 5461fda937..822ef7c24d 100644 --- a/monai/data/utils.py +++ b/monai/data/utils.py @@ -247,14 +247,14 @@ def iter_patch_position( def iter_patch( - arr: np.ndarray, + arr: NdarrayTensor, patch_size: Sequence[int] | int = 0, start_pos: Sequence[int] = (), overlap: Sequence[float] | float = 0.0, copy_back: bool = True, mode: str | None = NumpyPadMode.WRAP, **pad_opts: dict, -): +) -> Generator[tuple[NdarrayTensor, np.ndarray], None, None]: """ Yield successive patches from `arr` of size `patch_size`. The iteration can start from position `start_pos` in `arr` but drawing from a padded array extended by the `patch_size` in each dimension (so these coordinates can be negative @@ -268,9 +268,16 @@ def iter_patch( overlap: the amount of overlap of neighboring patches in each dimension (a value between 0.0 and 1.0). If only one float number is given, it will be applied to all dimensions. Defaults to 0.0. copy_back: if True data from the yielded patches is copied back to `arr` once the generator completes - mode: One of the listed string values in ``monai.utils.NumpyPadMode`` or ``monai.utils.PytorchPadMode``, - or a user supplied function. If None, no wrapping is performed. Defaults to ``"wrap"``. - pad_opts: padding options, see `numpy.pad` + mode: available modes: (Numpy) {``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, + ``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} + (PyTorch) {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}. + One of the listed string values or a user supplied function. + If None, no wrapping is performed. Defaults to ``"wrap"``. + See also: https://numpy.org/doc/stable/reference/generated/numpy.pad.html + https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html + requires pytorch >= 1.10 for best compatibility. + pad_opts: other arguments for the `np.pad` or `torch.pad` function. + note that `np.pad` treats channel dimension as the first dimension. Yields: Patches of array data from `arr` which are views into a padded array which can be modified, if `copy_back` is @@ -285,6 +292,9 @@ def iter_patch( Nth_dim_start, Nth_dim_end]] """ + + from monai.transforms.croppad.functional import pad_nd # needs to be here to avoid circular import + # ensure patchSize and startPos are the right length patch_size_ = get_valid_patch_size(arr.shape, patch_size) start_pos = ensure_tuple_size(start_pos, arr.ndim) @@ -296,7 +306,7 @@ def iter_patch( _overlap = [op if v else 0.0 for op, v in zip(ensure_tuple_rep(overlap, arr.ndim), is_v)] # overlap if v else 0.0 # pad image by maximum values needed to ensure patches are taken from inside an image if padded: - arrpad = np.pad(arr, tuple((p, p) for p in _pad_size), look_up_option(mode, NumpyPadMode).value, **pad_opts) + arrpad = pad_nd(arr, to_pad=tuple((p, p) for p in _pad_size), mode=mode, **pad_opts) # choose a start position in the padded image start_pos_padded = tuple(s + p for s, p in zip(start_pos, _pad_size)) diff --git a/monai/transforms/croppad/array.py b/monai/transforms/croppad/array.py index aa13d54c51..94689d2fcf 100644 --- a/monai/transforms/croppad/array.py +++ b/monai/transforms/croppad/array.py @@ -150,7 +150,7 @@ def __call__( # type: ignore[override] kwargs_.update(kwargs) img_t = convert_to_tensor(data=img, track_meta=get_track_meta()) - return pad_func(img_t, to_pad_, mode_, self.get_transform_info(), kwargs_) + return pad_func(img_t, to_pad_, self.get_transform_info(), mode_, **kwargs_) def inverse(self, data: MetaTensor) -> MetaTensor: transform = self.pop_transform(data) diff --git a/monai/transforms/croppad/functional.py b/monai/transforms/croppad/functional.py index fa95958bd5..f1354fd755 100644 --- a/monai/transforms/croppad/functional.py +++ b/monai/transforms/croppad/functional.py @@ -21,6 +21,7 @@ import torch from torch.nn.functional import pad as pad_pt +from monai.config.type_definitions import NdarrayTensor from monai.data.meta_obj import get_track_meta from monai.data.meta_tensor import MetaTensor from monai.data.utils import to_affine_nd @@ -49,7 +50,7 @@ def _convert_pt_pad_mode(padding_mode): return PytorchPadMode.REPLICATE # "nearest", "border", and others -def _np_pad(img: torch.Tensor, pad_width: list[tuple[int, int]], mode: str, **kwargs) -> torch.Tensor: +def _np_pad(img: NdarrayTensor, pad_width: list[tuple[int, int]], mode: str, **kwargs) -> NdarrayTensor: if isinstance(img, torch.Tensor): if img.is_cuda: warnings.warn(f"Padding: moving img {img.shape} from cuda to cpu for dtype={img.dtype} mode={mode}.") @@ -59,14 +60,13 @@ def _np_pad(img: torch.Tensor, pad_width: list[tuple[int, int]], mode: str, **kw mode = convert_pad_mode(dst=img_np, mode=mode).value if mode == "constant" and "value" in kwargs: kwargs["constant_values"] = kwargs.pop("value") - out = torch.as_tensor(np.pad(img, pad_width, mode=mode, **kwargs)) # type: ignore - if isinstance(img, MetaTensor): - out = convert_to_dst_type(out, dst=img)[0] - return out + img_np = np.pad(img_np, pad_width, mode=mode, **kwargs) + return convert_to_dst_type(img_np, dst=img)[0] -def _pt_pad(img: torch.Tensor, pad_width: list[tuple[int, int]], mode: str, **kwargs) -> torch.Tensor: - mode = convert_pad_mode(dst=img, mode=mode).value +def _pt_pad(img: NdarrayTensor, pad_width: list[tuple[int, int]], mode: str, **kwargs) -> NdarrayTensor: + img_pt = torch.as_tensor(img) + mode = convert_pad_mode(dst=img_pt, mode=mode).value if mode == "constant" and "constant_values" in kwargs: _kwargs = kwargs.copy() _kwargs["value"] = _kwargs.pop("constant_values") @@ -74,13 +74,16 @@ def _pt_pad(img: torch.Tensor, pad_width: list[tuple[int, int]], mode: str, **kw _kwargs = kwargs pt_pad_width = [val for sublist in pad_width[1:] for val in sublist[::-1]][::-1] # torch.pad expects `[B, C, H, W, [D]]` shape - return pad_pt(img.unsqueeze(0), pt_pad_width, mode=mode, **_kwargs).squeeze(0) + img_pt = pad_pt(img_pt.unsqueeze(0), pt_pad_width, mode=mode, **_kwargs).squeeze(0) + return convert_to_dst_type(img_pt, dst=img)[0] -def pad_nd(img: torch.Tensor, to_pad: list[tuple[int, int]], mode: str, **kwargs): +def pad_nd(img: NdarrayTensor, to_pad: list[tuple[int, int]], mode: str=PytorchPadMode.CONSTANT, **kwargs) -> NdarrayTensor: """ - PyTorch/Numpy pad ``img`` with integers ``to_pad`` amounts. Depending on the ``mode`` and input dtype, - a suitable backend will be used automatically. + Pad `img` for a given an amount of padding in each dimension. + + `torch.nn.functional.pad` is used unless the mode or kwargs are not available in torch, + in which case `np.pad` will be used. Args: img: data to be transformed, assuming `img` is channel-first and padding doesn't apply to the channel dim. @@ -90,20 +93,18 @@ def pad_nd(img: torch.Tensor, to_pad: list[tuple[int, int]], mode: str, **kwargs ``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} (PyTorch) {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}. One of the listed string values or a user supplied function. Defaults to ``"constant"``. - See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html + See also: https://numpy.org/doc/stable/reference/generated/numpy.pad.html https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html kwargs: other arguments for the `np.pad` or `torch.pad` function. note that `np.pad` treats channel dimension as the first dimension. """ if mode in {"linear_ramp", "maximum", "mean", "median", "minimum", "symmetric", "empty"}: return _np_pad(img, pad_width=to_pad, mode=mode, **kwargs) - mode = convert_pad_mode(dst=img, mode=mode).value try: - _pad = ( - _np_pad - if mode in {"reflect", "replicate"} and img.dtype in {torch.int16, torch.int64, torch.bool, torch.uint8} - else _pt_pad - ) + _pad = _np_pad + if (mode in {"constant", "reflect", "edge", "replicate", "wrap", "circular"} and + img.dtype not in {torch.int16, torch.int64, torch.bool, torch.uint8}): + _pad = _pt_pad return _pad(img, pad_width=to_pad, mode=mode, **kwargs) except (ValueError, TypeError, RuntimeError) as err: if isinstance(err, NotImplementedError) or any( @@ -148,23 +149,26 @@ def crop_or_pad_nd(img: torch.Tensor, translation_mat, spatial_size: tuple[int, def pad_func( - img: torch.Tensor, to_pad: tuple[tuple[int, int]], mode: str, transform_info: dict, kwargs + img: torch.Tensor, to_pad: tuple[tuple[int, int]], transform_info: dict, mode: str=PytorchPadMode.CONSTANT, **kwargs ) -> torch.Tensor: """ Functional implementation of padding a MetaTensor. This function operates eagerly or lazily according to ``transform_info[TraceKeys.LAZY_EVALUATION]`` (default ``False``). + `torch.nn.functional.pad` is used unless the mode or kwargs are not available in torch, + in which case `np.pad` will be used. + Args: img: data to be transformed, assuming `img` is channel-first and padding doesn't apply to the channel dim. to_pad: the amount to be padded in each dimension [(low_H, high_H), (low_W, high_W), ...]. note that it including channel dimension. + transform_info: a dictionary with the relevant information pertaining to an applied transform. mode: available modes: (Numpy) {``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, ``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} (PyTorch) {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}. One of the listed string values or a user supplied function. Defaults to ``"constant"``. - See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html + See also: https://numpy.org/doc/stable/reference/generated/numpy.pad.html https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html - transform_info: a dictionary with the relevant information pertaining to an applied transform. kwargs: other arguments for the `np.pad` or `torch.pad` function. note that `np.pad` treats channel dimension as the first dimension. """ diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index c8dc12193a..25ee9d8bff 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -26,7 +26,7 @@ import torch from monai.config import USE_COMPILED, DtypeLike -from monai.config.type_definitions import NdarrayOrTensor +from monai.config.type_definitions import NdarrayOrTensor, NdarrayTensor from monai.data.meta_obj import get_track_meta from monai.data.meta_tensor import MetaTensor from monai.data.utils import AFFINE_TOL, affine_to_spacing, compute_shape_offset, iter_patch, to_affine_nd, zoom_affine @@ -47,7 +47,6 @@ from monai.transforms.traits import MultiSampleTrait from monai.transforms.transform import LazyTransform, Randomizable, RandomizableTransform, Transform from monai.transforms.utils import ( - convert_pad_mode, create_control_grid, create_grid, create_rotate, @@ -57,7 +56,7 @@ map_spatial_axes, scale_affine, ) -from monai.transforms.utils_pytorch_numpy_unification import linalg_inv, moveaxis +from monai.transforms.utils_pytorch_numpy_unification import argsort, argwhere, linalg_inv, moveaxis from monai.utils import ( GridSampleMode, GridSamplePadMode, @@ -77,7 +76,7 @@ optional_import, ) from monai.utils.deprecate_utils import deprecated_arg -from monai.utils.enums import GridPatchSort, PatchKeys, PytorchPadMode, TraceKeys, TransformBackends +from monai.utils.enums import GridPatchSort, PatchKeys, TraceKeys, TransformBackends from monai.utils.misc import ImageMetaKey as Key from monai.utils.module import look_up_option from monai.utils.type_conversion import convert_data_type, get_equivalent_dtype, get_torch_dtype_from_string @@ -3004,12 +3003,16 @@ class GridPatch(Transform, MultiSampleTrait): threshold: a value to keep only the patches whose sum of intensities are less than the threshold. Defaults to no filtering. pad_mode: the mode for padding the input image by `patch_size` to include patches that cross boundaries. - Defaults to None, which means no padding will be applied. - Available modes:{``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, ``"mean"``, - ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``}. + Available modes: (Numpy) {``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, + ``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} + (PyTorch) {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}. + One of the listed string values or a user supplied function. + Defaults to `None`, which means no padding will be applied. See also: https://numpy.org/doc/stable/reference/generated/numpy.pad.html + https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html + requires pytorch >= 1.10 for best compatibility. pad_kwargs: other arguments for the `np.pad` or `torch.pad` function. - + note that `np.pad` treats channel dimension as the first dimension. Returns: MetaTensor: the extracted patches as a single tensor (with patch dimension as the first dimension), with following metadata: @@ -3036,34 +3039,35 @@ def __init__( ): self.patch_size = ensure_tuple(patch_size) self.offset = ensure_tuple(offset) if offset else (0,) * len(self.patch_size) - self.pad_mode: NumpyPadMode | None = convert_pad_mode(dst=np.zeros(1), mode=pad_mode) if pad_mode else None + self.pad_mode = pad_mode self.pad_kwargs = pad_kwargs self.overlap = overlap self.num_patches = num_patches self.sort_fn = sort_fn.lower() if sort_fn else None self.threshold = threshold - def filter_threshold(self, image_np: np.ndarray, locations: np.ndarray): + def filter_threshold(self, image_np: NdarrayTensor, locations: np.ndarray) -> tuple[NdarrayTensor, np.ndarray]: """ Filter the patches and their locations according to a threshold. Args: - image_np: a numpy.ndarray representing a stack of patches. + image_np: a numpy.ndarray or torch.Tensor representing a stack of patches. locations: a numpy.ndarray representing the stack of location of each patch. Returns: - tuple[numpy.ndarray, numpy.ndarray]: tuple of filtered patches and locations. + tuple[NdarrayOrTensor, numpy.ndarray]: tuple of filtered patches and locations. """ n_dims = len(image_np.shape) - idx = np.argwhere(image_np.sum(axis=tuple(range(1, n_dims))) < self.threshold).reshape(-1) - return image_np[idx], locations[idx] + idx = argwhere(image_np.sum(axis=tuple(range(1, n_dims))) < self.threshold).reshape(-1) + idx_np = convert_data_type(idx, np.ndarray)[0] + return image_np[idx], locations[idx_np] - def filter_count(self, image_np: np.ndarray, locations: np.ndarray): + def filter_count(self, image_np: NdarrayTensor, locations: np.ndarray) -> tuple[NdarrayTensor, np.ndarray]: """ Sort the patches based on the sum of their intensity, and just keep `self.num_patches` of them. Args: - image_np: a numpy.ndarray representing a stack of patches. + image_np: a numpy.ndarray or torch.Tensor representing a stack of patches. locations: a numpy.ndarray representing the stack of location of each patch. """ if self.sort_fn is None: @@ -3072,14 +3076,15 @@ def filter_count(self, image_np: np.ndarray, locations: np.ndarray): elif self.num_patches is not None: n_dims = len(image_np.shape) if self.sort_fn == GridPatchSort.MIN: - idx = np.argsort(image_np.sum(axis=tuple(range(1, n_dims)))) + idx = argsort(image_np.sum(axis=tuple(range(1, n_dims)))) elif self.sort_fn == GridPatchSort.MAX: - idx = np.argsort(-image_np.sum(axis=tuple(range(1, n_dims)))) + idx = argsort(-image_np.sum(axis=tuple(range(1, n_dims)))) else: raise ValueError(f'`sort_fn` should be either "min", "max" or None! {self.sort_fn} provided!') idx = idx[: self.num_patches] + idx_np = convert_data_type(idx, np.ndarray)[0] image_np = image_np[idx] - locations = locations[idx] + locations = locations[idx_np] return image_np, locations def __call__(self, array: NdarrayOrTensor) -> MetaTensor: @@ -3094,9 +3099,8 @@ def __call__(self, array: NdarrayOrTensor) -> MetaTensor: with defined `PatchKeys.LOCATION` and `PatchKeys.COUNT` metadata. """ # create the patch iterator which sweeps the image row-by-row - array_np, *_ = convert_data_type(array, np.ndarray) patch_iterator = iter_patch( - array_np, + array, patch_size=(None,) + self.patch_size, # expand to have the channel dim start_pos=(0,) + self.offset, # expand to have the channel dim overlap=self.overlap, @@ -3105,8 +3109,8 @@ def __call__(self, array: NdarrayOrTensor) -> MetaTensor: **self.pad_kwargs, ) patches = list(zip(*patch_iterator)) - patched_image = np.array(patches[0]) - locations = np.array(patches[1])[:, 1:, 0] # only keep the starting location + patched_image = np.stack(patches[0]) if isinstance(array, np.ndarray) else torch.stack(patches[0]) + locations = np.stack(patches[1])[:, 1:, 0] # only keep the starting location # Apply threshold filtering if self.threshold is not None: @@ -3120,11 +3124,18 @@ def __call__(self, array: NdarrayOrTensor) -> MetaTensor: if self.threshold is None: padding = self.num_patches - len(patched_image) if padding > 0: - patched_image = np.pad( - patched_image, - [[0, padding], [0, 0]] + [[0, 0]] * len(self.patch_size), - constant_values=self.pad_kwargs.get("constant_values", 0), - ) + # pad constant patches to the end of the first dim + constant_values=self.pad_kwargs.get("constant_values", 0) + padding_shape = (padding, *list(patched_image.shape)[1:]) + if isinstance(patched_image, np.ndarray): + constant_padding = np.full(padding_shape, constant_values, dtype=patched_image.dtype) + patched_image = np.concatenate([patched_image, constant_padding], axis=0) + else: + constant_padding = torch.full( + padding_shape, constant_values, + dtype=patched_image.dtype, layout=patched_image.layout, device=patched_image.device + ) + patched_image = torch.cat([patched_image, constant_padding], dim=0) locations = np.pad(locations, [[0, padding], [0, 0]], constant_values=0) # Convert to MetaTensor @@ -3162,11 +3173,16 @@ class RandGridPatch(GridPatch, RandomizableTransform, MultiSampleTrait): threshold: a value to keep only the patches whose sum of intensities are less than the threshold. Defaults to no filtering. pad_mode: the mode for padding the input image by `patch_size` to include patches that cross boundaries. - Defaults to None, which means no padding will be applied. - Available modes:{``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, ``"mean"``, - ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``}. + Available modes: (Numpy) {``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, + ``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} + (PyTorch) {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}. + One of the listed string values or a user supplied function. + Defaults to `None`, which means no padding will be applied. See also: https://numpy.org/doc/stable/reference/generated/numpy.pad.html + https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html + requires pytorch >= 1.10 for best compatibility. pad_kwargs: other arguments for the `np.pad` or `torch.pad` function. + note that `np.pad` treats channel dimension as the first dimension. Returns: MetaTensor: the extracted patches as a single tensor (with patch dimension as the first dimension), @@ -3190,7 +3206,7 @@ def __init__( overlap: Sequence[float] | float = 0.0, sort_fn: str | None = None, threshold: float | None = None, - pad_mode: str = PytorchPadMode.CONSTANT, + pad_mode: str | None = None, **pad_kwargs, ): super().__init__( diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index 36e86da903..2f34f57ca2 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -67,7 +67,7 @@ ensure_tuple_rep, fall_back_tuple, ) -from monai.utils.enums import PytorchPadMode, TraceKeys +from monai.utils.enums import TraceKeys from monai.utils.module import optional_import nib, _ = optional_import("nibabel") @@ -1953,12 +1953,17 @@ class GridPatchd(MapTransform, MultiSampleTrait): threshold: a value to keep only the patches whose sum of intensities are less than the threshold. Defaults to no filtering. pad_mode: the mode for padding the input image by `patch_size` to include patches that cross boundaries. - Defaults to None, which means no padding will be applied. - Available modes:{``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, ``"mean"``, - ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``}. + Available modes: (Numpy) {``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, + ``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} + (PyTorch) {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}. + One of the listed string values or a user supplied function. + Defaults to `None`, which means no padding will be applied. See also: https://numpy.org/doc/stable/reference/generated/numpy.pad.html + https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html + requires pytorch >= 1.10 for best compatibility. allow_missing_keys: don't raise exception if key is missing. pad_kwargs: other arguments for the `np.pad` or `torch.pad` function. + note that `np.pad` treats channel dimension as the first dimension. Returns: dictionary, contains the all the original key/value with the values for `keys` @@ -1981,7 +1986,7 @@ def __init__( overlap: float = 0.0, sort_fn: str | None = None, threshold: float | None = None, - pad_mode: str = PytorchPadMode.CONSTANT, + pad_mode: str | None = None, allow_missing_keys: bool = False, **pad_kwargs, ): @@ -2028,12 +2033,17 @@ class RandGridPatchd(RandomizableTransform, MapTransform, MultiSampleTrait): threshold: a value to keep only the patches whose sum of intensities are less than the threshold. Defaults to no filtering. pad_mode: the mode for padding the input image by `patch_size` to include patches that cross boundaries. - Defaults to None, which means no padding will be applied. - Available modes:{``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, ``"mean"``, - ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``}. + Available modes: (Numpy) {``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, + ``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} + (PyTorch) {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}. + One of the listed string values or a user supplied function. + Defaults to `None`, which means no padding will be applied. See also: https://numpy.org/doc/stable/reference/generated/numpy.pad.html + https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html + requires pytorch >= 1.10 for best compatibility. allow_missing_keys: don't raise exception if key is missing. pad_kwargs: other arguments for the `np.pad` or `torch.pad` function. + note that `np.pad` treats channel dimension as the first dimension. Returns: dictionary, contains the all the original key/value with the values for `keys` @@ -2058,7 +2068,7 @@ def __init__( overlap: float = 0.0, sort_fn: str | None = None, threshold: float | None = None, - pad_mode: str = PytorchPadMode.CONSTANT, + pad_mode: str | None = None, allow_missing_keys: bool = False, **pad_kwargs, ): diff --git a/monai/transforms/utils_pytorch_numpy_unification.py b/monai/transforms/utils_pytorch_numpy_unification.py index cad15df181..10db9e9b31 100644 --- a/monai/transforms/utils_pytorch_numpy_unification.py +++ b/monai/transforms/utils_pytorch_numpy_unification.py @@ -28,6 +28,8 @@ "clip", "percentile", "where", + "argwhere", + "argsort", "nonzero", "floor_divide", "unravel_index", @@ -140,6 +142,42 @@ def where(condition: NdarrayOrTensor, x=None, y=None) -> NdarrayOrTensor: return result +def argwhere(a: NdarrayTensor) -> NdarrayTensor: + """`np.argwhere` with equivalent implementation for torch. + + Args: + a: input data. + + Returns: + Indices of elements that are non-zero. Indices are grouped by element. + This array will have shape (N, a.ndim) where N is the number of non-zero items. + """ + result: NdarrayTensor + if isinstance(a, np.ndarray): + result = np.argwhere(a) + else: + result = torch.argwhere(a) + return result + + +def argsort(a: NdarrayTensor, axis: int | None=-1) -> NdarrayTensor: + """`np.argsort` with equivalent implementation for torch. + + Args: + a: the array/tensor to sort. + axis: axis along which to sort. + + Returns: + Array/Tensor of indices that sort a along the specified axis. + """ + result: NdarrayOrTensor + if isinstance(a, np.ndarray): + result = np.argsort(a, axis=axis) + else: + result = torch.argsort(a, axis=axis) + return result + + def nonzero(x: NdarrayOrTensor) -> NdarrayOrTensor: """`np.nonzero` with equivalent implementation for torch. diff --git a/tests/test_grid_dataset.py b/tests/test_grid_dataset.py index 937dda344b..dab3858c2f 100644 --- a/tests/test_grid_dataset.py +++ b/tests/test_grid_dataset.py @@ -20,14 +20,39 @@ from monai.data import DataLoader, GridPatchDataset, PatchIter, PatchIterd, iter_patch from monai.transforms import RandShiftIntensity, RandShiftIntensityd from monai.utils import set_determinism -from tests.utils import assert_allclose, get_arange_img - +from tests.utils import TEST_NDARRAYS, assert_allclose, get_arange_img def identity_generator(x): # simple transform that returns the input itself for idx, item in enumerate(x): yield item, idx +TEST_CASES_ITER_PATCH = [] +for p in TEST_NDARRAYS: + TEST_CASES_ITER_PATCH.append([p, True]) + TEST_CASES_ITER_PATCH.append([p, False]) + +A = np.arange(16).repeat(3).reshape(4, 4, 3).transpose(2, 0, 1) +A11 = A[:, :2, :2] +A12 = A[:, :2, 2:] +A21 = A[:, 2:, :2] +A22 = A[:, 2:, 2:] +COORD11 = [[0,3],[0,2],[0,2]] +COORD12 = [[0,3],[0,2],[2,4]] +COORD21 = [[0,3],[2,4],[0,2]] +COORD22 = [[0,3],[2,4],[2,4]] + +TEST_CASE_0 = [{"patch_size": (2, 2)}, A, [A11, A12, A21, A22], np.array([COORD11, COORD12, COORD21, COORD22])] +TEST_CASE_1 = [{"patch_size": (2, 2), "start_pos": (0, 2, 2)}, A, [A22], np.array([COORD22])] +TEST_CASE_2 = [{"patch_size": (2, 2), "start_pos": (0, 0, 2)}, A, [A12, A22], np.array([COORD12, COORD22])] +TEST_CASE_3 = [{"patch_size": (2, 2), "start_pos": (0, 2, 0)}, A, [A21, A22], np.array([COORD21, COORD22])] + +TEST_CASES_PATCH_ITER = [] +for p in TEST_NDARRAYS: + TEST_CASES_PATCH_ITER.append([p, *TEST_CASE_0]) + TEST_CASES_PATCH_ITER.append([p, *TEST_CASE_1]) + TEST_CASES_PATCH_ITER.append([p, *TEST_CASE_2]) + TEST_CASES_PATCH_ITER.append([p, *TEST_CASE_3]) class TestGridPatchDataset(unittest.TestCase): def setUp(self): @@ -36,14 +61,31 @@ def setUp(self): def tearDown(self): set_determinism(None) - @parameterized.expand([[True], [False]]) - def test_iter_patch(self, cb): + @parameterized.expand(TEST_CASES_ITER_PATCH) + def test_iter_patch(self, in_type, cb): shape = (10, 30, 30) - input_img = get_arange_img(shape) + input_img = in_type(get_arange_img(shape)) for p, _ in iter_patch(input_img, patch_size=(None, 10, 30, None), copy_back=cb): p += 1.0 - assert_allclose(p, get_arange_img(shape) + 1.0) - assert_allclose(input_img, get_arange_img(shape) + (1.0 if cb else 0.0)) + assert_allclose(p, in_type(get_arange_img(shape)) + 1.0, type_test=True, device_test=True) + assert_allclose(input_img, in_type(get_arange_img(shape)) + (1.0 if cb else 0.0), type_test=True, device_test=True) + + @parameterized.expand(TEST_CASES_PATCH_ITER) + def test_patch_iter(self, in_type, input_parameters, image, expected, coords): + input_image = in_type(image) + patch_iterator = PatchIter(**input_parameters)(input_image) + for (result_image, result_loc), expected_patch, coord in zip(patch_iterator, expected, coords): + assert_allclose(result_image, in_type(expected_patch), type_test=True, device_test=True) + assert_allclose(result_loc, coord, type_test=True, device_test=True) + + @parameterized.expand(TEST_CASES_PATCH_ITER) + def test_patch_iterd(self, in_type, input_parameters, image, expected, coords): + image_key = "image" + input_dict = {image_key : in_type(image)} + patch_iterator = PatchIterd(keys=image_key, **input_parameters)(input_dict) + for (result_image_dict, result_loc), expected_patch, coord in zip(patch_iterator, expected, coords): + assert_allclose(result_image_dict[image_key], in_type(expected_patch), type_test=True, device_test=True) + assert_allclose(result_loc, coord, type_test=True, device_test=True) def test_shape(self): # test Iterable input data diff --git a/tests/test_grid_patch.py b/tests/test_grid_patch.py index 766b37cf31..b6f4ed0bc2 100644 --- a/tests/test_grid_patch.py +++ b/tests/test_grid_patch.py @@ -14,6 +14,7 @@ import unittest import numpy as np +import torch from parameterized import parameterized from monai.data import MetaTensor, set_track_meta @@ -103,7 +104,8 @@ def test_grid_patch(self, in_type, input_parameters, image, expected): output = splitter(input_image) self.assertEqual(len(output), len(expected)) for output_patch, expected_patch in zip(output, expected): - assert_allclose(output_patch, expected_patch, type_test=False) + assert_allclose(output_patch, in_type(expected_patch), type_test=False, + device_test=True if isinstance(in_type(expected_patch), torch.Tensor) else False) @parameterized.expand([TEST_CASE_META_0, TEST_CASE_META_1]) @SkipIfBeforePyTorchVersion((1, 9, 1)) diff --git a/tests/test_grid_patchd.py b/tests/test_grid_patchd.py index 46928150cd..41e47fde1e 100644 --- a/tests/test_grid_patchd.py +++ b/tests/test_grid_patchd.py @@ -14,6 +14,7 @@ import unittest import numpy as np +import torch from parameterized import parameterized from monai.transforms.spatial.dictionary import GridPatchd @@ -88,7 +89,8 @@ def test_grid_patchd(self, in_type, input_parameters, image_dict, expected): output = splitter(input_dict) self.assertEqual(len(output[image_key]), len(expected)) for output_patch, expected_patch in zip(output[image_key], expected): - assert_allclose(output_patch, expected_patch, type_test=False) + assert_allclose(output_patch, in_type(expected_patch), type_test=False, + device_test=True if isinstance(in_type(expected_patch), torch.Tensor) else False) if __name__ == "__main__": diff --git a/tests/test_rand_grid_patch.py b/tests/test_rand_grid_patch.py index cb66276a8c..520d1b64c8 100644 --- a/tests/test_rand_grid_patch.py +++ b/tests/test_rand_grid_patch.py @@ -14,6 +14,7 @@ import unittest import numpy as np +import torch from parameterized import parameterized from monai.data import MetaTensor, set_track_meta @@ -21,8 +22,6 @@ from monai.utils import set_determinism from tests.utils import TEST_NDARRAYS, SkipIfBeforePyTorchVersion, assert_allclose -set_determinism(1234) - A = np.arange(16).repeat(3).reshape(4, 4, 3).transpose(2, 0, 1) A11 = A[:, :2, :2] A12 = A[:, :2, 2:] @@ -53,6 +52,7 @@ "max_offset": -1, "sort_fn": "min", "num_patches": 1, + "pad_mode": "constant", "constant_values": 255, }, A, @@ -60,14 +60,14 @@ ] TEST_CASE_10 = [{"patch_size": (2, 2), "min_offset": 0, "max_offset": 0, "threshold": 50.0}, A, [A11]] -TEST_CASE_MEAT_0 = [ +TEST_CASE_META_0 = [ {"patch_size": (2, 2)}, A, [A11, A12, A21, A22], [{"location": [0, 0]}, {"location": [0, 2]}, {"location": [2, 0]}, {"location": [2, 2]}], ] -TEST_CASE_MEAT_1 = [ +TEST_CASE_META_1 = [ {"patch_size": (2, 2)}, MetaTensor(x=A, meta={"path": "path/to/file"}), [A11, A12, A21, A22], @@ -95,6 +95,12 @@ class TestRandGridPatch(unittest.TestCase): + def setUp(self): + set_determinism(seed=1234) + + def tearDown(self): + set_determinism(None) + @parameterized.expand(TEST_SINGLE) def test_rand_grid_patch(self, in_type, input_parameters, image, expected): input_image = in_type(image) @@ -103,9 +109,10 @@ def test_rand_grid_patch(self, in_type, input_parameters, image, expected): output = splitter(input_image) self.assertEqual(len(output), len(expected)) for output_patch, expected_patch in zip(output, expected): - assert_allclose(output_patch, expected_patch, type_test=False) + assert_allclose(output_patch, in_type(expected_patch), type_test=False, + device_test=True if isinstance(in_type(expected_patch), torch.Tensor) else False) - @parameterized.expand([TEST_CASE_MEAT_0, TEST_CASE_MEAT_1]) + @parameterized.expand([TEST_CASE_META_0, TEST_CASE_META_1]) @SkipIfBeforePyTorchVersion((1, 9, 1)) def test_rand_grid_patch_meta(self, input_parameters, image, expected, expected_meta): set_track_meta(True) diff --git a/tests/test_rand_grid_patchd.py b/tests/test_rand_grid_patchd.py index 15f4d5447f..a520f51f73 100644 --- a/tests/test_rand_grid_patchd.py +++ b/tests/test_rand_grid_patchd.py @@ -14,14 +14,13 @@ import unittest import numpy as np +import torch from parameterized import parameterized from monai.transforms.spatial.dictionary import RandGridPatchd from monai.utils import set_determinism from tests.utils import TEST_NDARRAYS, assert_allclose -set_determinism(1234) - A = np.arange(16).repeat(3).reshape(4, 4, 3).transpose(2, 0, 1) A11 = A[:, :2, :2] A12 = A[:, :2, 2:] @@ -52,6 +51,7 @@ "max_offset": -1, "sort_fn": "min", "num_patches": 1, + "pad_mode": "constant", "constant_values": 255, }, {"image": A}, @@ -75,6 +75,12 @@ class TestRandGridPatchd(unittest.TestCase): + def setUp(self): + set_determinism(seed=1234) + + def tearDown(self): + set_determinism(None) + @parameterized.expand(TEST_SINGLE) def test_rand_grid_patchd(self, in_type, input_parameters, image_dict, expected): image_key = "image" @@ -88,7 +94,8 @@ def test_rand_grid_patchd(self, in_type, input_parameters, image_dict, expected) output = splitter(input_dict) self.assertEqual(len(output[image_key]), len(expected)) for output_patch, expected_patch in zip(output[image_key], expected): - assert_allclose(output_patch, expected_patch, type_test=False) + assert_allclose(output_patch, in_type(expected_patch), type_test=False, + device_test=True if isinstance(in_type(expected_patch), torch.Tensor) else False) if __name__ == "__main__": From 96e720c8ba0a489b96937130c0d86b1311b2b7da Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 27 Mar 2023 15:14:34 +0000 Subject: [PATCH 2/9] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/transforms/croppad/functional.py | 4 ++-- monai/transforms/spatial/array.py | 2 +- tests/test_rand_grid_patch.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/monai/transforms/croppad/functional.py b/monai/transforms/croppad/functional.py index f1354fd755..3674f27b37 100644 --- a/monai/transforms/croppad/functional.py +++ b/monai/transforms/croppad/functional.py @@ -80,8 +80,8 @@ def _pt_pad(img: NdarrayTensor, pad_width: list[tuple[int, int]], mode: str, **k def pad_nd(img: NdarrayTensor, to_pad: list[tuple[int, int]], mode: str=PytorchPadMode.CONSTANT, **kwargs) -> NdarrayTensor: """ - Pad `img` for a given an amount of padding in each dimension. - + Pad `img` for a given an amount of padding in each dimension. + `torch.nn.functional.pad` is used unless the mode or kwargs are not available in torch, in which case `np.pad` will be used. diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index 25ee9d8bff..9647f56cac 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -3132,7 +3132,7 @@ def __call__(self, array: NdarrayOrTensor) -> MetaTensor: patched_image = np.concatenate([patched_image, constant_padding], axis=0) else: constant_padding = torch.full( - padding_shape, constant_values, + padding_shape, constant_values, dtype=patched_image.dtype, layout=patched_image.layout, device=patched_image.device ) patched_image = torch.cat([patched_image, constant_padding], dim=0) diff --git a/tests/test_rand_grid_patch.py b/tests/test_rand_grid_patch.py index 520d1b64c8..5111bc7b07 100644 --- a/tests/test_rand_grid_patch.py +++ b/tests/test_rand_grid_patch.py @@ -100,7 +100,7 @@ def setUp(self): def tearDown(self): set_determinism(None) - + @parameterized.expand(TEST_SINGLE) def test_rand_grid_patch(self, in_type, input_parameters, image, expected): input_image = in_type(image) From 537f6c2e77edc76704209ebaad846de8b08ca2fb Mon Sep 17 00:00:00 2001 From: monai-bot Date: Tue, 28 Mar 2023 11:13:43 +0000 Subject: [PATCH 3/9] [MONAI] code formatting Signed-off-by: monai-bot --- monai/data/grid_dataset.py | 6 ++++-- monai/data/utils.py | 2 +- monai/transforms/croppad/functional.py | 18 ++++++++++++++---- monai/transforms/spatial/array.py | 9 ++++++--- .../utils_pytorch_numpy_unification.py | 2 +- tests/test_grid_dataset.py | 17 +++++++++++------ tests/test_grid_patch.py | 8 ++++++-- tests/test_grid_patchd.py | 8 ++++++-- tests/test_rand_grid_patch.py | 8 ++++++-- tests/test_rand_grid_patchd.py | 8 ++++++-- 10 files changed, 61 insertions(+), 25 deletions(-) diff --git a/monai/data/grid_dataset.py b/monai/data/grid_dataset.py index 4cb316b7ca..43c72b5a78 100644 --- a/monai/data/grid_dataset.py +++ b/monai/data/grid_dataset.py @@ -11,7 +11,7 @@ from __future__ import annotations -from collections.abc import Generator, Callable, Hashable, Iterable, Mapping, Sequence +from collections.abc import Callable, Generator, Hashable, Iterable, Mapping, Sequence from copy import deepcopy import numpy as np @@ -126,7 +126,9 @@ def __init__( self.keys = ensure_tuple(keys) self.patch_iter = PatchIter(patch_size=patch_size, start_pos=start_pos, mode=mode, **pad_opts) - def __call__(self, data: Mapping[Hashable, NdarrayTensor]) -> Generator[tuple[Mapping[Hashable, NdarrayTensor], np.ndarray], None, None]: + def __call__( + self, data: Mapping[Hashable, NdarrayTensor] + ) -> Generator[tuple[Mapping[Hashable, NdarrayTensor], np.ndarray], None, None]: d = dict(data) original_spatial_shape = d[first(self.keys)].shape[1:] diff --git a/monai/data/utils.py b/monai/data/utils.py index 822ef7c24d..cf4891af4d 100644 --- a/monai/data/utils.py +++ b/monai/data/utils.py @@ -293,7 +293,7 @@ def iter_patch( """ - from monai.transforms.croppad.functional import pad_nd # needs to be here to avoid circular import + from monai.transforms.croppad.functional import pad_nd # needs to be here to avoid circular import # ensure patchSize and startPos are the right length patch_size_ = get_valid_patch_size(arr.shape, patch_size) diff --git a/monai/transforms/croppad/functional.py b/monai/transforms/croppad/functional.py index 3674f27b37..e789169677 100644 --- a/monai/transforms/croppad/functional.py +++ b/monai/transforms/croppad/functional.py @@ -78,7 +78,9 @@ def _pt_pad(img: NdarrayTensor, pad_width: list[tuple[int, int]], mode: str, **k return convert_to_dst_type(img_pt, dst=img)[0] -def pad_nd(img: NdarrayTensor, to_pad: list[tuple[int, int]], mode: str=PytorchPadMode.CONSTANT, **kwargs) -> NdarrayTensor: +def pad_nd( + img: NdarrayTensor, to_pad: list[tuple[int, int]], mode: str = PytorchPadMode.CONSTANT, **kwargs +) -> NdarrayTensor: """ Pad `img` for a given an amount of padding in each dimension. @@ -102,8 +104,12 @@ def pad_nd(img: NdarrayTensor, to_pad: list[tuple[int, int]], mode: str=PytorchP return _np_pad(img, pad_width=to_pad, mode=mode, **kwargs) try: _pad = _np_pad - if (mode in {"constant", "reflect", "edge", "replicate", "wrap", "circular"} and - img.dtype not in {torch.int16, torch.int64, torch.bool, torch.uint8}): + if mode in {"constant", "reflect", "edge", "replicate", "wrap", "circular"} and img.dtype not in { + torch.int16, + torch.int64, + torch.bool, + torch.uint8, + }: _pad = _pt_pad return _pad(img, pad_width=to_pad, mode=mode, **kwargs) except (ValueError, TypeError, RuntimeError) as err: @@ -149,7 +155,11 @@ def crop_or_pad_nd(img: torch.Tensor, translation_mat, spatial_size: tuple[int, def pad_func( - img: torch.Tensor, to_pad: tuple[tuple[int, int]], transform_info: dict, mode: str=PytorchPadMode.CONSTANT, **kwargs + img: torch.Tensor, + to_pad: tuple[tuple[int, int]], + transform_info: dict, + mode: str = PytorchPadMode.CONSTANT, + **kwargs, ) -> torch.Tensor: """ Functional implementation of padding a MetaTensor. This function operates eagerly or lazily according diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index 9647f56cac..80e904cd9c 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -3125,15 +3125,18 @@ def __call__(self, array: NdarrayOrTensor) -> MetaTensor: padding = self.num_patches - len(patched_image) if padding > 0: # pad constant patches to the end of the first dim - constant_values=self.pad_kwargs.get("constant_values", 0) + constant_values = self.pad_kwargs.get("constant_values", 0) padding_shape = (padding, *list(patched_image.shape)[1:]) if isinstance(patched_image, np.ndarray): constant_padding = np.full(padding_shape, constant_values, dtype=patched_image.dtype) patched_image = np.concatenate([patched_image, constant_padding], axis=0) else: constant_padding = torch.full( - padding_shape, constant_values, - dtype=patched_image.dtype, layout=patched_image.layout, device=patched_image.device + padding_shape, + constant_values, + dtype=patched_image.dtype, + layout=patched_image.layout, + device=patched_image.device, ) patched_image = torch.cat([patched_image, constant_padding], dim=0) locations = np.pad(locations, [[0, padding], [0, 0]], constant_values=0) diff --git a/monai/transforms/utils_pytorch_numpy_unification.py b/monai/transforms/utils_pytorch_numpy_unification.py index 10db9e9b31..403191936d 100644 --- a/monai/transforms/utils_pytorch_numpy_unification.py +++ b/monai/transforms/utils_pytorch_numpy_unification.py @@ -160,7 +160,7 @@ def argwhere(a: NdarrayTensor) -> NdarrayTensor: return result -def argsort(a: NdarrayTensor, axis: int | None=-1) -> NdarrayTensor: +def argsort(a: NdarrayTensor, axis: int | None = -1) -> NdarrayTensor: """`np.argsort` with equivalent implementation for torch. Args: diff --git a/tests/test_grid_dataset.py b/tests/test_grid_dataset.py index dab3858c2f..ba33547260 100644 --- a/tests/test_grid_dataset.py +++ b/tests/test_grid_dataset.py @@ -22,11 +22,13 @@ from monai.utils import set_determinism from tests.utils import TEST_NDARRAYS, assert_allclose, get_arange_img + def identity_generator(x): # simple transform that returns the input itself for idx, item in enumerate(x): yield item, idx + TEST_CASES_ITER_PATCH = [] for p in TEST_NDARRAYS: TEST_CASES_ITER_PATCH.append([p, True]) @@ -37,10 +39,10 @@ def identity_generator(x): A12 = A[:, :2, 2:] A21 = A[:, 2:, :2] A22 = A[:, 2:, 2:] -COORD11 = [[0,3],[0,2],[0,2]] -COORD12 = [[0,3],[0,2],[2,4]] -COORD21 = [[0,3],[2,4],[0,2]] -COORD22 = [[0,3],[2,4],[2,4]] +COORD11 = [[0, 3], [0, 2], [0, 2]] +COORD12 = [[0, 3], [0, 2], [2, 4]] +COORD21 = [[0, 3], [2, 4], [0, 2]] +COORD22 = [[0, 3], [2, 4], [2, 4]] TEST_CASE_0 = [{"patch_size": (2, 2)}, A, [A11, A12, A21, A22], np.array([COORD11, COORD12, COORD21, COORD22])] TEST_CASE_1 = [{"patch_size": (2, 2), "start_pos": (0, 2, 2)}, A, [A22], np.array([COORD22])] @@ -54,6 +56,7 @@ def identity_generator(x): TEST_CASES_PATCH_ITER.append([p, *TEST_CASE_2]) TEST_CASES_PATCH_ITER.append([p, *TEST_CASE_3]) + class TestGridPatchDataset(unittest.TestCase): def setUp(self): set_determinism(seed=1234) @@ -68,7 +71,9 @@ def test_iter_patch(self, in_type, cb): for p, _ in iter_patch(input_img, patch_size=(None, 10, 30, None), copy_back=cb): p += 1.0 assert_allclose(p, in_type(get_arange_img(shape)) + 1.0, type_test=True, device_test=True) - assert_allclose(input_img, in_type(get_arange_img(shape)) + (1.0 if cb else 0.0), type_test=True, device_test=True) + assert_allclose( + input_img, in_type(get_arange_img(shape)) + (1.0 if cb else 0.0), type_test=True, device_test=True + ) @parameterized.expand(TEST_CASES_PATCH_ITER) def test_patch_iter(self, in_type, input_parameters, image, expected, coords): @@ -81,7 +86,7 @@ def test_patch_iter(self, in_type, input_parameters, image, expected, coords): @parameterized.expand(TEST_CASES_PATCH_ITER) def test_patch_iterd(self, in_type, input_parameters, image, expected, coords): image_key = "image" - input_dict = {image_key : in_type(image)} + input_dict = {image_key: in_type(image)} patch_iterator = PatchIterd(keys=image_key, **input_parameters)(input_dict) for (result_image_dict, result_loc), expected_patch, coord in zip(patch_iterator, expected, coords): assert_allclose(result_image_dict[image_key], in_type(expected_patch), type_test=True, device_test=True) diff --git a/tests/test_grid_patch.py b/tests/test_grid_patch.py index b6f4ed0bc2..2c01106a61 100644 --- a/tests/test_grid_patch.py +++ b/tests/test_grid_patch.py @@ -104,8 +104,12 @@ def test_grid_patch(self, in_type, input_parameters, image, expected): output = splitter(input_image) self.assertEqual(len(output), len(expected)) for output_patch, expected_patch in zip(output, expected): - assert_allclose(output_patch, in_type(expected_patch), type_test=False, - device_test=True if isinstance(in_type(expected_patch), torch.Tensor) else False) + assert_allclose( + output_patch, + in_type(expected_patch), + type_test=False, + device_test=True if isinstance(in_type(expected_patch), torch.Tensor) else False, + ) @parameterized.expand([TEST_CASE_META_0, TEST_CASE_META_1]) @SkipIfBeforePyTorchVersion((1, 9, 1)) diff --git a/tests/test_grid_patchd.py b/tests/test_grid_patchd.py index 41e47fde1e..3394877b3e 100644 --- a/tests/test_grid_patchd.py +++ b/tests/test_grid_patchd.py @@ -89,8 +89,12 @@ def test_grid_patchd(self, in_type, input_parameters, image_dict, expected): output = splitter(input_dict) self.assertEqual(len(output[image_key]), len(expected)) for output_patch, expected_patch in zip(output[image_key], expected): - assert_allclose(output_patch, in_type(expected_patch), type_test=False, - device_test=True if isinstance(in_type(expected_patch), torch.Tensor) else False) + assert_allclose( + output_patch, + in_type(expected_patch), + type_test=False, + device_test=True if isinstance(in_type(expected_patch), torch.Tensor) else False, + ) if __name__ == "__main__": diff --git a/tests/test_rand_grid_patch.py b/tests/test_rand_grid_patch.py index 5111bc7b07..7d6cd5deda 100644 --- a/tests/test_rand_grid_patch.py +++ b/tests/test_rand_grid_patch.py @@ -109,8 +109,12 @@ def test_rand_grid_patch(self, in_type, input_parameters, image, expected): output = splitter(input_image) self.assertEqual(len(output), len(expected)) for output_patch, expected_patch in zip(output, expected): - assert_allclose(output_patch, in_type(expected_patch), type_test=False, - device_test=True if isinstance(in_type(expected_patch), torch.Tensor) else False) + assert_allclose( + output_patch, + in_type(expected_patch), + type_test=False, + device_test=True if isinstance(in_type(expected_patch), torch.Tensor) else False, + ) @parameterized.expand([TEST_CASE_META_0, TEST_CASE_META_1]) @SkipIfBeforePyTorchVersion((1, 9, 1)) diff --git a/tests/test_rand_grid_patchd.py b/tests/test_rand_grid_patchd.py index a520f51f73..513553870c 100644 --- a/tests/test_rand_grid_patchd.py +++ b/tests/test_rand_grid_patchd.py @@ -94,8 +94,12 @@ def test_rand_grid_patchd(self, in_type, input_parameters, image_dict, expected) output = splitter(input_dict) self.assertEqual(len(output[image_key]), len(expected)) for output_patch, expected_patch in zip(output[image_key], expected): - assert_allclose(output_patch, in_type(expected_patch), type_test=False, - device_test=True if isinstance(in_type(expected_patch), torch.Tensor) else False) + assert_allclose( + output_patch, + in_type(expected_patch), + type_test=False, + device_test=True if isinstance(in_type(expected_patch), torch.Tensor) else False, + ) if __name__ == "__main__": From b9264266b20bd491171b236586237ebfe7c967e3 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Tue, 28 Mar 2023 13:58:41 +0100 Subject: [PATCH 4/9] update to skip test Signed-off-by: Wenqi Li --- tests/test_grid_patchd.py | 3 ++- tests/test_rand_grid_patchd.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/test_grid_patchd.py b/tests/test_grid_patchd.py index 3394877b3e..3e22a55238 100644 --- a/tests/test_grid_patchd.py +++ b/tests/test_grid_patchd.py @@ -18,7 +18,7 @@ from parameterized import parameterized from monai.transforms.spatial.dictionary import GridPatchd -from tests.utils import TEST_NDARRAYS, assert_allclose +from tests.utils import TEST_NDARRAYS, SkipIfBeforePyTorchVersion, assert_allclose A = np.arange(16).repeat(3).reshape(4, 4, 3).transpose(2, 0, 1) A11 = A[:, :2, :2] @@ -78,6 +78,7 @@ class TestGridPatchd(unittest.TestCase): @parameterized.expand(TEST_SINGLE) + @SkipIfBeforePyTorchVersion((1, 11, 1)) def test_grid_patchd(self, in_type, input_parameters, image_dict, expected): image_key = "image" input_dict = {} diff --git a/tests/test_rand_grid_patchd.py b/tests/test_rand_grid_patchd.py index 513553870c..c6d6b82729 100644 --- a/tests/test_rand_grid_patchd.py +++ b/tests/test_rand_grid_patchd.py @@ -19,7 +19,7 @@ from monai.transforms.spatial.dictionary import RandGridPatchd from monai.utils import set_determinism -from tests.utils import TEST_NDARRAYS, assert_allclose +from tests.utils import TEST_NDARRAYS, SkipIfBeforePyTorchVersion, assert_allclose A = np.arange(16).repeat(3).reshape(4, 4, 3).transpose(2, 0, 1) A11 = A[:, :2, :2] @@ -82,6 +82,7 @@ def tearDown(self): set_determinism(None) @parameterized.expand(TEST_SINGLE) + @SkipIfBeforePyTorchVersion((1, 11, 1)) def test_rand_grid_patchd(self, in_type, input_parameters, image_dict, expected): image_key = "image" input_dict = {} From 27721d56a400efc4e31944df3cc83dfd624ad1d8 Mon Sep 17 00:00:00 2001 From: Qingpeng Li Date: Wed, 29 Mar 2023 18:16:20 +0800 Subject: [PATCH 5/9] fix type check DCO Remediation Commit for Qingpeng Li I, Qingpeng Li , hereby add my Signed-off-by to this commit: 4d178e23c245034d6ea85b223784cfb7b9c26da0 Signed-off-by: Qingpeng Li --- monai/data/utils.py | 8 ++++---- monai/transforms/croppad/functional.py | 4 ++-- monai/transforms/spatial/array.py | 17 ++++++++++------- .../utils_pytorch_numpy_unification.py | 14 ++++---------- 4 files changed, 20 insertions(+), 23 deletions(-) diff --git a/monai/data/utils.py b/monai/data/utils.py index cf4891af4d..06b6946713 100644 --- a/monai/data/utils.py +++ b/monai/data/utils.py @@ -247,14 +247,14 @@ def iter_patch_position( def iter_patch( - arr: NdarrayTensor, + arr: NdarrayOrTensor, patch_size: Sequence[int] | int = 0, start_pos: Sequence[int] = (), overlap: Sequence[float] | float = 0.0, copy_back: bool = True, mode: str | None = NumpyPadMode.WRAP, **pad_opts: dict, -) -> Generator[tuple[NdarrayTensor, np.ndarray], None, None]: +) -> Generator[tuple[NdarrayOrTensor, np.ndarray], None, None]: """ Yield successive patches from `arr` of size `patch_size`. The iteration can start from position `start_pos` in `arr` but drawing from a padded array extended by the `patch_size` in each dimension (so these coordinates can be negative @@ -306,7 +306,7 @@ def iter_patch( _overlap = [op if v else 0.0 for op, v in zip(ensure_tuple_rep(overlap, arr.ndim), is_v)] # overlap if v else 0.0 # pad image by maximum values needed to ensure patches are taken from inside an image if padded: - arrpad = pad_nd(arr, to_pad=tuple((p, p) for p in _pad_size), mode=mode, **pad_opts) + arrpad = pad_nd(arr, to_pad=[(p, p) for p in _pad_size], mode=mode, **pad_opts) # type: ignore # choose a start position in the padded image start_pos_padded = tuple(s + p for s, p in zip(start_pos, _pad_size)) @@ -329,7 +329,7 @@ def iter_patch( # copy back data from the padded image if required if copy_back: slices = tuple(slice(p, p + s) for p, s in zip(_pad_size, arr.shape)) - arr[...] = arrpad[slices] + arr[...] = arrpad[slices] # type: ignore def get_valid_patch_size(image_size: Sequence[int], patch_size: Sequence[int] | int | np.ndarray) -> tuple[int, ...]: diff --git a/monai/transforms/croppad/functional.py b/monai/transforms/croppad/functional.py index e789169677..016ec59f76 100644 --- a/monai/transforms/croppad/functional.py +++ b/monai/transforms/croppad/functional.py @@ -60,7 +60,7 @@ def _np_pad(img: NdarrayTensor, pad_width: list[tuple[int, int]], mode: str, **k mode = convert_pad_mode(dst=img_np, mode=mode).value if mode == "constant" and "value" in kwargs: kwargs["constant_values"] = kwargs.pop("value") - img_np = np.pad(img_np, pad_width, mode=mode, **kwargs) + img_np = np.pad(img_np, pad_width, mode=mode, **kwargs) # type: ignore return convert_to_dst_type(img_np, dst=img)[0] @@ -117,7 +117,7 @@ def pad_nd( k in str(err) for k in ("supported", "unexpected keyword", "implemented", "value") ): return _np_pad(img, pad_width=to_pad, mode=mode, **kwargs) - raise ValueError(f"{img.shape} {to_pad} {mode} {kwargs} {img.dtype} {img.device}") from err + raise ValueError(f"{img.shape} {to_pad} {mode} {kwargs} {img.dtype} {img.device if isinstance(img, torch.Tensor) else None}") from err def crop_or_pad_nd(img: torch.Tensor, translation_mat, spatial_size: tuple[int, ...], mode: str, **kwargs): diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index 80e904cd9c..e2c380d2e8 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -3046,7 +3046,7 @@ def __init__( self.sort_fn = sort_fn.lower() if sort_fn else None self.threshold = threshold - def filter_threshold(self, image_np: NdarrayTensor, locations: np.ndarray) -> tuple[NdarrayTensor, np.ndarray]: + def filter_threshold(self, image_np: NdarrayOrTensor, locations: np.ndarray) -> tuple[NdarrayOrTensor, np.ndarray]: """ Filter the patches and their locations according to a threshold. @@ -3058,11 +3058,12 @@ def filter_threshold(self, image_np: NdarrayTensor, locations: np.ndarray) -> tu tuple[NdarrayOrTensor, numpy.ndarray]: tuple of filtered patches and locations. """ n_dims = len(image_np.shape) - idx = argwhere(image_np.sum(axis=tuple(range(1, n_dims))) < self.threshold).reshape(-1) + sum_ = np.sum if isinstance(image_np, np.ndarray) else torch.sum + idx = argwhere(sum_(image_np, tuple(range(1, n_dims))) < self.threshold).reshape(-1) idx_np = convert_data_type(idx, np.ndarray)[0] - return image_np[idx], locations[idx_np] + return image_np[idx], locations[idx_np] # type: ignore - def filter_count(self, image_np: NdarrayTensor, locations: np.ndarray) -> tuple[NdarrayTensor, np.ndarray]: + def filter_count(self, image_np: NdarrayOrTensor, locations: np.ndarray) -> tuple[NdarrayOrTensor, np.ndarray]: """ Sort the patches based on the sum of their intensity, and just keep `self.num_patches` of them. @@ -3075,15 +3076,16 @@ def filter_count(self, image_np: NdarrayTensor, locations: np.ndarray) -> tuple[ locations = locations[: self.num_patches] elif self.num_patches is not None: n_dims = len(image_np.shape) + sum_ = np.sum if isinstance(image_np, np.ndarray) else torch.sum if self.sort_fn == GridPatchSort.MIN: - idx = argsort(image_np.sum(axis=tuple(range(1, n_dims)))) + idx = argsort(sum_(image_np, tuple(range(1, n_dims)))) elif self.sort_fn == GridPatchSort.MAX: - idx = argsort(-image_np.sum(axis=tuple(range(1, n_dims)))) + idx = argsort(-sum_(image_np, tuple(range(1, n_dims)))) else: raise ValueError(f'`sort_fn` should be either "min", "max" or None! {self.sort_fn} provided!') idx = idx[: self.num_patches] idx_np = convert_data_type(idx, np.ndarray)[0] - image_np = image_np[idx] + image_np = image_np[idx] # type: ignore locations = locations[idx_np] return image_np, locations @@ -3127,6 +3129,7 @@ def __call__(self, array: NdarrayOrTensor) -> MetaTensor: # pad constant patches to the end of the first dim constant_values = self.pad_kwargs.get("constant_values", 0) padding_shape = (padding, *list(patched_image.shape)[1:]) + constant_padding : NdarrayOrTensor if isinstance(patched_image, np.ndarray): constant_padding = np.full(padding_shape, constant_values, dtype=patched_image.dtype) patched_image = np.concatenate([patched_image, constant_padding], axis=0) diff --git a/monai/transforms/utils_pytorch_numpy_unification.py b/monai/transforms/utils_pytorch_numpy_unification.py index 403191936d..115c1d02e6 100644 --- a/monai/transforms/utils_pytorch_numpy_unification.py +++ b/monai/transforms/utils_pytorch_numpy_unification.py @@ -152,12 +152,9 @@ def argwhere(a: NdarrayTensor) -> NdarrayTensor: Indices of elements that are non-zero. Indices are grouped by element. This array will have shape (N, a.ndim) where N is the number of non-zero items. """ - result: NdarrayTensor if isinstance(a, np.ndarray): - result = np.argwhere(a) - else: - result = torch.argwhere(a) - return result + return np.argwhere(a) # type: ignore + return torch.argwhere(a) # type: ignore def argsort(a: NdarrayTensor, axis: int | None = -1) -> NdarrayTensor: @@ -170,12 +167,9 @@ def argsort(a: NdarrayTensor, axis: int | None = -1) -> NdarrayTensor: Returns: Array/Tensor of indices that sort a along the specified axis. """ - result: NdarrayOrTensor if isinstance(a, np.ndarray): - result = np.argsort(a, axis=axis) - else: - result = torch.argsort(a, axis=axis) - return result + return np.argsort(a, axis=axis) # type: ignore + return torch.argsort(a, dim=axis) # type: ignore def nonzero(x: NdarrayOrTensor) -> NdarrayOrTensor: From dccd42d4bc1105c0e3c301c0ea0bf3c3b02bf6cf Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 29 Mar 2023 10:18:25 +0000 Subject: [PATCH 6/9] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/transforms/spatial/array.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index e2c380d2e8..d5879d361f 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -26,7 +26,7 @@ import torch from monai.config import USE_COMPILED, DtypeLike -from monai.config.type_definitions import NdarrayOrTensor, NdarrayTensor +from monai.config.type_definitions import NdarrayOrTensor from monai.data.meta_obj import get_track_meta from monai.data.meta_tensor import MetaTensor from monai.data.utils import AFFINE_TOL, affine_to_spacing, compute_shape_offset, iter_patch, to_affine_nd, zoom_affine From 8f64c3dea0ee91c225f80be7f167e2aadeb26f94 Mon Sep 17 00:00:00 2001 From: monai-bot Date: Wed, 29 Mar 2023 10:33:58 +0000 Subject: [PATCH 7/9] [MONAI] code formatting Signed-off-by: monai-bot --- monai/data/utils.py | 2 +- monai/transforms/croppad/functional.py | 6 ++++-- monai/transforms/spatial/array.py | 6 +++--- monai/transforms/utils_pytorch_numpy_unification.py | 8 ++++---- 4 files changed, 12 insertions(+), 10 deletions(-) diff --git a/monai/data/utils.py b/monai/data/utils.py index 06b6946713..2c035afb3f 100644 --- a/monai/data/utils.py +++ b/monai/data/utils.py @@ -329,7 +329,7 @@ def iter_patch( # copy back data from the padded image if required if copy_back: slices = tuple(slice(p, p + s) for p, s in zip(_pad_size, arr.shape)) - arr[...] = arrpad[slices] # type: ignore + arr[...] = arrpad[slices] # type: ignore def get_valid_patch_size(image_size: Sequence[int], patch_size: Sequence[int] | int | np.ndarray) -> tuple[int, ...]: diff --git a/monai/transforms/croppad/functional.py b/monai/transforms/croppad/functional.py index 016ec59f76..e694edb737 100644 --- a/monai/transforms/croppad/functional.py +++ b/monai/transforms/croppad/functional.py @@ -60,7 +60,7 @@ def _np_pad(img: NdarrayTensor, pad_width: list[tuple[int, int]], mode: str, **k mode = convert_pad_mode(dst=img_np, mode=mode).value if mode == "constant" and "value" in kwargs: kwargs["constant_values"] = kwargs.pop("value") - img_np = np.pad(img_np, pad_width, mode=mode, **kwargs) # type: ignore + img_np = np.pad(img_np, pad_width, mode=mode, **kwargs) # type: ignore return convert_to_dst_type(img_np, dst=img)[0] @@ -117,7 +117,9 @@ def pad_nd( k in str(err) for k in ("supported", "unexpected keyword", "implemented", "value") ): return _np_pad(img, pad_width=to_pad, mode=mode, **kwargs) - raise ValueError(f"{img.shape} {to_pad} {mode} {kwargs} {img.dtype} {img.device if isinstance(img, torch.Tensor) else None}") from err + raise ValueError( + f"{img.shape} {to_pad} {mode} {kwargs} {img.dtype} {img.device if isinstance(img, torch.Tensor) else None}" + ) from err def crop_or_pad_nd(img: torch.Tensor, translation_mat, spatial_size: tuple[int, ...], mode: str, **kwargs): diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index d5879d361f..2fbd81f2f1 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -3061,7 +3061,7 @@ def filter_threshold(self, image_np: NdarrayOrTensor, locations: np.ndarray) -> sum_ = np.sum if isinstance(image_np, np.ndarray) else torch.sum idx = argwhere(sum_(image_np, tuple(range(1, n_dims))) < self.threshold).reshape(-1) idx_np = convert_data_type(idx, np.ndarray)[0] - return image_np[idx], locations[idx_np] # type: ignore + return image_np[idx], locations[idx_np] # type: ignore def filter_count(self, image_np: NdarrayOrTensor, locations: np.ndarray) -> tuple[NdarrayOrTensor, np.ndarray]: """ @@ -3085,7 +3085,7 @@ def filter_count(self, image_np: NdarrayOrTensor, locations: np.ndarray) -> tupl raise ValueError(f'`sort_fn` should be either "min", "max" or None! {self.sort_fn} provided!') idx = idx[: self.num_patches] idx_np = convert_data_type(idx, np.ndarray)[0] - image_np = image_np[idx] # type: ignore + image_np = image_np[idx] # type: ignore locations = locations[idx_np] return image_np, locations @@ -3129,7 +3129,7 @@ def __call__(self, array: NdarrayOrTensor) -> MetaTensor: # pad constant patches to the end of the first dim constant_values = self.pad_kwargs.get("constant_values", 0) padding_shape = (padding, *list(patched_image.shape)[1:]) - constant_padding : NdarrayOrTensor + constant_padding: NdarrayOrTensor if isinstance(patched_image, np.ndarray): constant_padding = np.full(padding_shape, constant_values, dtype=patched_image.dtype) patched_image = np.concatenate([patched_image, constant_padding], axis=0) diff --git a/monai/transforms/utils_pytorch_numpy_unification.py b/monai/transforms/utils_pytorch_numpy_unification.py index 115c1d02e6..0774d50314 100644 --- a/monai/transforms/utils_pytorch_numpy_unification.py +++ b/monai/transforms/utils_pytorch_numpy_unification.py @@ -153,8 +153,8 @@ def argwhere(a: NdarrayTensor) -> NdarrayTensor: This array will have shape (N, a.ndim) where N is the number of non-zero items. """ if isinstance(a, np.ndarray): - return np.argwhere(a) # type: ignore - return torch.argwhere(a) # type: ignore + return np.argwhere(a) # type: ignore + return torch.argwhere(a) # type: ignore def argsort(a: NdarrayTensor, axis: int | None = -1) -> NdarrayTensor: @@ -168,8 +168,8 @@ def argsort(a: NdarrayTensor, axis: int | None = -1) -> NdarrayTensor: Array/Tensor of indices that sort a along the specified axis. """ if isinstance(a, np.ndarray): - return np.argsort(a, axis=axis) # type: ignore - return torch.argsort(a, dim=axis) # type: ignore + return np.argsort(a, axis=axis) # type: ignore + return torch.argsort(a, dim=axis) # type: ignore def nonzero(x: NdarrayOrTensor) -> NdarrayOrTensor: From b546d28cf7ad9c3564d2d8472a1a793e54732108 Mon Sep 17 00:00:00 2001 From: Qingpeng Li Date: Wed, 29 Mar 2023 19:22:36 +0800 Subject: [PATCH 8/9] fix type check Signed-off-by: Qingpeng Li --- monai/transforms/spatial/array.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index 2fbd81f2f1..15cebc93c5 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -3058,8 +3058,7 @@ def filter_threshold(self, image_np: NdarrayOrTensor, locations: np.ndarray) -> tuple[NdarrayOrTensor, numpy.ndarray]: tuple of filtered patches and locations. """ n_dims = len(image_np.shape) - sum_ = np.sum if isinstance(image_np, np.ndarray) else torch.sum - idx = argwhere(sum_(image_np, tuple(range(1, n_dims))) < self.threshold).reshape(-1) + idx = argwhere(image_np.sum(tuple(range(1, n_dims))) < self.threshold).reshape(-1) idx_np = convert_data_type(idx, np.ndarray)[0] return image_np[idx], locations[idx_np] # type: ignore @@ -3076,11 +3075,10 @@ def filter_count(self, image_np: NdarrayOrTensor, locations: np.ndarray) -> tupl locations = locations[: self.num_patches] elif self.num_patches is not None: n_dims = len(image_np.shape) - sum_ = np.sum if isinstance(image_np, np.ndarray) else torch.sum if self.sort_fn == GridPatchSort.MIN: - idx = argsort(sum_(image_np, tuple(range(1, n_dims)))) + idx = argsort(image_np.sum(tuple(range(1, n_dims)))) elif self.sort_fn == GridPatchSort.MAX: - idx = argsort(-sum_(image_np, tuple(range(1, n_dims)))) + idx = argsort(-image_np.sum(tuple(range(1, n_dims)))) else: raise ValueError(f'`sort_fn` should be either "min", "max" or None! {self.sort_fn} provided!') idx = idx[: self.num_patches] From 3aa78797bda14f473b02d44319ad20da4cfe80a2 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Wed, 29 Mar 2023 20:47:13 +0100 Subject: [PATCH 9/9] update Signed-off-by: Wenqi Li --- tests/test_grid_patch.py | 1 + tests/test_rand_grid_patch.py | 1 + 2 files changed, 2 insertions(+) diff --git a/tests/test_grid_patch.py b/tests/test_grid_patch.py index 2c01106a61..342c8cbecb 100644 --- a/tests/test_grid_patch.py +++ b/tests/test_grid_patch.py @@ -98,6 +98,7 @@ class TestGridPatch(unittest.TestCase): @parameterized.expand(TEST_CASES) + @SkipIfBeforePyTorchVersion((1, 11, 1)) def test_grid_patch(self, in_type, input_parameters, image, expected): input_image = in_type(image) splitter = GridPatch(**input_parameters) diff --git a/tests/test_rand_grid_patch.py b/tests/test_rand_grid_patch.py index 7d6cd5deda..fa1ba145a0 100644 --- a/tests/test_rand_grid_patch.py +++ b/tests/test_rand_grid_patch.py @@ -102,6 +102,7 @@ def tearDown(self): set_determinism(None) @parameterized.expand(TEST_SINGLE) + @SkipIfBeforePyTorchVersion((1, 11, 1)) def test_rand_grid_patch(self, in_type, input_parameters, image, expected): input_image = in_type(image) splitter = RandGridPatch(**input_parameters)