Skip to content
Merged
39 changes: 24 additions & 15 deletions monai/networks/losses/dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,35 +31,44 @@ class DiceLoss(_Loss):
size they can get overwhelmed by the signal from the background so excluding it in such cases helps convergence.
"""

def __init__(self, include_background=True):
def __init__(self, include_background=True, do_sigmoid=False, do_softmax=False):
"""
If `include_background` is False channel index 0 (background category) is excluded from the calculation.
Args:
include_background (bool): If False channel index 0 (background category) is excluded from the calculation.
do_sigmoid (bool): If True, apply a sigmoid function to the prediction.
do_softmax (bool): If True, apply a softmax function to the prediction.
"""
super().__init__()
self.includeBackground = include_background
self.include_background = include_background
self.do_sigmoid = do_sigmoid
self.do_softmax = do_softmax

def forward(self, pred, ground, smooth=1e-5):
if ground.shape[1] != 1:
raise ValueError("Ground truth should have only a single channel, shape is " + str(ground.shape))

psum = pred.float()
if self.do_sigmoid:
psum = psum.sigmoid()
if pred.shape[1] == 1: # binary dice loss, use sigmoid activation
psum = pred.float().sigmoid()
if self.do_softmax:
raise ValueError('do_softmax is not compatible with single channel prediction.')
if not self.include_background:
raise RuntimeWarning('single channel ground truth, `include_background=False` ignored.')
tsum = ground
else:
pinds = (0, 3, 1, 2) if len(ground.shape) == 4 else (0, 4, 1, 2, 3)
# multiclass dice loss, use softmax in the first dimension and convert target to one-hot encoding
psum = torch.softmax(pred, 1)
tsum = one_hot(ground, pred.shape[1]) # BCHW(D) -> BCHW(D)N
tsum = tsum[:, 0].permute(*pinds).contiguous() # BCHW(D)N -> BNHW(D)

assert tsum.shape == pred.shape, ("Ground truth one-hot has differing shape (%r) from source (%r)" %
(tsum.shape, pred.shape))

if self.do_softmax:
if self.do_sigmoid:
raise ValueError('do_sigmoid=True and do_softmax=Ture are not compatible.')
# multiclass dice loss, use softmax in the first dimension and convert target to one-hot encoding
psum = torch.softmax(pred, 1)
tsum = one_hot(ground, pred.shape[1]) # B1HW(D) -> BNHW(D)
# exclude background category so that it doesn't overwhelm the other segmentations if they are small
if not self.includeBackground:
if not self.include_background:
tsum = tsum[:, 1:]
psum = psum[:, 1:]
pred = pred[:, 1:]
assert tsum.shape == pred.shape, ("Ground truth one-hot has differing shape (%r) from source (%r)" %
(tsum.shape, pred.shape))

batchsize = ground.size(0)
tsum = tsum.float().view(batchsize, -1)
Expand Down
80 changes: 80 additions & 0 deletions monai/networks/metrics/mean_dice.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# Copyright 2020 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
from typing import Callable, Union, Optional, Sequence
from ignite.exceptions import NotComputableError
from ignite.metrics import Metric
from ignite.metrics.metric import sync_all_reduce, reinit__is_reduced
from monai.utils.compute_meandice import compute_meandice

__all__ = [
'MeanDice'
]


class MeanDice(Metric):
"""Computes dice score metric from full size Tensor and collects average over batch, class-channels, iterations.
"""
def __init__(
self,
include_background=True,
to_onehot_y=False,
logit_thresh=0.5,
add_sigmoid=False,
mutually_exclusive=False,
output_transform: Callable = lambda x: x,
device: Optional[Union[str, torch.device]] = None
):
"""

