diff --git a/monai/metrics/meandice.py b/monai/metrics/meandice.py index f21040d58e..0802cc3364 100644 --- a/monai/metrics/meandice.py +++ b/monai/metrics/meandice.py @@ -14,7 +14,7 @@ import torch from monai.metrics.utils import do_metric_reduction -from monai.utils import MetricReduction +from monai.utils import MetricReduction, deprecated_arg from .metric import CumulativeIterationMetric @@ -23,35 +23,76 @@ class DiceMetric(CumulativeIterationMetric): """ - Compute average Dice score for a set of pairs of prediction-groundtruth segmentations. + Computes Dice score for a set of pairs of prediction-groundtruth labels. It supports single-channel label maps + or multi-channel images with class segmentations per channel. This allows the computation for both multi-class + and multi-label tasks. - It supports both multi-classes and multi-labels tasks. - Input `y_pred` is compared with ground truth `y`. - `y_pred` is expected to have binarized predictions and `y` can be single-channel class indices or in the - one-hot format. The `include_background` parameter can be set to ``False`` to exclude - the first category (channel index 0) which is by convention assumed to be background. If the non-background - segmentations are small compared to the total image size they can get overwhelmed by the signal from the - background. `y_preds` and `y` can be a list of channel-first Tensor (CHW[D]) or a batch-first Tensor (BCHW[D]), - `y` can also be in the format of `B1HW[D]`. + If either prediction ``y_pred`` or ground truth ``y`` have shape BCHW[D], it is expected that these represent one- + hot segmentations for C number of classes. If either shape is B1HW[D], it is expected that these are label maps + and the number of classes must be specified by the ``num_classes`` parameter. In either case for either inputs, + this metric applies no activations and so non-binary values will produce unexpected results if this metric is used + for binary overlap measurement (ie. either was expected to be one-hot formatted). Soft labels are thus permitted by + this metric. Typically this implies that raw predictions from a network must first be activated and possibly made + into label maps, eg. for a multi-class prediction tensor softmax and then argmax should be applied over the channel + dimensions to produce a label map. + + The ``include_background`` parameter can be set to `False` to exclude the first category (channel index 0) which + is by convention assumed to be background. If the non-background segmentations are small compared to the total + image size they can get overwhelmed by the signal from the background. This assumes the shape of both prediction + and ground truth is BCHW[D]. + + The typical execution steps of this metric class follows :py:class:`monai.metrics.metric.Cumulative`. + + Further information can be found in the official + `MONAI Dice Overview `. + + Example: + + .. code-block:: python + + import torch + from monai.metrics import DiceMetric + from monai.losses import DiceLoss + from monai.networks import one_hot + + batch_size, n_classes, h, w = 7, 5, 128, 128 + + y_pred = torch.rand(batch_size, n_classes, h, w) # network predictions + y_pred = torch.argmax(y_pred, 1, True) # convert to label map + + # ground truth as label map + y = torch.randint(0, n_classes, size=(batch_size, 1, h, w)) + + dm = DiceMetric( + reduction="mean_batch", return_with_label=True, num_classes=n_classes + ) + + raw_scores = dm(y_pred, y) + print(dm.aggregate()) + + # now compute the Dice loss which should be the same as 1 - raw_scores + dl = DiceLoss(to_onehot_y=True, reduction="none") + loss = dl(one_hot(y_pred, n_classes), y).squeeze() + + print(1.0 - loss) # same as raw_scores - Example of the typical execution steps of this metric class follows :py:class:`monai.metrics.metric.Cumulative`. Args: - include_background: whether to include Dice computation on the first channel of - the predicted output. Defaults to ``True``. - reduction: define mode of reduction to the metrics, will only apply reduction on `not-nan` values, - available reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``, - ``"mean_channel"``, ``"sum_channel"``}, default to ``"mean"``. if "none", will not do reduction. - get_not_nans: whether to return the `not_nans` count, if True, aggregate() returns (metric, not_nans). - Here `not_nans` count the number of not nans for the metric, thus its shape equals to the shape of the metric. - ignore_empty: whether to ignore empty ground truth cases during calculation. - If `True`, NaN value will be set for empty ground truth cases. - If `False`, 1 will be set if the predictions of empty ground truth cases are also empty. - num_classes: number of input channels (always including the background). When this is None, + include_background: whether to include Dice computation on the first channel/category of the prediction and + ground truth. Defaults to ``True``, use ``False`` to exclude the background class. + reduction: defines mode of reduction to the metrics, this will only apply reduction on `not-nan` values. The + available reduction modes are enumerated by :py:class:`monai.utils.enums.MetricReduction`. If "none", is + selected, the metric will not do reduction. + get_not_nans: whether to return the `not_nans` count. If True, aggregate() returns `(metric, not_nans)` where + `not_nans` counts the number of valid values in the result, and will have the same shape. + ignore_empty: whether to ignore empty ground truth cases during calculation. If `True`, the `NaN` value will be + set for an empty ground truth cases, otherwise 1 will be set if the predictions of empty ground truth cases + are also empty. + num_classes: number of input channels (always including the background). When this is ``None``, ``y_pred.shape[1]`` will be used. This option is useful when both ``y_pred`` and ``y`` are single-channel class indices and the number of classes is not automatically inferred from data. return_with_label: whether to return the metrics with label, only works when reduction is "mean_batch". - If `True`, use "label_{index}" as the key corresponding to C channels; if 'include_background' is True, + If `True`, use "label_{index}" as the key corresponding to C channels; if ``include_background`` is True, the index begins at "0", otherwise at "1". It can also take a list of label names. The outcome will then be returned as a dictionary. @@ -77,22 +118,21 @@ def __init__( include_background=self.include_background, reduction=MetricReduction.NONE, get_not_nans=False, - softmax=False, + apply_argmax=False, ignore_empty=self.ignore_empty, num_classes=self.num_classes, ) def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: # type: ignore[override] """ + Compute the dice value using ``DiceHelper``. + Args: - y_pred: input data to compute, typical segmentation model output. - It must be one-hot format and first dim is batch, example shape: [16, 3, 32, 32]. The values - should be binarized. - y: ground truth to compute mean Dice metric. `y` can be single-channel class indices or - in the one-hot format. + y_pred: prediction value, see class docstring for format definition. + y: ground truth label. Raises: - ValueError: when `y_pred` has less than three dimensions. + ValueError: when `y_pred` has fewer than three dimensions. """ dims = y_pred.ndimension() if dims < 3: @@ -107,10 +147,8 @@ def aggregate( Execute reduction and aggregation logic for the output of `compute_dice`. Args: - reduction: define mode of reduction to the metrics, will only apply reduction on `not-nan` values, - available reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``, - ``"mean_channel"``, ``"sum_channel"``}, default to `self.reduction`. if "none", will not do reduction. - + reduction: defines mode of reduction as enumerated in :py:class:`monai.utils.enums.MetricReduction`. + By default this will do no reduction. """ data = self.get_buffer() if not isinstance(data, torch.Tensor): @@ -138,18 +176,20 @@ def compute_dice( ignore_empty: bool = True, num_classes: int | None = None, ) -> torch.Tensor: - """Computes Dice score metric for a batch of predictions. + """ + Computes Dice score metric for a batch of predictions. This performs the same computation as + :py:class:`monai.metrics.DiceMetric`, which is preferrable to use over this function. For input formats, see the + documentation for that class . Args: y_pred: input data to compute, typical segmentation model output. - `y_pred` can be single-channel class indices or in the one-hot format. - y: ground truth to compute mean dice metric. `y` can be single-channel class indices or in the one-hot format. - include_background: whether to include Dice computation on the first channel of - the predicted output. Defaults to True. - ignore_empty: whether to ignore empty ground truth cases during calculation. - If `True`, NaN value will be set for empty ground truth cases. - If `False`, 1 will be set if the predictions of empty ground truth cases are also empty. - num_classes: number of input channels (always including the background). When this is None, + y: ground truth to compute mean dice metric. + include_background: whether to include Dice computation on the first channel/category of the prediction and + ground truth. Defaults to ``True``, use ``False`` to exclude the background class. + ignore_empty: whether to ignore empty ground truth cases during calculation. If `True`, the `NaN` value will be + set for an empty ground truth cases, otherwise 1 will be set if the predictions of empty ground truth cases + are also empty. + num_classes: number of input channels (always including the background). When this is ``None``, ``y_pred.shape[1]`` will be used. This option is useful when both ``y_pred`` and ``y`` are single-channel class indices and the number of classes is not automatically inferred from data. @@ -161,7 +201,7 @@ def compute_dice( include_background=include_background, reduction=MetricReduction.NONE, get_not_nans=False, - softmax=False, + apply_argmax=False, ignore_empty=ignore_empty, num_classes=num_classes, )(y_pred=y_pred, y=y) @@ -169,8 +209,8 @@ def compute_dice( class DiceHelper: """ - Compute Dice score between two tensors `y_pred` and `y`. - `y_pred` and `y` can be single-channel class indices or in the one-hot format. + Compute Dice score between two tensors ``y_pred`` and ``y``. This is used by :py:class:`monai.metrics.DiceMetric`, + see the documentation for that class for input formats. Example: @@ -188,49 +228,65 @@ class DiceHelper: score, not_nans = DiceHelper(include_background=False, sigmoid=True, softmax=True)(y_pred, y) print(score, not_nans) + Args: + include_background: whether to include Dice computation on the first channel/category of the prediction and + ground truth. Defaults to ``True``, use ``False`` to exclude the background class. + threshold: if ``True`, ``y_pred`` will be thresholded at a value of 0.5. Defaults to False. + apply_argmax: whether ``y_pred`` are softmax activated outputs. If True, `argmax` will be performed to + get the discrete prediction. Defaults to the value of ``not threshold``. + activate: if this and ``threshold` are ``True``, sigmoid activation is applied to ``y_pred`` before + thresholding. Defaults to False. + get_not_nans: whether to return the number of not-nan values. + reduction: defines mode of reduction to the metrics, this will only apply reduction on `not-nan` values. The + available reduction modes are enumerated by :py:class:`monai.utils.enums.MetricReduction`. If "none", is + selected, the metric will not do reduction. + ignore_empty: whether to ignore empty ground truth cases during calculation. If `True`, the `NaN` value will be + set for an empty ground truth cases, otherwise 1 will be set if the predictions of empty ground truth cases + are also empty. + num_classes: number of input channels (always including the background). When this is ``None``, + ``y_pred.shape[1]`` will be used. This option is useful when both ``y_pred`` and ``y`` are + single-channel class indices and the number of classes is not automatically inferred from data. """ + @deprecated_arg("softmax", "1.5", "1.7", "Use `apply_argmax` instead.", new_name="apply_argmax") + @deprecated_arg("sigmoid", "1.5", "1.7", "Use `threshold` instead.", new_name="threshold") def __init__( self, include_background: bool | None = None, - sigmoid: bool = False, - softmax: bool | None = None, + threshold: bool = False, + apply_argmax: bool | None = None, activate: bool = False, get_not_nans: bool = True, reduction: MetricReduction | str = MetricReduction.MEAN_BATCH, ignore_empty: bool = True, num_classes: int | None = None, + sigmoid: bool | None = None, + softmax: bool | None = None, ) -> None: - """ + # handling deprecated arguments + if sigmoid is not None: + threshold = sigmoid + if softmax is not None: + apply_argmax = softmax - Args: - include_background: whether to include the score on the first channel - (default to the value of `sigmoid`, False). - sigmoid: whether ``y_pred`` are/will be sigmoid activated outputs. If True, thresholding at 0.5 - will be performed to get the discrete prediction. Defaults to False. - softmax: whether ``y_pred`` are softmax activated outputs. If True, `argmax` will be performed to - get the discrete prediction. Defaults to the value of ``not sigmoid``. - activate: whether to apply sigmoid to ``y_pred`` if ``sigmoid`` is True. Defaults to False. - This option is only valid when ``sigmoid`` is True. - get_not_nans: whether to return the number of not-nan values. - reduction: define mode of reduction to the metrics - ignore_empty: if `True`, NaN value will be set for empty ground truth cases. - If `False`, 1 will be set if the Union of ``y_pred`` and ``y`` is empty. - num_classes: number of input channels (always including the background). When this is None, - ``y_pred.shape[1]`` will be used. This option is useful when both ``y_pred`` and ``y`` are - single-channel class indices and the number of classes is not automatically inferred from data. - """ - self.sigmoid = sigmoid + self.threshold = threshold self.reduction = reduction self.get_not_nans = get_not_nans - self.include_background = sigmoid if include_background is None else include_background - self.softmax = not sigmoid if softmax is None else softmax + self.include_background = threshold if include_background is None else include_background + self.apply_argmax = not threshold if apply_argmax is None else apply_argmax self.activate = activate self.ignore_empty = ignore_empty self.num_classes = num_classes def compute_channel(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: - """""" + """ + Compute the dice metric for binary inputs which have only spatial dimensions. This method is called separately + for each batch item and for each channel of those items. + + Args: + y_pred: input predictions with shape HW[D]. + y: ground truth with shape HW[D]. + """ y_o = torch.sum(y) if y_o > 0: return (2.0 * torch.sum(torch.masked_select(y, y_pred))) / (y_o + torch.sum(y_pred)) @@ -243,25 +299,25 @@ def compute_channel(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor def __call__(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: """ + Compute the metric for the given prediction and ground truth. Args: y_pred: input predictions with shape (batch_size, num_classes or 1, spatial_dims...). the number of channels is inferred from ``y_pred.shape[1]`` when ``num_classes is None``. y: ground truth with shape (batch_size, num_classes or 1, spatial_dims...). """ - _softmax, _sigmoid = self.softmax, self.sigmoid + _apply_argmax, _threshold = self.apply_argmax, self.threshold if self.num_classes is None: n_pred_ch = y_pred.shape[1] # y_pred is in one-hot format or multi-channel scores else: n_pred_ch = self.num_classes if y_pred.shape[1] == 1 and self.num_classes > 1: # y_pred is single-channel class indices - _softmax = _sigmoid = False + _apply_argmax = _threshold = False - if _softmax: - if n_pred_ch > 1: - y_pred = torch.argmax(y_pred, dim=1, keepdim=True) + if _apply_argmax and n_pred_ch > 1: + y_pred = torch.argmax(y_pred, dim=1, keepdim=True) - elif _sigmoid: + elif _threshold: if self.activate: y_pred = torch.sigmoid(y_pred) y_pred = y_pred > 0.5 diff --git a/tests/metrics/test_compute_meandice.py b/tests/metrics/test_compute_meandice.py index aae15483b5..04c81ff9a7 100644 --- a/tests/metrics/test_compute_meandice.py +++ b/tests/metrics/test_compute_meandice.py @@ -267,15 +267,15 @@ def test_nans(self, input_data, expected_value): @parameterized.expand([TEST_CASE_3]) def test_helper(self, input_data, _unused): vals = {"y_pred": dict(input_data).pop("y_pred"), "y": dict(input_data).pop("y")} - result = DiceHelper(sigmoid=True)(**vals) + result = DiceHelper(threshold=True)(**vals) np.testing.assert_allclose(result[0].cpu().numpy(), [0.0, 0.0, 0.0], atol=1e-4) np.testing.assert_allclose(sorted(result[1].cpu().numpy()), [0.0, 1.0, 2.0], atol=1e-4) - result = DiceHelper(softmax=True, get_not_nans=False)(**vals) + result = DiceHelper(apply_argmax=True, get_not_nans=False)(**vals) np.testing.assert_allclose(result[0].cpu().numpy(), [0.0, 0.0], atol=1e-4) num_classes = vals["y_pred"].shape[1] vals["y_pred"] = torch.argmax(vals["y_pred"], dim=1, keepdim=True) - result = DiceHelper(sigmoid=True, num_classes=num_classes)(**vals) + result = DiceHelper(threshold=True, num_classes=num_classes)(**vals) np.testing.assert_allclose(result[0].cpu().numpy(), [0.0, 0.0, 0.0], atol=1e-4) # DiceMetric class tests