From f885f1a832a58e4378c39b1fdd50bcd50644c722 Mon Sep 17 00:00:00 2001 From: Isaac Yang Date: Thu, 10 Dec 2020 14:26:10 -0800 Subject: [PATCH] Add ConvertToMultiChannelBasedOnBratsClassesd to handle BraTS18 dataset label Signed-off-by: Isaac Yang --- monai/transforms/utility/dictionary.py | 27 ++++++++++++++++++++ tests/test_convert_to_multi_channeld.py | 34 +++++++++++++++++++++++++ 2 files changed, 61 insertions(+) create mode 100644 tests/test_convert_to_multi_channeld.py diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index e6a9da8076..0ee3a399e2 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -636,6 +636,30 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda return d +class ConvertToMultiChannelBasedOnBratsClassesd(MapTransform): + """ + Convert labels to multi channels based on brats18 classes: + label 1 is the necrotic and non-enhancing tumor core + label 2 is the the peritumoral edema + label 4 is the GD-enhancing tumor + The possible classes are TC (Tumor core), WT (Whole tumor) + and ET (Enhancing tumor). + """ + + def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + d = dict(data) + for key in self.keys: + result = list() + # merge labels 1 (tumor non-enh) and 4 (tumor enh) to TC + result.append(np.logical_or(d[key] == 1, d[key] == 4)) + # merge labels 1 (tumor non-enh) and 4 (tumor enh) and 2 (large edema) to WT + result.append(np.logical_or(np.logical_or(d[key] == 1, d[key] == 4), d[key] == 2)) + # label 4 is ET + result.append(d[key] == 4) + d[key] = np.stack(result, axis=0).astype(np.float32) + return d + + IdentityD = IdentityDict = Identityd AsChannelFirstD = AsChannelFirstDict = AsChannelFirstd AsChannelLastD = AsChannelLastDict = AsChannelLastd @@ -653,3 +677,6 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda LambdaD = LambdaDict = Lambdad LabelToMaskD = LabelToMaskDict = LabelToMaskd FgBgToIndicesD = FgBgToIndicesDict = FgBgToIndicesd +ConvertToMultiChannelBasedOnBratsClassesD = ( + ConvertToMultiChannelBasedOnBratsClassesDict +) = ConvertToMultiChannelBasedOnBratsClassesd diff --git a/tests/test_convert_to_multi_channeld.py b/tests/test_convert_to_multi_channeld.py new file mode 100644 index 0000000000..2de3ee7394 --- /dev/null +++ b/tests/test_convert_to_multi_channeld.py @@ -0,0 +1,34 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from parameterized import parameterized + +from monai.transforms import ConvertToMultiChannelBasedOnBratsClassesd + +TEST_CASE = [ + {"keys": "label"}, + {"label": np.array([[0, 1, 2], [1, 2, 4], [0, 1, 4]])}, + np.array([[[0, 1, 0], [1, 0, 1], [0, 1, 1]], [[0, 1, 1], [1, 1, 1], [0, 1, 1]], [[0, 0, 0], [0, 0, 1], [0, 0, 1]]]), +] + + +class TestConvertToMultiChanneld(unittest.TestCase): + @parameterized.expand([TEST_CASE]) + def test_type_shape(self, keys, data, expected_result): + result = ConvertToMultiChannelBasedOnBratsClassesd(**keys)(data) + np.testing.assert_equal(result["label"], expected_result) + + +if __name__ == "__main__": + unittest.main()