Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions docs/source/apps.rst
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,13 @@ Applications
.. autofunction:: extractall

.. autofunction:: download_and_extract

`Deepgrow`
----------

.. automodule:: monai.apps.deepgrow.dataset
.. autofunction:: create_dataset

.. automodule:: monai.apps.deepgrow.interaction
.. autoclass:: Interaction
:members:
79 changes: 79 additions & 0 deletions monai/apps/deepgrow/interaction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# Copyright 2020 - 2021 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable, Dict, Sequence, Union

import torch

from monai.engines import SupervisedEvaluator, SupervisedTrainer
from monai.engines.utils import CommonKeys
from monai.engines.workflow import Events
from monai.transforms import Compose


class Interaction:
"""
Ignite handler used to introduce interactions (simulation of clicks) for Deepgrow Training/Evaluation.
This implementation is based on:

Sakinis et al., Interactive segmentation of medical images through
fully convolutional neural networks. (2019) https://arxiv.org/abs/1903.08205

Args:
transforms: execute additional transformation during every iteration (before train).
Typically, several Tensor based transforms composed by `Compose`.
max_interactions: maximum number of interactions per iteration
train: training or evaluation
key_probability: field name to fill probability for every interaction
"""

def __init__(
self,
transforms: Union[Sequence[Callable], Callable],
max_interactions: int,
train: bool,
key_probability: str = "probability",
) -> None:

if not isinstance(transforms, Compose):
transforms = Compose(transforms)

self.transforms = transforms
self.max_interactions = max_interactions
self.train = train
self.key_probability = key_probability

def attach(self, engine: Union[SupervisedTrainer, SupervisedEvaluator]) -> None:
if not engine.has_event_handler(self, Events.ITERATION_STARTED):
engine.add_event_handler(Events.ITERATION_STARTED, self)

def __call__(self, engine: Union[SupervisedTrainer, SupervisedEvaluator], batchdata: Dict[str, torch.Tensor]):
if batchdata is None:
raise ValueError("Must provide batch data for current iteration.")

for j in range(self.max_interactions):
inputs, _ = engine.prepare_batch(batchdata)
inputs = inputs.to(engine.state.device)

engine.network.eval()
with torch.no_grad():
if engine.amp:
with torch.cuda.amp.autocast():
predictions = engine.inferer(inputs, engine.network)
else:
predictions = engine.inferer(inputs, engine.network)

batchdata.update({CommonKeys.PRED: predictions})
batchdata[self.key_probability] = torch.as_tensor(
([1.0 - ((1.0 / self.max_interactions) * j)] if self.train else [1.0]) * len(inputs)
)
batchdata = self.transforms(batchdata)

return engine._iteration(engine, batchdata)
1 change: 1 addition & 0 deletions tests/min_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ def run_testsuit():
"test_handler_metrics_saver_dist",
"test_evenly_divisible_all_gather_dist",
"test_handler_classification_saver_dist",
"test_deepgrow_interaction",
"test_deepgrow_dataset",
]
assert sorted(exclude_cases) == sorted(set(exclude_cases)), f"Duplicated items in {exclude_cases}"
Expand Down
62 changes: 62 additions & 0 deletions tests/test_deepgrow_interaction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Copyright 2020 - 2021 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import torch

from monai.apps.deepgrow.interaction import Interaction
from monai.data import Dataset
from monai.engines import SupervisedTrainer
from monai.transforms import Activationsd, Compose, ToNumpyd


class TestInteractions(unittest.TestCase):
def run_interaction(self, train, compose):
data = []
for i in range(5):
data.append({"image": torch.tensor([float(i)]), "label": torch.tensor([float(i)])})
network = torch.nn.Linear(1, 1)
lr = 1e-3
opt = torch.optim.SGD(network.parameters(), lr)
loss = torch.nn.L1Loss()
dataset = Dataset(data, transform=None)
data_loader = torch.utils.data.DataLoader(dataset, batch_size=5)

iteration_transforms = [Activationsd(keys="pred", sigmoid=True), ToNumpyd(keys="pred")]
iteration_transforms = Compose(iteration_transforms) if compose else iteration_transforms

i = Interaction(transforms=iteration_transforms, train=train, max_interactions=5)
self.assertEqual(len(i.transforms.transforms), 2, "Mismatch in expected transforms")

# set up engine
engine = SupervisedTrainer(
device=torch.device("cpu"),
max_epochs=1,
train_data_loader=data_loader,
network=network,
optimizer=opt,
loss_function=loss,
iteration_update=i,
)

engine.run()
self.assertIsNotNone(engine.state.batch.get("probability"), "Probability is missing")

def test_train_interaction(self):
self.run_interaction(train=True, compose=True)

def test_val_interaction(self):
self.run_interaction(train=False, compose=False)


if __name__ == "__main__":
unittest.main()