From d2a040bfe9cfdc5c2798a450bce22aeff2ec7b95 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Mon, 1 Mar 2021 13:20:43 +0000 Subject: [PATCH 01/64] inverse transformations Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/__init__.py | 1 + monai/transforms/compose.py | 21 +- monai/transforms/croppad/dictionary.py | 222 ++++++++++- monai/transforms/inverse_transform.py | 213 ++++++++++ monai/transforms/spatial/array.py | 96 +++-- monai/transforms/spatial/dictionary.py | 484 +++++++++++++++++++++-- requirements-dev.txt | 1 + tests/test_inverse.py | 512 +++++++++++++++++++++++++ tests/test_spacingd.py | 8 +- 9 files changed, 1475 insertions(+), 83 deletions(-) create mode 100644 monai/transforms/inverse_transform.py create mode 100644 tests/test_inverse.py diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index 8b30d76bec..718e7e32dd 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -138,6 +138,7 @@ ThresholdIntensityD, ThresholdIntensityDict, ) +from .inverse_transform import InvertibleTransform, NonRigidTransform from .io.array import LoadImage, SaveImage from .io.dictionary import LoadImaged, LoadImageD, LoadImageDict, SaveImaged, SaveImageD, SaveImageDict from .post.array import ( diff --git a/monai/transforms/compose.py b/monai/transforms/compose.py index a9f66b12a0..678667bb0a 100644 --- a/monai/transforms/compose.py +++ b/monai/transforms/compose.py @@ -13,10 +13,13 @@ """ import warnings -from typing import Any, Callable, Optional, Sequence, Union +from copy import deepcopy +from typing import Any, Callable, Hashable, Mapping, Optional, Sequence, Tuple, Union import numpy as np +from monai.transforms.inverse_transform import InvertibleTransform + # For backwards compatiblity (so this still works: from monai.transforms.compose import MapTransform) from monai.transforms.transform import MapTransform, Randomizable, RandomizableTransform, Transform # noqa: F401 from monai.transforms.utils import apply_transform @@ -25,7 +28,7 @@ __all__ = ["Compose"] -class Compose(RandomizableTransform): +class Compose(RandomizableTransform, InvertibleTransform): """ ``Compose`` provides the ability to chain a series of calls together in a sequence. Each transform in the sequence must take a single argument and @@ -136,3 +139,17 @@ def __call__(self, input_): for _transform in self.transforms: input_ = apply_transform(_transform, input_) return input_ + + def inverse(self, data, keys: Optional[Tuple[Hashable, ...]] = None): + if not isinstance(data, Mapping): + raise RuntimeError("Inverse method only available for dictionary transforms") + d = deepcopy(dict(data)) + if keys: + keys = ensure_tuple(keys) + + # loop backwards over transforms + for t in reversed(self.transforms): + # check if transform is one of the invertible ones + if isinstance(t, InvertibleTransform): + d = t.inverse(d, keys) + return d diff --git a/monai/transforms/croppad/dictionary.py b/monai/transforms/croppad/dictionary.py index 9739c6322f..aee5a7039c 100644 --- a/monai/transforms/croppad/dictionary.py +++ b/monai/transforms/croppad/dictionary.py @@ -15,6 +15,8 @@ Class names are ended with 'd' to denote dictionary-based transforms. """ +from copy import deepcopy +from math import floor from typing import Any, Callable, Dict, Hashable, List, Mapping, Optional, Sequence, Tuple, Union import numpy as np @@ -30,6 +32,7 @@ SpatialCrop, SpatialPad, ) +from monai.transforms.inverse_transform import InvertibleTransform from monai.transforms.transform import MapTransform, Randomizable, RandomizableTransform from monai.transforms.utils import ( generate_pos_neg_label_crop_centers, @@ -82,7 +85,7 @@ NumpyPadModeSequence = Union[Sequence[Union[NumpyPadMode, str]], NumpyPadMode, str] -class SpatialPadd(MapTransform): +class SpatialPadd(MapTransform, InvertibleTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.SpatialPad`. Performs padding to the data, symmetric for all sides or all on one side for each dimension. @@ -117,11 +120,34 @@ def __init__( def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) for key, m in zip(self.keys, self.mode): + self.append_applied_transforms(d, key) d[key] = self.padder(d[key], mode=m) return d + def inverse( + self, data: Mapping[Hashable, np.ndarray], keys: Optional[Tuple[Hashable, ...]] = None + ) -> Dict[Hashable, np.ndarray]: + d = deepcopy(dict(data)) + for key in keys or self.keys: + transform = self.get_most_recent_transform(d, key) + # Create inverse transform + orig_size = transform["orig_size"] + if self.padder.method == Method.SYMMETRIC: + current_size = d[key].shape[1:] + roi_center = [floor(i / 2) if r % 2 == 0 else (i - 1) // 2 for r, i in zip(orig_size, current_size)] + else: + roi_center = [floor(r / 2) if r % 2 == 0 else (r - 1) // 2 for r in orig_size] + + inverse_transform = SpatialCrop(roi_center, orig_size) + # Apply inverse transform + d[key] = inverse_transform(d[key]) + # Remove the applied transform + self.remove_most_recent_transform(d, key) + + return d -class BorderPadd(MapTransform): + +class BorderPadd(MapTransform, InvertibleTransform): """ Pad the input data by adding specified borders to every dimension. Dictionary-based wrapper of :py:class:`monai.transforms.BorderPad`. @@ -162,11 +188,38 @@ def __init__( def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) for key, m in zip(self.keys, self.mode): + self.append_applied_transforms(d, key) d[key] = self.padder(d[key], mode=m) return d + def inverse( + self, data: Mapping[Hashable, np.ndarray], keys: Optional[Tuple[Hashable, ...]] = None + ) -> Dict[Hashable, np.ndarray]: + d = deepcopy(dict(data)) + + for key in keys or self.keys: + transform = self.get_most_recent_transform(d, key) + # Create inverse transform + orig_size = np.array(transform["orig_size"]) + roi_start = np.array(self.padder.spatial_border) + # Need to convert single value to [min1,min2,...] + if roi_start.size == 1: + roi_start = np.full((len(orig_size)), roi_start) + # need to convert [min1,max1,min2,...] to [min1,min2,...] + elif roi_start.size == 2 * orig_size.size: + roi_start = roi_start[::2] + roi_end = np.array(transform["orig_size"]) + roi_start + + inverse_transform = SpatialCrop(roi_start=roi_start, roi_end=roi_end) + # Apply inverse transform + d[key] = inverse_transform(d[key]) + # Remove the applied transform + self.remove_most_recent_transform(d, key) + + return d + -class DivisiblePadd(MapTransform): +class DivisiblePadd(MapTransform, InvertibleTransform): """ Pad the input data, so that the spatial sizes are divisible by `k`. Dictionary-based wrapper of :py:class:`monai.transforms.DivisiblePad`. @@ -198,11 +251,32 @@ def __init__( def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) for key, m in zip(self.keys, self.mode): + self.append_applied_transforms(d, key) d[key] = self.padder(d[key], mode=m) return d + def inverse( + self, data: Mapping[Hashable, np.ndarray], keys: Optional[Tuple[Hashable, ...]] = None + ) -> Dict[Hashable, np.ndarray]: + d = deepcopy(dict(data)) + + for key in keys or self.keys: + transform = self.get_most_recent_transform(d, key) + # Create inverse transform + orig_size = np.array(transform["orig_size"]) + current_size = np.array(d[key].shape[1:]) + roi_start = np.floor((current_size - orig_size) / 2) + roi_end = orig_size + roi_start + inverse_transform = SpatialCrop(roi_start=roi_start, roi_end=roi_end) + # Apply inverse transform + d[key] = inverse_transform(d[key]) + # Remove the applied transform + self.remove_most_recent_transform(d, key) + + return d + -class SpatialCropd(MapTransform): +class SpatialCropd(MapTransform, InvertibleTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.SpatialCrop`. Either a spatial center and size must be provided, or alternatively if center and size @@ -232,11 +306,35 @@ def __init__( def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) for key in self.keys: + self.append_applied_transforms(d, key) d[key] = self.cropper(d[key]) return d + def inverse( + self, data: Mapping[Hashable, np.ndarray], keys: Optional[Tuple[Hashable, ...]] = None + ) -> Dict[Hashable, np.ndarray]: + d = deepcopy(dict(data)) + + for key in keys or self.keys: + transform = self.get_most_recent_transform(d, key) + # Create inverse transform + orig_size = transform["orig_size"] + pad_to_start = self.cropper.roi_start + pad_to_end = orig_size - self.cropper.roi_end + # interweave mins and maxes + pad = np.empty((2 * len(orig_size)), dtype=np.int32) + pad[0::2] = pad_to_start + pad[1::2] = pad_to_end + inverse_transform = BorderPad(pad.tolist()) + # Apply inverse transform + d[key] = inverse_transform(d[key]) + # Remove the applied transform + self.remove_most_recent_transform(d, key) + + return d + -class CenterSpatialCropd(MapTransform): +class CenterSpatialCropd(MapTransform, InvertibleTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.CenterSpatialCrop`. @@ -254,11 +352,38 @@ def __init__(self, keys: KeysCollection, roi_size: Union[Sequence[int], int]) -> def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) for key in self.keys: + orig_size = d[key].shape[1:] d[key] = self.cropper(d[key]) + self.append_applied_transforms(d, key, orig_size=orig_size) return d + def inverse( + self, data: Mapping[Hashable, np.ndarray], keys: Optional[Tuple[Hashable, ...]] = None + ) -> Dict[Hashable, np.ndarray]: + d = deepcopy(dict(data)) + + for key in keys or self.keys: + transform = self.get_most_recent_transform(d, key) + # Create inverse transform + orig_size = np.array(transform["orig_size"]) + current_size = np.array(d[key].shape[1:]) + pad_to_start = np.floor((orig_size - current_size) / 2) + # in each direction, if original size is even and current size is odd, += 1 + pad_to_start[np.logical_and(orig_size % 2 == 0, current_size % 2 == 1)] += 1 + pad_to_end = orig_size - current_size - pad_to_start + pad = np.empty((2 * len(orig_size)), dtype=np.int32) + pad[0::2] = pad_to_start + pad[1::2] = pad_to_end + inverse_transform = BorderPad(pad.tolist()) + # Apply inverse transform + d[key] = inverse_transform(d[key]) + # Remove the applied transform + self.remove_most_recent_transform(d, key) -class RandSpatialCropd(RandomizableTransform, MapTransform): + return d + + +class RandSpatialCropd(RandomizableTransform, MapTransform, InvertibleTransform): """ Dictionary-based version :py:class:`monai.transforms.RandSpatialCrop`. Crop image with random size or specific size ROI. It can crop at a random position as @@ -285,6 +410,7 @@ def __init__( ) -> None: RandomizableTransform.__init__(self) MapTransform.__init__(self, keys) + self._do_transform = True self.roi_size = roi_size self.random_center = random_center self.random_size = random_size @@ -306,12 +432,50 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda raise AssertionError for key in self.keys: if self.random_center: + self.append_applied_transforms(d, key, {"slices": [(i.start, i.stop) for i in self._slices[1:]]}) # type: ignore d[key] = d[key][self._slices] else: + self.append_applied_transforms(d, key) cropper = CenterSpatialCrop(self._size) d[key] = cropper(d[key]) return d + def inverse( + self, data: Mapping[Hashable, np.ndarray], keys: Optional[Tuple[Hashable, ...]] = None + ) -> Dict[Hashable, np.ndarray]: + d = deepcopy(dict(data)) + + for key in keys or self.keys: + transform = self.get_most_recent_transform(d, key) + # Create inverse transform + orig_size = transform["orig_size"] + random_center = self.random_center + pad_to_start = np.empty((len(orig_size)), dtype=np.int32) + pad_to_end = np.empty((len(orig_size)), dtype=np.int32) + if random_center: + for i, _slice in enumerate(transform["extra_info"]["slices"]): + pad_to_start[i] = _slice[0] + pad_to_end[i] = orig_size[i] - _slice[1] + else: + current_size = d[key].shape[1:] + for i, (o_s, c_s) in enumerate(zip(orig_size, current_size)): + pad_to_start[i] = pad_to_end[i] = (o_s - c_s) / 2 + if o_s % 2 == 0 and c_s % 2 == 1: + pad_to_start[i] += 1 + elif o_s % 2 == 1 and c_s % 2 == 0: + pad_to_end[i] += 1 + # interweave mins and maxes + pad = np.empty((2 * len(orig_size)), dtype=np.int32) + pad[0::2] = pad_to_start + pad[1::2] = pad_to_end + inverse_transform = BorderPad(pad.tolist()) + # Apply inverse transform + d[key] = inverse_transform(d[key]) + # Remove the applied transform + self.remove_most_recent_transform(d, key) + + return d + class RandSpatialCropSamplesd(RandomizableTransform, MapTransform): """ @@ -366,7 +530,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> List[Dict[Hashable, n return [self.cropper(data) for _ in range(self.num_samples)] -class CropForegroundd(MapTransform): +class CropForegroundd(MapTransform, InvertibleTransform): """ Dictionary-based version :py:class:`monai.transforms.CropForeground`. Crop only the foreground object of the expected images. @@ -418,9 +582,33 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda d[self.end_coord_key] = np.asarray(box_end) cropper = SpatialCrop(roi_start=box_start, roi_end=box_end) for key in self.keys: + self.append_applied_transforms(d, key, extra_info={"box_start": box_start, "box_end": box_end}) d[key] = cropper(d[key]) return d + def inverse( + self, data: Mapping[Hashable, np.ndarray], keys: Optional[Tuple[Hashable, ...]] = None + ) -> Dict[Hashable, np.ndarray]: + d = deepcopy(dict(data)) + for key in keys or self.keys: + transform = self.get_most_recent_transform(d, key) + # Create inverse transform + orig_size = np.array(transform["orig_size"]) + extra_info = transform["extra_info"] + pad_to_start = np.array(extra_info["box_start"]) + pad_to_end = orig_size - np.array(extra_info["box_end"]) + # interweave mins and maxes + pad = np.empty((2 * len(orig_size)), dtype=np.int32) + pad[0::2] = pad_to_start + pad[1::2] = pad_to_end + inverse_transform = BorderPad(pad.tolist()) + # Apply inverse transform + d[key] = inverse_transform(d[key]) + # Remove the applied transform + self.remove_most_recent_transform(d, key) + + return d + class RandWeightedCropd(RandomizableTransform, MapTransform): """ @@ -596,7 +784,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> List[Dict[Hashable, n return results -class ResizeWithPadOrCropd(MapTransform): +class ResizeWithPadOrCropd(MapTransform, InvertibleTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.ResizeWithPadOrCrop`. @@ -624,7 +812,25 @@ def __init__( def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) for key in self.keys: + orig_size = d[key].shape[1:] d[key] = self.padcropper(d[key]) + self.append_applied_transforms(d, key, orig_size=orig_size) + return d + + def inverse( + self, data: Mapping[Hashable, np.ndarray], keys: Optional[Tuple[Hashable, ...]] = None + ) -> Dict[Hashable, np.ndarray]: + d = deepcopy(dict(data)) + for key in keys or self.keys: + transform = self.get_most_recent_transform(d, key) + # Create inverse transform + orig_size = transform["orig_size"] + inverse_transform = ResizeWithPadOrCrop(spatial_size=orig_size, mode=self.padcropper.padder.mode) + # Apply inverse transform + d[key] = inverse_transform(d[key]) + # Remove the applied transform + self.remove_most_recent_transform(d, key) + return d diff --git a/monai/transforms/inverse_transform.py b/monai/transforms/inverse_transform.py new file mode 100644 index 0000000000..38919d1fc0 --- /dev/null +++ b/monai/transforms/inverse_transform.py @@ -0,0 +1,213 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings +from typing import Dict, Hashable, Optional, Tuple + +import numpy as np +import torch + +from monai.transforms.transform import RandomizableTransform, Transform +from monai.utils import optional_import + +sitk, has_sitk = optional_import("SimpleITK") +vtk, has_vtk = optional_import("vtk") +vtk_numpy_support, _ = optional_import("vtk.util.numpy_support") + +__all__ = ["InvertibleTransform", "NonRigidTransform"] + + +class InvertibleTransform(Transform): + """Classes for invertible transforms. + + This class exists so that an ``invert`` method can be implemented. This allows, for + example, images to be cropped, rotated, padded, etc., during training and inference, + and after be returned to their original size before saving to file for comparison in + an external viewer. + + When the `__call__` method is called, a serialization of the class is stored. When + the `inverse` method is called, the serialization is then removed. We use last in, + first out for the inverted transforms. + """ + + def append_applied_transforms( + self, + data: dict, + key: Hashable, + extra_info: Optional[dict] = None, + orig_size: Optional[Tuple] = None, + ) -> None: + """Append to list of applied transforms for that key.""" + key_transform = str(key) + "_transforms" + info = { + "class": self.__class__.__name__, + "id": id(self), + "orig_size": orig_size or data[key].shape[1:], + } + if extra_info is not None: + info["extra_info"] = extra_info + # If class is randomizable transform, store whether the transform was actually performed (based on `prob`) + if isinstance(self, RandomizableTransform): + info["do_transform"] = self._do_transform + # If this is the first, create list + if key_transform not in data: + data[key_transform] = [] + data[key_transform].append(info) + + def check_transforms_match(self, transform: dict) -> None: + # Check transorms are of same type. + if transform["id"] != id(self): + raise RuntimeError("Should inverse most recently applied invertible transform first") + + def get_most_recent_transform(self, data: dict, key: Hashable) -> dict: + """Get most recent transform.""" + transform = dict(data[str(key) + "_transforms"][-1]) + self.check_transforms_match(transform) + return transform + + @staticmethod + def remove_most_recent_transform(data: dict, key: Hashable) -> None: + """Remove most recent transform.""" + data[str(key) + "_transforms"].pop() + + def inverse(self, data: dict, keys: Optional[Tuple[Hashable, ...]] = None) -> Dict[Hashable, np.ndarray]: + """ + Inverse of ``__call__``. + + Raises: + NotImplementedError: When the subclass does not override this method. + + """ + raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") + + +class NonRigidTransform(Transform): + @staticmethod + def _get_disp_to_def_arr(shape, spacing): + def_to_disp = np.mgrid[[slice(0, i) for i in shape]].astype(np.float64) + for idx, i in enumerate(shape): + # shift for origin (in MONAI, center of image) + def_to_disp[idx] -= (i - 1) / 2 + # if supplied, account for spacing (e.g., for control point grids) + if spacing is not None: + def_to_disp[idx] *= spacing[idx] + return def_to_disp + + @staticmethod + def _inv_disp_w_sitk(fwd_disp, num_iters): + fwd_disp_sitk = sitk.GetImageFromArray(fwd_disp, isVector=True) + inv_disp_sitk = sitk.InvertDisplacementField(fwd_disp_sitk, num_iters) + inv_disp = sitk.GetArrayFromImage(inv_disp_sitk) + return inv_disp + + @staticmethod + def _inv_disp_w_vtk(fwd_disp): + orig_shape = fwd_disp.shape + required_num_tensor_components = 3 + # VTK requires 3 tensor components, so if shape was (H, W, 2), make it + # (H, W, 1, 3) (i.e., depth 1 with a 3rd tensor component of 0s) + while fwd_disp.shape[-1] < required_num_tensor_components: + fwd_disp = np.append(fwd_disp, np.zeros(fwd_disp.shape[:-1] + (1,)), axis=-1) + fwd_disp = fwd_disp[..., None, :] + + # Create VTKDoubleArray. Shape needs to be (H*W*D, 3) + fwd_disp_flattened = fwd_disp.reshape(-1, required_num_tensor_components) # need to keep this in memory + vtk_data_array = vtk_numpy_support.numpy_to_vtk(fwd_disp_flattened) + + # Generating the vtkImageData + fwd_disp_vtk = vtk.vtkImageData() + fwd_disp_vtk.SetOrigin(0, 0, 0) + fwd_disp_vtk.SetSpacing(1, 1, 1) + fwd_disp_vtk.SetDimensions(*fwd_disp.shape[:-1][::-1]) # VTK spacing opposite order to numpy + fwd_disp_vtk.GetPointData().SetScalars(vtk_data_array) + + if __debug__: + fwd_disp_vtk_np = vtk_numpy_support.vtk_to_numpy(fwd_disp_vtk.GetPointData().GetArray(0)) + assert fwd_disp_vtk_np.size == fwd_disp.size + assert fwd_disp_vtk_np.min() == fwd_disp.min() + assert fwd_disp_vtk_np.max() == fwd_disp.max() + assert fwd_disp_vtk.GetNumberOfScalarComponents() == required_num_tensor_components + + # create b-spline coefficients for the displacement grid + bspline_filter = vtk.vtkImageBSplineCoefficients() + bspline_filter.SetInputData(fwd_disp_vtk) + bspline_filter.Update() + + # use these b-spline coefficients to create a transform + bspline_transform = vtk.vtkBSplineTransform() + bspline_transform.SetCoefficientData(bspline_filter.GetOutput()) + bspline_transform.Update() + + # invert the b-spline transform onto a new grid + grid_maker = vtk.vtkTransformToGrid() + grid_maker.SetInput(bspline_transform.GetInverse()) + grid_maker.SetGridOrigin(fwd_disp_vtk.GetOrigin()) + grid_maker.SetGridSpacing(fwd_disp_vtk.GetSpacing()) + grid_maker.SetGridExtent(fwd_disp_vtk.GetExtent()) + grid_maker.SetGridScalarTypeToFloat() + grid_maker.Update() + + # Get inverse displacement as an image + inv_disp_vtk = grid_maker.GetOutput() + + # Convert back to numpy and reshape + inv_disp = vtk_numpy_support.vtk_to_numpy(inv_disp_vtk.GetPointData().GetArray(0)) + # if there were originally < 3 tensor components, remove the zeros we added at the start + inv_disp = inv_disp[..., : orig_shape[-1]] + # reshape to original + inv_disp = inv_disp.reshape(orig_shape) + + return inv_disp + + @staticmethod + def compute_inverse_deformation( + num_spatial_dims, fwd_def_orig, spacing=None, num_iters: int = 100, use_package: str = "vtk" + ): + """Package can be vtk or sitk.""" + if use_package.lower() == "vtk" and not has_vtk: + warnings.warn("Please install VTK to estimate inverse of non-rigid transforms. Data has not been modified") + return None + if use_package.lower() == "sitk" and not has_sitk: + warnings.warn( + "Please install SimpleITK to estimate inverse of non-rigid transforms. Data has not been modified" + ) + return None + + # Convert to numpy if necessary + if isinstance(fwd_def_orig, torch.Tensor): + fwd_def_orig = fwd_def_orig.cpu().numpy() + # Remove any extra dimensions (we'll add them back in at the end) + fwd_def = fwd_def_orig[:num_spatial_dims] + # Def -> disp + def_to_disp = NonRigidTransform._get_disp_to_def_arr(fwd_def.shape[1:], spacing) + fwd_disp = fwd_def - def_to_disp + # move tensor component to end (T,H,W,[D])->(H,W,[D],T) + fwd_disp = np.moveaxis(fwd_disp, 0, -1) + + # If using vtk... + if use_package.lower() == "vtk": + inv_disp = NonRigidTransform._inv_disp_w_vtk(fwd_disp) + # If using sitk... + elif use_package.lower() == "sitk": + inv_disp = NonRigidTransform._inv_disp_w_sitk(fwd_disp, num_iters) + else: + raise RuntimeError("Enter vtk or sitk for inverse calculation") + + # move tensor component back to beginning + inv_disp = np.moveaxis(inv_disp, -1, 0) + # Disp -> def + inv_def = inv_disp + def_to_disp + # Add back in any removed dimensions + ndim_in = fwd_def_orig.shape[0] + ndim_out = inv_def.shape[0] + inv_def = np.concatenate([inv_def, fwd_def_orig[ndim_out:ndim_in]]) + + return inv_def diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index 3559d0eb3c..b80c5ba4a9 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -150,7 +150,7 @@ def __call__( ValueError: When ``pixdim`` is nonpositive. Returns: - data_array (resampled into `self.pixdim`), original pixdim, current pixdim. + data_array (resampled into `self.pixdim`), original affine, current affine. """ _dtype = dtype or self.dtype or data_array.dtype @@ -280,7 +280,7 @@ def __call__( ornt[:, 0] += 1 # skip channel dim ornt = np.concatenate([np.array([[0, 1]]), ornt]) shape = data_array.shape[1:] - data_array = nib.orientations.apply_orientation(data_array, ornt) + data_array = np.ascontiguousarray(nib.orientations.apply_orientation(data_array, ornt)) new_affine = affine_ @ nib.orientations.inv_ornt_aff(spatial_ornt, shape) new_affine = to_affine_nd(affine, new_affine) return data_array, affine, new_affine @@ -316,7 +316,7 @@ def __call__(self, img: np.ndarray) -> np.ndarray: class Resize(Transform): """ - Resize the input image to given spatial size. + Resize the input image to given spatial size (with scaling, not cropping/padding). Implemented using :py:class:`torch.nn.functional.interpolate`. Args: @@ -428,7 +428,8 @@ def __call__( padding_mode: Optional[Union[GridSamplePadMode, str]] = None, align_corners: Optional[bool] = None, dtype: DtypeLike = None, - ) -> np.ndarray: + return_rotation_matrix: bool = False, + ) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]: """ Args: img: channel first array, must have shape: [chns, H, W] or [chns, H, W, D]. @@ -445,6 +446,7 @@ def __call__( dtype: data type for resampling computation. Defaults to ``self.dtype``. If None, use the data type of input data. To be compatible with other modules, the output data type is always ``np.float32``. + return_rotation_matrix: whether or not to return the applied rotation matrix. Raises: ValueError: When ``img`` spatially is not one of [2D, 3D]. @@ -481,7 +483,10 @@ def __call__( torch.as_tensor(np.ascontiguousarray(transform).astype(_dtype)), spatial_size=output_shape, ) - return np.asarray(output.squeeze(0).detach().cpu().numpy(), dtype=np.float32) + output_np = np.asarray(output.squeeze(0).detach().cpu().numpy(), dtype=np.float32) + if return_rotation_matrix: + return output_np, transform + return output_np class Zoom(Transform): @@ -589,7 +594,7 @@ def __init__(self, k: int = 1, spatial_axes: Tuple[int, int] = (0, 1)) -> None: If axis is negative it counts from the last to the first axis. """ self.k = k - spatial_axes_ = ensure_tuple(spatial_axes) + spatial_axes_: Tuple[int, int] = ensure_tuple(spatial_axes) # type: ignore if len(spatial_axes_) != 2: raise ValueError("spatial_axes must be 2 int numbers to indicate the axes to rotate 90 degrees.") self.spatial_axes = spatial_axes_ @@ -619,7 +624,7 @@ def __init__(self, prob: float = 0.1, max_k: int = 3, spatial_axes: Tuple[int, i spatial_axes: 2 int numbers, defines the plane to rotate with 2 spatial axes. Default: (0, 1), this is the first two axis in spatial dimensions. """ - RandomizableTransform.__init__(self, min(max(prob, 0.0), 1.0)) + RandomizableTransform.__init__(self, prob) self.max_k = max_k self.spatial_axes = spatial_axes @@ -742,7 +747,7 @@ def __call__( align_corners=self.align_corners if align_corners is None else align_corners, dtype=dtype or self.dtype or img.dtype, ) - return rotator(img) + return np.array(rotator(img)) class RandFlip(RandomizableTransform): @@ -757,7 +762,7 @@ class RandFlip(RandomizableTransform): """ def __init__(self, prob: float = 0.1, spatial_axis: Optional[Union[Sequence[int], int]] = None) -> None: - RandomizableTransform.__init__(self, min(max(prob, 0.0), 1.0)) + RandomizableTransform.__init__(self, prob) self.flipper = Flip(spatial_axis=spatial_axis) def __call__(self, img: np.ndarray) -> np.ndarray: @@ -893,6 +898,9 @@ class AffineGrid(Transform): as_tensor_output: whether to output tensor instead of numpy array. defaults to True. device: device to store the output grid data. + affine: If applied, ignore the params (`rotate_params`, etc.) and use the + supplied matrix. Should be square with each side = num of image spatial + dimensions + 1. """ @@ -904,6 +912,7 @@ def __init__( scale_params: Optional[Union[Sequence[float], float]] = None, as_tensor_output: bool = True, device: Optional[torch.device] = None, + affine: Optional[Union[np.ndarray, torch.Tensor]] = None, ) -> None: self.rotate_params = rotate_params self.shear_params = shear_params @@ -913,13 +922,19 @@ def __init__( self.as_tensor_output = as_tensor_output self.device = device + self.affine = affine + def __call__( - self, spatial_size: Optional[Sequence[int]] = None, grid: Optional[Union[np.ndarray, torch.Tensor]] = None - ) -> Union[np.ndarray, torch.Tensor]: + self, + spatial_size: Optional[Sequence[int]] = None, + grid: Optional[Union[np.ndarray, torch.Tensor]] = None, + return_affine: bool = False, + ) -> Union[np.ndarray, torch.Tensor, Tuple[Union[np.ndarray, torch.Tensor], torch.Tensor]]: """ Args: spatial_size: output grid size. grid: grid to be transformed. Shape must be (3, H, W) for 2D or (4, H, W, D) for 3D. + return_affine: boolean as to whether to return the generated affine matrix or not. Raises: ValueError: When ``grid=None`` and ``spatial_size=None``. Incompatible values. @@ -931,16 +946,20 @@ def __call__( else: raise ValueError("Incompatible values: grid=None and spatial_size=None.") - spatial_dims = len(grid.shape) - 1 - affine = np.eye(spatial_dims + 1) - if self.rotate_params: - affine = affine @ create_rotate(spatial_dims, self.rotate_params) - if self.shear_params: - affine = affine @ create_shear(spatial_dims, self.shear_params) - if self.translate_params: - affine = affine @ create_translate(spatial_dims, self.translate_params) - if self.scale_params: - affine = affine @ create_scale(spatial_dims, self.scale_params) + affine: Union[np.ndarray, torch.Tensor] + if self.affine is None: + spatial_dims = len(grid.shape) - 1 + affine = np.eye(spatial_dims + 1) + if self.rotate_params: + affine = affine @ create_rotate(spatial_dims, self.rotate_params) + if self.shear_params: + affine = affine @ create_shear(spatial_dims, self.shear_params) + if self.translate_params: + affine = affine @ create_translate(spatial_dims, self.translate_params) + if self.scale_params: + affine = affine @ create_scale(spatial_dims, self.scale_params) + else: + affine = self.affine affine = torch.as_tensor(np.ascontiguousarray(affine), device=self.device) grid = torch.tensor(grid) if not isinstance(grid, torch.Tensor) else grid.detach().clone() @@ -949,9 +968,10 @@ def __call__( grid = (affine.float() @ grid.reshape((grid.shape[0], -1)).float()).reshape([-1] + list(grid.shape[1:])) if grid is None or not isinstance(grid, torch.Tensor): raise ValueError("Unknown grid.") - if self.as_tensor_output: - return grid - return np.asarray(grid.cpu().numpy()) + output: Union[np.ndarray, torch.Tensor] = grid if self.as_tensor_output else np.asarray(grid.cpu().numpy()) + if return_affine: + return output, affine + return output class RandAffineGrid(RandomizableTransform): @@ -1021,12 +1041,16 @@ def randomize(self, data: Optional[Any] = None) -> None: self.scale_params = self._get_rand_param(self.scale_range, 1.0) def __call__( - self, spatial_size: Optional[Sequence[int]] = None, grid: Optional[Union[np.ndarray, torch.Tensor]] = None - ) -> Union[np.ndarray, torch.Tensor]: + self, + spatial_size: Optional[Sequence[int]] = None, + grid: Optional[Union[np.ndarray, torch.Tensor]] = None, + return_affine: bool = False, + ) -> Union[np.ndarray, torch.Tensor, Tuple[Union[np.ndarray, torch.Tensor], torch.Tensor]]: """ Args: spatial_size: output grid size. grid: grid to be transformed. Shape must be (3, H, W) for 2D or (4, H, W, D) for 3D. + return_affine: boolean as to whether to return the generated affine matrix or not. Returns: a 2D (3xHxW) or 3D (4xHxWxD) grid. @@ -1040,7 +1064,7 @@ def __call__( as_tensor_output=self.as_tensor_output, device=self.device, ) - return affine_grid(spatial_size, grid) + return affine_grid(spatial_size, grid, return_affine) class RandDeformGrid(RandomizableTransform): @@ -1266,7 +1290,7 @@ def __call__( See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample """ sp_size = fall_back_tuple(spatial_size or self.spatial_size, img.shape[1:]) - grid = self.affine_grid(spatial_size=sp_size) + grid: torch.Tensor = self.affine_grid(spatial_size=sp_size) # type: ignore return self.resampler( img=img, grid=grid, mode=mode or self.mode, padding_mode=padding_mode or self.padding_mode ) @@ -1357,7 +1381,8 @@ def __call__( spatial_size: Optional[Union[Sequence[int], int]] = None, mode: Optional[Union[GridSampleMode, str]] = None, padding_mode: Optional[Union[GridSamplePadMode, str]] = None, - ) -> Union[np.ndarray, torch.Tensor]: + return_affine: bool = False, + ) -> Union[np.ndarray, torch.Tensor, Tuple[Union[np.ndarray, torch.Tensor], torch.Tensor]]: """ Args: img: shape must be (num_channels, H, W[, D]), @@ -1372,17 +1397,26 @@ def __call__( padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} Padding mode for outside grid values. Defaults to ``self.padding_mode``. See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + return_affine: boolean as to whether to return the generated affine matrix or not. """ self.randomize() sp_size = fall_back_tuple(spatial_size or self.spatial_size, img.shape[1:]) + affine = np.eye(len(sp_size) + 1) if self._do_transform: - grid = self.rand_affine_grid(spatial_size=sp_size) + out = self.rand_affine_grid(spatial_size=sp_size) + if return_affine: + grid, affine = out + else: + grid = out else: grid = create_grid(spatial_size=sp_size) - return self.resampler( + resampled = self.resampler( img=img, grid=grid, mode=mode or self.mode, padding_mode=padding_mode or self.padding_mode ) + if return_affine: + return resampled, affine + return resampled class Rand2DElastic(RandomizableTransform): diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index 6693d75bcd..3b654c67f6 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -15,15 +15,19 @@ Class names are ended with 'd' to denote dictionary-based transforms. """ +from copy import deepcopy from typing import Any, Dict, Hashable, Mapping, Optional, Sequence, Tuple, Union import numpy as np import torch from monai.config import DtypeLike, KeysCollection +from monai.networks.layers import AffineTransform from monai.networks.layers.simplelayers import GaussianFilter -from monai.transforms.croppad.array import CenterSpatialCrop +from monai.transforms.croppad.array import CenterSpatialCrop, SpatialPad +from monai.transforms.inverse_transform import InvertibleTransform, NonRigidTransform from monai.transforms.spatial.array import ( + AffineGrid, Flip, Orientation, Rand2DElastic, @@ -45,8 +49,11 @@ ensure_tuple, ensure_tuple_rep, fall_back_tuple, + optional_import, ) +nib, _ = optional_import("nibabel") + __all__ = [ "Spacingd", "Orientationd", @@ -98,7 +105,7 @@ NumpyPadModeSequence = Union[Sequence[Union[NumpyPadMode, str]], NumpyPadMode, str] -class Spacingd(MapTransform): +class Spacingd(MapTransform, InvertibleTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.Spacing`. @@ -177,10 +184,11 @@ def __call__( ) -> Dict[Union[Hashable, str], Union[np.ndarray, Dict[str, np.ndarray]]]: d: Dict = dict(data) for idx, key in enumerate(self.keys): - meta_data = d[f"{key}_{self.meta_key_postfix}"] + meta_data_key = f"{key}_{self.meta_key_postfix}" + meta_data = d[meta_data_key] # resample array of each corresponding key # using affine fetched from d[affine_key] - d[key], _, new_affine = self.spacing_transform( + d[key], old_affine, new_affine = self.spacing_transform( data_array=np.asarray(d[key]), affine=meta_data["affine"], mode=self.mode[idx], @@ -188,12 +196,46 @@ def __call__( align_corners=self.align_corners[idx], dtype=self.dtype[idx], ) + self.append_applied_transforms( + d, key, extra_info={"meta_data_key": meta_data_key, "old_affine": old_affine} + ) # set the 'affine' key meta_data["affine"] = new_affine return d + def inverse( + self, data: Mapping[Hashable, np.ndarray], keys: Optional[Tuple[Hashable, ...]] = None + ) -> Dict[Hashable, np.ndarray]: + d = deepcopy(dict(data)) + for idx, key in enumerate(keys or self.keys): + transform = self.get_most_recent_transform(d, key) + if self.spacing_transform.diagonal: + raise RuntimeError( + "Spacingd:inverse not yet implemented for diagonal=True. " + + "Please raise a github issue if you need this feature" + ) + # Create inverse transform + meta_data = d[transform["extra_info"]["meta_data_key"]] + old_affine = np.array(transform["extra_info"]["old_affine"]) + orig_pixdim = np.sqrt(np.sum(np.square(old_affine), 0))[:-1] + inverse_transform = Spacing(orig_pixdim, diagonal=self.spacing_transform.diagonal) + # Apply inverse + d[key], _, new_affine = inverse_transform( + data_array=np.asarray(d[key]), + affine=meta_data["affine"], + mode=self.mode[idx], + padding_mode=self.padding_mode[idx], + align_corners=self.align_corners[idx], + dtype=self.dtype[idx], + ) + meta_data["affine"] = new_affine + # Remove the applied transform + self.remove_most_recent_transform(d, key) -class Orientationd(MapTransform): + return d + + +class Orientationd(MapTransform, InvertibleTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.Orientation`. @@ -246,13 +288,40 @@ def __call__( ) -> Dict[Union[Hashable, str], Union[np.ndarray, Dict[str, np.ndarray]]]: d: Dict = dict(data) for key in self.keys: - meta_data = d[f"{key}_{self.meta_key_postfix}"] - d[key], _, new_affine = self.ornt_transform(d[key], affine=meta_data["affine"]) + meta_data_key = f"{key}_{self.meta_key_postfix}" + meta_data = d[meta_data_key] + d[key], old_affine, new_affine = self.ornt_transform(d[key], affine=meta_data["affine"]) + self.append_applied_transforms( + d, key, extra_info={"meta_data_key": meta_data_key, "old_affine": old_affine} + ) + d[meta_data_key]["affine"] = new_affine + return d + + def inverse( + self, data: Mapping[Hashable, np.ndarray], keys: Optional[Tuple[Hashable, ...]] = None + ) -> Dict[Hashable, np.ndarray]: + d = deepcopy(dict(data)) + for key in keys or self.keys: + transform = self.get_most_recent_transform(d, key) + # Create inverse transform + meta_data = d[transform["extra_info"]["meta_data_key"]] + orig_affine = transform["extra_info"]["old_affine"] + orig_axcodes = nib.orientations.aff2axcodes(orig_affine) + inverse_transform = Orientation( + axcodes=orig_axcodes, + as_closest_canonical=self.ornt_transform.as_closest_canonical, + labels=self.ornt_transform.labels, + ) + # Apply inverse + d[key], _, new_affine = inverse_transform(d[key], affine=meta_data["affine"]) meta_data["affine"] = new_affine + # Remove the applied transform + self.remove_most_recent_transform(d, key) + return d -class Rotate90d(MapTransform): +class Rotate90d(MapTransform, InvertibleTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.Rotate90`. """ @@ -270,11 +339,33 @@ def __init__(self, keys: KeysCollection, k: int = 1, spatial_axes: Tuple[int, in def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) for key in self.keys: + self.append_applied_transforms(d, key) d[key] = self.rotator(d[key]) return d + def inverse( + self, data: Mapping[Hashable, np.ndarray], keys: Optional[Tuple[Hashable, ...]] = None + ) -> Dict[Hashable, np.ndarray]: + d = deepcopy(dict(data)) + for key in keys or self.keys: + _ = self.get_most_recent_transform(d, key) + # Create inverse transform + spatial_axes = self.rotator.spatial_axes + num_times_rotated = self.rotator.k + num_times_to_rotate = 4 - num_times_rotated + inverse_transform = Rotate90(num_times_to_rotate, spatial_axes) + # Might need to convert to numpy + if isinstance(d[key], torch.Tensor): + d[key] = torch.Tensor(d[key]).cpu().numpy() + # Apply inverse + d[key] = inverse_transform(d[key]) + # Remove the applied transform + self.remove_most_recent_transform(d, key) + + return d + -class RandRotate90d(RandomizableTransform, MapTransform): +class RandRotate90d(RandomizableTransform, MapTransform, InvertibleTransform): """ Dictionary-based version :py:class:`monai.transforms.RandRotate90`. With probability `prob`, input arrays are rotated by 90 degrees @@ -319,10 +410,33 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Mapping[Hashable, np. for key in self.keys: if self._do_transform: d[key] = rotator(d[key]) + self.append_applied_transforms(d, key, extra_info={"rand_k": self._rand_k}) return d + def inverse( + self, data: Mapping[Hashable, np.ndarray], keys: Optional[Tuple[Hashable, ...]] = None + ) -> Dict[Hashable, np.ndarray]: + d = deepcopy(dict(data)) + for key in keys or self.keys: + transform = self.get_most_recent_transform(d, key) + # Check if random transform was actually performed (based on `prob`) + if transform["do_transform"]: + # Create inverse transform + num_times_rotated = transform["extra_info"]["rand_k"] + num_times_to_rotate = 4 - num_times_rotated + inverse_transform = Rotate90(num_times_to_rotate, self.spatial_axes) + # Might need to convert to numpy + if isinstance(d[key], torch.Tensor): + d[key] = torch.Tensor(d[key]).cpu().numpy() + # Apply inverse + d[key] = inverse_transform(d[key]) + # Remove the applied transform + self.remove_most_recent_transform(d, key) -class Resized(MapTransform): + return d + + +class Resized(MapTransform, InvertibleTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.Resize`. @@ -358,11 +472,30 @@ def __init__( def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) for idx, key in enumerate(self.keys): + self.append_applied_transforms(d, key) d[key] = self.resizer(d[key], mode=self.mode[idx], align_corners=self.align_corners[idx]) return d + def inverse( + self, data: Mapping[Hashable, np.ndarray], keys: Optional[Tuple[Hashable, ...]] = None + ) -> Dict[Hashable, np.ndarray]: + d = deepcopy(dict(data)) + for idx, key in enumerate(keys or self.keys): + transform = self.get_most_recent_transform(d, key) + orig_size = transform["orig_size"] + mode = self.mode[idx] + align_corners = self.align_corners[idx] + # Create inverse transform + inverse_transform = Resize(orig_size, mode, align_corners) + # Apply inverse transform + d[key] = inverse_transform(d[key]) + # Remove the applied transform + self.remove_most_recent_transform(d, key) + + return d + -class RandAffined(RandomizableTransform, MapTransform): +class RandAffined(RandomizableTransform, MapTransform, InvertibleTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.RandAffine`. """ @@ -452,16 +585,44 @@ def __call__( sp_size = fall_back_tuple(self.rand_affine.spatial_size, data[self.keys[0]].shape[1:]) if self._do_transform: - grid = self.rand_affine.rand_affine_grid(spatial_size=sp_size) + grid, affine = self.rand_affine.rand_affine_grid(spatial_size=sp_size, return_affine=True) else: grid = create_grid(spatial_size=sp_size) + affine = np.eye(len(sp_size) + 1) for idx, key in enumerate(self.keys): + self.append_applied_transforms(d, key, extra_info={"affine": affine}) d[key] = self.rand_affine.resampler(d[key], grid, mode=self.mode[idx], padding_mode=self.padding_mode[idx]) return d + def inverse( + self, data: Mapping[Hashable, np.ndarray], keys: Optional[Tuple[Hashable, ...]] = None + ) -> Dict[Hashable, np.ndarray]: + d = deepcopy(dict(data)) + + for idx, key in enumerate(keys or self.keys): + transform = self.get_most_recent_transform(d, key) + orig_size = transform["orig_size"] + # Create inverse transform + fwd_affine = transform["extra_info"]["affine"] + inv_affine = np.linalg.inv(fwd_affine) + + affine_grid = AffineGrid(affine=inv_affine) + grid: torch.Tensor = affine_grid(orig_size) # type: ignore + + # Apply inverse transform + out = self.rand_affine.resampler(d[key], grid, self.mode[idx], self.padding_mode[idx]) + + # Convert to numpy + d[key] = out if isinstance(out, np.ndarray) else out.cpu().numpy() + + # Remove the applied transform + self.remove_most_recent_transform(d, key) + + return d + -class Rand2DElasticd(RandomizableTransform, MapTransform): +class Rand2DElasticd(RandomizableTransform, MapTransform, InvertibleTransform, NonRigidTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.Rand2DElastic`. """ @@ -551,6 +712,17 @@ def randomize(self, spatial_size: Sequence[int]) -> None: super().randomize(None) self.rand_2d_elastic.randomize(spatial_size) + @staticmethod + def cpg_to_dvf(cpg, spacing, output_shape): + grid = torch.nn.functional.interpolate( + recompute_scale_factor=True, + input=cpg.unsqueeze(0), + scale_factor=ensure_tuple_rep(spacing, 2), + mode=InterpolateMode.BILINEAR.value, + align_corners=False, + ) + return CenterSpatialCrop(roi_size=output_shape)(grid[0]) + def __call__( self, data: Mapping[Hashable, Union[np.ndarray, torch.Tensor]] ) -> Dict[Hashable, Union[np.ndarray, torch.Tensor]]: @@ -560,27 +732,69 @@ def __call__( self.randomize(spatial_size=sp_size) if self._do_transform: - grid = self.rand_2d_elastic.deform_grid(spatial_size=sp_size) - grid = self.rand_2d_elastic.rand_affine_grid(grid=grid) - grid = torch.nn.functional.interpolate( # type: ignore - recompute_scale_factor=True, - input=grid.unsqueeze(0), - scale_factor=ensure_tuple_rep(self.rand_2d_elastic.deform_grid.spacing, 2), - mode=InterpolateMode.BICUBIC.value, - align_corners=False, - ) - grid = CenterSpatialCrop(roi_size=sp_size)(grid[0]) + cpg = self.rand_2d_elastic.deform_grid(spatial_size=sp_size) + cpg_w_affine, affine = self.rand_2d_elastic.rand_affine_grid(grid=cpg, return_affine=True) + grid = self.cpg_to_dvf(cpg_w_affine, self.rand_2d_elastic.deform_grid.spacing, sp_size) + extra_info: Optional[Dict] = {"cpg": deepcopy(cpg), "affine": deepcopy(affine)} else: grid = create_grid(spatial_size=sp_size) + extra_info = None for idx, key in enumerate(self.keys): + self.append_applied_transforms(d, key, extra_info=extra_info) d[key] = self.rand_2d_elastic.resampler( d[key], grid, mode=self.mode[idx], padding_mode=self.padding_mode[idx] ) return d + def inverse( + self, data: Mapping[Hashable, np.ndarray], keys: Optional[Tuple[Hashable, ...]] = None + ) -> Dict[Hashable, np.ndarray]: + d = deepcopy(dict(data)) + # This variable will be `not None` if vtk or sitk is present + inv_def_no_affine = None + + for idx, key in enumerate(keys or self.keys): + transform = self.get_most_recent_transform(d, key) + # Create inverse transform + if transform["do_transform"]: + orig_size = transform["orig_size"] + # Only need to calculate inverse deformation once as it is the same for all keys + if idx == 0: + # If magnitude == 0, then non-rigid component is identity -- so just create blank + if self.rand_2d_elastic.deform_grid.magnitude == (0.0, 0.0): + inv_def_no_affine = create_grid(spatial_size=orig_size) + else: + fwd_cpg_no_affine = transform["extra_info"]["cpg"] + fwd_def_no_affine = self.cpg_to_dvf( + fwd_cpg_no_affine, self.rand_2d_elastic.deform_grid.spacing, orig_size + ) + inv_def_no_affine = self.compute_inverse_deformation(len(orig_size), fwd_def_no_affine) + # if inverse did not succeed (sitk or vtk present), data will not be changed. + if inv_def_no_affine is not None: + fwd_affine = transform["extra_info"]["affine"] + inv_affine = np.linalg.inv(fwd_affine) + inv_def_w_affine_wrong_size = AffineGrid(affine=inv_affine, as_tensor_output=False)( + grid=inv_def_no_affine + ) + # Back to original size + inv_def_w_affine = CenterSpatialCrop(roi_size=orig_size)(inv_def_w_affine_wrong_size) # type: ignore + # Apply inverse transform + if inv_def_no_affine is not None: + out = self.rand_2d_elastic.resampler( + d[key], inv_def_w_affine, self.mode[idx], self.padding_mode[idx] + ) + d[key] = out.cpu().numpy() if isinstance(out, torch.Tensor) else out + + else: + d[key] = CenterSpatialCrop(roi_size=orig_size)(d[key]) + # Remove the applied transform + self.remove_most_recent_transform(d, key) + + return d + -class Rand3DElasticd(RandomizableTransform, MapTransform): +class Rand3DElasticd(RandomizableTransform, MapTransform, InvertibleTransform, NonRigidTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.Rand3DElastic`. """ @@ -678,23 +892,62 @@ def __call__( sp_size = fall_back_tuple(self.rand_3d_elastic.spatial_size, data[self.keys[0]].shape[1:]) self.randomize(grid_size=sp_size) - grid = create_grid(spatial_size=sp_size) + grid_no_affine = create_grid(spatial_size=sp_size) if self._do_transform: device = self.rand_3d_elastic.device - grid = torch.tensor(grid).to(device) + grid_no_affine = torch.tensor(grid_no_affine).to(device) gaussian = GaussianFilter(spatial_dims=3, sigma=self.rand_3d_elastic.sigma, truncated=3.0).to(device) offset = torch.tensor(self.rand_3d_elastic.rand_offset, device=device).unsqueeze(0) - grid[:3] += gaussian(offset)[0] * self.rand_3d_elastic.magnitude - grid = self.rand_3d_elastic.rand_affine_grid(grid=grid) + grid_no_affine[:3] += gaussian(offset)[0] * self.rand_3d_elastic.magnitude + grid_w_affine, affine = self.rand_3d_elastic.rand_affine_grid(grid=grid_no_affine, return_affine=True) for idx, key in enumerate(self.keys): + self.append_applied_transforms( + d, key, extra_info={"grid_no_affine": grid_no_affine.cpu().numpy(), "affine": affine} + ) d[key] = self.rand_3d_elastic.resampler( - d[key], grid, mode=self.mode[idx], padding_mode=self.padding_mode[idx] + d[key], grid_w_affine, mode=self.mode[idx], padding_mode=self.padding_mode[idx] ) return d + def inverse( + self, data: Mapping[Hashable, np.ndarray], keys: Optional[Tuple[Hashable, ...]] = None + ) -> Dict[Hashable, np.ndarray]: + d = deepcopy(dict(data)) + + for idx, key in enumerate(keys or self.keys): + transform = self.get_most_recent_transform(d, key) + # Create inverse transform + if transform["do_transform"]: + orig_size = transform["orig_size"] + # Only need to calculate inverse deformation once as it is the same for all keys + if idx == 0: + fwd_def_no_affine = transform["extra_info"]["grid_no_affine"] + inv_def_no_affine = self.compute_inverse_deformation(len(orig_size), fwd_def_no_affine) + # if inverse did not succeed (sitk or vtk present), data will not be changed. + if inv_def_no_affine is not None: + fwd_affine = transform["extra_info"]["affine"] + inv_affine = np.linalg.inv(fwd_affine) + inv_def_w_affine_wrong_size = AffineGrid(affine=inv_affine, as_tensor_output=False)( + grid=inv_def_no_affine + ) + # Back to original size + inv_def_w_affine = CenterSpatialCrop(roi_size=orig_size)(inv_def_w_affine_wrong_size) # type: ignore + # Apply inverse transform + if inv_def_w_affine is not None: + out = self.rand_3d_elastic.resampler( + d[key], inv_def_w_affine, self.mode[idx], self.padding_mode[idx] + ) + d[key] = out.cpu().numpy() if isinstance(out, torch.Tensor) else out + else: + d[key] = CenterSpatialCrop(roi_size=orig_size)(d[key]) + # Remove the applied transform + self.remove_most_recent_transform(d, key) + + return d + -class Flipd(MapTransform): +class Flipd(MapTransform, InvertibleTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.Flip`. @@ -713,11 +966,28 @@ def __init__(self, keys: KeysCollection, spatial_axis: Optional[Union[Sequence[i def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) for key in self.keys: + self.append_applied_transforms(d, key) d[key] = self.flipper(d[key]) return d + def inverse( + self, data: Mapping[Hashable, np.ndarray], keys: Optional[Tuple[Hashable, ...]] = None + ) -> Dict[Hashable, np.ndarray]: + d = deepcopy(dict(data)) + for key in keys or self.keys: + _ = self.get_most_recent_transform(d, key) + # Might need to convert to numpy + if isinstance(d[key], torch.Tensor): + d[key] = torch.Tensor(d[key]).cpu().numpy() + # Inverse is same as forward + d[key] = self.flipper(d[key]) + # Remove the applied transform + self.remove_most_recent_transform(d, key) + + return d -class RandFlipd(RandomizableTransform, MapTransform): + +class RandFlipd(RandomizableTransform, MapTransform, InvertibleTransform): """ Dictionary-based version :py:class:`monai.transforms.RandFlip`. @@ -748,10 +1018,29 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda for key in self.keys: if self._do_transform: d[key] = self.flipper(d[key]) + self.append_applied_transforms(d, key) + return d + + def inverse( + self, data: Mapping[Hashable, np.ndarray], keys: Optional[Tuple[Hashable, ...]] = None + ) -> Dict[Hashable, np.ndarray]: + d = deepcopy(dict(data)) + for key in keys or self.keys: + transform = self.get_most_recent_transform(d, key) + # Check if random transform was actually performed (based on `prob`) + if transform["do_transform"]: + # Might need to convert to numpy + if isinstance(d[key], torch.Tensor): + d[key] = torch.Tensor(d[key]).cpu().numpy() + # Inverse is same as forward + d[key] = self.flipper(d[key]) + # Remove the applied transform + self.remove_most_recent_transform(d, key) + return d -class Rotated(MapTransform): +class Rotated(MapTransform, InvertibleTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.Rotate`. @@ -799,17 +1088,49 @@ def __init__( def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) for idx, key in enumerate(self.keys): - d[key] = self.rotator( + orig_size = d[key].shape[1:] + d[key], rot_mat = self.rotator( d[key], mode=self.mode[idx], padding_mode=self.padding_mode[idx], align_corners=self.align_corners[idx], dtype=self.dtype[idx], + return_rotation_matrix=True, + ) + self.append_applied_transforms(d, key, orig_size=orig_size, extra_info={"rot_mat": rot_mat}) + return d + + def inverse( + self, data: Mapping[Hashable, np.ndarray], keys: Optional[Tuple[Hashable, ...]] = None + ) -> Dict[Hashable, np.ndarray]: + d = deepcopy(dict(data)) + for idx, key in enumerate(keys or self.keys): + transform = self.get_most_recent_transform(d, key) + # Create inverse transform + fwd_rot_mat = transform["extra_info"]["rot_mat"] + inv_rot_mat = np.linalg.inv(fwd_rot_mat) + + xform = AffineTransform( + normalized=False, + mode=self.mode[idx], + padding_mode=self.padding_mode[idx], + align_corners=self.align_corners[idx], + reverse_indexing=True, + ) + dtype = self.dtype[idx] + output = xform( + torch.as_tensor(np.ascontiguousarray(d[key]).astype(dtype)).unsqueeze(0), + torch.as_tensor(np.ascontiguousarray(inv_rot_mat).astype(dtype)), + spatial_size=transform["orig_size"], ) + d[key] = np.asarray(output.squeeze(0).detach().cpu().numpy(), dtype=np.float32) + # Remove the applied transform + self.remove_most_recent_transform(d, key) + return d -class RandRotated(RandomizableTransform, MapTransform): +class RandRotated(RandomizableTransform, MapTransform, InvertibleTransform): """ Dictionary-based version :py:class:`monai.transforms.RandRotate` Randomly rotates the input arrays. @@ -888,23 +1209,60 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda self.randomize() d = dict(data) if not self._do_transform: + for key in self.keys: + self.append_applied_transforms(d, key, extra_info={"rot_mat": np.eye(4)}) return d + angle: Union[Sequence[float], float] = self.x if d[self.keys[0]].ndim == 3 else (self.x, self.y, self.z) rotator = Rotate( - angle=self.x if d[self.keys[0]].ndim == 3 else (self.x, self.y, self.z), + angle=angle, keep_size=self.keep_size, ) for idx, key in enumerate(self.keys): - d[key] = rotator( + orig_size = d[key].shape[1:] + d[key], rot_mat = rotator( d[key], mode=self.mode[idx], padding_mode=self.padding_mode[idx], align_corners=self.align_corners[idx], dtype=self.dtype[idx], + return_rotation_matrix=True, ) + self.append_applied_transforms(d, key, orig_size=orig_size, extra_info={"rot_mat": rot_mat}) return d + def inverse( + self, data: Mapping[Hashable, np.ndarray], keys: Optional[Tuple[Hashable, ...]] = None + ) -> Dict[Hashable, np.ndarray]: + d = deepcopy(dict(data)) + for idx, key in enumerate(keys or self.keys): + transform = self.get_most_recent_transform(d, key) + # Check if random transform was actually performed (based on `prob`) + if transform["do_transform"]: + # Create inverse transform + fwd_rot_mat = transform["extra_info"]["rot_mat"] + inv_rot_mat = np.linalg.inv(fwd_rot_mat) + + xform = AffineTransform( + normalized=False, + mode=self.mode[idx], + padding_mode=self.padding_mode[idx], + align_corners=self.align_corners[idx], + reverse_indexing=True, + ) + dtype = self.dtype[idx] + output = xform( + torch.as_tensor(np.ascontiguousarray(d[key]).astype(dtype)).unsqueeze(0), + torch.as_tensor(np.ascontiguousarray(inv_rot_mat).astype(dtype)), + spatial_size=transform["orig_size"], + ) + d[key] = np.asarray(output.squeeze(0).detach().cpu().numpy(), dtype=np.float32) + # Remove the applied transform + self.remove_most_recent_transform(d, key) + + return d -class Zoomd(MapTransform): + +class Zoomd(MapTransform, InvertibleTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.Zoom`. @@ -946,6 +1304,7 @@ def __init__( def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) for idx, key in enumerate(self.keys): + self.append_applied_transforms(d, key) d[key] = self.zoomer( d[key], mode=self.mode[idx], @@ -954,8 +1313,31 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda ) return d + def inverse( + self, data: Mapping[Hashable, np.ndarray], keys: Optional[Tuple[Hashable, ...]] = None + ) -> Dict[Hashable, np.ndarray]: + d = deepcopy(dict(data)) + for idx, key in enumerate(keys or self.keys): + transform = self.get_most_recent_transform(d, key) + # Create inverse transform + zoom = np.array(self.zoomer.zoom) + inverse_transform = Zoom(zoom=1 / zoom, keep_size=self.zoomer.keep_size) + # Apply inverse + d[key] = inverse_transform( + d[key], + mode=self.mode[idx], + padding_mode=self.padding_mode[idx], + align_corners=self.align_corners[idx], + ) + # Size might be out by 1 voxel so pad + d[key] = SpatialPad(transform["orig_size"])(d[key]) + # Remove the applied transform + self.remove_most_recent_transform(d, key) + + return d + -class RandZoomd(RandomizableTransform, MapTransform): +class RandZoomd(RandomizableTransform, MapTransform, InvertibleTransform): """ Dict-based version :py:class:`monai.transforms.RandZoom`. @@ -1021,6 +1403,8 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda self.randomize() d = dict(data) if not self._do_transform: + for key in self.keys: + self.append_applied_transforms(d, key, extra_info={"zoom": self._zoom}) return d img_dims = data[self.keys[0]].ndim @@ -1032,6 +1416,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda self._zoom = ensure_tuple_rep(self._zoom[0], img_dims - 2) + ensure_tuple(self._zoom[-1]) zoomer = Zoom(self._zoom, keep_size=self.keep_size) for idx, key in enumerate(self.keys): + self.append_applied_transforms(d, key, extra_info={"zoom": self._zoom}) d[key] = zoomer( d[key], mode=self.mode[idx], @@ -1040,6 +1425,29 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda ) return d + def inverse( + self, data: Mapping[Hashable, np.ndarray], keys: Optional[Tuple[Hashable, ...]] = None + ) -> Dict[Hashable, np.ndarray]: + d = deepcopy(dict(data)) + for idx, key in enumerate(keys or self.keys): + transform = self.get_most_recent_transform(d, key) + # Create inverse transform + zoom = np.array(transform["extra_info"]["zoom"]) + inverse_transform = Zoom(zoom=1 / zoom, keep_size=self.keep_size) + # Apply inverse + d[key] = inverse_transform( + d[key], + mode=self.mode[idx], + padding_mode=self.padding_mode[idx], + align_corners=self.align_corners[idx], + ) + # Size might be out by 1 voxel so pad + d[key] = SpatialPad(transform["orig_size"])(d[key]) + # Remove the applied transform + self.remove_most_recent_transform(d, key) + + return d + SpacingD = SpacingDict = Spacingd OrientationD = OrientationDict = Orientationd diff --git a/requirements-dev.txt b/requirements-dev.txt index 2a43e63d73..1f7b608f79 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -30,3 +30,4 @@ Sphinx==3.3.0 recommonmark==0.6.0 sphinx-autodoc-typehints==1.11.1 sphinx-rtd-theme==0.5.0 +vtk \ No newline at end of file diff --git a/tests/test_inverse.py b/tests/test_inverse.py new file mode 100644 index 0000000000..ed570394f9 --- /dev/null +++ b/tests/test_inverse.py @@ -0,0 +1,512 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random +import sys +import unittest +from typing import TYPE_CHECKING, List, Tuple + +import numpy as np +import torch +from parameterized import parameterized + +from monai.data import CacheDataset, DataLoader, create_test_image_2d, create_test_image_3d +from monai.data.utils import decollate_batch +from monai.networks.nets import UNet +from monai.transforms import ( + AddChannel, + AddChanneld, + BorderPadd, + CenterSpatialCropd, + Compose, + CropForegroundd, + DivisiblePadd, + Flipd, + InvertibleTransform, + LoadImaged, + Orientationd, + Rand2DElasticd, + Rand3DElasticd, + RandAffined, + RandFlipd, + RandRotate90d, + RandRotated, + RandSpatialCropd, + RandZoomd, + Resized, + ResizeWithPadOrCrop, + ResizeWithPadOrCropd, + Rotate90d, + Rotated, + Spacingd, + SpatialCropd, + SpatialPad, + SpatialPadd, + Zoomd, +) +from monai.utils import first, optional_import, set_determinism +from tests.utils import make_nifti_image, make_rand_affine, test_is_quick + +if TYPE_CHECKING: + import matplotlib.pyplot as plt + + has_matplotlib = True + has_vtk = True +else: + plt, has_matplotlib = optional_import("matplotlib.pyplot") + _, has_vtk = optional_import("vtk") + +set_determinism(seed=0) + +AFFINE = make_rand_affine() +AFFINE[0] *= 2 + +IM_1D = AddChannel()(np.arange(0, 10)) +IM_2D_FNAME, SEG_2D_FNAME = [make_nifti_image(i) for i in create_test_image_2d(101, 100)] +IM_3D_FNAME, SEG_3D_FNAME = [make_nifti_image(i, AFFINE) for i in create_test_image_3d(100, 101, 107)] + +KEYS = ["image", "label"] +DATA_1D = {"image": IM_1D, "label": IM_1D, "other": IM_1D} +LOAD_IMS = Compose([LoadImaged(KEYS), AddChanneld(KEYS)]) +DATA_2D = LOAD_IMS({"image": IM_2D_FNAME, "label": SEG_2D_FNAME}) +DATA_3D = LOAD_IMS({"image": IM_3D_FNAME, "label": SEG_3D_FNAME}) + +TESTS: List[Tuple] = [] + +TESTS.append( + ( + "SpatialPadd (x2) 2d", + DATA_2D, + 0.0, + SpatialPadd(KEYS, spatial_size=[111, 113], method="end"), + SpatialPadd(KEYS, spatial_size=[118, 117]), + ) +) + +TESTS.append( + ( + "SpatialPadd 3d", + DATA_3D, + 0.0, + SpatialPadd(KEYS, spatial_size=[112, 113, 116]), + ) +) + +TESTS.append( + ( + "RandRotated, prob 0", + DATA_2D, + 0, + RandRotated(KEYS, prob=0), + ) +) + +TESTS.append( + ( + "SpatialCropd 2d", + DATA_2D, + 3e-2, + SpatialCropd(KEYS, [49, 51], [90, 89]), + ) +) + +TESTS.append( + ( + "SpatialCropd 3d", + DATA_3D, + 4e-2, + SpatialCropd(KEYS, [49, 51, 44], [90, 89, 93]), + ) +) + +TESTS.append(("RandSpatialCropd 2d", DATA_2D, 5e-2, RandSpatialCropd(KEYS, [96, 93], True, False))) + +TESTS.append(("RandSpatialCropd 3d", DATA_3D, 2e-2, RandSpatialCropd(KEYS, [96, 93, 92], False, False))) + +TESTS.append( + ( + "BorderPadd 2d", + DATA_2D, + 0, + BorderPadd(KEYS, [3, 7, 2, 5]), + ) +) + +TESTS.append( + ( + "BorderPadd 2d", + DATA_2D, + 0, + BorderPadd(KEYS, [3, 7]), + ) +) + +TESTS.append( + ( + "BorderPadd 3d", + DATA_3D, + 0, + BorderPadd(KEYS, [4]), + ) +) + +TESTS.append( + ( + "DivisiblePadd 2d", + DATA_2D, + 0, + DivisiblePadd(KEYS, k=4), + ) +) + +TESTS.append( + ( + "DivisiblePadd 3d", + DATA_3D, + 0, + DivisiblePadd(KEYS, k=[4, 8, 11]), + ) +) + +TESTS.append( + ( + "Flipd 3d", + DATA_3D, + 0, + Flipd(KEYS, [1, 2]), + ) +) + +TESTS.append( + ( + "Flipd 3d", + DATA_3D, + 0, + Flipd(KEYS, [1, 2]), + ) +) + +TESTS.append( + ( + "RandFlipd 3d", + DATA_3D, + 0, + RandFlipd(KEYS, 1, [1, 2]), + ) +) + +TESTS.append( + ( + "Rotated 2d", + DATA_2D, + 8e-2, + Rotated(KEYS, random.uniform(np.pi / 6, np.pi), keep_size=True, align_corners=False), + ) +) + +TESTS.append( + ( + "Rotated 3d", + DATA_3D, + 5e-2, + Rotated(KEYS, [random.uniform(np.pi / 6, np.pi) for _ in range(3)], True), # type: ignore + ) +) + +TESTS.append( + ( + "RandRotated 3d", + DATA_3D, + 5e-2, + RandRotated(KEYS, *[random.uniform(np.pi / 6, np.pi) for _ in range(3)], 1), # type: ignore + ) +) + +TESTS.append( + ( + "Orientationd 3d", + DATA_3D, + 0, + # For data loader, output needs to be same size, so input must be square/cubic + SpatialPadd(KEYS, max(DATA_3D["image"].shape)), + Orientationd(KEYS, "RAS"), + ) +) + +TESTS.append( + ( + "Rotate90d 2d", + DATA_2D, + 0, + Rotate90d(KEYS), + ) +) + +TESTS.append( + ( + "Rotate90d 3d", + DATA_3D, + 0, + Rotate90d(KEYS, k=2, spatial_axes=(1, 2)), + ) +) + +TESTS.append( + ( + "RandRotate90d 3d", + DATA_3D, + 0, + # For data loader, output needs to be same size, so input must be square/cubic + SpatialPadd(KEYS, max(DATA_3D["image"].shape)), + RandRotate90d(KEYS, prob=1, spatial_axes=(1, 2)), + ) +) + +TESTS.append( + ( + "Zoomd 1d", + DATA_1D, + 0, + Zoomd(KEYS, zoom=2, keep_size=False), + ) +) + +TESTS.append( + ( + "Zoomd 2d", + DATA_2D, + 2e-1, + Zoomd(KEYS, zoom=0.9), + ) +) + +TESTS.append( + ( + "Zoomd 3d", + DATA_3D, + 3e-2, + Zoomd(KEYS, zoom=[2.5, 1, 3], keep_size=False), + ) +) + +TESTS.append(("RandZoom 3d", DATA_3D, 9e-2, RandZoomd(KEYS, 1, [0.5, 0.6, 0.9], [1.1, 1, 1.05], keep_size=True))) + +TESTS.append( + ( + "CenterSpatialCropd 2d", + DATA_2D, + 0, + CenterSpatialCropd(KEYS, roi_size=95), + ) +) + +TESTS.append( + ( + "CenterSpatialCropd 3d", + DATA_3D, + 0, + CenterSpatialCropd(KEYS, roi_size=[95, 97, 98]), + ) +) + +TESTS.append(("CropForegroundd 2d", DATA_2D, 0, CropForegroundd(KEYS, source_key="label", margin=2))) + +TESTS.append(("CropForegroundd 3d", DATA_3D, 0, CropForegroundd(KEYS, source_key="label"))) + +TESTS.append(("Spacingd 3d", DATA_3D, 3e-2, Spacingd(KEYS, [0.5, 0.7, 0.9], diagonal=False))) + +TESTS.append(("Resized 2d", DATA_2D, 2e-1, Resized(KEYS, [50, 47]))) + +TESTS.append(("Resized 3d", DATA_3D, 5e-2, Resized(KEYS, [201, 150, 78]))) + +TESTS.append(("ResizeWithPadOrCropd 3d", DATA_3D, 1e-2, ResizeWithPadOrCropd(KEYS, [201, 150, 78]))) + +TESTS.append( + ( + "RandAffine 3d", + DATA_3D, + 5e-2, + RandAffined( + KEYS, + [155, 179, 192], + prob=1, + padding_mode="zeros", + rotate_range=[np.pi / 6, -np.pi / 5, np.pi / 7], + shear_range=[(0.5, 0.5)], + translate_range=[10, 5, -4], + scale_range=[(0.8, 1.2), (0.9, 1.3)], + ), + ) +) + +if has_vtk: + TESTS.append( + ( + "Rand2DElasticd 2d", + DATA_2D, + 2e-1, + Rand2DElasticd( + KEYS, + spacing=(10.0, 10.0), + magnitude_range=(1, 1), + spatial_size=(155, 192), + prob=1, + padding_mode="zeros", + rotate_range=[(np.pi / 6, np.pi / 6)], + shear_range=[(0.5, 0.5)], + translate_range=[10, 5], + scale_range=[(0.2, 0.2), (0.3, 0.3)], + ), + ) + ) + +if not test_is_quick and has_vtk: + TESTS.append( + ( + "Rand3DElasticd 3d", + DATA_3D, + 1e-1, + Rand3DElasticd( + KEYS, + sigma_range=(3, 5), + magnitude_range=(100, 100), + prob=1, + padding_mode="zeros", + rotate_range=[np.pi / 6, np.pi / 7], + shear_range=[(0.5, 0.5), 0.2], + translate_range=[10, 5, 3], + scale_range=[(0.8, 1.2), (0.9, 1.3)], + ), + ) + ) + +TESTS_COMPOSE_X2 = [(t[0] + " Compose", t[1], t[2], Compose(Compose(t[3:]))) for t in TESTS] + +TESTS = TESTS + TESTS_COMPOSE_X2 # type: ignore + + +# Should fail because uses an array transform (SpatialPad), as opposed to dictionary +TEST_FAIL_0 = (DATA_2D["image"], 0.0, Compose([SpatialPad(spatial_size=[101, 103])])) +TESTS_FAIL = [TEST_FAIL_0] + + +def plot_im(orig, fwd_bck, fwd): + diff_orig_fwd_bck = orig - fwd_bck + ims_to_show = [orig, fwd, fwd_bck, diff_orig_fwd_bck] + titles = ["x", "fx", "f⁻¹fx", "x - f⁻¹fx"] + fig, axes = plt.subplots(1, 4, gridspec_kw={"width_ratios": [i.shape[1] for i in ims_to_show]}) + vmin = min(np.array(i).min() for i in [orig, fwd_bck, fwd]) + vmax = max(np.array(i).max() for i in [orig, fwd_bck, fwd]) + for im, title, ax in zip(ims_to_show, titles, axes): + _vmin, _vmax = (vmin, vmax) if id(im) != id(diff_orig_fwd_bck) else (None, None) + im = np.squeeze(np.array(im)) + while im.ndim > 2: + im = im[..., im.shape[-1] // 2] + im_show = ax.imshow(np.squeeze(im), vmin=_vmin, vmax=_vmax) + ax.set_title(title, fontsize=25) + ax.axis("off") + fig.colorbar(im_show, ax=ax) + plt.show() + + +class TestInverse(unittest.TestCase): + def check_inverse(self, name, keys, orig_d, fwd_bck_d, unmodified_d, acceptable_diff): + for key in keys: + orig = orig_d[key] + fwd_bck = fwd_bck_d[key] + if isinstance(fwd_bck, torch.Tensor): + fwd_bck = fwd_bck.cpu().numpy() + unmodified = unmodified_d[key] + if isinstance(orig, np.ndarray): + mean_diff = np.mean(np.abs(orig - fwd_bck)) + unmodded_diff = np.mean(np.abs(orig - ResizeWithPadOrCrop(orig.shape[1:])(unmodified))) + try: + self.assertLessEqual(mean_diff, acceptable_diff) + except AssertionError: + print( + f"Failed: {name}. Mean diff = {mean_diff} (expected <= {acceptable_diff}), unmodified diff: {unmodded_diff}" + ) + if has_matplotlib and orig[0].ndim > 1: + plot_im(orig, fwd_bck, unmodified) + elif orig[0].ndim == 1: + print(orig) + print(fwd_bck) + raise + + @parameterized.expand(TESTS) + def test_inverse(self, _, data, acceptable_diff, *transforms): + name = _ + + forwards = [data.copy()] + + # Apply forwards + for t in transforms: + forwards.append(t(forwards[-1])) + + # Check that error is thrown when inverse are used out of order. + t = SpatialPadd("image", [10, 5]) + with self.assertRaises(RuntimeError): + t.inverse(forwards[-1]) + + # Apply inverses + fwd_bck = forwards[-1].copy() + for i, t in enumerate(reversed(transforms)): + if isinstance(t, InvertibleTransform): + fwd_bck = t.inverse(fwd_bck) + self.check_inverse(name, data.keys(), forwards[-i - 2], fwd_bck, forwards[-1], acceptable_diff) + + @parameterized.expand(TESTS_FAIL) + def test_fail(self, data, _, *transform): + d = transform[0](data) + with self.assertRaises(RuntimeError): + d = transform[0].inverse(d) + + def test_inverse_inferred_seg(self): + + test_data = [] + for _ in range(4): + image, label = create_test_image_2d(100, 101) + test_data.append({"image": image, "label": label.astype(np.float32)}) + + batch_size = 2 + # num workers = 0 for mac + num_workers = 2 if sys.platform != "darwin" else 0 + transforms = Compose([AddChanneld(KEYS), SpatialPadd(KEYS, (150, 153)), CenterSpatialCropd(KEYS, (110, 99))]) + num_invertible_transforms = sum(1 for i in transforms.transforms if isinstance(i, InvertibleTransform)) + + dataset = CacheDataset(test_data, transform=transforms, progress=False) + loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers) + + device = "cuda" if torch.cuda.is_available() else "cpu" + model = UNet( + dimensions=2, + in_channels=1, + out_channels=1, + channels=(2, 4), + strides=(2,), + ).to(device) + + data = first(loader) + labels = data["label"].to(device) + segs = model(labels).detach().cpu() + segs_dict = {"label": segs, "label_transforms": data["label_transforms"]} + segs_dict_decollated = decollate_batch(segs_dict) + + # inverse of individual segmentation + seg_dict = first(segs_dict_decollated) + inv_seg = transforms.inverse(seg_dict, "label")["label"] + self.assertEqual(len(data["label_transforms"]), num_invertible_transforms) + self.assertEqual(len(seg_dict["label_transforms"]), num_invertible_transforms) + self.assertEqual(inv_seg.shape[1:], test_data[0]["label"].shape) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_spacingd.py b/tests/test_spacingd.py index ec32563543..c56313c41e 100644 --- a/tests/test_spacingd.py +++ b/tests/test_spacingd.py @@ -21,7 +21,7 @@ def test_spacingd_3d(self): data = {"image": np.ones((2, 10, 15, 20)), "image_meta_dict": {"affine": np.eye(4)}} spacing = Spacingd(keys="image", pixdim=(1, 2, 1.4)) res = spacing(data) - self.assertEqual(("image", "image_meta_dict"), tuple(sorted(res))) + self.assertEqual(("image", "image_meta_dict", "image_transforms"), tuple(sorted(res))) np.testing.assert_allclose(res["image"].shape, (2, 10, 8, 15)) np.testing.assert_allclose(res["image_meta_dict"]["affine"], np.diag([1, 2, 1.4, 1.0])) @@ -29,7 +29,7 @@ def test_spacingd_2d(self): data = {"image": np.ones((2, 10, 20)), "image_meta_dict": {"affine": np.eye(3)}} spacing = Spacingd(keys="image", pixdim=(1, 2, 1.4)) res = spacing(data) - self.assertEqual(("image", "image_meta_dict"), tuple(sorted(res))) + self.assertEqual(("image", "image_meta_dict", "image_transforms"), tuple(sorted(res))) np.testing.assert_allclose(res["image"].shape, (2, 10, 10)) np.testing.assert_allclose(res["image_meta_dict"]["affine"], np.diag((1, 2, 1))) @@ -49,7 +49,7 @@ def test_interp_all(self): ), ) res = spacing(data) - self.assertEqual(("image", "image_meta_dict", "seg", "seg_meta_dict"), tuple(sorted(res))) + self.assertEqual(("image", "image_meta_dict", "image_transforms", "seg", "seg_meta_dict", "seg_transforms"), tuple(sorted(res))) np.testing.assert_allclose(res["image"].shape, (2, 1, 46)) np.testing.assert_allclose(res["image_meta_dict"]["affine"], np.diag((1, 0.2, 1, 1))) @@ -69,7 +69,7 @@ def test_interp_sep(self): ), ) res = spacing(data) - self.assertEqual(("image", "image_meta_dict", "seg", "seg_meta_dict"), tuple(sorted(res))) + self.assertEqual(("image", "image_meta_dict", "image_transforms", "seg", "seg_meta_dict", "seg_transforms"), tuple(sorted(res))) np.testing.assert_allclose(res["image"].shape, (2, 1, 46)) np.testing.assert_allclose(res["image_meta_dict"]["affine"], np.diag((1, 0.2, 1, 1))) From 7dd3539403424a1025c132c80fb2485aff800a84 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Mon, 1 Mar 2021 14:20:26 +0000 Subject: [PATCH 02/64] autofix Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- tests/test_spacingd.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/tests/test_spacingd.py b/tests/test_spacingd.py index c56313c41e..e4efe4241d 100644 --- a/tests/test_spacingd.py +++ b/tests/test_spacingd.py @@ -49,7 +49,10 @@ def test_interp_all(self): ), ) res = spacing(data) - self.assertEqual(("image", "image_meta_dict", "image_transforms", "seg", "seg_meta_dict", "seg_transforms"), tuple(sorted(res))) + self.assertEqual( + ("image", "image_meta_dict", "image_transforms", "seg", "seg_meta_dict", "seg_transforms"), + tuple(sorted(res)), + ) np.testing.assert_allclose(res["image"].shape, (2, 1, 46)) np.testing.assert_allclose(res["image_meta_dict"]["affine"], np.diag((1, 0.2, 1, 1))) @@ -69,7 +72,10 @@ def test_interp_sep(self): ), ) res = spacing(data) - self.assertEqual(("image", "image_meta_dict", "image_transforms", "seg", "seg_meta_dict", "seg_transforms"), tuple(sorted(res))) + self.assertEqual( + ("image", "image_meta_dict", "image_transforms", "seg", "seg_meta_dict", "seg_transforms"), + tuple(sorted(res)), + ) np.testing.assert_allclose(res["image"].shape, (2, 1, 46)) np.testing.assert_allclose(res["image_meta_dict"]["affine"], np.diag((1, 0.2, 1, 1))) From 3236334cb0097d510623effc0d7fd3e34d69a778 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Mon, 1 Mar 2021 14:59:51 +0000 Subject: [PATCH 03/64] fix rand elastic 3d Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/spatial/dictionary.py | 6 +++++- tests/test_rand_elasticd_3d.py | 2 ++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index 3b654c67f6..4fd6008d73 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -893,6 +893,7 @@ def __call__( self.randomize(grid_size=sp_size) grid_no_affine = create_grid(spatial_size=sp_size) + affine = np.eye(4) if self._do_transform: device = self.rand_3d_elastic.device grid_no_affine = torch.tensor(grid_no_affine).to(device) @@ -900,10 +901,13 @@ def __call__( offset = torch.tensor(self.rand_3d_elastic.rand_offset, device=device).unsqueeze(0) grid_no_affine[:3] += gaussian(offset)[0] * self.rand_3d_elastic.magnitude grid_w_affine, affine = self.rand_3d_elastic.rand_affine_grid(grid=grid_no_affine, return_affine=True) + else: + grid_w_affine = grid_no_affine + affine = np.eye(len(sp_size) + 1) for idx, key in enumerate(self.keys): self.append_applied_transforms( - d, key, extra_info={"grid_no_affine": grid_no_affine.cpu().numpy(), "affine": affine} + d, key, extra_info={"grid_no_affine": grid_no_affine, "affine": affine} ) d[key] = self.rand_3d_elastic.resampler( d[key], grid_w_affine, mode=self.mode[idx], padding_mode=self.padding_mode[idx] diff --git a/tests/test_rand_elasticd_3d.py b/tests/test_rand_elasticd_3d.py index 47ab814882..cf9f56c109 100644 --- a/tests/test_rand_elasticd_3d.py +++ b/tests/test_rand_elasticd_3d.py @@ -113,6 +113,8 @@ def test_rand_3d_elasticd(self, input_param, input_data, expected_val): g.set_random_state(123) res = g(input_data) for key in res: + if "_transforms" in key: + continue result = res[key] expected = expected_val[key] if isinstance(expected_val, dict) else expected_val self.assertEqual(isinstance(result, torch.Tensor), isinstance(expected, torch.Tensor)) From faf771c0ed8c747dbe3f3e96a52f0f2aad20198c Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Mon, 1 Mar 2021 15:15:36 +0000 Subject: [PATCH 04/64] fixes Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- tests/test_rand_affined.py | 2 ++ tests/test_rand_elasticd_2d.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/tests/test_rand_affined.py b/tests/test_rand_affined.py index 54d71ad8f7..ae2adbe3b3 100644 --- a/tests/test_rand_affined.py +++ b/tests/test_rand_affined.py @@ -145,6 +145,8 @@ def test_rand_affined(self, input_param, input_data, expected_val): res = g(input_data) for key in res: result = res[key] + if "_transforms" in key: + continue expected = expected_val[key] if isinstance(expected_val, dict) else expected_val self.assertEqual(isinstance(result, torch.Tensor), isinstance(expected, torch.Tensor)) if isinstance(result, torch.Tensor): diff --git a/tests/test_rand_elasticd_2d.py b/tests/test_rand_elasticd_2d.py index f8eb026088..88f2438606 100644 --- a/tests/test_rand_elasticd_2d.py +++ b/tests/test_rand_elasticd_2d.py @@ -142,6 +142,8 @@ def test_rand_2d_elasticd(self, input_param, input_data, expected_val): g.set_random_state(123) res = g(input_data) for key in res: + if "_transforms" in key: + continue result = res[key] expected = expected_val[key] if isinstance(expected_val, dict) else expected_val self.assertEqual(isinstance(result, torch.Tensor), isinstance(expected, torch.Tensor)) From 1ca59399f55c4fcb00a0f17d8d7f01b2bbedb53d Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Mon, 1 Mar 2021 16:28:33 +0000 Subject: [PATCH 05/64] fix 2d elastic Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/spatial/dictionary.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index 4fd6008d73..9b52d76375 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -718,7 +718,7 @@ def cpg_to_dvf(cpg, spacing, output_shape): recompute_scale_factor=True, input=cpg.unsqueeze(0), scale_factor=ensure_tuple_rep(spacing, 2), - mode=InterpolateMode.BILINEAR.value, + mode=InterpolateMode.BICUBIC.value, align_corners=False, ) return CenterSpatialCrop(roi_size=output_shape)(grid[0]) From a5138bb12f050e7b0d9b2b293e079926eb554f0b Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Mon, 1 Mar 2021 16:54:48 +0000 Subject: [PATCH 06/64] code format Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/spatial/dictionary.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index 9b52d76375..96f3859fcf 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -906,9 +906,7 @@ def __call__( affine = np.eye(len(sp_size) + 1) for idx, key in enumerate(self.keys): - self.append_applied_transforms( - d, key, extra_info={"grid_no_affine": grid_no_affine, "affine": affine} - ) + self.append_applied_transforms(d, key, extra_info={"grid_no_affine": grid_no_affine, "affine": affine}) d[key] = self.rand_3d_elastic.resampler( d[key], grid_w_affine, mode=self.mode[idx], padding_mode=self.padding_mode[idx] ) From 0e891394f8e58f50f0a95dd5b1c9dc9f4e9cfd3b Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Tue, 2 Mar 2021 10:31:46 +0000 Subject: [PATCH 07/64] update inverse docstrings Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/inverse_transform.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/monai/transforms/inverse_transform.py b/monai/transforms/inverse_transform.py index 38919d1fc0..ec1c684f7c 100644 --- a/monai/transforms/inverse_transform.py +++ b/monai/transforms/inverse_transform.py @@ -36,6 +36,16 @@ class InvertibleTransform(Transform): When the `__call__` method is called, a serialization of the class is stored. When the `inverse` method is called, the serialization is then removed. We use last in, first out for the inverted transforms. + + Note to developers: When converting a transform to an invertible transform, you need to: + 1. Inherit from this class. + 2. In `__call__`, add a call to `append_applied_transforms`. + 3. Any extra information that might be needed for the inverse can be included with the + dictionary `extra_info`. This dictionary should have the same keys regardless of + whether `do_transform` was True or False and can only contain objects that are + accepted in pytorch's batch (e.g., `None` is not allowed). + 4. Implement an `inverse` method. Make sure that after performing the inverse, + `remove_most_recent_transform` is called. """ def append_applied_transforms( From c3fa403b2c4847609ac40df18382d1deceb0e4e6 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Tue, 2 Mar 2021 10:42:41 +0000 Subject: [PATCH 08/64] inverse for randaxisflipd Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/spatial/dictionary.py | 27 +++++++++++++++++++++++--- tests/test_inverse.py | 11 +++++++++++ 2 files changed, 35 insertions(+), 3 deletions(-) diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index d93f0deb5c..f65d3054ca 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -1039,10 +1039,12 @@ def inverse( d[key] = torch.Tensor(d[key]).cpu().numpy() # Inverse is same as forward d[key] = self.flipper(d[key]) - # Remove the applied transform - self.remove_most_recent_transform(d, key) + # Remove the applied transform + self.remove_most_recent_transform(d, key) + return d -class RandAxisFlipd(RandomizableTransform, MapTransform): + +class RandAxisFlipd(RandomizableTransform, MapTransform, InvertibleTransform): """ Dictionary-based version :py:class:`monai.transforms.RandAxisFlip`. @@ -1072,6 +1074,25 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda for key in self.keys: if self._do_transform: d[key] = flipper(d[key]) + self.append_applied_transforms(d, key, extra_info={"axis": self._axis}) + return d + + def inverse( + self, data: Mapping[Hashable, np.ndarray], keys: Optional[Tuple[Hashable, ...]] = None + ) -> Dict[Hashable, np.ndarray]: + d = deepcopy(dict(data)) + for key in keys or self.keys: + transform = self.get_most_recent_transform(d, key) + # Check if random transform was actually performed (based on `prob`) + if transform["do_transform"]: + flipper = Flip(spatial_axis=transform["extra_info"]["axis"]) + # Might need to convert to numpy + if isinstance(d[key], torch.Tensor): + d[key] = torch.Tensor(d[key]).cpu().numpy() + # Inverse is same as forward + d[key] = flipper(d[key]) + # Remove the applied transform + self.remove_most_recent_transform(d, key) return d diff --git a/tests/test_inverse.py b/tests/test_inverse.py index ed570394f9..0bcbfdc46c 100644 --- a/tests/test_inverse.py +++ b/tests/test_inverse.py @@ -36,6 +36,7 @@ Rand2DElasticd, Rand3DElasticd, RandAffined, + RandAxisFlipd, RandFlipd, RandRotate90d, RandRotated, @@ -203,6 +204,16 @@ ) ) +TESTS.append( + ( + "RandAxisFlipd 3d", + DATA_3D, + 0, + RandAxisFlipd(KEYS, 1), + ) +) + + TESTS.append( ( "Rotated 2d", From b8cea259d6958392da15cad143ae2983ca0544bd Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Tue, 2 Mar 2021 11:34:32 +0000 Subject: [PATCH 09/64] remove shuffle Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- tests/test_decollate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_decollate.py b/tests/test_decollate.py index 5c6f04b48e..d60eb64316 100644 --- a/tests/test_decollate.py +++ b/tests/test_decollate.py @@ -73,7 +73,7 @@ def test_decollation(self, _, data, batch_size=2, num_workers=2): transforms = Compose([LoadImaged("image"), transforms]) dataset = CacheDataset(data, transforms, progress=False) - loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers) + loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers) for b, batch_data in enumerate(loader): decollated_1 = decollate_batch(batch_data) From d4bfbb3037c4cfd913c6a7560c68bfc059e7458e Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Tue, 2 Mar 2021 18:02:56 +0000 Subject: [PATCH 10/64] merge Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- docs/source/transforms.rst | 12 ++++ monai/data/image_reader.py | 39 ++++++++---- monai/transforms/__init__.py | 4 ++ monai/transforms/io/array.py | 24 +++++++ monai/transforms/io/dictionary.py | 28 ++++++--- monai/transforms/utility/array.py | 29 ++++++++- monai/transforms/utility/dictionary.py | 31 ++++++++++ tests/min_tests.py | 2 + tests/test_ensure_channel_first.py | 86 ++++++++++++++++++++++++++ tests/test_ensure_channel_firstd.py | 62 +++++++++++++++++++ tests/test_nifti_endianness.py | 48 ++++++++++++++ 11 files changed, 343 insertions(+), 22 deletions(-) create mode 100644 tests/test_ensure_channel_first.py create mode 100644 tests/test_ensure_channel_firstd.py create mode 100644 tests/test_nifti_endianness.py diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index 00d8cb9053..dd10176de9 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -437,6 +437,12 @@ Utility :members: :special-members: __call__ +`EnsureChannelFirst` +"""""""""""""""""""" +.. autoclass:: EnsureChannelFirst + :members: + :special-members: __call__ + `RepeatChannel` """"""""""""""" .. autoclass:: RepeatChannel @@ -890,6 +896,12 @@ Utility (Dict) :members: :special-members: __call__ +`EnsureChannelFirstd` +""""""""""""""""""""" +.. autoclass:: EnsureChannelFirstd + :members: + :special-members: __call__ + `RepeatChanneld` """""""""""""""" .. autoclass:: RepeatChanneld diff --git a/monai/data/image_reader.py b/monai/data/image_reader.py index e458833979..dfbdaf5b41 100644 --- a/monai/data/image_reader.py +++ b/monai/data/image_reader.py @@ -109,6 +109,17 @@ def _copy_compatible_dict(from_dict: Dict, to_dict: Dict): ) +def _stack_images(image_list: List, meta_dict: Dict): + if len(image_list) > 1: + if meta_dict.get("original_channel_dim", None) not in ("no_channel", None): + raise RuntimeError("can not read a list of images which already have channel dimension.") + meta_dict["original_channel_dim"] = 0 + img_array = np.stack(image_list, axis=0) + else: + img_array = image_list[0] + return img_array + + class ITKReader(ImageReader): """ Load medical images based on ITK library. @@ -200,11 +211,12 @@ def get_data(self, img): header["original_affine"] = self._get_affine(i) header["affine"] = header["original_affine"].copy() header["spatial_shape"] = self._get_spatial_shape(i) - img_array.append(self._get_array_data(i)) + data = self._get_array_data(i) + img_array.append(data) + header["original_channel_dim"] = "no_channel" if len(data.shape) == len(header["spatial_shape"]) else -1 _copy_compatible_dict(header, compatible_meta) - img_array_ = np.stack(img_array, axis=0) if len(img_array) > 1 else img_array[0] - return img_array_, compatible_meta + return _stack_images(img_array, compatible_meta), compatible_meta def _get_meta_dict(self, img) -> Dict: """ @@ -265,6 +277,7 @@ def _get_spatial_shape(self, img) -> np.ndarray: img: a ITK image object loaded from a image file. """ + # the img data should have no channel dim or the last dim is channel shape = list(itk.size(img)) shape.reverse() return np.asarray(shape) @@ -371,11 +384,12 @@ def get_data(self, img): i = nib.as_closest_canonical(i) header["affine"] = self._get_affine(i) header["spatial_shape"] = self._get_spatial_shape(i) - img_array.append(self._get_array_data(i)) + data = self._get_array_data(i) + img_array.append(data) + header["original_channel_dim"] = "no_channel" if len(data.shape) == len(header["spatial_shape"]) else -1 _copy_compatible_dict(header, compatible_meta) - img_array_ = np.stack(img_array, axis=0) if len(img_array) > 1 else img_array[0] - return img_array_, compatible_meta + return _stack_images(img_array, compatible_meta), compatible_meta def _get_meta_dict(self, img) -> Dict: """ @@ -408,6 +422,7 @@ def _get_spatial_shape(self, img) -> np.ndarray: """ ndim = img.header["dim"][0] spatial_rank = min(ndim, 3) + # the img data should have no channel dim or the last dim is channel return np.asarray(img.header["dim"][1 : spatial_rank + 1]) def _get_array_data(self, img) -> np.ndarray: @@ -504,12 +519,12 @@ def get_data(self, img): for i in ensure_tuple(img): header = {} if isinstance(i, np.ndarray): + # can not detect the channel dim of numpy array, use all the dims as spatial_shape header["spatial_shape"] = i.shape img_array.append(i) _copy_compatible_dict(header, compatible_meta) - img_array_ = np.stack(img_array, axis=0) if len(img_array) > 1 else img_array[0] - return img_array_, compatible_meta + return _stack_images(img_array, compatible_meta), compatible_meta class PILReader(ImageReader): @@ -582,11 +597,12 @@ def get_data(self, img): for i in ensure_tuple(img): header = self._get_meta_dict(i) header["spatial_shape"] = self._get_spatial_shape(i) - img_array.append(np.asarray(i)) + data = np.asarray(i) + img_array.append(data) + header["original_channel_dim"] = "no_channel" if len(data.shape) == len(header["spatial_shape"]) else -1 _copy_compatible_dict(header, compatible_meta) - img_array_ = np.stack(img_array, axis=0) if len(img_array) > 1 else img_array[0] - return img_array_, compatible_meta + return _stack_images(img_array, compatible_meta), compatible_meta def _get_meta_dict(self, img) -> Dict: """ @@ -608,4 +624,5 @@ def _get_spatial_shape(self, img) -> np.ndarray: Args: img: a PIL Image object loaded from a image file. """ + # the img data should have no channel dim or the last dim is channel return np.asarray((img.width, img.height)) diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index 72e0dba15e..e57394aab9 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -251,6 +251,7 @@ CastToType, ConvertToMultiChannelBasedOnBratsClasses, DataStats, + EnsureChannelFirst, FgBgToIndices, Identity, LabelToMask, @@ -297,6 +298,9 @@ DeleteItemsd, DeleteItemsD, DeleteItemsDict, + EnsureChannelFirstd, + EnsureChannelFirstD, + EnsureChannelFirstDict, FgBgToIndicesd, FgBgToIndicesD, FgBgToIndicesDict, diff --git a/monai/transforms/io/array.py b/monai/transforms/io/array.py index 855621e432..9c4f631699 100644 --- a/monai/transforms/io/array.py +++ b/monai/transforms/io/array.py @@ -33,6 +33,27 @@ __all__ = ["LoadImage", "SaveImage"] +def switch_endianness(data, old, new): + """ + If any numpy arrays have `old` (e.g., ">"), + replace with `new` (e.g., "<"). + """ + if isinstance(data, np.ndarray): + if data.dtype.byteorder == old: + data = data.newbyteorder(new) + elif isinstance(data, tuple): + data = (switch_endianness(x, old, new) for x in data) + elif isinstance(data, list): + data = [switch_endianness(x, old, new) for x in data] + elif isinstance(data, dict): + data = {k: switch_endianness(v, old, new) for k, v in data.items()} + elif isinstance(data, (bool, str, float, int)): + pass + else: + raise AssertionError() + return data + + class LoadImage(Transform): """ Load image file or files from provided path based on reader. @@ -132,6 +153,9 @@ def __call__( if self.image_only: return img_array meta_data[Key.FILENAME_OR_OBJ] = ensure_tuple(filename)[0] + # make sure all elements in metadata are little endian + meta_data = switch_endianness(meta_data, ">", "<") + return img_array, meta_data diff --git a/monai/transforms/io/dictionary.py b/monai/transforms/io/dictionary.py index 55707f750e..d9b6b5e6ab 100644 --- a/monai/transforms/io/dictionary.py +++ b/monai/transforms/io/dictionary.py @@ -59,6 +59,7 @@ def __init__( dtype: DtypeLike = np.float32, meta_key_postfix: str = "meta_dict", overwriting: bool = False, + image_only: bool = False, *args, **kwargs, ) -> None: @@ -76,11 +77,13 @@ def __init__( For example, load nifti file for `image`, store the metadata into `image_meta_dict`. overwriting: whether allow to overwrite existing meta data of same key. default is False, which will raise exception if encountering existing key. + image_only: if True return dictionary containing just only the image volumes, otherwise return + dictionary containing image data array and header dict per input key. args: additional parameters for reader if providing a reader name. kwargs: additional parameters for reader if providing a reader name. """ super().__init__(keys) - self._loader = LoadImage(reader, False, dtype, *args, **kwargs) + self._loader = LoadImage(reader, image_only, dtype, *args, **kwargs) if not isinstance(meta_key_postfix, str): raise TypeError(f"meta_key_postfix must be a str but is {type(meta_key_postfix).__name__}.") self.meta_key_postfix = meta_key_postfix @@ -98,15 +101,20 @@ def __call__(self, data, reader: Optional[ImageReader] = None): d = dict(data) for key in self.keys: data = self._loader(d[key], reader) - if not isinstance(data, (tuple, list)): - raise ValueError("loader must return a tuple or list.") - d[key] = data[0] - if not isinstance(data[1], dict): - raise ValueError("metadata must be a dict.") - key_to_add = f"{key}_{self.meta_key_postfix}" - if key_to_add in d and not self.overwriting: - raise KeyError(f"Meta data with key {key_to_add} already exists and overwriting=False.") - d[key_to_add] = data[1] + if self._loader.image_only: + if not isinstance(data, np.ndarray): + raise ValueError("loader must return a numpy array (because image_only=True was used).") + d[key] = data + else: + if not isinstance(data, (tuple, list)): + raise ValueError("loader must return a tuple or list (because image_only=False was used).") + d[key] = data[0] + if not isinstance(data[1], dict): + raise ValueError("metadata must be a dict.") + key_to_add = f"{key}_{self.meta_key_postfix}" + if key_to_add in d and not self.overwriting: + raise KeyError(f"Meta data with key {key_to_add} already exists and overwriting=False.") + d[key_to_add] = data[1] return d diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index 24d2feb781..62daf9309c 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -15,7 +15,7 @@ import logging import time -from typing import TYPE_CHECKING, Callable, List, Optional, Sequence, Tuple, Union +from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Sequence, Tuple, Union import numpy as np import torch @@ -39,6 +39,7 @@ "AsChannelFirst", "AsChannelLast", "AddChannel", + "EnsureChannelFirst", "RepeatChannel", "RemoveRepeatedChannel", "SplitChannel", @@ -149,6 +150,32 @@ def __call__(self, img: NdarrayTensor): return img[None] +class EnsureChannelFirst(Transform): + """ + Automatically adjust or add the channel dimension of input data to ensure `channel_first` shape. + It extracts the `original_channel_dim` info from provided meta_data dictionary. + Typical values of `original_channel_dim` can be: "no_channel", 0, -1. + Convert the data to `channel_first` based on the `original_channel_dim` information. + + """ + + def __call__(self, img: np.ndarray, meta_dict: Optional[Dict] = None): + """ + Apply the transform to `img`. + """ + if not isinstance(meta_dict, dict): + raise ValueError("meta_dict must be a dictionay data.") + + channel_dim = meta_dict.get("original_channel_dim", None) + + if channel_dim is None: + raise ValueError("meta_dict must contain `original_channel_dim` information.") + elif channel_dim == "no_channel": + return AddChannel()(img) + else: + return AsChannelFirst(channel_dim=channel_dim)(img) + + class RepeatChannel(Transform): """ Repeat channel data to construct expected input shape for models. diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index e9d923d0fd..4a0808fdbb 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -31,6 +31,7 @@ CastToType, ConvertToMultiChannelBasedOnBratsClasses, DataStats, + EnsureChannelFirst, FgBgToIndices, Identity, LabelToMask, @@ -60,6 +61,7 @@ "AsChannelFirstd", "AsChannelLastd", "AddChanneld", + "EnsureChannelFirstd", "RepeatChanneld", "RemoveRepeatedChanneld", "SplitChanneld", @@ -89,6 +91,8 @@ "AsChannelLastDict", "AddChannelD", "AddChannelDict", + "EnsureChannelFirstD", + "EnsureChannelFirstDict", "RandLambdaD", "RandLambdaDict", "RepeatChannelD", @@ -217,6 +221,32 @@ def __call__(self, data: Mapping[Hashable, NdarrayTensor]) -> Dict[Hashable, Nda return d +class EnsureChannelFirstd(MapTransform): + """ + Dictionary-based wrapper of :py:class:`monai.transforms.EnsureChannelFirst`. + """ + + def __init__(self, keys: KeysCollection, meta_key_postfix: str = "meta_dict") -> None: + """ + Args: + keys: keys of the corresponding items to be transformed. + See also: :py:class:`monai.transforms.compose.MapTransform` + meta_key_postfix: `key_{postfix}` was used to store the metadata in `LoadImaged`. + So need the key to extract metadata for channel dim information, default is `meta_dict`. + For example, for data with key `image`, metadata by default is in `image_meta_dict`. + + """ + super().__init__(keys) + self.adjuster = EnsureChannelFirst() + self.meta_key_postfix = meta_key_postfix + + def __call__(self, data) -> Dict[Hashable, np.ndarray]: + d = dict(data) + for key in self.keys: + d[key] = self.adjuster(d[key], d[f"{key}_{self.meta_key_postfix}"]) + return d + + class RepeatChanneld(MapTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.RepeatChannel`. @@ -894,6 +924,7 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torc AsChannelFirstD = AsChannelFirstDict = AsChannelFirstd AsChannelLastD = AsChannelLastDict = AsChannelLastd AddChannelD = AddChannelDict = AddChanneld +EnsureChannelFirstD = EnsureChannelFirstDict = EnsureChannelFirstd RemoveRepeatedChannelD = RemoveRepeatedChannelDict = RemoveRepeatedChanneld RepeatChannelD = RepeatChannelDict = RepeatChanneld SplitChannelD = SplitChannelDict = SplitChanneld diff --git a/tests/min_tests.py b/tests/min_tests.py index 999a1aeaa0..83c1ceea9f 100644 --- a/tests/min_tests.py +++ b/tests/min_tests.py @@ -109,6 +109,8 @@ def run_testsuit(): "test_deepgrow_dataset", "test_save_image", "test_save_imaged", + "test_ensure_channel_first", + "test_ensure_channel_firstd", ] assert sorted(exclude_cases) == sorted(set(exclude_cases)), f"Duplicated items in {exclude_cases}" diff --git a/tests/test_ensure_channel_first.py b/tests/test_ensure_channel_first.py new file mode 100644 index 0000000000..ff656f2e24 --- /dev/null +++ b/tests/test_ensure_channel_first.py @@ -0,0 +1,86 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import tempfile +import unittest + +import itk +import nibabel as nib +import numpy as np +from parameterized import parameterized +from PIL import Image + +from monai.data import ITKReader +from monai.transforms import EnsureChannelFirst, LoadImage + +TEST_CASE_1 = [{"image_only": False}, ["test_image.nii.gz"], None] + +TEST_CASE_2 = [{"image_only": False}, ["test_image.nii.gz"], -1] + +TEST_CASE_3 = [ + {"image_only": False}, + ["test_image.nii.gz", "test_image2.nii.gz", "test_image3.nii.gz"], + None, +] + +TEST_CASE_4 = [{"reader": ITKReader(), "image_only": False}, ["test_image.nii.gz"], None] + +TEST_CASE_5 = [{"reader": ITKReader(), "image_only": False}, ["test_image.nii.gz"], -1] + +TEST_CASE_6 = [ + {"reader": ITKReader(), "image_only": False}, + ["test_image.nii.gz", "test_image2.nii.gz", "test_image3.nii.gz"], + None, +] + +TEST_CASE_7 = [ + {"image_only": False, "reader": ITKReader(pixel_type=itk.UC)}, + "tests/testing_data/CT_DICOM", + None, +] + + +class TestEnsureChannelFirst(unittest.TestCase): + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6]) + def test_load_nifti(self, input_param, filenames, original_channel_dim): + if original_channel_dim is None: + test_image = np.random.rand(128, 128, 128) + elif original_channel_dim == -1: + test_image = np.random.rand(128, 128, 128, 1) + + with tempfile.TemporaryDirectory() as tempdir: + for i, name in enumerate(filenames): + filenames[i] = os.path.join(tempdir, name) + nib.save(nib.Nifti1Image(test_image, np.eye(4)), filenames[i]) + result, header = LoadImage(**input_param)(filenames) + result = EnsureChannelFirst()(result, header) + self.assertEqual(result.shape[0], len(filenames)) + + @parameterized.expand([TEST_CASE_7]) + def test_itk_dicom_series_reader(self, input_param, filenames, original_channel_dim): + result, header = LoadImage(**input_param)(filenames) + result = EnsureChannelFirst()(result, header) + self.assertEqual(result.shape[0], 1) + + def test_load_png(self): + spatial_size = (256, 256, 3) + test_image = np.random.randint(0, 256, size=spatial_size) + with tempfile.TemporaryDirectory() as tempdir: + filename = os.path.join(tempdir, "test_image.png") + Image.fromarray(test_image.astype("uint8")).save(filename) + result, header = LoadImage(image_only=False)(filename) + result = EnsureChannelFirst()(result, header) + self.assertEqual(result.shape[0], 3) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_ensure_channel_firstd.py b/tests/test_ensure_channel_firstd.py new file mode 100644 index 0000000000..a5298f4453 --- /dev/null +++ b/tests/test_ensure_channel_firstd.py @@ -0,0 +1,62 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import tempfile +import unittest + +import nibabel as nib +import numpy as np +from parameterized import parameterized +from PIL import Image + +from monai.transforms import EnsureChannelFirstd, LoadImaged + +TEST_CASE_1 = [{"keys": "img"}, ["test_image.nii.gz"], None] + +TEST_CASE_2 = [{"keys": "img"}, ["test_image.nii.gz"], -1] + +TEST_CASE_3 = [ + {"keys": "img"}, + ["test_image.nii.gz", "test_image2.nii.gz", "test_image3.nii.gz"], + None, +] + + +class TestEnsureChannelFirstd(unittest.TestCase): + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) + def test_load_nifti(self, input_param, filenames, original_channel_dim): + if original_channel_dim is None: + test_image = np.random.rand(128, 128, 128) + elif original_channel_dim == -1: + test_image = np.random.rand(128, 128, 128, 1) + + with tempfile.TemporaryDirectory() as tempdir: + for i, name in enumerate(filenames): + filenames[i] = os.path.join(tempdir, name) + nib.save(nib.Nifti1Image(test_image, np.eye(4)), filenames[i]) + result = LoadImaged(**input_param)({"img": filenames}) + result = EnsureChannelFirstd(**input_param)(result) + self.assertEqual(result["img"].shape[0], len(filenames)) + + def test_load_png(self): + spatial_size = (256, 256, 3) + test_image = np.random.randint(0, 256, size=spatial_size) + with tempfile.TemporaryDirectory() as tempdir: + filename = os.path.join(tempdir, "test_image.png") + Image.fromarray(test_image.astype("uint8")).save(filename) + result = LoadImaged(keys="img")({"img": filename}) + result = EnsureChannelFirstd(keys="img")(result) + self.assertEqual(result["img"].shape[0], 3) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_nifti_endianness.py b/tests/test_nifti_endianness.py new file mode 100644 index 0000000000..14317c0832 --- /dev/null +++ b/tests/test_nifti_endianness.py @@ -0,0 +1,48 @@ +import tempfile +import unittest +from typing import TYPE_CHECKING, List, Tuple +from unittest.case import skipUnless + +import numpy as np +from parameterized import parameterized + +from monai.data import DataLoader, Dataset, create_test_image_2d +from monai.transforms import LoadImage, LoadImaged +from monai.utils.module import optional_import + +if TYPE_CHECKING: + import nibabel as nib + + has_nib = True +else: + nib, has_nib = optional_import("nibabel") + +TESTS: List[Tuple] = [] +for endianness in ["<", ">"]: + for use_array in [True, False]: + for image_only in [True, False]: + TESTS.append((endianness, use_array, image_only)) + + +class TestNiftiEndianness(unittest.TestCase): + def setUp(self): + self.im, _ = create_test_image_2d(100, 100) + self.fname = tempfile.NamedTemporaryFile(suffix=".nii.gz").name + + @parameterized.expand(TESTS) + @skipUnless(has_nib, "Requires NiBabel") + def test_endianness(self, endianness, use_array, image_only): + + hdr = nib.Nifti1Header(endianness=endianness) + nii = nib.Nifti1Image(self.im, np.eye(4), header=hdr) + nib.save(nii, self.fname) + + data = [self.fname] if use_array else [{"image": self.fname}] + tr = LoadImage(image_only=image_only) if use_array else LoadImaged("image", image_only=image_only) + check_ds = Dataset(data, tr) + check_loader = DataLoader(check_ds, batch_size=1) + _ = next(iter(check_loader)) + + +if __name__ == "__main__": + unittest.main() From 088f6260a78ceb931ecc70349aa41c62c79a3aaf Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Tue, 2 Mar 2021 18:03:04 +0000 Subject: [PATCH 11/64] debug message Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- tests/test_decollate.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/tests/test_decollate.py b/tests/test_decollate.py index d60eb64316..70ce514c6f 100644 --- a/tests/test_decollate.py +++ b/tests/test_decollate.py @@ -79,9 +79,17 @@ def test_decollation(self, _, data, batch_size=2, num_workers=2): decollated_1 = decollate_batch(batch_data) decollated_2 = Decollated()(batch_data) - for decollated in [decollated_1, decollated_2]: + for z, decollated in enumerate([decollated_1, decollated_2]): for i, d in enumerate(decollated): - self.check_match(dataset[b * batch_size + i], d) + try: + self.check_match(dataset[b * batch_size + i], d) + except RuntimeError: + print(f"problem with b={b}, i={i}, decollated_{z+1}") + print("d") + print(d) + print("dataset[b * batch_size + i]") + print(dataset[b * batch_size + i]) + raise if __name__ == "__main__": From 4b7ba75127178640ab4bf4f8415aa688f8300a9a Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Tue, 2 Mar 2021 19:10:10 +0000 Subject: [PATCH 12/64] don't write to file if no nibabel Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- tests/test_inverse.py | 31 +++++++++++++++++++++++-------- 1 file changed, 23 insertions(+), 8 deletions(-) diff --git a/tests/test_inverse.py b/tests/test_inverse.py index 0bcbfdc46c..e35e65df01 100644 --- a/tests/test_inverse.py +++ b/tests/test_inverse.py @@ -61,24 +61,33 @@ has_matplotlib = True has_vtk = True + has_nib = True else: plt, has_matplotlib = optional_import("matplotlib.pyplot") _, has_vtk = optional_import("vtk") + _, has_nib = optional_import("nibabel") -set_determinism(seed=0) AFFINE = make_rand_affine() AFFINE[0] *= 2 -IM_1D = AddChannel()(np.arange(0, 10)) -IM_2D_FNAME, SEG_2D_FNAME = [make_nifti_image(i) for i in create_test_image_2d(101, 100)] -IM_3D_FNAME, SEG_3D_FNAME = [make_nifti_image(i, AFFINE) for i in create_test_image_3d(100, 101, 107)] - KEYS = ["image", "label"] +IM_1D = AddChannel()(np.arange(0, 10)) DATA_1D = {"image": IM_1D, "label": IM_1D, "other": IM_1D} -LOAD_IMS = Compose([LoadImaged(KEYS), AddChanneld(KEYS)]) -DATA_2D = LOAD_IMS({"image": IM_2D_FNAME, "label": SEG_2D_FNAME}) -DATA_3D = LOAD_IMS({"image": IM_3D_FNAME, "label": SEG_3D_FNAME}) + +IM_2D, SEG_2D = create_test_image_2d(101, 100) +IM_3D, SEG_3D = create_test_image_3d(100, 101, 107) +if has_nib: + IM_2D_FNAME, SEG_2D_FNAME = [make_nifti_image(i) for i in create_test_image_2d(101, 100)] + IM_3D_FNAME, SEG_3D_FNAME = [make_nifti_image(i, AFFINE) for i in create_test_image_3d(100, 101, 107)] + + LOAD_IMS = Compose([LoadImaged(KEYS), AddChanneld(KEYS)]) + DATA_2D = LOAD_IMS({"image": IM_2D_FNAME, "label": SEG_2D_FNAME}) + DATA_3D = LOAD_IMS({"image": IM_3D_FNAME, "label": SEG_3D_FNAME}) +else: + ADD_CH = AddChanneld(KEYS) + DATA_2D = ADD_CH({"image": IM_2D, "label": SEG_2D}) + DATA_3D = ADD_CH({"image": IM_3D, "label": SEG_3D}) TESTS: List[Tuple] = [] @@ -429,6 +438,12 @@ def plot_im(orig, fwd_bck, fwd): class TestInverse(unittest.TestCase): + def setUp(self): + set_determinism(seed=0) + + def tearDown(self): + set_determinism(seed=None) + def check_inverse(self, name, keys, orig_d, fwd_bck_d, unmodified_d, acceptable_diff): for key in keys: orig = orig_d[key] From e8bbdd74db08e834fb0ee5b05ceb926811234dff Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Wed, 3 Mar 2021 13:19:53 +0000 Subject: [PATCH 13/64] skip if no meta data key Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/spatial/dictionary.py | 7 +++++++ tests/test_inverse.py | 6 +++--- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index 566d9868f0..25ff6c618b 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -20,6 +20,7 @@ import numpy as np import torch +import warnings from monai.config import DtypeLike, KeysCollection from monai.networks.layers import AffineTransform @@ -192,6 +193,9 @@ def __call__( d, self.mode, self.padding_mode, self.align_corners, self.dtype ): meta_data_key = f"{key}_{self.meta_key_postfix}" + if meta_data_key not in data.keys(): + warnings.warn(f"No meta data found with key: {meta_data_key}. Nothing to do.") + continue meta_data = d[meta_data_key] # resample array of each corresponding key # using affine fetched from d[affine_key] @@ -298,6 +302,9 @@ def __call__( d: Dict = dict(data) for key in self.key_iterator(d): meta_data_key = f"{key}_{self.meta_key_postfix}" + if meta_data_key not in data.keys(): + warnings.warn(f"No meta data found with key: {meta_data_key}. Nothing to do.") + continue meta_data = d[meta_data_key] d[key], old_affine, new_affine = self.ornt_transform(d[key], affine=meta_data["affine"]) self.append_applied_transforms( diff --git a/tests/test_inverse.py b/tests/test_inverse.py index e35e65df01..e46fa3dced 100644 --- a/tests/test_inverse.py +++ b/tests/test_inverse.py @@ -245,7 +245,7 @@ ( "RandRotated 3d", DATA_3D, - 5e-2, + 1e-1, RandRotated(KEYS, *[random.uniform(np.pi / 6, np.pi) for _ in range(3)], 1), # type: ignore ) ) @@ -347,13 +347,13 @@ TESTS.append(("Resized 3d", DATA_3D, 5e-2, Resized(KEYS, [201, 150, 78]))) -TESTS.append(("ResizeWithPadOrCropd 3d", DATA_3D, 1e-2, ResizeWithPadOrCropd(KEYS, [201, 150, 78]))) +TESTS.append(("ResizeWithPadOrCropd 3d", DATA_3D, 3e-2, ResizeWithPadOrCropd(KEYS, [201, 150, 78]))) TESTS.append( ( "RandAffine 3d", DATA_3D, - 5e-2, + 7e-2, RandAffined( KEYS, [155, 179, 192], From af791a564361f574e7a86709777485549f3a2127 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Wed, 3 Mar 2021 13:22:36 +0000 Subject: [PATCH 14/64] more lenient thresholds for tests Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- tests/test_inverse.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_inverse.py b/tests/test_inverse.py index e46fa3dced..e4f627f378 100644 --- a/tests/test_inverse.py +++ b/tests/test_inverse.py @@ -236,7 +236,7 @@ ( "Rotated 3d", DATA_3D, - 5e-2, + 1e-1, Rotated(KEYS, [random.uniform(np.pi / 6, np.pi) for _ in range(3)], True), # type: ignore ) ) @@ -353,7 +353,7 @@ ( "RandAffine 3d", DATA_3D, - 7e-2, + 1e-1, RandAffined( KEYS, [155, 179, 192], From ffc1b9b00f696ed74b765f62d1757fb86024ca43 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Wed, 3 Mar 2021 13:57:24 +0000 Subject: [PATCH 15/64] isort Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/spatial/dictionary.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index 25ff6c618b..ad71b9a085 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -15,12 +15,12 @@ Class names are ended with 'd' to denote dictionary-based transforms. """ +import warnings from copy import deepcopy from typing import Any, Dict, Hashable, Mapping, Optional, Sequence, Tuple, Union import numpy as np import torch -import warnings from monai.config import DtypeLike, KeysCollection from monai.networks.layers import AffineTransform From 2d6f9a1061e040402c2a4a2bb3e1718c9856b282 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Wed, 3 Mar 2021 14:06:40 +0000 Subject: [PATCH 16/64] mypy Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/spatial/dictionary.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index ad71b9a085..f93fa002fa 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -934,7 +934,7 @@ def __call__( for key, mode, padding_mode in self.key_iterator(d, self.mode, self.padding_mode): self.append_applied_transforms(d, key, extra_info={"grid_no_affine": grid_no_affine, "affine": affine}) - d[key] = self.rand_3d_elastic.resampler(d[key], grid, mode=mode, padding_mode=padding_mode) + d[key] = self.rand_3d_elastic.resampler(d[key], grid_w_affine, mode=mode, padding_mode=padding_mode) return d def inverse( From e5804e1e796555fd21762c57e60a4141913ecb3c Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Wed, 3 Mar 2021 14:33:00 +0000 Subject: [PATCH 17/64] tests require nibabel Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- tests/test_inverse.py | 127 +++++++++++++++++++++--------------------- 1 file changed, 64 insertions(+), 63 deletions(-) diff --git a/tests/test_inverse.py b/tests/test_inverse.py index e4f627f378..b9ae0e3704 100644 --- a/tests/test_inverse.py +++ b/tests/test_inverse.py @@ -67,34 +67,14 @@ _, has_vtk = optional_import("vtk") _, has_nib = optional_import("nibabel") - -AFFINE = make_rand_affine() -AFFINE[0] *= 2 - KEYS = ["image", "label"] -IM_1D = AddChannel()(np.arange(0, 10)) -DATA_1D = {"image": IM_1D, "label": IM_1D, "other": IM_1D} - -IM_2D, SEG_2D = create_test_image_2d(101, 100) -IM_3D, SEG_3D = create_test_image_3d(100, 101, 107) -if has_nib: - IM_2D_FNAME, SEG_2D_FNAME = [make_nifti_image(i) for i in create_test_image_2d(101, 100)] - IM_3D_FNAME, SEG_3D_FNAME = [make_nifti_image(i, AFFINE) for i in create_test_image_3d(100, 101, 107)] - - LOAD_IMS = Compose([LoadImaged(KEYS), AddChanneld(KEYS)]) - DATA_2D = LOAD_IMS({"image": IM_2D_FNAME, "label": SEG_2D_FNAME}) - DATA_3D = LOAD_IMS({"image": IM_3D_FNAME, "label": SEG_3D_FNAME}) -else: - ADD_CH = AddChanneld(KEYS) - DATA_2D = ADD_CH({"image": IM_2D, "label": SEG_2D}) - DATA_3D = ADD_CH({"image": IM_3D, "label": SEG_3D}) TESTS: List[Tuple] = [] TESTS.append( ( "SpatialPadd (x2) 2d", - DATA_2D, + "2D", 0.0, SpatialPadd(KEYS, spatial_size=[111, 113], method="end"), SpatialPadd(KEYS, spatial_size=[118, 117]), @@ -104,7 +84,7 @@ TESTS.append( ( "SpatialPadd 3d", - DATA_3D, + "3D", 0.0, SpatialPadd(KEYS, spatial_size=[112, 113, 116]), ) @@ -113,7 +93,7 @@ TESTS.append( ( "RandRotated, prob 0", - DATA_2D, + "2D", 0, RandRotated(KEYS, prob=0), ) @@ -122,7 +102,7 @@ TESTS.append( ( "SpatialCropd 2d", - DATA_2D, + "2D", 3e-2, SpatialCropd(KEYS, [49, 51], [90, 89]), ) @@ -131,20 +111,20 @@ TESTS.append( ( "SpatialCropd 3d", - DATA_3D, + "3D", 4e-2, SpatialCropd(KEYS, [49, 51, 44], [90, 89, 93]), ) ) -TESTS.append(("RandSpatialCropd 2d", DATA_2D, 5e-2, RandSpatialCropd(KEYS, [96, 93], True, False))) +TESTS.append(("RandSpatialCropd 2d", "2D", 5e-2, RandSpatialCropd(KEYS, [96, 93], True, False))) -TESTS.append(("RandSpatialCropd 3d", DATA_3D, 2e-2, RandSpatialCropd(KEYS, [96, 93, 92], False, False))) +TESTS.append(("RandSpatialCropd 3d", "3D", 2e-2, RandSpatialCropd(KEYS, [96, 93, 92], False, False))) TESTS.append( ( "BorderPadd 2d", - DATA_2D, + "2D", 0, BorderPadd(KEYS, [3, 7, 2, 5]), ) @@ -153,7 +133,7 @@ TESTS.append( ( "BorderPadd 2d", - DATA_2D, + "2D", 0, BorderPadd(KEYS, [3, 7]), ) @@ -162,7 +142,7 @@ TESTS.append( ( "BorderPadd 3d", - DATA_3D, + "3D", 0, BorderPadd(KEYS, [4]), ) @@ -171,7 +151,7 @@ TESTS.append( ( "DivisiblePadd 2d", - DATA_2D, + "2D", 0, DivisiblePadd(KEYS, k=4), ) @@ -180,7 +160,7 @@ TESTS.append( ( "DivisiblePadd 3d", - DATA_3D, + "3D", 0, DivisiblePadd(KEYS, k=[4, 8, 11]), ) @@ -189,7 +169,7 @@ TESTS.append( ( "Flipd 3d", - DATA_3D, + "3D", 0, Flipd(KEYS, [1, 2]), ) @@ -198,7 +178,7 @@ TESTS.append( ( "Flipd 3d", - DATA_3D, + "3D", 0, Flipd(KEYS, [1, 2]), ) @@ -207,7 +187,7 @@ TESTS.append( ( "RandFlipd 3d", - DATA_3D, + "3D", 0, RandFlipd(KEYS, 1, [1, 2]), ) @@ -216,7 +196,7 @@ TESTS.append( ( "RandAxisFlipd 3d", - DATA_3D, + "3D", 0, RandAxisFlipd(KEYS, 1), ) @@ -226,7 +206,7 @@ TESTS.append( ( "Rotated 2d", - DATA_2D, + "2D", 8e-2, Rotated(KEYS, random.uniform(np.pi / 6, np.pi), keep_size=True, align_corners=False), ) @@ -235,7 +215,7 @@ TESTS.append( ( "Rotated 3d", - DATA_3D, + "3D", 1e-1, Rotated(KEYS, [random.uniform(np.pi / 6, np.pi) for _ in range(3)], True), # type: ignore ) @@ -244,7 +224,7 @@ TESTS.append( ( "RandRotated 3d", - DATA_3D, + "3D", 1e-1, RandRotated(KEYS, *[random.uniform(np.pi / 6, np.pi) for _ in range(3)], 1), # type: ignore ) @@ -253,10 +233,10 @@ TESTS.append( ( "Orientationd 3d", - DATA_3D, + "3D", 0, # For data loader, output needs to be same size, so input must be square/cubic - SpatialPadd(KEYS, max(DATA_3D["image"].shape)), + SpatialPadd(KEYS, 110), Orientationd(KEYS, "RAS"), ) ) @@ -264,7 +244,7 @@ TESTS.append( ( "Rotate90d 2d", - DATA_2D, + "2D", 0, Rotate90d(KEYS), ) @@ -273,7 +253,7 @@ TESTS.append( ( "Rotate90d 3d", - DATA_3D, + "3D", 0, Rotate90d(KEYS, k=2, spatial_axes=(1, 2)), ) @@ -282,10 +262,10 @@ TESTS.append( ( "RandRotate90d 3d", - DATA_3D, + "3D", 0, # For data loader, output needs to be same size, so input must be square/cubic - SpatialPadd(KEYS, max(DATA_3D["image"].shape)), + SpatialPadd(KEYS, 110), RandRotate90d(KEYS, prob=1, spatial_axes=(1, 2)), ) ) @@ -293,7 +273,7 @@ TESTS.append( ( "Zoomd 1d", - DATA_1D, + "1D", 0, Zoomd(KEYS, zoom=2, keep_size=False), ) @@ -302,7 +282,7 @@ TESTS.append( ( "Zoomd 2d", - DATA_2D, + "2D", 2e-1, Zoomd(KEYS, zoom=0.9), ) @@ -311,18 +291,18 @@ TESTS.append( ( "Zoomd 3d", - DATA_3D, + "3D", 3e-2, Zoomd(KEYS, zoom=[2.5, 1, 3], keep_size=False), ) ) -TESTS.append(("RandZoom 3d", DATA_3D, 9e-2, RandZoomd(KEYS, 1, [0.5, 0.6, 0.9], [1.1, 1, 1.05], keep_size=True))) +TESTS.append(("RandZoom 3d", "3D", 9e-2, RandZoomd(KEYS, 1, [0.5, 0.6, 0.9], [1.1, 1, 1.05], keep_size=True))) TESTS.append( ( "CenterSpatialCropd 2d", - DATA_2D, + "2D", 0, CenterSpatialCropd(KEYS, roi_size=95), ) @@ -331,28 +311,28 @@ TESTS.append( ( "CenterSpatialCropd 3d", - DATA_3D, + "3D", 0, CenterSpatialCropd(KEYS, roi_size=[95, 97, 98]), ) ) -TESTS.append(("CropForegroundd 2d", DATA_2D, 0, CropForegroundd(KEYS, source_key="label", margin=2))) +TESTS.append(("CropForegroundd 2d", "2D", 0, CropForegroundd(KEYS, source_key="label", margin=2))) -TESTS.append(("CropForegroundd 3d", DATA_3D, 0, CropForegroundd(KEYS, source_key="label"))) +TESTS.append(("CropForegroundd 3d", "3D", 0, CropForegroundd(KEYS, source_key="label"))) -TESTS.append(("Spacingd 3d", DATA_3D, 3e-2, Spacingd(KEYS, [0.5, 0.7, 0.9], diagonal=False))) +TESTS.append(("Spacingd 3d", "3D", 3e-2, Spacingd(KEYS, [0.5, 0.7, 0.9], diagonal=False))) -TESTS.append(("Resized 2d", DATA_2D, 2e-1, Resized(KEYS, [50, 47]))) +TESTS.append(("Resized 2d", "2D", 2e-1, Resized(KEYS, [50, 47]))) -TESTS.append(("Resized 3d", DATA_3D, 5e-2, Resized(KEYS, [201, 150, 78]))) +TESTS.append(("Resized 3d", "3D", 5e-2, Resized(KEYS, [201, 150, 78]))) -TESTS.append(("ResizeWithPadOrCropd 3d", DATA_3D, 3e-2, ResizeWithPadOrCropd(KEYS, [201, 150, 78]))) +TESTS.append(("ResizeWithPadOrCropd 3d", "3D", 3e-2, ResizeWithPadOrCropd(KEYS, [201, 150, 78]))) TESTS.append( ( "RandAffine 3d", - DATA_3D, + "3D", 1e-1, RandAffined( KEYS, @@ -371,7 +351,7 @@ TESTS.append( ( "Rand2DElasticd 2d", - DATA_2D, + "2D", 2e-1, Rand2DElasticd( KEYS, @@ -392,7 +372,7 @@ TESTS.append( ( "Rand3DElasticd 3d", - DATA_3D, + "3D", 1e-1, Rand3DElasticd( KEYS, @@ -414,7 +394,7 @@ # Should fail because uses an array transform (SpatialPad), as opposed to dictionary -TEST_FAIL_0 = (DATA_2D["image"], 0.0, Compose([SpatialPad(spatial_size=[101, 103])])) +TEST_FAIL_0 = ("2D", 0.0, Compose([SpatialPad(spatial_size=[101, 103])])) TESTS_FAIL = [TEST_FAIL_0] @@ -439,8 +419,26 @@ def plot_im(orig, fwd_bck, fwd): class TestInverse(unittest.TestCase): def setUp(self): + if not has_nib: + self.skipTest("nibabel required for test_inverse") + set_determinism(seed=0) + self.all_data = {} + + affine = make_rand_affine() + affine[0] *= 2 + + im_1d = AddChannel()(np.arange(0, 10)) + self.all_data["1D"] = {"image": im_1d, "label": im_1d, "other": im_1d} + + im_2d_fname, seg_2d_fname = [make_nifti_image(i) for i in create_test_image_2d(101, 100)] + im_3d_fname, seg_3d_fname = [make_nifti_image(i, affine) for i in create_test_image_3d(100, 101, 107)] + + load_ims = Compose([LoadImaged(KEYS), AddChanneld(KEYS)]) + self.all_data["2D"] = load_ims({"image": im_2d_fname, "label": seg_2d_fname}) + self.all_data["3D"] = load_ims({"image": im_3d_fname, "label": seg_3d_fname}) + def tearDown(self): set_determinism(seed=None) @@ -468,9 +466,11 @@ def check_inverse(self, name, keys, orig_d, fwd_bck_d, unmodified_d, acceptable_ raise @parameterized.expand(TESTS) - def test_inverse(self, _, data, acceptable_diff, *transforms): + def test_inverse(self, _, data_name, acceptable_diff, *transforms): name = _ + data = self.all_data[data_name] + forwards = [data.copy()] # Apply forwards @@ -490,7 +490,8 @@ def test_inverse(self, _, data, acceptable_diff, *transforms): self.check_inverse(name, data.keys(), forwards[-i - 2], fwd_bck, forwards[-1], acceptable_diff) @parameterized.expand(TESTS_FAIL) - def test_fail(self, data, _, *transform): + def test_fail(self, data_name, _, *transform): + data = self.all_data[data_name]["image"] d = transform[0](data) with self.assertRaises(RuntimeError): d = transform[0].inverse(d) From e62a4ce139f5976ce51749db053f66613afa4783 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Wed, 3 Mar 2021 14:36:01 +0000 Subject: [PATCH 18/64] undo skip if no metadata Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/spatial/dictionary.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index f93fa002fa..1c3f52ea10 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -15,7 +15,6 @@ Class names are ended with 'd' to denote dictionary-based transforms. """ -import warnings from copy import deepcopy from typing import Any, Dict, Hashable, Mapping, Optional, Sequence, Tuple, Union @@ -193,9 +192,6 @@ def __call__( d, self.mode, self.padding_mode, self.align_corners, self.dtype ): meta_data_key = f"{key}_{self.meta_key_postfix}" - if meta_data_key not in data.keys(): - warnings.warn(f"No meta data found with key: {meta_data_key}. Nothing to do.") - continue meta_data = d[meta_data_key] # resample array of each corresponding key # using affine fetched from d[affine_key] @@ -302,9 +298,6 @@ def __call__( d: Dict = dict(data) for key in self.key_iterator(d): meta_data_key = f"{key}_{self.meta_key_postfix}" - if meta_data_key not in data.keys(): - warnings.warn(f"No meta data found with key: {meta_data_key}. Nothing to do.") - continue meta_data = d[meta_data_key] d[key], old_affine, new_affine = self.ornt_transform(d[key], affine=meta_data["affine"]) self.append_applied_transforms( From df76f0706a9a2527ab031764bd0d0e54c5ed63ad Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Wed, 3 Mar 2021 14:59:32 +0000 Subject: [PATCH 19/64] update test_decollate Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- tests/test_decollate.py | 20 +++++--------------- 1 file changed, 5 insertions(+), 15 deletions(-) diff --git a/tests/test_decollate.py b/tests/test_decollate.py index 70ce514c6f..785a847003 100644 --- a/tests/test_decollate.py +++ b/tests/test_decollate.py @@ -13,7 +13,6 @@ import numpy as np import torch -from parameterized import parameterized from monai.data import CacheDataset, DataLoader, create_test_image_2d from monai.data.utils import decollate_batch @@ -24,18 +23,6 @@ _, has_nib = optional_import("nibabel") -IM_2D = create_test_image_2d(100, 101)[0] -DATA_2D = {"image": make_nifti_image(IM_2D) if has_nib else IM_2D} - -TESTS = [] -TESTS.append( - ( - "2D", - [DATA_2D for _ in range(6)], - ) -) - - class TestDeCollate(unittest.TestCase): def setUp(self) -> None: set_determinism(seed=0) @@ -58,8 +45,11 @@ def check_match(self, in1, in2): else: raise RuntimeError(f"Not sure how to compare types. type(in1): {type(in1)}, type(in2): {type(in2)}") - @parameterized.expand(TESTS) - def test_decollation(self, _, data, batch_size=2, num_workers=2): + def test_decollation(self, batch_size=2, num_workers=2): + + im = create_test_image_2d(100, 101)[0] + data = [{"image": make_nifti_image(im) if has_nib else im} for _ in range(6)] + transforms = Compose( [ AddChanneld("image"), From cd0f3059ba64b53d0f65d68e5dbdaeec9ed26e4b Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Wed, 3 Mar 2021 15:04:07 +0000 Subject: [PATCH 20/64] isort Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- tests/test_decollate.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_decollate.py b/tests/test_decollate.py index 785a847003..5a5c39cf86 100644 --- a/tests/test_decollate.py +++ b/tests/test_decollate.py @@ -23,6 +23,7 @@ _, has_nib = optional_import("nibabel") + class TestDeCollate(unittest.TestCase): def setUp(self) -> None: set_determinism(seed=0) From 014f7dff5201be8a1597eaeaf47434b5bd5a5d47 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Thu, 4 Mar 2021 14:33:03 +0000 Subject: [PATCH 21/64] changes Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/handlers/segmentation_saver.py | 7 ++++ monai/networks/nets/densenet.py | 4 +-- monai/networks/nets/senet.py | 4 +-- monai/transforms/io/array.py | 1 + monai/transforms/io/dictionary.py | 1 + monai/transforms/utility/array.py | 2 +- tests/test_convert_to_multi_channel.py | 1 + tests/test_decollate.py | 22 ++++++------- tests/test_densenet.py | 41 ++++++++++++++++++++---- tests/test_integration_sliding_window.py | 2 +- tests/test_inverse.py | 4 +-- tests/test_senet.py | 32 +++++++++++++++--- 12 files changed, 90 insertions(+), 31 deletions(-) diff --git a/monai/handlers/segmentation_saver.py b/monai/handlers/segmentation_saver.py index a46918b893..56370fd41c 100644 --- a/monai/handlers/segmentation_saver.py +++ b/monai/handlers/segmentation_saver.py @@ -41,6 +41,7 @@ def __init__( scale: Optional[int] = None, dtype: DtypeLike = np.float64, output_dtype: DtypeLike = np.float32, + squeeze_end_dims: bool = True, batch_transform: Callable = lambda x: x, output_transform: Callable = lambda x: x, name: Optional[str] = None, @@ -77,6 +78,11 @@ def __init__( If None, use the data type of input data. It's used for Nifti format only. output_dtype: data type for saving data. Defaults to ``np.float32``, it's used for Nifti format only. + squeeze_end_dims: if True, any trailing singleton dimensions will be removed (after the channel + has been moved to the end). So if input is (C,H,W,D), this will be altered to (H,W,D,C), and + then if C==1, it will be saved as (H,W,D). If D also ==1, it will be saved as (H,W). If false, + image will always be saved as (H,W,D,C). + it's used for NIfTI format only. batch_transform: a callable that is used to transform the ignite.engine.batch into expected format to extract the meta_data dictionary. output_transform: a callable that is used to transform the @@ -96,6 +102,7 @@ def __init__( scale=scale, dtype=dtype, output_dtype=output_dtype, + squeeze_end_dims=squeeze_end_dims, save_batch=True, ) self.batch_transform = batch_transform diff --git a/monai/networks/nets/densenet.py b/monai/networks/nets/densenet.py index ad1d1d6e5f..a59ab99e68 100644 --- a/monai/networks/nets/densenet.py +++ b/monai/networks/nets/densenet.py @@ -210,14 +210,14 @@ def _load_state_dict(model, model_url, progress): `_ """ pattern = re.compile( - r"^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$" + r"^(.*denselayer\d+)(\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$" ) state_dict = load_state_dict_from_url(model_url, progress=progress) for key in list(state_dict.keys()): res = pattern.match(key) if res: - new_key = res.group(1) + res.group(2) + new_key = res.group(1) + ".layers" + res.group(2) + res.group(3) state_dict[new_key] = state_dict[key] del state_dict[key] diff --git a/monai/networks/nets/senet.py b/monai/networks/nets/senet.py index 655ff203c7..ef67f853d6 100644 --- a/monai/networks/nets/senet.py +++ b/monai/networks/nets/senet.py @@ -275,7 +275,7 @@ def _load_state_dict(model, model_url, progress): if pattern_conv.match(key): new_key = re.sub(pattern_conv, r"\1conv.\2", key) elif pattern_bn.match(key): - new_key = re.sub(pattern_bn, r"\1conv\2norm.\3", key) + new_key = re.sub(pattern_bn, r"\1conv\2adn.N.\3", key) elif pattern_se.match(key): state_dict[key] = state_dict[key].squeeze() new_key = re.sub(pattern_se, r"\1se_layer.fc.0.\2", key) @@ -285,7 +285,7 @@ def _load_state_dict(model, model_url, progress): elif pattern_down_conv.match(key): new_key = re.sub(pattern_down_conv, r"\1project.conv.\2", key) elif pattern_down_bn.match(key): - new_key = re.sub(pattern_down_bn, r"\1project.norm.\2", key) + new_key = re.sub(pattern_down_bn, r"\1project.adn.N.\2", key) if new_key: state_dict[new_key] = state_dict[key] del state_dict[key] diff --git a/monai/transforms/io/array.py b/monai/transforms/io/array.py index 4ede04cf69..a256a16ec8 100644 --- a/monai/transforms/io/array.py +++ b/monai/transforms/io/array.py @@ -206,6 +206,7 @@ class SaveImage(Transform): has been moved to the end). So if input is (C,H,W,D), this will be altered to (H,W,D,C), and then if C==1, it will be saved as (H,W,D). If D also ==1, it will be saved as (H,W). If false, image will always be saved as (H,W,D,C). + it's used for NIfTI format only. """ diff --git a/monai/transforms/io/dictionary.py b/monai/transforms/io/dictionary.py index 8a428e1118..50ab8f9868 100644 --- a/monai/transforms/io/dictionary.py +++ b/monai/transforms/io/dictionary.py @@ -172,6 +172,7 @@ class SaveImaged(MapTransform): has been moved to the end). So if input is (C,H,W,D), this will be altered to (H,W,D,C), and then if C==1, it will be saved as (H,W,D). If D also ==1, it will be saved as (H,W). If false, image will always be saved as (H,W,D,C). + it's used for NIfTI format only. """ diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index 62daf9309c..8776238711 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -655,7 +655,7 @@ def __call__(self, img: np.ndarray) -> np.ndarray: result.append(np.logical_or(np.logical_or(img == 1, img == 4), img == 2)) # label 4 is ET result.append(img == 4) - return np.stack(result, axis=0).astype(np.float32) + return np.stack(result, axis=0) class AddExtremePointsChannel(RandomizableTransform): diff --git a/tests/test_convert_to_multi_channel.py b/tests/test_convert_to_multi_channel.py index ea27371ac7..03510ad38c 100644 --- a/tests/test_convert_to_multi_channel.py +++ b/tests/test_convert_to_multi_channel.py @@ -27,6 +27,7 @@ class TestConvertToMultiChannel(unittest.TestCase): def test_type_shape(self, data, expected_result): result = ConvertToMultiChannelBasedOnBratsClasses()(data) np.testing.assert_equal(result, expected_result) + self.assertEqual(f"{result.dtype}", "bool") if __name__ == "__main__": diff --git a/tests/test_decollate.py b/tests/test_decollate.py index 5a5c39cf86..84b92bdb2c 100644 --- a/tests/test_decollate.py +++ b/tests/test_decollate.py @@ -9,6 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import sys import unittest import numpy as np @@ -34,7 +35,12 @@ def tearDown(self) -> None: def check_match(self, in1, in2): if isinstance(in1, dict): self.assertTrue(isinstance(in2, dict)) - self.check_match(list(in1.keys()), list(in2.keys())) + for (k1, v1), (k2, v2) in zip(in1.items(), in2.items()): + self.check_match(k1, k2) + # Transform ids won't match for windows with multiprocessing + if k1 == "id" and sys.platform == "win32": + continue + self.check_match(v1, v2) self.check_match(list(in1.values()), list(in2.values())) elif any(isinstance(in1, i) for i in [list, tuple]): for l1, l2 in zip(in1, in2): @@ -64,23 +70,15 @@ def test_decollation(self, batch_size=2, num_workers=2): transforms = Compose([LoadImaged("image"), transforms]) dataset = CacheDataset(data, transforms, progress=False) - loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers) + loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=0) for b, batch_data in enumerate(loader): decollated_1 = decollate_batch(batch_data) decollated_2 = Decollated()(batch_data) - for z, decollated in enumerate([decollated_1, decollated_2]): + for decollated in [decollated_1, decollated_2]: for i, d in enumerate(decollated): - try: - self.check_match(dataset[b * batch_size + i], d) - except RuntimeError: - print(f"problem with b={b}, i={i}, decollated_{z+1}") - print("d") - print(d) - print("dataset[b * batch_size + i]") - print(dataset[b * batch_size + i]) - raise + self.check_match(dataset[b * batch_size + i], d) if __name__ == "__main__": diff --git a/tests/test_densenet.py b/tests/test_densenet.py index 876689314a..41b5fbf7d6 100644 --- a/tests/test_densenet.py +++ b/tests/test_densenet.py @@ -10,14 +10,25 @@ # limitations under the License. import unittest +from typing import TYPE_CHECKING +from unittest import skipUnless import torch from parameterized import parameterized from monai.networks import eval_mode from monai.networks.nets import densenet121, densenet169, densenet201, densenet264 +from monai.utils import optional_import from tests.utils import skip_if_quick, test_pretrained_networks, test_script_save +if TYPE_CHECKING: + import torchvision + + has_torchvision = True +else: + torchvision, has_torchvision = optional_import("torchvision") + + device = "cuda" if torch.cuda.is_available() else "cpu" TEST_CASE_1 = [ # 4-channel 3D, batch 2 @@ -50,27 +61,45 @@ TEST_PRETRAINED_2D_CASE_1 = [ # 4-channel 2D, batch 2 densenet121, {"pretrained": True, "progress": True, "spatial_dims": 2, "in_channels": 2, "out_channels": 3}, - (2, 2, 32, 64), - (2, 3), + (1, 2, 32, 64), + (1, 3), ] TEST_PRETRAINED_2D_CASE_2 = [ # 4-channel 2D, batch 2 densenet121, - {"pretrained": True, "progress": False, "spatial_dims": 2, "in_channels": 2, "out_channels": 3}, - (2, 2, 32, 64), - (2, 3), + {"pretrained": True, "progress": False, "spatial_dims": 2, "in_channels": 2, "out_channels": 1}, + (1, 2, 32, 64), + (1, 1), +] + +TEST_PRETRAINED_2D_CASE_3 = [ + densenet121, + {"pretrained": True, "progress": False, "spatial_dims": 2, "in_channels": 3, "out_channels": 1}, + (1, 3, 32, 32), ] class TestPretrainedDENSENET(unittest.TestCase): @parameterized.expand([TEST_PRETRAINED_2D_CASE_1, TEST_PRETRAINED_2D_CASE_2]) @skip_if_quick - def test_121_3d_shape_pretrain(self, model, input_param, input_shape, expected_shape): + def test_121_2d_shape_pretrain(self, model, input_param, input_shape, expected_shape): net = test_pretrained_networks(model, input_param, device) with eval_mode(net): result = net.forward(torch.randn(input_shape).to(device)) self.assertEqual(result.shape, expected_shape) + @parameterized.expand([TEST_PRETRAINED_2D_CASE_3]) + @skipUnless(has_torchvision, "Requires `torchvision` package.") + def test_pretrain_consistency(self, model, input_param, input_shape): + example = torch.randn(input_shape).to(device) + net = test_pretrained_networks(model, input_param, device) + with eval_mode(net): + result = net.features.forward(example) + torchvision_net = torchvision.models.densenet121(pretrained=True).to(device) + with eval_mode(torchvision_net): + expected_result = torchvision_net.features.forward(example) + self.assertTrue(torch.all(result == expected_result)) + class TestDENSENET(unittest.TestCase): @parameterized.expand(TEST_CASES) diff --git a/tests/test_integration_sliding_window.py b/tests/test_integration_sliding_window.py index c4d020276e..faec377586 100644 --- a/tests/test_integration_sliding_window.py +++ b/tests/test_integration_sliding_window.py @@ -84,7 +84,7 @@ def test_training(self): ) output_image = nib.load(output_file).get_fdata() np.testing.assert_allclose(np.sum(output_image), 33621) - np.testing.assert_allclose(output_image.shape, (28, 25, 63, 1)) + np.testing.assert_allclose(output_image.shape, (28, 25, 63)) if __name__ == "__main__": diff --git a/tests/test_inverse.py b/tests/test_inverse.py index b9ae0e3704..ed8b5bb30e 100644 --- a/tests/test_inverse.py +++ b/tests/test_inverse.py @@ -499,11 +499,11 @@ def test_fail(self, data_name, _, *transform): def test_inverse_inferred_seg(self): test_data = [] - for _ in range(4): + for _ in range(20): image, label = create_test_image_2d(100, 101) test_data.append({"image": image, "label": label.astype(np.float32)}) - batch_size = 2 + batch_size = 10 # num workers = 0 for mac num_workers = 2 if sys.platform != "darwin" else 0 transforms = Compose([AddChanneld(KEYS), SpatialPadd(KEYS, (150, 153)), CenterSpatialCropd(KEYS, (110, 99))]) diff --git a/tests/test_senet.py b/tests/test_senet.py index 883d75d62d..c1327ceb7d 100644 --- a/tests/test_senet.py +++ b/tests/test_senet.py @@ -10,6 +10,8 @@ # limitations under the License. import unittest +from typing import TYPE_CHECKING +from unittest import skipUnless import torch from parameterized import parameterized @@ -23,8 +25,17 @@ se_resnext101_32x4d, senet154, ) +from monai.utils import optional_import from tests.utils import test_pretrained_networks, test_script_save +if TYPE_CHECKING: + import pretrainedmodels + + has_cadene_pretrain = True +else: + pretrainedmodels, has_cadene_pretrain = optional_import("pretrainedmodels") + + device = "cuda" if torch.cuda.is_available() else "cpu" NET_ARGS = {"spatial_dims": 3, "in_channels": 2, "num_classes": 2} @@ -56,11 +67,7 @@ def test_script(self, net, net_args): class TestPretrainedSENET(unittest.TestCase): - @parameterized.expand( - [ - TEST_CASE_PRETRAINED, - ] - ) + @parameterized.expand([TEST_CASE_PRETRAINED]) def test_senet_shape(self, model, input_param): net = test_pretrained_networks(model, input_param, device) input_data = torch.randn(3, 3, 64, 64).to(device) @@ -70,6 +77,21 @@ def test_senet_shape(self, model, input_param): result = net(input_data) self.assertEqual(result.shape, expected_shape) + @parameterized.expand([TEST_CASE_PRETRAINED]) + @skipUnless(has_cadene_pretrain, "Requires `pretrainedmodels` package.") + def test_pretrain_consistency(self, model, input_param): + input_data = torch.randn(1, 3, 64, 64).to(device) + net = test_pretrained_networks(model, input_param, device) + with eval_mode(net): + result = net.features(input_data) + cadene_net = pretrainedmodels.se_resnet50().to(device) + with eval_mode(cadene_net): + expected_result = cadene_net.features(input_data) + # The difference between Cadene's senet and our version is that + # we use nn.Linear as the FC layer, but Cadene's version uses + # a conv layer with kernel size equals to 1. It may bring a little difference. + self.assertTrue(torch.allclose(result, expected_result, rtol=1e-5, atol=1e-5)) + if __name__ == "__main__": unittest.main() From d4e2c5463eac936ac5802ac887f43a4c4a8f37dc Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Fri, 5 Mar 2021 12:51:44 +0000 Subject: [PATCH 22/64] inverse_transform.py -> inverse.py Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/__init__.py | 2 +- monai/transforms/compose.py | 2 +- monai/transforms/croppad/dictionary.py | 2 +- monai/transforms/{inverse_transform.py => inverse.py} | 0 monai/transforms/spatial/dictionary.py | 2 +- 5 files changed, 4 insertions(+), 4 deletions(-) rename monai/transforms/{inverse_transform.py => inverse.py} (100%) diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index e57394aab9..fbaaed7279 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -138,7 +138,7 @@ ThresholdIntensityD, ThresholdIntensityDict, ) -from .inverse_transform import InvertibleTransform, NonRigidTransform +from .inverse import InvertibleTransform, NonRigidTransform from .io.array import LoadImage, SaveImage from .io.dictionary import LoadImaged, LoadImageD, LoadImageDict, SaveImaged, SaveImageD, SaveImageDict from .post.array import ( diff --git a/monai/transforms/compose.py b/monai/transforms/compose.py index 678667bb0a..341d3e4e49 100644 --- a/monai/transforms/compose.py +++ b/monai/transforms/compose.py @@ -18,7 +18,7 @@ import numpy as np -from monai.transforms.inverse_transform import InvertibleTransform +from monai.transforms.inverse import InvertibleTransform # For backwards compatiblity (so this still works: from monai.transforms.compose import MapTransform) from monai.transforms.transform import MapTransform, Randomizable, RandomizableTransform, Transform # noqa: F401 diff --git a/monai/transforms/croppad/dictionary.py b/monai/transforms/croppad/dictionary.py index aef55ddc93..42eccd7259 100644 --- a/monai/transforms/croppad/dictionary.py +++ b/monai/transforms/croppad/dictionary.py @@ -32,7 +32,7 @@ SpatialCrop, SpatialPad, ) -from monai.transforms.inverse_transform import InvertibleTransform +from monai.transforms.inverse import InvertibleTransform from monai.transforms.transform import MapTransform, Randomizable, RandomizableTransform from monai.transforms.utils import ( generate_pos_neg_label_crop_centers, diff --git a/monai/transforms/inverse_transform.py b/monai/transforms/inverse.py similarity index 100% rename from monai/transforms/inverse_transform.py rename to monai/transforms/inverse.py diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index 1c3f52ea10..e65a6da072 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -25,7 +25,7 @@ from monai.networks.layers import AffineTransform from monai.networks.layers.simplelayers import GaussianFilter from monai.transforms.croppad.array import CenterSpatialCrop, SpatialPad -from monai.transforms.inverse_transform import InvertibleTransform, NonRigidTransform +from monai.transforms.inverse import InvertibleTransform, NonRigidTransform from monai.transforms.spatial.array import ( AffineGrid, Flip, From 2a0eb62dbb772dcade5948acc88beb78b919f01b Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Fri, 5 Mar 2021 18:03:15 +0000 Subject: [PATCH 23/64] with AllowMissingKeysMode Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/__init__.py | 4 +- monai/transforms/compose.py | 3 +- monai/transforms/transform.py | 29 +++++++++++- monai/transforms/utils.py | 66 +++++++++++++++++---------- tests/test_with_allow_missing_keys.py | 58 +++++++++++++++++++++++ 5 files changed, 129 insertions(+), 31 deletions(-) create mode 100644 tests/test_with_allow_missing_keys.py diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index fbaaed7279..a24abaf21c 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -242,7 +242,7 @@ ZoomD, ZoomDict, ) -from .transform import MapTransform, Randomizable, RandomizableTransform, Transform +from .transform import apply_transform, MapTransform, Randomizable, RandomizableTransform, Transform from .utility.array import ( AddChannel, AddExtremePointsChannel, @@ -346,7 +346,6 @@ ToTensorDict, ) from .utils import ( - apply_transform, copypaste_arrays, create_control_grid, create_grid, @@ -371,4 +370,5 @@ resize_center, weighted_patch_samples, zero_margins, + AllowMissingKeysMode, ) diff --git a/monai/transforms/compose.py b/monai/transforms/compose.py index 341d3e4e49..6d8ebeb73e 100644 --- a/monai/transforms/compose.py +++ b/monai/transforms/compose.py @@ -21,8 +21,7 @@ from monai.transforms.inverse import InvertibleTransform # For backwards compatiblity (so this still works: from monai.transforms.compose import MapTransform) -from monai.transforms.transform import MapTransform, Randomizable, RandomizableTransform, Transform # noqa: F401 -from monai.transforms.utils import apply_transform +from monai.transforms.transform import apply_transform, MapTransform, Randomizable, RandomizableTransform, Transform # noqa: F401 from monai.utils import MAX_SEED, ensure_tuple, get_seed __all__ = ["Compose"] diff --git a/monai/transforms/transform.py b/monai/transforms/transform.py index 7a09efa6d5..02b3ee6c71 100644 --- a/monai/transforms/transform.py +++ b/monai/transforms/transform.py @@ -13,16 +13,41 @@ """ from abc import ABC, abstractmethod -from typing import Any, Dict, Generator, Hashable, Iterable, List, Optional, Tuple +from typing import Any, Callable, Dict, Generator, Hashable, Iterable, List, Optional, Tuple, Union import numpy as np from monai.config import KeysCollection from monai.utils import MAX_SEED, ensure_tuple -__all__ = ["Randomizable", "RandomizableTransform", "Transform", "MapTransform"] +__all__ = ["apply_transform", "Randomizable", "RandomizableTransform", "Transform", "MapTransform"] + +def apply_transform(transform: Callable, data, map_items: bool = True): + """ + Transform `data` with `transform`. + If `data` is a list or tuple and `map_data` is True, each item of `data` will be transformed + and this method returns a list of outcomes. + otherwise transform will be applied once with `data` as the argument. + + Args: + transform: a callable to be used to transform `data` + data: an object to be transformed. + map_items: whether to apply transform to each item in `data`, + if `data` is a list or tuple. Defaults to True. + + Raises: + Exception: When ``transform`` raises an exception. + + """ + try: + if isinstance(data, (list, tuple)) and map_items: + return [transform(item) for item in data] + return transform(data) + except Exception as e: + raise RuntimeError(f"applying transform {transform}") from e + class Randomizable(ABC): """ An interface for handling random state locally, currently based on a class variable `R`, diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index 9a84eb00d9..a07a2b8dbd 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -10,6 +10,8 @@ # limitations under the License. import itertools +from monai.transforms.transform import MapTransform +from monai.transforms.compose import Compose import random import warnings from typing import Callable, List, Optional, Sequence, Tuple, Union @@ -49,6 +51,7 @@ "get_extreme_points", "extreme_points_to_image", "map_spatial_axes", + "AllowMissingKeysMode", ] @@ -363,31 +366,6 @@ def _correct_centers( return centers -def apply_transform(transform: Callable, data, map_items: bool = True): - """ - Transform `data` with `transform`. - If `data` is a list or tuple and `map_data` is True, each item of `data` will be transformed - and this method returns a list of outcomes. - otherwise transform will be applied once with `data` as the argument. - - Args: - transform: a callable to be used to transform `data` - data: an object to be transformed. - map_items: whether to apply transform to each item in `data`, - if `data` is a list or tuple. Defaults to True. - - Raises: - Exception: When ``transform`` raises an exception. - - """ - try: - if isinstance(data, (list, tuple)) and map_items: - return [transform(item) for item in data] - return transform(data) - except Exception as e: - raise RuntimeError(f"applying transform {transform}") from e - - def create_grid( spatial_size: Sequence[int], spacing: Optional[Sequence[float]] = None, @@ -730,3 +708,41 @@ def map_spatial_axes( spatial_axes_.append(a - 1 if a < 0 else a) return spatial_axes_ + +class AllowMissingKeysMode(): + """Temporarily set all MapTransforms to not throw an error if keys are missing. After, revert to original states. + + Args: + transform: either MapTransform or a Compose + + Example: + + .. code-block:: python + + data = {"image": np.arange(16, dtype=float).reshape(1, 4, 4)} + t = SpatialPadd(["image", "label"], 10, allow_missing_keys=False) + _ = t(data) # would raise exception + with AllowMissingKeysMode(t): + _ = t(data) # OK! + """ + def __init__(self, transform: Union[MapTransform, Compose]): + if isinstance(transform, MapTransform): + self.transforms = [transform] + elif isinstance(transform, Compose): + # Only keep contained MapTransforms + self.transforms = [t for t in transform.flatten().transforms if isinstance(t, MapTransform)] + else: + self.transforms = [] + + # Get the state of each `allow_missing_keys` + self.orig_states = [t.allow_missing_keys for t in self.transforms] + + def __enter__(self): + # Set all to True + for t in self.transforms: + t.allow_missing_keys = True + + def __exit__(self, type, value, traceback): + # Revert + for t, o_s in zip(self.transforms, self.orig_states): + t.allow_missing_keys = o_s diff --git a/tests/test_with_allow_missing_keys.py b/tests/test_with_allow_missing_keys.py new file mode 100644 index 0000000000..77fa0a1a90 --- /dev/null +++ b/tests/test_with_allow_missing_keys.py @@ -0,0 +1,58 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import unittest +from monai.transforms import Compose, SpatialPadd, SpatialPad, AllowMissingKeysMode + +class TestWithAllowMissingKeys(unittest.TestCase): + def setUp(self): + self.data = {"image": np.arange(16, dtype=float).reshape(1, 4, 4)} + + def test_map_transform(self): + for amk in [True, False]: + t = SpatialPadd(["image", "label"], 10, allow_missing_keys=amk) + with AllowMissingKeysMode(t): + # check state is True + self.assertTrue(t.allow_missing_keys) + # and that transform works even though key is missing + _ = t(self.data) + # check it has returned to original state + self.assertEqual(t.allow_missing_keys, amk) + if not amk: + # should fail because amks==False and key is missing + with self.assertRaises(KeyError): + _ = t(self.data) + + def test_compose(self): + amks = [True, False, True] + t = Compose([SpatialPadd(["image", "label"], 10, allow_missing_keys=amk) for amk in amks]) + with AllowMissingKeysMode(t): + # check states are all True + for _t in t.transforms: + self.assertTrue(_t.allow_missing_keys) + # and that transform works even though key is missing + _ = t(self.data) + # check they've returned to original state + for _t, amk in zip(t.transforms, amks): + self.assertEqual(_t.allow_missing_keys, amk) + # should fail because not all amks==True and key is missing + with self.assertRaises([KeyError, RuntimeError]): + _ = t(self.data) + + def test_array_transform(self): + for t in [SpatialPad(10), Compose([SpatialPad(10)])]: + with AllowMissingKeysMode(t): + # should work as nothing should have changed + _ = t(self.data["image"]) + +if __name__ == "__main__": + unittest.main() From 53c58b0e2bafb00263dbaa41b5652b5b2188dff2 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Fri, 5 Mar 2021 18:22:36 +0000 Subject: [PATCH 24/64] remove keys from inverse method Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/compose.py | 18 ++-- monai/transforms/croppad/dictionary.py | 32 +++---- monai/transforms/inverse.py | 2 +- monai/transforms/spatial/dictionary.py | 112 +++++++++++++------------ tests/test_inverse.py | 16 +--- 5 files changed, 86 insertions(+), 94 deletions(-) diff --git a/monai/transforms/compose.py b/monai/transforms/compose.py index 6d8ebeb73e..54e3b8690d 100644 --- a/monai/transforms/compose.py +++ b/monai/transforms/compose.py @@ -139,16 +139,12 @@ def __call__(self, input_): input_ = apply_transform(_transform, input_) return input_ - def inverse(self, data, keys: Optional[Tuple[Hashable, ...]] = None): - if not isinstance(data, Mapping): - raise RuntimeError("Inverse method only available for dictionary transforms") - d = deepcopy(dict(data)) - if keys: - keys = ensure_tuple(keys) + def inverse(self, data): + invertible_transforms = [t for t in self.flatten().transforms if isinstance(t, InvertibleTransform)] + if len(invertible_transforms) == 0: + warnings.warn("inverse has been called but no invertible transforms have been supplied") # loop backwards over transforms - for t in reversed(self.transforms): - # check if transform is one of the invertible ones - if isinstance(t, InvertibleTransform): - d = t.inverse(d, keys) - return d + for t in reversed(invertible_transforms): + data = t.inverse(data) + return data diff --git a/monai/transforms/croppad/dictionary.py b/monai/transforms/croppad/dictionary.py index 42eccd7259..e0e815a36b 100644 --- a/monai/transforms/croppad/dictionary.py +++ b/monai/transforms/croppad/dictionary.py @@ -127,10 +127,10 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda return d def inverse( - self, data: Mapping[Hashable, np.ndarray], keys: Optional[Tuple[Hashable, ...]] = None + self, data: Mapping[Hashable, np.ndarray] ) -> Dict[Hashable, np.ndarray]: d = deepcopy(dict(data)) - for key in keys or self.keys: + for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) # Create inverse transform orig_size = transform["orig_size"] @@ -197,11 +197,11 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda return d def inverse( - self, data: Mapping[Hashable, np.ndarray], keys: Optional[Tuple[Hashable, ...]] = None + self, data: Mapping[Hashable, np.ndarray] ) -> Dict[Hashable, np.ndarray]: d = deepcopy(dict(data)) - for key in keys or self.keys: + for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) # Create inverse transform orig_size = np.array(transform["orig_size"]) @@ -265,11 +265,11 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda return d def inverse( - self, data: Mapping[Hashable, np.ndarray], keys: Optional[Tuple[Hashable, ...]] = None + self, data: Mapping[Hashable, np.ndarray] ) -> Dict[Hashable, np.ndarray]: d = deepcopy(dict(data)) - for key in keys or self.keys: + for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) # Create inverse transform orig_size = np.array(transform["orig_size"]) @@ -322,11 +322,11 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda return d def inverse( - self, data: Mapping[Hashable, np.ndarray], keys: Optional[Tuple[Hashable, ...]] = None + self, data: Mapping[Hashable, np.ndarray] ) -> Dict[Hashable, np.ndarray]: d = deepcopy(dict(data)) - for key in keys or self.keys: + for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) # Create inverse transform orig_size = transform["orig_size"] @@ -372,11 +372,11 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda return d def inverse( - self, data: Mapping[Hashable, np.ndarray], keys: Optional[Tuple[Hashable, ...]] = None + self, data: Mapping[Hashable, np.ndarray] ) -> Dict[Hashable, np.ndarray]: d = deepcopy(dict(data)) - for key in keys or self.keys: + for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) # Create inverse transform orig_size = np.array(transform["orig_size"]) @@ -457,11 +457,11 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda return d def inverse( - self, data: Mapping[Hashable, np.ndarray], keys: Optional[Tuple[Hashable, ...]] = None + self, data: Mapping[Hashable, np.ndarray] ) -> Dict[Hashable, np.ndarray]: d = deepcopy(dict(data)) - for key in keys or self.keys: + for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) # Create inverse transform orig_size = transform["orig_size"] @@ -607,10 +607,10 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda return d def inverse( - self, data: Mapping[Hashable, np.ndarray], keys: Optional[Tuple[Hashable, ...]] = None + self, data: Mapping[Hashable, np.ndarray] ) -> Dict[Hashable, np.ndarray]: d = deepcopy(dict(data)) - for key in keys or self.keys: + for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) # Create inverse transform orig_size = np.array(transform["orig_size"]) @@ -844,10 +844,10 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda return d def inverse( - self, data: Mapping[Hashable, np.ndarray], keys: Optional[Tuple[Hashable, ...]] = None + self, data: Mapping[Hashable, np.ndarray] ) -> Dict[Hashable, np.ndarray]: d = deepcopy(dict(data)) - for key in keys or self.keys: + for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) # Create inverse transform orig_size = transform["orig_size"] diff --git a/monai/transforms/inverse.py b/monai/transforms/inverse.py index ec1c684f7c..c326ad1c4e 100644 --- a/monai/transforms/inverse.py +++ b/monai/transforms/inverse.py @@ -88,7 +88,7 @@ def remove_most_recent_transform(data: dict, key: Hashable) -> None: """Remove most recent transform.""" data[str(key) + "_transforms"].pop() - def inverse(self, data: dict, keys: Optional[Tuple[Hashable, ...]] = None) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: dict) -> Dict[Hashable, np.ndarray]: """ Inverse of ``__call__``. diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index e65a6da072..1f6dba768a 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -211,10 +211,12 @@ def __call__( return d def inverse( - self, data: Mapping[Hashable, np.ndarray], keys: Optional[Tuple[Hashable, ...]] = None + self, data: Mapping[Hashable, np.ndarray] ) -> Dict[Hashable, np.ndarray]: d = deepcopy(dict(data)) - for idx, key in enumerate(keys or self.keys): + for key, mode, padding_mode, align_corners, dtype in self.key_iterator( + d, self.mode, self.padding_mode, self.align_corners, self.dtype + ): transform = self.get_most_recent_transform(d, key) if self.spacing_transform.diagonal: raise RuntimeError( @@ -230,10 +232,10 @@ def inverse( d[key], _, new_affine = inverse_transform( data_array=np.asarray(d[key]), affine=meta_data["affine"], - mode=self.mode[idx], - padding_mode=self.padding_mode[idx], - align_corners=self.align_corners[idx], - dtype=self.dtype[idx], + mode=mode, + padding_mode=padding_mode, + align_corners=align_corners, + dtype=dtype, ) meta_data["affine"] = new_affine # Remove the applied transform @@ -307,10 +309,10 @@ def __call__( return d def inverse( - self, data: Mapping[Hashable, np.ndarray], keys: Optional[Tuple[Hashable, ...]] = None + self, data: Mapping[Hashable, np.ndarray] ) -> Dict[Hashable, np.ndarray]: d = deepcopy(dict(data)) - for key in keys or self.keys: + for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) # Create inverse transform meta_data = d[transform["extra_info"]["meta_data_key"]] @@ -356,10 +358,10 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda return d def inverse( - self, data: Mapping[Hashable, np.ndarray], keys: Optional[Tuple[Hashable, ...]] = None + self, data: Mapping[Hashable, np.ndarray] ) -> Dict[Hashable, np.ndarray]: d = deepcopy(dict(data)) - for key in keys or self.keys: + for key in self.key_iterator(d): _ = self.get_most_recent_transform(d, key) # Create inverse transform spatial_axes = self.rotator.spatial_axes @@ -428,10 +430,10 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Mapping[Hashable, np. return d def inverse( - self, data: Mapping[Hashable, np.ndarray], keys: Optional[Tuple[Hashable, ...]] = None + self, data: Mapping[Hashable, np.ndarray] ) -> Dict[Hashable, np.ndarray]: d = deepcopy(dict(data)) - for key in keys or self.keys: + for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) # Check if random transform was actually performed (based on `prob`) if transform["do_transform"]: @@ -493,14 +495,12 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda return d def inverse( - self, data: Mapping[Hashable, np.ndarray], keys: Optional[Tuple[Hashable, ...]] = None + self, data: Mapping[Hashable, np.ndarray] ) -> Dict[Hashable, np.ndarray]: d = deepcopy(dict(data)) - for idx, key in enumerate(keys or self.keys): + for key, mode, align_corners in self.key_iterator(d, self.mode, self.align_corners): transform = self.get_most_recent_transform(d, key) orig_size = transform["orig_size"] - mode = self.mode[idx] - align_corners = self.align_corners[idx] # Create inverse transform inverse_transform = Resize(orig_size, mode, align_corners) # Apply inverse transform @@ -614,11 +614,11 @@ def __call__( return d def inverse( - self, data: Mapping[Hashable, np.ndarray], keys: Optional[Tuple[Hashable, ...]] = None + self, data: Mapping[Hashable, np.ndarray] ) -> Dict[Hashable, np.ndarray]: d = deepcopy(dict(data)) - for idx, key in enumerate(keys or self.keys): + for key, mode, padding_mode in self.key_iterator(d, self.mode, self.padding_mode): transform = self.get_most_recent_transform(d, key) orig_size = transform["orig_size"] # Create inverse transform @@ -629,7 +629,7 @@ def inverse( grid: torch.Tensor = affine_grid(orig_size) # type: ignore # Apply inverse transform - out = self.rand_affine.resampler(d[key], grid, self.mode[idx], self.padding_mode[idx]) + out = self.rand_affine.resampler(d[key], grid, mode, padding_mode) # Convert to numpy d[key] = out if isinstance(out, np.ndarray) else out.cpu().numpy() @@ -766,13 +766,13 @@ def __call__( return d def inverse( - self, data: Mapping[Hashable, np.ndarray], keys: Optional[Tuple[Hashable, ...]] = None + self, data: Mapping[Hashable, np.ndarray] ) -> Dict[Hashable, np.ndarray]: d = deepcopy(dict(data)) # This variable will be `not None` if vtk or sitk is present inv_def_no_affine = None - for idx, key in enumerate(keys or self.keys): + for idx, (key, mode, padding_mode) in enumerate(self.key_iterator(d, self.mode, self.padding_mode)): transform = self.get_most_recent_transform(d, key) # Create inverse transform if transform["do_transform"]: @@ -800,7 +800,7 @@ def inverse( # Apply inverse transform if inv_def_no_affine is not None: out = self.rand_2d_elastic.resampler( - d[key], inv_def_w_affine, self.mode[idx], self.padding_mode[idx] + d[key], inv_def_w_affine, mode, padding_mode ) d[key] = out.cpu().numpy() if isinstance(out, torch.Tensor) else out @@ -931,11 +931,11 @@ def __call__( return d def inverse( - self, data: Mapping[Hashable, np.ndarray], keys: Optional[Tuple[Hashable, ...]] = None + self, data: Mapping[Hashable, np.ndarray] ) -> Dict[Hashable, np.ndarray]: d = deepcopy(dict(data)) - for idx, key in enumerate(keys or self.keys): + for idx, (key, mode, padding_mode) in enumerate(self.key_iterator(d, self.mode, self.padding_mode)): transform = self.get_most_recent_transform(d, key) # Create inverse transform if transform["do_transform"]: @@ -956,7 +956,7 @@ def inverse( # Apply inverse transform if inv_def_w_affine is not None: out = self.rand_3d_elastic.resampler( - d[key], inv_def_w_affine, self.mode[idx], self.padding_mode[idx] + d[key], inv_def_w_affine, mode, padding_mode ) d[key] = out.cpu().numpy() if isinstance(out, torch.Tensor) else out else: @@ -997,10 +997,10 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda return d def inverse( - self, data: Mapping[Hashable, np.ndarray], keys: Optional[Tuple[Hashable, ...]] = None + self, data: Mapping[Hashable, np.ndarray] ) -> Dict[Hashable, np.ndarray]: d = deepcopy(dict(data)) - for key in keys or self.keys: + for key in self.key_iterator(d): _ = self.get_most_recent_transform(d, key) # Might need to convert to numpy if isinstance(d[key], torch.Tensor): @@ -1050,10 +1050,10 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda return d def inverse( - self, data: Mapping[Hashable, np.ndarray], keys: Optional[Tuple[Hashable, ...]] = None + self, data: Mapping[Hashable, np.ndarray] ) -> Dict[Hashable, np.ndarray]: d = deepcopy(dict(data)) - for key in keys or self.keys: + for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) # Check if random transform was actually performed (based on `prob`) if transform["do_transform"]: @@ -1102,10 +1102,10 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda return d def inverse( - self, data: Mapping[Hashable, np.ndarray], keys: Optional[Tuple[Hashable, ...]] = None + self, data: Mapping[Hashable, np.ndarray] ) -> Dict[Hashable, np.ndarray]: d = deepcopy(dict(data)) - for key in keys or self.keys: + for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) # Check if random transform was actually performed (based on `prob`) if transform["do_transform"]: @@ -1185,10 +1185,12 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda return d def inverse( - self, data: Mapping[Hashable, np.ndarray], keys: Optional[Tuple[Hashable, ...]] = None + self, data: Mapping[Hashable, np.ndarray] ) -> Dict[Hashable, np.ndarray]: d = deepcopy(dict(data)) - for idx, key in enumerate(keys or self.keys): + for key, mode, padding_mode, align_corners, dtype in self.key_iterator( + d, self.mode, self.padding_mode, self.align_corners, self.dtype + ): transform = self.get_most_recent_transform(d, key) # Create inverse transform fwd_rot_mat = transform["extra_info"]["rot_mat"] @@ -1196,12 +1198,11 @@ def inverse( xform = AffineTransform( normalized=False, - mode=self.mode[idx], - padding_mode=self.padding_mode[idx], - align_corners=self.align_corners[idx], + mode=mode, + padding_mode=padding_mode, + align_corners=align_corners, reverse_indexing=True, ) - dtype = self.dtype[idx] output = xform( torch.as_tensor(np.ascontiguousarray(d[key]).astype(dtype)).unsqueeze(0), torch.as_tensor(np.ascontiguousarray(inv_rot_mat).astype(dtype)), @@ -1319,10 +1320,12 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda return d def inverse( - self, data: Mapping[Hashable, np.ndarray], keys: Optional[Tuple[Hashable, ...]] = None + self, data: Mapping[Hashable, np.ndarray] ) -> Dict[Hashable, np.ndarray]: d = deepcopy(dict(data)) - for idx, key in enumerate(keys or self.keys): + for key, mode, padding_mode, align_corners, dtype in self.key_iterator( + d, self.mode, self.padding_mode, self.align_corners, self.dtype + ): transform = self.get_most_recent_transform(d, key) # Check if random transform was actually performed (based on `prob`) if transform["do_transform"]: @@ -1332,12 +1335,11 @@ def inverse( xform = AffineTransform( normalized=False, - mode=self.mode[idx], - padding_mode=self.padding_mode[idx], - align_corners=self.align_corners[idx], + mode=mode, + padding_mode=padding_mode, + align_corners=align_corners, reverse_indexing=True, ) - dtype = self.dtype[idx] output = xform( torch.as_tensor(np.ascontiguousarray(d[key]).astype(dtype)).unsqueeze(0), torch.as_tensor(np.ascontiguousarray(inv_rot_mat).astype(dtype)), @@ -1406,10 +1408,12 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda return d def inverse( - self, data: Mapping[Hashable, np.ndarray], keys: Optional[Tuple[Hashable, ...]] = None + self, data: Mapping[Hashable, np.ndarray] ) -> Dict[Hashable, np.ndarray]: d = deepcopy(dict(data)) - for idx, key in enumerate(keys or self.keys): + for key, mode, padding_mode, align_corners in self.key_iterator( + d, self.mode, self.padding_mode, self.align_corners + ): transform = self.get_most_recent_transform(d, key) # Create inverse transform zoom = np.array(self.zoomer.zoom) @@ -1417,9 +1421,9 @@ def inverse( # Apply inverse d[key] = inverse_transform( d[key], - mode=self.mode[idx], - padding_mode=self.padding_mode[idx], - align_corners=self.align_corners[idx], + mode=mode, + padding_mode=padding_mode, + align_corners=align_corners, ) # Size might be out by 1 voxel so pad d[key] = SpatialPad(transform["orig_size"])(d[key]) @@ -1522,10 +1526,12 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda return d def inverse( - self, data: Mapping[Hashable, np.ndarray], keys: Optional[Tuple[Hashable, ...]] = None + self, data: Mapping[Hashable, np.ndarray] ) -> Dict[Hashable, np.ndarray]: d = deepcopy(dict(data)) - for idx, key in enumerate(keys or self.keys): + for key, mode, padding_mode, align_corners in self.key_iterator( + d, self.mode, self.padding_mode, self.align_corners + ): transform = self.get_most_recent_transform(d, key) # Create inverse transform zoom = np.array(transform["extra_info"]["zoom"]) @@ -1533,9 +1539,9 @@ def inverse( # Apply inverse d[key] = inverse_transform( d[key], - mode=self.mode[idx], - padding_mode=self.padding_mode[idx], - align_corners=self.align_corners[idx], + mode=mode, + padding_mode=padding_mode, + align_corners=align_corners, ) # Size might be out by 1 voxel so pad d[key] = SpatialPad(transform["orig_size"])(d[key]) diff --git a/tests/test_inverse.py b/tests/test_inverse.py index ed8b5bb30e..89f8fbf6a0 100644 --- a/tests/test_inverse.py +++ b/tests/test_inverse.py @@ -52,6 +52,7 @@ SpatialPad, SpatialPadd, Zoomd, + AllowMissingKeysMode, ) from monai.utils import first, optional_import, set_determinism from tests.utils import make_nifti_image, make_rand_affine, test_is_quick @@ -393,11 +394,6 @@ TESTS = TESTS + TESTS_COMPOSE_X2 # type: ignore -# Should fail because uses an array transform (SpatialPad), as opposed to dictionary -TEST_FAIL_0 = ("2D", 0.0, Compose([SpatialPad(spatial_size=[101, 103])])) -TESTS_FAIL = [TEST_FAIL_0] - - def plot_im(orig, fwd_bck, fwd): diff_orig_fwd_bck = orig - fwd_bck ims_to_show = [orig, fwd, fwd_bck, diff_orig_fwd_bck] @@ -489,13 +485,6 @@ def test_inverse(self, _, data_name, acceptable_diff, *transforms): fwd_bck = t.inverse(fwd_bck) self.check_inverse(name, data.keys(), forwards[-i - 2], fwd_bck, forwards[-1], acceptable_diff) - @parameterized.expand(TESTS_FAIL) - def test_fail(self, data_name, _, *transform): - data = self.all_data[data_name]["image"] - d = transform[0](data) - with self.assertRaises(RuntimeError): - d = transform[0].inverse(d) - def test_inverse_inferred_seg(self): test_data = [] @@ -529,7 +518,8 @@ def test_inverse_inferred_seg(self): # inverse of individual segmentation seg_dict = first(segs_dict_decollated) - inv_seg = transforms.inverse(seg_dict, "label")["label"] + with AllowMissingKeysMode(transforms): + inv_seg = transforms.inverse(seg_dict)["label"] self.assertEqual(len(data["label_transforms"]), num_invertible_transforms) self.assertEqual(len(seg_dict["label_transforms"]), num_invertible_transforms) self.assertEqual(inv_seg.shape[1:], test_data[0]["label"].shape) From 7036ac8e66700d24958c5fc8505e09b79afdee9a Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Fri, 5 Mar 2021 18:41:31 +0000 Subject: [PATCH 25/64] autofixes Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/__init__.py | 4 +- monai/transforms/compose.py | 11 +++-- monai/transforms/croppad/dictionary.py | 32 +++--------- monai/transforms/spatial/dictionary.py | 68 +++++++------------------- monai/transforms/transform.py | 4 +- monai/transforms/utils.py | 9 ++-- tests/test_inverse.py | 3 +- tests/test_with_allow_missing_keys.py | 10 ++-- 8 files changed, 50 insertions(+), 91 deletions(-) diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index a24abaf21c..9976a4b1ea 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -242,7 +242,7 @@ ZoomD, ZoomDict, ) -from .transform import apply_transform, MapTransform, Randomizable, RandomizableTransform, Transform +from .transform import MapTransform, Randomizable, RandomizableTransform, Transform, apply_transform from .utility.array import ( AddChannel, AddExtremePointsChannel, @@ -346,6 +346,7 @@ ToTensorDict, ) from .utils import ( + AllowMissingKeysMode, copypaste_arrays, create_control_grid, create_grid, @@ -370,5 +371,4 @@ resize_center, weighted_patch_samples, zero_margins, - AllowMissingKeysMode, ) diff --git a/monai/transforms/compose.py b/monai/transforms/compose.py index 54e3b8690d..7d3abef04f 100644 --- a/monai/transforms/compose.py +++ b/monai/transforms/compose.py @@ -13,15 +13,20 @@ """ import warnings -from copy import deepcopy -from typing import Any, Callable, Hashable, Mapping, Optional, Sequence, Tuple, Union +from typing import Any, Callable, Optional, Sequence, Union import numpy as np from monai.transforms.inverse import InvertibleTransform # For backwards compatiblity (so this still works: from monai.transforms.compose import MapTransform) -from monai.transforms.transform import apply_transform, MapTransform, Randomizable, RandomizableTransform, Transform # noqa: F401 +from monai.transforms.transform import ( # noqa: F401 + MapTransform, + Randomizable, + RandomizableTransform, + Transform, + apply_transform, +) from monai.utils import MAX_SEED, ensure_tuple, get_seed __all__ = ["Compose"] diff --git a/monai/transforms/croppad/dictionary.py b/monai/transforms/croppad/dictionary.py index e0e815a36b..cbed5f6e52 100644 --- a/monai/transforms/croppad/dictionary.py +++ b/monai/transforms/croppad/dictionary.py @@ -126,9 +126,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda d[key] = self.padder(d[key], mode=m) return d - def inverse( - self, data: Mapping[Hashable, np.ndarray] - ) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = deepcopy(dict(data)) for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) @@ -196,9 +194,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda d[key] = self.padder(d[key], mode=m) return d - def inverse( - self, data: Mapping[Hashable, np.ndarray] - ) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = deepcopy(dict(data)) for key in self.key_iterator(d): @@ -264,9 +260,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda d[key] = self.padder(d[key], mode=m) return d - def inverse( - self, data: Mapping[Hashable, np.ndarray] - ) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = deepcopy(dict(data)) for key in self.key_iterator(d): @@ -321,9 +315,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda d[key] = self.cropper(d[key]) return d - def inverse( - self, data: Mapping[Hashable, np.ndarray] - ) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = deepcopy(dict(data)) for key in self.key_iterator(d): @@ -371,9 +363,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda self.append_applied_transforms(d, key, orig_size=orig_size) return d - def inverse( - self, data: Mapping[Hashable, np.ndarray] - ) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = deepcopy(dict(data)) for key in self.key_iterator(d): @@ -456,9 +446,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda d[key] = cropper(d[key]) return d - def inverse( - self, data: Mapping[Hashable, np.ndarray] - ) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = deepcopy(dict(data)) for key in self.key_iterator(d): @@ -606,9 +594,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda d[key] = cropper(d[key]) return d - def inverse( - self, data: Mapping[Hashable, np.ndarray] - ) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = deepcopy(dict(data)) for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) @@ -843,9 +829,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda self.append_applied_transforms(d, key, orig_size=orig_size) return d - def inverse( - self, data: Mapping[Hashable, np.ndarray] - ) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = deepcopy(dict(data)) for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index 1f6dba768a..47db6762ef 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -210,9 +210,7 @@ def __call__( meta_data["affine"] = new_affine return d - def inverse( - self, data: Mapping[Hashable, np.ndarray] - ) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = deepcopy(dict(data)) for key, mode, padding_mode, align_corners, dtype in self.key_iterator( d, self.mode, self.padding_mode, self.align_corners, self.dtype @@ -308,9 +306,7 @@ def __call__( d[meta_data_key]["affine"] = new_affine return d - def inverse( - self, data: Mapping[Hashable, np.ndarray] - ) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = deepcopy(dict(data)) for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) @@ -357,9 +353,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda d[key] = self.rotator(d[key]) return d - def inverse( - self, data: Mapping[Hashable, np.ndarray] - ) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = deepcopy(dict(data)) for key in self.key_iterator(d): _ = self.get_most_recent_transform(d, key) @@ -429,9 +423,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Mapping[Hashable, np. self.append_applied_transforms(d, key, extra_info={"rand_k": self._rand_k}) return d - def inverse( - self, data: Mapping[Hashable, np.ndarray] - ) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = deepcopy(dict(data)) for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) @@ -494,9 +486,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda d[key] = self.resizer(d[key], mode=mode, align_corners=align_corners) return d - def inverse( - self, data: Mapping[Hashable, np.ndarray] - ) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = deepcopy(dict(data)) for key, mode, align_corners in self.key_iterator(d, self.mode, self.align_corners): transform = self.get_most_recent_transform(d, key) @@ -613,9 +603,7 @@ def __call__( d[key] = self.rand_affine.resampler(d[key], grid, mode=mode, padding_mode=padding_mode) return d - def inverse( - self, data: Mapping[Hashable, np.ndarray] - ) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = deepcopy(dict(data)) for key, mode, padding_mode in self.key_iterator(d, self.mode, self.padding_mode): @@ -765,9 +753,7 @@ def __call__( d[key] = self.rand_2d_elastic.resampler(d[key], grid, mode=mode, padding_mode=padding_mode) return d - def inverse( - self, data: Mapping[Hashable, np.ndarray] - ) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = deepcopy(dict(data)) # This variable will be `not None` if vtk or sitk is present inv_def_no_affine = None @@ -799,9 +785,7 @@ def inverse( inv_def_w_affine = CenterSpatialCrop(roi_size=orig_size)(inv_def_w_affine_wrong_size) # type: ignore # Apply inverse transform if inv_def_no_affine is not None: - out = self.rand_2d_elastic.resampler( - d[key], inv_def_w_affine, mode, padding_mode - ) + out = self.rand_2d_elastic.resampler(d[key], inv_def_w_affine, mode, padding_mode) d[key] = out.cpu().numpy() if isinstance(out, torch.Tensor) else out else: @@ -930,9 +914,7 @@ def __call__( d[key] = self.rand_3d_elastic.resampler(d[key], grid_w_affine, mode=mode, padding_mode=padding_mode) return d - def inverse( - self, data: Mapping[Hashable, np.ndarray] - ) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = deepcopy(dict(data)) for idx, (key, mode, padding_mode) in enumerate(self.key_iterator(d, self.mode, self.padding_mode)): @@ -955,9 +937,7 @@ def inverse( inv_def_w_affine = CenterSpatialCrop(roi_size=orig_size)(inv_def_w_affine_wrong_size) # type: ignore # Apply inverse transform if inv_def_w_affine is not None: - out = self.rand_3d_elastic.resampler( - d[key], inv_def_w_affine, mode, padding_mode - ) + out = self.rand_3d_elastic.resampler(d[key], inv_def_w_affine, mode, padding_mode) d[key] = out.cpu().numpy() if isinstance(out, torch.Tensor) else out else: d[key] = CenterSpatialCrop(roi_size=orig_size)(d[key]) @@ -996,9 +976,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda d[key] = self.flipper(d[key]) return d - def inverse( - self, data: Mapping[Hashable, np.ndarray] - ) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = deepcopy(dict(data)) for key in self.key_iterator(d): _ = self.get_most_recent_transform(d, key) @@ -1049,9 +1027,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda self.append_applied_transforms(d, key) return d - def inverse( - self, data: Mapping[Hashable, np.ndarray] - ) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = deepcopy(dict(data)) for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) @@ -1101,9 +1077,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda self.append_applied_transforms(d, key, extra_info={"axis": self._axis}) return d - def inverse( - self, data: Mapping[Hashable, np.ndarray] - ) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = deepcopy(dict(data)) for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) @@ -1184,9 +1158,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda self.append_applied_transforms(d, key, orig_size=orig_size, extra_info={"rot_mat": rot_mat}) return d - def inverse( - self, data: Mapping[Hashable, np.ndarray] - ) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = deepcopy(dict(data)) for key, mode, padding_mode, align_corners, dtype in self.key_iterator( d, self.mode, self.padding_mode, self.align_corners, self.dtype @@ -1319,9 +1291,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda self.append_applied_transforms(d, key, orig_size=orig_size, extra_info={"rot_mat": rot_mat}) return d - def inverse( - self, data: Mapping[Hashable, np.ndarray] - ) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = deepcopy(dict(data)) for key, mode, padding_mode, align_corners, dtype in self.key_iterator( d, self.mode, self.padding_mode, self.align_corners, self.dtype @@ -1407,9 +1377,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda ) return d - def inverse( - self, data: Mapping[Hashable, np.ndarray] - ) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = deepcopy(dict(data)) for key, mode, padding_mode, align_corners in self.key_iterator( d, self.mode, self.padding_mode, self.align_corners @@ -1525,9 +1493,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda ) return d - def inverse( - self, data: Mapping[Hashable, np.ndarray] - ) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = deepcopy(dict(data)) for key, mode, padding_mode, align_corners in self.key_iterator( d, self.mode, self.padding_mode, self.align_corners diff --git a/monai/transforms/transform.py b/monai/transforms/transform.py index 02b3ee6c71..2a79b2edf2 100644 --- a/monai/transforms/transform.py +++ b/monai/transforms/transform.py @@ -13,7 +13,7 @@ """ from abc import ABC, abstractmethod -from typing import Any, Callable, Dict, Generator, Hashable, Iterable, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, Generator, Hashable, Iterable, List, Optional, Tuple import numpy as np @@ -23,7 +23,6 @@ __all__ = ["apply_transform", "Randomizable", "RandomizableTransform", "Transform", "MapTransform"] - def apply_transform(transform: Callable, data, map_items: bool = True): """ Transform `data` with `transform`. @@ -48,6 +47,7 @@ def apply_transform(transform: Callable, data, map_items: bool = True): except Exception as e: raise RuntimeError(f"applying transform {transform}") from e + class Randomizable(ABC): """ An interface for handling random state locally, currently based on a class variable `R`, diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index a07a2b8dbd..27c15fb76e 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -10,8 +10,6 @@ # limitations under the License. import itertools -from monai.transforms.transform import MapTransform -from monai.transforms.compose import Compose import random import warnings from typing import Callable, List, Optional, Sequence, Tuple, Union @@ -21,6 +19,8 @@ from monai.config import DtypeLike, IndexSelection from monai.networks.layers import GaussianFilter +from monai.transforms.compose import Compose +from monai.transforms.transform import MapTransform from monai.utils import ensure_tuple, ensure_tuple_rep, ensure_tuple_size, fall_back_tuple, min_version, optional_import measure, _ = optional_import("skimage.measure", "0.14.2", min_version) @@ -39,7 +39,6 @@ "map_binary_to_indices", "weighted_patch_samples", "generate_pos_neg_label_crop_centers", - "apply_transform", "create_grid", "create_control_grid", "create_rotate", @@ -709,7 +708,8 @@ def map_spatial_axes( return spatial_axes_ -class AllowMissingKeysMode(): + +class AllowMissingKeysMode: """Temporarily set all MapTransforms to not throw an error if keys are missing. After, revert to original states. Args: @@ -725,6 +725,7 @@ class AllowMissingKeysMode(): with AllowMissingKeysMode(t): _ = t(data) # OK! """ + def __init__(self, transform: Union[MapTransform, Compose]): if isinstance(transform, MapTransform): self.transforms = [transform] diff --git a/tests/test_inverse.py b/tests/test_inverse.py index 89f8fbf6a0..bff11c1f70 100644 --- a/tests/test_inverse.py +++ b/tests/test_inverse.py @@ -24,6 +24,7 @@ from monai.transforms import ( AddChannel, AddChanneld, + AllowMissingKeysMode, BorderPadd, CenterSpatialCropd, Compose, @@ -49,10 +50,8 @@ Rotated, Spacingd, SpatialCropd, - SpatialPad, SpatialPadd, Zoomd, - AllowMissingKeysMode, ) from monai.utils import first, optional_import, set_determinism from tests.utils import make_nifti_image, make_rand_affine, test_is_quick diff --git a/tests/test_with_allow_missing_keys.py b/tests/test_with_allow_missing_keys.py index 77fa0a1a90..72d4171c39 100644 --- a/tests/test_with_allow_missing_keys.py +++ b/tests/test_with_allow_missing_keys.py @@ -9,9 +9,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -import numpy as np import unittest -from monai.transforms import Compose, SpatialPadd, SpatialPad, AllowMissingKeysMode + +import numpy as np + +from monai.transforms import AllowMissingKeysMode, Compose, SpatialPad, SpatialPadd + class TestWithAllowMissingKeys(unittest.TestCase): def setUp(self): @@ -45,7 +48,7 @@ def test_compose(self): for _t, amk in zip(t.transforms, amks): self.assertEqual(_t.allow_missing_keys, amk) # should fail because not all amks==True and key is missing - with self.assertRaises([KeyError, RuntimeError]): + with self.assertRaises((KeyError, RuntimeError)): _ = t(self.data) def test_array_transform(self): @@ -54,5 +57,6 @@ def test_array_transform(self): # should work as nothing should have changed _ = t(self.data["image"]) + if __name__ == "__main__": unittest.main() From dbc07703c146663bbd59c0d674e089e084a614c7 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Mon, 8 Mar 2021 15:06:29 +0000 Subject: [PATCH 26/64] update for allow_missing_keys_mode Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- tests/test_inverse.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_inverse.py b/tests/test_inverse.py index bff11c1f70..e7aad3b95a 100644 --- a/tests/test_inverse.py +++ b/tests/test_inverse.py @@ -24,7 +24,6 @@ from monai.transforms import ( AddChannel, AddChanneld, - AllowMissingKeysMode, BorderPadd, CenterSpatialCropd, Compose, @@ -52,6 +51,7 @@ SpatialCropd, SpatialPadd, Zoomd, + allow_missing_keys_mode, ) from monai.utils import first, optional_import, set_determinism from tests.utils import make_nifti_image, make_rand_affine, test_is_quick @@ -517,7 +517,7 @@ def test_inverse_inferred_seg(self): # inverse of individual segmentation seg_dict = first(segs_dict_decollated) - with AllowMissingKeysMode(transforms): + with allow_missing_keys_mode(transforms): inv_seg = transforms.inverse(seg_dict)["label"] self.assertEqual(len(data["label_transforms"]), num_invertible_transforms) self.assertEqual(len(seg_dict["label_transforms"]), num_invertible_transforms) From 29153465869a795977b9c76930c75944ad8b2f82 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Mon, 8 Mar 2021 15:10:07 +0000 Subject: [PATCH 27/64] inverse to use apply_transform Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/compose.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/transforms/compose.py b/monai/transforms/compose.py index 7d3abef04f..d509ea33a1 100644 --- a/monai/transforms/compose.py +++ b/monai/transforms/compose.py @@ -151,5 +151,5 @@ def inverse(self, data): # loop backwards over transforms for t in reversed(invertible_transforms): - data = t.inverse(data) + data = apply_transform(t.inverse, data) return data From 8efd75e43dc76ecbf324cf202248573282574b0b Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Mon, 8 Mar 2021 15:30:08 +0000 Subject: [PATCH 28/64] add enums Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/croppad/dictionary.py | 22 +++++----- monai/transforms/inverse.py | 29 ++++++++----- monai/transforms/spatial/dictionary.py | 56 +++++++++++++------------- tests/test_decollate.py | 3 +- 4 files changed, 61 insertions(+), 49 deletions(-) diff --git a/monai/transforms/croppad/dictionary.py b/monai/transforms/croppad/dictionary.py index cbed5f6e52..97045ccd18 100644 --- a/monai/transforms/croppad/dictionary.py +++ b/monai/transforms/croppad/dictionary.py @@ -131,7 +131,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) # Create inverse transform - orig_size = transform["orig_size"] + orig_size = transform[InvertibleTransform.Keys.orig_size] if self.padder.method == Method.SYMMETRIC: current_size = d[key].shape[1:] roi_center = [floor(i / 2) if r % 2 == 0 else (i - 1) // 2 for r, i in zip(orig_size, current_size)] @@ -200,7 +200,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) # Create inverse transform - orig_size = np.array(transform["orig_size"]) + orig_size = np.array(transform[InvertibleTransform.Keys.orig_size]) roi_start = np.array(self.padder.spatial_border) # Need to convert single value to [min1,min2,...] if roi_start.size == 1: @@ -208,7 +208,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar # need to convert [min1,max1,min2,...] to [min1,min2,...] elif roi_start.size == 2 * orig_size.size: roi_start = roi_start[::2] - roi_end = np.array(transform["orig_size"]) + roi_start + roi_end = np.array(transform[InvertibleTransform.Keys.orig_size]) + roi_start inverse_transform = SpatialCrop(roi_start=roi_start, roi_end=roi_end) # Apply inverse transform @@ -266,7 +266,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) # Create inverse transform - orig_size = np.array(transform["orig_size"]) + orig_size = np.array(transform[InvertibleTransform.Keys.orig_size]) current_size = np.array(d[key].shape[1:]) roi_start = np.floor((current_size - orig_size) / 2) roi_end = orig_size + roi_start @@ -321,7 +321,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) # Create inverse transform - orig_size = transform["orig_size"] + orig_size = transform[InvertibleTransform.Keys.orig_size] pad_to_start = self.cropper.roi_start pad_to_end = orig_size - self.cropper.roi_end # interweave mins and maxes @@ -369,7 +369,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) # Create inverse transform - orig_size = np.array(transform["orig_size"]) + orig_size = np.array(transform[InvertibleTransform.Keys.orig_size]) current_size = np.array(d[key].shape[1:]) pad_to_start = np.floor((orig_size - current_size) / 2) # in each direction, if original size is even and current size is odd, += 1 @@ -452,12 +452,12 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) # Create inverse transform - orig_size = transform["orig_size"] + orig_size = transform[InvertibleTransform.Keys.orig_size] random_center = self.random_center pad_to_start = np.empty((len(orig_size)), dtype=np.int32) pad_to_end = np.empty((len(orig_size)), dtype=np.int32) if random_center: - for i, _slice in enumerate(transform["extra_info"]["slices"]): + for i, _slice in enumerate(transform[InvertibleTransform.Keys.extra_info]["slices"]): pad_to_start[i] = _slice[0] pad_to_end[i] = orig_size[i] - _slice[1] else: @@ -599,8 +599,8 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) # Create inverse transform - orig_size = np.array(transform["orig_size"]) - extra_info = transform["extra_info"] + orig_size = np.array(transform[InvertibleTransform.Keys.orig_size]) + extra_info = transform[InvertibleTransform.Keys.extra_info] pad_to_start = np.array(extra_info["box_start"]) pad_to_end = orig_size - np.array(extra_info["box_end"]) # interweave mins and maxes @@ -834,7 +834,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) # Create inverse transform - orig_size = transform["orig_size"] + orig_size = transform[InvertibleTransform.Keys.orig_size] inverse_transform = ResizeWithPadOrCrop(spatial_size=orig_size, mode=self.padcropper.padder.mode) # Apply inverse transform d[key] = inverse_transform(d[key]) diff --git a/monai/transforms/inverse.py b/monai/transforms/inverse.py index c326ad1c4e..bd7affa29f 100644 --- a/monai/transforms/inverse.py +++ b/monai/transforms/inverse.py @@ -10,6 +10,7 @@ # limitations under the License. import warnings +from enum import Enum from typing import Dict, Hashable, Optional, Tuple import numpy as np @@ -48,6 +49,16 @@ class InvertibleTransform(Transform): `remove_most_recent_transform` is called. """ + class Keys(Enum): + """Extra meta data keys used for inverse transforms.""" + + class_name = "class" + id = "id" + orig_size = "orig_size" + extra_info = "extra_info" + do_transform = "do_transforms" + key_suffix = "_transform" + def append_applied_transforms( self, data: dict, @@ -56,17 +67,17 @@ def append_applied_transforms( orig_size: Optional[Tuple] = None, ) -> None: """Append to list of applied transforms for that key.""" - key_transform = str(key) + "_transforms" + key_transform = str(key) + str(self.Keys.key_suffix) info = { - "class": self.__class__.__name__, - "id": id(self), - "orig_size": orig_size or data[key].shape[1:], + self.Keys.class_name: self.__class__.__name__, + self.Keys.id: id(self), + self.Keys.orig_size: orig_size or data[key].shape[1:], } if extra_info is not None: - info["extra_info"] = extra_info + info[self.Keys.extra_info] = extra_info # If class is randomizable transform, store whether the transform was actually performed (based on `prob`) if isinstance(self, RandomizableTransform): - info["do_transform"] = self._do_transform + info[self.Keys.do_transform] = self._do_transform # If this is the first, create list if key_transform not in data: data[key_transform] = [] @@ -74,19 +85,19 @@ def append_applied_transforms( def check_transforms_match(self, transform: dict) -> None: # Check transorms are of same type. - if transform["id"] != id(self): + if transform[self.Keys.id] != id(self): raise RuntimeError("Should inverse most recently applied invertible transform first") def get_most_recent_transform(self, data: dict, key: Hashable) -> dict: """Get most recent transform.""" - transform = dict(data[str(key) + "_transforms"][-1]) + transform = dict(data[str(key) + str(self.Keys.key_suffix)][-1]) self.check_transforms_match(transform) return transform @staticmethod def remove_most_recent_transform(data: dict, key: Hashable) -> None: """Remove most recent transform.""" - data[str(key) + "_transforms"].pop() + data[str(key) + str(InvertibleTransform.Keys.key_suffix)].pop() def inverse(self, data: dict) -> Dict[Hashable, np.ndarray]: """ diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index 2f142e3ecc..623731d468 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -226,8 +226,8 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar + "Please raise a github issue if you need this feature" ) # Create inverse transform - meta_data = d[transform["extra_info"]["meta_data_key"]] - old_affine = np.array(transform["extra_info"]["old_affine"]) + meta_data = d[transform[InvertibleTransform.Keys.extra_info]["meta_data_key"]] + old_affine = np.array(transform[InvertibleTransform.Keys.extra_info]["old_affine"]) orig_pixdim = np.sqrt(np.sum(np.square(old_affine), 0))[:-1] inverse_transform = Spacing(orig_pixdim, diagonal=self.spacing_transform.diagonal) # Apply inverse @@ -315,8 +315,8 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) # Create inverse transform - meta_data = d[transform["extra_info"]["meta_data_key"]] - orig_affine = transform["extra_info"]["old_affine"] + meta_data = d[transform[InvertibleTransform.Keys.extra_info]["meta_data_key"]] + orig_affine = transform[InvertibleTransform.Keys.extra_info]["old_affine"] orig_axcodes = nib.orientations.aff2axcodes(orig_affine) inverse_transform = Orientation( axcodes=orig_axcodes, @@ -432,9 +432,9 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) # Check if random transform was actually performed (based on `prob`) - if transform["do_transform"]: + if transform[InvertibleTransform.Keys.do_transform]: # Create inverse transform - num_times_rotated = transform["extra_info"]["rand_k"] + num_times_rotated = transform[InvertibleTransform.Keys.extra_info]["rand_k"] num_times_to_rotate = 4 - num_times_rotated inverse_transform = Rotate90(num_times_to_rotate, self.spatial_axes) # Might need to convert to numpy @@ -494,7 +494,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar d = deepcopy(dict(data)) for key, mode, align_corners in self.key_iterator(d, self.mode, self.align_corners): transform = self.get_most_recent_transform(d, key) - orig_size = transform["orig_size"] + orig_size = transform[InvertibleTransform.Keys.orig_size] # Create inverse transform inverse_transform = Resize(orig_size, mode, align_corners) # Apply inverse transform @@ -685,9 +685,9 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar for key, mode, padding_mode in self.key_iterator(d, self.mode, self.padding_mode): transform = self.get_most_recent_transform(d, key) - orig_size = transform["orig_size"] + orig_size = transform[InvertibleTransform.Keys.orig_size] # Create inverse transform - fwd_affine = transform["extra_info"]["affine"] + fwd_affine = transform[InvertibleTransform.Keys.extra_info]["affine"] inv_affine = np.linalg.inv(fwd_affine) affine_grid = AffineGrid(affine=inv_affine) @@ -838,22 +838,22 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar for idx, (key, mode, padding_mode) in enumerate(self.key_iterator(d, self.mode, self.padding_mode)): transform = self.get_most_recent_transform(d, key) # Create inverse transform - if transform["do_transform"]: - orig_size = transform["orig_size"] + if transform[InvertibleTransform.Keys.do_transform]: + orig_size = transform[InvertibleTransform.Keys.orig_size] # Only need to calculate inverse deformation once as it is the same for all keys if idx == 0: # If magnitude == 0, then non-rigid component is identity -- so just create blank if self.rand_2d_elastic.deform_grid.magnitude == (0.0, 0.0): inv_def_no_affine = create_grid(spatial_size=orig_size) else: - fwd_cpg_no_affine = transform["extra_info"]["cpg"] + fwd_cpg_no_affine = transform[InvertibleTransform.Keys.extra_info]["cpg"] fwd_def_no_affine = self.cpg_to_dvf( fwd_cpg_no_affine, self.rand_2d_elastic.deform_grid.spacing, orig_size ) inv_def_no_affine = self.compute_inverse_deformation(len(orig_size), fwd_def_no_affine) # if inverse did not succeed (sitk or vtk present), data will not be changed. if inv_def_no_affine is not None: - fwd_affine = transform["extra_info"]["affine"] + fwd_affine = transform[InvertibleTransform.Keys.extra_info]["affine"] inv_affine = np.linalg.inv(fwd_affine) inv_def_w_affine_wrong_size = AffineGrid(affine=inv_affine, as_tensor_output=False)( grid=inv_def_no_affine @@ -997,15 +997,15 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar for idx, (key, mode, padding_mode) in enumerate(self.key_iterator(d, self.mode, self.padding_mode)): transform = self.get_most_recent_transform(d, key) # Create inverse transform - if transform["do_transform"]: - orig_size = transform["orig_size"] + if transform[InvertibleTransform.Keys.do_transform]: + orig_size = transform[InvertibleTransform.Keys.orig_size] # Only need to calculate inverse deformation once as it is the same for all keys if idx == 0: - fwd_def_no_affine = transform["extra_info"]["grid_no_affine"] + fwd_def_no_affine = transform[InvertibleTransform.Keys.extra_info]["grid_no_affine"] inv_def_no_affine = self.compute_inverse_deformation(len(orig_size), fwd_def_no_affine) # if inverse did not succeed (sitk or vtk present), data will not be changed. if inv_def_no_affine is not None: - fwd_affine = transform["extra_info"]["affine"] + fwd_affine = transform[InvertibleTransform.Keys.extra_info]["affine"] inv_affine = np.linalg.inv(fwd_affine) inv_def_w_affine_wrong_size = AffineGrid(affine=inv_affine, as_tensor_output=False)( grid=inv_def_no_affine @@ -1109,7 +1109,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) # Check if random transform was actually performed (based on `prob`) - if transform["do_transform"]: + if transform[InvertibleTransform.Keys.do_transform]: # Might need to convert to numpy if isinstance(d[key], torch.Tensor): d[key] = torch.Tensor(d[key]).cpu().numpy() @@ -1159,8 +1159,8 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) # Check if random transform was actually performed (based on `prob`) - if transform["do_transform"]: - flipper = Flip(spatial_axis=transform["extra_info"]["axis"]) + if transform[InvertibleTransform.Keys.do_transform]: + flipper = Flip(spatial_axis=transform[InvertibleTransform.Keys.extra_info]["axis"]) # Might need to convert to numpy if isinstance(d[key], torch.Tensor): d[key] = torch.Tensor(d[key]).cpu().numpy() @@ -1242,7 +1242,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar ): transform = self.get_most_recent_transform(d, key) # Create inverse transform - fwd_rot_mat = transform["extra_info"]["rot_mat"] + fwd_rot_mat = transform[InvertibleTransform.Keys.extra_info]["rot_mat"] inv_rot_mat = np.linalg.inv(fwd_rot_mat) xform = AffineTransform( @@ -1255,7 +1255,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar output = xform( torch.as_tensor(np.ascontiguousarray(d[key]).astype(dtype)).unsqueeze(0), torch.as_tensor(np.ascontiguousarray(inv_rot_mat).astype(dtype)), - spatial_size=transform["orig_size"], + spatial_size=transform[InvertibleTransform.Keys.orig_size], ) d[key] = np.asarray(output.squeeze(0).detach().cpu().numpy(), dtype=np.float32) # Remove the applied transform @@ -1375,9 +1375,9 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar ): transform = self.get_most_recent_transform(d, key) # Check if random transform was actually performed (based on `prob`) - if transform["do_transform"]: + if transform[InvertibleTransform.Keys.do_transform]: # Create inverse transform - fwd_rot_mat = transform["extra_info"]["rot_mat"] + fwd_rot_mat = transform[InvertibleTransform.Keys.extra_info]["rot_mat"] inv_rot_mat = np.linalg.inv(fwd_rot_mat) xform = AffineTransform( @@ -1390,7 +1390,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar output = xform( torch.as_tensor(np.ascontiguousarray(d[key]).astype(dtype)).unsqueeze(0), torch.as_tensor(np.ascontiguousarray(inv_rot_mat).astype(dtype)), - spatial_size=transform["orig_size"], + spatial_size=transform[InvertibleTransform.Keys.orig_size], ) d[key] = np.asarray(output.squeeze(0).detach().cpu().numpy(), dtype=np.float32) # Remove the applied transform @@ -1471,7 +1471,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar align_corners=align_corners, ) # Size might be out by 1 voxel so pad - d[key] = SpatialPad(transform["orig_size"])(d[key]) + d[key] = SpatialPad(transform[InvertibleTransform.Keys.orig_size])(d[key]) # Remove the applied transform self.remove_most_recent_transform(d, key) @@ -1577,7 +1577,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar ): transform = self.get_most_recent_transform(d, key) # Create inverse transform - zoom = np.array(transform["extra_info"]["zoom"]) + zoom = np.array(transform[InvertibleTransform.Keys.extra_info]["zoom"]) inverse_transform = Zoom(zoom=1 / zoom, keep_size=self.keep_size) # Apply inverse d[key] = inverse_transform( @@ -1587,7 +1587,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar align_corners=align_corners, ) # Size might be out by 1 voxel so pad - d[key] = SpatialPad(transform["orig_size"])(d[key]) + d[key] = SpatialPad(transform[InvertibleTransform.Keys.orig_size])(d[key]) # Remove the applied transform self.remove_most_recent_transform(d, key) diff --git a/tests/test_decollate.py b/tests/test_decollate.py index 84b92bdb2c..94309cb057 100644 --- a/tests/test_decollate.py +++ b/tests/test_decollate.py @@ -18,6 +18,7 @@ from monai.data import CacheDataset, DataLoader, create_test_image_2d from monai.data.utils import decollate_batch from monai.transforms import AddChanneld, Compose, LoadImaged, RandFlipd, SpatialPadd, ToTensord +from monai.transforms.inverse import InvertibleTransform from monai.transforms.post.dictionary import Decollated from monai.utils import optional_import, set_determinism from tests.utils import make_nifti_image @@ -38,7 +39,7 @@ def check_match(self, in1, in2): for (k1, v1), (k2, v2) in zip(in1.items(), in2.items()): self.check_match(k1, k2) # Transform ids won't match for windows with multiprocessing - if k1 == "id" and sys.platform == "win32": + if k1 == str(InvertibleTransform.Keys.id) and sys.platform == "win32": continue self.check_match(v1, v2) self.check_match(list(in1.values()), list(in2.values())) From a0ad428407884ab00276d4447ce52ea06973c10d Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Mon, 8 Mar 2021 15:43:51 +0000 Subject: [PATCH 29/64] push_ and pop_transform Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/croppad/dictionary.py | 34 ++++++------- monai/transforms/inverse.py | 9 ++-- monai/transforms/spatial/dictionary.py | 68 ++++++++++++-------------- 3 files changed, 53 insertions(+), 58 deletions(-) diff --git a/monai/transforms/croppad/dictionary.py b/monai/transforms/croppad/dictionary.py index 97045ccd18..f7cb2d6d6d 100644 --- a/monai/transforms/croppad/dictionary.py +++ b/monai/transforms/croppad/dictionary.py @@ -122,7 +122,7 @@ def __init__( def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) for key, m in self.key_iterator(d, self.mode): - self.append_applied_transforms(d, key) + self.push_transform(d, key) d[key] = self.padder(d[key], mode=m) return d @@ -142,7 +142,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar # Apply inverse transform d[key] = inverse_transform(d[key]) # Remove the applied transform - self.remove_most_recent_transform(d, key) + self.pop_transform(d, key) return d @@ -190,7 +190,7 @@ def __init__( def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) for key, m in self.key_iterator(d, self.mode): - self.append_applied_transforms(d, key) + self.push_transform(d, key) d[key] = self.padder(d[key], mode=m) return d @@ -214,7 +214,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar # Apply inverse transform d[key] = inverse_transform(d[key]) # Remove the applied transform - self.remove_most_recent_transform(d, key) + self.pop_transform(d, key) return d @@ -256,7 +256,7 @@ def __init__( def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) for key, m in self.key_iterator(d, self.mode): - self.append_applied_transforms(d, key) + self.push_transform(d, key) d[key] = self.padder(d[key], mode=m) return d @@ -274,7 +274,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar # Apply inverse transform d[key] = inverse_transform(d[key]) # Remove the applied transform - self.remove_most_recent_transform(d, key) + self.pop_transform(d, key) return d @@ -311,7 +311,7 @@ def __init__( def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) for key in self.key_iterator(d): - self.append_applied_transforms(d, key) + self.push_transform(d, key) d[key] = self.cropper(d[key]) return d @@ -332,7 +332,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar # Apply inverse transform d[key] = inverse_transform(d[key]) # Remove the applied transform - self.remove_most_recent_transform(d, key) + self.pop_transform(d, key) return d @@ -360,7 +360,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda for key in self.key_iterator(d): orig_size = d[key].shape[1:] d[key] = self.cropper(d[key]) - self.append_applied_transforms(d, key, orig_size=orig_size) + self.push_transform(d, key, orig_size=orig_size) return d def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: @@ -382,7 +382,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar # Apply inverse transform d[key] = inverse_transform(d[key]) # Remove the applied transform - self.remove_most_recent_transform(d, key) + self.pop_transform(d, key) return d @@ -438,10 +438,10 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda raise AssertionError for key in self.key_iterator(d): if self.random_center: - self.append_applied_transforms(d, key, {"slices": [(i.start, i.stop) for i in self._slices[1:]]}) # type: ignore + self.push_transform(d, key, {"slices": [(i.start, i.stop) for i in self._slices[1:]]}) # type: ignore d[key] = d[key][self._slices] else: - self.append_applied_transforms(d, key) + self.push_transform(d, key) cropper = CenterSpatialCrop(self._size) d[key] = cropper(d[key]) return d @@ -476,7 +476,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar # Apply inverse transform d[key] = inverse_transform(d[key]) # Remove the applied transform - self.remove_most_recent_transform(d, key) + self.pop_transform(d, key) return d @@ -590,7 +590,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda d[self.end_coord_key] = np.asarray(box_end) cropper = SpatialCrop(roi_start=box_start, roi_end=box_end) for key in self.key_iterator(d): - self.append_applied_transforms(d, key, extra_info={"box_start": box_start, "box_end": box_end}) + self.push_transform(d, key, extra_info={"box_start": box_start, "box_end": box_end}) d[key] = cropper(d[key]) return d @@ -611,7 +611,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar # Apply inverse transform d[key] = inverse_transform(d[key]) # Remove the applied transform - self.remove_most_recent_transform(d, key) + self.pop_transform(d, key) return d @@ -826,7 +826,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda for key in self.key_iterator(d): orig_size = d[key].shape[1:] d[key] = self.padcropper(d[key]) - self.append_applied_transforms(d, key, orig_size=orig_size) + self.push_transform(d, key, orig_size=orig_size) return d def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: @@ -839,7 +839,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar # Apply inverse transform d[key] = inverse_transform(d[key]) # Remove the applied transform - self.remove_most_recent_transform(d, key) + self.pop_transform(d, key) return d diff --git a/monai/transforms/inverse.py b/monai/transforms/inverse.py index bd7affa29f..ff6cf16d9d 100644 --- a/monai/transforms/inverse.py +++ b/monai/transforms/inverse.py @@ -40,7 +40,7 @@ class InvertibleTransform(Transform): Note to developers: When converting a transform to an invertible transform, you need to: 1. Inherit from this class. - 2. In `__call__`, add a call to `append_applied_transforms`. + 2. In `__call__`, add a call to `push_transform`. 3. Any extra information that might be needed for the inverse can be included with the dictionary `extra_info`. This dictionary should have the same keys regardless of whether `do_transform` was True or False and can only contain objects that are @@ -59,7 +59,7 @@ class Keys(Enum): do_transform = "do_transforms" key_suffix = "_transform" - def append_applied_transforms( + def push_transform( self, data: dict, key: Hashable, @@ -94,10 +94,9 @@ def get_most_recent_transform(self, data: dict, key: Hashable) -> dict: self.check_transforms_match(transform) return transform - @staticmethod - def remove_most_recent_transform(data: dict, key: Hashable) -> None: + def pop_transform(self, data: dict, key: Hashable) -> None: """Remove most recent transform.""" - data[str(key) + str(InvertibleTransform.Keys.key_suffix)].pop() + data[str(key) + str(self.Keys.key_suffix)].pop() def inverse(self, data: dict) -> Dict[Hashable, np.ndarray]: """ diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index 623731d468..fe6e3ac496 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -207,9 +207,7 @@ def __call__( align_corners=align_corners, dtype=dtype, ) - self.append_applied_transforms( - d, key, extra_info={"meta_data_key": meta_data_key, "old_affine": old_affine} - ) + self.push_transform(d, key, extra_info={"meta_data_key": meta_data_key, "old_affine": old_affine}) # set the 'affine' key meta_data["affine"] = new_affine return d @@ -241,7 +239,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar ) meta_data["affine"] = new_affine # Remove the applied transform - self.remove_most_recent_transform(d, key) + self.pop_transform(d, key) return d @@ -304,9 +302,7 @@ def __call__( meta_data_key = f"{key}_{self.meta_key_postfix}" meta_data = d[meta_data_key] d[key], old_affine, new_affine = self.ornt_transform(d[key], affine=meta_data["affine"]) - self.append_applied_transforms( - d, key, extra_info={"meta_data_key": meta_data_key, "old_affine": old_affine} - ) + self.push_transform(d, key, extra_info={"meta_data_key": meta_data_key, "old_affine": old_affine}) d[meta_data_key]["affine"] = new_affine return d @@ -327,7 +323,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar d[key], _, new_affine = inverse_transform(d[key], affine=meta_data["affine"]) meta_data["affine"] = new_affine # Remove the applied transform - self.remove_most_recent_transform(d, key) + self.pop_transform(d, key) return d @@ -353,7 +349,7 @@ def __init__( def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) for key in self.key_iterator(d): - self.append_applied_transforms(d, key) + self.push_transform(d, key) d[key] = self.rotator(d[key]) return d @@ -372,7 +368,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar # Apply inverse d[key] = inverse_transform(d[key]) # Remove the applied transform - self.remove_most_recent_transform(d, key) + self.pop_transform(d, key) return d @@ -424,7 +420,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Mapping[Hashable, np. for key in self.key_iterator(d): if self._do_transform: d[key] = rotator(d[key]) - self.append_applied_transforms(d, key, extra_info={"rand_k": self._rand_k}) + self.push_transform(d, key, extra_info={"rand_k": self._rand_k}) return d def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: @@ -443,7 +439,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar # Apply inverse d[key] = inverse_transform(d[key]) # Remove the applied transform - self.remove_most_recent_transform(d, key) + self.pop_transform(d, key) return d @@ -486,7 +482,7 @@ def __init__( def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) for key, mode, align_corners in self.key_iterator(d, self.mode, self.align_corners): - self.append_applied_transforms(d, key) + self.push_transform(d, key) d[key] = self.resizer(d[key], mode=mode, align_corners=align_corners) return d @@ -500,7 +496,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar # Apply inverse transform d[key] = inverse_transform(d[key]) # Remove the applied transform - self.remove_most_recent_transform(d, key) + self.pop_transform(d, key) return d @@ -676,7 +672,7 @@ def __call__( affine = np.eye(len(sp_size) + 1) for key, mode, padding_mode in self.key_iterator(d, self.mode, self.padding_mode): - self.append_applied_transforms(d, key, extra_info={"affine": affine}) + self.push_transform(d, key, extra_info={"affine": affine}) d[key] = self.rand_affine.resampler(d[key], grid, mode=mode, padding_mode=padding_mode) return d @@ -700,7 +696,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar d[key] = out if isinstance(out, np.ndarray) else out.cpu().numpy() # Remove the applied transform - self.remove_most_recent_transform(d, key) + self.pop_transform(d, key) return d @@ -826,7 +822,7 @@ def __call__( extra_info = None for key, mode, padding_mode in self.key_iterator(d, self.mode, self.padding_mode): - self.append_applied_transforms(d, key, extra_info=extra_info) + self.push_transform(d, key, extra_info=extra_info) d[key] = self.rand_2d_elastic.resampler(d[key], grid, mode=mode, padding_mode=padding_mode) return d @@ -868,7 +864,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar else: d[key] = CenterSpatialCrop(roi_size=orig_size)(d[key]) # Remove the applied transform - self.remove_most_recent_transform(d, key) + self.pop_transform(d, key) return d @@ -987,7 +983,7 @@ def __call__( affine = np.eye(len(sp_size) + 1) for key, mode, padding_mode in self.key_iterator(d, self.mode, self.padding_mode): - self.append_applied_transforms(d, key, extra_info={"grid_no_affine": grid_no_affine, "affine": affine}) + self.push_transform(d, key, extra_info={"grid_no_affine": grid_no_affine, "affine": affine}) d[key] = self.rand_3d_elastic.resampler(d[key], grid_w_affine, mode=mode, padding_mode=padding_mode) return d @@ -1019,7 +1015,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar else: d[key] = CenterSpatialCrop(roi_size=orig_size)(d[key]) # Remove the applied transform - self.remove_most_recent_transform(d, key) + self.pop_transform(d, key) return d @@ -1049,7 +1045,7 @@ def __init__( def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) for key in self.key_iterator(d): - self.append_applied_transforms(d, key) + self.push_transform(d, key) d[key] = self.flipper(d[key]) return d @@ -1063,7 +1059,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar # Inverse is same as forward d[key] = self.flipper(d[key]) # Remove the applied transform - self.remove_most_recent_transform(d, key) + self.pop_transform(d, key) return d @@ -1101,7 +1097,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda for key in self.key_iterator(d): if self._do_transform: d[key] = self.flipper(d[key]) - self.append_applied_transforms(d, key) + self.push_transform(d, key) return d def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: @@ -1116,7 +1112,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar # Inverse is same as forward d[key] = self.flipper(d[key]) # Remove the applied transform - self.remove_most_recent_transform(d, key) + self.pop_transform(d, key) return d @@ -1151,7 +1147,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda for key in self.key_iterator(d): if self._do_transform: d[key] = flipper(d[key]) - self.append_applied_transforms(d, key, extra_info={"axis": self._axis}) + self.push_transform(d, key, extra_info={"axis": self._axis}) return d def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: @@ -1167,7 +1163,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar # Inverse is same as forward d[key] = flipper(d[key]) # Remove the applied transform - self.remove_most_recent_transform(d, key) + self.pop_transform(d, key) return d @@ -1232,7 +1228,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda dtype=dtype, return_rotation_matrix=True, ) - self.append_applied_transforms(d, key, orig_size=orig_size, extra_info={"rot_mat": rot_mat}) + self.push_transform(d, key, orig_size=orig_size, extra_info={"rot_mat": rot_mat}) return d def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: @@ -1259,7 +1255,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar ) d[key] = np.asarray(output.squeeze(0).detach().cpu().numpy(), dtype=np.float32) # Remove the applied transform - self.remove_most_recent_transform(d, key) + self.pop_transform(d, key) return d @@ -1346,7 +1342,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda d = dict(data) if not self._do_transform: for key in self.keys: - self.append_applied_transforms(d, key, extra_info={"rot_mat": np.eye(4)}) + self.push_transform(d, key, extra_info={"rot_mat": np.eye(4)}) return d angle: Union[Sequence[float], float] = self.x if d[self.keys[0]].ndim == 3 else (self.x, self.y, self.z) rotator = Rotate( @@ -1365,7 +1361,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda dtype=dtype, return_rotation_matrix=True, ) - self.append_applied_transforms(d, key, orig_size=orig_size, extra_info={"rot_mat": rot_mat}) + self.push_transform(d, key, orig_size=orig_size, extra_info={"rot_mat": rot_mat}) return d def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: @@ -1394,7 +1390,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar ) d[key] = np.asarray(output.squeeze(0).detach().cpu().numpy(), dtype=np.float32) # Remove the applied transform - self.remove_most_recent_transform(d, key) + self.pop_transform(d, key) return d @@ -1445,7 +1441,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda for key, mode, padding_mode, align_corners in self.key_iterator( d, self.mode, self.padding_mode, self.align_corners ): - self.append_applied_transforms(d, key) + self.push_transform(d, key) d[key] = self.zoomer( d[key], mode=mode, @@ -1473,7 +1469,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar # Size might be out by 1 voxel so pad d[key] = SpatialPad(transform[InvertibleTransform.Keys.orig_size])(d[key]) # Remove the applied transform - self.remove_most_recent_transform(d, key) + self.pop_transform(d, key) return d @@ -1547,7 +1543,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda d = dict(data) if not self._do_transform: for key in self.keys: - self.append_applied_transforms(d, key, extra_info={"zoom": self._zoom}) + self.push_transform(d, key, extra_info={"zoom": self._zoom}) return d img_dims = data[self.keys[0]].ndim @@ -1561,7 +1557,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda for key, mode, padding_mode, align_corners in self.key_iterator( d, self.mode, self.padding_mode, self.align_corners ): - self.append_applied_transforms(d, key, extra_info={"zoom": self._zoom}) + self.push_transform(d, key, extra_info={"zoom": self._zoom}) d[key] = zoomer( d[key], mode=mode, @@ -1589,7 +1585,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar # Size might be out by 1 voxel so pad d[key] = SpatialPad(transform[InvertibleTransform.Keys.orig_size])(d[key]) # Remove the applied transform - self.remove_most_recent_transform(d, key) + self.pop_transform(d, key) return d From b7b17a033150f8e805a99efd960ad2704e2070ed Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Mon, 8 Mar 2021 15:51:22 +0000 Subject: [PATCH 30/64] update doc Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/inverse.py | 25 ++++++++++++++++++++++--- 1 file changed, 22 insertions(+), 3 deletions(-) diff --git a/monai/transforms/inverse.py b/monai/transforms/inverse.py index ff6cf16d9d..53ed5b6b99 100644 --- a/monai/transforms/inverse.py +++ b/monai/transforms/inverse.py @@ -34,9 +34,28 @@ class InvertibleTransform(Transform): and after be returned to their original size before saving to file for comparison in an external viewer. - When the `__call__` method is called, a serialization of the class is stored. When - the `inverse` method is called, the serialization is then removed. We use last in, - first out for the inverted transforms. + When the `__call__` method is called, the transformation information for each key is + stored. If the transforms were applied to keys "image" and "label", there will be two + extra keys in the dictionary: "image_transforms" and "label_transforms". Each list + contains a list of the transforms applied to that key. When the ``inverse`` method is + called, the inverse is called on each key individually, which allows for different + parameters being passed to each label (e.g., different interpolation for image and + label). + + When the ``inverse`` method is called, the inverse transforms are applied in a last- + in-first-out order. As the inverse is applied, its entry is removed from the list + detailing the applied transformations. That is to say that during the forward pass, + the list of applied transforms grows, and then during the inverse it shrinks back + down to an empty list. + + The information in data[key_transform] will be compatible with the default collate + since it only stores strings, numbers and arrays. + + We currently check that the id() of the transform is the same in the forward and + inverse directions. This is a useful check to ensure that the inverses are being + processed in the correct order. However, this may cause issues if the id() of the + object changes (such as multiprocessing on Windows). If you feel this issue affects + you, please raise a GitHub issue. Note to developers: When converting a transform to an invertible transform, you need to: 1. Inherit from this class. From 9313cffa41b12383e1de7571cb937a64f599ec94 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Mon, 8 Mar 2021 16:26:18 +0000 Subject: [PATCH 31/64] basic API Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/__init__.py | 1 + monai/transforms/compose.py | 14 +- monai/transforms/croppad/dictionary.py | 25 ++- monai/transforms/inverse.py | 122 +++++++++++++++ tests/test_decollate.py | 13 +- tests/test_inverse.py | 207 +++++++++++++++++++++++++ 6 files changed, 378 insertions(+), 4 deletions(-) create mode 100644 monai/transforms/inverse.py create mode 100644 tests/test_inverse.py diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index 796804df24..5b12da4d21 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -138,6 +138,7 @@ ThresholdIntensityD, ThresholdIntensityDict, ) +from .inverse import InvertibleTransform from .io.array import LoadImage, SaveImage from .io.dictionary import LoadImaged, LoadImageD, LoadImageDict, SaveImaged, SaveImageD, SaveImageDict from .post.array import ( diff --git a/monai/transforms/compose.py b/monai/transforms/compose.py index 21e7da068c..d509ea33a1 100644 --- a/monai/transforms/compose.py +++ b/monai/transforms/compose.py @@ -17,6 +17,8 @@ import numpy as np +from monai.transforms.inverse import InvertibleTransform + # For backwards compatiblity (so this still works: from monai.transforms.compose import MapTransform) from monai.transforms.transform import ( # noqa: F401 MapTransform, @@ -30,7 +32,7 @@ __all__ = ["Compose"] -class Compose(RandomizableTransform): +class Compose(RandomizableTransform, InvertibleTransform): """ ``Compose`` provides the ability to chain a series of calls together in a sequence. Each transform in the sequence must take a single argument and @@ -141,3 +143,13 @@ def __call__(self, input_): for _transform in self.transforms: input_ = apply_transform(_transform, input_) return input_ + + def inverse(self, data): + invertible_transforms = [t for t in self.flatten().transforms if isinstance(t, InvertibleTransform)] + if len(invertible_transforms) == 0: + warnings.warn("inverse has been called but no invertible transforms have been supplied") + + # loop backwards over transforms + for t in reversed(invertible_transforms): + data = apply_transform(t.inverse, data) + return data diff --git a/monai/transforms/croppad/dictionary.py b/monai/transforms/croppad/dictionary.py index 823b2dd3f4..ef47bd1d06 100644 --- a/monai/transforms/croppad/dictionary.py +++ b/monai/transforms/croppad/dictionary.py @@ -15,6 +15,8 @@ Class names are ended with 'd' to denote dictionary-based transforms. """ +from copy import deepcopy +from math import floor from typing import Any, Callable, Dict, Hashable, List, Mapping, Optional, Sequence, Tuple, Union import numpy as np @@ -30,6 +32,7 @@ SpatialCrop, SpatialPad, ) +from monai.transforms.inverse import InvertibleTransform from monai.transforms.transform import MapTransform, Randomizable, RandomizableTransform from monai.transforms.utils import ( generate_pos_neg_label_crop_centers, @@ -82,7 +85,7 @@ NumpyPadModeSequence = Union[Sequence[Union[NumpyPadMode, str]], NumpyPadMode, str] -class SpatialPadd(MapTransform): +class SpatialPadd(MapTransform, InvertibleTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.SpatialPad`. Performs padding to the data, symmetric for all sides or all on one side for each dimension. @@ -119,9 +122,29 @@ def __init__( def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) for key, m in self.key_iterator(d, self.mode): + self.push_transform(d, key) d[key] = self.padder(d[key], mode=m) return d + def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + d = deepcopy(dict(data)) + for key in self.key_iterator(d): + transform = self.get_most_recent_transform(d, key) + # Create inverse transform + orig_size = transform[self.Keys.orig_size.value] + if self.padder.method == Method.SYMMETRIC: + current_size = d[key].shape[1:] + roi_center = [floor(i / 2) if r % 2 == 0 else (i - 1) // 2 for r, i in zip(orig_size, current_size)] + else: + roi_center = [floor(r / 2) if r % 2 == 0 else (r - 1) // 2 for r in orig_size] + + inverse_transform = SpatialCrop(roi_center, orig_size) + # Apply inverse transform + d[key] = inverse_transform(d[key]) + # Remove the applied transform + self.pop_transform(d, key) + + return d class BorderPadd(MapTransform): """ diff --git a/monai/transforms/inverse.py b/monai/transforms/inverse.py new file mode 100644 index 0000000000..23cf3eb7e5 --- /dev/null +++ b/monai/transforms/inverse.py @@ -0,0 +1,122 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from enum import Enum +from typing import Dict, Hashable, Optional, Tuple + +import numpy as np + +from monai.transforms.transform import RandomizableTransform, Transform + + +__all__ = ["InvertibleTransform"] + + +class InvertibleTransform(Transform): + """Classes for invertible transforms. + + This class exists so that an ``invert`` method can be implemented. This allows, for + example, images to be cropped, rotated, padded, etc., during training and inference, + and after be returned to their original size before saving to file for comparison in + an external viewer. + + When the `__call__` method is called, the transformation information for each key is + stored. If the transforms were applied to keys "image" and "label", there will be two + extra keys in the dictionary: "image_transforms" and "label_transforms". Each list + contains a list of the transforms applied to that key. When the ``inverse`` method is + called, the inverse is called on each key individually, which allows for different + parameters being passed to each label (e.g., different interpolation for image and + label). + + When the ``inverse`` method is called, the inverse transforms are applied in a last- + in-first-out order. As the inverse is applied, its entry is removed from the list + detailing the applied transformations. That is to say that during the forward pass, + the list of applied transforms grows, and then during the inverse it shrinks back + down to an empty list. + + The information in data[key_transform] will be compatible with the default collate + since it only stores strings, numbers and arrays. + + We currently check that the id() of the transform is the same in the forward and + inverse directions. This is a useful check to ensure that the inverses are being + processed in the correct order. However, this may cause issues if the id() of the + object changes (such as multiprocessing on Windows). If you feel this issue affects + you, please raise a GitHub issue. + + Note to developers: When converting a transform to an invertible transform, you need to: + 1. Inherit from this class. + 2. In `__call__`, add a call to `push_transform`. + 3. Any extra information that might be needed for the inverse can be included with the + dictionary `extra_info`. This dictionary should have the same keys regardless of + whether `do_transform` was True or False and can only contain objects that are + accepted in pytorch's batch (e.g., `None` is not allowed). + 4. Implement an `inverse` method. Make sure that after performing the inverse, + `remove_most_recent_transform` is called. + """ + + class Keys(Enum): + """Extra meta data keys used for inverse transforms.""" + + class_name = "class" + id = "id" + orig_size = "orig_size" + extra_info = "extra_info" + do_transform = "do_transforms" + key_suffix = "_transforms" + + def push_transform( + self, + data: dict, + key: Hashable, + extra_info: Optional[dict] = None, + orig_size: Optional[Tuple] = None, + ) -> None: + """Append to list of applied transforms for that key.""" + key_transform = str(key) + self.Keys.key_suffix.value + info = { + self.Keys.class_name.value: self.__class__.__name__, + self.Keys.id.value: id(self), + self.Keys.orig_size.value: orig_size or data[key].shape[1:], + } + if extra_info is not None: + info[self.Keys.extra_info.value] = extra_info + # If class is randomizable transform, store whether the transform was actually performed (based on `prob`) + if isinstance(self, RandomizableTransform): + info[self.Keys.do_transform] = self._do_transform + # If this is the first, create list + if key_transform not in data: + data[key_transform] = [] + data[key_transform].append(info) + + def check_transforms_match(self, transform: dict) -> None: + # Check transorms are of same type. + if transform[self.Keys.id.value] != id(self): + raise RuntimeError("Should inverse most recently applied invertible transform first") + + def get_most_recent_transform(self, data: dict, key: Hashable) -> dict: + """Get most recent transform.""" + transform = dict(data[str(key) + self.Keys.key_suffix.value][-1]) + self.check_transforms_match(transform) + return transform + + def pop_transform(self, data: dict, key: Hashable) -> None: + """Remove most recent transform.""" + data[str(key) + self.Keys.key_suffix.value].pop() + + def inverse(self, data: dict) -> Dict[Hashable, np.ndarray]: + """ + Inverse of ``__call__``. + + Raises: + NotImplementedError: When the subclass does not override this method. + + """ + raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") diff --git a/tests/test_decollate.py b/tests/test_decollate.py index 5c6f04b48e..e525920181 100644 --- a/tests/test_decollate.py +++ b/tests/test_decollate.py @@ -9,8 +9,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +from enum import Enum +from monai.transforms.inverse import InvertibleTransform import unittest - +import sys import numpy as np import torch from parameterized import parameterized @@ -46,7 +48,14 @@ def tearDown(self) -> None: def check_match(self, in1, in2): if isinstance(in1, dict): self.assertTrue(isinstance(in2, dict)) - self.check_match(list(in1.keys()), list(in2.keys())) + for (k1, v1), (k2, v2) in zip(in1.items(), in2.items()): + if isinstance(k1, Enum) and isinstance(k2, Enum): + k1, k2 = k1.value, k2.value + self.check_match(k1, k2) + # Transform ids won't match for windows with multiprocessing, so don't check values + if k1 == InvertibleTransform.Keys.id.value and sys.platform == "win32": + continue + self.check_match(v1, v2) self.check_match(list(in1.values()), list(in2.values())) elif any(isinstance(in1, i) for i in [list, tuple]): for l1, l2 in zip(in1, in2): diff --git a/tests/test_inverse.py b/tests/test_inverse.py new file mode 100644 index 0000000000..9fb25a0b7d --- /dev/null +++ b/tests/test_inverse.py @@ -0,0 +1,207 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import unittest +from typing import TYPE_CHECKING, List, Tuple + +import numpy as np +import torch +from parameterized import parameterized + +from monai.data import CacheDataset, DataLoader, create_test_image_2d, create_test_image_3d +from monai.data.utils import decollate_batch +from monai.networks.nets import UNet +from monai.transforms import ( + AddChannel, + AddChanneld, + CenterSpatialCropd, + Compose, + InvertibleTransform, + LoadImaged, + ResizeWithPadOrCrop, + SpatialPadd, + allow_missing_keys_mode, +) +from monai.utils import first, optional_import, set_determinism +from tests.utils import make_nifti_image, make_rand_affine + +if TYPE_CHECKING: + import matplotlib.pyplot as plt + + has_matplotlib = True + has_nib = True +else: + plt, has_matplotlib = optional_import("matplotlib.pyplot") + _, has_nib = optional_import("nibabel") + +KEYS = ["image", "label"] + +TESTS: List[Tuple] = [] + +TESTS.append( + ( + "SpatialPadd (x2) 2d", + "2D", + 0.0, + SpatialPadd(KEYS, spatial_size=[111, 113], method="end"), + SpatialPadd(KEYS, spatial_size=[118, 117]), + ) +) + +TESTS.append( + ( + "SpatialPadd 3d", + "3D", + 0.0, + SpatialPadd(KEYS, spatial_size=[112, 113, 116]), + ) +) + +TESTS_COMPOSE_X2 = [(t[0] + " Compose", t[1], t[2], Compose(Compose(t[3:]))) for t in TESTS] + +TESTS = TESTS + TESTS_COMPOSE_X2 # type: ignore + + +def plot_im(orig, fwd_bck, fwd): + diff_orig_fwd_bck = orig - fwd_bck + ims_to_show = [orig, fwd, fwd_bck, diff_orig_fwd_bck] + titles = ["x", "fx", "f⁻¹fx", "x - f⁻¹fx"] + fig, axes = plt.subplots(1, 4, gridspec_kw={"width_ratios": [i.shape[1] for i in ims_to_show]}) + vmin = min(np.array(i).min() for i in [orig, fwd_bck, fwd]) + vmax = max(np.array(i).max() for i in [orig, fwd_bck, fwd]) + for im, title, ax in zip(ims_to_show, titles, axes): + _vmin, _vmax = (vmin, vmax) if id(im) != id(diff_orig_fwd_bck) else (None, None) + im = np.squeeze(np.array(im)) + while im.ndim > 2: + im = im[..., im.shape[-1] // 2] + im_show = ax.imshow(np.squeeze(im), vmin=_vmin, vmax=_vmax) + ax.set_title(title, fontsize=25) + ax.axis("off") + fig.colorbar(im_show, ax=ax) + plt.show() + + +class TestInverse(unittest.TestCase): + def setUp(self): + if not has_nib: + self.skipTest("nibabel required for test_inverse") + + set_determinism(seed=0) + + self.all_data = {} + + affine = make_rand_affine() + affine[0] *= 2 + + im_1d = AddChannel()(np.arange(0, 10)) + self.all_data["1D"] = {"image": im_1d, "label": im_1d, "other": im_1d} + + im_2d_fname, seg_2d_fname = [make_nifti_image(i) for i in create_test_image_2d(101, 100)] + im_3d_fname, seg_3d_fname = [make_nifti_image(i, affine) for i in create_test_image_3d(100, 101, 107)] + + load_ims = Compose([LoadImaged(KEYS), AddChanneld(KEYS)]) + self.all_data["2D"] = load_ims({"image": im_2d_fname, "label": seg_2d_fname}) + self.all_data["3D"] = load_ims({"image": im_3d_fname, "label": seg_3d_fname}) + + def tearDown(self): + set_determinism(seed=None) + + def check_inverse(self, name, keys, orig_d, fwd_bck_d, unmodified_d, acceptable_diff): + for key in keys: + orig = orig_d[key] + fwd_bck = fwd_bck_d[key] + if isinstance(fwd_bck, torch.Tensor): + fwd_bck = fwd_bck.cpu().numpy() + unmodified = unmodified_d[key] + if isinstance(orig, np.ndarray): + mean_diff = np.mean(np.abs(orig - fwd_bck)) + unmodded_diff = np.mean(np.abs(orig - ResizeWithPadOrCrop(orig.shape[1:])(unmodified))) + try: + self.assertLessEqual(mean_diff, acceptable_diff) + except AssertionError: + print( + f"Failed: {name}. Mean diff = {mean_diff} (expected <= {acceptable_diff}), unmodified diff: {unmodded_diff}" + ) + if has_matplotlib and orig[0].ndim > 1: + plot_im(orig, fwd_bck, unmodified) + elif orig[0].ndim == 1: + print(orig) + print(fwd_bck) + raise + + @parameterized.expand(TESTS) + def test_inverse(self, _, data_name, acceptable_diff, *transforms): + name = _ + + data = self.all_data[data_name] + + forwards = [data.copy()] + + # Apply forwards + for t in transforms: + forwards.append(t(forwards[-1])) + + # Check that error is thrown when inverse are used out of order. + t = SpatialPadd("image", [10, 5]) + with self.assertRaises(RuntimeError): + t.inverse(forwards[-1]) + + # Apply inverses + fwd_bck = forwards[-1].copy() + for i, t in enumerate(reversed(transforms)): + if isinstance(t, InvertibleTransform): + fwd_bck = t.inverse(fwd_bck) + self.check_inverse(name, data.keys(), forwards[-i - 2], fwd_bck, forwards[-1], acceptable_diff) + + def test_inverse_inferred_seg(self): + + test_data = [] + for _ in range(20): + image, label = create_test_image_2d(100, 101) + test_data.append({"image": image, "label": label.astype(np.float32)}) + + batch_size = 10 + # num workers = 0 for mac + num_workers = 2 if sys.platform != "darwin" else 0 + transforms = Compose([AddChanneld(KEYS), SpatialPadd(KEYS, (150, 153))]) + num_invertible_transforms = sum(1 for i in transforms.transforms if isinstance(i, InvertibleTransform)) + + dataset = CacheDataset(test_data, transform=transforms, progress=False) + loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers) + + device = "cuda" if torch.cuda.is_available() else "cpu" + model = UNet( + dimensions=2, + in_channels=1, + out_channels=1, + channels=(2, 4), + strides=(2,), + ).to(device) + + data = first(loader) + labels = data["label"].to(device) + segs = model(labels).detach().cpu() + label_transform_key = "label" + InvertibleTransform.Keys.key_suffix.value + segs_dict = {"label": segs, label_transform_key: data[label_transform_key]} + segs_dict_decollated = decollate_batch(segs_dict) + + # inverse of individual segmentation + seg_dict = first(segs_dict_decollated) + with allow_missing_keys_mode(transforms): + inv_seg = transforms.inverse(seg_dict)["label"] + self.assertEqual(len(data["label_transforms"]), num_invertible_transforms) + self.assertEqual(len(seg_dict["label_transforms"]), num_invertible_transforms) + self.assertEqual(inv_seg.shape[1:], test_data[0]["label"].shape) + + +if __name__ == "__main__": + unittest.main() From bc559af6bb26be35f6e70e34a1827666cf13c194 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Mon, 8 Mar 2021 16:30:18 +0000 Subject: [PATCH 32/64] code format Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/croppad/dictionary.py | 1 + monai/transforms/inverse.py | 1 - tests/test_decollate.py | 7 ++++--- tests/test_inverse.py | 1 - 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/monai/transforms/croppad/dictionary.py b/monai/transforms/croppad/dictionary.py index ef47bd1d06..0e1d5cad48 100644 --- a/monai/transforms/croppad/dictionary.py +++ b/monai/transforms/croppad/dictionary.py @@ -146,6 +146,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar return d + class BorderPadd(MapTransform): """ Pad the input data by adding specified borders to every dimension. diff --git a/monai/transforms/inverse.py b/monai/transforms/inverse.py index 23cf3eb7e5..fe2d0957b5 100644 --- a/monai/transforms/inverse.py +++ b/monai/transforms/inverse.py @@ -16,7 +16,6 @@ from monai.transforms.transform import RandomizableTransform, Transform - __all__ = ["InvertibleTransform"] diff --git a/tests/test_decollate.py b/tests/test_decollate.py index e525920181..5cde1b2658 100644 --- a/tests/test_decollate.py +++ b/tests/test_decollate.py @@ -9,10 +9,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from enum import Enum -from monai.transforms.inverse import InvertibleTransform -import unittest import sys +import unittest +from enum import Enum + import numpy as np import torch from parameterized import parameterized @@ -20,6 +20,7 @@ from monai.data import CacheDataset, DataLoader, create_test_image_2d from monai.data.utils import decollate_batch from monai.transforms import AddChanneld, Compose, LoadImaged, RandFlipd, SpatialPadd, ToTensord +from monai.transforms.inverse import InvertibleTransform from monai.transforms.post.dictionary import Decollated from monai.utils import optional_import, set_determinism from tests.utils import make_nifti_image diff --git a/tests/test_inverse.py b/tests/test_inverse.py index 9fb25a0b7d..ff259eefad 100644 --- a/tests/test_inverse.py +++ b/tests/test_inverse.py @@ -23,7 +23,6 @@ from monai.transforms import ( AddChannel, AddChanneld, - CenterSpatialCropd, Compose, InvertibleTransform, LoadImaged, From d8ad0eef80105e557e9fc7d4c2d4d40fdf64ae28 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Mon, 8 Mar 2021 16:35:07 +0000 Subject: [PATCH 33/64] code format Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/inverse.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/transforms/inverse.py b/monai/transforms/inverse.py index fe2d0957b5..6f50a59a6b 100644 --- a/monai/transforms/inverse.py +++ b/monai/transforms/inverse.py @@ -89,7 +89,7 @@ def push_transform( info[self.Keys.extra_info.value] = extra_info # If class is randomizable transform, store whether the transform was actually performed (based on `prob`) if isinstance(self, RandomizableTransform): - info[self.Keys.do_transform] = self._do_transform + info[self.Keys.do_transform.value] = self._do_transform # If this is the first, create list if key_transform not in data: data[key_transform] = [] From c7d59d50b508a9b180e959c0f5f33725880d703a Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Tue, 9 Mar 2021 15:26:35 +0000 Subject: [PATCH 34/64] formatting docstring Signed-off-by: Wenqi Li --- docs/source/transforms.rst | 6 ++++++ monai/transforms/inverse.py | 28 +++++++++++++++------------- 2 files changed, 21 insertions(+), 13 deletions(-) diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index 3bc8d0899a..dcdeab1ac8 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -38,6 +38,12 @@ Generic Interfaces :members: :special-members: __call__ +`InvertibleTransform` +^^^^^^^^^^^^^^^^^^^^^ +.. autoclass:: InvertibleTransform + :members: + + Vanilla Transforms ------------------ diff --git a/monai/transforms/inverse.py b/monai/transforms/inverse.py index 6f50a59a6b..d2ca0a78de 100644 --- a/monai/transforms/inverse.py +++ b/monai/transforms/inverse.py @@ -27,7 +27,7 @@ class InvertibleTransform(Transform): and after be returned to their original size before saving to file for comparison in an external viewer. - When the `__call__` method is called, the transformation information for each key is + When the ``__call__`` method is called, the transformation information for each key is stored. If the transforms were applied to keys "image" and "label", there will be two extra keys in the dictionary: "image_transforms" and "label_transforms". Each list contains a list of the transforms applied to that key. When the ``inverse`` method is @@ -41,24 +41,26 @@ class InvertibleTransform(Transform): the list of applied transforms grows, and then during the inverse it shrinks back down to an empty list. - The information in data[key_transform] will be compatible with the default collate + The information in ``data[key_transform]`` will be compatible with the default collate since it only stores strings, numbers and arrays. - We currently check that the id() of the transform is the same in the forward and + We currently check that the ``id()`` of the transform is the same in the forward and inverse directions. This is a useful check to ensure that the inverses are being - processed in the correct order. However, this may cause issues if the id() of the + processed in the correct order. However, this may cause issues if the ``id()`` of the object changes (such as multiprocessing on Windows). If you feel this issue affects you, please raise a GitHub issue. Note to developers: When converting a transform to an invertible transform, you need to: - 1. Inherit from this class. - 2. In `__call__`, add a call to `push_transform`. - 3. Any extra information that might be needed for the inverse can be included with the - dictionary `extra_info`. This dictionary should have the same keys regardless of - whether `do_transform` was True or False and can only contain objects that are - accepted in pytorch's batch (e.g., `None` is not allowed). - 4. Implement an `inverse` method. Make sure that after performing the inverse, - `remove_most_recent_transform` is called. + + #. Inherit from this class. + #. In ``__call__``, add a call to ``push_transform``. + #. Any extra information that might be needed for the inverse can be included with the + dictionary ``extra_info``. This dictionary should have the same keys regardless of + whether ``do_transform`` was `True` or `False` and can only contain objects that are + accepted in pytorch data loader's collate function (e.g., `None` is not allowed). + #. Implement an ``inverse`` method. Make sure that after performing the inverse, + ``pop_transform`` is called. + """ class Keys(Enum): @@ -96,7 +98,7 @@ def push_transform( data[key_transform].append(info) def check_transforms_match(self, transform: dict) -> None: - # Check transorms are of same type. + """Check transforms are of same instance.""" if transform[self.Keys.id.value] != id(self): raise RuntimeError("Should inverse most recently applied invertible transform first") From 6c913d8f443d3ddf24075f716a3dd0dc3f1b9f4a Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Wed, 10 Mar 2021 11:15:23 +0000 Subject: [PATCH 35/64] enum changes Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/croppad/dictionary.py | 3 ++- monai/transforms/inverse.py | 30 +++++++++----------------- monai/utils/enums.py | 11 ++++++++++ tests/test_decollate.py | 4 ++-- tests/test_inverse.py | 3 ++- 5 files changed, 27 insertions(+), 24 deletions(-) diff --git a/monai/transforms/croppad/dictionary.py b/monai/transforms/croppad/dictionary.py index 0e1d5cad48..667fb7a821 100644 --- a/monai/transforms/croppad/dictionary.py +++ b/monai/transforms/croppad/dictionary.py @@ -41,6 +41,7 @@ weighted_patch_samples, ) from monai.utils import Method, NumpyPadMode, ensure_tuple, ensure_tuple_rep, fall_back_tuple +from monai.utils.enums import InverseKeys __all__ = [ "NumpyPadModeSequence", @@ -131,7 +132,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) # Create inverse transform - orig_size = transform[self.Keys.orig_size.value] + orig_size = transform[InverseKeys.ORIG_SIZE.value] if self.padder.method == Method.SYMMETRIC: current_size = d[key].shape[1:] roi_center = [floor(i / 2) if r % 2 == 0 else (i - 1) // 2 for r, i in zip(orig_size, current_size)] diff --git a/monai/transforms/inverse.py b/monai/transforms/inverse.py index d2ca0a78de..f9de8746ca 100644 --- a/monai/transforms/inverse.py +++ b/monai/transforms/inverse.py @@ -9,12 +9,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -from enum import Enum from typing import Dict, Hashable, Optional, Tuple import numpy as np from monai.transforms.transform import RandomizableTransform, Transform +from monai.utils.enums import InverseKeys __all__ = ["InvertibleTransform"] @@ -63,16 +63,6 @@ class InvertibleTransform(Transform): """ - class Keys(Enum): - """Extra meta data keys used for inverse transforms.""" - - class_name = "class" - id = "id" - orig_size = "orig_size" - extra_info = "extra_info" - do_transform = "do_transforms" - key_suffix = "_transforms" - def push_transform( self, data: dict, @@ -81,17 +71,17 @@ def push_transform( orig_size: Optional[Tuple] = None, ) -> None: """Append to list of applied transforms for that key.""" - key_transform = str(key) + self.Keys.key_suffix.value + key_transform = str(key) + InverseKeys.KEY_SUFFIX.value info = { - self.Keys.class_name.value: self.__class__.__name__, - self.Keys.id.value: id(self), - self.Keys.orig_size.value: orig_size or data[key].shape[1:], + InverseKeys.CLASS_NAME.value: self.__class__.__name__, + InverseKeys.ID.value: id(self), + InverseKeys.ORIG_SIZE.value: orig_size or data[key].shape[1:], } if extra_info is not None: - info[self.Keys.extra_info.value] = extra_info + info[InverseKeys.EXTRA_INFO.value] = extra_info # If class is randomizable transform, store whether the transform was actually performed (based on `prob`) if isinstance(self, RandomizableTransform): - info[self.Keys.do_transform.value] = self._do_transform + info[InverseKeys.DO_TRANSFORM.value] = self._do_transform # If this is the first, create list if key_transform not in data: data[key_transform] = [] @@ -99,18 +89,18 @@ def push_transform( def check_transforms_match(self, transform: dict) -> None: """Check transforms are of same instance.""" - if transform[self.Keys.id.value] != id(self): + if transform[InverseKeys.ID.value] != id(self): raise RuntimeError("Should inverse most recently applied invertible transform first") def get_most_recent_transform(self, data: dict, key: Hashable) -> dict: """Get most recent transform.""" - transform = dict(data[str(key) + self.Keys.key_suffix.value][-1]) + transform = dict(data[str(key) + InverseKeys.KEY_SUFFIX.value][-1]) self.check_transforms_match(transform) return transform def pop_transform(self, data: dict, key: Hashable) -> None: """Remove most recent transform.""" - data[str(key) + self.Keys.key_suffix.value].pop() + data[str(key) + InverseKeys.KEY_SUFFIX.value].pop() def inverse(self, data: dict) -> Dict[Hashable, np.ndarray]: """ diff --git a/monai/utils/enums.py b/monai/utils/enums.py index d1d2d3bcce..3574f0b7e1 100644 --- a/monai/utils/enums.py +++ b/monai/utils/enums.py @@ -214,3 +214,14 @@ class Method(Enum): SYMMETRIC = "symmetric" END = "end" + + +class InverseKeys(Enum): + """Extra meta data keys used for inverse transforms.""" + + CLASS_NAME = "class" + ID = "id" + ORIG_SIZE = "orig_size" + EXTRA_INFO = "extra_info" + DO_TRANSFORM = "do_transforms" + KEY_SUFFIX = "_transforms" diff --git a/tests/test_decollate.py b/tests/test_decollate.py index 5cde1b2658..85bfe67b99 100644 --- a/tests/test_decollate.py +++ b/tests/test_decollate.py @@ -20,9 +20,9 @@ from monai.data import CacheDataset, DataLoader, create_test_image_2d from monai.data.utils import decollate_batch from monai.transforms import AddChanneld, Compose, LoadImaged, RandFlipd, SpatialPadd, ToTensord -from monai.transforms.inverse import InvertibleTransform from monai.transforms.post.dictionary import Decollated from monai.utils import optional_import, set_determinism +from monai.utils.enums import InverseKeys from tests.utils import make_nifti_image _, has_nib = optional_import("nibabel") @@ -54,7 +54,7 @@ def check_match(self, in1, in2): k1, k2 = k1.value, k2.value self.check_match(k1, k2) # Transform ids won't match for windows with multiprocessing, so don't check values - if k1 == InvertibleTransform.Keys.id.value and sys.platform == "win32": + if k1 == InverseKeys.ID.value and sys.platform == "win32": continue self.check_match(v1, v2) self.check_match(list(in1.values()), list(in2.values())) diff --git a/tests/test_inverse.py b/tests/test_inverse.py index ff259eefad..92e3025557 100644 --- a/tests/test_inverse.py +++ b/tests/test_inverse.py @@ -31,6 +31,7 @@ allow_missing_keys_mode, ) from monai.utils import first, optional_import, set_determinism +from monai.utils.enums import InverseKeys from tests.utils import make_nifti_image, make_rand_affine if TYPE_CHECKING: @@ -189,7 +190,7 @@ def test_inverse_inferred_seg(self): data = first(loader) labels = data["label"].to(device) segs = model(labels).detach().cpu() - label_transform_key = "label" + InvertibleTransform.Keys.key_suffix.value + label_transform_key = "label" + InverseKeys.KEY_SUFFIX.value segs_dict = {"label": segs, label_transform_key: data[label_transform_key]} segs_dict_decollated = decollate_batch(segs_dict) From 5a8a08b00c4c60f506dc3719fe9148a44137b3c5 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Wed, 10 Mar 2021 11:29:31 +0000 Subject: [PATCH 36/64] put matplotlib functionality in docstrings Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- tests/test_inverse.py | 65 +++++++++++++++++++++++++------------------ 1 file changed, 38 insertions(+), 27 deletions(-) diff --git a/tests/test_inverse.py b/tests/test_inverse.py index 92e3025557..46729c7bc6 100644 --- a/tests/test_inverse.py +++ b/tests/test_inverse.py @@ -35,12 +35,9 @@ from tests.utils import make_nifti_image, make_rand_affine if TYPE_CHECKING: - import matplotlib.pyplot as plt - has_matplotlib = True has_nib = True else: - plt, has_matplotlib = optional_import("matplotlib.pyplot") _, has_nib = optional_import("nibabel") KEYS = ["image", "label"] @@ -71,26 +68,45 @@ TESTS = TESTS + TESTS_COMPOSE_X2 # type: ignore -def plot_im(orig, fwd_bck, fwd): - diff_orig_fwd_bck = orig - fwd_bck - ims_to_show = [orig, fwd, fwd_bck, diff_orig_fwd_bck] - titles = ["x", "fx", "f⁻¹fx", "x - f⁻¹fx"] - fig, axes = plt.subplots(1, 4, gridspec_kw={"width_ratios": [i.shape[1] for i in ims_to_show]}) - vmin = min(np.array(i).min() for i in [orig, fwd_bck, fwd]) - vmax = max(np.array(i).max() for i in [orig, fwd_bck, fwd]) - for im, title, ax in zip(ims_to_show, titles, axes): - _vmin, _vmax = (vmin, vmax) if id(im) != id(diff_orig_fwd_bck) else (None, None) - im = np.squeeze(np.array(im)) - while im.ndim > 2: - im = im[..., im.shape[-1] // 2] - im_show = ax.imshow(np.squeeze(im), vmin=_vmin, vmax=_vmax) - ax.set_title(title, fontsize=25) - ax.axis("off") - fig.colorbar(im_show, ax=ax) - plt.show() - - class TestInverse(unittest.TestCase): + """Test inverse methods. + + If tests are failing, the following function might be useful for displaying + `x`, `fx`, `f⁻¹fx` and `x - f⁻¹fx`. + + .. code-block:: python + + def plot_im(orig, fwd_bck, fwd): + import matplotlib.pyplot as plt + diff_orig_fwd_bck = orig - fwd_bck + ims_to_show = [orig, fwd, fwd_bck, diff_orig_fwd_bck] + titles = ["x", "fx", "f⁻¹fx", "x - f⁻¹fx"] + fig, axes = plt.subplots(1, 4, gridspec_kw={"width_ratios": [i.shape[1] for i in ims_to_show]}) + vmin = min(np.array(i).min() for i in [orig, fwd_bck, fwd]) + vmax = max(np.array(i).max() for i in [orig, fwd_bck, fwd]) + for im, title, ax in zip(ims_to_show, titles, axes): + _vmin, _vmax = (vmin, vmax) if id(im) != id(diff_orig_fwd_bck) else (None, None) + im = np.squeeze(np.array(im)) + while im.ndim > 2: + im = im[..., im.shape[-1] // 2] + im_show = ax.imshow(np.squeeze(im), vmin=_vmin, vmax=_vmax) + ax.set_title(title, fontsize=25) + ax.axis("off") + fig.colorbar(im_show, ax=ax) + plt.show() + + This can then be added to the exception: + + .. code-block:: python + + except AssertionError: + print( + f"Failed: {name}. Mean diff = {mean_diff} (expected <= {acceptable_diff}), unmodified diff: {unmodded_diff}" + ) + if orig[0].ndim > 1: + plot_im(orig, fwd_bck, unmodified) + """ + def setUp(self): if not has_nib: self.skipTest("nibabel required for test_inverse") @@ -131,11 +147,6 @@ def check_inverse(self, name, keys, orig_d, fwd_bck_d, unmodified_d, acceptable_ print( f"Failed: {name}. Mean diff = {mean_diff} (expected <= {acceptable_diff}), unmodified diff: {unmodded_diff}" ) - if has_matplotlib and orig[0].ndim > 1: - plot_im(orig, fwd_bck, unmodified) - elif orig[0].ndim == 1: - print(orig) - print(fwd_bck) raise @parameterized.expand(TESTS) From 841f19c5d13180bd547930658a485b43a7dfdda8 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Wed, 10 Mar 2021 11:40:36 +0000 Subject: [PATCH 37/64] update module list Signed-off-by: Wenqi Li --- monai/utils/__init__.py | 1 + monai/utils/enums.py | 1 + 2 files changed, 2 insertions(+) diff --git a/monai/utils/__init__.py b/monai/utils/__init__.py index 1e17d44029..3c1e7efe24 100644 --- a/monai/utils/__init__.py +++ b/monai/utils/__init__.py @@ -20,6 +20,7 @@ GridSampleMode, GridSamplePadMode, InterpolateMode, + InverseKeys, LossReduction, Method, MetricReduction, diff --git a/monai/utils/enums.py b/monai/utils/enums.py index 3574f0b7e1..d661781616 100644 --- a/monai/utils/enums.py +++ b/monai/utils/enums.py @@ -28,6 +28,7 @@ "ChannelMatching", "SkipMode", "Method", + "InverseKeys", ] From 595f1cdd28676ac3db2457867695dd32cf5ccb52 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Wed, 10 Mar 2021 13:10:08 +0000 Subject: [PATCH 38/64] skip decollate id check for windows and mac Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- tests/test_decollate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_decollate.py b/tests/test_decollate.py index 85bfe67b99..076d1be56c 100644 --- a/tests/test_decollate.py +++ b/tests/test_decollate.py @@ -54,7 +54,7 @@ def check_match(self, in1, in2): k1, k2 = k1.value, k2.value self.check_match(k1, k2) # Transform ids won't match for windows with multiprocessing, so don't check values - if k1 == InverseKeys.ID.value and sys.platform == "win32": + if k1 == InverseKeys.ID.value and sys.platform in ["darwin", "win32"]: continue self.check_match(v1, v2) self.check_match(list(in1.values()), list(in2.values())) From 2c6b48e24302d0bee774e087d043e9a764631719 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Wed, 10 Mar 2021 14:25:25 +0000 Subject: [PATCH 39/64] fix test for windows and mac Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- tests/test_decollate.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/tests/test_decollate.py b/tests/test_decollate.py index 076d1be56c..24a34482b5 100644 --- a/tests/test_decollate.py +++ b/tests/test_decollate.py @@ -57,13 +57,12 @@ def check_match(self, in1, in2): if k1 == InverseKeys.ID.value and sys.platform in ["darwin", "win32"]: continue self.check_match(v1, v2) - self.check_match(list(in1.values()), list(in2.values())) - elif any(isinstance(in1, i) for i in [list, tuple]): + elif isinstance(in1, (list, tuple)): for l1, l2 in zip(in1, in2): self.check_match(l1, l2) - elif any(isinstance(in1, i) for i in [str, int]): + elif isinstance(in1, (str, int)): self.assertEqual(in1, in2) - elif any(isinstance(in1, i) for i in [torch.Tensor, np.ndarray]): + elif isinstance(in1, (torch.Tensor, np.ndarray)): np.testing.assert_array_equal(in1, in2) else: raise RuntimeError(f"Not sure how to compare types. type(in1): {type(in1)}, type(in2): {type(in2)}") From 358d653b56e90ef792b42003cbfc696333e17c56 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Wed, 10 Mar 2021 15:12:27 +0000 Subject: [PATCH 40/64] update merge Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/inverse.py | 114 ------------------------------------ tests/test_inverse.py | 55 +++++++++-------- 2 files changed, 32 insertions(+), 137 deletions(-) diff --git a/monai/transforms/inverse.py b/monai/transforms/inverse.py index 68366817a4..16d3dc85bc 100644 --- a/monai/transforms/inverse.py +++ b/monai/transforms/inverse.py @@ -19,120 +19,6 @@ __all__ = ["InvertibleTransform"] -class InvertibleTransform(Transform): - """Classes for invertible transforms. - - This class exists so that an ``invert`` method can be implemented. This allows, for - example, images to be cropped, rotated, padded, etc., during training and inference, - and after be returned to their original size before saving to file for comparison in - an external viewer. - - When the ``__call__`` method is called, the transformation information for each key is - stored. If the transforms were applied to keys "image" and "label", there will be two - extra keys in the dictionary: "image_transforms" and "label_transforms". Each list - contains a list of the transforms applied to that key. When the ``inverse`` method is - called, the inverse is called on each key individually, which allows for different - parameters being passed to each label (e.g., different interpolation for image and - label). - - When the ``inverse`` method is called, the inverse transforms are applied in a last- - in-first-out order. As the inverse is applied, its entry is removed from the list - detailing the applied transformations. That is to say that during the forward pass, - the list of applied transforms grows, and then during the inverse it shrinks back - down to an empty list. - - The information in ``data[key_transform]`` will be compatible with the default collate - since it only stores strings, numbers and arrays. - - We currently check that the ``id()`` of the transform is the same in the forward and - inverse directions. This is a useful check to ensure that the inverses are being - processed in the correct order. However, this may cause issues if the ``id()`` of the - object changes (such as multiprocessing on Windows). If you feel this issue affects - you, please raise a GitHub issue. - - Note to developers: When converting a transform to an invertible transform, you need to: - - #. Inherit from this class. - #. In ``__call__``, add a call to ``push_transform``. - #. Any extra information that might be needed for the inverse can be included with the - dictionary ``extra_info``. This dictionary should have the same keys regardless of - whether ``do_transform`` was `True` or `False` and can only contain objects that are - accepted in pytorch data loader's collate function (e.g., `None` is not allowed). - #. Implement an ``inverse`` method. Make sure that after performing the inverse, - ``pop_transform`` is called. - - """ - - def push_transform( - self, - data: dict, - key: Hashable, - extra_info: Optional[dict] = None, - orig_size: Optional[Tuple] = None, - ) -> None: - """Append to list of applied transforms for that key.""" - key_transform = str(key) + InverseKeys.KEY_SUFFIX.value - info = { - InverseKeys.CLASS_NAME.value: self.__class__.__name__, - InverseKeys.ID.value: id(self), - InverseKeys.ORIG_SIZE.value: orig_size or data[key].shape[1:], - } - if extra_info is not None: - info[InverseKeys.EXTRA_INFO.value] = extra_info - # If class is randomizable transform, store whether the transform was actually performed (based on `prob`) - if isinstance(self, RandomizableTransform): - info[InverseKeys.DO_TRANSFORM.value] = self._do_transform - # If this is the first, create list - if key_transform not in data: - data[key_transform] = [] - data[key_transform].append(info) - - def check_transforms_match(self, transform: dict) -> None: - """Check transforms are of same instance.""" - if transform[InverseKeys.ID.value] != id(self): - raise RuntimeError("Should inverse most recently applied invertible transform first") - - def get_most_recent_transform(self, data: dict, key: Hashable) -> dict: - """Get most recent transform.""" - transform = dict(data[str(key) + InverseKeys.KEY_SUFFIX.value][-1]) - self.check_transforms_match(transform) - return transform - - def pop_transform(self, data: dict, key: Hashable) -> None: - """Remove most recent transform.""" - data[str(key) + InverseKeys.KEY_SUFFIX.value].pop() - - def inverse(self, data: dict) -> Dict[Hashable, np.ndarray]: - """ - Inverse of ``__call__``. - - Raises: - NotImplementedError: When the subclass does not override this method. - - """ - raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") - -# Copyright 2020 - 2021 MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Dict, Hashable, Optional, Tuple - -import numpy as np - -from monai.transforms.transform import RandomizableTransform, Transform -from monai.utils.enums import InverseKeys - -__all__ = ["InvertibleTransform"] - - class InvertibleTransform(Transform): """Classes for invertible transforms. diff --git a/tests/test_inverse.py b/tests/test_inverse.py index e7aad3b95a..c1e68f270a 100644 --- a/tests/test_inverse.py +++ b/tests/test_inverse.py @@ -57,13 +57,9 @@ from tests.utils import make_nifti_image, make_rand_affine, test_is_quick if TYPE_CHECKING: - import matplotlib.pyplot as plt - - has_matplotlib = True has_vtk = True has_nib = True else: - plt, has_matplotlib = optional_import("matplotlib.pyplot") _, has_vtk = optional_import("vtk") _, has_nib = optional_import("nibabel") @@ -393,26 +389,39 @@ TESTS = TESTS + TESTS_COMPOSE_X2 # type: ignore -def plot_im(orig, fwd_bck, fwd): - diff_orig_fwd_bck = orig - fwd_bck - ims_to_show = [orig, fwd, fwd_bck, diff_orig_fwd_bck] - titles = ["x", "fx", "f⁻¹fx", "x - f⁻¹fx"] - fig, axes = plt.subplots(1, 4, gridspec_kw={"width_ratios": [i.shape[1] for i in ims_to_show]}) - vmin = min(np.array(i).min() for i in [orig, fwd_bck, fwd]) - vmax = max(np.array(i).max() for i in [orig, fwd_bck, fwd]) - for im, title, ax in zip(ims_to_show, titles, axes): - _vmin, _vmax = (vmin, vmax) if id(im) != id(diff_orig_fwd_bck) else (None, None) - im = np.squeeze(np.array(im)) - while im.ndim > 2: - im = im[..., im.shape[-1] // 2] - im_show = ax.imshow(np.squeeze(im), vmin=_vmin, vmax=_vmax) - ax.set_title(title, fontsize=25) - ax.axis("off") - fig.colorbar(im_show, ax=ax) - plt.show() - - class TestInverse(unittest.TestCase): + """Test inverse methods. + If tests are failing, the following function might be useful for displaying + `x`, `fx`, `f⁻¹fx` and `x - f⁻¹fx`. + .. code-block:: python + def plot_im(orig, fwd_bck, fwd): + import matplotlib.pyplot as plt + diff_orig_fwd_bck = orig - fwd_bck + ims_to_show = [orig, fwd, fwd_bck, diff_orig_fwd_bck] + titles = ["x", "fx", "f⁻¹fx", "x - f⁻¹fx"] + fig, axes = plt.subplots(1, 4, gridspec_kw={"width_ratios": [i.shape[1] for i in ims_to_show]}) + vmin = min(np.array(i).min() for i in [orig, fwd_bck, fwd]) + vmax = max(np.array(i).max() for i in [orig, fwd_bck, fwd]) + for im, title, ax in zip(ims_to_show, titles, axes): + _vmin, _vmax = (vmin, vmax) if id(im) != id(diff_orig_fwd_bck) else (None, None) + im = np.squeeze(np.array(im)) + while im.ndim > 2: + im = im[..., im.shape[-1] // 2] + im_show = ax.imshow(np.squeeze(im), vmin=_vmin, vmax=_vmax) + ax.set_title(title, fontsize=25) + ax.axis("off") + fig.colorbar(im_show, ax=ax) + plt.show() + This can then be added to the exception: + .. code-block:: python + except AssertionError: + print( + f"Failed: {name}. Mean diff = {mean_diff} (expected <= {acceptable_diff}), unmodified diff: {unmodded_diff}" + ) + if orig[0].ndim > 1: + plot_im(orig, fwd_bck, unmodified) + """ + def setUp(self): if not has_nib: self.skipTest("nibabel required for test_inverse") From 917c47ed1a8f0ecd712679543f3c2b0cfcc052de Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Wed, 10 Mar 2021 15:47:38 +0000 Subject: [PATCH 41/64] update merge 2 Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/croppad/dictionary.py | 22 +++++----- monai/transforms/inverse.py | 7 ++++ monai/transforms/spatial/dictionary.py | 57 +++++++++++++------------- tests/test_decollate.py | 1 - tests/test_inverse.py | 10 ++--- 5 files changed, 51 insertions(+), 46 deletions(-) diff --git a/monai/transforms/croppad/dictionary.py b/monai/transforms/croppad/dictionary.py index 173bd94dff..6bde7ead1f 100644 --- a/monai/transforms/croppad/dictionary.py +++ b/monai/transforms/croppad/dictionary.py @@ -132,7 +132,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) # Create inverse transform - orig_size = transform[InvertibleTransform.Keys.orig_size] + orig_size = transform[InverseKeys.ORIG_SIZE.value] if self.padder.method == Method.SYMMETRIC: current_size = d[key].shape[1:] roi_center = [floor(i / 2) if r % 2 == 0 else (i - 1) // 2 for r, i in zip(orig_size, current_size)] @@ -201,7 +201,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) # Create inverse transform - orig_size = np.array(transform[InvertibleTransform.Keys.orig_size]) + orig_size = np.array(transform[InverseKeys.ORIG_SIZE.value]) roi_start = np.array(self.padder.spatial_border) # Need to convert single value to [min1,min2,...] if roi_start.size == 1: @@ -209,7 +209,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar # need to convert [min1,max1,min2,...] to [min1,min2,...] elif roi_start.size == 2 * orig_size.size: roi_start = roi_start[::2] - roi_end = np.array(transform[InvertibleTransform.Keys.orig_size]) + roi_start + roi_end = np.array(transform[InverseKeys.ORIG_SIZE.value]) + roi_start inverse_transform = SpatialCrop(roi_start=roi_start, roi_end=roi_end) # Apply inverse transform @@ -267,7 +267,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) # Create inverse transform - orig_size = np.array(transform[InvertibleTransform.Keys.orig_size]) + orig_size = np.array(transform[InverseKeys.ORIG_SIZE.value]) current_size = np.array(d[key].shape[1:]) roi_start = np.floor((current_size - orig_size) / 2) roi_end = orig_size + roi_start @@ -322,7 +322,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) # Create inverse transform - orig_size = transform[InvertibleTransform.Keys.orig_size] + orig_size = transform[InverseKeys.ORIG_SIZE.value] pad_to_start = self.cropper.roi_start pad_to_end = orig_size - self.cropper.roi_end # interweave mins and maxes @@ -370,7 +370,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) # Create inverse transform - orig_size = np.array(transform[InvertibleTransform.Keys.orig_size]) + orig_size = np.array(transform[InverseKeys.ORIG_SIZE.value]) current_size = np.array(d[key].shape[1:]) pad_to_start = np.floor((orig_size - current_size) / 2) # in each direction, if original size is even and current size is odd, += 1 @@ -453,12 +453,12 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) # Create inverse transform - orig_size = transform[InvertibleTransform.Keys.orig_size] + orig_size = transform[InverseKeys.ORIG_SIZE.value] random_center = self.random_center pad_to_start = np.empty((len(orig_size)), dtype=np.int32) pad_to_end = np.empty((len(orig_size)), dtype=np.int32) if random_center: - for i, _slice in enumerate(transform[InvertibleTransform.Keys.extra_info]["slices"]): + for i, _slice in enumerate(transform[InverseKeys.EXTRA_INFO.value]["slices"]): pad_to_start[i] = _slice[0] pad_to_end[i] = orig_size[i] - _slice[1] else: @@ -600,8 +600,8 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) # Create inverse transform - orig_size = np.array(transform[InvertibleTransform.Keys.orig_size]) - extra_info = transform[InvertibleTransform.Keys.extra_info] + orig_size = np.array(transform[InverseKeys.ORIG_SIZE.value]) + extra_info = transform[InverseKeys.EXTRA_INFO.value] pad_to_start = np.array(extra_info["box_start"]) pad_to_end = orig_size - np.array(extra_info["box_end"]) # interweave mins and maxes @@ -835,7 +835,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) # Create inverse transform - orig_size = transform[InvertibleTransform.Keys.orig_size] + orig_size = transform[InverseKeys.ORIG_SIZE.value] inverse_transform = ResizeWithPadOrCrop(spatial_size=orig_size, mode=self.padcropper.padder.mode) # Apply inverse transform d[key] = inverse_transform(d[key]) diff --git a/monai/transforms/inverse.py b/monai/transforms/inverse.py index 16d3dc85bc..fae560669d 100644 --- a/monai/transforms/inverse.py +++ b/monai/transforms/inverse.py @@ -9,12 +9,19 @@ # See the License for the specific language governing permissions and # limitations under the License. +import warnings from typing import Dict, Hashable, Optional, Tuple import numpy as np +import torch from monai.transforms.transform import RandomizableTransform, Transform from monai.utils.enums import InverseKeys +from monai.utils.module import optional_import + +sitk, has_sitk = optional_import("SimpleITK") +vtk, has_vtk = optional_import("vtk") +vtk_numpy_support, _ = optional_import("vtk.util.numpy_support") __all__ = ["InvertibleTransform"] diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index fe6e3ac496..53ab683cc6 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -52,6 +52,7 @@ fall_back_tuple, optional_import, ) +from monai.utils.enums import InverseKeys nib, _ = optional_import("nibabel") @@ -224,8 +225,8 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar + "Please raise a github issue if you need this feature" ) # Create inverse transform - meta_data = d[transform[InvertibleTransform.Keys.extra_info]["meta_data_key"]] - old_affine = np.array(transform[InvertibleTransform.Keys.extra_info]["old_affine"]) + meta_data = d[transform[InverseKeys.EXTRA_INFO.value]["meta_data_key"]] + old_affine = np.array(transform[InverseKeys.EXTRA_INFO.value]["old_affine"]) orig_pixdim = np.sqrt(np.sum(np.square(old_affine), 0))[:-1] inverse_transform = Spacing(orig_pixdim, diagonal=self.spacing_transform.diagonal) # Apply inverse @@ -311,8 +312,8 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) # Create inverse transform - meta_data = d[transform[InvertibleTransform.Keys.extra_info]["meta_data_key"]] - orig_affine = transform[InvertibleTransform.Keys.extra_info]["old_affine"] + meta_data = d[transform[InverseKeys.EXTRA_INFO.value]["meta_data_key"]] + orig_affine = transform[InverseKeys.EXTRA_INFO.value]["old_affine"] orig_axcodes = nib.orientations.aff2axcodes(orig_affine) inverse_transform = Orientation( axcodes=orig_axcodes, @@ -428,9 +429,9 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) # Check if random transform was actually performed (based on `prob`) - if transform[InvertibleTransform.Keys.do_transform]: + if transform[InverseKeys.DO_TRANSFORM.value]: # Create inverse transform - num_times_rotated = transform[InvertibleTransform.Keys.extra_info]["rand_k"] + num_times_rotated = transform[InverseKeys.EXTRA_INFO.value]["rand_k"] num_times_to_rotate = 4 - num_times_rotated inverse_transform = Rotate90(num_times_to_rotate, self.spatial_axes) # Might need to convert to numpy @@ -490,7 +491,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar d = deepcopy(dict(data)) for key, mode, align_corners in self.key_iterator(d, self.mode, self.align_corners): transform = self.get_most_recent_transform(d, key) - orig_size = transform[InvertibleTransform.Keys.orig_size] + orig_size = transform[InverseKeys.ORIG_SIZE.value] # Create inverse transform inverse_transform = Resize(orig_size, mode, align_corners) # Apply inverse transform @@ -681,9 +682,9 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar for key, mode, padding_mode in self.key_iterator(d, self.mode, self.padding_mode): transform = self.get_most_recent_transform(d, key) - orig_size = transform[InvertibleTransform.Keys.orig_size] + orig_size = transform[InverseKeys.ORIG_SIZE.value] # Create inverse transform - fwd_affine = transform[InvertibleTransform.Keys.extra_info]["affine"] + fwd_affine = transform[InverseKeys.EXTRA_INFO.value]["affine"] inv_affine = np.linalg.inv(fwd_affine) affine_grid = AffineGrid(affine=inv_affine) @@ -834,22 +835,22 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar for idx, (key, mode, padding_mode) in enumerate(self.key_iterator(d, self.mode, self.padding_mode)): transform = self.get_most_recent_transform(d, key) # Create inverse transform - if transform[InvertibleTransform.Keys.do_transform]: - orig_size = transform[InvertibleTransform.Keys.orig_size] + if transform[InverseKeys.DO_TRANSFORM.value]: + orig_size = transform[InverseKeys.ORIG_SIZE.value] # Only need to calculate inverse deformation once as it is the same for all keys if idx == 0: # If magnitude == 0, then non-rigid component is identity -- so just create blank if self.rand_2d_elastic.deform_grid.magnitude == (0.0, 0.0): inv_def_no_affine = create_grid(spatial_size=orig_size) else: - fwd_cpg_no_affine = transform[InvertibleTransform.Keys.extra_info]["cpg"] + fwd_cpg_no_affine = transform[InverseKeys.EXTRA_INFO.value]["cpg"] fwd_def_no_affine = self.cpg_to_dvf( fwd_cpg_no_affine, self.rand_2d_elastic.deform_grid.spacing, orig_size ) inv_def_no_affine = self.compute_inverse_deformation(len(orig_size), fwd_def_no_affine) # if inverse did not succeed (sitk or vtk present), data will not be changed. if inv_def_no_affine is not None: - fwd_affine = transform[InvertibleTransform.Keys.extra_info]["affine"] + fwd_affine = transform[InverseKeys.EXTRA_INFO.value]["affine"] inv_affine = np.linalg.inv(fwd_affine) inv_def_w_affine_wrong_size = AffineGrid(affine=inv_affine, as_tensor_output=False)( grid=inv_def_no_affine @@ -993,15 +994,15 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar for idx, (key, mode, padding_mode) in enumerate(self.key_iterator(d, self.mode, self.padding_mode)): transform = self.get_most_recent_transform(d, key) # Create inverse transform - if transform[InvertibleTransform.Keys.do_transform]: - orig_size = transform[InvertibleTransform.Keys.orig_size] + if transform[InverseKeys.DO_TRANSFORM.value]: + orig_size = transform[InverseKeys.ORIG_SIZE.value] # Only need to calculate inverse deformation once as it is the same for all keys if idx == 0: - fwd_def_no_affine = transform[InvertibleTransform.Keys.extra_info]["grid_no_affine"] + fwd_def_no_affine = transform[InverseKeys.EXTRA_INFO.value]["grid_no_affine"] inv_def_no_affine = self.compute_inverse_deformation(len(orig_size), fwd_def_no_affine) # if inverse did not succeed (sitk or vtk present), data will not be changed. if inv_def_no_affine is not None: - fwd_affine = transform[InvertibleTransform.Keys.extra_info]["affine"] + fwd_affine = transform[InverseKeys.EXTRA_INFO.value]["affine"] inv_affine = np.linalg.inv(fwd_affine) inv_def_w_affine_wrong_size = AffineGrid(affine=inv_affine, as_tensor_output=False)( grid=inv_def_no_affine @@ -1105,7 +1106,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) # Check if random transform was actually performed (based on `prob`) - if transform[InvertibleTransform.Keys.do_transform]: + if transform[InverseKeys.DO_TRANSFORM.value]: # Might need to convert to numpy if isinstance(d[key], torch.Tensor): d[key] = torch.Tensor(d[key]).cpu().numpy() @@ -1155,8 +1156,8 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) # Check if random transform was actually performed (based on `prob`) - if transform[InvertibleTransform.Keys.do_transform]: - flipper = Flip(spatial_axis=transform[InvertibleTransform.Keys.extra_info]["axis"]) + if transform[InverseKeys.DO_TRANSFORM.value]: + flipper = Flip(spatial_axis=transform[InverseKeys.EXTRA_INFO.value]["axis"]) # Might need to convert to numpy if isinstance(d[key], torch.Tensor): d[key] = torch.Tensor(d[key]).cpu().numpy() @@ -1238,7 +1239,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar ): transform = self.get_most_recent_transform(d, key) # Create inverse transform - fwd_rot_mat = transform[InvertibleTransform.Keys.extra_info]["rot_mat"] + fwd_rot_mat = transform[InverseKeys.EXTRA_INFO.value]["rot_mat"] inv_rot_mat = np.linalg.inv(fwd_rot_mat) xform = AffineTransform( @@ -1251,7 +1252,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar output = xform( torch.as_tensor(np.ascontiguousarray(d[key]).astype(dtype)).unsqueeze(0), torch.as_tensor(np.ascontiguousarray(inv_rot_mat).astype(dtype)), - spatial_size=transform[InvertibleTransform.Keys.orig_size], + spatial_size=transform[InverseKeys.ORIG_SIZE.value], ) d[key] = np.asarray(output.squeeze(0).detach().cpu().numpy(), dtype=np.float32) # Remove the applied transform @@ -1371,9 +1372,9 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar ): transform = self.get_most_recent_transform(d, key) # Check if random transform was actually performed (based on `prob`) - if transform[InvertibleTransform.Keys.do_transform]: + if transform[InverseKeys.DO_TRANSFORM.value]: # Create inverse transform - fwd_rot_mat = transform[InvertibleTransform.Keys.extra_info]["rot_mat"] + fwd_rot_mat = transform[InverseKeys.EXTRA_INFO.value]["rot_mat"] inv_rot_mat = np.linalg.inv(fwd_rot_mat) xform = AffineTransform( @@ -1386,7 +1387,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar output = xform( torch.as_tensor(np.ascontiguousarray(d[key]).astype(dtype)).unsqueeze(0), torch.as_tensor(np.ascontiguousarray(inv_rot_mat).astype(dtype)), - spatial_size=transform[InvertibleTransform.Keys.orig_size], + spatial_size=transform[InverseKeys.ORIG_SIZE.value], ) d[key] = np.asarray(output.squeeze(0).detach().cpu().numpy(), dtype=np.float32) # Remove the applied transform @@ -1467,7 +1468,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar align_corners=align_corners, ) # Size might be out by 1 voxel so pad - d[key] = SpatialPad(transform[InvertibleTransform.Keys.orig_size])(d[key]) + d[key] = SpatialPad(transform[InverseKeys.ORIG_SIZE.value])(d[key]) # Remove the applied transform self.pop_transform(d, key) @@ -1573,7 +1574,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar ): transform = self.get_most_recent_transform(d, key) # Create inverse transform - zoom = np.array(transform[InvertibleTransform.Keys.extra_info]["zoom"]) + zoom = np.array(transform[InverseKeys.EXTRA_INFO.value]["zoom"]) inverse_transform = Zoom(zoom=1 / zoom, keep_size=self.keep_size) # Apply inverse d[key] = inverse_transform( @@ -1583,7 +1584,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar align_corners=align_corners, ) # Size might be out by 1 voxel so pad - d[key] = SpatialPad(transform[InvertibleTransform.Keys.orig_size])(d[key]) + d[key] = SpatialPad(transform[InverseKeys.ORIG_SIZE.value])(d[key]) # Remove the applied transform self.pop_transform(d, key) diff --git a/tests/test_decollate.py b/tests/test_decollate.py index af5631d933..ac0e0f73f1 100644 --- a/tests/test_decollate.py +++ b/tests/test_decollate.py @@ -19,7 +19,6 @@ from monai.data import CacheDataset, DataLoader, create_test_image_2d from monai.data.utils import decollate_batch from monai.transforms import AddChanneld, Compose, LoadImaged, RandFlipd, SpatialPadd, ToTensord -from monai.transforms.inverse import InvertibleTransform from monai.transforms.post.dictionary import Decollated from monai.utils import optional_import, set_determinism from monai.utils.enums import InverseKeys diff --git a/tests/test_inverse.py b/tests/test_inverse.py index c1e68f270a..cbb4ddcd97 100644 --- a/tests/test_inverse.py +++ b/tests/test_inverse.py @@ -54,6 +54,7 @@ allow_missing_keys_mode, ) from monai.utils import first, optional_import, set_determinism +from monai.utils.enums import InverseKeys from tests.utils import make_nifti_image, make_rand_affine, test_is_quick if TYPE_CHECKING: @@ -462,11 +463,6 @@ def check_inverse(self, name, keys, orig_d, fwd_bck_d, unmodified_d, acceptable_ print( f"Failed: {name}. Mean diff = {mean_diff} (expected <= {acceptable_diff}), unmodified diff: {unmodded_diff}" ) - if has_matplotlib and orig[0].ndim > 1: - plot_im(orig, fwd_bck, unmodified) - elif orig[0].ndim == 1: - print(orig) - print(fwd_bck) raise @parameterized.expand(TESTS) @@ -521,7 +517,9 @@ def test_inverse_inferred_seg(self): data = first(loader) labels = data["label"].to(device) segs = model(labels).detach().cpu() - segs_dict = {"label": segs, "label_transforms": data["label_transforms"]} + label_transform_key = "label" + InverseKeys.KEY_SUFFIX.value + segs_dict = {"label": segs, label_transform_key: data[label_transform_key]} + segs_dict_decollated = decollate_batch(segs_dict) # inverse of individual segmentation From 97e0ad3ca10789c35a32f589b528b619fa513f35 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Mon, 15 Mar 2021 10:53:58 +0000 Subject: [PATCH 42/64] remove duplicate tests Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- tests/test_inverse.py | 91 ------------------------------------------- 1 file changed, 91 deletions(-) diff --git a/tests/test_inverse.py b/tests/test_inverse.py index 94bcee3a51..c89142a90d 100644 --- a/tests/test_inverse.py +++ b/tests/test_inverse.py @@ -237,73 +237,6 @@ ) ) -TESTS.append( - ( - "SpatialCropd 2d", - "2D", - 3e-2, - SpatialCropd(KEYS, [49, 51], [90, 89]), - ) -) - -TESTS.append( - ( - "SpatialCropd 3d", - "3D", - 4e-2, - SpatialCropd(KEYS, [49, 51, 44], [90, 89, 93]), - ) -) - -TESTS.append(("RandSpatialCropd 2d", "2D", 5e-2, RandSpatialCropd(KEYS, [96, 93], True, False))) - -TESTS.append(("RandSpatialCropd 3d", "3D", 2e-2, RandSpatialCropd(KEYS, [96, 93, 92], False, False))) - -TESTS.append( - ( - "BorderPadd 2d", - "2D", - 0, - BorderPadd(KEYS, [3, 7, 2, 5]), - ) -) - -TESTS.append( - ( - "BorderPadd 2d", - "2D", - 0, - BorderPadd(KEYS, [3, 7]), - ) -) - -TESTS.append( - ( - "BorderPadd 3d", - "3D", - 0, - BorderPadd(KEYS, [4]), - ) -) - -TESTS.append( - ( - "DivisiblePadd 2d", - "2D", - 0, - DivisiblePadd(KEYS, k=4), - ) -) - -TESTS.append( - ( - "DivisiblePadd 3d", - "3D", - 0, - DivisiblePadd(KEYS, k=[4, 8, 11]), - ) -) - TESTS.append( ( "Flipd 3d", @@ -437,36 +370,12 @@ TESTS.append(("RandZoom 3d", "3D", 9e-2, RandZoomd(KEYS, 1, [0.5, 0.6, 0.9], [1.1, 1, 1.05], keep_size=True))) -TESTS.append( - ( - "CenterSpatialCropd 2d", - "2D", - 0, - CenterSpatialCropd(KEYS, roi_size=95), - ) -) - -TESTS.append( - ( - "CenterSpatialCropd 3d", - "3D", - 0, - CenterSpatialCropd(KEYS, roi_size=[95, 97, 98]), - ) -) - -TESTS.append(("CropForegroundd 2d", "2D", 0, CropForegroundd(KEYS, source_key="label", margin=2))) - -TESTS.append(("CropForegroundd 3d", "3D", 0, CropForegroundd(KEYS, source_key="label"))) - TESTS.append(("Spacingd 3d", "3D", 3e-2, Spacingd(KEYS, [0.5, 0.7, 0.9], diagonal=False))) TESTS.append(("Resized 2d", "2D", 2e-1, Resized(KEYS, [50, 47]))) TESTS.append(("Resized 3d", "3D", 5e-2, Resized(KEYS, [201, 150, 78]))) -TESTS.append(("ResizeWithPadOrCropd 3d", "3D", 3e-2, ResizeWithPadOrCropd(KEYS, [201, 150, 78]))) - TESTS.append( ( "RandAffine 3d", From ea31dd771beccccac660d6453dd4204c2dde7b00 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Mon, 15 Mar 2021 12:30:06 +0000 Subject: [PATCH 43/64] lossless inverse Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/spatial/array.py | 8 +- monai/transforms/spatial/dictionary.py | 135 +++++++++++++++++++++++-- tests/test_inverse.py | 102 +++++++++++++++++++ 3 files changed, 232 insertions(+), 13 deletions(-) diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index 2867361b8e..33b8da3ebb 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -281,7 +281,7 @@ def __call__( ornt[:, 0] += 1 # skip channel dim ornt = np.concatenate([np.array([[0, 1]]), ornt]) shape = data_array.shape[1:] - data_array = nib.orientations.apply_orientation(data_array, ornt) + data_array = np.ascontiguousarray(nib.orientations.apply_orientation(data_array, ornt)) new_affine = affine_ @ nib.orientations.inv_ornt_aff(spatial_ornt, shape) new_affine = to_affine_nd(affine, new_affine) return data_array, affine, new_affine @@ -590,7 +590,7 @@ def __init__(self, k: int = 1, spatial_axes: Tuple[int, int] = (0, 1)) -> None: If axis is negative it counts from the last to the first axis. """ self.k = k - spatial_axes_ = ensure_tuple(spatial_axes) + spatial_axes_: Tuple[int, int] = ensure_tuple(spatial_axes) # type: ignore if len(spatial_axes_) != 2: raise ValueError("spatial_axes must be 2 int numbers to indicate the axes to rotate 90 degrees.") self.spatial_axes = spatial_axes_ @@ -620,7 +620,7 @@ def __init__(self, prob: float = 0.1, max_k: int = 3, spatial_axes: Tuple[int, i spatial_axes: 2 int numbers, defines the plane to rotate with 2 spatial axes. Default: (0, 1), this is the first two axis in spatial dimensions. """ - RandomizableTransform.__init__(self, min(max(prob, 0.0), 1.0)) + RandomizableTransform.__init__(self, prob) self.max_k = max_k self.spatial_axes = spatial_axes @@ -758,7 +758,7 @@ class RandFlip(RandomizableTransform): """ def __init__(self, prob: float = 0.1, spatial_axis: Optional[Union[Sequence[int], int]] = None) -> None: - RandomizableTransform.__init__(self, min(max(prob, 0.0), 1.0)) + RandomizableTransform.__init__(self, prob) self.flipper = Flip(spatial_axis=spatial_axis) def __call__(self, img: np.ndarray) -> np.ndarray: diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index d9d38242fb..7cc10ccbfd 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -15,16 +15,20 @@ Class names are ended with 'd' to denote dictionary-based transforms. """ +from copy import deepcopy from typing import Any, Dict, Hashable, Mapping, Optional, Sequence, Tuple, Union import numpy as np import torch from monai.config import DtypeLike, KeysCollection +from monai.networks.layers import AffineTransform from monai.networks.layers.simplelayers import GaussianFilter -from monai.transforms.croppad.array import CenterSpatialCrop +from monai.transforms.croppad.array import CenterSpatialCrop, SpatialPad +from monai.transforms.inverse import InvertibleTransform, NonRigidTransform from monai.transforms.spatial.array import ( Affine, + AffineGrid, Flip, Orientation, Rand2DElastic, @@ -47,6 +51,7 @@ ensure_tuple_rep, fall_back_tuple, ) +from monai.utils.enums import InverseKeys __all__ = [ "Spacingd", @@ -204,7 +209,7 @@ def __call__( return d -class Orientationd(MapTransform): +class Orientationd(MapTransform, InvertibleTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.Orientation`. @@ -259,13 +264,36 @@ def __call__( ) -> Dict[Union[Hashable, str], Union[np.ndarray, Dict[str, np.ndarray]]]: d: Dict = dict(data) for key in self.key_iterator(d): - meta_data = d[f"{key}_{self.meta_key_postfix}"] - d[key], _, new_affine = self.ornt_transform(d[key], affine=meta_data["affine"]) + meta_data_key = f"{key}_{self.meta_key_postfix}" + meta_data = d[meta_data_key] + d[key], old_affine, new_affine = self.ornt_transform(d[key], affine=meta_data["affine"]) + self.push_transform(d, key, extra_info={"meta_data_key": meta_data_key, "old_affine": old_affine}) + d[meta_data_key]["affine"] = new_affine + return d + + def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + d = deepcopy(dict(data)) + for key in self.key_iterator(d): + transform = self.get_most_recent_transform(d, key) + # Create inverse transform + meta_data = d[transform[InverseKeys.EXTRA_INFO.value]["meta_data_key"]] + orig_affine = transform[InverseKeys.EXTRA_INFO.value]["old_affine"] + orig_axcodes = nib.orientations.aff2axcodes(orig_affine) + inverse_transform = Orientation( + axcodes=orig_axcodes, + as_closest_canonical=self.ornt_transform.as_closest_canonical, + labels=self.ornt_transform.labels, + ) + # Apply inverse + d[key], _, new_affine = inverse_transform(d[key], affine=meta_data["affine"]) meta_data["affine"] = new_affine + # Remove the applied transform + self.pop_transform(d, key) + return d -class Rotate90d(MapTransform): +class Rotate90d(MapTransform, InvertibleTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.Rotate90`. """ @@ -286,11 +314,31 @@ def __init__( def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) for key in self.key_iterator(d): + self.push_transform(d, key) d[key] = self.rotator(d[key]) return d + def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + d = deepcopy(dict(data)) + for key in self.key_iterator(d): + _ = self.get_most_recent_transform(d, key) + # Create inverse transform + spatial_axes = self.rotator.spatial_axes + num_times_rotated = self.rotator.k + num_times_to_rotate = 4 - num_times_rotated + inverse_transform = Rotate90(num_times_to_rotate, spatial_axes) + # Might need to convert to numpy + if isinstance(d[key], torch.Tensor): + d[key] = torch.Tensor(d[key]).cpu().numpy() + # Apply inverse + d[key] = inverse_transform(d[key]) + # Remove the applied transform + self.pop_transform(d, key) + + return d + -class RandRotate90d(RandomizableTransform, MapTransform): +class RandRotate90d(RandomizableTransform, MapTransform, InvertibleTransform): """ Dictionary-based version :py:class:`monai.transforms.RandRotate90`. With probability `prob`, input arrays are rotated by 90 degrees @@ -337,6 +385,27 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Mapping[Hashable, np. for key in self.key_iterator(d): if self._do_transform: d[key] = rotator(d[key]) + self.push_transform(d, key, extra_info={"rand_k": self._rand_k}) + return d + + def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + d = deepcopy(dict(data)) + for key in self.key_iterator(d): + transform = self.get_most_recent_transform(d, key) + # Check if random transform was actually performed (based on `prob`) + if transform[InverseKeys.DO_TRANSFORM.value]: + # Create inverse transform + num_times_rotated = transform[InverseKeys.EXTRA_INFO.value]["rand_k"] + num_times_to_rotate = 4 - num_times_rotated + inverse_transform = Rotate90(num_times_to_rotate, self.spatial_axes) + # Might need to convert to numpy + if isinstance(d[key], torch.Tensor): + d[key] = torch.Tensor(d[key]).cpu().numpy() + # Apply inverse + d[key] = inverse_transform(d[key]) + # Remove the applied transform + self.pop_transform(d, key) + return d @@ -789,7 +858,7 @@ def __call__( return d -class Flipd(MapTransform): +class Flipd(MapTransform, InvertibleTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.Flip`. @@ -814,11 +883,26 @@ def __init__( def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) for key in self.key_iterator(d): + self.push_transform(d, key) d[key] = self.flipper(d[key]) return d + def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + d = deepcopy(dict(data)) + for key in self.key_iterator(d): + _ = self.get_most_recent_transform(d, key) + # Might need to convert to numpy + if isinstance(d[key], torch.Tensor): + d[key] = torch.Tensor(d[key]).cpu().numpy() + # Inverse is same as forward + d[key] = self.flipper(d[key]) + # Remove the applied transform + self.pop_transform(d, key) + + return d -class RandFlipd(RandomizableTransform, MapTransform): + +class RandFlipd(RandomizableTransform, MapTransform, InvertibleTransform): """ Dictionary-based version :py:class:`monai.transforms.RandFlip`. @@ -851,10 +935,26 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda for key in self.key_iterator(d): if self._do_transform: d[key] = self.flipper(d[key]) + self.push_transform(d, key) + return d + + def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + d = deepcopy(dict(data)) + for key in self.key_iterator(d): + transform = self.get_most_recent_transform(d, key) + # Check if random transform was actually performed (based on `prob`) + if transform[InverseKeys.DO_TRANSFORM.value]: + # Might need to convert to numpy + if isinstance(d[key], torch.Tensor): + d[key] = torch.Tensor(d[key]).cpu().numpy() + # Inverse is same as forward + d[key] = self.flipper(d[key]) + # Remove the applied transform + self.pop_transform(d, key) return d -class RandAxisFlipd(RandomizableTransform, MapTransform): +class RandAxisFlipd(RandomizableTransform, MapTransform, InvertibleTransform): """ Dictionary-based version :py:class:`monai.transforms.RandAxisFlip`. @@ -885,6 +985,23 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda for key in self.key_iterator(d): if self._do_transform: d[key] = flipper(d[key]) + self.push_transform(d, key, extra_info={"axis": self._axis}) + return d + + def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + d = deepcopy(dict(data)) + for key in self.key_iterator(d): + transform = self.get_most_recent_transform(d, key) + # Check if random transform was actually performed (based on `prob`) + if transform[InverseKeys.DO_TRANSFORM.value]: + flipper = Flip(spatial_axis=transform[InverseKeys.EXTRA_INFO.value]["axis"]) + # Might need to convert to numpy + if isinstance(d[key], torch.Tensor): + d[key] = torch.Tensor(d[key]).cpu().numpy() + # Inverse is same as forward + d[key] = flipper(d[key]) + # Remove the applied transform + self.pop_transform(d, key) return d diff --git a/tests/test_inverse.py b/tests/test_inverse.py index 6635a4126f..1f81d71b58 100644 --- a/tests/test_inverse.py +++ b/tests/test_inverse.py @@ -9,6 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import random import sys import unittest from functools import partial @@ -28,11 +29,18 @@ Compose, CropForegroundd, DivisiblePadd, + Flipd, InvertibleTransform, LoadImaged, + Orientationd, + RandAxisFlipd, + RandFlipd, + RandRotate90d, + RandRotated, RandSpatialCropd, ResizeWithPadOrCrop, ResizeWithPadOrCropd, + Rotate90d, SpatialCropd, SpatialPadd, allow_missing_keys_mode, @@ -206,6 +214,100 @@ TESTS.append(("ResizeWithPadOrCropd 3d", "3D", 0, ResizeWithPadOrCropd(KEYS, [201, 150, 105]))) +TESTS.append( + ( + "RandRotated, prob 0", + "2D", + 0, + RandRotated(KEYS, prob=0), + ) +) + +TESTS.append( + ( + "Flipd 3d", + "3D", + 0, + Flipd(KEYS, [1, 2]), + ) +) + +TESTS.append( + ( + "Flipd 3d", + "3D", + 0, + Flipd(KEYS, [1, 2]), + ) +) + +TESTS.append( + ( + "RandFlipd 3d", + "3D", + 0, + RandFlipd(KEYS, 1, [1, 2]), + ) +) + +TESTS.append( + ( + "RandAxisFlipd 3d", + "3D", + 0, + RandAxisFlipd(KEYS, 1), + ) +) + +TESTS.append( + ( + "RandRotated 3d", + "3D", + 1e-1, + RandRotated(KEYS, *[random.uniform(np.pi / 6, np.pi) for _ in range(3)], 1), # type: ignore + ) +) + +TESTS.append( + ( + "Orientationd 3d", + "3D", + 0, + # For data loader, output needs to be same size, so input must be square/cubic + SpatialPadd(KEYS, 110), + Orientationd(KEYS, "RAS"), + ) +) + +TESTS.append( + ( + "Rotate90d 2d", + "2D", + 0, + Rotate90d(KEYS), + ) +) + +TESTS.append( + ( + "Rotate90d 3d", + "3D", + 0, + Rotate90d(KEYS, k=2, spatial_axes=(1, 2)), + ) +) + +TESTS.append( + ( + "RandRotate90d 3d", + "3D", + 0, + # For data loader, output needs to be same size, so input must be square/cubic + SpatialPadd(KEYS, 110), + RandRotate90d(KEYS, prob=1, spatial_axes=(1, 2)), + ) +) + TESTS_COMPOSE_X2 = [(t[0] + " Compose", t[1], t[2], Compose(Compose(t[3:]))) for t in TESTS] TESTS = TESTS + TESTS_COMPOSE_X2 # type: ignore From 28f5a36615275978d0a7c71d236d26cb23beb048 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Mon, 15 Mar 2021 12:42:28 +0000 Subject: [PATCH 44/64] code format Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/spatial/dictionary.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index 7cc10ccbfd..68cb18ae18 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -22,13 +22,11 @@ import torch from monai.config import DtypeLike, KeysCollection -from monai.networks.layers import AffineTransform from monai.networks.layers.simplelayers import GaussianFilter -from monai.transforms.croppad.array import CenterSpatialCrop, SpatialPad -from monai.transforms.inverse import InvertibleTransform, NonRigidTransform +from monai.transforms.croppad.array import CenterSpatialCrop +from monai.transforms.inverse import InvertibleTransform from monai.transforms.spatial.array import ( Affine, - AffineGrid, Flip, Orientation, Rand2DElastic, @@ -52,6 +50,9 @@ fall_back_tuple, ) from monai.utils.enums import InverseKeys +from monai.utils.module import optional_import + +nib, _ = optional_import("nibabel") __all__ = [ "Spacingd", From fed515e2c2ad4c0bd348711df2b98fe1b430bf7e Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Mon, 15 Mar 2021 15:32:47 +0000 Subject: [PATCH 45/64] remove extra tests Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- tests/test_inverse.py | 20 -------------------- 1 file changed, 20 deletions(-) diff --git a/tests/test_inverse.py b/tests/test_inverse.py index 1f81d71b58..d3a5a533cd 100644 --- a/tests/test_inverse.py +++ b/tests/test_inverse.py @@ -9,7 +9,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import random import sys import unittest from functools import partial @@ -36,7 +35,6 @@ RandAxisFlipd, RandFlipd, RandRotate90d, - RandRotated, RandSpatialCropd, ResizeWithPadOrCrop, ResizeWithPadOrCropd, @@ -214,15 +212,6 @@ TESTS.append(("ResizeWithPadOrCropd 3d", "3D", 0, ResizeWithPadOrCropd(KEYS, [201, 150, 105]))) -TESTS.append( - ( - "RandRotated, prob 0", - "2D", - 0, - RandRotated(KEYS, prob=0), - ) -) - TESTS.append( ( "Flipd 3d", @@ -259,15 +248,6 @@ ) ) -TESTS.append( - ( - "RandRotated 3d", - "3D", - 1e-1, - RandRotated(KEYS, *[random.uniform(np.pi / 6, np.pi) for _ in range(3)], 1), # type: ignore - ) -) - TESTS.append( ( "Orientationd 3d", From c72407ab83432b1e7ebbe3b1e8beb691c4d052d4 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Tue, 16 Mar 2021 08:56:06 +0000 Subject: [PATCH 46/64] update tests Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/spatial/dictionary.py | 2 +- tests/test_inverse.py | 19 ++++++++----------- 2 files changed, 9 insertions(+), 12 deletions(-) diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index 68cb18ae18..170006ed2b 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -282,7 +282,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar orig_axcodes = nib.orientations.aff2axcodes(orig_affine) inverse_transform = Orientation( axcodes=orig_axcodes, - as_closest_canonical=self.ornt_transform.as_closest_canonical, + as_closest_canonical=False, labels=self.ornt_transform.labels, ) # Apply inverse diff --git a/tests/test_inverse.py b/tests/test_inverse.py index 53c95f11b7..0c29ea7b08 100644 --- a/tests/test_inverse.py +++ b/tests/test_inverse.py @@ -249,16 +249,15 @@ ) ) -TESTS.append( - ( - "Orientationd 3d", - "3D", - 0, - # For data loader, output needs to be same size, so input must be square/cubic - SpatialPadd(KEYS, 110), - Orientationd(KEYS, "RAS"), +for acc in [True, False]: + TESTS.append( + ( + "Orientationd 3d", + "3D", + 0, + Orientationd(KEYS, "RAS", as_closest_canonical=acc), + ) ) -) TESTS.append( ( @@ -283,8 +282,6 @@ "RandRotate90d 3d", "3D", 0, - # For data loader, output needs to be same size, so input must be square/cubic - SpatialPadd(KEYS, 110), RandRotate90d(KEYS, prob=1, spatial_axes=(1, 2)), ) ) From 070ae46901fbfbab7d4afe568e61c8562e4c8686 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Tue, 16 Mar 2021 10:28:20 +0000 Subject: [PATCH 47/64] update after merge Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/spatial/dictionary.py | 1 - tests/test_inverse.py | 8 ++------ 2 files changed, 2 insertions(+), 7 deletions(-) diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index eda9a24776..d12992e95d 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -53,7 +53,6 @@ optional_import, ) from monai.utils.enums import InverseKeys -from monai.utils.module import optional_import nib, _ = optional_import("nibabel") diff --git a/tests/test_inverse.py b/tests/test_inverse.py index 38f60b04dd..7003386b23 100644 --- a/tests/test_inverse.py +++ b/tests/test_inverse.py @@ -30,8 +30,6 @@ CropForegroundd, DivisiblePadd, Flipd, - CropForegroundd, - DivisiblePadd, InvertibleTransform, LoadImaged, Orientationd, @@ -45,19 +43,17 @@ RandSpatialCropd, RandZoomd, Resized, - RandSpatialCropd, ResizeWithPadOrCrop, ResizeWithPadOrCropd, Rotate90d, Rotated, Spacingd, SpatialCropd, - ResizeWithPadOrCropd, - SpatialCropd, SpatialPadd, Zoomd, allow_missing_keys_mode, ) +from monai.transforms.transform import Randomizable from monai.utils import first, get_seed, optional_import, set_determinism from monai.utils.enums import InverseKeys from tests.utils import make_nifti_image, make_rand_affine, test_is_quick @@ -417,7 +413,7 @@ TESTS.append( ( "Zoomd 1d", - "1D", + "1D odd", 0, Zoomd(KEYS, zoom=2, keep_size=False), ) From e4b80e3b8e09ce7a68e1ea13334113655b2a35c3 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Tue, 16 Mar 2021 10:52:56 +0000 Subject: [PATCH 48/64] test fixes Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- tests/test_inverse.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/test_inverse.py b/tests/test_inverse.py index 7003386b23..53e6396262 100644 --- a/tests/test_inverse.py +++ b/tests/test_inverse.py @@ -342,7 +342,6 @@ ) ) - TESTS.append( ( "Rotated 2d", @@ -484,7 +483,7 @@ ) ) -if not test_is_quick and has_vtk: +if not test_is_quick() and has_vtk: TESTS.append( ( "Rand3DElasticd 3d", From 9907eeb3500a7216b4ce0a116d72ed48a9c63888 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Tue, 16 Mar 2021 11:31:17 +0000 Subject: [PATCH 49/64] Zoomd and RandZoomd Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/spatial/dictionary.py | 54 +++++++++++++++++++++++++- tests/test_inverse.py | 32 +++++++++++++++ 2 files changed, 84 insertions(+), 2 deletions(-) diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index 170006ed2b..42163c4ba6 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -1167,7 +1167,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda return d -class Zoomd(MapTransform): +class Zoomd(MapTransform, InvertibleTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.Zoom`. @@ -1213,6 +1213,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda for key, mode, padding_mode, align_corners in self.key_iterator( d, self.mode, self.padding_mode, self.align_corners ): + self.push_transform(d, key) d[key] = self.zoomer( d[key], mode=mode, @@ -1221,8 +1222,31 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda ) return d + def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + d = deepcopy(dict(data)) + for key, mode, padding_mode, align_corners in self.key_iterator( + d, self.mode, self.padding_mode, self.align_corners + ): + transform = self.get_most_recent_transform(d, key) + # Create inverse transform + zoom = np.array(self.zoomer.zoom) + inverse_transform = Zoom(zoom=1 / zoom, keep_size=self.zoomer.keep_size) + # Apply inverse + d[key] = inverse_transform( + d[key], + mode=mode, + padding_mode=padding_mode, + align_corners=align_corners, + ) + # Size might be out by 1 voxel so pad + d[key] = SpatialPad(transform[InverseKeys.ORIG_SIZE.value])(d[key]) + # Remove the applied transform + self.pop_transform(d, key) + + return d + -class RandZoomd(RandomizableTransform, MapTransform): +class RandZoomd(RandomizableTransform, MapTransform, InvertibleTransform): """ Dict-based version :py:class:`monai.transforms.RandZoom`. @@ -1290,6 +1314,8 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda self.randomize() d = dict(data) if not self._do_transform: + for key in self.keys: + self.push_transform(d, key, extra_info={"zoom": self._zoom}) return d img_dims = data[self.keys[0]].ndim @@ -1303,6 +1329,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda for key, mode, padding_mode, align_corners in self.key_iterator( d, self.mode, self.padding_mode, self.align_corners ): + self.push_transform(d, key, extra_info={"zoom": self._zoom}) d[key] = zoomer( d[key], mode=mode, @@ -1311,6 +1338,29 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda ) return d + def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + d = deepcopy(dict(data)) + for key, mode, padding_mode, align_corners in self.key_iterator( + d, self.mode, self.padding_mode, self.align_corners + ): + transform = self.get_most_recent_transform(d, key) + # Create inverse transform + zoom = np.array(transform[InverseKeys.EXTRA_INFO.value]["zoom"]) + inverse_transform = Zoom(zoom=1 / zoom, keep_size=self.keep_size) + # Apply inverse + d[key] = inverse_transform( + d[key], + mode=mode, + padding_mode=padding_mode, + align_corners=align_corners, + ) + # Size might be out by 1 voxel so pad + d[key] = SpatialPad(transform[InverseKeys.ORIG_SIZE.value])(d[key]) + # Remove the applied transform + self.pop_transform(d, key) + + return d + SpacingD = SpacingDict = Spacingd OrientationD = OrientationDict = Orientationd diff --git a/tests/test_inverse.py b/tests/test_inverse.py index 0c29ea7b08..d4e85b7c5a 100644 --- a/tests/test_inverse.py +++ b/tests/test_inverse.py @@ -37,11 +37,13 @@ Randomizable, RandRotate90d, RandSpatialCropd, + RandZoomd, ResizeWithPadOrCrop, ResizeWithPadOrCropd, Rotate90d, SpatialCropd, SpatialPadd, + Zoomd, allow_missing_keys_mode, ) from monai.utils import first, get_seed, optional_import, set_determinism @@ -286,6 +288,36 @@ ) ) + +TESTS.append( + ( + "Zoomd 1d", + "1D odd", + 0, + Zoomd(KEYS, zoom=2, keep_size=False), + ) +) + +TESTS.append( + ( + "Zoomd 2d", + "2D", + 2e-1, + Zoomd(KEYS, zoom=0.9), + ) +) + +TESTS.append( + ( + "Zoomd 3d", + "3D", + 3e-2, + Zoomd(KEYS, zoom=[2.5, 1, 3], keep_size=False), + ) +) + +TESTS.append(("RandZoom 3d", "3D", 9e-2, RandZoomd(KEYS, 1, [0.5, 0.6, 0.9], [1.1, 1, 1.05], keep_size=True))) + TESTS_COMPOSE_X2 = [(t[0] + " Compose", t[1], t[2], Compose(Compose(t[3:]))) for t in TESTS] TESTS = TESTS + TESTS_COMPOSE_X2 # type: ignore From ae2e60ff4a12bc7861ae59063de71782b12d4406 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Tue, 16 Mar 2021 11:35:14 +0000 Subject: [PATCH 50/64] add SpatialPad Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/spatial/dictionary.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index 42163c4ba6..98ad3c9478 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -23,7 +23,7 @@ from monai.config import DtypeLike, KeysCollection from monai.networks.layers.simplelayers import GaussianFilter -from monai.transforms.croppad.array import CenterSpatialCrop +from monai.transforms.croppad.array import CenterSpatialCrop, SpatialPad from monai.transforms.inverse import InvertibleTransform from monai.transforms.spatial.array import ( Affine, From 797a94e03fe173a42639b6614ae8f56389144c7a Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Tue, 16 Mar 2021 11:47:46 +0000 Subject: [PATCH 51/64] Inverse Spacingd Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/spatial/array.py | 2 +- monai/transforms/spatial/dictionary.py | 39 ++++++++++++++++++++++++-- tests/test_inverse.py | 3 ++ tests/test_spacingd.py | 14 ++++++--- 4 files changed, 50 insertions(+), 8 deletions(-) diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index 33b8da3ebb..8dd2692c2d 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -151,7 +151,7 @@ def __call__( ValueError: When ``pixdim`` is nonpositive. Returns: - data_array (resampled into `self.pixdim`), original pixdim, current pixdim. + data_array (resampled into `self.pixdim`), original affine, current affine. """ _dtype = dtype or self.dtype or data_array.dtype diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index 170006ed2b..dd385bfc6e 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -111,7 +111,7 @@ NumpyPadModeSequence = Union[Sequence[Union[NumpyPadMode, str]], NumpyPadMode, str] -class Spacingd(MapTransform): +class Spacingd(MapTransform, InvertibleTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.Spacing`. @@ -194,10 +194,11 @@ def __call__( for key, mode, padding_mode, align_corners, dtype in self.key_iterator( d, self.mode, self.padding_mode, self.align_corners, self.dtype ): - meta_data = d[f"{key}_{self.meta_key_postfix}"] + meta_data_key = f"{key}_{self.meta_key_postfix}" + meta_data = d[meta_data_key] # resample array of each corresponding key # using affine fetched from d[affine_key] - d[key], _, new_affine = self.spacing_transform( + d[key], old_affine, new_affine = self.spacing_transform( data_array=np.asarray(d[key]), affine=meta_data["affine"], mode=mode, @@ -205,10 +206,42 @@ def __call__( align_corners=align_corners, dtype=dtype, ) + self.push_transform(d, key, extra_info={"meta_data_key": meta_data_key, "old_affine": old_affine}) # set the 'affine' key meta_data["affine"] = new_affine return d + def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + d = deepcopy(dict(data)) + for key, mode, padding_mode, align_corners, dtype in self.key_iterator( + d, self.mode, self.padding_mode, self.align_corners, self.dtype + ): + transform = self.get_most_recent_transform(d, key) + if self.spacing_transform.diagonal: + raise RuntimeError( + "Spacingd:inverse not yet implemented for diagonal=True. " + + "Please raise a github issue if you need this feature" + ) + # Create inverse transform + meta_data = d[transform[InverseKeys.EXTRA_INFO.value]["meta_data_key"]] + old_affine = np.array(transform[InverseKeys.EXTRA_INFO.value]["old_affine"]) + orig_pixdim = np.sqrt(np.sum(np.square(old_affine), 0))[:-1] + inverse_transform = Spacing(orig_pixdim, diagonal=self.spacing_transform.diagonal) + # Apply inverse + d[key], _, new_affine = inverse_transform( + data_array=np.asarray(d[key]), + affine=meta_data["affine"], + mode=mode, + padding_mode=padding_mode, + align_corners=align_corners, + dtype=dtype, + ) + meta_data["affine"] = new_affine + # Remove the applied transform + self.pop_transform(d, key) + + return d + class Orientationd(MapTransform, InvertibleTransform): """ diff --git a/tests/test_inverse.py b/tests/test_inverse.py index 0c29ea7b08..204e2da723 100644 --- a/tests/test_inverse.py +++ b/tests/test_inverse.py @@ -40,6 +40,7 @@ ResizeWithPadOrCrop, ResizeWithPadOrCropd, Rotate90d, + Spacingd, SpatialCropd, SpatialPadd, allow_missing_keys_mode, @@ -286,6 +287,8 @@ ) ) +TESTS.append(("Spacingd 3d", "3D", 3e-2, Spacingd(KEYS, [0.5, 0.7, 0.9], diagonal=False))) + TESTS_COMPOSE_X2 = [(t[0] + " Compose", t[1], t[2], Compose(Compose(t[3:]))) for t in TESTS] TESTS = TESTS + TESTS_COMPOSE_X2 # type: ignore diff --git a/tests/test_spacingd.py b/tests/test_spacingd.py index ec32563543..e4efe4241d 100644 --- a/tests/test_spacingd.py +++ b/tests/test_spacingd.py @@ -21,7 +21,7 @@ def test_spacingd_3d(self): data = {"image": np.ones((2, 10, 15, 20)), "image_meta_dict": {"affine": np.eye(4)}} spacing = Spacingd(keys="image", pixdim=(1, 2, 1.4)) res = spacing(data) - self.assertEqual(("image", "image_meta_dict"), tuple(sorted(res))) + self.assertEqual(("image", "image_meta_dict", "image_transforms"), tuple(sorted(res))) np.testing.assert_allclose(res["image"].shape, (2, 10, 8, 15)) np.testing.assert_allclose(res["image_meta_dict"]["affine"], np.diag([1, 2, 1.4, 1.0])) @@ -29,7 +29,7 @@ def test_spacingd_2d(self): data = {"image": np.ones((2, 10, 20)), "image_meta_dict": {"affine": np.eye(3)}} spacing = Spacingd(keys="image", pixdim=(1, 2, 1.4)) res = spacing(data) - self.assertEqual(("image", "image_meta_dict"), tuple(sorted(res))) + self.assertEqual(("image", "image_meta_dict", "image_transforms"), tuple(sorted(res))) np.testing.assert_allclose(res["image"].shape, (2, 10, 10)) np.testing.assert_allclose(res["image_meta_dict"]["affine"], np.diag((1, 2, 1))) @@ -49,7 +49,10 @@ def test_interp_all(self): ), ) res = spacing(data) - self.assertEqual(("image", "image_meta_dict", "seg", "seg_meta_dict"), tuple(sorted(res))) + self.assertEqual( + ("image", "image_meta_dict", "image_transforms", "seg", "seg_meta_dict", "seg_transforms"), + tuple(sorted(res)), + ) np.testing.assert_allclose(res["image"].shape, (2, 1, 46)) np.testing.assert_allclose(res["image_meta_dict"]["affine"], np.diag((1, 0.2, 1, 1))) @@ -69,7 +72,10 @@ def test_interp_sep(self): ), ) res = spacing(data) - self.assertEqual(("image", "image_meta_dict", "seg", "seg_meta_dict"), tuple(sorted(res))) + self.assertEqual( + ("image", "image_meta_dict", "image_transforms", "seg", "seg_meta_dict", "seg_transforms"), + tuple(sorted(res)), + ) np.testing.assert_allclose(res["image"].shape, (2, 1, 46)) np.testing.assert_allclose(res["image_meta_dict"]["affine"], np.diag((1, 0.2, 1, 1))) From 08165fb832e82808c53bd25ea9b195c9de2e6b74 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Tue, 16 Mar 2021 11:54:24 +0000 Subject: [PATCH 52/64] Inverse Resized Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/spatial/array.py | 2 +- monai/transforms/spatial/dictionary.py | 17 ++++++++++++++++- tests/test_inverse.py | 5 +++++ 3 files changed, 22 insertions(+), 2 deletions(-) diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index 33b8da3ebb..4841a75397 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -317,7 +317,7 @@ def __call__(self, img: np.ndarray) -> np.ndarray: class Resize(Transform): """ - Resize the input image to given spatial size. + Resize the input image to given spatial size (with scaling, not cropping/padding). Implemented using :py:class:`torch.nn.functional.interpolate`. Args: diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index 170006ed2b..2f869132b8 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -410,7 +410,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar return d -class Resized(MapTransform): +class Resized(MapTransform, InvertibleTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.Resize`. @@ -448,9 +448,24 @@ def __init__( def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) for key, mode, align_corners in self.key_iterator(d, self.mode, self.align_corners): + self.push_transform(d, key) d[key] = self.resizer(d[key], mode=mode, align_corners=align_corners) return d + def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + d = deepcopy(dict(data)) + for key, mode, align_corners in self.key_iterator(d, self.mode, self.align_corners): + transform = self.get_most_recent_transform(d, key) + orig_size = transform[InverseKeys.ORIG_SIZE.value] + # Create inverse transform + inverse_transform = Resize(orig_size, mode, align_corners) + # Apply inverse transform + d[key] = inverse_transform(d[key]) + # Remove the applied transform + self.pop_transform(d, key) + + return d + class Affined(RandomizableTransform, MapTransform): """ diff --git a/tests/test_inverse.py b/tests/test_inverse.py index 0c29ea7b08..ffdaa84293 100644 --- a/tests/test_inverse.py +++ b/tests/test_inverse.py @@ -37,6 +37,7 @@ Randomizable, RandRotate90d, RandSpatialCropd, + Resized, ResizeWithPadOrCrop, ResizeWithPadOrCropd, Rotate90d, @@ -286,6 +287,10 @@ ) ) +TESTS.append(("Resized 2d", "2D", 2e-1, Resized(KEYS, [50, 47]))) + +TESTS.append(("Resized 3d", "3D", 5e-2, Resized(KEYS, [201, 150, 78]))) + TESTS_COMPOSE_X2 = [(t[0] + " Compose", t[1], t[2], Compose(Compose(t[3:]))) for t in TESTS] TESTS = TESTS + TESTS_COMPOSE_X2 # type: ignore From f22f7c00f0ca26dc7ed15bc7f39cf0a927ee8adb Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Tue, 16 Mar 2021 12:11:25 +0000 Subject: [PATCH 53/64] inverse RandAffined Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/spatial/array.py | 73 ++++++++++++++++++-------- monai/transforms/spatial/dictionary.py | 33 ++++++++++-- tests/test_inverse.py | 19 +++++++ tests/test_rand_affined.py | 2 + 4 files changed, 102 insertions(+), 25 deletions(-) diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index 33b8da3ebb..a916880fc7 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -925,6 +925,9 @@ class AffineGrid(Transform): as_tensor_output: whether to output tensor instead of numpy array. defaults to True. device: device to store the output grid data. + affine: If applied, ignore the params (`rotate_params`, etc.) and use the + supplied matrix. Should be square with each side = num of image spatial + dimensions + 1. """ @@ -936,6 +939,7 @@ def __init__( scale_params: Optional[Union[Sequence[float], float]] = None, as_tensor_output: bool = True, device: Optional[torch.device] = None, + affine: Optional[Union[np.ndarray, torch.Tensor]] = None, ) -> None: self.rotate_params = rotate_params self.shear_params = shear_params @@ -945,13 +949,19 @@ def __init__( self.as_tensor_output = as_tensor_output self.device = device + self.affine = affine + def __call__( - self, spatial_size: Optional[Sequence[int]] = None, grid: Optional[Union[np.ndarray, torch.Tensor]] = None - ) -> Union[np.ndarray, torch.Tensor]: + self, + spatial_size: Optional[Sequence[int]] = None, + grid: Optional[Union[np.ndarray, torch.Tensor]] = None, + return_affine: bool = False, + ) -> Union[np.ndarray, torch.Tensor, Tuple[Union[np.ndarray, torch.Tensor], torch.Tensor]]: """ Args: spatial_size: output grid size. grid: grid to be transformed. Shape must be (3, H, W) for 2D or (4, H, W, D) for 3D. + return_affine: boolean as to whether to return the generated affine matrix or not. Raises: ValueError: When ``grid=None`` and ``spatial_size=None``. Incompatible values. @@ -963,16 +973,20 @@ def __call__( else: raise ValueError("Incompatible values: grid=None and spatial_size=None.") - spatial_dims = len(grid.shape) - 1 - affine = np.eye(spatial_dims + 1) - if self.rotate_params: - affine = affine @ create_rotate(spatial_dims, self.rotate_params) - if self.shear_params: - affine = affine @ create_shear(spatial_dims, self.shear_params) - if self.translate_params: - affine = affine @ create_translate(spatial_dims, self.translate_params) - if self.scale_params: - affine = affine @ create_scale(spatial_dims, self.scale_params) + affine: Union[np.ndarray, torch.Tensor] + if self.affine is None: + spatial_dims = len(grid.shape) - 1 + affine = np.eye(spatial_dims + 1) + if self.rotate_params: + affine = affine @ create_rotate(spatial_dims, self.rotate_params) + if self.shear_params: + affine = affine @ create_shear(spatial_dims, self.shear_params) + if self.translate_params: + affine = affine @ create_translate(spatial_dims, self.translate_params) + if self.scale_params: + affine = affine @ create_scale(spatial_dims, self.scale_params) + else: + affine = self.affine affine = torch.as_tensor(np.ascontiguousarray(affine), device=self.device) grid = torch.tensor(grid) if not isinstance(grid, torch.Tensor) else grid.detach().clone() @@ -981,9 +995,10 @@ def __call__( grid = (affine.float() @ grid.reshape((grid.shape[0], -1)).float()).reshape([-1] + list(grid.shape[1:])) if grid is None or not isinstance(grid, torch.Tensor): raise ValueError("Unknown grid.") - if self.as_tensor_output: - return grid - return np.asarray(grid.cpu().numpy()) + output: Union[np.ndarray, torch.Tensor] = grid if self.as_tensor_output else np.asarray(grid.cpu().numpy()) + if return_affine: + return output, affine + return output class RandAffineGrid(RandomizableTransform): @@ -1053,12 +1068,16 @@ def randomize(self, data: Optional[Any] = None) -> None: self.scale_params = self._get_rand_param(self.scale_range, 1.0) def __call__( - self, spatial_size: Optional[Sequence[int]] = None, grid: Optional[Union[np.ndarray, torch.Tensor]] = None - ) -> Union[np.ndarray, torch.Tensor]: + self, + spatial_size: Optional[Sequence[int]] = None, + grid: Optional[Union[np.ndarray, torch.Tensor]] = None, + return_affine: bool = False, + ) -> Union[np.ndarray, torch.Tensor, Tuple[Union[np.ndarray, torch.Tensor], torch.Tensor]]: """ Args: spatial_size: output grid size. grid: grid to be transformed. Shape must be (3, H, W) for 2D or (4, H, W, D) for 3D. + return_affine: boolean as to whether to return the generated affine matrix or not. Returns: a 2D (3xHxW) or 3D (4xHxWxD) grid. @@ -1072,7 +1091,7 @@ def __call__( as_tensor_output=self.as_tensor_output, device=self.device, ) - return affine_grid(spatial_size, grid) + return affine_grid(spatial_size, grid, return_affine) class RandDeformGrid(RandomizableTransform): @@ -1298,7 +1317,7 @@ def __call__( See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample """ sp_size = fall_back_tuple(spatial_size or self.spatial_size, img.shape[1:]) - grid = self.affine_grid(spatial_size=sp_size) + grid: torch.Tensor = self.affine_grid(spatial_size=sp_size) # type: ignore return self.resampler( img=img, grid=grid, mode=mode or self.mode, padding_mode=padding_mode or self.padding_mode ) @@ -1389,7 +1408,8 @@ def __call__( spatial_size: Optional[Union[Sequence[int], int]] = None, mode: Optional[Union[GridSampleMode, str]] = None, padding_mode: Optional[Union[GridSamplePadMode, str]] = None, - ) -> Union[np.ndarray, torch.Tensor]: + return_affine: bool = False, + ) -> Union[np.ndarray, torch.Tensor, Tuple[Union[np.ndarray, torch.Tensor], torch.Tensor]]: """ Args: img: shape must be (num_channels, H, W[, D]), @@ -1404,17 +1424,26 @@ def __call__( padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} Padding mode for outside grid values. Defaults to ``self.padding_mode``. See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + return_affine: boolean as to whether to return the generated affine matrix or not. """ self.randomize() sp_size = fall_back_tuple(spatial_size or self.spatial_size, img.shape[1:]) + affine = np.eye(len(sp_size) + 1) if self._do_transform: - grid = self.rand_affine_grid(spatial_size=sp_size) + out = self.rand_affine_grid(spatial_size=sp_size) + if return_affine: + grid, affine = out + else: + grid = out else: grid = create_grid(spatial_size=sp_size) - return self.resampler( + resampled = self.resampler( img=img, grid=grid, mode=mode or self.mode, padding_mode=padding_mode or self.padding_mode ) + if return_affine: + return resampled, affine + return resampled class Rand2DElastic(RandomizableTransform): diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index 170006ed2b..3ef30c7dc6 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -22,11 +22,13 @@ import torch from monai.config import DtypeLike, KeysCollection +from monai.networks.layers import AffineTransform from monai.networks.layers.simplelayers import GaussianFilter from monai.transforms.croppad.array import CenterSpatialCrop from monai.transforms.inverse import InvertibleTransform from monai.transforms.spatial.array import ( Affine, + AffineGrid, Flip, Orientation, Rand2DElastic, @@ -525,7 +527,7 @@ def __call__( return d -class RandAffined(RandomizableTransform, MapTransform): +class RandAffined(RandomizableTransform, MapTransform, InvertibleTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.RandAffine`. """ @@ -617,16 +619,41 @@ def __call__( sp_size = fall_back_tuple(self.rand_affine.spatial_size, data[self.keys[0]].shape[1:]) if self._do_transform: - grid = self.rand_affine.rand_affine_grid(spatial_size=sp_size) + grid, affine = self.rand_affine.rand_affine_grid(spatial_size=sp_size, return_affine=True) else: grid = create_grid(spatial_size=sp_size) + affine = np.eye(len(sp_size) + 1) for key, mode, padding_mode in self.key_iterator(d, self.mode, self.padding_mode): + self.push_transform(d, key, extra_info={"affine": affine}) d[key] = self.rand_affine.resampler(d[key], grid, mode=mode, padding_mode=padding_mode) return d + def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + d = deepcopy(dict(data)) + + for key, mode, padding_mode in self.key_iterator(d, self.mode, self.padding_mode): + transform = self.get_most_recent_transform(d, key) + orig_size = transform[InverseKeys.ORIG_SIZE.value] + # Create inverse transform + fwd_affine = transform[InverseKeys.EXTRA_INFO.value]["affine"] + inv_affine = np.linalg.inv(fwd_affine) + + affine_grid = AffineGrid(affine=inv_affine) + grid: torch.Tensor = affine_grid(orig_size) # type: ignore + + # Apply inverse transform + out = self.rand_affine.resampler(d[key], grid, mode, padding_mode) + + # Convert to numpy + d[key] = out if isinstance(out, np.ndarray) else out.cpu().numpy() + + # Remove the applied transform + self.pop_transform(d, key) + + return d + -class Rand2DElasticd(RandomizableTransform, MapTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.Rand2DElastic`. """ diff --git a/tests/test_inverse.py b/tests/test_inverse.py index 0c29ea7b08..b13803a769 100644 --- a/tests/test_inverse.py +++ b/tests/test_inverse.py @@ -32,6 +32,7 @@ InvertibleTransform, LoadImaged, Orientationd, + RandAffined, RandAxisFlipd, RandFlipd, Randomizable, @@ -286,6 +287,24 @@ ) ) + +TESTS.append( + ( + "RandAffine 3d", + "3D", + 1e-1, + RandAffined( + KEYS, + [155, 179, 192], + prob=1, + padding_mode="zeros", + rotate_range=[np.pi / 6, -np.pi / 5, np.pi / 7], + shear_range=[(0.5, 0.5)], + translate_range=[10, 5, -4], + scale_range=[(0.8, 1.2), (0.9, 1.3)], + ), + ) +) TESTS_COMPOSE_X2 = [(t[0] + " Compose", t[1], t[2], Compose(Compose(t[3:]))) for t in TESTS] TESTS = TESTS + TESTS_COMPOSE_X2 # type: ignore diff --git a/tests/test_rand_affined.py b/tests/test_rand_affined.py index 54d71ad8f7..ae2adbe3b3 100644 --- a/tests/test_rand_affined.py +++ b/tests/test_rand_affined.py @@ -145,6 +145,8 @@ def test_rand_affined(self, input_param, input_data, expected_val): res = g(input_data) for key in res: result = res[key] + if "_transforms" in key: + continue expected = expected_val[key] if isinstance(expected_val, dict) else expected_val self.assertEqual(isinstance(result, torch.Tensor), isinstance(expected, torch.Tensor)) if isinstance(result, torch.Tensor): From ff143290e6c485e406a7141a5687699d6054545a Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Tue, 16 Mar 2021 12:13:04 +0000 Subject: [PATCH 54/64] undo unintentional change Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/spatial/dictionary.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index 3ef30c7dc6..ff097bf8fc 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -653,7 +653,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar return d - +class Rand2DElasticd(RandomizableTransform, MapTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.Rand2DElastic`. """ From 3b985b3a9f4703c6f494924c6c03beb834d37d77 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Tue, 16 Mar 2021 12:49:38 +0000 Subject: [PATCH 55/64] inverse Affined Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/spatial/array.py | 13 +++++++++-- monai/transforms/spatial/dictionary.py | 32 +++++++++++++++++++++++--- tests/test_inverse.py | 16 +++++++++++++ 3 files changed, 56 insertions(+), 5 deletions(-) diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index a916880fc7..3ca3f9cdc9 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -1300,6 +1300,7 @@ def __call__( spatial_size: Optional[Union[Sequence[int], int]] = None, mode: Optional[Union[GridSampleMode, str]] = None, padding_mode: Optional[Union[GridSamplePadMode, str]] = None, + return_affine: bool = False, ) -> Union[np.ndarray, torch.Tensor]: """ Args: @@ -1315,12 +1316,20 @@ def __call__( padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} Padding mode for outside grid values. Defaults to ``self.padding_mode``. See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + return_affine: boolean as to whether to return the generated affine matrix or not. """ sp_size = fall_back_tuple(spatial_size or self.spatial_size, img.shape[1:]) - grid: torch.Tensor = self.affine_grid(spatial_size=sp_size) # type: ignore - return self.resampler( + out = self.affine_grid(spatial_size=sp_size, return_affine=return_affine) + if return_affine: + grid, affine = out + else: + grid = out + resampled = self.resampler( img=img, grid=grid, mode=mode or self.mode, padding_mode=padding_mode or self.padding_mode ) + if return_affine: + return resampled, affine + return resampled class RandAffine(RandomizableTransform): diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index ff097bf8fc..af0f8612cd 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -22,7 +22,6 @@ import torch from monai.config import DtypeLike, KeysCollection -from monai.networks.layers import AffineTransform from monai.networks.layers.simplelayers import GaussianFilter from monai.transforms.croppad.array import CenterSpatialCrop from monai.transforms.inverse import InvertibleTransform @@ -454,7 +453,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda return d -class Affined(RandomizableTransform, MapTransform): +class Affined(MapTransform, InvertibleTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.Affine`. """ @@ -523,7 +522,33 @@ def __call__( ) -> Dict[Hashable, Union[np.ndarray, torch.Tensor]]: d = dict(data) for key, mode, padding_mode in self.key_iterator(d, self.mode, self.padding_mode): - d[key] = self.affine(d[key], mode=mode, padding_mode=padding_mode) + orig_size = d[key].shape[1:] + d[key], affine = self.affine(d[key], mode=mode, padding_mode=padding_mode, return_affine=True) + self.push_transform(d, key, orig_size=orig_size, extra_info={"affine": affine}) + return d + + def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + d = deepcopy(dict(data)) + + for key, mode, padding_mode in self.key_iterator(d, self.mode, self.padding_mode): + transform = self.get_most_recent_transform(d, key) + orig_size = transform[InverseKeys.ORIG_SIZE.value] + # Create inverse transform + fwd_affine = transform[InverseKeys.EXTRA_INFO.value]["affine"] + inv_affine = np.linalg.inv(fwd_affine) + + affine_grid = AffineGrid(affine=inv_affine) + grid: torch.Tensor = affine_grid(orig_size) # type: ignore + + # Apply inverse transform + out = self.affine.resampler(d[key], grid, mode, padding_mode) + + # Convert to numpy + d[key] = out if isinstance(out, np.ndarray) else out.cpu().numpy() + + # Remove the applied transform + self.pop_transform(d, key) + return d @@ -653,6 +678,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar return d + class Rand2DElasticd(RandomizableTransform, MapTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.Rand2DElastic`. diff --git a/tests/test_inverse.py b/tests/test_inverse.py index b13803a769..f3b28a5642 100644 --- a/tests/test_inverse.py +++ b/tests/test_inverse.py @@ -23,6 +23,7 @@ from monai.networks.nets import UNet from monai.transforms import ( AddChanneld, + Affined, BorderPadd, CenterSpatialCropd, Compose, @@ -287,6 +288,21 @@ ) ) +TESTS.append( + ( + "Affine 3d", + "3D", + 1e-1, + Affined( + KEYS, + spatial_size=[155, 179, 192], + rotate_params=[np.pi / 6, -np.pi / 5, np.pi / 7], + shear_params=[0.5, 0.5], + translate_params=[10, 5, -4], + scale_params=[0.8, 1.3], + ), + ) +) TESTS.append( ( From 62bbb21a729a11955d4bc6299b153935571b4e39 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Tue, 16 Mar 2021 12:55:31 +0000 Subject: [PATCH 56/64] code format Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/spatial/array.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index 3ca3f9cdc9..59f9f4ddcc 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -1301,7 +1301,7 @@ def __call__( mode: Optional[Union[GridSampleMode, str]] = None, padding_mode: Optional[Union[GridSamplePadMode, str]] = None, return_affine: bool = False, - ) -> Union[np.ndarray, torch.Tensor]: + ) -> Union[np.ndarray, torch.Tensor, Tuple[Union[np.ndarray, torch.Tensor], torch.Tensor]]: """ Args: img: shape must be (num_channels, H, W[, D]), From 42e81047f733f7bc35afe047e6ac832a3667edd0 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Tue, 16 Mar 2021 13:23:10 +0000 Subject: [PATCH 57/64] move affine Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- tests/test_inverse.py | 32 ++++++++++++++++---------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/tests/test_inverse.py b/tests/test_inverse.py index 845ae96bf1..28808a1ff4 100644 --- a/tests/test_inverse.py +++ b/tests/test_inverse.py @@ -445,6 +445,22 @@ TESTS.append(("Resized 3d", "3D", 5e-2, Resized(KEYS, [201, 150, 78]))) +TESTS.append( + ( + "Affine 3d", + "3D", + 1e-1, + Affined( + KEYS, + spatial_size=[155, 179, 192], + rotate_params=[np.pi / 6, -np.pi / 5, np.pi / 7], + shear_params=[0.5, 0.5], + translate_params=[10, 5, -4], + scale_params=[0.8, 1.3], + ), + ) +) + TESTS.append( ( "RandAffine 3d", @@ -504,22 +520,6 @@ ) ) -TESTS.append( - ( - "Affine 3d", - "3D", - 1e-1, - Affined( - KEYS, - spatial_size=[155, 179, 192], - rotate_params=[np.pi / 6, -np.pi / 5, np.pi / 7], - shear_params=[0.5, 0.5], - translate_params=[10, 5, -4], - scale_params=[0.8, 1.3], - ), - ) -) - TESTS_COMPOSE_X2 = [(t[0] + " Compose", t[1], t[2], Compose(Compose(t[3:]))) for t in TESTS] TESTS = TESTS + TESTS_COMPOSE_X2 # type: ignore From 384013aa1f02531030ac1776f9658efb9ddb8efd Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Tue, 16 Mar 2021 14:07:12 +0000 Subject: [PATCH 58/64] remove duplicate tests Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- tests/test_inverse.py | 85 ------------------------------------------- 1 file changed, 85 deletions(-) diff --git a/tests/test_inverse.py b/tests/test_inverse.py index 28808a1ff4..9603784898 100644 --- a/tests/test_inverse.py +++ b/tests/test_inverse.py @@ -234,15 +234,6 @@ ) ) -TESTS.append( - ( - "Flipd 3d", - "3D", - 0, - Flipd(KEYS, [1, 2]), - ) -) - TESTS.append( ( "RandFlipd 3d", @@ -307,42 +298,6 @@ ) ) -TESTS.append( - ( - "Flipd 3d", - "3D", - 0, - Flipd(KEYS, [1, 2]), - ) -) - -TESTS.append( - ( - "Flipd 3d", - "3D", - 0, - Flipd(KEYS, [1, 2]), - ) -) - -TESTS.append( - ( - "RandFlipd 3d", - "3D", - 0, - RandFlipd(KEYS, 1, [1, 2]), - ) -) - -TESTS.append( - ( - "RandAxisFlipd 3d", - "3D", - 0, - RandAxisFlipd(KEYS, 1), - ) -) - TESTS.append( ( "Rotated 2d", @@ -370,46 +325,6 @@ ) ) -TESTS.append( - ( - "Orientationd 3d", - "3D", - 0, - # For data loader, output needs to be same size, so input must be square/cubic - SpatialPadd(KEYS, 110), - Orientationd(KEYS, "RAS"), - ) -) - -TESTS.append( - ( - "Rotate90d 2d", - "2D", - 0, - Rotate90d(KEYS), - ) -) - -TESTS.append( - ( - "Rotate90d 3d", - "3D", - 0, - Rotate90d(KEYS, k=2, spatial_axes=(1, 2)), - ) -) - -TESTS.append( - ( - "RandRotate90d 3d", - "3D", - 0, - # For data loader, output needs to be same size, so input must be square/cubic - SpatialPadd(KEYS, 110), - RandRotate90d(KEYS, prob=1, spatial_axes=(1, 2)), - ) -) - TESTS.append( ( "Zoomd 1d", From f33a7318c523a45d1f43dcc5c4b13c79ec2ca350 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Tue, 16 Mar 2021 14:26:10 +0000 Subject: [PATCH 59/64] Inverse Rotated and RandRotated Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/spatial/array.py | 2 +- monai/transforms/spatial/dictionary.py | 78 ++++++++++++++++++++++++-- tests/test_inverse.py | 38 +++++++++++++ 3 files changed, 112 insertions(+), 6 deletions(-) diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index 33b8da3ebb..455955c645 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -743,7 +743,7 @@ def __call__( align_corners=self.align_corners if align_corners is None else align_corners, dtype=dtype or self.dtype or img.dtype, ) - return rotator(img) + return np.array(rotator(img)) class RandFlip(RandomizableTransform): diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index 170006ed2b..7c3c90a05e 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -22,6 +22,7 @@ import torch from monai.config import DtypeLike, KeysCollection +from monai.networks.layers import AffineTransform from monai.networks.layers.simplelayers import GaussianFilter from monai.transforms.croppad.array import CenterSpatialCrop from monai.transforms.inverse import InvertibleTransform @@ -1006,7 +1007,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar return d -class Rotated(MapTransform): +class Rotated(MapTransform, InvertibleTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.Rotate`. @@ -1058,17 +1059,48 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda for key, mode, padding_mode, align_corners, dtype in self.key_iterator( d, self.mode, self.padding_mode, self.align_corners, self.dtype ): - d[key] = self.rotator( + orig_size = d[key].shape[1:] + d[key], rot_mat = self.rotator( d[key], mode=mode, padding_mode=padding_mode, align_corners=align_corners, dtype=dtype, + return_rotationinver_matrix=True, ) + self.push_transform(d, key, orig_size=orig_size, extra_info={"rot_mat": rot_mat}) + return d + + def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + d = deepcopy(dict(data)) + for key, mode, padding_mode, align_corners, dtype in self.key_iterator( + d, self.mode, self.padding_mode, self.align_corners, self.dtype + ): + transform = self.get_most_recent_transform(d, key) + # Create inverse transform + fwd_rot_mat = transform[InverseKeys.EXTRA_INFO.value]["rot_mat"] + inv_rot_mat = np.linalg.inv(fwd_rot_mat) + + xform = AffineTransform( + normalized=False, + mode=mode, + padding_mode=padding_mode, + align_corners=align_corners, + reverse_indexing=True, + ) + output = xform( + torch.as_tensor(np.ascontiguousarray(d[key]).astype(dtype)).unsqueeze(0), + torch.as_tensor(np.ascontiguousarray(inv_rot_mat).astype(dtype)), + spatial_size=transform[InverseKeys.ORIG_SIZE.value], + ) + d[key] = np.asarray(output.squeeze(0).detach().cpu().numpy(), dtype=np.float32) + # Remove the applied transform + self.pop_transform(d, key) + return d -class RandRotated(RandomizableTransform, MapTransform): +class RandRotated(RandomizableTransform, MapTransform, InvertibleTransform): """ Dictionary-based version :py:class:`monai.transforms.RandRotate` Randomly rotates the input arrays. @@ -1149,21 +1181,57 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda self.randomize() d = dict(data) if not self._do_transform: + for key in self.keys: + self.push_transform(d, key, extra_info={"rot_mat": np.eye(4)}) return d + angle: Union[Sequence[float], float] = self.x if d[self.keys[0]].ndim == 3 else (self.x, self.y, self.z) rotator = Rotate( - angle=self.x if d[self.keys[0]].ndim == 3 else (self.x, self.y, self.z), + angle=angle, keep_size=self.keep_size, ) for key, mode, padding_mode, align_corners, dtype in self.key_iterator( d, self.mode, self.padding_mode, self.align_corners, self.dtype ): - d[key] = rotator( + orig_size = d[key].shape[1:] + d[key], rot_mat = rotator( d[key], mode=mode, padding_mode=padding_mode, align_corners=align_corners, dtype=dtype, + return_rotation_matrix=True, ) + self.push_transform(d, key, orig_size=orig_size, extra_info={"rot_mat": rot_mat}) + return d + + def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + d = deepcopy(dict(data)) + for key, mode, padding_mode, align_corners, dtype in self.key_iterator( + d, self.mode, self.padding_mode, self.align_corners, self.dtype + ): + transform = self.get_most_recent_transform(d, key) + # Check if random transform was actually performed (based on `prob`) + if transform[InverseKeys.DO_TRANSFORM.value]: + # Create inverse transform + fwd_rot_mat = transform[InverseKeys.EXTRA_INFO.value]["rot_mat"] + inv_rot_mat = np.linalg.inv(fwd_rot_mat) + + xform = AffineTransform( + normalized=False, + mode=mode, + padding_mode=padding_mode, + align_corners=align_corners, + reverse_indexing=True, + ) + output = xform( + torch.as_tensor(np.ascontiguousarray(d[key]).astype(dtype)).unsqueeze(0), + torch.as_tensor(np.ascontiguousarray(inv_rot_mat).astype(dtype)), + spatial_size=transform[InverseKeys.ORIG_SIZE.value], + ) + d[key] = np.asarray(output.squeeze(0).detach().cpu().numpy(), dtype=np.float32) + # Remove the applied transform + self.pop_transform(d, key) + return d diff --git a/tests/test_inverse.py b/tests/test_inverse.py index 0c29ea7b08..43bbb8a043 100644 --- a/tests/test_inverse.py +++ b/tests/test_inverse.py @@ -36,10 +36,12 @@ RandFlipd, Randomizable, RandRotate90d, + RandRotated, RandSpatialCropd, ResizeWithPadOrCrop, ResizeWithPadOrCropd, Rotate90d, + Rotated, SpatialCropd, SpatialPadd, allow_missing_keys_mode, @@ -286,6 +288,42 @@ ) ) +TESTS.append( + ( + "RandRotated, prob 0", + "2D", + 0, + RandRotated(KEYS, prob=0), + ) +) + +TESTS.append( + ( + "Rotated 2d", + "2D", + 8e-2, + Rotated(KEYS, random.uniform(np.pi / 6, np.pi), keep_size=True, align_corners=False), + ) +) + +TESTS.append( + ( + "Rotated 3d", + "3D", + 1e-1, + Rotated(KEYS, [random.uniform(np.pi / 6, np.pi) for _ in range(3)], True), # type: ignore + ) +) + +TESTS.append( + ( + "RandRotated 3d", + "3D", + 1e-1, + RandRotated(KEYS, *[random.uniform(np.pi / 6, np.pi) for _ in range(3)], 1), # type: ignore + ) +) + TESTS_COMPOSE_X2 = [(t[0] + " Compose", t[1], t[2], Compose(Compose(t[3:]))) for t in TESTS] TESTS = TESTS + TESTS_COMPOSE_X2 # type: ignore From 100bc7dd82b68d66bc89f7eed7c471b70f45768d Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Tue, 16 Mar 2021 14:28:26 +0000 Subject: [PATCH 60/64] add random Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- tests/test_inverse.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_inverse.py b/tests/test_inverse.py index 43bbb8a043..a935b9a3c7 100644 --- a/tests/test_inverse.py +++ b/tests/test_inverse.py @@ -9,6 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import random import sys import unittest from functools import partial From 067e1616b7c6d560ee43bf23044dc83740116dbd Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Tue, 16 Mar 2021 14:29:20 +0000 Subject: [PATCH 61/64] add return rotation matrix Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/spatial/array.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index 455955c645..e6ecf6c44b 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -429,7 +429,8 @@ def __call__( padding_mode: Optional[Union[GridSamplePadMode, str]] = None, align_corners: Optional[bool] = None, dtype: DtypeLike = None, - ) -> np.ndarray: + return_rotation_matrix: bool = False, + ) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]: """ Args: img: channel first array, must have shape: [chns, H, W] or [chns, H, W, D]. @@ -446,6 +447,7 @@ def __call__( dtype: data type for resampling computation. Defaults to ``self.dtype``. If None, use the data type of input data. To be compatible with other modules, the output data type is always ``np.float32``. + return_rotation_matrix: whether or not to return the applied rotation matrix. Raises: ValueError: When ``img`` spatially is not one of [2D, 3D]. @@ -482,7 +484,10 @@ def __call__( torch.as_tensor(np.ascontiguousarray(transform).astype(_dtype)), spatial_size=output_shape, ) - return np.asarray(output.squeeze(0).detach().cpu().numpy(), dtype=np.float32) + output_np = np.asarray(output.squeeze(0).detach().cpu().numpy(), dtype=np.float32) + if return_rotation_matrix: + return output_np, transform + return output_np class Zoom(Transform): From 6403f93b7e29cc124b9da568fd54ac02f83a4715 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Tue, 16 Mar 2021 14:31:04 +0000 Subject: [PATCH 62/64] remove typo Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/spatial/dictionary.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index 7c3c90a05e..9f3bb32067 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -1066,7 +1066,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda padding_mode=padding_mode, align_corners=align_corners, dtype=dtype, - return_rotationinver_matrix=True, + return_rotation_matrix=True, ) self.push_transform(d, key, orig_size=orig_size, extra_info={"rot_mat": rot_mat}) return d From a3f50e216b2389d2d2a0ae68186c782179d696e5 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Wed, 17 Mar 2021 16:27:04 +0000 Subject: [PATCH 63/64] remove return_affine Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/spatial/array.py | 62 ++++++++++---------------- monai/transforms/spatial/dictionary.py | 6 ++- tests/test_inverse.py | 1 + 3 files changed, 28 insertions(+), 41 deletions(-) diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index a263d0e43a..f22adecbd5 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -961,13 +961,11 @@ def __call__( self, spatial_size: Optional[Sequence[int]] = None, grid: Optional[Union[np.ndarray, torch.Tensor]] = None, - return_affine: bool = False, ) -> Union[np.ndarray, torch.Tensor, Tuple[Union[np.ndarray, torch.Tensor], torch.Tensor]]: """ Args: spatial_size: output grid size. grid: grid to be transformed. Shape must be (3, H, W) for 2D or (4, H, W, D) for 3D. - return_affine: boolean as to whether to return the generated affine matrix or not. Raises: ValueError: When ``grid=None`` and ``spatial_size=None``. Incompatible values. @@ -979,7 +977,6 @@ def __call__( else: raise ValueError("Incompatible values: grid=None and spatial_size=None.") - affine: Union[np.ndarray, torch.Tensor] if self.affine is None: spatial_dims = len(grid.shape) - 1 affine = np.eye(spatial_dims + 1) @@ -991,20 +988,21 @@ def __call__( affine = affine @ create_translate(spatial_dims, self.translate_params) if self.scale_params: affine = affine @ create_scale(spatial_dims, self.scale_params) - else: - affine = self.affine - affine = torch.as_tensor(np.ascontiguousarray(affine), device=self.device) + self.affine = affine + + self.affine = torch.as_tensor(np.ascontiguousarray(self.affine), device=self.device) grid = torch.tensor(grid) if not isinstance(grid, torch.Tensor) else grid.detach().clone() if self.device: grid = grid.to(self.device) - grid = (affine.float() @ grid.reshape((grid.shape[0], -1)).float()).reshape([-1] + list(grid.shape[1:])) + grid = (self.affine.float() @ grid.reshape((grid.shape[0], -1)).float()).reshape([-1] + list(grid.shape[1:])) if grid is None or not isinstance(grid, torch.Tensor): raise ValueError("Unknown grid.") - output: Union[np.ndarray, torch.Tensor] = grid if self.as_tensor_output else np.asarray(grid.cpu().numpy()) - if return_affine: - return output, affine - return output + return grid if self.as_tensor_output else np.asarray(grid.cpu().numpy()) + + def get_transformation_matrix(self) -> Optional[Union[np.ndarray, torch.Tensor]]: + """Get the most recently applied transformation matrix""" + return self.affine class RandAffineGrid(RandomizableTransform): @@ -1055,6 +1053,7 @@ def __init__( self.as_tensor_output = as_tensor_output self.device = device + self.affine: Optional[Union[np.ndarray, torch.Tensor]] = None def _get_rand_param(self, param_range, add_scalar: float = 0.0): out_param = [] @@ -1077,13 +1076,11 @@ def __call__( self, spatial_size: Optional[Sequence[int]] = None, grid: Optional[Union[np.ndarray, torch.Tensor]] = None, - return_affine: bool = False, ) -> Union[np.ndarray, torch.Tensor, Tuple[Union[np.ndarray, torch.Tensor], torch.Tensor]]: """ Args: spatial_size: output grid size. grid: grid to be transformed. Shape must be (3, H, W) for 2D or (4, H, W, D) for 3D. - return_affine: boolean as to whether to return the generated affine matrix or not. Returns: a 2D (3xHxW) or 3D (4xHxWxD) grid. @@ -1097,7 +1094,13 @@ def __call__( as_tensor_output=self.as_tensor_output, device=self.device, ) - return affine_grid(spatial_size, grid, return_affine) + grid = affine_grid(spatial_size, grid) + self.affine = affine_grid.get_transformation_matrix() + return grid + + def get_transformation_matrix(self) -> Optional[Union[np.ndarray, torch.Tensor]]: + """Get the most recently applied transformation matrix""" + return self.affine class RandDeformGrid(RandomizableTransform): @@ -1306,8 +1309,7 @@ def __call__( spatial_size: Optional[Union[Sequence[int], int]] = None, mode: Optional[Union[GridSampleMode, str]] = None, padding_mode: Optional[Union[GridSamplePadMode, str]] = None, - return_affine: bool = False, - ) -> Union[np.ndarray, torch.Tensor, Tuple[Union[np.ndarray, torch.Tensor], torch.Tensor]]: + ) -> Union[np.ndarray, torch.Tensor]: """ Args: img: shape must be (num_channels, H, W[, D]), @@ -1322,20 +1324,12 @@ def __call__( padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} Padding mode for outside grid values. Defaults to ``self.padding_mode``. See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample - return_affine: boolean as to whether to return the generated affine matrix or not. """ sp_size = fall_back_tuple(spatial_size or self.spatial_size, img.shape[1:]) - out = self.affine_grid(spatial_size=sp_size, return_affine=return_affine) - if return_affine: - grid, affine = out - else: - grid = out - resampled = self.resampler( + grid = self.affine_grid(spatial_size=sp_size) + return self.resampler( img=img, grid=grid, mode=mode or self.mode, padding_mode=padding_mode or self.padding_mode ) - if return_affine: - return resampled, affine - return resampled class RandAffine(RandomizableTransform): @@ -1423,8 +1417,7 @@ def __call__( spatial_size: Optional[Union[Sequence[int], int]] = None, mode: Optional[Union[GridSampleMode, str]] = None, padding_mode: Optional[Union[GridSamplePadMode, str]] = None, - return_affine: bool = False, - ) -> Union[np.ndarray, torch.Tensor, Tuple[Union[np.ndarray, torch.Tensor], torch.Tensor]]: + ) -> Union[np.ndarray, torch.Tensor]: """ Args: img: shape must be (num_channels, H, W[, D]), @@ -1439,26 +1432,17 @@ def __call__( padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} Padding mode for outside grid values. Defaults to ``self.padding_mode``. See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample - return_affine: boolean as to whether to return the generated affine matrix or not. """ self.randomize() sp_size = fall_back_tuple(spatial_size or self.spatial_size, img.shape[1:]) - affine = np.eye(len(sp_size) + 1) if self._do_transform: - out = self.rand_affine_grid(spatial_size=sp_size) - if return_affine: - grid, affine = out - else: - grid = out + grid = self.rand_affine_grid(spatial_size=sp_size) else: grid = create_grid(spatial_size=sp_size) - resampled = self.resampler( + return self.resampler( img=img, grid=grid, mode=mode or self.mode, padding_mode=padding_mode or self.padding_mode ) - if return_affine: - return resampled, affine - return resampled class Rand2DElastic(RandomizableTransform): diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index 560b39bf79..caa1a34e08 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -572,7 +572,8 @@ def __call__( d = dict(data) for key, mode, padding_mode in self.key_iterator(d, self.mode, self.padding_mode): orig_size = d[key].shape[1:] - d[key], affine = self.affine(d[key], mode=mode, padding_mode=padding_mode, return_affine=True) + d[key] = self.affine(d[key], mode=mode, padding_mode=padding_mode) + affine = self.affine.affine_grid.get_transformation_matrix() self.push_transform(d, key, orig_size=orig_size, extra_info={"affine": affine}) return d @@ -693,7 +694,8 @@ def __call__( sp_size = fall_back_tuple(self.rand_affine.spatial_size, data[self.keys[0]].shape[1:]) if self._do_transform: - grid, affine = self.rand_affine.rand_affine_grid(spatial_size=sp_size, return_affine=True) + grid = self.rand_affine.rand_affine_grid(spatial_size=sp_size) + affine = self.rand_affine.rand_affine_grid.get_transformation_matrix() else: grid = create_grid(spatial_size=sp_size) affine = np.eye(len(sp_size) + 1) diff --git a/tests/test_inverse.py b/tests/test_inverse.py index 42954a6aea..c1225ea11c 100644 --- a/tests/test_inverse.py +++ b/tests/test_inverse.py @@ -400,6 +400,7 @@ ), ) ) + TESTS_COMPOSE_X2 = [(t[0] + " Compose", t[1], t[2], Compose(Compose(t[3:]))) for t in TESTS] TESTS = TESTS + TESTS_COMPOSE_X2 # type: ignore From e2a1dac521442f05d3df324a38bd690156d24ce7 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Wed, 17 Mar 2021 16:35:07 +0000 Subject: [PATCH 64/64] code format Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/spatial/array.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index f22adecbd5..de9bba8e95 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -961,7 +961,7 @@ def __call__( self, spatial_size: Optional[Sequence[int]] = None, grid: Optional[Union[np.ndarray, torch.Tensor]] = None, - ) -> Union[np.ndarray, torch.Tensor, Tuple[Union[np.ndarray, torch.Tensor], torch.Tensor]]: + ) -> Union[np.ndarray, torch.Tensor]: """ Args: spatial_size: output grid size. @@ -1076,7 +1076,7 @@ def __call__( self, spatial_size: Optional[Sequence[int]] = None, grid: Optional[Union[np.ndarray, torch.Tensor]] = None, - ) -> Union[np.ndarray, torch.Tensor, Tuple[Union[np.ndarray, torch.Tensor], torch.Tensor]]: + ) -> Union[np.ndarray, torch.Tensor]: """ Args: spatial_size: output grid size.