From 3abcb116c0b58bc5053f8b70e4902180261637f7 Mon Sep 17 00:00:00 2001 From: Behrooz <3968947+drbeh@users.noreply.github.com> Date: Wed, 23 Nov 2022 11:39:07 -0500 Subject: [PATCH 1/2] Sqaush to fix DCO issue Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> --- docs/source/transforms.rst | 6 +++ monai/transforms/__init__.py | 3 ++ monai/transforms/utility/dictionary.py | 55 ++++++++++++++++++++++- tests/test_flatten_sub_keysd.py | 62 ++++++++++++++++++++++++++ tests/test_module_list.py | 2 +- 5 files changed, 126 insertions(+), 2 deletions(-) create mode 100644 tests/test_flatten_sub_keysd.py diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index 7b728fde48..6932536907 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -1975,6 +1975,12 @@ Utility (Dict) :members: :special-members: __call__ +`FlattenSubKeysd` +"""""""""""""""""""" +.. autoclass:: FlattenSubKeysd + :members: + :special-members: __call__ + `Transposed` """""""""""" .. autoclass:: Transposed diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index 9cabc167a7..a754cb9479 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -545,6 +545,9 @@ FgBgToIndicesd, FgBgToIndicesD, FgBgToIndicesDict, + FlattenSubKeysd, + FlattenSubKeysD, + FlattenSubKeysDict, Identityd, IdentityD, IdentityDict, diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index 49cdb6d97b..1c88347a4f 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -127,6 +127,9 @@ "MapLabelValueD", "MapLabelValueDict", "MapLabelValued", + "FlattenSubKeysd", + "FlattenSubKeysD", + "FlattenSubKeysDict", "RandCuCIMd", "RandCuCIMD", "RandCuCIMDict", @@ -770,7 +773,7 @@ def _delete_item(keys, d, use_re: bool = False): class SelectItemsd(MapTransform): """ Select only specified items from data dictionary to release memory. - It will copy the selected key-values and construct and new dictionary. + It will copy the selected key-values and construct a new dictionary. """ backend = [TransformBackends.TORCH, TransformBackends.NUMPY] @@ -779,6 +782,55 @@ def __call__(self, data): return {key: data[key] for key in self.key_iterator(data)} +class FlattenSubKeysd(MapTransform): + """ + If an item is dictionary, it flatten the item by moving the sub-items (defined by sub-keys) to the top level. + {"pred": {"a": ..., "b", ... }} --> {"a": ..., "b", ... } + + Args: + keys: keys of the corresponding items to be flatten + sub_keys: the sub-keys of items to be flatten. If not provided all the sub-keys are flattened. + delete_keys: whether to delete the key of the items that their sub-keys are flattened. Default to True. + prefix: optional prefix to be added to the sub-keys when moving to the top level. + By default no prefix will be added. + """ + + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + + def __init__( + self, + keys: KeysCollection, + sub_keys: Optional[KeysCollection] = None, + delete_keys: bool = True, + prefix: Optional[str] = None, + ) -> None: + super().__init__(keys) + self.sub_keys = sub_keys + self.delete_keys = delete_keys + self.prefix = prefix + + def __call__(self, data): + d = dict(data) + for key in self.key_iterator(d): + # set the sub-keys for the specified key + sub_keys = d[key].keys() if self.sub_keys is None else self.sub_keys + + # move all the sub-keys to the top level + for sk in sub_keys: + # set the top-level key for the sub-key + sk_top = f"{self.prefix}_{sk}" if self.prefix else sk + if sk_top in d: + raise ValueError( + f"'{sk_top}' already exists in the top-level keys. Please change `prefix` to avoid duplicity." + ) + d[sk_top] = d[key][sk] + + # delete top level key that is flattened + if self.delete_keys: + del d[key] + return d + + class SqueezeDimd(MapTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.SqueezeDim`. @@ -1742,3 +1794,4 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N CuCIMD = CuCIMDict = CuCIMd RandCuCIMD = RandCuCIMDict = RandCuCIMd AddCoordinateChannelsD = AddCoordinateChannelsDict = AddCoordinateChannelsd +FlattenSubKeysD = FlattenSubKeysDict = FlattenSubKeysd diff --git a/tests/test_flatten_sub_keysd.py b/tests/test_flatten_sub_keysd.py new file mode 100644 index 0000000000..24f0e88620 --- /dev/null +++ b/tests/test_flatten_sub_keysd.py @@ -0,0 +1,62 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +from parameterized import parameterized + +from monai.transforms import FlattenSubKeysd + +A = torch.randn(2, 2) +B = torch.randn(3, 3) +C = torch.randn(1, 3) +I = torch.randn(2, 3) +D1 = {"a": A, "b": B} +D2 = {"a": A, "b": B, "c": C} + + +TEST_CASE_0 = [{"keys": "pred"}, {"image": I, "pred": D1}, {"a": A, "b": B, "image": I}] +TEST_CASE_1 = [{"keys": "pred"}, {"image": I, "pred": D2}, {"a": A, "b": B, "c": C, "image": I}] +TEST_CASE_2 = [{"keys": "pred", "sub_keys": ["a", "b"]}, {"image": I, "pred": D1}, {"a": A, "b": B, "image": I}] +TEST_CASE_3 = [{"keys": "pred", "sub_keys": ["a", "b"]}, {"image": I, "pred": D2}, {"a": A, "b": B, "image": I}] +TEST_CASE_4 = [ + {"keys": "pred", "sub_keys": ["a", "b"], "delete_keys": False}, + {"image": I, "pred": D1}, + {"a": A, "b": B, "image": I, "pred": D1}, +] +TEST_CASE_5 = [ + {"keys": "pred", "sub_keys": ["a", "b"], "prefix": "new"}, + {"image": I, "pred": D2}, + {"new_a": A, "new_b": B, "image": I}, +] +TEST_CASE_ERROR_1 = [ # error for duplicate key + {"keys": "pred", "sub_keys": ["a", "b"]}, + {"image": I, "pred": D2, "a": None}, +] + + +class TestFlattenSubKeysd(unittest.TestCase): + @parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5]) + def test_dict(self, params, input_data, expected): + result = FlattenSubKeysd(**params)(input_data) + self.assertSetEqual(set(result.keys()), set(expected.keys())) + for k in expected: + self.assertEqual(id(result[k]), id(expected[k])) + + @parameterized.expand([TEST_CASE_ERROR_1]) + def test_error(self, params, input_data): + with self.assertRaises(ValueError): + FlattenSubKeysd(**params)(input_data) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_module_list.py b/tests/test_module_list.py index d0b5aaf26b..acd574d463 100644 --- a/tests/test_module_list.py +++ b/tests/test_module_list.py @@ -39,7 +39,7 @@ def test_transform_api(self): """monai subclasses of MapTransforms must have alias names ending with 'd', 'D', 'Dict'""" to_exclude = {"MapTransform"} # except for these transforms to_exclude_docs = {"Decollate", "Ensemble", "Invert", "SaveClassification", "RandTorchVision", "RandCrop"} - to_exclude_docs.update({"DeleteItems", "SelectItems", "CopyItems", "ConcatItems"}) + to_exclude_docs.update({"DeleteItems", "SelectItems", "FlattenSubKeys", "CopyItems", "ConcatItems"}) to_exclude_docs.update({"ToMetaTensor", "FromMetaTensor"}) xforms = { name: obj From 78d679b8e986b3d40a2a1c2f2353736ba364f19f Mon Sep 17 00:00:00 2001 From: Behrooz <3968947+drbeh@users.noreply.github.com> Date: Wed, 23 Nov 2022 11:45:46 -0500 Subject: [PATCH 2/2] Remove extra quote Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> --- docs/source/transforms.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index 6932536907..84959c702a 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -1976,7 +1976,7 @@ Utility (Dict) :special-members: __call__ `FlattenSubKeysd` -"""""""""""""""""""" +""""""""""""""""" .. autoclass:: FlattenSubKeysd :members: :special-members: __call__