diff --git a/monai/apps/deepedit/__init__.py b/monai/apps/deepedit/__init__.py new file mode 100644 index 0000000000..14ae193634 --- /dev/null +++ b/monai/apps/deepedit/__init__.py @@ -0,0 +1,10 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/monai/apps/deepedit/transforms.py b/monai/apps/deepedit/transforms.py new file mode 100644 index 0000000000..2ab2722076 --- /dev/null +++ b/monai/apps/deepedit/transforms.py @@ -0,0 +1,167 @@ +import json +import logging +from typing import Dict, Hashable, Mapping, Tuple + +import numpy as np + +from monai.config import KeysCollection +from monai.transforms.transform import MapTransform, Randomizable, Transform + +logger = logging.getLogger(__name__) + +from monai.utils import optional_import + +distance_transform_cdt, _ = optional_import("scipy.ndimage.morphology", name="distance_transform_cdt") + + +class DiscardAddGuidanced(MapTransform): + def __init__( + self, + keys: KeysCollection, + probability: float = 1.0, + allow_missing_keys: bool = False, + ): + """ + Discard positive and negative points randomly or Add the two channels for inference time + + :param probability: Discard probability; For inference it will be always 1.0 + """ + super().__init__(keys, allow_missing_keys) + self.probability = probability + + def _apply(self, image): + if self.probability >= 1.0 or np.random.choice([True, False], p=[self.probability, 1 - self.probability]): + signal = np.zeros((1, image.shape[-3], image.shape[-2], image.shape[-1]), dtype=np.float32) + if image.shape[0] == 3: + image[1] = signal + image[2] = signal + else: + image = np.concatenate((image, signal, signal), axis=0) + return image + + def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + d: Dict = dict(data) + for key in self.key_iterator(d): + if key == "image": + d[key] = self._apply(d[key]) + else: + print("This transform only applies to the image") + return d + + +class ResizeGuidanceCustomd(Transform): + """ + Resize the guidance based on cropped vs resized image. + """ + + def __init__( + self, + guidance: str, + ref_image: str, + ) -> None: + self.guidance = guidance + self.ref_image = ref_image + + def __call__(self, data): + d = dict(data) + current_shape = d[self.ref_image].shape[1:] + + factor = np.divide(current_shape, d["image_meta_dict"]["dim"][1:4]) + pos_clicks, neg_clicks = d["foreground"], d["background"] + + pos = np.multiply(pos_clicks, factor).astype(int).tolist() if len(pos_clicks) else [] + neg = np.multiply(neg_clicks, factor).astype(int).tolist() if len(neg_clicks) else [] + + d[self.guidance] = [pos, neg] + return d + + +class ClickRatioAddRandomGuidanced(Randomizable, Transform): + """ + Add random guidance based on discrepancies that were found between label and prediction. + Args: + guidance: key to guidance source, shape (2, N, # of dim) + discrepancy: key that represents discrepancies found between label and prediction, shape (2, C, D, H, W) or (2, C, H, W) + probability: key that represents click/interaction probability, shape (1) + fn_fp_click_ratio: ratio of clicks between FN and FP + """ + + def __init__( + self, + guidance: str = "guidance", + discrepancy: str = "discrepancy", + probability: str = "probability", + fn_fp_click_ratio: Tuple[float, float] = (1.0, 1.0), + ): + self.guidance = guidance + self.discrepancy = discrepancy + self.probability = probability + self.fn_fp_click_ratio = fn_fp_click_ratio + self._will_interact = None + + def randomize(self, data=None): + probability = data[self.probability] + self._will_interact = self.R.choice([True, False], p=[probability, 1.0 - probability]) + + def find_guidance(self, discrepancy): + distance = distance_transform_cdt(discrepancy).flatten() + probability = np.exp(distance) - 1.0 + idx = np.where(discrepancy.flatten() > 0)[0] + + if np.sum(discrepancy > 0) > 0: + seed = self.R.choice(idx, size=1, p=probability[idx] / np.sum(probability[idx])) + dst = distance[seed] + + g = np.asarray(np.unravel_index(seed, discrepancy.shape)).transpose().tolist()[0] + g[0] = dst[0] + return g + return None + + def add_guidance(self, discrepancy, will_interact): + if not will_interact: + return None, None + + pos_discr = discrepancy[0] + neg_discr = discrepancy[1] + + can_be_positive = np.sum(pos_discr) > 0 + can_be_negative = np.sum(neg_discr) > 0 + + pos_prob = self.fn_fp_click_ratio[0] / (self.fn_fp_click_ratio[0] + self.fn_fp_click_ratio[1]) + neg_prob = self.fn_fp_click_ratio[1] / (self.fn_fp_click_ratio[0] + self.fn_fp_click_ratio[1]) + + correct_pos = self.R.choice([True, False], p=[pos_prob, neg_prob]) + + if can_be_positive and not can_be_negative: + return self.find_guidance(pos_discr), None + + if not can_be_positive and can_be_negative: + return None, self.find_guidance(neg_discr) + + if correct_pos and can_be_positive: + return self.find_guidance(pos_discr), None + + if not correct_pos and can_be_negative: + return None, self.find_guidance(neg_discr) + return None, None + + def _apply(self, guidance, discrepancy): + guidance = guidance.tolist() if isinstance(guidance, np.ndarray) else guidance + guidance = json.loads(guidance) if isinstance(guidance, str) else guidance + pos, neg = self.add_guidance(discrepancy, self._will_interact) + if pos: + guidance[0].append(pos) + guidance[1].append([-1] * len(pos)) + if neg: + guidance[0].append([-1] * len(neg)) + guidance[1].append(neg) + + return json.dumps(np.asarray(guidance).astype(int).tolist()) + + def __call__(self, data): + d = dict(data) + guidance = d[self.guidance] + discrepancy = d[self.discrepancy] + self.randomize(data) + d[self.guidance] = self._apply(guidance, discrepancy) + return d diff --git a/tests/min_tests.py b/tests/min_tests.py index d2f8f0aff6..5b376d7b57 100644 --- a/tests/min_tests.py +++ b/tests/min_tests.py @@ -37,6 +37,7 @@ def run_testsuit(): "test_csv_iterable_dataset", "test_dataset", "test_dataset_summary", + "test_deepedit_transforms", "test_deepgrow_dataset", "test_deepgrow_interaction", "test_deepgrow_transforms", diff --git a/tests/test_deepedit_transforms.py b/tests/test_deepedit_transforms.py new file mode 100644 index 0000000000..c2b11e8ee7 --- /dev/null +++ b/tests/test_deepedit_transforms.py @@ -0,0 +1,97 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from parameterized import parameterized + +from monai.apps.deepedit.transforms import ClickRatioAddRandomGuidanced, DiscardAddGuidanced, ResizeGuidanceCustomd + +IMAGE = np.array([[[[1, 0, 2, 0, 1], [0, 1, 2, 1, 0], [2, 2, 3, 2, 2], [0, 1, 2, 1, 0], [1, 0, 2, 0, 1]]]]) +LABEL = np.array([[[[0, 0, 0, 0, 0], [0, 1, 0, 1, 0], [0, 0, 1, 0, 0], [0, 1, 0, 1, 0], [0, 0, 0, 0, 0]]]]) + +DATA_1 = { + "image": IMAGE, + "label": LABEL, + "image_meta_dict": {"dim": IMAGE.shape}, + "label_meta_dict": {}, + "foreground": [0, 0, 0], + "background": [0, 0, 0], +} + +DISCARD_ADD_GUIDANCE_TEST_CASE = [ + {"image": IMAGE, "label": LABEL}, + DATA_1, + (3, 1, 5, 5), +] + +DATA_2 = { + "image": IMAGE, + "label": LABEL, + "guidance": np.array([[[1, 0, 2, 2]], [[-1, -1, -1, -1]]]), + "discrepancy": np.array( + [ + [[[[0, 0, 0, 0, 0], [0, 0, 0, 1, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]], + [[[[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 1, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]], + ] + ), + "probability": 1.0, +} + +CLICK_RATIO_ADD_RANDOM_GUIDANCE_TEST_CASE_1 = [ + {"guidance": "guidance", "discrepancy": "discrepancy", "probability": "probability"}, + DATA_2, + "[[[1, 0, 2, 2], [-1, -1, -1, -1]], [[-1, -1, -1, -1], [1, 0, 2, 1]]]", +] + +DATA_3 = { + "image": np.arange(1000).reshape((1, 5, 10, 20)), + "image_meta_dict": {"foreground_cropped_shape": (1, 10, 20, 40), "dim": [3, 512, 512, 128]}, + "guidance": [[[6, 10, 14], [8, 10, 14]], [[8, 10, 16]]], + "foreground": [[10, 14, 6], [10, 14, 8]], + "background": [[10, 16, 8]], +} + +RESIZE_GUIDANCE_TEST_CASE_1 = [ + {"ref_image": "image", "guidance": "guidance"}, + DATA_3, + [[[0, 0, 0], [0, 0, 1]], [[0, 0, 1]]], +] + + +class TestDiscardAddGuidanced(unittest.TestCase): + @parameterized.expand([DISCARD_ADD_GUIDANCE_TEST_CASE]) + def test_correct_results(self, arguments, input_data, expected_result): + add_fn = DiscardAddGuidanced(arguments) + result = add_fn(input_data) + self.assertEqual(result["image"].shape, expected_result) + + +class TestClickRatioAddRandomGuidanced(unittest.TestCase): + @parameterized.expand([CLICK_RATIO_ADD_RANDOM_GUIDANCE_TEST_CASE_1]) + def test_correct_results(self, arguments, input_data, expected_result): + seed = 0 + add_fn = ClickRatioAddRandomGuidanced(**arguments) + add_fn.set_random_state(seed) + result = add_fn(input_data) + self.assertEqual(result[arguments["guidance"]], expected_result) + + +class TestResizeGuidanced(unittest.TestCase): + @parameterized.expand([RESIZE_GUIDANCE_TEST_CASE_1]) + def test_correct_results(self, arguments, input_data, expected_result): + result = ResizeGuidanceCustomd(**arguments)(input_data) + self.assertEqual(result[arguments["guidance"]], expected_result) + + +if __name__ == "__main__": + unittest.main()