diff --git a/monai/losses/dice.py b/monai/losses/dice.py index f1c357d31f..07a38d9572 100644 --- a/monai/losses/dice.py +++ b/monai/losses/dice.py @@ -811,7 +811,7 @@ class DiceFocalLoss(_Loss): The details of Focal Loss is shown in ``monai.losses.FocalLoss``. ``gamma`` and ``lambda_focal`` are only used for the focal loss. - ``include_background``, ``weight`` and ``reduction`` are used for both losses + ``include_background``, ``weight``, ``reduction``, and ``alpha`` are used for both losses, and other parameters are only used for dice loss. """ @@ -837,6 +837,7 @@ def __init__( weight: Sequence[float] | float | int | torch.Tensor | None = None, lambda_dice: float = 1.0, lambda_focal: float = 1.0, + alpha: float | None = None, ) -> None: """ Args: @@ -871,7 +872,8 @@ def __init__( Defaults to 1.0. lambda_focal: the trade-off weight value for focal loss. The value should be no less than 0.0. Defaults to 1.0. - + alpha: value of the alpha in the definition of the alpha-balanced Focal loss. The value should be in + [0, 1]. Defaults to None. """ super().__init__() weight = focal_weight if focal_weight is not None else weight @@ -890,7 +892,12 @@ def __init__( weight=weight, ) self.focal = FocalLoss( - include_background=include_background, to_onehot_y=False, gamma=gamma, weight=weight, reduction=reduction + include_background=include_background, + to_onehot_y=False, + gamma=gamma, + weight=weight, + alpha=alpha, + reduction=reduction, ) if lambda_dice < 0.0: raise ValueError("lambda_dice should be no less than 0.0.") diff --git a/tests/test_dice_focal_loss.py b/tests/test_dice_focal_loss.py index 814a174762..f769aac69f 100644 --- a/tests/test_dice_focal_loss.py +++ b/tests/test_dice_focal_loss.py @@ -91,6 +91,35 @@ def test_script(self): test_input = torch.ones(2, 1, 8, 8) test_script_save(loss, test_input, test_input) + @parameterized.expand( + [ + ("sum_None_0.5_0.25", "sum", None, 0.5, 0.25), + ("sum_weight_0.5_0.25", "sum", torch.tensor([1.0, 1.0, 2.0]), 0.5, 0.25), + ("sum_weight_tuple_0.5_0.25", "sum", (3, 2.0, 1), 0.5, 0.25), + ("mean_None_0.5_0.25", "mean", None, 0.5, 0.25), + ("mean_weight_0.5_0.25", "mean", torch.tensor([1.0, 1.0, 2.0]), 0.5, 0.25), + ("mean_weight_tuple_0.5_0.25", "mean", (3, 2.0, 1), 0.5, 0.25), + ("none_None_0.5_0.25", "none", None, 0.5, 0.25), + ("none_weight_0.5_0.25", "none", torch.tensor([1.0, 1.0, 2.0]), 0.5, 0.25), + ("none_weight_tuple_0.5_0.25", "none", (3, 2.0, 1), 0.5, 0.25), + ] + ) + def test_with_alpha(self, name, reduction, weight, lambda_focal, alpha): + size = [3, 3, 5, 5] + label = torch.randint(low=0, high=2, size=size) + pred = torch.randn(size) + + common_params = {"include_background": True, "to_onehot_y": False, "reduction": reduction, "weight": weight} + + dice_focal = DiceFocalLoss(gamma=1.0, lambda_focal=lambda_focal, alpha=alpha, **common_params) + dice = DiceLoss(**common_params) + focal = FocalLoss(gamma=1.0, alpha=alpha, **common_params) + + result = dice_focal(pred, label) + expected_val = dice(pred, label) + lambda_focal * focal(pred, label) + + np.testing.assert_allclose(result, expected_val, err_msg=f"Failed on case: {name}") + if __name__ == "__main__": unittest.main()