diff --git a/monai/metrics/hausdorff_distance.py b/monai/metrics/hausdorff_distance.py index 34edaef681..b98d87baef 100644 --- a/monai/metrics/hausdorff_distance.py +++ b/monai/metrics/hausdorff_distance.py @@ -12,11 +12,19 @@ from __future__ import annotations import warnings +from collections.abc import Sequence +from typing import Any import numpy as np import torch -from monai.metrics.utils import do_metric_reduction, get_mask_edges, get_surface_distance, ignore_background +from monai.metrics.utils import ( + do_metric_reduction, + get_mask_edges, + get_surface_distance, + ignore_background, + prepare_spacing, +) from monai.utils import MetricReduction, convert_data_type from .metric import CumulativeIterationMetric @@ -70,7 +78,7 @@ def __init__( self.reduction = reduction self.get_not_nans = get_not_nans - def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: # type: ignore[override] + def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor, **kwargs: Any) -> torch.Tensor: # type: ignore[override] """ Args: y_pred: input data to compute, typical segmentation model output. @@ -78,6 +86,16 @@ def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor should be binarized. y: ground truth to compute the distance. It must be one-hot format and first dim is batch. The values should be binarized. + kwargs: additional parameters, e.g. ``spacing`` should be passed to correctly compute the metric. + ``spacing``: spacing of pixel (or voxel). This parameter is relevant only + if ``distance_metric`` is set to ``"euclidean"``. + If a single number, isotropic spacing with that value is used for all images in the batch. If a sequence of numbers, + the length of the sequence must be equal to the image dimensions. + This spacing will be used for all images in the batch. + If a sequence of sequences, the length of the outer sequence must be equal to the batch size. + If inner sequence has length 1, isotropic spacing with that value is used for all images in the batch, + else the inner sequence length must be equal to the image dimensions. If ``None``, spacing of unity is used + for all images in batch. Defaults to ``None``. Raises: ValueError: when `y_pred` has less than three dimensions. @@ -85,6 +103,7 @@ def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor dims = y_pred.ndimension() if dims < 3: raise ValueError("y_pred should have at least three dimensions.") + # compute (BxC) for each channel for each batch return compute_hausdorff_distance( y_pred=y_pred, @@ -93,6 +112,7 @@ def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor distance_metric=self.distance_metric, percentile=self.percentile, directed=self.directed, + spacing=kwargs.get("spacing"), ) def aggregate( @@ -123,6 +143,7 @@ def compute_hausdorff_distance( distance_metric: str = "euclidean", percentile: float | None = None, directed: bool = False, + spacing: int | float | np.ndarray | Sequence[int | float | np.ndarray | Sequence[int | float]] | None = None, ) -> torch.Tensor: """ Compute the Hausdorff distance. @@ -141,6 +162,13 @@ def compute_hausdorff_distance( percentile of the Hausdorff Distance rather than the maximum result will be achieved. Defaults to ``None``. directed: whether to calculate directed Hausdorff distance. Defaults to ``False``. + spacing: spacing of pixel (or voxel). This parameter is relevant only if ``distance_metric`` is set to ``"euclidean"``. + If a single number, isotropic spacing with that value is used for all images in the batch. If a sequence of numbers, + the length of the sequence must be equal to the image dimensions. This spacing will be used for all images in the batch. + If a sequence of sequences, the length of the outer sequence must be equal to the batch size. + If inner sequence has length 1, isotropic spacing with that value is used for all images in the batch, + else the inner sequence length must be equal to the image dimensions. If ``None``, spacing of unity is used + for all images in batch. Defaults to ``None``. """ if not include_background: @@ -153,6 +181,10 @@ def compute_hausdorff_distance( batch_size, n_class = y_pred.shape[:2] hd = np.empty((batch_size, n_class)) + + img_dim = y_pred.ndim - 2 + spacing_list = prepare_spacing(spacing=spacing, batch_size=batch_size, img_dim=img_dim) + for b, c in np.ndindex(batch_size, n_class): (edges_pred, edges_gt) = get_mask_edges(y_pred[b, c], y[b, c]) if not np.any(edges_gt): @@ -160,23 +192,31 @@ def compute_hausdorff_distance( if not np.any(edges_pred): warnings.warn(f"the prediction of class {c} is all 0, this may result in nan/inf distance.") - distance_1 = compute_percent_hausdorff_distance(edges_pred, edges_gt, distance_metric, percentile) + distance_1 = compute_percent_hausdorff_distance( + edges_pred, edges_gt, distance_metric, percentile, spacing_list[b] + ) if directed: hd[b, c] = distance_1 else: - distance_2 = compute_percent_hausdorff_distance(edges_gt, edges_pred, distance_metric, percentile) + distance_2 = compute_percent_hausdorff_distance( + edges_gt, edges_pred, distance_metric, percentile, spacing_list[b] + ) hd[b, c] = max(distance_1, distance_2) return convert_data_type(hd, output_type=torch.Tensor, device=y_pred.device, dtype=torch.float)[0] def compute_percent_hausdorff_distance( - edges_pred: np.ndarray, edges_gt: np.ndarray, distance_metric: str = "euclidean", percentile: float | None = None + edges_pred: np.ndarray, + edges_gt: np.ndarray, + distance_metric: str = "euclidean", + percentile: float | None = None, + spacing: int | float | np.ndarray | Sequence[int | float] | None = None, ) -> float: """ This function is used to compute the directed Hausdorff distance. """ - surface_distance = get_surface_distance(edges_pred, edges_gt, distance_metric=distance_metric) + surface_distance = get_surface_distance(edges_pred, edges_gt, distance_metric=distance_metric, spacing=spacing) # for both pred and gt do not have foreground if surface_distance.shape == (0,): diff --git a/monai/metrics/loss_metric.py b/monai/metrics/loss_metric.py index 2cc9755e36..3136f42f4a 100644 --- a/monai/metrics/loss_metric.py +++ b/monai/metrics/loss_metric.py @@ -11,6 +11,8 @@ from __future__ import annotations +from typing import Any + import torch from torch.nn.modules.loss import _Loss @@ -92,7 +94,7 @@ def aggregate( f, not_nans = do_metric_reduction(data, reduction or self.reduction) return (f, not_nans) if self.get_not_nans else f - def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor | None = None) -> TensorOrList: + def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor | None = None, **kwargs: Any) -> TensorOrList: """ Input `y_pred` is compared with ground truth `y`. Both `y_pred` and `y` are expected to be a batch-first Tensor (BC[HWD]). diff --git a/monai/metrics/metric.py b/monai/metrics/metric.py index 608d914808..a6dc1a49a2 100644 --- a/monai/metrics/metric.py +++ b/monai/metrics/metric.py @@ -49,7 +49,7 @@ class IterationMetric(Metric): """ def __call__( - self, y_pred: TensorOrList, y: TensorOrList | None = None + self, y_pred: TensorOrList, y: TensorOrList | None = None, **kwargs: Any ) -> torch.Tensor | Sequence[torch.Tensor | Sequence[torch.Tensor]]: """ Execute basic computation for model prediction `y_pred` and ground truth `y` (optional). @@ -60,6 +60,7 @@ def __call__( or a `batch-first` Tensor. y: the ground truth to compute, must be a list of `channel-first` Tensor or a `batch-first` Tensor. + kwargs: additional parameters for specific metric computation logic (e.g. ``spacing`` for SurfaceDistanceMetric, etc.). Returns: The computed metric values at the iteration level. @@ -69,15 +70,15 @@ def __call__( """ # handling a list of channel-first data if isinstance(y_pred, (list, tuple)) or isinstance(y, (list, tuple)): - return self._compute_list(y_pred, y) + return self._compute_list(y_pred, y, **kwargs) # handling a single batch-first data if isinstance(y_pred, torch.Tensor): y_ = y.detach() if isinstance(y, torch.Tensor) else None - return self._compute_tensor(y_pred.detach(), y_) + return self._compute_tensor(y_pred.detach(), y_, **kwargs) raise ValueError("y_pred or y must be a list/tuple of `channel-first` Tensors or a `batch-first` Tensor.") def _compute_list( - self, y_pred: TensorOrList, y: TensorOrList | None = None + self, y_pred: TensorOrList, y: TensorOrList | None = None, **kwargs: Any ) -> torch.Tensor | list[torch.Tensor | Sequence[torch.Tensor]]: """ Execute the metric computation for `y_pred` and `y` in a list of "channel-first" tensors. @@ -93,9 +94,12 @@ def _compute_list( Note: subclass may enhance the operation to have multi-thread support. """ if y is not None: - ret = [self._compute_tensor(p.detach().unsqueeze(0), y_.detach().unsqueeze(0)) for p, y_ in zip(y_pred, y)] + ret = [ + self._compute_tensor(p.detach().unsqueeze(0), y_.detach().unsqueeze(0), **kwargs) + for p, y_ in zip(y_pred, y) + ] else: - ret = [self._compute_tensor(p_.detach().unsqueeze(0), None) for p_ in y_pred] + ret = [self._compute_tensor(p_.detach().unsqueeze(0), None, **kwargs) for p_ in y_pred] # concat the list of results (e.g. a batch of evaluation scores) if isinstance(ret[0], torch.Tensor): @@ -106,7 +110,7 @@ def _compute_list( return ret @abstractmethod - def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor | None = None) -> TensorOrList: + def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor | None = None, **kwargs: Any) -> TensorOrList: """ Computation logic for `y_pred` and `y` of an iteration, the data should be "batch-first" Tensors. A subclass should implement its own computation logic. @@ -318,7 +322,7 @@ class CumulativeIterationMetric(Cumulative, IterationMetric): """ def __call__( - self, y_pred: TensorOrList, y: TensorOrList | None = None + self, y_pred: TensorOrList, y: TensorOrList | None = None, **kwargs: Any ) -> torch.Tensor | Sequence[torch.Tensor | Sequence[torch.Tensor]]: """ Execute basic computation for model prediction and ground truth. @@ -331,12 +335,13 @@ def __call__( or a `batch-first` Tensor. y: the ground truth to compute, must be a list of `channel-first` Tensor or a `batch-first` Tensor. + kwargs: additional parameters for specific metric computation logic (e.g. ``spacing`` for SurfaceDistanceMetric, etc.). Returns: The computed metric values at the iteration level. The output shape should be a `batch-first` tensor (BC[HWD]) or a list of `batch-first` tensors. """ - ret = super().__call__(y_pred=y_pred, y=y) + ret = super().__call__(y_pred=y_pred, y=y, **kwargs) if isinstance(ret, (tuple, list)): self.extend(*ret) else: diff --git a/monai/metrics/surface_dice.py b/monai/metrics/surface_dice.py index 12c47dec8d..ad0aeb332b 100644 --- a/monai/metrics/surface_dice.py +++ b/monai/metrics/surface_dice.py @@ -12,11 +12,19 @@ from __future__ import annotations import warnings +from collections.abc import Sequence +from typing import Any import numpy as np import torch -from monai.metrics.utils import do_metric_reduction, get_mask_edges, get_surface_distance, ignore_background +from monai.metrics.utils import ( + do_metric_reduction, + get_mask_edges, + get_surface_distance, + ignore_background, + prepare_spacing, +) from monai.utils import MetricReduction, convert_data_type from .metric import CumulativeIterationMetric @@ -67,13 +75,23 @@ def __init__( self.reduction = reduction self.get_not_nans = get_not_nans - def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: # type: ignore[override] + def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor, **kwargs: Any) -> torch.Tensor: # type: ignore[override] r""" Args: y_pred: Predicted segmentation, typically segmentation model output. It must be a one-hot encoded, batch-first tensor [B,C,H,W]. y: Reference segmentation. It must be a one-hot encoded, batch-first tensor [B,C,H,W]. + kwargs: additional parameters, e.g. ``spacing`` should be passed to correctly compute the metric. + ``spacing``: spacing of pixel (or voxel). This parameter is relevant only + if ``distance_metric`` is set to ``"euclidean"``. + If a single number, isotropic spacing with that value is used for all images in the batch. If a sequence of numbers, + the length of the sequence must be equal to the image dimensions. + This spacing will be used for all images in the batch. + If a sequence of sequences, the length of the outer sequence must be equal to the batch size. + If inner sequence has length 1, isotropic spacing with that value is used for all images in the batch, + else the inner sequence length must be equal to the image dimensions. If ``None``, spacing of unity is used + for all images in batch. Defaults to ``None``. Returns: Pytorch Tensor of shape [B,C], containing the NSD values :math:`\operatorname {NSD}_{b,c}` for each batch @@ -85,6 +103,7 @@ def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor class_thresholds=self.class_thresholds, include_background=self.include_background, distance_metric=self.distance_metric, + spacing=kwargs.get("spacing"), ) def aggregate( @@ -117,6 +136,7 @@ def compute_surface_dice( class_thresholds: list[float], include_background: bool = False, distance_metric: str = "euclidean", + spacing: int | float | np.ndarray | Sequence[int | float | np.ndarray | Sequence[int | float]] | None = None, ) -> torch.Tensor: r""" This function computes the (Normalized) Surface Dice (NSD) between the two tensors `y_pred` (referred to as @@ -167,6 +187,13 @@ def compute_surface_dice( distance_metric: The metric used to compute surface distances. One of [``"euclidean"``, ``"chessboard"``, ``"taxicab"``]. Defaults to ``"euclidean"``. + spacing: spacing of pixel (or voxel). This parameter is relevant only if ``distance_metric`` is set to ``"euclidean"``. + If a single number, isotropic spacing with that value is used for all images in the batch. If a sequence of numbers, + the length of the sequence must be equal to the image dimensions. This spacing will be used for all images in the batch. + If a sequence of sequences, the length of the outer sequence must be equal to the batch size. + If inner sequence has length 1, isotropic spacing with that value is used for all images in the batch, + else the inner sequence length must be equal to the image dimensions. If ``None``, spacing of unity is used + for all images in batch. Defaults to ``None``. Raises: ValueError: If `y_pred` and/or `y` are not PyTorch tensors. @@ -219,6 +246,9 @@ def compute_surface_dice( nsd = np.empty((batch_size, n_class)) + img_dim = y_pred.ndim - 2 + spacing_list = prepare_spacing(spacing=spacing, batch_size=batch_size, img_dim=img_dim) + for b, c in np.ndindex(batch_size, n_class): (edges_pred, edges_gt) = get_mask_edges(y_pred[b, c], y[b, c], crop=False) if not np.any(edges_gt): @@ -226,8 +256,12 @@ def compute_surface_dice( if not np.any(edges_pred): warnings.warn(f"the prediction of class {c} is all 0, this may result in nan/inf distance.") - distances_pred_gt = get_surface_distance(edges_pred, edges_gt, distance_metric=distance_metric) - distances_gt_pred = get_surface_distance(edges_gt, edges_pred, distance_metric=distance_metric) + distances_pred_gt = get_surface_distance( + edges_pred, edges_gt, distance_metric=distance_metric, spacing=spacing_list[b] + ) + distances_gt_pred = get_surface_distance( + edges_gt, edges_pred, distance_metric=distance_metric, spacing=spacing_list[b] + ) boundary_complete = len(distances_pred_gt) + len(distances_gt_pred) boundary_correct = np.sum(distances_pred_gt <= class_thresholds[c]) + np.sum( diff --git a/monai/metrics/surface_distance.py b/monai/metrics/surface_distance.py index 2a079525ac..cb3fa89190 100644 --- a/monai/metrics/surface_distance.py +++ b/monai/metrics/surface_distance.py @@ -12,11 +12,19 @@ from __future__ import annotations import warnings +from collections.abc import Sequence +from typing import Any import numpy as np import torch -from monai.metrics.utils import do_metric_reduction, get_mask_edges, get_surface_distance, ignore_background +from monai.metrics.utils import ( + do_metric_reduction, + get_mask_edges, + get_surface_distance, + ignore_background, + prepare_spacing, +) from monai.utils import MetricReduction, convert_data_type from .metric import CumulativeIterationMetric @@ -63,7 +71,7 @@ def __init__( self.reduction = reduction self.get_not_nans = get_not_nans - def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: # type: ignore[override] + def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor, **kwargs: Any) -> torch.Tensor: # type: ignore[override] """ Args: y_pred: input data to compute, typical segmentation model output. @@ -71,12 +79,23 @@ def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor should be binarized. y: ground truth to compute the distance. It must be one-hot format and first dim is batch. The values should be binarized. + kwargs: additional parameters, e.g. ``spacing`` should be passed to correctly compute the metric. + ``spacing``: spacing of pixel (or voxel). This parameter is relevant only + if ``distance_metric`` is set to ``"euclidean"``. + If a single number, isotropic spacing with that value is used for all images in the batch. If a sequence of numbers, + the length of the sequence must be equal to the image dimensions. + This spacing will be used for all images in the batch. + If a sequence of sequences, the length of the outer sequence must be equal to the batch size. + If inner sequence has length 1, isotropic spacing with that value is used for all images in the batch, + else the inner sequence length must be equal to the image dimensions. If ``None``, spacing of unity is used + for all images in batch. Defaults to ``None``. Raises: ValueError: when `y_pred` has less than three dimensions. """ if y_pred.dim() < 3: raise ValueError("y_pred should have at least three dimensions.") + # compute (BxC) for each channel for each batch return compute_average_surface_distance( y_pred=y_pred, @@ -84,6 +103,7 @@ def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor include_background=self.include_background, symmetric=self.symmetric, distance_metric=self.distance_metric, + spacing=kwargs.get("spacing"), ) def aggregate( @@ -113,6 +133,7 @@ def compute_average_surface_distance( include_background: bool = False, symmetric: bool = False, distance_metric: str = "euclidean", + spacing: int | float | np.ndarray | Sequence[int | float | np.ndarray | Sequence[int | float]] | None = None, ) -> torch.Tensor: """ This function is used to compute the Average Surface Distance from `y_pred` to `y` @@ -133,6 +154,13 @@ def compute_average_surface_distance( `seg_pred` and `seg_gt`. Defaults to ``False``. distance_metric: : [``"euclidean"``, ``"chessboard"``, ``"taxicab"``] the metric used to compute surface distance. Defaults to ``"euclidean"``. + spacing: spacing of pixel (or voxel). This parameter is relevant only if ``distance_metric`` is set to ``"euclidean"``. + If a single number, isotropic spacing with that value is used for all images in the batch. If a sequence of numbers, + the length of the sequence must be equal to the image dimensions. This spacing will be used for all images in the batch. + If a sequence of sequences, the length of the outer sequence must be equal to the batch size. + If inner sequence has length 1, isotropic spacing with that value is used for all images in the batch, + else the inner sequence length must be equal to the image dimensions. If ``None``, spacing of unity is used + for all images in batch. Defaults to ``None``. """ if not include_background: @@ -147,15 +175,22 @@ def compute_average_surface_distance( batch_size, n_class = y_pred.shape[:2] asd = np.empty((batch_size, n_class)) + img_dim = y_pred.ndim - 2 + spacing_list = prepare_spacing(spacing=spacing, batch_size=batch_size, img_dim=img_dim) + for b, c in np.ndindex(batch_size, n_class): (edges_pred, edges_gt) = get_mask_edges(y_pred[b, c], y[b, c]) if not np.any(edges_gt): warnings.warn(f"the ground truth of class {c} is all 0, this may result in nan/inf distance.") if not np.any(edges_pred): warnings.warn(f"the prediction of class {c} is all 0, this may result in nan/inf distance.") - surface_distance = get_surface_distance(edges_pred, edges_gt, distance_metric=distance_metric) + surface_distance = get_surface_distance( + edges_pred, edges_gt, distance_metric=distance_metric, spacing=spacing_list[b] + ) if symmetric: - surface_distance_2 = get_surface_distance(edges_gt, edges_pred, distance_metric=distance_metric) + surface_distance_2 = get_surface_distance( + edges_gt, edges_pred, distance_metric=distance_metric, spacing=spacing_list[b] + ) surface_distance = np.concatenate([surface_distance, surface_distance_2]) asd[b, c] = np.nan if surface_distance.shape == (0,) else surface_distance.mean() diff --git a/monai/metrics/utils.py b/monai/metrics/utils.py index d0b5c28744..f585cfd9aa 100644 --- a/monai/metrics/utils.py +++ b/monai/metrics/utils.py @@ -12,6 +12,7 @@ from __future__ import annotations import warnings +from collections.abc import Sequence from typing import Any import numpy as np @@ -172,7 +173,12 @@ def get_mask_edges( return edges_pred, edges_gt -def get_surface_distance(seg_pred: np.ndarray, seg_gt: np.ndarray, distance_metric: str = "euclidean") -> np.ndarray: +def get_surface_distance( + seg_pred: np.ndarray, + seg_gt: np.ndarray, + distance_metric: str = "euclidean", + spacing: int | float | np.ndarray | Sequence[int | float] | None = None, +) -> np.ndarray: """ This function is used to compute the surface distances from `seg_pred` to `seg_gt`. @@ -185,6 +191,13 @@ def get_surface_distance(seg_pred: np.ndarray, seg_gt: np.ndarray, distance_metr - ``"euclidean"``, uses Exact Euclidean distance transform. - ``"chessboard"``, uses `chessboard` metric in chamfer type of transform. - ``"taxicab"``, uses `taxicab` metric in chamfer type of transform. + spacing: spacing of pixel (or voxel) along each axis. If a sequence, must be of + length equal to the image dimensions; if a single number, this is used for all axes. + If ``None``, spacing of unity is used. Defaults to ``None``. + spacing: spacing of pixel (or voxel). This parameter is relevant only if ``distance_metric`` is set to ``"euclidean"``. + Several input options are allowed: (1) If a single number, isotropic spacing with that value is used. + (2) If a sequence of numbers, the length of the sequence must be equal to the image dimensions. + (3) If ``None``, spacing of unity is used. Defaults to ``None``. Note: If seg_pred or seg_gt is all 0, may result in nan/inf distance. @@ -198,7 +211,7 @@ def get_surface_distance(seg_pred: np.ndarray, seg_gt: np.ndarray, distance_metr dis = np.inf * np.ones_like(seg_gt) return np.asarray(dis[seg_gt]) if distance_metric == "euclidean": - dis = distance_transform_edt(~seg_gt) + dis = distance_transform_edt(~seg_gt, sampling=spacing) elif distance_metric in {"chessboard", "taxicab"}: dis = distance_transform_cdt(~seg_gt, metric=distance_metric) else: @@ -261,3 +274,63 @@ def remap_instance_id(pred: torch.Tensor, by_size: bool = False) -> torch.Tensor for idx, instance_id in enumerate(pred_id): new_pred[pred == instance_id] = idx + 1 return new_pred + + +def prepare_spacing( + spacing: int | float | np.ndarray | Sequence[int | float | np.ndarray | Sequence[int | float]] | None, + batch_size: int, + img_dim: int, +) -> Sequence[None | int | float | np.ndarray | Sequence[int | float]]: + """ + This function is used to prepare the `spacing` parameter to include batch dimension for the computation of + surface distance, hausdorff distance or surface dice. + + An example with batch_size = 4 and img_dim = 3: + input spacing = None -> output spacing = [None, None, None, None] + input spacing = 0.8 -> output spacing = [0.8, 0.8, 0.8, 0.8] + input spacing = [0.8, 0.5, 0.9] -> output spacing = [[0.8, 0.5, 0.9], [0.8, 0.5, 0.9], [0.8, 0.5, 0.9], [0.8, 0.5, 0.9]] + input spacing = [0.8, 0.7, 1.2, 0.8] -> output spacing = [0.8, 0.7, 1.2, 0.8] (same as input) + + An example with batch_size = 3 and img_dim = 3: + input spacing = [0.8, 0.5, 0.9] -> output spacing = [[0.8, 0.5, 0.9], [0.8, 0.5, 0.9], [0.8, 0.5, 0.9], [0.8, 0.5, 0.9]] + + Args: + spacing: can be a float, a sequence of length `img_dim`, or a sequence with length `batch_size` + that includes floats or sequences of length `img_dim`. + + Raises: + AssertionError: when `spacing` is a sequence of sequence, where the outer sequence length does not + equal `batch_size` or inner sequence length does not equal `img_dim`. + + Returns: + spacing: a sequence with length `batch_size` that includes integers, floats or sequences of length `img_dim`. + """ + if spacing is None or isinstance(spacing, (int, float)): + return list([spacing] * batch_size) + elif isinstance(spacing, (Sequence, np.ndarray)): + assert all( + [isinstance(s, type(spacing[0])) for s in list(spacing)] + ), "if `spacing` is a sequence, its elements should be of same type." + + if isinstance(spacing[0], (Sequence, np.ndarray)): + assert ( + len(spacing) == batch_size + ), "if `spacing` is a sequence of sequences, the outer sequence should have same length as batch size." + assert all( + [len(s) == img_dim for s in list(spacing)] + ), "each element of `spacing` list should either have same length as image dim." + assert all( + [isinstance(i, (int, float)) for s in list(spacing) for i in list(s)] + ), "if `spacing` is a sequence of sequences or 2D np.ndarray, the elements should be integers or floats." + return list(spacing) + elif isinstance(spacing[0], (int, float)): + assert ( + len(spacing) == img_dim + ), "if `spacing` is a sequence of numbers, it should have same length as image dim." + return [spacing for _ in range(batch_size)] # type: ignore + else: + raise AssertionError(f"`spacing` is a sequence of elements with unsupported type: {type(spacing[0])}") + else: + raise AssertionError( + "`spacing` should either be an integer, float, a sequence of numbers or a sequence of sequences." + ) diff --git a/tests/test_hausdorff_distance.py b/tests/test_hausdorff_distance.py index 40f5b187d0..a50b27b79e 100644 --- a/tests/test_hausdorff_distance.py +++ b/tests/test_hausdorff_distance.py @@ -23,7 +23,10 @@ def create_spherical_seg_3d( - radius: float = 20.0, centre: tuple[int, int, int] = (49, 49, 49), im_shape: tuple[int, int, int] = (99, 99, 99) + radius: float = 20.0, + centre: tuple[int, int, int] = (49, 49, 49), + im_shape: tuple[int, int, int] = (99, 99, 99), + im_spacing: tuple[float, float, float] = (1.0, 1.0, 1.0), ) -> np.ndarray: """ Return a 3D image with a sphere inside. Voxel values will be @@ -32,16 +35,23 @@ def create_spherical_seg_3d( Args: radius: radius of sphere (in terms of number of voxels, can be partial) centre: location of sphere centre. - im_shape: shape of image to create + im_shape: shape of image to create. + im_spacing: spacing of image to create. See also: :py:meth:`~create_test_image_3d` """ # Create image image = np.zeros(im_shape, dtype=np.int32) - spy, spx, spz = np.ogrid[ - -centre[0] : im_shape[0] - centre[0], -centre[1] : im_shape[1] - centre[1], -centre[2] : im_shape[2] - centre[2] - ] + spy, spx, spz = np.ogrid[: im_shape[0], : im_shape[1], : im_shape[2]] + spy = spy.astype(float) * im_spacing[0] + spx = spx.astype(float) * im_spacing[1] + spz = spz.astype(float) * im_spacing[2] + + spy -= centre[0] + spx -= centre[1] + spz -= centre[2] + circle = (spx * spx + spy * spy + spz * spz) <= radius * radius image[circle] = 1 @@ -49,12 +59,14 @@ def create_spherical_seg_3d( return image +test_spacing = (0.85, 1.2, 0.9) TEST_CASES = [ - [[create_spherical_seg_3d(), create_spherical_seg_3d(), 1], [0, 0, 0, 0, 0, 0]], + [[create_spherical_seg_3d(), create_spherical_seg_3d(), None, 1], [0, 0, 0, 0, 0, 0]], [ [ create_spherical_seg_3d(radius=20, centre=(20, 20, 20)), create_spherical_seg_3d(radius=20, centre=(19, 19, 19)), + None, ], [1.7320508075688772, 1.7320508075688772, 1, 1, 3, 3], ], @@ -62,6 +74,7 @@ def create_spherical_seg_3d( [ create_spherical_seg_3d(radius=33, centre=(19, 33, 22)), create_spherical_seg_3d(radius=33, centre=(20, 33, 22)), + None, ], [1, 1, 1, 1, 1, 1], ], @@ -69,6 +82,7 @@ def create_spherical_seg_3d( [ create_spherical_seg_3d(radius=20, centre=(20, 33, 22)), create_spherical_seg_3d(radius=40, centre=(20, 33, 22)), + None, ], [20.09975124224178, 20.223748416156685, 15, 20, 24, 35], ], @@ -77,6 +91,7 @@ def create_spherical_seg_3d( # pred does not have foreground (but gt has), the metric should be inf np.zeros([99, 99, 99]), create_spherical_seg_3d(radius=40, centre=(20, 33, 22)), + None, ], [np.inf, np.inf, np.inf, np.inf, np.inf, np.inf], ], @@ -85,6 +100,7 @@ def create_spherical_seg_3d( # gt does not have foreground (but pred has), the metric should be inf create_spherical_seg_3d(), np.zeros([99, 99, 99]), + None, ], [np.inf, np.inf, np.inf, np.inf, np.inf, np.inf], ], @@ -92,20 +108,46 @@ def create_spherical_seg_3d( [ create_spherical_seg_3d(radius=20, centre=(20, 33, 22)), create_spherical_seg_3d(radius=40, centre=(20, 33, 22)), + None, 95, ], [19.924858845171276, 20.09975124224178, 14, 18, 22, 33], ], + [ + [ + create_spherical_seg_3d(radius=20, centre=(20, 20, 20), im_spacing=test_spacing), + create_spherical_seg_3d(radius=20, centre=(19, 19, 19), im_spacing=test_spacing), + test_spacing, + ], + [2.0808651447296143, 2.2671568, 2, 2, 3, 4], + ], + [ + [ + create_spherical_seg_3d(radius=15, centre=(20, 33, 22), im_spacing=test_spacing), + create_spherical_seg_3d(radius=30, centre=(20, 33, 22), im_spacing=test_spacing), + test_spacing, + ], + [15.439640998840332, 15.62594, 11, 17, 20, 28], + ], ] TEST_CASES_NANS = [ + [ + [ + # both pred and gt do not have foreground, spacing is None, metric and not_nans should be 0 + np.zeros([99, 99, 99]), + np.zeros([99, 99, 99]), + None, + ] + ], [ [ # both pred and gt do not have foreground, metric and not_nans should be 0 np.zeros([99, 99, 99]), np.zeros([99, 99, 99]), + test_spacing, ] - ] + ], ] @@ -113,10 +155,10 @@ class TestHausdorffDistance(unittest.TestCase): @parameterized.expand(TEST_CASES) def test_value(self, input_data, expected_value): percentile = None - if len(input_data) == 3: - [seg_1, seg_2, percentile] = input_data + if len(input_data) == 4: + [seg_1, seg_2, spacing, percentile] = input_data else: - [seg_1, seg_2] = input_data + [seg_1, seg_2, spacing] = input_data ct = 0 seg_1 = torch.tensor(seg_1, device=_device) seg_2 = torch.tensor(seg_2, device=_device) @@ -129,7 +171,7 @@ def test_value(self, input_data, expected_value): batch, n_class = 2, 3 batch_seg_1 = seg_1.unsqueeze(0).unsqueeze(0).repeat([batch, n_class, 1, 1, 1]) batch_seg_2 = seg_2.unsqueeze(0).unsqueeze(0).repeat([batch, n_class, 1, 1, 1]) - hd_metric(batch_seg_1, batch_seg_2) + hd_metric(batch_seg_1, batch_seg_2, spacing=spacing) result = hd_metric.aggregate(reduction="mean") expected_value_curr = expected_value[ct] np.testing.assert_allclose(expected_value_curr, result.cpu(), rtol=1e-7) @@ -138,13 +180,13 @@ def test_value(self, input_data, expected_value): @parameterized.expand(TEST_CASES_NANS) def test_nans(self, input_data): - [seg_1, seg_2] = input_data + [seg_1, seg_2, spacing] = input_data seg_1 = torch.tensor(seg_1) seg_2 = torch.tensor(seg_2) hd_metric = HausdorffDistanceMetric(include_background=False, get_not_nans=True) batch_seg_1 = seg_1.unsqueeze(0).unsqueeze(0) batch_seg_2 = seg_2.unsqueeze(0).unsqueeze(0) - hd_metric(batch_seg_1, batch_seg_2) + hd_metric(batch_seg_1, batch_seg_2, spacing=spacing) result, not_nans = hd_metric.aggregate() np.testing.assert_allclose(0, result, rtol=1e-7) np.testing.assert_allclose(0, not_nans, rtol=1e-7) diff --git a/tests/test_surface_dice.py b/tests/test_surface_dice.py index 3ee54e5903..15e6245619 100644 --- a/tests/test_surface_dice.py +++ b/tests/test_surface_dice.py @@ -23,6 +23,67 @@ class TestAllSurfaceDiceMetrics(unittest.TestCase): + def test_tolerance_euclidean_distance_with_spacing(self): + batch_size = 2 + n_class = 2 + test_spacing = (0.85, 1.2) + predictions = torch.zeros((batch_size, 480, 640), dtype=torch.int64, device=_device) + labels = torch.zeros((batch_size, 480, 640), dtype=torch.int64, device=_device) + predictions[0, :, 50:] = 1 + labels[0, :, 60:] = 1 # 10 px shift + predictions_hot = F.one_hot(predictions, num_classes=n_class).permute(0, 3, 1, 2) + labels_hot = F.one_hot(labels, num_classes=n_class).permute(0, 3, 1, 2) + + sd0 = SurfaceDiceMetric(class_thresholds=[0, 0], include_background=True) + res0 = sd0(predictions_hot, labels_hot, spacing=test_spacing) + agg0 = sd0.aggregate() # aggregation: nanmean across image then nanmean across batch + sd0_nans = SurfaceDiceMetric(class_thresholds=[0, 0], include_background=True, get_not_nans=True) + res0_nans = sd0_nans(predictions_hot, labels_hot) + agg0_nans, not_nans = sd0_nans.aggregate() + + np.testing.assert_array_equal(res0.cpu(), res0_nans.cpu()) + np.testing.assert_equal(res0.device, predictions.device) + np.testing.assert_array_equal(agg0.cpu(), agg0_nans.cpu()) + np.testing.assert_equal(agg0.device, predictions.device) + + res1 = SurfaceDiceMetric(class_thresholds=[1, 1], include_background=True)( + predictions_hot, labels_hot, spacing=test_spacing + ) + res9 = SurfaceDiceMetric(class_thresholds=[9, 9], include_background=True)( + predictions_hot, labels_hot, spacing=test_spacing + ) + res10 = SurfaceDiceMetric(class_thresholds=[10, 10], include_background=True)( + predictions_hot, labels_hot, spacing=test_spacing + ) + res11 = SurfaceDiceMetric(class_thresholds=[11, 11], include_background=True)( + predictions_hot, labels_hot, spacing=test_spacing + ) + # because spacing is (0.85, 1.2) and we moved 10 pixels in the columns direction, + # everything with tolerance 12 or more should be the same as tolerance 12 (surface dice is 1.0) + res12 = SurfaceDiceMetric(class_thresholds=[12, 12], include_background=True)( + predictions_hot, labels_hot, spacing=test_spacing + ) + res13 = SurfaceDiceMetric(class_thresholds=[13, 13], include_background=True)( + predictions_hot, labels_hot, spacing=test_spacing + ) + + for res in [res0, res9, res10, res11, res12, res13]: + assert res.shape == torch.Size([2, 2]) + + assert res0[0, 0] < res1[0, 0] < res9[0, 0] < res10[0, 0] < res11[0, 0] + assert res0[0, 1] < res1[0, 1] < res9[0, 1] < res10[0, 1] < res11[0, 1] + np.testing.assert_array_equal(res12.cpu(), res13.cpu()) + + expected_res0 = np.zeros((batch_size, n_class)) + expected_res0[0, 1] = 1 - (478 + 480 + 9 * 2) / (480 * 4 + 588 * 2 + 578 * 2) + expected_res0[0, 0] = 1 - (478 + 480 + 9 * 2) / (480 * 4 + 48 * 2 + 58 * 2) + expected_res0[1, 0] = 1 + expected_res0[1, 1] = np.nan + for b, c in np.ndindex(batch_size, n_class): + np.testing.assert_allclose(expected_res0[b, c], res0[b, c].cpu()) + np.testing.assert_array_equal(agg0.cpu(), np.nanmean(np.nanmean(expected_res0, axis=1), axis=0)) + np.testing.assert_equal(not_nans.cpu(), torch.tensor(2)) + def test_tolerance_euclidean_distance(self): batch_size = 2 n_class = 2 diff --git a/tests/test_surface_distance.py b/tests/test_surface_distance.py index f2e2ea7144..81ddee107b 100644 --- a/tests/test_surface_distance.py +++ b/tests/test_surface_distance.py @@ -23,7 +23,10 @@ def create_spherical_seg_3d( - radius: float = 20.0, centre: tuple[int, int, int] = (49, 49, 49), im_shape: tuple[int, int, int] = (99, 99, 99) + radius: float = 20.0, + centre: tuple[int, int, int] = (49, 49, 49), + im_shape: tuple[int, int, int] = (99, 99, 99), + im_spacing: tuple[float, float, float] = (1.0, 1.0, 1.0), ) -> np.ndarray: """ Return a 3D image with a sphere inside. Voxel values will be @@ -32,16 +35,24 @@ def create_spherical_seg_3d( Args: radius: radius of sphere (in terms of number of voxels, can be partial) centre: location of sphere centre. - im_shape: shape of image to create + im_shape: shape of image to create. + im_spacing: spacing of image to create. See also: :py:meth:`~create_test_image_3d` """ # Create image image = np.zeros(im_shape, dtype=np.int32) - spy, spx, spz = np.ogrid[ - -centre[0] : im_shape[0] - centre[0], -centre[1] : im_shape[1] - centre[1], -centre[2] : im_shape[2] - centre[2] - ] + spy, spx, spz = np.ogrid[: im_shape[0], : im_shape[1], : im_shape[2]] + + spy = spy.astype(float) * im_spacing[0] + spx = spx.astype(float) * im_spacing[1] + spz = spz.astype(float) * im_spacing[2] + + spy -= centre[0] + spx -= centre[1] + spz -= centre[2] + circle = (spx * spx + spy * spy + spz * spz) <= radius * radius image[circle] = 1 @@ -49,8 +60,9 @@ def create_spherical_seg_3d( return image +test_spacing = (0.85, 1.2, 0.9) TEST_CASES = [ - [[create_spherical_seg_3d(), create_spherical_seg_3d()], [0, 0]], + [[create_spherical_seg_3d(), create_spherical_seg_3d(), "euclidean", None], [0, 0]], [ [ create_spherical_seg_3d(radius=20, centre=(20, 20, 20)), @@ -63,6 +75,8 @@ def create_spherical_seg_3d( [ create_spherical_seg_3d(radius=33, centre=(19, 33, 22)), create_spherical_seg_3d(radius=33, centre=(20, 33, 22)), + "euclidean", + None, ], [0.350217, 0.3483278807706289], ], @@ -70,6 +84,8 @@ def create_spherical_seg_3d( [ create_spherical_seg_3d(radius=20, centre=(20, 33, 22)), create_spherical_seg_3d(radius=40, centre=(20, 33, 22)), + "euclidean", + None, ], [15.117741, 12.040033513150455], ], @@ -89,18 +105,39 @@ def create_spherical_seg_3d( ], [20.214613, 12.432687531048186], ], - [[np.zeros([99, 99, 99]), create_spherical_seg_3d(radius=40, centre=(20, 33, 22))], [np.inf, np.inf]], + [ + [np.zeros([99, 99, 99]), create_spherical_seg_3d(radius=40, centre=(20, 33, 22)), "euclidean", None], + [np.inf, np.inf], + ], [[create_spherical_seg_3d(), np.zeros([99, 99, 99]), "taxicab"], [np.inf, np.inf]], + [ + [ + create_spherical_seg_3d(radius=33, centre=(42, 45, 52), im_spacing=test_spacing), + create_spherical_seg_3d(radius=33, centre=(43, 45, 52), im_spacing=test_spacing), + "euclidean", + test_spacing, + ], + [0.4951, 0.4951], + ], ] TEST_CASES_NANS = [ [ [ - # both pred and gt do not have foreground, metric and not_nans should be 0 + # both pred and gt do not have foreground, spacing is None, metric and not_nans should be 0 np.zeros([99, 99, 99]), np.zeros([99, 99, 99]), + None, ] - ] + ], + [ + [ + # both pred and gt do not have foreground, spacing is not None, metric and not_nans should be 0 + np.zeros([99, 99, 99]), + np.zeros([99, 99, 99]), + test_spacing, + ] + ], ] @@ -109,9 +146,10 @@ class TestAllSurfaceMetrics(unittest.TestCase): def test_value(self, input_data, expected_value): if len(input_data) == 3: [seg_1, seg_2, metric] = input_data + spacing = None else: - [seg_1, seg_2] = input_data - metric = "euclidean" + [seg_1, seg_2, metric, spacing] = input_data + ct = 0 seg_1 = torch.tensor(seg_1, device=_device) seg_2 = torch.tensor(seg_2, device=_device) @@ -121,7 +159,7 @@ def test_value(self, input_data, expected_value): batch, n_class = 2, 3 batch_seg_1 = seg_1.unsqueeze(0).unsqueeze(0).repeat([batch, n_class, 1, 1, 1]) batch_seg_2 = seg_2.unsqueeze(0).unsqueeze(0).repeat([batch, n_class, 1, 1, 1]) - sur_metric(batch_seg_1, batch_seg_2) + sur_metric(batch_seg_1, batch_seg_2, spacing=spacing) result = sur_metric.aggregate() expected_value_curr = expected_value[ct] np.testing.assert_allclose(expected_value_curr, result.cpu(), rtol=1e-5) @@ -130,14 +168,14 @@ def test_value(self, input_data, expected_value): @parameterized.expand(TEST_CASES_NANS) def test_nans(self, input_data): - [seg_1, seg_2] = input_data + [seg_1, seg_2, spacing] = input_data seg_1 = torch.tensor(seg_1) seg_2 = torch.tensor(seg_2) sur_metric = SurfaceDistanceMetric(include_background=False, get_not_nans=True) # test list of channel-first Tensor batch_seg_1 = [seg_1.unsqueeze(0)] batch_seg_2 = [seg_2.unsqueeze(0)] - sur_metric(batch_seg_1, batch_seg_2) + sur_metric(batch_seg_1, batch_seg_2, spacing=spacing) result, not_nans = sur_metric.aggregate(reduction="mean") np.testing.assert_allclose(0, result, rtol=1e-5) np.testing.assert_allclose(0, not_nans, rtol=1e-5)