diff --git a/monai/handlers/__init__.py b/monai/handlers/__init__.py index ed5db8a7f3..7fc7b6df57 100644 --- a/monai/handlers/__init__.py +++ b/monai/handlers/__init__.py @@ -12,6 +12,7 @@ from __future__ import annotations from .average_precision import AveragePrecision +from .calibration import CalibrationError from .checkpoint_loader import CheckpointLoader from .checkpoint_saver import CheckpointSaver from .classification_saver import ClassificationSaver diff --git a/monai/handlers/calibration.py b/monai/handlers/calibration.py new file mode 100644 index 0000000000..afc4f45a50 --- /dev/null +++ b/monai/handlers/calibration.py @@ -0,0 +1,71 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from collections.abc import Callable + +from monai.handlers.ignite_metric import IgniteMetricHandler +from monai.metrics import CalibrationErrorMetric, CalibrationReduction +from monai.utils import MetricReduction + +__all__ = ["CalibrationError"] + + +class CalibrationError(IgniteMetricHandler): + """ + Computes Calibration Error and reports the aggregated value according to `metric_reduction` + over all accumulated iterations. Can return the expected, average, or maximum calibration error. + + Args: + num_bins: number of bins to calculate calibration. Defaults to 20. + include_background: whether to include calibration error computation on the first channel of + the predicted output. Defaults to True. + calibration_reduction: Method for calculating calibration error values from binned data. + Available modes are `"expected"`, `"average"`, and `"maximum"`. Defaults to `"expected"`. + metric_reduction: Mode of reduction to apply to the metrics. + Reduction is only applied to non-NaN values. + Available reduction modes are `"none"`, `"mean"`, `"sum"`, `"mean_batch"`, + `"sum_batch"`, `"mean_channel"`, and `"sum_channel"`. + Defaults to `"mean"`. If set to `"none"`, no reduction will be performed. + output_transform: callable to extract `y_pred` and `y` from `ignite.engine.state.output` then + construct `(y_pred, y)` pair, where `y_pred` and `y` can be `batch-first` Tensors or + lists of `channel-first` Tensors. the form of `(y_pred, y)` is required by the `update()`. + `engine.state` and `output_transform` inherit from the ignite concept: + https://pytorch.org/ignite/concepts.html#state, explanation and usage example are in the tutorial: + https://github.com/Project-MONAI/tutorials/blob/master/modules/batch_output_transform.ipynb. + save_details: whether to save metric computation details per image, for example: calibration error + of every image. default to True, will save to `engine.state.metric_details` dict with the + metric name as key. + + """ + + def __init__( + self, + num_bins: int = 20, + include_background: bool = True, + calibration_reduction: CalibrationReduction | str = CalibrationReduction.EXPECTED, + metric_reduction: MetricReduction | str = MetricReduction.MEAN, + output_transform: Callable = lambda x: x, + save_details: bool = True, + ) -> None: + metric_fn = CalibrationErrorMetric( + num_bins=num_bins, + include_background=include_background, + calibration_reduction=calibration_reduction, + metric_reduction=metric_reduction, + ) + + super().__init__( + metric_fn=metric_fn, + output_transform=output_transform, + save_details=save_details, + ) diff --git a/monai/metrics/__init__.py b/monai/metrics/__init__.py index ae20903cfd..0da25feca9 100644 --- a/monai/metrics/__init__.py +++ b/monai/metrics/__init__.py @@ -13,6 +13,7 @@ from .active_learning_metrics import LabelQualityScore, VarianceMetric, compute_variance, label_quality_score from .average_precision import AveragePrecisionMetric, compute_average_precision +from .calibration import CalibrationErrorMetric, CalibrationReduction, calibration_binning from .confusion_matrix import ConfusionMatrixMetric, compute_confusion_matrix_metric, get_confusion_matrix from .cumulative_average import CumulativeAverage from .f_beta_score import FBetaScore diff --git a/monai/metrics/calibration.py b/monai/metrics/calibration.py new file mode 100644 index 0000000000..8d7b5729b9 --- /dev/null +++ b/monai/metrics/calibration.py @@ -0,0 +1,260 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import Any + +import torch + +from monai.metrics.metric import CumulativeIterationMetric +from monai.metrics.utils import do_metric_reduction, ignore_background +from monai.utils import MetricReduction +from monai.utils.enums import StrEnum + +__all__ = [ + "calibration_binning", + "CalibrationErrorMetric", + "CalibrationReduction", +] + + +def calibration_binning( + y_pred: torch.Tensor, y: torch.Tensor, num_bins: int = 20, right: bool = False +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Compute calibration bins for predicted probabilities and ground truth labels. + This function calculates the mean predicted probabilities, mean ground truths, + and bin counts for each bin using a hard binning calibration approach. + + The function operates on input and target tensors with batch and channel dimensions, + handling each batch and channel separately. For bins that do not contain any elements, + the mean predicted values and mean ground truth values are set to NaN. + + Args: + y_pred: predicted tensor with shape [batch, channel, spatial], where spatial + can be any number of dimensions. The y_pred tensor represents probabilities. + Values should be in the range [0, 1] (probabilities). + y: Target tensor with the same shape as y_pred. It represents ground truth values. + num_bins: The number of bins to use for calibration. Defaults to 20. Must be >= 1. + right: If False (default), the bins include the left boundary and exclude the right boundary. + If True, the bins exclude the left boundary and include the right boundary. + + Returns: + A tuple of three tensors: + - mean_p_per_bin: Tensor of shape [batch_size, num_channels, num_bins] containing + the mean predicted values in each bin. + - mean_gt_per_bin: Tensor of shape [batch_size, num_channels, num_bins] containing + the mean ground truth values in each bin. + - bin_counts: Tensor of shape [batch_size, num_channels, num_bins] containing + the count of elements in each bin. + + Raises: + ValueError: If the input and target shapes do not match, if the input has fewer than 3 dimensions, + or if num_bins < 1. + + Note: + This function currently uses nested for loops over batch and channel dimensions + for binning operations. Future improvements may include vectorizing these operations + for enhanced performance. + """ + # Input validation + if y_pred.shape != y.shape: + raise ValueError(f"y_pred and y must have the same shape, got {y_pred.shape} and {y.shape}.") + if y_pred.ndim < 3: + raise ValueError(f"y_pred must have shape (B, C, spatial...), got ndim={y_pred.ndim}.") + if num_bins < 1: + raise ValueError(f"num_bins must be >= 1, got {num_bins}.") + + batch_size, num_channels = y_pred.shape[:2] + boundaries = torch.linspace( + start=0.0, + end=1.0 + torch.finfo(torch.float32).eps, + steps=num_bins + 1, + device=y_pred.device, + ) + + mean_p_per_bin = torch.zeros(batch_size, num_channels, num_bins, device=y_pred.device) + mean_gt_per_bin = torch.zeros_like(mean_p_per_bin) + bin_counts = torch.zeros_like(mean_p_per_bin) + + y_pred_flat = y_pred.flatten(start_dim=2).float() + y_flat = y.flatten(start_dim=2).float() + + for b in range(batch_size): + for c in range(num_channels): + values_p = y_pred_flat[b, c, :] + values_gt = y_flat[b, c, :] + + # Compute bin indices and clamp to valid range to handle out-of-range values + bin_idx = torch.bucketize(values_p, boundaries[1:], right=right) + bin_idx = bin_idx.clamp(max=num_bins - 1) + + # Compute bin counts using scatter_add + counts = torch.zeros(num_bins, device=y_pred.device, dtype=torch.float32) + counts.scatter_add_(0, bin_idx, torch.ones_like(values_p)) + bin_counts[b, c, :] = counts + + # Compute sums for mean calculation using scatter_add (more compatible than scatter_reduce) + sum_p = torch.zeros(num_bins, device=y_pred.device, dtype=torch.float32) + sum_p.scatter_add_(0, bin_idx, values_p) + + sum_gt = torch.zeros(num_bins, device=y_pred.device, dtype=torch.float32) + sum_gt.scatter_add_(0, bin_idx, values_gt) + + # Compute means, avoiding division by zero + safe_counts = counts.clamp(min=1) + mean_p_per_bin[b, c, :] = sum_p / safe_counts + mean_gt_per_bin[b, c, :] = sum_gt / safe_counts + + # Set empty bins to NaN + mean_p_per_bin[bin_counts == 0] = torch.nan + mean_gt_per_bin[bin_counts == 0] = torch.nan + + return mean_p_per_bin, mean_gt_per_bin, bin_counts + + +class CalibrationReduction(StrEnum): + """ + Enumeration of calibration error reduction methods. + + - EXPECTED: Expected Calibration Error (ECE) - weighted average by bin count + - AVERAGE: Average Calibration Error (ACE) - simple average across bins + - MAXIMUM: Maximum Calibration Error (MCE) - maximum error across bins + """ + + EXPECTED = "expected" + AVERAGE = "average" + MAXIMUM = "maximum" + + +class CalibrationErrorMetric(CumulativeIterationMetric): + """ + Compute the Calibration Error between predicted probabilities and ground truth labels. + This metric is suitable for multi-class tasks and supports batched inputs. + + The input `y_pred` represents the model's predicted probabilities, and `y` represents the ground truth labels. + `y_pred` is expected to have probabilities, and `y` should be in one-hot format. You can use suitable transforms + in `monai.transforms.post` to achieve the desired format. + + The `include_background` parameter can be set to `False` to exclude the first category (channel index 0), + which is conventionally assumed to be the background. This is particularly useful in segmentation tasks where + the background class might skew the calibration results. + + The metric supports both single-channel and multi-channel data. For multi-channel data, the input tensors + should be in the format of BCHW[D], where B is the batch size, C is the number of channels, and HW[D] + are the spatial dimensions. + + Args: + num_bins: Number of bins to divide probabilities into for calibration calculation. Defaults to 20. + include_background: Whether to include computation on the first channel of the predicted output. + Defaults to `True`. + calibration_reduction: Method for calculating calibration error values from binned data. + Available modes are `"expected"`, `"average"`, and `"maximum"`. Defaults to `"expected"`. + metric_reduction: Mode of reduction to apply to the metrics. + Reduction is only applied to non-NaN values. + Available reduction modes are `"none"`, `"mean"`, `"sum"`, `"mean_batch"`, + `"sum_batch"`, `"mean_channel"`, and `"sum_channel"`. + Defaults to `"mean"`. If set to `"none"`, no reduction will be performed. + get_not_nans: Whether to return the count of non-NaN values. + If `True`, `aggregate()` returns a tuple (metric, not_nans). Defaults to `False`. + right: Whether to use the right or left bin edge for binning. Defaults to `False` (left). + + Example of the typical execution steps of this metric class follows :py:class:`monai.metrics.metric.Cumulative`. + + Example: + >>> from monai.transforms import Activations, AsDiscrete + >>> # Transforms to convert model outputs to probabilities and labels to one-hot + >>> softmax = Activations(softmax=True) # or sigmoid=True for binary/multi-label + >>> to_onehot = AsDiscrete(to_onehot=num_classes) + >>> metric = CalibrationErrorMetric(num_bins=15, include_background=False, calibration_reduction="expected") + >>> for batch_data in dataloader: + >>> logits, labels = model(batch_data) + >>> preds = softmax(logits) # convert logits to probabilities + >>> labels_onehot = to_onehot(labels) # convert labels to one-hot format + >>> metric(y_pred=preds, y=labels_onehot) + >>> ece = metric.aggregate() + """ + + def __init__( + self, + num_bins: int = 20, + include_background: bool = True, + calibration_reduction: CalibrationReduction | str = CalibrationReduction.EXPECTED, + metric_reduction: MetricReduction | str = MetricReduction.MEAN, + get_not_nans: bool = False, + right: bool = False, + ) -> None: + super().__init__() + self.num_bins = num_bins + self.include_background = include_background + self.calibration_reduction = CalibrationReduction(calibration_reduction) + self.metric_reduction = metric_reduction + self.get_not_nans = get_not_nans + self.right = right + + def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor, **kwargs: Any) -> torch.Tensor: # type: ignore[override] + """ + Compute calibration error for the given predictions and ground truth. + + Args: + y_pred: input data to compute. It should be in the format of (batch, channel, spatial...). + It represents probability predictions of the model. + y: ground truth in one-hot format. It should be in the format of (batch, channel, spatial...). + The values should be binarized. + **kwargs: additional keyword arguments (unused, for API compatibility). + + Returns: + Calibration error tensor with shape (batch, channel). + """ + if not self.include_background: + y_pred, y = ignore_background(y_pred=y_pred, y=y) + + mean_p_per_bin, mean_gt_per_bin, bin_counts = calibration_binning( + y_pred=y_pred, y=y, num_bins=self.num_bins, right=self.right + ) + + # Calculate the absolute differences, ignoring nan values + abs_diff = torch.abs(mean_p_per_bin - mean_gt_per_bin) + + if self.calibration_reduction == CalibrationReduction.EXPECTED: + # Calculate the weighted sum of absolute differences + return torch.nansum(abs_diff * bin_counts, dim=-1) / torch.sum(bin_counts, dim=-1) + elif self.calibration_reduction == CalibrationReduction.AVERAGE: + return torch.nanmean(abs_diff, dim=-1) # Average across all dimensions, ignoring nan + elif self.calibration_reduction == CalibrationReduction.MAXIMUM: + abs_diff_no_nan = torch.nan_to_num(abs_diff, nan=0.0) + return torch.max(abs_diff_no_nan, dim=-1).values # Maximum across all dimensions + else: + raise ValueError(f"Unsupported calibration reduction: {self.calibration_reduction}") + + def aggregate( + self, reduction: MetricReduction | str | None = None + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + """ + Execute reduction logic for the output of `_compute_tensor`. + + Args: + reduction: define mode of reduction to the metrics, will only apply reduction on `not-nan` values, + available reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``, + ``"mean_channel"``, ``"sum_channel"``}, default to `self.metric_reduction`. if "none", will not + do reduction. + + Returns: + If `get_not_nans` is True, returns a tuple (metric, not_nans), otherwise returns only the metric. + """ + data = self.get_buffer() + if not isinstance(data, torch.Tensor): + raise ValueError("the data to aggregate must be PyTorch Tensor.") + + # do metric reduction + f, not_nans = do_metric_reduction(data, reduction or self.metric_reduction) + return (f, not_nans) if self.get_not_nans else f diff --git a/tests/handlers/test_handler_calibration_error.py b/tests/handlers/test_handler_calibration_error.py new file mode 100644 index 0000000000..5cc0c2609c --- /dev/null +++ b/tests/handlers/test_handler_calibration_error.py @@ -0,0 +1,184 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch +from parameterized import parameterized + +from monai.handlers import CalibrationError, from_engine +from monai.utils import IgniteInfo, min_version, optional_import +from tests.test_utils import assert_allclose + +Engine, has_ignite = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Engine") + +_device = "cuda:0" if torch.cuda.is_available() else "cpu" + +# Test cases for handler +# Format: [input_params, expected_value, expected_rows, expected_channels] +TEST_CASE_1 = [ + { + "num_bins": 5, + "include_background": True, + "calibration_reduction": "expected", + "metric_reduction": "mean", + "output_transform": from_engine(["pred", "label"]), + }, + 0.2250, + 4, # 2 batches * 2 iterations + 2, # 2 channels +] + +TEST_CASE_2 = [ + { + "num_bins": 5, + "include_background": False, + "calibration_reduction": "expected", + "metric_reduction": "mean", + "output_transform": from_engine(["pred", "label"]), + }, + 0.2500, + 4, # 2 batches * 2 iterations + 1, # 1 channel (background excluded) +] + +TEST_CASE_3 = [ + { + "num_bins": 5, + "include_background": True, + "calibration_reduction": "average", + "metric_reduction": "mean", + "output_transform": from_engine(["pred", "label"]), + }, + 0.2584, # Mean of [[0.2000, 0.4667], [0.2000, 0.1667]] + 4, + 2, +] + +TEST_CASE_4 = [ + { + "num_bins": 5, + "include_background": True, + "calibration_reduction": "maximum", + "metric_reduction": "mean", + "output_transform": from_engine(["pred", "label"]), + }, + 0.4000, # Mean of [[0.3000, 0.7000], [0.3000, 0.3000]] + 4, + 2, +] + + +@unittest.skipUnless(has_ignite, "Requires pytorch-ignite") +class TestHandlerCalibrationError(unittest.TestCase): + + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4]) + def test_compute(self, input_params, expected_value, expected_rows, expected_channels): + calibration_metric = CalibrationError(**input_params) + + # Test data: 2 batches with 2 channels each + y_pred = torch.tensor( + [ + [[[0.7, 0.3], [0.1, 0.9]], [[0.7, 0.3], [0.5, 0.5]]], + [[[0.9, 0.9], [0.3, 0.3]], [[0.1, 0.1], [0.9, 0.7]]], + ] + ).to(_device) + y = torch.tensor( + [ + [[[1, 0], [0, 1]], [[0, 1], [1, 0]]], + [[[1, 1], [0, 0]], [[0, 0], [1, 1]]], + ] + ).to(_device) + + # Create data as list of batches (2 iterations) + data = [{"pred": y_pred, "label": y}, {"pred": y_pred, "label": y}] + + def _val_func(engine, batch): + return batch + + engine = Engine(_val_func) + calibration_metric.attach(engine=engine, name="calibration_error") + + engine.run(data, max_epochs=1) + + assert_allclose( + engine.state.metrics["calibration_error"], expected_value, atol=1e-4, rtol=1e-4, type_test=False + ) + + # Check details shape using invariants rather than exact tuple + details = engine.state.metric_details["calibration_error"] + self.assertEqual(details.shape[0], expected_rows) + self.assertEqual(details.shape[-1], expected_channels) + + +@unittest.skipUnless(has_ignite, "Requires pytorch-ignite") +class TestHandlerCalibrationErrorEdgeCases(unittest.TestCase): + + def test_single_iteration(self): + """Test handler with single iteration.""" + calibration_metric = CalibrationError( + num_bins=5, + include_background=True, + calibration_reduction="expected", + metric_reduction="mean", + output_transform=from_engine(["pred", "label"]), + ) + + y_pred = torch.tensor([[[[0.7, 0.3], [0.1, 0.9]]]]).to(_device) + y = torch.tensor([[[[1, 0], [0, 1]]]]).to(_device) + + data = [{"pred": y_pred, "label": y}] + + def _val_func(engine, batch): + return batch + + engine = Engine(_val_func) + calibration_metric.attach(engine=engine, name="calibration_error") + + engine.run(data, max_epochs=1) + + assert_allclose(engine.state.metrics["calibration_error"], 0.2, atol=1e-4, rtol=1e-4, type_test=False) + + def test_save_details_false(self): + """Test handler with save_details=False.""" + calibration_metric = CalibrationError( + num_bins=5, + include_background=True, + calibration_reduction="expected", + metric_reduction="mean", + output_transform=from_engine(["pred", "label"]), + save_details=False, + ) + + y_pred = torch.tensor([[[[0.7, 0.3], [0.1, 0.9]]]]).to(_device) + y = torch.tensor([[[[1, 0], [0, 1]]]]).to(_device) + + data = [{"pred": y_pred, "label": y}] + + def _val_func(engine, batch): + return batch + + engine = Engine(_val_func) + calibration_metric.attach(engine=engine, name="calibration_error") + + engine.run(data, max_epochs=1) + + assert_allclose(engine.state.metrics["calibration_error"], 0.2, atol=1e-4, rtol=1e-4, type_test=False) + + # When save_details=False, metric_details should not exist or should not have the metric key + if hasattr(engine.state, "metric_details"): + self.assertNotIn("calibration_error", engine.state.metric_details or {}) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/metrics/test_calibration_metric.py b/tests/metrics/test_calibration_metric.py new file mode 100644 index 0000000000..f220525793 --- /dev/null +++ b/tests/metrics/test_calibration_metric.py @@ -0,0 +1,357 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest +from unittest import mock + +import torch +from parameterized import parameterized + +from monai.metrics import CalibrationErrorMetric, CalibrationReduction, calibration_binning +from monai.utils import MetricReduction +from tests.test_utils import assert_allclose + +_device = "cuda:0" if torch.cuda.is_available() else "cpu" + +# Test cases for calibration binning +# Format: [name, y_pred, y, num_bins, right, expected_mean_p, expected_mean_gt, expected_counts] +TEST_BINNING_SMALL_MID = [ + "small_mid", + torch.tensor([[[[0.7, 0.3], [0.1, 0.9]]]]), + torch.tensor([[[[1, 0], [0, 1]]]]), + 5, + False, + torch.tensor([[[0.1, 0.3, float("nan"), 0.7, 0.9]]]), + torch.tensor([[[0.0, 0.0, float("nan"), 1.0, 1.0]]]), + torch.tensor([[[1.0, 1.0, 0.0, 1.0, 1.0]]]), +] + +TEST_BINNING_LARGE_MID = [ + "large_mid", + torch.tensor( + [ + [[[0.7, 0.3], [0.1, 0.9]], [[0.7, 0.3], [0.5, 0.5]]], + [[[0.9, 0.9], [0.3, 0.3]], [[0.1, 0.1], [0.9, 0.7]]], + ] + ), + torch.tensor( + [ + [[[1, 0], [0, 1]], [[0, 1], [1, 0]]], + [[[1, 1], [0, 0]], [[0, 0], [1, 1]]], + ] + ), + 5, + False, + torch.tensor( + [ + [[0.1, 0.3, float("nan"), 0.7, 0.9], [float("nan"), 0.3, 0.5, 0.7, float("nan")]], + [[float("nan"), 0.3, float("nan"), float("nan"), 0.9], [0.1, float("nan"), float("nan"), 0.7, 0.9]], + ] + ), + torch.tensor( + [ + [[0.0, 0.0, float("nan"), 1.0, 1.0], [float("nan"), 1.0, 0.5, 0.0, float("nan")]], + [[float("nan"), 0.0, float("nan"), float("nan"), 1.0], [0.0, float("nan"), float("nan"), 1.0, 1.0]], + ] + ), + torch.tensor( + [ + [[1.0, 1.0, 0.0, 1.0, 1.0], [0.0, 1.0, 2.0, 1.0, 0.0]], + [[0.0, 2.0, 0.0, 0.0, 2.0], [2.0, 0.0, 0.0, 1.0, 1.0]], + ] + ), +] + +TEST_BINNING_SMALL_LEFT_EDGE = [ + "small_left_edge", + torch.tensor([[[[0.8, 0.2], [0.4, 0.6]]]]), + torch.tensor([[[[1, 0], [0, 1]]]]), + 5, + False, + torch.tensor([[[0.2, 0.4, 0.6, 0.8, float("nan")]]]), + torch.tensor([[[0.0, 0.0, 1.0, 1.0, float("nan")]]]), + torch.tensor([[[1.0, 1.0, 1.0, 1.0, 0.0]]]), +] + +TEST_BINNING_SMALL_RIGHT_EDGE = [ + "small_right_edge", + torch.tensor([[[[0.8, 0.2], [0.4, 0.6]]]]), + torch.tensor([[[[1, 0], [0, 1]]]]), + 5, + True, + torch.tensor([[[float("nan"), 0.2, 0.4, 0.6, 0.8]]]), + torch.tensor([[[float("nan"), 0.0, 0.0, 1.0, 1.0]]]), + torch.tensor([[[0.0, 1.0, 1.0, 1.0, 1.0]]]), +] + +BINNING_TEST_CASES = [ + TEST_BINNING_SMALL_MID, + TEST_BINNING_LARGE_MID, + TEST_BINNING_SMALL_LEFT_EDGE, + TEST_BINNING_SMALL_RIGHT_EDGE, +] + +# Test cases for calibration error metric values +# Format: [name, y_pred, y, num_bins, expected_expected, expected_average, expected_maximum] +TEST_VALUE_1B1C = [ + "1b1c", + torch.tensor([[[[0.7, 0.3], [0.1, 0.9]]]]), + torch.tensor([[[[1, 0], [0, 1]]]]), + 5, + torch.tensor([[0.2]]), + torch.tensor([[0.2]]), + torch.tensor([[0.3]]), +] + +TEST_VALUE_2B2C = [ + "2b2c", + torch.tensor( + [ + [[[0.7, 0.3], [0.1, 0.9]], [[0.7, 0.3], [0.5, 0.5]]], + [[[0.9, 0.9], [0.3, 0.3]], [[0.1, 0.1], [0.9, 0.7]]], + ] + ), + torch.tensor( + [ + [[[1, 0], [0, 1]], [[0, 1], [1, 0]]], + [[[1, 1], [0, 0]], [[0, 0], [1, 1]]], + ] + ), + 5, + torch.tensor([[0.2000, 0.3500], [0.2000, 0.1500]]), + torch.tensor([[0.2000, 0.4667], [0.2000, 0.1667]]), + torch.tensor([[0.3000, 0.7000], [0.3000, 0.3000]]), +] + +VALUE_TEST_CASES = [ + TEST_VALUE_1B1C, + TEST_VALUE_2B2C, +] + + +class TestCalibrationBinning(unittest.TestCase): + + @parameterized.expand(BINNING_TEST_CASES) + def test_binning(self, _name, y_pred, y, num_bins, right, expected_mean_p, expected_mean_gt, expected_counts): + y_pred = y_pred.to(_device) + y = y.to(_device) + expected_mean_p = expected_mean_p.to(_device) + expected_mean_gt = expected_mean_gt.to(_device) + expected_counts = expected_counts.to(_device) + + # Use mock.patch to replace torch.linspace + # This is to avoid floating point precision issues when looking at edge conditions + mock_boundaries = torch.tensor([0.0, 0.2, 0.4, 0.6, 0.8, 1.0], device=_device) + with mock.patch("monai.metrics.calibration.torch.linspace", return_value=mock_boundaries): + mean_p_per_bin, mean_gt_per_bin, bin_counts = calibration_binning(y_pred, y, num_bins=num_bins, right=right) + + # Handle NaN comparisons: compare NaN masks separately, then compare non-NaN values + # mean_p_per_bin + self.assertTrue(torch.equal(torch.isnan(mean_p_per_bin), torch.isnan(expected_mean_p))) + mask_p = ~torch.isnan(expected_mean_p) + if mask_p.any(): + assert_allclose(mean_p_per_bin[mask_p], expected_mean_p[mask_p], atol=1e-4, rtol=1e-4) + + # mean_gt_per_bin + self.assertTrue(torch.equal(torch.isnan(mean_gt_per_bin), torch.isnan(expected_mean_gt))) + mask_gt = ~torch.isnan(expected_mean_gt) + if mask_gt.any(): + assert_allclose(mean_gt_per_bin[mask_gt], expected_mean_gt[mask_gt], atol=1e-4, rtol=1e-4) + + # bin_counts (no NaNs) + assert_allclose(bin_counts, expected_counts, atol=1e-4, rtol=1e-4) + + def test_shape_mismatch_raises(self): + """Test that mismatched shapes raise ValueError.""" + y_pred = torch.tensor([[[[0.7, 0.3], [0.1, 0.9]]]]).to(_device) + y = torch.tensor([[[[1, 0], [0, 1], [0, 0]]]]).to(_device) # Different shape + with self.assertRaises(ValueError) as context: + calibration_binning(y_pred, y, num_bins=5) + self.assertIn("same shape", str(context.exception)) + + def test_insufficient_ndim_raises(self): + """Test that tensors with ndim < 3 raise ValueError.""" + y_pred = torch.tensor([[0.7, 0.3]]).to(_device) # Only 2D + y = torch.tensor([[1, 0]]).to(_device) + with self.assertRaises(ValueError) as context: + calibration_binning(y_pred, y, num_bins=5) + self.assertIn("ndim", str(context.exception)) + + def test_invalid_num_bins_raises(self): + """Test that num_bins < 1 raises ValueError.""" + y_pred = torch.tensor([[[[0.7, 0.3], [0.1, 0.9]]]]).to(_device) + y = torch.tensor([[[[1, 0], [0, 1]]]]).to(_device) + with self.assertRaises(ValueError) as context: + calibration_binning(y_pred, y, num_bins=0) + self.assertIn("num_bins", str(context.exception)) + + +class TestCalibrationErrorMetricValue(unittest.TestCase): + + @parameterized.expand(VALUE_TEST_CASES) + def test_expected_reduction(self, _name, y_pred, y, num_bins, expected_expected, _expected_average, _expected_max): + y_pred = y_pred.to(_device) + y = y.to(_device) + expected_expected = expected_expected.to(_device) + + metric = CalibrationErrorMetric( + num_bins=num_bins, + include_background=True, + calibration_reduction=CalibrationReduction.EXPECTED, + metric_reduction=MetricReduction.NONE, + ) + + metric(y_pred=y_pred, y=y) + result = metric.aggregate() + + assert_allclose(result, expected_expected, atol=1e-4, rtol=1e-4) + + @parameterized.expand(VALUE_TEST_CASES) + def test_average_reduction(self, _name, y_pred, y, num_bins, _expected_expected, expected_average, _expected_max): + y_pred = y_pred.to(_device) + y = y.to(_device) + expected_average = expected_average.to(_device) + + metric = CalibrationErrorMetric( + num_bins=num_bins, + include_background=True, + calibration_reduction=CalibrationReduction.AVERAGE, + metric_reduction=MetricReduction.NONE, + ) + + metric(y_pred=y_pred, y=y) + result = metric.aggregate() + + assert_allclose(result, expected_average, atol=1e-4, rtol=1e-4) + + @parameterized.expand(VALUE_TEST_CASES) + def test_maximum_reduction(self, _name, y_pred, y, num_bins, _expected_expected, _expected_average, expected_max): + y_pred = y_pred.to(_device) + y = y.to(_device) + expected_max = expected_max.to(_device) + + metric = CalibrationErrorMetric( + num_bins=num_bins, + include_background=True, + calibration_reduction=CalibrationReduction.MAXIMUM, + metric_reduction=MetricReduction.NONE, + ) + + metric(y_pred=y_pred, y=y) + result = metric.aggregate() + + assert_allclose(result, expected_max, atol=1e-4, rtol=1e-4) + + +class TestCalibrationErrorMetricOptions(unittest.TestCase): + + def test_include_background_false(self): + y_pred = torch.tensor( + [ + [[[0.7, 0.3], [0.1, 0.9]], [[0.7, 0.3], [0.5, 0.5]]], + [[[0.9, 0.9], [0.3, 0.3]], [[0.1, 0.1], [0.9, 0.7]]], + ] + ).to(_device) + y = torch.tensor( + [ + [[[1, 0], [0, 1]], [[0, 1], [1, 0]]], + [[[1, 1], [0, 0]], [[0, 0], [1, 1]]], + ] + ).to(_device) + + metric = CalibrationErrorMetric( + num_bins=5, + include_background=False, + calibration_reduction=CalibrationReduction.EXPECTED, + metric_reduction=MetricReduction.MEAN, + ) + + metric(y_pred=y_pred, y=y) + result = metric.aggregate() + + assert_allclose(result, torch.tensor(0.2500, device=_device), atol=1e-4, rtol=1e-4) + + def test_metric_reduction_mean(self): + y_pred = torch.tensor( + [ + [[[0.7, 0.3], [0.1, 0.9]], [[0.7, 0.3], [0.5, 0.5]]], + [[[0.9, 0.9], [0.3, 0.3]], [[0.1, 0.1], [0.9, 0.7]]], + ] + ).to(_device) + y = torch.tensor( + [ + [[[1, 0], [0, 1]], [[0, 1], [1, 0]]], + [[[1, 1], [0, 0]], [[0, 0], [1, 1]]], + ] + ).to(_device) + + metric = CalibrationErrorMetric( + num_bins=5, + include_background=True, + calibration_reduction=CalibrationReduction.EXPECTED, + metric_reduction=MetricReduction.MEAN, + ) + + metric(y_pred=y_pred, y=y) + result = metric.aggregate() + + # Mean of [[0.2000, 0.3500], [0.2000, 0.1500]] = 0.225 + assert_allclose(result, torch.tensor(0.2250, device=_device), atol=1e-4, rtol=1e-4) + + def test_get_not_nans(self): + y_pred = torch.tensor([[[[0.7, 0.3], [0.1, 0.9]]]]).to(_device) + y = torch.tensor([[[[1, 0], [0, 1]]]]).to(_device) + + metric = CalibrationErrorMetric( + num_bins=5, + include_background=True, + calibration_reduction=CalibrationReduction.EXPECTED, + metric_reduction=MetricReduction.MEAN, + get_not_nans=True, + ) + + metric(y_pred=y_pred, y=y) + result, not_nans = metric.aggregate() + + assert_allclose(result, torch.tensor(0.2, device=_device), atol=1e-4, rtol=1e-4) + self.assertEqual(not_nans.item(), 1) + + def test_cumulative_iterations(self): + """Test that the metric correctly accumulates over multiple iterations.""" + y_pred = torch.tensor([[[[0.7, 0.3], [0.1, 0.9]]]]).to(_device) + y = torch.tensor([[[[1, 0], [0, 1]]]]).to(_device) + + metric = CalibrationErrorMetric( + num_bins=5, + include_background=True, + calibration_reduction=CalibrationReduction.EXPECTED, + metric_reduction=MetricReduction.MEAN, + ) + + # First iteration + metric(y_pred=y_pred, y=y) + # Second iteration + metric(y_pred=y_pred, y=y) + + result = metric.aggregate() + # Should still be 0.2 since both iterations have the same data + assert_allclose(result, torch.tensor(0.2, device=_device), atol=1e-4, rtol=1e-4) + + # Test reset + metric.reset() + data = metric.get_buffer() + self.assertIsNone(data) + + +if __name__ == "__main__": + unittest.main()