diff --git a/monai/metrics/surface_dice.py b/monai/metrics/surface_dice.py index ad0aeb332b..4dce4d0133 100644 --- a/monai/metrics/surface_dice.py +++ b/monai/metrics/surface_dice.py @@ -32,10 +32,14 @@ class SurfaceDiceMetric(CumulativeIterationMetric): """ - Computes the Normalized Surface Distance (NSD) for each batch sample and class of + Computes the Normalized Surface Dice (NSD) for each batch sample and class of predicted segmentations `y_pred` and corresponding reference segmentations `y` according to equation :eq:`nsd`. - This implementation supports 2D images. For 3D images, please refer to DeepMind's implementation - https://github.com/deepmind/surface-distance. + This implementation is based on https://arxiv.org/abs/2111.05408 and supports 2D and 3D images. + Be aware that the computation of boundaries is different from DeepMind's implementation + https://github.com/deepmind/surface-distance. In this implementation, the length/area of a segmentation boundary is + interpreted as the number of its edge pixels. In DeepMind's implementation, the length of a segmentation boundary + depends on the local neighborhood (cf. https://arxiv.org/abs/1809.04430). + This issue is discussed here: https://github.com/Project-MONAI/MONAI/issues/4103. The class- and batch sample-wise NSD values can be aggregated with the function `aggregate`. @@ -79,9 +83,9 @@ def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor, **kwargs: Any) r""" Args: y_pred: Predicted segmentation, typically segmentation model output. - It must be a one-hot encoded, batch-first tensor [B,C,H,W]. + It must be a one-hot encoded, batch-first tensor [B,C,H,W] or [B,C,H,W,D]. y: Reference segmentation. - It must be a one-hot encoded, batch-first tensor [B,C,H,W]. + It must be a one-hot encoded, batch-first tensor [B,C,H,W] or [B,C,H,W,D]. kwargs: additional parameters, e.g. ``spacing`` should be passed to correctly compute the metric. ``spacing``: spacing of pixel (or voxel). This parameter is relevant only if ``distance_metric`` is set to ``"euclidean"``. @@ -168,7 +172,7 @@ def compute_surface_dice( will be returned for this class. In the case of a class being present in only one of predicted segmentation or reference segmentation, the class NSD will be 0. - This implementation is based on https://arxiv.org/abs/2111.05408 and supports 2D images. + This implementation is based on https://arxiv.org/abs/2111.05408 and supports 2D and 3D images. Be aware that the computation of boundaries is different from DeepMind's implementation https://github.com/deepmind/surface-distance. In this implementation, the length of a segmentation boundary is interpreted as the number of its edge pixels. In DeepMind's implementation, the length of a segmentation boundary @@ -176,9 +180,9 @@ def compute_surface_dice( Args: y_pred: Predicted segmentation, typically segmentation model output. - It must be a one-hot encoded, batch-first tensor [B,C,H,W]. + It must be a one-hot encoded, batch-first tensor [B,C,H,W] or [B,C,H,W,D]. y: Reference segmentation. - It must be a one-hot encoded, batch-first tensor [B,C,H,W]. + It must be a one-hot encoded, batch-first tensor [B,C,H,W] or [B,C,H,W,D]. class_thresholds: List of class-specific thresholds. The thresholds relate to the acceptable amount of deviation in the segmentation boundary in pixels. Each threshold needs to be a finite, non-negative number. @@ -215,8 +219,8 @@ def compute_surface_dice( if not isinstance(y_pred, torch.Tensor) or not isinstance(y, torch.Tensor): raise ValueError("y_pred and y must be PyTorch Tensor.") - if y_pred.ndimension() != 4 or y.ndimension() != 4: - raise ValueError("y_pred and y should have four dimensions: [B,C,H,W].") + if y_pred.ndimension() not in (4, 5) or y.ndimension() not in (4, 5): + raise ValueError("y_pred and y should be one-hot encoded: [B,C,H,W] or [B,C,H,W,D].") if y_pred.shape != y.shape: raise ValueError( diff --git a/tests/test_surface_dice.py b/tests/test_surface_dice.py index 15e6245619..c5400aea39 100644 --- a/tests/test_surface_dice.py +++ b/tests/test_surface_dice.py @@ -128,6 +128,53 @@ def test_tolerance_euclidean_distance(self): np.testing.assert_array_equal(agg0.cpu(), np.nanmean(np.nanmean(expected_res0, axis=1), axis=0)) np.testing.assert_equal(not_nans.cpu(), torch.tensor(2)) + def test_tolerance_euclidean_distance_3d(self): + batch_size = 2 + n_class = 2 + predictions = torch.zeros((batch_size, 200, 110, 80), dtype=torch.int64, device=_device) + labels = torch.zeros((batch_size, 200, 110, 80), dtype=torch.int64, device=_device) + predictions[0, :, :, 20:] = 1 + labels[0, :, :, 30:] = 1 # offset by 10 + predictions_hot = F.one_hot(predictions, num_classes=n_class).permute(0, 4, 1, 2, 3) + labels_hot = F.one_hot(labels, num_classes=n_class).permute(0, 4, 1, 2, 3) + + sd0 = SurfaceDiceMetric(class_thresholds=[0, 0], include_background=True) + res0 = sd0(predictions_hot, labels_hot) + agg0 = sd0.aggregate() # aggregation: nanmean across image then nanmean across batch + sd0_nans = SurfaceDiceMetric(class_thresholds=[0, 0], include_background=True, get_not_nans=True) + res0_nans = sd0_nans(predictions_hot, labels_hot) + agg0_nans, not_nans = sd0_nans.aggregate() + + np.testing.assert_array_equal(res0.cpu(), res0_nans.cpu()) + np.testing.assert_equal(res0.device, predictions.device) + np.testing.assert_array_equal(agg0.cpu(), agg0_nans.cpu()) + np.testing.assert_equal(agg0.device, predictions.device) + + res1 = SurfaceDiceMetric(class_thresholds=[1, 1], include_background=True)(predictions_hot, labels_hot) + res10 = SurfaceDiceMetric(class_thresholds=[10, 10], include_background=True)(predictions_hot, labels_hot) + res11 = SurfaceDiceMetric(class_thresholds=[11, 11], include_background=True)(predictions_hot, labels_hot) + + for res in [res0, res1, res10, res11]: + assert res.shape == torch.Size([2, 2]) + + assert res0[0, 0] < res1[0, 0] < res10[0, 0] + assert res0[0, 1] < res1[0, 1] < res10[0, 1] + np.testing.assert_array_equal(res10.cpu(), res11.cpu()) + + expected_res0 = np.zeros((batch_size, n_class)) + expected_res0[0, 1] = 1 - (200 * 110 + 198 * 108 + 9 * 200 * 2 + 9 * 108 * 2) / ( + 200 * 110 * 4 + (58 + 48) * 200 * 2 + (58 + 48) * 108 * 2 + ) + expected_res0[0, 0] = 1 - (200 * 110 + 198 * 108 + 9 * 200 * 2 + 9 * 108 * 2) / ( + 200 * 110 * 4 + (28 + 18) * 200 * 2 + (28 + 18) * 108 * 2 + ) + expected_res0[1, 0] = 1 + expected_res0[1, 1] = np.nan + for b, c in np.ndindex(batch_size, n_class): + np.testing.assert_allclose(expected_res0[b, c], res0[b, c].cpu()) + np.testing.assert_array_equal(agg0.cpu(), np.nanmean(np.nanmean(expected_res0, axis=1), axis=0)) + np.testing.assert_equal(not_nans.cpu(), torch.tensor(2)) + def test_tolerance_all_distances(self): batch_size = 1 n_class = 2 @@ -262,10 +309,10 @@ def test_asserts(self): # wrong dimensions with self.assertRaises(ValueError) as context: SurfaceDiceMetric(class_thresholds=[1, 1], include_background=True)(predictions, labels_hot) - self.assertEqual("y_pred and y should have four dimensions: [B,C,H,W].", str(context.exception)) + self.assertEqual("y_pred and y should be one-hot encoded: [B,C,H,W] or [B,C,H,W,D].", str(context.exception)) with self.assertRaises(ValueError) as context: SurfaceDiceMetric(class_thresholds=[1, 1], include_background=True)(predictions_hot, labels) - self.assertEqual("y_pred and y should have four dimensions: [B,C,H,W].", str(context.exception)) + self.assertEqual("y_pred and y should be one-hot encoded: [B,C,H,W] or [B,C,H,W,D].", str(context.exception)) # mismatch of shape of input tensors input_bad_shape = torch.clone(predictions_hot)