diff --git a/docs/source/apps.rst b/docs/source/apps.rst index e4bdf301a4..b4cc200f08 100644 --- a/docs/source/apps.rst +++ b/docs/source/apps.rst @@ -130,6 +130,30 @@ Applications .. autoclass:: TileOnGridd :members: +.. automodule:: monai.apps.pathology.transforms.post.array +.. autoclass:: Watershed + :members: +.. autoclass:: GenerateWatershedMask + :members: +.. autoclass:: GenerateInstanceBorder + :members: +.. autoclass:: GenerateDistanceMap + :members: +.. autoclass:: GenerateWatershedMarkers + :members: + +.. automodule:: monai.apps.pathology.transforms.post.dictionary +.. autoclass:: Watershedd + :members: +.. autoclass:: GenerateWatershedMaskd + :members: +.. autoclass:: GenerateInstanceBorderd + :members: +.. autoclass:: GenerateDistanceMapd + :members: +.. autoclass:: GenerateWatershedMarkersd + :members: + `Detection` ----------- diff --git a/monai/apps/pathology/transforms/__init__.py b/monai/apps/pathology/transforms/__init__.py index 290c0ba6a8..616cf3220a 100644 --- a/monai/apps/pathology/transforms/__init__.py +++ b/monai/apps/pathology/transforms/__init__.py @@ -9,6 +9,30 @@ # See the License for the specific language governing permissions and # limitations under the License. +from .post.array import ( + GenerateDistanceMap, + GenerateInstanceBorder, + GenerateWatershedMarkers, + GenerateWatershedMask, + Watershed, +) +from .post.dictionary import ( + GenerateDistanceMapD, + GenerateDistanceMapd, + GenerateDistanceMapDict, + GenerateInstanceBorderD, + GenerateInstanceBorderd, + GenerateInstanceBorderDict, + GenerateWatershedMarkersD, + GenerateWatershedMarkersd, + GenerateWatershedMarkersDict, + GenerateWatershedMaskD, + GenerateWatershedMaskd, + GenerateWatershedMaskDict, + WatershedD, + Watershedd, + WatershedDict, +) from .spatial.array import SplitOnGrid, TileOnGrid from .spatial.dictionary import SplitOnGridd, SplitOnGridD, SplitOnGridDict, TileOnGridd, TileOnGridD, TileOnGridDict from .stain.array import ExtractHEStains, NormalizeHEStains diff --git a/monai/apps/pathology/transforms/post/__init__.py b/monai/apps/pathology/transforms/post/__init__.py new file mode 100644 index 0000000000..46e1968367 --- /dev/null +++ b/monai/apps/pathology/transforms/post/__init__.py @@ -0,0 +1,35 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .array import ( + GenerateDistanceMap, + GenerateInstanceBorder, + GenerateWatershedMarkers, + GenerateWatershedMask, + Watershed, +) +from .dictionary import ( + GenerateDistanceMapD, + GenerateDistanceMapd, + GenerateDistanceMapDict, + GenerateInstanceBorderD, + GenerateInstanceBorderd, + GenerateInstanceBorderDict, + GenerateWatershedMarkersD, + GenerateWatershedMarkersd, + GenerateWatershedMarkersDict, + GenerateWatershedMaskD, + GenerateWatershedMaskd, + GenerateWatershedMaskDict, + WatershedD, + Watershedd, + WatershedDict, +) diff --git a/monai/apps/pathology/transforms/post/array.py b/monai/apps/pathology/transforms/post/array.py new file mode 100644 index 0000000000..8cef96a2c7 --- /dev/null +++ b/monai/apps/pathology/transforms/post/array.py @@ -0,0 +1,322 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable, Optional + +import numpy as np + +from monai.config.type_definitions import DtypeLike, NdarrayOrTensor +from monai.transforms.post.array import Activations, AsDiscrete, RemoveSmallObjects, SobelGradients +from monai.transforms.transform import Transform +from monai.transforms.utils_pytorch_numpy_unification import max, maximum, min +from monai.utils import TransformBackends, convert_to_numpy, optional_import +from monai.utils.type_conversion import convert_to_dst_type + +label, _ = optional_import("scipy.ndimage.measurements", name="label") +disk, _ = optional_import("skimage.morphology", name="disk") +opening, _ = optional_import("skimage.morphology", name="opening") +watershed, _ = optional_import("skimage.segmentation", name="watershed") + +__all__ = [ + "Watershed", + "GenerateWatershedMask", + "GenerateInstanceBorder", + "GenerateDistanceMap", + "GenerateWatershedMarkers", +] + + +class Watershed(Transform): + """ + Use `skimage.segmentation.watershed` to get instance segmentation results from images. + See: https://scikit-image.org/docs/stable/api/skimage.segmentation.html#skimage.segmentation.watershed. + + Args: + connectivity: An array with the same number of dimensions as image whose non-zero elements indicate + neighbors for connection. Following the scipy convention, default is a one-connected array of + the dimension of the image. + dtype: target data content type to convert, default is np.uint8. + + """ + + backend = [TransformBackends.NUMPY] + + def __init__(self, connectivity: Optional[int] = 1, dtype: DtypeLike = np.uint8) -> None: + self.connectivity = connectivity + self.dtype = dtype + + def __call__( # type: ignore + self, image: NdarrayOrTensor, mask: Optional[NdarrayOrTensor] = None, markers: Optional[NdarrayOrTensor] = None + ) -> NdarrayOrTensor: + """ + Args: + image: image where the lowest value points are labeled first. Shape must be [1, H, W, [D]]. + mask: optional, the same shape as image. Only points at which mask == True will be labeled. + If None (no mask given), it is a volume of all 1s. + markers: optional, the same shape as image. The desired number of markers, or an array marking + the basins with the values to be assigned in the label matrix. Zero means not a marker. + If None (no markers given), the local minima of the image are used as markers. + """ + + image = convert_to_numpy(image) + markers = convert_to_numpy(markers) + mask = convert_to_numpy(mask) + + instance_seg = watershed(image, markers=markers, mask=mask, connectivity=self.connectivity) + + return convert_to_dst_type(instance_seg, image, dtype=self.dtype)[0] + + +class GenerateWatershedMask(Transform): + """ + generate mask used in `watershed`. Only points at which mask == True will be labeled. + + Args: + softmax: if True, apply a softmax function to the prediction. + sigmoid: if True, apply a sigmoid function to the prediction. + threshold: if not None, threshold the float values to int number 0 or 1 with specified theashold. + remove_small_objects: whether need to remove some objects in the marker. Defaults to True. + min_size: objects smaller than this size are removed if `remove_small_objects` is True. Defaults to 10. + dtype: target data content type to convert, default is np.uint8. + + """ + + backend = [TransformBackends.NUMPY] + + def __init__( + self, + softmax: bool = True, + sigmoid: bool = False, + threshold: Optional[float] = None, + remove_small_objects: bool = True, + min_size: int = 10, + dtype: DtypeLike = np.uint8, + ) -> None: + if sigmoid and threshold is None: + raise ValueError("Threshold is needed when using sigmoid activation.") + + self.dtype = dtype + self.activations = Activations(sigmoid=sigmoid, softmax=softmax) + self.asdiscrete = AsDiscrete(threshold=threshold, argmax=softmax) + if remove_small_objects: + self.remove_small_objects = RemoveSmallObjects(min_size=min_size) + else: + self.remove_small_objects = None # type: ignore + + def __call__(self, prob_map: NdarrayOrTensor) -> NdarrayOrTensor: + """ + Args: + prob_map: probability map of segmentation, shape must be [C, H, W, [D]] + """ + + pred = self.activations(prob_map) + pred = self.asdiscrete(pred) + + pred = convert_to_numpy(pred) + + pred = label(pred)[0] + if self.remove_small_objects: + pred = self.remove_small_objects(pred) + pred[pred > 0] = 1 # type: ignore + + return convert_to_dst_type(pred, prob_map, dtype=self.dtype)[0] + + +class GenerateInstanceBorder(Transform): + """ + Generate instance border by hover map. The more parts of the image that cannot be identified as foreground areas, + the larger the grey scale value. The grey value of the instance's border will be larger. + + Args: + kernel_size: the size of the Sobel kernel. Defaults to 21. + min_size: objects smaller than this size are removed if `remove_small_objects` is True. Defaults to 10. + remove_small_objects: whether need to remove some objects in segmentation results. Defaults to True. + dtype: target data content type to convert, default is np.float32. + + + Raises: + ValueError: when the `mask` shape is not [1, H, W]. + ValueError: when the `hover_map` shape is not [2, H, W]. + + """ + + backend = [TransformBackends.NUMPY] + + def __init__( + self, + kernel_size: int = 21, + min_size: int = 10, + remove_small_objects: bool = True, + dtype: DtypeLike = np.float32, + ) -> None: + + self.dtype = dtype + + self.sobel_gradient = SobelGradients(kernel_size=kernel_size) + if remove_small_objects: + self.remove_small_objects = RemoveSmallObjects(min_size=min_size) + else: + self.remove_small_objects = None # type: ignore + + def __call__(self, mask: NdarrayOrTensor, hover_map: NdarrayOrTensor) -> NdarrayOrTensor: # type: ignore + """ + Args: + mask: binarized segmentation result. Shape must be [1, H, W]. + hover_map: horizontal and vertical distances of nuclear pixels to their centres of mass. Shape must be [2, H, W]. + The first and second channel represent the horizontal and vertical maps respectively. For more details refer + to papers: https://arxiv.org/abs/1812.06499. + + Return: + Instance border map. + + Raises: + ValueError: when the `hover_map` has only one value. + ValueError: when the `sobel gradient map` has only one value. + + """ + if len(mask.shape) != 3 or len(hover_map.shape) != 3: + raise ValueError( + f"Suppose the mask and hover map should be with shape of [C, H, W], but got {mask.shape}, {hover_map.shape}" + ) + if mask.shape[0] != 1: + raise ValueError(f"Suppose the mask only has one channel, but got {mask.shape[0]}") + if hover_map.shape[0] != 2: + raise ValueError(f"Suppose the hover map only has two channels, but got {hover_map.shape[0]}") + + hover_h = hover_map[0:1, ...] + hover_v = hover_map[1:2, ...] + + hover_h_min, hover_h_max = min(hover_h), max(hover_h) + hover_v_min, hover_v_max = min(hover_v), max(hover_v) + if (hover_h_max - hover_h_min) == 0 or (hover_v_max - hover_v_min) == 0: + raise ValueError("Not a valid hover map, please check your input") + hover_h = (hover_h - hover_h_min) / (hover_h_max - hover_h_min) + hover_v = (hover_v - hover_v_min) / (hover_v_max - hover_v_min) + sobelh = self.sobel_gradient(hover_h)[0, ...] + sobelv = self.sobel_gradient(hover_v)[1, ...] + sobelh_min, sobelh_max = min(sobelh), max(sobelh) + sobelv_min, sobelv_max = min(sobelv), max(sobelv) + if (sobelh_max - sobelh_min) == 0 or (sobelv_max - sobelv_min) == 0: + raise ValueError("Not a valid sobel gradient map") + sobelh = 1 - (sobelh - sobelh_min) / (sobelh_max - sobelh_min) + sobelv = 1 - (sobelv - sobelv_min) / (sobelv_max - sobelv_min) + + # combine the h & v values using max + overall = maximum(sobelh, sobelv) + overall = overall - (1 - mask) + overall[overall < 0] = 0 + + return convert_to_dst_type(overall, mask, dtype=self.dtype)[0] + + +class GenerateDistanceMap(Transform): + """ + Generate distance map. + In general, the instance map is calculated from the distance to the background. + Here, we use 1 - "instance border map" to generate the distance map. + Nuclei values form mountains so inverse to get basins. + + Args: + smooth_fn: execute smooth function on distance map. Defaults to None. You can specify + callable functions for smoothing. + For example, if you want apply gaussian smooth, you can specify `smooth_fn = GaussianSmooth()` + dtype: target data content type to convert, default is np.float32. + """ + + backend = [TransformBackends.NUMPY] + + def __init__(self, smooth_fn: Optional[Callable] = None, dtype: DtypeLike = np.float32) -> None: + self.smooth_fn = smooth_fn + self.dtype = dtype + + def __call__(self, mask: NdarrayOrTensor, instance_border: NdarrayOrTensor) -> NdarrayOrTensor: # type: ignore + """ + Args: + mask: binarized segmentation result. Shape must be [1, H, W]. + instance_border: foreground probability map. Shape must be [1, H, W]. + """ + if mask.shape[0] != 1 or mask.ndim != 3: + raise ValueError(f"Input mask should be with size of [1, H, W], but got {mask.shape}") + if instance_border.shape[0] != 1 or instance_border.ndim != 3: + raise ValueError(f"Input instance_border should be with size of [1, H, W], but got {instance_border.shape}") + + distance_map = (1.0 - instance_border) * mask + + if callable(self.smooth_fn): + distance_map = self.smooth_fn(distance_map) + + return convert_to_dst_type(-distance_map, mask, dtype=self.dtype)[0] + + +class GenerateWatershedMarkers(Transform): + """ + Generate markers to be used in `watershed`. The watershed algorithm treats pixels values as a local topography + (elevation). The algorithm floods basins from the markers until basins attributed to different markers meet on + watershed lines. Generally, markers are chosen as local minima of the image, from which basins are flooded. + Here is the implementation from HoVerNet papar. + For more details refer to papers: https://arxiv.org/abs/1812.06499. + + Args: + threshold: threshold the float values of foreground probability map to int 0 or 1 with specified theashold. + It turns uncertain area to 1 and other area to 0. Defaults to 0.4. + radius: the radius of the disk-shaped footprint used in `opening`. Defaults to 2. + min_size: objects smaller than this size are removed if `remove_small_objects` is True. Defaults to 10. + remove_small_objects: whether need to remove some objects in the marker. Defaults to True. + postprocess_fn: execute additional post transformation on marker. Defaults to None. + dtype: target data content type to convert, default is np.uint8. + + """ + + backend = [TransformBackends.NUMPY] + + def __init__( + self, + threshold: float = 0.4, + radius: int = 2, + min_size: int = 10, + remove_small_objects: bool = True, + postprocess_fn: Optional[Callable] = None, + dtype: DtypeLike = np.uint8, + ) -> None: + self.threshold = threshold + self.radius = radius + self.postprocess_fn = postprocess_fn + self.dtype = dtype + + if remove_small_objects: + self.remove_small_objects = RemoveSmallObjects(min_size=min_size) + + def __call__(self, mask: NdarrayOrTensor, instance_border: NdarrayOrTensor) -> NdarrayOrTensor: # type: ignore + """ + Args: + mask: binarized segmentation result. Shape must be [1, H, W]. + instance_border: instance border map. Shape must be [1, H, W]. + """ + if mask.shape[0] != 1 or mask.ndim != 3: + raise ValueError(f"Input mask should be with size of [1, H, W], but got {mask.shape}") + if instance_border.shape[0] != 1 or instance_border.ndim != 3: + raise ValueError(f"Input instance_border should be with size of [1, H, W], but got {instance_border.shape}") + + instance_border = instance_border >= self.threshold # uncertain area + + marker = mask - convert_to_dst_type(instance_border, mask, np.uint8)[0] # certain foreground + marker[marker < 0] = 0 # type: ignore + if self.postprocess_fn: + marker = self.postprocess_fn(marker) + + marker = convert_to_numpy(marker) + + marker = opening(marker.squeeze(), disk(self.radius)) + marker = label(marker)[0] + if self.remove_small_objects: + marker = self.remove_small_objects(marker[None]) + + return convert_to_dst_type(marker, mask, dtype=self.dtype)[0] diff --git a/monai/apps/pathology/transforms/post/dictionary.py b/monai/apps/pathology/transforms/post/dictionary.py new file mode 100644 index 0000000000..3eab526ee7 --- /dev/null +++ b/monai/apps/pathology/transforms/post/dictionary.py @@ -0,0 +1,302 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable, Dict, Hashable, Mapping, Optional + +import numpy as np + +from monai.apps.pathology.transforms.post.array import ( + GenerateDistanceMap, + GenerateInstanceBorder, + GenerateWatershedMarkers, + GenerateWatershedMask, + Watershed, +) +from monai.config.type_definitions import DtypeLike, KeysCollection, NdarrayOrTensor +from monai.transforms.transform import MapTransform + +__all__ = [ + "WatershedD", + "WatershedDict", + "Watershedd", + "GenerateWatershedMaskD", + "GenerateWatershedMaskDict", + "GenerateWatershedMaskd", + "GenerateInstanceBorderD", + "GenerateInstanceBorderDict", + "GenerateInstanceBorderd", + "GenerateDistanceMapD", + "GenerateDistanceMapDict", + "GenerateDistanceMapd", + "GenerateWatershedMarkersD", + "GenerateWatershedMarkersDict", + "GenerateWatershedMarkersd", +] + + +class Watershedd(MapTransform): + """ + Dictionary-based wrapper of :py:class:`monai.apps.pathology.transforms.array.Watershed`. + Use `skimage.segmentation.watershed` to get instance segmentation results from images. + See: https://scikit-image.org/docs/stable/api/skimage.segmentation.html#skimage.segmentation.watershed. + + Args: + keys: keys of the corresponding items to be transformed. + See also: monai.transforms.MapTransform + mask_key: keys of mask used in watershed. Only points at which mask == True will be labeled. + markers_key: keys of markers used in watershed. If None (no markers given), the local minima of the image are + used as markers. + connectivity: An array with the same number of dimensions as image whose non-zero elements indicate neighbors + for connection. Following the scipy convention, default is a one-connected array of the dimension of the + image. + dtype: target data content type to convert. Defaults to np.uint8. + allow_missing_keys: don't raise exception if key is missing. + + Raises: + ValueError: when the `image` shape is not [1, H, W]. + ValueError: when the `mask` shape is not [1, H, W]. + + """ + + backend = Watershed.backend + + def __init__( + self, + keys: KeysCollection, + mask_key: Optional[str] = "mask", + markers_key: Optional[str] = None, + connectivity: Optional[int] = 1, + dtype: DtypeLike = np.uint8, + allow_missing_keys: bool = False, + ) -> None: + super().__init__(keys, allow_missing_keys) + self.mask_key = mask_key + self.markers_key = markers_key + self.transform = Watershed(connectivity=connectivity, dtype=dtype) + + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + d = dict(data) + markers = d[self.markers_key] if self.markers_key else None + mask = d[self.mask_key] if self.mask_key else None + + for key in self.key_iterator(d): + d[key] = self.transform(d[key], mask, markers) + + return d + + +class GenerateWatershedMaskd(MapTransform): + """ + Dictionary-based wrapper of :py:class:`monai.apps.pathology.transforms.array.GenerateWatershedMask`. + + Args: + keys: keys of the corresponding items to be transformed. + mask_key: the mask will be written to the value of `{mask_key}`. + softmax: if True, apply a softmax function to the prediction. + sigmoid: if True, apply a sigmoid function to the prediction. + threshold: if not None, threshold the float values to int number 0 or 1 with specified theashold. + remove_small_objects: whether need to remove some objects in the marker. Defaults to True. + min_size: objects smaller than this size are removed if `remove_small_objects` is True. Defaults to 10. + dtype: target data content type to convert. Defaults to np.uint8. + allow_missing_keys: don't raise exception if key is missing. + + """ + + backend = GenerateWatershedMask.backend + + def __init__( + self, + keys: KeysCollection, + mask_key: str = "mask", + softmax: bool = True, + sigmoid: bool = False, + threshold: Optional[float] = None, + remove_small_objects: bool = True, + min_size: int = 10, + dtype: DtypeLike = np.uint8, + allow_missing_keys: bool = False, + ) -> None: + super().__init__(keys, allow_missing_keys) + self.mask_key = mask_key + self.transform = GenerateWatershedMask( + softmax=softmax, + sigmoid=sigmoid, + threshold=threshold, + remove_small_objects=remove_small_objects, + min_size=min_size, + dtype=dtype, + ) + + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + d = dict(data) + for key in self.key_iterator(d): + mask = self.transform(d[key]) + key_to_add = f"{self.mask_key}" + if key_to_add in d: + raise KeyError(f"Mask with key {key_to_add} already exists.") + d[key_to_add] = mask + return d + + +class GenerateInstanceBorderd(MapTransform): + """ + Dictionary-based wrapper of :py:class:`monai.apps.pathology.transforms.array.GenerateInstanceBorder`. + + Args: + keys: keys of the corresponding items to be transformed. + hover_map_key: keys of hover map used to generate probability map. + border_key: the instance border map will be written to the value of `{border_key}`. + kernel_size: the size of the Sobel kernel. Defaults to 21. + min_size: objects smaller than this size are removed if `remove_small_objects` is True. Defaults to 10. + remove_small_objects: whether need to remove some objects in segmentation results. Defaults to True. + dtype: target data content type to convert, default is np.float32. + allow_missing_keys: don't raise exception if key is missing. + + Raises: + ValueError: when the `hover_map` has only one value. + ValueError: when the `sobel gradient map` has only one value. + + """ + + backend = GenerateInstanceBorder.backend + + def __init__( + self, + keys: KeysCollection, + hover_map_key: str = "hover_map", + border_key: str = "border", + kernel_size: int = 21, + min_size: int = 10, + remove_small_objects: bool = True, + dtype: DtypeLike = np.float32, + allow_missing_keys: bool = False, + ) -> None: + super().__init__(keys, allow_missing_keys) + self.hover_map_key = hover_map_key + self.border_key = border_key + self.transform = GenerateInstanceBorder( + kernel_size=kernel_size, remove_small_objects=remove_small_objects, min_size=min_size, dtype=dtype + ) + + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + d = dict(data) + for key in self.key_iterator(d): + instance_border = self.transform(d[key], d[self.hover_map_key]) + key_to_add = f"{self.border_key}" + if key_to_add in d: + raise KeyError(f"Instance border map with key {key_to_add} already exists.") + d[key_to_add] = instance_border + return d + + +class GenerateDistanceMapd(MapTransform): + """ + Dictionary-based wrapper of :py:class:`monai.apps.pathology.transforms.array.GenerateDistanceMap`. + + Args: + keys: keys of the corresponding items to be transformed. + border_key: keys of the instance border map used to generate distance map. + dist_key: the distance map will be written to the value of `{dist_key}`. + smooth_fn: execute smooth function on distance map. Defaults to None. You can specify + callable functions for smoothing. + For example, if you want apply gaussian smooth, you can specify `smooth_fn = GaussianSmooth()` + dtype: target data content type to convert, default is np.float32. + allow_missing_keys: don't raise exception if key is missing. + """ + + backend = GenerateDistanceMap.backend + + def __init__( + self, + keys: KeysCollection, + border_key: str = "border", + dist_key: str = "dist", + smooth_fn: Optional[Callable] = None, + dtype: DtypeLike = np.float32, + allow_missing_keys: bool = False, + ) -> None: + super().__init__(keys, allow_missing_keys) + self.border_key = border_key + self.dist_key = dist_key + self.transform = GenerateDistanceMap(smooth_fn=smooth_fn, dtype=dtype) + + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + d = dict(data) + for key in self.key_iterator(d): + distance_map = self.transform(d[key], d[self.border_key]) + key_to_add = f"{self.dist_key}" + if key_to_add in d: + raise KeyError(f"Distance map with key {key_to_add} already exists.") + d[key_to_add] = distance_map + return d + + +class GenerateWatershedMarkersd(MapTransform): + """ + Dictionary-based wrapper of :py:class:`monai.apps.pathology.transforms.array.GenerateWatershedMarkers`. + + Args: + keys: keys of the corresponding items to be transformed. + border_key: keys of the instance border map used to generate markers. + markers_key: the markers will be written to the value of `{markers_key}`. + threshold: threshold the float values of instance border map to int 0 or 1 with specified theashold. + It turns uncertain area to 1 and other area to 0. Defaults to 0.4. + radius: the radius of the disk-shaped footprint used in `opening`. Defaults to 2. + min_size: objects smaller than this size are removed if `remove_small_objects` is True. Defaults to 10. + remove_small_objects: whether need to remove some objects in the marker. Defaults to True. + postprocess_fn: execute additional post transformation on marker. Defaults to None. + dtype: target data content type to convert, default is np.uint8. + allow_missing_keys: don't raise exception if key is missing. + """ + + backend = GenerateWatershedMarkers.backend + + def __init__( + self, + keys: KeysCollection, + border_key: str = "border", + markers_key: str = "markers", + threshold: float = 0.4, + radius: int = 2, + min_size: int = 10, + remove_small_objects: bool = True, + postprocess_fn: Optional[Callable] = None, + dtype: DtypeLike = np.uint8, + allow_missing_keys: bool = False, + ) -> None: + super().__init__(keys, allow_missing_keys) + self.border_key = border_key + self.markers_key = markers_key + self.transform = GenerateWatershedMarkers( + threshold=threshold, + radius=radius, + min_size=min_size, + remove_small_objects=remove_small_objects, + postprocess_fn=postprocess_fn, + dtype=dtype, + ) + + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + d = dict(data) + for key in self.key_iterator(d): + markers = self.transform(d[key], d[self.border_key]) + key_to_add = f"{self.markers_key}" + if key_to_add in d: + raise KeyError(f"Markers with key {key_to_add} already exists.") + d[key_to_add] = markers + return d + + +WatershedD = WatershedDict = Watershedd +GenerateWatershedMaskD = GenerateWatershedMaskDict = GenerateWatershedMaskd +GenerateInstanceBorderD = GenerateInstanceBorderDict = GenerateInstanceBorderd +GenerateDistanceMapD = GenerateDistanceMapDict = GenerateDistanceMapd +GenerateWatershedMarkersD = GenerateWatershedMarkersDict = GenerateWatershedMarkersd diff --git a/tests/test_generate_distance_map.py b/tests/test_generate_distance_map.py new file mode 100644 index 0000000000..0be252dbf8 --- /dev/null +++ b/tests/test_generate_distance_map.py @@ -0,0 +1,51 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from parameterized import parameterized + +from monai.apps.pathology.transforms.post.array import GenerateDistanceMap +from monai.transforms.intensity.array import GaussianSmooth +from tests.utils import TEST_NDARRAYS + +EXCEPTION_TESTS = [] +TESTS = [] + +np.random.RandomState(123) + +for p in TEST_NDARRAYS: + EXCEPTION_TESTS.append([{}, p(np.random.rand(2, 5, 5)), p(np.random.rand(1, 5, 5)), ValueError]) + + EXCEPTION_TESTS.append([{}, p(np.random.rand(1, 5, 5)), p(np.random.rand(2, 5, 5)), ValueError]) + +for p in TEST_NDARRAYS: + TESTS.append([{}, p(np.random.rand(1, 5, 5)), p(np.random.rand(1, 5, 5)), (1, 5, 5)]) + TESTS.append( + [{"smooth_fn": GaussianSmooth(sigma=0.4)}, p(np.random.rand(1, 5, 5)), p(np.random.rand(1, 5, 5)), (1, 5, 5)] + ) + + +class TestGenerateDistanceMap(unittest.TestCase): + @parameterized.expand(EXCEPTION_TESTS) + def test_value(self, argments, mask, probmap, exception_type): + with self.assertRaises(exception_type): + GenerateDistanceMap(**argments)(mask, probmap) + + @parameterized.expand(TESTS) + def test_value2(self, argments, mask, probmap, expected_shape): + result = GenerateDistanceMap(**argments)(mask, probmap) + self.assertEqual(result.shape, expected_shape) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_generate_distance_mapd.py b/tests/test_generate_distance_mapd.py new file mode 100644 index 0000000000..fb6e59f36b --- /dev/null +++ b/tests/test_generate_distance_mapd.py @@ -0,0 +1,62 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from parameterized import parameterized + +from monai.apps.pathology.transforms.post.dictionary import GenerateDistanceMapd +from monai.transforms.intensity.array import GaussianSmooth +from tests.utils import TEST_NDARRAYS + +EXCEPTION_TESTS = [] +TESTS = [] + +np.random.RandomState(123) + +for p in TEST_NDARRAYS: + EXCEPTION_TESTS.append( + [{"keys": "mask", "border_key": "border"}, p(np.random.rand(2, 5, 5)), p(np.random.rand(1, 5, 5)), ValueError] + ) + + EXCEPTION_TESTS.append( + [{"keys": "mask", "border_key": "border"}, p(np.random.rand(1, 5, 5)), p(np.random.rand(2, 5, 5)), ValueError] + ) + +for p in TEST_NDARRAYS: + TESTS.append( + [{"keys": "mask", "border_key": "border"}, p(np.random.rand(1, 5, 5)), p(np.random.rand(1, 5, 5)), (1, 5, 5)] + ) + TESTS.append( + [ + {"keys": "mask", "border_key": "border", "smooth_fn": GaussianSmooth(sigma=0.4)}, + p(np.random.rand(1, 5, 5)), + p(np.random.rand(1, 5, 5)), + (1, 5, 5), + ] + ) + + +class TestGenerateDistanceMapd(unittest.TestCase): + @parameterized.expand(EXCEPTION_TESTS) + def test_value(self, argments, mask, border_map, exception_type): + with self.assertRaises(exception_type): + GenerateDistanceMapd(**argments)({"mask": mask, "border": border_map}) + + @parameterized.expand(TESTS) + def test_value2(self, argments, mask, border_map, expected_shape): + result = GenerateDistanceMapd(**argments)({"mask": mask, "border": border_map}) + self.assertEqual(result["dist"].shape, expected_shape) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_generate_instance_border.py b/tests/test_generate_instance_border.py new file mode 100644 index 0000000000..ae265b5774 --- /dev/null +++ b/tests/test_generate_instance_border.py @@ -0,0 +1,86 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from parameterized import parameterized + +from monai.apps.pathology.transforms.post.array import GenerateInstanceBorder +from tests.utils import TEST_NDARRAYS + +EXCEPTION_TESTS = [] +TESTS = [] + +np.random.RandomState(123) + + +for p in TEST_NDARRAYS: + EXCEPTION_TESTS.append( + [ + {"kernel_size": 3, "remove_small_objects": False}, + p(np.random.rand(1, 5, 5, 5)), + p(np.random.rand(2, 5, 5)), + ValueError, + ] + ) + + EXCEPTION_TESTS.append( + [ + {"kernel_size": 3, "remove_small_objects": False}, + p(np.random.rand(1, 5, 5)), + p(np.random.rand(1, 5, 5)), + ValueError, + ] + ) + + EXCEPTION_TESTS.append( + [ + {"kernel_size": 3, "remove_small_objects": False}, + p(np.random.rand(2, 5, 5)), + p(np.random.rand(2, 5, 5)), + ValueError, + ] + ) + +for p in TEST_NDARRAYS: + TESTS.append( + [ + {"kernel_size": 3, "remove_small_objects": False}, + p(np.random.rand(1, 5, 5)), + p(np.random.rand(2, 5, 5)), + (1, 5, 5), + ] + ) + TESTS.append( + [ + {"kernel_size": 3, "remove_small_objects": False}, + p(np.random.rand(1, 5, 5)), + p(np.random.rand(2, 5, 5)), + (1, 5, 5), + ] + ) + + +class TestGenerateInstanceBorder(unittest.TestCase): + @parameterized.expand(EXCEPTION_TESTS) + def test_value(self, argments, mask, hover_map, exception_type): + with self.assertRaises(exception_type): + GenerateInstanceBorder(**argments)(mask, hover_map) + + @parameterized.expand(TESTS) + def test_value2(self, argments, mask, hover_map, expected_shape): + result = GenerateInstanceBorder(**argments)(mask, hover_map) + self.assertEqual(result.shape, expected_shape) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_generate_instance_borderd.py b/tests/test_generate_instance_borderd.py new file mode 100644 index 0000000000..f1139676c0 --- /dev/null +++ b/tests/test_generate_instance_borderd.py @@ -0,0 +1,86 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from parameterized import parameterized + +from monai.apps.pathology.transforms.post.dictionary import GenerateInstanceBorderd +from tests.utils import TEST_NDARRAYS + +EXCEPTION_TESTS = [] +TESTS = [] + +np.random.RandomState(123) + + +for p in TEST_NDARRAYS: + EXCEPTION_TESTS.append( + [ + {"keys": "mask", "kernel_size": 3, "remove_small_objects": True, "min_size": 10}, + p(np.random.rand(1, 5, 5, 5)), + p(np.random.rand(2, 5, 5)), + ValueError, + ] + ) + + EXCEPTION_TESTS.append( + [ + {"keys": "mask", "kernel_size": 3, "remove_small_objects": True, "min_size": 10}, + p(np.random.rand(1, 5, 5)), + p(np.random.rand(1, 5, 5)), + ValueError, + ] + ) + + EXCEPTION_TESTS.append( + [ + {"keys": "mask", "kernel_size": 3, "remove_small_objects": True, "min_size": 10}, + p(np.random.rand(2, 5, 5)), + p(np.random.rand(2, 5, 5)), + ValueError, + ] + ) + +for p in TEST_NDARRAYS: + TESTS.append( + [ + {"keys": "mask", "kernel_size": 3, "remove_small_objects": False, "min_size": 10}, + p(np.random.rand(1, 5, 5)), + p(np.random.rand(2, 5, 5)), + (1, 5, 5), + ] + ) + TESTS.append( + [ + {"keys": "mask", "kernel_size": 3, "remove_small_objects": True, "min_size": 10}, + p(np.random.rand(1, 5, 5)), + p(np.random.rand(2, 5, 5)), + (1, 5, 5), + ] + ) + + +class TestGenerateInstanceBorderd(unittest.TestCase): + @parameterized.expand(EXCEPTION_TESTS) + def test_value(self, argments, mask, hover_map, exception_type): + with self.assertRaises(exception_type): + GenerateInstanceBorderd(**argments)({"mask": mask, "hover_map": hover_map}) + + @parameterized.expand(TESTS) + def test_value2(self, argments, mask, hover_map, expected_shape): + result = GenerateInstanceBorderd(**argments)({"mask": mask, "hover_map": hover_map}) + self.assertEqual(result["border"].shape, expected_shape) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_generate_watershed_markers.py b/tests/test_generate_watershed_markers.py new file mode 100644 index 0000000000..7b046686e9 --- /dev/null +++ b/tests/test_generate_watershed_markers.py @@ -0,0 +1,53 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from parameterized import parameterized + +from monai.apps.pathology.transforms.post.array import GenerateWatershedMarkers +from monai.utils import min_version, optional_import +from tests.utils import TEST_NDARRAYS + +_, has_skimage = optional_import("skimage", "0.19.3", min_version) +_, has_scipy = optional_import("scipy", "1.8.1", min_version) + +EXCEPTION_TESTS = [] +TESTS = [] + +np.random.RandomState(123) + +for p in TEST_NDARRAYS: + EXCEPTION_TESTS.append([{}, p(np.random.rand(2, 5, 5)), p(np.random.rand(1, 5, 5)), ValueError]) + + EXCEPTION_TESTS.append([{}, p(np.random.rand(1, 5, 5)), p(np.random.rand(2, 5, 5)), ValueError]) + +for p in TEST_NDARRAYS: + TESTS.append([{}, p(np.random.rand(1, 5, 5)), p(np.random.rand(1, 5, 5)), (1, 5, 5)]) + + +@unittest.skipUnless(has_skimage, "Requires scikit-image library.") +@unittest.skipUnless(has_scipy, "Requires scipy library.") +class TestGenerateWatershedMarkers(unittest.TestCase): + @parameterized.expand(EXCEPTION_TESTS) + def test_value(self, argments, mask, probmap, exception_type): + with self.assertRaises(exception_type): + GenerateWatershedMarkers(**argments)(mask, probmap) + + @parameterized.expand(TESTS) + def test_value2(self, argments, mask, probmap, expected_shape): + result = GenerateWatershedMarkers(**argments)(mask, probmap) + self.assertEqual(result.shape, expected_shape) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_generate_watershed_markersd.py b/tests/test_generate_watershed_markersd.py new file mode 100644 index 0000000000..cccb20c985 --- /dev/null +++ b/tests/test_generate_watershed_markersd.py @@ -0,0 +1,68 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from parameterized import parameterized + +from monai.apps.pathology.transforms.post.dictionary import GenerateWatershedMarkersd +from monai.utils import min_version, optional_import +from tests.utils import TEST_NDARRAYS + +_, has_skimage = optional_import("skimage", "0.19.3", min_version) +_, has_scipy = optional_import("scipy", "1.8.1", min_version) + +EXCEPTION_TESTS = [] +TESTS = [] + +np.random.RandomState(123) + +for p in TEST_NDARRAYS: + EXCEPTION_TESTS.append( + [{"keys": "mask", "border_key": "border"}, p(np.random.rand(2, 5, 5)), p(np.random.rand(1, 5, 5)), ValueError] + ) + + EXCEPTION_TESTS.append( + [{"keys": "mask", "border_key": "border"}, p(np.random.rand(1, 5, 5)), p(np.random.rand(2, 5, 5)), ValueError] + ) + + EXCEPTION_TESTS.append( + [ + {"keys": "mask", "border_key": "border", "markers_key": "old_markers"}, + p(np.random.rand(1, 5, 5)), + p(np.random.rand(1, 5, 5)), + KeyError, + ] + ) + +for p in TEST_NDARRAYS: + TESTS.append( + [{"keys": "mask", "border_key": "border"}, p(np.random.rand(1, 5, 5)), p(np.random.rand(1, 5, 5)), (1, 5, 5)] + ) + + +@unittest.skipUnless(has_skimage, "Requires scikit-image library.") +@unittest.skipUnless(has_scipy, "Requires scipy library.") +class TestGenerateWatershedMarkersd(unittest.TestCase): + @parameterized.expand(EXCEPTION_TESTS) + def test_value(self, argments, mask, border_map, exception_type): + with self.assertRaises(exception_type): + GenerateWatershedMarkersd(**argments)({"mask": mask, "border": border_map, "old_markers": 1}) + + @parameterized.expand(TESTS) + def test_value2(self, argments, mask, border_map, expected_shape): + result = GenerateWatershedMarkersd(**argments)({"mask": mask, "border": border_map}) + self.assertEqual(result["markers"].shape, expected_shape) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_generate_watershed_mask.py b/tests/test_generate_watershed_mask.py new file mode 100644 index 0000000000..d6ad491dc6 --- /dev/null +++ b/tests/test_generate_watershed_mask.py @@ -0,0 +1,82 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import torch +from parameterized import parameterized + +from monai.apps.pathology.transforms.post.array import GenerateWatershedMask +from monai.utils import min_version, optional_import +from tests.utils import TEST_NDARRAYS + +_, has_scipy = optional_import("scipy", "1.8.1", min_version) + +EXCEPTION_TESTS = [] +TESTS = [] + +np.random.RandomState(123) + + +for p in TEST_NDARRAYS: + EXCEPTION_TESTS.append( + [ + {"softmax": False, "sigmoid": True, "remove_small_objects": True, "min_size": 10}, + p(np.random.rand(1, 5, 5)), + ValueError, + ] + ) + +for p in TEST_NDARRAYS: + TESTS.append( + [ + {"softmax": True, "sigmoid": False, "threshold": None, "remove_small_objects": False, "min_size": 10}, + p( + [ + [[0.5022, 0.3403, 0.9997], [0.8793, 0.5514, 0.2697], [0.6134, 0.6389, 0.0680]], + [[0.5000, 0.3400, 0.9900], [0.8900, 0.5600, 0.2700], [0.6100, 0.6300, 0.0600]], + ] + ), + (1, 3, 3), + [0, 1], + ] + ) + + TESTS.append( + [ + {"softmax": False, "sigmoid": True, "threshold": 0.5, "remove_small_objects": False, "min_size": 10}, + p([[[0.5022, 0.3403, 0.9997], [0.8793, 0.5514, 0.2697], [-0.1134, -0.0389, -0.0680]]]), + (1, 3, 3), + [0, 1], + ] + ) + + +@unittest.skipUnless(has_scipy, "Requires scipy library.") +class TestGenerateWatershedMask(unittest.TestCase): + @parameterized.expand(EXCEPTION_TESTS) + def test_value(self, argments, image, exception_type): + with self.assertRaises(exception_type): + GenerateWatershedMask(**argments)(image) + + @parameterized.expand(TESTS) + def test_value2(self, argments, image, expected_shape, expected_value): + result = GenerateWatershedMask(**argments)(image) + self.assertEqual(result.shape, expected_shape) + + if isinstance(result, torch.Tensor): + result = result.cpu().numpy() + self.assertEqual(np.unique(result).tolist(), expected_value) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_generate_watershed_maskd.py b/tests/test_generate_watershed_maskd.py new file mode 100644 index 0000000000..5c9699d8fb --- /dev/null +++ b/tests/test_generate_watershed_maskd.py @@ -0,0 +1,98 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import torch +from parameterized import parameterized + +from monai.apps.pathology.transforms.post.dictionary import GenerateWatershedMaskd +from monai.utils import min_version, optional_import +from tests.utils import TEST_NDARRAYS + +_, has_scipy = optional_import("scipy", "1.8.1", min_version) + +EXCEPTION_TESTS = [] +TESTS = [] + +np.random.RandomState(123) + + +for p in TEST_NDARRAYS: + EXCEPTION_TESTS.append( + [ + {"keys": "img", "softmax": False, "sigmoid": True, "remove_small_objects": True, "min_size": 10}, + p(np.random.rand(1, 5, 5)), + ValueError, + ] + ) + +for p in TEST_NDARRAYS: + TESTS.append( + [ + { + "keys": "img", + "mask_key": "mask", + "softmax": True, + "sigmoid": False, + "threshold": None, + "remove_small_objects": False, + "min_size": 10, + }, + p( + [ + [[0.5022, 0.3403, 0.9997], [0.8793, 0.5514, 0.2697], [0.6134, 0.6389, 0.0680]], + [[0.5000, 0.3400, 0.9900], [0.8900, 0.5600, 0.2700], [0.6100, 0.6300, 0.0600]], + ] + ), + (1, 3, 3), + [0, 1], + ] + ) + + TESTS.append( + [ + { + "keys": "img", + "mask_key": "mask", + "softmax": False, + "sigmoid": True, + "threshold": 0.5, + "remove_small_objects": False, + "min_size": 10, + }, + p([[[0.5022, 0.3403, 0.9997], [0.8793, 0.5514, 0.2697], [-0.1134, -0.0389, -0.0680]]]), + (1, 3, 3), + [0, 1], + ] + ) + + +@unittest.skipUnless(has_scipy, "Requires scipy library.") +class TestGenerateWatershedMaskd(unittest.TestCase): + @parameterized.expand(EXCEPTION_TESTS) + def test_value(self, argments, image, exception_type): + with self.assertRaises(exception_type): + GenerateWatershedMaskd(**argments)({"img": image}) + + @parameterized.expand(TESTS) + def test_value2(self, argments, image, expected_shape, expected_value): + result = GenerateWatershedMaskd(**argments)({"img": image}) + self.assertEqual(result["mask"].shape, expected_shape) + + if isinstance(result["mask"], torch.Tensor): + result["mask"] = result["mask"].cpu().numpy() + self.assertEqual(np.unique(result["mask"]).tolist(), expected_value) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_watershed.py b/tests/test_watershed.py new file mode 100644 index 0000000000..705ddce817 --- /dev/null +++ b/tests/test_watershed.py @@ -0,0 +1,58 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from parameterized import parameterized + +from monai.apps.pathology.transforms.post.array import ( + GenerateDistanceMap, + GenerateInstanceBorder, + GenerateWatershedMarkers, + GenerateWatershedMask, + Watershed, +) +from monai.utils import min_version, optional_import +from tests.utils import TEST_NDARRAYS + +_, has_skimage = optional_import("skimage", "0.19.3", min_version) +_, has_scipy = optional_import("scipy", "1.8.1", min_version) + +np.random.RandomState(123) + +TESTS = [] +params = {"connectivity": 1} +for p in TEST_NDARRAYS: + image = p(np.random.rand(1, 10, 10)) + hover_map = p(np.random.rand(2, 10, 10)) + + TESTS.append([params, image, hover_map, (1, 10, 10)]) + + +@unittest.skipUnless(has_skimage, "Requires scikit-image library.") +@unittest.skipUnless(has_scipy, "Requires scipy library.") +class TestWatershed(unittest.TestCase): + @parameterized.expand(TESTS) + def test_output(self, args, image, hover_map, expected_shape): + mask = GenerateWatershedMask()(image) + border_map = GenerateInstanceBorder(kernel_size=3)(mask, hover_map) + distance_map = GenerateDistanceMap()(mask, border_map) + markers = GenerateWatershedMarkers()(mask, border_map) + + calculate_instance_seg = Watershed(**args) + output = calculate_instance_seg(distance_map, mask, markers) + + self.assertTupleEqual(output.shape, expected_shape) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_watershedd.py b/tests/test_watershedd.py new file mode 100644 index 0000000000..6474759de4 --- /dev/null +++ b/tests/test_watershedd.py @@ -0,0 +1,68 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from parameterized import parameterized + +from monai.apps.pathology.transforms.post.dictionary import ( + GenerateDistanceMapd, + GenerateInstanceBorderd, + GenerateWatershedMarkersd, + GenerateWatershedMaskd, + Watershedd, +) +from monai.transforms import Compose +from monai.utils import min_version, optional_import +from tests.utils import TEST_NDARRAYS + +_, has_skimage = optional_import("skimage", "0.19.3", min_version) +_, has_scipy = optional_import("scipy", "1.8.1", min_version) + +TESTS = [] +params = {"keys": "dist", "mask_key": "mask", "markers_key": "markers", "connectivity": 1} +for p in TEST_NDARRAYS: + image = p(np.random.rand(1, 10, 10)) + hover_map = p(np.random.rand(2, 10, 10)) + + TESTS.append([params, image, hover_map, (1, 10, 10)]) + + params.update({"markers_key": None}) + TESTS.append([params, image, hover_map, (1, 10, 10)]) + + params.update({"mask_key": None, "markers_key": None}) + TESTS.append([params, image, hover_map, (1, 10, 10)]) + + +@unittest.skipUnless(has_skimage, "Requires scikit-image library.") +@unittest.skipUnless(has_scipy, "Requires scipy library.") +class TestWatershedd(unittest.TestCase): + @parameterized.expand(TESTS) + def test_output(self, args, image, hover_map, expected_shape): + data = {"output": image, "hover_map": hover_map} + + trans = Compose( + [ + GenerateWatershedMaskd(keys="output"), + GenerateInstanceBorderd(keys="mask", hover_map_key="hover_map", kernel_size=3), + GenerateDistanceMapd(keys="mask", border_key="border"), + GenerateWatershedMarkersd(keys="mask", border_key="border"), + Watershedd(**args), + ] + ) + + output = trans(data) + self.assertTupleEqual(output["dist"].shape, expected_shape) + + +if __name__ == "__main__": + unittest.main()