From 6f161973e6a8d0e2f2d6b435c7b92a9421aa5a24 Mon Sep 17 00:00:00 2001 From: Ryan Clanton <55164720+ryancinsight@users.noreply.github.com> Date: Thu, 5 May 2022 10:06:43 -0400 Subject: [PATCH 1/3] Update dice.py reduce redundant operations in DiceFocalLoss, initially caused oom Signed-off-by: Ryan Clanton <55164720+ryancinsight@users.noreply.github.com> --- monai/losses/dice.py | 22 ++++++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/monai/losses/dice.py b/monai/losses/dice.py index 610327ef63..5504f25360 100644 --- a/monai/losses/dice.py +++ b/monai/losses/dice.py @@ -796,8 +796,6 @@ def __init__( """ super().__init__() self.dice = DiceLoss( - include_background=include_background, - to_onehot_y=to_onehot_y, sigmoid=sigmoid, softmax=softmax, other_act=other_act, @@ -809,8 +807,6 @@ def __init__( batch=batch, ) self.focal = FocalLoss( - include_background=include_background, - to_onehot_y=to_onehot_y, gamma=gamma, weight=focal_weight, reduction=reduction, @@ -821,6 +817,8 @@ def __init__( raise ValueError("lambda_focal should be no less than 0.0.") self.lambda_dice = lambda_dice self.lambda_focal = lambda_focal + self.to_onehot_y = to_onehot_y + self.include_background = include_background def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ @@ -836,6 +834,22 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ if len(input.shape) != len(target.shape): raise ValueError("the number of dimensions for input and target should be the same.") + + n_pred_ch = input.shape[1] + + if self.to_onehot_y: + if n_pred_ch == 1: + warnings.warn("single channel prediction, `to_onehot_y=True` ignored.") + else: + target = one_hot(target, num_classes=n_pred_ch) + + if not self.include_background: + if n_pred_ch == 1: + warnings.warn("single channel prediction, `include_background=False` ignored.") + else: + # if skipping background, removing first channel + target = target[:, 1:] + input = input[:, 1:] dice_loss = self.dice(input, target) focal_loss = self.focal(input, target) From be14db40edd97ef25d8436c44322ce9663e6aefb Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 5 May 2022 14:13:13 +0000 Subject: [PATCH 2/3] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Ryan Clanton <55164720+ryancinsight@users.noreply.github.com> --- monai/losses/dice.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/monai/losses/dice.py b/monai/losses/dice.py index 5504f25360..b65ea97676 100644 --- a/monai/losses/dice.py +++ b/monai/losses/dice.py @@ -834,15 +834,15 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ if len(input.shape) != len(target.shape): raise ValueError("the number of dimensions for input and target should be the same.") - + n_pred_ch = input.shape[1] - + if self.to_onehot_y: if n_pred_ch == 1: warnings.warn("single channel prediction, `to_onehot_y=True` ignored.") else: target = one_hot(target, num_classes=n_pred_ch) - + if not self.include_background: if n_pred_ch == 1: warnings.warn("single channel prediction, `include_background=False` ignored.") From 4018a99e7d3de73a501f3cba4b8768ad6f19eaf6 Mon Sep 17 00:00:00 2001 From: monai-bot Date: Fri, 6 May 2022 07:17:15 +0000 Subject: [PATCH 3/3] [MONAI] python code formatting Signed-off-by: monai-bot --- monai/losses/dice.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/monai/losses/dice.py b/monai/losses/dice.py index b65ea97676..67802abc0a 100644 --- a/monai/losses/dice.py +++ b/monai/losses/dice.py @@ -806,11 +806,7 @@ def __init__( smooth_dr=smooth_dr, batch=batch, ) - self.focal = FocalLoss( - gamma=gamma, - weight=focal_weight, - reduction=reduction, - ) + self.focal = FocalLoss(gamma=gamma, weight=focal_weight, reduction=reduction) if lambda_dice < 0.0: raise ValueError("lambda_dice should be no less than 0.0.") if lambda_focal < 0.0: