From 9e650107403b13d3eb02252aa32cfb979064ee61 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Thu, 18 Mar 2021 08:58:34 +0000 Subject: [PATCH] Inverse Affined and RandAffined (#1781) Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/__init__.py | 2 +- monai/transforms/inverse.py | 131 +++++++++++++++++++++++++ monai/transforms/spatial/dictionary.py | 126 ++++++++++++++++++++---- requirements-dev.txt | 1 + tests/test_inverse.py | 48 ++++++++- tests/test_rand_elasticd_2d.py | 2 + tests/test_rand_elasticd_3d.py | 2 + 7 files changed, 290 insertions(+), 22 deletions(-) diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index 5b12da4d21..b5dced2983 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -138,7 +138,7 @@ ThresholdIntensityD, ThresholdIntensityDict, ) -from .inverse import InvertibleTransform +from .inverse import InvertibleTransform, NonRigidTransform from .io.array import LoadImage, SaveImage from .io.dictionary import LoadImaged, LoadImageD, LoadImageDict, SaveImaged, SaveImageD, SaveImageDict from .post.array import ( diff --git a/monai/transforms/inverse.py b/monai/transforms/inverse.py index f9de8746ca..fae560669d 100644 --- a/monai/transforms/inverse.py +++ b/monai/transforms/inverse.py @@ -9,12 +9,19 @@ # See the License for the specific language governing permissions and # limitations under the License. +import warnings from typing import Dict, Hashable, Optional, Tuple import numpy as np +import torch from monai.transforms.transform import RandomizableTransform, Transform from monai.utils.enums import InverseKeys +from monai.utils.module import optional_import + +sitk, has_sitk = optional_import("SimpleITK") +vtk, has_vtk = optional_import("vtk") +vtk_numpy_support, _ = optional_import("vtk.util.numpy_support") __all__ = ["InvertibleTransform"] @@ -111,3 +118,127 @@ def inverse(self, data: dict) -> Dict[Hashable, np.ndarray]: """ raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") + + +class NonRigidTransform(Transform): + @staticmethod + def _get_disp_to_def_arr(shape, spacing): + def_to_disp = np.mgrid[[slice(0, i) for i in shape]].astype(np.float64) + for idx, i in enumerate(shape): + # shift for origin (in MONAI, center of image) + def_to_disp[idx] -= (i - 1) / 2 + # if supplied, account for spacing (e.g., for control point grids) + if spacing is not None: + def_to_disp[idx] *= spacing[idx] + return def_to_disp + + @staticmethod + def _inv_disp_w_sitk(fwd_disp, num_iters): + fwd_disp_sitk = sitk.GetImageFromArray(fwd_disp, isVector=True) + inv_disp_sitk = sitk.InvertDisplacementField(fwd_disp_sitk, num_iters) + inv_disp = sitk.GetArrayFromImage(inv_disp_sitk) + return inv_disp + + @staticmethod + def _inv_disp_w_vtk(fwd_disp): + orig_shape = fwd_disp.shape + required_num_tensor_components = 3 + # VTK requires 3 tensor components, so if shape was (H, W, 2), make it + # (H, W, 1, 3) (i.e., depth 1 with a 3rd tensor component of 0s) + while fwd_disp.shape[-1] < required_num_tensor_components: + fwd_disp = np.append(fwd_disp, np.zeros(fwd_disp.shape[:-1] + (1,)), axis=-1) + fwd_disp = fwd_disp[..., None, :] + + # Create VTKDoubleArray. Shape needs to be (H*W*D, 3) + fwd_disp_flattened = fwd_disp.reshape(-1, required_num_tensor_components) # need to keep this in memory + vtk_data_array = vtk_numpy_support.numpy_to_vtk(fwd_disp_flattened) + + # Generating the vtkImageData + fwd_disp_vtk = vtk.vtkImageData() + fwd_disp_vtk.SetOrigin(0, 0, 0) + fwd_disp_vtk.SetSpacing(1, 1, 1) + fwd_disp_vtk.SetDimensions(*fwd_disp.shape[:-1][::-1]) # VTK spacing opposite order to numpy + fwd_disp_vtk.GetPointData().SetScalars(vtk_data_array) + + if __debug__: + fwd_disp_vtk_np = vtk_numpy_support.vtk_to_numpy(fwd_disp_vtk.GetPointData().GetArray(0)) + assert fwd_disp_vtk_np.size == fwd_disp.size + assert fwd_disp_vtk_np.min() == fwd_disp.min() + assert fwd_disp_vtk_np.max() == fwd_disp.max() + assert fwd_disp_vtk.GetNumberOfScalarComponents() == required_num_tensor_components + + # create b-spline coefficients for the displacement grid + bspline_filter = vtk.vtkImageBSplineCoefficients() + bspline_filter.SetInputData(fwd_disp_vtk) + bspline_filter.Update() + + # use these b-spline coefficients to create a transform + bspline_transform = vtk.vtkBSplineTransform() + bspline_transform.SetCoefficientData(bspline_filter.GetOutput()) + bspline_transform.Update() + + # invert the b-spline transform onto a new grid + grid_maker = vtk.vtkTransformToGrid() + grid_maker.SetInput(bspline_transform.GetInverse()) + grid_maker.SetGridOrigin(fwd_disp_vtk.GetOrigin()) + grid_maker.SetGridSpacing(fwd_disp_vtk.GetSpacing()) + grid_maker.SetGridExtent(fwd_disp_vtk.GetExtent()) + grid_maker.SetGridScalarTypeToFloat() + grid_maker.Update() + + # Get inverse displacement as an image + inv_disp_vtk = grid_maker.GetOutput() + + # Convert back to numpy and reshape + inv_disp = vtk_numpy_support.vtk_to_numpy(inv_disp_vtk.GetPointData().GetArray(0)) + # if there were originally < 3 tensor components, remove the zeros we added at the start + inv_disp = inv_disp[..., : orig_shape[-1]] + # reshape to original + inv_disp = inv_disp.reshape(orig_shape) + + return inv_disp + + @staticmethod + def compute_inverse_deformation( + num_spatial_dims, fwd_def_orig, spacing=None, num_iters: int = 100, use_package: str = "vtk" + ): + """Package can be vtk or sitk.""" + if use_package.lower() == "vtk" and not has_vtk: + warnings.warn("Please install VTK to estimate inverse of non-rigid transforms. Data has not been modified") + return None + if use_package.lower() == "sitk" and not has_sitk: + warnings.warn( + "Please install SimpleITK to estimate inverse of non-rigid transforms. Data has not been modified" + ) + return None + + # Convert to numpy if necessary + if isinstance(fwd_def_orig, torch.Tensor): + fwd_def_orig = fwd_def_orig.cpu().numpy() + # Remove any extra dimensions (we'll add them back in at the end) + fwd_def = fwd_def_orig[:num_spatial_dims] + # Def -> disp + def_to_disp = NonRigidTransform._get_disp_to_def_arr(fwd_def.shape[1:], spacing) + fwd_disp = fwd_def - def_to_disp + # move tensor component to end (T,H,W,[D])->(H,W,[D],T) + fwd_disp = np.moveaxis(fwd_disp, 0, -1) + + # If using vtk... + if use_package.lower() == "vtk": + inv_disp = NonRigidTransform._inv_disp_w_vtk(fwd_disp) + # If using sitk... + elif use_package.lower() == "sitk": + inv_disp = NonRigidTransform._inv_disp_w_sitk(fwd_disp, num_iters) + else: + raise RuntimeError("Enter vtk or sitk for inverse calculation") + + # move tensor component back to beginning + inv_disp = np.moveaxis(inv_disp, -1, 0) + # Disp -> def + inv_def = inv_disp + def_to_disp + # Add back in any removed dimensions + ndim_in = fwd_def_orig.shape[0] + ndim_out = inv_def.shape[0] + inv_def = np.concatenate([inv_def, fwd_def_orig[ndim_out:ndim_in]]) + + return inv_def diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index caa1a34e08..dc6ef816ca 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -25,7 +25,7 @@ from monai.networks.layers import AffineTransform from monai.networks.layers.simplelayers import GaussianFilter from monai.transforms.croppad.array import CenterSpatialCrop, SpatialPad -from monai.transforms.inverse import InvertibleTransform +from monai.transforms.inverse import InvertibleTransform, NonRigidTransform from monai.transforms.spatial.array import ( Affine, AffineGrid, @@ -50,9 +50,9 @@ ensure_tuple, ensure_tuple_rep, fall_back_tuple, + optional_import, ) from monai.utils.enums import InverseKeys -from monai.utils.module import optional_import nib, _ = optional_import("nibabel") @@ -730,7 +730,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar return d -class Rand2DElasticd(RandomizableTransform, MapTransform): +class Rand2DElasticd(RandomizableTransform, MapTransform, InvertibleTransform, NonRigidTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.Rand2DElastic`. """ @@ -822,6 +822,17 @@ def randomize(self, spatial_size: Sequence[int]) -> None: super().randomize(None) self.rand_2d_elastic.randomize(spatial_size) + @staticmethod + def cpg_to_dvf(cpg, spacing, output_shape): + grid = torch.nn.functional.interpolate( + recompute_scale_factor=True, + input=cpg.unsqueeze(0), + scale_factor=ensure_tuple_rep(spacing, 2), + mode=InterpolateMode.BICUBIC.value, + align_corners=False, + ) + return CenterSpatialCrop(roi_size=output_shape)(grid[0]) + def __call__( self, data: Mapping[Hashable, Union[np.ndarray, torch.Tensor]] ) -> Dict[Hashable, Union[np.ndarray, torch.Tensor]]: @@ -831,25 +842,64 @@ def __call__( self.randomize(spatial_size=sp_size) if self._do_transform: - grid = self.rand_2d_elastic.deform_grid(spatial_size=sp_size) - grid = self.rand_2d_elastic.rand_affine_grid(grid=grid) - grid = torch.nn.functional.interpolate( # type: ignore - recompute_scale_factor=True, - input=grid.unsqueeze(0), - scale_factor=ensure_tuple_rep(self.rand_2d_elastic.deform_grid.spacing, 2), - mode=InterpolateMode.BICUBIC.value, - align_corners=False, - ) - grid = CenterSpatialCrop(roi_size=sp_size)(grid[0]) + cpg = self.rand_2d_elastic.deform_grid(spatial_size=sp_size) + cpg_w_affine = self.rand_2d_elastic.rand_affine_grid(grid=cpg) + affine = self.rand_2d_elastic.rand_affine_grid.get_transformation_matrix() + grid = self.cpg_to_dvf(cpg_w_affine, self.rand_2d_elastic.deform_grid.spacing, sp_size) + extra_info: Optional[Dict] = {"cpg": deepcopy(cpg), "affine": deepcopy(affine)} else: grid = create_grid(spatial_size=sp_size) + extra_info = None for key, mode, padding_mode in self.key_iterator(d, self.mode, self.padding_mode): + self.push_transform(d, key, extra_info=extra_info) d[key] = self.rand_2d_elastic.resampler(d[key], grid, mode=mode, padding_mode=padding_mode) return d + def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + d = deepcopy(dict(data)) + # This variable will be `not None` if vtk or sitk is present + inv_def_no_affine = None + + for idx, (key, mode, padding_mode) in enumerate(self.key_iterator(d, self.mode, self.padding_mode)): + transform = self.get_most_recent_transform(d, key) + # Create inverse transform + if transform[InverseKeys.DO_TRANSFORM.value]: + orig_size = transform[InverseKeys.ORIG_SIZE.value] + # Only need to calculate inverse deformation once as it is the same for all keys + if idx == 0: + # If magnitude == 0, then non-rigid component is identity -- so just create blank + if self.rand_2d_elastic.deform_grid.magnitude == (0.0, 0.0): + inv_def_no_affine = create_grid(spatial_size=orig_size) + else: + fwd_cpg_no_affine = transform[InverseKeys.EXTRA_INFO.value]["cpg"] + fwd_def_no_affine = self.cpg_to_dvf( + fwd_cpg_no_affine, self.rand_2d_elastic.deform_grid.spacing, orig_size + ) + inv_def_no_affine = self.compute_inverse_deformation(len(orig_size), fwd_def_no_affine) + # if inverse did not succeed (sitk or vtk present), data will not be changed. + if inv_def_no_affine is not None: + fwd_affine = transform[InverseKeys.EXTRA_INFO.value]["affine"] + inv_affine = np.linalg.inv(fwd_affine) + inv_def_w_affine_wrong_size = AffineGrid(affine=inv_affine, as_tensor_output=False)( + grid=inv_def_no_affine + ) + # Back to original size + inv_def_w_affine = CenterSpatialCrop(roi_size=orig_size)(inv_def_w_affine_wrong_size) # type: ignore + # Apply inverse transform + if inv_def_no_affine is not None: + out = self.rand_2d_elastic.resampler(d[key], inv_def_w_affine, mode, padding_mode) + d[key] = out.cpu().numpy() if isinstance(out, torch.Tensor) else out + + else: + d[key] = CenterSpatialCrop(roi_size=orig_size)(d[key]) + # Remove the applied transform + self.pop_transform(d, key) + + return d + -class Rand3DElasticd(RandomizableTransform, MapTransform): +class Rand3DElasticd(RandomizableTransform, MapTransform, InvertibleTransform, NonRigidTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.Rand3DElastic`. """ @@ -949,17 +999,55 @@ def __call__( sp_size = fall_back_tuple(self.rand_3d_elastic.spatial_size, data[self.keys[0]].shape[1:]) self.randomize(grid_size=sp_size) - grid = create_grid(spatial_size=sp_size) + grid_no_affine = create_grid(spatial_size=sp_size) + affine = np.eye(4) if self._do_transform: device = self.rand_3d_elastic.device - grid = torch.tensor(grid).to(device) + grid_no_affine = torch.tensor(grid_no_affine).to(device) gaussian = GaussianFilter(spatial_dims=3, sigma=self.rand_3d_elastic.sigma, truncated=3.0).to(device) offset = torch.tensor(self.rand_3d_elastic.rand_offset, device=device).unsqueeze(0) - grid[:3] += gaussian(offset)[0] * self.rand_3d_elastic.magnitude - grid = self.rand_3d_elastic.rand_affine_grid(grid=grid) + grid_no_affine[:3] += gaussian(offset)[0] * self.rand_3d_elastic.magnitude + grid_w_affine = self.rand_3d_elastic.rand_affine_grid(grid=grid_no_affine) + affine = self.rand_3d_elastic.rand_affine_grid.get_transformation_matrix() + else: + grid_w_affine = grid_no_affine + affine = np.eye(len(sp_size) + 1) for key, mode, padding_mode in self.key_iterator(d, self.mode, self.padding_mode): - d[key] = self.rand_3d_elastic.resampler(d[key], grid, mode=mode, padding_mode=padding_mode) + self.push_transform(d, key, extra_info={"grid_no_affine": grid_no_affine, "affine": affine}) + d[key] = self.rand_3d_elastic.resampler(d[key], grid_w_affine, mode=mode, padding_mode=padding_mode) + return d + + def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + d = deepcopy(dict(data)) + + for idx, (key, mode, padding_mode) in enumerate(self.key_iterator(d, self.mode, self.padding_mode)): + transform = self.get_most_recent_transform(d, key) + # Create inverse transform + if transform[InverseKeys.DO_TRANSFORM.value]: + orig_size = transform[InverseKeys.ORIG_SIZE.value] + # Only need to calculate inverse deformation once as it is the same for all keys + if idx == 0: + fwd_def_no_affine = transform[InverseKeys.EXTRA_INFO.value]["grid_no_affine"] + inv_def_no_affine = self.compute_inverse_deformation(len(orig_size), fwd_def_no_affine) + # if inverse did not succeed (sitk or vtk present), data will not be changed. + if inv_def_no_affine is not None: + fwd_affine = transform[InverseKeys.EXTRA_INFO.value]["affine"] + inv_affine = np.linalg.inv(fwd_affine) + inv_def_w_affine_wrong_size = AffineGrid(affine=inv_affine, as_tensor_output=False)( + grid=inv_def_no_affine + ) + # Back to original size + inv_def_w_affine = CenterSpatialCrop(roi_size=orig_size)(inv_def_w_affine_wrong_size) # type: ignore + # Apply inverse transform + if inv_def_w_affine is not None: + out = self.rand_3d_elastic.resampler(d[key], inv_def_w_affine, mode, padding_mode) + d[key] = out.cpu().numpy() if isinstance(out, torch.Tensor) else out + else: + d[key] = CenterSpatialCrop(roi_size=orig_size)(d[key]) + # Remove the applied transform + self.pop_transform(d, key) + return d diff --git a/requirements-dev.txt b/requirements-dev.txt index dc4181b310..e8c63d48d7 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -32,3 +32,4 @@ sphinx-autodoc-typehints==1.11.1 sphinx-rtd-theme==0.5.0 cucim==0.18.1 openslide-python==1.1.2 +vtk diff --git a/tests/test_inverse.py b/tests/test_inverse.py index c1225ea11c..d5574ce34d 100644 --- a/tests/test_inverse.py +++ b/tests/test_inverse.py @@ -34,6 +34,8 @@ InvertibleTransform, LoadImaged, Orientationd, + Rand2DElasticd, + Rand3DElasticd, RandAffined, RandAxisFlipd, RandFlipd, @@ -55,12 +57,13 @@ ) from monai.utils import first, get_seed, optional_import, set_determinism from monai.utils.enums import InverseKeys -from tests.utils import make_nifti_image, make_rand_affine +from tests.utils import make_nifti_image, make_rand_affine, test_is_quick if TYPE_CHECKING: - + has_vtk = True has_nib = True else: + _, has_vtk = optional_import("vtk") _, has_nib = optional_import("nibabel") KEYS = ["image", "label"] @@ -401,6 +404,47 @@ ) ) +if has_vtk: + TESTS.append( + ( + "Rand2DElasticd 2d", + "2D", + 2e-1, + Rand2DElasticd( + KEYS, + spacing=(10.0, 10.0), + magnitude_range=(1, 1), + spatial_size=(155, 192), + prob=1, + padding_mode="zeros", + rotate_range=[(np.pi / 6, np.pi / 6)], + shear_range=[(0.5, 0.5)], + translate_range=[10, 5], + scale_range=[(0.2, 0.2), (0.3, 0.3)], + ), + ) + ) + +if not test_is_quick() and has_vtk: + TESTS.append( + ( + "Rand3DElasticd 3d", + "3D", + 1e-1, + Rand3DElasticd( + KEYS, + sigma_range=(3, 5), + magnitude_range=(100, 100), + prob=1, + padding_mode="zeros", + rotate_range=[np.pi / 6, np.pi / 7], + shear_range=[(0.5, 0.5), 0.2], + translate_range=[10, 5, 3], + scale_range=[(0.8, 1.2), (0.9, 1.3)], + ), + ) + ) + TESTS_COMPOSE_X2 = [(t[0] + " Compose", t[1], t[2], Compose(Compose(t[3:]))) for t in TESTS] TESTS = TESTS + TESTS_COMPOSE_X2 # type: ignore diff --git a/tests/test_rand_elasticd_2d.py b/tests/test_rand_elasticd_2d.py index f8eb026088..88f2438606 100644 --- a/tests/test_rand_elasticd_2d.py +++ b/tests/test_rand_elasticd_2d.py @@ -142,6 +142,8 @@ def test_rand_2d_elasticd(self, input_param, input_data, expected_val): g.set_random_state(123) res = g(input_data) for key in res: + if "_transforms" in key: + continue result = res[key] expected = expected_val[key] if isinstance(expected_val, dict) else expected_val self.assertEqual(isinstance(result, torch.Tensor), isinstance(expected, torch.Tensor)) diff --git a/tests/test_rand_elasticd_3d.py b/tests/test_rand_elasticd_3d.py index 47ab814882..cf9f56c109 100644 --- a/tests/test_rand_elasticd_3d.py +++ b/tests/test_rand_elasticd_3d.py @@ -113,6 +113,8 @@ def test_rand_3d_elasticd(self, input_param, input_data, expected_val): g.set_random_state(123) res = g(input_data) for key in res: + if "_transforms" in key: + continue result = res[key] expected = expected_val[key] if isinstance(expected_val, dict) else expected_val self.assertEqual(isinstance(result, torch.Tensor), isinstance(expected, torch.Tensor))