diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index 5aabdcdc33..8990e7991d 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -602,6 +602,12 @@ Post-processing :members: :special-members: __call__ +`DistanceTransformEDT` +""""""""""""""""""""""""""""""" +.. autoclass:: DistanceTransformEDT + :members: + :special-members: __call__ + `RemoveSmallObjects` """""""""""""""""""" .. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/RemoveSmallObjects.png @@ -1622,6 +1628,12 @@ Post-processing (Dict) :members: :special-members: __call__ +`DistanceTransformEDTd` +"""""""""""""""""""""""""""""""" +.. autoclass:: DistanceTransformEDTd + :members: + :special-members: __call__ + `RemoveSmallObjectsd` """"""""""""""""""""" .. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/RemoveSmallObjectsd.png diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index 1d308b0ed4..97644843ff 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -277,6 +277,7 @@ from .post.array import ( Activations, AsDiscrete, + DistanceTransformEDT, FillHoles, Invert, KeepLargestConnectedComponent, @@ -295,6 +296,9 @@ AsDiscreteD, AsDiscreted, AsDiscreteDict, + DistanceTransformEDTd, + DistanceTransformEDTD, + DistanceTransformEDTDict, Ensembled, EnsembleD, EnsembleDict, diff --git a/monai/transforms/post/array.py b/monai/transforms/post/array.py index df2e807a4b..f10dd21642 100644 --- a/monai/transforms/post/array.py +++ b/monai/transforms/post/array.py @@ -31,6 +31,7 @@ from monai.transforms.utility.array import ToTensor from monai.transforms.utils import ( convert_applied_interp_mode, + distance_transform_edt, fill_holes, get_largest_connected_component_mask, get_unique_labels, @@ -53,6 +54,7 @@ "SobelGradients", "VoteEnsemble", "Invert", + "DistanceTransformEDT", ] @@ -936,3 +938,39 @@ def __call__(self, image: NdarrayOrTensor) -> torch.Tensor: grads = convert_to_dst_type(grads.squeeze(0), image_tensor)[0] return grads + + +class DistanceTransformEDT(Transform): + """ + Applies the Euclidean distance transform on the input. + Either GPU based with CuPy / cuCIM or CPU based with scipy. + To use the GPU implementation, make sure cuCIM is available and that the data is a `torch.tensor` on a GPU device. + + Note that the results of the libraries can differ, so stick to one if possible. + For details, check out the `SciPy`_ and `cuCIM`_ documentation and / or :func:`monai.transforms.utils.distance_transform_edt`. + + .. _SciPy: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.distance_transform_edt.html + .. _cuCIM: https://docs.rapids.ai/api/cucim/nightly/api/#cucim.core.operations.morphology.distance_transform_edt + """ + + backend = [TransformBackends.NUMPY, TransformBackends.CUPY] + + def __init__(self, sampling: None | float | list[float] = None) -> None: + super().__init__() + self.sampling = sampling + + def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: + """ + Args: + img: Input image on which the distance transform shall be run. + Has to be a channel first array, must have shape: (num_channels, H, W [,D]). + Can be of any type but will be converted into binary: 1 wherever image equates to True, 0 elsewhere. + Input gets passed channel-wise to the distance-transform, thus results from this function will differ + from directly calling ``distance_transform_edt()`` in CuPy or SciPy. + sampling: Spacing of elements along each dimension. If a sequence, must be of length equal to the input rank -1; + if a single number, this is used for all axes. If not specified, a grid spacing of unity is implied. + + Returns: + An array with the same shape and data type as img + """ + return distance_transform_edt(img=img, sampling=self.sampling) # type: ignore diff --git a/monai/transforms/post/dictionary.py b/monai/transforms/post/dictionary.py index 3fbfe46118..393f161917 100644 --- a/monai/transforms/post/dictionary.py +++ b/monai/transforms/post/dictionary.py @@ -33,6 +33,7 @@ from monai.transforms.post.array import ( Activations, AsDiscrete, + DistanceTransformEDT, FillHoles, KeepLargestConnectedComponent, LabelFilter, @@ -91,6 +92,9 @@ "VoteEnsembleD", "VoteEnsembleDict", "VoteEnsembled", + "DistanceTransformEDTd", + "DistanceTransformEDTD", + "DistanceTransformEDTDict", ] DEFAULT_POST_FIX = PostFix.meta() @@ -855,6 +859,51 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, N return d +class DistanceTransformEDTd(MapTransform): + """ + Applies the Euclidean distance transform on the input. + Either GPU based with CuPy / cuCIM or CPU based with scipy. + To use the GPU implementation, make sure cuCIM is available and that the data is a `torch.tensor` on a GPU device. + + Note that the results of the libraries can differ, so stick to one if possible. + For details, check out the `SciPy`_ and `cuCIM`_ documentation and / or :func:`monai.transforms.utils.distance_transform_edt`. + + + Note on the input shape: + Has to be a channel first array, must have shape: (num_channels, H, W [,D]). + Can be of any type but will be converted into binary: 1 wherever image equates to True, 0 elsewhere. + Input gets passed channel-wise to the distance-transform, thus results from this function will differ + from directly calling ``distance_transform_edt()`` in CuPy or SciPy. + + Args: + keys: keys of the corresponding items to be transformed. + allow_missing_keys: don't raise exception if key is missing. + sampling: Spacing of elements along each dimension. If a sequence, must be of length equal to the input rank -1; + if a single number, this is used for all axes. If not specified, a grid spacing of unity is implied. + + .. _SciPy: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.distance_transform_edt.html + .. _cuCIM: https://docs.rapids.ai/api/cucim/nightly/api/#cucim.core.operations.morphology.distance_transform_edt + + + """ + + backend = DistanceTransformEDT.backend + + def __init__( + self, keys: KeysCollection, allow_missing_keys: bool = False, sampling: None | float | list[float] = None + ) -> None: + super().__init__(keys, allow_missing_keys) + self.sampling = sampling + self.distance_transform = DistanceTransformEDT(sampling=self.sampling) + + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Mapping[Hashable, NdarrayOrTensor]: + d = dict(data) + for key in self.key_iterator(d): + d[key] = self.distance_transform(img=d[key]) + + return d + + ActivationsD = ActivationsDict = Activationsd AsDiscreteD = AsDiscreteDict = AsDiscreted FillHolesD = FillHolesDict = FillHolesd @@ -869,3 +918,4 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, N VoteEnsembleD = VoteEnsembleDict = VoteEnsembled EnsembleD = EnsembleDict = Ensembled SobelGradientsD = SobelGradientsDict = SobelGradientsd +DistanceTransformEDTD = DistanceTransformEDTDict = DistanceTransformEDTd diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index 6d7fa8ada8..d2c06dfd93 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -66,11 +66,17 @@ pytorch_after, ) from monai.utils.enums import TransformBackends -from monai.utils.type_conversion import convert_data_type, convert_to_cupy, convert_to_dst_type, convert_to_tensor +from monai.utils.type_conversion import ( + convert_data_type, + convert_to_cupy, + convert_to_dst_type, + convert_to_numpy, + convert_to_tensor, +) measure, has_measure = optional_import("skimage.measure", "0.14.2", min_version) morphology, has_morphology = optional_import("skimage.morphology") -ndimage, _ = optional_import("scipy.ndimage") +ndimage, has_ndimage = optional_import("scipy.ndimage") cp, has_cp = optional_import("cupy") cp_ndarray, _ = optional_import("cupy", name="ndarray") exposure, has_skimage = optional_import("skimage.exposure") @@ -124,6 +130,7 @@ "reset_ops_id", "resolves_modes", "has_status_keys", + "distance_transform_edt", ] @@ -2051,5 +2058,142 @@ def has_status_keys(data: torch.Tensor, status_key: Any, default_message: str = return True, None +def distance_transform_edt( + img: NdarrayOrTensor, + sampling: None | float | list[float] = None, + return_distances: bool = True, + return_indices: bool = False, + distances: NdarrayOrTensor | None = None, + indices: NdarrayOrTensor | None = None, + *, + block_params: tuple[int, int, int] | None = None, + float64_distances: bool = False, +) -> None | NdarrayOrTensor | tuple[NdarrayOrTensor, NdarrayOrTensor]: + """ + Euclidean distance transform, either GPU based with CuPy / cuCIM or CPU based with scipy. + To use the GPU implementation, make sure cuCIM is available and that the data is a `torch.tensor` on a GPU device. + + Note that the results of the libraries can differ, so stick to one if possible. + For details, check out the `SciPy`_ and `cuCIM`_ documentation. + + .. _SciPy: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.distance_transform_edt.html + .. _cuCIM: https://docs.rapids.ai/api/cucim/nightly/api/#cucim.core.operations.morphology.distance_transform_edt + + Args: + img: Input image on which the distance transform shall be run. + Has to be a channel first array, must have shape: (num_channels, H, W [,D]). + Can be of any type but will be converted into binary: 1 wherever image equates to True, 0 elsewhere. + Input gets passed channel-wise to the distance-transform, thus results from this function will differ + from directly calling ``distance_transform_edt()`` in CuPy or SciPy. + sampling: Spacing of elements along each dimension. If a sequence, must be of length equal to the input rank -1; + if a single number, this is used for all axes. If not specified, a grid spacing of unity is implied. + return_distances: Whether to calculate the distance transform. + return_indices: Whether to calculate the feature transform. + distances: An output array to store the calculated distance transform, instead of returning it. + `return_distances` must be True. + indices: An output array to store the calculated feature transform, instead of returning it. `return_indicies` must be True. + block_params: This parameter is specific to cuCIM and does not exist in SciPy. For details, look into `cuCIM`_. + float64_distances: This parameter is specific to cuCIM and does not exist in SciPy. + If True, use double precision in the distance computation (to match SciPy behavior). + Otherwise, single precision will be used for efficiency. + + Returns: + distances: The calculated distance transform. Returned only when `return_distances` is True and `distances` is not supplied. + It will have the same shape as image. For cuCIM: Will have dtype torch.float64 if float64_distances is True, + otherwise it will have dtype torch.float32. For SciPy: Will have dtype np.float64. + indices: The calculated feature transform. It has an image-shaped array for each dimension of the image. + Returned only when `return_indices` is True and `indices` is not supplied. dtype np.float64. + + """ + distance_transform_edt, has_cucim = optional_import( + "cucim.core.operations.morphology", name="distance_transform_edt" + ) + use_cp = has_cp and has_cucim and isinstance(img, torch.Tensor) and img.device.type == "cuda" + + if not return_distances and not return_indices: + raise RuntimeError("Neither return_distances nor return_indices True") + + if not (img.ndim >= 3 and img.ndim <= 4): + raise RuntimeError("Wrong input dimensionality. Use (num_channels, H, W [,D])") + + distances_original, indices_original = distances, indices + distances, indices = None, None + if use_cp: + distances_, indices_ = None, None + if return_distances: + dtype = torch.float64 if float64_distances else torch.float32 + if distances is None: + distances = torch.zeros_like(img, dtype=dtype) # type: ignore + else: + if not isinstance(distances, torch.Tensor) and distances.device != img.device: + raise TypeError("distances must be a torch.Tensor on the same device as img") + if not distances.dtype == dtype: + raise TypeError("distances must be a torch.Tensor of dtype float32 or float64") + distances_ = convert_to_cupy(distances) + if return_indices: + dtype = torch.int32 + if indices is None: + indices = torch.zeros((img.dim(),) + img.shape, dtype=dtype) # type: ignore + else: + if not isinstance(indices, torch.Tensor) and indices.device != img.device: + raise TypeError("indices must be a torch.Tensor on the same device as img") + if not indices.dtype == dtype: + raise TypeError("indices must be a torch.Tensor of dtype int32") + indices_ = convert_to_cupy(indices) + img_ = convert_to_cupy(img) + for channel_idx in range(img_.shape[0]): + distance_transform_edt( + img_[channel_idx], + sampling=sampling, + return_distances=return_distances, + return_indices=return_indices, + distances=distances_[channel_idx] if distances_ is not None else None, + indices=indices_[channel_idx] if indices_ is not None else None, + block_params=block_params, + float64_distances=float64_distances, + ) + else: + if not has_ndimage: + raise RuntimeError("scipy.ndimage required if cupy is not available") + img_ = convert_to_numpy(img) + if return_distances: + if distances is None: + distances = np.zeros_like(img_, dtype=np.float64) + else: + if not isinstance(distances, np.ndarray): + raise TypeError("distances must be a numpy.ndarray") + if not distances.dtype == np.float64: + raise TypeError("distances must be a numpy.ndarray of dtype float64") + if return_indices: + if indices is None: + indices = np.zeros((img_.ndim,) + img_.shape, dtype=np.int32) + else: + if not isinstance(indices, np.ndarray): + raise TypeError("indices must be a numpy.ndarray") + if not indices.dtype == np.int32: + raise TypeError("indices must be a numpy.ndarray of dtype int32") + + for channel_idx in range(img_.shape[0]): + ndimage.distance_transform_edt( + img_[channel_idx], + sampling=sampling, + return_distances=return_distances, + return_indices=return_indices, + distances=distances[channel_idx] if distances is not None else None, + indices=indices[channel_idx] if indices is not None else None, + ) + + r_vals = [] + if return_distances and distances_original is None: + r_vals.append(distances) + if return_indices and indices_original is None: + r_vals.append(indices) + if not r_vals: + return None + if len(r_vals) == 1: + return r_vals[0] + return tuple(r_vals) # type: ignore + + if __name__ == "__main__": print_transform_backends() diff --git a/tests/min_tests.py b/tests/min_tests.py index 879cdc61b4..0f51df4652 100644 --- a/tests/min_tests.py +++ b/tests/min_tests.py @@ -61,6 +61,7 @@ def run_testsuit(): "test_deepgrow_transforms", "test_detect_envelope", "test_dints_network", + "test_distance_transform_edt", "test_efficientnet", "test_ensemble_evaluator", "test_ensure_channel_first", diff --git a/tests/test_distance_transform_edt.py b/tests/test_distance_transform_edt.py new file mode 100644 index 0000000000..83b9348348 --- /dev/null +++ b/tests/test_distance_transform_edt.py @@ -0,0 +1,202 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import numpy as np +import torch +from parameterized import parameterized + +from monai.transforms import DistanceTransformEDT, DistanceTransformEDTd +from tests.utils import HAS_CUPY, assert_allclose, optional_import, skip_if_no_cuda + +momorphology, has_cucim = optional_import("cucim.core.operations.morphology") +ndimage, has_ndimage = optional_import("scipy.ndimage") +cp, _ = optional_import("cupy") + +TEST_CASES = [ + [ + np.array( + ([[0, 1, 1, 1, 1], [0, 0, 1, 1, 1], [0, 1, 1, 1, 1], [0, 1, 1, 1, 0], [0, 1, 1, 0, 0]],), dtype=np.float32 + ), + np.array( + [ + [ + [0.0, 1.0, 1.4142, 2.2361, 3.0], + [0.0, 0.0, 1.0, 2.0, 2.0], + [0.0, 1.0, 1.4142, 1.4142, 1.0], + [0.0, 1.0, 1.4142, 1.0, 0.0], + [0.0, 1.0, 1.0, 0.0, 0.0], + ] + ] + ), + ], + [ # Example 4D input to test channel-wise CuPy + np.array( + [[[[0, 1, 1, 1, 1], [0, 0, 1, 1, 1], [0, 1, 1, 1, 1], [0, 1, 1, 1, 0], [0, 1, 1, 0, 0]]]], dtype=np.float32 + ), + np.array( + [ + [ + [ + [0.0, 1.0, 1.4142, 2.2361, 3.0], + [0.0, 0.0, 1.0, 2.0, 2.0], + [0.0, 1.0, 1.4142, 1.4142, 1.0], + [0.0, 1.0, 1.4142, 1.0, 0.0], + [0.0, 1.0, 1.0, 0.0, 0.0], + ] + ] + ] + ), + ], + [ + np.array( + [ + [ + [0.0, 1.0, 1.0, 1.0, 1.0], + [0.0, 0.0, 1.0, 1.0, 1.0], + [0.0, 1.0, 1.0, 1.0, 1.0], + [0.0, 1.0, 1.0, 1.0, 0.0], + [0.0, 1.0, 1.0, 0.0, 0.0], + ], + [ + [0.0, 1.0, 1.0, 1.0, 1.0], + [0.0, 0.0, 1.0, 1.0, 1.0], + [0.0, 1.0, 1.0, 1.0, 1.0], + [0.0, 1.0, 1.0, 1.0, 0.0], + [0.0, 1.0, 1.0, 0.0, 0.0], + ], + [ + [0.0, 1.0, 1.0, 1.0, 1.0], + [0.0, 0.0, 1.0, 1.0, 1.0], + [0.0, 1.0, 1.0, 1.0, 1.0], + [0.0, 1.0, 1.0, 1.0, 0.0], + [0.0, 1.0, 1.0, 0.0, 0.0], + ], + ], + dtype=np.float32, + ), + np.array( + [ + [ + [0.0, 1.0, 1.4142135, 2.236068, 3.0], + [0.0, 0.0, 1.0, 2.0, 2.0], + [0.0, 1.0, 1.4142135, 1.4142135, 1.0], + [0.0, 1.0, 1.4142135, 1.0, 0.0], + [0.0, 1.0, 1.0, 0.0, 0.0], + ], + [ + [0.0, 1.0, 1.4142135, 2.236068, 3.0], + [0.0, 0.0, 1.0, 2.0, 2.0], + [0.0, 1.0, 1.4142135, 1.4142135, 1.0], + [0.0, 1.0, 1.4142135, 1.0, 0.0], + [0.0, 1.0, 1.0, 0.0, 0.0], + ], + [ + [0.0, 1.0, 1.4142135, 2.236068, 3.0], + [0.0, 0.0, 1.0, 2.0, 2.0], + [0.0, 1.0, 1.4142135, 1.4142135, 1.0], + [0.0, 1.0, 1.4142135, 1.0, 0.0], + [0.0, 1.0, 1.0, 0.0, 0.0], + ], + ], + dtype=np.float32, + ), + ], +] + +SAMPLING_TEST_CASES = [ + [ + 2, + np.array( + ([[0, 1, 1, 1, 1], [0, 0, 1, 1, 1], [0, 1, 1, 1, 1], [0, 1, 1, 1, 0], [0, 1, 1, 0, 0]],), dtype=np.float32 + ), + np.array( + [ + [ + [0.0, 2.0, 2.828427, 4.472136, 6.0], + [0.0, 0.0, 2.0, 4.0, 4.0], + [0.0, 2.0, 2.828427, 2.828427, 2.0], + [0.0, 2.0, 2.828427, 2.0, 0.0], + [0.0, 2.0, 2.0, 0.0, 0.0], + ] + ] + ), + ] +] + +RAISES_TEST_CASES = ( + [ # Example 4D input. Should raise under CuPy + np.array( + [[[[[0, 1, 1, 1, 1], [0, 0, 1, 1, 1], [0, 1, 1, 1, 1], [0, 1, 1, 1, 0], [0, 1, 1, 0, 0]]]]], + dtype=np.float32, + ) + ], +) + + +class TestDistanceTransformEDT(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_scipy_transform(self, input, expected_output): + transform = DistanceTransformEDT() + output = transform(input) + assert_allclose(output, expected_output, atol=1e-4, rtol=1e-4, type_test=False) + + @parameterized.expand(TEST_CASES) + def test_scipy_transformd(self, input, expected_output): + transform = DistanceTransformEDTd(keys=("to_transform",)) + data = {"to_transform": input} + data_ = transform(data) + output = data_["to_transform"] + assert_allclose(output, expected_output, atol=1e-4, rtol=1e-4, type_test=False) + + @parameterized.expand(SAMPLING_TEST_CASES) + def test_scipy_sampling(self, sampling, input, expected_output): + transform = DistanceTransformEDT(sampling=sampling) + output = transform(input) + assert_allclose(output, expected_output, atol=1e-4, rtol=1e-4, type_test=False) + + @parameterized.expand(TEST_CASES) + @skip_if_no_cuda + @unittest.skipUnless(HAS_CUPY, "CuPy is required.") + @unittest.skipUnless(momorphology, "cuCIM transforms are required.") + def test_cucim_transform(self, input, expected_output): + input_ = torch.tensor(input, device="cuda") + transform = DistanceTransformEDT() + output = transform(input_) + assert_allclose(cp.asnumpy(output), cp.asnumpy(expected_output), atol=1e-4, rtol=1e-4, type_test=False) + + @parameterized.expand(SAMPLING_TEST_CASES) + @skip_if_no_cuda + @unittest.skipUnless(HAS_CUPY, "CuPy is required.") + @unittest.skipUnless(momorphology, "cuCIM transforms are required.") + def test_cucim_sampling(self, sampling, input, expected_output): + input_ = torch.tensor(input, device="cuda") + transform = DistanceTransformEDT(sampling=sampling) + output = transform(input_) + assert_allclose(cp.asnumpy(output), cp.asnumpy(expected_output), atol=1e-4, rtol=1e-4, type_test=False) + + @parameterized.expand(RAISES_TEST_CASES) + @skip_if_no_cuda + @unittest.skipUnless(HAS_CUPY, "CuPy is required.") + @unittest.skipUnless(momorphology, "cuCIM transforms are required.") + def test_cucim_raises(self, raises): + """Currently images of shape a certain shape are supported. This test checks for the according error message""" + input_ = torch.tensor(raises, device="cuda") + transform = DistanceTransformEDT() + with self.assertRaises(RuntimeError): + transform(input_) + + +if __name__ == "__main__": + unittest.main()