diff --git a/monai/networks/losses/dice.py b/monai/networks/losses/dice.py index 5774d0ae11..c193dbd3f8 100644 --- a/monai/networks/losses/dice.py +++ b/monai/networks/losses/dice.py @@ -31,35 +31,44 @@ class DiceLoss(_Loss): size they can get overwhelmed by the signal from the background so excluding it in such cases helps convergence. """ - def __init__(self, include_background=True): + def __init__(self, include_background=True, do_sigmoid=False, do_softmax=False): """ - If `include_background` is False channel index 0 (background category) is excluded from the calculation. + Args: + include_background (bool): If False channel index 0 (background category) is excluded from the calculation. + do_sigmoid (bool): If True, apply a sigmoid function to the prediction. + do_softmax (bool): If True, apply a softmax function to the prediction. """ super().__init__() - self.includeBackground = include_background + self.include_background = include_background + self.do_sigmoid = do_sigmoid + self.do_softmax = do_softmax def forward(self, pred, ground, smooth=1e-5): if ground.shape[1] != 1: raise ValueError("Ground truth should have only a single channel, shape is " + str(ground.shape)) + psum = pred.float() + if self.do_sigmoid: + psum = psum.sigmoid() if pred.shape[1] == 1: # binary dice loss, use sigmoid activation - psum = pred.float().sigmoid() + if self.do_softmax: + raise ValueError('do_softmax is not compatible with single channel prediction.') + if not self.include_background: + raise RuntimeWarning('single channel ground truth, `include_background=False` ignored.') tsum = ground else: - pinds = (0, 3, 1, 2) if len(ground.shape) == 4 else (0, 4, 1, 2, 3) - # multiclass dice loss, use softmax in the first dimension and convert target to one-hot encoding - psum = torch.softmax(pred, 1) - tsum = one_hot(ground, pred.shape[1]) # BCHW(D) -> BCHW(D)N - tsum = tsum[:, 0].permute(*pinds).contiguous() # BCHW(D)N -> BNHW(D) - - assert tsum.shape == pred.shape, ("Ground truth one-hot has differing shape (%r) from source (%r)" % - (tsum.shape, pred.shape)) - + if self.do_softmax: + if self.do_sigmoid: + raise ValueError('do_sigmoid=True and do_softmax=Ture are not compatible.') + # multiclass dice loss, use softmax in the first dimension and convert target to one-hot encoding + psum = torch.softmax(pred, 1) + tsum = one_hot(ground, pred.shape[1]) # B1HW(D) -> BNHW(D) # exclude background category so that it doesn't overwhelm the other segmentations if they are small - if not self.includeBackground: + if not self.include_background: tsum = tsum[:, 1:] psum = psum[:, 1:] - pred = pred[:, 1:] + assert tsum.shape == pred.shape, ("Ground truth one-hot has differing shape (%r) from source (%r)" % + (tsum.shape, pred.shape)) batchsize = ground.size(0) tsum = tsum.float().view(batchsize, -1) diff --git a/monai/networks/metrics/mean_dice.py b/monai/networks/metrics/mean_dice.py new file mode 100644 index 0000000000..53f4aaff38 --- /dev/null +++ b/monai/networks/metrics/mean_dice.py @@ -0,0 +1,80 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from typing import Callable, Union, Optional, Sequence +from ignite.exceptions import NotComputableError +from ignite.metrics import Metric +from ignite.metrics.metric import sync_all_reduce, reinit__is_reduced +from monai.utils.compute_meandice import compute_meandice + +__all__ = [ + 'MeanDice' +] + + +class MeanDice(Metric): + """Computes dice score metric from full size Tensor and collects average over batch, class-channels, iterations. + """ + def __init__( + self, + include_background=True, + to_onehot_y=False, + logit_thresh=0.5, + add_sigmoid=False, + mutually_exclusive=False, + output_transform: Callable = lambda x: x, + device: Optional[Union[str, torch.device]] = None + ): + """ + + Args: + include_background (Bool): whether to include dice computation on the first channel of the predicted output. + to_onehot_y (Bool): whether to convert the output prediction into the one-hot format. + logit_thresh (Float): the threshold value to round value to 0.0 and 1.0, default is 0.5. + add_sigmoid (Bool): whether to add sigmoid function to the output prediction before computing Dice. + mutually_exclusive (Bool): if True, the output prediction will be converted into a binary matrix using + a combination of argmax and to_onehot. + output_transform (Callable): transform the ignite.engine.state.output into [y_pred, y] pair. + device (torch.device): device specification in case of distributed computation usage. + """ + super(MeanDice, self).__init__(output_transform, device=device) + self.include_background = include_background + self.to_onehot_y = to_onehot_y + self.logit_thresh = logit_thresh + self.add_sigmoid = add_sigmoid + self.mutually_exclusive = mutually_exclusive + + self._sum = 0 + self._num_examples = 0 + + @reinit__is_reduced + def reset(self): + self._sum = 0 + self._num_examples = 0 + + @reinit__is_reduced + def update(self, output: Sequence[Union[torch.Tensor, dict]]): + assert len(output) == 2, 'MeanDice metric can only support y_pred and y.' + y_pred, y = output + average = compute_meandice(y_pred, y, self.include_background, self.to_onehot_y, self.mutually_exclusive, + self.add_sigmoid, self.logit_thresh) + + batch_size = len(y) + self._sum += average.item() * batch_size + self._num_examples += batch_size + + @sync_all_reduce("_sum", "_num_examples") + def compute(self): + if self._num_examples == 0: + raise NotComputableError( + 'MeanDice must have at least one example before it can be computed.') + return self._sum / self._num_examples diff --git a/monai/networks/utils.py b/monai/networks/utils.py index 0add2e9b2e..403553ca40 100644 --- a/monai/networks/utils.py +++ b/monai/networks/utils.py @@ -8,27 +8,35 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """ Utilities and types for defining networks, these depend on Pytorch. """ import torch import torch.nn as nn +import torch.nn.functional as f def one_hot(labels, num_classes): """ - For a tensor `labels' of dimensions BC[D][H]W, return a tensor of dimensions BC[D][H]WN for `num_classes' N number of - classes. For every value v = labels[b,c,h,w], the value in the result at [b,c,h,w,v] will be 1 and all others 0. - Note that this will include the background label, thus a binary mask should be treated as having 2 classes. - """ - onehotshape = tuple(labels.shape) + (num_classes,) - labels = labels % num_classes - y = torch.eye(num_classes, device=labels.device) - onehot = y[labels.view(-1).long()] + For a tensor `labels' of dimensions B1[spatial_dims], return a tensor of dimensions BN[spatial_dims] + for `num_classes' N number of classes. - return onehot.reshape(*onehotshape) + Example: + For every value v = labels[b,1,h,w], the value in the result at [b,v,h,w] will be 1 and all others 0. + Note that this will include the background label, thus a binary mask should be treated as having 2 classes. + """ + num_dims = labels.dim() + if num_dims < 2 or labels.shape[1] != 1: + raise ValueError('labels should have a channel with length equals to one.') + + labels = torch.squeeze(labels, 1) + labels = f.one_hot(labels.long(), num_classes) + new_axes = [0, -1] + list(range(1, num_dims - 1)) + labels = labels.permute(*new_axes) + if not labels.is_contiguous(): + return labels.contiguous() + return labels def slice_channels(tensor, *slicevals): diff --git a/monai/utils/compute_meandice.py b/monai/utils/compute_meandice.py new file mode 100644 index 0000000000..92b0cafae4 --- /dev/null +++ b/monai/utils/compute_meandice.py @@ -0,0 +1,63 @@ +import torch + +from monai.networks.utils import one_hot + + +def compute_meandice(y_pred, + y, + include_background=False, + to_onehot_y=True, + mutually_exclusive=True, + add_sigmoid=False, + logit_thresh=None): + """Computes dice score metric from full size Tensor and collects average. + + Args: + y_pred (torch.Tensor): input data to compute, typical segmentation model output. + it must be One-Hot format and first dim is batch, example shape: [16, 3, 32, 32]. + y (torch.Tensor): ground truth to compute mean dice metric, the first dim is batch. + include_background (Bool): whether to skip dice computation on the first channel of the predicted output. + to_onehot_y (Bool): whether to convert `y` into the one-hot format. + mutually_exclusive (Bool): if True, `y_pred` will be converted into a binary matrix using + a combination of argmax and to_onehot. + add_sigmoid (Bool): whether to add sigmoid function to y_pred before computation. + logit_thresh (Float): the threshold value used to convert `y_pred` into a binary matrix. + + Note: + This method provide two options to convert `y_pred` into a binary matrix: + (1) when `mutually_exclusive` is True, it uses a combination of argmax and to_onehot + (2) when `mutually_exclusive` is False, it uses a threshold `logit_thresh` + (optionally with a sigmoid function before thresholding). + + """ + n_channels_y_pred = y_pred.shape[1] + + if mutually_exclusive: + if logit_thresh is not None: + raise ValueError('`logit_thresh` is incompatible when mutually_exclusive is True.') + y_pred = torch.argmax(y_pred, dim=1, keepdim=True) + y_pred = one_hot(y_pred, n_channels_y_pred) + else: # channel-wise thresholding + if add_sigmoid: + y_pred = torch.sigmoid(y_pred) + if logit_thresh is not None: + y_pred = (y_pred >= logit_thresh).float() + + if to_onehot_y: + y = one_hot(y, n_channels_y_pred) + + if not include_background: + y = y[:, 1:] if y.shape[1] > 1 else y + y_pred = y_pred[:, 1:] if y_pred.shape[1] > 1 else y_pred + + # reducing only spatial dimensions (not batch nor channels) + reduce_axis = list(range(2, y_pred.dim())) + intersection = torch.sum(y * y_pred, reduce_axis) + + y_o = torch.sum(y, reduce_axis) + y_pred_o = torch.sum(y_pred, reduce_axis) + denominator = y_o + y_pred_o + + f = (2.0 * intersection) / denominator + # final reduce_mean across batches and channels + return torch.mean(f) diff --git a/tests/test_compute_meandice.py b/tests/test_compute_meandice.py new file mode 100644 index 0000000000..65f6e8d00e --- /dev/null +++ b/tests/test_compute_meandice.py @@ -0,0 +1,61 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +from parameterized import parameterized + +from monai.utils.compute_meandice import compute_meandice + +# keep background +TEST_CASE_1 = [ + { + 'y_pred': torch.tensor([[[[1., -1.], [-1., 1.]]]]), + 'y': torch.tensor([[[[1., 0.], [1., 1.]]]]), + 'include_background': True, + 'to_onehot_y': False, + 'mutually_exclusive': False, + 'logit_thresh': 0.5, + 'add_sigmoid': True, + }, + 0.8000, +] + +# remove background and not One-Hot target +TEST_CASE_2 = [ + { + 'y_pred': + torch.tensor([[[[-1., 3.], [2., -4.]], [[0., -1.], [3., 2.]], [[0., 1.], [2., -1.]]], + [[[-2., 0.], [3., 1.]], [[0., 2.], [1., -2.]], [[-1., 2.], [4., 0.]]]]), + 'y': + torch.tensor([[[[1, 2], [1, 0]]], [[[1, 1], [2, 0]]]]), + 'include_background': + False, + 'to_onehot_y': + True, + 'mutually_exclusive': + True, + }, + 0.4583, +] + + +class TestComputeMeanDice(unittest.TestCase): + + @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) + def test_value(self, input_data, expected_value): + result = compute_meandice(**input_data) + self.assertAlmostEqual(result.item(), expected_value, places=4) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_dice_loss.py b/tests/test_dice_loss.py index 0e1908b999..bbaa49c8ac 100644 --- a/tests/test_dice_loss.py +++ b/tests/test_dice_loss.py @@ -16,9 +16,10 @@ from monai.networks.losses.dice import DiceLoss -TEST_CASE_1 = [ +TEST_CASE_1 = [ # shape: (1, 1, 2, 2), (1, 1, 2, 2) { - 'include_background': False, + 'include_background': True, + 'do_sigmoid': True, }, { 'pred': torch.tensor([[[[1., -1.], [-1., 1.]]]]), @@ -28,9 +29,10 @@ 0.307576, ] -TEST_CASE_2 = [ +TEST_CASE_2 = [ # shape: (2, 1, 2, 2), (2, 1, 2, 2) { 'include_background': True, + 'do_sigmoid': True, }, { 'pred': torch.tensor([[[[1., -1.], [-1., 1.]]], [[[1., -1.], [-1., 1.]]]]), @@ -40,10 +42,47 @@ 0.416636, ] +TEST_CASE_3 = [ # shape: (2, 2, 3), (2, 1, 3) + { + 'include_background': True, + }, + { + 'pred': torch.tensor([[[1., 1., 0.], [0., 0., 1.]], [[1., 0., 1.], [0., 1., 0.]]]), + 'ground': torch.tensor([[[0., 0., 1.]], [[0., 1., 0.]]]), + 'smooth': 0.0, + }, + 0.0, +] + +TEST_CASE_4 = [ # shape: (2, 2, 3), (2, 1, 3) + { + 'include_background': True, + 'do_sigmoid': True, + }, + { + 'pred': torch.tensor([[[-1., 0., 1.], [1., 0., -1.]], [[0., 0., 0.], [0., 0., 0.]]]), + 'ground': torch.tensor([[[1., 0., 0.]], [[1., 1., 0.]]]), + 'smooth': 1e-4, + }, + 0.422957, +] + +TEST_CASE_5 = [ # shape: (2, 2, 3), (2, 1, 3) + { + 'include_background': True, + 'do_softmax': True, + }, + { + 'pred': torch.tensor([[[-1., 0., 1.], [1., 0., -1.]], [[0., 0., 0.], [0., 0., 0.]]]), + 'ground': torch.tensor([[[1., 0., 0.]], [[1., 1., 0.]]]), + 'smooth': 1e-4, + }, + 0.373045, +] class TestDiceLoss(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5]) def test_shape(self, input_param, input_data, expected_val): result = DiceLoss(**input_param).forward(**input_data) self.assertAlmostEqual(result.item(), expected_val, places=5) diff --git a/tests/test_to_onehot.py b/tests/test_to_onehot.py new file mode 100644 index 0000000000..7407edbafb --- /dev/null +++ b/tests/test_to_onehot.py @@ -0,0 +1,49 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import torch +from parameterized import parameterized + +from monai.networks.utils import one_hot + +TEST_CASE_1 = [ # single channel 2D, batch 3, shape (2, 1, 2, 2) + {'labels': torch.tensor([[[[0, 1], [1, 2]]], [[[2, 1], [1, 0]]]]), 'num_classes': 3}, + (2, 3, 2, 2), +] + +TEST_CASE_2 = [ # single channel 1D, batch 2, shape (2, 1, 4) + {'labels': torch.tensor([[[1, 2, 2, 0]], [[2, 1, 0, 1]]]), 'num_classes': 3}, + (2, 3, 4), + np.array([[[0, 0, 0, 1], [1, 0, 0, 0], [0, 1, 1, 0]], [[0, 0, 1, 0], [0, 1, 0, 1], [1, 0, 0, 0]]]), +] + +TEST_CASE_3 = [ # single channel 0D, batch 2, shape (2, 1) + {'labels': torch.tensor([[1.], [2.]]), 'num_classes': 3}, + (2, 3), + np.array([[0, 1, 0], [0, 0, 1]]), +] + + +class TestToOneHot(unittest.TestCase): + + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) + def test_shape(self, input_data, expected_shape, expected_result=None): + result = one_hot(**input_data) + self.assertEqual(result.shape, expected_shape) + if expected_result is not None: + self.assertTrue(np.allclose(expected_result, result.numpy())) + + +if __name__ == '__main__': + unittest.main()