Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
a3e419c
Support to train/run Deepgrow 2D/3D models
SachidanandAlle Dec 23, 2020
0f626f6
Fix import dependencies
SachidanandAlle Dec 23, 2020
d15f04e
Fix import dependencies
SachidanandAlle Dec 23, 2020
95f11db
Fix import dependencies
SachidanandAlle Dec 23, 2020
9fd9456
Fix bbox for inference and stat handler for additional info
SachidanandAlle Dec 26, 2020
06879fe
Fix handler and transform init in iteration
SachidanandAlle Dec 27, 2020
ec8d0d7
fix transforms for training/inference
SachidanandAlle Jan 4, 2021
bf4d031
fix 2D transform
SachidanandAlle Jan 9, 2021
91aa795
Merge branch 'master' into master
SachidanandAlle Jan 15, 2021
a2c0c89
Fix ci build
SachidanandAlle Jan 15, 2021
41ad0af
Merge branch 'master' into master
SachidanandAlle Feb 3, 2021
3f70ccf
Merge branch 'master' into master
SachidanandAlle Feb 4, 2021
672ae85
Merge branch 'master' into master
SachidanandAlle Feb 8, 2021
cad793a
fix comments + add unit tests
SachidanandAlle Feb 8, 2021
e45e4e0
add deegpr tests to exlucde list
SachidanandAlle Feb 8, 2021
7c641b8
fix docs
SachidanandAlle Feb 9, 2021
8787d3e
Add more unittests and docstrings
YuanTingHsieh Feb 9, 2021
28191dd
fix build
SachidanandAlle Feb 9, 2021
f92c319
fix test
SachidanandAlle Feb 9, 2021
284fa13
fix test
SachidanandAlle Feb 9, 2021
cfdc769
fix build + docs
SachidanandAlle Feb 9, 2021
34a4422
Merge branch 'master' into master
SachidanandAlle Feb 9, 2021
cc81759
Fix connected_regions / refactor docstring / add unit tests
YuanTingHsieh Feb 11, 2021
d5a6d6a
remove tests in min_tests.py
YuanTingHsieh Feb 11, 2021
3893a62
fix review comments
SachidanandAlle Feb 11, 2021
19bd7b7
Merge branch 'master' into master
SachidanandAlle Feb 11, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions docs/source/apps.rst
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,35 @@ Applications
.. autofunction:: extractall

.. autofunction:: download_and_extract

`Deepgrow`
----------

.. automodule:: monai.apps.deepgrow.dataset
.. autofunction:: create_dataset

.. automodule:: monai.apps.deepgrow.interaction
.. autoclass:: Interaction
:members:

.. automodule:: monai.apps.deepgrow.transforms
.. autoclass:: AddInitialSeedPointd
:members:
.. autoclass:: AddGuidanceSignald
:members:
.. autoclass:: AddRandomGuidanced
:members:
.. autoclass:: AddGuidanceFromPointsd
:members:
.. autoclass:: SpatialCropForegroundd
:members:
.. autoclass:: SpatialCropGuidanced
:members:
.. autoclass:: RestoreCroppedLabeld
:members:
.. autoclass:: FindDiscrepancyRegionsd
:members:
.. autoclass:: FindAllValidSlicesd
:members:
.. autoclass:: Fetch2DSliced
:members:
10 changes: 10 additions & 0 deletions monai/apps/deepgrow/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Copyright 2020 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
268 changes: 268 additions & 0 deletions monai/apps/deepgrow/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,268 @@
# Copyright 2020 - 2021 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import os
from typing import Dict, List

import numpy as np

from monai.transforms import AsChannelFirstd, Compose, LoadImaged, Orientationd, Spacingd
from monai.utils import GridSampleMode


