diff --git a/monai/losses/focal_loss.py b/monai/losses/focal_loss.py index 98c1a071b6..28d1c0cdc9 100644 --- a/monai/losses/focal_loss.py +++ b/monai/losses/focal_loss.py @@ -234,9 +234,8 @@ def sigmoid_focal_loss( """ # computing binary cross entropy with logits # equivalent to F.binary_cross_entropy_with_logits(input, target, reduction='none') - # see also https://github.com/pytorch/pytorch/blob/v1.9.0/aten/src/ATen/native/Loss.cpp#L231 - max_val = (-input).clamp(min=0) - loss: torch.Tensor = input - input * target + max_val + ((-max_val).exp() + (-input - max_val).exp()).log() + # see also https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/Loss.cpp#L363 + loss: torch.Tensor = input - input * target - F.logsigmoid(input) # sigmoid(-i) if t==1; sigmoid(i) if t==0 <=> # 1-sigmoid(i) if t==1; sigmoid(i) if t==0 <=> diff --git a/tests/test_focal_loss.py b/tests/test_focal_loss.py index de8d625058..0bb8a078ae 100644 --- a/tests/test_focal_loss.py +++ b/tests/test_focal_loss.py @@ -132,7 +132,7 @@ def test_consistency_with_cross_entropy_2d_no_reduction(self): error = np.abs(a - b) max_error = np.maximum(error, max_error) - assert np.allclose(max_error, 0) + assert np.allclose(max_error, 0, atol=1e-6) def test_consistency_with_cross_entropy_2d_onehot_label(self): """For gamma=0 the focal loss reduces to the cross entropy loss"""