Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/source/losses.rst
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,11 @@ Segmentation Losses
.. autoclass:: DiceCELoss
:members:

`DiceFocalLoss`
~~~~~~~~~~~~~~~
.. autoclass:: DiceFocalLoss
:members:

`FocalLoss`
~~~~~~~~~~~
.. autoclass:: FocalLoss
Expand Down
3 changes: 3 additions & 0 deletions monai/losses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,14 @@
from .dice import (
Dice,
DiceCELoss,
DiceFocalLoss,
DiceLoss,
GeneralizedDiceLoss,
GeneralizedWassersteinDiceLoss,
MaskedDiceLoss,
dice,
dice_ce,
dice_focal,
generalized_dice,
generalized_wasserstein_dice,
)
Expand Down
150 changes: 136 additions & 14 deletions monai/losses/dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,15 @@
# limitations under the License.

import warnings
from typing import Callable, List, Optional, Union
from typing import Callable, List, Optional, Sequence, Union

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.modules.loss import _Loss

from monai.losses.focal_loss import FocalLoss
from monai.networks import one_hot
from monai.utils import LossReduction, Weight

Expand Down Expand Up @@ -600,15 +601,12 @@ def _compute_alpha_generalized_true_positives(self, flat_target: torch.Tensor) -

class DiceCELoss(_Loss):
"""
Compute both Dice loss and Cross Entropy Loss, and return the sum of these two losses.
Input logits `input` (BNHW[D] where N is number of classes) is compared with ground truth `target` (BNHW[D]).
Axis N of `input` is expected to have logit predictions for each class rather than being image channels,
while the same axis of `target` can be 1 or N (one-hot format). The `smooth_nr` and `smooth_dr` parameters are
values added for dice loss part to the intersection and union components of the inter-over-union calculation
to smooth results respectively, these values should be small. The `include_background` class attribute can be
set to False for an instance of the loss to exclude the first category (channel index 0) which is by convention
assumed to be background. If the non-background segmentations are small compared to the total image size they can get
overwhelmed by the signal from the background so excluding it in such cases helps convergence.
Compute both Dice loss and Cross Entropy Loss, and return the weighted sum of these two losses.
The details of Dice loss is shown in ``monai.losses.DiceLoss``.
The details of Cross Entropy Loss is shown in ``torch.nn.CrossEntropyLoss``. In this implementation,
two deprecated parameters ``size_average`` and ``reduce``, and the parameter ``ignore_index`` are
not supported.

"""

