From 54effecf8723d4c343aad61b1f602ef3625232b7 Mon Sep 17 00:00:00 2001 From: Rich <33289025+rijobro@users.noreply.github.com> Date: Tue, 23 Feb 2021 16:28:02 +0000 Subject: [PATCH 1/9] decollate batch Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/data/__init__.py | 1 + monai/data/utils.py | 66 +++++++++++++++++++++++++++++++++++ tests/test_decollate.py | 77 +++++++++++++++++++++++++++++++++++++++++ 3 files changed, 144 insertions(+) create mode 100644 tests/test_decollate.py diff --git a/monai/data/__init__.py b/monai/data/__init__.py index e0db1e17ae..c0dbb6302d 100644 --- a/monai/data/__init__.py +++ b/monai/data/__init__.py @@ -38,6 +38,7 @@ compute_shape_offset, correct_nifti_header_if_necessary, create_file_basename, + decollate_batch, dense_patch_slices, get_random_patch, get_valid_patch_size, diff --git a/monai/data/utils.py b/monai/data/utils.py index acc6d2e97a..7141d13647 100644 --- a/monai/data/utils.py +++ b/monai/data/utils.py @@ -63,6 +63,7 @@ "json_hashing", "pickle_hashing", "sorted_dict", + "decollate_batch", ] @@ -243,6 +244,71 @@ def list_data_collate(batch: Sequence): return default_collate(data) +def decollate_batch(data: dict, batch_size: Optional[int] = None): + """De-collate a batch of data (for example, as produced by a `DataLoader`). + + Returns a list of dictionaries. Each dictionary will only contain the data for a given batch. + + Images originally stored as (B,C,H,W,[D]) will be returned as (C,H,W,[D]). Other information, + such as metadata, may have been stored in a list (or a list inside nested dictionaries). In + this case we return the element of the list corresponding to the batch idx. + + Return types aren't guaranteed to be the same as the original, since numpy arrays will have been + converted to torch.Tensor, and tuples/lists may have been converted to lists of tensors + + For example: + + ``` + batch_data = { + "image": torch.rand((2,1,10,10)), + "image_meta_dict": {"scl_slope": torch.Tensor([0.0, 0.0])} + } + out = decollate_batch(batch_data) + print(len(out)) + >>> 2 + + print(out[0]) + >>> {'image': tensor([[[4.3549e-01...43e-01]]]), 'image_meta_dict': {'scl_slope': 0.0}} + ``` + + Args: + data: data to be de-collated + batch_size: number of batches in data. If `None` is passed, try to figure out batch size. + """ + if not isinstance(data, dict): + raise RuntimeError("Only currently implemented for dictionary data (might be trivial to adapt).") + if batch_size is None: + for v in data.values(): + if isinstance(v, torch.Tensor): + batch_size = v.shape[0] + break + if batch_size is None: + raise RuntimeError("Couldn't determine batch size, please specify as argument.") + + def torch_to_single(d: torch.Tensor): + """If input is a torch.Tensor with only 1 element, return just the element.""" + return d if d.numel() > 1 else d.item() + + def decollate(data: Any, idx: int): + """Recursively de-collate.""" + if isinstance(data, dict): + return {k: decollate(v, idx) for k, v in data.items()} + if isinstance(data, torch.Tensor): + out = data[idx] + return torch_to_single(out) + elif isinstance(data, list): + if len(data) == 0: + return data + if isinstance(data[0], torch.Tensor): + return [torch_to_single(d[idx]) for d in data] + if issequenceiterable(data[0]): + return [decollate(d, idx) for d in data] + return data[idx] + raise TypeError(f"Not sure how to de-collate type: {type(data)}") + + return [{key: decollate(data[key], idx) for key in data.keys()} for idx in range(batch_size)] + + def worker_init_fn(worker_id: int) -> None: """ Callback function for PyTorch DataLoader `worker_init_fn`. diff --git a/tests/test_decollate.py b/tests/test_decollate.py new file mode 100644 index 0000000000..dcf94d06fc --- /dev/null +++ b/tests/test_decollate.py @@ -0,0 +1,77 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import torch +from parameterized import parameterized + +from monai.data import CacheDataset, DataLoader, create_test_image_2d +from monai.data.utils import decollate_batch +from monai.transforms import AddChanneld, Compose, LoadImaged, RandFlipd, SpatialPadd, ToTensord +from monai.utils import set_determinism +from tests.utils import make_nifti_image + +set_determinism(seed=0) + +IM_2D_FNAME = make_nifti_image(create_test_image_2d(100, 101)[0]) + +DATA_2D = {"image": IM_2D_FNAME} + +TESTS = [] +TESTS.append( + ( + "2D", + [DATA_2D for _ in range(6)], + ) +) + + +class TestDeCollate(unittest.TestCase): + def check_match(self, in1, in2): + if isinstance(in1, dict): + self.assertTrue(isinstance(in2, dict)) + self.check_match(list(in1.keys()), list(in2.keys())) + self.check_match(list(in1.values()), list(in2.values())) + elif any(isinstance(in1, i) for i in [list, tuple]): + for l1, l2 in zip(in1, in2): + self.check_match(l1, l2) + elif any(isinstance(in1, i) for i in [str, int]): + self.assertEqual(in1, in2) + elif any(isinstance(in1, i) for i in [torch.Tensor, np.ndarray]): + np.testing.assert_array_equal(in1, in2) + else: + raise RuntimeError(f"Not sure how to compare types. type(in1): {type(in1)}, type(in2): {type(in2)}") + + @parameterized.expand(TESTS) + def test_decollation(self, _, data, batch_size=2, num_workers=2): + transforms = Compose( + [ + LoadImaged("image"), + AddChanneld("image"), + SpatialPadd("image", 150), + RandFlipd("image", prob=1.0, spatial_axis=1), + ToTensord("image"), + ] + ) + dataset = CacheDataset(data, transforms, progress=False) + loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers) + + for b, batch_data in enumerate(loader): + decollated = decollate_batch(batch_data) + + for i, d in enumerate(decollated): + self.check_match(dataset[b * batch_size + i], d) + + +if __name__ == "__main__": + unittest.main() From 4bdb951bd0d478209a5d2855785bc22175e0553c Mon Sep 17 00:00:00 2001 From: Rich <33289025+rijobro@users.noreply.github.com> Date: Tue, 23 Feb 2021 16:54:02 +0000 Subject: [PATCH 2/9] add imports Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/data/utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/monai/data/utils.py b/monai/data/utils.py index 7141d13647..b6e0da8db2 100644 --- a/monai/data/utils.py +++ b/monai/data/utils.py @@ -18,7 +18,7 @@ from collections import defaultdict from itertools import product, starmap from pathlib import PurePath -from typing import Dict, Generator, Iterable, List, Optional, Sequence, Tuple, Union +from typing import Any, Dict, Generator, Iterable, List, Optional, Sequence, Tuple, Union import numpy as np import torch @@ -36,6 +36,7 @@ first, optional_import, ) +from monai.utils.misc import issequenceiterable nib, _ = optional_import("nibabel") From d35c69aa51cc6eca556a3b3e7db1222f964bdf4e Mon Sep 17 00:00:00 2001 From: Rich <33289025+rijobro@users.noreply.github.com> Date: Tue, 23 Feb 2021 17:08:07 +0000 Subject: [PATCH 3/9] docstring fix Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/data/utils.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/monai/data/utils.py b/monai/data/utils.py index b6e0da8db2..9752bed93b 100644 --- a/monai/data/utils.py +++ b/monai/data/utils.py @@ -259,18 +259,18 @@ def decollate_batch(data: dict, batch_size: Optional[int] = None): For example: - ``` - batch_data = { - "image": torch.rand((2,1,10,10)), - "image_meta_dict": {"scl_slope": torch.Tensor([0.0, 0.0])} - } - out = decollate_batch(batch_data) - print(len(out)) - >>> 2 - - print(out[0]) - >>> {'image': tensor([[[4.3549e-01...43e-01]]]), 'image_meta_dict': {'scl_slope': 0.0}} - ``` + .. code-block:: python + + batch_data = { + "image": torch.rand((2,1,10,10)), + "image_meta_dict": {"scl_slope": torch.Tensor([0.0, 0.0])} + } + out = decollate_batch(batch_data) + print(len(out)) + >>> 2 + + print(out[0]) + >>> {'image': tensor([[[4.3549e-01...43e-01]]]), 'image_meta_dict': {'scl_slope': 0.0}} Args: data: data to be de-collated From db2760604c40021ca5aa0c6cfd45710020250ac1 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Wed, 24 Feb 2021 11:08:59 +0000 Subject: [PATCH 4/9] if no nibabel dont save nii to disk Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- tests/test_decollate.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/tests/test_decollate.py b/tests/test_decollate.py index dcf94d06fc..c56ce4271e 100644 --- a/tests/test_decollate.py +++ b/tests/test_decollate.py @@ -18,14 +18,15 @@ from monai.data import CacheDataset, DataLoader, create_test_image_2d from monai.data.utils import decollate_batch from monai.transforms import AddChanneld, Compose, LoadImaged, RandFlipd, SpatialPadd, ToTensord -from monai.utils import set_determinism +from monai.utils import optional_import, set_determinism from tests.utils import make_nifti_image -set_determinism(seed=0) +_, has_nib = optional_import("nibabel") -IM_2D_FNAME = make_nifti_image(create_test_image_2d(100, 101)[0]) +set_determinism(seed=0) -DATA_2D = {"image": IM_2D_FNAME} +IM_2D = create_test_image_2d(100, 101)[0] +DATA_2D = {"image": make_nifti_image(IM_2D) if has_nib else IM_2D} TESTS = [] TESTS.append( @@ -56,13 +57,16 @@ def check_match(self, in1, in2): def test_decollation(self, _, data, batch_size=2, num_workers=2): transforms = Compose( [ - LoadImaged("image"), AddChanneld("image"), SpatialPadd("image", 150), RandFlipd("image", prob=1.0, spatial_axis=1), ToTensord("image"), ] ) + # If nibabel present, read from disk + if has_nib: + transforms.transforms.insert(0, LoadImaged("image")) + dataset = CacheDataset(data, transforms, progress=False) loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers) From ea445eb1cb34134b7a3671b0ac30fbc4caef6996 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Wed, 24 Feb 2021 11:24:47 +0000 Subject: [PATCH 5/9] fix nibabel read Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- tests/test_decollate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_decollate.py b/tests/test_decollate.py index c56ce4271e..538849f686 100644 --- a/tests/test_decollate.py +++ b/tests/test_decollate.py @@ -65,7 +65,7 @@ def test_decollation(self, _, data, batch_size=2, num_workers=2): ) # If nibabel present, read from disk if has_nib: - transforms.transforms.insert(0, LoadImaged("image")) + transforms = Compose([LoadImaged("image"), transforms]) dataset = CacheDataset(data, transforms, progress=False) loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers) From 6f19f44e39966d119fae2e45d00a41faaa9bbc7f Mon Sep 17 00:00:00 2001 From: Rich <33289025+rijobro@users.noreply.github.com> Date: Wed, 24 Feb 2021 14:45:01 +0000 Subject: [PATCH 6/9] determinism in setUp Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- tests/test_decollate.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/tests/test_decollate.py b/tests/test_decollate.py index 538849f686..a82991e70c 100644 --- a/tests/test_decollate.py +++ b/tests/test_decollate.py @@ -23,8 +23,6 @@ _, has_nib = optional_import("nibabel") -set_determinism(seed=0) - IM_2D = create_test_image_2d(100, 101)[0] DATA_2D = {"image": make_nifti_image(IM_2D) if has_nib else IM_2D} @@ -38,6 +36,13 @@ class TestDeCollate(unittest.TestCase): + + def setUp(self) -> None: + set_determinism(seed=0) + + def tearDown(self) -> None: + set_determinism(None) + def check_match(self, in1, in2): if isinstance(in1, dict): self.assertTrue(isinstance(in2, dict)) From 5a6800aaef07a7a5f2560a1bf1d4188dd67e88c3 Mon Sep 17 00:00:00 2001 From: Rich <33289025+rijobro@users.noreply.github.com> Date: Wed, 24 Feb 2021 16:15:38 +0000 Subject: [PATCH 7/9] code format Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- tests/test_decollate.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_decollate.py b/tests/test_decollate.py index a82991e70c..e9e7c8fdde 100644 --- a/tests/test_decollate.py +++ b/tests/test_decollate.py @@ -36,7 +36,6 @@ class TestDeCollate(unittest.TestCase): - def setUp(self) -> None: set_determinism(seed=0) From b9efa4d5208126e3853ce89f70661289a3c57e43 Mon Sep 17 00:00:00 2001 From: Rich <33289025+rijobro@users.noreply.github.com> Date: Wed, 24 Feb 2021 17:56:10 +0000 Subject: [PATCH 8/9] add post transform Decollated Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/data/utils.py | 4 ++-- monai/transforms/__init__.py | 3 +++ monai/transforms/post/dictionary.py | 23 +++++++++++++++++++++++ tests/test_decollate.py | 9 ++++++--- 4 files changed, 34 insertions(+), 5 deletions(-) diff --git a/monai/data/utils.py b/monai/data/utils.py index 9752bed93b..98310bb83f 100644 --- a/monai/data/utils.py +++ b/monai/data/utils.py @@ -245,7 +245,7 @@ def list_data_collate(batch: Sequence): return default_collate(data) -def decollate_batch(data: dict, batch_size: Optional[int] = None): +def decollate_batch(data: dict, batch_size: Optional[int] = None) -> List[dict]: """De-collate a batch of data (for example, as produced by a `DataLoader`). Returns a list of dictionaries. Each dictionary will only contain the data for a given batch. @@ -273,7 +273,7 @@ def decollate_batch(data: dict, batch_size: Optional[int] = None): >>> {'image': tensor([[[4.3549e-01...43e-01]]]), 'image_meta_dict': {'scl_slope': 0.0}} Args: - data: data to be de-collated + data: data to be de-collated. batch_size: number of batches in data. If `None` is passed, try to figure out batch size. """ if not isinstance(data, dict): diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index 6f7c2a4f61..a38e6cd637 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -155,6 +155,9 @@ AsDiscreted, AsDiscreteD, AsDiscreteDict, + Decollated, + DecollateD, + DecollateDict, Ensembled, KeepLargestConnectedComponentd, KeepLargestConnectedComponentD, diff --git a/monai/transforms/post/dictionary.py b/monai/transforms/post/dictionary.py index 60cda11a91..78b7542861 100644 --- a/monai/transforms/post/dictionary.py +++ b/monai/transforms/post/dictionary.py @@ -20,6 +20,7 @@ import numpy as np import torch +import monai.data from monai.config import KeysCollection from monai.transforms.compose import MapTransform from monai.transforms.post.array import ( @@ -52,6 +53,9 @@ "MeanEnsembleDict", "VoteEnsembleD", "VoteEnsembleDict", + "DecollateD", + "DecollateDict", + "Decollated", ] @@ -306,9 +310,28 @@ def __init__( super().__init__(keys, ensemble, output_key) +class Decollated(MapTransform): + """ + Decollate a batch of data. + + Note that unlike most MapTransforms, this will decollate all data, so keys are not needed. + + Args: + batch_size: if not supplied, we try to determine it based on array lengths. Will raise an error if + it fails to determine it automatically. + """ + + def __init__(self, batch_size: Optional[int] = None) -> None: + self.batch_size = batch_size + + def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: + return monai.data.decollate_batch(data, self.batch_size) + + ActivationsD = ActivationsDict = Activationsd AsDiscreteD = AsDiscreteDict = AsDiscreted KeepLargestConnectedComponentD = KeepLargestConnectedComponentDict = KeepLargestConnectedComponentd LabelToContourD = LabelToContourDict = LabelToContourd MeanEnsembleD = MeanEnsembleDict = MeanEnsembled VoteEnsembleD = VoteEnsembleDict = VoteEnsembled +DecollateD = DecollateDict = Decollated diff --git a/tests/test_decollate.py b/tests/test_decollate.py index e9e7c8fdde..5c6f04b48e 100644 --- a/tests/test_decollate.py +++ b/tests/test_decollate.py @@ -18,6 +18,7 @@ from monai.data import CacheDataset, DataLoader, create_test_image_2d from monai.data.utils import decollate_batch from monai.transforms import AddChanneld, Compose, LoadImaged, RandFlipd, SpatialPadd, ToTensord +from monai.transforms.post.dictionary import Decollated from monai.utils import optional_import, set_determinism from tests.utils import make_nifti_image @@ -75,10 +76,12 @@ def test_decollation(self, _, data, batch_size=2, num_workers=2): loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers) for b, batch_data in enumerate(loader): - decollated = decollate_batch(batch_data) + decollated_1 = decollate_batch(batch_data) + decollated_2 = Decollated()(batch_data) - for i, d in enumerate(decollated): - self.check_match(dataset[b * batch_size + i], d) + for decollated in [decollated_1, decollated_2]: + for i, d in enumerate(decollated): + self.check_match(dataset[b * batch_size + i], d) if __name__ == "__main__": From 5754fdb11c14837ec05f566932e14c8e2361fa08 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Thu, 25 Feb 2021 08:54:23 +0000 Subject: [PATCH 9/9] codeformat Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/post/dictionary.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/transforms/post/dictionary.py b/monai/transforms/post/dictionary.py index 3406d4be81..85abdac0ac 100644 --- a/monai/transforms/post/dictionary.py +++ b/monai/transforms/post/dictionary.py @@ -324,7 +324,7 @@ class Decollated(MapTransform): def __init__(self, batch_size: Optional[int] = None) -> None: self.batch_size = batch_size - def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: + def __call__(self, data: dict) -> List[dict]: return monai.data.decollate_batch(data, self.batch_size)