From cb5ed048e8d3c07430fa48ab8d1eaf7ec3a5e55c Mon Sep 17 00:00:00 2001 From: Qingpeng Li Date: Wed, 24 May 2023 00:49:12 +0800 Subject: [PATCH 01/11] add softmax version to focal_loss --- monai/losses/focal_loss.py | 164 +++++++++++++++++++++++++------------ 1 file changed, 110 insertions(+), 54 deletions(-) diff --git a/monai/losses/focal_loss.py b/monai/losses/focal_loss.py index 4d1c723402..31224063dd 100644 --- a/monai/losses/focal_loss.py +++ b/monai/losses/focal_loss.py @@ -28,10 +28,10 @@ class FocalLoss(_Loss): FocalLoss is an extension of BCEWithLogitsLoss that down-weights loss from high confidence correct predictions. - Reimplementation of the Focal Loss (with a build-in sigmoid activation) described in: + Reimplementation of the Focal Loss described in: - - "Focal Loss for Dense Object Detection", T. Lin et al., ICCV 2017 - - "AnatomyNet: Deep learning for fast and fully automated whole‐volume segmentation of head and neck anatomy", + - ["Focal Loss for Dense Object Detection"](https://arxiv.org/abs/1708.02002), T. Lin et al., ICCV 2017 + - "AnatomyNet: Deep learning for fast and fully automated whole-volume segmentation of head and neck anatomy", Zhu et al., Medical Physics 2018 Example: @@ -70,19 +70,23 @@ def __init__( include_background: bool = True, to_onehot_y: bool = False, gamma: float = 2.0, + alpha: float | None = None, weight: Sequence[float] | float | int | torch.Tensor | None = None, reduction: LossReduction | str = LossReduction.MEAN, + use_softmax: bool = False, ) -> None: """ Args: - include_background: if False, channel index 0 (background category) is excluded from the calculation. - to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False. - gamma: value of the exponent gamma in the definition of the Focal loss. + include_background: if False, channel index 0 (background category) is excluded from the loss calculation. + If False, `alpha` is invalid when using softmax. + to_onehot_y: whether to convert the label `y` into the one-hot format. Defaults to False. + gamma: value of the exponent gamma in the definition of the Focal loss. Defaults to 2. + alpha: value of the alpha in the definition of the alpha-balanced Focal loss. + The value should be in [0, 1]. Defaults to None. weight: weights to apply to the voxels of each class. If None no weights are applied. - This corresponds to the weights `\alpha` in [1]. The input can be a single value (same weight for all classes), a sequence of values (the length - of the sequence should be the same as the number of classes, if not ``include_background``, the - number should not include class 0). + of the sequence should be the same as the number of classes. If not ``include_background``, + the number of classes should not include the background category class 0). The value/values should be no less than 0. Defaults to None. reduction: {``"none"``, ``"mean"``, ``"sum"``} Specifies the reduction to apply to the output. Defaults to ``"mean"``. @@ -91,6 +95,9 @@ def __init__( - ``"mean"``: the sum of the output will be divided by the number of elements in the output. - ``"sum"``: the output will be summed. + use_softmax: whether to use softmax to transform the original logits into probabilities. + If True, softmax is used. If False, sigmoid is used. Defaults to False. + Example: >>> import torch >>> from monai.losses import FocalLoss @@ -103,14 +110,16 @@ def __init__( self.include_background = include_background self.to_onehot_y = to_onehot_y self.gamma = gamma - self.weight: Sequence[float] | float | int | torch.Tensor | None = weight + self.alpha = alpha + self.weight = weight + self.use_softmax = use_softmax def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ Args: input: the shape should be BNH[WD], where N is the number of classes. The input should be the original logits since it will be transformed by - a sigmoid in the forward function. + a sigmoid/softmax in the forward function. target: the shape should be BNH[WD] or B1H[WD], where N is the number of classes. Raises: @@ -141,63 +150,110 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: if target.shape != input.shape: raise ValueError(f"ground truth has different shape ({target.shape}) from input ({input.shape})") - i = input - t = target - - # Change the shape of input and target to B x N x num_voxels. - b, n = t.shape[:2] - i = i.reshape(b, n, -1) - t = t.reshape(b, n, -1) - - # computing binary cross entropy with logits - # see also https://github.com/pytorch/pytorch/blob/v1.9.0/aten/src/ATen/native/Loss.cpp#L231 - max_val = (-i).clamp(min=0) - ce = i - i * t + max_val + ((-max_val).exp() + (-i - max_val).exp()).log() + loss: Optional[torch.Tensor] = None + input = input.float() + target = target.float() + if self.use_softmax: + if not self.include_background and self.alpha is not None: + self.alpha = None + warnings.warn("`include_background=False`, `alpha` ignored when using softmax.") + loss = softmax_focal_loss(input, target, self.gamma, self.alpha) + else: + loss = sigmoid_focal_loss(input, target, self.gamma, self.alpha) if self.weight is not None: + # make sure the lengths of weights are equal to the number of classes class_weight: Optional[torch.Tensor] = None + num_of_classes = target.shape[1] if isinstance(self.weight, (float, int)): - class_weight = torch.as_tensor([self.weight] * i.size(1)) + class_weight = torch.as_tensor([self.weight] * num_of_classes) else: class_weight = torch.as_tensor(self.weight) - if class_weight.size(0) != i.size(1): + if class_weight.shape[0] != num_of_classes: raise ValueError( - "the length of the weight sequence should be the same as the number of classes. " - + "If `include_background=False`, the number should not include class 0." + """the length of the `weight` sequence should be the same as the number of classes. + If `include_background=False`, the weight should not include + the background category class 0.""" ) if class_weight.min() < 0: - raise ValueError("the value/values of weights should be no less than 0.") - class_weight = class_weight.to(i) - # Convert the weight to a map in which each voxel - # has the weight associated with the ground-truth label - # associated with this voxel in target. - at = class_weight[None, :, None] # N => 1,N,1 - at = at.expand((t.size(0), -1, t.size(2))) # 1,N,1 => B,N,H*W - # Multiply the log proba by their weights. - ce = ce * at - - # Compute the loss mini-batch. - # (1-p_t)^gamma * log(p_t) with reduced chance of overflow - p = F.logsigmoid(-i * (t * 2.0 - 1.0)) - flat_loss: torch.Tensor = (p * self.gamma).exp() * ce - - # Previously there was a mean over the last dimension, which did not - # return a compatible BCE loss. To maintain backwards compatible - # behavior we have a flag that performs this extra step, disable or - # parameterize if necessary. (Or justify why the mean should be there) - average_spatial_dims = True + raise ValueError("the value/values of the `weight` should be no less than 0.") + # apply class_weight to loss + class_weight = class_weight.to(loss) + broadcast_dims = [-1] + [1] * len(target.shape[2:]) + class_weight = class_weight.view(broadcast_dims) + loss = class_weight * loss if self.reduction == LossReduction.SUM.value: + # Previously there was a mean over the last dimension, which did not + # return a compatible BCE loss. To maintain backwards compatible + # behavior we have a flag that performs this extra step, disable or + # parameterize if necessary. (Or justify why the mean should be there) + average_spatial_dims = True if average_spatial_dims: - flat_loss = flat_loss.mean(dim=-1) - loss = flat_loss.sum() + loss = loss.mean(dim=target.shape[2:]) + loss = loss.sum() elif self.reduction == LossReduction.MEAN.value: - if average_spatial_dims: - flat_loss = flat_loss.mean(dim=-1) - loss = flat_loss.mean() + loss = loss.mean() elif self.reduction == LossReduction.NONE.value: - spacetime_dims = input.shape[2:] - loss = flat_loss.reshape([b, n] + list(spacetime_dims)) + pass else: raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].') return loss + +def softmax_focal_loss( + input: torch.Tensor, + target: torch.Tensor, + gamma: float = 2.0, + alpha: Optional[float] = None, +) -> torch.Tensor: + """ + FL(pt) = -alpha * (1 - pt)**gamma * log(pt) + + where p_i = exp(s_i) / sum_j exp(s_j), t is the target (ground truth) class, and + s_j is the unnormalized score for class j. + """ + pt = input.softmax(1) + loss: torch.Tensor = - (1 - pt).pow(gamma) * input.log_softmax(1) * target + + if alpha is not None: + # (1-alpha) for the background class and alpha for the other classes + alpha_fac = torch.tensor([1-alpha] + [alpha] * (target.shape[1]-1)).to(loss) + broadcast_dims = [-1] + [1] * len(target.shape[2:]) + alpha_fac = alpha_fac.view(broadcast_dims) + loss = alpha_fac * loss + + return loss + +def sigmoid_focal_loss( + input: torch.Tensor, + target: torch.Tensor, + gamma: float = 2.0, + alpha: Optional[float] = None, +) -> torch.Tensor: + """ + FL(pt) = -alpha * (1 - pt)**gamma * log(pt) + + where p = sigmoid(x), pt = p if label is 1 or 1 - p if label is 0 + """ + # computing binary cross entropy with logits + # equivalent to F.binary_cross_entropy_with_logits(input, target, reduction='none') + # see also https://github.com/pytorch/pytorch/blob/v1.9.0/aten/src/ATen/native/Loss.cpp#L231 + max_val = (-input).clamp(min=0) + loss: torch.Tensor = input - input * target + max_val + ((-max_val).exp() + (-input - max_val).exp()).log() + + # sigmoid(-i) if t==1; sigmoid(i) if t==0 <=> + # 1-sigmoid(i) if t==1; sigmoid(i) if t==0 <=> + # 1-p if t==1; p if t==0 <=> + # pfac, that is, the term (1 - pt) + invprobs = F.logsigmoid(-input * (target * 2 - 1)) # reduced chance of overflow + # (pfac.log() * gamma).exp() <=> + # pfac.log().exp() ^ gamma <=> + # pfac ^ gamma + loss = (invprobs * gamma).exp() * loss + + if alpha is not None: + # alpha if t==1; (1-alpha) if t==0 + alpha_factor = target * alpha + (1 - target) * (1 - alpha) + loss = alpha_factor * loss + + return loss \ No newline at end of file From 5e86c10e1c24bd4f982a7c0cf3d323384163c6a8 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 23 May 2023 16:55:21 +0000 Subject: [PATCH 02/11] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/losses/focal_loss.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/monai/losses/focal_loss.py b/monai/losses/focal_loss.py index 31224063dd..add4832a08 100644 --- a/monai/losses/focal_loss.py +++ b/monai/losses/focal_loss.py @@ -85,7 +85,7 @@ def __init__( The value should be in [0, 1]. Defaults to None. weight: weights to apply to the voxels of each class. If None no weights are applied. The input can be a single value (same weight for all classes), a sequence of values (the length - of the sequence should be the same as the number of classes. If not ``include_background``, + of the sequence should be the same as the number of classes. If not ``include_background``, the number of classes should not include the background category class 0). The value/values should be no less than 0. Defaults to None. reduction: {``"none"``, ``"mean"``, ``"sum"``} @@ -172,7 +172,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: if class_weight.shape[0] != num_of_classes: raise ValueError( """the length of the `weight` sequence should be the same as the number of classes. - If `include_background=False`, the weight should not include + If `include_background=False`, the weight should not include the background category class 0.""" ) if class_weight.min() < 0: @@ -221,7 +221,7 @@ def softmax_focal_loss( broadcast_dims = [-1] + [1] * len(target.shape[2:]) alpha_fac = alpha_fac.view(broadcast_dims) loss = alpha_fac * loss - + return loss def sigmoid_focal_loss( @@ -256,4 +256,4 @@ def sigmoid_focal_loss( alpha_factor = target * alpha + (1 - target) * (1 - alpha) loss = alpha_factor * loss - return loss \ No newline at end of file + return loss From 78eebafff0295818c3c409414dbd8acf26c09558 Mon Sep 17 00:00:00 2001 From: Qingpeng Li Date: Wed, 24 May 2023 01:23:16 +0800 Subject: [PATCH 03/11] fix dim --- monai/losses/focal_loss.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/losses/focal_loss.py b/monai/losses/focal_loss.py index 31224063dd..71c948658f 100644 --- a/monai/losses/focal_loss.py +++ b/monai/losses/focal_loss.py @@ -190,7 +190,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: # parameterize if necessary. (Or justify why the mean should be there) average_spatial_dims = True if average_spatial_dims: - loss = loss.mean(dim=target.shape[2:]) + loss = loss.mean(dim=list(range(2, len(target.shape)))) loss = loss.sum() elif self.reduction == LossReduction.MEAN.value: loss = loss.mean() From a97cf1f2ed30d5efd04c41de899f1cffa05f244c Mon Sep 17 00:00:00 2001 From: Qingpeng Li Date: Wed, 24 May 2023 01:42:06 +0800 Subject: [PATCH 04/11] fix format DCO Remediation Commit for Qingpeng Li I, Qingpeng Li , hereby add my Signed-off-by to this commit: cb5ed048e8d3c07430fa48ab8d1eaf7ec3a5e55c I, Qingpeng Li , hereby add my Signed-off-by to this commit: 78eebafff0295818c3c409414dbd8acf26c09558 Signed-off-by: Qingpeng Li --- monai/losses/focal_loss.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/losses/focal_loss.py b/monai/losses/focal_loss.py index d1401cacdf..039cc8b05f 100644 --- a/monai/losses/focal_loss.py +++ b/monai/losses/focal_loss.py @@ -245,7 +245,7 @@ def sigmoid_focal_loss( # 1-sigmoid(i) if t==1; sigmoid(i) if t==0 <=> # 1-p if t==1; p if t==0 <=> # pfac, that is, the term (1 - pt) - invprobs = F.logsigmoid(-input * (target * 2 - 1)) # reduced chance of overflow + invprobs = F.logsigmoid(-input * (target * 2 - 1)) # reduced chance of overflow # (pfac.log() * gamma).exp() <=> # pfac.log().exp() ^ gamma <=> # pfac ^ gamma From 5db9e58b8b6cf0e2323e938fc8ef125df3d8a5d0 Mon Sep 17 00:00:00 2001 From: Qingpeng Li Date: Wed, 24 May 2023 03:26:11 +0800 Subject: [PATCH 05/11] improve algorithm --- monai/losses/focal_loss.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/losses/focal_loss.py b/monai/losses/focal_loss.py index 039cc8b05f..5c9bfd5cd1 100644 --- a/monai/losses/focal_loss.py +++ b/monai/losses/focal_loss.py @@ -212,8 +212,8 @@ def softmax_focal_loss( where p_i = exp(s_i) / sum_j exp(s_j), t is the target (ground truth) class, and s_j is the unnormalized score for class j. """ - pt = input.softmax(1) - loss: torch.Tensor = - (1 - pt).pow(gamma) * input.log_softmax(1) * target + input_ls = input.log_softmax(1) + loss: torch.Tensor = - (1 - input_ls.exp()).pow(gamma) * input_ls * target if alpha is not None: # (1-alpha) for the background class and alpha for the other classes From 37adb2de413a6164fc64a575d6923a3f296288d1 Mon Sep 17 00:00:00 2001 From: Qingpeng Li Date: Fri, 26 May 2023 21:11:05 +0800 Subject: [PATCH 06/11] add unittests for focal_loss.py DCO Remediation Commit for Qingpeng Li I, Qingpeng Li , hereby add my Signed-off-by to this commit: 5db9e58b8b6cf0e2323e938fc8ef125df3d8a5d0 Signed-off-by: Qingpeng Li --- tests/test_focal_loss.py | 250 +++++++++++++++++++++++++++------------ 1 file changed, 173 insertions(+), 77 deletions(-) diff --git a/tests/test_focal_loss.py b/tests/test_focal_loss.py index 5f30b7b07d..9344879f63 100644 --- a/tests/test_focal_loss.py +++ b/tests/test_focal_loss.py @@ -13,16 +13,64 @@ import unittest +import numpy as np import torch import torch.nn as nn import torch.nn.functional as F +from parameterized import parameterized + from monai.losses import FocalLoss from monai.networks import one_hot from tests.utils import test_script_save +TEST_CASES = [] +def get_input_data(device): + return { + "input": torch.tensor([[ + [[1.0, 1.0], [0.5, 0.0]], + [[1.0, 1.0], [0.5, 0.0]], + [[1.0, 1.0], [0.5, 0.0]]]], device=device), # (1, 3, 2, 2) + "target": torch.tensor([[[[0, 1], [2, 0]]]], device=device), # (1, 1, 2, 2) + } +for device in ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"]: + input_data = get_input_data(device) + TEST_CASES.append([{"to_onehot_y": True}, input_data, 0.34959]) + TEST_CASES.append([{"to_onehot_y": False}, { + "input": input_data['input'], # (1, 3, 2, 2) + "target": F.one_hot(input_data['target'].squeeze(1)).permute(0,3,1,2)}, # (1, 3, 2, 2) + 0.34959 + ]) + TEST_CASES.append([{"to_onehot_y": True, "include_background": False}, input_data, 0.36498]) + TEST_CASES.append([{"to_onehot_y": True, "alpha": 0.8}, input_data, 0.08423]) + TEST_CASES.append([{"to_onehot_y": True, "reduction": "none"}, input_data, + np.array([[ + [[0.02266, 0.70187], + [0.37741, 0.17329]], + [[0.70187, 0.02266], + [0.37741, 0.17329]], + [[0.70187, 0.70187], + [0.06757, 0.17329]]]]) + ]) + TEST_CASES.append([{"to_onehot_y": True, "weight": torch.tensor([0.5, 0.1, 0.2]), "reduction": "none"}, input_data, + np.array([[ + [[0.01133, 0.35093], + [0.18871, 0.08664]], + [[0.07019, 0.00227], + [0.03774, 0.01733]], + [[0.14037, 0.14037], + [0.01352, 0.03466]]]]) + ]) + TEST_CASES.append([{"to_onehot_y": True, "use_softmax": True}, input_data, 0.16276]) + TEST_CASES.append([{"to_onehot_y": True, "alpha": 0.8, "use_softmax": True}, input_data, 0.08138]) class TestFocalLoss(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_result(self, input_param, input_data, expected_val): + focal_loss = FocalLoss(**input_param) + result = focal_loss(**input_data) + np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, atol=1e-4, rtol=1e-4) + def test_consistency_with_cross_entropy_2d(self): """For gamma=0 the focal loss reduces to the cross entropy loss""" focal_loss = FocalLoss(to_onehot_y=False, gamma=0.0, reduction="mean", weight=1.0) @@ -144,86 +192,116 @@ def test_consistency_with_cross_entropy_classification_01(self): self.assertNotAlmostEqual(max_error, 0.0, places=3) def test_bin_seg_2d(self): - # define 2d examples - target = torch.tensor([[0, 0, 0, 0], [0, 1, 1, 0], [0, 1, 1, 0], [0, 0, 0, 0]]) - # add another dimension corresponding to the batch (batch size = 1 here) - target = target.unsqueeze(0) # shape (1, H, W) - pred_very_good = 100 * F.one_hot(target, num_classes=2).permute(0, 3, 1, 2).float() - 50.0 - - # initialize the mean dice loss - loss = FocalLoss(to_onehot_y=True) - - # focal loss for pred_very_good should be close to 0 - target = target.unsqueeze(1) # shape (1, 1, H, W) - focal_loss_good = float(loss(pred_very_good, target).cpu()) - self.assertAlmostEqual(focal_loss_good, 0.0, places=3) + for use_softmax in [True, False]: + # define 2d examples + target = torch.tensor([[0, 0, 0, 0], [0, 1, 1, 0], [0, 1, 1, 0], [0, 0, 0, 0]]) + # add another dimension corresponding to the batch (batch size = 1 here) + target = target.unsqueeze(0) # shape (1, H, W) + pred_very_good = 100 * F.one_hot(target, num_classes=2).permute(0, 3, 1, 2).float() - 50.0 + + # initialize the mean dice loss + loss = FocalLoss(to_onehot_y=True, use_softmax=use_softmax) + + # focal loss for pred_very_good should be close to 0 + target = target.unsqueeze(1) # shape (1, 1, H, W) + focal_loss_good = float(loss(pred_very_good, target).cpu()) + self.assertAlmostEqual(focal_loss_good, 0.0, places=3) + + # with alpha + loss = FocalLoss(to_onehot_y=True, alpha=0.5, use_softmax=use_softmax) + focal_loss_good = float(loss(pred_very_good, target).cpu()) + self.assertAlmostEqual(focal_loss_good, 0.0, places=3) def test_empty_class_2d(self): - num_classes = 2 - # define 2d examples - target = torch.tensor([[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]) - # add another dimension corresponding to the batch (batch size = 1 here) - target = target.unsqueeze(0) # shape (1, H, W) - pred_very_good = 1000 * F.one_hot(target, num_classes=num_classes).permute(0, 3, 1, 2).float() - 500.0 - - # initialize the mean dice loss - loss = FocalLoss(to_onehot_y=True) - - # focal loss for pred_very_good should be close to 0 - target = target.unsqueeze(1) # shape (1, 1, H, W) - focal_loss_good = float(loss(pred_very_good, target).cpu()) - self.assertAlmostEqual(focal_loss_good, 0.0, places=3) + for use_softmax in [True, False]: + num_classes = 2 + # define 2d examples + target = torch.tensor([[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]) + # add another dimension corresponding to the batch (batch size = 1 here) + target = target.unsqueeze(0) # shape (1, H, W) + pred_very_good = 1000 * F.one_hot(target, num_classes=num_classes).permute(0, 3, 1, 2).float() - 500.0 + + # initialize the mean dice loss + loss = FocalLoss(to_onehot_y=True, use_softmax=use_softmax) + + # focal loss for pred_very_good should be close to 0 + target = target.unsqueeze(1) # shape (1, 1, H, W) + focal_loss_good = float(loss(pred_very_good, target).cpu()) + self.assertAlmostEqual(focal_loss_good, 0.0, places=3) + + # with alpha + loss = FocalLoss(to_onehot_y=True, alpha=0.5, use_softmax=use_softmax) + focal_loss_good = float(loss(pred_very_good, target).cpu()) + self.assertAlmostEqual(focal_loss_good, 0.0, places=3) def test_multi_class_seg_2d(self): - num_classes = 6 # labels 0 to 5 - # define 2d examples - target = torch.tensor([[0, 0, 0, 0], [0, 1, 2, 0], [0, 3, 4, 0], [0, 0, 0, 0]]) - # add another dimension corresponding to the batch (batch size = 1 here) - target = target.unsqueeze(0) # shape (1, H, W) - pred_very_good = 1000 * F.one_hot(target, num_classes=num_classes).permute(0, 3, 1, 2).float() - 500.0 - # initialize the mean dice loss - loss = FocalLoss(to_onehot_y=True) - loss_onehot = FocalLoss(to_onehot_y=False) - - # focal loss for pred_very_good should be close to 0 - target_one_hot = F.one_hot(target, num_classes=num_classes).permute(0, 3, 1, 2) # test one hot - target = target.unsqueeze(1) # shape (1, 1, H, W) - - focal_loss_good = float(loss(pred_very_good, target).cpu()) - self.assertAlmostEqual(focal_loss_good, 0.0, places=3) - - focal_loss_good = float(loss_onehot(pred_very_good, target_one_hot).cpu()) - self.assertAlmostEqual(focal_loss_good, 0.0, places=3) + for use_softmax in [True, False]: + num_classes = 6 # labels 0 to 5 + # define 2d examples + target = torch.tensor([[0, 0, 0, 0], [0, 1, 2, 0], [0, 3, 4, 0], [0, 0, 0, 0]]) + # add another dimension corresponding to the batch (batch size = 1 here) + target = target.unsqueeze(0) # shape (1, H, W) + pred_very_good = 1000 * F.one_hot(target, num_classes=num_classes).permute(0, 3, 1, 2).float() - 500.0 + # initialize the mean dice loss + loss = FocalLoss(to_onehot_y=True, use_softmax=use_softmax) + loss_onehot = FocalLoss(to_onehot_y=False, use_softmax=use_softmax) + + # focal loss for pred_very_good should be close to 0 + target_one_hot = F.one_hot(target, num_classes=num_classes).permute(0, 3, 1, 2) # test one hot + target = target.unsqueeze(1) # shape (1, 1, H, W) + + focal_loss_good = float(loss(pred_very_good, target).cpu()) + self.assertAlmostEqual(focal_loss_good, 0.0, places=3) + + focal_loss_good = float(loss_onehot(pred_very_good, target_one_hot).cpu()) + self.assertAlmostEqual(focal_loss_good, 0.0, places=3) + + # with alpha + loss = FocalLoss(to_onehot_y=True, alpha=0.5, use_softmax=use_softmax) + focal_loss_good = float(loss(pred_very_good, target).cpu()) + self.assertAlmostEqual(focal_loss_good, 0.0, places=3) + loss_onehot = FocalLoss(to_onehot_y=False, alpha=0.5, use_softmax=use_softmax) + focal_loss_good = float(loss_onehot(pred_very_good, target_one_hot).cpu()) + self.assertAlmostEqual(focal_loss_good, 0.0, places=3) def test_bin_seg_3d(self): - num_classes = 2 # labels 0, 1 - # define 3d examples - target = torch.tensor( - [ - # raw 0 - [[0, 0, 0, 0], [0, 1, 1, 0], [0, 1, 1, 0], [0, 0, 0, 0]], - # raw 1 - [[0, 0, 0, 0], [0, 1, 1, 0], [0, 1, 1, 0], [0, 0, 0, 0]], - # raw 2 - [[0, 0, 0, 0], [0, 1, 1, 0], [0, 1, 1, 0], [0, 0, 0, 0]], - ] - ) - # add another dimension corresponding to the batch (batch size = 1 here) - target = target.unsqueeze(0) # shape (1, H, W, D) - target_one_hot = F.one_hot(target, num_classes=num_classes).permute(0, 4, 1, 2, 3) # test one hot - pred_very_good = 1000 * F.one_hot(target, num_classes=num_classes).permute(0, 4, 1, 2, 3).float() - 500.0 - - # initialize the mean dice loss - loss = FocalLoss(to_onehot_y=True) - loss_onehot = FocalLoss(to_onehot_y=False) - - # focal loss for pred_very_good should be close to 0 - target = target.unsqueeze(1) # shape (1, 1, H, W) - focal_loss_good = float(loss(pred_very_good, target).cpu()) - self.assertAlmostEqual(focal_loss_good, 0.0, places=3) - - focal_loss_good = float(loss_onehot(pred_very_good, target_one_hot).cpu()) - self.assertAlmostEqual(focal_loss_good, 0.0, places=3) + for use_softmax in [True, False]: + num_classes = 2 # labels 0, 1 + # define 3d examples + target = torch.tensor( + [ + # raw 0 + [[0, 0, 0, 0], [0, 1, 1, 0], [0, 1, 1, 0], [0, 0, 0, 0]], + # raw 1 + [[0, 0, 0, 0], [0, 1, 1, 0], [0, 1, 1, 0], [0, 0, 0, 0]], + # raw 2 + [[0, 0, 0, 0], [0, 1, 1, 0], [0, 1, 1, 0], [0, 0, 0, 0]], + ] + ) + # add another dimension corresponding to the batch (batch size = 1 here) + target = target.unsqueeze(0) # shape (1, H, W, D) + target_one_hot = F.one_hot(target, num_classes=num_classes).permute(0, 4, 1, 2, 3) # test one hot + pred_very_good = 1000 * F.one_hot(target, num_classes=num_classes).permute(0, 4, 1, 2, 3).float() - 500.0 + + # initialize the mean dice loss + loss = FocalLoss(to_onehot_y=True, use_softmax=use_softmax) + loss_onehot = FocalLoss(to_onehot_y=False, use_softmax=use_softmax) + + # focal loss for pred_very_good should be close to 0 + target = target.unsqueeze(1) # shape (1, 1, H, W) + focal_loss_good = float(loss(pred_very_good, target).cpu()) + self.assertAlmostEqual(focal_loss_good, 0.0, places=3) + + focal_loss_good = float(loss_onehot(pred_very_good, target_one_hot).cpu()) + self.assertAlmostEqual(focal_loss_good, 0.0, places=3) + + # with alpha + loss = FocalLoss(to_onehot_y=True, alpha=0.5, use_softmax=use_softmax) + focal_loss_good = float(loss(pred_very_good, target).cpu()) + self.assertAlmostEqual(focal_loss_good, 0.0, places=3) + loss_onehot = FocalLoss(to_onehot_y=False, alpha=0.5, use_softmax=use_softmax) + focal_loss_good = float(loss_onehot(pred_very_good, target_one_hot).cpu()) + self.assertAlmostEqual(focal_loss_good, 0.0, places=3) def test_foreground(self): background = torch.ones(1, 1, 5, 5) @@ -260,10 +338,28 @@ def test_ill_class_weight(self): with self.assertRaisesRegex(ValueError, ""): FocalLoss(include_background=False, weight=(1.0, 1.0, -1.0))(chn_input, chn_target) + def test_warnings(self): + with self.assertWarns(Warning): + chn_input = torch.ones((1, 1, 3)) + chn_target = torch.ones((1, 1, 3)) + loss = FocalLoss(to_onehot_y=True) + loss(chn_input, chn_target) + with self.assertWarns(Warning): + chn_input = torch.ones((1, 1, 3)) + chn_target = torch.ones((1, 1, 3)) + loss = FocalLoss(include_background=False) + loss(chn_input, chn_target) + with self.assertWarns(Warning): + chn_input = torch.ones((1, 3, 3)) + chn_target = torch.ones((1, 3, 3)) + loss = FocalLoss(include_background=False, use_softmax=True, alpha=0.5) + loss(chn_input, chn_target) + def test_script(self): - loss = FocalLoss() - test_input = torch.ones(2, 2, 8, 8) - test_script_save(loss, test_input, test_input) + for use_softmax in [True, False]: + loss = FocalLoss(use_softmax=use_softmax) + test_input = torch.ones(2, 2, 8, 8) + test_script_save(loss, test_input, test_input) if __name__ == "__main__": From fca52faf8752dadcf8c932eb8b1235ea0332a3a1 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 26 May 2023 13:13:01 +0000 Subject: [PATCH 07/11] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/test_focal_loss.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_focal_loss.py b/tests/test_focal_loss.py index 9344879f63..6b12c8412d 100644 --- a/tests/test_focal_loss.py +++ b/tests/test_focal_loss.py @@ -28,8 +28,8 @@ def get_input_data(device): return { "input": torch.tensor([[ - [[1.0, 1.0], [0.5, 0.0]], - [[1.0, 1.0], [0.5, 0.0]], + [[1.0, 1.0], [0.5, 0.0]], + [[1.0, 1.0], [0.5, 0.0]], [[1.0, 1.0], [0.5, 0.0]]]], device=device), # (1, 3, 2, 2) "target": torch.tensor([[[[0, 1], [2, 0]]]], device=device), # (1, 1, 2, 2) } From 99f0284627d75e9cb260009854cefe809bf61d1e Mon Sep 17 00:00:00 2001 From: monai-bot Date: Fri, 26 May 2023 13:42:35 +0000 Subject: [PATCH 08/11] [MONAI] code formatting Signed-off-by: monai-bot --- monai/losses/focal_loss.py | 16 +++----- tests/test_focal_loss.py | 78 ++++++++++++++++++++++++-------------- 2 files changed, 55 insertions(+), 39 deletions(-) diff --git a/monai/losses/focal_loss.py b/monai/losses/focal_loss.py index 5c9bfd5cd1..d6071edd71 100644 --- a/monai/losses/focal_loss.py +++ b/monai/losses/focal_loss.py @@ -200,11 +200,9 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].') return loss + def softmax_focal_loss( - input: torch.Tensor, - target: torch.Tensor, - gamma: float = 2.0, - alpha: Optional[float] = None, + input: torch.Tensor, target: torch.Tensor, gamma: float = 2.0, alpha: Optional[float] = None ) -> torch.Tensor: """ FL(pt) = -alpha * (1 - pt)**gamma * log(pt) @@ -213,22 +211,20 @@ def softmax_focal_loss( s_j is the unnormalized score for class j. """ input_ls = input.log_softmax(1) - loss: torch.Tensor = - (1 - input_ls.exp()).pow(gamma) * input_ls * target + loss: torch.Tensor = -(1 - input_ls.exp()).pow(gamma) * input_ls * target if alpha is not None: # (1-alpha) for the background class and alpha for the other classes - alpha_fac = torch.tensor([1-alpha] + [alpha] * (target.shape[1]-1)).to(loss) + alpha_fac = torch.tensor([1 - alpha] + [alpha] * (target.shape[1] - 1)).to(loss) broadcast_dims = [-1] + [1] * len(target.shape[2:]) alpha_fac = alpha_fac.view(broadcast_dims) loss = alpha_fac * loss return loss + def sigmoid_focal_loss( - input: torch.Tensor, - target: torch.Tensor, - gamma: float = 2.0, - alpha: Optional[float] = None, + input: torch.Tensor, target: torch.Tensor, gamma: float = 2.0, alpha: Optional[float] = None ) -> torch.Tensor: """ FL(pt) = -alpha * (1 - pt)**gamma * log(pt) diff --git a/tests/test_focal_loss.py b/tests/test_focal_loss.py index 6b12c8412d..91feb9c0bb 100644 --- a/tests/test_focal_loss.py +++ b/tests/test_focal_loss.py @@ -17,7 +17,6 @@ import torch import torch.nn as nn import torch.nn.functional as F - from parameterized import parameterized from monai.losses import FocalLoss @@ -25,45 +24,66 @@ from tests.utils import test_script_save TEST_CASES = [] + + def get_input_data(device): return { - "input": torch.tensor([[ - [[1.0, 1.0], [0.5, 0.0]], - [[1.0, 1.0], [0.5, 0.0]], - [[1.0, 1.0], [0.5, 0.0]]]], device=device), # (1, 3, 2, 2) - "target": torch.tensor([[[[0, 1], [2, 0]]]], device=device), # (1, 1, 2, 2) + "input": torch.tensor( + [[[[1.0, 1.0], [0.5, 0.0]], [[1.0, 1.0], [0.5, 0.0]], [[1.0, 1.0], [0.5, 0.0]]]], device=device + ), # (1, 3, 2, 2) + "target": torch.tensor([[[[0, 1], [2, 0]]]], device=device), # (1, 1, 2, 2) } + + for device in ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"]: input_data = get_input_data(device) TEST_CASES.append([{"to_onehot_y": True}, input_data, 0.34959]) - TEST_CASES.append([{"to_onehot_y": False}, { - "input": input_data['input'], # (1, 3, 2, 2) - "target": F.one_hot(input_data['target'].squeeze(1)).permute(0,3,1,2)}, # (1, 3, 2, 2) - 0.34959 - ]) + TEST_CASES.append( + [ + {"to_onehot_y": False}, + { + "input": input_data["input"], # (1, 3, 2, 2) + "target": F.one_hot(input_data["target"].squeeze(1)).permute(0, 3, 1, 2), + }, # (1, 3, 2, 2) + 0.34959, + ] + ) TEST_CASES.append([{"to_onehot_y": True, "include_background": False}, input_data, 0.36498]) TEST_CASES.append([{"to_onehot_y": True, "alpha": 0.8}, input_data, 0.08423]) - TEST_CASES.append([{"to_onehot_y": True, "reduction": "none"}, input_data, - np.array([[ - [[0.02266, 0.70187], - [0.37741, 0.17329]], - [[0.70187, 0.02266], - [0.37741, 0.17329]], - [[0.70187, 0.70187], - [0.06757, 0.17329]]]]) - ]) - TEST_CASES.append([{"to_onehot_y": True, "weight": torch.tensor([0.5, 0.1, 0.2]), "reduction": "none"}, input_data, - np.array([[ - [[0.01133, 0.35093], - [0.18871, 0.08664]], - [[0.07019, 0.00227], - [0.03774, 0.01733]], - [[0.14037, 0.14037], - [0.01352, 0.03466]]]]) - ]) + TEST_CASES.append( + [ + {"to_onehot_y": True, "reduction": "none"}, + input_data, + np.array( + [ + [ + [[0.02266, 0.70187], [0.37741, 0.17329]], + [[0.70187, 0.02266], [0.37741, 0.17329]], + [[0.70187, 0.70187], [0.06757, 0.17329]], + ] + ] + ), + ] + ) + TEST_CASES.append( + [ + {"to_onehot_y": True, "weight": torch.tensor([0.5, 0.1, 0.2]), "reduction": "none"}, + input_data, + np.array( + [ + [ + [[0.01133, 0.35093], [0.18871, 0.08664]], + [[0.07019, 0.00227], [0.03774, 0.01733]], + [[0.14037, 0.14037], [0.01352, 0.03466]], + ] + ] + ), + ] + ) TEST_CASES.append([{"to_onehot_y": True, "use_softmax": True}, input_data, 0.16276]) TEST_CASES.append([{"to_onehot_y": True, "alpha": 0.8, "use_softmax": True}, input_data, 0.08138]) + class TestFocalLoss(unittest.TestCase): @parameterized.expand(TEST_CASES) def test_result(self, input_param, input_data, expected_val): From 564ce2cb5625981f54043d5ee1651ea2927125af Mon Sep 17 00:00:00 2001 From: Qingpeng Li Date: Sat, 27 May 2023 16:32:42 +0800 Subject: [PATCH 09/11] refactor unittests Signed-off-by: Qingpeng Li --- tests/test_focal_loss.py | 40 ++++++++-------------------------------- 1 file changed, 8 insertions(+), 32 deletions(-) diff --git a/tests/test_focal_loss.py b/tests/test_focal_loss.py index 91feb9c0bb..c1cbc2cd34 100644 --- a/tests/test_focal_loss.py +++ b/tests/test_focal_loss.py @@ -24,27 +24,21 @@ from tests.utils import test_script_save TEST_CASES = [] - - -def get_input_data(device): - return { +for device in ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"]: + input_data = { "input": torch.tensor( [[[[1.0, 1.0], [0.5, 0.0]], [[1.0, 1.0], [0.5, 0.0]], [[1.0, 1.0], [0.5, 0.0]]]], device=device ), # (1, 3, 2, 2) "target": torch.tensor([[[[0, 1], [2, 0]]]], device=device), # (1, 1, 2, 2) } - - -for device in ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"]: - input_data = get_input_data(device) TEST_CASES.append([{"to_onehot_y": True}, input_data, 0.34959]) TEST_CASES.append( [ {"to_onehot_y": False}, { "input": input_data["input"], # (1, 3, 2, 2) - "target": F.one_hot(input_data["target"].squeeze(1)).permute(0, 3, 1, 2), - }, # (1, 3, 2, 2) + "target": F.one_hot(input_data["target"].squeeze(1)).permute(0, 3, 1, 2), # (1, 3, 2, 2) + }, 0.34959, ] ) @@ -52,32 +46,14 @@ def get_input_data(device): TEST_CASES.append([{"to_onehot_y": True, "alpha": 0.8}, input_data, 0.08423]) TEST_CASES.append( [ - {"to_onehot_y": True, "reduction": "none"}, - input_data, - np.array( - [ - [ - [[0.02266, 0.70187], [0.37741, 0.17329]], - [[0.70187, 0.02266], [0.37741, 0.17329]], - [[0.70187, 0.70187], [0.06757, 0.17329]], - ] - ] - ), + {"to_onehot_y": True, "reduction": "none"}, input_data, + np.array([[[[0.02266, 0.70187], [0.37741, 0.17329]], [[0.70187, 0.02266], [0.37741, 0.17329]], [[0.70187, 0.70187], [0.06757, 0.17329]]]]) ] ) TEST_CASES.append( [ - {"to_onehot_y": True, "weight": torch.tensor([0.5, 0.1, 0.2]), "reduction": "none"}, - input_data, - np.array( - [ - [ - [[0.01133, 0.35093], [0.18871, 0.08664]], - [[0.07019, 0.00227], [0.03774, 0.01733]], - [[0.14037, 0.14037], [0.01352, 0.03466]], - ] - ] - ), + {"to_onehot_y": True, "weight": torch.tensor([0.5, 0.1, 0.2]), "reduction": "none"}, input_data, + np.array([[[[0.01133, 0.35093], [0.18871, 0.08664]], [[0.07019, 0.00227], [0.03774, 0.01733]], [[0.14037, 0.14037], [0.01352, 0.03466]]]]) ] ) TEST_CASES.append([{"to_onehot_y": True, "use_softmax": True}, input_data, 0.16276]) From 92f2ad67a0095ab98302305887c40d0455a80956 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 27 May 2023 08:33:53 +0000 Subject: [PATCH 10/11] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/test_focal_loss.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_focal_loss.py b/tests/test_focal_loss.py index c1cbc2cd34..1aa0851838 100644 --- a/tests/test_focal_loss.py +++ b/tests/test_focal_loss.py @@ -38,7 +38,7 @@ { "input": input_data["input"], # (1, 3, 2, 2) "target": F.one_hot(input_data["target"].squeeze(1)).permute(0, 3, 1, 2), # (1, 3, 2, 2) - }, + }, 0.34959, ] ) @@ -52,7 +52,7 @@ ) TEST_CASES.append( [ - {"to_onehot_y": True, "weight": torch.tensor([0.5, 0.1, 0.2]), "reduction": "none"}, input_data, + {"to_onehot_y": True, "weight": torch.tensor([0.5, 0.1, 0.2]), "reduction": "none"}, input_data, np.array([[[[0.01133, 0.35093], [0.18871, 0.08664]], [[0.07019, 0.00227], [0.03774, 0.01733]], [[0.14037, 0.14037], [0.01352, 0.03466]]]]) ] ) From 443c2788dc03e2b6f76cd31f8661417b0892c936 Mon Sep 17 00:00:00 2001 From: monai-bot Date: Sat, 27 May 2023 08:40:00 +0000 Subject: [PATCH 11/11] [MONAI] code formatting Signed-off-by: monai-bot --- tests/test_focal_loss.py | 26 ++++++++++++++++++++++---- 1 file changed, 22 insertions(+), 4 deletions(-) diff --git a/tests/test_focal_loss.py b/tests/test_focal_loss.py index 1aa0851838..46a947ea7c 100644 --- a/tests/test_focal_loss.py +++ b/tests/test_focal_loss.py @@ -46,14 +46,32 @@ TEST_CASES.append([{"to_onehot_y": True, "alpha": 0.8}, input_data, 0.08423]) TEST_CASES.append( [ - {"to_onehot_y": True, "reduction": "none"}, input_data, - np.array([[[[0.02266, 0.70187], [0.37741, 0.17329]], [[0.70187, 0.02266], [0.37741, 0.17329]], [[0.70187, 0.70187], [0.06757, 0.17329]]]]) + {"to_onehot_y": True, "reduction": "none"}, + input_data, + np.array( + [ + [ + [[0.02266, 0.70187], [0.37741, 0.17329]], + [[0.70187, 0.02266], [0.37741, 0.17329]], + [[0.70187, 0.70187], [0.06757, 0.17329]], + ] + ] + ), ] ) TEST_CASES.append( [ - {"to_onehot_y": True, "weight": torch.tensor([0.5, 0.1, 0.2]), "reduction": "none"}, input_data, - np.array([[[[0.01133, 0.35093], [0.18871, 0.08664]], [[0.07019, 0.00227], [0.03774, 0.01733]], [[0.14037, 0.14037], [0.01352, 0.03466]]]]) + {"to_onehot_y": True, "weight": torch.tensor([0.5, 0.1, 0.2]), "reduction": "none"}, + input_data, + np.array( + [ + [ + [[0.01133, 0.35093], [0.18871, 0.08664]], + [[0.07019, 0.00227], [0.03774, 0.01733]], + [[0.14037, 0.14037], [0.01352, 0.03466]], + ] + ] + ), ] ) TEST_CASES.append([{"to_onehot_y": True, "use_softmax": True}, input_data, 0.16276])