From 142ee73bde675b82f3980ad52797422c7ccad795 Mon Sep 17 00:00:00 2001 From: Andres Date: Thu, 19 Aug 2021 17:00:19 +0100 Subject: [PATCH 1/5] Add deepedit transforms Signed-off-by: Andres --- monai/apps/deepedit/__init__.py | 10 ++ monai/apps/deepedit/transforms.py | 169 ++++++++++++++++++++++++++++++ 2 files changed, 179 insertions(+) create mode 100644 monai/apps/deepedit/__init__.py create mode 100644 monai/apps/deepedit/transforms.py diff --git a/monai/apps/deepedit/__init__.py b/monai/apps/deepedit/__init__.py new file mode 100644 index 0000000000..14ae193634 --- /dev/null +++ b/monai/apps/deepedit/__init__.py @@ -0,0 +1,10 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/monai/apps/deepedit/transforms.py b/monai/apps/deepedit/transforms.py new file mode 100644 index 0000000000..e5c5db64c2 --- /dev/null +++ b/monai/apps/deepedit/transforms.py @@ -0,0 +1,169 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import logging +from typing import Dict, Tuple + +import numpy as np +from monai.transforms.transform import Randomizable, Transform + +logger = logging.getLogger(__name__) + +from monai.utils import optional_import + +distance_transform_cdt, _ = optional_import("scipy.ndimage.morphology", name="distance_transform_cdt") + + +class DiscardAddGuidanced(Transform): + def __init__(self, image: str = "image", probability: float = 1.0): + """ + Discard positive and negative points randomly or Add the two channels for inference time + + :param image: image key + :param batched: Is it batched (if used during training and data is batched as interaction transform) + :param probability: Discard probability; For inference it will be always 1.0 + """ + self.image = image + self.probability = probability + + def _apply(self, image): + if self.probability >= 1.0 or np.random.choice([True, False], p=[self.probability, 1 - self.probability]): + signal = np.zeros((1, image.shape[-3], image.shape[-2], image.shape[-1]), dtype=np.float32) + if image.shape[0] == 3: + image[1] = signal + image[2] = signal + else: + image = np.concatenate((image, signal, signal), axis=0) + return image + + def __call__(self, data): + d: Dict = dict(data) + d[self.image] = self._apply(d[self.image]) + return d + + +class ResizeGuidanceCustomd(Transform): + """ + Resize the guidance based on cropped vs resized image. + """ + + def __init__( + self, + guidance: str, + ref_image: str, + ) -> None: + self.guidance = guidance + self.ref_image = ref_image + + def __call__(self, data): + d = dict(data) + current_shape = d[self.ref_image].shape[1:] + + factor = np.divide(current_shape, d["image_meta_dict"]["dim"][1:4]) + pos_clicks, neg_clicks = d["foreground"], d["background"] + + pos = np.multiply(pos_clicks, factor).astype(int).tolist() if len(pos_clicks) else [] + neg = np.multiply(neg_clicks, factor).astype(int).tolist() if len(neg_clicks) else [] + + d[self.guidance] = [pos, neg] + return d + + +class ClickRatioAddRandomGuidanced(Randomizable, Transform): + """ + Add random guidance based on discrepancies that were found between label and prediction. + Args: + guidance: key to guidance source, shape (2, N, # of dim) + discrepancy: key that represents discrepancies found between label and prediction, shape (2, C, D, H, W) or (2, C, H, W) + probability: key that represents click/interaction probability, shape (1) + fn_fp_click_ratio: ratio of clicks between FN and FP + """ + + def __init__( + self, + guidance: str = "guidance", + discrepancy: str = "discrepancy", + probability: str = "probability", + fn_fp_click_ratio: Tuple[float, float] = (1.0, 1.0), + ): + self.guidance = guidance + self.discrepancy = discrepancy + self.probability = probability + self.fn_fp_click_ratio = fn_fp_click_ratio + self._will_interact = None + + def randomize(self, data=None): + probability = data[self.probability] + self._will_interact = self.R.choice([True, False], p=[probability, 1.0 - probability]) + + def find_guidance(self, discrepancy): + distance = distance_transform_cdt(discrepancy).flatten() + probability = np.exp(distance) - 1.0 + idx = np.where(discrepancy.flatten() > 0)[0] + + if np.sum(discrepancy > 0) > 0: + seed = self.R.choice(idx, size=1, p=probability[idx] / np.sum(probability[idx])) + dst = distance[seed] + + g = np.asarray(np.unravel_index(seed, discrepancy.shape)).transpose().tolist()[0] + g[0] = dst[0] + return g + return None + + def add_guidance(self, discrepancy, will_interact): + if not will_interact: + return None, None + + pos_discr = discrepancy[0] + neg_discr = discrepancy[1] + + can_be_positive = np.sum(pos_discr) > 0 + can_be_negative = np.sum(neg_discr) > 0 + + pos_prob = self.fn_fp_click_ratio[0] / (self.fn_fp_click_ratio[0] + self.fn_fp_click_ratio[1]) + neg_prob = self.fn_fp_click_ratio[1] / (self.fn_fp_click_ratio[0] + self.fn_fp_click_ratio[1]) + + correct_pos = self.R.choice([True, False], p=[pos_prob, neg_prob]) + + if can_be_positive and not can_be_negative: + return self.find_guidance(pos_discr), None + + if not can_be_positive and can_be_negative: + return None, self.find_guidance(neg_discr) + + if correct_pos and can_be_positive: + return self.find_guidance(pos_discr), None + + if not correct_pos and can_be_negative: + return None, self.find_guidance(neg_discr) + return None, None + + def _apply(self, guidance, discrepancy): + guidance = guidance.tolist() if isinstance(guidance, np.ndarray) else guidance + guidance = json.loads(guidance) if isinstance(guidance, str) else guidance + pos, neg = self.add_guidance(discrepancy, self._will_interact) + if pos: + guidance[0].append(pos) + guidance[1].append([-1] * len(pos)) + if neg: + guidance[0].append([-1] * len(neg)) + guidance[1].append(neg) + + return json.dumps(np.asarray(guidance).astype(int).tolist()) + + def __call__(self, data): + d = dict(data) + guidance = d[self.guidance] + discrepancy = d[self.discrepancy] + self.randomize(data) + d[self.guidance] = self._apply(guidance, discrepancy) + return d From c6358c3be7d4775499af4e43670f737a3caea4bf Mon Sep 17 00:00:00 2001 From: Andres Date: Thu, 19 Aug 2021 19:09:08 +0100 Subject: [PATCH 2/5] Run unittests - autofix Signed-off-by: Andres --- monai/apps/deepedit/transforms.py | 1 + 1 file changed, 1 insertion(+) diff --git a/monai/apps/deepedit/transforms.py b/monai/apps/deepedit/transforms.py index e5c5db64c2..0467dec3c2 100644 --- a/monai/apps/deepedit/transforms.py +++ b/monai/apps/deepedit/transforms.py @@ -14,6 +14,7 @@ from typing import Dict, Tuple import numpy as np + from monai.transforms.transform import Randomizable, Transform logger = logging.getLogger(__name__) From 1975d47cf45591cfff4c00d663cf8551128648bd Mon Sep 17 00:00:00 2001 From: Andres Date: Sun, 22 Aug 2021 15:12:01 +0100 Subject: [PATCH 3/5] Update transform Signed-off-by: Andres --- monai/apps/deepedit/transforms.py | 37 ++++++++++++++----------------- 1 file changed, 17 insertions(+), 20 deletions(-) diff --git a/monai/apps/deepedit/transforms.py b/monai/apps/deepedit/transforms.py index 0467dec3c2..2ab2722076 100644 --- a/monai/apps/deepedit/transforms.py +++ b/monai/apps/deepedit/transforms.py @@ -1,21 +1,11 @@ -# Copyright 2020 - 2021 MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - import json import logging -from typing import Dict, Tuple +from typing import Dict, Hashable, Mapping, Tuple import numpy as np -from monai.transforms.transform import Randomizable, Transform +from monai.config import KeysCollection +from monai.transforms.transform import MapTransform, Randomizable, Transform logger = logging.getLogger(__name__) @@ -24,16 +14,19 @@ distance_transform_cdt, _ = optional_import("scipy.ndimage.morphology", name="distance_transform_cdt") -class DiscardAddGuidanced(Transform): - def __init__(self, image: str = "image", probability: float = 1.0): +class DiscardAddGuidanced(MapTransform): + def __init__( + self, + keys: KeysCollection, + probability: float = 1.0, + allow_missing_keys: bool = False, + ): """ Discard positive and negative points randomly or Add the two channels for inference time - :param image: image key - :param batched: Is it batched (if used during training and data is batched as interaction transform) :param probability: Discard probability; For inference it will be always 1.0 """ - self.image = image + super().__init__(keys, allow_missing_keys) self.probability = probability def _apply(self, image): @@ -46,9 +39,13 @@ def _apply(self, image): image = np.concatenate((image, signal, signal), axis=0) return image - def __call__(self, data): + def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d: Dict = dict(data) - d[self.image] = self._apply(d[self.image]) + for key in self.key_iterator(d): + if key == "image": + d[key] = self._apply(d[key]) + else: + print("This transform only applies to the image") return d From 260917abfa2328b04a0bca6c59af393aef122b8a Mon Sep 17 00:00:00 2001 From: Andres Date: Fri, 27 Aug 2021 02:57:38 +0100 Subject: [PATCH 4/5] Add unit tests for DeepEdit transforms Signed-off-by: Andres --- tests/test_deepedit_transforms.py | 97 +++++++++++++++++++++++++++++++ 1 file changed, 97 insertions(+) create mode 100644 tests/test_deepedit_transforms.py diff --git a/tests/test_deepedit_transforms.py b/tests/test_deepedit_transforms.py new file mode 100644 index 0000000000..c2b11e8ee7 --- /dev/null +++ b/tests/test_deepedit_transforms.py @@ -0,0 +1,97 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from parameterized import parameterized + +from monai.apps.deepedit.transforms import ClickRatioAddRandomGuidanced, DiscardAddGuidanced, ResizeGuidanceCustomd + +IMAGE = np.array([[[[1, 0, 2, 0, 1], [0, 1, 2, 1, 0], [2, 2, 3, 2, 2], [0, 1, 2, 1, 0], [1, 0, 2, 0, 1]]]]) +LABEL = np.array([[[[0, 0, 0, 0, 0], [0, 1, 0, 1, 0], [0, 0, 1, 0, 0], [0, 1, 0, 1, 0], [0, 0, 0, 0, 0]]]]) + +DATA_1 = { + "image": IMAGE, + "label": LABEL, + "image_meta_dict": {"dim": IMAGE.shape}, + "label_meta_dict": {}, + "foreground": [0, 0, 0], + "background": [0, 0, 0], +} + +DISCARD_ADD_GUIDANCE_TEST_CASE = [ + {"image": IMAGE, "label": LABEL}, + DATA_1, + (3, 1, 5, 5), +] + +DATA_2 = { + "image": IMAGE, + "label": LABEL, + "guidance": np.array([[[1, 0, 2, 2]], [[-1, -1, -1, -1]]]), + "discrepancy": np.array( + [ + [[[[0, 0, 0, 0, 0], [0, 0, 0, 1, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]], + [[[[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 1, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]], + ] + ), + "probability": 1.0, +} + +CLICK_RATIO_ADD_RANDOM_GUIDANCE_TEST_CASE_1 = [ + {"guidance": "guidance", "discrepancy": "discrepancy", "probability": "probability"}, + DATA_2, + "[[[1, 0, 2, 2], [-1, -1, -1, -1]], [[-1, -1, -1, -1], [1, 0, 2, 1]]]", +] + +DATA_3 = { + "image": np.arange(1000).reshape((1, 5, 10, 20)), + "image_meta_dict": {"foreground_cropped_shape": (1, 10, 20, 40), "dim": [3, 512, 512, 128]}, + "guidance": [[[6, 10, 14], [8, 10, 14]], [[8, 10, 16]]], + "foreground": [[10, 14, 6], [10, 14, 8]], + "background": [[10, 16, 8]], +} + +RESIZE_GUIDANCE_TEST_CASE_1 = [ + {"ref_image": "image", "guidance": "guidance"}, + DATA_3, + [[[0, 0, 0], [0, 0, 1]], [[0, 0, 1]]], +] + + +class TestDiscardAddGuidanced(unittest.TestCase): + @parameterized.expand([DISCARD_ADD_GUIDANCE_TEST_CASE]) + def test_correct_results(self, arguments, input_data, expected_result): + add_fn = DiscardAddGuidanced(arguments) + result = add_fn(input_data) + self.assertEqual(result["image"].shape, expected_result) + + +class TestClickRatioAddRandomGuidanced(unittest.TestCase): + @parameterized.expand([CLICK_RATIO_ADD_RANDOM_GUIDANCE_TEST_CASE_1]) + def test_correct_results(self, arguments, input_data, expected_result): + seed = 0 + add_fn = ClickRatioAddRandomGuidanced(**arguments) + add_fn.set_random_state(seed) + result = add_fn(input_data) + self.assertEqual(result[arguments["guidance"]], expected_result) + + +class TestResizeGuidanced(unittest.TestCase): + @parameterized.expand([RESIZE_GUIDANCE_TEST_CASE_1]) + def test_correct_results(self, arguments, input_data, expected_result): + result = ResizeGuidanceCustomd(**arguments)(input_data) + self.assertEqual(result[arguments["guidance"]], expected_result) + + +if __name__ == "__main__": + unittest.main() From 7a77173d4e84399d4eabf0f9a04af3615c2ae47a Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Fri, 27 Aug 2021 09:01:10 +0100 Subject: [PATCH 5/5] exclude in min tests. Signed-off-by: Wenqi Li --- tests/min_tests.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/min_tests.py b/tests/min_tests.py index d2f8f0aff6..5b376d7b57 100644 --- a/tests/min_tests.py +++ b/tests/min_tests.py @@ -37,6 +37,7 @@ def run_testsuit(): "test_csv_iterable_dataset", "test_dataset", "test_dataset_summary", + "test_deepedit_transforms", "test_deepgrow_dataset", "test_deepgrow_interaction", "test_deepgrow_transforms",