diff --git a/.gitignore b/.gitignore index 4889d2d917..7444d7f2f9 100644 --- a/.gitignore +++ b/.gitignore @@ -48,6 +48,9 @@ coverage.xml .hypothesis/ .pytest_cache/ +# temporary unittest artifacts +tests/testing_data/temp_* + # Translations *.mo *.pot diff --git a/docs/source/apps.rst b/docs/source/apps.rst index fa92a2bc2d..29d835514f 100644 --- a/docs/source/apps.rst +++ b/docs/source/apps.rst @@ -74,10 +74,16 @@ Applications .. autoclass:: MaskedInferenceWSIDataset :members: +.. automodule:: monai.apps.pathology.handlers +.. autoclass:: ProbMapProducer + :members: + +.. automodule:: monai.apps.pathology.metrics +.. autoclass:: LesionFROC + :members: + .. automodule:: monai.apps.pathology.utils +.. autofunction:: compute_multi_instance_mask +.. autofunction:: compute_isolated_tumor_cells .. autoclass:: PathologyProbNMS :members: - -.. automodule:: monai.apps.pathology.handlers -.. autoclass:: ProbMapProducer - :members: \ No newline at end of file diff --git a/monai/apps/pathology/__init__.py b/monai/apps/pathology/__init__.py index 3474a7c10a..203e1a80d7 100644 --- a/monai/apps/pathology/__init__.py +++ b/monai/apps/pathology/__init__.py @@ -11,4 +11,5 @@ from .datasets import MaskedInferenceWSIDataset, PatchWSIDataset, SmartCacheDataset from .handlers import ProbMapProducer -from .utils import ProbNMS +from .metrics import LesionFROC +from .utils import PathologyProbNMS, compute_isolated_tumor_cells, compute_multi_instance_mask diff --git a/monai/apps/pathology/metrics.py b/monai/apps/pathology/metrics.py new file mode 100644 index 0000000000..63b9d073a7 --- /dev/null +++ b/monai/apps/pathology/metrics.py @@ -0,0 +1,180 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +from typing import Dict, List, Tuple, Union + +import numpy as np + +from monai.apps.pathology.utils import PathologyProbNMS, compute_isolated_tumor_cells, compute_multi_instance_mask +from monai.data.image_reader import WSIReader +from monai.metrics import compute_fp_tp_probs, compute_froc_curve_data, compute_froc_score + + +class LesionFROC: + """ + Evaluate with Free Response Operating Characteristic (FROC) score. + + Args: + data: either the list of dictionaries containing probability maps (inference result) and + tumor mask (ground truth), as below, or the path to a json file containing such list. + `{ + "prob_map": "path/to/prob_map_1.npy", + "tumor_mask": "path/to/ground_truth_1.tiff", + "level": 6, + "pixel_spacing": 0.243 + }` + grow_distance: Euclidean distance (in micrometer) by which to grow the label the ground truth's tumors. + Defaults to 75, which is the equivalent size of 5 tumor cells. + itc_diameter: the maximum diameter of a region (in micrometer) to be considered as an isolated tumor cell. + Defaults to 200. + eval_thresholds: the false positive rates for calculating the average sensitivity. + Defaults to (0.25, 0.5, 1, 2, 4, 8) which is the same as the CAMELYON 16 Challenge. + nms_sigma: the standard deviation for gaussian filter of non-maximal suppression. Defaults to 0.0. + nms_prob_threshold: the probability threshold of non-maximal suppression. Defaults to 0.5. + nms_box_size: the box size (in pixel) to be removed around the the pixel for non-maximal suppression. + image_reader_name: the name of library to be used for loading whole slide imaging, either CuCIM or OpenSlide. + Defaults to CuCIM. + + Note: + For more info on `nms_*` parameters look at monai.utils.prob_nms.ProbNMS`. + + """ + + def __init__( + self, + data: Union[List[Dict], str], + grow_distance: int = 75, + itc_diameter: int = 200, + eval_thresholds: Tuple = (0.25, 0.5, 1, 2, 4, 8), + nms_sigma: float = 0.0, + nms_prob_threshold: float = 0.5, + nms_box_size: int = 48, + image_reader_name: str = "cuCIM", + ) -> None: + + if isinstance(data, str): + self.data = self._load_data(data) + else: + self.data = data + self.grow_distance = grow_distance + self.itc_diameter = itc_diameter + self.eval_thresholds = eval_thresholds + self.image_reader = WSIReader(image_reader_name) + self.nms = PathologyProbNMS( + sigma=nms_sigma, + prob_threshold=nms_prob_threshold, + box_size=nms_box_size, + ) + + def _load_data(self, file_path: str) -> List[Dict]: + with open(file_path, "r") as f: + data: List[Dict] = json.load(f) + return data + + def prepare_inference_result(self, sample: Dict): + """ + Prepare the probability map for detection evaluation. + + """ + # load the probability map (the result of model inference) + prob_map = np.load(sample["prob_map"]) + + # apply non-maximal suppression + nms_outputs = self.nms(probs_map=prob_map, resolution_level=sample["level"]) + + # separate nms outputs + if nms_outputs: + probs, x_coord, y_coord = zip(*nms_outputs) + else: + probs, x_coord, y_coord = [], [], [] + + return np.array(probs), np.array(x_coord), np.array(y_coord) + + def prepare_ground_truth(self, sample): + """ + Prepare the ground truth for evaluation based on the binary tumor mask + + """ + # load binary tumor masks + img_obj = self.image_reader.read(sample["tumor_mask"]) + tumor_mask = self.image_reader.get_data(img_obj, level=sample["level"])[0][0] + + # calculate pixel spacing at the mask level + mask_pixel_spacing = sample["pixel_spacing"] * pow(2, sample["level"]) + + # compute multi-instance mask from a binary mask + grow_pixel_threshold = self.grow_distance / (mask_pixel_spacing * 2) + tumor_mask = compute_multi_instance_mask(mask=tumor_mask, threshold=grow_pixel_threshold) + + # identify isolated tumor cells + itc_threshold = (self.itc_diameter + self.grow_distance) / mask_pixel_spacing + itc_labels = compute_isolated_tumor_cells(tumor_mask=tumor_mask, threshold=itc_threshold) + + return tumor_mask, itc_labels + + def compute_fp_tp(self): + """ + Compute false positive and true positive probabilities for tumor detection, + by comparing the model outputs with the prepared ground truths for all samples + + """ + total_fp_probs, total_tp_probs = [], [] + total_num_targets = 0 + num_images = len(self.data) + + for sample in self.data: + probs, y_coord, x_coord = self.prepare_inference_result(sample) + ground_truth, itc_labels = self.prepare_ground_truth(sample) + # compute FP and TP probabilities for a pair of an image and an ground truth mask + fp_probs, tp_probs, num_targets = compute_fp_tp_probs( + probs=probs, + y_coord=y_coord, + x_coord=x_coord, + evaluation_mask=ground_truth, + labels_to_exclude=itc_labels, + resolution_level=sample["level"], + ) + total_fp_probs.extend(fp_probs) + total_tp_probs.extend(tp_probs) + total_num_targets += num_targets + + return ( + np.array(total_fp_probs), + np.array(total_tp_probs), + total_num_targets, + num_images, + ) + + def evaluate(self): + """ + Evaluate the detection performance of a model based on the model probability map output, + the ground truth tumor mask, and their associated metadata (e.g., pixel_spacing, level) + """ + # compute false positive (FP) and true positive (TP) probabilities for all images + fp_probs, tp_probs, num_targets, num_images = self.compute_fp_tp() + + # compute FROC curve given the evaluation of all images + fps_per_image, total_sensitivity = compute_froc_curve_data( + fp_probs=fp_probs, + tp_probs=tp_probs, + num_targets=num_targets, + num_images=num_images, + ) + + # compute FROC score give specific evaluation threshold + froc_score = compute_froc_score( + fps_per_image=fps_per_image, + total_sensitivity=total_sensitivity, + eval_thresholds=self.eval_thresholds, + ) + + return froc_score diff --git a/monai/apps/pathology/utils.py b/monai/apps/pathology/utils.py index b0803526fd..ae77bfafd1 100644 --- a/monai/apps/pathology/utils.py +++ b/monai/apps/pathology/utils.py @@ -9,12 +9,53 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Union +from typing import List, Union import numpy as np import torch -from monai.utils import ProbNMS +from monai.utils import ProbNMS, optional_import + +measure, _ = optional_import("skimage.measure") +ndimage, _ = optional_import("scipy.ndimage") + + +def compute_multi_instance_mask(mask: np.ndarray, threshold: float): + """ + This method computes the segmentation mask according to the binary tumor mask. + + Args: + mask: the binary mask array + threshold: the threshold to fill holes + """ + + neg = 255 - mask * 255 + distance = ndimage.morphology.distance_transform_edt(neg) + binary = distance < threshold + + filled_image = ndimage.morphology.binary_fill_holes(binary) + multi_instance_mask = measure.label(filled_image, connectivity=2) + + return multi_instance_mask + + +def compute_isolated_tumor_cells(tumor_mask: np.ndarray, threshold: float) -> List[int]: + """ + This method computes identifies Isolated Tumor Cells (ITC) and return their labels. + + Args: + tumor_mask: the tumor mask. + threshold: the threshold (at the mask level) to define an isolated tumor cell (ITC). + A region with the longest diameter less than this threshold is considered as an ITC. + """ + max_label = np.amax(tumor_mask) + properties = measure.regionprops(tumor_mask, coordinates="rc") + itc_list = [] + for i in range(max_label): # type: ignore + if properties[i].major_axis_length < threshold: + itc_list.append(i + 1) + + return itc_list class PathologyProbNMS(ProbNMS): diff --git a/monai/utils/prob_nms.py b/monai/utils/prob_nms.py index 29ba93d287..c789dab0bb 100644 --- a/monai/utils/prob_nms.py +++ b/monai/utils/prob_nms.py @@ -22,8 +22,9 @@ class ProbNMS: prob_threshold: the probability threshold, the function will stop searching if the highest probability is no larger than the threshold. The value should be no less than 0.0. Defaults to 0.5. - box_size: determines the sizes of the removing area of the selected coordinates for - each dimensions. Defaults to 48. + box_size: the box size (in pixel) to be removed around the the pixel with the maximum probability. + It can be an integer that defines the size of a square or cube, + or a list containing different values for each dimensions. Defaults to 48. Return: a list of selected lists, where inner lists contain probability and coordinates. diff --git a/tests/test_lesion_froc.py b/tests/test_lesion_froc.py new file mode 100644 index 0000000000..6702997c64 --- /dev/null +++ b/tests/test_lesion_froc.py @@ -0,0 +1,320 @@ +import os +import unittest +from unittest import skipUnless + +import numpy as np +from parameterized import parameterized + +from monai.apps.pathology.metrics import LesionFROC +from monai.utils import optional_import + +_, has_cucim = optional_import("cucim") +_, has_skimage = optional_import("skimage.measure") +_, has_sp = optional_import("scipy.ndimage") +PILImage, has_pil = optional_import("PIL.Image") + + +def save_as_tif(filename, array): + array = array[::-1, ...] # Upside-down + img = PILImage.fromarray(array) + if not filename.endswith(".tif"): + filename += ".tif" + img.save(os.path.join("tests", "testing_data", filename)) + + +def around(val, interval=3): + return slice(val - interval, val + interval) + + +# mask and prediction image size +HEIGHT = 101 +WIDTH = 800 + + +def prepare_test_data(): + # ------------------------------------- + # Ground Truth - Binary Masks + # ------------------------------------- + # ground truth with no tumor + ground_truth = np.zeros((HEIGHT, WIDTH), dtype=np.uint8) + save_as_tif("temp_ground_truth_0", ground_truth) + + # ground truth with one tumor + ground_truth[around(HEIGHT // 2), around(1 * WIDTH // 7)] = 1 + save_as_tif("temp_ground_truth_1", ground_truth) + + # ground truth with two tumors + ground_truth[around(HEIGHT // 2), around(2 * WIDTH // 7)] = 1 + save_as_tif("temp_ground_truth_2", ground_truth) + + # ground truth with three tumors + ground_truth[around(HEIGHT // 2), around(3 * WIDTH // 7)] = 1 + save_as_tif("temp_ground_truth_3", ground_truth) + + # ground truth with four tumors + ground_truth[around(HEIGHT // 2), around(4 * WIDTH // 7)] = 1 + save_as_tif("temp_ground_truth_4", ground_truth) + + # ------------------------------------- + # predictions - Probability Maps + # ------------------------------------- + + # prediction with no tumor + prob_map = np.zeros((HEIGHT, WIDTH)) + np.save("./tests/testing_data/temp_prob_map_0_0.npy", prob_map) + + # prediction with one incorrect tumor + prob_map[HEIGHT // 2, 5 * WIDTH // 7] = 0.6 + np.save("./tests/testing_data/temp_prob_map_0_1.npy", prob_map) + + # prediction with correct first tumors and an incorrect tumor + prob_map[HEIGHT // 2, 1 * WIDTH // 7] = 0.8 + np.save("./tests/testing_data/temp_prob_map_1_1.npy", prob_map) + + # prediction with correct firt two tumors and an incorrect tumor + prob_map[HEIGHT // 2, 2 * WIDTH // 7] = 0.8 + np.save("./tests/testing_data/temp_prob_map_2_1.npy", prob_map) + + # prediction with two incorrect tumors + prob_map = np.zeros((HEIGHT, WIDTH)) + prob_map[HEIGHT // 2, 5 * WIDTH // 7] = 0.6 + prob_map[HEIGHT // 2, 6 * WIDTH // 7] = 0.4 + np.save("./tests/testing_data/temp_prob_map_0_2.npy", prob_map) + + # prediction with correct first tumors and two incorrect tumors + prob_map[HEIGHT // 2, 1 * WIDTH // 7] = 0.8 + np.save("./tests/testing_data/temp_prob_map_1_2.npy", prob_map) + + +TEST_CASE_0 = [ + { + "data": [ + { + "prob_map": "./tests/testing_data/temp_prob_map_0_0.npy", + "tumor_mask": "./tests/testing_data/temp_ground_truth_0.tif", + "level": 0, + "pixel_spacing": 1, + } + ], + "grow_distance": 2, + "itc_diameter": 0, + }, + np.nan, +] + + +TEST_CASE_1 = [ + { + "data": [ + { + "prob_map": "./tests/testing_data/temp_prob_map_0_0.npy", + "tumor_mask": "./tests/testing_data/temp_ground_truth_1.tif", + "level": 0, + "pixel_spacing": 1, + } + ], + "grow_distance": 2, + "itc_diameter": 0, + }, + 0.0, +] + +TEST_CASE_2 = [ + { + "data": [ + { + "prob_map": "./tests/testing_data/temp_prob_map_1_1.npy", + "tumor_mask": "./tests/testing_data/temp_ground_truth_1.tif", + "level": 0, + "pixel_spacing": 1, + } + ], + "grow_distance": 2, + "itc_diameter": 0, + }, + 1.0, +] + +TEST_CASE_3 = [ + { + "data": [ + { + "prob_map": "./tests/testing_data/temp_prob_map_2_1.npy", + "tumor_mask": "./tests/testing_data/temp_ground_truth_1.tif", + "level": 0, + "pixel_spacing": 1, + } + ], + "grow_distance": 2, + "itc_diameter": 0, + }, + 1.0, +] + + +TEST_CASE_4 = [ + { + "data": [ + { + "prob_map": "./tests/testing_data/temp_prob_map_2_1.npy", + "tumor_mask": "./tests/testing_data/temp_ground_truth_2.tif", + "level": 0, + "pixel_spacing": 1, + } + ], + "grow_distance": 2, + "itc_diameter": 0, + }, + 1.0, +] + +TEST_CASE_5 = [ + { + "data": [ + { + "prob_map": "./tests/testing_data/temp_prob_map_1_2.npy", + "tumor_mask": "./tests/testing_data/temp_ground_truth_2.tif", + "level": 0, + "pixel_spacing": 1, + } + ], + "grow_distance": 2, + "itc_diameter": 0, + }, + 0.5, +] + + +TEST_CASE_5 = [ + { + "data": [ + { + "prob_map": "./tests/testing_data/temp_prob_map_1_1.npy", + "tumor_mask": "./tests/testing_data/temp_ground_truth_1.tif", + "level": 0, + "pixel_spacing": 1, + }, + { + "prob_map": "./tests/testing_data/temp_prob_map_1_2.npy", + "tumor_mask": "./tests/testing_data/temp_ground_truth_2.tif", + "level": 0, + "pixel_spacing": 1, + }, + ], + "grow_distance": 2, + "itc_diameter": 0, + }, + 2.0 / 3.0, +] + +TEST_CASE_6 = [ + { + "data": [ + { + "prob_map": "./tests/testing_data/temp_prob_map_1_1.npy", + "tumor_mask": "./tests/testing_data/temp_ground_truth_3.tif", + "level": 0, + "pixel_spacing": 1, + }, + { + "prob_map": "./tests/testing_data/temp_prob_map_1_2.npy", + "tumor_mask": "./tests/testing_data/temp_ground_truth_2.tif", + "level": 0, + "pixel_spacing": 1, + }, + ], + "grow_distance": 2, + "itc_diameter": 0, + }, + 0.4, +] + +TEST_CASE_7 = [ + { + "data": [ + { + "prob_map": "./tests/testing_data/temp_prob_map_0_1.npy", + "tumor_mask": "./tests/testing_data/temp_ground_truth_1.tif", + "level": 0, + "pixel_spacing": 1, + }, + { + "prob_map": "./tests/testing_data/temp_prob_map_1_1.npy", + "tumor_mask": "./tests/testing_data/temp_ground_truth_3.tif", + "level": 0, + "pixel_spacing": 1, + }, + { + "prob_map": "./tests/testing_data/temp_prob_map_1_2.npy", + "tumor_mask": "./tests/testing_data/temp_ground_truth_2.tif", + "level": 0, + "pixel_spacing": 1, + }, + ], + "grow_distance": 2, + "itc_diameter": 0, + }, + 1.0 / 3.0, +] + +TEST_CASE_8 = [ + { + "data": [ + { + "prob_map": "./tests/testing_data/temp_prob_map_0_2.npy", + "tumor_mask": "./tests/testing_data/temp_ground_truth_4.tif", + "level": 0, + "pixel_spacing": 1, + }, + { + "prob_map": "./tests/testing_data/temp_prob_map_1_1.npy", + "tumor_mask": "./tests/testing_data/temp_ground_truth_3.tif", + "level": 0, + "pixel_spacing": 1, + }, + { + "prob_map": "./tests/testing_data/temp_prob_map_1_2.npy", + "tumor_mask": "./tests/testing_data/temp_ground_truth_2.tif", + "level": 0, + "pixel_spacing": 1, + }, + ], + "grow_distance": 2, + "itc_diameter": 0, + }, + 2.0 / 9.0, +] + + +class TestEvaluateTumorFROC(unittest.TestCase): + @skipUnless(has_cucim, "Requires cucim") + @skipUnless(has_skimage, "Requires skimage") + @skipUnless(has_sp, "Requires scipy") + @skipUnless(has_pil, "Requires PIL") + def setUp(self): + prepare_test_data() + + @parameterized.expand( + [ + TEST_CASE_0, + TEST_CASE_1, + TEST_CASE_2, + TEST_CASE_3, + TEST_CASE_4, + TEST_CASE_5, + TEST_CASE_6, + TEST_CASE_7, + TEST_CASE_8, + ] + ) + def test_read_patches_cucim(self, input_parameters, expected): + froc = LesionFROC(**input_parameters) + froc_score = froc.evaluate() + if np.isnan(expected): + self.assertTrue(np.isnan(froc_score)) + else: + self.assertAlmostEqual(froc_score, expected) + + +if __name__ == "__main__": + unittest.main()