diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index a726b25435..4f039b9c35 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -305,6 +305,11 @@ Post-processing :members: :special-members: __call__ +`Prob NMS` +"""""""""" +.. autoclass:: ProbNMS + :members: + `VoteEnsemble` """""""""""""" .. autoclass:: VoteEnsemble diff --git a/docs/source/utils.rst b/docs/source/utils.rst index 071d9ecefd..855954fd29 100644 --- a/docs/source/utils.rst +++ b/docs/source/utils.rst @@ -27,11 +27,6 @@ Misc .. automodule:: monai.utils.misc :members: -Prob NMS --------- -.. automodule:: monai.utils.prob_nms -.. autoclass:: ProbNMS - :members: Profiling --------- diff --git a/monai/apps/pathology/utils.py b/monai/apps/pathology/utils.py index ae77bfafd1..0d1f530bff 100644 --- a/monai/apps/pathology/utils.py +++ b/monai/apps/pathology/utils.py @@ -14,7 +14,8 @@ import numpy as np import torch -from monai.utils import ProbNMS, optional_import +from monai.transforms.post.array import ProbNMS +from monai.utils import optional_import measure, _ = optional_import("skimage.measure") ndimage, _ = optional_import("scipy.ndimage") diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index b8cc832db1..b66567e71a 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -160,6 +160,7 @@ KeepLargestConnectedComponent, LabelToContour, MeanEnsemble, + ProbNMS, VoteEnsemble, ) from .post.dictionary import ( @@ -182,6 +183,9 @@ MeanEnsembled, MeanEnsembleD, MeanEnsembleDict, + ProbNMSd, + ProbNMSD, + ProbNMSDict, VoteEnsembled, VoteEnsembleD, VoteEnsembleDict, diff --git a/monai/transforms/post/array.py b/monai/transforms/post/array.py index 6462753cf9..7ac0e6799c 100644 --- a/monai/transforms/post/array.py +++ b/monai/transforms/post/array.py @@ -21,6 +21,7 @@ import torch.nn.functional as F from monai.networks import one_hot +from monai.networks.layers import GaussianFilter from monai.transforms.transform import Transform from monai.transforms.utils import get_largest_connected_component_mask from monai.utils import ensure_tuple @@ -422,3 +423,97 @@ def __call__(self, img: Union[Sequence[torch.Tensor], torch.Tensor]) -> torch.Te return torch.argmax(img_, dim=1, keepdim=has_ch_dim) # for One-Hot data, round the float number to 0 or 1 return torch.round(img_) + + +class ProbNMS(Transform): + """ + Performs probability based non-maximum suppression (NMS) on the probabilities map via + iteratively selecting the coordinate with highest probability and then move it as well + as its surrounding values. The remove range is determined by the parameter `box_size`. + If multiple coordinates have the same highest probability, only one of them will be + selected. + + Args: + spatial_dims: number of spatial dimensions of the input probabilities map. + Defaults to 2. + sigma: the standard deviation for gaussian filter. + It could be a single value, or `spatial_dims` number of values. Defaults to 0.0. + prob_threshold: the probability threshold, the function will stop searching if + the highest probability is no larger than the threshold. The value should be + no less than 0.0. Defaults to 0.5. + box_size: the box size (in pixel) to be removed around the the pixel with the maximum probability. + It can be an integer that defines the size of a square or cube, + or a list containing different values for each dimensions. Defaults to 48. + + Return: + a list of selected lists, where inner lists contain probability and coordinates. + For example, for 3D input, the inner lists are in the form of [probability, x, y, z]. + + Raises: + ValueError: When ``prob_threshold`` is less than 0.0. + ValueError: When ``box_size`` is a list or tuple, and its length is not equal to `spatial_dims`. + ValueError: When ``box_size`` has a less than 1 value. + + """ + + def __init__( + self, + spatial_dims: int = 2, + sigma: Union[Sequence[float], float, Sequence[torch.Tensor], torch.Tensor] = 0.0, + prob_threshold: float = 0.5, + box_size: Union[int, Sequence[int]] = 48, + ) -> None: + self.sigma = sigma + self.spatial_dims = spatial_dims + if self.sigma != 0: + self.filter = GaussianFilter(spatial_dims=spatial_dims, sigma=sigma) + if prob_threshold < 0: + raise ValueError("prob_threshold should be no less than 0.0.") + self.prob_threshold = prob_threshold + if isinstance(box_size, int): + self.box_size = np.asarray([box_size] * spatial_dims) + else: + if len(box_size) != spatial_dims: + raise ValueError("the sequence length of box_size should be the same as spatial_dims.") + self.box_size = np.asarray(box_size) + if self.box_size.min() <= 0: + raise ValueError("box_size should be larger than 0.") + + self.box_lower_bd = self.box_size // 2 + self.box_upper_bd = self.box_size - self.box_lower_bd + + def __call__( + self, + prob_map: Union[np.ndarray, torch.Tensor], + ): + """ + prob_map: the input probabilities map, it must have shape (H[, W, ...]). + """ + if self.sigma != 0: + if not isinstance(prob_map, torch.Tensor): + prob_map = torch.as_tensor(prob_map, dtype=torch.float) + self.filter.to(prob_map) + prob_map = self.filter(prob_map) + else: + if not isinstance(prob_map, torch.Tensor): + prob_map = prob_map.copy() + + if isinstance(prob_map, torch.Tensor): + prob_map = prob_map.detach().cpu().numpy() + + prob_map_shape = prob_map.shape + + outputs = [] + while np.max(prob_map) > self.prob_threshold: + max_idx = np.unravel_index(prob_map.argmax(), prob_map_shape) + prob_max = prob_map[max_idx] + max_idx_arr = np.asarray(max_idx) + outputs.append([prob_max] + list(max_idx_arr)) + + idx_min_range = (max_idx_arr - self.box_lower_bd).clip(0, None) + idx_max_range = (max_idx_arr + self.box_upper_bd).clip(None, prob_map_shape) + # for each dimension, set values during index ranges to 0 + slices = tuple(slice(idx_min_range[i], idx_max_range[i]) for i in range(self.spatial_dims)) + prob_map[slices] = 0 + + return outputs diff --git a/monai/transforms/post/dictionary.py b/monai/transforms/post/dictionary.py index 6d28f780d4..52bde4ab79 100644 --- a/monai/transforms/post/dictionary.py +++ b/monai/transforms/post/dictionary.py @@ -28,6 +28,7 @@ KeepLargestConnectedComponent, LabelToContour, MeanEnsemble, + ProbNMS, VoteEnsemble, ) from monai.transforms.transform import MapTransform @@ -340,10 +341,66 @@ def __call__(self, data: dict) -> List[dict]: return monai.data.decollate_batch(data, self.batch_size) +class ProbNMSd(MapTransform): + """ + Performs probability based non-maximum suppression (NMS) on the probabilities map via + iteratively selecting the coordinate with highest probability and then move it as well + as its surrounding values. The remove range is determined by the parameter `box_size`. + If multiple coordinates have the same highest probability, only one of them will be + selected. + + Args: + spatial_dims: number of spatial dimensions of the input probabilities map. + Defaults to 2. + sigma: the standard deviation for gaussian filter. + It could be a single value, or `spatial_dims` number of values. Defaults to 0.0. + prob_threshold: the probability threshold, the function will stop searching if + the highest probability is no larger than the threshold. The value should be + no less than 0.0. Defaults to 0.5. + box_size: the box size (in pixel) to be removed around the the pixel with the maximum probability. + It can be an integer that defines the size of a square or cube, + or a list containing different values for each dimensions. Defaults to 48. + + Return: + a list of selected lists, where inner lists contain probability and coordinates. + For example, for 3D input, the inner lists are in the form of [probability, x, y, z]. + + Raises: + ValueError: When ``prob_threshold`` is less than 0.0. + ValueError: When ``box_size`` is a list or tuple, and its length is not equal to `spatial_dims`. + ValueError: When ``box_size`` has a less than 1 value. + + """ + + def __init__( + self, + keys: KeysCollection, + spatial_dims: int = 2, + sigma: Union[Sequence[float], float, Sequence[torch.Tensor], torch.Tensor] = 0.0, + prob_threshold: float = 0.5, + box_size: Union[int, Sequence[int]] = 48, + allow_missing_keys: bool = False, + ) -> None: + super().__init__(keys, allow_missing_keys) + self.prob_nms = ProbNMS( + spatial_dims=spatial_dims, + sigma=sigma, + prob_threshold=prob_threshold, + box_size=box_size, + ) + + def __call__(self, data: Mapping[Hashable, Union[np.ndarray, torch.Tensor]]): + d = dict(data) + for key in self.key_iterator(d): + d[key] = self.prob_nms(d[key]) + return d + + ActivationsD = ActivationsDict = Activationsd AsDiscreteD = AsDiscreteDict = AsDiscreted KeepLargestConnectedComponentD = KeepLargestConnectedComponentDict = KeepLargestConnectedComponentd LabelToContourD = LabelToContourDict = LabelToContourd MeanEnsembleD = MeanEnsembleDict = MeanEnsembled +ProbNMSD = ProbNMSDict = ProbNMSd VoteEnsembleD = VoteEnsembleDict = VoteEnsembled DecollateD = DecollateDict = Decollated diff --git a/monai/utils/__init__.py b/monai/utils/__init__.py index f6a137f47d..d622ce96ae 100644 --- a/monai/utils/__init__.py +++ b/monai/utils/__init__.py @@ -69,6 +69,5 @@ min_version, optional_import, ) -from .prob_nms import ProbNMS from .profiling import PerfContext, torch_profiler_full, torch_profiler_time_cpu_gpu, torch_profiler_time_end_to_end from .state_cacher import StateCacher diff --git a/monai/utils/prob_nms.py b/monai/utils/prob_nms.py deleted file mode 100644 index c25223d524..0000000000 --- a/monai/utils/prob_nms.py +++ /dev/null @@ -1,100 +0,0 @@ -from typing import List, Sequence, Tuple, Union - -import numpy as np -import torch - -from monai.networks.layers import GaussianFilter - - -class ProbNMS: - """ - Performs probability based non-maximum suppression (NMS) on the probabilities map via - iteratively selecting the coordinate with highest probability and then move it as well - as its surrounding values. The remove range is determined by the parameter `box_size`. - If multiple coordinates have the same highest probability, only one of them will be - selected. - - Args: - spatial_dims: number of spatial dimensions of the input probabilities map. - Defaults to 2. - sigma: the standard deviation for gaussian filter. - It could be a single value, or `spatial_dims` number of values. Defaults to 0.0. - prob_threshold: the probability threshold, the function will stop searching if - the highest probability is no larger than the threshold. The value should be - no less than 0.0. Defaults to 0.5. - box_size: the box size (in pixel) to be removed around the the pixel with the maximum probability. - It can be an integer that defines the size of a square or cube, - or a list containing different values for each dimensions. Defaults to 48. - - Return: - a list of selected lists, where inner lists contain probability and coordinates. - For example, for 3D input, the inner lists are in the form of [probability, x, y, z]. - - Raises: - ValueError: When ``prob_threshold`` is less than 0.0. - ValueError: When ``box_size`` is a list or tuple, and its length is not equal to `spatial_dims`. - ValueError: When ``box_size`` has a less than 1 value. - - """ - - def __init__( - self, - spatial_dims: int = 2, - sigma: Union[Sequence[float], float, Sequence[torch.Tensor], torch.Tensor] = 0.0, - prob_threshold: float = 0.5, - box_size: Union[int, List[int], Tuple[int]] = 48, - ) -> None: - self.sigma = sigma - self.spatial_dims = spatial_dims - if self.sigma != 0: - self.filter = GaussianFilter(spatial_dims=spatial_dims, sigma=sigma) - if prob_threshold < 0: - raise ValueError("prob_threshold should be no less than 0.0.") - self.prob_threshold = prob_threshold - if isinstance(box_size, int): - self.box_size = np.asarray([box_size] * spatial_dims) - else: - if len(box_size) != spatial_dims: - raise ValueError("the sequence length of box_size should be the same as spatial_dims.") - self.box_size = np.asarray(box_size) - if self.box_size.min() <= 0: - raise ValueError("box_size should be larger than 0.") - - self.box_lower_bd = self.box_size // 2 - self.box_upper_bd = self.box_size - self.box_lower_bd - - def __call__( - self, - prob_map: Union[np.ndarray, torch.Tensor], - ): - """ - prob_map: the input probabilities map, it must have shape (H[, W, ...]). - """ - if self.sigma != 0: - if not isinstance(prob_map, torch.Tensor): - prob_map = torch.as_tensor(prob_map, dtype=torch.float) - self.filter.to(prob_map) - prob_map = self.filter(prob_map) - else: - if not isinstance(prob_map, torch.Tensor): - prob_map = prob_map.copy() - - if isinstance(prob_map, torch.Tensor): - prob_map = prob_map.detach().cpu().numpy() - - prob_map_shape = prob_map.shape - - outputs = [] - while np.max(prob_map) > self.prob_threshold: - max_idx = np.unravel_index(prob_map.argmax(), prob_map_shape) - prob_max = prob_map[max_idx] - max_idx_arr = np.asarray(max_idx) - outputs.append([prob_max] + list(max_idx_arr)) - - idx_min_range = (max_idx_arr - self.box_lower_bd).clip(0, None) - idx_max_range = (max_idx_arr + self.box_upper_bd).clip(None, prob_map_shape) - # for each dimension, set values during index ranges to 0 - slices = tuple(slice(idx_min_range[i], idx_max_range[i]) for i in range(self.spatial_dims)) - prob_map[slices] = 0 - - return outputs diff --git a/tests/test_prob_nms.py b/tests/test_probnms.py similarity index 98% rename from tests/test_prob_nms.py rename to tests/test_probnms.py index fb88d9cfb4..e51d1017d8 100644 --- a/tests/test_prob_nms.py +++ b/tests/test_probnms.py @@ -15,7 +15,7 @@ import torch from parameterized import parameterized -from monai.utils import ProbNMS +from monai.transforms.post.array import ProbNMS probs_map_1 = np.random.rand(100, 100).clip(0, 0.5) TEST_CASES_2D_1 = [{"spatial_dims": 2, "prob_threshold": 0.5, "box_size": 10}, probs_map_1, []] diff --git a/tests/test_probnmsd.py b/tests/test_probnmsd.py new file mode 100644 index 0000000000..5b75d4310f --- /dev/null +++ b/tests/test_probnmsd.py @@ -0,0 +1,103 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import torch +from parameterized import parameterized + +from monai.transforms.post.dictionary import ProbNMSD + +probs_map_1 = np.random.rand(100, 100).clip(0, 0.5) +TEST_CASES_2D_1 = [{"spatial_dims": 2, "prob_threshold": 0.5, "box_size": 10}, {"prob_map": probs_map_1}, []] + +probs_map_2 = np.random.rand(100, 100).clip(0, 0.5) +probs_map_2[33, 33] = 0.7 +probs_map_2[66, 66] = 0.9 +expected_2 = [[0.9, 66, 66], [0.7, 33, 33]] +TEST_CASES_2D_2 = [ + {"spatial_dims": 2, "prob_threshold": 0.5, "box_size": [10, 10]}, + {"prob_map": probs_map_2}, + expected_2, +] + +probs_map_3 = np.random.rand(100, 100).clip(0, 0.5) +probs_map_3[56, 58] = 0.7 +probs_map_3[60, 66] = 0.8 +probs_map_3[66, 66] = 0.9 +expected_3 = [[0.9, 66, 66], [0.8, 60, 66]] +TEST_CASES_2D_3 = [ + {"spatial_dims": 2, "prob_threshold": 0.5, "box_size": (10, 20)}, + {"prob_map": probs_map_3}, + expected_3, +] + +probs_map_4 = np.random.rand(100, 100).clip(0, 0.5) +probs_map_4[33, 33] = 0.7 +probs_map_4[66, 66] = 0.9 +expected_4 = [[0.9, 66, 66]] +TEST_CASES_2D_4 = [ + {"spatial_dims": 2, "prob_threshold": 0.8, "box_size": 10}, + {"prob_map": probs_map_4}, + expected_4, +] + +probs_map_5 = np.random.rand(100, 100).clip(0, 0.5) +TEST_CASES_2D_5 = [{"spatial_dims": 2, "prob_threshold": 0.5, "sigma": 0.1}, {"prob_map": probs_map_5}, []] + +probs_map_6 = torch.as_tensor(np.random.rand(100, 100).clip(0, 0.5)) +TEST_CASES_2D_6 = [{"spatial_dims": 2, "prob_threshold": 0.5, "sigma": 0.1}, {"prob_map": probs_map_6}, []] + +probs_map_7 = torch.as_tensor(np.random.rand(100, 100).clip(0, 0.5)) +probs_map_7[33, 33] = 0.7 +probs_map_7[66, 66] = 0.9 +if torch.cuda.is_available(): + probs_map_7 = probs_map_7.cuda() +expected_7 = [[0.9, 66, 66], [0.7, 33, 33]] +TEST_CASES_2D_7 = [ + {"spatial_dims": 2, "prob_threshold": 0.5, "sigma": 0.1}, + {"prob_map": probs_map_7}, + expected_7, +] + +probs_map_3d = torch.rand([50, 50, 50]).uniform_(0, 0.5) +probs_map_3d[25, 25, 25] = 0.7 +probs_map_3d[45, 45, 45] = 0.9 +expected_3d = [[0.9, 45, 45, 45], [0.7, 25, 25, 25]] +TEST_CASES_3D = [ + {"spatial_dims": 3, "prob_threshold": 0.5, "box_size": (10, 10, 10)}, + {"prob_map": probs_map_3d}, + expected_3d, +] + + +class TestProbNMS(unittest.TestCase): + @parameterized.expand( + [ + TEST_CASES_2D_1, + TEST_CASES_2D_2, + TEST_CASES_2D_3, + TEST_CASES_2D_4, + TEST_CASES_2D_5, + TEST_CASES_2D_6, + TEST_CASES_2D_7, + TEST_CASES_3D, + ] + ) + def test_output(self, class_args, probs_map, expected): + nms = ProbNMSD(keys="prob_map", **class_args) + output = nms(probs_map) + np.testing.assert_allclose(output["prob_map"], expected) + + +if __name__ == "__main__": + unittest.main()