Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
185 changes: 185 additions & 0 deletions examples/unet_segmentation_3d_dict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
# Copyright 2020 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import sys
import tempfile
from glob import glob
import logging

import nibabel as nib
import numpy as np
import torch
from torch.utils.tensorboard import SummaryWriter
from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator, _prepare_batch
from ignite.handlers import ModelCheckpoint, EarlyStopping
from torch.utils.data import DataLoader

# assumes the framework is found here, change as necessary
sys.path.append("..")

import monai
import monai.transforms.compose as transforms
from monai.data.nifti_reader import NiftiDatasetd
from monai.transforms.composables import AddChanneld, RandRotate90d
from monai.handlers.stats_handler import StatsHandler
from monai.handlers.mean_dice import MeanDice
from monai.visualize import img2tensorboard
from monai.data.synthetic import create_test_image_3d
from monai.handlers.utils import stopping_fn_from_metric

monai.config.print_config()

# Create a temporary directory and 50 random image, mask paris
tempdir = tempfile.mkdtemp()

for i in range(50):
im, seg = create_test_image_3d(128, 128, 128)

n = nib.Nifti1Image(im, np.eye(4))
nib.save(n, os.path.join(tempdir, 'im%i.nii.gz' % i))

n = nib.Nifti1Image(seg, np.eye(4))
nib.save(n, os.path.join(tempdir, 'seg%i.nii.gz' % i))

images = sorted(glob(os.path.join(tempdir, 'im*.nii.gz')))
segs = sorted(glob(os.path.join(tempdir, 'seg*.nii.gz')))

# Define transforms for image and segmentation
transforms = transforms.Compose([
AddChanneld(keys=['image', 'seg']),
RandRotate90d(keys=['image', 'seg'], prob=0.8, axes=[1, 3])
])

# Define nifti dataset, dataloader.
ds = NiftiDatasetd(images, segs, transform=transforms)
loader = DataLoader(ds, batch_size=10, num_workers=2, pin_memory=torch.cuda.is_available())
check_data = monai.utils.misc.first(loader)
print(check_data['image'].shape, check_data['seg'].shape)

lr = 1e-5

# Create UNet, DiceLoss and Adam optimizer.
net = monai.networks.nets.UNet(
dimensions=3,
in_channels=1,
num_classes=1,
channels=(16, 32, 64, 128, 256),
strides=(2, 2, 2, 2),
num_res_units=2,
)

loss = monai.losses.DiceLoss(do_sigmoid=True)
opt = torch.optim.Adam(net.parameters(), lr)


# Since network outputs logits and segmentation, we need a custom function.
def _loss_fn(i, j):
return loss(i[0], j)


# Create trainer
def prepare_batch(batch, device=None, non_blocking=False):
return _prepare_batch((batch['image'], batch['seg']), device, non_blocking)


device = torch.device("cuda:0")
trainer = create_supervised_trainer(net, opt, _loss_fn, device, False,
prepare_batch=prepare_batch,
output_transform=lambda x, y, y_pred, loss: [y_pred, loss.item(), y])

# adding checkpoint handler to save models (network params and optimizer stats) during training
checkpoint_handler = ModelCheckpoint('./runs/', 'net', n_saved=10, require_empty=False)
trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED,
handler=checkpoint_handler,
to_save={'net': net, 'opt': opt})
train_stats_handler = StatsHandler()
train_stats_handler.attach(trainer)


@trainer.on(Events.EPOCH_COMPLETED)
def log_training_loss(engine):
# log loss to tensorboard with second item of engine.state.output, loss.item() from output_transform
writer.add_scalar('Loss/train', engine.state.output[1], engine.state.epoch)