def create_dataset(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is create_datalist instead of create_dataset, right?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should be create_dataset.
As it reads in the raw images and it will transform them into new volumes (a lot of .npy files) and store them back on disk.
So that means it creates a new dataset.
Just that this function also creates a datalist that has the entries of the new dataset and returns it.

datalist,
output_dir,
dimension,
pixdim,
keys=("image", "label"),
base_dir=None,
limit=0,
relative_path=False,
transforms=None,
) -> List[Dict]:
"""
Utility to pre-process and create dataset list for Deepgrow training over on existing one.
The input data list is normally a list of images and labels (3D volume) that needs pre-processing
for Deepgrow training pipeline.

Args:
datalist: A generic dataset with a length property which normally contains a list of data dictionary.
For example, typical input data can be a list of dictionaries::

[{'image': 'img1.nii', 'label': 'label1.nii'}]

output_dir: target directory to store the training data for Deepgrow Training
pixdim: output voxel spacing.
dimension: dimension for Deepgrow training. It can be 2 or 3.
keys: Image and Label keys in input datalist. Defaults to 'image' and 'label'
base_dir: base directory in case related path is used for the keys in datalist. Defaults to None.
limit: limit number of inputs for pre-processing. Defaults to 0 (no limit).
relative_path: output keys values should be based on relative path. Defaults to False.
transforms: explicit transforms to execute operations on input data.

Raises:
ValueError: When ``dimension`` is not one of [2, 3]
ValueError: When ``datalist`` is Empty

Example::

datalist = create_dataset(
datalist=[{'image': 'img1.nii', 'label': 'label1.nii'}],
base_dir=None,
output_dir=output_2d,
dimension=2,
keys=('image', 'label')
pixdim=(1.0, 1.0),
limit=0,
relative_path=True
)

print(datalist[0]["image"], datalist[0]["label"])
"""

if dimension not in [2, 3]:
raise ValueError("Dimension can be only 2 or 3 as Deepgrow supports only 2D/3D Training")

if not len(datalist):
raise ValueError("Input Datalist is empty")

if not isinstance(keys, list) and not isinstance(keys, tuple):
keys = [keys]

transforms = _default_transforms(keys, pixdim) if transforms is None else transforms
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @wyli @SachidanandAlle , I feel this create_datalist() function actually can be a dict transform which works after these transforms, what do you think?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

based on Daguang/research ask, we are providing only utility to flatten/pre-process the data prior to training.
otherwise, there can be very good chained dataset logic possible where some transforms contribute to create new set of dataset. keeping it simple based on the ask.

new_datalist = []
for idx in range(len(datalist)):
if limit and idx >= limit:
break

image = datalist[idx][keys[0]]
label = datalist[idx].get(keys[1]) if len(keys) > 1 else None
if base_dir:
image = os.path.join(base_dir, image)
label = os.path.join(base_dir, label) if label else None

image = os.path.abspath(image)
label = os.path.abspath(label) if label else None

logging.info("Image: {}; Label: {}".format(image, label if label else None))
if dimension == 2:
data = _save_data_2d(
vol_idx=idx,
data=transforms({"image": image, "label": label}),
keys=("image", "label"),
dataset_dir=output_dir,
relative_path=relative_path,
)
else:
data = _save_data_3d(
vol_idx=idx,
data=transforms({"image": image, "label": label}),
keys=("image", "label"),
dataset_dir=output_dir,
relative_path=relative_path,
)
new_datalist.extend(data)
return new_datalist


def _default_transforms(keys, pixdim):
mode = [GridSampleMode.BILINEAR, GridSampleMode.NEAREST] if len(keys) == 2 else [GridSampleMode.BILINEAR]
return Compose([
LoadImaged(keys=keys),
AsChannelFirstd(keys=keys),
Spacingd(keys=keys, pixdim=pixdim, mode=mode),
Orientationd(keys=keys, axcodes="RAS"),
])


def _save_data_2d(vol_idx, data, keys, dataset_dir, relative_path):
vol_image = data[keys[0]]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe it's better to input image, label in args directly,

vol_label = data.get(keys[1])
data_list = []

if len(vol_image.shape) == 4:
logging.info(
"4D-Image, pick only first series; Image: {}; Label: {}".format(
vol_image.shape, vol_label.shape if vol_label else None
)
)
vol_image = vol_image[0]
vol_image = np.moveaxis(vol_image, -1, 0)

