From ede726e9de18b1197ce13938d39d5bba593ee927 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Thu, 5 Mar 2020 20:22:05 +0800 Subject: [PATCH 1/9] [DLMED] add LoadNifti transform and common Dataset --- monai/data/nifti_reader.py | 35 +++++++++++++++++++ monai/transforms/composables.py | 31 ++++++++++++++++- monai/transforms/transforms.py | 62 ++++++++++++++++++++++++++++++++- tests/test_load_nifti.py | 53 ++++++++++++++++++++++++++++ tests/test_load_niftid.py | 54 ++++++++++++++++++++++++++++ 5 files changed, 233 insertions(+), 2 deletions(-) create mode 100644 tests/test_load_nifti.py create mode 100644 tests/test_load_niftid.py diff --git a/monai/data/nifti_reader.py b/monai/data/nifti_reader.py index 287c97fdde..2c5ed339dc 100644 --- a/monai/data/nifti_reader.py +++ b/monai/data/nifti_reader.py @@ -135,3 +135,38 @@ def __getitem__(self, index): continue compatible_meta[meta_key] = meta_datum return img, target, compatible_meta + + +@export("monai.data") +class Datasetd(Dataset): + """ + General Dataset to handle dictionary format data, it can operate transforms for specific fields. + Input data should be a list of dictionaries, for example: + [{ { { + 'img': 'image1.nii.gz', 'img': 'image2.nii.gz', 'img': 'image3.nii.gz', + 'seg': 'label1.nii.gz', 'seg': 'label2.nii.gz', 'seg': 'label3.nii.gz', + 'extra': 123, 'extra': 456, 'extra': 789, + 'shape': 'CHWD' 'shape': 'CHWD' 'shape': 'CHWD' + }, }, }] + """ + + def __init__(self, data, transform=None): + """ + Args: + data (dict): input data to load and transform to generate dataset for model. + transform (Callable, optional): dict transforms to excute operations on dictionary data. + """ + assert isinstance(data, list) and all(isinstance(item, dict) for item in data), \ + 'input data must be a list of dictionaries.' + self.data = data + self.transform = transform + + def __len__(self): + return len(self.data) + + def __getitem__(self, index): + data = self.data[index] + if self.transform is not None: + data = self.transform(data) + + return data diff --git a/monai/transforms/composables.py b/monai/transforms/composables.py index 4c13e105ca..ef3e6755d4 100644 --- a/monai/transforms/composables.py +++ b/monai/transforms/composables.py @@ -18,7 +18,7 @@ import monai from monai.data.utils import get_random_patch, get_valid_patch_size from monai.transforms.compose import Randomizable, Transform -from monai.transforms.transforms import Rotate90, SpatialCrop +from monai.transforms.transforms import LoadNifti, Rotate90, SpatialCrop from monai.utils.misc import ensure_tuple from monai.transforms.utils import generate_pos_neg_label_crop_centers @@ -53,6 +53,35 @@ def __init__(self, keys): raise ValueError('keys should be a hashable or a sequence of hashables, got {}'.format(type(key))) +@export +class LoadNiftid(MapTransform): + """ + dictionary-based wrapper of LoadNifti, must load image and metadata together. + """ + + def __init__(self, keys, as_closest_canonical=False, dtype=None): + """ + Args: + keys (hashable items): keys of the corresponding items to be transformed. + See also: monai.transform.composables.MapTransform + as_closest_canonical (bool): if True, load the image as closest to canonical axis format. + dtype (np.dtype, optional): if not None convert the loaded image to this data type. + """ + MapTransform.__init__(self, keys) + self.loader = LoadNifti(as_closest_canonical, False, dtype) + + def __call__(self, data): + d = dict(data) + for key in self.keys: + data = self.loader(d[key]) + assert isinstance(data, (tuple, list)), 'if data contains metadata, must be tuple or list.' + d[key] = data[0] + assert isinstance(data[1], dict), 'metadata must be in dict format.' + for k, v in data[1].items(): + d[key + '.' + k] = v + return d + + @export class Rotate90d(MapTransform): """ diff --git a/monai/transforms/transforms.py b/monai/transforms/transforms.py index dc6f571106..c47dd598c0 100644 --- a/monai/transforms/transforms.py +++ b/monai/transforms/transforms.py @@ -14,8 +14,9 @@ """ import numpy as np +import nibabel as nib import torch - +from torch.utils.data._utils.collate import np_str_obj_array_pattern import monai from monai.data.utils import get_random_patch, get_valid_patch_size from monai.transforms.compose import Randomizable @@ -24,6 +25,65 @@ export = monai.utils.export("monai.transforms") +@export +class LoadNifti: + """ + Load Nifti format file from provided path. + """ + + def __init__(self, as_closest_canonical=False, image_only=False, dtype=None): + """ + Args: + as_closest_canonical (bool): if True, load the image as closest to canonical axis format. + image_only (bool): if True return only the image volume, other return image volume and header dict. + dtype (np.dtype, optional): if not None convert the loaded image to this data type. + """ + self.as_closest_canonical = as_closest_canonical + self.image_only = image_only + self.dtype = dtype + + def __call__(self, img): + """ + Args: + img (str or file): path to file or file-like object. + + Returns: + The loaded image volume if `image_only` is True, or a tuple containing the volume and the Nifti + header in dict format otherwise. + + Note: + header['original_affine'] stores the original affine loaded from `filename_or_obj`. + header['affine'] stores the affine after the optional `as_closest_canonical` transform. + """ + data = nib.load(img) + + header = dict(data.header) + header['filename_or_obj'] = img + header['original_affine'] = data.affine + header['affine'] = data.affine + header['as_closest_canonical'] = self.as_closest_canonical + + if self.as_closest_canonical: + data = nib.as_closest_canonical(data) + header['affine'] = data.affine + + if self.dtype is not None: + data = data.get_fdata(dtype=self.dtype) + else: + data = np.asanyarray(data.dataobj) + + if self.image_only: + return data + compatible_meta = dict() + for meta_key in header: + meta_datum = header[meta_key] + if type(meta_datum).__name__ == 'ndarray' \ + and np_str_obj_array_pattern.search(meta_datum.dtype.str) is not None: + continue + compatible_meta[meta_key] = meta_datum + return data, compatible_meta + + @export class AddChannel: """ diff --git a/tests/test_load_nifti.py b/tests/test_load_nifti.py new file mode 100644 index 0000000000..ccf79cfaa3 --- /dev/null +++ b/tests/test_load_nifti.py @@ -0,0 +1,53 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import os +import numpy as np +import nibabel as nib +from parameterized import parameterized +from monai.transforms.transforms import LoadNifti + +TEST_CASE_1 = [ + { + 'as_closest_canonical': False, + 'image_only': True + }, + 'test_image.nii.gz', + (128, 128, 128) +] + +TEST_CASE_2 = [ + { + 'as_closest_canonical': False, + 'image_only': False + }, + 'test_image.nii.gz', + (128, 128, 128) +] + + +class TestLoadNifti(unittest.TestCase): + + @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) + def test_shape(self, input_param, input_data, expected_shape): + test_image = np.random.randint(0, 2, size=[128, 128, 128]) + nib.save(nib.Nifti1Image(test_image, np.eye(4)), 'test_image.nii.gz') + result = LoadNifti(**input_param)(input_data) + if os.path.exists('test_image.nii.gz'): + os.remove('test_image.nii.gz') + if isinstance(result, tuple): + result = result[0] + self.assertTupleEqual(result.shape, expected_shape) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_load_niftid.py b/tests/test_load_niftid.py new file mode 100644 index 0000000000..644f8e9d20 --- /dev/null +++ b/tests/test_load_niftid.py @@ -0,0 +1,54 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import os +import numpy as np +import nibabel as nib +from parameterized import parameterized +from monai.transforms.composables import LoadNiftid + +TEST_CASE_1 = [ + { + 'keys': ['image', 'label', 'extra'], + 'as_closest_canonical': False + }, + { + 'image': 'test_image.nii.gz', + 'label': 'test_label.nii.gz', + 'extra': 'test_extra.nii.gz' + }, + (128, 128, 128) +] + + +class TestLoadNiftid(unittest.TestCase): + + @parameterized.expand([TEST_CASE_1]) + def test_shape(self, input_param, input_data, expected_shape): + test_image = nib.Nifti1Image(np.random.randint(0, 2, size=[128, 128, 128]), np.eye(4)) + nib.save(test_image, 'test_image.nii.gz') + nib.save(test_image, 'test_label.nii.gz') + nib.save(test_image, 'test_extra.nii.gz') + result = LoadNiftid(**input_param)(input_data) + if os.path.exists('test_image.nii.gz'): + os.remove('test_image.nii.gz') + if os.path.exists('test_label.nii.gz'): + os.remove('test_label.nii.gz') + if os.path.exists('test_extra.nii.gz'): + os.remove('test_extra.nii.gz') + self.assertTupleEqual(result['image'].shape, expected_shape) + self.assertTupleEqual(result['label'].shape, expected_shape) + self.assertTupleEqual(result['extra'].shape, expected_shape) + + +if __name__ == '__main__': + unittest.main() From c11112b8b4f247014ac9b3d9893d16a2e7cf5cb2 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Thu, 5 Mar 2020 23:05:31 +0800 Subject: [PATCH 2/9] [DLMED] update according to comments --- monai/data/dataset.py | 45 ++++++++++++++++++++ monai/data/nifti_reader.py | 35 ---------------- monai/transforms/transforms.py | 26 ++++++------ tests/test_dataset.py | 75 ++++++++++++++++++++++++++++++++++ 4 files changed, 133 insertions(+), 48 deletions(-) create mode 100644 monai/data/dataset.py create mode 100644 tests/test_dataset.py diff --git a/monai/data/dataset.py b/monai/data/dataset.py new file mode 100644 index 0000000000..4b3221d19d --- /dev/null +++ b/monai/data/dataset.py @@ -0,0 +1,45 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from monai.utils.module import export + + +@export("monai.data") +class Dataset(torch.utils.data.Dataset): + """ + General Dataset to handle dictionary format data, it can operate transforms for specific fields. + For example, typical input data can be a list of dictionaries: + [{ { { + 'img': 'image1.nii.gz', 'img': 'image2.nii.gz', 'img': 'image3.nii.gz', + 'seg': 'label1.nii.gz', 'seg': 'label2.nii.gz', 'seg': 'label3.nii.gz', + 'extra': 123 'extra': 456 'extra': 789 + }, }, }] + """ + + def __init__(self, data, transform=None): + """ + Args: + data (Iterable): input data to load and transform to generate dataset for model. + transform (Callable, optional): transforms to excute operations on input data. + """ + self.data = data + self.transform = transform + + def __len__(self): + return len(self.data) + + def __getitem__(self, index): + data = self.data[index] + if self.transform is not None: + data = self.transform(data) + + return data diff --git a/monai/data/nifti_reader.py b/monai/data/nifti_reader.py index 96077ad099..bdfb2de951 100644 --- a/monai/data/nifti_reader.py +++ b/monai/data/nifti_reader.py @@ -137,41 +137,6 @@ def __getitem__(self, index): return img, target, compatible_meta -@export("monai.data") -class Datasetd(Dataset): - """ - General Dataset to handle dictionary format data, it can operate transforms for specific fields. - Input data should be a list of dictionaries, for example: - [{ { { - 'img': 'image1.nii.gz', 'img': 'image2.nii.gz', 'img': 'image3.nii.gz', - 'seg': 'label1.nii.gz', 'seg': 'label2.nii.gz', 'seg': 'label3.nii.gz', - 'extra': 123, 'extra': 456, 'extra': 789, - 'shape': 'CHWD' 'shape': 'CHWD' 'shape': 'CHWD' - }, }, }] - """ - - def __init__(self, data, transform=None): - """ - Args: - data (dict): input data to load and transform to generate dataset for model. - transform (Callable, optional): dict transforms to excute operations on dictionary data. - """ - assert isinstance(data, list) and all(isinstance(item, dict) for item in data), \ - 'input data must be a list of dictionaries.' - self.data = data - self.transform = transform - - def __len__(self): - return len(self.data) - - def __getitem__(self, index): - data = self.data[index] - if self.transform is not None: - data = self.transform(data) - - return data - - @export("monai.data") class NiftiDatasetd(Dataset): """ diff --git a/monai/transforms/transforms.py b/monai/transforms/transforms.py index fcd5b745da..a8e533aa7c 100644 --- a/monai/transforms/transforms.py +++ b/monai/transforms/transforms.py @@ -45,10 +45,10 @@ def __init__(self, as_closest_canonical=False, image_only=False, dtype=None): self.image_only = image_only self.dtype = dtype - def __call__(self, img): + def __call__(self, filename): """ Args: - img (str or file): path to file or file-like object. + filename (str or file): path to file or file-like object. Returns: The loaded image volume if `image_only` is True, or a tuple containing the volume and the Nifti @@ -58,25 +58,25 @@ def __call__(self, img): header['original_affine'] stores the original affine loaded from `filename_or_obj`. header['affine'] stores the affine after the optional `as_closest_canonical` transform. """ - data = nib.load(img) + img = nib.load(filename) - header = dict(data.header) - header['filename_or_obj'] = img - header['original_affine'] = data.affine - header['affine'] = data.affine + header = dict(img.header) + header['filename_or_obj'] = filename + header['original_affine'] = img.affine + header['affine'] = img.affine header['as_closest_canonical'] = self.as_closest_canonical if self.as_closest_canonical: - data = nib.as_closest_canonical(data) - header['affine'] = data.affine + img = nib.as_closest_canonical(img) + header['affine'] = img.affine if self.dtype is not None: - data = data.get_fdata(dtype=self.dtype) + img = img.get_fdata(dtype=self.dtype) else: - data = np.asanyarray(data.dataobj) + img = np.asanyarray(img.dataobj) if self.image_only: - return data + return img compatible_meta = dict() for meta_key in header: meta_datum = header[meta_key] @@ -84,7 +84,7 @@ def __call__(self, img): and np_str_obj_array_pattern.search(meta_datum.dtype.str) is not None: continue compatible_meta[meta_key] = meta_datum - return data, compatible_meta + return img, compatible_meta @export diff --git a/tests/test_dataset.py b/tests/test_dataset.py new file mode 100644 index 0000000000..d8f137cb0a --- /dev/null +++ b/tests/test_dataset.py @@ -0,0 +1,75 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import os +import numpy as np +import nibabel as nib +from parameterized import parameterized +from monai.data.dataset import Dataset +from monai.transforms.composables import LoadNiftid + +TEST_CASE_1 = [ + { + 'data': [ + { + 'image': 'test_image1.nii.gz', + 'label': 'test_label1.nii.gz', + 'extra': 'test_extra1.nii.gz' + }, + { + 'image': 'test_image2.nii.gz', + 'label': 'test_label2.nii.gz', + 'extra': 'test_extra2.nii.gz' + } + ], + 'transform': LoadNiftid(keys=['image', 'label', 'extra']) + }, + (128, 128, 128) +] + + +class TestDataset(unittest.TestCase): + + @parameterized.expand([TEST_CASE_1]) + def test_shape(self, input_param, expected_shape): + test_image = nib.Nifti1Image(np.random.randint(0, 2, size=[128, 128, 128]), np.eye(4)) + nib.save(test_image, 'test_image1.nii.gz') + nib.save(test_image, 'test_label1.nii.gz') + nib.save(test_image, 'test_extra1.nii.gz') + nib.save(test_image, 'test_image2.nii.gz') + nib.save(test_image, 'test_label2.nii.gz') + nib.save(test_image, 'test_extra2.nii.gz') + dataset = Dataset(**input_param) + data1 = dataset[0] + data2 = dataset[1] + if os.path.exists('test_image1.nii.gz'): + os.remove('test_image1.nii.gz') + if os.path.exists('test_label1.nii.gz'): + os.remove('test_label1.nii.gz') + if os.path.exists('test_extra1.nii.gz'): + os.remove('test_extra1.nii.gz') + if os.path.exists('test_image2.nii.gz'): + os.remove('test_image2.nii.gz') + if os.path.exists('test_label2.nii.gz'): + os.remove('test_label2.nii.gz') + if os.path.exists('test_extra2.nii.gz'): + os.remove('test_extra2.nii.gz') + self.assertTupleEqual(data1['image'].shape, expected_shape) + self.assertTupleEqual(data1['label'].shape, expected_shape) + self.assertTupleEqual(data1['extra'].shape, expected_shape) + self.assertTupleEqual(data2['image'].shape, expected_shape) + self.assertTupleEqual(data2['label'].shape, expected_shape) + self.assertTupleEqual(data2['extra'].shape, expected_shape) + + +if __name__ == '__main__': + unittest.main() From 852ef7e1a1719ed06bb899dc5c5459705ca16b54 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Thu, 5 Mar 2020 23:19:22 +0800 Subject: [PATCH 3/9] [DLMED] fix typo --- monai/data/nifti_reader.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/monai/data/nifti_reader.py b/monai/data/nifti_reader.py index bdfb2de951..753691de20 100644 --- a/monai/data/nifti_reader.py +++ b/monai/data/nifti_reader.py @@ -203,3 +203,8 @@ def __getitem__(self, index): data['label'] = label if len(compatible_meta) > 0: data.update(compatible_meta) + + if self.transform is not None: + data = self.transform(data) + + return data From 01ede7883dde0597d2f03f945cf88d3909dd5d66 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Thu, 5 Mar 2020 23:51:26 +0800 Subject: [PATCH 4/9] [DLMED] update example to use latest API --- examples/unet_segmentation_3d_dict.py | 17 +++++-- monai/data/nifti_reader.py | 73 --------------------------- 2 files changed, 12 insertions(+), 78 deletions(-) diff --git a/examples/unet_segmentation_3d_dict.py b/examples/unet_segmentation_3d_dict.py index 8c6955d87b..6e31ae68b8 100644 --- a/examples/unet_segmentation_3d_dict.py +++ b/examples/unet_segmentation_3d_dict.py @@ -28,8 +28,8 @@ import monai import monai.transforms.compose as transforms -from monai.data.nifti_reader import NiftiDatasetd -from monai.transforms.composables import AddChanneld, RandRotate90d +from monai.data.dataset import Dataset +from monai.transforms.composables import LoadNiftid, AddChanneld, RandRotate90d from monai.handlers.stats_handler import StatsHandler from monai.handlers.mean_dice import MeanDice from monai.visualize import img2tensorboard @@ -52,15 +52,22 @@ images = sorted(glob(os.path.join(tempdir, 'im*.nii.gz'))) segs = sorted(glob(os.path.join(tempdir, 'seg*.nii.gz'))) +train_data = list() +for img, seg in zip(images[:30], segs[:30]): + train_data.append({'image': img, 'seg': seg}) +val_data = list() +for img, seg in zip(images[-20:], segs[-20:]): + val_data.append({'image': img, 'seg': seg}) # Define transforms for image and segmentation transforms = transforms.Compose([ + LoadNiftid(keys=['image', 'seg']), AddChanneld(keys=['image', 'seg']), RandRotate90d(keys=['image', 'seg'], prob=0.8, axes=[1, 3]) ]) # Define nifti dataset, dataloader. -ds = NiftiDatasetd(images, segs, transform=transforms) +ds = Dataset(data=train_data, transform=transforms) loader = DataLoader(ds, batch_size=10, num_workers=2, pin_memory=torch.cuda.is_available()) check_data = monai.utils.misc.first(loader) print(check_data['image'].shape, check_data['seg'].shape) @@ -160,7 +167,7 @@ def log_training_loss(engine): evaluator.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=early_stopper) # create a validation data loader -val_ds = NiftiDatasetd(images[-20:], segs[-20:], transform=transforms) +val_ds = Dataset(data=val_data, transform=transforms) val_loader = DataLoader(ds, batch_size=5, num_workers=8, pin_memory=torch.cuda.is_available()) @@ -178,7 +185,7 @@ def log_metrics_to_tensorboard(engine): # create a training data loader logging.basicConfig(stream=sys.stdout, level=logging.INFO) -train_ds = NiftiDatasetd(images[:20], segs[:20], transform=transforms) +train_ds = Dataset(data=train_data, transform=transforms) train_loader = DataLoader(train_ds, batch_size=5, num_workers=8, pin_memory=torch.cuda.is_available()) train_epochs = 30 diff --git a/monai/data/nifti_reader.py b/monai/data/nifti_reader.py index 753691de20..3ece2264d9 100644 --- a/monai/data/nifti_reader.py +++ b/monai/data/nifti_reader.py @@ -135,76 +135,3 @@ def __getitem__(self, index): continue compatible_meta[meta_key] = meta_datum return img, target, compatible_meta - - -@export("monai.data") -class NiftiDatasetd(Dataset): - """ - Loads image/segmentation pairs of Nifti files from the given filename lists. Dict level transformations can be - specified for the dictionary data which is constructed by image, label and other metadata. - """ - - def __init__(self, image_files, seg_files=None, labels=None, as_closest_canonical=False, transform=None, dtype=None): - """ - Initializes the dataset with the image and segmentation filename lists. The transform `transform` is applied - to the images and `seg_transform` to the segmentations. - - Args: - image_files (list of str): list of image filenames. - seg_files (list of str): if in segmentation task, list of segmentation filenames. - labels (list or array): if in classification task, list of classification labels. - as_closest_canonical (bool): if True, load the image as closest to canonical orientation. - transform (Callable, optional): dict transforms to excute operations on dictionary data. - dtype (np.dtype, optional): if not None convert the loaded image to this data type. - """ - - if len(image_files) != len(seg_files): - raise ValueError('Must have same number of image and segmentation files') - - self.image_files = image_files - self.seg_files = seg_files - self.labels = labels - self.as_closest_canonical = as_closest_canonical - self.transform = transform - self.dtype = dtype - - def __len__(self): - return len(self.image_files) - - def __getitem__(self, index): - meta_data = None - img, meta_data = load_nifti( - filename_or_obj=self.image_files[index], - as_closest_canonical=self.as_closest_canonical, - image_only=False, - dtype=self.dtype - ) - - seg = None - if self.seg_files is not None: - seg = load_nifti(self.seg_files[index]) - label = None - if self.labels is not None: - label = self.labels[index] - - compatible_meta = {} - assert isinstance(meta_data, dict), 'meta_data must be in dictionary format.' - for meta_key in meta_data: - meta_datum = meta_data[meta_key] - if type(meta_datum).__name__ == 'ndarray' \ - and np_str_obj_array_pattern.search(meta_datum.dtype.str) is not None: - continue - compatible_meta[meta_key] = meta_datum - - data = {'image': img} - if seg is not None: - data['seg'] = seg - if label is not None: - data['label'] = label - if len(compatible_meta) > 0: - data.update(compatible_meta) - - if self.transform is not None: - data = self.transform(data) - - return data From 4c6b65bdb8facf05a51dbcc9fd00b432e3edb4ac Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Fri, 6 Mar 2020 14:49:04 +0800 Subject: [PATCH 5/9] [DLMED] update code to add more features --- examples/unet_segmentation_3d_dict.py | 51 +++++++++------- monai/data/synthetic.py | 15 ++++- monai/losses/dice.py | 2 +- monai/transforms/composables.py | 84 ++++++++++++++++++-------- monai/transforms/transforms.py | 19 ++++++ tests/test_change_to_channel_first.py | 49 +++++++++++++++ tests/test_change_to_channel_firstd.py | 58 ++++++++++++++++++ tests/test_dataset.py | 59 ++++++++---------- tests/test_load_nifti.py | 21 ++++--- tests/test_load_niftid.py | 30 +++++---- tests/test_spatial_crop.py | 5 +- 11 files changed, 278 insertions(+), 115 deletions(-) create mode 100644 tests/test_change_to_channel_first.py create mode 100644 tests/test_change_to_channel_firstd.py diff --git a/examples/unet_segmentation_3d_dict.py b/examples/unet_segmentation_3d_dict.py index 6e31ae68b8..6bf576fb5e 100644 --- a/examples/unet_segmentation_3d_dict.py +++ b/examples/unet_segmentation_3d_dict.py @@ -28,13 +28,14 @@ import monai import monai.transforms.compose as transforms -from monai.data.dataset import Dataset -from monai.transforms.composables import LoadNiftid, AddChanneld, RandRotate90d +from monai.transforms.composables import \ + LoadNiftid, ChangeToChannelFirstd, RandCropByPosNegLabeld, RandRotate90d from monai.handlers.stats_handler import StatsHandler from monai.handlers.mean_dice import MeanDice from monai.visualize import img2tensorboard from monai.data.synthetic import create_test_image_3d from monai.handlers.utils import stopping_fn_from_metric +from monai.data.utils import list_data_collate monai.config.print_config() @@ -42,35 +43,37 @@ tempdir = tempfile.mkdtemp() for i in range(50): - im, seg = create_test_image_3d(128, 128, 128) + im, seg = create_test_image_3d(128, 128, 128, channel_dim=-1) n = nib.Nifti1Image(im, np.eye(4)) - nib.save(n, os.path.join(tempdir, 'im%i.nii.gz' % i)) + nib.save(n, os.path.join(tempdir, 'img%i.nii.gz' % i)) n = nib.Nifti1Image(seg, np.eye(4)) nib.save(n, os.path.join(tempdir, 'seg%i.nii.gz' % i)) -images = sorted(glob(os.path.join(tempdir, 'im*.nii.gz'))) +images = sorted(glob(os.path.join(tempdir, 'img*.nii.gz'))) segs = sorted(glob(os.path.join(tempdir, 'seg*.nii.gz'))) -train_data = list() -for img, seg in zip(images[:30], segs[:30]): - train_data.append({'image': img, 'seg': seg}) -val_data = list() -for img, seg in zip(images[-20:], segs[-20:]): - val_data.append({'image': img, 'seg': seg}) +train_files = [{'img': img, 'seg': seg} for img, seg in zip(images[:40], segs[:40])] +val_files = [{'img': img, 'seg': seg} for img, seg in zip(images[-10:], segs[-10:])] # Define transforms for image and segmentation -transforms = transforms.Compose([ - LoadNiftid(keys=['image', 'seg']), - AddChanneld(keys=['image', 'seg']), - RandRotate90d(keys=['image', 'seg'], prob=0.8, axes=[1, 3]) +train_transforms = transforms.Compose([ + LoadNiftid(keys=['img', 'seg']), + ChangeToChannelFirstd(keys=['img', 'seg'], channel_dim=-1), + RandCropByPosNegLabeld(keys=['img', 'seg'], label_key='seg', size=[96, 96, 96], pos=1, neg=1, num_samples=4), + RandRotate90d(keys=['img', 'seg'], prob=0.8, axes=[1, 3]) +]) +val_transforms = transforms.Compose([ + LoadNiftid(keys=['img', 'seg']), + ChangeToChannelFirstd(keys=['img', 'seg'], channel_dim=-1) ]) # Define nifti dataset, dataloader. -ds = Dataset(data=train_data, transform=transforms) -loader = DataLoader(ds, batch_size=10, num_workers=2, pin_memory=torch.cuda.is_available()) +ds = monai.data.Dataset(data=train_files, transform=train_transforms) +loader = DataLoader(ds, batch_size=2, num_workers=2, collate_fn=list_data_collate, + pin_memory=torch.cuda.is_available()) check_data = monai.utils.misc.first(loader) -print(check_data['image'].shape, check_data['seg'].shape) +print(check_data['img'].shape, check_data['seg'].shape) lr = 1e-5 @@ -95,7 +98,7 @@ def _loss_fn(i, j): # Create trainer def prepare_batch(batch, device=None, non_blocking=False): - return _prepare_batch((batch['image'], batch['seg']), device, non_blocking) + return _prepare_batch((batch['img'], batch['seg']), device, non_blocking) device = torch.device("cuda:0") @@ -167,8 +170,9 @@ def log_training_loss(engine): evaluator.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=early_stopper) # create a validation data loader -val_ds = Dataset(data=val_data, transform=transforms) -val_loader = DataLoader(ds, batch_size=5, num_workers=8, pin_memory=torch.cuda.is_available()) +val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) +val_loader = DataLoader(ds, batch_size=5, num_workers=8, collate_fn=list_data_collate, + pin_memory=torch.cuda.is_available()) @trainer.on(Events.EPOCH_COMPLETED(every=validation_every_n_epochs)) @@ -185,8 +189,9 @@ def log_metrics_to_tensorboard(engine): # create a training data loader logging.basicConfig(stream=sys.stdout, level=logging.INFO) -train_ds = Dataset(data=train_data, transform=transforms) -train_loader = DataLoader(train_ds, batch_size=5, num_workers=8, pin_memory=torch.cuda.is_available()) +train_ds = monai.data.Dataset(data=train_files, transform=train_transforms) +train_loader = DataLoader(train_ds, batch_size=2, num_workers=8, collate_fn=list_data_collate, + pin_memory=torch.cuda.is_available()) train_epochs = 30 state = trainer.run(train_loader, train_epochs) diff --git a/monai/data/synthetic.py b/monai/data/synthetic.py index a51d730357..fb27544a07 100644 --- a/monai/data/synthetic.py +++ b/monai/data/synthetic.py @@ -14,7 +14,7 @@ from monai.transforms.utils import rescale_array -def create_test_image_2d(width, height, num_objs=12, rad_max=30, noise_max=0.0, num_seg_classes=5): +def create_test_image_2d(width, height, channel_dim=None, num_objs=12, rad_max=30, noise_max=0.0, num_seg_classes=5): """ Return a noisy 2D image with `numObj' circles and a 2D mask image. The maximum radius of the circles is given as `radMax'. The mask will have `numSegClasses' number of classes for segmentations labeled sequentially from 1, plus a @@ -40,10 +40,16 @@ def create_test_image_2d(width, height, num_objs=12, rad_max=30, noise_max=0.0, norm = np.random.uniform(0, num_seg_classes * noise_max, size=image.shape) noisyimage = rescale_array(np.maximum(image, norm)) + if channel_dim is not None: + assert isinstance(channel_dim, int) and channel_dim in (-1, 0, 2), 'invalid channel dim.' + noisyimage, labels = noisyimage[None], labels[None] \ + if channel_dim == 0 else noisyimage[..., None], labels[..., None] + return noisyimage, labels -def create_test_image_3d(height, width, depth, num_objs=12, rad_max=30, noise_max=0.0, num_seg_classes=5): +def create_test_image_3d(height, width, depth, channel_dim=None, num_objs=12, + rad_max=30, noise_max=0.0, num_seg_classes=5): """ Return a noisy 3D image and segmentation. @@ -69,4 +75,9 @@ def create_test_image_3d(height, width, depth, num_objs=12, rad_max=30, noise_ma norm = np.random.uniform(0, num_seg_classes * noise_max, size=image.shape) noisyimage = rescale_array(np.maximum(image, norm)) + if channel_dim is not None: + assert isinstance(channel_dim, int) and channel_dim in (-1, 0, 3), 'invalid channel dim.' + noisyimage, labels = (noisyimage[None], labels[None]) \ + if channel_dim == 0 else (noisyimage[..., None], labels[..., None]) + return noisyimage, labels diff --git a/monai/losses/dice.py b/monai/losses/dice.py index 46792a4714..0b513f8b56 100644 --- a/monai/losses/dice.py +++ b/monai/losses/dice.py @@ -78,7 +78,7 @@ def forward(self, pred, ground, smooth=1e-5): intersection = psum * tsum sums = psum + tsum - score = 2.0 * (intersection.sum(2) + smooth) / (sums.sum(2) + smooth) + score = 2.0 * (intersection.sum(2)) / (sums.sum(2) + smooth) return 1 - score.mean() diff --git a/monai/transforms/composables.py b/monai/transforms/composables.py index 3e56384d1a..ee66c39226 100644 --- a/monai/transforms/composables.py +++ b/monai/transforms/composables.py @@ -18,7 +18,7 @@ import monai from monai.data.utils import get_random_patch, get_valid_patch_size from monai.transforms.compose import Randomizable, Transform -from monai.transforms.transforms import LoadNifti, Rotate90, SpatialCrop, AddChannel +from monai.transforms.transforms import LoadNifti, ChangeToChannelFirst, AddChannel, Rotate90, SpatialCrop from monai.utils.misc import ensure_tuple from monai.transforms.utils import generate_pos_neg_label_crop_centers @@ -59,16 +59,22 @@ class LoadNiftid(MapTransform): dictionary-based wrapper of LoadNifti, must load image and metadata together. """ - def __init__(self, keys, as_closest_canonical=False, dtype=None): + def __init__(self, keys, as_closest_canonical=False, dtype=None, meta_key_format='{}.{}', overwriting_keys=False): """ Args: keys (hashable items): keys of the corresponding items to be transformed. See also: monai.transform.composables.MapTransform as_closest_canonical (bool): if True, load the image as closest to canonical axis format. dtype (np.dtype, optional): if not None convert the loaded image to this data type. + meta_key_format (str): key format to store meta data of the nifti image. + it must contain 2 fields for the key of this image and the key of every meta data item. + overwriting_keys (bool): whether allow to overwrite existing keys of meta data. + default is False, which will raise exception if encountering existing key. """ MapTransform.__init__(self, keys) self.loader = LoadNifti(as_closest_canonical, False, dtype) + self.meta_key_format = meta_key_format + self.overwriting_keys = overwriting_keys def __call__(self, data): d = dict(data) @@ -77,8 +83,56 @@ def __call__(self, data): assert isinstance(data, (tuple, list)), 'if data contains metadata, must be tuple or list.' d[key] = data[0] assert isinstance(data[1], dict), 'metadata must be in dict format.' - for k, v in data[1].items(): - d[key + '.' + k] = v + for k in sorted(data[1].keys()): + key_to_add = self.meta_key_format.format(key, k) + if key_to_add in d and self.overwriting_keys is False: + raise KeyError('meta data key is alreay existing.') + d[key_to_add] = data[1][k] + return d + + +@export +class ChangeToChannelFirstd(MapTransform): + """ + dictionary-based wrapper of ChangeToChannelFirst. + """ + + def __init__(self, keys, channel_dim=-1): + """ + Args: + keys (hashable items): keys of the corresponding items to be transformed. + See also: monai.transform.composables.MapTransform + channel_dim (int): which dimension of input image is the channel, default is the last dimension. + """ + MapTransform.__init__(self, keys) + self.converter = ChangeToChannelFirst(channel_dim=channel_dim) + + def __call__(self, data): + d = dict(data) + for key in self.keys: + d[key] = self.converter(d[key]) + return d + + +@export +class AddChanneld(MapTransform): + """ + dictionary-based wrapper of AddChannel. + """ + + def __init__(self, keys): + """ + Args: + keys (hashable items): keys of the corresponding items to be transformed. + See also: monai.transform.composables.MapTransform + """ + MapTransform.__init__(self, keys) + self.adder = AddChannel() + + def __call__(self, data): + d = dict(data) + for key in self.keys: + d[key] = self.adder(d[key]) return d @@ -178,28 +232,6 @@ def __call__(self, data): return d -@export -class AddChanneld(MapTransform): - """ - dictionary-based wrapper of AddChannel. - """ - - def __init__(self, keys): - """ - Args: - keys (hashable items): keys of the corresponding items to be transformed. - See also: monai.transform.composables.MapTransform - """ - MapTransform.__init__(self, keys) - self.adder = AddChannel() - - def __call__(self, data): - d = dict(data) - for key in self.keys: - d[key] = self.adder(d[key]) - return d - - @export class RandCropByPosNegLabeld(Randomizable, MapTransform): """ diff --git a/monai/transforms/transforms.py b/monai/transforms/transforms.py index a8e533aa7c..253523aa37 100644 --- a/monai/transforms/transforms.py +++ b/monai/transforms/transforms.py @@ -87,6 +87,25 @@ def __call__(self, filename): return img, compatible_meta +@export +class ChangeToChannelFirst: + """ + Change the channel dimension of the image to the first dimension. + Args: + channel_dim (int): which dimension of input image is the channel, default is the last dimension. + """ + + def __init__(self, channel_dim=-1): + self.channel_dim = channel_dim + + def __call__(self, img): + if self.channel_dim == -1: + self.channel_dim = img.ndim - 1 + axes = list(range(img.ndim)) + axes.remove(self.channel_dim) + return np.transpose(img, [self.channel_dim] + axes) + + @export class AddChannel: """ diff --git a/tests/test_change_to_channel_first.py b/tests/test_change_to_channel_first.py new file mode 100644 index 0000000000..9dc9ae5318 --- /dev/null +++ b/tests/test_change_to_channel_first.py @@ -0,0 +1,49 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np +from parameterized import parameterized +from monai.transforms.transforms import ChangeToChannelFirst + +TEST_CASE_1 = [ + { + 'channel_dim': -1 + }, + (4, 1, 2, 3) +] + +TEST_CASE_2 = [ + { + 'channel_dim': 3 + }, + (4, 1, 2, 3) +] + +TEST_CASE_3 = [ + { + 'channel_dim': 2 + }, + (3, 1, 2, 4) +] + + +class TestChangeToChannelFirst(unittest.TestCase): + + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) + def test_shape(self, input_param, expected_shape): + test_data = np.random.randint(0, 2, size=[1, 2, 3, 4]) + result = ChangeToChannelFirst(**input_param)(test_data) + self.assertTupleEqual(result.shape, expected_shape) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_change_to_channel_firstd.py b/tests/test_change_to_channel_firstd.py new file mode 100644 index 0000000000..4df97c06b1 --- /dev/null +++ b/tests/test_change_to_channel_firstd.py @@ -0,0 +1,58 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np +from parameterized import parameterized +from monai.transforms.composables import ChangeToChannelFirstd + +TEST_CASE_1 = [ + { + 'keys': ['image', 'label', 'extra'], + 'channel_dim': -1 + }, + (4, 1, 2, 3) +] + +TEST_CASE_2 = [ + { + 'keys': ['image', 'label', 'extra'], + 'channel_dim': 3 + }, + (4, 1, 2, 3) +] + +TEST_CASE_3 = [ + { + 'keys': ['image', 'label', 'extra'], + 'channel_dim': 2 + }, + (3, 1, 2, 4) +] + + +class TestChangeToChannelFirstd(unittest.TestCase): + + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) + def test_shape(self, input_param, expected_shape): + test_data = { + 'image': np.random.randint(0, 2, size=[1, 2, 3, 4]), + 'label': np.random.randint(0, 2, size=[1, 2, 3, 4]), + 'extra': np.random.randint(0, 2, size=[1, 2, 3, 4]) + } + result = ChangeToChannelFirstd(**input_param)(test_data) + self.assertTupleEqual(result['image'].shape, expected_shape) + self.assertTupleEqual(result['label'].shape, expected_shape) + self.assertTupleEqual(result['extra'].shape, expected_shape) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_dataset.py b/tests/test_dataset.py index d8f137cb0a..6829812dbc 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -11,28 +11,15 @@ import unittest import os +import shutil import numpy as np +import tempfile import nibabel as nib from parameterized import parameterized from monai.data.dataset import Dataset from monai.transforms.composables import LoadNiftid TEST_CASE_1 = [ - { - 'data': [ - { - 'image': 'test_image1.nii.gz', - 'label': 'test_label1.nii.gz', - 'extra': 'test_extra1.nii.gz' - }, - { - 'image': 'test_image2.nii.gz', - 'label': 'test_label2.nii.gz', - 'extra': 'test_extra2.nii.gz' - } - ], - 'transform': LoadNiftid(keys=['image', 'label', 'extra']) - }, (128, 128, 128) ] @@ -40,29 +27,31 @@ class TestDataset(unittest.TestCase): @parameterized.expand([TEST_CASE_1]) - def test_shape(self, input_param, expected_shape): + def test_shape(self, expected_shape): test_image = nib.Nifti1Image(np.random.randint(0, 2, size=[128, 128, 128]), np.eye(4)) - nib.save(test_image, 'test_image1.nii.gz') - nib.save(test_image, 'test_label1.nii.gz') - nib.save(test_image, 'test_extra1.nii.gz') - nib.save(test_image, 'test_image2.nii.gz') - nib.save(test_image, 'test_label2.nii.gz') - nib.save(test_image, 'test_extra2.nii.gz') - dataset = Dataset(**input_param) + tempdir = tempfile.mkdtemp() + nib.save(test_image, os.path.join(tempdir, 'test_image1.nii.gz')) + nib.save(test_image, os.path.join(tempdir, 'test_label1.nii.gz')) + nib.save(test_image, os.path.join(tempdir, 'test_extra1.nii.gz')) + nib.save(test_image, os.path.join(tempdir, 'test_image2.nii.gz')) + nib.save(test_image, os.path.join(tempdir, 'test_label2.nii.gz')) + nib.save(test_image, os.path.join(tempdir, 'test_extra2.nii.gz')) + test_data = [ + { + 'image': os.path.join(tempdir, 'test_image1.nii.gz'), + 'label': os.path.join(tempdir, 'test_label1.nii.gz'), + 'extra': os.path.join(tempdir, 'test_extra1.nii.gz') + }, + { + 'image': os.path.join(tempdir, 'test_image2.nii.gz'), + 'label': os.path.join(tempdir, 'test_label2.nii.gz'), + 'extra': os.path.join(tempdir, 'test_extra2.nii.gz') + } + ] + dataset = Dataset(data=test_data, transform=LoadNiftid(keys=['image', 'label', 'extra'])) data1 = dataset[0] data2 = dataset[1] - if os.path.exists('test_image1.nii.gz'): - os.remove('test_image1.nii.gz') - if os.path.exists('test_label1.nii.gz'): - os.remove('test_label1.nii.gz') - if os.path.exists('test_extra1.nii.gz'): - os.remove('test_extra1.nii.gz') - if os.path.exists('test_image2.nii.gz'): - os.remove('test_image2.nii.gz') - if os.path.exists('test_label2.nii.gz'): - os.remove('test_label2.nii.gz') - if os.path.exists('test_extra2.nii.gz'): - os.remove('test_extra2.nii.gz') + shutil.rmtree(tempdir) self.assertTupleEqual(data1['image'].shape, expected_shape) self.assertTupleEqual(data1['label'].shape, expected_shape) self.assertTupleEqual(data1['extra'].shape, expected_shape) diff --git a/tests/test_load_nifti.py b/tests/test_load_nifti.py index ccf79cfaa3..de0660ccb3 100644 --- a/tests/test_load_nifti.py +++ b/tests/test_load_nifti.py @@ -11,39 +11,40 @@ import unittest import os +import shutil import numpy as np +import tempfile import nibabel as nib from parameterized import parameterized from monai.transforms.transforms import LoadNifti -TEST_CASE_1 = [ +TEST_CASE_IMAGE_ONLY = [ { 'as_closest_canonical': False, 'image_only': True }, - 'test_image.nii.gz', (128, 128, 128) ] -TEST_CASE_2 = [ +TEST_CASE_IMAGE_METADATA = [ { 'as_closest_canonical': False, 'image_only': False }, - 'test_image.nii.gz', (128, 128, 128) ] class TestLoadNifti(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) - def test_shape(self, input_param, input_data, expected_shape): + @parameterized.expand([TEST_CASE_IMAGE_ONLY, TEST_CASE_IMAGE_METADATA]) + def test_shape(self, input_param, expected_shape): test_image = np.random.randint(0, 2, size=[128, 128, 128]) - nib.save(nib.Nifti1Image(test_image, np.eye(4)), 'test_image.nii.gz') - result = LoadNifti(**input_param)(input_data) - if os.path.exists('test_image.nii.gz'): - os.remove('test_image.nii.gz') + tempdir = tempfile.mkdtemp() + nib.save(nib.Nifti1Image(test_image, np.eye(4)), os.path.join(tempdir, 'test_image.nii.gz')) + test_data = os.path.join(tempdir, 'test_image.nii.gz') + result = LoadNifti(**input_param)(test_data) + shutil.rmtree(tempdir) if isinstance(result, tuple): result = result[0] self.assertTupleEqual(result.shape, expected_shape) diff --git a/tests/test_load_niftid.py b/tests/test_load_niftid.py index 644f8e9d20..76819cf4d6 100644 --- a/tests/test_load_niftid.py +++ b/tests/test_load_niftid.py @@ -11,7 +11,9 @@ import unittest import os +import shutil import numpy as np +import tempfile import nibabel as nib from parameterized import parameterized from monai.transforms.composables import LoadNiftid @@ -21,11 +23,6 @@ 'keys': ['image', 'label', 'extra'], 'as_closest_canonical': False }, - { - 'image': 'test_image.nii.gz', - 'label': 'test_label.nii.gz', - 'extra': 'test_extra.nii.gz' - }, (128, 128, 128) ] @@ -33,18 +30,19 @@ class TestLoadNiftid(unittest.TestCase): @parameterized.expand([TEST_CASE_1]) - def test_shape(self, input_param, input_data, expected_shape): + def test_shape(self, input_param, expected_shape): test_image = nib.Nifti1Image(np.random.randint(0, 2, size=[128, 128, 128]), np.eye(4)) - nib.save(test_image, 'test_image.nii.gz') - nib.save(test_image, 'test_label.nii.gz') - nib.save(test_image, 'test_extra.nii.gz') - result = LoadNiftid(**input_param)(input_data) - if os.path.exists('test_image.nii.gz'): - os.remove('test_image.nii.gz') - if os.path.exists('test_label.nii.gz'): - os.remove('test_label.nii.gz') - if os.path.exists('test_extra.nii.gz'): - os.remove('test_extra.nii.gz') + tempdir = tempfile.mkdtemp() + nib.save(test_image, os.path.join(tempdir, 'test_image.nii.gz')) + nib.save(test_image, os.path.join(tempdir, 'test_label.nii.gz')) + nib.save(test_image, os.path.join(tempdir, 'test_extra.nii.gz')) + test_data = { + 'image': os.path.join(tempdir, 'test_image.nii.gz'), + 'label': os.path.join(tempdir, 'test_label.nii.gz'), + 'extra': os.path.join(tempdir, 'test_extra.nii.gz') + } + result = LoadNiftid(**input_param)(test_data) + shutil.rmtree(tempdir) self.assertTupleEqual(result['image'].shape, expected_shape) self.assertTupleEqual(result['label'].shape, expected_shape) self.assertTupleEqual(result['extra'].shape, expected_shape) diff --git a/tests/test_spatial_crop.py b/tests/test_spatial_crop.py index 2a3c2e7f9c..8a99d90f63 100644 --- a/tests/test_spatial_crop.py +++ b/tests/test_spatial_crop.py @@ -20,7 +20,7 @@ 'roi_size': [2, 2, 2] }, np.random.randint(0, 2, size=[3, 3, 3, 3]), - (3, 2, 2, 2), + (3, 2, 2, 2) ] TEST_CASE_2 = [ @@ -29,9 +29,10 @@ 'roi_end': [2, 2, 2] }, np.random.randint(0, 2, size=[3, 3, 3, 3]), - (3, 2, 2, 2), + (3, 2, 2, 2) ] + class TestSpatialCrop(unittest.TestCase): @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) From 96e60c15b9d40a074d16289f68ad09855e63de59 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Fri, 6 Mar 2020 17:14:18 +0800 Subject: [PATCH 6/9] [DLMED] update to AsChannelFirst --- examples/unet_segmentation_3d_dict.py | 6 +++--- monai/data/synthetic.py | 6 +++--- monai/losses/dice.py | 2 +- monai/transforms/composables.py | 8 ++++---- monai/transforms/transforms.py | 2 +- ...hange_to_channel_first.py => test_as_channel_first.py} | 6 +++--- ...nge_to_channel_firstd.py => test_as_channel_firstd.py} | 6 +++--- tests/test_dice_loss.py | 6 +++--- 8 files changed, 21 insertions(+), 21 deletions(-) rename tests/{test_change_to_channel_first.py => test_as_channel_first.py} (86%) rename tests/{test_change_to_channel_firstd.py => test_as_channel_firstd.py} (89%) diff --git a/examples/unet_segmentation_3d_dict.py b/examples/unet_segmentation_3d_dict.py index 6bf576fb5e..d7ea3795ea 100644 --- a/examples/unet_segmentation_3d_dict.py +++ b/examples/unet_segmentation_3d_dict.py @@ -29,7 +29,7 @@ import monai import monai.transforms.compose as transforms from monai.transforms.composables import \ - LoadNiftid, ChangeToChannelFirstd, RandCropByPosNegLabeld, RandRotate90d + LoadNiftid, AsChannelFirstd, RandCropByPosNegLabeld, RandRotate90d from monai.handlers.stats_handler import StatsHandler from monai.handlers.mean_dice import MeanDice from monai.visualize import img2tensorboard @@ -59,13 +59,13 @@ # Define transforms for image and segmentation train_transforms = transforms.Compose([ LoadNiftid(keys=['img', 'seg']), - ChangeToChannelFirstd(keys=['img', 'seg'], channel_dim=-1), + AsChannelFirstd(keys=['img', 'seg'], channel_dim=-1), RandCropByPosNegLabeld(keys=['img', 'seg'], label_key='seg', size=[96, 96, 96], pos=1, neg=1, num_samples=4), RandRotate90d(keys=['img', 'seg'], prob=0.8, axes=[1, 3]) ]) val_transforms = transforms.Compose([ LoadNiftid(keys=['img', 'seg']), - ChangeToChannelFirstd(keys=['img', 'seg'], channel_dim=-1) + AsChannelFirstd(keys=['img', 'seg'], channel_dim=-1) ]) # Define nifti dataset, dataloader. diff --git a/monai/data/synthetic.py b/monai/data/synthetic.py index fb27544a07..6ea1517dc9 100644 --- a/monai/data/synthetic.py +++ b/monai/data/synthetic.py @@ -14,7 +14,7 @@ from monai.transforms.utils import rescale_array -def create_test_image_2d(width, height, channel_dim=None, num_objs=12, rad_max=30, noise_max=0.0, num_seg_classes=5): +def create_test_image_2d(width, height, num_objs=12, rad_max=30, noise_max=0.0, num_seg_classes=5, channel_dim=None): """ Return a noisy 2D image with `numObj' circles and a 2D mask image. The maximum radius of the circles is given as `radMax'. The mask will have `numSegClasses' number of classes for segmentations labeled sequentially from 1, plus a @@ -48,8 +48,8 @@ def create_test_image_2d(width, height, channel_dim=None, num_objs=12, rad_max=3 return noisyimage, labels -def create_test_image_3d(height, width, depth, channel_dim=None, num_objs=12, - rad_max=30, noise_max=0.0, num_seg_classes=5): +def create_test_image_3d(height, width, depth, num_objs=12, rad_max=30, + noise_max=0.0, num_seg_classes=5, channel_dim=None): """ Return a noisy 3D image and segmentation. diff --git a/monai/losses/dice.py b/monai/losses/dice.py index 0b513f8b56..f7ebacaa9e 100644 --- a/monai/losses/dice.py +++ b/monai/losses/dice.py @@ -78,7 +78,7 @@ def forward(self, pred, ground, smooth=1e-5): intersection = psum * tsum sums = psum + tsum - score = 2.0 * (intersection.sum(2)) / (sums.sum(2) + smooth) + score = (2.0 * intersection.sum(2) + smooth) / (sums.sum(2) + smooth) return 1 - score.mean() diff --git a/monai/transforms/composables.py b/monai/transforms/composables.py index ee66c39226..af2a973a4a 100644 --- a/monai/transforms/composables.py +++ b/monai/transforms/composables.py @@ -18,7 +18,7 @@ import monai from monai.data.utils import get_random_patch, get_valid_patch_size from monai.transforms.compose import Randomizable, Transform -from monai.transforms.transforms import LoadNifti, ChangeToChannelFirst, AddChannel, Rotate90, SpatialCrop +from monai.transforms.transforms import LoadNifti, AsChannelFirst, AddChannel, Rotate90, SpatialCrop from monai.utils.misc import ensure_tuple from monai.transforms.utils import generate_pos_neg_label_crop_centers @@ -92,9 +92,9 @@ def __call__(self, data): @export -class ChangeToChannelFirstd(MapTransform): +class AsChannelFirstd(MapTransform): """ - dictionary-based wrapper of ChangeToChannelFirst. + dictionary-based wrapper of AsChannelFirst. """ def __init__(self, keys, channel_dim=-1): @@ -105,7 +105,7 @@ def __init__(self, keys, channel_dim=-1): channel_dim (int): which dimension of input image is the channel, default is the last dimension. """ MapTransform.__init__(self, keys) - self.converter = ChangeToChannelFirst(channel_dim=channel_dim) + self.converter = AsChannelFirst(channel_dim=channel_dim) def __call__(self, data): d = dict(data) diff --git a/monai/transforms/transforms.py b/monai/transforms/transforms.py index d1ac92a5a5..8076ce36e0 100644 --- a/monai/transforms/transforms.py +++ b/monai/transforms/transforms.py @@ -88,7 +88,7 @@ def __call__(self, filename): @export -class ChangeToChannelFirst: +class AsChannelFirst: """ Change the channel dimension of the image to the first dimension. Args: diff --git a/tests/test_change_to_channel_first.py b/tests/test_as_channel_first.py similarity index 86% rename from tests/test_change_to_channel_first.py rename to tests/test_as_channel_first.py index 9dc9ae5318..ccd0f3765a 100644 --- a/tests/test_change_to_channel_first.py +++ b/tests/test_as_channel_first.py @@ -12,7 +12,7 @@ import unittest import numpy as np from parameterized import parameterized -from monai.transforms.transforms import ChangeToChannelFirst +from monai.transforms.transforms import AsChannelFirst TEST_CASE_1 = [ { @@ -36,12 +36,12 @@ ] -class TestChangeToChannelFirst(unittest.TestCase): +class TestAsChannelFirst(unittest.TestCase): @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) def test_shape(self, input_param, expected_shape): test_data = np.random.randint(0, 2, size=[1, 2, 3, 4]) - result = ChangeToChannelFirst(**input_param)(test_data) + result = AsChannelFirst(**input_param)(test_data) self.assertTupleEqual(result.shape, expected_shape) diff --git a/tests/test_change_to_channel_firstd.py b/tests/test_as_channel_firstd.py similarity index 89% rename from tests/test_change_to_channel_firstd.py rename to tests/test_as_channel_firstd.py index 4df97c06b1..6f9b450c4f 100644 --- a/tests/test_change_to_channel_firstd.py +++ b/tests/test_as_channel_firstd.py @@ -12,7 +12,7 @@ import unittest import numpy as np from parameterized import parameterized -from monai.transforms.composables import ChangeToChannelFirstd +from monai.transforms.composables import AsChannelFirstd TEST_CASE_1 = [ { @@ -39,7 +39,7 @@ ] -class TestChangeToChannelFirstd(unittest.TestCase): +class TestAsChannelFirstd(unittest.TestCase): @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) def test_shape(self, input_param, expected_shape): @@ -48,7 +48,7 @@ def test_shape(self, input_param, expected_shape): 'label': np.random.randint(0, 2, size=[1, 2, 3, 4]), 'extra': np.random.randint(0, 2, size=[1, 2, 3, 4]) } - result = ChangeToChannelFirstd(**input_param)(test_data) + result = AsChannelFirstd(**input_param)(test_data) self.assertTupleEqual(result['image'].shape, expected_shape) self.assertTupleEqual(result['label'].shape, expected_shape) self.assertTupleEqual(result['extra'].shape, expected_shape) diff --git a/tests/test_dice_loss.py b/tests/test_dice_loss.py index a7ad9171b9..c5640a5660 100644 --- a/tests/test_dice_loss.py +++ b/tests/test_dice_loss.py @@ -39,7 +39,7 @@ 'ground': torch.tensor([[[[1., 1.], [1., 1.]]], [[[1., 0.], [1., 0.]]]]), 'smooth': 1e-4, }, - 0.416636, + 0.416657, ] TEST_CASE_3 = [ # shape: (2, 2, 3), (2, 1, 3) @@ -64,7 +64,7 @@ 'ground': torch.tensor([[[1., 0., 0.]], [[1., 1., 0.]]]), 'smooth': 1e-4, }, - 0.435015, + 0.435050, ] TEST_CASE_5 = [ # shape: (2, 2, 3), (2, 1, 3) @@ -77,7 +77,7 @@ 'ground': torch.tensor([[[1., 0., 0.]], [[1., 1., 0.]]]), 'smooth': 1e-4, }, - 0.383678, + 0.383713, ] TEST_CASE_6 = [ # shape: (1, 1, 2, 2), (1, 1, 2, 2) From bfe75a35f7ab186f39544636a74e657a32b67a52 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Fri, 6 Mar 2020 17:42:35 +0800 Subject: [PATCH 7/9] [DLMED] update to use np.moveaxis API instead --- monai/transforms/transforms.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/monai/transforms/transforms.py b/monai/transforms/transforms.py index 8076ce36e0..6afa77bfbb 100644 --- a/monai/transforms/transforms.py +++ b/monai/transforms/transforms.py @@ -96,14 +96,11 @@ class AsChannelFirst: """ def __init__(self, channel_dim=-1): + assert isinstance(channel_dim, int) and channel_dim >= -1, 'invalid channel dimension.' self.channel_dim = channel_dim def __call__(self, img): - if self.channel_dim == -1: - self.channel_dim = img.ndim - 1 - axes = list(range(img.ndim)) - axes.remove(self.channel_dim) - return np.transpose(img, [self.channel_dim] + axes) + return np.moveaxis(img, self.channel_dim, 0) @export From b620fcb0751e59215d0d00bb3636710e9ee09e74 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Sat, 7 Mar 2020 09:21:25 +0800 Subject: [PATCH 8/9] [DLMED] update according to comments --- monai/data/synthetic.py | 3 ++- monai/transforms/transforms.py | 14 ++++++-------- tests/test_load_niftid.py | 21 +++++++++------------ 3 files changed, 17 insertions(+), 21 deletions(-) diff --git a/monai/data/synthetic.py b/monai/data/synthetic.py index 6ea1517dc9..4efd4fe393 100644 --- a/monai/data/synthetic.py +++ b/monai/data/synthetic.py @@ -19,7 +19,8 @@ def create_test_image_2d(width, height, num_objs=12, rad_max=30, noise_max=0.0, Return a noisy 2D image with `numObj' circles and a 2D mask image. The maximum radius of the circles is given as `radMax'. The mask will have `numSegClasses' number of classes for segmentations labeled sequentially from 1, plus a background class represented as 0. If `noiseMax' is greater than 0 then noise will be added to the image taken from - the uniform distribution on range [0,noiseMax). + the uniform distribution on range [0,noiseMax). If `channel_dim' is None, will create an image without channel + dimemsion, otherwise create an image with channel dimension as first dim or last dim. """ image = np.zeros((width, height)) diff --git a/monai/transforms/transforms.py b/monai/transforms/transforms.py index 6afa77bfbb..35d7c673b7 100644 --- a/monai/transforms/transforms.py +++ b/monai/transforms/transforms.py @@ -40,6 +40,12 @@ def __init__(self, as_closest_canonical=False, image_only=False, dtype=None): as_closest_canonical (bool): if True, load the image as closest to canonical axis format. image_only (bool): if True return only the image volume, other return image volume and header dict. dtype (np.dtype, optional): if not None convert the loaded image to this data type. + + Note: + The loaded image volume if `image_only` is True, or a tuple containing the volume and the Nifti + header in dict format otherwise. + header['original_affine'] stores the original affine loaded from `filename_or_obj`. + header['affine'] stores the affine after the optional `as_closest_canonical` transform. """ self.as_closest_canonical = as_closest_canonical self.image_only = image_only @@ -49,14 +55,6 @@ def __call__(self, filename): """ Args: filename (str or file): path to file or file-like object. - - Returns: - The loaded image volume if `image_only` is True, or a tuple containing the volume and the Nifti - header in dict format otherwise. - - Note: - header['original_affine'] stores the original affine loaded from `filename_or_obj`. - header['affine'] stores the affine after the optional `as_closest_canonical` transform. """ img = nib.load(filename) diff --git a/tests/test_load_niftid.py b/tests/test_load_niftid.py index 76819cf4d6..071972f03f 100644 --- a/tests/test_load_niftid.py +++ b/tests/test_load_niftid.py @@ -18,9 +18,11 @@ from parameterized import parameterized from monai.transforms.composables import LoadNiftid +KEYS = ['image', 'label', 'extra'] + TEST_CASE_1 = [ { - 'keys': ['image', 'label', 'extra'], + 'keys': KEYS, 'as_closest_canonical': False }, (128, 128, 128) @@ -33,19 +35,14 @@ class TestLoadNiftid(unittest.TestCase): def test_shape(self, input_param, expected_shape): test_image = nib.Nifti1Image(np.random.randint(0, 2, size=[128, 128, 128]), np.eye(4)) tempdir = tempfile.mkdtemp() - nib.save(test_image, os.path.join(tempdir, 'test_image.nii.gz')) - nib.save(test_image, os.path.join(tempdir, 'test_label.nii.gz')) - nib.save(test_image, os.path.join(tempdir, 'test_extra.nii.gz')) - test_data = { - 'image': os.path.join(tempdir, 'test_image.nii.gz'), - 'label': os.path.join(tempdir, 'test_label.nii.gz'), - 'extra': os.path.join(tempdir, 'test_extra.nii.gz') - } + test_data = dict() + for key in KEYS: + nib.save(test_image, os.path.join(tempdir, key + '.nii.gz')) + test_data.update({key: os.path.join(tempdir, key + '.nii.gz')}) result = LoadNiftid(**input_param)(test_data) shutil.rmtree(tempdir) - self.assertTupleEqual(result['image'].shape, expected_shape) - self.assertTupleEqual(result['label'].shape, expected_shape) - self.assertTupleEqual(result['extra'].shape, expected_shape) + for key in KEYS: + self.assertTupleEqual(result[key].shape, expected_shape) if __name__ == '__main__': From 225f34ffeb708545f7e5beb0fdab5bba8ded2aa7 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Mon, 9 Mar 2020 21:02:12 +0000 Subject: [PATCH 9/9] update generalized dice to be consistent with the changes in dice loss --- monai/losses/dice.py | 2 +- tests/test_generalized_dice_loss.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/monai/losses/dice.py b/monai/losses/dice.py index f7ebacaa9e..808c3c65d3 100644 --- a/monai/losses/dice.py +++ b/monai/losses/dice.py @@ -159,5 +159,5 @@ def forward(self, pred, ground, smooth=1e-5): b[infs] = 0.0 b[infs] = torch.max(b) - score = 2.0 * (intersection.sum(2) * w) / (sums.sum(2) * w + smooth) + score = (2.0 * intersection.sum(2) * w + smooth) / (sums.sum(2) * w + smooth) return 1 - score.mean() diff --git a/tests/test_generalized_dice_loss.py b/tests/test_generalized_dice_loss.py index fe29bc2d11..e08ff1d296 100644 --- a/tests/test_generalized_dice_loss.py +++ b/tests/test_generalized_dice_loss.py @@ -39,7 +39,7 @@ 'ground': torch.tensor([[[[1., 1.], [1., 1.]]], [[[1., 0.], [1., 0.]]]]), 'smooth': 1e-4, }, - 0.41678, + 0.416597, ] TEST_CASE_2 = [ # shape: (2, 2, 3), (2, 1, 3) @@ -64,7 +64,7 @@ 'ground': torch.tensor([[[1., 0., 0.]], [[1., 1., 0.]]]), 'smooth': 1e-4, }, - 0.435111, + 0.435034, ] TEST_CASE_4 = [ # shape: (2, 2, 3), (2, 1, 3) @@ -77,7 +77,7 @@ 'ground': torch.tensor([[[1., 0., 0.]], [[1., 1., 0.]]]), 'smooth': 1e-4, }, - 0.383776, + 0.383699, ] TEST_CASE_5 = [ # shape: (2, 2, 3), (2, 1, 3) @@ -89,7 +89,7 @@ 'ground': torch.tensor([[[0., 0., 0.]], [[0., 0., 0.]]]), 'smooth': 1e-8, }, - 1.0, + 0.0, ] TEST_CASE_6 = [ # shape: (1, 1, 2, 2), (1, 1, 2, 2)