# tensor of ones to use where for converting labels to zero and ones
ones = torch.ones(engine.state.batch['seg'][0].shape, dtype=torch.int32)
first_output_tensor = engine.state.output[0][1][0].detach().cpu()
# log model output to tensorboard, as three dimensional tensor with no channels dimension
img2tensorboard.add_animated_gif_no_channels(writer, "first_output_final_batch", first_output_tensor, 64,
255, engine.state.epoch)
# get label tensor and convert to single class
first_label_tensor = torch.where(engine.state.batch['seg'][0] > 0, ones, engine.state.batch['seg'][0])
# log label tensor to tensorboard, there is a channel dimension when getting label from batch
img2tensorboard.add_animated_gif(writer, "first_label_final_batch", first_label_tensor, 64,
255, engine.state.epoch)
second_output_tensor = engine.state.output[0][1][1].detach().cpu()
img2tensorboard.add_animated_gif_no_channels(writer, "second_output_final_batch", second_output_tensor, 64,
255, engine.state.epoch)
second_label_tensor = torch.where(engine.state.batch['seg'][1] > 0, ones, engine.state.batch['seg'][1])
img2tensorboard.add_animated_gif(writer, "second_label_final_batch", second_label_tensor, 64,
255, engine.state.epoch)
third_output_tensor = engine.state.output[0][1][2].detach().cpu()
img2tensorboard.add_animated_gif_no_channels(writer, "third_output_final_batch", third_output_tensor, 64,
255, engine.state.epoch)
third_label_tensor = torch.where(engine.state.batch['seg'][2] > 0, ones, engine.state.batch['seg'][2])
img2tensorboard.add_animated_gif(writer, "third_label_final_batch", third_label_tensor, 64,
255, engine.state.epoch)
engine.logger.info("Epoch[%s] Loss: %s", engine.state.epoch, engine.state.output[1])


writer = SummaryWriter()

# Set parameters for validation
validation_every_n_epochs = 1
metric_name = 'Mean_Dice'

# add evaluation metric to the evaluator engine
val_metrics = {metric_name: MeanDice(add_sigmoid=True, to_onehot_y=False)}
evaluator = create_supervised_evaluator(net, val_metrics, device, True,
prepare_batch=prepare_batch,
output_transform=lambda x, y, y_pred: (y_pred[0], y))

# Add stats event handler to print validation stats via evaluator
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
val_stats_handler = StatsHandler()
val_stats_handler.attach(evaluator)

# Add early stopping handler to evaluator.
early_stopper = EarlyStopping(patience=4,
score_function=stopping_fn_from_metric(metric_name),
trainer=trainer)
evaluator.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=early_stopper)

# create a validation data loader
val_ds = NiftiDatasetd(images[-20:], segs[-20:], transform=transforms)
val_loader = DataLoader(ds, batch_size=5, num_workers=8, pin_memory=torch.cuda.is_available())


@trainer.on(Events.EPOCH_COMPLETED(every=validation_every_n_epochs))
def run_validation(engine):
evaluator.run(val_loader)


@evaluator.on(Events.EPOCH_COMPLETED)
def log_metrics_to_tensorboard(engine):
for _, value in engine.state.metrics.items():
writer.add_scalar('Metrics/' + metric_name, value, trainer.state.epoch)


# create a training data loader
logging.basicConfig(stream=sys.stdout, level=logging.INFO)

train_ds = NiftiDatasetd(images[:20], segs[:20], transform=transforms)
train_loader = DataLoader(train_ds, batch_size=5, num_workers=8, pin_memory=torch.cuda.is_available())

train_epochs = 30
state = trainer.run(train_loader, train_epochs)
75 changes: 74 additions & 1 deletion monai/data/nifti_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

from torch.utils.data import Dataset
from torch.utils.data._utils.collate import np_str_obj_array_pattern

from monai.utils.module import export
from monai.transforms.compose import Randomizable

Expand Down Expand Up @@ -107,6 +106,7 @@ def __getitem__(self, index):
else:
img, meta_data = load_nifti(self.image_files[index], as_closest_canonical=self.as_closest_canonical,
image_only=self.image_only, dtype=self.dtype)
target = None
if self.seg_files is not None:
target = load_nifti(self.seg_files[index])
elif self.labels is not None:
Expand Down Expand Up @@ -135,3 +135,76 @@ def __getitem__(self, index):
continue
compatible_meta[meta_key] = meta_datum
return img, target, compatible_meta


@export("monai.data")
class NiftiDatasetd(Dataset):
"""
Loads image/segmentation pairs of Nifti files from the given filename lists. Dict level transformations can be
specified for the dictionary data which is constructed by image, label and other metadata.
"""