def __init__(
Expand All @@ -625,11 +623,13 @@ def __init__(
smooth_dr: float = 1e-5,
batch: bool = False,
ce_weight: Optional[torch.Tensor] = None,
lambda_dice: float = 1.0,
lambda_ce: float = 1.0,
) -> None:
"""
Args:
``ce_weight`` is only used for cross entropy loss, ``reduction`` is used for both losses and other
parameters are only used for dice loss.
``ce_weight`` and ``lambda_ce`` are only used for cross entropy loss.
``reduction`` is used for both losses and other parameters are only used for dice loss.

include_background: if False channel index 0 (background category) is excluded from the calculation.
to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False.
Expand All @@ -655,6 +655,10 @@ def __init__(
before any `reduction`.
ce_weight: a rescaling weight given to each class for cross entropy loss.
See ``torch.nn.CrossEntropyLoss()`` for more information.
lambda_dice: the trade-off weight value for dice loss. The value should be no less than 0.0.
Defaults to 1.0.
lambda_ce: the trade-off weight value for cross entropy loss. The value should be no less than 0.0.
Defaults to 1.0.

"""
super().__init__()
Expand All @@ -675,6 +679,12 @@ def __init__(
weight=ce_weight,
reduction=reduction,
)
if lambda_dice < 0.0:
raise ValueError("lambda_dice should be no less than 0.0.")
if lambda_ce < 0.0:
raise ValueError("lambda_ce should be no less than 0.0.")
self.lambda_dice = lambda_dice
self.lambda_ce = lambda_ce

def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Expand All @@ -684,7 +694,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:

Raises:
ValueError: When number of dimensions for input and target are different.
ValueError: When number of channels for target is nither 1 or the same as input.
ValueError: When number of channels for target is neither 1 nor the same as input.

"""
if len(input.shape) != len(target.shape):
Expand All @@ -700,11 +710,123 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
target = torch.squeeze(target, dim=1)
target = target.long()
ce_loss = self.cross_entropy(input, target)
total_loss: torch.Tensor = dice_loss + ce_loss
total_loss: torch.Tensor = self.lambda_dice * dice_loss + self.lambda_ce * ce_loss
return total_loss


class DiceFocalLoss(_Loss):
"""
Compute both Dice loss and Focal Loss, and return the weighted sum of these two losses.
The details of Dice loss is shown in ``monai.losses.DiceLoss``.
The details of Focal Loss is shown in ``monai.losses.FocalLoss``.

"""

def __init__(
self,
include_background: bool = True,
to_onehot_y: bool = False,
sigmoid: bool = False,
softmax: bool = False,
other_act: Optional[Callable] = None,
squared_pred: bool = False,
jaccard: bool = False,
reduction: str = "mean",
smooth_nr: float = 1e-5,
smooth_dr: float = 1e-5,
batch: bool = False,
gamma: float = 2.0,
focal_weight: Optional[Union[Sequence[float], float, int, torch.Tensor]] = None,
lambda_dice: float = 1.0,
lambda_focal: float = 1.0,
) -> None:
"""
Args:
``gamma``, ``focal_weight`` and ``lambda_focal`` are only used for focal loss.
``include_background``, ``to_onehot_y``and ``reduction`` are used for both losses
and other parameters are only used for dice loss.
include_background: if False channel index 0 (background category) is excluded from the calculation.
to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False.
sigmoid: if True, apply a sigmoid function to the prediction.
softmax: if True, apply a softmax function to the prediction.
other_act: if don't want to use `sigmoid` or `softmax`, use other callable function to execute
other activation layers, Defaults to ``None``. for example:
`other_act = torch.tanh`.
squared_pred: use squared versions of targets and predictions in the denominator or not.
jaccard: compute Jaccard Index (soft IoU) instead of dice or not.
reduction: {``"none"``, ``"mean"``, ``"sum"``}
Specifies the reduction to apply to the output. Defaults to ``"mean"``.

- ``"none"``: no reduction will be applied.
- ``"mean"``: the sum of the output will be divided by the number of elements in the output.
- ``"sum"``: the output will be summed.

smooth_nr: a small constant added to the numerator to avoid zero.
smooth_dr: a small constant added to the denominator to avoid nan.
batch: whether to sum the intersection and union areas over the batch dimension before the dividing.
Defaults to False, a Dice loss value is computed independently from each item in the batch
before any `reduction`.
gamma: value of the exponent gamma in the definition of the Focal loss.
focal_weight: weights to apply to the voxels of each class. If None no weights are applied.
The input can be a single value (same weight for all classes), a sequence of values (the length
of the sequence should be the same as the number of classes).
lambda_dice: the trade-off weight value for dice loss. The value should be no less than 0.0.
Defaults to 1.0.
lambda_focal: the trade-off weight value for focal loss. The value should be no less than 0.0.
Defaults to 1.0.

"""
super().__init__()
self.dice = DiceLoss(
include_background=include_background,
to_onehot_y=to_onehot_y,
sigmoid=sigmoid,
softmax=softmax,
other_act=other_act,
squared_pred=squared_pred,
jaccard=jaccard,
reduction=reduction,
smooth_nr=smooth_nr,
smooth_dr=smooth_dr,
batch=batch,
)
self.focal = FocalLoss(
include_background=include_background,
to_onehot_y=to_onehot_y,
gamma=gamma,
weight=focal_weight,
reduction=reduction,
)
if lambda_dice < 0.0:
raise ValueError("lambda_dice should be no less than 0.0.")
if lambda_focal < 0.0:
raise ValueError("lambda_focal should be no less than 0.0.")
self.lambda_dice = lambda_dice
self.lambda_focal = lambda_focal

def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Args:
input: the shape should be BNH[WD]. The input should be the original logits
due to the restriction of ``monai.losses.FocalLoss``.
target: the shape should be BNH[WD] or B1H[WD].

Raises:
ValueError: When number of dimensions for input and target are different.
ValueError: When number of channels for target is neither 1 nor the same as input.

"""
if len(input.shape) != len(target.shape):
raise ValueError("the number of dimensions for input and target should be the same.")

dice_loss = self.dice(input, target)
focal_loss = self.focal(input, target)
total_loss: torch.Tensor = self.lambda_dice * dice_loss + self.lambda_focal * focal_loss
return total_loss


dice = Dice = DiceLoss
dice_ce = DiceCELoss
dice_focal = DiceFocalLoss
generalized_dice = GeneralizedDiceLoss
generalized_wasserstein_dice = GeneralizedWassersteinDiceLoss
14 changes: 13 additions & 1 deletion monai/losses/focal_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,9 @@ def __init__(
weight: weights to apply to the voxels of each class. If None no weights are applied.
This corresponds to the weights `\alpha` in [1].
The input can be a single value (same weight for all classes), a sequence of values (the length
of the sequence should be the same as the number of classes).
of the sequence should be the same as the number of classes, if not ``include_background``, the
number should not include class 0).
The value/values should be no less than 0. Defaults to None.
reduction: {``"none"``, ``"mean"``, ``"sum"``}
Specifies the reduction to apply to the output. Defaults to ``"mean"``.

Expand Down Expand Up @@ -83,6 +85,9 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
AssertionError: When input and target (after one hot transform if setted)
have different shapes.
ValueError: When ``self.reduction`` is not one of ["mean", "sum", "none"].
ValueError: When ``self.weight`` is a sequence and the length is not equal to the
number of classes.
ValueError: When ``self.weight`` is/contains a value that is less than 0.

"""
n_pred_ch = input.shape[1]
Expand Down Expand Up @@ -122,6 +127,13 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
class_weight = torch.as_tensor([self.weight] * i.size(1))
else:
class_weight = torch.as_tensor(self.weight)
if class_weight.size(0) != i.size(1):
raise ValueError(
"the length of the weight sequence should be the same as the number of classes. "
+ "If `include_background=False`, the number should not include class 0."
)
if class_weight.min() < 0:
raise ValueError("the value/values of weights should be no less than 0.")
class_weight = class_weight.to(i)
# Convert the weight to a map in which each voxel
# has the weight associated with the ground-truth label
Expand Down
4 changes: 2 additions & 2 deletions monai/networks/nets/senet.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,8 +263,8 @@ def _load_state_dict(model, arch, progress):
model_url = model_urls[arch]
else:
raise ValueError(
"only 'senet154', 'se_resnet50', 'se_resnet101', 'se_resnet152', 'se_resnext50_32x4d', \
and se_resnext101_32x4d are supported to load pretrained weights."
"only 'senet154', 'se_resnet50', 'se_resnet101', 'se_resnet152', 'se_resnext50_32x4d', "
+ "and se_resnext101_32x4d are supported to load pretrained weights."
)

pattern_conv = re.compile(r"^(layer[1-4]\.\d\.(?:conv)\d\.)(\w*)$")
Expand Down
14 changes: 14 additions & 0 deletions tests/test_dice_ce_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,20 @@
},
0.2088,
],
[ # shape: (2, 2, 3), (2, 1, 3) lambda_dice: 1.0, lambda_ce: 2.0
{
"include_background": False,
"to_onehot_y": True,
"ce_weight": torch.tensor([1.0, 1.0]),
"lambda_dice": 1.0,
"lambda_ce": 2.0,
},
{
"input": torch.tensor([[[100.0, 100.0, 0.0], [0.0, 0.0, 1.0]], [[1.0, 0.0, 1.0], [0.0, 1.0, 0.0]]]),
"target": torch.tensor([[[0.0, 0.0, 1.0]], [[0.0, 1.0, 0.0]]]),
},
0.4176,
],
[ # shape: (2, 2, 3), (2, 1, 3), do not include class 0
{"include_background": False, "to_onehot_y": True, "ce_weight": torch.tensor([0.0, 1.0])},
{
Expand Down
80 changes: 80 additions & 0 deletions tests/test_dice_focal_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# Copyright 2020 - 2021 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import numpy as np
import torch

from monai.losses import DiceFocalLoss, DiceLoss, FocalLoss
from tests.utils import SkipIfBeforePyTorchVersion, test_script_save


class TestDiceFocalLoss(unittest.TestCase):
def test_result_onehot_target_include_bg(self):
size = [3, 3, 5, 5]
label = torch.randint(low=0, high=2, size=size)
pred = torch.randn(size)
for reduction in ["sum", "mean", "none"]:
common_params = {
"include_background": True,
"to_onehot_y": False,
"reduction": reduction,
}
for focal_weight in [None, torch.tensor([1.0, 1.0, 2.0]), (3, 2.0, 1)]:
for lambda_focal in [0.5, 1.0, 1.5]:
dice_focal = DiceFocalLoss(
focal_weight=focal_weight, gamma=1.0, lambda_focal=lambda_focal, **common_params
)
dice = DiceLoss(**common_params)
focal = FocalLoss(weight=focal_weight, gamma=1.0, **common_params)
result = dice_focal(pred, label)
expected_val = dice(pred, label) + lambda_focal * focal(pred, label)
np.testing.assert_allclose(result, expected_val)

def test_result_no_onehot_no_bg(self):
size = [3, 3, 5, 5]
label = torch.randint(low=0, high=2, size=size)
label = torch.argmax(label, dim=1, keepdim=True)
pred = torch.randn(size)
for reduction in ["sum", "mean", "none"]:
common_params = {
"include_background": False,
"to_onehot_y": True,
"reduction": reduction,
}
for focal_weight in [2.0, torch.tensor([1.0, 2.0]), (2.0, 1)]:
for lambda_focal in [0.5, 1.0, 1.5]:
dice_focal = DiceFocalLoss(focal_weight=focal_weight, lambda_focal=lambda_focal, **common_params)
dice = DiceLoss(**common_params)
focal = FocalLoss(weight=focal_weight, **common_params)
result = dice_focal(pred, label)
expected_val = dice(pred, label) + lambda_focal * focal(pred, label)
np.testing.assert_allclose(result, expected_val)

def test_ill_shape(self):
loss = DiceFocalLoss()
with self.assertRaisesRegex(ValueError, ""):
loss(torch.ones((1, 2, 3)), torch.ones((1, 1, 2, 3)))

def test_ill_lambda(self):
with self.assertRaisesRegex(ValueError, ""):
loss = DiceFocalLoss(lambda_dice=-1.0)

@SkipIfBeforePyTorchVersion((1, 7, 0))
def test_script(self):
loss = DiceFocalLoss()
test_input = torch.ones(2, 1, 8, 8)
test_script_save(loss, test_input, test_input)


if __name__ == "__main__":
unittest.main()
Loading