Args:
include_background (Bool): whether to include dice computation on the first channel of the predicted output.
to_onehot_y (Bool): whether to convert the output prediction into the one-hot format.
logit_thresh (Float): the threshold value to round value to 0.0 and 1.0, default is 0.5.
add_sigmoid (Bool): whether to add sigmoid function to the output prediction before computing Dice.
mutually_exclusive (Bool): if True, the output prediction will be converted into a binary matrix using
a combination of argmax and to_onehot.
output_transform (Callable): transform the ignite.engine.state.output into [y_pred, y] pair.
device (torch.device): device specification in case of distributed computation usage.
"""
super(MeanDice, self).__init__(output_transform, device=device)
self.include_background = include_background
self.to_onehot_y = to_onehot_y
self.logit_thresh = logit_thresh
self.add_sigmoid = add_sigmoid
self.mutually_exclusive = mutually_exclusive

self._sum = 0
self._num_examples = 0

@reinit__is_reduced
def reset(self):
self._sum = 0
self._num_examples = 0

@reinit__is_reduced
def update(self, output: Sequence[Union[torch.Tensor, dict]]):
assert len(output) == 2, 'MeanDice metric can only support y_pred and y.'
y_pred, y = output
average = compute_meandice(y_pred, y, self.include_background, self.to_onehot_y, self.mutually_exclusive,
self.add_sigmoid, self.logit_thresh)

batch_size = len(y)
self._sum += average.item() * batch_size
self._num_examples += batch_size

@sync_all_reduce("_sum", "_num_examples")
def compute(self):
if self._num_examples == 0:
raise NotComputableError(
'MeanDice must have at least one example before it can be computed.')
return self._sum / self._num_examples
28 changes: 18 additions & 10 deletions monai/networks/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,27 +8,35 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Utilities and types for defining networks, these depend on Pytorch.
"""

import torch
import torch.nn as nn
import torch.nn.functional as f


def one_hot(labels, num_classes):
"""
For a tensor `labels' of dimensions BC[D][H]W, return a tensor of dimensions BC[D][H]WN for `num_classes' N number of
classes. For every value v = labels[b,c,h,w], the value in the result at [b,c,h,w,v] will be 1 and all others 0.
Note that this will include the background label, thus a binary mask should be treated as having 2 classes.
"""
onehotshape = tuple(labels.shape) + (num_classes,)
labels = labels % num_classes
y = torch.eye(num_classes, device=labels.device)
onehot = y[labels.view(-1).long()]
For a tensor `labels' of dimensions B1[spatial_dims], return a tensor of dimensions BN[spatial_dims]
for `num_classes' N number of classes.

return onehot.reshape(*onehotshape)
Example:
For every value v = labels[b,1,h,w], the value in the result at [b,v,h,w] will be 1 and all others 0.
Note that this will include the background label, thus a binary mask should be treated as having 2 classes.
"""
num_dims = labels.dim()
if num_dims < 2 or labels.shape[1] != 1:
raise ValueError('labels should have a channel with length equals to one.')

labels = torch.squeeze(labels, 1)
labels = f.one_hot(labels.long(), num_classes)
new_axes = [0, -1] + list(range(1, num_dims - 1))
labels = labels.permute(*new_axes)
if not labels.is_contiguous():
return labels.contiguous()
return labels


def slice_channels(tensor, *slicevals):
Expand Down
63 changes: 63 additions & 0 deletions monai/utils/compute_meandice.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import torch

from monai.networks.utils import one_hot


def compute_meandice(y_pred,
y,
include_background=False,
to_onehot_y=True,
mutually_exclusive=True,
add_sigmoid=False,
logit_thresh=None):
"""Computes dice score metric from full size Tensor and collects average.

Args:
y_pred (torch.Tensor): input data to compute, typical segmentation model output.
it must be One-Hot format and first dim is batch, example shape: [16, 3, 32, 32].
y (torch.Tensor): ground truth to compute mean dice metric, the first dim is batch.
include_background (Bool): whether to skip dice computation on the first channel of the predicted output.
to_onehot_y (Bool): whether to convert `y` into the one-hot format.
mutually_exclusive (Bool): if True, `y_pred` will be converted into a binary matrix using
a combination of argmax and to_onehot.
add_sigmoid (Bool): whether to add sigmoid function to y_pred before computation.
logit_thresh (Float): the threshold value used to convert `y_pred` into a binary matrix.

Note:
This method provide two options to convert `y_pred` into a binary matrix:
(1) when `mutually_exclusive` is True, it uses a combination of argmax and to_onehot
(2) when `mutually_exclusive` is False, it uses a threshold `logit_thresh`
(optionally with a sigmoid function before thresholding).

