diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index a9cb847b93..3e3afab5d8 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -441,6 +441,8 @@ class Rotate(Transform, ThreadUnsafe): the output data type is always ``np.float32``. """ + backend = [TransformBackends.TORCH] + def __init__( self, angle: Union[Sequence[float], float], @@ -448,7 +450,7 @@ def __init__( mode: Union[GridSampleMode, str] = GridSampleMode.BILINEAR, padding_mode: Union[GridSamplePadMode, str] = GridSamplePadMode.BORDER, align_corners: bool = False, - dtype: DtypeLike = np.float64, + dtype: Union[DtypeLike, torch.dtype] = np.float64, ) -> None: self.angle = angle self.keep_size = keep_size @@ -460,12 +462,12 @@ def __init__( def __call__( self, - img: np.ndarray, + img: NdarrayOrTensor, mode: Optional[Union[GridSampleMode, str]] = None, padding_mode: Optional[Union[GridSamplePadMode, str]] = None, align_corners: Optional[bool] = None, - dtype: DtypeLike = None, - ) -> np.ndarray: + dtype: Union[DtypeLike, torch.dtype] = None, + ) -> torch.Tensor: """ Args: img: channel first array, must have shape: [chns, H, W] or [chns, H, W, D]. @@ -488,7 +490,11 @@ def __call__( """ _dtype = dtype or self.dtype or img.dtype - im_shape = np.asarray(img.shape[1:]) # spatial dimensions + + img_t: torch.Tensor + img_t, *_ = convert_data_type(img, torch.Tensor, dtype=_dtype) # type: ignore + + im_shape = np.asarray(img_t.shape[1:]) # spatial dimensions input_ndim = len(im_shape) if input_ndim not in (2, 3): raise ValueError(f"Unsupported img dimension: {input_ndim}, available options are [2, 3].") @@ -506,6 +512,9 @@ def __call__( shift_1 = create_translate(input_ndim, (-(output_shape - 1) / 2).tolist()) transform = shift @ transform @ shift_1 + transform_t: torch.Tensor + transform_t, *_ = convert_to_dst_type(transform, img_t) # type: ignore + xform = AffineTransform( normalized=False, mode=look_up_option(mode or self.mode, GridSampleMode), @@ -513,13 +522,13 @@ def __call__( align_corners=self.align_corners if align_corners is None else align_corners, reverse_indexing=True, ) - output = xform( - torch.as_tensor(np.ascontiguousarray(img).astype(_dtype)).unsqueeze(0), - torch.as_tensor(np.ascontiguousarray(transform).astype(_dtype)), + output: torch.Tensor = xform( + img_t.unsqueeze(0), + transform_t, spatial_size=output_shape, ) self._rotation_matrix = transform - return np.asarray(output.squeeze(0).detach().cpu().numpy(), dtype=np.float32) + return output.squeeze(0).detach().float() def get_rotation_matrix(self) -> Optional[np.ndarray]: """ @@ -738,6 +747,8 @@ class RandRotate(RandomizableTransform): the output data type is always ``np.float32``. """ + backend = Rotate.backend + def __init__( self, range_x: Union[Tuple[float, float], float] = 0.0, @@ -748,7 +759,7 @@ def __init__( mode: Union[GridSampleMode, str] = GridSampleMode.BILINEAR, padding_mode: Union[GridSamplePadMode, str] = GridSamplePadMode.BORDER, align_corners: bool = False, - dtype: DtypeLike = np.float64, + dtype: Union[DtypeLike, torch.dtype] = np.float64, ) -> None: RandomizableTransform.__init__(self, prob) self.range_x = ensure_tuple(range_x) @@ -779,12 +790,12 @@ def randomize(self, data: Optional[Any] = None) -> None: def __call__( self, - img: np.ndarray, + img: NdarrayOrTensor, mode: Optional[Union[GridSampleMode, str]] = None, padding_mode: Optional[Union[GridSamplePadMode, str]] = None, align_corners: Optional[bool] = None, - dtype: DtypeLike = None, - ) -> np.ndarray: + dtype: Union[DtypeLike, torch.dtype] = None, + ) -> torch.Tensor: """ Args: img: channel first array, must have shape 2D: (nchannels, H, W), or 3D: (nchannels, H, W, D). @@ -802,7 +813,9 @@ def __call__( """ self.randomize() if not self._do_transform: - return img + img_t: torch.Tensor + img_t, *_ = convert_data_type(img, torch.Tensor) # type: ignore + return img_t rotator = Rotate( angle=self.x if img.ndim == 3 else (self.x, self.y, self.z), keep_size=self.keep_size, @@ -811,7 +824,7 @@ def __call__( align_corners=self.align_corners if align_corners is None else align_corners, dtype=dtype or self.dtype or img.dtype, ) - return np.array(rotator(img)) + return rotator(img) class RandFlip(RandomizableTransform): diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index 96fe21db12..903512e9b8 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -57,6 +57,7 @@ ) from monai.utils.enums import InverseKeys from monai.utils.module import optional_import +from monai.utils.type_conversion import convert_data_type, convert_to_dst_type nib, _ = optional_import("nibabel") @@ -1287,6 +1288,8 @@ class Rotated(MapTransform, InvertibleTransform): allow_missing_keys: don't raise exception if key is missing. """ + backend = Rotate.backend + def __init__( self, keys: KeysCollection, @@ -1295,7 +1298,7 @@ def __init__( mode: GridSampleModeSequence = GridSampleMode.BILINEAR, padding_mode: GridSamplePadModeSequence = GridSamplePadMode.BORDER, align_corners: Union[Sequence[bool], bool] = False, - dtype: Union[Sequence[DtypeLike], DtypeLike] = np.float64, + dtype: Union[Sequence[Union[DtypeLike, torch.dtype]], Union[DtypeLike, torch.dtype]] = np.float64, allow_missing_keys: bool = False, ) -> None: super().__init__(keys, allow_missing_keys) @@ -1306,7 +1309,7 @@ def __init__( self.align_corners = ensure_tuple_rep(align_corners, len(self.keys)) self.dtype = ensure_tuple_rep(dtype, len(self.keys)) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) for key, mode, padding_mode, align_corners, dtype in self.key_iterator( d, self.mode, self.padding_mode, self.align_corners, self.dtype @@ -1333,7 +1336,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda ) return d - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = deepcopy(dict(data)) for key, dtype in self.key_iterator(d, self.dtype): transform = self.get_most_recent_transform(d, key) @@ -1351,12 +1354,17 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar align_corners=False if align_corners == "none" else align_corners, reverse_indexing=True, ) + img_t: torch.Tensor + img_t, *_ = convert_data_type(d[key], torch.Tensor, dtype=dtype) # type: ignore + transform_t: torch.Tensor + transform_t, *_ = convert_to_dst_type(inv_rot_mat, img_t) # type: ignore + output = xform( - torch.as_tensor(np.ascontiguousarray(d[key]).astype(dtype)).unsqueeze(0), - torch.as_tensor(np.ascontiguousarray(inv_rot_mat).astype(dtype)), + img_t.unsqueeze(0), + transform_t, spatial_size=transform[InverseKeys.ORIG_SIZE], ) - d[key] = np.asarray(output.squeeze(0).detach().cpu().numpy(), dtype=np.float32) + d[key] = output.squeeze(0).detach().float() # Remove the applied transform self.pop_transform(d, key) @@ -1398,6 +1406,8 @@ class RandRotated(RandomizableTransform, MapTransform, InvertibleTransform): allow_missing_keys: don't raise exception if key is missing. """ + backend = Rotate.backend + def __init__( self, keys: KeysCollection, @@ -1409,7 +1419,7 @@ def __init__( mode: GridSampleModeSequence = GridSampleMode.BILINEAR, padding_mode: GridSamplePadModeSequence = GridSamplePadMode.BORDER, align_corners: Union[Sequence[bool], bool] = False, - dtype: Union[Sequence[DtypeLike], DtypeLike] = np.float64, + dtype: Union[Sequence[Union[DtypeLike, torch.dtype]], Union[DtypeLike, torch.dtype]] = np.float64, allow_missing_keys: bool = False, ) -> None: MapTransform.__init__(self, keys, allow_missing_keys) @@ -1440,7 +1450,7 @@ def randomize(self, data: Optional[Any] = None) -> None: self.y = self.R.uniform(low=self.range_y[0], high=self.range_y[1]) self.z = self.R.uniform(low=self.range_z[0], high=self.range_z[1]) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: self.randomize() d = dict(data) angle: Union[Sequence[float], float] = self.x if d[self.keys[0]].ndim == 3 else (self.x, self.y, self.z) @@ -1462,6 +1472,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda ) rot_mat = rotator.get_rotation_matrix() else: + d[key], *_ = convert_data_type(d[key], torch.Tensor) rot_mat = np.eye(d[key].ndim) self.push_transform( d, @@ -1476,7 +1487,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda ) return d - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = deepcopy(dict(data)) for key, dtype in self.key_iterator(d, self.dtype): transform = self.get_most_recent_transform(d, key) @@ -1496,12 +1507,17 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar align_corners=False if align_corners == "none" else align_corners, reverse_indexing=True, ) + img_t: torch.Tensor + img_t, *_ = convert_data_type(d[key], torch.Tensor, dtype=dtype) # type: ignore + transform_t: torch.Tensor + transform_t, *_ = convert_to_dst_type(inv_rot_mat, img_t) # type: ignore + output: torch.Tensor output = xform( - torch.as_tensor(np.ascontiguousarray(d[key]).astype(dtype)).unsqueeze(0), - torch.as_tensor(np.ascontiguousarray(inv_rot_mat).astype(dtype)), + img_t.unsqueeze(0), + transform_t, spatial_size=transform[InverseKeys.ORIG_SIZE], ) - d[key] = np.asarray(output.squeeze(0).detach().cpu().numpy(), dtype=np.float32) + d[key] = output.squeeze(0).detach().float() # Remove the applied transform self.pop_transform(d, key) diff --git a/tests/test_rand_rotate.py b/tests/test_rand_rotate.py index 0ff8508a0f..4817e81735 100644 --- a/tests/test_rand_rotate.py +++ b/tests/test_rand_rotate.py @@ -10,25 +10,60 @@ # limitations under the License. import unittest +from typing import List, Tuple import numpy as np import scipy.ndimage +import torch from parameterized import parameterized from monai.transforms import RandRotate -from tests.utils import NumpyImageTestCase2D, NumpyImageTestCase3D +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, NumpyImageTestCase3D +TEST_CASES_2D: List[Tuple] = [] +for p in TEST_NDARRAYS: + TEST_CASES_2D.append((p, np.pi / 2, True, "bilinear", "border", False)) + TEST_CASES_2D.append((p, np.pi / 4, True, "nearest", "border", False)) + TEST_CASES_2D.append((p, np.pi, False, "nearest", "zeros", True)) + TEST_CASES_2D.append((p, (-np.pi / 4, 0), False, "nearest", "zeros", True)) -class TestRandRotate2D(NumpyImageTestCase2D): - @parameterized.expand( - [ - (np.pi / 2, True, "bilinear", "border", False), - (np.pi / 4, True, "nearest", "border", False), - (np.pi, False, "nearest", "zeros", True), - ((-np.pi / 4, 0), False, "nearest", "zeros", True), - ] +TEST_CASES_3D: List[Tuple] = [] +for p in TEST_NDARRAYS: + TEST_CASES_3D.append( + (p, np.pi / 2, -np.pi / 6, (0.0, np.pi), False, "bilinear", "border", False, (1, 87, 104, 109)) + ) + TEST_CASES_3D.append( + ( + p, + np.pi / 4, + (-np.pi / 9, np.pi / 4.5), + (np.pi / 9, np.pi / 6), + False, + "nearest", + "border", + True, + (1, 89, 105, 104), + ) + ) + TEST_CASES_3D.append( + ( + p, + 0.0, + (2 * np.pi, 2.06 * np.pi), + (-np.pi / 180, np.pi / 180), + True, + "nearest", + "zeros", + True, + (1, 48, 64, 80), + ) ) - def test_correct_results(self, degrees, keep_size, mode, padding_mode, align_corners): + TEST_CASES_3D.append((p, (-np.pi / 4, 0), 0, 0, False, "nearest", "zeros", False, (1, 48, 77, 90))) + + +class TestRandRotate2D(NumpyImageTestCase2D): + @parameterized.expand(TEST_CASES_2D) + def test_correct_results(self, im_type, degrees, keep_size, mode, padding_mode, align_corners): rotate_fn = RandRotate( range_x=degrees, prob=1.0, @@ -38,7 +73,7 @@ def test_correct_results(self, degrees, keep_size, mode, padding_mode, align_cor align_corners=align_corners, ) rotate_fn.set_random_state(243) - rotated = rotate_fn(self.imt[0]) + rotated = rotate_fn(im_type(self.imt[0])) _order = 0 if mode == "nearest" else 1 if mode == "border": @@ -52,38 +87,14 @@ def test_correct_results(self, degrees, keep_size, mode, padding_mode, align_cor self.imt[0, 0], -np.rad2deg(angle), (0, 1), not keep_size, order=_order, mode=_mode, prefilter=False ) expected = np.stack(expected).astype(np.float32) + rotated = rotated.cpu() if isinstance(rotated, torch.Tensor) else rotated good = np.sum(np.isclose(expected, rotated[0], atol=1e-3)) self.assertLessEqual(np.abs(good - expected.size), 5, "diff at most 5 pixels") class TestRandRotate3D(NumpyImageTestCase3D): - @parameterized.expand( - [ - (np.pi / 2, -np.pi / 6, (0.0, np.pi), False, "bilinear", "border", False, (1, 87, 104, 109)), - ( - np.pi / 4, - (-np.pi / 9, np.pi / 4.5), - (np.pi / 9, np.pi / 6), - False, - "nearest", - "border", - True, - (1, 89, 105, 104), - ), - ( - 0.0, - (2 * np.pi, 2.06 * np.pi), - (-np.pi / 180, np.pi / 180), - True, - "nearest", - "zeros", - True, - (1, 48, 64, 80), - ), - ((-np.pi / 4, 0), 0, 0, False, "nearest", "zeros", False, (1, 48, 77, 90)), - ] - ) - def test_correct_results(self, x, y, z, keep_size, mode, padding_mode, align_corners, expected): + @parameterized.expand(TEST_CASES_3D) + def test_correct_results(self, im_type, x, y, z, keep_size, mode, padding_mode, align_corners, expected): rotate_fn = RandRotate( range_x=x, range_y=y, @@ -95,8 +106,8 @@ def test_correct_results(self, x, y, z, keep_size, mode, padding_mode, align_cor align_corners=align_corners, ) rotate_fn.set_random_state(243) - rotated = rotate_fn(self.imt[0]) - np.testing.assert_allclose(rotated.shape, expected) + rotated = rotate_fn(im_type(self.imt[0])) + torch.testing.assert_allclose(rotated.shape, expected, rtol=1e-7, atol=0) if __name__ == "__main__": diff --git a/tests/test_rand_rotated.py b/tests/test_rand_rotated.py index 47b4b7107e..4c9a27f668 100644 --- a/tests/test_rand_rotated.py +++ b/tests/test_rand_rotated.py @@ -10,26 +10,104 @@ # limitations under the License. import unittest +from typing import List, Tuple import numpy as np import scipy.ndimage +import torch from parameterized import parameterized from monai.transforms import RandRotated from monai.utils import GridSampleMode, GridSamplePadMode -from tests.utils import NumpyImageTestCase2D, NumpyImageTestCase3D +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, NumpyImageTestCase3D +TEST_CASES_2D: List[Tuple] = [] +for p in TEST_NDARRAYS: + TEST_CASES_2D.append((p, np.pi / 2, True, "bilinear", "border", False)) + TEST_CASES_2D.append((p, np.pi / 4, True, "nearest", "border", False)) + TEST_CASES_2D.append((p, np.pi, False, "nearest", "zeros", True)) + TEST_CASES_2D.append((p, (-np.pi / 4, 0), False, "nearest", "zeros", True)) -class TestRandRotated2D(NumpyImageTestCase2D): - @parameterized.expand( - [ - (np.pi / 2, True, "bilinear", "border", False), - (np.pi / 4, True, "nearest", "border", False), - (np.pi, False, "nearest", "zeros", True), - ((-np.pi / 4, 0), False, "nearest", "zeros", True), - ] + +TEST_CASES_3D: List[Tuple] = [] +for p in TEST_NDARRAYS: + TEST_CASES_3D.append( + (p, np.pi / 2, -np.pi / 6, (0.0, np.pi), False, "bilinear", "border", False, (1, 87, 104, 109)) ) - def test_correct_results(self, degrees, keep_size, mode, padding_mode, align_corners): + TEST_CASES_3D.append( + ( + p, + np.pi / 2, + -np.pi / 6, + (0.0, np.pi), + False, + GridSampleMode.NEAREST, + GridSamplePadMode.BORDER, + False, + (1, 87, 104, 109), + ) + ) + TEST_CASES_3D.append( + ( + p, + np.pi / 4, + (-np.pi / 9, np.pi / 4.5), + (np.pi / 9, np.pi / 6), + False, + "nearest", + "border", + True, + (1, 89, 105, 104), + ) + ) + TEST_CASES_3D.append( + ( + p, + np.pi / 4, + (-np.pi / 9, np.pi / 4.5), + (np.pi / 9, np.pi / 6), + False, + GridSampleMode.NEAREST, + GridSamplePadMode.BORDER, + True, + (1, 89, 105, 104), + ) + ) + TEST_CASES_3D.append( + ( + p, + 0.0, + (2 * np.pi, 2.06 * np.pi), + (-np.pi / 180, np.pi / 180), + True, + "nearest", + "zeros", + True, + (1, 48, 64, 80), + ) + ) + TEST_CASES_3D.append( + ( + p, + 0.0, + (2 * np.pi, 2.06 * np.pi), + (-np.pi / 180, np.pi / 180), + True, + GridSampleMode.NEAREST, + GridSamplePadMode.ZEROS, + True, + (1, 48, 64, 80), + ) + ) + TEST_CASES_3D.append((p, (-np.pi / 4, 0), 0, 0, False, "nearest", "zeros", False, (1, 48, 77, 90))) + TEST_CASES_3D.append( + (p, (-np.pi / 4, 0), 0, 0, False, GridSampleMode.NEAREST, GridSamplePadMode.ZEROS, False, (1, 48, 77, 90)) + ) + + +class TestRandRotated2D(NumpyImageTestCase2D): + @parameterized.expand(TEST_CASES_2D) + def test_correct_results(self, im_type, degrees, keep_size, mode, padding_mode, align_corners): rotate_fn = RandRotated( "img", range_x=degrees, @@ -40,7 +118,7 @@ def test_correct_results(self, degrees, keep_size, mode, padding_mode, align_cor align_corners=align_corners, ) rotate_fn.set_random_state(243) - rotated = rotate_fn({"img": self.imt[0], "seg": self.segn[0]}) + rotated = rotate_fn({"img": im_type(self.imt[0]), "seg": im_type(self.segn[0])}) _order = 0 if mode == "nearest" else 1 if padding_mode == "border": @@ -53,70 +131,16 @@ def test_correct_results(self, degrees, keep_size, mode, padding_mode, align_cor expected = scipy.ndimage.rotate( self.imt[0, 0], -np.rad2deg(angle), (0, 1), not keep_size, order=_order, mode=_mode, prefilter=False ) + for k, v in rotated.items(): + rotated[k] = v.cpu() if isinstance(v, torch.Tensor) else v expected = np.stack(expected).astype(np.float32) good = np.sum(np.isclose(expected, rotated["img"][0], atol=1e-3)) self.assertLessEqual(np.abs(good - expected.size), 5, "diff at most 5 pixels") class TestRandRotated3D(NumpyImageTestCase3D): - @parameterized.expand( - [ - (np.pi / 2, -np.pi / 6, (0.0, np.pi), False, "bilinear", "border", False, (1, 87, 104, 109)), - ( - np.pi / 2, - -np.pi / 6, - (0.0, np.pi), - False, - GridSampleMode.NEAREST, - GridSamplePadMode.BORDER, - False, - (1, 87, 104, 109), - ), - ( - np.pi / 4, - (-np.pi / 9, np.pi / 4.5), - (np.pi / 9, np.pi / 6), - False, - "nearest", - "border", - True, - (1, 89, 105, 104), - ), - ( - np.pi / 4, - (-np.pi / 9, np.pi / 4.5), - (np.pi / 9, np.pi / 6), - False, - GridSampleMode.NEAREST, - GridSamplePadMode.BORDER, - True, - (1, 89, 105, 104), - ), - ( - 0.0, - (2 * np.pi, 2.06 * np.pi), - (-np.pi / 180, np.pi / 180), - True, - "nearest", - "zeros", - True, - (1, 48, 64, 80), - ), - ( - 0.0, - (2 * np.pi, 2.06 * np.pi), - (-np.pi / 180, np.pi / 180), - True, - GridSampleMode.NEAREST, - GridSamplePadMode.ZEROS, - True, - (1, 48, 64, 80), - ), - ((-np.pi / 4, 0), 0, 0, False, "nearest", "zeros", False, (1, 48, 77, 90)), - ((-np.pi / 4, 0), 0, 0, False, GridSampleMode.NEAREST, GridSamplePadMode.ZEROS, False, (1, 48, 77, 90)), - ] - ) - def test_correct_shapes(self, x, y, z, keep_size, mode, padding_mode, align_corners, expected): + @parameterized.expand(TEST_CASES_3D) + def test_correct_shapes(self, im_type, x, y, z, keep_size, mode, padding_mode, align_corners, expected): rotate_fn = RandRotated( "img", range_x=x, @@ -129,7 +153,7 @@ def test_correct_shapes(self, x, y, z, keep_size, mode, padding_mode, align_corn align_corners=align_corners, ) rotate_fn.set_random_state(243) - rotated = rotate_fn({"img": self.imt[0], "seg": self.segn[0]}) + rotated = rotate_fn({"img": im_type(self.imt[0]), "seg": im_type(self.segn[0])}) np.testing.assert_allclose(rotated["img"].shape, expected) diff --git a/tests/test_rotate.py b/tests/test_rotate.py index 436c952d4b..16a9c6d124 100644 --- a/tests/test_rotate.py +++ b/tests/test_rotate.py @@ -10,42 +10,44 @@ # limitations under the License. import unittest +from typing import List, Tuple import numpy as np import scipy.ndimage +import torch from parameterized import parameterized from monai.transforms import Rotate -from tests.utils import NumpyImageTestCase2D, NumpyImageTestCase3D - -TEST_CASES_2D = [ - (np.pi / 6, False, "bilinear", "border", False), - (np.pi / 4, True, "bilinear", "border", False), - (-np.pi / 4.5, True, "nearest", "reflection", False), - (np.pi, False, "nearest", "zeros", False), - (-np.pi / 2, False, "bilinear", "zeros", True), -] - -TEST_CASES_3D = [ - (-np.pi / 2, True, "nearest", "border", False), - (np.pi / 4, True, "bilinear", "border", False), - (-np.pi / 4.5, True, "nearest", "reflection", False), - (np.pi, False, "nearest", "zeros", False), - (-np.pi / 2, False, "bilinear", "zeros", False), -] - -TEST_CASES_SHAPE_3D = [ - ([-np.pi / 2, 1.0, 2.0], "nearest", "border", False), - ([np.pi / 4, 0, 0], "bilinear", "border", False), - ([-np.pi / 4.5, -20, 20], "nearest", "reflection", False), -] +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, NumpyImageTestCase3D + +TEST_CASES_2D: List[Tuple] = [] +for p in TEST_NDARRAYS: + TEST_CASES_2D.append((p, np.pi / 6, False, "bilinear", "border", False)) + TEST_CASES_2D.append((p, np.pi / 4, True, "bilinear", "border", False)) + TEST_CASES_2D.append((p, -np.pi / 4.5, True, "nearest", "reflection", False)) + TEST_CASES_2D.append((p, np.pi, False, "nearest", "zeros", False)) + TEST_CASES_2D.append((p, -np.pi / 2, False, "bilinear", "zeros", True)) + +TEST_CASES_3D: List[Tuple] = [] +for p in TEST_NDARRAYS: + TEST_CASES_3D.append((p, -np.pi / 2, True, "nearest", "border", False)) + TEST_CASES_3D.append((p, np.pi / 4, True, "bilinear", "border", False)) + TEST_CASES_3D.append((p, -np.pi / 4.5, True, "nearest", "reflection", False)) + TEST_CASES_3D.append((p, np.pi, False, "nearest", "zeros", False)) + TEST_CASES_3D.append((p, -np.pi / 2, False, "bilinear", "zeros", False)) + +TEST_CASES_SHAPE_3D: List[Tuple] = [] +for p in TEST_NDARRAYS: + TEST_CASES_SHAPE_3D.append((p, [-np.pi / 2, 1.0, 2.0], "nearest", "border", False)) + TEST_CASES_SHAPE_3D.append((p, [np.pi / 4, 0, 0], "bilinear", "border", False)) + TEST_CASES_SHAPE_3D.append((p, [-np.pi / 4.5, -20, 20], "nearest", "reflection", False)) class TestRotate2D(NumpyImageTestCase2D): @parameterized.expand(TEST_CASES_2D) - def test_correct_results(self, angle, keep_size, mode, padding_mode, align_corners): + def test_correct_results(self, im_type, angle, keep_size, mode, padding_mode, align_corners): rotate_fn = Rotate(angle, keep_size, mode, padding_mode, align_corners) - rotated = rotate_fn(self.imt[0]) + rotated = rotate_fn(im_type(self.imt[0])) if keep_size: np.testing.assert_allclose(self.imt[0].shape, rotated.shape) _order = 0 if mode == "nearest" else 1 @@ -70,15 +72,16 @@ def test_correct_results(self, angle, keep_size, mode, padding_mode, align_corne ) ) expected = np.stack(expected).astype(np.float32) + rotated = rotated.cpu() if isinstance(rotated, torch.Tensor) else rotated good = np.sum(np.isclose(expected, rotated, atol=1e-3)) self.assertLessEqual(np.abs(good - expected.size), 5, "diff at most 5 pixels") class TestRotate3D(NumpyImageTestCase3D): @parameterized.expand(TEST_CASES_3D) - def test_correct_results(self, angle, keep_size, mode, padding_mode, align_corners): + def test_correct_results(self, im_type, angle, keep_size, mode, padding_mode, align_corners): rotate_fn = Rotate([angle, 0, 0], keep_size, mode, padding_mode, align_corners) - rotated = rotate_fn(self.imt[0]) + rotated = rotate_fn(im_type(self.imt[0])) if keep_size: np.testing.assert_allclose(self.imt[0].shape, rotated.shape) _order = 0 if mode == "nearest" else 1 @@ -103,23 +106,25 @@ def test_correct_results(self, angle, keep_size, mode, padding_mode, align_corne ) ) expected = np.stack(expected).astype(np.float32) + rotated = rotated.cpu() if isinstance(rotated, torch.Tensor) else rotated n_good = np.sum(np.isclose(expected, rotated, atol=1e-3)) self.assertLessEqual(expected.size - n_good, 5, "diff at most 5 pixels") @parameterized.expand(TEST_CASES_SHAPE_3D) - def test_correct_shape(self, angle, mode, padding_mode, align_corners): + def test_correct_shape(self, im_type, angle, mode, padding_mode, align_corners): rotate_fn = Rotate(angle, True, align_corners=align_corners) - rotated = rotate_fn(self.imt[0], mode=mode, padding_mode=padding_mode) + rotated = rotate_fn(im_type(self.imt[0]), mode=mode, padding_mode=padding_mode) np.testing.assert_allclose(self.imt[0].shape, rotated.shape) def test_ill_case(self): - rotate_fn = Rotate(10, True) - with self.assertRaises(ValueError): # wrong shape - rotate_fn(self.imt) - - rotate_fn = Rotate(10, keep_size=False) - with self.assertRaises(ValueError): # wrong mode - rotate_fn(self.imt[0], mode="trilinear") + for p in TEST_NDARRAYS: + rotate_fn = Rotate(10, True) + with self.assertRaises(ValueError): # wrong shape + rotate_fn(p(self.imt)) + + rotate_fn = Rotate(10, keep_size=False) + with self.assertRaises(ValueError): # wrong mode + rotate_fn(p(self.imt[0]), mode="trilinear") if __name__ == "__main__": diff --git a/tests/test_rotated.py b/tests/test_rotated.py index 2ea421101b..cd27dd5406 100644 --- a/tests/test_rotated.py +++ b/tests/test_rotated.py @@ -10,36 +10,38 @@ # limitations under the License. import unittest +from typing import List, Tuple import numpy as np import scipy.ndimage +import torch from parameterized import parameterized from monai.transforms import Rotated -from tests.utils import NumpyImageTestCase2D, NumpyImageTestCase3D +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, NumpyImageTestCase3D -TEST_CASES_2D = [ - (-np.pi / 6, False, "bilinear", "border", False), - (-np.pi / 4, True, "bilinear", "border", False), - (np.pi / 4.5, True, "nearest", "reflection", False), - (-np.pi, False, "nearest", "zeros", False), - (np.pi / 2, False, "bilinear", "zeros", True), -] +TEST_CASES_2D: List[Tuple] = [] +for p in TEST_NDARRAYS: + TEST_CASES_2D.append((p, -np.pi / 6, False, "bilinear", "border", False)) + TEST_CASES_2D.append((p, -np.pi / 4, True, "bilinear", "border", False)) + TEST_CASES_2D.append((p, np.pi / 4.5, True, "nearest", "reflection", False)) + TEST_CASES_2D.append((p, -np.pi, False, "nearest", "zeros", False)) + TEST_CASES_2D.append((p, np.pi / 2, False, "bilinear", "zeros", True)) -TEST_CASES_3D = [ - (-np.pi / 6, False, "bilinear", "border", False), - (-np.pi / 4, True, "bilinear", "border", False), - (np.pi / 4.5, True, "nearest", "reflection", False), - (-np.pi, False, "nearest", "zeros", False), - (np.pi / 2, False, "bilinear", "zeros", True), -] +TEST_CASES_3D: List[Tuple] = [] +for p in TEST_NDARRAYS: + TEST_CASES_3D.append((p, -np.pi / 6, False, "bilinear", "border", False)) + TEST_CASES_3D.append((p, -np.pi / 4, True, "bilinear", "border", False)) + TEST_CASES_3D.append((p, np.pi / 4.5, True, "nearest", "reflection", False)) + TEST_CASES_3D.append((p, -np.pi, False, "nearest", "zeros", False)) + TEST_CASES_3D.append((p, np.pi / 2, False, "bilinear", "zeros", True)) class TestRotated2D(NumpyImageTestCase2D): @parameterized.expand(TEST_CASES_2D) - def test_correct_results(self, angle, keep_size, mode, padding_mode, align_corners): + def test_correct_results(self, im_type, angle, keep_size, mode, padding_mode, align_corners): rotate_fn = Rotated(("img", "seg"), angle, keep_size, (mode, "nearest"), padding_mode, align_corners) - rotated = rotate_fn({"img": self.imt[0], "seg": self.segn[0]}) + rotated = rotate_fn({"img": im_type(self.imt[0]), "seg": im_type(self.segn[0])}) if keep_size: np.testing.assert_allclose(self.imt[0].shape, rotated["img"].shape) _order = 0 if mode == "nearest" else 1 @@ -52,6 +54,8 @@ def test_correct_results(self, angle, keep_size, mode, padding_mode, align_corne expected = scipy.ndimage.rotate( self.imt[0, 0], -np.rad2deg(angle), (0, 1), not keep_size, order=_order, mode=_mode, prefilter=False ) + for k, v in rotated.items(): + rotated[k] = v.cpu() if isinstance(v, torch.Tensor) else v good = np.sum(np.isclose(expected, rotated["img"][0], atol=1e-3)) self.assertLessEqual(np.abs(good - expected.size), 5, "diff at most 5 pixels") @@ -64,9 +68,9 @@ def test_correct_results(self, angle, keep_size, mode, padding_mode, align_corne class TestRotated3D(NumpyImageTestCase3D): @parameterized.expand(TEST_CASES_3D) - def test_correct_results(self, angle, keep_size, mode, padding_mode, align_corners): + def test_correct_results(self, im_type, angle, keep_size, mode, padding_mode, align_corners): rotate_fn = Rotated(("img", "seg"), [0, angle, 0], keep_size, (mode, "nearest"), padding_mode, align_corners) - rotated = rotate_fn({"img": self.imt[0], "seg": self.segn[0]}) + rotated = rotate_fn({"img": im_type(self.imt[0]), "seg": im_type(self.segn[0])}) if keep_size: np.testing.assert_allclose(self.imt[0].shape, rotated["img"].shape) _order = 0 if mode == "nearest" else 1 @@ -79,6 +83,8 @@ def test_correct_results(self, angle, keep_size, mode, padding_mode, align_corne expected = scipy.ndimage.rotate( self.imt[0, 0], np.rad2deg(angle), (0, 2), not keep_size, order=_order, mode=_mode, prefilter=False ) + for k, v in rotated.items(): + rotated[k] = v.cpu() if isinstance(v, torch.Tensor) else v good = np.sum(np.isclose(expected.astype(np.float32), rotated["img"][0], atol=1e-3)) self.assertLessEqual(np.abs(good - expected.size), 5, "diff at most 5 voxels.") @@ -91,9 +97,9 @@ def test_correct_results(self, angle, keep_size, mode, padding_mode, align_corne class TestRotated3DXY(NumpyImageTestCase3D): @parameterized.expand(TEST_CASES_3D) - def test_correct_results(self, angle, keep_size, mode, padding_mode, align_corners): + def test_correct_results(self, im_type, angle, keep_size, mode, padding_mode, align_corners): rotate_fn = Rotated(("img", "seg"), [0, 0, angle], keep_size, (mode, "nearest"), padding_mode, align_corners) - rotated = rotate_fn({"img": self.imt[0], "seg": self.segn[0]}) + rotated = rotate_fn({"img": im_type(self.imt[0]), "seg": im_type(self.segn[0])}) if keep_size: np.testing.assert_allclose(self.imt[0].shape, rotated["img"].shape) _order = 0 if mode == "nearest" else 1 @@ -106,6 +112,8 @@ def test_correct_results(self, angle, keep_size, mode, padding_mode, align_corne expected = scipy.ndimage.rotate( self.imt[0, 0], -np.rad2deg(angle), (0, 1), not keep_size, order=_order, mode=_mode, prefilter=False ) + for k, v in rotated.items(): + rotated[k] = v.cpu() if isinstance(v, torch.Tensor) else v good = np.sum(np.isclose(expected, rotated["img"][0], atol=1e-3)) self.assertLessEqual(np.abs(good - expected.size), 5, "diff at most 5 voxels")