Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions monai/transforms/spatial/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@
from monai.transforms.inverse import InvertibleTransform
from monai.transforms.spatial.functional import (
affine_func,
flip,
flip_image,
flip_point,
orientation,
resize,
rotate,
Expand Down Expand Up @@ -684,8 +685,9 @@ class Flip(InvertibleTransform, LazyTransform):
def __init__(self, spatial_axis: Sequence[int] | int | None = None, lazy: bool = False) -> None:
LazyTransform.__init__(self, lazy=lazy)
self.spatial_axis = spatial_axis
self.operators = [flip_point, flip_image]

def __call__(self, img: torch.Tensor, lazy: bool | None = None) -> torch.Tensor:
def __call__(self, img: torch.Tensor, lazy: bool | None = None) -> torch.Tensor: # type: ignore[return]
"""
Args:
img: channel first array, must have shape: (num_channels, H[, W, ..., ])
Expand All @@ -695,7 +697,10 @@ def __call__(self, img: torch.Tensor, lazy: bool | None = None) -> torch.Tensor:
"""
img = convert_to_tensor(img, track_meta=get_track_meta())
lazy_ = self.lazy if lazy is None else lazy
return flip(img, self.spatial_axis, lazy=lazy_, transform_info=self.get_transform_info()) # type: ignore
for operator in self.operators:
ret: torch.Tensor = operator(img, self.spatial_axis, lazy=lazy_, transform_info=self.get_transform_info())
if ret is not None:
return ret

def inverse(self, data: torch.Tensor) -> torch.Tensor:
self.pop_transform(data)
Expand Down
97 changes: 81 additions & 16 deletions monai/transforms/spatial/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from monai.transforms.intensity.array import GaussianSmooth
from monai.transforms.inverse import TraceableTransform
from monai.transforms.utils import create_rotate, create_translate, resolves_modes, scale_affine
from monai.transforms.utils_pytorch_numpy_unification import allclose
from monai.transforms.utils_pytorch_numpy_unification import allclose, concatenate
from monai.utils import (
LazyAttr,
TraceKeys,
Expand All @@ -44,13 +44,24 @@
fall_back_tuple,
optional_import,
)
from monai.utils.enums import MetaKeys, KindKeys

nib, has_nib = optional_import("nibabel")
cupy, _ = optional_import("cupy")
cupy_ndi, _ = optional_import("cupyx.scipy.ndimage")
np_ndi, _ = optional_import("scipy.ndimage")

__all__ = ["spatial_resample", "orientation", "flip", "resize", "rotate", "zoom", "rotate90", "affine_func"]
__all__ = [
"spatial_resample",
"orientation",
"flip_image",
"flip_point",
"resize",
"rotate",
"zoom",
"rotate90",
"affine_func",
]


def _maybe_new_metatensor(img, dtype=None, device=None):
Expand Down Expand Up @@ -229,7 +240,26 @@ def orientation(img, original_affine, spatial_ornt, lazy, transform_info) -> tor
return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else out # type: ignore


def flip(img, sp_axes, lazy, transform_info):
def flip_helper(data, sp_axes, lazy, transform_info):
sp_size = data.peek_pending_shape() if isinstance(data, MetaTensor) else data.shape[1:]
sp_size = convert_to_numpy(sp_size, wrap_sequence=True).tolist()
extra_info = {"axes": sp_axes} # track the spatial axes
axes = monai.transforms.utils.map_spatial_axes(data.ndim, sp_axes) # use the axes with channel dim
rank = data.peek_pending_rank() if isinstance(data, MetaTensor) else torch.tensor(3.0, dtype=torch.double)
# axes include the channel dim
xform = torch.eye(int(rank) + 1, dtype=torch.double)
for axis in axes:
sp = axis - 1
if data.kind == KindKeys.PIXEL:
xform[sp, -1] = sp_size[sp] - 1
xform[sp, sp] = xform[sp, sp] * -1
meta_info = TraceableTransform.track_transform_meta(
data, affine=xform, extra_info=extra_info, lazy=lazy, transform_info=transform_info
)
return axes, meta_info, xform


def flip_image(img, sp_axes, lazy, transform_info):
Copy link
Contributor

@atbenmurray atbenmurray Mar 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same deal as with #7509. We should pull flip implementation out of flip_image and flip_geom.

"""
Functional implementation of flip.
This function operates eagerly or lazily according to
Expand All @@ -245,26 +275,49 @@ def flip(img, sp_axes, lazy, transform_info):
lazy: a flag that indicates whether the operation should be performed lazily or not
transform_info: a dictionary with the relevant information pertaining to an applied transform.
"""
sp_size = img.peek_pending_shape() if isinstance(img, MetaTensor) else img.shape[1:]
sp_size = convert_to_numpy(sp_size, wrap_sequence=True).tolist()
extra_info = {"axes": sp_axes} # track the spatial axes
axes = monai.transforms.utils.map_spatial_axes(img.ndim, sp_axes) # use the axes with channel dim
rank = img.peek_pending_rank() if isinstance(img, MetaTensor) else torch.tensor(3.0, dtype=torch.double)
# axes include the channel dim
xform = torch.eye(int(rank) + 1, dtype=torch.double)
for axis in axes:
sp = axis - 1
xform[sp, sp], xform[sp, -1] = xform[sp, sp] * -1, sp_size[sp] - 1
meta_info = TraceableTransform.track_transform_meta(
img, sp_size=sp_size, affine=xform, extra_info=extra_info, transform_info=transform_info, lazy=lazy
)
kind = img.meta.get(MetaKeys.KIND, KindKeys.PIXEL) if isinstance(img, MetaTensor) else KindKeys.PIXEL
if kind != KindKeys.PIXEL:
return None
axes, meta_info, _ = flip_helper(img, sp_axes, lazy, transform_info)
out = _maybe_new_metatensor(img)
if lazy:
return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else meta_info
out = torch.flip(out, axes)
return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else out


def flip_point(points, sp_axes, lazy, transform_info):
"""
Functional implementation of flip points.
This function operates eagerly or lazily according to
``lazy`` (default ``False``).

Args:
points: point coordinates, represented by a torch tensor or ndarray with dimensions of 1xNx2 or 1xNx3.
Here 1 represents the channel dimension.
sp_axes: spatial axes along which to flip over. Default is None.
The default `axis=None` will flip over all of the axes of the input array.
If axis is negative it counts from the last to the first axis.
If axis is a tuple of ints, flipping is performed on all of the axes
specified in the tuple.
lazy: a flag that indicates whether the operation should be performed lazily or not.
transform_info: a dictionary with the relevant information pertaining to an applied transform.
"""
kind = points.meta.get(MetaKeys.KIND, KindKeys.PIXEL) if isinstance(points, MetaTensor) else KindKeys.PIXEL
if kind != KindKeys.POINT:
return None
_, meta_info, xform = flip_helper(points, sp_axes, lazy, transform_info)

out = _maybe_new_metatensor(points)
if lazy:
# TODO: add lazy support
raise NotImplementedError
return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else meta_info
# TODO: use CoordinateTransformd.apply_affine_to_points instead
out = apply_affine_to_points(out[0], xform, dtype=torch.float64).unsqueeze(0)
return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else out


def resize(
img, out_size, mode, align_corners, dtype, input_ndim, anti_aliasing, anti_aliasing_sigma, lazy, transform_info
):
Expand Down Expand Up @@ -610,3 +663,15 @@ def affine_func(
out = _maybe_new_metatensor(img, dtype=torch.float32, device=resampler.device)
out = out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else out
return out if image_only else (out, affine)


def apply_affine_to_points(data, affine, dtype):
data = convert_to_tensor(data, track_meta=get_track_meta())
data_: torch.Tensor = convert_to_tensor(data, track_meta=False, dtype=dtype)

homogeneous = concatenate((data_, torch.ones((data_.shape[0], 1))), axis=1)
transformed_homogeneous = torch.matmul(affine, homogeneous.T)
transformed_coordinates = transformed_homogeneous[:-1].T
out, *_ = convert_to_dst_type(transformed_coordinates, data, dtype=dtype)

return out
26 changes: 26 additions & 0 deletions tests/test_flip.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from __future__ import annotations

import unittest
from copy import deepcopy

import numpy as np
import torch
Expand All @@ -32,6 +33,15 @@
for device in TEST_DEVICES:
TORCH_CASES.append([[0, 1], torch.zeros((1, 3, 2)), track_meta, *device])

POINT_2D_WITH_REFER = MetaTensor(
[[[3, 4], [5, 7], [6, 2], [7, 8]]], meta={"kind": "point", "refer_meta": {"spatial_shape": (10, 10)}}
)
POINT_3D = MetaTensor([[[3, 4, 5], [5, 7, 6], [6, 2, 7]]], meta={"kind": "point"})
POINT_CASES = []
for spatial_axis in [[0], [1], [0, 1]]:
for point in [POINT_2D_WITH_REFER, POINT_3D]:
POINT_CASES.append([spatial_axis, point])


class TestFlip(NumpyImageTestCase2D):

Expand Down Expand Up @@ -73,6 +83,22 @@ def test_torch(self, spatial_axis, img: torch.Tensor, track_meta: bool, device):
with self.assertRaisesRegex(ValueError, "MetaTensor"):
xform.inverse(res)

@parameterized.expand(POINT_CASES)
def test_points(self, spatial_axis, point):
init_param = {"spatial_axis": spatial_axis}
xform = Flip(**init_param)
res = xform(point) # type: ignore[arg-type]
self.assertEqual(point.shape, res.shape)
expected = deepcopy(point)
if point.meta.get("refer_meta", None) is not None:
for _axes in spatial_axis:
expected[..., _axes] = (10, 10)[_axes] - point[..., _axes]
else:
for _axes in spatial_axis:
expected[..., _axes] = -point[..., _axes]
assert_allclose(res, expected, type_test="tensor")
test_local_inversion(xform, res, point)


if __name__ == "__main__":
unittest.main()
26 changes: 26 additions & 0 deletions tests/test_flipd.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from __future__ import annotations

import unittest
from copy import deepcopy

import numpy as np
import torch
Expand All @@ -33,6 +34,15 @@
for device in TEST_DEVICES:
TORCH_CASES.append([[0, 1], torch.zeros((1, 3, 2)), track_meta, *device])

POINT_2D_WITH_REFER = MetaTensor(
[[[3, 4], [5, 7], [6, 2], [7, 8]]], meta={"kind": "point", "refer_meta": {"spatial_shape": (10, 10)}}
)
POINT_3D = MetaTensor([[[3, 4, 5], [5, 7, 6], [6, 2, 7]]], meta={"kind": "point"})
POINT_CASES = []
for spatial_axis in [[0], [1], [0, 1]]:
for point in [POINT_2D_WITH_REFER, POINT_3D]:
POINT_CASES.append([spatial_axis, point])


class TestFlipd(NumpyImageTestCase2D):

Expand Down Expand Up @@ -80,6 +90,22 @@ def test_meta_dict(self):
res = xform({"image": torch.zeros(1, 3, 4)})
self.assertEqual(res["image"].applied_operations, res["image_transforms"])

@parameterized.expand(POINT_CASES)
def test_points(self, spatial_axis, point):
init_param = {"keys": "point", "spatial_axis": spatial_axis}
xform = Flipd(**init_param)
res = xform({"point": point}) # type: ignore[arg-type]
self.assertEqual(point.shape, res["point"].shape)
expected = deepcopy(point)
if point.meta.get("refer_meta", None) is not None:
for _axes in spatial_axis:
expected[..., _axes] = (10, 10)[_axes] - point[..., _axes]
else:
for _axes in spatial_axis:
expected[..., _axes] = -point[..., _axes]
assert_allclose(res["point"], expected, type_test="tensor")
test_local_inversion(xform, {"point": res["point"]}, {"point": point}, "point")


if __name__ == "__main__":
unittest.main()