Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 28 additions & 16 deletions examples/unet_segmentation_3d_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,42 +28,52 @@

import monai
import monai.transforms.compose as transforms
from monai.data.nifti_reader import NiftiDatasetd
from monai.transforms.composables import AddChanneld, RandRotate90d
from monai.transforms.composables import \
LoadNiftid, AsChannelFirstd, RandCropByPosNegLabeld, RandRotate90d
from monai.handlers.stats_handler import StatsHandler
from monai.handlers.mean_dice import MeanDice
from monai.visualize import img2tensorboard
from monai.data.synthetic import create_test_image_3d
from monai.handlers.utils import stopping_fn_from_metric
from monai.data.utils import list_data_collate

monai.config.print_config()

# Create a temporary directory and 50 random image, mask paris
tempdir = tempfile.mkdtemp()

for i in range(50):
im, seg = create_test_image_3d(128, 128, 128)
im, seg = create_test_image_3d(128, 128, 128, channel_dim=-1)

n = nib.Nifti1Image(im, np.eye(4))
nib.save(n, os.path.join(tempdir, 'im%i.nii.gz' % i))
nib.save(n, os.path.join(tempdir, 'img%i.nii.gz' % i))

n = nib.Nifti1Image(seg, np.eye(4))
nib.save(n, os.path.join(tempdir, 'seg%i.nii.gz' % i))

images = sorted(glob(os.path.join(tempdir, 'im*.nii.gz')))
images = sorted(glob(os.path.join(tempdir, 'img*.nii.gz')))
segs = sorted(glob(os.path.join(tempdir, 'seg*.nii.gz')))
train_files = [{'img': img, 'seg': seg} for img, seg in zip(images[:40], segs[:40])]
val_files = [{'img': img, 'seg': seg} for img, seg in zip(images[-10:], segs[-10:])]

