diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index 17626e4582..de9bba8e95 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -931,6 +931,9 @@ class AffineGrid(Transform): as_tensor_output: whether to output tensor instead of numpy array. defaults to True. device: device to store the output grid data. + affine: If applied, ignore the params (`rotate_params`, etc.) and use the + supplied matrix. Should be square with each side = num of image spatial + dimensions + 1. """ @@ -942,6 +945,7 @@ def __init__( scale_params: Optional[Union[Sequence[float], float]] = None, as_tensor_output: bool = True, device: Optional[torch.device] = None, + affine: Optional[Union[np.ndarray, torch.Tensor]] = None, ) -> None: self.rotate_params = rotate_params self.shear_params = shear_params @@ -951,8 +955,12 @@ def __init__( self.as_tensor_output = as_tensor_output self.device = device + self.affine = affine + def __call__( - self, spatial_size: Optional[Sequence[int]] = None, grid: Optional[Union[np.ndarray, torch.Tensor]] = None + self, + spatial_size: Optional[Sequence[int]] = None, + grid: Optional[Union[np.ndarray, torch.Tensor]] = None, ) -> Union[np.ndarray, torch.Tensor]: """ Args: @@ -969,27 +977,32 @@ def __call__( else: raise ValueError("Incompatible values: grid=None and spatial_size=None.") - spatial_dims = len(grid.shape) - 1 - affine = np.eye(spatial_dims + 1) - if self.rotate_params: - affine = affine @ create_rotate(spatial_dims, self.rotate_params) - if self.shear_params: - affine = affine @ create_shear(spatial_dims, self.shear_params) - if self.translate_params: - affine = affine @ create_translate(spatial_dims, self.translate_params) - if self.scale_params: - affine = affine @ create_scale(spatial_dims, self.scale_params) - affine = torch.as_tensor(np.ascontiguousarray(affine), device=self.device) + if self.affine is None: + spatial_dims = len(grid.shape) - 1 + affine = np.eye(spatial_dims + 1) + if self.rotate_params: + affine = affine @ create_rotate(spatial_dims, self.rotate_params) + if self.shear_params: + affine = affine @ create_shear(spatial_dims, self.shear_params) + if self.translate_params: + affine = affine @ create_translate(spatial_dims, self.translate_params) + if self.scale_params: + affine = affine @ create_scale(spatial_dims, self.scale_params) + self.affine = affine + + self.affine = torch.as_tensor(np.ascontiguousarray(self.affine), device=self.device) grid = torch.tensor(grid) if not isinstance(grid, torch.Tensor) else grid.detach().clone() if self.device: grid = grid.to(self.device) - grid = (affine.float() @ grid.reshape((grid.shape[0], -1)).float()).reshape([-1] + list(grid.shape[1:])) + grid = (self.affine.float() @ grid.reshape((grid.shape[0], -1)).float()).reshape([-1] + list(grid.shape[1:])) if grid is None or not isinstance(grid, torch.Tensor): raise ValueError("Unknown grid.") - if self.as_tensor_output: - return grid - return np.asarray(grid.cpu().numpy()) + return grid if self.as_tensor_output else np.asarray(grid.cpu().numpy()) + + def get_transformation_matrix(self) -> Optional[Union[np.ndarray, torch.Tensor]]: + """Get the most recently applied transformation matrix""" + return self.affine class RandAffineGrid(RandomizableTransform): @@ -1040,6 +1053,7 @@ def __init__( self.as_tensor_output = as_tensor_output self.device = device + self.affine: Optional[Union[np.ndarray, torch.Tensor]] = None def _get_rand_param(self, param_range, add_scalar: float = 0.0): out_param = [] @@ -1059,7 +1073,9 @@ def randomize(self, data: Optional[Any] = None) -> None: self.scale_params = self._get_rand_param(self.scale_range, 1.0) def __call__( - self, spatial_size: Optional[Sequence[int]] = None, grid: Optional[Union[np.ndarray, torch.Tensor]] = None + self, + spatial_size: Optional[Sequence[int]] = None, + grid: Optional[Union[np.ndarray, torch.Tensor]] = None, ) -> Union[np.ndarray, torch.Tensor]: """ Args: @@ -1078,7 +1094,13 @@ def __call__( as_tensor_output=self.as_tensor_output, device=self.device, ) - return affine_grid(spatial_size, grid) + grid = affine_grid(spatial_size, grid) + self.affine = affine_grid.get_transformation_matrix() + return grid + + def get_transformation_matrix(self) -> Optional[Union[np.ndarray, torch.Tensor]]: + """Get the most recently applied transformation matrix""" + return self.affine class RandDeformGrid(RandomizableTransform): diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index 1c6b6a14bc..caa1a34e08 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -28,6 +28,7 @@ from monai.transforms.inverse import InvertibleTransform from monai.transforms.spatial.array import ( Affine, + AffineGrid, Flip, Orientation, Rand2DElastic, @@ -501,7 +502,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar return d -class Affined(RandomizableTransform, MapTransform): +class Affined(MapTransform, InvertibleTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.Affine`. """ @@ -570,11 +571,38 @@ def __call__( ) -> Dict[Hashable, Union[np.ndarray, torch.Tensor]]: d = dict(data) for key, mode, padding_mode in self.key_iterator(d, self.mode, self.padding_mode): + orig_size = d[key].shape[1:] d[key] = self.affine(d[key], mode=mode, padding_mode=padding_mode) + affine = self.affine.affine_grid.get_transformation_matrix() + self.push_transform(d, key, orig_size=orig_size, extra_info={"affine": affine}) + return d + + def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + d = deepcopy(dict(data)) + + for key, mode, padding_mode in self.key_iterator(d, self.mode, self.padding_mode): + transform = self.get_most_recent_transform(d, key) + orig_size = transform[InverseKeys.ORIG_SIZE.value] + # Create inverse transform + fwd_affine = transform[InverseKeys.EXTRA_INFO.value]["affine"] + inv_affine = np.linalg.inv(fwd_affine) + + affine_grid = AffineGrid(affine=inv_affine) + grid: torch.Tensor = affine_grid(orig_size) # type: ignore + + # Apply inverse transform + out = self.affine.resampler(d[key], grid, mode, padding_mode) + + # Convert to numpy + d[key] = out if isinstance(out, np.ndarray) else out.cpu().numpy() + + # Remove the applied transform + self.pop_transform(d, key) + return d -class RandAffined(RandomizableTransform, MapTransform): +class RandAffined(RandomizableTransform, MapTransform, InvertibleTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.RandAffine`. """ @@ -667,13 +695,40 @@ def __call__( sp_size = fall_back_tuple(self.rand_affine.spatial_size, data[self.keys[0]].shape[1:]) if self._do_transform: grid = self.rand_affine.rand_affine_grid(spatial_size=sp_size) + affine = self.rand_affine.rand_affine_grid.get_transformation_matrix() else: grid = create_grid(spatial_size=sp_size) + affine = np.eye(len(sp_size) + 1) for key, mode, padding_mode in self.key_iterator(d, self.mode, self.padding_mode): + self.push_transform(d, key, extra_info={"affine": affine}) d[key] = self.rand_affine.resampler(d[key], grid, mode=mode, padding_mode=padding_mode) return d + def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + d = deepcopy(dict(data)) + + for key, mode, padding_mode in self.key_iterator(d, self.mode, self.padding_mode): + transform = self.get_most_recent_transform(d, key) + orig_size = transform[InverseKeys.ORIG_SIZE.value] + # Create inverse transform + fwd_affine = transform[InverseKeys.EXTRA_INFO.value]["affine"] + inv_affine = np.linalg.inv(fwd_affine) + + affine_grid = AffineGrid(affine=inv_affine) + grid: torch.Tensor = affine_grid(orig_size) # type: ignore + + # Apply inverse transform + out = self.rand_affine.resampler(d[key], grid, mode, padding_mode) + + # Convert to numpy + d[key] = out if isinstance(out, np.ndarray) else out.cpu().numpy() + + # Remove the applied transform + self.pop_transform(d, key) + + return d + class Rand2DElasticd(RandomizableTransform, MapTransform): """ diff --git a/tests/test_inverse.py b/tests/test_inverse.py index 03e1270ea3..c1225ea11c 100644 --- a/tests/test_inverse.py +++ b/tests/test_inverse.py @@ -24,6 +24,7 @@ from monai.networks.nets import UNet from monai.transforms import ( AddChanneld, + Affined, BorderPadd, CenterSpatialCropd, Compose, @@ -33,6 +34,7 @@ InvertibleTransform, LoadImaged, Orientationd, + RandAffined, RandAxisFlipd, RandFlipd, Randomizable, @@ -365,6 +367,40 @@ ) ) +TESTS.append( + ( + "Affine 3d", + "3D", + 1e-1, + Affined( + KEYS, + spatial_size=[155, 179, 192], + rotate_params=[np.pi / 6, -np.pi / 5, np.pi / 7], + shear_params=[0.5, 0.5], + translate_params=[10, 5, -4], + scale_params=[0.8, 1.3], + ), + ) +) + +TESTS.append( + ( + "RandAffine 3d", + "3D", + 1e-1, + RandAffined( + KEYS, + [155, 179, 192], + prob=1, + padding_mode="zeros", + rotate_range=[np.pi / 6, -np.pi / 5, np.pi / 7], + shear_range=[(0.5, 0.5)], + translate_range=[10, 5, -4], + scale_range=[(0.8, 1.2), (0.9, 1.3)], + ), + ) +) + TESTS_COMPOSE_X2 = [(t[0] + " Compose", t[1], t[2], Compose(Compose(t[3:]))) for t in TESTS] TESTS = TESTS + TESTS_COMPOSE_X2 # type: ignore diff --git a/tests/test_rand_affined.py b/tests/test_rand_affined.py index 54d71ad8f7..ae2adbe3b3 100644 --- a/tests/test_rand_affined.py +++ b/tests/test_rand_affined.py @@ -145,6 +145,8 @@ def test_rand_affined(self, input_param, input_data, expected_val): res = g(input_data) for key in res: result = res[key] + if "_transforms" in key: + continue expected = expected_val[key] if isinstance(expected_val, dict) else expected_val self.assertEqual(isinstance(result, torch.Tensor), isinstance(expected, torch.Tensor)) if isinstance(result, torch.Tensor):