Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 28 additions & 15 deletions monai/transforms/spatial/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,14 +441,16 @@ class Rotate(Transform, ThreadUnsafe):
the output data type is always ``np.float32``.
"""

backend = [TransformBackends.TORCH]

def __init__(
self,
angle: Union[Sequence[float], float],
keep_size: bool = True,
mode: Union[GridSampleMode, str] = GridSampleMode.BILINEAR,
padding_mode: Union[GridSamplePadMode, str] = GridSamplePadMode.BORDER,
align_corners: bool = False,
dtype: DtypeLike = np.float64,
dtype: Union[DtypeLike, torch.dtype] = np.float64,
) -> None:
self.angle = angle
self.keep_size = keep_size
Expand All @@ -460,12 +462,12 @@ def __init__(

def __call__(
self,
img: np.ndarray,
img: NdarrayOrTensor,
mode: Optional[Union[GridSampleMode, str]] = None,
padding_mode: Optional[Union[GridSamplePadMode, str]] = None,
align_corners: Optional[bool] = None,
dtype: DtypeLike = None,
) -> np.ndarray:
dtype: Union[DtypeLike, torch.dtype] = None,
) -> torch.Tensor:
"""
Args:
img: channel first array, must have shape: [chns, H, W] or [chns, H, W, D].
Expand All @@ -488,7 +490,11 @@ def __call__(

"""
_dtype = dtype or self.dtype or img.dtype
im_shape = np.asarray(img.shape[1:]) # spatial dimensions

img_t: torch.Tensor
img_t, *_ = convert_data_type(img, torch.Tensor, dtype=_dtype) # type: ignore

im_shape = np.asarray(img_t.shape[1:]) # spatial dimensions
input_ndim = len(im_shape)
if input_ndim not in (2, 3):
raise ValueError(f"Unsupported img dimension: {input_ndim}, available options are [2, 3].")
Expand All @@ -506,20 +512,23 @@ def __call__(
shift_1 = create_translate(input_ndim, (-(output_shape - 1) / 2).tolist())
transform = shift @ transform @ shift_1

transform_t: torch.Tensor
transform_t, *_ = convert_to_dst_type(transform, img_t) # type: ignore

xform = AffineTransform(
normalized=False,
mode=look_up_option(mode or self.mode, GridSampleMode),
padding_mode=look_up_option(padding_mode or self.padding_mode, GridSamplePadMode),
align_corners=self.align_corners if align_corners is None else align_corners,
reverse_indexing=True,
)
output = xform(
torch.as_tensor(np.ascontiguousarray(img).astype(_dtype)).unsqueeze(0),
torch.as_tensor(np.ascontiguousarray(transform).astype(_dtype)),
output: torch.Tensor = xform(
img_t.unsqueeze(0),
transform_t,
spatial_size=output_shape,
)
self._rotation_matrix = transform
return np.asarray(output.squeeze(0).detach().cpu().numpy(), dtype=np.float32)
return output.squeeze(0).detach().float()

def get_rotation_matrix(self) -> Optional[np.ndarray]:
"""
Expand Down Expand Up @@ -738,6 +747,8 @@ class RandRotate(RandomizableTransform):
the output data type is always ``np.float32``.
"""

backend = Rotate.backend

def __init__(
self,
range_x: Union[Tuple[float, float], float] = 0.0,
Expand All @@ -748,7 +759,7 @@ def __init__(
mode: Union[GridSampleMode, str] = GridSampleMode.BILINEAR,
padding_mode: Union[GridSamplePadMode, str] = GridSamplePadMode.BORDER,
align_corners: bool = False,
dtype: DtypeLike = np.float64,
dtype: Union[DtypeLike, torch.dtype] = np.float64,
) -> None:
RandomizableTransform.__init__(self, prob)
self.range_x = ensure_tuple(range_x)
Expand Down Expand Up @@ -779,12 +790,12 @@ def randomize(self, data: Optional[Any] = None) -> None:

def __call__(
self,
img: np.ndarray,
img: NdarrayOrTensor,
mode: Optional[Union[GridSampleMode, str]] = None,
padding_mode: Optional[Union[GridSamplePadMode, str]] = None,
align_corners: Optional[bool] = None,
dtype: DtypeLike = None,
) -> np.ndarray:
dtype: Union[DtypeLike, torch.dtype] = None,
) -> torch.Tensor:
"""
Args:
img: channel first array, must have shape 2D: (nchannels, H, W), or 3D: (nchannels, H, W, D).
Expand All @@ -802,7 +813,9 @@ def __call__(
"""
self.randomize()
if not self._do_transform:
return img
img_t: torch.Tensor
img_t, *_ = convert_data_type(img, torch.Tensor) # type: ignore
return img_t
rotator = Rotate(
angle=self.x if img.ndim == 3 else (self.x, self.y, self.z),
keep_size=self.keep_size,
Expand All @@ -811,7 +824,7 @@ def __call__(
align_corners=self.align_corners if align_corners is None else align_corners,
dtype=dtype or self.dtype or img.dtype,
)
return np.array(rotator(img))
return rotator(img)


class RandFlip(RandomizableTransform):
Expand Down
40 changes: 28 additions & 12 deletions monai/transforms/spatial/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
)
from monai.utils.enums import InverseKeys
from monai.utils.module import optional_import
from monai.utils.type_conversion import convert_data_type, convert_to_dst_type

nib, _ = optional_import("nibabel")

Expand Down Expand Up @@ -1287,6 +1288,8 @@ class Rotated(MapTransform, InvertibleTransform):
allow_missing_keys: don't raise exception if key is missing.
"""

backend = Rotate.backend

def __init__(
self,
keys: KeysCollection,
Expand All @@ -1295,7 +1298,7 @@ def __init__(
mode: GridSampleModeSequence = GridSampleMode.BILINEAR,
padding_mode: GridSamplePadModeSequence = GridSamplePadMode.BORDER,
align_corners: Union[Sequence[bool], bool] = False,
dtype: Union[Sequence[DtypeLike], DtypeLike] = np.float64,
dtype: Union[Sequence[Union[DtypeLike, torch.dtype]], Union[DtypeLike, torch.dtype]] = np.float64,
allow_missing_keys: bool = False,
) -> None:
super().__init__(keys, allow_missing_keys)
Expand All @@ -1306,7 +1309,7 @@ def __init__(
self.align_corners = ensure_tuple_rep(align_corners, len(self.keys))
self.dtype = ensure_tuple_rep(dtype, len(self.keys))

def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]:
def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
d = dict(data)
for key, mode, padding_mode, align_corners, dtype in self.key_iterator(
d, self.mode, self.padding_mode, self.align_corners, self.dtype
Expand All @@ -1333,7 +1336,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda
)
return d

def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]:
def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
d = deepcopy(dict(data))
for key, dtype in self.key_iterator(d, self.dtype):
transform = self.get_most_recent_transform(d, key)
Expand All @@ -1351,12 +1354,17 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar
align_corners=False if align_corners == "none" else align_corners,
reverse_indexing=True,
)
img_t: torch.Tensor
img_t, *_ = convert_data_type(d[key], torch.Tensor, dtype=dtype) # type: ignore
transform_t: torch.Tensor
transform_t, *_ = convert_to_dst_type(inv_rot_mat, img_t) # type: ignore

output = xform(
torch.as_tensor(np.ascontiguousarray(d[key]).astype(dtype)).unsqueeze(0),
torch.as_tensor(np.ascontiguousarray(inv_rot_mat).astype(dtype)),
img_t.unsqueeze(0),
transform_t,
spatial_size=transform[InverseKeys.ORIG_SIZE],
)
d[key] = np.asarray(output.squeeze(0).detach().cpu().numpy(), dtype=np.float32)
d[key] = output.squeeze(0).detach().float()
# Remove the applied transform
self.pop_transform(d, key)

Expand Down Expand Up @@ -1398,6 +1406,8 @@ class RandRotated(RandomizableTransform, MapTransform, InvertibleTransform):
allow_missing_keys: don't raise exception if key is missing.
"""

backend = Rotate.backend

def __init__(
self,
keys: KeysCollection,
Expand All @@ -1409,7 +1419,7 @@ def __init__(
mode: GridSampleModeSequence = GridSampleMode.BILINEAR,
padding_mode: GridSamplePadModeSequence = GridSamplePadMode.BORDER,
align_corners: Union[Sequence[bool], bool] = False,
dtype: Union[Sequence[DtypeLike], DtypeLike] = np.float64,
dtype: Union[Sequence[Union[DtypeLike, torch.dtype]], Union[DtypeLike, torch.dtype]] = np.float64,
allow_missing_keys: bool = False,
) -> None:
MapTransform.__init__(self, keys, allow_missing_keys)
Expand Down Expand Up @@ -1440,7 +1450,7 @@ def randomize(self, data: Optional[Any] = None) -> None:
self.y = self.R.uniform(low=self.range_y[0], high=self.range_y[1])
self.z = self.R.uniform(low=self.range_z[0], high=self.range_z[1])

def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]:
def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
self.randomize()
d = dict(data)
angle: Union[Sequence[float], float] = self.x if d[self.keys[0]].ndim == 3 else (self.x, self.y, self.z)
Expand All @@ -1462,6 +1472,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda
)
rot_mat = rotator.get_rotation_matrix()
else:
d[key], *_ = convert_data_type(d[key], torch.Tensor)
rot_mat = np.eye(d[key].ndim)
self.push_transform(
d,
Expand All @@ -1476,7 +1487,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda
)
return d

def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]:
def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
d = deepcopy(dict(data))
for key, dtype in self.key_iterator(d, self.dtype):
transform = self.get_most_recent_transform(d, key)
Expand All @@ -1496,12 +1507,17 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar
align_corners=False if align_corners == "none" else align_corners,
reverse_indexing=True,
)
img_t: torch.Tensor
img_t, *_ = convert_data_type(d[key], torch.Tensor, dtype=dtype) # type: ignore
transform_t: torch.Tensor
transform_t, *_ = convert_to_dst_type(inv_rot_mat, img_t) # type: ignore
output: torch.Tensor
output = xform(
torch.as_tensor(np.ascontiguousarray(d[key]).astype(dtype)).unsqueeze(0),
torch.as_tensor(np.ascontiguousarray(inv_rot_mat).astype(dtype)),
img_t.unsqueeze(0),
transform_t,
spatial_size=transform[InverseKeys.ORIG_SIZE],
)
d[key] = np.asarray(output.squeeze(0).detach().cpu().numpy(), dtype=np.float32)
d[key] = output.squeeze(0).detach().float()
# Remove the applied transform
self.pop_transform(d, key)

Expand Down
91 changes: 51 additions & 40 deletions tests/test_rand_rotate.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,25 +10,60 @@
# limitations under the License.

import unittest
from typing import List, Tuple

import numpy as np
import scipy.ndimage
import torch
from parameterized import parameterized

from monai.transforms import RandRotate
from tests.utils import NumpyImageTestCase2D, NumpyImageTestCase3D
from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, NumpyImageTestCase3D

TEST_CASES_2D: List[Tuple] = []
for p in TEST_NDARRAYS:
TEST_CASES_2D.append((p, np.pi / 2, True, "bilinear", "border", False))
TEST_CASES_2D.append((p, np.pi / 4, True, "nearest", "border", False))
TEST_CASES_2D.append((p, np.pi, False, "nearest", "zeros", True))
TEST_CASES_2D.append((p, (-np.pi / 4, 0), False, "nearest", "zeros", True))

class TestRandRotate2D(NumpyImageTestCase2D):
@parameterized.expand(
[
(np.pi / 2, True, "bilinear", "border", False),
(np.pi / 4, True, "nearest", "border", False),
(np.pi, False, "nearest", "zeros", True),
((-np.pi / 4, 0), False, "nearest", "zeros", True),
]
TEST_CASES_3D: List[Tuple] = []
for p in TEST_NDARRAYS:
TEST_CASES_3D.append(
(p, np.pi / 2, -np.pi / 6, (0.0, np.pi), False, "bilinear", "border", False, (1, 87, 104, 109))
)
TEST_CASES_3D.append(
(
p,
np.pi / 4,
(-np.pi / 9, np.pi / 4.5),
(np.pi / 9, np.pi / 6),
False,
"nearest",
"border",
True,
(1, 89, 105, 104),
)
)
TEST_CASES_3D.append(
(
p,
0.0,
(2 * np.pi, 2.06 * np.pi),
(-np.pi / 180, np.pi / 180),
True,
"nearest",
"zeros",
True,
(1, 48, 64, 80),
)
)
def test_correct_results(self, degrees, keep_size, mode, padding_mode, align_corners):
TEST_CASES_3D.append((p, (-np.pi / 4, 0), 0, 0, False, "nearest", "zeros", False, (1, 48, 77, 90)))


class TestRandRotate2D(NumpyImageTestCase2D):
@parameterized.expand(TEST_CASES_2D)
def test_correct_results(self, im_type, degrees, keep_size, mode, padding_mode, align_corners):
rotate_fn = RandRotate(
range_x=degrees,
prob=1.0,
Expand All @@ -38,7 +73,7 @@ def test_correct_results(self, degrees, keep_size, mode, padding_mode, align_cor
align_corners=align_corners,
)
rotate_fn.set_random_state(243)
rotated = rotate_fn(self.imt[0])
rotated = rotate_fn(im_type(self.imt[0]))

_order = 0 if mode == "nearest" else 1
if mode == "border":
Expand All @@ -52,38 +87,14 @@ def test_correct_results(self, degrees, keep_size, mode, padding_mode, align_cor
self.imt[0, 0], -np.rad2deg(angle), (0, 1), not keep_size, order=_order, mode=_mode, prefilter=False
)
expected = np.stack(expected).astype(np.float32)
rotated = rotated.cpu() if isinstance(rotated, torch.Tensor) else rotated
good = np.sum(np.isclose(expected, rotated[0], atol=1e-3))
self.assertLessEqual(np.abs(good - expected.size), 5, "diff at most 5 pixels")


class TestRandRotate3D(NumpyImageTestCase3D):
@parameterized.expand(
[
(np.pi / 2, -np.pi / 6, (0.0, np.pi), False, "bilinear", "border", False, (1, 87, 104, 109)),
(
np.pi / 4,
(-np.pi / 9, np.pi / 4.5),
(np.pi / 9, np.pi / 6),
False,
"nearest",
"border",
True,
(1, 89, 105, 104),
),
(
0.0,
(2 * np.pi, 2.06 * np.pi),
(-np.pi / 180, np.pi / 180),
True,
"nearest",
"zeros",
True,
(1, 48, 64, 80),
),
((-np.pi / 4, 0), 0, 0, False, "nearest", "zeros", False, (1, 48, 77, 90)),
]
)
def test_correct_results(self, x, y, z, keep_size, mode, padding_mode, align_corners, expected):
@parameterized.expand(TEST_CASES_3D)
def test_correct_results(self, im_type, x, y, z, keep_size, mode, padding_mode, align_corners, expected):
rotate_fn = RandRotate(
range_x=x,
range_y=y,
Expand All @@ -95,8 +106,8 @@ def test_correct_results(self, x, y, z, keep_size, mode, padding_mode, align_cor
align_corners=align_corners,
)
rotate_fn.set_random_state(243)
rotated = rotate_fn(self.imt[0])
np.testing.assert_allclose(rotated.shape, expected)
rotated = rotate_fn(im_type(self.imt[0]))
torch.testing.assert_allclose(rotated.shape, expected, rtol=1e-7, atol=0)


if __name__ == "__main__":
Expand Down
Loading