From 756073fe7d12cc9a5e7e49372a3ffe62001d3ffa Mon Sep 17 00:00:00 2001 From: Behrooz <3968947+drbeh@users.noreply.github.com> Date: Thu, 8 Sep 2022 21:51:45 +0000 Subject: [PATCH 1/4] Implement SplitDimToListd Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> --- monai/transforms/utility/dictionary.py | 40 ++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index cde6bd8cc2..77b009eff0 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -420,6 +420,45 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torc return d +class SplitDimToListd(MapTransform): + """ + Split a dictionary of tensors with given keys along a dimension to the list of dictionaries with those keys. + + Args: + keys: keys of the corresponding items to be transformed. + See also: :py:class:`monai.transforms.compose.MapTransform` + dim: which dimension of input image is the channel, default to 0. + keepdim: if `True`, output will have singleton in the split dimension. If `False`, this + dimension will be squeezed. + update_meta: if `True`, copy `[key]_meta_dict` for each output and update affine to + reflect the cropped image + allow_missing_keys: don't raise exception if key is missing. + """ + + def __init__( + self, + keys: KeysCollection, + dim: int = 0, + keepdim: bool = False, + update_meta: bool = True, + allow_missing_keys: bool = False, + ) -> None: + super().__init__(keys, allow_missing_keys) + self.splitter = SplitDim(dim, keepdim, update_meta) + + def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> List[Dict[Hashable, torch.Tensor]]: + d = dict(data) + output = [] + results = [self.splitter(d[key]) for key in self.keys] + for row in zip(*results): + new_dict = {k: v for k, v in zip(self.keys, row)} + # fill in the extra keys with unmodified data + for k in set(d.keys()).difference(set(self.keys)): + new_dict[k] = deepcopy(d[k]) + output.append(new_dict) + return output + + @deprecated(since="0.8", msg_suffix="please use `SplitDimd` instead.") class SplitChanneld(SplitDimd): """ @@ -1674,6 +1713,7 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N RepeatChannelD = RepeatChannelDict = RepeatChanneld SplitChannelD = SplitChannelDict = SplitChanneld SplitDimD = SplitDimDict = SplitDimd +SplitDimToListD = SplitDimToListDict = SplitDimToListd CastToTypeD = CastToTypeDict = CastToTyped ToTensorD = ToTensorDict = ToTensord EnsureTypeD = EnsureTypeDict = EnsureTyped From 1d25028f1d0417e6ceab22086f6fb087f27e8936 Mon Sep 17 00:00:00 2001 From: Behrooz <3968947+drbeh@users.noreply.github.com> Date: Thu, 8 Sep 2022 21:52:36 +0000 Subject: [PATCH 2/4] Update init Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> --- monai/transforms/__init__.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index 7df5c4f075..ac49a19fc3 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -572,6 +572,9 @@ SplitDimd, SplitDimD, SplitDimDict, + SplitDimToListd, + SplitDimToListD, + SplitDimToListDict, SqueezeDimd, SqueezeDimD, SqueezeDimDict, From dcc2a9cf759817cb10f5cdbbc66428ee7fb622a0 Mon Sep 17 00:00:00 2001 From: Behrooz <3968947+drbeh@users.noreply.github.com> Date: Thu, 8 Sep 2022 21:57:56 +0000 Subject: [PATCH 3/4] Add to docs Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> --- docs/source/transforms.rst | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index ccd050534d..7f7ffa3deb 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -1885,6 +1885,12 @@ Utility (Dict) :members: :special-members: __call__ +`SplitDimToListd` +""""""""""""""""" +.. autoclass:: SplitDimToListd + :members: + :special-members: __call__ + `SplitChanneld` """"""""""""""" .. autoclass:: SplitChanneld From 78353e444d918f6aebf3aba16fece40b7fbcfab4 Mon Sep 17 00:00:00 2001 From: Behrooz <3968947+drbeh@users.noreply.github.com> Date: Thu, 8 Sep 2022 22:03:38 +0000 Subject: [PATCH 4/4] Add unittests Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> --- tests/min_tests.py | 1 + tests/test_splitdimtolistd.py | 82 +++++++++++++++++++++++++++++++++++ 2 files changed, 83 insertions(+) create mode 100644 tests/test_splitdimtolistd.py diff --git a/tests/min_tests.py b/tests/min_tests.py index 0f1e4e61ec..b5d25ea113 100644 --- a/tests/min_tests.py +++ b/tests/min_tests.py @@ -155,6 +155,7 @@ def run_testsuit(): "test_spacing", "test_spacingd", "test_splitdimd", + "test_splitdimtolistd", "test_surface_distance", "test_surface_dice", "test_testtimeaugmentation", diff --git a/tests/test_splitdimtolistd.py b/tests/test_splitdimtolistd.py new file mode 100644 index 0000000000..633a4f95aa --- /dev/null +++ b/tests/test_splitdimtolistd.py @@ -0,0 +1,82 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from copy import deepcopy + +import numpy as np +import torch +from parameterized import parameterized + +from monai.data.meta_tensor import MetaTensor +from monai.transforms import LoadImaged +from monai.transforms.utility.dictionary import SplitDimToListd +from tests.utils import TEST_NDARRAYS, assert_allclose, make_nifti_image, make_rand_affine + +TESTS = [] +for p in TEST_NDARRAYS: + for keepdim in (True, False): + for update_meta in (True, False): + TESTS.append((keepdim, p, update_meta)) + + +class TestSplitDimToListd(unittest.TestCase): + @classmethod + def setUpClass(cls): + arr = np.random.rand(2, 10, 8, 7) + affine = make_rand_affine() + data = {"image": make_nifti_image(arr, affine)} + + loader = LoadImaged("image") + cls.data: MetaTensor = loader(data) + + @parameterized.expand(TESTS) + def test_correct(self, keepdim, im_type, update_meta): + data = deepcopy(self.data) + data["image"] = im_type(data["image"]) + arr = data["image"] + for dim in range(arr.ndim): + out = SplitDimToListd("image", dim=dim, keepdim=keepdim, update_meta=update_meta)(data) + self.assertIsInstance(out, list) + self.assertEqual(len(out), arr.shape[dim]) + # if updating metadata, pick some random points and + # check same world coordinates between input and output + if update_meta: + for _ in range(10): + idx = [np.random.choice(i) for i in arr.shape] + split_im_idx = idx[dim] + split_idx = deepcopy(idx) + split_idx[dim] = 0 + split_im = out[split_im_idx]["image"] + if isinstance(data, MetaTensor) and isinstance(split_im, MetaTensor): + # idx[1:] to remove channel and then add 1 for 4th element + real_world = data.affine @ torch.tensor(idx[1:] + [1]).double() + real_world2 = split_im.affine @ torch.tensor(split_idx[1:] + [1]).double() + assert_allclose(real_world, real_world2) + + img_0 = out[0]["image"] + expected_ndim = arr.ndim if keepdim else arr.ndim - 1 + self.assertEqual(img_0.ndim, expected_ndim) + # assert is a shallow copy + arr[0, 0, 0, 0] *= 2 + self.assertEqual(arr.flatten()[0], img_0.flatten()[0]) + + def test_error(self): + """Should fail because splitting along singleton dimension""" + shape = (2, 1, 8, 7) + for p in TEST_NDARRAYS: + arr = p(np.random.rand(*shape)) + with self.assertRaises(RuntimeError): + _ = SplitDimToListd("image", dim=1)({"image": arr}) + + +if __name__ == "__main__": + unittest.main()