def __init__(self, image_files, seg_files=None, labels=None, as_closest_canonical=False, transform=None, dtype=None):
"""
Initializes the dataset with the image and segmentation filename lists. The transform `transform` is applied
to the images and `seg_transform` to the segmentations.

Args:
image_files (list of str): list of image filenames.
seg_files (list of str): if in segmentation task, list of segmentation filenames.
labels (list or array): if in classification task, list of classification labels.
as_closest_canonical (bool): if True, load the image as closest to canonical orientation.
transform (Callable, optional): dict transforms to excute operations on dictionary data.
dtype (np.dtype, optional): if not None convert the loaded image to this data type.
"""

if len(image_files) != len(seg_files):
raise ValueError('Must have same number of image and segmentation files')

self.image_files = image_files
self.seg_files = seg_files
self.labels = labels
self.as_closest_canonical = as_closest_canonical
self.transform = transform
self.dtype = dtype

def __len__(self):
return len(self.image_files)

def __getitem__(self, index):
meta_data = None
img, meta_data = load_nifti(
filename_or_obj=self.image_files[index],
as_closest_canonical=self.as_closest_canonical,
image_only=False,
dtype=self.dtype
)

seg = None
if self.seg_files is not None:
seg = load_nifti(self.seg_files[index])
label = None
if self.labels is not None:
label = self.labels[index]

compatible_meta = {}
assert isinstance(meta_data, dict), 'meta_data must be in dictionary format.'
for meta_key in meta_data:
meta_datum = meta_data[meta_key]
if type(meta_datum).__name__ == 'ndarray' \
and np_str_obj_array_pattern.search(meta_datum.dtype.str) is not None:
continue
compatible_meta[meta_key] = meta_datum

data = {'image': img}
if seg is not None:
data['seg'] = seg
if label is not None:
data['label'] = label
if len(compatible_meta) > 0:
data.update(compatible_meta)

if self.transform is not None:
data = self.transform(data)

return data
24 changes: 23 additions & 1 deletion monai/transforms/composables.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import monai
from monai.data.utils import get_random_patch, get_valid_patch_size
from monai.transforms.compose import Randomizable, Transform
from monai.transforms.transforms import Rotate90, SpatialCrop
from monai.transforms.transforms import Rotate90, SpatialCrop, AddChannel
from monai.utils.misc import ensure_tuple
from monai.transforms.utils import generate_pos_neg_label_crop_centers

Expand Down Expand Up @@ -149,6 +149,28 @@ def __call__(self, data):
return d


@export
class AddChanneld(MapTransform):
"""
dictionary-based wrapper of AddChannel.
"""

def __init__(self, keys):
"""
Args:
keys (hashable items): keys of the corresponding items to be transformed.
See also: monai.transform.composables.MapTransform
"""
MapTransform.__init__(self, keys)
self.adder = AddChannel()

def __call__(self, data):
d = dict(data)
for key in self.keys:
d[key] = self.adder(d[key])
return d


@export
class RandCropByPosNegLabeld(Randomizable, MapTransform):
"""
Expand Down
2 changes: 1 addition & 1 deletion monai/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ def __init__(self, k=1, axes=(1, 2)):
self.plane_axes = axes

def __call__(self, img):
return np.rot90(img, self.k, self.plane_axes)
return np.ascontiguousarray(np.rot90(img, self.k, self.plane_axes))


@export
Expand Down
37 changes: 37 additions & 0 deletions tests/test_add_channeld.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Copyright 2020 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
import numpy as np
from parameterized import parameterized
from monai.transforms.composables import AddChanneld

TEST_CASE_1 = [
{'keys': ['img', 'seg']},
{
'img': np.array([[0, 1], [1, 2]]),
'seg': np.array([[0, 1], [1, 2]])
},
(1, 2, 2),
]


class TestAddChanneld(unittest.TestCase):

@parameterized.expand([TEST_CASE_1])
def test_shape(self, input_param, input_data, expected_shape):
result = AddChanneld(**input_param)(input_data)
self.assertEqual(result['img'].shape, expected_shape)
self.assertEqual(result['seg'].shape, expected_shape)


if __name__ == '__main__':
unittest.main()