diff --git a/monai/transforms/composables.py b/monai/transforms/composables.py index f86afd546e..d8cc708af7 100644 --- a/monai/transforms/composables.py +++ b/monai/transforms/composables.py @@ -883,3 +883,25 @@ def __call__(self, data): for key in self.keys: d[key] = zoomer(d[key]) return d + + +@export +@alias('DeleteKeysD', 'DeleteKeysDict') +class DeleteKeysd(MapTransform): + """ + Delete specified keys from data dictionary to release memory. + It will remove the key-values and copy the others to construct a new dictionary. + """ + + def __init__(self, keys): + """ + Args: + keys (hashable items): keys of the corresponding items to be transformed. + See also: monai.transform.composables.MapTransform + """ + MapTransform.__init__(self, keys) + + def __call__(self, data): + for key in self.keys: + del data[key] + return dict(data) diff --git a/tests/test_delete_keys.py b/tests/test_delete_keys.py new file mode 100644 index 0000000000..35917e36f5 --- /dev/null +++ b/tests/test_delete_keys.py @@ -0,0 +1,38 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import time +import sys +from parameterized import parameterized +from monai.transforms.composables import DeleteKeysd + +TEST_CASE_1 = [ + {'keys': [str(i) for i in range(30)]}, + 20, + 648, +] + + +class TestDeleteKeysd(unittest.TestCase): + + @parameterized.expand([TEST_CASE_1]) + def test_memory(self, input_param, expected_key_size, expected_mem_size): + input_data = dict() + for i in range(50): + input_data[str(i)] = [time.time()] * 100000 + result = DeleteKeysd(**input_param)(input_data) + self.assertEqual(len(result.keys()), expected_key_size) + self.assertEqual(sys.getsizeof(result), expected_mem_size) + + +if __name__ == '__main__': + unittest.main()