diff --git a/examples/unet_segmentation_3d.py b/examples/unet_segmentation_3d_array.py similarity index 100% rename from examples/unet_segmentation_3d.py rename to examples/unet_segmentation_3d_array.py diff --git a/examples/unet_segmentation_3d_dict.py b/examples/unet_segmentation_3d_dict.py new file mode 100644 index 0000000000..8c6955d87b --- /dev/null +++ b/examples/unet_segmentation_3d_dict.py @@ -0,0 +1,185 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +import tempfile +from glob import glob +import logging + +import nibabel as nib +import numpy as np +import torch +from torch.utils.tensorboard import SummaryWriter +from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator, _prepare_batch +from ignite.handlers import ModelCheckpoint, EarlyStopping +from torch.utils.data import DataLoader + +# assumes the framework is found here, change as necessary +sys.path.append("..") + +import monai +import monai.transforms.compose as transforms +from monai.data.nifti_reader import NiftiDatasetd +from monai.transforms.composables import AddChanneld, RandRotate90d +from monai.handlers.stats_handler import StatsHandler +from monai.handlers.mean_dice import MeanDice +from monai.visualize import img2tensorboard +from monai.data.synthetic import create_test_image_3d +from monai.handlers.utils import stopping_fn_from_metric + +monai.config.print_config() + +# Create a temporary directory and 50 random image, mask paris +tempdir = tempfile.mkdtemp() + +for i in range(50): + im, seg = create_test_image_3d(128, 128, 128) + + n = nib.Nifti1Image(im, np.eye(4)) + nib.save(n, os.path.join(tempdir, 'im%i.nii.gz' % i)) + + n = nib.Nifti1Image(seg, np.eye(4)) + nib.save(n, os.path.join(tempdir, 'seg%i.nii.gz' % i)) + +images = sorted(glob(os.path.join(tempdir, 'im*.nii.gz'))) +segs = sorted(glob(os.path.join(tempdir, 'seg*.nii.gz'))) + +# Define transforms for image and segmentation +transforms = transforms.Compose([ + AddChanneld(keys=['image', 'seg']), + RandRotate90d(keys=['image', 'seg'], prob=0.8, axes=[1, 3]) +]) + +# Define nifti dataset, dataloader. +ds = NiftiDatasetd(images, segs, transform=transforms) +loader = DataLoader(ds, batch_size=10, num_workers=2, pin_memory=torch.cuda.is_available()) +check_data = monai.utils.misc.first(loader) +print(check_data['image'].shape, check_data['seg'].shape) + +lr = 1e-5 + +# Create UNet, DiceLoss and Adam optimizer. +net = monai.networks.nets.UNet( + dimensions=3, + in_channels=1, + num_classes=1, + channels=(16, 32, 64, 128, 256), + strides=(2, 2, 2, 2), + num_res_units=2, +) + +loss = monai.losses.DiceLoss(do_sigmoid=True) +opt = torch.optim.Adam(net.parameters(), lr) + + +# Since network outputs logits and segmentation, we need a custom function. +def _loss_fn(i, j): + return loss(i[0], j) + + +# Create trainer +def prepare_batch(batch, device=None, non_blocking=False): + return _prepare_batch((batch['image'], batch['seg']), device, non_blocking) + + +device = torch.device("cuda:0") +trainer = create_supervised_trainer(net, opt, _loss_fn, device, False, + prepare_batch=prepare_batch, + output_transform=lambda x, y, y_pred, loss: [y_pred, loss.item(), y]) + +# adding checkpoint handler to save models (network params and optimizer stats) during training +checkpoint_handler = ModelCheckpoint('./runs/', 'net', n_saved=10, require_empty=False) +trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED, + handler=checkpoint_handler, + to_save={'net': net, 'opt': opt}) +train_stats_handler = StatsHandler() +train_stats_handler.attach(trainer) + + +@trainer.on(Events.EPOCH_COMPLETED) +def log_training_loss(engine): + # log loss to tensorboard with second item of engine.state.output, loss.item() from output_transform + writer.add_scalar('Loss/train', engine.state.output[1], engine.state.epoch) + + # tensor of ones to use where for converting labels to zero and ones + ones = torch.ones(engine.state.batch['seg'][0].shape, dtype=torch.int32) + first_output_tensor = engine.state.output[0][1][0].detach().cpu() + # log model output to tensorboard, as three dimensional tensor with no channels dimension + img2tensorboard.add_animated_gif_no_channels(writer, "first_output_final_batch", first_output_tensor, 64, + 255, engine.state.epoch) + # get label tensor and convert to single class + first_label_tensor = torch.where(engine.state.batch['seg'][0] > 0, ones, engine.state.batch['seg'][0]) + # log label tensor to tensorboard, there is a channel dimension when getting label from batch + img2tensorboard.add_animated_gif(writer, "first_label_final_batch", first_label_tensor, 64, + 255, engine.state.epoch) + second_output_tensor = engine.state.output[0][1][1].detach().cpu() + img2tensorboard.add_animated_gif_no_channels(writer, "second_output_final_batch", second_output_tensor, 64, + 255, engine.state.epoch) + second_label_tensor = torch.where(engine.state.batch['seg'][1] > 0, ones, engine.state.batch['seg'][1]) + img2tensorboard.add_animated_gif(writer, "second_label_final_batch", second_label_tensor, 64, + 255, engine.state.epoch) + third_output_tensor = engine.state.output[0][1][2].detach().cpu() + img2tensorboard.add_animated_gif_no_channels(writer, "third_output_final_batch", third_output_tensor, 64, + 255, engine.state.epoch) + third_label_tensor = torch.where(engine.state.batch['seg'][2] > 0, ones, engine.state.batch['seg'][2]) + img2tensorboard.add_animated_gif(writer, "third_label_final_batch", third_label_tensor, 64, + 255, engine.state.epoch) + engine.logger.info("Epoch[%s] Loss: %s", engine.state.epoch, engine.state.output[1]) + + +writer = SummaryWriter() + +# Set parameters for validation +validation_every_n_epochs = 1 +metric_name = 'Mean_Dice' + +# add evaluation metric to the evaluator engine +val_metrics = {metric_name: MeanDice(add_sigmoid=True, to_onehot_y=False)} +evaluator = create_supervised_evaluator(net, val_metrics, device, True, + prepare_batch=prepare_batch, + output_transform=lambda x, y, y_pred: (y_pred[0], y)) + +# Add stats event handler to print validation stats via evaluator +logging.basicConfig(stream=sys.stdout, level=logging.INFO) +val_stats_handler = StatsHandler() +val_stats_handler.attach(evaluator) + +# Add early stopping handler to evaluator. +early_stopper = EarlyStopping(patience=4, + score_function=stopping_fn_from_metric(metric_name), + trainer=trainer) +evaluator.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=early_stopper) + +# create a validation data loader +val_ds = NiftiDatasetd(images[-20:], segs[-20:], transform=transforms) +val_loader = DataLoader(ds, batch_size=5, num_workers=8, pin_memory=torch.cuda.is_available()) + + +@trainer.on(Events.EPOCH_COMPLETED(every=validation_every_n_epochs)) +def run_validation(engine): + evaluator.run(val_loader) + + +@evaluator.on(Events.EPOCH_COMPLETED) +def log_metrics_to_tensorboard(engine): + for _, value in engine.state.metrics.items(): + writer.add_scalar('Metrics/' + metric_name, value, trainer.state.epoch) + + +# create a training data loader +logging.basicConfig(stream=sys.stdout, level=logging.INFO) + +train_ds = NiftiDatasetd(images[:20], segs[:20], transform=transforms) +train_loader = DataLoader(train_ds, batch_size=5, num_workers=8, pin_memory=torch.cuda.is_available()) + +train_epochs = 30 +state = trainer.run(train_loader, train_epochs) diff --git a/monai/data/nifti_reader.py b/monai/data/nifti_reader.py index 287c97fdde..753691de20 100644 --- a/monai/data/nifti_reader.py +++ b/monai/data/nifti_reader.py @@ -14,7 +14,6 @@ from torch.utils.data import Dataset from torch.utils.data._utils.collate import np_str_obj_array_pattern - from monai.utils.module import export from monai.transforms.compose import Randomizable @@ -107,6 +106,7 @@ def __getitem__(self, index): else: img, meta_data = load_nifti(self.image_files[index], as_closest_canonical=self.as_closest_canonical, image_only=self.image_only, dtype=self.dtype) + target = None if self.seg_files is not None: target = load_nifti(self.seg_files[index]) elif self.labels is not None: @@ -135,3 +135,76 @@ def __getitem__(self, index): continue compatible_meta[meta_key] = meta_datum return img, target, compatible_meta + + +@export("monai.data") +class NiftiDatasetd(Dataset): + """ + Loads image/segmentation pairs of Nifti files from the given filename lists. Dict level transformations can be + specified for the dictionary data which is constructed by image, label and other metadata. + """ + + def __init__(self, image_files, seg_files=None, labels=None, as_closest_canonical=False, transform=None, dtype=None): + """ + Initializes the dataset with the image and segmentation filename lists. The transform `transform` is applied + to the images and `seg_transform` to the segmentations. + + Args: + image_files (list of str): list of image filenames. + seg_files (list of str): if in segmentation task, list of segmentation filenames. + labels (list or array): if in classification task, list of classification labels. + as_closest_canonical (bool): if True, load the image as closest to canonical orientation. + transform (Callable, optional): dict transforms to excute operations on dictionary data. + dtype (np.dtype, optional): if not None convert the loaded image to this data type. + """ + + if len(image_files) != len(seg_files): + raise ValueError('Must have same number of image and segmentation files') + + self.image_files = image_files + self.seg_files = seg_files + self.labels = labels + self.as_closest_canonical = as_closest_canonical + self.transform = transform + self.dtype = dtype + + def __len__(self): + return len(self.image_files) + + def __getitem__(self, index): + meta_data = None + img, meta_data = load_nifti( + filename_or_obj=self.image_files[index], + as_closest_canonical=self.as_closest_canonical, + image_only=False, + dtype=self.dtype + ) + + seg = None + if self.seg_files is not None: + seg = load_nifti(self.seg_files[index]) + label = None + if self.labels is not None: + label = self.labels[index] + + compatible_meta = {} + assert isinstance(meta_data, dict), 'meta_data must be in dictionary format.' + for meta_key in meta_data: + meta_datum = meta_data[meta_key] + if type(meta_datum).__name__ == 'ndarray' \ + and np_str_obj_array_pattern.search(meta_datum.dtype.str) is not None: + continue + compatible_meta[meta_key] = meta_datum + + data = {'image': img} + if seg is not None: + data['seg'] = seg + if label is not None: + data['label'] = label + if len(compatible_meta) > 0: + data.update(compatible_meta) + + if self.transform is not None: + data = self.transform(data) + + return data diff --git a/monai/transforms/composables.py b/monai/transforms/composables.py index 4c13e105ca..d177bc98e1 100644 --- a/monai/transforms/composables.py +++ b/monai/transforms/composables.py @@ -18,7 +18,7 @@ import monai from monai.data.utils import get_random_patch, get_valid_patch_size from monai.transforms.compose import Randomizable, Transform -from monai.transforms.transforms import Rotate90, SpatialCrop +from monai.transforms.transforms import Rotate90, SpatialCrop, AddChannel from monai.utils.misc import ensure_tuple from monai.transforms.utils import generate_pos_neg_label_crop_centers @@ -149,6 +149,28 @@ def __call__(self, data): return d +@export +class AddChanneld(MapTransform): + """ + dictionary-based wrapper of AddChannel. + """ + + def __init__(self, keys): + """ + Args: + keys (hashable items): keys of the corresponding items to be transformed. + See also: monai.transform.composables.MapTransform + """ + MapTransform.__init__(self, keys) + self.adder = AddChannel() + + def __call__(self, data): + d = dict(data) + for key in self.keys: + d[key] = self.adder(d[key]) + return d + + @export class RandCropByPosNegLabeld(Randomizable, MapTransform): """ diff --git a/monai/transforms/transforms.py b/monai/transforms/transforms.py index 7f8961b4e6..3cca4a01c3 100644 --- a/monai/transforms/transforms.py +++ b/monai/transforms/transforms.py @@ -286,7 +286,7 @@ def __init__(self, k=1, axes=(1, 2)): self.plane_axes = axes def __call__(self, img): - return np.rot90(img, self.k, self.plane_axes) + return np.ascontiguousarray(np.rot90(img, self.k, self.plane_axes)) @export diff --git a/tests/test_add_channeld.py b/tests/test_add_channeld.py new file mode 100644 index 0000000000..a2940ffffb --- /dev/null +++ b/tests/test_add_channeld.py @@ -0,0 +1,37 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np +from parameterized import parameterized +from monai.transforms.composables import AddChanneld + +TEST_CASE_1 = [ + {'keys': ['img', 'seg']}, + { + 'img': np.array([[0, 1], [1, 2]]), + 'seg': np.array([[0, 1], [1, 2]]) + }, + (1, 2, 2), +] + + +class TestAddChanneld(unittest.TestCase): + + @parameterized.expand([TEST_CASE_1]) + def test_shape(self, input_param, input_data, expected_shape): + result = AddChanneld(**input_param)(input_data) + self.assertEqual(result['img'].shape, expected_shape) + self.assertEqual(result['seg'].shape, expected_shape) + + +if __name__ == '__main__': + unittest.main()