From f5cca2694b68c6fa1a8a431dce044d18c8d59bb4 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Tue, 3 Mar 2020 22:07:48 +0800 Subject: [PATCH 1/3] [DLMED] add UNet example with dict based transforms --- ...on_3d.py => unet_segmentation_3d_array.py} | 0 examples/unet_segmentation_3d_dict.py | 179 ++++++++++++++++++ monai/data/nifti_reader.py | 81 +++++++- monai/transforms/composables.py | 24 ++- monai/transforms/transforms.py | 2 +- monai/utils/constants.py | 40 ++++ tests/test_add_channeld.py | 37 ++++ 7 files changed, 355 insertions(+), 8 deletions(-) rename examples/{unet_segmentation_3d.py => unet_segmentation_3d_array.py} (100%) create mode 100644 examples/unet_segmentation_3d_dict.py create mode 100644 monai/utils/constants.py create mode 100644 tests/test_add_channeld.py diff --git a/examples/unet_segmentation_3d.py b/examples/unet_segmentation_3d_array.py similarity index 100% rename from examples/unet_segmentation_3d.py rename to examples/unet_segmentation_3d_array.py diff --git a/examples/unet_segmentation_3d_dict.py b/examples/unet_segmentation_3d_dict.py new file mode 100644 index 0000000000..53f6cff1aa --- /dev/null +++ b/examples/unet_segmentation_3d_dict.py @@ -0,0 +1,179 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +import tempfile +from glob import glob +import logging + +import nibabel as nib +import numpy as np +import torch +from torch.utils.tensorboard import SummaryWriter +from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator, _prepare_batch +from ignite.handlers import ModelCheckpoint, EarlyStopping +from torch.utils.data import DataLoader + +# assumes the framework is found here, change as necessary +sys.path.append("..") + +import monai +import monai.transforms.compose as transforms +from monai.utils.constants import DataElementKey as Dek +from monai.data.nifti_reader import NiftiDatasetd +from monai.transforms.composables import AddChanneld, RandRotate90d +from monai.handlers.stats_handler import StatsHandler +from monai.handlers.mean_dice import MeanDice +from monai.visualize import img2tensorboard +from monai.data.synthetic import create_test_image_3d +from monai.handlers.utils import stopping_fn_from_metric + +monai.config.print_config() + +# Create a temporary directory and 50 random image, mask paris +tempdir = tempfile.mkdtemp() + +for i in range(50): + im, seg = create_test_image_3d(128, 128, 128) + + n = nib.Nifti1Image(im, np.eye(4)) + nib.save(n, os.path.join(tempdir, 'im%i.nii.gz' % i)) + + n = nib.Nifti1Image(seg, np.eye(4)) + nib.save(n, os.path.join(tempdir, 'seg%i.nii.gz' % i)) + +images = sorted(glob(os.path.join(tempdir, 'im*.nii.gz'))) +segs = sorted(glob(os.path.join(tempdir, 'seg*.nii.gz'))) + +# Define transforms for image and segmentation +transforms = transforms.Compose([ + AddChanneld(keys=[Dek.IMAGE, Dek.LABEL]), + RandRotate90d(keys=[Dek.IMAGE, Dek.LABEL], prob=0.8, axes=[1, 3]) +]) + +# Define nifti dataset, dataloader. +ds = NiftiDatasetd(images, segs, transform=transforms) +loader = DataLoader(ds, batch_size=10, num_workers=2, pin_memory=torch.cuda.is_available()) +check_data = monai.utils.misc.first(loader) +print(check_data[Dek.IMAGE].shape, check_data[Dek.LABEL].shape) + +lr = 1e-5 + +# Create UNet, DiceLoss and Adam optimizer. +net = monai.networks.nets.UNet( + dimensions=3, + in_channels=1, + num_classes=1, + channels=(16, 32, 64, 128, 256), + strides=(2, 2, 2, 2), + num_res_units=2, +) + +loss = monai.losses.DiceLoss(do_sigmoid=True) +opt = torch.optim.Adam(net.parameters(), lr) + +# Since network outputs logits and segmentation, we need a custom function. +def _loss_fn(i, j): + return loss(i[0], j) + +# Create trainer +def prepare_batch(batch, device=None, non_blocking=False): + return _prepare_batch((batch[Dek.IMAGE], batch[Dek.LABEL]), device, non_blocking) + +device = torch.device("cuda:0") +trainer = create_supervised_trainer(net, opt, _loss_fn, device, False, + prepare_batch=prepare_batch, + output_transform=lambda x, y, y_pred, loss: [y_pred, loss.item(), y]) + +# adding checkpoint handler to save models (network params and optimizer stats) during training +checkpoint_handler = ModelCheckpoint('./runs/', 'net', n_saved=10, require_empty=False) +trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED, + handler=checkpoint_handler, + to_save={'net': net, 'opt': opt}) +train_stats_handler = StatsHandler() +train_stats_handler.attach(trainer) + +@trainer.on(Events.EPOCH_COMPLETED) +def log_training_loss(engine): + # log loss to tensorboard with second item of engine.state.output, loss.item() from output_transform + writer.add_scalar('Loss/train', engine.state.output[1], engine.state.epoch) + + # tensor of ones to use where for converting labels to zero and ones + ones = torch.ones(engine.state.batch[Dek.LABEL][0].shape, dtype=torch.int32) + first_output_tensor = engine.state.output[0][1][0].detach().cpu() + # log model output to tensorboard, as three dimensional tensor with no channels dimension + img2tensorboard.add_animated_gif_no_channels(writer, "first_output_final_batch", first_output_tensor, 64, + 255, engine.state.epoch) + # get label tensor and convert to single class + first_label_tensor = torch.where(engine.state.batch[Dek.LABEL][0] > 0, ones, engine.state.batch[Dek.LABEL][0]) + # log label tensor to tensorboard, there is a channel dimension when getting label from batch + img2tensorboard.add_animated_gif(writer, "first_label_final_batch", first_label_tensor, 64, + 255, engine.state.epoch) + second_output_tensor = engine.state.output[0][1][1].detach().cpu() + img2tensorboard.add_animated_gif_no_channels(writer, "second_output_final_batch", second_output_tensor, 64, + 255, engine.state.epoch) + second_label_tensor = torch.where(engine.state.batch[Dek.LABEL][1] > 0, ones, engine.state.batch[Dek.LABEL][1]) + img2tensorboard.add_animated_gif(writer, "second_label_final_batch", second_label_tensor, 64, + 255, engine.state.epoch) + third_output_tensor = engine.state.output[0][1][2].detach().cpu() + img2tensorboard.add_animated_gif_no_channels(writer, "third_output_final_batch", third_output_tensor, 64, + 255, engine.state.epoch) + third_label_tensor = torch.where(engine.state.batch[Dek.LABEL][2] > 0, ones, engine.state.batch[Dek.LABEL][2]) + img2tensorboard.add_animated_gif(writer, "third_label_final_batch", third_label_tensor, 64, + 255, engine.state.epoch) + engine.logger.info("Epoch[%s] Loss: %s", engine.state.epoch, engine.state.output[1]) + +writer = SummaryWriter() + +# Set parameters for validation +validation_every_n_epochs = 1 +metric_name = 'Mean_Dice' + +# add evaluation metric to the evaluator engine +val_metrics = {metric_name: MeanDice(add_sigmoid=True, to_onehot_y=False)} +evaluator = create_supervised_evaluator(net, val_metrics, device, True, + prepare_batch=prepare_batch, + output_transform=lambda x, y, y_pred: (y_pred[0], y)) + +# Add stats event handler to print validation stats via evaluator +logging.basicConfig(stream=sys.stdout, level=logging.INFO) +val_stats_handler = StatsHandler() +val_stats_handler.attach(evaluator) + +# Add early stopping handler to evaluator. +early_stopper = EarlyStopping(patience=4, + score_function=stopping_fn_from_metric(metric_name), + trainer=trainer) +evaluator.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=early_stopper) + +# create a validation data loader +val_ds = NiftiDatasetd(images[-20:], segs[-20:], transform=transforms) +val_loader = DataLoader(ds, batch_size=5, num_workers=8, pin_memory=torch.cuda.is_available()) + + +@trainer.on(Events.EPOCH_COMPLETED(every=validation_every_n_epochs)) +def run_validation(engine): + evaluator.run(val_loader) + +@evaluator.on(Events.EPOCH_COMPLETED) +def log_metrics_to_tensorboard(engine): + for _, value in engine.state.metrics.items(): + writer.add_scalar('Metrics/' + metric_name, value, trainer.state.epoch) + +# create a training data loader +logging.basicConfig(stream=sys.stdout, level=logging.INFO) + +train_ds = NiftiDatasetd(images[:20], segs[:20], transform=transforms) +train_loader = DataLoader(train_ds, batch_size=5, num_workers=8, pin_memory=torch.cuda.is_available()) + +train_epochs = 30 +state = trainer.run(train_loader, train_epochs) diff --git a/monai/data/nifti_reader.py b/monai/data/nifti_reader.py index c803411b3b..6aa849e04a 100644 --- a/monai/data/nifti_reader.py +++ b/monai/data/nifti_reader.py @@ -14,7 +14,8 @@ from torch.utils.data import Dataset from torch.utils.data._utils.collate import np_str_obj_array_pattern - +from monai.utils.constants import DataElementKey as Dek +from monai.utils.constants import ImageProperty as Prop from monai.utils.module import export @@ -40,14 +41,14 @@ def load_nifti(filename_or_obj, as_closest_canonical=False, image_only=True, dty img = nib.load(filename_or_obj) header = dict(img.header) - header['filename_or_obj'] = filename_or_obj - header['original_affine'] = img.affine - header['affine'] = img.affine - header['as_closest_canonical'] = as_closest_canonical + header[Prop.FILENAME_OR_OBJ] = filename_or_obj + header[Prop.ORIGINAL_AFFINE] = img.affine + header[Prop.AFFINE] = img.affine + header[Prop.AS_CLOSEST_CANONICAL] = as_closest_canonical if as_closest_canonical: img = nib.as_closest_canonical(img) - header['affine'] = img.affine + header[Prop.AFFINE] = img.affine if dtype is not None: dat = img.get_fdata(dtype=dtype) @@ -131,3 +132,71 @@ def __getitem__(self, index): continue compatible_meta[meta_key] = meta_datum return img, seg, compatible_meta + + +@export("monai.data") +class NiftiDatasetd(Dataset): + """ + Loads image/segmentation pairs of Nifti files from the given filename lists. Dict level transformations can be + specified for the dictionary data which is constructed by image, label and other metadata. + """ + + def __init__(self, image_files, seg_files, as_closest_canonical=False, transform=None, + image_only=True, dtype=None): + """ + Initializes the dataset with the image and segmentation filename lists. The transform `transform` is applied + to the images and `seg_transform` to the segmentations. + + Args: + image_files (list of str): list of image filenames. + seg_files (list of str): list of segmentation filenames. + as_closest_canonical (bool): if True, load the image as closest to canonical orientation. + transform (Callable, optional): dict transforms to excute operations on dictionary data. + image_only (bool): if True return only the image volume, other return image volume and header dict. + dtype (np.dtype, optional): if not None convert the loaded image to this data type. + """ + + if len(image_files) != len(seg_files): + raise ValueError('Must have same number of image and segmentation files') + + self.image_files = image_files + self.seg_files = seg_files + self.as_closest_canonical = as_closest_canonical + self.transform = transform + self.image_only = image_only + self.dtype = dtype + + def __len__(self): + return len(self.image_files) + + def __getitem__(self, index): + meta_data = None + if self.image_only: + img = load_nifti(self.image_files[index], as_closest_canonical=self.as_closest_canonical, + image_only=self.image_only, dtype=self.dtype) + else: + img, meta_data = load_nifti(self.image_files[index], as_closest_canonical=self.as_closest_canonical, + image_only=self.image_only, dtype=self.dtype) + seg = load_nifti(self.seg_files[index]) + + compatible_meta = {} + if meta_data is not None: + assert isinstance(meta_data, dict), 'meta_data must be in dictionary format.' + for meta_key in meta_data: + meta_datum = meta_data[meta_key] + if type(meta_datum).__name__ == 'ndarray' \ + and np_str_obj_array_pattern.search(meta_datum.dtype.str) is not None: + continue + compatible_meta[meta_key] = meta_datum + + data = { + Dek.IMAGE: img, + Dek.LABEL: seg + } + if len(compatible_meta) > 0: + data.update(compatible_meta) + + if self.transform is not None: + data = self.transform(data) + + return data diff --git a/monai/transforms/composables.py b/monai/transforms/composables.py index 48c11b7487..e70fd9ed85 100644 --- a/monai/transforms/composables.py +++ b/monai/transforms/composables.py @@ -17,7 +17,7 @@ import monai from monai.transforms.compose import Randomizable, Transform -from monai.transforms.transforms import Rotate90 +from monai.transforms.transforms import Rotate90, AddChannel from monai.utils.misc import ensure_tuple export = monai.utils.export("monai.transforms") @@ -120,6 +120,28 @@ def __call__(self, data): return d +@export +class AddChanneld(MapTransform): + """ + dictionary-based wrapper of AddChannel. + """ + + def __init__(self, keys): + """ + Args: + keys (hashable items): keys of the corresponding items to be transformed. + See also: monai.transform.composables.MapTransform + """ + MapTransform.__init__(self, keys) + self.adder = AddChannel() + + def __call__(self, data): + d = dict(data) + for key in self.keys: + d[key] = self.adder(d[key]) + return d + + # if __name__ == "__main__": # import numpy as np # data = { diff --git a/monai/transforms/transforms.py b/monai/transforms/transforms.py index 602baf48d1..acc1c441d8 100644 --- a/monai/transforms/transforms.py +++ b/monai/transforms/transforms.py @@ -167,7 +167,7 @@ def __init__(self, k=1, axes=(1, 2)): self.plane_axes = axes def __call__(self, img): - return np.rot90(img, self.k, self.plane_axes) + return np.ascontiguousarray(np.rot90(img, self.k, self.plane_axes)) @export diff --git a/monai/utils/constants.py b/monai/utils/constants.py new file mode 100644 index 0000000000..915c3fcc9a --- /dev/null +++ b/monai/utils/constants.py @@ -0,0 +1,40 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +class ActivationFunc: + """Commonly used activation function names. + """ + + SOFTMAX = "softmax" + LOG_SOFTMAX = "log_softmax" + SIGMOID = "sigmoid" + LINEAR = "linear" + TANH = "tanh" + + +class DataElementKey: + """Data Element keys + """ + + IMAGE = "image" + LABEL = "label" + + +class ImageProperty: + """Key names for image properties. + """ + + FILENAME_OR_OBJ = 'filename_or_obj' + AFFINE = 'affine' # image affine matrix + ORIGINAL_AFFINE = 'original_affine' # original affine matrix before transformation + SPACING = 'spacing' # itk naming convention for pixel/voxel size + AS_CLOSEST_CANONICAL = 'as_closest_canonical' # load the image as closest to canonical axis format + BACKGROUND_INDEX = 'background_index' # which index is background diff --git a/tests/test_add_channeld.py b/tests/test_add_channeld.py new file mode 100644 index 0000000000..a2940ffffb --- /dev/null +++ b/tests/test_add_channeld.py @@ -0,0 +1,37 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np +from parameterized import parameterized +from monai.transforms.composables import AddChanneld + +TEST_CASE_1 = [ + {'keys': ['img', 'seg']}, + { + 'img': np.array([[0, 1], [1, 2]]), + 'seg': np.array([[0, 1], [1, 2]]) + }, + (1, 2, 2), +] + + +class TestAddChanneld(unittest.TestCase): + + @parameterized.expand([TEST_CASE_1]) + def test_shape(self, input_param, input_data, expected_shape): + result = AddChanneld(**input_param)(input_data) + self.assertEqual(result['img'].shape, expected_shape) + self.assertEqual(result['seg'].shape, expected_shape) + + +if __name__ == '__main__': + unittest.main() From de4e390fc7e04070b780c966a2f5479d413043c4 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Thu, 5 Mar 2020 11:24:37 +0800 Subject: [PATCH 2/3] [DLMED] temporarily remove constants, will discuss after GTC --- examples/unet_segmentation_3d_dict.py | 24 ++++++++++------ monai/data/nifti_reader.py | 31 ++++++++++++--------- monai/utils/constants.py | 40 --------------------------- 3 files changed, 33 insertions(+), 62 deletions(-) delete mode 100644 monai/utils/constants.py diff --git a/examples/unet_segmentation_3d_dict.py b/examples/unet_segmentation_3d_dict.py index 53f6cff1aa..037bfb9fbb 100644 --- a/examples/unet_segmentation_3d_dict.py +++ b/examples/unet_segmentation_3d_dict.py @@ -28,7 +28,6 @@ import monai import monai.transforms.compose as transforms -from monai.utils.constants import DataElementKey as Dek from monai.data.nifti_reader import NiftiDatasetd from monai.transforms.composables import AddChanneld, RandRotate90d from monai.handlers.stats_handler import StatsHandler @@ -56,15 +55,15 @@ # Define transforms for image and segmentation transforms = transforms.Compose([ - AddChanneld(keys=[Dek.IMAGE, Dek.LABEL]), - RandRotate90d(keys=[Dek.IMAGE, Dek.LABEL], prob=0.8, axes=[1, 3]) + AddChanneld(keys=['image', 'label']), + RandRotate90d(keys=['image', 'label'], prob=0.8, axes=[1, 3]) ]) # Define nifti dataset, dataloader. ds = NiftiDatasetd(images, segs, transform=transforms) loader = DataLoader(ds, batch_size=10, num_workers=2, pin_memory=torch.cuda.is_available()) check_data = monai.utils.misc.first(loader) -print(check_data[Dek.IMAGE].shape, check_data[Dek.LABEL].shape) +print(check_data['image'].shape, check_data['label'].shape) lr = 1e-5 @@ -81,13 +80,16 @@ loss = monai.losses.DiceLoss(do_sigmoid=True) opt = torch.optim.Adam(net.parameters(), lr) + # Since network outputs logits and segmentation, we need a custom function. def _loss_fn(i, j): return loss(i[0], j) + # Create trainer def prepare_batch(batch, device=None, non_blocking=False): - return _prepare_batch((batch[Dek.IMAGE], batch[Dek.LABEL]), device, non_blocking) + return _prepare_batch((batch['image'], batch['label']), device, non_blocking) + device = torch.device("cuda:0") trainer = create_supervised_trainer(net, opt, _loss_fn, device, False, @@ -102,36 +104,38 @@ def prepare_batch(batch, device=None, non_blocking=False): train_stats_handler = StatsHandler() train_stats_handler.attach(trainer) + @trainer.on(Events.EPOCH_COMPLETED) def log_training_loss(engine): # log loss to tensorboard with second item of engine.state.output, loss.item() from output_transform writer.add_scalar('Loss/train', engine.state.output[1], engine.state.epoch) # tensor of ones to use where for converting labels to zero and ones - ones = torch.ones(engine.state.batch[Dek.LABEL][0].shape, dtype=torch.int32) + ones = torch.ones(engine.state.batch['label'][0].shape, dtype=torch.int32) first_output_tensor = engine.state.output[0][1][0].detach().cpu() # log model output to tensorboard, as three dimensional tensor with no channels dimension img2tensorboard.add_animated_gif_no_channels(writer, "first_output_final_batch", first_output_tensor, 64, 255, engine.state.epoch) # get label tensor and convert to single class - first_label_tensor = torch.where(engine.state.batch[Dek.LABEL][0] > 0, ones, engine.state.batch[Dek.LABEL][0]) + first_label_tensor = torch.where(engine.state.batch['label'][0] > 0, ones, engine.state.batch['label'][0]) # log label tensor to tensorboard, there is a channel dimension when getting label from batch img2tensorboard.add_animated_gif(writer, "first_label_final_batch", first_label_tensor, 64, 255, engine.state.epoch) second_output_tensor = engine.state.output[0][1][1].detach().cpu() img2tensorboard.add_animated_gif_no_channels(writer, "second_output_final_batch", second_output_tensor, 64, 255, engine.state.epoch) - second_label_tensor = torch.where(engine.state.batch[Dek.LABEL][1] > 0, ones, engine.state.batch[Dek.LABEL][1]) + second_label_tensor = torch.where(engine.state.batch['label'][1] > 0, ones, engine.state.batch['label'][1]) img2tensorboard.add_animated_gif(writer, "second_label_final_batch", second_label_tensor, 64, 255, engine.state.epoch) third_output_tensor = engine.state.output[0][1][2].detach().cpu() img2tensorboard.add_animated_gif_no_channels(writer, "third_output_final_batch", third_output_tensor, 64, 255, engine.state.epoch) - third_label_tensor = torch.where(engine.state.batch[Dek.LABEL][2] > 0, ones, engine.state.batch[Dek.LABEL][2]) + third_label_tensor = torch.where(engine.state.batch['label'][2] > 0, ones, engine.state.batch['label'][2]) img2tensorboard.add_animated_gif(writer, "third_label_final_batch", third_label_tensor, 64, 255, engine.state.epoch) engine.logger.info("Epoch[%s] Loss: %s", engine.state.epoch, engine.state.output[1]) + writer = SummaryWriter() # Set parameters for validation @@ -164,11 +168,13 @@ def log_training_loss(engine): def run_validation(engine): evaluator.run(val_loader) + @evaluator.on(Events.EPOCH_COMPLETED) def log_metrics_to_tensorboard(engine): for _, value in engine.state.metrics.items(): writer.add_scalar('Metrics/' + metric_name, value, trainer.state.epoch) + # create a training data loader logging.basicConfig(stream=sys.stdout, level=logging.INFO) diff --git a/monai/data/nifti_reader.py b/monai/data/nifti_reader.py index 2aeeafdc87..82bc04340b 100644 --- a/monai/data/nifti_reader.py +++ b/monai/data/nifti_reader.py @@ -14,8 +14,6 @@ from torch.utils.data import Dataset from torch.utils.data._utils.collate import np_str_obj_array_pattern -from monai.utils.constants import DataElementKey as Dek -from monai.utils.constants import ImageProperty as Prop from monai.utils.module import export from monai.transforms.compose import Randomizable @@ -42,14 +40,14 @@ def load_nifti(filename_or_obj, as_closest_canonical=False, image_only=True, dty img = nib.load(filename_or_obj) header = dict(img.header) - header[Prop.FILENAME_OR_OBJ] = filename_or_obj - header[Prop.ORIGINAL_AFFINE] = img.affine - header[Prop.AFFINE] = img.affine - header[Prop.AS_CLOSEST_CANONICAL] = as_closest_canonical + header['filename_or_obj'] = filename_or_obj + header['original_affine'] = img.affine + header['affine'] = img.affine + header['as_closest_canonical'] = as_closest_canonical if as_closest_canonical: img = nib.as_closest_canonical(img) - header[Prop.AFFINE] = img.affine + header['affine'] = img.affine if dtype is not None: dat = img.get_fdata(dtype=dtype) @@ -108,6 +106,7 @@ def __getitem__(self, index): else: img, meta_data = load_nifti(self.image_files[index], as_closest_canonical=self.as_closest_canonical, image_only=self.image_only, dtype=self.dtype) + target = None if self.seg_files is not None: target = load_nifti(self.seg_files[index]) elif self.labels is not None: @@ -145,7 +144,7 @@ class NiftiDatasetd(Dataset): specified for the dictionary data which is constructed by image, label and other metadata. """ - def __init__(self, image_files, seg_files, as_closest_canonical=False, transform=None, + def __init__(self, image_files, seg_files=None, labels=None, as_closest_canonical=False, transform=None, image_only=True, dtype=None): """ Initializes the dataset with the image and segmentation filename lists. The transform `transform` is applied @@ -153,7 +152,8 @@ def __init__(self, image_files, seg_files, as_closest_canonical=False, transform Args: image_files (list of str): list of image filenames. - seg_files (list of str): list of segmentation filenames. + seg_files (list of str): if in segmentation task, list of segmentation filenames. + labels (list or array): if in classification task, list of classification labels. as_closest_canonical (bool): if True, load the image as closest to canonical orientation. transform (Callable, optional): dict transforms to excute operations on dictionary data. image_only (bool): if True return only the image volume, other return image volume and header dict. @@ -165,6 +165,7 @@ def __init__(self, image_files, seg_files, as_closest_canonical=False, transform self.image_files = image_files self.seg_files = seg_files + self.labels = labels self.as_closest_canonical = as_closest_canonical self.transform = transform self.image_only = image_only @@ -181,7 +182,11 @@ def __getitem__(self, index): else: img, meta_data = load_nifti(self.image_files[index], as_closest_canonical=self.as_closest_canonical, image_only=self.image_only, dtype=self.dtype) - seg = load_nifti(self.seg_files[index]) + target = None + if self.seg_files is not None: + target = load_nifti(self.seg_files[index]) + elif self.labels is not None: + target = self.labels[index] compatible_meta = {} if meta_data is not None: @@ -194,8 +199,8 @@ def __getitem__(self, index): compatible_meta[meta_key] = meta_datum data = { - Dek.IMAGE: img, - Dek.LABEL: seg + 'image': img, + 'label': target } if len(compatible_meta) > 0: data.update(compatible_meta) @@ -203,4 +208,4 @@ def __getitem__(self, index): if self.transform is not None: data = self.transform(data) - return data \ No newline at end of file + return data diff --git a/monai/utils/constants.py b/monai/utils/constants.py deleted file mode 100644 index 915c3fcc9a..0000000000 --- a/monai/utils/constants.py +++ /dev/null @@ -1,40 +0,0 @@ -# Copyright 2020 MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -class ActivationFunc: - """Commonly used activation function names. - """ - - SOFTMAX = "softmax" - LOG_SOFTMAX = "log_softmax" - SIGMOID = "sigmoid" - LINEAR = "linear" - TANH = "tanh" - - -class DataElementKey: - """Data Element keys - """ - - IMAGE = "image" - LABEL = "label" - - -class ImageProperty: - """Key names for image properties. - """ - - FILENAME_OR_OBJ = 'filename_or_obj' - AFFINE = 'affine' # image affine matrix - ORIGINAL_AFFINE = 'original_affine' # original affine matrix before transformation - SPACING = 'spacing' # itk naming convention for pixel/voxel size - AS_CLOSEST_CANONICAL = 'as_closest_canonical' # load the image as closest to canonical axis format - BACKGROUND_INDEX = 'background_index' # which index is background From 66e80e55f0fa5ed688258a5dd9d105604d152d14 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Thu, 5 Mar 2020 20:40:25 +0800 Subject: [PATCH 3/3] [DLMED] adjust default behavior according to comments --- examples/unet_segmentation_3d_dict.py | 16 ++++---- monai/data/nifti_reader.py | 53 +++++++++++++-------------- 2 files changed, 34 insertions(+), 35 deletions(-) diff --git a/examples/unet_segmentation_3d_dict.py b/examples/unet_segmentation_3d_dict.py index 037bfb9fbb..8c6955d87b 100644 --- a/examples/unet_segmentation_3d_dict.py +++ b/examples/unet_segmentation_3d_dict.py @@ -55,15 +55,15 @@ # Define transforms for image and segmentation transforms = transforms.Compose([ - AddChanneld(keys=['image', 'label']), - RandRotate90d(keys=['image', 'label'], prob=0.8, axes=[1, 3]) + AddChanneld(keys=['image', 'seg']), + RandRotate90d(keys=['image', 'seg'], prob=0.8, axes=[1, 3]) ]) # Define nifti dataset, dataloader. ds = NiftiDatasetd(images, segs, transform=transforms) loader = DataLoader(ds, batch_size=10, num_workers=2, pin_memory=torch.cuda.is_available()) check_data = monai.utils.misc.first(loader) -print(check_data['image'].shape, check_data['label'].shape) +print(check_data['image'].shape, check_data['seg'].shape) lr = 1e-5 @@ -88,7 +88,7 @@ def _loss_fn(i, j): # Create trainer def prepare_batch(batch, device=None, non_blocking=False): - return _prepare_batch((batch['image'], batch['label']), device, non_blocking) + return _prepare_batch((batch['image'], batch['seg']), device, non_blocking) device = torch.device("cuda:0") @@ -111,26 +111,26 @@ def log_training_loss(engine): writer.add_scalar('Loss/train', engine.state.output[1], engine.state.epoch) # tensor of ones to use where for converting labels to zero and ones - ones = torch.ones(engine.state.batch['label'][0].shape, dtype=torch.int32) + ones = torch.ones(engine.state.batch['seg'][0].shape, dtype=torch.int32) first_output_tensor = engine.state.output[0][1][0].detach().cpu() # log model output to tensorboard, as three dimensional tensor with no channels dimension img2tensorboard.add_animated_gif_no_channels(writer, "first_output_final_batch", first_output_tensor, 64, 255, engine.state.epoch) # get label tensor and convert to single class - first_label_tensor = torch.where(engine.state.batch['label'][0] > 0, ones, engine.state.batch['label'][0]) + first_label_tensor = torch.where(engine.state.batch['seg'][0] > 0, ones, engine.state.batch['seg'][0]) # log label tensor to tensorboard, there is a channel dimension when getting label from batch img2tensorboard.add_animated_gif(writer, "first_label_final_batch", first_label_tensor, 64, 255, engine.state.epoch) second_output_tensor = engine.state.output[0][1][1].detach().cpu() img2tensorboard.add_animated_gif_no_channels(writer, "second_output_final_batch", second_output_tensor, 64, 255, engine.state.epoch) - second_label_tensor = torch.where(engine.state.batch['label'][1] > 0, ones, engine.state.batch['label'][1]) + second_label_tensor = torch.where(engine.state.batch['seg'][1] > 0, ones, engine.state.batch['seg'][1]) img2tensorboard.add_animated_gif(writer, "second_label_final_batch", second_label_tensor, 64, 255, engine.state.epoch) third_output_tensor = engine.state.output[0][1][2].detach().cpu() img2tensorboard.add_animated_gif_no_channels(writer, "third_output_final_batch", third_output_tensor, 64, 255, engine.state.epoch) - third_label_tensor = torch.where(engine.state.batch['label'][2] > 0, ones, engine.state.batch['label'][2]) + third_label_tensor = torch.where(engine.state.batch['seg'][2] > 0, ones, engine.state.batch['seg'][2]) img2tensorboard.add_animated_gif(writer, "third_label_final_batch", third_label_tensor, 64, 255, engine.state.epoch) engine.logger.info("Epoch[%s] Loss: %s", engine.state.epoch, engine.state.output[1]) diff --git a/monai/data/nifti_reader.py b/monai/data/nifti_reader.py index 82bc04340b..753691de20 100644 --- a/monai/data/nifti_reader.py +++ b/monai/data/nifti_reader.py @@ -144,8 +144,7 @@ class NiftiDatasetd(Dataset): specified for the dictionary data which is constructed by image, label and other metadata. """ - def __init__(self, image_files, seg_files=None, labels=None, as_closest_canonical=False, transform=None, - image_only=True, dtype=None): + def __init__(self, image_files, seg_files=None, labels=None, as_closest_canonical=False, transform=None, dtype=None): """ Initializes the dataset with the image and segmentation filename lists. The transform `transform` is applied to the images and `seg_transform` to the segmentations. @@ -156,7 +155,6 @@ def __init__(self, image_files, seg_files=None, labels=None, as_closest_canonica labels (list or array): if in classification task, list of classification labels. as_closest_canonical (bool): if True, load the image as closest to canonical orientation. transform (Callable, optional): dict transforms to excute operations on dictionary data. - image_only (bool): if True return only the image volume, other return image volume and header dict. dtype (np.dtype, optional): if not None convert the loaded image to this data type. """ @@ -168,7 +166,6 @@ def __init__(self, image_files, seg_files=None, labels=None, as_closest_canonica self.labels = labels self.as_closest_canonical = as_closest_canonical self.transform = transform - self.image_only = image_only self.dtype = dtype def __len__(self): @@ -176,32 +173,34 @@ def __len__(self): def __getitem__(self, index): meta_data = None - if self.image_only: - img = load_nifti(self.image_files[index], as_closest_canonical=self.as_closest_canonical, - image_only=self.image_only, dtype=self.dtype) - else: - img, meta_data = load_nifti(self.image_files[index], as_closest_canonical=self.as_closest_canonical, - image_only=self.image_only, dtype=self.dtype) - target = None + img, meta_data = load_nifti( + filename_or_obj=self.image_files[index], + as_closest_canonical=self.as_closest_canonical, + image_only=False, + dtype=self.dtype + ) + + seg = None if self.seg_files is not None: - target = load_nifti(self.seg_files[index]) - elif self.labels is not None: - target = self.labels[index] + seg = load_nifti(self.seg_files[index]) + label = None + if self.labels is not None: + label = self.labels[index] compatible_meta = {} - if meta_data is not None: - assert isinstance(meta_data, dict), 'meta_data must be in dictionary format.' - for meta_key in meta_data: - meta_datum = meta_data[meta_key] - if type(meta_datum).__name__ == 'ndarray' \ - and np_str_obj_array_pattern.search(meta_datum.dtype.str) is not None: - continue - compatible_meta[meta_key] = meta_datum - - data = { - 'image': img, - 'label': target - } + assert isinstance(meta_data, dict), 'meta_data must be in dictionary format.' + for meta_key in meta_data: + meta_datum = meta_data[meta_key] + if type(meta_datum).__name__ == 'ndarray' \ + and np_str_obj_array_pattern.search(meta_datum.dtype.str) is not None: + continue + compatible_meta[meta_key] = meta_datum + + data = {'image': img} + if seg is not None: + data['seg'] = seg + if label is not None: + data['label'] = label if len(compatible_meta) > 0: data.update(compatible_meta)