From a9a2d3a0139ff7f2c54fe762ea1ccfa841ab6e8a Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Tue, 12 Mar 2024 21:56:54 +0800 Subject: [PATCH 1/2] fix #7533 Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- monai/losses/focal_loss.py | 5 ++--- tests/test_focal_loss.py | 2 +- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/monai/losses/focal_loss.py b/monai/losses/focal_loss.py index 98c1a071b6..0b1e7fba06 100644 --- a/monai/losses/focal_loss.py +++ b/monai/losses/focal_loss.py @@ -234,9 +234,8 @@ def sigmoid_focal_loss( """ # computing binary cross entropy with logits # equivalent to F.binary_cross_entropy_with_logits(input, target, reduction='none') - # see also https://github.com/pytorch/pytorch/blob/v1.9.0/aten/src/ATen/native/Loss.cpp#L231 - max_val = (-input).clamp(min=0) - loss: torch.Tensor = input - input * target + max_val + ((-max_val).exp() + (-input - max_val).exp()).log() + # see also https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/Loss.cpp#L363 + loss: torch.Tensor = input - input * target - F.logsigmoid(input) # sigmoid(-i) if t==1; sigmoid(i) if t==0 <=> # 1-sigmoid(i) if t==1; sigmoid(i) if t==0 <=> diff --git a/tests/test_focal_loss.py b/tests/test_focal_loss.py index de8d625058..0bb8a078ae 100644 --- a/tests/test_focal_loss.py +++ b/tests/test_focal_loss.py @@ -132,7 +132,7 @@ def test_consistency_with_cross_entropy_2d_no_reduction(self): error = np.abs(a - b) max_error = np.maximum(error, max_error) - assert np.allclose(max_error, 0) + assert np.allclose(max_error, 0, atol=1e-6) def test_consistency_with_cross_entropy_2d_onehot_label(self): """For gamma=0 the focal loss reduces to the cross entropy loss""" From 9e22a553174b6c154210a2654df2ceeb9203f4cb Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Tue, 12 Mar 2024 22:11:35 +0800 Subject: [PATCH 2/2] fix flake8 Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- monai/losses/focal_loss.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/losses/focal_loss.py b/monai/losses/focal_loss.py index 0b1e7fba06..28d1c0cdc9 100644 --- a/monai/losses/focal_loss.py +++ b/monai/losses/focal_loss.py @@ -235,7 +235,7 @@ def sigmoid_focal_loss( # computing binary cross entropy with logits # equivalent to F.binary_cross_entropy_with_logits(input, target, reduction='none') # see also https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/Loss.cpp#L363 - loss: torch.Tensor = input - input * target - F.logsigmoid(input) + loss: torch.Tensor = input - input * target - F.logsigmoid(input) # sigmoid(-i) if t==1; sigmoid(i) if t==0 <=> # 1-sigmoid(i) if t==1; sigmoid(i) if t==0 <=>