From adff16230ad4d4c4e32149eb0e72eea706b40d2e Mon Sep 17 00:00:00 2001 From: ytl0623 Date: Fri, 16 Jan 2026 11:03:23 +0800 Subject: [PATCH] Adjust execution order of activation and masking in MaskedDiceLoss Signed-off-by: ytl0623 --- monai/losses/dice.py | 30 ++++++++++++++++++++++++++- tests/losses/test_masked_dice_loss.py | 6 +++--- 2 files changed, 32 insertions(+), 4 deletions(-) diff --git a/monai/losses/dice.py b/monai/losses/dice.py index 948749606b..3d810bc1fe 100644 --- a/monai/losses/dice.py +++ b/monai/losses/dice.py @@ -244,7 +244,22 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: Args follow :py:class:`monai.losses.DiceLoss`. """ super().__init__(*args, **kwargs) - self.spatial_weighted = MaskedLoss(loss=super().forward) + self.dice = DiceLoss( + include_background=self.include_background, + to_onehot_y=self.to_onehot_y, + sigmoid=False, + softmax=False, + other_act=None, + squared_pred=self.squared_pred, + jaccard=self.jaccard, + reduction=self.reduction, + smooth_nr=self.smooth_nr, + smooth_dr=self.smooth_dr, + batch=self.batch, + weight=self.class_weight, + soft_label=self.soft_label, + ) + self.spatial_weighted = MaskedLoss(loss=self.dice.forward) def forward(self, input: torch.Tensor, target: torch.Tensor, mask: torch.Tensor | None = None) -> torch.Tensor: """ @@ -253,6 +268,19 @@ def forward(self, input: torch.Tensor, target: torch.Tensor, mask: torch.Tensor target: the shape should be BNH[WD]. mask: the shape should B1H[WD] or 11H[WD]. """ + + if self.sigmoid: + input = torch.sigmoid(input) + + n_pred_ch = input.shape[1] + if self.softmax: + if n_pred_ch == 1: + warnings.warn("single channel prediction, `softmax=True` ignored.") + else: + input = torch.softmax(input, 1) + + if self.other_act is not None: + input = self.other_act(input) return self.spatial_weighted(input=input, target=target, mask=mask) # type: ignore[no-any-return] diff --git a/tests/losses/test_masked_dice_loss.py b/tests/losses/test_masked_dice_loss.py index c971723615..ea08254ee5 100644 --- a/tests/losses/test_masked_dice_loss.py +++ b/tests/losses/test_masked_dice_loss.py @@ -27,7 +27,7 @@ "target": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]]), "mask": torch.tensor([[[[0.0, 0.0], [1.0, 1.0]]]]), }, - 0.500, + 0.333333, ], [ # shape: (2, 1, 2, 2), (2, 1, 2, 2) {"include_background": True, "sigmoid": True, "smooth_nr": 1e-4, "smooth_dr": 1e-4}, @@ -36,7 +36,7 @@ "target": torch.tensor([[[[1.0, 1.0], [1.0, 1.0]]], [[[1.0, 0.0], [1.0, 0.0]]]]), "mask": torch.tensor([[[[1.0, 1.0], [1.0, 1.0]]], [[[1.0, 1.0], [0.0, 0.0]]]]), }, - 0.422969, + 0.301128, ], [ # shape: (2, 2, 3), (2, 1, 3) {"include_background": False, "to_onehot_y": True, "smooth_nr": 0, "smooth_dr": 0}, @@ -54,7 +54,7 @@ "target": torch.tensor([[[1.0, 0.0, 0.0]], [[1.0, 1.0, 0.0]]]), "mask": torch.tensor([[[1.0, 1.0, 0.0]]]), }, - 0.47033, + 0.579184, ], [ # shape: (2, 2, 3), (2, 1, 3) {