From 07efc776b93bcc21263e212ee972763e46514f3f Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Thu, 6 Feb 2020 16:56:53 +0800 Subject: [PATCH 01/11] [DLMED] add support to SlidingWindow inference and image padding --- monai/data/transforms/image_end_padder.py | 44 ++++++ monai/utils/sliding_window_inference.py | 165 ++++++++++++++++++++++ tests/test_image_end_padder.py | 36 +++++ tests/test_sliding_window_inference.py | 37 +++++ 4 files changed, 282 insertions(+) create mode 100644 monai/data/transforms/image_end_padder.py create mode 100644 monai/utils/sliding_window_inference.py create mode 100644 tests/test_image_end_padder.py create mode 100644 tests/test_sliding_window_inference.py diff --git a/monai/data/transforms/image_end_padder.py b/monai/data/transforms/image_end_padder.py new file mode 100644 index 0000000000..86d435c820 --- /dev/null +++ b/monai/data/transforms/image_end_padder.py @@ -0,0 +1,44 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import monai + +export = monai.utils.export("monai.data.transforms") + + +@export +class ImageEndPadder: + """Performs padding by appending to the end of the data all on one side for each dimension. + Uses np.pad so in practice, a mode needs to be provided. See numpy.lib.arraypad.pad + for additional details. + + Args: + out_size (list): the size of region of interest at the end of the operation. + mode (string): a portion from numpy.lib.arraypad.pad is copied below. + dtype: output data format. + """ + + def __init__(self, out_size, mode, dtype=np.float32): + assert out_size is not None and isinstance(out_size, (list, tuple)), 'out_size must be list or tuple' + self.out_size = out_size + assert isinstance(mode, str), 'mode must be str' + self.mode = mode + self.dtype = dtype + + def _determine_data_pad_width(self, data_shape): + return [(0, max(self.out_size[i] - data_shape[i], 0)) for i in range(len(self.out_size))] + + def __call__(self, img): + data_pad_width = self._determine_data_pad_width(img.shape[2:]) + all_pad_width = [(0, 0), (0, 0)] + data_pad_width + img = np.pad(img, all_pad_width, self.mode) + return img diff --git a/monai/utils/sliding_window_inference.py b/monai/utils/sliding_window_inference.py new file mode 100644 index 0000000000..746e9932ba --- /dev/null +++ b/monai/utils/sliding_window_inference.py @@ -0,0 +1,165 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import math +import numpy as np +import torch +from monai.data.transforms import ImageEndPadder + + +def sliding_window_inference(inputs, roi_size, sw_batch_size, predictor, device): + """Use SlidingWindow method to execute inference. + + Args: + roi_size (list, tuple): the window size to execute SlidingWindow inference. + sw_batch_size (int): the batch size to run window slices. + predictor: a portion from numpy.lib.arraypad.pad is copied below. + device: on which device to execute model inference, cpu or gpu. + + Note: + must be channel first, support both 2D and 3D. + input data must have batch dim. + execute on 1 image/per inference, run a batch of window slices of 1 input image. + """ + num_spatial_dims = len(inputs.shape) - 2 + assert len(roi_size) == num_spatial_dims, 'roi_size {} does not match input dims.'.format(roi_size) + + # determine image spatial size and batch size + # Note: all input images must have the same image size and batch size + image_size = list(inputs.shape[2:]) + batch_size = inputs.shape[0] + + # TODO: Enable batch sizes > 1 in future. + assert batch_size == 1, "input batch size must be 1" + + # in case that image size is smaller than roi size + is_oversized = False + original_image_size = [image_size[i] for i in range(num_spatial_dims)] + + for i in range(num_spatial_dims): + if int(roi_size[i]) > image_size[i]: + is_oversized = True + break + + if is_oversized: + for i in range(num_spatial_dims): + image_size[i] = max(image_size[i], roi_size[i]) + + padder = ImageEndPadder(roi_size, 'constant') + inputs = padder(inputs) + + scan_interval = _get_scan_interval(image_size, roi_size, num_spatial_dims) + scan_num = [int(math.ceil(float(image_size[i]) / scan_interval[i])) for i in range(num_spatial_dims)] + + # Store all slices in list. + slices = [] + if num_spatial_dims == 3: + for i in range(scan_num[0]): + start_i = i * scan_interval[0] + start_i -= max(start_i + roi_size[0] - image_size[0], 0) + slice_i = slice(start_i, start_i + roi_size[0]) + + for j in range(scan_num[1]): + start_j = j * scan_interval[1] + start_j -= max(start_j + roi_size[1] - image_size[1], 0) + slice_j = slice(start_j, start_j + roi_size[1]) + + for k in range(0, scan_num[2]): + start_k = k * scan_interval[2] + start_k -= max(start_k + roi_size[2] - image_size[2], 0) + slice_k = slice(start_k, start_k + roi_size[2]) + slices.append((slice_i, slice_j, slice_k)) + else: + for i in range(scan_num[0]): + start_i = i * scan_interval[0] + start_i -= max(start_i + roi_size[0] - image_size[0], 0) + slice_i = slice(start_i, start_i + roi_size[0]) + + for j in range(scan_num[1]): + start_j = j * scan_interval[1] + start_j -= max(start_j + roi_size[1] - image_size[1], 0) + slice_j = slice(start_j, start_j + roi_size[1]) + slices.append((slice_i, slice_j)) + + buffered_requests = [] + for slice_index in range(0, len(slices), sw_batch_size): + slice_index_range = range(slice_index, min(slice_index + sw_batch_size, len(slices))) + input_slices = [] + for curr_index in slice_index_range: + if num_spatial_dims == 3: + slice_i, slice_j, slice_k = slices[curr_index] + input_slices.append(inputs[0, :, slice_i, slice_j, slice_k]) + else: + slice_i, slice_j = slices[curr_index] + input_slices.append(inputs[0, :, slice_i, slice_j]) + buffered_requests.append(np.stack(input_slices)) + + # Perform predictions + output_rois = list() + for data in buffered_requests: + output_rois.append(predictor(data)) + + output_classes = output_rois[0].shape[1] + output_shape = [batch_size, output_classes] + list(image_size) + + # allocate memory to store the full output and the count for overlapping parts + output_dict = torch.zeros(output_shape, dtype=torch.float32, device=device) + count_dict = torch.zeros(output_shape, dtype=torch.float32, device=device) + + window_index = 0 + for slice_index in range(0, len(slices), sw_batch_size): + slice_index_range = range(slice_index, min(slice_index + sw_batch_size, len(slices))) + output_roi = output_rois[window_index] + window_index += 1 + + # store the result in the proper location of the full output + for curr_index in slice_index_range: + if num_spatial_dims == 3: + slice_i, slice_j, slice_k = slices[curr_index] + output_dict[0, :, slice_i, slice_j, slice_k] += \ + output_roi[curr_index - slice_index, :] + count_dict[0, :, slice_i, slice_j, slice_k] += 1 + else: + slice_i, slice_j = slices[curr_index] + output_dict[0, :, slice_i, slice_j] += \ + output_roi[curr_index - slice_index, :] + count_dict[0, :, slice_i, slice_j] += 1 + + # account for any overlapping sections + output_dict /= count_dict + + # in case that image size is smaller than roi size + if is_oversized: + new_output_dict = list() + if num_spatial_dims == 3: + new_output_dict = output_dict[:, :, :original_image_size[0], + :original_image_size[1], :original_image_size[2]] + else: + new_output_dict = output_dict[:, :, :original_image_size[0], :original_image_size[1]] + + output_dict = new_output_dict + + return output_dict + + +def _get_scan_interval(image_size, roi_size, num_spatial_dims): + assert (len(image_size) == num_spatial_dims), 'image coord different from spatial dims.' + assert (len(roi_size) == num_spatial_dims), 'roi coord different from spatial dims.' + + scan_interval = [1 for _ in range(num_spatial_dims)] + for i in range(num_spatial_dims): + if roi_size[i] == image_size[i]: + scan_interval[i] = int(roi_size[i]) + else: + # this means that it's r-16 (if r>=64) and r*0.75 (if r<=64) + scan_interval[i] = int(max(roi_size[i] - 16, roi_size[i] * 0.75)) + return scan_interval diff --git a/tests/test_image_end_padder.py b/tests/test_image_end_padder.py new file mode 100644 index 0000000000..ee186f8585 --- /dev/null +++ b/tests/test_image_end_padder.py @@ -0,0 +1,36 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np +from parameterized import parameterized +from monai.data.transforms import ImageEndPadder + +TEST_CASE_1 = [ + { + 'out_size': [16, 16, 8], + 'mode': 'constant' + }, + np.zeros((1, 3, 8, 8, 4)), + np.zeros((1, 3, 16, 16, 8)), +] + +class TestImageEndPadder(unittest.TestCase): + + @parameterized.expand([TEST_CASE_1]) + def test_image_end_pad_shape(self, input_param, input_data, expected_val): + padder = ImageEndPadder(**input_param) + result = padder(input_data) + self.assertAlmostEqual(result.shape, expected_val.shape) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_sliding_window_inference.py b/tests/test_sliding_window_inference.py new file mode 100644 index 0000000000..2ce9e9da07 --- /dev/null +++ b/tests/test_sliding_window_inference.py @@ -0,0 +1,37 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import torch +import numpy as np + +from monai.utils.sliding_window_inference import sliding_window_inference + + +class TestSlidingWindowInference(unittest.TestCase): + + def test_sliding_window_default(self): + inputs = np.ones((1, 3, 16, 16, 8)) + roi_size = [4, 4, 4] + sw_batch_size = 4 + device = torch.device("cuda:0") + + def compute(data): + data = torch.from_numpy(data) + return data.to(device) + 1 + + result = sliding_window_inference(inputs, roi_size, sw_batch_size, compute, device) + expected_val = torch.ones((1, 3, 16, 16, 8), dtype=torch.float32, device=device) + 1 + self.assertAlmostEqual(result.shape, expected_val.shape) + + +if __name__ == '__main__': + unittest.main() From 786810feefffb40bb95d90963fc8fd2d95178f80 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Thu, 6 Feb 2020 18:53:54 +0800 Subject: [PATCH 02/11] [DLMED] add Nifti writer for model output data --- monai/data/writers/niftiwriter.py | 38 +++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) create mode 100644 monai/data/writers/niftiwriter.py diff --git a/monai/data/writers/niftiwriter.py b/monai/data/writers/niftiwriter.py new file mode 100644 index 0000000000..9ba9e52c01 --- /dev/null +++ b/monai/data/writers/niftiwriter.py @@ -0,0 +1,38 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import nibabel as nib + + +def write_nifti(data, affine, file_name, revert_canonical, dtype="float32"): + """Write numpy data into nifti files to disk. + + Args: + data (numpy.ndarray): input data to write to file. + affine (numpy.ndarray): affine information for the data. + file_name (string): expected file name that saved on disk. + revert_canonical (bool): whether to revert canonical. + dtype (np.dtype, optional): convert the loaded image to this data type. + + """ + assert isinstance(data, np.ndarray), 'input data must be numpy array.' + if affine is None: + affine = np.eye(4) + + if revert_canonical: + codes = nib.orientations.axcodes2ornt(nib.orientations.aff2axcodes(np.linalg.inv(affine))) + reverted_results = nib.orientations.apply_orientation(np.squeeze(data), codes) + results_img = nib.Nifti1Image(reverted_results.astype(dtype), affine) + else: + results_img = nib.Nifti1Image(np.squeeze(data).astype(dtype), np.squeeze(affine)) + + nib.save(results_img, file_name) From a0894adb0fc3dc94fb71d6eb1a73d26e297bb949 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Thu, 6 Feb 2020 23:30:28 +0000 Subject: [PATCH 03/11] initial sliding window inference workflow --- examples/unet_inference_3d.py | 82 +++++++++++++++++ examples/unet_segmentation_3d.py | 40 ++------- .../handlers/segmentation_saver.py | 88 +++++++++++++++++++ monai/data/readers/niftireader.py | 24 ++++- monai/utils/generateddata.py | 32 ++++++- monai/utils/sliding_window_inference.py | 73 ++++++--------- tests/test_sliding_window_inference.py | 31 ++++--- 7 files changed, 273 insertions(+), 97 deletions(-) create mode 100644 examples/unet_inference_3d.py create mode 100644 monai/application/handlers/segmentation_saver.py diff --git a/examples/unet_inference_3d.py b/examples/unet_inference_3d.py new file mode 100644 index 0000000000..6c28d7506e --- /dev/null +++ b/examples/unet_inference_3d.py @@ -0,0 +1,82 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +import tempfile +from glob import glob + +import nibabel as nib +import numpy as np +import torch +import torchvision.transforms as transforms +from ignite.engine import Engine +from torch.utils.data import DataLoader + +from monai import application +from monai.application.handlers.segmentation_saver import SegmentationSaver +from monai.data.readers import NiftiDataset +from monai.data.transforms import AddChannel, Rescale, ToTensor +from monai.networks.nets.unet import UNet +from monai.networks.utils import predict_segmentation +from monai.utils.generateddata import create_test_image_3d +from monai.utils.sliding_window_inference import sliding_window_inference + +sys.path.append("..") # assumes the framework is found here, change as necessary +application.config.print_config() +tempdir = tempfile.mkdtemp() +# tempdir = './' +for i in range(50): + im, seg = create_test_image_3d(256, 256, 256) + + n = nib.Nifti1Image(im, np.eye(4)) + nib.save(n, os.path.join(tempdir, 'im%i.nii.gz' % i)) + + n = nib.Nifti1Image(seg, np.eye(4)) + nib.save(n, os.path.join(tempdir, 'seg%i.nii.gz' % i)) + +images = sorted(glob(os.path.join(tempdir, 'im*.nii.gz'))) +segs = sorted(glob(os.path.join(tempdir, 'seg*.nii.gz'))) +imtrans = transforms.Compose([Rescale(), AddChannel(), ToTensor()]) +segtrans = transforms.Compose([AddChannel(), ToTensor()]) +ds = NiftiDataset(images, segs, imtrans, segtrans, image_only=False) + +device = torch.device("cpu:0") +roi_size = (64, 64, 64) +sw_batch_size = 4 +net = UNet( + dimensions=3, + in_channels=1, + num_classes=1, + channels=(16, 32, 64, 128, 256), + strides=(2, 2, 2, 2), + num_res_units=2, +) +net.to(device) + + +def _sliding_window_processor(_engine, batch): + net.eval() + img, seg, meta_data = batch + with torch.no_grad(): + seg_probs = sliding_window_inference(img, roi_size, sw_batch_size, net, device) + return predict_segmentation(seg_probs) + + +infer_engine = Engine(_sliding_window_processor) + +# checkpoint_handler = ModelCheckpoint('./', 'net', n_saved=10, save_interval=3, require_empty=False) +# infer_engine.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=checkpoint_handler, to_save={'net': net}) + +SegmentationSaver().attach(infer_engine) + +loader = DataLoader(ds, batch_size=1, num_workers=0, pin_memory=torch.cuda.is_available()) +state = infer_engine.run(loader) diff --git a/examples/unet_segmentation_3d.py b/examples/unet_segmentation_3d.py index 4a508dbd08..fc2fe0d5ae 100644 --- a/examples/unet_segmentation_3d.py +++ b/examples/unet_segmentation_3d.py @@ -22,8 +22,9 @@ from ignite.handlers import ModelCheckpoint from torch.utils.data import DataLoader -from monai import application, networks, utils +from monai import application, networks from monai.data.readers import NiftiDataset +from monai.utils.generateddata import create_test_image_3d from monai.data.transforms import (AddChannel, Rescale, ToTensor, UniformRandomPatch) # assumes the framework is found here, change as necessary @@ -31,32 +32,6 @@ application.config.print_config() - -def create_test_image_3d(height, width, depth, num_objs=12, rad_max=30, noise_max=0.0, num_seg_classes=5): - '''Return a noisy 3D image and segmentation.''' - image = np.zeros((width, height, depth)) - - for i in range(num_objs): - x = np.random.randint(rad_max, width - rad_max) - y = np.random.randint(rad_max, height - rad_max) - z = np.random.randint(rad_max, depth - rad_max) - rad = np.random.randint(5, rad_max) - spy, spx, spz = np.ogrid[-x:width - x, -y:height - y, -z:depth - z] - circle = (spx * spx + spy * spy + spz * spz) <= rad * rad - - if num_seg_classes > 1: - image[circle] = np.ceil(np.random.random() * num_seg_classes) - else: - image[circle] = np.random.random() * 0.5 + 0.5 - - labels = np.ceil(image).astype(np.int32) - - norm = np.random.uniform(0, num_seg_classes * noise_max, size=image.shape) - noisyimage = utils.arrayutils.rescale_array(np.maximum(image, norm)) - - return noisyimage, labels - - tempdir = tempfile.mkdtemp() for i in range(50): @@ -77,12 +52,12 @@ def create_test_image_3d(height, width, depth, num_objs=12, rad_max=30, noise_ma ds = NiftiDataset(images, segs, imtrans, segtrans) -loader = DataLoader(ds, batch_size=10, num_workers=2, pin_memory=torch.cuda.is_available()) -im, seg = utils.mathutils.first(loader) -print(im.shape, seg.shape) +# loader = DataLoader(ds, batch_size=10, num_workers=2, pin_memory=torch.cuda.is_available()) +# im, seg = utils.mathutils.first(loader) +# print(im.shape, seg.shape) +train_epochs = 30 lr = 1e-3 - net = networks.nets.UNet( dimensions=3, in_channels=1, @@ -91,12 +66,9 @@ def create_test_image_3d(height, width, depth, num_objs=12, rad_max=30, noise_ma strides=(2, 2, 2, 2), num_res_units=2, ) - loss = networks.losses.DiceLoss() opt = torch.optim.Adam(net.parameters(), lr) -train_epochs = 30 - def _loss_fn(i, j): return loss(i[0], j) diff --git a/monai/application/handlers/segmentation_saver.py b/monai/application/handlers/segmentation_saver.py new file mode 100644 index 0000000000..91905103bc --- /dev/null +++ b/monai/application/handlers/segmentation_saver.py @@ -0,0 +1,88 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import nibabel as nib +import torch +from ignite.engine import Events + +from monai.data.writers.niftiwriter import write_nifti + + +class SegmentationSaver: + """ + Event handler triggered on completing every iteration to save the segmentation predictions. + """ + + def __init__(self, output_path='./', dtype='float32', output_postfix='seg', output_ext='.nii.gz'): + self.output_path = output_path + self.dtype = dtype + self.output_postfix = output_postfix + self.output_ext = output_ext + + def attach(self, engine): + return engine.add_event_handler(Events.ITERATION_COMPLETED, self) + + @staticmethod + def _create_file_basename(postfix, input_file_name, folder_path, data_root_dir=""): + """ + Utility function to create the path to the output file based on the input + filename (extension is added by lib level writer before writing the file) + + Args: + postfix (str): output name's postfix + input_file_name (str): path to the input image file + folder_path (str): path for the output file + data_root_dir (str): if not empty, it specifies the beginning parts of the input file's + absolute path. This is used to compute `input_file_rel_path`, the relative path to the file from + `data_root_dir` to preserve folder structure when saving in case there are files in different + folders with the same file names. + """ + + # get the filename and directory + filedir, filename = os.path.split(input_file_name) + + # jettison the extension to have just filename + filename, ext = os.path.splitext(filename) + while ext != "": + filename, ext = os.path.splitext(filename) + + # use data_root_dir to find relative path to file + filedir_rel_path = "" + if data_root_dir: + filedir_rel_path = os.path.relpath(filedir, data_root_dir) + + # sub-folder path will be original name without the extension + subfolder_path = os.path.join(folder_path, filedir_rel_path, filename) + if not os.path.exists(subfolder_path): + os.makedirs(subfolder_path) + + # add the sub-folder plus the postfix name to become the file basename in the output path + return os.path.join(subfolder_path, filename + "_" + postfix) + + def __call__(self, engine): + """ + This method assumes: + - 3rd output of engine.state.batch is a meta data dict + - engine.state.output is already post-processed and ready for saving as a Nifti1Image. + """ + meta_data = engine.state.batch[2] # assuming 3rd output of input dataset is a meta data dict + filenames = meta_data['filename_or_obj'] + for batch_id, filename in enumerate(filenames): # save a batch of files + seg_output = engine.state.output[batch_id] + if isinstance(seg_output, torch.Tensor): + seg_output = seg_output.detach().cpu().numpy() + original_affine = nib.load(filename).affine + output_filename = self._create_file_basename(self.output_postfix, filename, self.output_path) + output_filename = '{}{}'.format(output_filename, self.output_ext) + write_nifti(seg_output, original_affine, output_filename, revert_canonical=True, dtype=seg_output.dtype) + print('saved: {}'.format(output_filename)) diff --git a/monai/data/readers/niftireader.py b/monai/data/readers/niftireader.py index 34622819ca..6dafd9e52f 100644 --- a/monai/data/readers/niftireader.py +++ b/monai/data/readers/niftireader.py @@ -14,6 +14,7 @@ import random from torch.utils.data import Dataset +from torch.utils.data._utils.collate import np_str_obj_array_pattern from monai.utils.moduleutils import export @@ -59,7 +60,7 @@ class NiftiDataset(Dataset): for the image and segmentation arrays separately. """ - def __init__(self, image_files, seg_files, transform=None, seg_transform=None): + def __init__(self, image_files, seg_files, transform=None, seg_transform=None, image_only=True): """ Initializes the dataset with the image and segmentation filename lists. The transform `transform` is applied to the images and `seg_transform` to the segmentations. @@ -77,13 +78,18 @@ def __init__(self, image_files, seg_files, transform=None, seg_transform=None): self.image_files = image_files self.seg_files = seg_files self.transform = transform - self.seg_transform = seg_transform + self.seg_transform = seg_transform + self.image_only = image_only def __len__(self): return len(self.image_files) def __getitem__(self, index): - img = load_nifti(self.image_files[index]) + meta_data = None + if self.image_only: + img = load_nifti(self.image_files[index], image_only=self.image_only) + else: + img, meta_data = load_nifti(self.image_files[index], image_only=self.image_only) seg = load_nifti(self.seg_files[index]) # https://github.com/pytorch/vision/issues/9#issuecomment-304224800 @@ -97,4 +103,14 @@ def __getitem__(self, index): random.seed(seed) # ensure randomized transforms roll the same values for segmentations as images seg = self.seg_transform(seg) - return img, seg + if self.image_only or meta_data is None: + return img, seg + + compatible_meta = {} + for meta_key in meta_data: + meta_datum = meta_data[meta_key] + if type(meta_datum).__name__ == 'ndarray' \ + and np_str_obj_array_pattern.search(meta_datum.dtype.str) is not None: + continue + compatible_meta[meta_key] = meta_datum + return img, seg, compatible_meta diff --git a/monai/utils/generateddata.py b/monai/utils/generateddata.py index 26904943e1..04edb5f6c4 100644 --- a/monai/utils/generateddata.py +++ b/monai/utils/generateddata.py @@ -9,7 +9,6 @@ # See the License for the specific language governing permissions and # limitations under the License. - import numpy as np from monai.utils.arrayutils import rescale_array @@ -28,7 +27,7 @@ def create_test_image_2d(width, height, num_objs=12, rad_max=30, noise_max=0.0, x = np.random.randint(rad_max, width - rad_max) y = np.random.randint(rad_max, height - rad_max) rad = np.random.randint(5, rad_max) - spy, spx = np.ogrid[-x : width - x, -y : height - y] + spy, spx = np.ogrid[-x:width - x, -y:height - y] circle = (spx * spx + spy * spy) <= rad * rad if num_seg_classes > 1: @@ -42,3 +41,32 @@ def create_test_image_2d(width, height, num_objs=12, rad_max=30, noise_max=0.0, noisyimage = rescale_array(np.maximum(image, norm)) return noisyimage, labels + + +def create_test_image_3d(height, width, depth, num_objs=12, rad_max=30, noise_max=0.0, num_seg_classes=5): + """ + Return a noisy 3D image and segmentation. + + See also: create_test_image_2d + """ + image = np.zeros((width, height, depth)) + + for i in range(num_objs): + x = np.random.randint(rad_max, width - rad_max) + y = np.random.randint(rad_max, height - rad_max) + z = np.random.randint(rad_max, depth - rad_max) + rad = np.random.randint(5, rad_max) + spy, spx, spz = np.ogrid[-x:width - x, -y:height - y, -z:depth - z] + circle = (spx * spx + spy * spy + spz * spz) <= rad * rad + + if num_seg_classes > 1: + image[circle] = np.ceil(np.random.random() * num_seg_classes) + else: + image[circle] = np.random.random() * 0.5 + 0.5 + + labels = np.ceil(image).astype(np.int32) + + norm = np.random.uniform(0, num_seg_classes * noise_max, size=image.shape) + noisyimage = rescale_array(np.maximum(image, norm)) + + return noisyimage, labels diff --git a/monai/utils/sliding_window_inference.py b/monai/utils/sliding_window_inference.py index 746e9932ba..8041635999 100644 --- a/monai/utils/sliding_window_inference.py +++ b/monai/utils/sliding_window_inference.py @@ -9,20 +9,22 @@ # See the License for the specific language governing permissions and # limitations under the License. - import math -import numpy as np + import torch + from monai.data.transforms import ImageEndPadder +from monai.data.transforms.dataset_transforms import ToTensor def sliding_window_inference(inputs, roi_size, sw_batch_size, predictor, device): """Use SlidingWindow method to execute inference. Args: + inputs (numpy array): input image to be processed (assuming NCHW[D]) roi_size (list, tuple): the window size to execute SlidingWindow inference. sw_batch_size (int): the batch size to run window slices. - predictor: a portion from numpy.lib.arraypad.pad is copied below. + predictor: a moani.networks.nets module device: on which device to execute model inference, cpu or gpu. Note: @@ -39,24 +41,16 @@ def sliding_window_inference(inputs, roi_size, sw_batch_size, predictor, device) batch_size = inputs.shape[0] # TODO: Enable batch sizes > 1 in future. - assert batch_size == 1, "input batch size must be 1" + if batch_size > 1: + raise NotImplementedError - # in case that image size is smaller than roi size - is_oversized = False original_image_size = [image_size[i] for i in range(num_spatial_dims)] + # in case that image size is smaller than roi size + image_size = tuple(max(image_size[i], roi_size[i]) for i in range(num_spatial_dims)) + inputs = ImageEndPadder(roi_size, 'constant')(inputs) # in np array + inputs = ToTensor()(inputs) - for i in range(num_spatial_dims): - if int(roi_size[i]) > image_size[i]: - is_oversized = True - break - - if is_oversized: - for i in range(num_spatial_dims): - image_size[i] = max(image_size[i], roi_size[i]) - - padder = ImageEndPadder(roi_size, 'constant') - inputs = padder(inputs) - + # TODO: interval from user's specification scan_interval = _get_scan_interval(image_size, roi_size, num_spatial_dims) scan_num = [int(math.ceil(float(image_size[i]) / scan_interval[i])) for i in range(num_spatial_dims)] @@ -101,54 +95,41 @@ def sliding_window_inference(inputs, roi_size, sw_batch_size, predictor, device) else: slice_i, slice_j = slices[curr_index] input_slices.append(inputs[0, :, slice_i, slice_j]) - buffered_requests.append(np.stack(input_slices)) + buffered_requests.append(torch.stack(input_slices)) # Perform predictions output_rois = list() for data in buffered_requests: - output_rois.append(predictor(data)) + seg_prob, _ = predictor(data) # segmentation probabilities + output_rois.append(seg_prob) + # stitching output image output_classes = output_rois[0].shape[1] output_shape = [batch_size, output_classes] + list(image_size) # allocate memory to store the full output and the count for overlapping parts - output_dict = torch.zeros(output_shape, dtype=torch.float32, device=device) - count_dict = torch.zeros(output_shape, dtype=torch.float32, device=device) + output_image = torch.zeros(output_shape, dtype=torch.float32, device=device) + count_map = torch.zeros(output_shape, dtype=torch.float32, device=device) - window_index = 0 - for slice_index in range(0, len(slices), sw_batch_size): + for window_id, slice_index in enumerate(range(0, len(slices), sw_batch_size)): slice_index_range = range(slice_index, min(slice_index + sw_batch_size, len(slices))) - output_roi = output_rois[window_index] - window_index += 1 - # store the result in the proper location of the full output for curr_index in slice_index_range: if num_spatial_dims == 3: slice_i, slice_j, slice_k = slices[curr_index] - output_dict[0, :, slice_i, slice_j, slice_k] += \ - output_roi[curr_index - slice_index, :] - count_dict[0, :, slice_i, slice_j, slice_k] += 1 + output_image[0, :, slice_i, slice_j, slice_k] += output_rois[window_id][curr_index - slice_index, :] + count_map[0, :, slice_i, slice_j, slice_k] += 1. else: slice_i, slice_j = slices[curr_index] - output_dict[0, :, slice_i, slice_j] += \ - output_roi[curr_index - slice_index, :] - count_dict[0, :, slice_i, slice_j] += 1 + output_image[0, :, slice_i, slice_j] += output_rois[window_id][curr_index - slice_index, :] + count_map[0, :, slice_i, slice_j] += 1. # account for any overlapping sections - output_dict /= count_dict - - # in case that image size is smaller than roi size - if is_oversized: - new_output_dict = list() - if num_spatial_dims == 3: - new_output_dict = output_dict[:, :, :original_image_size[0], - :original_image_size[1], :original_image_size[2]] - else: - new_output_dict = output_dict[:, :, :original_image_size[0], :original_image_size[1]] + output_image /= count_map - output_dict = new_output_dict - - return output_dict + if num_spatial_dims == 3: + return output_image[..., :original_image_size[0], :original_image_size[1], :original_image_size[2]] + return output_image[..., :original_image_size[0], :original_image_size[1]] # 2D def _get_scan_interval(image_size, roi_size, num_spatial_dims): diff --git a/tests/test_sliding_window_inference.py b/tests/test_sliding_window_inference.py index 2ce9e9da07..a3f659aeaf 100644 --- a/tests/test_sliding_window_inference.py +++ b/tests/test_sliding_window_inference.py @@ -10,27 +10,36 @@ # limitations under the License. import unittest -import torch + import numpy as np +import torch +from parameterized import parameterized from monai.utils.sliding_window_inference import sliding_window_inference +TEST_CASE_1 = [(1, 3, 16, 15, 7), (4, 10, 7), 3] # 3D small roi + +TEST_CASE_2 = [(1, 3, 16, 15, 7), (20, 22, 23), 10] # 3D large roi + +TEST_CASE_3 = [(1, 3, 15, 7), (2, 6), 1000] # 2D small roi, large batch + +TEST_CASE_4 = [(1, 3, 16, 7), (80, 50), 7] # 2D large roi + class TestSlidingWindowInference(unittest.TestCase): - def test_sliding_window_default(self): - inputs = np.ones((1, 3, 16, 16, 8)) - roi_size = [4, 4, 4] - sw_batch_size = 4 - device = torch.device("cuda:0") + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4]) + def test_sliding_window_default(self, image_shape, roi_shape, sw_batch_size): + inputs = np.ones(image_shape) + device = torch.device("cpu:0") def compute(data): - data = torch.from_numpy(data) - return data.to(device) + 1 + # data = torch.from_numpy(data) + return data.to(device) + 1, None # to be consistent with monai.networks.nets.unet.UNet - result = sliding_window_inference(inputs, roi_size, sw_batch_size, compute, device) - expected_val = torch.ones((1, 3, 16, 16, 8), dtype=torch.float32, device=device) + 1 - self.assertAlmostEqual(result.shape, expected_val.shape) + result = sliding_window_inference(inputs, roi_shape, sw_batch_size, compute, device) + expected_val = np.ones(image_shape, dtype=np.float32) + 1 + self.assertTrue(np.allclose(result.numpy(), expected_val)) if __name__ == '__main__': From 775715d1c74816a7af7f54d574c0cdf132410874 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Fri, 7 Feb 2020 10:56:08 +0800 Subject: [PATCH 04/11] [DLMED] add checkpoint loader as ignite event-handler --- .../application/handlers/checkpoint_loader.py | 44 +++++++++++++++++++ 1 file changed, 44 insertions(+) create mode 100644 monai/application/handlers/checkpoint_loader.py diff --git a/monai/application/handlers/checkpoint_loader.py b/monai/application/handlers/checkpoint_loader.py new file mode 100644 index 0000000000..1304044b1c --- /dev/null +++ b/monai/application/handlers/checkpoint_loader.py @@ -0,0 +1,44 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch +from ignite.handlers import Checkpoint +from ignite.engine import Events +import monai + + +@monai.utils.export("monai.application.handlers") +@monai.utils.alias("CheckpointLoader") +class CheckpointLoader: + """CheckpointLoader acts as an ignite handler to load checkpoint data from file. + It can load variables for network, optimizer, lr_scheduler. + And also can restore training if load the state_dict of ignite engine. + + Args: + load_path (string): the file path of checkpint, it shoule be a PyTorch pth file. + load_dict (dict): target objects that load checkpoint to. examples: + {'network': net, 'optimizer': optimizer, 'engine', engine} + + """ + def __init__(self, load_path, load_dict): + assert load_path is not None, 'must provide clear path to load checkpoint.' + self.load_path = load_path + assert load_dict is not None and len(load_dict) > 0, 'must provide target objects to load.' + self.load_dict = load_dict + + def attach(self, engine): + return engine.add_event_handler(Events.STARTED, self) + + def __call__(self, engine): + checkpoint = torch.load(self.load_path) + Checkpoint.load_objects(to_load=self.load_dict, checkpoint=checkpoint) + print('Restored all variables from {}'.format(self.load_path)) From 17f574faf5c45c344ccd8cb6fa0d9ed3a1195155 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Fri, 7 Feb 2020 10:03:53 +0000 Subject: [PATCH 05/11] added iter_dense_patch_slices, rename buffered_slices to slice_batches, fixes typos --- monai/utils/arrayutils.py | 61 +++++++++++++++++++++++-- monai/utils/sliding_window_inference.py | 45 ++++-------------- 2 files changed, 64 insertions(+), 42 deletions(-) diff --git a/monai/utils/arrayutils.py b/monai/utils/arrayutils.py index 79cf96ffb5..01af572bba 100644 --- a/monai/utils/arrayutils.py +++ b/monai/utils/arrayutils.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. - +import math import random from itertools import product, starmap @@ -193,8 +193,8 @@ def get_random_patch(dims, patch_size): def iter_patch_slices(dims, patch_size, start_pos=()): """ - Yield successive tuples of slices defining patches of size `patch_size` from an array of dimensions `dims`. The - iteration starts from position `start_pos` in the array, or starting at the origin if this isn't provided. Each + Yield successive tuples of slices defining patches of size `patch_size` from an array of dimensions `dims`. The + iteration starts from position `start_pos` in the array, or starting at the origin if this isn't provided. Each patch is chosen in a contiguous grid using a first dimension as least significant ordering. Args: @@ -219,10 +219,61 @@ def iter_patch_slices(dims, patch_size, start_pos=()): yield tuple(slice(s, s + p) for s, p in zip(position[::-1], patch_size)) +def iter_dense_patch_slices(image_size, patch_size, scan_interval): + """ + Enumerate all slices defining 2D/3D patches of size `patch_size` from an `image_size` input image. + + Args: + image_size (tuple of int): dimensions of image to iterate over + patch_size (tuple of int): size of patches to generate slices + scan_interval (tuple of int): dense patch sampling interval + + Returns: + a list of slice objects defining each patch + """ + num_spatial_dims = len(image_size) + if num_spatial_dims not in (2, 3): + raise ValueError('image_size should has 2 or 3 elements') + patch_size = get_valid_patch_size(image_size, patch_size) + scan_interval = ensure_tuple_size(scan_interval, num_spatial_dims) + + scan_num = [int(math.ceil(float(image_size[i]) / scan_interval[i])) if scan_interval[i] != 0 else 1 + for i in range(num_spatial_dims)] + slices = [] + if num_spatial_dims == 3: + for i in range(scan_num[0]): + start_i = i * scan_interval[0] + start_i -= max(start_i + patch_size[0] - image_size[0], 0) + slice_i = slice(start_i, start_i + patch_size[0]) + + for j in range(scan_num[1]): + start_j = j * scan_interval[1] + start_j -= max(start_j + patch_size[1] - image_size[1], 0) + slice_j = slice(start_j, start_j + patch_size[1]) + + for k in range(0, scan_num[2]): + start_k = k * scan_interval[2] + start_k -= max(start_k + patch_size[2] - image_size[2], 0) + slice_k = slice(start_k, start_k + patch_size[2]) + slices.append((slice_i, slice_j, slice_k)) + else: + for i in range(scan_num[0]): + start_i = i * scan_interval[0] + start_i -= max(start_i + patch_size[0] - image_size[0], 0) + slice_i = slice(start_i, start_i + patch_size[0]) + + for j in range(scan_num[1]): + start_j = j * scan_interval[1] + start_j -= max(start_j + patch_size[1] - image_size[1], 0) + slice_j = slice(start_j, start_j + patch_size[1]) + slices.append((slice_i, slice_j)) + return slices + + def iter_patch(arr, patch_size, start_pos=(), copy_back=True, pad_mode="wrap", **pad_opts): """ - Yield successive patches from `arr' of size `patchSize'. The iteration can start from position `startPos' in `arr' - but drawing from a padded array extended by the `patchSize' in each dimension (so these coordinates can be negative + Yield successive patches from `arr' of size `patchSize'. The iteration can start from position `startPos' in `arr' + but drawing from a padded array extended by the `patchSize' in each dimension (so these coordinates can be negative to start in the padded region). If `copyBack' is True the values from each patch are written back to `arr'. Args: diff --git a/monai/utils/sliding_window_inference.py b/monai/utils/sliding_window_inference.py index 8041635999..243067acf4 100644 --- a/monai/utils/sliding_window_inference.py +++ b/monai/utils/sliding_window_inference.py @@ -9,12 +9,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -import math - import torch from monai.data.transforms import ImageEndPadder from monai.data.transforms.dataset_transforms import ToTensor +from monai.utils.arrayutils import iter_dense_patch_slices def sliding_window_inference(inputs, roi_size, sw_batch_size, predictor, device): @@ -24,7 +23,7 @@ def sliding_window_inference(inputs, roi_size, sw_batch_size, predictor, device) inputs (numpy array): input image to be processed (assuming NCHW[D]) roi_size (list, tuple): the window size to execute SlidingWindow inference. sw_batch_size (int): the batch size to run window slices. - predictor: a moani.networks.nets module + predictor: a monai.networks.nets module device: on which device to execute model inference, cpu or gpu. Note: @@ -52,39 +51,11 @@ def sliding_window_inference(inputs, roi_size, sw_batch_size, predictor, device) # TODO: interval from user's specification scan_interval = _get_scan_interval(image_size, roi_size, num_spatial_dims) - scan_num = [int(math.ceil(float(image_size[i]) / scan_interval[i])) for i in range(num_spatial_dims)] # Store all slices in list. - slices = [] - if num_spatial_dims == 3: - for i in range(scan_num[0]): - start_i = i * scan_interval[0] - start_i -= max(start_i + roi_size[0] - image_size[0], 0) - slice_i = slice(start_i, start_i + roi_size[0]) - - for j in range(scan_num[1]): - start_j = j * scan_interval[1] - start_j -= max(start_j + roi_size[1] - image_size[1], 0) - slice_j = slice(start_j, start_j + roi_size[1]) - - for k in range(0, scan_num[2]): - start_k = k * scan_interval[2] - start_k -= max(start_k + roi_size[2] - image_size[2], 0) - slice_k = slice(start_k, start_k + roi_size[2]) - slices.append((slice_i, slice_j, slice_k)) - else: - for i in range(scan_num[0]): - start_i = i * scan_interval[0] - start_i -= max(start_i + roi_size[0] - image_size[0], 0) - slice_i = slice(start_i, start_i + roi_size[0]) - - for j in range(scan_num[1]): - start_j = j * scan_interval[1] - start_j -= max(start_j + roi_size[1] - image_size[1], 0) - slice_j = slice(start_j, start_j + roi_size[1]) - slices.append((slice_i, slice_j)) - - buffered_requests = [] + slices = iter_dense_patch_slices(image_size, roi_size, scan_interval) + + slice_batches = [] for slice_index in range(0, len(slices), sw_batch_size): slice_index_range = range(slice_index, min(slice_index + sw_batch_size, len(slices))) input_slices = [] @@ -95,11 +66,11 @@ def sliding_window_inference(inputs, roi_size, sw_batch_size, predictor, device) else: slice_i, slice_j = slices[curr_index] input_slices.append(inputs[0, :, slice_i, slice_j]) - buffered_requests.append(torch.stack(input_slices)) + slice_batches.append(torch.stack(input_slices)) # Perform predictions output_rois = list() - for data in buffered_requests: + for data in slice_batches: seg_prob, _ = predictor(data) # segmentation probabilities output_rois.append(seg_prob) @@ -143,4 +114,4 @@ def _get_scan_interval(image_size, roi_size, num_spatial_dims): else: # this means that it's r-16 (if r>=64) and r*0.75 (if r<=64) scan_interval[i] = int(max(roi_size[i] - 16, roi_size[i] * 0.75)) - return scan_interval + return tuple(scan_interval) From 78e17cafe93f15d7f3bbe2533fe6d94af567062d Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Fri, 7 Feb 2020 11:25:32 +0000 Subject: [PATCH 06/11] include checkpoint loader handler --- .gitignore | 1 + examples/unet_inference_3d.py | 8 +++++--- examples/unet_segmentation_3d.py | 13 +++++++------ monai/application/handlers/checkpoint_loader.py | 10 ++++++---- 4 files changed, 19 insertions(+), 13 deletions(-) diff --git a/.gitignore b/.gitignore index c30f242fd2..ca4a13ee25 100644 --- a/.gitignore +++ b/.gitignore @@ -103,4 +103,5 @@ venv.bak/ # mypy .mypy_cache/ examples/scd_lvsegs.npz +.temp/ .idea/ diff --git a/examples/unet_inference_3d.py b/examples/unet_inference_3d.py index 6c28d7506e..98dce86d5b 100644 --- a/examples/unet_inference_3d.py +++ b/examples/unet_inference_3d.py @@ -22,6 +22,7 @@ from torch.utils.data import DataLoader from monai import application +from monai.application.handlers.checkpoint_loader import CheckpointLoader from monai.application.handlers.segmentation_saver import SegmentationSaver from monai.data.readers import NiftiDataset from monai.data.transforms import AddChannel, Rescale, ToTensor @@ -33,7 +34,7 @@ sys.path.append("..") # assumes the framework is found here, change as necessary application.config.print_config() tempdir = tempfile.mkdtemp() -# tempdir = './' +tempdir = './temp' for i in range(50): im, seg = create_test_image_3d(256, 256, 256) @@ -76,7 +77,8 @@ def _sliding_window_processor(_engine, batch): # checkpoint_handler = ModelCheckpoint('./', 'net', n_saved=10, save_interval=3, require_empty=False) # infer_engine.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=checkpoint_handler, to_save={'net': net}) -SegmentationSaver().attach(infer_engine) +SegmentationSaver(output_path='tempdir', output_ext='.nii.gz', output_postfix='seg').attach(infer_engine) +CheckpointLoader(load_path='./net_checkpoint_9.pth', load_dict={'net': net}).attach(infer_engine) -loader = DataLoader(ds, batch_size=1, num_workers=0, pin_memory=torch.cuda.is_available()) +loader = DataLoader(ds, batch_size=1, num_workers=1, pin_memory=torch.cuda.is_available()) state = infer_engine.run(loader) diff --git a/examples/unet_segmentation_3d.py b/examples/unet_segmentation_3d.py index fc2fe0d5ae..ff9dde8851 100644 --- a/examples/unet_segmentation_3d.py +++ b/examples/unet_segmentation_3d.py @@ -24,8 +24,8 @@ from monai import application, networks from monai.data.readers import NiftiDataset -from monai.utils.generateddata import create_test_image_3d from monai.data.transforms import (AddChannel, Rescale, ToTensor, UniformRandomPatch) +from monai.utils.generateddata import create_test_image_3d # assumes the framework is found here, change as necessary sys.path.append("..") @@ -33,7 +33,7 @@ application.config.print_config() tempdir = tempfile.mkdtemp() - +tempdir = './temp' for i in range(50): im, seg = create_test_image_3d(256, 256, 256) @@ -47,7 +47,6 @@ segs = sorted(glob(os.path.join(tempdir, 'seg*.nii.gz'))) imtrans = transforms.Compose([Rescale(), AddChannel(), UniformRandomPatch((64, 64, 64)), ToTensor()]) - segtrans = transforms.Compose([AddChannel(), UniformRandomPatch((64, 64, 64)), ToTensor()]) ds = NiftiDataset(images, segs, imtrans, segtrans) @@ -74,12 +73,14 @@ def _loss_fn(i, j): return loss(i[0], j) -device = torch.device("cuda:0") +device = torch.device("cpu:0") trainer = create_supervised_trainer(net, opt, _loss_fn, device, False) -checkpoint_handler = ModelCheckpoint('./', 'net', n_saved=10, save_interval=3, require_empty=False) -trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=checkpoint_handler, to_save={'net': net}) +checkpoint_handler = ModelCheckpoint('./', 'net', n_saved=10, require_empty=False) +trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED(every=3), + handler=checkpoint_handler, + to_save={'net': net, 'opt': opt}) @trainer.on(Events.EPOCH_COMPLETED) diff --git a/monai/application/handlers/checkpoint_loader.py b/monai/application/handlers/checkpoint_loader.py index 1304044b1c..dc2d2df5b2 100644 --- a/monai/application/handlers/checkpoint_loader.py +++ b/monai/application/handlers/checkpoint_loader.py @@ -9,26 +9,28 @@ # See the License for the specific language governing permissions and # limitations under the License. - import torch -from ignite.handlers import Checkpoint from ignite.engine import Events +from ignite.handlers import Checkpoint + import monai @monai.utils.export("monai.application.handlers") @monai.utils.alias("CheckpointLoader") class CheckpointLoader: - """CheckpointLoader acts as an ignite handler to load checkpoint data from file. + """ + CheckpointLoader acts as an ignite handler to load checkpoint data from file. It can load variables for network, optimizer, lr_scheduler. And also can restore training if load the state_dict of ignite engine. Args: - load_path (string): the file path of checkpint, it shoule be a PyTorch pth file. + load_path (string): the file path of checkpoint, it should be a PyTorch pth file. load_dict (dict): target objects that load checkpoint to. examples: {'network': net, 'optimizer': optimizer, 'engine', engine} """ + def __init__(self, load_path, load_dict): assert load_path is not None, 'must provide clear path to load checkpoint.' self.load_path = load_path From 8427da6d0c7094e5eed9b689c2f3f57f408ebecc Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Fri, 7 Feb 2020 12:17:41 +0000 Subject: [PATCH 07/11] renaming iter_dense_patch_slices to dense_patch_slices --- monai/utils/arrayutils.py | 2 +- monai/utils/sliding_window_inference.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/monai/utils/arrayutils.py b/monai/utils/arrayutils.py index 01af572bba..8d9c19bb0b 100644 --- a/monai/utils/arrayutils.py +++ b/monai/utils/arrayutils.py @@ -219,7 +219,7 @@ def iter_patch_slices(dims, patch_size, start_pos=()): yield tuple(slice(s, s + p) for s, p in zip(position[::-1], patch_size)) -def iter_dense_patch_slices(image_size, patch_size, scan_interval): +def dense_patch_slices(image_size, patch_size, scan_interval): """ Enumerate all slices defining 2D/3D patches of size `patch_size` from an `image_size` input image. diff --git a/monai/utils/sliding_window_inference.py b/monai/utils/sliding_window_inference.py index 243067acf4..b9698b9b4e 100644 --- a/monai/utils/sliding_window_inference.py +++ b/monai/utils/sliding_window_inference.py @@ -13,7 +13,7 @@ from monai.data.transforms import ImageEndPadder from monai.data.transforms.dataset_transforms import ToTensor -from monai.utils.arrayutils import iter_dense_patch_slices +from monai.utils.arrayutils import dense_patch_slices def sliding_window_inference(inputs, roi_size, sw_batch_size, predictor, device): @@ -53,7 +53,7 @@ def sliding_window_inference(inputs, roi_size, sw_batch_size, predictor, device) scan_interval = _get_scan_interval(image_size, roi_size, num_spatial_dims) # Store all slices in list. - slices = iter_dense_patch_slices(image_size, roi_size, scan_interval) + slices = dense_patch_slices(image_size, roi_size, scan_interval) slice_batches = [] for slice_index in range(0, len(slices), sw_batch_size): From 2401e2195ba8859abb8b11386910d4af0a6813dd Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Wed, 12 Feb 2020 18:24:15 +0000 Subject: [PATCH 08/11] updates interfaces of sliding window predictor and seg saver --- examples/unet_inference_3d.py | 2 +- .../application/handlers/segmentation_saver.py | 18 ++++++++++++++++-- monai/data/writers/niftiwriter.py | 2 +- monai/utils/sliding_window_inference.py | 6 ++++-- 4 files changed, 22 insertions(+), 6 deletions(-) diff --git a/examples/unet_inference_3d.py b/examples/unet_inference_3d.py index 98dce86d5b..e28b0d1427 100644 --- a/examples/unet_inference_3d.py +++ b/examples/unet_inference_3d.py @@ -68,7 +68,7 @@ def _sliding_window_processor(_engine, batch): net.eval() img, seg, meta_data = batch with torch.no_grad(): - seg_probs = sliding_window_inference(img, roi_size, sw_batch_size, net, device) + seg_probs = sliding_window_inference(img, roi_size, sw_batch_size, lambda x: net(x)[0], device) return predict_segmentation(seg_probs) diff --git a/monai/application/handlers/segmentation_saver.py b/monai/application/handlers/segmentation_saver.py index 91905103bc..366d92786b 100644 --- a/monai/application/handlers/segmentation_saver.py +++ b/monai/application/handlers/segmentation_saver.py @@ -23,11 +23,24 @@ class SegmentationSaver: Event handler triggered on completing every iteration to save the segmentation predictions. """ - def __init__(self, output_path='./', dtype='float32', output_postfix='seg', output_ext='.nii.gz'): + def __init__(self, output_path='./', dtype='float32', output_postfix='seg', output_ext='.nii.gz', + output_transform=lambda x: x): + """ + Args: + output_path (str): output image directory. + dtype (str): to convert the image to save to this datatype. + output_postfix (str): a string appended to all output file names. + output_ext (str): output file extension name. + output_transform (Callable): a callable that is used to transform the + ignite.engine.output into the form expected nifti image data. + The first dimension of this transform's output will be treated as the + batch dimension. Each item in the batch will be saved individually. + """ self.output_path = output_path self.dtype = dtype self.output_postfix = output_postfix self.output_ext = output_ext + self.output_transform = output_transform def attach(self, engine): return engine.add_event_handler(Events.ITERATION_COMPLETED, self) @@ -77,8 +90,9 @@ def __call__(self, engine): """ meta_data = engine.state.batch[2] # assuming 3rd output of input dataset is a meta data dict filenames = meta_data['filename_or_obj'] + engine_output = self.output_transform(engine.state.output) for batch_id, filename in enumerate(filenames): # save a batch of files - seg_output = engine.state.output[batch_id] + seg_output = engine_output[batch_id] if isinstance(seg_output, torch.Tensor): seg_output = seg_output.detach().cpu().numpy() original_affine = nib.load(filename).affine diff --git a/monai/data/writers/niftiwriter.py b/monai/data/writers/niftiwriter.py index 9ba9e52c01..6a97b42c46 100644 --- a/monai/data/writers/niftiwriter.py +++ b/monai/data/writers/niftiwriter.py @@ -21,7 +21,7 @@ def write_nifti(data, affine, file_name, revert_canonical, dtype="float32"): affine (numpy.ndarray): affine information for the data. file_name (string): expected file name that saved on disk. revert_canonical (bool): whether to revert canonical. - dtype (np.dtype, optional): convert the loaded image to this data type. + dtype (np.dtype, optional): convert the image to save to this data type. """ assert isinstance(data, np.ndarray), 'input data must be numpy array.' diff --git a/monai/utils/sliding_window_inference.py b/monai/utils/sliding_window_inference.py index b9698b9b4e..c81303a8cb 100644 --- a/monai/utils/sliding_window_inference.py +++ b/monai/utils/sliding_window_inference.py @@ -23,7 +23,9 @@ def sliding_window_inference(inputs, roi_size, sw_batch_size, predictor, device) inputs (numpy array): input image to be processed (assuming NCHW[D]) roi_size (list, tuple): the window size to execute SlidingWindow inference. sw_batch_size (int): the batch size to run window slices. - predictor: a monai.networks.nets module + predictor (Callable): given input tensor `patch_data` in shape NCHW[D], `predictor(patch_data)` + should return a prediction with the same spatial shape and batch_size, i.e. NMHW[D]; + where HW[D] represents the patch spatial size, M is the number of output channels, N is `sw_batch_size`. device: on which device to execute model inference, cpu or gpu. Note: @@ -71,7 +73,7 @@ def sliding_window_inference(inputs, roi_size, sw_batch_size, predictor, device) # Perform predictions output_rois = list() for data in slice_batches: - seg_prob, _ = predictor(data) # segmentation probabilities + seg_prob = predictor(data) # batched patch segmentation output_rois.append(seg_prob) # stitching output image From 3925c2d791bc3daf79221c78ca967ca2d1fddcff Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Wed, 12 Feb 2020 18:30:37 +0000 Subject: [PATCH 09/11] fixes sliding window unit test --- tests/test_sliding_window_inference.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/test_sliding_window_inference.py b/tests/test_sliding_window_inference.py index a3f659aeaf..e0d727c407 100644 --- a/tests/test_sliding_window_inference.py +++ b/tests/test_sliding_window_inference.py @@ -34,8 +34,7 @@ def test_sliding_window_default(self, image_shape, roi_shape, sw_batch_size): device = torch.device("cpu:0") def compute(data): - # data = torch.from_numpy(data) - return data.to(device) + 1, None # to be consistent with monai.networks.nets.unet.UNet + return data.to(device) + 1 result = sliding_window_inference(inputs, roi_shape, sw_batch_size, compute, device) expected_val = np.ones(image_shape, dtype=np.float32) + 1 From 064e6b32e26c8ea34cd02ee35a3f0af971769296 Mon Sep 17 00:00:00 2001 From: Mohammad Adil Date: Thu, 13 Feb 2020 02:58:24 -0800 Subject: [PATCH 10/11] Fix canonical transformation bug in nifti writer. (#74) * Fix canonical transformation bug in nifti writer. * fixes nifti reader/writer's as_closest_canonical option, adds unit tests Co-authored-by: Wenqi Li --- monai/data/readers/niftireader.py | 19 ++++++--- monai/data/writers/niftiwriter.py | 19 +++++---- tests/test_nifti_rw.py | 65 +++++++++++++++++++++++++++++++ tests/utils.py | 14 +++++++ 4 files changed, 103 insertions(+), 14 deletions(-) create mode 100644 tests/test_nifti_rw.py diff --git a/monai/data/readers/niftireader.py b/monai/data/readers/niftireader.py index 6dafd9e52f..1898411694 100644 --- a/monai/data/readers/niftireader.py +++ b/monai/data/readers/niftireader.py @@ -32,25 +32,32 @@ def load_nifti(filename_or_obj, as_closest_canonical=False, image_only=True, dty Returns: The loaded image volume if `image_only` is True, or a tuple containing the volume and the Nifti header in dict format otherwise + + Note: + header['original_affine'] stores the original affine loaded from `filename_or_obj`. + header['affine'] stores the affine after the optional `as_closest_canonical` transform. """ img = nib.load(filename_or_obj) + header = dict(img.header) + header['filename_or_obj'] = filename_or_obj + header['original_affine'] = img.affine + header['affine'] = img.affine + header['as_closest_canonical'] = as_closest_canonical + if as_closest_canonical: img = nib.as_closest_canonical(img) + header['affine'] = img.affine if dtype is not None: dat = img.get_fdata(dtype=dtype) else: dat = np.asanyarray(img.dataobj) - header = dict(img.header) - header['filename_or_obj'] = filename_or_obj - if image_only: return dat - else: - return dat, header + return dat, header @export("monai.data.readers") @@ -77,7 +84,7 @@ def __init__(self, image_files, seg_files, transform=None, seg_transform=None, i self.image_files = image_files self.seg_files = seg_files - self.transform = transform + self.transform = transform self.seg_transform = seg_transform self.image_only = image_only diff --git a/monai/data/writers/niftiwriter.py b/monai/data/writers/niftiwriter.py index 6a97b42c46..44e454d91a 100644 --- a/monai/data/writers/niftiwriter.py +++ b/monai/data/writers/niftiwriter.py @@ -13,26 +13,29 @@ import nibabel as nib -def write_nifti(data, affine, file_name, revert_canonical, dtype="float32"): +def write_nifti(data, affine, file_name, target_affine=None, dtype="float32"): """Write numpy data into nifti files to disk. Args: data (numpy.ndarray): input data to write to file. affine (numpy.ndarray): affine information for the data. file_name (string): expected file name that saved on disk. - revert_canonical (bool): whether to revert canonical. + target_affine (numpy.ndarray, optional): + before saving the (data, affine), transform the data into the orientation defined by `target_affine`. dtype (np.dtype, optional): convert the image to save to this data type. - """ assert isinstance(data, np.ndarray), 'input data must be numpy array.' if affine is None: affine = np.eye(4) - if revert_canonical: - codes = nib.orientations.axcodes2ornt(nib.orientations.aff2axcodes(np.linalg.inv(affine))) - reverted_results = nib.orientations.apply_orientation(np.squeeze(data), codes) - results_img = nib.Nifti1Image(reverted_results.astype(dtype), affine) + if target_affine is None: + results_img = nib.Nifti1Image(data.astype(dtype), affine) else: - results_img = nib.Nifti1Image(np.squeeze(data).astype(dtype), np.squeeze(affine)) + start_ornt = nib.orientations.io_orientation(affine) + target_ornt = nib.orientations.io_orientation(target_affine) + ornt_transform = nib.orientations.ornt_transform(start_ornt, target_ornt) + + reverted_results = nib.orientations.apply_orientation(data, ornt_transform) + results_img = nib.Nifti1Image(reverted_results.astype(dtype), target_affine) nib.save(results_img, file_name) diff --git a/tests/test_nifti_rw.py b/tests/test_nifti_rw.py new file mode 100644 index 0000000000..40ecb1fba7 --- /dev/null +++ b/tests/test_nifti_rw.py @@ -0,0 +1,65 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import unittest + +import nibabel as nib +import numpy as np +from parameterized import parameterized + +from monai.data.readers.niftireader import load_nifti +from monai.data.writers.niftiwriter import write_nifti + +from .utils import make_nifti_image + +TEST_IMAGE = np.zeros((1, 2, 3)) +TEST_AFFINE = np.array([[-5.3, 0., 0., 102.01], [0., 0.52, 2.17, -7.50], [-0., 1.98, -0.26, -23.12], [0., 0., 0., 1.]]) + +TEST_CASE_1 = [TEST_IMAGE, TEST_AFFINE, (1, 2, 3), dict(as_closest_canonical=True, image_only=False)] +TEST_CASE_2 = [TEST_IMAGE, TEST_AFFINE, (1, 3, 2), dict(as_closest_canonical=True, image_only=True)] +TEST_CASE_3 = [TEST_IMAGE, TEST_AFFINE, (1, 2, 3), dict(as_closest_canonical=False, image_only=True)] +TEST_CASE_4 = [TEST_IMAGE, TEST_AFFINE, (1, 2, 3), dict(as_closest_canonical=False, image_only=False)] + + +class TestNiftiLoadRead(unittest.TestCase): + + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4]) + def test_orientation(self, array, affine, expected_shape, reader_param): + test_image = make_nifti_image(array, affine) + + # read test cases + load_result = load_nifti(test_image, **reader_param) + if isinstance(load_result, tuple): + data_array, header = load_result + else: + data_array = load_result + header = None + if os.path.exists(test_image): + os.remove(test_image) + + # write test cases + if header is not None: + write_nifti(data_array, header['affine'], test_image, header['original_affine']) + else: + write_nifti(data_array, affine, test_image) + saved = nib.load(test_image) + saved_affine = saved.affine + saved_shape = saved.get_fdata().shape + if os.path.exists(test_image): + os.remove(test_image) + + self.assertTrue(np.allclose(saved_affine, affine)) + self.assertTrue(np.allclose(saved_shape, expected_shape)) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/utils.py b/tests/utils.py index c2b6de1580..e356b8b7b3 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -11,6 +11,8 @@ import os import unittest +import tempfile +import nibabel as nib import numpy as np import torch @@ -26,6 +28,18 @@ def skip_if_quick(obj): return unittest.skipIf(is_quick, "Skipping slow tests")(obj) +def make_nifti_image(array, affine): + """ + Create a temporary nifti image on the disk and return the image name. + User is responsible for deleting the temporary file when done with it. + """ + test_image = nib.Nifti1Image(array, affine) + + _, image_name = tempfile.mkstemp(suffix='.nii.gz') + nib.save(test_image, image_name) + return image_name + + class NumpyImageTestCase2D(unittest.TestCase): im_shape = (128, 128) input_channels = 1 From 3f4968b723fa82cebb2618e73f183520d4967b4a Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Thu, 13 Feb 2020 11:13:38 +0000 Subject: [PATCH 11/11] resolves image writer issues --- examples/unet_inference_3d.py | 2 +- monai/application/handlers/segmentation_saver.py | 14 ++++++++------ monai/data/readers/niftireader.py | 14 +++++++++++--- 3 files changed, 20 insertions(+), 10 deletions(-) diff --git a/examples/unet_inference_3d.py b/examples/unet_inference_3d.py index e28b0d1427..871d5bde94 100644 --- a/examples/unet_inference_3d.py +++ b/examples/unet_inference_3d.py @@ -34,7 +34,7 @@ sys.path.append("..") # assumes the framework is found here, change as necessary application.config.print_config() tempdir = tempfile.mkdtemp() -tempdir = './temp' +# tempdir = './temp' for i in range(50): im, seg = create_test_image_3d(256, 256, 256) diff --git a/monai/application/handlers/segmentation_saver.py b/monai/application/handlers/segmentation_saver.py index 366d92786b..fb3b2a56e7 100644 --- a/monai/application/handlers/segmentation_saver.py +++ b/monai/application/handlers/segmentation_saver.py @@ -11,7 +11,6 @@ import os -import nibabel as nib import torch from ignite.engine import Events @@ -20,7 +19,7 @@ class SegmentationSaver: """ - Event handler triggered on completing every iteration to save the segmentation predictions. + Event handler triggered on completing every iteration to save the segmentation predictions as nifti files. """ def __init__(self, output_path='./', dtype='float32', output_postfix='seg', output_ext='.nii.gz', @@ -85,18 +84,21 @@ def _create_file_basename(postfix, input_file_name, folder_path, data_root_dir=" def __call__(self, engine): """ This method assumes: - - 3rd output of engine.state.batch is a meta data dict - - engine.state.output is already post-processed and ready for saving as a Nifti1Image. + - 3rd output of engine.state.batch is a meta data dict, and have the keys: + 'filename_or_obj' -- for output file name creation + and optionally 'original_affine', 'affine' for data orientation handling. + - output file datatype from `engine.state.output.dtype`. """ meta_data = engine.state.batch[2] # assuming 3rd output of input dataset is a meta data dict filenames = meta_data['filename_or_obj'] + original_affine = meta_data.get('original_affine', None) + affine = meta_data.get('affine', None) engine_output = self.output_transform(engine.state.output) for batch_id, filename in enumerate(filenames): # save a batch of files seg_output = engine_output[batch_id] if isinstance(seg_output, torch.Tensor): seg_output = seg_output.detach().cpu().numpy() - original_affine = nib.load(filename).affine output_filename = self._create_file_basename(self.output_postfix, filename, self.output_path) output_filename = '{}{}'.format(output_filename, self.output_ext) - write_nifti(seg_output, original_affine, output_filename, revert_canonical=True, dtype=seg_output.dtype) + write_nifti(seg_output, affine, output_filename, original_affine, dtype=seg_output.dtype) print('saved: {}'.format(output_filename)) diff --git a/monai/data/readers/niftireader.py b/monai/data/readers/niftireader.py index 1898411694..a700737bbb 100644 --- a/monai/data/readers/niftireader.py +++ b/monai/data/readers/niftireader.py @@ -67,7 +67,8 @@ class NiftiDataset(Dataset): for the image and segmentation arrays separately. """ - def __init__(self, image_files, seg_files, transform=None, seg_transform=None, image_only=True): + def __init__(self, image_files, seg_files, as_closest_canonical=False, + transform=None, seg_transform=None, image_only=True, dtype=None): """ Initializes the dataset with the image and segmentation filename lists. The transform `transform` is applied to the images and `seg_transform` to the segmentations. @@ -75,8 +76,11 @@ def __init__(self, image_files, seg_files, transform=None, seg_transform=None, i Args: image_files (list of str): list of image filenames seg_files (list of str): list of segmentation filenames + as_closest_canonical (bool): if True, load the image as closest to canonical orientation transform (Callable, optional): transform to apply to image arrays seg_transform (Callable, optional): transform to apply to segmentation arrays + image_only (bool): if True return only the image volume, other return image volume and header dict + dtype (np.dtype, optional): if not None convert the loaded image to this data type """ if len(image_files) != len(seg_files): @@ -84,9 +88,11 @@ def __init__(self, image_files, seg_files, transform=None, seg_transform=None, i self.image_files = image_files self.seg_files = seg_files + self.as_closest_canonical = as_closest_canonical self.transform = transform self.seg_transform = seg_transform self.image_only = image_only + self.dtype = dtype def __len__(self): return len(self.image_files) @@ -94,9 +100,11 @@ def __len__(self): def __getitem__(self, index): meta_data = None if self.image_only: - img = load_nifti(self.image_files[index], image_only=self.image_only) + img = load_nifti(self.image_files[index], as_closest_canonical=self.as_closest_canonical, + image_only=self.image_only, dtype=self.dtype) else: - img, meta_data = load_nifti(self.image_files[index], image_only=self.image_only) + img, meta_data = load_nifti(self.image_files[index], as_closest_canonical=self.as_closest_canonical, + image_only=self.image_only, dtype=self.dtype) seg = load_nifti(self.seg_files[index]) # https://github.com/pytorch/vision/issues/9#issuecomment-304224800