diff --git a/monai/data/__init__.py b/monai/data/__init__.py index e0db1e17ae..99990d7f53 100644 --- a/monai/data/__init__.py +++ b/monai/data/__init__.py @@ -46,6 +46,7 @@ iter_patch_slices, json_hashing, list_data_collate, + pad_list_data_collate, partition_dataset, partition_dataset_classes, pickle_hashing, diff --git a/monai/data/utils.py b/monai/data/utils.py index acc6d2e97a..c42e1abefa 100644 --- a/monai/data/utils.py +++ b/monai/data/utils.py @@ -36,6 +36,7 @@ first, optional_import, ) +from monai.utils.enums import Method nib, _ = optional_import("nibabel") @@ -63,6 +64,7 @@ "json_hashing", "pickle_hashing", "sorted_dict", + "pad_list_data_collate", ] @@ -240,7 +242,83 @@ def list_data_collate(batch: Sequence): """ elem = batch[0] data = [i for k in batch for i in k] if isinstance(elem, list) else batch - return default_collate(data) + try: + return default_collate(data) + except RuntimeError as re: + re_str = str(re) + if "stack expects each tensor to be equal size" in re_str: + re_str += ( + "\nMONAI hint: if your transforms intentionally create images of different shapes, creating your " + + "`DataLoader` with `collate_fn=pad_list_data_collate` might solve this problem (check its " + + "documentation)." + ) + raise RuntimeError(re_str) + + +def pad_list_data_collate( + batch: Sequence, + method: Union[Method, str] = Method.SYMMETRIC, + mode: Union[NumpyPadMode, str] = NumpyPadMode.CONSTANT, +): + """ + Same as MONAI's ``list_data_collate``, except any tensors are centrally padded to match the shape of the biggest + tensor in each dimension. + + Note: + Need to use this collate if apply some transforms that can generate batch data. + + Args: + batch: batch of data to pad-collate + method: padding method (see :py:class:`monai.transforms.SpatialPad`) + mode: padding mode (see :py:class:`monai.transforms.SpatialPad`) + """ + list_of_dicts = isinstance(batch[0], dict) + for key_or_idx in batch[0].keys() if list_of_dicts else range(len(batch[0])): + max_shapes = [] + for elem in batch: + if not isinstance(elem[key_or_idx], (torch.Tensor, np.ndarray)): + break + max_shapes.append(elem[key_or_idx].shape[1:]) + # len > 0 if objects were arrays + if len(max_shapes) == 0: + continue + max_shape = np.array(max_shapes).max(axis=0) + # If all same size, skip + if np.all(np.array(max_shapes).min(axis=0) == max_shape): + continue + # Do we need to convert output to Tensor? + output_to_tensor = isinstance(batch[0][key_or_idx], torch.Tensor) + + # Use `SpatialPadd` or `SpatialPad` to match sizes + # Default params are central padding, padding with 0's + # If input is dictionary, use the dictionary version so that the transformation is recorded + padder: Union[SpatialPadd, SpatialPad] + if list_of_dicts: + from monai.transforms.croppad.dictionary import SpatialPadd # needs to be here to avoid circular import + + padder = SpatialPadd(key_or_idx, max_shape, method, mode) # type: ignore + + else: + from monai.transforms.croppad.array import SpatialPad # needs to be here to avoid circular import + + padder = SpatialPad(max_shape, method, mode) # type: ignore + + for idx in range(len(batch)): + padded = padder(batch[idx])[key_or_idx] if list_of_dicts else padder(batch[idx][key_or_idx]) + # since tuple is immutable we'll have to recreate + if isinstance(batch[idx], tuple): + batch[idx] = list(batch[idx]) # type: ignore + batch[idx][key_or_idx] = padded + batch[idx] = tuple(batch[idx]) # type: ignore + # else, replace + else: + batch[idx][key_or_idx] = padder(batch[idx])[key_or_idx] + + if output_to_tensor: + batch[idx][key_or_idx] = torch.Tensor(batch[idx][key_or_idx]) + + # After padding, use default list collator + return list_data_collate(batch) def worker_init_fn(worker_id: int) -> None: diff --git a/tests/test_pad_collation.py b/tests/test_pad_collation.py new file mode 100644 index 0000000000..156d2649e0 --- /dev/null +++ b/tests/test_pad_collation.py @@ -0,0 +1,95 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random +import unittest +from typing import List, Tuple + +import numpy as np +import torch +from parameterized import parameterized + +from monai.data import CacheDataset, DataLoader +from monai.data.utils import pad_list_data_collate +from monai.transforms import ( + RandRotate, + RandRotate90, + RandRotate90d, + RandRotated, + RandSpatialCrop, + RandSpatialCropd, + RandZoom, + RandZoomd, +) +from monai.utils import set_determinism + +TESTS: List[Tuple] = [] + + +TESTS.append((dict, RandSpatialCropd("image", roi_size=[8, 7], random_size=True))) +TESTS.append((dict, RandRotated("image", prob=1, range_x=np.pi, keep_size=False))) +TESTS.append((dict, RandZoomd("image", prob=1, min_zoom=1.1, max_zoom=2.0, keep_size=False))) +TESTS.append((dict, RandRotate90d("image", prob=1, max_k=2))) + +TESTS.append((list, RandSpatialCrop(roi_size=[8, 7], random_size=True))) +TESTS.append((list, RandRotate(prob=1, range_x=np.pi, keep_size=False))) +TESTS.append((list, RandZoom(prob=1, min_zoom=1.1, max_zoom=2.0, keep_size=False))) +TESTS.append((list, RandRotate90(prob=1, max_k=2))) + + +class _Dataset(torch.utils.data.Dataset): + def __init__(self, images, labels, transforms): + self.images = images + self.labels = labels + self.transforms = transforms + + def __len__(self): + return len(self.images) + + def __getitem__(self, index): + return self.transforms(self.images[index]), self.labels[index] + + +class TestPadCollation(unittest.TestCase): + def setUp(self) -> None: + set_determinism(seed=0) + # image is non square to throw rotation errors + im = np.arange(0, 10 * 9).reshape(1, 10, 9) + num_elements = 20 + self.dict_data = [{"image": im} for _ in range(num_elements)] + self.list_data = [im for _ in range(num_elements)] + self.list_labels = [random.randint(0, 1) for _ in range(num_elements)] + + def tearDown(self) -> None: + set_determinism(None) + + @parameterized.expand(TESTS) + def test_pad_collation(self, t_type, transform): + + if t_type == dict: + dataset = CacheDataset(self.dict_data, transform, progress=False) + else: + dataset = _Dataset(self.list_data, self.list_labels, transform) + + # Default collation should raise an error + loader_fail = DataLoader(dataset, batch_size=10) + with self.assertRaises(RuntimeError): + for _ in loader_fail: + pass + + # Padded collation shouldn't + loader = DataLoader(dataset, batch_size=2, collate_fn=pad_list_data_collate) + for _ in loader: + pass + + +if __name__ == "__main__": + unittest.main()