From 822bf35d41edcd569f598549dabd430472c3dcca Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Mon, 8 Feb 2021 17:38:10 +0800 Subject: [PATCH 1/4] [DLMED] Enhance doc-string of transforms Signed-off-by: Nic Ma --- monai/transforms/spatial/array.py | 13 +++++++++++-- tests/test_rotate90.py | 4 ++-- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index 75a25459e8..f0cf047aa6 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -285,11 +285,16 @@ def __call__( class Flip(Transform): """ Reverses the order of elements along the given spatial axis. Preserves shape. - Uses ``np.flip`` in practice. See numpy.flip for additional details. - https://docs.scipy.org/doc/numpy/reference/generated/numpy.flip.html + Uses ``np.flip`` in practice. See numpy.flip for additional details: + https://docs.scipy.org/doc/numpy/reference/generated/numpy.flip.html. Args: spatial_axis: spatial axes along which to flip over. Default is None. + The default `axis=None` will flip over all of the axes of the input array. + If axis is negative it counts from the last to the first axis. + If axis is a tuple of ints, flipping is performed on all of the axes + specified in the tuple. + """ def __init__(self, spatial_axis: Optional[Union[Sequence[int], int]]) -> None: @@ -567,6 +572,9 @@ def __call__( class Rotate90(Transform): """ Rotate an array by 90 degrees in the plane specified by `axes`. + See np.rot90 for additional details: + https://numpy.org/doc/stable/reference/generated/numpy.rot90.html. + """ def __init__(self, k: int = 1, spatial_axes: Tuple[int, int] = (0, 1)) -> None: @@ -575,6 +583,7 @@ def __init__(self, k: int = 1, spatial_axes: Tuple[int, int] = (0, 1)) -> None: k: number of times to rotate by 90 degrees. spatial_axes: 2 int numbers, defines the plane to rotate with 2 spatial axes. Default: (0, 1), this is the first two axis in spatial dimensions. + If axis is negative it counts from the last to the first axis. """ self.k = k spatial_axes_ = ensure_tuple(spatial_axes) diff --git a/tests/test_rotate90.py b/tests/test_rotate90.py index a8b4e3f57c..4ab39d5cf6 100644 --- a/tests/test_rotate90.py +++ b/tests/test_rotate90.py @@ -37,11 +37,11 @@ def test_k(self): self.assertTrue(np.allclose(rotated, expected)) def test_spatial_axes(self): - rotate = Rotate90(spatial_axes=(0, 1)) + rotate = Rotate90(spatial_axes=(0, -1)) rotated = rotate(self.imt[0]) expected = [] for channel in self.imt[0]: - expected.append(np.rot90(channel, 1, (0, 1))) + expected.append(np.rot90(channel, 1, (0, -1))) expected = np.stack(expected) self.assertTrue(np.allclose(rotated, expected)) From a53e5b3ff588bef424b648dd9f01ef2fe5b44b04 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Tue, 9 Feb 2021 18:42:14 +0800 Subject: [PATCH 2/4] [DLMED] optimize the numpy computation Signed-off-by: Nic Ma --- monai/transforms/__init__.py | 1 + monai/transforms/spatial/array.py | 15 +++++------- monai/transforms/utils.py | 40 +++++++++++++++++++++++++++++++ tests/test_flip.py | 2 +- 4 files changed, 48 insertions(+), 10 deletions(-) diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index ebd21a1c45..4dc7744755 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -338,6 +338,7 @@ in_bounds, is_empty, map_binary_to_indices, + map_spatial_axes, rand_choice, rescale_array, rescale_array_int_max, diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index f0cf047aa6..9cbe1536d0 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -31,6 +31,7 @@ create_scale, create_shear, create_translate, + map_spatial_axes, ) from monai.utils import ( GridSampleMode, @@ -297,7 +298,7 @@ class Flip(Transform): """ - def __init__(self, spatial_axis: Optional[Union[Sequence[int], int]]) -> None: + def __init__(self, spatial_axis: Optional[Union[Sequence[int], int]] = None) -> None: self.spatial_axis = spatial_axis def __call__(self, img: np.ndarray) -> np.ndarray: @@ -305,10 +306,8 @@ def __call__(self, img: np.ndarray) -> np.ndarray: Args: img: channel first array, must have shape: (num_channels, H[, W, ..., ]), """ - flipped = [] - for channel in img: - flipped.append(np.flip(channel, self.spatial_axis)) - return np.stack(flipped).astype(img.dtype) + + return np.flip(img, map_spatial_axes(img, self.spatial_axis)).astype(img.dtype) class Resize(Transform): @@ -596,10 +595,8 @@ def __call__(self, img: np.ndarray) -> np.ndarray: Args: img: channel first array, must have shape: (num_channels, H[, W, ..., ]), """ - rotated = [] - for channel in img: - rotated.append(np.rot90(channel, self.k, self.spatial_axes)) - return np.stack(rotated).astype(img.dtype) + + return np.rot90(img, self.k, map_spatial_axes(img, self.spatial_axes)).astype(img.dtype) class RandRotate90(Randomizable, Transform): diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index e5e9f81cc6..287ebf7382 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -48,6 +48,7 @@ "get_largest_connected_component_mask", "get_extreme_points", "extreme_points_to_image", + "map_spatial_axes", ] @@ -690,3 +691,42 @@ def extreme_points_to_image( points_image = (points_image - min_intensity) / (max_intensity - min_intensity) points_image = points_image * (rescale_max - rescale_min) + rescale_min return points_image + + +def map_spatial_axes( + img: np.ndarray, + spatial_axes: Optional[Union[Sequence[int], int]] = None, + channel_first: bool = True, +) -> List[int]: + """ + Utility to map the spatial axes to real axes in channel first/last shape. + For example: + If `channel_first` is True, and `img` has 3 spatial dims, map spatial axes to real axes as below: + None -> [1, 2, 3] + [0, 1] -> [1, 2] + [0, -1] -> [1, -1] + If `channel_first` is False, and `img` has 3 spatial dims, map spatial axes to real axes as below: + None -> [0, 1, 2] + [0, 1] -> [0, 1] + [0, -1] -> [0, -2] + + Args: + img: target image data to compute real axes. + spatial_axes: spatial axes to be converted, default is None. + The default `None` will convert to all the spatial axes of the image. + If axis is negative it counts from the last to the first axis. + If axis is a tuple of ints. + channel_first: the image data is channel first or channel last, defaut to channel first. + + """ + if spatial_axes is None: + spatial_axes_ = list(range(1, img.ndim) if channel_first else range(0, img.ndim - 1)) + else: + spatial_axes_ = [] + for a in ensure_tuple(spatial_axes): + if channel_first: + spatial_axes_.append(a if a < 0 else a + 1) + else: + spatial_axes_.append(a - 1 if a < 0 else a) + + return spatial_axes_ diff --git a/tests/test_flip.py b/tests/test_flip.py index 7a2af02585..fe169c4da8 100644 --- a/tests/test_flip.py +++ b/tests/test_flip.py @@ -19,7 +19,7 @@ INVALID_CASES = [("wrong_axis", ["s", 1], TypeError), ("not_numbers", "s", TypeError)] -VALID_CASES = [("no_axis", None), ("one_axis", 1), ("many_axis", [0, 1])] +VALID_CASES = [("no_axis", None), ("one_axis", 1), ("many_axis", [0, 1]), ("negative_axis", [0, -1])] class TestFlip(NumpyImageTestCase2D): From 5ce3ea446955f7425218af8ce404277bd2ef43c2 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Tue, 9 Feb 2021 19:13:12 +0800 Subject: [PATCH 3/4] [DLMED] fix flake8 issue Signed-off-by: Nic Ma --- monai/transforms/spatial/array.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index 9cbe1536d0..3d9b7f3a63 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -307,7 +307,8 @@ def __call__(self, img: np.ndarray) -> np.ndarray: img: channel first array, must have shape: (num_channels, H[, W, ..., ]), """ - return np.flip(img, map_spatial_axes(img, self.spatial_axis)).astype(img.dtype) + result: np.ndarray = np.flip(img, map_spatial_axes(img, self.spatial_axis)) + return result.astype(img.dtype) class Resize(Transform): @@ -596,7 +597,8 @@ def __call__(self, img: np.ndarray) -> np.ndarray: img: channel first array, must have shape: (num_channels, H[, W, ..., ]), """ - return np.rot90(img, self.k, map_spatial_axes(img, self.spatial_axes)).astype(img.dtype) + result: np.ndarray = np.rot90(img, self.k, map_spatial_axes(img, self.spatial_axes)) + return result.astype(img.dtype) class RandRotate90(Randomizable, Transform): From c42c4b34f40175c26b984e6b745ce37d96bab2de Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Tue, 9 Feb 2021 22:19:31 +0800 Subject: [PATCH 4/4] [DLMED] update according to comments Signed-off-by: Nic Ma --- monai/transforms/spatial/array.py | 4 ++-- monai/transforms/utils.py | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index 3d9b7f3a63..df10480188 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -307,7 +307,7 @@ def __call__(self, img: np.ndarray) -> np.ndarray: img: channel first array, must have shape: (num_channels, H[, W, ..., ]), """ - result: np.ndarray = np.flip(img, map_spatial_axes(img, self.spatial_axis)) + result: np.ndarray = np.flip(img, map_spatial_axes(img.ndim, self.spatial_axis)) return result.astype(img.dtype) @@ -597,7 +597,7 @@ def __call__(self, img: np.ndarray) -> np.ndarray: img: channel first array, must have shape: (num_channels, H[, W, ..., ]), """ - result: np.ndarray = np.rot90(img, self.k, map_spatial_axes(img, self.spatial_axes)) + result: np.ndarray = np.rot90(img, self.k, map_spatial_axes(img.ndim, self.spatial_axes)) return result.astype(img.dtype) diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index 287ebf7382..9a84eb00d9 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -694,7 +694,7 @@ def extreme_points_to_image( def map_spatial_axes( - img: np.ndarray, + img_ndim: int, spatial_axes: Optional[Union[Sequence[int], int]] = None, channel_first: bool = True, ) -> List[int]: @@ -711,7 +711,7 @@ def map_spatial_axes( [0, -1] -> [0, -2] Args: - img: target image data to compute real axes. + img_ndim: dimension number of the target image. spatial_axes: spatial axes to be converted, default is None. The default `None` will convert to all the spatial axes of the image. If axis is negative it counts from the last to the first axis. @@ -720,7 +720,7 @@ def map_spatial_axes( """ if spatial_axes is None: - spatial_axes_ = list(range(1, img.ndim) if channel_first else range(0, img.ndim - 1)) + spatial_axes_ = list(range(1, img_ndim) if channel_first else range(0, img_ndim - 1)) else: spatial_axes_ = [] for a in ensure_tuple(spatial_axes):