From db73f0de472fe28bdead6845d4c891de02d82edb Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Wed, 28 Oct 2020 15:37:06 +0000 Subject: [PATCH 1/2] update dice and the relevant loss Signed-off-by: Wenqi Li --- monai/losses/dice.py | 75 +++++++++++++++++++++-------- monai/losses/tversky.py | 13 +++-- tests/test_dice_loss.py | 60 ++++++++++++++--------- tests/test_generalized_dice_loss.py | 63 +++++++++++++++--------- tests/test_masked_dice_loss.py | 44 +++++++++-------- tests/test_seg_loss_integration.py | 3 +- tests/test_tversky_loss.py | 53 +++++++++++--------- 7 files changed, 197 insertions(+), 114 deletions(-) diff --git a/monai/losses/dice.py b/monai/losses/dice.py index 0d8bf96764..1b21a709aa 100644 --- a/monai/losses/dice.py +++ b/monai/losses/dice.py @@ -26,12 +26,12 @@ class DiceLoss(_Loss): Compute average Dice loss between two tensors. It can support both multi-classes and multi-labels tasks. Input logits `input` (BNHW[D] where N is number of classes) is compared with ground truth `target` (BNHW[D]). Axis N of `input` is expected to have logit predictions for each class rather than being image channels, - while the same axis of `target` can be 1 or N (one-hot format). The `smooth` parameter is a value added to the - intersection and union components of the inter-over-union calculation to smooth results and prevent divide by 0, - this value should be small. The `include_background` class attribute can be set to False for an instance of - DiceLoss to exclude the first category (channel index 0) which is by convention assumed to be background. - If the non-background segmentations are small compared to the total image size they can get overwhelmed by - the signal from the background so excluding it in such cases helps convergence. + while the same axis of `target` can be 1 or N (one-hot format). The `smooth_nr` and `smooth_dr` parameters are + values added to the intersection and union components of the inter-over-union calculation to smooth results + respectively, these values should be small. The `include_background` class attribute can be set to False for + an instance of DiceLoss to exclude the first category (channel index 0) which is by convention assumed to be + background. If the non-background segmentations are small compared to the total image size they can get + overwhelmed by the signal from the background so excluding it in such cases helps convergence. Milletari, F. et. al. (2016) V-Net: Fully Convolutional Neural Networks forVolumetric Medical Image Segmentation, 3DV, 2016. @@ -47,6 +47,9 @@ def __init__( squared_pred: bool = False, jaccard: bool = False, reduction: Union[LossReduction, str] = LossReduction.MEAN, + smooth_nr: float = 1e-5, + smooth_dr: float = 1e-5, + batch: bool = False, ) -> None: """ Args: @@ -66,6 +69,12 @@ def __init__( - ``"mean"``: the sum of the output will be divided by the number of elements in the output. - ``"sum"``: the output will be summed. + smooth_nr: a small constant added to the numerator to avoid zero. + smooth_dr: a small constant added to the denominator to avoid nan. + batch: whether to sum the intersection and union areas over the batch dimension before the dividing. + Defaults to False, a Dice loss value is computed independently from each item in the batch + before any `reduction`. + Raises: TypeError: When ``other_act`` is not an ``Optional[Callable]``. ValueError: When more than 1 of [``sigmoid=True``, ``softmax=True``, ``other_act is not None``]. @@ -84,13 +93,15 @@ def __init__( self.other_act = other_act self.squared_pred = squared_pred self.jaccard = jaccard + self.smooth_nr = float(smooth_nr) + self.smooth_dr = float(smooth_dr) + self.batch = batch - def forward(self, input: torch.Tensor, target: torch.Tensor, smooth: float = 1e-5) -> torch.Tensor: + def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ Args: input: the shape should be BNH[WD]. target: the shape should be BNH[WD]. - smooth: a small constant to avoid nan. Raises: ValueError: When ``self.reduction`` is not one of ["mean", "sum", "none"]. @@ -129,6 +140,10 @@ def forward(self, input: torch.Tensor, target: torch.Tensor, smooth: float = 1e- # reducing only spatial dimensions (not batch nor channels) reduce_axis = list(range(2, len(input.shape))) + if self.batch: + # reducing spatial dimensions and batch + reduce_axis = [0] + reduce_axis + intersection = torch.sum(target * input, dim=reduce_axis) if self.squared_pred: @@ -143,7 +158,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor, smooth: float = 1e- if self.jaccard: denominator = 2.0 * (denominator - intersection) - f: torch.Tensor = 1.0 - (2.0 * intersection + smooth) / (denominator + smooth) + f: torch.Tensor = 1.0 - (2.0 * intersection + self.smooth_nr) / (denominator + self.smooth_dr) if self.reduction == LossReduction.MEAN.value: f = torch.mean(f) # the batch and channel average @@ -167,14 +182,11 @@ class MaskedDiceLoss(DiceLoss): """ - def forward( - self, input: torch.Tensor, target: torch.Tensor, smooth: float = 1e-5, mask: Optional[torch.Tensor] = None - ) -> torch.Tensor: + def forward(self, input: torch.Tensor, target: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor: """ Args: input: the shape should be BNH[WD]. target: the shape should be BNH[WD]. - smooth: a small constant to avoid nan. mask: the shape should B1H[WD] or 11H[WD]. """ if mask is not None: @@ -195,7 +207,7 @@ def forward( else: warnings.warn("no mask value specified for the MaskedDiceLoss.") - return super().forward(input=input, target=target, smooth=smooth) + return super().forward(input=input, target=target) class GeneralizedDiceLoss(_Loss): @@ -218,6 +230,9 @@ def __init__( other_act: Optional[Callable] = None, w_type: Union[Weight, str] = Weight.SQUARE, reduction: Union[LossReduction, str] = LossReduction.MEAN, + smooth_nr: float = 1e-5, + smooth_dr: float = 1e-5, + batch: bool = False, ) -> None: """ Args: @@ -237,6 +252,10 @@ def __init__( - ``"none"``: no reduction will be applied. - ``"mean"``: the sum of the output will be divided by the number of elements in the output. - ``"sum"``: the output will be summed. + smooth_nr: a small constant added to the numerator to avoid zero. + smooth_dr: a small constant added to the denominator to avoid nan. + batch: whether to sum the intersection and union areas over the batch dimension before the dividing. + Defaults to False, intersection over union is computed from each item in the batch. Raises: TypeError: When ``other_act`` is not an ``Optional[Callable]``. @@ -262,12 +281,15 @@ def __init__( elif w_type == Weight.SQUARE: self.w_func = lambda x: torch.reciprocal(x * x) - def forward(self, input: torch.Tensor, target: torch.Tensor, smooth: float = 1e-5) -> torch.Tensor: + self.smooth_nr = float(smooth_nr) + self.smooth_dr = float(smooth_dr) + self.batch = batch + + def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ Args: input: the shape should be BNH[WD]. target: the shape should be BNH[WD]. - smooth: a small constant to avoid nan. Raises: ValueError: When ``self.reduction`` is not one of ["mean", "sum", "none"]. @@ -305,6 +327,8 @@ def forward(self, input: torch.Tensor, target: torch.Tensor, smooth: float = 1e- # reducing only spatial dimensions (not batch nor channels) reduce_axis = list(range(2, len(input.shape))) + if self.batch: + reduce_axis = [0] + reduce_axis intersection = torch.sum(target * input, reduce_axis) ground_o = torch.sum(target, reduce_axis) @@ -318,7 +342,9 @@ def forward(self, input: torch.Tensor, target: torch.Tensor, smooth: float = 1e- b[infs] = 0.0 b[infs] = torch.max(b) - f: torch.Tensor = 1.0 - (2.0 * (intersection * w).sum(1) + smooth) / ((denominator * w).sum(1) + smooth) + f: torch.Tensor = 1.0 - (2.0 * (intersection * w).sum(1) + self.smooth_nr) / ( + (denominator * w).sum(1) + self.smooth_dr + ) if self.reduction == LossReduction.MEAN.value: f = torch.mean(f) # the batch and channel average @@ -369,7 +395,11 @@ class GeneralizedWassersteinDiceLoss(_Loss): """ def __init__( - self, dist_matrix: Union[np.ndarray, torch.Tensor], reduction: Union[LossReduction, str] = LossReduction.MEAN + self, + dist_matrix: Union[np.ndarray, torch.Tensor], + reduction: Union[LossReduction, str] = LossReduction.MEAN, + smooth_nr: float = 1e-5, + smooth_dr: float = 1e-5, ) -> None: """ Args: @@ -377,6 +407,8 @@ def __init__( between the classes. It must have dimension C x C where C is the number of classes. reduction: str; reduction mode. + smooth_nr: a small constant added to the numerator to avoid zero. + smooth_dr: a small constant added to the denominator to avoid nan. Raises: ValueError: When ``dist_matrix`` is not a square matrix. @@ -393,13 +425,14 @@ def __init__( if torch.max(self.m) != 1: self.m = self.m / torch.max(self.m) self.num_classes = self.m.size(0) + self.smooth_nr = float(smooth_nr) + self.smooth_dr = float(smooth_dr) - def forward(self, input: torch.Tensor, target: torch.Tensor, smooth: float = 1e-5) -> torch.Tensor: + def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ Args: input: the shape should be BNH[WD]. target: the shape should be BNH[WD]. - smooth: a small constant to avoid nan. """ # Aggregate spatial dimensions @@ -418,7 +451,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor, smooth: float = 1e- denom = self.compute_denominator(alpha, flat_target, wass_dist_map) # Compute and return the final loss - wass_dice: torch.Tensor = (2.0 * true_pos + smooth) / (denom + smooth) + wass_dice: torch.Tensor = (2.0 * true_pos + self.smooth_nr) / (denom + self.smooth_dr) wass_dice_loss: torch.Tensor = 1.0 - wass_dice return wass_dice_loss.mean() diff --git a/monai/losses/tversky.py b/monai/losses/tversky.py index aa9b28f1d2..3ddc9c0e17 100644 --- a/monai/losses/tversky.py +++ b/monai/losses/tversky.py @@ -42,6 +42,8 @@ def __init__( alpha: float = 0.5, beta: float = 0.5, reduction: Union[LossReduction, str] = LossReduction.MEAN, + smooth_nr: float = 1e-5, + smooth_dr: float = 1e-5, ) -> None: """ Args: @@ -60,6 +62,8 @@ def __init__( - ``"none"``: no reduction will be applied. - ``"mean"``: the sum of the output will be divided by the number of elements in the output. - ``"sum"``: the output will be summed. + smooth_nr: a small constant added to the numerator to avoid zero. + smooth_dr: a small constant added to the denominator to avoid nan. Raises: TypeError: When ``other_act`` is not an ``Optional[Callable]``. @@ -80,13 +84,14 @@ def __init__( self.other_act = other_act self.alpha = alpha self.beta = beta + self.smooth_nr = float(smooth_nr) + self.smooth_dr = float(smooth_dr) - def forward(self, input: torch.Tensor, target: torch.Tensor, smooth: float = 1e-5) -> torch.Tensor: + def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ Args: input: the shape should be BNH[WD]. target: the shape should be BNH[WD]. - smooth: a small constant to avoid nan. Raises: ValueError: When ``self.reduction`` is not one of ["mean", "sum", "none"]. @@ -135,8 +140,8 @@ def forward(self, input: torch.Tensor, target: torch.Tensor, smooth: float = 1e- fp = self.alpha * torch.sum(p0 * g1, reduce_axis) fn = self.beta * torch.sum(p1 * g0, reduce_axis) - numerator = tp + smooth - denominator = tp + fp + fn + smooth + numerator = tp + self.smooth_nr + denominator = tp + fp + fn + self.smooth_dr score: torch.Tensor = 1.0 - numerator / denominator diff --git a/tests/test_dice_loss.py b/tests/test_dice_loss.py index 7ac863e796..6c2b3a4579 100644 --- a/tests/test_dice_loss.py +++ b/tests/test_dice_loss.py @@ -19,74 +19,80 @@ TEST_CASES = [ [ # shape: (1, 1, 2, 2), (1, 1, 2, 2) - {"include_background": True, "sigmoid": True}, + {"include_background": True, "sigmoid": True, "smooth_nr": 1e-6, "smooth_dr": 1e-6}, { "input": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]]]), "target": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]]), - "smooth": 1e-6, }, 0.307576, ], [ # shape: (2, 1, 2, 2), (2, 1, 2, 2) - {"include_background": True, "sigmoid": True}, + {"include_background": True, "sigmoid": True, "smooth_nr": 1e-4, "smooth_dr": 1e-4}, { "input": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]], [[[1.0, -1.0], [-1.0, 1.0]]]]), "target": torch.tensor([[[[1.0, 1.0], [1.0, 1.0]]], [[[1.0, 0.0], [1.0, 0.0]]]]), - "smooth": 1e-4, }, 0.416657, ], [ # shape: (2, 2, 3), (2, 1, 3) - {"include_background": False, "to_onehot_y": True}, + {"include_background": False, "to_onehot_y": True, "smooth_nr": 0, "smooth_dr": 0}, { "input": torch.tensor([[[1.0, 1.0, 0.0], [0.0, 0.0, 1.0]], [[1.0, 0.0, 1.0], [0.0, 1.0, 0.0]]]), "target": torch.tensor([[[0.0, 0.0, 1.0]], [[0.0, 1.0, 0.0]]]), - "smooth": 0.0, }, 0.0, ], [ # shape: (2, 2, 3), (2, 1, 3) - {"include_background": True, "to_onehot_y": True, "sigmoid": True}, + {"include_background": True, "to_onehot_y": True, "sigmoid": True, "smooth_nr": 1e-4, "smooth_dr": 1e-4}, { "input": torch.tensor([[[-1.0, 0.0, 1.0], [1.0, 0.0, -1.0]], [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]), "target": torch.tensor([[[1.0, 0.0, 0.0]], [[1.0, 1.0, 0.0]]]), - "smooth": 1e-4, }, 0.435050, ], [ # shape: (2, 2, 3), (2, 1, 3) - {"include_background": True, "to_onehot_y": True, "sigmoid": True, "reduction": "none"}, + { + "include_background": True, + "to_onehot_y": True, + "sigmoid": True, + "reduction": "none", + "smooth_nr": 1e-4, + "smooth_dr": 1e-4, + }, { "input": torch.tensor([[[-1.0, 0.0, 1.0], [1.0, 0.0, -1.0]], [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]), "target": torch.tensor([[[1.0, 0.0, 0.0]], [[1.0, 1.0, 0.0]]]), - "smooth": 1e-4, }, [[0.296529, 0.415136], [0.599976, 0.428559]], ], [ # shape: (2, 2, 3), (2, 1, 3) - {"include_background": True, "to_onehot_y": True, "softmax": True}, + {"include_background": True, "to_onehot_y": True, "softmax": True, "smooth_nr": 1e-4, "smooth_dr": 1e-4}, { "input": torch.tensor([[[-1.0, 0.0, 1.0], [1.0, 0.0, -1.0]], [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]), "target": torch.tensor([[[1.0, 0.0, 0.0]], [[1.0, 1.0, 0.0]]]), - "smooth": 1e-4, }, 0.383713, ], [ # shape: (2, 2, 3), (2, 1, 3) - {"include_background": True, "to_onehot_y": True, "softmax": True, "reduction": "sum"}, + { + "include_background": True, + "to_onehot_y": True, + "softmax": True, + "reduction": "sum", + "smooth_nr": 1e-4, + "smooth_dr": 1e-4, + }, { "input": torch.tensor([[[-1.0, 0.0, 1.0], [1.0, 0.0, -1.0]], [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]), "target": torch.tensor([[[1.0, 0.0, 0.0]], [[1.0, 1.0, 0.0]]]), - "smooth": 1e-4, }, 1.534853, ], [ # shape: (1, 1, 2, 2), (1, 1, 2, 2) - {"include_background": True, "sigmoid": True}, + {"include_background": True, "sigmoid": True, "smooth_nr": 1e-6, "smooth_dr": 1e-6}, { "input": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]]]), "target": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]]), - "smooth": 1e-6, }, 0.307576, ], @@ -95,7 +101,6 @@ { "input": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]]]), "target": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]]), - "smooth": 1e-5, }, 0.178337, ], @@ -104,28 +109,39 @@ { "input": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]]]), "target": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]]), - "smooth": 1e-5, }, 0.470451, ], [ # shape: (2, 1, 2, 2), (2, 1, 2, 2) - {"include_background": True, "other_act": torch.tanh}, + {"include_background": True, "other_act": torch.tanh, "smooth_nr": 1e-4, "smooth_dr": 1e-4}, { "input": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]], [[[1.0, -1.0], [-1.0, 1.0]]]]), "target": torch.tensor([[[[1.0, 1.0], [1.0, 1.0]]], [[[1.0, 0.0], [1.0, 0.0]]]]), - "smooth": 1e-4, }, 0.999963, ], [ # shape: (2, 2, 3), (2, 1, 3) - {"include_background": True, "to_onehot_y": True, "other_act": lambda x: torch.log_softmax(x, dim=1)}, + { + "include_background": True, + "to_onehot_y": True, + "other_act": lambda x: torch.log_softmax(x, dim=1), + "smooth_nr": 1e-4, + "smooth_dr": 1e-4, + }, { "input": torch.tensor([[[-1.0, 0.0, 1.0], [1.0, 0.0, -1.0]], [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]), "target": torch.tensor([[[1.0, 0.0, 0.0]], [[1.0, 1.0, 0.0]]]), - "smooth": 1e-4, }, -8.522593, ], + [ # shape: (2, 1, 2, 2), (2, 1, 2, 2) + {"include_background": True, "other_act": torch.tanh, "smooth_nr": 1e-4, "smooth_dr": 1e-4, "batch": True}, + { + "input": torch.tensor([[[[1.0, -0.0], [-1.0, 1.0]]], [[[1.0, -1.0], [-1.0, 1.0]]]]), + "target": torch.tensor([[[[1.0, 1.0], [1.0, 1.0]]], [[[1.0, 0.0], [1.0, 0.0]]]]), + }, + 0.774718, + ], ] diff --git a/tests/test_generalized_dice_loss.py b/tests/test_generalized_dice_loss.py index d6cdc40493..3ddc68ae4e 100644 --- a/tests/test_generalized_dice_loss.py +++ b/tests/test_generalized_dice_loss.py @@ -19,110 +19,125 @@ TEST_CASES = [ [ # shape: (1, 1, 2, 2), (1, 1, 2, 2) - {"include_background": True, "sigmoid": True}, + {"include_background": True, "sigmoid": True, "smooth_nr": 1e-6, "smooth_dr": 1e-6}, { "input": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]]]), "target": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]]), - "smooth": 1e-6, }, 0.307576, ], [ # shape: (2, 1, 2, 2), (2, 1, 2, 2) - {"include_background": True, "sigmoid": True}, + {"include_background": True, "sigmoid": True, "smooth_nr": 1e-4, "smooth_dr": 1e-4}, { "input": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]], [[[1.0, -1.0], [-1.0, 1.0]]]]), "target": torch.tensor([[[[1.0, 1.0], [1.0, 1.0]]], [[[1.0, 0.0], [1.0, 0.0]]]]), - "smooth": 1e-4, }, 0.416597, ], [ # shape: (2, 2, 3), (2, 1, 3) - {"include_background": False, "to_onehot_y": True}, + {"include_background": False, "to_onehot_y": True, "smooth_nr": 0.0, "smooth_dr": 0.0}, { "input": torch.tensor([[[1.0, 1.0, 0.0], [0.0, 0.0, 1.0]], [[1.0, 0.0, 1.0], [0.0, 1.0, 0.0]]]), "target": torch.tensor([[[0.0, 0.0, 1.0]], [[0.0, 1.0, 0.0]]]), - "smooth": 0.0, }, 0.0, ], [ # shape: (2, 2, 3), (2, 1, 3) - {"include_background": True, "to_onehot_y": True, "sigmoid": True}, + {"include_background": True, "to_onehot_y": True, "sigmoid": True, "smooth_nr": 1e-4, "smooth_dr": 1e-4}, { "input": torch.tensor([[[-1.0, 0.0, 1.0], [1.0, 0.0, -1.0]], [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]), "target": torch.tensor([[[1.0, 0.0, 0.0]], [[1.0, 1.0, 0.0]]]), - "smooth": 1e-4, }, 0.469964, ], [ # shape: (2, 2, 3), (2, 1, 3) - {"include_background": True, "to_onehot_y": True, "softmax": True}, + {"include_background": True, "to_onehot_y": True, "softmax": True, "smooth_nr": 1e-4, "smooth_dr": 1e-4}, { "input": torch.tensor([[[-1.0, 0.0, 1.0], [1.0, 0.0, -1.0]], [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]), "target": torch.tensor([[[1.0, 0.0, 0.0]], [[1.0, 1.0, 0.0]]]), - "smooth": 1e-4, }, 0.414507, ], [ # shape: (2, 2, 3), (2, 1, 3) - {"include_background": True, "to_onehot_y": True, "softmax": True, "reduction": "sum"}, + { + "include_background": True, + "to_onehot_y": True, + "softmax": True, + "reduction": "sum", + "smooth_nr": 1e-4, + "smooth_dr": 1e-4, + }, { "input": torch.tensor([[[-1.0, 0.0, 1.0], [1.0, 0.0, -1.0]], [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]), "target": torch.tensor([[[1.0, 0.0, 0.0]], [[1.0, 1.0, 0.0]]]), - "smooth": 1e-4, }, 0.829015, ], [ # shape: (2, 2, 3), (2, 1, 3) - {"include_background": True, "to_onehot_y": True, "softmax": True, "reduction": "none"}, + { + "include_background": True, + "to_onehot_y": True, + "softmax": True, + "reduction": "none", + "smooth_nr": 1e-4, + "smooth_dr": 1e-4, + }, { "input": torch.tensor([[[-1.0, 0.0, 1.0], [1.0, 0.0, -1.0]], [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]), "target": torch.tensor([[[1.0, 0.0, 0.0]], [[1.0, 1.0, 0.0]]]), - "smooth": 1e-4, }, [0.273476, 0.555539], ], [ # shape: (2, 2, 3), (2, 1, 3) - {"include_background": False, "to_onehot_y": True}, + {"include_background": False, "to_onehot_y": True, "smooth_nr": 1e-8, "smooth_dr": 1e-8}, { "input": torch.tensor([[[1.0, 1.0, 0.0], [0.0, 0.0, 1.0]], [[1.0, 0.0, 1.0], [0.0, 1.0, 0.0]]]), "target": torch.tensor([[[0.0, 0.0, 0.0]], [[0.0, 0.0, 0.0]]]), - "smooth": 1e-8, }, 0.0, ], [ # shape: (1, 1, 2, 2), (1, 1, 2, 2) - {"include_background": True, "sigmoid": True}, + {"include_background": True, "sigmoid": True, "smooth_nr": 1e-6, "smooth_dr": 1e-6}, { "input": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]]]), "target": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]]), - "smooth": 1e-6, }, 0.307576, ], [ # shape: (1, 2, 4), (1, 1, 4) - {"include_background": True, "to_onehot_y": True, "softmax": True, "w_type": "simple"}, + { + "include_background": True, + "to_onehot_y": True, + "softmax": True, + "w_type": "simple", + "smooth_nr": 0, + "smooth_dr": 0, + }, { "input": torch.tensor([[[0.0, 10.0, 10.0, 10.0], [10.0, 0.0, 0.0, 0.0]]]), "target": torch.tensor([[[1, 1, 0, 0]]]), - "smooth": 0.0, }, 0.250023, ], [ # shape: (2, 1, 2, 2), (2, 1, 2, 2) - {"include_background": True, "other_act": torch.tanh}, + {"include_background": True, "other_act": torch.tanh, "smooth_nr": 1e-4, "smooth_dr": 1e-4}, { "input": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]], [[[1.0, -1.0], [-1.0, 1.0]]]]), "target": torch.tensor([[[[1.0, 1.0], [1.0, 1.0]]], [[[1.0, 0.0], [1.0, 0.0]]]]), - "smooth": 1e-4, }, 0.99970, ], [ # shape: (2, 2, 3), (2, 1, 3) - {"include_background": True, "to_onehot_y": True, "other_act": lambda x: torch.log_softmax(x, dim=1)}, + { + "include_background": True, + "to_onehot_y": True, + "other_act": lambda x: torch.log_softmax(x, dim=1), + "smooth_nr": 1e-4, + "smooth_dr": 1e-4, + }, { "input": torch.tensor([[[-1.0, 0.0, 1.0], [1.0, 0.0, -1.0]], [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]), "target": torch.tensor([[[1.0, 0.0, 0.0]], [[1.0, 1.0, 0.0]]]), - "smooth": 1e-4, }, -0.097833, ], diff --git a/tests/test_masked_dice_loss.py b/tests/test_masked_dice_loss.py index 10738be7f2..3ea3151bd7 100644 --- a/tests/test_masked_dice_loss.py +++ b/tests/test_masked_dice_loss.py @@ -19,96 +19,100 @@ TEST_CASES = [ [ # shape: (1, 1, 2, 2), (1, 1, 2, 2) - {"include_background": True, "sigmoid": True}, + {"include_background": True, "sigmoid": True, "smooth_nr": 1e-6, "smooth_dr": 1e-6}, { "input": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]]]), "target": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]]), "mask": torch.tensor([[[[0.0, 0.0], [1.0, 1.0]]]]), - "smooth": 1e-6, }, 0.500, ], [ # shape: (2, 1, 2, 2), (2, 1, 2, 2) - {"include_background": True, "sigmoid": True}, + {"include_background": True, "sigmoid": True, "smooth_nr": 1e-4, "smooth_dr": 1e-4}, { "input": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]], [[[1.0, -1.0], [-1.0, 1.0]]]]), "target": torch.tensor([[[[1.0, 1.0], [1.0, 1.0]]], [[[1.0, 0.0], [1.0, 0.0]]]]), "mask": torch.tensor([[[[1.0, 1.0], [1.0, 1.0]]], [[[1.0, 1.0], [0.0, 0.0]]]]), - "smooth": 1e-4, }, 0.422969, ], [ # shape: (2, 2, 3), (2, 1, 3) - {"include_background": False, "to_onehot_y": True}, + {"include_background": False, "to_onehot_y": True, "smooth_nr": 0, "smooth_dr": 0}, { "input": torch.tensor([[[1.0, 1.0, 0.0], [0.0, 0.0, 1.0]], [[1.0, 0.0, 1.0], [0.0, 1.0, 0.0]]]), "target": torch.tensor([[[0.0, 0.0, 1.0]], [[0.0, 1.0, 0.0]]]), "mask": torch.tensor([[[1.0, 1.0, 1.0]], [[0.0, 1.0, 0.0]]]), - "smooth": 0.0, }, 0.0, ], [ # shape: (2, 2, 3), (2, 1, 3) - {"include_background": True, "to_onehot_y": True, "sigmoid": True}, + {"include_background": True, "to_onehot_y": True, "sigmoid": True, "smooth_nr": 1e-4, "smooth_dr": 1e-4}, { "input": torch.tensor([[[-1.0, 0.0, 1.0], [1.0, 0.0, -1.0]], [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]), "target": torch.tensor([[[1.0, 0.0, 0.0]], [[1.0, 1.0, 0.0]]]), "mask": torch.tensor([[[1.0, 1.0, 0.0]]]), - "smooth": 1e-4, }, 0.47033, ], [ # shape: (2, 2, 3), (2, 1, 3) - {"include_background": True, "to_onehot_y": True, "sigmoid": True, "reduction": "none"}, + { + "include_background": True, + "to_onehot_y": True, + "sigmoid": True, + "reduction": "none", + "smooth_nr": 1e-4, + "smooth_dr": 1e-4, + }, { "input": torch.tensor([[[-1.0, 0.0, 1.0], [1.0, 0.0, -1.0]], [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]), "target": torch.tensor([[[1.0, 0.0, 0.0]], [[1.0, 1.0, 0.0]]]), - "smooth": 1e-4, }, [[0.296529, 0.415136], [0.599976, 0.428559]], ], [ # shape: (2, 2, 3), (2, 1, 3) - {"include_background": True, "to_onehot_y": True, "softmax": True}, + {"include_background": True, "to_onehot_y": True, "softmax": True, "smooth_nr": 1e-4, "smooth_dr": 1e-4}, { "input": torch.tensor([[[-1.0, 0.0, 1.0], [1.0, 0.0, -1.0]], [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]), "target": torch.tensor([[[1.0, 0.0, 0.0]], [[1.0, 1.0, 0.0]]]), - "smooth": 1e-4, }, 0.383713, ], [ # shape: (2, 2, 3), (2, 1, 3) - {"include_background": True, "to_onehot_y": True, "softmax": True, "reduction": "sum"}, + { + "include_background": True, + "to_onehot_y": True, + "softmax": True, + "reduction": "sum", + "smooth_nr": 1e-4, + "smooth_dr": 1e-4, + }, { "input": torch.tensor([[[-1.0, 0.0, 1.0], [1.0, 0.0, -1.0]], [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]), "target": torch.tensor([[[1.0, 0.0, 0.0]], [[1.0, 1.0, 0.0]]]), - "smooth": 1e-4, }, 1.534853, ], [ # shape: (1, 1, 2, 2), (1, 1, 2, 2) - {"include_background": True, "sigmoid": True}, + {"include_background": True, "sigmoid": True, "smooth_nr": 1e-6, "smooth_dr": 1e-6}, { "input": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]]]), "target": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]]), - "smooth": 1e-6, }, 0.307576, ], [ # shape: (1, 1, 2, 2), (1, 1, 2, 2) - {"include_background": True, "sigmoid": True, "squared_pred": True}, + {"include_background": True, "sigmoid": True, "squared_pred": True, "smooth_nr": 1e-5, "smooth_dr": 1e-5}, { "input": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]]]), "target": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]]), - "smooth": 1e-5, }, 0.178337, ], [ # shape: (1, 1, 2, 2), (1, 1, 2, 2) - {"include_background": True, "sigmoid": True, "jaccard": True}, + {"include_background": True, "sigmoid": True, "jaccard": True, "smooth_nr": 1e-5, "smooth_dr": 1e-5}, { "input": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]]]), "target": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]]), - "smooth": 1e-5, }, 0.470451, ], diff --git a/tests/test_seg_loss_integration.py b/tests/test_seg_loss_integration.py index 9503d12aaa..5024867a9a 100644 --- a/tests/test_seg_loss_integration.py +++ b/tests/test_seg_loss_integration.py @@ -20,7 +20,8 @@ from monai.losses import DiceLoss, FocalLoss, GeneralizedDiceLoss, TverskyLoss TEST_CASES = [ - [DiceLoss, {"to_onehot_y": True, "squared_pred": True}, {"smooth": 1e-4}], + [DiceLoss, {"to_onehot_y": True, "squared_pred": True, "smooth_nr": 1e-4, "smooth_dr": 1e-4}, {}], + [DiceLoss, {"to_onehot_y": True, "squared_pred": True, "batch": True}, {}], [DiceLoss, {"to_onehot_y": True, "sigmoid": True}, {}], [DiceLoss, {"to_onehot_y": True, "softmax": True}, {}], [FocalLoss, {"gamma": 1.5, "weight": torch.tensor([1, 2])}, {}], diff --git a/tests/test_tversky_loss.py b/tests/test_tversky_loss.py index ad7c089bc0..f8cda51193 100644 --- a/tests/test_tversky_loss.py +++ b/tests/test_tversky_loss.py @@ -19,101 +19,110 @@ TEST_CASES = [ [ # shape: (1, 1, 2, 2), (1, 1, 2, 2) - {"include_background": True, "sigmoid": True}, + {"include_background": True, "sigmoid": True, "smooth_nr": 1e-6, "smooth_dr": 1e-6}, { "input": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]]]), "target": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]]), - "smooth": 1e-6, }, 0.307576, ], [ # shape: (2, 1, 2, 2), (2, 1, 2, 2) - {"include_background": True, "sigmoid": True}, + {"include_background": True, "sigmoid": True, "smooth_nr": 1e-4, "smooth_dr": 1e-4}, { "input": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]], [[[1.0, -1.0], [-1.0, 1.0]]]]), "target": torch.tensor([[[[1.0, 1.0], [1.0, 1.0]]], [[[1.0, 0.0], [1.0, 0.0]]]]), - "smooth": 1e-4, }, 0.416657, ], [ # shape: (2, 2, 3), (2, 1, 3) - {"include_background": False, "to_onehot_y": True}, + {"include_background": False, "to_onehot_y": True, "smooth_nr": 0, "smooth_dr": 0}, { "input": torch.tensor([[[1.0, 1.0, 0.0], [0.0, 0.0, 1.0]], [[1.0, 0.0, 1.0], [0.0, 1.0, 0.0]]]), "target": torch.tensor([[[0.0, 0.0, 1.0]], [[0.0, 1.0, 0.0]]]), - "smooth": 0.0, }, 0.0, ], [ # shape: (2, 2, 3), (2, 1, 3) - {"include_background": True, "to_onehot_y": True, "sigmoid": True}, + {"include_background": True, "to_onehot_y": True, "sigmoid": True, "smooth_nr": 1e-4, "smooth_dr": 1e-4}, { "input": torch.tensor([[[-1.0, 0.0, 1.0], [1.0, 0.0, -1.0]], [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]), "target": torch.tensor([[[1.0, 0.0, 0.0]], [[1.0, 1.0, 0.0]]]), - "smooth": 1e-4, }, 0.435050, ], [ # shape: (2, 2, 3), (2, 1, 3) - {"include_background": True, "to_onehot_y": True, "sigmoid": True, "reduction": "sum"}, + { + "include_background": True, + "to_onehot_y": True, + "sigmoid": True, + "reduction": "sum", + "smooth_nr": 1e-4, + "smooth_dr": 1e-4, + }, { "input": torch.tensor([[[-1.0, 0.0, 1.0], [1.0, 0.0, -1.0]], [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]), "target": torch.tensor([[[1.0, 0.0, 0.0]], [[1.0, 1.0, 0.0]]]), - "smooth": 1e-4, }, 1.74013, ], [ # shape: (2, 2, 3), (2, 1, 3) - {"include_background": True, "to_onehot_y": True, "softmax": True}, + {"include_background": True, "to_onehot_y": True, "softmax": True, "smooth_nr": 1e-4, "smooth_dr": 1e-4}, { "input": torch.tensor([[[-1.0, 0.0, 1.0], [1.0, 0.0, -1.0]], [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]), "target": torch.tensor([[[1.0, 0.0, 0.0]], [[1.0, 1.0, 0.0]]]), - "smooth": 1e-4, }, 0.383713, ], [ # shape: (2, 2, 3), (2, 1, 3) - {"include_background": True, "to_onehot_y": True, "softmax": True, "reduction": "none"}, + { + "include_background": True, + "to_onehot_y": True, + "softmax": True, + "reduction": "none", + "smooth_nr": 1e-4, + "smooth_dr": 1e-4, + }, { "input": torch.tensor([[[-1.0, 0.0, 1.0], [1.0, 0.0, -1.0]], [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]), "target": torch.tensor([[[1.0, 0.0, 0.0]], [[1.0, 1.0, 0.0]]]), - "smooth": 1e-4, }, [[0.210961, 0.295339], [0.599952, 0.428547]], ], [ # shape: (1, 1, 2, 2), (1, 1, 2, 2) - {"include_background": True, "sigmoid": True, "alpha": 0.3, "beta": 0.7}, + {"include_background": True, "sigmoid": True, "alpha": 0.3, "beta": 0.7, "smooth_nr": 1e-6, "smooth_dr": 1e-6}, { "input": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]]]), "target": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]]), - "smooth": 1e-6, }, 0.3589, ], [ # shape: (1, 1, 2, 2), (1, 1, 2, 2) - {"include_background": True, "sigmoid": True, "alpha": 0.7, "beta": 0.3}, + {"include_background": True, "sigmoid": True, "alpha": 0.7, "beta": 0.3, "smooth_nr": 1e-6, "smooth_dr": 1e-6}, { "input": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]]]), "target": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]]), - "smooth": 1e-6, }, 0.247366, ], [ # shape: (2, 1, 2, 2), (2, 1, 2, 2) - {"include_background": True, "other_act": torch.tanh}, + {"include_background": True, "other_act": torch.tanh, "smooth_nr": 1e-4, "smooth_dr": 1e-4}, { "input": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]], [[[1.0, -1.0], [-1.0, 1.0]]]]), "target": torch.tensor([[[[1.0, 1.0], [1.0, 1.0]]], [[[1.0, 0.0], [1.0, 0.0]]]]), - "smooth": 1e-4, }, 0.999963, ], [ # shape: (2, 2, 3), (2, 1, 3) - {"include_background": True, "to_onehot_y": True, "other_act": lambda x: torch.log_softmax(x, dim=1)}, + { + "include_background": True, + "to_onehot_y": True, + "other_act": lambda x: torch.log_softmax(x, dim=1), + "smooth_nr": 1e-4, + "smooth_dr": 1e-4, + }, { "input": torch.tensor([[[-1.0, 0.0, 1.0], [1.0, 0.0, -1.0]], [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]), "target": torch.tensor([[[1.0, 0.0, 0.0]], [[1.0, 1.0, 0.0]]]), - "smooth": 1e-4, }, -8.533317, ], From 0a81f97637d99e4be9abcd1672ca6c902f7a51d1 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Wed, 28 Oct 2020 17:57:36 +0000 Subject: [PATCH 2/2] update based on the comments Signed-off-by: Wenqi Li --- monai/losses/tversky.py | 10 +++++++++- tests/test_dice_loss.py | 16 ++++++++++++++++ tests/test_seg_loss_integration.py | 2 ++ tests/test_tversky_loss.py | 24 ++++++++++++++++++++++++ 4 files changed, 51 insertions(+), 1 deletion(-) diff --git a/monai/losses/tversky.py b/monai/losses/tversky.py index 3ddc9c0e17..62b937d680 100644 --- a/monai/losses/tversky.py +++ b/monai/losses/tversky.py @@ -44,6 +44,7 @@ def __init__( reduction: Union[LossReduction, str] = LossReduction.MEAN, smooth_nr: float = 1e-5, smooth_dr: float = 1e-5, + batch: bool = False, ) -> None: """ Args: @@ -62,8 +63,12 @@ def __init__( - ``"none"``: no reduction will be applied. - ``"mean"``: the sum of the output will be divided by the number of elements in the output. - ``"sum"``: the output will be summed. + smooth_nr: a small constant added to the numerator to avoid zero. smooth_dr: a small constant added to the denominator to avoid nan. + batch: whether to sum the intersection and union areas over the batch dimension before the dividing. + Defaults to False, a Dice loss value is computed independently from each item in the batch + before any `reduction`. Raises: TypeError: When ``other_act`` is not an ``Optional[Callable]``. @@ -86,6 +91,7 @@ def __init__( self.beta = beta self.smooth_nr = float(smooth_nr) self.smooth_dr = float(smooth_dr) + self.batch = batch def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ @@ -135,11 +141,13 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: # reducing only spatial dimensions (not batch nor channels) reduce_axis = list(range(2, len(input.shape))) + if self.batch: + # reducing spatial dimensions and batch + reduce_axis = [0] + reduce_axis tp = torch.sum(p0 * g0, reduce_axis) fp = self.alpha * torch.sum(p0 * g1, reduce_axis) fn = self.beta * torch.sum(p1 * g0, reduce_axis) - numerator = tp + self.smooth_nr denominator = tp + fp + fn + self.smooth_dr diff --git a/tests/test_dice_loss.py b/tests/test_dice_loss.py index 6c2b3a4579..e2361df7a6 100644 --- a/tests/test_dice_loss.py +++ b/tests/test_dice_loss.py @@ -142,6 +142,22 @@ }, 0.774718, ], + [ # shape: (2, 1, 2, 2), (2, 1, 2, 2) + {"include_background": True, "other_act": torch.tanh, "smooth_nr": 0, "smooth_dr": 1e-4, "batch": True}, + { + "input": torch.tensor([[[[1.0, -0.0], [-1.0, 1.0]]], [[[1.0, -1.0], [-1.0, 1.0]]]]), + "target": torch.tensor([[[[1.0, 1.0], [1.0, 1.0]]], [[[1.0, 0.0], [1.0, 0.0]]]]), + }, + 0.774733, + ], + [ # shape: (2, 1, 2, 2), (2, 1, 2, 2) + {"include_background": True, "other_act": torch.tanh, "smooth_nr": 0, "smooth_dr": 1e-4, "batch": False}, + { + "input": torch.tensor([[[[1.0, -0.0], [-1.0, 1.0]]], [[[1.0, -1.0], [-1.0, 1.0]]]]), + "target": torch.tensor([[[[1.0, 1.0], [1.0, 1.0]]], [[[1.0, 0.0], [1.0, 0.0]]]]), + }, + 0.840058, + ], ] diff --git a/tests/test_seg_loss_integration.py b/tests/test_seg_loss_integration.py index 5024867a9a..c170d16cb6 100644 --- a/tests/test_seg_loss_integration.py +++ b/tests/test_seg_loss_integration.py @@ -21,6 +21,7 @@ TEST_CASES = [ [DiceLoss, {"to_onehot_y": True, "squared_pred": True, "smooth_nr": 1e-4, "smooth_dr": 1e-4}, {}], + [DiceLoss, {"to_onehot_y": True, "squared_pred": True, "smooth_nr": 0, "smooth_dr": 1e-3}, {}], [DiceLoss, {"to_onehot_y": True, "squared_pred": True, "batch": True}, {}], [DiceLoss, {"to_onehot_y": True, "sigmoid": True}, {}], [DiceLoss, {"to_onehot_y": True, "softmax": True}, {}], @@ -31,6 +32,7 @@ [GeneralizedDiceLoss, {"to_onehot_y": True, "sigmoid": True, "w_type": "simple"}, {}], [GeneralizedDiceLoss, {"to_onehot_y": True, "sigmoid": True, "w_type": "uniform"}, {}], [TverskyLoss, {"to_onehot_y": True, "softmax": True, "alpha": 0.8, "beta": 0.2}, {}], + [TverskyLoss, {"to_onehot_y": True, "softmax": True, "alpha": 0.8, "beta": 0.2, "batch": True}, {}], [TverskyLoss, {"to_onehot_y": True, "softmax": True, "alpha": 1.0, "beta": 0.0}, {}], ] diff --git a/tests/test_tversky_loss.py b/tests/test_tversky_loss.py index f8cda51193..bf4b3f8f0a 100644 --- a/tests/test_tversky_loss.py +++ b/tests/test_tversky_loss.py @@ -42,6 +42,14 @@ }, 0.0, ], + [ # shape: (2, 2, 3), (2, 1, 3) + {"include_background": False, "to_onehot_y": True, "smooth_nr": 0, "smooth_dr": 1e-3}, + { + "input": torch.tensor([[[1.0, 1.0, 0.0], [0.0, 0.0, 1.0]], [[1.0, 0.0, 1.0], [0.0, 1.0, 0.0]]]), + "target": torch.tensor([[[0.0, 0.0, 1.0]], [[0.0, 1.0, 0.0]]]), + }, + 0.000999, + ], [ # shape: (2, 2, 3), (2, 1, 3) {"include_background": True, "to_onehot_y": True, "sigmoid": True, "smooth_nr": 1e-4, "smooth_dr": 1e-4}, { @@ -50,6 +58,14 @@ }, 0.435050, ], + [ # shape: (2, 2, 3), (2, 1, 3) + {"include_background": True, "to_onehot_y": True, "sigmoid": True, "batch": True}, + { + "input": torch.tensor([[[-1.0, 0.0, 1.0], [1.0, 0.0, -1.0]], [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]), + "target": torch.tensor([[[1.0, 0.0, 0.0]], [[1.0, 1.0, 0.0]]]), + }, + 0.422979, + ], [ # shape: (2, 2, 3), (2, 1, 3) { "include_background": True, @@ -112,6 +128,14 @@ }, 0.999963, ], + [ # shape: (2, 1, 2, 2), (2, 1, 2, 2) + {"include_background": True, "other_act": torch.tanh, "smooth_nr": 0, "smooth_dr": 1e-3, "batch": True}, + { + "input": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]], [[[1.0, -1.0], [-1.0, 1.0]]]]), + "target": torch.tensor([[[[1.0, 1.0], [1.0, 1.0]]], [[[1.0, 0.0], [1.0, 0.0]]]]), + }, + 0.999963, + ], [ # shape: (2, 2, 3), (2, 1, 3) { "include_background": True,