"""
n_channels_y_pred = y_pred.shape[1]

if mutually_exclusive:
if logit_thresh is not None:
raise ValueError('`logit_thresh` is incompatible when mutually_exclusive is True.')
y_pred = torch.argmax(y_pred, dim=1, keepdim=True)
y_pred = one_hot(y_pred, n_channels_y_pred)
else: # channel-wise thresholding
if add_sigmoid:
y_pred = torch.sigmoid(y_pred)
if logit_thresh is not None:
y_pred = (y_pred >= logit_thresh).float()

if to_onehot_y:
y = one_hot(y, n_channels_y_pred)

if not include_background:
y = y[:, 1:] if y.shape[1] > 1 else y
y_pred = y_pred[:, 1:] if y_pred.shape[1] > 1 else y_pred

# reducing only spatial dimensions (not batch nor channels)
reduce_axis = list(range(2, y_pred.dim()))
intersection = torch.sum(y * y_pred, reduce_axis)

y_o = torch.sum(y, reduce_axis)
y_pred_o = torch.sum(y_pred, reduce_axis)
denominator = y_o + y_pred_o

f = (2.0 * intersection) / denominator
# final reduce_mean across batches and channels
return torch.mean(f)
61 changes: 61 additions & 0 deletions tests/test_compute_meandice.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# Copyright 2020 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import torch
from parameterized import parameterized

from monai.utils.compute_meandice import compute_meandice

# keep background
TEST_CASE_1 = [
{
'y_pred': torch.tensor([[[[1., -1.], [-1., 1.]]]]),
'y': torch.tensor([[[[1., 0.], [1., 1.]]]]),
'include_background': True,
'to_onehot_y': False,
'mutually_exclusive': False,
'logit_thresh': 0.5,
'add_sigmoid': True,
},
0.8000,
]

# remove background and not One-Hot target
TEST_CASE_2 = [
{
'y_pred':
torch.tensor([[[[-1., 3.], [2., -4.]], [[0., -1.], [3., 2.]], [[0., 1.], [2., -1.]]],
[[[-2., 0.], [3., 1.]], [[0., 2.], [1., -2.]], [[-1., 2.], [4., 0.]]]]),
'y':
torch.tensor([[[[1, 2], [1, 0]]], [[[1, 1], [2, 0]]]]),
'include_background':
False,
'to_onehot_y':
True,
'mutually_exclusive':
True,
},
0.4583,
]


class TestComputeMeanDice(unittest.TestCase):

@parameterized.expand([TEST_CASE_1, TEST_CASE_2])
def test_value(self, input_data, expected_value):
result = compute_meandice(**input_data)
self.assertAlmostEqual(result.item(), expected_value, places=4)


if __name__ == '__main__':
unittest.main()
47 changes: 43 additions & 4 deletions tests/test_dice_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@

from monai.networks.losses.dice import DiceLoss

TEST_CASE_1 = [
TEST_CASE_1 = [ # shape: (1, 1, 2, 2), (1, 1, 2, 2)
{
'include_background': False,
'include_background': True,
'do_sigmoid': True,
},
{
'pred': torch.tensor([[[[1., -1.], [-1., 1.]]]]),
Expand All @@ -28,9 +29,10 @@
0.307576,
]

TEST_CASE_2 = [
TEST_CASE_2 = [ # shape: (2, 1, 2, 2), (2, 1, 2, 2)
{
'include_background': True,
'do_sigmoid': True,
},
{
'pred': torch.tensor([[[[1., -1.], [-1., 1.]]], [[[1., -1.], [-1., 1.]]]]),
Expand All @@ -40,10 +42,47 @@
0.416636,
]

TEST_CASE_3 = [ # shape: (2, 2, 3), (2, 1, 3)
{
'include_background': True,
},
{
'pred': torch.tensor([[[1., 1., 0.], [0., 0., 1.]], [[1., 0., 1.], [0., 1., 0.]]]),
'ground': torch.tensor([[[0., 0., 1.]], [[0., 1., 0.]]]),
'smooth': 0.0,
},
0.0,
]

TEST_CASE_4 = [ # shape: (2, 2, 3), (2, 1, 3)
{
'include_background': True,
'do_sigmoid': True,
},
{
'pred': torch.tensor([[[-1., 0., 1.], [1., 0., -1.]], [[0., 0., 0.], [0., 0., 0.]]]),
'ground': torch.tensor([[[1., 0., 0.]], [[1., 1., 0.]]]),
'smooth': 1e-4,
},
0.422957,
]

TEST_CASE_5 = [ # shape: (2, 2, 3), (2, 1, 3)
{
'include_background': True,
'do_softmax': True,
},
{
'pred': torch.tensor([[[-1., 0., 1.], [1., 0., -1.]], [[0., 0., 0.], [0., 0., 0.]]]),
'ground': torch.tensor([[[1., 0., 0.]], [[1., 1., 0.]]]),
'smooth': 1e-4,
},
0.373045,
]

class TestDiceLoss(unittest.TestCase):

@parameterized.expand([TEST_CASE_1, TEST_CASE_2])
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5])
def test_shape(self, input_param, input_data, expected_val):
result = DiceLoss(**input_param).forward(**input_data)
self.assertAlmostEqual(result.item(), expected_val, places=5)
Expand Down
Loading