From c6650f6b8cc0abaa2bbf32c8faf05ab2f4473b8a Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Sat, 8 Feb 2020 10:53:38 +0800 Subject: [PATCH 01/10] [DLMED] add mean_dice metric and to_one_hot tool --- monai/networks/metrics/mean_dice.py | 105 ++++++++++++++++++++++++++++ monai/utils/to_onehot.py | 36 ++++++++++ tests/test_to_one_hot.py | 37 ++++++++++ 3 files changed, 178 insertions(+) create mode 100644 monai/networks/metrics/mean_dice.py create mode 100644 monai/utils/to_onehot.py create mode 100644 tests/test_to_one_hot.py diff --git a/monai/networks/metrics/mean_dice.py b/monai/networks/metrics/mean_dice.py new file mode 100644 index 0000000000..012fb7e81f --- /dev/null +++ b/monai/networks/metrics/mean_dice.py @@ -0,0 +1,105 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable, Union, Optional, Sequence + +import torch + +from ignite.exceptions import NotComputableError +from ignite.metrics import Metric +from ignite.metrics.metric import sync_all_reduce, reinit__is_reduced +from monai.utils.to_onehot import to_onehot + +__all__ = [ + 'MeanDice' +] + + +class MeanDice(Metric): + """Computes dice score metric from full size Tensor and collects average. + + Args: + remove_bg (Bool): skip dice computation on the first channel of the predicted output or not. + logit_thresh (Float): the threshold value to round value to 0.0 and 1.0, default is 0.5. + is_onehot_targets (Bool): whether the label data(y) is already in One-Hot format, will convert if not. + output_transform (Callable): transform the ignite.engine.state.output into [y_pred, y] pair. + device (torch.device): device specification in case of distributed computation usage. + + Note: + This metric extends from Ignite Metric, for more details, please check: + https://github.com/pytorch/ignite/tree/master/ignite/metrics + + """ + def __init__( + self, + remove_bg=True, + logit_thresh=0.5, + is_onehot_targets=False, + output_transform: Callable = lambda x: x, + device: Optional[Union[str, torch.device]] = None + ): + super(MeanDice, self).__init__(output_transform, device=device) + self.remove_bg = remove_bg + self.logit_thresh = logit_thresh + self.is_onehot_targets = is_onehot_targets + + @reinit__is_reduced + def reset(self): + self._sum = 0 + self._num_examples = 0 + + @reinit__is_reduced + def update(self, output: Sequence[Union[torch.Tensor, dict]]): + assert len(output) == 2, 'MeanDice metric can only support y_pred and y.' + y_pred, y = output + + average = self._function(y_pred, y) + + if len(average.shape) != 0: + raise ValueError('_function did not return the average loss.') + + n = len(y) + self._sum += average.item() * n + self._num_examples += n + + @sync_all_reduce("_sum", "_num_examples") + def compute(self): + if self._num_examples == 0: + raise NotComputableError( + 'MeanDice must have at least one example before it can be computed.') + return self._sum / self._num_examples + + def _function(self, y_pred, y): + n_channels_y_pred = y_pred.shape[1] + n_len = len(y_pred.shape) + assert n_len == 4 or n_len == 5, 'unsupported input shape.' + + if self.is_onehot_targets is False: + y = to_onehot(y, n_channels_y_pred) + + if self.remove_bg: + y = y[:, 1:] + y_pred = y_pred[:, 1:] + + y = (y >= self.logit_thresh).float() + y_pred = (y_pred >= self.logit_thresh).float() + + # reducing only spatial dimensions (not batch nor channels) + reduce_axis = list(range(2, n_len)) + intersection = torch.sum(y * y_pred, reduce_axis) + + y_o = torch.sum(y, reduce_axis) + y_pred_o = torch.sum(y_pred, reduce_axis) + denominator = y_o + y_pred_o + + f = (2.0 * intersection) / denominator + # final reduce_mean across batches and channels + return torch.mean(f) diff --git a/monai/utils/to_onehot.py b/monai/utils/to_onehot.py new file mode 100644 index 0000000000..a4f2689319 --- /dev/null +++ b/monai/utils/to_onehot.py @@ -0,0 +1,36 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn.functional as f + + +def to_onehot(data, num_classes): + """Util function to convert PyTorch tensor to One-Hot encoding format. + + Args: + data (torch.Tensor): target data to convert One-Hot format. + num_classes (Int): the class number in the Tensor data. + + """ + assert num_classes is not None and type(num_classes) == int, 'must set class number for one-hot.' + + data = torch.squeeze(data, 1) + data = f.one_hot(data.long(), num_classes) + num_dims = len(data.shape) + assert num_dims == 4 or num_dims == 5, 'unsupported input shape.' + if num_dims == 5: + data = data.permute(0, 4, 1, 2, 3) + elif num_dims == 4: + data = data.permute(0, 3, 1, 2) + if data.is_contiguous() is False: + data = data.contiguous() + return data diff --git a/tests/test_to_one_hot.py b/tests/test_to_one_hot.py new file mode 100644 index 0000000000..f988bfca8d --- /dev/null +++ b/tests/test_to_one_hot.py @@ -0,0 +1,37 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +from parameterized import parameterized + +from monai.utils.to_onehot import to_onehot + +TEST_CASE_1 = [ # single channel 2D, batch 16 + { + 'data': torch.randint(low=0, high=2, size=(16, 1, 32, 32)), + 'num_classes': 3 + }, + (16, 3, 32, 32), +] + + +class TestToOneHot(unittest.TestCase): + + @parameterized.expand([TEST_CASE_1]) + def test_shape(self, input_data, expected_shape): + result = to_onehot(**input_data) + self.assertEqual(result.shape, expected_shape) + + +if __name__ == '__main__': + unittest.main() From 65baaf5eefc3841afa76e4ca7b49ec68a5c1fe5c Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Sun, 9 Feb 2020 22:16:25 +0800 Subject: [PATCH 02/10] [DLMED] add support to multi-classes output --- monai/networks/metrics/mean_dice.py | 43 ++++++---------------------- monai/utils/compute_meandice.py | 44 +++++++++++++++++++++++++++++ 2 files changed, 52 insertions(+), 35 deletions(-) create mode 100644 monai/utils/compute_meandice.py diff --git a/monai/networks/metrics/mean_dice.py b/monai/networks/metrics/mean_dice.py index 012fb7e81f..e8bf9e9d29 100644 --- a/monai/networks/metrics/mean_dice.py +++ b/monai/networks/metrics/mean_dice.py @@ -9,14 +9,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, Union, Optional, Sequence - import torch - +from typing import Callable, Union, Optional, Sequence from ignite.exceptions import NotComputableError from ignite.metrics import Metric from ignite.metrics.metric import sync_all_reduce, reinit__is_reduced -from monai.utils.to_onehot import to_onehot +from monai.utils.compute_meandice import compute_meandice __all__ = [ 'MeanDice' @@ -28,12 +26,14 @@ class MeanDice(Metric): Args: remove_bg (Bool): skip dice computation on the first channel of the predicted output or not. - logit_thresh (Float): the threshold value to round value to 0.0 and 1.0, default is 0.5. is_onehot_targets (Bool): whether the label data(y) is already in One-Hot format, will convert if not. + logit_thresh (Float): the threshold value to round value to 0.0 and 1.0, default is 0.5. output_transform (Callable): transform the ignite.engine.state.output into [y_pred, y] pair. device (torch.device): device specification in case of distributed computation usage. Note: + (1) if this is multi-labels task(One-Hot label), use logit_thresh to convert y_pred to 0 or 1. + (2) if this is multi-classes task(non-Ono-Hot label), use Argmax to select index and convert to One-Hot. This metric extends from Ignite Metric, for more details, please check: https://github.com/pytorch/ignite/tree/master/ignite/metrics @@ -41,15 +41,15 @@ class MeanDice(Metric): def __init__( self, remove_bg=True, - logit_thresh=0.5, is_onehot_targets=False, + logit_thresh=0.5, output_transform: Callable = lambda x: x, device: Optional[Union[str, torch.device]] = None ): super(MeanDice, self).__init__(output_transform, device=device) self.remove_bg = remove_bg - self.logit_thresh = logit_thresh self.is_onehot_targets = is_onehot_targets + self.logit_thresh = logit_thresh @reinit__is_reduced def reset(self): @@ -61,7 +61,7 @@ def update(self, output: Sequence[Union[torch.Tensor, dict]]): assert len(output) == 2, 'MeanDice metric can only support y_pred and y.' y_pred, y = output - average = self._function(y_pred, y) + average = compute_meandice(y_pred, y, self.remove_bg, self.is_onehot_targets, self.logit_thresh) if len(average.shape) != 0: raise ValueError('_function did not return the average loss.') @@ -76,30 +76,3 @@ def compute(self): raise NotComputableError( 'MeanDice must have at least one example before it can be computed.') return self._sum / self._num_examples - - def _function(self, y_pred, y): - n_channels_y_pred = y_pred.shape[1] - n_len = len(y_pred.shape) - assert n_len == 4 or n_len == 5, 'unsupported input shape.' - - if self.is_onehot_targets is False: - y = to_onehot(y, n_channels_y_pred) - - if self.remove_bg: - y = y[:, 1:] - y_pred = y_pred[:, 1:] - - y = (y >= self.logit_thresh).float() - y_pred = (y_pred >= self.logit_thresh).float() - - # reducing only spatial dimensions (not batch nor channels) - reduce_axis = list(range(2, n_len)) - intersection = torch.sum(y * y_pred, reduce_axis) - - y_o = torch.sum(y, reduce_axis) - y_pred_o = torch.sum(y_pred, reduce_axis) - denominator = y_o + y_pred_o - - f = (2.0 * intersection) / denominator - # final reduce_mean across batches and channels - return torch.mean(f) diff --git a/monai/utils/compute_meandice.py b/monai/utils/compute_meandice.py new file mode 100644 index 0000000000..45d4e83c40 --- /dev/null +++ b/monai/utils/compute_meandice.py @@ -0,0 +1,44 @@ +import torch +from monai.utils.to_onehot import to_onehot + + +def compute_meandice(y_pred, y, remove_bg, is_onehot_targets, logit_thresh=0.5): + """Computes dice score metric from full size Tensor and collects average. + + Args: + remove_bg (Bool): skip dice computation on the first channel of the predicted output or not. + is_onehot_targets (Bool): whether the label data(y) is already in One-Hot format, will convert if not. + logit_thresh (Float): the threshold value to round value to 0.0 and 1.0, default is 0.5. + + Note: + (1) if this is multi-labels task(One-Hot label), use logit_thresh to convert y_pred to 0 or 1. + (2) if this is multi-classes task(non-Ono-Hot label), use Argmax to select index and convert to One-Hot. + + """ + n_channels_y_pred = y_pred.shape[1] + n_len = len(y_pred.shape) + assert n_len == 4 or n_len == 5, 'unsupported input shape.' + + if is_onehot_targets is True: + y_pred = (y_pred >= logit_thresh).float() + y = (y >= logit_thresh).float() + else: + y_pred = (torch.argmax(y_pred, dim=1)).float() + y_pred = to_onehot(y_pred, n_channels_y_pred) + y = to_onehot(y, n_channels_y_pred) + + if remove_bg: + y = y[:, 1:] + y_pred = y_pred[:, 1:] + + # reducing only spatial dimensions (not batch nor channels) + reduce_axis = list(range(2, n_len)) + intersection = torch.sum(y * y_pred, reduce_axis) + + y_o = torch.sum(y, reduce_axis) + y_pred_o = torch.sum(y_pred, reduce_axis) + denominator = y_o + y_pred_o + + f = (2.0 * intersection) / denominator + # final reduce_mean across batches and channels + return torch.mean(f) From 190087f5fe5c055d5f535eadd99e2e34f34ce749 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Mon, 10 Feb 2020 10:45:37 +0800 Subject: [PATCH 03/10] [DLMED] add unit tests and support to add sigmoid and softmax --- monai/networks/metrics/mean_dice.py | 7 +++- monai/utils/compute_meandice.py | 20 +++++++++- monai/utils/to_onehot.py | 3 ++ tests/test_compute_meandice.py | 58 +++++++++++++++++++++++++++++ tests/test_to_one_hot.py | 4 +- 5 files changed, 88 insertions(+), 4 deletions(-) create mode 100644 tests/test_compute_meandice.py diff --git a/monai/networks/metrics/mean_dice.py b/monai/networks/metrics/mean_dice.py index e8bf9e9d29..936e86471b 100644 --- a/monai/networks/metrics/mean_dice.py +++ b/monai/networks/metrics/mean_dice.py @@ -43,6 +43,8 @@ def __init__( remove_bg=True, is_onehot_targets=False, logit_thresh=0.5, + add_sigmoid=False, + add_softmax=False, output_transform: Callable = lambda x: x, device: Optional[Union[str, torch.device]] = None ): @@ -50,6 +52,8 @@ def __init__( self.remove_bg = remove_bg self.is_onehot_targets = is_onehot_targets self.logit_thresh = logit_thresh + self.add_sigmoid = add_sigmoid + self.add_softmax = add_softmax @reinit__is_reduced def reset(self): @@ -61,7 +65,8 @@ def update(self, output: Sequence[Union[torch.Tensor, dict]]): assert len(output) == 2, 'MeanDice metric can only support y_pred and y.' y_pred, y = output - average = compute_meandice(y_pred, y, self.remove_bg, self.is_onehot_targets, self.logit_thresh) + average = compute_meandice(y_pred, y, self.remove_bg, self.is_onehot_targets, + self.logit_thresh, self.add_sigmoid, self.add_softmax) if len(average.shape) != 0: raise ValueError('_function did not return the average loss.') diff --git a/monai/utils/compute_meandice.py b/monai/utils/compute_meandice.py index 45d4e83c40..e2f3e6fb82 100644 --- a/monai/utils/compute_meandice.py +++ b/monai/utils/compute_meandice.py @@ -2,13 +2,26 @@ from monai.utils.to_onehot import to_onehot -def compute_meandice(y_pred, y, remove_bg, is_onehot_targets, logit_thresh=0.5): +def compute_meandice( + y_pred, + y, + remove_bg, + is_onehot_targets, + logit_thresh=0.5, + add_sigmoid=False, + add_softmax=False +): """Computes dice score metric from full size Tensor and collects average. Args: + y_pred (torch.Tensor): input data to compute, typical segmentation model output. + it must be One-Hot format and first dim is batch, example shape: [16, 3, 32, 32]. + y (torch.Tensor): ground true to compute mean dice metric, the first dim is batch. remove_bg (Bool): skip dice computation on the first channel of the predicted output or not. is_onehot_targets (Bool): whether the label data(y) is already in One-Hot format, will convert if not. logit_thresh (Float): the threshold value to round value to 0.0 and 1.0, default is 0.5. + add_sigmoid (Bool): whether to add sigmoid function to y_pred before computation. + add_softmax (Bool): whether to add softmax function to y_pred before computation. Note: (1) if this is multi-labels task(One-Hot label), use logit_thresh to convert y_pred to 0 or 1. @@ -19,6 +32,11 @@ def compute_meandice(y_pred, y, remove_bg, is_onehot_targets, logit_thresh=0.5): n_len = len(y_pred.shape) assert n_len == 4 or n_len == 5, 'unsupported input shape.' + if add_sigmoid is True: + y_pred = torch.sigmoid(y_pred) + if add_softmax is True: + y_pred = torch.nn.functional.softmax(y_pred, dim=1) + if is_onehot_targets is True: y_pred = (y_pred >= logit_thresh).float() y = (y >= logit_thresh).float() diff --git a/monai/utils/to_onehot.py b/monai/utils/to_onehot.py index a4f2689319..f2b6b97761 100644 --- a/monai/utils/to_onehot.py +++ b/monai/utils/to_onehot.py @@ -15,6 +15,9 @@ def to_onehot(data, num_classes): """Util function to convert PyTorch tensor to One-Hot encoding format. + The input data should only have 1 channel and the first dim is batch. + Example shapes: [16, 1, 96, 96], [16, 1, 96, 96, 32]. + And the data values must match "num_classes", example: num_classes = 10 and values in [0 ... 9]. Args: data (torch.Tensor): target data to convert One-Hot format. diff --git a/tests/test_compute_meandice.py b/tests/test_compute_meandice.py new file mode 100644 index 0000000000..32b1aebc43 --- /dev/null +++ b/tests/test_compute_meandice.py @@ -0,0 +1,58 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +from parameterized import parameterized + +from monai.utils.compute_meandice import compute_meandice + +# keep background +TEST_CASE_1 = [ + { + 'y_pred': torch.tensor([[[[1., -1.], [-1., 1.]]]]), + 'y': torch.tensor([[[[1., 0.], [1., 1.]]]]), + 'remove_bg': False, + 'is_onehot_targets': True, + 'logit_thresh': 0.5, + 'add_sigmoid': True, + 'add_softmax': False + }, + 0.8000 +] + +# remove background and not One-Hot target +TEST_CASE_2 = [ + { + 'y_pred': torch.tensor([[[[-1., 3.], [2., -4.]], [[0., -1.], [3., 2.]], [[0., 1.], [2., -1.]]], + [[[-2., 0.], [3., 1.]], [[0., 2.], [1., -2.]], [[-1., 2.], [4., 0.]]]]), + 'y': torch.tensor([[[[1, 2], [1, 0]]], [[[1, 1], [2, 0]]]]), + 'remove_bg': True, + 'is_onehot_targets': False, + 'logit_thresh': None, + 'add_sigmoid': False, + 'add_softmax': True + }, + 0.4583 +] + + +class TestComputeMeanDice(unittest.TestCase): + + @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) + def test_value(self, input_data, expected_value): + result = compute_meandice(**input_data) + self.assertAlmostEqual(result.item(), expected_value, places=4) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_to_one_hot.py b/tests/test_to_one_hot.py index f988bfca8d..b65fc8d0f7 100644 --- a/tests/test_to_one_hot.py +++ b/tests/test_to_one_hot.py @@ -18,10 +18,10 @@ TEST_CASE_1 = [ # single channel 2D, batch 16 { - 'data': torch.randint(low=0, high=2, size=(16, 1, 32, 32)), + 'data': torch.tensor([[[[0, 1], [1, 2]]], [[[2, 1], [1, 0]]]]), 'num_classes': 3 }, - (16, 3, 32, 32), + (2, 3, 2, 2), ] From 68e6378af23c60c39052f9f0db3e7d075d30791f Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Mon, 10 Feb 2020 12:52:51 +0000 Subject: [PATCH 04/10] fixes to_onehot --- monai/utils/to_onehot.py | 22 +++++++--------- .../{test_to_one_hot.py => test_to_onehot.py} | 26 ++++++++++++++----- 2 files changed, 29 insertions(+), 19 deletions(-) rename tests/{test_to_one_hot.py => test_to_onehot.py} (50%) diff --git a/monai/utils/to_onehot.py b/monai/utils/to_onehot.py index f2b6b97761..e9137a5ff7 100644 --- a/monai/utils/to_onehot.py +++ b/monai/utils/to_onehot.py @@ -13,27 +13,25 @@ import torch.nn.functional as f -def to_onehot(data, num_classes): - """Util function to convert PyTorch tensor to One-Hot encoding format. +def to_onehot(data, num_classes: int): + """Utility function to convert PyTorch tensor to One-Hot encoding format. The input data should only have 1 channel and the first dim is batch. Example shapes: [16, 1, 96, 96], [16, 1, 96, 96, 32]. And the data values must match "num_classes", example: num_classes = 10 and values in [0 ... 9]. Args: data (torch.Tensor): target data to convert One-Hot format. - num_classes (Int): the class number in the Tensor data. + num_classes (int): the class number in the Tensor data. """ - assert num_classes is not None and type(num_classes) == int, 'must set class number for one-hot.' + num_dims = data.dim() + if num_dims < 2 or data.shape[1] != 1: + raise ValueError('data should have a channel with length equals to one.') data = torch.squeeze(data, 1) data = f.one_hot(data.long(), num_classes) - num_dims = len(data.shape) - assert num_dims == 4 or num_dims == 5, 'unsupported input shape.' - if num_dims == 5: - data = data.permute(0, 4, 1, 2, 3) - elif num_dims == 4: - data = data.permute(0, 3, 1, 2) - if data.is_contiguous() is False: - data = data.contiguous() + new_axes = [0, -1] + list(range(1, num_dims - 1)) + data = data.permute(*new_axes) + if not data.is_contiguous(): + return data.contiguous() return data diff --git a/tests/test_to_one_hot.py b/tests/test_to_onehot.py similarity index 50% rename from tests/test_to_one_hot.py rename to tests/test_to_onehot.py index b65fc8d0f7..b3a528ee64 100644 --- a/tests/test_to_one_hot.py +++ b/tests/test_to_onehot.py @@ -11,26 +11,38 @@ import unittest +import numpy as np import torch from parameterized import parameterized from monai.utils.to_onehot import to_onehot -TEST_CASE_1 = [ # single channel 2D, batch 16 - { - 'data': torch.tensor([[[[0, 1], [1, 2]]], [[[2, 1], [1, 0]]]]), - 'num_classes': 3 - }, +TEST_CASE_1 = [ # single channel 2D, batch 3, shape (2, 1, 2, 2) + {'data': torch.tensor([[[[0, 1], [1, 2]]], [[[2, 1], [1, 0]]]]), 'num_classes': 3}, (2, 3, 2, 2), ] +TEST_CASE_2 = [ # single channel 1D, batch 2, shape (2, 1, 4) + {'data': torch.tensor([[[1, 2, 2, 0]], [[2, 1, 0, 1]]]), 'num_classes': 3}, + (2, 3, 4), + np.array([[[0, 0, 0, 1], [1, 0, 0, 0], [0, 1, 1, 0]], [[0, 0, 1, 0], [0, 1, 0, 1], [1, 0, 0, 0]]]), +] + +TEST_CASE_3 = [ # single channel 0D, batch 2, shape (2, 1) + {'data': torch.tensor([[1.], [2.]]), 'num_classes': 3}, + (2, 3), + np.array([[0, 1, 0], [0, 0, 1]]), +] + class TestToOneHot(unittest.TestCase): - @parameterized.expand([TEST_CASE_1]) - def test_shape(self, input_data, expected_shape): + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) + def test_shape(self, input_data, expected_shape, expected_result=None): result = to_onehot(**input_data) self.assertEqual(result.shape, expected_shape) + if expected_result is not None: + self.assertTrue(np.allclose(expected_result, result.numpy())) if __name__ == '__main__': From 3519e71da8e7606ce73dea8b8bb3e802b3cec9ed Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Mon, 10 Feb 2020 13:58:30 +0000 Subject: [PATCH 05/10] fixes compute meandice --- monai/utils/compute_meandice.py | 67 +++++++++++++++++---------------- tests/test_compute_meandice.py | 29 +++++++------- 2 files changed, 50 insertions(+), 46 deletions(-) diff --git a/monai/utils/compute_meandice.py b/monai/utils/compute_meandice.py index e2f3e6fb82..2869fa5302 100644 --- a/monai/utils/compute_meandice.py +++ b/monai/utils/compute_meandice.py @@ -1,56 +1,57 @@ import torch + from monai.utils.to_onehot import to_onehot -def compute_meandice( - y_pred, - y, - remove_bg, - is_onehot_targets, - logit_thresh=0.5, - add_sigmoid=False, - add_softmax=False -): +def compute_meandice(y_pred, + y, + exclude_bg=False, + to_onehot_y=True, + mutually_exclusive=True, + add_sigmoid=False, + logit_thresh=None): """Computes dice score metric from full size Tensor and collects average. Args: y_pred (torch.Tensor): input data to compute, typical segmentation model output. it must be One-Hot format and first dim is batch, example shape: [16, 3, 32, 32]. - y (torch.Tensor): ground true to compute mean dice metric, the first dim is batch. - remove_bg (Bool): skip dice computation on the first channel of the predicted output or not. - is_onehot_targets (Bool): whether the label data(y) is already in One-Hot format, will convert if not. - logit_thresh (Float): the threshold value to round value to 0.0 and 1.0, default is 0.5. + y (torch.Tensor): ground truth to compute mean dice metric, the first dim is batch. + exclude_bg (Bool): whether to skip dice computation on the first channel of the predicted output. + to_onehot_y (Bool): whether to convert `y` into the one-hot format. + mutually_exclusive (Bool): if True, `y_pred` will be converted into a binary matrix using + a combination of argmax and to_onehot. add_sigmoid (Bool): whether to add sigmoid function to y_pred before computation. - add_softmax (Bool): whether to add softmax function to y_pred before computation. + logit_thresh (Float): the threshold value used to convert `y_pred` into a binary matrix. Note: - (1) if this is multi-labels task(One-Hot label), use logit_thresh to convert y_pred to 0 or 1. - (2) if this is multi-classes task(non-Ono-Hot label), use Argmax to select index and convert to One-Hot. + This method provide two options to convert `y_pred` into a binary matrix: + (1) when `mutually_exclusive` is True, it uses a combination of argmax and to_onehot + (2) when `mutually_exclusive` is False, it uses a threshold `logit_thresh` + (optionally with a sigmoid function before thresholding). """ n_channels_y_pred = y_pred.shape[1] - n_len = len(y_pred.shape) - assert n_len == 4 or n_len == 5, 'unsupported input shape.' - - if add_sigmoid is True: - y_pred = torch.sigmoid(y_pred) - if add_softmax is True: - y_pred = torch.nn.functional.softmax(y_pred, dim=1) - - if is_onehot_targets is True: - y_pred = (y_pred >= logit_thresh).float() - y = (y >= logit_thresh).float() - else: - y_pred = (torch.argmax(y_pred, dim=1)).float() + + if mutually_exclusive: + if logit_thresh is not None: + raise ValueError('`logit_thresh` is incompatible when mutually_exlcusive is True.') + y_pred = torch.argmax(y_pred, dim=1, keepdim=True) y_pred = to_onehot(y_pred, n_channels_y_pred) + else: # channel-wise thresholding + if add_sigmoid: + y_pred = torch.sigmoid(y_pred) + if logit_thresh is not None: + y_pred = (y_pred >= logit_thresh).float() + + if to_onehot_y: y = to_onehot(y, n_channels_y_pred) - if remove_bg: - y = y[:, 1:] - y_pred = y_pred[:, 1:] + if exclude_bg: + y = y[:, 1:] if y.shape[1] > 1 else y + y_pred = y_pred[:, 1:] if y_pred.shape[1] > 1 else y_pred # reducing only spatial dimensions (not batch nor channels) - reduce_axis = list(range(2, n_len)) + reduce_axis = list(range(2, y_pred.dim())) intersection = torch.sum(y * y_pred, reduce_axis) y_o = torch.sum(y, reduce_axis) diff --git a/tests/test_compute_meandice.py b/tests/test_compute_meandice.py index 32b1aebc43..01084aec77 100644 --- a/tests/test_compute_meandice.py +++ b/tests/test_compute_meandice.py @@ -21,28 +21,31 @@ { 'y_pred': torch.tensor([[[[1., -1.], [-1., 1.]]]]), 'y': torch.tensor([[[[1., 0.], [1., 1.]]]]), - 'remove_bg': False, - 'is_onehot_targets': True, + 'exclude_bg': False, + 'to_onehot_y': False, + 'mutually_exclusive': False, 'logit_thresh': 0.5, 'add_sigmoid': True, - 'add_softmax': False }, - 0.8000 + 0.8000, ] # remove background and not One-Hot target TEST_CASE_2 = [ { - 'y_pred': torch.tensor([[[[-1., 3.], [2., -4.]], [[0., -1.], [3., 2.]], [[0., 1.], [2., -1.]]], - [[[-2., 0.], [3., 1.]], [[0., 2.], [1., -2.]], [[-1., 2.], [4., 0.]]]]), - 'y': torch.tensor([[[[1, 2], [1, 0]]], [[[1, 1], [2, 0]]]]), - 'remove_bg': True, - 'is_onehot_targets': False, - 'logit_thresh': None, - 'add_sigmoid': False, - 'add_softmax': True + 'y_pred': + torch.tensor([[[[-1., 3.], [2., -4.]], [[0., -1.], [3., 2.]], [[0., 1.], [2., -1.]]], + [[[-2., 0.], [3., 1.]], [[0., 2.], [1., -2.]], [[-1., 2.], [4., 0.]]]]), + 'y': + torch.tensor([[[[1, 2], [1, 0]]], [[[1, 1], [2, 0]]]]), + 'exclude_bg': + True, + 'to_onehot_y': + True, + 'mutually_exclusive': + True, }, - 0.4583 + 0.4583, ] From cb47d5ad70f69c55d48e7f04c505936d439e33a4 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Mon, 10 Feb 2020 14:10:27 +0000 Subject: [PATCH 06/10] update mean_dice metrics handler --- monai/networks/metrics/mean_dice.py | 57 ++++++++++++++--------------- monai/utils/compute_meandice.py | 2 +- 2 files changed, 28 insertions(+), 31 deletions(-) diff --git a/monai/networks/metrics/mean_dice.py b/monai/networks/metrics/mean_dice.py index 936e86471b..f0f92f203c 100644 --- a/monai/networks/metrics/mean_dice.py +++ b/monai/networks/metrics/mean_dice.py @@ -22,38 +22,39 @@ class MeanDice(Metric): - """Computes dice score metric from full size Tensor and collects average. - - Args: - remove_bg (Bool): skip dice computation on the first channel of the predicted output or not. - is_onehot_targets (Bool): whether the label data(y) is already in One-Hot format, will convert if not. - logit_thresh (Float): the threshold value to round value to 0.0 and 1.0, default is 0.5. - output_transform (Callable): transform the ignite.engine.state.output into [y_pred, y] pair. - device (torch.device): device specification in case of distributed computation usage. - - Note: - (1) if this is multi-labels task(One-Hot label), use logit_thresh to convert y_pred to 0 or 1. - (2) if this is multi-classes task(non-Ono-Hot label), use Argmax to select index and convert to One-Hot. - This metric extends from Ignite Metric, for more details, please check: - https://github.com/pytorch/ignite/tree/master/ignite/metrics - + """Computes dice score metric from full size Tensor and collects average over batch, class-channels, iterations. """ def __init__( self, - remove_bg=True, - is_onehot_targets=False, + exclude_bg=True, + to_onehot_y=False, logit_thresh=0.5, add_sigmoid=False, - add_softmax=False, + mutually_exclusive=False, output_transform: Callable = lambda x: x, device: Optional[Union[str, torch.device]] = None ): + """ + + Args: + exclude_bg (Bool): skip dice computation on the first channel of the predicted output or not. + to_onehot_y (Bool): whether the label data(y) is already in One-Hot format, will convert if not. + logit_thresh (Float): the threshold value to round value to 0.0 and 1.0, default is 0.5. + add_sigmoid (Bool): whether to add sigmoid function to the output prediction before computating Dice. + mutually_exclusive (Bool): if True, the output prediction will be converted into a binary matrix using + a combination of argmax and to_onehot. + output_transform (Callable): transform the ignite.engine.state.output into [y_pred, y] pair. + device (torch.device): device specification in case of distributed computation usage. + """ super(MeanDice, self).__init__(output_transform, device=device) - self.remove_bg = remove_bg - self.is_onehot_targets = is_onehot_targets + self.exclude_bg = exclude_bg + self.to_onehot_y = to_onehot_y self.logit_thresh = logit_thresh self.add_sigmoid = add_sigmoid - self.add_softmax = add_softmax + self.mutually_exclusive = mutually_exclusive + + self._sum = 0 + self._num_examples = 0 @reinit__is_reduced def reset(self): @@ -64,16 +65,12 @@ def reset(self): def update(self, output: Sequence[Union[torch.Tensor, dict]]): assert len(output) == 2, 'MeanDice metric can only support y_pred and y.' y_pred, y = output + average = compute_meandice(y_pred, y, self.exclude_bg, self.to_onehot_y, self.mutually_exclusive, + self.add_sigmoid, self.logit_thresh) - average = compute_meandice(y_pred, y, self.remove_bg, self.is_onehot_targets, - self.logit_thresh, self.add_sigmoid, self.add_softmax) - - if len(average.shape) != 0: - raise ValueError('_function did not return the average loss.') - - n = len(y) - self._sum += average.item() * n - self._num_examples += n + batch_size = len(y) + self._sum += average.item() * batch_size + self._num_examples += batch_size @sync_all_reduce("_sum", "_num_examples") def compute(self): diff --git a/monai/utils/compute_meandice.py b/monai/utils/compute_meandice.py index 2869fa5302..1daafa83b7 100644 --- a/monai/utils/compute_meandice.py +++ b/monai/utils/compute_meandice.py @@ -34,7 +34,7 @@ def compute_meandice(y_pred, if mutually_exclusive: if logit_thresh is not None: - raise ValueError('`logit_thresh` is incompatible when mutually_exlcusive is True.') + raise ValueError('`logit_thresh` is incompatible when mutually_exclusive is True.') y_pred = torch.argmax(y_pred, dim=1, keepdim=True) y_pred = to_onehot(y_pred, n_channels_y_pred) else: # channel-wise thresholding From 007971adaf00a516bb4d4c600c97b1bca3eb29c6 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Mon, 10 Feb 2020 16:45:01 +0000 Subject: [PATCH 07/10] merge utils/to_onehot.py and networks/utils/one_hot into networks/utils/one_hot --- monai/networks/losses/dice.py | 5 +---- monai/networks/utils.py | 27 ++++++++++++++++-------- monai/utils/compute_meandice.py | 6 +++--- monai/utils/to_onehot.py | 37 --------------------------------- tests/test_to_onehot.py | 10 ++++----- 5 files changed, 27 insertions(+), 58 deletions(-) delete mode 100644 monai/utils/to_onehot.py diff --git a/monai/networks/losses/dice.py b/monai/networks/losses/dice.py index 5774d0ae11..1aee62521d 100644 --- a/monai/networks/losses/dice.py +++ b/monai/networks/losses/dice.py @@ -46,11 +46,9 @@ def forward(self, pred, ground, smooth=1e-5): psum = pred.float().sigmoid() tsum = ground else: - pinds = (0, 3, 1, 2) if len(ground.shape) == 4 else (0, 4, 1, 2, 3) # multiclass dice loss, use softmax in the first dimension and convert target to one-hot encoding psum = torch.softmax(pred, 1) - tsum = one_hot(ground, pred.shape[1]) # BCHW(D) -> BCHW(D)N - tsum = tsum[:, 0].permute(*pinds).contiguous() # BCHW(D)N -> BNHW(D) + tsum = one_hot(ground, pred.shape[1]) # B1HW(D) -> BNHW(D) assert tsum.shape == pred.shape, ("Ground truth one-hot has differing shape (%r) from source (%r)" % (tsum.shape, pred.shape)) @@ -59,7 +57,6 @@ def forward(self, pred, ground, smooth=1e-5): if not self.includeBackground: tsum = tsum[:, 1:] psum = psum[:, 1:] - pred = pred[:, 1:] batchsize = ground.size(0) tsum = tsum.float().view(batchsize, -1) diff --git a/monai/networks/utils.py b/monai/networks/utils.py index 0add2e9b2e..cabbc520b8 100644 --- a/monai/networks/utils.py +++ b/monai/networks/utils.py @@ -8,27 +8,36 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """ Utilities and types for defining networks, these depend on Pytorch. """ import torch import torch.nn as nn +import torch.nn.functional as f def one_hot(labels, num_classes): """ - For a tensor `labels' of dimensions BC[D][H]W, return a tensor of dimensions BC[D][H]WN for `num_classes' N number of - classes. For every value v = labels[b,c,h,w], the value in the result at [b,c,h,w,v] will be 1 and all others 0. + For a tensor `labels' of dimensions B1[spatial_dims], return a tensor of dimensions BN[spatial_dims] for `num_classes' N number of + classes. + + Example: + For every value v = labels[b,1,h,w], the value in the result at [b,v,h,w] will be 1 and all others 0. + Note that this will include the background label, thus a binary mask should be treated as having 2 classes. """ - onehotshape = tuple(labels.shape) + (num_classes,) - labels = labels % num_classes - y = torch.eye(num_classes, device=labels.device) - onehot = y[labels.view(-1).long()] - - return onehot.reshape(*onehotshape) + num_dims = labels.dim() + if num_dims < 2 or labels.shape[1] != 1: + raise ValueError('labels should have a channel with length equals to one.') + + labels = torch.squeeze(labels, 1) + labels = f.one_hot(labels.long(), num_classes) + new_axes = [0, -1] + list(range(1, num_dims - 1)) + labels = labels.permute(*new_axes) + if not labels.is_contiguous(): + return labels.contiguous() + return labels def slice_channels(tensor, *slicevals): diff --git a/monai/utils/compute_meandice.py b/monai/utils/compute_meandice.py index 1daafa83b7..aa6526585c 100644 --- a/monai/utils/compute_meandice.py +++ b/monai/utils/compute_meandice.py @@ -1,6 +1,6 @@ import torch -from monai.utils.to_onehot import to_onehot +from monai.networks.utils import one_hot def compute_meandice(y_pred, @@ -36,7 +36,7 @@ def compute_meandice(y_pred, if logit_thresh is not None: raise ValueError('`logit_thresh` is incompatible when mutually_exclusive is True.') y_pred = torch.argmax(y_pred, dim=1, keepdim=True) - y_pred = to_onehot(y_pred, n_channels_y_pred) + y_pred = one_hot(y_pred, n_channels_y_pred) else: # channel-wise thresholding if add_sigmoid: y_pred = torch.sigmoid(y_pred) @@ -44,7 +44,7 @@ def compute_meandice(y_pred, y_pred = (y_pred >= logit_thresh).float() if to_onehot_y: - y = to_onehot(y, n_channels_y_pred) + y = one_hot(y, n_channels_y_pred) if exclude_bg: y = y[:, 1:] if y.shape[1] > 1 else y diff --git a/monai/utils/to_onehot.py b/monai/utils/to_onehot.py deleted file mode 100644 index e9137a5ff7..0000000000 --- a/monai/utils/to_onehot.py +++ /dev/null @@ -1,37 +0,0 @@ -# Copyright 2020 MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import torch -import torch.nn.functional as f - - -def to_onehot(data, num_classes: int): - """Utility function to convert PyTorch tensor to One-Hot encoding format. - The input data should only have 1 channel and the first dim is batch. - Example shapes: [16, 1, 96, 96], [16, 1, 96, 96, 32]. - And the data values must match "num_classes", example: num_classes = 10 and values in [0 ... 9]. - - Args: - data (torch.Tensor): target data to convert One-Hot format. - num_classes (int): the class number in the Tensor data. - - """ - num_dims = data.dim() - if num_dims < 2 or data.shape[1] != 1: - raise ValueError('data should have a channel with length equals to one.') - - data = torch.squeeze(data, 1) - data = f.one_hot(data.long(), num_classes) - new_axes = [0, -1] + list(range(1, num_dims - 1)) - data = data.permute(*new_axes) - if not data.is_contiguous(): - return data.contiguous() - return data diff --git a/tests/test_to_onehot.py b/tests/test_to_onehot.py index b3a528ee64..7407edbafb 100644 --- a/tests/test_to_onehot.py +++ b/tests/test_to_onehot.py @@ -15,21 +15,21 @@ import torch from parameterized import parameterized -from monai.utils.to_onehot import to_onehot +from monai.networks.utils import one_hot TEST_CASE_1 = [ # single channel 2D, batch 3, shape (2, 1, 2, 2) - {'data': torch.tensor([[[[0, 1], [1, 2]]], [[[2, 1], [1, 0]]]]), 'num_classes': 3}, + {'labels': torch.tensor([[[[0, 1], [1, 2]]], [[[2, 1], [1, 0]]]]), 'num_classes': 3}, (2, 3, 2, 2), ] TEST_CASE_2 = [ # single channel 1D, batch 2, shape (2, 1, 4) - {'data': torch.tensor([[[1, 2, 2, 0]], [[2, 1, 0, 1]]]), 'num_classes': 3}, + {'labels': torch.tensor([[[1, 2, 2, 0]], [[2, 1, 0, 1]]]), 'num_classes': 3}, (2, 3, 4), np.array([[[0, 0, 0, 1], [1, 0, 0, 0], [0, 1, 1, 0]], [[0, 0, 1, 0], [0, 1, 0, 1], [1, 0, 0, 0]]]), ] TEST_CASE_3 = [ # single channel 0D, batch 2, shape (2, 1) - {'data': torch.tensor([[1.], [2.]]), 'num_classes': 3}, + {'labels': torch.tensor([[1.], [2.]]), 'num_classes': 3}, (2, 3), np.array([[0, 1, 0], [0, 0, 1]]), ] @@ -39,7 +39,7 @@ class TestToOneHot(unittest.TestCase): @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) def test_shape(self, input_data, expected_shape, expected_result=None): - result = to_onehot(**input_data) + result = one_hot(**input_data) self.assertEqual(result.shape, expected_shape) if expected_result is not None: self.assertTrue(np.allclose(expected_result, result.numpy())) From de83ca4f06dcb39c7971f633f29a4471b2d7d907 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Mon, 10 Feb 2020 16:50:14 +0000 Subject: [PATCH 08/10] revise arguments and docstrings in monai/networks/metrics/mean_dice.py --- monai/networks/metrics/mean_dice.py | 12 ++++++------ monai/networks/utils.py | 7 +++---- monai/utils/compute_meandice.py | 6 +++--- tests/test_compute_meandice.py | 6 +++--- 4 files changed, 15 insertions(+), 16 deletions(-) diff --git a/monai/networks/metrics/mean_dice.py b/monai/networks/metrics/mean_dice.py index f0f92f203c..53f4aaff38 100644 --- a/monai/networks/metrics/mean_dice.py +++ b/monai/networks/metrics/mean_dice.py @@ -26,7 +26,7 @@ class MeanDice(Metric): """ def __init__( self, - exclude_bg=True, + include_background=True, to_onehot_y=False, logit_thresh=0.5, add_sigmoid=False, @@ -37,17 +37,17 @@ def __init__( """ Args: - exclude_bg (Bool): skip dice computation on the first channel of the predicted output or not. - to_onehot_y (Bool): whether the label data(y) is already in One-Hot format, will convert if not. + include_background (Bool): whether to include dice computation on the first channel of the predicted output. + to_onehot_y (Bool): whether to convert the output prediction into the one-hot format. logit_thresh (Float): the threshold value to round value to 0.0 and 1.0, default is 0.5. - add_sigmoid (Bool): whether to add sigmoid function to the output prediction before computating Dice. + add_sigmoid (Bool): whether to add sigmoid function to the output prediction before computing Dice. mutually_exclusive (Bool): if True, the output prediction will be converted into a binary matrix using a combination of argmax and to_onehot. output_transform (Callable): transform the ignite.engine.state.output into [y_pred, y] pair. device (torch.device): device specification in case of distributed computation usage. """ super(MeanDice, self).__init__(output_transform, device=device) - self.exclude_bg = exclude_bg + self.include_background = include_background self.to_onehot_y = to_onehot_y self.logit_thresh = logit_thresh self.add_sigmoid = add_sigmoid @@ -65,7 +65,7 @@ def reset(self): def update(self, output: Sequence[Union[torch.Tensor, dict]]): assert len(output) == 2, 'MeanDice metric can only support y_pred and y.' y_pred, y = output - average = compute_meandice(y_pred, y, self.exclude_bg, self.to_onehot_y, self.mutually_exclusive, + average = compute_meandice(y_pred, y, self.include_background, self.to_onehot_y, self.mutually_exclusive, self.add_sigmoid, self.logit_thresh) batch_size = len(y) diff --git a/monai/networks/utils.py b/monai/networks/utils.py index cabbc520b8..403553ca40 100644 --- a/monai/networks/utils.py +++ b/monai/networks/utils.py @@ -19,13 +19,12 @@ def one_hot(labels, num_classes): """ - For a tensor `labels' of dimensions B1[spatial_dims], return a tensor of dimensions BN[spatial_dims] for `num_classes' N number of - classes. + For a tensor `labels' of dimensions B1[spatial_dims], return a tensor of dimensions BN[spatial_dims] + for `num_classes' N number of classes. Example: For every value v = labels[b,1,h,w], the value in the result at [b,v,h,w] will be 1 and all others 0. - - Note that this will include the background label, thus a binary mask should be treated as having 2 classes. + Note that this will include the background label, thus a binary mask should be treated as having 2 classes. """ num_dims = labels.dim() if num_dims < 2 or labels.shape[1] != 1: diff --git a/monai/utils/compute_meandice.py b/monai/utils/compute_meandice.py index aa6526585c..92b0cafae4 100644 --- a/monai/utils/compute_meandice.py +++ b/monai/utils/compute_meandice.py @@ -5,7 +5,7 @@ def compute_meandice(y_pred, y, - exclude_bg=False, + include_background=False, to_onehot_y=True, mutually_exclusive=True, add_sigmoid=False, @@ -16,7 +16,7 @@ def compute_meandice(y_pred, y_pred (torch.Tensor): input data to compute, typical segmentation model output. it must be One-Hot format and first dim is batch, example shape: [16, 3, 32, 32]. y (torch.Tensor): ground truth to compute mean dice metric, the first dim is batch. - exclude_bg (Bool): whether to skip dice computation on the first channel of the predicted output. + include_background (Bool): whether to skip dice computation on the first channel of the predicted output. to_onehot_y (Bool): whether to convert `y` into the one-hot format. mutually_exclusive (Bool): if True, `y_pred` will be converted into a binary matrix using a combination of argmax and to_onehot. @@ -46,7 +46,7 @@ def compute_meandice(y_pred, if to_onehot_y: y = one_hot(y, n_channels_y_pred) - if exclude_bg: + if not include_background: y = y[:, 1:] if y.shape[1] > 1 else y y_pred = y_pred[:, 1:] if y_pred.shape[1] > 1 else y_pred diff --git a/tests/test_compute_meandice.py b/tests/test_compute_meandice.py index 01084aec77..65f6e8d00e 100644 --- a/tests/test_compute_meandice.py +++ b/tests/test_compute_meandice.py @@ -21,7 +21,7 @@ { 'y_pred': torch.tensor([[[[1., -1.], [-1., 1.]]]]), 'y': torch.tensor([[[[1., 0.], [1., 1.]]]]), - 'exclude_bg': False, + 'include_background': True, 'to_onehot_y': False, 'mutually_exclusive': False, 'logit_thresh': 0.5, @@ -38,8 +38,8 @@ [[[-2., 0.], [3., 1.]], [[0., 2.], [1., -2.]], [[-1., 2.], [4., 0.]]]]), 'y': torch.tensor([[[[1, 2], [1, 0]]], [[[1, 1], [2, 0]]]]), - 'exclude_bg': - True, + 'include_background': + False, 'to_onehot_y': True, 'mutually_exclusive': From 4568645affffae7f4eb32e00f06af1512bb43f74 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Mon, 10 Feb 2020 17:15:19 +0000 Subject: [PATCH 09/10] update dice loss --- monai/networks/losses/dice.py | 29 +++++++++++++++++++---------- tests/test_dice_loss.py | 2 +- 2 files changed, 20 insertions(+), 11 deletions(-) diff --git a/monai/networks/losses/dice.py b/monai/networks/losses/dice.py index 1aee62521d..488fea205c 100644 --- a/monai/networks/losses/dice.py +++ b/monai/networks/losses/dice.py @@ -31,32 +31,41 @@ class DiceLoss(_Loss): size they can get overwhelmed by the signal from the background so excluding it in such cases helps convergence. """ - def __init__(self, include_background=True): + def __init__(self, include_background=True, do_sigmoid=True, do_softmax=False): """ If `include_background` is False channel index 0 (background category) is excluded from the calculation. """ super().__init__() - self.includeBackground = include_background + self.include_background = include_background + self.do_sigmoid = do_sigmoid + self.do_softmax = do_softmax def forward(self, pred, ground, smooth=1e-5): if ground.shape[1] != 1: raise ValueError("Ground truth should have only a single channel, shape is " + str(ground.shape)) + psum = pred.float() + if self.do_sigmoid: + psum = psum.sigmoid() if pred.shape[1] == 1: # binary dice loss, use sigmoid activation - psum = pred.float().sigmoid() + if self.do_softmax: + raise ValueError('do_softmax is not compatible with single channel prediction.') + if not self.include_background: + raise RuntimeWarning('single channel ground truth, `include_background=False` ignored.') tsum = ground else: - # multiclass dice loss, use softmax in the first dimension and convert target to one-hot encoding - psum = torch.softmax(pred, 1) + if self.do_softmax: + if self.do_sigmoid: + raise ValueError('do_sigmoid=True and do_softmax=Ture are not compatible.') + # multiclass dice loss, use softmax in the first dimension and convert target to one-hot encoding + psum = torch.softmax(pred, 1) tsum = one_hot(ground, pred.shape[1]) # B1HW(D) -> BNHW(D) - - assert tsum.shape == pred.shape, ("Ground truth one-hot has differing shape (%r) from source (%r)" % - (tsum.shape, pred.shape)) - # exclude background category so that it doesn't overwhelm the other segmentations if they are small - if not self.includeBackground: + if not self.include_background: tsum = tsum[:, 1:] psum = psum[:, 1:] + assert tsum.shape == pred.shape, ("Ground truth one-hot has differing shape (%r) from source (%r)" % + (tsum.shape, pred.shape)) batchsize = ground.size(0) tsum = tsum.float().view(batchsize, -1) diff --git a/tests/test_dice_loss.py b/tests/test_dice_loss.py index 0e1908b999..5ad05d69d1 100644 --- a/tests/test_dice_loss.py +++ b/tests/test_dice_loss.py @@ -18,7 +18,7 @@ TEST_CASE_1 = [ { - 'include_background': False, + 'include_background': True, }, { 'pred': torch.tensor([[[[1., -1.], [-1., 1.]]]]), From 54c198a7ab4a676c237795dee1de6069b9c55592 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Mon, 10 Feb 2020 17:52:28 +0000 Subject: [PATCH 10/10] additional dice loss test cases --- monai/networks/losses/dice.py | 7 ++++-- tests/test_dice_loss.py | 45 ++++++++++++++++++++++++++++++++--- 2 files changed, 47 insertions(+), 5 deletions(-) diff --git a/monai/networks/losses/dice.py b/monai/networks/losses/dice.py index 488fea205c..c193dbd3f8 100644 --- a/monai/networks/losses/dice.py +++ b/monai/networks/losses/dice.py @@ -31,9 +31,12 @@ class DiceLoss(_Loss): size they can get overwhelmed by the signal from the background so excluding it in such cases helps convergence. """ - def __init__(self, include_background=True, do_sigmoid=True, do_softmax=False): + def __init__(self, include_background=True, do_sigmoid=False, do_softmax=False): """ - If `include_background` is False channel index 0 (background category) is excluded from the calculation. + Args: + include_background (bool): If False channel index 0 (background category) is excluded from the calculation. + do_sigmoid (bool): If True, apply a sigmoid function to the prediction. + do_softmax (bool): If True, apply a softmax function to the prediction. """ super().__init__() self.include_background = include_background diff --git a/tests/test_dice_loss.py b/tests/test_dice_loss.py index 5ad05d69d1..bbaa49c8ac 100644 --- a/tests/test_dice_loss.py +++ b/tests/test_dice_loss.py @@ -16,9 +16,10 @@ from monai.networks.losses.dice import DiceLoss -TEST_CASE_1 = [ +TEST_CASE_1 = [ # shape: (1, 1, 2, 2), (1, 1, 2, 2) { 'include_background': True, + 'do_sigmoid': True, }, { 'pred': torch.tensor([[[[1., -1.], [-1., 1.]]]]), @@ -28,9 +29,10 @@ 0.307576, ] -TEST_CASE_2 = [ +TEST_CASE_2 = [ # shape: (2, 1, 2, 2), (2, 1, 2, 2) { 'include_background': True, + 'do_sigmoid': True, }, { 'pred': torch.tensor([[[[1., -1.], [-1., 1.]]], [[[1., -1.], [-1., 1.]]]]), @@ -40,10 +42,47 @@ 0.416636, ] +TEST_CASE_3 = [ # shape: (2, 2, 3), (2, 1, 3) + { + 'include_background': True, + }, + { + 'pred': torch.tensor([[[1., 1., 0.], [0., 0., 1.]], [[1., 0., 1.], [0., 1., 0.]]]), + 'ground': torch.tensor([[[0., 0., 1.]], [[0., 1., 0.]]]), + 'smooth': 0.0, + }, + 0.0, +] + +TEST_CASE_4 = [ # shape: (2, 2, 3), (2, 1, 3) + { + 'include_background': True, + 'do_sigmoid': True, + }, + { + 'pred': torch.tensor([[[-1., 0., 1.], [1., 0., -1.]], [[0., 0., 0.], [0., 0., 0.]]]), + 'ground': torch.tensor([[[1., 0., 0.]], [[1., 1., 0.]]]), + 'smooth': 1e-4, + }, + 0.422957, +] + +TEST_CASE_5 = [ # shape: (2, 2, 3), (2, 1, 3) + { + 'include_background': True, + 'do_softmax': True, + }, + { + 'pred': torch.tensor([[[-1., 0., 1.], [1., 0., -1.]], [[0., 0., 0.], [0., 0., 0.]]]), + 'ground': torch.tensor([[[1., 0., 0.]], [[1., 1., 0.]]]), + 'smooth': 1e-4, + }, + 0.373045, +] class TestDiceLoss(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5]) def test_shape(self, input_param, input_data, expected_val): result = DiceLoss(**input_param).forward(**input_data) self.assertAlmostEqual(result.item(), expected_val, places=5)