Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions monai/transforms/spatial/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -637,6 +637,8 @@ class Rotate90(Transform):

"""

backend = [TransformBackends.TORCH, TransformBackends.NUMPY]

def __init__(self, k: int = 1, spatial_axes: Tuple[int, int] = (0, 1)) -> None:
"""
Args:
Expand All @@ -651,14 +653,15 @@ def __init__(self, k: int = 1, spatial_axes: Tuple[int, int] = (0, 1)) -> None:
raise ValueError("spatial_axes must be 2 int numbers to indicate the axes to rotate 90 degrees.")
self.spatial_axes = spatial_axes_

def __call__(self, img: np.ndarray) -> np.ndarray:
def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
"""
Args:
img: channel first array, must have shape: (num_channels, H[, W, ..., ]),
"""

result: np.ndarray = np.rot90(img, self.k, map_spatial_axes(img.ndim, self.spatial_axes))
return result.astype(img.dtype)
rot90 = torch.rot90 if isinstance(img, torch.Tensor) else np.rot90
out: NdarrayOrTensor = rot90(img, self.k, map_spatial_axes(img.ndim, self.spatial_axes))
out, *_ = convert_data_type(out, dtype=img.dtype)
return out


class RandRotate90(RandomizableTransform):
Expand All @@ -667,6 +670,8 @@ class RandRotate90(RandomizableTransform):
in the plane specified by `spatial_axes`.
"""

backend = Rotate90.backend

def __init__(self, prob: float = 0.1, max_k: int = 3, spatial_axes: Tuple[int, int] = (0, 1)) -> None:
"""
Args:
Expand All @@ -686,7 +691,7 @@ def randomize(self, data: Optional[Any] = None) -> None:
self._rand_k = self.R.randint(self.max_k) + 1
super().randomize(None)

def __call__(self, img: np.ndarray) -> np.ndarray:
def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
"""
Args:
img: channel first array, must have shape: (num_channels, H[, W, ..., ]),
Expand Down
18 changes: 8 additions & 10 deletions monai/transforms/spatial/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,8 @@ class Rotate90d(MapTransform, InvertibleTransform):
Dictionary-based wrapper of :py:class:`monai.transforms.Rotate90`.
"""

backend = Rotate90.backend

def __init__(
self, keys: KeysCollection, k: int = 1, spatial_axes: Tuple[int, int] = (0, 1), allow_missing_keys: bool = False
) -> None:
Expand All @@ -395,14 +397,14 @@ def __init__(
super().__init__(keys, allow_missing_keys)
self.rotator = Rotate90(k, spatial_axes)

def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]:
def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
d = dict(data)
for key in self.key_iterator(d):
self.push_transform(d, key)
d[key] = self.rotator(d[key])
return d

