diff --git a/docs/source/apps.rst b/docs/source/apps.rst index b4cc200f08..248813d679 100644 --- a/docs/source/apps.rst +++ b/docs/source/apps.rst @@ -131,6 +131,14 @@ Applications :members: .. automodule:: monai.apps.pathology.transforms.post.array +.. autoclass:: GenerateSuccinctContour + :members: +.. autoclass:: GenerateInstanceContour + :members: +.. autoclass:: GenerateInstanceCentroid + :members: +.. autoclass:: GenerateInstanceType + :members: .. autoclass:: Watershed :members: .. autoclass:: GenerateWatershedMask @@ -143,6 +151,14 @@ Applications :members: .. automodule:: monai.apps.pathology.transforms.post.dictionary +.. autoclass:: GenerateSuccinctContourd + :members: +.. autoclass:: GenerateInstanceContourd + :members: +.. autoclass:: GenerateInstanceCentroidd + :members: +.. autoclass:: GenerateInstanceTyped + :members: .. autoclass:: Watershedd :members: .. autoclass:: GenerateWatershedMaskd diff --git a/monai/apps/pathology/transforms/__init__.py b/monai/apps/pathology/transforms/__init__.py index 616cf3220a..3e784b8ebf 100644 --- a/monai/apps/pathology/transforms/__init__.py +++ b/monai/apps/pathology/transforms/__init__.py @@ -12,6 +12,10 @@ from .post.array import ( GenerateDistanceMap, GenerateInstanceBorder, + GenerateInstanceCentroid, + GenerateInstanceContour, + GenerateInstanceType, + GenerateSuccinctContour, GenerateWatershedMarkers, GenerateWatershedMask, Watershed, @@ -23,6 +27,18 @@ GenerateInstanceBorderD, GenerateInstanceBorderd, GenerateInstanceBorderDict, + GenerateInstanceCentroidD, + GenerateInstanceCentroidd, + GenerateInstanceCentroidDict, + GenerateInstanceContourD, + GenerateInstanceContourd, + GenerateInstanceContourDict, + GenerateInstanceTypeD, + GenerateInstanceTyped, + GenerateInstanceTypeDict, + GenerateSuccinctContourD, + GenerateSuccinctContourd, + GenerateSuccinctContourDict, GenerateWatershedMarkersD, GenerateWatershedMarkersd, GenerateWatershedMarkersDict, diff --git a/monai/apps/pathology/transforms/post/__init__.py b/monai/apps/pathology/transforms/post/__init__.py index 46e1968367..3e6af77ce6 100644 --- a/monai/apps/pathology/transforms/post/__init__.py +++ b/monai/apps/pathology/transforms/post/__init__.py @@ -12,6 +12,10 @@ from .array import ( GenerateDistanceMap, GenerateInstanceBorder, + GenerateInstanceCentroid, + GenerateInstanceContour, + GenerateInstanceType, + GenerateSuccinctContour, GenerateWatershedMarkers, GenerateWatershedMask, Watershed, @@ -23,6 +27,18 @@ GenerateInstanceBorderD, GenerateInstanceBorderd, GenerateInstanceBorderDict, + GenerateInstanceCentroidD, + GenerateInstanceCentroidd, + GenerateInstanceCentroidDict, + GenerateInstanceContourD, + GenerateInstanceContourd, + GenerateInstanceContourDict, + GenerateInstanceTypeD, + GenerateInstanceTyped, + GenerateInstanceTypeDict, + GenerateSuccinctContourD, + GenerateSuccinctContourd, + GenerateSuccinctContourDict, GenerateWatershedMarkersD, GenerateWatershedMarkersd, GenerateWatershedMarkersDict, diff --git a/monai/apps/pathology/transforms/post/array.py b/monai/apps/pathology/transforms/post/array.py index 2f84e96257..55ff531172 100644 --- a/monai/apps/pathology/transforms/post/array.py +++ b/monai/apps/pathology/transforms/post/array.py @@ -9,21 +9,24 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, Optional +from typing import Callable, List, Optional, Sequence, Tuple, Union import numpy as np from monai.config.type_definitions import DtypeLike, NdarrayOrTensor from monai.transforms.post.array import Activations, AsDiscrete, RemoveSmallObjects, SobelGradients from monai.transforms.transform import Transform -from monai.transforms.utils_pytorch_numpy_unification import max, maximum, min +from monai.transforms.utils_pytorch_numpy_unification import max, maximum, min, sum, unique from monai.utils import TransformBackends, convert_to_numpy, optional_import +from monai.utils.misc import ensure_tuple_rep from monai.utils.type_conversion import convert_to_dst_type label, _ = optional_import("scipy.ndimage.measurements", name="label") disk, _ = optional_import("skimage.morphology", name="disk") opening, _ = optional_import("skimage.morphology", name="opening") watershed, _ = optional_import("skimage.segmentation", name="watershed") +find_contours, _ = optional_import("skimage.measure", name="find_contours") +centroid, _ = optional_import("skimage.measure", name="centroid") __all__ = [ "Watershed", @@ -31,6 +34,10 @@ "GenerateInstanceBorder", "GenerateDistanceMap", "GenerateWatershedMarkers", + "GenerateSuccinctContour", + "GenerateInstanceContour", + "GenerateInstanceCentroid", + "GenerateInstanceType", ] @@ -320,3 +327,300 @@ def __call__(self, mask: NdarrayOrTensor, instance_border: NdarrayOrTensor) -> N marker = self.remove_small_objects(marker[None]) return convert_to_dst_type(marker, mask, dtype=self.dtype)[0] + + +class GenerateSuccinctContour(Transform): + """ + Converts Scipy-style contours(generated by skimage.measure.find_contours) to a more succinct version which only includes + the pixels to which lines need to be drawn (i.e. not the intervening pixels along each line). + + Args: + height: height of bounding box, used to detect direction of line segment. + width: width of bounding box, used to detect direction of line segment. + + Returns: + the pixels that need to be joined by straight lines to describe the outmost pixels of the foreground similar to + OpenCV's cv.CHAIN_APPROX_SIMPLE (counterclockwise) + """ + + def __init__(self, height: int, width: int) -> None: + self.height = height + self.width = width + + def _generate_contour_coord(self, current: np.ndarray, previous: np.ndarray) -> Tuple[int, int]: + """ + Generate contour coordinates. Given the previous and current coordinates of border positions, + returns the int pixel that marks the extremity of the segmented pixels. + + Args: + current: coordinates of the current border position. + previous: coordinates of the previous border position. + """ + + p_delta = (current[0] - previous[0], current[1] - previous[1]) + + if p_delta == (0.0, 1.0) or p_delta == (0.5, 0.5) or p_delta == (1.0, 0.0): + row = int(current[0] + 0.5) + col = int(current[1]) + elif p_delta == (0.0, -1.0) or p_delta == (0.5, -0.5): + row = int(current[0]) + col = int(current[1]) + elif p_delta == (-1, 0.0) or p_delta == (-0.5, -0.5): + row = int(current[0]) + col = int(current[1] + 0.5) + elif p_delta == (-0.5, 0.5): + row = int(current[0] + 0.5) + col = int(current[1] + 0.5) + + return row, col + + def _calculate_distance_from_topleft(self, sequence: Sequence[Tuple[int, int]]) -> int: + """ + Each sequence of coordinates describes a boundary between foreground and background starting and ending at two sides + of the bounding box. To order the sequences correctly, we compute the distance from the topleft of the bounding box + around the perimeter in a clockwise direction. + + Args: + sequence: list of border points coordinates. + + Returns: + the distance round the perimeter of the bounding box from the top-left origin + """ + distance: int + first_coord = sequence[0] + if first_coord[0] == 0: + distance = first_coord[1] + elif first_coord[1] == self.width - 1: + distance = self.width + first_coord[0] + elif first_coord[0] == self.height - 1: + distance = 2 * self.width + self.height - first_coord[1] + else: + distance = 2 * (self.width + self.height) - first_coord[0] + + return distance + + def __call__(self, contours: List[np.ndarray]) -> np.ndarray: + """ + Args: + contours: list of (n, 2)-ndarrays, scipy-style clockwise line segments, with lines separating foreground/background. + Each contour is an ndarray of shape (n, 2), consisting of n (row, column) coordinates along the contour. + """ + pixels: List[Tuple[int, int]] = [] + sequences = [] + corners = [False, False, False, False] + + for group in contours: + sequence: List[Tuple[int, int]] = [] + last_added = None + prev = None + corner = -1 + + for i, coord in enumerate(group): + if i == 0: + # originating from the top, so must be heading south east + if coord[0] == 0.0: + corner = 1 + pixel = (0, int(coord[1] - 0.5)) + if pixel[1] == self.width - 1: + corners[1] = True + elif pixel[1] == 0.0: + corners[0] = True + # originating from the left, so must be heading north east + elif coord[1] == 0.0: + corner = 0 + pixel = (int(coord[0] + 0.5), 0) + # originating from the bottom, so must be heading north west + elif coord[0] == self.height - 1: + corner = 3 + pixel = (int(coord[0]), int(coord[1] + 0.5)) + if pixel[1] == self.width - 1: + corners[2] = True + # originating from the right, so must be heading south west + elif coord[1] == self.width - 1: + corner = 2 + pixel = (int(coord[0] - 0.5), int(coord[1])) + sequence.append(pixel) + last_added = pixel + elif i == len(group) - 1: + # add this point + pixel = self._generate_contour_coord(coord, prev) # type: ignore + if pixel != last_added: + sequence.append(pixel) + last_added = pixel + elif np.any(coord - prev != group[i + 1] - coord): + pixel = self._generate_contour_coord(coord, prev) # type: ignore + if pixel != last_added: + sequence.append(pixel) + last_added = pixel + + # flag whether each corner has been crossed + if i == len(group) - 1: + if corner == 0: + if coord[0] == 0: + corners[corner] = True + elif corner == 1: + if coord[1] == self.width - 1: + corners[corner] = True + elif corner == 2: + if coord[0] == self.height - 1: + corners[corner] = True + elif corner == 3: + if coord[1] == 0.0: + corners[corner] = True + + prev = coord + dist = self._calculate_distance_from_topleft(sequence) + + sequences.append({"distance": dist, "sequence": sequence}) + + # check whether we need to insert any missing corners + if corners[0] is False: + sequences.append({"distance": 0, "sequence": [(0, 0)]}) + if corners[1] is False: + sequences.append({"distance": self.width, "sequence": [(0, self.width - 1)]}) + if corners[2] is False: + sequences.append({"distance": self.width + self.height, "sequence": [(self.height - 1, self.width - 1)]}) + if corners[3] is False: + sequences.append({"distance": 2 * self.width + self.height, "sequence": [(self.height - 1, 0)]}) + + # join the sequences into a single contour + # starting at top left and rotating clockwise + sequences.sort(key=lambda x: x.get("distance")) # type: ignore + + last = (-1, -1) + for _sequence in sequences: + if _sequence["sequence"][0] == last: # type: ignore + pixels.pop() + if pixels: + pixels = [*pixels, *_sequence["sequence"]] # type: ignore + else: + pixels = _sequence["sequence"] # type: ignore + last = pixels[-1] + + if pixels[0] == last: + pixels.pop(0) + + if pixels[0] == (0, 0): + pixels.append(pixels.pop(0)) + + return np.flip(convert_to_numpy(pixels, dtype=np.int32)) # type: ignore + + +class GenerateInstanceContour(Transform): + """ + Generate contour for each instance in a 2D array. Use `GenerateSuccinctContour` to only include + the pixels to which lines need to be drawn + + Args: + points_num: assumed that the created contour does not form a contour if it does not contain more points + than the specified value. Defaults to 3. + level: optional. Value along which to find contours in the array. By default, the level is set + to (max(image) + min(image)) / 2. + + """ + + backend = [TransformBackends.NUMPY] + + def __init__(self, points_num: int = 3, level: Optional[float] = None) -> None: + self.level = level + self.points_num = points_num + + def __call__(self, image: NdarrayOrTensor, offset: Optional[Sequence[int]] = (0, 0)) -> np.ndarray: + """ + Args: + image: instance-level segmentation result. Shape should be [C, H, W] + offset: optional, offset of starting position of the instance in the array, default is (0, 0). + """ + image = image.squeeze() # squeeze channel dim + image = convert_to_numpy(image) + inst_contour_cv = find_contours(image, level=self.level) + generate_contour = GenerateSuccinctContour(image.shape[0], image.shape[1]) + inst_contour = generate_contour(inst_contour_cv) + + # < `self.points_num` points don't make a contour, so skip, likely artifact too + # as the contours obtained via approximation => too small or sthg + if inst_contour.shape[0] < self.points_num: + print(f"< {self.points_num} points don't make a contour, so skip") + return None # type: ignore + # check for tricky shape + elif len(inst_contour.shape) != 2: + print(f"{len(inst_contour.shape)} != 2, check for tricky shape") + return None # type: ignore + else: + inst_contour[:, 0] += offset[0] # type: ignore + inst_contour[:, 1] += offset[1] # type: ignore + return inst_contour + + +class GenerateInstanceCentroid(Transform): + """ + Generate instance centroid using `skimage.measure.centroid`. + + Args: + dtype: the data type of output centroid. + + """ + + backend = [TransformBackends.NUMPY] + + def __init__(self, dtype: Optional[DtypeLike] = int) -> None: + self.dtype = dtype + + def __call__(self, image: NdarrayOrTensor, offset: Union[Sequence[int], int] = 0) -> np.ndarray: + """ + Args: + image: instance-level segmentation result. Shape should be [1, H, W, [D]] + offset: optional, offset of starting position of the instance in the array, default is 0 for each dim. + + """ + image = convert_to_numpy(image) + image = image.squeeze(0) # squeeze channel dim + ndim = len(image.shape) + offset = ensure_tuple_rep(offset, ndim) + + inst_centroid = centroid(image) + for i in range(ndim): + inst_centroid[i] += offset[i] + + return convert_to_dst_type(inst_centroid, image, dtype=self.dtype)[0] # type: ignore + + +class GenerateInstanceType(Transform): + """ + Generate instance type and probability for each instance. + """ + + backend = [TransformBackends.NUMPY] + + def __init__(self) -> None: + super().__init__() + + def __call__( # type: ignore + self, type_pred: NdarrayOrTensor, seg_pred: NdarrayOrTensor, bbox: np.ndarray, instance_id: int + ) -> Tuple[int, float]: + """ + Args: + type_pred: pixel-level type prediction map after activation function. + seg_pred: pixel-level segmentation prediction map after activation function. + bbox: bounding box coordinates of the instance, shape is [channel, 2 * spatial dims]. + instance_id: get instance type from specified instance id. + """ + + rmin, rmax, cmin, cmax = bbox.flatten() + seg_map_crop = seg_pred[0, rmin:rmax, cmin:cmax] + type_map_crop = type_pred[0, rmin:rmax, cmin:cmax] + + seg_map_crop = convert_to_dst_type(seg_map_crop == instance_id, type_map_crop, dtype=bool)[0] + + inst_type = type_map_crop[seg_map_crop] # type: ignore + type_list, type_pixels = unique(inst_type, return_counts=True) + type_list = list(zip(type_list, type_pixels)) + type_list = sorted(type_list, key=lambda x: x[1], reverse=True) # type: ignore + inst_type = type_list[0][0] + if inst_type == 0: # ! pick the 2nd most dominant if exist + if len(type_list) > 1: + inst_type = type_list[1][0] + type_dict = {v[0]: v[1] for v in type_list} + type_prob = type_dict[inst_type] / (sum(seg_map_crop) + 1.0e-6) + + return (int(inst_type), float(type_prob)) diff --git a/monai/apps/pathology/transforms/post/dictionary.py b/monai/apps/pathology/transforms/post/dictionary.py index 3eab526ee7..c358eebf39 100644 --- a/monai/apps/pathology/transforms/post/dictionary.py +++ b/monai/apps/pathology/transforms/post/dictionary.py @@ -16,12 +16,20 @@ from monai.apps.pathology.transforms.post.array import ( GenerateDistanceMap, GenerateInstanceBorder, + GenerateInstanceCentroid, + GenerateInstanceContour, + GenerateInstanceType, + GenerateSuccinctContour, GenerateWatershedMarkers, GenerateWatershedMask, Watershed, ) from monai.config.type_definitions import DtypeLike, KeysCollection, NdarrayOrTensor from monai.transforms.transform import MapTransform +from monai.utils import optional_import + +find_contours, _ = optional_import("skimage.measure", name="find_contours") +moments, _ = optional_import("skimage.measure", name="moments") __all__ = [ "WatershedD", @@ -39,6 +47,18 @@ "GenerateWatershedMarkersD", "GenerateWatershedMarkersDict", "GenerateWatershedMarkersd", + "GenerateSuccinctContourDict", + "GenerateSuccinctContourD", + "GenerateSuccinctContourd", + "GenerateInstanceContourDict", + "GenerateInstanceContourD", + "GenerateInstanceContourd", + "GenerateInstanceCentroidDict", + "GenerateInstanceCentroidD", + "GenerateInstanceCentroidd", + "GenerateInstanceTypeDict", + "GenerateInstanceTypeD", + "GenerateInstanceTyped", ] @@ -295,8 +315,177 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N return d +class GenerateSuccinctContourd(MapTransform): + """ + Dictionary-based wrapper of :py:class:`monai.apps.pathology.transforms.post.array.GenerateSuccinctContour`. + Converts Scipy-style contours(generated by skimage.measure.find_contours) to a more succinct version which + only includes the pixels to which lines need to be drawn (i.e. not the intervening pixels along each line). + + Args: + keys: keys of the corresponding items to be transformed. + height: height of bounding box, used to detect direction of line segment. + width: width of bounding box, used to detect direction of line segment. + allow_missing_keys: don't raise exception if key is missing. + + """ + + backend = GenerateSuccinctContour.backend + + def __init__(self, keys: KeysCollection, height: int, width: int, allow_missing_keys: bool = False) -> None: + super().__init__(keys, allow_missing_keys) + self.converter = GenerateSuccinctContour(height=height, width=width) + + def __call__(self, data): + d = dict(data) + for key in self.key_iterator(d): + d[key] = self.converter(d[key]) + + return d + + +class GenerateInstanceContourd(MapTransform): + """ + Dictionary-based wrapper of :py:class:`monai.apps.pathology.transforms.post.array.GenerateInstanceContour`. + Generate contour for each instance in a 2D array. Use `GenerateSuccinctContour` to only include the pixels + to which lines need to be drawn + + Args: + keys: keys of the corresponding items to be transformed. + contour_key_postfix: the output contour coordinates will be written to the value of + `{key}_{contour_key_postfix}`. + offset_key: keys of offset used in `GenerateInstanceContour`. + points_num: assumed that the created contour does not form a contour if it does not contain more points + than the specified value. Defaults to 3. + level: optional. Value along which to find contours in the array. By default, the level is set + to (max(image) + min(image)) / 2. + allow_missing_keys: don't raise exception if key is missing. + + """ + + backend = GenerateInstanceContour.backend + + def __init__( + self, + keys: KeysCollection, + contour_key_postfix: str = "contour", + offset_key: Optional[str] = None, + points_num: int = 3, + level: Optional[float] = None, + allow_missing_keys: bool = False, + ) -> None: + super().__init__(keys, allow_missing_keys) + self.converter = GenerateInstanceContour(points_num=points_num, level=level) + self.contour_key_postfix = contour_key_postfix + self.offset_key = offset_key + + def __call__(self, data): + d = dict(data) + for key in self.key_iterator(d): + offset = d[self.offset_key] if self.offset_key else None + contour = self.converter(d[key], offset) + key_to_add = f"{key}_{self.contour_key_postfix}" + if key_to_add in d: + raise KeyError(f"Contour with key {key_to_add} already exists.") + d[key_to_add] = contour + return d + + +class GenerateInstanceCentroidd(MapTransform): + """ + Dictionary-based wrapper of :py:class:`monai.apps.pathology.transforms.post.array.GenerateInstanceCentroid`. + Generate instance centroid using `skimage.measure.centroid`. + + Args: + keys: keys of the corresponding items to be transformed. + centroid_key_postfix: the output centroid coordinates will be written to the value of + `{key}_{centroid_key_postfix}`. + offset_key: keys of offset used in `GenerateInstanceCentroid`. + dtype: the data type of output centroid. + allow_missing_keys: don't raise exception if key is missing. + + """ + + backend = GenerateInstanceCentroid.backend + + def __init__( + self, + keys: KeysCollection, + centroid_key_postfix: str = "centroid", + offset_key: Optional[str] = None, + dtype: Optional[DtypeLike] = int, + allow_missing_keys: bool = False, + ) -> None: + super().__init__(keys, allow_missing_keys) + self.converter = GenerateInstanceCentroid(dtype=dtype) + self.centroid_key_postfix = centroid_key_postfix + self.offset_key = offset_key + + def __call__(self, data): + d = dict(data) + for key in self.key_iterator(d): + offset = d[self.offset_key] if self.offset_key else None + centroid = self.converter(d[key], offset) + key_to_add = f"{key}_{self.centroid_key_postfix}" + if key_to_add in d: + raise KeyError(f"Centroid with key {key_to_add} already exists.") + d[key_to_add] = centroid + return d + + +class GenerateInstanceTyped(MapTransform): + """ + Dictionary-based wrapper of :py:class:`monai.apps.pathology.transforms.post.array.GenerateInstanceType`. + Generate instance type and probability for each instance. + + Args: + keys: keys of the corresponding items to be transformed. + type_info_key: the output instance type and probability will be written to the value of + `{type_info_key}`. + bbox_key: keys of bounding box. + seg_pred_key: keys of segmentation prediction map. + instance_id_key: keys of instance id. + allow_missing_keys: don't raise exception if key is missing. + + """ + + backend = GenerateInstanceType.backend + + def __init__( + self, + keys: KeysCollection, + type_info_key: str = "type_info", + bbox_key: str = "bbox", + seg_pred_key: str = "seg", + instance_id_key: str = "id", + allow_missing_keys: bool = False, + ) -> None: + super().__init__(keys, allow_missing_keys) + self.converter = GenerateInstanceType() + self.type_info_key = type_info_key + self.bbox_key = bbox_key + self.seg_pred_key = seg_pred_key + self.instance_id_key = instance_id_key + + def __call__(self, data): + d = dict(data) + for key in self.key_iterator(d): + seg = d[self.seg_pred_key] + bbox = d[self.bbox_key] + id = d[self.instance_id_key] + instance_type, type_prob = self.converter(d[key], seg, bbox, id) + key_to_add = f"{self.type_info_key}" + if key_to_add in d: + raise KeyError(f"Type information with key {key_to_add} already exists.") + d[key_to_add] = {"inst_type": instance_type, "type_prob": type_prob} + return d + + WatershedD = WatershedDict = Watershedd GenerateWatershedMaskD = GenerateWatershedMaskDict = GenerateWatershedMaskd GenerateInstanceBorderD = GenerateInstanceBorderDict = GenerateInstanceBorderd GenerateDistanceMapD = GenerateDistanceMapDict = GenerateDistanceMapd GenerateWatershedMarkersD = GenerateWatershedMarkersDict = GenerateWatershedMarkersd +GenerateSuccinctContourDict = GenerateSuccinctContourD = GenerateSuccinctContourd +GenerateInstanceContourDict = GenerateInstanceContourD = GenerateInstanceContourd +GenerateInstanceCentroidDict = GenerateInstanceCentroidD = GenerateInstanceCentroidd +GenerateInstanceTypeDict = GenerateInstanceTypeD = GenerateInstanceTyped diff --git a/monai/transforms/utils_pytorch_numpy_unification.py b/monai/transforms/utils_pytorch_numpy_unification.py index e97e61c73c..aef4a32fe3 100644 --- a/monai/transforms/utils_pytorch_numpy_unification.py +++ b/monai/transforms/utils_pytorch_numpy_unification.py @@ -387,13 +387,13 @@ def mode(x: NdarrayTensor, dim: int = -1, to_long: bool = True) -> NdarrayTensor return o -def unique(x: NdarrayTensor) -> NdarrayTensor: +def unique(x: NdarrayTensor, **kwargs) -> NdarrayTensor: """`torch.unique` with equivalent implementation for numpy. Args: x: array/tensor. """ - return np.unique(x) if isinstance(x, (np.ndarray, list)) else torch.unique(x) # type: ignore + return np.unique(x, **kwargs) if isinstance(x, (np.ndarray, list)) else torch.unique(x, **kwargs) # type: ignore def linalg_inv(x: NdarrayTensor) -> NdarrayTensor: diff --git a/tests/test_generate_instance_centroid.py b/tests/test_generate_instance_centroid.py new file mode 100644 index 0000000000..46f94be637 --- /dev/null +++ b/tests/test_generate_instance_centroid.py @@ -0,0 +1,52 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from parameterized import parameterized + +from monai.apps.pathology.transforms.post.array import GenerateInstanceCentroid +from monai.transforms import BoundingRect +from monai.utils import min_version, optional_import +from tests.utils import TEST_NDARRAYS, assert_allclose + +_, has_skimage = optional_import("skimage", "0.19.3", min_version) + +y, x = np.ogrid[0:30, 0:30] +get_bbox = BoundingRect() + +TEST_CASE_1 = [(x - 2) ** 2 + (y - 2) ** 2 <= 2**2, [0, 0], [2, 2]] + +TEST_CASE_2 = [(x - 8) ** 2 + (y - 8) ** 2 <= 2**2, [6, 6], [8, 8]] + +TEST_CASE_3 = [(x - 5) ** 2 / 3**2 + (y - 5) ** 2 / 2**2 <= 1, [2, 3], [4, 6]] + + +TEST_CASE = [] +for p in TEST_NDARRAYS: + TEST_CASE.append([p, *TEST_CASE_1]) + TEST_CASE.append([p, *TEST_CASE_2]) + TEST_CASE.append([p, *TEST_CASE_3]) + + +@unittest.skipUnless(has_skimage, "Requires scikit-image library.") +class TestGenerateInstanceCentroid(unittest.TestCase): + @parameterized.expand(TEST_CASE) + def test_shape(self, in_type, test_data, offset, expected): + inst_bbox = get_bbox(test_data[None]) + inst_map = test_data[inst_bbox[0][0] : inst_bbox[0][1], inst_bbox[0][2] : inst_bbox[0][3]] + result = GenerateInstanceCentroid()(in_type(inst_map[None]), offset=offset) + assert_allclose(result, expected, type_test=False) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_generate_instance_centroidd.py b/tests/test_generate_instance_centroidd.py new file mode 100644 index 0000000000..f989de5ff2 --- /dev/null +++ b/tests/test_generate_instance_centroidd.py @@ -0,0 +1,54 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from parameterized import parameterized + +from monai.apps.pathology.transforms.post.dictionary import GenerateInstanceCentroidd +from monai.transforms import BoundingRect +from monai.utils import min_version, optional_import +from tests.utils import TEST_NDARRAYS, assert_allclose + +_, has_skimage = optional_import("skimage", "0.19.3", min_version) + +y, x = np.ogrid[0:30, 0:30] +get_bbox = BoundingRect() + +TEST_CASE_1 = [(x - 2) ** 2 + (y - 2) ** 2 <= 2**2, [0, 0], [2, 2]] + +TEST_CASE_2 = [(x - 8) ** 2 + (y - 8) ** 2 <= 2**2, [6, 6], [8, 8]] + +TEST_CASE_3 = [(x - 5) ** 2 / 3**2 + (y - 5) ** 2 / 2**2 <= 1, [2, 3], [4, 6]] + +TEST_CASE = [] +for p in TEST_NDARRAYS: + TEST_CASE.append([p, *TEST_CASE_1]) + TEST_CASE.append([p, *TEST_CASE_2]) + TEST_CASE.append([p, *TEST_CASE_3]) + + +@unittest.skipUnless(has_skimage, "Requires scikit-image library.") +class TestGenerateInstanceCentroidd(unittest.TestCase): + @parameterized.expand(TEST_CASE) + def test_shape(self, in_type, test_data, offset, expected): + inst_bbox = get_bbox(test_data[None]) + inst_map = test_data[inst_bbox[0][0] : inst_bbox[0][1], inst_bbox[0][2] : inst_bbox[0][3]] + test_case = {"image": in_type(inst_map[None]), "offset": offset} + result = GenerateInstanceCentroidd(keys="image", centroid_key_postfix="centroid", offset_key="offset")( + test_case + ) + assert_allclose(result["image_centroid"], expected, type_test=False) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_generate_instance_contour.py b/tests/test_generate_instance_contour.py new file mode 100644 index 0000000000..22b778c06c --- /dev/null +++ b/tests/test_generate_instance_contour.py @@ -0,0 +1,57 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from parameterized import parameterized + +from monai.apps.pathology.transforms.post.array import GenerateInstanceContour +from monai.transforms import BoundingRect +from monai.utils import min_version, optional_import +from tests.utils import TEST_NDARRAYS, assert_allclose + +_, has_skimage = optional_import("skimage", "0.19.3", min_version) + +y, x = np.ogrid[0:30, 0:30] +get_bbox = BoundingRect() + +TEST_CASE_1 = [(x - 2) ** 2 + (y - 2) ** 2 <= 2**2, 3, [0, 0], [[2, 0], [0, 2], [2, 4], [4, 2]]] + +TEST_CASE_2 = [(x - 8) ** 2 + (y - 8) ** 2 <= 2**2, 3, [8, 8], [[10, 8], [8, 10], [10, 12], [12, 10]]] + +TEST_CASE_3 = [ + (x - 5) ** 2 / 3**2 + (y - 5) ** 2 / 2**2 <= 1, + 3, + [2, 3], + [[5, 3], [4, 4], [3, 4], [2, 5], [3, 6], [4, 6], [5, 7], [6, 6], [7, 6], [8, 5], [7, 4], [6, 4]], +] + +TEST_CASE = [] +for p in TEST_NDARRAYS: + TEST_CASE.append([p, *TEST_CASE_1]) + TEST_CASE.append([p, *TEST_CASE_2]) + TEST_CASE.append([p, *TEST_CASE_3]) + + +@unittest.skipUnless(has_skimage, "Requires scikit-image library.") +class TestGenerateInstanceContour(unittest.TestCase): + @parameterized.expand(TEST_CASE) + def test_shape(self, in_type, test_data, points_num, offset, expected): + + inst_bbox = get_bbox(test_data[None]) + inst_map = test_data[inst_bbox[0][0] : inst_bbox[0][1], inst_bbox[0][2] : inst_bbox[0][3]] + result = GenerateInstanceContour(points_num=points_num)(in_type(inst_map[None]), offset=offset) + assert_allclose(result, expected, type_test=False) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_generate_instance_contourd.py b/tests/test_generate_instance_contourd.py new file mode 100644 index 0000000000..9c9c1efbe6 --- /dev/null +++ b/tests/test_generate_instance_contourd.py @@ -0,0 +1,60 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from parameterized import parameterized + +from monai.apps.pathology.transforms.post.dictionary import GenerateInstanceContourd +from monai.transforms import BoundingRect +from monai.utils import min_version, optional_import +from tests.utils import TEST_NDARRAYS, assert_allclose + +_, has_skimage = optional_import("skimage", "0.19.3", min_version) + +y, x = np.ogrid[0:30, 0:30] +get_bbox = BoundingRect() + +TEST_CASE_1 = [(x - 2) ** 2 + (y - 2) ** 2 <= 2**2, 3, [0, 0], [[2, 0], [0, 2], [2, 4], [4, 2]]] + +TEST_CASE_2 = [(x - 10) ** 2 + (y - 10) ** 2 <= 2**2, 3, [8, 8], [[10, 8], [8, 10], [10, 12], [12, 10]]] + + +TEST_CASE_3 = [ + (x - 5) ** 2 / 3**2 + (y - 5) ** 2 / 2**2 <= 1, + 3, + [2, 3], + [[5, 3], [4, 4], [3, 4], [2, 5], [3, 6], [4, 6], [5, 7], [6, 6], [7, 6], [8, 5], [7, 4], [6, 4]], +] + +TEST_CASE = [] +for p in TEST_NDARRAYS: + TEST_CASE.append([p, *TEST_CASE_1]) + TEST_CASE.append([p, *TEST_CASE_2]) + TEST_CASE.append([p, *TEST_CASE_3]) + + +@unittest.skipUnless(has_skimage, "Requires scikit-image library.") +class TestGenerateInstanceContourd(unittest.TestCase): + @parameterized.expand(TEST_CASE) + def test_shape(self, in_type, test_data, points_num, offset, expected): + inst_bbox = get_bbox(test_data[None]) + inst_map = test_data[inst_bbox[0][0] : inst_bbox[0][1], inst_bbox[0][2] : inst_bbox[0][3]] + test_data = {"image": in_type(inst_map[None]), "offset": offset} + result = GenerateInstanceContourd( + keys="image", contour_key_postfix="contour", offset_key="offset", points_num=points_num + )(test_data) + assert_allclose(result["image_contour"], expected, type_test=False) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_generate_instance_type.py b/tests/test_generate_instance_type.py new file mode 100644 index 0000000000..8a083d19b7 --- /dev/null +++ b/tests/test_generate_instance_type.py @@ -0,0 +1,49 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from parameterized import parameterized + +from monai.apps.pathology.transforms.post.array import GenerateInstanceType +from tests.utils import TEST_NDARRAYS, assert_allclose + +y, x = np.ogrid[0:30, 0:30] + +TEST_CASE_1 = [ + (x - 2) ** 2 + (y - 2) ** 2 <= 2**2, + (x - 2) ** 2 + (y - 3) ** 2 <= 2**2, + np.array([[0, 5, 0, 5]]), + [1, 0.6666666111111158], +] + +TEST_CASE_2 = [ + (x - 8) ** 2 / 3**2 + (y - 8) ** 2 / 2**2 <= 1, + (x - 7) ** 2 / 3**2 + (y - 7) ** 2 / 2**2 <= 1, + np.array([[6, 11, 5, 12]]), + [1, 0.7058823114186875], +] +TEST_CASE = [] +for p in TEST_NDARRAYS: + TEST_CASE.append([p, *TEST_CASE_1]) + TEST_CASE.append([p, *TEST_CASE_2]) + + +class TestGenerateInstanceType(unittest.TestCase): + @parameterized.expand(TEST_CASE) + def test_shape(self, in_type, type_pred, seg_pred, bbox, expected): + result = GenerateInstanceType()(in_type(type_pred[None]), in_type(seg_pred[None]), bbox, 1) + assert_allclose(result, expected) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_generate_instance_typed.py b/tests/test_generate_instance_typed.py new file mode 100644 index 0000000000..08d9f550a9 --- /dev/null +++ b/tests/test_generate_instance_typed.py @@ -0,0 +1,51 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from parameterized import parameterized + +from monai.apps.pathology.transforms.post.dictionary import GenerateInstanceTyped +from tests.utils import TEST_NDARRAYS, assert_allclose + +y, x = np.ogrid[0:30, 0:30] + +TEST_CASE_1 = [ + (x - 2) ** 2 + (y - 2) ** 2 <= 2**2, + (x - 2) ** 2 + (y - 3) ** 2 <= 2**2, + np.array([[0, 5, 0, 5]]), + [1, 0.6666666111111158], +] + +TEST_CASE_2 = [ + (x - 8) ** 2 / 3**2 + (y - 8) ** 2 / 2**2 <= 1, + (x - 7) ** 2 / 3**2 + (y - 7) ** 2 / 2**2 <= 1, + np.array([[6, 11, 5, 12]]), + [1, 0.7058823114186875], +] +TEST_CASE = [] +for p in TEST_NDARRAYS: + TEST_CASE.append([p, *TEST_CASE_1]) + TEST_CASE.append([p, *TEST_CASE_2]) + + +class TestGenerateInstanceTyped(unittest.TestCase): + @parameterized.expand(TEST_CASE) + def test_shape(self, in_type, type_pred, seg_pred, bbox, expected): + test_data = {"type_pred": in_type(type_pred[None]), "seg": in_type(seg_pred[None]), "bbox": bbox, "id": 1} + result = GenerateInstanceTyped(keys="type_pred")(test_data) + assert_allclose(result["type_info"]["inst_type"], expected[0]) + assert_allclose(result["type_info"]["type_prob"], expected[1]) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_generate_succinct_contour.py b/tests/test_generate_succinct_contour.py new file mode 100644 index 0000000000..478c23b522 --- /dev/null +++ b/tests/test_generate_succinct_contour.py @@ -0,0 +1,52 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from parameterized import parameterized + +from monai.apps.pathology.transforms.post.array import GenerateSuccinctContour + +TEST_CASE_1 = [ + [ + np.array([[1.5, 0.0], [1.0, 0.5], [0.5, 1.0], [0.0, 1.5]]), + np.array([[0.0, 2.5], [0.5, 3.0], [1.0, 3.5], [1.5, 4.0]]), + np.array([[4.0, 1.5], [3.5, 1.0], [3.0, 0.5], [2.5, 0.0]]), + np.array([[2.5, 4.0], [3.0, 3.5], [3.5, 3.0], [4.0, 2.5]]), + ], + 5, + 5, + [[2, 0], [0, 2], [2, 4], [4, 2]], +] + +TEST_CASE_2 = [ + [ + np.array([[1.5, 0.0], [1.0, 0.5], [0.5, 1.0], [0.5, 2.0], [0.0, 2.5]]), + np.array([[0.0, 3.5], [0.5, 4.0], [0.5, 5.0], [1.0, 5.5], [1.5, 6.0]]), + np.array([[4.0, 2.5], [3.5, 2.0], [3.5, 1.0], [3.0, 0.5], [2.5, 0.0]]), + np.array([[2.5, 6.0], [3.0, 5.5], [3.5, 5.0], [3.5, 4.0], [4.0, 3.5]]), + ], + 5, + 7, + [[3, 0], [2, 1], [1, 1], [0, 2], [1, 3], [2, 3], [3, 4], [4, 3], [5, 3], [6, 2], [5, 1], [4, 1]], +] + + +class TestGenerateSuccinctContour(unittest.TestCase): + @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) + def test_shape(self, test_data, height, width, expected): + result = GenerateSuccinctContour(height=height, width=width)(test_data) + np.testing.assert_allclose(result, expected) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_generate_succinct_contourd.py b/tests/test_generate_succinct_contourd.py new file mode 100644 index 0000000000..b34142ec0d --- /dev/null +++ b/tests/test_generate_succinct_contourd.py @@ -0,0 +1,54 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from parameterized import parameterized + +from monai.apps.pathology.transforms.post.dictionary import GenerateSuccinctContourd + +y, x = np.ogrid[0:5, 0:5] +TEST_CASE_1 = [ + [ + np.array([[1.5, 0.0], [1.0, 0.5], [0.5, 1.0], [0.0, 1.5]]), + np.array([[0.0, 2.5], [0.5, 3.0], [1.0, 3.5], [1.5, 4.0]]), + np.array([[4.0, 1.5], [3.5, 1.0], [3.0, 0.5], [2.5, 0.0]]), + np.array([[2.5, 4.0], [3.0, 3.5], [3.5, 3.0], [4.0, 2.5]]), + ], + 5, + 5, + [[2, 0], [0, 2], [2, 4], [4, 2]], +] + +TEST_CASE_2 = [ + [ + np.array([[1.5, 0.0], [1.0, 0.5], [0.5, 1.0], [0.5, 2.0], [0.0, 2.5]]), + np.array([[0.0, 3.5], [0.5, 4.0], [0.5, 5.0], [1.0, 5.5], [1.5, 6.0]]), + np.array([[4.0, 2.5], [3.5, 2.0], [3.5, 1.0], [3.0, 0.5], [2.5, 0.0]]), + np.array([[2.5, 6.0], [3.0, 5.5], [3.5, 5.0], [3.5, 4.0], [4.0, 3.5]]), + ], + 5, + 7, + [[3, 0], [2, 1], [1, 1], [0, 2], [1, 3], [2, 3], [3, 4], [4, 3], [5, 3], [6, 2], [5, 1], [4, 1]], +] + + +class TestGenerateSuccinctContour(unittest.TestCase): + @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) + def test_shape(self, data, height, width, expected): + test_data = {"contour": data} + result = GenerateSuccinctContourd(keys="contour", height=height, width=width)(test_data) + np.testing.assert_allclose(result["contour"], expected) + + +if __name__ == "__main__": + unittest.main()