Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ venv.bak/
# mypy
.mypy_cache/
examples/scd_lvsegs.npz
.temp/
.idea/

*~
85 changes: 85 additions & 0 deletions examples/unet_inference_3d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# Copyright 2020 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import sys
import tempfile
from glob import glob

import nibabel as nib
import numpy as np
import torch
import torchvision.transforms as transforms
from ignite.engine import Engine
from torch.utils.data import DataLoader

from monai import config
from monai.handlers.checkpoint_loader import CheckpointLoader
from monai.handlers.segmentation_saver import SegmentationSaver
from monai.data.nifti_reader import NiftiDataset
from monai.transforms import AddChannel, Rescale, ToTensor
from monai.networks.nets.unet import UNet
from monai.networks.utils import predict_segmentation
from monai.data.synthetic import create_test_image_3d
from monai.utils.sliding_window_inference import sliding_window_inference

sys.path.append("..") # assumes the framework is found here, change as necessary
config.print_config()

tempdir = tempfile.mkdtemp()
# tempdir = './temp'
for i in range(50):
im, seg = create_test_image_3d(256, 256, 256)

n = nib.Nifti1Image(im, np.eye(4))
nib.save(n, os.path.join(tempdir, 'im%i.nii.gz' % i))

n = nib.Nifti1Image(seg, np.eye(4))
nib.save(n, os.path.join(tempdir, 'seg%i.nii.gz' % i))

images = sorted(glob(os.path.join(tempdir, 'im*.nii.gz')))
segs = sorted(glob(os.path.join(tempdir, 'seg*.nii.gz')))
imtrans = transforms.Compose([Rescale(), AddChannel(), ToTensor()])
segtrans = transforms.Compose([AddChannel(), ToTensor()])
ds = NiftiDataset(images, segs, transform=imtrans, seg_transform=segtrans, image_only=False)

device = torch.device("cpu:0")
roi_size = (64, 64, 64)
sw_batch_size = 4
net = UNet(
dimensions=3,
in_channels=1,
num_classes=1,
channels=(16, 32, 64, 128, 256),
strides=(2, 2, 2, 2),
num_res_units=2,
)
net.to(device)


def _sliding_window_processor(_engine, batch):
net.eval()
img, seg, meta_data = batch
with torch.no_grad():
seg_probs = sliding_window_inference(img, roi_size, sw_batch_size, lambda x: net(x)[0], device)
return predict_segmentation(seg_probs)


infer_engine = Engine(_sliding_window_processor)

# checkpoint_handler = ModelCheckpoint('./', 'net', n_saved=10, save_interval=3, require_empty=False)
# infer_engine.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=checkpoint_handler, to_save={'net': net})

SegmentationSaver(output_path='tempdir', output_ext='.nii.gz', output_postfix='seg').attach(infer_engine)
CheckpointLoader(load_path='./net_checkpoint_9.pth', load_dict={'net': net}).attach(infer_engine)

loader = DataLoader(ds, batch_size=1, num_workers=1, pin_memory=torch.cuda.is_available())
state = infer_engine.run(loader)
34 changes: 5 additions & 29 deletions examples/unet_segmentation_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,40 +30,14 @@
from monai.transforms import (AddChannel, Rescale, ToTensor, UniformRandomPatch)
from monai.handlers.stats_handler import StatsHandler
from monai.handlers.mean_dice import MeanDice
from monai.transforms.utils import rescale_array
from monai.visualize import img2tensorboard
from monai.data.synthetic import create_test_image_3d

# assumes the framework is found here, change as necessary
sys.path.append("..")

config.print_config()


def create_test_image_3d(height, width, depth, num_objs=12, rad_max=30, noise_max=0.0, num_seg_classes=5):
'''Return a noisy 3D image and segmentation.'''
image = np.zeros((width, height, depth))

for i in range(num_objs):
x = np.random.randint(rad_max, width - rad_max)
y = np.random.randint(rad_max, height - rad_max)
z = np.random.randint(rad_max, depth - rad_max)
rad = np.random.randint(5, rad_max)
spy, spx, spz = np.ogrid[-x:width - x, -y:height - y, -z:depth - z]
circle = (spx * spx + spy * spy + spz * spz) <= rad * rad

if num_seg_classes > 1:
image[circle] = np.ceil(np.random.random() * num_seg_classes)
else:
image[circle] = np.random.random() * 0.5 + 0.5