image_count = 0
label_count = 0
unique_labels_count = 0
for sid in range(vol_image.shape[0]):
image = vol_image[sid, ...]
label = vol_label[sid, ...] if vol_label is not None else None

if vol_label is not None and np.sum(label) == 0:
continue

image_file_prefix = "vol_idx_{:0>4d}_slice_{:0>3d}".format(vol_idx, sid)
image_file = os.path.join(dataset_dir, "images", image_file_prefix)
image_file += ".npy"

os.makedirs(os.path.join(dataset_dir, "images"), exist_ok=True)
np.save(image_file, image)
image_count += 1

# Test Data
if vol_label is None:
data_list.append(
{
"image": image_file.replace(dataset_dir + "/", "") if relative_path else image_file,
}
)
continue

# For all Labels
unique_labels = np.unique(label.flatten())
unique_labels = unique_labels[unique_labels != 0]
unique_labels_count = max(unique_labels_count, len(unique_labels))

for idx in unique_labels:
label_file_prefix = "{}_region_{:0>2d}".format(image_file_prefix, int(idx))
label_file = os.path.join(dataset_dir, "labels", label_file_prefix)
label_file += ".npy"

os.makedirs(os.path.join(dataset_dir, "labels"), exist_ok=True)
curr_label = (label == idx).astype(np.float32)
np.save(label_file, curr_label)

label_count += 1
data_list.append(
{
"image": image_file.replace(dataset_dir + "/", "") if relative_path else image_file,
"label": label_file.replace(dataset_dir + "/", "") if relative_path else label_file,
"region": int(idx),
}
)

logging.info(
"{} => Image Shape: {} => {}; Label Shape: {} => {}; Unique Labels: {}".format(
vol_idx,
vol_image.shape,
image_count,
vol_label.shape if vol_label is not None else None,
label_count,
unique_labels_count,
)
)
return data_list


def _save_data_3d(vol_idx, data, keys, dataset_dir, relative_path):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think 2D and 3D have very similar logic, could you please merge them into 1 function?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

acutally its different.. 2D flattens slices over multiple labels and 3D only does for labels.
common things are already taken out of these 2 private functions.

vol_image = data[keys[0]]
vol_label = data.get(keys[1])
data_list = []

if len(vol_image.shape) == 4:
logging.info("4D-Image, pick only first series; Image: {}; Label: {}".format(vol_image.shape, vol_label.shape))
vol_image = vol_image[0]
vol_image = np.moveaxis(vol_image, -1, 0)

image_count = 0
label_count = 0
unique_labels_count = 0

image_file_prefix = "vol_idx_{:0>4d}".format(vol_idx)
image_file = os.path.join(dataset_dir, "images", image_file_prefix)
image_file += ".npy"

os.makedirs(os.path.join(dataset_dir, "images"), exist_ok=True)
np.save(image_file, vol_image)
image_count += 1

# Test Data
if vol_label is None:
data_list.append(
{
"image": image_file.replace(dataset_dir + "/", "") if relative_path else image_file,
}
)
else:
# For all Labels
unique_labels = np.unique(vol_label.flatten())
unique_labels = unique_labels[unique_labels != 0]
unique_labels_count = max(unique_labels_count, len(unique_labels))

for idx in unique_labels:
label_file_prefix = "{}_region_{:0>2d}".format(image_file_prefix, int(idx))
label_file = os.path.join(dataset_dir, "labels", label_file_prefix)
label_file += ".npy"

curr_label = (vol_label == idx).astype(np.float32)
os.makedirs(os.path.join(dataset_dir, "labels"), exist_ok=True)
np.save(label_file, curr_label)

label_count += 1
data_list.append(
{
"image": image_file.replace(dataset_dir + "/", "") if relative_path else image_file,
"label": label_file.replace(dataset_dir + "/", "") if relative_path else label_file,
"region": int(idx),
}
)

