Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -17,25 +17,28 @@
import nibabel as nib
import numpy as np
import torch
import torchvision.transforms as transforms
from ignite.engine import Engine
from torch.utils.data import DataLoader

# assumes the framework is found here, change as necessary
sys.path.append("..")

from monai import config
from monai.handlers.checkpoint_loader import CheckpointLoader
from monai.handlers.segmentation_saver import SegmentationSaver
import monai.transforms.compose as transforms
from monai.data.nifti_reader import NiftiDataset
from monai.transforms import AddChannel, Rescale, ToTensor
from monai.networks.nets.unet import UNet
from monai.networks.utils import predict_segmentation
from monai.data.synthetic import create_test_image_3d
from monai.utils.sliding_window_inference import sliding_window_inference

sys.path.append("..") # assumes the framework is found here, change as necessary
config.print_config()

tempdir = tempfile.mkdtemp()
# tempdir = './temp'
print('generating synthetic data to {} (this may take a while)'.format(tempdir))
for i in range(50):
im, seg = create_test_image_3d(256, 256, 256)

Expand All @@ -51,7 +54,7 @@
segtrans = transforms.Compose([AddChannel(), ToTensor()])
ds = NiftiDataset(images, segs, transform=imtrans, seg_transform=segtrans, image_only=False)

device = torch.device("cpu:0")
device = torch.device("cuda:0")
roi_size = (64, 64, 64)
sw_batch_size = 4
net = UNet(
Expand All @@ -65,7 +68,7 @@
net.to(device)


def _sliding_window_processor(_engine, batch):
def _sliding_window_processor(engine, batch):
net.eval()
img, seg, meta_data = batch
with torch.no_grad():
Expand All @@ -75,11 +78,11 @@ def _sliding_window_processor(_engine, batch):

infer_engine = Engine(_sliding_window_processor)

# checkpoint_handler = ModelCheckpoint('./', 'net', n_saved=10, save_interval=3, require_empty=False)
# infer_engine.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=checkpoint_handler, to_save={'net': net})

SegmentationSaver(output_path='tempdir', output_ext='.nii.gz', output_postfix='seg').attach(infer_engine)
CheckpointLoader(load_path='./net_checkpoint_9.pth', load_dict={'net': net}).attach(infer_engine)
# for the arrary data format, assume the 3rd item of batch data is the meta_data
SegmentationSaver(output_path='tempdir', output_ext='.nii.gz', output_postfix='seg',
batch_transform=lambda x: x[2]).attach(infer_engine)
# the model was trained by "unet_segmentation_3d_array" exmple
CheckpointLoader(load_path='./runs/net_checkpoint_120.pth', load_dict={'net': net}).attach(infer_engine)

loader = DataLoader(ds, batch_size=1, num_workers=1, pin_memory=torch.cuda.is_available())
state = infer_engine.run(loader)
95 changes: 95 additions & 0 deletions examples/unet_inference_3d_dict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# Copyright 2020 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import sys
import tempfile
from glob import glob

import nibabel as nib
import numpy as np
import torch
from ignite.engine import Engine
from torch.utils.data import DataLoader

# assumes the framework is found here, change as necessary
sys.path.append("..")

import monai
from monai.data.utils import list_data_collate
from monai.utils.sliding_window_inference import sliding_window_inference
from monai.data.synthetic import create_test_image_3d
from monai.networks.utils import predict_segmentation
from monai.networks.nets.unet import UNet
from monai.transforms.composables import LoadNiftid, AsChannelFirstd
import monai.transforms.compose as transforms
from monai.handlers.segmentation_saver import SegmentationSaver
from monai.handlers.checkpoint_loader import CheckpointLoader
from monai import config

config.print_config()

tempdir = tempfile.mkdtemp()
# tempdir = './temp'
print('generating synthetic data to {} (this may take a while)'.format(tempdir))
for i in range(50):
im, seg = create_test_image_3d(256, 256, 256, channel_dim=-1)

n = nib.Nifti1Image(im, np.eye(4))
nib.save(n, os.path.join(tempdir, 'im%i.nii.gz' % i))

n = nib.Nifti1Image(seg, np.eye(4))
nib.save(n, os.path.join(tempdir, 'seg%i.nii.gz' % i))

images = sorted(glob(os.path.join(tempdir, 'im*.nii.gz')))
segs = sorted(glob(os.path.join(tempdir, 'seg*.nii.gz')))
val_files = [{'img': img, 'seg': seg} for img, seg in zip(images, segs)]
val_transforms = transforms.Compose([
LoadNiftid(keys=['img', 'seg']),
AsChannelFirstd(keys=['img', 'seg'], channel_dim=-1)
])
val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)

