diff --git a/monai/losses/hausdorff_loss.py b/monai/losses/hausdorff_loss.py index f3fd87c22a..6117f27741 100644 --- a/monai/losses/hausdorff_loss.py +++ b/monai/losses/hausdorff_loss.py @@ -19,12 +19,11 @@ import warnings from typing import Callable -import numpy as np import torch from torch.nn.modules.loss import _Loss -from monai.metrics.utils import distance_transform_edt from monai.networks import one_hot +from monai.transforms.utils import distance_transform_edt from monai.utils import LossReduction @@ -95,7 +94,7 @@ def __init__( self.batch = batch @torch.no_grad() - def distance_field(self, img: np.ndarray) -> np.ndarray: + def distance_field(self, img: torch.Tensor) -> torch.Tensor: """Generate distance transform. Args: @@ -104,18 +103,20 @@ def distance_field(self, img: np.ndarray) -> np.ndarray: Returns: np.ndarray: Distance field. """ - field = np.zeros_like(img) + field = torch.zeros_like(img) - for batch in range(len(img)): - fg_mask = img[batch] > 0.5 + for batch_idx in range(len(img)): + fg_mask = img[batch_idx] > 0.5 - if fg_mask.any(): + # For cases where the mask is entirely background or entirely foreground + # the distance transform is not well defined for all 1s, + # which always would happen on either foreground or background, so skip + if fg_mask.any() and not fg_mask.all(): + fg_dist: torch.Tensor = distance_transform_edt(fg_mask) # type: ignore bg_mask = ~fg_mask + bg_dist: torch.Tensor = distance_transform_edt(bg_mask) # type: ignore - fg_dist = distance_transform_edt(fg_mask) - bg_dist = distance_transform_edt(bg_mask) - - field[batch] = fg_dist + bg_dist + field[batch_idx] = fg_dist + bg_dist return field @@ -181,8 +182,8 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: for i in range(input.shape[1]): ch_input = input[:, [i]] ch_target = target[:, [i]] - pred_dt = torch.from_numpy(self.distance_field(ch_input.detach().cpu().numpy())).float() - target_dt = torch.from_numpy(self.distance_field(ch_target.detach().cpu().numpy())).float() + pred_dt = self.distance_field(ch_input.detach()).float() + target_dt = self.distance_field(ch_target.detach()).float() pred_error = (ch_input - ch_target) ** 2 distance = pred_dt**self.alpha + target_dt**self.alpha diff --git a/monai/metrics/hausdorff_distance.py b/monai/metrics/hausdorff_distance.py index e79cc24325..c390629a7b 100644 --- a/monai/metrics/hausdorff_distance.py +++ b/monai/metrics/hausdorff_distance.py @@ -11,7 +11,6 @@ from __future__ import annotations -import warnings from collections.abc import Sequence from typing import Any @@ -20,12 +19,12 @@ from monai.metrics.utils import ( do_metric_reduction, - get_mask_edges, + get_edge_surface_distance, get_surface_distance, ignore_background, prepare_spacing, ) -from monai.utils import MetricReduction, convert_data_type +from monai.utils import MetricReduction, convert_data_type, deprecated from .metric import CumulativeIterationMetric @@ -180,31 +179,46 @@ def compute_hausdorff_distance( raise ValueError(f"y_pred and y should have same shapes, got {y_pred.shape} and {y.shape}.") batch_size, n_class = y_pred.shape[:2] - hd = np.empty((batch_size, n_class)) + hd = torch.empty((batch_size, n_class), dtype=torch.float, device=y_pred.device) img_dim = y_pred.ndim - 2 spacing_list = prepare_spacing(spacing=spacing, batch_size=batch_size, img_dim=img_dim) for b, c in np.ndindex(batch_size, n_class): - (edges_pred, edges_gt) = get_mask_edges(y_pred[b, c], y[b, c]) - if not np.any(edges_gt): - warnings.warn(f"the ground truth of class {c} is all 0, this may result in nan/inf distance.") - if not np.any(edges_pred): - warnings.warn(f"the prediction of class {c} is all 0, this may result in nan/inf distance.") - - distance_1 = compute_percent_hausdorff_distance( - edges_pred, edges_gt, distance_metric, percentile, spacing_list[b] + _, distances, _ = get_edge_surface_distance( + y_pred[b, c], + y[b, c], + distance_metric=distance_metric, + spacing=spacing_list[b], + symetric=not directed, + class_index=c, ) - if directed: - hd[b, c] = distance_1 - else: - distance_2 = compute_percent_hausdorff_distance( - edges_gt, edges_pred, distance_metric, percentile, spacing_list[b] - ) - hd[b, c] = max(distance_1, distance_2) - return convert_data_type(hd, output_type=torch.Tensor, device=y_pred.device, dtype=torch.float)[0] + percentile_distances = [_compute_percentile_hausdorff_distance(d, percentile) for d in distances] + max_distance = torch.max(torch.stack(percentile_distances)) + hd[b, c] = max_distance + return hd +def _compute_percentile_hausdorff_distance( + surface_distance: torch.Tensor, percentile: float | None = None +) -> torch.Tensor: + """ + This function is used to compute the Hausdorff distance. + """ + + # for both pred and gt do not have foreground + if surface_distance.shape == (0,): + return torch.tensor(torch.nan, dtype=torch.float, device=surface_distance.device) + + if not percentile: + return surface_distance.max() # type: ignore[no-any-return] + + if 0 <= percentile <= 100: + return torch.quantile(surface_distance, percentile / 100) # type: ignore[no-any-return] + raise ValueError(f"percentile should be a value between 0 and 100, get {percentile}.") + + +@deprecated(since="1.3.0", removed="1.5.0") def compute_percent_hausdorff_distance( edges_pred: np.ndarray, edges_gt: np.ndarray, @@ -216,7 +230,9 @@ def compute_percent_hausdorff_distance( This function is used to compute the directed Hausdorff distance. """ - surface_distance = get_surface_distance(edges_pred, edges_gt, distance_metric=distance_metric, spacing=spacing) + surface_distance: np.ndarray = get_surface_distance( + edges_pred, edges_gt, distance_metric=distance_metric, spacing=spacing + ) # type: ignore # for both pred and gt do not have foreground if surface_distance.shape == (0,): diff --git a/monai/metrics/surface_dice.py b/monai/metrics/surface_dice.py index 4b7e80a8b5..f6edc5f598 100644 --- a/monai/metrics/surface_dice.py +++ b/monai/metrics/surface_dice.py @@ -11,21 +11,14 @@ from __future__ import annotations -import warnings from collections.abc import Sequence from typing import Any import numpy as np import torch -from monai.metrics.utils import ( - do_metric_reduction, - get_mask_edges, - get_surface_distance, - ignore_background, - prepare_spacing, -) -from monai.utils import MetricReduction, convert_data_type +from monai.metrics.utils import do_metric_reduction, get_edge_surface_distance, ignore_background, prepare_spacing +from monai.utils import MetricReduction from .metric import CumulativeIterationMetric @@ -251,47 +244,39 @@ def compute_surface_dice( if any(np.array(class_thresholds) < 0): raise ValueError("All class thresholds need to be >= 0.") - nsd = np.empty((batch_size, n_class)) + nsd = torch.empty((batch_size, n_class), device=y_pred.device, dtype=torch.float) img_dim = y_pred.ndim - 2 spacing_list = prepare_spacing(spacing=spacing, batch_size=batch_size, img_dim=img_dim) for b, c in np.ndindex(batch_size, n_class): + (edges_pred, edges_gt), (distances_pred_gt, distances_gt_pred), areas = get_edge_surface_distance( # type: ignore + y_pred[b, c], + y[b, c], + distance_metric=distance_metric, + spacing=spacing_list[b], + use_subvoxels=use_subvoxels, + symetric=True, + class_index=c, + ) + boundary_correct: int | torch.Tensor | float + boundary_complete: int | torch.Tensor | float if not use_subvoxels: - (edges_pred, edges_gt) = get_mask_edges(y_pred[b, c], y[b, c], crop=True) - distances_pred_gt = get_surface_distance( - edges_pred, edges_gt, distance_metric=distance_metric, spacing=spacing_list[b] - ) - distances_gt_pred = get_surface_distance( - edges_gt, edges_pred, distance_metric=distance_metric, spacing=spacing_list[b] - ) - boundary_complete = len(distances_pred_gt) + len(distances_gt_pred) - boundary_correct = np.sum(distances_pred_gt <= class_thresholds[c]) + np.sum( + boundary_correct = torch.sum(distances_pred_gt <= class_thresholds[c]) + torch.sum( distances_gt_pred <= class_thresholds[c] ) else: - _spacing = spacing_list[b] if spacing_list[b] is not None else [1] * img_dim - areas_pred: np.ndarray - areas_gt: np.ndarray - edges_pred, edges_gt, areas_pred, areas_gt = get_mask_edges( # type: ignore - y_pred[b, c], y[b, c], crop=True, spacing=_spacing # type: ignore - ) - dist_pred_to_gt = get_surface_distance(edges_pred, edges_gt, distance_metric, spacing=spacing_list[b]) - dist_gt_to_pred = get_surface_distance(edges_gt, edges_pred, distance_metric, spacing=spacing_list[b]) + areas_pred, areas_gt = areas # type: ignore areas_gt, areas_pred = areas_gt[edges_gt], areas_pred[edges_pred] - boundary_complete = areas_gt.sum() + areas_pred.sum() - gt_true = areas_gt[dist_gt_to_pred <= class_thresholds[c]].sum() if len(areas_gt) > 0 else 0.0 - pred_true = areas_pred[dist_pred_to_gt <= class_thresholds[c]].sum() if len(areas_pred) > 0 else 0.0 + boundary_complete = areas_gt.sum() + areas_pred.sum() # type: ignore + gt_true = areas_gt[distances_gt_pred <= class_thresholds[c]].sum() if len(areas_gt) > 0 else 0.0 + pred_true = areas_pred[distances_pred_gt <= class_thresholds[c]].sum() if len(areas_pred) > 0 else 0.0 boundary_correct = gt_true + pred_true - if not np.any(edges_gt): - warnings.warn(f"the ground truth of class {c} is all 0, this may result in nan/inf distance.") - if not np.any(edges_pred): - warnings.warn(f"the prediction of class {c} is all 0, this may result in nan/inf distance.") if boundary_complete == 0: # the class is neither present in the prediction, nor in the reference segmentation - nsd[b, c] = np.nan + nsd[b, c] = torch.nan else: nsd[b, c] = boundary_correct / boundary_complete - return convert_data_type(nsd, output_type=torch.Tensor, device=y_pred.device, dtype=torch.float)[0] + return nsd diff --git a/monai/metrics/surface_distance.py b/monai/metrics/surface_distance.py index f56ee94119..bdc4395562 100644 --- a/monai/metrics/surface_distance.py +++ b/monai/metrics/surface_distance.py @@ -11,20 +11,13 @@ from __future__ import annotations -import warnings from collections.abc import Sequence from typing import Any import numpy as np import torch -from monai.metrics.utils import ( - do_metric_reduction, - get_mask_edges, - get_surface_distance, - ignore_background, - prepare_spacing, -) +from monai.metrics.utils import do_metric_reduction, get_edge_surface_distance, ignore_background, prepare_spacing from monai.utils import MetricReduction, convert_data_type from .metric import CumulativeIterationMetric @@ -173,25 +166,21 @@ def compute_average_surface_distance( raise ValueError(f"y_pred and y should have same shapes, got {y_pred.shape} and {y.shape}.") batch_size, n_class = y_pred.shape[:2] - asd = np.empty((batch_size, n_class)) + asd = torch.empty((batch_size, n_class), dtype=torch.float32, device=y_pred.device) img_dim = y_pred.ndim - 2 spacing_list = prepare_spacing(spacing=spacing, batch_size=batch_size, img_dim=img_dim) for b, c in np.ndindex(batch_size, n_class): - (edges_pred, edges_gt) = get_mask_edges(y_pred[b, c], y[b, c]) - if not np.any(edges_gt): - warnings.warn(f"the ground truth of class {c} is all 0, this may result in nan/inf distance.") - if not np.any(edges_pred): - warnings.warn(f"the prediction of class {c} is all 0, this may result in nan/inf distance.") - surface_distance = get_surface_distance( - edges_pred, edges_gt, distance_metric=distance_metric, spacing=spacing_list[b] + _, distances, _ = get_edge_surface_distance( + y_pred[b, c], + y[b, c], + distance_metric=distance_metric, + spacing=spacing_list[b], + symetric=symmetric, + class_index=c, ) - if symmetric: - surface_distance_2 = get_surface_distance( - edges_gt, edges_pred, distance_metric=distance_metric, spacing=spacing_list[b] - ) - surface_distance = np.concatenate([surface_distance, surface_distance_2]) - asd[b, c] = np.nan if surface_distance.shape == (0,) else surface_distance.mean() + surface_distance = torch.cat(distances) + asd[b, c] = torch.nan if surface_distance.shape == (0,) else surface_distance.mean() return convert_data_type(asd, output_type=torch.Tensor, device=y_pred.device, dtype=torch.float)[0] diff --git a/monai/metrics/utils.py b/monai/metrics/utils.py index 3213689b33..547390e03c 100644 --- a/monai/metrics/utils.py +++ b/monai/metrics/utils.py @@ -12,7 +12,8 @@ from __future__ import annotations import warnings -from functools import lru_cache +from functools import lru_cache, partial +from types import ModuleType from typing import Any, Sequence import numpy as np @@ -20,10 +21,15 @@ from monai.config import NdarrayOrTensor, NdarrayTensor from monai.transforms.croppad.dictionary import CropForegroundD +from monai.transforms.utils import distance_transform_edt as monai_distance_transform_edt from monai.utils import ( MetricReduction, + convert_to_cupy, + convert_to_dst_type, convert_to_numpy, convert_to_tensor, + deprecated_arg, + deprecated_arg_default, ensure_tuple_rep, look_up_option, optional_import, @@ -32,6 +38,10 @@ binary_erosion, _ = optional_import("scipy.ndimage.morphology", name="binary_erosion") distance_transform_edt, _ = optional_import("scipy.ndimage.morphology", name="distance_transform_edt") distance_transform_cdt, _ = optional_import("scipy.ndimage.morphology", name="distance_transform_cdt") +cucim_binary_erosion, has_cucim_binary_erosion = optional_import("cucim.skimage.morphology", name="binary_erosion") +cucim_distance_transform_edt, has_cucim_distance_transform_edt = optional_import( + "cucim.core.operations.morphology", name="distance_transform_edt" +) __all__ = [ "ignore_background", @@ -124,13 +134,23 @@ def do_metric_reduction( return f, not_nans +@deprecated_arg_default( + name="always_return_as_numpy", since="1.3.0", replaced="1.5.0", old_default=True, new_default=False +) +@deprecated_arg( + name="always_return_as_numpy", + since="1.5.0", + removed="1.7.0", + msg_suffix="The option is removed and the return type will always be equal to the input type.", +) def get_mask_edges( seg_pred: NdarrayOrTensor, seg_gt: NdarrayOrTensor, label_idx: int = 1, crop: bool = True, spacing: Sequence | None = None, -) -> tuple[np.ndarray, np.ndarray]: + always_return_as_numpy: bool = True, +) -> tuple[NdarrayTensor, NdarrayTensor]: """ Compute edges from binary segmentation masks. This function is helpful to further calculate metrics such as Average Surface @@ -156,9 +176,25 @@ def get_mask_edges( images. Defaults to ``True``. spacing: the input spacing. If not None, the subvoxel edges and areas will be computed. otherwise `scipy`'s binary erosion is used to calculate the edges. + always_return_as_numpy: whether to a numpy array regardless of the input type. + If False, return the same type as inputs. """ if seg_pred.shape != seg_gt.shape: raise ValueError(f"seg_pred and seg_gt should have same shapes, got {seg_pred.shape} and {seg_gt.shape}.") + converter: Any + lib: ModuleType + if isinstance(seg_pred, torch.Tensor) and not always_return_as_numpy: + converter = partial(convert_to_tensor, device=seg_pred.device) + lib = torch + else: + converter = convert_to_numpy + lib = np + use_cucim = ( + spacing is None + and has_cucim_binary_erosion + and isinstance(seg_pred, torch.Tensor) + and seg_pred.device.type == "cuda" + ) # If not binary images, convert them if seg_pred.dtype not in (bool, torch.bool): @@ -168,10 +204,10 @@ def get_mask_edges( if crop: or_vol = seg_pred | seg_gt if not or_vol.any(): - pred, gt = np.zeros(seg_pred.shape, dtype=bool), np.zeros(seg_gt.shape, dtype=bool) + pred, gt = lib.zeros(seg_pred.shape, dtype=bool), lib.zeros(seg_gt.shape, dtype=bool) # type: ignore return (pred, gt) if spacing is None else (pred, gt, pred, gt) # type: ignore channel_first = [seg_pred[None], seg_gt[None], or_vol[None]] - if spacing is None: # cpu only erosion + if spacing is None and not use_cucim: # cpu only erosion seg_pred, seg_gt, or_vol = convert_to_tensor(channel_first, device="cpu", dtype=bool) else: # pytorch subvoxel, maybe on gpu, but croppad boolean values on GPU is not supported seg_pred, seg_gt, or_vol = convert_to_tensor(channel_first, dtype=torch.float16) @@ -182,10 +218,15 @@ def get_mask_edges( seg_pred, seg_gt = cropped["pred"][0], cropped["gt"][0] if spacing is None: # Do binary erosion and use XOR to get edges - seg_pred, seg_gt = convert_to_numpy([seg_pred, seg_gt], dtype=bool) - edges_pred = binary_erosion(seg_pred) ^ seg_pred - edges_gt = binary_erosion(seg_gt) ^ seg_gt - return edges_pred, edges_gt + if not use_cucim: + seg_pred, seg_gt = convert_to_numpy([seg_pred, seg_gt], dtype=bool) + edges_pred = binary_erosion(seg_pred) ^ seg_pred + edges_gt = binary_erosion(seg_gt) ^ seg_gt + else: + seg_pred, seg_gt = convert_to_cupy([seg_pred, seg_gt], dtype=bool) # type: ignore[arg-type] + edges_pred = cucim_binary_erosion(seg_pred) ^ seg_pred + edges_gt = cucim_binary_erosion(seg_gt) ^ seg_gt + return converter((edges_pred, edges_gt), dtype=bool) # type: ignore code_to_area_table, k = get_code_to_measure_table(spacing, device=seg_pred.device) # type: ignore spatial_dims = len(spacing) conv = torch.nn.functional.conv3d if spatial_dims == 3 else torch.nn.functional.conv2d @@ -199,15 +240,15 @@ def get_mask_edges( areas_pred = torch.index_select(code_to_area_table, 0, code_pred.view(-1).int()).reshape(code_pred.shape) areas_gt = torch.index_select(code_to_area_table, 0, code_gt.view(-1).int()).reshape(code_gt.shape) ret = (edges_pred[0], edges_gt[0], areas_pred[0], areas_gt[0]) - return convert_to_numpy(ret, wrap_sequence=False) # type: ignore + return converter(ret, wrap_sequence=False) # type: ignore def get_surface_distance( - seg_pred: np.ndarray, - seg_gt: np.ndarray, + seg_pred: NdarrayOrTensor, + seg_gt: NdarrayOrTensor, distance_metric: str = "euclidean", spacing: int | float | np.ndarray | Sequence[int | float] | None = None, -) -> np.ndarray: +) -> NdarrayOrTensor: """ This function is used to compute the surface distances from `seg_pred` to `seg_gt`. @@ -232,21 +273,81 @@ def get_surface_distance( If seg_pred or seg_gt is all 0, may result in nan/inf distance. """ - - if not np.any(seg_gt): - dis = np.inf * np.ones_like(seg_gt) + lib: ModuleType = torch if isinstance(seg_pred, torch.Tensor) else np + if not seg_gt.any(): + dis = lib.inf * lib.ones_like(seg_gt, dtype=lib.float32) # type: ignore else: - if not np.any(seg_pred): - dis = np.inf * np.ones_like(seg_gt) - return np.asarray(dis[seg_gt]) + if not lib.any(seg_pred): # type: ignore + dis = lib.inf * lib.ones_like(seg_gt, dtype=lib.float32) # type: ignore + dis = dis[seg_gt] # type: ignore + return convert_to_dst_type(dis, seg_pred, dtype=dis.dtype)[0] if distance_metric == "euclidean": - dis = distance_transform_edt(~seg_gt, sampling=spacing) + dis = monai_distance_transform_edt((~seg_gt)[None, ...], sampling=spacing)[0] # type: ignore elif distance_metric in {"chessboard", "taxicab"}: - dis = distance_transform_cdt(~seg_gt, metric=distance_metric) + dis = distance_transform_cdt(convert_to_numpy(~seg_gt), metric=distance_metric) else: raise ValueError(f"distance_metric {distance_metric} is not implemented.") + dis = convert_to_dst_type(dis, seg_pred, dtype=lib.float32)[0] + return dis[seg_pred] # type: ignore + - return np.asarray(dis[seg_pred]) +def get_edge_surface_distance( + y_pred: torch.Tensor, + y: torch.Tensor, + distance_metric: str = "euclidean", + spacing: int | float | np.ndarray | Sequence[int | float] | None = None, + use_subvoxels: bool = False, + symetric: bool = False, + class_index: int = -1, +) -> tuple[ + tuple[torch.Tensor, torch.Tensor], + tuple[torch.Tensor, torch.Tensor] | tuple[torch.Tensor], + tuple[torch.Tensor, torch.Tensor] | tuple[()], +]: + """ + This function is used to compute the surface distance from `y_pred` to `y` using the edges of the masks. + + Args: + y_pred: the predicted binary or labelfield image. Expected to be in format (H, W[, D]). + y: the actual binary or labelfield image. Expected to be in format (H, W[, D]). + distance_metric: : [``"euclidean"``, ``"chessboard"``, ``"taxicab"``] + See :py:func:`monai.metrics.utils.get_surface_distance`. + spacing: spacing of pixel (or voxel). This parameter is relevant only if ``distance_metric`` is set to ``"euclidean"``. + See :py:func:`monai.metrics.utils.get_surface_distance`. + use_subvoxels: whether to use subvoxel resolution (using the spacing). + This will return the areas of the edges. + symetric: whether to compute the surface distance from `y_pred` to `y` and from `y` to `y_pred`. + class_index: The class-index used for context when warning about empty ground truth or prediction. + + Returns: + (edges_pred, edges_gt), (distances_pred_to_gt, [distances_gt_to_pred]), (areas_pred, areas_gt) | tuple() + + """ + edges_spacing = None + if use_subvoxels: + edges_spacing = spacing if spacing is not None else ([1] * len(y_pred.shape)) + (edges_pred, edges_gt, *areas) = get_mask_edges( + y_pred, y, crop=True, spacing=edges_spacing, always_return_as_numpy=False + ) + if not edges_gt.any(): + warnings.warn( + f"the ground truth of class {class_index if class_index != -1 else 'Unknown'} is all 0," + " this may result in nan/inf distance." + ) + if not edges_pred.any(): + warnings.warn( + f"the prediction of class {class_index if class_index != -1 else 'Unknown'} is all 0," + " this may result in nan/inf distance." + ) + distances: tuple[torch.Tensor, torch.Tensor] | tuple[torch.Tensor] + if symetric: + distances = ( + get_surface_distance(edges_pred, edges_gt, distance_metric, spacing), + get_surface_distance(edges_gt, edges_pred, distance_metric, spacing), + ) # type: ignore + else: + distances = (get_surface_distance(edges_pred, edges_gt, distance_metric, spacing),) # type: ignore + return convert_to_tensor(((edges_pred, edges_gt), distances, tuple(areas)), device=y_pred.device) # type: ignore[no-any-return] def is_binary_tensor(input: torch.Tensor, name: str) -> None: diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index d2c06dfd93..44e5b25a34 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -2099,9 +2099,10 @@ def distance_transform_edt( Returns: distances: The calculated distance transform. Returned only when `return_distances` is True and `distances` is not supplied. - It will have the same shape as image. For cuCIM: Will have dtype torch.float64 if float64_distances is True, + It will have the same shape and type as image. For cuCIM: Will have dtype torch.float64 if float64_distances is True, otherwise it will have dtype torch.float32. For SciPy: Will have dtype np.float64. indices: The calculated feature transform. It has an image-shaped array for each dimension of the image. + The type will be equal to the type of the image. Returned only when `return_indices` is True and `indices` is not supplied. dtype np.float64. """ @@ -2109,7 +2110,6 @@ def distance_transform_edt( "cucim.core.operations.morphology", name="distance_transform_edt" ) use_cp = has_cp and has_cucim and isinstance(img, torch.Tensor) and img.device.type == "cuda" - if not return_distances and not return_indices: raise RuntimeError("Neither return_distances nor return_indices True") @@ -2190,9 +2190,8 @@ def distance_transform_edt( r_vals.append(indices) if not r_vals: return None - if len(r_vals) == 1: - return r_vals[0] - return tuple(r_vals) # type: ignore + device = img.device if isinstance(img, torch.Tensor) else None + return convert_data_type(r_vals[0] if len(r_vals) == 1 else r_vals, output_type=type(img), device=device)[0] # type: ignore if __name__ == "__main__": diff --git a/monai/utils/type_conversion.py b/monai/utils/type_conversion.py index 734d8a2b17..e4f97fc4a6 100644 --- a/monai/utils/type_conversion.py +++ b/monai/utils/type_conversion.py @@ -237,7 +237,14 @@ def convert_to_cupy(data: Any, dtype: np.dtype | None = None, wrap_sequence: boo if safe: data = safe_dtype_range(data, dtype) # direct calls - if isinstance(data, (cp_ndarray, np.ndarray, torch.Tensor, float, int, bool)): + if isinstance(data, torch.Tensor) and data.device.type == "cuda": + # This is needed because of https://github.com/cupy/cupy/issues/7874#issuecomment-1727511030 + if data.dtype == torch.bool: + data = data.detach().to(torch.uint8) + if dtype is None: + dtype = bool # type: ignore + data = cp.asarray(data, dtype) + elif isinstance(data, (cp_ndarray, np.ndarray, torch.Tensor, float, int, bool)): data = cp.asarray(data, dtype) elif isinstance(data, list): list_ret = [convert_to_cupy(i, dtype) for i in data] diff --git a/tests/test_hausdorff_distance.py b/tests/test_hausdorff_distance.py index a50b27b79e..71bbad36d2 100644 --- a/tests/test_hausdorff_distance.py +++ b/tests/test_hausdorff_distance.py @@ -12,6 +12,7 @@ from __future__ import annotations import unittest +from itertools import product import numpy as np import torch @@ -19,7 +20,9 @@ from monai.metrics import HausdorffDistanceMetric -_device = "cuda:0" if torch.cuda.is_available() else "cpu" +_devices = ["cpu"] +if torch.cuda.is_available(): + _devices.append("cuda") def create_spherical_seg_3d( @@ -150,33 +153,42 @@ def create_spherical_seg_3d( ], ] +TEST_CASES_EXPANDED = [] +for test_case in TEST_CASES: + test_output: list[float | int] + test_input, test_output = test_case # type: ignore + for _device in _devices: + for i, (metric, directed) in enumerate(product(["euclidean", "chessboard", "taxicab"], [True, False])): + TEST_CASES_EXPANDED.append((_device, metric, directed, test_input, test_output[i])) + + +def _describe_test_case(test_func, test_number, params): + _device, metric, directed, test_input, test_output = params.args + return f"device: {_device} metric: {metric} directed:{directed} expected: {test_output}" + class TestHausdorffDistance(unittest.TestCase): - @parameterized.expand(TEST_CASES) - def test_value(self, input_data, expected_value): + @parameterized.expand(TEST_CASES_EXPANDED, doc_func=_describe_test_case) + def test_value(self, device, metric, directed, input_data, expected_value): percentile = None if len(input_data) == 4: [seg_1, seg_2, spacing, percentile] = input_data else: [seg_1, seg_2, spacing] = input_data - ct = 0 - seg_1 = torch.tensor(seg_1, device=_device) - seg_2 = torch.tensor(seg_2, device=_device) - for metric in ["euclidean", "chessboard", "taxicab"]: - for directed in [True, False]: - hd_metric = HausdorffDistanceMetric( - include_background=False, distance_metric=metric, percentile=percentile, directed=directed - ) - # shape of seg_1, seg_2 are: HWD, converts to BNHWD - batch, n_class = 2, 3 - batch_seg_1 = seg_1.unsqueeze(0).unsqueeze(0).repeat([batch, n_class, 1, 1, 1]) - batch_seg_2 = seg_2.unsqueeze(0).unsqueeze(0).repeat([batch, n_class, 1, 1, 1]) - hd_metric(batch_seg_1, batch_seg_2, spacing=spacing) - result = hd_metric.aggregate(reduction="mean") - expected_value_curr = expected_value[ct] - np.testing.assert_allclose(expected_value_curr, result.cpu(), rtol=1e-7) - np.testing.assert_equal(result.device, seg_1.device) - ct += 1 + + seg_1 = torch.tensor(seg_1, device=device) + seg_2 = torch.tensor(seg_2, device=device) + hd_metric = HausdorffDistanceMetric( + include_background=False, distance_metric=metric, percentile=percentile, directed=directed + ) + # shape of seg_1, seg_2 are: HWD, converts to BNHWD + batch, n_class = 2, 3 + batch_seg_1 = seg_1.unsqueeze(0).unsqueeze(0).repeat([batch, n_class, 1, 1, 1]) + batch_seg_2 = seg_2.unsqueeze(0).unsqueeze(0).repeat([batch, n_class, 1, 1, 1]) + hd_metric(batch_seg_1, batch_seg_2, spacing=spacing) + result: torch.Tensor = hd_metric.aggregate(reduction="mean") # type: ignore + np.testing.assert_allclose(expected_value, result.cpu(), rtol=1e-6) + np.testing.assert_equal(result.device, seg_1.device) @parameterized.expand(TEST_CASES_NANS) def test_nans(self, input_data): diff --git a/tests/test_hausdorff_loss.py b/tests/test_hausdorff_loss.py index 1cecce4da2..5ed20f5f3b 100644 --- a/tests/test_hausdorff_loss.py +++ b/tests/test_hausdorff_loss.py @@ -72,7 +72,7 @@ "input": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]], [[[1.0, -1.0], [-1.0, 1.0]]]], device=device), "target": torch.tensor([[[[1.0, 1.0], [1.0, 1.0]]], [[[1.0, 0.0], [1.0, 0.0]]]], device=device), }, - 0.758470, + 0.455082, ] ) TEST_CASES.append( @@ -141,7 +141,7 @@ "input": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]], [[[1.0, -1.0], [-1.0, 1.0]]]], device=device), "target": torch.tensor([[[[1.0, 1.0], [1.0, 1.0]]], [[[1.0, 0.0], [1.0, 0.0]]]], device=device), }, - 3.450064, + 1.870039, ] ) TEST_CASES.append( @@ -164,7 +164,7 @@ "input": torch.tensor([[[[1.0, -0.0], [-1.0, 1.0]]], [[[1.0, -1.0], [-1.0, 1.0]]]], device=device), "target": torch.tensor([[[[1.0, 1.0], [1.0, 1.0]]], [[[1.0, 0.0], [1.0, 0.0]]]], device=device), }, - 2.661359, + 1.607137, ] ) TEST_CASES.append( @@ -174,7 +174,7 @@ "input": torch.tensor([[[[1.0, -0.0], [-1.0, 1.0]]], [[[1.0, -1.0], [-1.0, 1.0]]]], device=device), "target": torch.tensor([[[[1.0, 1.0], [1.0, 1.0]]], [[[1.0, 0.0], [1.0, 0.0]]]], device=device), }, - 2.661359, + 1.607137, ] ) TEST_CASES.append( @@ -184,16 +184,21 @@ "input": torch.tensor([[[[1.0, -0.0], [-1.0, 1.0]]], [[[1.0, -1.0], [-1.0, 1.0]]]], device=device), "target": torch.tensor([[[[1.0, 1.0], [1.0, 1.0]]], [[[1.0, 0.0], [1.0, 0.0]]]], device=device), }, - 2.661359, + 1.607137, ] ) TEST_CASES_LOG = [[*inputs, np.log(np.array(output) + 1)] for *inputs, output in TEST_CASES] +def _describe_test_case(test_func, test_number, params): + input_param, input_data, _ = params.args + return f"params:{input_param}, shape:{input_data['input'].shape}, device:{input_data['input'].device}" + + @skipUnless(has_scipy, "Scipy required") class TestHausdorffDTLoss(unittest.TestCase): - @parameterized.expand(TEST_CASES) + @parameterized.expand(TEST_CASES, doc_func=_describe_test_case) def test_shape(self, input_param, input_data, expected_val): result = HausdorffDTLoss(**input_param).forward(**input_data) np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, rtol=1e-5) @@ -229,7 +234,7 @@ def test_input_warnings(self): @skipUnless(has_scipy, "Scipy required") class TesLogtHausdorffDTLoss(unittest.TestCase): - @parameterized.expand(TEST_CASES_LOG) + @parameterized.expand(TEST_CASES_LOG, doc_func=_describe_test_case) def test_shape(self, input_param, input_data, expected_val): result = LogHausdorffDTLoss(**input_param).forward(**input_data) np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, rtol=1e-5)