# Define transforms for image and segmentation
transforms = transforms.Compose([
AddChanneld(keys=['image', 'seg']),
RandRotate90d(keys=['image', 'seg'], prob=0.8, axes=[1, 3])
train_transforms = transforms.Compose([
LoadNiftid(keys=['img', 'seg']),
AsChannelFirstd(keys=['img', 'seg'], channel_dim=-1),
RandCropByPosNegLabeld(keys=['img', 'seg'], label_key='seg', size=[96, 96, 96], pos=1, neg=1, num_samples=4),
RandRotate90d(keys=['img', 'seg'], prob=0.8, axes=[1, 3])
])
val_transforms = transforms.Compose([
LoadNiftid(keys=['img', 'seg']),
AsChannelFirstd(keys=['img', 'seg'], channel_dim=-1)
])

# Define nifti dataset, dataloader.
ds = NiftiDatasetd(images, segs, transform=transforms)
loader = DataLoader(ds, batch_size=10, num_workers=2, pin_memory=torch.cuda.is_available())
ds = monai.data.Dataset(data=train_files, transform=train_transforms)
loader = DataLoader(ds, batch_size=2, num_workers=2, collate_fn=list_data_collate,
pin_memory=torch.cuda.is_available())
check_data = monai.utils.misc.first(loader)
print(check_data['image'].shape, check_data['seg'].shape)
print(check_data['img'].shape, check_data['seg'].shape)

lr = 1e-5

Expand All @@ -88,7 +98,7 @@ def _loss_fn(i, j):

# Create trainer
def prepare_batch(batch, device=None, non_blocking=False):
return _prepare_batch((batch['image'], batch['seg']), device, non_blocking)
return _prepare_batch((batch['img'], batch['seg']), device, non_blocking)


device = torch.device("cuda:0")
Expand Down Expand Up @@ -160,8 +170,9 @@ def log_training_loss(engine):
evaluator.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=early_stopper)

# create a validation data loader
val_ds = NiftiDatasetd(images[-20:], segs[-20:], transform=transforms)
val_loader = DataLoader(ds, batch_size=5, num_workers=8, pin_memory=torch.cuda.is_available())
val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
val_loader = DataLoader(ds, batch_size=5, num_workers=8, collate_fn=list_data_collate,
pin_memory=torch.cuda.is_available())


@trainer.on(Events.EPOCH_COMPLETED(every=validation_every_n_epochs))
Expand All @@ -178,8 +189,9 @@ def log_metrics_to_tensorboard(engine):
# create a training data loader
logging.basicConfig(stream=sys.stdout, level=logging.INFO)

train_ds = NiftiDatasetd(images[:20], segs[:20], transform=transforms)
train_loader = DataLoader(train_ds, batch_size=5, num_workers=8, pin_memory=torch.cuda.is_available())
train_ds = monai.data.Dataset(data=train_files, transform=train_transforms)
train_loader = DataLoader(train_ds, batch_size=2, num_workers=8, collate_fn=list_data_collate,
pin_memory=torch.cuda.is_available())

train_epochs = 30
state = trainer.run(train_loader, train_epochs)
45 changes: 45 additions & 0 deletions monai/data/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Copyright 2020 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
from monai.utils.module import export


@export("monai.data")
class Dataset(torch.utils.data.Dataset):
"""
General Dataset to handle dictionary format data, it can operate transforms for specific fields.
For example, typical input data can be a list of dictionaries:
[{ { {
'img': 'image1.nii.gz', 'img': 'image2.nii.gz', 'img': 'image3.nii.gz',
'seg': 'label1.nii.gz', 'seg': 'label2.nii.gz', 'seg': 'label3.nii.gz',
'extra': 123 'extra': 456 'extra': 789
}, }, }]
"""

def __init__(self, data, transform=None):
"""
Args:
data (Iterable): input data to load and transform to generate dataset for model.
transform (Callable, optional): transforms to excute operations on input data.
"""
self.data = data
self.transform = transform

def __len__(self):
return len(self.data)

def __getitem__(self, index):
data = self.data[index]
if self.transform is not None:
data = self.transform(data)

return data
73 changes: 0 additions & 73 deletions monai/data/nifti_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,76 +135,3 @@ def __getitem__(self, index):
continue
compatible_meta[meta_key] = meta_datum
return img, target, compatible_meta


@export("monai.data")
class NiftiDatasetd(Dataset):
"""
Loads image/segmentation pairs of Nifti files from the given filename lists. Dict level transformations can be
specified for the dictionary data which is constructed by image, label and other metadata.
"""

def __init__(self, image_files, seg_files=None, labels=None, as_closest_canonical=False, transform=None, dtype=None):
"""
Initializes the dataset with the image and segmentation filename lists. The transform `transform` is applied
to the images and `seg_transform` to the segmentations.

Args:
image_files (list of str): list of image filenames.
seg_files (list of str): if in segmentation task, list of segmentation filenames.
labels (list or array): if in classification task, list of classification labels.
as_closest_canonical (bool): if True, load the image as closest to canonical orientation.
transform (Callable, optional): dict transforms to excute operations on dictionary data.
dtype (np.dtype, optional): if not None convert the loaded image to this data type.
"""

if len(image_files) != len(seg_files):
raise ValueError('Must have same number of image and segmentation files')

self.image_files = image_files
self.seg_files = seg_files
self.labels = labels
self.as_closest_canonical = as_closest_canonical
self.transform = transform
self.dtype = dtype

def __len__(self):
return len(self.image_files)

def __getitem__(self, index):
meta_data = None
img, meta_data = load_nifti(
filename_or_obj=self.image_files[index],
as_closest_canonical=self.as_closest_canonical,
image_only=False,
dtype=self.dtype
)

seg = None
if self.seg_files is not None:
seg = load_nifti(self.seg_files[index])
label = None
if self.labels is not None:
label = self.labels[index]

compatible_meta = {}
assert isinstance(meta_data, dict), 'meta_data must be in dictionary format.'
for meta_key in meta_data:
meta_datum = meta_data[meta_key]
if type(meta_datum).__name__ == 'ndarray' \
and np_str_obj_array_pattern.search(meta_datum.dtype.str) is not None:
continue
compatible_meta[meta_key] = meta_datum

data = {'image': img}
if seg is not None:
data['seg'] = seg
if label is not None:
data['label'] = label
if len(compatible_meta) > 0:
data.update(compatible_meta)

if self.transform is not None:
data = self.transform(data)

return data
18 changes: 15 additions & 3 deletions monai/data/synthetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,13 @@
from monai.transforms.utils import rescale_array


def create_test_image_2d(width, height, num_objs=12, rad_max=30, noise_max=0.0, num_seg_classes=5):
def create_test_image_2d(width, height, num_objs=12, rad_max=30, noise_max=0.0, num_seg_classes=5, channel_dim=None):
"""
Return a noisy 2D image with `numObj' circles and a 2D mask image. The maximum radius of the circles is given as
`radMax'. The mask will have `numSegClasses' number of classes for segmentations labeled sequentially from 1, plus a
background class represented as 0. If `noiseMax' is greater than 0 then noise will be added to the image taken from
the uniform distribution on range [0,noiseMax).
the uniform distribution on range [0,noiseMax). If `channel_dim' is None, will create an image without channel
dimemsion, otherwise create an image with channel dimension as first dim or last dim.
"""
image = np.zeros((width, height))

Expand All @@ -40,10 +41,16 @@ def create_test_image_2d(width, height, num_objs=12, rad_max=30, noise_max=0.0,
norm = np.random.uniform(0, num_seg_classes * noise_max, size=image.shape)
noisyimage = rescale_array(np.maximum(image, norm))

if channel_dim is not None:
assert isinstance(channel_dim, int) and channel_dim in (-1, 0, 2), 'invalid channel dim.'
noisyimage, labels = noisyimage[None], labels[None] \
if channel_dim == 0 else noisyimage[..., None], labels[..., None]

return noisyimage, labels


def create_test_image_3d(height, width, depth, num_objs=12, rad_max=30, noise_max=0.0, num_seg_classes=5):
def create_test_image_3d(height, width, depth, num_objs=12, rad_max=30,
noise_max=0.0, num_seg_classes=5, channel_dim=None):
"""
Return a noisy 3D image and segmentation.

Expand All @@ -69,4 +76,9 @@ def create_test_image_3d(height, width, depth, num_objs=12, rad_max=30, noise_ma
norm = np.random.uniform(0, num_seg_classes * noise_max, size=image.shape)
noisyimage = rescale_array(np.maximum(image, norm))

if channel_dim is not None:
assert isinstance(channel_dim, int) and channel_dim in (-1, 0, 3), 'invalid channel dim.'
noisyimage, labels = (noisyimage[None], labels[None]) \
if channel_dim == 0 else (noisyimage[..., None], labels[..., None])

return noisyimage, labels
4 changes: 2 additions & 2 deletions monai/losses/dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def forward(self, pred, ground, smooth=1e-5):
intersection = psum * tsum
sums = psum + tsum

score = 2.0 * (intersection.sum(2) + smooth) / (sums.sum(2) + smooth)
score = (2.0 * intersection.sum(2) + smooth) / (sums.sum(2) + smooth)
return 1 - score.mean()


Expand Down Expand Up @@ -159,5 +159,5 @@ def forward(self, pred, ground, smooth=1e-5):
b[infs] = 0.0
b[infs] = torch.max(b)

score = 2.0 * (intersection.sum(2) * w) / (sums.sum(2) * w + smooth)
score = (2.0 * intersection.sum(2) * w + smooth) / (sums.sum(2) * w + smooth)
return 1 - score.mean()
Loading