From 97152878bd39a8d14929a1843ec3064f3617720d Mon Sep 17 00:00:00 2001 From: Rich <33289025+rijobro@users.noreply.github.com> Date: Wed, 24 Feb 2021 13:42:02 +0000 Subject: [PATCH 1/6] pad_collation Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/data/__init__.py | 1 + monai/data/utils.py | 53 ++++++++++++++++++++++++++++++++- tests/test_pad_collation.py | 58 +++++++++++++++++++++++++++++++++++++ 3 files changed, 111 insertions(+), 1 deletion(-) create mode 100644 tests/test_pad_collation.py diff --git a/monai/data/__init__.py b/monai/data/__init__.py index e0db1e17ae..99990d7f53 100644 --- a/monai/data/__init__.py +++ b/monai/data/__init__.py @@ -46,6 +46,7 @@ iter_patch_slices, json_hashing, list_data_collate, + pad_list_data_collate, partition_dataset, partition_dataset_classes, pickle_hashing, diff --git a/monai/data/utils.py b/monai/data/utils.py index acc6d2e97a..3577daadff 100644 --- a/monai/data/utils.py +++ b/monai/data/utils.py @@ -63,6 +63,7 @@ "json_hashing", "pickle_hashing", "sorted_dict", + "pad_list_data_collate", ] @@ -240,7 +241,57 @@ def list_data_collate(batch: Sequence): """ elem = batch[0] data = [i for k in batch for i in k] if isinstance(elem, list) else batch - return default_collate(data) + try: + return default_collate(data) + except RuntimeError as re: + re_str = str(re) + if "stack expects each tensor to be equal size" in re_str: + re_str += ( + "\nMONAI hint: if your transforms intentionally create images of different shapes, creating your " + + "`DataLoader` with `collate_fn=pad_list_data_collate` might solve this problem (check its " + + "documentation)." + ) + raise RuntimeError(re_str) + + +def pad_list_data_collate(batch: Sequence): + """ + Same as MONAI's ``list_data_collate``, except any tensors are centrally padded to match the shape of the biggest + tensor in each dimension. + + Note: + Need to use this collate if apply some transforms that can generate batch data. + + """ + for key in batch[0].keys(): + max_shapes = [] + for elem in batch: + if not isinstance(elem[key], (torch.Tensor, np.ndarray)): + break + max_shapes.append(elem[key].shape[1:]) + # len > 0 if objects were arrays + if len(max_shapes) == 0: + continue + max_shape = np.array(max_shapes).max(axis=0) + # If all same size, skip + if np.all(np.array(max_shapes).min(axis=0) == max_shape): + continue + # Do we need to convert output to Tensor? + output_to_tensor = isinstance(batch[0][key], torch.Tensor) + + # Use `SpatialPadd` to match sizes + # Default params are central padding, padding with 0's + # Use the dictionary version so that the transformation is recorded + from monai.transforms.croppad.dictionary import SpatialPadd # needs to be here to avoid circular import + + padder = SpatialPadd(key, max_shape) # type: ignore + for idx in range(len(batch)): + batch[idx][key] = padder(batch[idx])[key] + if output_to_tensor: + batch[idx][key] = torch.Tensor(batch[idx][key]) + + # After padding, use default list collator + return list_data_collate(batch) def worker_init_fn(worker_id: int) -> None: diff --git a/tests/test_pad_collation.py b/tests/test_pad_collation.py new file mode 100644 index 0000000000..10ca1dec01 --- /dev/null +++ b/tests/test_pad_collation.py @@ -0,0 +1,58 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from typing import List, Tuple + +import numpy as np +from parameterized import parameterized + +from monai.data.utils import pad_list_data_collate +from monai.transforms import RandRotate90d, RandRotated, RandSpatialCropd, RandZoomd +from monai.utils import set_determinism + +set_determinism(seed=0) + +from monai.data import CacheDataset, DataLoader + +TESTS: List[Tuple] = [] + +TESTS.append((RandSpatialCropd("image", roi_size=[8, 7], random_size=True),)) +TESTS.append((RandRotated("image", prob=1, range_x=np.pi, keep_size=False),)) +TESTS.append((RandZoomd("image", prob=1, min_zoom=1.1, max_zoom=2.0, keep_size=False),)) +TESTS.append((RandRotate90d("image", prob=1, max_k=2),)) + + +class TestPadCollation(unittest.TestCase): + def setUp(self) -> None: + # image is non square to throw rotation errors + im = np.arange(0, 10 * 9).reshape(1, 10, 9) + self.data = [{"image": im} for _ in range(2)] + + @parameterized.expand(TESTS) + def test_pad_collation(self, transform): + + dataset = CacheDataset(self.data, transform, progress=False) + + # Default collation should raise an error + loader_fail = DataLoader(dataset, batch_size=2) + with self.assertRaises(RuntimeError): + for _ in loader_fail: + pass + + # Padded collation shouldn't + loader = DataLoader(dataset, batch_size=2, collate_fn=pad_list_data_collate) + for _ in loader: + pass + + +if __name__ == "__main__": + unittest.main() From d224ea5fed122622f7d510e6686da237890d8ac8 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Wed, 24 Feb 2021 14:06:41 +0000 Subject: [PATCH 2/6] increase number of test cases to ensure required testing errors Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- tests/test_pad_collation.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_pad_collation.py b/tests/test_pad_collation.py index 10ca1dec01..0068ae13c8 100644 --- a/tests/test_pad_collation.py +++ b/tests/test_pad_collation.py @@ -35,7 +35,7 @@ class TestPadCollation(unittest.TestCase): def setUp(self) -> None: # image is non square to throw rotation errors im = np.arange(0, 10 * 9).reshape(1, 10, 9) - self.data = [{"image": im} for _ in range(2)] + self.data = [{"image": im} for _ in range(20)] @parameterized.expand(TESTS) def test_pad_collation(self, transform): @@ -43,7 +43,7 @@ def test_pad_collation(self, transform): dataset = CacheDataset(self.data, transform, progress=False) # Default collation should raise an error - loader_fail = DataLoader(dataset, batch_size=2) + loader_fail = DataLoader(dataset, batch_size=10) with self.assertRaises(RuntimeError): for _ in loader_fail: pass From 735f9e0ba68ad9168fa64ab0a265a18c949582df Mon Sep 17 00:00:00 2001 From: Rich <33289025+rijobro@users.noreply.github.com> Date: Wed, 24 Feb 2021 14:43:57 +0000 Subject: [PATCH 3/6] determinism in setUp Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- tests/test_pad_collation.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/test_pad_collation.py b/tests/test_pad_collation.py index 0068ae13c8..b8a56abd95 100644 --- a/tests/test_pad_collation.py +++ b/tests/test_pad_collation.py @@ -19,7 +19,6 @@ from monai.transforms import RandRotate90d, RandRotated, RandSpatialCropd, RandZoomd from monai.utils import set_determinism -set_determinism(seed=0) from monai.data import CacheDataset, DataLoader @@ -33,10 +32,14 @@ class TestPadCollation(unittest.TestCase): def setUp(self) -> None: + set_determinism(seed=0) # image is non square to throw rotation errors im = np.arange(0, 10 * 9).reshape(1, 10, 9) self.data = [{"image": im} for _ in range(20)] + def tearDown(self) -> None: + set_determinism(None) + @parameterized.expand(TESTS) def test_pad_collation(self, transform): From e6910c2c71364a94d9f00926b8d7776080709be5 Mon Sep 17 00:00:00 2001 From: Rich <33289025+rijobro@users.noreply.github.com> Date: Wed, 24 Feb 2021 16:00:42 +0000 Subject: [PATCH 4/6] pad collate for list of lists Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/data/utils.py | 37 ++++++++++++++++++------- tests/test_pad_collation.py | 54 ++++++++++++++++++++++++++++++------- 2 files changed, 71 insertions(+), 20 deletions(-) diff --git a/monai/data/utils.py b/monai/data/utils.py index 3577daadff..9483d04bca 100644 --- a/monai/data/utils.py +++ b/monai/data/utils.py @@ -263,12 +263,13 @@ def pad_list_data_collate(batch: Sequence): Need to use this collate if apply some transforms that can generate batch data. """ - for key in batch[0].keys(): + list_of_dicts = isinstance(batch[0], dict) + for key_or_idx in batch[0].keys() if list_of_dicts else range(len(batch[0])): max_shapes = [] for elem in batch: - if not isinstance(elem[key], (torch.Tensor, np.ndarray)): + if not isinstance(elem[key_or_idx], (torch.Tensor, np.ndarray)): break - max_shapes.append(elem[key].shape[1:]) + max_shapes.append(elem[key_or_idx].shape[1:]) # len > 0 if objects were arrays if len(max_shapes) == 0: continue @@ -277,18 +278,34 @@ def pad_list_data_collate(batch: Sequence): if np.all(np.array(max_shapes).min(axis=0) == max_shape): continue # Do we need to convert output to Tensor? - output_to_tensor = isinstance(batch[0][key], torch.Tensor) + output_to_tensor = isinstance(batch[0][key_or_idx], torch.Tensor) - # Use `SpatialPadd` to match sizes + # Use `SpatialPadd` or `SpatialPad` to match sizes # Default params are central padding, padding with 0's - # Use the dictionary version so that the transformation is recorded - from monai.transforms.croppad.dictionary import SpatialPadd # needs to be here to avoid circular import + # If input is dictionary, use the dictionary version so that the transformation is recorded + if list_of_dicts: + from monai.transforms.croppad.dictionary import SpatialPadd # needs to be here to avoid circular import + + padder = SpatialPadd(key_or_idx, max_shape) # type: ignore + + else: + from monai.transforms.croppad.array import SpatialPad # needs to be here to avoid circular import + + padder = SpatialPad(max_shape) - padder = SpatialPadd(key, max_shape) # type: ignore for idx in range(len(batch)): - batch[idx][key] = padder(batch[idx])[key] + padded = padder(batch[idx])[key_or_idx] if list_of_dicts else padder(batch[idx][key_or_idx]) + # since tuple is immutable we'll have to recreate + if isinstance(batch[idx], tuple): + batch[idx] = list(batch[idx]) + batch[idx][key_or_idx] = padded + batch[idx] = tuple(batch[idx]) + # else, replace + else: + batch[idx][key_or_idx] = padder(batch[idx])[key_or_idx] + if output_to_tensor: - batch[idx][key] = torch.Tensor(batch[idx][key]) + batch[idx][key_or_idx] = torch.Tensor(batch[idx][key_or_idx]) # After padding, use default list collator return list_data_collate(batch) diff --git a/tests/test_pad_collation.py b/tests/test_pad_collation.py index b8a56abd95..156d2649e0 100644 --- a/tests/test_pad_collation.py +++ b/tests/test_pad_collation.py @@ -9,25 +9,53 @@ # See the License for the specific language governing permissions and # limitations under the License. +import random import unittest from typing import List, Tuple import numpy as np +import torch from parameterized import parameterized +from monai.data import CacheDataset, DataLoader from monai.data.utils import pad_list_data_collate -from monai.transforms import RandRotate90d, RandRotated, RandSpatialCropd, RandZoomd +from monai.transforms import ( + RandRotate, + RandRotate90, + RandRotate90d, + RandRotated, + RandSpatialCrop, + RandSpatialCropd, + RandZoom, + RandZoomd, +) from monai.utils import set_determinism +TESTS: List[Tuple] = [] -from monai.data import CacheDataset, DataLoader -TESTS: List[Tuple] = [] +TESTS.append((dict, RandSpatialCropd("image", roi_size=[8, 7], random_size=True))) +TESTS.append((dict, RandRotated("image", prob=1, range_x=np.pi, keep_size=False))) +TESTS.append((dict, RandZoomd("image", prob=1, min_zoom=1.1, max_zoom=2.0, keep_size=False))) +TESTS.append((dict, RandRotate90d("image", prob=1, max_k=2))) + +TESTS.append((list, RandSpatialCrop(roi_size=[8, 7], random_size=True))) +TESTS.append((list, RandRotate(prob=1, range_x=np.pi, keep_size=False))) +TESTS.append((list, RandZoom(prob=1, min_zoom=1.1, max_zoom=2.0, keep_size=False))) +TESTS.append((list, RandRotate90(prob=1, max_k=2))) + + +class _Dataset(torch.utils.data.Dataset): + def __init__(self, images, labels, transforms): + self.images = images + self.labels = labels + self.transforms = transforms + + def __len__(self): + return len(self.images) -TESTS.append((RandSpatialCropd("image", roi_size=[8, 7], random_size=True),)) -TESTS.append((RandRotated("image", prob=1, range_x=np.pi, keep_size=False),)) -TESTS.append((RandZoomd("image", prob=1, min_zoom=1.1, max_zoom=2.0, keep_size=False),)) -TESTS.append((RandRotate90d("image", prob=1, max_k=2),)) + def __getitem__(self, index): + return self.transforms(self.images[index]), self.labels[index] class TestPadCollation(unittest.TestCase): @@ -35,15 +63,21 @@ def setUp(self) -> None: set_determinism(seed=0) # image is non square to throw rotation errors im = np.arange(0, 10 * 9).reshape(1, 10, 9) - self.data = [{"image": im} for _ in range(20)] + num_elements = 20 + self.dict_data = [{"image": im} for _ in range(num_elements)] + self.list_data = [im for _ in range(num_elements)] + self.list_labels = [random.randint(0, 1) for _ in range(num_elements)] def tearDown(self) -> None: set_determinism(None) @parameterized.expand(TESTS) - def test_pad_collation(self, transform): + def test_pad_collation(self, t_type, transform): - dataset = CacheDataset(self.data, transform, progress=False) + if t_type == dict: + dataset = CacheDataset(self.dict_data, transform, progress=False) + else: + dataset = _Dataset(self.list_data, self.list_labels, transform) # Default collation should raise an error loader_fail = DataLoader(dataset, batch_size=10) From 412b1f7549f77bad40fe529a488c724aacf41e64 Mon Sep 17 00:00:00 2001 From: Rich <33289025+rijobro@users.noreply.github.com> Date: Wed, 24 Feb 2021 16:13:30 +0000 Subject: [PATCH 5/6] code format Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/data/utils.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/monai/data/utils.py b/monai/data/utils.py index 9483d04bca..df74744ac7 100644 --- a/monai/data/utils.py +++ b/monai/data/utils.py @@ -283,6 +283,7 @@ def pad_list_data_collate(batch: Sequence): # Use `SpatialPadd` or `SpatialPad` to match sizes # Default params are central padding, padding with 0's # If input is dictionary, use the dictionary version so that the transformation is recorded + padder: Union[SpatialPadd, SpatialPad] if list_of_dicts: from monai.transforms.croppad.dictionary import SpatialPadd # needs to be here to avoid circular import @@ -291,15 +292,15 @@ def pad_list_data_collate(batch: Sequence): else: from monai.transforms.croppad.array import SpatialPad # needs to be here to avoid circular import - padder = SpatialPad(max_shape) + padder = SpatialPad(max_shape) # type: ignore for idx in range(len(batch)): padded = padder(batch[idx])[key_or_idx] if list_of_dicts else padder(batch[idx][key_or_idx]) # since tuple is immutable we'll have to recreate if isinstance(batch[idx], tuple): - batch[idx] = list(batch[idx]) + batch[idx] = list(batch[idx]) # type: ignore batch[idx][key_or_idx] = padded - batch[idx] = tuple(batch[idx]) + batch[idx] = tuple(batch[idx]) # type: ignore # else, replace else: batch[idx][key_or_idx] = padder(batch[idx])[key_or_idx] From b66f094106b4cca858b3e4cba0114976f966510e Mon Sep 17 00:00:00 2001 From: Rich <33289025+rijobro@users.noreply.github.com> Date: Wed, 24 Feb 2021 17:08:44 +0000 Subject: [PATCH 6/6] allow padding options Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/data/utils.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/monai/data/utils.py b/monai/data/utils.py index df74744ac7..c42e1abefa 100644 --- a/monai/data/utils.py +++ b/monai/data/utils.py @@ -36,6 +36,7 @@ first, optional_import, ) +from monai.utils.enums import Method nib, _ = optional_import("nibabel") @@ -254,7 +255,11 @@ def list_data_collate(batch: Sequence): raise RuntimeError(re_str) -def pad_list_data_collate(batch: Sequence): +def pad_list_data_collate( + batch: Sequence, + method: Union[Method, str] = Method.SYMMETRIC, + mode: Union[NumpyPadMode, str] = NumpyPadMode.CONSTANT, +): """ Same as MONAI's ``list_data_collate``, except any tensors are centrally padded to match the shape of the biggest tensor in each dimension. @@ -262,6 +267,10 @@ def pad_list_data_collate(batch: Sequence): Note: Need to use this collate if apply some transforms that can generate batch data. + Args: + batch: batch of data to pad-collate + method: padding method (see :py:class:`monai.transforms.SpatialPad`) + mode: padding mode (see :py:class:`monai.transforms.SpatialPad`) """ list_of_dicts = isinstance(batch[0], dict) for key_or_idx in batch[0].keys() if list_of_dicts else range(len(batch[0])): @@ -287,12 +296,12 @@ def pad_list_data_collate(batch: Sequence): if list_of_dicts: from monai.transforms.croppad.dictionary import SpatialPadd # needs to be here to avoid circular import - padder = SpatialPadd(key_or_idx, max_shape) # type: ignore + padder = SpatialPadd(key_or_idx, max_shape, method, mode) # type: ignore else: from monai.transforms.croppad.array import SpatialPad # needs to be here to avoid circular import - padder = SpatialPad(max_shape) # type: ignore + padder = SpatialPad(max_shape, method, mode) # type: ignore for idx in range(len(batch)): padded = padder(batch[idx])[key_or_idx] if list_of_dicts else padder(batch[idx][key_or_idx])