Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 10 additions & 49 deletions monai/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,64 +337,25 @@ def pad_list_data_collate(
mode: Union[NumpyPadMode, str] = NumpyPadMode.CONSTANT,
):
"""
Function version of :py:class:`monai.transforms.croppad.batch.PadListDataCollate`.

Same as MONAI's ``list_data_collate``, except any tensors are centrally padded to match the shape of the biggest
tensor in each dimension.
tensor in each dimension. This transform is useful if some of the applied transforms generate batch data of
different sizes.

Note:
Need to use this collate if apply some transforms that can generate batch data.
This can be used on both list and dictionary data. In the case of the dictionary data, this transform will be added
to the list of invertible transforms.

The inverse can be called using the static method: `monai.transforms.croppad.batch.PadListDataCollate.inverse`.

Args:
batch: batch of data to pad-collate
method: padding method (see :py:class:`monai.transforms.SpatialPad`)
mode: padding mode (see :py:class:`monai.transforms.SpatialPad`)
"""
list_of_dicts = isinstance(batch[0], dict)
for key_or_idx in batch[0].keys() if list_of_dicts else range(len(batch[0])):
max_shapes = []
for elem in batch:
if not isinstance(elem[key_or_idx], (torch.Tensor, np.ndarray)):
break
max_shapes.append(elem[key_or_idx].shape[1:])
# len > 0 if objects were arrays
if len(max_shapes) == 0:
continue
max_shape = np.array(max_shapes).max(axis=0)
# If all same size, skip
if np.all(np.array(max_shapes).min(axis=0) == max_shape):
continue
# Do we need to convert output to Tensor?
output_to_tensor = isinstance(batch[0][key_or_idx], torch.Tensor)

# Use `SpatialPadd` or `SpatialPad` to match sizes
# Default params are central padding, padding with 0's
# If input is dictionary, use the dictionary version so that the transformation is recorded
padder: Union[SpatialPadd, SpatialPad]
if list_of_dicts:
from monai.transforms.croppad.dictionary import SpatialPadd # needs to be here to avoid circular import
from monai.transforms.croppad.batch import PadListDataCollate # needs to be here to avoid circular import

padder = SpatialPadd(key_or_idx, max_shape, method, mode) # type: ignore

else:
from monai.transforms.croppad.array import SpatialPad # needs to be here to avoid circular import

padder = SpatialPad(max_shape, method, mode) # type: ignore

for idx in range(len(batch)):
padded = padder(batch[idx])[key_or_idx] if list_of_dicts else padder(batch[idx][key_or_idx])
# since tuple is immutable we'll have to recreate
if isinstance(batch[idx], tuple):
batch[idx] = list(batch[idx]) # type: ignore
batch[idx][key_or_idx] = padded
batch[idx] = tuple(batch[idx]) # type: ignore
# else, replace
else:
batch[idx][key_or_idx] = padder(batch[idx])[key_or_idx]

if output_to_tensor:
batch[idx][key_or_idx] = torch.Tensor(batch[idx][key_or_idx])

# After padding, use default list collator
return list_data_collate(batch)
return PadListDataCollate(method, mode)(batch)


