Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/source/transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,11 @@ Post-processing
:members:
:special-members: __call__

`Prob NMS`
""""""""""
.. autoclass:: ProbNMS
:members:

`VoteEnsemble`
""""""""""""""
.. autoclass:: VoteEnsemble
Expand Down
5 changes: 0 additions & 5 deletions docs/source/utils.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,6 @@ Misc
.. automodule:: monai.utils.misc
:members:

Prob NMS
--------
.. automodule:: monai.utils.prob_nms
.. autoclass:: ProbNMS
:members:

Profiling
---------
Expand Down
3 changes: 2 additions & 1 deletion monai/apps/pathology/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
import numpy as np
import torch

from monai.utils import ProbNMS, optional_import
from monai.transforms.post.array import ProbNMS
from monai.utils import optional_import

measure, _ = optional_import("skimage.measure")
ndimage, _ = optional_import("scipy.ndimage")
Expand Down
4 changes: 4 additions & 0 deletions monai/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@
KeepLargestConnectedComponent,
LabelToContour,
MeanEnsemble,
ProbNMS,
VoteEnsemble,
)
from .post.dictionary import (
Expand All @@ -182,6 +183,9 @@
MeanEnsembled,
MeanEnsembleD,
MeanEnsembleDict,
ProbNMSd,
ProbNMSD,
ProbNMSDict,
VoteEnsembled,
VoteEnsembleD,
VoteEnsembleDict,
Expand Down
95 changes: 95 additions & 0 deletions monai/transforms/post/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import torch.nn.functional as F

from monai.networks import one_hot
from monai.networks.layers import GaussianFilter
from monai.transforms.transform import Transform
from monai.transforms.utils import get_largest_connected_component_mask
from monai.utils import ensure_tuple
Expand Down Expand Up @@ -422,3 +423,97 @@ def __call__(self, img: Union[Sequence[torch.Tensor], torch.Tensor]) -> torch.Te
return torch.argmax(img_, dim=1, keepdim=has_ch_dim)
# for One-Hot data, round the float number to 0 or 1
return torch.round(img_)


class ProbNMS(Transform):
"""
Performs probability based non-maximum suppression (NMS) on the probabilities map via
iteratively selecting the coordinate with highest probability and then move it as well
as its surrounding values. The remove range is determined by the parameter `box_size`.
If multiple coordinates have the same highest probability, only one of them will be
selected.

Args:
spatial_dims: number of spatial dimensions of the input probabilities map.
Defaults to 2.
sigma: the standard deviation for gaussian filter.
It could be a single value, or `spatial_dims` number of values. Defaults to 0.0.
prob_threshold: the probability threshold, the function will stop searching if
the highest probability is no larger than the threshold. The value should be
no less than 0.0. Defaults to 0.5.
box_size: the box size (in pixel) to be removed around the the pixel with the maximum probability.
It can be an integer that defines the size of a square or cube,
or a list containing different values for each dimensions. Defaults to 48.

Return:
a list of selected lists, where inner lists contain probability and coordinates.
For example, for 3D input, the inner lists are in the form of [probability, x, y, z].

Raises:
ValueError: When ``prob_threshold`` is less than 0.0.
ValueError: When ``box_size`` is a list or tuple, and its length is not equal to `spatial_dims`.
ValueError: When ``box_size`` has a less than 1 value.

