diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index 5979b996ee..816e9d58f2 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -637,6 +637,8 @@ class Rotate90(Transform): """ + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + def __init__(self, k: int = 1, spatial_axes: Tuple[int, int] = (0, 1)) -> None: """ Args: @@ -651,14 +653,15 @@ def __init__(self, k: int = 1, spatial_axes: Tuple[int, int] = (0, 1)) -> None: raise ValueError("spatial_axes must be 2 int numbers to indicate the axes to rotate 90 degrees.") self.spatial_axes = spatial_axes_ - def __call__(self, img: np.ndarray) -> np.ndarray: + def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: """ Args: img: channel first array, must have shape: (num_channels, H[, W, ..., ]), """ - - result: np.ndarray = np.rot90(img, self.k, map_spatial_axes(img.ndim, self.spatial_axes)) - return result.astype(img.dtype) + rot90 = torch.rot90 if isinstance(img, torch.Tensor) else np.rot90 + out: NdarrayOrTensor = rot90(img, self.k, map_spatial_axes(img.ndim, self.spatial_axes)) + out, *_ = convert_data_type(out, dtype=img.dtype) + return out class RandRotate90(RandomizableTransform): @@ -667,6 +670,8 @@ class RandRotate90(RandomizableTransform): in the plane specified by `spatial_axes`. """ + backend = Rotate90.backend + def __init__(self, prob: float = 0.1, max_k: int = 3, spatial_axes: Tuple[int, int] = (0, 1)) -> None: """ Args: @@ -686,7 +691,7 @@ def randomize(self, data: Optional[Any] = None) -> None: self._rand_k = self.R.randint(self.max_k) + 1 super().randomize(None) - def __call__(self, img: np.ndarray) -> np.ndarray: + def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: """ Args: img: channel first array, must have shape: (num_channels, H[, W, ..., ]), diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index 22dafb3f80..c09d8e8011 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -382,6 +382,8 @@ class Rotate90d(MapTransform, InvertibleTransform): Dictionary-based wrapper of :py:class:`monai.transforms.Rotate90`. """ + backend = Rotate90.backend + def __init__( self, keys: KeysCollection, k: int = 1, spatial_axes: Tuple[int, int] = (0, 1), allow_missing_keys: bool = False ) -> None: @@ -395,14 +397,14 @@ def __init__( super().__init__(keys, allow_missing_keys) self.rotator = Rotate90(k, spatial_axes) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): self.push_transform(d, key) d[key] = self.rotator(d[key]) return d - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = deepcopy(dict(data)) for key in self.key_iterator(d): _ = self.get_most_recent_transform(d, key) @@ -411,9 +413,6 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar num_times_rotated = self.rotator.k num_times_to_rotate = 4 - num_times_rotated inverse_transform = Rotate90(num_times_to_rotate, spatial_axes) - # Might need to convert to numpy - if isinstance(d[key], torch.Tensor): - d[key] = torch.Tensor(d[key]).cpu().numpy() # Apply inverse d[key] = inverse_transform(d[key]) # Remove the applied transform @@ -429,6 +428,8 @@ class RandRotate90d(RandomizableTransform, MapTransform, InvertibleTransform): in the plane specified by `spatial_axes`. """ + backend = Rotate90.backend + def __init__( self, keys: KeysCollection, @@ -461,7 +462,7 @@ def randomize(self, data: Optional[Any] = None) -> None: self._rand_k = self.R.randint(self.max_k) + 1 super().randomize(None) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Mapping[Hashable, np.ndarray]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Mapping[Hashable, NdarrayOrTensor]: self.randomize() d = dict(data) @@ -472,7 +473,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Mapping[Hashable, np. self.push_transform(d, key, extra_info={"rand_k": self._rand_k}) return d - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = deepcopy(dict(data)) for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) @@ -482,9 +483,6 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar num_times_rotated = transform[InverseKeys.EXTRA_INFO]["rand_k"] num_times_to_rotate = 4 - num_times_rotated inverse_transform = Rotate90(num_times_to_rotate, self.spatial_axes) - # Might need to convert to numpy - if isinstance(d[key], torch.Tensor): - d[key] = torch.Tensor(d[key]).cpu().numpy() # Apply inverse d[key] = inverse_transform(d[key]) # Remove the applied transform diff --git a/tests/test_inverse_collation.py b/tests/test_inverse_collation.py index c302e04017..c5dd9f1210 100644 --- a/tests/test_inverse_collation.py +++ b/tests/test_inverse_collation.py @@ -48,7 +48,12 @@ for t in [ RandFlipd(keys=KEYS, prob=0.5, spatial_axis=[1, 2]), RandAxisFlipd(keys=KEYS, prob=0.5), - RandRotate90d(keys=KEYS, spatial_axes=(1, 2)), + Compose( + [ + RandRotate90d(keys=KEYS, spatial_axes=(1, 2)), + ToTensord(keys=KEYS), + ] + ), RandZoomd(keys=KEYS, prob=0.5, min_zoom=0.5, max_zoom=1.1, keep_size=True), RandRotated(keys=KEYS, prob=0.5, range_x=np.pi), RandAffined( @@ -67,7 +72,12 @@ for t in [ RandFlipd(keys=KEYS, prob=0.5, spatial_axis=[1]), RandAxisFlipd(keys=KEYS, prob=0.5), - RandRotate90d(keys=KEYS, prob=0.5, spatial_axes=(0, 1)), + Compose( + [ + RandRotate90d(keys=KEYS, prob=0.5, spatial_axes=(0, 1)), + ToTensord(keys=KEYS), + ] + ), RandZoomd(keys=KEYS, prob=0.5, min_zoom=0.5, max_zoom=1.1, keep_size=True), RandRotated(keys=KEYS, prob=0.5, range_x=np.pi), RandAffined( diff --git a/tests/test_pad_collation.py b/tests/test_pad_collation.py index a8c544558f..47bfa69582 100644 --- a/tests/test_pad_collation.py +++ b/tests/test_pad_collation.py @@ -20,6 +20,7 @@ from monai.data import CacheDataset, DataLoader from monai.data.utils import decollate_batch, pad_list_data_collate from monai.transforms import ( + Compose, PadListDataCollate, RandRotate, RandRotate90, @@ -29,6 +30,8 @@ RandSpatialCropd, RandZoom, RandZoomd, + ToTensor, + ToTensord, ) from monai.utils import set_determinism @@ -41,12 +44,12 @@ TESTS.append((dict, pad_collate, RandSpatialCropd("image", roi_size=[8, 7], random_size=True))) TESTS.append((dict, pad_collate, RandRotated("image", prob=1, range_x=np.pi, keep_size=False))) TESTS.append((dict, pad_collate, RandZoomd("image", prob=1, min_zoom=1.1, max_zoom=2.0, keep_size=False))) - TESTS.append((dict, pad_collate, RandRotate90d("image", prob=1, max_k=2))) + TESTS.append((dict, pad_collate, Compose([RandRotate90d("image", prob=1, max_k=2), ToTensord("image")]))) TESTS.append((list, pad_collate, RandSpatialCrop(roi_size=[8, 7], random_size=True))) TESTS.append((list, pad_collate, RandRotate(prob=1, range_x=np.pi, keep_size=False))) TESTS.append((list, pad_collate, RandZoom(prob=1, min_zoom=1.1, max_zoom=2.0, keep_size=False))) - TESTS.append((list, pad_collate, RandRotate90(prob=1, max_k=2))) + TESTS.append((list, pad_collate, Compose([RandRotate90(prob=1, max_k=2), ToTensor()]))) class _Dataset(torch.utils.data.Dataset): diff --git a/tests/test_rand_rotate90.py b/tests/test_rand_rotate90.py index 50a1b28e53..f339158f94 100644 --- a/tests/test_rand_rotate90.py +++ b/tests/test_rand_rotate90.py @@ -14,49 +14,53 @@ import numpy as np from monai.transforms import RandRotate90 -from tests.utils import NumpyImageTestCase2D +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose class TestRandRotate90(NumpyImageTestCase2D): def test_default(self): rotate = RandRotate90() - rotate.set_random_state(123) - rotated = rotate(self.imt[0]) - expected = [] - for channel in self.imt[0]: - expected.append(np.rot90(channel, 0, (0, 1))) - expected = np.stack(expected) - self.assertTrue(np.allclose(rotated, expected)) + for p in TEST_NDARRAYS: + rotate.set_random_state(123) + rotated = rotate(p(self.imt[0])) + expected = [] + for channel in self.imt[0]: + expected.append(np.rot90(channel, 0, (0, 1))) + expected = np.stack(expected) + assert_allclose(rotated, expected, rtol=1.0e-5, atol=1.0e-8) def test_k(self): rotate = RandRotate90(max_k=2) - rotate.set_random_state(234) - rotated = rotate(self.imt[0]) - expected = [] - for channel in self.imt[0]: - expected.append(np.rot90(channel, 0, (0, 1))) - expected = np.stack(expected) - self.assertTrue(np.allclose(rotated, expected)) + for p in TEST_NDARRAYS: + rotate.set_random_state(234) + rotated = rotate(p(self.imt[0])) + expected = [] + for channel in self.imt[0]: + expected.append(np.rot90(channel, 0, (0, 1))) + expected = np.stack(expected) + assert_allclose(rotated, expected, rtol=1.0e-5, atol=1.0e-8) def test_spatial_axes(self): rotate = RandRotate90(spatial_axes=(0, 1)) - rotate.set_random_state(234) - rotated = rotate(self.imt[0]) - expected = [] - for channel in self.imt[0]: - expected.append(np.rot90(channel, 0, (0, 1))) - expected = np.stack(expected) - self.assertTrue(np.allclose(rotated, expected)) + for p in TEST_NDARRAYS: + rotate.set_random_state(234) + rotated = rotate(p(self.imt[0])) + expected = [] + for channel in self.imt[0]: + expected.append(np.rot90(channel, 0, (0, 1))) + expected = np.stack(expected) + assert_allclose(rotated, expected, rtol=1.0e-5, atol=1.0e-8) def test_prob_k_spatial_axes(self): rotate = RandRotate90(prob=1.0, max_k=2, spatial_axes=(0, 1)) - rotate.set_random_state(234) - rotated = rotate(self.imt[0]) - expected = [] - for channel in self.imt[0]: - expected.append(np.rot90(channel, 1, (0, 1))) - expected = np.stack(expected) - self.assertTrue(np.allclose(rotated, expected)) + for p in TEST_NDARRAYS: + rotate.set_random_state(234) + rotated = rotate(p(self.imt[0])) + expected = [] + for channel in self.imt[0]: + expected.append(np.rot90(channel, 1, (0, 1))) + expected = np.stack(expected) + assert_allclose(rotated, expected, rtol=1.0e-5, atol=1.0e-8) if __name__ == "__main__": diff --git a/tests/test_rand_rotate90d.py b/tests/test_rand_rotate90d.py index a487b695f5..f9083afb0c 100644 --- a/tests/test_rand_rotate90d.py +++ b/tests/test_rand_rotate90d.py @@ -14,53 +14,57 @@ import numpy as np from monai.transforms import RandRotate90d -from tests.utils import NumpyImageTestCase2D +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose class TestRandRotate90d(NumpyImageTestCase2D): def test_default(self): key = None rotate = RandRotate90d(keys=key) - rotate.set_random_state(123) - rotated = rotate({key: self.imt[0]}) - expected = [] - for channel in self.imt[0]: - expected.append(np.rot90(channel, 0, (0, 1))) - expected = np.stack(expected) - self.assertTrue(np.allclose(rotated[key], expected)) + for p in TEST_NDARRAYS: + rotate.set_random_state(123) + rotated = rotate({key: p(self.imt[0])}) + expected = [] + for channel in self.imt[0]: + expected.append(np.rot90(channel, 0, (0, 1))) + expected = np.stack(expected) + assert_allclose(rotated[key], expected) def test_k(self): key = "test" rotate = RandRotate90d(keys=key, max_k=2) - rotate.set_random_state(234) - rotated = rotate({key: self.imt[0]}) - expected = [] - for channel in self.imt[0]: - expected.append(np.rot90(channel, 0, (0, 1))) - expected = np.stack(expected) - self.assertTrue(np.allclose(rotated[key], expected)) + for p in TEST_NDARRAYS: + rotate.set_random_state(234) + rotated = rotate({key: p(self.imt[0])}) + expected = [] + for channel in self.imt[0]: + expected.append(np.rot90(channel, 0, (0, 1))) + expected = np.stack(expected) + assert_allclose(rotated[key], expected) def test_spatial_axes(self): key = "test" rotate = RandRotate90d(keys=key, spatial_axes=(0, 1)) - rotate.set_random_state(234) - rotated = rotate({key: self.imt[0]}) - expected = [] - for channel in self.imt[0]: - expected.append(np.rot90(channel, 0, (0, 1))) - expected = np.stack(expected) - self.assertTrue(np.allclose(rotated[key], expected)) + for p in TEST_NDARRAYS: + rotate.set_random_state(234) + rotated = rotate({key: p(self.imt[0])}) + expected = [] + for channel in self.imt[0]: + expected.append(np.rot90(channel, 0, (0, 1))) + expected = np.stack(expected) + assert_allclose(rotated[key], expected) def test_prob_k_spatial_axes(self): key = "test" rotate = RandRotate90d(keys=key, prob=1.0, max_k=2, spatial_axes=(0, 1)) - rotate.set_random_state(234) - rotated = rotate({key: self.imt[0]}) - expected = [] - for channel in self.imt[0]: - expected.append(np.rot90(channel, 1, (0, 1))) - expected = np.stack(expected) - self.assertTrue(np.allclose(rotated[key], expected)) + for p in TEST_NDARRAYS: + rotate.set_random_state(234) + rotated = rotate({key: p(self.imt[0])}) + expected = [] + for channel in self.imt[0]: + expected.append(np.rot90(channel, 1, (0, 1))) + expected = np.stack(expected) + assert_allclose(rotated[key], expected) def test_no_key(self): key = "unknown" diff --git a/tests/test_rotate90.py b/tests/test_rotate90.py index 4ab39d5cf6..03a967a16b 100644 --- a/tests/test_rotate90.py +++ b/tests/test_rotate90.py @@ -14,45 +14,49 @@ import numpy as np from monai.transforms import Rotate90 -from tests.utils import NumpyImageTestCase2D +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose class TestRotate90(NumpyImageTestCase2D): def test_rotate90_default(self): rotate = Rotate90() - rotated = rotate(self.imt[0]) - expected = [] - for channel in self.imt[0]: - expected.append(np.rot90(channel, 1, (0, 1))) - expected = np.stack(expected) - self.assertTrue(np.allclose(rotated, expected)) + for p in TEST_NDARRAYS: + rotated = rotate(p(self.imt[0])) + expected = [] + for channel in self.imt[0]: + expected.append(np.rot90(channel, 1, (0, 1))) + expected = np.stack(expected) + assert_allclose(rotated, expected, rtol=1.0e-5, atol=1.0e-8) def test_k(self): rotate = Rotate90(k=2) - rotated = rotate(self.imt[0]) - expected = [] - for channel in self.imt[0]: - expected.append(np.rot90(channel, 2, (0, 1))) - expected = np.stack(expected) - self.assertTrue(np.allclose(rotated, expected)) + for p in TEST_NDARRAYS: + rotated = rotate(p(self.imt[0])) + expected = [] + for channel in self.imt[0]: + expected.append(np.rot90(channel, 2, (0, 1))) + expected = np.stack(expected) + assert_allclose(rotated, expected, rtol=1.0e-5, atol=1.0e-8) def test_spatial_axes(self): rotate = Rotate90(spatial_axes=(0, -1)) - rotated = rotate(self.imt[0]) - expected = [] - for channel in self.imt[0]: - expected.append(np.rot90(channel, 1, (0, -1))) - expected = np.stack(expected) - self.assertTrue(np.allclose(rotated, expected)) + for p in TEST_NDARRAYS: + rotated = rotate(p(self.imt[0])) + expected = [] + for channel in self.imt[0]: + expected.append(np.rot90(channel, 1, (0, -1))) + expected = np.stack(expected) + assert_allclose(rotated, expected, rtol=1.0e-5, atol=1.0e-8) def test_prob_k_spatial_axes(self): rotate = Rotate90(k=2, spatial_axes=(0, 1)) - rotated = rotate(self.imt[0]) - expected = [] - for channel in self.imt[0]: - expected.append(np.rot90(channel, 2, (0, 1))) - expected = np.stack(expected) - self.assertTrue(np.allclose(rotated, expected)) + for p in TEST_NDARRAYS: + rotated = rotate(p(self.imt[0])) + expected = [] + for channel in self.imt[0]: + expected.append(np.rot90(channel, 2, (0, 1))) + expected = np.stack(expected) + assert_allclose(rotated, expected, rtol=1.0e-5, atol=1.0e-8) if __name__ == "__main__": diff --git a/tests/test_rotate90d.py b/tests/test_rotate90d.py index 3d71ead82a..a1fa3c977c 100644 --- a/tests/test_rotate90d.py +++ b/tests/test_rotate90d.py @@ -14,49 +14,53 @@ import numpy as np from monai.transforms import Rotate90d -from tests.utils import NumpyImageTestCase2D +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose class TestRotate90d(NumpyImageTestCase2D): def test_rotate90_default(self): key = "test" rotate = Rotate90d(keys=key) - rotated = rotate({key: self.imt[0]}) - expected = [] - for channel in self.imt[0]: - expected.append(np.rot90(channel, 1, (0, 1))) - expected = np.stack(expected) - self.assertTrue(np.allclose(rotated[key], expected)) + for p in TEST_NDARRAYS: + rotated = rotate({key: p(self.imt[0])}) + expected = [] + for channel in self.imt[0]: + expected.append(np.rot90(channel, 1, (0, 1))) + expected = np.stack(expected) + assert_allclose(rotated[key], expected) def test_k(self): key = None rotate = Rotate90d(keys=key, k=2) - rotated = rotate({key: self.imt[0]}) - expected = [] - for channel in self.imt[0]: - expected.append(np.rot90(channel, 2, (0, 1))) - expected = np.stack(expected) - self.assertTrue(np.allclose(rotated[key], expected)) + for p in TEST_NDARRAYS: + rotated = rotate({key: p(self.imt[0])}) + expected = [] + for channel in self.imt[0]: + expected.append(np.rot90(channel, 2, (0, 1))) + expected = np.stack(expected) + assert_allclose(rotated[key], expected) def test_spatial_axes(self): key = "test" rotate = Rotate90d(keys=key, spatial_axes=(0, 1)) - rotated = rotate({key: self.imt[0]}) - expected = [] - for channel in self.imt[0]: - expected.append(np.rot90(channel, 1, (0, 1))) - expected = np.stack(expected) - self.assertTrue(np.allclose(rotated[key], expected)) + for p in TEST_NDARRAYS: + rotated = rotate({key: p(self.imt[0])}) + expected = [] + for channel in self.imt[0]: + expected.append(np.rot90(channel, 1, (0, 1))) + expected = np.stack(expected) + assert_allclose(rotated[key], expected) def test_prob_k_spatial_axes(self): key = "test" rotate = Rotate90d(keys=key, k=2, spatial_axes=(0, 1)) - rotated = rotate({key: self.imt[0]}) - expected = [] - for channel in self.imt[0]: - expected.append(np.rot90(channel, 2, (0, 1))) - expected = np.stack(expected) - self.assertTrue(np.allclose(rotated[key], expected)) + for p in TEST_NDARRAYS: + rotated = rotate({key: p(self.imt[0])}) + expected = [] + for channel in self.imt[0]: + expected.append(np.rot90(channel, 2, (0, 1))) + expected = np.stack(expected) + assert_allclose(rotated[key], expected) def test_no_key(self): key = "unknown"