From 55899660a7ba0fc32f2ffff58a54519b2189dd17 Mon Sep 17 00:00:00 2001 From: Lucas Robinet Date: Wed, 27 Mar 2024 08:47:20 +0100 Subject: [PATCH 1/5] harmonization and clarification of dice losses variants docs and associated tests Signed-off-by: Lucas Robinet --- monai/losses/dice.py | 53 +++++++++++++++++++---- tests/test_dice_ce_loss.py | 18 ++++++-- tests/test_dice_focal_loss.py | 14 +++++- tests/test_generalized_dice_focal_loss.py | 14 +++++- 4 files changed, 82 insertions(+), 17 deletions(-) diff --git a/monai/losses/dice.py b/monai/losses/dice.py index b3c0f57c6e..794dfcd8bd 100644 --- a/monai/losses/dice.py +++ b/monai/losses/dice.py @@ -776,14 +776,23 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: input: the shape should be BNH[WD]. target: the shape should be BNH[WD] or B1H[WD]. - Raises: - ValueError: When number of dimensions for input and target are different. - ValueError: When number of channels for target is neither 1 nor the same as input. + ValueError: When number of dimensions for input and target are different. + ValueError: When the target channel isn't either one-hot encoded or categorical with the same shape of the input. + + Returns: + torch.Tensor: value of the loss. """ - if len(input.shape) != len(target.shape): + if input.dim() != target.dim(): raise ValueError( "the number of dimensions for input and target should be the same, " + f"got shape {input.shape} (nb dims: {len(input.shape)}) and {target.shape} (nb dims: {len(target.shape)}). " + "if target is not one-hot encoded, please provide a tensor with shape B1H[WD]." + ) + + if target.shape[1] != 1 and target.shape[1] != input.shape[1]: + raise ValueError( + "number of channels for target is neither 1 (with no one-hot encoding) nor the same as input, " f"got shape {input.shape} and {target.shape}." ) @@ -899,14 +908,24 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: Raises: ValueError: When number of dimensions for input and target are different. - ValueError: When number of channels for target is neither 1 nor the same as input. + ValueError: When the target channel isn't either one-hot encoded or categorical with the same shape of the input. + Returns: + torch.Tensor: value of the loss. """ - if len(input.shape) != len(target.shape): + if input.dim() != target.dim(): raise ValueError( "the number of dimensions for input and target should be the same, " + f"got shape {input.shape} (nb dims: {len(input.shape)}) and {target.shape} (nb dims: {len(target.shape)}). " + "if target is not one-hot encoded, please provide a tensor with shape B1H[WD]." + ) + + if target.shape[1] != 1 and target.shape[1] != input.shape[1]: + raise ValueError( + "number of channels for target is neither 1 (with no one-hot encoding) nor the same as input, " f"got shape {input.shape} and {target.shape}." ) + if self.to_onehot_y: n_pred_ch = input.shape[1] if n_pred_ch == 1: @@ -1015,15 +1034,23 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: target (torch.Tensor): the shape should be BNH[WD] or B1H[WD]. Raises: - ValueError: When the input and target tensors have different numbers of dimensions, or the target - channel isn't either one-hot encoded or categorical with the same shape of the input. + ValueError: When number of dimensions for input and target are different. + ValueError: When the target channel isn't either one-hot encoded or categorical with the same shape of the input. Returns: torch.Tensor: value of the loss. """ if input.dim() != target.dim(): raise ValueError( - f"Input - {input.shape} - and target - {target.shape} - must have the same number of dimensions." + "the number of dimensions for input and target should be the same, " + f"got shape {input.shape} (nb dims: {len(input.shape)}) and {target.shape} (nb dims: {len(target.shape)}). " + "if target is not one-hot encoded, please provide a tensor with shape B1H[WD]." + ) + + if target.shape[1] != 1 and target.shape[1] != input.shape[1]: + raise ValueError( + "number of channels for target is neither 1 (with no one-hot encoding) nor the same as input, " + f"got shape {input.shape} and {target.shape}." ) gdl_loss = self.generalized_dice(input, target) @@ -1038,3 +1065,11 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: generalized_dice = GeneralizedDiceLoss generalized_dice_focal = GeneralizedDiceFocalLoss generalized_wasserstein_dice = GeneralizedWassersteinDiceLoss + +if __name__ == "__main__": + loss = DiceFocalLoss(to_onehot_y=False, softmax=True) + input = torch.randn(1, 3, 128, 128, 128) + target = torch.randn(1, 4, 128, 128, 128) + print(input.dim()) + print(target.dim()) + loss(input, target) diff --git a/tests/test_dice_ce_loss.py b/tests/test_dice_ce_loss.py index 225618ed2c..97c7ae5050 100644 --- a/tests/test_dice_ce_loss.py +++ b/tests/test_dice_ce_loss.py @@ -93,10 +93,20 @@ def test_result(self, input_param, input_data, expected_val): result = diceceloss(**input_data) np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, atol=1e-4, rtol=1e-4) - # def test_ill_shape(self): - # loss = DiceCELoss() - # with self.assertRaisesRegex(ValueError, ""): - # loss(torch.ones((1, 2, 3)), torch.ones((1, 1, 2, 3))) + def test_ill_shape(self): + loss = DiceCELoss() + with self.assertRaises(AssertionError): + loss.forward(torch.ones((1, 2, 3)), torch.ones((1, 2, 5))) + + def test_ill_shape2(self): + loss = DiceCELoss() + with self.assertRaises(ValueError): + loss.forward(torch.ones((1, 2, 3)), torch.ones((1, 1, 2, 3))) + + def test_ill_shape3(self): + loss = DiceCELoss() + with self.assertRaises(ValueError): + loss.forward(torch.ones((1, 3, 4, 4)), torch.ones((1, 2, 4, 4))) # def test_ill_reduction(self): # with self.assertRaisesRegex(ValueError, ""): diff --git a/tests/test_dice_focal_loss.py b/tests/test_dice_focal_loss.py index 13899da003..814a174762 100644 --- a/tests/test_dice_focal_loss.py +++ b/tests/test_dice_focal_loss.py @@ -69,8 +69,18 @@ def test_result_no_onehot_no_bg(self, size, onehot): def test_ill_shape(self): loss = DiceFocalLoss() - with self.assertRaisesRegex(ValueError, ""): - loss(torch.ones((1, 2, 3)), torch.ones((1, 1, 2, 3))) + with self.assertRaises(AssertionError): + loss.forward(torch.ones((1, 2, 3)), torch.ones((1, 2, 5))) + + def test_ill_shape2(self): + loss = DiceFocalLoss() + with self.assertRaises(ValueError): + loss.forward(torch.ones((1, 2, 3)), torch.ones((1, 1, 2, 3))) + + def test_ill_shape3(self): + loss = DiceFocalLoss() + with self.assertRaises(ValueError): + loss.forward(torch.ones((1, 3, 4, 4)), torch.ones((1, 2, 4, 4))) def test_ill_lambda(self): with self.assertRaisesRegex(ValueError, ""): diff --git a/tests/test_generalized_dice_focal_loss.py b/tests/test_generalized_dice_focal_loss.py index 8a0a80865e..65252611ca 100644 --- a/tests/test_generalized_dice_focal_loss.py +++ b/tests/test_generalized_dice_focal_loss.py @@ -59,8 +59,18 @@ def test_result_no_onehot_no_bg(self): def test_ill_shape(self): loss = GeneralizedDiceFocalLoss() - with self.assertRaisesRegex(ValueError, ""): - loss(torch.ones((1, 2, 3)), torch.ones((1, 1, 2, 3))) + with self.assertRaises(AssertionError): + loss.forward(torch.ones((1, 2, 3)), torch.ones((1, 2, 5))) + + def test_ill_shape2(self): + loss = GeneralizedDiceFocalLoss() + with self.assertRaises(ValueError): + loss.forward(torch.ones((1, 2, 3)), torch.ones((1, 1, 2, 3))) + + def test_ill_shape3(self): + loss = GeneralizedDiceFocalLoss() + with self.assertRaises(ValueError): + loss.forward(torch.ones((1, 3, 4, 4)), torch.ones((1, 2, 4, 4))) def test_ill_lambda(self): with self.assertRaisesRegex(ValueError, ""): From b7320abb18d9a5b7f6f698fc26c13df0952ff0f6 Mon Sep 17 00:00:00 2001 From: Lucas Robinet Date: Wed, 27 Mar 2024 09:07:50 +0100 Subject: [PATCH 2/5] fixing missing raises in diceceloss forward Signed-off-by: Lucas Robinet --- monai/losses/dice.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/monai/losses/dice.py b/monai/losses/dice.py index 794dfcd8bd..5d061d4a24 100644 --- a/monai/losses/dice.py +++ b/monai/losses/dice.py @@ -776,8 +776,9 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: input: the shape should be BNH[WD]. target: the shape should be BNH[WD] or B1H[WD]. - ValueError: When number of dimensions for input and target are different. - ValueError: When the target channel isn't either one-hot encoded or categorical with the same shape of the input. + Raises: + ValueError: When number of dimensions for input and target are different. + ValueError: When the target channel isn't either one-hot encoded or categorical with the same shape of the input. Returns: torch.Tensor: value of the loss. From 768f1f4ed5f339b2c96b5ae04386d09f577189dd Mon Sep 17 00:00:00 2001 From: Lucas Robinet Date: Fri, 5 Apr 2024 09:09:44 +0200 Subject: [PATCH 3/5] Update monai/losses/dice.py Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Signed-off-by: Lucas Robinet --- monai/losses/dice.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/monai/losses/dice.py b/monai/losses/dice.py index 5d061d4a24..6fcfd68b5d 100644 --- a/monai/losses/dice.py +++ b/monai/losses/dice.py @@ -1067,10 +1067,3 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: generalized_dice_focal = GeneralizedDiceFocalLoss generalized_wasserstein_dice = GeneralizedWassersteinDiceLoss -if __name__ == "__main__": - loss = DiceFocalLoss(to_onehot_y=False, softmax=True) - input = torch.randn(1, 3, 128, 128, 128) - target = torch.randn(1, 4, 128, 128, 128) - print(input.dim()) - print(target.dim()) - loss(input, target) From 3e959ee67b4510936cd01de63cd95683eef564e3 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 5 Apr 2024 07:10:12 +0000 Subject: [PATCH 4/5] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/losses/dice.py | 1 - 1 file changed, 1 deletion(-) diff --git a/monai/losses/dice.py b/monai/losses/dice.py index 6fcfd68b5d..c095726c2b 100644 --- a/monai/losses/dice.py +++ b/monai/losses/dice.py @@ -1066,4 +1066,3 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: generalized_dice = GeneralizedDiceLoss generalized_dice_focal = GeneralizedDiceFocalLoss generalized_wasserstein_dice = GeneralizedWassersteinDiceLoss - From b73c8ca5a026fe4f344bfdf11dfe95a6b82f14a5 Mon Sep 17 00:00:00 2001 From: Lucas Robinet Date: Fri, 5 Apr 2024 09:33:03 +0200 Subject: [PATCH 5/5] docstring and error clarification Signed-off-by: Lucas Robinet --- monai/losses/dice.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/monai/losses/dice.py b/monai/losses/dice.py index c095726c2b..f1c357d31f 100644 --- a/monai/losses/dice.py +++ b/monai/losses/dice.py @@ -778,7 +778,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: Raises: ValueError: When number of dimensions for input and target are different. - ValueError: When the target channel isn't either one-hot encoded or categorical with the same shape of the input. + ValueError: When number of channels for target is neither 1 (without one-hot encoding) nor the same as input. Returns: torch.Tensor: value of the loss. @@ -793,7 +793,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: if target.shape[1] != 1 and target.shape[1] != input.shape[1]: raise ValueError( - "number of channels for target is neither 1 (with no one-hot encoding) nor the same as input, " + "number of channels for target is neither 1 (without one-hot encoding) nor the same as input, " f"got shape {input.shape} and {target.shape}." ) @@ -909,7 +909,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: Raises: ValueError: When number of dimensions for input and target are different. - ValueError: When the target channel isn't either one-hot encoded or categorical with the same shape of the input. + ValueError: When number of channels for target is neither 1 (without one-hot encoding) nor the same as input. Returns: torch.Tensor: value of the loss. @@ -923,7 +923,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: if target.shape[1] != 1 and target.shape[1] != input.shape[1]: raise ValueError( - "number of channels for target is neither 1 (with no one-hot encoding) nor the same as input, " + "number of channels for target is neither 1 (without one-hot encoding) nor the same as input, " f"got shape {input.shape} and {target.shape}." ) @@ -1036,7 +1036,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: Raises: ValueError: When number of dimensions for input and target are different. - ValueError: When the target channel isn't either one-hot encoded or categorical with the same shape of the input. + ValueError: When number of channels for target is neither 1 (without one-hot encoding) nor the same as input. Returns: torch.Tensor: value of the loss. @@ -1050,7 +1050,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: if target.shape[1] != 1 and target.shape[1] != input.shape[1]: raise ValueError( - "number of channels for target is neither 1 (with no one-hot encoding) nor the same as input, " + "number of channels for target is neither 1 (without one-hot encoding) nor the same as input, " f"got shape {input.shape} and {target.shape}." )