diff --git a/monai/losses/focal_loss.py b/monai/losses/focal_loss.py index 4d1c723402..d6071edd71 100644 --- a/monai/losses/focal_loss.py +++ b/monai/losses/focal_loss.py @@ -28,10 +28,10 @@ class FocalLoss(_Loss): FocalLoss is an extension of BCEWithLogitsLoss that down-weights loss from high confidence correct predictions. - Reimplementation of the Focal Loss (with a build-in sigmoid activation) described in: + Reimplementation of the Focal Loss described in: - - "Focal Loss for Dense Object Detection", T. Lin et al., ICCV 2017 - - "AnatomyNet: Deep learning for fast and fully automated whole‐volume segmentation of head and neck anatomy", + - ["Focal Loss for Dense Object Detection"](https://arxiv.org/abs/1708.02002), T. Lin et al., ICCV 2017 + - "AnatomyNet: Deep learning for fast and fully automated whole-volume segmentation of head and neck anatomy", Zhu et al., Medical Physics 2018 Example: @@ -70,19 +70,23 @@ def __init__( include_background: bool = True, to_onehot_y: bool = False, gamma: float = 2.0, + alpha: float | None = None, weight: Sequence[float] | float | int | torch.Tensor | None = None, reduction: LossReduction | str = LossReduction.MEAN, + use_softmax: bool = False, ) -> None: """ Args: - include_background: if False, channel index 0 (background category) is excluded from the calculation. - to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False. - gamma: value of the exponent gamma in the definition of the Focal loss. + include_background: if False, channel index 0 (background category) is excluded from the loss calculation. + If False, `alpha` is invalid when using softmax. + to_onehot_y: whether to convert the label `y` into the one-hot format. Defaults to False. + gamma: value of the exponent gamma in the definition of the Focal loss. Defaults to 2. + alpha: value of the alpha in the definition of the alpha-balanced Focal loss. + The value should be in [0, 1]. Defaults to None. weight: weights to apply to the voxels of each class. If None no weights are applied. - This corresponds to the weights `\alpha` in [1]. The input can be a single value (same weight for all classes), a sequence of values (the length - of the sequence should be the same as the number of classes, if not ``include_background``, the - number should not include class 0). + of the sequence should be the same as the number of classes. If not ``include_background``, + the number of classes should not include the background category class 0). The value/values should be no less than 0. Defaults to None. reduction: {``"none"``, ``"mean"``, ``"sum"``} Specifies the reduction to apply to the output. Defaults to ``"mean"``. @@ -91,6 +95,9 @@ def __init__( - ``"mean"``: the sum of the output will be divided by the number of elements in the output. - ``"sum"``: the output will be summed. + use_softmax: whether to use softmax to transform the original logits into probabilities. + If True, softmax is used. If False, sigmoid is used. Defaults to False. + Example: >>> import torch >>> from monai.losses import FocalLoss @@ -103,14 +110,16 @@ def __init__( self.include_background = include_background self.to_onehot_y = to_onehot_y self.gamma = gamma - self.weight: Sequence[float] | float | int | torch.Tensor | None = weight + self.alpha = alpha + self.weight = weight + self.use_softmax = use_softmax def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ Args: input: the shape should be BNH[WD], where N is the number of classes. The input should be the original logits since it will be transformed by - a sigmoid in the forward function. + a sigmoid/softmax in the forward function. target: the shape should be BNH[WD] or B1H[WD], where N is the number of classes. Raises: @@ -141,63 +150,106 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: if target.shape != input.shape: raise ValueError(f"ground truth has different shape ({target.shape}) from input ({input.shape})") - i = input - t = target - - # Change the shape of input and target to B x N x num_voxels. - b, n = t.shape[:2] - i = i.reshape(b, n, -1) - t = t.reshape(b, n, -1) - - # computing binary cross entropy with logits - # see also https://github.com/pytorch/pytorch/blob/v1.9.0/aten/src/ATen/native/Loss.cpp#L231 - max_val = (-i).clamp(min=0) - ce = i - i * t + max_val + ((-max_val).exp() + (-i - max_val).exp()).log() + loss: Optional[torch.Tensor] = None + input = input.float() + target = target.float() + if self.use_softmax: + if not self.include_background and self.alpha is not None: + self.alpha = None + warnings.warn("`include_background=False`, `alpha` ignored when using softmax.") + loss = softmax_focal_loss(input, target, self.gamma, self.alpha) + else: + loss = sigmoid_focal_loss(input, target, self.gamma, self.alpha) if self.weight is not None: + # make sure the lengths of weights are equal to the number of classes class_weight: Optional[torch.Tensor] = None + num_of_classes = target.shape[1] if isinstance(self.weight, (float, int)): - class_weight = torch.as_tensor([self.weight] * i.size(1)) + class_weight = torch.as_tensor([self.weight] * num_of_classes) else: class_weight = torch.as_tensor(self.weight) - if class_weight.size(0) != i.size(1): + if class_weight.shape[0] != num_of_classes: raise ValueError( - "the length of the weight sequence should be the same as the number of classes. " - + "If `include_background=False`, the number should not include class 0." + """the length of the `weight` sequence should be the same as the number of classes. + If `include_background=False`, the weight should not include + the background category class 0.""" ) if class_weight.min() < 0: - raise ValueError("the value/values of weights should be no less than 0.") - class_weight = class_weight.to(i) - # Convert the weight to a map in which each voxel - # has the weight associated with the ground-truth label - # associated with this voxel in target. - at = class_weight[None, :, None] # N => 1,N,1 - at = at.expand((t.size(0), -1, t.size(2))) # 1,N,1 => B,N,H*W - # Multiply the log proba by their weights. - ce = ce * at - - # Compute the loss mini-batch. - # (1-p_t)^gamma * log(p_t) with reduced chance of overflow - p = F.logsigmoid(-i * (t * 2.0 - 1.0)) - flat_loss: torch.Tensor = (p * self.gamma).exp() * ce - - # Previously there was a mean over the last dimension, which did not - # return a compatible BCE loss. To maintain backwards compatible - # behavior we have a flag that performs this extra step, disable or - # parameterize if necessary. (Or justify why the mean should be there) - average_spatial_dims = True + raise ValueError("the value/values of the `weight` should be no less than 0.") + # apply class_weight to loss + class_weight = class_weight.to(loss) + broadcast_dims = [-1] + [1] * len(target.shape[2:]) + class_weight = class_weight.view(broadcast_dims) + loss = class_weight * loss if self.reduction == LossReduction.SUM.value: + # Previously there was a mean over the last dimension, which did not + # return a compatible BCE loss. To maintain backwards compatible + # behavior we have a flag that performs this extra step, disable or + # parameterize if necessary. (Or justify why the mean should be there) + average_spatial_dims = True if average_spatial_dims: - flat_loss = flat_loss.mean(dim=-1) - loss = flat_loss.sum() + loss = loss.mean(dim=list(range(2, len(target.shape)))) + loss = loss.sum() elif self.reduction == LossReduction.MEAN.value: - if average_spatial_dims: - flat_loss = flat_loss.mean(dim=-1) - loss = flat_loss.mean() + loss = loss.mean() elif self.reduction == LossReduction.NONE.value: - spacetime_dims = input.shape[2:] - loss = flat_loss.reshape([b, n] + list(spacetime_dims)) + pass else: raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].') return loss + + +def softmax_focal_loss( + input: torch.Tensor, target: torch.Tensor, gamma: float = 2.0, alpha: Optional[float] = None +) -> torch.Tensor: + """ + FL(pt) = -alpha * (1 - pt)**gamma * log(pt) + + where p_i = exp(s_i) / sum_j exp(s_j), t is the target (ground truth) class, and + s_j is the unnormalized score for class j. + """ + input_ls = input.log_softmax(1) + loss: torch.Tensor = -(1 - input_ls.exp()).pow(gamma) * input_ls * target + + if alpha is not None: + # (1-alpha) for the background class and alpha for the other classes + alpha_fac = torch.tensor([1 - alpha] + [alpha] * (target.shape[1] - 1)).to(loss) + broadcast_dims = [-1] + [1] * len(target.shape[2:]) + alpha_fac = alpha_fac.view(broadcast_dims) + loss = alpha_fac * loss + + return loss + + +def sigmoid_focal_loss( + input: torch.Tensor, target: torch.Tensor, gamma: float = 2.0, alpha: Optional[float] = None +) -> torch.Tensor: + """ + FL(pt) = -alpha * (1 - pt)**gamma * log(pt) + + where p = sigmoid(x), pt = p if label is 1 or 1 - p if label is 0 + """ + # computing binary cross entropy with logits + # equivalent to F.binary_cross_entropy_with_logits(input, target, reduction='none') + # see also https://github.com/pytorch/pytorch/blob/v1.9.0/aten/src/ATen/native/Loss.cpp#L231 + max_val = (-input).clamp(min=0) + loss: torch.Tensor = input - input * target + max_val + ((-max_val).exp() + (-input - max_val).exp()).log() + + # sigmoid(-i) if t==1; sigmoid(i) if t==0 <=> + # 1-sigmoid(i) if t==1; sigmoid(i) if t==0 <=> + # 1-p if t==1; p if t==0 <=> + # pfac, that is, the term (1 - pt) + invprobs = F.logsigmoid(-input * (target * 2 - 1)) # reduced chance of overflow + # (pfac.log() * gamma).exp() <=> + # pfac.log().exp() ^ gamma <=> + # pfac ^ gamma + loss = (invprobs * gamma).exp() * loss + + if alpha is not None: + # alpha if t==1; (1-alpha) if t==0 + alpha_factor = target * alpha + (1 - target) * (1 - alpha) + loss = alpha_factor * loss + + return loss diff --git a/tests/test_focal_loss.py b/tests/test_focal_loss.py index 5f30b7b07d..46a947ea7c 100644 --- a/tests/test_focal_loss.py +++ b/tests/test_focal_loss.py @@ -13,16 +13,78 @@ import unittest +import numpy as np import torch import torch.nn as nn import torch.nn.functional as F +from parameterized import parameterized from monai.losses import FocalLoss from monai.networks import one_hot from tests.utils import test_script_save +TEST_CASES = [] +for device in ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"]: + input_data = { + "input": torch.tensor( + [[[[1.0, 1.0], [0.5, 0.0]], [[1.0, 1.0], [0.5, 0.0]], [[1.0, 1.0], [0.5, 0.0]]]], device=device + ), # (1, 3, 2, 2) + "target": torch.tensor([[[[0, 1], [2, 0]]]], device=device), # (1, 1, 2, 2) + } + TEST_CASES.append([{"to_onehot_y": True}, input_data, 0.34959]) + TEST_CASES.append( + [ + {"to_onehot_y": False}, + { + "input": input_data["input"], # (1, 3, 2, 2) + "target": F.one_hot(input_data["target"].squeeze(1)).permute(0, 3, 1, 2), # (1, 3, 2, 2) + }, + 0.34959, + ] + ) + TEST_CASES.append([{"to_onehot_y": True, "include_background": False}, input_data, 0.36498]) + TEST_CASES.append([{"to_onehot_y": True, "alpha": 0.8}, input_data, 0.08423]) + TEST_CASES.append( + [ + {"to_onehot_y": True, "reduction": "none"}, + input_data, + np.array( + [ + [ + [[0.02266, 0.70187], [0.37741, 0.17329]], + [[0.70187, 0.02266], [0.37741, 0.17329]], + [[0.70187, 0.70187], [0.06757, 0.17329]], + ] + ] + ), + ] + ) + TEST_CASES.append( + [ + {"to_onehot_y": True, "weight": torch.tensor([0.5, 0.1, 0.2]), "reduction": "none"}, + input_data, + np.array( + [ + [ + [[0.01133, 0.35093], [0.18871, 0.08664]], + [[0.07019, 0.00227], [0.03774, 0.01733]], + [[0.14037, 0.14037], [0.01352, 0.03466]], + ] + ] + ), + ] + ) + TEST_CASES.append([{"to_onehot_y": True, "use_softmax": True}, input_data, 0.16276]) + TEST_CASES.append([{"to_onehot_y": True, "alpha": 0.8, "use_softmax": True}, input_data, 0.08138]) + class TestFocalLoss(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_result(self, input_param, input_data, expected_val): + focal_loss = FocalLoss(**input_param) + result = focal_loss(**input_data) + np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, atol=1e-4, rtol=1e-4) + def test_consistency_with_cross_entropy_2d(self): """For gamma=0 the focal loss reduces to the cross entropy loss""" focal_loss = FocalLoss(to_onehot_y=False, gamma=0.0, reduction="mean", weight=1.0) @@ -144,86 +206,116 @@ def test_consistency_with_cross_entropy_classification_01(self): self.assertNotAlmostEqual(max_error, 0.0, places=3) def test_bin_seg_2d(self): - # define 2d examples - target = torch.tensor([[0, 0, 0, 0], [0, 1, 1, 0], [0, 1, 1, 0], [0, 0, 0, 0]]) - # add another dimension corresponding to the batch (batch size = 1 here) - target = target.unsqueeze(0) # shape (1, H, W) - pred_very_good = 100 * F.one_hot(target, num_classes=2).permute(0, 3, 1, 2).float() - 50.0 - - # initialize the mean dice loss - loss = FocalLoss(to_onehot_y=True) - - # focal loss for pred_very_good should be close to 0 - target = target.unsqueeze(1) # shape (1, 1, H, W) - focal_loss_good = float(loss(pred_very_good, target).cpu()) - self.assertAlmostEqual(focal_loss_good, 0.0, places=3) + for use_softmax in [True, False]: + # define 2d examples + target = torch.tensor([[0, 0, 0, 0], [0, 1, 1, 0], [0, 1, 1, 0], [0, 0, 0, 0]]) + # add another dimension corresponding to the batch (batch size = 1 here) + target = target.unsqueeze(0) # shape (1, H, W) + pred_very_good = 100 * F.one_hot(target, num_classes=2).permute(0, 3, 1, 2).float() - 50.0 + + # initialize the mean dice loss + loss = FocalLoss(to_onehot_y=True, use_softmax=use_softmax) + + # focal loss for pred_very_good should be close to 0 + target = target.unsqueeze(1) # shape (1, 1, H, W) + focal_loss_good = float(loss(pred_very_good, target).cpu()) + self.assertAlmostEqual(focal_loss_good, 0.0, places=3) + + # with alpha + loss = FocalLoss(to_onehot_y=True, alpha=0.5, use_softmax=use_softmax) + focal_loss_good = float(loss(pred_very_good, target).cpu()) + self.assertAlmostEqual(focal_loss_good, 0.0, places=3) def test_empty_class_2d(self): - num_classes = 2 - # define 2d examples - target = torch.tensor([[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]) - # add another dimension corresponding to the batch (batch size = 1 here) - target = target.unsqueeze(0) # shape (1, H, W) - pred_very_good = 1000 * F.one_hot(target, num_classes=num_classes).permute(0, 3, 1, 2).float() - 500.0 - - # initialize the mean dice loss - loss = FocalLoss(to_onehot_y=True) - - # focal loss for pred_very_good should be close to 0 - target = target.unsqueeze(1) # shape (1, 1, H, W) - focal_loss_good = float(loss(pred_very_good, target).cpu()) - self.assertAlmostEqual(focal_loss_good, 0.0, places=3) + for use_softmax in [True, False]: + num_classes = 2 + # define 2d examples + target = torch.tensor([[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]) + # add another dimension corresponding to the batch (batch size = 1 here) + target = target.unsqueeze(0) # shape (1, H, W) + pred_very_good = 1000 * F.one_hot(target, num_classes=num_classes).permute(0, 3, 1, 2).float() - 500.0 + + # initialize the mean dice loss + loss = FocalLoss(to_onehot_y=True, use_softmax=use_softmax) + + # focal loss for pred_very_good should be close to 0 + target = target.unsqueeze(1) # shape (1, 1, H, W) + focal_loss_good = float(loss(pred_very_good, target).cpu()) + self.assertAlmostEqual(focal_loss_good, 0.0, places=3) + + # with alpha + loss = FocalLoss(to_onehot_y=True, alpha=0.5, use_softmax=use_softmax) + focal_loss_good = float(loss(pred_very_good, target).cpu()) + self.assertAlmostEqual(focal_loss_good, 0.0, places=3) def test_multi_class_seg_2d(self): - num_classes = 6 # labels 0 to 5 - # define 2d examples - target = torch.tensor([[0, 0, 0, 0], [0, 1, 2, 0], [0, 3, 4, 0], [0, 0, 0, 0]]) - # add another dimension corresponding to the batch (batch size = 1 here) - target = target.unsqueeze(0) # shape (1, H, W) - pred_very_good = 1000 * F.one_hot(target, num_classes=num_classes).permute(0, 3, 1, 2).float() - 500.0 - # initialize the mean dice loss - loss = FocalLoss(to_onehot_y=True) - loss_onehot = FocalLoss(to_onehot_y=False) - - # focal loss for pred_very_good should be close to 0 - target_one_hot = F.one_hot(target, num_classes=num_classes).permute(0, 3, 1, 2) # test one hot - target = target.unsqueeze(1) # shape (1, 1, H, W) - - focal_loss_good = float(loss(pred_very_good, target).cpu()) - self.assertAlmostEqual(focal_loss_good, 0.0, places=3) - - focal_loss_good = float(loss_onehot(pred_very_good, target_one_hot).cpu()) - self.assertAlmostEqual(focal_loss_good, 0.0, places=3) + for use_softmax in [True, False]: + num_classes = 6 # labels 0 to 5 + # define 2d examples + target = torch.tensor([[0, 0, 0, 0], [0, 1, 2, 0], [0, 3, 4, 0], [0, 0, 0, 0]]) + # add another dimension corresponding to the batch (batch size = 1 here) + target = target.unsqueeze(0) # shape (1, H, W) + pred_very_good = 1000 * F.one_hot(target, num_classes=num_classes).permute(0, 3, 1, 2).float() - 500.0 + # initialize the mean dice loss + loss = FocalLoss(to_onehot_y=True, use_softmax=use_softmax) + loss_onehot = FocalLoss(to_onehot_y=False, use_softmax=use_softmax) + + # focal loss for pred_very_good should be close to 0 + target_one_hot = F.one_hot(target, num_classes=num_classes).permute(0, 3, 1, 2) # test one hot + target = target.unsqueeze(1) # shape (1, 1, H, W) + + focal_loss_good = float(loss(pred_very_good, target).cpu()) + self.assertAlmostEqual(focal_loss_good, 0.0, places=3) + + focal_loss_good = float(loss_onehot(pred_very_good, target_one_hot).cpu()) + self.assertAlmostEqual(focal_loss_good, 0.0, places=3) + + # with alpha + loss = FocalLoss(to_onehot_y=True, alpha=0.5, use_softmax=use_softmax) + focal_loss_good = float(loss(pred_very_good, target).cpu()) + self.assertAlmostEqual(focal_loss_good, 0.0, places=3) + loss_onehot = FocalLoss(to_onehot_y=False, alpha=0.5, use_softmax=use_softmax) + focal_loss_good = float(loss_onehot(pred_very_good, target_one_hot).cpu()) + self.assertAlmostEqual(focal_loss_good, 0.0, places=3) def test_bin_seg_3d(self): - num_classes = 2 # labels 0, 1 - # define 3d examples - target = torch.tensor( - [ - # raw 0 - [[0, 0, 0, 0], [0, 1, 1, 0], [0, 1, 1, 0], [0, 0, 0, 0]], - # raw 1 - [[0, 0, 0, 0], [0, 1, 1, 0], [0, 1, 1, 0], [0, 0, 0, 0]], - # raw 2 - [[0, 0, 0, 0], [0, 1, 1, 0], [0, 1, 1, 0], [0, 0, 0, 0]], - ] - ) - # add another dimension corresponding to the batch (batch size = 1 here) - target = target.unsqueeze(0) # shape (1, H, W, D) - target_one_hot = F.one_hot(target, num_classes=num_classes).permute(0, 4, 1, 2, 3) # test one hot - pred_very_good = 1000 * F.one_hot(target, num_classes=num_classes).permute(0, 4, 1, 2, 3).float() - 500.0 - - # initialize the mean dice loss - loss = FocalLoss(to_onehot_y=True) - loss_onehot = FocalLoss(to_onehot_y=False) - - # focal loss for pred_very_good should be close to 0 - target = target.unsqueeze(1) # shape (1, 1, H, W) - focal_loss_good = float(loss(pred_very_good, target).cpu()) - self.assertAlmostEqual(focal_loss_good, 0.0, places=3) - - focal_loss_good = float(loss_onehot(pred_very_good, target_one_hot).cpu()) - self.assertAlmostEqual(focal_loss_good, 0.0, places=3) + for use_softmax in [True, False]: + num_classes = 2 # labels 0, 1 + # define 3d examples + target = torch.tensor( + [ + # raw 0 + [[0, 0, 0, 0], [0, 1, 1, 0], [0, 1, 1, 0], [0, 0, 0, 0]], + # raw 1 + [[0, 0, 0, 0], [0, 1, 1, 0], [0, 1, 1, 0], [0, 0, 0, 0]], + # raw 2 + [[0, 0, 0, 0], [0, 1, 1, 0], [0, 1, 1, 0], [0, 0, 0, 0]], + ] + ) + # add another dimension corresponding to the batch (batch size = 1 here) + target = target.unsqueeze(0) # shape (1, H, W, D) + target_one_hot = F.one_hot(target, num_classes=num_classes).permute(0, 4, 1, 2, 3) # test one hot + pred_very_good = 1000 * F.one_hot(target, num_classes=num_classes).permute(0, 4, 1, 2, 3).float() - 500.0 + + # initialize the mean dice loss + loss = FocalLoss(to_onehot_y=True, use_softmax=use_softmax) + loss_onehot = FocalLoss(to_onehot_y=False, use_softmax=use_softmax) + + # focal loss for pred_very_good should be close to 0 + target = target.unsqueeze(1) # shape (1, 1, H, W) + focal_loss_good = float(loss(pred_very_good, target).cpu()) + self.assertAlmostEqual(focal_loss_good, 0.0, places=3) + + focal_loss_good = float(loss_onehot(pred_very_good, target_one_hot).cpu()) + self.assertAlmostEqual(focal_loss_good, 0.0, places=3) + + # with alpha + loss = FocalLoss(to_onehot_y=True, alpha=0.5, use_softmax=use_softmax) + focal_loss_good = float(loss(pred_very_good, target).cpu()) + self.assertAlmostEqual(focal_loss_good, 0.0, places=3) + loss_onehot = FocalLoss(to_onehot_y=False, alpha=0.5, use_softmax=use_softmax) + focal_loss_good = float(loss_onehot(pred_very_good, target_one_hot).cpu()) + self.assertAlmostEqual(focal_loss_good, 0.0, places=3) def test_foreground(self): background = torch.ones(1, 1, 5, 5) @@ -260,10 +352,28 @@ def test_ill_class_weight(self): with self.assertRaisesRegex(ValueError, ""): FocalLoss(include_background=False, weight=(1.0, 1.0, -1.0))(chn_input, chn_target) + def test_warnings(self): + with self.assertWarns(Warning): + chn_input = torch.ones((1, 1, 3)) + chn_target = torch.ones((1, 1, 3)) + loss = FocalLoss(to_onehot_y=True) + loss(chn_input, chn_target) + with self.assertWarns(Warning): + chn_input = torch.ones((1, 1, 3)) + chn_target = torch.ones((1, 1, 3)) + loss = FocalLoss(include_background=False) + loss(chn_input, chn_target) + with self.assertWarns(Warning): + chn_input = torch.ones((1, 3, 3)) + chn_target = torch.ones((1, 3, 3)) + loss = FocalLoss(include_background=False, use_softmax=True, alpha=0.5) + loss(chn_input, chn_target) + def test_script(self): - loss = FocalLoss() - test_input = torch.ones(2, 2, 8, 8) - test_script_save(loss, test_input, test_input) + for use_softmax in [True, False]: + loss = FocalLoss(use_softmax=use_softmax) + test_input = torch.ones(2, 2, 8, 8) + test_script_save(loss, test_input, test_input) if __name__ == "__main__":