From a75d8915e756e56347a5f751890c62a3bb8e2935 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Mon, 9 Mar 2020 10:52:09 +0800 Subject: [PATCH 1/7] [DLMED] add arbitrary format support for all event-handlers --- ...ference_3d.py => unet_inference_3d_array.py} | 17 +++++++++-------- monai/handlers/segmentation_saver.py | 15 +++++++++------ monai/utils/sliding_window_inference.py | 4 ++-- tests/test_sliding_window_inference.py | 2 +- 4 files changed, 21 insertions(+), 17 deletions(-) rename examples/{unet_inference_3d.py => unet_inference_3d_array.py} (83%) diff --git a/examples/unet_inference_3d.py b/examples/unet_inference_3d_array.py similarity index 83% rename from examples/unet_inference_3d.py rename to examples/unet_inference_3d_array.py index aa84a6560d..a4d0b7ffec 100644 --- a/examples/unet_inference_3d.py +++ b/examples/unet_inference_3d_array.py @@ -21,6 +21,9 @@ from ignite.engine import Engine from torch.utils.data import DataLoader +# assumes the framework is found here, change as necessary +sys.path.append("..") + from monai import config from monai.handlers.checkpoint_loader import CheckpointLoader from monai.handlers.segmentation_saver import SegmentationSaver @@ -31,7 +34,6 @@ from monai.data.synthetic import create_test_image_3d from monai.utils.sliding_window_inference import sliding_window_inference -sys.path.append("..") # assumes the framework is found here, change as necessary config.print_config() tempdir = tempfile.mkdtemp() @@ -51,7 +53,7 @@ segtrans = transforms.Compose([AddChannel(), ToTensor()]) ds = NiftiDataset(images, segs, transform=imtrans, seg_transform=segtrans, image_only=False) -device = torch.device("cpu:0") +device = torch.device("cuda:0") roi_size = (64, 64, 64) sw_batch_size = 4 net = UNet( @@ -65,7 +67,7 @@ net.to(device) -def _sliding_window_processor(_engine, batch): +def _sliding_window_processor(engine, batch): net.eval() img, seg, meta_data = batch with torch.no_grad(): @@ -75,11 +77,10 @@ def _sliding_window_processor(_engine, batch): infer_engine = Engine(_sliding_window_processor) -# checkpoint_handler = ModelCheckpoint('./', 'net', n_saved=10, save_interval=3, require_empty=False) -# infer_engine.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=checkpoint_handler, to_save={'net': net}) - -SegmentationSaver(output_path='tempdir', output_ext='.nii.gz', output_postfix='seg').attach(infer_engine) -CheckpointLoader(load_path='./net_checkpoint_9.pth', load_dict={'net': net}).attach(infer_engine) +# for the arrary data format, assume the 3rd item of batch data is the meta_data +SegmentationSaver(output_path='tempdir', output_ext='.nii.gz', output_postfix='seg', + batch_transform=lambda x: x[2]).attach(infer_engine) +CheckpointLoader(load_path='./runs/net_checkpoint_120.pth', load_dict={'net': net}).attach(infer_engine) loader = DataLoader(ds, batch_size=1, num_workers=1, pin_memory=torch.cuda.is_available()) state = infer_engine.run(loader) diff --git a/monai/handlers/segmentation_saver.py b/monai/handlers/segmentation_saver.py index a87e517f81..65b18158e1 100644 --- a/monai/handlers/segmentation_saver.py +++ b/monai/handlers/segmentation_saver.py @@ -23,13 +23,15 @@ class SegmentationSaver: """ def __init__(self, output_path='./', dtype='float32', output_postfix='seg', output_ext='.nii.gz', - output_transform=lambda x: x, name=None): + batch_transform=lambda x: x, output_transform=lambda x: x, name=None): """ Args: output_path (str): output image directory. dtype (str): to convert the image to save to this datatype. output_postfix (str): a string appended to all output file names. output_ext (str): output file extension name. + batch_transform (Callable): a callable that is used to transform the + ignite.engine.batch into expected format to extract the meta_data dictionary. output_transform (Callable): a callable that is used to transform the ignite.engine.output into the form expected nifti image data. The first dimension of this transform's output will be treated as the @@ -40,6 +42,7 @@ def __init__(self, output_path='./', dtype='float32', output_postfix='seg', outp self.dtype = dtype self.output_postfix = output_postfix self.output_ext = output_ext + self.batch_transform = batch_transform self.output_transform = output_transform self.logger = None if name is None else logging.getLogger(name) @@ -88,12 +91,12 @@ def _create_file_basename(postfix, input_file_name, folder_path, data_root_dir=" def __call__(self, engine): """ - This method assumes: - - 3rd output of engine.state.batch is a meta data dict, and have the keys: + This method assumes to have the keys: 'filename_or_obj' -- for output file name creation and optionally 'original_affine', 'affine' for data orientation handling. - output file datatype from `engine.state.output.dtype`. """ + meta_data = self.batch_transform(engine.state.batch) meta_data = engine.state.batch[2] # assuming 3rd output of input dataset is a meta data dict filenames = meta_data['filename_or_obj'] original_affine = meta_data.get('original_affine', None) @@ -101,11 +104,11 @@ def __call__(self, engine): engine_output = self.output_transform(engine.state.output) for batch_id, filename in enumerate(filenames): # save a batch of files seg_output = engine_output[batch_id] - _affine = affine[batch_id] - _original_affine = original_affine[batch_id] + affine_ = affine[batch_id] + original_affine_ = original_affine[batch_id] if isinstance(seg_output, torch.Tensor): seg_output = seg_output.detach().cpu().numpy() output_filename = self._create_file_basename(self.output_postfix, filename, self.output_path) output_filename = '{}{}'.format(output_filename, self.output_ext) - write_nifti(seg_output, _affine, output_filename, _original_affine, dtype=seg_output.dtype) + write_nifti(seg_output, affine_, output_filename, original_affine_, dtype=seg_output.dtype) self.logger.info('saved: {}'.format(output_filename)) diff --git a/monai/utils/sliding_window_inference.py b/monai/utils/sliding_window_inference.py index 2efc7a7481..708a0f284e 100644 --- a/monai/utils/sliding_window_inference.py +++ b/monai/utils/sliding_window_inference.py @@ -10,7 +10,7 @@ # limitations under the License. import torch - +from ignite.utils import convert_tensor from monai.transforms.transforms import ImageEndPadder from monai.transforms.transforms import ToTensor from monai.data.utils import dense_patch_slices @@ -49,7 +49,7 @@ def sliding_window_inference(inputs, roi_size, sw_batch_size, predictor, device) # in case that image size is smaller than roi size image_size = tuple(max(image_size[i], roi_size[i]) for i in range(num_spatial_dims)) inputs = ImageEndPadder(roi_size, 'constant')(inputs) # in np array - inputs = ToTensor()(inputs) + inputs = convert_tensor(ToTensor()(inputs), device, False) # TODO: interval from user's specification scan_interval = _get_scan_interval(image_size, roi_size, num_spatial_dims) diff --git a/tests/test_sliding_window_inference.py b/tests/test_sliding_window_inference.py index e0d727c407..142d3e31e5 100644 --- a/tests/test_sliding_window_inference.py +++ b/tests/test_sliding_window_inference.py @@ -34,7 +34,7 @@ def test_sliding_window_default(self, image_shape, roi_shape, sw_batch_size): device = torch.device("cpu:0") def compute(data): - return data.to(device) + 1 + return data + 1 result = sliding_window_inference(inputs, roi_shape, sw_batch_size, compute, device) expected_val = np.ones(image_shape, dtype=np.float32) + 1 From c7b3168ee7db9e6b8f2a66a89f03c65cda05ec16 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Tue, 10 Mar 2020 01:25:25 +0800 Subject: [PATCH 2/7] [DLMED] update StatsHandler and fix comments --- examples/unet_inference_3d_array.py | 3 +- examples/unet_segmentation_3d_array.py | 4 +-- monai/handlers/stats_handler.py | 28 +++++++++++------- monai/utils/sliding_window_inference.py | 3 +- tests/test_handler_stats.py | 38 +++++++++++++++++++++---- 5 files changed, 55 insertions(+), 21 deletions(-) diff --git a/examples/unet_inference_3d_array.py b/examples/unet_inference_3d_array.py index a4d0b7ffec..07cf243b37 100644 --- a/examples/unet_inference_3d_array.py +++ b/examples/unet_inference_3d_array.py @@ -17,7 +17,6 @@ import nibabel as nib import numpy as np import torch -import torchvision.transforms as transforms from ignite.engine import Engine from torch.utils.data import DataLoader @@ -27,6 +26,7 @@ from monai import config from monai.handlers.checkpoint_loader import CheckpointLoader from monai.handlers.segmentation_saver import SegmentationSaver +import monai.transforms.compose as transforms from monai.data.nifti_reader import NiftiDataset from monai.transforms import AddChannel, Rescale, ToTensor from monai.networks.nets.unet import UNet @@ -80,6 +80,7 @@ def _sliding_window_processor(engine, batch): # for the arrary data format, assume the 3rd item of batch data is the meta_data SegmentationSaver(output_path='tempdir', output_ext='.nii.gz', output_postfix='seg', batch_transform=lambda x: x[2]).attach(infer_engine) +# the model was trained by "unet_segmentation_3d_array" exmple CheckpointLoader(load_path='./runs/net_checkpoint_120.pth', load_dict={'net': net}).attach(infer_engine) loader = DataLoader(ds, batch_size=1, num_workers=1, pin_memory=torch.cuda.is_available()) diff --git a/examples/unet_segmentation_3d_array.py b/examples/unet_segmentation_3d_array.py index 0e68eb67c8..11a4565a3a 100644 --- a/examples/unet_segmentation_3d_array.py +++ b/examples/unet_segmentation_3d_array.py @@ -30,7 +30,7 @@ import monai.transforms.compose as transforms from monai.data.nifti_reader import NiftiDataset -from monai.transforms import (AddChannel, Rescale, ToTensor, UniformRandomPatch) +from monai.transforms import AddChannel, Rescale, ToTensor, UniformRandomPatch from monai.handlers.stats_handler import StatsHandler from monai.handlers.mean_dice import MeanDice from monai.visualize import img2tensorboard @@ -148,7 +148,7 @@ def log_training_loss(engine): # Add stats event handler to print validation stats via evaluator logging.basicConfig(stream=sys.stdout, level=logging.INFO) -val_stats_handler = StatsHandler() +val_stats_handler = StatsHandler(output_transform=lambda output: (None, None)) val_stats_handler.attach(evaluator) # Add early stopping handler to evaluator. diff --git a/monai/handlers/stats_handler.py b/monai/handlers/stats_handler.py index b1ba3563d5..ba0f1633ec 100644 --- a/monai/handlers/stats_handler.py +++ b/monai/handlers/stats_handler.py @@ -19,26 +19,35 @@ class StatsHandler(object): """StatsHandler defines a set of Ignite Event-handlers for all the log printing logics. It's can be used for any Ignite Engine(trainer, validator and evaluator). - And it can support logging for epoch level and iteration level with pre-defined StatsLoggers. - By default, this class logs the dictionary of `engine.state.metrics`. + And it can support logging for epoch level and iteration level with pre-defined loggers. + By default: + (1) epoch_print_logger logs `engine.state.metrics`. + (2) iteration_print_logger logs loss value, expected output format is (y_pred, loss). """ def __init__(self, epoch_print_logger=None, iteration_print_logger=None, + batch_transform=lambda x: x, + output_transform=lambda x: x, name=None): """ Args: epoch_print_logger (Callable): customized callable printer for epoch level logging. - must accept parameter "engine", use default printer if None. + must accept parameter "engine", use default printer if None. iteration_print_logger (Callable): custimized callable printer for iteration level logging. - must accept parameter "engine", use default printer if None. + must accept parameter "engine", use default printer if None. + batch_transform (Callable): a callable that is used to transform the + ignite.engine.batch into expected format to extract input data. + output_transform (Callable): a callable that is used to transform the + ignite.engine.output into expected format to extract several output data. name (str): identifier of logging.logger to use, defaulting to `engine.logger`. """ self.epoch_print_logger = epoch_print_logger self.iteration_print_logger = iteration_print_logger - + self.batch_transform = batch_transform + self.output_transform = output_transform self.logger = None if name is None else logging.getLogger(name) @@ -125,8 +134,8 @@ def _default_iteration_print(self, engine: Engine): engine (ignite.engine): Ignite Engine, it can be a trainer, validator or evaluator. """ - prints_dict = engine.state.metrics - if not prints_dict: + loss = self.output_transform(engine.state.output)[1] + if loss is None: return num_iterations = engine.state.epoch_length current_iteration = (engine.state.iteration - 1) % num_iterations + 1 @@ -138,9 +147,6 @@ def _default_iteration_print(self, engine: Engine): num_epochs, current_iteration, num_iterations) - - for name in sorted(prints_dict): - value = prints_dict[name] - out_str += KEY_VAL_FORMAT.format(name, value) + out_str += KEY_VAL_FORMAT.format('Loss', loss) self.logger.info(out_str) diff --git a/monai/utils/sliding_window_inference.py b/monai/utils/sliding_window_inference.py index 708a0f284e..0125a19d06 100644 --- a/monai/utils/sliding_window_inference.py +++ b/monai/utils/sliding_window_inference.py @@ -12,7 +12,6 @@ import torch from ignite.utils import convert_tensor from monai.transforms.transforms import ImageEndPadder -from monai.transforms.transforms import ToTensor from monai.data.utils import dense_patch_slices @@ -49,7 +48,7 @@ def sliding_window_inference(inputs, roi_size, sw_batch_size, predictor, device) # in case that image size is smaller than roi size image_size = tuple(max(image_size[i], roi_size[i]) for i in range(num_spatial_dims)) inputs = ImageEndPadder(roi_size, 'constant')(inputs) # in np array - inputs = convert_tensor(ToTensor()(inputs), device, False) + inputs = convert_tensor(torch.from_numpy(inputs), device, False) # TODO: interval from user's specification scan_interval = _get_scan_interval(image_size, roi_size, num_spatial_dims) diff --git a/tests/test_handler_stats.py b/tests/test_handler_stats.py index 5bbe17d1c2..58a62133d2 100644 --- a/tests/test_handler_stats.py +++ b/tests/test_handler_stats.py @@ -9,6 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import torch import logging import re import unittest @@ -29,12 +30,12 @@ def test_metrics_print(self): # set up engine def _train_func(engine, batch): - pass + return None, torch.tensor(0.0) engine = Engine(_train_func) # set up dummy metric - @engine.on(Events.ITERATION_COMPLETED) + @engine.on(Events.EPOCH_COMPLETED) def _update_metric(engine): current_metric = engine.state.metrics.get(key_to_print, 0.1) engine.state.metrics[key_to_print] = current_metric + 0.1 @@ -52,9 +53,36 @@ def _update_metric(engine): matched = [] for idx, line in enumerate(output_str.split('\n')): if grep.match(line): - self.assertTrue(has_key_word.match(line)) - matched.append(idx) - self.assertEqual(matched, [1, 2, 3, 5, 6, 7, 8, 10]) + if idx in [5, 10]: + self.assertTrue(has_key_word.match(line)) + + def test_loss_print(self): + log_stream = StringIO() + logging.basicConfig(stream=log_stream, level=logging.INFO) + key_to_handler = 'test_logging' + key_to_print = 'Loss' + + # set up engine + def _train_func(engine, batch): + return None, torch.tensor(0.0) + + engine = Engine(_train_func) + + # set up testing handler + stats_handler = StatsHandler(name=key_to_handler) + stats_handler.attach(engine) + + engine.run(range(3), max_epochs=2) + + # check logging output + output_str = log_stream.getvalue() + grep = re.compile('.*{}.*'.format(key_to_handler)) + has_key_word = re.compile('.*{}.*'.format(key_to_print)) + matched = [] + for idx, line in enumerate(output_str.split('\n')): + if grep.match(line): + if idx in [1, 2, 3, 6, 7, 8]: + self.assertTrue(has_key_word.match(line)) if __name__ == '__main__': From df66cbaa58ae948c16a2eb9561af4636cfdee933 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Tue, 10 Mar 2020 11:26:08 +0800 Subject: [PATCH 3/7] [DLMED] update according to comments --- examples/unet_segmentation_3d_array.py | 2 +- monai/handlers/segmentation_saver.py | 2 +- monai/handlers/stats_handler.py | 7 ++++--- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/examples/unet_segmentation_3d_array.py b/examples/unet_segmentation_3d_array.py index 11a4565a3a..f431b47faf 100644 --- a/examples/unet_segmentation_3d_array.py +++ b/examples/unet_segmentation_3d_array.py @@ -148,7 +148,7 @@ def log_training_loss(engine): # Add stats event handler to print validation stats via evaluator logging.basicConfig(stream=sys.stdout, level=logging.INFO) -val_stats_handler = StatsHandler(output_transform=lambda output: (None, None)) +val_stats_handler = StatsHandler() val_stats_handler.attach(evaluator) # Add early stopping handler to evaluator. diff --git a/monai/handlers/segmentation_saver.py b/monai/handlers/segmentation_saver.py index 65b18158e1..0983f4b135 100644 --- a/monai/handlers/segmentation_saver.py +++ b/monai/handlers/segmentation_saver.py @@ -95,9 +95,9 @@ def __call__(self, engine): 'filename_or_obj' -- for output file name creation and optionally 'original_affine', 'affine' for data orientation handling. - output file datatype from `engine.state.output.dtype`. + And this method assumes self.batch_transform will extract meta data from the input batch. """ meta_data = self.batch_transform(engine.state.batch) - meta_data = engine.state.batch[2] # assuming 3rd output of input dataset is a meta data dict filenames = meta_data['filename_or_obj'] original_affine = meta_data.get('original_affine', None) affine = meta_data.get('affine', None) diff --git a/monai/handlers/stats_handler.py b/monai/handlers/stats_handler.py index ba0f1633ec..3dba0fca22 100644 --- a/monai/handlers/stats_handler.py +++ b/monai/handlers/stats_handler.py @@ -10,7 +10,7 @@ # limitations under the License. import logging - +import torch from ignite.engine import Engine, Events KEY_VAL_FORMAT = '{}: {:.4f} ' @@ -128,14 +128,15 @@ def _default_epoch_print(self, engine: Engine): def _default_iteration_print(self, engine: Engine): """Execute iteration log operation based on Ignite engine.state data. - print the values from ignite state.logs dict. + Print the values from ignite state.logs dict. + Default behaivor is to print loss from output[1], skip if output[1] is not loss. Args: engine (ignite.engine): Ignite Engine, it can be a trainer, validator or evaluator. """ loss = self.output_transform(engine.state.output)[1] - if loss is None: + if loss is None or (isinstance(loss, torch.Tensor) is True and len(loss.shape) > 0): return num_iterations = engine.state.epoch_length current_iteration = (engine.state.iteration - 1) % num_iterations + 1 From 2d13d1de3134211712437410e4e5352a936c1781 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Tue, 10 Mar 2020 13:17:54 +0800 Subject: [PATCH 4/7] [DLMED] fix integration test error --- tests/integration_sliding_window.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/integration_sliding_window.py b/tests/integration_sliding_window.py index 32abbfff5c..db10d7cc49 100644 --- a/tests/integration_sliding_window.py +++ b/tests/integration_sliding_window.py @@ -58,7 +58,8 @@ def _sliding_window_processor(_engine, batch): infer_engine = Engine(_sliding_window_processor) with tempfile.TemporaryDirectory() as temp_dir: - SegmentationSaver(output_path=temp_dir, output_ext='.nii.gz', output_postfix='seg').attach(infer_engine) + SegmentationSaver(output_path=temp_dir, output_ext='.nii.gz', output_postfix='seg', + batch_transform=lambda x: x[2]).attach(infer_engine) infer_engine.run(loader) From 476e5336120973f7720c10d9e533ad8008299ee3 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Tue, 10 Mar 2020 14:27:40 +0800 Subject: [PATCH 5/7] [DLMED] support both Tensor and Float loss value --- monai/handlers/stats_handler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/handlers/stats_handler.py b/monai/handlers/stats_handler.py index 3dba0fca22..69bdc9379d 100644 --- a/monai/handlers/stats_handler.py +++ b/monai/handlers/stats_handler.py @@ -148,6 +148,6 @@ def _default_iteration_print(self, engine: Engine): num_epochs, current_iteration, num_iterations) - out_str += KEY_VAL_FORMAT.format('Loss', loss) + out_str += KEY_VAL_FORMAT.format('Loss', loss.item() if isinstance(loss, torch.Tensor) else loss) self.logger.info(out_str) From 50dadecb097c21d453177fb991973b4c14e81614 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Tue, 10 Mar 2020 18:09:11 +0800 Subject: [PATCH 6/7] [DLMED] add dict based UNet inference example --- examples/unet_inference_3d_dict.py | 94 +++++++++++++++++++++++++++ examples/unet_segmentation_3d_dict.py | 4 +- 2 files changed, 96 insertions(+), 2 deletions(-) create mode 100644 examples/unet_inference_3d_dict.py diff --git a/examples/unet_inference_3d_dict.py b/examples/unet_inference_3d_dict.py new file mode 100644 index 0000000000..c0d27c3c20 --- /dev/null +++ b/examples/unet_inference_3d_dict.py @@ -0,0 +1,94 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +import tempfile +from glob import glob + +import nibabel as nib +import numpy as np +import torch +from ignite.engine import Engine +from torch.utils.data import DataLoader + +# assumes the framework is found here, change as necessary +sys.path.append("..") + +import monai +from monai.data.utils import list_data_collate +from monai.utils.sliding_window_inference import sliding_window_inference +from monai.data.synthetic import create_test_image_3d +from monai.networks.utils import predict_segmentation +from monai.networks.nets.unet import UNet +from monai.transforms.composables import LoadNiftid, AsChannelFirstd +import monai.transforms.compose as transforms +from monai.handlers.segmentation_saver import SegmentationSaver +from monai.handlers.checkpoint_loader import CheckpointLoader +from monai import config + +config.print_config() + +tempdir = tempfile.mkdtemp() +# tempdir = './temp' +for i in range(50): + im, seg = create_test_image_3d(256, 256, 256, channel_dim=-1) + + n = nib.Nifti1Image(im, np.eye(4)) + nib.save(n, os.path.join(tempdir, 'im%i.nii.gz' % i)) + + n = nib.Nifti1Image(seg, np.eye(4)) + nib.save(n, os.path.join(tempdir, 'seg%i.nii.gz' % i)) + +images = sorted(glob(os.path.join(tempdir, 'im*.nii.gz'))) +segs = sorted(glob(os.path.join(tempdir, 'seg*.nii.gz'))) +val_files = [{'img': img, 'seg': seg} for img, seg in zip(images, segs)] +val_transforms = transforms.Compose([ + LoadNiftid(keys=['img', 'seg']), + AsChannelFirstd(keys=['img', 'seg'], channel_dim=-1) +]) +val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) + +device = torch.device("cuda:0") +roi_size = (64, 64, 64) +sw_batch_size = 4 +net = UNet( + dimensions=3, + in_channels=1, + num_classes=1, + channels=(16, 32, 64, 128, 256), + strides=(2, 2, 2, 2), + num_res_units=2, +) +net.to(device) + + +def _sliding_window_processor(engine, batch): + net.eval() + with torch.no_grad(): + seg_probs = sliding_window_inference(batch['img'], roi_size, sw_batch_size, lambda x: net(x)[0], device) + return predict_segmentation(seg_probs) + + +infer_engine = Engine(_sliding_window_processor) + +# for the arrary data format, assume the 3rd item of batch data is the meta_data +SegmentationSaver(output_path='tempdir', output_ext='.nii.gz', output_postfix='seg', + batch_transform=lambda batch: {'filename_or_obj': batch['img.filename_or_obj'], + 'original_affine': batch['img.original_affine'], + 'affine': batch['img.affine'], + }).attach(infer_engine) +# the model was trained by "unet_segmentation_3d_array" exmple +CheckpointLoader(load_path='./runs/net_checkpoint_120.pth', load_dict={'net': net}).attach(infer_engine) + +val_loader = DataLoader(val_ds, batch_size=1, num_workers=4, collate_fn=list_data_collate, + pin_memory=torch.cuda.is_available()) +state = infer_engine.run(val_loader) diff --git a/examples/unet_segmentation_3d_dict.py b/examples/unet_segmentation_3d_dict.py index d7ea3795ea..5eb4dec27c 100644 --- a/examples/unet_segmentation_3d_dict.py +++ b/examples/unet_segmentation_3d_dict.py @@ -70,7 +70,7 @@ # Define nifti dataset, dataloader. ds = monai.data.Dataset(data=train_files, transform=train_transforms) -loader = DataLoader(ds, batch_size=2, num_workers=2, collate_fn=list_data_collate, +loader = DataLoader(ds, batch_size=2, num_workers=4, collate_fn=list_data_collate, pin_memory=torch.cuda.is_available()) check_data = monai.utils.misc.first(loader) print(check_data['img'].shape, check_data['seg'].shape) @@ -190,7 +190,7 @@ def log_metrics_to_tensorboard(engine): logging.basicConfig(stream=sys.stdout, level=logging.INFO) train_ds = monai.data.Dataset(data=train_files, transform=train_transforms) -train_loader = DataLoader(train_ds, batch_size=2, num_workers=8, collate_fn=list_data_collate, +train_loader = DataLoader(train_ds, batch_size=2, num_workers=4, collate_fn=list_data_collate, pin_memory=torch.cuda.is_available()) train_epochs = 30 From 2a85d9da1adaacdba6dc2529ee33ba512e949e4c Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Tue, 10 Mar 2020 11:46:03 +0000 Subject: [PATCH 7/7] update docstring --- examples/unet_inference_3d_array.py | 1 + examples/unet_inference_3d_dict.py | 1 + examples/unet_segmentation_3d_array.py | 2 +- examples/unet_segmentation_3d_dict.py | 2 +- monai/handlers/segmentation_saver.py | 14 +++++++++----- monai/handlers/stats_handler.py | 7 +++---- 6 files changed, 16 insertions(+), 11 deletions(-) diff --git a/examples/unet_inference_3d_array.py b/examples/unet_inference_3d_array.py index 07cf243b37..8fe417c7dd 100644 --- a/examples/unet_inference_3d_array.py +++ b/examples/unet_inference_3d_array.py @@ -38,6 +38,7 @@ tempdir = tempfile.mkdtemp() # tempdir = './temp' +print('generating synthetic data to {} (this may take a while)'.format(tempdir)) for i in range(50): im, seg = create_test_image_3d(256, 256, 256) diff --git a/examples/unet_inference_3d_dict.py b/examples/unet_inference_3d_dict.py index c0d27c3c20..405b49aa8d 100644 --- a/examples/unet_inference_3d_dict.py +++ b/examples/unet_inference_3d_dict.py @@ -39,6 +39,7 @@ tempdir = tempfile.mkdtemp() # tempdir = './temp' +print('generating synthetic data to {} (this may take a while)'.format(tempdir)) for i in range(50): im, seg = create_test_image_3d(256, 256, 256, channel_dim=-1) diff --git a/examples/unet_segmentation_3d_array.py b/examples/unet_segmentation_3d_array.py index f431b47faf..de5996470d 100644 --- a/examples/unet_segmentation_3d_array.py +++ b/examples/unet_segmentation_3d_array.py @@ -41,7 +41,7 @@ # Create a temporary directory and 50 random image, mask paris tempdir = tempfile.mkdtemp() - +print('generating synthetic data to {} (this may take a while)'.format(tempdir)) for i in range(50): im, seg = create_test_image_3d(128, 128, 128) diff --git a/examples/unet_segmentation_3d_dict.py b/examples/unet_segmentation_3d_dict.py index 5eb4dec27c..640bed21c0 100644 --- a/examples/unet_segmentation_3d_dict.py +++ b/examples/unet_segmentation_3d_dict.py @@ -41,7 +41,7 @@ # Create a temporary directory and 50 random image, mask paris tempdir = tempfile.mkdtemp() - +print('generating synthetic data to {} (this may take a while)'.format(tempdir)) for i in range(50): im, seg = create_test_image_3d(128, 128, 128, channel_dim=-1) diff --git a/monai/handlers/segmentation_saver.py b/monai/handlers/segmentation_saver.py index 0983f4b135..e0fa50310a 100644 --- a/monai/handlers/segmentation_saver.py +++ b/monai/handlers/segmentation_saver.py @@ -91,16 +91,20 @@ def _create_file_basename(postfix, input_file_name, folder_path, data_root_dir=" def __call__(self, engine): """ - This method assumes to have the keys: - 'filename_or_obj' -- for output file name creation - and optionally 'original_affine', 'affine' for data orientation handling. - - output file datatype from `engine.state.output.dtype`. - And this method assumes self.batch_transform will extract meta data from the input batch. + This method assumes self.batch_transform will extract Metadata from the input batch. + Metadata should have the following keys: + + - ``'filename_or_obj'`` -- for output file name creation + - ``'original_affine'`` (optional) for data orientation handling + - ``'affine'`` (optional) for data output affine. + + output file datatype is determined from ``engine.state.output.dtype``. """ meta_data = self.batch_transform(engine.state.batch) filenames = meta_data['filename_or_obj'] original_affine = meta_data.get('original_affine', None) affine = meta_data.get('affine', None) + engine_output = self.output_transform(engine.state.output) for batch_id, filename in enumerate(filenames): # save a batch of files seg_output = engine_output[batch_id] diff --git a/monai/handlers/stats_handler.py b/monai/handlers/stats_handler.py index 69bdc9379d..47709543fe 100644 --- a/monai/handlers/stats_handler.py +++ b/monai/handlers/stats_handler.py @@ -50,7 +50,6 @@ def __init__(self, self.output_transform = output_transform self.logger = None if name is None else logging.getLogger(name) - def attach(self, engine: Engine): """Register a set of Ignite Event-Handlers to a specified Ignite engine. @@ -136,8 +135,8 @@ def _default_iteration_print(self, engine: Engine): """ loss = self.output_transform(engine.state.output)[1] - if loss is None or (isinstance(loss, torch.Tensor) is True and len(loss.shape) > 0): - return + if loss is None or (torch.is_tensor(loss) and len(loss.shape) > 0): + return # not printing multi dimensional output num_iterations = engine.state.epoch_length current_iteration = (engine.state.iteration - 1) % num_iterations + 1 current_epoch = engine.state.epoch @@ -148,6 +147,6 @@ def _default_iteration_print(self, engine: Engine): num_epochs, current_iteration, num_iterations) - out_str += KEY_VAL_FORMAT.format('Loss', loss.item() if isinstance(loss, torch.Tensor) else loss) + out_str += KEY_VAL_FORMAT.format('Loss', loss.item() if torch.is_tensor(loss) else loss) self.logger.info(out_str)