"""

def __init__(
self,
spatial_dims: int = 2,
sigma: Union[Sequence[float], float, Sequence[torch.Tensor], torch.Tensor] = 0.0,
prob_threshold: float = 0.5,
box_size: Union[int, Sequence[int]] = 48,
) -> None:
self.sigma = sigma
self.spatial_dims = spatial_dims
if self.sigma != 0:
self.filter = GaussianFilter(spatial_dims=spatial_dims, sigma=sigma)
if prob_threshold < 0:
raise ValueError("prob_threshold should be no less than 0.0.")
self.prob_threshold = prob_threshold
if isinstance(box_size, int):
self.box_size = np.asarray([box_size] * spatial_dims)
else:
if len(box_size) != spatial_dims:
raise ValueError("the sequence length of box_size should be the same as spatial_dims.")
self.box_size = np.asarray(box_size)
if self.box_size.min() <= 0:
raise ValueError("box_size should be larger than 0.")

self.box_lower_bd = self.box_size // 2
self.box_upper_bd = self.box_size - self.box_lower_bd

def __call__(
self,
prob_map: Union[np.ndarray, torch.Tensor],
):
"""
prob_map: the input probabilities map, it must have shape (H[, W, ...]).
"""
if self.sigma != 0:
if not isinstance(prob_map, torch.Tensor):
prob_map = torch.as_tensor(prob_map, dtype=torch.float)
self.filter.to(prob_map)
prob_map = self.filter(prob_map)
else:
if not isinstance(prob_map, torch.Tensor):
prob_map = prob_map.copy()

if isinstance(prob_map, torch.Tensor):
prob_map = prob_map.detach().cpu().numpy()

prob_map_shape = prob_map.shape

outputs = []
while np.max(prob_map) > self.prob_threshold:
max_idx = np.unravel_index(prob_map.argmax(), prob_map_shape)
prob_max = prob_map[max_idx]
max_idx_arr = np.asarray(max_idx)
outputs.append([prob_max] + list(max_idx_arr))

idx_min_range = (max_idx_arr - self.box_lower_bd).clip(0, None)
idx_max_range = (max_idx_arr + self.box_upper_bd).clip(None, prob_map_shape)
# for each dimension, set values during index ranges to 0
slices = tuple(slice(idx_min_range[i], idx_max_range[i]) for i in range(self.spatial_dims))
prob_map[slices] = 0

return outputs
57 changes: 57 additions & 0 deletions monai/transforms/post/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
KeepLargestConnectedComponent,
LabelToContour,
MeanEnsemble,
ProbNMS,
VoteEnsemble,
)
from monai.transforms.transform import MapTransform
Expand Down Expand Up @@ -340,10 +341,66 @@ def __call__(self, data: dict) -> List[dict]:
return monai.data.decollate_batch(data, self.batch_size)


class ProbNMSd(MapTransform):
"""
Performs probability based non-maximum suppression (NMS) on the probabilities map via
iteratively selecting the coordinate with highest probability and then move it as well
as its surrounding values. The remove range is determined by the parameter `box_size`.
If multiple coordinates have the same highest probability, only one of them will be
selected.

Args:
spatial_dims: number of spatial dimensions of the input probabilities map.
Defaults to 2.
sigma: the standard deviation for gaussian filter.
It could be a single value, or `spatial_dims` number of values. Defaults to 0.0.
prob_threshold: the probability threshold, the function will stop searching if
the highest probability is no larger than the threshold. The value should be
no less than 0.0. Defaults to 0.5.
box_size: the box size (in pixel) to be removed around the the pixel with the maximum probability.
It can be an integer that defines the size of a square or cube,
or a list containing different values for each dimensions. Defaults to 48.

Return:
a list of selected lists, where inner lists contain probability and coordinates.
For example, for 3D input, the inner lists are in the form of [probability, x, y, z].

Raises:
ValueError: When ``prob_threshold`` is less than 0.0.
ValueError: When ``box_size`` is a list or tuple, and its length is not equal to `spatial_dims`.
ValueError: When ``box_size`` has a less than 1 value.

"""

def __init__(
self,
keys: KeysCollection,
spatial_dims: int = 2,
sigma: Union[Sequence[float], float, Sequence[torch.Tensor], torch.Tensor] = 0.0,
prob_threshold: float = 0.5,
box_size: Union[int, Sequence[int]] = 48,
allow_missing_keys: bool = False,
) -> None:
super().__init__(keys, allow_missing_keys)
self.prob_nms = ProbNMS(
spatial_dims=spatial_dims,
sigma=sigma,
prob_threshold=prob_threshold,
box_size=box_size,
)

def __call__(self, data: Mapping[Hashable, Union[np.ndarray, torch.Tensor]]):
d = dict(data)
for key in self.key_iterator(d):
d[key] = self.prob_nms(d[key])
return d


ActivationsD = ActivationsDict = Activationsd
AsDiscreteD = AsDiscreteDict = AsDiscreted
KeepLargestConnectedComponentD = KeepLargestConnectedComponentDict = KeepLargestConnectedComponentd
LabelToContourD = LabelToContourDict = LabelToContourd
MeanEnsembleD = MeanEnsembleDict = MeanEnsembled
ProbNMSD = ProbNMSDict = ProbNMSd
VoteEnsembleD = VoteEnsembleDict = VoteEnsembled
DecollateD = DecollateDict = Decollated
1 change: 0 additions & 1 deletion monai/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,5 @@
min_version,
optional_import,
)
from .prob_nms import ProbNMS
from .profiling import PerfContext, torch_profiler_full, torch_profiler_time_cpu_gpu, torch_profiler_time_end_to_end
from .state_cacher import StateCacher
100 changes: 0 additions & 100 deletions monai/utils/prob_nms.py

This file was deleted.

2 changes: 1 addition & 1 deletion tests/test_prob_nms.py → tests/test_probnms.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import torch
from parameterized import parameterized

from monai.utils import ProbNMS
from monai.transforms.post.array import ProbNMS

probs_map_1 = np.random.rand(100, 100).clip(0, 0.5)
TEST_CASES_2D_1 = [{"spatial_dims": 2, "prob_threshold": 0.5, "box_size": 10}, probs_map_1, []]
Expand Down
Loading