device = torch.device("cuda:0")
roi_size = (64, 64, 64)
sw_batch_size = 4
net = UNet(
dimensions=3,
in_channels=1,
num_classes=1,
channels=(16, 32, 64, 128, 256),
strides=(2, 2, 2, 2),
num_res_units=2,
)
net.to(device)


def _sliding_window_processor(engine, batch):
net.eval()
with torch.no_grad():
seg_probs = sliding_window_inference(batch['img'], roi_size, sw_batch_size, lambda x: net(x)[0], device)
return predict_segmentation(seg_probs)


infer_engine = Engine(_sliding_window_processor)

# for the arrary data format, assume the 3rd item of batch data is the meta_data
SegmentationSaver(output_path='tempdir', output_ext='.nii.gz', output_postfix='seg',
batch_transform=lambda batch: {'filename_or_obj': batch['img.filename_or_obj'],
'original_affine': batch['img.original_affine'],
'affine': batch['img.affine'],
}).attach(infer_engine)
# the model was trained by "unet_segmentation_3d_array" exmple
CheckpointLoader(load_path='./runs/net_checkpoint_120.pth', load_dict={'net': net}).attach(infer_engine)

val_loader = DataLoader(val_ds, batch_size=1, num_workers=4, collate_fn=list_data_collate,
pin_memory=torch.cuda.is_available())
state = infer_engine.run(val_loader)
4 changes: 2 additions & 2 deletions examples/unet_segmentation_3d_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
import monai.transforms.compose as transforms

from monai.data.nifti_reader import NiftiDataset
from monai.transforms import (AddChannel, Rescale, ToTensor, UniformRandomPatch)
from monai.transforms import AddChannel, Rescale, ToTensor, UniformRandomPatch
from monai.handlers.stats_handler import StatsHandler
from monai.handlers.mean_dice import MeanDice
from monai.visualize import img2tensorboard
Expand All @@ -41,7 +41,7 @@

# Create a temporary directory and 50 random image, mask paris
tempdir = tempfile.mkdtemp()

print('generating synthetic data to {} (this may take a while)'.format(tempdir))
for i in range(50):
im, seg = create_test_image_3d(128, 128, 128)

Expand Down
6 changes: 3 additions & 3 deletions examples/unet_segmentation_3d_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@

# Create a temporary directory and 50 random image, mask paris
tempdir = tempfile.mkdtemp()

print('generating synthetic data to {} (this may take a while)'.format(tempdir))
for i in range(50):
im, seg = create_test_image_3d(128, 128, 128, channel_dim=-1)

Expand Down Expand Up @@ -70,7 +70,7 @@

