Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
5fc2478
first commit
KumoLiu Nov 3, 2022
5346b6e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 3, 2022
dc54220
updated based on comments
KumoLiu Nov 3, 2022
8c1820f
Merge branch 'postprocess-hovernet' of https://github.com/KumoLiu/MON…
KumoLiu Nov 3, 2022
c553867
add docstring
KumoLiu Nov 3, 2022
4ce5a53
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 3, 2022
e91635b
minor fix
KumoLiu Nov 3, 2022
7c26e39
Merge branch 'dev' into postprocess-hovernet
KumoLiu Nov 7, 2022
43ce043
Merge remote-tracking branch 'origin/dev' into postprocess-hovernet
KumoLiu Nov 7, 2022
b4d0053
Merge branch 'postprocess-hovernet' of https://github.com/KumoLiu/MON…
KumoLiu Nov 7, 2022
2e05042
fix flake8
KumoLiu Nov 7, 2022
d3d9941
fix flake8
KumoLiu Nov 8, 2022
cfafa4f
Merge remote-tracking branch 'origin/dev' into postprocess-hovernet
KumoLiu Nov 8, 2022
accd1ce
Merge branch 'dev' into postprocess-hovernet
bhashemian Nov 8, 2022
348f830
update basd on comments
KumoLiu Nov 9, 2022
ddeea74
Merge branch 'postprocess-hovernet' of https://github.com/KumoLiu/MON…
KumoLiu Nov 9, 2022
4260e85
Merge remote-tracking branch 'origin/dev' into postprocess-hovernet
KumoLiu Nov 9, 2022
a321474
add unit tests
KumoLiu Nov 9, 2022
399a05c
fix CI
KumoLiu Nov 9, 2022
767f88a
minor fix
KumoLiu Nov 9, 2022
a9cf3ec
Merge branch 'dev' into postprocess-hovernet
bhashemian Nov 9, 2022
2c64220
Merge branch 'dev' into postprocess-hovernet
bhashemian Nov 10, 2022
fbb0546
update based on comments
KumoLiu Nov 11, 2022
653a0e0
Merge remote-tracking branch 'origin/dev' into postprocess-hovernet
KumoLiu Nov 11, 2022
95c2dfc
fix docstring
KumoLiu Nov 11, 2022
cde52c4
Merge branch 'dev' into postprocess-hovernet
KumoLiu Nov 12, 2022
85aba18
Merge branch 'dev' into postprocess-hovernet
bhashemian Nov 14, 2022
e4e07ac
Merge remote-tracking branch 'origin/dev' into postprocess-hovernet
KumoLiu Nov 15, 2022
87a5ea6
update based on comments
KumoLiu Nov 15, 2022
9a58465
Merge branch 'dev' into postprocess-hovernet
bhashemian Nov 15, 2022
f675f9e
Merge branch 'dev' into postprocess-hovernet
Nic-Ma Nov 16, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/source/apps.rst
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,8 @@ Applications
:members:
.. autoclass:: GenerateWatershedMarkers
:members:
.. autoclass:: HoVerNetNuclearTypePostProcessing
:members:

.. automodule:: monai.apps.pathology.transforms.post.dictionary
.. autoclass:: GenerateSuccinctContourd
Expand All @@ -169,6 +171,8 @@ Applications
:members:
.. autoclass:: GenerateWatershedMarkersd
:members:
.. autoclass:: HoVerNetNuclearTypePostProcessingd
:members:

