diff --git a/monai/losses/dice.py b/monai/losses/dice.py index 610327ef63..67802abc0a 100644 --- a/monai/losses/dice.py +++ b/monai/losses/dice.py @@ -796,8 +796,6 @@ def __init__( """ super().__init__() self.dice = DiceLoss( - include_background=include_background, - to_onehot_y=to_onehot_y, sigmoid=sigmoid, softmax=softmax, other_act=other_act, @@ -808,19 +806,15 @@ def __init__( smooth_dr=smooth_dr, batch=batch, ) - self.focal = FocalLoss( - include_background=include_background, - to_onehot_y=to_onehot_y, - gamma=gamma, - weight=focal_weight, - reduction=reduction, - ) + self.focal = FocalLoss(gamma=gamma, weight=focal_weight, reduction=reduction) if lambda_dice < 0.0: raise ValueError("lambda_dice should be no less than 0.0.") if lambda_focal < 0.0: raise ValueError("lambda_focal should be no less than 0.0.") self.lambda_dice = lambda_dice self.lambda_focal = lambda_focal + self.to_onehot_y = to_onehot_y + self.include_background = include_background def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ @@ -837,6 +831,22 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: if len(input.shape) != len(target.shape): raise ValueError("the number of dimensions for input and target should be the same.") + n_pred_ch = input.shape[1] + + if self.to_onehot_y: + if n_pred_ch == 1: + warnings.warn("single channel prediction, `to_onehot_y=True` ignored.") + else: + target = one_hot(target, num_classes=n_pred_ch) + + if not self.include_background: + if n_pred_ch == 1: + warnings.warn("single channel prediction, `include_background=False` ignored.") + else: + # if skipping background, removing first channel + target = target[:, 1:] + input = input[:, 1:] + dice_loss = self.dice(input, target) focal_loss = self.focal(input, target) total_loss: torch.Tensor = self.lambda_dice * dice_loss + self.lambda_focal * focal_loss