Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 40 additions & 18 deletions monai/transforms/spatial/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -931,6 +931,9 @@ class AffineGrid(Transform):
as_tensor_output: whether to output tensor instead of numpy array.
defaults to True.
device: device to store the output grid data.
affine: If applied, ignore the params (`rotate_params`, etc.) and use the
supplied matrix. Should be square with each side = num of image spatial
dimensions + 1.

"""

Expand All @@ -942,6 +945,7 @@ def __init__(
scale_params: Optional[Union[Sequence[float], float]] = None,
as_tensor_output: bool = True,
device: Optional[torch.device] = None,
affine: Optional[Union[np.ndarray, torch.Tensor]] = None,
) -> None:
self.rotate_params = rotate_params
self.shear_params = shear_params
Expand All @@ -951,8 +955,12 @@ def __init__(
self.as_tensor_output = as_tensor_output
self.device = device

self.affine = affine

def __call__(
self, spatial_size: Optional[Sequence[int]] = None, grid: Optional[Union[np.ndarray, torch.Tensor]] = None
self,
spatial_size: Optional[Sequence[int]] = None,
grid: Optional[Union[np.ndarray, torch.Tensor]] = None,
) -> Union[np.ndarray, torch.Tensor]:
"""
Args:
Expand All @@ -969,27 +977,32 @@ def __call__(
else:
raise ValueError("Incompatible values: grid=None and spatial_size=None.")

spatial_dims = len(grid.shape) - 1
affine = np.eye(spatial_dims + 1)
if self.rotate_params:
affine = affine @ create_rotate(spatial_dims, self.rotate_params)
if self.shear_params:
affine = affine @ create_shear(spatial_dims, self.shear_params)
if self.translate_params:
affine = affine @ create_translate(spatial_dims, self.translate_params)
if self.scale_params:
affine = affine @ create_scale(spatial_dims, self.scale_params)
affine = torch.as_tensor(np.ascontiguousarray(affine), device=self.device)
if self.affine is None:
spatial_dims = len(grid.shape) - 1
affine = np.eye(spatial_dims + 1)
if self.rotate_params:
affine = affine @ create_rotate(spatial_dims, self.rotate_params)
if self.shear_params:
affine = affine @ create_shear(spatial_dims, self.shear_params)
if self.translate_params:
affine = affine @ create_translate(spatial_dims, self.translate_params)
if self.scale_params:
affine = affine @ create_scale(spatial_dims, self.scale_params)
self.affine = affine

self.affine = torch.as_tensor(np.ascontiguousarray(self.affine), device=self.device)

grid = torch.tensor(grid) if not isinstance(grid, torch.Tensor) else grid.detach().clone()
if self.device:
grid = grid.to(self.device)
grid = (affine.float() @ grid.reshape((grid.shape[0], -1)).float()).reshape([-1] + list(grid.shape[1:]))
grid = (self.affine.float() @ grid.reshape((grid.shape[0], -1)).float()).reshape([-1] + list(grid.shape[1:]))
if grid is None or not isinstance(grid, torch.Tensor):
raise ValueError("Unknown grid.")
if self.as_tensor_output:
return grid
return np.asarray(grid.cpu().numpy())
return grid if self.as_tensor_output else np.asarray(grid.cpu().numpy())

def get_transformation_matrix(self) -> Optional[Union[np.ndarray, torch.Tensor]]:
"""Get the most recently applied transformation matrix"""
return self.affine


class RandAffineGrid(RandomizableTransform):
Expand Down Expand Up @@ -1040,6 +1053,7 @@ def __init__(

self.as_tensor_output = as_tensor_output
self.device = device
self.affine: Optional[Union[np.ndarray, torch.Tensor]] = None

def _get_rand_param(self, param_range, add_scalar: float = 0.0):
out_param = []
Expand All @@ -1059,7 +1073,9 @@ def randomize(self, data: Optional[Any] = None) -> None:
self.scale_params = self._get_rand_param(self.scale_range, 1.0)

def __call__(
self, spatial_size: Optional[Sequence[int]] = None, grid: Optional[Union[np.ndarray, torch.Tensor]] = None
self,
spatial_size: Optional[Sequence[int]] = None,
grid: Optional[Union[np.ndarray, torch.Tensor]] = None,
) -> Union[np.ndarray, torch.Tensor]:
"""
Args:
Expand All @@ -1078,7 +1094,13 @@ def __call__(
as_tensor_output=self.as_tensor_output,
device=self.device,
)
return affine_grid(spatial_size, grid)
grid = affine_grid(spatial_size, grid)
self.affine = affine_grid.get_transformation_matrix()
return grid

def get_transformation_matrix(self) -> Optional[Union[np.ndarray, torch.Tensor]]:
"""Get the most recently applied transformation matrix"""
return self.affine


class RandDeformGrid(RandomizableTransform):
Expand Down
59 changes: 57 additions & 2 deletions monai/transforms/spatial/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from monai.transforms.inverse import InvertibleTransform
from monai.transforms.spatial.array import (
Affine,
AffineGrid,
Flip,
Orientation,
Rand2DElastic,
Expand Down Expand Up @@ -501,7 +502,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar
return d


class Affined(RandomizableTransform, MapTransform):
class Affined(MapTransform, InvertibleTransform):
"""
Dictionary-based wrapper of :py:class:`monai.transforms.Affine`.
"""
Expand Down Expand Up @@ -570,11 +571,38 @@ def __call__(
) -> Dict[Hashable, Union[np.ndarray, torch.Tensor]]:
d = dict(data)
for key, mode, padding_mode in self.key_iterator(d, self.mode, self.padding_mode):
orig_size = d[key].shape[1:]
d[key] = self.affine(d[key], mode=mode, padding_mode=padding_mode)
affine = self.affine.affine_grid.get_transformation_matrix()
self.push_transform(d, key, orig_size=orig_size, extra_info={"affine": affine})
return d

def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]:
d = deepcopy(dict(data))

for key, mode, padding_mode in self.key_iterator(d, self.mode, self.padding_mode):
transform = self.get_most_recent_transform(d, key)
orig_size = transform[InverseKeys.ORIG_SIZE.value]
# Create inverse transform
fwd_affine = transform[InverseKeys.EXTRA_INFO.value]["affine"]
inv_affine = np.linalg.inv(fwd_affine)

affine_grid = AffineGrid(affine=inv_affine)
grid: torch.Tensor = affine_grid(orig_size) # type: ignore

# Apply inverse transform
out = self.affine.resampler(d[key], grid, mode, padding_mode)

# Convert to numpy
d[key] = out if isinstance(out, np.ndarray) else out.cpu().numpy()

# Remove the applied transform
self.pop_transform(d, key)

return d


class RandAffined(RandomizableTransform, MapTransform):
class RandAffined(RandomizableTransform, MapTransform, InvertibleTransform):
"""
Dictionary-based wrapper of :py:class:`monai.transforms.RandAffine`.
"""
Expand Down Expand Up @@ -667,13 +695,40 @@ def __call__(
sp_size = fall_back_tuple(self.rand_affine.spatial_size, data[self.keys[0]].shape[1:])
if self._do_transform:
grid = self.rand_affine.rand_affine_grid(spatial_size=sp_size)
affine = self.rand_affine.rand_affine_grid.get_transformation_matrix()
else:
grid = create_grid(spatial_size=sp_size)
affine = np.eye(len(sp_size) + 1)

for key, mode, padding_mode in self.key_iterator(d, self.mode, self.padding_mode):
self.push_transform(d, key, extra_info={"affine": affine})
d[key] = self.rand_affine.resampler(d[key], grid, mode=mode, padding_mode=padding_mode)
return d

def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]:
d = deepcopy(dict(data))

for key, mode, padding_mode in self.key_iterator(d, self.mode, self.padding_mode):
transform = self.get_most_recent_transform(d, key)
orig_size = transform[InverseKeys.ORIG_SIZE.value]
# Create inverse transform
fwd_affine = transform[InverseKeys.EXTRA_INFO.value]["affine"]
inv_affine = np.linalg.inv(fwd_affine)

affine_grid = AffineGrid(affine=inv_affine)
grid: torch.Tensor = affine_grid(orig_size) # type: ignore

# Apply inverse transform
out = self.rand_affine.resampler(d[key], grid, mode, padding_mode)

# Convert to numpy
d[key] = out if isinstance(out, np.ndarray) else out.cpu().numpy()

# Remove the applied transform
self.pop_transform(d, key)

return d


class Rand2DElasticd(RandomizableTransform, MapTransform):
"""
Expand Down
36 changes: 36 additions & 0 deletions tests/test_inverse.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from monai.networks.nets import UNet
from monai.transforms import (
AddChanneld,
Affined,
BorderPadd,
CenterSpatialCropd,
Compose,
Expand All @@ -33,6 +34,7 @@
InvertibleTransform,
LoadImaged,
Orientationd,
RandAffined,
RandAxisFlipd,
RandFlipd,
Randomizable,
Expand Down Expand Up @@ -365,6 +367,40 @@
)
)

TESTS.append(
(
"Affine 3d",
"3D",
1e-1,
Affined(
KEYS,
spatial_size=[155, 179, 192],
rotate_params=[np.pi / 6, -np.pi / 5, np.pi / 7],
shear_params=[0.5, 0.5],
translate_params=[10, 5, -4],
scale_params=[0.8, 1.3],
),
)
)

TESTS.append(
(
"RandAffine 3d",
"3D",
1e-1,
RandAffined(
KEYS,
[155, 179, 192],
prob=1,
padding_mode="zeros",
rotate_range=[np.pi / 6, -np.pi / 5, np.pi / 7],
shear_range=[(0.5, 0.5)],
translate_range=[10, 5, -4],
scale_range=[(0.8, 1.2), (0.9, 1.3)],
),
)
)

TESTS_COMPOSE_X2 = [(t[0] + " Compose", t[1], t[2], Compose(Compose(t[3:]))) for t in TESTS]

TESTS = TESTS + TESTS_COMPOSE_X2 # type: ignore
Expand Down
2 changes: 2 additions & 0 deletions tests/test_rand_affined.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,8 @@ def test_rand_affined(self, input_param, input_data, expected_val):
res = g(input_data)
for key in res:
result = res[key]
if "_transforms" in key:
continue
expected = expected_val[key] if isinstance(expected_val, dict) else expected_val
self.assertEqual(isinstance(result, torch.Tensor), isinstance(expected, torch.Tensor))
if isinstance(result, torch.Tensor):
Expand Down