diff --git a/docs/source/apps.rst b/docs/source/apps.rst index 248813d679..cf44547aa7 100644 --- a/docs/source/apps.rst +++ b/docs/source/apps.rst @@ -149,6 +149,8 @@ Applications :members: .. autoclass:: GenerateWatershedMarkers :members: +.. autoclass:: HoVerNetNuclearTypePostProcessing + :members: .. automodule:: monai.apps.pathology.transforms.post.dictionary .. autoclass:: GenerateSuccinctContourd @@ -169,6 +171,8 @@ Applications :members: .. autoclass:: GenerateWatershedMarkersd :members: +.. autoclass:: HoVerNetNuclearTypePostProcessingd + :members: `Detection` ----------- diff --git a/monai/apps/pathology/transforms/__init__.py b/monai/apps/pathology/transforms/__init__.py index 3e784b8ebf..18b53d7f2a 100644 --- a/monai/apps/pathology/transforms/__init__.py +++ b/monai/apps/pathology/transforms/__init__.py @@ -18,6 +18,7 @@ GenerateSuccinctContour, GenerateWatershedMarkers, GenerateWatershedMask, + HoVerNetNuclearTypePostProcessing, Watershed, ) from .post.dictionary import ( @@ -45,6 +46,9 @@ GenerateWatershedMaskD, GenerateWatershedMaskd, GenerateWatershedMaskDict, + HoVerNetNuclearTypePostProcessingD, + HoVerNetNuclearTypePostProcessingd, + HoVerNetNuclearTypePostProcessingDict, WatershedD, Watershedd, WatershedDict, diff --git a/monai/apps/pathology/transforms/post/__init__.py b/monai/apps/pathology/transforms/post/__init__.py index 3e6af77ce6..836582b8a3 100644 --- a/monai/apps/pathology/transforms/post/__init__.py +++ b/monai/apps/pathology/transforms/post/__init__.py @@ -18,6 +18,7 @@ GenerateSuccinctContour, GenerateWatershedMarkers, GenerateWatershedMask, + HoVerNetNuclearTypePostProcessing, Watershed, ) from .dictionary import ( @@ -45,6 +46,9 @@ GenerateWatershedMaskD, GenerateWatershedMaskd, GenerateWatershedMaskDict, + HoVerNetNuclearTypePostProcessingD, + HoVerNetNuclearTypePostProcessingd, + HoVerNetNuclearTypePostProcessingDict, WatershedD, Watershedd, WatershedDict, diff --git a/monai/apps/pathology/transforms/post/array.py b/monai/apps/pathology/transforms/post/array.py index 55ff531172..619dd0028a 100644 --- a/monai/apps/pathology/transforms/post/array.py +++ b/monai/apps/pathology/transforms/post/array.py @@ -9,12 +9,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, List, Optional, Sequence, Tuple, Union +from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union import numpy as np from monai.config.type_definitions import DtypeLike, NdarrayOrTensor -from monai.transforms.post.array import Activations, AsDiscrete, RemoveSmallObjects, SobelGradients +from monai.transforms import Activations, AsDiscrete, BoundingRect, RemoveSmallObjects, SobelGradients from monai.transforms.transform import Transform from monai.transforms.utils_pytorch_numpy_unification import max, maximum, min, sum, unique from monai.utils import TransformBackends, convert_to_numpy, optional_import @@ -38,6 +38,7 @@ "GenerateInstanceContour", "GenerateInstanceCentroid", "GenerateInstanceType", + "HoVerNetNuclearTypePostProcessing", ] @@ -512,7 +513,7 @@ class GenerateInstanceContour(Transform): the pixels to which lines need to be drawn Args: - points_num: assumed that the created contour does not form a contour if it does not contain more points + min_num_points: assumed that the created contour does not form a contour if it does not contain more points than the specified value. Defaults to 3. level: optional. Value along which to find contours in the array. By default, the level is set to (max(image) + min(image)) / 2. @@ -521,9 +522,9 @@ class GenerateInstanceContour(Transform): backend = [TransformBackends.NUMPY] - def __init__(self, points_num: int = 3, level: Optional[float] = None) -> None: + def __init__(self, min_num_points: int = 3, level: Optional[float] = None) -> None: self.level = level - self.points_num = points_num + self.min_num_points = min_num_points def __call__(self, image: NdarrayOrTensor, offset: Optional[Sequence[int]] = (0, 0)) -> np.ndarray: """ @@ -537,10 +538,10 @@ def __call__(self, image: NdarrayOrTensor, offset: Optional[Sequence[int]] = (0, generate_contour = GenerateSuccinctContour(image.shape[0], image.shape[1]) inst_contour = generate_contour(inst_contour_cv) - # < `self.points_num` points don't make a contour, so skip, likely artifact too + # < `self.min_num_points` points don't make a contour, so skip, likely artifact too # as the contours obtained via approximation => too small or sthg - if inst_contour.shape[0] < self.points_num: - print(f"< {self.points_num} points don't make a contour, so skip") + if inst_contour.shape[0] < self.min_num_points: + print(f"< {self.min_num_points} points don't make a contour, so skip") return None # type: ignore # check for tricky shape elif len(inst_contour.shape) != 2: @@ -624,3 +625,70 @@ def __call__( # type: ignore type_prob = type_dict[inst_type] / (sum(seg_map_crop) + 1.0e-6) return (int(inst_type), float(type_prob)) + + +class HoVerNetNuclearTypePostProcessing(Transform): + """ + The whole post-procesing transform for nuclear type classification branch. It return a dict which contains + centroid, bounding box, type prediciton for each instance. + + Args: + min_num_points: assumed that the created contour does not form a contour if it does not contain more points + than the specified value. Defaults to 3. + level: optional. Used in `skimage.measure.find_contours`. Value along which to find contours in the array. + By default, the level is set to (max(image) + min(image)) / 2. + return_centroids: whether to return centroids for each instance. + output_classes: number of the nuclear type classes. + """ + + def __init__( + self, + min_num_points: int = 3, + level: Optional[float] = None, + return_centroids: Optional[bool] = False, + output_classes: Optional[int] = None, + ) -> None: + super().__init__() + self.return_centroids = return_centroids + self.output_classes = output_classes + + self.generate_instance_contour = GenerateInstanceContour(min_num_points=min_num_points, level=level) + self.generate_instance_centroid = GenerateInstanceCentroid() + self.generate_instance_type = GenerateInstanceType() + + def __call__(self, type_pred: NdarrayOrTensor, instance_pred: NdarrayOrTensor) -> Dict: # type: ignore + type_pred = Activations(softmax=True)(type_pred) + type_pred = AsDiscrete(argmax=True)(type_pred) + + inst_id_list = np.unique(instance_pred)[1:] # exclude background + inst_info_dict = None + if self.return_centroids: + inst_info_dict = {} + for inst_id in inst_id_list: + inst_map = instance_pred == inst_id + inst_bbox = BoundingRect()(inst_map) + inst_map = inst_map[:, inst_bbox[0][0] : inst_bbox[0][1], inst_bbox[0][2] : inst_bbox[0][3]] + offset = [inst_bbox[0][2], inst_bbox[0][0]] + inst_contour = self.generate_instance_contour(inst_map, offset) + inst_centroid = self.generate_instance_centroid(inst_map, offset) + if inst_contour is not None: + inst_info_dict[inst_id] = { # inst_id should start at 1 + "bounding_box": inst_bbox, + "centroid": inst_centroid, + "contour": inst_contour, + "type_probability": None, + "type": None, + } + + if self.output_classes is not None: + for inst_id in list(inst_info_dict.keys()): + inst_type, type_prob = self.generate_instance_type( + bbox=inst_info_dict[inst_id]["bounding_box"], # type: ignore + type_pred=type_pred, + seg_pred=instance_pred, + instance_id=inst_id, + ) + inst_info_dict[inst_id]["type"] = inst_type # type: ignore + inst_info_dict[inst_id]["type_probability"] = type_prob # type: ignore + + return inst_info_dict diff --git a/monai/apps/pathology/transforms/post/dictionary.py b/monai/apps/pathology/transforms/post/dictionary.py index c358eebf39..6c42c8ad41 100644 --- a/monai/apps/pathology/transforms/post/dictionary.py +++ b/monai/apps/pathology/transforms/post/dictionary.py @@ -22,11 +22,13 @@ GenerateSuccinctContour, GenerateWatershedMarkers, GenerateWatershedMask, + HoVerNetNuclearTypePostProcessing, Watershed, ) from monai.config.type_definitions import DtypeLike, KeysCollection, NdarrayOrTensor -from monai.transforms.transform import MapTransform -from monai.utils import optional_import +from monai.transforms.transform import MapTransform, Transform +from monai.utils import convert_to_dst_type, optional_import +from monai.utils.enums import HoVerNetBranch find_contours, _ = optional_import("skimage.measure", name="find_contours") moments, _ = optional_import("skimage.measure", name="moments") @@ -59,6 +61,9 @@ "GenerateInstanceTypeDict", "GenerateInstanceTypeD", "GenerateInstanceTyped", + "HoVerNetNuclearTypePostProcessingDict", + "HoVerNetNuclearTypePostProcessingD", + "HoVerNetNuclearTypePostProcessingd", ] @@ -354,7 +359,7 @@ class GenerateInstanceContourd(MapTransform): contour_key_postfix: the output contour coordinates will be written to the value of `{key}_{contour_key_postfix}`. offset_key: keys of offset used in `GenerateInstanceContour`. - points_num: assumed that the created contour does not form a contour if it does not contain more points + min_num_points: assumed that the created contour does not form a contour if it does not contain more points than the specified value. Defaults to 3. level: optional. Value along which to find contours in the array. By default, the level is set to (max(image) + min(image)) / 2. @@ -369,12 +374,12 @@ def __init__( keys: KeysCollection, contour_key_postfix: str = "contour", offset_key: Optional[str] = None, - points_num: int = 3, + min_num_points: int = 3, level: Optional[float] = None, allow_missing_keys: bool = False, ) -> None: super().__init__(keys, allow_missing_keys) - self.converter = GenerateInstanceContour(points_num=points_num, level=level) + self.converter = GenerateInstanceContour(min_num_points=min_num_points, level=level) self.contour_key_postfix = contour_key_postfix self.offset_key = offset_key @@ -480,6 +485,68 @@ def __call__(self, data): return d +class HoVerNetNuclearTypePostProcessingd(Transform): + """ + Dictionary-based wrapper of :py:class:`monai.apps.pathology.transforms.post.array.HoVerNetNuclearTypePostProcessing`. + Generate instance type and probability for each instance. + + Args: + type_pred_key: the key pointing to the pred type map to be transformed. + instance_pred_key: the key pointing to the pred distance map. + min_num_points: assumed that the created contour does not form a contour if it does not contain more points + than the specified value. Defaults to 3. + level: optional. Used in `skimage.measure.find_contours`. Value along which to find contours in the array. + By default, the level is set to (max(image) + min(image)) / 2. + return_binary: whether to return the binary segmentation prediction after `seg_postprocessing`. + pred_binary_key: if `return_binary` is True, this `pred_binary_key` is used to get the binary prediciton + from the output. + return_centroids: whether to return centroids for each instance. + output_classes: number of the nuclear type classes. + instance_info_dict_key: key use to record information for each instance. + allow_missing_keys: don't raise exception if key is missing. + + """ + + def __init__( + self, + type_pred_key: str = HoVerNetBranch.NC.value, + instance_pred_key: str = "dist", + min_num_points: int = 3, + level: Optional[float] = None, + return_binary: Optional[bool] = True, + pred_binary_key: Optional[str] = "pred_binary", + return_centroids: Optional[bool] = False, + output_classes: Optional[int] = None, + instance_info_dict_key: Optional[str] = "instance_info_dict", + ) -> None: + super().__init__() + self.converter = HoVerNetNuclearTypePostProcessing( + min_num_points=min_num_points, level=level, return_centroids=return_centroids, output_classes=output_classes + ) + self.type_pred_key = type_pred_key + self.instance_pred_key = instance_pred_key + self.pred_binary_key = pred_binary_key + self.instance_info_dict_key = instance_info_dict_key + self.return_binary = return_binary + + def __call__(self, data): + d = dict(data) + inst_pred = d[self.instance_pred_key] + type_pred = d[self.type_pred_key] + inst_info_dict = self.converter(type_pred, inst_pred) + key_to_add = f"{self.instance_info_dict_key}" + if key_to_add in d: + raise KeyError(f"Type information with key {key_to_add} already exists.") + d[key_to_add] = inst_info_dict # type: ignore + + inst_pred = convert_to_dst_type(inst_pred, type_pred)[0] + if self.return_binary: + inst_pred[inst_pred > 0] = 1 + d[self.pred_binary_key] = inst_pred + + return d + + WatershedD = WatershedDict = Watershedd GenerateWatershedMaskD = GenerateWatershedMaskDict = GenerateWatershedMaskd GenerateInstanceBorderD = GenerateInstanceBorderDict = GenerateInstanceBorderd @@ -489,3 +556,4 @@ def __call__(self, data): GenerateInstanceContourDict = GenerateInstanceContourD = GenerateInstanceContourd GenerateInstanceCentroidDict = GenerateInstanceCentroidD = GenerateInstanceCentroidd GenerateInstanceTypeDict = GenerateInstanceTypeD = GenerateInstanceTyped +HoVerNetNuclearTypePostProcessingDict = HoVerNetNuclearTypePostProcessingD = HoVerNetNuclearTypePostProcessingd diff --git a/tests/test_generate_instance_contour.py b/tests/test_generate_instance_contour.py index 22b778c06c..8c43bf5bc5 100644 --- a/tests/test_generate_instance_contour.py +++ b/tests/test_generate_instance_contour.py @@ -45,11 +45,11 @@ @unittest.skipUnless(has_skimage, "Requires scikit-image library.") class TestGenerateInstanceContour(unittest.TestCase): @parameterized.expand(TEST_CASE) - def test_shape(self, in_type, test_data, points_num, offset, expected): + def test_shape(self, in_type, test_data, min_num_points, offset, expected): inst_bbox = get_bbox(test_data[None]) inst_map = test_data[inst_bbox[0][0] : inst_bbox[0][1], inst_bbox[0][2] : inst_bbox[0][3]] - result = GenerateInstanceContour(points_num=points_num)(in_type(inst_map[None]), offset=offset) + result = GenerateInstanceContour(min_num_points=min_num_points)(in_type(inst_map[None]), offset=offset) assert_allclose(result, expected, type_test=False) diff --git a/tests/test_generate_instance_contourd.py b/tests/test_generate_instance_contourd.py index be9a8d321d..e92020c6bc 100644 --- a/tests/test_generate_instance_contourd.py +++ b/tests/test_generate_instance_contourd.py @@ -45,12 +45,12 @@ @unittest.skipUnless(has_skimage, "Requires scikit-image library.") class TestGenerateInstanceContourd(unittest.TestCase): @parameterized.expand(TEST_CASE) - def test_shape(self, in_type, test_data, points_num, offset, expected): + def test_shape(self, in_type, test_data, min_num_points, offset, expected): inst_bbox = get_bbox(test_data[None]) inst_map = test_data[inst_bbox[0][0] : inst_bbox[0][1], inst_bbox[0][2] : inst_bbox[0][3]] test_data = {"image": in_type(inst_map[None]), "offset": offset} result = GenerateInstanceContourd( - keys="image", contour_key_postfix="contour", offset_key="offset", points_num=points_num + keys="image", contour_key_postfix="contour", offset_key="offset", min_num_points=min_num_points )(test_data) assert_allclose(result["image_contour"], expected, type_test=False) diff --git a/tests/test_hovernet_nuclear_type_post_processing.py b/tests/test_hovernet_nuclear_type_post_processing.py new file mode 100644 index 0000000000..954c2a34fb --- /dev/null +++ b/tests/test_hovernet_nuclear_type_post_processing.py @@ -0,0 +1,82 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from parameterized import parameterized + +from monai.apps.pathology.transforms.post.array import HoVerNetNuclearTypePostProcessing +from monai.apps.pathology.transforms.post.dictionary import ( + GenerateDistanceMapd, + GenerateInstanceBorderd, + GenerateWatershedMarkersd, + GenerateWatershedMaskd, + Watershedd, +) +from monai.transforms import Compose, ComputeHoVerMaps, FillHoles, GaussianSmooth +from monai.utils import min_version, optional_import +from monai.utils.enums import HoVerNetBranch +from tests.utils import TEST_NDARRAYS, assert_allclose + +_, has_scipy = optional_import("scipy", "1.8.1", min_version) +_, has_skimage = optional_import("skimage", "0.19.3", min_version) + +y, x = np.ogrid[0:30, 0:30] +image = (x - 10) ** 2 + (y - 10) ** 2 <= 5**2 + +seg_postpprocessing = Compose( + [ + GenerateWatershedMaskd( + keys=HoVerNetBranch.NP.value, sigmoid=True, softmax=False, threshold=0.7, remove_small_objects=False + ), + GenerateInstanceBorderd( + keys="mask", hover_map_key=HoVerNetBranch.HV.value, kernel_size=3, remove_small_objects=False + ), + GenerateDistanceMapd(keys="mask", border_key="border", smooth_fn=GaussianSmooth()), + GenerateWatershedMarkersd( + keys="mask", border_key="border", threshold=0.9, radius=2, postprocess_fn=FillHoles() + ), + Watershedd(keys="dist", mask_key="mask", markers_key="markers"), + ] +) +TEST_CASE_1 = [seg_postpprocessing, {"return_centroids": True, "output_classes": 1}, [image, [10, 10]]] +TEST_CASE_2 = [seg_postpprocessing, {"return_centroids": False, "output_classes": None}, [image]] + + +TEST_CASE = [] +for p in TEST_NDARRAYS: + TEST_CASE.append([p, image] + TEST_CASE_1) + TEST_CASE.append([p, image] + TEST_CASE_2) + + +@unittest.skipUnless(has_scipy, "Requires scipy library.") +@unittest.skipUnless(has_skimage, "Requires scikit-image library.") +class TestHoVerNetNuclearTypePostProcessing(unittest.TestCase): + @parameterized.expand(TEST_CASE) + def test_value(self, in_type, test_data, seg_postpprocessing, kwargs, expected): + hovermap = ComputeHoVerMaps()(test_data[None].astype(int)) + input = { + HoVerNetBranch.NP.value: in_type(test_data[None].astype(float)), + HoVerNetBranch.HV.value: in_type(hovermap), + HoVerNetBranch.NC.value: in_type(test_data[None]), + } + + pred = seg_postpprocessing(input) + + post_transforms = HoVerNetNuclearTypePostProcessing(**kwargs) + out = post_transforms(type_pred=in_type(test_data[None]), instance_pred=pred["dist"]) + if out is not None: + assert_allclose(out[1]["centroid"], expected[1], type_test=False) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_hovernet_nuclear_type_post_processingd.py b/tests/test_hovernet_nuclear_type_post_processingd.py new file mode 100644 index 0000000000..116a6bb9af --- /dev/null +++ b/tests/test_hovernet_nuclear_type_post_processingd.py @@ -0,0 +1,80 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from parameterized import parameterized + +from monai.apps.pathology.transforms.post.dictionary import ( + GenerateDistanceMapd, + GenerateInstanceBorderd, + GenerateWatershedMarkersd, + GenerateWatershedMaskd, + HoVerNetNuclearTypePostProcessingd, + Watershedd, +) +from monai.transforms import Compose, ComputeHoVerMaps, FillHoles, GaussianSmooth +from monai.utils import min_version, optional_import +from monai.utils.enums import HoVerNetBranch +from tests.utils import TEST_NDARRAYS, assert_allclose + +_, has_scipy = optional_import("scipy", "1.8.1", min_version) +_, has_skimage = optional_import("skimage", "0.19.3", min_version) + +y, x = np.ogrid[0:30, 0:30] +image = (x - 10) ** 2 + (y - 10) ** 2 <= 5**2 + +seg_postpprocessing = [ + GenerateWatershedMaskd( + keys=HoVerNetBranch.NP.value, sigmoid=True, softmax=False, threshold=0.7, remove_small_objects=False + ), + GenerateInstanceBorderd( + keys="mask", hover_map_key=HoVerNetBranch.HV.value, kernel_size=3, remove_small_objects=False + ), + GenerateDistanceMapd(keys="mask", border_key="border", smooth_fn=GaussianSmooth()), + GenerateWatershedMarkersd(keys="mask", border_key="border", threshold=0.9, radius=2, postprocess_fn=FillHoles()), + Watershedd(keys="dist", mask_key="mask", markers_key="markers"), +] + +TEST_CASE_1 = [seg_postpprocessing, {"return_centroids": True, "output_classes": 1}, [image, [10, 10]]] +TEST_CASE_2 = [seg_postpprocessing, {"return_centroids": False, "output_classes": None}, [image]] + + +TEST_CASE = [] +for p in TEST_NDARRAYS: + TEST_CASE.append([p, image] + TEST_CASE_1) + TEST_CASE.append([p, image] + TEST_CASE_2) + + +@unittest.skipUnless(has_scipy, "Requires scipy library.") +@unittest.skipUnless(has_skimage, "Requires scikit-image library.") +class TestHoVerNetNuclearTypePostProcessingd(unittest.TestCase): + @parameterized.expand(TEST_CASE) + def test_value(self, in_type, test_data, seg_postpprocessing, kwargs, expected): + hovermap = ComputeHoVerMaps()(test_data[None].astype(int)) + input = { + HoVerNetBranch.NP.value: in_type(test_data[None].astype(float)), + HoVerNetBranch.HV.value: in_type(hovermap), + HoVerNetBranch.NC.value: in_type(test_data[None]), + } + + class_trans = [HoVerNetNuclearTypePostProcessingd(**kwargs)] + post_transforms = Compose(seg_postpprocessing + class_trans) + out = post_transforms(input) + + assert_allclose(out["dist"].squeeze(), expected[0], type_test=False) + if out["instance_info_dict"] is not None: + assert_allclose(out["instance_info_dict"][1]["centroid"], expected[1], type_test=False) + + +if __name__ == "__main__": + unittest.main()