Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions monai/handlers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from __future__ import annotations

from .average_precision import AveragePrecision
from .calibration import CalibrationError
from .checkpoint_loader import CheckpointLoader
from .checkpoint_saver import CheckpointSaver
from .classification_saver import ClassificationSaver
Expand Down
71 changes: 71 additions & 0 deletions monai/handlers/calibration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from collections.abc import Callable

from monai.handlers.ignite_metric import IgniteMetricHandler
from monai.metrics import CalibrationErrorMetric, CalibrationReduction
from monai.utils import MetricReduction

__all__ = ["CalibrationError"]


class CalibrationError(IgniteMetricHandler):
"""
Computes Calibration Error and reports the aggregated value according to `metric_reduction`
over all accumulated iterations. Can return the expected, average, or maximum calibration error.

Args:
num_bins: number of bins to calculate calibration. Defaults to 20.
include_background: whether to include calibration error computation on the first channel of
the predicted output. Defaults to True.
calibration_reduction: Method for calculating calibration error values from binned data.
Available modes are `"expected"`, `"average"`, and `"maximum"`. Defaults to `"expected"`.
metric_reduction: Mode of reduction to apply to the metrics.
Reduction is only applied to non-NaN values.
Available reduction modes are `"none"`, `"mean"`, `"sum"`, `"mean_batch"`,
`"sum_batch"`, `"mean_channel"`, and `"sum_channel"`.
Defaults to `"mean"`. If set to `"none"`, no reduction will be performed.
output_transform: callable to extract `y_pred` and `y` from `ignite.engine.state.output` then
construct `(y_pred, y)` pair, where `y_pred` and `y` can be `batch-first` Tensors or
lists of `channel-first` Tensors. the form of `(y_pred, y)` is required by the `update()`.
`engine.state` and `output_transform` inherit from the ignite concept:
https://pytorch.org/ignite/concepts.html#state, explanation and usage example are in the tutorial:
https://github.com/Project-MONAI/tutorials/blob/master/modules/batch_output_transform.ipynb.
save_details: whether to save metric computation details per image, for example: calibration error
of every image. default to True, will save to `engine.state.metric_details` dict with the
metric name as key.

"""

def __init__(
self,
num_bins: int = 20,
include_background: bool = True,
calibration_reduction: CalibrationReduction | str = CalibrationReduction.EXPECTED,
metric_reduction: MetricReduction | str = MetricReduction.MEAN,
output_transform: Callable = lambda x: x,
save_details: bool = True,
) -> None:
metric_fn = CalibrationErrorMetric(
num_bins=num_bins,
include_background=include_background,
calibration_reduction=calibration_reduction,
metric_reduction=metric_reduction,
)

super().__init__(
metric_fn=metric_fn,
output_transform=output_transform,
save_details=save_details,
)
1 change: 1 addition & 0 deletions monai/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from .active_learning_metrics import LabelQualityScore, VarianceMetric, compute_variance, label_quality_score
from .average_precision import AveragePrecisionMetric, compute_average_precision
from .calibration import CalibrationErrorMetric, CalibrationReduction, calibration_binning
from .confusion_matrix import ConfusionMatrixMetric, compute_confusion_matrix_metric, get_confusion_matrix
from .cumulative_average import CumulativeAverage
from .f_beta_score import FBetaScore
Expand Down
260 changes: 260 additions & 0 deletions monai/metrics/calibration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,260 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from typing import Any

import torch

from monai.metrics.metric import CumulativeIterationMetric
from monai.metrics.utils import do_metric_reduction, ignore_background
from monai.utils import MetricReduction
from monai.utils.enums import StrEnum

__all__ = [
"calibration_binning",
"CalibrationErrorMetric",
"CalibrationReduction",
]


