From 18cc88c39616ec70eed4d9d46371eb24004d3f1b Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Fri, 22 Jan 2021 11:02:03 +0000 Subject: [PATCH 01/80] move Transfrom out of compose file Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/networks/layers/spatial_transforms.py | 5 ++ monai/transforms/__init__.py | 3 +- monai/transforms/compose.py | 51 +-------------------- 3 files changed, 9 insertions(+), 50 deletions(-) diff --git a/monai/networks/layers/spatial_transforms.py b/monai/networks/layers/spatial_transforms.py index c0f22502c8..4130ed0d4e 100644 --- a/monai/networks/layers/spatial_transforms.py +++ b/monai/networks/layers/spatial_transforms.py @@ -544,3 +544,8 @@ def forward( align_corners=self.align_corners, ) return dst + + def forward( + self, src: torch.Tensor, theta: torch.Tensor, spatial_size: Optional[Union[Sequence[int], int]] = None + ) -> torch.Tensor: + pass diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index 9eaedd6b15..11a6d79935 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -10,7 +10,7 @@ # limitations under the License. from .adaptors import FunctionSignature, adaptor, apply_alias, to_kwargs -from .compose import Compose, MapTransform, Randomizable, Transform +from .compose import Compose, MapTransform, Randomizable from .croppad.array import ( BorderPad, BoundingRect, @@ -234,6 +234,7 @@ ZoomD, ZoomDict, ) +from .transform import Transform from .utility.array import ( AddChannel, AddExtremePointsChannel, diff --git a/monai/transforms/compose.py b/monai/transforms/compose.py index 3e23377b36..8c4f107f61 100644 --- a/monai/transforms/compose.py +++ b/monai/transforms/compose.py @@ -19,58 +19,11 @@ import numpy as np from monai.config import KeysCollection +from monai.transforms.transform import Transform from monai.transforms.utils import apply_transform from monai.utils import MAX_SEED, ensure_tuple, get_seed -__all__ = ["Transform", "Randomizable", "Compose", "MapTransform"] - - -class Transform(ABC): - """ - An abstract class of a ``Transform``. - A transform is callable that processes ``data``. - - It could be stateful and may modify ``data`` in place, - the implementation should be aware of: - - #. thread safety when mutating its own states. - When used from a multi-process context, transform's instance variables are read-only. - #. ``data`` content unused by this transform may still be used in the - subsequent transforms in a composed transform. - #. storing too much information in ``data`` may not scale. - - See Also - - :py:class:`monai.transforms.Compose` - """ - - @abstractmethod - def __call__(self, data: Any): - """ - ``data`` is an element which often comes from an iteration over an - iterable, such as :py:class:`torch.utils.data.Dataset`. This method should - return an updated version of ``data``. - To simplify the input validations, most of the transforms assume that - - - ``data`` is a Numpy ndarray, PyTorch Tensor or string - - the data shape can be: - - #. string data without shape, `LoadImage` transform expects file paths - #. most of the pre-processing transforms expect: ``(num_channels, spatial_dim_1[, spatial_dim_2, ...])``, - except that `AddChannel` expects (spatial_dim_1[, spatial_dim_2, ...]) and - `AsChannelFirst` expects (spatial_dim_1[, spatial_dim_2, ...], num_channels) - #. most of the post-processing transforms expect - ``(batch_size, num_channels, spatial_dim_1[, spatial_dim_2, ...])`` - - - the channel dimension is not omitted even if number of channels is one - - This method can optionally take additional arguments to help execute transformation operation. - - Raises: - NotImplementedError: When the subclass does not override this method. - - """ - raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") +__all__ = ["Randomizable", "Compose", "MapTransform"] class Randomizable(ABC): From 18b53f533e8d9348352d54ecb37917c2f21acda4 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Fri, 22 Jan 2021 13:33:35 +0000 Subject: [PATCH 02/80] add transform file Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/transform.py | 76 +++++++++++++++++++++++++++++++++++ 1 file changed, 76 insertions(+) create mode 100644 monai/transforms/transform.py diff --git a/monai/transforms/transform.py b/monai/transforms/transform.py new file mode 100644 index 0000000000..cd8ededda2 --- /dev/null +++ b/monai/transforms/transform.py @@ -0,0 +1,76 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +A collection of generic interfaces for MONAI transforms. +""" + +from abc import ABC, abstractmethod +from typing import Any + +__all__ = ["Transform"] + + +class Transform(ABC): + """ + An abstract class of a ``Transform``. + A transform is callable that processes ``data``. + + It could be stateful and may modify ``data`` in place, + the implementation should be aware of: + + #. thread safety when mutating its own states. + When used from a multi-process context, transform's instance variables are read-only. + #. ``data`` content unused by this transform may still be used in the + subsequent transforms in a composed transform. + #. storing too much information in ``data`` may not scale. + + See Also + + :py:class:`monai.transforms.Compose` + """ + + @abstractmethod + def __call__(self, data: Any): + """ + ``data`` is an element which often comes from an iteration over an + iterable, such as :py:class:`torch.utils.data.Dataset`. This method should + return an updated version of ``data``. + To simplify the input validations, most of the transforms assume that + + - ``data`` is a Numpy ndarray, PyTorch Tensor or string + - the data shape can be: + + #. string data without shape, `LoadImage` transform expects file paths + #. most of the pre-processing transforms expect: ``(num_channels, spatial_dim_1[, spatial_dim_2, ...])``, + except that `AddChannel` expects (spatial_dim_1[, spatial_dim_2, ...]) and + `AsChannelFirst` expects (spatial_dim_1[, spatial_dim_2, ...], num_channels) + #. most of the post-processing transforms expect + ``(batch_size, num_channels, spatial_dim_1[, spatial_dim_2, ...])`` + + - the channel dimension is not omitted even if number of channels is one + + This method can optionally take additional arguments to help execute transformation operation. + + Raises: + NotImplementedError: When the subclass does not override this method. + + """ + raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") + + def inverse(self, data: Any): + """ + Inverse of ``__call__``. + + Raises: + NotImplementedError: When the subclass does not override this method. + + """ + raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") From 827dd84f3218875416e79e9f65f6b814f19ec08f Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Thu, 28 Jan 2021 13:57:03 +0000 Subject: [PATCH 03/80] inverse compose and spatialpadd Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/networks/layers/spatial_transforms.py | 5 - monai/transforms/__init__.py | 4 +- monai/transforms/compose.py | 154 ++--------------- monai/transforms/croppad/dictionary.py | 28 +++- monai/transforms/intensity/dictionary.py | 2 +- monai/transforms/io/dictionary.py | 2 +- monai/transforms/post/dictionary.py | 2 +- monai/transforms/spatial/dictionary.py | 2 +- monai/transforms/transform.py | 174 +++++++++++++++++++- monai/transforms/utility/dictionary.py | 2 +- tests/test_inverse.py | 88 ++++++++++ 11 files changed, 307 insertions(+), 156 deletions(-) create mode 100644 tests/test_inverse.py diff --git a/monai/networks/layers/spatial_transforms.py b/monai/networks/layers/spatial_transforms.py index 4130ed0d4e..c0f22502c8 100644 --- a/monai/networks/layers/spatial_transforms.py +++ b/monai/networks/layers/spatial_transforms.py @@ -544,8 +544,3 @@ def forward( align_corners=self.align_corners, ) return dst - - def forward( - self, src: torch.Tensor, theta: torch.Tensor, spatial_size: Optional[Union[Sequence[int], int]] = None - ) -> torch.Tensor: - pass diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index 11a6d79935..50730b28e0 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -10,7 +10,7 @@ # limitations under the License. from .adaptors import FunctionSignature, adaptor, apply_alias, to_kwargs -from .compose import Compose, MapTransform, Randomizable +from .compose import Compose from .croppad.array import ( BorderPad, BoundingRect, @@ -234,7 +234,7 @@ ZoomD, ZoomDict, ) -from .transform import Transform +from .transform import MapTransform, Randomizable, SpatialMapTransform, Transform from .utility.array import ( AddChannel, AddExtremePointsChannel, diff --git a/monai/transforms/compose.py b/monai/transforms/compose.py index 8c4f107f61..82c8b88dcb 100644 --- a/monai/transforms/compose.py +++ b/monai/transforms/compose.py @@ -13,88 +13,15 @@ """ import warnings -from abc import ABC, abstractmethod -from typing import Any, Callable, Hashable, Optional, Sequence, Tuple, Union +from typing import Any, Callable, Optional, Sequence, Union, Hashable, Mapping import numpy as np -from monai.config import KeysCollection -from monai.transforms.transform import Transform +from monai.transforms.transform import Randomizable, Transform from monai.transforms.utils import apply_transform from monai.utils import MAX_SEED, ensure_tuple, get_seed -__all__ = ["Randomizable", "Compose", "MapTransform"] - - -class Randomizable(ABC): - """ - An interface for handling random state locally, currently based on a class variable `R`, - which is an instance of `np.random.RandomState`. - This is mainly for randomized data augmentation transforms. For example:: - - class RandShiftIntensity(Randomizable): - def randomize(): - self._offset = self.R.uniform(low=0, high=100) - def __call__(self, img): - self.randomize() - return img + self._offset - - transform = RandShiftIntensity() - transform.set_random_state(seed=0) - - """ - - R: np.random.RandomState = np.random.RandomState() - - def set_random_state( - self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None - ) -> "Randomizable": - """ - Set the random state locally, to control the randomness, the derived - classes should use :py:attr:`self.R` instead of `np.random` to introduce random - factors. - - Args: - seed: set the random state with an integer seed. - state: set the random state with a `np.random.RandomState` object. - - Raises: - TypeError: When ``state`` is not an ``Optional[np.random.RandomState]``. - - Returns: - a Randomizable instance. - - """ - if seed is not None: - _seed = id(seed) if not isinstance(seed, (int, np.integer)) else seed - _seed = _seed % MAX_SEED - self.R = np.random.RandomState(_seed) - return self - - if state is not None: - if not isinstance(state, np.random.RandomState): - raise TypeError(f"state must be None or a np.random.RandomState but is {type(state).__name__}.") - self.R = state - return self - - self.R = np.random.RandomState() - return self - - @abstractmethod - def randomize(self, data: Any) -> None: - """ - Within this method, :py:attr:`self.R` should be used, instead of `np.random`, to introduce random factors. - - all :py:attr:`self.R` calls happen here so that we have a better chance to - identify errors of sync the random state. - - This method can generate the random factors based on properties of the input data. - - Raises: - NotImplementedError: When the subclass does not override this method. - - """ - raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") +__all__ = ["Compose"] class Compose(Randomizable, Transform): @@ -189,66 +116,15 @@ def __call__(self, input_): input_ = apply_transform(_transform, input_) return input_ - -class MapTransform(Transform): - """ - A subclass of :py:class:`monai.transforms.Transform` with an assumption - that the ``data`` input of ``self.__call__`` is a MutableMapping such as ``dict``. - - The ``keys`` parameter will be used to get and set the actual data - item to transform. That is, the callable of this transform should - follow the pattern: - - .. code-block:: python - - def __call__(self, data): - for key in self.keys: - if key in data: - # update output data with some_transform_function(data[key]). - else: - # do nothing or some exceptions handling. - return data - - Raises: - ValueError: When ``keys`` is an empty iterable. - TypeError: When ``keys`` type is not in ``Union[Hashable, Iterable[Hashable]]``. - - """ - - def __init__(self, keys: KeysCollection) -> None: - self.keys: Tuple[Hashable, ...] = ensure_tuple(keys) - if not self.keys: - raise ValueError("keys must be non empty.") - for key in self.keys: - if not isinstance(key, Hashable): - raise TypeError(f"keys must be one of (Hashable, Iterable[Hashable]) but is {type(keys).__name__}.") - - @abstractmethod - def __call__(self, data): - """ - ``data`` often comes from an iteration over an iterable, - such as :py:class:`torch.utils.data.Dataset`. - - To simplify the input validations, this method assumes: - - - ``data`` is a Python dictionary - - ``data[key]`` is a Numpy ndarray, PyTorch Tensor or string, where ``key`` is an element - of ``self.keys``, the data shape can be: - - #. string data without shape, `LoadImaged` transform expects file paths - #. most of the pre-processing transforms expect: ``(num_channels, spatial_dim_1[, spatial_dim_2, ...])``, - except that `AddChanneld` expects (spatial_dim_1[, spatial_dim_2, ...]) and - `AsChannelFirstd` expects (spatial_dim_1[, spatial_dim_2, ...], num_channels) - #. most of the post-processing transforms expect - ``(batch_size, num_channels, spatial_dim_1[, spatial_dim_2, ...])`` - - - the channel dimension is not omitted even if number of channels is one - - Raises: - NotImplementedError: When the subclass does not override this method. - - returns: - An updated dictionary version of ``data`` by applying the transform. - - """ - raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") + def inverse(self, data): + if not isinstance(data, Mapping): + raise RuntimeError("Inverse method only available for dictionary transforms") + # loop over data elements + for k in data: + transform_key = k + "_transforms" + if transform_key not in data: + continue + for t in reversed(data[transform_key]): + transform = t["obj"] + data = transform.inverse(data) + return data diff --git a/monai/transforms/croppad/dictionary.py b/monai/transforms/croppad/dictionary.py index 8bf33dd632..3b4306dc09 100644 --- a/monai/transforms/croppad/dictionary.py +++ b/monai/transforms/croppad/dictionary.py @@ -18,10 +18,10 @@ from typing import Any, Callable, Dict, Hashable, List, Mapping, Optional, Sequence, Tuple, Union import numpy as np +from math import floor from monai.config import IndexSelection, KeysCollection from monai.data.utils import get_random_patch, get_valid_patch_size -from monai.transforms.compose import MapTransform, Randomizable from monai.transforms.croppad.array import ( BorderPad, BoundingRect, @@ -31,6 +31,7 @@ SpatialCrop, SpatialPad, ) +from monai.transforms.transform import MapTransform, Randomizable, SpatialMapTransform from monai.transforms.utils import ( generate_pos_neg_label_crop_centers, generate_spatial_bounding_box, @@ -82,7 +83,7 @@ NumpyPadModeSequence = Union[Sequence[Union[NumpyPadMode, str]], NumpyPadMode, str] -class SpatialPadd(MapTransform): +class SpatialPadd(SpatialMapTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.SpatialPad`. Performs padding to the data, symmetric for all sides or all on one side for each dimension. @@ -106,7 +107,7 @@ def __init__( mode: {``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, ``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} One of the listed string values or a user supplied function. Defaults to ``"constant"``. - See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html + See also: https://numpy.org/doc/stable/reference/generated/numpy.pad.html It also can be a sequence of string, each element corresponds to a key in ``keys``. """ @@ -117,9 +118,30 @@ def __init__( def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) for key, m in zip(self.keys, self.mode): + d = self.append_applied_transforms(d, key, {"obj":self, "orig_size": d[key].shape}) d[key] = self.padder(d[key], mode=m) return d + def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + d = dict(data) + for key, m in zip(self.keys, self.mode): + transform = self.get_most_recent_transform(d, key) + if transform["obj"] != self: + raise RuntimeError( + "Should inverse most recently applied inverse-able transform first") + # Create inverse transform + roi_size = transform["orig_size"][1:] + im_shape = d[key].shape[1:] if self.padder.method == Method.SYMMETRIC else transform["orig_size"][1:] + roi_center = [floor(i/2) if r % 2 == 0 else (i-1)/2 for r, i in zip(roi_size, im_shape)] + + inverse_transform = SpatialCrop(roi_center, roi_size) + # Apply inverse transform + d[key] = inverse_transform(d[key]) + # Remove the applied transform + self.remove_most_recent_transform(d, key) + + return d + class BorderPadd(MapTransform): """ diff --git a/monai/transforms/intensity/dictionary.py b/monai/transforms/intensity/dictionary.py index 34d75faf63..b510c0a072 100644 --- a/monai/transforms/intensity/dictionary.py +++ b/monai/transforms/intensity/dictionary.py @@ -22,7 +22,6 @@ import torch from monai.config import KeysCollection -from monai.transforms.compose import MapTransform, Randomizable from monai.transforms.intensity.array import ( AdjustContrast, GaussianSharpen, @@ -35,6 +34,7 @@ ShiftIntensity, ThresholdIntensity, ) +from monai.transforms.transform import MapTransform, Randomizable from monai.utils import dtype_torch_to_numpy, ensure_tuple_size __all__ = [ diff --git a/monai/transforms/io/dictionary.py b/monai/transforms/io/dictionary.py index 62ac4c8562..703e0dc478 100644 --- a/monai/transforms/io/dictionary.py +++ b/monai/transforms/io/dictionary.py @@ -21,8 +21,8 @@ from monai.config import KeysCollection from monai.data.image_reader import ImageReader -from monai.transforms.compose import MapTransform from monai.transforms.io.array import LoadImage +from monai.transforms.transform import MapTransform __all__ = [ "LoadImaged", diff --git a/monai/transforms/post/dictionary.py b/monai/transforms/post/dictionary.py index 60cda11a91..aff4ae3572 100644 --- a/monai/transforms/post/dictionary.py +++ b/monai/transforms/post/dictionary.py @@ -21,7 +21,6 @@ import torch from monai.config import KeysCollection -from monai.transforms.compose import MapTransform from monai.transforms.post.array import ( Activations, AsDiscrete, @@ -30,6 +29,7 @@ MeanEnsemble, VoteEnsemble, ) +from monai.transforms.transform import MapTransform from monai.utils import ensure_tuple_rep __all__ = [ diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index 615a327d90..a6f1b53996 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -22,7 +22,6 @@ from monai.config import KeysCollection from monai.networks.layers.simplelayers import GaussianFilter -from monai.transforms.compose import MapTransform, Randomizable from monai.transforms.croppad.array import CenterSpatialCrop from monai.transforms.spatial.array import ( Flip, @@ -36,6 +35,7 @@ Spacing, Zoom, ) +from monai.transforms.transform import MapTransform, Randomizable from monai.transforms.utils import create_grid from monai.utils import ( GridSampleMode, diff --git a/monai/transforms/transform.py b/monai/transforms/transform.py index cd8ededda2..826c5a0774 100644 --- a/monai/transforms/transform.py +++ b/monai/transforms/transform.py @@ -13,9 +13,85 @@ """ from abc import ABC, abstractmethod -from typing import Any +from typing import Any, Hashable, Optional, Tuple -__all__ = ["Transform"] +import numpy as np + +from monai.config import KeysCollection +from monai.utils import MAX_SEED, ensure_tuple + +__all__ = ["Randomizable", "Transform", "MapTransform", "SpatialMapTransform"] + + +class Randomizable(ABC): + """ + An interface for handling random state locally, currently based on a class variable `R`, + which is an instance of `np.random.RandomState`. + This is mainly for randomized data augmentation transforms. For example:: + + class RandShiftIntensity(Randomizable): + def randomize(): + self._offset = self.R.uniform(low=0, high=100) + def __call__(self, img): + self.randomize() + return img + self._offset + + transform = RandShiftIntensity() + transform.set_random_state(seed=0) + + """ + + R: np.random.RandomState = np.random.RandomState() + + def set_random_state( + self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None + ) -> "Randomizable": + """ + Set the random state locally, to control the randomness, the derived + classes should use :py:attr:`self.R` instead of `np.random` to introduce random + factors. + + Args: + seed: set the random state with an integer seed. + state: set the random state with a `np.random.RandomState` object. + + Raises: + TypeError: When ``state`` is not an ``Optional[np.random.RandomState]``. + + Returns: + a Randomizable instance. + + """ + if seed is not None: + _seed = id(seed) if not isinstance(seed, (int, np.integer)) else seed + _seed = _seed % MAX_SEED + self.R = np.random.RandomState(_seed) + return self + + if state is not None: + if not isinstance(state, np.random.RandomState): + raise TypeError(f"state must be None or a np.random.RandomState but is {type(state).__name__}.") + self.R = state + return self + + self.R = np.random.RandomState() + return self + + @abstractmethod + def randomize(self, data: Any) -> None: + """ + Within this method, :py:attr:`self.R` should be used, instead of `np.random`, to introduce random factors. + + all :py:attr:`self.R` calls happen here so that we have a better chance to + identify errors of sync the random state. + + This method can generate the random factors based on properties of the input data. + + Raises: + NotImplementedError: When the subclass does not override this method. + + """ + raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") class Transform(ABC): @@ -65,6 +141,100 @@ def __call__(self, data: Any): """ raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") + +class MapTransform(Transform): + """ + A subclass of :py:class:`monai.transforms.Transform` with an assumption + that the ``data`` input of ``self.__call__`` is a MutableMapping such as ``dict``. + + The ``keys`` parameter will be used to get and set the actual data + item to transform. That is, the callable of this transform should + follow the pattern: + + .. code-block:: python + + def __call__(self, data): + for key in self.keys: + if key in data: + # update output data with some_transform_function(data[key]). + else: + # do nothing or some exceptions handling. + return data + + Raises: + ValueError: When ``keys`` is an empty iterable. + TypeError: When ``keys`` type is not in ``Union[Hashable, Iterable[Hashable]]``. + + """ + + def __init__(self, keys: KeysCollection) -> None: + self.keys: Tuple[Hashable, ...] = ensure_tuple(keys) + if not self.keys: + raise ValueError("keys must be non empty.") + for key in self.keys: + if not isinstance(key, Hashable): + raise TypeError(f"keys must be one of (Hashable, Iterable[Hashable]) but is {type(keys).__name__}.") + + @abstractmethod + def __call__(self, data): + """ + ``data`` often comes from an iteration over an iterable, + such as :py:class:`torch.utils.data.Dataset`. + + To simplify the input validations, this method assumes: + + - ``data`` is a Python dictionary + - ``data[key]`` is a Numpy ndarray, PyTorch Tensor or string, where ``key`` is an element + of ``self.keys``, the data shape can be: + + #. string data without shape, `LoadImaged` transform expects file paths + #. most of the pre-processing transforms expect: ``(num_channels, spatial_dim_1[, spatial_dim_2, ...])``, + except that `AddChanneld` expects (spatial_dim_1[, spatial_dim_2, ...]) and + `AsChannelFirstd` expects (spatial_dim_1[, spatial_dim_2, ...], num_channels) + #. most of the post-processing transforms expect + ``(batch_size, num_channels, spatial_dim_1[, spatial_dim_2, ...])`` + + - the channel dimension is not omitted even if number of channels is one + + Raises: + NotImplementedError: When the subclass does not override this method. + + returns: + An updated dictionary version of ``data`` by applying the transform. + + """ + raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") + + +class SpatialMapTransform(MapTransform): + """Sub-class of ``MapTransform`` for deterministic, spatial transforms. + + This class exists so that an ``invert`` method can be implemented. This allows + images to be cropped, rotated, padded, etc., and then for segmentations to be + returned to their original size before saving to file for comparison in an external + viewer. + """ + + @staticmethod + def append_applied_transforms(data, key, args): + """Append to list of applied transforms for that key.""" + key += "_transforms" + # If this is the first, create list + if key not in data: + data[key] = [] + data[key].append(args) + return data + + @staticmethod + def get_most_recent_transform(data, key): + """Get all applied transforms.""" + return data[key + "_transforms"][-1] + + @staticmethod + def remove_most_recent_transform(data, key): + """Get all applied transforms.""" + data[key + "_transforms"].pop() + def inverse(self, data: Any): """ Inverse of ``__call__``. diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index 1427f24356..c3d4f9ff1b 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -23,7 +23,7 @@ import torch from monai.config import KeysCollection -from monai.transforms.compose import MapTransform, Randomizable +from monai.transforms.transform import MapTransform, Randomizable from monai.transforms.utility.array import ( AddChannel, AsChannelFirst, diff --git a/tests/test_inverse.py b/tests/test_inverse.py new file mode 100644 index 0000000000..4df7fd0ff3 --- /dev/null +++ b/tests/test_inverse.py @@ -0,0 +1,88 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +# from parameterized import parameterized + +from monai.transforms import Compose, SpatialPadd, SpatialPad + +TEST_0 = [ + {"image": np.arange(0,10).reshape(1, 10)}, + [ + SpatialPadd(keys="image", spatial_size=[15]), + SpatialPadd(keys="image", spatial_size=[21]), + SpatialPadd(keys="image", spatial_size=[24]), + ] +] + +TEST_1 = [ + {"image": np.arange(0,10*9).reshape(1, 10, 9)}, + [ + SpatialPadd(keys="image", spatial_size=[11, 12]), + SpatialPadd(keys="image", spatial_size=[12, 21]), + SpatialPadd(keys="image", spatial_size=[14, 25]), + ] +] + +TEST_2 = [ + {"image": np.arange(0,10).reshape(1, 10)}, + [ + Compose([ + SpatialPadd(keys="image", spatial_size=[15]), + SpatialPadd(keys="image", spatial_size=[21]), + SpatialPadd(keys="image", spatial_size=[24]), + ]) + ] +] + +TEST_FAIL = [ + np.arange(0,10).reshape(1, 10), + Compose([ + SpatialPad(spatial_size=[15]), + ]) +] + +TESTS = [TEST_0, TEST_1, TEST_2] + +class TestInverse(unittest.TestCase): + def test_inverse(self, data, transforms): + d = data.copy() + + # Apply forwards + for t in transforms: + d = t(d) + + # Check that error is thrown when inverse are used out of order. + t = transforms[0] if len(transforms) > 1 else SpatialPadd("image", [10, 5]) + with self.assertRaises(RuntimeError): + t.inverse(d) + + # Apply inverses + for t in reversed(transforms): + d = t.inverse(d) + + self.assertTrue(np.all(d["image"] == data["image"])) + + def test_fail(self, data, transform): + d = transform(data) + with self.assertRaises(RuntimeError): + d = transform.inverse(d) + + + +if __name__ == "__main__": + # unittest.main() + a = TestInverse() + a.test_fail(TEST_FAIL[0], TEST_FAIL[1]) + for t in TESTS: + a.test_inverse(t[0], t[1]) From 7851e4f757090295b68c437120bafd90e18c3d0d Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Thu, 28 Jan 2021 14:10:21 +0000 Subject: [PATCH 04/80] autofixes Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/compose.py | 9 +++-- monai/transforms/croppad/dictionary.py | 11 +++--- tests/test_inverse.py | 51 ++++++++++++++------------ 3 files changed, 37 insertions(+), 34 deletions(-) diff --git a/monai/transforms/compose.py b/monai/transforms/compose.py index 82c8b88dcb..529404fb60 100644 --- a/monai/transforms/compose.py +++ b/monai/transforms/compose.py @@ -13,7 +13,7 @@ """ import warnings -from typing import Any, Callable, Optional, Sequence, Union, Hashable, Mapping +from typing import Any, Callable, Mapping, Optional, Sequence, Union import numpy as np @@ -119,12 +119,13 @@ def __call__(self, input_): def inverse(self, data): if not isinstance(data, Mapping): raise RuntimeError("Inverse method only available for dictionary transforms") + d = dict(data) # loop over data elements - for k in data: + for k in d: transform_key = k + "_transforms" if transform_key not in data: continue for t in reversed(data[transform_key]): transform = t["obj"] - data = transform.inverse(data) - return data + d = transform.inverse(d) + return d diff --git a/monai/transforms/croppad/dictionary.py b/monai/transforms/croppad/dictionary.py index 3b4306dc09..850d411d68 100644 --- a/monai/transforms/croppad/dictionary.py +++ b/monai/transforms/croppad/dictionary.py @@ -15,10 +15,10 @@ Class names are ended with 'd' to denote dictionary-based transforms. """ +from math import floor from typing import Any, Callable, Dict, Hashable, List, Mapping, Optional, Sequence, Tuple, Union import numpy as np -from math import floor from monai.config import IndexSelection, KeysCollection from monai.data.utils import get_random_patch, get_valid_patch_size @@ -118,21 +118,20 @@ def __init__( def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) for key, m in zip(self.keys, self.mode): - d = self.append_applied_transforms(d, key, {"obj":self, "orig_size": d[key].shape}) + d = self.append_applied_transforms(d, key, {"obj": self, "orig_size": d[key].shape}) d[key] = self.padder(d[key], mode=m) return d def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) - for key, m in zip(self.keys, self.mode): + for key in self.keys: transform = self.get_most_recent_transform(d, key) if transform["obj"] != self: - raise RuntimeError( - "Should inverse most recently applied inverse-able transform first") + raise RuntimeError("Should inverse most recently applied inverse-able transform first") # Create inverse transform roi_size = transform["orig_size"][1:] im_shape = d[key].shape[1:] if self.padder.method == Method.SYMMETRIC else transform["orig_size"][1:] - roi_center = [floor(i/2) if r % 2 == 0 else (i-1)/2 for r, i in zip(roi_size, im_shape)] + roi_center = [floor(i / 2) if r % 2 == 0 else (i - 1) / 2 for r, i in zip(roi_size, im_shape)] inverse_transform = SpatialCrop(roi_center, roi_size) # Apply inverse transform diff --git a/tests/test_inverse.py b/tests/test_inverse.py index 4df7fd0ff3..3a8d92de9a 100644 --- a/tests/test_inverse.py +++ b/tests/test_inverse.py @@ -12,49 +12,56 @@ import unittest import numpy as np -# from parameterized import parameterized +from parameterized import parameterized -from monai.transforms import Compose, SpatialPadd, SpatialPad +from monai.transforms import Compose, SpatialPad, SpatialPadd TEST_0 = [ - {"image": np.arange(0,10).reshape(1, 10)}, + {"image": np.arange(0, 10).reshape(1, 10)}, [ SpatialPadd(keys="image", spatial_size=[15]), SpatialPadd(keys="image", spatial_size=[21]), SpatialPadd(keys="image", spatial_size=[24]), - ] + ], ] TEST_1 = [ - {"image": np.arange(0,10*9).reshape(1, 10, 9)}, + {"image": np.arange(0, 10 * 9).reshape(1, 10, 9)}, [ SpatialPadd(keys="image", spatial_size=[11, 12]), SpatialPadd(keys="image", spatial_size=[12, 21]), SpatialPadd(keys="image", spatial_size=[14, 25]), - ] + ], ] TEST_2 = [ - {"image": np.arange(0,10).reshape(1, 10)}, + {"image": np.arange(0, 10).reshape(1, 10)}, [ - Compose([ - SpatialPadd(keys="image", spatial_size=[15]), - SpatialPadd(keys="image", spatial_size=[21]), - SpatialPadd(keys="image", spatial_size=[24]), - ]) - ] + Compose( + [ + SpatialPadd(keys="image", spatial_size=[15]), + SpatialPadd(keys="image", spatial_size=[21]), + SpatialPadd(keys="image", spatial_size=[24]), + ] + ) + ], ] -TEST_FAIL = [ - np.arange(0,10).reshape(1, 10), - Compose([ - SpatialPad(spatial_size=[15]), - ]) +TEST_FAIL_0 = [ + np.arange(0, 10).reshape(1, 10), + Compose( + [ + SpatialPad(spatial_size=[15]), + ] + ), ] TESTS = [TEST_0, TEST_1, TEST_2] +TEST_FAILS = [TEST_FAIL_0] + class TestInverse(unittest.TestCase): + @parameterized.expand(TESTS) def test_inverse(self, data, transforms): d = data.copy() @@ -73,16 +80,12 @@ def test_inverse(self, data, transforms): self.assertTrue(np.all(d["image"] == data["image"])) + @parameterized.expand(TEST_FAILS) def test_fail(self, data, transform): d = transform(data) with self.assertRaises(RuntimeError): d = transform.inverse(d) - if __name__ == "__main__": - # unittest.main() - a = TestInverse() - a.test_fail(TEST_FAIL[0], TEST_FAIL[1]) - for t in TESTS: - a.test_inverse(t[0], t[1]) + unittest.main() From 9d08f2640df3f1c9cb712754716b550a770b17e7 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Thu, 28 Jan 2021 14:30:31 +0000 Subject: [PATCH 05/80] extra test Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- tests/test_inverse.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/tests/test_inverse.py b/tests/test_inverse.py index 3a8d92de9a..2ad92ef4b4 100644 --- a/tests/test_inverse.py +++ b/tests/test_inverse.py @@ -15,12 +15,13 @@ from parameterized import parameterized from monai.transforms import Compose, SpatialPad, SpatialPadd +from monai.utils import Method TEST_0 = [ {"image": np.arange(0, 10).reshape(1, 10)}, [ SpatialPadd(keys="image", spatial_size=[15]), - SpatialPadd(keys="image", spatial_size=[21]), + SpatialPadd(keys="image", spatial_size=[21], method=Method.END), SpatialPadd(keys="image", spatial_size=[24]), ], ] @@ -30,7 +31,7 @@ [ SpatialPadd(keys="image", spatial_size=[11, 12]), SpatialPadd(keys="image", spatial_size=[12, 21]), - SpatialPadd(keys="image", spatial_size=[14, 25]), + SpatialPadd(keys="image", spatial_size=[14, 25], method=Method.END), ], ] @@ -63,22 +64,25 @@ class TestInverse(unittest.TestCase): @parameterized.expand(TESTS) def test_inverse(self, data, transforms): - d = data.copy() + forwards = [data.copy()] # Apply forwards for t in transforms: - d = t(d) + forwards.append(t(forwards[-1])) # Check that error is thrown when inverse are used out of order. t = transforms[0] if len(transforms) > 1 else SpatialPadd("image", [10, 5]) with self.assertRaises(RuntimeError): - t.inverse(d) + t.inverse(forwards[-1]) # Apply inverses - for t in reversed(transforms): - d = t.inverse(d) + backwards = [forwards[-1].copy()] + for i, t in enumerate(reversed(transforms)): + backwards.append(t.inverse(backwards[-1])) + self.assertTrue(np.all(backwards[-1]["image"] == forwards[len(forwards) - i - 2]["image"])) - self.assertTrue(np.all(d["image"] == data["image"])) + # Check we got back to beginning + self.assertTrue(np.all(backwards[-1]["image"] == forwards[0]["image"])) @parameterized.expand(TEST_FAILS) def test_fail(self, data, transform): From a54869bb5e534e1d205c88d03c04c6101371e7a4 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Thu, 28 Jan 2021 16:42:43 +0000 Subject: [PATCH 06/80] serialisation of transform Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/__init__.py | 2 +- monai/transforms/compose.py | 6 +++--- monai/transforms/croppad/dictionary.py | 24 +++++++++++++++------- monai/transforms/transform.py | 28 ++++++++++++++++---------- 4 files changed, 38 insertions(+), 22 deletions(-) diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index 50730b28e0..817244ea99 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -234,7 +234,7 @@ ZoomD, ZoomDict, ) -from .transform import MapTransform, Randomizable, SpatialMapTransform, Transform +from .transform import InvertibleTransform, MapTransform, Randomizable, Transform from .utility.array import ( AddChannel, AddExtremePointsChannel, diff --git a/monai/transforms/compose.py b/monai/transforms/compose.py index 529404fb60..ccf80a5f15 100644 --- a/monai/transforms/compose.py +++ b/monai/transforms/compose.py @@ -17,14 +17,14 @@ import numpy as np -from monai.transforms.transform import Randomizable, Transform +from monai.transforms.transform import InvertibleTransform, Randomizable, Transform from monai.transforms.utils import apply_transform from monai.utils import MAX_SEED, ensure_tuple, get_seed __all__ = ["Compose"] -class Compose(Randomizable, Transform): +class Compose(Randomizable, Transform, InvertibleTransform): """ ``Compose`` provides the ability to chain a series of calls together in a sequence. Each transform in the sequence must take a single argument and @@ -126,6 +126,6 @@ def inverse(self, data): if transform_key not in data: continue for t in reversed(data[transform_key]): - transform = t["obj"] + transform = t["class"](**t["init_args"]) d = transform.inverse(d) return d diff --git a/monai/transforms/croppad/dictionary.py b/monai/transforms/croppad/dictionary.py index 850d411d68..438286072b 100644 --- a/monai/transforms/croppad/dictionary.py +++ b/monai/transforms/croppad/dictionary.py @@ -31,7 +31,7 @@ SpatialCrop, SpatialPad, ) -from monai.transforms.transform import MapTransform, Randomizable, SpatialMapTransform +from monai.transforms.transform import InvertibleTransform, MapTransform, Randomizable from monai.transforms.utils import ( generate_pos_neg_label_crop_centers, generate_spatial_bounding_box, @@ -83,7 +83,7 @@ NumpyPadModeSequence = Union[Sequence[Union[NumpyPadMode, str]], NumpyPadMode, str] -class SpatialPadd(SpatialMapTransform): +class SpatialPadd(MapTransform, InvertibleTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.SpatialPad`. Performs padding to the data, symmetric for all sides or all on one side for each dimension. @@ -118,19 +118,29 @@ def __init__( def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) for key, m in zip(self.keys, self.mode): - d = self.append_applied_transforms(d, key, {"obj": self, "orig_size": d[key].shape}) + orig_size = d[key].shape d[key] = self.padder(d[key], mode=m) + self.append_applied_transforms(d, key, {"orig_size": orig_size}) return d + def get_input_args(self): + return { + "keys": self.keys, + "method": self.padder.method, + "mode": self.mode, + "spatial_size": self.padder.spatial_size, + } + def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) for key in self.keys: transform = self.get_most_recent_transform(d, key) - if transform["obj"] != self: - raise RuntimeError("Should inverse most recently applied inverse-able transform first") + if transform["class"] != type(self) or transform["init_args"] != self.get_input_args(): + raise RuntimeError("Should inverse most recently applied invertible transform first") # Create inverse transform - roi_size = transform["orig_size"][1:] - im_shape = d[key].shape[1:] if self.padder.method == Method.SYMMETRIC else transform["orig_size"][1:] + extra_info = transform["extra_info"] + roi_size = extra_info["orig_size"][1:] + im_shape = d[key].shape[1:] if self.padder.method == Method.SYMMETRIC else extra_info["orig_size"][1:] roi_center = [floor(i / 2) if r % 2 == 0 else (i - 1) / 2 for r, i in zip(roi_size, im_shape)] inverse_transform = SpatialCrop(roi_center, roi_size) diff --git a/monai/transforms/transform.py b/monai/transforms/transform.py index 826c5a0774..f0d7259019 100644 --- a/monai/transforms/transform.py +++ b/monai/transforms/transform.py @@ -20,7 +20,7 @@ from monai.config import KeysCollection from monai.utils import MAX_SEED, ensure_tuple -__all__ = ["Randomizable", "Transform", "MapTransform", "SpatialMapTransform"] +__all__ = ["Randomizable", "Transform", "MapTransform", "InvertibleTransform"] class Randomizable(ABC): @@ -206,24 +206,26 @@ def __call__(self, data): raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") -class SpatialMapTransform(MapTransform): - """Sub-class of ``MapTransform`` for deterministic, spatial transforms. +class InvertibleTransform(ABC): + """Classes for invertible transforms. - This class exists so that an ``invert`` method can be implemented. This allows - images to be cropped, rotated, padded, etc., and then for segmentations to be - returned to their original size before saving to file for comparison in an external - viewer. + This class exists so that an ``invert`` method can be implemented. This allows, for + example, images to be cropped, rotated, padded, etc., during training and inference, + and after be returned to their original size before saving to file for comparison in + an external viewer. + + When the `__call__` method is called, a serialization of the class is stored. When + the `inverse` method is called, the serialization is then removed. We use last in, + first out for the inverted transforms. """ - @staticmethod - def append_applied_transforms(data, key, args): + def append_applied_transforms(self, data, key, extra_args): """Append to list of applied transforms for that key.""" key += "_transforms" # If this is the first, create list if key not in data: data[key] = [] - data[key].append(args) - return data + data[key].append({"class": type(self), "init_args": self.get_input_args(), "extra_info": extra_args}) @staticmethod def get_most_recent_transform(data, key): @@ -235,6 +237,10 @@ def remove_most_recent_transform(data, key): """Get all applied transforms.""" data[key + "_transforms"].pop() + def get_input_args(self): + """Return dictionary of input arguments.""" + raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") + def inverse(self, data: Any): """ Inverse of ``__call__``. From b310c196c7a24680e5db54565eeb6dde00801e46 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Thu, 28 Jan 2021 17:42:20 +0000 Subject: [PATCH 07/80] add rotate Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/spatial/dictionary.py | 39 ++++++++++++- monai/transforms/transform.py | 16 +++--- tests/test_inverse.py | 80 ++++++++++++++++++++------ 3 files changed, 109 insertions(+), 26 deletions(-) diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index a6f1b53996..42f4409076 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -35,7 +35,7 @@ Spacing, Zoom, ) -from monai.transforms.transform import MapTransform, Randomizable +from monai.transforms.transform import InvertibleTransform, MapTransform, Randomizable from monai.transforms.utils import create_grid from monai.utils import ( GridSampleMode, @@ -761,7 +761,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda return d -class Rotated(MapTransform): +class Rotated(MapTransform, InvertibleTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.Rotate`. @@ -816,6 +816,41 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda align_corners=self.align_corners[idx], dtype=self.dtype[idx], ) + self.append_applied_transforms(d, key) + return d + + def get_input_args(self) -> dict: + return { + "keys": self.keys, + "angle": self.rotator.angle, + "keep_size": self.rotator.keep_size, + "mode": self.mode, + "padding_mode": self.padding_mode, + "align_corners": self.align_corners, + "dtype": self.dtype, + } + + def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + d = dict(data) + for idx, key in enumerate(self.keys): + transform = self.get_most_recent_transform(d, key) + if transform["class"] != type(self) or transform["init_args"] != self.get_input_args(): + raise RuntimeError("Should inverse most recently applied invertible transform first") + # Create inverse transform + in_angle = transform["init_args"]["angle"] + angle = [-a for a in in_angle] if isinstance(in_angle, Sequence) else -in_angle + inverse_rotator = Rotate(angle=angle, keep_size=transform["init_args"]["keep_size"]) + # Apply inverse transform + d[key] = inverse_rotator( + d[key], + mode=self.mode[idx], + padding_mode=self.padding_mode[idx], + align_corners=self.align_corners[idx], + dtype=self.dtype[idx], + ) + # Remove the applied transform + self.remove_most_recent_transform(d, key) + return d diff --git a/monai/transforms/transform.py b/monai/transforms/transform.py index f0d7259019..02e5879fb9 100644 --- a/monai/transforms/transform.py +++ b/monai/transforms/transform.py @@ -219,29 +219,29 @@ class InvertibleTransform(ABC): first out for the inverted transforms. """ - def append_applied_transforms(self, data, key, extra_args): + def append_applied_transforms(self, data:dict, key:Hashable, extra_args: Optional[dict]=None) -> None: """Append to list of applied transforms for that key.""" - key += "_transforms" + key = str(key) + "_transforms" # If this is the first, create list if key not in data: data[key] = [] data[key].append({"class": type(self), "init_args": self.get_input_args(), "extra_info": extra_args}) @staticmethod - def get_most_recent_transform(data, key): + def get_most_recent_transform(data:dict, key:Hashable) -> dict: """Get all applied transforms.""" - return data[key + "_transforms"][-1] + return dict(data[str(key) + "_transforms"][-1]) @staticmethod - def remove_most_recent_transform(data, key): + def remove_most_recent_transform(data:dict, key:Hashable) -> None: """Get all applied transforms.""" - data[key + "_transforms"].pop() + data[str(key) + "_transforms"].pop() - def get_input_args(self): + def get_input_args(self) -> dict: """Return dictionary of input arguments.""" raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") - def inverse(self, data: Any): + def inverse(self, data:dict): """ Inverse of ``__call__``. diff --git a/tests/test_inverse.py b/tests/test_inverse.py index 2ad92ef4b4..fda58f7fea 100644 --- a/tests/test_inverse.py +++ b/tests/test_inverse.py @@ -9,29 +9,32 @@ # See the License for the specific language governing permissions and # limitations under the License. +from monai.transforms.transform import InvertibleTransform import unittest import numpy as np from parameterized import parameterized -from monai.transforms import Compose, SpatialPad, SpatialPadd +from monai.transforms import Compose, SpatialPad, SpatialPadd, Rotated, AddChanneld +from monai.data import create_test_image_2d from monai.utils import Method +import matplotlib.pyplot as plt TEST_0 = [ {"image": np.arange(0, 10).reshape(1, 10)}, [ - SpatialPadd(keys="image", spatial_size=[15]), - SpatialPadd(keys="image", spatial_size=[21], method=Method.END), - SpatialPadd(keys="image", spatial_size=[24]), + SpatialPadd("image", spatial_size=[15]), + SpatialPadd("image", spatial_size=[21], method=Method.END), + SpatialPadd("image", spatial_size=[24]), ], ] TEST_1 = [ {"image": np.arange(0, 10 * 9).reshape(1, 10, 9)}, [ - SpatialPadd(keys="image", spatial_size=[11, 12]), - SpatialPadd(keys="image", spatial_size=[12, 21]), - SpatialPadd(keys="image", spatial_size=[14, 25], method=Method.END), + SpatialPadd("image", spatial_size=[11, 12]), + SpatialPadd("image", spatial_size=[12, 21]), + SpatialPadd("image", spatial_size=[14, 25], method=Method.END), ], ] @@ -40,9 +43,9 @@ [ Compose( [ - SpatialPadd(keys="image", spatial_size=[15]), - SpatialPadd(keys="image", spatial_size=[21]), - SpatialPadd(keys="image", spatial_size=[24]), + SpatialPadd("image", spatial_size=[15]), + SpatialPadd("image", spatial_size=[21]), + SpatialPadd("image", spatial_size=[24]), ] ) ], @@ -57,13 +60,22 @@ ), ] -TESTS = [TEST_0, TEST_1, TEST_2] +TEST_ROTATE = [ + {"image": create_test_image_2d(100, 100)[0]}, + [ + AddChanneld("image"), + Rotated("image", -np.pi / 6, True, "bilinear", "border", False), + ] +] + +TESTS_LOSSLESS = [TEST_0, TEST_1, TEST_2] +TESTS_LOSSY = [TEST_ROTATE] TEST_FAILS = [TEST_FAIL_0] class TestInverse(unittest.TestCase): - @parameterized.expand(TESTS) - def test_inverse(self, data, transforms): + @parameterized.expand(TESTS_LOSSLESS) + def test_inverse_lossless(self, data, transforms): forwards = [data.copy()] # Apply forwards @@ -78,12 +90,46 @@ def test_inverse(self, data, transforms): # Apply inverses backwards = [forwards[-1].copy()] for i, t in enumerate(reversed(transforms)): - backwards.append(t.inverse(backwards[-1])) - self.assertTrue(np.all(backwards[-1]["image"] == forwards[len(forwards) - i - 2]["image"])) + if isinstance(t, InvertibleTransform): + backwards.append(t.inverse(backwards[-1])) + self.assertTrue(np.all(backwards[-1]["image"] == forwards[len(forwards) - i - 2]["image"])) # Check we got back to beginning self.assertTrue(np.all(backwards[-1]["image"] == forwards[0]["image"])) + def test_inverse_lossy(self, data, transforms): + forwards = [data.copy()] + + # Apply forwards + for t in transforms: + forwards.append(t(forwards[-1])) + + # Check that error is thrown when inverse are used out of order. + t = SpatialPadd("image", [10, 5]) + with self.assertRaises(RuntimeError): + t.inverse(forwards[-1]) + + # Apply inverses + backwards = [forwards[-1].copy()] + for i, t in enumerate(reversed(transforms)): + if isinstance(t, InvertibleTransform): + backwards.append(t.inverse(backwards[-1])) + # self.assertTrue(np.all(backwards[-1]["image"] == forwards[len(forwards) - i - 2]["image"])) + + # Check we got back to beginning + # self.assertTrue(np.all(backwards[-1]["image"] == forwards[0]["image"])) + fig, axes = plt.subplots(1, 3) + pre = forwards[0]["image"] + post = backwards[-1]["image"][0] + diff = post - pre + for i, (im, title) in enumerate(zip([pre, post, diff],["pre", "post", "diff"])): + ax = axes[i] + _ = ax.imshow(im) + ax.set_title(title, fontsize=25) + ax.axis('off') + fig.show() + pass + @parameterized.expand(TEST_FAILS) def test_fail(self, data, transform): d = transform(data) @@ -92,4 +138,6 @@ def test_fail(self, data, transform): if __name__ == "__main__": - unittest.main() + # unittest.main() + a = TestInverse() + a.test_inverse_lossy(*TEST_ROTATE) From f2edc9ce96cfd624c0ae1cd0dc0e881530cdaaa7 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Thu, 28 Jan 2021 18:16:07 +0000 Subject: [PATCH 08/80] rotate w/ keep_size=True Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/transform.py | 8 ++++---- tests/test_inverse.py | 12 ++++++------ 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/monai/transforms/transform.py b/monai/transforms/transform.py index 02e5879fb9..0b8c9097f2 100644 --- a/monai/transforms/transform.py +++ b/monai/transforms/transform.py @@ -219,7 +219,7 @@ class InvertibleTransform(ABC): first out for the inverted transforms. """ - def append_applied_transforms(self, data:dict, key:Hashable, extra_args: Optional[dict]=None) -> None: + def append_applied_transforms(self, data: dict, key: Hashable, extra_args: Optional[dict] = None) -> None: """Append to list of applied transforms for that key.""" key = str(key) + "_transforms" # If this is the first, create list @@ -228,12 +228,12 @@ def append_applied_transforms(self, data:dict, key:Hashable, extra_args: Optiona data[key].append({"class": type(self), "init_args": self.get_input_args(), "extra_info": extra_args}) @staticmethod - def get_most_recent_transform(data:dict, key:Hashable) -> dict: + def get_most_recent_transform(data: dict, key: Hashable) -> dict: """Get all applied transforms.""" return dict(data[str(key) + "_transforms"][-1]) @staticmethod - def remove_most_recent_transform(data:dict, key:Hashable) -> None: + def remove_most_recent_transform(data: dict, key: Hashable) -> None: """Get all applied transforms.""" data[str(key) + "_transforms"].pop() @@ -241,7 +241,7 @@ def get_input_args(self) -> dict: """Return dictionary of input arguments.""" raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") - def inverse(self, data:dict): + def inverse(self, data: dict): """ Inverse of ``__call__``. diff --git a/tests/test_inverse.py b/tests/test_inverse.py index fda58f7fea..5ddc6eaf46 100644 --- a/tests/test_inverse.py +++ b/tests/test_inverse.py @@ -9,16 +9,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -from monai.transforms.transform import InvertibleTransform import unittest +import matplotlib.pyplot as plt import numpy as np from parameterized import parameterized -from monai.transforms import Compose, SpatialPad, SpatialPadd, Rotated, AddChanneld from monai.data import create_test_image_2d +from monai.transforms import AddChanneld, Compose, Rotated, SpatialPad, SpatialPadd +from monai.transforms.transform import InvertibleTransform from monai.utils import Method -import matplotlib.pyplot as plt TEST_0 = [ {"image": np.arange(0, 10).reshape(1, 10)}, @@ -65,7 +65,7 @@ [ AddChanneld("image"), Rotated("image", -np.pi / 6, True, "bilinear", "border", False), - ] + ], ] TESTS_LOSSLESS = [TEST_0, TEST_1, TEST_2] @@ -122,11 +122,11 @@ def test_inverse_lossy(self, data, transforms): pre = forwards[0]["image"] post = backwards[-1]["image"][0] diff = post - pre - for i, (im, title) in enumerate(zip([pre, post, diff],["pre", "post", "diff"])): + for i, (im, title) in enumerate(zip([pre, post, diff], ["pre", "post", "diff"])): ax = axes[i] _ = ax.imshow(im) ax.set_title(title, fontsize=25) - ax.axis('off') + ax.axis("off") fig.show() pass From 41bec26911be90d493f519b91fc727296ef945d1 Mon Sep 17 00:00:00 2001 From: Rich <33289025+rijobro@users.noreply.github.com> Date: Fri, 29 Jan 2021 10:27:17 +0000 Subject: [PATCH 09/80] rotate w/ keep_size=False Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/spatial/dictionary.py | 6 +- tests/test_inverse.py | 106 +++++++++++++++++-------- 2 files changed, 77 insertions(+), 35 deletions(-) diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index 42f4409076..5659aa25a2 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -809,6 +809,7 @@ def __init__( def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) for idx, key in enumerate(self.keys): + self.append_applied_transforms(d, key, {"orig_size": d[key].shape[1:]}) d[key] = self.rotator( d[key], mode=self.mode[idx], @@ -816,7 +817,6 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda align_corners=self.align_corners[idx], dtype=self.dtype[idx], ) - self.append_applied_transforms(d, key) return d def get_input_args(self) -> dict: @@ -848,6 +848,10 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar align_corners=self.align_corners[idx], dtype=self.dtype[idx], ) + # If the keep_size==False, need to crop image + if not transform["init_args"]["keep_size"]: + d[key] = CenterSpatialCrop(transform["extra_info"]["orig_size"])(d[key]) + # Remove the applied transform self.remove_most_recent_transform(d, key) diff --git a/tests/test_inverse.py b/tests/test_inverse.py index 5ddc6eaf46..bc8c2658bf 100644 --- a/tests/test_inverse.py +++ b/tests/test_inverse.py @@ -9,16 +9,27 @@ # See the License for the specific language governing permissions and # limitations under the License. +import random import unittest +from typing import TYPE_CHECKING import matplotlib.pyplot as plt import numpy as np -from parameterized import parameterized from monai.data import create_test_image_2d from monai.transforms import AddChanneld, Compose, Rotated, SpatialPad, SpatialPadd from monai.transforms.transform import InvertibleTransform -from monai.utils import Method +from monai.utils import Method, optional_import + +# from parameterized import parameterized + + +if TYPE_CHECKING: + import matplotlib.pyplot as plt + + has_matplotlib = True +else: + plt, has_matplotlib = optional_import("matplotlib.pyplot") TEST_0 = [ {"image": np.arange(0, 10).reshape(1, 10)}, @@ -60,21 +71,50 @@ ), ] -TEST_ROTATE = [ - {"image": create_test_image_2d(100, 100)[0]}, - [ - AddChanneld("image"), - Rotated("image", -np.pi / 6, True, "bilinear", "border", False), - ], -] +TEST_ROTATES = [] +for k in [True, False]: + for a in [False, True]: + TEST_ROTATE = [ + {"image": create_test_image_2d(100, 100)[0]}, + [ + AddChanneld("image"), + Rotated("image", random.uniform(np.pi / 6, np.pi), k, "bilinear", "border", a), + ], + ] + TEST_ROTATES.append(TEST_ROTATE) TESTS_LOSSLESS = [TEST_0, TEST_1, TEST_2] -TESTS_LOSSY = [TEST_ROTATE] -TEST_FAILS = [TEST_FAIL_0] +TESTS_LOSSY = [*TEST_ROTATES] +TESTS_FAIL = [TEST_FAIL_0] + + +def get_percent_diff_im(array_true, array): + return 100 * (array_true - array) / (array_true + 1e-5) + + +def get_mean_percent_diff(array_true, array): + return abs(np.mean(get_percent_diff_im(array_true, array))) + + +def plot_im(orig, fwd_bck, fwd): + diff_orig_fwd_bck = get_percent_diff_im(orig, fwd_bck) + fig, axes = plt.subplots( + 1, 4, gridspec_kw={"width_ratios": [orig.shape[1], fwd_bck.shape[1], diff_orig_fwd_bck.shape[1], fwd.shape[1]]} + ) + for i, (im, title) in enumerate( + zip([orig, fwd_bck, diff_orig_fwd_bck, fwd], ["orig", "fwd_bck", "%% diff", "fwd"]) + ): + ax = axes[i] + vmax = max(np.max(i) for i in [orig, fwd_bck, fwd]) if i != 2 else None + im_show = ax.imshow(np.squeeze(im), vmin=0, vmax=vmax) + ax.set_title(title, fontsize=25) + ax.axis("off") + fig.colorbar(im_show, ax=ax) + plt.show() class TestInverse(unittest.TestCase): - @parameterized.expand(TESTS_LOSSLESS) + # @parameterized.expand(TESTS_LOSSLESS) def test_inverse_lossless(self, data, transforms): forwards = [data.copy()] @@ -94,10 +134,8 @@ def test_inverse_lossless(self, data, transforms): backwards.append(t.inverse(backwards[-1])) self.assertTrue(np.all(backwards[-1]["image"] == forwards[len(forwards) - i - 2]["image"])) - # Check we got back to beginning - self.assertTrue(np.all(backwards[-1]["image"] == forwards[0]["image"])) - - def test_inverse_lossy(self, data, transforms): + # @parameterized.expand(TESTS_LOSSY) + def test_inverse_lossy(self, data, transforms, visualise=False): forwards = [data.copy()] # Apply forwards @@ -114,23 +152,18 @@ def test_inverse_lossy(self, data, transforms): for i, t in enumerate(reversed(transforms)): if isinstance(t, InvertibleTransform): backwards.append(t.inverse(backwards[-1])) - # self.assertTrue(np.all(backwards[-1]["image"] == forwards[len(forwards) - i - 2]["image"])) - - # Check we got back to beginning - # self.assertTrue(np.all(backwards[-1]["image"] == forwards[0]["image"])) - fig, axes = plt.subplots(1, 3) - pre = forwards[0]["image"] - post = backwards[-1]["image"][0] - diff = post - pre - for i, (im, title) in enumerate(zip([pre, post, diff], ["pre", "post", "diff"])): - ax = axes[i] - _ = ax.imshow(im) - ax.set_title(title, fontsize=25) - ax.axis("off") - fig.show() - pass - - @parameterized.expand(TEST_FAILS) + mean_percent_diff = get_mean_percent_diff(backwards[-1]["image"], forwards[-i - 2]["image"]) + self.assertLess(mean_percent_diff, 10) + + if has_matplotlib and visualise: + plot_im(forwards[1]["image"], backwards[-1]["image"], forwards[-1]["image"]) + + # Check that if the inverse hadn't been called, mean_percent_diff would have been greater + if forwards[1]["image"].shape == forwards[-1]["image"].shape: + mean_percent_diff = get_mean_percent_diff(forwards[1]["image"], forwards[-1]["image"]) + self.assertGreater(mean_percent_diff, 50) + + # @parameterized.expand(TESTS_FAIL) def test_fail(self, data, transform): d = transform(data) with self.assertRaises(RuntimeError): @@ -140,4 +173,9 @@ def test_fail(self, data, transform): if __name__ == "__main__": # unittest.main() a = TestInverse() - a.test_inverse_lossy(*TEST_ROTATE) + for t in TESTS_LOSSY: + a.test_inverse_lossy(*t) + for t in TESTS_LOSSLESS: + a.test_inverse_lossless(*t) + for t in TESTS_FAIL: + a.test_fail(*t) From 8768d5785f70593b5249c0847e235ee7f5450d01 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Fri, 29 Jan 2021 10:34:20 +0000 Subject: [PATCH 10/80] autofix Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- tests/test_inverse.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/tests/test_inverse.py b/tests/test_inverse.py index bc8c2658bf..a30a26ef5d 100644 --- a/tests/test_inverse.py +++ b/tests/test_inverse.py @@ -13,7 +13,6 @@ import unittest from typing import TYPE_CHECKING -import matplotlib.pyplot as plt import numpy as np from monai.data import create_test_image_2d @@ -172,10 +171,10 @@ def test_fail(self, data, transform): if __name__ == "__main__": # unittest.main() - a = TestInverse() + test = TestInverse() for t in TESTS_LOSSY: - a.test_inverse_lossy(*t) + test.test_inverse_lossy(*t) for t in TESTS_LOSSLESS: - a.test_inverse_lossless(*t) + test.test_inverse_lossless(*t) for t in TESTS_FAIL: - a.test_fail(*t) + test.test_fail(*t) From 3b91c44ae93f089a36f2e77bb974ebf7e7644a12 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Fri, 29 Jan 2021 11:09:08 +0000 Subject: [PATCH 11/80] randrotated Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/spatial/dictionary.py | 52 +++++++++++++++++++++++++- tests/test_inverse.py | 46 +++++++++++++++-------- 2 files changed, 80 insertions(+), 18 deletions(-) diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index 5659aa25a2..64e5cb85f6 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -858,7 +858,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar return d -class RandRotated(Randomizable, MapTransform): +class RandRotated(Randomizable, MapTransform, InvertibleTransform): """ Dictionary-based version :py:class:`monai.transforms.RandRotate` Randomly rotates the input arrays. @@ -938,12 +938,16 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda self.randomize() d = dict(data) if not self._do_transform: + for key in self.keys: + self.append_applied_transforms(d, key, {"do_transform": False}) return d + angle=self.x if d[self.keys[0]].ndim == 3 else (self.x, self.y, self.z), rotator = Rotate( - angle=self.x if d[self.keys[0]].ndim == 3 else (self.x, self.y, self.z), + angle=angle, keep_size=self.keep_size, ) for idx, key in enumerate(self.keys): + self.append_applied_transforms(d, key, {"angle": angle, "orig_size": d[key].shape[1:]}) d[key] = rotator( d[key], mode=self.mode[idx], @@ -953,6 +957,50 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda ) return d + def get_input_args(self) -> dict: + return { + "keys": self.keys, + "range_x": self.range_x, + "range_y": self.range_y, + "range_z": self.range_z, + "prob": self.prob, + "keep_size": self.keep_size, + "mode": self.mode, + "padding_mode": self.padding_mode, + "align_corners": self.align_corners, + "dtype": self.dtype, + } + + def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + d = dict(data) + for idx, key in enumerate(self.keys): + transform = self.get_most_recent_transform(d, key) + if transform["class"] != type(self) or transform["init_args"] != self.get_input_args(): + raise RuntimeError("Should inverse most recently applied invertible transform first") + # If the transform wasn't applied (because of `prob`), nothing to do + if "do_transform" in transform["extra_info"]: + return d + # Create inverse transform + in_angle = transform["extra_info"]["angle"] + angle = [-a for a in in_angle] if isinstance(in_angle, Sequence) else -in_angle + inverse_rotator = Rotate(angle=angle, keep_size=transform["init_args"]["keep_size"]) + # Apply inverse transform + d[key] = inverse_rotator( + d[key], + mode=self.mode[idx], + padding_mode=self.padding_mode[idx], + align_corners=self.align_corners[idx], + dtype=self.dtype[idx], + ) + # If the keep_size==False, need to crop image + if not transform["init_args"]["keep_size"]: + d[key] = CenterSpatialCrop(transform["extra_info"]["orig_size"])(d[key]) + + # Remove the applied transform + self.remove_most_recent_transform(d, key) + + return d + class Zoomd(MapTransform): """ diff --git a/tests/test_inverse.py b/tests/test_inverse.py index a30a26ef5d..75326be5f7 100644 --- a/tests/test_inverse.py +++ b/tests/test_inverse.py @@ -16,7 +16,7 @@ import numpy as np from monai.data import create_test_image_2d -from monai.transforms import AddChanneld, Compose, Rotated, SpatialPad, SpatialPadd +from monai.transforms import AddChanneld, Compose, Rotated, RandRotated, SpatialPad, SpatialPadd from monai.transforms.transform import InvertibleTransform from monai.utils import Method, optional_import @@ -71,16 +71,30 @@ ] TEST_ROTATES = [] -for k in [True, False]: - for a in [False, True]: - TEST_ROTATE = [ - {"image": create_test_image_2d(100, 100)[0]}, - [ - AddChanneld("image"), - Rotated("image", random.uniform(np.pi / 6, np.pi), k, "bilinear", "border", a), - ], - ] - TEST_ROTATES.append(TEST_ROTATE) +# for k in [True, False]: +# for a in [False, True]: +# TEST_ROTATE = [ +# {"image": create_test_image_2d(100, 100)[0]}, +# [ +# AddChanneld("image"), +# Rotated("image", random.uniform(np.pi / 6, np.pi), k, "bilinear", "border", a), +# ], +# ] +# TEST_ROTATES.append(TEST_ROTATE) +for p in [0, 1]: + TEST_ROTATE = [ + {"image": create_test_image_2d(100, 100)[0]}, + [ + AddChanneld("image"), + RandRotated( + "image", + random.uniform(np.pi / 6, np.pi), + random.uniform(np.pi / 6, np.pi), + random.uniform(np.pi / 6, np.pi), + p, True, "bilinear", "border", False), + ], + ] + TEST_ROTATES.append(TEST_ROTATE) TESTS_LOSSLESS = [TEST_0, TEST_1, TEST_2] TESTS_LOSSY = [*TEST_ROTATES] @@ -172,9 +186,9 @@ def test_fail(self, data, transform): if __name__ == "__main__": # unittest.main() test = TestInverse() + # for t in TESTS_LOSSLESS: + # test.test_inverse_lossless(*t) for t in TESTS_LOSSY: - test.test_inverse_lossy(*t) - for t in TESTS_LOSSLESS: - test.test_inverse_lossless(*t) - for t in TESTS_FAIL: - test.test_fail(*t) + test.test_inverse_lossy(*t, True) + # for t in TESTS_FAIL: + # test.test_fail(*t) From e569b82322f1e66d8193879c4cfaf7e4d6d904a6 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Fri, 29 Jan 2021 11:29:19 +0000 Subject: [PATCH 12/80] randrotated Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/spatial/dictionary.py | 37 +++++++++++++------------- monai/transforms/transform.py | 3 +++ tests/test_inverse.py | 35 +++++++++++------------- 3 files changed, 36 insertions(+), 39 deletions(-) diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index 64e5cb85f6..dcb5cbac1d 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -939,7 +939,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda d = dict(data) if not self._do_transform: for key in self.keys: - self.append_applied_transforms(d, key, {"do_transform": False}) + self.append_applied_transforms(d, key) return d angle=self.x if d[self.keys[0]].ndim == 3 else (self.x, self.y, self.z), rotator = Rotate( @@ -977,24 +977,23 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar transform = self.get_most_recent_transform(d, key) if transform["class"] != type(self) or transform["init_args"] != self.get_input_args(): raise RuntimeError("Should inverse most recently applied invertible transform first") - # If the transform wasn't applied (because of `prob`), nothing to do - if "do_transform" in transform["extra_info"]: - return d - # Create inverse transform - in_angle = transform["extra_info"]["angle"] - angle = [-a for a in in_angle] if isinstance(in_angle, Sequence) else -in_angle - inverse_rotator = Rotate(angle=angle, keep_size=transform["init_args"]["keep_size"]) - # Apply inverse transform - d[key] = inverse_rotator( - d[key], - mode=self.mode[idx], - padding_mode=self.padding_mode[idx], - align_corners=self.align_corners[idx], - dtype=self.dtype[idx], - ) - # If the keep_size==False, need to crop image - if not transform["init_args"]["keep_size"]: - d[key] = CenterSpatialCrop(transform["extra_info"]["orig_size"])(d[key]) + # Check if random transform was actually performed (based on `prob`) + if transform["do_transform"]: + # Create inverse transform + in_angle = transform["extra_info"]["angle"] + angle = [-a for a in in_angle] if isinstance(in_angle, Sequence) else -in_angle + inverse_rotator = Rotate(angle=angle, keep_size=transform["init_args"]["keep_size"]) + # Apply inverse transform + d[key] = inverse_rotator( + d[key], + mode=self.mode[idx], + padding_mode=self.padding_mode[idx], + align_corners=self.align_corners[idx], + dtype=self.dtype[idx], + ) + # If the keep_size==False, need to crop image + if not transform["init_args"]["keep_size"]: + d[key] = CenterSpatialCrop(transform["extra_info"]["orig_size"])(d[key]) # Remove the applied transform self.remove_most_recent_transform(d, key) diff --git a/monai/transforms/transform.py b/monai/transforms/transform.py index 0b8c9097f2..50d2ea8d2d 100644 --- a/monai/transforms/transform.py +++ b/monai/transforms/transform.py @@ -226,6 +226,9 @@ def append_applied_transforms(self, data: dict, key: Hashable, extra_args: Optio if key not in data: data[key] = [] data[key].append({"class": type(self), "init_args": self.get_input_args(), "extra_info": extra_args}) + # If class is randomizable, store whether the transform was actually performed (based on `prob`) + if isinstance(self, Randomizable): + data[key][-1]["do_transform"] = self._do_transform @staticmethod def get_most_recent_transform(data: dict, key: Hashable) -> dict: diff --git a/tests/test_inverse.py b/tests/test_inverse.py index 75326be5f7..55c8ffd99d 100644 --- a/tests/test_inverse.py +++ b/tests/test_inverse.py @@ -71,16 +71,16 @@ ] TEST_ROTATES = [] -# for k in [True, False]: -# for a in [False, True]: -# TEST_ROTATE = [ -# {"image": create_test_image_2d(100, 100)[0]}, -# [ -# AddChanneld("image"), -# Rotated("image", random.uniform(np.pi / 6, np.pi), k, "bilinear", "border", a), -# ], -# ] -# TEST_ROTATES.append(TEST_ROTATE) +for k in [True, False]: + for a in [False, True]: + TEST_ROTATE = [ + {"image": create_test_image_2d(100, 100)[0]}, + [ + AddChanneld("image"), + Rotated("image", random.uniform(np.pi / 6, np.pi), k, "bilinear", "border", a), + ], + ] + TEST_ROTATES.append(TEST_ROTATE) for p in [0, 1]: TEST_ROTATE = [ {"image": create_test_image_2d(100, 100)[0]}, @@ -171,11 +171,6 @@ def test_inverse_lossy(self, data, transforms, visualise=False): if has_matplotlib and visualise: plot_im(forwards[1]["image"], backwards[-1]["image"], forwards[-1]["image"]) - # Check that if the inverse hadn't been called, mean_percent_diff would have been greater - if forwards[1]["image"].shape == forwards[-1]["image"].shape: - mean_percent_diff = get_mean_percent_diff(forwards[1]["image"], forwards[-1]["image"]) - self.assertGreater(mean_percent_diff, 50) - # @parameterized.expand(TESTS_FAIL) def test_fail(self, data, transform): d = transform(data) @@ -186,9 +181,9 @@ def test_fail(self, data, transform): if __name__ == "__main__": # unittest.main() test = TestInverse() - # for t in TESTS_LOSSLESS: - # test.test_inverse_lossless(*t) + for t in TESTS_LOSSLESS: + test.test_inverse_lossless(*t) for t in TESTS_LOSSY: - test.test_inverse_lossy(*t, True) - # for t in TESTS_FAIL: - # test.test_fail(*t) + test.test_inverse_lossy(*t) + for t in TESTS_FAIL: + test.test_fail(*t) From 6cff0487c203c6ee3ce6ce12bfcbf125e399d13d Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Fri, 29 Jan 2021 12:33:33 +0000 Subject: [PATCH 13/80] 3d rotation not working Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/utils.py | 13 ++---- tests/test_inverse.py | 95 ++++++++++++++++++++++++--------------- 2 files changed, 62 insertions(+), 46 deletions(-) diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index 23c6bd100a..13fb4d11fb 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -378,17 +378,10 @@ def apply_transform(transform: Callable, data, map_items: bool = True): data: an object to be transformed. map_items: whether to apply transform to each item in `data`, if `data` is a list or tuple. Defaults to True. - - Raises: - Exception: When ``transform`` raises an exception. - """ - try: - if isinstance(data, (list, tuple)) and map_items: - return [transform(item) for item in data] - return transform(data) - except Exception as e: - raise RuntimeError(f"applying transform {transform}") from e + if isinstance(data, (list, tuple)) and map_items: + return [transform(item) for item in data] + return transform(data) def create_grid( diff --git a/tests/test_inverse.py b/tests/test_inverse.py index 55c8ffd99d..d58bac03d5 100644 --- a/tests/test_inverse.py +++ b/tests/test_inverse.py @@ -9,13 +9,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +from functools import partial import random import unittest from typing import TYPE_CHECKING import numpy as np -from monai.data import create_test_image_2d +from monai.data import create_test_image_2d, create_test_image_3d from monai.transforms import AddChanneld, Compose, Rotated, RandRotated, SpatialPad, SpatialPadd from monai.transforms.transform import InvertibleTransform from monai.utils import Method, optional_import @@ -30,32 +31,44 @@ else: plt, has_matplotlib = optional_import("matplotlib.pyplot") -TEST_0 = [ +TEST_SPATIALS = [] +TEST_SPATIALS.append([ + "Spatial 1d", {"image": np.arange(0, 10).reshape(1, 10)}, [ SpatialPadd("image", spatial_size=[15]), SpatialPadd("image", spatial_size=[21], method=Method.END), SpatialPadd("image", spatial_size=[24]), ], -] +]) -TEST_1 = [ +TEST_SPATIALS.append([ + "Spatial 2d", {"image": np.arange(0, 10 * 9).reshape(1, 10, 9)}, [ SpatialPadd("image", spatial_size=[11, 12]), SpatialPadd("image", spatial_size=[12, 21]), SpatialPadd("image", spatial_size=[14, 25], method=Method.END), ], -] +]) -TEST_2 = [ - {"image": np.arange(0, 10).reshape(1, 10)}, +TEST_SPATIALS.append([ + "Spatial 3d", + {"image": np.arange(0, 10 * 9 * 8).reshape(1, 10, 9, 8)}, + [ + SpatialPadd("image", spatial_size=[55, 50, 45]), + ], +]) + +TEST_COMPOSE = [ + "Compose", + {"image": np.arange(0, 10 * 9 * 8).reshape(1, 10, 9, 8)}, [ Compose( [ - SpatialPadd("image", spatial_size=[15]), - SpatialPadd("image", spatial_size=[21]), - SpatialPadd("image", spatial_size=[24]), + SpatialPadd("image", spatial_size=[15, 12, 4]), + SpatialPadd("image", spatial_size=[21, 32, 1]), + SpatialPadd("image", spatial_size=[55, 50, 45]), ] ) ], @@ -70,33 +83,36 @@ ), ] +# TODO: add 3D TEST_ROTATES = [] -for k in [True, False]: - for a in [False, True]: +for create_im in [create_test_image_2d]: #, partial(create_test_image_3d, 100)]: + for keep_size in [True, False]: + for align_corners in [False, True]: + im, _ = create_im(100, 100) + angle = random.uniform(np.pi / 6, np.pi) + TEST_ROTATE = [ + f"Rotate{im.ndim}d, keep_size={keep_size}, align_corners={align_corners}", + {"image": im}, + [ + AddChanneld("image"), + Rotated("image", angle, keep_size, "bilinear", "border", align_corners), + ], + ] + TEST_ROTATES.append(TEST_ROTATE) + for prob in [0, 1]: + im, _ = create_im(100, 100) + angles = [random.uniform(np.pi / 6, np.pi) for _ in range(3)] TEST_ROTATE = [ - {"image": create_test_image_2d(100, 100)[0]}, + f"RandRotate{im.ndim}d, prob={prob}", + {"image": im}, [ AddChanneld("image"), - Rotated("image", random.uniform(np.pi / 6, np.pi), k, "bilinear", "border", a), + RandRotated("image", *angles, prob, True, "bilinear", "border", False), ], ] TEST_ROTATES.append(TEST_ROTATE) -for p in [0, 1]: - TEST_ROTATE = [ - {"image": create_test_image_2d(100, 100)[0]}, - [ - AddChanneld("image"), - RandRotated( - "image", - random.uniform(np.pi / 6, np.pi), - random.uniform(np.pi / 6, np.pi), - random.uniform(np.pi / 6, np.pi), - p, True, "bilinear", "border", False), - ], - ] - TEST_ROTATES.append(TEST_ROTATE) - -TESTS_LOSSLESS = [TEST_0, TEST_1, TEST_2] + +TESTS_LOSSLESS = [*TEST_SPATIALS, TEST_COMPOSE] TESTS_LOSSY = [*TEST_ROTATES] TESTS_FAIL = [TEST_FAIL_0] @@ -119,6 +135,9 @@ def plot_im(orig, fwd_bck, fwd): ): ax = axes[i] vmax = max(np.max(i) for i in [orig, fwd_bck, fwd]) if i != 2 else None + im = np.squeeze(im) + while im.ndim > 2: + im = im[..., im.shape[-1] // 2] im_show = ax.imshow(np.squeeze(im), vmin=0, vmax=vmax) ax.set_title(title, fontsize=25) ax.axis("off") @@ -128,7 +147,8 @@ def plot_im(orig, fwd_bck, fwd): class TestInverse(unittest.TestCase): # @parameterized.expand(TESTS_LOSSLESS) - def test_inverse_lossless(self, data, transforms): + def test_inverse_lossless(self, desc, data, transforms): + print(f"testing: {desc}...") forwards = [data.copy()] # Apply forwards @@ -148,7 +168,8 @@ def test_inverse_lossless(self, data, transforms): self.assertTrue(np.all(backwards[-1]["image"] == forwards[len(forwards) - i - 2]["image"])) # @parameterized.expand(TESTS_LOSSY) - def test_inverse_lossy(self, data, transforms, visualise=False): + def test_inverse_lossy(self, desc, data, transforms): + print("testing: " + desc) forwards = [data.copy()] # Apply forwards @@ -166,10 +187,12 @@ def test_inverse_lossy(self, data, transforms, visualise=False): if isinstance(t, InvertibleTransform): backwards.append(t.inverse(backwards[-1])) mean_percent_diff = get_mean_percent_diff(backwards[-1]["image"], forwards[-i - 2]["image"]) - self.assertLess(mean_percent_diff, 10) - - if has_matplotlib and visualise: - plot_im(forwards[1]["image"], backwards[-1]["image"], forwards[-1]["image"]) + try: + self.assertLess(mean_percent_diff, 10) + except AssertionError: + if has_matplotlib: + plot_im(forwards[1]["image"], backwards[-1]["image"], forwards[-1]["image"]) + raise # @parameterized.expand(TESTS_FAIL) def test_fail(self, data, transform): From c10b7e322e29dda050fd145e1370dc4010577d01 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Fri, 29 Jan 2021 14:32:04 +0000 Subject: [PATCH 14/80] works for dataloader Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/croppad/dictionary.py | 8 +- monai/transforms/spatial/dictionary.py | 12 +- monai/transforms/transform.py | 26 +++-- tests/test_inverse.py | 148 +++++++++++++++---------- 4 files changed, 113 insertions(+), 81 deletions(-) diff --git a/monai/transforms/croppad/dictionary.py b/monai/transforms/croppad/dictionary.py index 438286072b..76125f059d 100644 --- a/monai/transforms/croppad/dictionary.py +++ b/monai/transforms/croppad/dictionary.py @@ -123,11 +123,11 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda self.append_applied_transforms(d, key, {"orig_size": orig_size}) return d - def get_input_args(self): + def get_input_args(self, key): return { - "keys": self.keys, + "keys": key, "method": self.padder.method, - "mode": self.mode, + "mode": self.mode[0], "spatial_size": self.padder.spatial_size, } @@ -135,8 +135,6 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar d = dict(data) for key in self.keys: transform = self.get_most_recent_transform(d, key) - if transform["class"] != type(self) or transform["init_args"] != self.get_input_args(): - raise RuntimeError("Should inverse most recently applied invertible transform first") # Create inverse transform extra_info = transform["extra_info"] roi_size = extra_info["orig_size"][1:] diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index dcb5cbac1d..43376cc538 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -819,9 +819,9 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda ) return d - def get_input_args(self) -> dict: + def get_input_args(self, key) -> dict: return { - "keys": self.keys, + "keys": key, "angle": self.rotator.angle, "keep_size": self.rotator.keep_size, "mode": self.mode, @@ -834,8 +834,6 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar d = dict(data) for idx, key in enumerate(self.keys): transform = self.get_most_recent_transform(d, key) - if transform["class"] != type(self) or transform["init_args"] != self.get_input_args(): - raise RuntimeError("Should inverse most recently applied invertible transform first") # Create inverse transform in_angle = transform["init_args"]["angle"] angle = [-a for a in in_angle] if isinstance(in_angle, Sequence) else -in_angle @@ -957,9 +955,9 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda ) return d - def get_input_args(self) -> dict: + def get_input_args(self, key) -> dict: return { - "keys": self.keys, + "keys": key, "range_x": self.range_x, "range_y": self.range_y, "range_z": self.range_z, @@ -975,8 +973,6 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar d = dict(data) for idx, key in enumerate(self.keys): transform = self.get_most_recent_transform(d, key) - if transform["class"] != type(self) or transform["init_args"] != self.get_input_args(): - raise RuntimeError("Should inverse most recently applied invertible transform first") # Check if random transform was actually performed (based on `prob`) if transform["do_transform"]: # Create inverse transform diff --git a/monai/transforms/transform.py b/monai/transforms/transform.py index 50d2ea8d2d..738b718ce0 100644 --- a/monai/transforms/transform.py +++ b/monai/transforms/transform.py @@ -221,27 +221,29 @@ class InvertibleTransform(ABC): def append_applied_transforms(self, data: dict, key: Hashable, extra_args: Optional[dict] = None) -> None: """Append to list of applied transforms for that key.""" - key = str(key) + "_transforms" + key_transform = str(key) + "_transforms" # If this is the first, create list - if key not in data: - data[key] = [] - data[key].append({"class": type(self), "init_args": self.get_input_args(), "extra_info": extra_args}) + if key_transform not in data: + data[key_transform] = [] + data[key_transform].append({"class": type(self), "init_args": self.get_input_args(key), "extra_info": extra_args}) # If class is randomizable, store whether the transform was actually performed (based on `prob`) if isinstance(self, Randomizable): - data[key][-1]["do_transform"] = self._do_transform + data[key_transform][-1]["do_transform"] = self._do_transform - @staticmethod - def get_most_recent_transform(data: dict, key: Hashable) -> dict: - """Get all applied transforms.""" - return dict(data[str(key) + "_transforms"][-1]) + def get_most_recent_transform(self, data: dict, key: Hashable) -> dict: + """Get most recent transform.""" + transform = dict(data[str(key) + "_transforms"][-1]) + if transform["class"] != type(self) or transform["init_args"] != self.get_input_args(key): + raise RuntimeError("Should inverse most recently applied invertible transform first") + return transform @staticmethod def remove_most_recent_transform(data: dict, key: Hashable) -> None: - """Get all applied transforms.""" + """Remove most recent transform.""" data[str(key) + "_transforms"].pop() - def get_input_args(self) -> dict: - """Return dictionary of input arguments.""" + def get_input_args(self, key) -> dict: + """Get input arguments for a single key.""" raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") def inverse(self, data: dict): diff --git a/tests/test_inverse.py b/tests/test_inverse.py index d58bac03d5..359ef4f467 100644 --- a/tests/test_inverse.py +++ b/tests/test_inverse.py @@ -15,7 +15,7 @@ from typing import TYPE_CHECKING import numpy as np - +from monai.data import Dataset from monai.data import create_test_image_2d, create_test_image_3d from monai.transforms import AddChanneld, Compose, Rotated, RandRotated, SpatialPad, SpatialPadd from monai.transforms.transform import InvertibleTransform @@ -40,6 +40,7 @@ SpatialPadd("image", spatial_size=[21], method=Method.END), SpatialPadd("image", spatial_size=[24]), ], + True, ]) TEST_SPATIALS.append([ @@ -50,6 +51,7 @@ SpatialPadd("image", spatial_size=[12, 21]), SpatialPadd("image", spatial_size=[14, 25], method=Method.END), ], + True, ]) TEST_SPATIALS.append([ @@ -58,21 +60,46 @@ [ SpatialPadd("image", spatial_size=[55, 50, 45]), ], + True ]) -TEST_COMPOSE = [ - "Compose", - {"image": np.arange(0, 10 * 9 * 8).reshape(1, 10, 9, 8)}, +TEST_COMPOSES = [] +TEST_COMPOSES.append([ + "Compose 2d", + { + "image": np.arange(0, 10 * 9).reshape(1, 10, 9), + "label": np.arange(0, 10 * 9).reshape(1, 10, 9), + "other": np.arange(0, 10 * 9).reshape(1, 10, 9), + }, [ Compose( [ - SpatialPadd("image", spatial_size=[15, 12, 4]), - SpatialPadd("image", spatial_size=[21, 32, 1]), - SpatialPadd("image", spatial_size=[55, 50, 45]), + SpatialPadd(["image", "label"], spatial_size=[15, 12]), + SpatialPadd(["label"], spatial_size=[21, 32]), + SpatialPadd(["image"], spatial_size=[55, 50]), ] ) ], -] + True, +]) +TEST_COMPOSES.append([ + "Compose 3d", + { + "image": np.arange(0, 10 * 9 * 8).reshape(1, 10, 9, 8), + "label": np.arange(0, 10 * 9 * 8).reshape(1, 10, 9, 8), + "other": np.arange(0, 10 * 9 * 8).reshape(1, 10, 9, 8), + }, + [ + Compose( + [ + SpatialPadd(["image", "label"], spatial_size=[15, 12, 4]), + SpatialPadd(["label"], spatial_size=[21, 32, 1]), + SpatialPadd(["image"], spatial_size=[55, 50, 45]), + ] + ) + ], + True, +]) TEST_FAIL_0 = [ np.arange(0, 10).reshape(1, 10), @@ -81,6 +108,7 @@ SpatialPad(spatial_size=[15]), ] ), + True, ] # TODO: add 3D @@ -97,6 +125,7 @@ AddChanneld("image"), Rotated("image", angle, keep_size, "bilinear", "border", align_corners), ], + False, ] TEST_ROTATES.append(TEST_ROTATE) for prob in [0, 1]: @@ -109,24 +138,22 @@ AddChanneld("image"), RandRotated("image", *angles, prob, True, "bilinear", "border", False), ], + False, ] TEST_ROTATES.append(TEST_ROTATE) -TESTS_LOSSLESS = [*TEST_SPATIALS, TEST_COMPOSE] -TESTS_LOSSY = [*TEST_ROTATES] +TESTS = [*TEST_SPATIALS, *TEST_COMPOSES, *TEST_ROTATES] TESTS_FAIL = [TEST_FAIL_0] -def get_percent_diff_im(array_true, array): - return 100 * (array_true - array) / (array_true + 1e-5) - - -def get_mean_percent_diff(array_true, array): - return abs(np.mean(get_percent_diff_im(array_true, array))) +def get_fractional_diff_im(array_true, array): + diff = array_true - array + avg = (array_true + array) / 2 + return diff / (avg + 1e-10) def plot_im(orig, fwd_bck, fwd): - diff_orig_fwd_bck = get_percent_diff_im(orig, fwd_bck) + diff_orig_fwd_bck = 100 * get_fractional_diff_im(orig, fwd_bck) fig, axes = plt.subplots( 1, 4, gridspec_kw={"width_ratios": [orig.shape[1], fwd_bck.shape[1], diff_orig_fwd_bck.shape[1], fwd.shape[1]]} ) @@ -138,38 +165,34 @@ def plot_im(orig, fwd_bck, fwd): im = np.squeeze(im) while im.ndim > 2: im = im[..., im.shape[-1] // 2] - im_show = ax.imshow(np.squeeze(im), vmin=0, vmax=vmax) + im_show = ax.imshow(np.squeeze(im), vmax=vmax) ax.set_title(title, fontsize=25) ax.axis("off") fig.colorbar(im_show, ax=ax) plt.show() - class TestInverse(unittest.TestCase): - # @parameterized.expand(TESTS_LOSSLESS) - def test_inverse_lossless(self, desc, data, transforms): - print(f"testing: {desc}...") - forwards = [data.copy()] - # Apply forwards - for t in transforms: - forwards.append(t(forwards[-1])) - - # Check that error is thrown when inverse are used out of order. - t = transforms[0] if len(transforms) > 1 else SpatialPadd("image", [10, 5]) - with self.assertRaises(RuntimeError): - t.inverse(forwards[-1]) - - # Apply inverses - backwards = [forwards[-1].copy()] - for i, t in enumerate(reversed(transforms)): - if isinstance(t, InvertibleTransform): - backwards.append(t.inverse(backwards[-1])) - self.assertTrue(np.all(backwards[-1]["image"] == forwards[len(forwards) - i - 2]["image"])) - - # @parameterized.expand(TESTS_LOSSY) - def test_inverse_lossy(self, desc, data, transforms): - print("testing: " + desc) + def check_inverse(self, keys, orig_d, fwd_bck_d, unmodified_d, lossless): + for key in keys: + orig = orig_d[key] + fwd_bck = fwd_bck_d[key] + unmodified = unmodified_d[key] + try: + if lossless: + self.assertTrue(np.all(orig == fwd_bck)) + else: + fractional_diff_im = get_fractional_diff_im(orig, fwd_bck) + mean_percent_diff = 100 * np.mean(np.abs(fractional_diff_im)) + self.assertLess(mean_percent_diff, 10) + except AssertionError: + if has_matplotlib: + plot_im(orig, fwd_bck, unmodified) + raise + + # @parameterized.expand(TESTS) + def test_inverse(self, desc, data, transforms, lossless): + print(f"testing: {desc} (lossless: {lossless})...") forwards = [data.copy()] # Apply forwards @@ -182,31 +205,44 @@ def test_inverse_lossy(self, desc, data, transforms): t.inverse(forwards[-1]) # Apply inverses - backwards = [forwards[-1].copy()] + fwd_bck = forwards[-1].copy() for i, t in enumerate(reversed(transforms)): if isinstance(t, InvertibleTransform): - backwards.append(t.inverse(backwards[-1])) - mean_percent_diff = get_mean_percent_diff(backwards[-1]["image"], forwards[-i - 2]["image"]) - try: - self.assertLess(mean_percent_diff, 10) - except AssertionError: - if has_matplotlib: - plot_im(forwards[1]["image"], backwards[-1]["image"], forwards[-1]["image"]) - raise + fwd_bck = t.inverse(fwd_bck) + self.check_inverse( + data.keys(), forwards[- i - 2], fwd_bck, + forwards[-1], lossless + ) # @parameterized.expand(TESTS_FAIL) - def test_fail(self, data, transform): + def test_fail(self, data, transform, _): d = transform(data) with self.assertRaises(RuntimeError): d = transform.inverse(d) + # @parameterized.expand(TEST_COMPOSES) + def test_w_data_loader(self, desc, data, transforms, lossless): + print(f"testing: {desc}...") + transform = transforms[0] + numel = 2 + test_data = [data for _ in range(numel)] + + dataset = Dataset(data=test_data, transform=transform) + self.assertEqual(len(dataset), 2) + for data_fwd in dataset: + data_fwd_bck = transform.inverse(data_fwd) + self.check_inverse( + data.keys(), data, data_fwd_bck, + data_fwd, lossless + ) + if __name__ == "__main__": # unittest.main() test = TestInverse() - for t in TESTS_LOSSLESS: - test.test_inverse_lossless(*t) - for t in TESTS_LOSSY: - test.test_inverse_lossy(*t) + for t in TESTS: + test.test_inverse(*t) + for t in TEST_COMPOSES: + test.test_w_data_loader(*t) for t in TESTS_FAIL: test.test_fail(*t) From 1800c6a1f6fb5ce44622f8f812ee57c19b18c1a2 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Fri, 29 Jan 2021 14:35:53 +0000 Subject: [PATCH 15/80] throw error for 3d rotation inverse Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/spatial/dictionary.py | 4 ++++ tests/test_inverse.py | 9 +++++---- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index 43376cc538..0aabd3d6cf 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -833,6 +833,8 @@ def get_input_args(self, key) -> dict: def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) for idx, key in enumerate(self.keys): + if d[key][1:].ndim != 2: + raise NotImplementedError("inverse rotation only currently implemented for 2D") transform = self.get_most_recent_transform(d, key) # Create inverse transform in_angle = transform["init_args"]["angle"] @@ -972,6 +974,8 @@ def get_input_args(self, key) -> dict: def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) for idx, key in enumerate(self.keys): + if d[key][1:].ndim != 2: + raise NotImplementedError("inverse rotation only currently implemented for 2D") transform = self.get_most_recent_transform(d, key) # Check if random transform was actually performed (based on `prob`) if transform["do_transform"]: diff --git a/tests/test_inverse.py b/tests/test_inverse.py index 359ef4f467..b31aca29e7 100644 --- a/tests/test_inverse.py +++ b/tests/test_inverse.py @@ -9,18 +9,19 @@ # See the License for the specific language governing permissions and # limitations under the License. -from functools import partial + import random import unittest from typing import TYPE_CHECKING import numpy as np -from monai.data import Dataset -from monai.data import create_test_image_2d, create_test_image_3d +from monai.data import CacheDataset +from monai.data import create_test_image_2d #, create_test_image_3d from monai.transforms import AddChanneld, Compose, Rotated, RandRotated, SpatialPad, SpatialPadd from monai.transforms.transform import InvertibleTransform from monai.utils import Method, optional_import +# from functools import partial # from parameterized import parameterized @@ -227,7 +228,7 @@ def test_w_data_loader(self, desc, data, transforms, lossless): numel = 2 test_data = [data for _ in range(numel)] - dataset = Dataset(data=test_data, transform=transform) + dataset = CacheDataset(data=test_data, transform=transform) self.assertEqual(len(dataset), 2) for data_fwd in dataset: data_fwd_bck = transform.inverse(data_fwd) From be5388851221ae36500c6eaa1c4ad786d9e21cf4 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Fri, 29 Jan 2021 14:51:50 +0000 Subject: [PATCH 16/80] dataloader Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/compose.py | 4 ++-- monai/transforms/croppad/dictionary.py | 4 ++-- monai/transforms/spatial/dictionary.py | 10 ++++----- tests/test_inverse.py | 31 ++++++++++++-------------- 4 files changed, 23 insertions(+), 26 deletions(-) diff --git a/monai/transforms/compose.py b/monai/transforms/compose.py index ccf80a5f15..01c11e688a 100644 --- a/monai/transforms/compose.py +++ b/monai/transforms/compose.py @@ -16,7 +16,7 @@ from typing import Any, Callable, Mapping, Optional, Sequence, Union import numpy as np - +from copy import deepcopy from monai.transforms.transform import InvertibleTransform, Randomizable, Transform from monai.transforms.utils import apply_transform from monai.utils import MAX_SEED, ensure_tuple, get_seed @@ -119,7 +119,7 @@ def __call__(self, input_): def inverse(self, data): if not isinstance(data, Mapping): raise RuntimeError("Inverse method only available for dictionary transforms") - d = dict(data) + d = deepcopy(dict(data)) # loop over data elements for k in d: transform_key = k + "_transforms" diff --git a/monai/transforms/croppad/dictionary.py b/monai/transforms/croppad/dictionary.py index 76125f059d..30f594414e 100644 --- a/monai/transforms/croppad/dictionary.py +++ b/monai/transforms/croppad/dictionary.py @@ -17,7 +17,7 @@ from math import floor from typing import Any, Callable, Dict, Hashable, List, Mapping, Optional, Sequence, Tuple, Union - +from copy import deepcopy import numpy as np from monai.config import IndexSelection, KeysCollection @@ -132,7 +132,7 @@ def get_input_args(self, key): } def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: - d = dict(data) + d = deepcopy(dict(data)) for key in self.keys: transform = self.get_most_recent_transform(d, key) # Create inverse transform diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index 0aabd3d6cf..7803930f8e 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -19,7 +19,7 @@ import numpy as np import torch - +from copy import deepcopy from monai.config import KeysCollection from monai.networks.layers.simplelayers import GaussianFilter from monai.transforms.croppad.array import CenterSpatialCrop @@ -831,9 +831,9 @@ def get_input_args(self, key) -> dict: } def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: - d = dict(data) + d = deepcopy(dict(data)) for idx, key in enumerate(self.keys): - if d[key][1:].ndim != 2: + if d[key][0].ndim != 2: raise NotImplementedError("inverse rotation only currently implemented for 2D") transform = self.get_most_recent_transform(d, key) # Create inverse transform @@ -972,9 +972,9 @@ def get_input_args(self, key) -> dict: } def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: - d = dict(data) + d = deepcopy(dict(data)) for idx, key in enumerate(self.keys): - if d[key][1:].ndim != 2: + if d[key][0].ndim != 2: raise NotImplementedError("inverse rotation only currently implemented for 2D") transform = self.get_most_recent_transform(d, key) # Check if random transform was actually performed (based on `prob`) diff --git a/tests/test_inverse.py b/tests/test_inverse.py index b31aca29e7..d9562d6459 100644 --- a/tests/test_inverse.py +++ b/tests/test_inverse.py @@ -144,17 +144,12 @@ TEST_ROTATES.append(TEST_ROTATE) TESTS = [*TEST_SPATIALS, *TEST_COMPOSES, *TEST_ROTATES] +TESTS_DATALOADER = [*TEST_COMPOSES, *TEST_SPATIALS] TESTS_FAIL = [TEST_FAIL_0] -def get_fractional_diff_im(array_true, array): - diff = array_true - array - avg = (array_true + array) / 2 - return diff / (avg + 1e-10) - - def plot_im(orig, fwd_bck, fwd): - diff_orig_fwd_bck = 100 * get_fractional_diff_im(orig, fwd_bck) + diff_orig_fwd_bck = orig - fwd_bck fig, axes = plt.subplots( 1, 4, gridspec_kw={"width_ratios": [orig.shape[1], fwd_bck.shape[1], diff_orig_fwd_bck.shape[1], fwd.shape[1]]} ) @@ -183,9 +178,9 @@ def check_inverse(self, keys, orig_d, fwd_bck_d, unmodified_d, lossless): if lossless: self.assertTrue(np.all(orig == fwd_bck)) else: - fractional_diff_im = get_fractional_diff_im(orig, fwd_bck) - mean_percent_diff = 100 * np.mean(np.abs(fractional_diff_im)) - self.assertLess(mean_percent_diff, 10) + mean_diff = np.mean(np.abs(orig - fwd_bck)) + print(f"Mean diff = {mean_diff}") + self.assertLess(mean_diff, 1.5e-2) except AssertionError: if has_matplotlib: plot_im(orig, fwd_bck, unmodified) @@ -230,12 +225,14 @@ def test_w_data_loader(self, desc, data, transforms, lossless): dataset = CacheDataset(data=test_data, transform=transform) self.assertEqual(len(dataset), 2) - for data_fwd in dataset: - data_fwd_bck = transform.inverse(data_fwd) - self.check_inverse( - data.keys(), data, data_fwd_bck, - data_fwd, lossless - ) + num_epochs = 2 + for _ in range(num_epochs): + for data_fwd in dataset: + data_fwd_bck = transform.inverse(data_fwd) + self.check_inverse( + data.keys(), data, data_fwd_bck, + data_fwd, lossless + ) if __name__ == "__main__": @@ -243,7 +240,7 @@ def test_w_data_loader(self, desc, data, transforms, lossless): test = TestInverse() for t in TESTS: test.test_inverse(*t) - for t in TEST_COMPOSES: + for t in TESTS_DATALOADER: test.test_w_data_loader(*t) for t in TESTS_FAIL: test.test_fail(*t) From 58fdcf4371783446a67efee6b504787045bdf960 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Fri, 29 Jan 2021 15:11:29 +0000 Subject: [PATCH 17/80] autofixes Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/compose.py | 3 +- monai/transforms/croppad/dictionary.py | 3 +- monai/transforms/spatial/dictionary.py | 5 +- monai/transforms/transform.py | 4 +- tests/test_inverse.py | 151 +++++++++++++------------ 5 files changed, 88 insertions(+), 78 deletions(-) diff --git a/monai/transforms/compose.py b/monai/transforms/compose.py index 01c11e688a..7529e119cb 100644 --- a/monai/transforms/compose.py +++ b/monai/transforms/compose.py @@ -13,10 +13,11 @@ """ import warnings +from copy import deepcopy from typing import Any, Callable, Mapping, Optional, Sequence, Union import numpy as np -from copy import deepcopy + from monai.transforms.transform import InvertibleTransform, Randomizable, Transform from monai.transforms.utils import apply_transform from monai.utils import MAX_SEED, ensure_tuple, get_seed diff --git a/monai/transforms/croppad/dictionary.py b/monai/transforms/croppad/dictionary.py index 30f594414e..3285e82523 100644 --- a/monai/transforms/croppad/dictionary.py +++ b/monai/transforms/croppad/dictionary.py @@ -15,9 +15,10 @@ Class names are ended with 'd' to denote dictionary-based transforms. """ +from copy import deepcopy from math import floor from typing import Any, Callable, Dict, Hashable, List, Mapping, Optional, Sequence, Tuple, Union -from copy import deepcopy + import numpy as np from monai.config import IndexSelection, KeysCollection diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index 7803930f8e..0dfaf04304 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -15,11 +15,12 @@ Class names are ended with 'd' to denote dictionary-based transforms. """ +from copy import deepcopy from typing import Any, Dict, Hashable, Mapping, Optional, Sequence, Tuple, Union import numpy as np import torch -from copy import deepcopy + from monai.config import KeysCollection from monai.networks.layers.simplelayers import GaussianFilter from monai.transforms.croppad.array import CenterSpatialCrop @@ -941,7 +942,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda for key in self.keys: self.append_applied_transforms(d, key) return d - angle=self.x if d[self.keys[0]].ndim == 3 else (self.x, self.y, self.z), + angle = (self.x if d[self.keys[0]].ndim == 3 else (self.x, self.y, self.z),) rotator = Rotate( angle=angle, keep_size=self.keep_size, diff --git a/monai/transforms/transform.py b/monai/transforms/transform.py index 738b718ce0..d130b4b0e7 100644 --- a/monai/transforms/transform.py +++ b/monai/transforms/transform.py @@ -225,7 +225,9 @@ def append_applied_transforms(self, data: dict, key: Hashable, extra_args: Optio # If this is the first, create list if key_transform not in data: data[key_transform] = [] - data[key_transform].append({"class": type(self), "init_args": self.get_input_args(key), "extra_info": extra_args}) + data[key_transform].append( + {"class": type(self), "init_args": self.get_input_args(key), "extra_info": extra_args} + ) # If class is randomizable, store whether the transform was actually performed (based on `prob`) if isinstance(self, Randomizable): data[key_transform][-1]["do_transform"] = self._do_transform diff --git a/tests/test_inverse.py b/tests/test_inverse.py index d9562d6459..0d0b7d6cec 100644 --- a/tests/test_inverse.py +++ b/tests/test_inverse.py @@ -15,9 +15,10 @@ from typing import TYPE_CHECKING import numpy as np + +from monai.data import create_test_image_2d # , create_test_image_3d from monai.data import CacheDataset -from monai.data import create_test_image_2d #, create_test_image_3d -from monai.transforms import AddChanneld, Compose, Rotated, RandRotated, SpatialPad, SpatialPadd +from monai.transforms import AddChanneld, Compose, RandRotated, Rotated, SpatialPad, SpatialPadd from monai.transforms.transform import InvertibleTransform from monai.utils import Method, optional_import @@ -33,74 +34,84 @@ plt, has_matplotlib = optional_import("matplotlib.pyplot") TEST_SPATIALS = [] -TEST_SPATIALS.append([ - "Spatial 1d", - {"image": np.arange(0, 10).reshape(1, 10)}, +TEST_SPATIALS.append( [ - SpatialPadd("image", spatial_size=[15]), - SpatialPadd("image", spatial_size=[21], method=Method.END), - SpatialPadd("image", spatial_size=[24]), - ], - True, -]) - -TEST_SPATIALS.append([ - "Spatial 2d", - {"image": np.arange(0, 10 * 9).reshape(1, 10, 9)}, + "Spatial 1d", + {"image": np.arange(0, 10).reshape(1, 10)}, + [ + SpatialPadd("image", spatial_size=[15]), + SpatialPadd("image", spatial_size=[21], method=Method.END), + SpatialPadd("image", spatial_size=[24]), + ], + True, + ] +) + +TEST_SPATIALS.append( [ - SpatialPadd("image", spatial_size=[11, 12]), - SpatialPadd("image", spatial_size=[12, 21]), - SpatialPadd("image", spatial_size=[14, 25], method=Method.END), - ], - True, -]) - -TEST_SPATIALS.append([ - "Spatial 3d", - {"image": np.arange(0, 10 * 9 * 8).reshape(1, 10, 9, 8)}, + "Spatial 2d", + {"image": np.arange(0, 10 * 9).reshape(1, 10, 9)}, + [ + SpatialPadd("image", spatial_size=[11, 12]), + SpatialPadd("image", spatial_size=[12, 21]), + SpatialPadd("image", spatial_size=[14, 25], method=Method.END), + ], + True, + ] +) + +TEST_SPATIALS.append( [ - SpatialPadd("image", spatial_size=[55, 50, 45]), - ], - True -]) + "Spatial 3d", + {"image": np.arange(0, 10 * 9 * 8).reshape(1, 10, 9, 8)}, + [ + SpatialPadd("image", spatial_size=[55, 50, 45]), + ], + True, + ] +) TEST_COMPOSES = [] -TEST_COMPOSES.append([ - "Compose 2d", - { - "image": np.arange(0, 10 * 9).reshape(1, 10, 9), - "label": np.arange(0, 10 * 9).reshape(1, 10, 9), - "other": np.arange(0, 10 * 9).reshape(1, 10, 9), - }, +TEST_COMPOSES.append( [ - Compose( - [ - SpatialPadd(["image", "label"], spatial_size=[15, 12]), - SpatialPadd(["label"], spatial_size=[21, 32]), - SpatialPadd(["image"], spatial_size=[55, 50]), - ] - ) - ], - True, -]) -TEST_COMPOSES.append([ - "Compose 3d", - { - "image": np.arange(0, 10 * 9 * 8).reshape(1, 10, 9, 8), - "label": np.arange(0, 10 * 9 * 8).reshape(1, 10, 9, 8), - "other": np.arange(0, 10 * 9 * 8).reshape(1, 10, 9, 8), - }, + "Compose 2d", + { + "image": np.arange(0, 10 * 9).reshape(1, 10, 9), + "label": np.arange(0, 10 * 9).reshape(1, 10, 9), + "other": np.arange(0, 10 * 9).reshape(1, 10, 9), + }, + [ + Compose( + [ + SpatialPadd(["image", "label"], spatial_size=[15, 12]), + SpatialPadd(["label"], spatial_size=[21, 32]), + SpatialPadd(["image"], spatial_size=[55, 50]), + ] + ) + ], + True, + ] +) +TEST_COMPOSES.append( [ - Compose( - [ - SpatialPadd(["image", "label"], spatial_size=[15, 12, 4]), - SpatialPadd(["label"], spatial_size=[21, 32, 1]), - SpatialPadd(["image"], spatial_size=[55, 50, 45]), - ] - ) - ], - True, -]) + "Compose 3d", + { + "image": np.arange(0, 10 * 9 * 8).reshape(1, 10, 9, 8), + "label": np.arange(0, 10 * 9 * 8).reshape(1, 10, 9, 8), + "other": np.arange(0, 10 * 9 * 8).reshape(1, 10, 9, 8), + }, + [ + Compose( + [ + SpatialPadd(["image", "label"], spatial_size=[15, 12, 4]), + SpatialPadd(["label"], spatial_size=[21, 32, 1]), + SpatialPadd(["image"], spatial_size=[55, 50, 45]), + ] + ) + ], + True, + ] +) TEST_FAIL_0 = [ np.arange(0, 10).reshape(1, 10), @@ -114,7 +125,7 @@ # TODO: add 3D TEST_ROTATES = [] -for create_im in [create_test_image_2d]: #, partial(create_test_image_3d, 100)]: +for create_im in [create_test_image_2d]: # , partial(create_test_image_3d, 100)]: for keep_size in [True, False]: for align_corners in [False, True]: im, _ = create_im(100, 100) @@ -167,8 +178,8 @@ def plot_im(orig, fwd_bck, fwd): fig.colorbar(im_show, ax=ax) plt.show() -class TestInverse(unittest.TestCase): +class TestInverse(unittest.TestCase): def check_inverse(self, keys, orig_d, fwd_bck_d, unmodified_d, lossless): for key in keys: orig = orig_d[key] @@ -205,10 +216,7 @@ def test_inverse(self, desc, data, transforms, lossless): for i, t in enumerate(reversed(transforms)): if isinstance(t, InvertibleTransform): fwd_bck = t.inverse(fwd_bck) - self.check_inverse( - data.keys(), forwards[- i - 2], fwd_bck, - forwards[-1], lossless - ) + self.check_inverse(data.keys(), forwards[-i - 2], fwd_bck, forwards[-1], lossless) # @parameterized.expand(TESTS_FAIL) def test_fail(self, data, transform, _): @@ -229,10 +237,7 @@ def test_w_data_loader(self, desc, data, transforms, lossless): for _ in range(num_epochs): for data_fwd in dataset: data_fwd_bck = transform.inverse(data_fwd) - self.check_inverse( - data.keys(), data, data_fwd_bck, - data_fwd, lossless - ) + self.check_inverse(data.keys(), data, data_fwd_bck, data_fwd, lossless) if __name__ == "__main__": From 5bd4d1cfd5e0b840abb10d0f1b78dd3ed5892df6 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Fri, 29 Jan 2021 16:28:47 +0000 Subject: [PATCH 18/80] add constructors for Randomizable class Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/apps/datasets.py | 8 +++-- monai/transforms/intensity/array.py | 23 +++++---------- monai/transforms/intensity/dictionary.py | 37 +++++++++--------------- monai/transforms/spatial/array.py | 36 +++++++++-------------- monai/transforms/spatial/dictionary.py | 28 ++++++++---------- monai/transforms/transform.py | 4 +++ tests/test_inverse.py | 6 ++-- 7 files changed, 60 insertions(+), 82 deletions(-) diff --git a/monai/apps/datasets.py b/monai/apps/datasets.py index d8fd815ce9..f0416b8c4f 100644 --- a/monai/apps/datasets.py +++ b/monai/apps/datasets.py @@ -94,7 +94,9 @@ def __init__( data = self._generate_data_list(dataset_dir) if transform == (): transform = LoadImaged("image") - super().__init__(data, transform, cache_num=cache_num, cache_rate=cache_rate, num_workers=num_workers) + CacheDataset.__init__( + self, data, transform, cache_num=cache_num, cache_rate=cache_rate, num_workers=num_workers + ) def randomize(self, data: Optional[Any] = None) -> None: self.rann = self.R.random() @@ -275,7 +277,9 @@ def __init__( self._properties = load_decathlon_properties(os.path.join(dataset_dir, "dataset.json"), property_keys) if transform == (): transform = LoadImaged(["image", "label"]) - super().__init__(data, transform, cache_num=cache_num, cache_rate=cache_rate, num_workers=num_workers) + CacheDataset.__init__( + self, data, transform, cache_num=cache_num, cache_rate=cache_rate, num_workers=num_workers + ) def get_indices(self) -> np.ndarray: """ diff --git a/monai/transforms/intensity/array.py b/monai/transforms/intensity/array.py index 205b719246..e8162d315c 100644 --- a/monai/transforms/intensity/array.py +++ b/monai/transforms/intensity/array.py @@ -59,10 +59,9 @@ class RandGaussianNoise(Randomizable, Transform): """ def __init__(self, prob: float = 0.1, mean: Union[Sequence[float], float] = 0.0, std: float = 0.1) -> None: - self.prob = prob + Randomizable.__init__(self, prob) self.mean = mean self.std = std - self._do_transform = False self._noise = None def randomize(self, im_shape: Sequence[int]) -> None: @@ -112,6 +111,7 @@ def __init__(self, offsets: Union[Tuple[float, float], float], prob: float = 0.1 if single number, offset value is picked from (-offsets, offsets). prob: probability of shift. """ + Randomizable.__init__(self, prob) if isinstance(offsets, (int, float)): self.offsets = (min(-offsets, offsets), max(-offsets, offsets)) else: @@ -119,9 +119,6 @@ def __init__(self, offsets: Union[Tuple[float, float], float], prob: float = 0.1 raise AssertionError("offsets should be a number or pair of numbers.") self.offsets = (min(offsets), max(offsets)) - self.prob = prob - self._do_transform = False - def randomize(self, data: Optional[Any] = None) -> None: self._offset = self.R.uniform(low=self.offsets[0], high=self.offsets[1]) self._do_transform = self.R.random() < self.prob @@ -185,6 +182,7 @@ def __init__(self, factors: Union[Tuple[float, float], float], prob: float = 0.1 prob: probability of scale. """ + Randomizable.__init__(self, prob) if isinstance(factors, (int, float)): self.factors = (min(-factors, factors), max(-factors, factors)) else: @@ -192,9 +190,6 @@ def __init__(self, factors: Union[Tuple[float, float], float], prob: float = 0.1 raise AssertionError("factors should be a number or pair of numbers.") self.factors = (min(factors), max(factors)) - self.prob = prob - self._do_transform = False - def randomize(self, data: Optional[Any] = None) -> None: self.factor = self.R.uniform(low=self.factors[0], high=self.factors[1]) self._do_transform = self.R.random() < self.prob @@ -381,7 +376,7 @@ class RandAdjustContrast(Randomizable, Transform): """ def __init__(self, prob: float = 0.1, gamma: Union[Sequence[float], float] = (0.5, 4.5)) -> None: - self.prob = prob + Randomizable.__init__(self, prob) if isinstance(gamma, (int, float)): if gamma <= 0.5: @@ -394,7 +389,6 @@ def __init__(self, prob: float = 0.1, gamma: Union[Sequence[float], float] = (0. raise AssertionError("gamma should be a number or pair of numbers.") self.gamma = (min(gamma), max(gamma)) - self._do_transform = False self.gamma_value = None def randomize(self, data: Optional[Any] = None) -> None: @@ -669,12 +663,11 @@ def __init__( prob: float = 0.1, approx: str = "erf", ) -> None: + Randomizable.__init__(self, prob) self.sigma_x = sigma_x self.sigma_y = sigma_y self.sigma_z = sigma_z - self.prob = prob self.approx = approx - self._do_transform = False def randomize(self, data: Optional[Any] = None) -> None: self._do_transform = self.R.random_sample() < self.prob @@ -772,6 +765,7 @@ def __init__( approx: str = "erf", prob: float = 0.1, ) -> None: + Randomizable.__init__(self, prob) self.sigma1_x = sigma1_x self.sigma1_y = sigma1_y self.sigma1_z = sigma1_z @@ -780,8 +774,6 @@ def __init__( self.sigma2_z = sigma2_z self.alpha = alpha self.approx = approx - self.prob = prob - self._do_transform = False def randomize(self, data: Optional[Any] = None) -> None: self._do_transform = self.R.random_sample() < self.prob @@ -817,6 +809,7 @@ class RandHistogramShift(Randomizable, Transform): """ def __init__(self, num_control_points: Union[Tuple[int, int], int] = 10, prob: float = 0.1) -> None: + Randomizable.__init__(self, prob) if isinstance(num_control_points, int): if num_control_points <= 2: @@ -828,8 +821,6 @@ def __init__(self, num_control_points: Union[Tuple[int, int], int] = 10, prob: f if min(num_control_points) <= 2: raise AssertionError("num_control_points should be greater than or equal to 3") self.num_control_points = (min(num_control_points), max(num_control_points)) - self.prob = prob - self._do_transform = False def randomize(self, data: Optional[Any] = None) -> None: self._do_transform = self.R.random() < self.prob diff --git a/monai/transforms/intensity/dictionary.py b/monai/transforms/intensity/dictionary.py index 555e157db9..019988016e 100644 --- a/monai/transforms/intensity/dictionary.py +++ b/monai/transforms/intensity/dictionary.py @@ -108,11 +108,10 @@ class RandGaussianNoised(Randomizable, MapTransform): def __init__( self, keys: KeysCollection, prob: float = 0.1, mean: Union[Sequence[float], float] = 0.0, std: float = 0.1 ) -> None: - super().__init__(keys) - self.prob = prob + MapTransform.__init__(self, keys) + Randomizable.__init__(self, prob) self.mean = ensure_tuple_size(mean, len(self.keys)) self.std = std - self._do_transform = False self._noise: Optional[np.ndarray] = None def randomize(self, im_shape: Sequence[int]) -> None: @@ -171,7 +170,8 @@ def __init__(self, keys: KeysCollection, offsets: Union[Tuple[float, float], flo prob: probability of rotating. (Default 0.1, with 10% probability it returns a rotated array.) """ - super().__init__(keys) + MapTransform.__init__(self, keys) + Randomizable.__init__(self, prob) if isinstance(offsets, (int, float)): self.offsets = (min(-offsets, offsets), max(-offsets, offsets)) @@ -180,9 +180,6 @@ def __init__(self, keys: KeysCollection, offsets: Union[Tuple[float, float], flo raise AssertionError("offsets should be a number or pair of numbers.") self.offsets = (min(offsets), max(offsets)) - self.prob = prob - self._do_transform = False - def randomize(self, data: Optional[Any] = None) -> None: self._offset = self.R.uniform(low=self.offsets[0], high=self.offsets[1]) self._do_transform = self.R.random() < self.prob @@ -243,7 +240,8 @@ def __init__(self, keys: KeysCollection, factors: Union[Tuple[float, float], flo (Default 0.1, with 10% probability it returns a rotated array.) """ - super().__init__(keys) + MapTransform.__init__(self, keys) + Randomizable.__init__(self, prob) if isinstance(factors, (int, float)): self.factors = (min(-factors, factors), max(-factors, factors)) @@ -252,9 +250,6 @@ def __init__(self, keys: KeysCollection, factors: Union[Tuple[float, float], flo raise AssertionError("factors should be a number or pair of numbers.") self.factors = (min(factors), max(factors)) - self.prob = prob - self._do_transform = False - def randomize(self, data: Optional[Any] = None) -> None: self.factor = self.R.uniform(low=self.factors[0], high=self.factors[1]) self._do_transform = self.R.random() < self.prob @@ -398,8 +393,8 @@ class RandAdjustContrastd(Randomizable, MapTransform): def __init__( self, keys: KeysCollection, prob: float = 0.1, gamma: Union[Tuple[float, float], float] = (0.5, 4.5) ) -> None: - super().__init__(keys) - self.prob: float = prob + MapTransform.__init__(self, keys) + Randomizable.__init__(self, prob) if isinstance(gamma, (int, float)): if gamma <= 0.5: @@ -412,7 +407,6 @@ def __init__( raise AssertionError("gamma should be a number or pair of numbers.") self.gamma = (min(gamma), max(gamma)) - self._do_transform = False self.gamma_value: Optional[float] = None def randomize(self, data: Optional[Any] = None) -> None: @@ -552,13 +546,12 @@ def __init__( approx: str = "erf", prob: float = 0.1, ) -> None: - super().__init__(keys) + MapTransform.__init__(self, keys) + Randomizable.__init__(self, prob) self.sigma_x = sigma_x self.sigma_y = sigma_y self.sigma_z = sigma_z self.approx = approx - self.prob = prob - self._do_transform = False def randomize(self, data: Optional[Any] = None) -> None: self._do_transform = self.R.random_sample() < self.prob @@ -650,7 +643,8 @@ def __init__( approx: str = "erf", prob: float = 0.1, ): - super().__init__(keys) + MapTransform.__init__(self, keys) + Randomizable.__init__(self, prob) self.sigma1_x = sigma1_x self.sigma1_y = sigma1_y self.sigma1_z = sigma1_z @@ -659,8 +653,6 @@ def __init__( self.sigma2_z = sigma2_z self.alpha = alpha self.approx = approx - self.prob = prob - self._do_transform = False def randomize(self, data: Optional[Any] = None) -> None: self._do_transform = self.R.random_sample() < self.prob @@ -704,7 +696,8 @@ class RandHistogramShiftd(Randomizable, MapTransform): def __init__( self, keys: KeysCollection, num_control_points: Union[Tuple[int, int], int] = 10, prob: float = 0.1 ) -> None: - super().__init__(keys) + MapTransform.__init__(self, keys) + Randomizable.__init__(self, prob) if isinstance(num_control_points, int): if num_control_points <= 2: raise AssertionError("num_control_points should be greater than or equal to 3") @@ -715,8 +708,6 @@ def __init__( if min(num_control_points) <= 2: raise AssertionError("num_control_points should be greater than or equal to 3") self.num_control_points = (min(num_control_points), max(num_control_points)) - self.prob = prob - self._do_transform = False def randomize(self, data: Optional[Any] = None) -> None: self._do_transform = self.R.random() < self.prob diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index 3e1ded4e94..df5dfb6161 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -609,11 +609,10 @@ def __init__(self, prob: float = 0.1, max_k: int = 3, spatial_axes: Tuple[int, i spatial_axes: 2 int numbers, defines the plane to rotate with 2 spatial axes. Default: (0, 1), this is the first two axis in spatial dimensions. """ - self.prob = min(max(prob, 0.0), 1.0) + Randomizable.__init__(self, min(max(prob, 0.0), 1.0)) self.max_k = max_k self.spatial_axes = spatial_axes - self._do_transform = False self._rand_k = 0 def randomize(self, data: Optional[Any] = None) -> None: @@ -672,6 +671,7 @@ def __init__( align_corners: bool = False, dtype: Optional[np.dtype] = np.float64, ) -> None: + Randomizable.__init__(self, prob) self.range_x = ensure_tuple(range_x) if len(self.range_x) == 1: self.range_x = tuple(sorted([-self.range_x[0], self.range_x[0]])) @@ -682,14 +682,12 @@ def __init__( if len(self.range_z) == 1: self.range_z = tuple(sorted([-self.range_z[0], self.range_z[0]])) - self.prob = prob self.keep_size = keep_size self.mode: GridSampleMode = GridSampleMode(mode) self.padding_mode: GridSamplePadMode = GridSamplePadMode(padding_mode) self.align_corners = align_corners self.dtype = dtype - self._do_transform = False self.x = 0.0 self.y = 0.0 self.z = 0.0 @@ -749,9 +747,8 @@ class RandFlip(Randomizable, Transform): """ def __init__(self, prob: float = 0.1, spatial_axis: Optional[Union[Sequence[int], int]] = None) -> None: - self.prob = prob + Randomizable.__init__(self, min(max(prob, 0.0), 1.0)) self.flipper = Flip(spatial_axis=spatial_axis) - self._do_transform = False def randomize(self, data: Optional[Any] = None) -> None: self._do_transform = self.R.random_sample() < self.prob @@ -806,17 +803,16 @@ def __init__( align_corners: Optional[bool] = None, keep_size: bool = True, ) -> None: + Randomizable.__init__(self, prob) self.min_zoom = ensure_tuple(min_zoom) self.max_zoom = ensure_tuple(max_zoom) if len(self.min_zoom) != len(self.max_zoom): raise AssertionError("min_zoom and max_zoom must have same length.") - self.prob = prob self.mode: InterpolateMode = InterpolateMode(mode) self.padding_mode: NumpyPadMode = NumpyPadMode(padding_mode) self.align_corners = align_corners self.keep_size = keep_size - self._do_transform = False self._zoom: Sequence[float] = [1.0] def randomize(self, data: Optional[Any] = None) -> None: @@ -1319,6 +1315,7 @@ def __init__( - :py:class:`RandAffineGrid` for the random affine parameters configurations. - :py:class:`Affine` for the affine transformation parameters configurations. """ + Randomizable.__init__(self, prob) self.rand_affine_grid = RandAffineGrid( rotate_range=rotate_range, @@ -1334,9 +1331,6 @@ def __init__( self.mode: GridSampleMode = GridSampleMode(mode) self.padding_mode: GridSamplePadMode = GridSamplePadMode(padding_mode) - self.do_transform = False - self.prob = prob - def set_random_state( self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None ) -> "RandAffine": @@ -1345,7 +1339,7 @@ def set_random_state( return self def randomize(self, data: Optional[Any] = None) -> None: - self.do_transform = self.R.rand() < self.prob + self._do_transform = self.R.rand() < self.prob self.rand_affine_grid.randomize() def __call__( @@ -1373,7 +1367,7 @@ def __call__( self.randomize() sp_size = fall_back_tuple(spatial_size or self.spatial_size, img.shape[1:]) - if self.do_transform: + if self._do_transform: grid = self.rand_affine_grid(spatial_size=sp_size) else: grid = create_grid(spatial_size=sp_size) @@ -1440,6 +1434,7 @@ def __init__( - :py:class:`RandAffineGrid` for the random affine parameters configurations. - :py:class:`Affine` for the affine transformation parameters configurations. """ + Randomizable.__init__(self, prob) self.deform_grid = RandDeformGrid( spacing=spacing, magnitude_range=magnitude_range, as_tensor_output=True, device=device ) @@ -1456,8 +1451,6 @@ def __init__( self.spatial_size = spatial_size self.mode: GridSampleMode = GridSampleMode(mode) self.padding_mode: GridSamplePadMode = GridSamplePadMode(padding_mode) - self.prob = prob - self.do_transform = False def set_random_state( self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None @@ -1468,7 +1461,7 @@ def set_random_state( return self def randomize(self, spatial_size: Sequence[int]) -> None: - self.do_transform = self.R.rand() < self.prob + self._do_transform = self.R.rand() < self.prob self.deform_grid.randomize(spatial_size) self.rand_affine_grid.randomize() @@ -1494,7 +1487,7 @@ def __call__( """ sp_size = fall_back_tuple(spatial_size or self.spatial_size, img.shape[1:]) self.randomize(spatial_size=sp_size) - if self.do_transform: + if self._do_transform: grid = self.deform_grid(spatial_size=sp_size) grid = self.rand_affine_grid(grid=grid) grid = torch.nn.functional.interpolate( # type: ignore @@ -1572,6 +1565,7 @@ def __init__( - :py:class:`RandAffineGrid` for the random affine parameters configurations. - :py:class:`Affine` for the affine transformation parameters configurations. """ + Randomizable.__init__(self, prob) self.rand_affine_grid = RandAffineGrid(rotate_range, shear_range, translate_range, scale_range, True, device) self.resampler = Resample(as_tensor_output=as_tensor_output, device=device) @@ -1582,8 +1576,6 @@ def __init__( self.padding_mode: GridSamplePadMode = GridSamplePadMode(padding_mode) self.device = device - self.prob = prob - self.do_transform = False self.rand_offset = None self.magnitude = 1.0 self.sigma = 1.0 @@ -1596,8 +1588,8 @@ def set_random_state( return self def randomize(self, grid_size: Sequence[int]) -> None: - self.do_transform = self.R.rand() < self.prob - if self.do_transform: + self._do_transform = self.R.rand() < self.prob + if self._do_transform: self.rand_offset = self.R.uniform(-1.0, 1.0, [3] + list(grid_size)).astype(np.float32) self.magnitude = self.R.uniform(self.magnitude_range[0], self.magnitude_range[1]) self.sigma = self.R.uniform(self.sigma_range[0], self.sigma_range[1]) @@ -1626,7 +1618,7 @@ def __call__( sp_size = fall_back_tuple(spatial_size or self.spatial_size, img.shape[1:]) self.randomize(grid_size=sp_size) grid = create_grid(spatial_size=sp_size) - if self.do_transform: + if self._do_transform: if self.rand_offset is None: raise AssertionError grid = torch.as_tensor(np.ascontiguousarray(grid), device=self.device) diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index 0dfaf04304..d793ca96c8 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -300,13 +300,12 @@ def __init__( spatial_axes: 2 int numbers, defines the plane to rotate with 2 spatial axes. Default: (0, 1), this is the first two axis in spatial dimensions. """ - super().__init__(keys) + MapTransform.__init__(self, keys) + Randomizable.__init__(self, min(max(prob, 0.0), 1.0)) - self.prob = min(max(prob, 0.0), 1.0) self.max_k = max_k self.spatial_axes = spatial_axes - self._do_transform = False self._rand_k = 0 def randomize(self, data: Optional[Any] = None) -> None: @@ -457,7 +456,7 @@ def __call__( self.randomize() sp_size = fall_back_tuple(self.rand_affine.spatial_size, data[self.keys[0]].shape[1:]) - if self.rand_affine.do_transform: + if self.rand_affine._do_transform: grid = self.rand_affine.rand_affine_grid(spatial_size=sp_size) else: grid = create_grid(spatial_size=sp_size) @@ -564,7 +563,7 @@ def __call__( sp_size = fall_back_tuple(self.rand_2d_elastic.spatial_size, data[self.keys[0]].shape[1:]) self.randomize(spatial_size=sp_size) - if self.rand_2d_elastic.do_transform: + if self.rand_2d_elastic._do_transform: grid = self.rand_2d_elastic.deform_grid(spatial_size=sp_size) grid = self.rand_2d_elastic.rand_affine_grid(grid=grid) grid = torch.nn.functional.interpolate( # type: ignore @@ -685,7 +684,7 @@ def __call__( self.randomize(grid_size=sp_size) grid = create_grid(spatial_size=sp_size) - if self.rand_3d_elastic.do_transform: + if self.rand_3d_elastic._do_transform: device = self.rand_3d_elastic.device grid = torch.tensor(grid).to(device) gaussian = GaussianFilter(spatial_dims=3, sigma=self.rand_3d_elastic.sigma, truncated=3.0).to(device) @@ -742,11 +741,10 @@ def __init__( prob: float = 0.1, spatial_axis: Optional[Union[Sequence[int], int]] = None, ) -> None: - super().__init__(keys) + MapTransform.__init__(self, keys) + Randomizable.__init__(self, prob) self.spatial_axis = spatial_axis - self.prob = prob - self._do_transform = False self.flipper = Flip(spatial_axis=spatial_axis) def randomize(self, data: Optional[Any] = None) -> None: @@ -906,7 +904,8 @@ def __init__( align_corners: Union[Sequence[bool], bool] = False, dtype: Union[Sequence[Optional[np.dtype]], Optional[np.dtype]] = np.float64, ) -> None: - super().__init__(keys) + MapTransform.__init__(self, keys) + Randomizable.__init__(self, prob) self.range_x = ensure_tuple(range_x) if len(self.range_x) == 1: self.range_x = tuple(sorted([-self.range_x[0], self.range_x[0]])) @@ -917,14 +916,12 @@ def __init__( if len(self.range_z) == 1: self.range_z = tuple(sorted([-self.range_z[0], self.range_z[0]])) - self.prob = prob self.keep_size = keep_size self.mode = ensure_tuple_rep(mode, len(self.keys)) self.padding_mode = ensure_tuple_rep(padding_mode, len(self.keys)) self.align_corners = ensure_tuple_rep(align_corners, len(self.keys)) self.dtype = ensure_tuple_rep(dtype, len(self.keys)) - self._do_transform = False self.x = 0.0 self.y = 0.0 self.z = 0.0 @@ -942,7 +939,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda for key in self.keys: self.append_applied_transforms(d, key) return d - angle = (self.x if d[self.keys[0]].ndim == 3 else (self.x, self.y, self.z),) + angle: Sequence = (self.x if d[self.keys[0]].ndim == 3 else (self.x, self.y, self.z),) rotator = Rotate( angle=angle, keep_size=self.keep_size, @@ -1096,19 +1093,18 @@ def __init__( align_corners: Union[Sequence[Optional[bool]], Optional[bool]] = None, keep_size: bool = True, ) -> None: - super().__init__(keys) + MapTransform.__init__(self, keys) + Randomizable.__init__(self, prob) self.min_zoom = ensure_tuple(min_zoom) self.max_zoom = ensure_tuple(max_zoom) if len(self.min_zoom) != len(self.max_zoom): raise AssertionError("min_zoom and max_zoom must have same length.") - self.prob = prob self.mode = ensure_tuple_rep(mode, len(self.keys)) self.padding_mode = ensure_tuple_rep(padding_mode, len(self.keys)) self.align_corners = ensure_tuple_rep(align_corners, len(self.keys)) self.keep_size = keep_size - self._do_transform = False self._zoom: Sequence[float] = [1.0] def randomize(self, data: Optional[Any] = None) -> None: diff --git a/monai/transforms/transform.py b/monai/transforms/transform.py index d130b4b0e7..256e751c74 100644 --- a/monai/transforms/transform.py +++ b/monai/transforms/transform.py @@ -43,6 +43,10 @@ def __call__(self, img): R: np.random.RandomState = np.random.RandomState() + def __init__(self, prob): + self._do_transform = False + self.prob = prob + def set_random_state( self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None ) -> "Randomizable": diff --git a/tests/test_inverse.py b/tests/test_inverse.py index 0d0b7d6cec..4f4491f7f2 100644 --- a/tests/test_inverse.py +++ b/tests/test_inverse.py @@ -142,13 +142,13 @@ TEST_ROTATES.append(TEST_ROTATE) for prob in [0, 1]: im, _ = create_im(100, 100) - angles = [random.uniform(np.pi / 6, np.pi) for _ in range(3)] + x, y, z = (random.uniform(np.pi / 6, np.pi) for _ in range(3)) TEST_ROTATE = [ f"RandRotate{im.ndim}d, prob={prob}", {"image": im}, [ AddChanneld("image"), - RandRotated("image", *angles, prob, True, "bilinear", "border", False), + RandRotated("image", x, y, z, prob, True, "bilinear", "border", False), ], False, ] @@ -191,7 +191,7 @@ def check_inverse(self, keys, orig_d, fwd_bck_d, unmodified_d, lossless): else: mean_diff = np.mean(np.abs(orig - fwd_bck)) print(f"Mean diff = {mean_diff}") - self.assertLess(mean_diff, 1.5e-2) + self.assertLess(mean_diff, 3e-2) except AssertionError: if has_matplotlib: plot_im(orig, fwd_bck, unmodified) From 50e723837f42c56b21c5a0044f28d216806c175a Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Fri, 29 Jan 2021 16:38:39 +0000 Subject: [PATCH 19/80] testing Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- tests/test_inverse.py | 36 +++++++++++++++++------------------- 1 file changed, 17 insertions(+), 19 deletions(-) diff --git a/tests/test_inverse.py b/tests/test_inverse.py index 4f4491f7f2..f7d7d42d2c 100644 --- a/tests/test_inverse.py +++ b/tests/test_inverse.py @@ -43,7 +43,7 @@ SpatialPadd("image", spatial_size=[21], method=Method.END), SpatialPadd("image", spatial_size=[24]), ], - True, + 0.0, ] ) @@ -56,7 +56,7 @@ SpatialPadd("image", spatial_size=[12, 21]), SpatialPadd("image", spatial_size=[14, 25], method=Method.END), ], - True, + 0.0, ] ) @@ -67,7 +67,7 @@ [ SpatialPadd("image", spatial_size=[55, 50, 45]), ], - True, + 0.0, ] ) @@ -89,7 +89,7 @@ ] ) ], - True, + 0.0, ] ) TEST_COMPOSES.append( @@ -109,7 +109,7 @@ ] ) ], - True, + 0.0, ] ) @@ -120,7 +120,7 @@ SpatialPad(spatial_size=[15]), ] ), - True, + 0.0, ] # TODO: add 3D @@ -137,7 +137,7 @@ AddChanneld("image"), Rotated("image", angle, keep_size, "bilinear", "border", align_corners), ], - False, + 5e-2, ] TEST_ROTATES.append(TEST_ROTATE) for prob in [0, 1]: @@ -150,7 +150,7 @@ AddChanneld("image"), RandRotated("image", x, y, z, prob, True, "bilinear", "border", False), ], - False, + 5e-2, ] TEST_ROTATES.append(TEST_ROTATE) @@ -180,26 +180,24 @@ def plot_im(orig, fwd_bck, fwd): class TestInverse(unittest.TestCase): - def check_inverse(self, keys, orig_d, fwd_bck_d, unmodified_d, lossless): + def check_inverse(self, keys, orig_d, fwd_bck_d, unmodified_d, acceptable_diff): for key in keys: orig = orig_d[key] fwd_bck = fwd_bck_d[key] unmodified = unmodified_d[key] try: - if lossless: - self.assertTrue(np.all(orig == fwd_bck)) - else: - mean_diff = np.mean(np.abs(orig - fwd_bck)) + mean_diff = np.mean(np.abs(orig - fwd_bck)) + if acceptable_diff > 0: print(f"Mean diff = {mean_diff}") - self.assertLess(mean_diff, 3e-2) + self.assertLessEqual(mean_diff, acceptable_diff) except AssertionError: if has_matplotlib: plot_im(orig, fwd_bck, unmodified) raise # @parameterized.expand(TESTS) - def test_inverse(self, desc, data, transforms, lossless): - print(f"testing: {desc} (lossless: {lossless})...") + def test_inverse(self, desc, data, transforms, acceptable_diff): + print(f"testing: {desc} (acceptable diff: {acceptable_diff})") forwards = [data.copy()] # Apply forwards @@ -216,7 +214,7 @@ def test_inverse(self, desc, data, transforms, lossless): for i, t in enumerate(reversed(transforms)): if isinstance(t, InvertibleTransform): fwd_bck = t.inverse(fwd_bck) - self.check_inverse(data.keys(), forwards[-i - 2], fwd_bck, forwards[-1], lossless) + self.check_inverse(data.keys(), forwards[-i - 2], fwd_bck, forwards[-1], acceptable_diff) # @parameterized.expand(TESTS_FAIL) def test_fail(self, data, transform, _): @@ -225,7 +223,7 @@ def test_fail(self, data, transform, _): d = transform.inverse(d) # @parameterized.expand(TEST_COMPOSES) - def test_w_data_loader(self, desc, data, transforms, lossless): + def test_w_data_loader(self, desc, data, transforms, acceptable_diff): print(f"testing: {desc}...") transform = transforms[0] numel = 2 @@ -237,7 +235,7 @@ def test_w_data_loader(self, desc, data, transforms, lossless): for _ in range(num_epochs): for data_fwd in dataset: data_fwd_bck = transform.inverse(data_fwd) - self.check_inverse(data.keys(), data, data_fwd_bck, data_fwd, lossless) + self.check_inverse(data.keys(), data, data_fwd_bck, data_fwd, acceptable_diff) if __name__ == "__main__": From 9241cf75900bb33ecf1e79f06c1359cbdd7b9219 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Fri, 29 Jan 2021 17:47:39 +0000 Subject: [PATCH 20/80] start adding spatialcropd Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/croppad/dictionary.py | 30 ++++- monai/transforms/transform.py | 25 +++- tests/test_inverse.py | 153 ++++++++++++------------- 3 files changed, 125 insertions(+), 83 deletions(-) diff --git a/monai/transforms/croppad/dictionary.py b/monai/transforms/croppad/dictionary.py index 3285e82523..1c8deee509 100644 --- a/monai/transforms/croppad/dictionary.py +++ b/monai/transforms/croppad/dictionary.py @@ -232,7 +232,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda return d -class SpatialCropd(MapTransform): +class SpatialCropd(MapTransform, InvertibleTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.SpatialCrop`. Either a spatial center and size must be provided, or alternatively if center and size @@ -262,9 +262,37 @@ def __init__( def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) for key in self.keys: + self.append_applied_transforms(d, key, {"orig_size": d[key].shape}) d[key] = self.cropper(d[key]) return d + def get_input_args(self, key): + return { + "keys": key, + "roi_start": self.cropper.roi_start, + "roi_end": self.cropper.roi_end, + } + + def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + d = deepcopy(dict(data)) + + for key in self.keys: + transform = self.get_most_recent_transform(d, key) + # Create inverse transform + extra_info = transform["extra_info"] + orig_size = extra_info["orig_size"][1:] + raise NotImplementedError("TODO") + # im_shape = d[key].shape[1:] if self.padder.method == Method.SYMMETRIC else extra_info["orig_size"][1:] + # roi_center = [floor(i / 2) if r % 2 == 0 else (i - 1) / 2 for r, i in zip(roi_size, im_shape)] + + # inverse_transform = SpatialCrop(roi_center, roi_size) + # # Apply inverse transform + # d[key] = inverse_transform(d[key]) + # # Remove the applied transform + # self.remove_most_recent_transform(d, key) + + return d + class CenterSpatialCropd(MapTransform): """ diff --git a/monai/transforms/transform.py b/monai/transforms/transform.py index 256e751c74..619033ae31 100644 --- a/monai/transforms/transform.py +++ b/monai/transforms/transform.py @@ -236,11 +236,32 @@ def append_applied_transforms(self, data: dict, key: Hashable, extra_args: Optio if isinstance(self, Randomizable): data[key_transform][-1]["do_transform"] = self._do_transform + def check_transforms_match(self, transform: dict, key: Hashable) -> None: + explanation = "Should inverse most recently applied invertible transform first" + # Check transorms are of same type. + if transform["class"] != type(self): + raise RuntimeError(explanation) + + def check_dictionaries_match(dict1, dict2): + if dict1.keys() != dict2.keys(): + raise RuntimeError(explanation) + for k in dict1.keys(): + if dict1[k] != dict2[k]: + raise RuntimeError(explanation) + + t1 = transform["init_args"] + t2 = self.get_input_args(key) + + if t1.keys() != t2.keys(): + raise RuntimeError(explanation) + for k in t1.keys(): + if np.any(t1[k] != t2[k]): + raise RuntimeError(explanation) + def get_most_recent_transform(self, data: dict, key: Hashable) -> dict: """Get most recent transform.""" transform = dict(data[str(key) + "_transforms"][-1]) - if transform["class"] != type(self) or transform["init_args"] != self.get_input_args(key): - raise RuntimeError("Should inverse most recently applied invertible transform first") + self.check_transforms_match(transform, key) return transform @staticmethod diff --git a/tests/test_inverse.py b/tests/test_inverse.py index f7d7d42d2c..9220d99539 100644 --- a/tests/test_inverse.py +++ b/tests/test_inverse.py @@ -18,8 +18,7 @@ from monai.data import create_test_image_2d # , create_test_image_3d from monai.data import CacheDataset -from monai.transforms import AddChanneld, Compose, RandRotated, Rotated, SpatialPad, SpatialPadd -from monai.transforms.transform import InvertibleTransform +from monai.transforms import InvertibleTransform, AddChanneld, Compose, RandRotated, Rotated, SpatialPad, SpatialPadd, SpatialCropd from monai.utils import Method, optional_import # from functools import partial @@ -34,92 +33,86 @@ plt, has_matplotlib = optional_import("matplotlib.pyplot") TEST_SPATIALS = [] -TEST_SPATIALS.append( +TEST_SPATIALS.append([ + "Spatial 1d", + {"image": np.arange(0, 10).reshape(1, 10)}, [ - "Spatial 1d", - {"image": np.arange(0, 10).reshape(1, 10)}, - [ - SpatialPadd("image", spatial_size=[15]), - SpatialPadd("image", spatial_size=[21], method=Method.END), - SpatialPadd("image", spatial_size=[24]), - ], - 0.0, - ] -) - -TEST_SPATIALS.append( - [ - "Spatial 2d", - {"image": np.arange(0, 10 * 9).reshape(1, 10, 9)}, - [ - SpatialPadd("image", spatial_size=[11, 12]), - SpatialPadd("image", spatial_size=[12, 21]), - SpatialPadd("image", spatial_size=[14, 25], method=Method.END), - ], - 0.0, - ] -) - -TEST_SPATIALS.append( + SpatialPadd("image", spatial_size=[15]), + SpatialPadd("image", spatial_size=[21], method=Method.END), + SpatialPadd("image", spatial_size=[24]), + ], + 0.0, +]) + +TEST_SPATIALS.append([ + "Spatial 2d", + {"image": np.arange(0, 10 * 9).reshape(1, 10, 9)}, [ - "Spatial 3d", - {"image": np.arange(0, 10 * 9 * 8).reshape(1, 10, 9, 8)}, - [ - SpatialPadd("image", spatial_size=[55, 50, 45]), - ], - 0.0, - ] -) + SpatialPadd("image", spatial_size=[11, 12]), + SpatialPadd("image", spatial_size=[12, 21]), + SpatialPadd("image", spatial_size=[14, 25], method=Method.END), + ], + 0.0, +]) + +TEST_SPATIALS.append([ + "Spatial 3d", + {"image": np.arange(0, 10 * 9 * 8).reshape(1, 10, 9, 8)}, + [SpatialPadd("image", spatial_size=[55, 50, 45])], + 0.0, +]) + +TEST_CROPS = [] +for im_size in [100, 101]: + for center in [im_size // 2, 40]: + TEST_CROPS.append([ + f"Spatial crop 2d, input size: {im_size, im_size + 1}, crop center: {center, center + 1}, crop size: {90, 91}", + {"image": create_test_image_2d(im_size, im_size + 1)[0]}, + [SpatialCropd("image", [center, center + 1], [90, 91])], + 0.0, + ]) TEST_COMPOSES = [] -TEST_COMPOSES.append( +TEST_COMPOSES.append([ + "Compose 2d", + { + "image": np.arange(0, 10 * 9).reshape(1, 10, 9), + "label": np.arange(0, 10 * 9).reshape(1, 10, 9), + "other": np.arange(0, 10 * 9).reshape(1, 10, 9), + }, [ - "Compose 2d", - { - "image": np.arange(0, 10 * 9).reshape(1, 10, 9), - "label": np.arange(0, 10 * 9).reshape(1, 10, 9), - "other": np.arange(0, 10 * 9).reshape(1, 10, 9), - }, - [ - Compose( - [ - SpatialPadd(["image", "label"], spatial_size=[15, 12]), - SpatialPadd(["label"], spatial_size=[21, 32]), - SpatialPadd(["image"], spatial_size=[55, 50]), - ] - ) - ], - 0.0, - ] -) -TEST_COMPOSES.append( + Compose( + [ + SpatialPadd(["image", "label"], spatial_size=[15, 12]), + SpatialPadd(["label"], spatial_size=[21, 32]), + SpatialPadd(["image"], spatial_size=[55, 50]), + ] + ) + ], + 0.0, +]) +TEST_COMPOSES.append([ + "Compose 3d", + { + "image": np.arange(0, 10 * 9 * 8).reshape(1, 10, 9, 8), + "label": np.arange(0, 10 * 9 * 8).reshape(1, 10, 9, 8), + "other": np.arange(0, 10 * 9 * 8).reshape(1, 10, 9, 8), + }, [ - "Compose 3d", - { - "image": np.arange(0, 10 * 9 * 8).reshape(1, 10, 9, 8), - "label": np.arange(0, 10 * 9 * 8).reshape(1, 10, 9, 8), - "other": np.arange(0, 10 * 9 * 8).reshape(1, 10, 9, 8), - }, - [ - Compose( - [ - SpatialPadd(["image", "label"], spatial_size=[15, 12, 4]), - SpatialPadd(["label"], spatial_size=[21, 32, 1]), - SpatialPadd(["image"], spatial_size=[55, 50, 45]), - ] - ) - ], - 0.0, - ] -) + Compose( + [ + SpatialPadd(["image", "label"], spatial_size=[15, 12, 4]), + SpatialPadd(["label"], spatial_size=[21, 32, 1]), + SpatialPadd(["image"], spatial_size=[55, 50, 45]), + ] + ) + ], + 0.0, +]) TEST_FAIL_0 = [ np.arange(0, 10).reshape(1, 10), - Compose( - [ - SpatialPad(spatial_size=[15]), - ] - ), + Compose([SpatialPad(spatial_size=[15])]), 0.0, ] @@ -154,7 +147,7 @@ ] TEST_ROTATES.append(TEST_ROTATE) -TESTS = [*TEST_SPATIALS, *TEST_COMPOSES, *TEST_ROTATES] +TESTS = [*TEST_CROPS, *TEST_SPATIALS, *TEST_COMPOSES, *TEST_ROTATES] TESTS_DATALOADER = [*TEST_COMPOSES, *TEST_SPATIALS] TESTS_FAIL = [TEST_FAIL_0] From 46471a61f6588b4bdd2c04fb39da0e065393f87c Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Mon, 1 Feb 2021 13:34:11 +0000 Subject: [PATCH 21/80] update tests Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/data/dataset.py | 21 ++- monai/transforms/croppad/dictionary.py | 12 +- monai/transforms/spatial/dictionary.py | 24 +-- monai/transforms/transform.py | 6 +- tests/test_inverse.py | 198 +++++++++---------------- 5 files changed, 107 insertions(+), 154 deletions(-) diff --git a/monai/data/dataset.py b/monai/data/dataset.py index e67c7a2954..82f7289611 100644 --- a/monai/data/dataset.py +++ b/monai/data/dataset.py @@ -51,14 +51,16 @@ class Dataset(_TorchDataset): }, }, }] """ - def __init__(self, data: Sequence, transform: Optional[Callable] = None) -> None: + def __init__(self, data: Sequence, transform: Optional[Callable] = None, progress: bool = True) -> None: """ Args: data: input data to load and transform to generate dataset for model. transform: a callable data transform on input data. + progress: whether to display a progress bar. """ self.data = data self.transform = transform + self.progress = progress def __len__(self) -> int: return len(self.data) @@ -115,6 +117,7 @@ def __init__( transform: Union[Sequence[Callable], Callable], cache_dir: Optional[Union[Path, str]] = None, hash_func: Callable[..., bytes] = pickle_hashing, + progress: bool = True, ) -> None: """ Args: @@ -129,10 +132,11 @@ def __init__( If the cache_dir doesn't exist, will automatically create it. hash_func: a callable to compute hash from data items to be cached. defaults to `monai.data.utils.pickle_hashing`. + progress: whether to display a progress bar. """ if not isinstance(transform, Compose): transform = Compose(transform) - super().__init__(data=data, transform=transform) + super().__init__(data=data, transform=transform, progress=progress) self.cache_dir = Path(cache_dir) if cache_dir is not None else None self.hash_func = hash_func if self.cache_dir is not None: @@ -345,7 +349,7 @@ def __init__( lmdb_kwargs: additional keyword arguments to the lmdb environment. for more details please visit: https://lmdb.readthedocs.io/en/release/#environment-class """ - super().__init__(data=data, transform=transform, cache_dir=cache_dir, hash_func=hash_func) + super().__init__(data=data, transform=transform, cache_dir=cache_dir, hash_func=hash_func, progress=progress) if not self.cache_dir: raise ValueError("cache_dir must be specified.") self.db_file = self.cache_dir / f"{db_name}.lmdb" @@ -354,14 +358,13 @@ def __init__( if not self.lmdb_kwargs.get("map_size", 0): self.lmdb_kwargs["map_size"] = 1024 ** 4 # default map_size self._read_env = None - self.progress = progress print(f"Accessing lmdb file: {self.db_file.absolute()}.") def _fill_cache_start_reader(self): # create cache self.lmdb_kwargs["readonly"] = False env = lmdb.open(path=f"{self.db_file}", subdir=False, **self.lmdb_kwargs) - if not has_tqdm: + if self.progress and not has_tqdm: warnings.warn("LMDBDataset: tqdm is not installed. not displaying the caching progress.") for item in tqdm(self.data) if has_tqdm and self.progress else self.data: key = self.hash_func(item) @@ -470,6 +473,7 @@ def __init__( cache_num: int = sys.maxsize, cache_rate: float = 1.0, num_workers: Optional[int] = None, + progress: bool = True, ) -> None: """ Args: @@ -481,10 +485,11 @@ def __init__( will take the minimum of (cache_num, data_length x cache_rate, data_length). num_workers: the number of worker processes to use. If num_workers is None then the number returned by os.cpu_count() is used. + progress: whether to display a progress bar. """ if not isinstance(transform, Compose): transform = Compose(transform) - super().__init__(data=data, transform=transform) + super().__init__(data=data, transform=transform, progress=progress) self.cache_num = min(int(cache_num), int(len(data) * cache_rate), len(data)) self.num_workers = num_workers if self.num_workers is not None: @@ -494,10 +499,10 @@ def __init__( def _fill_cache(self) -> List: if self.cache_num <= 0: return [] - if not has_tqdm: + if self.progress and not has_tqdm: warnings.warn("tqdm is not installed, will not show the caching progress bar.") with ThreadPool(self.num_workers) as p: - if has_tqdm: + if self.progress and has_tqdm: return list( tqdm( p.imap(self._load_cache_item, range(self.cache_num)), diff --git a/monai/transforms/croppad/dictionary.py b/monai/transforms/croppad/dictionary.py index 1c8deee509..88c7d2f074 100644 --- a/monai/transforms/croppad/dictionary.py +++ b/monai/transforms/croppad/dictionary.py @@ -118,17 +118,17 @@ def __init__( def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) - for key, m in zip(self.keys, self.mode): + for idx, (key, m) in enumerate(zip(self.keys, self.mode)): orig_size = d[key].shape d[key] = self.padder(d[key], mode=m) - self.append_applied_transforms(d, key, {"orig_size": orig_size}) + self.append_applied_transforms(d, key, idx, {"orig_size": orig_size}) return d - def get_input_args(self, key): + def get_input_args(self, key, idx = 0): return { "keys": key, "method": self.padder.method, - "mode": self.mode[0], + "mode": self.mode[idx], "spatial_size": self.padder.spatial_size, } @@ -262,11 +262,11 @@ def __init__( def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) for key in self.keys: - self.append_applied_transforms(d, key, {"orig_size": d[key].shape}) + self.append_applied_transforms(d, key, extra_args={"orig_size": d[key].shape}) d[key] = self.cropper(d[key]) return d - def get_input_args(self, key): + def get_input_args(self, key, _): return { "keys": key, "roi_start": self.cropper.roi_start, diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index d793ca96c8..f9e21685cb 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -808,7 +808,7 @@ def __init__( def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) for idx, key in enumerate(self.keys): - self.append_applied_transforms(d, key, {"orig_size": d[key].shape[1:]}) + self.append_applied_transforms(d, key, idx, {"orig_size": d[key].shape[1:]}) d[key] = self.rotator( d[key], mode=self.mode[idx], @@ -818,15 +818,15 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda ) return d - def get_input_args(self, key) -> dict: + def get_input_args(self, key, idx = 0) -> dict: return { "keys": key, "angle": self.rotator.angle, "keep_size": self.rotator.keep_size, - "mode": self.mode, - "padding_mode": self.padding_mode, - "align_corners": self.align_corners, - "dtype": self.dtype, + "mode": self.mode[idx], + "padding_mode": self.padding_mode[idx], + "align_corners": self.align_corners[idx], + "dtype": self.dtype[idx], } def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: @@ -945,7 +945,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda keep_size=self.keep_size, ) for idx, key in enumerate(self.keys): - self.append_applied_transforms(d, key, {"angle": angle, "orig_size": d[key].shape[1:]}) + self.append_applied_transforms(d, key, idx, {"angle": angle, "orig_size": d[key].shape[1:]}) d[key] = rotator( d[key], mode=self.mode[idx], @@ -955,7 +955,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda ) return d - def get_input_args(self, key) -> dict: + def get_input_args(self, key, idx = 0) -> dict: return { "keys": key, "range_x": self.range_x, @@ -963,10 +963,10 @@ def get_input_args(self, key) -> dict: "range_z": self.range_z, "prob": self.prob, "keep_size": self.keep_size, - "mode": self.mode, - "padding_mode": self.padding_mode, - "align_corners": self.align_corners, - "dtype": self.dtype, + "mode": self.mode[idx], + "padding_mode": self.padding_mode[idx], + "align_corners": self.align_corners[idx], + "dtype": self.dtype[idx], } def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: diff --git a/monai/transforms/transform.py b/monai/transforms/transform.py index 619033ae31..08fef85542 100644 --- a/monai/transforms/transform.py +++ b/monai/transforms/transform.py @@ -223,14 +223,14 @@ class InvertibleTransform(ABC): first out for the inverted transforms. """ - def append_applied_transforms(self, data: dict, key: Hashable, extra_args: Optional[dict] = None) -> None: + def append_applied_transforms(self, data: dict, key: Hashable, idx: int = 0, extra_args: Optional[dict] = None) -> None: """Append to list of applied transforms for that key.""" key_transform = str(key) + "_transforms" # If this is the first, create list if key_transform not in data: data[key_transform] = [] data[key_transform].append( - {"class": type(self), "init_args": self.get_input_args(key), "extra_info": extra_args} + {"class": type(self), "init_args": self.get_input_args(key, idx), "extra_info": extra_args} ) # If class is randomizable, store whether the transform was actually performed (based on `prob`) if isinstance(self, Randomizable): @@ -269,7 +269,7 @@ def remove_most_recent_transform(data: dict, key: Hashable) -> None: """Remove most recent transform.""" data[str(key) + "_transforms"].pop() - def get_input_args(self, key) -> dict: + def get_input_args(self, key: Hashable, idx: int = 0) -> dict: """Get input arguments for a single key.""" raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") diff --git a/tests/test_inverse.py b/tests/test_inverse.py index 9220d99539..6fde7921b1 100644 --- a/tests/test_inverse.py +++ b/tests/test_inverse.py @@ -14,14 +14,14 @@ import unittest from typing import TYPE_CHECKING -import numpy as np -from monai.data import create_test_image_2d # , create_test_image_3d +import numpy as np +from typing import List, Tuple +from monai.data import create_test_image_2d, create_test_image_3d from monai.data import CacheDataset -from monai.transforms import InvertibleTransform, AddChanneld, Compose, RandRotated, Rotated, SpatialPad, SpatialPadd, SpatialCropd +from monai.transforms import InvertibleTransform, AddChannel, Compose, RandRotated, Rotated, SpatialPad, SpatialPadd, SpatialCropd from monai.utils import Method, optional_import -# from functools import partial # from parameterized import parameterized @@ -32,125 +32,77 @@ else: plt, has_matplotlib = optional_import("matplotlib.pyplot") -TEST_SPATIALS = [] -TEST_SPATIALS.append([ - "Spatial 1d", - {"image": np.arange(0, 10).reshape(1, 10)}, - [ - SpatialPadd("image", spatial_size=[15]), - SpatialPadd("image", spatial_size=[21], method=Method.END), - SpatialPadd("image", spatial_size=[24]), - ], - 0.0, -]) -TEST_SPATIALS.append([ + +IM_2D = AddChannel()(create_test_image_2d(100, 101)[0]) +IM_3D = AddChannel()(create_test_image_3d(100, 101, 107)[0]) + +DATA_2D = {"image": IM_2D, "label": IM_2D, "other": IM_2D} +DATA_3D = {"image": IM_3D, "label": IM_3D, "other": IM_3D} +KEYS = ["image", "label"] + +TESTS: List[Tuple] = [] + +TESTS.append(( "Spatial 2d", - {"image": np.arange(0, 10 * 9).reshape(1, 10, 9)}, - [ - SpatialPadd("image", spatial_size=[11, 12]), - SpatialPadd("image", spatial_size=[12, 21]), - SpatialPadd("image", spatial_size=[14, 25], method=Method.END), - ], + DATA_2D, 0.0, -]) + SpatialPadd(KEYS, spatial_size=[111, 113], method=Method.END), + SpatialPadd(KEYS, spatial_size=[118, 117]), +)) -TEST_SPATIALS.append([ +TESTS.append(( "Spatial 3d", - {"image": np.arange(0, 10 * 9 * 8).reshape(1, 10, 9, 8)}, - [SpatialPadd("image", spatial_size=[55, 50, 45])], - 0.0, -]) - -TEST_CROPS = [] -for im_size in [100, 101]: - for center in [im_size // 2, 40]: - TEST_CROPS.append([ - f"Spatial crop 2d, input size: {im_size, im_size + 1}, crop center: {center, center + 1}, crop size: {90, 91}", - {"image": create_test_image_2d(im_size, im_size + 1)[0]}, - [SpatialCropd("image", [center, center + 1], [90, 91])], - 0.0, - ]) - -TEST_COMPOSES = [] -TEST_COMPOSES.append([ - "Compose 2d", - { - "image": np.arange(0, 10 * 9).reshape(1, 10, 9), - "label": np.arange(0, 10 * 9).reshape(1, 10, 9), - "other": np.arange(0, 10 * 9).reshape(1, 10, 9), - }, - [ - Compose( - [ - SpatialPadd(["image", "label"], spatial_size=[15, 12]), - SpatialPadd(["label"], spatial_size=[21, 32]), - SpatialPadd(["image"], spatial_size=[55, 50]), - ] - ) - ], - 0.0, -]) -TEST_COMPOSES.append([ - "Compose 3d", - { - "image": np.arange(0, 10 * 9 * 8).reshape(1, 10, 9, 8), - "label": np.arange(0, 10 * 9 * 8).reshape(1, 10, 9, 8), - "other": np.arange(0, 10 * 9 * 8).reshape(1, 10, 9, 8), - }, - [ - Compose( - [ - SpatialPadd(["image", "label"], spatial_size=[15, 12, 4]), - SpatialPadd(["label"], spatial_size=[21, 32, 1]), - SpatialPadd(["image"], spatial_size=[55, 50, 45]), - ] - ) - ], + DATA_3D, 0.0, -]) - -TEST_FAIL_0 = [ - np.arange(0, 10).reshape(1, 10), - Compose([SpatialPad(spatial_size=[15])]), - 0.0, -] - -# TODO: add 3D -TEST_ROTATES = [] -for create_im in [create_test_image_2d]: # , partial(create_test_image_3d, 100)]: + SpatialPadd(KEYS, spatial_size=[112, 113, 116]), +)) + +TESTS.append(( + "Rand, prob 0", + DATA_2D, + 0, + RandRotated(KEYS, prob=0), +)) + +# # TEST_CROPS = [] +# # for im_size in [100, 101]: +# # for center in [im_size // 2, 40]: +# # TEST_CROPS.append([ +# # f"Spatial crop 2d, input size: {im_size, im_size + 1}, crop center: {center, center + 1}, crop size: {90, 91}", +# # {"image": create_test_image_2d(im_size, im_size + 1)[0]}, +# # 0.0, +# # SpatialCropd(KEYS, [center, center + 1], [90, 91]), +# # ]) + +# # TODO: add 3D +for data in [DATA_2D]: # , DATA_3D]: + ndim = data['image'].ndim for keep_size in [True, False]: for align_corners in [False, True]: - im, _ = create_im(100, 100) angle = random.uniform(np.pi / 6, np.pi) - TEST_ROTATE = [ - f"Rotate{im.ndim}d, keep_size={keep_size}, align_corners={align_corners}", - {"image": im}, - [ - AddChanneld("image"), - Rotated("image", angle, keep_size, "bilinear", "border", align_corners), - ], + TESTS.append(( + f"Rotate{ndim}d, keep_size={keep_size}, align_corners={align_corners}", + data, 5e-2, - ] - TEST_ROTATES.append(TEST_ROTATE) - for prob in [0, 1]: - im, _ = create_im(100, 100) - x, y, z = (random.uniform(np.pi / 6, np.pi) for _ in range(3)) - TEST_ROTATE = [ - f"RandRotate{im.ndim}d, prob={prob}", - {"image": im}, - [ - AddChanneld("image"), - RandRotated("image", x, y, z, prob, True, "bilinear", "border", False), - ], - 5e-2, - ] - TEST_ROTATES.append(TEST_ROTATE) - -TESTS = [*TEST_CROPS, *TEST_SPATIALS, *TEST_COMPOSES, *TEST_ROTATES] -TESTS_DATALOADER = [*TEST_COMPOSES, *TEST_SPATIALS] -TESTS_FAIL = [TEST_FAIL_0] + Rotated(KEYS, angle, keep_size, "bilinear", "border", align_corners), + )) + + x, y, z = (random.uniform(np.pi / 6, np.pi) for _ in range(3)) + TESTS.append(( + f"RandRotate{ndim}d", + data, + 5e-2, + RandRotated(KEYS, x, y, z, 1), + )) +TESTS_COMPOSE_X2 = [(t[0] + " Compose", t[1], t[2], Compose(Compose(t[3:]))) for t in TESTS] + +TESTS = [*TESTS, *TESTS_COMPOSE_X2] + + +TEST_FAIL_0 = (IM_2D, 0.0, Compose([SpatialPad(spatial_size=[101,103])])) +TESTS_FAIL = [TEST_FAIL_0] def plot_im(orig, fwd_bck, fwd): diff_orig_fwd_bck = orig - fwd_bck @@ -178,19 +130,17 @@ def check_inverse(self, keys, orig_d, fwd_bck_d, unmodified_d, acceptable_diff): orig = orig_d[key] fwd_bck = fwd_bck_d[key] unmodified = unmodified_d[key] + mean_diff = np.mean(np.abs(orig - fwd_bck)) try: - mean_diff = np.mean(np.abs(orig - fwd_bck)) - if acceptable_diff > 0: - print(f"Mean diff = {mean_diff}") self.assertLessEqual(mean_diff, acceptable_diff) except AssertionError: if has_matplotlib: + print(f"Mean diff = {mean_diff} (expected <= {acceptable_diff})") plot_im(orig, fwd_bck, unmodified) - raise + raise # @parameterized.expand(TESTS) - def test_inverse(self, desc, data, transforms, acceptable_diff): - print(f"testing: {desc} (acceptable diff: {acceptable_diff})") + def test_inverse(self, _, data, acceptable_diff, *transforms): forwards = [data.copy()] # Apply forwards @@ -210,19 +160,18 @@ def test_inverse(self, desc, data, transforms, acceptable_diff): self.check_inverse(data.keys(), forwards[-i - 2], fwd_bck, forwards[-1], acceptable_diff) # @parameterized.expand(TESTS_FAIL) - def test_fail(self, data, transform, _): - d = transform(data) + def test_fail(self, data, _, *transform): + d = transform[0](data) with self.assertRaises(RuntimeError): - d = transform.inverse(d) + d = transform[0].inverse(d) # @parameterized.expand(TEST_COMPOSES) - def test_w_data_loader(self, desc, data, transforms, acceptable_diff): - print(f"testing: {desc}...") + def test_w_data_loader(self, _, data, acceptable_diff, *transforms): transform = transforms[0] numel = 2 test_data = [data for _ in range(numel)] - dataset = CacheDataset(data=test_data, transform=transform) + dataset = CacheDataset(test_data, transform, progress=False) self.assertEqual(len(dataset), 2) num_epochs = 2 for _ in range(num_epochs): @@ -236,7 +185,6 @@ def test_w_data_loader(self, desc, data, transforms, acceptable_diff): test = TestInverse() for t in TESTS: test.test_inverse(*t) - for t in TESTS_DATALOADER: test.test_w_data_loader(*t) for t in TESTS_FAIL: test.test_fail(*t) From 901f56bfe037889c15bf33f24c1518cf4abeb513 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Mon, 1 Feb 2021 15:24:55 +0000 Subject: [PATCH 22/80] crop Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/croppad/dictionary.py | 27 +++--- monai/transforms/spatial/dictionary.py | 4 +- tests/test_inverse.py | 109 +++++++++++++------------ 3 files changed, 75 insertions(+), 65 deletions(-) diff --git a/monai/transforms/croppad/dictionary.py b/monai/transforms/croppad/dictionary.py index 88c7d2f074..c28f9df57e 100644 --- a/monai/transforms/croppad/dictionary.py +++ b/monai/transforms/croppad/dictionary.py @@ -124,7 +124,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda self.append_applied_transforms(d, key, idx, {"orig_size": orig_size}) return d - def get_input_args(self, key, idx = 0): + def get_input_args(self, key: Hashable, idx: int = 0) -> dict: return { "keys": key, "method": self.padder.method, @@ -266,7 +266,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda d[key] = self.cropper(d[key]) return d - def get_input_args(self, key, _): + def get_input_args(self, key: Hashable, idx: int = 0) -> dict: return { "keys": key, "roi_start": self.cropper.roi_start, @@ -279,17 +279,18 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar for key in self.keys: transform = self.get_most_recent_transform(d, key) # Create inverse transform - extra_info = transform["extra_info"] - orig_size = extra_info["orig_size"][1:] - raise NotImplementedError("TODO") - # im_shape = d[key].shape[1:] if self.padder.method == Method.SYMMETRIC else extra_info["orig_size"][1:] - # roi_center = [floor(i / 2) if r % 2 == 0 else (i - 1) / 2 for r, i in zip(roi_size, im_shape)] - - # inverse_transform = SpatialCrop(roi_center, roi_size) - # # Apply inverse transform - # d[key] = inverse_transform(d[key]) - # # Remove the applied transform - # self.remove_most_recent_transform(d, key) + orig_size = transform["extra_info"]["orig_size"][1:] + pad_to_start = transform["init_args"]["roi_start"] + pad_to_end = orig_size - transform["init_args"]["roi_end"] + # interweave mins and maxes + pad = np.empty((2 * len(orig_size)), dtype=np.int32) + pad[0::2] = pad_to_start + pad[1::2] = pad_to_end + inverse_transform = BorderPad(pad.tolist()) + # Apply inverse transform + d[key] = inverse_transform(d[key]) + # Remove the applied transform + self.remove_most_recent_transform(d, key) return d diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index f9e21685cb..73fe9d0708 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -818,7 +818,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda ) return d - def get_input_args(self, key, idx = 0) -> dict: + def get_input_args(self, key: Hashable, idx: int = 0) -> dict: return { "keys": key, "angle": self.rotator.angle, @@ -955,7 +955,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda ) return d - def get_input_args(self, key, idx = 0) -> dict: + def get_input_args(self, key: Hashable, idx: int = 0) -> dict: return { "keys": key, "range_x": self.range_x, diff --git a/tests/test_inverse.py b/tests/test_inverse.py index 6fde7921b1..a9cddba658 100644 --- a/tests/test_inverse.py +++ b/tests/test_inverse.py @@ -43,62 +43,70 @@ TESTS: List[Tuple] = [] -TESTS.append(( - "Spatial 2d", - DATA_2D, - 0.0, - SpatialPadd(KEYS, spatial_size=[111, 113], method=Method.END), - SpatialPadd(KEYS, spatial_size=[118, 117]), -)) +# TESTS.append(( +# "Spatial 2d", +# DATA_2D, +# 0.0, +# SpatialPadd(KEYS, spatial_size=[111, 113], method=Method.END), +# SpatialPadd(KEYS, spatial_size=[118, 117]), +# )) + +# TESTS.append(( +# "Spatial 3d", +# DATA_3D, +# 0.0, +# SpatialPadd(KEYS, spatial_size=[112, 113, 116]), +# )) + +# TESTS.append(( +# "Rand, prob 0", +# DATA_2D, +# 0, +# RandRotated(KEYS, prob=0), +# )) + -TESTS.append(( - "Spatial 3d", - DATA_3D, - 0.0, - SpatialPadd(KEYS, spatial_size=[112, 113, 116]), -)) TESTS.append(( - "Rand, prob 0", + f"Spatial crop 2d", DATA_2D, - 0, - RandRotated(KEYS, prob=0), + 0.0, + SpatialCropd("image", [49, 51], [96, 97]), )) -# # TEST_CROPS = [] -# # for im_size in [100, 101]: -# # for center in [im_size // 2, 40]: -# # TEST_CROPS.append([ -# # f"Spatial crop 2d, input size: {im_size, im_size + 1}, crop center: {center, center + 1}, crop size: {90, 91}", -# # {"image": create_test_image_2d(im_size, im_size + 1)[0]}, -# # 0.0, -# # SpatialCropd(KEYS, [center, center + 1], [90, 91]), -# # ]) - -# # TODO: add 3D -for data in [DATA_2D]: # , DATA_3D]: - ndim = data['image'].ndim - for keep_size in [True, False]: - for align_corners in [False, True]: - angle = random.uniform(np.pi / 6, np.pi) - TESTS.append(( - f"Rotate{ndim}d, keep_size={keep_size}, align_corners={align_corners}", - data, - 5e-2, - Rotated(KEYS, angle, keep_size, "bilinear", "border", align_corners), - )) - - x, y, z = (random.uniform(np.pi / 6, np.pi) for _ in range(3)) - TESTS.append(( - f"RandRotate{ndim}d", - data, - 5e-2, - RandRotated(KEYS, x, y, z, 1), - )) - -TESTS_COMPOSE_X2 = [(t[0] + " Compose", t[1], t[2], Compose(Compose(t[3:]))) for t in TESTS] - -TESTS = [*TESTS, *TESTS_COMPOSE_X2] +# for im_size in [100, 101]: +# for center in [im_size // 2, 40]: +# TESTS.append([ +# f"Spatial crop 2d, input size: {im_size, im_size + 1}, crop center: {center, center + 1}, crop size: {90, 91}", +# DATA_2D, +# 0.0, +# SpatialCropd(KEYS, [center, center + 1], [90, 91]), +# ]) + +# # # TODO: add 3D +# for data in [DATA_2D]: # , DATA_3D]: +# ndim = data['image'].ndim +# for keep_size in [True, False]: +# for align_corners in [False, True]: +# angle = random.uniform(np.pi / 6, np.pi) +# TESTS.append(( +# f"Rotate{ndim}d, keep_size={keep_size}, align_corners={align_corners}", +# data, +# 5e-2, +# Rotated(KEYS, angle, keep_size, "bilinear", "border", align_corners), +# )) + +# x, y, z = (random.uniform(np.pi / 6, np.pi) for _ in range(3)) +# TESTS.append(( +# f"RandRotate{ndim}d", +# data, +# 5e-2, +# RandRotated(KEYS, x, y, z, 1), +# )) + +# TESTS_COMPOSE_X2 = [(t[0] + " Compose", t[1], t[2], Compose(Compose(t[3:]))) for t in TESTS] + +TESTS = [*TESTS] #, *TESTS_COMPOSE_X2] TEST_FAIL_0 = (IM_2D, 0.0, Compose([SpatialPad(spatial_size=[101,103])])) @@ -131,6 +139,7 @@ def check_inverse(self, keys, orig_d, fwd_bck_d, unmodified_d, acceptable_diff): fwd_bck = fwd_bck_d[key] unmodified = unmodified_d[key] mean_diff = np.mean(np.abs(orig - fwd_bck)) + plot_im(orig, fwd_bck, unmodified) try: self.assertLessEqual(mean_diff, acceptable_diff) except AssertionError: From e8981880eaaf73929dd42402df9d5d529ec95fc7 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Mon, 1 Feb 2021 15:32:00 +0000 Subject: [PATCH 23/80] crop finished Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- tests/test_inverse.py | 67 ++++++++++++++++++++----------------------- 1 file changed, 31 insertions(+), 36 deletions(-) diff --git a/tests/test_inverse.py b/tests/test_inverse.py index a9cddba658..5b4d522525 100644 --- a/tests/test_inverse.py +++ b/tests/test_inverse.py @@ -43,45 +43,41 @@ TESTS: List[Tuple] = [] -# TESTS.append(( -# "Spatial 2d", -# DATA_2D, -# 0.0, -# SpatialPadd(KEYS, spatial_size=[111, 113], method=Method.END), -# SpatialPadd(KEYS, spatial_size=[118, 117]), -# )) - -# TESTS.append(( -# "Spatial 3d", -# DATA_3D, -# 0.0, -# SpatialPadd(KEYS, spatial_size=[112, 113, 116]), -# )) - -# TESTS.append(( -# "Rand, prob 0", -# DATA_2D, -# 0, -# RandRotated(KEYS, prob=0), -# )) +TESTS.append(( + "Spatial 2d", + DATA_2D, + 0.0, + SpatialPadd(KEYS, spatial_size=[111, 113], method=Method.END), + SpatialPadd(KEYS, spatial_size=[118, 117]), +)) +TESTS.append(( + "Spatial 3d", + DATA_3D, + 0.0, + SpatialPadd(KEYS, spatial_size=[112, 113, 116]), +)) +TESTS.append(( + "Rand, prob 0", + DATA_2D, + 0, + RandRotated(KEYS, prob=0), +)) TESTS.append(( - f"Spatial crop 2d", + "Spatial crop 2d", DATA_2D, - 0.0, - SpatialCropd("image", [49, 51], [96, 97]), + 2e-2, + SpatialCropd("image", [49, 51], [90, 89]), )) -# for im_size in [100, 101]: -# for center in [im_size // 2, 40]: -# TESTS.append([ -# f"Spatial crop 2d, input size: {im_size, im_size + 1}, crop center: {center, center + 1}, crop size: {90, 91}", -# DATA_2D, -# 0.0, -# SpatialCropd(KEYS, [center, center + 1], [90, 91]), -# ]) +TESTS.append(( + "Spatial crop 3d", + DATA_3D, + 2e-2, + SpatialCropd("image", [49, 51, 44], [90, 89, 93]), +)) # # # TODO: add 3D # for data in [DATA_2D]: # , DATA_3D]: @@ -104,12 +100,12 @@ # RandRotated(KEYS, x, y, z, 1), # )) -# TESTS_COMPOSE_X2 = [(t[0] + " Compose", t[1], t[2], Compose(Compose(t[3:]))) for t in TESTS] +TESTS_COMPOSE_X2 = [(t[0] + " Compose", t[1], t[2], Compose(Compose(t[3:]))) for t in TESTS] -TESTS = [*TESTS] #, *TESTS_COMPOSE_X2] +TESTS = [*TESTS, *TESTS_COMPOSE_X2] -TEST_FAIL_0 = (IM_2D, 0.0, Compose([SpatialPad(spatial_size=[101,103])])) +TEST_FAIL_0 = (IM_2D, 0.0, Compose([SpatialPad(spatial_size=[101, 103])])) TESTS_FAIL = [TEST_FAIL_0] def plot_im(orig, fwd_bck, fwd): @@ -139,7 +135,6 @@ def check_inverse(self, keys, orig_d, fwd_bck_d, unmodified_d, acceptable_diff): fwd_bck = fwd_bck_d[key] unmodified = unmodified_d[key] mean_diff = np.mean(np.abs(orig - fwd_bck)) - plot_im(orig, fwd_bck, unmodified) try: self.assertLessEqual(mean_diff, acceptable_diff) except AssertionError: From 31ebc76bbc82283e19791c66f81ac951e61e9004 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Mon, 1 Feb 2021 17:56:12 +0000 Subject: [PATCH 24/80] RandSpatialCropd Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/croppad/dictionary.py | 56 ++++++++++++++++++-- monai/transforms/spatial/dictionary.py | 8 +-- monai/transforms/transform.py | 10 +--- tests/test_inverse.py | 73 +++++++++++++++++--------- 4 files changed, 106 insertions(+), 41 deletions(-) diff --git a/monai/transforms/croppad/dictionary.py b/monai/transforms/croppad/dictionary.py index c28f9df57e..373eff4d96 100644 --- a/monai/transforms/croppad/dictionary.py +++ b/monai/transforms/croppad/dictionary.py @@ -262,7 +262,7 @@ def __init__( def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) for key in self.keys: - self.append_applied_transforms(d, key, extra_args={"orig_size": d[key].shape}) + self.append_applied_transforms(d, key) d[key] = self.cropper(d[key]) return d @@ -279,7 +279,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar for key in self.keys: transform = self.get_most_recent_transform(d, key) # Create inverse transform - orig_size = transform["extra_info"]["orig_size"][1:] + orig_size = transform["orig_size"] pad_to_start = transform["init_args"]["roi_start"] pad_to_end = orig_size - transform["init_args"]["roi_end"] # interweave mins and maxes @@ -317,7 +317,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda return d -class RandSpatialCropd(Randomizable, MapTransform): +class RandSpatialCropd(Randomizable, MapTransform, InvertibleTransform): """ Dictionary-based version :py:class:`monai.transforms.RandSpatialCrop`. Crop image with random size or specific size ROI. It can crop at a random position as @@ -342,7 +342,9 @@ def __init__( random_center: bool = True, random_size: bool = True, ) -> None: - super().__init__(keys) + Randomizable.__init__(self, prob=1.0) + MapTransform.__init__(self, keys) + self._do_transform = True self.roi_size = roi_size self.random_center = random_center self.random_size = random_size @@ -356,13 +358,15 @@ def randomize(self, img_size: Sequence[int]) -> None: if self.random_center: valid_size = get_valid_patch_size(img_size, self._size) self._slices = (slice(None),) + get_random_patch(img_size, valid_size, self.R) + pass def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) self.randomize(d[self.keys[0]].shape[1:]) # image shape from the first data key if self._size is None: raise AssertionError - for key in self.keys: + for idx, key in enumerate(self.keys): + self.append_applied_transforms(d, key, idx, {"slices": self._slices}) if self.random_center: d[key] = d[key][self._slices] else: @@ -370,6 +374,48 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda d[key] = cropper(d[key]) return d + def get_input_args(self, key: Hashable, idx: int = 0) -> dict: + return { + "keys": key, + "roi_size": self.roi_size, + "random_center": self.random_center, + "random_size": self.random_size, + } + + def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + d = deepcopy(dict(data)) + + for key in self.keys: + transform = self.get_most_recent_transform(d, key) + # Create inverse transform + orig_size = transform["orig_size"] + random_center = transform["init_args"]["random_center"] + pad_to_start = np.empty((len(orig_size)), dtype=np.int32) + pad_to_end = np.empty((len(orig_size)), dtype=np.int32) + if random_center: + for i, _slice in enumerate(transform["extra_info"]["slices"][1:]): + pad_to_start[i] = _slice.start + pad_to_end[i] = orig_size[i] - _slice.stop + else: + current_size = d[key].shape[1:] + for i, (o_s, c_s) in enumerate(zip(orig_size, current_size)): + pad_to_start[i] = pad_to_end[i] = (o_s - c_s) / 2 + if o_s % 2 == 0 and c_s % 2 == 1: + pad_to_start[i] += 1 + elif o_s % 2 == 1 and c_s % 2 == 0: + pad_to_end[i] += 1 + # interweave mins and maxes + pad = np.empty((2 * len(orig_size)), dtype=np.int32) + pad[0::2] = pad_to_start + pad[1::2] = pad_to_end + inverse_transform = BorderPad(pad.tolist()) + # Apply inverse transform + d[key] = inverse_transform(d[key]) + # Remove the applied transform + self.remove_most_recent_transform(d, key) + + return d + class RandSpatialCropSamplesd(Randomizable, MapTransform): """ diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index 73fe9d0708..d620b22b5e 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -808,7 +808,7 @@ def __init__( def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) for idx, key in enumerate(self.keys): - self.append_applied_transforms(d, key, idx, {"orig_size": d[key].shape[1:]}) + self.append_applied_transforms(d, key, idx) d[key] = self.rotator( d[key], mode=self.mode[idx], @@ -849,7 +849,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar ) # If the keep_size==False, need to crop image if not transform["init_args"]["keep_size"]: - d[key] = CenterSpatialCrop(transform["extra_info"]["orig_size"])(d[key]) + d[key] = CenterSpatialCrop(transform["orig_size"])(d[key]) # Remove the applied transform self.remove_most_recent_transform(d, key) @@ -945,7 +945,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda keep_size=self.keep_size, ) for idx, key in enumerate(self.keys): - self.append_applied_transforms(d, key, idx, {"angle": angle, "orig_size": d[key].shape[1:]}) + self.append_applied_transforms(d, key, idx, {"angle": angle}) d[key] = rotator( d[key], mode=self.mode[idx], @@ -991,7 +991,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar ) # If the keep_size==False, need to crop image if not transform["init_args"]["keep_size"]: - d[key] = CenterSpatialCrop(transform["extra_info"]["orig_size"])(d[key]) + d[key] = CenterSpatialCrop(transform["orig_size"])(d[key]) # Remove the applied transform self.remove_most_recent_transform(d, key) diff --git a/monai/transforms/transform.py b/monai/transforms/transform.py index 08fef85542..7a97a148a1 100644 --- a/monai/transforms/transform.py +++ b/monai/transforms/transform.py @@ -230,7 +230,8 @@ def append_applied_transforms(self, data: dict, key: Hashable, idx: int = 0, ext if key_transform not in data: data[key_transform] = [] data[key_transform].append( - {"class": type(self), "init_args": self.get_input_args(key, idx), "extra_info": extra_args} + {"class": type(self), "init_args": self.get_input_args(key, idx), + "orig_size": data[key].shape[1:], "extra_info": extra_args} ) # If class is randomizable, store whether the transform was actually performed (based on `prob`) if isinstance(self, Randomizable): @@ -242,13 +243,6 @@ def check_transforms_match(self, transform: dict, key: Hashable) -> None: if transform["class"] != type(self): raise RuntimeError(explanation) - def check_dictionaries_match(dict1, dict2): - if dict1.keys() != dict2.keys(): - raise RuntimeError(explanation) - for k in dict1.keys(): - if dict1[k] != dict2[k]: - raise RuntimeError(explanation) - t1 = transform["init_args"] t2 = self.get_input_args(key) diff --git a/tests/test_inverse.py b/tests/test_inverse.py index 5b4d522525..a62675295a 100644 --- a/tests/test_inverse.py +++ b/tests/test_inverse.py @@ -19,7 +19,17 @@ from typing import List, Tuple from monai.data import create_test_image_2d, create_test_image_3d from monai.data import CacheDataset -from monai.transforms import InvertibleTransform, AddChannel, Compose, RandRotated, Rotated, SpatialPad, SpatialPadd, SpatialCropd +from monai.transforms import ( + InvertibleTransform, + AddChannel, + Compose, + RandRotated, + RandSpatialCropd, + Rotated, + SpatialPad, + SpatialPadd, + SpatialCropd +) from monai.utils import Method, optional_import # from parameterized import parameterized @@ -33,10 +43,11 @@ plt, has_matplotlib = optional_import("matplotlib.pyplot") - +IM_1D = AddChannel()(np.arange(0, 11)) IM_2D = AddChannel()(create_test_image_2d(100, 101)[0]) IM_3D = AddChannel()(create_test_image_3d(100, 101, 107)[0]) +DATA_1D = {"image": IM_1D, "label": IM_1D, "other": IM_1D} DATA_2D = {"image": IM_2D, "label": IM_2D, "other": IM_2D} DATA_3D = {"image": IM_3D, "label": IM_3D, "other": IM_3D} KEYS = ["image", "label"] @@ -69,36 +80,50 @@ "Spatial crop 2d", DATA_2D, 2e-2, - SpatialCropd("image", [49, 51], [90, 89]), + SpatialCropd(KEYS, [49, 51], [90, 89]), )) TESTS.append(( "Spatial crop 3d", DATA_3D, 2e-2, - SpatialCropd("image", [49, 51, 44], [90, 89, 93]), + SpatialCropd(KEYS, [49, 51, 44], [90, 89, 93]), +)) + +TESTS.append(( + "Rand spatial crop 2d", + DATA_2D, + 2e-2, + RandSpatialCropd(KEYS, [96, 93], True, False) +)) + +TESTS.append(( + "Rand spatial crop 3d", + DATA_3D, + 2e-2, + RandSpatialCropd(KEYS, [96, 93, 92], False, True) )) -# # # TODO: add 3D -# for data in [DATA_2D]: # , DATA_3D]: -# ndim = data['image'].ndim -# for keep_size in [True, False]: -# for align_corners in [False, True]: -# angle = random.uniform(np.pi / 6, np.pi) -# TESTS.append(( -# f"Rotate{ndim}d, keep_size={keep_size}, align_corners={align_corners}", -# data, -# 5e-2, -# Rotated(KEYS, angle, keep_size, "bilinear", "border", align_corners), -# )) - -# x, y, z = (random.uniform(np.pi / 6, np.pi) for _ in range(3)) -# TESTS.append(( -# f"RandRotate{ndim}d", -# data, -# 5e-2, -# RandRotated(KEYS, x, y, z, 1), -# )) +# TODO: add 3D +for data in [DATA_2D]: # , DATA_3D]: + ndim = data['image'].ndim + for keep_size in [True, False]: + for align_corners in [False, True]: + angle = random.uniform(np.pi / 6, np.pi) + TESTS.append(( + f"Rotate{ndim}d, keep_size={keep_size}, align_corners={align_corners}", + data, + 5e-2, + Rotated(KEYS, angle, keep_size, "bilinear", "border", align_corners), + )) + + x, y, z = (random.uniform(np.pi / 6, np.pi) for _ in range(3)) + TESTS.append(( + f"RandRotate{ndim}d", + data, + 5e-2, + RandRotated(KEYS, x, y, z, 1), + )) TESTS_COMPOSE_X2 = [(t[0] + " Compose", t[1], t[2], Compose(Compose(t[3:]))) for t in TESTS] From 40bc7887408e3606045ff29e6e53e683f3a161db Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Mon, 1 Feb 2021 18:34:39 +0000 Subject: [PATCH 25/80] BorderPadd Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/croppad/dictionary.py | 34 +++++++++++++++++++++++++- tests/test_inverse.py | 26 ++++++++++++++++++-- 2 files changed, 57 insertions(+), 3 deletions(-) diff --git a/monai/transforms/croppad/dictionary.py b/monai/transforms/croppad/dictionary.py index 373eff4d96..11d16c086f 100644 --- a/monai/transforms/croppad/dictionary.py +++ b/monai/transforms/croppad/dictionary.py @@ -151,7 +151,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar return d -class BorderPadd(MapTransform): +class BorderPadd(MapTransform, InvertibleTransform): """ Pad the input data by adding specified borders to every dimension. Dictionary-based wrapper of :py:class:`monai.transforms.BorderPad`. @@ -192,9 +192,41 @@ def __init__( def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) for key, m in zip(self.keys, self.mode): + self.append_applied_transforms(d, key) d[key] = self.padder(d[key], mode=m) return d + def get_input_args(self, key: Hashable, idx: int = 0) -> dict: + return { + "keys": key, + "spatial_border": self.padder.spatial_border, + "mode": self.mode[idx], + } + + def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + d = deepcopy(dict(data)) + + for key in self.keys: + transform = self.get_most_recent_transform(d, key) + # Create inverse transform + orig_size = np.array(transform["orig_size"]) + roi_start = np.array(transform["init_args"]["spatial_border"]) + # Need to convert single value to [min1,min2,...] + if roi_start.size == 1: + roi_start = np.full((len(orig_size)), roi_start) + # need to convert [min1,max1,min2,...] to [min1,min2,...] + elif roi_start.size == 2 * orig_size.size: + roi_start = roi_start[::2] + roi_end = np.array(transform["orig_size"]) + roi_start + + inverse_transform = SpatialCrop(roi_start=roi_start, roi_end=roi_end) + # Apply inverse transform + d[key] = inverse_transform(d[key]) + # Remove the applied transform + self.remove_most_recent_transform(d, key) + + return d + class DivisiblePadd(MapTransform): """ diff --git a/tests/test_inverse.py b/tests/test_inverse.py index a62675295a..db4b7b19fd 100644 --- a/tests/test_inverse.py +++ b/tests/test_inverse.py @@ -28,7 +28,8 @@ Rotated, SpatialPad, SpatialPadd, - SpatialCropd + SpatialCropd, + BorderPadd, ) from monai.utils import Method, optional_import @@ -43,7 +44,7 @@ plt, has_matplotlib = optional_import("matplotlib.pyplot") -IM_1D = AddChannel()(np.arange(0, 11)) +IM_1D = AddChannel()(np.arange(0, 10)) IM_2D = AddChannel()(create_test_image_2d(100, 101)[0]) IM_3D = AddChannel()(create_test_image_3d(100, 101, 107)[0]) @@ -104,6 +105,27 @@ RandSpatialCropd(KEYS, [96, 93, 92], False, True) )) +TESTS.append(( + "BorderPadd 2d", + DATA_2D, + 0, + BorderPadd(KEYS, [3, 7, 2, 5]), +)) + +TESTS.append(( + "BorderPadd 2d", + DATA_2D, + 0, + BorderPadd(KEYS, [3, 7]), +)) + +TESTS.append(( + "BorderPadd 3d", + DATA_3D, + 0, + BorderPadd(KEYS, [4]), +)) + # TODO: add 3D for data in [DATA_2D]: # , DATA_3D]: ndim = data['image'].ndim From 7405c363af8154e917da4b61640e33fa98eeef2f Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Tue, 2 Feb 2021 10:14:35 +0000 Subject: [PATCH 26/80] update after git merge Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/croppad/array.py | 2 +- monai/transforms/intensity/array.py | 2 +- monai/transforms/intensity/dictionary.py | 1 - monai/transforms/io/array.py | 2 +- monai/transforms/post/array.py | 2 +- monai/transforms/spatial/array.py | 2 +- monai/transforms/utility/array.py | 2 +- monai/transforms/utility/dictionary.py | 2 +- 8 files changed, 7 insertions(+), 8 deletions(-) diff --git a/monai/transforms/croppad/array.py b/monai/transforms/croppad/array.py index b4444803a4..ef5e0019bd 100644 --- a/monai/transforms/croppad/array.py +++ b/monai/transforms/croppad/array.py @@ -20,7 +20,7 @@ from monai.config import IndexSelection from monai.data.utils import get_random_patch, get_valid_patch_size -from monai.transforms.compose import Randomizable, Transform +from monai.transforms.transform import Randomizable, Transform from monai.transforms.utils import ( generate_pos_neg_label_crop_centers, generate_spatial_bounding_box, diff --git a/monai/transforms/intensity/array.py b/monai/transforms/intensity/array.py index 2ff9372d8d..82847749f3 100644 --- a/monai/transforms/intensity/array.py +++ b/monai/transforms/intensity/array.py @@ -22,7 +22,7 @@ from monai.config import DtypeLike from monai.networks.layers import GaussianFilter, HilbertTransform, SavitzkyGolayFilter -from monai.transforms.compose import Randomizable, Transform +from monai.transforms.transform import Randomizable, Transform from monai.transforms.utils import rescale_array from monai.utils import PT_BEFORE_1_7, InvalidPyTorchVersionError, dtype_torch_to_numpy, ensure_tuple_size diff --git a/monai/transforms/intensity/dictionary.py b/monai/transforms/intensity/dictionary.py index 1470912b7c..43df15c9d7 100644 --- a/monai/transforms/intensity/dictionary.py +++ b/monai/transforms/intensity/dictionary.py @@ -22,7 +22,6 @@ import torch from monai.config import DtypeLike, KeysCollection -from monai.transforms.compose import MapTransform, Randomizable from monai.transforms.intensity.array import ( AdjustContrast, GaussianSharpen, diff --git a/monai/transforms/io/array.py b/monai/transforms/io/array.py index 772c7cf74f..670654281f 100644 --- a/monai/transforms/io/array.py +++ b/monai/transforms/io/array.py @@ -19,7 +19,7 @@ from monai.config import DtypeLike from monai.data.image_reader import ImageReader, ITKReader, NibabelReader, NumpyReader, PILReader -from monai.transforms.compose import Transform +from monai.transforms.transform import Transform from monai.utils import ensure_tuple, optional_import nib, _ = optional_import("nibabel") diff --git a/monai/transforms/post/array.py b/monai/transforms/post/array.py index 0c60b0cc89..8b4f71093b 100644 --- a/monai/transforms/post/array.py +++ b/monai/transforms/post/array.py @@ -21,7 +21,7 @@ import torch.nn.functional as F from monai.networks import one_hot -from monai.transforms.compose import Transform +from monai.transforms.transform import Transform from monai.transforms.utils import get_largest_connected_component_mask from monai.utils import ensure_tuple diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index 0dc7276de8..6fdab71959 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -22,7 +22,7 @@ from monai.config import USE_COMPILED, DtypeLike from monai.data.utils import compute_shape_offset, to_affine_nd, zoom_affine from monai.networks.layers import AffineTransform, GaussianFilter, grid_pull -from monai.transforms.compose import Randomizable, Transform +from monai.transforms.transform import Randomizable, Transform from monai.transforms.croppad.array import CenterSpatialCrop from monai.transforms.utils import ( create_control_grid, diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index c0ae40de59..8b161a9223 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -21,7 +21,7 @@ import torch from monai.config import DtypeLike, NdarrayTensor -from monai.transforms.compose import Randomizable, Transform +from monai.transforms.transform import Randomizable, Transform from monai.transforms.utils import extreme_points_to_image, get_extreme_points, map_binary_to_indices from monai.utils import ensure_tuple, min_version, optional_import diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index 951c9dd459..d2b0aeb8b4 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -23,7 +23,7 @@ import torch from monai.config import DtypeLike, KeysCollection, NdarrayTensor -from monai.transforms.compose import MapTransform, Randomizable +from monai.transforms.transform import MapTransform, Randomizable from monai.transforms.utility.array import ( AddChannel, AsChannelFirst, From 1ff5729e443f8e287db583c633b3bf71c976d0c6 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Tue, 2 Feb 2021 10:14:51 +0000 Subject: [PATCH 27/80] DivisiblePadd Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/croppad/dictionary.py | 28 +++++++++++++++++++++++++- tests/test_inverse.py | 15 ++++++++++++++ 2 files changed, 42 insertions(+), 1 deletion(-) diff --git a/monai/transforms/croppad/dictionary.py b/monai/transforms/croppad/dictionary.py index de88c9cb00..354c1e2a32 100644 --- a/monai/transforms/croppad/dictionary.py +++ b/monai/transforms/croppad/dictionary.py @@ -228,7 +228,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar return d -class DivisiblePadd(MapTransform): +class DivisiblePadd(MapTransform, InvertibleTransform): """ Pad the input data, so that the spatial sizes are divisible by `k`. Dictionary-based wrapper of :py:class:`monai.transforms.DivisiblePad`. @@ -260,9 +260,35 @@ def __init__( def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) for key, m in zip(self.keys, self.mode): + self.append_applied_transforms(d, key) d[key] = self.padder(d[key], mode=m) return d + def get_input_args(self, key: Hashable, idx: int = 0) -> dict: + return { + "keys": key, + "k": self.padder.k, + "mode": self.mode[idx], + } + + def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + d = deepcopy(dict(data)) + + for key in self.keys: + transform = self.get_most_recent_transform(d, key) + # Create inverse transform + orig_size = np.array(transform["orig_size"]) + current_size = np.array(d[key].shape[1:]) + roi_start = np.floor((current_size - orig_size) / 2) + roi_end = orig_size + roi_start + inverse_transform = SpatialCrop(roi_start=roi_start, roi_end=roi_end) + # Apply inverse transform + d[key] = inverse_transform(d[key]) + # Remove the applied transform + self.remove_most_recent_transform(d, key) + + return d + class SpatialCropd(MapTransform, InvertibleTransform): """ diff --git a/tests/test_inverse.py b/tests/test_inverse.py index db4b7b19fd..7cc370b391 100644 --- a/tests/test_inverse.py +++ b/tests/test_inverse.py @@ -30,6 +30,7 @@ SpatialPadd, SpatialCropd, BorderPadd, + DivisiblePadd, ) from monai.utils import Method, optional_import @@ -126,6 +127,20 @@ BorderPadd(KEYS, [4]), )) +TESTS.append(( + "DivisiblePadd 2d", + DATA_2D, + 0, + DivisiblePadd(KEYS, k=4), +)) + +TESTS.append(( + "DivisiblePadd 3d", + DATA_3D, + 0, + DivisiblePadd(KEYS, k=[4, 8, 11]), +)) + # TODO: add 3D for data in [DATA_2D]: # , DATA_3D]: ndim = data['image'].ndim From b73df2d6bed0113e3cde7b98af03a7b6aefcf9db Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Tue, 2 Feb 2021 10:21:09 +0000 Subject: [PATCH 28/80] Flipd Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/spatial/dictionary.py | 20 +++++++++++++++++++- tests/test_inverse.py | 8 ++++++++ 2 files changed, 27 insertions(+), 1 deletion(-) diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index 5ee0d87cc4..d2d03a0aaf 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -699,7 +699,7 @@ def __call__( return d -class Flipd(MapTransform): +class Flipd(MapTransform, InvertibleTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.Flip`. @@ -718,9 +718,27 @@ def __init__(self, keys: KeysCollection, spatial_axis: Optional[Union[Sequence[i def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) for key in self.keys: + self.append_applied_transforms(d, key) d[key] = self.flipper(d[key]) return d + def get_input_args(self, key: Hashable, idx: int = 0) -> dict: + return { + "keys": key, + "spatial_axis": self.flipper.spatial_axis, + } + + def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + d = deepcopy(dict(data)) + for key in self.keys: + _ = self.get_most_recent_transform(d, key) + # Inverse is same as forward + d[key] = self.flipper(d[key]) + # Remove the applied transform + self.remove_most_recent_transform(d, key) + + return d + class RandFlipd(Randomizable, MapTransform): """ diff --git a/tests/test_inverse.py b/tests/test_inverse.py index 7cc370b391..821b78df5f 100644 --- a/tests/test_inverse.py +++ b/tests/test_inverse.py @@ -31,6 +31,7 @@ SpatialCropd, BorderPadd, DivisiblePadd, + Flipd, ) from monai.utils import Method, optional_import @@ -141,6 +142,13 @@ DivisiblePadd(KEYS, k=[4, 8, 11]), )) +TESTS.append(( + "Flipd 3d", + DATA_3D, + 0, + Flipd(KEYS, [1, 2]), +)) + # TODO: add 3D for data in [DATA_2D]: # , DATA_3D]: ndim = data['image'].ndim From 667273acee4c3dfbb2622df93b553adcc6c05c71 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Tue, 2 Feb 2021 11:19:46 +0000 Subject: [PATCH 29/80] start adding orientationd Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/spatial/dictionary.py | 26 ++++++++++++++- tests/test_inverse.py | 45 ++++++++++++++++++-------- 2 files changed, 56 insertions(+), 15 deletions(-) diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index d2d03a0aaf..a3f0c43e6d 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -194,7 +194,7 @@ def __call__( return d -class Orientationd(MapTransform): +class Orientationd(MapTransform, InvertibleTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.Orientation`. @@ -249,9 +249,33 @@ def __call__( for key in self.keys: meta_data = d[f"{key}_{self.meta_key_postfix}"] d[key], _, new_affine = self.ornt_transform(d[key], affine=meta_data["affine"]) + self.append_applied_transforms(d, key, {"old_affine": meta_data["affine"], "new_affine": new_affine}) meta_data["affine"] = new_affine return d + def get_input_args(self, key: Hashable, idx: int = 0) -> dict: + return { + "keys": key, + "axcodes": self.ornt_transform.axcodes, + "as_closest_canonical": self.ornt_transform.as_closest_canonical, + "labels": self.ornt_transform.labels, + "meta_key_postfix": self.meta_key_postfix, + } + + def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + d = deepcopy(dict(data)) + for key in self.keys: + transform = self.get_most_recent_transform(d, key) + old_affine = transform["extra_info"]["old_affine"] + new_affine = transform["extra_info"]["new_affine"] + + # Inverse is same as forward + d[key] = self.flipper(d[key]) + # Remove the applied transform + self.remove_most_recent_transform(d, key) + + return d + class Rotate90d(MapTransform): """ diff --git a/tests/test_inverse.py b/tests/test_inverse.py index 821b78df5f..1b5b5a3524 100644 --- a/tests/test_inverse.py +++ b/tests/test_inverse.py @@ -21,6 +21,7 @@ from monai.data import CacheDataset from monai.transforms import ( InvertibleTransform, + AddChanneld, AddChannel, Compose, RandRotated, @@ -32,8 +33,10 @@ BorderPadd, DivisiblePadd, Flipd, + LoadImaged, ) from monai.utils import Method, optional_import +from tests.utils import make_nifti_image # from parameterized import parameterized @@ -47,13 +50,14 @@ IM_1D = AddChannel()(np.arange(0, 10)) -IM_2D = AddChannel()(create_test_image_2d(100, 101)[0]) -IM_3D = AddChannel()(create_test_image_3d(100, 101, 107)[0]) +IM_2D_FNAME, SEG_2D_FNAME = [make_nifti_image(i) for i in create_test_image_2d(100, 101)] +IM_3D_FNAME, SEG_3D_FNAME = [make_nifti_image(i) for i in create_test_image_3d(100, 101, 107)] -DATA_1D = {"image": IM_1D, "label": IM_1D, "other": IM_1D} -DATA_2D = {"image": IM_2D, "label": IM_2D, "other": IM_2D} -DATA_3D = {"image": IM_3D, "label": IM_3D, "other": IM_3D} KEYS = ["image", "label"] +DATA_1D = {"image": IM_1D, "label": IM_1D, "other": IM_1D} +LOAD_IMS = Compose([LoadImaged(KEYS), AddChanneld(KEYS)]) +DATA_2D = LOAD_IMS({"image": IM_2D_FNAME, "label": SEG_2D_FNAME}) +DATA_3D = LOAD_IMS({"image": IM_3D_FNAME, "label": SEG_3D_FNAME}) TESTS: List[Tuple] = [] @@ -149,6 +153,14 @@ Flipd(KEYS, [1, 2]), )) +TESTS.append(( + "Flipd 3d", + DATA_3D, + 0, + Flipd(KEYS, [1, 2]), +)) + + # TODO: add 3D for data in [DATA_2D]: # , DATA_3D]: ndim = data['image'].ndim @@ -175,7 +187,7 @@ TESTS = [*TESTS, *TESTS_COMPOSE_X2] -TEST_FAIL_0 = (IM_2D, 0.0, Compose([SpatialPad(spatial_size=[101, 103])])) +TEST_FAIL_0 = (DATA_2D["image"], 0.0, Compose([SpatialPad(spatial_size=[101, 103])])) TESTS_FAIL = [TEST_FAIL_0] def plot_im(orig, fwd_bck, fwd): @@ -204,14 +216,19 @@ def check_inverse(self, keys, orig_d, fwd_bck_d, unmodified_d, acceptable_diff): orig = orig_d[key] fwd_bck = fwd_bck_d[key] unmodified = unmodified_d[key] - mean_diff = np.mean(np.abs(orig - fwd_bck)) - try: - self.assertLessEqual(mean_diff, acceptable_diff) - except AssertionError: - if has_matplotlib: - print(f"Mean diff = {mean_diff} (expected <= {acceptable_diff})") - plot_im(orig, fwd_bck, unmodified) - raise + if isinstance(orig, dict): + self.assertEqual(orig.keys(), fwd_bck.keys()) + for a, b in zip(orig.values(), fwd_bck.values()): + self.assertTrue(np.all(a == b) or np.all(np.isnan(a) & np.isnan(b))) + else: + mean_diff = np.mean(np.abs(orig - fwd_bck)) + try: + self.assertLessEqual(mean_diff, acceptable_diff) + except AssertionError: + if has_matplotlib: + print(f"Mean diff = {mean_diff} (expected <= {acceptable_diff})") + plot_im(orig, fwd_bck, unmodified) + raise # @parameterized.expand(TESTS) def test_inverse(self, _, data, acceptable_diff, *transforms): From b1e077c707e7430d6ce095ea731e741ea1ae1eb9 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Tue, 2 Feb 2021 11:48:36 +0000 Subject: [PATCH 30/80] tidy rotate3d Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/spatial/dictionary.py | 2 +- tests/test_inverse.py | 60 +++++++++++++------------- 2 files changed, 32 insertions(+), 30 deletions(-) diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index a3f0c43e6d..1852ff99ad 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -981,7 +981,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda for key in self.keys: self.append_applied_transforms(d, key) return d - angle: Sequence = (self.x if d[self.keys[0]].ndim == 3 else (self.x, self.y, self.z),) + angle: Sequence = self.x if d[self.keys[0]].ndim == 3 else (self.x, self.y, self.z) rotator = Rotate( angle=angle, keep_size=self.keep_size, diff --git a/tests/test_inverse.py b/tests/test_inverse.py index 1b5b5a3524..647f67d907 100644 --- a/tests/test_inverse.py +++ b/tests/test_inverse.py @@ -35,7 +35,7 @@ Flipd, LoadImaged, ) -from monai.utils import Method, optional_import +from monai.utils import optional_import, set_determinism from tests.utils import make_nifti_image # from parameterized import parameterized @@ -48,6 +48,8 @@ else: plt, has_matplotlib = optional_import("matplotlib.pyplot") +set_determinism(seed=0) + IM_1D = AddChannel()(np.arange(0, 10)) IM_2D_FNAME, SEG_2D_FNAME = [make_nifti_image(i) for i in create_test_image_2d(100, 101)] @@ -62,50 +64,50 @@ TESTS: List[Tuple] = [] TESTS.append(( - "Spatial 2d", + "SpatialPadd (x2) 2d", DATA_2D, 0.0, - SpatialPadd(KEYS, spatial_size=[111, 113], method=Method.END), + SpatialPadd(KEYS, spatial_size=[111, 113], method="end"), SpatialPadd(KEYS, spatial_size=[118, 117]), )) TESTS.append(( - "Spatial 3d", + "SpatialPadd 3d", DATA_3D, 0.0, SpatialPadd(KEYS, spatial_size=[112, 113, 116]), )) TESTS.append(( - "Rand, prob 0", + "RandRotated, prob 0", DATA_2D, 0, RandRotated(KEYS, prob=0), )) TESTS.append(( - "Spatial crop 2d", + "SpatialCropd 2d", DATA_2D, 2e-2, SpatialCropd(KEYS, [49, 51], [90, 89]), )) TESTS.append(( - "Spatial crop 3d", + "SpatialCropd 3d", DATA_3D, 2e-2, SpatialCropd(KEYS, [49, 51, 44], [90, 89, 93]), )) TESTS.append(( - "Rand spatial crop 2d", + "RandSpatialCropd 2d", DATA_2D, 2e-2, RandSpatialCropd(KEYS, [96, 93], True, False) )) TESTS.append(( - "Rand spatial crop 3d", + "RandSpatialCropd 3d", DATA_3D, 2e-2, RandSpatialCropd(KEYS, [96, 93, 92], False, True) @@ -160,27 +162,27 @@ Flipd(KEYS, [1, 2]), )) +TESTS.append(( + "Rotated 2d", + DATA_2D, + 6e-2, + Rotated(KEYS, random.uniform(np.pi / 6, np.pi), keep_size=True, align_corners=False), +)) + +TESTS.append(( + "RandRotated 2d", + DATA_2D, + 6e-2, + RandRotated(KEYS, random.uniform(np.pi / 6, np.pi)), +)) -# TODO: add 3D -for data in [DATA_2D]: # , DATA_3D]: - ndim = data['image'].ndim - for keep_size in [True, False]: - for align_corners in [False, True]: - angle = random.uniform(np.pi / 6, np.pi) - TESTS.append(( - f"Rotate{ndim}d, keep_size={keep_size}, align_corners={align_corners}", - data, - 5e-2, - Rotated(KEYS, angle, keep_size, "bilinear", "border", align_corners), - )) - - x, y, z = (random.uniform(np.pi / 6, np.pi) for _ in range(3)) - TESTS.append(( - f"RandRotate{ndim}d", - data, - 5e-2, - RandRotated(KEYS, x, y, z, 1), - )) +# TODO: add 3D (can replace RandRotated 2d) +# TESTS.append(( +# "RandRotated 3d", +# DATA_3D, +# 5e-2, +# RandRotated(KEYS, *(random.uniform(np.pi / 6, np.pi) for _ in range(3)), 1), +# )) TESTS_COMPOSE_X2 = [(t[0] + " Compose", t[1], t[2], Compose(Compose(t[3:]))) for t in TESTS] From 898f2d0f8f93e9c2bb19a18efe3890874f02c10c Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Tue, 2 Feb 2021 14:49:27 +0000 Subject: [PATCH 31/80] Orientationd Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/spatial/dictionary.py | 29 ++++++++++++++++++-------- monai/transforms/transform.py | 10 ++++----- tests/test_inverse.py | 19 ++++++++++------- tests/utils.py | 9 ++++++++ 4 files changed, 46 insertions(+), 21 deletions(-) diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index 1852ff99ad..c05710fd95 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -46,8 +46,11 @@ ensure_tuple, ensure_tuple_rep, fall_back_tuple, + optional_import, ) +nib, _ = optional_import("nibabel") + __all__ = [ "Spacingd", "Orientationd", @@ -247,10 +250,11 @@ def __call__( ) -> Dict[Union[Hashable, str], Union[np.ndarray, Dict[str, np.ndarray]]]: d: Dict = dict(data) for key in self.keys: - meta_data = d[f"{key}_{self.meta_key_postfix}"] - d[key], _, new_affine = self.ornt_transform(d[key], affine=meta_data["affine"]) - self.append_applied_transforms(d, key, {"old_affine": meta_data["affine"], "new_affine": new_affine}) - meta_data["affine"] = new_affine + meta_data_key = f"{key}_{self.meta_key_postfix}" + meta_data = d[meta_data_key] + d[key], old_affine, new_affine = self.ornt_transform(d[key], affine=meta_data["affine"]) + self.append_applied_transforms(d, key, extra_info={"meta_data_key": meta_data_key, "old_affine": old_affine}) + d[meta_data_key]["affine"] = new_affine return d def get_input_args(self, key: Hashable, idx: int = 0) -> dict: @@ -266,11 +270,18 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar d = deepcopy(dict(data)) for key in self.keys: transform = self.get_most_recent_transform(d, key) - old_affine = transform["extra_info"]["old_affine"] - new_affine = transform["extra_info"]["new_affine"] - - # Inverse is same as forward - d[key] = self.flipper(d[key]) + meta_data = d[transform["extra_info"]["meta_data_key"]] + orig_affine = transform["extra_info"]["old_affine"] + orig_axcodes = nib.orientations.aff2axcodes(orig_affine) + inverse_transform = Orientation( + axcodes=orig_axcodes, + as_closest_canonical=transform["init_args"]["as_closest_canonical"], + labels=transform["init_args"]["labels"], + ) + # Apply inverse + d[key], _, new_affine = inverse_transform(d[key], affine=meta_data["affine"]) + self.append_applied_transforms(d, key, extra_info={"old_affine": meta_data["affine"]}) + meta_data["affine"] = new_affine # Remove the applied transform self.remove_most_recent_transform(d, key) diff --git a/monai/transforms/transform.py b/monai/transforms/transform.py index 7a97a148a1..095e732ffd 100644 --- a/monai/transforms/transform.py +++ b/monai/transforms/transform.py @@ -223,16 +223,16 @@ class InvertibleTransform(ABC): first out for the inverted transforms. """ - def append_applied_transforms(self, data: dict, key: Hashable, idx: int = 0, extra_args: Optional[dict] = None) -> None: + def append_applied_transforms(self, data: dict, key: Hashable, idx: int = 0, extra_info: Optional[dict] = None) -> None: """Append to list of applied transforms for that key.""" key_transform = str(key) + "_transforms" # If this is the first, create list if key_transform not in data: data[key_transform] = [] - data[key_transform].append( - {"class": type(self), "init_args": self.get_input_args(key, idx), - "orig_size": data[key].shape[1:], "extra_info": extra_args} - ) + data[key_transform].append({ + "class": type(self), "init_args": self.get_input_args(key, idx), + "orig_size": data[key].shape[1:], "extra_info": extra_info + }) # If class is randomizable, store whether the transform was actually performed (based on `prob`) if isinstance(self, Randomizable): data[key_transform][-1]["do_transform"] = self._do_transform diff --git a/tests/test_inverse.py b/tests/test_inverse.py index 647f67d907..549d5aa0a7 100644 --- a/tests/test_inverse.py +++ b/tests/test_inverse.py @@ -10,6 +10,7 @@ # limitations under the License. +from monai.transforms.spatial.dictionary import Orientationd import random import unittest from typing import TYPE_CHECKING @@ -36,7 +37,7 @@ LoadImaged, ) from monai.utils import optional_import, set_determinism -from tests.utils import make_nifti_image +from tests.utils import make_nifti_image, make_rand_affine # from parameterized import parameterized @@ -50,10 +51,11 @@ set_determinism(seed=0) +AFFINE = make_rand_affine() IM_1D = AddChannel()(np.arange(0, 10)) IM_2D_FNAME, SEG_2D_FNAME = [make_nifti_image(i) for i in create_test_image_2d(100, 101)] -IM_3D_FNAME, SEG_3D_FNAME = [make_nifti_image(i) for i in create_test_image_3d(100, 101, 107)] +IM_3D_FNAME, SEG_3D_FNAME = [make_nifti_image(i, AFFINE) for i in create_test_image_3d(100, 101, 107)] KEYS = ["image", "label"] DATA_1D = {"image": IM_1D, "label": IM_1D, "other": IM_1D} @@ -184,6 +186,13 @@ # RandRotated(KEYS, *(random.uniform(np.pi / 6, np.pi) for _ in range(3)), 1), # )) +TESTS.append(( + "Orientationd 3d", + DATA_3D, + 0, + Orientationd(KEYS, 'RAS'), +)) + TESTS_COMPOSE_X2 = [(t[0] + " Compose", t[1], t[2], Compose(Compose(t[3:]))) for t in TESTS] TESTS = [*TESTS, *TESTS_COMPOSE_X2] @@ -218,11 +227,7 @@ def check_inverse(self, keys, orig_d, fwd_bck_d, unmodified_d, acceptable_diff): orig = orig_d[key] fwd_bck = fwd_bck_d[key] unmodified = unmodified_d[key] - if isinstance(orig, dict): - self.assertEqual(orig.keys(), fwd_bck.keys()) - for a, b in zip(orig.values(), fwd_bck.values()): - self.assertTrue(np.all(a == b) or np.all(np.isnan(a) & np.isnan(b))) - else: + if isinstance(orig, np.ndarray): mean_diff = np.mean(np.abs(orig - fwd_bck)) try: self.assertLessEqual(mean_diff, acceptable_diff) diff --git a/tests/utils.py b/tests/utils.py index ebc9bff99f..a99e2fc4b4 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -150,6 +150,15 @@ def make_nifti_image(array, affine=None): os.close(temp_f) return image_name +def make_rand_affine(ndim: int = 3): + """Create random affine transformation (with values == -1, 0 or 1).""" + vals = np.random.choice([-1, 1], size=ndim) + positions = np.random.choice([0, 1, 2], size=ndim, replace=False) + af = np.zeros([ndim + 1, ndim + 1]) + af[ndim, ndim] = 1 + for i, (v, p) in enumerate(zip(vals, positions)): + af[i, p] = v + return af class DistTestCase(unittest.TestCase): """ From b8c1b8ae657cbbc41832096c11cfb7560bd88aef Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Tue, 2 Feb 2021 14:56:59 +0000 Subject: [PATCH 32/80] Rotate90d Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/spatial/dictionary.py | 25 ++++++++++++++++++++++++- tests/test_inverse.py | 15 +++++++++++++++ 2 files changed, 39 insertions(+), 1 deletion(-) diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index c05710fd95..a4a53df3e8 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -288,7 +288,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar return d -class Rotate90d(MapTransform): +class Rotate90d(MapTransform, InvertibleTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.Rotate90`. """ @@ -306,9 +306,32 @@ def __init__(self, keys: KeysCollection, k: int = 1, spatial_axes: Tuple[int, in def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) for key in self.keys: + self.append_applied_transforms(d, key) d[key] = self.rotator(d[key]) return d + def get_input_args(self, key: Hashable, idx: int = 0) -> dict: + return { + "keys": key, + "k": self.rotator.k, + "spatial_axes": self.rotator.spatial_axes, + } + + def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + d = deepcopy(dict(data)) + for key in self.keys: + transform = self.get_most_recent_transform(d, key) + spatial_axes = transform["init_args"]["spatial_axes"] + num_times_rotated = transform["init_args"]["k"] + num_times_to_rotate = 4 - num_times_rotated + inverse_transform = Rotate90(num_times_to_rotate, spatial_axes) + # Apply inverse + d[key] = inverse_transform(d[key]) + # Remove the applied transform + self.remove_most_recent_transform(d, key) + + return d + class RandRotate90d(Randomizable, MapTransform): """ diff --git a/tests/test_inverse.py b/tests/test_inverse.py index 549d5aa0a7..41523592dd 100644 --- a/tests/test_inverse.py +++ b/tests/test_inverse.py @@ -35,6 +35,7 @@ DivisiblePadd, Flipd, LoadImaged, + Rotate90d, ) from monai.utils import optional_import, set_determinism from tests.utils import make_nifti_image, make_rand_affine @@ -193,6 +194,20 @@ Orientationd(KEYS, 'RAS'), )) +TESTS.append(( + "Rotate90d 2d", + DATA_2D, + 0, + Rotate90d(KEYS), +)) + +TESTS.append(( + "Rotate90d 3d", + DATA_3D, + 0, + Rotate90d(KEYS, k=2, spatial_axes=[1, 2]), +)) + TESTS_COMPOSE_X2 = [(t[0] + " Compose", t[1], t[2], Compose(Compose(t[3:]))) for t in TESTS] TESTS = [*TESTS, *TESTS_COMPOSE_X2] From 3020fd4fd727e0eb5bb063b0bb947d3b954d4ec8 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Tue, 2 Feb 2021 15:31:05 +0000 Subject: [PATCH 33/80] Zoomd Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/spatial/dictionary.py | 39 ++++++++++++++++++++++++-- tests/test_inverse.py | 25 +++++++++++++++++ 2 files changed, 61 insertions(+), 3 deletions(-) diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index a4a53df3e8..6607774faa 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -23,7 +23,7 @@ from monai.config import DtypeLike, KeysCollection from monai.networks.layers.simplelayers import GaussianFilter -from monai.transforms.croppad.array import CenterSpatialCrop +from monai.transforms.croppad.array import CenterSpatialCrop, SpatialPad from monai.transforms.spatial.array import ( Flip, Orientation, @@ -270,6 +270,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar d = deepcopy(dict(data)) for key in self.keys: transform = self.get_most_recent_transform(d, key) + # Create inverse transform meta_data = d[transform["extra_info"]["meta_data_key"]] orig_affine = transform["extra_info"]["old_affine"] orig_axcodes = nib.orientations.aff2axcodes(orig_affine) @@ -280,7 +281,6 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar ) # Apply inverse d[key], _, new_affine = inverse_transform(d[key], affine=meta_data["affine"]) - self.append_applied_transforms(d, key, extra_info={"old_affine": meta_data["affine"]}) meta_data["affine"] = new_affine # Remove the applied transform self.remove_most_recent_transform(d, key) @@ -321,6 +321,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar d = deepcopy(dict(data)) for key in self.keys: transform = self.get_most_recent_transform(d, key) + # Create inverse transform spatial_axes = transform["init_args"]["spatial_axes"] num_times_rotated = transform["init_args"]["k"] num_times_to_rotate = 4 - num_times_rotated @@ -1075,7 +1076,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar return d -class Zoomd(MapTransform): +class Zoomd(MapTransform, InvertibleTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.Zoom`. @@ -1117,6 +1118,7 @@ def __init__( def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) for idx, key in enumerate(self.keys): + self.append_applied_transforms(d, key) d[key] = self.zoomer( d[key], mode=self.mode[idx], @@ -1125,6 +1127,37 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda ) return d + def get_input_args(self, key: Hashable, idx: int = 0) -> dict: + return { + "keys": key, + "zoom": self.zoomer.zoom, + "mode": self.mode[idx], + "padding_mode": self.padding_mode[idx], + "align_corners": self.align_corners[idx], + "keep_size": self.zoomer.keep_size, + } + + def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + d = deepcopy(dict(data)) + for key in self.keys: + transform = self.get_most_recent_transform(d, key) + # Create inverse transform + init_args = transform["init_args"] + zoom = np.array(init_args["zoom"]) + inverse_transform = Zoom(zoom=1 / zoom, keep_size=init_args["keep_size"]) + # Apply inverse + d[key] = inverse_transform( + d[key], + mode=init_args["mode"], + padding_mode=init_args["padding_mode"], + align_corners=init_args["align_corners"], + ) + # Size might be out by 1 voxel so pad + d[key] = SpatialPad(transform["orig_size"])(d[key]) + # Remove the applied transform + self.remove_most_recent_transform(d, key) + + return d class RandZoomd(Randomizable, MapTransform): """ diff --git a/tests/test_inverse.py b/tests/test_inverse.py index 41523592dd..e742c15f35 100644 --- a/tests/test_inverse.py +++ b/tests/test_inverse.py @@ -36,6 +36,7 @@ Flipd, LoadImaged, Rotate90d, + Zoomd, ) from monai.utils import optional_import, set_determinism from tests.utils import make_nifti_image, make_rand_affine @@ -208,11 +209,33 @@ Rotate90d(KEYS, k=2, spatial_axes=[1, 2]), )) +TESTS.append(( + "Zoomd 1d", + DATA_1D, + 0, + Zoomd(KEYS, zoom=2, keep_size=False), +)) + +TESTS.append(( + "Zoomd 2d", + DATA_2D, + 8e-2, + Zoomd(KEYS, zoom=0.5), +)) + +TESTS.append(( + "Zoomd 3d", + DATA_3D, + 2e-2, + Zoomd(KEYS, zoom=[2.5, 1, 3], keep_size=False), +)) + TESTS_COMPOSE_X2 = [(t[0] + " Compose", t[1], t[2], Compose(Compose(t[3:]))) for t in TESTS] TESTS = [*TESTS, *TESTS_COMPOSE_X2] +# Should fail because uses an array transform (SpatialPad), as opposed to dictionary TEST_FAIL_0 = (DATA_2D["image"], 0.0, Compose([SpatialPad(spatial_size=[101, 103])])) TESTS_FAIL = [TEST_FAIL_0] @@ -251,6 +274,8 @@ def check_inverse(self, keys, orig_d, fwd_bck_d, unmodified_d, acceptable_diff): print(f"Mean diff = {mean_diff} (expected <= {acceptable_diff})") plot_im(orig, fwd_bck, unmodified) raise + else: + np.testing.assert_equal(orig, fwd_bck) # @parameterized.expand(TESTS) def test_inverse(self, _, data, acceptable_diff, *transforms): From 9034d7b22a6ab08d1ef1d0147b2ee638ed742618 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Tue, 2 Feb 2021 16:53:54 +0000 Subject: [PATCH 34/80] CenterSpatialCropd Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/croppad/dictionary.py | 33 +++++++++++++++++++++++++- monai/transforms/transform.py | 5 ++-- tests/test_inverse.py | 25 ++++++++++++++----- 3 files changed, 54 insertions(+), 9 deletions(-) diff --git a/monai/transforms/croppad/dictionary.py b/monai/transforms/croppad/dictionary.py index 354c1e2a32..02310a48e0 100644 --- a/monai/transforms/croppad/dictionary.py +++ b/monai/transforms/croppad/dictionary.py @@ -353,7 +353,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar return d -class CenterSpatialCropd(MapTransform): +class CenterSpatialCropd(MapTransform, InvertibleTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.CenterSpatialCrop`. @@ -371,9 +371,40 @@ def __init__(self, keys: KeysCollection, roi_size: Union[Sequence[int], int]) -> def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) for key in self.keys: + self.append_applied_transforms(d, key) d[key] = self.cropper(d[key]) + # cropper will modify `roi_size` key, so update it + self.get_most_recent_transform(d, key, False)["init_args"]["roi_size"] = self.cropper.roi_size return d + def get_input_args(self, key: Hashable, idx: int = 0) -> dict: + return { + "keys": key, + "roi_size": self.cropper.roi_size, + } + + def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + d = deepcopy(dict(data)) + + for key in self.keys: + transform = self.get_most_recent_transform(d, key) + # Create inverse transform + orig_size = np.array(transform["orig_size"]) + current_size = np.array(transform["init_args"]["roi_size"]) + pad_to_start = np.floor((orig_size - current_size) / 2) + # in each direction, if original size is even and current size is odd, += 1 + pad_to_start[np.logical_and(orig_size % 2 == 0, current_size % 2 == 1)] += 1 + pad_to_end = orig_size - current_size - pad_to_start + pad = np.empty((2 * len(orig_size)), dtype=np.int32) + pad[0::2] = pad_to_start + pad[1::2] = pad_to_end + inverse_transform = BorderPad(pad.tolist()) + # Apply inverse transform + d[key] = inverse_transform(d[key]) + # Remove the applied transform + self.remove_most_recent_transform(d, key) + + return d class RandSpatialCropd(Randomizable, MapTransform, InvertibleTransform): """ diff --git a/monai/transforms/transform.py b/monai/transforms/transform.py index 095e732ffd..a2a2efe3b7 100644 --- a/monai/transforms/transform.py +++ b/monai/transforms/transform.py @@ -252,10 +252,11 @@ def check_transforms_match(self, transform: dict, key: Hashable) -> None: if np.any(t1[k] != t2[k]): raise RuntimeError(explanation) - def get_most_recent_transform(self, data: dict, key: Hashable) -> dict: + def get_most_recent_transform(self, data: dict, key: Hashable, check: bool = True) -> dict: """Get most recent transform.""" transform = dict(data[str(key) + "_transforms"][-1]) - self.check_transforms_match(transform, key) + if check: + self.check_transforms_match(transform, key) return transform @staticmethod diff --git a/tests/test_inverse.py b/tests/test_inverse.py index e742c15f35..27b561bb23 100644 --- a/tests/test_inverse.py +++ b/tests/test_inverse.py @@ -37,6 +37,7 @@ LoadImaged, Rotate90d, Zoomd, + CenterSpatialCropd, ) from monai.utils import optional_import, set_determinism from tests.utils import make_nifti_image, make_rand_affine @@ -51,7 +52,7 @@ else: plt, has_matplotlib = optional_import("matplotlib.pyplot") -set_determinism(seed=0) +# set_determinism(seed=0) AFFINE = make_rand_affine() @@ -92,7 +93,7 @@ TESTS.append(( "SpatialCropd 2d", DATA_2D, - 2e-2, + 3e-2, SpatialCropd(KEYS, [49, 51], [90, 89]), )) @@ -106,7 +107,7 @@ TESTS.append(( "RandSpatialCropd 2d", DATA_2D, - 2e-2, + 5e-2, RandSpatialCropd(KEYS, [96, 93], True, False) )) @@ -226,10 +227,24 @@ TESTS.append(( "Zoomd 3d", DATA_3D, - 2e-2, + 3e-2, Zoomd(KEYS, zoom=[2.5, 1, 3], keep_size=False), )) +TESTS.append(( + "CenterSpatialCropd 2d", + DATA_2D, + 0, + CenterSpatialCropd(KEYS, roi_size=95), +)) + +TESTS.append(( + "CenterSpatialCropd 3d", + DATA_3D, + 0, + CenterSpatialCropd(KEYS, roi_size=[95, 97, 98]), +)) + TESTS_COMPOSE_X2 = [(t[0] + " Compose", t[1], t[2], Compose(Compose(t[3:]))) for t in TESTS] TESTS = [*TESTS, *TESTS_COMPOSE_X2] @@ -274,8 +289,6 @@ def check_inverse(self, keys, orig_d, fwd_bck_d, unmodified_d, acceptable_diff): print(f"Mean diff = {mean_diff} (expected <= {acceptable_diff})") plot_im(orig, fwd_bck, unmodified) raise - else: - np.testing.assert_equal(orig, fwd_bck) # @parameterized.expand(TESTS) def test_inverse(self, _, data, acceptable_diff, *transforms): From 69927eae10ad68aadfbeb51243e01509ad8e0cf8 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Tue, 2 Feb 2021 17:16:21 +0000 Subject: [PATCH 35/80] CropForegroundd Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/croppad/dictionary.py | 35 +++++++++++++++++++++++++- tests/test_inverse.py | 30 +++++++++++++++++----- 2 files changed, 58 insertions(+), 7 deletions(-) diff --git a/monai/transforms/croppad/dictionary.py b/monai/transforms/croppad/dictionary.py index 02310a48e0..321f22f23e 100644 --- a/monai/transforms/croppad/dictionary.py +++ b/monai/transforms/croppad/dictionary.py @@ -558,7 +558,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> List[Dict[Hashable, n return [self.cropper(data) for _ in range(self.num_samples)] -class CropForegroundd(MapTransform): +class CropForegroundd(MapTransform, InvertibleTransform): """ Dictionary-based version :py:class:`monai.transforms.CropForeground`. Crop only the foreground object of the expected images. @@ -610,9 +610,42 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda d[self.end_coord_key] = np.asarray(box_end) cropper = SpatialCrop(roi_start=box_start, roi_end=box_end) for key in self.keys: + self.append_applied_transforms(d, key, extra_info={"box_start": box_start, "box_end": box_end}) d[key] = cropper(d[key]) return d + def get_input_args(self, key: Hashable, idx: int = 0) -> dict: + return { + "keys": key, + "source_key": self.source_key, + "select_fn": self.select_fn, + "channel_indices": self.channel_indices, + "margin": self.margin, + "start_coord_key": self.start_coord_key, + "end_coord_key": self.end_coord_key, + } + + def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + d = deepcopy(dict(data)) + for key in self.keys: + transform = self.get_most_recent_transform(d, key) + # Create inverse transform + orig_size = np.array(transform["orig_size"]) + extra_info = transform["extra_info"] + pad_to_start = np.array(extra_info["box_start"]) + pad_to_end = orig_size - np.array(extra_info["box_end"]) + # interweave mins and maxes + pad = np.empty((2 * len(orig_size)), dtype=np.int32) + pad[0::2] = pad_to_start + pad[1::2] = pad_to_end + inverse_transform = BorderPad(pad.tolist()) + # Apply inverse transform + d[key] = inverse_transform(d[key]) + # Remove the applied transform + self.remove_most_recent_transform(d, key) + + return d + class RandWeightedCropd(Randomizable, MapTransform): """ diff --git a/tests/test_inverse.py b/tests/test_inverse.py index 27b561bb23..7c25ad651e 100644 --- a/tests/test_inverse.py +++ b/tests/test_inverse.py @@ -38,6 +38,7 @@ Rotate90d, Zoomd, CenterSpatialCropd, + CropForegroundd, ) from monai.utils import optional_import, set_determinism from tests.utils import make_nifti_image, make_rand_affine @@ -100,7 +101,7 @@ TESTS.append(( "SpatialCropd 3d", DATA_3D, - 2e-2, + 4e-2, SpatialCropd(KEYS, [49, 51, 44], [90, 89, 93]), )) @@ -170,7 +171,7 @@ TESTS.append(( "Rotated 2d", DATA_2D, - 6e-2, + 8e-2, Rotated(KEYS, random.uniform(np.pi / 6, np.pi), keep_size=True, align_corners=False), )) @@ -245,6 +246,20 @@ CenterSpatialCropd(KEYS, roi_size=[95, 97, 98]), )) +TESTS.append(( + "CropForegroundd 2d", + DATA_2D, + 0, + CropForegroundd(KEYS, source_key="label", margin=[2, 1]) +)) + +TESTS.append(( + "CropForegroundd 3d", + DATA_3D, + 0, + CropForegroundd(KEYS, source_key="label") +)) + TESTS_COMPOSE_X2 = [(t[0] + " Compose", t[1], t[2], Compose(Compose(t[3:]))) for t in TESTS] TESTS = [*TESTS, *TESTS_COMPOSE_X2] @@ -275,7 +290,7 @@ def plot_im(orig, fwd_bck, fwd): class TestInverse(unittest.TestCase): - def check_inverse(self, keys, orig_d, fwd_bck_d, unmodified_d, acceptable_diff): + def check_inverse(self, name, keys, orig_d, fwd_bck_d, unmodified_d, acceptable_diff): for key in keys: orig = orig_d[key] fwd_bck = fwd_bck_d[key] @@ -285,13 +300,15 @@ def check_inverse(self, keys, orig_d, fwd_bck_d, unmodified_d, acceptable_diff): try: self.assertLessEqual(mean_diff, acceptable_diff) except AssertionError: + print(f"Failed: {name}. Mean diff = {mean_diff} (expected <= {acceptable_diff})") if has_matplotlib: - print(f"Mean diff = {mean_diff} (expected <= {acceptable_diff})") plot_im(orig, fwd_bck, unmodified) raise # @parameterized.expand(TESTS) def test_inverse(self, _, data, acceptable_diff, *transforms): + name = _ + forwards = [data.copy()] # Apply forwards @@ -308,7 +325,7 @@ def test_inverse(self, _, data, acceptable_diff, *transforms): for i, t in enumerate(reversed(transforms)): if isinstance(t, InvertibleTransform): fwd_bck = t.inverse(fwd_bck) - self.check_inverse(data.keys(), forwards[-i - 2], fwd_bck, forwards[-1], acceptable_diff) + self.check_inverse(name, data.keys(), forwards[-i - 2], fwd_bck, forwards[-1], acceptable_diff) # @parameterized.expand(TESTS_FAIL) def test_fail(self, data, _, *transform): @@ -318,6 +335,7 @@ def test_fail(self, data, _, *transform): # @parameterized.expand(TEST_COMPOSES) def test_w_data_loader(self, _, data, acceptable_diff, *transforms): + name = _ transform = transforms[0] numel = 2 test_data = [data for _ in range(numel)] @@ -328,7 +346,7 @@ def test_w_data_loader(self, _, data, acceptable_diff, *transforms): for _ in range(num_epochs): for data_fwd in dataset: data_fwd_bck = transform.inverse(data_fwd) - self.check_inverse(data.keys(), data, data_fwd_bck, data_fwd, acceptable_diff) + self.check_inverse(name, data.keys(), data, data_fwd_bck, data_fwd, acceptable_diff) if __name__ == "__main__": From 2f48ad78953f8684169d79abf0125bb7ca2d5af3 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Wed, 3 Feb 2021 11:08:40 +0000 Subject: [PATCH 36/80] Spacingd Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/spatial/array.py | 2 +- monai/transforms/spatial/dictionary.py | 48 ++++++++++++++++++++++++-- tests/test_inverse.py | 17 ++++++--- 3 files changed, 58 insertions(+), 9 deletions(-) diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index 6fdab71959..2794c38859 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -146,7 +146,7 @@ def __call__( ValueError: When ``pixdim`` is nonpositive. Returns: - data_array (resampled into `self.pixdim`), original pixdim, current pixdim. + data_array (resampled into `self.pixdim`), original affine, current affine. """ _dtype = dtype or self.dtype or data_array.dtype diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index 6607774faa..58fa9da961 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -102,7 +102,7 @@ NumpyPadModeSequence = Union[Sequence[Union[NumpyPadMode, str]], NumpyPadMode, str] -class Spacingd(MapTransform): +class Spacingd(MapTransform, InvertibleTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.Spacing`. @@ -181,10 +181,11 @@ def __call__( ) -> Dict[Union[Hashable, str], Union[np.ndarray, Dict[str, np.ndarray]]]: d: Dict = dict(data) for idx, key in enumerate(self.keys): - meta_data = d[f"{key}_{self.meta_key_postfix}"] + meta_data_key = f"{key}_{self.meta_key_postfix}" + meta_data = d[meta_data_key] # resample array of each corresponding key # using affine fetched from d[affine_key] - d[key], _, new_affine = self.spacing_transform( + d[key], old_affine, new_affine = self.spacing_transform( data_array=np.asarray(d[key]), affine=meta_data["affine"], mode=self.mode[idx], @@ -192,10 +193,51 @@ def __call__( align_corners=self.align_corners[idx], dtype=self.dtype[idx], ) + self.append_applied_transforms(d, key, extra_info={"meta_data_key": meta_data_key, "old_affine": old_affine}) # set the 'affine' key meta_data["affine"] = new_affine return d + def get_input_args(self, key: Hashable, idx: int = 0) -> dict: + return { + "keys": key, + "pixdim": self.spacing_transform.pixdim, + "diagonal": self.spacing_transform.diagonal, + "mode": self.mode[idx], + "padding_mode": self.padding_mode[idx], + "align_corners": self.align_corners[idx], + "dtype": self.dtype[idx], + "meta_key_postfix": self.meta_key_postfix + } + + def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + d = deepcopy(dict(data)) + for key in self.keys: + transform = self.get_most_recent_transform(d, key) + init_args = transform["init_args"] + if init_args["diagonal"]: + raise RuntimeError( + "Spacingd:inverse not yet implemented for diagonal=True. " + + "Please raise a github issue if you need this feature") + # Create inverse transform + meta_data = d[transform["extra_info"]["meta_data_key"]] + orig_pixdim = np.sqrt(np.sum(np.square(transform["extra_info"]["old_affine"]), 0))[:-1] + inverse_transform = Spacing(orig_pixdim, diagonal=init_args["diagonal"]) + # Apply inverse + d[key], _, new_affine = inverse_transform( + data_array=np.asarray(d[key]), + affine=meta_data["affine"], + mode=init_args["mode"], + padding_mode=init_args["padding_mode"], + align_corners=init_args["align_corners"], + dtype=init_args["dtype"], + ) + meta_data["affine"] = new_affine + # Remove the applied transform + self.remove_most_recent_transform(d, key) + + return d + class Orientationd(MapTransform, InvertibleTransform): """ diff --git a/tests/test_inverse.py b/tests/test_inverse.py index 7c25ad651e..8a612f9d50 100644 --- a/tests/test_inverse.py +++ b/tests/test_inverse.py @@ -9,13 +9,10 @@ # See the License for the specific language governing permissions and # limitations under the License. - -from monai.transforms.spatial.dictionary import Orientationd import random import unittest from typing import TYPE_CHECKING - import numpy as np from typing import List, Tuple from monai.data import create_test_image_2d, create_test_image_3d @@ -39,6 +36,8 @@ Zoomd, CenterSpatialCropd, CropForegroundd, + Orientationd, + Spacingd, ) from monai.utils import optional_import, set_determinism from tests.utils import make_nifti_image, make_rand_affine @@ -56,6 +55,7 @@ # set_determinism(seed=0) AFFINE = make_rand_affine() +AFFINE[0] *= 2 IM_1D = AddChannel()(np.arange(0, 10)) IM_2D_FNAME, SEG_2D_FNAME = [make_nifti_image(i) for i in create_test_image_2d(100, 101)] @@ -221,8 +221,8 @@ TESTS.append(( "Zoomd 2d", DATA_2D, - 8e-2, - Zoomd(KEYS, zoom=0.5), + 1e-1, + Zoomd(KEYS, zoom=0.9), )) TESTS.append(( @@ -260,6 +260,13 @@ CropForegroundd(KEYS, source_key="label") )) +TESTS.append(( + "Spacingd 3d", + DATA_3D, + 3e-2, + Spacingd(KEYS, [0.5, 0.7, 0.9], diagonal=False) +)) + TESTS_COMPOSE_X2 = [(t[0] + " Compose", t[1], t[2], Compose(Compose(t[3:]))) for t in TESTS] TESTS = [*TESTS, *TESTS_COMPOSE_X2] From 43155c4e03761ad1fed27a35175a58ff1b17fb8b Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Wed, 3 Feb 2021 11:26:19 +0000 Subject: [PATCH 37/80] Resized Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/spatial/array.py | 2 +- monai/transforms/spatial/dictionary.py | 27 +++++++++++++++++++++++++- tests/test_inverse.py | 17 +++++++++++++++- 3 files changed, 43 insertions(+), 3 deletions(-) diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index 2794c38859..f130c9a832 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -308,7 +308,7 @@ def __call__(self, img: np.ndarray) -> np.ndarray: class Resize(Transform): """ - Resize the input image to given spatial size. + Resize the input image to given spatial size (with scaling, not cropping/padding). Implemented using :py:class:`torch.nn.functional.interpolate`. Args: diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index 58fa9da961..6a31ce8748 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -425,7 +425,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Mapping[Hashable, np. return d -class Resized(MapTransform): +class Resized(MapTransform, InvertibleTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.Resize`. @@ -461,9 +461,34 @@ def __init__( def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) for idx, key in enumerate(self.keys): + self.append_applied_transforms(d, key) d[key] = self.resizer(d[key], mode=self.mode[idx], align_corners=self.align_corners[idx]) return d + def get_input_args(self, key: Hashable, idx: int = 0) -> dict: + return { + "keys": key, + "spatial_size": self.resizer.spatial_size, + "mode": self.mode[idx], + "align_corners": self.align_corners[idx], + } + + def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + d = deepcopy(dict(data)) + for key in self.keys: + transform = self.get_most_recent_transform(d, key) + orig_size = transform["orig_size"] + mode = transform["init_args"]["mode"] + align_corners = transform["init_args"]["align_corners"] + # Create inverse transform + inverse_transform = Resize(orig_size, mode, align_corners) + # Apply inverse transform + d[key] = inverse_transform(d[key]) + # Remove the applied transform + self.remove_most_recent_transform(d, key) + + return d + class RandAffined(Randomizable, MapTransform): """ diff --git a/tests/test_inverse.py b/tests/test_inverse.py index 8a612f9d50..1f5e143868 100644 --- a/tests/test_inverse.py +++ b/tests/test_inverse.py @@ -38,6 +38,7 @@ CropForegroundd, Orientationd, Spacingd, + Resized, ) from monai.utils import optional_import, set_determinism from tests.utils import make_nifti_image, make_rand_affine @@ -221,7 +222,7 @@ TESTS.append(( "Zoomd 2d", DATA_2D, - 1e-1, + 2e-1, Zoomd(KEYS, zoom=0.9), )) @@ -267,6 +268,20 @@ Spacingd(KEYS, [0.5, 0.7, 0.9], diagonal=False) )) +TESTS.append(( + "Resized 2d", + DATA_2D, + 1e-1, + Resized(KEYS, [50, 47]) +)) + +TESTS.append(( + "Resized 3d", + DATA_3D, + 1e-2, + Resized(KEYS, [201, 150, 78]) +)) + TESTS_COMPOSE_X2 = [(t[0] + " Compose", t[1], t[2], Compose(Compose(t[3:]))) for t in TESTS] TESTS = [*TESTS, *TESTS_COMPOSE_X2] From 1bf0015693064b2ebe1bcf75f01cc1fce8f557a7 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Wed, 3 Feb 2021 12:15:12 +0000 Subject: [PATCH 38/80] ResizeWithPadOrCropd Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/croppad/dictionary.py | 27 +++++++++++++++++++++++++- tests/test_inverse.py | 12 ++++++++++-- 2 files changed, 36 insertions(+), 3 deletions(-) diff --git a/monai/transforms/croppad/dictionary.py b/monai/transforms/croppad/dictionary.py index 321f22f23e..49fed069a6 100644 --- a/monai/transforms/croppad/dictionary.py +++ b/monai/transforms/croppad/dictionary.py @@ -819,7 +819,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> List[Dict[Hashable, n return results -class ResizeWithPadOrCropd(MapTransform): +class ResizeWithPadOrCropd(MapTransform, InvertibleTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.ResizeWithPadOrCrop`. @@ -847,7 +847,32 @@ def __init__( def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) for key in self.keys: + self.append_applied_transforms(d, key) d[key] = self.padcropper(d[key]) + # padder will modify `spatial_size`, so update it + self.get_most_recent_transform(d, key, False)["init_args"]["spatial_size"] = self.padcropper.padder.spatial_size + return d + + def get_input_args(self, key: Hashable, idx: int = 0) -> dict: + return { + "keys": key, + "spatial_size": self.padcropper.padder.spatial_size, + "mode": self.padcropper.padder.mode, + } + + def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + d = deepcopy(dict(data)) + for key in self.keys: + transform = self.get_most_recent_transform(d, key) + # Create inverse transform + orig_size = transform["orig_size"] + mode = transform["init_args"]["mode"] + inverse_transform = ResizeWithPadOrCrop(spatial_size=orig_size, mode=mode) + # Apply inverse transform + d[key] = inverse_transform(d[key]) + # Remove the applied transform + self.remove_most_recent_transform(d, key) + return d diff --git a/tests/test_inverse.py b/tests/test_inverse.py index 1f5e143868..4f865983d4 100644 --- a/tests/test_inverse.py +++ b/tests/test_inverse.py @@ -39,6 +39,7 @@ Orientationd, Spacingd, Resized, + ResizeWithPadOrCropd, ) from monai.utils import optional_import, set_determinism from tests.utils import make_nifti_image, make_rand_affine @@ -271,17 +272,24 @@ TESTS.append(( "Resized 2d", DATA_2D, - 1e-1, + 2e-1, Resized(KEYS, [50, 47]) )) TESTS.append(( "Resized 3d", DATA_3D, - 1e-2, + 5e-2, Resized(KEYS, [201, 150, 78]) )) +TESTS.append(( + "ResizeWithPadOrCropd 3d", + DATA_3D, + 0, + ResizeWithPadOrCropd(KEYS, [201, 150, 78]) +)) + TESTS_COMPOSE_X2 = [(t[0] + " Compose", t[1], t[2], Compose(Compose(t[3:]))) for t in TESTS] TESTS = [*TESTS, *TESTS_COMPOSE_X2] From a2e13d48a72f9671f1fec4d8b24418888b2a437c Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Wed, 3 Feb 2021 13:17:59 +0000 Subject: [PATCH 39/80] Rotated and RandRotated 3d Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/croppad/dictionary.py | 10 ++-- monai/transforms/spatial/array.py | 7 ++- monai/transforms/spatial/dictionary.py | 83 ++++++++++++++------------ monai/transforms/transform.py | 23 +++---- tests/test_inverse.py | 23 ++++--- 5 files changed, 78 insertions(+), 68 deletions(-) diff --git a/monai/transforms/croppad/dictionary.py b/monai/transforms/croppad/dictionary.py index 49fed069a6..6efcce197c 100644 --- a/monai/transforms/croppad/dictionary.py +++ b/monai/transforms/croppad/dictionary.py @@ -371,10 +371,9 @@ def __init__(self, keys: KeysCollection, roi_size: Union[Sequence[int], int]) -> def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) for key in self.keys: - self.append_applied_transforms(d, key) + orig_size = d[key].shape[1:] d[key] = self.cropper(d[key]) - # cropper will modify `roi_size` key, so update it - self.get_most_recent_transform(d, key, False)["init_args"]["roi_size"] = self.cropper.roi_size + self.append_applied_transforms(d, key, orig_size=orig_size) return d def get_input_args(self, key: Hashable, idx: int = 0) -> dict: @@ -847,10 +846,9 @@ def __init__( def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) for key in self.keys: - self.append_applied_transforms(d, key) + orig_size = d[key].shape[1:] d[key] = self.padcropper(d[key]) - # padder will modify `spatial_size`, so update it - self.get_most_recent_transform(d, key, False)["init_args"]["spatial_size"] = self.padcropper.padder.spatial_size + self.append_applied_transforms(d, key, orig_size=orig_size) return d def get_input_args(self, key: Hashable, idx: int = 0) -> dict: diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index f130c9a832..ca7e3a223d 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -420,6 +420,7 @@ def __call__( padding_mode: Optional[Union[GridSamplePadMode, str]] = None, align_corners: Optional[bool] = None, dtype: DtypeLike = None, + return_rotation_matrix: bool = False, ) -> np.ndarray: """ Args: @@ -437,6 +438,7 @@ def __call__( dtype: data type for resampling computation. Defaults to ``self.dtype``. If None, use the data type of input data. To be compatible with other modules, the output data type is always ``np.float32``. + return_rotation_matrix: whether or not to return the applied rotation matrix. Raises: ValueError: When ``img`` spatially is not one of [2D, 3D]. @@ -473,7 +475,10 @@ def __call__( torch.as_tensor(np.ascontiguousarray(transform).astype(_dtype)), spatial_size=output_shape, ) - return np.asarray(output.squeeze(0).detach().cpu().numpy(), dtype=np.float32) + output_np = np.asarray(output.squeeze(0).detach().cpu().numpy(), dtype=np.float32) + if return_rotation_matrix: + return output_np, transform + return output_np class Zoom(Transform): diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index 6a31ce8748..c962e54e00 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -22,6 +22,7 @@ import torch from monai.config import DtypeLike, KeysCollection +from monai.networks.layers import AffineTransform from monai.networks.layers.simplelayers import GaussianFilter from monai.transforms.croppad.array import CenterSpatialCrop, SpatialPad from monai.transforms.spatial.array import ( @@ -952,14 +953,16 @@ def __init__( def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) for idx, key in enumerate(self.keys): - self.append_applied_transforms(d, key, idx) - d[key] = self.rotator( + orig_size = d[key].shape[1:] + d[key], rot_mat = self.rotator( d[key], mode=self.mode[idx], padding_mode=self.padding_mode[idx], align_corners=self.align_corners[idx], dtype=self.dtype[idx], + return_rotation_matrix=True, ) + self.append_applied_transforms(d, key, idx, orig_size=orig_size, extra_info={"rot_mat": rot_mat}) return d def get_input_args(self, key: Hashable, idx: int = 0) -> dict: @@ -975,26 +978,27 @@ def get_input_args(self, key: Hashable, idx: int = 0) -> dict: def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = deepcopy(dict(data)) - for idx, key in enumerate(self.keys): - if d[key][0].ndim != 2: - raise NotImplementedError("inverse rotation only currently implemented for 2D") + for key in self.keys: transform = self.get_most_recent_transform(d, key) + init_args = transform["init_args"] # Create inverse transform - in_angle = transform["init_args"]["angle"] - angle = [-a for a in in_angle] if isinstance(in_angle, Sequence) else -in_angle - inverse_rotator = Rotate(angle=angle, keep_size=transform["init_args"]["keep_size"]) - # Apply inverse transform - d[key] = inverse_rotator( - d[key], - mode=self.mode[idx], - padding_mode=self.padding_mode[idx], - align_corners=self.align_corners[idx], - dtype=self.dtype[idx], - ) - # If the keep_size==False, need to crop image - if not transform["init_args"]["keep_size"]: - d[key] = CenterSpatialCrop(transform["orig_size"])(d[key]) + fwd_rot_mat = transform["extra_info"]["rot_mat"] + inv_rot_mat = np.linalg.inv(fwd_rot_mat) + xform = AffineTransform( + normalized=False, + mode=init_args["mode"], + padding_mode=init_args["padding_mode"], + align_corners=init_args["align_corners"], + reverse_indexing=True, + ) + dtype = init_args["dtype"] + output = xform( + torch.as_tensor(np.ascontiguousarray(d[key]).astype(dtype)).unsqueeze(0), + torch.as_tensor(np.ascontiguousarray(inv_rot_mat).astype(dtype)), + spatial_size=transform["orig_size"], + ) + d[key] = np.asarray(output.squeeze(0).detach().cpu().numpy(), dtype=np.float32) # Remove the applied transform self.remove_most_recent_transform(d, key) @@ -1089,14 +1093,16 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda keep_size=self.keep_size, ) for idx, key in enumerate(self.keys): - self.append_applied_transforms(d, key, idx, {"angle": angle}) - d[key] = rotator( + orig_size = d[key].shape[1:] + d[key], rot_mat = rotator( d[key], mode=self.mode[idx], padding_mode=self.padding_mode[idx], align_corners=self.align_corners[idx], dtype=self.dtype[idx], + return_rotation_matrix=True, ) + self.append_applied_transforms(d, key, idx, orig_size=orig_size, extra_info={"rot_mat": rot_mat}) return d def get_input_args(self, key: Hashable, idx: int = 0) -> dict: @@ -1115,28 +1121,29 @@ def get_input_args(self, key: Hashable, idx: int = 0) -> dict: def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = deepcopy(dict(data)) - for idx, key in enumerate(self.keys): - if d[key][0].ndim != 2: - raise NotImplementedError("inverse rotation only currently implemented for 2D") + for key in self.keys: transform = self.get_most_recent_transform(d, key) # Check if random transform was actually performed (based on `prob`) if transform["do_transform"]: + init_args = transform["init_args"] # Create inverse transform - in_angle = transform["extra_info"]["angle"] - angle = [-a for a in in_angle] if isinstance(in_angle, Sequence) else -in_angle - inverse_rotator = Rotate(angle=angle, keep_size=transform["init_args"]["keep_size"]) - # Apply inverse transform - d[key] = inverse_rotator( - d[key], - mode=self.mode[idx], - padding_mode=self.padding_mode[idx], - align_corners=self.align_corners[idx], - dtype=self.dtype[idx], + fwd_rot_mat = transform["extra_info"]["rot_mat"] + inv_rot_mat = np.linalg.inv(fwd_rot_mat) + + xform = AffineTransform( + normalized=False, + mode=init_args["mode"], + padding_mode=init_args["padding_mode"], + align_corners=init_args["align_corners"], + reverse_indexing=True, ) - # If the keep_size==False, need to crop image - if not transform["init_args"]["keep_size"]: - d[key] = CenterSpatialCrop(transform["orig_size"])(d[key]) - + dtype = init_args["dtype"] + output = xform( + torch.as_tensor(np.ascontiguousarray(d[key]).astype(dtype)).unsqueeze(0), + torch.as_tensor(np.ascontiguousarray(inv_rot_mat).astype(dtype)), + spatial_size=transform["orig_size"], + ) + d[key] = np.asarray(output.squeeze(0).detach().cpu().numpy(), dtype=np.float32) # Remove the applied transform self.remove_most_recent_transform(d, key) diff --git a/monai/transforms/transform.py b/monai/transforms/transform.py index a2a2efe3b7..5f6a6f100f 100644 --- a/monai/transforms/transform.py +++ b/monai/transforms/transform.py @@ -223,19 +223,21 @@ class InvertibleTransform(ABC): first out for the inverted transforms. """ - def append_applied_transforms(self, data: dict, key: Hashable, idx: int = 0, extra_info: Optional[dict] = None) -> None: + def append_applied_transforms(self, data: dict, key: Hashable, idx: int = 0, extra_info: Optional[dict] = None, orig_size: Optional[Tuple] = None) -> None: """Append to list of applied transforms for that key.""" key_transform = str(key) + "_transforms" + info = {} + info["class"] = type(self) + info["init_args"] = self.get_input_args(key, idx) + info["orig_size"] = orig_size or data[key].shape[1:] + info["extra_info"] = extra_info + # If class is randomizable, store whether the transform was actually performed (based on `prob`) + if isinstance(self, Randomizable): + info["do_transform"] = self._do_transform # If this is the first, create list if key_transform not in data: data[key_transform] = [] - data[key_transform].append({ - "class": type(self), "init_args": self.get_input_args(key, idx), - "orig_size": data[key].shape[1:], "extra_info": extra_info - }) - # If class is randomizable, store whether the transform was actually performed (based on `prob`) - if isinstance(self, Randomizable): - data[key_transform][-1]["do_transform"] = self._do_transform + data[key_transform].append(info) def check_transforms_match(self, transform: dict, key: Hashable) -> None: explanation = "Should inverse most recently applied invertible transform first" @@ -252,11 +254,10 @@ def check_transforms_match(self, transform: dict, key: Hashable) -> None: if np.any(t1[k] != t2[k]): raise RuntimeError(explanation) - def get_most_recent_transform(self, data: dict, key: Hashable, check: bool = True) -> dict: + def get_most_recent_transform(self, data: dict, key: Hashable) -> dict: """Get most recent transform.""" transform = dict(data[str(key) + "_transforms"][-1]) - if check: - self.check_transforms_match(transform, key) + self.check_transforms_match(transform, key) return transform @staticmethod diff --git a/tests/test_inverse.py b/tests/test_inverse.py index 4f865983d4..9cff6d0c62 100644 --- a/tests/test_inverse.py +++ b/tests/test_inverse.py @@ -178,19 +178,18 @@ )) TESTS.append(( - "RandRotated 2d", - DATA_2D, - 6e-2, - RandRotated(KEYS, random.uniform(np.pi / 6, np.pi)), + "Rotated 3d", + DATA_3D, + 5e-2, + Rotated(KEYS, [random.uniform(np.pi / 6, np.pi) for _ in range(3)], 1), )) -# TODO: add 3D (can replace RandRotated 2d) -# TESTS.append(( -# "RandRotated 3d", -# DATA_3D, -# 5e-2, -# RandRotated(KEYS, *(random.uniform(np.pi / 6, np.pi) for _ in range(3)), 1), -# )) +TESTS.append(( + "RandRotated 3d", + DATA_3D, + 5e-2, + RandRotated(KEYS, *(random.uniform(np.pi / 6, np.pi) for _ in range(3)), 1), +)) TESTS.append(( "Orientationd 3d", @@ -286,7 +285,7 @@ TESTS.append(( "ResizeWithPadOrCropd 3d", DATA_3D, - 0, + 1e-2, ResizeWithPadOrCropd(KEYS, [201, 150, 78]) )) From 749d92531e257e83f03256a338aaf8ec85b8bafe Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Wed, 3 Feb 2021 13:30:46 +0000 Subject: [PATCH 40/80] RandZoomd Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/spatial/dictionary.py | 39 +++++++++++++++++++++++++- tests/test_inverse.py | 8 ++++++ 2 files changed, 46 insertions(+), 1 deletion(-) diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index c962e54e00..9fefab5d5a 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -1233,7 +1233,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar return d -class RandZoomd(Randomizable, MapTransform): +class RandZoomd(Randomizable, MapTransform, InvertibleTransform): """ Dict-based version :py:class:`monai.transforms.RandZoom`. @@ -1299,6 +1299,8 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda self.randomize() d = dict(data) if not self._do_transform: + for key in self.keys: + self.append_applied_transforms(d, key) return d img_dims = data[self.keys[0]].ndim @@ -1310,6 +1312,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda self._zoom = ensure_tuple_rep(self._zoom[0], img_dims - 2) + ensure_tuple(self._zoom[-1]) zoomer = Zoom(self._zoom, keep_size=self.keep_size) for idx, key in enumerate(self.keys): + self.append_applied_transforms(d, key, extra_info={"zoom": self._zoom}) d[key] = zoomer( d[key], mode=self.mode[idx], @@ -1318,6 +1321,40 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda ) return d + def get_input_args(self, key: Hashable, idx: int = 0) -> dict: + return { + "keys": key, + "prob": self.prob, + "min_zoom": self.min_zoom, + "max_zoom": self.max_zoom, + "mode": self.mode[idx], + "padding_mode": self.padding_mode[idx], + "align_corners": self.align_corners[idx], + "keep_size": self.keep_size, + } + + def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + d = deepcopy(dict(data)) + for key in self.keys: + transform = self.get_most_recent_transform(d, key) + # Create inverse transform + init_args = transform["init_args"] + zoom = np.array(transform["extra_info"]["zoom"]) + inverse_transform = Zoom(zoom=1 / zoom, keep_size=init_args["keep_size"]) + # Apply inverse + d[key] = inverse_transform( + d[key], + mode=init_args["mode"], + padding_mode=init_args["padding_mode"], + align_corners=init_args["align_corners"], + ) + # Size might be out by 1 voxel so pad + d[key] = SpatialPad(transform["orig_size"])(d[key]) + # Remove the applied transform + self.remove_most_recent_transform(d, key) + + return d + SpacingD = SpacingDict = Spacingd OrientationD = OrientationDict = Orientationd diff --git a/tests/test_inverse.py b/tests/test_inverse.py index 9cff6d0c62..c771f362d8 100644 --- a/tests/test_inverse.py +++ b/tests/test_inverse.py @@ -40,6 +40,7 @@ Spacingd, Resized, ResizeWithPadOrCropd, + RandZoomd, ) from monai.utils import optional_import, set_determinism from tests.utils import make_nifti_image, make_rand_affine @@ -289,6 +290,13 @@ ResizeWithPadOrCropd(KEYS, [201, 150, 78]) )) +TESTS.append(( + "RandZoom 3d", + DATA_3D, + 5e-2, + RandZoomd(KEYS, 1, [0.5, 0.6, 0.9], [3, 4.2, 6.1], keep_size=False) +)) + TESTS_COMPOSE_X2 = [(t[0] + " Compose", t[1], t[2], Compose(Compose(t[3:]))) for t in TESTS] TESTS = [*TESTS, *TESTS_COMPOSE_X2] From 944e4f124de5809b1a4a5dfd176fd3312de2105d Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Wed, 3 Feb 2021 13:40:48 +0000 Subject: [PATCH 41/80] RandFlipd Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/spatial/dictionary.py | 29 ++++++++++++++++++++++---- tests/test_inverse.py | 22 ++++++++++++------- 2 files changed, 40 insertions(+), 11 deletions(-) diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index 9fefab5d5a..b22fa1d232 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -867,7 +867,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar return d -class RandFlipd(Randomizable, MapTransform): +class RandFlipd(Randomizable, MapTransform, InvertibleTransform): """ Dictionary-based version :py:class:`monai.transforms.RandFlip`. @@ -898,10 +898,31 @@ def randomize(self, data: Optional[Any] = None) -> None: def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: self.randomize() d = dict(data) - if not self._do_transform: - return d for key in self.keys: - d[key] = self.flipper(d[key]) + if self._do_transform: + d[key] = self.flipper(d[key]) + self.append_applied_transforms(d, key) + + return d + + + def get_input_args(self, key: Hashable, idx: int = 0) -> dict: + return { + "keys": key, + "spatial_axis": self.flipper.spatial_axis, + } + + def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + d = deepcopy(dict(data)) + for key in self.keys: + transform = self.get_most_recent_transform(d, key) + # Check if random transform was actually performed (based on `prob`) + if transform["do_transform"]: + # Inverse is same as forward + d[key] = self.flipper(d[key]) + # Remove the applied transform + self.remove_most_recent_transform(d, key) + return d diff --git a/tests/test_inverse.py b/tests/test_inverse.py index c771f362d8..ed7fa8a97b 100644 --- a/tests/test_inverse.py +++ b/tests/test_inverse.py @@ -41,6 +41,7 @@ Resized, ResizeWithPadOrCropd, RandZoomd, + RandFlipd, ) from monai.utils import optional_import, set_determinism from tests.utils import make_nifti_image, make_rand_affine @@ -171,6 +172,13 @@ Flipd(KEYS, [1, 2]), )) +TESTS.append(( + "RandFlipd 3d", + DATA_3D, + 0, + RandFlipd(KEYS, 1, [1, 2]), +)) + TESTS.append(( "Rotated 2d", DATA_2D, @@ -234,6 +242,13 @@ Zoomd(KEYS, zoom=[2.5, 1, 3], keep_size=False), )) +TESTS.append(( + "RandZoom 3d", + DATA_3D, + 5e-2, + RandZoomd(KEYS, 1, [0.5, 0.6, 0.9], [3, 4.2, 6.1], keep_size=False) +)) + TESTS.append(( "CenterSpatialCropd 2d", DATA_2D, @@ -290,13 +305,6 @@ ResizeWithPadOrCropd(KEYS, [201, 150, 78]) )) -TESTS.append(( - "RandZoom 3d", - DATA_3D, - 5e-2, - RandZoomd(KEYS, 1, [0.5, 0.6, 0.9], [3, 4.2, 6.1], keep_size=False) -)) - TESTS_COMPOSE_X2 = [(t[0] + " Compose", t[1], t[2], Compose(Compose(t[3:]))) for t in TESTS] TESTS = [*TESTS, *TESTS_COMPOSE_X2] From 2b5450a400b7ceeb82c42875fcaf92b53768292e Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Wed, 3 Feb 2021 14:17:07 +0000 Subject: [PATCH 42/80] RandRotate90d Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/spatial/dictionary.py | 31 +++++++++++++++++++++++++- tests/test_inverse.py | 8 +++++++ 2 files changed, 38 insertions(+), 1 deletion(-) diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index b22fa1d232..6922ff9d2e 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -377,7 +377,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar return d -class RandRotate90d(Randomizable, MapTransform): +class RandRotate90d(Randomizable, MapTransform, InvertibleTransform): """ Dictionary-based version :py:class:`monai.transforms.RandRotate90`. With probability `prob`, input arrays are rotated by 90 degrees @@ -417,12 +417,41 @@ def randomize(self, data: Optional[Any] = None) -> None: def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Mapping[Hashable, np.ndarray]: self.randomize() if not self._do_transform: + for key in self.keys: + self.append_applied_transforms(data, key) return data rotator = Rotate90(self._rand_k, self.spatial_axes) d = dict(data) for key in self.keys: d[key] = rotator(d[key]) + self.append_applied_transforms(d, key, extra_info={"rand_k": self._rand_k}) + return d + + def get_input_args(self, key: Hashable, idx: int = 0) -> dict: + return { + "keys": key, + "prob": self.prob, + "max_k": self.max_k, + "spatial_axes": self.spatial_axes, + } + + def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + d = deepcopy(dict(data)) + for key in self.keys: + transform = self.get_most_recent_transform(d, key) + # Check if random transform was actually performed (based on `prob`) + if transform["do_transform"]: + # Create inverse transform + spatial_axes = transform["init_args"]["spatial_axes"] + num_times_rotated = transform["extra_info"]["rand_k"] + num_times_to_rotate = 4 - num_times_rotated + inverse_transform = Rotate90(num_times_to_rotate, spatial_axes) + # Apply inverse + d[key] = inverse_transform(d[key]) + # Remove the applied transform + self.remove_most_recent_transform(d, key) + return d diff --git a/tests/test_inverse.py b/tests/test_inverse.py index ed7fa8a97b..5b93deef74 100644 --- a/tests/test_inverse.py +++ b/tests/test_inverse.py @@ -42,6 +42,7 @@ ResizeWithPadOrCropd, RandZoomd, RandFlipd, + RandRotate90d, ) from monai.utils import optional_import, set_determinism from tests.utils import make_nifti_image, make_rand_affine @@ -221,6 +222,13 @@ Rotate90d(KEYS, k=2, spatial_axes=[1, 2]), )) +TESTS.append(( + "RandRotate90d 3d", + DATA_3D, + 0, + RandRotate90d(KEYS, prob=1, spatial_axes=[1, 2]), +)) + TESTS.append(( "Zoomd 1d", DATA_1D, From c798e2cce69dbcc6e5e9c2549f3cb5e765f06c3f Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Wed, 3 Feb 2021 15:16:47 +0000 Subject: [PATCH 43/80] RandAffined to call correct constructor Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/spatial/dictionary.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index 6922ff9d2e..fce62cc646 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -581,7 +581,7 @@ def __init__( - :py:class:`monai.transforms.compose.MapTransform` - :py:class:`RandAffineGrid` for the random affine parameters configurations. """ - super().__init__(keys) + MapTransform().__init__(self, keys) self.rand_affine = RandAffine( prob=prob, rotate_range=rotate_range, From 9b51e6b5a3078a4c1ff65c84bc91de06dea33cd4 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Wed, 3 Feb 2021 17:47:40 +0000 Subject: [PATCH 44/80] RandAffined (not finished) Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/spatial/dictionary.py | 71 +++++++++++++++++++++++++- tests/test_inverse.py | 18 +++++-- 2 files changed, 83 insertions(+), 6 deletions(-) diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index fce62cc646..21294163df 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -36,6 +36,7 @@ Rotate90, Spacing, Zoom, + AffineGrid, ) from monai.transforms.transform import InvertibleTransform, MapTransform, Randomizable from monai.transforms.utils import create_grid @@ -520,7 +521,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar return d -class RandAffined(Randomizable, MapTransform): +class RandAffined(Randomizable, MapTransform, InvertibleTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.RandAffine`. """ @@ -581,7 +582,8 @@ def __init__( - :py:class:`monai.transforms.compose.MapTransform` - :py:class:`RandAffineGrid` for the random affine parameters configurations. """ - MapTransform().__init__(self, keys) + MapTransform.__init__(self, keys) + Randomizable.__init__(self, prob) self.rand_affine = RandAffine( prob=prob, rotate_range=rotate_range, @@ -604,6 +606,8 @@ def set_random_state( def randomize(self, data: Optional[Any] = None) -> None: self.rand_affine.randomize() + self.prob = self.rand_affine.prob + self._do_transform = self.rand_affine._do_transform def __call__( self, data: Mapping[Hashable, Union[np.ndarray, torch.Tensor]] @@ -618,9 +622,72 @@ def __call__( grid = create_grid(spatial_size=sp_size) for idx, key in enumerate(self.keys): + rag = self.rand_affine.rand_affine_grid + extra_info = { + "rotate_params": rag.rotate_params, + "shear_params": rag.shear_params, + "translate_params": rag.translate_params, + "scale_params": rag.scale_params, + "orig_was_numpy": isinstance(d[key], np.ndarray), + } + # rotate_params, shear_params, translate_params, scale_params + self.append_applied_transforms(d, key, idx, extra_info=extra_info) d[key] = self.rand_affine.resampler(d[key], grid, mode=self.mode[idx], padding_mode=self.padding_mode[idx]) return d + def get_input_args(self, key: Hashable, idx: int = 0) -> dict: + return { + "keys": key, + "spatial_size": self.rand_affine.spatial_size, + "prob": self.rand_affine.prob, + "rotate_range": self.rand_affine.rand_affine_grid.rotate_range, + "shear_range": self.rand_affine.rand_affine_grid.shear_range, + "translate_range": self.rand_affine.rand_affine_grid.translate_range, + "scale_range": self.rand_affine.rand_affine_grid.scale_range, + "mode": self.rand_affine.mode, + "padding_mode": self.rand_affine.padding_mode, + "as_tensor_output": self.rand_affine.resampler.as_tensor_output, + "device": self.rand_affine.resampler.device, + } + + def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + d = deepcopy(dict(data)) + + for key in self.keys: + transform = self.get_most_recent_transform(d, key) + extra_info = transform["extra_info"] + init_args = transform["init_args"] + orig_size = transform["orig_size"] + # Create inverse transform + if transform["do_transform"]: + rotate_params = - np.array(extra_info["rotate_params"]) + shear_params = - np.array(extra_info["shear_params"]) + translate_params = - np.array(extra_info["translate_params"]) + scale_params = 1 / np.array(extra_info["scale_params"]) + if np.sum(rotate_params != 0) >= 2: + raise RuntimeError("RandAffined:inverse not yet implemented for >= 2 rotation directions") + + affine_grid = AffineGrid( + rotate_params=rotate_params.tolist(), + shear_params=shear_params.tolist(), + translate_params=translate_params.tolist(), + scale_params=scale_params.tolist(), + as_tensor_output=init_args["as_tensor_output"], + device=init_args["device"], + ) + grid = affine_grid(orig_size) + else: + grid = create_grid(spatial_size=orig_size) + + # Apply inverse transform + d[key] = self.rand_affine.resampler(d[key], grid, init_args["mode"], init_args["padding_mode"]) + if extra_info["orig_was_numpy"]: + d[key] = d[key].cpu().numpy() + # Remove the applied transform + self.remove_most_recent_transform(d, key) + + return d + class Rand2DElasticd(Randomizable, MapTransform): """ diff --git a/tests/test_inverse.py b/tests/test_inverse.py index 5b93deef74..c81db8644f 100644 --- a/tests/test_inverse.py +++ b/tests/test_inverse.py @@ -43,6 +43,7 @@ RandZoomd, RandFlipd, RandRotate90d, + RandAffined, ) from monai.utils import optional_import, set_determinism from tests.utils import make_nifti_image, make_rand_affine @@ -313,6 +314,13 @@ ResizeWithPadOrCropd(KEYS, [201, 150, 78]) )) +TESTS.append(( + "RandAffine 3d", + DATA_3D, + 5e-2, + RandAffined(KEYS, [98, 96, 105], 1, rotate_range=np.pi / 6, shear_range=[1, 1, 1], translate_range=[10, 5, -4], scale_range=[0.9, 1, 1.1]) +)) + TESTS_COMPOSE_X2 = [(t[0] + " Compose", t[1], t[2], Compose(Compose(t[3:]))) for t in TESTS] TESTS = [*TESTS, *TESTS_COMPOSE_X2] @@ -327,15 +335,17 @@ def plot_im(orig, fwd_bck, fwd): fig, axes = plt.subplots( 1, 4, gridspec_kw={"width_ratios": [orig.shape[1], fwd_bck.shape[1], diff_orig_fwd_bck.shape[1], fwd.shape[1]]} ) + vmin = min(np.array(i).min() for i in [orig, fwd_bck, fwd]) + vmax = max(np.array(i).max() for i in [orig, fwd_bck, fwd]) for i, (im, title) in enumerate( - zip([orig, fwd_bck, diff_orig_fwd_bck, fwd], ["orig", "fwd_bck", "%% diff", "fwd"]) + zip([orig, fwd_bck, diff_orig_fwd_bck, fwd], ["x", "f⁻¹fx", "diff", "fx"]) ): ax = axes[i] - vmax = max(np.max(i) for i in [orig, fwd_bck, fwd]) if i != 2 else None - im = np.squeeze(im) + im = np.squeeze(np.array(im)) while im.ndim > 2: im = im[..., im.shape[-1] // 2] - im_show = ax.imshow(np.squeeze(im), vmax=vmax) + _vmin, _vmax = (vmin, vmax) if i != 2 else (None, None) + im_show = ax.imshow(np.squeeze(im), vmin=_vmin, vmax=_vmax) ax.set_title(title, fontsize=25) ax.axis("off") fig.colorbar(im_show, ax=ax) From 521184b3426202dade7b5523a073ca8f781793b5 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Thu, 4 Feb 2021 12:35:06 +0000 Subject: [PATCH 45/80] RandAffined Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/spatial/array.py | 168 +++++++++++++------------ monai/transforms/spatial/dictionary.py | 104 ++++++--------- tests/test_inverse.py | 13 +- 3 files changed, 135 insertions(+), 150 deletions(-) diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index ca7e3a223d..fe9bfd93c1 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -24,6 +24,7 @@ from monai.networks.layers import AffineTransform, GaussianFilter, grid_pull from monai.transforms.transform import Randomizable, Transform from monai.transforms.croppad.array import CenterSpatialCrop +from monai.utils import issequenceiterable from monai.transforms.utils import ( create_control_grid, create_grid, @@ -890,6 +891,10 @@ class AffineGrid(Transform): as_tensor_output: whether to output tensor instead of numpy array. defaults to True. device: device to store the output grid data. + affine: If applied, ignore the params (`rotate_params`, etc.) and use the + supplied matrix. Should be square with each side = num of image spatial + dimensions + 1. + return_affine: boolean as to whether to return the generated affine matrix or not. """ @@ -901,6 +906,8 @@ def __init__( scale_params: Optional[Union[Sequence[float], float]] = None, as_tensor_output: bool = True, device: Optional[torch.device] = None, + affine: Optional[Union[np.array, torch.Tensor]] = None, + return_affine: bool = False, ) -> None: self.rotate_params = rotate_params self.shear_params = shear_params @@ -910,6 +917,9 @@ def __init__( self.as_tensor_output = as_tensor_output self.device = device + self.affine = affine + self.return_affine = return_affine + def __call__( self, spatial_size: Optional[Sequence[int]] = None, grid: Optional[Union[np.ndarray, torch.Tensor]] = None ) -> Union[np.ndarray, torch.Tensor]: @@ -928,16 +938,19 @@ def __call__( else: raise ValueError("Incompatible values: grid=None and spatial_size=None.") - spatial_dims = len(grid.shape) - 1 - affine = np.eye(spatial_dims + 1) - if self.rotate_params: - affine = affine @ create_rotate(spatial_dims, self.rotate_params) - if self.shear_params: - affine = affine @ create_shear(spatial_dims, self.shear_params) - if self.translate_params: - affine = affine @ create_translate(spatial_dims, self.translate_params) - if self.scale_params: - affine = affine @ create_scale(spatial_dims, self.scale_params) + if self.affine is None: + spatial_dims = len(grid.shape) - 1 + affine = np.eye(spatial_dims + 1) + if self.rotate_params: + affine = affine @ create_rotate(spatial_dims, self.rotate_params) + if self.shear_params: + affine = affine @ create_shear(spatial_dims, self.shear_params) + if self.translate_params: + affine = affine @ create_translate(spatial_dims, self.translate_params) + if self.scale_params: + affine = affine @ create_scale(spatial_dims, self.scale_params) + else: + affine = self.affine affine = torch.as_tensor(np.ascontiguousarray(affine), device=self.device) grid = torch.tensor(grid) if not isinstance(grid, torch.Tensor) else grid.detach().clone() @@ -946,9 +959,10 @@ def __call__( grid = (affine.float() @ grid.reshape((grid.shape[0], -1)).float()).reshape([-1] + list(grid.shape[1:])) if grid is None or not isinstance(grid, torch.Tensor): raise ValueError("Unknown grid.") - if self.as_tensor_output: - return grid - return np.asarray(grid.cpu().numpy()) + output = grid if self.as_tensor_output else np.asarray(grid.cpu().numpy()) + if self.return_affine: + return output, affine + return output class RandAffineGrid(Randomizable, Transform): @@ -958,33 +972,29 @@ class RandAffineGrid(Randomizable, Transform): def __init__( self, - rotate_range: Optional[Union[Sequence[float], float]] = None, - shear_range: Optional[Union[Sequence[float], float]] = None, - translate_range: Optional[Union[Sequence[float], float]] = None, - scale_range: Optional[Union[Sequence[float], float]] = None, + rotate_range: Optional[Union[Sequence[Sequence[float]], Sequence[float], float]] = None, + shear_range: Optional[Union[Sequence[Sequence[float]], Sequence[float], float]] = None, + translate_range: Optional[Union[Sequence[Sequence[float]], Sequence[float], float]] = None, + scale_range: Optional[Union[Sequence[Sequence[float]], Sequence[float], float]] = None, as_tensor_output: bool = True, device: Optional[torch.device] = None, + return_affine: bool = False, ) -> None: """ Args: - rotate_range: angle range in radians. rotate_range[0] with be used to generate the 1st rotation - parameter from `uniform[-rotate_range[0], rotate_range[0])`. Similarly, `rotate_range[1]` and - `rotate_range[2]` are used in 3D affine for the range of 2nd and 3rd axes. - shear_range: shear_range[0] with be used to generate the 1st shearing parameter from - `uniform[-shear_range[0], shear_range[0])`. Similarly, `shear_range[1]` to - `shear_range[N]` controls the range of the uniform distribution used to generate the 2nd to - N-th parameter. - translate_range : translate_range[0] with be used to generate the 1st shift parameter from - `uniform[-translate_range[0], translate_range[0])`. Similarly, `translate_range[1]` - to `translate_range[N]` controls the range of the uniform distribution used to generate - the 2nd to N-th parameter. - scale_range: scaling_range[0] with be used to generate the 1st scaling factor from - `uniform[-scale_range[0], scale_range[0]) + 1.0`. Similarly, `scale_range[1]` to - `scale_range[N]` controls the range of the uniform distribution used to generate the 2nd to - N-th parameter. + rotate_range: angle range in radians. If element `i` is iterable, then + `uniform[-rotate_range[i][0], rotate_range[i][1])` will be used to generate the rotation parameter + for the ith dimension. If not, `uniform[-rotate_range[i], rotate_range[i])` will be used. This can + be altered on a per-dimension basis. E.g., `((0,3), 1, ...)`: for dim0, rotation will be in range + `[0, 3]`, and for dim1 `[-1, 1]` will be used. Setting a single value will use `[-x, x]` for dim0 + and nothing for the remaining dimensions. + shear_range: shear_range with format matching `rotate_range`. + translate_range: translate_range with format matching `rotate_range`. + scale_range: scaling_range with format matching `rotate_range`. as_tensor_output: whether to output tensor instead of numpy array. defaults to True. device: device to store the output grid data. + return_affine: boolean as to whether to return the generated affine matrix or not. See also: - :py:meth:`monai.transforms.utils.create_rotate` @@ -1005,15 +1015,25 @@ def __init__( self.as_tensor_output = as_tensor_output self.device = device + self.return_affine = return_affine + + def _get_rand_param(self, param_range): + out_param = [] + for f in param_range: + if issequenceiterable(f): + if len(f) != 2: + raise ValueError("If giving range as [min,max], should only have two elements per dim.") + out_param.append(self.R.uniform(f[0], f[1])) + elif f is not None: + out_param.append(self.R.uniform(-f, f)) + return out_param + + def randomize(self, data: Optional[Any] = None) -> None: - if self.rotate_range: - self.rotate_params = [self.R.uniform(-f, f) for f in self.rotate_range if f is not None] - if self.shear_range: - self.shear_params = [self.R.uniform(-f, f) for f in self.shear_range if f is not None] - if self.translate_range: - self.translate_params = [self.R.uniform(-f, f) for f in self.translate_range if f is not None] - if self.scale_range: - self.scale_params = [self.R.uniform(-f, f) + 1.0 for f in self.scale_range if f is not None] + self.rotate_params = self._get_rand_param(self.rotate_range) + self.shear_params = self._get_rand_param(self.shear_range) + self.translate_params = self._get_rand_param(self.translate_range) + self.scale_params = self._get_rand_param(self.scale_range) def __call__( self, spatial_size: Optional[Sequence[int]] = None, grid: Optional[Union[np.ndarray, torch.Tensor]] = None @@ -1034,6 +1054,7 @@ def __call__( scale_params=self.scale_params, as_tensor_output=self.as_tensor_output, device=self.device, + return_affine=self.return_affine, ) return affine_grid(spatial_size, grid) @@ -1289,21 +1310,15 @@ def __init__( Args: prob: probability of returning a randomized affine grid. defaults to 0.1, with 10% chance returns a randomized grid. - rotate_range: angle range in radians. rotate_range[0] with be used to generate the 1st rotation - parameter from `uniform[-rotate_range[0], rotate_range[0])`. Similarly, `rotate_range[1]` and - `rotate_range[2]` are used in 3D affine for the range of 2nd and 3rd axes. - shear_range: shear_range[0] with be used to generate the 1st shearing parameter from - `uniform[-shear_range[0], shear_range[0])`. Similarly, `shear_range[1]` to - `shear_range[N]` controls the range of the uniform distribution used to generate the 2nd to - N-th parameter. - translate_range : translate_range[0] with be used to generate the 1st shift parameter from - `uniform[-translate_range[0], translate_range[0])`. Similarly, `translate_range[1]` - to `translate_range[N]` controls the range of the uniform distribution used to generate - the 2nd to N-th parameter. - scale_range: scaling_range[0] with be used to generate the 1st scaling factor from - `uniform[-scale_range[0], scale_range[0]) + 1.0`. Similarly, `scale_range[1]` to - `scale_range[N]` controls the range of the uniform distribution used to generate the 2nd to - N-th parameter. + rotate_range: angle range in radians. If element `i` is iterable, then + `uniform[-rotate_range[i][0], rotate_range[i][1])` will be used to generate the rotation parameter + for the ith dimension. If not, `uniform[-rotate_range[i], rotate_range[i])` will be used. This can + be altered on a per-dimension basis. E.g., `((0,3), 1, ...)`: for dim0, rotation will be in range + `[0, 3]`, and for dim1 `[-1, 1]` will be used. Setting a single value will use `[-x, x]` for dim0 + and nothing for the remaining dimensions. + shear_range: shear_range with format matching `rotate_range`. + translate_range: translate_range with format matching `rotate_range`. + scale_range: scaling_range with format matching `rotate_range`. spatial_size: output image spatial size. if `spatial_size` and `self.spatial_size` are not defined, or smaller than 1, the transform will use the spatial size of `img`. @@ -1333,6 +1348,7 @@ def __init__( scale_range=scale_range, as_tensor_output=True, device=device, + return_affine=True, ) self.resampler = Resample(as_tensor_output=as_tensor_output, device=device) @@ -1412,17 +1428,15 @@ def __init__( prob: probability of returning a randomized elastic transform. defaults to 0.1, with 10% chance returns a randomized elastic transform, otherwise returns a ``spatial_size`` centered area extracted from the input image. - rotate_range: angle range in radians. rotate_range[0] with be used to generate the 1st rotation - parameter from `uniform[-rotate_range[0], rotate_range[0])`. - shear_range: shear_range[0] with be used to generate the 1st shearing parameter from - `uniform[-shear_range[0], shear_range[0])`. Similarly, `shear_range[1]` controls - the range of the uniform distribution used to generate the 2nd parameter. - translate_range : translate_range[0] with be used to generate the 1st shift parameter from - `uniform[-translate_range[0], translate_range[0])`. Similarly, `translate_range[1]` controls - the range of the uniform distribution used to generate the 2nd parameter. - scale_range: scaling_range[0] with be used to generate the 1st scaling factor from - `uniform[-scale_range[0], scale_range[0]) + 1.0`. Similarly, `scale_range[1]` controls - the range of the uniform distribution used to generate the 2nd parameter. + rotate_range: angle range in radians. If element `i` is iterable, then + `uniform[-rotate_range[i][0], rotate_range[i][1])` will be used to generate the rotation parameter + for the ith dimension. If not, `uniform[-rotate_range[i], rotate_range[i])` will be used. This can + be altered on a per-dimension basis. E.g., `((0,3), 1, ...)`: for dim0, rotation will be in range + `[0, 3]`, and for dim1 `[-1, 1]` will be used. Setting a single value will use `[-x, x]` for dim0 + and nothing for the remaining dimensions. + shear_range: shear_range with format matching `rotate_range`. + translate_range: translate_range with format matching `rotate_range`. + scale_range: scaling_range with format matching `rotate_range`. spatial_size: specifying output image spatial size [h, w]. if `spatial_size` and `self.spatial_size` are not defined, or smaller than 1, the transform will use the spatial size of `img`. @@ -1541,19 +1555,15 @@ def __init__( prob: probability of returning a randomized elastic transform. defaults to 0.1, with 10% chance returns a randomized elastic transform, otherwise returns a ``spatial_size`` centered area extracted from the input image. - rotate_range: angle range in radians. rotate_range[0] with be used to generate the 1st rotation - parameter from `uniform[-rotate_range[0], rotate_range[0])`. Similarly, `rotate_range[1]` and - `rotate_range[2]` are used in 3D affine for the range of 2nd and 3rd axes. - shear_range: shear_range[0] with be used to generate the 1st shearing parameter from - `uniform[-shear_range[0], shear_range[0])`. Similarly, `shear_range[1]` and `shear_range[2]` - controls the range of the uniform distribution used to generate the 2nd and 3rd parameters. - translate_range : translate_range[0] with be used to generate the 1st shift parameter from - `uniform[-translate_range[0], translate_range[0])`. Similarly, `translate_range[1]` and - `translate_range[2]` controls the range of the uniform distribution used to generate - the 2nd and 3rd parameters. - scale_range: scaling_range[0] with be used to generate the 1st scaling factor from - `uniform[-scale_range[0], scale_range[0]) + 1.0`. Similarly, `scale_range[1]` and `scale_range[2]` - controls the range of the uniform distribution used to generate the 2nd and 3rd parameters. + rotate_range: angle range in radians. If element `i` is iterable, then + `uniform[-rotate_range[i][0], rotate_range[i][1])` will be used to generate the rotation parameter + for the ith dimension. If not, `uniform[-rotate_range[i], rotate_range[i])` will be used. This can + be altered on a per-dimension basis. E.g., `((0,3), 1, ...)`: for dim0, rotation will be in range + `[0, 3]`, and for dim1 `[-1, 1]` will be used. Setting a single value will use `[-x, x]` for dim0 + and nothing for the remaining dimensions. + shear_range: shear_range with format matching `rotate_range`. + translate_range: translate_range with format matching `rotate_range`. + scale_range: scaling_range with format matching `rotate_range`. spatial_size: specifying output image spatial size [h, w, d]. if `spatial_size` and `self.spatial_size` are not defined, or smaller than 1, the transform will use the spatial size of `img`. diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index 21294163df..f6e344715a 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -551,21 +551,15 @@ def __init__( to `(32, 64)` if the second spatial dimension size of img is `64`. prob: probability of returning a randomized affine grid. defaults to 0.1, with 10% chance returns a randomized grid. - rotate_range: angle range in radians. rotate_range[0] with be used to generate the 1st rotation - parameter from `uniform[-rotate_range[0], rotate_range[0])`. Similarly, `rotate_range[1]` and - `rotate_range[2]` are used in 3D affine for the range of 2nd and 3rd axes. - shear_range: shear_range[0] with be used to generate the 1st shearing parameter from - `uniform[-shear_range[0], shear_range[0])`. Similarly, `shear_range[1]` to - `shear_range[N]` controls the range of the uniform distribution used to generate the 2nd to - N-th parameter. - translate_range : translate_range[0] with be used to generate the 1st shift parameter from - `uniform[-translate_range[0], translate_range[0])`. Similarly, `translate_range[1]` - to `translate_range[N]` controls the range of the uniform distribution used to generate - the 2nd to N-th parameter. - scale_range: scaling_range[0] with be used to generate the 1st scaling factor from - `uniform[-scale_range[0], scale_range[0]) + 1.0`. Similarly, `scale_range[1]` to - `scale_range[N]` controls the range of the uniform distribution used to generate the 2nd to - N-th parameter. + rotate_range: angle range in radians. If element `i` is iterable, then + `uniform[-rotate_range[i][0], rotate_range[i][1])` will be used to generate the rotation parameter + for the ith dimension. If not, `uniform[-rotate_range[i], rotate_range[i])` will be used. This can + be altered on a per-dimension basis. E.g., `((0,3), 1, ...)`: for dim0, rotation will be in range + `[0, 3]`, and for dim1 `[-1, 1]` will be used. Setting a single value will use `[-x, x]` for dim0 + and nothing for the remaining dimensions. + shear_range: shear_range with format matching `rotate_range`. + translate_range: translate_range with format matching `rotate_range`. + scale_range: scaling_range with format matching `rotate_range`. mode: {``"bilinear"``, ``"nearest"``} Interpolation mode to calculate output values. Defaults to ``"bilinear"``. See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample @@ -617,21 +611,13 @@ def __call__( sp_size = fall_back_tuple(self.rand_affine.spatial_size, data[self.keys[0]].shape[1:]) if self.rand_affine._do_transform: - grid = self.rand_affine.rand_affine_grid(spatial_size=sp_size) + grid, affine = self.rand_affine.rand_affine_grid(spatial_size=sp_size) else: grid = create_grid(spatial_size=sp_size) + affine = np.eye(len(sp_size) + 1) for idx, key in enumerate(self.keys): - rag = self.rand_affine.rand_affine_grid - extra_info = { - "rotate_params": rag.rotate_params, - "shear_params": rag.shear_params, - "translate_params": rag.translate_params, - "scale_params": rag.scale_params, - "orig_was_numpy": isinstance(d[key], np.ndarray), - } - # rotate_params, shear_params, translate_params, scale_params - self.append_applied_transforms(d, key, idx, extra_info=extra_info) + self.append_applied_transforms(d, key, idx, extra_info={"affine": affine, "orig_was_numpy": isinstance(d[key], np.ndarray)}) d[key] = self.rand_affine.resampler(d[key], grid, mode=self.mode[idx], padding_mode=self.padding_mode[idx]) return d @@ -659,25 +645,11 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar init_args = transform["init_args"] orig_size = transform["orig_size"] # Create inverse transform - if transform["do_transform"]: - rotate_params = - np.array(extra_info["rotate_params"]) - shear_params = - np.array(extra_info["shear_params"]) - translate_params = - np.array(extra_info["translate_params"]) - scale_params = 1 / np.array(extra_info["scale_params"]) - if np.sum(rotate_params != 0) >= 2: - raise RuntimeError("RandAffined:inverse not yet implemented for >= 2 rotation directions") - - affine_grid = AffineGrid( - rotate_params=rotate_params.tolist(), - shear_params=shear_params.tolist(), - translate_params=translate_params.tolist(), - scale_params=scale_params.tolist(), - as_tensor_output=init_args["as_tensor_output"], - device=init_args["device"], - ) - grid = affine_grid(orig_size) - else: - grid = create_grid(spatial_size=orig_size) + fwd_affine = extra_info["affine"] + inv_affine = np.linalg.inv(fwd_affine) + + affine_grid = AffineGrid(affine=inv_affine) + grid = affine_grid(orig_size) # Apply inverse transform d[key] = self.rand_affine.resampler(d[key], grid, init_args["mode"], init_args["padding_mode"]) @@ -725,17 +697,15 @@ def __init__( prob: probability of returning a randomized affine grid. defaults to 0.1, with 10% chance returns a randomized grid, otherwise returns a ``spatial_size`` centered area extracted from the input image. - rotate_range: angle range in radians. rotate_range[0] with be used to generate the 1st rotation - parameter from `uniform[-rotate_range[0], rotate_range[0])`. - shear_range: shear_range[0] with be used to generate the 1st shearing parameter from - `uniform[-shear_range[0], shear_range[0])`. Similarly, `shear_range[1]` controls - the range of the uniform distribution used to generate the 2nd parameter. - translate_range : translate_range[0] with be used to generate the 1st shift parameter from - `uniform[-translate_range[0], translate_range[0])`. Similarly, `translate_range[1]` controls - the range of the uniform distribution used to generate the 2nd parameter. - scale_range: scaling_range[0] with be used to generate the 1st scaling factor from - `uniform[-scale_range[0], scale_range[0]) + 1.0`. Similarly, `scale_range[1]` controls - the range of the uniform distribution used to generate the 2nd parameter. + rotate_range: angle range in radians. If element `i` is iterable, then + `uniform[-rotate_range[i][0], rotate_range[i][1])` will be used to generate the rotation parameter + for the ith dimension. If not, `uniform[-rotate_range[i], rotate_range[i])` will be used. This can + be altered on a per-dimension basis. E.g., `((0,3), 1, ...)`: for dim0, rotation will be in range + `[0, 3]`, and for dim1 `[-1, 1]` will be used. Setting a single value will use `[-x, x]` for dim0 + and nothing for the remaining dimensions. + shear_range: shear_range with format matching `rotate_range`. + translate_range: translate_range with format matching `rotate_range`. + scale_range: scaling_range with format matching `rotate_range`. mode: {``"bilinear"``, ``"nearest"``} Interpolation mode to calculate output values. Defaults to ``"bilinear"``. See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample @@ -844,19 +814,15 @@ def __init__( prob: probability of returning a randomized affine grid. defaults to 0.1, with 10% chance returns a randomized grid, otherwise returns a ``spatial_size`` centered area extracted from the input image. - rotate_range: angle range in radians. rotate_range[0] with be used to generate the 1st rotation - parameter from `uniform[-rotate_range[0], rotate_range[0])`. Similarly, `rotate_range[1]` and - `rotate_range[2]` are used in 3D affine for the range of 2nd and 3rd axes. - shear_range: shear_range[0] with be used to generate the 1st shearing parameter from - `uniform[-shear_range[0], shear_range[0])`. Similarly, `shear_range[1]` and `shear_range[2]` - controls the range of the uniform distribution used to generate the 2nd and 3rd parameters. - translate_range : translate_range[0] with be used to generate the 1st shift parameter from - `uniform[-translate_range[0], translate_range[0])`. Similarly, `translate_range[1]` and - `translate_range[2]` controls the range of the uniform distribution used to generate - the 2nd and 3rd parameters. - scale_range: scaling_range[0] with be used to generate the 1st scaling factor from - `uniform[-scale_range[0], scale_range[0]) + 1.0`. Similarly, `scale_range[1]` and `scale_range[2]` - controls the range of the uniform distribution used to generate the 2nd and 3rd parameters. + rotate_range: angle range in radians. If element `i` is iterable, then + `uniform[-rotate_range[i][0], rotate_range[i][1])` will be used to generate the rotation parameter + for the ith dimension. If not, `uniform[-rotate_range[i], rotate_range[i])` will be used. This can + be altered on a per-dimension basis. E.g., `((0,3), 1, ...)`: for dim0, rotation will be in range + `[0, 3]`, and for dim1 `[-1, 1]` will be used. Setting a single value will use `[-x, x]` for dim0 + and nothing for the remaining dimensions. + shear_range: shear_range with format matching `rotate_range`. + translate_range: translate_range with format matching `rotate_range`. + scale_range: scaling_range with format matching `rotate_range`. mode: {``"bilinear"``, ``"nearest"``} Interpolation mode to calculate output values. Defaults to ``"bilinear"``. See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample diff --git a/tests/test_inverse.py b/tests/test_inverse.py index c81db8644f..ba9e385a61 100644 --- a/tests/test_inverse.py +++ b/tests/test_inverse.py @@ -58,7 +58,7 @@ else: plt, has_matplotlib = optional_import("matplotlib.pyplot") -# set_determinism(seed=0) +set_determinism(seed=0) AFFINE = make_rand_affine() AFFINE[0] *= 2 @@ -318,7 +318,16 @@ "RandAffine 3d", DATA_3D, 5e-2, - RandAffined(KEYS, [98, 96, 105], 1, rotate_range=np.pi / 6, shear_range=[1, 1, 1], translate_range=[10, 5, -4], scale_range=[0.9, 1, 1.1]) + RandAffined( + KEYS, + [155, 179, 192], + prob=1, + padding_mode="zeros", + rotate_range=[np.pi / 6, -np.pi / 5, np.pi / 7], + shear_range=[[0.5, 0.5]], + translate_range=[10, 5, -4], + scale_range=[[0.8, 1.2], [0.9, 1.3]], + ) )) TESTS_COMPOSE_X2 = [(t[0] + " Compose", t[1], t[2], Compose(Compose(t[3:]))) for t in TESTS] From 5a0712f9cad328b67b417fba3c2fedbf4e8ff652 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Thu, 4 Feb 2021 15:58:07 +0000 Subject: [PATCH 46/80] correct returning of affine for RandAffined Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/spatial/array.py | 39 ++++++++++++++++---------- monai/transforms/spatial/dictionary.py | 2 +- 2 files changed, 25 insertions(+), 16 deletions(-) diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index fe9bfd93c1..71a28ed99f 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -894,7 +894,6 @@ class AffineGrid(Transform): affine: If applied, ignore the params (`rotate_params`, etc.) and use the supplied matrix. Should be square with each side = num of image spatial dimensions + 1. - return_affine: boolean as to whether to return the generated affine matrix or not. """ @@ -907,7 +906,6 @@ def __init__( as_tensor_output: bool = True, device: Optional[torch.device] = None, affine: Optional[Union[np.array, torch.Tensor]] = None, - return_affine: bool = False, ) -> None: self.rotate_params = rotate_params self.shear_params = shear_params @@ -918,15 +916,18 @@ def __init__( self.device = device self.affine = affine - self.return_affine = return_affine def __call__( - self, spatial_size: Optional[Sequence[int]] = None, grid: Optional[Union[np.ndarray, torch.Tensor]] = None + self, + spatial_size: Optional[Sequence[int]] = None, + grid: Optional[Union[np.ndarray, torch.Tensor]] = None, + return_affine: bool = False, ) -> Union[np.ndarray, torch.Tensor]: """ Args: spatial_size: output grid size. grid: grid to be transformed. Shape must be (3, H, W) for 2D or (4, H, W, D) for 3D. + return_affine: boolean as to whether to return the generated affine matrix or not. Raises: ValueError: When ``grid=None`` and ``spatial_size=None``. Incompatible values. @@ -960,7 +961,7 @@ def __call__( if grid is None or not isinstance(grid, torch.Tensor): raise ValueError("Unknown grid.") output = grid if self.as_tensor_output else np.asarray(grid.cpu().numpy()) - if self.return_affine: + if return_affine: return output, affine return output @@ -978,7 +979,6 @@ def __init__( scale_range: Optional[Union[Sequence[Sequence[float]], Sequence[float], float]] = None, as_tensor_output: bool = True, device: Optional[torch.device] = None, - return_affine: bool = False, ) -> None: """ Args: @@ -994,7 +994,6 @@ def __init__( as_tensor_output: whether to output tensor instead of numpy array. defaults to True. device: device to store the output grid data. - return_affine: boolean as to whether to return the generated affine matrix or not. See also: - :py:meth:`monai.transforms.utils.create_rotate` @@ -1015,8 +1014,6 @@ def __init__( self.as_tensor_output = as_tensor_output self.device = device - self.return_affine = return_affine - def _get_rand_param(self, param_range): out_param = [] for f in param_range: @@ -1036,12 +1033,16 @@ def randomize(self, data: Optional[Any] = None) -> None: self.scale_params = self._get_rand_param(self.scale_range) def __call__( - self, spatial_size: Optional[Sequence[int]] = None, grid: Optional[Union[np.ndarray, torch.Tensor]] = None + self, + spatial_size: Optional[Sequence[int]] = None, + grid: Optional[Union[np.ndarray, torch.Tensor]] = None, + return_affine: bool = False, ) -> Union[np.ndarray, torch.Tensor]: """ Args: spatial_size: output grid size. grid: grid to be transformed. Shape must be (3, H, W) for 2D or (4, H, W, D) for 3D. + return_affine: boolean as to whether to return the generated affine matrix or not. Returns: a 2D (3xHxW) or 3D (4xHxWxD) grid. @@ -1054,9 +1055,8 @@ def __call__( scale_params=self.scale_params, as_tensor_output=self.as_tensor_output, device=self.device, - return_affine=self.return_affine, ) - return affine_grid(spatial_size, grid) + return affine_grid(spatial_size, grid, return_affine) class RandDeformGrid(Randomizable, Transform): @@ -1348,7 +1348,6 @@ def __init__( scale_range=scale_range, as_tensor_output=True, device=device, - return_affine=True, ) self.resampler = Resample(as_tensor_output=as_tensor_output, device=device) @@ -1373,6 +1372,7 @@ def __call__( spatial_size: Optional[Union[Sequence[int], int]] = None, mode: Optional[Union[GridSampleMode, str]] = None, padding_mode: Optional[Union[GridSamplePadMode, str]] = None, + return_affine: bool = False, ) -> Union[np.ndarray, torch.Tensor]: """ Args: @@ -1388,17 +1388,26 @@ def __call__( padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} Padding mode for outside grid values. Defaults to ``self.padding_mode``. See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + return_affine: boolean as to whether to return the generated affine matrix or not. """ self.randomize() sp_size = fall_back_tuple(spatial_size or self.spatial_size, img.shape[1:]) + affine = np.eye(len(sp_size) + 1) if self._do_transform: - grid = self.rand_affine_grid(spatial_size=sp_size) + out = self.rand_affine_grid(spatial_size=sp_size) + if return_affine: + grid, affine = out + else: + grid = out else: grid = create_grid(spatial_size=sp_size) - return self.resampler( + resampled = self.resampler( img=img, grid=grid, mode=mode or self.mode, padding_mode=padding_mode or self.padding_mode ) + if return_affine: + return resampled, affine + return resampled class Rand2DElastic(Randomizable, Transform): diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index f6e344715a..8b8bd26f00 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -611,7 +611,7 @@ def __call__( sp_size = fall_back_tuple(self.rand_affine.spatial_size, data[self.keys[0]].shape[1:]) if self.rand_affine._do_transform: - grid, affine = self.rand_affine.rand_affine_grid(spatial_size=sp_size) + grid, affine = self.rand_affine.rand_affine_grid(spatial_size=sp_size, return_affine=True) else: grid = create_grid(spatial_size=sp_size) affine = np.eye(len(sp_size) + 1) From 936eaf5349d46c8c1fd93daa5a2527ec0ad89247 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Fri, 5 Feb 2021 15:00:53 +0000 Subject: [PATCH 47/80] RandElastic Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/__init__.py | 2 +- monai/transforms/spatial/dictionary.py | 100 +++++++++++++++++++++++-- monai/transforms/transform.py | 39 +++++++++- requirements-dev.txt | 1 + tests/test_inverse.py | 58 +++++++++++--- 5 files changed, 182 insertions(+), 18 deletions(-) diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index 817244ea99..13c31ede8d 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -234,7 +234,7 @@ ZoomD, ZoomDict, ) -from .transform import InvertibleTransform, MapTransform, Randomizable, Transform +from .transform import InvertibleTransform, MapTransform, Randomizable, Transform, NonRigidTransform from .utility.array import ( AddChannel, AddExtremePointsChannel, diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index 8b8bd26f00..b84083ba3f 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -38,7 +38,7 @@ Zoom, AffineGrid, ) -from monai.transforms.transform import InvertibleTransform, MapTransform, Randomizable +from monai.transforms.transform import InvertibleTransform, MapTransform, Randomizable, NonRigidTransform from monai.transforms.utils import create_grid from monai.utils import ( GridSampleMode, @@ -600,8 +600,6 @@ def set_random_state( def randomize(self, data: Optional[Any] = None) -> None: self.rand_affine.randomize() - self.prob = self.rand_affine.prob - self._do_transform = self.rand_affine._do_transform def __call__( self, data: Mapping[Hashable, Union[np.ndarray, torch.Tensor]] @@ -661,7 +659,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar return d -class Rand2DElasticd(Randomizable, MapTransform): +class Rand2DElasticd(Randomizable, MapTransform, InvertibleTransform, NonRigidTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.Rand2DElastic`. """ @@ -722,7 +720,8 @@ def __init__( - :py:class:`RandAffineGrid` for the random affine parameters configurations. - :py:class:`Affine` for the affine transformation parameters configurations. """ - super().__init__(keys) + MapTransform.__init__(self, keys) + Randomizable.__init__(self, prob) self.rand_2d_elastic = Rand2DElastic( spacing=spacing, magnitude_range=magnitude_range, @@ -771,13 +770,55 @@ def __call__( grid = create_grid(spatial_size=sp_size) for idx, key in enumerate(self.keys): + self.append_applied_transforms(d, key, idx, extra_info={"grid": deepcopy(grid)}) d[key] = self.rand_2d_elastic.resampler( d[key], grid, mode=self.mode[idx], padding_mode=self.padding_mode[idx] ) return d + def get_input_args(self, key: Hashable, idx: int = 0) -> dict: + return { + "keys": key, + "spacing": self.rand_2d_elastic.deform_grid.spacing, + "magnitude_range": self.rand_2d_elastic.deform_grid.magnitude, + "spatial_size": self.rand_2d_elastic.spatial_size, + "prob": self.rand_2d_elastic.prob, + "rotate_range": self.rand_2d_elastic.rand_affine_grid.rotate_range, + "shear_range": self.rand_2d_elastic.rand_affine_grid.shear_range, + "translate_range": self.rand_2d_elastic.rand_affine_grid.translate_range, + "scale_range": self.rand_2d_elastic.rand_affine_grid.scale_range, + "mode": self.mode[idx], + "padding_mode": self.padding_mode[idx], + "as_tensor_output": self.rand_2d_elastic.resampler.as_tensor_output, + "device": self.rand_2d_elastic.resampler.device, + } + + def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + d = deepcopy(dict(data)) + + for key in self.keys: + transform = self.get_most_recent_transform(d, key) + + extra_info = transform["extra_info"] + init_args = transform["init_args"] + orig_size = transform["orig_size"] + # Create inverse transform + inv_def_grid = self.compute_inverse_deformation(orig_size, extra_info["grid"], init_args["spacing"]) + # if no sitk, `inv_def_grid` will be `None`, and data will not be changed. + if inv_def_grid is not None: + # Apply inverse transform + d[key] = self.rand_2d_elastic.resampler( + d[key], inv_def_grid, init_args["mode"], init_args["padding_mode"] + ) + # Back to original size + d[key] = CenterSpatialCrop(roi_size=orig_size)(d[key]) + # Remove the applied transform + self.remove_most_recent_transform(d, key) + + return d + -class Rand3DElasticd(Randomizable, MapTransform): +class Rand3DElasticd(Randomizable, MapTransform, InvertibleTransform, NonRigidTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.Rand3DElastic`. """ @@ -839,7 +880,8 @@ def __init__( - :py:class:`RandAffineGrid` for the random affine parameters configurations. - :py:class:`Affine` for the affine transformation parameters configurations. """ - super().__init__(keys) + MapTransform.__init__(self, keys) + Randomizable.__init__(self, prob) self.rand_3d_elastic = Rand3DElastic( sigma_range=sigma_range, magnitude_range=magnitude_range, @@ -864,6 +906,8 @@ def set_random_state( def randomize(self, grid_size: Sequence[int]) -> None: self.rand_3d_elastic.randomize(grid_size) + self.prob = self.rand_3d_elastic.prob + self._do_transform = self.rand_3d_elastic._do_transform def __call__( self, data: Mapping[Hashable, Union[np.ndarray, torch.Tensor]] @@ -882,11 +926,53 @@ def __call__( grid = self.rand_3d_elastic.rand_affine_grid(grid=grid) for idx, key in enumerate(self.keys): + self.append_applied_transforms(d, key, idx, extra_info={"grid": grid.cpu().numpy()}) d[key] = self.rand_3d_elastic.resampler( d[key], grid, mode=self.mode[idx], padding_mode=self.padding_mode[idx] ) return d + def get_input_args(self, key: Hashable, idx: int = 0) -> dict: + return { + "keys": key, + "sigma_range": self.rand_3d_elastic.sigma, + "magnitude_range": self.rand_3d_elastic.magnitude_range, + "spatial_size": self.rand_3d_elastic.spatial_size, + "prob": self.rand_3d_elastic.prob, + "rotate_range": self.rand_3d_elastic.rand_affine_grid.rotate_range, + "shear_range": self.rand_3d_elastic.rand_affine_grid.shear_range, + "translate_range": self.rand_3d_elastic.rand_affine_grid.translate_range, + "scale_range": self.rand_3d_elastic.rand_affine_grid.scale_range, + "mode": self.mode[idx], + "padding_mode": self.padding_mode[idx], + "as_tensor_output": self.rand_3d_elastic.resampler.as_tensor_output, + "device": self.rand_3d_elastic.resampler.device, + } + + def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + d = deepcopy(dict(data)) + + for key in self.keys: + transform = self.get_most_recent_transform(d, key) + + extra_info = transform["extra_info"] + init_args = transform["init_args"] + orig_size = transform["orig_size"] + # Create inverse transform + inv_def_grid = self.compute_inverse_deformation(orig_size, extra_info["grid"]) + # if no sitk, `inv_def_grid` will be `None`, and data will not be changed. + if inv_def_grid is not None: + # Back to original size + inv_def_grid = CenterSpatialCrop(roi_size=orig_size)(inv_def_grid) + # Apply inverse transform + d[key] = self.rand_3d_elastic.resampler( + d[key], inv_def_grid, init_args["mode"], init_args["padding_mode"] + ) + # Remove the applied transform + self.remove_most_recent_transform(d, key) + + return d + class Flipd(MapTransform, InvertibleTransform): """ diff --git a/monai/transforms/transform.py b/monai/transforms/transform.py index 5f6a6f100f..c040ea59ed 100644 --- a/monai/transforms/transform.py +++ b/monai/transforms/transform.py @@ -12,6 +12,7 @@ A collection of generic interfaces for MONAI transforms. """ +import warnings from abc import ABC, abstractmethod from typing import Any, Hashable, Optional, Tuple @@ -19,8 +20,11 @@ from monai.config import KeysCollection from monai.utils import MAX_SEED, ensure_tuple +from monai.utils import optional_import -__all__ = ["Randomizable", "Transform", "MapTransform", "InvertibleTransform"] +sitk, has_sitk = optional_import("SimpleITK") + +__all__ = ["Randomizable", "Transform", "MapTransform", "InvertibleTransform", "NonRigidTransform"] class Randomizable(ABC): @@ -278,3 +282,36 @@ def inverse(self, data: dict): """ raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") + +class NonRigidTransform(ABC): + @staticmethod + def compute_inverse_deformation(output_size, fwd_def_grid_orig, spacing=None): + if not has_sitk: + warnings.warn("Please install SimpleITK to estimate inverse of non-rigid transforms. Data has not been modified") + return None + # return fwd_def_grid_orig + # Remove any extra dimensions (we'll add them back in at the end) + fwd_def_grid = fwd_def_grid_orig[:len(output_size)] + # Def -> disp + def_to_disp = np.mgrid[[slice(0, i) for i in fwd_def_grid.shape[1:]]].astype(np.float64) + for idx, i in enumerate(fwd_def_grid.shape[1:]): + def_to_disp[idx] -= (i - 1) / 2 + fwd_disp_grid = fwd_def_grid - def_to_disp + # move tensor component to end (T,H,W,[D])->(H,W,[D],T) + fwd_disp_grid = np.moveaxis(fwd_disp_grid, 0, -1) + # Inverse with SimpleITK + fwd_disp_grid_sitk = sitk.GetImageFromArray(fwd_disp_grid, isVector=True) + if spacing is not None: + fwd_disp_grid_sitk.SetSpacing(spacing) + inv_disp_grid_sitk = sitk.InvertDisplacementField(fwd_disp_grid_sitk) + inv_disp_grid = sitk.GetArrayFromImage(inv_disp_grid_sitk) + # move tensor component back to beginning + inv_disp_grid = np.moveaxis(inv_disp_grid, -1, 0) + # Disp -> def + inv_def_grid = inv_disp_grid + def_to_disp + # Add back in any removed dimensions + ndim_in = fwd_def_grid_orig.shape[0] + ndim_out = inv_def_grid.shape[0] + inv_def_grid = np.concatenate([inv_def_grid, fwd_def_grid_orig[ndim_out:ndim_in]]) + + return inv_def_grid diff --git a/requirements-dev.txt b/requirements-dev.txt index 3de7365d16..42d07dff5a 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -30,3 +30,4 @@ Sphinx==3.3.0 recommonmark==0.6.0 sphinx-autodoc-typehints==1.11.1 sphinx-rtd-theme==0.5.0 +SimpleITK diff --git a/tests/test_inverse.py b/tests/test_inverse.py index ba9e385a61..7db1d28914 100644 --- a/tests/test_inverse.py +++ b/tests/test_inverse.py @@ -44,6 +44,9 @@ RandFlipd, RandRotate90d, RandAffined, + Rand2DElasticd, + Rand3DElasticd, + ResizeWithPadOrCrop, ) from monai.utils import optional_import, set_determinism from tests.utils import make_nifti_image, make_rand_affine @@ -58,13 +61,13 @@ else: plt, has_matplotlib = optional_import("matplotlib.pyplot") -set_determinism(seed=0) +# set_determinism(seed=0) AFFINE = make_rand_affine() AFFINE[0] *= 2 IM_1D = AddChannel()(np.arange(0, 10)) -IM_2D_FNAME, SEG_2D_FNAME = [make_nifti_image(i) for i in create_test_image_2d(100, 101)] +IM_2D_FNAME, SEG_2D_FNAME = [make_nifti_image(i) for i in create_test_image_2d(101, 101)] IM_3D_FNAME, SEG_3D_FNAME = [make_nifti_image(i, AFFINE) for i in create_test_image_3d(100, 101, 107)] KEYS = ["image", "label"] @@ -330,6 +333,42 @@ ) )) +TESTS.append(( + "Rand2DElasticd 2d", + DATA_2D, + 0, + Rand2DElasticd( + KEYS, + spacing=[10, 10], + magnitude_range=[1, 1], + # spatial_size=[155, 192], + prob=1, + padding_mode="zeros", + rotate_range=[np.pi / 6, np.pi / 7], + # shear_range=[[0.5, 0.5]], + # translate_range=[10, 5], + # scale_range=[[0.8, 1.2], [0.9, 1.3]], + ) +)) + +TESTS.append(( + "Rand3DElasticd 3d", + DATA_3D, + 0, + Rand3DElasticd( + KEYS, + sigma_range=[1, 3], + magnitude_range=[1., 2., 1.], + spatial_size=[155, 192, 200], + prob=1, + padding_mode="zeros", + rotate_range=[np.pi / 6, np.pi / 7], + shear_range=[[0.5, 0.5]], + translate_range=[10, 5], + scale_range=[[0.8, 1.2], [0.9, 1.3]], + ) +)) + TESTS_COMPOSE_X2 = [(t[0] + " Compose", t[1], t[2], Compose(Compose(t[3:]))) for t in TESTS] TESTS = [*TESTS, *TESTS_COMPOSE_X2] @@ -341,19 +380,18 @@ def plot_im(orig, fwd_bck, fwd): diff_orig_fwd_bck = orig - fwd_bck + ims_to_show = [orig, fwd, fwd_bck, diff_orig_fwd_bck] + titles = ["x", "fx", "f⁻¹fx", "x - f⁻¹fx"] fig, axes = plt.subplots( - 1, 4, gridspec_kw={"width_ratios": [orig.shape[1], fwd_bck.shape[1], diff_orig_fwd_bck.shape[1], fwd.shape[1]]} + 1, 4, gridspec_kw={"width_ratios": [i.shape[1] for i in ims_to_show]} ) vmin = min(np.array(i).min() for i in [orig, fwd_bck, fwd]) vmax = max(np.array(i).max() for i in [orig, fwd_bck, fwd]) - for i, (im, title) in enumerate( - zip([orig, fwd_bck, diff_orig_fwd_bck, fwd], ["x", "f⁻¹fx", "diff", "fx"]) - ): - ax = axes[i] + for im, title, ax in zip(ims_to_show, titles, axes): + _vmin, _vmax = (vmin, vmax) if id(im) != id(diff_orig_fwd_bck) else (None, None) im = np.squeeze(np.array(im)) while im.ndim > 2: im = im[..., im.shape[-1] // 2] - _vmin, _vmax = (vmin, vmax) if i != 2 else (None, None) im_show = ax.imshow(np.squeeze(im), vmin=_vmin, vmax=_vmax) ax.set_title(title, fontsize=25) ax.axis("off") @@ -369,10 +407,12 @@ def check_inverse(self, name, keys, orig_d, fwd_bck_d, unmodified_d, acceptable_ unmodified = unmodified_d[key] if isinstance(orig, np.ndarray): mean_diff = np.mean(np.abs(orig - fwd_bck)) + unmodded_diff = np.mean(np.abs(orig - ResizeWithPadOrCrop(orig.shape[1:])(unmodified))) try: self.assertLessEqual(mean_diff, acceptable_diff) + self.assertLessEqual(mean_diff, unmodded_diff) except AssertionError: - print(f"Failed: {name}. Mean diff = {mean_diff} (expected <= {acceptable_diff})") + print(f"Failed: {name}. Mean diff = {mean_diff} (expected <= {acceptable_diff}), unmodified diff: {unmodded_diff}") if has_matplotlib: plot_im(orig, fwd_bck, unmodified) raise From c0decd59686b9562d37662f8a805e67f59850016 Mon Sep 17 00:00:00 2001 From: Rich <33289025+rijobro@users.noreply.github.com> Date: Mon, 8 Feb 2021 09:39:12 +0000 Subject: [PATCH 48/80] cpg -> disp Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/transform.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/monai/transforms/transform.py b/monai/transforms/transform.py index c040ea59ed..be61c1efa6 100644 --- a/monai/transforms/transform.py +++ b/monai/transforms/transform.py @@ -285,25 +285,23 @@ def inverse(self, data: dict): class NonRigidTransform(ABC): @staticmethod - def compute_inverse_deformation(output_size, fwd_def_grid_orig, spacing=None): + def compute_inverse_deformation(num_spatial_dims, fwd_def_grid_orig, spacing, num_iters: int = 10): if not has_sitk: warnings.warn("Please install SimpleITK to estimate inverse of non-rigid transforms. Data has not been modified") return None - # return fwd_def_grid_orig # Remove any extra dimensions (we'll add them back in at the end) - fwd_def_grid = fwd_def_grid_orig[:len(output_size)] + fwd_def_grid = fwd_def_grid_orig[:num_spatial_dims].cpu().numpy() # Def -> disp def_to_disp = np.mgrid[[slice(0, i) for i in fwd_def_grid.shape[1:]]].astype(np.float64) for idx, i in enumerate(fwd_def_grid.shape[1:]): def_to_disp[idx] -= (i - 1) / 2 + def_to_disp[idx] *= spacing[idx] fwd_disp_grid = fwd_def_grid - def_to_disp # move tensor component to end (T,H,W,[D])->(H,W,[D],T) fwd_disp_grid = np.moveaxis(fwd_disp_grid, 0, -1) # Inverse with SimpleITK fwd_disp_grid_sitk = sitk.GetImageFromArray(fwd_disp_grid, isVector=True) - if spacing is not None: - fwd_disp_grid_sitk.SetSpacing(spacing) - inv_disp_grid_sitk = sitk.InvertDisplacementField(fwd_disp_grid_sitk) + inv_disp_grid_sitk = sitk.InvertDisplacementField(fwd_disp_grid_sitk, num_iters) inv_disp_grid = sitk.GetArrayFromImage(inv_disp_grid_sitk) # move tensor component back to beginning inv_disp_grid = np.moveaxis(inv_disp_grid, -1, 0) From 31ddf65a703e62ce89250699e6b5a981d09e8af8 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Mon, 8 Feb 2021 15:22:43 +0000 Subject: [PATCH 49/80] SimpleITK for inverse nonrigid Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/spatial/dictionary.py | 50 +++++---- monai/transforms/transform.py | 142 +++++++++++++++++++++---- tests/test_inverse.py | 10 +- 3 files changed, 156 insertions(+), 46 deletions(-) diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index b84083ba3f..134508073c 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -756,13 +756,13 @@ def __call__( self.randomize(spatial_size=sp_size) if self.rand_2d_elastic._do_transform: - grid = self.rand_2d_elastic.deform_grid(spatial_size=sp_size) - grid = self.rand_2d_elastic.rand_affine_grid(grid=grid) + cpg = self.rand_2d_elastic.deform_grid(spatial_size=sp_size) + cpg = self.rand_2d_elastic.rand_affine_grid(grid=cpg) grid = torch.nn.functional.interpolate( # type: ignore recompute_scale_factor=True, - input=grid.unsqueeze(0), + input=cpg.unsqueeze(0), scale_factor=ensure_tuple_rep(self.rand_2d_elastic.deform_grid.spacing, 2), - mode=InterpolateMode.BICUBIC.value, + mode=InterpolateMode.BILINEAR.value, align_corners=False, ) grid = CenterSpatialCrop(roi_size=sp_size)(grid[0]) @@ -803,15 +803,19 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar init_args = transform["init_args"] orig_size = transform["orig_size"] # Create inverse transform - inv_def_grid = self.compute_inverse_deformation(orig_size, extra_info["grid"], init_args["spacing"]) - # if no sitk, `inv_def_grid` will be `None`, and data will not be changed. - if inv_def_grid is not None: - # Apply inverse transform - d[key] = self.rand_2d_elastic.resampler( - d[key], inv_def_grid, init_args["mode"], init_args["padding_mode"] - ) - # Back to original size + fwd_def = extra_info["grid"] + if fwd_def is None: d[key] = CenterSpatialCrop(roi_size=orig_size)(d[key]) + else: + inv_def = self.compute_inverse_deformation(len(orig_size), fwd_def) + # if no sitk, `inv_def` will be `None`, and data will not be changed. + if inv_def is not None: + # Back to original size + inv_def = CenterSpatialCrop(roi_size=orig_size)(inv_def) + # Apply inverse transform + d[key] = self.rand_2d_elastic.resampler( + d[key], inv_def, init_args["mode"], init_args["padding_mode"] + ) # Remove the applied transform self.remove_most_recent_transform(d, key) @@ -959,15 +963,19 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar init_args = transform["init_args"] orig_size = transform["orig_size"] # Create inverse transform - inv_def_grid = self.compute_inverse_deformation(orig_size, extra_info["grid"]) - # if no sitk, `inv_def_grid` will be `None`, and data will not be changed. - if inv_def_grid is not None: - # Back to original size - inv_def_grid = CenterSpatialCrop(roi_size=orig_size)(inv_def_grid) - # Apply inverse transform - d[key] = self.rand_3d_elastic.resampler( - d[key], inv_def_grid, init_args["mode"], init_args["padding_mode"] - ) + fwd_def = extra_info["grid"] + if fwd_def is None: + d[key] = CenterSpatialCrop(roi_size=orig_size)(d[key]) + else: + inv_def = self.compute_inverse_deformation(len(orig_size), fwd_def) + # if no sitk, `inv_def` will be `None`, and data will not be changed. + if inv_def is not None: + # Back to original size + inv_def = CenterSpatialCrop(roi_size=orig_size)(inv_def) + # Apply inverse transform + d[key] = self.rand_3d_elastic.resampler( + d[key], inv_def, init_args["mode"], init_args["padding_mode"] + ) # Remove the applied transform self.remove_most_recent_transform(d, key) diff --git a/monai/transforms/transform.py b/monai/transforms/transform.py index be61c1efa6..ac04634a37 100644 --- a/monai/transforms/transform.py +++ b/monai/transforms/transform.py @@ -15,14 +15,17 @@ import warnings from abc import ABC, abstractmethod from typing import Any, Hashable, Optional, Tuple - +import torch import numpy as np +from itertools import chain from monai.config import KeysCollection from monai.utils import MAX_SEED, ensure_tuple from monai.utils import optional_import sitk, has_sitk = optional_import("SimpleITK") +vtk, has_vtk = optional_import("vtk") +vtk_numpy_support, _ = optional_import("vtk.util.numpy_support") __all__ = ["Randomizable", "Transform", "MapTransform", "InvertibleTransform", "NonRigidTransform"] @@ -285,31 +288,130 @@ def inverse(self, data: dict): class NonRigidTransform(ABC): @staticmethod - def compute_inverse_deformation(num_spatial_dims, fwd_def_grid_orig, spacing, num_iters: int = 10): - if not has_sitk: + def _get_disp_to_def_arr(shape, spacing): + def_to_disp = np.mgrid[[slice(0, i) for i in shape]].astype(np.float64) + for idx, i in enumerate(shape): + # shift for origin (in MONAI, center of image) + def_to_disp[idx] -= (i - 1) / 2 + # if supplied, account for spacing (e.g., for control point grids) + if spacing is not None: + def_to_disp[idx] *= spacing[idx] + return def_to_disp + + @staticmethod + def _inv_disp_w_sitk(fwd_disp, num_iters): + fwd_disp_sitk = sitk.GetImageFromArray(fwd_disp, isVector=True) + inv_disp_sitk = sitk.InvertDisplacementField(fwd_disp_sitk, num_iters) + inv_disp = sitk.GetArrayFromImage(inv_disp_sitk) + return inv_disp + + @staticmethod + def _inv_disp_w_vtk(fwd_disp): + while fwd_disp.shape[-1] < 3: + fwd_disp = np.append(fwd_disp, np.zeros(fwd_disp.shape[:-1] + (1, )), axis=-1) + fwd_disp = fwd_disp[..., None, :] + # fwd_disp_vtk = vtk.vtkImageImport() + # # The previously created array is converted to a string of chars and imported. + # data_string = fwd_disp.tostring() + # fwd_disp_vtk.CopyImportVoidPointer(data_string, len(data_string)) + # # The type of the newly imported data is set to unsigned char (uint8) + # fwd_disp_vtk.SetDataScalarTypeToUnsignedChar() + # fwd_disp_vtk.SetNumberOfScalarComponents(3) + extent = list(chain.from_iterable(zip([0, 0, 0], fwd_disp.shape[:-1]))) + # fwd_disp_vtk.SetWholeExtent(extent) + # fwd_disp_vtk.SetDataExtentToWholeExtent() + # fwd_disp_vtk.Update() + # fwd_disp_vtk = fwd_disp_vtk.GetOutput() + + fwd_disp_flattened = fwd_disp.flatten() # need to keep this in memory + vtk_data_array = vtk_numpy_support.numpy_to_vtk(fwd_disp_flattened) + + # Generating the vtkImageData + fwd_disp_vtk = vtk.vtkImageData() + fwd_disp_vtk.SetOrigin(0, 0, 0) + fwd_disp_vtk.SetSpacing(1, 1, 1) + fwd_disp_vtk.SetDimensions(*fwd_disp.shape[:-1]) + + fwd_disp_vtk.AllocateScalars(vtk_numpy_support.get_vtk_array_type(fwd_disp.dtype), 3) + fwd_disp_vtk.SetExtent(extent) + fwd_disp_vtk.GetPointData().AddArray(vtk_data_array) + + # # create b-spline coefficients for the displacement grid + # bspline_filter = vtk.vtkImageBSplineCoefficients() + # bspline_filter.SetInputData(fwd_disp_vtk) + # bspline_filter.Update() + + # # use these b-spline coefficients to create a transform + # bspline_transform = vtk.vtkBSplineTransform() + # bspline_transform.SetCoefficientData(bspline_filter.GetOutput()) + # bspline_transform.Update() + + # # invert the b-spline transform onto a new grid + # grid_maker = vtk.vtkTransformToGrid() + # grid_maker.SetInput(bspline_transform.GetInverse()) + # grid_maker.SetGridOrigin(fwd_disp_vtk.GetOrigin()) + # grid_maker.SetGridSpacing(fwd_disp_vtk.GetSpacing()) + # grid_maker.SetGridExtent(fwd_disp_vtk.GetExtent()) + # grid_maker.SetGridScalarTypeToFloat() + # grid_maker.Update() + + # # Get inverse displacement as an image + # inv_disp_vtk = grid_maker.GetOutput() + + # from vtk.util.numpy_support import vtk_to_numpy + # inv_disp = vtk_numpy_support.vtk_to_numpy(inv_disp_vtk.GetPointData().GetScalars()) + inv_disp = vtk_numpy_support.vtk_to_numpy(fwd_disp_vtk.GetPointData().GetArray(0)) + inv_disp = inv_disp.reshape(fwd_disp.shape) + + return inv_disp + + + @staticmethod + def compute_inverse_deformation(num_spatial_dims, fwd_def_orig, spacing=None, num_iters: int = 100, use_package: str = "sitk"): + """Package can be vtk or sitk.""" + if use_package.lower() == "vtk" and not has_vtk: + warnings.warn("Please install VTK to estimate inverse of non-rigid transforms. Data has not been modified") + return None + if use_package.lower() == "sitk" and not has_sitk: warnings.warn("Please install SimpleITK to estimate inverse of non-rigid transforms. Data has not been modified") return None + + # Convert to numpy if necessary + if isinstance(fwd_def_orig, torch.Tensor): + fwd_def_orig = fwd_def_orig.cpu().numpy() # Remove any extra dimensions (we'll add them back in at the end) - fwd_def_grid = fwd_def_grid_orig[:num_spatial_dims].cpu().numpy() + fwd_def = fwd_def_orig[:num_spatial_dims] # Def -> disp - def_to_disp = np.mgrid[[slice(0, i) for i in fwd_def_grid.shape[1:]]].astype(np.float64) - for idx, i in enumerate(fwd_def_grid.shape[1:]): - def_to_disp[idx] -= (i - 1) / 2 - def_to_disp[idx] *= spacing[idx] - fwd_disp_grid = fwd_def_grid - def_to_disp + def_to_disp = NonRigidTransform._get_disp_to_def_arr(fwd_def.shape[1:], spacing) + fwd_disp = fwd_def - def_to_disp # move tensor component to end (T,H,W,[D])->(H,W,[D],T) - fwd_disp_grid = np.moveaxis(fwd_disp_grid, 0, -1) - # Inverse with SimpleITK - fwd_disp_grid_sitk = sitk.GetImageFromArray(fwd_disp_grid, isVector=True) - inv_disp_grid_sitk = sitk.InvertDisplacementField(fwd_disp_grid_sitk, num_iters) - inv_disp_grid = sitk.GetArrayFromImage(inv_disp_grid_sitk) + fwd_disp = np.moveaxis(fwd_disp, 0, -1) + + # If using vtk... + if use_package.lower() == "vtk": + inv_disp = NonRigidTransform._inv_disp_w_vtk(fwd_disp) + # If using sitk... + else: + inv_disp = NonRigidTransform._inv_disp_w_sitk(fwd_disp, num_iters) + + + import matplotlib.pyplot as plt + fig, axes = plt.subplots(2, 2) + for i, direc1 in enumerate(["x", "y"]): + for j, (im, direc2) in enumerate(zip([fwd_disp, inv_disp], ["fwd", "inv"])): + ax = axes[i, j] + im_show = ax.imshow(im[..., i]) + ax.set_title(f"{direc2}{direc1}", fontsize=25) + ax.axis("off") + fig.colorbar(im_show, ax=ax) + plt.show() # move tensor component back to beginning - inv_disp_grid = np.moveaxis(inv_disp_grid, -1, 0) + inv_disp = np.moveaxis(inv_disp, -1, 0) # Disp -> def - inv_def_grid = inv_disp_grid + def_to_disp + inv_def = inv_disp + def_to_disp # Add back in any removed dimensions - ndim_in = fwd_def_grid_orig.shape[0] - ndim_out = inv_def_grid.shape[0] - inv_def_grid = np.concatenate([inv_def_grid, fwd_def_grid_orig[ndim_out:ndim_in]]) + ndim_in = fwd_def_orig.shape[0] + ndim_out = inv_def.shape[0] + inv_def = np.concatenate([inv_def, fwd_def_orig[ndim_out:ndim_in]]) - return inv_def_grid + return inv_def diff --git a/tests/test_inverse.py b/tests/test_inverse.py index 7db1d28914..c0cecb418f 100644 --- a/tests/test_inverse.py +++ b/tests/test_inverse.py @@ -61,13 +61,13 @@ else: plt, has_matplotlib = optional_import("matplotlib.pyplot") -# set_determinism(seed=0) +set_determinism(seed=0) AFFINE = make_rand_affine() AFFINE[0] *= 2 IM_1D = AddChannel()(np.arange(0, 10)) -IM_2D_FNAME, SEG_2D_FNAME = [make_nifti_image(i) for i in create_test_image_2d(101, 101)] +IM_2D_FNAME, SEG_2D_FNAME = [make_nifti_image(i) for i in create_test_image_2d(101, 100)] IM_3D_FNAME, SEG_3D_FNAME = [make_nifti_image(i, AFFINE) for i in create_test_image_3d(100, 101, 107)] KEYS = ["image", "label"] @@ -336,15 +336,15 @@ TESTS.append(( "Rand2DElasticd 2d", DATA_2D, - 0, + 8e-2, Rand2DElasticd( KEYS, spacing=[10, 10], - magnitude_range=[1, 1], + magnitude_range=[2, 2], # spatial_size=[155, 192], prob=1, padding_mode="zeros", - rotate_range=[np.pi / 6, np.pi / 7], + rotate_range=[[np.pi / 6, np.pi / 6], np.pi / 7], # shear_range=[[0.5, 0.5]], # translate_range=[10, 5], # scale_range=[[0.8, 1.2], [0.9, 1.3]], From 7ec45965010e5d4c6050f0bf68eff4f829dd0469 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Mon, 8 Feb 2021 17:03:51 +0000 Subject: [PATCH 50/80] code format Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/__init__.py | 2 +- monai/transforms/croppad/dictionary.py | 1 + monai/transforms/spatial/array.py | 60 +-- monai/transforms/spatial/dictionary.py | 76 +-- monai/transforms/transform.py | 35 +- tests/test_inverse.py | 641 +++++++++++++------------ tests/utils.py | 2 + 7 files changed, 428 insertions(+), 389 deletions(-) diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index 6f99cfe3d2..ed157dcf45 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -234,7 +234,7 @@ ZoomD, ZoomDict, ) -from .transform import InvertibleTransform, MapTransform, Randomizable, Transform, NonRigidTransform +from .transform import InvertibleTransform, MapTransform, NonRigidTransform, Randomizable, Transform from .utility.array import ( AddChannel, AddExtremePointsChannel, diff --git a/monai/transforms/croppad/dictionary.py b/monai/transforms/croppad/dictionary.py index 6efcce197c..109f66b759 100644 --- a/monai/transforms/croppad/dictionary.py +++ b/monai/transforms/croppad/dictionary.py @@ -405,6 +405,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar return d + class RandSpatialCropd(Randomizable, MapTransform, InvertibleTransform): """ Dictionary-based version :py:class:`monai.transforms.RandSpatialCrop`. diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index ef377b185f..bbbd7be26e 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -22,9 +22,8 @@ from monai.config import USE_COMPILED, DtypeLike from monai.data.utils import compute_shape_offset, to_affine_nd, zoom_affine from monai.networks.layers import AffineTransform, GaussianFilter, grid_pull -from monai.transforms.transform import Randomizable, Transform from monai.transforms.croppad.array import CenterSpatialCrop -from monai.utils import issequenceiterable +from monai.transforms.transform import Randomizable, Transform from monai.transforms.utils import ( create_control_grid, create_grid, @@ -42,6 +41,7 @@ ensure_tuple_rep, ensure_tuple_size, fall_back_tuple, + issequenceiterable, optional_import, ) @@ -427,7 +427,7 @@ def __call__( align_corners: Optional[bool] = None, dtype: DtypeLike = None, return_rotation_matrix: bool = False, - ) -> np.ndarray: + ) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]: """ Args: img: channel first array, must have shape: [chns, H, W] or [chns, H, W, D]. @@ -746,7 +746,7 @@ def __call__( align_corners=self.align_corners if align_corners is None else align_corners, dtype=dtype or self.dtype or img.dtype, ) - return rotator(img) + return np.array(rotator(img)) class RandFlip(Randomizable, Transform): @@ -914,7 +914,7 @@ def __init__( scale_params: Optional[Union[Sequence[float], float]] = None, as_tensor_output: bool = True, device: Optional[torch.device] = None, - affine: Optional[Union[np.array, torch.Tensor]] = None, + affine: Optional[Union[np.ndarray, torch.Tensor]] = None, ) -> None: self.rotate_params = rotate_params self.shear_params = shear_params @@ -931,7 +931,7 @@ def __call__( spatial_size: Optional[Sequence[int]] = None, grid: Optional[Union[np.ndarray, torch.Tensor]] = None, return_affine: bool = False, - ) -> Union[np.ndarray, torch.Tensor]: + ) -> Union[np.ndarray, torch.Tensor, Tuple[Union[np.ndarray, torch.Tensor], torch.Tensor]]: """ Args: spatial_size: output grid size. @@ -948,6 +948,7 @@ def __call__( else: raise ValueError("Incompatible values: grid=None and spatial_size=None.") + affine: Union[np.ndarray, torch.Tensor] if self.affine is None: spatial_dims = len(grid.shape) - 1 affine = np.eye(spatial_dims + 1) @@ -969,7 +970,7 @@ def __call__( grid = (affine.float() @ grid.reshape((grid.shape[0], -1)).float()).reshape([-1] + list(grid.shape[1:])) if grid is None or not isinstance(grid, torch.Tensor): raise ValueError("Unknown grid.") - output = grid if self.as_tensor_output else np.asarray(grid.cpu().numpy()) + output: Union[np.ndarray, torch.Tensor] = grid if self.as_tensor_output else np.asarray(grid.cpu().numpy()) if return_affine: return output, affine return output @@ -982,10 +983,10 @@ class RandAffineGrid(Randomizable, Transform): def __init__( self, - rotate_range: Optional[Union[Sequence[Sequence[float]], Sequence[float], float]] = None, - shear_range: Optional[Union[Sequence[Sequence[float]], Sequence[float], float]] = None, - translate_range: Optional[Union[Sequence[Sequence[float]], Sequence[float], float]] = None, - scale_range: Optional[Union[Sequence[Sequence[float]], Sequence[float], float]] = None, + rotate_range: Optional[Union[Sequence[Union[Tuple[float, float], float]], float]] = None, + shear_range: Optional[Union[Sequence[Union[Tuple[float, float], float]], float]] = None, + translate_range: Optional[Union[Sequence[Union[Tuple[float, float], float]], float]] = None, + scale_range: Optional[Union[Sequence[Union[Tuple[float, float], float]], float]] = None, as_tensor_output: bool = True, device: Optional[torch.device] = None, ) -> None: @@ -1034,7 +1035,6 @@ def _get_rand_param(self, param_range): out_param.append(self.R.uniform(-f, f)) return out_param - def randomize(self, data: Optional[Any] = None) -> None: self.rotate_params = self._get_rand_param(self.rotate_range) self.shear_params = self._get_rand_param(self.shear_range) @@ -1046,7 +1046,7 @@ def __call__( spatial_size: Optional[Sequence[int]] = None, grid: Optional[Union[np.ndarray, torch.Tensor]] = None, return_affine: bool = False, - ) -> Union[np.ndarray, torch.Tensor]: + ) -> Union[np.ndarray, torch.Tensor, Tuple[Union[np.ndarray, torch.Tensor], torch.Tensor]]: """ Args: spatial_size: output grid size. @@ -1291,7 +1291,7 @@ def __call__( See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample """ sp_size = fall_back_tuple(spatial_size or self.spatial_size, img.shape[1:]) - grid = self.affine_grid(spatial_size=sp_size) + grid: torch.Tensor = self.affine_grid(spatial_size=sp_size) # type: ignore return self.resampler( img=img, grid=grid, mode=mode or self.mode, padding_mode=padding_mode or self.padding_mode ) @@ -1305,11 +1305,11 @@ class RandAffine(Randomizable, Transform): def __init__( self, prob: float = 0.1, - rotate_range: Optional[Union[Sequence[float], float]] = None, - shear_range: Optional[Union[Sequence[float], float]] = None, - translate_range: Optional[Union[Sequence[float], float]] = None, - scale_range: Optional[Union[Sequence[float], float]] = None, - spatial_size: Optional[Union[Sequence[float], float]] = None, + rotate_range: Optional[Union[Sequence[Union[Tuple[float, float], float]], float]] = None, + shear_range: Optional[Union[Sequence[Union[Tuple[float, float], float]], float]] = None, + translate_range: Optional[Union[Sequence[Union[Tuple[float, float], float]], float]] = None, + scale_range: Optional[Union[Sequence[Union[Tuple[float, float], float]], float]] = None, + spatial_size: Optional[Union[Sequence[Union[Tuple[float, float], float]], float]] = None, mode: Union[GridSampleMode, str] = GridSampleMode.BILINEAR, padding_mode: Union[GridSamplePadMode, str] = GridSamplePadMode.REFLECTION, as_tensor_output: bool = True, @@ -1382,7 +1382,7 @@ def __call__( mode: Optional[Union[GridSampleMode, str]] = None, padding_mode: Optional[Union[GridSamplePadMode, str]] = None, return_affine: bool = False, - ) -> Union[np.ndarray, torch.Tensor]: + ) -> Union[np.ndarray, torch.Tensor, Tuple[Union[np.ndarray, torch.Tensor], torch.Tensor]]: """ Args: img: shape must be (num_channels, H, W[, D]), @@ -1429,11 +1429,11 @@ def __init__( spacing: Union[Tuple[float, float], float], magnitude_range: Tuple[float, float], prob: float = 0.1, - rotate_range: Optional[Union[Sequence[float], float]] = None, - shear_range: Optional[Union[Sequence[float], float]] = None, - translate_range: Optional[Union[Sequence[float], float]] = None, - scale_range: Optional[Union[Sequence[float], float]] = None, - spatial_size: Optional[Union[Sequence[int], int]] = None, + rotate_range: Optional[Union[Sequence[Union[Tuple[float, float], float]], float]] = None, + shear_range: Optional[Union[Sequence[Union[Tuple[float, float], float]], float]] = None, + translate_range: Optional[Union[Sequence[Union[Tuple[float, float], float]], float]] = None, + scale_range: Optional[Union[Sequence[Union[Tuple[float, float], float]], float]] = None, + spatial_size: Optional[Union[Sequence[Union[Tuple[float, float], float]], float]] = None, mode: Union[GridSampleMode, str] = GridSampleMode.BILINEAR, padding_mode: Union[GridSamplePadMode, str] = GridSamplePadMode.REFLECTION, as_tensor_output: bool = False, @@ -1554,11 +1554,11 @@ def __init__( sigma_range: Tuple[float, float], magnitude_range: Tuple[float, float], prob: float = 0.1, - rotate_range: Optional[Union[Sequence[float], float]] = None, - shear_range: Optional[Union[Sequence[float], float]] = None, - translate_range: Optional[Union[Sequence[float], float]] = None, - scale_range: Optional[Union[Sequence[float], float]] = None, - spatial_size: Optional[Union[Sequence[int], int]] = None, + rotate_range: Optional[Union[Sequence[Union[Tuple[float, float], float]], float]] = None, + shear_range: Optional[Union[Sequence[Union[Tuple[float, float], float]], float]] = None, + translate_range: Optional[Union[Sequence[Union[Tuple[float, float], float]], float]] = None, + scale_range: Optional[Union[Sequence[Union[Tuple[float, float], float]], float]] = None, + spatial_size: Optional[Union[Sequence[Union[Tuple[float, float], float]], float]] = None, mode: Union[GridSampleMode, str] = GridSampleMode.BILINEAR, padding_mode: Union[GridSamplePadMode, str] = GridSamplePadMode.REFLECTION, as_tensor_output: bool = False, diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index 134508073c..7b116d68a5 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -26,6 +26,7 @@ from monai.networks.layers.simplelayers import GaussianFilter from monai.transforms.croppad.array import CenterSpatialCrop, SpatialPad from monai.transforms.spatial.array import ( + AffineGrid, Flip, Orientation, Rand2DElastic, @@ -36,9 +37,8 @@ Rotate90, Spacing, Zoom, - AffineGrid, ) -from monai.transforms.transform import InvertibleTransform, MapTransform, Randomizable, NonRigidTransform +from monai.transforms.transform import InvertibleTransform, MapTransform, NonRigidTransform, Randomizable from monai.transforms.utils import create_grid from monai.utils import ( GridSampleMode, @@ -195,7 +195,9 @@ def __call__( align_corners=self.align_corners[idx], dtype=self.dtype[idx], ) - self.append_applied_transforms(d, key, extra_info={"meta_data_key": meta_data_key, "old_affine": old_affine}) + self.append_applied_transforms( + d, key, extra_info={"meta_data_key": meta_data_key, "old_affine": old_affine} + ) # set the 'affine' key meta_data["affine"] = new_affine return d @@ -209,7 +211,7 @@ def get_input_args(self, key: Hashable, idx: int = 0) -> dict: "padding_mode": self.padding_mode[idx], "align_corners": self.align_corners[idx], "dtype": self.dtype[idx], - "meta_key_postfix": self.meta_key_postfix + "meta_key_postfix": self.meta_key_postfix, } def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: @@ -219,8 +221,9 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar init_args = transform["init_args"] if init_args["diagonal"]: raise RuntimeError( - "Spacingd:inverse not yet implemented for diagonal=True. " + - "Please raise a github issue if you need this feature") + "Spacingd:inverse not yet implemented for diagonal=True. " + + "Please raise a github issue if you need this feature" + ) # Create inverse transform meta_data = d[transform["extra_info"]["meta_data_key"]] orig_pixdim = np.sqrt(np.sum(np.square(transform["extra_info"]["old_affine"]), 0))[:-1] @@ -297,7 +300,9 @@ def __call__( meta_data_key = f"{key}_{self.meta_key_postfix}" meta_data = d[meta_data_key] d[key], old_affine, new_affine = self.ornt_transform(d[key], affine=meta_data["affine"]) - self.append_applied_transforms(d, key, extra_info={"meta_data_key": meta_data_key, "old_affine": old_affine}) + self.append_applied_transforms( + d, key, extra_info={"meta_data_key": meta_data_key, "old_affine": old_affine} + ) d[meta_data_key]["affine"] = new_affine return d @@ -417,13 +422,13 @@ def randomize(self, data: Optional[Any] = None) -> None: def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Mapping[Hashable, np.ndarray]: self.randomize() + d = dict(data) if not self._do_transform: for key in self.keys: - self.append_applied_transforms(data, key) + self.append_applied_transforms(d, key) return data rotator = Rotate90(self._rand_k, self.spatial_axes) - d = dict(data) for key in self.keys: d[key] = rotator(d[key]) self.append_applied_transforms(d, key, extra_info={"rand_k": self._rand_k}) @@ -531,10 +536,10 @@ def __init__( keys: KeysCollection, spatial_size: Optional[Union[Sequence[int], int]] = None, prob: float = 0.1, - rotate_range: Optional[Union[Sequence[float], float]] = None, - shear_range: Optional[Union[Sequence[float], float]] = None, - translate_range: Optional[Union[Sequence[float], float]] = None, - scale_range: Optional[Union[Sequence[float], float]] = None, + rotate_range: Optional[Union[Sequence[Union[Tuple[float, float], float]], float]] = None, + shear_range: Optional[Union[Sequence[Union[Tuple[float, float], float]], float]] = None, + translate_range: Optional[Union[Sequence[Union[Tuple[float, float], float]], float]] = None, + scale_range: Optional[Union[Sequence[Union[Tuple[float, float], float]], float]] = None, mode: GridSampleModeSequence = GridSampleMode.BILINEAR, padding_mode: GridSamplePadModeSequence = GridSamplePadMode.REFLECTION, as_tensor_output: bool = True, @@ -615,7 +620,7 @@ def __call__( affine = np.eye(len(sp_size) + 1) for idx, key in enumerate(self.keys): - self.append_applied_transforms(d, key, idx, extra_info={"affine": affine, "orig_was_numpy": isinstance(d[key], np.ndarray)}) + self.append_applied_transforms(d, key, idx, extra_info={"affine": affine}) d[key] = self.rand_affine.resampler(d[key], grid, mode=self.mode[idx], padding_mode=self.padding_mode[idx]) return d @@ -647,12 +652,13 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar inv_affine = np.linalg.inv(fwd_affine) affine_grid = AffineGrid(affine=inv_affine) - grid = affine_grid(orig_size) + grid: torch.Tensor = affine_grid(orig_size) # type: ignore # Apply inverse transform - d[key] = self.rand_affine.resampler(d[key], grid, init_args["mode"], init_args["padding_mode"]) - if extra_info["orig_was_numpy"]: - d[key] = d[key].cpu().numpy() + out = self.rand_affine.resampler(d[key], grid, init_args["mode"], init_args["padding_mode"]) + # Convert to original output type + if isinstance(out, torch.Tensor): + d[key] = out.cpu().numpy() # Remove the applied transform self.remove_most_recent_transform(d, key) @@ -671,10 +677,10 @@ def __init__( magnitude_range: Tuple[float, float], spatial_size: Optional[Union[Sequence[int], int]] = None, prob: float = 0.1, - rotate_range: Optional[Union[Sequence[float], float]] = None, - shear_range: Optional[Union[Sequence[float], float]] = None, - translate_range: Optional[Union[Sequence[float], float]] = None, - scale_range: Optional[Union[Sequence[float], float]] = None, + rotate_range: Optional[Union[Sequence[Union[Tuple[float, float], float]], float]] = None, + shear_range: Optional[Union[Sequence[Union[Tuple[float, float], float]], float]] = None, + translate_range: Optional[Union[Sequence[Union[Tuple[float, float], float]], float]] = None, + scale_range: Optional[Union[Sequence[Union[Tuple[float, float], float]], float]] = None, mode: GridSampleModeSequence = GridSampleMode.BILINEAR, padding_mode: GridSamplePadModeSequence = GridSamplePadMode.REFLECTION, as_tensor_output: bool = False, @@ -813,9 +819,13 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar # Back to original size inv_def = CenterSpatialCrop(roi_size=orig_size)(inv_def) # Apply inverse transform - d[key] = self.rand_2d_elastic.resampler( + out = self.rand_2d_elastic.resampler( d[key], inv_def, init_args["mode"], init_args["padding_mode"] ) + if isinstance(out, torch.Tensor): + d[key] = out.cpu().numpy() + else: + d[key] = out # Remove the applied transform self.remove_most_recent_transform(d, key) @@ -832,12 +842,12 @@ def __init__( keys: KeysCollection, sigma_range: Tuple[float, float], magnitude_range: Tuple[float, float], - spatial_size: Optional[Union[Sequence[int], int]] = None, + spatial_size: Optional[Union[Sequence[Union[Tuple[float, float], float]], float]] = None, prob: float = 0.1, - rotate_range: Optional[Union[Sequence[float], float]] = None, - shear_range: Optional[Union[Sequence[float], float]] = None, - translate_range: Optional[Union[Sequence[float], float]] = None, - scale_range: Optional[Union[Sequence[float], float]] = None, + rotate_range: Optional[Union[Sequence[Union[Tuple[float, float], float]], float]] = None, + shear_range: Optional[Union[Sequence[Union[Tuple[float, float], float]], float]] = None, + translate_range: Optional[Union[Sequence[Union[Tuple[float, float], float]], float]] = None, + scale_range: Optional[Union[Sequence[Union[Tuple[float, float], float]], float]] = None, mode: GridSampleModeSequence = GridSampleMode.BILINEAR, padding_mode: GridSamplePadModeSequence = GridSamplePadMode.REFLECTION, as_tensor_output: bool = False, @@ -973,9 +983,13 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar # Back to original size inv_def = CenterSpatialCrop(roi_size=orig_size)(inv_def) # Apply inverse transform - d[key] = self.rand_3d_elastic.resampler( + out = self.rand_3d_elastic.resampler( d[key], inv_def, init_args["mode"], init_args["padding_mode"] ) + if isinstance(out, torch.Tensor): + d[key] = out.cpu().numpy() + else: + d[key] = out # Remove the applied transform self.remove_most_recent_transform(d, key) @@ -1061,7 +1075,6 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda return d - def get_input_args(self, key: Hashable, idx: int = 0) -> dict: return { "keys": key, @@ -1264,7 +1277,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda for key in self.keys: self.append_applied_transforms(d, key) return d - angle: Sequence = self.x if d[self.keys[0]].ndim == 3 else (self.x, self.y, self.z) + angle: Union[Sequence[float], float] = self.x if d[self.keys[0]].ndim == 3 else (self.x, self.y, self.z) rotator = Rotate( angle=angle, keep_size=self.keep_size, @@ -1410,6 +1423,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar return d + class RandZoomd(Randomizable, MapTransform, InvertibleTransform): """ Dict-based version :py:class:`monai.transforms.RandZoom`. diff --git a/monai/transforms/transform.py b/monai/transforms/transform.py index ac04634a37..8fcf5a42e7 100644 --- a/monai/transforms/transform.py +++ b/monai/transforms/transform.py @@ -14,14 +14,14 @@ import warnings from abc import ABC, abstractmethod -from typing import Any, Hashable, Optional, Tuple -import torch -import numpy as np from itertools import chain +from typing import Any, Dict, Hashable, Optional, Tuple + +import numpy as np +import torch from monai.config import KeysCollection -from monai.utils import MAX_SEED, ensure_tuple -from monai.utils import optional_import +from monai.utils import MAX_SEED, ensure_tuple, optional_import sitk, has_sitk = optional_import("SimpleITK") vtk, has_vtk = optional_import("vtk") @@ -230,10 +230,17 @@ class InvertibleTransform(ABC): first out for the inverted transforms. """ - def append_applied_transforms(self, data: dict, key: Hashable, idx: int = 0, extra_info: Optional[dict] = None, orig_size: Optional[Tuple] = None) -> None: + def append_applied_transforms( + self, + data: dict, + key: Hashable, + idx: int = 0, + extra_info: Optional[dict] = None, + orig_size: Optional[Tuple] = None, + ) -> None: """Append to list of applied transforms for that key.""" key_transform = str(key) + "_transforms" - info = {} + info: Dict[str, Any] = {} info["class"] = type(self) info["init_args"] = self.get_input_args(key, idx) info["orig_size"] = orig_size or data[key].shape[1:] @@ -286,6 +293,7 @@ def inverse(self, data: dict): """ raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") + class NonRigidTransform(ABC): @staticmethod def _get_disp_to_def_arr(shape, spacing): @@ -308,7 +316,7 @@ def _inv_disp_w_sitk(fwd_disp, num_iters): @staticmethod def _inv_disp_w_vtk(fwd_disp): while fwd_disp.shape[-1] < 3: - fwd_disp = np.append(fwd_disp, np.zeros(fwd_disp.shape[:-1] + (1, )), axis=-1) + fwd_disp = np.append(fwd_disp, np.zeros(fwd_disp.shape[:-1] + (1,)), axis=-1) fwd_disp = fwd_disp[..., None, :] # fwd_disp_vtk = vtk.vtkImageImport() # # The previously created array is converted to a string of chars and imported. @@ -365,15 +373,18 @@ def _inv_disp_w_vtk(fwd_disp): return inv_disp - @staticmethod - def compute_inverse_deformation(num_spatial_dims, fwd_def_orig, spacing=None, num_iters: int = 100, use_package: str = "sitk"): + def compute_inverse_deformation( + num_spatial_dims, fwd_def_orig, spacing=None, num_iters: int = 100, use_package: str = "sitk" + ): """Package can be vtk or sitk.""" if use_package.lower() == "vtk" and not has_vtk: warnings.warn("Please install VTK to estimate inverse of non-rigid transforms. Data has not been modified") return None if use_package.lower() == "sitk" and not has_sitk: - warnings.warn("Please install SimpleITK to estimate inverse of non-rigid transforms. Data has not been modified") + warnings.warn( + "Please install SimpleITK to estimate inverse of non-rigid transforms. Data has not been modified" + ) return None # Convert to numpy if necessary @@ -394,8 +405,8 @@ def compute_inverse_deformation(num_spatial_dims, fwd_def_orig, spacing=None, nu else: inv_disp = NonRigidTransform._inv_disp_w_sitk(fwd_disp, num_iters) - import matplotlib.pyplot as plt + fig, axes = plt.subplots(2, 2) for i, direc1 in enumerate(["x", "y"]): for j, (im, direc2) in enumerate(zip([fwd_disp, inv_disp], ["fwd", "inv"])): diff --git a/tests/test_inverse.py b/tests/test_inverse.py index c0cecb418f..497e57192f 100644 --- a/tests/test_inverse.py +++ b/tests/test_inverse.py @@ -11,42 +11,41 @@ import random import unittest -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, List, Tuple import numpy as np -from typing import List, Tuple -from monai.data import create_test_image_2d, create_test_image_3d -from monai.data import CacheDataset + +from monai.data import CacheDataset, create_test_image_2d, create_test_image_3d from monai.transforms import ( - InvertibleTransform, - AddChanneld, AddChannel, - Compose, - RandRotated, - RandSpatialCropd, - Rotated, - SpatialPad, - SpatialPadd, - SpatialCropd, + AddChanneld, BorderPadd, + CenterSpatialCropd, + Compose, + CropForegroundd, DivisiblePadd, Flipd, + InvertibleTransform, LoadImaged, - Rotate90d, - Zoomd, - CenterSpatialCropd, - CropForegroundd, Orientationd, - Spacingd, - Resized, - ResizeWithPadOrCropd, - RandZoomd, - RandFlipd, - RandRotate90d, - RandAffined, Rand2DElasticd, Rand3DElasticd, + RandAffined, + RandFlipd, + RandRotate90d, + RandRotated, + RandSpatialCropd, + RandZoomd, + Resized, ResizeWithPadOrCrop, + ResizeWithPadOrCropd, + Rotate90d, + Rotated, + Spacingd, + SpatialCropd, + SpatialPad, + SpatialPadd, + Zoomd, ) from monai.utils import optional_import, set_determinism from tests.utils import make_nifti_image, make_rand_affine @@ -78,296 +77,307 @@ TESTS: List[Tuple] = [] -TESTS.append(( - "SpatialPadd (x2) 2d", - DATA_2D, - 0.0, - SpatialPadd(KEYS, spatial_size=[111, 113], method="end"), - SpatialPadd(KEYS, spatial_size=[118, 117]), -)) - -TESTS.append(( - "SpatialPadd 3d", - DATA_3D, - 0.0, - SpatialPadd(KEYS, spatial_size=[112, 113, 116]), -)) - -TESTS.append(( - "RandRotated, prob 0", - DATA_2D, - 0, - RandRotated(KEYS, prob=0), -)) - -TESTS.append(( - "SpatialCropd 2d", - DATA_2D, - 3e-2, - SpatialCropd(KEYS, [49, 51], [90, 89]), -)) - -TESTS.append(( - "SpatialCropd 3d", - DATA_3D, - 4e-2, - SpatialCropd(KEYS, [49, 51, 44], [90, 89, 93]), -)) - -TESTS.append(( - "RandSpatialCropd 2d", - DATA_2D, - 5e-2, - RandSpatialCropd(KEYS, [96, 93], True, False) -)) - -TESTS.append(( - "RandSpatialCropd 3d", - DATA_3D, - 2e-2, - RandSpatialCropd(KEYS, [96, 93, 92], False, True) -)) - -TESTS.append(( - "BorderPadd 2d", - DATA_2D, - 0, - BorderPadd(KEYS, [3, 7, 2, 5]), -)) - -TESTS.append(( - "BorderPadd 2d", - DATA_2D, - 0, - BorderPadd(KEYS, [3, 7]), -)) - -TESTS.append(( - "BorderPadd 3d", - DATA_3D, - 0, - BorderPadd(KEYS, [4]), -)) - -TESTS.append(( - "DivisiblePadd 2d", - DATA_2D, - 0, - DivisiblePadd(KEYS, k=4), -)) - -TESTS.append(( - "DivisiblePadd 3d", - DATA_3D, - 0, - DivisiblePadd(KEYS, k=[4, 8, 11]), -)) - -TESTS.append(( - "Flipd 3d", - DATA_3D, - 0, - Flipd(KEYS, [1, 2]), -)) - -TESTS.append(( - "Flipd 3d", - DATA_3D, - 0, - Flipd(KEYS, [1, 2]), -)) - -TESTS.append(( - "RandFlipd 3d", - DATA_3D, - 0, - RandFlipd(KEYS, 1, [1, 2]), -)) - -TESTS.append(( - "Rotated 2d", - DATA_2D, - 8e-2, - Rotated(KEYS, random.uniform(np.pi / 6, np.pi), keep_size=True, align_corners=False), -)) - -TESTS.append(( - "Rotated 3d", - DATA_3D, - 5e-2, - Rotated(KEYS, [random.uniform(np.pi / 6, np.pi) for _ in range(3)], 1), -)) - -TESTS.append(( - "RandRotated 3d", - DATA_3D, - 5e-2, - RandRotated(KEYS, *(random.uniform(np.pi / 6, np.pi) for _ in range(3)), 1), -)) - -TESTS.append(( - "Orientationd 3d", - DATA_3D, - 0, - Orientationd(KEYS, 'RAS'), -)) - -TESTS.append(( - "Rotate90d 2d", - DATA_2D, - 0, - Rotate90d(KEYS), -)) - -TESTS.append(( - "Rotate90d 3d", - DATA_3D, - 0, - Rotate90d(KEYS, k=2, spatial_axes=[1, 2]), -)) - -TESTS.append(( - "RandRotate90d 3d", - DATA_3D, - 0, - RandRotate90d(KEYS, prob=1, spatial_axes=[1, 2]), -)) - -TESTS.append(( - "Zoomd 1d", - DATA_1D, - 0, - Zoomd(KEYS, zoom=2, keep_size=False), -)) - -TESTS.append(( - "Zoomd 2d", - DATA_2D, - 2e-1, - Zoomd(KEYS, zoom=0.9), -)) - -TESTS.append(( - "Zoomd 3d", - DATA_3D, - 3e-2, - Zoomd(KEYS, zoom=[2.5, 1, 3], keep_size=False), -)) - -TESTS.append(( - "RandZoom 3d", - DATA_3D, - 5e-2, - RandZoomd(KEYS, 1, [0.5, 0.6, 0.9], [3, 4.2, 6.1], keep_size=False) -)) - -TESTS.append(( - "CenterSpatialCropd 2d", - DATA_2D, - 0, - CenterSpatialCropd(KEYS, roi_size=95), -)) - -TESTS.append(( - "CenterSpatialCropd 3d", - DATA_3D, - 0, - CenterSpatialCropd(KEYS, roi_size=[95, 97, 98]), -)) - -TESTS.append(( - "CropForegroundd 2d", - DATA_2D, - 0, - CropForegroundd(KEYS, source_key="label", margin=[2, 1]) -)) - -TESTS.append(( - "CropForegroundd 3d", - DATA_3D, - 0, - CropForegroundd(KEYS, source_key="label") -)) - -TESTS.append(( - "Spacingd 3d", - DATA_3D, - 3e-2, - Spacingd(KEYS, [0.5, 0.7, 0.9], diagonal=False) -)) - -TESTS.append(( - "Resized 2d", - DATA_2D, - 2e-1, - Resized(KEYS, [50, 47]) -)) - -TESTS.append(( - "Resized 3d", - DATA_3D, - 5e-2, - Resized(KEYS, [201, 150, 78]) -)) - -TESTS.append(( - "ResizeWithPadOrCropd 3d", - DATA_3D, - 1e-2, - ResizeWithPadOrCropd(KEYS, [201, 150, 78]) -)) - -TESTS.append(( - "RandAffine 3d", - DATA_3D, - 5e-2, - RandAffined( - KEYS, - [155, 179, 192], - prob=1, - padding_mode="zeros", - rotate_range=[np.pi / 6, -np.pi / 5, np.pi / 7], - shear_range=[[0.5, 0.5]], - translate_range=[10, 5, -4], - scale_range=[[0.8, 1.2], [0.9, 1.3]], +TESTS.append( + ( + "SpatialPadd (x2) 2d", + DATA_2D, + 0.0, + SpatialPadd(KEYS, spatial_size=[111, 113], method="end"), + SpatialPadd(KEYS, spatial_size=[118, 117]), + ) +) + +TESTS.append( + ( + "SpatialPadd 3d", + DATA_3D, + 0.0, + SpatialPadd(KEYS, spatial_size=[112, 113, 116]), + ) +) + +TESTS.append( + ( + "RandRotated, prob 0", + DATA_2D, + 0, + RandRotated(KEYS, prob=0), + ) +) + +TESTS.append( + ( + "SpatialCropd 2d", + DATA_2D, + 3e-2, + SpatialCropd(KEYS, [49, 51], [90, 89]), + ) +) + +TESTS.append( + ( + "SpatialCropd 3d", + DATA_3D, + 4e-2, + SpatialCropd(KEYS, [49, 51, 44], [90, 89, 93]), + ) +) + +TESTS.append(("RandSpatialCropd 2d", DATA_2D, 5e-2, RandSpatialCropd(KEYS, [96, 93], True, False))) + +TESTS.append(("RandSpatialCropd 3d", DATA_3D, 2e-2, RandSpatialCropd(KEYS, [96, 93, 92], False, True))) + +TESTS.append( + ( + "BorderPadd 2d", + DATA_2D, + 0, + BorderPadd(KEYS, [3, 7, 2, 5]), + ) +) + +TESTS.append( + ( + "BorderPadd 2d", + DATA_2D, + 0, + BorderPadd(KEYS, [3, 7]), + ) +) + +TESTS.append( + ( + "BorderPadd 3d", + DATA_3D, + 0, + BorderPadd(KEYS, [4]), + ) +) + +TESTS.append( + ( + "DivisiblePadd 2d", + DATA_2D, + 0, + DivisiblePadd(KEYS, k=4), + ) +) + +TESTS.append( + ( + "DivisiblePadd 3d", + DATA_3D, + 0, + DivisiblePadd(KEYS, k=[4, 8, 11]), + ) +) + +TESTS.append( + ( + "Flipd 3d", + DATA_3D, + 0, + Flipd(KEYS, [1, 2]), + ) +) + +TESTS.append( + ( + "Flipd 3d", + DATA_3D, + 0, + Flipd(KEYS, [1, 2]), + ) +) + +TESTS.append( + ( + "RandFlipd 3d", + DATA_3D, + 0, + RandFlipd(KEYS, 1, [1, 2]), + ) +) + +TESTS.append( + ( + "Rotated 2d", + DATA_2D, + 8e-2, + Rotated(KEYS, random.uniform(np.pi / 6, np.pi), keep_size=True, align_corners=False), + ) +) + +TESTS.append( + ( + "Rotated 3d", + DATA_3D, + 5e-2, + Rotated(KEYS, [random.uniform(np.pi / 6, np.pi) for _ in range(3)], True), # type: ignore + ) +) + +TESTS.append( + ( + "RandRotated 3d", + DATA_3D, + 5e-2, + RandRotated(KEYS, *[random.uniform(np.pi / 6, np.pi) for _ in range(3)], 1), # type: ignore + ) +) + +TESTS.append( + ( + "Orientationd 3d", + DATA_3D, + 0, + Orientationd(KEYS, "RAS"), + ) +) + +TESTS.append( + ( + "Rotate90d 2d", + DATA_2D, + 0, + Rotate90d(KEYS), + ) +) + +TESTS.append( + ( + "Rotate90d 3d", + DATA_3D, + 0, + Rotate90d(KEYS, k=2, spatial_axes=(1, 2)), + ) +) + +TESTS.append( + ( + "RandRotate90d 3d", + DATA_3D, + 0, + RandRotate90d(KEYS, prob=1, spatial_axes=(1, 2)), ) -)) - -TESTS.append(( - "Rand2DElasticd 2d", - DATA_2D, - 8e-2, - Rand2DElasticd( - KEYS, - spacing=[10, 10], - magnitude_range=[2, 2], - # spatial_size=[155, 192], - prob=1, - padding_mode="zeros", - rotate_range=[[np.pi / 6, np.pi / 6], np.pi / 7], - # shear_range=[[0.5, 0.5]], - # translate_range=[10, 5], - # scale_range=[[0.8, 1.2], [0.9, 1.3]], +) + +TESTS.append( + ( + "Zoomd 1d", + DATA_1D, + 0, + Zoomd(KEYS, zoom=2, keep_size=False), ) -)) - -TESTS.append(( - "Rand3DElasticd 3d", - DATA_3D, - 0, - Rand3DElasticd( - KEYS, - sigma_range=[1, 3], - magnitude_range=[1., 2., 1.], - spatial_size=[155, 192, 200], - prob=1, - padding_mode="zeros", - rotate_range=[np.pi / 6, np.pi / 7], - shear_range=[[0.5, 0.5]], - translate_range=[10, 5], - scale_range=[[0.8, 1.2], [0.9, 1.3]], +) + +TESTS.append( + ( + "Zoomd 2d", + DATA_2D, + 2e-1, + Zoomd(KEYS, zoom=0.9), ) -)) +) + +TESTS.append( + ( + "Zoomd 3d", + DATA_3D, + 3e-2, + Zoomd(KEYS, zoom=[2.5, 1, 3], keep_size=False), + ) +) + +TESTS.append(("RandZoom 3d", DATA_3D, 5e-2, RandZoomd(KEYS, 1, [0.5, 0.6, 0.9], [3, 4.2, 6.1], keep_size=False))) + +TESTS.append( + ( + "CenterSpatialCropd 2d", + DATA_2D, + 0, + CenterSpatialCropd(KEYS, roi_size=95), + ) +) + +TESTS.append( + ( + "CenterSpatialCropd 3d", + DATA_3D, + 0, + CenterSpatialCropd(KEYS, roi_size=[95, 97, 98]), + ) +) + +TESTS.append(("CropForegroundd 2d", DATA_2D, 0, CropForegroundd(KEYS, source_key="label", margin=2))) + +TESTS.append(("CropForegroundd 3d", DATA_3D, 0, CropForegroundd(KEYS, source_key="label"))) + +TESTS.append(("Spacingd 3d", DATA_3D, 3e-2, Spacingd(KEYS, [0.5, 0.7, 0.9], diagonal=False))) + +TESTS.append(("Resized 2d", DATA_2D, 2e-1, Resized(KEYS, [50, 47]))) + +TESTS.append(("Resized 3d", DATA_3D, 5e-2, Resized(KEYS, [201, 150, 78]))) + +TESTS.append(("ResizeWithPadOrCropd 3d", DATA_3D, 1e-2, ResizeWithPadOrCropd(KEYS, [201, 150, 78]))) + +TESTS.append( + ( + "RandAffine 3d", + DATA_3D, + 5e-2, + RandAffined( + KEYS, + [155, 179, 192], + prob=1, + padding_mode="zeros", + rotate_range=[np.pi / 6, -np.pi / 5, np.pi / 7], + shear_range=[(0.5, 0.5)], + translate_range=[10, 5, -4], + scale_range=[(0.8, 1.2), (0.9, 1.3)], + ), + ) +) + +TESTS.append( + ( + "Rand2DElasticd 2d", + DATA_2D, + 8e-2, + Rand2DElasticd( + KEYS, + spacing=(10.0, 10.0), + magnitude_range=(2, 2), + # spatial_size=[155, 192], + prob=1, + padding_mode="zeros", + rotate_range=[(np.pi / 6, np.pi / 6), np.pi / 7], + # shear_range=[(0.5, 0.5)], + # translate_range=[10, 5], + # scale_range=[(0.8, 1.2), (0.9, 1.3)], + ), + ) +) + +TESTS.append( + ( + "Rand3DElasticd 3d", + DATA_3D, + 0, + Rand3DElasticd( + KEYS, + sigma_range=(1, 3), + magnitude_range=(1.0, 2.0), + spatial_size=[155, 192, 200], + prob=1, + padding_mode="zeros", + rotate_range=[np.pi / 6, np.pi / 7], + shear_range=[(0.5, 0.5)], + translate_range=[10, 5], + scale_range=[(0.8, 1.2), (0.9, 1.3)], + ), + ) +) TESTS_COMPOSE_X2 = [(t[0] + " Compose", t[1], t[2], Compose(Compose(t[3:]))) for t in TESTS] @@ -378,13 +388,12 @@ TEST_FAIL_0 = (DATA_2D["image"], 0.0, Compose([SpatialPad(spatial_size=[101, 103])])) TESTS_FAIL = [TEST_FAIL_0] + def plot_im(orig, fwd_bck, fwd): diff_orig_fwd_bck = orig - fwd_bck ims_to_show = [orig, fwd, fwd_bck, diff_orig_fwd_bck] titles = ["x", "fx", "f⁻¹fx", "x - f⁻¹fx"] - fig, axes = plt.subplots( - 1, 4, gridspec_kw={"width_ratios": [i.shape[1] for i in ims_to_show]} - ) + fig, axes = plt.subplots(1, 4, gridspec_kw={"width_ratios": [i.shape[1] for i in ims_to_show]}) vmin = min(np.array(i).min() for i in [orig, fwd_bck, fwd]) vmax = max(np.array(i).max() for i in [orig, fwd_bck, fwd]) for im, title, ax in zip(ims_to_show, titles, axes): @@ -412,7 +421,9 @@ def check_inverse(self, name, keys, orig_d, fwd_bck_d, unmodified_d, acceptable_ self.assertLessEqual(mean_diff, acceptable_diff) self.assertLessEqual(mean_diff, unmodded_diff) except AssertionError: - print(f"Failed: {name}. Mean diff = {mean_diff} (expected <= {acceptable_diff}), unmodified diff: {unmodded_diff}") + print( + f"Failed: {name}. Mean diff = {mean_diff} (expected <= {acceptable_diff}), unmodified diff: {unmodded_diff}" + ) if has_matplotlib: plot_im(orig, fwd_bck, unmodified) raise diff --git a/tests/utils.py b/tests/utils.py index a99e2fc4b4..56e8f40a1d 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -150,6 +150,7 @@ def make_nifti_image(array, affine=None): os.close(temp_f) return image_name + def make_rand_affine(ndim: int = 3): """Create random affine transformation (with values == -1, 0 or 1).""" vals = np.random.choice([-1, 1], size=ndim) @@ -160,6 +161,7 @@ def make_rand_affine(ndim: int = 3): af[i, p] = v return af + class DistTestCase(unittest.TestCase): """ testcase without _outcome, so that it's picklable. From d03afd7c2ab9cdbb138a673dd0cf3893e10f9093 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Tue, 9 Feb 2021 15:23:37 +0000 Subject: [PATCH 51/80] decollate batch Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/data/__init__.py | 1 + monai/data/utils.py | 55 +++++++++++++++++++++++++++++++++ tests/test_decollate.py | 68 +++++++++++++++++++++++++++++++++++++++++ 3 files changed, 124 insertions(+) create mode 100644 tests/test_decollate.py diff --git a/monai/data/__init__.py b/monai/data/__init__.py index e0db1e17ae..7b1e60a30f 100644 --- a/monai/data/__init__.py +++ b/monai/data/__init__.py @@ -56,4 +56,5 @@ to_affine_nd, worker_init_fn, zoom_affine, + decollate_batch, ) diff --git a/monai/data/utils.py b/monai/data/utils.py index acc6d2e97a..416a601dfa 100644 --- a/monai/data/utils.py +++ b/monai/data/utils.py @@ -63,6 +63,7 @@ "json_hashing", "pickle_hashing", "sorted_dict", + "decollate_batch", ] @@ -242,6 +243,60 @@ def list_data_collate(batch: Sequence): data = [i for k in batch for i in k] if isinstance(elem, list) else batch return default_collate(data) +def decollate_batch(data: dict, batch_size: Optional[int] = None): + """De-collate a batch of data (for example, as produced by a `DataLoader`). + + Returns a list of dictionaries. Each dictionary will only contain the data for a given batch. + + Images originally stored as (B,C,H,W,[D]) will be returned as (C,H,W,[D]). Other information, + such as metadata, may have been stored in a list (or a list inside nested dictionaries). In + this case we return the element of the list corresponding to the batch idx. + + For example: + + ``` + batch_data = { + "image": torch.rand((2,1,10,10)), + "image_meta_dict": {"scl_slope": torch.Tensor([0.0, 0.0])} + } + out = decollate_batch(batch_data) + print(len(out)) + >>> 2 + + print(out[0]) + >>> {'image': tensor([[[4.3549e-01...43e-01]]]), 'image_meta_dict': {'scl_slope': 0.0}} + ``` + + Args: + data: data to be de-collated + batch_size: number of batches in data. If `None` is passed, try to figure out batch size. + """ + if not isinstance(data, dict): + raise RuntimeError("Only currently implemented for dictionary data (might be trivial to adapt).") + if batch_size is None: + for v in data.values(): + if isinstance(v, torch.Tensor): + batch_size = v.shape[0] + if batch_size is None: + raise RuntimeError("Couldn't determine batch size, please specify as argument.") + + def decollate(data, idx, batch_size): + if isinstance(data, torch.Tensor): + out = data[idx] + return out if out.numel() > 1 else out.item() + elif isinstance(data, dict): + return {k: decollate(v, idx, batch_size) for k, v in data.items()} + elif isinstance(data, list): + if len(data) == batch_size: + return data[idx] + if len(data) == 1: + return [decollate(data[0], idx, batch_size)] + if isinstance(data[0], torch.Tensor): + return [d[idx] if d[idx].numel() > 1 else d[idx].item() for d in data] + raise TypeError(f"Not sure how to de-collate list of len: {len(data)}") + raise TypeError(f"Not sure how to de-collate type: {type(data)}") + + return [{key: decollate(data[key], idx, batch_size) for key in data.keys()} for idx in range(batch_size)] def worker_init_fn(worker_id: int) -> None: """ diff --git a/tests/test_decollate.py b/tests/test_decollate.py new file mode 100644 index 0000000000..ebb34dafcc --- /dev/null +++ b/tests/test_decollate.py @@ -0,0 +1,68 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np + +from monai.data import DataLoader +from monai.data import CacheDataset, create_test_image_2d, create_test_image_3d +from monai.transforms import AddChanneld, Compose, LoadImaged, ToTensord +from monai.data.utils import decollate_batch +from monai.utils import set_determinism +from tests.utils import make_nifti_image + +from parameterized import parameterized + + +set_determinism(seed=0) + +IM_2D_FNAME = make_nifti_image(create_test_image_2d(100, 101)[0]) +IM_3D_FNAME = make_nifti_image(create_test_image_3d(100, 101, 107)[0]) + +TRANSFORMS = Compose([LoadImaged("image"), AddChanneld("image"), ToTensord("image")]) +DATA_2D = {"image": IM_2D_FNAME} +DATA_3D = {"image": IM_3D_FNAME} + +TESTS = [] +TESTS.append(( + "2D", + [DATA_2D for _ in range(5)], + TRANSFORMS, +)) +TESTS.append(( + "3D", + [DATA_3D for _ in range(9)], + TRANSFORMS, +)) + +class TestDeCollate(unittest.TestCase): + def check_dictionaries_match(self, d1, d2): + self.assertEqual(d1.keys(), d2.keys()) + for v1, v2 in zip(d1.values(), d2.values()): + if isinstance(v1, dict): + self.check_dictionaries_match(v1, v2) + else: + np.testing.assert_array_equal(v1, v2) + + @parameterized.expand(TESTS) + def test_decollation(self, _, data, transforms, batch_size=2, num_workers=0): + dataset = CacheDataset(data, transforms, progress=False) + loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers) + + for b, batch_data in enumerate(loader): + decollated = decollate_batch(batch_data) + + for i, d in enumerate(decollated): + self.check_dictionaries_match(d, dataset[b * batch_size + i]) + + +if __name__ == "__main__": + unittest.main() From e20078987f1727431ae06309ec3763543df0b903 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Tue, 9 Feb 2021 17:58:06 +0000 Subject: [PATCH 52/80] decollate2 Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/data/utils.py | 32 ++++++++++++++++----------- tests/test_decollate.py | 48 +++++++++++++++++++++++------------------ 2 files changed, 47 insertions(+), 33 deletions(-) diff --git a/monai/data/utils.py b/monai/data/utils.py index 416a601dfa..2dd948a189 100644 --- a/monai/data/utils.py +++ b/monai/data/utils.py @@ -12,13 +12,14 @@ import hashlib import json import math +from monai.utils.misc import issequenceiterable import os import pickle import warnings from collections import defaultdict from itertools import product, starmap from pathlib import PurePath -from typing import Dict, Generator, Iterable, List, Optional, Sequence, Tuple, Union +from typing import Any, Dict, Generator, Iterable, List, Optional, Sequence, Tuple, Union import numpy as np import torch @@ -252,6 +253,9 @@ def decollate_batch(data: dict, batch_size: Optional[int] = None): such as metadata, may have been stored in a list (or a list inside nested dictionaries). In this case we return the element of the list corresponding to the batch idx. + Return types aren't guaranteed to be the same as the original, since numpy arrays will have been + converted to torch.Tensor, and tuples/lists may have been converted to lists of tensors + For example: ``` @@ -277,26 +281,30 @@ def decollate_batch(data: dict, batch_size: Optional[int] = None): for v in data.values(): if isinstance(v, torch.Tensor): batch_size = v.shape[0] + break if batch_size is None: raise RuntimeError("Couldn't determine batch size, please specify as argument.") - def decollate(data, idx, batch_size): + def torch_to_single(d: torch.Tensor): + """If input is a torch.Tensor with only 1 element, return just the element.""" + return d if d.numel() > 1 else d.item() + + def decollate(data: Any, idx: int): + """Recursively de-collate.""" + if isinstance(data, dict): + return {k: decollate(v, idx) for k, v in data.items()} if isinstance(data, torch.Tensor): out = data[idx] - return out if out.numel() > 1 else out.item() - elif isinstance(data, dict): - return {k: decollate(v, idx, batch_size) for k, v in data.items()} + return torch_to_single(out) elif isinstance(data, list): - if len(data) == batch_size: - return data[idx] - if len(data) == 1: - return [decollate(data[0], idx, batch_size)] if isinstance(data[0], torch.Tensor): - return [d[idx] if d[idx].numel() > 1 else d[idx].item() for d in data] - raise TypeError(f"Not sure how to de-collate list of len: {len(data)}") + return [torch_to_single(d[idx]) for d in data] + if issequenceiterable(data[0]): + return [decollate(d, idx) for d in data] + return data[idx] raise TypeError(f"Not sure how to de-collate type: {type(data)}") - return [{key: decollate(data[key], idx, batch_size) for key in data.keys()} for idx in range(batch_size)] + return [{key: decollate(data[key], idx) for key in data.keys()} for idx in range(batch_size)] def worker_init_fn(worker_id: int) -> None: """ diff --git a/tests/test_decollate.py b/tests/test_decollate.py index ebb34dafcc..758289cd43 100644 --- a/tests/test_decollate.py +++ b/tests/test_decollate.py @@ -9,12 +9,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +import torch import unittest import numpy as np from monai.data import DataLoader -from monai.data import CacheDataset, create_test_image_2d, create_test_image_3d -from monai.transforms import AddChanneld, Compose, LoadImaged, ToTensord +from monai.data import CacheDataset, create_test_image_2d +from monai.transforms import AddChanneld, Compose, LoadImaged, ToTensord, SpatialPadd, RandFlipd from monai.data.utils import decollate_batch from monai.utils import set_determinism from tests.utils import make_nifti_image @@ -25,35 +26,40 @@ set_determinism(seed=0) IM_2D_FNAME = make_nifti_image(create_test_image_2d(100, 101)[0]) -IM_3D_FNAME = make_nifti_image(create_test_image_3d(100, 101, 107)[0]) -TRANSFORMS = Compose([LoadImaged("image"), AddChanneld("image"), ToTensord("image")]) DATA_2D = {"image": IM_2D_FNAME} -DATA_3D = {"image": IM_3D_FNAME} TESTS = [] TESTS.append(( "2D", - [DATA_2D for _ in range(5)], - TRANSFORMS, -)) -TESTS.append(( - "3D", - [DATA_3D for _ in range(9)], - TRANSFORMS, + [DATA_2D for _ in range(6)], )) class TestDeCollate(unittest.TestCase): - def check_dictionaries_match(self, d1, d2): - self.assertEqual(d1.keys(), d2.keys()) - for v1, v2 in zip(d1.values(), d2.values()): - if isinstance(v1, dict): - self.check_dictionaries_match(v1, v2) - else: - np.testing.assert_array_equal(v1, v2) + def check_match(self, in1, in2): + if isinstance(in1, dict): + self.assertTrue(isinstance(in2, dict)) + self.check_match(list(in1.keys()), list(in2.keys())) + self.check_match(list(in1.values()), list(in2.values())) + elif any(isinstance(in1, i) for i in [list, tuple]): + for l1, l2 in zip(in1, in2): + self.check_match(l1, l2) + elif any(isinstance(in1, i) for i in [str, int]): + self.assertEqual(in1, in2) + elif any(isinstance(in1, i) for i in [torch.Tensor, np.ndarray]): + np.testing.assert_array_equal(in1, in2) + else: + raise RuntimeError(f"Not sure how to compare types. type(in1): {type(in1)}, type(in2): {type(in2)}") @parameterized.expand(TESTS) - def test_decollation(self, _, data, transforms, batch_size=2, num_workers=0): + def test_decollation(self, _, data, batch_size=2, num_workers=0): + transforms = Compose([ + LoadImaged("image"), + AddChanneld("image"), + SpatialPadd("image", 150), + RandFlipd("image", prob=1., spatial_axis=1), + ToTensord("image"), + ]) dataset = CacheDataset(data, transforms, progress=False) loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers) @@ -61,7 +67,7 @@ def test_decollation(self, _, data, transforms, batch_size=2, num_workers=0): decollated = decollate_batch(batch_data) for i, d in enumerate(decollated): - self.check_dictionaries_match(d, dataset[b * batch_size + i]) + self.check_match(dataset[b * batch_size + i], d) if __name__ == "__main__": From c8aaf8cbbcbf72f20a5ffffc67f69158c46e6dbf Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Wed, 10 Feb 2021 11:11:45 +0000 Subject: [PATCH 53/80] need to remove all init_args Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/data/__init__.py | 1 + monai/data/inverse_batch_transform.py | 46 ++++ monai/data/utils.py | 2 + monai/transforms/__init__.py | 3 +- monai/transforms/compose.py | 17 +- monai/transforms/croppad/dictionary.py | 50 ++--- monai/transforms/inverse_batch_transform.py | 46 ++++ monai/transforms/inverse_transform.py | 229 ++++++++++++++++++++ monai/transforms/spatial/dictionary.py | 3 +- monai/transforms/transform.py | 226 +------------------ tests/test_decollate.py | 2 +- tests/test_inverse.py | 60 +++-- 12 files changed, 402 insertions(+), 283 deletions(-) create mode 100644 monai/data/inverse_batch_transform.py create mode 100644 monai/transforms/inverse_batch_transform.py create mode 100644 monai/transforms/inverse_transform.py diff --git a/monai/data/__init__.py b/monai/data/__init__.py index 7b1e60a30f..90b65bc347 100644 --- a/monai/data/__init__.py +++ b/monai/data/__init__.py @@ -25,6 +25,7 @@ from .grid_dataset import GridPatchDataset, PatchDataset from .image_dataset import ImageDataset from .image_reader import ImageReader, ITKReader, NibabelReader, NumpyReader, PILReader +from .inverse_batch_transform import BatchInverseTransform from .iterable_dataset import IterableDataset from .nifti_saver import NiftiSaver from .nifti_writer import write_nifti diff --git a/monai/data/inverse_batch_transform.py b/monai/data/inverse_batch_transform.py new file mode 100644 index 0000000000..485213a6ab --- /dev/null +++ b/monai/data/inverse_batch_transform.py @@ -0,0 +1,46 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable +from monai.data.utils import decollate_batch + +from monai.data.dataloader import DataLoader +from monai.data.dataset import Dataset + + +__all__ = ["BatchInverseTransform"] + +class _BatchInverseDataset(Dataset): + def __init__(self, data, transform) -> None: + self.data = decollate_batch(data) + self.transform = transform + + def __getitem__(self, index: int): + data = self.data[index] + return self.transform.inverse(data) + + +class BatchInverseTransform: + """something""" + def __init__(self, transform: Callable, loader) -> None: + """ + Args: + transform: a callable data transform on input data. + loader: data loader used to generate the batch of data. + """ + self.transform = transform + self.batch_size = loader.batch_size + self.num_workers = loader.num_workers + + def __call__(self, data): + inv_ds = _BatchInverseDataset(data, self.transform) + inv_loader = DataLoader(inv_ds, batch_size=self.batch_size, num_workers=self.num_workers) + return next(iter(inv_loader)) diff --git a/monai/data/utils.py b/monai/data/utils.py index 2dd948a189..6c14d431ae 100644 --- a/monai/data/utils.py +++ b/monai/data/utils.py @@ -297,6 +297,8 @@ def decollate(data: Any, idx: int): out = data[idx] return torch_to_single(out) elif isinstance(data, list): + if len(data) == 0: + return data if isinstance(data[0], torch.Tensor): return [torch_to_single(d[idx]) for d in data] if issequenceiterable(data[0]): diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index ed157dcf45..e86b036ba9 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -234,7 +234,8 @@ ZoomD, ZoomDict, ) -from .transform import InvertibleTransform, MapTransform, NonRigidTransform, Randomizable, Transform +from .transform import MapTransform, Randomizable, Transform +from .inverse_transform import InvertibleTransform, NonRigidTransform from .utility.array import ( AddChannel, AddExtremePointsChannel, diff --git a/monai/transforms/compose.py b/monai/transforms/compose.py index 7529e119cb..713df6ade7 100644 --- a/monai/transforms/compose.py +++ b/monai/transforms/compose.py @@ -18,7 +18,8 @@ import numpy as np -from monai.transforms.transform import InvertibleTransform, Randomizable, Transform +from monai.transforms.transform import Randomizable, Transform +from monai.transforms.inverse_transform import InvertibleTransform from monai.transforms.utils import apply_transform from monai.utils import MAX_SEED, ensure_tuple, get_seed @@ -121,12 +122,10 @@ def inverse(self, data): if not isinstance(data, Mapping): raise RuntimeError("Inverse method only available for dictionary transforms") d = deepcopy(dict(data)) - # loop over data elements - for k in d: - transform_key = k + "_transforms" - if transform_key not in data: - continue - for t in reversed(data[transform_key]): - transform = t["class"](**t["init_args"]) - d = transform.inverse(d) + + # loop backwards over transforms + for t in reversed(self.transforms): + # check if transform is one of the invertible ones + if isinstance(t, InvertibleTransform): + d = t.inverse(d) return d diff --git a/monai/transforms/croppad/dictionary.py b/monai/transforms/croppad/dictionary.py index 109f66b759..e04141673b 100644 --- a/monai/transforms/croppad/dictionary.py +++ b/monai/transforms/croppad/dictionary.py @@ -32,7 +32,8 @@ SpatialCrop, SpatialPad, ) -from monai.transforms.transform import InvertibleTransform, MapTransform, Randomizable +from monai.transforms.transform import MapTransform, Randomizable +from monai.transforms.inverse_transform import InvertibleTransform from monai.transforms.utils import ( generate_pos_neg_label_crop_centers, generate_spatial_bounding_box, @@ -119,30 +120,23 @@ def __init__( def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) for idx, (key, m) in enumerate(zip(self.keys, self.mode)): - orig_size = d[key].shape + self.append_applied_transforms(d, key, idx) d[key] = self.padder(d[key], mode=m) - self.append_applied_transforms(d, key, idx, {"orig_size": orig_size}) return d - def get_input_args(self, key: Hashable, idx: int = 0) -> dict: - return { - "keys": key, - "method": self.padder.method, - "mode": self.mode[idx], - "spatial_size": self.padder.spatial_size, - } - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = deepcopy(dict(data)) for key in self.keys: transform = self.get_most_recent_transform(d, key) # Create inverse transform - extra_info = transform["extra_info"] - roi_size = extra_info["orig_size"][1:] - im_shape = d[key].shape[1:] if self.padder.method == Method.SYMMETRIC else extra_info["orig_size"][1:] - roi_center = [floor(i / 2) if r % 2 == 0 else (i - 1) / 2 for r, i in zip(roi_size, im_shape)] + orig_size = transform["orig_size"] + if self.padder.method == Method.SYMMETRIC: + current_size = d[key].shape[1:] + roi_center = [floor(i / 2) if r % 2 == 0 else (i - 1) / 2 for r, i in zip(orig_size, current_size)] + else: + roi_center = [floor(r / 2) if r % 2 == 0 else (r - 1) / 2 for r in orig_size] - inverse_transform = SpatialCrop(roi_center, roi_size) + inverse_transform = SpatialCrop(roi_center, orig_size) # Apply inverse transform d[key] = inverse_transform(d[key]) # Remove the applied transform @@ -210,7 +204,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar transform = self.get_most_recent_transform(d, key) # Create inverse transform orig_size = np.array(transform["orig_size"]) - roi_start = np.array(transform["init_args"]["spatial_border"]) + roi_start = np.array(self.padder.spatial_border) # Need to convert single value to [min1,min2,...] if roi_start.size == 1: roi_start = np.full((len(orig_size)), roi_start) @@ -264,13 +258,6 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda d[key] = self.padder(d[key], mode=m) return d - def get_input_args(self, key: Hashable, idx: int = 0) -> dict: - return { - "keys": key, - "k": self.padder.k, - "mode": self.mode[idx], - } - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = deepcopy(dict(data)) @@ -338,8 +325,8 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar transform = self.get_most_recent_transform(d, key) # Create inverse transform orig_size = transform["orig_size"] - pad_to_start = transform["init_args"]["roi_start"] - pad_to_end = orig_size - transform["init_args"]["roi_end"] + pad_to_start = self.cropper.roi_start + pad_to_end = orig_size - self.cropper.roi_end # interweave mins and maxes pad = np.empty((2 * len(orig_size)), dtype=np.int32) pad[0::2] = pad_to_start @@ -455,10 +442,11 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda if self._size is None: raise AssertionError for idx, key in enumerate(self.keys): - self.append_applied_transforms(d, key, idx, {"slices": self._slices}) if self.random_center: + self.append_applied_transforms(d, key, idx, {"slices": [(i.start, i.stop) for i in self._slices[1:]]}) d[key] = d[key][self._slices] else: + self.append_applied_transforms(d, key, idx) cropper = CenterSpatialCrop(self._size) d[key] = cropper(d[key]) return d @@ -478,13 +466,13 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar transform = self.get_most_recent_transform(d, key) # Create inverse transform orig_size = transform["orig_size"] - random_center = transform["init_args"]["random_center"] + random_center = self.random_center pad_to_start = np.empty((len(orig_size)), dtype=np.int32) pad_to_end = np.empty((len(orig_size)), dtype=np.int32) if random_center: - for i, _slice in enumerate(transform["extra_info"]["slices"][1:]): - pad_to_start[i] = _slice.start - pad_to_end[i] = orig_size[i] - _slice.stop + for i, _slice in enumerate(transform["extra_info"]["slices"]): + pad_to_start[i] = _slice[0] + pad_to_end[i] = orig_size[i] - _slice[1] else: current_size = d[key].shape[1:] for i, (o_s, c_s) in enumerate(zip(orig_size, current_size)): diff --git a/monai/transforms/inverse_batch_transform.py b/monai/transforms/inverse_batch_transform.py new file mode 100644 index 0000000000..485213a6ab --- /dev/null +++ b/monai/transforms/inverse_batch_transform.py @@ -0,0 +1,46 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable +from monai.data.utils import decollate_batch + +from monai.data.dataloader import DataLoader +from monai.data.dataset import Dataset + + +__all__ = ["BatchInverseTransform"] + +class _BatchInverseDataset(Dataset): + def __init__(self, data, transform) -> None: + self.data = decollate_batch(data) + self.transform = transform + + def __getitem__(self, index: int): + data = self.data[index] + return self.transform.inverse(data) + + +class BatchInverseTransform: + """something""" + def __init__(self, transform: Callable, loader) -> None: + """ + Args: + transform: a callable data transform on input data. + loader: data loader used to generate the batch of data. + """ + self.transform = transform + self.batch_size = loader.batch_size + self.num_workers = loader.num_workers + + def __call__(self, data): + inv_ds = _BatchInverseDataset(data, self.transform) + inv_loader = DataLoader(inv_ds, batch_size=self.batch_size, num_workers=self.num_workers) + return next(iter(inv_loader)) diff --git a/monai/transforms/inverse_transform.py b/monai/transforms/inverse_transform.py new file mode 100644 index 0000000000..265399f9bb --- /dev/null +++ b/monai/transforms/inverse_transform.py @@ -0,0 +1,229 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings +from abc import ABC +from itertools import chain +from typing import Any, Dict, Hashable, Optional, Tuple +import numpy as np +import torch + +from monai.utils import optional_import +from monai.transforms.transform import Randomizable + +sitk, has_sitk = optional_import("SimpleITK") +vtk, has_vtk = optional_import("vtk") +vtk_numpy_support, _ = optional_import("vtk.util.numpy_support") + +__all__ = ["InvertibleTransform", "NonRigidTransform"] + +class InvertibleTransform(ABC): + """Classes for invertible transforms. + + This class exists so that an ``invert`` method can be implemented. This allows, for + example, images to be cropped, rotated, padded, etc., during training and inference, + and after be returned to their original size before saving to file for comparison in + an external viewer. + + When the `__call__` method is called, a serialization of the class is stored. When + the `inverse` method is called, the serialization is then removed. We use last in, + first out for the inverted transforms. + """ + + def append_applied_transforms( + self, + data: dict, + key: Hashable, + idx: int = 0, + extra_info: Optional[dict] = None, + orig_size: Optional[Tuple] = None, + ) -> None: + """Append to list of applied transforms for that key.""" + key_transform = str(key) + "_transforms" + info: Dict[str, Any] = {} + info["id"] = id(self) + # info["init_args"] = self.get_input_args(key, idx) + info["orig_size"] = orig_size or data[key].shape[1:] + if extra_info is not None: + info["extra_info"] = extra_info + # If class is randomizable, store whether the transform was actually performed (based on `prob`) + if isinstance(self, Randomizable): + info["do_transform"] = self._do_transform + # If this is the first, create list + if key_transform not in data: + data[key_transform] = [] + data[key_transform].append(info) + + + def check_transforms_match(self, transform: dict, key: Hashable) -> None: + explanation = "Should inverse most recently applied invertible transform first" + # Check transorms are of same type. + if transform["id"] != id(self): + raise RuntimeError(explanation) + + def get_most_recent_transform(self, data: dict, key: Hashable) -> dict: + """Get most recent transform.""" + transform = dict(data[str(key) + "_transforms"][-1]) + self.check_transforms_match(transform, key) + return transform + + @staticmethod + def remove_most_recent_transform(data: dict, key: Hashable) -> None: + """Remove most recent transform.""" + data[str(key) + "_transforms"].pop() + + def get_input_args(self, key: Hashable, idx: int = 0) -> dict: + """Get input arguments for a single key.""" + raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") + + def inverse(self, data: dict): + """ + Inverse of ``__call__``. + + Raises: + NotImplementedError: When the subclass does not override this method. + + """ + raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") + + +class NonRigidTransform(ABC): + @staticmethod + def _get_disp_to_def_arr(shape, spacing): + def_to_disp = np.mgrid[[slice(0, i) for i in shape]].astype(np.float64) + for idx, i in enumerate(shape): + # shift for origin (in MONAI, center of image) + def_to_disp[idx] -= (i - 1) / 2 + # if supplied, account for spacing (e.g., for control point grids) + if spacing is not None: + def_to_disp[idx] *= spacing[idx] + return def_to_disp + + @staticmethod + def _inv_disp_w_sitk(fwd_disp, num_iters): + fwd_disp_sitk = sitk.GetImageFromArray(fwd_disp, isVector=True) + inv_disp_sitk = sitk.InvertDisplacementField(fwd_disp_sitk, num_iters) + inv_disp = sitk.GetArrayFromImage(inv_disp_sitk) + return inv_disp + + @staticmethod + def _inv_disp_w_vtk(fwd_disp): + while fwd_disp.shape[-1] < 3: + fwd_disp = np.append(fwd_disp, np.zeros(fwd_disp.shape[:-1] + (1,)), axis=-1) + fwd_disp = fwd_disp[..., None, :] + # fwd_disp_vtk = vtk.vtkImageImport() + # # The previously created array is converted to a string of chars and imported. + # data_string = fwd_disp.tostring() + # fwd_disp_vtk.CopyImportVoidPointer(data_string, len(data_string)) + # # The type of the newly imported data is set to unsigned char (uint8) + # fwd_disp_vtk.SetDataScalarTypeToUnsignedChar() + # fwd_disp_vtk.SetNumberOfScalarComponents(3) + extent = list(chain.from_iterable(zip([0, 0, 0], fwd_disp.shape[:-1]))) + # fwd_disp_vtk.SetWholeExtent(extent) + # fwd_disp_vtk.SetDataExtentToWholeExtent() + # fwd_disp_vtk.Update() + # fwd_disp_vtk = fwd_disp_vtk.GetOutput() + + fwd_disp_flattened = fwd_disp.flatten() # need to keep this in memory + vtk_data_array = vtk_numpy_support.numpy_to_vtk(fwd_disp_flattened) + + # Generating the vtkImageData + fwd_disp_vtk = vtk.vtkImageData() + fwd_disp_vtk.SetOrigin(0, 0, 0) + fwd_disp_vtk.SetSpacing(1, 1, 1) + fwd_disp_vtk.SetDimensions(*fwd_disp.shape[:-1]) + + fwd_disp_vtk.AllocateScalars(vtk_numpy_support.get_vtk_array_type(fwd_disp.dtype), 3) + fwd_disp_vtk.SetExtent(extent) + fwd_disp_vtk.GetPointData().AddArray(vtk_data_array) + + # # create b-spline coefficients for the displacement grid + # bspline_filter = vtk.vtkImageBSplineCoefficients() + # bspline_filter.SetInputData(fwd_disp_vtk) + # bspline_filter.Update() + + # # use these b-spline coefficients to create a transform + # bspline_transform = vtk.vtkBSplineTransform() + # bspline_transform.SetCoefficientData(bspline_filter.GetOutput()) + # bspline_transform.Update() + + # # invert the b-spline transform onto a new grid + # grid_maker = vtk.vtkTransformToGrid() + # grid_maker.SetInput(bspline_transform.GetInverse()) + # grid_maker.SetGridOrigin(fwd_disp_vtk.GetOrigin()) + # grid_maker.SetGridSpacing(fwd_disp_vtk.GetSpacing()) + # grid_maker.SetGridExtent(fwd_disp_vtk.GetExtent()) + # grid_maker.SetGridScalarTypeToFloat() + # grid_maker.Update() + + # # Get inverse displacement as an image + # inv_disp_vtk = grid_maker.GetOutput() + + # from vtk.util.numpy_support import vtk_to_numpy + # inv_disp = vtk_numpy_support.vtk_to_numpy(inv_disp_vtk.GetPointData().GetScalars()) + inv_disp = vtk_numpy_support.vtk_to_numpy(fwd_disp_vtk.GetPointData().GetArray(0)) + inv_disp = inv_disp.reshape(fwd_disp.shape) + + return inv_disp + + @staticmethod + def compute_inverse_deformation( + num_spatial_dims, fwd_def_orig, spacing=None, num_iters: int = 100, use_package: str = "sitk" + ): + """Package can be vtk or sitk.""" + if use_package.lower() == "vtk" and not has_vtk: + warnings.warn("Please install VTK to estimate inverse of non-rigid transforms. Data has not been modified") + return None + if use_package.lower() == "sitk" and not has_sitk: + warnings.warn( + "Please install SimpleITK to estimate inverse of non-rigid transforms. Data has not been modified" + ) + return None + + # Convert to numpy if necessary + if isinstance(fwd_def_orig, torch.Tensor): + fwd_def_orig = fwd_def_orig.cpu().numpy() + # Remove any extra dimensions (we'll add them back in at the end) + fwd_def = fwd_def_orig[:num_spatial_dims] + # Def -> disp + def_to_disp = NonRigidTransform._get_disp_to_def_arr(fwd_def.shape[1:], spacing) + fwd_disp = fwd_def - def_to_disp + # move tensor component to end (T,H,W,[D])->(H,W,[D],T) + fwd_disp = np.moveaxis(fwd_disp, 0, -1) + + # If using vtk... + if use_package.lower() == "vtk": + inv_disp = NonRigidTransform._inv_disp_w_vtk(fwd_disp) + # If using sitk... + else: + inv_disp = NonRigidTransform._inv_disp_w_sitk(fwd_disp, num_iters) + + import matplotlib.pyplot as plt + + fig, axes = plt.subplots(2, 2) + for i, direc1 in enumerate(["x", "y"]): + for j, (im, direc2) in enumerate(zip([fwd_disp, inv_disp], ["fwd", "inv"])): + ax = axes[i, j] + im_show = ax.imshow(im[..., i]) + ax.set_title(f"{direc2}{direc1}", fontsize=25) + ax.axis("off") + fig.colorbar(im_show, ax=ax) + plt.show() + # move tensor component back to beginning + inv_disp = np.moveaxis(inv_disp, -1, 0) + # Disp -> def + inv_def = inv_disp + def_to_disp + # Add back in any removed dimensions + ndim_in = fwd_def_orig.shape[0] + ndim_out = inv_def.shape[0] + inv_def = np.concatenate([inv_def, fwd_def_orig[ndim_out:ndim_in]]) + + return inv_def diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index 7b116d68a5..a626c87ae1 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -38,7 +38,8 @@ Spacing, Zoom, ) -from monai.transforms.transform import InvertibleTransform, MapTransform, NonRigidTransform, Randomizable +from monai.transforms.transform import MapTransform, Randomizable +from monai.transforms.inverse_transform import InvertibleTransform, NonRigidTransform from monai.transforms.utils import create_grid from monai.utils import ( GridSampleMode, diff --git a/monai/transforms/transform.py b/monai/transforms/transform.py index 8fcf5a42e7..36c7445cf1 100644 --- a/monai/transforms/transform.py +++ b/monai/transforms/transform.py @@ -12,22 +12,13 @@ A collection of generic interfaces for MONAI transforms. """ -import warnings from abc import ABC, abstractmethod -from itertools import chain -from typing import Any, Dict, Hashable, Optional, Tuple - +from typing import Any, Hashable, Optional, Tuple import numpy as np -import torch - from monai.config import KeysCollection -from monai.utils import MAX_SEED, ensure_tuple, optional_import - -sitk, has_sitk = optional_import("SimpleITK") -vtk, has_vtk = optional_import("vtk") -vtk_numpy_support, _ = optional_import("vtk.util.numpy_support") +from monai.utils import MAX_SEED, ensure_tuple -__all__ = ["Randomizable", "Transform", "MapTransform", "InvertibleTransform", "NonRigidTransform"] +__all__ = ["Randomizable", "Transform", "MapTransform"] class Randomizable(ABC): @@ -215,214 +206,3 @@ def __call__(self, data): """ raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") - - -class InvertibleTransform(ABC): - """Classes for invertible transforms. - - This class exists so that an ``invert`` method can be implemented. This allows, for - example, images to be cropped, rotated, padded, etc., during training and inference, - and after be returned to their original size before saving to file for comparison in - an external viewer. - - When the `__call__` method is called, a serialization of the class is stored. When - the `inverse` method is called, the serialization is then removed. We use last in, - first out for the inverted transforms. - """ - - def append_applied_transforms( - self, - data: dict, - key: Hashable, - idx: int = 0, - extra_info: Optional[dict] = None, - orig_size: Optional[Tuple] = None, - ) -> None: - """Append to list of applied transforms for that key.""" - key_transform = str(key) + "_transforms" - info: Dict[str, Any] = {} - info["class"] = type(self) - info["init_args"] = self.get_input_args(key, idx) - info["orig_size"] = orig_size or data[key].shape[1:] - info["extra_info"] = extra_info - # If class is randomizable, store whether the transform was actually performed (based on `prob`) - if isinstance(self, Randomizable): - info["do_transform"] = self._do_transform - # If this is the first, create list - if key_transform not in data: - data[key_transform] = [] - data[key_transform].append(info) - - def check_transforms_match(self, transform: dict, key: Hashable) -> None: - explanation = "Should inverse most recently applied invertible transform first" - # Check transorms are of same type. - if transform["class"] != type(self): - raise RuntimeError(explanation) - - t1 = transform["init_args"] - t2 = self.get_input_args(key) - - if t1.keys() != t2.keys(): - raise RuntimeError(explanation) - for k in t1.keys(): - if np.any(t1[k] != t2[k]): - raise RuntimeError(explanation) - - def get_most_recent_transform(self, data: dict, key: Hashable) -> dict: - """Get most recent transform.""" - transform = dict(data[str(key) + "_transforms"][-1]) - self.check_transforms_match(transform, key) - return transform - - @staticmethod - def remove_most_recent_transform(data: dict, key: Hashable) -> None: - """Remove most recent transform.""" - data[str(key) + "_transforms"].pop() - - def get_input_args(self, key: Hashable, idx: int = 0) -> dict: - """Get input arguments for a single key.""" - raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") - - def inverse(self, data: dict): - """ - Inverse of ``__call__``. - - Raises: - NotImplementedError: When the subclass does not override this method. - - """ - raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") - - -class NonRigidTransform(ABC): - @staticmethod - def _get_disp_to_def_arr(shape, spacing): - def_to_disp = np.mgrid[[slice(0, i) for i in shape]].astype(np.float64) - for idx, i in enumerate(shape): - # shift for origin (in MONAI, center of image) - def_to_disp[idx] -= (i - 1) / 2 - # if supplied, account for spacing (e.g., for control point grids) - if spacing is not None: - def_to_disp[idx] *= spacing[idx] - return def_to_disp - - @staticmethod - def _inv_disp_w_sitk(fwd_disp, num_iters): - fwd_disp_sitk = sitk.GetImageFromArray(fwd_disp, isVector=True) - inv_disp_sitk = sitk.InvertDisplacementField(fwd_disp_sitk, num_iters) - inv_disp = sitk.GetArrayFromImage(inv_disp_sitk) - return inv_disp - - @staticmethod - def _inv_disp_w_vtk(fwd_disp): - while fwd_disp.shape[-1] < 3: - fwd_disp = np.append(fwd_disp, np.zeros(fwd_disp.shape[:-1] + (1,)), axis=-1) - fwd_disp = fwd_disp[..., None, :] - # fwd_disp_vtk = vtk.vtkImageImport() - # # The previously created array is converted to a string of chars and imported. - # data_string = fwd_disp.tostring() - # fwd_disp_vtk.CopyImportVoidPointer(data_string, len(data_string)) - # # The type of the newly imported data is set to unsigned char (uint8) - # fwd_disp_vtk.SetDataScalarTypeToUnsignedChar() - # fwd_disp_vtk.SetNumberOfScalarComponents(3) - extent = list(chain.from_iterable(zip([0, 0, 0], fwd_disp.shape[:-1]))) - # fwd_disp_vtk.SetWholeExtent(extent) - # fwd_disp_vtk.SetDataExtentToWholeExtent() - # fwd_disp_vtk.Update() - # fwd_disp_vtk = fwd_disp_vtk.GetOutput() - - fwd_disp_flattened = fwd_disp.flatten() # need to keep this in memory - vtk_data_array = vtk_numpy_support.numpy_to_vtk(fwd_disp_flattened) - - # Generating the vtkImageData - fwd_disp_vtk = vtk.vtkImageData() - fwd_disp_vtk.SetOrigin(0, 0, 0) - fwd_disp_vtk.SetSpacing(1, 1, 1) - fwd_disp_vtk.SetDimensions(*fwd_disp.shape[:-1]) - - fwd_disp_vtk.AllocateScalars(vtk_numpy_support.get_vtk_array_type(fwd_disp.dtype), 3) - fwd_disp_vtk.SetExtent(extent) - fwd_disp_vtk.GetPointData().AddArray(vtk_data_array) - - # # create b-spline coefficients for the displacement grid - # bspline_filter = vtk.vtkImageBSplineCoefficients() - # bspline_filter.SetInputData(fwd_disp_vtk) - # bspline_filter.Update() - - # # use these b-spline coefficients to create a transform - # bspline_transform = vtk.vtkBSplineTransform() - # bspline_transform.SetCoefficientData(bspline_filter.GetOutput()) - # bspline_transform.Update() - - # # invert the b-spline transform onto a new grid - # grid_maker = vtk.vtkTransformToGrid() - # grid_maker.SetInput(bspline_transform.GetInverse()) - # grid_maker.SetGridOrigin(fwd_disp_vtk.GetOrigin()) - # grid_maker.SetGridSpacing(fwd_disp_vtk.GetSpacing()) - # grid_maker.SetGridExtent(fwd_disp_vtk.GetExtent()) - # grid_maker.SetGridScalarTypeToFloat() - # grid_maker.Update() - - # # Get inverse displacement as an image - # inv_disp_vtk = grid_maker.GetOutput() - - # from vtk.util.numpy_support import vtk_to_numpy - # inv_disp = vtk_numpy_support.vtk_to_numpy(inv_disp_vtk.GetPointData().GetScalars()) - inv_disp = vtk_numpy_support.vtk_to_numpy(fwd_disp_vtk.GetPointData().GetArray(0)) - inv_disp = inv_disp.reshape(fwd_disp.shape) - - return inv_disp - - @staticmethod - def compute_inverse_deformation( - num_spatial_dims, fwd_def_orig, spacing=None, num_iters: int = 100, use_package: str = "sitk" - ): - """Package can be vtk or sitk.""" - if use_package.lower() == "vtk" and not has_vtk: - warnings.warn("Please install VTK to estimate inverse of non-rigid transforms. Data has not been modified") - return None - if use_package.lower() == "sitk" and not has_sitk: - warnings.warn( - "Please install SimpleITK to estimate inverse of non-rigid transforms. Data has not been modified" - ) - return None - - # Convert to numpy if necessary - if isinstance(fwd_def_orig, torch.Tensor): - fwd_def_orig = fwd_def_orig.cpu().numpy() - # Remove any extra dimensions (we'll add them back in at the end) - fwd_def = fwd_def_orig[:num_spatial_dims] - # Def -> disp - def_to_disp = NonRigidTransform._get_disp_to_def_arr(fwd_def.shape[1:], spacing) - fwd_disp = fwd_def - def_to_disp - # move tensor component to end (T,H,W,[D])->(H,W,[D],T) - fwd_disp = np.moveaxis(fwd_disp, 0, -1) - - # If using vtk... - if use_package.lower() == "vtk": - inv_disp = NonRigidTransform._inv_disp_w_vtk(fwd_disp) - # If using sitk... - else: - inv_disp = NonRigidTransform._inv_disp_w_sitk(fwd_disp, num_iters) - - import matplotlib.pyplot as plt - - fig, axes = plt.subplots(2, 2) - for i, direc1 in enumerate(["x", "y"]): - for j, (im, direc2) in enumerate(zip([fwd_disp, inv_disp], ["fwd", "inv"])): - ax = axes[i, j] - im_show = ax.imshow(im[..., i]) - ax.set_title(f"{direc2}{direc1}", fontsize=25) - ax.axis("off") - fig.colorbar(im_show, ax=ax) - plt.show() - # move tensor component back to beginning - inv_disp = np.moveaxis(inv_disp, -1, 0) - # Disp -> def - inv_def = inv_disp + def_to_disp - # Add back in any removed dimensions - ndim_in = fwd_def_orig.shape[0] - ndim_out = inv_def.shape[0] - inv_def = np.concatenate([inv_def, fwd_def_orig[ndim_out:ndim_in]]) - - return inv_def diff --git a/tests/test_decollate.py b/tests/test_decollate.py index 758289cd43..29271d6659 100644 --- a/tests/test_decollate.py +++ b/tests/test_decollate.py @@ -52,7 +52,7 @@ def check_match(self, in1, in2): raise RuntimeError(f"Not sure how to compare types. type(in1): {type(in1)}, type(in2): {type(in2)}") @parameterized.expand(TESTS) - def test_decollation(self, _, data, batch_size=2, num_workers=0): + def test_decollation(self, _, data, batch_size=2, num_workers=2): transforms = Compose([ LoadImaged("image"), AddChanneld("image"), diff --git a/tests/test_inverse.py b/tests/test_inverse.py index 497e57192f..99c0b6124d 100644 --- a/tests/test_inverse.py +++ b/tests/test_inverse.py @@ -13,9 +13,12 @@ import unittest from typing import TYPE_CHECKING, List, Tuple +from monai.data.utils import decollate_batch import numpy as np - +import torch +from monai.data import DataLoader from monai.data import CacheDataset, create_test_image_2d, create_test_image_3d +from monai.networks.nets import UNet from monai.transforms import ( AddChannel, AddChanneld, @@ -47,6 +50,7 @@ SpatialPadd, Zoomd, ) +from monai.data import BatchInverseTransform from monai.utils import optional_import, set_determinism from tests.utils import make_nifti_image, make_rand_affine @@ -125,7 +129,7 @@ TESTS.append(("RandSpatialCropd 2d", DATA_2D, 5e-2, RandSpatialCropd(KEYS, [96, 93], True, False))) -TESTS.append(("RandSpatialCropd 3d", DATA_3D, 2e-2, RandSpatialCropd(KEYS, [96, 93, 92], False, True))) +TESTS.append(("RandSpatialCropd 3d", DATA_3D, 2e-2, RandSpatialCropd(KEYS, [96, 93, 92], False, False))) TESTS.append( ( @@ -351,7 +355,7 @@ # spatial_size=[155, 192], prob=1, padding_mode="zeros", - rotate_range=[(np.pi / 6, np.pi / 6), np.pi / 7], + # rotate_range=[(np.pi / 6, np.pi / 6), np.pi / 7], # shear_range=[(0.5, 0.5)], # translate_range=[10, 5], # scale_range=[(0.8, 1.2), (0.9, 1.3)], @@ -381,7 +385,7 @@ TESTS_COMPOSE_X2 = [(t[0] + " Compose", t[1], t[2], Compose(Compose(t[3:]))) for t in TESTS] -TESTS = [*TESTS, *TESTS_COMPOSE_X2] +TESTS = TESTS + TESTS_COMPOSE_X2 # Should fail because uses an array transform (SpatialPad), as opposed to dictionary @@ -413,6 +417,8 @@ def check_inverse(self, name, keys, orig_d, fwd_bck_d, unmodified_d, acceptable_ for key in keys: orig = orig_d[key] fwd_bck = fwd_bck_d[key] + if isinstance(fwd_bck, torch.Tensor): + fwd_bck = fwd_bck.cpu().numpy() unmodified = unmodified_d[key] if isinstance(orig, np.ndarray): mean_diff = np.mean(np.abs(orig - fwd_bck)) @@ -424,8 +430,11 @@ def check_inverse(self, name, keys, orig_d, fwd_bck_d, unmodified_d, acceptable_ print( f"Failed: {name}. Mean diff = {mean_diff} (expected <= {acceptable_diff}), unmodified diff: {unmodded_diff}" ) - if has_matplotlib: + if has_matplotlib and orig[0].ndim > 1: plot_im(orig, fwd_bck, unmodified) + elif orig[0].ndim == 1: + print(orig) + print(fwd_bck) raise # @parameterized.expand(TESTS) @@ -456,20 +465,37 @@ def test_fail(self, data, _, *transform): with self.assertRaises(RuntimeError): d = transform[0].inverse(d) - # @parameterized.expand(TEST_COMPOSES) - def test_w_data_loader(self, _, data, acceptable_diff, *transforms): + def test_w_dataloader(self, _, data, acceptable_diff, *transforms): name = _ - transform = transforms[0] - numel = 2 + device = "cuda" if torch.cuda.is_available() else "cpu" + if isinstance(transforms, tuple): + transforms = Compose(transforms) + numel = 4 test_data = [data for _ in range(numel)] - dataset = CacheDataset(test_data, transform, progress=False) - self.assertEqual(len(dataset), 2) - num_epochs = 2 - for _ in range(num_epochs): - for data_fwd in dataset: - data_fwd_bck = transform.inverse(data_fwd) - self.check_inverse(name, data.keys(), data, data_fwd_bck, data_fwd, acceptable_diff) + ndims = len(data["image"].shape[1:]) + batch_size = 2 + num_workers = 0 + + dataset = CacheDataset(test_data, transforms, progress=False) + loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers) + inv_batch = BatchInverseTransform(transforms, loader) + + model = UNet(ndims, 1, 1, channels=(6, 6), strides=(2, 2)).to(device) + for batch_data in loader: + inputs, _ = ( + batch_data["image"].to(device), + batch_data["label"].to(device), + ) + + fwd_bck_batch = inv_batch(batch_data) + fwd_bck = decollate_batch(fwd_bck_batch) + + for _test_data, _fwd_bck, _dataset in zip(test_data, fwd_bck, dataset): + self.check_inverse(name, data.keys(), _test_data, _fwd_bck, _dataset, acceptable_diff) + + if torch.cuda.is_available(): + _ = model(inputs) if __name__ == "__main__": @@ -477,6 +503,6 @@ def test_w_data_loader(self, _, data, acceptable_diff, *transforms): test = TestInverse() for t in TESTS: test.test_inverse(*t) - test.test_w_data_loader(*t) + test.test_w_dataloader(*t) for t in TESTS_FAIL: test.test_fail(*t) From f1ebd7a7dfcf8e2cbfb82c9a15fb85146adddeb6 Mon Sep 17 00:00:00 2001 From: Rich <33289025+rijobro@users.noreply.github.com> Date: Wed, 10 Feb 2021 12:03:24 +0000 Subject: [PATCH 54/80] remove init_args Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/croppad/dictionary.py | 51 +---- monai/transforms/inverse_transform.py | 28 ++- monai/transforms/spatial/dictionary.py | 277 ++++++------------------- tests/test_inverse.py | 6 +- 4 files changed, 78 insertions(+), 284 deletions(-) diff --git a/monai/transforms/croppad/dictionary.py b/monai/transforms/croppad/dictionary.py index e04141673b..ed4597ea6d 100644 --- a/monai/transforms/croppad/dictionary.py +++ b/monai/transforms/croppad/dictionary.py @@ -190,13 +190,6 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda d[key] = self.padder(d[key], mode=m) return d - def get_input_args(self, key: Hashable, idx: int = 0) -> dict: - return { - "keys": key, - "spatial_border": self.padder.spatial_border, - "mode": self.mode[idx], - } - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = deepcopy(dict(data)) @@ -311,13 +304,6 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda d[key] = self.cropper(d[key]) return d - def get_input_args(self, key: Hashable, idx: int = 0) -> dict: - return { - "keys": key, - "roi_start": self.cropper.roi_start, - "roi_end": self.cropper.roi_end, - } - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = deepcopy(dict(data)) @@ -363,12 +349,6 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda self.append_applied_transforms(d, key, orig_size=orig_size) return d - def get_input_args(self, key: Hashable, idx: int = 0) -> dict: - return { - "keys": key, - "roi_size": self.cropper.roi_size, - } - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = deepcopy(dict(data)) @@ -376,7 +356,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar transform = self.get_most_recent_transform(d, key) # Create inverse transform orig_size = np.array(transform["orig_size"]) - current_size = np.array(transform["init_args"]["roi_size"]) + current_size = np.array(d[key].shape[1:]) pad_to_start = np.floor((orig_size - current_size) / 2) # in each direction, if original size is even and current size is odd, += 1 pad_to_start[np.logical_and(orig_size % 2 == 0, current_size % 2 == 1)] += 1 @@ -451,14 +431,6 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda d[key] = cropper(d[key]) return d - def get_input_args(self, key: Hashable, idx: int = 0) -> dict: - return { - "keys": key, - "roi_size": self.roi_size, - "random_center": self.random_center, - "random_size": self.random_size, - } - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = deepcopy(dict(data)) @@ -602,17 +574,6 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda d[key] = cropper(d[key]) return d - def get_input_args(self, key: Hashable, idx: int = 0) -> dict: - return { - "keys": key, - "source_key": self.source_key, - "select_fn": self.select_fn, - "channel_indices": self.channel_indices, - "margin": self.margin, - "start_coord_key": self.start_coord_key, - "end_coord_key": self.end_coord_key, - } - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = deepcopy(dict(data)) for key in self.keys: @@ -840,21 +801,13 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda self.append_applied_transforms(d, key, orig_size=orig_size) return d - def get_input_args(self, key: Hashable, idx: int = 0) -> dict: - return { - "keys": key, - "spatial_size": self.padcropper.padder.spatial_size, - "mode": self.padcropper.padder.mode, - } - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = deepcopy(dict(data)) for key in self.keys: transform = self.get_most_recent_transform(d, key) # Create inverse transform orig_size = transform["orig_size"] - mode = transform["init_args"]["mode"] - inverse_transform = ResizeWithPadOrCrop(spatial_size=orig_size, mode=mode) + inverse_transform = ResizeWithPadOrCrop(spatial_size=orig_size, mode=self.padcropper.padder.mode) # Apply inverse transform d[key] = inverse_transform(d[key]) # Remove the applied transform diff --git a/monai/transforms/inverse_transform.py b/monai/transforms/inverse_transform.py index 265399f9bb..05d4f658d1 100644 --- a/monai/transforms/inverse_transform.py +++ b/monai/transforms/inverse_transform.py @@ -50,7 +50,6 @@ def append_applied_transforms( key_transform = str(key) + "_transforms" info: Dict[str, Any] = {} info["id"] = id(self) - # info["init_args"] = self.get_input_args(key, idx) info["orig_size"] = orig_size or data[key].shape[1:] if extra_info is not None: info["extra_info"] = extra_info @@ -80,10 +79,6 @@ def remove_most_recent_transform(data: dict, key: Hashable) -> None: """Remove most recent transform.""" data[str(key) + "_transforms"].pop() - def get_input_args(self, key: Hashable, idx: int = 0) -> dict: - """Get input arguments for a single key.""" - raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") - def inverse(self, data: dict): """ Inverse of ``__call__``. @@ -206,17 +201,18 @@ def compute_inverse_deformation( else: inv_disp = NonRigidTransform._inv_disp_w_sitk(fwd_disp, num_iters) - import matplotlib.pyplot as plt - - fig, axes = plt.subplots(2, 2) - for i, direc1 in enumerate(["x", "y"]): - for j, (im, direc2) in enumerate(zip([fwd_disp, inv_disp], ["fwd", "inv"])): - ax = axes[i, j] - im_show = ax.imshow(im[..., i]) - ax.set_title(f"{direc2}{direc1}", fontsize=25) - ax.axis("off") - fig.colorbar(im_show, ax=ax) - plt.show() + if False: + import matplotlib.pyplot as plt + fig, axes = plt.subplots(2, 2) + for i, direc1 in enumerate(["x", "y"]): + for j, (im, direc2) in enumerate(zip([fwd_disp, inv_disp], ["fwd", "inv"])): + ax = axes[i, j] + im_show = ax.imshow(im[..., i]) + ax.set_title(f"{direc2}{direc1}", fontsize=25) + ax.axis("off") + fig.colorbar(im_show, ax=ax) + plt.show() + # move tensor component back to beginning inv_disp = np.moveaxis(inv_disp, -1, 0) # Disp -> def diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index a626c87ae1..1ae1994db7 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -203,40 +203,28 @@ def __call__( meta_data["affine"] = new_affine return d - def get_input_args(self, key: Hashable, idx: int = 0) -> dict: - return { - "keys": key, - "pixdim": self.spacing_transform.pixdim, - "diagonal": self.spacing_transform.diagonal, - "mode": self.mode[idx], - "padding_mode": self.padding_mode[idx], - "align_corners": self.align_corners[idx], - "dtype": self.dtype[idx], - "meta_key_postfix": self.meta_key_postfix, - } - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = deepcopy(dict(data)) - for key in self.keys: + for idx, key in enumerate(self.keys): transform = self.get_most_recent_transform(d, key) - init_args = transform["init_args"] - if init_args["diagonal"]: + if self.spacing_transform.diagonal: raise RuntimeError( "Spacingd:inverse not yet implemented for diagonal=True. " + "Please raise a github issue if you need this feature" ) # Create inverse transform meta_data = d[transform["extra_info"]["meta_data_key"]] - orig_pixdim = np.sqrt(np.sum(np.square(transform["extra_info"]["old_affine"]), 0))[:-1] - inverse_transform = Spacing(orig_pixdim, diagonal=init_args["diagonal"]) + old_affine = np.array(transform["extra_info"]["old_affine"]) + orig_pixdim = np.sqrt(np.sum(np.square(old_affine), 0))[:-1] + inverse_transform = Spacing(orig_pixdim, diagonal=self.spacing_transform.diagonal) # Apply inverse d[key], _, new_affine = inverse_transform( data_array=np.asarray(d[key]), affine=meta_data["affine"], - mode=init_args["mode"], - padding_mode=init_args["padding_mode"], - align_corners=init_args["align_corners"], - dtype=init_args["dtype"], + mode=self.mode[idx], + padding_mode=self.padding_mode[idx], + align_corners=self.align_corners[idx], + dtype=self.dtype[idx], ) meta_data["affine"] = new_affine # Remove the applied transform @@ -307,15 +295,6 @@ def __call__( d[meta_data_key]["affine"] = new_affine return d - def get_input_args(self, key: Hashable, idx: int = 0) -> dict: - return { - "keys": key, - "axcodes": self.ornt_transform.axcodes, - "as_closest_canonical": self.ornt_transform.as_closest_canonical, - "labels": self.ornt_transform.labels, - "meta_key_postfix": self.meta_key_postfix, - } - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = deepcopy(dict(data)) for key in self.keys: @@ -326,8 +305,8 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar orig_axcodes = nib.orientations.aff2axcodes(orig_affine) inverse_transform = Orientation( axcodes=orig_axcodes, - as_closest_canonical=transform["init_args"]["as_closest_canonical"], - labels=transform["init_args"]["labels"], + as_closest_canonical=self.ornt_transform.as_closest_canonical, + labels=self.ornt_transform.labels, ) # Apply inverse d[key], _, new_affine = inverse_transform(d[key], affine=meta_data["affine"]) @@ -360,22 +339,18 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda d[key] = self.rotator(d[key]) return d - def get_input_args(self, key: Hashable, idx: int = 0) -> dict: - return { - "keys": key, - "k": self.rotator.k, - "spatial_axes": self.rotator.spatial_axes, - } - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = deepcopy(dict(data)) for key in self.keys: - transform = self.get_most_recent_transform(d, key) + _ = self.get_most_recent_transform(d, key) # Create inverse transform - spatial_axes = transform["init_args"]["spatial_axes"] - num_times_rotated = transform["init_args"]["k"] + spatial_axes = self.rotator.spatial_axes + num_times_rotated = self.rotator.k num_times_to_rotate = 4 - num_times_rotated inverse_transform = Rotate90(num_times_to_rotate, spatial_axes) + # Might need to convert to numpy + if isinstance(d[key], torch.Tensor): + d[key] = d[key].cpu().numpy() # Apply inverse d[key] = inverse_transform(d[key]) # Remove the applied transform @@ -435,13 +410,6 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Mapping[Hashable, np. self.append_applied_transforms(d, key, extra_info={"rand_k": self._rand_k}) return d - def get_input_args(self, key: Hashable, idx: int = 0) -> dict: - return { - "keys": key, - "prob": self.prob, - "max_k": self.max_k, - "spatial_axes": self.spatial_axes, - } def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = deepcopy(dict(data)) @@ -450,10 +418,12 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar # Check if random transform was actually performed (based on `prob`) if transform["do_transform"]: # Create inverse transform - spatial_axes = transform["init_args"]["spatial_axes"] num_times_rotated = transform["extra_info"]["rand_k"] num_times_to_rotate = 4 - num_times_rotated - inverse_transform = Rotate90(num_times_to_rotate, spatial_axes) + inverse_transform = Rotate90(num_times_to_rotate, self.spatial_axes) + # Might need to convert to numpy + if isinstance(d[key], torch.Tensor): + d[key] = d[key].cpu().numpy() # Apply inverse d[key] = inverse_transform(d[key]) # Remove the applied transform @@ -502,21 +472,13 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda d[key] = self.resizer(d[key], mode=self.mode[idx], align_corners=self.align_corners[idx]) return d - def get_input_args(self, key: Hashable, idx: int = 0) -> dict: - return { - "keys": key, - "spatial_size": self.resizer.spatial_size, - "mode": self.mode[idx], - "align_corners": self.align_corners[idx], - } - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = deepcopy(dict(data)) - for key in self.keys: + for idx, key in enumerate(self.keys): transform = self.get_most_recent_transform(d, key) orig_size = transform["orig_size"] - mode = transform["init_args"]["mode"] - align_corners = transform["init_args"]["align_corners"] + mode = self.mode[idx] + align_corners = self.align_corners[idx] # Create inverse transform inverse_transform = Resize(orig_size, mode, align_corners) # Apply inverse transform @@ -625,38 +587,21 @@ def __call__( d[key] = self.rand_affine.resampler(d[key], grid, mode=self.mode[idx], padding_mode=self.padding_mode[idx]) return d - def get_input_args(self, key: Hashable, idx: int = 0) -> dict: - return { - "keys": key, - "spatial_size": self.rand_affine.spatial_size, - "prob": self.rand_affine.prob, - "rotate_range": self.rand_affine.rand_affine_grid.rotate_range, - "shear_range": self.rand_affine.rand_affine_grid.shear_range, - "translate_range": self.rand_affine.rand_affine_grid.translate_range, - "scale_range": self.rand_affine.rand_affine_grid.scale_range, - "mode": self.rand_affine.mode, - "padding_mode": self.rand_affine.padding_mode, - "as_tensor_output": self.rand_affine.resampler.as_tensor_output, - "device": self.rand_affine.resampler.device, - } - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = deepcopy(dict(data)) - for key in self.keys: + for idx, key in enumerate(self.keys): transform = self.get_most_recent_transform(d, key) - extra_info = transform["extra_info"] - init_args = transform["init_args"] orig_size = transform["orig_size"] # Create inverse transform - fwd_affine = extra_info["affine"] + fwd_affine = transform["extra_info"]["affine"] inv_affine = np.linalg.inv(fwd_affine) affine_grid = AffineGrid(affine=inv_affine) grid: torch.Tensor = affine_grid(orig_size) # type: ignore # Apply inverse transform - out = self.rand_affine.resampler(d[key], grid, init_args["mode"], init_args["padding_mode"]) + out = self.rand_affine.resampler(d[key], grid, self.mode[idx], self.padding_mode[idx]) # Convert to original output type if isinstance(out, torch.Tensor): d[key] = out.cpu().numpy() @@ -783,34 +728,14 @@ def __call__( ) return d - def get_input_args(self, key: Hashable, idx: int = 0) -> dict: - return { - "keys": key, - "spacing": self.rand_2d_elastic.deform_grid.spacing, - "magnitude_range": self.rand_2d_elastic.deform_grid.magnitude, - "spatial_size": self.rand_2d_elastic.spatial_size, - "prob": self.rand_2d_elastic.prob, - "rotate_range": self.rand_2d_elastic.rand_affine_grid.rotate_range, - "shear_range": self.rand_2d_elastic.rand_affine_grid.shear_range, - "translate_range": self.rand_2d_elastic.rand_affine_grid.translate_range, - "scale_range": self.rand_2d_elastic.rand_affine_grid.scale_range, - "mode": self.mode[idx], - "padding_mode": self.padding_mode[idx], - "as_tensor_output": self.rand_2d_elastic.resampler.as_tensor_output, - "device": self.rand_2d_elastic.resampler.device, - } - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = deepcopy(dict(data)) - for key in self.keys: + for idx, key in enumerate(self.keys): transform = self.get_most_recent_transform(d, key) - - extra_info = transform["extra_info"] - init_args = transform["init_args"] orig_size = transform["orig_size"] # Create inverse transform - fwd_def = extra_info["grid"] + fwd_def = transform["extra_info"]["grid"] if fwd_def is None: d[key] = CenterSpatialCrop(roi_size=orig_size)(d[key]) else: @@ -821,12 +746,9 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar inv_def = CenterSpatialCrop(roi_size=orig_size)(inv_def) # Apply inverse transform out = self.rand_2d_elastic.resampler( - d[key], inv_def, init_args["mode"], init_args["padding_mode"] + d[key], inv_def, self.mode[idx], self.padding_mode[idx] ) - if isinstance(out, torch.Tensor): - d[key] = out.cpu().numpy() - else: - d[key] = out + d[key] = out.cpu().numpy() if isinstance(out, torch.Tensor) else out # Remove the applied transform self.remove_most_recent_transform(d, key) @@ -947,34 +869,14 @@ def __call__( ) return d - def get_input_args(self, key: Hashable, idx: int = 0) -> dict: - return { - "keys": key, - "sigma_range": self.rand_3d_elastic.sigma, - "magnitude_range": self.rand_3d_elastic.magnitude_range, - "spatial_size": self.rand_3d_elastic.spatial_size, - "prob": self.rand_3d_elastic.prob, - "rotate_range": self.rand_3d_elastic.rand_affine_grid.rotate_range, - "shear_range": self.rand_3d_elastic.rand_affine_grid.shear_range, - "translate_range": self.rand_3d_elastic.rand_affine_grid.translate_range, - "scale_range": self.rand_3d_elastic.rand_affine_grid.scale_range, - "mode": self.mode[idx], - "padding_mode": self.padding_mode[idx], - "as_tensor_output": self.rand_3d_elastic.resampler.as_tensor_output, - "device": self.rand_3d_elastic.resampler.device, - } - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = deepcopy(dict(data)) - for key in self.keys: + for idx, key in enumerate(self.keys): transform = self.get_most_recent_transform(d, key) - - extra_info = transform["extra_info"] - init_args = transform["init_args"] orig_size = transform["orig_size"] # Create inverse transform - fwd_def = extra_info["grid"] + fwd_def = transform["extra_info"]["grid"] if fwd_def is None: d[key] = CenterSpatialCrop(roi_size=orig_size)(d[key]) else: @@ -985,7 +887,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar inv_def = CenterSpatialCrop(roi_size=orig_size)(inv_def) # Apply inverse transform out = self.rand_3d_elastic.resampler( - d[key], inv_def, init_args["mode"], init_args["padding_mode"] + d[key], inv_def, self.mode[idx], self.padding_mode[idx] ) if isinstance(out, torch.Tensor): d[key] = out.cpu().numpy() @@ -1020,16 +922,13 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda d[key] = self.flipper(d[key]) return d - def get_input_args(self, key: Hashable, idx: int = 0) -> dict: - return { - "keys": key, - "spatial_axis": self.flipper.spatial_axis, - } - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = deepcopy(dict(data)) for key in self.keys: _ = self.get_most_recent_transform(d, key) + # Might need to convert to numpy + if isinstance(d[key], torch.Tensor): + d[key] = d[key].cpu().numpy() # Inverse is same as forward d[key] = self.flipper(d[key]) # Remove the applied transform @@ -1076,18 +975,15 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda return d - def get_input_args(self, key: Hashable, idx: int = 0) -> dict: - return { - "keys": key, - "spatial_axis": self.flipper.spatial_axis, - } - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = deepcopy(dict(data)) for key in self.keys: transform = self.get_most_recent_transform(d, key) # Check if random transform was actually performed (based on `prob`) if transform["do_transform"]: + # Might need to convert to numpy + if isinstance(d[key], torch.Tensor): + d[key] = d[key].cpu().numpy() # Inverse is same as forward d[key] = self.flipper(d[key]) # Remove the applied transform @@ -1156,34 +1052,22 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda self.append_applied_transforms(d, key, idx, orig_size=orig_size, extra_info={"rot_mat": rot_mat}) return d - def get_input_args(self, key: Hashable, idx: int = 0) -> dict: - return { - "keys": key, - "angle": self.rotator.angle, - "keep_size": self.rotator.keep_size, - "mode": self.mode[idx], - "padding_mode": self.padding_mode[idx], - "align_corners": self.align_corners[idx], - "dtype": self.dtype[idx], - } - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = deepcopy(dict(data)) - for key in self.keys: + for idx, key in enumerate(self.keys): transform = self.get_most_recent_transform(d, key) - init_args = transform["init_args"] # Create inverse transform fwd_rot_mat = transform["extra_info"]["rot_mat"] inv_rot_mat = np.linalg.inv(fwd_rot_mat) xform = AffineTransform( normalized=False, - mode=init_args["mode"], - padding_mode=init_args["padding_mode"], - align_corners=init_args["align_corners"], + mode=self.mode[idx], + padding_mode=self.padding_mode[idx], + align_corners=self.align_corners[idx], reverse_indexing=True, ) - dtype = init_args["dtype"] + dtype = self.dtype[idx] output = xform( torch.as_tensor(np.ascontiguousarray(d[key]).astype(dtype)).unsqueeze(0), torch.as_tensor(np.ascontiguousarray(inv_rot_mat).astype(dtype)), @@ -1296,39 +1180,24 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda self.append_applied_transforms(d, key, idx, orig_size=orig_size, extra_info={"rot_mat": rot_mat}) return d - def get_input_args(self, key: Hashable, idx: int = 0) -> dict: - return { - "keys": key, - "range_x": self.range_x, - "range_y": self.range_y, - "range_z": self.range_z, - "prob": self.prob, - "keep_size": self.keep_size, - "mode": self.mode[idx], - "padding_mode": self.padding_mode[idx], - "align_corners": self.align_corners[idx], - "dtype": self.dtype[idx], - } - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = deepcopy(dict(data)) - for key in self.keys: + for idx, key in enumerate(self.keys): transform = self.get_most_recent_transform(d, key) # Check if random transform was actually performed (based on `prob`) if transform["do_transform"]: - init_args = transform["init_args"] # Create inverse transform fwd_rot_mat = transform["extra_info"]["rot_mat"] inv_rot_mat = np.linalg.inv(fwd_rot_mat) xform = AffineTransform( normalized=False, - mode=init_args["mode"], - padding_mode=init_args["padding_mode"], - align_corners=init_args["align_corners"], + mode=self.mode[idx], + padding_mode=self.padding_mode[idx], + align_corners=self.align_corners[idx], reverse_indexing=True, ) - dtype = init_args["dtype"] + dtype = self.dtype[idx] output = xform( torch.as_tensor(np.ascontiguousarray(d[key]).astype(dtype)).unsqueeze(0), torch.as_tensor(np.ascontiguousarray(inv_rot_mat).astype(dtype)), @@ -1392,30 +1261,19 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda ) return d - def get_input_args(self, key: Hashable, idx: int = 0) -> dict: - return { - "keys": key, - "zoom": self.zoomer.zoom, - "mode": self.mode[idx], - "padding_mode": self.padding_mode[idx], - "align_corners": self.align_corners[idx], - "keep_size": self.zoomer.keep_size, - } - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = deepcopy(dict(data)) - for key in self.keys: + for idx, key in enumerate(self.keys): transform = self.get_most_recent_transform(d, key) # Create inverse transform - init_args = transform["init_args"] - zoom = np.array(init_args["zoom"]) - inverse_transform = Zoom(zoom=1 / zoom, keep_size=init_args["keep_size"]) + zoom = np.array(self.zoomer.zoom) + inverse_transform = Zoom(zoom=1 / zoom, keep_size=self.zoomer.keep_size) # Apply inverse d[key] = inverse_transform( d[key], - mode=init_args["mode"], - padding_mode=init_args["padding_mode"], - align_corners=init_args["align_corners"], + mode=self.mode[idx], + padding_mode=self.padding_mode[idx], + align_corners=self.align_corners[idx], ) # Size might be out by 1 voxel so pad d[key] = SpatialPad(transform["orig_size"])(d[key]) @@ -1513,32 +1371,19 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda ) return d - def get_input_args(self, key: Hashable, idx: int = 0) -> dict: - return { - "keys": key, - "prob": self.prob, - "min_zoom": self.min_zoom, - "max_zoom": self.max_zoom, - "mode": self.mode[idx], - "padding_mode": self.padding_mode[idx], - "align_corners": self.align_corners[idx], - "keep_size": self.keep_size, - } - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = deepcopy(dict(data)) - for key in self.keys: + for idx, key in enumerate(self.keys): transform = self.get_most_recent_transform(d, key) # Create inverse transform - init_args = transform["init_args"] zoom = np.array(transform["extra_info"]["zoom"]) - inverse_transform = Zoom(zoom=1 / zoom, keep_size=init_args["keep_size"]) + inverse_transform = Zoom(zoom=1 / zoom, keep_size=self.keep_size) # Apply inverse d[key] = inverse_transform( d[key], - mode=init_args["mode"], - padding_mode=init_args["padding_mode"], - align_corners=init_args["align_corners"], + mode=self.mode[idx], + padding_mode=self.padding_mode[idx], + align_corners=self.align_corners[idx], ) # Size might be out by 1 voxel so pad d[key] = SpatialPad(transform["orig_size"])(d[key]) diff --git a/tests/test_inverse.py b/tests/test_inverse.py index 99c0b6124d..369275a7bd 100644 --- a/tests/test_inverse.py +++ b/tests/test_inverse.py @@ -293,7 +293,7 @@ ) ) -TESTS.append(("RandZoom 3d", DATA_3D, 5e-2, RandZoomd(KEYS, 1, [0.5, 0.6, 0.9], [3, 4.2, 6.1], keep_size=False))) +TESTS.append(("RandZoom 3d", DATA_3D, 5e-2, RandZoomd(KEYS, 1, [0.5, 0.6, 0.9], [1.1, 1, 1.05], keep_size=True))) TESTS.append( ( @@ -347,7 +347,7 @@ ( "Rand2DElasticd 2d", DATA_2D, - 8e-2, + 1e-1, Rand2DElasticd( KEYS, spacing=(10.0, 10.0), @@ -367,7 +367,7 @@ ( "Rand3DElasticd 3d", DATA_3D, - 0, + 2e-1, Rand3DElasticd( KEYS, sigma_range=(1, 3), From 48fbff2283a8817463a0b54f1bc25aed45df2d7d Mon Sep 17 00:00:00 2001 From: Rich <33289025+rijobro@users.noreply.github.com> Date: Wed, 10 Feb 2021 15:44:10 +0000 Subject: [PATCH 55/80] working with data loader Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/croppad/dictionary.py | 8 +-- monai/transforms/inverse_transform.py | 18 +++--- monai/transforms/spatial/array.py | 2 +- monai/transforms/spatial/dictionary.py | 12 ++-- tests/test_inverse.py | 82 +++++++++++++------------- tests/utils.py | 6 +- 6 files changed, 67 insertions(+), 61 deletions(-) diff --git a/monai/transforms/croppad/dictionary.py b/monai/transforms/croppad/dictionary.py index ed4597ea6d..4aaac7e022 100644 --- a/monai/transforms/croppad/dictionary.py +++ b/monai/transforms/croppad/dictionary.py @@ -119,8 +119,8 @@ def __init__( def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) - for idx, (key, m) in enumerate(zip(self.keys, self.mode)): - self.append_applied_transforms(d, key, idx) + for key, m in zip(self.keys, self.mode): + self.append_applied_transforms(d, key) d[key] = self.padder(d[key], mode=m) return d @@ -423,10 +423,10 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda raise AssertionError for idx, key in enumerate(self.keys): if self.random_center: - self.append_applied_transforms(d, key, idx, {"slices": [(i.start, i.stop) for i in self._slices[1:]]}) + self.append_applied_transforms(d, key, {"slices": [(i.start, i.stop) for i in self._slices[1:]]}) d[key] = d[key][self._slices] else: - self.append_applied_transforms(d, key, idx) + self.append_applied_transforms(d, key) cropper = CenterSpatialCrop(self._size) d[key] = cropper(d[key]) return d diff --git a/monai/transforms/inverse_transform.py b/monai/transforms/inverse_transform.py index 05d4f658d1..809b278963 100644 --- a/monai/transforms/inverse_transform.py +++ b/monai/transforms/inverse_transform.py @@ -42,15 +42,16 @@ def append_applied_transforms( self, data: dict, key: Hashable, - idx: int = 0, extra_info: Optional[dict] = None, orig_size: Optional[Tuple] = None, ) -> None: """Append to list of applied transforms for that key.""" key_transform = str(key) + "_transforms" - info: Dict[str, Any] = {} - info["id"] = id(self) - info["orig_size"] = orig_size or data[key].shape[1:] + info = { + "class": self.__class__.__name__, + "id": id(self), + "orig_size": orig_size or data[key].shape[1:], + } if extra_info is not None: info["extra_info"] = extra_info # If class is randomizable, store whether the transform was actually performed (based on `prob`) @@ -62,16 +63,15 @@ def append_applied_transforms( data[key_transform].append(info) - def check_transforms_match(self, transform: dict, key: Hashable) -> None: - explanation = "Should inverse most recently applied invertible transform first" + def check_transforms_match(self, transform: dict) -> None: # Check transorms are of same type. if transform["id"] != id(self): - raise RuntimeError(explanation) + raise RuntimeError("Should inverse most recently applied invertible transform first") def get_most_recent_transform(self, data: dict, key: Hashable) -> dict: """Get most recent transform.""" transform = dict(data[str(key) + "_transforms"][-1]) - self.check_transforms_match(transform, key) + self.check_transforms_match(transform) return transform @staticmethod @@ -116,7 +116,7 @@ def _inv_disp_w_vtk(fwd_disp): fwd_disp = fwd_disp[..., None, :] # fwd_disp_vtk = vtk.vtkImageImport() # # The previously created array is converted to a string of chars and imported. - # data_string = fwd_disp.tostring() + # data_string = fwd_disp.tobytes() # fwd_disp_vtk.CopyImportVoidPointer(data_string, len(data_string)) # # The type of the newly imported data is set to unsigned char (uint8) # fwd_disp_vtk.SetDataScalarTypeToUnsignedChar() diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index 3210c81ef1..63dc384f97 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -278,7 +278,7 @@ def __call__( ornt[:, 0] += 1 # skip channel dim ornt = np.concatenate([np.array([[0, 1]]), ornt]) shape = data_array.shape[1:] - data_array = nib.orientations.apply_orientation(data_array, ornt) + data_array = np.ascontiguousarray(nib.orientations.apply_orientation(data_array, ornt)) new_affine = affine_ @ nib.orientations.inv_ornt_aff(spatial_ornt, shape) new_affine = to_affine_nd(affine, new_affine) return data_array, affine, new_affine diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index 1ae1994db7..1f2860d5bd 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -402,7 +402,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Mapping[Hashable, np. if not self._do_transform: for key in self.keys: self.append_applied_transforms(d, key) - return data + return d rotator = Rotate90(self._rand_k, self.spatial_axes) for key in self.keys: @@ -583,7 +583,7 @@ def __call__( affine = np.eye(len(sp_size) + 1) for idx, key in enumerate(self.keys): - self.append_applied_transforms(d, key, idx, extra_info={"affine": affine}) + self.append_applied_transforms(d, key, extra_info={"affine": affine}) d[key] = self.rand_affine.resampler(d[key], grid, mode=self.mode[idx], padding_mode=self.padding_mode[idx]) return d @@ -722,7 +722,7 @@ def __call__( grid = create_grid(spatial_size=sp_size) for idx, key in enumerate(self.keys): - self.append_applied_transforms(d, key, idx, extra_info={"grid": deepcopy(grid)}) + self.append_applied_transforms(d, key, extra_info={"grid": deepcopy(grid)}) d[key] = self.rand_2d_elastic.resampler( d[key], grid, mode=self.mode[idx], padding_mode=self.padding_mode[idx] ) @@ -863,7 +863,7 @@ def __call__( grid = self.rand_3d_elastic.rand_affine_grid(grid=grid) for idx, key in enumerate(self.keys): - self.append_applied_transforms(d, key, idx, extra_info={"grid": grid.cpu().numpy()}) + self.append_applied_transforms(d, key, extra_info={"grid": grid.cpu().numpy()}) d[key] = self.rand_3d_elastic.resampler( d[key], grid, mode=self.mode[idx], padding_mode=self.padding_mode[idx] ) @@ -1049,7 +1049,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda dtype=self.dtype[idx], return_rotation_matrix=True, ) - self.append_applied_transforms(d, key, idx, orig_size=orig_size, extra_info={"rot_mat": rot_mat}) + self.append_applied_transforms(d, key, orig_size=orig_size, extra_info={"rot_mat": rot_mat}) return d def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: @@ -1177,7 +1177,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda dtype=self.dtype[idx], return_rotation_matrix=True, ) - self.append_applied_transforms(d, key, idx, orig_size=orig_size, extra_info={"rot_mat": rot_mat}) + self.append_applied_transforms(d, key, orig_size=orig_size, extra_info={"rot_mat": rot_mat}) return d def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: diff --git a/tests/test_inverse.py b/tests/test_inverse.py index 369275a7bd..dde4f8c11a 100644 --- a/tests/test_inverse.py +++ b/tests/test_inverse.py @@ -12,7 +12,7 @@ import random import unittest from typing import TYPE_CHECKING, List, Tuple - +import sys from monai.data.utils import decollate_batch import numpy as np import torch @@ -50,11 +50,12 @@ SpatialPadd, Zoomd, ) + from monai.data import BatchInverseTransform from monai.utils import optional_import, set_determinism -from tests.utils import make_nifti_image, make_rand_affine +from tests.utils import make_nifti_image, make_rand_affine, skip_if_quick, test_is_quick -# from parameterized import parameterized +from parameterized import parameterized if TYPE_CHECKING: @@ -235,6 +236,8 @@ "Orientationd 3d", DATA_3D, 0, + # For data loader, output needs to be same size, so input must be square/cubic + SpatialPadd(KEYS, max(DATA_3D["image"].shape)), Orientationd(KEYS, "RAS"), ) ) @@ -262,6 +265,8 @@ "RandRotate90d 3d", DATA_3D, 0, + # For data loader, output needs to be same size, so input must be square/cubic + SpatialPadd(KEYS, max(DATA_3D["image"].shape)), RandRotate90d(KEYS, prob=1, spatial_axes=(1, 2)), ) ) @@ -293,7 +298,7 @@ ) ) -TESTS.append(("RandZoom 3d", DATA_3D, 5e-2, RandZoomd(KEYS, 1, [0.5, 0.6, 0.9], [1.1, 1, 1.05], keep_size=True))) +TESTS.append(("RandZoom 3d", DATA_3D, 9e-2, RandZoomd(KEYS, 1, [0.5, 0.6, 0.9], [1.1, 1, 1.05], keep_size=True))) TESTS.append( ( @@ -363,25 +368,26 @@ ) ) -TESTS.append( - ( - "Rand3DElasticd 3d", - DATA_3D, - 2e-1, - Rand3DElasticd( - KEYS, - sigma_range=(1, 3), - magnitude_range=(1.0, 2.0), - spatial_size=[155, 192, 200], - prob=1, - padding_mode="zeros", - rotate_range=[np.pi / 6, np.pi / 7], - shear_range=[(0.5, 0.5)], - translate_range=[10, 5], - scale_range=[(0.8, 1.2), (0.9, 1.3)], - ), +if not test_is_quick: + TESTS.append( + ( + "Rand3DElasticd 3d", + DATA_3D, + 2e-1, + Rand3DElasticd( + KEYS, + sigma_range=(1, 3), + magnitude_range=(1.0, 2.0), + spatial_size=[155, 192, 200], + prob=1, + padding_mode="zeros", + rotate_range=[np.pi / 6, np.pi / 7], + shear_range=[(0.5, 0.5)], + translate_range=[10, 5], + scale_range=[(0.8, 1.2), (0.9, 1.3)], + ), + ) ) -) TESTS_COMPOSE_X2 = [(t[0] + " Compose", t[1], t[2], Compose(Compose(t[3:]))) for t in TESTS] @@ -425,7 +431,6 @@ def check_inverse(self, name, keys, orig_d, fwd_bck_d, unmodified_d, acceptable_ unmodded_diff = np.mean(np.abs(orig - ResizeWithPadOrCrop(orig.shape[1:])(unmodified))) try: self.assertLessEqual(mean_diff, acceptable_diff) - self.assertLessEqual(mean_diff, unmodded_diff) except AssertionError: print( f"Failed: {name}. Mean diff = {mean_diff} (expected <= {acceptable_diff}), unmodified diff: {unmodded_diff}" @@ -437,7 +442,7 @@ def check_inverse(self, name, keys, orig_d, fwd_bck_d, unmodified_d, acceptable_ print(fwd_bck) raise - # @parameterized.expand(TESTS) + @parameterized.expand(TESTS) def test_inverse(self, _, data, acceptable_diff, *transforms): name = _ @@ -459,12 +464,7 @@ def test_inverse(self, _, data, acceptable_diff, *transforms): fwd_bck = t.inverse(fwd_bck) self.check_inverse(name, data.keys(), forwards[-i - 2], fwd_bck, forwards[-1], acceptable_diff) - # @parameterized.expand(TESTS_FAIL) - def test_fail(self, data, _, *transform): - d = transform[0](data) - with self.assertRaises(RuntimeError): - d = transform[0].inverse(d) - + @parameterized.expand(TESTS) def test_w_dataloader(self, _, data, acceptable_diff, *transforms): name = _ device = "cuda" if torch.cuda.is_available() else "cpu" @@ -475,7 +475,8 @@ def test_w_dataloader(self, _, data, acceptable_diff, *transforms): ndims = len(data["image"].shape[1:]) batch_size = 2 - num_workers = 0 + # num workers = 0 for mac + num_workers = 2 if sys.platform != "darwin" else 0 dataset = CacheDataset(test_data, transforms, progress=False) loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers) @@ -491,18 +492,19 @@ def test_w_dataloader(self, _, data, acceptable_diff, *transforms): fwd_bck_batch = inv_batch(batch_data) fwd_bck = decollate_batch(fwd_bck_batch) - for _test_data, _fwd_bck, _dataset in zip(test_data, fwd_bck, dataset): - self.check_inverse(name, data.keys(), _test_data, _fwd_bck, _dataset, acceptable_diff) + for idx, (_test_data, _fwd_bck) in enumerate(zip(test_data, fwd_bck)): + _fwd = transforms(test_data[idx]) + self.check_inverse(name, data.keys(), _test_data, _fwd_bck, _fwd, acceptable_diff) if torch.cuda.is_available(): _ = model(inputs) + @parameterized.expand(TESTS_FAIL) + def test_fail(self, data, _, *transform): + d = transform[0](data) + with self.assertRaises(RuntimeError): + d = transform[0].inverse(d) + if __name__ == "__main__": - # unittest.main() - test = TestInverse() - for t in TESTS: - test.test_inverse(*t) - test.test_w_dataloader(*t) - for t in TESTS_FAIL: - test.test_fail(*t) + unittest.main() diff --git a/tests/utils.py b/tests/utils.py index 861669ac12..e13d34eb22 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -47,12 +47,16 @@ def test_pretrained_networks(network, input_param, device): return net +def test_is_quick(): + return os.environ.get(quick_test_var, "").lower() == "true" + + def skip_if_quick(obj): """ Skip the unit tests if environment variable `quick_test_var=true`. For example, the user can skip the relevant tests by setting ``export QUICKTEST=true``. """ - is_quick = os.environ.get(quick_test_var, "").lower() == "true" + is_quick = test_is_quick() return unittest.skipIf(is_quick, "Skipping slow tests")(obj) From 1cbc8f76a88acd907b87cd5c0a56bbc724c054c4 Mon Sep 17 00:00:00 2001 From: Rich <33289025+rijobro@users.noreply.github.com> Date: Wed, 10 Feb 2021 16:22:48 +0000 Subject: [PATCH 56/80] code format Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/data/__init__.py | 2 +- monai/data/inverse_batch_transform.py | 13 +++--- monai/data/utils.py | 4 +- monai/transforms/__init__.py | 2 +- monai/transforms/compose.py | 2 +- monai/transforms/croppad/dictionary.py | 10 ++--- monai/transforms/inverse_batch_transform.py | 46 --------------------- monai/transforms/inverse_transform.py | 8 ++-- monai/transforms/spatial/array.py | 2 +- monai/transforms/spatial/dictionary.py | 19 ++++----- monai/transforms/transform.py | 2 + tests/test_decollate.py | 39 +++++++++-------- tests/test_inverse.py | 19 ++++----- 13 files changed, 62 insertions(+), 106 deletions(-) delete mode 100644 monai/transforms/inverse_batch_transform.py diff --git a/monai/data/__init__.py b/monai/data/__init__.py index 90b65bc347..d63b604ecd 100644 --- a/monai/data/__init__.py +++ b/monai/data/__init__.py @@ -39,6 +39,7 @@ compute_shape_offset, correct_nifti_header_if_necessary, create_file_basename, + decollate_batch, dense_patch_slices, get_random_patch, get_valid_patch_size, @@ -57,5 +58,4 @@ to_affine_nd, worker_init_fn, zoom_affine, - decollate_batch, ) diff --git a/monai/data/inverse_batch_transform.py b/monai/data/inverse_batch_transform.py index 485213a6ab..0f10ec6c41 100644 --- a/monai/data/inverse_batch_transform.py +++ b/monai/data/inverse_batch_transform.py @@ -9,17 +9,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable -from monai.data.utils import decollate_batch +from torch.utils.data.dataset import Dataset from monai.data.dataloader import DataLoader -from monai.data.dataset import Dataset - +from monai.data.utils import decollate_batch +from monai.transforms.inverse_transform import InvertibleTransform __all__ = ["BatchInverseTransform"] + class _BatchInverseDataset(Dataset): - def __init__(self, data, transform) -> None: + def __init__(self, data, transform: InvertibleTransform) -> None: self.data = decollate_batch(data) self.transform = transform @@ -30,7 +30,8 @@ def __getitem__(self, index: int): class BatchInverseTransform: """something""" - def __init__(self, transform: Callable, loader) -> None: + + def __init__(self, transform: InvertibleTransform, loader) -> None: """ Args: transform: a callable data transform on input data. diff --git a/monai/data/utils.py b/monai/data/utils.py index 6c14d431ae..b6e0da8db2 100644 --- a/monai/data/utils.py +++ b/monai/data/utils.py @@ -12,7 +12,6 @@ import hashlib import json import math -from monai.utils.misc import issequenceiterable import os import pickle import warnings @@ -37,6 +36,7 @@ first, optional_import, ) +from monai.utils.misc import issequenceiterable nib, _ = optional_import("nibabel") @@ -244,6 +244,7 @@ def list_data_collate(batch: Sequence): data = [i for k in batch for i in k] if isinstance(elem, list) else batch return default_collate(data) + def decollate_batch(data: dict, batch_size: Optional[int] = None): """De-collate a batch of data (for example, as produced by a `DataLoader`). @@ -308,6 +309,7 @@ def decollate(data: Any, idx: int): return [{key: decollate(data[key], idx) for key in data.keys()} for idx in range(batch_size)] + def worker_init_fn(worker_id: int) -> None: """ Callback function for PyTorch DataLoader `worker_init_fn`. diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index 83921b5e68..5f47d9336f 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -138,6 +138,7 @@ ThresholdIntensityD, ThresholdIntensityDict, ) +from .inverse_transform import InvertibleTransform, NonRigidTransform from .io.array import LoadImage from .io.dictionary import LoadImaged, LoadImageD, LoadImageDict from .post.array import ( @@ -235,7 +236,6 @@ ZoomDict, ) from .transform import MapTransform, Randomizable, Transform -from .inverse_transform import InvertibleTransform, NonRigidTransform from .utility.array import ( AddChannel, AddExtremePointsChannel, diff --git a/monai/transforms/compose.py b/monai/transforms/compose.py index 713df6ade7..b3997f1197 100644 --- a/monai/transforms/compose.py +++ b/monai/transforms/compose.py @@ -18,8 +18,8 @@ import numpy as np -from monai.transforms.transform import Randomizable, Transform from monai.transforms.inverse_transform import InvertibleTransform +from monai.transforms.transform import Randomizable, Transform from monai.transforms.utils import apply_transform from monai.utils import MAX_SEED, ensure_tuple, get_seed diff --git a/monai/transforms/croppad/dictionary.py b/monai/transforms/croppad/dictionary.py index 4aaac7e022..f9c9fb8670 100644 --- a/monai/transforms/croppad/dictionary.py +++ b/monai/transforms/croppad/dictionary.py @@ -32,8 +32,8 @@ SpatialCrop, SpatialPad, ) -from monai.transforms.transform import MapTransform, Randomizable from monai.transforms.inverse_transform import InvertibleTransform +from monai.transforms.transform import MapTransform, Randomizable from monai.transforms.utils import ( generate_pos_neg_label_crop_centers, generate_spatial_bounding_box, @@ -132,9 +132,9 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar orig_size = transform["orig_size"] if self.padder.method == Method.SYMMETRIC: current_size = d[key].shape[1:] - roi_center = [floor(i / 2) if r % 2 == 0 else (i - 1) / 2 for r, i in zip(orig_size, current_size)] + roi_center = [floor(i / 2) if r % 2 == 0 else (i - 1) // 2 for r, i in zip(orig_size, current_size)] else: - roi_center = [floor(r / 2) if r % 2 == 0 else (r - 1) / 2 for r in orig_size] + roi_center = [floor(r / 2) if r % 2 == 0 else (r - 1) // 2 for r in orig_size] inverse_transform = SpatialCrop(roi_center, orig_size) # Apply inverse transform @@ -421,9 +421,9 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda self.randomize(d[self.keys[0]].shape[1:]) # image shape from the first data key if self._size is None: raise AssertionError - for idx, key in enumerate(self.keys): + for key in self.keys: if self.random_center: - self.append_applied_transforms(d, key, {"slices": [(i.start, i.stop) for i in self._slices[1:]]}) + self.append_applied_transforms(d, key, {"slices": [(i.start, i.stop) for i in self._slices[1:]]}) # type: ignore d[key] = d[key][self._slices] else: self.append_applied_transforms(d, key) diff --git a/monai/transforms/inverse_batch_transform.py b/monai/transforms/inverse_batch_transform.py deleted file mode 100644 index 485213a6ab..0000000000 --- a/monai/transforms/inverse_batch_transform.py +++ /dev/null @@ -1,46 +0,0 @@ -# Copyright 2020 - 2021 MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Callable -from monai.data.utils import decollate_batch - -from monai.data.dataloader import DataLoader -from monai.data.dataset import Dataset - - -__all__ = ["BatchInverseTransform"] - -class _BatchInverseDataset(Dataset): - def __init__(self, data, transform) -> None: - self.data = decollate_batch(data) - self.transform = transform - - def __getitem__(self, index: int): - data = self.data[index] - return self.transform.inverse(data) - - -class BatchInverseTransform: - """something""" - def __init__(self, transform: Callable, loader) -> None: - """ - Args: - transform: a callable data transform on input data. - loader: data loader used to generate the batch of data. - """ - self.transform = transform - self.batch_size = loader.batch_size - self.num_workers = loader.num_workers - - def __call__(self, data): - inv_ds = _BatchInverseDataset(data, self.transform) - inv_loader = DataLoader(inv_ds, batch_size=self.batch_size, num_workers=self.num_workers) - return next(iter(inv_loader)) diff --git a/monai/transforms/inverse_transform.py b/monai/transforms/inverse_transform.py index 809b278963..49e49ddd6d 100644 --- a/monai/transforms/inverse_transform.py +++ b/monai/transforms/inverse_transform.py @@ -12,12 +12,13 @@ import warnings from abc import ABC from itertools import chain -from typing import Any, Dict, Hashable, Optional, Tuple +from typing import Hashable, Optional, Tuple + import numpy as np import torch -from monai.utils import optional_import from monai.transforms.transform import Randomizable +from monai.utils import optional_import sitk, has_sitk = optional_import("SimpleITK") vtk, has_vtk = optional_import("vtk") @@ -25,6 +26,7 @@ __all__ = ["InvertibleTransform", "NonRigidTransform"] + class InvertibleTransform(ABC): """Classes for invertible transforms. @@ -62,7 +64,6 @@ def append_applied_transforms( data[key_transform] = [] data[key_transform].append(info) - def check_transforms_match(self, transform: dict) -> None: # Check transorms are of same type. if transform["id"] != id(self): @@ -203,6 +204,7 @@ def compute_inverse_deformation( if False: import matplotlib.pyplot as plt + fig, axes = plt.subplots(2, 2) for i, direc1 in enumerate(["x", "y"]): for j, (im, direc2) in enumerate(zip([fwd_disp, inv_disp], ["fwd", "inv"])): diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index 63dc384f97..c229c1d85b 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -592,7 +592,7 @@ def __init__(self, k: int = 1, spatial_axes: Tuple[int, int] = (0, 1)) -> None: If axis is negative it counts from the last to the first axis. """ self.k = k - spatial_axes_ = ensure_tuple(spatial_axes) + spatial_axes_: Tuple[int, int] = ensure_tuple(spatial_axes) # type: ignore if len(spatial_axes_) != 2: raise ValueError("spatial_axes must be 2 int numbers to indicate the axes to rotate 90 degrees.") self.spatial_axes = spatial_axes_ diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index 1f2860d5bd..6068dee5c8 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -25,6 +25,7 @@ from monai.networks.layers import AffineTransform from monai.networks.layers.simplelayers import GaussianFilter from monai.transforms.croppad.array import CenterSpatialCrop, SpatialPad +from monai.transforms.inverse_transform import InvertibleTransform, NonRigidTransform from monai.transforms.spatial.array import ( AffineGrid, Flip, @@ -39,7 +40,6 @@ Zoom, ) from monai.transforms.transform import MapTransform, Randomizable -from monai.transforms.inverse_transform import InvertibleTransform, NonRigidTransform from monai.transforms.utils import create_grid from monai.utils import ( GridSampleMode, @@ -350,7 +350,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar inverse_transform = Rotate90(num_times_to_rotate, spatial_axes) # Might need to convert to numpy if isinstance(d[key], torch.Tensor): - d[key] = d[key].cpu().numpy() + d[key] = torch.Tensor(d[key]).cpu().numpy() # Apply inverse d[key] = inverse_transform(d[key]) # Remove the applied transform @@ -410,7 +410,6 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Mapping[Hashable, np. self.append_applied_transforms(d, key, extra_info={"rand_k": self._rand_k}) return d - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = deepcopy(dict(data)) for key in self.keys: @@ -423,7 +422,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar inverse_transform = Rotate90(num_times_to_rotate, self.spatial_axes) # Might need to convert to numpy if isinstance(d[key], torch.Tensor): - d[key] = d[key].cpu().numpy() + d[key] = torch.Tensor(d[key]).cpu().numpy() # Apply inverse d[key] = inverse_transform(d[key]) # Remove the applied transform @@ -745,9 +744,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar # Back to original size inv_def = CenterSpatialCrop(roi_size=orig_size)(inv_def) # Apply inverse transform - out = self.rand_2d_elastic.resampler( - d[key], inv_def, self.mode[idx], self.padding_mode[idx] - ) + out = self.rand_2d_elastic.resampler(d[key], inv_def, self.mode[idx], self.padding_mode[idx]) d[key] = out.cpu().numpy() if isinstance(out, torch.Tensor) else out # Remove the applied transform self.remove_most_recent_transform(d, key) @@ -886,9 +883,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar # Back to original size inv_def = CenterSpatialCrop(roi_size=orig_size)(inv_def) # Apply inverse transform - out = self.rand_3d_elastic.resampler( - d[key], inv_def, self.mode[idx], self.padding_mode[idx] - ) + out = self.rand_3d_elastic.resampler(d[key], inv_def, self.mode[idx], self.padding_mode[idx]) if isinstance(out, torch.Tensor): d[key] = out.cpu().numpy() else: @@ -928,7 +923,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar _ = self.get_most_recent_transform(d, key) # Might need to convert to numpy if isinstance(d[key], torch.Tensor): - d[key] = d[key].cpu().numpy() + d[key] = torch.Tensor(d[key]).cpu().numpy() # Inverse is same as forward d[key] = self.flipper(d[key]) # Remove the applied transform @@ -983,7 +978,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar if transform["do_transform"]: # Might need to convert to numpy if isinstance(d[key], torch.Tensor): - d[key] = d[key].cpu().numpy() + d[key] = torch.Tensor(d[key]).cpu().numpy() # Inverse is same as forward d[key] = self.flipper(d[key]) # Remove the applied transform diff --git a/monai/transforms/transform.py b/monai/transforms/transform.py index 36c7445cf1..4f0b2eca79 100644 --- a/monai/transforms/transform.py +++ b/monai/transforms/transform.py @@ -14,7 +14,9 @@ from abc import ABC, abstractmethod from typing import Any, Hashable, Optional, Tuple + import numpy as np + from monai.config import KeysCollection from monai.utils import MAX_SEED, ensure_tuple diff --git a/tests/test_decollate.py b/tests/test_decollate.py index 29271d6659..dcf94d06fc 100644 --- a/tests/test_decollate.py +++ b/tests/test_decollate.py @@ -9,20 +9,18 @@ # See the License for the specific language governing permissions and # limitations under the License. -import torch import unittest + import numpy as np +import torch +from parameterized import parameterized -from monai.data import DataLoader -from monai.data import CacheDataset, create_test_image_2d -from monai.transforms import AddChanneld, Compose, LoadImaged, ToTensord, SpatialPadd, RandFlipd +from monai.data import CacheDataset, DataLoader, create_test_image_2d from monai.data.utils import decollate_batch +from monai.transforms import AddChanneld, Compose, LoadImaged, RandFlipd, SpatialPadd, ToTensord from monai.utils import set_determinism from tests.utils import make_nifti_image -from parameterized import parameterized - - set_determinism(seed=0) IM_2D_FNAME = make_nifti_image(create_test_image_2d(100, 101)[0]) @@ -30,10 +28,13 @@ DATA_2D = {"image": IM_2D_FNAME} TESTS = [] -TESTS.append(( - "2D", - [DATA_2D for _ in range(6)], -)) +TESTS.append( + ( + "2D", + [DATA_2D for _ in range(6)], + ) +) + class TestDeCollate(unittest.TestCase): def check_match(self, in1, in2): @@ -53,13 +54,15 @@ def check_match(self, in1, in2): @parameterized.expand(TESTS) def test_decollation(self, _, data, batch_size=2, num_workers=2): - transforms = Compose([ - LoadImaged("image"), - AddChanneld("image"), - SpatialPadd("image", 150), - RandFlipd("image", prob=1., spatial_axis=1), - ToTensord("image"), - ]) + transforms = Compose( + [ + LoadImaged("image"), + AddChanneld("image"), + SpatialPadd("image", 150), + RandFlipd("image", prob=1.0, spatial_axis=1), + ToTensord("image"), + ] + ) dataset = CacheDataset(data, transforms, progress=False) loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers) diff --git a/tests/test_inverse.py b/tests/test_inverse.py index dde4f8c11a..818a7583a7 100644 --- a/tests/test_inverse.py +++ b/tests/test_inverse.py @@ -10,14 +10,16 @@ # limitations under the License. import random +import sys import unittest from typing import TYPE_CHECKING, List, Tuple -import sys -from monai.data.utils import decollate_batch + import numpy as np import torch -from monai.data import DataLoader -from monai.data import CacheDataset, create_test_image_2d, create_test_image_3d +from parameterized import parameterized + +from monai.data import BatchInverseTransform, CacheDataset, DataLoader, create_test_image_2d, create_test_image_3d +from monai.data.utils import decollate_batch from monai.networks.nets import UNet from monai.transforms import ( AddChannel, @@ -50,13 +52,8 @@ SpatialPadd, Zoomd, ) - -from monai.data import BatchInverseTransform from monai.utils import optional_import, set_determinism -from tests.utils import make_nifti_image, make_rand_affine, skip_if_quick, test_is_quick - -from parameterized import parameterized - +from tests.utils import make_nifti_image, make_rand_affine, test_is_quick if TYPE_CHECKING: import matplotlib.pyplot as plt @@ -391,7 +388,7 @@ TESTS_COMPOSE_X2 = [(t[0] + " Compose", t[1], t[2], Compose(Compose(t[3:]))) for t in TESTS] -TESTS = TESTS + TESTS_COMPOSE_X2 +TESTS = TESTS + TESTS_COMPOSE_X2 # type: ignore # Should fail because uses an array transform (SpatialPad), as opposed to dictionary From b60d206597194dac5e5449ef68cae262bf04bbb3 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Thu, 11 Feb 2021 08:58:16 +0000 Subject: [PATCH 57/80] RandCropByPosNegLabeld to call correct parent constructor Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/croppad/dictionary.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/transforms/croppad/dictionary.py b/monai/transforms/croppad/dictionary.py index f9c9fb8670..3dd7667856 100644 --- a/monai/transforms/croppad/dictionary.py +++ b/monai/transforms/croppad/dictionary.py @@ -710,7 +710,7 @@ def __init__( fg_indices_key: Optional[str] = None, bg_indices_key: Optional[str] = None, ) -> None: - super().__init__(keys) + MapTransform.__init__(self, keys) self.label_key = label_key self.spatial_size: Union[Tuple[int, ...], Sequence[int], int] = spatial_size if pos < 0 or neg < 0: From 61e56bec33653b94a77cfe80c7df08cdaa21aef0 Mon Sep 17 00:00:00 2001 From: Rich <33289025+rijobro@users.noreply.github.com> Date: Thu, 11 Feb 2021 13:34:27 +0000 Subject: [PATCH 58/80] vtk attempt Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/inverse_transform.py | 86 +++++++++++++-------------- 1 file changed, 42 insertions(+), 44 deletions(-) diff --git a/monai/transforms/inverse_transform.py b/monai/transforms/inverse_transform.py index 49e49ddd6d..dfd5e94c47 100644 --- a/monai/transforms/inverse_transform.py +++ b/monai/transforms/inverse_transform.py @@ -115,19 +115,10 @@ def _inv_disp_w_vtk(fwd_disp): while fwd_disp.shape[-1] < 3: fwd_disp = np.append(fwd_disp, np.zeros(fwd_disp.shape[:-1] + (1,)), axis=-1) fwd_disp = fwd_disp[..., None, :] - # fwd_disp_vtk = vtk.vtkImageImport() - # # The previously created array is converted to a string of chars and imported. - # data_string = fwd_disp.tobytes() - # fwd_disp_vtk.CopyImportVoidPointer(data_string, len(data_string)) - # # The type of the newly imported data is set to unsigned char (uint8) - # fwd_disp_vtk.SetDataScalarTypeToUnsignedChar() - # fwd_disp_vtk.SetNumberOfScalarComponents(3) - extent = list(chain.from_iterable(zip([0, 0, 0], fwd_disp.shape[:-1]))) - # fwd_disp_vtk.SetWholeExtent(extent) - # fwd_disp_vtk.SetDataExtentToWholeExtent() - # fwd_disp_vtk.Update() - # fwd_disp_vtk = fwd_disp_vtk.GetOutput() - + # if any spatial dimensions have size == 1, double them along that axis (required by vtk) + for i, s in enumerate(fwd_disp.shape[:-1]): + if s == 1: + fwd_disp = np.repeat(fwd_disp, repeats=2, axis=i) fwd_disp_flattened = fwd_disp.flatten() # need to keep this in memory vtk_data_array = vtk_numpy_support.numpy_to_vtk(fwd_disp_flattened) @@ -136,43 +127,50 @@ def _inv_disp_w_vtk(fwd_disp): fwd_disp_vtk.SetOrigin(0, 0, 0) fwd_disp_vtk.SetSpacing(1, 1, 1) fwd_disp_vtk.SetDimensions(*fwd_disp.shape[:-1]) - - fwd_disp_vtk.AllocateScalars(vtk_numpy_support.get_vtk_array_type(fwd_disp.dtype), 3) - fwd_disp_vtk.SetExtent(extent) - fwd_disp_vtk.GetPointData().AddArray(vtk_data_array) - - # # create b-spline coefficients for the displacement grid - # bspline_filter = vtk.vtkImageBSplineCoefficients() - # bspline_filter.SetInputData(fwd_disp_vtk) - # bspline_filter.Update() - - # # use these b-spline coefficients to create a transform - # bspline_transform = vtk.vtkBSplineTransform() - # bspline_transform.SetCoefficientData(bspline_filter.GetOutput()) - # bspline_transform.Update() - - # # invert the b-spline transform onto a new grid - # grid_maker = vtk.vtkTransformToGrid() - # grid_maker.SetInput(bspline_transform.GetInverse()) - # grid_maker.SetGridOrigin(fwd_disp_vtk.GetOrigin()) - # grid_maker.SetGridSpacing(fwd_disp_vtk.GetSpacing()) - # grid_maker.SetGridExtent(fwd_disp_vtk.GetExtent()) - # grid_maker.SetGridScalarTypeToFloat() - # grid_maker.Update() - - # # Get inverse displacement as an image - # inv_disp_vtk = grid_maker.GetOutput() - - # from vtk.util.numpy_support import vtk_to_numpy - # inv_disp = vtk_numpy_support.vtk_to_numpy(inv_disp_vtk.GetPointData().GetScalars()) - inv_disp = vtk_numpy_support.vtk_to_numpy(fwd_disp_vtk.GetPointData().GetArray(0)) + if False: + fwd_disp_vtk.SetNumberOfScalarComponents(3) + fwd_disp_vtk.GetPointData().AddArray(vtk_data_array) + else: + fwd_disp_vtk.AllocateScalars(vtk_numpy_support.get_vtk_array_type(fwd_disp.dtype), 3) + fwd_disp_vtk.GetPointData().GetArray(0).SetArray(vtk_data_array) + + + if __debug__: + fwd_disp_vtk_np = vtk_numpy_support.vtk_to_numpy(fwd_disp_vtk.GetPointData().GetArray(0)) + assert fwd_disp_vtk_np.size == fwd_disp.size + assert fwd_disp_vtk_np.min() == fwd_disp.min() + assert fwd_disp_vtk_np.max() == fwd_disp.max() + + # create b-spline coefficients for the displacement grid + bspline_filter = vtk.vtkImageBSplineCoefficients() + bspline_filter.SetInputData(fwd_disp_vtk) + bspline_filter.Update() + + # use these b-spline coefficients to create a transform + bspline_transform = vtk.vtkBSplineTransform() + bspline_transform.SetCoefficientData(bspline_filter.GetOutput()) + bspline_transform.Update() + + # invert the b-spline transform onto a new grid + grid_maker = vtk.vtkTransformToGrid() + grid_maker.SetInput(bspline_transform.GetInverse()) + grid_maker.SetGridOrigin(fwd_disp_vtk.GetOrigin()) + grid_maker.SetGridSpacing(fwd_disp_vtk.GetSpacing()) + grid_maker.SetGridExtent(fwd_disp_vtk.GetExtent()) + grid_maker.SetGridScalarTypeToFloat() + grid_maker.Update() + + # Get inverse displacement as an image + inv_disp_vtk = grid_maker.GetOutput() + + inv_disp = vtk_numpy_support.vtk_to_numpy(inv_disp_vtk.GetPointData().GetArray(0)) inv_disp = inv_disp.reshape(fwd_disp.shape) return inv_disp @staticmethod def compute_inverse_deformation( - num_spatial_dims, fwd_def_orig, spacing=None, num_iters: int = 100, use_package: str = "sitk" + num_spatial_dims, fwd_def_orig, spacing=None, num_iters: int = 100, use_package: str = "vtk" ): """Package can be vtk or sitk.""" if use_package.lower() == "vtk" and not has_vtk: From 83502e5e4c4b3c944c49b72d00f0af73b57df612 Mon Sep 17 00:00:00 2001 From: Rich <33289025+rijobro@users.noreply.github.com> Date: Thu, 11 Feb 2021 15:47:21 +0000 Subject: [PATCH 59/80] more vtk progress Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/inverse_transform.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/monai/transforms/inverse_transform.py b/monai/transforms/inverse_transform.py index dfd5e94c47..924540296e 100644 --- a/monai/transforms/inverse_transform.py +++ b/monai/transforms/inverse_transform.py @@ -119,7 +119,7 @@ def _inv_disp_w_vtk(fwd_disp): for i, s in enumerate(fwd_disp.shape[:-1]): if s == 1: fwd_disp = np.repeat(fwd_disp, repeats=2, axis=i) - fwd_disp_flattened = fwd_disp.flatten() # need to keep this in memory + fwd_disp_flattened = fwd_disp.reshape(-1, 3) # need to keep this in memory vtk_data_array = vtk_numpy_support.numpy_to_vtk(fwd_disp_flattened) # Generating the vtkImageData @@ -127,19 +127,14 @@ def _inv_disp_w_vtk(fwd_disp): fwd_disp_vtk.SetOrigin(0, 0, 0) fwd_disp_vtk.SetSpacing(1, 1, 1) fwd_disp_vtk.SetDimensions(*fwd_disp.shape[:-1]) - if False: - fwd_disp_vtk.SetNumberOfScalarComponents(3) - fwd_disp_vtk.GetPointData().AddArray(vtk_data_array) - else: - fwd_disp_vtk.AllocateScalars(vtk_numpy_support.get_vtk_array_type(fwd_disp.dtype), 3) - fwd_disp_vtk.GetPointData().GetArray(0).SetArray(vtk_data_array) - + fwd_disp_vtk.GetPointData().SetScalars(vtk_data_array) if __debug__: fwd_disp_vtk_np = vtk_numpy_support.vtk_to_numpy(fwd_disp_vtk.GetPointData().GetArray(0)) assert fwd_disp_vtk_np.size == fwd_disp.size assert fwd_disp_vtk_np.min() == fwd_disp.min() assert fwd_disp_vtk_np.max() == fwd_disp.max() + assert fwd_disp_vtk.GetNumberOfScalarComponents() == 3 # create b-spline coefficients for the displacement grid bspline_filter = vtk.vtkImageBSplineCoefficients() From 3284a0740a1bee56ed1612acfe90af5fb5b3b587 Mon Sep 17 00:00:00 2001 From: Rich <33289025+rijobro@users.noreply.github.com> Date: Thu, 11 Feb 2021 15:57:29 +0000 Subject: [PATCH 60/80] no error Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/inverse_transform.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/monai/transforms/inverse_transform.py b/monai/transforms/inverse_transform.py index 924540296e..ee2695e164 100644 --- a/monai/transforms/inverse_transform.py +++ b/monai/transforms/inverse_transform.py @@ -112,13 +112,14 @@ def _inv_disp_w_sitk(fwd_disp, num_iters): @staticmethod def _inv_disp_w_vtk(fwd_disp): + orig_shape = fwd_disp.shape + # VTK requires 3 tensor components, so if shape was (H, W, 2), make it + # (H, W, 1, 3) (i.e., depth 1 with a 3rd tensor component of 0s) while fwd_disp.shape[-1] < 3: fwd_disp = np.append(fwd_disp, np.zeros(fwd_disp.shape[:-1] + (1,)), axis=-1) fwd_disp = fwd_disp[..., None, :] - # if any spatial dimensions have size == 1, double them along that axis (required by vtk) - for i, s in enumerate(fwd_disp.shape[:-1]): - if s == 1: - fwd_disp = np.repeat(fwd_disp, repeats=2, axis=i) + + # Create VTKDoubleArray. Shape needs to be (H*W*D, 3) fwd_disp_flattened = fwd_disp.reshape(-1, 3) # need to keep this in memory vtk_data_array = vtk_numpy_support.numpy_to_vtk(fwd_disp_flattened) @@ -158,8 +159,12 @@ def _inv_disp_w_vtk(fwd_disp): # Get inverse displacement as an image inv_disp_vtk = grid_maker.GetOutput() + # Convert back to numpy and reshape inv_disp = vtk_numpy_support.vtk_to_numpy(inv_disp_vtk.GetPointData().GetArray(0)) - inv_disp = inv_disp.reshape(fwd_disp.shape) + # if there were originally < 3 tensor components, remove the zeros we added at the start + inv_disp = inv_disp[..., :orig_shape[-1]] + # reshape to original + inv_disp = inv_disp.reshape(orig_shape) return inv_disp From 884a4079d0596a9595400a89594ef8b71f5cf0e3 Mon Sep 17 00:00:00 2001 From: Rich <33289025+rijobro@users.noreply.github.com> Date: Thu, 11 Feb 2021 17:08:17 +0000 Subject: [PATCH 61/80] batchdataset inherits from monai dataset Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/data/inverse_batch_transform.py | 3 +-- monai/transforms/inverse_transform.py | 9 +++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/monai/data/inverse_batch_transform.py b/monai/data/inverse_batch_transform.py index 0f10ec6c41..8ac9301b8a 100644 --- a/monai/data/inverse_batch_transform.py +++ b/monai/data/inverse_batch_transform.py @@ -9,9 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from torch.utils.data.dataset import Dataset - from monai.data.dataloader import DataLoader +from monai.data.dataset import Dataset from monai.data.utils import decollate_batch from monai.transforms.inverse_transform import InvertibleTransform diff --git a/monai/transforms/inverse_transform.py b/monai/transforms/inverse_transform.py index ee2695e164..f9a8719ce6 100644 --- a/monai/transforms/inverse_transform.py +++ b/monai/transforms/inverse_transform.py @@ -113,14 +113,15 @@ def _inv_disp_w_sitk(fwd_disp, num_iters): @staticmethod def _inv_disp_w_vtk(fwd_disp): orig_shape = fwd_disp.shape + required_num_tensor_components = 3 # VTK requires 3 tensor components, so if shape was (H, W, 2), make it # (H, W, 1, 3) (i.e., depth 1 with a 3rd tensor component of 0s) - while fwd_disp.shape[-1] < 3: + while fwd_disp.shape[-1] < required_num_tensor_components: fwd_disp = np.append(fwd_disp, np.zeros(fwd_disp.shape[:-1] + (1,)), axis=-1) fwd_disp = fwd_disp[..., None, :] # Create VTKDoubleArray. Shape needs to be (H*W*D, 3) - fwd_disp_flattened = fwd_disp.reshape(-1, 3) # need to keep this in memory + fwd_disp_flattened = fwd_disp.reshape(-1, required_num_tensor_components) # need to keep this in memory vtk_data_array = vtk_numpy_support.numpy_to_vtk(fwd_disp_flattened) # Generating the vtkImageData @@ -135,7 +136,7 @@ def _inv_disp_w_vtk(fwd_disp): assert fwd_disp_vtk_np.size == fwd_disp.size assert fwd_disp_vtk_np.min() == fwd_disp.min() assert fwd_disp_vtk_np.max() == fwd_disp.max() - assert fwd_disp_vtk.GetNumberOfScalarComponents() == 3 + assert fwd_disp_vtk.GetNumberOfScalarComponents() == required_num_tensor_components # create b-spline coefficients for the displacement grid bspline_filter = vtk.vtkImageBSplineCoefficients() @@ -200,7 +201,7 @@ def compute_inverse_deformation( else: inv_disp = NonRigidTransform._inv_disp_w_sitk(fwd_disp, num_iters) - if False: + if True: import matplotlib.pyplot as plt fig, axes = plt.subplots(2, 2) From 0834b0693cb17349cef937a86e885cb8b059f542 Mon Sep 17 00:00:00 2001 From: Rich <33289025+rijobro@users.noreply.github.com> Date: Tue, 16 Feb 2021 14:25:22 +0000 Subject: [PATCH 62/80] inverse with vtk or sitk Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/intensity/array.py | 2 +- monai/transforms/inverse_transform.py | 19 +--- monai/transforms/spatial/dictionary.py | 121 +++++++++++++++---------- requirements-dev.txt | 2 +- tests/test_inverse.py | 39 ++++---- 5 files changed, 97 insertions(+), 86 deletions(-) diff --git a/monai/transforms/intensity/array.py b/monai/transforms/intensity/array.py index 82847749f3..7fc7b23d3c 100644 --- a/monai/transforms/intensity/array.py +++ b/monai/transforms/intensity/array.py @@ -238,7 +238,7 @@ def __init__( self.dtype = dtype def _normalize(self, img: np.ndarray, sub=None, div=None) -> np.ndarray: - slices = (img != 0) if self.nonzero else np.ones(img.shape, dtype=np.bool_) + slices = (img != 0) if self.nonzero else np.ones(img.shape, dtype=bool) if not np.any(slices): return img diff --git a/monai/transforms/inverse_transform.py b/monai/transforms/inverse_transform.py index f9a8719ce6..38b5dd9ce4 100644 --- a/monai/transforms/inverse_transform.py +++ b/monai/transforms/inverse_transform.py @@ -128,7 +128,7 @@ def _inv_disp_w_vtk(fwd_disp): fwd_disp_vtk = vtk.vtkImageData() fwd_disp_vtk.SetOrigin(0, 0, 0) fwd_disp_vtk.SetSpacing(1, 1, 1) - fwd_disp_vtk.SetDimensions(*fwd_disp.shape[:-1]) + fwd_disp_vtk.SetDimensions(*fwd_disp.shape[:-1][::-1]) # VTK spacing opposite order to numpy fwd_disp_vtk.GetPointData().SetScalars(vtk_data_array) if __debug__: @@ -198,21 +198,10 @@ def compute_inverse_deformation( if use_package.lower() == "vtk": inv_disp = NonRigidTransform._inv_disp_w_vtk(fwd_disp) # If using sitk... - else: + elif use_package.lower() == "sitk": inv_disp = NonRigidTransform._inv_disp_w_sitk(fwd_disp, num_iters) - - if True: - import matplotlib.pyplot as plt - - fig, axes = plt.subplots(2, 2) - for i, direc1 in enumerate(["x", "y"]): - for j, (im, direc2) in enumerate(zip([fwd_disp, inv_disp], ["fwd", "inv"])): - ax = axes[i, j] - im_show = ax.imshow(im[..., i]) - ax.set_title(f"{direc2}{direc1}", fontsize=25) - ax.axis("off") - fig.colorbar(im_show, ax=ax) - plt.show() + else: + raise RuntimeError("Enter vtk or sitk for inverse calculation") # move tensor component back to beginning inv_disp = np.moveaxis(inv_disp, -1, 0) diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index 6068dee5c8..14ab17cc4f 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -676,7 +676,7 @@ def __init__( self.rand_2d_elastic = Rand2DElastic( spacing=spacing, magnitude_range=magnitude_range, - prob=prob, + prob=1.0, # because probability controlled by this class rotate_range=rotate_range, shear_range=shear_range, translate_range=translate_range, @@ -696,8 +696,20 @@ def set_random_state( return self def randomize(self, spatial_size: Sequence[int]) -> None: + self._do_transform = self.R.rand() < self.prob self.rand_2d_elastic.randomize(spatial_size) + @staticmethod + def cpg_to_dvf(cpg, spacing, output_shape): + grid = torch.nn.functional.interpolate( + recompute_scale_factor=True, + input=cpg.unsqueeze(0), + scale_factor=ensure_tuple_rep(spacing, 2), + mode=InterpolateMode.BILINEAR.value, + align_corners=False, + ) + return CenterSpatialCrop(roi_size=output_shape)(grid[0]) + def __call__( self, data: Mapping[Hashable, Union[np.ndarray, torch.Tensor]] ) -> Dict[Hashable, Union[np.ndarray, torch.Tensor]]: @@ -706,22 +718,17 @@ def __call__( sp_size = fall_back_tuple(self.rand_2d_elastic.spatial_size, data[self.keys[0]].shape[1:]) self.randomize(spatial_size=sp_size) - if self.rand_2d_elastic._do_transform: + if self._do_transform: cpg = self.rand_2d_elastic.deform_grid(spatial_size=sp_size) - cpg = self.rand_2d_elastic.rand_affine_grid(grid=cpg) - grid = torch.nn.functional.interpolate( # type: ignore - recompute_scale_factor=True, - input=cpg.unsqueeze(0), - scale_factor=ensure_tuple_rep(self.rand_2d_elastic.deform_grid.spacing, 2), - mode=InterpolateMode.BILINEAR.value, - align_corners=False, - ) - grid = CenterSpatialCrop(roi_size=sp_size)(grid[0]) + cpg_w_affine, affine = self.rand_2d_elastic.rand_affine_grid(grid=cpg, return_affine=True) + grid = self.cpg_to_dvf(cpg_w_affine, self.rand_2d_elastic.deform_grid.spacing, sp_size) + extra_info = {"cpg": deepcopy(cpg), "affine": deepcopy(affine)} else: grid = create_grid(spatial_size=sp_size) + extra_info = None for idx, key in enumerate(self.keys): - self.append_applied_transforms(d, key, extra_info={"grid": deepcopy(grid)}) + self.append_applied_transforms(d, key, extra_info=extra_info) d[key] = self.rand_2d_elastic.resampler( d[key], grid, mode=self.mode[idx], padding_mode=self.padding_mode[idx] ) @@ -729,23 +736,35 @@ def __call__( def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = deepcopy(dict(data)) + # This variable will be `not None` if vtk or sitk is present + inv_def_w_affine = None for idx, key in enumerate(self.keys): transform = self.get_most_recent_transform(d, key) - orig_size = transform["orig_size"] # Create inverse transform - fwd_def = transform["extra_info"]["grid"] - if fwd_def is None: - d[key] = CenterSpatialCrop(roi_size=orig_size)(d[key]) + if transform["do_transform"]: + orig_size = transform["orig_size"] + # Only need to calculate inverse deformation once as it is the same for all keys + if idx == 0: + # If magnitude == 0, then non-rigid component is identity -- so just create blank + if self.rand_2d_elastic.deform_grid.magnitude == (0.0, 0.0): + inv_def_no_affine = create_grid(spatial_size=orig_size) + else: + fwd_cpg_no_affine = transform["extra_info"]["cpg"] + fwd_def_no_affine = self.cpg_to_dvf(fwd_cpg_no_affine, self.rand_2d_elastic.deform_grid.spacing, orig_size) + inv_def_no_affine = self.compute_inverse_deformation(len(orig_size), fwd_def_no_affine) + # if inverse did not succeed (sitk or vtk present), data will not be changed. + if inv_def_no_affine is not None: + fwd_affine = transform["extra_info"]["affine"] + inv_affine = np.linalg.inv(fwd_affine) + inv_def_w_affine = AffineGrid(affine=inv_affine)(grid=inv_def_no_affine) + # Back to original size + inv_def_w_affine = CenterSpatialCrop(roi_size=orig_size)(inv_def_w_affine) + # Apply inverse transform + out = self.rand_2d_elastic.resampler(d[key], inv_def_w_affine, self.mode[idx], self.padding_mode[idx]) + d[key] = out.cpu().numpy() if isinstance(out, torch.Tensor) else out else: - inv_def = self.compute_inverse_deformation(len(orig_size), fwd_def) - # if no sitk, `inv_def` will be `None`, and data will not be changed. - if inv_def is not None: - # Back to original size - inv_def = CenterSpatialCrop(roi_size=orig_size)(inv_def) - # Apply inverse transform - out = self.rand_2d_elastic.resampler(d[key], inv_def, self.mode[idx], self.padding_mode[idx]) - d[key] = out.cpu().numpy() if isinstance(out, torch.Tensor) else out + d[key] = CenterSpatialCrop(roi_size=orig_size)(d[key]) # Remove the applied transform self.remove_most_recent_transform(d, key) @@ -819,7 +838,7 @@ def __init__( self.rand_3d_elastic = Rand3DElastic( sigma_range=sigma_range, magnitude_range=magnitude_range, - prob=prob, + prob=1.0, # because probability controlled by this class rotate_range=rotate_range, shear_range=shear_range, translate_range=translate_range, @@ -839,9 +858,8 @@ def set_random_state( return self def randomize(self, grid_size: Sequence[int]) -> None: + self._do_transform = self.R.rand() < self.prob self.rand_3d_elastic.randomize(grid_size) - self.prob = self.rand_3d_elastic.prob - self._do_transform = self.rand_3d_elastic._do_transform def __call__( self, data: Mapping[Hashable, Union[np.ndarray, torch.Tensor]] @@ -850,44 +868,49 @@ def __call__( sp_size = fall_back_tuple(self.rand_3d_elastic.spatial_size, data[self.keys[0]].shape[1:]) self.randomize(grid_size=sp_size) - grid = create_grid(spatial_size=sp_size) - if self.rand_3d_elastic._do_transform: + grid_no_affine = create_grid(spatial_size=sp_size) + if self._do_transform: device = self.rand_3d_elastic.device - grid = torch.tensor(grid).to(device) + grid_no_affine = torch.tensor(grid_no_affine).to(device) gaussian = GaussianFilter(spatial_dims=3, sigma=self.rand_3d_elastic.sigma, truncated=3.0).to(device) offset = torch.tensor(self.rand_3d_elastic.rand_offset, device=device).unsqueeze(0) - grid[:3] += gaussian(offset)[0] * self.rand_3d_elastic.magnitude - grid = self.rand_3d_elastic.rand_affine_grid(grid=grid) + grid_no_affine[:3] += gaussian(offset)[0] * self.rand_3d_elastic.magnitude + grid_w_affine, affine = self.rand_3d_elastic.rand_affine_grid(grid=grid_no_affine, return_affine=True) for idx, key in enumerate(self.keys): - self.append_applied_transforms(d, key, extra_info={"grid": grid.cpu().numpy()}) + self.append_applied_transforms(d, key, extra_info={"grid_no_affine": grid_no_affine.cpu().numpy(), "affine": affine}) d[key] = self.rand_3d_elastic.resampler( - d[key], grid, mode=self.mode[idx], padding_mode=self.padding_mode[idx] + d[key], grid_w_affine, mode=self.mode[idx], padding_mode=self.padding_mode[idx] ) return d def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = deepcopy(dict(data)) + # This variable will be `not None` if vtk or sitk is present + inv_def_w_affine = None for idx, key in enumerate(self.keys): transform = self.get_most_recent_transform(d, key) - orig_size = transform["orig_size"] # Create inverse transform - fwd_def = transform["extra_info"]["grid"] - if fwd_def is None: - d[key] = CenterSpatialCrop(roi_size=orig_size)(d[key]) + if transform["do_transform"]: + orig_size = transform["orig_size"] + # Only need to calculate inverse deformation once as it is the same for all keys + if idx == 0: + fwd_def_no_affine = transform["extra_info"]["grid_no_affine"] + inv_def_no_affine = self.compute_inverse_deformation(len(orig_size), fwd_def_no_affine) + # if inverse did not succeed (sitk or vtk present), data will not be changed. + if inv_def_no_affine is not None: + fwd_affine = transform["extra_info"]["affine"] + inv_affine = np.linalg.inv(fwd_affine) + inv_def_w_affine = AffineGrid(affine=inv_affine)(grid=inv_def_no_affine) + # Back to original size + inv_def_w_affine = CenterSpatialCrop(roi_size=orig_size)(inv_def_w_affine) + # Apply inverse transform + if inv_def_w_affine is not None: + out = self.rand_3d_elastic.resampler(d[key], inv_def_w_affine, self.mode[idx], self.padding_mode[idx]) + d[key] = out.cpu().numpy() if isinstance(out, torch.Tensor) else out else: - inv_def = self.compute_inverse_deformation(len(orig_size), fwd_def) - # if no sitk, `inv_def` will be `None`, and data will not be changed. - if inv_def is not None: - # Back to original size - inv_def = CenterSpatialCrop(roi_size=orig_size)(inv_def) - # Apply inverse transform - out = self.rand_3d_elastic.resampler(d[key], inv_def, self.mode[idx], self.padding_mode[idx]) - if isinstance(out, torch.Tensor): - d[key] = out.cpu().numpy() - else: - d[key] = out + d[key] = CenterSpatialCrop(roi_size=orig_size)(d[key]) # Remove the applied transform self.remove_most_recent_transform(d, key) diff --git a/requirements-dev.txt b/requirements-dev.txt index 42d07dff5a..6748964f5c 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -30,4 +30,4 @@ Sphinx==3.3.0 recommonmark==0.6.0 sphinx-autodoc-typehints==1.11.1 sphinx-rtd-theme==0.5.0 -SimpleITK +vtk \ No newline at end of file diff --git a/tests/test_inverse.py b/tests/test_inverse.py index 818a7583a7..e3859c271d 100644 --- a/tests/test_inverse.py +++ b/tests/test_inverse.py @@ -349,18 +349,18 @@ ( "Rand2DElasticd 2d", DATA_2D, - 1e-1, + 2e-1, Rand2DElasticd( KEYS, spacing=(10.0, 10.0), - magnitude_range=(2, 2), - # spatial_size=[155, 192], + magnitude_range=(1, 1), + spatial_size=[155, 192], prob=1, padding_mode="zeros", - # rotate_range=[(np.pi / 6, np.pi / 6), np.pi / 7], - # shear_range=[(0.5, 0.5)], - # translate_range=[10, 5], - # scale_range=[(0.8, 1.2), (0.9, 1.3)], + rotate_range=[(np.pi / 6, np.pi / 6)], + shear_range=[(0.5, 0.5)], + translate_range=[10, 5], + scale_range=[(1.2, 1.2), (1.3, 1.3)], ), ) ) @@ -370,19 +370,18 @@ ( "Rand3DElasticd 3d", DATA_3D, - 2e-1, - Rand3DElasticd( - KEYS, - sigma_range=(1, 3), - magnitude_range=(1.0, 2.0), - spatial_size=[155, 192, 200], - prob=1, - padding_mode="zeros", - rotate_range=[np.pi / 6, np.pi / 7], - shear_range=[(0.5, 0.5)], - translate_range=[10, 5], - scale_range=[(0.8, 1.2), (0.9, 1.3)], - ), + 1e-1, + Rand3DElasticd( + KEYS, + sigma_range=(3, 5), + magnitude_range=(100, 100), + prob=1, + padding_mode="zeros", + rotate_range=[np.pi / 6, np.pi / 7], + shear_range=[(0.5, 0.5), 0.2], + translate_range=[10, 5, 3], + scale_range=[(0.8, 1.2), (0.9, 1.3)], + ), ) ) From cf238d2542478427e4037483036b9b198289923b Mon Sep 17 00:00:00 2001 From: Rich <33289025+rijobro@users.noreply.github.com> Date: Tue, 16 Feb 2021 17:14:23 +0000 Subject: [PATCH 63/80] option for no collation after batch inverse Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/data/inverse_batch_transform.py | 22 ++++++++++++++++++---- tests/test_inverse.py | 25 +++++++++++++++++++++++++ 2 files changed, 43 insertions(+), 4 deletions(-) diff --git a/monai/data/inverse_batch_transform.py b/monai/data/inverse_batch_transform.py index 8ac9301b8a..1c5f27a8cc 100644 --- a/monai/data/inverse_batch_transform.py +++ b/monai/data/inverse_batch_transform.py @@ -9,9 +9,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Callable from monai.data.dataloader import DataLoader +from torch.utils.data.dataloader import DataLoader as TorchDataLoader from monai.data.dataset import Dataset -from monai.data.utils import decollate_batch +from monai.data.utils import decollate_batch, list_data_collate from monai.transforms.inverse_transform import InvertibleTransform __all__ = ["BatchInverseTransform"] @@ -30,17 +32,29 @@ def __getitem__(self, index: int): class BatchInverseTransform: """something""" - def __init__(self, transform: InvertibleTransform, loader) -> None: + def __init__( + self, transform: InvertibleTransform, loader: TorchDataLoader, collate_fn: Callable = None + ) -> None: """ Args: transform: a callable data transform on input data. loader: data loader used to generate the batch of data. + collate_fn: how to collate data after inverse transformations. Default will use the DataLoader's default collation method. + If returning images of different sizes, this will likely create an error (since the collation will concatenate arrays, + requiring them to be the same size). In this case, using `collate_fn=lambda x: x` might solve the problem. """ self.transform = transform self.batch_size = loader.batch_size self.num_workers = loader.num_workers + self.collate_fn = collate_fn def __call__(self, data): inv_ds = _BatchInverseDataset(data, self.transform) - inv_loader = DataLoader(inv_ds, batch_size=self.batch_size, num_workers=self.num_workers) - return next(iter(inv_loader)) + inv_loader = DataLoader(inv_ds, batch_size=self.batch_size, num_workers=self.num_workers, collate_fn=self.collate_fn) + try: + return next(iter(inv_loader)) + except RuntimeError as re: + re = str(re) + if "stack expects each tensor to be equal size" in re: + re += "\nMONAI hint: try creating `BatchInverseTransform` with `collate_fn=lambda x: x`." + raise RuntimeError(re) diff --git a/tests/test_inverse.py b/tests/test_inverse.py index e3859c271d..6e41791c2d 100644 --- a/tests/test_inverse.py +++ b/tests/test_inverse.py @@ -495,6 +495,31 @@ def test_w_dataloader(self, _, data, acceptable_diff, *transforms): if torch.cuda.is_available(): _ = model(inputs) + def test_diff_sized_inputs(self): + + key = "image" + test_data = [{key: AddChannel()(create_test_image_2d(100 + i, 101 + i)[0])} for i in range(4)] + + batch_size = 2 + # num workers = 0 for mac + num_workers = 2 if sys.platform != "darwin" else 0 + transforms = Compose([SpatialPadd(key, (150, 153))]) + + dataset = CacheDataset(test_data, transform=transforms, progress=False) + loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers) + # blank collate function since input are different size + inv_batch = BatchInverseTransform(transforms, loader, collate_fn=lambda x: x) + + for batch_idx, batch_data in enumerate(loader): + fwd = decollate_batch(batch_data) + fwd_bck = inv_batch(batch_data) + + for idx, (_fwd, _fwd_bck) in enumerate(zip(fwd, fwd_bck)): + unmodified = test_data[batch_idx * batch_size + idx] + self.check_inverse("diff_sized_inputs", [key], unmodified, _fwd_bck, _fwd, 0) + + + @parameterized.expand(TESTS_FAIL) def test_fail(self, data, _, *transform): d = transform[0](data) From 0047275ebe3a350d697e029d7b2b16238eb27504 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Wed, 17 Feb 2021 11:08:35 +0000 Subject: [PATCH 64/80] create same keys whether random transform used or not Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/spatial/dictionary.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index 14ab17cc4f..dee37265a1 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -399,14 +399,11 @@ def randomize(self, data: Optional[Any] = None) -> None: def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Mapping[Hashable, np.ndarray]: self.randomize() d = dict(data) - if not self._do_transform: - for key in self.keys: - self.append_applied_transforms(d, key) - return d rotator = Rotate90(self._rand_k, self.spatial_axes) for key in self.keys: - d[key] = rotator(d[key]) + if self._do_transform: + d[key] = rotator(d[key]) self.append_applied_transforms(d, key, extra_info={"rand_k": self._rand_k}) return d @@ -1178,7 +1175,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda d = dict(data) if not self._do_transform: for key in self.keys: - self.append_applied_transforms(d, key) + self.append_applied_transforms(d, key, extra_info={"rot_mat": np.eye(4)}) return d angle: Union[Sequence[float], float] = self.x if d[self.keys[0]].ndim == 3 else (self.x, self.y, self.z) rotator = Rotate( @@ -1368,7 +1365,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda d = dict(data) if not self._do_transform: for key in self.keys: - self.append_applied_transforms(d, key) + self.append_applied_transforms(d, key, extra_info={"zoom": self._zoom}) return d img_dims = data[self.keys[0]].ndim From a3055bf2f10d23c57ff111c2d3591d67ac517735 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Fri, 19 Feb 2021 10:21:59 +0000 Subject: [PATCH 65/80] pad collate Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/data/__init__.py | 1 + monai/data/image_reader.py | 8 ++--- monai/data/inverse_batch_transform.py | 4 +-- monai/data/utils.py | 48 +++++++++++++++++++++++++- monai/transforms/croppad/array.py | 6 ++-- monai/transforms/croppad/dictionary.py | 6 ++-- monai/transforms/spatial/array.py | 18 +++++----- 7 files changed, 69 insertions(+), 22 deletions(-) diff --git a/monai/data/__init__.py b/monai/data/__init__.py index d63b604ecd..380afb7773 100644 --- a/monai/data/__init__.py +++ b/monai/data/__init__.py @@ -58,4 +58,5 @@ to_affine_nd, worker_init_fn, zoom_affine, + pad_list_data_collate, ) diff --git a/monai/data/image_reader.py b/monai/data/image_reader.py index d0f5f4aefc..a897fc670e 100644 --- a/monai/data/image_reader.py +++ b/monai/data/image_reader.py @@ -17,10 +17,10 @@ from torch.utils.data._utils.collate import np_str_obj_array_pattern from monai.config import DtypeLike, KeysCollection -from monai.data.utils import correct_nifti_header_if_necessary +import monai.data.utils from monai.utils import ensure_tuple, optional_import -from .utils import is_supported_format +import monai.data.utils if TYPE_CHECKING: import itk # type: ignore @@ -311,7 +311,7 @@ def verify_suffix(self, filename: Union[Sequence[str], str]) -> bool: """ suffixes: Sequence[str] = ["nii", "nii.gz"] - return has_nib and is_supported_format(filename, suffixes) + return has_nib and monai.data.is_supported_format(filename, suffixes) def read(self, data: Union[Sequence[str], str], **kwargs): """ @@ -332,7 +332,7 @@ def read(self, data: Union[Sequence[str], str], **kwargs): kwargs_.update(kwargs) for name in filenames: img = nib.load(name, **kwargs_) - img = correct_nifti_header_if_necessary(img) + img = monai.data.utils.correct_nifti_header_if_necessary(img) img_.append(img) return img_ if len(filenames) > 1 else img_[0] diff --git a/monai/data/inverse_batch_transform.py b/monai/data/inverse_batch_transform.py index 1c5f27a8cc..71f9166fad 100644 --- a/monai/data/inverse_batch_transform.py +++ b/monai/data/inverse_batch_transform.py @@ -9,11 +9,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable +from typing import Callable, Optional from monai.data.dataloader import DataLoader from torch.utils.data.dataloader import DataLoader as TorchDataLoader from monai.data.dataset import Dataset -from monai.data.utils import decollate_batch, list_data_collate +from monai.data.utils import decollate_batch from monai.transforms.inverse_transform import InvertibleTransform __all__ = ["BatchInverseTransform"] diff --git a/monai/data/utils.py b/monai/data/utils.py index b6e0da8db2..233196fbd9 100644 --- a/monai/data/utils.py +++ b/monai/data/utils.py @@ -24,6 +24,7 @@ import torch from torch.utils.data import DistributedSampler as _TorchDistributedSampler from torch.utils.data._utils.collate import default_collate +import monai.transforms.croppad.dictionary from monai.networks.layers.simplelayers import GaussianFilter from monai.utils import ( @@ -65,6 +66,7 @@ "pickle_hashing", "sorted_dict", "decollate_batch", + "pad_list_data_collate", ] @@ -242,7 +244,51 @@ def list_data_collate(batch: Sequence): """ elem = batch[0] data = [i for k in batch for i in k] if isinstance(elem, list) else batch - return default_collate(data) + try: + return default_collate(data) + except RuntimeError as re: + re_str = str(re) + if "stack expects each tensor to be equal size" in re_str: + re_str += "\nMONAI hint: if your transforms intentionally create images of different shapes, creating your `DataLoader` with `collate_fn=pad_list_data_collate` might solve this problem (check its documentation)." + raise RuntimeError(re_str) + + +def pad_list_data_collate(batch: Sequence): + """ + Same as MONAI's ``list_data_collate``, except any tensors are centrally padded to match the shape of the biggest tensor in each dimension + + Note: + Need to use this collate if apply some transforms that can generate batch data. + + """ + for key in batch[0].keys(): + max_shapes = [] + for elem in batch: + if not isinstance(elem[key], (torch.Tensor, np.ndarray)): + break + max_shapes.append(elem[key].shape[1:]) + # len > 0 if objects were arrays + if len(max_shapes) == 0: + continue + max_shape = np.array(max_shapes).max(axis=0) + # If all same size, skip + if np.all(np.array(max_shapes).min(axis=0) == max_shape): + continue + # Do we need to convert output to Tensor? + output_to_tensor = isinstance(batch[0][key], torch.Tensor) + + # Use `SpatialPadd` to match sizes + # Default params are central padding, padding with 0's + # Use the dictionary version so that the transformation is recorded + padder = monai.transforms.croppad.dictionary.SpatialPadd(key, max_shape) # type: ignore + for idx in range(len(batch)): + batch[idx][key] = padder(batch[idx])[key] + if output_to_tensor: + batch[idx][key] = torch.Tensor(batch[idx][key]) + + # After padding, use default list collator + return list_data_collate(batch) + def decollate_batch(data: dict, batch_size: Optional[int] = None): diff --git a/monai/transforms/croppad/array.py b/monai/transforms/croppad/array.py index ef5e0019bd..e7b7f04138 100644 --- a/monai/transforms/croppad/array.py +++ b/monai/transforms/croppad/array.py @@ -19,7 +19,7 @@ import torch from monai.config import IndexSelection -from monai.data.utils import get_random_patch, get_valid_patch_size +import monai.data.utils from monai.transforms.transform import Randomizable, Transform from monai.transforms.utils import ( generate_pos_neg_label_crop_centers, @@ -304,8 +304,8 @@ def randomize(self, img_size: Sequence[int]) -> None: if self.random_size: self._size = tuple((self.R.randint(low=self._size[i], high=img_size[i] + 1) for i in range(len(img_size)))) if self.random_center: - valid_size = get_valid_patch_size(img_size, self._size) - self._slices = (slice(None),) + get_random_patch(img_size, valid_size, self.R) + valid_size = monai.data.utils.get_valid_patch_size(img_size, self._size) + self._slices = (slice(None),) + monai.data.utils.get_random_patch(img_size, valid_size, self.R) def __call__(self, img: np.ndarray): """ diff --git a/monai/transforms/croppad/dictionary.py b/monai/transforms/croppad/dictionary.py index 3dd7667856..4bf2d2080a 100644 --- a/monai/transforms/croppad/dictionary.py +++ b/monai/transforms/croppad/dictionary.py @@ -22,7 +22,7 @@ import numpy as np from monai.config import IndexSelection, KeysCollection -from monai.data.utils import get_random_patch, get_valid_patch_size +import monai.data.utils from monai.transforms.croppad.array import ( BorderPad, BoundingRect, @@ -412,8 +412,8 @@ def randomize(self, img_size: Sequence[int]) -> None: if self.random_size: self._size = [self.R.randint(low=self._size[i], high=img_size[i] + 1) for i in range(len(img_size))] if self.random_center: - valid_size = get_valid_patch_size(img_size, self._size) - self._slices = (slice(None),) + get_random_patch(img_size, valid_size, self.R) + valid_size = monai.data.utils.get_valid_patch_size(img_size, self._size) + self._slices = (slice(None),) + monai.data.utils.get_random_patch(img_size, valid_size, self.R) pass def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index c229c1d85b..d4f190125d 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -20,7 +20,7 @@ import torch from monai.config import USE_COMPILED, DtypeLike -from monai.data.utils import compute_shape_offset, to_affine_nd, zoom_affine +import monai.data.utils from monai.networks.layers import AffineTransform, GaussianFilter, grid_pull from monai.transforms.croppad.array import CenterSpatialCrop from monai.transforms.transform import Randomizable, Transform @@ -160,24 +160,24 @@ def __call__( affine = np.eye(sr + 1, dtype=np.float64) affine_ = np.eye(sr + 1, dtype=np.float64) else: - affine_ = to_affine_nd(sr, affine) + affine_ = monai.data.utils.to_affine_nd(sr, affine) out_d = self.pixdim[:sr] if out_d.size < sr: out_d = np.append(out_d, [1.0] * (out_d.size - sr)) if np.any(out_d <= 0): raise ValueError(f"pixdim must be positive, got {out_d}.") # compute output affine, shape and offset - new_affine = zoom_affine(affine_, out_d, diagonal=self.diagonal) - output_shape, offset = compute_shape_offset(data_array.shape[1:], affine_, new_affine) + new_affine = monai.data.utils.zoom_affine(affine_, out_d, diagonal=self.diagonal) + output_shape, offset = monai.data.utils.compute_shape_offset(data_array.shape[1:], affine_, new_affine) new_affine[:sr, -1] = offset[:sr] transform = np.linalg.inv(affine_) @ new_affine # adapt to the actual rank - transform = to_affine_nd(sr, transform) + transform = monai.data.utils.to_affine_nd(sr, transform) # no resampling if it's identity transform if np.allclose(transform, np.diag(np.ones(len(transform))), atol=1e-3): output_data = data_array.copy().astype(np.float32) - new_affine = to_affine_nd(affine, new_affine) + new_affine = monai.data.utils.to_affine_nd(affine, new_affine) return output_data, affine, new_affine # resample @@ -195,7 +195,7 @@ def __call__( spatial_size=output_shape, ) output_data = np.asarray(output_data.squeeze(0).detach().cpu().numpy(), dtype=np.float32) # type: ignore - new_affine = to_affine_nd(affine, new_affine) + new_affine = monai.data.utils.to_affine_nd(affine, new_affine) return output_data, affine, new_affine @@ -261,7 +261,7 @@ def __call__( affine = np.eye(sr + 1, dtype=np.float64) affine_ = np.eye(sr + 1, dtype=np.float64) else: - affine_ = to_affine_nd(sr, affine) + affine_ = monai.data.utils.to_affine_nd(sr, affine) src = nib.io_orientation(affine_) if self.as_closest_canonical: spatial_ornt = src @@ -280,7 +280,7 @@ def __call__( shape = data_array.shape[1:] data_array = np.ascontiguousarray(nib.orientations.apply_orientation(data_array, ornt)) new_affine = affine_ @ nib.orientations.inv_ornt_aff(spatial_ornt, shape) - new_affine = to_affine_nd(affine, new_affine) + new_affine = monai.data.utils.to_affine_nd(affine, new_affine) return data_array, affine, new_affine From ebee3c572c31211fcf3fa68df5eb3b4154da5276 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Fri, 19 Feb 2021 10:22:39 +0000 Subject: [PATCH 66/80] batch inverse improvement Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/data/inverse_batch_transform.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/monai/data/inverse_batch_transform.py b/monai/data/inverse_batch_transform.py index 71f9166fad..18ccc89193 100644 --- a/monai/data/inverse_batch_transform.py +++ b/monai/data/inverse_batch_transform.py @@ -33,7 +33,7 @@ class BatchInverseTransform: """something""" def __init__( - self, transform: InvertibleTransform, loader: TorchDataLoader, collate_fn: Callable = None + self, transform: InvertibleTransform, loader: TorchDataLoader, collate_fn: Optional[Callable] = None ) -> None: """ Args: @@ -54,7 +54,7 @@ def __call__(self, data): try: return next(iter(inv_loader)) except RuntimeError as re: - re = str(re) - if "stack expects each tensor to be equal size" in re: - re += "\nMONAI hint: try creating `BatchInverseTransform` with `collate_fn=lambda x: x`." - raise RuntimeError(re) + re_str = str(re) + if "stack expects each tensor to be equal size" in re_str: + re_str += "\nMONAI hint: try creating `BatchInverseTransform` with `collate_fn=lambda x: x`." + raise RuntimeError(re_str) From 0f955d3b835be065227a0218beed279c7875cf92 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Fri, 19 Feb 2021 10:23:05 +0000 Subject: [PATCH 67/80] bug fix in inverse RandAffined Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/spatial/dictionary.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index dee37265a1..db8c27142a 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -543,7 +543,7 @@ def __init__( MapTransform.__init__(self, keys) Randomizable.__init__(self, prob) self.rand_affine = RandAffine( - prob=prob, + prob=1.0, # because probability handled in this class rotate_range=rotate_range, shear_range=shear_range, translate_range=translate_range, @@ -563,6 +563,7 @@ def set_random_state( return self def randomize(self, data: Optional[Any] = None) -> None: + self._do_transform = self.R.rand() < self.prob self.rand_affine.randomize() def __call__( @@ -572,7 +573,7 @@ def __call__( self.randomize() sp_size = fall_back_tuple(self.rand_affine.spatial_size, data[self.keys[0]].shape[1:]) - if self.rand_affine._do_transform: + if self._do_transform: grid, affine = self.rand_affine.rand_affine_grid(spatial_size=sp_size, return_affine=True) else: grid = create_grid(spatial_size=sp_size) @@ -597,10 +598,12 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar grid: torch.Tensor = affine_grid(orig_size) # type: ignore # Apply inverse transform - out = self.rand_affine.resampler(d[key], grid, self.mode[idx], self.padding_mode[idx]) - # Convert to original output type - if isinstance(out, torch.Tensor): - d[key] = out.cpu().numpy() + d[key] = self.rand_affine.resampler(torch.Tensor(d[key]), grid, self.mode[idx], self.padding_mode[idx]) + + # Convert to numpy + if isinstance(d[key], torch.Tensor): + d[key] = d[key].cpu().numpy() + # Remove the applied transform self.remove_most_recent_transform(d, key) From 917bb17d753dafc648c9d447bbd7fcd1c499241c Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Fri, 19 Feb 2021 10:24:08 +0000 Subject: [PATCH 68/80] test for pad_collation Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- tests/test_pad_collation.py | 63 +++++++++++++++++++++++++++++++++++++ 1 file changed, 63 insertions(+) create mode 100644 tests/test_pad_collation.py diff --git a/tests/test_pad_collation.py b/tests/test_pad_collation.py new file mode 100644 index 0000000000..f91f9d4b97 --- /dev/null +++ b/tests/test_pad_collation.py @@ -0,0 +1,63 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from monai.data.utils import pad_list_data_collate +import unittest +from typing import List, Tuple + +import numpy as np +from parameterized import parameterized + +from monai.transforms import ( + RandRotated, + RandSpatialCropd, + RandZoomd, + RandRotate90d, +) +from monai.utils import set_determinism + +set_determinism(seed=0) + +import numpy as np +from monai.data import CacheDataset, DataLoader +from monai.transforms import RandSpatialCropd, RandRotated + +TESTS: List[Tuple] = [] + +TESTS.append((RandSpatialCropd("image", roi_size=[8, 7], random_size=True),)) +TESTS.append((RandRotated("image", prob=1, range_x=np.pi, keep_size=False),)) +TESTS.append((RandZoomd("image", prob=1, min_zoom=1.1, max_zoom=2.0, keep_size=False),)) +TESTS.append((RandRotate90d("image", prob=1, max_k=2),)) + +class TestPadCollation(unittest.TestCase): + def setUp(self) -> None: + # image is non square to throw rotation errors + im = np.arange(0, 10 * 9).reshape(1, 10, 9) + self.data = [{"image": im} for _ in range(2)] + + @parameterized.expand(TESTS) + def test_pad_collation(self, transform): + + dataset = CacheDataset(self.data, transform, progress=False) + + # Default collation should raise an error + loader_fail = DataLoader(dataset, batch_size=2) + with self.assertRaises(RuntimeError): + for _ in loader_fail: + pass + + # Padded collation shouldn't + loader = DataLoader(dataset, batch_size=2, collate_fn=pad_list_data_collate) + for _ in loader: + pass + +if __name__ == "__main__": + unittest.main() From 8de158f4f7686bc3164bc2415a19af9cab804a8d Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Fri, 19 Feb 2021 10:24:08 +0000 Subject: [PATCH 69/80] Revert "test for pad_collation" This reverts commit 917bb17d753dafc648c9d447bbd7fcd1c499241c. Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- tests/test_pad_collation.py | 63 ------------------------------------- 1 file changed, 63 deletions(-) delete mode 100644 tests/test_pad_collation.py diff --git a/tests/test_pad_collation.py b/tests/test_pad_collation.py deleted file mode 100644 index f91f9d4b97..0000000000 --- a/tests/test_pad_collation.py +++ /dev/null @@ -1,63 +0,0 @@ -# Copyright 2020 - 2021 MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from monai.data.utils import pad_list_data_collate -import unittest -from typing import List, Tuple - -import numpy as np -from parameterized import parameterized - -from monai.transforms import ( - RandRotated, - RandSpatialCropd, - RandZoomd, - RandRotate90d, -) -from monai.utils import set_determinism - -set_determinism(seed=0) - -import numpy as np -from monai.data import CacheDataset, DataLoader -from monai.transforms import RandSpatialCropd, RandRotated - -TESTS: List[Tuple] = [] - -TESTS.append((RandSpatialCropd("image", roi_size=[8, 7], random_size=True),)) -TESTS.append((RandRotated("image", prob=1, range_x=np.pi, keep_size=False),)) -TESTS.append((RandZoomd("image", prob=1, min_zoom=1.1, max_zoom=2.0, keep_size=False),)) -TESTS.append((RandRotate90d("image", prob=1, max_k=2),)) - -class TestPadCollation(unittest.TestCase): - def setUp(self) -> None: - # image is non square to throw rotation errors - im = np.arange(0, 10 * 9).reshape(1, 10, 9) - self.data = [{"image": im} for _ in range(2)] - - @parameterized.expand(TESTS) - def test_pad_collation(self, transform): - - dataset = CacheDataset(self.data, transform, progress=False) - - # Default collation should raise an error - loader_fail = DataLoader(dataset, batch_size=2) - with self.assertRaises(RuntimeError): - for _ in loader_fail: - pass - - # Padded collation shouldn't - loader = DataLoader(dataset, batch_size=2, collate_fn=pad_list_data_collate) - for _ in loader: - pass - -if __name__ == "__main__": - unittest.main() From a3194bdad48295cf73448b37016290e74f16d65c Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Fri, 19 Feb 2021 10:21:59 +0000 Subject: [PATCH 70/80] Revert "pad collate" This reverts commit a3055bf2f10d23c57ff111c2d3591d67ac517735. Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/data/__init__.py | 1 - monai/data/image_reader.py | 8 ++--- monai/data/inverse_batch_transform.py | 4 +-- monai/data/utils.py | 48 +------------------------- monai/transforms/croppad/array.py | 6 ++-- monai/transforms/croppad/dictionary.py | 6 ++-- monai/transforms/spatial/array.py | 18 +++++----- 7 files changed, 22 insertions(+), 69 deletions(-) diff --git a/monai/data/__init__.py b/monai/data/__init__.py index 380afb7773..d63b604ecd 100644 --- a/monai/data/__init__.py +++ b/monai/data/__init__.py @@ -58,5 +58,4 @@ to_affine_nd, worker_init_fn, zoom_affine, - pad_list_data_collate, ) diff --git a/monai/data/image_reader.py b/monai/data/image_reader.py index a897fc670e..d0f5f4aefc 100644 --- a/monai/data/image_reader.py +++ b/monai/data/image_reader.py @@ -17,10 +17,10 @@ from torch.utils.data._utils.collate import np_str_obj_array_pattern from monai.config import DtypeLike, KeysCollection -import monai.data.utils +from monai.data.utils import correct_nifti_header_if_necessary from monai.utils import ensure_tuple, optional_import -import monai.data.utils +from .utils import is_supported_format if TYPE_CHECKING: import itk # type: ignore @@ -311,7 +311,7 @@ def verify_suffix(self, filename: Union[Sequence[str], str]) -> bool: """ suffixes: Sequence[str] = ["nii", "nii.gz"] - return has_nib and monai.data.is_supported_format(filename, suffixes) + return has_nib and is_supported_format(filename, suffixes) def read(self, data: Union[Sequence[str], str], **kwargs): """ @@ -332,7 +332,7 @@ def read(self, data: Union[Sequence[str], str], **kwargs): kwargs_.update(kwargs) for name in filenames: img = nib.load(name, **kwargs_) - img = monai.data.utils.correct_nifti_header_if_necessary(img) + img = correct_nifti_header_if_necessary(img) img_.append(img) return img_ if len(filenames) > 1 else img_[0] diff --git a/monai/data/inverse_batch_transform.py b/monai/data/inverse_batch_transform.py index 18ccc89193..219bc602e4 100644 --- a/monai/data/inverse_batch_transform.py +++ b/monai/data/inverse_batch_transform.py @@ -9,11 +9,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, Optional +from typing import Callable from monai.data.dataloader import DataLoader from torch.utils.data.dataloader import DataLoader as TorchDataLoader from monai.data.dataset import Dataset -from monai.data.utils import decollate_batch +from monai.data.utils import decollate_batch, list_data_collate from monai.transforms.inverse_transform import InvertibleTransform __all__ = ["BatchInverseTransform"] diff --git a/monai/data/utils.py b/monai/data/utils.py index 233196fbd9..b6e0da8db2 100644 --- a/monai/data/utils.py +++ b/monai/data/utils.py @@ -24,7 +24,6 @@ import torch from torch.utils.data import DistributedSampler as _TorchDistributedSampler from torch.utils.data._utils.collate import default_collate -import monai.transforms.croppad.dictionary from monai.networks.layers.simplelayers import GaussianFilter from monai.utils import ( @@ -66,7 +65,6 @@ "pickle_hashing", "sorted_dict", "decollate_batch", - "pad_list_data_collate", ] @@ -244,51 +242,7 @@ def list_data_collate(batch: Sequence): """ elem = batch[0] data = [i for k in batch for i in k] if isinstance(elem, list) else batch - try: - return default_collate(data) - except RuntimeError as re: - re_str = str(re) - if "stack expects each tensor to be equal size" in re_str: - re_str += "\nMONAI hint: if your transforms intentionally create images of different shapes, creating your `DataLoader` with `collate_fn=pad_list_data_collate` might solve this problem (check its documentation)." - raise RuntimeError(re_str) - - -def pad_list_data_collate(batch: Sequence): - """ - Same as MONAI's ``list_data_collate``, except any tensors are centrally padded to match the shape of the biggest tensor in each dimension - - Note: - Need to use this collate if apply some transforms that can generate batch data. - - """ - for key in batch[0].keys(): - max_shapes = [] - for elem in batch: - if not isinstance(elem[key], (torch.Tensor, np.ndarray)): - break - max_shapes.append(elem[key].shape[1:]) - # len > 0 if objects were arrays - if len(max_shapes) == 0: - continue - max_shape = np.array(max_shapes).max(axis=0) - # If all same size, skip - if np.all(np.array(max_shapes).min(axis=0) == max_shape): - continue - # Do we need to convert output to Tensor? - output_to_tensor = isinstance(batch[0][key], torch.Tensor) - - # Use `SpatialPadd` to match sizes - # Default params are central padding, padding with 0's - # Use the dictionary version so that the transformation is recorded - padder = monai.transforms.croppad.dictionary.SpatialPadd(key, max_shape) # type: ignore - for idx in range(len(batch)): - batch[idx][key] = padder(batch[idx])[key] - if output_to_tensor: - batch[idx][key] = torch.Tensor(batch[idx][key]) - - # After padding, use default list collator - return list_data_collate(batch) - + return default_collate(data) def decollate_batch(data: dict, batch_size: Optional[int] = None): diff --git a/monai/transforms/croppad/array.py b/monai/transforms/croppad/array.py index e7b7f04138..ef5e0019bd 100644 --- a/monai/transforms/croppad/array.py +++ b/monai/transforms/croppad/array.py @@ -19,7 +19,7 @@ import torch from monai.config import IndexSelection -import monai.data.utils +from monai.data.utils import get_random_patch, get_valid_patch_size from monai.transforms.transform import Randomizable, Transform from monai.transforms.utils import ( generate_pos_neg_label_crop_centers, @@ -304,8 +304,8 @@ def randomize(self, img_size: Sequence[int]) -> None: if self.random_size: self._size = tuple((self.R.randint(low=self._size[i], high=img_size[i] + 1) for i in range(len(img_size)))) if self.random_center: - valid_size = monai.data.utils.get_valid_patch_size(img_size, self._size) - self._slices = (slice(None),) + monai.data.utils.get_random_patch(img_size, valid_size, self.R) + valid_size = get_valid_patch_size(img_size, self._size) + self._slices = (slice(None),) + get_random_patch(img_size, valid_size, self.R) def __call__(self, img: np.ndarray): """ diff --git a/monai/transforms/croppad/dictionary.py b/monai/transforms/croppad/dictionary.py index 4bf2d2080a..3dd7667856 100644 --- a/monai/transforms/croppad/dictionary.py +++ b/monai/transforms/croppad/dictionary.py @@ -22,7 +22,7 @@ import numpy as np from monai.config import IndexSelection, KeysCollection -import monai.data.utils +from monai.data.utils import get_random_patch, get_valid_patch_size from monai.transforms.croppad.array import ( BorderPad, BoundingRect, @@ -412,8 +412,8 @@ def randomize(self, img_size: Sequence[int]) -> None: if self.random_size: self._size = [self.R.randint(low=self._size[i], high=img_size[i] + 1) for i in range(len(img_size))] if self.random_center: - valid_size = monai.data.utils.get_valid_patch_size(img_size, self._size) - self._slices = (slice(None),) + monai.data.utils.get_random_patch(img_size, valid_size, self.R) + valid_size = get_valid_patch_size(img_size, self._size) + self._slices = (slice(None),) + get_random_patch(img_size, valid_size, self.R) pass def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index d4f190125d..c229c1d85b 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -20,7 +20,7 @@ import torch from monai.config import USE_COMPILED, DtypeLike -import monai.data.utils +from monai.data.utils import compute_shape_offset, to_affine_nd, zoom_affine from monai.networks.layers import AffineTransform, GaussianFilter, grid_pull from monai.transforms.croppad.array import CenterSpatialCrop from monai.transforms.transform import Randomizable, Transform @@ -160,24 +160,24 @@ def __call__( affine = np.eye(sr + 1, dtype=np.float64) affine_ = np.eye(sr + 1, dtype=np.float64) else: - affine_ = monai.data.utils.to_affine_nd(sr, affine) + affine_ = to_affine_nd(sr, affine) out_d = self.pixdim[:sr] if out_d.size < sr: out_d = np.append(out_d, [1.0] * (out_d.size - sr)) if np.any(out_d <= 0): raise ValueError(f"pixdim must be positive, got {out_d}.") # compute output affine, shape and offset - new_affine = monai.data.utils.zoom_affine(affine_, out_d, diagonal=self.diagonal) - output_shape, offset = monai.data.utils.compute_shape_offset(data_array.shape[1:], affine_, new_affine) + new_affine = zoom_affine(affine_, out_d, diagonal=self.diagonal) + output_shape, offset = compute_shape_offset(data_array.shape[1:], affine_, new_affine) new_affine[:sr, -1] = offset[:sr] transform = np.linalg.inv(affine_) @ new_affine # adapt to the actual rank - transform = monai.data.utils.to_affine_nd(sr, transform) + transform = to_affine_nd(sr, transform) # no resampling if it's identity transform if np.allclose(transform, np.diag(np.ones(len(transform))), atol=1e-3): output_data = data_array.copy().astype(np.float32) - new_affine = monai.data.utils.to_affine_nd(affine, new_affine) + new_affine = to_affine_nd(affine, new_affine) return output_data, affine, new_affine # resample @@ -195,7 +195,7 @@ def __call__( spatial_size=output_shape, ) output_data = np.asarray(output_data.squeeze(0).detach().cpu().numpy(), dtype=np.float32) # type: ignore - new_affine = monai.data.utils.to_affine_nd(affine, new_affine) + new_affine = to_affine_nd(affine, new_affine) return output_data, affine, new_affine @@ -261,7 +261,7 @@ def __call__( affine = np.eye(sr + 1, dtype=np.float64) affine_ = np.eye(sr + 1, dtype=np.float64) else: - affine_ = monai.data.utils.to_affine_nd(sr, affine) + affine_ = to_affine_nd(sr, affine) src = nib.io_orientation(affine_) if self.as_closest_canonical: spatial_ornt = src @@ -280,7 +280,7 @@ def __call__( shape = data_array.shape[1:] data_array = np.ascontiguousarray(nib.orientations.apply_orientation(data_array, ornt)) new_affine = affine_ @ nib.orientations.inv_ornt_aff(spatial_ornt, shape) - new_affine = monai.data.utils.to_affine_nd(affine, new_affine) + new_affine = to_affine_nd(affine, new_affine) return data_array, affine, new_affine From e4d6f00a4417686648a0169417eeed8203dad586 Mon Sep 17 00:00:00 2001 From: Rich <33289025+rijobro@users.noreply.github.com> Date: Fri, 19 Feb 2021 10:26:59 +0000 Subject: [PATCH 71/80] pad_collation Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/data/__init__.py | 1 + monai/data/image_reader.py | 8 ++-- monai/data/inverse_batch_transform.py | 4 +- monai/data/utils.py | 48 +++++++++++++++++++- monai/transforms/croppad/array.py | 6 +-- monai/transforms/croppad/dictionary.py | 6 +-- monai/transforms/spatial/array.py | 18 ++++---- tests/test_pad_collation.py | 63 ++++++++++++++++++++++++++ 8 files changed, 132 insertions(+), 22 deletions(-) create mode 100644 tests/test_pad_collation.py diff --git a/monai/data/__init__.py b/monai/data/__init__.py index d63b604ecd..380afb7773 100644 --- a/monai/data/__init__.py +++ b/monai/data/__init__.py @@ -58,4 +58,5 @@ to_affine_nd, worker_init_fn, zoom_affine, + pad_list_data_collate, ) diff --git a/monai/data/image_reader.py b/monai/data/image_reader.py index d0f5f4aefc..a897fc670e 100644 --- a/monai/data/image_reader.py +++ b/monai/data/image_reader.py @@ -17,10 +17,10 @@ from torch.utils.data._utils.collate import np_str_obj_array_pattern from monai.config import DtypeLike, KeysCollection -from monai.data.utils import correct_nifti_header_if_necessary +import monai.data.utils from monai.utils import ensure_tuple, optional_import -from .utils import is_supported_format +import monai.data.utils if TYPE_CHECKING: import itk # type: ignore @@ -311,7 +311,7 @@ def verify_suffix(self, filename: Union[Sequence[str], str]) -> bool: """ suffixes: Sequence[str] = ["nii", "nii.gz"] - return has_nib and is_supported_format(filename, suffixes) + return has_nib and monai.data.is_supported_format(filename, suffixes) def read(self, data: Union[Sequence[str], str], **kwargs): """ @@ -332,7 +332,7 @@ def read(self, data: Union[Sequence[str], str], **kwargs): kwargs_.update(kwargs) for name in filenames: img = nib.load(name, **kwargs_) - img = correct_nifti_header_if_necessary(img) + img = monai.data.utils.correct_nifti_header_if_necessary(img) img_.append(img) return img_ if len(filenames) > 1 else img_[0] diff --git a/monai/data/inverse_batch_transform.py b/monai/data/inverse_batch_transform.py index 219bc602e4..18ccc89193 100644 --- a/monai/data/inverse_batch_transform.py +++ b/monai/data/inverse_batch_transform.py @@ -9,11 +9,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable +from typing import Callable, Optional from monai.data.dataloader import DataLoader from torch.utils.data.dataloader import DataLoader as TorchDataLoader from monai.data.dataset import Dataset -from monai.data.utils import decollate_batch, list_data_collate +from monai.data.utils import decollate_batch from monai.transforms.inverse_transform import InvertibleTransform __all__ = ["BatchInverseTransform"] diff --git a/monai/data/utils.py b/monai/data/utils.py index b6e0da8db2..233196fbd9 100644 --- a/monai/data/utils.py +++ b/monai/data/utils.py @@ -24,6 +24,7 @@ import torch from torch.utils.data import DistributedSampler as _TorchDistributedSampler from torch.utils.data._utils.collate import default_collate +import monai.transforms.croppad.dictionary from monai.networks.layers.simplelayers import GaussianFilter from monai.utils import ( @@ -65,6 +66,7 @@ "pickle_hashing", "sorted_dict", "decollate_batch", + "pad_list_data_collate", ] @@ -242,7 +244,51 @@ def list_data_collate(batch: Sequence): """ elem = batch[0] data = [i for k in batch for i in k] if isinstance(elem, list) else batch - return default_collate(data) + try: + return default_collate(data) + except RuntimeError as re: + re_str = str(re) + if "stack expects each tensor to be equal size" in re_str: + re_str += "\nMONAI hint: if your transforms intentionally create images of different shapes, creating your `DataLoader` with `collate_fn=pad_list_data_collate` might solve this problem (check its documentation)." + raise RuntimeError(re_str) + + +def pad_list_data_collate(batch: Sequence): + """ + Same as MONAI's ``list_data_collate``, except any tensors are centrally padded to match the shape of the biggest tensor in each dimension + + Note: + Need to use this collate if apply some transforms that can generate batch data. + + """ + for key in batch[0].keys(): + max_shapes = [] + for elem in batch: + if not isinstance(elem[key], (torch.Tensor, np.ndarray)): + break + max_shapes.append(elem[key].shape[1:]) + # len > 0 if objects were arrays + if len(max_shapes) == 0: + continue + max_shape = np.array(max_shapes).max(axis=0) + # If all same size, skip + if np.all(np.array(max_shapes).min(axis=0) == max_shape): + continue + # Do we need to convert output to Tensor? + output_to_tensor = isinstance(batch[0][key], torch.Tensor) + + # Use `SpatialPadd` to match sizes + # Default params are central padding, padding with 0's + # Use the dictionary version so that the transformation is recorded + padder = monai.transforms.croppad.dictionary.SpatialPadd(key, max_shape) # type: ignore + for idx in range(len(batch)): + batch[idx][key] = padder(batch[idx])[key] + if output_to_tensor: + batch[idx][key] = torch.Tensor(batch[idx][key]) + + # After padding, use default list collator + return list_data_collate(batch) + def decollate_batch(data: dict, batch_size: Optional[int] = None): diff --git a/monai/transforms/croppad/array.py b/monai/transforms/croppad/array.py index ef5e0019bd..e7b7f04138 100644 --- a/monai/transforms/croppad/array.py +++ b/monai/transforms/croppad/array.py @@ -19,7 +19,7 @@ import torch from monai.config import IndexSelection -from monai.data.utils import get_random_patch, get_valid_patch_size +import monai.data.utils from monai.transforms.transform import Randomizable, Transform from monai.transforms.utils import ( generate_pos_neg_label_crop_centers, @@ -304,8 +304,8 @@ def randomize(self, img_size: Sequence[int]) -> None: if self.random_size: self._size = tuple((self.R.randint(low=self._size[i], high=img_size[i] + 1) for i in range(len(img_size)))) if self.random_center: - valid_size = get_valid_patch_size(img_size, self._size) - self._slices = (slice(None),) + get_random_patch(img_size, valid_size, self.R) + valid_size = monai.data.utils.get_valid_patch_size(img_size, self._size) + self._slices = (slice(None),) + monai.data.utils.get_random_patch(img_size, valid_size, self.R) def __call__(self, img: np.ndarray): """ diff --git a/monai/transforms/croppad/dictionary.py b/monai/transforms/croppad/dictionary.py index 3dd7667856..4bf2d2080a 100644 --- a/monai/transforms/croppad/dictionary.py +++ b/monai/transforms/croppad/dictionary.py @@ -22,7 +22,7 @@ import numpy as np from monai.config import IndexSelection, KeysCollection -from monai.data.utils import get_random_patch, get_valid_patch_size +import monai.data.utils from monai.transforms.croppad.array import ( BorderPad, BoundingRect, @@ -412,8 +412,8 @@ def randomize(self, img_size: Sequence[int]) -> None: if self.random_size: self._size = [self.R.randint(low=self._size[i], high=img_size[i] + 1) for i in range(len(img_size))] if self.random_center: - valid_size = get_valid_patch_size(img_size, self._size) - self._slices = (slice(None),) + get_random_patch(img_size, valid_size, self.R) + valid_size = monai.data.utils.get_valid_patch_size(img_size, self._size) + self._slices = (slice(None),) + monai.data.utils.get_random_patch(img_size, valid_size, self.R) pass def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index c229c1d85b..d4f190125d 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -20,7 +20,7 @@ import torch from monai.config import USE_COMPILED, DtypeLike -from monai.data.utils import compute_shape_offset, to_affine_nd, zoom_affine +import monai.data.utils from monai.networks.layers import AffineTransform, GaussianFilter, grid_pull from monai.transforms.croppad.array import CenterSpatialCrop from monai.transforms.transform import Randomizable, Transform @@ -160,24 +160,24 @@ def __call__( affine = np.eye(sr + 1, dtype=np.float64) affine_ = np.eye(sr + 1, dtype=np.float64) else: - affine_ = to_affine_nd(sr, affine) + affine_ = monai.data.utils.to_affine_nd(sr, affine) out_d = self.pixdim[:sr] if out_d.size < sr: out_d = np.append(out_d, [1.0] * (out_d.size - sr)) if np.any(out_d <= 0): raise ValueError(f"pixdim must be positive, got {out_d}.") # compute output affine, shape and offset - new_affine = zoom_affine(affine_, out_d, diagonal=self.diagonal) - output_shape, offset = compute_shape_offset(data_array.shape[1:], affine_, new_affine) + new_affine = monai.data.utils.zoom_affine(affine_, out_d, diagonal=self.diagonal) + output_shape, offset = monai.data.utils.compute_shape_offset(data_array.shape[1:], affine_, new_affine) new_affine[:sr, -1] = offset[:sr] transform = np.linalg.inv(affine_) @ new_affine # adapt to the actual rank - transform = to_affine_nd(sr, transform) + transform = monai.data.utils.to_affine_nd(sr, transform) # no resampling if it's identity transform if np.allclose(transform, np.diag(np.ones(len(transform))), atol=1e-3): output_data = data_array.copy().astype(np.float32) - new_affine = to_affine_nd(affine, new_affine) + new_affine = monai.data.utils.to_affine_nd(affine, new_affine) return output_data, affine, new_affine # resample @@ -195,7 +195,7 @@ def __call__( spatial_size=output_shape, ) output_data = np.asarray(output_data.squeeze(0).detach().cpu().numpy(), dtype=np.float32) # type: ignore - new_affine = to_affine_nd(affine, new_affine) + new_affine = monai.data.utils.to_affine_nd(affine, new_affine) return output_data, affine, new_affine @@ -261,7 +261,7 @@ def __call__( affine = np.eye(sr + 1, dtype=np.float64) affine_ = np.eye(sr + 1, dtype=np.float64) else: - affine_ = to_affine_nd(sr, affine) + affine_ = monai.data.utils.to_affine_nd(sr, affine) src = nib.io_orientation(affine_) if self.as_closest_canonical: spatial_ornt = src @@ -280,7 +280,7 @@ def __call__( shape = data_array.shape[1:] data_array = np.ascontiguousarray(nib.orientations.apply_orientation(data_array, ornt)) new_affine = affine_ @ nib.orientations.inv_ornt_aff(spatial_ornt, shape) - new_affine = to_affine_nd(affine, new_affine) + new_affine = monai.data.utils.to_affine_nd(affine, new_affine) return data_array, affine, new_affine diff --git a/tests/test_pad_collation.py b/tests/test_pad_collation.py new file mode 100644 index 0000000000..f91f9d4b97 --- /dev/null +++ b/tests/test_pad_collation.py @@ -0,0 +1,63 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from monai.data.utils import pad_list_data_collate +import unittest +from typing import List, Tuple + +import numpy as np +from parameterized import parameterized + +from monai.transforms import ( + RandRotated, + RandSpatialCropd, + RandZoomd, + RandRotate90d, +) +from monai.utils import set_determinism + +set_determinism(seed=0) + +import numpy as np +from monai.data import CacheDataset, DataLoader +from monai.transforms import RandSpatialCropd, RandRotated + +TESTS: List[Tuple] = [] + +TESTS.append((RandSpatialCropd("image", roi_size=[8, 7], random_size=True),)) +TESTS.append((RandRotated("image", prob=1, range_x=np.pi, keep_size=False),)) +TESTS.append((RandZoomd("image", prob=1, min_zoom=1.1, max_zoom=2.0, keep_size=False),)) +TESTS.append((RandRotate90d("image", prob=1, max_k=2),)) + +class TestPadCollation(unittest.TestCase): + def setUp(self) -> None: + # image is non square to throw rotation errors + im = np.arange(0, 10 * 9).reshape(1, 10, 9) + self.data = [{"image": im} for _ in range(2)] + + @parameterized.expand(TESTS) + def test_pad_collation(self, transform): + + dataset = CacheDataset(self.data, transform, progress=False) + + # Default collation should raise an error + loader_fail = DataLoader(dataset, batch_size=2) + with self.assertRaises(RuntimeError): + for _ in loader_fail: + pass + + # Padded collation shouldn't + loader = DataLoader(dataset, batch_size=2, collate_fn=pad_list_data_collate) + for _ in loader: + pass + +if __name__ == "__main__": + unittest.main() From 15bbf9ad4f57fb49edefdf2ec091808c3aae7062 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Fri, 19 Feb 2021 11:43:27 +0000 Subject: [PATCH 72/80] codeformate Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/data/__init__.py | 2 +- monai/data/image_reader.py | 8 +++----- monai/data/inverse_batch_transform.py | 19 +++++++++++------- monai/data/utils.py | 12 ++++++++---- monai/transforms/croppad/array.py | 2 +- monai/transforms/croppad/dictionary.py | 2 +- monai/transforms/inverse_transform.py | 3 +-- monai/transforms/spatial/array.py | 2 +- monai/transforms/spatial/dictionary.py | 27 ++++++++++++++------------ tests/test_inverse.py | 26 ++++++++++++------------- tests/test_pad_collation.py | 13 ++++--------- 11 files changed, 59 insertions(+), 57 deletions(-) diff --git a/monai/data/__init__.py b/monai/data/__init__.py index 380afb7773..679f88c2ab 100644 --- a/monai/data/__init__.py +++ b/monai/data/__init__.py @@ -48,6 +48,7 @@ iter_patch_slices, json_hashing, list_data_collate, + pad_list_data_collate, partition_dataset, partition_dataset_classes, pickle_hashing, @@ -58,5 +59,4 @@ to_affine_nd, worker_init_fn, zoom_affine, - pad_list_data_collate, ) diff --git a/monai/data/image_reader.py b/monai/data/image_reader.py index a897fc670e..2d3e06f3c2 100644 --- a/monai/data/image_reader.py +++ b/monai/data/image_reader.py @@ -16,12 +16,10 @@ import numpy as np from torch.utils.data._utils.collate import np_str_obj_array_pattern -from monai.config import DtypeLike, KeysCollection import monai.data.utils +from monai.config import DtypeLike, KeysCollection from monai.utils import ensure_tuple, optional_import -import monai.data.utils - if TYPE_CHECKING: import itk # type: ignore import nibabel as nib @@ -442,7 +440,7 @@ def verify_suffix(self, filename: Union[Sequence[str], str]) -> bool: if a list of files, verify all the suffixes. """ suffixes: Sequence[str] = ["npz", "npy"] - return is_supported_format(filename, suffixes) + return monai.data.is_supported_format(filename, suffixes) def read(self, data: Union[Sequence[str], str], **kwargs): """ @@ -526,7 +524,7 @@ def verify_suffix(self, filename: Union[Sequence[str], str]) -> bool: if a list of files, verify all the suffixes. """ suffixes: Sequence[str] = ["png", "jpg", "jpeg", "bmp"] - return has_pil and is_supported_format(filename, suffixes) + return has_pil and monai.data.is_supported_format(filename, suffixes) def read(self, data: Union[Sequence[str], str, np.ndarray], **kwargs): """ diff --git a/monai/data/inverse_batch_transform.py b/monai/data/inverse_batch_transform.py index 18ccc89193..426322e829 100644 --- a/monai/data/inverse_batch_transform.py +++ b/monai/data/inverse_batch_transform.py @@ -10,8 +10,10 @@ # limitations under the License. from typing import Callable, Optional -from monai.data.dataloader import DataLoader + from torch.utils.data.dataloader import DataLoader as TorchDataLoader + +from monai.data.dataloader import DataLoader from monai.data.dataset import Dataset from monai.data.utils import decollate_batch from monai.transforms.inverse_transform import InvertibleTransform @@ -22,11 +24,11 @@ class _BatchInverseDataset(Dataset): def __init__(self, data, transform: InvertibleTransform) -> None: self.data = decollate_batch(data) - self.transform = transform + self.invertible_transform = transform def __getitem__(self, index: int): data = self.data[index] - return self.transform.inverse(data) + return self.invertible_transform.inverse(data) class BatchInverseTransform: @@ -39,9 +41,10 @@ def __init__( Args: transform: a callable data transform on input data. loader: data loader used to generate the batch of data. - collate_fn: how to collate data after inverse transformations. Default will use the DataLoader's default collation method. - If returning images of different sizes, this will likely create an error (since the collation will concatenate arrays, - requiring them to be the same size). In this case, using `collate_fn=lambda x: x` might solve the problem. + collate_fn: how to collate data after inverse transformations. Default will use the DataLoader's default + collation method. If returning images of different sizes, this will likely create an error (since the + collation will concatenate arrays, requiring them to be the same size). In this case, using + `collate_fn=lambda x: x` might solve the problem. """ self.transform = transform self.batch_size = loader.batch_size @@ -50,7 +53,9 @@ def __init__( def __call__(self, data): inv_ds = _BatchInverseDataset(data, self.transform) - inv_loader = DataLoader(inv_ds, batch_size=self.batch_size, num_workers=self.num_workers, collate_fn=self.collate_fn) + inv_loader = DataLoader( + inv_ds, batch_size=self.batch_size, num_workers=self.num_workers, collate_fn=self.collate_fn + ) try: return next(iter(inv_loader)) except RuntimeError as re: diff --git a/monai/data/utils.py b/monai/data/utils.py index 233196fbd9..26923a650d 100644 --- a/monai/data/utils.py +++ b/monai/data/utils.py @@ -24,8 +24,8 @@ import torch from torch.utils.data import DistributedSampler as _TorchDistributedSampler from torch.utils.data._utils.collate import default_collate -import monai.transforms.croppad.dictionary +import monai.transforms.croppad.dictionary from monai.networks.layers.simplelayers import GaussianFilter from monai.utils import ( MAX_SEED, @@ -249,13 +249,18 @@ def list_data_collate(batch: Sequence): except RuntimeError as re: re_str = str(re) if "stack expects each tensor to be equal size" in re_str: - re_str += "\nMONAI hint: if your transforms intentionally create images of different shapes, creating your `DataLoader` with `collate_fn=pad_list_data_collate` might solve this problem (check its documentation)." + re_str += ( + "\nMONAI hint: if your transforms intentionally create images of different shapes, creating your " + + "`DataLoader` with `collate_fn=pad_list_data_collate` might solve this problem (check its " + + "documentation)." + ) raise RuntimeError(re_str) def pad_list_data_collate(batch: Sequence): """ - Same as MONAI's ``list_data_collate``, except any tensors are centrally padded to match the shape of the biggest tensor in each dimension + Same as MONAI's ``list_data_collate``, except any tensors are centrally padded to match the shape of the biggest + tensor in each dimension. Note: Need to use this collate if apply some transforms that can generate batch data. @@ -290,7 +295,6 @@ def pad_list_data_collate(batch: Sequence): return list_data_collate(batch) - def decollate_batch(data: dict, batch_size: Optional[int] = None): """De-collate a batch of data (for example, as produced by a `DataLoader`). diff --git a/monai/transforms/croppad/array.py b/monai/transforms/croppad/array.py index e7b7f04138..060065f81b 100644 --- a/monai/transforms/croppad/array.py +++ b/monai/transforms/croppad/array.py @@ -18,8 +18,8 @@ import numpy as np import torch -from monai.config import IndexSelection import monai.data.utils +from monai.config import IndexSelection from monai.transforms.transform import Randomizable, Transform from monai.transforms.utils import ( generate_pos_neg_label_crop_centers, diff --git a/monai/transforms/croppad/dictionary.py b/monai/transforms/croppad/dictionary.py index 4bf2d2080a..1b48b8e684 100644 --- a/monai/transforms/croppad/dictionary.py +++ b/monai/transforms/croppad/dictionary.py @@ -21,8 +21,8 @@ import numpy as np -from monai.config import IndexSelection, KeysCollection import monai.data.utils +from monai.config import IndexSelection, KeysCollection from monai.transforms.croppad.array import ( BorderPad, BoundingRect, diff --git a/monai/transforms/inverse_transform.py b/monai/transforms/inverse_transform.py index 38b5dd9ce4..e5a9772d22 100644 --- a/monai/transforms/inverse_transform.py +++ b/monai/transforms/inverse_transform.py @@ -11,7 +11,6 @@ import warnings from abc import ABC -from itertools import chain from typing import Hashable, Optional, Tuple import numpy as np @@ -163,7 +162,7 @@ def _inv_disp_w_vtk(fwd_disp): # Convert back to numpy and reshape inv_disp = vtk_numpy_support.vtk_to_numpy(inv_disp_vtk.GetPointData().GetArray(0)) # if there were originally < 3 tensor components, remove the zeros we added at the start - inv_disp = inv_disp[..., :orig_shape[-1]] + inv_disp = inv_disp[..., : orig_shape[-1]] # reshape to original inv_disp = inv_disp.reshape(orig_shape) diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index d4f190125d..62462b8f8d 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -19,8 +19,8 @@ import numpy as np import torch -from monai.config import USE_COMPILED, DtypeLike import monai.data.utils +from monai.config import USE_COMPILED, DtypeLike from monai.networks.layers import AffineTransform, GaussianFilter, grid_pull from monai.transforms.croppad.array import CenterSpatialCrop from monai.transforms.transform import Randomizable, Transform diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index db8c27142a..ec9b0188b4 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -598,11 +598,10 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar grid: torch.Tensor = affine_grid(orig_size) # type: ignore # Apply inverse transform - d[key] = self.rand_affine.resampler(torch.Tensor(d[key]), grid, self.mode[idx], self.padding_mode[idx]) + out = self.rand_affine.resampler(d[key], grid, self.mode[idx], self.padding_mode[idx]) # Convert to numpy - if isinstance(d[key], torch.Tensor): - d[key] = d[key].cpu().numpy() + d[key] = out if isinstance(out, np.ndarray) else out.cpu().numpy() # Remove the applied transform self.remove_most_recent_transform(d, key) @@ -722,7 +721,7 @@ def __call__( cpg = self.rand_2d_elastic.deform_grid(spatial_size=sp_size) cpg_w_affine, affine = self.rand_2d_elastic.rand_affine_grid(grid=cpg, return_affine=True) grid = self.cpg_to_dvf(cpg_w_affine, self.rand_2d_elastic.deform_grid.spacing, sp_size) - extra_info = {"cpg": deepcopy(cpg), "affine": deepcopy(affine)} + extra_info: Optional[Dict] = {"cpg": deepcopy(cpg), "affine": deepcopy(affine)} else: grid = create_grid(spatial_size=sp_size) extra_info = None @@ -737,7 +736,7 @@ def __call__( def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = deepcopy(dict(data)) # This variable will be `not None` if vtk or sitk is present - inv_def_w_affine = None + inv_def_no_affine = None for idx, key in enumerate(self.keys): transform = self.get_most_recent_transform(d, key) @@ -751,13 +750,15 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar inv_def_no_affine = create_grid(spatial_size=orig_size) else: fwd_cpg_no_affine = transform["extra_info"]["cpg"] - fwd_def_no_affine = self.cpg_to_dvf(fwd_cpg_no_affine, self.rand_2d_elastic.deform_grid.spacing, orig_size) + fwd_def_no_affine = self.cpg_to_dvf( + fwd_cpg_no_affine, self.rand_2d_elastic.deform_grid.spacing, orig_size + ) inv_def_no_affine = self.compute_inverse_deformation(len(orig_size), fwd_def_no_affine) # if inverse did not succeed (sitk or vtk present), data will not be changed. if inv_def_no_affine is not None: fwd_affine = transform["extra_info"]["affine"] inv_affine = np.linalg.inv(fwd_affine) - inv_def_w_affine = AffineGrid(affine=inv_affine)(grid=inv_def_no_affine) + inv_def_w_affine: np.ndarray = AffineGrid(affine=inv_affine, as_tensor_output=False)(grid=inv_def_no_affine) # type: ignore # Back to original size inv_def_w_affine = CenterSpatialCrop(roi_size=orig_size)(inv_def_w_affine) # Apply inverse transform @@ -878,7 +879,9 @@ def __call__( grid_w_affine, affine = self.rand_3d_elastic.rand_affine_grid(grid=grid_no_affine, return_affine=True) for idx, key in enumerate(self.keys): - self.append_applied_transforms(d, key, extra_info={"grid_no_affine": grid_no_affine.cpu().numpy(), "affine": affine}) + self.append_applied_transforms( + d, key, extra_info={"grid_no_affine": grid_no_affine.cpu().numpy(), "affine": affine} + ) d[key] = self.rand_3d_elastic.resampler( d[key], grid_w_affine, mode=self.mode[idx], padding_mode=self.padding_mode[idx] ) @@ -886,8 +889,6 @@ def __call__( def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = deepcopy(dict(data)) - # This variable will be `not None` if vtk or sitk is present - inv_def_w_affine = None for idx, key in enumerate(self.keys): transform = self.get_most_recent_transform(d, key) @@ -902,12 +903,14 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar if inv_def_no_affine is not None: fwd_affine = transform["extra_info"]["affine"] inv_affine = np.linalg.inv(fwd_affine) - inv_def_w_affine = AffineGrid(affine=inv_affine)(grid=inv_def_no_affine) + inv_def_w_affine: np.ndarray = AffineGrid(affine=inv_affine, as_tensor_output=False)(grid=inv_def_no_affine) # type: ignore # Back to original size inv_def_w_affine = CenterSpatialCrop(roi_size=orig_size)(inv_def_w_affine) # Apply inverse transform if inv_def_w_affine is not None: - out = self.rand_3d_elastic.resampler(d[key], inv_def_w_affine, self.mode[idx], self.padding_mode[idx]) + out = self.rand_3d_elastic.resampler( + d[key], inv_def_w_affine, self.mode[idx], self.padding_mode[idx] + ) d[key] = out.cpu().numpy() if isinstance(out, torch.Tensor) else out else: d[key] = CenterSpatialCrop(roi_size=orig_size)(d[key]) diff --git a/tests/test_inverse.py b/tests/test_inverse.py index 6e41791c2d..fba876fffe 100644 --- a/tests/test_inverse.py +++ b/tests/test_inverse.py @@ -370,18 +370,18 @@ ( "Rand3DElasticd 3d", DATA_3D, - 1e-1, - Rand3DElasticd( - KEYS, - sigma_range=(3, 5), - magnitude_range=(100, 100), - prob=1, - padding_mode="zeros", - rotate_range=[np.pi / 6, np.pi / 7], - shear_range=[(0.5, 0.5), 0.2], - translate_range=[10, 5, 3], - scale_range=[(0.8, 1.2), (0.9, 1.3)], - ), + 1e-1, + Rand3DElasticd( + KEYS, + sigma_range=(3, 5), + magnitude_range=(100, 100), + prob=1, + padding_mode="zeros", + rotate_range=[np.pi / 6, np.pi / 7], + shear_range=[(0.5, 0.5), 0.2], + translate_range=[10, 5, 3], + scale_range=[(0.8, 1.2), (0.9, 1.3)], + ), ) ) @@ -518,8 +518,6 @@ def test_diff_sized_inputs(self): unmodified = test_data[batch_idx * batch_size + idx] self.check_inverse("diff_sized_inputs", [key], unmodified, _fwd_bck, _fwd, 0) - - @parameterized.expand(TESTS_FAIL) def test_fail(self, data, _, *transform): d = transform[0](data) diff --git a/tests/test_pad_collation.py b/tests/test_pad_collation.py index f91f9d4b97..10ca1dec01 100644 --- a/tests/test_pad_collation.py +++ b/tests/test_pad_collation.py @@ -9,26 +9,19 @@ # See the License for the specific language governing permissions and # limitations under the License. -from monai.data.utils import pad_list_data_collate import unittest from typing import List, Tuple import numpy as np from parameterized import parameterized -from monai.transforms import ( - RandRotated, - RandSpatialCropd, - RandZoomd, - RandRotate90d, -) +from monai.data.utils import pad_list_data_collate +from monai.transforms import RandRotate90d, RandRotated, RandSpatialCropd, RandZoomd from monai.utils import set_determinism set_determinism(seed=0) -import numpy as np from monai.data import CacheDataset, DataLoader -from monai.transforms import RandSpatialCropd, RandRotated TESTS: List[Tuple] = [] @@ -37,6 +30,7 @@ TESTS.append((RandZoomd("image", prob=1, min_zoom=1.1, max_zoom=2.0, keep_size=False),)) TESTS.append((RandRotate90d("image", prob=1, max_k=2),)) + class TestPadCollation(unittest.TestCase): def setUp(self) -> None: # image is non square to throw rotation errors @@ -59,5 +53,6 @@ def test_pad_collation(self, transform): for _ in loader: pass + if __name__ == "__main__": unittest.main() From 595119f3665e9d1e534fd88b395550035ada3dd3 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Fri, 19 Feb 2021 16:45:32 +0000 Subject: [PATCH 73/80] Compose len Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/compose.py | 10 ++- monai/transforms/croppad/dictionary.py | 48 +++++++++----- monai/transforms/inverse_transform.py | 2 +- monai/transforms/spatial/dictionary.py | 92 +++++++++++++++++--------- tests/test_compose.py | 7 ++ 5 files changed, 109 insertions(+), 50 deletions(-) diff --git a/monai/transforms/compose.py b/monai/transforms/compose.py index b3997f1197..b3e2968a86 100644 --- a/monai/transforms/compose.py +++ b/monai/transforms/compose.py @@ -14,7 +14,7 @@ import warnings from copy import deepcopy -from typing import Any, Callable, Mapping, Optional, Sequence, Union +from typing import Any, Callable, Hashable, Mapping, Optional, Sequence, Tuple, Union import numpy as np @@ -113,12 +113,16 @@ def randomize(self, data: Optional[Any] = None) -> None: f'Transform "{tfm_name}" in Compose not randomized\n{tfm_name}.{type_error}.', RuntimeWarning ) + def __len__(self): + """Return number of transformations.""" + return sum(len(t) if isinstance(t, Compose) else 1 for t in self.transforms) + def __call__(self, input_): for _transform in self.transforms: input_ = apply_transform(_transform, input_) return input_ - def inverse(self, data): + def inverse(self, data, keys: Optional[Tuple[Hashable, ...]] = None): if not isinstance(data, Mapping): raise RuntimeError("Inverse method only available for dictionary transforms") d = deepcopy(dict(data)) @@ -127,5 +131,5 @@ def inverse(self, data): for t in reversed(self.transforms): # check if transform is one of the invertible ones if isinstance(t, InvertibleTransform): - d = t.inverse(d) + d = t.inverse(d, keys) return d diff --git a/monai/transforms/croppad/dictionary.py b/monai/transforms/croppad/dictionary.py index 1b48b8e684..0a615d14be 100644 --- a/monai/transforms/croppad/dictionary.py +++ b/monai/transforms/croppad/dictionary.py @@ -124,9 +124,11 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda d[key] = self.padder(d[key], mode=m) return d - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def inverse( + self, data: Mapping[Hashable, np.ndarray], keys: Optional[Tuple[Hashable, ...]] = None + ) -> Dict[Hashable, np.ndarray]: d = deepcopy(dict(data)) - for key in self.keys: + for key in keys or self.keys: transform = self.get_most_recent_transform(d, key) # Create inverse transform orig_size = transform["orig_size"] @@ -190,10 +192,12 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda d[key] = self.padder(d[key], mode=m) return d - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def inverse( + self, data: Mapping[Hashable, np.ndarray], keys: Optional[Tuple[Hashable, ...]] = None + ) -> Dict[Hashable, np.ndarray]: d = deepcopy(dict(data)) - for key in self.keys: + for key in keys or self.keys: transform = self.get_most_recent_transform(d, key) # Create inverse transform orig_size = np.array(transform["orig_size"]) @@ -251,10 +255,12 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda d[key] = self.padder(d[key], mode=m) return d - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def inverse( + self, data: Mapping[Hashable, np.ndarray], keys: Optional[Tuple[Hashable, ...]] = None + ) -> Dict[Hashable, np.ndarray]: d = deepcopy(dict(data)) - for key in self.keys: + for key in keys or self.keys: transform = self.get_most_recent_transform(d, key) # Create inverse transform orig_size = np.array(transform["orig_size"]) @@ -304,10 +310,12 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda d[key] = self.cropper(d[key]) return d - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def inverse( + self, data: Mapping[Hashable, np.ndarray], keys: Optional[Tuple[Hashable, ...]] = None + ) -> Dict[Hashable, np.ndarray]: d = deepcopy(dict(data)) - for key in self.keys: + for key in keys or self.keys: transform = self.get_most_recent_transform(d, key) # Create inverse transform orig_size = transform["orig_size"] @@ -349,10 +357,12 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda self.append_applied_transforms(d, key, orig_size=orig_size) return d - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def inverse( + self, data: Mapping[Hashable, np.ndarray], keys: Optional[Tuple[Hashable, ...]] = None + ) -> Dict[Hashable, np.ndarray]: d = deepcopy(dict(data)) - for key in self.keys: + for key in keys or self.keys: transform = self.get_most_recent_transform(d, key) # Create inverse transform orig_size = np.array(transform["orig_size"]) @@ -431,10 +441,12 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda d[key] = cropper(d[key]) return d - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def inverse( + self, data: Mapping[Hashable, np.ndarray], keys: Optional[Tuple[Hashable, ...]] = None + ) -> Dict[Hashable, np.ndarray]: d = deepcopy(dict(data)) - for key in self.keys: + for key in keys or self.keys: transform = self.get_most_recent_transform(d, key) # Create inverse transform orig_size = transform["orig_size"] @@ -574,9 +586,11 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda d[key] = cropper(d[key]) return d - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def inverse( + self, data: Mapping[Hashable, np.ndarray], keys: Optional[Tuple[Hashable, ...]] = None + ) -> Dict[Hashable, np.ndarray]: d = deepcopy(dict(data)) - for key in self.keys: + for key in keys or self.keys: transform = self.get_most_recent_transform(d, key) # Create inverse transform orig_size = np.array(transform["orig_size"]) @@ -801,9 +815,11 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda self.append_applied_transforms(d, key, orig_size=orig_size) return d - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def inverse( + self, data: Mapping[Hashable, np.ndarray], keys: Optional[Tuple[Hashable, ...]] = None + ) -> Dict[Hashable, np.ndarray]: d = deepcopy(dict(data)) - for key in self.keys: + for key in keys or self.keys: transform = self.get_most_recent_transform(d, key) # Create inverse transform orig_size = transform["orig_size"] diff --git a/monai/transforms/inverse_transform.py b/monai/transforms/inverse_transform.py index e5a9772d22..8833fc610d 100644 --- a/monai/transforms/inverse_transform.py +++ b/monai/transforms/inverse_transform.py @@ -79,7 +79,7 @@ def remove_most_recent_transform(data: dict, key: Hashable) -> None: """Remove most recent transform.""" data[str(key) + "_transforms"].pop() - def inverse(self, data: dict): + def inverse(self, data: dict, keys: Optional[Tuple[Hashable, ...]] = None): """ Inverse of ``__call__``. diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index ec9b0188b4..a40281c75b 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -203,9 +203,11 @@ def __call__( meta_data["affine"] = new_affine return d - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def inverse( + self, data: Mapping[Hashable, np.ndarray], keys: Optional[Tuple[Hashable, ...]] = None + ) -> Dict[Hashable, np.ndarray]: d = deepcopy(dict(data)) - for idx, key in enumerate(self.keys): + for idx, key in enumerate(keys or self.keys): transform = self.get_most_recent_transform(d, key) if self.spacing_transform.diagonal: raise RuntimeError( @@ -295,9 +297,11 @@ def __call__( d[meta_data_key]["affine"] = new_affine return d - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def inverse( + self, data: Mapping[Hashable, np.ndarray], keys: Optional[Tuple[Hashable, ...]] = None + ) -> Dict[Hashable, np.ndarray]: d = deepcopy(dict(data)) - for key in self.keys: + for key in keys or self.keys: transform = self.get_most_recent_transform(d, key) # Create inverse transform meta_data = d[transform["extra_info"]["meta_data_key"]] @@ -339,9 +343,11 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda d[key] = self.rotator(d[key]) return d - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def inverse( + self, data: Mapping[Hashable, np.ndarray], keys: Optional[Tuple[Hashable, ...]] = None + ) -> Dict[Hashable, np.ndarray]: d = deepcopy(dict(data)) - for key in self.keys: + for key in keys or self.keys: _ = self.get_most_recent_transform(d, key) # Create inverse transform spatial_axes = self.rotator.spatial_axes @@ -407,9 +413,11 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Mapping[Hashable, np. self.append_applied_transforms(d, key, extra_info={"rand_k": self._rand_k}) return d - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def inverse( + self, data: Mapping[Hashable, np.ndarray], keys: Optional[Tuple[Hashable, ...]] = None + ) -> Dict[Hashable, np.ndarray]: d = deepcopy(dict(data)) - for key in self.keys: + for key in keys or self.keys: transform = self.get_most_recent_transform(d, key) # Check if random transform was actually performed (based on `prob`) if transform["do_transform"]: @@ -468,9 +476,11 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda d[key] = self.resizer(d[key], mode=self.mode[idx], align_corners=self.align_corners[idx]) return d - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def inverse( + self, data: Mapping[Hashable, np.ndarray], keys: Optional[Tuple[Hashable, ...]] = None + ) -> Dict[Hashable, np.ndarray]: d = deepcopy(dict(data)) - for idx, key in enumerate(self.keys): + for idx, key in enumerate(keys or self.keys): transform = self.get_most_recent_transform(d, key) orig_size = transform["orig_size"] mode = self.mode[idx] @@ -584,10 +594,12 @@ def __call__( d[key] = self.rand_affine.resampler(d[key], grid, mode=self.mode[idx], padding_mode=self.padding_mode[idx]) return d - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def inverse( + self, data: Mapping[Hashable, np.ndarray], keys: Optional[Tuple[Hashable, ...]] = None + ) -> Dict[Hashable, np.ndarray]: d = deepcopy(dict(data)) - for idx, key in enumerate(self.keys): + for idx, key in enumerate(keys or self.keys): transform = self.get_most_recent_transform(d, key) orig_size = transform["orig_size"] # Create inverse transform @@ -733,12 +745,14 @@ def __call__( ) return d - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def inverse( + self, data: Mapping[Hashable, np.ndarray], keys: Optional[Tuple[Hashable, ...]] = None + ) -> Dict[Hashable, np.ndarray]: d = deepcopy(dict(data)) # This variable will be `not None` if vtk or sitk is present inv_def_no_affine = None - for idx, key in enumerate(self.keys): + for idx, key in enumerate(keys or self.keys): transform = self.get_most_recent_transform(d, key) # Create inverse transform if transform["do_transform"]: @@ -758,7 +772,9 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar if inv_def_no_affine is not None: fwd_affine = transform["extra_info"]["affine"] inv_affine = np.linalg.inv(fwd_affine) - inv_def_w_affine: np.ndarray = AffineGrid(affine=inv_affine, as_tensor_output=False)(grid=inv_def_no_affine) # type: ignore + inv_def_w_affine: np.ndarray = AffineGrid(affine=inv_affine, as_tensor_output=False)( + grid=inv_def_no_affine + ) # type: ignore # Back to original size inv_def_w_affine = CenterSpatialCrop(roi_size=orig_size)(inv_def_w_affine) # Apply inverse transform @@ -887,10 +903,12 @@ def __call__( ) return d - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def inverse( + self, data: Mapping[Hashable, np.ndarray], keys: Optional[Tuple[Hashable, ...]] = None + ) -> Dict[Hashable, np.ndarray]: d = deepcopy(dict(data)) - for idx, key in enumerate(self.keys): + for idx, key in enumerate(keys or self.keys): transform = self.get_most_recent_transform(d, key) # Create inverse transform if transform["do_transform"]: @@ -903,7 +921,9 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar if inv_def_no_affine is not None: fwd_affine = transform["extra_info"]["affine"] inv_affine = np.linalg.inv(fwd_affine) - inv_def_w_affine: np.ndarray = AffineGrid(affine=inv_affine, as_tensor_output=False)(grid=inv_def_no_affine) # type: ignore + inv_def_w_affine: np.ndarray = AffineGrid(affine=inv_affine, as_tensor_output=False)( + grid=inv_def_no_affine + ) # type: ignore # Back to original size inv_def_w_affine = CenterSpatialCrop(roi_size=orig_size)(inv_def_w_affine) # Apply inverse transform @@ -943,9 +963,11 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda d[key] = self.flipper(d[key]) return d - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def inverse( + self, data: Mapping[Hashable, np.ndarray], keys: Optional[Tuple[Hashable, ...]] = None + ) -> Dict[Hashable, np.ndarray]: d = deepcopy(dict(data)) - for key in self.keys: + for key in keys or self.keys: _ = self.get_most_recent_transform(d, key) # Might need to convert to numpy if isinstance(d[key], torch.Tensor): @@ -996,9 +1018,11 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda return d - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def inverse( + self, data: Mapping[Hashable, np.ndarray], keys: Optional[Tuple[Hashable, ...]] = None + ) -> Dict[Hashable, np.ndarray]: d = deepcopy(dict(data)) - for key in self.keys: + for key in keys or self.keys: transform = self.get_most_recent_transform(d, key) # Check if random transform was actually performed (based on `prob`) if transform["do_transform"]: @@ -1073,9 +1097,11 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda self.append_applied_transforms(d, key, orig_size=orig_size, extra_info={"rot_mat": rot_mat}) return d - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def inverse( + self, data: Mapping[Hashable, np.ndarray], keys: Optional[Tuple[Hashable, ...]] = None + ) -> Dict[Hashable, np.ndarray]: d = deepcopy(dict(data)) - for idx, key in enumerate(self.keys): + for idx, key in enumerate(keys or self.keys): transform = self.get_most_recent_transform(d, key) # Create inverse transform fwd_rot_mat = transform["extra_info"]["rot_mat"] @@ -1201,9 +1227,11 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda self.append_applied_transforms(d, key, orig_size=orig_size, extra_info={"rot_mat": rot_mat}) return d - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def inverse( + self, data: Mapping[Hashable, np.ndarray], keys: Optional[Tuple[Hashable, ...]] = None + ) -> Dict[Hashable, np.ndarray]: d = deepcopy(dict(data)) - for idx, key in enumerate(self.keys): + for idx, key in enumerate(keys or self.keys): transform = self.get_most_recent_transform(d, key) # Check if random transform was actually performed (based on `prob`) if transform["do_transform"]: @@ -1282,9 +1310,11 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda ) return d - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def inverse( + self, data: Mapping[Hashable, np.ndarray], keys: Optional[Tuple[Hashable, ...]] = None + ) -> Dict[Hashable, np.ndarray]: d = deepcopy(dict(data)) - for idx, key in enumerate(self.keys): + for idx, key in enumerate(keys or self.keys): transform = self.get_most_recent_transform(d, key) # Create inverse transform zoom = np.array(self.zoomer.zoom) @@ -1392,9 +1422,11 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda ) return d - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def inverse( + self, data: Mapping[Hashable, np.ndarray], keys: Optional[Tuple[Hashable, ...]] = None + ) -> Dict[Hashable, np.ndarray]: d = deepcopy(dict(data)) - for idx, key in enumerate(self.keys): + for idx, key in enumerate(keys or self.keys): transform = self.get_most_recent_transform(d, key) # Create inverse transform zoom = np.array(transform["extra_info"]["zoom"]) diff --git a/tests/test_compose.py b/tests/test_compose.py index 3585b3453c..2103cdaa36 100644 --- a/tests/test_compose.py +++ b/tests/test_compose.py @@ -156,6 +156,13 @@ def test_data_loader_2(self): self.assertAlmostEqual(out_1.cpu().item(), 0.131966779) set_determinism(None) + def test_len(self): + x = AddChannel() + t = Compose([x, x, x, x, Compose([Compose([x, x]), x, x])]) + + # test len + self.assertEqual(len(t), 8) + if __name__ == "__main__": unittest.main() From e2d63ae4ffb5698fd8f3028c8c51f704e5878455 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Fri, 19 Feb 2021 18:06:47 +0000 Subject: [PATCH 74/80] inverse batch and fixes Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/data/inverse_batch_transform.py | 18 +++-- monai/transforms/compose.py | 4 +- monai/transforms/inverse_transform.py | 4 +- monai/transforms/spatial/dictionary.py | 6 +- tests/test_inverse.py | 95 +++++++++++++++++++------- 5 files changed, 90 insertions(+), 37 deletions(-) diff --git a/monai/data/inverse_batch_transform.py b/monai/data/inverse_batch_transform.py index 426322e829..49166b7fc1 100644 --- a/monai/data/inverse_batch_transform.py +++ b/monai/data/inverse_batch_transform.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, Optional +from typing import Any, Callable, Dict, Hashable, Optional, Tuple from torch.utils.data.dataloader import DataLoader as TorchDataLoader @@ -17,18 +17,20 @@ from monai.data.dataset import Dataset from monai.data.utils import decollate_batch from monai.transforms.inverse_transform import InvertibleTransform +from monai.utils import first __all__ = ["BatchInverseTransform"] class _BatchInverseDataset(Dataset): - def __init__(self, data, transform: InvertibleTransform) -> None: + def __init__(self, data: Dict[str, Any], transform: InvertibleTransform, keys: Optional[Tuple[Hashable, ...]] = None) -> None: self.data = decollate_batch(data) self.invertible_transform = transform + self.keys = keys - def __getitem__(self, index: int): + def __getitem__(self, index: int) -> Dict[str, Any]: data = self.data[index] - return self.invertible_transform.inverse(data) + return self.invertible_transform.inverse(data, self.keys) class BatchInverseTransform: @@ -51,13 +53,15 @@ def __init__( self.num_workers = loader.num_workers self.collate_fn = collate_fn - def __call__(self, data): - inv_ds = _BatchInverseDataset(data, self.transform) + def __call__(self, data: Dict[str, Any], keys: Optional[Tuple[Hashable, ...]] = None) -> Dict[str, Any]: + + inv_ds = _BatchInverseDataset(data, self.transform, keys) inv_loader = DataLoader( inv_ds, batch_size=self.batch_size, num_workers=self.num_workers, collate_fn=self.collate_fn ) try: - return next(iter(inv_loader)) + # Only need to return first as only 1 batch of data + return first(inv_loader) # type: ignore except RuntimeError as re: re_str = str(re) if "stack expects each tensor to be equal size" in re_str: diff --git a/monai/transforms/compose.py b/monai/transforms/compose.py index b3e2968a86..70b306262d 100644 --- a/monai/transforms/compose.py +++ b/monai/transforms/compose.py @@ -14,7 +14,7 @@ import warnings from copy import deepcopy -from typing import Any, Callable, Hashable, Mapping, Optional, Sequence, Tuple, Union +from typing import Any, Callable, Dict, Hashable, List, Mapping, Optional, Sequence, Tuple, Union import numpy as np @@ -126,6 +126,8 @@ def inverse(self, data, keys: Optional[Tuple[Hashable, ...]] = None): if not isinstance(data, Mapping): raise RuntimeError("Inverse method only available for dictionary transforms") d = deepcopy(dict(data)) + if keys: + keys = ensure_tuple(keys) # loop backwards over transforms for t in reversed(self.transforms): diff --git a/monai/transforms/inverse_transform.py b/monai/transforms/inverse_transform.py index 8833fc610d..8b368264cc 100644 --- a/monai/transforms/inverse_transform.py +++ b/monai/transforms/inverse_transform.py @@ -11,7 +11,7 @@ import warnings from abc import ABC -from typing import Hashable, Optional, Tuple +from typing import Any, Dict, Hashable, Optional, Tuple import numpy as np import torch @@ -79,7 +79,7 @@ def remove_most_recent_transform(data: dict, key: Hashable) -> None: """Remove most recent transform.""" data[str(key) + "_transforms"].pop() - def inverse(self, data: dict, keys: Optional[Tuple[Hashable, ...]] = None): + def inverse(self, data: dict, keys: Optional[Tuple[Hashable, ...]] = None) -> Dict[str, Any]: """ Inverse of ``__call__``. diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index a40281c75b..2e3e15244f 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -778,8 +778,10 @@ def inverse( # Back to original size inv_def_w_affine = CenterSpatialCrop(roi_size=orig_size)(inv_def_w_affine) # Apply inverse transform - out = self.rand_2d_elastic.resampler(d[key], inv_def_w_affine, self.mode[idx], self.padding_mode[idx]) - d[key] = out.cpu().numpy() if isinstance(out, torch.Tensor) else out + if inv_def_no_affine is not None: + out = self.rand_2d_elastic.resampler(d[key], inv_def_w_affine, self.mode[idx], self.padding_mode[idx]) + d[key] = out.cpu().numpy() if isinstance(out, torch.Tensor) else out + else: d[key] = CenterSpatialCrop(roi_size=orig_size)(d[key]) # Remove the applied transform diff --git a/tests/test_inverse.py b/tests/test_inverse.py index fba876fffe..b271dfb856 100644 --- a/tests/test_inverse.py +++ b/tests/test_inverse.py @@ -52,15 +52,18 @@ SpatialPadd, Zoomd, ) -from monai.utils import optional_import, set_determinism +from monai.utils import first, optional_import, set_determinism from tests.utils import make_nifti_image, make_rand_affine, test_is_quick if TYPE_CHECKING: import matplotlib.pyplot as plt + import vtk has_matplotlib = True + has_vtk = True else: plt, has_matplotlib = optional_import("matplotlib.pyplot") + _, has_vtk = optional_import("vtk") set_determinism(seed=0) @@ -345,27 +348,28 @@ ) ) -TESTS.append( - ( - "Rand2DElasticd 2d", - DATA_2D, - 2e-1, - Rand2DElasticd( - KEYS, - spacing=(10.0, 10.0), - magnitude_range=(1, 1), - spatial_size=[155, 192], - prob=1, - padding_mode="zeros", - rotate_range=[(np.pi / 6, np.pi / 6)], - shear_range=[(0.5, 0.5)], - translate_range=[10, 5], - scale_range=[(1.2, 1.2), (1.3, 1.3)], - ), +if has_vtk: + TESTS.append( + ( + "Rand2DElasticd 2d", + DATA_2D, + 2e-1, + Rand2DElasticd( + KEYS, + spacing=(10.0, 10.0), + magnitude_range=(1, 1), + spatial_size=[155, 192], + prob=1, + padding_mode="zeros", + rotate_range=[(np.pi / 6, np.pi / 6)], + shear_range=[(0.5, 0.5)], + translate_range=[10, 5], + scale_range=[(1.2, 1.2), (1.3, 1.3)], + ), + ) ) -) -if not test_is_quick: +if not test_is_quick and has_vtk: TESTS.append( ( "Rand3DElasticd 3d", @@ -463,7 +467,7 @@ def test_inverse(self, _, data, acceptable_diff, *transforms): @parameterized.expand(TESTS) def test_w_dataloader(self, _, data, acceptable_diff, *transforms): name = _ - device = "cuda" if torch.cuda.is_available() else "cpu" + device = "cpu" if isinstance(transforms, tuple): transforms = Compose(transforms) numel = 4 @@ -471,8 +475,7 @@ def test_w_dataloader(self, _, data, acceptable_diff, *transforms): ndims = len(data["image"].shape[1:]) batch_size = 2 - # num workers = 0 for mac - num_workers = 2 if sys.platform != "darwin" else 0 + num_workers = 0 dataset = CacheDataset(test_data, transforms, progress=False) loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers) @@ -501,8 +504,7 @@ def test_diff_sized_inputs(self): test_data = [{key: AddChannel()(create_test_image_2d(100 + i, 101 + i)[0])} for i in range(4)] batch_size = 2 - # num workers = 0 for mac - num_workers = 2 if sys.platform != "darwin" else 0 + num_workers = 0 transforms = Compose([SpatialPadd(key, (150, 153))]) dataset = CacheDataset(test_data, transform=transforms, progress=False) @@ -524,6 +526,49 @@ def test_fail(self, data, _, *transform): with self.assertRaises(RuntimeError): d = transform[0].inverse(d) + def test_inverse_inferred_seg(self): + + test_data = [] + for _ in range(4): + image, label = create_test_image_2d(100, 101) + test_data.append({"image": image, "label": label.astype(np.float32)}) + + batch_size = 2 + # num workers = 0 for mac + num_workers = 2 if sys.platform != "darwin" else 0 + transforms = Compose([AddChanneld(KEYS), SpatialPadd(KEYS, (150, 153)), CenterSpatialCropd(KEYS, (110, 99))]) + num_invertible_transforms = sum(1 for i in transforms.transforms if isinstance(i, InvertibleTransform)) + + dataset = CacheDataset(test_data, transform=transforms, progress=False) + loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers) + + device = "cuda" if torch.cuda.is_available() else "cpu" + model = UNet( + dimensions=2, + in_channels=1, + out_channels=1, + channels=(2, 4), + strides=(2,), + ).to(device) + + data = first(loader) + labels = data["label"].to(device) + segs = model(labels).detach().cpu() + segs_dict = {"label": segs, "label_transforms": data["label_transforms"]} + segs_dict_decollated = decollate_batch(segs_dict) + + # inverse of individual segmentation + seg_dict = first(segs_dict_decollated) + inv_seg = transforms.inverse(seg_dict, "label")["label"] + self.assertEqual(len(data["label_transforms"]), num_invertible_transforms) + self.assertEqual(len(seg_dict["label_transforms"]), num_invertible_transforms) + self.assertEqual(inv_seg.shape[1:], test_data[0]["label"].shape) + + # Inverse of batch + batch_inverter = BatchInverseTransform(transforms, loader, collate_fn=lambda x: x) + inv_batch = batch_inverter(segs_dict, "label") + self.assertEqual(inv_batch[0]["label"].shape[1:], test_data[0]["label"].shape) + if __name__ == "__main__": unittest.main() From e515a78ec86f11e82d5167f101fd53fdd76004b4 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Fri, 19 Feb 2021 18:18:20 +0000 Subject: [PATCH 75/80] codeformat Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/data/inverse_batch_transform.py | 9 ++++++--- monai/transforms/compose.py | 2 +- monai/transforms/inverse_transform.py | 4 ++-- monai/transforms/spatial/dictionary.py | 4 +++- tests/test_inverse.py | 1 - 5 files changed, 12 insertions(+), 8 deletions(-) diff --git a/monai/data/inverse_batch_transform.py b/monai/data/inverse_batch_transform.py index 49166b7fc1..dfef9bf4e7 100644 --- a/monai/data/inverse_batch_transform.py +++ b/monai/data/inverse_batch_transform.py @@ -11,6 +11,7 @@ from typing import Any, Callable, Dict, Hashable, Optional, Tuple +import numpy as np from torch.utils.data.dataloader import DataLoader as TorchDataLoader from monai.data.dataloader import DataLoader @@ -23,12 +24,14 @@ class _BatchInverseDataset(Dataset): - def __init__(self, data: Dict[str, Any], transform: InvertibleTransform, keys: Optional[Tuple[Hashable, ...]] = None) -> None: + def __init__( + self, data: Dict[str, Any], transform: InvertibleTransform, keys: Optional[Tuple[Hashable, ...]] = None + ) -> None: self.data = decollate_batch(data) self.invertible_transform = transform self.keys = keys - def __getitem__(self, index: int) -> Dict[str, Any]: + def __getitem__(self, index: int) -> Dict[Hashable, np.ndarray]: data = self.data[index] return self.invertible_transform.inverse(data, self.keys) @@ -53,7 +56,7 @@ def __init__( self.num_workers = loader.num_workers self.collate_fn = collate_fn - def __call__(self, data: Dict[str, Any], keys: Optional[Tuple[Hashable, ...]] = None) -> Dict[str, Any]: + def __call__(self, data: Dict[str, Any], keys: Optional[Tuple[Hashable, ...]] = None) -> Dict[Hashable, np.ndarray]: inv_ds = _BatchInverseDataset(data, self.transform, keys) inv_loader = DataLoader( diff --git a/monai/transforms/compose.py b/monai/transforms/compose.py index 70b306262d..08d3466dab 100644 --- a/monai/transforms/compose.py +++ b/monai/transforms/compose.py @@ -14,7 +14,7 @@ import warnings from copy import deepcopy -from typing import Any, Callable, Dict, Hashable, List, Mapping, Optional, Sequence, Tuple, Union +from typing import Any, Callable, Hashable, Mapping, Optional, Sequence, Tuple, Union import numpy as np diff --git a/monai/transforms/inverse_transform.py b/monai/transforms/inverse_transform.py index 8b368264cc..a20e233b57 100644 --- a/monai/transforms/inverse_transform.py +++ b/monai/transforms/inverse_transform.py @@ -11,7 +11,7 @@ import warnings from abc import ABC -from typing import Any, Dict, Hashable, Optional, Tuple +from typing import Dict, Hashable, Optional, Tuple import numpy as np import torch @@ -79,7 +79,7 @@ def remove_most_recent_transform(data: dict, key: Hashable) -> None: """Remove most recent transform.""" data[str(key) + "_transforms"].pop() - def inverse(self, data: dict, keys: Optional[Tuple[Hashable, ...]] = None) -> Dict[str, Any]: + def inverse(self, data: dict, keys: Optional[Tuple[Hashable, ...]] = None) -> Dict[Hashable, np.ndarray]: """ Inverse of ``__call__``. diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index 2e3e15244f..aaaa0012bf 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -779,7 +779,9 @@ def inverse( inv_def_w_affine = CenterSpatialCrop(roi_size=orig_size)(inv_def_w_affine) # Apply inverse transform if inv_def_no_affine is not None: - out = self.rand_2d_elastic.resampler(d[key], inv_def_w_affine, self.mode[idx], self.padding_mode[idx]) + out = self.rand_2d_elastic.resampler( + d[key], inv_def_w_affine, self.mode[idx], self.padding_mode[idx] + ) d[key] = out.cpu().numpy() if isinstance(out, torch.Tensor) else out else: diff --git a/tests/test_inverse.py b/tests/test_inverse.py index b271dfb856..ac22ff0fcd 100644 --- a/tests/test_inverse.py +++ b/tests/test_inverse.py @@ -57,7 +57,6 @@ if TYPE_CHECKING: import matplotlib.pyplot as plt - import vtk has_matplotlib = True has_vtk = True From b0830673b5aafa0b7726b1d23c0e3e7c01f149fc Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Mon, 22 Feb 2021 15:23:11 +0000 Subject: [PATCH 76/80] inverse pad collation Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/data/inverse_batch_transform.py | 29 ++++++++++++++++++++++----- tests/test_inverse.py | 26 +++++++++++++++++++++++- 2 files changed, 49 insertions(+), 6 deletions(-) diff --git a/monai/data/inverse_batch_transform.py b/monai/data/inverse_batch_transform.py index dfef9bf4e7..49cee7f43e 100644 --- a/monai/data/inverse_batch_transform.py +++ b/monai/data/inverse_batch_transform.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from monai.transforms.croppad.array import CenterSpatialCrop +from monai.utils.misc import ensure_tuple from typing import Any, Callable, Dict, Hashable, Optional, Tuple import numpy as np @@ -16,7 +18,7 @@ from monai.data.dataloader import DataLoader from monai.data.dataset import Dataset -from monai.data.utils import decollate_batch +from monai.data.utils import decollate_batch, pad_list_data_collate from monai.transforms.inverse_transform import InvertibleTransform from monai.utils import first @@ -25,14 +27,30 @@ class _BatchInverseDataset(Dataset): def __init__( - self, data: Dict[str, Any], transform: InvertibleTransform, keys: Optional[Tuple[Hashable, ...]] = None + self, + data: Dict[str, Any], + transform: InvertibleTransform, + keys: Optional[Tuple[Hashable, ...]], + pad_collation_used: bool, ) -> None: self.data = decollate_batch(data) self.invertible_transform = transform - self.keys = keys + self.keys = ensure_tuple(keys) if keys else None + self.pad_collation_used = pad_collation_used def __getitem__(self, index: int) -> Dict[Hashable, np.ndarray]: - data = self.data[index] + data = dict(self.data[index]) + # If pad collation was used, then we need to undo this first + if self.pad_collation_used: + keys = self.keys or [key for key in data.keys() if str(key) + "_transforms" in data.keys()] + for key in keys: + transform_key = str(key) + "_transforms" + transform = data[transform_key].pop() + if transform["class"] != "SpatialPadd": + raise RuntimeError("Expected most recent transform to have been SpatialPadd because " + + "pad_list_data_collate was used. Instead, found " + transform["class"]) + data[key] = CenterSpatialCrop(transform["orig_size"])(data[key]) + return self.invertible_transform.inverse(data, self.keys) @@ -55,10 +73,11 @@ def __init__( self.batch_size = loader.batch_size self.num_workers = loader.num_workers self.collate_fn = collate_fn + self.pad_collation_used = loader.collate_fn == pad_list_data_collate def __call__(self, data: Dict[str, Any], keys: Optional[Tuple[Hashable, ...]] = None) -> Dict[Hashable, np.ndarray]: - inv_ds = _BatchInverseDataset(data, self.transform, keys) + inv_ds = _BatchInverseDataset(data, self.transform, keys, self.pad_collation_used) inv_loader = DataLoader( inv_ds, batch_size=self.batch_size, num_workers=self.num_workers, collate_fn=self.collate_fn ) diff --git a/tests/test_inverse.py b/tests/test_inverse.py index ac22ff0fcd..b30635ae42 100644 --- a/tests/test_inverse.py +++ b/tests/test_inverse.py @@ -19,7 +19,7 @@ from parameterized import parameterized from monai.data import BatchInverseTransform, CacheDataset, DataLoader, create_test_image_2d, create_test_image_3d -from monai.data.utils import decollate_batch +from monai.data.utils import decollate_batch, pad_list_data_collate from monai.networks.nets import UNet from monai.transforms import ( AddChannel, @@ -519,6 +519,30 @@ def test_diff_sized_inputs(self): unmodified = test_data[batch_idx * batch_size + idx] self.check_inverse("diff_sized_inputs", [key], unmodified, _fwd_bck, _fwd, 0) + def test_inverse_w_pad_list_data_collate(self): + + test_data = [] + for _ in range(4): + image, label = [AddChannel()(i) for i in create_test_image_2d(100, 101)] + test_data.append({"image": image, "label": label.astype(np.float32)}) + + batch_size = 2 + num_workers = 0 + transforms = Compose([CropForegroundd(KEYS, source_key="label")]) + + dataset = CacheDataset(test_data, transform=transforms, progress=False) + loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, collate_fn=pad_list_data_collate) + # blank collate function since input are different size + inv_batch = BatchInverseTransform(transforms, loader) + + for batch_idx, batch_data in enumerate(loader): + fwd = decollate_batch(batch_data) + fwd_bck = decollate_batch(inv_batch(batch_data)) + + for idx, (_fwd, _fwd_bck) in enumerate(zip(fwd, fwd_bck)): + unmodified = test_data[batch_idx * batch_size + idx] + self.check_inverse("diff_sized_inputs", KEYS, unmodified, _fwd_bck, _fwd, 2e-2) + @parameterized.expand(TESTS_FAIL) def test_fail(self, data, _, *transform): d = transform[0](data) From c4b81e511f85df3d66b0ca49deec28fb2766d5e0 Mon Sep 17 00:00:00 2001 From: Rich <33289025+rijobro@users.noreply.github.com> Date: Mon, 22 Feb 2021 15:46:42 +0000 Subject: [PATCH 77/80] update test threshold Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- tests/test_inverse.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_inverse.py b/tests/test_inverse.py index b30635ae42..1f31043489 100644 --- a/tests/test_inverse.py +++ b/tests/test_inverse.py @@ -541,7 +541,7 @@ def test_inverse_w_pad_list_data_collate(self): for idx, (_fwd, _fwd_bck) in enumerate(zip(fwd, fwd_bck)): unmodified = test_data[batch_idx * batch_size + idx] - self.check_inverse("diff_sized_inputs", KEYS, unmodified, _fwd_bck, _fwd, 2e-2) + self.check_inverse("diff_sized_inputs", KEYS, unmodified, _fwd_bck, _fwd, 1e-1) @parameterized.expand(TESTS_FAIL) def test_fail(self, data, _, *transform): From 70e18cab9ac4097bb9c518839cf334b247f68937 Mon Sep 17 00:00:00 2001 From: Rich <33289025+rijobro@users.noreply.github.com> Date: Tue, 23 Feb 2021 09:27:12 +0000 Subject: [PATCH 78/80] update transform location after deepgrow merge Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/apps/deepgrow/transforms.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/apps/deepgrow/transforms.py b/monai/apps/deepgrow/transforms.py index f178360031..80b0d1648d 100644 --- a/monai/apps/deepgrow/transforms.py +++ b/monai/apps/deepgrow/transforms.py @@ -17,7 +17,7 @@ from monai.config import IndexSelection, KeysCollection from monai.networks.layers import GaussianFilter from monai.transforms import SpatialCrop -from monai.transforms.compose import MapTransform, Randomizable, Transform +from monai.transforms.transform import MapTransform, Randomizable, Transform from monai.transforms.utils import generate_spatial_bounding_box from monai.utils import min_version, optional_import From ad7e9ff41e7526b9e3385e2fb75b246061b86687 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Tue, 23 Feb 2021 13:48:48 +0000 Subject: [PATCH 79/80] TTA Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/data/__init__.py | 1 + monai/data/inverse_batch_transform.py | 24 +++-- monai/data/nifti_saver.py | 4 +- monai/data/nifti_writer.py | 14 +-- monai/data/png_saver.py | 3 +- monai/data/test_time_augmentation.py | 108 ++++++++++++++++++++++ tests/test_testtimeaugmentation.py | 125 ++++++++++++++++++++++++++ 7 files changed, 259 insertions(+), 20 deletions(-) create mode 100644 monai/data/test_time_augmentation.py create mode 100644 tests/test_testtimeaugmentation.py diff --git a/monai/data/__init__.py b/monai/data/__init__.py index 679f88c2ab..56862de9e8 100644 --- a/monai/data/__init__.py +++ b/monai/data/__init__.py @@ -26,6 +26,7 @@ from .image_dataset import ImageDataset from .image_reader import ImageReader, ITKReader, NibabelReader, NumpyReader, PILReader from .inverse_batch_transform import BatchInverseTransform +from .test_time_augmentation import TestTimeAugmentation from .iterable_dataset import IterableDataset from .nifti_saver import NiftiSaver from .nifti_writer import write_nifti diff --git a/monai/data/inverse_batch_transform.py b/monai/data/inverse_batch_transform.py index 49cee7f43e..74b31df334 100644 --- a/monai/data/inverse_batch_transform.py +++ b/monai/data/inverse_batch_transform.py @@ -11,7 +11,7 @@ from monai.transforms.croppad.array import CenterSpatialCrop from monai.utils.misc import ensure_tuple -from typing import Any, Callable, Dict, Hashable, Optional, Tuple +from typing import Any, Callable, Dict, Hashable, Optional, Tuple, Union import numpy as np from torch.utils.data.dataloader import DataLoader as TorchDataLoader @@ -45,11 +45,11 @@ def __getitem__(self, index: int) -> Dict[Hashable, np.ndarray]: keys = self.keys or [key for key in data.keys() if str(key) + "_transforms" in data.keys()] for key in keys: transform_key = str(key) + "_transforms" - transform = data[transform_key].pop() - if transform["class"] != "SpatialPadd": - raise RuntimeError("Expected most recent transform to have been SpatialPadd because " + - "pad_list_data_collate was used. Instead, found " + transform["class"]) - data[key] = CenterSpatialCrop(transform["orig_size"])(data[key]) + transform = data[transform_key][-1] + if transform["class"] == "SpatialPadd": + data[key] = CenterSpatialCrop(transform["orig_size"])(data[key]) + # remove transform + data[transform_key].pop() return self.invertible_transform.inverse(data, self.keys) @@ -75,17 +75,23 @@ def __init__( self.collate_fn = collate_fn self.pad_collation_used = loader.collate_fn == pad_list_data_collate - def __call__(self, data: Dict[str, Any], keys: Optional[Tuple[Hashable, ...]] = None) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: Dict[str, Any], keys: Optional[Tuple[Hashable, ...]] = None) -> Union[Dict[Hashable, np.ndarray], np.ndarray]: inv_ds = _BatchInverseDataset(data, self.transform, keys, self.pad_collation_used) inv_loader = DataLoader( inv_ds, batch_size=self.batch_size, num_workers=self.num_workers, collate_fn=self.collate_fn ) try: - # Only need to return first as only 1 batch of data - return first(inv_loader) # type: ignore + output = first(inv_loader) except RuntimeError as re: re_str = str(re) if "stack expects each tensor to be equal size" in re_str: re_str += "\nMONAI hint: try creating `BatchInverseTransform` with `collate_fn=lambda x: x`." raise RuntimeError(re_str) + + # Only need to return first as only 1 batch of data + if keys is not None: + keys_tuple = ensure_tuple(keys) + if len(keys_tuple) == 1: + return output[keys_tuple[0]] + return output # type: ignore diff --git a/monai/data/nifti_saver.py b/monai/data/nifti_saver.py index 01e701b1a6..b75d7218e6 100644 --- a/monai/data/nifti_saver.py +++ b/monai/data/nifti_saver.py @@ -14,9 +14,9 @@ import numpy as np import torch +import monai from monai.config import DtypeLike from monai.data.nifti_writer import write_nifti -from monai.data.utils import create_file_basename from monai.utils import GridSampleMode, GridSamplePadMode from monai.utils import ImageMetaKey as Key @@ -104,7 +104,7 @@ def save(self, data: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict] if isinstance(data, torch.Tensor): data = data.detach().cpu().numpy() - filename = create_file_basename(self.output_postfix, filename, self.output_dir) + filename = monai.data.utils.create_file_basename(self.output_postfix, filename, self.output_dir) filename = f"{filename}{self.output_ext}" # change data shape to be (channel, h, w, d) while len(data.shape) < 4: diff --git a/monai/data/nifti_writer.py b/monai/data/nifti_writer.py index f530482b14..13f01cd7e7 100644 --- a/monai/data/nifti_writer.py +++ b/monai/data/nifti_writer.py @@ -14,8 +14,8 @@ import numpy as np import torch +import monai from monai.config import DtypeLike -from monai.data.utils import compute_shape_offset, to_affine_nd from monai.networks.layers import AffineTransform from monai.utils import GridSampleMode, GridSamplePadMode, optional_import @@ -95,15 +95,15 @@ def write_nifti( sr = min(data.ndim, 3) if affine is None: affine = np.eye(4, dtype=np.float64) - affine = to_affine_nd(sr, affine) + affine = monai.data.utils.to_affine_nd(sr, affine) if target_affine is None: target_affine = affine - target_affine = to_affine_nd(sr, target_affine) + target_affine = monai.data.utils.to_affine_nd(sr, target_affine) if np.allclose(affine, target_affine, atol=1e-3): # no affine changes, save (data, affine) - results_img = nib.Nifti1Image(data.astype(output_dtype), to_affine_nd(3, target_affine)) + results_img = nib.Nifti1Image(data.astype(output_dtype), monai.data.utils.to_affine_nd(3, target_affine)) nib.save(results_img, file_name) return @@ -115,7 +115,7 @@ def write_nifti( data = nib.orientations.apply_orientation(data, ornt_transform) _affine = affine @ nib.orientations.inv_ornt_aff(ornt_transform, data_shape) if np.allclose(_affine, target_affine, atol=1e-3) or not resample: - results_img = nib.Nifti1Image(data.astype(output_dtype), to_affine_nd(3, target_affine)) + results_img = nib.Nifti1Image(data.astype(output_dtype), monai.data.utils.to_affine_nd(3, target_affine)) nib.save(results_img, file_name) return @@ -125,7 +125,7 @@ def write_nifti( ) transform = np.linalg.inv(_affine) @ target_affine if output_spatial_shape is None: - output_spatial_shape, _ = compute_shape_offset(data.shape, _affine, target_affine) + output_spatial_shape, _ = monai.data.utils.compute_shape_offset(data.shape, _affine, target_affine) output_spatial_shape_ = list(output_spatial_shape) if output_spatial_shape is not None else [] if data.ndim > 3: # multi channel, resampling each channel while len(output_spatial_shape_) < 3: @@ -151,6 +151,6 @@ def write_nifti( ) data_np = data_torch.squeeze(0).squeeze(0).detach().cpu().numpy() - results_img = nib.Nifti1Image(data_np.astype(output_dtype), to_affine_nd(3, target_affine)) + results_img = nib.Nifti1Image(data_np.astype(output_dtype), monai.data.utils.to_affine_nd(3, target_affine)) nib.save(results_img, file_name) return diff --git a/monai/data/png_saver.py b/monai/data/png_saver.py index 4c4c847824..84beeb35e9 100644 --- a/monai/data/png_saver.py +++ b/monai/data/png_saver.py @@ -15,7 +15,6 @@ import torch from monai.data.png_writer import write_png -from monai.data.utils import create_file_basename from monai.utils import ImageMetaKey as Key from monai.utils import InterpolateMode @@ -90,7 +89,7 @@ def save(self, data: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict] if isinstance(data, torch.Tensor): data = data.detach().cpu().numpy() - filename = create_file_basename(self.output_postfix, filename, self.output_dir) + filename = monai.data.utils.create_file_basename(self.output_postfix, filename, self.output_dir) filename = f"{filename}{self.output_ext}" if data.shape[0] == 1: diff --git a/monai/data/test_time_augmentation.py b/monai/data/test_time_augmentation.py new file mode 100644 index 0000000000..c4f1dd1434 --- /dev/null +++ b/monai/data/test_time_augmentation.py @@ -0,0 +1,108 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +from typing import Any, Dict +import torch + +from monai.data.dataloader import DataLoader +from monai.data.dataset import Dataset +from monai.data.inverse_batch_transform import BatchInverseTransform +from monai.data.utils import pad_list_data_collate +from monai.transforms.compose import Compose +from monai.transforms.inverse_transform import InvertibleTransform +from monai.transforms.transform import Randomizable + + +__all__ = ["TestTimeAugmentation"] + +def is_transform_rand(transform): + if not isinstance(transform, Compose): + return isinstance(transform, Randomizable) + # call recursively for each sub-transform + return any(is_transform_rand(t) for t in transform.transforms) + + +class TestTimeAugmentation: + def __init__( + self, + transform: InvertibleTransform, + batch_size, + num_workers, + inferrer_fn, + device, + ) -> None: + self.transform = transform + self.batch_size = batch_size + self.num_workers = num_workers + self.inferrer_fn = inferrer_fn + self.device = device + + # check that the transform has at least one random component + if not is_transform_rand(self.transform): + raise RuntimeError(type(self).__name__ + " requires a `Randomizable` transform or a" + + " `Compose` containing at least one `Randomizable` transform.") + + def __call__(self, data: Dict[str, Any], num_examples=10, image_key="image", label_key="label", return_full_data=False): + d = dict(data) + + # check num examples is multiple of batch size + if num_examples % self.batch_size != 0: + raise ValueError("num_examples should be multiple of batch size.") + + # generate batch of data of size == batch_size, dataset and dataloader + data_in = [d for _ in range(num_examples)] + ds = Dataset(data_in, self.transform) + dl = DataLoader(ds, self.num_workers, batch_size=self.batch_size, collate_fn=pad_list_data_collate) + + label_transform_key = label_key + "_transforms" + + # create inverter + inverter = BatchInverseTransform(self.transform, dl) + + outputs = [] + + for batch_data in dl: + + batch_images = batch_data[image_key].to(self.device) + + # do model forward pass + batch_output = self.inferrer_fn(batch_images) + if isinstance(batch_output, torch.Tensor): + batch_output = batch_output.detach().cpu() + if isinstance(batch_output, np.ndarray): + batch_output = torch.Tensor(batch_output) + + # check binary labels are extracted + if not all(torch.unique(batch_output.int()) == torch.Tensor([0, 1])): + raise RuntimeError("Test-time augmentation requires binary channels. If this is " + "not binary segmentation, then you should one-hot your output.") + + # create a dictionary containing the inferred batch and their transforms + inferred_dict = {label_key: batch_output, label_transform_key: batch_data[label_transform_key]} + + # do inverse transformation (only for the label key) + inv_batch = inverter(inferred_dict, label_key) + + # append + outputs.append(inv_batch) + + # calculate mean and standard deviation + output = np.concatenate(outputs) + + if return_full_data: + return output + + mode = np.apply_along_axis(lambda x: np.bincount(x).argmax(), axis=0, arr=output.astype(np.int64)) + mean = np.mean(output, axis=0) + std = np.std(output, axis=0) + vvc = np.std(output) / np.mean(output) + return mode, mean, std, vvc diff --git a/tests/test_testtimeaugmentation.py b/tests/test_testtimeaugmentation.py new file mode 100644 index 0000000000..879f4e1e55 --- /dev/null +++ b/tests/test_testtimeaugmentation.py @@ -0,0 +1,125 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from monai.transforms.croppad.dictionary import SpatialPadd +from monai.data.test_time_augmentation import TestTimeAugmentation +import unittest +from typing import TYPE_CHECKING + +import numpy as np +from torch._C import has_cuda +from functools import partial +from monai.data import CacheDataset, DataLoader, create_test_image_2d +from monai.data.utils import pad_list_data_collate +from monai.networks.nets import UNet +from monai.transforms import ( + Activations, + AddChanneld, + AsDiscrete, + Compose, + CropForegroundd, + DivisiblePadd, + KeepLargestConnectedComponent, + RandAffined, +) +import torch +from monai.losses import DiceLoss +from monai.utils import optional_import, set_determinism + + +if TYPE_CHECKING: + import tqdm + + has_tqdm = True +else: + tqdm, has_tqdm = optional_import("tqdm") + +trange = partial(tqdm.trange, desc="training") if has_tqdm else range + +set_determinism(seed=0) + + +class TestTestTimeAugmentation(unittest.TestCase): + def test_test_time_augmentation(self): + input_size = (20, 20) + device = "cuda" if has_cuda else "cpu" + num_training_ims = 10 + data = [] + custom_create_test_image_2d = partial(create_test_image_2d, *input_size, rad_max=7, num_seg_classes=1, num_objs=1) + keys = ["image", "label"] + + for i in range(num_training_ims): + im, label = custom_create_test_image_2d() + data.append({"image": im, "label": label}) + + transforms = Compose([ + AddChanneld(keys), + RandAffined( + keys, + prob=1.0, + spatial_size=(30, 30), + rotate_range=(np.pi/3, np.pi/3), + translate_range=(3, 3), + scale_range=((0.8, 1), (0.8, 1)), + padding_mode="zeros", + mode=("bilinear", "nearest"), + as_tensor_output=False, + ), + CropForegroundd(keys, source_key="image"), + DivisiblePadd(keys, 4), + ]) + + train_ds = CacheDataset(data, transforms) + # output might be different size, so pad so that they match + train_loader = DataLoader(train_ds, batch_size=2, collate_fn=pad_list_data_collate) + + model = UNet(2, 1, 1, channels=(6, 6), strides=(2, 2)).to(device) + loss_function = DiceLoss(sigmoid=True) + optimizer = torch.optim.Adam(model.parameters(), 1e-3) + + num_epochs = 10 + for _ in trange(num_epochs): + epoch_loss = 0 + + for batch_data in train_loader: + inputs, labels = batch_data["image"].to(device), batch_data["label"].to(device) + optimizer.zero_grad() + outputs = model(inputs) + loss = loss_function(outputs, labels) + loss.backward() + optimizer.step() + epoch_loss += loss.item() + + epoch_loss /= len(train_loader) + + image, label = custom_create_test_image_2d() + test_data = {"image": image, "label": label} + + post_trans = Compose([ + Activations(sigmoid=True), + AsDiscrete(threshold_values=True), + KeepLargestConnectedComponent(applied_labels=1), + ]) + inferrer_fn = lambda x: post_trans(model(x)) + tt_aug = TestTimeAugmentation(transforms, batch_size=5, num_workers=0, inferrer_fn=inferrer_fn, device=device) + mean, std = tt_aug(test_data) + self.assertEqual(mean.shape, (1,) + input_size) + self.assertEqual((mean.min(), mean.max()), (0.0, 1.0)) + self.assertEqual(std.shape, (1,) + input_size) + + + def test_fail_non_random(self): + transforms = Compose([AddChanneld("im"), SpatialPadd("im", 1)]) + with self.assertRaises(RuntimeError): + TestTimeAugmentation(transforms, None, None, None, None) + +if __name__ == "__main__": + unittest.main() From 19cf1d07cafaaffbd56aa3c10bb9f6aac1170211 Mon Sep 17 00:00:00 2001 From: Rich <33289025+rijobro@users.noreply.github.com> Date: Tue, 23 Feb 2021 14:52:35 +0000 Subject: [PATCH 80/80] code format changes Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/data/__init__.py | 2 +- monai/data/inverse_batch_transform.py | 17 ++---- monai/data/nifti_saver.py | 2 +- monai/data/nifti_writer.py | 2 +- monai/data/png_saver.py | 1 + monai/data/test_time_augmentation.py | 24 ++++++--- monai/transforms/compose.py | 4 +- monai/transforms/inverse_transform.py | 7 ++- monai/transforms/io/dictionary.py | 2 +- monai/transforms/spatial/dictionary.py | 12 ++--- tests/test_inverse.py | 4 +- tests/test_testtimeaugmentation.py | 71 +++++++++++++++----------- 12 files changed, 80 insertions(+), 68 deletions(-) diff --git a/monai/data/__init__.py b/monai/data/__init__.py index 56862de9e8..5ba4a990af 100644 --- a/monai/data/__init__.py +++ b/monai/data/__init__.py @@ -26,13 +26,13 @@ from .image_dataset import ImageDataset from .image_reader import ImageReader, ITKReader, NibabelReader, NumpyReader, PILReader from .inverse_batch_transform import BatchInverseTransform -from .test_time_augmentation import TestTimeAugmentation from .iterable_dataset import IterableDataset from .nifti_saver import NiftiSaver from .nifti_writer import write_nifti from .png_saver import PNGSaver from .png_writer import write_png from .synthetic import create_test_image_2d, create_test_image_3d +from .test_time_augmentation import TestTimeAugmentation from .thread_buffer import ThreadBuffer from .utils import ( DistributedSampler, diff --git a/monai/data/inverse_batch_transform.py b/monai/data/inverse_batch_transform.py index 74b31df334..1f6c903e36 100644 --- a/monai/data/inverse_batch_transform.py +++ b/monai/data/inverse_batch_transform.py @@ -9,9 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from monai.transforms.croppad.array import CenterSpatialCrop -from monai.utils.misc import ensure_tuple -from typing import Any, Callable, Dict, Hashable, Optional, Tuple, Union +from typing import Any, Callable, Dict, Hashable, Optional, Tuple import numpy as np from torch.utils.data.dataloader import DataLoader as TorchDataLoader @@ -19,8 +17,10 @@ from monai.data.dataloader import DataLoader from monai.data.dataset import Dataset from monai.data.utils import decollate_batch, pad_list_data_collate +from monai.transforms.croppad.array import CenterSpatialCrop from monai.transforms.inverse_transform import InvertibleTransform from monai.utils import first +from monai.utils.misc import ensure_tuple __all__ = ["BatchInverseTransform"] @@ -75,23 +75,16 @@ def __init__( self.collate_fn = collate_fn self.pad_collation_used = loader.collate_fn == pad_list_data_collate - def __call__(self, data: Dict[str, Any], keys: Optional[Tuple[Hashable, ...]] = None) -> Union[Dict[Hashable, np.ndarray], np.ndarray]: + def __call__(self, data: Dict[str, Any], keys: Optional[Tuple[Hashable, ...]] = None) -> Any: inv_ds = _BatchInverseDataset(data, self.transform, keys, self.pad_collation_used) inv_loader = DataLoader( inv_ds, batch_size=self.batch_size, num_workers=self.num_workers, collate_fn=self.collate_fn ) try: - output = first(inv_loader) + return first(inv_loader) except RuntimeError as re: re_str = str(re) if "stack expects each tensor to be equal size" in re_str: re_str += "\nMONAI hint: try creating `BatchInverseTransform` with `collate_fn=lambda x: x`." raise RuntimeError(re_str) - - # Only need to return first as only 1 batch of data - if keys is not None: - keys_tuple = ensure_tuple(keys) - if len(keys_tuple) == 1: - return output[keys_tuple[0]] - return output # type: ignore diff --git a/monai/data/nifti_saver.py b/monai/data/nifti_saver.py index b75d7218e6..155666a168 100644 --- a/monai/data/nifti_saver.py +++ b/monai/data/nifti_saver.py @@ -14,7 +14,7 @@ import numpy as np import torch -import monai +import monai.data.utils from monai.config import DtypeLike from monai.data.nifti_writer import write_nifti from monai.utils import GridSampleMode, GridSamplePadMode diff --git a/monai/data/nifti_writer.py b/monai/data/nifti_writer.py index 13f01cd7e7..be9bafc765 100644 --- a/monai/data/nifti_writer.py +++ b/monai/data/nifti_writer.py @@ -14,7 +14,7 @@ import numpy as np import torch -import monai +import monai.data.utils from monai.config import DtypeLike from monai.networks.layers import AffineTransform from monai.utils import GridSampleMode, GridSamplePadMode, optional_import diff --git a/monai/data/png_saver.py b/monai/data/png_saver.py index 84beeb35e9..114524de91 100644 --- a/monai/data/png_saver.py +++ b/monai/data/png_saver.py @@ -14,6 +14,7 @@ import numpy as np import torch +import monai.data.utils from monai.data.png_writer import write_png from monai.utils import ImageMetaKey as Key from monai.utils import InterpolateMode diff --git a/monai/data/test_time_augmentation.py b/monai/data/test_time_augmentation.py index c4f1dd1434..48cbc54843 100644 --- a/monai/data/test_time_augmentation.py +++ b/monai/data/test_time_augmentation.py @@ -9,8 +9,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -import numpy as np from typing import Any, Dict + +import numpy as np import torch from monai.data.dataloader import DataLoader @@ -21,9 +22,9 @@ from monai.transforms.inverse_transform import InvertibleTransform from monai.transforms.transform import Randomizable - __all__ = ["TestTimeAugmentation"] + def is_transform_rand(transform): if not isinstance(transform, Compose): return isinstance(transform, Randomizable) @@ -48,10 +49,15 @@ def __init__( # check that the transform has at least one random component if not is_transform_rand(self.transform): - raise RuntimeError(type(self).__name__ + " requires a `Randomizable` transform or a" - + " `Compose` containing at least one `Randomizable` transform.") - - def __call__(self, data: Dict[str, Any], num_examples=10, image_key="image", label_key="label", return_full_data=False): + raise RuntimeError( + type(self).__name__ + + " requires a `Randomizable` transform or a" + + " `Compose` containing at least one `Randomizable` transform." + ) + + def __call__( + self, data: Dict[str, Any], num_examples=10, image_key="image", label_key="label", return_full_data=False + ): d = dict(data) # check num examples is multiple of batch size @@ -83,8 +89,10 @@ def __call__(self, data: Dict[str, Any], num_examples=10, image_key="image", lab # check binary labels are extracted if not all(torch.unique(batch_output.int()) == torch.Tensor([0, 1])): - raise RuntimeError("Test-time augmentation requires binary channels. If this is " - "not binary segmentation, then you should one-hot your output.") + raise RuntimeError( + "Test-time augmentation requires binary channels. If this is " + "not binary segmentation, then you should one-hot your output." + ) # create a dictionary containing the inferred batch and their transforms inferred_dict = {label_key: batch_output, label_transform_key: batch_data[label_transform_key]} diff --git a/monai/transforms/compose.py b/monai/transforms/compose.py index 08d3466dab..d313be2fbc 100644 --- a/monai/transforms/compose.py +++ b/monai/transforms/compose.py @@ -19,14 +19,14 @@ import numpy as np from monai.transforms.inverse_transform import InvertibleTransform -from monai.transforms.transform import Randomizable, Transform +from monai.transforms.transform import Randomizable from monai.transforms.utils import apply_transform from monai.utils import MAX_SEED, ensure_tuple, get_seed __all__ = ["Compose"] -class Compose(Randomizable, Transform, InvertibleTransform): +class Compose(Randomizable, InvertibleTransform): """ ``Compose`` provides the ability to chain a series of calls together in a sequence. Each transform in the sequence must take a single argument and diff --git a/monai/transforms/inverse_transform.py b/monai/transforms/inverse_transform.py index a20e233b57..8b525de4aa 100644 --- a/monai/transforms/inverse_transform.py +++ b/monai/transforms/inverse_transform.py @@ -10,13 +10,12 @@ # limitations under the License. import warnings -from abc import ABC from typing import Dict, Hashable, Optional, Tuple import numpy as np import torch -from monai.transforms.transform import Randomizable +from monai.transforms.transform import Randomizable, Transform from monai.utils import optional_import sitk, has_sitk = optional_import("SimpleITK") @@ -26,7 +25,7 @@ __all__ = ["InvertibleTransform", "NonRigidTransform"] -class InvertibleTransform(ABC): +class InvertibleTransform(Transform): """Classes for invertible transforms. This class exists so that an ``invert`` method can be implemented. This allows, for @@ -90,7 +89,7 @@ def inverse(self, data: dict, keys: Optional[Tuple[Hashable, ...]] = None) -> Di raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") -class NonRigidTransform(ABC): +class NonRigidTransform(Transform): @staticmethod def _get_disp_to_def_arr(shape, spacing): def_to_disp = np.mgrid[[slice(0, i) for i in shape]].astype(np.float64) diff --git a/monai/transforms/io/dictionary.py b/monai/transforms/io/dictionary.py index 2c704a86b6..55707f750e 100644 --- a/monai/transforms/io/dictionary.py +++ b/monai/transforms/io/dictionary.py @@ -22,8 +22,8 @@ from monai.config import DtypeLike, KeysCollection from monai.data.image_reader import ImageReader from monai.transforms.io.array import LoadImage, SaveImage -from monai.utils import GridSampleMode, GridSamplePadMode, InterpolateMode from monai.transforms.transform import MapTransform +from monai.utils import GridSampleMode, GridSamplePadMode, InterpolateMode __all__ = [ "LoadImaged", diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index aaaa0012bf..c90742f4f9 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -772,11 +772,11 @@ def inverse( if inv_def_no_affine is not None: fwd_affine = transform["extra_info"]["affine"] inv_affine = np.linalg.inv(fwd_affine) - inv_def_w_affine: np.ndarray = AffineGrid(affine=inv_affine, as_tensor_output=False)( + inv_def_w_affine_wrong_size = AffineGrid(affine=inv_affine, as_tensor_output=False)( grid=inv_def_no_affine - ) # type: ignore + ) # Back to original size - inv_def_w_affine = CenterSpatialCrop(roi_size=orig_size)(inv_def_w_affine) + inv_def_w_affine = CenterSpatialCrop(roi_size=orig_size)(inv_def_w_affine_wrong_size) # type: ignore # Apply inverse transform if inv_def_no_affine is not None: out = self.rand_2d_elastic.resampler( @@ -925,11 +925,11 @@ def inverse( if inv_def_no_affine is not None: fwd_affine = transform["extra_info"]["affine"] inv_affine = np.linalg.inv(fwd_affine) - inv_def_w_affine: np.ndarray = AffineGrid(affine=inv_affine, as_tensor_output=False)( + inv_def_w_affine_wrong_size = AffineGrid(affine=inv_affine, as_tensor_output=False)( grid=inv_def_no_affine - ) # type: ignore + ) # Back to original size - inv_def_w_affine = CenterSpatialCrop(roi_size=orig_size)(inv_def_w_affine) + inv_def_w_affine = CenterSpatialCrop(roi_size=orig_size)(inv_def_w_affine_wrong_size) # type: ignore # Apply inverse transform if inv_def_w_affine is not None: out = self.rand_3d_elastic.resampler( diff --git a/tests/test_inverse.py b/tests/test_inverse.py index 1f31043489..1ca0e3e8cf 100644 --- a/tests/test_inverse.py +++ b/tests/test_inverse.py @@ -531,7 +531,9 @@ def test_inverse_w_pad_list_data_collate(self): transforms = Compose([CropForegroundd(KEYS, source_key="label")]) dataset = CacheDataset(test_data, transform=transforms, progress=False) - loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, collate_fn=pad_list_data_collate) + loader = DataLoader( + dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, collate_fn=pad_list_data_collate + ) # blank collate function since input are different size inv_batch = BatchInverseTransform(transforms, loader) diff --git a/tests/test_testtimeaugmentation.py b/tests/test_testtimeaugmentation.py index 879f4e1e55..93027e1ac8 100644 --- a/tests/test_testtimeaugmentation.py +++ b/tests/test_testtimeaugmentation.py @@ -9,16 +9,18 @@ # See the License for the specific language governing permissions and # limitations under the License. -from monai.transforms.croppad.dictionary import SpatialPadd -from monai.data.test_time_augmentation import TestTimeAugmentation import unittest +from functools import partial from typing import TYPE_CHECKING import numpy as np +import torch from torch._C import has_cuda -from functools import partial + from monai.data import CacheDataset, DataLoader, create_test_image_2d +from monai.data.test_time_augmentation import TestTimeAugmentation from monai.data.utils import pad_list_data_collate +from monai.losses import DiceLoss from monai.networks.nets import UNet from monai.transforms import ( Activations, @@ -30,11 +32,9 @@ KeepLargestConnectedComponent, RandAffined, ) -import torch -from monai.losses import DiceLoss +from monai.transforms.croppad.dictionary import SpatialPadd from monai.utils import optional_import, set_determinism - if TYPE_CHECKING: import tqdm @@ -53,29 +53,33 @@ def test_test_time_augmentation(self): device = "cuda" if has_cuda else "cpu" num_training_ims = 10 data = [] - custom_create_test_image_2d = partial(create_test_image_2d, *input_size, rad_max=7, num_seg_classes=1, num_objs=1) + custom_create_test_image_2d = partial( + create_test_image_2d, *input_size, rad_max=7, num_seg_classes=1, num_objs=1 + ) keys = ["image", "label"] - for i in range(num_training_ims): + for _ in range(num_training_ims): im, label = custom_create_test_image_2d() data.append({"image": im, "label": label}) - transforms = Compose([ - AddChanneld(keys), - RandAffined( - keys, - prob=1.0, - spatial_size=(30, 30), - rotate_range=(np.pi/3, np.pi/3), - translate_range=(3, 3), - scale_range=((0.8, 1), (0.8, 1)), - padding_mode="zeros", - mode=("bilinear", "nearest"), - as_tensor_output=False, - ), - CropForegroundd(keys, source_key="image"), - DivisiblePadd(keys, 4), - ]) + transforms = Compose( + [ + AddChanneld(keys), + RandAffined( + keys, + prob=1.0, + spatial_size=(30, 30), + rotate_range=(np.pi / 3, np.pi / 3), + translate_range=(3, 3), + scale_range=((0.8, 1), (0.8, 1)), + padding_mode="zeros", + mode=("bilinear", "nearest"), + as_tensor_output=False, + ), + CropForegroundd(keys, source_key="image"), + DivisiblePadd(keys, 4), + ] + ) train_ds = CacheDataset(data, transforms) # output might be different size, so pad so that they match @@ -103,23 +107,28 @@ def test_test_time_augmentation(self): image, label = custom_create_test_image_2d() test_data = {"image": image, "label": label} - post_trans = Compose([ - Activations(sigmoid=True), - AsDiscrete(threshold_values=True), - KeepLargestConnectedComponent(applied_labels=1), - ]) - inferrer_fn = lambda x: post_trans(model(x)) + post_trans = Compose( + [ + Activations(sigmoid=True), + AsDiscrete(threshold_values=True), + KeepLargestConnectedComponent(applied_labels=1), + ] + ) + + def inferrer_fn(x): + return post_trans(model(x)) + tt_aug = TestTimeAugmentation(transforms, batch_size=5, num_workers=0, inferrer_fn=inferrer_fn, device=device) mean, std = tt_aug(test_data) self.assertEqual(mean.shape, (1,) + input_size) self.assertEqual((mean.min(), mean.max()), (0.0, 1.0)) self.assertEqual(std.shape, (1,) + input_size) - def test_fail_non_random(self): transforms = Compose([AddChanneld("im"), SpatialPadd("im", 1)]) with self.assertRaises(RuntimeError): TestTimeAugmentation(transforms, None, None, None, None) + if __name__ == "__main__": unittest.main()