From 2f4db726f8644fb9a0a6e484ecd6a29d2904bfd5 Mon Sep 17 00:00:00 2001 From: YuanTingHsieh Date: Fri, 12 Feb 2021 17:03:48 -0800 Subject: [PATCH 1/4] Add deepgrow interaction Signed-off-by: YuanTingHsieh --- monai/apps/deepgrow/interaction.py | 82 ++++++++++++++++++++++++++++++ tests/min_tests.py | 1 + tests/test_deepgrow_interaction.py | 54 ++++++++++++++++++++ 3 files changed, 137 insertions(+) create mode 100644 monai/apps/deepgrow/interaction.py create mode 100644 tests/test_deepgrow_interaction.py diff --git a/monai/apps/deepgrow/interaction.py b/monai/apps/deepgrow/interaction.py new file mode 100644 index 0000000000..9be5e72ae5 --- /dev/null +++ b/monai/apps/deepgrow/interaction.py @@ -0,0 +1,82 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import importlib +from typing import Dict, Union + +import torch + +from monai.engines import SupervisedEvaluator, SupervisedTrainer +from monai.engines.utils import CommonKeys +from monai.engines.workflow import Events +from monai.transforms import Compose + + +class Interaction: + """ + Ignite handler used to introduce interactions (simulation of clicks) for Deepgrow Training/Evaluation. + + Args: + transforms: execute additional transformation during every iteration (before train). + Typically, several Tensor based transforms composed by `Compose`. + max_interactions: maximum number of interactions per iteration + train: training or evaluation + key_probability: field name to fill probability for every interaction + """ + + def __init__(self, transforms, max_interactions: int, train: bool, key_probability: str = "probability") -> None: + self.transforms = transforms + self.max_interactions = max_interactions + self.train = train + self.key_probability = key_probability + + if not isinstance(self.transforms, Compose): + transforms = [] + for t in self.transforms: + transforms.append(self.init_external_class(t)) + self.transforms = Compose(transforms) + + @staticmethod + def init_external_class(config_dict): + class_args = None if config_dict.get("args") is None else dict(config_dict.get("args")) + class_path = config_dict.get("path", config_dict["name"]) + + module_name, class_name = class_path.rsplit(".", 1) + m = importlib.import_module(module_name) + c = getattr(m, class_name) + return c(**class_args) if class_args else c() + + def attach(self, engine: Union[SupervisedTrainer, SupervisedEvaluator]) -> None: + if not engine.has_event_handler(self, Events.ITERATION_STARTED): + engine.add_event_handler(Events.ITERATION_STARTED, self) + + def __call__(self, engine: Union[SupervisedTrainer, SupervisedEvaluator], batchdata: Dict[str, torch.Tensor]): + if batchdata is None: + raise ValueError("Must provide batch data for current iteration.") + + for j in range(self.max_interactions): + inputs, _ = engine.prepare_batch(batchdata) + inputs = inputs.to(engine.state.device) + + engine.network.eval() + with torch.no_grad(): + if engine.amp: + with torch.cuda.amp.autocast(): + predictions = engine.inferer(inputs, engine.network) + else: + predictions = engine.inferer(inputs, engine.network) + + batchdata.update({CommonKeys.PRED: predictions}) + batchdata[self.key_probability] = torch.as_tensor( + ([1.0 - ((1.0 / self.max_interactions) * j)] if self.train else [1.0]) * len(inputs) + ) + batchdata = self.transforms(batchdata) + + return engine._iteration(engine, batchdata) diff --git a/tests/min_tests.py b/tests/min_tests.py index 0fd6985067..c284db886e 100644 --- a/tests/min_tests.py +++ b/tests/min_tests.py @@ -104,6 +104,7 @@ def run_testsuit(): "test_handler_metrics_saver_dist", "test_evenly_divisible_all_gather_dist", "test_handler_classification_saver_dist", + "test_deepgrow_interaction", ] assert sorted(exclude_cases) == sorted(set(exclude_cases)), f"Duplicated items in {exclude_cases}" diff --git a/tests/test_deepgrow_interaction.py b/tests/test_deepgrow_interaction.py new file mode 100644 index 0000000000..9c22db0c71 --- /dev/null +++ b/tests/test_deepgrow_interaction.py @@ -0,0 +1,54 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch + +from monai.apps.deepgrow.interaction import Interaction +from monai.data import Dataset +from monai.engines import SupervisedTrainer +from monai.transforms import Activationsd, Compose, ToNumpyd + + +class TestInteractions(unittest.TestCase): + def test_interaction(self): + data = [] + for i in range(5): + data.append({"image": torch.tensor([float(i)]), "label": torch.tensor([float(i)])}) + network = torch.nn.Linear(1, 1) + lr = 1e-3 + opt = torch.optim.SGD(network.parameters(), lr) + loss = torch.nn.L1Loss() + dataset = Dataset(data, transform=None) + data_loader = torch.utils.data.DataLoader(dataset, batch_size=5) + + iteration_transforms = Compose([Activationsd(keys="pred", sigmoid=True), ToNumpyd(keys="pred")]) + + i = Interaction(transforms=iteration_transforms, train=True, max_interactions=5) + assert len(i.transforms.transforms) == 2 + + # set up engine + engine = SupervisedTrainer( + device=torch.device("cpu"), + max_epochs=1, + train_data_loader=data_loader, + network=network, + optimizer=opt, + loss_function=loss, + iteration_update=i, + ) + + engine.run() + + +if __name__ == "__main__": + unittest.main() From af7794a45788aaf5bcc8d487b61112259c1ff8ce Mon Sep 17 00:00:00 2001 From: Sachidanand Alle Date: Sat, 20 Feb 2021 03:35:32 -0800 Subject: [PATCH 2/4] Fix review comments Signed-off-by: Sachidanand Alle --- monai/apps/deepgrow/interaction.py | 31 ++++++++++++------------------ tests/test_deepgrow_interaction.py | 16 +++++++++++---- 2 files changed, 24 insertions(+), 23 deletions(-) diff --git a/monai/apps/deepgrow/interaction.py b/monai/apps/deepgrow/interaction.py index 9be5e72ae5..004cce6d36 100644 --- a/monai/apps/deepgrow/interaction.py +++ b/monai/apps/deepgrow/interaction.py @@ -8,8 +8,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import importlib -from typing import Dict, Union +from typing import Callable, Dict, Sequence, Union import torch @@ -31,28 +30,22 @@ class Interaction: key_probability: field name to fill probability for every interaction """ - def __init__(self, transforms, max_interactions: int, train: bool, key_probability: str = "probability") -> None: + def __init__( + self, + transforms: Union[Sequence[Callable], Callable], + max_interactions: int, + train: bool, + key_probability: str = "probability", + ) -> None: + + if not isinstance(transforms, Compose): + transforms = Compose(transforms) + self.transforms = transforms self.max_interactions = max_interactions self.train = train self.key_probability = key_probability - if not isinstance(self.transforms, Compose): - transforms = [] - for t in self.transforms: - transforms.append(self.init_external_class(t)) - self.transforms = Compose(transforms) - - @staticmethod - def init_external_class(config_dict): - class_args = None if config_dict.get("args") is None else dict(config_dict.get("args")) - class_path = config_dict.get("path", config_dict["name"]) - - module_name, class_name = class_path.rsplit(".", 1) - m = importlib.import_module(module_name) - c = getattr(m, class_name) - return c(**class_args) if class_args else c() - def attach(self, engine: Union[SupervisedTrainer, SupervisedEvaluator]) -> None: if not engine.has_event_handler(self, Events.ITERATION_STARTED): engine.add_event_handler(Events.ITERATION_STARTED, self) diff --git a/tests/test_deepgrow_interaction.py b/tests/test_deepgrow_interaction.py index 9c22db0c71..272ae82a5b 100644 --- a/tests/test_deepgrow_interaction.py +++ b/tests/test_deepgrow_interaction.py @@ -20,7 +20,7 @@ class TestInteractions(unittest.TestCase): - def test_interaction(self): + def run_interaction(self, train, compose): data = [] for i in range(5): data.append({"image": torch.tensor([float(i)]), "label": torch.tensor([float(i)])}) @@ -31,10 +31,11 @@ def test_interaction(self): dataset = Dataset(data, transform=None) data_loader = torch.utils.data.DataLoader(dataset, batch_size=5) - iteration_transforms = Compose([Activationsd(keys="pred", sigmoid=True), ToNumpyd(keys="pred")]) + iteration_transforms = [Activationsd(keys="pred", sigmoid=True), ToNumpyd(keys="pred")] + iteration_transforms = Compose(iteration_transforms) if compose else iteration_transforms - i = Interaction(transforms=iteration_transforms, train=True, max_interactions=5) - assert len(i.transforms.transforms) == 2 + i = Interaction(transforms=iteration_transforms, train=train, max_interactions=5) + self.assertEqual(len(i.transforms.transforms), 2, "Mismatch in expected transforms") # set up engine engine = SupervisedTrainer( @@ -48,6 +49,13 @@ def test_interaction(self): ) engine.run() + self.assertIsNotNone(engine.state.batch.get("probability"), "Probability is missing") + + def test_train_interaction(self): + self.run_interaction(train=True, compose=True) + + def test_val_interaction(self): + self.run_interaction(train=False, compose=False) if __name__ == "__main__": From 261185c16e5674b8c1f6c5713eb39844a252dfa0 Mon Sep 17 00:00:00 2001 From: Sachidanand Alle Date: Mon, 22 Feb 2021 03:33:46 -0800 Subject: [PATCH 3/4] Fix doc Signed-off-by: Sachidanand Alle --- monai/apps/deepgrow/interaction.py | 1 + 1 file changed, 1 insertion(+) diff --git a/monai/apps/deepgrow/interaction.py b/monai/apps/deepgrow/interaction.py index 004cce6d36..e6c24a8483 100644 --- a/monai/apps/deepgrow/interaction.py +++ b/monai/apps/deepgrow/interaction.py @@ -21,6 +21,7 @@ class Interaction: """ Ignite handler used to introduce interactions (simulation of clicks) for Deepgrow Training/Evaluation. + To learn about Interactive segmentations please refer https://arxiv.org/abs/1903.08205 Args: transforms: execute additional transformation during every iteration (before train). From 7fcfea90fe1b112f65aa74511b44d026dcaff529 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Mon, 22 Feb 2021 14:13:12 +0000 Subject: [PATCH 4/4] fixes docs Signed-off-by: Wenqi Li --- docs/source/apps.rst | 10 ++++++++++ monai/apps/deepgrow/interaction.py | 5 ++++- 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/docs/source/apps.rst b/docs/source/apps.rst index f2afd93836..5301516b0b 100644 --- a/docs/source/apps.rst +++ b/docs/source/apps.rst @@ -28,3 +28,13 @@ Applications .. autofunction:: extractall .. autofunction:: download_and_extract + +`Deepgrow` +---------- + +.. automodule:: monai.apps.deepgrow.dataset +.. autofunction:: create_dataset + +.. automodule:: monai.apps.deepgrow.interaction +.. autoclass:: Interaction + :members: diff --git a/monai/apps/deepgrow/interaction.py b/monai/apps/deepgrow/interaction.py index e6c24a8483..77e271a9eb 100644 --- a/monai/apps/deepgrow/interaction.py +++ b/monai/apps/deepgrow/interaction.py @@ -21,7 +21,10 @@ class Interaction: """ Ignite handler used to introduce interactions (simulation of clicks) for Deepgrow Training/Evaluation. - To learn about Interactive segmentations please refer https://arxiv.org/abs/1903.08205 + This implementation is based on: + + Sakinis et al., Interactive segmentation of medical images through + fully convolutional neural networks. (2019) https://arxiv.org/abs/1903.08205 Args: transforms: execute additional transformation during every iteration (before train).