diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index 5578b93077..3499afcf95 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -250,6 +250,7 @@ Identity, LabelToMask, Lambda, + RemoveRepeatedChannel, RepeatChannel, SimulateDelay, SplitChannel, @@ -305,6 +306,9 @@ RandLambdad, RandLambdaD, RandLambdaDict, + RemoveRepeatedChanneld, + RemoveRepeatedChannelD, + RemoveRepeatedChannelDict, RepeatChanneld, RepeatChannelD, RepeatChannelDict, diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index 8b161a9223..fb9ae3c089 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -31,6 +31,7 @@ "AsChannelLast", "AddChannel", "RepeatChannel", + "RemoveRepeatedChannel", "SplitChannel", "CastToType", "ToTensor", @@ -161,6 +162,32 @@ def __call__(self, img: np.ndarray) -> np.ndarray: return np.repeat(img, self.repeats, 0) +class RemoveRepeatedChannel(Transform): + """ + RemoveRepeatedChannel data to undo RepeatChannel + The `repeats` count specifies the deletion of the origin data, for example: + ``RemoveRepeatedChannel(repeats=2)([[1, 2], [1, 2], [3, 4], [3, 4]])`` generates: ``[[1, 2], [3, 4]]`` + + Args: + repeats: the number of repetitions to be deleted for each element. + """ + + def __init__(self, repeats: int) -> None: + if repeats <= 0: + raise AssertionError("repeats count must be greater than 0.") + + self.repeats = repeats + + def __call__(self, img: np.ndarray) -> np.ndarray: + """ + Apply the transform to `img`, assuming `img` is a "channel-first" array. + """ + if np.shape(img)[0] < 2: + raise AssertionError("Image must have more than one channel") + + return np.array(img[:: self.repeats, :]) + + class SplitChannel(Transform): """ Split Numpy array or PyTorch Tensor data according to the channel dim. diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index c4bd7d4cba..83426734eb 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -35,6 +35,7 @@ Identity, LabelToMask, Lambda, + RemoveRepeatedChannel, RepeatChannel, SimulateDelay, SplitChannel, @@ -52,6 +53,7 @@ "AsChannelLastd", "AddChanneld", "RepeatChanneld", + "RemoveRepeatedChanneld", "SplitChanneld", "CastToTyped", "ToTensord", @@ -82,6 +84,8 @@ "RandLambdaDict", "RepeatChannelD", "RepeatChannelDict", + "RemoveRepeatedChannelD", + "RemoveRepeatedChannelDict", "SplitChannelD", "SplitChannelDict", "CastToTypeD", @@ -226,6 +230,28 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda return d +class RemoveRepeatedChanneld(MapTransform): + """ + Dictionary-based wrapper of :py:class:`monai.transforms.RemoveRepeatedChannel`. + """ + + def __init__(self, keys: KeysCollection, repeats: int) -> None: + """ + Args: + keys: keys of the corresponding items to be transformed. + See also: :py:class:`monai.transforms.compose.MapTransform` + repeats: the number of repetitions for each element. + """ + super().__init__(keys) + self.repeater = RemoveRepeatedChannel(repeats) + + def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + d = dict(data) + for key in self.keys: + d[key] = self.repeater(d[key]) + return d + + class SplitChanneld(MapTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.SplitChannel`. @@ -836,6 +862,7 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torc AsChannelFirstD = AsChannelFirstDict = AsChannelFirstd AsChannelLastD = AsChannelLastDict = AsChannelLastd AddChannelD = AddChannelDict = AddChanneld +RemoveRepeatedChannelD = RemoveRepeatedChannelDict = RemoveRepeatedChanneld RepeatChannelD = RepeatChannelDict = RepeatChanneld SplitChannelD = SplitChannelDict = SplitChanneld CastToTypeD = CastToTypeDict = CastToTyped diff --git a/tests/test_remove_repeated_channel.py b/tests/test_remove_repeated_channel.py new file mode 100644 index 0000000000..070e0e2b8d --- /dev/null +++ b/tests/test_remove_repeated_channel.py @@ -0,0 +1,30 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from parameterized import parameterized + +from monai.transforms import RemoveRepeatedChannel + +TEST_CASE_1 = [{"repeats": 2}, np.array([[1, 2], [1, 2], [3, 4], [3, 4]]), (2, 2)] + + +class TestRemoveRepeatedChannel(unittest.TestCase): + @parameterized.expand([TEST_CASE_1]) + def test_shape(self, input_param, input_data, expected_shape): + result = RemoveRepeatedChannel(**input_param)(input_data) + self.assertEqual(result.shape, expected_shape) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_remove_repeated_channeld.py b/tests/test_remove_repeated_channeld.py new file mode 100644 index 0000000000..46c68bbdc2 --- /dev/null +++ b/tests/test_remove_repeated_channeld.py @@ -0,0 +1,34 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from parameterized import parameterized + +from monai.transforms import RemoveRepeatedChanneld + +TEST_CASE_1 = [ + {"keys": ["img"], "repeats": 2}, + {"img": np.array([[1, 2], [1, 2], [3, 4], [3, 4]]), "seg": np.array([[1, 2], [1, 2], [3, 4], [3, 4]])}, + (2, 2), +] + + +class TestRemoveRepeatedChanneld(unittest.TestCase): + @parameterized.expand([TEST_CASE_1]) + def test_shape(self, input_param, input_data, expected_shape): + result = RemoveRepeatedChanneld(**input_param)(input_data) + self.assertEqual(result["img"].shape, expected_shape) + + +if __name__ == "__main__": + unittest.main()