Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
6dfd7b6
batch inverse
rijobro Mar 4, 2021
c4ff072
add batch_inverse tests
rijobro Feb 26, 2021
81c9aa8
Merge remote-tracking branch 'MONAI/master' into batch_inverse
rijobro Mar 17, 2021
0682344
update
rijobro Mar 17, 2021
b7aec08
autofix
rijobro Mar 17, 2021
de51114
Merge remote-tracking branch 'MONAI/master' into batch_inverse
rijobro Mar 17, 2021
77ff3cd
Merge remote-tracking branch 'MONAI/master' into batch_inverse
rijobro Mar 18, 2021
b78ee22
cant pickle lambda
rijobro Mar 18, 2021
dac89ba
No ID check for windows
rijobro Mar 18, 2021
2d1ca33
win32 change
rijobro Mar 18, 2021
0a8d321
autofix
rijobro Mar 18, 2021
385b9ce
code format
rijobro Mar 18, 2021
a03ce5f
more changes
rijobro Mar 18, 2021
ba6ecbb
changes
rijobro Mar 18, 2021
a7cf63e
Merge remote-tracking branch 'MONAI/master' into batch_inverse
rijobro Mar 18, 2021
f63ae66
Merge branch 'master' into batch_inverse
wyli Mar 19, 2021
e69f25a
create PadListDataCollate transform
rijobro Mar 19, 2021
79c7097
rst
rijobro Mar 19, 2021
4aa837f
change default collation of batch inversion
rijobro Mar 19, 2021
3975a86
Merge remote-tracking branch 'MONAI/master' into batch_inverse
rijobro Mar 19, 2021
3f8c630
Merge remote-tracking branch 'MONAI/master' into batch_inverse
rijobro Mar 19, 2021
d1bb7df
Merge branch 'master' into batch_inverse
rijobro Mar 20, 2021
a1c9952
Merge branch 'master' into batch_inverse
rijobro Mar 20, 2021
924901b
update docs
rijobro Mar 20, 2021
56d57c8
Merge branch 'batch_inverse' of https://github.com/rijobro/MONAI into…
rijobro Mar 20, 2021
708bf39
Merge branch 'batch_inverse' into rijobro/batch_inverse
rijobro Mar 20, 2021
6e3ca06
update docs
rijobro Mar 20, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/source/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -183,3 +183,8 @@ DataLoader
ThreadBuffer
~~~~~~~~~~~~
.. autoclass:: monai.data.ThreadBuffer


BatchInverseTransform
~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: monai.data.BatchInverseTransform
1 change: 1 addition & 0 deletions monai/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from .grid_dataset import GridPatchDataset, PatchDataset, PatchIter
from .image_dataset import ImageDataset
from .image_reader import ImageReader, ITKReader, NibabelReader, NumpyReader, PILReader, WSIReader
from .inverse_batch_transform import BatchInverseTransform
from .iterable_dataset import IterableDataset
from .nifti_saver import NiftiSaver
from .nifti_writer import write_nifti
Expand Down
84 changes: 84 additions & 0 deletions monai/data/inverse_batch_transform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# Copyright 2020 - 2021 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Callable, Dict, Hashable, Optional, Sequence

import numpy as np
from torch.utils.data.dataloader import DataLoader as TorchDataLoader

from monai.data.dataloader import DataLoader
from monai.data.dataset import Dataset
from monai.data.utils import decollate_batch, pad_list_data_collate
from monai.transforms.croppad.batch import PadListDataCollate
from monai.transforms.inverse import InvertibleTransform
from monai.transforms.transform import Transform
from monai.utils import first

__all__ = ["BatchInverseTransform"]


class _BatchInverseDataset(Dataset):
def __init__(
self,
data: Sequence[Any],
transform: InvertibleTransform,
pad_collation_used: bool,
) -> None:
super().__init__(data, transform)
self.invertible_transform = transform
self.pad_collation_used = pad_collation_used

def __getitem__(self, index: int) -> Dict[Hashable, np.ndarray]:
data = dict(self.data[index])
# If pad collation was used, then we need to undo this first
if self.pad_collation_used:
data = PadListDataCollate.inverse(data)

return self.invertible_transform.inverse(data)


def no_collation(x):
return x


class BatchInverseTransform(Transform):
"""Perform inverse on a batch of data. This is useful if you have inferred a batch of images and want to invert them all."""

def __init__(
self, transform: InvertibleTransform, loader: TorchDataLoader, collate_fn: Optional[Callable] = no_collation
) -> None:
"""
Args:
transform: a callable data transform on input data.
loader: data loader used to generate the batch of data.
collate_fn: how to collate data after inverse transformations. Default won't do any collation, so the output will be a
list of size batch size.
"""
self.transform = transform
self.batch_size = loader.batch_size
self.num_workers = loader.num_workers
self.collate_fn = collate_fn
self.pad_collation_used = loader.collate_fn == pad_list_data_collate

def __call__(self, data: Dict[str, Any]) -> Any:

decollated_data = decollate_batch(data)
inv_ds = _BatchInverseDataset(decollated_data, self.transform, self.pad_collation_used)
inv_loader = DataLoader(
inv_ds, batch_size=self.batch_size, num_workers=self.num_workers, collate_fn=self.collate_fn
)
try:
return first(inv_loader)
except RuntimeError as re:
re_str = str(re)
if "equal size" in re_str:
re_str += "\nMONAI hint: try creating `BatchInverseTransform` with `collate_fn=lambda x: x`."
raise RuntimeError(re_str)
11 changes: 11 additions & 0 deletions tests/test_inverse.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from parameterized import parameterized

from monai.data import CacheDataset, DataLoader, create_test_image_2d, create_test_image_3d
from monai.data.inverse_batch_transform import BatchInverseTransform
from monai.data.utils import decollate_batch
from monai.networks.nets import UNet
from monai.transforms import (
Expand Down Expand Up @@ -407,6 +408,10 @@
TESTS = TESTS + TESTS_COMPOSE_X2 # type: ignore


def no_collation(x):
return x


class TestInverse(unittest.TestCase):
"""Test inverse methods.

Expand Down Expand Up @@ -573,6 +578,12 @@ def test_inverse_inferred_seg(self):
self.assertEqual(len(seg_dict["label_transforms"]), num_invertible_transforms)
self.assertEqual(inv_seg.shape[1:], test_data[0]["label"].shape)

# Inverse of batch
batch_inverter = BatchInverseTransform(transforms, loader, collate_fn=no_collation)
with allow_missing_keys_mode(transforms):
inv_batch = batch_inverter(segs_dict)
self.assertEqual(inv_batch[0]["label"].shape[1:], test_data[0]["label"].shape)


if __name__ == "__main__":
unittest.main()