def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]:
def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
d = deepcopy(dict(data))
for key in self.key_iterator(d):
_ = self.get_most_recent_transform(d, key)
Expand All @@ -411,9 +413,6 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar
num_times_rotated = self.rotator.k
num_times_to_rotate = 4 - num_times_rotated
inverse_transform = Rotate90(num_times_to_rotate, spatial_axes)
# Might need to convert to numpy
if isinstance(d[key], torch.Tensor):
d[key] = torch.Tensor(d[key]).cpu().numpy()
# Apply inverse
d[key] = inverse_transform(d[key])
# Remove the applied transform
Expand All @@ -429,6 +428,8 @@ class RandRotate90d(RandomizableTransform, MapTransform, InvertibleTransform):
in the plane specified by `spatial_axes`.
"""

backend = Rotate90.backend

def __init__(
self,
keys: KeysCollection,
Expand Down Expand Up @@ -461,7 +462,7 @@ def randomize(self, data: Optional[Any] = None) -> None:
self._rand_k = self.R.randint(self.max_k) + 1
super().randomize(None)

def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Mapping[Hashable, np.ndarray]:
def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Mapping[Hashable, NdarrayOrTensor]:
self.randomize()
d = dict(data)

Expand All @@ -472,7 +473,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Mapping[Hashable, np.
self.push_transform(d, key, extra_info={"rand_k": self._rand_k})
return d

def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]:
def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
d = deepcopy(dict(data))
for key in self.key_iterator(d):
transform = self.get_most_recent_transform(d, key)
Expand All @@ -482,9 +483,6 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar
num_times_rotated = transform[InverseKeys.EXTRA_INFO]["rand_k"]
num_times_to_rotate = 4 - num_times_rotated
inverse_transform = Rotate90(num_times_to_rotate, self.spatial_axes)
# Might need to convert to numpy
if isinstance(d[key], torch.Tensor):
d[key] = torch.Tensor(d[key]).cpu().numpy()
# Apply inverse
d[key] = inverse_transform(d[key])
# Remove the applied transform
Expand Down
14 changes: 12 additions & 2 deletions tests/test_inverse_collation.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,12 @@
for t in [
RandFlipd(keys=KEYS, prob=0.5, spatial_axis=[1, 2]),
RandAxisFlipd(keys=KEYS, prob=0.5),
RandRotate90d(keys=KEYS, spatial_axes=(1, 2)),
Compose(
[
RandRotate90d(keys=KEYS, spatial_axes=(1, 2)),
ToTensord(keys=KEYS),
]
),
RandZoomd(keys=KEYS, prob=0.5, min_zoom=0.5, max_zoom=1.1, keep_size=True),
RandRotated(keys=KEYS, prob=0.5, range_x=np.pi),
RandAffined(
Expand All @@ -67,7 +72,12 @@
for t in [
RandFlipd(keys=KEYS, prob=0.5, spatial_axis=[1]),
RandAxisFlipd(keys=KEYS, prob=0.5),
RandRotate90d(keys=KEYS, prob=0.5, spatial_axes=(0, 1)),
Compose(
[
RandRotate90d(keys=KEYS, prob=0.5, spatial_axes=(0, 1)),
ToTensord(keys=KEYS),
]
),
RandZoomd(keys=KEYS, prob=0.5, min_zoom=0.5, max_zoom=1.1, keep_size=True),
RandRotated(keys=KEYS, prob=0.5, range_x=np.pi),
RandAffined(
Expand Down
7 changes: 5 additions & 2 deletions tests/test_pad_collation.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from monai.data import CacheDataset, DataLoader
from monai.data.utils import decollate_batch, pad_list_data_collate
from monai.transforms import (
Compose,
PadListDataCollate,
RandRotate,
RandRotate90,
Expand All @@ -29,6 +30,8 @@
RandSpatialCropd,
RandZoom,
RandZoomd,
ToTensor,
ToTensord,
)
from monai.utils import set_determinism

Expand All @@ -41,12 +44,12 @@
TESTS.append((dict, pad_collate, RandSpatialCropd("image", roi_size=[8, 7], random_size=True)))
TESTS.append((dict, pad_collate, RandRotated("image", prob=1, range_x=np.pi, keep_size=False)))
TESTS.append((dict, pad_collate, RandZoomd("image", prob=1, min_zoom=1.1, max_zoom=2.0, keep_size=False)))
TESTS.append((dict, pad_collate, RandRotate90d("image", prob=1, max_k=2)))
TESTS.append((dict, pad_collate, Compose([RandRotate90d("image", prob=1, max_k=2), ToTensord("image")])))

TESTS.append((list, pad_collate, RandSpatialCrop(roi_size=[8, 7], random_size=True)))
TESTS.append((list, pad_collate, RandRotate(prob=1, range_x=np.pi, keep_size=False)))
TESTS.append((list, pad_collate, RandZoom(prob=1, min_zoom=1.1, max_zoom=2.0, keep_size=False)))
TESTS.append((list, pad_collate, RandRotate90(prob=1, max_k=2)))
TESTS.append((list, pad_collate, Compose([RandRotate90(prob=1, max_k=2), ToTensor()])))


class _Dataset(torch.utils.data.Dataset):
Expand Down
62 changes: 33 additions & 29 deletions tests/test_rand_rotate90.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,49 +14,53 @@
import numpy as np

from monai.transforms import RandRotate90
from tests.utils import NumpyImageTestCase2D
from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose


class TestRandRotate90(NumpyImageTestCase2D):
def test_default(self):
rotate = RandRotate90()
rotate.set_random_state(123)
rotated = rotate(self.imt[0])
expected = []
for channel in self.imt[0]:
expected.append(np.rot90(channel, 0, (0, 1)))
expected = np.stack(expected)
self.assertTrue(np.allclose(rotated, expected))
for p in TEST_NDARRAYS:
rotate.set_random_state(123)
rotated = rotate(p(self.imt[0]))
expected = []
for channel in self.imt[0]:
expected.append(np.rot90(channel, 0, (0, 1)))
expected = np.stack(expected)
assert_allclose(rotated, expected, rtol=1.0e-5, atol=1.0e-8)

def test_k(self):
rotate = RandRotate90(max_k=2)
rotate.set_random_state(234)
rotated = rotate(self.imt[0])
expected = []
for channel in self.imt[0]:
expected.append(np.rot90(channel, 0, (0, 1)))
expected = np.stack(expected)
self.assertTrue(np.allclose(rotated, expected))
for p in TEST_NDARRAYS:
rotate.set_random_state(234)
rotated = rotate(p(self.imt[0]))
expected = []
for channel in self.imt[0]:
expected.append(np.rot90(channel, 0, (0, 1)))
expected = np.stack(expected)
assert_allclose(rotated, expected, rtol=1.0e-5, atol=1.0e-8)

def test_spatial_axes(self):
rotate = RandRotate90(spatial_axes=(0, 1))
rotate.set_random_state(234)
rotated = rotate(self.imt[0])
expected = []
for channel in self.imt[0]:
expected.append(np.rot90(channel, 0, (0, 1)))
expected = np.stack(expected)
self.assertTrue(np.allclose(rotated, expected))
for p in TEST_NDARRAYS:
rotate.set_random_state(234)
rotated = rotate(p(self.imt[0]))
expected = []
for channel in self.imt[0]:
expected.append(np.rot90(channel, 0, (0, 1)))
expected = np.stack(expected)
assert_allclose(rotated, expected, rtol=1.0e-5, atol=1.0e-8)

def test_prob_k_spatial_axes(self):
rotate = RandRotate90(prob=1.0, max_k=2, spatial_axes=(0, 1))
rotate.set_random_state(234)
rotated = rotate(self.imt[0])
expected = []
for channel in self.imt[0]:
expected.append(np.rot90(channel, 1, (0, 1)))
expected = np.stack(expected)
self.assertTrue(np.allclose(rotated, expected))
for p in TEST_NDARRAYS:
rotate.set_random_state(234)
rotated = rotate(p(self.imt[0]))
expected = []
for channel in self.imt[0]:
expected.append(np.rot90(channel, 1, (0, 1)))
expected = np.stack(expected)
assert_allclose(rotated, expected, rtol=1.0e-5, atol=1.0e-8)


if __name__ == "__main__":
Expand Down
62 changes: 33 additions & 29 deletions tests/test_rand_rotate90d.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,53 +14,57 @@
import numpy as np

from monai.transforms import RandRotate90d
from tests.utils import NumpyImageTestCase2D
from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose


class TestRandRotate90d(NumpyImageTestCase2D):
def test_default(self):
key = None
rotate = RandRotate90d(keys=key)
rotate.set_random_state(123)
rotated = rotate({key: self.imt[0]})
expected = []
for channel in self.imt[0]:
expected.append(np.rot90(channel, 0, (0, 1)))
expected = np.stack(expected)
self.assertTrue(np.allclose(rotated[key], expected))
for p in TEST_NDARRAYS:
rotate.set_random_state(123)
rotated = rotate({key: p(self.imt[0])})
expected = []
for channel in self.imt[0]:
expected.append(np.rot90(channel, 0, (0, 1)))
expected = np.stack(expected)
assert_allclose(rotated[key], expected)

def test_k(self):
key = "test"
rotate = RandRotate90d(keys=key, max_k=2)
rotate.set_random_state(234)
rotated = rotate({key: self.imt[0]})
expected = []
for channel in self.imt[0]:
expected.append(np.rot90(channel, 0, (0, 1)))
expected = np.stack(expected)
self.assertTrue(np.allclose(rotated[key], expected))
for p in TEST_NDARRAYS:
rotate.set_random_state(234)
rotated = rotate({key: p(self.imt[0])})
expected = []
for channel in self.imt[0]:
expected.append(np.rot90(channel, 0, (0, 1)))
expected = np.stack(expected)
assert_allclose(rotated[key], expected)

def test_spatial_axes(self):
key = "test"
rotate = RandRotate90d(keys=key, spatial_axes=(0, 1))
rotate.set_random_state(234)
rotated = rotate({key: self.imt[0]})
expected = []
for channel in self.imt[0]:
expected.append(np.rot90(channel, 0, (0, 1)))
expected = np.stack(expected)
self.assertTrue(np.allclose(rotated[key], expected))
for p in TEST_NDARRAYS:
rotate.set_random_state(234)
rotated = rotate({key: p(self.imt[0])})
expected = []
for channel in self.imt[0]:
expected.append(np.rot90(channel, 0, (0, 1)))
expected = np.stack(expected)
assert_allclose(rotated[key], expected)

def test_prob_k_spatial_axes(self):
key = "test"
rotate = RandRotate90d(keys=key, prob=1.0, max_k=2, spatial_axes=(0, 1))
rotate.set_random_state(234)
rotated = rotate({key: self.imt[0]})
expected = []
for channel in self.imt[0]:
expected.append(np.rot90(channel, 1, (0, 1)))
expected = np.stack(expected)
self.assertTrue(np.allclose(rotated[key], expected))
for p in TEST_NDARRAYS:
rotate.set_random_state(234)
rotated = rotate({key: p(self.imt[0])})
expected = []
for channel in self.imt[0]:
expected.append(np.rot90(channel, 1, (0, 1)))
expected = np.stack(expected)
assert_allclose(rotated[key], expected)

def test_no_key(self):
key = "unknown"
Expand Down
Loading