labels = np.ceil(image).astype(np.int32)

norm = np.random.uniform(0, num_seg_classes * noise_max, size=image.shape)
noisyimage = rescale_array(np.maximum(image, norm))

return noisyimage, labels


tempdir = tempfile.mkdtemp()

for i in range(50):
Expand All @@ -82,7 +56,7 @@ def create_test_image_3d(height, width, depth, num_objs=12, rad_max=30, noise_ma

segtrans = transforms.Compose([AddChannel(), UniformRandomPatch((64, 64, 64)), ToTensor()])

ds = NiftiDataset(images, segs, imtrans, segtrans)
ds = NiftiDataset(images, segs, transform=imtrans, seg_transform=segtrans)

loader = DataLoader(ds, batch_size=10, num_workers=2, pin_memory=torch.cuda.is_available())
im, seg = monai.utils.misc.first(loader)
Expand Down Expand Up @@ -115,7 +89,9 @@ def _loss_fn(i, j):
output_transform=lambda x, y, y_pred, loss: [y_pred, loss.item(), y])

checkpoint_handler = ModelCheckpoint('./', 'net', n_saved=10, require_empty=False)
trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=checkpoint_handler, to_save={'net': net})
trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED,
handler=checkpoint_handler,
to_save={'net': net, 'opt': opt})

dice_metric = MeanDice(add_sigmoid=True, output_transform=lambda output: (output[0][0], output[2]))
dice_metric.attach(trainer, "Training Dice")
Expand Down
2 changes: 1 addition & 1 deletion monai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

__copyright__ = "(c) 2020 MONAI Consortium"
__version__tuple__ = (0, 0, 1)
__version__ = "%i.%i.%i" % (__version__tuple__)
__version__ = "%i.%i.%i" % __version__tuple__

__basedir__ = os.path.dirname(__file__)

Expand Down
51 changes: 41 additions & 10 deletions monai/data/nifti_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import random

from torch.utils.data import Dataset
from torch.utils.data._utils.collate import np_str_obj_array_pattern

from monai.utils.module import export

