diff --git a/monai/metrics/surface_dice.py b/monai/metrics/surface_dice.py index 43305ca834..79964dbace 100644 --- a/monai/metrics/surface_dice.py +++ b/monai/metrics/surface_dice.py @@ -35,8 +35,9 @@ class SurfaceDiceMetric(CumulativeIterationMetric): Computes the Normalized Surface Dice (NSD) for each batch sample and class of predicted segmentations `y_pred` and corresponding reference segmentations `y` according to equation :eq:`nsd`. This implementation is based on https://arxiv.org/abs/2111.05408 and supports 2D and 3D images. - Be aware that the computation of boundaries is different from DeepMind's implementation - https://github.com/deepmind/surface-distance. In this implementation, the length/area of a segmentation boundary is + Be aware that by default (`use_subvoxels=False`), the computation of boundaries is different from DeepMind's + mplementation https://github.com/deepmind/surface-distance. + In this implementation, the length/area of a segmentation boundary is interpreted as the number of its edge pixels. In DeepMind's implementation, the length of a segmentation boundary depends on the local neighborhood (cf. https://arxiv.org/abs/1809.04430). This issue is discussed here: https://github.com/Project-MONAI/MONAI/issues/4103. @@ -86,7 +87,7 @@ def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor, **kwargs: Any) It must be a one-hot encoded, batch-first tensor [B,C,H,W] or [B,C,H,W,D]. y: Reference segmentation. It must be a one-hot encoded, batch-first tensor [B,C,H,W] or [B,C,H,W,D]. - kwargs: additional parameters, e.g. ``spacing`` should be passed to correctly compute the metric. + kwargs: additional parameters: ``spacing`` should be passed to correctly compute the metric. ``spacing``: spacing of pixel (or voxel). This parameter is relevant only if ``distance_metric`` is set to ``"euclidean"``. If a single number, isotropic spacing with that value is used for all images in the batch. If a sequence of numbers, @@ -96,6 +97,8 @@ def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor, **kwargs: Any) If inner sequence has length 1, isotropic spacing with that value is used for all images in the batch, else the inner sequence length must be equal to the image dimensions. If ``None``, spacing of unity is used for all images in batch. Defaults to ``None``. + use_subvoxels: Whether to use subvoxel distances. Defaults to ``False``. + Returns: Pytorch Tensor of shape [B,C], containing the NSD values :math:`\operatorname {NSD}_{b,c}` for each batch @@ -108,6 +111,7 @@ def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor, **kwargs: Any) include_background=self.include_background, distance_metric=self.distance_metric, spacing=kwargs.get("spacing"), + use_subvoxels=kwargs.get("use_subvoxels", False), ) def aggregate( @@ -141,13 +145,14 @@ def compute_surface_dice( include_background: bool = False, distance_metric: str = "euclidean", spacing: int | float | np.ndarray | Sequence[int | float | np.ndarray | Sequence[int | float]] | None = None, + use_subvoxels: bool = False, ) -> torch.Tensor: r""" This function computes the (Normalized) Surface Dice (NSD) between the two tensors `y_pred` (referred to as :math:`\hat{Y}`) and `y` (referred to as :math:`Y`). This metric determines which fraction of a segmentation boundary is correctly predicted. A boundary element is considered correctly predicted if the closest distance to the - reference boundary is smaller than or equal to the specified threshold related to the acceptable amount of deviation in - pixels. The NSD is bounded between 0 and 1. + reference boundary is smaller than or equal to the specified threshold related to the acceptable amount of deviation + in pixels. The NSD is bounded between 0 and 1. This implementation supports multi-class tasks with an individual threshold :math:`\tau_c` for each class :math:`c`. The class-specific NSD for batch index :math:`b`, :math:`\operatorname {NSD}_{b,c}`, is computed using the function: @@ -159,8 +164,8 @@ def compute_surface_dice( :label: nsd with :math:`\mathcal{D}_{Y_{b,c}}` and :math:`\mathcal{D}_{\hat{Y}_{b,c}}` being two sets of nearest-neighbor - distances. :math:`\mathcal{D}_{Y_{b,c}}` is computed from the predicted segmentation boundary towards the reference segmentation - boundary and vice-versa for :math:`\mathcal{D}_{\hat{Y}_{b,c}}`. :math:`\mathcal{D}_{Y_{b,c}}^{'}` and + distances. :math:`\mathcal{D}_{Y_{b,c}}` is computed from the predicted segmentation boundary towards the reference + segmentation boundary and vice-versa for :math:`\mathcal{D}_{\hat{Y}_{b,c}}`. :math:`\mathcal{D}_{Y_{b,c}}^{'}` and :math:`\mathcal{D}_{\hat{Y}_{b,c}}^{'}` refer to the subsets of distances that are smaller or equal to the acceptable distance :math:`\tau_c`: @@ -168,15 +173,14 @@ def compute_surface_dice( \mathcal{D}_{Y_{b,c}}^{'} = \{ d \in \mathcal{D}_{Y_{b,c}} \, | \, d \leq \tau_c \}. - In the case of a class neither being present in the predicted segmentation, nor in the reference segmentation, a nan value - will be returned for this class. In the case of a class being present in only one of predicted segmentation or - reference segmentation, the class NSD will be 0. + In the case of a class neither being present in the predicted segmentation, nor in the reference segmentation, + a nan value will be returned for this class. In the case of a class being present in only one of predicted + segmentation or reference segmentation, the class NSD will be 0. This implementation is based on https://arxiv.org/abs/2111.05408 and supports 2D and 3D images. - Be aware that the computation of boundaries is different from DeepMind's implementation - https://github.com/deepmind/surface-distance. In this implementation, the length of a segmentation boundary is - interpreted as the number of its edge pixels. In DeepMind's implementation, the length of a segmentation boundary - depends on the local neighborhood (cf. https://arxiv.org/abs/1809.04430). + The computation of boundaries follows DeepMind's implementation + https://github.com/deepmind/surface-distance when `use_subvoxels=True`; Otherwise the length of a segmentation + boundary is interpreted as the number of its edge pixels. Args: y_pred: Predicted segmentation, typically segmentation model output. @@ -198,6 +202,7 @@ def compute_surface_dice( If inner sequence has length 1, isotropic spacing with that value is used for all images in the batch, else the inner sequence length must be equal to the image dimensions. If ``None``, spacing of unity is used for all images in batch. Defaults to ``None``. + use_subvoxels: Whether to use subvoxel distances. Defaults to ``False``. Raises: ValueError: If `y_pred` and/or `y` are not PyTorch tensors. @@ -227,11 +232,6 @@ def compute_surface_dice( f"y_pred and y should have same shape, but instead, shapes are {y_pred.shape} (y_pred) and {y.shape} (y)." ) - if not torch.all(y_pred.byte() == y_pred) or not torch.all(y.byte() == y): - raise ValueError("y_pred and y should be binarized tensors (e.g. torch.int64).") - if torch.any(y_pred > 1) or torch.any(y > 1): - raise ValueError("y_pred and y should be one-hot encoded.") - y = y.float() y_pred = y_pred.float() @@ -254,24 +254,37 @@ def compute_surface_dice( spacing_list = prepare_spacing(spacing=spacing, batch_size=batch_size, img_dim=img_dim) for b, c in np.ndindex(batch_size, n_class): - (edges_pred, edges_gt) = get_mask_edges(y_pred[b, c], y[b, c], crop=False) + if not use_subvoxels: + (edges_pred, edges_gt) = get_mask_edges(y_pred[b, c], y[b, c], crop=True) + distances_pred_gt = get_surface_distance( + edges_pred, edges_gt, distance_metric=distance_metric, spacing=spacing_list[b] + ) + distances_gt_pred = get_surface_distance( + edges_gt, edges_pred, distance_metric=distance_metric, spacing=spacing_list[b] + ) + + boundary_complete = len(distances_pred_gt) + len(distances_gt_pred) + boundary_correct = np.sum(distances_pred_gt <= class_thresholds[c]) + np.sum( + distances_gt_pred <= class_thresholds[c] + ) + else: + _spacing = spacing_list[b] if spacing_list[b] is not None else [1] * img_dim + areas_pred: np.ndarray + areas_gt: np.ndarray + edges_pred, edges_gt, areas_pred, areas_gt = get_mask_edges( # type: ignore + y_pred[b, c], y[b, c], crop=True, spacing=_spacing # type: ignore + ) + dist_pred_to_gt = get_surface_distance(edges_pred, edges_gt, distance_metric, spacing=spacing_list[b]) + dist_gt_to_pred = get_surface_distance(edges_gt, edges_pred, distance_metric, spacing=spacing_list[b]) + areas_gt, areas_pred = areas_gt[edges_gt], areas_pred[edges_pred] + boundary_complete = areas_gt.sum() + areas_pred.sum() + gt_true = areas_gt[dist_gt_to_pred <= class_thresholds[c]].sum() if len(areas_gt) > 0 else 0.0 + pred_true = areas_pred[dist_pred_to_gt <= class_thresholds[c]].sum() if len(areas_pred) > 0 else 0.0 + boundary_correct = gt_true + pred_true if not np.any(edges_gt): warnings.warn(f"the ground truth of class {c} is all 0, this may result in nan/inf distance.") if not np.any(edges_pred): warnings.warn(f"the prediction of class {c} is all 0, this may result in nan/inf distance.") - - distances_pred_gt = get_surface_distance( - edges_pred, edges_gt, distance_metric=distance_metric, spacing=spacing_list[b] - ) - distances_gt_pred = get_surface_distance( - edges_gt, edges_pred, distance_metric=distance_metric, spacing=spacing_list[b] - ) - - boundary_complete = len(distances_pred_gt) + len(distances_gt_pred) - boundary_correct = np.sum(distances_pred_gt <= class_thresholds[c]) + np.sum( - distances_gt_pred <= class_thresholds[c] - ) - if boundary_complete == 0: # the class is neither present in the prediction, nor in the reference segmentation nsd[b, c] = np.nan diff --git a/monai/metrics/utils.py b/monai/metrics/utils.py index be121ef027..2ca69b5540 100644 --- a/monai/metrics/utils.py +++ b/monai/metrics/utils.py @@ -12,22 +12,37 @@ from __future__ import annotations import warnings -from collections.abc import Sequence -from typing import Any +from functools import lru_cache +from typing import Any, Sequence import numpy as np import torch from monai.config import NdarrayOrTensor, NdarrayTensor -from monai.transforms.croppad.array import SpatialCrop -from monai.transforms.utils import generate_spatial_bounding_box -from monai.utils import MetricReduction, convert_data_type, look_up_option, optional_import +from monai.transforms.croppad.dictionary import CropForegroundD +from monai.utils import ( + MetricReduction, + convert_to_numpy, + convert_to_tensor, + ensure_tuple_rep, + look_up_option, + optional_import, +) binary_erosion, _ = optional_import("scipy.ndimage.morphology", name="binary_erosion") distance_transform_edt, _ = optional_import("scipy.ndimage.morphology", name="distance_transform_edt") distance_transform_cdt, _ = optional_import("scipy.ndimage.morphology", name="distance_transform_cdt") -__all__ = ["ignore_background", "do_metric_reduction", "get_mask_edges", "get_surface_distance", "is_binary_tensor"] +__all__ = [ + "ignore_background", + "do_metric_reduction", + "get_mask_edges", + "get_surface_distance", + "is_binary_tensor", + "remap_instance_id", + "prepare_spacing", + "get_code_to_measure_table", +] def ignore_background(y_pred: NdarrayTensor, y: NdarrayTensor) -> tuple[NdarrayTensor, NdarrayTensor]: @@ -110,18 +125,19 @@ def do_metric_reduction( def get_mask_edges( - seg_pred: NdarrayOrTensor, seg_gt: NdarrayOrTensor, label_idx: int = 1, crop: bool = True + seg_pred: NdarrayOrTensor, + seg_gt: NdarrayOrTensor, + label_idx: int = 1, + crop: bool = True, + spacing: Sequence | None = None, ) -> tuple[np.ndarray, np.ndarray]: """ - Do binary erosion and use XOR for input to get the edges. This + Compute edges from binary segmentation masks. This function is helpful to further calculate metrics such as Average Surface Distance and Hausdorff Distance. The input images can be binary or labelfield images. If labelfield images are supplied, they are converted to binary images using `label_idx`. - `scipy`'s binary erosion is used to calculate the edges of the binary - labelfield. - In order to improve the computing efficiency, before getting the edges, the images can be cropped and only keep the foreground if not specifies ``crop = False``. @@ -138,39 +154,52 @@ def get_mask_edges( maintain two inputs' shapes, here the bounding box is achieved by ``(seg_pred | seg_gt)`` which represents the union set of two images. Defaults to ``True``. + spacing: the input spacing. If not None, the subvoxel edges and areas will be computed. + otherwise `scipy`'s binary erosion is used to calculate the edges. """ - - # Get both labelfields as np arrays - if isinstance(seg_pred, torch.Tensor): - seg_pred = seg_pred.detach().cpu().numpy() - if isinstance(seg_gt, torch.Tensor): - seg_gt = seg_gt.detach().cpu().numpy() - if seg_pred.shape != seg_gt.shape: raise ValueError(f"seg_pred and seg_gt should have same shapes, got {seg_pred.shape} and {seg_gt.shape}.") # If not binary images, convert them - if seg_pred.dtype != bool: + if seg_pred.dtype not in (bool, torch.bool): seg_pred = seg_pred == label_idx - if seg_gt.dtype != bool: + if seg_gt.dtype not in (bool, torch.bool): seg_gt = seg_gt == label_idx - if crop: - if not np.any(seg_pred | seg_gt): - return np.zeros_like(seg_pred), np.zeros_like(seg_gt) - - channel_dim = 0 - seg_pred, seg_gt = np.expand_dims(seg_pred, axis=channel_dim), np.expand_dims(seg_gt, axis=channel_dim) - box_start, box_end = generate_spatial_bounding_box(np.asarray(seg_pred | seg_gt)) - cropper = SpatialCrop(roi_start=box_start, roi_end=box_end) - seg_pred = convert_data_type(np.squeeze(cropper(seg_pred), axis=channel_dim), np.ndarray)[0] # type: ignore[arg-type] - seg_gt = convert_data_type(np.squeeze(cropper(seg_gt), axis=channel_dim), np.ndarray)[0] # type: ignore[arg-type] - - # Do binary erosion and use XOR to get edges - edges_pred = binary_erosion(seg_pred) ^ seg_pred - edges_gt = binary_erosion(seg_gt) ^ seg_gt - - return edges_pred, edges_gt + or_vol = seg_pred | seg_gt + if not or_vol.any(): + pred, gt = np.zeros(seg_pred.shape, dtype=bool), np.zeros(seg_gt.shape, dtype=bool) + return (pred, gt) if spacing is None else (pred, gt, pred, gt) # type: ignore + channel_first = [seg_pred[None], seg_gt[None], or_vol[None]] + if spacing is None: # cpu only erosion + seg_pred, seg_gt, or_vol = convert_to_tensor(channel_first, device="cpu", dtype=bool) + else: # pytorch subvoxel, maybe on gpu, but croppad boolean values on GPU is not supported + seg_pred, seg_gt, or_vol = convert_to_tensor(channel_first, dtype=torch.float16) + cropper = CropForegroundD( + ["pred", "gt"], source_key="src", margin=1, allow_smaller=True, start_coord_key=None, end_coord_key=None + ) + cropped = cropper({"pred": seg_pred, "gt": seg_gt, "src": or_vol}) # type: ignore + seg_pred, seg_gt = cropped["pred"][0], cropped["gt"][0] + + if spacing is None: # Do binary erosion and use XOR to get edges + seg_pred, seg_gt = convert_to_numpy([seg_pred, seg_gt], dtype=bool) + edges_pred = binary_erosion(seg_pred) ^ seg_pred + edges_gt = binary_erosion(seg_gt) ^ seg_gt + return edges_pred, edges_gt + code_to_area_table, k = get_code_to_measure_table(spacing, device=seg_pred.device) # type: ignore + spatial_dims = len(spacing) + conv = torch.nn.functional.conv3d if spatial_dims == 3 else torch.nn.functional.conv2d + vol = torch.stack([seg_pred[None], seg_gt[None]], dim=0).float() # type: ignore + code_pred, code_gt = conv(vol, k.to(vol)) # type: ignore + # edges + all_ones = len(code_to_area_table) - 1 + edges_pred = (code_pred != 0) & (code_pred != all_ones) + edges_gt = (code_gt != 0) & (code_gt != all_ones) + # areas of edges + areas_pred = torch.index_select(code_to_area_table, 0, code_pred.view(-1).int()).reshape(code_pred.shape) + areas_gt = torch.index_select(code_to_area_table, 0, code_gt.view(-1).int()).reshape(code_gt.shape) + ret = (edges_pred[0], edges_gt[0], areas_pred[0], areas_gt[0]) + return convert_to_numpy(ret, wrap_sequence=False) # type: ignore def get_surface_distance( @@ -259,13 +288,10 @@ def remap_instance_id(pred: torch.Tensor, by_size: bool = False) -> torch.Tensor # the original implementation has the limitation that if there is no 0 in pred, error will happen pred_id = [i for i in pred_id if i != 0] - if len(pred_id) == 0: + if not pred_id: return pred - if by_size is True: - instance_size = [] - for instance_id in pred_id: - instance_size.append((pred == instance_id).sum()) - + if by_size: + instance_size = [(pred == instance_id).sum() for instance_id in pred_id] pair_data = zip(pred_id, instance_size) pair_list = sorted(pair_data, key=lambda x: x[1], reverse=True) pred_id, _ = zip(*pair_list) @@ -292,14 +318,15 @@ def prepare_spacing( input spacing = [0.8, 0.7, 1.2, 0.8] -> output spacing = [0.8, 0.7, 1.2, 0.8] (same as input) An example with batch_size = 3 and img_dim = 3: - input spacing = [0.8, 0.5, 0.9] -> output spacing = [[0.8, 0.5, 0.9], [0.8, 0.5, 0.9], [0.8, 0.5, 0.9], [0.8, 0.5, 0.9]] + input spacing = [0.8, 0.5, 0.9] -> + output spacing = [[0.8, 0.5, 0.9], [0.8, 0.5, 0.9], [0.8, 0.5, 0.9], [0.8, 0.5, 0.9]] Args: spacing: can be a float, a sequence of length `img_dim`, or a sequence with length `batch_size` that includes floats or sequences of length `img_dim`. Raises: - AssertionError: when `spacing` is a sequence of sequence, where the outer sequence length does not + ValueError: when `spacing` is a sequence of sequence, where the outer sequence length does not equal `batch_size` or inner sequence length does not equal `img_dim`. Returns: @@ -307,30 +334,397 @@ def prepare_spacing( """ if spacing is None or isinstance(spacing, (int, float)): return list([spacing] * batch_size) - elif isinstance(spacing, (Sequence, np.ndarray)): - assert all( - isinstance(s, type(spacing[0])) for s in list(spacing) - ), "if `spacing` is a sequence, its elements should be of same type." - + if isinstance(spacing, (Sequence, np.ndarray)): + if any(not isinstance(s, type(spacing[0])) for s in list(spacing)): + raise ValueError(f"if `spacing` is a sequence, its elements should be of same type, got {spacing}.") if isinstance(spacing[0], (Sequence, np.ndarray)): - assert ( - len(spacing) == batch_size - ), "if `spacing` is a sequence of sequences, the outer sequence should have same length as batch size." - assert all( - len(s) == img_dim for s in list(spacing) - ), "each element of `spacing` list should either have same length as image dim." - assert all( - isinstance(i, (int, float)) for s in list(spacing) for i in list(s) - ), "if `spacing` is a sequence of sequences or 2D np.ndarray, the elements should be integers or floats." + if len(spacing) != batch_size: + raise ValueError( + "if `spacing` is a sequence of sequences, " + f"the outer sequence should have same length as batch size ({batch_size}), got {spacing}." + ) + if any(len(s) != img_dim for s in list(spacing)): + raise ValueError( + "each element of `spacing` list should either have same length as" + f"image dim ({img_dim}), got {spacing}." + ) + if not all(isinstance(i, (int, float)) for s in list(spacing) for i in list(s)): + raise ValueError( + f"if `spacing` is a sequence of sequences or 2D np.ndarray, " + f"the elements should be integers or floats, got {spacing}." + ) return list(spacing) - elif isinstance(spacing[0], (int, float)): - assert ( - len(spacing) == img_dim - ), "if `spacing` is a sequence of numbers, it should have same length as image dim." + if isinstance(spacing[0], (int, float)): + if len(spacing) != img_dim: + raise ValueError( + f"if `spacing` is a sequence of numbers, " + f"it should have same length as image dim ({img_dim}), got {spacing}." + ) return [spacing for _ in range(batch_size)] # type: ignore - else: - raise AssertionError(f"`spacing` is a sequence of elements with unsupported type: {type(spacing[0])}") - else: - raise AssertionError( - "`spacing` should either be an integer, float, a sequence of numbers or a sequence of sequences." - ) + raise ValueError(f"`spacing` is a sequence of elements with unsupported type: {type(spacing[0])}") + raise ValueError( + f"`spacing` should either be a number, a sequence of numbers or a sequence of sequences, got {spacing}." + ) + + +ENCODING_KERNEL = {2: [[8, 4], [2, 1]], 3: [[[128, 64], [32, 16]], [[8, 4], [2, 1]]]} + + +@lru_cache(maxsize=None) +def _get_neighbour_code_to_normals_table(device=None): + """ + returns a lookup table. For every binary neighbour code (2x2x2 neighbourhood = 8 neighbours = 8 bits = 256 codes) + it contains the surface normals of the triangles. The length of the normal vector encodes the surfel area. + Adapted from https://github.com/deepmind/surface-distance + + created using the marching_cube algorithm see e.g. https://en.wikipedia.org/wiki/Marching_cubes + + Args: + device: torch device to use for the table. + """ + zeros = [0.0, 0.0, 0.0] + ret = [ + [zeros, zeros, zeros, zeros], + [[0.125, 0.125, 0.125], zeros, zeros, zeros], + [[-0.125, -0.125, 0.125], zeros, zeros, zeros], + [[-0.25, -0.25, 0.0], [0.25, 0.25, -0.0], zeros, zeros], + [[0.125, -0.125, 0.125], zeros, zeros, zeros], + [[-0.25, -0.0, -0.25], [0.25, 0.0, 0.25], zeros, zeros], + [[0.125, -0.125, 0.125], [-0.125, -0.125, 0.125], zeros, zeros], + [[0.5, 0.0, -0.0], [0.25, 0.25, 0.25], [0.125, 0.125, 0.125], zeros], + [[-0.125, 0.125, 0.125], zeros, zeros, zeros], + [[0.125, 0.125, 0.125], [-0.125, 0.125, 0.125], zeros, zeros], + [[-0.25, 0.0, 0.25], [-0.25, 0.0, 0.25], zeros, zeros], + [[0.5, 0.0, 0.0], [-0.25, -0.25, 0.25], [-0.125, -0.125, 0.125], zeros], + [[0.25, -0.25, 0.0], [0.25, -0.25, 0.0], zeros, zeros], + [[0.5, 0.0, 0.0], [0.25, -0.25, 0.25], [-0.125, 0.125, -0.125], zeros], + [[-0.5, 0.0, 0.0], [-0.25, 0.25, 0.25], [-0.125, 0.125, 0.125], zeros], + [[0.5, 0.0, 0.0], [0.5, 0.0, 0.0], zeros, zeros], + [[0.125, -0.125, -0.125], zeros, zeros, zeros], + [[0.0, -0.25, -0.25], [0.0, 0.25, 0.25], zeros, zeros], + [[-0.125, -0.125, 0.125], [0.125, -0.125, -0.125], zeros, zeros], + [[0.0, -0.5, 0.0], [0.25, 0.25, 0.25], [0.125, 0.125, 0.125], zeros], + [[0.125, -0.125, 0.125], [0.125, -0.125, -0.125], zeros, zeros], + [[0.0, 0.0, -0.5], [0.25, 0.25, 0.25], [-0.125, -0.125, -0.125], zeros], + [[-0.125, -0.125, 0.125], [0.125, -0.125, 0.125], [0.125, -0.125, -0.125], zeros], + [[-0.125, -0.125, -0.125], [-0.25, -0.25, -0.25], [0.25, 0.25, 0.25], [0.125, 0.125, 0.125]], + [[-0.125, 0.125, 0.125], [0.125, -0.125, -0.125], zeros, zeros], + [[0.0, -0.25, -0.25], [0.0, 0.25, 0.25], [-0.125, 0.125, 0.125], zeros], + [[-0.25, 0.0, 0.25], [-0.25, 0.0, 0.25], [0.125, -0.125, -0.125], zeros], + [[0.125, 0.125, 0.125], [0.375, 0.375, 0.375], [0.0, -0.25, 0.25], [-0.25, 0.0, 0.25]], + [[0.125, -0.125, -0.125], [0.25, -0.25, 0.0], [0.25, -0.25, 0.0], zeros], + [[0.375, 0.375, 0.375], [0.0, 0.25, -0.25], [-0.125, -0.125, -0.125], [-0.25, 0.25, 0.0]], + [[-0.5, 0.0, 0.0], [-0.125, -0.125, -0.125], [-0.25, -0.25, -0.25], [0.125, 0.125, 0.125]], + [[-0.5, 0.0, 0.0], [-0.125, -0.125, -0.125], [-0.25, -0.25, -0.25], zeros], + [[0.125, -0.125, 0.125], zeros, zeros, zeros], + [[0.125, 0.125, 0.125], [0.125, -0.125, 0.125], zeros, zeros], + [[0.0, -0.25, 0.25], [0.0, 0.25, -0.25], zeros, zeros], + [[0.0, -0.5, 0.0], [0.125, 0.125, -0.125], [0.25, 0.25, -0.25], zeros], + [[0.125, -0.125, 0.125], [0.125, -0.125, 0.125], zeros, zeros], + [[0.125, -0.125, 0.125], [-0.25, -0.0, -0.25], [0.25, 0.0, 0.25], zeros], + [[0.0, -0.25, 0.25], [0.0, 0.25, -0.25], [0.125, -0.125, 0.125], zeros], + [[-0.375, -0.375, 0.375], [-0.0, 0.25, 0.25], [0.125, 0.125, -0.125], [-0.25, -0.0, -0.25]], + [[-0.125, 0.125, 0.125], [0.125, -0.125, 0.125], zeros, zeros], + [[0.125, 0.125, 0.125], [0.125, -0.125, 0.125], [-0.125, 0.125, 0.125], zeros], + [[-0.0, 0.0, 0.5], [-0.25, -0.25, 0.25], [-0.125, -0.125, 0.125], zeros], + [[0.25, 0.25, -0.25], [0.25, 0.25, -0.25], [0.125, 0.125, -0.125], [-0.125, -0.125, 0.125]], + [[0.125, -0.125, 0.125], [0.25, -0.25, 0.0], [0.25, -0.25, 0.0], zeros], + [[0.5, 0.0, 0.0], [0.25, -0.25, 0.25], [-0.125, 0.125, -0.125], [0.125, -0.125, 0.125]], + [[0.0, 0.25, -0.25], [0.375, -0.375, -0.375], [-0.125, 0.125, 0.125], [0.25, 0.25, 0.0]], + [[-0.5, 0.0, 0.0], [-0.25, -0.25, 0.25], [-0.125, -0.125, 0.125], zeros], + [[0.25, -0.25, 0.0], [-0.25, 0.25, 0.0], zeros, zeros], + [[0.0, 0.5, 0.0], [-0.25, 0.25, 0.25], [0.125, -0.125, -0.125], zeros], + [[0.0, 0.5, 0.0], [0.125, -0.125, 0.125], [-0.25, 0.25, -0.25], zeros], + [[0.0, 0.5, 0.0], [0.0, -0.5, 0.0], zeros, zeros], + [[0.25, -0.25, 0.0], [-0.25, 0.25, 0.0], [0.125, -0.125, 0.125], zeros], + [[-0.375, -0.375, -0.375], [-0.25, 0.0, 0.25], [-0.125, -0.125, -0.125], [-0.25, 0.25, 0.0]], + [[0.125, 0.125, 0.125], [0.0, -0.5, 0.0], [-0.25, -0.25, -0.25], [-0.125, -0.125, -0.125]], + [[0.0, -0.5, 0.0], [-0.25, -0.25, -0.25], [-0.125, -0.125, -0.125], zeros], + [[-0.125, 0.125, 0.125], [0.25, -0.25, 0.0], [-0.25, 0.25, 0.0], zeros], + [[0.0, 0.5, 0.0], [0.25, 0.25, -0.25], [-0.125, -0.125, 0.125], [-0.125, -0.125, 0.125]], + [[-0.375, 0.375, -0.375], [-0.25, -0.25, 0.0], [-0.125, 0.125, -0.125], [-0.25, 0.0, 0.25]], + [[0.0, 0.5, 0.0], [0.25, 0.25, -0.25], [-0.125, -0.125, 0.125], zeros], + [[0.25, -0.25, 0.0], [-0.25, 0.25, 0.0], [0.25, -0.25, 0.0], [0.25, -0.25, 0.0]], + [[-0.25, -0.25, 0.0], [-0.25, -0.25, 0.0], [-0.125, -0.125, 0.125], zeros], + [[0.125, 0.125, 0.125], [-0.25, -0.25, 0.0], [-0.25, -0.25, 0.0], zeros], + [[-0.25, -0.25, 0.0], [-0.25, -0.25, 0.0], zeros, zeros], + [[-0.125, -0.125, 0.125], zeros, zeros, zeros], + [[0.125, 0.125, 0.125], [-0.125, -0.125, 0.125], zeros, zeros], + [[-0.125, -0.125, 0.125], [-0.125, -0.125, 0.125], zeros, zeros], + [[-0.125, -0.125, 0.125], [-0.25, -0.25, 0.0], [0.25, 0.25, -0.0], zeros], + [[0.0, -0.25, 0.25], [0.0, -0.25, 0.25], zeros, zeros], + [[0.0, 0.0, 0.5], [0.25, -0.25, 0.25], [0.125, -0.125, 0.125], zeros], + [[0.0, -0.25, 0.25], [0.0, -0.25, 0.25], [-0.125, -0.125, 0.125], zeros], + [[0.375, -0.375, 0.375], [0.0, -0.25, -0.25], [-0.125, 0.125, -0.125], [0.25, 0.25, 0.0]], + [[-0.125, -0.125, 0.125], [-0.125, 0.125, 0.125], zeros, zeros], + [[0.125, 0.125, 0.125], [-0.125, -0.125, 0.125], [-0.125, 0.125, 0.125], zeros], + [[-0.125, -0.125, 0.125], [-0.25, 0.0, 0.25], [-0.25, 0.0, 0.25], zeros], + [[0.5, 0.0, 0.0], [-0.25, -0.25, 0.25], [-0.125, -0.125, 0.125], [-0.125, -0.125, 0.125]], + [[-0.0, 0.5, 0.0], [-0.25, 0.25, -0.25], [0.125, -0.125, 0.125], zeros], + [[-0.25, 0.25, -0.25], [-0.25, 0.25, -0.25], [-0.125, 0.125, -0.125], [-0.125, 0.125, -0.125]], + [[-0.25, 0.0, -0.25], [0.375, -0.375, -0.375], [0.0, 0.25, -0.25], [-0.125, 0.125, 0.125]], + [[0.5, 0.0, 0.0], [-0.25, 0.25, -0.25], [0.125, -0.125, 0.125], zeros], + [[-0.25, 0.0, 0.25], [0.25, 0.0, -0.25], zeros, zeros], + [[-0.0, 0.0, 0.5], [-0.25, 0.25, 0.25], [-0.125, 0.125, 0.125], zeros], + [[-0.125, -0.125, 0.125], [-0.25, 0.0, 0.25], [0.25, 0.0, -0.25], zeros], + [[-0.25, -0.0, -0.25], [-0.375, 0.375, 0.375], [-0.25, -0.25, 0.0], [-0.125, 0.125, 0.125]], + [[0.0, 0.0, -0.5], [0.25, 0.25, -0.25], [-0.125, -0.125, 0.125], zeros], + [[-0.0, 0.0, 0.5], [0.0, 0.0, 0.5], zeros, zeros], + [[0.125, 0.125, 0.125], [0.125, 0.125, 0.125], [0.25, 0.25, 0.25], [0.0, 0.0, 0.5]], + [[0.125, 0.125, 0.125], [0.25, 0.25, 0.25], [0.0, 0.0, 0.5], zeros], + [[-0.25, 0.0, 0.25], [0.25, 0.0, -0.25], [-0.125, 0.125, 0.125], zeros], + [[-0.0, 0.0, 0.5], [0.25, -0.25, 0.25], [0.125, -0.125, 0.125], [0.125, -0.125, 0.125]], + [[-0.25, 0.0, 0.25], [-0.25, 0.0, 0.25], [-0.25, 0.0, 0.25], [0.25, 0.0, -0.25]], + [[0.125, -0.125, 0.125], [0.25, 0.0, 0.25], [0.25, 0.0, 0.25], zeros], + [[0.25, 0.0, 0.25], [-0.375, -0.375, 0.375], [-0.25, 0.25, 0.0], [-0.125, -0.125, 0.125]], + [[-0.0, 0.0, 0.5], [0.25, -0.25, 0.25], [0.125, -0.125, 0.125], zeros], + [[0.125, 0.125, 0.125], [0.25, 0.0, 0.25], [0.25, 0.0, 0.25], zeros], + [[0.25, 0.0, 0.25], [0.25, 0.0, 0.25], zeros, zeros], + [[-0.125, -0.125, 0.125], [0.125, -0.125, 0.125], zeros, zeros], + [[0.125, 0.125, 0.125], [-0.125, -0.125, 0.125], [0.125, -0.125, 0.125], zeros], + [[-0.125, -0.125, 0.125], [0.0, -0.25, 0.25], [0.0, 0.25, -0.25], zeros], + [[0.0, -0.5, 0.0], [0.125, 0.125, -0.125], [0.25, 0.25, -0.25], [-0.125, -0.125, 0.125]], + [[0.0, -0.25, 0.25], [0.0, -0.25, 0.25], [0.125, -0.125, 0.125], zeros], + [[0.0, 0.0, 0.5], [0.25, -0.25, 0.25], [0.125, -0.125, 0.125], [0.125, -0.125, 0.125]], + [[0.0, -0.25, 0.25], [0.0, -0.25, 0.25], [0.0, -0.25, 0.25], [0.0, 0.25, -0.25]], + [[0.0, 0.25, 0.25], [0.0, 0.25, 0.25], [0.125, -0.125, -0.125], zeros], + [[-0.125, 0.125, 0.125], [0.125, -0.125, 0.125], [-0.125, -0.125, 0.125], zeros], + [[-0.125, 0.125, 0.125], [0.125, -0.125, 0.125], [-0.125, -0.125, 0.125], [0.125, 0.125, 0.125]], + [[-0.0, 0.0, 0.5], [-0.25, -0.25, 0.25], [-0.125, -0.125, 0.125], [-0.125, -0.125, 0.125]], + [[0.125, 0.125, 0.125], [0.125, -0.125, 0.125], [0.125, -0.125, -0.125], zeros], + [[-0.0, 0.5, 0.0], [-0.25, 0.25, -0.25], [0.125, -0.125, 0.125], [0.125, -0.125, 0.125]], + [[0.125, 0.125, 0.125], [-0.125, -0.125, 0.125], [0.125, -0.125, -0.125], zeros], + [[0.0, -0.25, -0.25], [0.0, 0.25, 0.25], [0.125, 0.125, 0.125], zeros], + [[0.125, 0.125, 0.125], [0.125, -0.125, -0.125], zeros, zeros], + [[0.5, 0.0, -0.0], [0.25, -0.25, -0.25], [0.125, -0.125, -0.125], zeros], + [[-0.25, 0.25, 0.25], [-0.125, 0.125, 0.125], [-0.25, 0.25, 0.25], [0.125, -0.125, -0.125]], + [[0.375, -0.375, 0.375], [0.0, 0.25, 0.25], [-0.125, 0.125, -0.125], [-0.25, 0.0, 0.25]], + [[0.0, -0.5, 0.0], [-0.25, 0.25, 0.25], [-0.125, 0.125, 0.125], zeros], + [[-0.375, -0.375, 0.375], [0.25, -0.25, 0.0], [0.0, 0.25, 0.25], [-0.125, -0.125, 0.125]], + [[-0.125, 0.125, 0.125], [-0.25, 0.25, 0.25], [0.0, 0.0, 0.5], zeros], + [[0.125, 0.125, 0.125], [0.0, 0.25, 0.25], [0.0, 0.25, 0.25], zeros], + [[0.0, 0.25, 0.25], [0.0, 0.25, 0.25], zeros, zeros], + [[0.5, 0.0, -0.0], [0.25, 0.25, 0.25], [0.125, 0.125, 0.125], [0.125, 0.125, 0.125]], + [[0.125, -0.125, 0.125], [-0.125, -0.125, 0.125], [0.125, 0.125, 0.125], zeros], + [[-0.25, -0.0, -0.25], [0.25, 0.0, 0.25], [0.125, 0.125, 0.125], zeros], + [[0.125, 0.125, 0.125], [0.125, -0.125, 0.125], zeros, zeros], + [[-0.25, -0.25, 0.0], [0.25, 0.25, -0.0], [0.125, 0.125, 0.125], zeros], + [[0.125, 0.125, 0.125], [-0.125, -0.125, 0.125], zeros, zeros], + [[0.125, 0.125, 0.125], [0.125, 0.125, 0.125], zeros, zeros], + [[0.125, 0.125, 0.125], zeros, zeros, zeros], + [[0.125, 0.125, 0.125], zeros, zeros, zeros], + [[0.125, 0.125, 0.125], [0.125, 0.125, 0.125], zeros, zeros], + [[0.125, 0.125, 0.125], [-0.125, -0.125, 0.125], zeros, zeros], + [[-0.25, -0.25, 0.0], [0.25, 0.25, -0.0], [0.125, 0.125, 0.125], zeros], + [[0.125, 0.125, 0.125], [0.125, -0.125, 0.125], zeros, zeros], + [[-0.25, -0.0, -0.25], [0.25, 0.0, 0.25], [0.125, 0.125, 0.125], zeros], + [[0.125, -0.125, 0.125], [-0.125, -0.125, 0.125], [0.125, 0.125, 0.125], zeros], + [[0.5, 0.0, -0.0], [0.25, 0.25, 0.25], [0.125, 0.125, 0.125], [0.125, 0.125, 0.125]], + [[0.0, 0.25, 0.25], [0.0, 0.25, 0.25], zeros, zeros], + [[0.125, 0.125, 0.125], [0.0, 0.25, 0.25], [0.0, 0.25, 0.25], zeros], + [[-0.125, 0.125, 0.125], [-0.25, 0.25, 0.25], [0.0, 0.0, 0.5], zeros], + [[-0.375, -0.375, 0.375], [0.25, -0.25, 0.0], [0.0, 0.25, 0.25], [-0.125, -0.125, 0.125]], + [[0.0, -0.5, 0.0], [-0.25, 0.25, 0.25], [-0.125, 0.125, 0.125], zeros], + [[0.375, -0.375, 0.375], [0.0, 0.25, 0.25], [-0.125, 0.125, -0.125], [-0.25, 0.0, 0.25]], + [[-0.25, 0.25, 0.25], [-0.125, 0.125, 0.125], [-0.25, 0.25, 0.25], [0.125, -0.125, -0.125]], + [[0.5, 0.0, -0.0], [0.25, -0.25, -0.25], [0.125, -0.125, -0.125], zeros], + [[0.125, 0.125, 0.125], [0.125, -0.125, -0.125], zeros, zeros], + [[0.0, -0.25, -0.25], [0.0, 0.25, 0.25], [0.125, 0.125, 0.125], zeros], + [[0.125, 0.125, 0.125], [-0.125, -0.125, 0.125], [0.125, -0.125, -0.125], zeros], + [[-0.0, 0.5, 0.0], [-0.25, 0.25, -0.25], [0.125, -0.125, 0.125], [0.125, -0.125, 0.125]], + [[0.125, 0.125, 0.125], [0.125, -0.125, 0.125], [0.125, -0.125, -0.125], zeros], + [[-0.0, 0.0, 0.5], [-0.25, -0.25, 0.25], [-0.125, -0.125, 0.125], [-0.125, -0.125, 0.125]], + [[-0.125, 0.125, 0.125], [0.125, -0.125, 0.125], [-0.125, -0.125, 0.125], [0.125, 0.125, 0.125]], + [[-0.125, 0.125, 0.125], [0.125, -0.125, 0.125], [-0.125, -0.125, 0.125], zeros], + [[0.0, 0.25, 0.25], [0.0, 0.25, 0.25], [0.125, -0.125, -0.125], zeros], + [[0.0, -0.25, -0.25], [0.0, 0.25, 0.25], [0.0, 0.25, 0.25], [0.0, 0.25, 0.25]], + [[0.0, 0.0, 0.5], [0.25, -0.25, 0.25], [0.125, -0.125, 0.125], [0.125, -0.125, 0.125]], + [[0.0, -0.25, 0.25], [0.0, -0.25, 0.25], [0.125, -0.125, 0.125], zeros], + [[0.0, -0.5, 0.0], [0.125, 0.125, -0.125], [0.25, 0.25, -0.25], [-0.125, -0.125, 0.125]], + [[-0.125, -0.125, 0.125], [0.0, -0.25, 0.25], [0.0, 0.25, -0.25], zeros], + [[0.125, 0.125, 0.125], [-0.125, -0.125, 0.125], [0.125, -0.125, 0.125], zeros], + [[-0.125, -0.125, 0.125], [0.125, -0.125, 0.125], zeros, zeros], + [[0.25, 0.0, 0.25], [0.25, 0.0, 0.25], zeros, zeros], + [[0.125, 0.125, 0.125], [0.25, 0.0, 0.25], [0.25, 0.0, 0.25], zeros], + [[-0.0, 0.0, 0.5], [0.25, -0.25, 0.25], [0.125, -0.125, 0.125], zeros], + [[0.25, 0.0, 0.25], [-0.375, -0.375, 0.375], [-0.25, 0.25, 0.0], [-0.125, -0.125, 0.125]], + [[0.125, -0.125, 0.125], [0.25, 0.0, 0.25], [0.25, 0.0, 0.25], zeros], + [[-0.25, -0.0, -0.25], [0.25, 0.0, 0.25], [0.25, 0.0, 0.25], [0.25, 0.0, 0.25]], + [[-0.0, 0.0, 0.5], [0.25, -0.25, 0.25], [0.125, -0.125, 0.125], [0.125, -0.125, 0.125]], + [[-0.25, 0.0, 0.25], [0.25, 0.0, -0.25], [-0.125, 0.125, 0.125], zeros], + [[0.125, 0.125, 0.125], [0.25, 0.25, 0.25], [0.0, 0.0, 0.5], zeros], + [[0.125, 0.125, 0.125], [0.125, 0.125, 0.125], [0.25, 0.25, 0.25], [0.0, 0.0, 0.5]], + [[-0.0, 0.0, 0.5], [0.0, 0.0, 0.5], zeros, zeros], + [[0.0, 0.0, -0.5], [0.25, 0.25, -0.25], [-0.125, -0.125, 0.125], zeros], + [[-0.25, -0.0, -0.25], [-0.375, 0.375, 0.375], [-0.25, -0.25, 0.0], [-0.125, 0.125, 0.125]], + [[-0.125, -0.125, 0.125], [-0.25, 0.0, 0.25], [0.25, 0.0, -0.25], zeros], + [[-0.0, 0.0, 0.5], [-0.25, 0.25, 0.25], [-0.125, 0.125, 0.125], zeros], + [[-0.25, 0.0, 0.25], [0.25, 0.0, -0.25], zeros, zeros], + [[0.5, 0.0, 0.0], [-0.25, 0.25, -0.25], [0.125, -0.125, 0.125], zeros], + [[-0.25, 0.0, -0.25], [0.375, -0.375, -0.375], [0.0, 0.25, -0.25], [-0.125, 0.125, 0.125]], + [[-0.25, 0.25, -0.25], [-0.25, 0.25, -0.25], [-0.125, 0.125, -0.125], [-0.125, 0.125, -0.125]], + [[-0.0, 0.5, 0.0], [-0.25, 0.25, -0.25], [0.125, -0.125, 0.125], zeros], + [[0.5, 0.0, 0.0], [-0.25, -0.25, 0.25], [-0.125, -0.125, 0.125], [-0.125, -0.125, 0.125]], + [[-0.125, -0.125, 0.125], [-0.25, 0.0, 0.25], [-0.25, 0.0, 0.25], zeros], + [[0.125, 0.125, 0.125], [-0.125, -0.125, 0.125], [-0.125, 0.125, 0.125], zeros], + [[-0.125, -0.125, 0.125], [-0.125, 0.125, 0.125], zeros, zeros], + [[0.375, -0.375, 0.375], [0.0, -0.25, -0.25], [-0.125, 0.125, -0.125], [0.25, 0.25, 0.0]], + [[0.0, -0.25, 0.25], [0.0, -0.25, 0.25], [-0.125, -0.125, 0.125], zeros], + [[0.0, 0.0, 0.5], [0.25, -0.25, 0.25], [0.125, -0.125, 0.125], zeros], + [[0.0, -0.25, 0.25], [0.0, -0.25, 0.25], zeros, zeros], + [[-0.125, -0.125, 0.125], [-0.25, -0.25, 0.0], [0.25, 0.25, -0.0], zeros], + [[-0.125, -0.125, 0.125], [-0.125, -0.125, 0.125], zeros, zeros], + [[0.125, 0.125, 0.125], [-0.125, -0.125, 0.125], zeros, zeros], + [[-0.125, -0.125, 0.125], zeros, zeros, zeros], + [[-0.25, -0.25, 0.0], [-0.25, -0.25, 0.0], zeros, zeros], + [[0.125, 0.125, 0.125], [-0.25, -0.25, 0.0], [-0.25, -0.25, 0.0], zeros], + [[-0.25, -0.25, 0.0], [-0.25, -0.25, 0.0], [-0.125, -0.125, 0.125], zeros], + [[-0.25, -0.25, 0.0], [-0.25, -0.25, 0.0], [-0.25, -0.25, 0.0], [0.25, 0.25, -0.0]], + [[0.0, 0.5, 0.0], [0.25, 0.25, -0.25], [-0.125, -0.125, 0.125], zeros], + [[-0.375, 0.375, -0.375], [-0.25, -0.25, 0.0], [-0.125, 0.125, -0.125], [-0.25, 0.0, 0.25]], + [[0.0, 0.5, 0.0], [0.25, 0.25, -0.25], [-0.125, -0.125, 0.125], [-0.125, -0.125, 0.125]], + [[-0.125, 0.125, 0.125], [0.25, -0.25, 0.0], [-0.25, 0.25, 0.0], zeros], + [[0.0, -0.5, 0.0], [-0.25, -0.25, -0.25], [-0.125, -0.125, -0.125], zeros], + [[0.125, 0.125, 0.125], [0.0, -0.5, 0.0], [-0.25, -0.25, -0.25], [-0.125, -0.125, -0.125]], + [[-0.375, -0.375, -0.375], [-0.25, 0.0, 0.25], [-0.125, -0.125, -0.125], [-0.25, 0.25, 0.0]], + [[0.25, -0.25, 0.0], [-0.25, 0.25, 0.0], [0.125, -0.125, 0.125], zeros], + [[0.0, 0.5, 0.0], [0.0, -0.5, 0.0], zeros, zeros], + [[0.0, 0.5, 0.0], [0.125, -0.125, 0.125], [-0.25, 0.25, -0.25], zeros], + [[0.0, 0.5, 0.0], [-0.25, 0.25, 0.25], [0.125, -0.125, -0.125], zeros], + [[0.25, -0.25, 0.0], [-0.25, 0.25, 0.0], zeros, zeros], + [[-0.5, 0.0, 0.0], [-0.25, -0.25, 0.25], [-0.125, -0.125, 0.125], zeros], + [[0.0, 0.25, -0.25], [0.375, -0.375, -0.375], [-0.125, 0.125, 0.125], [0.25, 0.25, 0.0]], + [[0.5, 0.0, 0.0], [0.25, -0.25, 0.25], [-0.125, 0.125, -0.125], [0.125, -0.125, 0.125]], + [[0.125, -0.125, 0.125], [0.25, -0.25, 0.0], [0.25, -0.25, 0.0], zeros], + [[0.25, 0.25, -0.25], [0.25, 0.25, -0.25], [0.125, 0.125, -0.125], [-0.125, -0.125, 0.125]], + [[-0.0, 0.0, 0.5], [-0.25, -0.25, 0.25], [-0.125, -0.125, 0.125], zeros], + [[0.125, 0.125, 0.125], [0.125, -0.125, 0.125], [-0.125, 0.125, 0.125], zeros], + [[-0.125, 0.125, 0.125], [0.125, -0.125, 0.125], zeros, zeros], + [[-0.375, -0.375, 0.375], [-0.0, 0.25, 0.25], [0.125, 0.125, -0.125], [-0.25, -0.0, -0.25]], + [[0.0, -0.25, 0.25], [0.0, 0.25, -0.25], [0.125, -0.125, 0.125], zeros], + [[0.125, -0.125, 0.125], [-0.25, -0.0, -0.25], [0.25, 0.0, 0.25], zeros], + [[0.125, -0.125, 0.125], [0.125, -0.125, 0.125], zeros, zeros], + [[0.0, -0.5, 0.0], [0.125, 0.125, -0.125], [0.25, 0.25, -0.25], zeros], + [[0.0, -0.25, 0.25], [0.0, 0.25, -0.25], zeros, zeros], + [[0.125, 0.125, 0.125], [0.125, -0.125, 0.125], zeros, zeros], + [[0.125, -0.125, 0.125], zeros, zeros, zeros], + [[-0.5, 0.0, 0.0], [-0.125, -0.125, -0.125], [-0.25, -0.25, -0.25], zeros], + [[-0.5, 0.0, 0.0], [-0.125, -0.125, -0.125], [-0.25, -0.25, -0.25], [0.125, 0.125, 0.125]], + [[0.375, 0.375, 0.375], [0.0, 0.25, -0.25], [-0.125, -0.125, -0.125], [-0.25, 0.25, 0.0]], + [[0.125, -0.125, -0.125], [0.25, -0.25, 0.0], [0.25, -0.25, 0.0], zeros], + [[0.125, 0.125, 0.125], [0.375, 0.375, 0.375], [0.0, -0.25, 0.25], [-0.25, 0.0, 0.25]], + [[-0.25, 0.0, 0.25], [-0.25, 0.0, 0.25], [0.125, -0.125, -0.125], zeros], + [[0.0, -0.25, -0.25], [0.0, 0.25, 0.25], [-0.125, 0.125, 0.125], zeros], + [[-0.125, 0.125, 0.125], [0.125, -0.125, -0.125], zeros, zeros], + [[-0.125, -0.125, -0.125], [-0.25, -0.25, -0.25], [0.25, 0.25, 0.25], [0.125, 0.125, 0.125]], + [[-0.125, -0.125, 0.125], [0.125, -0.125, 0.125], [0.125, -0.125, -0.125], zeros], + [[0.0, 0.0, -0.5], [0.25, 0.25, 0.25], [-0.125, -0.125, -0.125], zeros], + [[0.125, -0.125, 0.125], [0.125, -0.125, -0.125], zeros, zeros], + [[0.0, -0.5, 0.0], [0.25, 0.25, 0.25], [0.125, 0.125, 0.125], zeros], + [[-0.125, -0.125, 0.125], [0.125, -0.125, -0.125], zeros, zeros], + [[0.0, -0.25, -0.25], [0.0, 0.25, 0.25], zeros, zeros], + [[0.125, -0.125, -0.125], zeros, zeros, zeros], + [[0.5, 0.0, 0.0], [0.5, 0.0, 0.0], zeros, zeros], + [[-0.5, 0.0, 0.0], [-0.25, 0.25, 0.25], [-0.125, 0.125, 0.125], zeros], + [[0.5, 0.0, 0.0], [0.25, -0.25, 0.25], [-0.125, 0.125, -0.125], zeros], + [[0.25, -0.25, 0.0], [0.25, -0.25, 0.0], zeros, zeros], + [[0.5, 0.0, 0.0], [-0.25, -0.25, 0.25], [-0.125, -0.125, 0.125], zeros], + [[-0.25, 0.0, 0.25], [-0.25, 0.0, 0.25], zeros, zeros], + [[0.125, 0.125, 0.125], [-0.125, 0.125, 0.125], zeros, zeros], + [[-0.125, 0.125, 0.125], zeros, zeros, zeros], + [[0.5, 0.0, -0.0], [0.25, 0.25, 0.25], [0.125, 0.125, 0.125], zeros], + [[0.125, -0.125, 0.125], [-0.125, -0.125, 0.125], zeros, zeros], + [[-0.25, -0.0, -0.25], [0.25, 0.0, 0.25], zeros, zeros], + [[0.125, 0.125, 0.125], zeros, zeros, zeros], + [[-0.25, -0.25, 0.0], [0.25, 0.25, -0.0], zeros, zeros], + [[0.125, 0.125, 0.125], zeros, zeros, zeros], + [[0.125, 0.125, 0.125], zeros, zeros, zeros], + [zeros, zeros, zeros, zeros], + ] + return torch.as_tensor(ret, device=device) + + +def create_table_neighbour_code_to_surface_area(spacing_mm, device=None): + """ + Returns an array mapping neighbourhood code to the surface elements area. + Adapted from https://github.com/deepmind/surface-distance + + Note that the normals encode the initial surface area. This function computes + the area corresponding to the given `spacing`. + + Args: + spacing_mm: a sequence of 3 numbers. Voxel spacing along the first 3 spatial axes. + device: device to put the table on. + + Returns: + An array of size 256, mapping neighbourhood code to the surface area. + ENCODING_KERNEL[3] which is the kernel used to compute the neighbourhood code. + """ + spacing_mm = ensure_tuple_rep(spacing_mm, 3) + # compute the area for all 256 possible surface elements given a 2x2x2 neighbourhood according to the spacing_mm + c = _get_neighbour_code_to_normals_table(device) + s = torch.as_tensor( + [[[spacing_mm[1] * spacing_mm[2], spacing_mm[0] * spacing_mm[2], spacing_mm[0] * spacing_mm[1]]]], + device=device, + dtype=c.dtype, + ) + norm = torch.linalg.norm(c * s, dim=-1) + neighbour_code_to_surface_area = norm.sum(-1) + return neighbour_code_to_surface_area, torch.as_tensor([[ENCODING_KERNEL[3]]], device=device) + + +def create_table_neighbour_code_to_contour_length(spacing_mm, device=None): + """ + Returns an array mapping neighbourhood code to the contour length. + Adapted from https://github.com/deepmind/surface-distance + + In 2D, each point has 4 neighbors. Thus, are 16 configurations. A + configuration is encoded with '1' meaning "inside the object" and '0' "outside + the object". For example, + "0101" and "1010" both encode an edge along the first spatial axis with length spacing[0] mm; + "0011" and "1100" both encode an edge along the second spatial axis with length spacing[1] mm. + + Args: + spacing_mm: 2-element list-like structure. Pixel spacing along the 1st and 2nd spatial axes. + device: device to put the table on. + + Returns: + A 16-element array mapping neighbourhood code to the contour length. + ENCODING_KERNEL[2] which is the kernel used to compute the neighbourhood code. + """ + spacing_mm = ensure_tuple_rep(spacing_mm, 2) + first, second = spacing_mm # spacing along the first and second spatial dimension respectively + diag = 0.5 * np.linalg.norm(spacing_mm) + + neighbour_code_to_contour_length = np.zeros([16], dtype=diag.dtype) + neighbour_code_to_contour_length[int("0001", 2)] = diag + neighbour_code_to_contour_length[int("0010", 2)] = diag + neighbour_code_to_contour_length[int("0011", 2)] = second + neighbour_code_to_contour_length[int("0100", 2)] = diag + neighbour_code_to_contour_length[int("0101", 2)] = first + neighbour_code_to_contour_length[int("0110", 2)] = 2 * diag + neighbour_code_to_contour_length[int("0111", 2)] = diag + neighbour_code_to_contour_length[int("1000", 2)] = diag + neighbour_code_to_contour_length[int("1001", 2)] = 2 * diag + neighbour_code_to_contour_length[int("1010", 2)] = first + neighbour_code_to_contour_length[int("1011", 2)] = diag + neighbour_code_to_contour_length[int("1100", 2)] = second + neighbour_code_to_contour_length[int("1101", 2)] = diag + neighbour_code_to_contour_length[int("1110", 2)] = diag + neighbour_code_to_contour_length = convert_to_tensor(neighbour_code_to_contour_length, device=device) + return neighbour_code_to_contour_length, torch.as_tensor([[ENCODING_KERNEL[2]]], device=device) + + +def get_code_to_measure_table(spacing, device=None): + """ + returns a table mapping neighbourhood code to the surface area or contour length. + + Args: + spacing: a sequence of 2 or 3 numbers, indicating the spacing in the spatial dimensions. + device: device to put the table on. + """ + spatial_dims = len(spacing) + spacing = ensure_tuple_rep(spacing, look_up_option(spatial_dims, (2, 3))) + if spatial_dims == 2: + return create_table_neighbour_code_to_contour_length(spacing, device) + return create_table_neighbour_code_to_surface_area(spacing, device) diff --git a/monai/transforms/croppad/dictionary.py b/monai/transforms/croppad/dictionary.py index 6b908dda8c..81046aa37a 100644 --- a/monai/transforms/croppad/dictionary.py +++ b/monai/transforms/croppad/dictionary.py @@ -732,8 +732,8 @@ def __init__( allow_smaller: bool = False, k_divisible: Sequence[int] | int = 1, mode: SequenceStr = PytorchPadMode.CONSTANT, - start_coord_key: str = "foreground_start_coord", - end_coord_key: str = "foreground_end_coord", + start_coord_key: str | None = "foreground_start_coord", + end_coord_key: str | None = "foreground_end_coord", allow_missing_keys: bool = False, lazy: bool = False, **pad_kwargs, diff --git a/tests/test_surface_dice.py b/tests/test_surface_dice.py index c5400aea39..53b0d38bb2 100644 --- a/tests/test_surface_dice.py +++ b/tests/test_surface_dice.py @@ -17,7 +17,8 @@ import torch import torch.nn.functional as F -from monai.metrics.surface_dice import SurfaceDiceMetric +from monai.metrics.surface_dice import SurfaceDiceMetric, compute_surface_dice +from tests.utils import assert_allclose _device = "cuda:0" if torch.cuda.is_available() else "cpu" @@ -326,26 +327,6 @@ def test_asserts(self): str(context.exception), ) - # input tensors not one-hot encoded - predictions_no_hot = torch.clone(predictions_hot) - predictions_no_hot[0, :, 0, 0] = torch.tensor([2, 0]) - - with self.assertRaises(ValueError) as context: - SurfaceDiceMetric(class_thresholds=[1, 1], include_background=True)(predictions_no_hot, predictions_hot) - self.assertEqual("y_pred and y should be one-hot encoded.", str(context.exception)) - with self.assertRaises(ValueError) as context: - SurfaceDiceMetric(class_thresholds=[1, 1], include_background=True)(predictions_hot, predictions_no_hot) - self.assertEqual("y_pred and y should be one-hot encoded.", str(context.exception)) - - predictions_no_hot = predictions_no_hot.float() - predictions_no_hot[0, :, 0, 0] = torch.tensor([0.5, 0]) - with self.assertRaises(ValueError) as context: - SurfaceDiceMetric(class_thresholds=[1, 1], include_background=True)(predictions_no_hot, predictions_hot) - self.assertEqual("y_pred and y should be binarized tensors (e.g. torch.int64).", str(context.exception)) - with self.assertRaises(ValueError) as context: - SurfaceDiceMetric(class_thresholds=[1, 1], include_background=True)(predictions_hot, predictions_no_hot) - self.assertEqual("y_pred and y should be binarized tensors (e.g. torch.int64).", str(context.exception)) - # wrong number of class thresholds with self.assertRaises(ValueError) as context: SurfaceDiceMetric(class_thresholds=[1, 1, 1], include_background=True)(predictions_hot, labels_hot) @@ -401,6 +382,35 @@ def test_not_predicted_not_present(self): np.testing.assert_equal(res, torch.tensor([0], dtype=torch.float)) np.testing.assert_equal(not_nans, torch.tensor([0], dtype=torch.float)) + def test_compute_surface_dice_subvoxel(self): + mask_gt, mask_pred = torch.zeros(1, 1, 128, 128, 128), torch.zeros(1, 1, 128, 128, 128) + mask_gt[0, 0, 50, 60, 70] = 1 + res = compute_surface_dice( + mask_pred, mask_gt, [1.0], include_background=True, spacing=(3, 2, 1), use_subvoxels=True + ) + assert_allclose(res, 0.0, type_test=False) + mask_gt[0, 0, 50, 60, 70] = 0 + mask_pred[0, 0, 50, 60, 72] = 1 + res = compute_surface_dice( + mask_pred, mask_gt, [1.0], include_background=True, spacing=(3, 2, 1), use_subvoxels=True + ) + assert_allclose(res, 0.0, type_test=False) + mask_gt[0, 0, 50, 60, 70] = 1 + mask_pred[0, 0, 50, 60, 72] = 1 + res = compute_surface_dice( + mask_pred, mask_gt, [1.0], include_background=True, spacing=(3, 2, 1), use_subvoxels=True + ) + assert_allclose(res, 0.5, type_test=False) + + d = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + mask_gt, mask_pred = torch.zeros(1, 1, 100, 100, 100, device=d), torch.zeros(1, 1, 100, 100, 100, device=d) + mask_gt[0, 0, 0:50, :, :] = 1 + mask_pred[0, 0, 0:51, :, :] = 1 + res = compute_surface_dice( + mask_pred, mask_gt, [1.0], include_background=True, spacing=(2, 1, 1), use_subvoxels=True + ) + assert_allclose(res, 0.836145, type_test=False, atol=1e-3, rtol=1e-3) + if __name__ == "__main__": unittest.main()