diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index 4fad271109..57170a33a9 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -504,6 +504,18 @@ Utility :members: :special-members: __call__ +`ConvertToMultiChannelBasedOnBratsClasses` +"""""""""""""""""""""""""""""""""""""""""" +.. autoclass:: ConvertToMultiChannelBasedOnBratsClasses + :members: + :special-members: __call__ + +`AddExtremePointsChannel` +""""""""""""""""""""""""" +.. autoclass:: AddExtremePointsChannel + :members: + :special-members: __call__ + `TorchVision` """"""""""""" .. autoclass:: TorchVision @@ -975,6 +987,18 @@ Utility (Dict) :members: :special-members: __call__ +`ConvertToMultiChannelBasedOnBratsClassesd` +""""""""""""""""""""""""""""""""""""""""""" +.. autoclass:: ConvertToMultiChannelBasedOnBratsClassesd + :members: + :special-members: __call__ + +`AddExtremePointsChanneld` +"""""""""""""""""""""""""" +.. autoclass:: AddExtremePointsChanneld + :members: + :special-members: __call__ + `TorchVisiond` """""""""""""" .. autoclass:: TorchVisiond diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index c7b4c67488..4bfc8acfbd 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -199,6 +199,7 @@ AsChannelFirst, AsChannelLast, CastToType, + ConvertToMultiChannelBasedOnBratsClasses, DataStats, FgBgToIndices, Identity, diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index 7e7fe816a9..a851a56a44 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -41,6 +41,7 @@ "Lambda", "LabelToMask", "FgBgToIndices", + "ConvertToMultiChannelBasedOnBratsClasses", "AddExtremePointsChannel", "TorchVision", ] @@ -556,6 +557,27 @@ def __call__( return fg_indices, bg_indices +class ConvertToMultiChannelBasedOnBratsClasses(Transform): + """ + Convert labels to multi channels based on brats18 classes: + label 1 is the necrotic and non-enhancing tumor core + label 2 is the the peritumoral edema + label 4 is the GD-enhancing tumor + The possible classes are TC (Tumor core), WT (Whole tumor) + and ET (Enhancing tumor). + """ + + def __call__(self, img: np.ndarray) -> np.ndarray: + result = [] + # merge labels 1 (tumor non-enh) and 4 (tumor enh) to TC + result.append(np.logical_or(img == 1, img == 4)) + # merge labels 1 (tumor non-enh) and 4 (tumor enh) and 2 (large edema) to WT + result.append(np.logical_or(np.logical_or(img == 1, img == 4), img == 2)) + # label 4 is ET + result.append(img == 4) + return np.stack(result, axis=0).astype(np.float32) + + class AddExtremePointsChannel(Transform, Randomizable): """ Add extreme points of label to the image as a new channel. This transform generates extreme diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index 5c08f72c92..0ed328be0a 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -29,6 +29,7 @@ AsChannelFirst, AsChannelLast, CastToType, + ConvertToMultiChannelBasedOnBratsClasses, DataStats, FgBgToIndices, Identity, @@ -649,6 +650,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda class ConvertToMultiChannelBasedOnBratsClassesd(MapTransform): """ + Dictionary-based wrapper of :py:class:`monai.transforms.ConvertToMultiChannelBasedOnBratsClasses`. Convert labels to multi channels based on brats18 classes: label 1 is the necrotic and non-enhancing tumor core label 2 is the the peritumoral edema @@ -657,17 +659,14 @@ class ConvertToMultiChannelBasedOnBratsClassesd(MapTransform): and ET (Enhancing tumor). """ + def __init__(self, keys: KeysCollection): + super().__init__(keys) + self.converter = ConvertToMultiChannelBasedOnBratsClasses() + def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) for key in self.keys: - result = [] - # merge labels 1 (tumor non-enh) and 4 (tumor enh) to TC - result.append(np.logical_or(d[key] == 1, d[key] == 4)) - # merge labels 1 (tumor non-enh) and 4 (tumor enh) and 2 (large edema) to WT - result.append(np.logical_or(np.logical_or(d[key] == 1, d[key] == 4), d[key] == 2)) - # label 4 is ET - result.append(d[key] == 4) - d[key] = np.stack(result, axis=0).astype(np.float32) + d[key] = self.converter(d[key]) return d diff --git a/tests/test_convert_to_multi_channel.py b/tests/test_convert_to_multi_channel.py new file mode 100644 index 0000000000..ea27371ac7 --- /dev/null +++ b/tests/test_convert_to_multi_channel.py @@ -0,0 +1,33 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from parameterized import parameterized + +from monai.transforms import ConvertToMultiChannelBasedOnBratsClasses + +TEST_CASE = [ + np.array([[0, 1, 2], [1, 2, 4], [0, 1, 4]]), + np.array([[[0, 1, 0], [1, 0, 1], [0, 1, 1]], [[0, 1, 1], [1, 1, 1], [0, 1, 1]], [[0, 0, 0], [0, 0, 1], [0, 0, 1]]]), +] + + +class TestConvertToMultiChannel(unittest.TestCase): + @parameterized.expand([TEST_CASE]) + def test_type_shape(self, data, expected_result): + result = ConvertToMultiChannelBasedOnBratsClasses()(data) + np.testing.assert_equal(result, expected_result) + + +if __name__ == "__main__": + unittest.main()