From 6dfd7b6da54e6fa129cd5baf5aede111ff4fa111 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Thu, 4 Mar 2021 15:49:48 +0000 Subject: [PATCH 01/16] batch inverse Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/data/__init__.py | 1 + monai/data/inverse_batch_transform.py | 90 +++++++++++++++++++++++++++ 2 files changed, 91 insertions(+) create mode 100644 monai/data/inverse_batch_transform.py diff --git a/monai/data/__init__.py b/monai/data/__init__.py index 3dd0a980ef..679f88c2ab 100644 --- a/monai/data/__init__.py +++ b/monai/data/__init__.py @@ -25,6 +25,7 @@ from .grid_dataset import GridPatchDataset, PatchDataset from .image_dataset import ImageDataset from .image_reader import ImageReader, ITKReader, NibabelReader, NumpyReader, PILReader +from .inverse_batch_transform import BatchInverseTransform from .iterable_dataset import IterableDataset from .nifti_saver import NiftiSaver from .nifti_writer import write_nifti diff --git a/monai/data/inverse_batch_transform.py b/monai/data/inverse_batch_transform.py new file mode 100644 index 0000000000..1f6c903e36 --- /dev/null +++ b/monai/data/inverse_batch_transform.py @@ -0,0 +1,90 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Callable, Dict, Hashable, Optional, Tuple + +import numpy as np +from torch.utils.data.dataloader import DataLoader as TorchDataLoader + +from monai.data.dataloader import DataLoader +from monai.data.dataset import Dataset +from monai.data.utils import decollate_batch, pad_list_data_collate +from monai.transforms.croppad.array import CenterSpatialCrop +from monai.transforms.inverse_transform import InvertibleTransform +from monai.utils import first +from monai.utils.misc import ensure_tuple + +__all__ = ["BatchInverseTransform"] + + +class _BatchInverseDataset(Dataset): + def __init__( + self, + data: Dict[str, Any], + transform: InvertibleTransform, + keys: Optional[Tuple[Hashable, ...]], + pad_collation_used: bool, + ) -> None: + self.data = decollate_batch(data) + self.invertible_transform = transform + self.keys = ensure_tuple(keys) if keys else None + self.pad_collation_used = pad_collation_used + + def __getitem__(self, index: int) -> Dict[Hashable, np.ndarray]: + data = dict(self.data[index]) + # If pad collation was used, then we need to undo this first + if self.pad_collation_used: + keys = self.keys or [key for key in data.keys() if str(key) + "_transforms" in data.keys()] + for key in keys: + transform_key = str(key) + "_transforms" + transform = data[transform_key][-1] + if transform["class"] == "SpatialPadd": + data[key] = CenterSpatialCrop(transform["orig_size"])(data[key]) + # remove transform + data[transform_key].pop() + + return self.invertible_transform.inverse(data, self.keys) + + +class BatchInverseTransform: + """something""" + + def __init__( + self, transform: InvertibleTransform, loader: TorchDataLoader, collate_fn: Optional[Callable] = None + ) -> None: + """ + Args: + transform: a callable data transform on input data. + loader: data loader used to generate the batch of data. + collate_fn: how to collate data after inverse transformations. Default will use the DataLoader's default + collation method. If returning images of different sizes, this will likely create an error (since the + collation will concatenate arrays, requiring them to be the same size). In this case, using + `collate_fn=lambda x: x` might solve the problem. + """ + self.transform = transform + self.batch_size = loader.batch_size + self.num_workers = loader.num_workers + self.collate_fn = collate_fn + self.pad_collation_used = loader.collate_fn == pad_list_data_collate + + def __call__(self, data: Dict[str, Any], keys: Optional[Tuple[Hashable, ...]] = None) -> Any: + + inv_ds = _BatchInverseDataset(data, self.transform, keys, self.pad_collation_used) + inv_loader = DataLoader( + inv_ds, batch_size=self.batch_size, num_workers=self.num_workers, collate_fn=self.collate_fn + ) + try: + return first(inv_loader) + except RuntimeError as re: + re_str = str(re) + if "stack expects each tensor to be equal size" in re_str: + re_str += "\nMONAI hint: try creating `BatchInverseTransform` with `collate_fn=lambda x: x`." + raise RuntimeError(re_str) From c4ff0722e9e0ecd363b74a6fbfbf131293953a1e Mon Sep 17 00:00:00 2001 From: Rich <33289025+rijobro@users.noreply.github.com> Date: Fri, 26 Feb 2021 18:01:57 +0000 Subject: [PATCH 02/16] add batch_inverse tests Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- tests/test_inverse.py | 599 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 599 insertions(+) create mode 100644 tests/test_inverse.py diff --git a/tests/test_inverse.py b/tests/test_inverse.py new file mode 100644 index 0000000000..1ca0e3e8cf --- /dev/null +++ b/tests/test_inverse.py @@ -0,0 +1,599 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random +import sys +import unittest +from typing import TYPE_CHECKING, List, Tuple + +import numpy as np +import torch +from parameterized import parameterized + +from monai.data import BatchInverseTransform, CacheDataset, DataLoader, create_test_image_2d, create_test_image_3d +from monai.data.utils import decollate_batch, pad_list_data_collate +from monai.networks.nets import UNet +from monai.transforms import ( + AddChannel, + AddChanneld, + BorderPadd, + CenterSpatialCropd, + Compose, + CropForegroundd, + DivisiblePadd, + Flipd, + InvertibleTransform, + LoadImaged, + Orientationd, + Rand2DElasticd, + Rand3DElasticd, + RandAffined, + RandFlipd, + RandRotate90d, + RandRotated, + RandSpatialCropd, + RandZoomd, + Resized, + ResizeWithPadOrCrop, + ResizeWithPadOrCropd, + Rotate90d, + Rotated, + Spacingd, + SpatialCropd, + SpatialPad, + SpatialPadd, + Zoomd, +) +from monai.utils import first, optional_import, set_determinism +from tests.utils import make_nifti_image, make_rand_affine, test_is_quick + +if TYPE_CHECKING: + import matplotlib.pyplot as plt + + has_matplotlib = True + has_vtk = True +else: + plt, has_matplotlib = optional_import("matplotlib.pyplot") + _, has_vtk = optional_import("vtk") + +set_determinism(seed=0) + +AFFINE = make_rand_affine() +AFFINE[0] *= 2 + +IM_1D = AddChannel()(np.arange(0, 10)) +IM_2D_FNAME, SEG_2D_FNAME = [make_nifti_image(i) for i in create_test_image_2d(101, 100)] +IM_3D_FNAME, SEG_3D_FNAME = [make_nifti_image(i, AFFINE) for i in create_test_image_3d(100, 101, 107)] + +KEYS = ["image", "label"] +DATA_1D = {"image": IM_1D, "label": IM_1D, "other": IM_1D} +LOAD_IMS = Compose([LoadImaged(KEYS), AddChanneld(KEYS)]) +DATA_2D = LOAD_IMS({"image": IM_2D_FNAME, "label": SEG_2D_FNAME}) +DATA_3D = LOAD_IMS({"image": IM_3D_FNAME, "label": SEG_3D_FNAME}) + +TESTS: List[Tuple] = [] + +TESTS.append( + ( + "SpatialPadd (x2) 2d", + DATA_2D, + 0.0, + SpatialPadd(KEYS, spatial_size=[111, 113], method="end"), + SpatialPadd(KEYS, spatial_size=[118, 117]), + ) +) + +TESTS.append( + ( + "SpatialPadd 3d", + DATA_3D, + 0.0, + SpatialPadd(KEYS, spatial_size=[112, 113, 116]), + ) +) + +TESTS.append( + ( + "RandRotated, prob 0", + DATA_2D, + 0, + RandRotated(KEYS, prob=0), + ) +) + +TESTS.append( + ( + "SpatialCropd 2d", + DATA_2D, + 3e-2, + SpatialCropd(KEYS, [49, 51], [90, 89]), + ) +) + +TESTS.append( + ( + "SpatialCropd 3d", + DATA_3D, + 4e-2, + SpatialCropd(KEYS, [49, 51, 44], [90, 89, 93]), + ) +) + +TESTS.append(("RandSpatialCropd 2d", DATA_2D, 5e-2, RandSpatialCropd(KEYS, [96, 93], True, False))) + +TESTS.append(("RandSpatialCropd 3d", DATA_3D, 2e-2, RandSpatialCropd(KEYS, [96, 93, 92], False, False))) + +TESTS.append( + ( + "BorderPadd 2d", + DATA_2D, + 0, + BorderPadd(KEYS, [3, 7, 2, 5]), + ) +) + +TESTS.append( + ( + "BorderPadd 2d", + DATA_2D, + 0, + BorderPadd(KEYS, [3, 7]), + ) +) + +TESTS.append( + ( + "BorderPadd 3d", + DATA_3D, + 0, + BorderPadd(KEYS, [4]), + ) +) + +TESTS.append( + ( + "DivisiblePadd 2d", + DATA_2D, + 0, + DivisiblePadd(KEYS, k=4), + ) +) + +TESTS.append( + ( + "DivisiblePadd 3d", + DATA_3D, + 0, + DivisiblePadd(KEYS, k=[4, 8, 11]), + ) +) + +TESTS.append( + ( + "Flipd 3d", + DATA_3D, + 0, + Flipd(KEYS, [1, 2]), + ) +) + +TESTS.append( + ( + "Flipd 3d", + DATA_3D, + 0, + Flipd(KEYS, [1, 2]), + ) +) + +TESTS.append( + ( + "RandFlipd 3d", + DATA_3D, + 0, + RandFlipd(KEYS, 1, [1, 2]), + ) +) + +TESTS.append( + ( + "Rotated 2d", + DATA_2D, + 8e-2, + Rotated(KEYS, random.uniform(np.pi / 6, np.pi), keep_size=True, align_corners=False), + ) +) + +TESTS.append( + ( + "Rotated 3d", + DATA_3D, + 5e-2, + Rotated(KEYS, [random.uniform(np.pi / 6, np.pi) for _ in range(3)], True), # type: ignore + ) +) + +TESTS.append( + ( + "RandRotated 3d", + DATA_3D, + 5e-2, + RandRotated(KEYS, *[random.uniform(np.pi / 6, np.pi) for _ in range(3)], 1), # type: ignore + ) +) + +TESTS.append( + ( + "Orientationd 3d", + DATA_3D, + 0, + # For data loader, output needs to be same size, so input must be square/cubic + SpatialPadd(KEYS, max(DATA_3D["image"].shape)), + Orientationd(KEYS, "RAS"), + ) +) + +TESTS.append( + ( + "Rotate90d 2d", + DATA_2D, + 0, + Rotate90d(KEYS), + ) +) + +TESTS.append( + ( + "Rotate90d 3d", + DATA_3D, + 0, + Rotate90d(KEYS, k=2, spatial_axes=(1, 2)), + ) +) + +TESTS.append( + ( + "RandRotate90d 3d", + DATA_3D, + 0, + # For data loader, output needs to be same size, so input must be square/cubic + SpatialPadd(KEYS, max(DATA_3D["image"].shape)), + RandRotate90d(KEYS, prob=1, spatial_axes=(1, 2)), + ) +) + +TESTS.append( + ( + "Zoomd 1d", + DATA_1D, + 0, + Zoomd(KEYS, zoom=2, keep_size=False), + ) +) + +TESTS.append( + ( + "Zoomd 2d", + DATA_2D, + 2e-1, + Zoomd(KEYS, zoom=0.9), + ) +) + +TESTS.append( + ( + "Zoomd 3d", + DATA_3D, + 3e-2, + Zoomd(KEYS, zoom=[2.5, 1, 3], keep_size=False), + ) +) + +TESTS.append(("RandZoom 3d", DATA_3D, 9e-2, RandZoomd(KEYS, 1, [0.5, 0.6, 0.9], [1.1, 1, 1.05], keep_size=True))) + +TESTS.append( + ( + "CenterSpatialCropd 2d", + DATA_2D, + 0, + CenterSpatialCropd(KEYS, roi_size=95), + ) +) + +TESTS.append( + ( + "CenterSpatialCropd 3d", + DATA_3D, + 0, + CenterSpatialCropd(KEYS, roi_size=[95, 97, 98]), + ) +) + +TESTS.append(("CropForegroundd 2d", DATA_2D, 0, CropForegroundd(KEYS, source_key="label", margin=2))) + +TESTS.append(("CropForegroundd 3d", DATA_3D, 0, CropForegroundd(KEYS, source_key="label"))) + +TESTS.append(("Spacingd 3d", DATA_3D, 3e-2, Spacingd(KEYS, [0.5, 0.7, 0.9], diagonal=False))) + +TESTS.append(("Resized 2d", DATA_2D, 2e-1, Resized(KEYS, [50, 47]))) + +TESTS.append(("Resized 3d", DATA_3D, 5e-2, Resized(KEYS, [201, 150, 78]))) + +TESTS.append(("ResizeWithPadOrCropd 3d", DATA_3D, 1e-2, ResizeWithPadOrCropd(KEYS, [201, 150, 78]))) + +TESTS.append( + ( + "RandAffine 3d", + DATA_3D, + 5e-2, + RandAffined( + KEYS, + [155, 179, 192], + prob=1, + padding_mode="zeros", + rotate_range=[np.pi / 6, -np.pi / 5, np.pi / 7], + shear_range=[(0.5, 0.5)], + translate_range=[10, 5, -4], + scale_range=[(0.8, 1.2), (0.9, 1.3)], + ), + ) +) + +if has_vtk: + TESTS.append( + ( + "Rand2DElasticd 2d", + DATA_2D, + 2e-1, + Rand2DElasticd( + KEYS, + spacing=(10.0, 10.0), + magnitude_range=(1, 1), + spatial_size=[155, 192], + prob=1, + padding_mode="zeros", + rotate_range=[(np.pi / 6, np.pi / 6)], + shear_range=[(0.5, 0.5)], + translate_range=[10, 5], + scale_range=[(1.2, 1.2), (1.3, 1.3)], + ), + ) + ) + +if not test_is_quick and has_vtk: + TESTS.append( + ( + "Rand3DElasticd 3d", + DATA_3D, + 1e-1, + Rand3DElasticd( + KEYS, + sigma_range=(3, 5), + magnitude_range=(100, 100), + prob=1, + padding_mode="zeros", + rotate_range=[np.pi / 6, np.pi / 7], + shear_range=[(0.5, 0.5), 0.2], + translate_range=[10, 5, 3], + scale_range=[(0.8, 1.2), (0.9, 1.3)], + ), + ) + ) + +TESTS_COMPOSE_X2 = [(t[0] + " Compose", t[1], t[2], Compose(Compose(t[3:]))) for t in TESTS] + +TESTS = TESTS + TESTS_COMPOSE_X2 # type: ignore + + +# Should fail because uses an array transform (SpatialPad), as opposed to dictionary +TEST_FAIL_0 = (DATA_2D["image"], 0.0, Compose([SpatialPad(spatial_size=[101, 103])])) +TESTS_FAIL = [TEST_FAIL_0] + + +def plot_im(orig, fwd_bck, fwd): + diff_orig_fwd_bck = orig - fwd_bck + ims_to_show = [orig, fwd, fwd_bck, diff_orig_fwd_bck] + titles = ["x", "fx", "f⁻¹fx", "x - f⁻¹fx"] + fig, axes = plt.subplots(1, 4, gridspec_kw={"width_ratios": [i.shape[1] for i in ims_to_show]}) + vmin = min(np.array(i).min() for i in [orig, fwd_bck, fwd]) + vmax = max(np.array(i).max() for i in [orig, fwd_bck, fwd]) + for im, title, ax in zip(ims_to_show, titles, axes): + _vmin, _vmax = (vmin, vmax) if id(im) != id(diff_orig_fwd_bck) else (None, None) + im = np.squeeze(np.array(im)) + while im.ndim > 2: + im = im[..., im.shape[-1] // 2] + im_show = ax.imshow(np.squeeze(im), vmin=_vmin, vmax=_vmax) + ax.set_title(title, fontsize=25) + ax.axis("off") + fig.colorbar(im_show, ax=ax) + plt.show() + + +class TestInverse(unittest.TestCase): + def check_inverse(self, name, keys, orig_d, fwd_bck_d, unmodified_d, acceptable_diff): + for key in keys: + orig = orig_d[key] + fwd_bck = fwd_bck_d[key] + if isinstance(fwd_bck, torch.Tensor): + fwd_bck = fwd_bck.cpu().numpy() + unmodified = unmodified_d[key] + if isinstance(orig, np.ndarray): + mean_diff = np.mean(np.abs(orig - fwd_bck)) + unmodded_diff = np.mean(np.abs(orig - ResizeWithPadOrCrop(orig.shape[1:])(unmodified))) + try: + self.assertLessEqual(mean_diff, acceptable_diff) + except AssertionError: + print( + f"Failed: {name}. Mean diff = {mean_diff} (expected <= {acceptable_diff}), unmodified diff: {unmodded_diff}" + ) + if has_matplotlib and orig[0].ndim > 1: + plot_im(orig, fwd_bck, unmodified) + elif orig[0].ndim == 1: + print(orig) + print(fwd_bck) + raise + + @parameterized.expand(TESTS) + def test_inverse(self, _, data, acceptable_diff, *transforms): + name = _ + + forwards = [data.copy()] + + # Apply forwards + for t in transforms: + forwards.append(t(forwards[-1])) + + # Check that error is thrown when inverse are used out of order. + t = SpatialPadd("image", [10, 5]) + with self.assertRaises(RuntimeError): + t.inverse(forwards[-1]) + + # Apply inverses + fwd_bck = forwards[-1].copy() + for i, t in enumerate(reversed(transforms)): + if isinstance(t, InvertibleTransform): + fwd_bck = t.inverse(fwd_bck) + self.check_inverse(name, data.keys(), forwards[-i - 2], fwd_bck, forwards[-1], acceptable_diff) + + @parameterized.expand(TESTS) + def test_w_dataloader(self, _, data, acceptable_diff, *transforms): + name = _ + device = "cpu" + if isinstance(transforms, tuple): + transforms = Compose(transforms) + numel = 4 + test_data = [data for _ in range(numel)] + + ndims = len(data["image"].shape[1:]) + batch_size = 2 + num_workers = 0 + + dataset = CacheDataset(test_data, transforms, progress=False) + loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers) + inv_batch = BatchInverseTransform(transforms, loader) + + model = UNet(ndims, 1, 1, channels=(6, 6), strides=(2, 2)).to(device) + for batch_data in loader: + inputs, _ = ( + batch_data["image"].to(device), + batch_data["label"].to(device), + ) + + fwd_bck_batch = inv_batch(batch_data) + fwd_bck = decollate_batch(fwd_bck_batch) + + for idx, (_test_data, _fwd_bck) in enumerate(zip(test_data, fwd_bck)): + _fwd = transforms(test_data[idx]) + self.check_inverse(name, data.keys(), _test_data, _fwd_bck, _fwd, acceptable_diff) + + if torch.cuda.is_available(): + _ = model(inputs) + + def test_diff_sized_inputs(self): + + key = "image" + test_data = [{key: AddChannel()(create_test_image_2d(100 + i, 101 + i)[0])} for i in range(4)] + + batch_size = 2 + num_workers = 0 + transforms = Compose([SpatialPadd(key, (150, 153))]) + + dataset = CacheDataset(test_data, transform=transforms, progress=False) + loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers) + # blank collate function since input are different size + inv_batch = BatchInverseTransform(transforms, loader, collate_fn=lambda x: x) + + for batch_idx, batch_data in enumerate(loader): + fwd = decollate_batch(batch_data) + fwd_bck = inv_batch(batch_data) + + for idx, (_fwd, _fwd_bck) in enumerate(zip(fwd, fwd_bck)): + unmodified = test_data[batch_idx * batch_size + idx] + self.check_inverse("diff_sized_inputs", [key], unmodified, _fwd_bck, _fwd, 0) + + def test_inverse_w_pad_list_data_collate(self): + + test_data = [] + for _ in range(4): + image, label = [AddChannel()(i) for i in create_test_image_2d(100, 101)] + test_data.append({"image": image, "label": label.astype(np.float32)}) + + batch_size = 2 + num_workers = 0 + transforms = Compose([CropForegroundd(KEYS, source_key="label")]) + + dataset = CacheDataset(test_data, transform=transforms, progress=False) + loader = DataLoader( + dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, collate_fn=pad_list_data_collate + ) + # blank collate function since input are different size + inv_batch = BatchInverseTransform(transforms, loader) + + for batch_idx, batch_data in enumerate(loader): + fwd = decollate_batch(batch_data) + fwd_bck = decollate_batch(inv_batch(batch_data)) + + for idx, (_fwd, _fwd_bck) in enumerate(zip(fwd, fwd_bck)): + unmodified = test_data[batch_idx * batch_size + idx] + self.check_inverse("diff_sized_inputs", KEYS, unmodified, _fwd_bck, _fwd, 1e-1) + + @parameterized.expand(TESTS_FAIL) + def test_fail(self, data, _, *transform): + d = transform[0](data) + with self.assertRaises(RuntimeError): + d = transform[0].inverse(d) + + def test_inverse_inferred_seg(self): + + test_data = [] + for _ in range(4): + image, label = create_test_image_2d(100, 101) + test_data.append({"image": image, "label": label.astype(np.float32)}) + + batch_size = 2 + # num workers = 0 for mac + num_workers = 2 if sys.platform != "darwin" else 0 + transforms = Compose([AddChanneld(KEYS), SpatialPadd(KEYS, (150, 153)), CenterSpatialCropd(KEYS, (110, 99))]) + num_invertible_transforms = sum(1 for i in transforms.transforms if isinstance(i, InvertibleTransform)) + + dataset = CacheDataset(test_data, transform=transforms, progress=False) + loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers) + + device = "cuda" if torch.cuda.is_available() else "cpu" + model = UNet( + dimensions=2, + in_channels=1, + out_channels=1, + channels=(2, 4), + strides=(2,), + ).to(device) + + data = first(loader) + labels = data["label"].to(device) + segs = model(labels).detach().cpu() + segs_dict = {"label": segs, "label_transforms": data["label_transforms"]} + segs_dict_decollated = decollate_batch(segs_dict) + + # inverse of individual segmentation + seg_dict = first(segs_dict_decollated) + inv_seg = transforms.inverse(seg_dict, "label")["label"] + self.assertEqual(len(data["label_transforms"]), num_invertible_transforms) + self.assertEqual(len(seg_dict["label_transforms"]), num_invertible_transforms) + self.assertEqual(inv_seg.shape[1:], test_data[0]["label"].shape) + + # Inverse of batch + batch_inverter = BatchInverseTransform(transforms, loader, collate_fn=lambda x: x) + inv_batch = batch_inverter(segs_dict, "label") + self.assertEqual(inv_batch[0]["label"].shape[1:], test_data[0]["label"].shape) + + +if __name__ == "__main__": + unittest.main() From 06823448ed3a982cff874cc6bc0c6afe7194ea3c Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Wed, 17 Mar 2021 17:07:32 +0000 Subject: [PATCH 03/16] update Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/data/inverse_batch_transform.py | 16 ++++++---------- tests/test_inverse.py | 4 +++- 2 files changed, 9 insertions(+), 11 deletions(-) diff --git a/monai/data/inverse_batch_transform.py b/monai/data/inverse_batch_transform.py index 1f6c903e36..c973597786 100644 --- a/monai/data/inverse_batch_transform.py +++ b/monai/data/inverse_batch_transform.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Dict, Hashable, Optional, Tuple +from typing import Any, Callable, Dict, Hashable, Optional import numpy as np from torch.utils.data.dataloader import DataLoader as TorchDataLoader @@ -18,9 +18,8 @@ from monai.data.dataset import Dataset from monai.data.utils import decollate_batch, pad_list_data_collate from monai.transforms.croppad.array import CenterSpatialCrop -from monai.transforms.inverse_transform import InvertibleTransform +from monai.transforms.inverse import InvertibleTransform from monai.utils import first -from monai.utils.misc import ensure_tuple __all__ = ["BatchInverseTransform"] @@ -30,20 +29,17 @@ def __init__( self, data: Dict[str, Any], transform: InvertibleTransform, - keys: Optional[Tuple[Hashable, ...]], pad_collation_used: bool, ) -> None: self.data = decollate_batch(data) self.invertible_transform = transform - self.keys = ensure_tuple(keys) if keys else None self.pad_collation_used = pad_collation_used def __getitem__(self, index: int) -> Dict[Hashable, np.ndarray]: data = dict(self.data[index]) # If pad collation was used, then we need to undo this first if self.pad_collation_used: - keys = self.keys or [key for key in data.keys() if str(key) + "_transforms" in data.keys()] - for key in keys: + for key in data.keys(): transform_key = str(key) + "_transforms" transform = data[transform_key][-1] if transform["class"] == "SpatialPadd": @@ -51,7 +47,7 @@ def __getitem__(self, index: int) -> Dict[Hashable, np.ndarray]: # remove transform data[transform_key].pop() - return self.invertible_transform.inverse(data, self.keys) + return self.invertible_transform.inverse(data) class BatchInverseTransform: @@ -75,9 +71,9 @@ def __init__( self.collate_fn = collate_fn self.pad_collation_used = loader.collate_fn == pad_list_data_collate - def __call__(self, data: Dict[str, Any], keys: Optional[Tuple[Hashable, ...]] = None) -> Any: + def __call__(self, data: Dict[str, Any]) -> Any: - inv_ds = _BatchInverseDataset(data, self.transform, keys, self.pad_collation_used) + inv_ds = _BatchInverseDataset(data, self.transform, self.pad_collation_used) inv_loader = DataLoader( inv_ds, batch_size=self.batch_size, num_workers=self.num_workers, collate_fn=self.collate_fn ) diff --git a/tests/test_inverse.py b/tests/test_inverse.py index 432090f334..d2ce57dedc 100644 --- a/tests/test_inverse.py +++ b/tests/test_inverse.py @@ -9,6 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from monai.data.inverse_batch_transform import BatchInverseTransform import random import sys import unittest @@ -531,7 +532,8 @@ def test_inverse_inferred_seg(self): # Inverse of batch batch_inverter = BatchInverseTransform(transforms, loader, collate_fn=lambda x: x) - inv_batch = batch_inverter(segs_dict, "label") + with allow_missing_keys_mode(transforms): + inv_batch = batch_inverter(segs_dict) self.assertEqual(inv_batch[0]["label"].shape[1:], test_data[0]["label"].shape) From b7aec08d7f8650303d61fdf11882288a163f554e Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Wed, 17 Mar 2021 17:22:01 +0000 Subject: [PATCH 04/16] autofix Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- tests/test_inverse.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_inverse.py b/tests/test_inverse.py index d2ce57dedc..a05304a52e 100644 --- a/tests/test_inverse.py +++ b/tests/test_inverse.py @@ -9,7 +9,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from monai.data.inverse_batch_transform import BatchInverseTransform import random import sys import unittest @@ -21,6 +20,7 @@ from parameterized import parameterized from monai.data import CacheDataset, DataLoader, create_test_image_2d, create_test_image_3d +from monai.data.inverse_batch_transform import BatchInverseTransform from monai.data.utils import decollate_batch from monai.networks.nets import UNet from monai.transforms import ( From b78ee22529cdc84d9e5f50bb05b35d70f415cd0d Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Thu, 18 Mar 2021 09:54:41 +0000 Subject: [PATCH 05/16] cant pickle lambda Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- tests/test_inverse.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/test_inverse.py b/tests/test_inverse.py index 467c3e2f79..6f2380e268 100644 --- a/tests/test_inverse.py +++ b/tests/test_inverse.py @@ -407,6 +407,10 @@ TESTS = TESTS + TESTS_COMPOSE_X2 # type: ignore +def no_collation(x): + return x + + class TestInverse(unittest.TestCase): """Test inverse methods. @@ -567,7 +571,7 @@ def test_inverse_inferred_seg(self): self.assertEqual(inv_seg.shape[1:], test_data[0]["label"].shape) # Inverse of batch - batch_inverter = BatchInverseTransform(transforms, loader, collate_fn=lambda x: x) + batch_inverter = BatchInverseTransform(transforms, loader, collate_fn=no_collation) with allow_missing_keys_mode(transforms): inv_batch = batch_inverter(segs_dict) self.assertEqual(inv_batch[0]["label"].shape[1:], test_data[0]["label"].shape) From dac89baff8cc14038543c1f94cb8514d755b958a Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Thu, 18 Mar 2021 11:12:39 +0000 Subject: [PATCH 06/16] No ID check for windows Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/inverse.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/monai/transforms/inverse.py b/monai/transforms/inverse.py index f9de8746ca..322a0db45a 100644 --- a/monai/transforms/inverse.py +++ b/monai/transforms/inverse.py @@ -9,6 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import sys from typing import Dict, Hashable, Optional, Tuple import numpy as np @@ -89,8 +90,12 @@ def push_transform( def check_transforms_match(self, transform: dict) -> None: """Check transforms are of same instance.""" - if transform[InverseKeys.ID.value] != id(self): - raise RuntimeError("Should inverse most recently applied invertible transform first") + if transform[InverseKeys.ID.value] == id(self): + return + # basic check if windows because of multiprocessing differences (objects get recreated so don't have same ID) + if sys.platform == "win32" and transform[InverseKeys.CLASS_NAME.value] == self.__class__.__name__: + return + raise RuntimeError("Should inverse most recently applied invertible transform first") def get_most_recent_transform(self, data: dict, key: Hashable) -> dict: """Get most recent transform.""" From 2d1ca33ca07323dad366c8f3192501cda43298d3 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Thu, 18 Mar 2021 12:17:03 +0000 Subject: [PATCH 07/16] win32 change Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- tests/test_inverse.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/test_inverse.py b/tests/test_inverse.py index 6f2380e268..6ae92bc802 100644 --- a/tests/test_inverse.py +++ b/tests/test_inverse.py @@ -9,6 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from monai.transforms.croppad.array import RandSpatialCropSamples import random import sys import unittest @@ -519,6 +520,9 @@ def test_inverse(self, _, data_name, acceptable_diff, *transforms): # Check that error is thrown when inverse are used out of order. t = SpatialPadd("image", [10, 5]) + # on windows we only check the name, so in this case we need to use a different transform + if sys.platform == "win32" and isinstance(t[-1], SpatialPadd): + t = ResizeWithPadOrCropd("image", [10, 5]) with self.assertRaises(RuntimeError): t.inverse(forwards[-1]) From 0a8d32123695b58f1bb8bc362b0abd61bc920180 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Thu, 18 Mar 2021 12:54:57 +0000 Subject: [PATCH 08/16] autofix Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- tests/test_inverse.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_inverse.py b/tests/test_inverse.py index 6ae92bc802..e47c8c05cf 100644 --- a/tests/test_inverse.py +++ b/tests/test_inverse.py @@ -9,7 +9,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from monai.transforms.croppad.array import RandSpatialCropSamples import random import sys import unittest @@ -55,6 +54,7 @@ Zoomd, allow_missing_keys_mode, ) +from monai.transforms.croppad.array import RandSpatialCropSamples from monai.utils import first, get_seed, optional_import, set_determinism from monai.utils.enums import InverseKeys from tests.utils import make_nifti_image, make_rand_affine From 385b9ceda35e92569f1c01b610238b215827bdb2 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Thu, 18 Mar 2021 12:59:58 +0000 Subject: [PATCH 09/16] code format Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- tests/test_inverse.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_inverse.py b/tests/test_inverse.py index e47c8c05cf..ad95405a45 100644 --- a/tests/test_inverse.py +++ b/tests/test_inverse.py @@ -54,7 +54,6 @@ Zoomd, allow_missing_keys_mode, ) -from monai.transforms.croppad.array import RandSpatialCropSamples from monai.utils import first, get_seed, optional_import, set_determinism from monai.utils.enums import InverseKeys from tests.utils import make_nifti_image, make_rand_affine From a03ce5f5532cc86598e5d82e2f50e236c43afbf7 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Thu, 18 Mar 2021 13:52:25 +0000 Subject: [PATCH 10/16] more changes Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- tests/test_inverse.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_inverse.py b/tests/test_inverse.py index ad95405a45..1303595d38 100644 --- a/tests/test_inverse.py +++ b/tests/test_inverse.py @@ -520,7 +520,7 @@ def test_inverse(self, _, data_name, acceptable_diff, *transforms): # Check that error is thrown when inverse are used out of order. t = SpatialPadd("image", [10, 5]) # on windows we only check the name, so in this case we need to use a different transform - if sys.platform == "win32" and isinstance(t[-1], SpatialPadd): + if sys.platform == "win32" and isinstance(transforms[-1], SpatialPadd): t = ResizeWithPadOrCropd("image", [10, 5]) with self.assertRaises(RuntimeError): t.inverse(forwards[-1]) From ba6ecbb72bc7bb675eaaaf964f66fd05510dc3e4 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Thu, 18 Mar 2021 15:38:42 +0000 Subject: [PATCH 11/16] changes Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/inverse.py | 9 +++++---- tests/test_inverse.py | 13 ++++++------- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/monai/transforms/inverse.py b/monai/transforms/inverse.py index 322a0db45a..ed6834bc08 100644 --- a/monai/transforms/inverse.py +++ b/monai/transforms/inverse.py @@ -9,10 +9,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -import sys from typing import Dict, Hashable, Optional, Tuple import numpy as np +import torch from monai.transforms.transform import RandomizableTransform, Transform from monai.utils.enums import InverseKeys @@ -92,9 +92,10 @@ def check_transforms_match(self, transform: dict) -> None: """Check transforms are of same instance.""" if transform[InverseKeys.ID.value] == id(self): return - # basic check if windows because of multiprocessing differences (objects get recreated so don't have same ID) - if sys.platform == "win32" and transform[InverseKeys.CLASS_NAME.value] == self.__class__.__name__: - return + # basic check if multiprocessing uses 'spawn' (objects get recreated so don't have same ID) + if torch.multiprocessing.get_start_method(allow_none=True) == "spawn": + if transform[InverseKeys.CLASS_NAME.value] == self.__class__.__name__: + return raise RuntimeError("Should inverse most recently applied invertible transform first") def get_most_recent_transform(self, data: dict, key: Hashable) -> dict: diff --git a/tests/test_inverse.py b/tests/test_inverse.py index 1303595d38..e33113e13b 100644 --- a/tests/test_inverse.py +++ b/tests/test_inverse.py @@ -517,13 +517,12 @@ def test_inverse(self, _, data_name, acceptable_diff, *transforms): t.set_random_state(seed=get_seed()) forwards.append(t(forwards[-1])) - # Check that error is thrown when inverse are used out of order. - t = SpatialPadd("image", [10, 5]) - # on windows we only check the name, so in this case we need to use a different transform - if sys.platform == "win32" and isinstance(transforms[-1], SpatialPadd): - t = ResizeWithPadOrCropd("image", [10, 5]) - with self.assertRaises(RuntimeError): - t.inverse(forwards[-1]) + # skip this test if multiprocessing uses 'spawn', as the check is only basic anyway + if torch.multiprocessing.get_start_method(allow_none=True) == "spawn": + # Check that error is thrown when inverse are used out of order. + t = SpatialPadd("image", [10, 5]) + with self.assertRaises(RuntimeError): + t.inverse(forwards[-1]) # Apply inverses fwd_bck = forwards[-1].copy() From e69f25aa86d8e8814c938401cfc9e6314114eb26 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Fri, 19 Mar 2021 18:29:52 +0000 Subject: [PATCH 12/16] create PadListDataCollate transform Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/data/inverse_batch_transform.py | 23 ++--- monai/data/utils.py | 59 ++---------- monai/transforms/__init__.py | 1 + monai/transforms/croppad/batch.py | 129 ++++++++++++++++++++++++++ monai/transforms/post/dictionary.py | 1 + tests/test_decollate.py | 35 ++++--- tests/test_pad_collation.py | 33 ++++--- 7 files changed, 190 insertions(+), 91 deletions(-) create mode 100644 monai/transforms/croppad/batch.py diff --git a/monai/data/inverse_batch_transform.py b/monai/data/inverse_batch_transform.py index c973597786..09f3fc1f3e 100644 --- a/monai/data/inverse_batch_transform.py +++ b/monai/data/inverse_batch_transform.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Dict, Hashable, Optional +from typing import Any, Callable, Dict, Hashable, Optional, Sequence import numpy as np from torch.utils.data.dataloader import DataLoader as TorchDataLoader @@ -17,7 +17,7 @@ from monai.data.dataloader import DataLoader from monai.data.dataset import Dataset from monai.data.utils import decollate_batch, pad_list_data_collate -from monai.transforms.croppad.array import CenterSpatialCrop +from monai.transforms.croppad.batch import PadListDataCollate from monai.transforms.inverse import InvertibleTransform from monai.utils import first @@ -27,11 +27,11 @@ class _BatchInverseDataset(Dataset): def __init__( self, - data: Dict[str, Any], + data: Sequence[Any], transform: InvertibleTransform, pad_collation_used: bool, ) -> None: - self.data = decollate_batch(data) + super().__init__(data, transform) self.invertible_transform = transform self.pad_collation_used = pad_collation_used @@ -39,19 +39,13 @@ def __getitem__(self, index: int) -> Dict[Hashable, np.ndarray]: data = dict(self.data[index]) # If pad collation was used, then we need to undo this first if self.pad_collation_used: - for key in data.keys(): - transform_key = str(key) + "_transforms" - transform = data[transform_key][-1] - if transform["class"] == "SpatialPadd": - data[key] = CenterSpatialCrop(transform["orig_size"])(data[key]) - # remove transform - data[transform_key].pop() + data = PadListDataCollate.inverse(data) return self.invertible_transform.inverse(data) class BatchInverseTransform: - """something""" + """Perform inverse on a batch of data. This is useful if you have inferred a batch of images and want to invert them all.""" def __init__( self, transform: InvertibleTransform, loader: TorchDataLoader, collate_fn: Optional[Callable] = None @@ -73,7 +67,8 @@ def __init__( def __call__(self, data: Dict[str, Any]) -> Any: - inv_ds = _BatchInverseDataset(data, self.transform, self.pad_collation_used) + decollated_data = decollate_batch(data) + inv_ds = _BatchInverseDataset(decollated_data, self.transform, self.pad_collation_used) inv_loader = DataLoader( inv_ds, batch_size=self.batch_size, num_workers=self.num_workers, collate_fn=self.collate_fn ) @@ -81,6 +76,6 @@ def __call__(self, data: Dict[str, Any]) -> Any: return first(inv_loader) except RuntimeError as re: re_str = str(re) - if "stack expects each tensor to be equal size" in re_str: + if "equal size" in re_str: re_str += "\nMONAI hint: try creating `BatchInverseTransform` with `collate_fn=lambda x: x`." raise RuntimeError(re_str) diff --git a/monai/data/utils.py b/monai/data/utils.py index ae0180f4b5..bdbfa5c636 100644 --- a/monai/data/utils.py +++ b/monai/data/utils.py @@ -337,64 +337,25 @@ def pad_list_data_collate( mode: Union[NumpyPadMode, str] = NumpyPadMode.CONSTANT, ): """ + Function version of :py:class:`monai.transforms.croppad.batch.PadListDataCollate`. + Same as MONAI's ``list_data_collate``, except any tensors are centrally padded to match the shape of the biggest - tensor in each dimension. + tensor in each dimension. This transform is useful if some of the applied transforms generate batch data of + different sizes. - Note: - Need to use this collate if apply some transforms that can generate batch data. + This can be used on both list and dictionary data. In the case of the dictionary data, this transform will be added + to the list of invertible transforms. + + The inverse can be called using the static method: `monai.transforms.croppad.batch.PadListDataCollate.inverse`. Args: batch: batch of data to pad-collate method: padding method (see :py:class:`monai.transforms.SpatialPad`) mode: padding mode (see :py:class:`monai.transforms.SpatialPad`) """ - list_of_dicts = isinstance(batch[0], dict) - for key_or_idx in batch[0].keys() if list_of_dicts else range(len(batch[0])): - max_shapes = [] - for elem in batch: - if not isinstance(elem[key_or_idx], (torch.Tensor, np.ndarray)): - break - max_shapes.append(elem[key_or_idx].shape[1:]) - # len > 0 if objects were arrays - if len(max_shapes) == 0: - continue - max_shape = np.array(max_shapes).max(axis=0) - # If all same size, skip - if np.all(np.array(max_shapes).min(axis=0) == max_shape): - continue - # Do we need to convert output to Tensor? - output_to_tensor = isinstance(batch[0][key_or_idx], torch.Tensor) - - # Use `SpatialPadd` or `SpatialPad` to match sizes - # Default params are central padding, padding with 0's - # If input is dictionary, use the dictionary version so that the transformation is recorded - padder: Union[SpatialPadd, SpatialPad] - if list_of_dicts: - from monai.transforms.croppad.dictionary import SpatialPadd # needs to be here to avoid circular import + from monai.transforms.croppad.batch import PadListDataCollate # needs to be here to avoid circular import - padder = SpatialPadd(key_or_idx, max_shape, method, mode) # type: ignore - - else: - from monai.transforms.croppad.array import SpatialPad # needs to be here to avoid circular import - - padder = SpatialPad(max_shape, method, mode) # type: ignore - - for idx in range(len(batch)): - padded = padder(batch[idx])[key_or_idx] if list_of_dicts else padder(batch[idx][key_or_idx]) - # since tuple is immutable we'll have to recreate - if isinstance(batch[idx], tuple): - batch[idx] = list(batch[idx]) # type: ignore - batch[idx][key_or_idx] = padded - batch[idx] = tuple(batch[idx]) # type: ignore - # else, replace - else: - batch[idx][key_or_idx] = padder(batch[idx])[key_or_idx] - - if output_to_tensor: - batch[idx][key_or_idx] = torch.Tensor(batch[idx][key_or_idx]) - - # After padding, use default list collator - return list_data_collate(batch) + return PadListDataCollate(method, mode)(batch) def worker_init_fn(worker_id: int) -> None: diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index 5b12da4d21..9866fe1b6a 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -25,6 +25,7 @@ SpatialCrop, SpatialPad, ) +from .croppad.batch import PadListDataCollate from .croppad.dictionary import ( BorderPadd, BorderPadD, diff --git a/monai/transforms/croppad/batch.py b/monai/transforms/croppad/batch.py new file mode 100644 index 0000000000..7cbf39597c --- /dev/null +++ b/monai/transforms/croppad/batch.py @@ -0,0 +1,129 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +A collection of "vanilla" transforms for crop and pad operations acting on batches of data +https://github.com/Project-MONAI/MONAI/wiki/MONAI_Design +""" + +from copy import deepcopy +from typing import Any, Dict, Hashable, Union + +import numpy as np +import torch + +from monai.data.utils import list_data_collate +from monai.transforms.compose import Compose +from monai.transforms.croppad.array import CenterSpatialCrop, SpatialPad +from monai.transforms.inverse import InvertibleTransform +from monai.transforms.utility.array import ToTensor +from monai.utils.enums import InverseKeys, Method, NumpyPadMode + +__all__ = [ + "PadListDataCollate", +] + + +def replace_element(to_replace, batch, idx, key_or_idx): + # since tuple is immutable we'll have to recreate + if isinstance(batch[idx], tuple): + batch_idx_list = list(batch[idx]) + batch_idx_list[key_or_idx] = to_replace + batch[idx] = tuple(batch_idx_list) + # else, replace + else: + batch[idx][key_or_idx] = to_replace + return batch + + +class PadListDataCollate(InvertibleTransform): + """ + Same as MONAI's ``list_data_collate``, except any tensors are centrally padded to match the shape of the biggest + tensor in each dimension. This transform is useful if some of the applied transforms generate batch data of + different sizes. + + This can be used on both list and dictionary data. In the case of the dictionary data, this transform will be added + to the list of invertible transforms. + + Note that normally, a user won't explicitly use the `__call__` method. Rather this would be passed to the `DataLoader`. + This means that `__call__` handles data as it comes out of a `DataLoader`, containing batch dimension. However, the + `inverse` operates on dictionaries containing images of shape `C,H,W,[D]`. This asymmetry is necessary so that we can + pass the inverse through multiprocessing. + + Args: + batch: batch of data to pad-collate + method: padding method (see :py:class:`monai.transforms.SpatialPad`) + mode: padding mode (see :py:class:`monai.transforms.SpatialPad`) + """ + + def __init__( + self, + method: Union[Method, str] = Method.SYMMETRIC, + mode: Union[NumpyPadMode, str] = NumpyPadMode.CONSTANT, + ) -> None: + self.method = method + self.mode = mode + + def __call__(self, batch: Any): + # data is either list of dicts or list of lists + is_list_of_dicts = isinstance(batch[0], dict) + # loop over items inside of each element in a batch + for key_or_idx in batch[0].keys() if is_list_of_dicts else range(len(batch[0])): + # calculate max size of each dimension + max_shapes = [] + for elem in batch: + if not isinstance(elem[key_or_idx], (torch.Tensor, np.ndarray)): + break + max_shapes.append(elem[key_or_idx].shape[1:]) + # len > 0 if objects were arrays, else skip as no padding to be done + if len(max_shapes) == 0: + continue + max_shape = np.array(max_shapes).max(axis=0) + # If all same size, skip + if np.all(np.array(max_shapes).min(axis=0) == max_shape): + continue + # Do we need to convert output to Tensor? + output_to_tensor = isinstance(batch[0][key_or_idx], torch.Tensor) + + # Use `SpatialPadd` or `SpatialPad` to match sizes + # Default params are central padding, padding with 0's + # If input is dictionary, use the dictionary version so that the transformation is recorded + + padder = SpatialPad(max_shape, self.method, self.mode) # type: ignore + transform = padder if not output_to_tensor else Compose([padder, ToTensor()]) + + for idx in range(len(batch)): + im = batch[idx][key_or_idx] + orig_size = im.shape[1:] + padded = transform(batch[idx][key_or_idx]) + batch = replace_element(padded, batch, idx, key_or_idx) + + # If we have a dictionary of data, append to list + if is_list_of_dicts: + self.push_transform(batch[idx], key_or_idx, orig_size=orig_size) + + # After padding, use default list collator + return list_data_collate(batch) + + @staticmethod + def inverse(data: dict) -> Dict[Hashable, np.ndarray]: + if not isinstance(data, dict): + raise RuntimeError("Inverse can only currently be applied on dictionaries.") + + d = deepcopy(data) + for key in d.keys(): + transform_key = str(key) + InverseKeys.KEY_SUFFIX.value + if transform_key in d.keys(): + transform = d[transform_key][-1] + if transform[InverseKeys.CLASS_NAME.value] == PadListDataCollate.__name__: + d[key] = CenterSpatialCrop(transform["orig_size"])(d[key]) + # remove transform + d[transform_key].pop() + return d diff --git a/monai/transforms/post/dictionary.py b/monai/transforms/post/dictionary.py index 42796e2412..6d28f780d4 100644 --- a/monai/transforms/post/dictionary.py +++ b/monai/transforms/post/dictionary.py @@ -333,6 +333,7 @@ class Decollated(MapTransform): """ def __init__(self, batch_size: Optional[int] = None) -> None: + super().__init__(None) self.batch_size = batch_size def __call__(self, data: dict) -> List[dict]: diff --git a/tests/test_decollate.py b/tests/test_decollate.py index 4ed8de6bbb..4dc5a217a7 100644 --- a/tests/test_decollate.py +++ b/tests/test_decollate.py @@ -12,25 +12,38 @@ import sys import unittest from enum import Enum +from typing import List, Tuple import numpy as np import torch +from parameterized import parameterized from monai.data import CacheDataset, DataLoader, create_test_image_2d from monai.data.utils import decollate_batch from monai.transforms import AddChanneld, Compose, LoadImaged, RandFlipd, SpatialPadd, ToTensord from monai.transforms.post.dictionary import Decollated +from monai.transforms.spatial.dictionary import RandAffined, RandRotate90d from monai.utils import optional_import, set_determinism from monai.utils.enums import InverseKeys from tests.utils import make_nifti_image _, has_nib = optional_import("nibabel") +KEYS = ["image"] + +TESTS: List[Tuple] = [] +TESTS.append((SpatialPadd(KEYS, 150), RandFlipd(KEYS, prob=1.0, spatial_axis=1))) +TESTS.append((RandRotate90d(KEYS, prob=0.0, max_k=1),)) +TESTS.append((RandAffined(KEYS, prob=0.0, translate_range=10),)) + class TestDeCollate(unittest.TestCase): def setUp(self) -> None: set_determinism(seed=0) + im = create_test_image_2d(100, 101)[0] + self.data = [{"image": make_nifti_image(im) if has_nib else im} for _ in range(6)] + def tearDown(self) -> None: set_determinism(None) @@ -55,24 +68,18 @@ def check_match(self, in1, in2): else: raise RuntimeError(f"Not sure how to compare types. type(in1): {type(in1)}, type(in2): {type(in2)}") - def test_decollation(self, batch_size=2, num_workers=2): + @parameterized.expand(TESTS) + def test_decollation(self, *transforms): - im = create_test_image_2d(100, 101)[0] - data = [{"image": make_nifti_image(im) if has_nib else im} for _ in range(6)] - - transforms = Compose( - [ - AddChanneld("image"), - SpatialPadd("image", 150), - RandFlipd("image", prob=1.0, spatial_axis=1), - ToTensord("image"), - ] - ) + batch_size = 2 + num_workers = 2 + + t_compose = Compose([AddChanneld(KEYS), Compose(transforms), ToTensord(KEYS)]) # If nibabel present, read from disk if has_nib: - transforms = Compose([LoadImaged("image"), transforms]) + t_compose = Compose([LoadImaged("image"), t_compose]) - dataset = CacheDataset(data, transforms, progress=False) + dataset = CacheDataset(self.data, t_compose, progress=False) loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers) for b, batch_data in enumerate(loader): diff --git a/tests/test_pad_collation.py b/tests/test_pad_collation.py index 156d2649e0..3835dc8895 100644 --- a/tests/test_pad_collation.py +++ b/tests/test_pad_collation.py @@ -18,8 +18,9 @@ from parameterized import parameterized from monai.data import CacheDataset, DataLoader -from monai.data.utils import pad_list_data_collate +from monai.data.utils import decollate_batch, pad_list_data_collate from monai.transforms import ( + PadListDataCollate, RandRotate, RandRotate90, RandRotate90d, @@ -33,16 +34,16 @@ TESTS: List[Tuple] = [] +for pad_collate in [pad_list_data_collate, PadListDataCollate()]: + TESTS.append((dict, pad_collate, RandSpatialCropd("image", roi_size=[8, 7], random_size=True))) + TESTS.append((dict, pad_collate, RandRotated("image", prob=1, range_x=np.pi, keep_size=False))) + TESTS.append((dict, pad_collate, RandZoomd("image", prob=1, min_zoom=1.1, max_zoom=2.0, keep_size=False))) + TESTS.append((dict, pad_collate, RandRotate90d("image", prob=1, max_k=2))) -TESTS.append((dict, RandSpatialCropd("image", roi_size=[8, 7], random_size=True))) -TESTS.append((dict, RandRotated("image", prob=1, range_x=np.pi, keep_size=False))) -TESTS.append((dict, RandZoomd("image", prob=1, min_zoom=1.1, max_zoom=2.0, keep_size=False))) -TESTS.append((dict, RandRotate90d("image", prob=1, max_k=2))) - -TESTS.append((list, RandSpatialCrop(roi_size=[8, 7], random_size=True))) -TESTS.append((list, RandRotate(prob=1, range_x=np.pi, keep_size=False))) -TESTS.append((list, RandZoom(prob=1, min_zoom=1.1, max_zoom=2.0, keep_size=False))) -TESTS.append((list, RandRotate90(prob=1, max_k=2))) + TESTS.append((list, pad_collate, RandSpatialCrop(roi_size=[8, 7], random_size=True))) + TESTS.append((list, pad_collate, RandRotate(prob=1, range_x=np.pi, keep_size=False))) + TESTS.append((list, pad_collate, RandZoom(prob=1, min_zoom=1.1, max_zoom=2.0, keep_size=False))) + TESTS.append((list, pad_collate, RandRotate90(prob=1, max_k=2))) class _Dataset(torch.utils.data.Dataset): @@ -72,7 +73,7 @@ def tearDown(self) -> None: set_determinism(None) @parameterized.expand(TESTS) - def test_pad_collation(self, t_type, transform): + def test_pad_collation(self, t_type, collate_method, transform): if t_type == dict: dataset = CacheDataset(self.dict_data, transform, progress=False) @@ -86,9 +87,13 @@ def test_pad_collation(self, t_type, transform): pass # Padded collation shouldn't - loader = DataLoader(dataset, batch_size=2, collate_fn=pad_list_data_collate) - for _ in loader: - pass + loader = DataLoader(dataset, batch_size=10, collate_fn=collate_method) + # check collation in forward direction + for data in loader: + if t_type == dict: + decollated_data = decollate_batch(data) + for d in decollated_data: + PadListDataCollate.inverse(d) if __name__ == "__main__": From 79c70972ddef501adba7ee199f3b1b8ace28e130 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Fri, 19 Mar 2021 18:35:34 +0000 Subject: [PATCH 13/16] rst Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- docs/source/data.rst | 5 +++++ monai/data/inverse_batch_transform.py | 3 ++- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/docs/source/data.rst b/docs/source/data.rst index c95659bc6e..5191ce4312 100644 --- a/docs/source/data.rst +++ b/docs/source/data.rst @@ -177,3 +177,8 @@ DataLoader ThreadBuffer ~~~~~~~~~~~~ .. autoclass:: monai.data.ThreadBuffer + + +`BatchInverseTransform` +~~~~~~~~~~~~~~~~~~~~~~~ +.. autoclass:: BatchInverseTransform \ No newline at end of file diff --git a/monai/data/inverse_batch_transform.py b/monai/data/inverse_batch_transform.py index 09f3fc1f3e..5a33b2e2c1 100644 --- a/monai/data/inverse_batch_transform.py +++ b/monai/data/inverse_batch_transform.py @@ -19,6 +19,7 @@ from monai.data.utils import decollate_batch, pad_list_data_collate from monai.transforms.croppad.batch import PadListDataCollate from monai.transforms.inverse import InvertibleTransform +from monai.transforms.transform import Transform from monai.utils import first __all__ = ["BatchInverseTransform"] @@ -44,7 +45,7 @@ def __getitem__(self, index: int) -> Dict[Hashable, np.ndarray]: return self.invertible_transform.inverse(data) -class BatchInverseTransform: +class BatchInverseTransform(Transform): """Perform inverse on a batch of data. This is useful if you have inferred a batch of images and want to invert them all.""" def __init__( From 4aa837fde2e57b49616c772b08292a16e63678af Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Fri, 19 Mar 2021 18:41:12 +0000 Subject: [PATCH 14/16] change default collation of batch inversion Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/data/inverse_batch_transform.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/monai/data/inverse_batch_transform.py b/monai/data/inverse_batch_transform.py index 5a33b2e2c1..fbc42c6ce1 100644 --- a/monai/data/inverse_batch_transform.py +++ b/monai/data/inverse_batch_transform.py @@ -45,20 +45,22 @@ def __getitem__(self, index: int) -> Dict[Hashable, np.ndarray]: return self.invertible_transform.inverse(data) +def no_collation(x): + return x + + class BatchInverseTransform(Transform): """Perform inverse on a batch of data. This is useful if you have inferred a batch of images and want to invert them all.""" def __init__( - self, transform: InvertibleTransform, loader: TorchDataLoader, collate_fn: Optional[Callable] = None + self, transform: InvertibleTransform, loader: TorchDataLoader, collate_fn: Optional[Callable] = no_collation ) -> None: """ Args: transform: a callable data transform on input data. loader: data loader used to generate the batch of data. - collate_fn: how to collate data after inverse transformations. Default will use the DataLoader's default - collation method. If returning images of different sizes, this will likely create an error (since the - collation will concatenate arrays, requiring them to be the same size). In this case, using - `collate_fn=lambda x: x` might solve the problem. + collate_fn: how to collate data after inverse transformations. Default won't do any collation, so the output will be a + list of size batch size. """ self.transform = transform self.batch_size = loader.batch_size From 924901b2ba75a11d35a6fcd471ca48bf2affca97 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Sat, 20 Mar 2021 09:00:56 +0000 Subject: [PATCH 15/16] update docs Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- docs/source/data.rst | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/source/data.rst b/docs/source/data.rst index 63b3a8c595..8071bb1585 100644 --- a/docs/source/data.rst +++ b/docs/source/data.rst @@ -185,6 +185,6 @@ ThreadBuffer .. autoclass:: monai.data.ThreadBuffer -`BatchInverseTransform` -~~~~~~~~~~~~~~~~~~~~~~~ -.. autoclass:: BatchInverseTransform \ No newline at end of file +BatchInverseTransform +~~~~~~~~~~~~~~~~~~~~~ +.. autoclass:: monai.data.BatchInverseTransform From 6e3ca066f80df3ea725c1a433b1b1cbb7d93a7e7 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Sat, 20 Mar 2021 09:00:56 +0000 Subject: [PATCH 16/16] update docs Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com>