Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions monai/apps/deepedit/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Copyright 2020 - 2021 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
167 changes: 167 additions & 0 deletions monai/apps/deepedit/transforms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
import json
import logging
from typing import Dict, Hashable, Mapping, Tuple

import numpy as np

from monai.config import KeysCollection
from monai.transforms.transform import MapTransform, Randomizable, Transform

logger = logging.getLogger(__name__)

from monai.utils import optional_import

distance_transform_cdt, _ = optional_import("scipy.ndimage.morphology", name="distance_transform_cdt")


class DiscardAddGuidanced(MapTransform):
def __init__(
self,
keys: KeysCollection,
probability: float = 1.0,
allow_missing_keys: bool = False,
):
"""
Discard positive and negative points randomly or Add the two channels for inference time

:param probability: Discard probability; For inference it will be always 1.0
"""
super().__init__(keys, allow_missing_keys)
self.probability = probability

def _apply(self, image):
if self.probability >= 1.0 or np.random.choice([True, False], p=[self.probability, 1 - self.probability]):
signal = np.zeros((1, image.shape[-3], image.shape[-2], image.shape[-1]), dtype=np.float32)
if image.shape[0] == 3:
image[1] = signal
image[2] = signal
else:
image = np.concatenate((image, signal, signal), axis=0)
return image

def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]:
d: Dict = dict(data)
for key in self.key_iterator(d):
if key == "image":
d[key] = self._apply(d[key])
else:
print("This transform only applies to the image")
return d


class ResizeGuidanceCustomd(Transform):
"""
Resize the guidance based on cropped vs resized image.
"""

def __init__(
self,
guidance: str,
ref_image: str,
) -> None:
self.guidance = guidance
self.ref_image = ref_image

def __call__(self, data):
d = dict(data)
current_shape = d[self.ref_image].shape[1:]

factor = np.divide(current_shape, d["image_meta_dict"]["dim"][1:4])
pos_clicks, neg_clicks = d["foreground"], d["background"]

pos = np.multiply(pos_clicks, factor).astype(int).tolist() if len(pos_clicks) else []
neg = np.multiply(neg_clicks, factor).astype(int).tolist() if len(neg_clicks) else []

d[self.guidance] = [pos, neg]
return d


class ClickRatioAddRandomGuidanced(Randomizable, Transform):
"""
Add random guidance based on discrepancies that were found between label and prediction.
Args:
guidance: key to guidance source, shape (2, N, # of dim)
discrepancy: key that represents discrepancies found between label and prediction, shape (2, C, D, H, W) or (2, C, H, W)
probability: key that represents click/interaction probability, shape (1)
fn_fp_click_ratio: ratio of clicks between FN and FP
"""

def __init__(
self,
guidance: str = "guidance",
discrepancy: str = "discrepancy",
probability: str = "probability",
fn_fp_click_ratio: Tuple[float, float] = (1.0, 1.0),
):
self.guidance = guidance
self.discrepancy = discrepancy
self.probability = probability
self.fn_fp_click_ratio = fn_fp_click_ratio
self._will_interact = None

def randomize(self, data=None):
probability = data[self.probability]
self._will_interact = self.R.choice([True, False], p=[probability, 1.0 - probability])

def find_guidance(self, discrepancy):
distance = distance_transform_cdt(discrepancy).flatten()
probability = np.exp(distance) - 1.0
idx = np.where(discrepancy.flatten() > 0)[0]

if np.sum(discrepancy > 0) > 0:
seed = self.R.choice(idx, size=1, p=probability[idx] / np.sum(probability[idx]))
dst = distance[seed]

g = np.asarray(np.unravel_index(seed, discrepancy.shape)).transpose().tolist()[0]
g[0] = dst[0]
return g
return None

def add_guidance(self, discrepancy, will_interact):
if not will_interact:
return None, None

pos_discr = discrepancy[0]
neg_discr = discrepancy[1]

can_be_positive = np.sum(pos_discr) > 0
can_be_negative = np.sum(neg_discr) > 0

pos_prob = self.fn_fp_click_ratio[0] / (self.fn_fp_click_ratio[0] + self.fn_fp_click_ratio[1])
neg_prob = self.fn_fp_click_ratio[1] / (self.fn_fp_click_ratio[0] + self.fn_fp_click_ratio[1])

correct_pos = self.R.choice([True, False], p=[pos_prob, neg_prob])

if can_be_positive and not can_be_negative:
return self.find_guidance(pos_discr), None

if not can_be_positive and can_be_negative:
return None, self.find_guidance(neg_discr)

if correct_pos and can_be_positive:
return self.find_guidance(pos_discr), None

if not correct_pos and can_be_negative:
return None, self.find_guidance(neg_discr)
return None, None

