diff --git a/docs/source/data.rst b/docs/source/data.rst index 6ed6be9702..8071bb1585 100644 --- a/docs/source/data.rst +++ b/docs/source/data.rst @@ -183,3 +183,8 @@ DataLoader ThreadBuffer ~~~~~~~~~~~~ .. autoclass:: monai.data.ThreadBuffer + + +BatchInverseTransform +~~~~~~~~~~~~~~~~~~~~~ +.. autoclass:: monai.data.BatchInverseTransform diff --git a/monai/data/__init__.py b/monai/data/__init__.py index 2a7647e527..2001ccfc8f 100644 --- a/monai/data/__init__.py +++ b/monai/data/__init__.py @@ -26,6 +26,7 @@ from .grid_dataset import GridPatchDataset, PatchDataset, PatchIter from .image_dataset import ImageDataset from .image_reader import ImageReader, ITKReader, NibabelReader, NumpyReader, PILReader, WSIReader +from .inverse_batch_transform import BatchInverseTransform from .iterable_dataset import IterableDataset from .nifti_saver import NiftiSaver from .nifti_writer import write_nifti diff --git a/monai/data/inverse_batch_transform.py b/monai/data/inverse_batch_transform.py new file mode 100644 index 0000000000..fbc42c6ce1 --- /dev/null +++ b/monai/data/inverse_batch_transform.py @@ -0,0 +1,84 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Callable, Dict, Hashable, Optional, Sequence + +import numpy as np +from torch.utils.data.dataloader import DataLoader as TorchDataLoader + +from monai.data.dataloader import DataLoader +from monai.data.dataset import Dataset +from monai.data.utils import decollate_batch, pad_list_data_collate +from monai.transforms.croppad.batch import PadListDataCollate +from monai.transforms.inverse import InvertibleTransform +from monai.transforms.transform import Transform +from monai.utils import first + +__all__ = ["BatchInverseTransform"] + + +class _BatchInverseDataset(Dataset): + def __init__( + self, + data: Sequence[Any], + transform: InvertibleTransform, + pad_collation_used: bool, + ) -> None: + super().__init__(data, transform) + self.invertible_transform = transform + self.pad_collation_used = pad_collation_used + + def __getitem__(self, index: int) -> Dict[Hashable, np.ndarray]: + data = dict(self.data[index]) + # If pad collation was used, then we need to undo this first + if self.pad_collation_used: + data = PadListDataCollate.inverse(data) + + return self.invertible_transform.inverse(data) + + +def no_collation(x): + return x + + +class BatchInverseTransform(Transform): + """Perform inverse on a batch of data. This is useful if you have inferred a batch of images and want to invert them all.""" + + def __init__( + self, transform: InvertibleTransform, loader: TorchDataLoader, collate_fn: Optional[Callable] = no_collation + ) -> None: + """ + Args: + transform: a callable data transform on input data. + loader: data loader used to generate the batch of data. + collate_fn: how to collate data after inverse transformations. Default won't do any collation, so the output will be a + list of size batch size. + """ + self.transform = transform + self.batch_size = loader.batch_size + self.num_workers = loader.num_workers + self.collate_fn = collate_fn + self.pad_collation_used = loader.collate_fn == pad_list_data_collate + + def __call__(self, data: Dict[str, Any]) -> Any: + + decollated_data = decollate_batch(data) + inv_ds = _BatchInverseDataset(decollated_data, self.transform, self.pad_collation_used) + inv_loader = DataLoader( + inv_ds, batch_size=self.batch_size, num_workers=self.num_workers, collate_fn=self.collate_fn + ) + try: + return first(inv_loader) + except RuntimeError as re: + re_str = str(re) + if "equal size" in re_str: + re_str += "\nMONAI hint: try creating `BatchInverseTransform` with `collate_fn=lambda x: x`." + raise RuntimeError(re_str) diff --git a/tests/test_inverse.py b/tests/test_inverse.py index 8ce4e3bbf3..f548b53f11 100644 --- a/tests/test_inverse.py +++ b/tests/test_inverse.py @@ -21,6 +21,7 @@ from parameterized import parameterized from monai.data import CacheDataset, DataLoader, create_test_image_2d, create_test_image_3d +from monai.data.inverse_batch_transform import BatchInverseTransform from monai.data.utils import decollate_batch from monai.networks.nets import UNet from monai.transforms import ( @@ -407,6 +408,10 @@ TESTS = TESTS + TESTS_COMPOSE_X2 # type: ignore +def no_collation(x): + return x + + class TestInverse(unittest.TestCase): """Test inverse methods. @@ -573,6 +578,12 @@ def test_inverse_inferred_seg(self): self.assertEqual(len(seg_dict["label_transforms"]), num_invertible_transforms) self.assertEqual(inv_seg.shape[1:], test_data[0]["label"].shape) + # Inverse of batch + batch_inverter = BatchInverseTransform(transforms, loader, collate_fn=no_collation) + with allow_missing_keys_mode(transforms): + inv_batch = batch_inverter(segs_dict) + self.assertEqual(inv_batch[0]["label"].shape[1:], test_data[0]["label"].shape) + if __name__ == "__main__": unittest.main()