diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index a93c48984c..bf5ed2b180 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -737,6 +737,18 @@ Spatial :members: :special-members: __call__ +`GridPatch` +""""""""""" +.. autoclass:: GridPatch + :members: + :special-members: __call__ + +`RandGridPatch` +""""""""""""""" +.. autoclass:: RandGridPatch + :members: + :special-members: __call__ + `GridSplit` """"""""""" .. autoclass:: GridSplit @@ -1513,6 +1525,18 @@ Spatial (Dict) :members: :special-members: __call__ +`GridPatchd` +"""""""""""" +.. autoclass:: GridPatchd + :members: + :special-members: __call__ + +`RandGridPatchd` +"""""""""""""""" +.. autoclass:: RandGridPatchd + :members: + :special-members: __call__ + `GridSplitd` """""""""""" .. autoclass:: GridSplitd diff --git a/monai/apps/pathology/transforms/spatial/array.py b/monai/apps/pathology/transforms/spatial/array.py index a44dce1e3f..34cabac50c 100644 --- a/monai/apps/pathology/transforms/spatial/array.py +++ b/monai/apps/pathology/transforms/spatial/array.py @@ -17,12 +17,13 @@ from monai.config.type_definitions import NdarrayOrTensor from monai.transforms.transform import Randomizable, Transform -from monai.utils import convert_data_type, convert_to_dst_type +from monai.utils import convert_data_type, convert_to_dst_type, deprecated from monai.utils.enums import TransformBackends __all__ = ["SplitOnGrid", "TileOnGrid"] +@deprecated(since="0.8", msg_suffix="use `monai.transforms.GridSplit` instead.") class SplitOnGrid(Transform): """ Split the image into patches based on the provided grid shape. @@ -107,6 +108,7 @@ def get_params(self, image_size): return patch_size, steps +@deprecated(since="0.8", msg_suffix="use `monai.transforms.GridPatch` or `monai.transforms.RandGridPatch` instead.") class TileOnGrid(Randomizable, Transform): """ Tile the 2D image into patches on a grid and maintain a subset of it. diff --git a/monai/apps/pathology/transforms/spatial/dictionary.py b/monai/apps/pathology/transforms/spatial/dictionary.py index d5c34a0840..022d82a053 100644 --- a/monai/apps/pathology/transforms/spatial/dictionary.py +++ b/monai/apps/pathology/transforms/spatial/dictionary.py @@ -15,12 +15,14 @@ from monai.config import KeysCollection from monai.config.type_definitions import NdarrayOrTensor from monai.transforms.transform import MapTransform, Randomizable +from monai.utils import deprecated from .array import SplitOnGrid, TileOnGrid __all__ = ["SplitOnGridd", "SplitOnGridD", "SplitOnGridDict", "TileOnGridd", "TileOnGridD", "TileOnGridDict"] +@deprecated(since="0.8", msg_suffix="use `monai.transforms.GridSplitd` instead.") class SplitOnGridd(MapTransform): """ Split the image into patches based on the provided grid shape. @@ -55,6 +57,7 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N return d +@deprecated(since="0.8", msg_suffix="use `monai.transforms.GridPatchd` or `monai.transforms.RandGridPatchd` instead.") class TileOnGridd(Randomizable, MapTransform): """ Tile the 2D image into patches on a grid and maintain a subset of it. diff --git a/monai/data/wsi_datasets.py b/monai/data/wsi_datasets.py index 6fe5435d57..83b4fd9fe3 100644 --- a/monai/data/wsi_datasets.py +++ b/monai/data/wsi_datasets.py @@ -206,13 +206,13 @@ def __init__( elif isinstance(offset_limits[0], tuple): self.offset_limits = offset_limits else: - ValueError( + raise ValueError( "The offset limits should be either a tuple of integers or tuple of tuple of integers." ) else: - ValueError("The offset limits should be a tuple.") + raise ValueError("The offset limits should be a tuple.") else: - ValueError( + raise ValueError( f'Invalid string for offset "{offset}". It should be either "random" as a string,' "an integer, or a tuple of integers defining the offset." ) @@ -238,7 +238,7 @@ def _evaluate_patch_coordinates(self, sample): """Define the location for each patch based on sliding-window approach""" patch_size = self._get_size(sample) level = self._get_level(sample) - start_pos = self._get_offset(sample) + offset = self._get_offset(sample) wsi_obj = self._get_wsi_object(sample) wsi_size = self.wsi_reader.get_size(wsi_obj, 0) @@ -246,7 +246,7 @@ def _evaluate_patch_coordinates(self, sample): patch_size_ = tuple(p * downsample for p in patch_size) # patch size at level 0 locations = list( iter_patch_position( - image_size=wsi_size, patch_size=patch_size_, start_pos=start_pos, overlap=self.overlap, padded=False + image_size=wsi_size, patch_size=patch_size_, start_pos=offset, overlap=self.overlap, padded=False ) ) sample["size"] = patch_size diff --git a/monai/data/wsi_reader.py b/monai/data/wsi_reader.py index fdf7de3d63..955651999a 100644 --- a/monai/data/wsi_reader.py +++ b/monai/data/wsi_reader.py @@ -278,7 +278,7 @@ def __init__(self, backend="cucim", level: int = 0, **kwargs): elif self.backend == "openslide": self.reader = OpenSlideWSIReader(level=level, **kwargs) else: - raise ValueError("The supported backends are: cucim") + raise ValueError(f"The supported backends are cucim and openslide, '{self.backend}' was given.") self.supported_suffixes = self.reader.supported_suffixes def get_level_count(self, wsi) -> int: diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index c2385499b3..18459c1b7b 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -311,6 +311,7 @@ AffineGrid, Flip, GridDistortion, + GridPatch, GridSplit, Orientation, Rand2DElastic, @@ -321,6 +322,7 @@ RandDeformGrid, RandFlip, RandGridDistortion, + RandGridPatch, RandRotate, RandRotate90, RandZoom, @@ -343,6 +345,9 @@ GridDistortiond, GridDistortionD, GridDistortionDict, + GridPatchd, + GridPatchD, + GridPatchDict, GridSplitd, GridSplitD, GridSplitDict, @@ -367,6 +372,9 @@ RandGridDistortiond, RandGridDistortionD, RandGridDistortionDict, + RandGridPatchd, + RandGridPatchD, + RandGridPatchDict, RandRotate90d, RandRotate90D, RandRotate90Dict, diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index f833f57ebb..eb854c8d23 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -22,13 +22,21 @@ from monai.config import USE_COMPILED, DtypeLike from monai.config.type_definitions import NdarrayOrTensor -from monai.data.utils import AFFINE_TOL, compute_shape_offset, reorient_spatial_axes, to_affine_nd, zoom_affine +from monai.data.utils import ( + AFFINE_TOL, + compute_shape_offset, + iter_patch, + reorient_spatial_axes, + to_affine_nd, + zoom_affine, +) from monai.networks.layers import AffineTransform, GaussianFilter, grid_pull from monai.networks.utils import meshgrid_ij, normalize_transform from monai.transforms.croppad.array import CenterSpatialCrop, Pad from monai.transforms.intensity.array import GaussianSmooth from monai.transforms.transform import Randomizable, RandomizableTransform, ThreadUnsafe, Transform from monai.transforms.utils import ( + convert_pad_mode, create_control_grid, create_grid, create_rotate, @@ -44,6 +52,7 @@ InterpolateMode, NumpyPadMode, PytorchPadMode, + convert_to_dst_type, ensure_tuple, ensure_tuple_rep, ensure_tuple_size, @@ -53,10 +62,10 @@ pytorch_after, ) from monai.utils.deprecate_utils import deprecated_arg -from monai.utils.enums import TransformBackends +from monai.utils.enums import GridPatchSort, TransformBackends from monai.utils.misc import ImageMetaKey as Key from monai.utils.module import look_up_option -from monai.utils.type_conversion import convert_data_type, convert_to_dst_type +from monai.utils.type_conversion import convert_data_type nib, has_nib = optional_import("nibabel") @@ -68,6 +77,8 @@ "Flip", "GridDistortion", "GridSplit", + "GridPatch", + "RandGridPatch", "Resize", "Rotate", "Zoom", @@ -2577,7 +2588,6 @@ def __call__( image, shape=(*self.grid, n_channels, split_size[0], split_size[1]), strides=(x_stride * x_step, y_stride * y_step, c_stride, x_stride, y_stride), - writeable=False, ) # Flatten the first two dimensions strided_image = strided_image.reshape(-1, *strided_image.shape[2:]) @@ -2609,3 +2619,161 @@ def _get_params( ) return size, steps + + +class GridPatch(Transform): + """ + Extract all the patches sweeping the entire image in a row-major sliding-window manner with possible overlaps. + It can sort the patches and return all or a subset of them. + + Args: + patch_size: size of patches to generate slices for, 0 or None selects whole dimension + offset: offset of starting position in the array, default is 0 for each dimension. + num_patches: number of patches to return. Defaults to None, which returns all the available patches. + overlap: the amount of overlap of neighboring patches in each dimension (a value between 0.0 and 1.0). + If only one float number is given, it will be applied to all dimensions. Defaults to 0.0. + sort_fn: a callable or string that defines the order of the patches to be returned. If it is a callable, it + will be passed directly to the `key` argument of `sorted` function. The string can be "min" or "max", + which are, respectively, the minimum and maximum of the sum of intensities of a patch across all dimensions + and channels. Also "random" creates a random order of patches. + By default no sorting is being done and patches are returned in a row-major order. + pad_mode: refer to NumpyPadMode and PytorchPadMode. Defaults to ``"constant"``. + pad_kwargs: other arguments for the `np.pad` or `torch.pad` function. + + """ + + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + + def __init__( + self, + patch_size: Sequence[int], + offset: Sequence[int] = (), + num_patches: Optional[int] = None, + overlap: Union[Sequence[float], float] = 0.0, + sort_fn: Optional[Union[Callable, str]] = None, + pad_mode: Union[NumpyPadMode, PytorchPadMode, str] = NumpyPadMode.CONSTANT, + **pad_kwargs, + ): + self.patch_size = ensure_tuple(patch_size) + self.offset = ensure_tuple(offset) + self.pad_mode: NumpyPadMode = convert_pad_mode(dst=np.zeros(1), mode=pad_mode) + self.pad_kwargs = pad_kwargs + self.overlap = overlap + self.num_patches = num_patches + self.sort_fn: Optional[Callable] + if isinstance(sort_fn, str): + if sort_fn == GridPatchSort.RANDOM.value: + self.sort_fn = np.random.random + elif sort_fn == GridPatchSort.MIN.value: + self.sort_fn = self.get_patch_sum + elif sort_fn == GridPatchSort.MAX.value: + self.sort_fn = self.get_negative_patch_sum + else: + raise ValueError( + f'sort_fn should be one of the following values, "{sort_fn}" was given:', + [enum.value for enum in GridPatchSort], + ) + else: + self.sort_fn = sort_fn + + @staticmethod + def get_patch_sum(x): + return x[0].sum() + + @staticmethod + def get_negative_patch_sum(x): + return -x[0].sum() + + def __call__(self, array: NdarrayOrTensor): + # create the patch iterator which sweeps the image row-by-row + array_np, *_ = convert_data_type(array, np.ndarray) + patch_iterator = iter_patch( + array_np, + patch_size=(None,) + self.patch_size, # expand to have the channel dim + start_pos=(0,) + self.offset, # expand to have the channel dim + overlap=self.overlap, + copy_back=False, + mode=self.pad_mode, + **self.pad_kwargs, + ) + if self.sort_fn is not None: + output = sorted(patch_iterator, key=self.sort_fn) + else: + output = list(patch_iterator) + if self.num_patches: + output = output[: self.num_patches] + if len(output) < self.num_patches: + patch = np.full((array.shape[0], *self.patch_size), self.pad_kwargs.get("constant_values", 0)) + slices = np.zeros((3, len(self.patch_size))) + output += [(patch, slices)] * (self.num_patches - len(output)) + + output = [convert_to_dst_type(src=patch, dst=array)[0] for patch in output] + + return output + + +class RandGridPatch(GridPatch, RandomizableTransform): + """ + Extract all the patches sweeping the entire image in a row-major sliding-window manner with possible overlaps, + and with random offset for the minimal corner of the image, (0,0) for 2D and (0,0,0) for 3D. + It can sort the patches and return all or a subset of them. + + Args: + patch_size: size of patches to generate slices for, 0 or None selects whole dimension + min_offset: the minimum range of offset to be selected randomly. Defaults to 0. + max_offset: the maximum range of offset to be selected randomly. + Defaults to image size modulo patch size. + num_patches: number of patches to return. Defaults to None, which returns all the available patches. + overlap: the amount of overlap of neighboring patches in each dimension (a value between 0.0 and 1.0). + If only one float number is given, it will be applied to all dimensions. Defaults to 0.0. + sort_fn: a callable or string that defines the order of the patches to be returned. If it is a callable, it + will be passed directly to the `key` argument of `sorted` function. The string can be "min" or "max", + which are, respectively, the minimum and maximum of the sum of intensities of a patch across all dimensions + and channels. Also "random" creates a random order of patches. + By default no sorting is being done and patches are returned in a row-major order. + pad_mode: refer to NumpyPadMode and PytorchPadMode. Defaults to ``"constant"``. + pad_kwargs: other arguments for the `np.pad` or `torch.pad` function. + + """ + + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + + def __init__( + self, + patch_size: Sequence[int], + min_offset: Optional[Union[Sequence[int], int]] = None, + max_offset: Optional[Union[Sequence[int], int]] = None, + num_patches: Optional[int] = None, + overlap: Union[Sequence[float], float] = 0.0, + sort_fn: Optional[Union[Callable, str]] = None, + pad_mode: Union[NumpyPadMode, PytorchPadMode, str] = NumpyPadMode.CONSTANT, + **pad_kwargs, + ): + super().__init__( + patch_size=patch_size, + offset=(), + num_patches=num_patches, + overlap=overlap, + sort_fn=sort_fn, + pad_mode=pad_mode, + **pad_kwargs, + ) + self.min_offset = min_offset + self.max_offset = max_offset + + def randomize(self, array): + if self.min_offset is None: + min_offset = (0,) * len(self.patch_size) + else: + min_offset = ensure_tuple_rep(self.min_offset, len(self.patch_size)) + if self.max_offset is None: + max_offset = tuple(s % p for s, p in zip(array.shape[1:], self.patch_size)) + else: + max_offset = ensure_tuple_rep(self.max_offset, len(self.patch_size)) + + self.offset = tuple(self.R.randint(low=low, high=high + 1) for low, high in zip(min_offset, max_offset)) + + def __call__(self, array: NdarrayOrTensor, randomize: bool = True): + if randomize: + self.randomize(array) + return super().__call__(array) diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index 28d9c66448..4a38dbbb59 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -17,7 +17,7 @@ from copy import deepcopy from enum import Enum -from typing import Any, Dict, Hashable, List, Mapping, Optional, Sequence, Tuple, Union +from typing import Any, Callable, Dict, Hashable, List, Mapping, Optional, Sequence, Tuple, Union import numpy as np import torch @@ -34,6 +34,7 @@ AffineGrid, Flip, GridDistortion, + GridPatch, GridSplit, Orientation, Rand2DElastic, @@ -42,6 +43,7 @@ RandAxisFlip, RandFlip, RandGridDistortion, + RandGridPatch, RandRotate, RandZoom, ResampleToMatch, @@ -63,6 +65,7 @@ ensure_tuple, ensure_tuple_rep, fall_back_tuple, + first, ) from monai.utils.deprecate_utils import deprecated_arg from monai.utils.enums import PostFix, TraceKeys @@ -133,6 +136,12 @@ "GridSplitd", "GridSplitD", "GridSplitDict", + "GridPatchd", + "GridPatchD", + "GridPatchDict", + "RandGridPatchd", + "RandGridPatchD", + "RandGridPatchDict", ] GridSampleModeSequence = Union[Sequence[Union[GridSampleMode, str]], GridSampleMode, str] @@ -2194,6 +2203,179 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> List[Dict[Hashab return output +class GridPatchd(MapTransform): + """ + Extract all the patches sweeping the entire image in a row-major sliding-window manner with possible overlaps. + It can sort the patches and return all or a subset of them. + + Args: + keys: keys of the corresponding items to be transformed. + patch_size: size of patches to generate slices for, 0 or None selects whole dimension + offset: starting position in the array, default is 0 for each dimension. + np.random.randint(0, patch_size, 2) creates random start between 0 and `patch_size` for a 2D image. + num_patches: number of patches to return. Defaults to None, which returns all the available patches. + overlap: amount of overlap between patches in each dimension. Default to 0.0. + sort_fn: a callable or string that defines the order of the patches to be returned. If it is a callable, it + will be passed directly to the `key` argument of `sorted` function. The string can be "min" or "max", + which are, respectively, the minimum and maximum of the sum of intensities of a patch across all dimensions + and channels. Also "random" creates a random order of patches. + By default no sorting is being done and patches are returned in a row-major order. + pad_mode: refer to NumpyPadMode and PytorchPadMode. Defaults to ``"constant"``. + allow_missing_keys: don't raise exception if key is missing. + pad_kwargs: other arguments for the `np.pad` or `torch.pad` function. + + Returns: + a list of dictionaries, each of which contains the all the original key/value with the values for `keys` + replaced by the patches. It also add the following new keys: + + "slices": slices from the image that defines the patch, + "patch_size": size of the extracted patch + "num_patches": total number of patches in the image + "offset": the amount of offset for the patches in the image (starting position of upper left patch) + """ + + backend = GridPatch.backend + + def __init__( + self, + keys: KeysCollection, + patch_size: Sequence[int], + offset: Sequence[int] = (), + num_patches: Optional[int] = None, + overlap: float = 0.0, + sort_fn: Optional[Union[Callable, str]] = None, + pad_mode: Union[NumpyPadMode, PytorchPadMode, str] = NumpyPadMode.CONSTANT, + allow_missing_keys: bool = False, + **pad_kwargs, + ): + super().__init__(keys, allow_missing_keys) + self.patcher = GridPatch( + patch_size=patch_size, + offset=offset, + num_patches=num_patches, + overlap=overlap, + sort_fn=sort_fn, + pad_mode=pad_mode, + **pad_kwargs, + ) + + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> List[Dict]: + d = dict(data) + original_spatial_shape = d[first(self.keys)].shape[1:] + output = [] + results = [self.patcher(d[key]) for key in self.keys] + num_patches = min(len(r) for r in results) + for patch in zip(*results): + new_dict = {k: v[0] for k, v in zip(self.keys, patch)} + # fill in the extra keys with unmodified data + for k in set(d.keys()).difference(set(self.keys)): + new_dict[k] = deepcopy(d[k]) + # fill additional metadata + new_dict["original_spatial_shape"] = original_spatial_shape + new_dict["slices"] = patch[0][1] # use the coordinate of the first item + new_dict["patch_size"] = self.patcher.patch_size + new_dict["num_patches"] = num_patches + new_dict["offset"] = self.patcher.offset + output.append(new_dict) + return output + + +class RandGridPatchd(RandomizableTransform, MapTransform): + """ + Extract all the patches sweeping the entire image in a row-major sliding-window manner with possible overlaps, + and with random offset for the minimal corner of the image, (0,0) for 2D and (0,0,0) for 3D. + It can sort the patches and return all or a subset of them. + + Args: + keys: keys of the corresponding items to be transformed. + patch_size: size of patches to generate slices for, 0 or None selects whole dimension + min_offset: the minimum range of starting position to be selected randomly. Defaults to 0. + max_offset: the maximum range of starting position to be selected randomly. + Defaults to image size modulo patch size. + num_patches: number of patches to return. Defaults to None, which returns all the available patches. + overlap: the amount of overlap of neighboring patches in each dimension (a value between 0.0 and 1.0). + If only one float number is given, it will be applied to all dimensions. Defaults to 0.0. + sort_fn: a callable or string that defines the order of the patches to be returned. If it is a callable, it + will be passed directly to the `key` argument of `sorted` function. The string can be "min" or "max", + which are, respectively, the minimum and maximum of the sum of intensities of a patch across all dimensions + and channels. Also "random" creates a random order of patches. + By default no sorting is being done and patches are returned in a row-major order. + pad_mode: refer to NumpyPadMode and PytorchPadMode. Defaults to ``"constant"``. + allow_missing_keys: don't raise exception if key is missing. + pad_kwargs: other arguments for the `np.pad` or `torch.pad` function. + + Returns: + a list of dictionaries, each of which contains the all the original key/value with the values for `keys` + replaced by the patches. It also add the following new keys: + + "slices": slices from the image that defines the patch, + "patch_size": size of the extracted patch + "num_patches": total number of patches in the image + "offset": the amount of offset for the patches in the image (starting position of the first patch) + + """ + + backend = RandGridPatch.backend + + def __init__( + self, + keys: KeysCollection, + patch_size: Sequence[int], + min_offset: Optional[Union[Sequence[int], int]] = None, + max_offset: Optional[Union[Sequence[int], int]] = None, + num_patches: Optional[int] = None, + overlap: float = 0.0, + sort_fn: Optional[Union[Callable, str]] = None, + pad_mode: Union[NumpyPadMode, PytorchPadMode, str] = NumpyPadMode.CONSTANT, + allow_missing_keys: bool = False, + **pad_kwargs, + ): + MapTransform.__init__(self, keys, allow_missing_keys) + self.patcher = RandGridPatch( + patch_size=patch_size, + min_offset=min_offset, + max_offset=max_offset, + num_patches=num_patches, + overlap=overlap, + sort_fn=sort_fn, + pad_mode=pad_mode, + **pad_kwargs, + ) + + def set_random_state( + self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None + ) -> "RandGridPatchd": + super().set_random_state(seed, state) + self.patcher.set_random_state(seed, state) + return self + + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> List[Dict]: + d = dict(data) + original_spatial_shape = d[first(self.keys)].shape[1:] + # all the keys share the same random noise + first_key: Union[Hashable, List] = self.first_key(d) + if first_key == []: + return [d] + self.patcher.randomize(d[first_key]) # type: ignore + results = [self.patcher(d[key], randomize=False) for key in self.keys] + + num_patches = min(len(r) for r in results) + output = [] + for patch in zip(*results): + new_dict = {k: v[0] for k, v in zip(self.keys, patch)} + # fill in the extra keys with unmodified data + for k in set(d.keys()).difference(set(self.keys)): + new_dict[k] = deepcopy(d[k]) + # fill additional metadata + new_dict["original_spatial_shape"] = original_spatial_shape + new_dict["slices"] = patch[0][1] # use the coordinate of the first item + new_dict["patch_size"] = self.patcher.patch_size + new_dict["num_patches"] = num_patches + new_dict["offset"] = self.patcher.offset + output.append(new_dict) + return output + + SpatialResampleD = SpatialResampleDict = SpatialResampled ResampleToMatchD = ResampleToMatchDict = ResampleToMatchd SpacingD = SpacingDict = Spacingd @@ -2215,3 +2397,5 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> List[Dict[Hashab ZoomD = ZoomDict = Zoomd RandZoomD = RandZoomDict = RandZoomd GridSplitD = GridSplitDict = GridSplitd +GridPatchD = GridPatchDict = GridPatchd +RandGridPatchD = RandGridPatchDict = RandGridPatchd diff --git a/monai/utils/__init__.py b/monai/utils/__init__.py index e7ecab077d..cd8555d173 100644 --- a/monai/utils/__init__.py +++ b/monai/utils/__init__.py @@ -22,6 +22,7 @@ CommonKeys, DiceCEReduction, ForwardMode, + GridPatchSort, GridSampleMode, GridSamplePadMode, InterpolateMode, diff --git a/monai/utils/enums.py b/monai/utils/enums.py index af044f30fe..50b55560f9 100644 --- a/monai/utils/enums.py +++ b/monai/utils/enums.py @@ -37,6 +37,7 @@ "ForwardMode", "TransformBackends", "BoxModeName", + "GridPatchSort", ] @@ -329,3 +330,13 @@ class BoxModeName(Enum): XYZWHD = "xyzwhd" # [xmin, ymin, zmin, xsize, ysize, zsize] CCWH = "ccwh" # [xcenter, ycenter, xsize, ysize] CCCWHD = "cccwhd" # [xcenter, ycenter, zcenter, xsize, ysize, zsize] + + +class GridPatchSort(Enum): + """ + The sorting method for the generated patches in `GridPatch` + """ + + RANDOM = "random" + MIN = "min" + MAX = "max" diff --git a/tests/test_grid_patch.py b/tests/test_grid_patch.py new file mode 100644 index 0000000000..c1d73f262f --- /dev/null +++ b/tests/test_grid_patch.py @@ -0,0 +1,78 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from parameterized import parameterized + +from monai.transforms.spatial.array import GridPatch +from tests.utils import TEST_NDARRAYS, assert_allclose + +A = np.arange(16).repeat(3).reshape(4, 4, 3).transpose(2, 0, 1) +A11 = A[:, :2, :2] +A12 = A[:, :2, 2:] +A21 = A[:, 2:, :2] +A22 = A[:, 2:, 2:] + +TEST_CASE_0 = [{"patch_size": (2, 2)}, A, [A11, A12, A21, A22]] +TEST_CASE_1 = [{"patch_size": (2, 2), "num_patches": 3}, A, [A11, A12, A21]] +TEST_CASE_2 = [{"patch_size": (2, 2), "num_patches": 5}, A, [A11, A12, A21, A22, np.zeros((3, 2, 2))]] +TEST_CASE_3 = [{"patch_size": (2, 2), "offset": (0, 0)}, A, [A11, A12, A21, A22]] +TEST_CASE_4 = [{"patch_size": (2, 2), "offset": (0, 0)}, A, [A11, A12, A21, A22]] +TEST_CASE_5 = [{"patch_size": (2, 2), "offset": (2, 2)}, A, [A22]] +TEST_CASE_6 = [{"patch_size": (2, 2), "offset": (0, 2)}, A, [A12, A22]] +TEST_CASE_7 = [{"patch_size": (2, 2), "offset": (2, 0)}, A, [A21, A22]] +TEST_CASE_8 = [{"patch_size": (2, 2), "num_patches": 3, "sort_fn": "max"}, A, [A22, A21, A12]] +TEST_CASE_9 = [{"patch_size": (2, 2), "num_patches": 4, "sort_fn": "min"}, A, [A11, A12, A21, A22]] +TEST_CASE_10 = [{"patch_size": (2, 2), "overlap": 0.5, "num_patches": 3}, A, [A11, A[:, :2, 1:3], A12]] +TEST_CASE_11 = [ + {"patch_size": (3, 3), "num_patches": 2, "constant_values": 255}, + A, + [A[:, :3, :3], np.pad(A[:, :3, 3:], ((0, 0), (0, 0), (0, 2)), mode="constant", constant_values=255)], +] +TEST_CASE_12 = [ + {"patch_size": (3, 3), "offset": (-2, -2), "num_patches": 2}, + A, + [np.zeros((3, 3, 3)), np.pad(A[:, :1, 1:4], ((0, 0), (2, 0), (0, 0)), mode="constant")], +] + + +TEST_SINGLE = [] +for p in TEST_NDARRAYS: + TEST_SINGLE.append([p, *TEST_CASE_0]) + TEST_SINGLE.append([p, *TEST_CASE_1]) + TEST_SINGLE.append([p, *TEST_CASE_2]) + TEST_SINGLE.append([p, *TEST_CASE_3]) + TEST_SINGLE.append([p, *TEST_CASE_4]) + TEST_SINGLE.append([p, *TEST_CASE_5]) + TEST_SINGLE.append([p, *TEST_CASE_6]) + TEST_SINGLE.append([p, *TEST_CASE_7]) + TEST_SINGLE.append([p, *TEST_CASE_8]) + TEST_SINGLE.append([p, *TEST_CASE_9]) + TEST_SINGLE.append([p, *TEST_CASE_10]) + TEST_SINGLE.append([p, *TEST_CASE_11]) + TEST_SINGLE.append([p, *TEST_CASE_12]) + + +class TestGridPatch(unittest.TestCase): + @parameterized.expand(TEST_SINGLE) + def test_grid_patch(self, in_type, input_parameters, image, expected): + input_image = in_type(image) + splitter = GridPatch(**input_parameters) + output = list(splitter(input_image)) + self.assertEqual(len(output), len(expected)) + for output_patch, expected_patch in zip(output, expected): + assert_allclose(output_patch[0], expected_patch, type_test=False) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_grid_patchd.py b/tests/test_grid_patchd.py new file mode 100644 index 0000000000..a9eec8a2f6 --- /dev/null +++ b/tests/test_grid_patchd.py @@ -0,0 +1,83 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from parameterized import parameterized + +from monai.transforms.spatial.dictionary import GridPatchd +from tests.utils import TEST_NDARRAYS, assert_allclose + +A = np.arange(16).repeat(3).reshape(4, 4, 3).transpose(2, 0, 1) +A11 = A[:, :2, :2] +A12 = A[:, :2, 2:] +A21 = A[:, 2:, :2] +A22 = A[:, 2:, 2:] + +TEST_CASE_0 = [{"patch_size": (2, 2)}, {"image": A}, [A11, A12, A21, A22]] +TEST_CASE_1 = [{"patch_size": (2, 2), "num_patches": 3}, {"image": A}, [A11, A12, A21]] +TEST_CASE_2 = [{"patch_size": (2, 2), "num_patches": 5}, {"image": A}, [A11, A12, A21, A22, np.zeros((3, 2, 2))]] +TEST_CASE_3 = [{"patch_size": (2, 2), "offset": (0, 0)}, {"image": A}, [A11, A12, A21, A22]] +TEST_CASE_4 = [{"patch_size": (2, 2), "offset": (0, 0)}, {"image": A}, [A11, A12, A21, A22]] +TEST_CASE_5 = [{"patch_size": (2, 2), "offset": (2, 2)}, {"image": A}, [A22]] +TEST_CASE_6 = [{"patch_size": (2, 2), "offset": (0, 2)}, {"image": A}, [A12, A22]] +TEST_CASE_7 = [{"patch_size": (2, 2), "offset": (2, 0)}, {"image": A}, [A21, A22]] +TEST_CASE_8 = [{"patch_size": (2, 2), "num_patches": 3, "sort_fn": "max"}, {"image": A}, [A22, A21, A12]] +TEST_CASE_9 = [{"patch_size": (2, 2), "num_patches": 4, "sort_fn": "min"}, {"image": A}, [A11, A12, A21, A22]] +TEST_CASE_10 = [{"patch_size": (2, 2), "overlap": 0.5, "num_patches": 3}, {"image": A}, [A11, A[:, :2, 1:3], A12]] +TEST_CASE_11 = [ + {"patch_size": (3, 3), "num_patches": 2, "constant_values": 255}, + {"image": A}, + [A[:, :3, :3], np.pad(A[:, :3, 3:], ((0, 0), (0, 0), (0, 2)), mode="constant", constant_values=255)], +] +TEST_CASE_12 = [ + {"patch_size": (3, 3), "offset": (-2, -2), "num_patches": 2}, + {"image": A}, + [np.zeros((3, 3, 3)), np.pad(A[:, :1, 1:4], ((0, 0), (2, 0), (0, 0)), mode="constant")], +] + + +TEST_SINGLE = [] +for p in TEST_NDARRAYS: + TEST_SINGLE.append([p, *TEST_CASE_0]) + TEST_SINGLE.append([p, *TEST_CASE_1]) + TEST_SINGLE.append([p, *TEST_CASE_2]) + TEST_SINGLE.append([p, *TEST_CASE_3]) + TEST_SINGLE.append([p, *TEST_CASE_4]) + TEST_SINGLE.append([p, *TEST_CASE_5]) + TEST_SINGLE.append([p, *TEST_CASE_6]) + TEST_SINGLE.append([p, *TEST_CASE_7]) + TEST_SINGLE.append([p, *TEST_CASE_8]) + TEST_SINGLE.append([p, *TEST_CASE_9]) + TEST_SINGLE.append([p, *TEST_CASE_10]) + TEST_SINGLE.append([p, *TEST_CASE_11]) + TEST_SINGLE.append([p, *TEST_CASE_12]) + + +class TestGridPatchd(unittest.TestCase): + @parameterized.expand(TEST_SINGLE) + def test_grid_patchd(self, in_type, input_parameters, image_dict, expected): + image_key = "image" + input_dict = {} + for k, v in image_dict.items(): + input_dict[k] = v + if k == image_key: + input_dict[k] = in_type(v) + splitter = GridPatchd(keys=image_key, **input_parameters) + output = list(splitter(input_dict)) + self.assertEqual(len(output), len(expected)) + for output_patch, expected_patch in zip(output, expected): + assert_allclose(output_patch[image_key], expected_patch, type_test=False) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_rand_grid_patch.py b/tests/test_rand_grid_patch.py new file mode 100644 index 0000000000..36da899982 --- /dev/null +++ b/tests/test_rand_grid_patch.py @@ -0,0 +1,86 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from parameterized import parameterized + +from monai.transforms.spatial.array import RandGridPatch +from monai.utils import set_determinism +from tests.utils import TEST_NDARRAYS, assert_allclose + +set_determinism(1234) + +A = np.arange(16).repeat(3).reshape(4, 4, 3).transpose(2, 0, 1) +A11 = A[:, :2, :2] +A12 = A[:, :2, 2:] +A21 = A[:, 2:, :2] +A22 = A[:, 2:, 2:] + +TEST_CASE_0 = [{"patch_size": (2, 2), "min_offset": 0, "max_offset": 0}, A, [A11, A12, A21, A22]] +TEST_CASE_1 = [{"patch_size": (2, 2), "min_offset": 0, "num_patches": 3}, A, [A11, A12, A21]] +TEST_CASE_2 = [ + {"patch_size": (2, 2), "min_offset": 0, "max_offset": 0, "num_patches": 5}, + A, + [A11, A12, A21, A22, np.zeros((3, 2, 2))], +] +TEST_CASE_3 = [{"patch_size": (2, 2), "min_offset": 0, "max_offset": 0}, A, [A11, A12, A21, A22]] +TEST_CASE_4 = [{"patch_size": (2, 2)}, A, [A11, A12, A21, A22]] +TEST_CASE_5 = [{"patch_size": (2, 2), "min_offset": 2, "max_offset": 2}, A, [A22]] +TEST_CASE_6 = [{"patch_size": (2, 2), "min_offset": (0, 2), "max_offset": (0, 2)}, A, [A12, A22]] +TEST_CASE_7 = [{"patch_size": (2, 2), "min_offset": 1, "max_offset": 2}, A, [A22]] +TEST_CASE_8 = [ + {"patch_size": (2, 2), "min_offset": 0, "max_offset": 1, "num_patches": 1, "sort_fn": "max"}, + A, + [A[:, 1:3, 1:3]], +] +TEST_CASE_9 = [ + { + "patch_size": (3, 3), + "min_offset": -3, + "max_offset": -1, + "sort_fn": "min", + "num_patches": 1, + "constant_values": 255, + }, + A, + [np.pad(A[:, :2, 1:], ((0, 0), (1, 0), (0, 0)), mode="constant", constant_values=255)], +] + +TEST_SINGLE = [] +for p in TEST_NDARRAYS: + TEST_SINGLE.append([p, *TEST_CASE_0]) + TEST_SINGLE.append([p, *TEST_CASE_1]) + TEST_SINGLE.append([p, *TEST_CASE_2]) + TEST_SINGLE.append([p, *TEST_CASE_3]) + TEST_SINGLE.append([p, *TEST_CASE_4]) + TEST_SINGLE.append([p, *TEST_CASE_5]) + TEST_SINGLE.append([p, *TEST_CASE_6]) + TEST_SINGLE.append([p, *TEST_CASE_7]) + TEST_SINGLE.append([p, *TEST_CASE_8]) + TEST_SINGLE.append([p, *TEST_CASE_9]) + + +class TestRandGridPatch(unittest.TestCase): + @parameterized.expand(TEST_SINGLE) + def test_rand_grid_patch(self, in_type, input_parameters, image, expected): + input_image = in_type(image) + splitter = RandGridPatch(**input_parameters) + splitter.set_random_state(1234) + output = list(splitter(input_image)) + self.assertEqual(len(output), len(expected)) + for output_patch, expected_patch in zip(output, expected): + assert_allclose(output_patch[0], expected_patch, type_test=False) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_rand_grid_patchd.py b/tests/test_rand_grid_patchd.py new file mode 100644 index 0000000000..6f89a3d155 --- /dev/null +++ b/tests/test_rand_grid_patchd.py @@ -0,0 +1,91 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from parameterized import parameterized + +from monai.transforms.spatial.dictionary import RandGridPatchd +from monai.utils import set_determinism +from tests.utils import TEST_NDARRAYS, assert_allclose + +set_determinism(1234) + +A = np.arange(16).repeat(3).reshape(4, 4, 3).transpose(2, 0, 1) +A11 = A[:, :2, :2] +A12 = A[:, :2, 2:] +A21 = A[:, 2:, :2] +A22 = A[:, 2:, 2:] + +TEST_CASE_0 = [{"patch_size": (2, 2), "min_offset": 0, "max_offset": 0}, {"image": A}, [A11, A12, A21, A22]] +TEST_CASE_1 = [{"patch_size": (2, 2), "min_offset": 0, "num_patches": 3}, {"image": A}, [A11, A12, A21]] +TEST_CASE_2 = [ + {"patch_size": (2, 2), "min_offset": 0, "max_offset": 0, "num_patches": 5}, + {"image": A}, + [A11, A12, A21, A22, np.zeros((3, 2, 2))], +] +TEST_CASE_3 = [{"patch_size": (2, 2), "min_offset": 0, "max_offset": 0}, {"image": A}, [A11, A12, A21, A22]] +TEST_CASE_4 = [{"patch_size": (2, 2)}, {"image": A}, [A11, A12, A21, A22]] +TEST_CASE_5 = [{"patch_size": (2, 2), "min_offset": 2, "max_offset": 2}, {"image": A}, [A22]] +TEST_CASE_6 = [{"patch_size": (2, 2), "min_offset": (0, 2), "max_offset": (0, 2)}, {"image": A}, [A12, A22]] +TEST_CASE_7 = [{"patch_size": (2, 2), "min_offset": 1, "max_offset": 2}, {"image": A}, [A22]] +TEST_CASE_8 = [ + {"patch_size": (2, 2), "min_offset": 0, "max_offset": 1, "num_patches": 1, "sort_fn": "max"}, + {"image": A}, + [A[:, 1:3, 1:3]], +] +TEST_CASE_9 = [ + { + "patch_size": (3, 3), + "min_offset": -3, + "max_offset": -1, + "sort_fn": "min", + "num_patches": 1, + "constant_values": 255, + }, + {"image": A}, + [np.pad(A[:, :2, 1:], ((0, 0), (1, 0), (0, 0)), mode="constant", constant_values=255)], +] + +TEST_SINGLE = [] +for p in TEST_NDARRAYS: + TEST_SINGLE.append([p, *TEST_CASE_0]) + TEST_SINGLE.append([p, *TEST_CASE_1]) + TEST_SINGLE.append([p, *TEST_CASE_2]) + TEST_SINGLE.append([p, *TEST_CASE_3]) + TEST_SINGLE.append([p, *TEST_CASE_4]) + TEST_SINGLE.append([p, *TEST_CASE_5]) + TEST_SINGLE.append([p, *TEST_CASE_6]) + TEST_SINGLE.append([p, *TEST_CASE_7]) + TEST_SINGLE.append([p, *TEST_CASE_8]) + TEST_SINGLE.append([p, *TEST_CASE_9]) + + +class TestRandGridPatchd(unittest.TestCase): + @parameterized.expand(TEST_SINGLE) + def test_rand_grid_patchd(self, in_type, input_parameters, image_dict, expected): + image_key = "image" + input_dict = {} + for k, v in image_dict.items(): + input_dict[k] = v + if k == image_key: + input_dict[k] = in_type(v) + splitter = RandGridPatchd(keys=image_key, **input_parameters) + splitter.set_random_state(1234) + output = list(splitter(input_dict)) + self.assertEqual(len(output), len(expected)) + for output_patch, expected_patch in zip(output, expected): + assert_allclose(output_patch[image_key], expected_patch, type_test=False) + + +if __name__ == "__main__": + unittest.main()