Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/source/transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1975,6 +1975,12 @@ Utility (Dict)
:members:
:special-members: __call__

`FlattenSubKeysd`
"""""""""""""""""
.. autoclass:: FlattenSubKeysd
:members:
:special-members: __call__

`Transposed`
""""""""""""
.. autoclass:: Transposed
Expand Down
3 changes: 3 additions & 0 deletions monai/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,6 +545,9 @@
FgBgToIndicesd,
FgBgToIndicesD,
FgBgToIndicesDict,
FlattenSubKeysd,
FlattenSubKeysD,
FlattenSubKeysDict,
Identityd,
IdentityD,
IdentityDict,
Expand Down
55 changes: 54 additions & 1 deletion monai/transforms/utility/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,9 @@
"MapLabelValueD",
"MapLabelValueDict",
"MapLabelValued",
"FlattenSubKeysd",
"FlattenSubKeysD",
"FlattenSubKeysDict",
"RandCuCIMd",
"RandCuCIMD",
"RandCuCIMDict",
Expand Down Expand Up @@ -770,7 +773,7 @@ def _delete_item(keys, d, use_re: bool = False):
class SelectItemsd(MapTransform):
"""
Select only specified items from data dictionary to release memory.
It will copy the selected key-values and construct and new dictionary.
It will copy the selected key-values and construct a new dictionary.
"""

backend = [TransformBackends.TORCH, TransformBackends.NUMPY]
Expand All @@ -779,6 +782,55 @@ def __call__(self, data):
return {key: data[key] for key in self.key_iterator(data)}


class FlattenSubKeysd(MapTransform):
"""
If an item is dictionary, it flatten the item by moving the sub-items (defined by sub-keys) to the top level.
{"pred": {"a": ..., "b", ... }} --> {"a": ..., "b", ... }

Args:
keys: keys of the corresponding items to be flatten
sub_keys: the sub-keys of items to be flatten. If not provided all the sub-keys are flattened.
delete_keys: whether to delete the key of the items that their sub-keys are flattened. Default to True.
prefix: optional prefix to be added to the sub-keys when moving to the top level.
By default no prefix will be added.
"""

backend = [TransformBackends.TORCH, TransformBackends.NUMPY]

def __init__(
self,
keys: KeysCollection,
sub_keys: Optional[KeysCollection] = None,
delete_keys: bool = True,
prefix: Optional[str] = None,
) -> None:
super().__init__(keys)
self.sub_keys = sub_keys
self.delete_keys = delete_keys
self.prefix = prefix

def __call__(self, data):
d = dict(data)
for key in self.key_iterator(d):
# set the sub-keys for the specified key
sub_keys = d[key].keys() if self.sub_keys is None else self.sub_keys

# move all the sub-keys to the top level
for sk in sub_keys:
# set the top-level key for the sub-key
sk_top = f"{self.prefix}_{sk}" if self.prefix else sk
if sk_top in d:
raise ValueError(
f"'{sk_top}' already exists in the top-level keys. Please change `prefix` to avoid duplicity."
)
d[sk_top] = d[key][sk]

# delete top level key that is flattened
if self.delete_keys:
del d[key]
return d


class SqueezeDimd(MapTransform):
"""
Dictionary-based wrapper of :py:class:`monai.transforms.SqueezeDim`.
Expand Down Expand Up @@ -1742,3 +1794,4 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N
CuCIMD = CuCIMDict = CuCIMd
RandCuCIMD = RandCuCIMDict = RandCuCIMd
AddCoordinateChannelsD = AddCoordinateChannelsDict = AddCoordinateChannelsd
FlattenSubKeysD = FlattenSubKeysDict = FlattenSubKeysd
62 changes: 62 additions & 0 deletions tests/test_flatten_sub_keysd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import torch
from parameterized import parameterized

from monai.transforms import FlattenSubKeysd

A = torch.randn(2, 2)
B = torch.randn(3, 3)
C = torch.randn(1, 3)
I = torch.randn(2, 3)
D1 = {"a": A, "b": B}
D2 = {"a": A, "b": B, "c": C}


TEST_CASE_0 = [{"keys": "pred"}, {"image": I, "pred": D1}, {"a": A, "b": B, "image": I}]
TEST_CASE_1 = [{"keys": "pred"}, {"image": I, "pred": D2}, {"a": A, "b": B, "c": C, "image": I}]
TEST_CASE_2 = [{"keys": "pred", "sub_keys": ["a", "b"]}, {"image": I, "pred": D1}, {"a": A, "b": B, "image": I}]
TEST_CASE_3 = [{"keys": "pred", "sub_keys": ["a", "b"]}, {"image": I, "pred": D2}, {"a": A, "b": B, "image": I}]
TEST_CASE_4 = [
{"keys": "pred", "sub_keys": ["a", "b"], "delete_keys": False},
{"image": I, "pred": D1},
{"a": A, "b": B, "image": I, "pred": D1},
]
TEST_CASE_5 = [
{"keys": "pred", "sub_keys": ["a", "b"], "prefix": "new"},
{"image": I, "pred": D2},
{"new_a": A, "new_b": B, "image": I},
]
TEST_CASE_ERROR_1 = [ # error for duplicate key
{"keys": "pred", "sub_keys": ["a", "b"]},
{"image": I, "pred": D2, "a": None},
]


class TestFlattenSubKeysd(unittest.TestCase):
@parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5])
def test_dict(self, params, input_data, expected):
result = FlattenSubKeysd(**params)(input_data)
self.assertSetEqual(set(result.keys()), set(expected.keys()))
for k in expected:
self.assertEqual(id(result[k]), id(expected[k]))

@parameterized.expand([TEST_CASE_ERROR_1])
def test_error(self, params, input_data):
with self.assertRaises(ValueError):
FlattenSubKeysd(**params)(input_data)


if __name__ == "__main__":
unittest.main()
2 changes: 1 addition & 1 deletion tests/test_module_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def test_transform_api(self):
"""monai subclasses of MapTransforms must have alias names ending with 'd', 'D', 'Dict'"""
to_exclude = {"MapTransform"} # except for these transforms
to_exclude_docs = {"Decollate", "Ensemble", "Invert", "SaveClassification", "RandTorchVision", "RandCrop"}
to_exclude_docs.update({"DeleteItems", "SelectItems", "CopyItems", "ConcatItems"})
to_exclude_docs.update({"DeleteItems", "SelectItems", "FlattenSubKeys", "CopyItems", "ConcatItems"})
to_exclude_docs.update({"ToMetaTensor", "FromMetaTensor"})
xforms = {
name: obj
Expand Down