From 5fc247831dfa26563acf327fac9c9923c20c7df0 Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Thu, 3 Nov 2022 17:08:29 +0800 Subject: [PATCH 01/15] first commit Signed-off-by: KumoLiu --- monai/apps/pathology/transforms/post/array.py | 82 ++++++++++++++++++- 1 file changed, 80 insertions(+), 2 deletions(-) diff --git a/monai/apps/pathology/transforms/post/array.py b/monai/apps/pathology/transforms/post/array.py index 2f84e96257..56eb11b1f9 100644 --- a/monai/apps/pathology/transforms/post/array.py +++ b/monai/apps/pathology/transforms/post/array.py @@ -9,16 +9,18 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, Optional +from typing import Callable, Dict, Hashable, Mapping, Optional import numpy as np from monai.config.type_definitions import DtypeLike, NdarrayOrTensor from monai.transforms.post.array import Activations, AsDiscrete, RemoveSmallObjects, SobelGradients from monai.transforms.transform import Transform +from monai.transforms import Activations, AsDiscrete, BoundingRect from monai.transforms.utils_pytorch_numpy_unification import max, maximum, min from monai.utils import TransformBackends, convert_to_numpy, optional_import -from monai.utils.type_conversion import convert_to_dst_type +from monai.utils.enums import HoVerNetBranch +from monai.utils.type_conversion import convert_to_dst_type, convert_to_tensor label, _ = optional_import("scipy.ndimage.measurements", name="label") disk, _ = optional_import("skimage.morphology", name="disk") @@ -320,3 +322,79 @@ def __call__(self, mask: NdarrayOrTensor, instance_border: NdarrayOrTensor) -> N marker = self.remove_small_objects(marker[None]) return convert_to_dst_type(marker, mask, dtype=self.dtype)[0] + + +class PostProcessHoVerNet(Transform): + def __init__( + self, + post_process_segmentation: Transform, + distance_map_key: str = "dist", + points_num: int = 3, + level: Optional[float] = None, + dtype: Optional[DtypeLike] = int, + return_binary: Optional[bool] = True, + pred_binary_key: Optional[str] = 'pred_binary', + return_centroids: Optional[bool] = None, + output_classes: Optional[int] = None, + inst_info_dict_key: Optional[str] = "inst_info_dict", + ) -> None: + super().__init__() + self.distance_map_key = distance_map_key + self.return_binary = return_binary + self.pred_binary_key = pred_binary_key + self.return_centroids = return_centroids + self.output_classes = output_classes + self.inst_info_dict_key = inst_info_dict_key + + self.post_process_segmentation = post_process_segmentation + self.generate_instance_contour = GenerateInstanceContour(points_num=points_num, level=level) + self.generate_instance_centroid = GenerateInstanceCentroid(dtype=dtype) + self.generate_instance_type = GenerateInstanceType() + + def __call__(self, pred: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + device = pred[HoVerNetBranch.NP.value].device + if HoVerNetBranch.NC.value in pred.keys(): + type_pred = Activations(softmax=True)(pred[HoVerNetBranch.NC.value]) + type_pred = AsDiscrete(argmax=True)(type_pred) + + pred_inst_dict = self.post_process_segmentation(pred) + pred_inst = pred_inst_dict[self.distance_map_key] + + inst_id_list = np.unique(pred_inst)[1:] # exclude background + inst_info_dict = None + if self.return_centroids: + inst_info_dict = {} + for inst_id in inst_id_list: + inst_map = pred_inst == inst_id + inst_bbox = BoundingRect()(inst_map) + inst_map = inst_map[:, inst_bbox[0][0]: inst_bbox[0][1], inst_bbox[0][2]: inst_bbox[0][3]] + offset = [inst_bbox[0][2], inst_bbox[0][0]] + inst_contour = self.generate_instance_contour(inst_map, offset) + inst_centroid = self.generate_instance_centroid(inst_map, offset) + if inst_contour is not None: + inst_info_dict[inst_id] = { # inst_id should start at 1 + "bounding_box": inst_bbox, + "centroid": inst_centroid, + "contour": inst_contour, + "type_probability": None, + "type": None, + } + + if self.output_classes is not None: + for inst_id in list(inst_info_dict.keys()): + inst_type, type_prob = self.generate_instance_type( + bbox=inst_info_dict[inst_id]["bounding_box"], + type_pred=type_pred, + seg_pred=pred_inst, + instance_id=inst_id, + ) + inst_info_dict[inst_id]["type"] = inst_type + inst_info_dict[inst_id]["type_probability"] = type_prob + + pred_inst = convert_to_tensor(pred_inst, device=device) + pred[HoVerNetBranch.NP.value] = pred_inst + if self.return_binary: + pred_inst[pred_inst > 0] = 1 + pred[self.pred_binary_key] = pred_inst + pred[self.inst_info_dict_key] = inst_info_dict + return pred From 5346b6e263d3801ad5c428b7b3d61145ed2590f0 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 3 Nov 2022 09:19:15 +0000 Subject: [PATCH 02/15] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/apps/pathology/transforms/post/array.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/monai/apps/pathology/transforms/post/array.py b/monai/apps/pathology/transforms/post/array.py index 56eb11b1f9..a68da2677d 100644 --- a/monai/apps/pathology/transforms/post/array.py +++ b/monai/apps/pathology/transforms/post/array.py @@ -350,13 +350,13 @@ def __init__( self.generate_instance_contour = GenerateInstanceContour(points_num=points_num, level=level) self.generate_instance_centroid = GenerateInstanceCentroid(dtype=dtype) self.generate_instance_type = GenerateInstanceType() - + def __call__(self, pred: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: device = pred[HoVerNetBranch.NP.value].device if HoVerNetBranch.NC.value in pred.keys(): type_pred = Activations(softmax=True)(pred[HoVerNetBranch.NC.value]) type_pred = AsDiscrete(argmax=True)(type_pred) - + pred_inst_dict = self.post_process_segmentation(pred) pred_inst = pred_inst_dict[self.distance_map_key] @@ -383,9 +383,9 @@ def __call__(self, pred: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N if self.output_classes is not None: for inst_id in list(inst_info_dict.keys()): inst_type, type_prob = self.generate_instance_type( - bbox=inst_info_dict[inst_id]["bounding_box"], - type_pred=type_pred, - seg_pred=pred_inst, + bbox=inst_info_dict[inst_id]["bounding_box"], + type_pred=type_pred, + seg_pred=pred_inst, instance_id=inst_id, ) inst_info_dict[inst_id]["type"] = inst_type From dc54220456a8947337ada49bcd99fc68c6843344 Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Thu, 3 Nov 2022 18:20:45 +0800 Subject: [PATCH 03/15] updated based on comments Signed-off-by: KumoLiu --- monai/apps/pathology/transforms/post/array.py | 22 +++++++++---------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/monai/apps/pathology/transforms/post/array.py b/monai/apps/pathology/transforms/post/array.py index 56eb11b1f9..cc6d6e3a8d 100644 --- a/monai/apps/pathology/transforms/post/array.py +++ b/monai/apps/pathology/transforms/post/array.py @@ -14,9 +14,9 @@ import numpy as np from monai.config.type_definitions import DtypeLike, NdarrayOrTensor +from monai.transforms import Activations, AsDiscrete, BoundingRect from monai.transforms.post.array import Activations, AsDiscrete, RemoveSmallObjects, SobelGradients from monai.transforms.transform import Transform -from monai.transforms import Activations, AsDiscrete, BoundingRect from monai.transforms.utils_pytorch_numpy_unification import max, maximum, min from monai.utils import TransformBackends, convert_to_numpy, optional_import from monai.utils.enums import HoVerNetBranch @@ -327,13 +327,13 @@ def __call__(self, mask: NdarrayOrTensor, instance_border: NdarrayOrTensor) -> N class PostProcessHoVerNet(Transform): def __init__( self, - post_process_segmentation: Transform, + seg_postpprocessing: Callable, distance_map_key: str = "dist", points_num: int = 3, level: Optional[float] = None, dtype: Optional[DtypeLike] = int, return_binary: Optional[bool] = True, - pred_binary_key: Optional[str] = 'pred_binary', + pred_binary_key: Optional[str] = "pred_binary", return_centroids: Optional[bool] = None, output_classes: Optional[int] = None, inst_info_dict_key: Optional[str] = "inst_info_dict", @@ -346,18 +346,18 @@ def __init__( self.output_classes = output_classes self.inst_info_dict_key = inst_info_dict_key - self.post_process_segmentation = post_process_segmentation + self.seg_postpprocessing = seg_postpprocessing self.generate_instance_contour = GenerateInstanceContour(points_num=points_num, level=level) self.generate_instance_centroid = GenerateInstanceCentroid(dtype=dtype) self.generate_instance_type = GenerateInstanceType() - + def __call__(self, pred: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: device = pred[HoVerNetBranch.NP.value].device if HoVerNetBranch.NC.value in pred.keys(): type_pred = Activations(softmax=True)(pred[HoVerNetBranch.NC.value]) type_pred = AsDiscrete(argmax=True)(type_pred) - - pred_inst_dict = self.post_process_segmentation(pred) + + pred_inst_dict = self.seg_postpprocessing(pred) pred_inst = pred_inst_dict[self.distance_map_key] inst_id_list = np.unique(pred_inst)[1:] # exclude background @@ -367,7 +367,7 @@ def __call__(self, pred: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N for inst_id in inst_id_list: inst_map = pred_inst == inst_id inst_bbox = BoundingRect()(inst_map) - inst_map = inst_map[:, inst_bbox[0][0]: inst_bbox[0][1], inst_bbox[0][2]: inst_bbox[0][3]] + inst_map = inst_map[:, inst_bbox[0][0] : inst_bbox[0][1], inst_bbox[0][2] : inst_bbox[0][3]] offset = [inst_bbox[0][2], inst_bbox[0][0]] inst_contour = self.generate_instance_contour(inst_map, offset) inst_centroid = self.generate_instance_centroid(inst_map, offset) @@ -383,9 +383,9 @@ def __call__(self, pred: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N if self.output_classes is not None: for inst_id in list(inst_info_dict.keys()): inst_type, type_prob = self.generate_instance_type( - bbox=inst_info_dict[inst_id]["bounding_box"], - type_pred=type_pred, - seg_pred=pred_inst, + bbox=inst_info_dict[inst_id]["bounding_box"], + type_pred=type_pred, + seg_pred=pred_inst, instance_id=inst_id, ) inst_info_dict[inst_id]["type"] = inst_type From c55386785ff28bda5b1666e380c32deaaf1d4605 Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Thu, 3 Nov 2022 20:45:04 +0800 Subject: [PATCH 04/15] add docstring Signed-off-by: KumoLiu --- monai/apps/pathology/transforms/post/array.py | 24 ++++++++++++++++--- 1 file changed, 21 insertions(+), 3 deletions(-) diff --git a/monai/apps/pathology/transforms/post/array.py b/monai/apps/pathology/transforms/post/array.py index cc6d6e3a8d..5204ceaa18 100644 --- a/monai/apps/pathology/transforms/post/array.py +++ b/monai/apps/pathology/transforms/post/array.py @@ -325,16 +325,34 @@ def __call__(self, mask: NdarrayOrTensor, instance_border: NdarrayOrTensor) -> N class PostProcessHoVerNet(Transform): + """ + Since HoVerNet do segmentation and classification meanwhile, this transform is used to combine postprocessing + for segmentation and classification. It assumes input as a dictionary. + + Args: + seg_postpprocessing: execute post-processing transformation for the segmentation output from model. + Typically, several Tensor based transforms composed by `Compose`. + distance_map_key: the key pointing to the distance map generated by `seg_postprocessing`. + points_num: assumed that the created contour does not form a contour if it does not contain more points + than the specified value. Defaults to 3. + level: optional. Used in `skimage.measure.find_contours`. Value along which to find contours in the array. + By default, the level is set to (max(image) + min(image)) / 2. + return_binary: whether to return the binary segmentation prediction after `seg_postprocessing`. + pred_binary_key: if `return_binary` is True, this `pred_binary_key` is used to get the binary prediciton + from the output. + return_centroids: whether to return centroids for each instance. + output_classes: number of the nuclear type classes. + inst_info_dict_keys: key use to record information for each instance. + """ def __init__( self, seg_postpprocessing: Callable, distance_map_key: str = "dist", points_num: int = 3, level: Optional[float] = None, - dtype: Optional[DtypeLike] = int, return_binary: Optional[bool] = True, pred_binary_key: Optional[str] = "pred_binary", - return_centroids: Optional[bool] = None, + return_centroids: Optional[bool] = False, output_classes: Optional[int] = None, inst_info_dict_key: Optional[str] = "inst_info_dict", ) -> None: @@ -348,7 +366,7 @@ def __init__( self.seg_postpprocessing = seg_postpprocessing self.generate_instance_contour = GenerateInstanceContour(points_num=points_num, level=level) - self.generate_instance_centroid = GenerateInstanceCentroid(dtype=dtype) + self.generate_instance_centroid = GenerateInstanceCentroid() self.generate_instance_type = GenerateInstanceType() def __call__(self, pred: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: From 4ce5a53af12cd957d5ca014a480c9f86bcc66517 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 3 Nov 2022 12:46:21 +0000 Subject: [PATCH 05/15] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/apps/pathology/transforms/post/array.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/apps/pathology/transforms/post/array.py b/monai/apps/pathology/transforms/post/array.py index 5204ceaa18..67dea9d650 100644 --- a/monai/apps/pathology/transforms/post/array.py +++ b/monai/apps/pathology/transforms/post/array.py @@ -335,7 +335,7 @@ class PostProcessHoVerNet(Transform): distance_map_key: the key pointing to the distance map generated by `seg_postprocessing`. points_num: assumed that the created contour does not form a contour if it does not contain more points than the specified value. Defaults to 3. - level: optional. Used in `skimage.measure.find_contours`. Value along which to find contours in the array. + level: optional. Used in `skimage.measure.find_contours`. Value along which to find contours in the array. By default, the level is set to (max(image) + min(image)) / 2. return_binary: whether to return the binary segmentation prediction after `seg_postprocessing`. pred_binary_key: if `return_binary` is True, this `pred_binary_key` is used to get the binary prediciton From e91635bd1f4f83e1f12107f4ce78ac5c3df614d3 Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Thu, 3 Nov 2022 22:28:57 +0800 Subject: [PATCH 06/15] minor fix Signed-off-by: KumoLiu --- docs/source/apps.rst | 2 ++ monai/apps/pathology/transforms/__init__.py | 1 + monai/apps/pathology/transforms/post/__init__.py | 1 + monai/apps/pathology/transforms/post/array.py | 3 ++- 4 files changed, 6 insertions(+), 1 deletion(-) diff --git a/docs/source/apps.rst b/docs/source/apps.rst index b4cc200f08..0d9ffccee3 100644 --- a/docs/source/apps.rst +++ b/docs/source/apps.rst @@ -141,6 +141,8 @@ Applications :members: .. autoclass:: GenerateWatershedMarkers :members: +.. autoclass:: HoVerNetPostProcessing + :members: .. automodule:: monai.apps.pathology.transforms.post.dictionary .. autoclass:: Watershedd diff --git a/monai/apps/pathology/transforms/__init__.py b/monai/apps/pathology/transforms/__init__.py index 616cf3220a..06854483ff 100644 --- a/monai/apps/pathology/transforms/__init__.py +++ b/monai/apps/pathology/transforms/__init__.py @@ -15,6 +15,7 @@ GenerateWatershedMarkers, GenerateWatershedMask, Watershed, + HoVerNetPostProcessing, ) from .post.dictionary import ( GenerateDistanceMapD, diff --git a/monai/apps/pathology/transforms/post/__init__.py b/monai/apps/pathology/transforms/post/__init__.py index 46e1968367..ef2818d011 100644 --- a/monai/apps/pathology/transforms/post/__init__.py +++ b/monai/apps/pathology/transforms/post/__init__.py @@ -15,6 +15,7 @@ GenerateWatershedMarkers, GenerateWatershedMask, Watershed, + HoVerNetPostProcessing, ) from .dictionary import ( GenerateDistanceMapD, diff --git a/monai/apps/pathology/transforms/post/array.py b/monai/apps/pathology/transforms/post/array.py index 67dea9d650..b78b7d655b 100644 --- a/monai/apps/pathology/transforms/post/array.py +++ b/monai/apps/pathology/transforms/post/array.py @@ -33,6 +33,7 @@ "GenerateInstanceBorder", "GenerateDistanceMap", "GenerateWatershedMarkers", + "HoVerNetPostProcessing", ] @@ -324,7 +325,7 @@ def __call__(self, mask: NdarrayOrTensor, instance_border: NdarrayOrTensor) -> N return convert_to_dst_type(marker, mask, dtype=self.dtype)[0] -class PostProcessHoVerNet(Transform): +class HoVerNetPostProcessing(Transform): """ Since HoVerNet do segmentation and classification meanwhile, this transform is used to combine postprocessing for segmentation and classification. It assumes input as a dictionary. From 2e05042b10864e55c5ea27d1173503e1730be7d1 Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Mon, 7 Nov 2022 23:45:14 +0800 Subject: [PATCH 07/15] fix flake8 Signed-off-by: KumoLiu --- monai/apps/pathology/transforms/__init__.py | 2 +- monai/apps/pathology/transforms/post/__init__.py | 2 +- monai/apps/pathology/transforms/post/array.py | 9 ++++----- 3 files changed, 6 insertions(+), 7 deletions(-) diff --git a/monai/apps/pathology/transforms/__init__.py b/monai/apps/pathology/transforms/__init__.py index b19544a7ef..44ff50e271 100644 --- a/monai/apps/pathology/transforms/__init__.py +++ b/monai/apps/pathology/transforms/__init__.py @@ -18,8 +18,8 @@ GenerateSuccinctContour, GenerateWatershedMarkers, GenerateWatershedMask, - Watershed, HoVerNetPostProcessing, + Watershed, ) from .post.dictionary import ( GenerateDistanceMapD, diff --git a/monai/apps/pathology/transforms/post/__init__.py b/monai/apps/pathology/transforms/post/__init__.py index befb66db94..bb50682c1a 100644 --- a/monai/apps/pathology/transforms/post/__init__.py +++ b/monai/apps/pathology/transforms/post/__init__.py @@ -18,8 +18,8 @@ GenerateSuccinctContour, GenerateWatershedMarkers, GenerateWatershedMask, - Watershed, HoVerNetPostProcessing, + Watershed, ) from .dictionary import ( GenerateDistanceMapD, diff --git a/monai/apps/pathology/transforms/post/array.py b/monai/apps/pathology/transforms/post/array.py index 501fc91a02..54c41bffa8 100644 --- a/monai/apps/pathology/transforms/post/array.py +++ b/monai/apps/pathology/transforms/post/array.py @@ -9,20 +9,18 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, Dict, List, Hashable, Optional, Mapping, Sequence, Tuple, Union +from typing import Callable, Dict, Hashable, List, Mapping, Optional, Sequence, Tuple, Union import numpy as np from monai.config.type_definitions import DtypeLike, NdarrayOrTensor -from monai.transforms import Activations, AsDiscrete, BoundingRect -from monai.transforms.post.array import Activations, AsDiscrete, RemoveSmallObjects, SobelGradients +from monai.transforms import Activations, AsDiscrete, BoundingRect, RemoveSmallObjects, SobelGradients from monai.transforms.transform import Transform from monai.transforms.utils_pytorch_numpy_unification import max, maximum, min, sum, unique from monai.utils import TransformBackends, convert_to_numpy, optional_import from monai.utils.enums import HoVerNetBranch -from monai.utils.type_conversion import convert_to_dst_type, convert_to_tensor from monai.utils.misc import ensure_tuple_rep -from monai.utils.type_conversion import convert_to_dst_type +from monai.utils.type_conversion import convert_to_dst_type, convert_to_tensor label, _ = optional_import("scipy.ndimage.measurements", name="label") disk, _ = optional_import("skimage.morphology", name="disk") @@ -650,6 +648,7 @@ class HoVerNetPostProcessing(Transform): output_classes: number of the nuclear type classes. inst_info_dict_keys: key use to record information for each instance. """ + def __init__( self, seg_postpprocessing: Callable, From d3d9941ef3c5e359a7ffb2016f5dfc871b0ca5e6 Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Tue, 8 Nov 2022 10:31:19 +0800 Subject: [PATCH 08/15] fix flake8 Signed-off-by: KumoLiu --- monai/apps/pathology/transforms/post/array.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/monai/apps/pathology/transforms/post/array.py b/monai/apps/pathology/transforms/post/array.py index 54c41bffa8..6933bfd890 100644 --- a/monai/apps/pathology/transforms/post/array.py +++ b/monai/apps/pathology/transforms/post/array.py @@ -20,7 +20,7 @@ from monai.utils import TransformBackends, convert_to_numpy, optional_import from monai.utils.enums import HoVerNetBranch from monai.utils.misc import ensure_tuple_rep -from monai.utils.type_conversion import convert_to_dst_type, convert_to_tensor +from monai.utils.type_conversion import convert_to_dst_type label, _ = optional_import("scipy.ndimage.measurements", name="label") disk, _ = optional_import("skimage.morphology", name="disk") @@ -675,7 +675,7 @@ def __init__( self.generate_instance_type = GenerateInstanceType() def __call__(self, pred: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: - device = pred[HoVerNetBranch.NP.value].device + pred = dict(pred) if HoVerNetBranch.NC.value in pred.keys(): type_pred = Activations(softmax=True)(pred[HoVerNetBranch.NC.value]) type_pred = AsDiscrete(argmax=True)(type_pred) @@ -706,18 +706,18 @@ def __call__(self, pred: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N if self.output_classes is not None: for inst_id in list(inst_info_dict.keys()): inst_type, type_prob = self.generate_instance_type( - bbox=inst_info_dict[inst_id]["bounding_box"], + bbox=inst_info_dict[inst_id]["bounding_box"], # type: ignore type_pred=type_pred, seg_pred=pred_inst, instance_id=inst_id, ) - inst_info_dict[inst_id]["type"] = inst_type - inst_info_dict[inst_id]["type_probability"] = type_prob + inst_info_dict[inst_id]["type"] = inst_type # type: ignore + inst_info_dict[inst_id]["type_probability"] = type_prob # type: ignore - pred_inst = convert_to_tensor(pred_inst, device=device) + pred_inst = convert_to_dst_type(pred_inst, pred[HoVerNetBranch.NP.value]) pred[HoVerNetBranch.NP.value] = pred_inst if self.return_binary: pred_inst[pred_inst > 0] = 1 pred[self.pred_binary_key] = pred_inst - pred[self.inst_info_dict_key] = inst_info_dict + pred[self.inst_info_dict_key] = inst_info_dict # type: ignore return pred From 348f8307f1a83275ff829b311ad1b79d4fc13161 Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Wed, 9 Nov 2022 10:37:03 +0800 Subject: [PATCH 09/15] update basd on comments Signed-off-by: KumoLiu --- monai/apps/pathology/transforms/post/array.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/monai/apps/pathology/transforms/post/array.py b/monai/apps/pathology/transforms/post/array.py index 6933bfd890..333ae6f8d9 100644 --- a/monai/apps/pathology/transforms/post/array.py +++ b/monai/apps/pathology/transforms/post/array.py @@ -514,7 +514,7 @@ class GenerateInstanceContour(Transform): the pixels to which lines need to be drawn Args: - points_num: assumed that the created contour does not form a contour if it does not contain more points + min_num_points: assumed that the created contour does not form a contour if it does not contain more points than the specified value. Defaults to 3. level: optional. Value along which to find contours in the array. By default, the level is set to (max(image) + min(image)) / 2. @@ -523,9 +523,9 @@ class GenerateInstanceContour(Transform): backend = [TransformBackends.NUMPY] - def __init__(self, points_num: int = 3, level: Optional[float] = None) -> None: + def __init__(self, min_num_points: int = 3, level: Optional[float] = None) -> None: self.level = level - self.points_num = points_num + self.min_num_points = min_num_points def __call__(self, image: NdarrayOrTensor, offset: Optional[Sequence[int]] = (0, 0)) -> np.ndarray: """ @@ -539,10 +539,10 @@ def __call__(self, image: NdarrayOrTensor, offset: Optional[Sequence[int]] = (0, generate_contour = GenerateSuccinctContour(image.shape[0], image.shape[1]) inst_contour = generate_contour(inst_contour_cv) - # < `self.points_num` points don't make a contour, so skip, likely artifact too + # < `self.min_num_points` points don't make a contour, so skip, likely artifact too # as the contours obtained via approximation => too small or sthg - if inst_contour.shape[0] < self.points_num: - print(f"< {self.points_num} points don't make a contour, so skip") + if inst_contour.shape[0] < self.min_num_points: + print(f"< {self.min_num_points} points don't make a contour, so skip") return None # type: ignore # check for tricky shape elif len(inst_contour.shape) != 2: @@ -637,7 +637,7 @@ class HoVerNetPostProcessing(Transform): seg_postpprocessing: execute post-processing transformation for the segmentation output from model. Typically, several Tensor based transforms composed by `Compose`. distance_map_key: the key pointing to the distance map generated by `seg_postprocessing`. - points_num: assumed that the created contour does not form a contour if it does not contain more points + min_num_points: assumed that the created contour does not form a contour if it does not contain more points than the specified value. Defaults to 3. level: optional. Used in `skimage.measure.find_contours`. Value along which to find contours in the array. By default, the level is set to (max(image) + min(image)) / 2. @@ -653,7 +653,7 @@ def __init__( self, seg_postpprocessing: Callable, distance_map_key: str = "dist", - points_num: int = 3, + min_num_points: int = 3, level: Optional[float] = None, return_binary: Optional[bool] = True, pred_binary_key: Optional[str] = "pred_binary", @@ -670,7 +670,7 @@ def __init__( self.inst_info_dict_key = inst_info_dict_key self.seg_postpprocessing = seg_postpprocessing - self.generate_instance_contour = GenerateInstanceContour(points_num=points_num, level=level) + self.generate_instance_contour = GenerateInstanceContour(min_num_points=min_num_points, level=level) self.generate_instance_centroid = GenerateInstanceCentroid() self.generate_instance_type = GenerateInstanceType() From a32147481b623dd31fe620f118bc080757d77cbf Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Wed, 9 Nov 2022 15:06:48 +0800 Subject: [PATCH 10/15] add unit tests Signed-off-by: KumoLiu --- monai/apps/pathology/transforms/post/array.py | 4 +- tests/test_hovernet_post_processing.py | 87 +++++++++++++++++++ 2 files changed, 89 insertions(+), 2 deletions(-) create mode 100644 tests/test_hovernet_post_processing.py diff --git a/monai/apps/pathology/transforms/post/array.py b/monai/apps/pathology/transforms/post/array.py index 333ae6f8d9..0e2ce8bcb5 100644 --- a/monai/apps/pathology/transforms/post/array.py +++ b/monai/apps/pathology/transforms/post/array.py @@ -646,7 +646,7 @@ class HoVerNetPostProcessing(Transform): from the output. return_centroids: whether to return centroids for each instance. output_classes: number of the nuclear type classes. - inst_info_dict_keys: key use to record information for each instance. + inst_info_dict_key: key use to record information for each instance. """ def __init__( @@ -714,7 +714,7 @@ def __call__(self, pred: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N inst_info_dict[inst_id]["type"] = inst_type # type: ignore inst_info_dict[inst_id]["type_probability"] = type_prob # type: ignore - pred_inst = convert_to_dst_type(pred_inst, pred[HoVerNetBranch.NP.value]) + pred_inst = convert_to_dst_type(pred_inst, pred[HoVerNetBranch.NP.value])[0] pred[HoVerNetBranch.NP.value] = pred_inst if self.return_binary: pred_inst[pred_inst > 0] = 1 diff --git a/tests/test_hovernet_post_processing.py b/tests/test_hovernet_post_processing.py new file mode 100644 index 0000000000..79137c78d0 --- /dev/null +++ b/tests/test_hovernet_post_processing.py @@ -0,0 +1,87 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from parameterized import parameterized + +from monai.apps.pathology.transforms.post.array import HoVerNetPostProcessing +from monai.apps.pathology.transforms.post.dictionary import ( + GenerateDistanceMapd, + GenerateInstanceBorderd, + GenerateWatershedMarkersd, + GenerateWatershedMaskd, + Watershedd, +) +from monai.transforms import Compose, ComputeHoVerMaps, FillHoles, GaussianSmooth +from monai.utils import min_version, optional_import +from monai.utils.enums import HoVerNetBranch +from tests.utils import TEST_NDARRAYS, assert_allclose + +_, has_scipy = optional_import("scipy", "1.8.1", min_version) +_, has_skimage = optional_import("skimage", "0.19.3", min_version) + +y, x = np.ogrid[0:30, 0:30] +image = (x - 10) ** 2 + (y - 10) ** 2 <= 5**2 +hovermap = ComputeHoVerMaps()(image[None].astype(int)) + +seg_postpprocessing = Compose( + [ + GenerateWatershedMaskd( + keys=HoVerNetBranch.NP.value, sigmoid=True, softmax=False, threshold=0.7, remove_small_objects=False + ), + GenerateInstanceBorderd( + keys="mask", hover_map_key=HoVerNetBranch.HV.value, kernel_size=3, remove_small_objects=False + ), + GenerateDistanceMapd(keys="mask", border_key="border", smooth_fn=GaussianSmooth()), + GenerateWatershedMarkersd( + keys="mask", border_key="border", threshold=0.9, radius=2, postprocess_fn=FillHoles() + ), + Watershedd(keys="dist", mask_key="mask", markers_key="markers"), + ] +) +TEST_CASE_1 = [ + seg_postpprocessing, + {"return_centroids": True, "return_centroids": True, "output_classes": 1}, + [image, [10, 10]], +] +TEST_CASE_2 = [ + seg_postpprocessing, + {"return_centroids": False, "return_centroids": False, "output_classes": None}, + [image], +] + + +TEST_CASE = [] +for p in TEST_NDARRAYS: + pred = { + HoVerNetBranch.NP.value: p(image[None].astype(float)), + HoVerNetBranch.HV.value: p(hovermap), + HoVerNetBranch.NC.value: p(image[None]), + } + TEST_CASE.append([pred] + TEST_CASE_1) + TEST_CASE.append([pred] + TEST_CASE_2) + + +@unittest.skipUnless(has_skimage, "Requires scikit-image library.") +class TestHoVerNetPostProcessing(unittest.TestCase): + @parameterized.expand(TEST_CASE) + def test_value(self, test_data, seg_postpprocessing, kwargs, expected): + post_transforms = HoVerNetPostProcessing(seg_postpprocessing=seg_postpprocessing, **kwargs) + out = post_transforms(test_data) + assert_allclose(out[HoVerNetBranch.NP.value].squeeze(), expected[0], type_test=False) + if out["inst_info_dict"] is not None: + assert_allclose(out["inst_info_dict"][1]["centroid"], expected[1], type_test=False) + + +if __name__ == "__main__": + unittest.main() From 399a05ccfe9c0755e95c46ec1392fcb6cb722ab9 Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Wed, 9 Nov 2022 16:03:44 +0800 Subject: [PATCH 11/15] fix CI Signed-off-by: KumoLiu --- tests/test_hovernet_post_processing.py | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/tests/test_hovernet_post_processing.py b/tests/test_hovernet_post_processing.py index 79137c78d0..f24881baff 100644 --- a/tests/test_hovernet_post_processing.py +++ b/tests/test_hovernet_post_processing.py @@ -32,7 +32,6 @@ y, x = np.ogrid[0:30, 0:30] image = (x - 10) ** 2 + (y - 10) ** 2 <= 5**2 -hovermap = ComputeHoVerMaps()(image[None].astype(int)) seg_postpprocessing = Compose( [ @@ -63,21 +62,24 @@ TEST_CASE = [] for p in TEST_NDARRAYS: - pred = { - HoVerNetBranch.NP.value: p(image[None].astype(float)), - HoVerNetBranch.HV.value: p(hovermap), - HoVerNetBranch.NC.value: p(image[None]), - } - TEST_CASE.append([pred] + TEST_CASE_1) - TEST_CASE.append([pred] + TEST_CASE_2) + TEST_CASE.append([p, image] + TEST_CASE_1) + TEST_CASE.append([p, image] + TEST_CASE_2) +@unittest.skipUnless(has_scipy, "Requires scipy library.") @unittest.skipUnless(has_skimage, "Requires scikit-image library.") class TestHoVerNetPostProcessing(unittest.TestCase): @parameterized.expand(TEST_CASE) - def test_value(self, test_data, seg_postpprocessing, kwargs, expected): + def test_value(self, in_type, test_data, seg_postpprocessing, kwargs, expected): + hovermap = ComputeHoVerMaps()(test_data[None].astype(int)) + input = { + HoVerNetBranch.NP.value: in_type(test_data[None].astype(float)), + HoVerNetBranch.HV.value: in_type(hovermap), + HoVerNetBranch.NC.value: in_type(test_data[None]), + } + post_transforms = HoVerNetPostProcessing(seg_postpprocessing=seg_postpprocessing, **kwargs) - out = post_transforms(test_data) + out = post_transforms(input) assert_allclose(out[HoVerNetBranch.NP.value].squeeze(), expected[0], type_test=False) if out["inst_info_dict"] is not None: assert_allclose(out["inst_info_dict"][1]["centroid"], expected[1], type_test=False) From 767f88a5e041b8f0048807fe6ce0b517a88c5ea1 Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Wed, 9 Nov 2022 17:26:21 +0800 Subject: [PATCH 12/15] minor fix Signed-off-by: KumoLiu --- monai/apps/pathology/transforms/post/dictionary.py | 6 +++--- tests/test_generate_instance_contour.py | 4 ++-- tests/test_generate_instance_contourd.py | 4 ++-- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/monai/apps/pathology/transforms/post/dictionary.py b/monai/apps/pathology/transforms/post/dictionary.py index c358eebf39..1103990f94 100644 --- a/monai/apps/pathology/transforms/post/dictionary.py +++ b/monai/apps/pathology/transforms/post/dictionary.py @@ -354,7 +354,7 @@ class GenerateInstanceContourd(MapTransform): contour_key_postfix: the output contour coordinates will be written to the value of `{key}_{contour_key_postfix}`. offset_key: keys of offset used in `GenerateInstanceContour`. - points_num: assumed that the created contour does not form a contour if it does not contain more points + min_num_points: assumed that the created contour does not form a contour if it does not contain more points than the specified value. Defaults to 3. level: optional. Value along which to find contours in the array. By default, the level is set to (max(image) + min(image)) / 2. @@ -369,12 +369,12 @@ def __init__( keys: KeysCollection, contour_key_postfix: str = "contour", offset_key: Optional[str] = None, - points_num: int = 3, + min_num_points: int = 3, level: Optional[float] = None, allow_missing_keys: bool = False, ) -> None: super().__init__(keys, allow_missing_keys) - self.converter = GenerateInstanceContour(points_num=points_num, level=level) + self.converter = GenerateInstanceContour(min_num_points=min_num_points, level=level) self.contour_key_postfix = contour_key_postfix self.offset_key = offset_key diff --git a/tests/test_generate_instance_contour.py b/tests/test_generate_instance_contour.py index 22b778c06c..8c43bf5bc5 100644 --- a/tests/test_generate_instance_contour.py +++ b/tests/test_generate_instance_contour.py @@ -45,11 +45,11 @@ @unittest.skipUnless(has_skimage, "Requires scikit-image library.") class TestGenerateInstanceContour(unittest.TestCase): @parameterized.expand(TEST_CASE) - def test_shape(self, in_type, test_data, points_num, offset, expected): + def test_shape(self, in_type, test_data, min_num_points, offset, expected): inst_bbox = get_bbox(test_data[None]) inst_map = test_data[inst_bbox[0][0] : inst_bbox[0][1], inst_bbox[0][2] : inst_bbox[0][3]] - result = GenerateInstanceContour(points_num=points_num)(in_type(inst_map[None]), offset=offset) + result = GenerateInstanceContour(min_num_points=min_num_points)(in_type(inst_map[None]), offset=offset) assert_allclose(result, expected, type_test=False) diff --git a/tests/test_generate_instance_contourd.py b/tests/test_generate_instance_contourd.py index 9c9c1efbe6..b02bfd461b 100644 --- a/tests/test_generate_instance_contourd.py +++ b/tests/test_generate_instance_contourd.py @@ -46,12 +46,12 @@ @unittest.skipUnless(has_skimage, "Requires scikit-image library.") class TestGenerateInstanceContourd(unittest.TestCase): @parameterized.expand(TEST_CASE) - def test_shape(self, in_type, test_data, points_num, offset, expected): + def test_shape(self, in_type, test_data, min_num_points, offset, expected): inst_bbox = get_bbox(test_data[None]) inst_map = test_data[inst_bbox[0][0] : inst_bbox[0][1], inst_bbox[0][2] : inst_bbox[0][3]] test_data = {"image": in_type(inst_map[None]), "offset": offset} result = GenerateInstanceContourd( - keys="image", contour_key_postfix="contour", offset_key="offset", points_num=points_num + keys="image", contour_key_postfix="contour", offset_key="offset", min_num_points=min_num_points )(test_data) assert_allclose(result["image_contour"], expected, type_test=False) From fbb05460a001655d683f4a513770490fc873b5db Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Fri, 11 Nov 2022 16:53:16 +0800 Subject: [PATCH 13/15] update based on comments Signed-off-by: KumoLiu --- monai/apps/pathology/transforms/post/array.py | 45 ++--------- .../pathology/transforms/post/dictionary.py | 68 +++++++++++++++- tests/test_hovernet_post_processing.py | 23 ++---- tests/test_hovernet_post_processingd.py | 80 +++++++++++++++++++ 4 files changed, 162 insertions(+), 54 deletions(-) create mode 100644 tests/test_hovernet_post_processingd.py diff --git a/monai/apps/pathology/transforms/post/array.py b/monai/apps/pathology/transforms/post/array.py index 0e2ce8bcb5..4b057a7a0f 100644 --- a/monai/apps/pathology/transforms/post/array.py +++ b/monai/apps/pathology/transforms/post/array.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, Dict, Hashable, List, Mapping, Optional, Sequence, Tuple, Union +from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union import numpy as np @@ -18,7 +18,6 @@ from monai.transforms.transform import Transform from monai.transforms.utils_pytorch_numpy_unification import max, maximum, min, sum, unique from monai.utils import TransformBackends, convert_to_numpy, optional_import -from monai.utils.enums import HoVerNetBranch from monai.utils.misc import ensure_tuple_rep from monai.utils.type_conversion import convert_to_dst_type @@ -634,61 +633,39 @@ class HoVerNetPostProcessing(Transform): for segmentation and classification. It assumes input as a dictionary. Args: - seg_postpprocessing: execute post-processing transformation for the segmentation output from model. - Typically, several Tensor based transforms composed by `Compose`. - distance_map_key: the key pointing to the distance map generated by `seg_postprocessing`. min_num_points: assumed that the created contour does not form a contour if it does not contain more points than the specified value. Defaults to 3. level: optional. Used in `skimage.measure.find_contours`. Value along which to find contours in the array. By default, the level is set to (max(image) + min(image)) / 2. - return_binary: whether to return the binary segmentation prediction after `seg_postprocessing`. - pred_binary_key: if `return_binary` is True, this `pred_binary_key` is used to get the binary prediciton - from the output. return_centroids: whether to return centroids for each instance. output_classes: number of the nuclear type classes. - inst_info_dict_key: key use to record information for each instance. """ def __init__( self, - seg_postpprocessing: Callable, - distance_map_key: str = "dist", min_num_points: int = 3, level: Optional[float] = None, - return_binary: Optional[bool] = True, - pred_binary_key: Optional[str] = "pred_binary", return_centroids: Optional[bool] = False, output_classes: Optional[int] = None, - inst_info_dict_key: Optional[str] = "inst_info_dict", ) -> None: super().__init__() - self.distance_map_key = distance_map_key - self.return_binary = return_binary - self.pred_binary_key = pred_binary_key self.return_centroids = return_centroids self.output_classes = output_classes - self.inst_info_dict_key = inst_info_dict_key - self.seg_postpprocessing = seg_postpprocessing self.generate_instance_contour = GenerateInstanceContour(min_num_points=min_num_points, level=level) self.generate_instance_centroid = GenerateInstanceCentroid() self.generate_instance_type = GenerateInstanceType() - def __call__(self, pred: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: - pred = dict(pred) - if HoVerNetBranch.NC.value in pred.keys(): - type_pred = Activations(softmax=True)(pred[HoVerNetBranch.NC.value]) - type_pred = AsDiscrete(argmax=True)(type_pred) + def __call__(self, type_pred: NdarrayOrTensor, inst_pred: NdarrayOrTensor) -> Dict: # type: ignore + type_pred = Activations(softmax=True)(type_pred) + type_pred = AsDiscrete(argmax=True)(type_pred) - pred_inst_dict = self.seg_postpprocessing(pred) - pred_inst = pred_inst_dict[self.distance_map_key] - - inst_id_list = np.unique(pred_inst)[1:] # exclude background + inst_id_list = np.unique(inst_pred)[1:] # exclude background inst_info_dict = None if self.return_centroids: inst_info_dict = {} for inst_id in inst_id_list: - inst_map = pred_inst == inst_id + inst_map = inst_pred == inst_id inst_bbox = BoundingRect()(inst_map) inst_map = inst_map[:, inst_bbox[0][0] : inst_bbox[0][1], inst_bbox[0][2] : inst_bbox[0][3]] offset = [inst_bbox[0][2], inst_bbox[0][0]] @@ -708,16 +685,10 @@ def __call__(self, pred: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N inst_type, type_prob = self.generate_instance_type( bbox=inst_info_dict[inst_id]["bounding_box"], # type: ignore type_pred=type_pred, - seg_pred=pred_inst, + seg_pred=inst_pred, instance_id=inst_id, ) inst_info_dict[inst_id]["type"] = inst_type # type: ignore inst_info_dict[inst_id]["type_probability"] = type_prob # type: ignore - pred_inst = convert_to_dst_type(pred_inst, pred[HoVerNetBranch.NP.value])[0] - pred[HoVerNetBranch.NP.value] = pred_inst - if self.return_binary: - pred_inst[pred_inst > 0] = 1 - pred[self.pred_binary_key] = pred_inst - pred[self.inst_info_dict_key] = inst_info_dict # type: ignore - return pred + return inst_info_dict diff --git a/monai/apps/pathology/transforms/post/dictionary.py b/monai/apps/pathology/transforms/post/dictionary.py index 1103990f94..6cced88ca9 100644 --- a/monai/apps/pathology/transforms/post/dictionary.py +++ b/monai/apps/pathology/transforms/post/dictionary.py @@ -22,11 +22,13 @@ GenerateSuccinctContour, GenerateWatershedMarkers, GenerateWatershedMask, + HoVerNetPostProcessing, Watershed, ) from monai.config.type_definitions import DtypeLike, KeysCollection, NdarrayOrTensor -from monai.transforms.transform import MapTransform -from monai.utils import optional_import +from monai.transforms.transform import MapTransform, Transform +from monai.utils import convert_to_dst_type, optional_import +from monai.utils.enums import HoVerNetBranch find_contours, _ = optional_import("skimage.measure", name="find_contours") moments, _ = optional_import("skimage.measure", name="moments") @@ -480,6 +482,68 @@ def __call__(self, data): return d +class HoVerNetPostProcessingd(Transform): + """ + Dictionary-based wrapper of :py:class:`monai.apps.pathology.transforms.post.array.HoVerNetPostProcessing`. + Generate instance type and probability for each instance. + + Args: + type_pred_key: the key pointing to the pred type map to be transformed. + inst_pred_key: the key pointing to the pred distance map. + min_num_points: assumed that the created contour does not form a contour if it does not contain more points + than the specified value. Defaults to 3. + level: optional. Used in `skimage.measure.find_contours`. Value along which to find contours in the array. + By default, the level is set to (max(image) + min(image)) / 2. + return_binary: whether to return the binary segmentation prediction after `seg_postprocessing`. + pred_binary_key: if `return_binary` is True, this `pred_binary_key` is used to get the binary prediciton + from the output. + return_centroids: whether to return centroids for each instance. + output_classes: number of the nuclear type classes. + inst_info_dict_key: key use to record information for each instance. + allow_missing_keys: don't raise exception if key is missing. + + """ + + def __init__( + self, + type_pred_key: str = HoVerNetBranch.NC.value, + inst_pred_key: str = "dist", + min_num_points: int = 3, + level: Optional[float] = None, + return_binary: Optional[bool] = True, + pred_binary_key: Optional[str] = "pred_binary", + return_centroids: Optional[bool] = False, + output_classes: Optional[int] = None, + inst_info_dict_key: Optional[str] = "inst_info_dict", + ) -> None: + super().__init__() + self.converter = HoVerNetPostProcessing( + min_num_points=min_num_points, level=level, return_centroids=return_centroids, output_classes=output_classes + ) + self.type_pred_key = type_pred_key + self.inst_pred_key = inst_pred_key + self.pred_binary_key = pred_binary_key + self.inst_info_dict_key = inst_info_dict_key + self.return_binary = return_binary + + def __call__(self, data): + d = dict(data) + inst_pred = d[self.inst_pred_key] + type_pred = d[self.type_pred_key] + inst_info_dict = self.converter(type_pred, inst_pred) + key_to_add = f"{self.inst_info_dict_key}" + if key_to_add in d: + raise KeyError(f"Type information with key {key_to_add} already exists.") + d[key_to_add] = inst_info_dict # type: ignore + + inst_pred = convert_to_dst_type(inst_pred, type_pred)[0] + if self.return_binary: + inst_pred[inst_pred > 0] = 1 + d[self.pred_binary_key] = inst_pred + + return d + + WatershedD = WatershedDict = Watershedd GenerateWatershedMaskD = GenerateWatershedMaskDict = GenerateWatershedMaskd GenerateInstanceBorderD = GenerateInstanceBorderDict = GenerateInstanceBorderd diff --git a/tests/test_hovernet_post_processing.py b/tests/test_hovernet_post_processing.py index f24881baff..aa60929e53 100644 --- a/tests/test_hovernet_post_processing.py +++ b/tests/test_hovernet_post_processing.py @@ -48,16 +48,8 @@ Watershedd(keys="dist", mask_key="mask", markers_key="markers"), ] ) -TEST_CASE_1 = [ - seg_postpprocessing, - {"return_centroids": True, "return_centroids": True, "output_classes": 1}, - [image, [10, 10]], -] -TEST_CASE_2 = [ - seg_postpprocessing, - {"return_centroids": False, "return_centroids": False, "output_classes": None}, - [image], -] +TEST_CASE_1 = [seg_postpprocessing, {"return_centroids": True, "output_classes": 1}, [image, [10, 10]]] +TEST_CASE_2 = [seg_postpprocessing, {"return_centroids": False, "output_classes": None}, [image]] TEST_CASE = [] @@ -78,11 +70,12 @@ def test_value(self, in_type, test_data, seg_postpprocessing, kwargs, expected): HoVerNetBranch.NC.value: in_type(test_data[None]), } - post_transforms = HoVerNetPostProcessing(seg_postpprocessing=seg_postpprocessing, **kwargs) - out = post_transforms(input) - assert_allclose(out[HoVerNetBranch.NP.value].squeeze(), expected[0], type_test=False) - if out["inst_info_dict"] is not None: - assert_allclose(out["inst_info_dict"][1]["centroid"], expected[1], type_test=False) + pred = seg_postpprocessing(input) + + post_transforms = HoVerNetPostProcessing(**kwargs) + out = post_transforms(type_pred=in_type(test_data[None]), inst_pred=pred["dist"]) + if out is not None: + assert_allclose(out[1]["centroid"], expected[1], type_test=False) if __name__ == "__main__": diff --git a/tests/test_hovernet_post_processingd.py b/tests/test_hovernet_post_processingd.py new file mode 100644 index 0000000000..e534e5aaf8 --- /dev/null +++ b/tests/test_hovernet_post_processingd.py @@ -0,0 +1,80 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from parameterized import parameterized + +from monai.apps.pathology.transforms.post.dictionary import ( + GenerateDistanceMapd, + GenerateInstanceBorderd, + GenerateWatershedMarkersd, + GenerateWatershedMaskd, + HoVerNetPostProcessingd, + Watershedd, +) +from monai.transforms import Compose, ComputeHoVerMaps, FillHoles, GaussianSmooth +from monai.utils import min_version, optional_import +from monai.utils.enums import HoVerNetBranch +from tests.utils import TEST_NDARRAYS, assert_allclose + +_, has_scipy = optional_import("scipy", "1.8.1", min_version) +_, has_skimage = optional_import("skimage", "0.19.3", min_version) + +y, x = np.ogrid[0:30, 0:30] +image = (x - 10) ** 2 + (y - 10) ** 2 <= 5**2 + +seg_postpprocessing = [ + GenerateWatershedMaskd( + keys=HoVerNetBranch.NP.value, sigmoid=True, softmax=False, threshold=0.7, remove_small_objects=False + ), + GenerateInstanceBorderd( + keys="mask", hover_map_key=HoVerNetBranch.HV.value, kernel_size=3, remove_small_objects=False + ), + GenerateDistanceMapd(keys="mask", border_key="border", smooth_fn=GaussianSmooth()), + GenerateWatershedMarkersd(keys="mask", border_key="border", threshold=0.9, radius=2, postprocess_fn=FillHoles()), + Watershedd(keys="dist", mask_key="mask", markers_key="markers"), +] + +TEST_CASE_1 = [seg_postpprocessing, {"return_centroids": True, "output_classes": 1}, [image, [10, 10]]] +TEST_CASE_2 = [seg_postpprocessing, {"return_centroids": False, "output_classes": None}, [image]] + + +TEST_CASE = [] +for p in TEST_NDARRAYS: + TEST_CASE.append([p, image] + TEST_CASE_1) + TEST_CASE.append([p, image] + TEST_CASE_2) + + +@unittest.skipUnless(has_scipy, "Requires scipy library.") +@unittest.skipUnless(has_skimage, "Requires scikit-image library.") +class TestHoVerNetPostProcessing(unittest.TestCase): + @parameterized.expand(TEST_CASE) + def test_value(self, in_type, test_data, seg_postpprocessing, kwargs, expected): + hovermap = ComputeHoVerMaps()(test_data[None].astype(int)) + input = { + HoVerNetBranch.NP.value: in_type(test_data[None].astype(float)), + HoVerNetBranch.HV.value: in_type(hovermap), + HoVerNetBranch.NC.value: in_type(test_data[None]), + } + + class_trans = [HoVerNetPostProcessingd(**kwargs)] + post_transforms = Compose(seg_postpprocessing + class_trans) + out = post_transforms(input) + + assert_allclose(out["dist"].squeeze(), expected[0], type_test=False) + if out["inst_info_dict"] is not None: + assert_allclose(out["inst_info_dict"][1]["centroid"], expected[1], type_test=False) + + +if __name__ == "__main__": + unittest.main() From 95c2dfc4ee934c4e6673d81cd6cfc3f184bb8c7b Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Fri, 11 Nov 2022 22:38:13 +0800 Subject: [PATCH 14/15] fix docstring Signed-off-by: KumoLiu --- docs/source/apps.rst | 4 +++- monai/apps/pathology/transforms/__init__.py | 5 ++++- monai/apps/pathology/transforms/post/__init__.py | 5 ++++- monai/apps/pathology/transforms/post/array.py | 8 ++++---- monai/apps/pathology/transforms/post/dictionary.py | 12 ++++++++---- ...py => test_hovernet_nc_branch_post_processing.py} | 6 +++--- ...y => test_hovernet_nc_branch_post_processingd.py} | 6 +++--- 7 files changed, 29 insertions(+), 17 deletions(-) rename tests/{test_hovernet_post_processing.py => test_hovernet_nc_branch_post_processing.py} (93%) rename tests/{test_hovernet_post_processingd.py => test_hovernet_nc_branch_post_processingd.py} (94%) diff --git a/docs/source/apps.rst b/docs/source/apps.rst index dcfbd8e528..295a64e479 100644 --- a/docs/source/apps.rst +++ b/docs/source/apps.rst @@ -149,7 +149,7 @@ Applications :members: .. autoclass:: GenerateWatershedMarkers :members: -.. autoclass:: HoVerNetPostProcessing +.. autoclass:: HoVerNetNCBranchPostProcessing :members: .. automodule:: monai.apps.pathology.transforms.post.dictionary @@ -171,6 +171,8 @@ Applications :members: .. autoclass:: GenerateWatershedMarkersd :members: +.. autoclass:: HoVerNetNCBranchPostProcessingd + :members: `Detection` ----------- diff --git a/monai/apps/pathology/transforms/__init__.py b/monai/apps/pathology/transforms/__init__.py index 44ff50e271..182948a83c 100644 --- a/monai/apps/pathology/transforms/__init__.py +++ b/monai/apps/pathology/transforms/__init__.py @@ -18,7 +18,7 @@ GenerateSuccinctContour, GenerateWatershedMarkers, GenerateWatershedMask, - HoVerNetPostProcessing, + HoVerNetNCBranchPostProcessing, Watershed, ) from .post.dictionary import ( @@ -46,6 +46,9 @@ GenerateWatershedMaskD, GenerateWatershedMaskd, GenerateWatershedMaskDict, + HoVerNetNCBranchPostProcessingD, + HoVerNetNCBranchPostProcessingd, + HoVerNetNCBranchPostProcessingDict, WatershedD, Watershedd, WatershedDict, diff --git a/monai/apps/pathology/transforms/post/__init__.py b/monai/apps/pathology/transforms/post/__init__.py index bb50682c1a..f281cc49f4 100644 --- a/monai/apps/pathology/transforms/post/__init__.py +++ b/monai/apps/pathology/transforms/post/__init__.py @@ -18,7 +18,7 @@ GenerateSuccinctContour, GenerateWatershedMarkers, GenerateWatershedMask, - HoVerNetPostProcessing, + HoVerNetNCBranchPostProcessing, Watershed, ) from .dictionary import ( @@ -46,6 +46,9 @@ GenerateWatershedMaskD, GenerateWatershedMaskd, GenerateWatershedMaskDict, + HoVerNetNCBranchPostProcessingD, + HoVerNetNCBranchPostProcessingd, + HoVerNetNCBranchPostProcessingDict, WatershedD, Watershedd, WatershedDict, diff --git a/monai/apps/pathology/transforms/post/array.py b/monai/apps/pathology/transforms/post/array.py index 4b057a7a0f..15412fdf49 100644 --- a/monai/apps/pathology/transforms/post/array.py +++ b/monai/apps/pathology/transforms/post/array.py @@ -38,7 +38,7 @@ "GenerateInstanceContour", "GenerateInstanceCentroid", "GenerateInstanceType", - "HoVerNetPostProcessing", + "HoVerNetNCBranchPostProcessing", ] @@ -627,10 +627,10 @@ def __call__( # type: ignore return (int(inst_type), float(type_prob)) -class HoVerNetPostProcessing(Transform): +class HoVerNetNCBranchPostProcessing(Transform): """ - Since HoVerNet do segmentation and classification meanwhile, this transform is used to combine postprocessing - for segmentation and classification. It assumes input as a dictionary. + The whole post-procesing transform for nuclear type classification branch. It return a dict which contains + centroid, bounding box, type prediciton for each instance. Args: min_num_points: assumed that the created contour does not form a contour if it does not contain more points diff --git a/monai/apps/pathology/transforms/post/dictionary.py b/monai/apps/pathology/transforms/post/dictionary.py index 6cced88ca9..a695194147 100644 --- a/monai/apps/pathology/transforms/post/dictionary.py +++ b/monai/apps/pathology/transforms/post/dictionary.py @@ -22,7 +22,7 @@ GenerateSuccinctContour, GenerateWatershedMarkers, GenerateWatershedMask, - HoVerNetPostProcessing, + HoVerNetNCBranchPostProcessing, Watershed, ) from monai.config.type_definitions import DtypeLike, KeysCollection, NdarrayOrTensor @@ -61,6 +61,9 @@ "GenerateInstanceTypeDict", "GenerateInstanceTypeD", "GenerateInstanceTyped", + "HoVerNetNCBranchPostProcessingDict", + "HoVerNetNCBranchPostProcessingD", + "HoVerNetNCBranchPostProcessingd", ] @@ -482,9 +485,9 @@ def __call__(self, data): return d -class HoVerNetPostProcessingd(Transform): +class HoVerNetNCBranchPostProcessingd(Transform): """ - Dictionary-based wrapper of :py:class:`monai.apps.pathology.transforms.post.array.HoVerNetPostProcessing`. + Dictionary-based wrapper of :py:class:`monai.apps.pathology.transforms.post.array.HoVerNetNCBranchPostProcessing`. Generate instance type and probability for each instance. Args: @@ -517,7 +520,7 @@ def __init__( inst_info_dict_key: Optional[str] = "inst_info_dict", ) -> None: super().__init__() - self.converter = HoVerNetPostProcessing( + self.converter = HoVerNetNCBranchPostProcessing( min_num_points=min_num_points, level=level, return_centroids=return_centroids, output_classes=output_classes ) self.type_pred_key = type_pred_key @@ -553,3 +556,4 @@ def __call__(self, data): GenerateInstanceContourDict = GenerateInstanceContourD = GenerateInstanceContourd GenerateInstanceCentroidDict = GenerateInstanceCentroidD = GenerateInstanceCentroidd GenerateInstanceTypeDict = GenerateInstanceTypeD = GenerateInstanceTyped +HoVerNetNCBranchPostProcessingDict = HoVerNetNCBranchPostProcessingD = HoVerNetNCBranchPostProcessingd diff --git a/tests/test_hovernet_post_processing.py b/tests/test_hovernet_nc_branch_post_processing.py similarity index 93% rename from tests/test_hovernet_post_processing.py rename to tests/test_hovernet_nc_branch_post_processing.py index aa60929e53..1b06656e89 100644 --- a/tests/test_hovernet_post_processing.py +++ b/tests/test_hovernet_nc_branch_post_processing.py @@ -14,7 +14,7 @@ import numpy as np from parameterized import parameterized -from monai.apps.pathology.transforms.post.array import HoVerNetPostProcessing +from monai.apps.pathology.transforms.post.array import HoVerNetNCBranchPostProcessing from monai.apps.pathology.transforms.post.dictionary import ( GenerateDistanceMapd, GenerateInstanceBorderd, @@ -60,7 +60,7 @@ @unittest.skipUnless(has_scipy, "Requires scipy library.") @unittest.skipUnless(has_skimage, "Requires scikit-image library.") -class TestHoVerNetPostProcessing(unittest.TestCase): +class TestHoVerNetNCBranchPostProcessing(unittest.TestCase): @parameterized.expand(TEST_CASE) def test_value(self, in_type, test_data, seg_postpprocessing, kwargs, expected): hovermap = ComputeHoVerMaps()(test_data[None].astype(int)) @@ -72,7 +72,7 @@ def test_value(self, in_type, test_data, seg_postpprocessing, kwargs, expected): pred = seg_postpprocessing(input) - post_transforms = HoVerNetPostProcessing(**kwargs) + post_transforms = HoVerNetNCBranchPostProcessing(**kwargs) out = post_transforms(type_pred=in_type(test_data[None]), inst_pred=pred["dist"]) if out is not None: assert_allclose(out[1]["centroid"], expected[1], type_test=False) diff --git a/tests/test_hovernet_post_processingd.py b/tests/test_hovernet_nc_branch_post_processingd.py similarity index 94% rename from tests/test_hovernet_post_processingd.py rename to tests/test_hovernet_nc_branch_post_processingd.py index e534e5aaf8..9f2b684886 100644 --- a/tests/test_hovernet_post_processingd.py +++ b/tests/test_hovernet_nc_branch_post_processingd.py @@ -19,7 +19,7 @@ GenerateInstanceBorderd, GenerateWatershedMarkersd, GenerateWatershedMaskd, - HoVerNetPostProcessingd, + HoVerNetNCBranchPostProcessingd, Watershedd, ) from monai.transforms import Compose, ComputeHoVerMaps, FillHoles, GaussianSmooth @@ -57,7 +57,7 @@ @unittest.skipUnless(has_scipy, "Requires scipy library.") @unittest.skipUnless(has_skimage, "Requires scikit-image library.") -class TestHoVerNetPostProcessing(unittest.TestCase): +class TestHoVerNetNCBranchPostProcessingd(unittest.TestCase): @parameterized.expand(TEST_CASE) def test_value(self, in_type, test_data, seg_postpprocessing, kwargs, expected): hovermap = ComputeHoVerMaps()(test_data[None].astype(int)) @@ -67,7 +67,7 @@ def test_value(self, in_type, test_data, seg_postpprocessing, kwargs, expected): HoVerNetBranch.NC.value: in_type(test_data[None]), } - class_trans = [HoVerNetPostProcessingd(**kwargs)] + class_trans = [HoVerNetNCBranchPostProcessingd(**kwargs)] post_transforms = Compose(seg_postpprocessing + class_trans) out = post_transforms(input) From 87a5ea6827e05b89204d466d7b88f3e94392f856 Mon Sep 17 00:00:00 2001 From: yunliu Date: Tue, 15 Nov 2022 06:06:07 +0000 Subject: [PATCH 15/15] update based on comments Signed-off-by: yunliu --- docs/source/apps.rst | 4 +-- monai/apps/pathology/transforms/__init__.py | 8 ++--- .../pathology/transforms/post/__init__.py | 8 ++--- monai/apps/pathology/transforms/post/array.py | 12 +++---- .../pathology/transforms/post/dictionary.py | 32 +++++++++---------- ..._hovernet_nuclear_type_post_processing.py} | 8 ++--- ...hovernet_nuclear_type_post_processingd.py} | 10 +++--- 7 files changed, 41 insertions(+), 41 deletions(-) rename tests/{test_hovernet_nc_branch_post_processing.py => test_hovernet_nuclear_type_post_processing.py} (94%) rename tests/{test_hovernet_nc_branch_post_processingd.py => test_hovernet_nuclear_type_post_processingd.py} (90%) diff --git a/docs/source/apps.rst b/docs/source/apps.rst index 295a64e479..cf44547aa7 100644 --- a/docs/source/apps.rst +++ b/docs/source/apps.rst @@ -149,7 +149,7 @@ Applications :members: .. autoclass:: GenerateWatershedMarkers :members: -.. autoclass:: HoVerNetNCBranchPostProcessing +.. autoclass:: HoVerNetNuclearTypePostProcessing :members: .. automodule:: monai.apps.pathology.transforms.post.dictionary @@ -171,7 +171,7 @@ Applications :members: .. autoclass:: GenerateWatershedMarkersd :members: -.. autoclass:: HoVerNetNCBranchPostProcessingd +.. autoclass:: HoVerNetNuclearTypePostProcessingd :members: `Detection` diff --git a/monai/apps/pathology/transforms/__init__.py b/monai/apps/pathology/transforms/__init__.py index 182948a83c..18b53d7f2a 100644 --- a/monai/apps/pathology/transforms/__init__.py +++ b/monai/apps/pathology/transforms/__init__.py @@ -18,7 +18,7 @@ GenerateSuccinctContour, GenerateWatershedMarkers, GenerateWatershedMask, - HoVerNetNCBranchPostProcessing, + HoVerNetNuclearTypePostProcessing, Watershed, ) from .post.dictionary import ( @@ -46,9 +46,9 @@ GenerateWatershedMaskD, GenerateWatershedMaskd, GenerateWatershedMaskDict, - HoVerNetNCBranchPostProcessingD, - HoVerNetNCBranchPostProcessingd, - HoVerNetNCBranchPostProcessingDict, + HoVerNetNuclearTypePostProcessingD, + HoVerNetNuclearTypePostProcessingd, + HoVerNetNuclearTypePostProcessingDict, WatershedD, Watershedd, WatershedDict, diff --git a/monai/apps/pathology/transforms/post/__init__.py b/monai/apps/pathology/transforms/post/__init__.py index f281cc49f4..836582b8a3 100644 --- a/monai/apps/pathology/transforms/post/__init__.py +++ b/monai/apps/pathology/transforms/post/__init__.py @@ -18,7 +18,7 @@ GenerateSuccinctContour, GenerateWatershedMarkers, GenerateWatershedMask, - HoVerNetNCBranchPostProcessing, + HoVerNetNuclearTypePostProcessing, Watershed, ) from .dictionary import ( @@ -46,9 +46,9 @@ GenerateWatershedMaskD, GenerateWatershedMaskd, GenerateWatershedMaskDict, - HoVerNetNCBranchPostProcessingD, - HoVerNetNCBranchPostProcessingd, - HoVerNetNCBranchPostProcessingDict, + HoVerNetNuclearTypePostProcessingD, + HoVerNetNuclearTypePostProcessingd, + HoVerNetNuclearTypePostProcessingDict, WatershedD, Watershedd, WatershedDict, diff --git a/monai/apps/pathology/transforms/post/array.py b/monai/apps/pathology/transforms/post/array.py index 15412fdf49..619dd0028a 100644 --- a/monai/apps/pathology/transforms/post/array.py +++ b/monai/apps/pathology/transforms/post/array.py @@ -38,7 +38,7 @@ "GenerateInstanceContour", "GenerateInstanceCentroid", "GenerateInstanceType", - "HoVerNetNCBranchPostProcessing", + "HoVerNetNuclearTypePostProcessing", ] @@ -627,7 +627,7 @@ def __call__( # type: ignore return (int(inst_type), float(type_prob)) -class HoVerNetNCBranchPostProcessing(Transform): +class HoVerNetNuclearTypePostProcessing(Transform): """ The whole post-procesing transform for nuclear type classification branch. It return a dict which contains centroid, bounding box, type prediciton for each instance. @@ -656,16 +656,16 @@ def __init__( self.generate_instance_centroid = GenerateInstanceCentroid() self.generate_instance_type = GenerateInstanceType() - def __call__(self, type_pred: NdarrayOrTensor, inst_pred: NdarrayOrTensor) -> Dict: # type: ignore + def __call__(self, type_pred: NdarrayOrTensor, instance_pred: NdarrayOrTensor) -> Dict: # type: ignore type_pred = Activations(softmax=True)(type_pred) type_pred = AsDiscrete(argmax=True)(type_pred) - inst_id_list = np.unique(inst_pred)[1:] # exclude background + inst_id_list = np.unique(instance_pred)[1:] # exclude background inst_info_dict = None if self.return_centroids: inst_info_dict = {} for inst_id in inst_id_list: - inst_map = inst_pred == inst_id + inst_map = instance_pred == inst_id inst_bbox = BoundingRect()(inst_map) inst_map = inst_map[:, inst_bbox[0][0] : inst_bbox[0][1], inst_bbox[0][2] : inst_bbox[0][3]] offset = [inst_bbox[0][2], inst_bbox[0][0]] @@ -685,7 +685,7 @@ def __call__(self, type_pred: NdarrayOrTensor, inst_pred: NdarrayOrTensor) -> Di inst_type, type_prob = self.generate_instance_type( bbox=inst_info_dict[inst_id]["bounding_box"], # type: ignore type_pred=type_pred, - seg_pred=inst_pred, + seg_pred=instance_pred, instance_id=inst_id, ) inst_info_dict[inst_id]["type"] = inst_type # type: ignore diff --git a/monai/apps/pathology/transforms/post/dictionary.py b/monai/apps/pathology/transforms/post/dictionary.py index a695194147..6c42c8ad41 100644 --- a/monai/apps/pathology/transforms/post/dictionary.py +++ b/monai/apps/pathology/transforms/post/dictionary.py @@ -22,7 +22,7 @@ GenerateSuccinctContour, GenerateWatershedMarkers, GenerateWatershedMask, - HoVerNetNCBranchPostProcessing, + HoVerNetNuclearTypePostProcessing, Watershed, ) from monai.config.type_definitions import DtypeLike, KeysCollection, NdarrayOrTensor @@ -61,9 +61,9 @@ "GenerateInstanceTypeDict", "GenerateInstanceTypeD", "GenerateInstanceTyped", - "HoVerNetNCBranchPostProcessingDict", - "HoVerNetNCBranchPostProcessingD", - "HoVerNetNCBranchPostProcessingd", + "HoVerNetNuclearTypePostProcessingDict", + "HoVerNetNuclearTypePostProcessingD", + "HoVerNetNuclearTypePostProcessingd", ] @@ -485,14 +485,14 @@ def __call__(self, data): return d -class HoVerNetNCBranchPostProcessingd(Transform): +class HoVerNetNuclearTypePostProcessingd(Transform): """ - Dictionary-based wrapper of :py:class:`monai.apps.pathology.transforms.post.array.HoVerNetNCBranchPostProcessing`. + Dictionary-based wrapper of :py:class:`monai.apps.pathology.transforms.post.array.HoVerNetNuclearTypePostProcessing`. Generate instance type and probability for each instance. Args: type_pred_key: the key pointing to the pred type map to be transformed. - inst_pred_key: the key pointing to the pred distance map. + instance_pred_key: the key pointing to the pred distance map. min_num_points: assumed that the created contour does not form a contour if it does not contain more points than the specified value. Defaults to 3. level: optional. Used in `skimage.measure.find_contours`. Value along which to find contours in the array. @@ -502,7 +502,7 @@ class HoVerNetNCBranchPostProcessingd(Transform): from the output. return_centroids: whether to return centroids for each instance. output_classes: number of the nuclear type classes. - inst_info_dict_key: key use to record information for each instance. + instance_info_dict_key: key use to record information for each instance. allow_missing_keys: don't raise exception if key is missing. """ @@ -510,31 +510,31 @@ class HoVerNetNCBranchPostProcessingd(Transform): def __init__( self, type_pred_key: str = HoVerNetBranch.NC.value, - inst_pred_key: str = "dist", + instance_pred_key: str = "dist", min_num_points: int = 3, level: Optional[float] = None, return_binary: Optional[bool] = True, pred_binary_key: Optional[str] = "pred_binary", return_centroids: Optional[bool] = False, output_classes: Optional[int] = None, - inst_info_dict_key: Optional[str] = "inst_info_dict", + instance_info_dict_key: Optional[str] = "instance_info_dict", ) -> None: super().__init__() - self.converter = HoVerNetNCBranchPostProcessing( + self.converter = HoVerNetNuclearTypePostProcessing( min_num_points=min_num_points, level=level, return_centroids=return_centroids, output_classes=output_classes ) self.type_pred_key = type_pred_key - self.inst_pred_key = inst_pred_key + self.instance_pred_key = instance_pred_key self.pred_binary_key = pred_binary_key - self.inst_info_dict_key = inst_info_dict_key + self.instance_info_dict_key = instance_info_dict_key self.return_binary = return_binary def __call__(self, data): d = dict(data) - inst_pred = d[self.inst_pred_key] + inst_pred = d[self.instance_pred_key] type_pred = d[self.type_pred_key] inst_info_dict = self.converter(type_pred, inst_pred) - key_to_add = f"{self.inst_info_dict_key}" + key_to_add = f"{self.instance_info_dict_key}" if key_to_add in d: raise KeyError(f"Type information with key {key_to_add} already exists.") d[key_to_add] = inst_info_dict # type: ignore @@ -556,4 +556,4 @@ def __call__(self, data): GenerateInstanceContourDict = GenerateInstanceContourD = GenerateInstanceContourd GenerateInstanceCentroidDict = GenerateInstanceCentroidD = GenerateInstanceCentroidd GenerateInstanceTypeDict = GenerateInstanceTypeD = GenerateInstanceTyped -HoVerNetNCBranchPostProcessingDict = HoVerNetNCBranchPostProcessingD = HoVerNetNCBranchPostProcessingd +HoVerNetNuclearTypePostProcessingDict = HoVerNetNuclearTypePostProcessingD = HoVerNetNuclearTypePostProcessingd diff --git a/tests/test_hovernet_nc_branch_post_processing.py b/tests/test_hovernet_nuclear_type_post_processing.py similarity index 94% rename from tests/test_hovernet_nc_branch_post_processing.py rename to tests/test_hovernet_nuclear_type_post_processing.py index 1b06656e89..954c2a34fb 100644 --- a/tests/test_hovernet_nc_branch_post_processing.py +++ b/tests/test_hovernet_nuclear_type_post_processing.py @@ -14,7 +14,7 @@ import numpy as np from parameterized import parameterized -from monai.apps.pathology.transforms.post.array import HoVerNetNCBranchPostProcessing +from monai.apps.pathology.transforms.post.array import HoVerNetNuclearTypePostProcessing from monai.apps.pathology.transforms.post.dictionary import ( GenerateDistanceMapd, GenerateInstanceBorderd, @@ -60,7 +60,7 @@ @unittest.skipUnless(has_scipy, "Requires scipy library.") @unittest.skipUnless(has_skimage, "Requires scikit-image library.") -class TestHoVerNetNCBranchPostProcessing(unittest.TestCase): +class TestHoVerNetNuclearTypePostProcessing(unittest.TestCase): @parameterized.expand(TEST_CASE) def test_value(self, in_type, test_data, seg_postpprocessing, kwargs, expected): hovermap = ComputeHoVerMaps()(test_data[None].astype(int)) @@ -72,8 +72,8 @@ def test_value(self, in_type, test_data, seg_postpprocessing, kwargs, expected): pred = seg_postpprocessing(input) - post_transforms = HoVerNetNCBranchPostProcessing(**kwargs) - out = post_transforms(type_pred=in_type(test_data[None]), inst_pred=pred["dist"]) + post_transforms = HoVerNetNuclearTypePostProcessing(**kwargs) + out = post_transforms(type_pred=in_type(test_data[None]), instance_pred=pred["dist"]) if out is not None: assert_allclose(out[1]["centroid"], expected[1], type_test=False) diff --git a/tests/test_hovernet_nc_branch_post_processingd.py b/tests/test_hovernet_nuclear_type_post_processingd.py similarity index 90% rename from tests/test_hovernet_nc_branch_post_processingd.py rename to tests/test_hovernet_nuclear_type_post_processingd.py index 9f2b684886..116a6bb9af 100644 --- a/tests/test_hovernet_nc_branch_post_processingd.py +++ b/tests/test_hovernet_nuclear_type_post_processingd.py @@ -19,7 +19,7 @@ GenerateInstanceBorderd, GenerateWatershedMarkersd, GenerateWatershedMaskd, - HoVerNetNCBranchPostProcessingd, + HoVerNetNuclearTypePostProcessingd, Watershedd, ) from monai.transforms import Compose, ComputeHoVerMaps, FillHoles, GaussianSmooth @@ -57,7 +57,7 @@ @unittest.skipUnless(has_scipy, "Requires scipy library.") @unittest.skipUnless(has_skimage, "Requires scikit-image library.") -class TestHoVerNetNCBranchPostProcessingd(unittest.TestCase): +class TestHoVerNetNuclearTypePostProcessingd(unittest.TestCase): @parameterized.expand(TEST_CASE) def test_value(self, in_type, test_data, seg_postpprocessing, kwargs, expected): hovermap = ComputeHoVerMaps()(test_data[None].astype(int)) @@ -67,13 +67,13 @@ def test_value(self, in_type, test_data, seg_postpprocessing, kwargs, expected): HoVerNetBranch.NC.value: in_type(test_data[None]), } - class_trans = [HoVerNetNCBranchPostProcessingd(**kwargs)] + class_trans = [HoVerNetNuclearTypePostProcessingd(**kwargs)] post_transforms = Compose(seg_postpprocessing + class_trans) out = post_transforms(input) assert_allclose(out["dist"].squeeze(), expected[0], type_test=False) - if out["inst_info_dict"] is not None: - assert_allclose(out["inst_info_dict"][1]["centroid"], expected[1], type_test=False) + if out["instance_info_dict"] is not None: + assert_allclose(out["instance_info_dict"][1]["centroid"], expected[1], type_test=False) if __name__ == "__main__":