`Detection`
-----------
Expand Down
4 changes: 4 additions & 0 deletions monai/apps/pathology/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
GenerateSuccinctContour,
GenerateWatershedMarkers,
GenerateWatershedMask,
HoVerNetNuclearTypePostProcessing,
Watershed,
)
from .post.dictionary import (
Expand Down Expand Up @@ -45,6 +46,9 @@
GenerateWatershedMaskD,
GenerateWatershedMaskd,
GenerateWatershedMaskDict,
HoVerNetNuclearTypePostProcessingD,
HoVerNetNuclearTypePostProcessingd,
HoVerNetNuclearTypePostProcessingDict,
WatershedD,
Watershedd,
WatershedDict,
Expand Down
4 changes: 4 additions & 0 deletions monai/apps/pathology/transforms/post/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
GenerateSuccinctContour,
GenerateWatershedMarkers,
GenerateWatershedMask,
HoVerNetNuclearTypePostProcessing,
Watershed,
)
from .dictionary import (
Expand Down Expand Up @@ -45,6 +46,9 @@
GenerateWatershedMaskD,
GenerateWatershedMaskd,
GenerateWatershedMaskDict,
HoVerNetNuclearTypePostProcessingD,
HoVerNetNuclearTypePostProcessingd,
HoVerNetNuclearTypePostProcessingDict,
WatershedD,
Watershedd,
WatershedDict,
Expand Down
84 changes: 76 additions & 8 deletions monai/apps/pathology/transforms/post/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Callable, List, Optional, Sequence, Tuple, Union
from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union

import numpy as np

from monai.config.type_definitions import DtypeLike, NdarrayOrTensor
from monai.transforms.post.array import Activations, AsDiscrete, RemoveSmallObjects, SobelGradients
from monai.transforms import Activations, AsDiscrete, BoundingRect, RemoveSmallObjects, SobelGradients
from monai.transforms.transform import Transform
from monai.transforms.utils_pytorch_numpy_unification import max, maximum, min, sum, unique
from monai.utils import TransformBackends, convert_to_numpy, optional_import
Expand All @@ -38,6 +38,7 @@
"GenerateInstanceContour",
"GenerateInstanceCentroid",
"GenerateInstanceType",
"HoVerNetNuclearTypePostProcessing",
]


Expand Down Expand Up @@ -512,7 +513,7 @@ class GenerateInstanceContour(Transform):
the pixels to which lines need to be drawn

Args:
points_num: assumed that the created contour does not form a contour if it does not contain more points
min_num_points: assumed that the created contour does not form a contour if it does not contain more points
than the specified value. Defaults to 3.
level: optional. Value along which to find contours in the array. By default, the level is set
to (max(image) + min(image)) / 2.
Expand All @@ -521,9 +522,9 @@ class GenerateInstanceContour(Transform):

backend = [TransformBackends.NUMPY]

def __init__(self, points_num: int = 3, level: Optional[float] = None) -> None:
def __init__(self, min_num_points: int = 3, level: Optional[float] = None) -> None:
self.level = level
self.points_num = points_num
self.min_num_points = min_num_points

def __call__(self, image: NdarrayOrTensor, offset: Optional[Sequence[int]] = (0, 0)) -> np.ndarray:
"""
Expand All @@ -537,10 +538,10 @@ def __call__(self, image: NdarrayOrTensor, offset: Optional[Sequence[int]] = (0,
generate_contour = GenerateSuccinctContour(image.shape[0], image.shape[1])
inst_contour = generate_contour(inst_contour_cv)

# < `self.points_num` points don't make a contour, so skip, likely artifact too
# < `self.min_num_points` points don't make a contour, so skip, likely artifact too
# as the contours obtained via approximation => too small or sthg
if inst_contour.shape[0] < self.points_num:
print(f"< {self.points_num} points don't make a contour, so skip")
if inst_contour.shape[0] < self.min_num_points:
print(f"< {self.min_num_points} points don't make a contour, so skip")
return None # type: ignore
# check for tricky shape
elif len(inst_contour.shape) != 2:
Expand Down Expand Up @@ -624,3 +625,70 @@ def __call__( # type: ignore
type_prob = type_dict[inst_type] / (sum(seg_map_crop) + 1.0e-6)

return (int(inst_type), float(type_prob))


class HoVerNetNuclearTypePostProcessing(Transform):
"""
The whole post-procesing transform for nuclear type classification branch. It return a dict which contains
centroid, bounding box, type prediciton for each instance.

