diff --git a/docs/source/apps.rst b/docs/source/apps.rst index f2afd93836..5301516b0b 100644 --- a/docs/source/apps.rst +++ b/docs/source/apps.rst @@ -28,3 +28,13 @@ Applications .. autofunction:: extractall .. autofunction:: download_and_extract + +`Deepgrow` +---------- + +.. automodule:: monai.apps.deepgrow.dataset +.. autofunction:: create_dataset + +.. automodule:: monai.apps.deepgrow.interaction +.. autoclass:: Interaction + :members: diff --git a/monai/apps/deepgrow/interaction.py b/monai/apps/deepgrow/interaction.py new file mode 100644 index 0000000000..77e271a9eb --- /dev/null +++ b/monai/apps/deepgrow/interaction.py @@ -0,0 +1,79 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Callable, Dict, Sequence, Union + +import torch + +from monai.engines import SupervisedEvaluator, SupervisedTrainer +from monai.engines.utils import CommonKeys +from monai.engines.workflow import Events +from monai.transforms import Compose + + +class Interaction: + """ + Ignite handler used to introduce interactions (simulation of clicks) for Deepgrow Training/Evaluation. + This implementation is based on: + + Sakinis et al., Interactive segmentation of medical images through + fully convolutional neural networks. (2019) https://arxiv.org/abs/1903.08205 + + Args: + transforms: execute additional transformation during every iteration (before train). + Typically, several Tensor based transforms composed by `Compose`. + max_interactions: maximum number of interactions per iteration + train: training or evaluation + key_probability: field name to fill probability for every interaction + """ + + def __init__( + self, + transforms: Union[Sequence[Callable], Callable], + max_interactions: int, + train: bool, + key_probability: str = "probability", + ) -> None: + + if not isinstance(transforms, Compose): + transforms = Compose(transforms) + + self.transforms = transforms + self.max_interactions = max_interactions + self.train = train + self.key_probability = key_probability + + def attach(self, engine: Union[SupervisedTrainer, SupervisedEvaluator]) -> None: + if not engine.has_event_handler(self, Events.ITERATION_STARTED): + engine.add_event_handler(Events.ITERATION_STARTED, self) + + def __call__(self, engine: Union[SupervisedTrainer, SupervisedEvaluator], batchdata: Dict[str, torch.Tensor]): + if batchdata is None: + raise ValueError("Must provide batch data for current iteration.") + + for j in range(self.max_interactions): + inputs, _ = engine.prepare_batch(batchdata) + inputs = inputs.to(engine.state.device) + + engine.network.eval() + with torch.no_grad(): + if engine.amp: + with torch.cuda.amp.autocast(): + predictions = engine.inferer(inputs, engine.network) + else: + predictions = engine.inferer(inputs, engine.network) + + batchdata.update({CommonKeys.PRED: predictions}) + batchdata[self.key_probability] = torch.as_tensor( + ([1.0 - ((1.0 / self.max_interactions) * j)] if self.train else [1.0]) * len(inputs) + ) + batchdata = self.transforms(batchdata) + + return engine._iteration(engine, batchdata) diff --git a/tests/min_tests.py b/tests/min_tests.py index e2c7bc529a..24a2667bfa 100644 --- a/tests/min_tests.py +++ b/tests/min_tests.py @@ -104,6 +104,7 @@ def run_testsuit(): "test_handler_metrics_saver_dist", "test_evenly_divisible_all_gather_dist", "test_handler_classification_saver_dist", + "test_deepgrow_interaction", "test_deepgrow_dataset", ] assert sorted(exclude_cases) == sorted(set(exclude_cases)), f"Duplicated items in {exclude_cases}" diff --git a/tests/test_deepgrow_interaction.py b/tests/test_deepgrow_interaction.py new file mode 100644 index 0000000000..272ae82a5b --- /dev/null +++ b/tests/test_deepgrow_interaction.py @@ -0,0 +1,62 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch + +from monai.apps.deepgrow.interaction import Interaction +from monai.data import Dataset +from monai.engines import SupervisedTrainer +from monai.transforms import Activationsd, Compose, ToNumpyd + + +class TestInteractions(unittest.TestCase): + def run_interaction(self, train, compose): + data = [] + for i in range(5): + data.append({"image": torch.tensor([float(i)]), "label": torch.tensor([float(i)])}) + network = torch.nn.Linear(1, 1) + lr = 1e-3 + opt = torch.optim.SGD(network.parameters(), lr) + loss = torch.nn.L1Loss() + dataset = Dataset(data, transform=None) + data_loader = torch.utils.data.DataLoader(dataset, batch_size=5) + + iteration_transforms = [Activationsd(keys="pred", sigmoid=True), ToNumpyd(keys="pred")] + iteration_transforms = Compose(iteration_transforms) if compose else iteration_transforms + + i = Interaction(transforms=iteration_transforms, train=train, max_interactions=5) + self.assertEqual(len(i.transforms.transforms), 2, "Mismatch in expected transforms") + + # set up engine + engine = SupervisedTrainer( + device=torch.device("cpu"), + max_epochs=1, + train_data_loader=data_loader, + network=network, + optimizer=opt, + loss_function=loss, + iteration_update=i, + ) + + engine.run() + self.assertIsNotNone(engine.state.batch.get("probability"), "Probability is missing") + + def test_train_interaction(self): + self.run_interaction(train=True, compose=True) + + def test_val_interaction(self): + self.run_interaction(train=False, compose=False) + + +if __name__ == "__main__": + unittest.main()