diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index 3739a83e71..2347217428 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -34,7 +34,8 @@ from monai.transforms.inverse import InvertibleTransform from monai.transforms.spatial.functional import ( affine_func, - flip, + flip_image, + flip_point, orientation, resize, rotate, @@ -684,8 +685,9 @@ class Flip(InvertibleTransform, LazyTransform): def __init__(self, spatial_axis: Sequence[int] | int | None = None, lazy: bool = False) -> None: LazyTransform.__init__(self, lazy=lazy) self.spatial_axis = spatial_axis + self.operators = [flip_point, flip_image] - def __call__(self, img: torch.Tensor, lazy: bool | None = None) -> torch.Tensor: + def __call__(self, img: torch.Tensor, lazy: bool | None = None) -> torch.Tensor: # type: ignore[return] """ Args: img: channel first array, must have shape: (num_channels, H[, W, ..., ]) @@ -695,7 +697,10 @@ def __call__(self, img: torch.Tensor, lazy: bool | None = None) -> torch.Tensor: """ img = convert_to_tensor(img, track_meta=get_track_meta()) lazy_ = self.lazy if lazy is None else lazy - return flip(img, self.spatial_axis, lazy=lazy_, transform_info=self.get_transform_info()) # type: ignore + for operator in self.operators: + ret: torch.Tensor = operator(img, self.spatial_axis, lazy=lazy_, transform_info=self.get_transform_info()) + if ret is not None: + return ret def inverse(self, data: torch.Tensor) -> torch.Tensor: self.pop_transform(data) diff --git a/monai/transforms/spatial/functional.py b/monai/transforms/spatial/functional.py index add4e7f5ea..e5712ee239 100644 --- a/monai/transforms/spatial/functional.py +++ b/monai/transforms/spatial/functional.py @@ -32,7 +32,7 @@ from monai.transforms.intensity.array import GaussianSmooth from monai.transforms.inverse import TraceableTransform from monai.transforms.utils import create_rotate, create_translate, resolves_modes, scale_affine -from monai.transforms.utils_pytorch_numpy_unification import allclose +from monai.transforms.utils_pytorch_numpy_unification import allclose, concatenate from monai.utils import ( LazyAttr, TraceKeys, @@ -44,13 +44,24 @@ fall_back_tuple, optional_import, ) +from monai.utils.enums import MetaKeys, KindKeys nib, has_nib = optional_import("nibabel") cupy, _ = optional_import("cupy") cupy_ndi, _ = optional_import("cupyx.scipy.ndimage") np_ndi, _ = optional_import("scipy.ndimage") -__all__ = ["spatial_resample", "orientation", "flip", "resize", "rotate", "zoom", "rotate90", "affine_func"] +__all__ = [ + "spatial_resample", + "orientation", + "flip_image", + "flip_point", + "resize", + "rotate", + "zoom", + "rotate90", + "affine_func", +] def _maybe_new_metatensor(img, dtype=None, device=None): @@ -229,7 +240,26 @@ def orientation(img, original_affine, spatial_ornt, lazy, transform_info) -> tor return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else out # type: ignore -def flip(img, sp_axes, lazy, transform_info): +def flip_helper(data, sp_axes, lazy, transform_info): + sp_size = data.peek_pending_shape() if isinstance(data, MetaTensor) else data.shape[1:] + sp_size = convert_to_numpy(sp_size, wrap_sequence=True).tolist() + extra_info = {"axes": sp_axes} # track the spatial axes + axes = monai.transforms.utils.map_spatial_axes(data.ndim, sp_axes) # use the axes with channel dim + rank = data.peek_pending_rank() if isinstance(data, MetaTensor) else torch.tensor(3.0, dtype=torch.double) + # axes include the channel dim + xform = torch.eye(int(rank) + 1, dtype=torch.double) + for axis in axes: + sp = axis - 1 + if data.kind == KindKeys.PIXEL: + xform[sp, -1] = sp_size[sp] - 1 + xform[sp, sp] = xform[sp, sp] * -1 + meta_info = TraceableTransform.track_transform_meta( + data, affine=xform, extra_info=extra_info, lazy=lazy, transform_info=transform_info + ) + return axes, meta_info, xform + + +def flip_image(img, sp_axes, lazy, transform_info): """ Functional implementation of flip. This function operates eagerly or lazily according to @@ -245,19 +275,10 @@ def flip(img, sp_axes, lazy, transform_info): lazy: a flag that indicates whether the operation should be performed lazily or not transform_info: a dictionary with the relevant information pertaining to an applied transform. """ - sp_size = img.peek_pending_shape() if isinstance(img, MetaTensor) else img.shape[1:] - sp_size = convert_to_numpy(sp_size, wrap_sequence=True).tolist() - extra_info = {"axes": sp_axes} # track the spatial axes - axes = monai.transforms.utils.map_spatial_axes(img.ndim, sp_axes) # use the axes with channel dim - rank = img.peek_pending_rank() if isinstance(img, MetaTensor) else torch.tensor(3.0, dtype=torch.double) - # axes include the channel dim - xform = torch.eye(int(rank) + 1, dtype=torch.double) - for axis in axes: - sp = axis - 1 - xform[sp, sp], xform[sp, -1] = xform[sp, sp] * -1, sp_size[sp] - 1 - meta_info = TraceableTransform.track_transform_meta( - img, sp_size=sp_size, affine=xform, extra_info=extra_info, transform_info=transform_info, lazy=lazy - ) + kind = img.meta.get(MetaKeys.KIND, KindKeys.PIXEL) if isinstance(img, MetaTensor) else KindKeys.PIXEL + if kind != KindKeys.PIXEL: + return None + axes, meta_info, _ = flip_helper(img, sp_axes, lazy, transform_info) out = _maybe_new_metatensor(img) if lazy: return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else meta_info @@ -265,6 +286,38 @@ def flip(img, sp_axes, lazy, transform_info): return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else out +def flip_point(points, sp_axes, lazy, transform_info): + """ + Functional implementation of flip points. + This function operates eagerly or lazily according to + ``lazy`` (default ``False``). + + Args: + points: point coordinates, represented by a torch tensor or ndarray with dimensions of 1xNx2 or 1xNx3. + Here 1 represents the channel dimension. + sp_axes: spatial axes along which to flip over. Default is None. + The default `axis=None` will flip over all of the axes of the input array. + If axis is negative it counts from the last to the first axis. + If axis is a tuple of ints, flipping is performed on all of the axes + specified in the tuple. + lazy: a flag that indicates whether the operation should be performed lazily or not. + transform_info: a dictionary with the relevant information pertaining to an applied transform. + """ + kind = points.meta.get(MetaKeys.KIND, KindKeys.PIXEL) if isinstance(points, MetaTensor) else KindKeys.PIXEL + if kind != KindKeys.POINT: + return None + _, meta_info, xform = flip_helper(points, sp_axes, lazy, transform_info) + + out = _maybe_new_metatensor(points) + if lazy: + # TODO: add lazy support + raise NotImplementedError + return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else meta_info + # TODO: use CoordinateTransformd.apply_affine_to_points instead + out = apply_affine_to_points(out[0], xform, dtype=torch.float64).unsqueeze(0) + return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else out + + def resize( img, out_size, mode, align_corners, dtype, input_ndim, anti_aliasing, anti_aliasing_sigma, lazy, transform_info ): @@ -610,3 +663,15 @@ def affine_func( out = _maybe_new_metatensor(img, dtype=torch.float32, device=resampler.device) out = out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else out return out if image_only else (out, affine) + + +def apply_affine_to_points(data, affine, dtype): + data = convert_to_tensor(data, track_meta=get_track_meta()) + data_: torch.Tensor = convert_to_tensor(data, track_meta=False, dtype=dtype) + + homogeneous = concatenate((data_, torch.ones((data_.shape[0], 1))), axis=1) + transformed_homogeneous = torch.matmul(affine, homogeneous.T) + transformed_coordinates = transformed_homogeneous[:-1].T + out, *_ = convert_to_dst_type(transformed_coordinates, data, dtype=dtype) + + return out diff --git a/tests/test_flip.py b/tests/test_flip.py index 789ec86920..971da97dd2 100644 --- a/tests/test_flip.py +++ b/tests/test_flip.py @@ -12,6 +12,7 @@ from __future__ import annotations import unittest +from copy import deepcopy import numpy as np import torch @@ -32,6 +33,15 @@ for device in TEST_DEVICES: TORCH_CASES.append([[0, 1], torch.zeros((1, 3, 2)), track_meta, *device]) +POINT_2D_WITH_REFER = MetaTensor( + [[[3, 4], [5, 7], [6, 2], [7, 8]]], meta={"kind": "point", "refer_meta": {"spatial_shape": (10, 10)}} +) +POINT_3D = MetaTensor([[[3, 4, 5], [5, 7, 6], [6, 2, 7]]], meta={"kind": "point"}) +POINT_CASES = [] +for spatial_axis in [[0], [1], [0, 1]]: + for point in [POINT_2D_WITH_REFER, POINT_3D]: + POINT_CASES.append([spatial_axis, point]) + class TestFlip(NumpyImageTestCase2D): @@ -73,6 +83,22 @@ def test_torch(self, spatial_axis, img: torch.Tensor, track_meta: bool, device): with self.assertRaisesRegex(ValueError, "MetaTensor"): xform.inverse(res) + @parameterized.expand(POINT_CASES) + def test_points(self, spatial_axis, point): + init_param = {"spatial_axis": spatial_axis} + xform = Flip(**init_param) + res = xform(point) # type: ignore[arg-type] + self.assertEqual(point.shape, res.shape) + expected = deepcopy(point) + if point.meta.get("refer_meta", None) is not None: + for _axes in spatial_axis: + expected[..., _axes] = (10, 10)[_axes] - point[..., _axes] + else: + for _axes in spatial_axis: + expected[..., _axes] = -point[..., _axes] + assert_allclose(res, expected, type_test="tensor") + test_local_inversion(xform, res, point) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_flipd.py b/tests/test_flipd.py index 1df6d34056..1d11e8654d 100644 --- a/tests/test_flipd.py +++ b/tests/test_flipd.py @@ -12,6 +12,7 @@ from __future__ import annotations import unittest +from copy import deepcopy import numpy as np import torch @@ -33,6 +34,15 @@ for device in TEST_DEVICES: TORCH_CASES.append([[0, 1], torch.zeros((1, 3, 2)), track_meta, *device]) +POINT_2D_WITH_REFER = MetaTensor( + [[[3, 4], [5, 7], [6, 2], [7, 8]]], meta={"kind": "point", "refer_meta": {"spatial_shape": (10, 10)}} +) +POINT_3D = MetaTensor([[[3, 4, 5], [5, 7, 6], [6, 2, 7]]], meta={"kind": "point"}) +POINT_CASES = [] +for spatial_axis in [[0], [1], [0, 1]]: + for point in [POINT_2D_WITH_REFER, POINT_3D]: + POINT_CASES.append([spatial_axis, point]) + class TestFlipd(NumpyImageTestCase2D): @@ -80,6 +90,22 @@ def test_meta_dict(self): res = xform({"image": torch.zeros(1, 3, 4)}) self.assertEqual(res["image"].applied_operations, res["image_transforms"]) + @parameterized.expand(POINT_CASES) + def test_points(self, spatial_axis, point): + init_param = {"keys": "point", "spatial_axis": spatial_axis} + xform = Flipd(**init_param) + res = xform({"point": point}) # type: ignore[arg-type] + self.assertEqual(point.shape, res["point"].shape) + expected = deepcopy(point) + if point.meta.get("refer_meta", None) is not None: + for _axes in spatial_axis: + expected[..., _axes] = (10, 10)[_axes] - point[..., _axes] + else: + for _axes in spatial_axis: + expected[..., _axes] = -point[..., _axes] + assert_allclose(res["point"], expected, type_test="tensor") + test_local_inversion(xform, {"point": res["point"]}, {"point": point}, "point") + if __name__ == "__main__": unittest.main()