diff --git a/monai/data/__init__.py b/monai/data/__init__.py index 99990d7f53..3dd0a980ef 100644 --- a/monai/data/__init__.py +++ b/monai/data/__init__.py @@ -38,6 +38,7 @@ compute_shape_offset, correct_nifti_header_if_necessary, create_file_basename, + decollate_batch, dense_patch_slices, get_random_patch, get_valid_patch_size, diff --git a/monai/data/utils.py b/monai/data/utils.py index c42e1abefa..7717ddf3aa 100644 --- a/monai/data/utils.py +++ b/monai/data/utils.py @@ -18,7 +18,7 @@ from collections import defaultdict from itertools import product, starmap from pathlib import PurePath -from typing import Dict, Generator, Iterable, List, Optional, Sequence, Tuple, Union +from typing import Any, Dict, Generator, Iterable, List, Optional, Sequence, Tuple, Union import numpy as np import torch @@ -37,6 +37,7 @@ optional_import, ) from monai.utils.enums import Method +from monai.utils.misc import issequenceiterable nib, _ = optional_import("nibabel") @@ -64,6 +65,7 @@ "json_hashing", "pickle_hashing", "sorted_dict", + "decollate_batch", "pad_list_data_collate", ] @@ -255,6 +257,71 @@ def list_data_collate(batch: Sequence): raise RuntimeError(re_str) +def decollate_batch(data: dict, batch_size: Optional[int] = None) -> List[dict]: + """De-collate a batch of data (for example, as produced by a `DataLoader`). + + Returns a list of dictionaries. Each dictionary will only contain the data for a given batch. + + Images originally stored as (B,C,H,W,[D]) will be returned as (C,H,W,[D]). Other information, + such as metadata, may have been stored in a list (or a list inside nested dictionaries). In + this case we return the element of the list corresponding to the batch idx. + + Return types aren't guaranteed to be the same as the original, since numpy arrays will have been + converted to torch.Tensor, and tuples/lists may have been converted to lists of tensors + + For example: + + .. code-block:: python + + batch_data = { + "image": torch.rand((2,1,10,10)), + "image_meta_dict": {"scl_slope": torch.Tensor([0.0, 0.0])} + } + out = decollate_batch(batch_data) + print(len(out)) + >>> 2 + + print(out[0]) + >>> {'image': tensor([[[4.3549e-01...43e-01]]]), 'image_meta_dict': {'scl_slope': 0.0}} + + Args: + data: data to be de-collated. + batch_size: number of batches in data. If `None` is passed, try to figure out batch size. + """ + if not isinstance(data, dict): + raise RuntimeError("Only currently implemented for dictionary data (might be trivial to adapt).") + if batch_size is None: + for v in data.values(): + if isinstance(v, torch.Tensor): + batch_size = v.shape[0] + break + if batch_size is None: + raise RuntimeError("Couldn't determine batch size, please specify as argument.") + + def torch_to_single(d: torch.Tensor): + """If input is a torch.Tensor with only 1 element, return just the element.""" + return d if d.numel() > 1 else d.item() + + def decollate(data: Any, idx: int): + """Recursively de-collate.""" + if isinstance(data, dict): + return {k: decollate(v, idx) for k, v in data.items()} + if isinstance(data, torch.Tensor): + out = data[idx] + return torch_to_single(out) + elif isinstance(data, list): + if len(data) == 0: + return data + if isinstance(data[0], torch.Tensor): + return [torch_to_single(d[idx]) for d in data] + if issequenceiterable(data[0]): + return [decollate(d, idx) for d in data] + return data[idx] + raise TypeError(f"Not sure how to de-collate type: {type(data)}") + + return [{key: decollate(data[key], idx) for key in data.keys()} for idx in range(batch_size)] + + def pad_list_data_collate( batch: Sequence, method: Union[Method, str] = Method.SYMMETRIC, diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index 357e00c6dd..5578b93077 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -155,6 +155,9 @@ AsDiscreted, AsDiscreteD, AsDiscreteDict, + Decollated, + DecollateD, + DecollateDict, Ensembled, KeepLargestConnectedComponentd, KeepLargestConnectedComponentD, diff --git a/monai/transforms/post/dictionary.py b/monai/transforms/post/dictionary.py index aff4ae3572..85abdac0ac 100644 --- a/monai/transforms/post/dictionary.py +++ b/monai/transforms/post/dictionary.py @@ -20,6 +20,7 @@ import numpy as np import torch +import monai.data from monai.config import KeysCollection from monai.transforms.post.array import ( Activations, @@ -52,6 +53,9 @@ "MeanEnsembleDict", "VoteEnsembleD", "VoteEnsembleDict", + "DecollateD", + "DecollateDict", + "Decollated", ] @@ -306,9 +310,28 @@ def __init__( super().__init__(keys, ensemble, output_key) +class Decollated(MapTransform): + """ + Decollate a batch of data. + + Note that unlike most MapTransforms, this will decollate all data, so keys are not needed. + + Args: + batch_size: if not supplied, we try to determine it based on array lengths. Will raise an error if + it fails to determine it automatically. + """ + + def __init__(self, batch_size: Optional[int] = None) -> None: + self.batch_size = batch_size + + def __call__(self, data: dict) -> List[dict]: + return monai.data.decollate_batch(data, self.batch_size) + + ActivationsD = ActivationsDict = Activationsd AsDiscreteD = AsDiscreteDict = AsDiscreted KeepLargestConnectedComponentD = KeepLargestConnectedComponentDict = KeepLargestConnectedComponentd LabelToContourD = LabelToContourDict = LabelToContourd MeanEnsembleD = MeanEnsembleDict = MeanEnsembled VoteEnsembleD = VoteEnsembleDict = VoteEnsembled +DecollateD = DecollateDict = Decollated diff --git a/tests/test_decollate.py b/tests/test_decollate.py new file mode 100644 index 0000000000..5c6f04b48e --- /dev/null +++ b/tests/test_decollate.py @@ -0,0 +1,88 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import torch +from parameterized import parameterized + +from monai.data import CacheDataset, DataLoader, create_test_image_2d +from monai.data.utils import decollate_batch +from monai.transforms import AddChanneld, Compose, LoadImaged, RandFlipd, SpatialPadd, ToTensord +from monai.transforms.post.dictionary import Decollated +from monai.utils import optional_import, set_determinism +from tests.utils import make_nifti_image + +_, has_nib = optional_import("nibabel") + +IM_2D = create_test_image_2d(100, 101)[0] +DATA_2D = {"image": make_nifti_image(IM_2D) if has_nib else IM_2D} + +TESTS = [] +TESTS.append( + ( + "2D", + [DATA_2D for _ in range(6)], + ) +) + + +class TestDeCollate(unittest.TestCase): + def setUp(self) -> None: + set_determinism(seed=0) + + def tearDown(self) -> None: + set_determinism(None) + + def check_match(self, in1, in2): + if isinstance(in1, dict): + self.assertTrue(isinstance(in2, dict)) + self.check_match(list(in1.keys()), list(in2.keys())) + self.check_match(list(in1.values()), list(in2.values())) + elif any(isinstance(in1, i) for i in [list, tuple]): + for l1, l2 in zip(in1, in2): + self.check_match(l1, l2) + elif any(isinstance(in1, i) for i in [str, int]): + self.assertEqual(in1, in2) + elif any(isinstance(in1, i) for i in [torch.Tensor, np.ndarray]): + np.testing.assert_array_equal(in1, in2) + else: + raise RuntimeError(f"Not sure how to compare types. type(in1): {type(in1)}, type(in2): {type(in2)}") + + @parameterized.expand(TESTS) + def test_decollation(self, _, data, batch_size=2, num_workers=2): + transforms = Compose( + [ + AddChanneld("image"), + SpatialPadd("image", 150), + RandFlipd("image", prob=1.0, spatial_axis=1), + ToTensord("image"), + ] + ) + # If nibabel present, read from disk + if has_nib: + transforms = Compose([LoadImaged("image"), transforms]) + + dataset = CacheDataset(data, transforms, progress=False) + loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers) + + for b, batch_data in enumerate(loader): + decollated_1 = decollate_batch(batch_data) + decollated_2 = Decollated()(batch_data) + + for decollated in [decollated_1, decollated_2]: + for i, d in enumerate(decollated): + self.check_match(dataset[b * batch_size + i], d) + + +if __name__ == "__main__": + unittest.main()