Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 35 additions & 7 deletions monai/losses/dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -778,12 +778,22 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:

Raises:
ValueError: When number of dimensions for input and target are different.
ValueError: When number of channels for target is neither 1 nor the same as input.
ValueError: When number of channels for target is neither 1 (without one-hot encoding) nor the same as input.

Returns:
torch.Tensor: value of the loss.

"""
if len(input.shape) != len(target.shape):
if input.dim() != target.dim():
raise ValueError(
"the number of dimensions for input and target should be the same, "
f"got shape {input.shape} (nb dims: {len(input.shape)}) and {target.shape} (nb dims: {len(target.shape)}). "
"if target is not one-hot encoded, please provide a tensor with shape B1H[WD]."
)

if target.shape[1] != 1 and target.shape[1] != input.shape[1]:
raise ValueError(
"number of channels for target is neither 1 (without one-hot encoding) nor the same as input, "
f"got shape {input.shape} and {target.shape}."
)

Expand Down Expand Up @@ -899,14 +909,24 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:

Raises:
ValueError: When number of dimensions for input and target are different.
ValueError: When number of channels for target is neither 1 nor the same as input.
ValueError: When number of channels for target is neither 1 (without one-hot encoding) nor the same as input.

Returns:
torch.Tensor: value of the loss.
"""
if len(input.shape) != len(target.shape):
if input.dim() != target.dim():
raise ValueError(
"the number of dimensions for input and target should be the same, "
f"got shape {input.shape} (nb dims: {len(input.shape)}) and {target.shape} (nb dims: {len(target.shape)}). "
"if target is not one-hot encoded, please provide a tensor with shape B1H[WD]."
)

if target.shape[1] != 1 and target.shape[1] != input.shape[1]:
raise ValueError(
"number of channels for target is neither 1 (without one-hot encoding) nor the same as input, "
f"got shape {input.shape} and {target.shape}."
)

if self.to_onehot_y:
n_pred_ch = input.shape[1]
if n_pred_ch == 1:
Expand Down Expand Up @@ -1015,15 +1035,23 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
target (torch.Tensor): the shape should be BNH[WD] or B1H[WD].

Raises:
ValueError: When the input and target tensors have different numbers of dimensions, or the target
channel isn't either one-hot encoded or categorical with the same shape of the input.
ValueError: When number of dimensions for input and target are different.
ValueError: When number of channels for target is neither 1 (without one-hot encoding) nor the same as input.

Returns:
torch.Tensor: value of the loss.
"""
if input.dim() != target.dim():
raise ValueError(
f"Input - {input.shape} - and target - {target.shape} - must have the same number of dimensions."
"the number of dimensions for input and target should be the same, "
f"got shape {input.shape} (nb dims: {len(input.shape)}) and {target.shape} (nb dims: {len(target.shape)}). "
"if target is not one-hot encoded, please provide a tensor with shape B1H[WD]."
)

if target.shape[1] != 1 and target.shape[1] != input.shape[1]:
raise ValueError(
"number of channels for target is neither 1 (without one-hot encoding) nor the same as input, "
f"got shape {input.shape} and {target.shape}."
)

gdl_loss = self.generalized_dice(input, target)
Expand Down
18 changes: 14 additions & 4 deletions tests/test_dice_ce_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,10 +93,20 @@ def test_result(self, input_param, input_data, expected_val):
result = diceceloss(**input_data)
np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, atol=1e-4, rtol=1e-4)

# def test_ill_shape(self):
# loss = DiceCELoss()
# with self.assertRaisesRegex(ValueError, ""):
# loss(torch.ones((1, 2, 3)), torch.ones((1, 1, 2, 3)))
def test_ill_shape(self):
loss = DiceCELoss()
with self.assertRaises(AssertionError):
loss.forward(torch.ones((1, 2, 3)), torch.ones((1, 2, 5)))

def test_ill_shape2(self):
loss = DiceCELoss()
with self.assertRaises(ValueError):
loss.forward(torch.ones((1, 2, 3)), torch.ones((1, 1, 2, 3)))

def test_ill_shape3(self):
loss = DiceCELoss()
with self.assertRaises(ValueError):
loss.forward(torch.ones((1, 3, 4, 4)), torch.ones((1, 2, 4, 4)))

# def test_ill_reduction(self):
# with self.assertRaisesRegex(ValueError, ""):
Expand Down
14 changes: 12 additions & 2 deletions tests/test_dice_focal_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,18 @@ def test_result_no_onehot_no_bg(self, size, onehot):

def test_ill_shape(self):
loss = DiceFocalLoss()
with self.assertRaisesRegex(ValueError, ""):
loss(torch.ones((1, 2, 3)), torch.ones((1, 1, 2, 3)))
with self.assertRaises(AssertionError):
loss.forward(torch.ones((1, 2, 3)), torch.ones((1, 2, 5)))

def test_ill_shape2(self):
loss = DiceFocalLoss()
with self.assertRaises(ValueError):
loss.forward(torch.ones((1, 2, 3)), torch.ones((1, 1, 2, 3)))

def test_ill_shape3(self):
loss = DiceFocalLoss()
with self.assertRaises(ValueError):
loss.forward(torch.ones((1, 3, 4, 4)), torch.ones((1, 2, 4, 4)))

def test_ill_lambda(self):
with self.assertRaisesRegex(ValueError, ""):
Expand Down
14 changes: 12 additions & 2 deletions tests/test_generalized_dice_focal_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,18 @@ def test_result_no_onehot_no_bg(self):

def test_ill_shape(self):
loss = GeneralizedDiceFocalLoss()
with self.assertRaisesRegex(ValueError, ""):
loss(torch.ones((1, 2, 3)), torch.ones((1, 1, 2, 3)))
with self.assertRaises(AssertionError):
loss.forward(torch.ones((1, 2, 3)), torch.ones((1, 2, 5)))

def test_ill_shape2(self):
loss = GeneralizedDiceFocalLoss()
with self.assertRaises(ValueError):
loss.forward(torch.ones((1, 2, 3)), torch.ones((1, 1, 2, 3)))

def test_ill_shape3(self):
loss = GeneralizedDiceFocalLoss()
with self.assertRaises(ValueError):
loss.forward(torch.ones((1, 3, 4, 4)), torch.ones((1, 2, 4, 4)))

def test_ill_lambda(self):
with self.assertRaisesRegex(ValueError, ""):
Expand Down