Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 19 additions & 9 deletions monai/losses/dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -796,8 +796,6 @@ def __init__(
"""
super().__init__()
self.dice = DiceLoss(
include_background=include_background,
to_onehot_y=to_onehot_y,
sigmoid=sigmoid,
softmax=softmax,
other_act=other_act,
Expand All @@ -808,19 +806,15 @@ def __init__(
smooth_dr=smooth_dr,
batch=batch,
)
self.focal = FocalLoss(
include_background=include_background,
to_onehot_y=to_onehot_y,
gamma=gamma,
weight=focal_weight,
reduction=reduction,
)
self.focal = FocalLoss(gamma=gamma, weight=focal_weight, reduction=reduction)
if lambda_dice < 0.0:
raise ValueError("lambda_dice should be no less than 0.0.")
if lambda_focal < 0.0:
raise ValueError("lambda_focal should be no less than 0.0.")
self.lambda_dice = lambda_dice
self.lambda_focal = lambda_focal
self.to_onehot_y = to_onehot_y
self.include_background = include_background

def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Expand All @@ -837,6 +831,22 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
if len(input.shape) != len(target.shape):
raise ValueError("the number of dimensions for input and target should be the same.")

n_pred_ch = input.shape[1]

if self.to_onehot_y:
if n_pred_ch == 1:
warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")
else:
target = one_hot(target, num_classes=n_pred_ch)

if not self.include_background:
if n_pred_ch == 1:
warnings.warn("single channel prediction, `include_background=False` ignored.")
else:
# if skipping background, removing first channel
target = target[:, 1:]
input = input[:, 1:]

dice_loss = self.dice(input, target)
focal_loss = self.focal(input, target)
total_loss: torch.Tensor = self.lambda_dice * dice_loss + self.lambda_focal * focal_loss
Expand Down