# Define nifti dataset, dataloader.
ds = monai.data.Dataset(data=train_files, transform=train_transforms)
loader = DataLoader(ds, batch_size=2, num_workers=2, collate_fn=list_data_collate,
loader = DataLoader(ds, batch_size=2, num_workers=4, collate_fn=list_data_collate,
pin_memory=torch.cuda.is_available())
check_data = monai.utils.misc.first(loader)
print(check_data['img'].shape, check_data['seg'].shape)
Expand Down Expand Up @@ -190,7 +190,7 @@ def log_metrics_to_tensorboard(engine):
logging.basicConfig(stream=sys.stdout, level=logging.INFO)

train_ds = monai.data.Dataset(data=train_files, transform=train_transforms)
train_loader = DataLoader(train_ds, batch_size=2, num_workers=8, collate_fn=list_data_collate,
train_loader = DataLoader(train_ds, batch_size=2, num_workers=4, collate_fn=list_data_collate,
pin_memory=torch.cuda.is_available())

train_epochs = 30
Expand Down
27 changes: 17 additions & 10 deletions monai/handlers/segmentation_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,15 @@ class SegmentationSaver:
"""

def __init__(self, output_path='./', dtype='float32', output_postfix='seg', output_ext='.nii.gz',
output_transform=lambda x: x, name=None):
batch_transform=lambda x: x, output_transform=lambda x: x, name=None):
"""
Args:
output_path (str): output image directory.
dtype (str): to convert the image to save to this datatype.
output_postfix (str): a string appended to all output file names.
output_ext (str): output file extension name.
batch_transform (Callable): a callable that is used to transform the
ignite.engine.batch into expected format to extract the meta_data dictionary.
output_transform (Callable): a callable that is used to transform the
ignite.engine.output into the form expected nifti image data.
The first dimension of this transform's output will be treated as the
Expand All @@ -40,6 +42,7 @@ def __init__(self, output_path='./', dtype='float32', output_postfix='seg', outp
self.dtype = dtype
self.output_postfix = output_postfix
self.output_ext = output_ext
self.batch_transform = batch_transform
self.output_transform = output_transform

self.logger = None if name is None else logging.getLogger(name)
Expand Down Expand Up @@ -88,24 +91,28 @@ def _create_file_basename(postfix, input_file_name, folder_path, data_root_dir="

def __call__(self, engine):
"""
This method assumes:
- 3rd output of engine.state.batch is a meta data dict, and have the keys:
'filename_or_obj' -- for output file name creation
and optionally 'original_affine', 'affine' for data orientation handling.
- output file datatype from `engine.state.output.dtype`.
This method assumes self.batch_transform will extract Metadata from the input batch.
Metadata should have the following keys:

- ``'filename_or_obj'`` -- for output file name creation
- ``'original_affine'`` (optional) for data orientation handling
- ``'affine'`` (optional) for data output affine.

output file datatype is determined from ``engine.state.output.dtype``.
"""
meta_data = engine.state.batch[2] # assuming 3rd output of input dataset is a meta data dict
meta_data = self.batch_transform(engine.state.batch)
filenames = meta_data['filename_or_obj']
original_affine = meta_data.get('original_affine', None)
affine = meta_data.get('affine', None)

engine_output = self.output_transform(engine.state.output)
for batch_id, filename in enumerate(filenames): # save a batch of files
seg_output = engine_output[batch_id]
_affine = affine[batch_id]
_original_affine = original_affine[batch_id]
affine_ = affine[batch_id]
original_affine_ = original_affine[batch_id]
if isinstance(seg_output, torch.Tensor):
seg_output = seg_output.detach().cpu().numpy()
output_filename = self._create_file_basename(self.output_postfix, filename, self.output_path)
output_filename = '{}{}'.format(output_filename, self.output_ext)
write_nifti(seg_output, _affine, output_filename, _original_affine, dtype=seg_output.dtype)
write_nifti(seg_output, affine_, output_filename, original_affine_, dtype=seg_output.dtype)
self.logger.info('saved: {}'.format(output_filename))
36 changes: 21 additions & 15 deletions monai/handlers/stats_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
# limitations under the License.

import logging

import torch
from ignite.engine import Engine, Events

KEY_VAL_FORMAT = '{}: {:.4f} '
Expand All @@ -19,29 +19,37 @@
class StatsHandler(object):
"""StatsHandler defines a set of Ignite Event-handlers for all the log printing logics.
It's can be used for any Ignite Engine(trainer, validator and evaluator).
And it can support logging for epoch level and iteration level with pre-defined StatsLoggers.
By default, this class logs the dictionary of `engine.state.metrics`.
And it can support logging for epoch level and iteration level with pre-defined loggers.
By default:
(1) epoch_print_logger logs `engine.state.metrics`.
(2) iteration_print_logger logs loss value, expected output format is (y_pred, loss).
"""

def __init__(self,
epoch_print_logger=None,
iteration_print_logger=None,
batch_transform=lambda x: x,
output_transform=lambda x: x,
name=None):
"""
Args:
epoch_print_logger (Callable): customized callable printer for epoch level logging.
must accept parameter "engine", use default printer if None.
must accept parameter "engine", use default printer if None.
iteration_print_logger (Callable): custimized callable printer for iteration level logging.
must accept parameter "engine", use default printer if None.
must accept parameter "engine", use default printer if None.
batch_transform (Callable): a callable that is used to transform the
ignite.engine.batch into expected format to extract input data.
output_transform (Callable): a callable that is used to transform the
ignite.engine.output into expected format to extract several output data.
name (str): identifier of logging.logger to use, defaulting to `engine.logger`.
"""

self.epoch_print_logger = epoch_print_logger
self.iteration_print_logger = iteration_print_logger

self.batch_transform = batch_transform
self.output_transform = output_transform
self.logger = None if name is None else logging.getLogger(name)


def attach(self, engine: Engine):
"""Register a set of Ignite Event-Handlers to a specified Ignite engine.

Expand Down Expand Up @@ -119,15 +127,16 @@ def _default_epoch_print(self, engine: Engine):

def _default_iteration_print(self, engine: Engine):
"""Execute iteration log operation based on Ignite engine.state data.
print the values from ignite state.logs dict.
Print the values from ignite state.logs dict.
Default behaivor is to print loss from output[1], skip if output[1] is not loss.

Args:
engine (ignite.engine): Ignite Engine, it can be a trainer, validator or evaluator.

"""
prints_dict = engine.state.metrics
if not prints_dict:
return
loss = self.output_transform(engine.state.output)[1]
if loss is None or (torch.is_tensor(loss) and len(loss.shape) > 0):
return # not printing multi dimensional output
num_iterations = engine.state.epoch_length
current_iteration = (engine.state.iteration - 1) % num_iterations + 1
current_epoch = engine.state.epoch
Expand All @@ -138,9 +147,6 @@ def _default_iteration_print(self, engine: Engine):
num_epochs,
current_iteration,
num_iterations)

for name in sorted(prints_dict):
value = prints_dict[name]
out_str += KEY_VAL_FORMAT.format(name, value)
out_str += KEY_VAL_FORMAT.format('Loss', loss.item() if torch.is_tensor(loss) else loss)

self.logger.info(out_str)
5 changes: 2 additions & 3 deletions monai/utils/sliding_window_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,8 @@
# limitations under the License.

import torch

from ignite.utils import convert_tensor
from monai.transforms.transforms import ImageEndPadder
from monai.transforms.transforms import ToTensor
from monai.data.utils import dense_patch_slices


Expand Down Expand Up @@ -49,7 +48,7 @@ def sliding_window_inference(inputs, roi_size, sw_batch_size, predictor, device)
# in case that image size is smaller than roi size
image_size = tuple(max(image_size[i], roi_size[i]) for i in range(num_spatial_dims))
inputs = ImageEndPadder(roi_size, 'constant')(inputs) # in np array
inputs = ToTensor()(inputs)
inputs = convert_tensor(torch.from_numpy(inputs), device, False)

# TODO: interval from user's specification
scan_interval = _get_scan_interval(image_size, roi_size, num_spatial_dims)
Expand Down
3 changes: 2 additions & 1 deletion tests/integration_sliding_window.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@ def _sliding_window_processor(_engine, batch):
infer_engine = Engine(_sliding_window_processor)

with tempfile.TemporaryDirectory() as temp_dir:
SegmentationSaver(output_path=temp_dir, output_ext='.nii.gz', output_postfix='seg').attach(infer_engine)
SegmentationSaver(output_path=temp_dir, output_ext='.nii.gz', output_postfix='seg',
batch_transform=lambda x: x[2]).attach(infer_engine)

infer_engine.run(loader)

Expand Down
Loading