From 1eda872e990daca938da780572758134dd43537a Mon Sep 17 00:00:00 2001 From: yiheng-wang-nv Date: Wed, 6 Jan 2021 15:10:47 +0800 Subject: [PATCH] Implement dice cross entropy loss Signed-off-by: yiheng-wang-nv --- docs/source/losses.rst | 5 ++ monai/losses/__init__.py | 1 + monai/losses/dice.py | 108 +++++++++++++++++++++++++++++++++++++ tests/test_dice_ce_loss.py | 69 ++++++++++++++++++++++++ 4 files changed, 183 insertions(+) create mode 100644 tests/test_dice_ce_loss.py diff --git a/docs/source/losses.rst b/docs/source/losses.rst index 3f87f172d5..cffb437277 100644 --- a/docs/source/losses.rst +++ b/docs/source/losses.rst @@ -43,6 +43,11 @@ Segmentation Losses .. autoclass:: generalized_wasserstein_dice :members: +`DiceCELoss` +~~~~~~~~~~~~ +.. autoclass:: DiceCELoss + :members: + `FocalLoss` ~~~~~~~~~~~ .. autoclass:: FocalLoss diff --git a/monai/losses/__init__.py b/monai/losses/__init__.py index 7c3ca0cfe1..dd358898c3 100644 --- a/monai/losses/__init__.py +++ b/monai/losses/__init__.py @@ -11,6 +11,7 @@ from .dice import ( Dice, + DiceCELoss, DiceLoss, GeneralizedDiceLoss, GeneralizedWassersteinDiceLoss, diff --git a/monai/losses/dice.py b/monai/losses/dice.py index 998ac38a76..bc3f7238e1 100644 --- a/monai/losses/dice.py +++ b/monai/losses/dice.py @@ -14,6 +14,7 @@ import numpy as np import torch +import torch.nn as nn import torch.nn.functional as F from torch.nn.modules.loss import _Loss @@ -594,6 +595,113 @@ def _compute_alpha_generalized_true_positives(self, flat_target: torch.Tensor) - return alpha +class DiceCELoss: + """ + Compute both Dice loss and Cross Entropy Loss, and return the sum of these two losses. + Input logits `input` (BNHW[D] where N is number of classes) is compared with ground truth `target` (BNHW[D]). + Axis N of `input` is expected to have logit predictions for each class rather than being image channels, + while the same axis of `target` can be 1 or N (one-hot format). The `smooth_nr` and `smooth_dr` parameters are + values added for dice loss part to the intersection and union components of the inter-over-union calculation + to smooth results respectively, these values should be small. The `include_background` class attribute can be + set to False for an instance of the loss to exclude the first category (channel index 0) which is by convention + assumed to be background. If the non-background segmentations are small compared to the total image size they can get + overwhelmed by the signal from the background so excluding it in such cases helps convergence. + """ + + def __init__( + self, + include_background: bool = True, + to_onehot_y: bool = False, + sigmoid: bool = False, + softmax: bool = False, + other_act: Optional[Callable] = None, + squared_pred: bool = False, + jaccard: bool = False, + reduction: str = "mean", + smooth_nr: float = 1e-5, + smooth_dr: float = 1e-5, + batch: bool = False, + ce_weight: Optional[torch.Tensor] = None, + ) -> None: + """ + Args: + ``ce_weight`` is only used for cross entropy loss, ``reduction`` is used for both losses and other + parameters are only used for dice loss. + + include_background: if False channel index 0 (background category) is excluded from the calculation. + to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False. + sigmoid: if True, apply a sigmoid function to the prediction. + softmax: if True, apply a softmax function to the prediction. + other_act: if don't want to use `sigmoid` or `softmax`, use other callable function to execute + other activation layers, Defaults to ``None``. for example: + `other_act = torch.tanh`. + squared_pred: use squared versions of targets and predictions in the denominator or not. + jaccard: compute Jaccard Index (soft IoU) instead of dice or not. + reduction: {``"mean"``, ``"sum"``} + Specifies the reduction to apply to the output. Defaults to ``"mean"``. The dice loss should + as least reduce the spatial dimensions, which is different from cross entropy loss, thus here + the ``none`` option cannot be used. + + - ``"mean"``: the sum of the output will be divided by the number of elements in the output. + - ``"sum"``: the output will be summed. + + smooth_nr: a small constant added to the numerator to avoid zero. + smooth_dr: a small constant added to the denominator to avoid nan. + batch: whether to sum the intersection and union areas over the batch dimension before the dividing. + Defaults to False, a Dice loss value is computed independently from each item in the batch + before any `reduction`. + ce_weight: a rescaling weight given to each class for cross entropy loss. + See ``torch.nn.CrossEntropyLoss()`` for more information. + + """ + super().__init__() + self.dice = DiceLoss( + include_background=include_background, + to_onehot_y=to_onehot_y, + sigmoid=sigmoid, + softmax=softmax, + other_act=other_act, + squared_pred=squared_pred, + jaccard=jaccard, + reduction=reduction, + smooth_nr=smooth_nr, + smooth_dr=smooth_dr, + batch=batch, + ) + self.cross_entropy = nn.CrossEntropyLoss( + weight=ce_weight, + reduction=reduction, + ) + + def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + """ + Args: + input: the shape should be BNH[WD]. + target: the shape should be BNH[WD] or B1H[WD]. + + Raises: + ValueError: When number of dimensions for input and target are different. + ValueError: When number of channels for target is nither 1 or the same as input. + + """ + if len(input.shape) != len(target.shape): + raise ValueError("the number of dimensions for input and target should be the same.") + + dice_loss = self.dice(input, target) + + n_pred_ch, n_target_ch = input.shape[1], target.shape[1] + if n_pred_ch == n_target_ch: + # target is in the one-hot format, convert to BH[WD] format to calculate ce loss + target = torch.argmax(target, dim=1) + else: + target = torch.squeeze(target, dim=1) + target = target.long() + ce_loss = self.cross_entropy(input, target) + total_loss: torch.Tensor = dice_loss + ce_loss + return total_loss + + dice = Dice = DiceLoss +dice_ce = DiceCELoss generalized_dice = GeneralizedDiceLoss generalized_wasserstein_dice = GeneralizedWassersteinDiceLoss diff --git a/tests/test_dice_ce_loss.py b/tests/test_dice_ce_loss.py new file mode 100644 index 0000000000..ff42c8a1ec --- /dev/null +++ b/tests/test_dice_ce_loss.py @@ -0,0 +1,69 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import torch +from parameterized import parameterized + +from monai.losses import DiceCELoss + +TEST_CASES = [ + [ # shape: (2, 2, 3), (2, 1, 3) + {"to_onehot_y": True}, + { + "input": torch.tensor([[[1.0, 1.0, 0.0], [0.0, 0.0, 1.0]], [[1.0, 0.0, 1.0], [0.0, 1.0, 0.0]]]), + "target": torch.tensor([[[0.0, 0.0, 1.0]], [[0.0, 1.0, 0.0]]]), + }, + 0.3133, # the result equals to -1 + np.log(1 + np.exp(1)) + ], + [ # shape: (2, 2, 3), (2, 2, 3), one-hot target + {"to_onehot_y": False}, + { + "input": torch.tensor([[[1.0, 1.0, 0.0], [0.0, 0.0, 1.0]], [[1.0, 0.0, 1.0], [0.0, 1.0, 0.0]]]), + "target": torch.tensor([[[1.0, 1.0, 0.0], [0.0, 0.0, 1.0]], [[1.0, 0.0, 1.0], [0.0, 1.0, 0.0]]]), + }, + 0.3133, + ], + [ # shape: (2, 2, 3), (2, 1, 3) + {"include_background": False, "to_onehot_y": True, "ce_weight": torch.tensor([1.0, 1.0])}, + { + "input": torch.tensor([[[100.0, 100.0, 0.0], [0.0, 0.0, 1.0]], [[1.0, 0.0, 1.0], [0.0, 1.0, 0.0]]]), + "target": torch.tensor([[[0.0, 0.0, 1.0]], [[0.0, 1.0, 0.0]]]), + }, + 0.2088, + ], + [ # shape: (2, 2, 3), (2, 1, 3), do not include class 0 + {"include_background": False, "to_onehot_y": True, "ce_weight": torch.tensor([0.0, 1.0])}, + { + "input": torch.tensor([[[100.0, 100.0, 0.0], [0.0, 0.0, 1.0]], [[1.0, 0.0, 1.0], [0.0, 1.0, 0.0]]]), + "target": torch.tensor([[[0.0, 0.0, 1.0]], [[0.0, 1.0, 0.0]]]), + }, + 0.3133, + ], +] + + +class TestDiceCELoss(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_result(self, input_param, input_data, expected_val): + result = DiceCELoss(**input_param).forward(**input_data) + np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, atol=1e-4, rtol=1e-4) + + def test_ill_shape(self): + loss = DiceCELoss() + with self.assertRaisesRegex(ValueError, ""): + loss.forward(torch.ones((1, 2, 3)), torch.ones((1, 1, 2, 3))) + + +if __name__ == "__main__": + unittest.main()