def worker_init_fn(worker_id: int) -> None:
Expand Down
1 change: 1 addition & 0 deletions monai/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
SpatialCrop,
SpatialPad,
)
from .croppad.batch import PadListDataCollate
from .croppad.dictionary import (
BorderPadd,
BorderPadD,
Expand Down
129 changes: 129 additions & 0 deletions monai/transforms/croppad/batch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
# Copyright 2020 - 2021 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
A collection of "vanilla" transforms for crop and pad operations acting on batches of data
https://github.com/Project-MONAI/MONAI/wiki/MONAI_Design
"""

from copy import deepcopy
from typing import Any, Dict, Hashable, Union

import numpy as np
import torch

from monai.data.utils import list_data_collate
from monai.transforms.compose import Compose
from monai.transforms.croppad.array import CenterSpatialCrop, SpatialPad
from monai.transforms.inverse import InvertibleTransform
from monai.transforms.utility.array import ToTensor
from monai.utils.enums import InverseKeys, Method, NumpyPadMode

__all__ = [
"PadListDataCollate",
]


def replace_element(to_replace, batch, idx, key_or_idx):
# since tuple is immutable we'll have to recreate
if isinstance(batch[idx], tuple):
batch_idx_list = list(batch[idx])
batch_idx_list[key_or_idx] = to_replace
batch[idx] = tuple(batch_idx_list)
# else, replace
else:
batch[idx][key_or_idx] = to_replace
return batch


class PadListDataCollate(InvertibleTransform):
"""
Same as MONAI's ``list_data_collate``, except any tensors are centrally padded to match the shape of the biggest
tensor in each dimension. This transform is useful if some of the applied transforms generate batch data of
different sizes.

This can be used on both list and dictionary data. In the case of the dictionary data, this transform will be added
to the list of invertible transforms.

Note that normally, a user won't explicitly use the `__call__` method. Rather this would be passed to the `DataLoader`.
This means that `__call__` handles data as it comes out of a `DataLoader`, containing batch dimension. However, the
`inverse` operates on dictionaries containing images of shape `C,H,W,[D]`. This asymmetry is necessary so that we can
pass the inverse through multiprocessing.

Args:
batch: batch of data to pad-collate
method: padding method (see :py:class:`monai.transforms.SpatialPad`)
mode: padding mode (see :py:class:`monai.transforms.SpatialPad`)
"""

def __init__(
self,
method: Union[Method, str] = Method.SYMMETRIC,
mode: Union[NumpyPadMode, str] = NumpyPadMode.CONSTANT,
) -> None:
self.method = method
self.mode = mode

def __call__(self, batch: Any):
# data is either list of dicts or list of lists
is_list_of_dicts = isinstance(batch[0], dict)
# loop over items inside of each element in a batch
for key_or_idx in batch[0].keys() if is_list_of_dicts else range(len(batch[0])):
# calculate max size of each dimension
max_shapes = []
for elem in batch:
if not isinstance(elem[key_or_idx], (torch.Tensor, np.ndarray)):
break
max_shapes.append(elem[key_or_idx].shape[1:])
# len > 0 if objects were arrays, else skip as no padding to be done
if len(max_shapes) == 0:
continue
max_shape = np.array(max_shapes).max(axis=0)
# If all same size, skip
if np.all(np.array(max_shapes).min(axis=0) == max_shape):
continue
# Do we need to convert output to Tensor?
output_to_tensor = isinstance(batch[0][key_or_idx], torch.Tensor)

# Use `SpatialPadd` or `SpatialPad` to match sizes
# Default params are central padding, padding with 0's
# If input is dictionary, use the dictionary version so that the transformation is recorded

padder = SpatialPad(max_shape, self.method, self.mode) # type: ignore
transform = padder if not output_to_tensor else Compose([padder, ToTensor()])

for idx in range(len(batch)):
im = batch[idx][key_or_idx]
orig_size = im.shape[1:]
padded = transform(batch[idx][key_or_idx])
batch = replace_element(padded, batch, idx, key_or_idx)

# If we have a dictionary of data, append to list
if is_list_of_dicts:
self.push_transform(batch[idx], key_or_idx, orig_size=orig_size)

# After padding, use default list collator
return list_data_collate(batch)

@staticmethod
def inverse(data: dict) -> Dict[Hashable, np.ndarray]:
if not isinstance(data, dict):
raise RuntimeError("Inverse can only currently be applied on dictionaries.")

d = deepcopy(data)
for key in d.keys():
transform_key = str(key) + InverseKeys.KEY_SUFFIX.value
if transform_key in d.keys():
transform = d[transform_key][-1]
if transform[InverseKeys.CLASS_NAME.value] == PadListDataCollate.__name__:
d[key] = CenterSpatialCrop(transform["orig_size"])(d[key])
# remove transform
d[transform_key].pop()
return d
1 change: 1 addition & 0 deletions monai/transforms/post/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,7 @@ class Decollated(MapTransform):
"""

def __init__(self, batch_size: Optional[int] = None) -> None:
super().__init__(None)
self.batch_size = batch_size

def __call__(self, data: dict) -> List[dict]:
Expand Down
35 changes: 21 additions & 14 deletions tests/test_decollate.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,25 +12,38 @@
import sys
import unittest
from enum import Enum
from typing import List, Tuple

import numpy as np
import torch
from parameterized import parameterized

from monai.data import CacheDataset, DataLoader, create_test_image_2d
from monai.data.utils import decollate_batch
from monai.transforms import AddChanneld, Compose, LoadImaged, RandFlipd, SpatialPadd, ToTensord
from monai.transforms.post.dictionary import Decollated
from monai.transforms.spatial.dictionary import RandAffined, RandRotate90d
from monai.utils import optional_import, set_determinism
from monai.utils.enums import InverseKeys
from tests.utils import make_nifti_image

_, has_nib = optional_import("nibabel")

KEYS = ["image"]

TESTS: List[Tuple] = []
TESTS.append((SpatialPadd(KEYS, 150), RandFlipd(KEYS, prob=1.0, spatial_axis=1)))
TESTS.append((RandRotate90d(KEYS, prob=0.0, max_k=1),))
TESTS.append((RandAffined(KEYS, prob=0.0, translate_range=10),))


class TestDeCollate(unittest.TestCase):
def setUp(self) -> None:
set_determinism(seed=0)

im = create_test_image_2d(100, 101)[0]
self.data = [{"image": make_nifti_image(im) if has_nib else im} for _ in range(6)]

def tearDown(self) -> None:
set_determinism(None)

Expand All @@ -55,24 +68,18 @@ def check_match(self, in1, in2):
else:
raise RuntimeError(f"Not sure how to compare types. type(in1): {type(in1)}, type(in2): {type(in2)}")

def test_decollation(self, batch_size=2, num_workers=2):
@parameterized.expand(TESTS)
def test_decollation(self, *transforms):

im = create_test_image_2d(100, 101)[0]
data = [{"image": make_nifti_image(im) if has_nib else im} for _ in range(6)]

transforms = Compose(
[
AddChanneld("image"),
SpatialPadd("image", 150),
RandFlipd("image", prob=1.0, spatial_axis=1),
ToTensord("image"),
]
)
batch_size = 2
num_workers = 2

t_compose = Compose([AddChanneld(KEYS), Compose(transforms), ToTensord(KEYS)])
# If nibabel present, read from disk
if has_nib:
transforms = Compose([LoadImaged("image"), transforms])
t_compose = Compose([LoadImaged("image"), t_compose])

dataset = CacheDataset(data, transforms, progress=False)
dataset = CacheDataset(self.data, t_compose, progress=False)
loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

for b, batch_data in enumerate(loader):
Expand Down
33 changes: 19 additions & 14 deletions tests/test_pad_collation.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@
from parameterized import parameterized

from monai.data import CacheDataset, DataLoader
from monai.data.utils import pad_list_data_collate
from monai.data.utils import decollate_batch, pad_list_data_collate
from monai.transforms import (
PadListDataCollate,
RandRotate,
RandRotate90,
RandRotate90d,
Expand All @@ -33,16 +34,16 @@

TESTS: List[Tuple] = []

for pad_collate in [pad_list_data_collate, PadListDataCollate()]:
TESTS.append((dict, pad_collate, RandSpatialCropd("image", roi_size=[8, 7], random_size=True)))
TESTS.append((dict, pad_collate, RandRotated("image", prob=1, range_x=np.pi, keep_size=False)))
TESTS.append((dict, pad_collate, RandZoomd("image", prob=1, min_zoom=1.1, max_zoom=2.0, keep_size=False)))
TESTS.append((dict, pad_collate, RandRotate90d("image", prob=1, max_k=2)))

TESTS.append((dict, RandSpatialCropd("image", roi_size=[8, 7], random_size=True)))
TESTS.append((dict, RandRotated("image", prob=1, range_x=np.pi, keep_size=False)))
TESTS.append((dict, RandZoomd("image", prob=1, min_zoom=1.1, max_zoom=2.0, keep_size=False)))
TESTS.append((dict, RandRotate90d("image", prob=1, max_k=2)))

TESTS.append((list, RandSpatialCrop(roi_size=[8, 7], random_size=True)))
TESTS.append((list, RandRotate(prob=1, range_x=np.pi, keep_size=False)))
TESTS.append((list, RandZoom(prob=1, min_zoom=1.1, max_zoom=2.0, keep_size=False)))
TESTS.append((list, RandRotate90(prob=1, max_k=2)))
TESTS.append((list, pad_collate, RandSpatialCrop(roi_size=[8, 7], random_size=True)))
TESTS.append((list, pad_collate, RandRotate(prob=1, range_x=np.pi, keep_size=False)))
TESTS.append((list, pad_collate, RandZoom(prob=1, min_zoom=1.1, max_zoom=2.0, keep_size=False)))
TESTS.append((list, pad_collate, RandRotate90(prob=1, max_k=2)))


class _Dataset(torch.utils.data.Dataset):
Expand Down Expand Up @@ -72,7 +73,7 @@ def tearDown(self) -> None:
set_determinism(None)

@parameterized.expand(TESTS)
def test_pad_collation(self, t_type, transform):
def test_pad_collation(self, t_type, collate_method, transform):

if t_type == dict:
dataset = CacheDataset(self.dict_data, transform, progress=False)
Expand All @@ -86,9 +87,13 @@ def test_pad_collation(self, t_type, transform):
pass

# Padded collation shouldn't
loader = DataLoader(dataset, batch_size=2, collate_fn=pad_list_data_collate)
for _ in loader:
pass
loader = DataLoader(dataset, batch_size=10, collate_fn=collate_method)
# check collation in forward direction
for data in loader:
if t_type == dict:
decollated_data = decollate_batch(data)
for d in decollated_data:
PadListDataCollate.inverse(d)


if __name__ == "__main__":
Expand Down