Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions monai/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
iter_patch_slices,
json_hashing,
list_data_collate,
pad_list_data_collate,
partition_dataset,
partition_dataset_classes,
pickle_hashing,
Expand Down
80 changes: 79 additions & 1 deletion monai/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
first,
optional_import,
)
from monai.utils.enums import Method

nib, _ = optional_import("nibabel")

Expand Down Expand Up @@ -63,6 +64,7 @@
"json_hashing",
"pickle_hashing",
"sorted_dict",
"pad_list_data_collate",
]


Expand Down Expand Up @@ -240,7 +242,83 @@ def list_data_collate(batch: Sequence):
"""
elem = batch[0]
data = [i for k in batch for i in k] if isinstance(elem, list) else batch
return default_collate(data)
try:
return default_collate(data)
except RuntimeError as re:
re_str = str(re)
if "stack expects each tensor to be equal size" in re_str:
re_str += (
"\nMONAI hint: if your transforms intentionally create images of different shapes, creating your "
+ "`DataLoader` with `collate_fn=pad_list_data_collate` might solve this problem (check its "
+ "documentation)."
)
raise RuntimeError(re_str)


def pad_list_data_collate(
batch: Sequence,
method: Union[Method, str] = Method.SYMMETRIC,
mode: Union[NumpyPadMode, str] = NumpyPadMode.CONSTANT,
):
"""
Same as MONAI's ``list_data_collate``, except any tensors are centrally padded to match the shape of the biggest
tensor in each dimension.

Note:
Need to use this collate if apply some transforms that can generate batch data.

Args:
batch: batch of data to pad-collate
method: padding method (see :py:class:`monai.transforms.SpatialPad`)
mode: padding mode (see :py:class:`monai.transforms.SpatialPad`)
"""
list_of_dicts = isinstance(batch[0], dict)
for key_or_idx in batch[0].keys() if list_of_dicts else range(len(batch[0])):
max_shapes = []
for elem in batch:
if not isinstance(elem[key_or_idx], (torch.Tensor, np.ndarray)):
break
max_shapes.append(elem[key_or_idx].shape[1:])
# len > 0 if objects were arrays
if len(max_shapes) == 0:
continue
max_shape = np.array(max_shapes).max(axis=0)
# If all same size, skip
if np.all(np.array(max_shapes).min(axis=0) == max_shape):
continue
# Do we need to convert output to Tensor?
output_to_tensor = isinstance(batch[0][key_or_idx], torch.Tensor)

# Use `SpatialPadd` or `SpatialPad` to match sizes
# Default params are central padding, padding with 0's
# If input is dictionary, use the dictionary version so that the transformation is recorded
padder: Union[SpatialPadd, SpatialPad]
if list_of_dicts:
from monai.transforms.croppad.dictionary import SpatialPadd # needs to be here to avoid circular import

padder = SpatialPadd(key_or_idx, max_shape, method, mode) # type: ignore

else:
from monai.transforms.croppad.array import SpatialPad # needs to be here to avoid circular import

padder = SpatialPad(max_shape, method, mode) # type: ignore

for idx in range(len(batch)):
padded = padder(batch[idx])[key_or_idx] if list_of_dicts else padder(batch[idx][key_or_idx])
# since tuple is immutable we'll have to recreate
if isinstance(batch[idx], tuple):
batch[idx] = list(batch[idx]) # type: ignore
batch[idx][key_or_idx] = padded
batch[idx] = tuple(batch[idx]) # type: ignore
# else, replace
else:
batch[idx][key_or_idx] = padder(batch[idx])[key_or_idx]

if output_to_tensor:
batch[idx][key_or_idx] = torch.Tensor(batch[idx][key_or_idx])

# After padding, use default list collator
return list_data_collate(batch)


def worker_init_fn(worker_id: int) -> None:
Expand Down
95 changes: 95 additions & 0 deletions tests/test_pad_collation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# Copyright 2020 - 2021 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import random
import unittest
from typing import List, Tuple

import numpy as np
import torch
from parameterized import parameterized

from monai.data import CacheDataset, DataLoader
from monai.data.utils import pad_list_data_collate
from monai.transforms import (
RandRotate,
RandRotate90,
RandRotate90d,
RandRotated,
RandSpatialCrop,
RandSpatialCropd,
RandZoom,
RandZoomd,
)
from monai.utils import set_determinism

TESTS: List[Tuple] = []


TESTS.append((dict, RandSpatialCropd("image", roi_size=[8, 7], random_size=True)))
TESTS.append((dict, RandRotated("image", prob=1, range_x=np.pi, keep_size=False)))
TESTS.append((dict, RandZoomd("image", prob=1, min_zoom=1.1, max_zoom=2.0, keep_size=False)))
TESTS.append((dict, RandRotate90d("image", prob=1, max_k=2)))

TESTS.append((list, RandSpatialCrop(roi_size=[8, 7], random_size=True)))
TESTS.append((list, RandRotate(prob=1, range_x=np.pi, keep_size=False)))
TESTS.append((list, RandZoom(prob=1, min_zoom=1.1, max_zoom=2.0, keep_size=False)))
TESTS.append((list, RandRotate90(prob=1, max_k=2)))


class _Dataset(torch.utils.data.Dataset):
def __init__(self, images, labels, transforms):
self.images = images
self.labels = labels
self.transforms = transforms

def __len__(self):
return len(self.images)

def __getitem__(self, index):
return self.transforms(self.images[index]), self.labels[index]


class TestPadCollation(unittest.TestCase):
def setUp(self) -> None:
set_determinism(seed=0)
# image is non square to throw rotation errors
im = np.arange(0, 10 * 9).reshape(1, 10, 9)
num_elements = 20
self.dict_data = [{"image": im} for _ in range(num_elements)]
self.list_data = [im for _ in range(num_elements)]
self.list_labels = [random.randint(0, 1) for _ in range(num_elements)]

def tearDown(self) -> None:
set_determinism(None)

@parameterized.expand(TESTS)
def test_pad_collation(self, t_type, transform):

if t_type == dict:
dataset = CacheDataset(self.dict_data, transform, progress=False)
else:
dataset = _Dataset(self.list_data, self.list_labels, transform)

# Default collation should raise an error
loader_fail = DataLoader(dataset, batch_size=10)
with self.assertRaises(RuntimeError):
for _ in loader_fail:
pass

# Padded collation shouldn't
loader = DataLoader(dataset, batch_size=2, collate_fn=pad_list_data_collate)
for _ in loader:
pass


if __name__ == "__main__":
unittest.main()