Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ def generate_apidocs(*args):
{"name": "Twitter", "url": "https://twitter.com/projectmonai", "icon": "fab fa-twitter-square"},
],
"collapse_navigation": True,
"navigation_with_keys": True,
"navigation_depth": 1,
"show_toc_level": 1,
"footer_start": ["copyright"],
Expand Down
12 changes: 11 additions & 1 deletion monai/handlers/mean_dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def __init__(
num_classes: int | None = None,
output_transform: Callable = lambda x: x,
save_details: bool = True,
return_with_label: bool | list[str] = False,
) -> None:
"""

Expand All @@ -50,9 +51,18 @@ def __init__(
https://github.com/Project-MONAI/tutorials/blob/master/modules/batch_output_transform.ipynb.
save_details: whether to save metric computation details per image, for example: mean dice of every image.
default to True, will save to `engine.state.metric_details` dict with the metric name as key.
return_with_label: whether to return the metrics with label, only works when reduction is "mean_batch".
If `True`, use "label_{index}" as the key corresponding to C channels; if 'include_background' is True,
the index begins at "0", otherwise at "1". It can also take a list of label names.
The outcome will then be returned as a dictionary.

See also:
:py:meth:`monai.metrics.meandice.compute_dice`
"""
metric_fn = DiceMetric(include_background=include_background, reduction=reduction, num_classes=num_classes)
metric_fn = DiceMetric(
include_background=include_background,
reduction=reduction,
num_classes=num_classes,
return_with_label=return_with_label,
)
super().__init__(metric_fn=metric_fn, output_transform=output_transform, save_details=save_details)
16 changes: 16 additions & 0 deletions monai/metrics/meandice.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,10 @@ class DiceMetric(CumulativeIterationMetric):
num_classes: number of input channels (always including the background). When this is None,
``y_pred.shape[1]`` will be used. This option is useful when both ``y_pred`` and ``y`` are
single-channel class indices and the number of classes is not automatically inferred from data.
return_with_label: whether to return the metrics with label, only works when reduction is "mean_batch".
If `True`, use "label_{index}" as the key corresponding to C channels; if 'include_background' is True,
the index begins at "0", otherwise at "1". It can also take a list of label names.
The outcome will then be returned as a dictionary.

