From 1a700c4d71074731658fe941741dd0101558a8ba Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Fri, 19 Mar 2021 18:42:53 +0000 Subject: [PATCH] PadListDataCollate transform Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/data/utils.py | 59 +++---------- monai/transforms/__init__.py | 1 + monai/transforms/croppad/batch.py | 129 ++++++++++++++++++++++++++++ monai/transforms/post/dictionary.py | 1 + tests/test_decollate.py | 35 +++++--- tests/test_pad_collation.py | 33 ++++--- 6 files changed, 181 insertions(+), 77 deletions(-) create mode 100644 monai/transforms/croppad/batch.py diff --git a/monai/data/utils.py b/monai/data/utils.py index ae0180f4b5..bdbfa5c636 100644 --- a/monai/data/utils.py +++ b/monai/data/utils.py @@ -337,64 +337,25 @@ def pad_list_data_collate( mode: Union[NumpyPadMode, str] = NumpyPadMode.CONSTANT, ): """ + Function version of :py:class:`monai.transforms.croppad.batch.PadListDataCollate`. + Same as MONAI's ``list_data_collate``, except any tensors are centrally padded to match the shape of the biggest - tensor in each dimension. + tensor in each dimension. This transform is useful if some of the applied transforms generate batch data of + different sizes. - Note: - Need to use this collate if apply some transforms that can generate batch data. + This can be used on both list and dictionary data. In the case of the dictionary data, this transform will be added + to the list of invertible transforms. + + The inverse can be called using the static method: `monai.transforms.croppad.batch.PadListDataCollate.inverse`. Args: batch: batch of data to pad-collate method: padding method (see :py:class:`monai.transforms.SpatialPad`) mode: padding mode (see :py:class:`monai.transforms.SpatialPad`) """ - list_of_dicts = isinstance(batch[0], dict) - for key_or_idx in batch[0].keys() if list_of_dicts else range(len(batch[0])): - max_shapes = [] - for elem in batch: - if not isinstance(elem[key_or_idx], (torch.Tensor, np.ndarray)): - break - max_shapes.append(elem[key_or_idx].shape[1:]) - # len > 0 if objects were arrays - if len(max_shapes) == 0: - continue - max_shape = np.array(max_shapes).max(axis=0) - # If all same size, skip - if np.all(np.array(max_shapes).min(axis=0) == max_shape): - continue - # Do we need to convert output to Tensor? - output_to_tensor = isinstance(batch[0][key_or_idx], torch.Tensor) - - # Use `SpatialPadd` or `SpatialPad` to match sizes - # Default params are central padding, padding with 0's - # If input is dictionary, use the dictionary version so that the transformation is recorded - padder: Union[SpatialPadd, SpatialPad] - if list_of_dicts: - from monai.transforms.croppad.dictionary import SpatialPadd # needs to be here to avoid circular import + from monai.transforms.croppad.batch import PadListDataCollate # needs to be here to avoid circular import - padder = SpatialPadd(key_or_idx, max_shape, method, mode) # type: ignore - - else: - from monai.transforms.croppad.array import SpatialPad # needs to be here to avoid circular import - - padder = SpatialPad(max_shape, method, mode) # type: ignore - - for idx in range(len(batch)): - padded = padder(batch[idx])[key_or_idx] if list_of_dicts else padder(batch[idx][key_or_idx]) - # since tuple is immutable we'll have to recreate - if isinstance(batch[idx], tuple): - batch[idx] = list(batch[idx]) # type: ignore - batch[idx][key_or_idx] = padded - batch[idx] = tuple(batch[idx]) # type: ignore - # else, replace - else: - batch[idx][key_or_idx] = padder(batch[idx])[key_or_idx] - - if output_to_tensor: - batch[idx][key_or_idx] = torch.Tensor(batch[idx][key_or_idx]) - - # After padding, use default list collator - return list_data_collate(batch) + return PadListDataCollate(method, mode)(batch) def worker_init_fn(worker_id: int) -> None: diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index 0ce09e69d2..22311cdca6 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -25,6 +25,7 @@ SpatialCrop, SpatialPad, ) +from .croppad.batch import PadListDataCollate from .croppad.dictionary import ( BorderPadd, BorderPadD, diff --git a/monai/transforms/croppad/batch.py b/monai/transforms/croppad/batch.py new file mode 100644 index 0000000000..7cbf39597c --- /dev/null +++ b/monai/transforms/croppad/batch.py @@ -0,0 +1,129 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +A collection of "vanilla" transforms for crop and pad operations acting on batches of data +https://github.com/Project-MONAI/MONAI/wiki/MONAI_Design +""" + +from copy import deepcopy +from typing import Any, Dict, Hashable, Union + +import numpy as np +import torch + +from monai.data.utils import list_data_collate +from monai.transforms.compose import Compose +from monai.transforms.croppad.array import CenterSpatialCrop, SpatialPad +from monai.transforms.inverse import InvertibleTransform +from monai.transforms.utility.array import ToTensor +from monai.utils.enums import InverseKeys, Method, NumpyPadMode + +__all__ = [ + "PadListDataCollate", +] + + +def replace_element(to_replace, batch, idx, key_or_idx): + # since tuple is immutable we'll have to recreate + if isinstance(batch[idx], tuple): + batch_idx_list = list(batch[idx]) + batch_idx_list[key_or_idx] = to_replace + batch[idx] = tuple(batch_idx_list) + # else, replace + else: + batch[idx][key_or_idx] = to_replace + return batch + + +class PadListDataCollate(InvertibleTransform): + """ + Same as MONAI's ``list_data_collate``, except any tensors are centrally padded to match the shape of the biggest + tensor in each dimension. This transform is useful if some of the applied transforms generate batch data of + different sizes. + + This can be used on both list and dictionary data. In the case of the dictionary data, this transform will be added + to the list of invertible transforms. + + Note that normally, a user won't explicitly use the `__call__` method. Rather this would be passed to the `DataLoader`. + This means that `__call__` handles data as it comes out of a `DataLoader`, containing batch dimension. However, the + `inverse` operates on dictionaries containing images of shape `C,H,W,[D]`. This asymmetry is necessary so that we can + pass the inverse through multiprocessing. + + Args: + batch: batch of data to pad-collate + method: padding method (see :py:class:`monai.transforms.SpatialPad`) + mode: padding mode (see :py:class:`monai.transforms.SpatialPad`) + """ + + def __init__( + self, + method: Union[Method, str] = Method.SYMMETRIC, + mode: Union[NumpyPadMode, str] = NumpyPadMode.CONSTANT, + ) -> None: + self.method = method + self.mode = mode + + def __call__(self, batch: Any): + # data is either list of dicts or list of lists + is_list_of_dicts = isinstance(batch[0], dict) + # loop over items inside of each element in a batch + for key_or_idx in batch[0].keys() if is_list_of_dicts else range(len(batch[0])): + # calculate max size of each dimension + max_shapes = [] + for elem in batch: + if not isinstance(elem[key_or_idx], (torch.Tensor, np.ndarray)): + break + max_shapes.append(elem[key_or_idx].shape[1:]) + # len > 0 if objects were arrays, else skip as no padding to be done + if len(max_shapes) == 0: + continue + max_shape = np.array(max_shapes).max(axis=0) + # If all same size, skip + if np.all(np.array(max_shapes).min(axis=0) == max_shape): + continue + # Do we need to convert output to Tensor? + output_to_tensor = isinstance(batch[0][key_or_idx], torch.Tensor) + + # Use `SpatialPadd` or `SpatialPad` to match sizes + # Default params are central padding, padding with 0's + # If input is dictionary, use the dictionary version so that the transformation is recorded + + padder = SpatialPad(max_shape, self.method, self.mode) # type: ignore + transform = padder if not output_to_tensor else Compose([padder, ToTensor()]) + + for idx in range(len(batch)): + im = batch[idx][key_or_idx] + orig_size = im.shape[1:] + padded = transform(batch[idx][key_or_idx]) + batch = replace_element(padded, batch, idx, key_or_idx) + + # If we have a dictionary of data, append to list + if is_list_of_dicts: + self.push_transform(batch[idx], key_or_idx, orig_size=orig_size) + + # After padding, use default list collator + return list_data_collate(batch) + + @staticmethod + def inverse(data: dict) -> Dict[Hashable, np.ndarray]: + if not isinstance(data, dict): + raise RuntimeError("Inverse can only currently be applied on dictionaries.") + + d = deepcopy(data) + for key in d.keys(): + transform_key = str(key) + InverseKeys.KEY_SUFFIX.value + if transform_key in d.keys(): + transform = d[transform_key][-1] + if transform[InverseKeys.CLASS_NAME.value] == PadListDataCollate.__name__: + d[key] = CenterSpatialCrop(transform["orig_size"])(d[key]) + # remove transform + d[transform_key].pop() + return d diff --git a/monai/transforms/post/dictionary.py b/monai/transforms/post/dictionary.py index 42796e2412..6d28f780d4 100644 --- a/monai/transforms/post/dictionary.py +++ b/monai/transforms/post/dictionary.py @@ -333,6 +333,7 @@ class Decollated(MapTransform): """ def __init__(self, batch_size: Optional[int] = None) -> None: + super().__init__(None) self.batch_size = batch_size def __call__(self, data: dict) -> List[dict]: diff --git a/tests/test_decollate.py b/tests/test_decollate.py index 4ed8de6bbb..4dc5a217a7 100644 --- a/tests/test_decollate.py +++ b/tests/test_decollate.py @@ -12,25 +12,38 @@ import sys import unittest from enum import Enum +from typing import List, Tuple import numpy as np import torch +from parameterized import parameterized from monai.data import CacheDataset, DataLoader, create_test_image_2d from monai.data.utils import decollate_batch from monai.transforms import AddChanneld, Compose, LoadImaged, RandFlipd, SpatialPadd, ToTensord from monai.transforms.post.dictionary import Decollated +from monai.transforms.spatial.dictionary import RandAffined, RandRotate90d from monai.utils import optional_import, set_determinism from monai.utils.enums import InverseKeys from tests.utils import make_nifti_image _, has_nib = optional_import("nibabel") +KEYS = ["image"] + +TESTS: List[Tuple] = [] +TESTS.append((SpatialPadd(KEYS, 150), RandFlipd(KEYS, prob=1.0, spatial_axis=1))) +TESTS.append((RandRotate90d(KEYS, prob=0.0, max_k=1),)) +TESTS.append((RandAffined(KEYS, prob=0.0, translate_range=10),)) + class TestDeCollate(unittest.TestCase): def setUp(self) -> None: set_determinism(seed=0) + im = create_test_image_2d(100, 101)[0] + self.data = [{"image": make_nifti_image(im) if has_nib else im} for _ in range(6)] + def tearDown(self) -> None: set_determinism(None) @@ -55,24 +68,18 @@ def check_match(self, in1, in2): else: raise RuntimeError(f"Not sure how to compare types. type(in1): {type(in1)}, type(in2): {type(in2)}") - def test_decollation(self, batch_size=2, num_workers=2): + @parameterized.expand(TESTS) + def test_decollation(self, *transforms): - im = create_test_image_2d(100, 101)[0] - data = [{"image": make_nifti_image(im) if has_nib else im} for _ in range(6)] - - transforms = Compose( - [ - AddChanneld("image"), - SpatialPadd("image", 150), - RandFlipd("image", prob=1.0, spatial_axis=1), - ToTensord("image"), - ] - ) + batch_size = 2 + num_workers = 2 + + t_compose = Compose([AddChanneld(KEYS), Compose(transforms), ToTensord(KEYS)]) # If nibabel present, read from disk if has_nib: - transforms = Compose([LoadImaged("image"), transforms]) + t_compose = Compose([LoadImaged("image"), t_compose]) - dataset = CacheDataset(data, transforms, progress=False) + dataset = CacheDataset(self.data, t_compose, progress=False) loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers) for b, batch_data in enumerate(loader): diff --git a/tests/test_pad_collation.py b/tests/test_pad_collation.py index 156d2649e0..3835dc8895 100644 --- a/tests/test_pad_collation.py +++ b/tests/test_pad_collation.py @@ -18,8 +18,9 @@ from parameterized import parameterized from monai.data import CacheDataset, DataLoader -from monai.data.utils import pad_list_data_collate +from monai.data.utils import decollate_batch, pad_list_data_collate from monai.transforms import ( + PadListDataCollate, RandRotate, RandRotate90, RandRotate90d, @@ -33,16 +34,16 @@ TESTS: List[Tuple] = [] +for pad_collate in [pad_list_data_collate, PadListDataCollate()]: + TESTS.append((dict, pad_collate, RandSpatialCropd("image", roi_size=[8, 7], random_size=True))) + TESTS.append((dict, pad_collate, RandRotated("image", prob=1, range_x=np.pi, keep_size=False))) + TESTS.append((dict, pad_collate, RandZoomd("image", prob=1, min_zoom=1.1, max_zoom=2.0, keep_size=False))) + TESTS.append((dict, pad_collate, RandRotate90d("image", prob=1, max_k=2))) -TESTS.append((dict, RandSpatialCropd("image", roi_size=[8, 7], random_size=True))) -TESTS.append((dict, RandRotated("image", prob=1, range_x=np.pi, keep_size=False))) -TESTS.append((dict, RandZoomd("image", prob=1, min_zoom=1.1, max_zoom=2.0, keep_size=False))) -TESTS.append((dict, RandRotate90d("image", prob=1, max_k=2))) - -TESTS.append((list, RandSpatialCrop(roi_size=[8, 7], random_size=True))) -TESTS.append((list, RandRotate(prob=1, range_x=np.pi, keep_size=False))) -TESTS.append((list, RandZoom(prob=1, min_zoom=1.1, max_zoom=2.0, keep_size=False))) -TESTS.append((list, RandRotate90(prob=1, max_k=2))) + TESTS.append((list, pad_collate, RandSpatialCrop(roi_size=[8, 7], random_size=True))) + TESTS.append((list, pad_collate, RandRotate(prob=1, range_x=np.pi, keep_size=False))) + TESTS.append((list, pad_collate, RandZoom(prob=1, min_zoom=1.1, max_zoom=2.0, keep_size=False))) + TESTS.append((list, pad_collate, RandRotate90(prob=1, max_k=2))) class _Dataset(torch.utils.data.Dataset): @@ -72,7 +73,7 @@ def tearDown(self) -> None: set_determinism(None) @parameterized.expand(TESTS) - def test_pad_collation(self, t_type, transform): + def test_pad_collation(self, t_type, collate_method, transform): if t_type == dict: dataset = CacheDataset(self.dict_data, transform, progress=False) @@ -86,9 +87,13 @@ def test_pad_collation(self, t_type, transform): pass # Padded collation shouldn't - loader = DataLoader(dataset, batch_size=2, collate_fn=pad_list_data_collate) - for _ in loader: - pass + loader = DataLoader(dataset, batch_size=10, collate_fn=collate_method) + # check collation in forward direction + for data in loader: + if t_type == dict: + decollated_data = decollate_batch(data) + for d in decollated_data: + PadListDataCollate.inverse(d) if __name__ == "__main__":