From ff059b093fe1a7249e30ab4b9fe94e270cd11d85 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Thu, 16 Sep 2021 15:46:08 +0100 Subject: [PATCH] torch `SpatialCrop`, `SpatialCropd` Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/__init__.py | 1 + monai/transforms/croppad/array.py | 78 +++++++++++-------- monai/transforms/croppad/dictionary.py | 47 ++++++----- monai/transforms/spatial/array.py | 2 +- monai/transforms/utils.py | 4 +- .../utils_pytorch_numpy_unification.py | 21 +++++ tests/test_spatial_crop.py | 28 ++++--- tests/test_spatial_cropd.py | 67 +++++++++------- 8 files changed, 152 insertions(+), 96 deletions(-) diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index e693c06dc2..826f66a34c 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -537,6 +537,7 @@ clip, floor_divide, in1d, + maximum, moveaxis, nonzero, percentile, diff --git a/monai/transforms/croppad/array.py b/monai/transforms/croppad/array.py index fac14f4582..0e6529703f 100644 --- a/monai/transforms/croppad/array.py +++ b/monai/transforms/croppad/array.py @@ -22,7 +22,7 @@ from torch.nn.functional import pad as pad_pt from monai.config import IndexSelection -from monai.config.type_definitions import NdarrayTensor +from monai.config.type_definitions import NdarrayOrTensor from monai.data.utils import get_random_patch, get_valid_patch_size from monai.transforms.transform import Randomizable, Transform from monai.transforms.utils import ( @@ -36,6 +36,7 @@ map_classes_to_indices, weighted_patch_samples, ) +from monai.transforms.utils_pytorch_numpy_unification import floor_divide, maximum from monai.utils import ( Method, NumpyPadMode, @@ -98,8 +99,7 @@ def __init__( @staticmethod def _np_pad(img: np.ndarray, all_pad_width, mode, **kwargs) -> np.ndarray: - img_np, *_ = convert_data_type(img, np.ndarray) - return np.pad(img_np, all_pad_width, mode=mode, **kwargs) # type: ignore + return np.pad(img, all_pad_width, mode=mode, **kwargs) # type: ignore @staticmethod def _pt_pad(img: torch.Tensor, all_pad_width, mode, **kwargs) -> torch.Tensor: @@ -109,9 +109,9 @@ def _pt_pad(img: torch.Tensor, all_pad_width, mode, **kwargs) -> torch.Tensor: def __call__( self, - img: NdarrayTensor, + img: NdarrayOrTensor, mode: Optional[Union[NumpyPadMode, PytorchPadMode, str]] = None, - ) -> NdarrayTensor: + ) -> NdarrayOrTensor: """ Args: img: data to be transformed, assuming `img` is channel-first and @@ -132,7 +132,7 @@ def __call__( pad = self._pt_pad else: pad = self._np_pad # type: ignore - return pad(img, self.to_pad, mode, **self.kwargs) + return pad(img, self.to_pad, mode, **self.kwargs) # type: ignore class SpatialPad(Transform): @@ -190,9 +190,9 @@ def _determine_data_pad_width(self, data_shape: Sequence[int]) -> List[Tuple[int def __call__( self, - img: NdarrayTensor, + img: NdarrayOrTensor, mode: Optional[Union[NumpyPadMode, PytorchPadMode, str]] = None, - ) -> NdarrayTensor: + ) -> NdarrayOrTensor: """ Args: img: data to be transformed, assuming `img` is channel-first and @@ -255,9 +255,9 @@ def __init__( def __call__( self, - img: NdarrayTensor, + img: NdarrayOrTensor, mode: Optional[Union[NumpyPadMode, PytorchPadMode, str]] = None, - ) -> NdarrayTensor: + ) -> NdarrayOrTensor: """ Args: img: data to be transformed, assuming `img` is channel-first and @@ -337,9 +337,9 @@ def __init__( def __call__( self, - img: NdarrayTensor, + img: NdarrayOrTensor, mode: Optional[Union[NumpyPadMode, PytorchPadMode, str]] = None, - ) -> NdarrayTensor: + ) -> NdarrayOrTensor: """ Args: img: data to be transformed, assuming `img` is channel-first @@ -377,12 +377,14 @@ class SpatialCrop(Transform): - the start and end coordinates of the ROI """ + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + def __init__( self, - roi_center: Union[Sequence[int], np.ndarray, None] = None, - roi_size: Union[Sequence[int], np.ndarray, None] = None, - roi_start: Union[Sequence[int], np.ndarray, None] = None, - roi_end: Union[Sequence[int], np.ndarray, None] = None, + roi_center: Union[Sequence[int], NdarrayOrTensor, None] = None, + roi_size: Union[Sequence[int], NdarrayOrTensor, None] = None, + roi_start: Union[Sequence[int], NdarrayOrTensor, None] = None, + roi_end: Union[Sequence[int], NdarrayOrTensor, None] = None, roi_slices: Optional[Sequence[slice]] = None, ) -> None: """ @@ -395,33 +397,38 @@ def __init__( use the end coordinate of image. roi_slices: list of slices for each of the spatial dimensions. """ + roi_start_torch: torch.Tensor + if roi_slices: if not all(s.step is None or s.step == 1 for s in roi_slices): raise ValueError("Only slice steps of 1/None are currently supported") self.slices = list(roi_slices) else: if roi_center is not None and roi_size is not None: - roi_center = np.asarray(roi_center, dtype=np.int16) - roi_size = np.asarray(roi_size, dtype=np.int16) - roi_start_np = np.maximum(roi_center - np.floor_divide(roi_size, 2), 0) - roi_end_np = np.maximum(roi_start_np + roi_size, roi_start_np) + roi_center = torch.as_tensor(roi_center, dtype=torch.int16) + roi_size = torch.as_tensor(roi_size, dtype=torch.int16, device=roi_center.device) + roi_start_torch = maximum( # type: ignore + roi_center - floor_divide(roi_size, 2), + torch.zeros_like(roi_center), + ) + roi_end_torch = maximum(roi_start_torch + roi_size, roi_start_torch) else: if roi_start is None or roi_end is None: raise ValueError("Please specify either roi_center, roi_size or roi_start, roi_end.") - roi_start_np = np.maximum(np.asarray(roi_start, dtype=np.int16), 0) - roi_end_np = np.maximum(np.asarray(roi_end, dtype=np.int16), roi_start_np) - # Allow for 1D by converting back to np.array (since np.maximum will convert to int) - roi_start_np = roi_start_np if isinstance(roi_start_np, np.ndarray) else np.array([roi_start_np]) - roi_end_np = roi_end_np if isinstance(roi_end_np, np.ndarray) else np.array([roi_end_np]) - # convert to slices - self.slices = [slice(s, e) for s, e in zip(roi_start_np, roi_end_np)] - - def __call__(self, img: Union[np.ndarray, torch.Tensor]): + roi_start_torch = torch.as_tensor(roi_start, dtype=torch.int16) + roi_start_torch = maximum(roi_start_torch, torch.zeros_like(roi_start_torch)) # type: ignore + roi_end_torch = maximum(torch.as_tensor(roi_end, dtype=torch.int16), roi_start_torch) + # convert to slices (accounting for 1d) + if roi_start_torch.numel() == 1: + self.slices = [slice(int(roi_start_torch.item()), int(roi_end_torch.item()))] + else: + self.slices = [slice(int(s.item()), int(e.item())) for s, e in zip(roi_start_torch, roi_end_torch)] + + def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: """ Apply the transform to `img`, assuming `img` is channel-first and slicing doesn't apply to the channel dim. """ - img, *_ = convert_data_type(img, np.ndarray) sd = min(len(self.slices), len(img.shape[1:])) # spatial dims slices = [slice(None)] + self.slices[:sd] return img[tuple(slices)] @@ -822,7 +829,8 @@ def __call__(self, img: np.ndarray, weight_map: Optional[np.ndarray] = None) -> results = [] for center in self.centers: cropper = SpatialCrop(roi_center=center, roi_size=_spatial_size) - results.append(cropper(img)) + cropped: np.ndarray = cropper(img) # type: ignore + results.append(cropped) return results @@ -962,7 +970,8 @@ def __call__( if self.centers is not None: for center in self.centers: cropper = SpatialCrop(roi_center=tuple(center), roi_size=self.spatial_size) # type: ignore - results.append(cropper(img)) + cropped: np.ndarray = cropper(img) # type: ignore + results.append(cropped) return results @@ -1098,7 +1107,8 @@ def __call__( if self.centers is not None: for center in self.centers: cropper = SpatialCrop(roi_center=tuple(center), roi_size=self.spatial_size) # type: ignore - results.append(cropper(img)) + cropped: np.ndarray = cropper(img) # type: ignore + results.append(cropped) return results @@ -1146,7 +1156,7 @@ def __call__(self, img: np.ndarray, mode: Optional[Union[NumpyPadMode, str]] = N See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html """ img, *_ = convert_data_type(img, np.ndarray) # type: ignore - return self.padder(self.cropper(img), mode=mode) + return self.padder(self.cropper(img), mode=mode) # type: ignore class BoundingRect(Transform): diff --git a/monai/transforms/croppad/dictionary.py b/monai/transforms/croppad/dictionary.py index a504b23179..2d50ba0b34 100644 --- a/monai/transforms/croppad/dictionary.py +++ b/monai/transforms/croppad/dictionary.py @@ -25,7 +25,7 @@ import numpy as np from monai.config import IndexSelection, KeysCollection -from monai.config.type_definitions import NdarrayTensor +from monai.config.type_definitions import NdarrayOrTensor from monai.data.utils import get_random_patch, get_valid_patch_size from monai.transforms.croppad.array import ( BorderPad, @@ -147,14 +147,14 @@ def __init__( self.mode = ensure_tuple_rep(mode, len(self.keys)) self.padder = SpatialPad(spatial_size, method, **kwargs) - def __call__(self, data: Mapping[Hashable, NdarrayTensor]) -> Dict[Hashable, NdarrayTensor]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) for key, m in self.key_iterator(d, self.mode): self.push_transform(d, key, extra_info={"mode": m.value if isinstance(m, Enum) else m}) d[key] = self.padder(d[key], mode=m) return d - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = deepcopy(dict(data)) for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) @@ -222,14 +222,14 @@ def __init__( self.mode = ensure_tuple_rep(mode, len(self.keys)) self.padder = BorderPad(spatial_border=spatial_border, **kwargs) - def __call__(self, data: Mapping[Hashable, NdarrayTensor]) -> Dict[Hashable, NdarrayTensor]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) for key, m in self.key_iterator(d, self.mode): self.push_transform(d, key, extra_info={"mode": m.value if isinstance(m, Enum) else m}) d[key] = self.padder(d[key], mode=m) return d - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = deepcopy(dict(data)) for key in self.key_iterator(d): @@ -298,14 +298,14 @@ def __init__( self.mode = ensure_tuple_rep(mode, len(self.keys)) self.padder = DivisiblePad(k=k, method=method, **kwargs) - def __call__(self, data: Mapping[Hashable, NdarrayTensor]) -> Dict[Hashable, NdarrayTensor]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) for key, m in self.key_iterator(d, self.mode): self.push_transform(d, key, extra_info={"mode": m.value if isinstance(m, Enum) else m}) d[key] = self.padder(d[key], mode=m) return d - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = deepcopy(dict(data)) for key in self.key_iterator(d): @@ -339,6 +339,8 @@ class SpatialCropd(MapTransform, InvertibleTransform): - the start and end coordinates of the ROI """ + backend = SpatialCrop.backend + def __init__( self, keys: KeysCollection, @@ -365,14 +367,14 @@ def __init__( super().__init__(keys, allow_missing_keys) self.cropper = SpatialCrop(roi_center, roi_size, roi_start, roi_end, roi_slices) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): self.push_transform(d, key) d[key] = self.cropper(d[key]) return d - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = deepcopy(dict(data)) for key in self.key_iterator(d): @@ -426,7 +428,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda self.push_transform(d, key, orig_size=orig_size) return d - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = deepcopy(dict(data)) for key in self.key_iterator(d): @@ -481,7 +483,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda return d - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = deepcopy(dict(data)) for key in self.key_iterator(d): @@ -576,7 +578,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda d[key] = cropper(d[key]) return d - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = deepcopy(dict(data)) for key in self.key_iterator(d): @@ -772,7 +774,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> List[Dict[Hashable, n ret.append(cropped) return ret - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: Mapping[Hashable, Any]) -> Dict[Hashable, Any]: d = deepcopy(dict(data)) # We changed the transform name from RandSpatialCropd to RandSpatialCropSamplesd # Need to revert that since we're calling RandSpatialCropd's inverse @@ -859,7 +861,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda d[key] = self.cropper.crop_pad(img=d[key], box_start=box_start, box_end=box_end, mode=m) return d - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = deepcopy(dict(data)) for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) @@ -964,7 +966,8 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> List[Dict[Hashable, n for i, center in enumerate(self.centers): cropper = SpatialCrop(roi_center=center, roi_size=_spatial_size) orig_size = img.shape[1:] - results[i][key] = cropper(img) + cropped: np.ndarray = cropper(img) # type: ignore + results[i][key] = cropped self.push_transform(results[i], key, extra_info={"center": center}, orig_size=orig_size) if self.center_coord_key: results[i][self.center_coord_key] = center @@ -979,7 +982,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> List[Dict[Hashable, n return results - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = deepcopy(dict(data)) for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) @@ -1136,7 +1139,8 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> List[Dict[Hashable, n img = d[key] cropper = SpatialCrop(roi_center=tuple(center), roi_size=self.spatial_size) # type: ignore orig_size = img.shape[1:] - results[i][key] = cropper(img) + cropped: np.ndarray = cropper(img) # type: ignore + results[i][key] = cropped self.push_transform(results[i], key, extra_info={"center": center}, orig_size=orig_size) # add `patch_index` to the meta data for key, meta_key, meta_key_postfix in self.key_iterator(d, self.meta_keys, self.meta_key_postfix): @@ -1147,7 +1151,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> List[Dict[Hashable, n return results - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = deepcopy(dict(data)) for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) @@ -1315,7 +1319,8 @@ def __call__(self, data: Mapping[Hashable, Any]) -> List[Dict[Hashable, np.ndarr img = d[key] cropper = SpatialCrop(roi_center=tuple(center), roi_size=self.spatial_size) # type: ignore orig_size = img.shape[1:] - results[i][key] = cropper(img) + cropped: np.ndarray = cropper(img) # type: ignore + results[i][key] = cropped self.push_transform(results[i], key, extra_info={"center": center}, orig_size=orig_size) # add `patch_index` to the meta data for key, meta_key, meta_key_postfix in self.key_iterator(d, self.meta_keys, self.meta_key_postfix): @@ -1326,7 +1331,7 @@ def __call__(self, data: Mapping[Hashable, Any]) -> List[Dict[Hashable, np.ndarr return results - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = deepcopy(dict(data)) for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) @@ -1399,7 +1404,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda ) return d - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = deepcopy(dict(data)) for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index cb80c2036b..8b1fb854f2 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -618,7 +618,7 @@ def __call__( img_t, *_ = convert_data_type(img, torch.Tensor, dtype=torch.float32) # type: ignore _zoom = ensure_tuple_rep(self.zoom, img.ndim - 1) # match the spatial image dim - zoomed: torch.Tensor = torch.nn.functional.interpolate( # type: ignore + zoomed: NdarrayOrTensor = torch.nn.functional.interpolate( # type: ignore recompute_scale_factor=True, input=img_t.unsqueeze(0), scale_factor=list(_zoom), diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index 883d5e5faa..05e45bf26f 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -22,7 +22,7 @@ import monai import monai.transforms.transform from monai.config import DtypeLike, IndexSelection -from monai.config.type_definitions import NdarrayOrTensor, NdarrayTensor +from monai.config.type_definitions import NdarrayOrTensor from monai.networks.layers import GaussianFilter from monai.transforms.compose import Compose, OneOf from monai.transforms.transform import MapTransform, Transform @@ -1271,7 +1271,7 @@ def print_table_column(name, torch, numpy, color=Colors.none): print_color(f"Number of uncategorised: {n_uncategorized}", Colors.red) -def convert_pad_mode(dst: NdarrayTensor, mode: Union[NumpyPadMode, PytorchPadMode, str]): +def convert_pad_mode(dst: NdarrayOrTensor, mode: Union[NumpyPadMode, PytorchPadMode, str]): """ Utility to convert padding mode between numpy array and PyTorch Tensor. diff --git a/monai/transforms/utils_pytorch_numpy_unification.py b/monai/transforms/utils_pytorch_numpy_unification.py index 8808e25265..7b65f690c0 100644 --- a/monai/transforms/utils_pytorch_numpy_unification.py +++ b/monai/transforms/utils_pytorch_numpy_unification.py @@ -28,6 +28,7 @@ "unravel_index", "ravel", "any_np_pt", + "maximum", ] @@ -216,3 +217,23 @@ def any_np_pt(x: NdarrayOrTensor, axis: int): # older versions of pytorch require the input to be cast to boolean return torch.any(x.bool(), axis) return np.any(x, axis) + + +def maximum(a: NdarrayOrTensor, b: NdarrayOrTensor) -> NdarrayOrTensor: + """`np.maximum` with equivalent implementation for torch. + + `torch.maximum` only available from pt>1.6, else use `torch.stack` and `torch.max`. + + Args: + a: first array/tensor + b: second array/tensor + + Returns: + Element-wise maximum between two arrays/tensors. + """ + if isinstance(a, torch.Tensor) and isinstance(b, torch.Tensor): + # is torch and has torch.maximum (pt>1.6) + if hasattr(torch, "maximum"): + return torch.maximum(a, b) + return torch.stack((a, b)).max(dim=0)[0] + return np.maximum(a, b) diff --git a/tests/test_spatial_crop.py b/tests/test_spatial_crop.py index c76915f0a3..c18dd86e46 100644 --- a/tests/test_spatial_crop.py +++ b/tests/test_spatial_crop.py @@ -16,8 +16,9 @@ from parameterized import parameterized from monai.transforms import SpatialCrop +from tests.utils import TEST_NDARRAYS, assert_allclose -TEST_CASES = [ +TESTS = [ [ {"roi_center": [1, 1, 1], "roi_size": [2, 2, 2]}, (3, 3, 3, 3), @@ -53,17 +54,24 @@ class TestSpatialCrop(unittest.TestCase): - @parameterized.expand(TEST_CASES) + @parameterized.expand(TESTS) def test_shape(self, input_param, input_shape, expected_shape): input_data = np.random.randint(0, 2, size=input_shape) - result = SpatialCrop(**input_param)(input_data) - self.assertTupleEqual(result.shape, expected_shape) - - @parameterized.expand(TEST_CASES) - def test_tensor_shape(self, input_param, input_shape, expected_shape): - input_data = torch.randint(0, 2, size=input_shape, device="cuda" if torch.cuda.is_available() else "cpu") - result = SpatialCrop(**input_param)(input_data) - self.assertTupleEqual(result.shape, expected_shape) + results = [] + for p in TEST_NDARRAYS: + for q in TEST_NDARRAYS + (None,): + input_param_mod = { + k: q(v) if k != "roi_slices" and q is not None else v for k, v in input_param.items() + } + im = p(input_data) + result = SpatialCrop(**input_param_mod)(im) + self.assertEqual(type(im), type(result)) + if isinstance(result, torch.Tensor): + self.assertEqual(result.device, im.device) + self.assertTupleEqual(result.shape, expected_shape) + results.append(result) + if len(results) > 1: + assert_allclose(results[0], results[-1], type_test=False) @parameterized.expand(TEST_ERRORS) def test_error(self, input_param): diff --git a/tests/test_spatial_cropd.py b/tests/test_spatial_cropd.py index 797c25d34b..17743124e0 100644 --- a/tests/test_spatial_cropd.py +++ b/tests/test_spatial_cropd.py @@ -15,38 +15,49 @@ from parameterized import parameterized from monai.transforms import SpatialCropd +from tests.utils import TEST_NDARRAYS -TEST_CASES = [ - [ - {"keys": ["img"], "roi_center": [1, 1, 1], "roi_size": [2, 2, 2]}, - {"img": np.random.randint(0, 2, size=[3, 3, 3, 3])}, - (3, 2, 2, 2), - ], - [ - {"keys": ["img"], "roi_start": [0, 0, 0], "roi_end": [2, 2, 2]}, - {"img": np.random.randint(0, 2, size=[3, 3, 3, 3])}, - (3, 2, 2, 2), - ], - [ - {"keys": ["img"], "roi_start": [0, 0], "roi_end": [2, 2]}, - {"img": np.random.randint(0, 2, size=[3, 3, 3, 3])}, - (3, 2, 2, 3), - ], - [ - {"keys": ["img"], "roi_start": [0, 0, 0, 0, 0], "roi_end": [2, 2, 2, 2, 2]}, - {"img": np.random.randint(0, 2, size=[3, 3, 3, 3])}, - (3, 2, 2, 2), - ], - [ - {"keys": ["img"], "roi_slices": [slice(s, e) for s, e in zip([-1, -2, 0], [None, None, 2])]}, - {"img": np.random.randint(0, 2, size=[3, 3, 3, 3])}, - (3, 1, 2, 2), - ], -] +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append( + [ + {"keys": ["img"], "roi_center": [1, 1, 1], "roi_size": [2, 2, 2]}, + {"img": p(np.random.randint(0, 2, size=[3, 3, 3, 3]))}, + (3, 2, 2, 2), + ] + ) + TESTS.append( + [ + {"keys": ["img"], "roi_start": [0, 0, 0], "roi_end": [2, 2, 2]}, + {"img": p(np.random.randint(0, 2, size=[3, 3, 3, 3]))}, + (3, 2, 2, 2), + ] + ) + TESTS.append( + [ + {"keys": ["img"], "roi_start": [0, 0], "roi_end": [2, 2]}, + {"img": p(np.random.randint(0, 2, size=[3, 3, 3, 3]))}, + (3, 2, 2, 3), + ] + ) + TESTS.append( + [ + {"keys": ["img"], "roi_start": [0, 0, 0, 0, 0], "roi_end": [2, 2, 2, 2, 2]}, + {"img": p(np.random.randint(0, 2, size=[3, 3, 3, 3]))}, + (3, 2, 2, 2), + ] + ) + TESTS.append( + [ + {"keys": ["img"], "roi_slices": [slice(s, e) for s, e in zip([-1, -2, 0], [None, None, 2])]}, + {"img": p(np.random.randint(0, 2, size=[3, 3, 3, 3]))}, + (3, 1, 2, 2), + ] + ) class TestSpatialCropd(unittest.TestCase): - @parameterized.expand(TEST_CASES) + @parameterized.expand(TESTS) def test_shape(self, input_param, input_data, expected_shape): result = SpatialCropd(**input_param)(input_data) self.assertTupleEqual(result["img"].shape, expected_shape)