From 8693adae25e4bc4d80598818c7c19f9fd0127b2b Mon Sep 17 00:00:00 2001 From: Kyle Harrington Date: Wed, 12 Jun 2024 16:13:09 -0400 Subject: [PATCH 1/9] Add alpha parameter to DiceFocalLoss --- monai/losses/dice.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/monai/losses/dice.py b/monai/losses/dice.py index f1c357d31f..92f629062f 100644 --- a/monai/losses/dice.py +++ b/monai/losses/dice.py @@ -811,7 +811,7 @@ class DiceFocalLoss(_Loss): The details of Focal Loss is shown in ``monai.losses.FocalLoss``. ``gamma`` and ``lambda_focal`` are only used for the focal loss. - ``include_background``, ``weight`` and ``reduction`` are used for both losses + ``include_background``, ``weight``, ``reduction``, and ``alpha`` are used for both losses, and other parameters are only used for dice loss. """ @@ -835,6 +835,7 @@ def __init__( gamma: float = 2.0, focal_weight: Sequence[float] | float | int | torch.Tensor | None = None, weight: Sequence[float] | float | int | torch.Tensor | None = None, + alpha: float | None = None, lambda_dice: float = 1.0, lambda_focal: float = 1.0, ) -> None: @@ -867,6 +868,7 @@ def __init__( weight: weights to apply to the voxels of each class. If None no weights are applied. The input can be a single value (same weight for all classes), a sequence of values (the length of the sequence should be the same as the number of classes). + alpha: value of the alpha in the definition of the alpha-balanced Focal loss. The value should be in [0, 1]. Defaults to None. lambda_dice: the trade-off weight value for dice loss. The value should be no less than 0.0. Defaults to 1.0. lambda_focal: the trade-off weight value for focal loss. The value should be no less than 0.0. @@ -890,7 +892,12 @@ def __init__( weight=weight, ) self.focal = FocalLoss( - include_background=include_background, to_onehot_y=False, gamma=gamma, weight=weight, reduction=reduction + include_background=include_background, + to_onehot_y=False, + gamma=gamma, + weight=weight, + alpha=alpha, + reduction=reduction ) if lambda_dice < 0.0: raise ValueError("lambda_dice should be no less than 0.0.") From 06a2509ba2a0ff774bebd620d544b07720990373 Mon Sep 17 00:00:00 2001 From: Kyle Harrington Date: Wed, 12 Jun 2024 16:21:21 -0400 Subject: [PATCH 2/9] Add a test for alpha usage in DiceFocalLoss --- tests/test_dice_focal_loss.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/tests/test_dice_focal_loss.py b/tests/test_dice_focal_loss.py index 814a174762..6032c176a9 100644 --- a/tests/test_dice_focal_loss.py +++ b/tests/test_dice_focal_loss.py @@ -91,6 +91,28 @@ def test_script(self): test_input = torch.ones(2, 1, 8, 8) test_script_save(loss, test_input, test_input) + def test_result_with_alpha(self): + size = [3, 3, 5, 5] + label = torch.randint(low=0, high=2, size=size) + pred = torch.randn(size) + alpha_values = [0.25, 0.5, 0.75] + for reduction in ["sum", "mean", "none"]: + for weight in [None, torch.tensor([1.0, 1.0, 2.0]), (3, 2.0, 1)]: + common_params = { + "include_background": True, + "to_onehot_y": False, + "reduction": reduction, + "weight": weight, + } + for lambda_focal in [0.5, 1.0, 1.5]: + for alpha in alpha_values: + dice_focal = DiceFocalLoss(gamma=1.0, lambda_focal=lambda_focal, alpha=alpha, **common_params) + dice = DiceLoss(**common_params) + focal = FocalLoss(gamma=1.0, alpha=alpha, **common_params) + result = dice_focal(pred, label) + expected_val = dice(pred, label) + lambda_focal * focal(pred, label) + np.testing.assert_allclose(result, expected_val) + if __name__ == "__main__": unittest.main() From 7eec525a4e5b03d8889ed9fcb96b46f10af137c0 Mon Sep 17 00:00:00 2001 From: Kyle Harrington Date: Wed, 12 Jun 2024 21:28:26 -0400 Subject: [PATCH 3/9] Apply black --- monai/losses/dice.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/losses/dice.py b/monai/losses/dice.py index 92f629062f..08b9550dfb 100644 --- a/monai/losses/dice.py +++ b/monai/losses/dice.py @@ -897,7 +897,7 @@ def __init__( gamma=gamma, weight=weight, alpha=alpha, - reduction=reduction + reduction=reduction, ) if lambda_dice < 0.0: raise ValueError("lambda_dice should be no less than 0.0.") From ec45af206d8a8b67868d08af4022af7c6dc22b28 Mon Sep 17 00:00:00 2001 From: Kyle Harrington Date: Wed, 12 Jun 2024 21:31:48 -0400 Subject: [PATCH 4/9] DCO Remediation Commit for Kyle Harrington I, Kyle Harrington , hereby add my Signed-off-by to this commit: 8693adae25e4bc4d80598818c7c19f9fd0127b2b I, Kyle Harrington , hereby add my Signed-off-by to this commit: 06a2509ba2a0ff774bebd620d544b07720990373 I, Kyle Harrington , hereby add my Signed-off-by to this commit: 7eec525a4e5b03d8889ed9fcb96b46f10af137c0 Signed-off-by: Kyle Harrington --- monai/losses/dice.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/monai/losses/dice.py b/monai/losses/dice.py index 08b9550dfb..96d012c194 100644 --- a/monai/losses/dice.py +++ b/monai/losses/dice.py @@ -868,7 +868,8 @@ def __init__( weight: weights to apply to the voxels of each class. If None no weights are applied. The input can be a single value (same weight for all classes), a sequence of values (the length of the sequence should be the same as the number of classes). - alpha: value of the alpha in the definition of the alpha-balanced Focal loss. The value should be in [0, 1]. Defaults to None. + alpha: value of the alpha in the definition of the alpha-balanced Focal loss. The value should be in + [0, 1]. Defaults to None. lambda_dice: the trade-off weight value for dice loss. The value should be no less than 0.0. Defaults to 1.0. lambda_focal: the trade-off weight value for focal loss. The value should be no less than 0.0. From 10c82c672df10f7ac134c9b78f4d1149289b652f Mon Sep 17 00:00:00 2001 From: Kyle Harrington Date: Tue, 18 Jun 2024 08:54:42 -0400 Subject: [PATCH 5/9] Fix argument ordering, and make alpha test parameterized --- monai/losses/dice.py | 2 +- tests/test_dice_focal_loss.py | 45 +++++++++++++++++++++-------------- 2 files changed, 28 insertions(+), 19 deletions(-) diff --git a/monai/losses/dice.py b/monai/losses/dice.py index 96d012c194..f1b06eff93 100644 --- a/monai/losses/dice.py +++ b/monai/losses/dice.py @@ -835,9 +835,9 @@ def __init__( gamma: float = 2.0, focal_weight: Sequence[float] | float | int | torch.Tensor | None = None, weight: Sequence[float] | float | int | torch.Tensor | None = None, - alpha: float | None = None, lambda_dice: float = 1.0, lambda_focal: float = 1.0, + alpha: float | None = None, ) -> None: """ Args: diff --git a/tests/test_dice_focal_loss.py b/tests/test_dice_focal_loss.py index 6032c176a9..c7a2eb4a0b 100644 --- a/tests/test_dice_focal_loss.py +++ b/tests/test_dice_focal_loss.py @@ -91,28 +91,37 @@ def test_script(self): test_input = torch.ones(2, 1, 8, 8) test_script_save(loss, test_input, test_input) - def test_result_with_alpha(self): + @parameterized.expand([ + ("sum_None_0.5_0.25", "sum", None, 0.5, 0.25), + ("sum_weight_0.5_0.25", "sum", torch.tensor([1.0, 1.0, 2.0]), 0.5, 0.25), + ("sum_weight_tuple_0.5_0.25", "sum", (3, 2.0, 1), 0.5, 0.25), + ("mean_None_0.5_0.25", "mean", None, 0.5, 0.25), + ("mean_weight_0.5_0.25", "mean", torch.tensor([1.0, 1.0, 2.0]), 0.5, 0.25), + ("mean_weight_tuple_0.5_0.25", "mean", (3, 2.0, 1), 0.5, 0.25), + ("none_None_0.5_0.25", "none", None, 0.5, 0.25), + ("none_weight_0.5_0.25", "none", torch.tensor([1.0, 1.0, 2.0]), 0.5, 0.25), + ("none_weight_tuple_0.5_0.25", "none", (3, 2.0, 1), 0.5, 0.25), + ]) + def test_with_alpha(self, name, reduction, weight, lambda_focal, alpha): size = [3, 3, 5, 5] label = torch.randint(low=0, high=2, size=size) pred = torch.randn(size) - alpha_values = [0.25, 0.5, 0.75] - for reduction in ["sum", "mean", "none"]: - for weight in [None, torch.tensor([1.0, 1.0, 2.0]), (3, 2.0, 1)]: - common_params = { - "include_background": True, - "to_onehot_y": False, - "reduction": reduction, - "weight": weight, - } - for lambda_focal in [0.5, 1.0, 1.5]: - for alpha in alpha_values: - dice_focal = DiceFocalLoss(gamma=1.0, lambda_focal=lambda_focal, alpha=alpha, **common_params) - dice = DiceLoss(**common_params) - focal = FocalLoss(gamma=1.0, alpha=alpha, **common_params) - result = dice_focal(pred, label) - expected_val = dice(pred, label) + lambda_focal * focal(pred, label) - np.testing.assert_allclose(result, expected_val) + common_params = { + "include_background": True, + "to_onehot_y": False, + "reduction": reduction, + "weight": weight, + } + + dice_focal = DiceFocalLoss(gamma=1.0, lambda_focal=lambda_focal, alpha=alpha, **common_params) + dice = DiceLoss(**common_params) + focal = FocalLoss(gamma=1.0, alpha=alpha, **common_params) + + result = dice_focal(pred, label) + expected_val = dice(pred, label) + lambda_focal * focal(pred, label) + + np.testing.assert_allclose(result, expected_val, err_msg=f"Failed on case: {name}") if __name__ == "__main__": unittest.main() From 303a257ba7fcadeaceddfca858f6b8662fc2a058 Mon Sep 17 00:00:00 2001 From: Kyle Harrington Date: Tue, 18 Jun 2024 08:56:40 -0400 Subject: [PATCH 6/9] Fix formatting --- monai/losses/dice.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/losses/dice.py b/monai/losses/dice.py index f1b06eff93..9817577c6b 100644 --- a/monai/losses/dice.py +++ b/monai/losses/dice.py @@ -837,7 +837,7 @@ def __init__( weight: Sequence[float] | float | int | torch.Tensor | None = None, lambda_dice: float = 1.0, lambda_focal: float = 1.0, - alpha: float | None = None, + alpha: float | None = None, ) -> None: """ Args: From f7286e24e77cc13ce2d429bbcd344c7b64692cf1 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 18 Jun 2024 12:56:54 +0000 Subject: [PATCH 7/9] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/test_dice_focal_loss.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_dice_focal_loss.py b/tests/test_dice_focal_loss.py index c7a2eb4a0b..aa1cce128b 100644 --- a/tests/test_dice_focal_loss.py +++ b/tests/test_dice_focal_loss.py @@ -117,10 +117,10 @@ def test_with_alpha(self, name, reduction, weight, lambda_focal, alpha): dice_focal = DiceFocalLoss(gamma=1.0, lambda_focal=lambda_focal, alpha=alpha, **common_params) dice = DiceLoss(**common_params) focal = FocalLoss(gamma=1.0, alpha=alpha, **common_params) - + result = dice_focal(pred, label) expected_val = dice(pred, label) + lambda_focal * focal(pred, label) - + np.testing.assert_allclose(result, expected_val, err_msg=f"Failed on case: {name}") if __name__ == "__main__": From 0f4747681f0eab4a2f89fcebe9e85afcfdf491fe Mon Sep 17 00:00:00 2001 From: Kyle Harrington Date: Tue, 18 Jun 2024 09:09:31 -0400 Subject: [PATCH 8/9] Fix test formatting with autofix --- tests/test_dice_focal_loss.py | 32 +++++++++++++++----------------- 1 file changed, 15 insertions(+), 17 deletions(-) diff --git a/tests/test_dice_focal_loss.py b/tests/test_dice_focal_loss.py index aa1cce128b..f769aac69f 100644 --- a/tests/test_dice_focal_loss.py +++ b/tests/test_dice_focal_loss.py @@ -91,28 +91,25 @@ def test_script(self): test_input = torch.ones(2, 1, 8, 8) test_script_save(loss, test_input, test_input) - @parameterized.expand([ - ("sum_None_0.5_0.25", "sum", None, 0.5, 0.25), - ("sum_weight_0.5_0.25", "sum", torch.tensor([1.0, 1.0, 2.0]), 0.5, 0.25), - ("sum_weight_tuple_0.5_0.25", "sum", (3, 2.0, 1), 0.5, 0.25), - ("mean_None_0.5_0.25", "mean", None, 0.5, 0.25), - ("mean_weight_0.5_0.25", "mean", torch.tensor([1.0, 1.0, 2.0]), 0.5, 0.25), - ("mean_weight_tuple_0.5_0.25", "mean", (3, 2.0, 1), 0.5, 0.25), - ("none_None_0.5_0.25", "none", None, 0.5, 0.25), - ("none_weight_0.5_0.25", "none", torch.tensor([1.0, 1.0, 2.0]), 0.5, 0.25), - ("none_weight_tuple_0.5_0.25", "none", (3, 2.0, 1), 0.5, 0.25), - ]) + @parameterized.expand( + [ + ("sum_None_0.5_0.25", "sum", None, 0.5, 0.25), + ("sum_weight_0.5_0.25", "sum", torch.tensor([1.0, 1.0, 2.0]), 0.5, 0.25), + ("sum_weight_tuple_0.5_0.25", "sum", (3, 2.0, 1), 0.5, 0.25), + ("mean_None_0.5_0.25", "mean", None, 0.5, 0.25), + ("mean_weight_0.5_0.25", "mean", torch.tensor([1.0, 1.0, 2.0]), 0.5, 0.25), + ("mean_weight_tuple_0.5_0.25", "mean", (3, 2.0, 1), 0.5, 0.25), + ("none_None_0.5_0.25", "none", None, 0.5, 0.25), + ("none_weight_0.5_0.25", "none", torch.tensor([1.0, 1.0, 2.0]), 0.5, 0.25), + ("none_weight_tuple_0.5_0.25", "none", (3, 2.0, 1), 0.5, 0.25), + ] + ) def test_with_alpha(self, name, reduction, weight, lambda_focal, alpha): size = [3, 3, 5, 5] label = torch.randint(low=0, high=2, size=size) pred = torch.randn(size) - common_params = { - "include_background": True, - "to_onehot_y": False, - "reduction": reduction, - "weight": weight, - } + common_params = {"include_background": True, "to_onehot_y": False, "reduction": reduction, "weight": weight} dice_focal = DiceFocalLoss(gamma=1.0, lambda_focal=lambda_focal, alpha=alpha, **common_params) dice = DiceLoss(**common_params) @@ -123,5 +120,6 @@ def test_with_alpha(self, name, reduction, weight, lambda_focal, alpha): np.testing.assert_allclose(result, expected_val, err_msg=f"Failed on case: {name}") + if __name__ == "__main__": unittest.main() From a590b494ac416b1ee04fd6beaaedea4b686ea957 Mon Sep 17 00:00:00 2001 From: Kyle Harrington Date: Sun, 23 Jun 2024 16:33:11 -0400 Subject: [PATCH 9/9] DCO Remediation Commit for Kyle Harrington I, Kyle Harrington , hereby add my Signed-off-by to this commit: 10c82c672df10f7ac134c9b78f4d1149289b652f I, Kyle Harrington , hereby add my Signed-off-by to this commit: 303a257ba7fcadeaceddfca858f6b8662fc2a058 I, Kyle Harrington , hereby add my Signed-off-by to this commit: 0f4747681f0eab4a2f89fcebe9e85afcfdf491fe Signed-off-by: Kyle Harrington --- monai/losses/dice.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/monai/losses/dice.py b/monai/losses/dice.py index 9817577c6b..07a38d9572 100644 --- a/monai/losses/dice.py +++ b/monai/losses/dice.py @@ -868,13 +868,12 @@ def __init__( weight: weights to apply to the voxels of each class. If None no weights are applied. The input can be a single value (same weight for all classes), a sequence of values (the length of the sequence should be the same as the number of classes). - alpha: value of the alpha in the definition of the alpha-balanced Focal loss. The value should be in - [0, 1]. Defaults to None. lambda_dice: the trade-off weight value for dice loss. The value should be no less than 0.0. Defaults to 1.0. lambda_focal: the trade-off weight value for focal loss. The value should be no less than 0.0. Defaults to 1.0. - + alpha: value of the alpha in the definition of the alpha-balanced Focal loss. The value should be in + [0, 1]. Defaults to None. """ super().__init__() weight = focal_weight if focal_weight is not None else weight