From b15768fd2a54ebcca9d0def2df911d9bfb846ffc Mon Sep 17 00:00:00 2001 From: Tomasz Bartczak Date: Thu, 18 May 2023 09:36:18 +0200 Subject: [PATCH 1/2] FROC metric in ND Signed-off-by: Tomasz Bartczak --- monai/metrics/__init__.py | 2 +- monai/metrics/froc.py | 71 +++++++++++++++++++++++++++++--------- tests/test_compute_froc.py | 38 +++++++++++++++++++- 3 files changed, 92 insertions(+), 19 deletions(-) diff --git a/monai/metrics/__init__.py b/monai/metrics/__init__.py index 1af1d757ee..4af1b5760d 100644 --- a/monai/metrics/__init__.py +++ b/monai/metrics/__init__.py @@ -15,7 +15,7 @@ from .confusion_matrix import ConfusionMatrixMetric, compute_confusion_matrix_metric, get_confusion_matrix from .cumulative_average import CumulativeAverage from .f_beta_score import FBetaScore -from .froc import compute_fp_tp_probs, compute_froc_curve_data, compute_froc_score +from .froc import compute_fp_tp_probs, compute_fp_tp_probs_nd, compute_froc_curve_data, compute_froc_score from .generalized_dice import GeneralizedDiceScore, compute_generalized_dice from .hausdorff_distance import HausdorffDistanceMetric, compute_hausdorff_distance, compute_percent_hausdorff_distance from .loss_metric import LossMetric diff --git a/monai/metrics/froc.py b/monai/metrics/froc.py index 6fd367d1e4..b914740a7b 100644 --- a/monai/metrics/froc.py +++ b/monai/metrics/froc.py @@ -19,13 +19,11 @@ from monai.config import NdarrayOrTensor -def compute_fp_tp_probs( +def compute_fp_tp_probs_nd( probs: NdarrayOrTensor, - y_coord: NdarrayOrTensor, - x_coord: NdarrayOrTensor, + coords: NdarrayOrTensor, evaluation_mask: NdarrayOrTensor, labels_to_exclude: list | None = None, - resolution_level: int = 0, ) -> tuple[NdarrayOrTensor, NdarrayOrTensor, int]: """ This function is modified from the official evaluation code of @@ -36,11 +34,9 @@ def compute_fp_tp_probs( Args: probs: an array with shape (n,) that represents the probabilities of the detections. Where, n is the number of predicted detections. - y_coord: an array with shape (n,) that represents the Y-coordinates of the detections. - x_coord: an array with shape (n,) that represents the X-coordinates of the detections. + coords: an array with shape (n, n_dim) that represents the coordinates of the detections in the same order as in `evaluation_mask`. evaluation_mask: the ground truth mask for evaluation. labels_to_exclude: labels in this list will not be counted for metric calculation. - resolution_level: the level at which the evaluation mask is made. Returns: fp_probs: an array that contains the probabilities of the false positive detections. @@ -48,17 +44,17 @@ def compute_fp_tp_probs( num_targets: the total number of targets (excluding `labels_to_exclude`) for all images under evaluation. """ - if not (probs.shape == y_coord.shape == x_coord.shape): + if not (len(probs) == len(coords)): + raise ValueError(f"the length of probs {probs.shape}, should be the same as of coords {coords.shape}.") + if not (len(coords.shape) > 1 and coords.shape[1] == len(evaluation_mask.shape)): raise ValueError( - f"the shapes between probs {probs.shape}, y_coord {y_coord.shape} and x_coord {x_coord.shape} should be the same." + f"coords {coords.shape} need to represent the same number of dimensions as mask {evaluation_mask.shape}." ) if isinstance(probs, torch.Tensor): probs = probs.detach().cpu().numpy() - if isinstance(y_coord, torch.Tensor): - y_coord = y_coord.detach().cpu().numpy() - if isinstance(x_coord, torch.Tensor): - x_coord = x_coord.detach().cpu().numpy() + if isinstance(coords, torch.Tensor): + coords = coords.detach().cpu().numpy() if isinstance(evaluation_mask, torch.Tensor): evaluation_mask = evaluation_mask.detach().cpu().numpy() @@ -68,10 +64,7 @@ def compute_fp_tp_probs( max_label = np.max(evaluation_mask) tp_probs = np.zeros((max_label,), dtype=np.float32) - y_coord = (y_coord / pow(2, resolution_level)).astype(int) - x_coord = (x_coord / pow(2, resolution_level)).astype(int) - - hittedlabel = evaluation_mask[y_coord, x_coord] + hittedlabel = evaluation_mask[tuple(coords.T)] fp_probs = probs[np.where(hittedlabel == 0)] for i in range(1, max_label + 1): if i not in labels_to_exclude and i in hittedlabel: @@ -81,6 +74,50 @@ def compute_fp_tp_probs( return fp_probs, tp_probs, cast(int, num_targets) +def compute_fp_tp_probs( + probs: NdarrayOrTensor, + y_coord: NdarrayOrTensor, + x_coord: NdarrayOrTensor, + evaluation_mask: NdarrayOrTensor, + labels_to_exclude: list | None = None, + resolution_level: int = 0, +) -> tuple[NdarrayOrTensor, NdarrayOrTensor, int]: + """ + This function is modified from the official evaluation code of + `CAMELYON 16 Challenge `_, and used to distinguish + true positive and false positive predictions. A true positive prediction is defined when + the detection point is within the annotated ground truth region. + + Args: + probs: an array with shape (n,) that represents the probabilities of the detections. + Where, n is the number of predicted detections. + y_coord: an array with shape (n,) that represents the Y-coordinates of the detections. + x_coord: an array with shape (n,) that represents the X-coordinates of the detections. + evaluation_mask: the ground truth mask for evaluation. + labels_to_exclude: labels in this list will not be counted for metric calculation. + resolution_level: the level at which the evaluation mask is made. + + Returns: + fp_probs: an array that contains the probabilities of the false positive detections. + tp_probs: an array that contains the probabilities of the True positive detections. + num_targets: the total number of targets (excluding `labels_to_exclude`) for all images under evaluation. + + """ + if isinstance(y_coord, torch.Tensor): + y_coord = y_coord.detach().cpu().numpy() + if isinstance(x_coord, torch.Tensor): + x_coord = x_coord.detach().cpu().numpy() + + y_coord = (y_coord / pow(2, resolution_level)).astype(int) + x_coord = (x_coord / pow(2, resolution_level)).astype(int) + + stacked = np.stack([y_coord, x_coord], axis=1) + + return compute_fp_tp_probs_nd( + probs=probs, coords=stacked, evaluation_mask=evaluation_mask, labels_to_exclude=labels_to_exclude + ) + + def compute_froc_curve_data( fp_probs: np.ndarray | torch.Tensor, tp_probs: np.ndarray | torch.Tensor, num_targets: int, num_images: int ) -> tuple[np.ndarray, np.ndarray]: diff --git a/tests/test_compute_froc.py b/tests/test_compute_froc.py index 1724c469d5..0a48dc099a 100644 --- a/tests/test_compute_froc.py +++ b/tests/test_compute_froc.py @@ -17,7 +17,7 @@ import torch from parameterized import parameterized -from monai.metrics import compute_fp_tp_probs, compute_froc_curve_data, compute_froc_score +from monai.metrics import compute_fp_tp_probs, compute_fp_tp_probs_nd, compute_froc_curve_data, compute_froc_score _device = "cuda:0" if torch.cuda.is_available() else "cpu" TEST_CASE_1 = [ @@ -82,6 +82,33 @@ 0.75, ] +TEST_CASE_ND_1 = [ + { + "probs": torch.tensor([1, 0.6, 0.8]), + "coords": torch.tensor([[0, 3], [2, 0], [3, 1]]), + "evaluation_mask": np.array([[0, 0, 1, 1], [2, 2, 0, 0], [0, 3, 3, 0], [0, 3, 3, 3]]), + }, + np.array([0.6]), + np.array([1, 0, 0.8]), + 3, +] + +TEST_CASE_ND_2 = [ + { + "probs": torch.tensor([1, 0.6, 0.8]), + "coords": torch.tensor([[0, 0, 3], [1, 2, 0], [0, 3, 1]]), + "evaluation_mask": np.array( + [ + [[0, 0, 1, 1], [2, 2, 0, 0], [0, 3, 3, 0], [0, 3, 3, 3]], + [[0, 0, 1, 1], [2, 2, 0, 0], [0, 3, 3, 0], [0, 3, 3, 3]], + ] + ), + }, + np.array([0.6]), + np.array([1, 0, 0.8]), + 3, +] + class TestComputeFpTp(unittest.TestCase): @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) @@ -92,6 +119,15 @@ def test_value(self, input_data, expected_fp, expected_tp, expected_num): np.testing.assert_equal(num_tumors, expected_num) +class TestComputeFpTpNd(unittest.TestCase): + @parameterized.expand([TEST_CASE_ND_1, TEST_CASE_ND_2]) + def test_value(self, input_data, expected_fp, expected_tp, expected_num): + fp_probs, tp_probs, num_tumors = compute_fp_tp_probs_nd(**input_data) + np.testing.assert_allclose(fp_probs, expected_fp, rtol=1e-5) + np.testing.assert_allclose(tp_probs, expected_tp, rtol=1e-5) + np.testing.assert_equal(num_tumors, expected_num) + + class TestComputeFrocScore(unittest.TestCase): @parameterized.expand([TEST_CASE_4, TEST_CASE_5]) def test_value(self, input_data, thresholds, expected_score): From ff675755e9e56d312cb6e0c219b3db53028cd247 Mon Sep 17 00:00:00 2001 From: Tomasz Bartczak Date: Thu, 18 May 2023 15:46:51 +0200 Subject: [PATCH 2/2] code style fix in FROC Signed-off-by: Tomasz Bartczak --- monai/metrics/froc.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/monai/metrics/froc.py b/monai/metrics/froc.py index b914740a7b..81a890aa68 100644 --- a/monai/metrics/froc.py +++ b/monai/metrics/froc.py @@ -34,7 +34,8 @@ def compute_fp_tp_probs_nd( Args: probs: an array with shape (n,) that represents the probabilities of the detections. Where, n is the number of predicted detections. - coords: an array with shape (n, n_dim) that represents the coordinates of the detections in the same order as in `evaluation_mask`. + coords: an array with shape (n, n_dim) that represents the coordinates of the detections. + The dimensions must be in the same order as in `evaluation_mask`. evaluation_mask: the ground truth mask for evaluation. labels_to_exclude: labels in this list will not be counted for metric calculation.