logging.info(
"{} => Image Shape: {} => {}; Label Shape: {} => {}; Unique Labels: {}".format(
vol_idx,
vol_image.shape,
image_count,
vol_label.shape if vol_label is not None else None,
label_count,
unique_labels_count,
)
)
return data_list
81 changes: 81 additions & 0 deletions monai/apps/deepgrow/interaction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# Copyright 2020 - 2021 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
from typing import Dict

import torch

from monai.engines.utils import CommonKeys
from monai.engines.workflow import Engine, Events
from monai.transforms import Compose


class Interaction:
"""
Ignite handler used to introduce interactions (simulation of clicks) for Deepgrow Training/Evaluation.

Args:
transforms: execute additional transformation during every iteration (before train).
Typically, several Tensor based transforms composed by `Compose`.
max_interactions: maximum number of interactions per iteration
train: training or evaluation
key_probability: field name to fill probability for every interaction
"""

def __init__(self, transforms, max_interactions: int, train: bool, key_probability: str = "probability") -> None:
self.transforms = transforms
self.max_interactions = max_interactions
self.train = train
self.key_probability = key_probability

if not isinstance(self.transforms, Compose):
transforms = []
for t in self.transforms:
transforms.append(self.init_external_class(t))
self.transforms = Compose(transforms)

@staticmethod
def init_external_class(config_dict):
class_args = None if config_dict.get("args") is None else dict(config_dict.get("args"))
class_path = config_dict.get("path", config_dict["name"])

module_name, class_name = class_path.rsplit(".", 1)
m = importlib.import_module(module_name)
c = getattr(m, class_name)
return c(**class_args) if class_args else c()

def attach(self, engine: Engine) -> None:
if not engine.has_event_handler(self, Events.ITERATION_STARTED):
engine.add_event_handler(Events.ITERATION_STARTED, self)

def __call__(self, engine: Engine, batchdata: Dict[str, torch.Tensor]):
if batchdata is None:
raise ValueError("Must provide batch data for current iteration.")

for j in range(self.max_interactions):
inputs, _ = engine.prepare_batch(batchdata)
inputs = inputs.to(engine.state.device)

engine.network.eval()
with torch.no_grad():
if engine.amp:
with torch.cuda.amp.autocast():
predictions = engine.inferer(inputs, engine.network)
else:
predictions = engine.inferer(inputs, engine.network)

batchdata.update({CommonKeys.PRED: predictions})
batchdata[self.key_probability] = torch.as_tensor(
([1.0 - ((1.0 / self.max_interactions) * j)] if self.train else [1.0]) * len(inputs)
)
batchdata = self.transforms(batchdata)

return engine._iteration(engine, batchdata)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As this handler will attach to ITERATION_STARTED and the engine will run _iteration() later, why you call engine._iteration here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the batchdata is updated... and the _iteration() must happen on new modified batchdata. will returning batchdata is enough? and _iteration() in later call uses the updated batchdata? in that case, we need one round of training to make sure it works as expected.

this code was based on your initial feedback on how to run interactions while training. may be something got changed recently. nevertheless, you can confirm what to make _iteration run on new/updated batchdata in above case.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggest to use:

engine.state.batch = self.transforms(batchdata)

Then return nothing in this function.

Thanks.

Copy link
Contributor Author

@SachidanandAlle SachidanandAlle Feb 11, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you mean engine.state.batch=batchdata
but it failed on post transform ( i guess engine.state.output is empty )