Expand All @@ -31,25 +32,32 @@ def load_nifti(filename_or_obj, as_closest_canonical=False, image_only=True, dty
Returns:
The loaded image volume if `image_only` is True, or a tuple containing the volume and the Nifti
header in dict format otherwise

Note:
header['original_affine'] stores the original affine loaded from `filename_or_obj`.
header['affine'] stores the affine after the optional `as_closest_canonical` transform.
"""

img = nib.load(filename_or_obj)

header = dict(img.header)
header['filename_or_obj'] = filename_or_obj
header['original_affine'] = img.affine
header['affine'] = img.affine
header['as_closest_canonical'] = as_closest_canonical

if as_closest_canonical:
img = nib.as_closest_canonical(img)
header['affine'] = img.affine

if dtype is not None:
dat = img.get_fdata(dtype=dtype)
else:
dat = np.asanyarray(img.dataobj)

header = dict(img.header)
header['filename_or_obj'] = filename_or_obj

if image_only:
return dat
else:
return dat, header
return dat, header


@export("monai.data")
Expand All @@ -59,31 +67,44 @@ class NiftiDataset(Dataset):
for the image and segmentation arrays separately.
"""

def __init__(self, image_files, seg_files, transform=None, seg_transform=None):
def __init__(self, image_files, seg_files, as_closest_canonical=False,
transform=None, seg_transform=None, image_only=True, dtype=None):
"""
Initializes the dataset with the image and segmentation filename lists. The transform `transform` is applied
to the images and `seg_transform` to the segmentations.

Args:
image_files (list of str): list of image filenames
seg_files (list of str): list of segmentation filenames
as_closest_canonical (bool): if True, load the image as closest to canonical orientation
transform (Callable, optional): transform to apply to image arrays
seg_transform (Callable, optional): transform to apply to segmentation arrays
image_only (bool): if True return only the image volume, other return image volume and header dict
dtype (np.dtype, optional): if not None convert the loaded image to this data type
"""

if len(image_files) != len(seg_files):
raise ValueError('Must have same number of image and segmentation files')

self.image_files = image_files
self.seg_files = seg_files
self.transform = transform
self.seg_transform = seg_transform
self.as_closest_canonical = as_closest_canonical
self.transform = transform
self.seg_transform = seg_transform
self.image_only = image_only
self.dtype = dtype

def __len__(self):
return len(self.image_files)

def __getitem__(self, index):
img = load_nifti(self.image_files[index])
meta_data = None
if self.image_only:
img = load_nifti(self.image_files[index], as_closest_canonical=self.as_closest_canonical,
image_only=self.image_only, dtype=self.dtype)
else:
img, meta_data = load_nifti(self.image_files[index], as_closest_canonical=self.as_closest_canonical,
image_only=self.image_only, dtype=self.dtype)
seg = load_nifti(self.seg_files[index])

# https://github.com/pytorch/vision/issues/9#issuecomment-304224800
Expand All @@ -97,4 +118,14 @@ def __getitem__(self, index):
random.seed(seed) # ensure randomized transforms roll the same values for segmentations as images
seg = self.seg_transform(seg)

return img, seg
if self.image_only or meta_data is None:
return img, seg

compatible_meta = {}
for meta_key in meta_data:
meta_datum = meta_data[meta_key]
if type(meta_datum).__name__ == 'ndarray' \
and np_str_obj_array_pattern.search(meta_datum.dtype.str) is not None:
continue
compatible_meta[meta_key] = meta_datum
return img, seg, compatible_meta
41 changes: 41 additions & 0 deletions monai/data/nifti_writer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Copyright 2020 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
import nibabel as nib


def write_nifti(data, affine, file_name, target_affine=None, dtype="float32"):
"""Write numpy data into nifti files to disk.

Args:
data (numpy.ndarray): input data to write to file.
affine (numpy.ndarray): affine information for the data.
file_name (string): expected file name that saved on disk.
target_affine (numpy.ndarray, optional):
before saving the (data, affine), transform the data into the orientation defined by `target_affine`.
dtype (np.dtype, optional): convert the image to save to this data type.
"""
assert isinstance(data, np.ndarray), 'input data must be numpy array.'
if affine is None:
affine = np.eye(4)

if target_affine is None:
results_img = nib.Nifti1Image(data.astype(dtype), affine)
else:
start_ornt = nib.orientations.io_orientation(affine)
target_ornt = nib.orientations.io_orientation(target_affine)
ornt_transform = nib.orientations.ornt_transform(start_ornt, target_ornt)

reverted_results = nib.orientations.apply_orientation(data, ornt_transform)
results_img = nib.Nifti1Image(reverted_results.astype(dtype), target_affine)

nib.save(results_img, file_name)
32 changes: 30 additions & 2 deletions monai/data/synthetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.


import numpy as np

from monai.transforms.utils import rescale_array
Expand All @@ -28,7 +27,7 @@ def create_test_image_2d(width, height, num_objs=12, rad_max=30, noise_max=0.0,
x = np.random.randint(rad_max, width - rad_max)
y = np.random.randint(rad_max, height - rad_max)
rad = np.random.randint(5, rad_max)
spy, spx = np.ogrid[-x : width - x, -y : height - y]
spy, spx = np.ogrid[-x:width - x, -y:height - y]
circle = (spx * spx + spy * spy) <= rad * rad

if num_seg_classes > 1:
Expand All @@ -42,3 +41,32 @@ def create_test_image_2d(width, height, num_objs=12, rad_max=30, noise_max=0.0,
noisyimage = rescale_array(np.maximum(image, norm))

return noisyimage, labels


def create_test_image_3d(height, width, depth, num_objs=12, rad_max=30, noise_max=0.0, num_seg_classes=5):
"""
Return a noisy 3D image and segmentation.

See also: create_test_image_2d
"""
image = np.zeros((width, height, depth))

for i in range(num_objs):
x = np.random.randint(rad_max, width - rad_max)
y = np.random.randint(rad_max, height - rad_max)
z = np.random.randint(rad_max, depth - rad_max)
rad = np.random.randint(5, rad_max)
spy, spx, spz = np.ogrid[-x:width - x, -y:height - y, -z:depth - z]
circle = (spx * spx + spy * spy + spz * spz) <= rad * rad

if num_seg_classes > 1:
image[circle] = np.ceil(np.random.random() * num_seg_classes)
else:
image[circle] = np.random.random() * 0.5 + 0.5

labels = np.ceil(image).astype(np.int32)

norm = np.random.uniform(0, num_seg_classes * noise_max, size=image.shape)
noisyimage = rescale_array(np.maximum(image, norm))

return noisyimage, labels
Loading