"""

Expand All @@ -60,13 +64,15 @@ def __init__(
get_not_nans: bool = False,
ignore_empty: bool = True,
num_classes: int | None = None,
return_with_label: bool | list[str] = False,
) -> None:
super().__init__()
self.include_background = include_background
self.reduction = reduction
self.get_not_nans = get_not_nans
self.ignore_empty = ignore_empty
self.num_classes = num_classes
self.return_with_label = return_with_label
self.dice_helper = DiceHelper(
include_background=self.include_background,
reduction=MetricReduction.NONE,
Expand Down Expand Up @@ -112,6 +118,16 @@ def aggregate(

# do metric reduction
f, not_nans = do_metric_reduction(data, reduction or self.reduction)
if self.reduction == MetricReduction.MEAN_BATCH and self.return_with_label:
_f = {}
if isinstance(self.return_with_label, bool):
for i, v in enumerate(f):
_label_key = f"label_{i+1}" if not self.include_background else f"label_{i}"
_f[_label_key] = round(v.item(), 4)
else:
for key, v in zip(self.return_with_label, f):
_f[key] = round(v.item(), 4)
f = _f
return (f, not_nans) if self.get_not_nans else f


Expand Down
74 changes: 72 additions & 2 deletions tests/test_compute_meandice.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,71 @@
[[0.0000, 0.0000], [0.0000, 0.0000]],
]

# test return_with_label
TEST_CASE_13 = [
{
"include_background": True,
"reduction": "mean_batch",
"get_not_nans": True,
"return_with_label": ["bg", "fg0", "fg1"],
},
{
"y_pred": torch.tensor(
[
[[[1.0, 1.0], [1.0, 0.0]], [[0.0, 1.0], [0.0, 0.0]], [[0.0, 1.0], [1.0, 1.0]]],
[[[1.0, 0.0], [1.0, 1.0]], [[0.0, 1.0], [1.0, 1.0]], [[0.0, 1.0], [1.0, 0.0]]],
]
),
"y": torch.tensor(
[
[[[1.0, 1.0], [1.0, 1.0]], [[0.0, 0.0], [0.0, 0.0]], [[0.0, 0.0], [0.0, 0.0]]],
[[[0.0, 0.0], [0.0, 1.0]], [[1.0, 1.0], [0.0, 0.0]], [[0.0, 0.0], [1.0, 0.0]]],
]
),
},
{"bg": 0.6786, "fg0": 0.4000, "fg1": 0.6667},
]

# test return_with_label, include_background
TEST_CASE_14 = [
{"include_background": True, "reduction": "mean_batch", "get_not_nans": True, "return_with_label": True},
{
"y_pred": torch.tensor(
[
[[[1.0, 1.0], [1.0, 0.0]], [[0.0, 1.0], [0.0, 0.0]], [[0.0, 1.0], [1.0, 1.0]]],
[[[1.0, 0.0], [1.0, 1.0]], [[0.0, 1.0], [1.0, 1.0]], [[0.0, 1.0], [1.0, 0.0]]],
]
),
"y": torch.tensor(
[
[[[1.0, 1.0], [1.0, 1.0]], [[0.0, 0.0], [0.0, 0.0]], [[0.0, 0.0], [0.0, 0.0]]],
[[[0.0, 0.0], [0.0, 1.0]], [[1.0, 1.0], [0.0, 0.0]], [[0.0, 0.0], [1.0, 0.0]]],
]
),
},
{"label_0": 0.6786, "label_1": 0.4000, "label_2": 0.6667},
]

# test return_with_label, not include_background
TEST_CASE_15 = [
{"include_background": False, "reduction": "mean_batch", "get_not_nans": True, "return_with_label": True},
{
"y_pred": torch.tensor(
[
[[[1.0, 1.0], [1.0, 0.0]], [[0.0, 1.0], [0.0, 0.0]], [[0.0, 1.0], [1.0, 1.0]]],
[[[1.0, 0.0], [1.0, 1.0]], [[0.0, 1.0], [1.0, 1.0]], [[0.0, 1.0], [1.0, 0.0]]],
]
),
"y": torch.tensor(
[
[[[1.0, 1.0], [1.0, 1.0]], [[0.0, 0.0], [0.0, 0.0]], [[0.0, 0.0], [0.0, 0.0]]],
[[[0.0, 0.0], [0.0, 1.0]], [[1.0, 1.0], [0.0, 0.0]], [[0.0, 0.0], [1.0, 0.0]]],
]
),
},
{"label_1": 0.4000, "label_2": 0.6667},
]


class TestComputeMeanDice(unittest.TestCase):
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_9, TEST_CASE_11, TEST_CASE_12])
Expand Down Expand Up @@ -223,12 +288,17 @@ def test_value_class(self, input_data, expected_value):
result = dice_metric.aggregate(reduction="none")
np.testing.assert_allclose(result.cpu().numpy(), expected_value, atol=1e-4)

@parameterized.expand([TEST_CASE_4, TEST_CASE_5, TEST_CASE_6, TEST_CASE_7, TEST_CASE_8])
@parameterized.expand(
[TEST_CASE_4, TEST_CASE_5, TEST_CASE_6, TEST_CASE_7, TEST_CASE_8, TEST_CASE_13, TEST_CASE_14, TEST_CASE_15]
)
def test_nans_class(self, params, input_data, expected_value):
dice_metric = DiceMetric(**params)
dice_metric(**input_data)
result, _ = dice_metric.aggregate()
np.testing.assert_allclose(result.cpu().numpy(), expected_value, atol=1e-4)
if isinstance(result, dict):
self.assertEqual(result, expected_value)
else:
np.testing.assert_allclose(result.cpu().numpy(), expected_value, atol=1e-4)


if __name__ == "__main__":
Expand Down