Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions monai/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,7 @@
in_bounds,
is_empty,
map_binary_to_indices,
map_spatial_axes,
rand_choice,
rescale_array,
rescale_array_int_max,
Expand Down
17 changes: 8 additions & 9 deletions monai/transforms/spatial/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
create_scale,
create_shear,
create_translate,
map_spatial_axes,
)
from monai.utils import (
GridSampleMode,
Expand Down Expand Up @@ -297,18 +298,17 @@ class Flip(Transform):

"""

def __init__(self, spatial_axis: Optional[Union[Sequence[int], int]]) -> None:
def __init__(self, spatial_axis: Optional[Union[Sequence[int], int]] = None) -> None:
self.spatial_axis = spatial_axis

def __call__(self, img: np.ndarray) -> np.ndarray:
"""
Args:
img: channel first array, must have shape: (num_channels, H[, W, ..., ]),
"""
flipped = []
for channel in img:
flipped.append(np.flip(channel, self.spatial_axis))
return np.stack(flipped).astype(img.dtype)

result: np.ndarray = np.flip(img, map_spatial_axes(img.ndim, self.spatial_axis))
return result.astype(img.dtype)


class Resize(Transform):
Expand Down Expand Up @@ -596,10 +596,9 @@ def __call__(self, img: np.ndarray) -> np.ndarray:
Args:
img: channel first array, must have shape: (num_channels, H[, W, ..., ]),
"""
rotated = []
for channel in img:
rotated.append(np.rot90(channel, self.k, self.spatial_axes))
return np.stack(rotated).astype(img.dtype)

result: np.ndarray = np.rot90(img, self.k, map_spatial_axes(img.ndim, self.spatial_axes))
return result.astype(img.dtype)


class RandRotate90(Randomizable, Transform):
Expand Down
40 changes: 40 additions & 0 deletions monai/transforms/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
"get_largest_connected_component_mask",
"get_extreme_points",
"extreme_points_to_image",
"map_spatial_axes",
]


Expand Down Expand Up @@ -690,3 +691,42 @@ def extreme_points_to_image(
points_image = (points_image - min_intensity) / (max_intensity - min_intensity)
points_image = points_image * (rescale_max - rescale_min) + rescale_min
return points_image


def map_spatial_axes(
img_ndim: int,
spatial_axes: Optional[Union[Sequence[int], int]] = None,
channel_first: bool = True,
) -> List[int]:
"""
Utility to map the spatial axes to real axes in channel first/last shape.
For example:
If `channel_first` is True, and `img` has 3 spatial dims, map spatial axes to real axes as below:
None -> [1, 2, 3]
[0, 1] -> [1, 2]
[0, -1] -> [1, -1]
If `channel_first` is False, and `img` has 3 spatial dims, map spatial axes to real axes as below:
None -> [0, 1, 2]
[0, 1] -> [0, 1]
[0, -1] -> [0, -2]

Args:
img_ndim: dimension number of the target image.
spatial_axes: spatial axes to be converted, default is None.
The default `None` will convert to all the spatial axes of the image.
If axis is negative it counts from the last to the first axis.
If axis is a tuple of ints.
channel_first: the image data is channel first or channel last, defaut to channel first.

"""
if spatial_axes is None:
spatial_axes_ = list(range(1, img_ndim) if channel_first else range(0, img_ndim - 1))
else:
spatial_axes_ = []
for a in ensure_tuple(spatial_axes):
if channel_first:
spatial_axes_.append(a if a < 0 else a + 1)
else:
spatial_axes_.append(a - 1 if a < 0 else a)

return spatial_axes_
2 changes: 1 addition & 1 deletion tests/test_flip.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

INVALID_CASES = [("wrong_axis", ["s", 1], TypeError), ("not_numbers", "s", TypeError)]

VALID_CASES = [("no_axis", None), ("one_axis", 1), ("many_axis", [0, 1])]
VALID_CASES = [("no_axis", None), ("one_axis", 1), ("many_axis", [0, 1]), ("negative_axis", [0, -1])]


class TestFlip(NumpyImageTestCase2D):
Expand Down