From eed6f445623623852cf8f73d321c099478ddf707 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Sat, 1 Jul 2023 17:54:49 +0100 Subject: [PATCH 01/10] enhance surface Dice to use subvoxel areas Signed-off-by: Wenqi Li --- monai/metrics/surface_dice.py | 88 +++-- monai/metrics/utils.py | 508 ++++++++++++++++++++++--- monai/transforms/croppad/dictionary.py | 4 +- tests/test_surface_dice.py | 37 +- 4 files changed, 531 insertions(+), 106 deletions(-) diff --git a/monai/metrics/surface_dice.py b/monai/metrics/surface_dice.py index 43305ca834..a4d1cc6cdd 100644 --- a/monai/metrics/surface_dice.py +++ b/monai/metrics/surface_dice.py @@ -86,7 +86,7 @@ def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor, **kwargs: Any) It must be a one-hot encoded, batch-first tensor [B,C,H,W] or [B,C,H,W,D]. y: Reference segmentation. It must be a one-hot encoded, batch-first tensor [B,C,H,W] or [B,C,H,W,D]. - kwargs: additional parameters, e.g. ``spacing`` should be passed to correctly compute the metric. + kwargs: additional parameters: ``spacing`` should be passed to correctly compute the metric. ``spacing``: spacing of pixel (or voxel). This parameter is relevant only if ``distance_metric`` is set to ``"euclidean"``. If a single number, isotropic spacing with that value is used for all images in the batch. If a sequence of numbers, @@ -96,6 +96,8 @@ def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor, **kwargs: Any) If inner sequence has length 1, isotropic spacing with that value is used for all images in the batch, else the inner sequence length must be equal to the image dimensions. If ``None``, spacing of unity is used for all images in batch. Defaults to ``None``. + use_subvoxels: Whether to use subvoxel distances. Defaults to ``False``. + Returns: Pytorch Tensor of shape [B,C], containing the NSD values :math:`\operatorname {NSD}_{b,c}` for each batch @@ -108,6 +110,7 @@ def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor, **kwargs: Any) include_background=self.include_background, distance_metric=self.distance_metric, spacing=kwargs.get("spacing"), + use_subvoxels=kwargs.get("use_subvoxels", False), ) def aggregate( @@ -141,13 +144,14 @@ def compute_surface_dice( include_background: bool = False, distance_metric: str = "euclidean", spacing: int | float | np.ndarray | Sequence[int | float | np.ndarray | Sequence[int | float]] | None = None, + use_subvoxels: bool = False, ) -> torch.Tensor: r""" This function computes the (Normalized) Surface Dice (NSD) between the two tensors `y_pred` (referred to as :math:`\hat{Y}`) and `y` (referred to as :math:`Y`). This metric determines which fraction of a segmentation boundary is correctly predicted. A boundary element is considered correctly predicted if the closest distance to the - reference boundary is smaller than or equal to the specified threshold related to the acceptable amount of deviation in - pixels. The NSD is bounded between 0 and 1. + reference boundary is smaller than or equal to the specified threshold related to the acceptable amount of deviation + in pixels. The NSD is bounded between 0 and 1. This implementation supports multi-class tasks with an individual threshold :math:`\tau_c` for each class :math:`c`. The class-specific NSD for batch index :math:`b`, :math:`\operatorname {NSD}_{b,c}`, is computed using the function: @@ -159,8 +163,8 @@ def compute_surface_dice( :label: nsd with :math:`\mathcal{D}_{Y_{b,c}}` and :math:`\mathcal{D}_{\hat{Y}_{b,c}}` being two sets of nearest-neighbor - distances. :math:`\mathcal{D}_{Y_{b,c}}` is computed from the predicted segmentation boundary towards the reference segmentation - boundary and vice-versa for :math:`\mathcal{D}_{\hat{Y}_{b,c}}`. :math:`\mathcal{D}_{Y_{b,c}}^{'}` and + distances. :math:`\mathcal{D}_{Y_{b,c}}` is computed from the predicted segmentation boundary towards the reference + segmentation boundary and vice-versa for :math:`\mathcal{D}_{\hat{Y}_{b,c}}`. :math:`\mathcal{D}_{Y_{b,c}}^{'}` and :math:`\mathcal{D}_{\hat{Y}_{b,c}}^{'}` refer to the subsets of distances that are smaller or equal to the acceptable distance :math:`\tau_c`: @@ -168,15 +172,14 @@ def compute_surface_dice( \mathcal{D}_{Y_{b,c}}^{'} = \{ d \in \mathcal{D}_{Y_{b,c}} \, | \, d \leq \tau_c \}. - In the case of a class neither being present in the predicted segmentation, nor in the reference segmentation, a nan value - will be returned for this class. In the case of a class being present in only one of predicted segmentation or - reference segmentation, the class NSD will be 0. + In the case of a class neither being present in the predicted segmentation, nor in the reference segmentation, + a nan value will be returned for this class. In the case of a class being present in only one of predicted + segmentation or reference segmentation, the class NSD will be 0. This implementation is based on https://arxiv.org/abs/2111.05408 and supports 2D and 3D images. - Be aware that the computation of boundaries is different from DeepMind's implementation - https://github.com/deepmind/surface-distance. In this implementation, the length of a segmentation boundary is - interpreted as the number of its edge pixels. In DeepMind's implementation, the length of a segmentation boundary - depends on the local neighborhood (cf. https://arxiv.org/abs/1809.04430). + The computation of boundaries follows DeepMind's implementation + https://github.com/deepmind/surface-distance when `use_subvoxels=True`; Otherwise the length of a segmentation + boundary is interpreted as the number of its edge pixels. Args: y_pred: Predicted segmentation, typically segmentation model output. @@ -198,6 +201,7 @@ def compute_surface_dice( If inner sequence has length 1, isotropic spacing with that value is used for all images in the batch, else the inner sequence length must be equal to the image dimensions. If ``None``, spacing of unity is used for all images in batch. Defaults to ``None``. + use_subvoxels: Whether to use subvoxel distances. Defaults to ``False``. Raises: ValueError: If `y_pred` and/or `y` are not PyTorch tensors. @@ -227,11 +231,6 @@ def compute_surface_dice( f"y_pred and y should have same shape, but instead, shapes are {y_pred.shape} (y_pred) and {y.shape} (y)." ) - if not torch.all(y_pred.byte() == y_pred) or not torch.all(y.byte() == y): - raise ValueError("y_pred and y should be binarized tensors (e.g. torch.int64).") - if torch.any(y_pred > 1) or torch.any(y > 1): - raise ValueError("y_pred and y should be one-hot encoded.") - y = y.float() y_pred = y_pred.float() @@ -254,23 +253,44 @@ def compute_surface_dice( spacing_list = prepare_spacing(spacing=spacing, batch_size=batch_size, img_dim=img_dim) for b, c in np.ndindex(batch_size, n_class): - (edges_pred, edges_gt) = get_mask_edges(y_pred[b, c], y[b, c], crop=False) - if not np.any(edges_gt): - warnings.warn(f"the ground truth of class {c} is all 0, this may result in nan/inf distance.") - if not np.any(edges_pred): - warnings.warn(f"the prediction of class {c} is all 0, this may result in nan/inf distance.") - - distances_pred_gt = get_surface_distance( - edges_pred, edges_gt, distance_metric=distance_metric, spacing=spacing_list[b] - ) - distances_gt_pred = get_surface_distance( - edges_gt, edges_pred, distance_metric=distance_metric, spacing=spacing_list[b] - ) - - boundary_complete = len(distances_pred_gt) + len(distances_gt_pred) - boundary_correct = np.sum(distances_pred_gt <= class_thresholds[c]) + np.sum( - distances_gt_pred <= class_thresholds[c] - ) + if not use_subvoxels: + (edges_pred, edges_gt) = get_mask_edges(y_pred[b, c], y[b, c], crop=True) + if not np.any(edges_gt): + warnings.warn(f"the ground truth of class {c} is all 0, this may result in nan/inf distance.") + if not np.any(edges_pred): + warnings.warn(f"the prediction of class {c} is all 0, this may result in nan/inf distance.") + + distances_pred_gt = get_surface_distance( + edges_pred, edges_gt, distance_metric=distance_metric, spacing=spacing_list[b] + ) + distances_gt_pred = get_surface_distance( + edges_gt, edges_pred, distance_metric=distance_metric, spacing=spacing_list[b] + ) + + boundary_complete = len(distances_pred_gt) + len(distances_gt_pred) + boundary_correct = np.sum(distances_pred_gt <= class_thresholds[c]) + np.sum( + distances_gt_pred <= class_thresholds[c] + ) + else: + _spacing = spacing_list[b] if spacing_list[b] is not None else [1] * img_dim + areas_pred: np.ndarray + areas_gt: np.ndarray + edges_pred, edges_gt, areas_pred, areas_gt = get_mask_edges( # type: ignore + y_pred[b, c], y[b, c], crop=True, spacing=_spacing # type: ignore + ) + if not np.any(edges_gt): + warnings.warn(f"the ground truth of class {c} is all 0, this may result in nan/inf distance.") + if not np.any(edges_pred): + warnings.warn(f"the prediction of class {c} is all 0, this may result in nan/inf distance.") + dist_gt = get_surface_distance(None, edges_gt, distance_metric=distance_metric, spacing=spacing_list[b]) + dist_pred = get_surface_distance(None, edges_pred, distance_metric=distance_metric, spacing=spacing_list[b]) + dist_gt_to_pred, dist_pred_to_gt = dist_pred[edges_gt], dist_gt[edges_pred] + areas_gt, areas_pred = areas_gt[edges_gt], areas_pred[edges_pred] + boundary_complete = areas_gt.sum() + areas_pred.sum() + boundary_correct = ( + areas_gt[dist_gt_to_pred <= class_thresholds[c]].sum() + + areas_pred[dist_pred_to_gt <= class_thresholds[c]].sum() + ) if boundary_complete == 0: # the class is neither present in the prediction, nor in the reference segmentation diff --git a/monai/metrics/utils.py b/monai/metrics/utils.py index be121ef027..35bf3feba6 100644 --- a/monai/metrics/utils.py +++ b/monai/metrics/utils.py @@ -12,22 +12,37 @@ from __future__ import annotations import warnings -from collections.abc import Sequence -from typing import Any +from functools import cache +from typing import Any, Sequence import numpy as np import torch from monai.config import NdarrayOrTensor, NdarrayTensor -from monai.transforms.croppad.array import SpatialCrop -from monai.transforms.utils import generate_spatial_bounding_box -from monai.utils import MetricReduction, convert_data_type, look_up_option, optional_import +from monai.transforms.croppad.dictionary import BorderPadD, CropForegroundD +from monai.utils import ( + MetricReduction, + convert_data_type, + convert_to_tensor, + ensure_tuple_rep, + look_up_option, + optional_import, +) binary_erosion, _ = optional_import("scipy.ndimage.morphology", name="binary_erosion") distance_transform_edt, _ = optional_import("scipy.ndimage.morphology", name="distance_transform_edt") distance_transform_cdt, _ = optional_import("scipy.ndimage.morphology", name="distance_transform_cdt") -__all__ = ["ignore_background", "do_metric_reduction", "get_mask_edges", "get_surface_distance", "is_binary_tensor"] +__all__ = [ + "ignore_background", + "do_metric_reduction", + "get_mask_edges", + "get_surface_distance", + "is_binary_tensor", + "remap_instance_id", + "prepare_spacing", + "get_code_to_measure_table", +] def ignore_background(y_pred: NdarrayTensor, y: NdarrayTensor) -> tuple[NdarrayTensor, NdarrayTensor]: @@ -110,7 +125,11 @@ def do_metric_reduction( def get_mask_edges( - seg_pred: NdarrayOrTensor, seg_gt: NdarrayOrTensor, label_idx: int = 1, crop: bool = True + seg_pred: NdarrayOrTensor, + seg_gt: NdarrayOrTensor, + label_idx: int = 1, + crop: bool = True, + spacing: Sequence | None = None, ) -> tuple[np.ndarray, np.ndarray]: """ Do binary erosion and use XOR for input to get the edges. This @@ -157,24 +176,38 @@ def get_mask_edges( if crop: if not np.any(seg_pred | seg_gt): - return np.zeros_like(seg_pred), np.zeros_like(seg_gt) - - channel_dim = 0 - seg_pred, seg_gt = np.expand_dims(seg_pred, axis=channel_dim), np.expand_dims(seg_gt, axis=channel_dim) - box_start, box_end = generate_spatial_bounding_box(np.asarray(seg_pred | seg_gt)) - cropper = SpatialCrop(roi_start=box_start, roi_end=box_end) - seg_pred = convert_data_type(np.squeeze(cropper(seg_pred), axis=channel_dim), np.ndarray)[0] # type: ignore[arg-type] - seg_gt = convert_data_type(np.squeeze(cropper(seg_gt), axis=channel_dim), np.ndarray)[0] # type: ignore[arg-type] - - # Do binary erosion and use XOR to get edges - edges_pred = binary_erosion(seg_pred) ^ seg_pred - edges_gt = binary_erosion(seg_gt) ^ seg_gt - - return edges_pred, edges_gt + pred, gt = np.zeros_like(seg_pred), np.zeros_like(seg_gt) + return (pred, gt) if spacing is None else (pred, gt, pred, gt) # type: ignore + cropper = CropForegroundD(keys=["pred", "gt"], source_key="src", start_coord_key=None, end_coord_key=None) + pad = BorderPadD(keys=["pred", "gt"], spatial_border=1, mode="constant") + mask = np.asarray(seg_pred | seg_gt) + cropped = pad(cropper({"pred": seg_pred[None], "gt": seg_gt[None], "src": mask[None]})) # type: ignore + seg_pred = cropped["pred"] + seg_gt = cropped["gt"] + + if spacing is None: + # Do binary erosion and use XOR to get edges + seg_pred = convert_data_type(seg_pred[0], np.ndarray)[0] + seg_gt = convert_data_type(seg_gt[0], np.ndarray)[0] + edges_pred = binary_erosion(seg_pred) ^ seg_pred + edges_gt = binary_erosion(seg_gt) ^ seg_gt + return edges_pred, edges_gt + code_to_area_table, k = get_code_to_measure_table(spacing) + sptial_dims = len(spacing) + conv = torch.nn.functional.conv3d if sptial_dims == 3 else torch.nn.functional.conv2d + code_pred, code_gt = conv(torch.stack([seg_pred, seg_gt], dim=0).float(), k.float()) # type: ignore + # edges + all_ones = len(code_to_area_table) - 1 + edges_pred = (code_pred != 0) & (code_pred != all_ones) + edges_gt = (code_gt != 0) & (code_gt != all_ones) + # areas of edges + areas_pred = torch.index_select(code_to_area_table, 0, code_pred.view(-1).int()).reshape(code_pred.shape) + areas_gt = torch.index_select(code_to_area_table, 0, code_gt.view(-1).int()).reshape(code_gt.shape) + return edges_pred.array[0], edges_gt.array[0], areas_pred.array[0], areas_gt.array[0] # type: ignore def get_surface_distance( - seg_pred: np.ndarray, + seg_pred: np.ndarray | None, seg_gt: np.ndarray, distance_metric: str = "euclidean", spacing: int | float | np.ndarray | Sequence[int | float] | None = None, @@ -207,7 +240,7 @@ def get_surface_distance( if not np.any(seg_gt): dis = np.inf * np.ones_like(seg_gt) else: - if not np.any(seg_pred): + if seg_pred is not None and not np.any(seg_pred): dis = np.inf * np.ones_like(seg_gt) return np.asarray(dis[seg_gt]) if distance_metric == "euclidean": @@ -217,7 +250,7 @@ def get_surface_distance( else: raise ValueError(f"distance_metric {distance_metric} is not implemented.") - return np.asarray(dis[seg_pred]) + return np.asarray(dis[seg_pred]) if seg_pred is not None else np.asarray(dis) def is_binary_tensor(input: torch.Tensor, name: str) -> None: @@ -259,13 +292,10 @@ def remap_instance_id(pred: torch.Tensor, by_size: bool = False) -> torch.Tensor # the original implementation has the limitation that if there is no 0 in pred, error will happen pred_id = [i for i in pred_id if i != 0] - if len(pred_id) == 0: + if not pred_id: return pred - if by_size is True: - instance_size = [] - for instance_id in pred_id: - instance_size.append((pred == instance_id).sum()) - + if by_size: + instance_size = [(pred == instance_id).sum() for instance_id in pred_id] pair_data = zip(pred_id, instance_size) pair_list = sorted(pair_data, key=lambda x: x[1], reverse=True) pred_id, _ = zip(*pair_list) @@ -292,14 +322,15 @@ def prepare_spacing( input spacing = [0.8, 0.7, 1.2, 0.8] -> output spacing = [0.8, 0.7, 1.2, 0.8] (same as input) An example with batch_size = 3 and img_dim = 3: - input spacing = [0.8, 0.5, 0.9] -> output spacing = [[0.8, 0.5, 0.9], [0.8, 0.5, 0.9], [0.8, 0.5, 0.9], [0.8, 0.5, 0.9]] + input spacing = [0.8, 0.5, 0.9] -> + output spacing = [[0.8, 0.5, 0.9], [0.8, 0.5, 0.9], [0.8, 0.5, 0.9], [0.8, 0.5, 0.9]] Args: spacing: can be a float, a sequence of length `img_dim`, or a sequence with length `batch_size` that includes floats or sequences of length `img_dim`. Raises: - AssertionError: when `spacing` is a sequence of sequence, where the outer sequence length does not + ValueError: when `spacing` is a sequence of sequence, where the outer sequence length does not equal `batch_size` or inner sequence length does not equal `img_dim`. Returns: @@ -307,30 +338,395 @@ def prepare_spacing( """ if spacing is None or isinstance(spacing, (int, float)): return list([spacing] * batch_size) - elif isinstance(spacing, (Sequence, np.ndarray)): - assert all( - isinstance(s, type(spacing[0])) for s in list(spacing) - ), "if `spacing` is a sequence, its elements should be of same type." - + if isinstance(spacing, (Sequence, np.ndarray)): + if any(not isinstance(s, type(spacing[0])) for s in list(spacing)): + raise ValueError(f"if `spacing` is a sequence, its elements should be of same type, got {spacing}.") if isinstance(spacing[0], (Sequence, np.ndarray)): - assert ( - len(spacing) == batch_size - ), "if `spacing` is a sequence of sequences, the outer sequence should have same length as batch size." - assert all( - len(s) == img_dim for s in list(spacing) - ), "each element of `spacing` list should either have same length as image dim." - assert all( - isinstance(i, (int, float)) for s in list(spacing) for i in list(s) - ), "if `spacing` is a sequence of sequences or 2D np.ndarray, the elements should be integers or floats." + if len(spacing) != batch_size: + raise ValueError( + "if `spacing` is a sequence of sequences, " + f"the outer sequence should have same length as batch size ({batch_size}), got {spacing}." + ) + if any(len(s) != img_dim for s in list(spacing)): + raise ValueError( + "each element of `spacing` list should either have same length as" + f"image dim ({img_dim}), got {spacing}." + ) + if not all(isinstance(i, (int, float)) for s in list(spacing) for i in list(s)): + raise ValueError( + f"if `spacing` is a sequence of sequences or 2D np.ndarray, " + f"the elements should be integers or floats, got {spacing}." + ) return list(spacing) - elif isinstance(spacing[0], (int, float)): - assert ( - len(spacing) == img_dim - ), "if `spacing` is a sequence of numbers, it should have same length as image dim." + if isinstance(spacing[0], (int, float)): + if len(spacing) != img_dim: + raise ValueError( + f"if `spacing` is a sequence of numbers, " + f"it should have same length as image dim ({img_dim}), got {spacing}." + ) return [spacing for _ in range(batch_size)] # type: ignore - else: - raise AssertionError(f"`spacing` is a sequence of elements with unsupported type: {type(spacing[0])}") - else: - raise AssertionError( - "`spacing` should either be an integer, float, a sequence of numbers or a sequence of sequences." - ) + raise ValueError(f"`spacing` is a sequence of elements with unsupported type: {type(spacing[0])}") + raise ValueError( + f"`spacing` should either be a number, a sequence of numbers or a sequence of sequences, got {spacing}." + ) + + +ENCODING_KERNEL = {2: [[8, 4], [2, 1]], 3: [[[128, 64], [32, 16]], [[8, 4], [2, 1]]]} + + +@cache +def _get_neighbour_code_to_normals_table(device=None): + """ + returns a lookup table. For every binary neighbour code (2x2x2 neighbourhood = 8 neighbours = 8 bits = 256 codes) + it contains the surface normals of the triangles. The length of the normal vector encodes the surfel area. + + created using the marching_cube algorithm see e.g. https://en.wikipedia.org/wiki/Marching_cubes + + Args: + device: torch device to use for the table. + """ + zeros = [0.0, 0.0, 0.0] + ret = [ + [zeros, zeros, zeros, zeros], + [[0.125, 0.125, 0.125], zeros, zeros, zeros], + [[-0.125, -0.125, 0.125], zeros, zeros, zeros], + [[-0.25, -0.25, 0.0], [0.25, 0.25, -0.0], zeros, zeros], + [[0.125, -0.125, 0.125], zeros, zeros, zeros], + [[-0.25, -0.0, -0.25], [0.25, 0.0, 0.25], zeros, zeros], + [[0.125, -0.125, 0.125], [-0.125, -0.125, 0.125], zeros, zeros], + [[0.5, 0.0, -0.0], [0.25, 0.25, 0.25], [0.125, 0.125, 0.125], zeros], + [[-0.125, 0.125, 0.125], zeros, zeros, zeros], + [[0.125, 0.125, 0.125], [-0.125, 0.125, 0.125], zeros, zeros], + [[-0.25, 0.0, 0.25], [-0.25, 0.0, 0.25], zeros, zeros], + [[0.5, 0.0, 0.0], [-0.25, -0.25, 0.25], [-0.125, -0.125, 0.125], zeros], + [[0.25, -0.25, 0.0], [0.25, -0.25, 0.0], zeros, zeros], + [[0.5, 0.0, 0.0], [0.25, -0.25, 0.25], [-0.125, 0.125, -0.125], zeros], + [[-0.5, 0.0, 0.0], [-0.25, 0.25, 0.25], [-0.125, 0.125, 0.125], zeros], + [[0.5, 0.0, 0.0], [0.5, 0.0, 0.0], zeros, zeros], + [[0.125, -0.125, -0.125], zeros, zeros, zeros], + [[0.0, -0.25, -0.25], [0.0, 0.25, 0.25], zeros, zeros], + [[-0.125, -0.125, 0.125], [0.125, -0.125, -0.125], zeros, zeros], + [[0.0, -0.5, 0.0], [0.25, 0.25, 0.25], [0.125, 0.125, 0.125], zeros], + [[0.125, -0.125, 0.125], [0.125, -0.125, -0.125], zeros, zeros], + [[0.0, 0.0, -0.5], [0.25, 0.25, 0.25], [-0.125, -0.125, -0.125], zeros], + [[-0.125, -0.125, 0.125], [0.125, -0.125, 0.125], [0.125, -0.125, -0.125], zeros], + [[-0.125, -0.125, -0.125], [-0.25, -0.25, -0.25], [0.25, 0.25, 0.25], [0.125, 0.125, 0.125]], + [[-0.125, 0.125, 0.125], [0.125, -0.125, -0.125], zeros, zeros], + [[0.0, -0.25, -0.25], [0.0, 0.25, 0.25], [-0.125, 0.125, 0.125], zeros], + [[-0.25, 0.0, 0.25], [-0.25, 0.0, 0.25], [0.125, -0.125, -0.125], zeros], + [[0.125, 0.125, 0.125], [0.375, 0.375, 0.375], [0.0, -0.25, 0.25], [-0.25, 0.0, 0.25]], + [[0.125, -0.125, -0.125], [0.25, -0.25, 0.0], [0.25, -0.25, 0.0], zeros], + [[0.375, 0.375, 0.375], [0.0, 0.25, -0.25], [-0.125, -0.125, -0.125], [-0.25, 0.25, 0.0]], + [[-0.5, 0.0, 0.0], [-0.125, -0.125, -0.125], [-0.25, -0.25, -0.25], [0.125, 0.125, 0.125]], + [[-0.5, 0.0, 0.0], [-0.125, -0.125, -0.125], [-0.25, -0.25, -0.25], zeros], + [[0.125, -0.125, 0.125], zeros, zeros, zeros], + [[0.125, 0.125, 0.125], [0.125, -0.125, 0.125], zeros, zeros], + [[0.0, -0.25, 0.25], [0.0, 0.25, -0.25], zeros, zeros], + [[0.0, -0.5, 0.0], [0.125, 0.125, -0.125], [0.25, 0.25, -0.25], zeros], + [[0.125, -0.125, 0.125], [0.125, -0.125, 0.125], zeros, zeros], + [[0.125, -0.125, 0.125], [-0.25, -0.0, -0.25], [0.25, 0.0, 0.25], zeros], + [[0.0, -0.25, 0.25], [0.0, 0.25, -0.25], [0.125, -0.125, 0.125], zeros], + [[-0.375, -0.375, 0.375], [-0.0, 0.25, 0.25], [0.125, 0.125, -0.125], [-0.25, -0.0, -0.25]], + [[-0.125, 0.125, 0.125], [0.125, -0.125, 0.125], zeros, zeros], + [[0.125, 0.125, 0.125], [0.125, -0.125, 0.125], [-0.125, 0.125, 0.125], zeros], + [[-0.0, 0.0, 0.5], [-0.25, -0.25, 0.25], [-0.125, -0.125, 0.125], zeros], + [[0.25, 0.25, -0.25], [0.25, 0.25, -0.25], [0.125, 0.125, -0.125], [-0.125, -0.125, 0.125]], + [[0.125, -0.125, 0.125], [0.25, -0.25, 0.0], [0.25, -0.25, 0.0], zeros], + [[0.5, 0.0, 0.0], [0.25, -0.25, 0.25], [-0.125, 0.125, -0.125], [0.125, -0.125, 0.125]], + [[0.0, 0.25, -0.25], [0.375, -0.375, -0.375], [-0.125, 0.125, 0.125], [0.25, 0.25, 0.0]], + [[-0.5, 0.0, 0.0], [-0.25, -0.25, 0.25], [-0.125, -0.125, 0.125], zeros], + [[0.25, -0.25, 0.0], [-0.25, 0.25, 0.0], zeros, zeros], + [[0.0, 0.5, 0.0], [-0.25, 0.25, 0.25], [0.125, -0.125, -0.125], zeros], + [[0.0, 0.5, 0.0], [0.125, -0.125, 0.125], [-0.25, 0.25, -0.25], zeros], + [[0.0, 0.5, 0.0], [0.0, -0.5, 0.0], zeros, zeros], + [[0.25, -0.25, 0.0], [-0.25, 0.25, 0.0], [0.125, -0.125, 0.125], zeros], + [[-0.375, -0.375, -0.375], [-0.25, 0.0, 0.25], [-0.125, -0.125, -0.125], [-0.25, 0.25, 0.0]], + [[0.125, 0.125, 0.125], [0.0, -0.5, 0.0], [-0.25, -0.25, -0.25], [-0.125, -0.125, -0.125]], + [[0.0, -0.5, 0.0], [-0.25, -0.25, -0.25], [-0.125, -0.125, -0.125], zeros], + [[-0.125, 0.125, 0.125], [0.25, -0.25, 0.0], [-0.25, 0.25, 0.0], zeros], + [[0.0, 0.5, 0.0], [0.25, 0.25, -0.25], [-0.125, -0.125, 0.125], [-0.125, -0.125, 0.125]], + [[-0.375, 0.375, -0.375], [-0.25, -0.25, 0.0], [-0.125, 0.125, -0.125], [-0.25, 0.0, 0.25]], + [[0.0, 0.5, 0.0], [0.25, 0.25, -0.25], [-0.125, -0.125, 0.125], zeros], + [[0.25, -0.25, 0.0], [-0.25, 0.25, 0.0], [0.25, -0.25, 0.0], [0.25, -0.25, 0.0]], + [[-0.25, -0.25, 0.0], [-0.25, -0.25, 0.0], [-0.125, -0.125, 0.125], zeros], + [[0.125, 0.125, 0.125], [-0.25, -0.25, 0.0], [-0.25, -0.25, 0.0], zeros], + [[-0.25, -0.25, 0.0], [-0.25, -0.25, 0.0], zeros, zeros], + [[-0.125, -0.125, 0.125], zeros, zeros, zeros], + [[0.125, 0.125, 0.125], [-0.125, -0.125, 0.125], zeros, zeros], + [[-0.125, -0.125, 0.125], [-0.125, -0.125, 0.125], zeros, zeros], + [[-0.125, -0.125, 0.125], [-0.25, -0.25, 0.0], [0.25, 0.25, -0.0], zeros], + [[0.0, -0.25, 0.25], [0.0, -0.25, 0.25], zeros, zeros], + [[0.0, 0.0, 0.5], [0.25, -0.25, 0.25], [0.125, -0.125, 0.125], zeros], + [[0.0, -0.25, 0.25], [0.0, -0.25, 0.25], [-0.125, -0.125, 0.125], zeros], + [[0.375, -0.375, 0.375], [0.0, -0.25, -0.25], [-0.125, 0.125, -0.125], [0.25, 0.25, 0.0]], + [[-0.125, -0.125, 0.125], [-0.125, 0.125, 0.125], zeros, zeros], + [[0.125, 0.125, 0.125], [-0.125, -0.125, 0.125], [-0.125, 0.125, 0.125], zeros], + [[-0.125, -0.125, 0.125], [-0.25, 0.0, 0.25], [-0.25, 0.0, 0.25], zeros], + [[0.5, 0.0, 0.0], [-0.25, -0.25, 0.25], [-0.125, -0.125, 0.125], [-0.125, -0.125, 0.125]], + [[-0.0, 0.5, 0.0], [-0.25, 0.25, -0.25], [0.125, -0.125, 0.125], zeros], + [[-0.25, 0.25, -0.25], [-0.25, 0.25, -0.25], [-0.125, 0.125, -0.125], [-0.125, 0.125, -0.125]], + [[-0.25, 0.0, -0.25], [0.375, -0.375, -0.375], [0.0, 0.25, -0.25], [-0.125, 0.125, 0.125]], + [[0.5, 0.0, 0.0], [-0.25, 0.25, -0.25], [0.125, -0.125, 0.125], zeros], + [[-0.25, 0.0, 0.25], [0.25, 0.0, -0.25], zeros, zeros], + [[-0.0, 0.0, 0.5], [-0.25, 0.25, 0.25], [-0.125, 0.125, 0.125], zeros], + [[-0.125, -0.125, 0.125], [-0.25, 0.0, 0.25], [0.25, 0.0, -0.25], zeros], + [[-0.25, -0.0, -0.25], [-0.375, 0.375, 0.375], [-0.25, -0.25, 0.0], [-0.125, 0.125, 0.125]], + [[0.0, 0.0, -0.5], [0.25, 0.25, -0.25], [-0.125, -0.125, 0.125], zeros], + [[-0.0, 0.0, 0.5], [0.0, 0.0, 0.5], zeros, zeros], + [[0.125, 0.125, 0.125], [0.125, 0.125, 0.125], [0.25, 0.25, 0.25], [0.0, 0.0, 0.5]], + [[0.125, 0.125, 0.125], [0.25, 0.25, 0.25], [0.0, 0.0, 0.5], zeros], + [[-0.25, 0.0, 0.25], [0.25, 0.0, -0.25], [-0.125, 0.125, 0.125], zeros], + [[-0.0, 0.0, 0.5], [0.25, -0.25, 0.25], [0.125, -0.125, 0.125], [0.125, -0.125, 0.125]], + [[-0.25, 0.0, 0.25], [-0.25, 0.0, 0.25], [-0.25, 0.0, 0.25], [0.25, 0.0, -0.25]], + [[0.125, -0.125, 0.125], [0.25, 0.0, 0.25], [0.25, 0.0, 0.25], zeros], + [[0.25, 0.0, 0.25], [-0.375, -0.375, 0.375], [-0.25, 0.25, 0.0], [-0.125, -0.125, 0.125]], + [[-0.0, 0.0, 0.5], [0.25, -0.25, 0.25], [0.125, -0.125, 0.125], zeros], + [[0.125, 0.125, 0.125], [0.25, 0.0, 0.25], [0.25, 0.0, 0.25], zeros], + [[0.25, 0.0, 0.25], [0.25, 0.0, 0.25], zeros, zeros], + [[-0.125, -0.125, 0.125], [0.125, -0.125, 0.125], zeros, zeros], + [[0.125, 0.125, 0.125], [-0.125, -0.125, 0.125], [0.125, -0.125, 0.125], zeros], + [[-0.125, -0.125, 0.125], [0.0, -0.25, 0.25], [0.0, 0.25, -0.25], zeros], + [[0.0, -0.5, 0.0], [0.125, 0.125, -0.125], [0.25, 0.25, -0.25], [-0.125, -0.125, 0.125]], + [[0.0, -0.25, 0.25], [0.0, -0.25, 0.25], [0.125, -0.125, 0.125], zeros], + [[0.0, 0.0, 0.5], [0.25, -0.25, 0.25], [0.125, -0.125, 0.125], [0.125, -0.125, 0.125]], + [[0.0, -0.25, 0.25], [0.0, -0.25, 0.25], [0.0, -0.25, 0.25], [0.0, 0.25, -0.25]], + [[0.0, 0.25, 0.25], [0.0, 0.25, 0.25], [0.125, -0.125, -0.125], zeros], + [[-0.125, 0.125, 0.125], [0.125, -0.125, 0.125], [-0.125, -0.125, 0.125], zeros], + [[-0.125, 0.125, 0.125], [0.125, -0.125, 0.125], [-0.125, -0.125, 0.125], [0.125, 0.125, 0.125]], + [[-0.0, 0.0, 0.5], [-0.25, -0.25, 0.25], [-0.125, -0.125, 0.125], [-0.125, -0.125, 0.125]], + [[0.125, 0.125, 0.125], [0.125, -0.125, 0.125], [0.125, -0.125, -0.125], zeros], + [[-0.0, 0.5, 0.0], [-0.25, 0.25, -0.25], [0.125, -0.125, 0.125], [0.125, -0.125, 0.125]], + [[0.125, 0.125, 0.125], [-0.125, -0.125, 0.125], [0.125, -0.125, -0.125], zeros], + [[0.0, -0.25, -0.25], [0.0, 0.25, 0.25], [0.125, 0.125, 0.125], zeros], + [[0.125, 0.125, 0.125], [0.125, -0.125, -0.125], zeros, zeros], + [[0.5, 0.0, -0.0], [0.25, -0.25, -0.25], [0.125, -0.125, -0.125], zeros], + [[-0.25, 0.25, 0.25], [-0.125, 0.125, 0.125], [-0.25, 0.25, 0.25], [0.125, -0.125, -0.125]], + [[0.375, -0.375, 0.375], [0.0, 0.25, 0.25], [-0.125, 0.125, -0.125], [-0.25, 0.0, 0.25]], + [[0.0, -0.5, 0.0], [-0.25, 0.25, 0.25], [-0.125, 0.125, 0.125], zeros], + [[-0.375, -0.375, 0.375], [0.25, -0.25, 0.0], [0.0, 0.25, 0.25], [-0.125, -0.125, 0.125]], + [[-0.125, 0.125, 0.125], [-0.25, 0.25, 0.25], [0.0, 0.0, 0.5], zeros], + [[0.125, 0.125, 0.125], [0.0, 0.25, 0.25], [0.0, 0.25, 0.25], zeros], + [[0.0, 0.25, 0.25], [0.0, 0.25, 0.25], zeros, zeros], + [[0.5, 0.0, -0.0], [0.25, 0.25, 0.25], [0.125, 0.125, 0.125], [0.125, 0.125, 0.125]], + [[0.125, -0.125, 0.125], [-0.125, -0.125, 0.125], [0.125, 0.125, 0.125], zeros], + [[-0.25, -0.0, -0.25], [0.25, 0.0, 0.25], [0.125, 0.125, 0.125], zeros], + [[0.125, 0.125, 0.125], [0.125, -0.125, 0.125], zeros, zeros], + [[-0.25, -0.25, 0.0], [0.25, 0.25, -0.0], [0.125, 0.125, 0.125], zeros], + [[0.125, 0.125, 0.125], [-0.125, -0.125, 0.125], zeros, zeros], + [[0.125, 0.125, 0.125], [0.125, 0.125, 0.125], zeros, zeros], + [[0.125, 0.125, 0.125], zeros, zeros, zeros], + [[0.125, 0.125, 0.125], zeros, zeros, zeros], + [[0.125, 0.125, 0.125], [0.125, 0.125, 0.125], zeros, zeros], + [[0.125, 0.125, 0.125], [-0.125, -0.125, 0.125], zeros, zeros], + [[-0.25, -0.25, 0.0], [0.25, 0.25, -0.0], [0.125, 0.125, 0.125], zeros], + [[0.125, 0.125, 0.125], [0.125, -0.125, 0.125], zeros, zeros], + [[-0.25, -0.0, -0.25], [0.25, 0.0, 0.25], [0.125, 0.125, 0.125], zeros], + [[0.125, -0.125, 0.125], [-0.125, -0.125, 0.125], [0.125, 0.125, 0.125], zeros], + [[0.5, 0.0, -0.0], [0.25, 0.25, 0.25], [0.125, 0.125, 0.125], [0.125, 0.125, 0.125]], + [[0.0, 0.25, 0.25], [0.0, 0.25, 0.25], zeros, zeros], + [[0.125, 0.125, 0.125], [0.0, 0.25, 0.25], [0.0, 0.25, 0.25], zeros], + [[-0.125, 0.125, 0.125], [-0.25, 0.25, 0.25], [0.0, 0.0, 0.5], zeros], + [[-0.375, -0.375, 0.375], [0.25, -0.25, 0.0], [0.0, 0.25, 0.25], [-0.125, -0.125, 0.125]], + [[0.0, -0.5, 0.0], [-0.25, 0.25, 0.25], [-0.125, 0.125, 0.125], zeros], + [[0.375, -0.375, 0.375], [0.0, 0.25, 0.25], [-0.125, 0.125, -0.125], [-0.25, 0.0, 0.25]], + [[-0.25, 0.25, 0.25], [-0.125, 0.125, 0.125], [-0.25, 0.25, 0.25], [0.125, -0.125, -0.125]], + [[0.5, 0.0, -0.0], [0.25, -0.25, -0.25], [0.125, -0.125, -0.125], zeros], + [[0.125, 0.125, 0.125], [0.125, -0.125, -0.125], zeros, zeros], + [[0.0, -0.25, -0.25], [0.0, 0.25, 0.25], [0.125, 0.125, 0.125], zeros], + [[0.125, 0.125, 0.125], [-0.125, -0.125, 0.125], [0.125, -0.125, -0.125], zeros], + [[-0.0, 0.5, 0.0], [-0.25, 0.25, -0.25], [0.125, -0.125, 0.125], [0.125, -0.125, 0.125]], + [[0.125, 0.125, 0.125], [0.125, -0.125, 0.125], [0.125, -0.125, -0.125], zeros], + [[-0.0, 0.0, 0.5], [-0.25, -0.25, 0.25], [-0.125, -0.125, 0.125], [-0.125, -0.125, 0.125]], + [[-0.125, 0.125, 0.125], [0.125, -0.125, 0.125], [-0.125, -0.125, 0.125], [0.125, 0.125, 0.125]], + [[-0.125, 0.125, 0.125], [0.125, -0.125, 0.125], [-0.125, -0.125, 0.125], zeros], + [[0.0, 0.25, 0.25], [0.0, 0.25, 0.25], [0.125, -0.125, -0.125], zeros], + [[0.0, -0.25, -0.25], [0.0, 0.25, 0.25], [0.0, 0.25, 0.25], [0.0, 0.25, 0.25]], + [[0.0, 0.0, 0.5], [0.25, -0.25, 0.25], [0.125, -0.125, 0.125], [0.125, -0.125, 0.125]], + [[0.0, -0.25, 0.25], [0.0, -0.25, 0.25], [0.125, -0.125, 0.125], zeros], + [[0.0, -0.5, 0.0], [0.125, 0.125, -0.125], [0.25, 0.25, -0.25], [-0.125, -0.125, 0.125]], + [[-0.125, -0.125, 0.125], [0.0, -0.25, 0.25], [0.0, 0.25, -0.25], zeros], + [[0.125, 0.125, 0.125], [-0.125, -0.125, 0.125], [0.125, -0.125, 0.125], zeros], + [[-0.125, -0.125, 0.125], [0.125, -0.125, 0.125], zeros, zeros], + [[0.25, 0.0, 0.25], [0.25, 0.0, 0.25], zeros, zeros], + [[0.125, 0.125, 0.125], [0.25, 0.0, 0.25], [0.25, 0.0, 0.25], zeros], + [[-0.0, 0.0, 0.5], [0.25, -0.25, 0.25], [0.125, -0.125, 0.125], zeros], + [[0.25, 0.0, 0.25], [-0.375, -0.375, 0.375], [-0.25, 0.25, 0.0], [-0.125, -0.125, 0.125]], + [[0.125, -0.125, 0.125], [0.25, 0.0, 0.25], [0.25, 0.0, 0.25], zeros], + [[-0.25, -0.0, -0.25], [0.25, 0.0, 0.25], [0.25, 0.0, 0.25], [0.25, 0.0, 0.25]], + [[-0.0, 0.0, 0.5], [0.25, -0.25, 0.25], [0.125, -0.125, 0.125], [0.125, -0.125, 0.125]], + [[-0.25, 0.0, 0.25], [0.25, 0.0, -0.25], [-0.125, 0.125, 0.125], zeros], + [[0.125, 0.125, 0.125], [0.25, 0.25, 0.25], [0.0, 0.0, 0.5], zeros], + [[0.125, 0.125, 0.125], [0.125, 0.125, 0.125], [0.25, 0.25, 0.25], [0.0, 0.0, 0.5]], + [[-0.0, 0.0, 0.5], [0.0, 0.0, 0.5], zeros, zeros], + [[0.0, 0.0, -0.5], [0.25, 0.25, -0.25], [-0.125, -0.125, 0.125], zeros], + [[-0.25, -0.0, -0.25], [-0.375, 0.375, 0.375], [-0.25, -0.25, 0.0], [-0.125, 0.125, 0.125]], + [[-0.125, -0.125, 0.125], [-0.25, 0.0, 0.25], [0.25, 0.0, -0.25], zeros], + [[-0.0, 0.0, 0.5], [-0.25, 0.25, 0.25], [-0.125, 0.125, 0.125], zeros], + [[-0.25, 0.0, 0.25], [0.25, 0.0, -0.25], zeros, zeros], + [[0.5, 0.0, 0.0], [-0.25, 0.25, -0.25], [0.125, -0.125, 0.125], zeros], + [[-0.25, 0.0, -0.25], [0.375, -0.375, -0.375], [0.0, 0.25, -0.25], [-0.125, 0.125, 0.125]], + [[-0.25, 0.25, -0.25], [-0.25, 0.25, -0.25], [-0.125, 0.125, -0.125], [-0.125, 0.125, -0.125]], + [[-0.0, 0.5, 0.0], [-0.25, 0.25, -0.25], [0.125, -0.125, 0.125], zeros], + [[0.5, 0.0, 0.0], [-0.25, -0.25, 0.25], [-0.125, -0.125, 0.125], [-0.125, -0.125, 0.125]], + [[-0.125, -0.125, 0.125], [-0.25, 0.0, 0.25], [-0.25, 0.0, 0.25], zeros], + [[0.125, 0.125, 0.125], [-0.125, -0.125, 0.125], [-0.125, 0.125, 0.125], zeros], + [[-0.125, -0.125, 0.125], [-0.125, 0.125, 0.125], zeros, zeros], + [[0.375, -0.375, 0.375], [0.0, -0.25, -0.25], [-0.125, 0.125, -0.125], [0.25, 0.25, 0.0]], + [[0.0, -0.25, 0.25], [0.0, -0.25, 0.25], [-0.125, -0.125, 0.125], zeros], + [[0.0, 0.0, 0.5], [0.25, -0.25, 0.25], [0.125, -0.125, 0.125], zeros], + [[0.0, -0.25, 0.25], [0.0, -0.25, 0.25], zeros, zeros], + [[-0.125, -0.125, 0.125], [-0.25, -0.25, 0.0], [0.25, 0.25, -0.0], zeros], + [[-0.125, -0.125, 0.125], [-0.125, -0.125, 0.125], zeros, zeros], + [[0.125, 0.125, 0.125], [-0.125, -0.125, 0.125], zeros, zeros], + [[-0.125, -0.125, 0.125], zeros, zeros, zeros], + [[-0.25, -0.25, 0.0], [-0.25, -0.25, 0.0], zeros, zeros], + [[0.125, 0.125, 0.125], [-0.25, -0.25, 0.0], [-0.25, -0.25, 0.0], zeros], + [[-0.25, -0.25, 0.0], [-0.25, -0.25, 0.0], [-0.125, -0.125, 0.125], zeros], + [[-0.25, -0.25, 0.0], [-0.25, -0.25, 0.0], [-0.25, -0.25, 0.0], [0.25, 0.25, -0.0]], + [[0.0, 0.5, 0.0], [0.25, 0.25, -0.25], [-0.125, -0.125, 0.125], zeros], + [[-0.375, 0.375, -0.375], [-0.25, -0.25, 0.0], [-0.125, 0.125, -0.125], [-0.25, 0.0, 0.25]], + [[0.0, 0.5, 0.0], [0.25, 0.25, -0.25], [-0.125, -0.125, 0.125], [-0.125, -0.125, 0.125]], + [[-0.125, 0.125, 0.125], [0.25, -0.25, 0.0], [-0.25, 0.25, 0.0], zeros], + [[0.0, -0.5, 0.0], [-0.25, -0.25, -0.25], [-0.125, -0.125, -0.125], zeros], + [[0.125, 0.125, 0.125], [0.0, -0.5, 0.0], [-0.25, -0.25, -0.25], [-0.125, -0.125, -0.125]], + [[-0.375, -0.375, -0.375], [-0.25, 0.0, 0.25], [-0.125, -0.125, -0.125], [-0.25, 0.25, 0.0]], + [[0.25, -0.25, 0.0], [-0.25, 0.25, 0.0], [0.125, -0.125, 0.125], zeros], + [[0.0, 0.5, 0.0], [0.0, -0.5, 0.0], zeros, zeros], + [[0.0, 0.5, 0.0], [0.125, -0.125, 0.125], [-0.25, 0.25, -0.25], zeros], + [[0.0, 0.5, 0.0], [-0.25, 0.25, 0.25], [0.125, -0.125, -0.125], zeros], + [[0.25, -0.25, 0.0], [-0.25, 0.25, 0.0], zeros, zeros], + [[-0.5, 0.0, 0.0], [-0.25, -0.25, 0.25], [-0.125, -0.125, 0.125], zeros], + [[0.0, 0.25, -0.25], [0.375, -0.375, -0.375], [-0.125, 0.125, 0.125], [0.25, 0.25, 0.0]], + [[0.5, 0.0, 0.0], [0.25, -0.25, 0.25], [-0.125, 0.125, -0.125], [0.125, -0.125, 0.125]], + [[0.125, -0.125, 0.125], [0.25, -0.25, 0.0], [0.25, -0.25, 0.0], zeros], + [[0.25, 0.25, -0.25], [0.25, 0.25, -0.25], [0.125, 0.125, -0.125], [-0.125, -0.125, 0.125]], + [[-0.0, 0.0, 0.5], [-0.25, -0.25, 0.25], [-0.125, -0.125, 0.125], zeros], + [[0.125, 0.125, 0.125], [0.125, -0.125, 0.125], [-0.125, 0.125, 0.125], zeros], + [[-0.125, 0.125, 0.125], [0.125, -0.125, 0.125], zeros, zeros], + [[-0.375, -0.375, 0.375], [-0.0, 0.25, 0.25], [0.125, 0.125, -0.125], [-0.25, -0.0, -0.25]], + [[0.0, -0.25, 0.25], [0.0, 0.25, -0.25], [0.125, -0.125, 0.125], zeros], + [[0.125, -0.125, 0.125], [-0.25, -0.0, -0.25], [0.25, 0.0, 0.25], zeros], + [[0.125, -0.125, 0.125], [0.125, -0.125, 0.125], zeros, zeros], + [[0.0, -0.5, 0.0], [0.125, 0.125, -0.125], [0.25, 0.25, -0.25], zeros], + [[0.0, -0.25, 0.25], [0.0, 0.25, -0.25], zeros, zeros], + [[0.125, 0.125, 0.125], [0.125, -0.125, 0.125], zeros, zeros], + [[0.125, -0.125, 0.125], zeros, zeros, zeros], + [[-0.5, 0.0, 0.0], [-0.125, -0.125, -0.125], [-0.25, -0.25, -0.25], zeros], + [[-0.5, 0.0, 0.0], [-0.125, -0.125, -0.125], [-0.25, -0.25, -0.25], [0.125, 0.125, 0.125]], + [[0.375, 0.375, 0.375], [0.0, 0.25, -0.25], [-0.125, -0.125, -0.125], [-0.25, 0.25, 0.0]], + [[0.125, -0.125, -0.125], [0.25, -0.25, 0.0], [0.25, -0.25, 0.0], zeros], + [[0.125, 0.125, 0.125], [0.375, 0.375, 0.375], [0.0, -0.25, 0.25], [-0.25, 0.0, 0.25]], + [[-0.25, 0.0, 0.25], [-0.25, 0.0, 0.25], [0.125, -0.125, -0.125], zeros], + [[0.0, -0.25, -0.25], [0.0, 0.25, 0.25], [-0.125, 0.125, 0.125], zeros], + [[-0.125, 0.125, 0.125], [0.125, -0.125, -0.125], zeros, zeros], + [[-0.125, -0.125, -0.125], [-0.25, -0.25, -0.25], [0.25, 0.25, 0.25], [0.125, 0.125, 0.125]], + [[-0.125, -0.125, 0.125], [0.125, -0.125, 0.125], [0.125, -0.125, -0.125], zeros], + [[0.0, 0.0, -0.5], [0.25, 0.25, 0.25], [-0.125, -0.125, -0.125], zeros], + [[0.125, -0.125, 0.125], [0.125, -0.125, -0.125], zeros, zeros], + [[0.0, -0.5, 0.0], [0.25, 0.25, 0.25], [0.125, 0.125, 0.125], zeros], + [[-0.125, -0.125, 0.125], [0.125, -0.125, -0.125], zeros, zeros], + [[0.0, -0.25, -0.25], [0.0, 0.25, 0.25], zeros, zeros], + [[0.125, -0.125, -0.125], zeros, zeros, zeros], + [[0.5, 0.0, 0.0], [0.5, 0.0, 0.0], zeros, zeros], + [[-0.5, 0.0, 0.0], [-0.25, 0.25, 0.25], [-0.125, 0.125, 0.125], zeros], + [[0.5, 0.0, 0.0], [0.25, -0.25, 0.25], [-0.125, 0.125, -0.125], zeros], + [[0.25, -0.25, 0.0], [0.25, -0.25, 0.0], zeros, zeros], + [[0.5, 0.0, 0.0], [-0.25, -0.25, 0.25], [-0.125, -0.125, 0.125], zeros], + [[-0.25, 0.0, 0.25], [-0.25, 0.0, 0.25], zeros, zeros], + [[0.125, 0.125, 0.125], [-0.125, 0.125, 0.125], zeros, zeros], + [[-0.125, 0.125, 0.125], zeros, zeros, zeros], + [[0.5, 0.0, -0.0], [0.25, 0.25, 0.25], [0.125, 0.125, 0.125], zeros], + [[0.125, -0.125, 0.125], [-0.125, -0.125, 0.125], zeros, zeros], + [[-0.25, -0.0, -0.25], [0.25, 0.0, 0.25], zeros, zeros], + [[0.125, 0.125, 0.125], zeros, zeros, zeros], + [[-0.25, -0.25, 0.0], [0.25, 0.25, -0.0], zeros, zeros], + [[0.125, 0.125, 0.125], zeros, zeros, zeros], + [[0.125, 0.125, 0.125], zeros, zeros, zeros], + [zeros, zeros, zeros, zeros], + ] + return torch.as_tensor(ret, device=device) + + +def create_table_neighbour_code_to_surface_area(spacing_mm, device=None): + """ + Returns an array mapping neighbourhood code to the surface elements area. + Adapted from https://github.com/deepmind/surface-distance + + Note that the normals encode the initial surface area. This function computes + the area corresponding to the given `spacing`. + + Args: + spacing_mm: a sequence of 3 numbers. Voxel spacing along the first 3 spatial axes. + deivce: device to put the table on. + + Returns: + An array of size 256, mapping neighbourhood code to the surface area. + ENCODING_KERNEL[3] which is the kernel used to compute the neighbourhood code. + """ + # compute the area for all 256 possible surface elements given a 2x2x2 neighbourhood) according to the spacing_mm + neighbour_code_to_surface_area = np.zeros([256]) + c = _get_neighbour_code_to_normals_table(device) + s = torch.as_tensor( + [[[spacing_mm[1] * spacing_mm[2], spacing_mm[0] * spacing_mm[2], spacing_mm[0] * spacing_mm[1]]]], + device=device, + dtype=c.dtype, + ) + norm = torch.linalg.norm(c * s, dim=-1) + neighbour_code_to_surface_area = norm.sum(-1) + return neighbour_code_to_surface_area, torch.as_tensor([[ENCODING_KERNEL[3]]], device=device) + + +def create_table_neighbour_code_to_contour_length(spacing_mm, device=None): + """ + Returns an array mapping neighbourhood code to the contour length. + Adapted from https://github.com/deepmind/surface-distance + + In 2D, each point has 4 neighbors. Thus, are 16 configurations. A + configuration is encoded with '1' meaning "inside the object" and '0' "outside + the object". For example, + "0101" and "1010" both encode an edge along the first spatial axis with length spacing[0] mm; + "0011" and "1100" both encode an edge along the second spatial axis with length spacing[1] mm. + + Args: + spacing_mm: 2-element list-like structure. Pixel spacing along the 1st and 2nd spatial axes. + device: device to put the table on. + + Returns: + A 16-element array mapping neighbourhood code to the contour length. + ENCODING_KERNEL[2] which is the kernel used to compute the neighbourhood code. + """ + first, second = spacing_mm # spacing along the first and second spatial dimension respectively + diag = 0.5 * np.linalg.norm(spacing_mm) + + neighbour_code_to_contour_length = np.zeros([16], dtype=diag.dtype) + neighbour_code_to_contour_length[int("0001", 2)] = diag + neighbour_code_to_contour_length[int("0010", 2)] = diag + neighbour_code_to_contour_length[int("0011", 2)] = second + neighbour_code_to_contour_length[int("0100", 2)] = diag + neighbour_code_to_contour_length[int("0101", 2)] = first + neighbour_code_to_contour_length[int("0110", 2)] = 2 * diag + neighbour_code_to_contour_length[int("0111", 2)] = diag + neighbour_code_to_contour_length[int("1000", 2)] = diag + neighbour_code_to_contour_length[int("1001", 2)] = 2 * diag + neighbour_code_to_contour_length[int("1010", 2)] = first + neighbour_code_to_contour_length[int("1011", 2)] = diag + neighbour_code_to_contour_length[int("1100", 2)] = second + neighbour_code_to_contour_length[int("1101", 2)] = diag + neighbour_code_to_contour_length[int("1110", 2)] = diag + neighbour_code_to_contour_length = convert_to_tensor(neighbour_code_to_contour_length, device=device) + return neighbour_code_to_contour_length, torch.as_tensor([[ENCODING_KERNEL[2]]], device=device) + + +def get_code_to_measure_table(spacing, device=None): + """ + returns a table mapping neighbourhood code to the surface area or contour length. + + Args: + spacing: a sequence of 2 or 3 numbers, indicating the spacing in the spatial dimensions. + device: device to put the table on. + """ + spatial_dims = len(spacing) + spacing = ensure_tuple_rep(spacing, look_up_option(spatial_dims, (2, 3))) + if spatial_dims == 2: + return create_table_neighbour_code_to_contour_length(spacing, device) + return create_table_neighbour_code_to_surface_area(spacing, device) diff --git a/monai/transforms/croppad/dictionary.py b/monai/transforms/croppad/dictionary.py index 6dc8f10c32..a58db7cc62 100644 --- a/monai/transforms/croppad/dictionary.py +++ b/monai/transforms/croppad/dictionary.py @@ -732,8 +732,8 @@ def __init__( allow_smaller: bool = True, k_divisible: Sequence[int] | int = 1, mode: SequenceStr = PytorchPadMode.CONSTANT, - start_coord_key: str = "foreground_start_coord", - end_coord_key: str = "foreground_end_coord", + start_coord_key: str | None = "foreground_start_coord", + end_coord_key: str | None = "foreground_end_coord", allow_missing_keys: bool = False, lazy: bool = False, **pad_kwargs, diff --git a/tests/test_surface_dice.py b/tests/test_surface_dice.py index c5400aea39..e4c6cf6934 100644 --- a/tests/test_surface_dice.py +++ b/tests/test_surface_dice.py @@ -17,7 +17,8 @@ import torch import torch.nn.functional as F -from monai.metrics.surface_dice import SurfaceDiceMetric +from monai.metrics.surface_dice import SurfaceDiceMetric, compute_surface_dice +from tests.utils import assert_allclose _device = "cuda:0" if torch.cuda.is_available() else "cpu" @@ -330,21 +331,8 @@ def test_asserts(self): predictions_no_hot = torch.clone(predictions_hot) predictions_no_hot[0, :, 0, 0] = torch.tensor([2, 0]) - with self.assertRaises(ValueError) as context: - SurfaceDiceMetric(class_thresholds=[1, 1], include_background=True)(predictions_no_hot, predictions_hot) - self.assertEqual("y_pred and y should be one-hot encoded.", str(context.exception)) - with self.assertRaises(ValueError) as context: - SurfaceDiceMetric(class_thresholds=[1, 1], include_background=True)(predictions_hot, predictions_no_hot) - self.assertEqual("y_pred and y should be one-hot encoded.", str(context.exception)) - predictions_no_hot = predictions_no_hot.float() predictions_no_hot[0, :, 0, 0] = torch.tensor([0.5, 0]) - with self.assertRaises(ValueError) as context: - SurfaceDiceMetric(class_thresholds=[1, 1], include_background=True)(predictions_no_hot, predictions_hot) - self.assertEqual("y_pred and y should be binarized tensors (e.g. torch.int64).", str(context.exception)) - with self.assertRaises(ValueError) as context: - SurfaceDiceMetric(class_thresholds=[1, 1], include_background=True)(predictions_hot, predictions_no_hot) - self.assertEqual("y_pred and y should be binarized tensors (e.g. torch.int64).", str(context.exception)) # wrong number of class thresholds with self.assertRaises(ValueError) as context: @@ -401,6 +389,27 @@ def test_not_predicted_not_present(self): np.testing.assert_equal(res, torch.tensor([0], dtype=torch.float)) np.testing.assert_equal(not_nans, torch.tensor([0], dtype=torch.float)) + def test_compute_surface_dice_subvoxel(self): + mask_gt, mask_pred = torch.zeros(1, 1, 128, 128, 128), torch.zeros(1, 1, 128, 128, 128) + mask_gt[0, 0, 50, 60, 70] = 1 + res = compute_surface_dice( + mask_pred, mask_gt, class_thresholds=[1.0], include_background=True, spacing=(3, 2, 1), use_subvoxels=True + ) + assert_allclose(res, 0.0, type_test=False) + mask_pred[0, 0, 50, 60, 72] = 1 + res = compute_surface_dice( + mask_pred, mask_gt, class_thresholds=[1.0], include_background=True, spacing=(3, 2, 1), use_subvoxels=True + ) + assert_allclose(res, 0.5, type_test=False) + + mask_gt, mask_pred = torch.zeros(1, 1, 100, 100, 100), torch.zeros(1, 1, 100, 100, 100) + mask_gt[0, 0, 0:50, :, :] = 1 + mask_pred[0, 0, 0:51, :, :] = 1 + res = compute_surface_dice( + mask_pred, mask_gt, class_thresholds=[1.0], include_background=True, spacing=(2, 1, 1), use_subvoxels=True + ) + assert_allclose(res, 0.836145, type_test=False, atol=1e-3, rtol=1e-3) + if __name__ == "__main__": unittest.main() From b3a44f4d253cf77cb64362cf028ccd565f834df6 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Sat, 1 Jul 2023 17:57:11 +0100 Subject: [PATCH 02/10] update docstring Signed-off-by: Wenqi Li --- monai/metrics/surface_dice.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/monai/metrics/surface_dice.py b/monai/metrics/surface_dice.py index a4d1cc6cdd..5237512922 100644 --- a/monai/metrics/surface_dice.py +++ b/monai/metrics/surface_dice.py @@ -35,8 +35,9 @@ class SurfaceDiceMetric(CumulativeIterationMetric): Computes the Normalized Surface Dice (NSD) for each batch sample and class of predicted segmentations `y_pred` and corresponding reference segmentations `y` according to equation :eq:`nsd`. This implementation is based on https://arxiv.org/abs/2111.05408 and supports 2D and 3D images. - Be aware that the computation of boundaries is different from DeepMind's implementation - https://github.com/deepmind/surface-distance. In this implementation, the length/area of a segmentation boundary is + Be aware that by default (`use_subvoxels=False`), the computation of boundaries is different from DeepMind's + mplementation https://github.com/deepmind/surface-distance. + In this implementation, the length/area of a segmentation boundary is interpreted as the number of its edge pixels. In DeepMind's implementation, the length of a segmentation boundary depends on the local neighborhood (cf. https://arxiv.org/abs/1809.04430). This issue is discussed here: https://github.com/Project-MONAI/MONAI/issues/4103. From 831c940b9ed7342ed8ed0fac19931e060cfe6daa Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Sat, 1 Jul 2023 18:08:43 +0100 Subject: [PATCH 03/10] py38 Signed-off-by: Wenqi Li --- monai/metrics/utils.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/monai/metrics/utils.py b/monai/metrics/utils.py index 35bf3feba6..a45e4886db 100644 --- a/monai/metrics/utils.py +++ b/monai/metrics/utils.py @@ -12,7 +12,7 @@ from __future__ import annotations import warnings -from functools import cache +from functools import lru_cache from typing import Any, Sequence import numpy as np @@ -374,7 +374,7 @@ def prepare_spacing( ENCODING_KERNEL = {2: [[8, 4], [2, 1]], 3: [[[128, 64], [32, 16]], [[8, 4], [2, 1]]]} -@cache +@lru_cache(maxsize=None) def _get_neighbour_code_to_normals_table(device=None): """ returns a lookup table. For every binary neighbour code (2x2x2 neighbourhood = 8 neighbours = 8 bits = 256 codes) @@ -657,14 +657,13 @@ def create_table_neighbour_code_to_surface_area(spacing_mm, device=None): Args: spacing_mm: a sequence of 3 numbers. Voxel spacing along the first 3 spatial axes. - deivce: device to put the table on. + device: device to put the table on. Returns: An array of size 256, mapping neighbourhood code to the surface area. ENCODING_KERNEL[3] which is the kernel used to compute the neighbourhood code. """ - # compute the area for all 256 possible surface elements given a 2x2x2 neighbourhood) according to the spacing_mm - neighbour_code_to_surface_area = np.zeros([256]) + # compute the area for all 256 possible surface elements given a 2x2x2 neighbourhood according to the spacing_mm c = _get_neighbour_code_to_normals_table(device) s = torch.as_tensor( [[[spacing_mm[1] * spacing_mm[2], spacing_mm[0] * spacing_mm[2], spacing_mm[0] * spacing_mm[1]]]], From fecc094c15e07b5bd5923290486ed15f8d3225d3 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Sat, 1 Jul 2023 18:52:36 +0100 Subject: [PATCH 04/10] simplify Signed-off-by: Wenqi Li --- monai/metrics/surface_dice.py | 26 +++++++++----------------- monai/metrics/utils.py | 6 +++--- tests/test_surface_dice.py | 13 ++++++++++--- 3 files changed, 22 insertions(+), 23 deletions(-) diff --git a/monai/metrics/surface_dice.py b/monai/metrics/surface_dice.py index 5237512922..79964dbace 100644 --- a/monai/metrics/surface_dice.py +++ b/monai/metrics/surface_dice.py @@ -256,11 +256,6 @@ def compute_surface_dice( for b, c in np.ndindex(batch_size, n_class): if not use_subvoxels: (edges_pred, edges_gt) = get_mask_edges(y_pred[b, c], y[b, c], crop=True) - if not np.any(edges_gt): - warnings.warn(f"the ground truth of class {c} is all 0, this may result in nan/inf distance.") - if not np.any(edges_pred): - warnings.warn(f"the prediction of class {c} is all 0, this may result in nan/inf distance.") - distances_pred_gt = get_surface_distance( edges_pred, edges_gt, distance_metric=distance_metric, spacing=spacing_list[b] ) @@ -279,20 +274,17 @@ def compute_surface_dice( edges_pred, edges_gt, areas_pred, areas_gt = get_mask_edges( # type: ignore y_pred[b, c], y[b, c], crop=True, spacing=_spacing # type: ignore ) - if not np.any(edges_gt): - warnings.warn(f"the ground truth of class {c} is all 0, this may result in nan/inf distance.") - if not np.any(edges_pred): - warnings.warn(f"the prediction of class {c} is all 0, this may result in nan/inf distance.") - dist_gt = get_surface_distance(None, edges_gt, distance_metric=distance_metric, spacing=spacing_list[b]) - dist_pred = get_surface_distance(None, edges_pred, distance_metric=distance_metric, spacing=spacing_list[b]) - dist_gt_to_pred, dist_pred_to_gt = dist_pred[edges_gt], dist_gt[edges_pred] + dist_pred_to_gt = get_surface_distance(edges_pred, edges_gt, distance_metric, spacing=spacing_list[b]) + dist_gt_to_pred = get_surface_distance(edges_gt, edges_pred, distance_metric, spacing=spacing_list[b]) areas_gt, areas_pred = areas_gt[edges_gt], areas_pred[edges_pred] boundary_complete = areas_gt.sum() + areas_pred.sum() - boundary_correct = ( - areas_gt[dist_gt_to_pred <= class_thresholds[c]].sum() - + areas_pred[dist_pred_to_gt <= class_thresholds[c]].sum() - ) - + gt_true = areas_gt[dist_gt_to_pred <= class_thresholds[c]].sum() if len(areas_gt) > 0 else 0.0 + pred_true = areas_pred[dist_pred_to_gt <= class_thresholds[c]].sum() if len(areas_pred) > 0 else 0.0 + boundary_correct = gt_true + pred_true + if not np.any(edges_gt): + warnings.warn(f"the ground truth of class {c} is all 0, this may result in nan/inf distance.") + if not np.any(edges_pred): + warnings.warn(f"the prediction of class {c} is all 0, this may result in nan/inf distance.") if boundary_complete == 0: # the class is neither present in the prediction, nor in the reference segmentation nsd[b, c] = np.nan diff --git a/monai/metrics/utils.py b/monai/metrics/utils.py index a45e4886db..c23d96d90b 100644 --- a/monai/metrics/utils.py +++ b/monai/metrics/utils.py @@ -207,7 +207,7 @@ def get_mask_edges( def get_surface_distance( - seg_pred: np.ndarray | None, + seg_pred: np.ndarray, seg_gt: np.ndarray, distance_metric: str = "euclidean", spacing: int | float | np.ndarray | Sequence[int | float] | None = None, @@ -240,7 +240,7 @@ def get_surface_distance( if not np.any(seg_gt): dis = np.inf * np.ones_like(seg_gt) else: - if seg_pred is not None and not np.any(seg_pred): + if not np.any(seg_pred): dis = np.inf * np.ones_like(seg_gt) return np.asarray(dis[seg_gt]) if distance_metric == "euclidean": @@ -250,7 +250,7 @@ def get_surface_distance( else: raise ValueError(f"distance_metric {distance_metric} is not implemented.") - return np.asarray(dis[seg_pred]) if seg_pred is not None else np.asarray(dis) + return np.asarray(dis[seg_pred]) def is_binary_tensor(input: torch.Tensor, name: str) -> None: diff --git a/tests/test_surface_dice.py b/tests/test_surface_dice.py index e4c6cf6934..64ab6bfee9 100644 --- a/tests/test_surface_dice.py +++ b/tests/test_surface_dice.py @@ -393,12 +393,19 @@ def test_compute_surface_dice_subvoxel(self): mask_gt, mask_pred = torch.zeros(1, 1, 128, 128, 128), torch.zeros(1, 1, 128, 128, 128) mask_gt[0, 0, 50, 60, 70] = 1 res = compute_surface_dice( - mask_pred, mask_gt, class_thresholds=[1.0], include_background=True, spacing=(3, 2, 1), use_subvoxels=True + mask_pred, mask_gt, [1.0], include_background=True, spacing=(3, 2, 1), use_subvoxels=True ) assert_allclose(res, 0.0, type_test=False) + mask_gt[0, 0, 50, 60, 70] = 0 mask_pred[0, 0, 50, 60, 72] = 1 res = compute_surface_dice( - mask_pred, mask_gt, class_thresholds=[1.0], include_background=True, spacing=(3, 2, 1), use_subvoxels=True + mask_pred, mask_gt, [1.0], include_background=True, spacing=(3, 2, 1), use_subvoxels=True + ) + assert_allclose(res, 0.0, type_test=False) + mask_gt[0, 0, 50, 60, 70] = 1 + mask_pred[0, 0, 50, 60, 72] = 1 + res = compute_surface_dice( + mask_pred, mask_gt, [1.0], include_background=True, spacing=(3, 2, 1), use_subvoxels=True ) assert_allclose(res, 0.5, type_test=False) @@ -406,7 +413,7 @@ def test_compute_surface_dice_subvoxel(self): mask_gt[0, 0, 0:50, :, :] = 1 mask_pred[0, 0, 0:51, :, :] = 1 res = compute_surface_dice( - mask_pred, mask_gt, class_thresholds=[1.0], include_background=True, spacing=(2, 1, 1), use_subvoxels=True + mask_pred, mask_gt, [1.0], include_background=True, spacing=(2, 1, 1), use_subvoxels=True ) assert_allclose(res, 0.836145, type_test=False, atol=1e-3, rtol=1e-3) From 0fbfa6162e7116705aa9891ec6518b95e210a7e0 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Sat, 1 Jul 2023 18:56:30 +0100 Subject: [PATCH 05/10] update docs Signed-off-by: Wenqi Li --- monai/metrics/utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/monai/metrics/utils.py b/monai/metrics/utils.py index c23d96d90b..07edc1217e 100644 --- a/monai/metrics/utils.py +++ b/monai/metrics/utils.py @@ -132,14 +132,12 @@ def get_mask_edges( spacing: Sequence | None = None, ) -> tuple[np.ndarray, np.ndarray]: """ - Do binary erosion and use XOR for input to get the edges. This + Compute edges from binary segmentation masks. This function is helpful to further calculate metrics such as Average Surface Distance and Hausdorff Distance. The input images can be binary or labelfield images. If labelfield images are supplied, they are converted to binary images using `label_idx`. - `scipy`'s binary erosion is used to calculate the edges of the binary - labelfield. In order to improve the computing efficiency, before getting the edges, the images can be cropped and only keep the foreground if not specifies @@ -157,6 +155,8 @@ def get_mask_edges( maintain two inputs' shapes, here the bounding box is achieved by ``(seg_pred | seg_gt)`` which represents the union set of two images. Defaults to ``True``. + spacing: the input spacing. If not None, the subvoxel edges and areas will be computed. + otherwise `scipy`'s binary erosion is used to calculate the edges. """ # Get both labelfields as np arrays From b8860c12d23fff931946ffe3c517bbb806a723c8 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Sun, 2 Jul 2023 16:36:26 +0100 Subject: [PATCH 06/10] allow_smaller=False Signed-off-by: Wenqi Li --- monai/metrics/utils.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/monai/metrics/utils.py b/monai/metrics/utils.py index 07edc1217e..30cfc7b516 100644 --- a/monai/metrics/utils.py +++ b/monai/metrics/utils.py @@ -19,7 +19,7 @@ import torch from monai.config import NdarrayOrTensor, NdarrayTensor -from monai.transforms.croppad.dictionary import BorderPadD, CropForegroundD +from monai.transforms.croppad.dictionary import CropForegroundD from monai.utils import ( MetricReduction, convert_data_type, @@ -178,10 +178,11 @@ def get_mask_edges( if not np.any(seg_pred | seg_gt): pred, gt = np.zeros_like(seg_pred), np.zeros_like(seg_gt) return (pred, gt) if spacing is None else (pred, gt, pred, gt) # type: ignore - cropper = CropForegroundD(keys=["pred", "gt"], source_key="src", start_coord_key=None, end_coord_key=None) - pad = BorderPadD(keys=["pred", "gt"], spatial_border=1, mode="constant") + cropper = CropForegroundD( + ["pred", "gt"], source_key="src", margin=1, allow_smaller=False, start_coord_key=None, end_coord_key=None + ) mask = np.asarray(seg_pred | seg_gt) - cropped = pad(cropper({"pred": seg_pred[None], "gt": seg_gt[None], "src": mask[None]})) # type: ignore + cropped = cropper({"pred": seg_pred[None], "gt": seg_gt[None], "src": mask[None]}) # type: ignore seg_pred = cropped["pred"] seg_gt = cropped["gt"] From 3cb38631a58f2beea8e009570ccd37a21a9f16e3 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Mon, 3 Jul 2023 11:09:35 +0100 Subject: [PATCH 07/10] allow_smaller=True Signed-off-by: Wenqi Li --- monai/metrics/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/metrics/utils.py b/monai/metrics/utils.py index 30cfc7b516..a99165d040 100644 --- a/monai/metrics/utils.py +++ b/monai/metrics/utils.py @@ -179,7 +179,7 @@ def get_mask_edges( pred, gt = np.zeros_like(seg_pred), np.zeros_like(seg_gt) return (pred, gt) if spacing is None else (pred, gt, pred, gt) # type: ignore cropper = CropForegroundD( - ["pred", "gt"], source_key="src", margin=1, allow_smaller=False, start_coord_key=None, end_coord_key=None + ["pred", "gt"], source_key="src", margin=1, allow_smaller=True, start_coord_key=None, end_coord_key=None ) mask = np.asarray(seg_pred | seg_gt) cropped = cropper({"pred": seg_pred[None], "gt": seg_gt[None], "src": mask[None]}) # type: ignore From 01ad65dd7dbb7effc6c3684d245ba3f662c197db Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Tue, 4 Jul 2023 09:26:04 +0100 Subject: [PATCH 08/10] update based on comments, simplify datatype converting Signed-off-by: Wenqi Li --- monai/metrics/utils.py | 33 ++++++++++++++------------------- tests/test_surface_dice.py | 7 ------- 2 files changed, 14 insertions(+), 26 deletions(-) diff --git a/monai/metrics/utils.py b/monai/metrics/utils.py index a99165d040..899e682c1f 100644 --- a/monai/metrics/utils.py +++ b/monai/metrics/utils.py @@ -138,7 +138,6 @@ def get_mask_edges( The input images can be binary or labelfield images. If labelfield images are supplied, they are converted to binary images using `label_idx`. - In order to improve the computing efficiency, before getting the edges, the images can be cropped and only keep the foreground if not specifies ``crop = False``. @@ -158,45 +157,38 @@ def get_mask_edges( spacing: the input spacing. If not None, the subvoxel edges and areas will be computed. otherwise `scipy`'s binary erosion is used to calculate the edges. """ - - # Get both labelfields as np arrays - if isinstance(seg_pred, torch.Tensor): - seg_pred = seg_pred.detach().cpu().numpy() - if isinstance(seg_gt, torch.Tensor): - seg_gt = seg_gt.detach().cpu().numpy() - if seg_pred.shape != seg_gt.shape: raise ValueError(f"seg_pred and seg_gt should have same shapes, got {seg_pred.shape} and {seg_gt.shape}.") # If not binary images, convert them - if seg_pred.dtype != bool: + if seg_pred.dtype not in (bool, torch.bool): seg_pred = seg_pred == label_idx - if seg_gt.dtype != bool: + if seg_gt.dtype not in (bool, torch.bool): seg_gt = seg_gt == label_idx if crop: - if not np.any(seg_pred | seg_gt): + if not (seg_pred | seg_gt).any(): pred, gt = np.zeros_like(seg_pred), np.zeros_like(seg_gt) return (pred, gt) if spacing is None else (pred, gt, pred, gt) # type: ignore cropper = CropForegroundD( ["pred", "gt"], source_key="src", margin=1, allow_smaller=True, start_coord_key=None, end_coord_key=None ) - mask = np.asarray(seg_pred | seg_gt) + mask = seg_pred | seg_gt cropped = cropper({"pred": seg_pred[None], "gt": seg_gt[None], "src": mask[None]}) # type: ignore - seg_pred = cropped["pred"] - seg_gt = cropped["gt"] + seg_pred = cropped["pred"][0] + seg_gt = cropped["gt"][0] if spacing is None: # Do binary erosion and use XOR to get edges - seg_pred = convert_data_type(seg_pred[0], np.ndarray)[0] - seg_gt = convert_data_type(seg_gt[0], np.ndarray)[0] + seg_pred = convert_data_type(seg_pred, np.ndarray)[0] + seg_gt = convert_data_type(seg_gt, np.ndarray)[0] edges_pred = binary_erosion(seg_pred) ^ seg_pred edges_gt = binary_erosion(seg_gt) ^ seg_gt return edges_pred, edges_gt code_to_area_table, k = get_code_to_measure_table(spacing) - sptial_dims = len(spacing) - conv = torch.nn.functional.conv3d if sptial_dims == 3 else torch.nn.functional.conv2d - code_pred, code_gt = conv(torch.stack([seg_pred, seg_gt], dim=0).float(), k.float()) # type: ignore + spatial_dims = len(spacing) + conv = torch.nn.functional.conv3d if spatial_dims == 3 else torch.nn.functional.conv2d + code_pred, code_gt = conv(torch.stack([seg_pred[None], seg_gt[None]], dim=0).float(), k.float()) # type: ignore # edges all_ones = len(code_to_area_table) - 1 edges_pred = (code_pred != 0) & (code_pred != all_ones) @@ -380,6 +372,7 @@ def _get_neighbour_code_to_normals_table(device=None): """ returns a lookup table. For every binary neighbour code (2x2x2 neighbourhood = 8 neighbours = 8 bits = 256 codes) it contains the surface normals of the triangles. The length of the normal vector encodes the surfel area. + Adapted from https://github.com/deepmind/surface-distance created using the marching_cube algorithm see e.g. https://en.wikipedia.org/wiki/Marching_cubes @@ -664,6 +657,7 @@ def create_table_neighbour_code_to_surface_area(spacing_mm, device=None): An array of size 256, mapping neighbourhood code to the surface area. ENCODING_KERNEL[3] which is the kernel used to compute the neighbourhood code. """ + spacing_mm = ensure_tuple_rep(spacing_mm, 3) # compute the area for all 256 possible surface elements given a 2x2x2 neighbourhood according to the spacing_mm c = _get_neighbour_code_to_normals_table(device) s = torch.as_tensor( @@ -695,6 +689,7 @@ def create_table_neighbour_code_to_contour_length(spacing_mm, device=None): A 16-element array mapping neighbourhood code to the contour length. ENCODING_KERNEL[2] which is the kernel used to compute the neighbourhood code. """ + spacing_mm = ensure_tuple_rep(spacing_mm, 2) first, second = spacing_mm # spacing along the first and second spatial dimension respectively diag = 0.5 * np.linalg.norm(spacing_mm) diff --git a/tests/test_surface_dice.py b/tests/test_surface_dice.py index 64ab6bfee9..1d0744461d 100644 --- a/tests/test_surface_dice.py +++ b/tests/test_surface_dice.py @@ -327,13 +327,6 @@ def test_asserts(self): str(context.exception), ) - # input tensors not one-hot encoded - predictions_no_hot = torch.clone(predictions_hot) - predictions_no_hot[0, :, 0, 0] = torch.tensor([2, 0]) - - predictions_no_hot = predictions_no_hot.float() - predictions_no_hot[0, :, 0, 0] = torch.tensor([0.5, 0]) - # wrong number of class thresholds with self.assertRaises(ValueError) as context: SurfaceDiceMetric(class_thresholds=[1, 1, 1], include_background=True)(predictions_hot, labels_hot) From c774e92e542ab99d1dffc89b8cd5b381252e0396 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Tue, 4 Jul 2023 05:59:05 -0400 Subject: [PATCH 09/10] test gpu Signed-off-by: Wenqi Li --- monai/metrics/utils.py | 34 +++++++++++++++++++--------------- tests/test_surface_dice.py | 3 ++- 2 files changed, 21 insertions(+), 16 deletions(-) diff --git a/monai/metrics/utils.py b/monai/metrics/utils.py index 899e682c1f..10e4eba416 100644 --- a/monai/metrics/utils.py +++ b/monai/metrics/utils.py @@ -24,6 +24,7 @@ MetricReduction, convert_data_type, convert_to_tensor, + convert_to_numpy, ensure_tuple_rep, look_up_option, optional_import, @@ -165,30 +166,32 @@ def get_mask_edges( seg_pred = seg_pred == label_idx if seg_gt.dtype not in (bool, torch.bool): seg_gt = seg_gt == label_idx - if crop: - if not (seg_pred | seg_gt).any(): - pred, gt = np.zeros_like(seg_pred), np.zeros_like(seg_gt) + or_vol = seg_pred | seg_gt + if not or_vol.any(): + pred, gt = np.zeros(seg_pred.shape, dtype=bool), np.zeros(seg_gt.shape, dtype=bool) return (pred, gt) if spacing is None else (pred, gt, pred, gt) # type: ignore + channel_first = [seg_pred[None], seg_gt[None], or_vol[None]] + if spacing is None: # cpu only erosion + seg_pred, seg_gt, or_vol = convert_to_tensor(channel_first, device="cpu", dtype=bool) + else: # pytorch subvoxel, maybe on gpu, but croppad boolean values on GPU is not supported + seg_pred, seg_gt, or_vol = convert_to_tensor(channel_first, dtype=torch.float16) cropper = CropForegroundD( ["pred", "gt"], source_key="src", margin=1, allow_smaller=True, start_coord_key=None, end_coord_key=None ) - mask = seg_pred | seg_gt - cropped = cropper({"pred": seg_pred[None], "gt": seg_gt[None], "src": mask[None]}) # type: ignore - seg_pred = cropped["pred"][0] - seg_gt = cropped["gt"][0] - - if spacing is None: - # Do binary erosion and use XOR to get edges - seg_pred = convert_data_type(seg_pred, np.ndarray)[0] - seg_gt = convert_data_type(seg_gt, np.ndarray)[0] + cropped = cropper({"pred": seg_pred, "gt": seg_gt, "src": or_vol}) # type: ignore + seg_pred, seg_gt = cropped["pred"][0], cropped["gt"][0] + + if spacing is None: # Do binary erosion and use XOR to get edges + seg_pred, seg_gt = convert_to_numpy([seg_pred, seg_gt], dtype=bool) edges_pred = binary_erosion(seg_pred) ^ seg_pred edges_gt = binary_erosion(seg_gt) ^ seg_gt return edges_pred, edges_gt - code_to_area_table, k = get_code_to_measure_table(spacing) + code_to_area_table, k = get_code_to_measure_table(spacing, device=seg_pred.device) spatial_dims = len(spacing) conv = torch.nn.functional.conv3d if spatial_dims == 3 else torch.nn.functional.conv2d - code_pred, code_gt = conv(torch.stack([seg_pred[None], seg_gt[None]], dim=0).float(), k.float()) # type: ignore + vol = torch.stack([seg_pred[None], seg_gt[None]], dim=0).float() + code_pred, code_gt = conv(vol, k.to(vol)) # type: ignore # edges all_ones = len(code_to_area_table) - 1 edges_pred = (code_pred != 0) & (code_pred != all_ones) @@ -196,7 +199,8 @@ def get_mask_edges( # areas of edges areas_pred = torch.index_select(code_to_area_table, 0, code_pred.view(-1).int()).reshape(code_pred.shape) areas_gt = torch.index_select(code_to_area_table, 0, code_gt.view(-1).int()).reshape(code_gt.shape) - return edges_pred.array[0], edges_gt.array[0], areas_pred.array[0], areas_gt.array[0] # type: ignore + ret = (edges_pred[0], edges_gt[0], areas_pred[0], areas_gt[0]) + return convert_to_numpy(ret, wrap_sequence=False) def get_surface_distance( diff --git a/tests/test_surface_dice.py b/tests/test_surface_dice.py index 1d0744461d..53b0d38bb2 100644 --- a/tests/test_surface_dice.py +++ b/tests/test_surface_dice.py @@ -402,7 +402,8 @@ def test_compute_surface_dice_subvoxel(self): ) assert_allclose(res, 0.5, type_test=False) - mask_gt, mask_pred = torch.zeros(1, 1, 100, 100, 100), torch.zeros(1, 1, 100, 100, 100) + d = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + mask_gt, mask_pred = torch.zeros(1, 1, 100, 100, 100, device=d), torch.zeros(1, 1, 100, 100, 100, device=d) mask_gt[0, 0, 0:50, :, :] = 1 mask_pred[0, 0, 0:51, :, :] = 1 res = compute_surface_dice( From 046b27a33eaa91296f4bb8030643850cd354e92b Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Tue, 4 Jul 2023 11:09:40 +0100 Subject: [PATCH 10/10] style Signed-off-by: Wenqi Li --- monai/metrics/utils.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/monai/metrics/utils.py b/monai/metrics/utils.py index 10e4eba416..2ca69b5540 100644 --- a/monai/metrics/utils.py +++ b/monai/metrics/utils.py @@ -22,9 +22,8 @@ from monai.transforms.croppad.dictionary import CropForegroundD from monai.utils import ( MetricReduction, - convert_data_type, - convert_to_tensor, convert_to_numpy, + convert_to_tensor, ensure_tuple_rep, look_up_option, optional_import, @@ -187,10 +186,10 @@ def get_mask_edges( edges_pred = binary_erosion(seg_pred) ^ seg_pred edges_gt = binary_erosion(seg_gt) ^ seg_gt return edges_pred, edges_gt - code_to_area_table, k = get_code_to_measure_table(spacing, device=seg_pred.device) + code_to_area_table, k = get_code_to_measure_table(spacing, device=seg_pred.device) # type: ignore spatial_dims = len(spacing) conv = torch.nn.functional.conv3d if spatial_dims == 3 else torch.nn.functional.conv2d - vol = torch.stack([seg_pred[None], seg_gt[None]], dim=0).float() + vol = torch.stack([seg_pred[None], seg_gt[None]], dim=0).float() # type: ignore code_pred, code_gt = conv(vol, k.to(vol)) # type: ignore # edges all_ones = len(code_to_area_table) - 1 @@ -200,7 +199,7 @@ def get_mask_edges( areas_pred = torch.index_select(code_to_area_table, 0, code_pred.view(-1).int()).reshape(code_pred.shape) areas_gt = torch.index_select(code_to_area_table, 0, code_gt.view(-1).int()).reshape(code_gt.shape) ret = (edges_pred[0], edges_gt[0], areas_pred[0], areas_gt[0]) - return convert_to_numpy(ret, wrap_sequence=False) + return convert_to_numpy(ret, wrap_sequence=False) # type: ignore def get_surface_distance(