diff --git a/monai/utils/prob_nms.py b/monai/utils/prob_nms.py index c789dab0bb..c25223d524 100644 --- a/monai/utils/prob_nms.py +++ b/monai/utils/prob_nms.py @@ -65,36 +65,36 @@ def __init__( def __call__( self, - probs_map: Union[np.ndarray, torch.Tensor], + prob_map: Union[np.ndarray, torch.Tensor], ): """ - probs_map: the input probabilities map, it must have shape (H[, W, ...]). + prob_map: the input probabilities map, it must have shape (H[, W, ...]). """ if self.sigma != 0: - if not isinstance(probs_map, torch.Tensor): - probs_map = torch.as_tensor(probs_map, dtype=torch.float) - self.filter.to(probs_map) - probs_map = self.filter(probs_map) + if not isinstance(prob_map, torch.Tensor): + prob_map = torch.as_tensor(prob_map, dtype=torch.float) + self.filter.to(prob_map) + prob_map = self.filter(prob_map) else: - if not isinstance(probs_map, torch.Tensor): - probs_map = probs_map.copy() + if not isinstance(prob_map, torch.Tensor): + prob_map = prob_map.copy() - if isinstance(probs_map, torch.Tensor): - probs_map = probs_map.detach().cpu().numpy() + if isinstance(prob_map, torch.Tensor): + prob_map = prob_map.detach().cpu().numpy() - probs_map_shape = probs_map.shape + prob_map_shape = prob_map.shape outputs = [] - while np.max(probs_map) > self.prob_threshold: - max_idx = np.unravel_index(probs_map.argmax(), probs_map_shape) - prob_max = probs_map[max_idx] + while np.max(prob_map) > self.prob_threshold: + max_idx = np.unravel_index(prob_map.argmax(), prob_map_shape) + prob_max = prob_map[max_idx] max_idx_arr = np.asarray(max_idx) outputs.append([prob_max] + list(max_idx_arr)) idx_min_range = (max_idx_arr - self.box_lower_bd).clip(0, None) - idx_max_range = (max_idx_arr + self.box_upper_bd).clip(None, probs_map_shape) + idx_max_range = (max_idx_arr + self.box_upper_bd).clip(None, prob_map_shape) # for each dimension, set values during index ranges to 0 slices = tuple(slice(idx_min_range[i], idx_max_range[i]) for i in range(self.spatial_dims)) - probs_map[slices] = 0 + prob_map[slices] = 0 return outputs diff --git a/tests/min_tests.py b/tests/min_tests.py index e896e81c70..06231af0a1 100644 --- a/tests/min_tests.py +++ b/tests/min_tests.py @@ -43,7 +43,7 @@ def run_testsuit(): "test_handler_confusion_matrix_dist", "test_handler_hausdorff_distance", "test_handler_mean_dice", - "test_handler_prob_map_generator", + "test_handler_prob_map_producer", "test_handler_rocauc", "test_handler_rocauc_dist", "test_handler_segmentation_saver", diff --git a/tests/test_handler_prob_map_generator.py b/tests/test_handler_prob_map_producer.py similarity index 94% rename from tests/test_handler_prob_map_generator.py rename to tests/test_handler_prob_map_producer.py index 4882060be9..8bf42131b4 100644 --- a/tests/test_handler_prob_map_generator.py +++ b/tests/test_handler_prob_map_producer.py @@ -23,9 +23,9 @@ from monai.engines import Evaluator from monai.handlers import ValidationHandler -TEST_CASE_0 = ["image_inference_output_1", 2] -TEST_CASE_1 = ["image_inference_output_2", 9] -TEST_CASE_2 = ["image_inference_output_3", 1000] +TEST_CASE_0 = ["temp_image_inference_output_1", 2] +TEST_CASE_1 = ["temp_image_inference_output_2", 9] +TEST_CASE_2 = ["temp_image_inference_output_3", 1000] class TestDataset(Dataset):