Traceback (most recent call last):
File "/opt/conda/lib/python3.6/site-packages/ignite/engine/engine.py", line 730, in _internal_run
time_taken = self._run_once_on_dataset()
File "/opt/conda/lib/python3.6/site-packages/ignite/engine/engine.py", line 828, in _run_once_on_dataset
self._handle_exception(e)
File "/opt/conda/lib/python3.6/site-packages/ignite/engine/engine.py", line 465, in _handle_exception
self._fire_event(Events.EXCEPTION_RAISED, e)
File "/opt/conda/lib/python3.6/site-packages/ignite/engine/engine.py", line 423, in _fire_event
func(*first, *(event_args + others), **kwargs)
File "/opt/monai/monai/handlers/stats_handler.py", line 145, in exception_raised
raise e
File "/opt/conda/lib/python3.6/site-packages/ignite/engine/engine.py", line 812, in _run_once_on_dataset
self._fire_event(Events.ITERATION_COMPLETED)
File "/opt/conda/lib/python3.6/site-packages/ignite/engine/engine.py", line 423, in _fire_event
func(*first, *(event_args + others), **kwargs)
File "/opt/monai/monai/engines/workflow.py", line 150, in run_post_transform
engine.state.output = apply_transform(posttrans, engine.state.output)
File "/opt/monai/monai/transforms/utils.py", line 387, in apply_transform
raise RuntimeError(f"applying transform {transform}") from e
RuntimeError: applying transform <monai.transforms.compose.Compose object at 0x7ff083933e80>
Traceback (most recent call last):
File "/opt/monai/monai/transforms/utils.py", line 385, in apply_transform
return transform(data)
File "/opt/monai/monai/transforms/post/dictionary.py", line 91, in call
d = dict(data)
TypeError: 'NoneType' object is not iterable

Traceback (most recent call last):
File "/opt/conda/lib/python3.6/runpy.py", line 193, in _run_module_as_main
"main", mod_spec)
File "/opt/conda/lib/python3.6/runpy.py", line 85, in _run_code
exec(code, run_globals)
File "apps/train.py", line 35, in
File "apps/train.py", line 27, in main
File "apps/mmar_conf.py", line 32, in train_mmar
File "/opt/monai/monai/engines/trainer.py", line 48, in run
super().run()
File "/opt/monai/monai/engines/workflow.py", line 191, in run
super().run(data=self.data_loader, max_epochs=self.state.max_epochs)
File "/opt/conda/lib/python3.6/site-packages/ignite/engine/engine.py", line 691, in run
return self._internal_run()
File "/opt/conda/lib/python3.6/site-packages/ignite/engine/engine.py", line 762, in _internal_run
self._handle_exception(e)
File "/opt/conda/lib/python3.6/site-packages/ignite/engine/engine.py", line 465, in _handle_exception
self._fire_event(Events.EXCEPTION_RAISED, e)
File "/opt/conda/lib/python3.6/site-packages/ignite/engine/engine.py", line 423, in _fire_event
func(*first, *(event_args + others), **kwargs)
File "/opt/monai/monai/handlers/stats_handler.py", line 145, in exception_raised
raise e
File "/opt/conda/lib/python3.6/site-packages/ignite/engine/engine.py", line 730, in _internal_run
time_taken = self._run_once_on_dataset()
File "/opt/conda/lib/python3.6/site-packages/ignite/engine/engine.py", line 828, in _run_once_on_dataset
self._handle_exception(e)
File "/opt/conda/lib/python3.6/site-packages/ignite/engine/engine.py", line 465, in _handle_exception
self._fire_event(Events.EXCEPTION_RAISED, e)
File "/opt/conda/lib/python3.6/site-packages/ignite/engine/engine.py", line 423, in _fire_event
func(*first, *(event_args + others), **kwargs)
File "/opt/monai/monai/handlers/stats_handler.py", line 145, in exception_raised
raise e
File "/opt/conda/lib/python3.6/site-packages/ignite/engine/engine.py", line 812, in _run_once_on_dataset
self._fire_event(Events.ITERATION_COMPLETED)
File "/opt/conda/lib/python3.6/site-packages/ignite/engine/engine.py", line 423, in _fire_event
func(*first, *(event_args + others), **kwargs)
File "/opt/monai/monai/engines/workflow.py", line 150, in run_post_transform
engine.state.output = apply_transform(posttrans, engine.state.output)
File "/opt/monai/monai/transforms/utils.py", line 387, in apply_transform
raise RuntimeError(f"applying transform {transform}") from e
RuntimeError: applying transform <monai.transforms.compose.Compose object at 0x7ff083933e80>

Loading