diff --git a/.gitignore b/.gitignore index b15c95db2f..2c3face445 100644 --- a/.gitignore +++ b/.gitignore @@ -103,6 +103,7 @@ venv.bak/ # mypy .mypy_cache/ examples/scd_lvsegs.npz +.temp/ .idea/ *~ diff --git a/examples/unet_inference_3d.py b/examples/unet_inference_3d.py new file mode 100644 index 0000000000..aa84a6560d --- /dev/null +++ b/examples/unet_inference_3d.py @@ -0,0 +1,85 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +import tempfile +from glob import glob + +import nibabel as nib +import numpy as np +import torch +import torchvision.transforms as transforms +from ignite.engine import Engine +from torch.utils.data import DataLoader + +from monai import config +from monai.handlers.checkpoint_loader import CheckpointLoader +from monai.handlers.segmentation_saver import SegmentationSaver +from monai.data.nifti_reader import NiftiDataset +from monai.transforms import AddChannel, Rescale, ToTensor +from monai.networks.nets.unet import UNet +from monai.networks.utils import predict_segmentation +from monai.data.synthetic import create_test_image_3d +from monai.utils.sliding_window_inference import sliding_window_inference + +sys.path.append("..") # assumes the framework is found here, change as necessary +config.print_config() + +tempdir = tempfile.mkdtemp() +# tempdir = './temp' +for i in range(50): + im, seg = create_test_image_3d(256, 256, 256) + + n = nib.Nifti1Image(im, np.eye(4)) + nib.save(n, os.path.join(tempdir, 'im%i.nii.gz' % i)) + + n = nib.Nifti1Image(seg, np.eye(4)) + nib.save(n, os.path.join(tempdir, 'seg%i.nii.gz' % i)) + +images = sorted(glob(os.path.join(tempdir, 'im*.nii.gz'))) +segs = sorted(glob(os.path.join(tempdir, 'seg*.nii.gz'))) +imtrans = transforms.Compose([Rescale(), AddChannel(), ToTensor()]) +segtrans = transforms.Compose([AddChannel(), ToTensor()]) +ds = NiftiDataset(images, segs, transform=imtrans, seg_transform=segtrans, image_only=False) + +device = torch.device("cpu:0") +roi_size = (64, 64, 64) +sw_batch_size = 4 +net = UNet( + dimensions=3, + in_channels=1, + num_classes=1, + channels=(16, 32, 64, 128, 256), + strides=(2, 2, 2, 2), + num_res_units=2, +) +net.to(device) + + +def _sliding_window_processor(_engine, batch): + net.eval() + img, seg, meta_data = batch + with torch.no_grad(): + seg_probs = sliding_window_inference(img, roi_size, sw_batch_size, lambda x: net(x)[0], device) + return predict_segmentation(seg_probs) + + +infer_engine = Engine(_sliding_window_processor) + +# checkpoint_handler = ModelCheckpoint('./', 'net', n_saved=10, save_interval=3, require_empty=False) +# infer_engine.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=checkpoint_handler, to_save={'net': net}) + +SegmentationSaver(output_path='tempdir', output_ext='.nii.gz', output_postfix='seg').attach(infer_engine) +CheckpointLoader(load_path='./net_checkpoint_9.pth', load_dict={'net': net}).attach(infer_engine) + +loader = DataLoader(ds, batch_size=1, num_workers=1, pin_memory=torch.cuda.is_available()) +state = infer_engine.run(loader) diff --git a/examples/unet_segmentation_3d.py b/examples/unet_segmentation_3d.py index e4b3d4e5f2..86fcaef5e7 100644 --- a/examples/unet_segmentation_3d.py +++ b/examples/unet_segmentation_3d.py @@ -30,40 +30,14 @@ from monai.transforms import (AddChannel, Rescale, ToTensor, UniformRandomPatch) from monai.handlers.stats_handler import StatsHandler from monai.handlers.mean_dice import MeanDice -from monai.transforms.utils import rescale_array from monai.visualize import img2tensorboard +from monai.data.synthetic import create_test_image_3d # assumes the framework is found here, change as necessary sys.path.append("..") config.print_config() - -def create_test_image_3d(height, width, depth, num_objs=12, rad_max=30, noise_max=0.0, num_seg_classes=5): - '''Return a noisy 3D image and segmentation.''' - image = np.zeros((width, height, depth)) - - for i in range(num_objs): - x = np.random.randint(rad_max, width - rad_max) - y = np.random.randint(rad_max, height - rad_max) - z = np.random.randint(rad_max, depth - rad_max) - rad = np.random.randint(5, rad_max) - spy, spx, spz = np.ogrid[-x:width - x, -y:height - y, -z:depth - z] - circle = (spx * spx + spy * spy + spz * spz) <= rad * rad - - if num_seg_classes > 1: - image[circle] = np.ceil(np.random.random() * num_seg_classes) - else: - image[circle] = np.random.random() * 0.5 + 0.5 - - labels = np.ceil(image).astype(np.int32) - - norm = np.random.uniform(0, num_seg_classes * noise_max, size=image.shape) - noisyimage = rescale_array(np.maximum(image, norm)) - - return noisyimage, labels - - tempdir = tempfile.mkdtemp() for i in range(50): @@ -82,7 +56,7 @@ def create_test_image_3d(height, width, depth, num_objs=12, rad_max=30, noise_ma segtrans = transforms.Compose([AddChannel(), UniformRandomPatch((64, 64, 64)), ToTensor()]) -ds = NiftiDataset(images, segs, imtrans, segtrans) +ds = NiftiDataset(images, segs, transform=imtrans, seg_transform=segtrans) loader = DataLoader(ds, batch_size=10, num_workers=2, pin_memory=torch.cuda.is_available()) im, seg = monai.utils.misc.first(loader) @@ -115,7 +89,9 @@ def _loss_fn(i, j): output_transform=lambda x, y, y_pred, loss: [y_pred, loss.item(), y]) checkpoint_handler = ModelCheckpoint('./', 'net', n_saved=10, require_empty=False) -trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=checkpoint_handler, to_save={'net': net}) +trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED, + handler=checkpoint_handler, + to_save={'net': net, 'opt': opt}) dice_metric = MeanDice(add_sigmoid=True, output_transform=lambda output: (output[0][0], output[2])) dice_metric.attach(trainer, "Training Dice") diff --git a/monai/__init__.py b/monai/__init__.py index 4662715a93..d101b7d6dc 100644 --- a/monai/__init__.py +++ b/monai/__init__.py @@ -16,7 +16,7 @@ __copyright__ = "(c) 2020 MONAI Consortium" __version__tuple__ = (0, 0, 1) -__version__ = "%i.%i.%i" % (__version__tuple__) +__version__ = "%i.%i.%i" % __version__tuple__ __basedir__ = os.path.dirname(__file__) diff --git a/monai/data/nifti_reader.py b/monai/data/nifti_reader.py index 2cf78971bc..b66a549cea 100644 --- a/monai/data/nifti_reader.py +++ b/monai/data/nifti_reader.py @@ -14,6 +14,7 @@ import random from torch.utils.data import Dataset +from torch.utils.data._utils.collate import np_str_obj_array_pattern from monai.utils.module import export @@ -31,25 +32,32 @@ def load_nifti(filename_or_obj, as_closest_canonical=False, image_only=True, dty Returns: The loaded image volume if `image_only` is True, or a tuple containing the volume and the Nifti header in dict format otherwise + + Note: + header['original_affine'] stores the original affine loaded from `filename_or_obj`. + header['affine'] stores the affine after the optional `as_closest_canonical` transform. """ img = nib.load(filename_or_obj) + header = dict(img.header) + header['filename_or_obj'] = filename_or_obj + header['original_affine'] = img.affine + header['affine'] = img.affine + header['as_closest_canonical'] = as_closest_canonical + if as_closest_canonical: img = nib.as_closest_canonical(img) + header['affine'] = img.affine if dtype is not None: dat = img.get_fdata(dtype=dtype) else: dat = np.asanyarray(img.dataobj) - header = dict(img.header) - header['filename_or_obj'] = filename_or_obj - if image_only: return dat - else: - return dat, header + return dat, header @export("monai.data") @@ -59,7 +67,8 @@ class NiftiDataset(Dataset): for the image and segmentation arrays separately. """ - def __init__(self, image_files, seg_files, transform=None, seg_transform=None): + def __init__(self, image_files, seg_files, as_closest_canonical=False, + transform=None, seg_transform=None, image_only=True, dtype=None): """ Initializes the dataset with the image and segmentation filename lists. The transform `transform` is applied to the images and `seg_transform` to the segmentations. @@ -67,8 +76,11 @@ def __init__(self, image_files, seg_files, transform=None, seg_transform=None): Args: image_files (list of str): list of image filenames seg_files (list of str): list of segmentation filenames + as_closest_canonical (bool): if True, load the image as closest to canonical orientation transform (Callable, optional): transform to apply to image arrays seg_transform (Callable, optional): transform to apply to segmentation arrays + image_only (bool): if True return only the image volume, other return image volume and header dict + dtype (np.dtype, optional): if not None convert the loaded image to this data type """ if len(image_files) != len(seg_files): @@ -76,14 +88,23 @@ def __init__(self, image_files, seg_files, transform=None, seg_transform=None): self.image_files = image_files self.seg_files = seg_files - self.transform = transform - self.seg_transform = seg_transform + self.as_closest_canonical = as_closest_canonical + self.transform = transform + self.seg_transform = seg_transform + self.image_only = image_only + self.dtype = dtype def __len__(self): return len(self.image_files) def __getitem__(self, index): - img = load_nifti(self.image_files[index]) + meta_data = None + if self.image_only: + img = load_nifti(self.image_files[index], as_closest_canonical=self.as_closest_canonical, + image_only=self.image_only, dtype=self.dtype) + else: + img, meta_data = load_nifti(self.image_files[index], as_closest_canonical=self.as_closest_canonical, + image_only=self.image_only, dtype=self.dtype) seg = load_nifti(self.seg_files[index]) # https://github.com/pytorch/vision/issues/9#issuecomment-304224800 @@ -97,4 +118,14 @@ def __getitem__(self, index): random.seed(seed) # ensure randomized transforms roll the same values for segmentations as images seg = self.seg_transform(seg) - return img, seg + if self.image_only or meta_data is None: + return img, seg + + compatible_meta = {} + for meta_key in meta_data: + meta_datum = meta_data[meta_key] + if type(meta_datum).__name__ == 'ndarray' \ + and np_str_obj_array_pattern.search(meta_datum.dtype.str) is not None: + continue + compatible_meta[meta_key] = meta_datum + return img, seg, compatible_meta diff --git a/monai/data/nifti_writer.py b/monai/data/nifti_writer.py new file mode 100644 index 0000000000..44e454d91a --- /dev/null +++ b/monai/data/nifti_writer.py @@ -0,0 +1,41 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import nibabel as nib + + +def write_nifti(data, affine, file_name, target_affine=None, dtype="float32"): + """Write numpy data into nifti files to disk. + + Args: + data (numpy.ndarray): input data to write to file. + affine (numpy.ndarray): affine information for the data. + file_name (string): expected file name that saved on disk. + target_affine (numpy.ndarray, optional): + before saving the (data, affine), transform the data into the orientation defined by `target_affine`. + dtype (np.dtype, optional): convert the image to save to this data type. + """ + assert isinstance(data, np.ndarray), 'input data must be numpy array.' + if affine is None: + affine = np.eye(4) + + if target_affine is None: + results_img = nib.Nifti1Image(data.astype(dtype), affine) + else: + start_ornt = nib.orientations.io_orientation(affine) + target_ornt = nib.orientations.io_orientation(target_affine) + ornt_transform = nib.orientations.ornt_transform(start_ornt, target_ornt) + + reverted_results = nib.orientations.apply_orientation(data, ornt_transform) + results_img = nib.Nifti1Image(reverted_results.astype(dtype), target_affine) + + nib.save(results_img, file_name) diff --git a/monai/data/synthetic.py b/monai/data/synthetic.py index c7120d16b5..a51d730357 100644 --- a/monai/data/synthetic.py +++ b/monai/data/synthetic.py @@ -9,7 +9,6 @@ # See the License for the specific language governing permissions and # limitations under the License. - import numpy as np from monai.transforms.utils import rescale_array @@ -28,7 +27,7 @@ def create_test_image_2d(width, height, num_objs=12, rad_max=30, noise_max=0.0, x = np.random.randint(rad_max, width - rad_max) y = np.random.randint(rad_max, height - rad_max) rad = np.random.randint(5, rad_max) - spy, spx = np.ogrid[-x : width - x, -y : height - y] + spy, spx = np.ogrid[-x:width - x, -y:height - y] circle = (spx * spx + spy * spy) <= rad * rad if num_seg_classes > 1: @@ -42,3 +41,32 @@ def create_test_image_2d(width, height, num_objs=12, rad_max=30, noise_max=0.0, noisyimage = rescale_array(np.maximum(image, norm)) return noisyimage, labels + + +def create_test_image_3d(height, width, depth, num_objs=12, rad_max=30, noise_max=0.0, num_seg_classes=5): + """ + Return a noisy 3D image and segmentation. + + See also: create_test_image_2d + """ + image = np.zeros((width, height, depth)) + + for i in range(num_objs): + x = np.random.randint(rad_max, width - rad_max) + y = np.random.randint(rad_max, height - rad_max) + z = np.random.randint(rad_max, depth - rad_max) + rad = np.random.randint(5, rad_max) + spy, spx, spz = np.ogrid[-x:width - x, -y:height - y, -z:depth - z] + circle = (spx * spx + spy * spy + spz * spz) <= rad * rad + + if num_seg_classes > 1: + image[circle] = np.ceil(np.random.random() * num_seg_classes) + else: + image[circle] = np.random.random() * 0.5 + 0.5 + + labels = np.ceil(image).astype(np.int32) + + norm = np.random.uniform(0, num_seg_classes * noise_max, size=image.shape) + noisyimage = rescale_array(np.maximum(image, norm)) + + return noisyimage, labels diff --git a/monai/data/utils.py b/monai/data/utils.py index b1731a0945..f8c1f7722a 100644 --- a/monai/data/utils.py +++ b/monai/data/utils.py @@ -9,6 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import math from itertools import starmap, product import numpy as np @@ -64,6 +65,57 @@ def iter_patch_slices(dims, patch_size, start_pos=()): yield tuple(slice(s, s + p) for s, p in zip(position[::-1], patch_size)) +def dense_patch_slices(image_size, patch_size, scan_interval): + """ + Enumerate all slices defining 2D/3D patches of size `patch_size` from an `image_size` input image. + + Args: + image_size (tuple of int): dimensions of image to iterate over + patch_size (tuple of int): size of patches to generate slices + scan_interval (tuple of int): dense patch sampling interval + + Returns: + a list of slice objects defining each patch + """ + num_spatial_dims = len(image_size) + if num_spatial_dims not in (2, 3): + raise ValueError('image_size should has 2 or 3 elements') + patch_size = get_valid_patch_size(image_size, patch_size) + scan_interval = ensure_tuple_size(scan_interval, num_spatial_dims) + + scan_num = [int(math.ceil(float(image_size[i]) / scan_interval[i])) if scan_interval[i] != 0 else 1 + for i in range(num_spatial_dims)] + slices = [] + if num_spatial_dims == 3: + for i in range(scan_num[0]): + start_i = i * scan_interval[0] + start_i -= max(start_i + patch_size[0] - image_size[0], 0) + slice_i = slice(start_i, start_i + patch_size[0]) + + for j in range(scan_num[1]): + start_j = j * scan_interval[1] + start_j -= max(start_j + patch_size[1] - image_size[1], 0) + slice_j = slice(start_j, start_j + patch_size[1]) + + for k in range(0, scan_num[2]): + start_k = k * scan_interval[2] + start_k -= max(start_k + patch_size[2] - image_size[2], 0) + slice_k = slice(start_k, start_k + patch_size[2]) + slices.append((slice_i, slice_j, slice_k)) + else: + for i in range(scan_num[0]): + start_i = i * scan_interval[0] + start_i -= max(start_i + patch_size[0] - image_size[0], 0) + slice_i = slice(start_i, start_i + patch_size[0]) + + for j in range(scan_num[1]): + start_j = j * scan_interval[1] + start_j -= max(start_j + patch_size[1] - image_size[1], 0) + slice_j = slice(start_j, start_j + patch_size[1]) + slices.append((slice_i, slice_j)) + return slices + + def iter_patch(arr, patch_size, start_pos=(), copy_back=True, pad_mode="wrap", **pad_opts): """ Yield successive patches from `arr' of size `patchSize'. The iteration can start from position `startPos' in `arr' diff --git a/monai/handlers/checkpoint_loader.py b/monai/handlers/checkpoint_loader.py new file mode 100644 index 0000000000..bbf1323a17 --- /dev/null +++ b/monai/handlers/checkpoint_loader.py @@ -0,0 +1,46 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from ignite.engine import Events +from ignite.handlers import Checkpoint + +import monai + + +@monai.utils.export("monai.handlers") +@monai.utils.alias("CheckpointLoader") +class CheckpointLoader: + """ + CheckpointLoader acts as an ignite handler to load checkpoint data from file. + It can load variables for network, optimizer, lr_scheduler. + And also can restore training if load the state_dict of ignite engine. + + Args: + load_path (string): the file path of checkpoint, it should be a PyTorch pth file. + load_dict (dict): target objects that load checkpoint to. examples: + {'network': net, 'optimizer': optimizer, 'engine', engine} + + """ + + def __init__(self, load_path, load_dict): + assert load_path is not None, 'must provide clear path to load checkpoint.' + self.load_path = load_path + assert load_dict is not None and len(load_dict) > 0, 'must provide target objects to load.' + self.load_dict = load_dict + + def attach(self, engine): + return engine.add_event_handler(Events.STARTED, self) + + def __call__(self, engine): + checkpoint = torch.load(self.load_path) + Checkpoint.load_objects(to_load=self.load_dict, checkpoint=checkpoint) + print('Restored all variables from {}'.format(self.load_path)) diff --git a/monai/handlers/metric_logger.py b/monai/handlers/metric_logger.py index cb46aac351..7a0ff3efe0 100644 --- a/monai/handlers/metric_logger.py +++ b/monai/handlers/metric_logger.py @@ -12,6 +12,7 @@ from collections import defaultdict import monai +from ignite.engine import Events @monai.utils.export("monai.handlers") @@ -25,7 +26,7 @@ def __init__(self, loss_transform=lambda x: x, metric_transform=lambda x: x): self.metrics = defaultdict(list) def attach(self, engine): - return engine.add_event_handler(monai.application.engine.Events.ITERATION_COMPLETED, self) + return engine.add_event_handler(Events.ITERATION_COMPLETED, self) def __call__(self, engine): self.loss.append(self.loss_transform(engine.state.output)) diff --git a/monai/handlers/segmentation_saver.py b/monai/handlers/segmentation_saver.py new file mode 100644 index 0000000000..31e6881678 --- /dev/null +++ b/monai/handlers/segmentation_saver.py @@ -0,0 +1,106 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import torch +from ignite.engine import Events + +from monai.data.nifti_writer import write_nifti + + +class SegmentationSaver: + """ + Event handler triggered on completing every iteration to save the segmentation predictions as nifti files. + """ + + def __init__(self, output_path='./', dtype='float32', output_postfix='seg', output_ext='.nii.gz', + output_transform=lambda x: x): + """ + Args: + output_path (str): output image directory. + dtype (str): to convert the image to save to this datatype. + output_postfix (str): a string appended to all output file names. + output_ext (str): output file extension name. + output_transform (Callable): a callable that is used to transform the + ignite.engine.output into the form expected nifti image data. + The first dimension of this transform's output will be treated as the + batch dimension. Each item in the batch will be saved individually. + """ + self.output_path = output_path + self.dtype = dtype + self.output_postfix = output_postfix + self.output_ext = output_ext + self.output_transform = output_transform + + def attach(self, engine): + return engine.add_event_handler(Events.ITERATION_COMPLETED, self) + + @staticmethod + def _create_file_basename(postfix, input_file_name, folder_path, data_root_dir=""): + """ + Utility function to create the path to the output file based on the input + filename (extension is added by lib level writer before writing the file) + + Args: + postfix (str): output name's postfix + input_file_name (str): path to the input image file + folder_path (str): path for the output file + data_root_dir (str): if not empty, it specifies the beginning parts of the input file's + absolute path. This is used to compute `input_file_rel_path`, the relative path to the file from + `data_root_dir` to preserve folder structure when saving in case there are files in different + folders with the same file names. + """ + + # get the filename and directory + filedir, filename = os.path.split(input_file_name) + + # jettison the extension to have just filename + filename, ext = os.path.splitext(filename) + while ext != "": + filename, ext = os.path.splitext(filename) + + # use data_root_dir to find relative path to file + filedir_rel_path = "" + if data_root_dir: + filedir_rel_path = os.path.relpath(filedir, data_root_dir) + + # sub-folder path will be original name without the extension + subfolder_path = os.path.join(folder_path, filedir_rel_path, filename) + if not os.path.exists(subfolder_path): + os.makedirs(subfolder_path) + + # add the sub-folder plus the postfix name to become the file basename in the output path + return os.path.join(subfolder_path, filename + "_" + postfix) + + def __call__(self, engine): + """ + This method assumes: + - 3rd output of engine.state.batch is a meta data dict, and have the keys: + 'filename_or_obj' -- for output file name creation + and optionally 'original_affine', 'affine' for data orientation handling. + - output file datatype from `engine.state.output.dtype`. + """ + meta_data = engine.state.batch[2] # assuming 3rd output of input dataset is a meta data dict + filenames = meta_data['filename_or_obj'] + original_affine = meta_data.get('original_affine', None) + affine = meta_data.get('affine', None) + engine_output = self.output_transform(engine.state.output) + for batch_id, filename in enumerate(filenames): # save a batch of files + seg_output = engine_output[batch_id] + _affine = affine[batch_id] + _original_affine = original_affine[batch_id] + if isinstance(seg_output, torch.Tensor): + seg_output = seg_output.detach().cpu().numpy() + output_filename = self._create_file_basename(self.output_postfix, filename, self.output_path) + output_filename = '{}{}'.format(output_filename, self.output_ext) + write_nifti(seg_output, _affine, output_filename, _original_affine, dtype=seg_output.dtype) + print('saved: {}'.format(output_filename)) diff --git a/monai/transforms/transforms.py b/monai/transforms/transforms.py index 7439ac6230..4fc52fad0f 100644 --- a/monai/transforms/transforms.py +++ b/monai/transforms/transforms.py @@ -116,3 +116,32 @@ def __call__(self, img): if self.dtype != img.dtype: img = img.astype(self.dtype) return img + + +@export +class ImageEndPadder: + """Performs padding by appending to the end of the data all on one side for each dimension. + Uses np.pad so in practice, a mode needs to be provided. See numpy.lib.arraypad.pad + for additional details. + + Args: + out_size (list): the size of region of interest at the end of the operation. + mode (string): a portion from numpy.lib.arraypad.pad is copied below. + dtype: output data format. + """ + + def __init__(self, out_size, mode, dtype=np.float32): + assert out_size is not None and isinstance(out_size, (list, tuple)), 'out_size must be list or tuple' + self.out_size = out_size + assert isinstance(mode, str), 'mode must be str' + self.mode = mode + self.dtype = dtype + + def _determine_data_pad_width(self, data_shape): + return [(0, max(self.out_size[i] - data_shape[i], 0)) for i in range(len(self.out_size))] + + def __call__(self, img): + data_pad_width = self._determine_data_pad_width(img.shape[2:]) + all_pad_width = [(0, 0), (0, 0)] + data_pad_width + img = np.pad(img, all_pad_width, self.mode) + return img diff --git a/monai/utils/sliding_window_inference.py b/monai/utils/sliding_window_inference.py new file mode 100644 index 0000000000..2efc7a7481 --- /dev/null +++ b/monai/utils/sliding_window_inference.py @@ -0,0 +1,119 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from monai.transforms.transforms import ImageEndPadder +from monai.transforms.transforms import ToTensor +from monai.data.utils import dense_patch_slices + + +def sliding_window_inference(inputs, roi_size, sw_batch_size, predictor, device): + """Use SlidingWindow method to execute inference. + + Args: + inputs (numpy array): input image to be processed (assuming NCHW[D]) + roi_size (list, tuple): the window size to execute SlidingWindow inference. + sw_batch_size (int): the batch size to run window slices. + predictor (Callable): given input tensor `patch_data` in shape NCHW[D], `predictor(patch_data)` + should return a prediction with the same spatial shape and batch_size, i.e. NMHW[D]; + where HW[D] represents the patch spatial size, M is the number of output channels, N is `sw_batch_size`. + device: on which device to execute model inference, cpu or gpu. + + Note: + must be channel first, support both 2D and 3D. + input data must have batch dim. + execute on 1 image/per inference, run a batch of window slices of 1 input image. + """ + num_spatial_dims = len(inputs.shape) - 2 + assert len(roi_size) == num_spatial_dims, 'roi_size {} does not match input dims.'.format(roi_size) + + # determine image spatial size and batch size + # Note: all input images must have the same image size and batch size + image_size = list(inputs.shape[2:]) + batch_size = inputs.shape[0] + + # TODO: Enable batch sizes > 1 in future. + if batch_size > 1: + raise NotImplementedError + + original_image_size = [image_size[i] for i in range(num_spatial_dims)] + # in case that image size is smaller than roi size + image_size = tuple(max(image_size[i], roi_size[i]) for i in range(num_spatial_dims)) + inputs = ImageEndPadder(roi_size, 'constant')(inputs) # in np array + inputs = ToTensor()(inputs) + + # TODO: interval from user's specification + scan_interval = _get_scan_interval(image_size, roi_size, num_spatial_dims) + + # Store all slices in list. + slices = dense_patch_slices(image_size, roi_size, scan_interval) + + slice_batches = [] + for slice_index in range(0, len(slices), sw_batch_size): + slice_index_range = range(slice_index, min(slice_index + sw_batch_size, len(slices))) + input_slices = [] + for curr_index in slice_index_range: + if num_spatial_dims == 3: + slice_i, slice_j, slice_k = slices[curr_index] + input_slices.append(inputs[0, :, slice_i, slice_j, slice_k]) + else: + slice_i, slice_j = slices[curr_index] + input_slices.append(inputs[0, :, slice_i, slice_j]) + slice_batches.append(torch.stack(input_slices)) + + # Perform predictions + output_rois = list() + for data in slice_batches: + seg_prob = predictor(data) # batched patch segmentation + output_rois.append(seg_prob) + + # stitching output image + output_classes = output_rois[0].shape[1] + output_shape = [batch_size, output_classes] + list(image_size) + + # allocate memory to store the full output and the count for overlapping parts + output_image = torch.zeros(output_shape, dtype=torch.float32, device=device) + count_map = torch.zeros(output_shape, dtype=torch.float32, device=device) + + for window_id, slice_index in enumerate(range(0, len(slices), sw_batch_size)): + slice_index_range = range(slice_index, min(slice_index + sw_batch_size, len(slices))) + # store the result in the proper location of the full output + for curr_index in slice_index_range: + if num_spatial_dims == 3: + slice_i, slice_j, slice_k = slices[curr_index] + output_image[0, :, slice_i, slice_j, slice_k] += output_rois[window_id][curr_index - slice_index, :] + count_map[0, :, slice_i, slice_j, slice_k] += 1. + else: + slice_i, slice_j = slices[curr_index] + output_image[0, :, slice_i, slice_j] += output_rois[window_id][curr_index - slice_index, :] + count_map[0, :, slice_i, slice_j] += 1. + + # account for any overlapping sections + output_image /= count_map + + if num_spatial_dims == 3: + return output_image[..., :original_image_size[0], :original_image_size[1], :original_image_size[2]] + return output_image[..., :original_image_size[0], :original_image_size[1]] # 2D + + +def _get_scan_interval(image_size, roi_size, num_spatial_dims): + assert (len(image_size) == num_spatial_dims), 'image coord different from spatial dims.' + assert (len(roi_size) == num_spatial_dims), 'roi coord different from spatial dims.' + + scan_interval = [1 for _ in range(num_spatial_dims)] + for i in range(num_spatial_dims): + if roi_size[i] == image_size[i]: + scan_interval[i] = int(roi_size[i]) + else: + # this means that it's r-16 (if r>=64) and r*0.75 (if r<=64) + scan_interval[i] = int(max(roi_size[i] - 16, roi_size[i] * 0.75)) + return tuple(scan_interval) diff --git a/runtests.sh b/runtests.sh index a805455bc1..54de9103c1 100755 --- a/runtests.sh +++ b/runtests.sh @@ -4,10 +4,10 @@ set -e homedir="$( cd -P "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" -cd $homedir +cd "$homedir" export PYTHONPATH="$homedir:$PYTHONPATH" -echo $PYTHONPATH +echo "$PYTHONPATH" # configuration values doCoverage=false diff --git a/tests/test_dice_loss.py b/tests/test_dice_loss.py index a322fb8285..5c1d7e8eb1 100644 --- a/tests/test_dice_loss.py +++ b/tests/test_dice_loss.py @@ -14,7 +14,7 @@ import torch from parameterized import parameterized -from monai.losses import DiceLoss +from monai.losses.dice import DiceLoss TEST_CASE_1 = [ # shape: (1, 1, 2, 2), (1, 1, 2, 2) { diff --git a/tests/test_image_end_padder.py b/tests/test_image_end_padder.py new file mode 100644 index 0000000000..1d705a0ce1 --- /dev/null +++ b/tests/test_image_end_padder.py @@ -0,0 +1,36 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np +from parameterized import parameterized +from monai.transforms.transforms import ImageEndPadder + +TEST_CASE_1 = [ + { + 'out_size': [16, 16, 8], + 'mode': 'constant' + }, + np.zeros((1, 3, 8, 8, 4)), + np.zeros((1, 3, 16, 16, 8)), +] + +class TestImageEndPadder(unittest.TestCase): + + @parameterized.expand([TEST_CASE_1]) + def test_image_end_pad_shape(self, input_param, input_data, expected_val): + padder = ImageEndPadder(**input_param) + result = padder(input_data) + self.assertAlmostEqual(result.shape, expected_val.shape) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_nifti_rw.py b/tests/test_nifti_rw.py new file mode 100644 index 0000000000..d758380c70 --- /dev/null +++ b/tests/test_nifti_rw.py @@ -0,0 +1,65 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import unittest + +import nibabel as nib +import numpy as np +from parameterized import parameterized + +from monai.data.nifti_reader import load_nifti +from monai.data.nifti_writer import write_nifti + +from .utils import make_nifti_image + +TEST_IMAGE = np.zeros((1, 2, 3)) +TEST_AFFINE = np.array([[-5.3, 0., 0., 102.01], [0., 0.52, 2.17, -7.50], [-0., 1.98, -0.26, -23.12], [0., 0., 0., 1.]]) + +TEST_CASE_1 = [TEST_IMAGE, TEST_AFFINE, (1, 2, 3), dict(as_closest_canonical=True, image_only=False)] +TEST_CASE_2 = [TEST_IMAGE, TEST_AFFINE, (1, 3, 2), dict(as_closest_canonical=True, image_only=True)] +TEST_CASE_3 = [TEST_IMAGE, TEST_AFFINE, (1, 2, 3), dict(as_closest_canonical=False, image_only=True)] +TEST_CASE_4 = [TEST_IMAGE, TEST_AFFINE, (1, 2, 3), dict(as_closest_canonical=False, image_only=False)] + + +class TestNiftiLoadRead(unittest.TestCase): + + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4]) + def test_orientation(self, array, affine, expected_shape, reader_param): + test_image = make_nifti_image(array, affine) + + # read test cases + load_result = load_nifti(test_image, **reader_param) + if isinstance(load_result, tuple): + data_array, header = load_result + else: + data_array = load_result + header = None + if os.path.exists(test_image): + os.remove(test_image) + + # write test cases + if header is not None: + write_nifti(data_array, header['affine'], test_image, header['original_affine']) + else: + write_nifti(data_array, affine, test_image) + saved = nib.load(test_image) + saved_affine = saved.affine + saved_shape = saved.get_fdata().shape + if os.path.exists(test_image): + os.remove(test_image) + + self.assertTrue(np.allclose(saved_affine, affine)) + self.assertTrue(np.allclose(saved_shape, expected_shape)) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_sliding_window_inference.py b/tests/test_sliding_window_inference.py new file mode 100644 index 0000000000..e0d727c407 --- /dev/null +++ b/tests/test_sliding_window_inference.py @@ -0,0 +1,45 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import torch +from parameterized import parameterized + +from monai.utils.sliding_window_inference import sliding_window_inference + +TEST_CASE_1 = [(1, 3, 16, 15, 7), (4, 10, 7), 3] # 3D small roi + +TEST_CASE_2 = [(1, 3, 16, 15, 7), (20, 22, 23), 10] # 3D large roi + +TEST_CASE_3 = [(1, 3, 15, 7), (2, 6), 1000] # 2D small roi, large batch + +TEST_CASE_4 = [(1, 3, 16, 7), (80, 50), 7] # 2D large roi + + +class TestSlidingWindowInference(unittest.TestCase): + + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4]) + def test_sliding_window_default(self, image_shape, roi_shape, sw_batch_size): + inputs = np.ones(image_shape) + device = torch.device("cpu:0") + + def compute(data): + return data.to(device) + 1 + + result = sliding_window_inference(inputs, roi_shape, sw_batch_size, compute, device) + expected_val = np.ones(image_shape, dtype=np.float32) + 1 + self.assertTrue(np.allclose(result.numpy(), expected_val)) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/utils.py b/tests/utils.py index 3a37003c50..27f75fe7c3 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -11,6 +11,8 @@ import os import unittest +import tempfile +import nibabel as nib import numpy as np import torch @@ -26,6 +28,18 @@ def skip_if_quick(obj): return unittest.skipIf(is_quick, "Skipping slow tests")(obj) +def make_nifti_image(array, affine): + """ + Create a temporary nifti image on the disk and return the image name. + User is responsible for deleting the temporary file when done with it. + """ + test_image = nib.Nifti1Image(array, affine) + + _, image_name = tempfile.mkstemp(suffix='.nii.gz') + nib.save(test_image, image_name) + return image_name + + class NumpyImageTestCase2D(unittest.TestCase): im_shape = (128, 128) input_channels = 1