Args:
min_num_points: assumed that the created contour does not form a contour if it does not contain more points
than the specified value. Defaults to 3.
level: optional. Used in `skimage.measure.find_contours`. Value along which to find contours in the array.
By default, the level is set to (max(image) + min(image)) / 2.
return_centroids: whether to return centroids for each instance.
output_classes: number of the nuclear type classes.
"""

def __init__(
self,
min_num_points: int = 3,
level: Optional[float] = None,
return_centroids: Optional[bool] = False,
output_classes: Optional[int] = None,
) -> None:
super().__init__()
self.return_centroids = return_centroids
self.output_classes = output_classes

self.generate_instance_contour = GenerateInstanceContour(min_num_points=min_num_points, level=level)
self.generate_instance_centroid = GenerateInstanceCentroid()
self.generate_instance_type = GenerateInstanceType()

def __call__(self, type_pred: NdarrayOrTensor, instance_pred: NdarrayOrTensor) -> Dict: # type: ignore
type_pred = Activations(softmax=True)(type_pred)
type_pred = AsDiscrete(argmax=True)(type_pred)

inst_id_list = np.unique(instance_pred)[1:] # exclude background
inst_info_dict = None
if self.return_centroids:
inst_info_dict = {}
for inst_id in inst_id_list:
inst_map = instance_pred == inst_id
inst_bbox = BoundingRect()(inst_map)
inst_map = inst_map[:, inst_bbox[0][0] : inst_bbox[0][1], inst_bbox[0][2] : inst_bbox[0][3]]
offset = [inst_bbox[0][2], inst_bbox[0][0]]
inst_contour = self.generate_instance_contour(inst_map, offset)
inst_centroid = self.generate_instance_centroid(inst_map, offset)
if inst_contour is not None:
inst_info_dict[inst_id] = { # inst_id should start at 1
"bounding_box": inst_bbox,
"centroid": inst_centroid,
"contour": inst_contour,
"type_probability": None,
"type": None,
}

if self.output_classes is not None:
for inst_id in list(inst_info_dict.keys()):
inst_type, type_prob = self.generate_instance_type(
bbox=inst_info_dict[inst_id]["bounding_box"], # type: ignore
type_pred=type_pred,
seg_pred=instance_pred,
instance_id=inst_id,
)
inst_info_dict[inst_id]["type"] = inst_type # type: ignore
inst_info_dict[inst_id]["type_probability"] = type_prob # type: ignore

return inst_info_dict
78 changes: 73 additions & 5 deletions monai/apps/pathology/transforms/post/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,13 @@
GenerateSuccinctContour,
GenerateWatershedMarkers,
GenerateWatershedMask,
HoVerNetNuclearTypePostProcessing,
Watershed,
)
from monai.config.type_definitions import DtypeLike, KeysCollection, NdarrayOrTensor
from monai.transforms.transform import MapTransform
from monai.utils import optional_import
from monai.transforms.transform import MapTransform, Transform
from monai.utils import convert_to_dst_type, optional_import
from monai.utils.enums import HoVerNetBranch

find_contours, _ = optional_import("skimage.measure", name="find_contours")
moments, _ = optional_import("skimage.measure", name="moments")
Expand Down Expand Up @@ -59,6 +61,9 @@
"GenerateInstanceTypeDict",
"GenerateInstanceTypeD",
"GenerateInstanceTyped",
"HoVerNetNuclearTypePostProcessingDict",
"HoVerNetNuclearTypePostProcessingD",
"HoVerNetNuclearTypePostProcessingd",
]


Expand Down Expand Up @@ -354,7 +359,7 @@ class GenerateInstanceContourd(MapTransform):
contour_key_postfix: the output contour coordinates will be written to the value of
`{key}_{contour_key_postfix}`.
offset_key: keys of offset used in `GenerateInstanceContour`.
points_num: assumed that the created contour does not form a contour if it does not contain more points
min_num_points: assumed that the created contour does not form a contour if it does not contain more points
than the specified value. Defaults to 3.
level: optional. Value along which to find contours in the array. By default, the level is set
to (max(image) + min(image)) / 2.
Expand All @@ -369,12 +374,12 @@ def __init__(
keys: KeysCollection,
contour_key_postfix: str = "contour",
offset_key: Optional[str] = None,
points_num: int = 3,
min_num_points: int = 3,
level: Optional[float] = None,
allow_missing_keys: bool = False,
) -> None:
super().__init__(keys, allow_missing_keys)
self.converter = GenerateInstanceContour(points_num=points_num, level=level)
self.converter = GenerateInstanceContour(min_num_points=min_num_points, level=level)
self.contour_key_postfix = contour_key_postfix
self.offset_key = offset_key

Expand Down Expand Up @@ -480,6 +485,68 @@ def __call__(self, data):
return d


class HoVerNetNuclearTypePostProcessingd(Transform):
"""
Dictionary-based wrapper of :py:class:`monai.apps.pathology.transforms.post.array.HoVerNetNuclearTypePostProcessing`.
Generate instance type and probability for each instance.

Args:
type_pred_key: the key pointing to the pred type map to be transformed.
instance_pred_key: the key pointing to the pred distance map.
min_num_points: assumed that the created contour does not form a contour if it does not contain more points
than the specified value. Defaults to 3.
level: optional. Used in `skimage.measure.find_contours`. Value along which to find contours in the array.
By default, the level is set to (max(image) + min(image)) / 2.
return_binary: whether to return the binary segmentation prediction after `seg_postprocessing`.
pred_binary_key: if `return_binary` is True, this `pred_binary_key` is used to get the binary prediciton
from the output.
return_centroids: whether to return centroids for each instance.
output_classes: number of the nuclear type classes.
instance_info_dict_key: key use to record information for each instance.
allow_missing_keys: don't raise exception if key is missing.

