From eb59e614896e1e51992a646a0655c275300fc9a8 Mon Sep 17 00:00:00 2001 From: Matthias Hadlich Date: Tue, 19 Sep 2023 16:07:16 +0200 Subject: [PATCH 1/7] Add SignalFillEmptyd tranform Signed-off-by: Matthias Hadlich --- monai/transforms/signal/dictionary.py | 54 +++++++++++++++++++++++++++ 1 file changed, 54 insertions(+) create mode 100644 monai/transforms/signal/dictionary.py diff --git a/monai/transforms/signal/dictionary.py b/monai/transforms/signal/dictionary.py new file mode 100644 index 0000000000..eb215dc25a --- /dev/null +++ b/monai/transforms/signal/dictionary.py @@ -0,0 +1,54 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +A collection of dictionary-based wrappers around the signal operations defined in :py:class:`monai.transforms.signal.array`. + +Class names are ended with 'd' to denote dictionary-based transforms. +""" + +from collections.abc import Hashable, Mapping + +import torch + +from monai.transforms.transform import MapTransform +from monai.config.type_definitions import KeysCollection, NdarrayOrTensor +from monai.transforms.signal.array import SignalFillEmpty + +__all__ = [ + "SignalFillEmptyd", + "SignalFillEmptyD", + "SignalFillEmptyDict", +] + +class SignalFillEmptyd(MapTransform): + """ + Applies the SignalFillEmptyd transform on the input. All NaN values will be replaced with the + replacement value. + + Args: + keys: keys of the corresponding items to model output. + allow_missing_keys: don't raise exception if key is missing. + replacement: The value that the NaN entries shall be mapped to. + """ + + backend = SignalFillEmpty.backend + + def __init__(self, keys: KeysCollection = None, allow_missing_keys: bool = False, replacement=0.0): + super().__init__(keys, allow_missing_keys) + self.signal_fill_empty = SignalFillEmpty(replacement=replacement) + + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Mapping[Hashable, NdarrayOrTensor]: + for key in self.key_iterator(data): + data[key] = self.signal_fill_empty(data[key]) + + return data + +SignalFillEmptyD = SignalFillEmptyDict = SignalFillEmptyd From f7414cff4e8dbd1c65b12426ab68c25590dd2fe5 Mon Sep 17 00:00:00 2001 From: Matthias Hadlich Date: Tue, 19 Sep 2023 16:11:22 +0200 Subject: [PATCH 2/7] Reformat code Signed-off-by: Matthias Hadlich --- monai/transforms/signal/dictionary.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/monai/transforms/signal/dictionary.py b/monai/transforms/signal/dictionary.py index eb215dc25a..e95a0de543 100644 --- a/monai/transforms/signal/dictionary.py +++ b/monai/transforms/signal/dictionary.py @@ -14,19 +14,16 @@ Class names are ended with 'd' to denote dictionary-based transforms. """ -from collections.abc import Hashable, Mapping +from __future__ import annotations -import torch +from collections.abc import Hashable, Mapping -from monai.transforms.transform import MapTransform from monai.config.type_definitions import KeysCollection, NdarrayOrTensor from monai.transforms.signal.array import SignalFillEmpty +from monai.transforms.transform import MapTransform + +__all__ = ["SignalFillEmptyd", "SignalFillEmptyD", "SignalFillEmptyDict"] -__all__ = [ - "SignalFillEmptyd", - "SignalFillEmptyD", - "SignalFillEmptyDict", -] class SignalFillEmptyd(MapTransform): """ @@ -48,7 +45,8 @@ def __init__(self, keys: KeysCollection = None, allow_missing_keys: bool = False def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Mapping[Hashable, NdarrayOrTensor]: for key in self.key_iterator(data): data[key] = self.signal_fill_empty(data[key]) - + return data + SignalFillEmptyD = SignalFillEmptyDict = SignalFillEmptyd From 35b34722d26a0dca2710046dd4f083d12f5db1ea Mon Sep 17 00:00:00 2001 From: Matthias Hadlich Date: Wed, 20 Sep 2023 15:34:13 +0200 Subject: [PATCH 3/7] Add SignalFillEmptyd to the transform symbols Signed-off-by: Matthias Hadlich --- monai/transforms/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index eb8c5af19e..51fd5c6288 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -345,6 +345,7 @@ SignalRandShift, SignalRemoveFrequency, ) +from .signal.dictionary import SignalFillEmptyd, SignalFillEmptyD, SignalFillEmptyDict from .smooth_field.array import ( RandSmoothDeform, RandSmoothFieldAdjustContrast, From eec9d0434ae7f9e153f4d4ded35549c84b0155ec Mon Sep 17 00:00:00 2001 From: Matthias Hadlich Date: Wed, 20 Sep 2023 15:36:02 +0200 Subject: [PATCH 4/7] Add SignalFillEmptyd to the Docs Signed-off-by: Matthias Hadlich --- docs/source/transforms.rst | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index 051eab9e0e..688796af75 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -1713,6 +1713,15 @@ Post-processing (Dict) :members: :special-members: __call__ +Signal (Dict) +^^^^^^^^^^^^^ + +`SignalFillEmptyd` +""""""""""""""""" +.. autoclass:: SignalFillEmptyd + :members: + :special-members: __call__ + Spatial (Dict) ^^^^^^^^^^^^^^ From 3d62726a117e7492f44a1bfaed3b7acd062046d2 Mon Sep 17 00:00:00 2001 From: Matthias Hadlich Date: Wed, 20 Sep 2023 15:36:32 +0200 Subject: [PATCH 5/7] Add test for SignalFillEmptyd Signed-off-by: Matthias Hadlich --- tests/test_signal_fillemptyd.py | 58 +++++++++++++++++++++++++++++++++ 1 file changed, 58 insertions(+) create mode 100644 tests/test_signal_fillemptyd.py diff --git a/tests/test_signal_fillemptyd.py b/tests/test_signal_fillemptyd.py new file mode 100644 index 0000000000..5b12055e7d --- /dev/null +++ b/tests/test_signal_fillemptyd.py @@ -0,0 +1,58 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import os +import unittest + +import numpy as np +import torch + +from monai.transforms import SignalFillEmptyd +from monai.utils.type_conversion import convert_to_tensor +from tests.utils import SkipIfBeforePyTorchVersion + +TEST_SIGNAL = os.path.join(os.path.dirname(__file__), "testing_data", "signal.npy") + + +@SkipIfBeforePyTorchVersion((1, 9)) +class TestSignalFillEmptyNumpy(unittest.TestCase): + def test_correct_parameters_multi_channels(self): + self.assertIsInstance(SignalFillEmptyd(replacement=0.0), SignalFillEmptyd) + sig = np.load(TEST_SIGNAL) + sig[:, 123] = np.NAN + data = {} + data["signal"] = sig + fillempty = SignalFillEmptyd(keys=("signal",), replacement=0.0) + data_ = fillempty(data) + + self.assertTrue(np.isnan(sig).any()) + self.assertTrue(not np.isnan(data_["signal"]).any()) + + +@SkipIfBeforePyTorchVersion((1, 9)) +class TestSignalFillEmptyTorch(unittest.TestCase): + def test_correct_parameters_multi_channels(self): + self.assertIsInstance(SignalFillEmptyd(replacement=0.0), SignalFillEmptyd) + sig = convert_to_tensor(np.load(TEST_SIGNAL)) + sig[:, 123] = convert_to_tensor(np.NAN) + data = {} + data["signal"] = sig + fillempty = SignalFillEmptyd(keys=("signal",), replacement=0.0) + data_ = fillempty(data) + + self.assertTrue(np.isnan(sig).any()) + self.assertTrue(not torch.isnan(data_["signal"]).any()) + + +if __name__ == "__main__": + unittest.main() From c994bf3998556d03ab4726fcb206a98b48f36c74 Mon Sep 17 00:00:00 2001 From: Matthias Hadlich Date: Wed, 20 Sep 2023 15:39:38 +0200 Subject: [PATCH 6/7] Fix docs Signed-off-by: Matthias Hadlich --- docs/source/transforms.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index 688796af75..b35fa5d585 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -1717,7 +1717,7 @@ Signal (Dict) ^^^^^^^^^^^^^ `SignalFillEmptyd` -""""""""""""""""" +"""""""""""""""""" .. autoclass:: SignalFillEmptyd :members: :special-members: __call__ From 992cecf5ceb9a01fc82a75cdbe5dcbd0654d41f2 Mon Sep 17 00:00:00 2001 From: Matthias Hadlich Date: Thu, 21 Sep 2023 19:15:15 +0200 Subject: [PATCH 7/7] Fix broken tests in test_signal_fillempty.py Signed-off-by: Matthias Hadlich --- tests/test_signal_fillempty.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_signal_fillempty.py b/tests/test_signal_fillempty.py index f44e4ba29a..ee606d960c 100644 --- a/tests/test_signal_fillempty.py +++ b/tests/test_signal_fillempty.py @@ -32,7 +32,7 @@ def test_correct_parameters_multi_channels(self): sig[:, 123] = np.NAN fillempty = SignalFillEmpty(replacement=0.0) fillemptysignal = fillempty(sig) - self.assertTrue(not np.isnan(fillemptysignal.any())) + self.assertTrue(not np.isnan(fillemptysignal).any()) @SkipIfBeforePyTorchVersion((1, 9)) @@ -43,7 +43,7 @@ def test_correct_parameters_multi_channels(self): sig[:, 123] = convert_to_tensor(np.NAN) fillempty = SignalFillEmpty(replacement=0.0) fillemptysignal = fillempty(sig) - self.assertTrue(not torch.isnan(fillemptysignal.any())) + self.assertTrue(not torch.isnan(fillemptysignal).any()) if __name__ == "__main__":