From 087cf74ee51681b097b433842dfe271d5b8ab3ca Mon Sep 17 00:00:00 2001 From: Zifu Wang Date: Thu, 10 Oct 2024 20:40:18 +0200 Subject: [PATCH 1/9] Modify Dice, Jaccard and Tversky losses Signed-off-by: Zifu Wang --- monai/losses/dice.py | 21 ++++++++++++++++----- monai/losses/tversky.py | 19 +++++++++++-------- 2 files changed, 27 insertions(+), 13 deletions(-) diff --git a/monai/losses/dice.py b/monai/losses/dice.py index 3f02fae6b8..e6f6bbb6cf 100644 --- a/monai/losses/dice.py +++ b/monai/losses/dice.py @@ -17,6 +17,7 @@ import numpy as np import torch +import torch.linalg as LA import torch.nn as nn import torch.nn.functional as F from torch.nn.modules.loss import _Loss @@ -39,8 +40,16 @@ class DiceLoss(_Loss): The `smooth_nr` and `smooth_dr` parameters are values added to the intersection and union components of the inter-over-union calculation to smooth results respectively, these values should be small. - The original paper: Milletari, F. et. al. (2016) V-Net: Fully Convolutional Neural Networks forVolumetric - Medical Image Segmentation, 3DV, 2016. + The original papers: + + Milletari, F. et. al. (2016) V-Net: Fully Convolutional Neural Networks for Volumetric + Medical Image Segmentation. 3DV 2016. + + Wang, Z. et. al. (2023) Jaccard Metric Losses: Optimizing the Jaccard Index with + Soft Labels. NeurIPS 2023. + + Wang, Z. et. al. (2023) Dice Semimetric Losses: Optimizing the Dice Score with + Soft Labels. MICCAI 2023. """ @@ -174,16 +183,17 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: # reducing spatial dimensions and batch reduce_axis = [0] + reduce_axis - intersection = torch.sum(target * input, dim=reduce_axis) - if self.squared_pred: ground_o = torch.sum(target**2, dim=reduce_axis) pred_o = torch.sum(input**2, dim=reduce_axis) + difference = LA.vector_norm(input - target, ord=2, dim=reduce_axis) ** 2 else: ground_o = torch.sum(target, dim=reduce_axis) pred_o = torch.sum(input, dim=reduce_axis) + difference = LA.vector_norm(input - target, ord=1, dim=reduce_axis) denominator = ground_o + pred_o + intersection = (denominator - difference) / 2 if self.jaccard: denominator = 2.0 * (denominator - intersection) @@ -370,12 +380,13 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: reduce_axis: list[int] = torch.arange(2, len(input.shape)).tolist() if self.batch: reduce_axis = [0] + reduce_axis - intersection = torch.sum(target * input, reduce_axis) ground_o = torch.sum(target, reduce_axis) pred_o = torch.sum(input, reduce_axis) + difference = LA.vector_norm(input - target, ord=1, dim=reduce_axis) denominator = ground_o + pred_o + intersection = (denominator - difference) / 2 w = self.w_func(ground_o.float()) infs = torch.isinf(w) diff --git a/monai/losses/tversky.py b/monai/losses/tversky.py index 4f22bf84b4..a498882adf 100644 --- a/monai/losses/tversky.py +++ b/monai/losses/tversky.py @@ -15,6 +15,7 @@ from collections.abc import Callable import torch +import torch.linalg as LA from torch.nn.modules.loss import _Loss from monai.networks import one_hot @@ -28,6 +29,9 @@ class TverskyLoss(_Loss): Sadegh et al. (2017) Tversky loss function for image segmentation using 3D fully convolutional deep networks. (https://arxiv.org/abs/1706.05721) + Wang, Z. et. al. (2023) Dice Semimetric Losses: Optimizing the Dice Score with + Soft Labels. MICCAI 2023. + Adapted from: https://github.com/NifTK/NiftyNet/blob/v0.6.0/niftynet/layer/loss_segmentation.py#L631 @@ -134,20 +138,19 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: if target.shape != input.shape: raise AssertionError(f"ground truth has differing shape ({target.shape}) from input ({input.shape})") - p0 = input - p1 = 1 - p0 - g0 = target - g1 = 1 - g0 - # reducing only spatial dimensions (not batch nor channels) reduce_axis: list[int] = torch.arange(2, len(input.shape)).tolist() if self.batch: # reducing spatial dimensions and batch reduce_axis = [0] + reduce_axis - tp = torch.sum(p0 * g0, reduce_axis) - fp = self.alpha * torch.sum(p0 * g1, reduce_axis) - fn = self.beta * torch.sum(p1 * g0, reduce_axis) + pred_o = torch.sum(input, reduce_axis) + ground_o = torch.sum(target, reduce_axis) + difference = LA.vector_norm(input - target, ord=1, dim=reduce_axis) + + tp = (pred_o + ground_o - difference) / 2 + fp = self.alpha * (pred_o - tp) + fn = self.beta * (ground_o - tp) numerator = tp + self.smooth_nr denominator = tp + fp + fn + self.smooth_dr From 3f74183c5776d4746229006acad5b52c8c9df6ec Mon Sep 17 00:00:00 2001 From: Zifu Wang Date: Mon, 21 Oct 2024 10:06:28 +0200 Subject: [PATCH 2/9] Add helper function --- monai/losses/dice.py | 44 +++++++++++++++--------------- monai/losses/tversky.py | 15 +++++------ monai/losses/utils.py | 60 +++++++++++++++++++++++++++++++++++++++++ 3 files changed, 88 insertions(+), 31 deletions(-) create mode 100644 monai/losses/utils.py diff --git a/monai/losses/dice.py b/monai/losses/dice.py index e6f6bbb6cf..ec21baaa85 100644 --- a/monai/losses/dice.py +++ b/monai/losses/dice.py @@ -17,13 +17,13 @@ import numpy as np import torch -import torch.linalg as LA import torch.nn as nn import torch.nn.functional as F from torch.nn.modules.loss import _Loss from monai.losses.focal_loss import FocalLoss from monai.losses.spatial_mask import MaskedLoss +from monai.losses.utils import compute_tp_fp_fn from monai.networks import one_hot from monai.utils import DiceCEReduction, LossReduction, Weight, look_up_option, pytorch_after @@ -67,6 +67,7 @@ def __init__( smooth_dr: float = 1e-5, batch: bool = False, weight: Sequence[float] | float | int | torch.Tensor | None = None, + soft_label: bool = False, ) -> None: """ Args: @@ -98,6 +99,7 @@ def __init__( of the sequence should be the same as the number of classes. If not ``include_background``, the number of classes should not include the background category class 0). The value/values should be no less than 0. Defaults to None. + soft_label: whether the target contains non-binary values or not Raises: TypeError: When ``other_act`` is not an ``Optional[Callable]``. @@ -123,6 +125,7 @@ def __init__( weight = torch.as_tensor(weight) if weight is not None else None self.register_buffer("class_weight", weight) self.class_weight: None | torch.Tensor + self.soft_label = soft_label def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ @@ -183,22 +186,15 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: # reducing spatial dimensions and batch reduce_axis = [0] + reduce_axis - if self.squared_pred: - ground_o = torch.sum(target**2, dim=reduce_axis) - pred_o = torch.sum(input**2, dim=reduce_axis) - difference = LA.vector_norm(input - target, ord=2, dim=reduce_axis) ** 2 - else: - ground_o = torch.sum(target, dim=reduce_axis) - pred_o = torch.sum(input, dim=reduce_axis) - difference = LA.vector_norm(input - target, ord=1, dim=reduce_axis) - - denominator = ground_o + pred_o - intersection = (denominator - difference) / 2 - - if self.jaccard: - denominator = 2.0 * (denominator - intersection) + ord = 2 if self.squared_pred else 1 + tp, fp, fn = compute_tp_fp_fn(input, target, reduce_axis, ord, self.soft_label) + if not self.jaccard: + fp *= 0.5 + fn *= 0.5 + numerator = 2 * tp + self.smooth_nr + denominator = 2 * (tp + fp + fn) + self.smooth_dr - f: torch.Tensor = 1.0 - (2.0 * intersection + self.smooth_nr) / (denominator + self.smooth_dr) + f = 1 - numerator / denominator num_of_classes = target.shape[1] if self.class_weight is not None and num_of_classes != 1: @@ -282,6 +278,7 @@ def __init__( smooth_nr: float = 1e-5, smooth_dr: float = 1e-5, batch: bool = False, + soft_label: bool = False, ) -> None: """ Args: @@ -305,6 +302,7 @@ def __init__( batch: whether to sum the intersection and union areas over the batch dimension before the dividing. Defaults to False, intersection over union is computed from each item in the batch. If True, the class-weighted intersection and union areas are first summed across the batches. + soft_label: whether the target contains non-binary values or not Raises: TypeError: When ``other_act`` is not an ``Optional[Callable]``. @@ -329,6 +327,7 @@ def __init__( self.smooth_nr = float(smooth_nr) self.smooth_dr = float(smooth_dr) self.batch = batch + self.soft_label = soft_label def w_func(self, grnd): if self.w_type == str(Weight.SIMPLE): @@ -381,13 +380,12 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: if self.batch: reduce_axis = [0] + reduce_axis - ground_o = torch.sum(target, reduce_axis) - pred_o = torch.sum(input, reduce_axis) - difference = LA.vector_norm(input - target, ord=1, dim=reduce_axis) - - denominator = ground_o + pred_o - intersection = (denominator - difference) / 2 + tp, fp, fn = compute_tp_fp_fn(input, target, reduce_axis, 1, self.soft_label) + fp *= 0.5 + fn *= 0.5 + denominator = 2 * (tp + fp + fn) + ground_o = torch.sum(target, reduce_axis) w = self.w_func(ground_o.float()) infs = torch.isinf(w) if self.batch: @@ -399,7 +397,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: w = w + infs * max_values final_reduce_dim = 0 if self.batch else 1 - numer = 2.0 * (intersection * w).sum(final_reduce_dim, keepdim=True) + self.smooth_nr + numer = 2.0 * (tp * w).sum(final_reduce_dim, keepdim=True) + self.smooth_nr denom = (denominator * w).sum(final_reduce_dim, keepdim=True) + self.smooth_dr f: torch.Tensor = 1.0 - (numer / denom) diff --git a/monai/losses/tversky.py b/monai/losses/tversky.py index a498882adf..d21aa76537 100644 --- a/monai/losses/tversky.py +++ b/monai/losses/tversky.py @@ -15,9 +15,9 @@ from collections.abc import Callable import torch -import torch.linalg as LA from torch.nn.modules.loss import _Loss +from monai.losses.utils import compute_tp_fp_fn from monai.networks import one_hot from monai.utils import LossReduction @@ -50,6 +50,7 @@ def __init__( smooth_nr: float = 1e-5, smooth_dr: float = 1e-5, batch: bool = False, + soft_label: bool = False, ) -> None: """ Args: @@ -74,6 +75,7 @@ def __init__( batch: whether to sum the intersection and union areas over the batch dimension before the dividing. Defaults to False, a Dice loss value is computed independently from each item in the batch before any `reduction`. + soft_label: whether the target contains non-binary values or not Raises: TypeError: When ``other_act`` is not an ``Optional[Callable]``. @@ -97,6 +99,7 @@ def __init__( self.smooth_nr = float(smooth_nr) self.smooth_dr = float(smooth_dr) self.batch = batch + self.soft_label = soft_label def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ @@ -144,13 +147,9 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: # reducing spatial dimensions and batch reduce_axis = [0] + reduce_axis - pred_o = torch.sum(input, reduce_axis) - ground_o = torch.sum(target, reduce_axis) - difference = LA.vector_norm(input - target, ord=1, dim=reduce_axis) - - tp = (pred_o + ground_o - difference) / 2 - fp = self.alpha * (pred_o - tp) - fn = self.beta * (ground_o - tp) + tp, fp, fn = compute_tp_fp_fn(input, target, reduce_axis, 1, self.soft_label, False) + fp *= self.alpha + fn *= self.beta numerator = tp + self.smooth_nr denominator = tp + fp + fn + self.smooth_dr diff --git a/monai/losses/utils.py b/monai/losses/utils.py new file mode 100644 index 0000000000..646ebd6d62 --- /dev/null +++ b/monai/losses/utils.py @@ -0,0 +1,60 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import warnings + +import torch +import torch.linalg as LA + + +def compute_tp_fp_fn( + input: torch.Tensor, + target: torch.Tensor, + reduce_axis: list[int], + ord: int, + soft_label: bool, + decoupled: bool = True, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Adapted from: + https://github.com/zifuwanggg/JDTLosses + """ + if torch.unique(target).shape[0] > 2 and not soft_label: + warnings.warn("soft labels are used, but `soft_label == False`.") + + # the original implementation that is erroneous with soft labels + if ord == 1 and not soft_label: + tp = torch.sum(input * target, dim=reduce_axis) + # the original implementation of Dice and Jaccard loss + if decoupled: + fp = torch.sum(input, dim=reduce_axis) - tp + fn = torch.sum(target, dim=reduce_axis) - tp + # the original implementation of Tversky loss + else: + fp = torch.sum(input * (1 - target), dim=reduce_axis) + fn = torch.sum((1 - input) * target, dim=reduce_axis) + else: + pred_o = LA.vector_norm(input, ord=ord, dim=reduce_axis) + ground_o = LA.vector_norm(target, ord=ord, dim=reduce_axis) + difference = LA.vector_norm(input - target, ord=ord, dim=reduce_axis) + + if ord > 1: + pred_o = torch.pow(pred_o, exponent=ord) + ground_o = torch.pow(ground_o, exponent=ord) + difference = torch.pow(difference, exponent=ord) + + tp = (pred_o + ground_o - difference) / 2 + fp = pred_o - tp + fn = ground_o - tp + + return tp, fp, fn From 3e4f714461fd03653b9ffea194e3d5919edf884c Mon Sep 17 00:00:00 2001 From: Zifu Wang Date: Tue, 22 Oct 2024 11:05:51 +0200 Subject: [PATCH 3/9] Fix mypy error Signed-off-by: Zifu Wang --- monai/losses/dice.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/losses/dice.py b/monai/losses/dice.py index ec21baaa85..76fc87127d 100644 --- a/monai/losses/dice.py +++ b/monai/losses/dice.py @@ -194,7 +194,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: numerator = 2 * tp + self.smooth_nr denominator = 2 * (tp + fp + fn) + self.smooth_dr - f = 1 - numerator / denominator + f: torch.Tensor = 1 - numerator / denominator num_of_classes = target.shape[1] if self.class_weight is not None and num_of_classes != 1: @@ -413,7 +413,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: else: raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].') - return f + return f # type: ignore[arg-type] class GeneralizedWassersteinDiceLoss(_Loss): From f3ab6797b0a6c8e39f43161bc6b08cf6794661c0 Mon Sep 17 00:00:00 2001 From: Zifu Wang Date: Tue, 22 Oct 2024 11:07:51 +0200 Subject: [PATCH 4/9] Fix mypy error Signed-off-by: Zifu Wang --- monai/losses/dice.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/losses/dice.py b/monai/losses/dice.py index 76fc87127d..92e9e0c684 100644 --- a/monai/losses/dice.py +++ b/monai/losses/dice.py @@ -413,7 +413,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: else: raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].') - return f # type: ignore[arg-type] + return f class GeneralizedWassersteinDiceLoss(_Loss): From a778e583d32082480811a4bec3bed9a3622b5871 Mon Sep 17 00:00:00 2001 From: Zifu Wang Date: Sat, 30 Nov 2024 21:30:29 +0100 Subject: [PATCH 5/9] Add test cases --- tests/test_dice_loss.py | 16 ++++++++++++++++ tests/test_generalized_dice_loss.py | 16 ++++++++++++++++ tests/test_tversky_loss.py | 16 ++++++++++++++++ 3 files changed, 48 insertions(+) diff --git a/tests/test_dice_loss.py b/tests/test_dice_loss.py index 14aa6ec241..cea6ccf113 100644 --- a/tests/test_dice_loss.py +++ b/tests/test_dice_loss.py @@ -34,6 +34,22 @@ }, 0.416657, ], + [ # shape: (2, 1, 2, 2), (2, 1, 2, 2) + {"include_background": True, "smooth_nr": 1e-4, "smooth_dr": 1e-4, "soft_label": True}, + { + "input": torch.tensor([[[[0.3, 0.4], [0.7, 0.9]]], [[[1.0, 0.1], [0.5, 0.3]]]]), + "target": torch.tensor([[[[0.3, 0.4], [0.7, 0.9]]], [[[1.0, 0.1], [0.5, 0.3]]]]), + }, + 0.0, + ], + [ # shape: (2, 1, 2, 2), (2, 1, 2, 2) + {"include_background": True, "smooth_nr": 1e-4, "smooth_dr": 1e-4, "soft_label": False}, + { + "input": torch.tensor([[[[0.3, 0.4], [0.7, 0.9]]], [[[1.0, 0.1], [0.5, 0.3]]]]), + "target": torch.tensor([[[[0.3, 0.4], [0.7, 0.9]]], [[[1.0, 0.1], [0.5, 0.3]]]]), + }, + 0.307773, + ], [ # shape: (2, 2, 3), (2, 1, 3) {"include_background": False, "to_onehot_y": True, "smooth_nr": 0, "smooth_dr": 0}, { diff --git a/tests/test_generalized_dice_loss.py b/tests/test_generalized_dice_loss.py index 5738f4a089..9706c2e746 100644 --- a/tests/test_generalized_dice_loss.py +++ b/tests/test_generalized_dice_loss.py @@ -34,6 +34,22 @@ }, 0.416597, ], + [ # shape: (2, 1, 2, 2), (2, 1, 2, 2) + {"include_background": True, "smooth_nr": 1e-4, "smooth_dr": 1e-4, "soft_label": True}, + { + "input": torch.tensor([[[[0.3, 0.4], [0.7, 0.9]]], [[[1.0, 0.1], [0.5, 0.3]]]]), + "target": torch.tensor([[[[0.3, 0.4], [0.7, 0.9]]], [[[1.0, 0.1], [0.5, 0.3]]]]), + }, + 0.0, + ], + [ # shape: (2, 1, 2, 2), (2, 1, 2, 2) + {"include_background": True, "smooth_nr": 1e-4, "smooth_dr": 1e-4, "soft_label": False}, + { + "input": torch.tensor([[[[0.3, 0.4], [0.7, 0.9]]], [[[1.0, 0.1], [0.5, 0.3]]]]), + "target": torch.tensor([[[[0.3, 0.4], [0.7, 0.9]]], [[[1.0, 0.1], [0.5, 0.3]]]]), + }, + 0.307748, + ], [ # shape: (2, 2, 3), (2, 1, 3) {"include_background": False, "to_onehot_y": True, "smooth_nr": 0.0, "smooth_dr": 0.0}, { diff --git a/tests/test_tversky_loss.py b/tests/test_tversky_loss.py index 0365503ea2..73a841a55d 100644 --- a/tests/test_tversky_loss.py +++ b/tests/test_tversky_loss.py @@ -34,6 +34,22 @@ }, 0.416657, ], + [ # shape: (2, 1, 2, 2), (2, 1, 2, 2) + {"include_background": True, "smooth_nr": 1e-4, "smooth_dr": 1e-4, "soft_label": True}, + { + "input": torch.tensor([[[[0.3, 0.4], [0.7, 0.9]]], [[[1.0, 0.1], [0.5, 0.3]]]]), + "target": torch.tensor([[[[0.3, 0.4], [0.7, 0.9]]], [[[1.0, 0.1], [0.5, 0.3]]]]), + }, + 0.0, + ], + [ # shape: (2, 1, 2, 2), (2, 1, 2, 2) + {"include_background": True, "smooth_nr": 1e-4, "smooth_dr": 1e-4, "soft_label": False}, + { + "input": torch.tensor([[[[0.3, 0.4], [0.7, 0.9]]], [[[1.0, 0.1], [0.5, 0.3]]]]), + "target": torch.tensor([[[[0.3, 0.4], [0.7, 0.9]]], [[[1.0, 0.1], [0.5, 0.3]]]]), + }, + 0.307773, + ], [ # shape: (2, 2, 3), (2, 1, 3) {"include_background": False, "to_onehot_y": True, "smooth_nr": 0, "smooth_dr": 0}, { From aeef0afccd762f9b0f90b4b42ac12db8fb8b9031 Mon Sep 17 00:00:00 2001 From: Zifu Wang Date: Sat, 30 Nov 2024 21:33:30 +0100 Subject: [PATCH 6/9] Modify args description and remove check --- monai/losses/dice.py | 6 ++++-- monai/losses/tversky.py | 3 ++- monai/losses/utils.py | 2 -- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/monai/losses/dice.py b/monai/losses/dice.py index 92e9e0c684..4108820bec 100644 --- a/monai/losses/dice.py +++ b/monai/losses/dice.py @@ -99,7 +99,8 @@ def __init__( of the sequence should be the same as the number of classes. If not ``include_background``, the number of classes should not include the background category class 0). The value/values should be no less than 0. Defaults to None. - soft_label: whether the target contains non-binary values or not + soft_label: whether the target contains non-binary values (soft labels) or not. + If True a soft label formulation of the loss will be used. Raises: TypeError: When ``other_act`` is not an ``Optional[Callable]``. @@ -302,7 +303,8 @@ def __init__( batch: whether to sum the intersection and union areas over the batch dimension before the dividing. Defaults to False, intersection over union is computed from each item in the batch. If True, the class-weighted intersection and union areas are first summed across the batches. - soft_label: whether the target contains non-binary values or not + soft_label: whether the target contains non-binary values (soft labels) or not. + If True a soft label formulation of the loss will be used. Raises: TypeError: When ``other_act`` is not an ``Optional[Callable]``. diff --git a/monai/losses/tversky.py b/monai/losses/tversky.py index d21aa76537..154f34c526 100644 --- a/monai/losses/tversky.py +++ b/monai/losses/tversky.py @@ -75,7 +75,8 @@ def __init__( batch: whether to sum the intersection and union areas over the batch dimension before the dividing. Defaults to False, a Dice loss value is computed independently from each item in the batch before any `reduction`. - soft_label: whether the target contains non-binary values or not + soft_label: whether the target contains non-binary values (soft labels) or not. + If True a soft label formulation of the loss will be used. Raises: TypeError: When ``other_act`` is not an ``Optional[Callable]``. diff --git a/monai/losses/utils.py b/monai/losses/utils.py index 646ebd6d62..88ae3059cf 100644 --- a/monai/losses/utils.py +++ b/monai/losses/utils.py @@ -29,8 +29,6 @@ def compute_tp_fp_fn( Adapted from: https://github.com/zifuwanggg/JDTLosses """ - if torch.unique(target).shape[0] > 2 and not soft_label: - warnings.warn("soft labels are used, but `soft_label == False`.") # the original implementation that is erroneous with soft labels if ord == 1 and not soft_label: From 2f77b1ff3d851f923222a8399df1a456dc9d0b21 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 30 Nov 2024 20:33:59 +0000 Subject: [PATCH 7/9] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/losses/utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/monai/losses/utils.py b/monai/losses/utils.py index 88ae3059cf..40d9264d61 100644 --- a/monai/losses/utils.py +++ b/monai/losses/utils.py @@ -11,7 +11,6 @@ from __future__ import annotations -import warnings import torch import torch.linalg as LA From 58c5396a341edfebaf109a24bfb6753bb2b7a171 Mon Sep 17 00:00:00 2001 From: Zifu Wang Date: Sat, 30 Nov 2024 21:46:45 +0100 Subject: [PATCH 8/9] Fix code format --- monai/losses/utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/monai/losses/utils.py b/monai/losses/utils.py index 40d9264d61..8b460f74b3 100644 --- a/monai/losses/utils.py +++ b/monai/losses/utils.py @@ -11,7 +11,6 @@ from __future__ import annotations - import torch import torch.linalg as LA From 185d2e142122ad0de2a131a3cb74831eb1db2238 Mon Sep 17 00:00:00 2001 From: Zifu Wang Date: Mon, 2 Dec 2024 11:46:14 +0100 Subject: [PATCH 9/9] DCO Remediation Commit for Zifu Wang I, Zifu Wang , hereby add my Signed-off-by to this commit: 3f74183c5776d4746229006acad5b52c8c9df6ec I, Zifu Wang , hereby add my Signed-off-by to this commit: a778e583d32082480811a4bec3bed9a3622b5871 I, Zifu Wang , hereby add my Signed-off-by to this commit: aeef0afccd762f9b0f90b4b42ac12db8fb8b9031 I, Zifu Wang , hereby add my Signed-off-by to this commit: 58c5396a341edfebaf109a24bfb6753bb2b7a171 Signed-off-by: Zifu Wang --- monai/losses/utils.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/monai/losses/utils.py b/monai/losses/utils.py index 8b460f74b3..782fd9c9c2 100644 --- a/monai/losses/utils.py +++ b/monai/losses/utils.py @@ -24,6 +24,16 @@ def compute_tp_fp_fn( decoupled: bool = True, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ + Args: + input: the shape should be BNH[WD], where N is the number of classes. + target: the shape should be BNH[WD] or B1H[WD], where N is the number of classes. + reduce_axis: the axis to be reduced. + ord: the order of the vector norm. + soft_label: whether the target contains non-binary values (soft labels) or not. + If True a soft label formulation of the loss will be used. + decoupled: whether the input and the target should be decoupled when computing fp and fn. + Only for the original implementation when soft_label is False. + Adapted from: https://github.com/zifuwanggg/JDTLosses """ @@ -39,6 +49,8 @@ def compute_tp_fp_fn( else: fp = torch.sum(input * (1 - target), dim=reduce_axis) fn = torch.sum((1 - input) * target, dim=reduce_axis) + # the new implementation that is correct with soft labels + # and it is identical to the original implementation with hard labels else: pred_o = LA.vector_norm(input, ord=ord, dim=reduce_axis) ground_o = LA.vector_norm(target, ord=ord, dim=reduce_axis)