"""

def __init__(
self,
type_pred_key: str = HoVerNetBranch.NC.value,
instance_pred_key: str = "dist",
min_num_points: int = 3,
level: Optional[float] = None,
return_binary: Optional[bool] = True,
pred_binary_key: Optional[str] = "pred_binary",
return_centroids: Optional[bool] = False,
output_classes: Optional[int] = None,
instance_info_dict_key: Optional[str] = "instance_info_dict",
) -> None:
super().__init__()
self.converter = HoVerNetNuclearTypePostProcessing(
min_num_points=min_num_points, level=level, return_centroids=return_centroids, output_classes=output_classes
)
self.type_pred_key = type_pred_key
self.instance_pred_key = instance_pred_key
self.pred_binary_key = pred_binary_key
self.instance_info_dict_key = instance_info_dict_key
self.return_binary = return_binary

def __call__(self, data):
d = dict(data)
inst_pred = d[self.instance_pred_key]
type_pred = d[self.type_pred_key]
inst_info_dict = self.converter(type_pred, inst_pred)
key_to_add = f"{self.instance_info_dict_key}"
if key_to_add in d:
raise KeyError(f"Type information with key {key_to_add} already exists.")
d[key_to_add] = inst_info_dict # type: ignore

inst_pred = convert_to_dst_type(inst_pred, type_pred)[0]
if self.return_binary:
inst_pred[inst_pred > 0] = 1
d[self.pred_binary_key] = inst_pred

return d


WatershedD = WatershedDict = Watershedd
GenerateWatershedMaskD = GenerateWatershedMaskDict = GenerateWatershedMaskd
GenerateInstanceBorderD = GenerateInstanceBorderDict = GenerateInstanceBorderd
Expand All @@ -489,3 +556,4 @@ def __call__(self, data):
GenerateInstanceContourDict = GenerateInstanceContourD = GenerateInstanceContourd
GenerateInstanceCentroidDict = GenerateInstanceCentroidD = GenerateInstanceCentroidd
GenerateInstanceTypeDict = GenerateInstanceTypeD = GenerateInstanceTyped
HoVerNetNuclearTypePostProcessingDict = HoVerNetNuclearTypePostProcessingD = HoVerNetNuclearTypePostProcessingd
4 changes: 2 additions & 2 deletions tests/test_generate_instance_contour.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,11 @@
@unittest.skipUnless(has_skimage, "Requires scikit-image library.")
class TestGenerateInstanceContour(unittest.TestCase):
@parameterized.expand(TEST_CASE)
def test_shape(self, in_type, test_data, points_num, offset, expected):
def test_shape(self, in_type, test_data, min_num_points, offset, expected):

inst_bbox = get_bbox(test_data[None])
inst_map = test_data[inst_bbox[0][0] : inst_bbox[0][1], inst_bbox[0][2] : inst_bbox[0][3]]
result = GenerateInstanceContour(points_num=points_num)(in_type(inst_map[None]), offset=offset)
result = GenerateInstanceContour(min_num_points=min_num_points)(in_type(inst_map[None]), offset=offset)
assert_allclose(result, expected, type_test=False)


Expand Down
4 changes: 2 additions & 2 deletions tests/test_generate_instance_contourd.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,12 @@
@unittest.skipUnless(has_skimage, "Requires scikit-image library.")
class TestGenerateInstanceContourd(unittest.TestCase):
@parameterized.expand(TEST_CASE)
def test_shape(self, in_type, test_data, points_num, offset, expected):
def test_shape(self, in_type, test_data, min_num_points, offset, expected):
inst_bbox = get_bbox(test_data[None])
inst_map = test_data[inst_bbox[0][0] : inst_bbox[0][1], inst_bbox[0][2] : inst_bbox[0][3]]
test_data = {"image": in_type(inst_map[None]), "offset": offset}
result = GenerateInstanceContourd(
keys="image", contour_key_postfix="contour", offset_key="offset", points_num=points_num
keys="image", contour_key_postfix="contour", offset_key="offset", min_num_points=min_num_points
)(test_data)
assert_allclose(result["image_contour"], expected, type_test=False)

Expand Down
Loading