diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index ebd21a1c45..4dc7744755 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -338,6 +338,7 @@ in_bounds, is_empty, map_binary_to_indices, + map_spatial_axes, rand_choice, rescale_array, rescale_array_int_max, diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index f0cf047aa6..df10480188 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -31,6 +31,7 @@ create_scale, create_shear, create_translate, + map_spatial_axes, ) from monai.utils import ( GridSampleMode, @@ -297,7 +298,7 @@ class Flip(Transform): """ - def __init__(self, spatial_axis: Optional[Union[Sequence[int], int]]) -> None: + def __init__(self, spatial_axis: Optional[Union[Sequence[int], int]] = None) -> None: self.spatial_axis = spatial_axis def __call__(self, img: np.ndarray) -> np.ndarray: @@ -305,10 +306,9 @@ def __call__(self, img: np.ndarray) -> np.ndarray: Args: img: channel first array, must have shape: (num_channels, H[, W, ..., ]), """ - flipped = [] - for channel in img: - flipped.append(np.flip(channel, self.spatial_axis)) - return np.stack(flipped).astype(img.dtype) + + result: np.ndarray = np.flip(img, map_spatial_axes(img.ndim, self.spatial_axis)) + return result.astype(img.dtype) class Resize(Transform): @@ -596,10 +596,9 @@ def __call__(self, img: np.ndarray) -> np.ndarray: Args: img: channel first array, must have shape: (num_channels, H[, W, ..., ]), """ - rotated = [] - for channel in img: - rotated.append(np.rot90(channel, self.k, self.spatial_axes)) - return np.stack(rotated).astype(img.dtype) + + result: np.ndarray = np.rot90(img, self.k, map_spatial_axes(img.ndim, self.spatial_axes)) + return result.astype(img.dtype) class RandRotate90(Randomizable, Transform): diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index e5e9f81cc6..9a84eb00d9 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -48,6 +48,7 @@ "get_largest_connected_component_mask", "get_extreme_points", "extreme_points_to_image", + "map_spatial_axes", ] @@ -690,3 +691,42 @@ def extreme_points_to_image( points_image = (points_image - min_intensity) / (max_intensity - min_intensity) points_image = points_image * (rescale_max - rescale_min) + rescale_min return points_image + + +def map_spatial_axes( + img_ndim: int, + spatial_axes: Optional[Union[Sequence[int], int]] = None, + channel_first: bool = True, +) -> List[int]: + """ + Utility to map the spatial axes to real axes in channel first/last shape. + For example: + If `channel_first` is True, and `img` has 3 spatial dims, map spatial axes to real axes as below: + None -> [1, 2, 3] + [0, 1] -> [1, 2] + [0, -1] -> [1, -1] + If `channel_first` is False, and `img` has 3 spatial dims, map spatial axes to real axes as below: + None -> [0, 1, 2] + [0, 1] -> [0, 1] + [0, -1] -> [0, -2] + + Args: + img_ndim: dimension number of the target image. + spatial_axes: spatial axes to be converted, default is None. + The default `None` will convert to all the spatial axes of the image. + If axis is negative it counts from the last to the first axis. + If axis is a tuple of ints. + channel_first: the image data is channel first or channel last, defaut to channel first. + + """ + if spatial_axes is None: + spatial_axes_ = list(range(1, img_ndim) if channel_first else range(0, img_ndim - 1)) + else: + spatial_axes_ = [] + for a in ensure_tuple(spatial_axes): + if channel_first: + spatial_axes_.append(a if a < 0 else a + 1) + else: + spatial_axes_.append(a - 1 if a < 0 else a) + + return spatial_axes_ diff --git a/tests/test_flip.py b/tests/test_flip.py index 7a2af02585..fe169c4da8 100644 --- a/tests/test_flip.py +++ b/tests/test_flip.py @@ -19,7 +19,7 @@ INVALID_CASES = [("wrong_axis", ["s", 1], TypeError), ("not_numbers", "s", TypeError)] -VALID_CASES = [("no_axis", None), ("one_axis", 1), ("many_axis", [0, 1])] +VALID_CASES = [("no_axis", None), ("one_axis", 1), ("many_axis", [0, 1]), ("negative_axis", [0, -1])] class TestFlip(NumpyImageTestCase2D):