From 16b55dd2af32284b2a734ff075542986b169f632 Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Thu, 1 Apr 2021 17:38:05 +0800 Subject: [PATCH] Implement dice_focal loss Signed-off-by: Yiheng Wang --- docs/source/losses.rst | 5 ++ monai/losses/__init__.py | 3 + monai/losses/dice.py | 150 ++++++++++++++++++++++++++++++---- monai/losses/focal_loss.py | 14 +++- monai/networks/nets/senet.py | 4 +- tests/test_dice_ce_loss.py | 14 ++++ tests/test_dice_focal_loss.py | 80 ++++++++++++++++++ tests/test_focal_loss.py | 10 +++ 8 files changed, 263 insertions(+), 17 deletions(-) create mode 100644 tests/test_dice_focal_loss.py diff --git a/docs/source/losses.rst b/docs/source/losses.rst index 5e19219fee..eea6656a24 100644 --- a/docs/source/losses.rst +++ b/docs/source/losses.rst @@ -48,6 +48,11 @@ Segmentation Losses .. autoclass:: DiceCELoss :members: +`DiceFocalLoss` +~~~~~~~~~~~~~~~ +.. autoclass:: DiceFocalLoss + :members: + `FocalLoss` ~~~~~~~~~~~ .. autoclass:: FocalLoss diff --git a/monai/losses/__init__.py b/monai/losses/__init__.py index b9146a6962..78a0fbc191 100644 --- a/monai/losses/__init__.py +++ b/monai/losses/__init__.py @@ -13,11 +13,14 @@ from .dice import ( Dice, DiceCELoss, + DiceFocalLoss, DiceLoss, GeneralizedDiceLoss, GeneralizedWassersteinDiceLoss, MaskedDiceLoss, dice, + dice_ce, + dice_focal, generalized_dice, generalized_wasserstein_dice, ) diff --git a/monai/losses/dice.py b/monai/losses/dice.py index 65bf47f388..47af8ea171 100644 --- a/monai/losses/dice.py +++ b/monai/losses/dice.py @@ -10,7 +10,7 @@ # limitations under the License. import warnings -from typing import Callable, List, Optional, Union +from typing import Callable, List, Optional, Sequence, Union import numpy as np import torch @@ -18,6 +18,7 @@ import torch.nn.functional as F from torch.nn.modules.loss import _Loss +from monai.losses.focal_loss import FocalLoss from monai.networks import one_hot from monai.utils import LossReduction, Weight @@ -600,15 +601,12 @@ def _compute_alpha_generalized_true_positives(self, flat_target: torch.Tensor) - class DiceCELoss(_Loss): """ - Compute both Dice loss and Cross Entropy Loss, and return the sum of these two losses. - Input logits `input` (BNHW[D] where N is number of classes) is compared with ground truth `target` (BNHW[D]). - Axis N of `input` is expected to have logit predictions for each class rather than being image channels, - while the same axis of `target` can be 1 or N (one-hot format). The `smooth_nr` and `smooth_dr` parameters are - values added for dice loss part to the intersection and union components of the inter-over-union calculation - to smooth results respectively, these values should be small. The `include_background` class attribute can be - set to False for an instance of the loss to exclude the first category (channel index 0) which is by convention - assumed to be background. If the non-background segmentations are small compared to the total image size they can get - overwhelmed by the signal from the background so excluding it in such cases helps convergence. + Compute both Dice loss and Cross Entropy Loss, and return the weighted sum of these two losses. + The details of Dice loss is shown in ``monai.losses.DiceLoss``. + The details of Cross Entropy Loss is shown in ``torch.nn.CrossEntropyLoss``. In this implementation, + two deprecated parameters ``size_average`` and ``reduce``, and the parameter ``ignore_index`` are + not supported. + """ def __init__( @@ -625,11 +623,13 @@ def __init__( smooth_dr: float = 1e-5, batch: bool = False, ce_weight: Optional[torch.Tensor] = None, + lambda_dice: float = 1.0, + lambda_ce: float = 1.0, ) -> None: """ Args: - ``ce_weight`` is only used for cross entropy loss, ``reduction`` is used for both losses and other - parameters are only used for dice loss. + ``ce_weight`` and ``lambda_ce`` are only used for cross entropy loss. + ``reduction`` is used for both losses and other parameters are only used for dice loss. include_background: if False channel index 0 (background category) is excluded from the calculation. to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False. @@ -655,6 +655,10 @@ def __init__( before any `reduction`. ce_weight: a rescaling weight given to each class for cross entropy loss. See ``torch.nn.CrossEntropyLoss()`` for more information. + lambda_dice: the trade-off weight value for dice loss. The value should be no less than 0.0. + Defaults to 1.0. + lambda_ce: the trade-off weight value for cross entropy loss. The value should be no less than 0.0. + Defaults to 1.0. """ super().__init__() @@ -675,6 +679,12 @@ def __init__( weight=ce_weight, reduction=reduction, ) + if lambda_dice < 0.0: + raise ValueError("lambda_dice should be no less than 0.0.") + if lambda_ce < 0.0: + raise ValueError("lambda_ce should be no less than 0.0.") + self.lambda_dice = lambda_dice + self.lambda_ce = lambda_ce def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ @@ -684,7 +694,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: Raises: ValueError: When number of dimensions for input and target are different. - ValueError: When number of channels for target is nither 1 or the same as input. + ValueError: When number of channels for target is neither 1 nor the same as input. """ if len(input.shape) != len(target.shape): @@ -700,11 +710,123 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: target = torch.squeeze(target, dim=1) target = target.long() ce_loss = self.cross_entropy(input, target) - total_loss: torch.Tensor = dice_loss + ce_loss + total_loss: torch.Tensor = self.lambda_dice * dice_loss + self.lambda_ce * ce_loss + return total_loss + + +class DiceFocalLoss(_Loss): + """ + Compute both Dice loss and Focal Loss, and return the weighted sum of these two losses. + The details of Dice loss is shown in ``monai.losses.DiceLoss``. + The details of Focal Loss is shown in ``monai.losses.FocalLoss``. + + """ + + def __init__( + self, + include_background: bool = True, + to_onehot_y: bool = False, + sigmoid: bool = False, + softmax: bool = False, + other_act: Optional[Callable] = None, + squared_pred: bool = False, + jaccard: bool = False, + reduction: str = "mean", + smooth_nr: float = 1e-5, + smooth_dr: float = 1e-5, + batch: bool = False, + gamma: float = 2.0, + focal_weight: Optional[Union[Sequence[float], float, int, torch.Tensor]] = None, + lambda_dice: float = 1.0, + lambda_focal: float = 1.0, + ) -> None: + """ + Args: + ``gamma``, ``focal_weight`` and ``lambda_focal`` are only used for focal loss. + ``include_background``, ``to_onehot_y``and ``reduction`` are used for both losses + and other parameters are only used for dice loss. + include_background: if False channel index 0 (background category) is excluded from the calculation. + to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False. + sigmoid: if True, apply a sigmoid function to the prediction. + softmax: if True, apply a softmax function to the prediction. + other_act: if don't want to use `sigmoid` or `softmax`, use other callable function to execute + other activation layers, Defaults to ``None``. for example: + `other_act = torch.tanh`. + squared_pred: use squared versions of targets and predictions in the denominator or not. + jaccard: compute Jaccard Index (soft IoU) instead of dice or not. + reduction: {``"none"``, ``"mean"``, ``"sum"``} + Specifies the reduction to apply to the output. Defaults to ``"mean"``. + + - ``"none"``: no reduction will be applied. + - ``"mean"``: the sum of the output will be divided by the number of elements in the output. + - ``"sum"``: the output will be summed. + + smooth_nr: a small constant added to the numerator to avoid zero. + smooth_dr: a small constant added to the denominator to avoid nan. + batch: whether to sum the intersection and union areas over the batch dimension before the dividing. + Defaults to False, a Dice loss value is computed independently from each item in the batch + before any `reduction`. + gamma: value of the exponent gamma in the definition of the Focal loss. + focal_weight: weights to apply to the voxels of each class. If None no weights are applied. + The input can be a single value (same weight for all classes), a sequence of values (the length + of the sequence should be the same as the number of classes). + lambda_dice: the trade-off weight value for dice loss. The value should be no less than 0.0. + Defaults to 1.0. + lambda_focal: the trade-off weight value for focal loss. The value should be no less than 0.0. + Defaults to 1.0. + + """ + super().__init__() + self.dice = DiceLoss( + include_background=include_background, + to_onehot_y=to_onehot_y, + sigmoid=sigmoid, + softmax=softmax, + other_act=other_act, + squared_pred=squared_pred, + jaccard=jaccard, + reduction=reduction, + smooth_nr=smooth_nr, + smooth_dr=smooth_dr, + batch=batch, + ) + self.focal = FocalLoss( + include_background=include_background, + to_onehot_y=to_onehot_y, + gamma=gamma, + weight=focal_weight, + reduction=reduction, + ) + if lambda_dice < 0.0: + raise ValueError("lambda_dice should be no less than 0.0.") + if lambda_focal < 0.0: + raise ValueError("lambda_focal should be no less than 0.0.") + self.lambda_dice = lambda_dice + self.lambda_focal = lambda_focal + + def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + """ + Args: + input: the shape should be BNH[WD]. The input should be the original logits + due to the restriction of ``monai.losses.FocalLoss``. + target: the shape should be BNH[WD] or B1H[WD]. + + Raises: + ValueError: When number of dimensions for input and target are different. + ValueError: When number of channels for target is neither 1 nor the same as input. + + """ + if len(input.shape) != len(target.shape): + raise ValueError("the number of dimensions for input and target should be the same.") + + dice_loss = self.dice(input, target) + focal_loss = self.focal(input, target) + total_loss: torch.Tensor = self.lambda_dice * dice_loss + self.lambda_focal * focal_loss return total_loss dice = Dice = DiceLoss dice_ce = DiceCELoss +dice_focal = DiceFocalLoss generalized_dice = GeneralizedDiceLoss generalized_wasserstein_dice = GeneralizedWassersteinDiceLoss diff --git a/monai/losses/focal_loss.py b/monai/losses/focal_loss.py index 664e7673a4..5e0ccd3179 100644 --- a/monai/losses/focal_loss.py +++ b/monai/losses/focal_loss.py @@ -45,7 +45,9 @@ def __init__( weight: weights to apply to the voxels of each class. If None no weights are applied. This corresponds to the weights `\alpha` in [1]. The input can be a single value (same weight for all classes), a sequence of values (the length - of the sequence should be the same as the number of classes). + of the sequence should be the same as the number of classes, if not ``include_background``, the + number should not include class 0). + The value/values should be no less than 0. Defaults to None. reduction: {``"none"``, ``"mean"``, ``"sum"``} Specifies the reduction to apply to the output. Defaults to ``"mean"``. @@ -83,6 +85,9 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: AssertionError: When input and target (after one hot transform if setted) have different shapes. ValueError: When ``self.reduction`` is not one of ["mean", "sum", "none"]. + ValueError: When ``self.weight`` is a sequence and the length is not equal to the + number of classes. + ValueError: When ``self.weight`` is/contains a value that is less than 0. """ n_pred_ch = input.shape[1] @@ -122,6 +127,13 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: class_weight = torch.as_tensor([self.weight] * i.size(1)) else: class_weight = torch.as_tensor(self.weight) + if class_weight.size(0) != i.size(1): + raise ValueError( + "the length of the weight sequence should be the same as the number of classes. " + + "If `include_background=False`, the number should not include class 0." + ) + if class_weight.min() < 0: + raise ValueError("the value/values of weights should be no less than 0.") class_weight = class_weight.to(i) # Convert the weight to a map in which each voxel # has the weight associated with the ground-truth label diff --git a/monai/networks/nets/senet.py b/monai/networks/nets/senet.py index f5738edeeb..1e04e02973 100644 --- a/monai/networks/nets/senet.py +++ b/monai/networks/nets/senet.py @@ -263,8 +263,8 @@ def _load_state_dict(model, arch, progress): model_url = model_urls[arch] else: raise ValueError( - "only 'senet154', 'se_resnet50', 'se_resnet101', 'se_resnet152', 'se_resnext50_32x4d', \ - and se_resnext101_32x4d are supported to load pretrained weights." + "only 'senet154', 'se_resnet50', 'se_resnet101', 'se_resnet152', 'se_resnext50_32x4d', " + + "and se_resnext101_32x4d are supported to load pretrained weights." ) pattern_conv = re.compile(r"^(layer[1-4]\.\d\.(?:conv)\d\.)(\w*)$") diff --git a/tests/test_dice_ce_loss.py b/tests/test_dice_ce_loss.py index 8627c6d130..3423e1425b 100644 --- a/tests/test_dice_ce_loss.py +++ b/tests/test_dice_ce_loss.py @@ -43,6 +43,20 @@ }, 0.2088, ], + [ # shape: (2, 2, 3), (2, 1, 3) lambda_dice: 1.0, lambda_ce: 2.0 + { + "include_background": False, + "to_onehot_y": True, + "ce_weight": torch.tensor([1.0, 1.0]), + "lambda_dice": 1.0, + "lambda_ce": 2.0, + }, + { + "input": torch.tensor([[[100.0, 100.0, 0.0], [0.0, 0.0, 1.0]], [[1.0, 0.0, 1.0], [0.0, 1.0, 0.0]]]), + "target": torch.tensor([[[0.0, 0.0, 1.0]], [[0.0, 1.0, 0.0]]]), + }, + 0.4176, + ], [ # shape: (2, 2, 3), (2, 1, 3), do not include class 0 {"include_background": False, "to_onehot_y": True, "ce_weight": torch.tensor([0.0, 1.0])}, { diff --git a/tests/test_dice_focal_loss.py b/tests/test_dice_focal_loss.py new file mode 100644 index 0000000000..4bab68131c --- /dev/null +++ b/tests/test_dice_focal_loss.py @@ -0,0 +1,80 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import torch + +from monai.losses import DiceFocalLoss, DiceLoss, FocalLoss +from tests.utils import SkipIfBeforePyTorchVersion, test_script_save + + +class TestDiceFocalLoss(unittest.TestCase): + def test_result_onehot_target_include_bg(self): + size = [3, 3, 5, 5] + label = torch.randint(low=0, high=2, size=size) + pred = torch.randn(size) + for reduction in ["sum", "mean", "none"]: + common_params = { + "include_background": True, + "to_onehot_y": False, + "reduction": reduction, + } + for focal_weight in [None, torch.tensor([1.0, 1.0, 2.0]), (3, 2.0, 1)]: + for lambda_focal in [0.5, 1.0, 1.5]: + dice_focal = DiceFocalLoss( + focal_weight=focal_weight, gamma=1.0, lambda_focal=lambda_focal, **common_params + ) + dice = DiceLoss(**common_params) + focal = FocalLoss(weight=focal_weight, gamma=1.0, **common_params) + result = dice_focal(pred, label) + expected_val = dice(pred, label) + lambda_focal * focal(pred, label) + np.testing.assert_allclose(result, expected_val) + + def test_result_no_onehot_no_bg(self): + size = [3, 3, 5, 5] + label = torch.randint(low=0, high=2, size=size) + label = torch.argmax(label, dim=1, keepdim=True) + pred = torch.randn(size) + for reduction in ["sum", "mean", "none"]: + common_params = { + "include_background": False, + "to_onehot_y": True, + "reduction": reduction, + } + for focal_weight in [2.0, torch.tensor([1.0, 2.0]), (2.0, 1)]: + for lambda_focal in [0.5, 1.0, 1.5]: + dice_focal = DiceFocalLoss(focal_weight=focal_weight, lambda_focal=lambda_focal, **common_params) + dice = DiceLoss(**common_params) + focal = FocalLoss(weight=focal_weight, **common_params) + result = dice_focal(pred, label) + expected_val = dice(pred, label) + lambda_focal * focal(pred, label) + np.testing.assert_allclose(result, expected_val) + + def test_ill_shape(self): + loss = DiceFocalLoss() + with self.assertRaisesRegex(ValueError, ""): + loss(torch.ones((1, 2, 3)), torch.ones((1, 1, 2, 3))) + + def test_ill_lambda(self): + with self.assertRaisesRegex(ValueError, ""): + loss = DiceFocalLoss(lambda_dice=-1.0) + + @SkipIfBeforePyTorchVersion((1, 7, 0)) + def test_script(self): + loss = DiceFocalLoss() + test_input = torch.ones(2, 1, 8, 8) + test_script_save(loss, test_input, test_input) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_focal_loss.py b/tests/test_focal_loss.py index 4512dac4b9..66665774ef 100644 --- a/tests/test_focal_loss.py +++ b/tests/test_focal_loss.py @@ -187,6 +187,16 @@ def test_ill_shape(self): with self.assertRaisesRegex(AssertionError, ""): FocalLoss(reduction="mean")(chn_input, chn_target) + def test_ill_class_weight(self): + chn_input = torch.ones((1, 4, 3, 3)) + chn_target = torch.ones((1, 4, 3, 3)) + with self.assertRaisesRegex(ValueError, ""): + FocalLoss(include_background=True, weight=(1.0, 1.0, 2.0))(chn_input, chn_target) + with self.assertRaisesRegex(ValueError, ""): + FocalLoss(include_background=False, weight=(1.0, 1.0, 1.0, 1.0))(chn_input, chn_target) + with self.assertRaisesRegex(ValueError, ""): + FocalLoss(include_background=False, weight=(1.0, 1.0, -1.0))(chn_input, chn_target) + @SkipIfBeforePyTorchVersion((1, 7, 0)) def test_script(self): loss = FocalLoss()