-
Notifications
You must be signed in to change notification settings - Fork 1.4k
Description
Describe the bug
compute_generalized_dice does not return values as highlighted in the docs.
import torch
from monai.metrics import compute_generalized_dice, GeneralizedDiceScore
a = torch.ones((20, 10, 64, 128, 128))
b = torch.ones((20, 10, 64, 128, 128))
compute_generalized_dice(a, b).shapeThis returns a shape of torch.Size([20]) whereas per the documentation it should return torch.Size([20, 10])
https://docs.monai.io/en/stable/metrics.html#generalized-dice-score
This leads to problems that propagate to GeneralizedDiceScore as below,
generalized_dice_score = GeneralizedDiceScore()
generalized_dice_score(a, b)Aggregating this over different reductions shows unexpected results and errors.
Case 1: generalized_dice_score.aggregate(reduction="sum") and generalized_dice_score.aggregate(reduction="mean") return this error IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1)
Case 2: generalized_dice_score.aggregate(reduction="sum_batch") returns a single-element tensor but it should return a tensor with the size as number of classes containing dice scores summed across all batches for each class. Similar for mean_batch.
Case 3: generalized_dice_score.aggregate(reduction="none") should give dice scores across batch and classes unreduced (similar to compute_generalized_dice) but it gives values reduced over all classes.
Environment
Ensuring you use the relevant python executable, please paste the output of:
================================
Printing MONAI config...
================================
MONAI version: 1.3.2
Numpy version: 1.24.4
Pytorch version: 2.4.0+cu121
MONAI flags: HAS_EXT = False, USE_COMPILED = False, USE_META_DICT = False
MONAI rev id: 59a7211070538586369afd4a01eca0a7fe2e742e
MONAI __file__: /home/<username>/miniconda3/lib/python3.10/site-packages/monai/__init__.py
Suggested Solution
It looks like removing the .sum() in the following code should give expected behaviour
MONAI/monai/metrics/generalized_dice.py
Lines 162 to 163 in 59a7211
| numer = 2.0 * (intersection * w).sum(dim=1) | |
| denom = (denominator * w).sum(dim=1) |
MONAI/monai/metrics/generalized_dice.py
Line 170 in 59a7211
| y_pred_o = y_pred_o.sum(dim=-1) |
Tried it out and the shapes and values look as expected.