From e6c6d025d9a00f81c0255e5c8ccd8538972ace02 Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Tue, 11 Apr 2023 22:58:19 +0800 Subject: [PATCH 1/2] add `device` in hovernet postpocessing Signed-off-by: KumoLiu --- monai/apps/pathology/transforms/post/array.py | 17 +++++++++++++---- .../pathology/transforms/post/dictionary.py | 9 +++++++-- 2 files changed, 20 insertions(+), 6 deletions(-) diff --git a/monai/apps/pathology/transforms/post/array.py b/monai/apps/pathology/transforms/post/array.py index 5289dc101c..88da2bad6f 100644 --- a/monai/apps/pathology/transforms/post/array.py +++ b/monai/apps/pathology/transforms/post/array.py @@ -31,7 +31,7 @@ from monai.transforms.utils_pytorch_numpy_unification import max, maximum, min, sum, unique from monai.utils import TransformBackends, convert_to_numpy, optional_import from monai.utils.misc import ensure_tuple_rep -from monai.utils.type_conversion import convert_to_dst_type +from monai.utils.type_conversion import convert_to_dst_type, convert_to_tensor label, _ = optional_import("scipy.ndimage.measurements", name="label") disk, _ = optional_import("skimage.morphology", name="disk") @@ -671,6 +671,7 @@ class HoVerNetInstanceMapPostProcessing(Transform): min_num_points: minimum number of points to be considered as a contour. Defaults to 3. contour_level: an optional value for `skimage.measure.find_contours` to find contours in the array. If not provided, the level is set to `(max(image) + min(image)) / 2`. + device: target device to put the output Tensor data. """ def __init__( @@ -686,9 +687,10 @@ def __init__( watershed_connectivity: int | None = 1, min_num_points: int = 3, contour_level: float | None = None, + device: str | torch.device | None = None ) -> None: super().__init__() - + self.device = device self.generate_watershed_mask = GenerateWatershedMask( activation=activation, threshold=mask_threshold, min_object_size=min_object_size ) @@ -742,7 +744,7 @@ def __call__( # type: ignore "centroid": instance_centroid, "contour": instance_contour, } - + instance_map = convert_to_tensor(instance_map, device=self.device) return instance_info, instance_map @@ -758,13 +760,19 @@ class HoVerNetNuclearTypePostProcessing(Transform): threshold: an optional float value to threshold to binarize probability map. If not provided, defaults to 0.5 when activation is not "softmax", otherwise None. return_type_map: whether to calculate and return pixel-level type map. + device: target device to put the output Tensor data. """ def __init__( - self, activation: str | Callable = "softmax", threshold: float | None = None, return_type_map: bool = True + self, + activation: str | Callable = "softmax", + threshold: float | None = None, + return_type_map: bool = True, + device: str | torch.device | None = None ) -> None: super().__init__() + self.device = device self.return_type_map = return_type_map self.generate_instance_type = GenerateInstanceType() @@ -824,5 +832,6 @@ def __call__( # type: ignore # update instance type map if type_map is not None: type_map[instance_map == inst_id] = instance_type + type_map = convert_to_tensor(type_map, device=self.device) return instance_info, type_map diff --git a/monai/apps/pathology/transforms/post/dictionary.py b/monai/apps/pathology/transforms/post/dictionary.py index ef6de1b596..6dd9d52fd4 100644 --- a/monai/apps/pathology/transforms/post/dictionary.py +++ b/monai/apps/pathology/transforms/post/dictionary.py @@ -13,6 +13,7 @@ from collections.abc import Callable, Hashable, Mapping +import torch import numpy as np from monai.apps.pathology.transforms.post.array import ( @@ -488,6 +489,7 @@ class HoVerNetInstanceMapPostProcessingd(Transform): min_num_points: minimum number of points to be considered as a contour. Defaults to 3. contour_level: an optional value for `skimage.measure.find_contours` to find contours in the array. If not provided, the level is set to `(max(image) + min(image)) / 2`. + device: target device to put the output Tensor data. """ def __init__( @@ -507,6 +509,7 @@ def __init__( watershed_connectivity: int | None = 1, min_num_points: int = 3, contour_level: float | None = None, + device: str | torch.device | None = None, ) -> None: super().__init__() self.instance_map_post_process = HoVerNetInstanceMapPostProcessing( @@ -521,6 +524,7 @@ def __init__( watershed_connectivity=watershed_connectivity, min_num_points=min_num_points, contour_level=contour_level, + device=device, ) self.nuclear_prediction_key = nuclear_prediction_key self.hover_map_key = hover_map_key @@ -553,7 +557,7 @@ class HoVerNetNuclearTypePostProcessingd(Transform): Defaults to `"instance_info"`. instance_map_key: the key where instance map is stored. Defaults to `"instance_map"`. type_map_key: the output key where type map is written. Defaults to `"type_map"`. - + device: target device to put the output Tensor data. """ @@ -566,10 +570,11 @@ def __init__( activation: str | Callable = "softmax", threshold: float | None = None, return_type_map: bool = True, + device: str | torch.device | None = None, ) -> None: super().__init__() self.type_post_process = HoVerNetNuclearTypePostProcessing( - activation=activation, threshold=threshold, return_type_map=return_type_map + activation=activation, threshold=threshold, return_type_map=return_type_map, device=device ) self.type_prediction_key = type_prediction_key self.instance_info_key = instance_info_key From e0236780abe23990b4ac040d0cf8b1a41091ad9c Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Tue, 11 Apr 2023 16:40:59 +0100 Subject: [PATCH 2/2] style fixes Signed-off-by: Wenqi Li --- monai/apps/pathology/transforms/post/array.py | 4 ++-- monai/apps/pathology/transforms/post/dictionary.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/monai/apps/pathology/transforms/post/array.py b/monai/apps/pathology/transforms/post/array.py index 88da2bad6f..248ff24bec 100644 --- a/monai/apps/pathology/transforms/post/array.py +++ b/monai/apps/pathology/transforms/post/array.py @@ -687,7 +687,7 @@ def __init__( watershed_connectivity: int | None = 1, min_num_points: int = 3, contour_level: float | None = None, - device: str | torch.device | None = None + device: str | torch.device | None = None, ) -> None: super().__init__() self.device = device @@ -769,7 +769,7 @@ def __init__( activation: str | Callable = "softmax", threshold: float | None = None, return_type_map: bool = True, - device: str | torch.device | None = None + device: str | torch.device | None = None, ) -> None: super().__init__() self.device = device diff --git a/monai/apps/pathology/transforms/post/dictionary.py b/monai/apps/pathology/transforms/post/dictionary.py index 6dd9d52fd4..a95bdfd48f 100644 --- a/monai/apps/pathology/transforms/post/dictionary.py +++ b/monai/apps/pathology/transforms/post/dictionary.py @@ -13,8 +13,8 @@ from collections.abc import Callable, Hashable, Mapping -import torch import numpy as np +import torch from monai.apps.pathology.transforms.post.array import ( GenerateDistanceMap,