diff --git a/monai/data/__init__.py b/monai/data/__init__.py index 3dd0a980ef..5ba4a990af 100644 --- a/monai/data/__init__.py +++ b/monai/data/__init__.py @@ -25,12 +25,14 @@ from .grid_dataset import GridPatchDataset, PatchDataset from .image_dataset import ImageDataset from .image_reader import ImageReader, ITKReader, NibabelReader, NumpyReader, PILReader +from .inverse_batch_transform import BatchInverseTransform from .iterable_dataset import IterableDataset from .nifti_saver import NiftiSaver from .nifti_writer import write_nifti from .png_saver import PNGSaver from .png_writer import write_png from .synthetic import create_test_image_2d, create_test_image_3d +from .test_time_augmentation import TestTimeAugmentation from .thread_buffer import ThreadBuffer from .utils import ( DistributedSampler, diff --git a/monai/data/image_reader.py b/monai/data/image_reader.py index e458833979..343f5b742f 100644 --- a/monai/data/image_reader.py +++ b/monai/data/image_reader.py @@ -17,12 +17,10 @@ import numpy as np from torch.utils.data._utils.collate import np_str_obj_array_pattern +import monai.data.utils from monai.config import DtypeLike, KeysCollection -from monai.data.utils import correct_nifti_header_if_necessary from monai.utils import ensure_tuple, optional_import -from .utils import is_supported_format - if TYPE_CHECKING: import itk # type: ignore import nibabel as nib @@ -322,7 +320,7 @@ def verify_suffix(self, filename: Union[Sequence[str], str]) -> bool: """ suffixes: Sequence[str] = ["nii", "nii.gz"] - return has_nib and is_supported_format(filename, suffixes) + return has_nib and monai.data.is_supported_format(filename, suffixes) def read(self, data: Union[Sequence[str], str], **kwargs): """ @@ -343,7 +341,7 @@ def read(self, data: Union[Sequence[str], str], **kwargs): kwargs_.update(kwargs) for name in filenames: img = nib.load(name, **kwargs_) - img = correct_nifti_header_if_necessary(img) + img = monai.data.utils.correct_nifti_header_if_necessary(img) img_.append(img) return img_ if len(filenames) > 1 else img_[0] @@ -453,7 +451,7 @@ def verify_suffix(self, filename: Union[Sequence[str], str]) -> bool: if a list of files, verify all the suffixes. """ suffixes: Sequence[str] = ["npz", "npy"] - return is_supported_format(filename, suffixes) + return monai.data.is_supported_format(filename, suffixes) def read(self, data: Union[Sequence[str], str], **kwargs): """ @@ -537,7 +535,7 @@ def verify_suffix(self, filename: Union[Sequence[str], str]) -> bool: if a list of files, verify all the suffixes. """ suffixes: Sequence[str] = ["png", "jpg", "jpeg", "bmp"] - return has_pil and is_supported_format(filename, suffixes) + return has_pil and monai.data.is_supported_format(filename, suffixes) def read(self, data: Union[Sequence[str], str, np.ndarray], **kwargs): """ diff --git a/monai/data/inverse_batch_transform.py b/monai/data/inverse_batch_transform.py new file mode 100644 index 0000000000..1f6c903e36 --- /dev/null +++ b/monai/data/inverse_batch_transform.py @@ -0,0 +1,90 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Callable, Dict, Hashable, Optional, Tuple + +import numpy as np +from torch.utils.data.dataloader import DataLoader as TorchDataLoader + +from monai.data.dataloader import DataLoader +from monai.data.dataset import Dataset +from monai.data.utils import decollate_batch, pad_list_data_collate +from monai.transforms.croppad.array import CenterSpatialCrop +from monai.transforms.inverse_transform import InvertibleTransform +from monai.utils import first +from monai.utils.misc import ensure_tuple + +__all__ = ["BatchInverseTransform"] + + +class _BatchInverseDataset(Dataset): + def __init__( + self, + data: Dict[str, Any], + transform: InvertibleTransform, + keys: Optional[Tuple[Hashable, ...]], + pad_collation_used: bool, + ) -> None: + self.data = decollate_batch(data) + self.invertible_transform = transform + self.keys = ensure_tuple(keys) if keys else None + self.pad_collation_used = pad_collation_used + + def __getitem__(self, index: int) -> Dict[Hashable, np.ndarray]: + data = dict(self.data[index]) + # If pad collation was used, then we need to undo this first + if self.pad_collation_used: + keys = self.keys or [key for key in data.keys() if str(key) + "_transforms" in data.keys()] + for key in keys: + transform_key = str(key) + "_transforms" + transform = data[transform_key][-1] + if transform["class"] == "SpatialPadd": + data[key] = CenterSpatialCrop(transform["orig_size"])(data[key]) + # remove transform + data[transform_key].pop() + + return self.invertible_transform.inverse(data, self.keys) + + +class BatchInverseTransform: + """something""" + + def __init__( + self, transform: InvertibleTransform, loader: TorchDataLoader, collate_fn: Optional[Callable] = None + ) -> None: + """ + Args: + transform: a callable data transform on input data. + loader: data loader used to generate the batch of data. + collate_fn: how to collate data after inverse transformations. Default will use the DataLoader's default + collation method. If returning images of different sizes, this will likely create an error (since the + collation will concatenate arrays, requiring them to be the same size). In this case, using + `collate_fn=lambda x: x` might solve the problem. + """ + self.transform = transform + self.batch_size = loader.batch_size + self.num_workers = loader.num_workers + self.collate_fn = collate_fn + self.pad_collation_used = loader.collate_fn == pad_list_data_collate + + def __call__(self, data: Dict[str, Any], keys: Optional[Tuple[Hashable, ...]] = None) -> Any: + + inv_ds = _BatchInverseDataset(data, self.transform, keys, self.pad_collation_used) + inv_loader = DataLoader( + inv_ds, batch_size=self.batch_size, num_workers=self.num_workers, collate_fn=self.collate_fn + ) + try: + return first(inv_loader) + except RuntimeError as re: + re_str = str(re) + if "stack expects each tensor to be equal size" in re_str: + re_str += "\nMONAI hint: try creating `BatchInverseTransform` with `collate_fn=lambda x: x`." + raise RuntimeError(re_str) diff --git a/monai/data/nifti_saver.py b/monai/data/nifti_saver.py index 01e701b1a6..155666a168 100644 --- a/monai/data/nifti_saver.py +++ b/monai/data/nifti_saver.py @@ -14,9 +14,9 @@ import numpy as np import torch +import monai.data.utils from monai.config import DtypeLike from monai.data.nifti_writer import write_nifti -from monai.data.utils import create_file_basename from monai.utils import GridSampleMode, GridSamplePadMode from monai.utils import ImageMetaKey as Key @@ -104,7 +104,7 @@ def save(self, data: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict] if isinstance(data, torch.Tensor): data = data.detach().cpu().numpy() - filename = create_file_basename(self.output_postfix, filename, self.output_dir) + filename = monai.data.utils.create_file_basename(self.output_postfix, filename, self.output_dir) filename = f"{filename}{self.output_ext}" # change data shape to be (channel, h, w, d) while len(data.shape) < 4: diff --git a/monai/data/nifti_writer.py b/monai/data/nifti_writer.py index f530482b14..be9bafc765 100644 --- a/monai/data/nifti_writer.py +++ b/monai/data/nifti_writer.py @@ -14,8 +14,8 @@ import numpy as np import torch +import monai.data.utils from monai.config import DtypeLike -from monai.data.utils import compute_shape_offset, to_affine_nd from monai.networks.layers import AffineTransform from monai.utils import GridSampleMode, GridSamplePadMode, optional_import @@ -95,15 +95,15 @@ def write_nifti( sr = min(data.ndim, 3) if affine is None: affine = np.eye(4, dtype=np.float64) - affine = to_affine_nd(sr, affine) + affine = monai.data.utils.to_affine_nd(sr, affine) if target_affine is None: target_affine = affine - target_affine = to_affine_nd(sr, target_affine) + target_affine = monai.data.utils.to_affine_nd(sr, target_affine) if np.allclose(affine, target_affine, atol=1e-3): # no affine changes, save (data, affine) - results_img = nib.Nifti1Image(data.astype(output_dtype), to_affine_nd(3, target_affine)) + results_img = nib.Nifti1Image(data.astype(output_dtype), monai.data.utils.to_affine_nd(3, target_affine)) nib.save(results_img, file_name) return @@ -115,7 +115,7 @@ def write_nifti( data = nib.orientations.apply_orientation(data, ornt_transform) _affine = affine @ nib.orientations.inv_ornt_aff(ornt_transform, data_shape) if np.allclose(_affine, target_affine, atol=1e-3) or not resample: - results_img = nib.Nifti1Image(data.astype(output_dtype), to_affine_nd(3, target_affine)) + results_img = nib.Nifti1Image(data.astype(output_dtype), monai.data.utils.to_affine_nd(3, target_affine)) nib.save(results_img, file_name) return @@ -125,7 +125,7 @@ def write_nifti( ) transform = np.linalg.inv(_affine) @ target_affine if output_spatial_shape is None: - output_spatial_shape, _ = compute_shape_offset(data.shape, _affine, target_affine) + output_spatial_shape, _ = monai.data.utils.compute_shape_offset(data.shape, _affine, target_affine) output_spatial_shape_ = list(output_spatial_shape) if output_spatial_shape is not None else [] if data.ndim > 3: # multi channel, resampling each channel while len(output_spatial_shape_) < 3: @@ -151,6 +151,6 @@ def write_nifti( ) data_np = data_torch.squeeze(0).squeeze(0).detach().cpu().numpy() - results_img = nib.Nifti1Image(data_np.astype(output_dtype), to_affine_nd(3, target_affine)) + results_img = nib.Nifti1Image(data_np.astype(output_dtype), monai.data.utils.to_affine_nd(3, target_affine)) nib.save(results_img, file_name) return diff --git a/monai/data/png_saver.py b/monai/data/png_saver.py index 4c4c847824..114524de91 100644 --- a/monai/data/png_saver.py +++ b/monai/data/png_saver.py @@ -14,8 +14,8 @@ import numpy as np import torch +import monai.data.utils from monai.data.png_writer import write_png -from monai.data.utils import create_file_basename from monai.utils import ImageMetaKey as Key from monai.utils import InterpolateMode @@ -90,7 +90,7 @@ def save(self, data: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict] if isinstance(data, torch.Tensor): data = data.detach().cpu().numpy() - filename = create_file_basename(self.output_postfix, filename, self.output_dir) + filename = monai.data.utils.create_file_basename(self.output_postfix, filename, self.output_dir) filename = f"{filename}{self.output_ext}" if data.shape[0] == 1: diff --git a/monai/data/test_time_augmentation.py b/monai/data/test_time_augmentation.py new file mode 100644 index 0000000000..48cbc54843 --- /dev/null +++ b/monai/data/test_time_augmentation.py @@ -0,0 +1,116 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict + +import numpy as np +import torch + +from monai.data.dataloader import DataLoader +from monai.data.dataset import Dataset +from monai.data.inverse_batch_transform import BatchInverseTransform +from monai.data.utils import pad_list_data_collate +from monai.transforms.compose import Compose +from monai.transforms.inverse_transform import InvertibleTransform +from monai.transforms.transform import Randomizable + +__all__ = ["TestTimeAugmentation"] + + +def is_transform_rand(transform): + if not isinstance(transform, Compose): + return isinstance(transform, Randomizable) + # call recursively for each sub-transform + return any(is_transform_rand(t) for t in transform.transforms) + + +class TestTimeAugmentation: + def __init__( + self, + transform: InvertibleTransform, + batch_size, + num_workers, + inferrer_fn, + device, + ) -> None: + self.transform = transform + self.batch_size = batch_size + self.num_workers = num_workers + self.inferrer_fn = inferrer_fn + self.device = device + + # check that the transform has at least one random component + if not is_transform_rand(self.transform): + raise RuntimeError( + type(self).__name__ + + " requires a `Randomizable` transform or a" + + " `Compose` containing at least one `Randomizable` transform." + ) + + def __call__( + self, data: Dict[str, Any], num_examples=10, image_key="image", label_key="label", return_full_data=False + ): + d = dict(data) + + # check num examples is multiple of batch size + if num_examples % self.batch_size != 0: + raise ValueError("num_examples should be multiple of batch size.") + + # generate batch of data of size == batch_size, dataset and dataloader + data_in = [d for _ in range(num_examples)] + ds = Dataset(data_in, self.transform) + dl = DataLoader(ds, self.num_workers, batch_size=self.batch_size, collate_fn=pad_list_data_collate) + + label_transform_key = label_key + "_transforms" + + # create inverter + inverter = BatchInverseTransform(self.transform, dl) + + outputs = [] + + for batch_data in dl: + + batch_images = batch_data[image_key].to(self.device) + + # do model forward pass + batch_output = self.inferrer_fn(batch_images) + if isinstance(batch_output, torch.Tensor): + batch_output = batch_output.detach().cpu() + if isinstance(batch_output, np.ndarray): + batch_output = torch.Tensor(batch_output) + + # check binary labels are extracted + if not all(torch.unique(batch_output.int()) == torch.Tensor([0, 1])): + raise RuntimeError( + "Test-time augmentation requires binary channels. If this is " + "not binary segmentation, then you should one-hot your output." + ) + + # create a dictionary containing the inferred batch and their transforms + inferred_dict = {label_key: batch_output, label_transform_key: batch_data[label_transform_key]} + + # do inverse transformation (only for the label key) + inv_batch = inverter(inferred_dict, label_key) + + # append + outputs.append(inv_batch) + + # calculate mean and standard deviation + output = np.concatenate(outputs) + + if return_full_data: + return output + + mode = np.apply_along_axis(lambda x: np.bincount(x).argmax(), axis=0, arr=output.astype(np.int64)) + mean = np.mean(output, axis=0) + std = np.std(output, axis=0) + vvc = np.std(output) / np.mean(output) + return mode, mean, std, vvc diff --git a/monai/data/utils.py b/monai/data/utils.py index 7717ddf3aa..2693f0107b 100644 --- a/monai/data/utils.py +++ b/monai/data/utils.py @@ -25,6 +25,7 @@ from torch.utils.data import DistributedSampler as _TorchDistributedSampler from torch.utils.data._utils.collate import default_collate +import monai.transforms.croppad.dictionary from monai.networks.layers.simplelayers import GaussianFilter from monai.utils import ( MAX_SEED, @@ -388,6 +389,71 @@ def pad_list_data_collate( return list_data_collate(batch) +def decollate_batch(data: dict, batch_size: Optional[int] = None): + """De-collate a batch of data (for example, as produced by a `DataLoader`). + + Returns a list of dictionaries. Each dictionary will only contain the data for a given batch. + + Images originally stored as (B,C,H,W,[D]) will be returned as (C,H,W,[D]). Other information, + such as metadata, may have been stored in a list (or a list inside nested dictionaries). In + this case we return the element of the list corresponding to the batch idx. + + Return types aren't guaranteed to be the same as the original, since numpy arrays will have been + converted to torch.Tensor, and tuples/lists may have been converted to lists of tensors + + For example: + + ``` + batch_data = { + "image": torch.rand((2,1,10,10)), + "image_meta_dict": {"scl_slope": torch.Tensor([0.0, 0.0])} + } + out = decollate_batch(batch_data) + print(len(out)) + >>> 2 + + print(out[0]) + >>> {'image': tensor([[[4.3549e-01...43e-01]]]), 'image_meta_dict': {'scl_slope': 0.0}} + ``` + + Args: + data: data to be de-collated + batch_size: number of batches in data. If `None` is passed, try to figure out batch size. + """ + if not isinstance(data, dict): + raise RuntimeError("Only currently implemented for dictionary data (might be trivial to adapt).") + if batch_size is None: + for v in data.values(): + if isinstance(v, torch.Tensor): + batch_size = v.shape[0] + break + if batch_size is None: + raise RuntimeError("Couldn't determine batch size, please specify as argument.") + + def torch_to_single(d: torch.Tensor): + """If input is a torch.Tensor with only 1 element, return just the element.""" + return d if d.numel() > 1 else d.item() + + def decollate(data: Any, idx: int): + """Recursively de-collate.""" + if isinstance(data, dict): + return {k: decollate(v, idx) for k, v in data.items()} + if isinstance(data, torch.Tensor): + out = data[idx] + return torch_to_single(out) + elif isinstance(data, list): + if len(data) == 0: + return data + if isinstance(data[0], torch.Tensor): + return [torch_to_single(d[idx]) for d in data] + if issequenceiterable(data[0]): + return [decollate(d, idx) for d in data] + return data[idx] + raise TypeError(f"Not sure how to de-collate type: {type(data)}") + + return [{key: decollate(data[key], idx) for key in data.keys()} for idx in range(batch_size)] + + def worker_init_fn(worker_id: int) -> None: """ Callback function for PyTorch DataLoader `worker_init_fn`. diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index 5578b93077..79c514a48e 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -138,6 +138,7 @@ ThresholdIntensityD, ThresholdIntensityDict, ) +from .inverse_transform import InvertibleTransform, NonRigidTransform from .io.array import LoadImage, SaveImage from .io.dictionary import LoadImaged, LoadImageD, LoadImageDict, SaveImaged, SaveImageD, SaveImageDict from .post.array import ( diff --git a/monai/transforms/compose.py b/monai/transforms/compose.py index 2d612ad2e3..e30c467b2e 100644 --- a/monai/transforms/compose.py +++ b/monai/transforms/compose.py @@ -13,10 +13,12 @@ """ import warnings -from typing import Any, Callable, Optional, Sequence, Union +from copy import deepcopy +from typing import Any, Callable, Hashable, Mapping, Optional, Sequence, Tuple, Union import numpy as np +from monai.transforms.inverse_transform import InvertibleTransform # For backwards compatiblity (so this still works: from monai.transforms.compose import MapTransform) from monai.transforms.transform import MapTransform # noqa: F401 from monai.transforms.transform import Randomizable, Transform @@ -26,7 +28,7 @@ __all__ = ["Compose"] -class Compose(Randomizable, Transform): +class Compose(Randomizable, InvertibleTransform): """ ``Compose`` provides the ability to chain a series of calls together in a sequence. Each transform in the sequence must take a single argument and @@ -137,3 +139,17 @@ def __call__(self, input_): for _transform in self.transforms: input_ = apply_transform(_transform, input_) return input_ + + def inverse(self, data, keys: Optional[Tuple[Hashable, ...]] = None): + if not isinstance(data, Mapping): + raise RuntimeError("Inverse method only available for dictionary transforms") + d = deepcopy(dict(data)) + if keys: + keys = ensure_tuple(keys) + + # loop backwards over transforms + for t in reversed(self.transforms): + # check if transform is one of the invertible ones + if isinstance(t, InvertibleTransform): + d = t.inverse(d, keys) + return d diff --git a/monai/transforms/croppad/array.py b/monai/transforms/croppad/array.py index ef5e0019bd..f779f38e1c 100644 --- a/monai/transforms/croppad/array.py +++ b/monai/transforms/croppad/array.py @@ -18,6 +18,7 @@ import numpy as np import torch +import monai.data.utils from monai.config import IndexSelection from monai.data.utils import get_random_patch, get_valid_patch_size from monai.transforms.transform import Randomizable, Transform @@ -304,8 +305,8 @@ def randomize(self, img_size: Sequence[int]) -> None: if self.random_size: self._size = tuple((self.R.randint(low=self._size[i], high=img_size[i] + 1) for i in range(len(img_size)))) if self.random_center: - valid_size = get_valid_patch_size(img_size, self._size) - self._slices = (slice(None),) + get_random_patch(img_size, valid_size, self.R) + valid_size = monai.data.utils.get_valid_patch_size(img_size, self._size) + self._slices = (slice(None),) + monai.data.utils.get_random_patch(img_size, valid_size, self.R) def __call__(self, img: np.ndarray): """ diff --git a/monai/transforms/croppad/dictionary.py b/monai/transforms/croppad/dictionary.py index 20ae6ac1ed..5401b3cbd8 100644 --- a/monai/transforms/croppad/dictionary.py +++ b/monai/transforms/croppad/dictionary.py @@ -15,10 +15,13 @@ Class names are ended with 'd' to denote dictionary-based transforms. """ +from copy import deepcopy +from math import floor from typing import Any, Callable, Dict, Hashable, List, Mapping, Optional, Sequence, Tuple, Union import numpy as np +import monai.data.utils from monai.config import IndexSelection, KeysCollection from monai.data.utils import get_random_patch, get_valid_patch_size from monai.transforms.croppad.array import ( @@ -30,6 +33,7 @@ SpatialCrop, SpatialPad, ) +from monai.transforms.inverse_transform import InvertibleTransform from monai.transforms.transform import MapTransform, Randomizable from monai.transforms.utils import ( generate_pos_neg_label_crop_centers, @@ -82,7 +86,7 @@ NumpyPadModeSequence = Union[Sequence[Union[NumpyPadMode, str]], NumpyPadMode, str] -class SpatialPadd(MapTransform): +class SpatialPadd(MapTransform, InvertibleTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.SpatialPad`. Performs padding to the data, symmetric for all sides or all on one side for each dimension. @@ -106,7 +110,7 @@ def __init__( mode: {``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, ``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} One of the listed string values or a user supplied function. Defaults to ``"constant"``. - See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html + See also: https://numpy.org/doc/stable/reference/generated/numpy.pad.html It also can be a sequence of string, each element corresponds to a key in ``keys``. """ @@ -117,11 +121,34 @@ def __init__( def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) for key, m in zip(self.keys, self.mode): + self.append_applied_transforms(d, key) d[key] = self.padder(d[key], mode=m) return d + def inverse( + self, data: Mapping[Hashable, np.ndarray], keys: Optional[Tuple[Hashable, ...]] = None + ) -> Dict[Hashable, np.ndarray]: + d = deepcopy(dict(data)) + for key in keys or self.keys: + transform = self.get_most_recent_transform(d, key) + # Create inverse transform + orig_size = transform["orig_size"] + if self.padder.method == Method.SYMMETRIC: + current_size = d[key].shape[1:] + roi_center = [floor(i / 2) if r % 2 == 0 else (i - 1) // 2 for r, i in zip(orig_size, current_size)] + else: + roi_center = [floor(r / 2) if r % 2 == 0 else (r - 1) // 2 for r in orig_size] + + inverse_transform = SpatialCrop(roi_center, orig_size) + # Apply inverse transform + d[key] = inverse_transform(d[key]) + # Remove the applied transform + self.remove_most_recent_transform(d, key) + + return d + -class BorderPadd(MapTransform): +class BorderPadd(MapTransform, InvertibleTransform): """ Pad the input data by adding specified borders to every dimension. Dictionary-based wrapper of :py:class:`monai.transforms.BorderPad`. @@ -162,11 +189,38 @@ def __init__( def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) for key, m in zip(self.keys, self.mode): + self.append_applied_transforms(d, key) d[key] = self.padder(d[key], mode=m) return d + def inverse( + self, data: Mapping[Hashable, np.ndarray], keys: Optional[Tuple[Hashable, ...]] = None + ) -> Dict[Hashable, np.ndarray]: + d = deepcopy(dict(data)) + + for key in keys or self.keys: + transform = self.get_most_recent_transform(d, key) + # Create inverse transform + orig_size = np.array(transform["orig_size"]) + roi_start = np.array(self.padder.spatial_border) + # Need to convert single value to [min1,min2,...] + if roi_start.size == 1: + roi_start = np.full((len(orig_size)), roi_start) + # need to convert [min1,max1,min2,...] to [min1,min2,...] + elif roi_start.size == 2 * orig_size.size: + roi_start = roi_start[::2] + roi_end = np.array(transform["orig_size"]) + roi_start + + inverse_transform = SpatialCrop(roi_start=roi_start, roi_end=roi_end) + # Apply inverse transform + d[key] = inverse_transform(d[key]) + # Remove the applied transform + self.remove_most_recent_transform(d, key) + + return d -class DivisiblePadd(MapTransform): + +class DivisiblePadd(MapTransform, InvertibleTransform): """ Pad the input data, so that the spatial sizes are divisible by `k`. Dictionary-based wrapper of :py:class:`monai.transforms.DivisiblePad`. @@ -198,11 +252,32 @@ def __init__( def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) for key, m in zip(self.keys, self.mode): + self.append_applied_transforms(d, key) d[key] = self.padder(d[key], mode=m) return d + def inverse( + self, data: Mapping[Hashable, np.ndarray], keys: Optional[Tuple[Hashable, ...]] = None + ) -> Dict[Hashable, np.ndarray]: + d = deepcopy(dict(data)) + + for key in keys or self.keys: + transform = self.get_most_recent_transform(d, key) + # Create inverse transform + orig_size = np.array(transform["orig_size"]) + current_size = np.array(d[key].shape[1:]) + roi_start = np.floor((current_size - orig_size) / 2) + roi_end = orig_size + roi_start + inverse_transform = SpatialCrop(roi_start=roi_start, roi_end=roi_end) + # Apply inverse transform + d[key] = inverse_transform(d[key]) + # Remove the applied transform + self.remove_most_recent_transform(d, key) + + return d + -class SpatialCropd(MapTransform): +class SpatialCropd(MapTransform, InvertibleTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.SpatialCrop`. Either a spatial center and size must be provided, or alternatively if center and size @@ -232,11 +307,35 @@ def __init__( def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) for key in self.keys: + self.append_applied_transforms(d, key) d[key] = self.cropper(d[key]) return d + def inverse( + self, data: Mapping[Hashable, np.ndarray], keys: Optional[Tuple[Hashable, ...]] = None + ) -> Dict[Hashable, np.ndarray]: + d = deepcopy(dict(data)) + + for key in keys or self.keys: + transform = self.get_most_recent_transform(d, key) + # Create inverse transform + orig_size = transform["orig_size"] + pad_to_start = self.cropper.roi_start + pad_to_end = orig_size - self.cropper.roi_end + # interweave mins and maxes + pad = np.empty((2 * len(orig_size)), dtype=np.int32) + pad[0::2] = pad_to_start + pad[1::2] = pad_to_end + inverse_transform = BorderPad(pad.tolist()) + # Apply inverse transform + d[key] = inverse_transform(d[key]) + # Remove the applied transform + self.remove_most_recent_transform(d, key) + + return d + -class CenterSpatialCropd(MapTransform): +class CenterSpatialCropd(MapTransform, InvertibleTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.CenterSpatialCrop`. @@ -254,11 +353,38 @@ def __init__(self, keys: KeysCollection, roi_size: Union[Sequence[int], int]) -> def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) for key in self.keys: + orig_size = d[key].shape[1:] d[key] = self.cropper(d[key]) + self.append_applied_transforms(d, key, orig_size=orig_size) return d + def inverse( + self, data: Mapping[Hashable, np.ndarray], keys: Optional[Tuple[Hashable, ...]] = None + ) -> Dict[Hashable, np.ndarray]: + d = deepcopy(dict(data)) + + for key in keys or self.keys: + transform = self.get_most_recent_transform(d, key) + # Create inverse transform + orig_size = np.array(transform["orig_size"]) + current_size = np.array(d[key].shape[1:]) + pad_to_start = np.floor((orig_size - current_size) / 2) + # in each direction, if original size is even and current size is odd, += 1 + pad_to_start[np.logical_and(orig_size % 2 == 0, current_size % 2 == 1)] += 1 + pad_to_end = orig_size - current_size - pad_to_start + pad = np.empty((2 * len(orig_size)), dtype=np.int32) + pad[0::2] = pad_to_start + pad[1::2] = pad_to_end + inverse_transform = BorderPad(pad.tolist()) + # Apply inverse transform + d[key] = inverse_transform(d[key]) + # Remove the applied transform + self.remove_most_recent_transform(d, key) -class RandSpatialCropd(Randomizable, MapTransform): + return d + + +class RandSpatialCropd(Randomizable, MapTransform, InvertibleTransform): """ Dictionary-based version :py:class:`monai.transforms.RandSpatialCrop`. Crop image with random size or specific size ROI. It can crop at a random position as @@ -283,7 +409,9 @@ def __init__( random_center: bool = True, random_size: bool = True, ) -> None: - super().__init__(keys) + Randomizable.__init__(self, prob=1.0) + MapTransform.__init__(self, keys) + self._do_transform = True self.roi_size = roi_size self.random_center = random_center self.random_size = random_size @@ -295,8 +423,9 @@ def randomize(self, img_size: Sequence[int]) -> None: if self.random_size: self._size = [self.R.randint(low=self._size[i], high=img_size[i] + 1) for i in range(len(img_size))] if self.random_center: - valid_size = get_valid_patch_size(img_size, self._size) - self._slices = (slice(None),) + get_random_patch(img_size, valid_size, self.R) + valid_size = monai.data.utils.get_valid_patch_size(img_size, self._size) + self._slices = (slice(None),) + monai.data.utils.get_random_patch(img_size, valid_size, self.R) + pass def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) @@ -305,12 +434,50 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda raise AssertionError for key in self.keys: if self.random_center: + self.append_applied_transforms(d, key, {"slices": [(i.start, i.stop) for i in self._slices[1:]]}) # type: ignore d[key] = d[key][self._slices] else: + self.append_applied_transforms(d, key) cropper = CenterSpatialCrop(self._size) d[key] = cropper(d[key]) return d + def inverse( + self, data: Mapping[Hashable, np.ndarray], keys: Optional[Tuple[Hashable, ...]] = None + ) -> Dict[Hashable, np.ndarray]: + d = deepcopy(dict(data)) + + for key in keys or self.keys: + transform = self.get_most_recent_transform(d, key) + # Create inverse transform + orig_size = transform["orig_size"] + random_center = self.random_center + pad_to_start = np.empty((len(orig_size)), dtype=np.int32) + pad_to_end = np.empty((len(orig_size)), dtype=np.int32) + if random_center: + for i, _slice in enumerate(transform["extra_info"]["slices"]): + pad_to_start[i] = _slice[0] + pad_to_end[i] = orig_size[i] - _slice[1] + else: + current_size = d[key].shape[1:] + for i, (o_s, c_s) in enumerate(zip(orig_size, current_size)): + pad_to_start[i] = pad_to_end[i] = (o_s - c_s) / 2 + if o_s % 2 == 0 and c_s % 2 == 1: + pad_to_start[i] += 1 + elif o_s % 2 == 1 and c_s % 2 == 0: + pad_to_end[i] += 1 + # interweave mins and maxes + pad = np.empty((2 * len(orig_size)), dtype=np.int32) + pad[0::2] = pad_to_start + pad[1::2] = pad_to_end + inverse_transform = BorderPad(pad.tolist()) + # Apply inverse transform + d[key] = inverse_transform(d[key]) + # Remove the applied transform + self.remove_most_recent_transform(d, key) + + return d + class RandSpatialCropSamplesd(Randomizable, MapTransform): """ @@ -364,7 +531,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> List[Dict[Hashable, n return [self.cropper(data) for _ in range(self.num_samples)] -class CropForegroundd(MapTransform): +class CropForegroundd(MapTransform, InvertibleTransform): """ Dictionary-based version :py:class:`monai.transforms.CropForeground`. Crop only the foreground object of the expected images. @@ -416,9 +583,33 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda d[self.end_coord_key] = np.asarray(box_end) cropper = SpatialCrop(roi_start=box_start, roi_end=box_end) for key in self.keys: + self.append_applied_transforms(d, key, extra_info={"box_start": box_start, "box_end": box_end}) d[key] = cropper(d[key]) return d + def inverse( + self, data: Mapping[Hashable, np.ndarray], keys: Optional[Tuple[Hashable, ...]] = None + ) -> Dict[Hashable, np.ndarray]: + d = deepcopy(dict(data)) + for key in keys or self.keys: + transform = self.get_most_recent_transform(d, key) + # Create inverse transform + orig_size = np.array(transform["orig_size"]) + extra_info = transform["extra_info"] + pad_to_start = np.array(extra_info["box_start"]) + pad_to_end = orig_size - np.array(extra_info["box_end"]) + # interweave mins and maxes + pad = np.empty((2 * len(orig_size)), dtype=np.int32) + pad[0::2] = pad_to_start + pad[1::2] = pad_to_end + inverse_transform = BorderPad(pad.tolist()) + # Apply inverse transform + d[key] = inverse_transform(d[key]) + # Remove the applied transform + self.remove_most_recent_transform(d, key) + + return d + class RandWeightedCropd(Randomizable, MapTransform): """ @@ -534,7 +725,7 @@ def __init__( fg_indices_key: Optional[str] = None, bg_indices_key: Optional[str] = None, ) -> None: - super().__init__(keys) + MapTransform.__init__(self, keys) self.label_key = label_key self.spatial_size: Union[Tuple[int, ...], Sequence[int], int] = spatial_size if pos < 0 or neg < 0: @@ -592,7 +783,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> List[Dict[Hashable, n return results -class ResizeWithPadOrCropd(MapTransform): +class ResizeWithPadOrCropd(MapTransform, InvertibleTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.ResizeWithPadOrCrop`. @@ -620,7 +811,25 @@ def __init__( def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) for key in self.keys: + orig_size = d[key].shape[1:] d[key] = self.padcropper(d[key]) + self.append_applied_transforms(d, key, orig_size=orig_size) + return d + + def inverse( + self, data: Mapping[Hashable, np.ndarray], keys: Optional[Tuple[Hashable, ...]] = None + ) -> Dict[Hashable, np.ndarray]: + d = deepcopy(dict(data)) + for key in keys or self.keys: + transform = self.get_most_recent_transform(d, key) + # Create inverse transform + orig_size = transform["orig_size"] + inverse_transform = ResizeWithPadOrCrop(spatial_size=orig_size, mode=self.padcropper.padder.mode) + # Apply inverse transform + d[key] = inverse_transform(d[key]) + # Remove the applied transform + self.remove_most_recent_transform(d, key) + return d diff --git a/monai/transforms/intensity/array.py b/monai/transforms/intensity/array.py index 40bef064eb..7fc7b23d3c 100644 --- a/monai/transforms/intensity/array.py +++ b/monai/transforms/intensity/array.py @@ -60,10 +60,9 @@ class RandGaussianNoise(Randomizable, Transform): """ def __init__(self, prob: float = 0.1, mean: Union[Sequence[float], float] = 0.0, std: float = 0.1) -> None: - self.prob = prob + Randomizable.__init__(self, prob) self.mean = mean self.std = std - self._do_transform = False self._noise = None def randomize(self, im_shape: Sequence[int]) -> None: @@ -113,6 +112,7 @@ def __init__(self, offsets: Union[Tuple[float, float], float], prob: float = 0.1 if single number, offset value is picked from (-offsets, offsets). prob: probability of shift. """ + Randomizable.__init__(self, prob) if isinstance(offsets, (int, float)): self.offsets = (min(-offsets, offsets), max(-offsets, offsets)) else: @@ -120,9 +120,6 @@ def __init__(self, offsets: Union[Tuple[float, float], float], prob: float = 0.1 raise AssertionError("offsets should be a number or pair of numbers.") self.offsets = (min(offsets), max(offsets)) - self.prob = prob - self._do_transform = False - def randomize(self, data: Optional[Any] = None) -> None: self._offset = self.R.uniform(low=self.offsets[0], high=self.offsets[1]) self._do_transform = self.R.random() < self.prob @@ -186,6 +183,7 @@ def __init__(self, factors: Union[Tuple[float, float], float], prob: float = 0.1 prob: probability of scale. """ + Randomizable.__init__(self, prob) if isinstance(factors, (int, float)): self.factors = (min(-factors, factors), max(-factors, factors)) else: @@ -193,9 +191,6 @@ def __init__(self, factors: Union[Tuple[float, float], float], prob: float = 0.1 raise AssertionError("factors should be a number or pair of numbers.") self.factors = (min(factors), max(factors)) - self.prob = prob - self._do_transform = False - def randomize(self, data: Optional[Any] = None) -> None: self.factor = self.R.uniform(low=self.factors[0], high=self.factors[1]) self._do_transform = self.R.random() < self.prob @@ -243,7 +238,7 @@ def __init__( self.dtype = dtype def _normalize(self, img: np.ndarray, sub=None, div=None) -> np.ndarray: - slices = (img != 0) if self.nonzero else np.ones(img.shape, dtype=np.bool_) + slices = (img != 0) if self.nonzero else np.ones(img.shape, dtype=bool) if not np.any(slices): return img @@ -383,7 +378,7 @@ class RandAdjustContrast(Randomizable, Transform): """ def __init__(self, prob: float = 0.1, gamma: Union[Sequence[float], float] = (0.5, 4.5)) -> None: - self.prob = prob + Randomizable.__init__(self, prob) if isinstance(gamma, (int, float)): if gamma <= 0.5: @@ -396,7 +391,6 @@ def __init__(self, prob: float = 0.1, gamma: Union[Sequence[float], float] = (0. raise AssertionError("gamma should be a number or pair of numbers.") self.gamma = (min(gamma), max(gamma)) - self._do_transform = False self.gamma_value = None def randomize(self, data: Optional[Any] = None) -> None: @@ -679,12 +673,11 @@ def __init__( prob: float = 0.1, approx: str = "erf", ) -> None: + Randomizable.__init__(self, prob) self.sigma_x = sigma_x self.sigma_y = sigma_y self.sigma_z = sigma_z - self.prob = prob self.approx = approx - self._do_transform = False def randomize(self, data: Optional[Any] = None) -> None: self._do_transform = self.R.random_sample() < self.prob @@ -782,6 +775,7 @@ def __init__( approx: str = "erf", prob: float = 0.1, ) -> None: + Randomizable.__init__(self, prob) self.sigma1_x = sigma1_x self.sigma1_y = sigma1_y self.sigma1_z = sigma1_z @@ -790,8 +784,6 @@ def __init__( self.sigma2_z = sigma2_z self.alpha = alpha self.approx = approx - self.prob = prob - self._do_transform = False def randomize(self, data: Optional[Any] = None) -> None: self._do_transform = self.R.random_sample() < self.prob @@ -827,6 +819,7 @@ class RandHistogramShift(Randomizable, Transform): """ def __init__(self, num_control_points: Union[Tuple[int, int], int] = 10, prob: float = 0.1) -> None: + Randomizable.__init__(self, prob) if isinstance(num_control_points, int): if num_control_points <= 2: @@ -838,8 +831,6 @@ def __init__(self, num_control_points: Union[Tuple[int, int], int] = 10, prob: f if min(num_control_points) <= 2: raise AssertionError("num_control_points should be greater than or equal to 3") self.num_control_points = (min(num_control_points), max(num_control_points)) - self.prob = prob - self._do_transform = False def randomize(self, data: Optional[Any] = None) -> None: self._do_transform = self.R.random() < self.prob diff --git a/monai/transforms/intensity/dictionary.py b/monai/transforms/intensity/dictionary.py index 54a85a57b0..deaf6823d6 100644 --- a/monai/transforms/intensity/dictionary.py +++ b/monai/transforms/intensity/dictionary.py @@ -108,11 +108,10 @@ class RandGaussianNoised(Randomizable, MapTransform): def __init__( self, keys: KeysCollection, prob: float = 0.1, mean: Union[Sequence[float], float] = 0.0, std: float = 0.1 ) -> None: - super().__init__(keys) - self.prob = prob + MapTransform.__init__(self, keys) + Randomizable.__init__(self, prob) self.mean = ensure_tuple_rep(mean, len(self.keys)) self.std = std - self._do_transform = False self._noise: List[np.ndarray] = [] def randomize(self, im_shape: Sequence[int]) -> None: @@ -173,7 +172,8 @@ def __init__(self, keys: KeysCollection, offsets: Union[Tuple[float, float], flo prob: probability of rotating. (Default 0.1, with 10% probability it returns a rotated array.) """ - super().__init__(keys) + MapTransform.__init__(self, keys) + Randomizable.__init__(self, prob) if isinstance(offsets, (int, float)): self.offsets = (min(-offsets, offsets), max(-offsets, offsets)) @@ -182,9 +182,6 @@ def __init__(self, keys: KeysCollection, offsets: Union[Tuple[float, float], flo raise AssertionError("offsets should be a number or pair of numbers.") self.offsets = (min(offsets), max(offsets)) - self.prob = prob - self._do_transform = False - def randomize(self, data: Optional[Any] = None) -> None: self._offset = self.R.uniform(low=self.offsets[0], high=self.offsets[1]) self._do_transform = self.R.random() < self.prob @@ -245,7 +242,8 @@ def __init__(self, keys: KeysCollection, factors: Union[Tuple[float, float], flo (Default 0.1, with 10% probability it returns a rotated array.) """ - super().__init__(keys) + MapTransform.__init__(self, keys) + Randomizable.__init__(self, prob) if isinstance(factors, (int, float)): self.factors = (min(-factors, factors), max(-factors, factors)) @@ -254,9 +252,6 @@ def __init__(self, keys: KeysCollection, factors: Union[Tuple[float, float], flo raise AssertionError("factors should be a number or pair of numbers.") self.factors = (min(factors), max(factors)) - self.prob = prob - self._do_transform = False - def randomize(self, data: Optional[Any] = None) -> None: self.factor = self.R.uniform(low=self.factors[0], high=self.factors[1]) self._do_transform = self.R.random() < self.prob @@ -400,8 +395,8 @@ class RandAdjustContrastd(Randomizable, MapTransform): def __init__( self, keys: KeysCollection, prob: float = 0.1, gamma: Union[Tuple[float, float], float] = (0.5, 4.5) ) -> None: - super().__init__(keys) - self.prob: float = prob + MapTransform.__init__(self, keys) + Randomizable.__init__(self, prob) if isinstance(gamma, (int, float)): if gamma <= 0.5: @@ -414,7 +409,6 @@ def __init__( raise AssertionError("gamma should be a number or pair of numbers.") self.gamma = (min(gamma), max(gamma)) - self._do_transform = False self.gamma_value: Optional[float] = None def randomize(self, data: Optional[Any] = None) -> None: @@ -554,13 +548,12 @@ def __init__( approx: str = "erf", prob: float = 0.1, ) -> None: - super().__init__(keys) + MapTransform.__init__(self, keys) + Randomizable.__init__(self, prob) self.sigma_x = sigma_x self.sigma_y = sigma_y self.sigma_z = sigma_z self.approx = approx - self.prob = prob - self._do_transform = False def randomize(self, data: Optional[Any] = None) -> None: self._do_transform = self.R.random_sample() < self.prob @@ -652,7 +645,8 @@ def __init__( approx: str = "erf", prob: float = 0.1, ): - super().__init__(keys) + MapTransform.__init__(self, keys) + Randomizable.__init__(self, prob) self.sigma1_x = sigma1_x self.sigma1_y = sigma1_y self.sigma1_z = sigma1_z @@ -661,8 +655,6 @@ def __init__( self.sigma2_z = sigma2_z self.alpha = alpha self.approx = approx - self.prob = prob - self._do_transform = False def randomize(self, data: Optional[Any] = None) -> None: self._do_transform = self.R.random_sample() < self.prob @@ -706,7 +698,8 @@ class RandHistogramShiftd(Randomizable, MapTransform): def __init__( self, keys: KeysCollection, num_control_points: Union[Tuple[int, int], int] = 10, prob: float = 0.1 ) -> None: - super().__init__(keys) + MapTransform.__init__(self, keys) + Randomizable.__init__(self, prob) if isinstance(num_control_points, int): if num_control_points <= 2: raise AssertionError("num_control_points should be greater than or equal to 3") @@ -717,8 +710,6 @@ def __init__( if min(num_control_points) <= 2: raise AssertionError("num_control_points should be greater than or equal to 3") self.num_control_points = (min(num_control_points), max(num_control_points)) - self.prob = prob - self._do_transform = False def randomize(self, data: Optional[Any] = None) -> None: self._do_transform = self.R.random() < self.prob diff --git a/monai/transforms/inverse_transform.py b/monai/transforms/inverse_transform.py new file mode 100644 index 0000000000..8b525de4aa --- /dev/null +++ b/monai/transforms/inverse_transform.py @@ -0,0 +1,213 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings +from typing import Dict, Hashable, Optional, Tuple + +import numpy as np +import torch + +from monai.transforms.transform import Randomizable, Transform +from monai.utils import optional_import + +sitk, has_sitk = optional_import("SimpleITK") +vtk, has_vtk = optional_import("vtk") +vtk_numpy_support, _ = optional_import("vtk.util.numpy_support") + +__all__ = ["InvertibleTransform", "NonRigidTransform"] + + +class InvertibleTransform(Transform): + """Classes for invertible transforms. + + This class exists so that an ``invert`` method can be implemented. This allows, for + example, images to be cropped, rotated, padded, etc., during training and inference, + and after be returned to their original size before saving to file for comparison in + an external viewer. + + When the `__call__` method is called, a serialization of the class is stored. When + the `inverse` method is called, the serialization is then removed. We use last in, + first out for the inverted transforms. + """ + + def append_applied_transforms( + self, + data: dict, + key: Hashable, + extra_info: Optional[dict] = None, + orig_size: Optional[Tuple] = None, + ) -> None: + """Append to list of applied transforms for that key.""" + key_transform = str(key) + "_transforms" + info = { + "class": self.__class__.__name__, + "id": id(self), + "orig_size": orig_size or data[key].shape[1:], + } + if extra_info is not None: + info["extra_info"] = extra_info + # If class is randomizable, store whether the transform was actually performed (based on `prob`) + if isinstance(self, Randomizable): + info["do_transform"] = self._do_transform + # If this is the first, create list + if key_transform not in data: + data[key_transform] = [] + data[key_transform].append(info) + + def check_transforms_match(self, transform: dict) -> None: + # Check transorms are of same type. + if transform["id"] != id(self): + raise RuntimeError("Should inverse most recently applied invertible transform first") + + def get_most_recent_transform(self, data: dict, key: Hashable) -> dict: + """Get most recent transform.""" + transform = dict(data[str(key) + "_transforms"][-1]) + self.check_transforms_match(transform) + return transform + + @staticmethod + def remove_most_recent_transform(data: dict, key: Hashable) -> None: + """Remove most recent transform.""" + data[str(key) + "_transforms"].pop() + + def inverse(self, data: dict, keys: Optional[Tuple[Hashable, ...]] = None) -> Dict[Hashable, np.ndarray]: + """ + Inverse of ``__call__``. + + Raises: + NotImplementedError: When the subclass does not override this method. + + """ + raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") + + +class NonRigidTransform(Transform): + @staticmethod + def _get_disp_to_def_arr(shape, spacing): + def_to_disp = np.mgrid[[slice(0, i) for i in shape]].astype(np.float64) + for idx, i in enumerate(shape): + # shift for origin (in MONAI, center of image) + def_to_disp[idx] -= (i - 1) / 2 + # if supplied, account for spacing (e.g., for control point grids) + if spacing is not None: + def_to_disp[idx] *= spacing[idx] + return def_to_disp + + @staticmethod + def _inv_disp_w_sitk(fwd_disp, num_iters): + fwd_disp_sitk = sitk.GetImageFromArray(fwd_disp, isVector=True) + inv_disp_sitk = sitk.InvertDisplacementField(fwd_disp_sitk, num_iters) + inv_disp = sitk.GetArrayFromImage(inv_disp_sitk) + return inv_disp + + @staticmethod + def _inv_disp_w_vtk(fwd_disp): + orig_shape = fwd_disp.shape + required_num_tensor_components = 3 + # VTK requires 3 tensor components, so if shape was (H, W, 2), make it + # (H, W, 1, 3) (i.e., depth 1 with a 3rd tensor component of 0s) + while fwd_disp.shape[-1] < required_num_tensor_components: + fwd_disp = np.append(fwd_disp, np.zeros(fwd_disp.shape[:-1] + (1,)), axis=-1) + fwd_disp = fwd_disp[..., None, :] + + # Create VTKDoubleArray. Shape needs to be (H*W*D, 3) + fwd_disp_flattened = fwd_disp.reshape(-1, required_num_tensor_components) # need to keep this in memory + vtk_data_array = vtk_numpy_support.numpy_to_vtk(fwd_disp_flattened) + + # Generating the vtkImageData + fwd_disp_vtk = vtk.vtkImageData() + fwd_disp_vtk.SetOrigin(0, 0, 0) + fwd_disp_vtk.SetSpacing(1, 1, 1) + fwd_disp_vtk.SetDimensions(*fwd_disp.shape[:-1][::-1]) # VTK spacing opposite order to numpy + fwd_disp_vtk.GetPointData().SetScalars(vtk_data_array) + + if __debug__: + fwd_disp_vtk_np = vtk_numpy_support.vtk_to_numpy(fwd_disp_vtk.GetPointData().GetArray(0)) + assert fwd_disp_vtk_np.size == fwd_disp.size + assert fwd_disp_vtk_np.min() == fwd_disp.min() + assert fwd_disp_vtk_np.max() == fwd_disp.max() + assert fwd_disp_vtk.GetNumberOfScalarComponents() == required_num_tensor_components + + # create b-spline coefficients for the displacement grid + bspline_filter = vtk.vtkImageBSplineCoefficients() + bspline_filter.SetInputData(fwd_disp_vtk) + bspline_filter.Update() + + # use these b-spline coefficients to create a transform + bspline_transform = vtk.vtkBSplineTransform() + bspline_transform.SetCoefficientData(bspline_filter.GetOutput()) + bspline_transform.Update() + + # invert the b-spline transform onto a new grid + grid_maker = vtk.vtkTransformToGrid() + grid_maker.SetInput(bspline_transform.GetInverse()) + grid_maker.SetGridOrigin(fwd_disp_vtk.GetOrigin()) + grid_maker.SetGridSpacing(fwd_disp_vtk.GetSpacing()) + grid_maker.SetGridExtent(fwd_disp_vtk.GetExtent()) + grid_maker.SetGridScalarTypeToFloat() + grid_maker.Update() + + # Get inverse displacement as an image + inv_disp_vtk = grid_maker.GetOutput() + + # Convert back to numpy and reshape + inv_disp = vtk_numpy_support.vtk_to_numpy(inv_disp_vtk.GetPointData().GetArray(0)) + # if there were originally < 3 tensor components, remove the zeros we added at the start + inv_disp = inv_disp[..., : orig_shape[-1]] + # reshape to original + inv_disp = inv_disp.reshape(orig_shape) + + return inv_disp + + @staticmethod + def compute_inverse_deformation( + num_spatial_dims, fwd_def_orig, spacing=None, num_iters: int = 100, use_package: str = "vtk" + ): + """Package can be vtk or sitk.""" + if use_package.lower() == "vtk" and not has_vtk: + warnings.warn("Please install VTK to estimate inverse of non-rigid transforms. Data has not been modified") + return None + if use_package.lower() == "sitk" and not has_sitk: + warnings.warn( + "Please install SimpleITK to estimate inverse of non-rigid transforms. Data has not been modified" + ) + return None + + # Convert to numpy if necessary + if isinstance(fwd_def_orig, torch.Tensor): + fwd_def_orig = fwd_def_orig.cpu().numpy() + # Remove any extra dimensions (we'll add them back in at the end) + fwd_def = fwd_def_orig[:num_spatial_dims] + # Def -> disp + def_to_disp = NonRigidTransform._get_disp_to_def_arr(fwd_def.shape[1:], spacing) + fwd_disp = fwd_def - def_to_disp + # move tensor component to end (T,H,W,[D])->(H,W,[D],T) + fwd_disp = np.moveaxis(fwd_disp, 0, -1) + + # If using vtk... + if use_package.lower() == "vtk": + inv_disp = NonRigidTransform._inv_disp_w_vtk(fwd_disp) + # If using sitk... + elif use_package.lower() == "sitk": + inv_disp = NonRigidTransform._inv_disp_w_sitk(fwd_disp, num_iters) + else: + raise RuntimeError("Enter vtk or sitk for inverse calculation") + + # move tensor component back to beginning + inv_disp = np.moveaxis(inv_disp, -1, 0) + # Disp -> def + inv_def = inv_disp + def_to_disp + # Add back in any removed dimensions + ndim_in = fwd_def_orig.shape[0] + ndim_out = inv_def.shape[0] + inv_def = np.concatenate([inv_def, fwd_def_orig[ndim_out:ndim_in]]) + + return inv_def diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index d6dbe56f01..b4c6a22629 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -19,8 +19,8 @@ import numpy as np import torch +import monai.data.utils from monai.config import USE_COMPILED, DtypeLike -from monai.data.utils import compute_shape_offset, to_affine_nd, zoom_affine from monai.networks.layers import AffineTransform, GaussianFilter, grid_pull from monai.transforms.croppad.array import CenterSpatialCrop from monai.transforms.transform import Randomizable, Transform @@ -150,7 +150,7 @@ def __call__( ValueError: When ``pixdim`` is nonpositive. Returns: - data_array (resampled into `self.pixdim`), original pixdim, current pixdim. + data_array (resampled into `self.pixdim`), original affine, current affine. """ _dtype = dtype or self.dtype or data_array.dtype @@ -162,24 +162,24 @@ def __call__( affine = np.eye(sr + 1, dtype=np.float64) affine_ = np.eye(sr + 1, dtype=np.float64) else: - affine_ = to_affine_nd(sr, affine) + affine_ = monai.data.utils.to_affine_nd(sr, affine) out_d = self.pixdim[:sr] if out_d.size < sr: out_d = np.append(out_d, [1.0] * (out_d.size - sr)) if np.any(out_d <= 0): raise ValueError(f"pixdim must be positive, got {out_d}.") # compute output affine, shape and offset - new_affine = zoom_affine(affine_, out_d, diagonal=self.diagonal) - output_shape, offset = compute_shape_offset(data_array.shape[1:], affine_, new_affine) + new_affine = monai.data.utils.zoom_affine(affine_, out_d, diagonal=self.diagonal) + output_shape, offset = monai.data.utils.compute_shape_offset(data_array.shape[1:], affine_, new_affine) new_affine[:sr, -1] = offset[:sr] transform = np.linalg.inv(affine_) @ new_affine # adapt to the actual rank - transform = to_affine_nd(sr, transform) + transform = monai.data.utils.to_affine_nd(sr, transform) # no resampling if it's identity transform if np.allclose(transform, np.diag(np.ones(len(transform))), atol=1e-3): output_data = data_array.copy().astype(np.float32) - new_affine = to_affine_nd(affine, new_affine) + new_affine = monai.data.utils.to_affine_nd(affine, new_affine) return output_data, affine, new_affine # resample @@ -197,7 +197,7 @@ def __call__( spatial_size=output_shape, ) output_data = np.asarray(output_data.squeeze(0).detach().cpu().numpy(), dtype=np.float32) # type: ignore - new_affine = to_affine_nd(affine, new_affine) + new_affine = monai.data.utils.to_affine_nd(affine, new_affine) return output_data, affine, new_affine @@ -263,7 +263,7 @@ def __call__( affine = np.eye(sr + 1, dtype=np.float64) affine_ = np.eye(sr + 1, dtype=np.float64) else: - affine_ = to_affine_nd(sr, affine) + affine_ = monai.data.utils.to_affine_nd(sr, affine) src = nib.io_orientation(affine_) if self.as_closest_canonical: spatial_ornt = src @@ -280,9 +280,9 @@ def __call__( ornt[:, 0] += 1 # skip channel dim ornt = np.concatenate([np.array([[0, 1]]), ornt]) shape = data_array.shape[1:] - data_array = nib.orientations.apply_orientation(data_array, ornt) + data_array = np.ascontiguousarray(nib.orientations.apply_orientation(data_array, ornt)) new_affine = affine_ @ nib.orientations.inv_ornt_aff(spatial_ornt, shape) - new_affine = to_affine_nd(affine, new_affine) + new_affine = monai.data.utils.to_affine_nd(affine, new_affine) return data_array, affine, new_affine @@ -316,7 +316,7 @@ def __call__(self, img: np.ndarray) -> np.ndarray: class Resize(Transform): """ - Resize the input image to given spatial size. + Resize the input image to given spatial size (with scaling, not cropping/padding). Implemented using :py:class:`torch.nn.functional.interpolate`. Args: @@ -428,7 +428,8 @@ def __call__( padding_mode: Optional[Union[GridSamplePadMode, str]] = None, align_corners: Optional[bool] = None, dtype: DtypeLike = None, - ) -> np.ndarray: + return_rotation_matrix: bool = False, + ) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]: """ Args: img: channel first array, must have shape: [chns, H, W] or [chns, H, W, D]. @@ -445,6 +446,7 @@ def __call__( dtype: data type for resampling computation. Defaults to ``self.dtype``. If None, use the data type of input data. To be compatible with other modules, the output data type is always ``np.float32``. + return_rotation_matrix: whether or not to return the applied rotation matrix. Raises: ValueError: When ``img`` spatially is not one of [2D, 3D]. @@ -481,7 +483,10 @@ def __call__( torch.as_tensor(np.ascontiguousarray(transform).astype(_dtype)), spatial_size=output_shape, ) - return np.asarray(output.squeeze(0).detach().cpu().numpy(), dtype=np.float32) + output_np = np.asarray(output.squeeze(0).detach().cpu().numpy(), dtype=np.float32) + if return_rotation_matrix: + return output_np, transform + return output_np class Zoom(Transform): @@ -589,7 +594,7 @@ def __init__(self, k: int = 1, spatial_axes: Tuple[int, int] = (0, 1)) -> None: If axis is negative it counts from the last to the first axis. """ self.k = k - spatial_axes_ = ensure_tuple(spatial_axes) + spatial_axes_: Tuple[int, int] = ensure_tuple(spatial_axes) # type: ignore if len(spatial_axes_) != 2: raise ValueError("spatial_axes must be 2 int numbers to indicate the axes to rotate 90 degrees.") self.spatial_axes = spatial_axes_ @@ -619,11 +624,10 @@ def __init__(self, prob: float = 0.1, max_k: int = 3, spatial_axes: Tuple[int, i spatial_axes: 2 int numbers, defines the plane to rotate with 2 spatial axes. Default: (0, 1), this is the first two axis in spatial dimensions. """ - self.prob = min(max(prob, 0.0), 1.0) + Randomizable.__init__(self, min(max(prob, 0.0), 1.0)) self.max_k = max_k self.spatial_axes = spatial_axes - self._do_transform = False self._rand_k = 0 def randomize(self, data: Optional[Any] = None) -> None: @@ -682,6 +686,7 @@ def __init__( align_corners: bool = False, dtype: DtypeLike = np.float64, ) -> None: + Randomizable.__init__(self, prob) self.range_x = ensure_tuple(range_x) if len(self.range_x) == 1: self.range_x = tuple(sorted([-self.range_x[0], self.range_x[0]])) @@ -692,14 +697,12 @@ def __init__( if len(self.range_z) == 1: self.range_z = tuple(sorted([-self.range_z[0], self.range_z[0]])) - self.prob = prob self.keep_size = keep_size self.mode: GridSampleMode = GridSampleMode(mode) self.padding_mode: GridSamplePadMode = GridSamplePadMode(padding_mode) self.align_corners = align_corners self.dtype = dtype - self._do_transform = False self.x = 0.0 self.y = 0.0 self.z = 0.0 @@ -744,7 +747,7 @@ def __call__( align_corners=self.align_corners if align_corners is None else align_corners, dtype=dtype or self.dtype or img.dtype, ) - return rotator(img) + return np.array(rotator(img)) class RandFlip(Randomizable, Transform): @@ -759,9 +762,8 @@ class RandFlip(Randomizable, Transform): """ def __init__(self, prob: float = 0.1, spatial_axis: Optional[Union[Sequence[int], int]] = None) -> None: - self.prob = prob + Randomizable.__init__(self, min(max(prob, 0.0), 1.0)) self.flipper = Flip(spatial_axis=spatial_axis) - self._do_transform = False def randomize(self, data: Optional[Any] = None) -> None: self._do_transform = self.R.random_sample() < self.prob @@ -816,17 +818,16 @@ def __init__( align_corners: Optional[bool] = None, keep_size: bool = True, ) -> None: + Randomizable.__init__(self, prob) self.min_zoom = ensure_tuple(min_zoom) self.max_zoom = ensure_tuple(max_zoom) if len(self.min_zoom) != len(self.max_zoom): raise AssertionError("min_zoom and max_zoom must have same length.") - self.prob = prob self.mode: InterpolateMode = InterpolateMode(mode) self.padding_mode: NumpyPadMode = NumpyPadMode(padding_mode) self.align_corners = align_corners self.keep_size = keep_size - self._do_transform = False self._zoom: Sequence[float] = [1.0] def randomize(self, data: Optional[Any] = None) -> None: @@ -900,6 +901,9 @@ class AffineGrid(Transform): as_tensor_output: whether to output tensor instead of numpy array. defaults to True. device: device to store the output grid data. + affine: If applied, ignore the params (`rotate_params`, etc.) and use the + supplied matrix. Should be square with each side = num of image spatial + dimensions + 1. """ @@ -911,6 +915,7 @@ def __init__( scale_params: Optional[Union[Sequence[float], float]] = None, as_tensor_output: bool = True, device: Optional[torch.device] = None, + affine: Optional[Union[np.ndarray, torch.Tensor]] = None, ) -> None: self.rotate_params = rotate_params self.shear_params = shear_params @@ -920,13 +925,19 @@ def __init__( self.as_tensor_output = as_tensor_output self.device = device + self.affine = affine + def __call__( - self, spatial_size: Optional[Sequence[int]] = None, grid: Optional[Union[np.ndarray, torch.Tensor]] = None - ) -> Union[np.ndarray, torch.Tensor]: + self, + spatial_size: Optional[Sequence[int]] = None, + grid: Optional[Union[np.ndarray, torch.Tensor]] = None, + return_affine: bool = False, + ) -> Union[np.ndarray, torch.Tensor, Tuple[Union[np.ndarray, torch.Tensor], torch.Tensor]]: """ Args: spatial_size: output grid size. grid: grid to be transformed. Shape must be (3, H, W) for 2D or (4, H, W, D) for 3D. + return_affine: boolean as to whether to return the generated affine matrix or not. Raises: ValueError: When ``grid=None`` and ``spatial_size=None``. Incompatible values. @@ -938,16 +949,20 @@ def __call__( else: raise ValueError("Incompatible values: grid=None and spatial_size=None.") - spatial_dims = len(grid.shape) - 1 - affine = np.eye(spatial_dims + 1) - if self.rotate_params: - affine = affine @ create_rotate(spatial_dims, self.rotate_params) - if self.shear_params: - affine = affine @ create_shear(spatial_dims, self.shear_params) - if self.translate_params: - affine = affine @ create_translate(spatial_dims, self.translate_params) - if self.scale_params: - affine = affine @ create_scale(spatial_dims, self.scale_params) + affine: Union[np.ndarray, torch.Tensor] + if self.affine is None: + spatial_dims = len(grid.shape) - 1 + affine = np.eye(spatial_dims + 1) + if self.rotate_params: + affine = affine @ create_rotate(spatial_dims, self.rotate_params) + if self.shear_params: + affine = affine @ create_shear(spatial_dims, self.shear_params) + if self.translate_params: + affine = affine @ create_translate(spatial_dims, self.translate_params) + if self.scale_params: + affine = affine @ create_scale(spatial_dims, self.scale_params) + else: + affine = self.affine affine = torch.as_tensor(np.ascontiguousarray(affine), device=self.device) grid = torch.tensor(grid) if not isinstance(grid, torch.Tensor) else grid.detach().clone() @@ -956,9 +971,10 @@ def __call__( grid = (affine.float() @ grid.reshape((grid.shape[0], -1)).float()).reshape([-1] + list(grid.shape[1:])) if grid is None or not isinstance(grid, torch.Tensor): raise ValueError("Unknown grid.") - if self.as_tensor_output: - return grid - return np.asarray(grid.cpu().numpy()) + output: Union[np.ndarray, torch.Tensor] = grid if self.as_tensor_output else np.asarray(grid.cpu().numpy()) + if return_affine: + return output, affine + return output class RandAffineGrid(Randomizable, Transform): @@ -1028,12 +1044,16 @@ def randomize(self, data: Optional[Any] = None) -> None: self.scale_params = self._get_rand_param(self.scale_range, 1.0) def __call__( - self, spatial_size: Optional[Sequence[int]] = None, grid: Optional[Union[np.ndarray, torch.Tensor]] = None - ) -> Union[np.ndarray, torch.Tensor]: + self, + spatial_size: Optional[Sequence[int]] = None, + grid: Optional[Union[np.ndarray, torch.Tensor]] = None, + return_affine: bool = False, + ) -> Union[np.ndarray, torch.Tensor, Tuple[Union[np.ndarray, torch.Tensor], torch.Tensor]]: """ Args: spatial_size: output grid size. grid: grid to be transformed. Shape must be (3, H, W) for 2D or (4, H, W, D) for 3D. + return_affine: boolean as to whether to return the generated affine matrix or not. Returns: a 2D (3xHxW) or 3D (4xHxWxD) grid. @@ -1047,7 +1067,7 @@ def __call__( as_tensor_output=self.as_tensor_output, device=self.device, ) - return affine_grid(spatial_size, grid) + return affine_grid(spatial_size, grid, return_affine) class RandDeformGrid(Randomizable, Transform): @@ -1273,7 +1293,7 @@ def __call__( See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample """ sp_size = fall_back_tuple(spatial_size or self.spatial_size, img.shape[1:]) - grid = self.affine_grid(spatial_size=sp_size) + grid: torch.Tensor = self.affine_grid(spatial_size=sp_size) # type: ignore return self.resampler( img=img, grid=grid, mode=mode or self.mode, padding_mode=padding_mode or self.padding_mode ) @@ -1331,6 +1351,7 @@ def __init__( - :py:class:`RandAffineGrid` for the random affine parameters configurations. - :py:class:`Affine` for the affine transformation parameters configurations. """ + Randomizable.__init__(self, prob) self.rand_affine_grid = RandAffineGrid( rotate_range=rotate_range, @@ -1346,9 +1367,6 @@ def __init__( self.mode: GridSampleMode = GridSampleMode(mode) self.padding_mode: GridSamplePadMode = GridSamplePadMode(padding_mode) - self.do_transform = False - self.prob = prob - def set_random_state( self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None ) -> "RandAffine": @@ -1357,7 +1375,7 @@ def set_random_state( return self def randomize(self, data: Optional[Any] = None) -> None: - self.do_transform = self.R.rand() < self.prob + self._do_transform = self.R.rand() < self.prob self.rand_affine_grid.randomize() def __call__( @@ -1366,7 +1384,8 @@ def __call__( spatial_size: Optional[Union[Sequence[int], int]] = None, mode: Optional[Union[GridSampleMode, str]] = None, padding_mode: Optional[Union[GridSamplePadMode, str]] = None, - ) -> Union[np.ndarray, torch.Tensor]: + return_affine: bool = False, + ) -> Union[np.ndarray, torch.Tensor, Tuple[Union[np.ndarray, torch.Tensor], torch.Tensor]]: """ Args: img: shape must be (num_channels, H, W[, D]), @@ -1381,17 +1400,26 @@ def __call__( padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} Padding mode for outside grid values. Defaults to ``self.padding_mode``. See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + return_affine: boolean as to whether to return the generated affine matrix or not. """ self.randomize() sp_size = fall_back_tuple(spatial_size or self.spatial_size, img.shape[1:]) - if self.do_transform: - grid = self.rand_affine_grid(spatial_size=sp_size) + affine = np.eye(len(sp_size) + 1) + if self._do_transform: + out = self.rand_affine_grid(spatial_size=sp_size) + if return_affine: + grid, affine = out + else: + grid = out else: grid = create_grid(spatial_size=sp_size) - return self.resampler( + resampled = self.resampler( img=img, grid=grid, mode=mode or self.mode, padding_mode=padding_mode or self.padding_mode ) + if return_affine: + return resampled, affine + return resampled class Rand2DElastic(Randomizable, Transform): @@ -1451,6 +1479,7 @@ def __init__( - :py:class:`RandAffineGrid` for the random affine parameters configurations. - :py:class:`Affine` for the affine transformation parameters configurations. """ + Randomizable.__init__(self, prob) self.deform_grid = RandDeformGrid( spacing=spacing, magnitude_range=magnitude_range, as_tensor_output=True, device=device ) @@ -1467,8 +1496,6 @@ def __init__( self.spatial_size = spatial_size self.mode: GridSampleMode = GridSampleMode(mode) self.padding_mode: GridSamplePadMode = GridSamplePadMode(padding_mode) - self.prob = prob - self.do_transform = False def set_random_state( self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None @@ -1479,7 +1506,7 @@ def set_random_state( return self def randomize(self, spatial_size: Sequence[int]) -> None: - self.do_transform = self.R.rand() < self.prob + self._do_transform = self.R.rand() < self.prob self.deform_grid.randomize(spatial_size) self.rand_affine_grid.randomize() @@ -1505,7 +1532,7 @@ def __call__( """ sp_size = fall_back_tuple(spatial_size or self.spatial_size, img.shape[1:]) self.randomize(spatial_size=sp_size) - if self.do_transform: + if self._do_transform: grid = self.deform_grid(spatial_size=sp_size) grid = self.rand_affine_grid(grid=grid) grid = torch.nn.functional.interpolate( # type: ignore @@ -1580,6 +1607,7 @@ def __init__( - :py:class:`RandAffineGrid` for the random affine parameters configurations. - :py:class:`Affine` for the affine transformation parameters configurations. """ + Randomizable.__init__(self, prob) self.rand_affine_grid = RandAffineGrid(rotate_range, shear_range, translate_range, scale_range, True, device) self.resampler = Resample(as_tensor_output=as_tensor_output, device=device) @@ -1590,8 +1618,6 @@ def __init__( self.padding_mode: GridSamplePadMode = GridSamplePadMode(padding_mode) self.device = device - self.prob = prob - self.do_transform = False self.rand_offset = None self.magnitude = 1.0 self.sigma = 1.0 @@ -1604,8 +1630,8 @@ def set_random_state( return self def randomize(self, grid_size: Sequence[int]) -> None: - self.do_transform = self.R.rand() < self.prob - if self.do_transform: + self._do_transform = self.R.rand() < self.prob + if self._do_transform: self.rand_offset = self.R.uniform(-1.0, 1.0, [3] + list(grid_size)).astype(np.float32) self.magnitude = self.R.uniform(self.magnitude_range[0], self.magnitude_range[1]) self.sigma = self.R.uniform(self.sigma_range[0], self.sigma_range[1]) @@ -1634,7 +1660,7 @@ def __call__( sp_size = fall_back_tuple(spatial_size or self.spatial_size, img.shape[1:]) self.randomize(grid_size=sp_size) grid = create_grid(spatial_size=sp_size) - if self.do_transform: + if self._do_transform: if self.rand_offset is None: raise AssertionError grid = torch.as_tensor(np.ascontiguousarray(grid), device=self.device) diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index 2c66cd5f50..55beb6e58b 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -15,15 +15,19 @@ Class names are ended with 'd' to denote dictionary-based transforms. """ +from copy import deepcopy from typing import Any, Dict, Hashable, Mapping, Optional, Sequence, Tuple, Union import numpy as np import torch from monai.config import DtypeLike, KeysCollection +from monai.networks.layers import AffineTransform from monai.networks.layers.simplelayers import GaussianFilter -from monai.transforms.croppad.array import CenterSpatialCrop +from monai.transforms.croppad.array import CenterSpatialCrop, SpatialPad +from monai.transforms.inverse_transform import InvertibleTransform, NonRigidTransform from monai.transforms.spatial.array import ( + AffineGrid, Flip, Orientation, Rand2DElastic, @@ -45,8 +49,11 @@ ensure_tuple, ensure_tuple_rep, fall_back_tuple, + optional_import, ) +nib, _ = optional_import("nibabel") + __all__ = [ "Spacingd", "Orientationd", @@ -98,7 +105,7 @@ NumpyPadModeSequence = Union[Sequence[Union[NumpyPadMode, str]], NumpyPadMode, str] -class Spacingd(MapTransform): +class Spacingd(MapTransform, InvertibleTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.Spacing`. @@ -177,10 +184,11 @@ def __call__( ) -> Dict[Union[Hashable, str], Union[np.ndarray, Dict[str, np.ndarray]]]: d: Dict = dict(data) for idx, key in enumerate(self.keys): - meta_data = d[f"{key}_{self.meta_key_postfix}"] + meta_data_key = f"{key}_{self.meta_key_postfix}" + meta_data = d[meta_data_key] # resample array of each corresponding key # using affine fetched from d[affine_key] - d[key], _, new_affine = self.spacing_transform( + d[key], old_affine, new_affine = self.spacing_transform( data_array=np.asarray(d[key]), affine=meta_data["affine"], mode=self.mode[idx], @@ -188,12 +196,46 @@ def __call__( align_corners=self.align_corners[idx], dtype=self.dtype[idx], ) + self.append_applied_transforms( + d, key, extra_info={"meta_data_key": meta_data_key, "old_affine": old_affine} + ) # set the 'affine' key meta_data["affine"] = new_affine return d + def inverse( + self, data: Mapping[Hashable, np.ndarray], keys: Optional[Tuple[Hashable, ...]] = None + ) -> Dict[Hashable, np.ndarray]: + d = deepcopy(dict(data)) + for idx, key in enumerate(keys or self.keys): + transform = self.get_most_recent_transform(d, key) + if self.spacing_transform.diagonal: + raise RuntimeError( + "Spacingd:inverse not yet implemented for diagonal=True. " + + "Please raise a github issue if you need this feature" + ) + # Create inverse transform + meta_data = d[transform["extra_info"]["meta_data_key"]] + old_affine = np.array(transform["extra_info"]["old_affine"]) + orig_pixdim = np.sqrt(np.sum(np.square(old_affine), 0))[:-1] + inverse_transform = Spacing(orig_pixdim, diagonal=self.spacing_transform.diagonal) + # Apply inverse + d[key], _, new_affine = inverse_transform( + data_array=np.asarray(d[key]), + affine=meta_data["affine"], + mode=self.mode[idx], + padding_mode=self.padding_mode[idx], + align_corners=self.align_corners[idx], + dtype=self.dtype[idx], + ) + meta_data["affine"] = new_affine + # Remove the applied transform + self.remove_most_recent_transform(d, key) + + return d + -class Orientationd(MapTransform): +class Orientationd(MapTransform, InvertibleTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.Orientation`. @@ -246,13 +288,40 @@ def __call__( ) -> Dict[Union[Hashable, str], Union[np.ndarray, Dict[str, np.ndarray]]]: d: Dict = dict(data) for key in self.keys: - meta_data = d[f"{key}_{self.meta_key_postfix}"] - d[key], _, new_affine = self.ornt_transform(d[key], affine=meta_data["affine"]) + meta_data_key = f"{key}_{self.meta_key_postfix}" + meta_data = d[meta_data_key] + d[key], old_affine, new_affine = self.ornt_transform(d[key], affine=meta_data["affine"]) + self.append_applied_transforms( + d, key, extra_info={"meta_data_key": meta_data_key, "old_affine": old_affine} + ) + d[meta_data_key]["affine"] = new_affine + return d + + def inverse( + self, data: Mapping[Hashable, np.ndarray], keys: Optional[Tuple[Hashable, ...]] = None + ) -> Dict[Hashable, np.ndarray]: + d = deepcopy(dict(data)) + for key in keys or self.keys: + transform = self.get_most_recent_transform(d, key) + # Create inverse transform + meta_data = d[transform["extra_info"]["meta_data_key"]] + orig_affine = transform["extra_info"]["old_affine"] + orig_axcodes = nib.orientations.aff2axcodes(orig_affine) + inverse_transform = Orientation( + axcodes=orig_axcodes, + as_closest_canonical=self.ornt_transform.as_closest_canonical, + labels=self.ornt_transform.labels, + ) + # Apply inverse + d[key], _, new_affine = inverse_transform(d[key], affine=meta_data["affine"]) meta_data["affine"] = new_affine + # Remove the applied transform + self.remove_most_recent_transform(d, key) + return d -class Rotate90d(MapTransform): +class Rotate90d(MapTransform, InvertibleTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.Rotate90`. """ @@ -270,11 +339,33 @@ def __init__(self, keys: KeysCollection, k: int = 1, spatial_axes: Tuple[int, in def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) for key in self.keys: + self.append_applied_transforms(d, key) d[key] = self.rotator(d[key]) return d + def inverse( + self, data: Mapping[Hashable, np.ndarray], keys: Optional[Tuple[Hashable, ...]] = None + ) -> Dict[Hashable, np.ndarray]: + d = deepcopy(dict(data)) + for key in keys or self.keys: + _ = self.get_most_recent_transform(d, key) + # Create inverse transform + spatial_axes = self.rotator.spatial_axes + num_times_rotated = self.rotator.k + num_times_to_rotate = 4 - num_times_rotated + inverse_transform = Rotate90(num_times_to_rotate, spatial_axes) + # Might need to convert to numpy + if isinstance(d[key], torch.Tensor): + d[key] = torch.Tensor(d[key]).cpu().numpy() + # Apply inverse + d[key] = inverse_transform(d[key]) + # Remove the applied transform + self.remove_most_recent_transform(d, key) -class RandRotate90d(Randomizable, MapTransform): + return d + + +class RandRotate90d(Randomizable, MapTransform, InvertibleTransform): """ Dictionary-based version :py:class:`monai.transforms.RandRotate90`. With probability `prob`, input arrays are rotated by 90 degrees @@ -299,13 +390,12 @@ def __init__( spatial_axes: 2 int numbers, defines the plane to rotate with 2 spatial axes. Default: (0, 1), this is the first two axis in spatial dimensions. """ - super().__init__(keys) + MapTransform.__init__(self, keys) + Randomizable.__init__(self, min(max(prob, 0.0), 1.0)) - self.prob = min(max(prob, 0.0), 1.0) self.max_k = max_k self.spatial_axes = spatial_axes - self._do_transform = False self._rand_k = 0 def randomize(self, data: Optional[Any] = None) -> None: @@ -314,17 +404,39 @@ def randomize(self, data: Optional[Any] = None) -> None: def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Mapping[Hashable, np.ndarray]: self.randomize() - if not self._do_transform: - return data + d = dict(data) rotator = Rotate90(self._rand_k, self.spatial_axes) - d = dict(data) for key in self.keys: - d[key] = rotator(d[key]) + if self._do_transform: + d[key] = rotator(d[key]) + self.append_applied_transforms(d, key, extra_info={"rand_k": self._rand_k}) + return d + + def inverse( + self, data: Mapping[Hashable, np.ndarray], keys: Optional[Tuple[Hashable, ...]] = None + ) -> Dict[Hashable, np.ndarray]: + d = deepcopy(dict(data)) + for key in keys or self.keys: + transform = self.get_most_recent_transform(d, key) + # Check if random transform was actually performed (based on `prob`) + if transform["do_transform"]: + # Create inverse transform + num_times_rotated = transform["extra_info"]["rand_k"] + num_times_to_rotate = 4 - num_times_rotated + inverse_transform = Rotate90(num_times_to_rotate, self.spatial_axes) + # Might need to convert to numpy + if isinstance(d[key], torch.Tensor): + d[key] = torch.Tensor(d[key]).cpu().numpy() + # Apply inverse + d[key] = inverse_transform(d[key]) + # Remove the applied transform + self.remove_most_recent_transform(d, key) + return d -class Resized(MapTransform): +class Resized(MapTransform, InvertibleTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.Resize`. @@ -360,11 +472,30 @@ def __init__( def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) for idx, key in enumerate(self.keys): + self.append_applied_transforms(d, key) d[key] = self.resizer(d[key], mode=self.mode[idx], align_corners=self.align_corners[idx]) return d + def inverse( + self, data: Mapping[Hashable, np.ndarray], keys: Optional[Tuple[Hashable, ...]] = None + ) -> Dict[Hashable, np.ndarray]: + d = deepcopy(dict(data)) + for idx, key in enumerate(keys or self.keys): + transform = self.get_most_recent_transform(d, key) + orig_size = transform["orig_size"] + mode = self.mode[idx] + align_corners = self.align_corners[idx] + # Create inverse transform + inverse_transform = Resize(orig_size, mode, align_corners) + # Apply inverse transform + d[key] = inverse_transform(d[key]) + # Remove the applied transform + self.remove_most_recent_transform(d, key) + + return d + -class RandAffined(Randomizable, MapTransform): +class RandAffined(Randomizable, MapTransform, InvertibleTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.RandAffine`. """ @@ -420,9 +551,10 @@ def __init__( - :py:class:`monai.transforms.compose.MapTransform` - :py:class:`RandAffineGrid` for the random affine parameters configurations. """ - super().__init__(keys) + MapTransform.__init__(self, keys) + Randomizable.__init__(self, prob) self.rand_affine = RandAffine( - prob=prob, + prob=1.0, # because probability handled in this class rotate_range=rotate_range, shear_range=shear_range, translate_range=translate_range, @@ -442,6 +574,7 @@ def set_random_state( return self def randomize(self, data: Optional[Any] = None) -> None: + self._do_transform = self.R.rand() < self.prob self.rand_affine.randomize() def __call__( @@ -451,17 +584,45 @@ def __call__( self.randomize() sp_size = fall_back_tuple(self.rand_affine.spatial_size, data[self.keys[0]].shape[1:]) - if self.rand_affine.do_transform: - grid = self.rand_affine.rand_affine_grid(spatial_size=sp_size) + if self._do_transform: + grid, affine = self.rand_affine.rand_affine_grid(spatial_size=sp_size, return_affine=True) else: grid = create_grid(spatial_size=sp_size) + affine = np.eye(len(sp_size) + 1) for idx, key in enumerate(self.keys): + self.append_applied_transforms(d, key, extra_info={"affine": affine}) d[key] = self.rand_affine.resampler(d[key], grid, mode=self.mode[idx], padding_mode=self.padding_mode[idx]) return d + def inverse( + self, data: Mapping[Hashable, np.ndarray], keys: Optional[Tuple[Hashable, ...]] = None + ) -> Dict[Hashable, np.ndarray]: + d = deepcopy(dict(data)) + + for idx, key in enumerate(keys or self.keys): + transform = self.get_most_recent_transform(d, key) + orig_size = transform["orig_size"] + # Create inverse transform + fwd_affine = transform["extra_info"]["affine"] + inv_affine = np.linalg.inv(fwd_affine) + + affine_grid = AffineGrid(affine=inv_affine) + grid: torch.Tensor = affine_grid(orig_size) # type: ignore + + # Apply inverse transform + out = self.rand_affine.resampler(d[key], grid, self.mode[idx], self.padding_mode[idx]) + + # Convert to numpy + d[key] = out if isinstance(out, np.ndarray) else out.cpu().numpy() -class Rand2DElasticd(Randomizable, MapTransform): + # Remove the applied transform + self.remove_most_recent_transform(d, key) + + return d + + +class Rand2DElasticd(Randomizable, MapTransform, InvertibleTransform, NonRigidTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.Rand2DElastic`. """ @@ -523,11 +684,12 @@ def __init__( - :py:class:`RandAffineGrid` for the random affine parameters configurations. - :py:class:`Affine` for the affine transformation parameters configurations. """ - super().__init__(keys) + MapTransform.__init__(self, keys) + Randomizable.__init__(self, prob) self.rand_2d_elastic = Rand2DElastic( spacing=spacing, magnitude_range=magnitude_range, - prob=prob, + prob=1.0, # because probability controlled by this class rotate_range=rotate_range, shear_range=shear_range, translate_range=translate_range, @@ -547,8 +709,20 @@ def set_random_state( return self def randomize(self, spatial_size: Sequence[int]) -> None: + self._do_transform = self.R.rand() < self.prob self.rand_2d_elastic.randomize(spatial_size) + @staticmethod + def cpg_to_dvf(cpg, spacing, output_shape): + grid = torch.nn.functional.interpolate( + recompute_scale_factor=True, + input=cpg.unsqueeze(0), + scale_factor=ensure_tuple_rep(spacing, 2), + mode=InterpolateMode.BILINEAR.value, + align_corners=False, + ) + return CenterSpatialCrop(roi_size=output_shape)(grid[0]) + def __call__( self, data: Mapping[Hashable, Union[np.ndarray, torch.Tensor]] ) -> Dict[Hashable, Union[np.ndarray, torch.Tensor]]: @@ -557,28 +731,70 @@ def __call__( sp_size = fall_back_tuple(self.rand_2d_elastic.spatial_size, data[self.keys[0]].shape[1:]) self.randomize(spatial_size=sp_size) - if self.rand_2d_elastic.do_transform: - grid = self.rand_2d_elastic.deform_grid(spatial_size=sp_size) - grid = self.rand_2d_elastic.rand_affine_grid(grid=grid) - grid = torch.nn.functional.interpolate( # type: ignore - recompute_scale_factor=True, - input=grid.unsqueeze(0), - scale_factor=ensure_tuple_rep(self.rand_2d_elastic.deform_grid.spacing, 2), - mode=InterpolateMode.BICUBIC.value, - align_corners=False, - ) - grid = CenterSpatialCrop(roi_size=sp_size)(grid[0]) + if self._do_transform: + cpg = self.rand_2d_elastic.deform_grid(spatial_size=sp_size) + cpg_w_affine, affine = self.rand_2d_elastic.rand_affine_grid(grid=cpg, return_affine=True) + grid = self.cpg_to_dvf(cpg_w_affine, self.rand_2d_elastic.deform_grid.spacing, sp_size) + extra_info: Optional[Dict] = {"cpg": deepcopy(cpg), "affine": deepcopy(affine)} else: grid = create_grid(spatial_size=sp_size) + extra_info = None for idx, key in enumerate(self.keys): + self.append_applied_transforms(d, key, extra_info=extra_info) d[key] = self.rand_2d_elastic.resampler( d[key], grid, mode=self.mode[idx], padding_mode=self.padding_mode[idx] ) return d + def inverse( + self, data: Mapping[Hashable, np.ndarray], keys: Optional[Tuple[Hashable, ...]] = None + ) -> Dict[Hashable, np.ndarray]: + d = deepcopy(dict(data)) + # This variable will be `not None` if vtk or sitk is present + inv_def_no_affine = None + + for idx, key in enumerate(keys or self.keys): + transform = self.get_most_recent_transform(d, key) + # Create inverse transform + if transform["do_transform"]: + orig_size = transform["orig_size"] + # Only need to calculate inverse deformation once as it is the same for all keys + if idx == 0: + # If magnitude == 0, then non-rigid component is identity -- so just create blank + if self.rand_2d_elastic.deform_grid.magnitude == (0.0, 0.0): + inv_def_no_affine = create_grid(spatial_size=orig_size) + else: + fwd_cpg_no_affine = transform["extra_info"]["cpg"] + fwd_def_no_affine = self.cpg_to_dvf( + fwd_cpg_no_affine, self.rand_2d_elastic.deform_grid.spacing, orig_size + ) + inv_def_no_affine = self.compute_inverse_deformation(len(orig_size), fwd_def_no_affine) + # if inverse did not succeed (sitk or vtk present), data will not be changed. + if inv_def_no_affine is not None: + fwd_affine = transform["extra_info"]["affine"] + inv_affine = np.linalg.inv(fwd_affine) + inv_def_w_affine_wrong_size = AffineGrid(affine=inv_affine, as_tensor_output=False)( + grid=inv_def_no_affine + ) + # Back to original size + inv_def_w_affine = CenterSpatialCrop(roi_size=orig_size)(inv_def_w_affine_wrong_size) # type: ignore + # Apply inverse transform + if inv_def_no_affine is not None: + out = self.rand_2d_elastic.resampler( + d[key], inv_def_w_affine, self.mode[idx], self.padding_mode[idx] + ) + d[key] = out.cpu().numpy() if isinstance(out, torch.Tensor) else out + + else: + d[key] = CenterSpatialCrop(roi_size=orig_size)(d[key]) + # Remove the applied transform + self.remove_most_recent_transform(d, key) + + return d + -class Rand3DElasticd(Randomizable, MapTransform): +class Rand3DElasticd(Randomizable, MapTransform, InvertibleTransform, NonRigidTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.Rand3DElastic`. """ @@ -641,11 +857,12 @@ def __init__( - :py:class:`RandAffineGrid` for the random affine parameters configurations. - :py:class:`Affine` for the affine transformation parameters configurations. """ - super().__init__(keys) + MapTransform.__init__(self, keys) + Randomizable.__init__(self, prob) self.rand_3d_elastic = Rand3DElastic( sigma_range=sigma_range, magnitude_range=magnitude_range, - prob=prob, + prob=1.0, # because probability controlled by this class rotate_range=rotate_range, shear_range=shear_range, translate_range=translate_range, @@ -665,6 +882,7 @@ def set_random_state( return self def randomize(self, grid_size: Sequence[int]) -> None: + self._do_transform = self.R.rand() < self.prob self.rand_3d_elastic.randomize(grid_size) def __call__( @@ -674,23 +892,62 @@ def __call__( sp_size = fall_back_tuple(self.rand_3d_elastic.spatial_size, data[self.keys[0]].shape[1:]) self.randomize(grid_size=sp_size) - grid = create_grid(spatial_size=sp_size) - if self.rand_3d_elastic.do_transform: + grid_no_affine = create_grid(spatial_size=sp_size) + if self._do_transform: device = self.rand_3d_elastic.device - grid = torch.tensor(grid).to(device) + grid_no_affine = torch.tensor(grid_no_affine).to(device) gaussian = GaussianFilter(spatial_dims=3, sigma=self.rand_3d_elastic.sigma, truncated=3.0).to(device) offset = torch.tensor(self.rand_3d_elastic.rand_offset, device=device).unsqueeze(0) - grid[:3] += gaussian(offset)[0] * self.rand_3d_elastic.magnitude - grid = self.rand_3d_elastic.rand_affine_grid(grid=grid) + grid_no_affine[:3] += gaussian(offset)[0] * self.rand_3d_elastic.magnitude + grid_w_affine, affine = self.rand_3d_elastic.rand_affine_grid(grid=grid_no_affine, return_affine=True) for idx, key in enumerate(self.keys): + self.append_applied_transforms( + d, key, extra_info={"grid_no_affine": grid_no_affine.cpu().numpy(), "affine": affine} + ) d[key] = self.rand_3d_elastic.resampler( - d[key], grid, mode=self.mode[idx], padding_mode=self.padding_mode[idx] + d[key], grid_w_affine, mode=self.mode[idx], padding_mode=self.padding_mode[idx] ) return d + def inverse( + self, data: Mapping[Hashable, np.ndarray], keys: Optional[Tuple[Hashable, ...]] = None + ) -> Dict[Hashable, np.ndarray]: + d = deepcopy(dict(data)) + + for idx, key in enumerate(keys or self.keys): + transform = self.get_most_recent_transform(d, key) + # Create inverse transform + if transform["do_transform"]: + orig_size = transform["orig_size"] + # Only need to calculate inverse deformation once as it is the same for all keys + if idx == 0: + fwd_def_no_affine = transform["extra_info"]["grid_no_affine"] + inv_def_no_affine = self.compute_inverse_deformation(len(orig_size), fwd_def_no_affine) + # if inverse did not succeed (sitk or vtk present), data will not be changed. + if inv_def_no_affine is not None: + fwd_affine = transform["extra_info"]["affine"] + inv_affine = np.linalg.inv(fwd_affine) + inv_def_w_affine_wrong_size = AffineGrid(affine=inv_affine, as_tensor_output=False)( + grid=inv_def_no_affine + ) + # Back to original size + inv_def_w_affine = CenterSpatialCrop(roi_size=orig_size)(inv_def_w_affine_wrong_size) # type: ignore + # Apply inverse transform + if inv_def_w_affine is not None: + out = self.rand_3d_elastic.resampler( + d[key], inv_def_w_affine, self.mode[idx], self.padding_mode[idx] + ) + d[key] = out.cpu().numpy() if isinstance(out, torch.Tensor) else out + else: + d[key] = CenterSpatialCrop(roi_size=orig_size)(d[key]) + # Remove the applied transform + self.remove_most_recent_transform(d, key) + + return d + -class Flipd(MapTransform): +class Flipd(MapTransform, InvertibleTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.Flip`. @@ -709,11 +966,28 @@ def __init__(self, keys: KeysCollection, spatial_axis: Optional[Union[Sequence[i def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) for key in self.keys: + self.append_applied_transforms(d, key) d[key] = self.flipper(d[key]) return d + def inverse( + self, data: Mapping[Hashable, np.ndarray], keys: Optional[Tuple[Hashable, ...]] = None + ) -> Dict[Hashable, np.ndarray]: + d = deepcopy(dict(data)) + for key in keys or self.keys: + _ = self.get_most_recent_transform(d, key) + # Might need to convert to numpy + if isinstance(d[key], torch.Tensor): + d[key] = torch.Tensor(d[key]).cpu().numpy() + # Inverse is same as forward + d[key] = self.flipper(d[key]) + # Remove the applied transform + self.remove_most_recent_transform(d, key) + + return d -class RandFlipd(Randomizable, MapTransform): + +class RandFlipd(Randomizable, MapTransform, InvertibleTransform): """ Dictionary-based version :py:class:`monai.transforms.RandFlip`. @@ -732,11 +1006,10 @@ def __init__( prob: float = 0.1, spatial_axis: Optional[Union[Sequence[int], int]] = None, ) -> None: - super().__init__(keys) + MapTransform.__init__(self, keys) + Randomizable.__init__(self, prob) self.spatial_axis = spatial_axis - self.prob = prob - self._do_transform = False self.flipper = Flip(spatial_axis=spatial_axis) def randomize(self, data: Optional[Any] = None) -> None: @@ -745,14 +1018,33 @@ def randomize(self, data: Optional[Any] = None) -> None: def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: self.randomize() d = dict(data) - if not self._do_transform: - return d for key in self.keys: - d[key] = self.flipper(d[key]) + if self._do_transform: + d[key] = self.flipper(d[key]) + self.append_applied_transforms(d, key) + + return d + + def inverse( + self, data: Mapping[Hashable, np.ndarray], keys: Optional[Tuple[Hashable, ...]] = None + ) -> Dict[Hashable, np.ndarray]: + d = deepcopy(dict(data)) + for key in keys or self.keys: + transform = self.get_most_recent_transform(d, key) + # Check if random transform was actually performed (based on `prob`) + if transform["do_transform"]: + # Might need to convert to numpy + if isinstance(d[key], torch.Tensor): + d[key] = torch.Tensor(d[key]).cpu().numpy() + # Inverse is same as forward + d[key] = self.flipper(d[key]) + # Remove the applied transform + self.remove_most_recent_transform(d, key) + return d -class Rotated(MapTransform): +class Rotated(MapTransform, InvertibleTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.Rotate`. @@ -800,17 +1092,49 @@ def __init__( def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) for idx, key in enumerate(self.keys): - d[key] = self.rotator( + orig_size = d[key].shape[1:] + d[key], rot_mat = self.rotator( d[key], mode=self.mode[idx], padding_mode=self.padding_mode[idx], align_corners=self.align_corners[idx], dtype=self.dtype[idx], + return_rotation_matrix=True, ) + self.append_applied_transforms(d, key, orig_size=orig_size, extra_info={"rot_mat": rot_mat}) return d + def inverse( + self, data: Mapping[Hashable, np.ndarray], keys: Optional[Tuple[Hashable, ...]] = None + ) -> Dict[Hashable, np.ndarray]: + d = deepcopy(dict(data)) + for idx, key in enumerate(keys or self.keys): + transform = self.get_most_recent_transform(d, key) + # Create inverse transform + fwd_rot_mat = transform["extra_info"]["rot_mat"] + inv_rot_mat = np.linalg.inv(fwd_rot_mat) + + xform = AffineTransform( + normalized=False, + mode=self.mode[idx], + padding_mode=self.padding_mode[idx], + align_corners=self.align_corners[idx], + reverse_indexing=True, + ) + dtype = self.dtype[idx] + output = xform( + torch.as_tensor(np.ascontiguousarray(d[key]).astype(dtype)).unsqueeze(0), + torch.as_tensor(np.ascontiguousarray(inv_rot_mat).astype(dtype)), + spatial_size=transform["orig_size"], + ) + d[key] = np.asarray(output.squeeze(0).detach().cpu().numpy(), dtype=np.float32) + # Remove the applied transform + self.remove_most_recent_transform(d, key) -class RandRotated(Randomizable, MapTransform): + return d + + +class RandRotated(Randomizable, MapTransform, InvertibleTransform): """ Dictionary-based version :py:class:`monai.transforms.RandRotate` Randomly rotates the input arrays. @@ -857,7 +1181,8 @@ def __init__( align_corners: Union[Sequence[bool], bool] = False, dtype: Union[Sequence[DtypeLike], DtypeLike] = np.float64, ) -> None: - super().__init__(keys) + MapTransform.__init__(self, keys) + Randomizable.__init__(self, prob) self.range_x = ensure_tuple(range_x) if len(self.range_x) == 1: self.range_x = tuple(sorted([-self.range_x[0], self.range_x[0]])) @@ -868,14 +1193,12 @@ def __init__( if len(self.range_z) == 1: self.range_z = tuple(sorted([-self.range_z[0], self.range_z[0]])) - self.prob = prob self.keep_size = keep_size self.mode = ensure_tuple_rep(mode, len(self.keys)) self.padding_mode = ensure_tuple_rep(padding_mode, len(self.keys)) self.align_corners = ensure_tuple_rep(align_corners, len(self.keys)) self.dtype = ensure_tuple_rep(dtype, len(self.keys)) - self._do_transform = False self.x = 0.0 self.y = 0.0 self.z = 0.0 @@ -890,23 +1213,60 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda self.randomize() d = dict(data) if not self._do_transform: + for key in self.keys: + self.append_applied_transforms(d, key, extra_info={"rot_mat": np.eye(4)}) return d + angle: Union[Sequence[float], float] = self.x if d[self.keys[0]].ndim == 3 else (self.x, self.y, self.z) rotator = Rotate( - angle=self.x if d[self.keys[0]].ndim == 3 else (self.x, self.y, self.z), + angle=angle, keep_size=self.keep_size, ) for idx, key in enumerate(self.keys): - d[key] = rotator( + orig_size = d[key].shape[1:] + d[key], rot_mat = rotator( d[key], mode=self.mode[idx], padding_mode=self.padding_mode[idx], align_corners=self.align_corners[idx], dtype=self.dtype[idx], + return_rotation_matrix=True, ) + self.append_applied_transforms(d, key, orig_size=orig_size, extra_info={"rot_mat": rot_mat}) + return d + + def inverse( + self, data: Mapping[Hashable, np.ndarray], keys: Optional[Tuple[Hashable, ...]] = None + ) -> Dict[Hashable, np.ndarray]: + d = deepcopy(dict(data)) + for idx, key in enumerate(keys or self.keys): + transform = self.get_most_recent_transform(d, key) + # Check if random transform was actually performed (based on `prob`) + if transform["do_transform"]: + # Create inverse transform + fwd_rot_mat = transform["extra_info"]["rot_mat"] + inv_rot_mat = np.linalg.inv(fwd_rot_mat) + + xform = AffineTransform( + normalized=False, + mode=self.mode[idx], + padding_mode=self.padding_mode[idx], + align_corners=self.align_corners[idx], + reverse_indexing=True, + ) + dtype = self.dtype[idx] + output = xform( + torch.as_tensor(np.ascontiguousarray(d[key]).astype(dtype)).unsqueeze(0), + torch.as_tensor(np.ascontiguousarray(inv_rot_mat).astype(dtype)), + spatial_size=transform["orig_size"], + ) + d[key] = np.asarray(output.squeeze(0).detach().cpu().numpy(), dtype=np.float32) + # Remove the applied transform + self.remove_most_recent_transform(d, key) + return d -class Zoomd(MapTransform): +class Zoomd(MapTransform, InvertibleTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.Zoom`. @@ -948,6 +1308,7 @@ def __init__( def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) for idx, key in enumerate(self.keys): + self.append_applied_transforms(d, key) d[key] = self.zoomer( d[key], mode=self.mode[idx], @@ -956,8 +1317,31 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda ) return d + def inverse( + self, data: Mapping[Hashable, np.ndarray], keys: Optional[Tuple[Hashable, ...]] = None + ) -> Dict[Hashable, np.ndarray]: + d = deepcopy(dict(data)) + for idx, key in enumerate(keys or self.keys): + transform = self.get_most_recent_transform(d, key) + # Create inverse transform + zoom = np.array(self.zoomer.zoom) + inverse_transform = Zoom(zoom=1 / zoom, keep_size=self.zoomer.keep_size) + # Apply inverse + d[key] = inverse_transform( + d[key], + mode=self.mode[idx], + padding_mode=self.padding_mode[idx], + align_corners=self.align_corners[idx], + ) + # Size might be out by 1 voxel so pad + d[key] = SpatialPad(transform["orig_size"])(d[key]) + # Remove the applied transform + self.remove_most_recent_transform(d, key) -class RandZoomd(Randomizable, MapTransform): + return d + + +class RandZoomd(Randomizable, MapTransform, InvertibleTransform): """ Dict-based version :py:class:`monai.transforms.RandZoom`. @@ -1000,19 +1384,18 @@ def __init__( align_corners: Union[Sequence[Optional[bool]], Optional[bool]] = None, keep_size: bool = True, ) -> None: - super().__init__(keys) + MapTransform.__init__(self, keys) + Randomizable.__init__(self, prob) self.min_zoom = ensure_tuple(min_zoom) self.max_zoom = ensure_tuple(max_zoom) if len(self.min_zoom) != len(self.max_zoom): raise AssertionError("min_zoom and max_zoom must have same length.") - self.prob = prob self.mode = ensure_tuple_rep(mode, len(self.keys)) self.padding_mode = ensure_tuple_rep(padding_mode, len(self.keys)) self.align_corners = ensure_tuple_rep(align_corners, len(self.keys)) self.keep_size = keep_size - self._do_transform = False self._zoom: Sequence[float] = [1.0] def randomize(self, data: Optional[Any] = None) -> None: @@ -1024,6 +1407,8 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda self.randomize() d = dict(data) if not self._do_transform: + for key in self.keys: + self.append_applied_transforms(d, key, extra_info={"zoom": self._zoom}) return d img_dims = data[self.keys[0]].ndim @@ -1035,6 +1420,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda self._zoom = ensure_tuple_rep(self._zoom[0], img_dims - 2) + ensure_tuple(self._zoom[-1]) zoomer = Zoom(self._zoom, keep_size=self.keep_size) for idx, key in enumerate(self.keys): + self.append_applied_transforms(d, key, extra_info={"zoom": self._zoom}) d[key] = zoomer( d[key], mode=self.mode[idx], @@ -1043,6 +1429,29 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda ) return d + def inverse( + self, data: Mapping[Hashable, np.ndarray], keys: Optional[Tuple[Hashable, ...]] = None + ) -> Dict[Hashable, np.ndarray]: + d = deepcopy(dict(data)) + for idx, key in enumerate(keys or self.keys): + transform = self.get_most_recent_transform(d, key) + # Create inverse transform + zoom = np.array(transform["extra_info"]["zoom"]) + inverse_transform = Zoom(zoom=1 / zoom, keep_size=self.keep_size) + # Apply inverse + d[key] = inverse_transform( + d[key], + mode=self.mode[idx], + padding_mode=self.padding_mode[idx], + align_corners=self.align_corners[idx], + ) + # Size might be out by 1 voxel so pad + d[key] = SpatialPad(transform["orig_size"])(d[key]) + # Remove the applied transform + self.remove_most_recent_transform(d, key) + + return d + SpacingD = SpacingDict = Spacingd OrientationD = OrientationDict = Orientationd diff --git a/monai/transforms/transform.py b/monai/transforms/transform.py index e5841cbe97..4f0b2eca79 100644 --- a/monai/transforms/transform.py +++ b/monai/transforms/transform.py @@ -43,6 +43,10 @@ def __call__(self, img): R: np.random.RandomState = np.random.RandomState() + def __init__(self, prob): + self._do_transform = False + self.prob = prob + def set_random_state( self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None ) -> "Randomizable": diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index 9a84eb00d9..841570be93 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -375,17 +375,10 @@ def apply_transform(transform: Callable, data, map_items: bool = True): data: an object to be transformed. map_items: whether to apply transform to each item in `data`, if `data` is a list or tuple. Defaults to True. - - Raises: - Exception: When ``transform`` raises an exception. - """ - try: - if isinstance(data, (list, tuple)) and map_items: - return [transform(item) for item in data] - return transform(data) - except Exception as e: - raise RuntimeError(f"applying transform {transform}") from e + if isinstance(data, (list, tuple)) and map_items: + return [transform(item) for item in data] + return transform(data) def create_grid( diff --git a/requirements-dev.txt b/requirements-dev.txt index 2a43e63d73..1f7b608f79 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -30,3 +30,4 @@ Sphinx==3.3.0 recommonmark==0.6.0 sphinx-autodoc-typehints==1.11.1 sphinx-rtd-theme==0.5.0 +vtk \ No newline at end of file diff --git a/tests/test_inverse.py b/tests/test_inverse.py new file mode 100644 index 0000000000..1ca0e3e8cf --- /dev/null +++ b/tests/test_inverse.py @@ -0,0 +1,599 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random +import sys +import unittest +from typing import TYPE_CHECKING, List, Tuple + +import numpy as np +import torch +from parameterized import parameterized + +from monai.data import BatchInverseTransform, CacheDataset, DataLoader, create_test_image_2d, create_test_image_3d +from monai.data.utils import decollate_batch, pad_list_data_collate +from monai.networks.nets import UNet +from monai.transforms import ( + AddChannel, + AddChanneld, + BorderPadd, + CenterSpatialCropd, + Compose, + CropForegroundd, + DivisiblePadd, + Flipd, + InvertibleTransform, + LoadImaged, + Orientationd, + Rand2DElasticd, + Rand3DElasticd, + RandAffined, + RandFlipd, + RandRotate90d, + RandRotated, + RandSpatialCropd, + RandZoomd, + Resized, + ResizeWithPadOrCrop, + ResizeWithPadOrCropd, + Rotate90d, + Rotated, + Spacingd, + SpatialCropd, + SpatialPad, + SpatialPadd, + Zoomd, +) +from monai.utils import first, optional_import, set_determinism +from tests.utils import make_nifti_image, make_rand_affine, test_is_quick + +if TYPE_CHECKING: + import matplotlib.pyplot as plt + + has_matplotlib = True + has_vtk = True +else: + plt, has_matplotlib = optional_import("matplotlib.pyplot") + _, has_vtk = optional_import("vtk") + +set_determinism(seed=0) + +AFFINE = make_rand_affine() +AFFINE[0] *= 2 + +IM_1D = AddChannel()(np.arange(0, 10)) +IM_2D_FNAME, SEG_2D_FNAME = [make_nifti_image(i) for i in create_test_image_2d(101, 100)] +IM_3D_FNAME, SEG_3D_FNAME = [make_nifti_image(i, AFFINE) for i in create_test_image_3d(100, 101, 107)] + +KEYS = ["image", "label"] +DATA_1D = {"image": IM_1D, "label": IM_1D, "other": IM_1D} +LOAD_IMS = Compose([LoadImaged(KEYS), AddChanneld(KEYS)]) +DATA_2D = LOAD_IMS({"image": IM_2D_FNAME, "label": SEG_2D_FNAME}) +DATA_3D = LOAD_IMS({"image": IM_3D_FNAME, "label": SEG_3D_FNAME}) + +TESTS: List[Tuple] = [] + +TESTS.append( + ( + "SpatialPadd (x2) 2d", + DATA_2D, + 0.0, + SpatialPadd(KEYS, spatial_size=[111, 113], method="end"), + SpatialPadd(KEYS, spatial_size=[118, 117]), + ) +) + +TESTS.append( + ( + "SpatialPadd 3d", + DATA_3D, + 0.0, + SpatialPadd(KEYS, spatial_size=[112, 113, 116]), + ) +) + +TESTS.append( + ( + "RandRotated, prob 0", + DATA_2D, + 0, + RandRotated(KEYS, prob=0), + ) +) + +TESTS.append( + ( + "SpatialCropd 2d", + DATA_2D, + 3e-2, + SpatialCropd(KEYS, [49, 51], [90, 89]), + ) +) + +TESTS.append( + ( + "SpatialCropd 3d", + DATA_3D, + 4e-2, + SpatialCropd(KEYS, [49, 51, 44], [90, 89, 93]), + ) +) + +TESTS.append(("RandSpatialCropd 2d", DATA_2D, 5e-2, RandSpatialCropd(KEYS, [96, 93], True, False))) + +TESTS.append(("RandSpatialCropd 3d", DATA_3D, 2e-2, RandSpatialCropd(KEYS, [96, 93, 92], False, False))) + +TESTS.append( + ( + "BorderPadd 2d", + DATA_2D, + 0, + BorderPadd(KEYS, [3, 7, 2, 5]), + ) +) + +TESTS.append( + ( + "BorderPadd 2d", + DATA_2D, + 0, + BorderPadd(KEYS, [3, 7]), + ) +) + +TESTS.append( + ( + "BorderPadd 3d", + DATA_3D, + 0, + BorderPadd(KEYS, [4]), + ) +) + +TESTS.append( + ( + "DivisiblePadd 2d", + DATA_2D, + 0, + DivisiblePadd(KEYS, k=4), + ) +) + +TESTS.append( + ( + "DivisiblePadd 3d", + DATA_3D, + 0, + DivisiblePadd(KEYS, k=[4, 8, 11]), + ) +) + +TESTS.append( + ( + "Flipd 3d", + DATA_3D, + 0, + Flipd(KEYS, [1, 2]), + ) +) + +TESTS.append( + ( + "Flipd 3d", + DATA_3D, + 0, + Flipd(KEYS, [1, 2]), + ) +) + +TESTS.append( + ( + "RandFlipd 3d", + DATA_3D, + 0, + RandFlipd(KEYS, 1, [1, 2]), + ) +) + +TESTS.append( + ( + "Rotated 2d", + DATA_2D, + 8e-2, + Rotated(KEYS, random.uniform(np.pi / 6, np.pi), keep_size=True, align_corners=False), + ) +) + +TESTS.append( + ( + "Rotated 3d", + DATA_3D, + 5e-2, + Rotated(KEYS, [random.uniform(np.pi / 6, np.pi) for _ in range(3)], True), # type: ignore + ) +) + +TESTS.append( + ( + "RandRotated 3d", + DATA_3D, + 5e-2, + RandRotated(KEYS, *[random.uniform(np.pi / 6, np.pi) for _ in range(3)], 1), # type: ignore + ) +) + +TESTS.append( + ( + "Orientationd 3d", + DATA_3D, + 0, + # For data loader, output needs to be same size, so input must be square/cubic + SpatialPadd(KEYS, max(DATA_3D["image"].shape)), + Orientationd(KEYS, "RAS"), + ) +) + +TESTS.append( + ( + "Rotate90d 2d", + DATA_2D, + 0, + Rotate90d(KEYS), + ) +) + +TESTS.append( + ( + "Rotate90d 3d", + DATA_3D, + 0, + Rotate90d(KEYS, k=2, spatial_axes=(1, 2)), + ) +) + +TESTS.append( + ( + "RandRotate90d 3d", + DATA_3D, + 0, + # For data loader, output needs to be same size, so input must be square/cubic + SpatialPadd(KEYS, max(DATA_3D["image"].shape)), + RandRotate90d(KEYS, prob=1, spatial_axes=(1, 2)), + ) +) + +TESTS.append( + ( + "Zoomd 1d", + DATA_1D, + 0, + Zoomd(KEYS, zoom=2, keep_size=False), + ) +) + +TESTS.append( + ( + "Zoomd 2d", + DATA_2D, + 2e-1, + Zoomd(KEYS, zoom=0.9), + ) +) + +TESTS.append( + ( + "Zoomd 3d", + DATA_3D, + 3e-2, + Zoomd(KEYS, zoom=[2.5, 1, 3], keep_size=False), + ) +) + +TESTS.append(("RandZoom 3d", DATA_3D, 9e-2, RandZoomd(KEYS, 1, [0.5, 0.6, 0.9], [1.1, 1, 1.05], keep_size=True))) + +TESTS.append( + ( + "CenterSpatialCropd 2d", + DATA_2D, + 0, + CenterSpatialCropd(KEYS, roi_size=95), + ) +) + +TESTS.append( + ( + "CenterSpatialCropd 3d", + DATA_3D, + 0, + CenterSpatialCropd(KEYS, roi_size=[95, 97, 98]), + ) +) + +TESTS.append(("CropForegroundd 2d", DATA_2D, 0, CropForegroundd(KEYS, source_key="label", margin=2))) + +TESTS.append(("CropForegroundd 3d", DATA_3D, 0, CropForegroundd(KEYS, source_key="label"))) + +TESTS.append(("Spacingd 3d", DATA_3D, 3e-2, Spacingd(KEYS, [0.5, 0.7, 0.9], diagonal=False))) + +TESTS.append(("Resized 2d", DATA_2D, 2e-1, Resized(KEYS, [50, 47]))) + +TESTS.append(("Resized 3d", DATA_3D, 5e-2, Resized(KEYS, [201, 150, 78]))) + +TESTS.append(("ResizeWithPadOrCropd 3d", DATA_3D, 1e-2, ResizeWithPadOrCropd(KEYS, [201, 150, 78]))) + +TESTS.append( + ( + "RandAffine 3d", + DATA_3D, + 5e-2, + RandAffined( + KEYS, + [155, 179, 192], + prob=1, + padding_mode="zeros", + rotate_range=[np.pi / 6, -np.pi / 5, np.pi / 7], + shear_range=[(0.5, 0.5)], + translate_range=[10, 5, -4], + scale_range=[(0.8, 1.2), (0.9, 1.3)], + ), + ) +) + +if has_vtk: + TESTS.append( + ( + "Rand2DElasticd 2d", + DATA_2D, + 2e-1, + Rand2DElasticd( + KEYS, + spacing=(10.0, 10.0), + magnitude_range=(1, 1), + spatial_size=[155, 192], + prob=1, + padding_mode="zeros", + rotate_range=[(np.pi / 6, np.pi / 6)], + shear_range=[(0.5, 0.5)], + translate_range=[10, 5], + scale_range=[(1.2, 1.2), (1.3, 1.3)], + ), + ) + ) + +if not test_is_quick and has_vtk: + TESTS.append( + ( + "Rand3DElasticd 3d", + DATA_3D, + 1e-1, + Rand3DElasticd( + KEYS, + sigma_range=(3, 5), + magnitude_range=(100, 100), + prob=1, + padding_mode="zeros", + rotate_range=[np.pi / 6, np.pi / 7], + shear_range=[(0.5, 0.5), 0.2], + translate_range=[10, 5, 3], + scale_range=[(0.8, 1.2), (0.9, 1.3)], + ), + ) + ) + +TESTS_COMPOSE_X2 = [(t[0] + " Compose", t[1], t[2], Compose(Compose(t[3:]))) for t in TESTS] + +TESTS = TESTS + TESTS_COMPOSE_X2 # type: ignore + + +# Should fail because uses an array transform (SpatialPad), as opposed to dictionary +TEST_FAIL_0 = (DATA_2D["image"], 0.0, Compose([SpatialPad(spatial_size=[101, 103])])) +TESTS_FAIL = [TEST_FAIL_0] + + +def plot_im(orig, fwd_bck, fwd): + diff_orig_fwd_bck = orig - fwd_bck + ims_to_show = [orig, fwd, fwd_bck, diff_orig_fwd_bck] + titles = ["x", "fx", "f⁻¹fx", "x - f⁻¹fx"] + fig, axes = plt.subplots(1, 4, gridspec_kw={"width_ratios": [i.shape[1] for i in ims_to_show]}) + vmin = min(np.array(i).min() for i in [orig, fwd_bck, fwd]) + vmax = max(np.array(i).max() for i in [orig, fwd_bck, fwd]) + for im, title, ax in zip(ims_to_show, titles, axes): + _vmin, _vmax = (vmin, vmax) if id(im) != id(diff_orig_fwd_bck) else (None, None) + im = np.squeeze(np.array(im)) + while im.ndim > 2: + im = im[..., im.shape[-1] // 2] + im_show = ax.imshow(np.squeeze(im), vmin=_vmin, vmax=_vmax) + ax.set_title(title, fontsize=25) + ax.axis("off") + fig.colorbar(im_show, ax=ax) + plt.show() + + +class TestInverse(unittest.TestCase): + def check_inverse(self, name, keys, orig_d, fwd_bck_d, unmodified_d, acceptable_diff): + for key in keys: + orig = orig_d[key] + fwd_bck = fwd_bck_d[key] + if isinstance(fwd_bck, torch.Tensor): + fwd_bck = fwd_bck.cpu().numpy() + unmodified = unmodified_d[key] + if isinstance(orig, np.ndarray): + mean_diff = np.mean(np.abs(orig - fwd_bck)) + unmodded_diff = np.mean(np.abs(orig - ResizeWithPadOrCrop(orig.shape[1:])(unmodified))) + try: + self.assertLessEqual(mean_diff, acceptable_diff) + except AssertionError: + print( + f"Failed: {name}. Mean diff = {mean_diff} (expected <= {acceptable_diff}), unmodified diff: {unmodded_diff}" + ) + if has_matplotlib and orig[0].ndim > 1: + plot_im(orig, fwd_bck, unmodified) + elif orig[0].ndim == 1: + print(orig) + print(fwd_bck) + raise + + @parameterized.expand(TESTS) + def test_inverse(self, _, data, acceptable_diff, *transforms): + name = _ + + forwards = [data.copy()] + + # Apply forwards + for t in transforms: + forwards.append(t(forwards[-1])) + + # Check that error is thrown when inverse are used out of order. + t = SpatialPadd("image", [10, 5]) + with self.assertRaises(RuntimeError): + t.inverse(forwards[-1]) + + # Apply inverses + fwd_bck = forwards[-1].copy() + for i, t in enumerate(reversed(transforms)): + if isinstance(t, InvertibleTransform): + fwd_bck = t.inverse(fwd_bck) + self.check_inverse(name, data.keys(), forwards[-i - 2], fwd_bck, forwards[-1], acceptable_diff) + + @parameterized.expand(TESTS) + def test_w_dataloader(self, _, data, acceptable_diff, *transforms): + name = _ + device = "cpu" + if isinstance(transforms, tuple): + transforms = Compose(transforms) + numel = 4 + test_data = [data for _ in range(numel)] + + ndims = len(data["image"].shape[1:]) + batch_size = 2 + num_workers = 0 + + dataset = CacheDataset(test_data, transforms, progress=False) + loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers) + inv_batch = BatchInverseTransform(transforms, loader) + + model = UNet(ndims, 1, 1, channels=(6, 6), strides=(2, 2)).to(device) + for batch_data in loader: + inputs, _ = ( + batch_data["image"].to(device), + batch_data["label"].to(device), + ) + + fwd_bck_batch = inv_batch(batch_data) + fwd_bck = decollate_batch(fwd_bck_batch) + + for idx, (_test_data, _fwd_bck) in enumerate(zip(test_data, fwd_bck)): + _fwd = transforms(test_data[idx]) + self.check_inverse(name, data.keys(), _test_data, _fwd_bck, _fwd, acceptable_diff) + + if torch.cuda.is_available(): + _ = model(inputs) + + def test_diff_sized_inputs(self): + + key = "image" + test_data = [{key: AddChannel()(create_test_image_2d(100 + i, 101 + i)[0])} for i in range(4)] + + batch_size = 2 + num_workers = 0 + transforms = Compose([SpatialPadd(key, (150, 153))]) + + dataset = CacheDataset(test_data, transform=transforms, progress=False) + loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers) + # blank collate function since input are different size + inv_batch = BatchInverseTransform(transforms, loader, collate_fn=lambda x: x) + + for batch_idx, batch_data in enumerate(loader): + fwd = decollate_batch(batch_data) + fwd_bck = inv_batch(batch_data) + + for idx, (_fwd, _fwd_bck) in enumerate(zip(fwd, fwd_bck)): + unmodified = test_data[batch_idx * batch_size + idx] + self.check_inverse("diff_sized_inputs", [key], unmodified, _fwd_bck, _fwd, 0) + + def test_inverse_w_pad_list_data_collate(self): + + test_data = [] + for _ in range(4): + image, label = [AddChannel()(i) for i in create_test_image_2d(100, 101)] + test_data.append({"image": image, "label": label.astype(np.float32)}) + + batch_size = 2 + num_workers = 0 + transforms = Compose([CropForegroundd(KEYS, source_key="label")]) + + dataset = CacheDataset(test_data, transform=transforms, progress=False) + loader = DataLoader( + dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, collate_fn=pad_list_data_collate + ) + # blank collate function since input are different size + inv_batch = BatchInverseTransform(transforms, loader) + + for batch_idx, batch_data in enumerate(loader): + fwd = decollate_batch(batch_data) + fwd_bck = decollate_batch(inv_batch(batch_data)) + + for idx, (_fwd, _fwd_bck) in enumerate(zip(fwd, fwd_bck)): + unmodified = test_data[batch_idx * batch_size + idx] + self.check_inverse("diff_sized_inputs", KEYS, unmodified, _fwd_bck, _fwd, 1e-1) + + @parameterized.expand(TESTS_FAIL) + def test_fail(self, data, _, *transform): + d = transform[0](data) + with self.assertRaises(RuntimeError): + d = transform[0].inverse(d) + + def test_inverse_inferred_seg(self): + + test_data = [] + for _ in range(4): + image, label = create_test_image_2d(100, 101) + test_data.append({"image": image, "label": label.astype(np.float32)}) + + batch_size = 2 + # num workers = 0 for mac + num_workers = 2 if sys.platform != "darwin" else 0 + transforms = Compose([AddChanneld(KEYS), SpatialPadd(KEYS, (150, 153)), CenterSpatialCropd(KEYS, (110, 99))]) + num_invertible_transforms = sum(1 for i in transforms.transforms if isinstance(i, InvertibleTransform)) + + dataset = CacheDataset(test_data, transform=transforms, progress=False) + loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers) + + device = "cuda" if torch.cuda.is_available() else "cpu" + model = UNet( + dimensions=2, + in_channels=1, + out_channels=1, + channels=(2, 4), + strides=(2,), + ).to(device) + + data = first(loader) + labels = data["label"].to(device) + segs = model(labels).detach().cpu() + segs_dict = {"label": segs, "label_transforms": data["label_transforms"]} + segs_dict_decollated = decollate_batch(segs_dict) + + # inverse of individual segmentation + seg_dict = first(segs_dict_decollated) + inv_seg = transforms.inverse(seg_dict, "label")["label"] + self.assertEqual(len(data["label_transforms"]), num_invertible_transforms) + self.assertEqual(len(seg_dict["label_transforms"]), num_invertible_transforms) + self.assertEqual(inv_seg.shape[1:], test_data[0]["label"].shape) + + # Inverse of batch + batch_inverter = BatchInverseTransform(transforms, loader, collate_fn=lambda x: x) + inv_batch = batch_inverter(segs_dict, "label") + self.assertEqual(inv_batch[0]["label"].shape[1:], test_data[0]["label"].shape) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_testtimeaugmentation.py b/tests/test_testtimeaugmentation.py new file mode 100644 index 0000000000..93027e1ac8 --- /dev/null +++ b/tests/test_testtimeaugmentation.py @@ -0,0 +1,134 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from functools import partial +from typing import TYPE_CHECKING + +import numpy as np +import torch +from torch._C import has_cuda + +from monai.data import CacheDataset, DataLoader, create_test_image_2d +from monai.data.test_time_augmentation import TestTimeAugmentation +from monai.data.utils import pad_list_data_collate +from monai.losses import DiceLoss +from monai.networks.nets import UNet +from monai.transforms import ( + Activations, + AddChanneld, + AsDiscrete, + Compose, + CropForegroundd, + DivisiblePadd, + KeepLargestConnectedComponent, + RandAffined, +) +from monai.transforms.croppad.dictionary import SpatialPadd +from monai.utils import optional_import, set_determinism + +if TYPE_CHECKING: + import tqdm + + has_tqdm = True +else: + tqdm, has_tqdm = optional_import("tqdm") + +trange = partial(tqdm.trange, desc="training") if has_tqdm else range + +set_determinism(seed=0) + + +class TestTestTimeAugmentation(unittest.TestCase): + def test_test_time_augmentation(self): + input_size = (20, 20) + device = "cuda" if has_cuda else "cpu" + num_training_ims = 10 + data = [] + custom_create_test_image_2d = partial( + create_test_image_2d, *input_size, rad_max=7, num_seg_classes=1, num_objs=1 + ) + keys = ["image", "label"] + + for _ in range(num_training_ims): + im, label = custom_create_test_image_2d() + data.append({"image": im, "label": label}) + + transforms = Compose( + [ + AddChanneld(keys), + RandAffined( + keys, + prob=1.0, + spatial_size=(30, 30), + rotate_range=(np.pi / 3, np.pi / 3), + translate_range=(3, 3), + scale_range=((0.8, 1), (0.8, 1)), + padding_mode="zeros", + mode=("bilinear", "nearest"), + as_tensor_output=False, + ), + CropForegroundd(keys, source_key="image"), + DivisiblePadd(keys, 4), + ] + ) + + train_ds = CacheDataset(data, transforms) + # output might be different size, so pad so that they match + train_loader = DataLoader(train_ds, batch_size=2, collate_fn=pad_list_data_collate) + + model = UNet(2, 1, 1, channels=(6, 6), strides=(2, 2)).to(device) + loss_function = DiceLoss(sigmoid=True) + optimizer = torch.optim.Adam(model.parameters(), 1e-3) + + num_epochs = 10 + for _ in trange(num_epochs): + epoch_loss = 0 + + for batch_data in train_loader: + inputs, labels = batch_data["image"].to(device), batch_data["label"].to(device) + optimizer.zero_grad() + outputs = model(inputs) + loss = loss_function(outputs, labels) + loss.backward() + optimizer.step() + epoch_loss += loss.item() + + epoch_loss /= len(train_loader) + + image, label = custom_create_test_image_2d() + test_data = {"image": image, "label": label} + + post_trans = Compose( + [ + Activations(sigmoid=True), + AsDiscrete(threshold_values=True), + KeepLargestConnectedComponent(applied_labels=1), + ] + ) + + def inferrer_fn(x): + return post_trans(model(x)) + + tt_aug = TestTimeAugmentation(transforms, batch_size=5, num_workers=0, inferrer_fn=inferrer_fn, device=device) + mean, std = tt_aug(test_data) + self.assertEqual(mean.shape, (1,) + input_size) + self.assertEqual((mean.min(), mean.max()), (0.0, 1.0)) + self.assertEqual(std.shape, (1,) + input_size) + + def test_fail_non_random(self): + transforms = Compose([AddChanneld("im"), SpatialPadd("im", 1)]) + with self.assertRaises(RuntimeError): + TestTimeAugmentation(transforms, None, None, None, None) + + +if __name__ == "__main__": + unittest.main()