def calibration_binning(
y_pred: torch.Tensor, y: torch.Tensor, num_bins: int = 20, right: bool = False
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Compute calibration bins for predicted probabilities and ground truth labels.
This function calculates the mean predicted probabilities, mean ground truths,
and bin counts for each bin using a hard binning calibration approach.

The function operates on input and target tensors with batch and channel dimensions,
handling each batch and channel separately. For bins that do not contain any elements,
the mean predicted values and mean ground truth values are set to NaN.

Args:
y_pred: predicted tensor with shape [batch, channel, spatial], where spatial
can be any number of dimensions. The y_pred tensor represents probabilities.
Values should be in the range [0, 1] (probabilities).
y: Target tensor with the same shape as y_pred. It represents ground truth values.
num_bins: The number of bins to use for calibration. Defaults to 20. Must be >= 1.
right: If False (default), the bins include the left boundary and exclude the right boundary.
If True, the bins exclude the left boundary and include the right boundary.

Returns:
A tuple of three tensors:
- mean_p_per_bin: Tensor of shape [batch_size, num_channels, num_bins] containing
the mean predicted values in each bin.
- mean_gt_per_bin: Tensor of shape [batch_size, num_channels, num_bins] containing
the mean ground truth values in each bin.
- bin_counts: Tensor of shape [batch_size, num_channels, num_bins] containing
the count of elements in each bin.

Raises:
ValueError: If the input and target shapes do not match, if the input has fewer than 3 dimensions,
or if num_bins < 1.

Note:
This function currently uses nested for loops over batch and channel dimensions
for binning operations. Future improvements may include vectorizing these operations
for enhanced performance.
"""
# Input validation
if y_pred.shape != y.shape:
raise ValueError(f"y_pred and y must have the same shape, got {y_pred.shape} and {y.shape}.")
if y_pred.ndim < 3:
raise ValueError(f"y_pred must have shape (B, C, spatial...), got ndim={y_pred.ndim}.")
if num_bins < 1:
raise ValueError(f"num_bins must be >= 1, got {num_bins}.")

batch_size, num_channels = y_pred.shape[:2]
boundaries = torch.linspace(
start=0.0,
end=1.0 + torch.finfo(torch.float32).eps,
steps=num_bins + 1,
device=y_pred.device,
)

mean_p_per_bin = torch.zeros(batch_size, num_channels, num_bins, device=y_pred.device)
mean_gt_per_bin = torch.zeros_like(mean_p_per_bin)
bin_counts = torch.zeros_like(mean_p_per_bin)

y_pred_flat = y_pred.flatten(start_dim=2).float()
y_flat = y.flatten(start_dim=2).float()

for b in range(batch_size):
for c in range(num_channels):
values_p = y_pred_flat[b, c, :]
values_gt = y_flat[b, c, :]

# Compute bin indices and clamp to valid range to handle out-of-range values
bin_idx = torch.bucketize(values_p, boundaries[1:], right=right)
bin_idx = bin_idx.clamp(max=num_bins - 1)

# Compute bin counts using scatter_add
counts = torch.zeros(num_bins, device=y_pred.device, dtype=torch.float32)
counts.scatter_add_(0, bin_idx, torch.ones_like(values_p))
bin_counts[b, c, :] = counts

# Compute sums for mean calculation using scatter_add (more compatible than scatter_reduce)
sum_p = torch.zeros(num_bins, device=y_pred.device, dtype=torch.float32)
sum_p.scatter_add_(0, bin_idx, values_p)

sum_gt = torch.zeros(num_bins, device=y_pred.device, dtype=torch.float32)
sum_gt.scatter_add_(0, bin_idx, values_gt)

# Compute means, avoiding division by zero
safe_counts = counts.clamp(min=1)
mean_p_per_bin[b, c, :] = sum_p / safe_counts
mean_gt_per_bin[b, c, :] = sum_gt / safe_counts

# Set empty bins to NaN
mean_p_per_bin[bin_counts == 0] = torch.nan
mean_gt_per_bin[bin_counts == 0] = torch.nan

return mean_p_per_bin, mean_gt_per_bin, bin_counts


class CalibrationReduction(StrEnum):
"""
Enumeration of calibration error reduction methods.

- EXPECTED: Expected Calibration Error (ECE) - weighted average by bin count
- AVERAGE: Average Calibration Error (ACE) - simple average across bins
- MAXIMUM: Maximum Calibration Error (MCE) - maximum error across bins
"""

EXPECTED = "expected"
AVERAGE = "average"
MAXIMUM = "maximum"


class CalibrationErrorMetric(CumulativeIterationMetric):
"""
Compute the Calibration Error between predicted probabilities and ground truth labels.
This metric is suitable for multi-class tasks and supports batched inputs.

The input `y_pred` represents the model's predicted probabilities, and `y` represents the ground truth labels.
`y_pred` is expected to have probabilities, and `y` should be in one-hot format. You can use suitable transforms
in `monai.transforms.post` to achieve the desired format.

The `include_background` parameter can be set to `False` to exclude the first category (channel index 0),
which is conventionally assumed to be the background. This is particularly useful in segmentation tasks where
the background class might skew the calibration results.

The metric supports both single-channel and multi-channel data. For multi-channel data, the input tensors
should be in the format of BCHW[D], where B is the batch size, C is the number of channels, and HW[D]
are the spatial dimensions.

Args:
num_bins: Number of bins to divide probabilities into for calibration calculation. Defaults to 20.
include_background: Whether to include computation on the first channel of the predicted output.
Defaults to `True`.
calibration_reduction: Method for calculating calibration error values from binned data.
Available modes are `"expected"`, `"average"`, and `"maximum"`. Defaults to `"expected"`.
metric_reduction: Mode of reduction to apply to the metrics.
Reduction is only applied to non-NaN values.
Available reduction modes are `"none"`, `"mean"`, `"sum"`, `"mean_batch"`,
`"sum_batch"`, `"mean_channel"`, and `"sum_channel"`.
Defaults to `"mean"`. If set to `"none"`, no reduction will be performed.
get_not_nans: Whether to return the count of non-NaN values.
If `True`, `aggregate()` returns a tuple (metric, not_nans). Defaults to `False`.
right: Whether to use the right or left bin edge for binning. Defaults to `False` (left).

Example of the typical execution steps of this metric class follows :py:class:`monai.metrics.metric.Cumulative`.

Example:
>>> from monai.transforms import Activations, AsDiscrete
>>> # Transforms to convert model outputs to probabilities and labels to one-hot
>>> softmax = Activations(softmax=True) # or sigmoid=True for binary/multi-label
>>> to_onehot = AsDiscrete(to_onehot=num_classes)
>>> metric = CalibrationErrorMetric(num_bins=15, include_background=False, calibration_reduction="expected")
>>> for batch_data in dataloader:
>>> logits, labels = model(batch_data)
>>> preds = softmax(logits) # convert logits to probabilities
>>> labels_onehot = to_onehot(labels) # convert labels to one-hot format
>>> metric(y_pred=preds, y=labels_onehot)
>>> ece = metric.aggregate()
"""

def __init__(
self,
num_bins: int = 20,
include_background: bool = True,
calibration_reduction: CalibrationReduction | str = CalibrationReduction.EXPECTED,
metric_reduction: MetricReduction | str = MetricReduction.MEAN,
get_not_nans: bool = False,
right: bool = False,
) -> None:
super().__init__()
self.num_bins = num_bins
self.include_background = include_background
self.calibration_reduction = CalibrationReduction(calibration_reduction)
self.metric_reduction = metric_reduction
self.get_not_nans = get_not_nans
self.right = right

def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor, **kwargs: Any) -> torch.Tensor: # type: ignore[override]
"""
Compute calibration error for the given predictions and ground truth.

Args:
y_pred: input data to compute. It should be in the format of (batch, channel, spatial...).
It represents probability predictions of the model.
y: ground truth in one-hot format. It should be in the format of (batch, channel, spatial...).
The values should be binarized.
**kwargs: additional keyword arguments (unused, for API compatibility).

Returns:
Calibration error tensor with shape (batch, channel).
"""
if not self.include_background:
y_pred, y = ignore_background(y_pred=y_pred, y=y)

mean_p_per_bin, mean_gt_per_bin, bin_counts = calibration_binning(
y_pred=y_pred, y=y, num_bins=self.num_bins, right=self.right
)

# Calculate the absolute differences, ignoring nan values
abs_diff = torch.abs(mean_p_per_bin - mean_gt_per_bin)

if self.calibration_reduction == CalibrationReduction.EXPECTED:
# Calculate the weighted sum of absolute differences
return torch.nansum(abs_diff * bin_counts, dim=-1) / torch.sum(bin_counts, dim=-1)
elif self.calibration_reduction == CalibrationReduction.AVERAGE:
return torch.nanmean(abs_diff, dim=-1) # Average across all dimensions, ignoring nan
elif self.calibration_reduction == CalibrationReduction.MAXIMUM:
abs_diff_no_nan = torch.nan_to_num(abs_diff, nan=0.0)
return torch.max(abs_diff_no_nan, dim=-1).values # Maximum across all dimensions
Comment on lines +228 to +235
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

🧩 Analysis chain

🏁 Script executed:

find . -type f -name "calibration.py" | grep metrics

Repository: Project-MONAI/MONAI

Length of output: 93


🏁 Script executed:

cat -n monai/metrics/calibration.py | head -250 | tail -50

Repository: Project-MONAI/MONAI

Length of output: 2965


🏁 Script executed:

rg "CalibrationReduction" --type py -B 2 -A 2

Repository: Project-MONAI/MONAI

Length of output: 6926


🏁 Script executed:

fd "test.*calibration" --type f

Repository: Project-MONAI/MONAI

Length of output: 152


🏁 Script executed:

cat tests/metrics/test_calibration_metric.py

Repository: Project-MONAI/MONAI

Length of output: 12660


🏁 Script executed:

rg "calibration_binning" -B 5 -A 10 monai/metrics/calibration.py | head -80

Repository: Project-MONAI/MONAI

Length of output: 1757


🏁 Script executed:

python3 << 'EOF'
import torch

# Test torch.max behavior with all NaN tensor
all_nan = torch.tensor([[float('nan'), float('nan')], [float('nan'), float('nan')]])
print("All NaN tensor:", all_nan)
print("torch.max(all_nan, dim=-1):", torch.max(all_nan, dim=-1))

# Test nan_to_num with nan=0.0
nan_to_zero = torch.nan_to_num(all_nan, nan=0.0)
print("\nnan_to_num with nan=0.0:", nan_to_zero)
print("torch.max(nan_to_zero, dim=-1):", torch.max(nan_to_zero, dim=-1))

# Test nan_to_num with nan=-inf
nan_to_neginf = torch.nan_to_num(all_nan, nan=-float('inf'))
print("\nnan_to_num with nan=-inf:", nan_to_neginf)
print("torch.max(nan_to_neginf, dim=-1):", torch.max(nan_to_neginf, dim=-1))

# Test torch.nanmax
print("\ntorch.nanmax(all_nan, dim=-1):", torch.nanmax(all_nan, dim=-1))

# Test comparison for restoration
max_vals = torch.max(nan_to_neginf, dim=-1).values
print("\nmax_vals after max(nan_to_neginf):", max_vals)
print("max_vals == -inf:", max_vals == -float('inf'))
max_vals[max_vals == -float('inf')] = float('nan')
print("After restoring NaN:", max_vals)

EOF

Repository: Project-MONAI/MONAI

Length of output: 149


🏁 Script executed:

python3 << 'EOF'
import torch

# Test 1: torch.max behavior with all NaN tensor
print("=== Test 1: All NaN tensor ===")
all_nan = torch.tensor([[float('nan'), float('nan')], [float('nan'), float('nan')]])
print("Input:", all_nan)
result = torch.max(all_nan, dim=-1)
print("torch.max(all_nan, dim=-1).values:", result.values)
print()

# Test 2: nan_to_num with nan=0.0
print("=== Test 2: nan_to_num with nan=0.0 ===")
nan_to_zero = torch.nan_to_num(all_nan, nan=0.0)
print("After nan_to_num(nan=0.0):", nan_to_zero)
print("torch.max result:", torch.max(nan_to_zero, dim=-1).values)
print()

# Test 3: nan_to_num with nan=-inf
print("=== Test 3: nan_to_num with nan=-inf ===")
nan_to_neginf = torch.nan_to_num(all_nan, nan=-float('inf'))
print("After nan_to_num(nan=-inf):", nan_to_neginf)
max_result = torch.max(nan_to_neginf, dim=-1).values
print("torch.max result:", max_result)
print()

# Test 4: Restore NaN after max
print("=== Test 4: Restore NaN ===")
max_vals = max_result.clone()
print("Original max_vals:", max_vals)
max_vals[max_vals == -float('inf')] = float('nan')
print("After restoring NaN:", max_vals)
print()

# Test 5: Check if torch.nanmax exists and works
print("=== Test 5: torch.nanmax ===")
try:
    result = torch.nanmax(all_nan, dim=-1)
    print("torch.nanmax(all_nan, dim=-1):", result)
except Exception as e:
    print("Error:", e)

EOF

Repository: Project-MONAI/MONAI

Length of output: 180


🏁 Script executed:

rg "all.*nan|empty.*bin|nan.*maximum" tests/metrics/test_calibration_metric.py -i

Repository: Project-MONAI/MONAI

Length of output: 45


🏁 Script executed:

sed -n '1,150p' monai/metrics/calibration.py | tail -80

Repository: Project-MONAI/MONAI

Length of output: 3534


Use sentinel value to preserve NaN when all bins are empty in MAXIMUM reduction

When all bins are empty (all NaN), nan_to_num(..., nan=0.0) converts NaN values to 0, causing MCE to return 0 instead of NaN. This misrepresents "no valid data" as "zero error". Replace with -inf as sentinel and restore NaN for all-NaN cases:

Suggested fix
        elif self.calibration_reduction == CalibrationReduction.MAXIMUM:
-            abs_diff_no_nan = torch.nan_to_num(abs_diff, nan=0.0)
-            return torch.max(abs_diff_no_nan, dim=-1).values  # Maximum across all dimensions
+            abs_diff_no_nan = torch.nan_to_num(abs_diff, nan=-torch.inf)
+            max_vals = torch.max(abs_diff_no_nan, dim=-1).values
+            max_vals[max_vals == -torch.inf] = torch.nan
+            return max_vals  # Maximum across valid bins

Additionally, add a test case for the all-empty-bins edge case to prevent regression.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if self.calibration_reduction == CalibrationReduction.EXPECTED:
# Calculate the weighted sum of absolute differences
return torch.nansum(abs_diff * bin_counts, dim=-1) / torch.sum(bin_counts, dim=-1)
elif self.calibration_reduction == CalibrationReduction.AVERAGE:
return torch.nanmean(abs_diff, dim=-1) # Average across all dimensions, ignoring nan
elif self.calibration_reduction == CalibrationReduction.MAXIMUM:
abs_diff_no_nan = torch.nan_to_num(abs_diff, nan=0.0)
return torch.max(abs_diff_no_nan, dim=-1).values # Maximum across all dimensions
if self.calibration_reduction == CalibrationReduction.EXPECTED:
# Calculate the weighted sum of absolute differences
return torch.nansum(abs_diff * bin_counts, dim=-1) / torch.sum(bin_counts, dim=-1)
elif self.calibration_reduction == CalibrationReduction.AVERAGE:
return torch.nanmean(abs_diff, dim=-1) # Average across all dimensions, ignoring nan
elif self.calibration_reduction == CalibrationReduction.MAXIMUM:
abs_diff_no_nan = torch.nan_to_num(abs_diff, nan=-torch.inf)
max_vals = torch.max(abs_diff_no_nan, dim=-1).values
max_vals[max_vals == -torch.inf] = torch.nan
return max_vals # Maximum across valid bins
🤖 Prompt for AI Agents
In `@monai/metrics/calibration.py` around lines 228 - 235, In the
CalibrationReduction.MAXIMUM branch, don’t convert NaN to 0 (which hides “no
data”); instead use a -inf sentinel when calling torch.nan_to_num on abs_diff
(e.g. nan=-torch.inf), take the max along dim=-1, then detect buckets that were
all-NaN (e.g. all_nan_mask = torch.isnan(abs_diff).all(dim=-1)) and restore
those positions in the result to NaN; update the method where
self.calibration_reduction is checked (the MAXIMUM branch that uses
abs_diff_no_nan) accordingly and add a unit test covering the “all bins empty”
case to prevent regressions.

else:
raise ValueError(f"Unsupported calibration reduction: {self.calibration_reduction}")

def aggregate(
self, reduction: MetricReduction | str | None = None
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
"""
Execute reduction logic for the output of `_compute_tensor`.

Args:
reduction: define mode of reduction to the metrics, will only apply reduction on `not-nan` values,
available reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``,
``"mean_channel"``, ``"sum_channel"``}, default to `self.metric_reduction`. if "none", will not
do reduction.

Returns:
If `get_not_nans` is True, returns a tuple (metric, not_nans), otherwise returns only the metric.
"""
data = self.get_buffer()
if not isinstance(data, torch.Tensor):
raise ValueError("the data to aggregate must be PyTorch Tensor.")

# do metric reduction
f, not_nans = do_metric_reduction(data, reduction or self.metric_reduction)
return (f, not_nans) if self.get_not_nans else f
Loading
Loading