def _apply(self, guidance, discrepancy):
guidance = guidance.tolist() if isinstance(guidance, np.ndarray) else guidance
guidance = json.loads(guidance) if isinstance(guidance, str) else guidance
pos, neg = self.add_guidance(discrepancy, self._will_interact)
if pos:
guidance[0].append(pos)
guidance[1].append([-1] * len(pos))
if neg:
guidance[0].append([-1] * len(neg))
guidance[1].append(neg)

return json.dumps(np.asarray(guidance).astype(int).tolist())

def __call__(self, data):
d = dict(data)
guidance = d[self.guidance]
discrepancy = d[self.discrepancy]
self.randomize(data)
d[self.guidance] = self._apply(guidance, discrepancy)
return d
1 change: 1 addition & 0 deletions tests/min_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def run_testsuit():
"test_csv_iterable_dataset",
"test_dataset",
"test_dataset_summary",
"test_deepedit_transforms",
"test_deepgrow_dataset",
"test_deepgrow_interaction",
"test_deepgrow_transforms",
Expand Down
97 changes: 97 additions & 0 deletions tests/test_deepedit_transforms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# Copyright 2020 - 2021 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import numpy as np
from parameterized import parameterized

from monai.apps.deepedit.transforms import ClickRatioAddRandomGuidanced, DiscardAddGuidanced, ResizeGuidanceCustomd

IMAGE = np.array([[[[1, 0, 2, 0, 1], [0, 1, 2, 1, 0], [2, 2, 3, 2, 2], [0, 1, 2, 1, 0], [1, 0, 2, 0, 1]]]])
LABEL = np.array([[[[0, 0, 0, 0, 0], [0, 1, 0, 1, 0], [0, 0, 1, 0, 0], [0, 1, 0, 1, 0], [0, 0, 0, 0, 0]]]])

DATA_1 = {
"image": IMAGE,
"label": LABEL,
"image_meta_dict": {"dim": IMAGE.shape},
"label_meta_dict": {},
"foreground": [0, 0, 0],
"background": [0, 0, 0],
}

DISCARD_ADD_GUIDANCE_TEST_CASE = [
{"image": IMAGE, "label": LABEL},
DATA_1,
(3, 1, 5, 5),
]

DATA_2 = {
"image": IMAGE,
"label": LABEL,
"guidance": np.array([[[1, 0, 2, 2]], [[-1, -1, -1, -1]]]),
"discrepancy": np.array(
[
[[[[0, 0, 0, 0, 0], [0, 0, 0, 1, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]],
[[[[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 1, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]],
]
),
"probability": 1.0,
}

CLICK_RATIO_ADD_RANDOM_GUIDANCE_TEST_CASE_1 = [
{"guidance": "guidance", "discrepancy": "discrepancy", "probability": "probability"},
DATA_2,
"[[[1, 0, 2, 2], [-1, -1, -1, -1]], [[-1, -1, -1, -1], [1, 0, 2, 1]]]",
]

DATA_3 = {
"image": np.arange(1000).reshape((1, 5, 10, 20)),
"image_meta_dict": {"foreground_cropped_shape": (1, 10, 20, 40), "dim": [3, 512, 512, 128]},
"guidance": [[[6, 10, 14], [8, 10, 14]], [[8, 10, 16]]],
"foreground": [[10, 14, 6], [10, 14, 8]],
"background": [[10, 16, 8]],
}

RESIZE_GUIDANCE_TEST_CASE_1 = [
{"ref_image": "image", "guidance": "guidance"},
DATA_3,
[[[0, 0, 0], [0, 0, 1]], [[0, 0, 1]]],
]


class TestDiscardAddGuidanced(unittest.TestCase):
@parameterized.expand([DISCARD_ADD_GUIDANCE_TEST_CASE])
def test_correct_results(self, arguments, input_data, expected_result):
add_fn = DiscardAddGuidanced(arguments)
result = add_fn(input_data)
self.assertEqual(result["image"].shape, expected_result)


class TestClickRatioAddRandomGuidanced(unittest.TestCase):
@parameterized.expand([CLICK_RATIO_ADD_RANDOM_GUIDANCE_TEST_CASE_1])
def test_correct_results(self, arguments, input_data, expected_result):
seed = 0
add_fn = ClickRatioAddRandomGuidanced(**arguments)
add_fn.set_random_state(seed)
result = add_fn(input_data)
self.assertEqual(result[arguments["guidance"]], expected_result)


class TestResizeGuidanced(unittest.TestCase):
@parameterized.expand([RESIZE_GUIDANCE_TEST_CASE_1])
def test_correct_results(self, arguments, input_data, expected_result):
result = ResizeGuidanceCustomd(**arguments)(input_data)
self.assertEqual(result[arguments["guidance"]], expected_result)


if __name__ == "__main__":
unittest.main()