From 3a8f99c03598307ecc7b34b2080e38add6227ba8 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Thu, 12 Mar 2020 09:26:24 +0800 Subject: [PATCH 1/3] 160 develop TensorBoard event handler (#161) Co-authored-by: Kevin Lu Co-authored-by: Wenqi Li --- examples/densenet_classification_3d.py | 4 +- examples/unet_segmentation_3d.ipynb | 4 +- examples/unet_segmentation_3d_array.py | 80 ++++---- examples/unet_segmentation_3d_dict.py | 4 +- monai/handlers/segmentation_saver.py | 5 +- monai/handlers/stats_handler.py | 77 ++++++-- monai/handlers/tensorboard_handlers.py | 249 +++++++++++++++++++++++++ monai/handlers/utils.py | 4 + monai/networks/utils.py | 4 +- monai/utils/misc.py | 15 ++ monai/visualize/img2tensorboard.py | 10 +- tests/integration_sliding_window.py | 10 +- tests/integration_unet2d.py | 9 +- tests/test_handler_stats.py | 38 +++- tests/test_handler_tb_image.py | 60 ++++++ tests/test_handler_tb_stats.py | 81 ++++++++ tests/test_unet.py | 6 +- 17 files changed, 567 insertions(+), 93 deletions(-) create mode 100644 monai/handlers/tensorboard_handlers.py create mode 100644 tests/test_handler_tb_image.py create mode 100644 tests/test_handler_tb_stats.py diff --git a/examples/densenet_classification_3d.py b/examples/densenet_classification_3d.py index 07ac3ffe04..753097a2a9 100644 --- a/examples/densenet_classification_3d.py +++ b/examples/densenet_classification_3d.py @@ -91,7 +91,7 @@ trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=checkpoint_handler, to_save={'net': net, 'opt': opt}) -train_stats_handler = StatsHandler() +train_stats_handler = StatsHandler(output_transform=lambda x: x[3]) train_stats_handler.attach(trainer) @trainer.on(Events.EPOCH_COMPLETED) @@ -108,7 +108,7 @@ def log_training_loss(engine): # Add stats event handler to print validation stats via evaluator logging.basicConfig(stream=sys.stdout, level=logging.INFO) -val_stats_handler = StatsHandler() +val_stats_handler = StatsHandler(output_transform=lambda x: None) val_stats_handler.attach(evaluator) # Add early stopping handler to evaluator. diff --git a/examples/unet_segmentation_3d.ipynb b/examples/unet_segmentation_3d.ipynb index bfba0a70bf..0d49742f10 100644 --- a/examples/unet_segmentation_3d.ipynb +++ b/examples/unet_segmentation_3d.ipynb @@ -197,7 +197,7 @@ "trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED,\n", " handler=checkpoint_handler,\n", " to_save={'net': net, 'opt': opt})\n", - "train_stats_handler = StatsHandler()\n", + "train_stats_handler = StatsHandler(output_transform=lambda x: x[1])\n", "train_stats_handler.attach(trainer)\n", "\n", "writer = SummaryWriter()\n", @@ -260,7 +260,7 @@ "\n", "# Add stats event handler to print validation stats via evaluator\n", "logging.basicConfig(stream=sys.stdout, level=logging.INFO)\n", - "val_stats_handler = StatsHandler()\n", + "val_stats_handler = StatsHandler(lambda x: None)\n", "val_stats_handler.attach(evaluator)\n", "\n", "# Add early stopping handler to evaluator.\n", diff --git a/examples/unet_segmentation_3d_array.py b/examples/unet_segmentation_3d_array.py index de5996470d..3b0d880f10 100644 --- a/examples/unet_segmentation_3d_array.py +++ b/examples/unet_segmentation_3d_array.py @@ -18,7 +18,6 @@ import nibabel as nib import numpy as np import torch -from torch.utils.tensorboard import SummaryWriter from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator from ignite.handlers import ModelCheckpoint, EarlyStopping from torch.utils.data import DataLoader @@ -32,8 +31,8 @@ from monai.data.nifti_reader import NiftiDataset from monai.transforms import AddChannel, Rescale, ToTensor, UniformRandomPatch from monai.handlers.stats_handler import StatsHandler +from monai.handlers.tensorboard_handlers import TensorBoardStatsHandler, TensorBoardImageHandler from monai.handlers.mean_dice import MeanDice -from monai.visualize import img2tensorboard from monai.data.synthetic import create_test_image_3d from monai.handlers.utils import stopping_fn_from_metric @@ -88,70 +87,71 @@ loss = monai.losses.DiceLoss(do_sigmoid=True) opt = torch.optim.Adam(net.parameters(), lr) + # Since network outputs logits and segmentation, we need a custom function. def _loss_fn(i, j): return loss(i[0], j) + # Create trainer -device = torch.device("cuda:0") +device = torch.device("cpu:0") trainer = create_supervised_trainer(net, opt, _loss_fn, device, False, - output_transform=lambda x, y, y_pred, loss: [y_pred, loss.item(), y]) + output_transform=lambda x, y, y_pred, loss: [y_pred[1], loss.item(), y]) # adding checkpoint handler to save models (network params and optimizer stats) during training checkpoint_handler = ModelCheckpoint('./runs/', 'net', n_saved=10, require_empty=False) trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=checkpoint_handler, to_save={'net': net, 'opt': opt}) -train_stats_handler = StatsHandler() +logging.basicConfig(stream=sys.stdout, level=logging.INFO) + +# print training loss to commandline +train_stats_handler = StatsHandler(output_transform=lambda x: x[1]) train_stats_handler.attach(trainer) +# record training loss to TensorBoard at every iteration +train_tensorboard_stats_handler = TensorBoardStatsHandler( + output_transform=lambda x: {'training_dice_loss': x[1]}, # plot under tag name taining_dice_loss + global_epoch_transform=lambda x: trainer.state.epoch) +train_tensorboard_stats_handler.attach(trainer) + + @trainer.on(Events.EPOCH_COMPLETED) def log_training_loss(engine): - # log loss to tensorboard with second item of engine.state.output, loss.item() from output_transform - writer.add_scalar('Loss/train', engine.state.output[1], engine.state.epoch) - - # tensor of ones to use where for converting labels to zero and ones - ones = torch.ones(engine.state.batch[1][0].shape, dtype=torch.int32) - first_output_tensor = engine.state.output[0][1][0].detach().cpu() - # log model output to tensorboard, as three dimensional tensor with no channels dimension - img2tensorboard.add_animated_gif_no_channels(writer, "first_output_final_batch", first_output_tensor, 64, - 255, engine.state.epoch) - # get label tensor and convert to single class - first_label_tensor = torch.where(engine.state.batch[1][0] > 0, ones, engine.state.batch[1][0]) - # log label tensor to tensorboard, there is a channel dimension when getting label from batch - img2tensorboard.add_animated_gif(writer, "first_label_final_batch", first_label_tensor, 64, - 255, engine.state.epoch) - second_output_tensor = engine.state.output[0][1][1].detach().cpu() - img2tensorboard.add_animated_gif_no_channels(writer, "second_output_final_batch", second_output_tensor, 64, - 255, engine.state.epoch) - second_label_tensor = torch.where(engine.state.batch[1][1] > 0, ones, engine.state.batch[1][1]) - img2tensorboard.add_animated_gif(writer, "second_label_final_batch", second_label_tensor, 64, - 255, engine.state.epoch) - third_output_tensor = engine.state.output[0][1][2].detach().cpu() - img2tensorboard.add_animated_gif_no_channels(writer, "third_output_final_batch", third_output_tensor, 64, - 255, engine.state.epoch) - third_label_tensor = torch.where(engine.state.batch[1][2] > 0, ones, engine.state.batch[1][2]) - img2tensorboard.add_animated_gif(writer, "third_label_final_batch", third_label_tensor, 64, - 255, engine.state.epoch) engine.logger.info("Epoch[%s] Loss: %s", engine.state.epoch, engine.state.output[1]) -writer = SummaryWriter() # Set parameters for validation validation_every_n_epochs = 1 metric_name = 'Mean_Dice' # add evaluation metric to the evaluator engine -val_metrics = {metric_name: MeanDice(add_sigmoid=True)} -evaluator = create_supervised_evaluator(net, val_metrics, device, True, - output_transform=lambda x, y, y_pred: (y_pred[0], y)) +val_metrics = {metric_name: MeanDice( + add_sigmoid=True, to_onehot_y=False, output_transform=lambda output: (output[0][0], output[1])) +} +evaluator = create_supervised_evaluator(net, val_metrics, device, True) # Add stats event handler to print validation stats via evaluator -logging.basicConfig(stream=sys.stdout, level=logging.INFO) -val_stats_handler = StatsHandler() +val_stats_handler = StatsHandler( + output_transform=lambda x: None, # disable per iteration output + global_epoch_transform=lambda x: trainer.state.epoch) val_stats_handler.attach(evaluator) -# Add early stopping handler to evaluator. +# add handler to record metrics to TensorBoard at every epoch +val_tensorboard_stats_handler = TensorBoardStatsHandler( + output_transform=lambda x: None, # no iteration plot + global_epoch_transform=lambda x: trainer.state.epoch) # use epoch number from trainer +val_tensorboard_stats_handler.attach(evaluator) +# add handler to draw several images and the corresponding labels and model outputs +# here we draw the first 3 images(draw the first channel) as GIF format along Depth axis +val_tensorboard_image_handler = TensorBoardImageHandler( + batch_transform=lambda batch: (batch[0], batch[1]), + output_transform=lambda output: output[0][1], + global_iter_transform=lambda x: trainer.state.epoch +) +evaluator.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=val_tensorboard_image_handler) + +# Add early stopping handler to evaluator early_stopper = EarlyStopping(patience=4, score_function=stopping_fn_from_metric(metric_name), trainer=trainer) @@ -166,10 +166,6 @@ def log_training_loss(engine): def run_validation(engine): evaluator.run(val_loader) -@evaluator.on(Events.EPOCH_COMPLETED) -def log_metrics_to_tensorboard(engine): - for name, value in engine.state.metrics.items(): - writer.add_scalar('Metrics/{name}', value, trainer.state.epoch) # create a training data loader logging.basicConfig(stream=sys.stdout, level=logging.INFO) diff --git a/examples/unet_segmentation_3d_dict.py b/examples/unet_segmentation_3d_dict.py index 640bed21c0..0e1e78811b 100644 --- a/examples/unet_segmentation_3d_dict.py +++ b/examples/unet_segmentation_3d_dict.py @@ -111,7 +111,7 @@ def prepare_batch(batch, device=None, non_blocking=False): trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=checkpoint_handler, to_save={'net': net, 'opt': opt}) -train_stats_handler = StatsHandler() +train_stats_handler = StatsHandler(output_transform=lambda x: x[1]) train_stats_handler.attach(trainer) @@ -160,7 +160,7 @@ def log_training_loss(engine): # Add stats event handler to print validation stats via evaluator logging.basicConfig(stream=sys.stdout, level=logging.INFO) -val_stats_handler = StatsHandler() +val_stats_handler = StatsHandler(output_transform=lambda x: None) val_stats_handler.attach(evaluator) # Add early stopping handler to evaluator. diff --git a/monai/handlers/segmentation_saver.py b/monai/handlers/segmentation_saver.py index e0fa50310a..1f3fe2615d 100644 --- a/monai/handlers/segmentation_saver.py +++ b/monai/handlers/segmentation_saver.py @@ -10,7 +10,7 @@ # limitations under the License. import os - +import numpy as np import torch from ignite.engine import Events @@ -114,5 +114,6 @@ def __call__(self, engine): seg_output = seg_output.detach().cpu().numpy() output_filename = self._create_file_basename(self.output_postfix, filename, self.output_path) output_filename = '{}{}'.format(output_filename, self.output_ext) - write_nifti(seg_output, affine_, output_filename, original_affine_, dtype=seg_output.dtype) + # change output to "channel last" format and write to nifti format file + write_nifti(np.moveaxis(seg_output, 0, -1), affine_, output_filename, original_affine_, dtype=seg_output.dtype) self.logger.info('saved: {}'.format(output_filename)) diff --git a/monai/handlers/stats_handler.py b/monai/handlers/stats_handler.py index 47709543fe..9d2e518919 100644 --- a/monai/handlers/stats_handler.py +++ b/monai/handlers/stats_handler.py @@ -9,47 +9,63 @@ # See the License for the specific language governing permissions and # limitations under the License. +import warnings import logging import torch from ignite.engine import Engine, Events +from monai.utils.misc import is_scalar -KEY_VAL_FORMAT = '{}: {:.4f} ' +DEFAULT_KEY_VAL_FORMAT = '{}: {:.4f} ' +DEFAULT_TAG = 'Loss' class StatsHandler(object): """StatsHandler defines a set of Ignite Event-handlers for all the log printing logics. It's can be used for any Ignite Engine(trainer, validator and evaluator). And it can support logging for epoch level and iteration level with pre-defined loggers. - By default: - (1) epoch_print_logger logs `engine.state.metrics`. - (2) iteration_print_logger logs loss value, expected output format is (y_pred, loss). + + Default behaviors: + - When EPOCH_COMPLETED, logs ``engine.state.metrics`` using ``self.logger``. + - When ITERATION_COMPELTED, logs + ``self.output_transform(engine.state.output)`` using ``self.logger``. """ def __init__(self, epoch_print_logger=None, iteration_print_logger=None, - batch_transform=lambda x: x, output_transform=lambda x: x, - name=None): + global_epoch_transform=lambda x: x, + name=None, + tag_name=DEFAULT_TAG, + key_var_format=DEFAULT_KEY_VAL_FORMAT): """ Args: epoch_print_logger (Callable): customized callable printer for epoch level logging. must accept parameter "engine", use default printer if None. iteration_print_logger (Callable): custimized callable printer for iteration level logging. must accept parameter "engine", use default printer if None. - batch_transform (Callable): a callable that is used to transform the - ignite.engine.batch into expected format to extract input data. output_transform (Callable): a callable that is used to transform the - ignite.engine.output into expected format to extract several output data. - name (str): identifier of logging.logger to use, defaulting to `engine.logger`. + ``ignite.engine.output`` into a scalar to print, or a dictionary of {key: scalar}. + in the latter case, the output string will be formated as key: value. + by default this value logging happens when every iteration completed. + global_epoch_transform (Callable): a callable that is used to customize global epoch number. + For example, in evaluation, the evaluator engine might want to print synced epoch number + with the trainer engine. + name (str): identifier of logging.logger to use, defaulting to ``engine.logger``. + tag_name (string): when iteration output is a scalar, tag_name is used to print + tag_name: scalar_value to logger. Defaults to ``'Loss'``. + key_var_format (string): a formatting string to control the output string format of key: value. """ self.epoch_print_logger = epoch_print_logger self.iteration_print_logger = iteration_print_logger - self.batch_transform = batch_transform self.output_transform = output_transform + self.global_epoch_transform = global_epoch_transform self.logger = None if name is None else logging.getLogger(name) + self.tag_name = tag_name + self.key_var_format = key_var_format + def attach(self, engine: Engine): """Register a set of Ignite Event-Handlers to a specified Ignite engine. @@ -116,12 +132,12 @@ def _default_epoch_print(self, engine: Engine): prints_dict = engine.state.metrics if not prints_dict: return - current_epoch = engine.state.epoch + current_epoch = self.global_epoch_transform(engine.state.epoch) out_str = "Epoch[{}] Metrics -- ".format(current_epoch) for name in sorted(prints_dict): value = prints_dict[name] - out_str += KEY_VAL_FORMAT.format(name, value) + out_str += self.key_var_format.format(name, value) self.logger.info(out_str) @@ -134,19 +150,42 @@ def _default_iteration_print(self, engine: Engine): engine (ignite.engine): Ignite Engine, it can be a trainer, validator or evaluator. """ - loss = self.output_transform(engine.state.output)[1] - if loss is None or (torch.is_tensor(loss) and len(loss.shape) > 0): - return # not printing multi dimensional output + loss = self.output_transform(engine.state.output) + if loss is None: + return # no printing if the output is empty + + out_str = '' + if isinstance(loss, dict): # print dictionary items + for name in sorted(loss): + value = loss[name] + if not is_scalar(value): + warnings.warn('ignoring non-scalar output in StatsHandler,' + ' make sure `output_transform(engine.state.output)` returns' + ' a scalar or dictionary of key and scalar pairs to avoid this warning.' + ' {}:{}'.format(name, type(value))) + continue # not printing multi dimensional output + out_str += self.key_var_format.format(name, value.item() if torch.is_tensor(value) else value) + else: + if is_scalar(loss): # not printing multi dimensional output + out_str += self.key_var_format.format(self.tag_name, loss.item() if torch.is_tensor(loss) else loss) + else: + warnings.warn('ignoring non-scalar output in StatsHandler,' + ' make sure `output_transform(engine.state.output)` returns' + ' a scalar or a dictionary of key and scalar pairs to avoid this warning.' + ' {}'.format(type(loss))) + + if not out_str: + return # no value to print + num_iterations = engine.state.epoch_length current_iteration = (engine.state.iteration - 1) % num_iterations + 1 current_epoch = engine.state.epoch num_epochs = engine.state.max_epochs - out_str = "Epoch: {}/{}, Iter: {}/{} -- ".format( + base_str = "Epoch: {}/{}, Iter: {}/{} --".format( current_epoch, num_epochs, current_iteration, num_iterations) - out_str += KEY_VAL_FORMAT.format('Loss', loss.item() if torch.is_tensor(loss) else loss) - self.logger.info(out_str) + self.logger.info(' '.join([base_str, out_str])) diff --git a/monai/handlers/tensorboard_handlers.py b/monai/handlers/tensorboard_handlers.py new file mode 100644 index 0000000000..2d5116bd59 --- /dev/null +++ b/monai/handlers/tensorboard_handlers.py @@ -0,0 +1,249 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import warnings +import torch +from torch.utils.tensorboard import SummaryWriter +from ignite.engine import Engine, Events +from monai.visualize import img2tensorboard +from monai.utils.misc import is_scalar +from monai.transforms.utils import rescale_array + + +class TensorBoardStatsHandler(object): + """TensorBoardStatsHandler defines a set of Ignite Event-handlers for all the TensorBoard logics. + It's can be used for any Ignite Engine(trainer, validator and evaluator). + And it can support both epoch level and iteration level with pre-defined TensorBoard event writer. + The expected data source is ignite ``engine.state.output`` and ``engine.state.metrics``. + + Default behaviors: + - When EPOCH_COMPLETED, write each dictionary item in + ``engine.state.metrics`` to TensorBoard. + - When ITERATION_COMPELTED, write each dictionary item in + ``self.output_transform(engine.state.output)`` to TensorBoard. + """ + + def __init__(self, + summary_writer=None, + epoch_event_writer=None, + iteration_event_writer=None, + output_transform=lambda x: {'Loss': x}, + global_epoch_transform=lambda x: x): + """ + Args: + summary_writer (SummaryWriter): user can specify TensorBoard SummaryWriter, + default to create a new writer. + epoch_event_writer (Callable): customized callable TensorBoard writer for epoch level. + must accept parameter "engine" and "summary_writer", use default event writer if None. + iteration_event_writer (Callable): custimized callable TensorBoard writer for iteration level. + must accept parameter "engine" and "summary_writer", use default event writer if None. + output_transform (Callable): a callable that is used to transform the + ``ignite.engine.output`` into a dictionary of (tag_name: scalar) pairs to be plotted onto tensorboard. + by default this scalar plotting happens when every iteration completed. + global_epoch_transform (Callable): a callable that is used to customize global epoch number. + For example, in evaluation, the evaluator engine might want to use trainer engines epoch number + when plotting epoch vs metric curves. + """ + self._writer = SummaryWriter() if summary_writer is None else summary_writer + self.epoch_event_writer = epoch_event_writer + self.iteration_event_writer = iteration_event_writer + self.output_transform = output_transform + self.global_epoch_transform = global_epoch_transform + + def attach(self, engine: Engine): + """Register a set of Ignite Event-Handlers to a specified Ignite engine. + + Args: + engine (ignite.engine): Ignite Engine, it can be a trainer, validator or evaluator. + + """ + if not engine.has_event_handler(self.iteration_completed, Events.ITERATION_COMPLETED): + engine.add_event_handler(Events.ITERATION_COMPLETED, self.iteration_completed) + if not engine.has_event_handler(self.epoch_completed, Events.EPOCH_COMPLETED): + engine.add_event_handler(Events.EPOCH_COMPLETED, self.epoch_completed) + + def epoch_completed(self, engine: Engine): + """handler for train or validation/evaluation epoch completed Event. + Write epoch level events, default values are from ignite state.metrics dict. + + Args: + engine (ignite.engine): Ignite Engine, it can be a trainer, validator or evaluator. + + """ + if self.epoch_event_writer is not None: + self.epoch_event_writer(engine, self._writer) + else: + self._default_epoch_writer(engine, self._writer) + + def iteration_completed(self, engine: Engine): + """handler for train or validation/evaluation iteration completed Event. + Write iteration level events, default values are from ignite state.logs dict. + + Args: + engine (ignite.engine): Ignite Engine, it can be a trainer, validator or evaluator. + + """ + if self.iteration_event_writer is not None: + self.iteration_event_writer(engine, self._writer) + else: + self._default_iteration_writer(engine, self._writer) + + def _default_epoch_writer(self, engine: Engine, writer: SummaryWriter): + """Execute epoch level event write operation based on Ignite engine.state data. + Default is to write the values from ignite state.metrics dict. + + Args: + engine (ignite.engine): Ignite Engine, it can be a trainer, validator or evaluator. + writer (SummaryWriter): TensorBoard writer, created in TensorBoardHandler. + + """ + current_epoch = self.global_epoch_transform(engine.state.epoch) + summary_dict = engine.state.metrics + for name, value in summary_dict.items(): + writer.add_scalar(name, value, current_epoch) + writer.flush() + + def _default_iteration_writer(self, engine: Engine, writer: SummaryWriter): + """Execute iteration level event write operation based on Ignite engine.state data. + Default is to write the loss value of current iteration. + + Args: + engine (ignite.engine): Ignite Engine, it can be a trainer, validator or evaluator. + writer (SummaryWriter): TensorBoard writer, created in TensorBoardHandler. + + """ + loss_dict = self.output_transform(engine.state.output) + if loss_dict is None: + return # do nothing if output is empty + if not isinstance(loss_dict, dict): + raise ValueError('TensorBoardStatsHandler requires' + ' output_transform(engine.state.output) returning a dictionary' + ' of key and scalar pairs to plot' + ' got {}.'.format(type(loss_dict))) + for name, value in loss_dict.items(): + if not is_scalar(value): + warnings.warn('ignoring non-scalar output in tensorboard curve plotting,' + ' make sure `output_transform(engine.state.output)` returns' + ' a dictionary of key and scalar pairs to avoid this warning.' + ' Got {}:{}'.format(name, type(value))) + continue + plot_value = value.item() if torch.is_tensor(value) else value + writer.add_scalar(name, plot_value, engine.state.iteration) + writer.flush() + + +class TensorBoardImageHandler(object): + """TensorBoardImageHandler is an ignite Event handler that can visualise images, labels and outputs as 2D/3D images. + 2D output (shape in Batch, channel, H, W) will be shown as simple image using the first element in the batch, + for 3D to ND output (shape in Batch, channel, H, W, D) input, + the last three dimensions will be shown as GIF image along the last axis (typically Depth). + + It's can be used for any Ignite Engine (trainer, validator and evaluator). + User can easily added it to engine for any expected Event, for example: ``EPOCH_COMPLETED``, + ``ITERATION_COMPLETED``. The expected data source is ignite's ``engine.state.batch`` and ``engine.state.output``. + + Default behavior: + - Show y_pred as images (GIF for 3D) on TensorBoard when Event triggered, + - need to use ``batch_transform`` and ``output_transform`` to specify + how many images to show and show which channel. + - Expects ``batch_transform(engine.state.batch)`` to return data + format: (image[N, channel, ...], label[N, channel, ...]). + - Expects ``output_transform(engine.state.output)`` to return a torch + tensor in format (y_pred[N, channel, ...], loss). + """ + + def __init__(self, + summary_writer=None, + batch_transform=lambda x: x, + output_transform=lambda x: x, + global_iter_transform=lambda x: x, + max_channels=1, + max_frames=64): + """ + Args: + summary_writer (SummaryWriter): user can specify TensorBoard SummaryWriter, + default to create a new writer. + batch_transform (Callable): a callable that is used to transform the + ``ignite.engine.batch`` into expected format to extract several label data. + output_transform (Callable): a callable that is used to transform the + ``ignite.engine.output`` into expected format to extract several output data. + global_iter_transform (Callable): a callable that is used to customize global step number for TensorBoard. + For example, in evaluation, the evaluator engine needs to know current epoch from trainer. + max_channels (int): number of channels to plot. + max_frames (int): number of frames for 2D-t plot. + + """ + self._writer = SummaryWriter() if summary_writer is None else summary_writer + self.batch_transform = batch_transform + self.output_transform = output_transform + self.global_iter_transform = global_iter_transform + + self.max_frames = max_frames + self.max_channels = max_channels + + def __call__(self, engine): + step = self.global_iter_transform(engine.state.iteration) + + show_images = self.batch_transform(engine.state.batch)[0] + if torch.is_tensor(show_images): + show_images = show_images.detach().cpu().numpy() + if show_images is not None: + if not isinstance(show_images, np.ndarray): + raise ValueError('output_transform(engine.state.output)[0] must be an ndarray or tensor.') + self._add_2_or_3_d(show_images, step, 'input_0') + + show_labels = self.batch_transform(engine.state.batch)[1] + if torch.is_tensor(show_labels): + show_labels = show_labels.detach().cpu().numpy() + if show_labels is not None: + if not isinstance(show_labels, np.ndarray): + raise ValueError('batch_transform(engine.state.batch)[1] must be an ndarray or tensor.') + self._add_2_or_3_d(show_labels, step, 'input_1') + + show_outputs = self.output_transform(engine.state.output) + if torch.is_tensor(show_outputs): + show_outputs = show_outputs.detach().cpu().numpy() + if show_outputs is not None: + if not isinstance(show_outputs, np.ndarray): + raise ValueError('output_transform(engine.state.output) must be an ndarray or tensor.') + self._add_2_or_3_d(show_outputs, step, 'output') + + self._writer.flush() + + def _add_2_or_3_d(self, data, step, tag='output'): + # for i, d in enumerate(data): # go through a batch of images + d = data[0] # show the first element in a batch + + if d.ndim == 2: + d = rescale_array(d, 0, 1) + dataformats = 'HW' + self._writer.add_image('{}_{}'.format(tag, dataformats), d, step, dataformats=dataformats) + return + + if d.ndim == 3: + if d.shape[0] == 3 and self.max_channels == 3: # rgb? + dataformats = 'CHW' + self._writer.add_image('{}_{}'.format(tag, dataformats), d, step, dataformats='CHW') + return + for j, d2 in enumerate(d[:self.max_channels]): + d2 = rescale_array(d2, 0, 1) + dataformats = 'HW' + self._writer.add_image('{}_{}_{}'.format(tag, dataformats, j), d2, step, dataformats=dataformats) + return + + if d.ndim >= 4: + spatial = d.shape[-3:] + for j, d3 in enumerate(d.reshape([-1] + list(spatial))[:self.max_channels]): + d3 = rescale_array(d3, 0, 255) + img2tensorboard.add_animated_gif( + self._writer, '{}_HWD_{}'.format(tag, j), d3[None], self.max_frames, 1.0, step) + return diff --git a/monai/handlers/utils.py b/monai/handlers/utils.py index 377d4d0073..1cd849d18e 100644 --- a/monai/handlers/utils.py +++ b/monai/handlers/utils.py @@ -12,13 +12,17 @@ def stopping_fn_from_metric(metric_name): """Returns a stopping function for ignite.handlers.EarlyStopping using the given metric name.""" + def stopping_fn(engine): return engine.state.metrics[metric_name] + return stopping_fn def stopping_fn_from_loss(): """Returns a stopping function for ignite.handlers.EarlyStopping using the loss value.""" + def stopping_fn(engine): return -engine.state.output + return stopping_fn diff --git a/monai/networks/utils.py b/monai/networks/utils.py index 5a22884846..bca9922374 100644 --- a/monai/networks/utils.py +++ b/monai/networks/utils.py @@ -52,6 +52,6 @@ def predict_segmentation(logits): """ # generate prediction outputs, logits has shape BCHW[D] if logits.shape[1] == 1: - return (logits[:, 0] >= 0).int() # for binary segmentation threshold on channel 0 + return (logits >= 0).int() # for binary segmentation threshold on channel 0 else: - return logits.max(1)[1] # take the index of the max value along dimension 1 + return logits.argmax(1).unsqueeze(1) # take the index of the max value along dimension 1 diff --git a/monai/utils/misc.py b/monai/utils/misc.py index 775b8f2ebd..261e521adb 100644 --- a/monai/utils/misc.py +++ b/monai/utils/misc.py @@ -11,6 +11,9 @@ import itertools +import numpy as np +import torch + def zip_with(op, *vals, mapfunc=map): """ @@ -40,3 +43,15 @@ def ensure_tuple(vals): vals = (vals,) return tuple(vals) + + +def is_scalar_tensor(val): + if torch.is_tensor(val) and val.ndim == 0: + return True + return False + + +def is_scalar(val): + if torch.is_tensor(val) and val.ndim == 0: + return True + return np.isscalar(val) diff --git a/monai/visualize/img2tensorboard.py b/monai/visualize/img2tensorboard.py index 82211a8ecd..8fce996685 100644 --- a/monai/visualize/img2tensorboard.py +++ b/monai/visualize/img2tensorboard.py @@ -27,8 +27,8 @@ def _image3_animated_gif(imp, scale_factor=1): # x=numpy.random.randint(0,256,[10,10,10],numpy.uint8) (tag, ims) = imp ims = [ - (np.asarray((ims[i, :, :])) * scale_factor).astype(np.uint8) - for i in range(ims.shape[0]) + (np.asarray((ims[:, :, i])) * scale_factor).astype(np.uint8) + for i in range(ims.shape[2]) ] ims = [GifImage.fromarray(im) for im in ims] img_str = b'' @@ -49,8 +49,8 @@ def _image3_animated_gif(imp, scale_factor=1): def make_animated_gif_summary(tag, tensor, max_out=3, - animation_axes=(1,), - image_axes=(2, 3), + animation_axes=(3,), + image_axes=(1, 2), other_indices=None, scale_factor=1): """ @@ -58,7 +58,7 @@ def make_animated_gif_summary(tag, Args: tag: Data identifier - tensor: tensor for the image, expected to be in CDHW format + tensor: tensor for the image, expected to be in CHWD format max_out: maximum number of slices to animate through animation_axes: axis to animate on (not currently used) image_axes: axes of image (not currently used) diff --git a/tests/integration_sliding_window.py b/tests/integration_sliding_window.py index db10d7cc49..31c7a1248a 100644 --- a/tests/integration_sliding_window.py +++ b/tests/integration_sliding_window.py @@ -65,17 +65,19 @@ def _sliding_window_processor(_engine, batch): basename = os.path.basename(img_name)[:-len('.nii.gz')] saved_name = os.path.join(temp_dir, basename, '{}_seg.nii.gz'.format(basename)) - testing_shape = nib.load(saved_name).get_fdata().shape + # get spatial dimensions shape, the saved nifti image format: HWDC + testing_shape = nib.load(saved_name).get_fdata().shape[:-1] if os.path.exists(img_name): os.remove(img_name) if os.path.exists(seg_name): os.remove(seg_name) - - return testing_shape == input_shape + if testing_shape != input_shape: + print('testing shape: {} does not match input shape: {}.'.format(testing_shape, input_shape)) + return False + return True if __name__ == "__main__": result = run_test() - sys.exit(0 if result else 1) diff --git a/tests/integration_unet2d.py b/tests/integration_unet2d.py index 1fd9074c66..7b0f116b77 100644 --- a/tests/integration_unet2d.py +++ b/tests/integration_unet2d.py @@ -51,12 +51,13 @@ def loss_fn(pred, grnd): trainer = create_supervised_trainer(net, opt, loss_fn, device, False) trainer.run(src, 1) - - return trainer.state.output + loss = trainer.state.output + print('Loss:', loss) + if loss >= 1: + print('Loss value is wrong, expect to be < 1.') + return loss if __name__ == "__main__": result = run_test() - print(result) - sys.exit(0 if result < 1 else 1) diff --git a/tests/test_handler_stats.py b/tests/test_handler_stats.py index 58a62133d2..fdb0600e04 100644 --- a/tests/test_handler_stats.py +++ b/tests/test_handler_stats.py @@ -30,7 +30,7 @@ def test_metrics_print(self): # set up engine def _train_func(engine, batch): - return None, torch.tensor(0.0) + return torch.tensor(0.0) engine = Engine(_train_func) @@ -50,7 +50,6 @@ def _update_metric(engine): output_str = log_stream.getvalue() grep = re.compile('.*{}.*'.format(key_to_handler)) has_key_word = re.compile('.*{}.*'.format(key_to_print)) - matched = [] for idx, line in enumerate(output_str.split('\n')): if grep.match(line): if idx in [5, 10]: @@ -60,16 +59,44 @@ def test_loss_print(self): log_stream = StringIO() logging.basicConfig(stream=log_stream, level=logging.INFO) key_to_handler = 'test_logging' - key_to_print = 'Loss' + key_to_print = 'myLoss' # set up engine def _train_func(engine, batch): - return None, torch.tensor(0.0) + return torch.tensor(0.0) engine = Engine(_train_func) # set up testing handler - stats_handler = StatsHandler(name=key_to_handler) + stats_handler = StatsHandler(name=key_to_handler, tag_name=key_to_print) + stats_handler.attach(engine) + + engine.run(range(3), max_epochs=2) + + # check logging output + output_str = log_stream.getvalue() + grep = re.compile('.*{}.*'.format(key_to_handler)) + has_key_word = re.compile('.*{}.*'.format(key_to_print)) + for idx, line in enumerate(output_str.split('\n')): + if grep.match(line): + if idx in [1, 2, 3, 6, 7, 8]: + self.assertTrue(has_key_word.match(line)) + + def test_loss_dict(self): + log_stream = StringIO() + logging.basicConfig(stream=log_stream, level=logging.INFO) + key_to_handler = 'test_logging' + key_to_print = 'myLoss1' + + # set up engine + def _train_func(engine, batch): + return torch.tensor(0.0) + + engine = Engine(_train_func) + + # set up testing handler + stats_handler = StatsHandler(name=key_to_handler, + output_transform=lambda x: {key_to_print: x}) stats_handler.attach(engine) engine.run(range(3), max_epochs=2) @@ -78,7 +105,6 @@ def _train_func(engine, batch): output_str = log_stream.getvalue() grep = re.compile('.*{}.*'.format(key_to_handler)) has_key_word = re.compile('.*{}.*'.format(key_to_print)) - matched = [] for idx, line in enumerate(output_str.split('\n')): if grep.match(line): if idx in [1, 2, 3, 6, 7, 8]: diff --git a/tests/test_handler_tb_image.py b/tests/test_handler_tb_image.py new file mode 100644 index 0000000000..9bf55e162b --- /dev/null +++ b/tests/test_handler_tb_image.py @@ -0,0 +1,60 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import glob +import os +import shutil +import unittest + +import numpy as np +import torch +from ignite.engine import Engine, Events +from parameterized import parameterized + +from monai.handlers.tensorboard_handlers import TensorBoardImageHandler + +TEST_CASES = [ + [[20, 20]], + [[2, 20, 20]], + [[3, 20, 20]], + [[20, 20, 20]], + [[2, 20, 20, 20]], + [[2, 2, 20, 20, 20]], +] + + +class TestHandlerTBImage(unittest.TestCase): + + @parameterized.expand(TEST_CASES) + def test_tb_image_shape(self, shape): + default_dir = os.path.join('.', 'runs') + shutil.rmtree(default_dir, ignore_errors=True) + + # set up engine + def _train_func(engine, batch): + return torch.zeros((1, 1, 10, 10)) + + engine = Engine(_train_func) + + # set up testing handler + stats_handler = TensorBoardImageHandler() + engine.add_event_handler(Events.ITERATION_COMPLETED, stats_handler) + + data = zip(np.random.normal(size=(10, 4, *shape)), np.random.normal(size=(10, 4, *shape))) + engine.run(data, epoch_length=10, max_epochs=1) + + self.assertTrue(os.path.exists(default_dir)) + self.assertTrue(len(glob.glob(default_dir)) > 0) + shutil.rmtree(default_dir) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_handler_tb_stats.py b/tests/test_handler_tb_stats.py new file mode 100644 index 0000000000..53a691701f --- /dev/null +++ b/tests/test_handler_tb_stats.py @@ -0,0 +1,81 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import shutil +import tempfile +import unittest +import glob + +from ignite.engine import Engine, Events +from torch.utils.tensorboard import SummaryWriter + +from monai.handlers.tensorboard_handlers import TensorBoardStatsHandler + + +class TestHandlerTBStats(unittest.TestCase): + + def test_metrics_print(self): + default_dir = os.path.join('.', 'runs') + shutil.rmtree(default_dir, ignore_errors=True) + + # set up engine + def _train_func(engine, batch): + return batch + 1.0 + + engine = Engine(_train_func) + + # set up dummy metric + @engine.on(Events.EPOCH_COMPLETED) + def _update_metric(engine): + current_metric = engine.state.metrics.get('acc', 0.1) + engine.state.metrics['acc'] = current_metric + 0.1 + + # set up testing handler + stats_handler = TensorBoardStatsHandler() + stats_handler.attach(engine) + engine.run(range(3), max_epochs=2) + # check logging output + + self.assertTrue(os.path.exists(default_dir)) + shutil.rmtree(default_dir) + + def test_metrics_writer(self): + default_dir = os.path.join('.', 'runs') + shutil.rmtree(default_dir, ignore_errors=True) + with tempfile.TemporaryDirectory() as temp_dir: + + # set up engine + def _train_func(engine, batch): + return batch + 1.0 + + engine = Engine(_train_func) + + # set up dummy metric + @engine.on(Events.EPOCH_COMPLETED) + def _update_metric(engine): + current_metric = engine.state.metrics.get('acc', 0.1) + engine.state.metrics['acc'] = current_metric + 0.1 + + # set up testing handler + writer = SummaryWriter(log_dir=temp_dir) + stats_handler = TensorBoardStatsHandler( + writer, output_transform=lambda x: {'loss': x * 2.0}, + global_epoch_transform=lambda x: x * 3.0) + stats_handler.attach(engine) + engine.run(range(3), max_epochs=2) + # check logging output + self.assertTrue(len(glob.glob(temp_dir)) > 0) + self.assertTrue(not os.path.exists(default_dir)) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_unet.py b/tests/test_unet.py index 98102375a6..c1e838c284 100644 --- a/tests/test_unet.py +++ b/tests/test_unet.py @@ -26,7 +26,7 @@ 'num_res_units': 1, }, torch.randn(16, 1, 32, 32), - (16, 32, 32), + (16, 1, 32, 32), ] TEST_CASE_2 = [ # single channel 3D, batch 16 @@ -39,7 +39,7 @@ 'num_res_units': 1, }, torch.randn(16, 1, 32, 24, 48), - (16, 32, 24, 48), + (16, 1, 32, 24, 48), ] TEST_CASE_3 = [ # 4-channel 3D, batch 16 @@ -52,7 +52,7 @@ 'num_res_units': 1, }, torch.randn(16, 4, 32, 64, 48), - (16, 32, 64, 48), + (16, 1, 32, 64, 48), ] From e40917fe79992fb21e2c36640b684642bd506391 Mon Sep 17 00:00:00 2001 From: Mohammad Adil Date: Wed, 11 Mar 2020 18:49:14 -0700 Subject: [PATCH 2/3] Adding dict-based and random spatial transforms. (#163) Co-authored-by: Nic Ma --- monai/transforms/composables.py | 230 +++++++++++++++++- monai/transforms/transforms.py | 3 +- tests/test_flip.py | 33 ++- ...{test_random_flip.py => test_rand_flip.py} | 29 ++- ...t_random_rotate.py => test_rand_rotate.py} | 0 ...{test_random_zoom.py => test_rand_zoom.py} | 22 +- tests/test_rotate.py | 23 +- tests/test_zoom.py | 38 ++- 8 files changed, 330 insertions(+), 48 deletions(-) rename tests/{test_random_flip.py => test_rand_flip.py} (62%) rename tests/{test_random_rotate.py => test_rand_rotate.py} (100%) rename tests/{test_random_zoom.py => test_rand_zoom.py} (76%) diff --git a/monai/transforms/composables.py b/monai/transforms/composables.py index 8ccd5a747a..c19c3f7df9 100644 --- a/monai/transforms/composables.py +++ b/monai/transforms/composables.py @@ -22,7 +22,8 @@ from monai.transforms.compose import Randomizable, Transform from monai.transforms.transforms import (LoadNifti, AsChannelFirst, Orientation, AddChannel, Spacing, Rotate90, SpatialCrop, - RandAffine, Rand2DElastic, Rand3DElastic) + RandAffine, Rand2DElastic, Rand3DElastic, + Flip, Rotate, Zoom) from monai.utils.misc import ensure_tuple from monai.transforms.utils import generate_pos_neg_label_crop_centers, create_grid from monai.utils.aliases import alias @@ -476,7 +477,6 @@ def __init__(self, keys, as_tensor_output (bool): the computation is implemented using pytorch tensors, this option specifies whether to convert it back to numpy arrays. device (torch.device): device on which the tensor will be allocated. - See also: - ``RandAffineGrid`` for the random affine paramters configurations. - ``Affine`` for the affine transformation parameters configurations. @@ -551,7 +551,6 @@ def __init__(self, keys, as_tensor_output (bool): the computation is implemented using pytorch tensors, this option specifies whether to convert it back to numpy arrays. device (torch.device): device on which the tensor will be allocated. - See also: - ``RandAffineGrid`` for the random affine paramters configurations. - ``Affine`` for the affine transformation parameters configurations. @@ -594,3 +593,228 @@ def __call__(self, data): for key in self.keys: # same interpolation mode d[key] = self.rand_3d_elastic.resampler(d[key], grid, mode=self.rand_3d_elastic.mode) return d + + +@export +@alias('FlipD', 'FlipDict') +class Flipd(MapTransform): + """Dictionary-based wrapper of Flip. + + Args: + keys (dict): Keys to pick data for transformation. + axis (None, int or tuple of ints): Axes along which to flip over. Default is None. + """ + + def __init__(self, keys, axis=None): + MapTransform.__init__(self, keys) + self.flipper = Flip(axis=axis) + + def __call__(self, data): + d = dict(data) + for key in self.keys: + d[key] = self.flipper(d[key]) + return d + + +@export +@alias('RandFlipD', 'RandFlipDict') +class RandFlipd(Randomizable, MapTransform): + """Dict-based wrapper of RandFlip. + + Args: + prob (float): Probability of flipping. + axis (None, int or tuple of ints): Axes along which to flip over. Default is None. + """ + + def __init__(self, keys, prob=0.1, axis=None): + MapTransform.__init__(self, keys) + self.axis = axis + self.prob = prob + + self._do_transform = False + self.flipper = Flip(axis=axis) + + def randomize(self): + self._do_transform = self.R.random_sample() < self.prob + + def __call__(self, data): + self.randomize() + d = dict(data) + if not self._do_transform: + return d + for key in self.keys: + d[key] = self.flipper(d[key]) + return d + + +@export +@alias('RotateD', 'RotateDict') +class Rotated(MapTransform): + """Dictionary-based wrapper of Rotate. + + Args: + keys (dict): Keys to pick data for transformation. + angle (float): Rotation angle in degrees. + axes (tuple of 2 ints): Axes of rotation. Default: (1, 2). This is the first two + axis in spatial dimensions according to MONAI channel first shape assumption. + reshape (bool): If true, output shape is made same as input. Default: True. + order (int): Order of spline interpolation. Range 0-5. Default: 1. This is + different from scipy where default interpolation is 3. + mode (str): Points outside boundary filled according to this mode. Options are + 'constant', 'nearest', 'reflect', 'wrap'. Default: 'constant'. + cval (scalar): Values to fill outside boundary. Default: 0. + prefiter (bool): Apply spline_filter before interpolation. Default: True. + """ + + def __init__(self, keys, angle, axes=(1, 2), reshape=True, order=1, + mode='constant', cval=0, prefilter=True): + MapTransform.__init__(self, keys) + self.rotator = Rotate(angle=angle, axes=axes, reshape=reshape, + order=order, mode=mode, cval=cval, prefilter=prefilter) + + def __call__(self, data): + d = dict(data) + for key in self.keys: + d[key] = self.rotator(d[key]) + return d + + +@export +@alias('RandRotateD', 'RandRotateDict') +class RandRotated(Randomizable, MapTransform): + """Randomly rotates the input arrays. + + Args: + prob (float): Probability of rotation. + degrees (tuple of float or float): Range of rotation in degrees. If single number, + angle is picked from (-degrees, degrees). + axes (tuple of 2 ints): Axes of rotation. Default: (1, 2). This is the first two + axis in spatial dimensions according to MONAI channel first shape assumption. + reshape (bool): If true, output shape is made same as input. Default: True. + order (int): Order of spline interpolation. Range 0-5. Default: 1. This is + different from scipy where default interpolation is 3. + mode (str): Points outside boundary filled according to this mode. Options are + 'constant', 'nearest', 'reflect', 'wrap'. Default: 'constant'. + cval (scalar): Value to fill outside boundary. Default: 0. + prefiter (bool): Apply spline_filter before interpolation. Default: True. + """ + def __init__(self, keys, degrees, prob=0.1, axes=(1, 2), reshape=True, order=1, + mode='constant', cval=0, prefilter=True): + MapTransform.__init__(self, keys) + self.prob = prob + self.degrees = degrees + self.reshape = reshape + self.order = order + self.mode = mode + self.cval = cval + self.prefilter = prefilter + self.axes = axes + + if not hasattr(self.degrees, '__iter__'): + self.degrees = (-self.degrees, self.degrees) + assert len(self.degrees) == 2, "degrees should be a number or pair of numbers." + + self._do_transform = False + self.angle = None + + def randomize(self): + self._do_transform = self.R.random_sample() < self.prob + self.angle = self.R.uniform(low=self.degrees[0], high=self.degrees[1]) + + def __call__(self, data): + self.randomize() + d = dict(data) + if not self._do_transform: + return d + rotator = Rotate(self.angle, self.axes, self.reshape, self.order, + self.mode, self.cval, self.prefilter) + for key in self.keys: + d[key] = self.flipper(d[key]) + return d + + +@export +@alias('ZoomD', 'ZoomDict') +class Zoomd(MapTransform): + """Dictionary-based wrapper of Zoom transform. + + Args: + zoom (float or sequence): The zoom factor along the spatial axes. + If a float, zoom is the same for each spatial axis. + If a sequence, zoom should contain one value for each spatial axis. + order (int): order of interpolation. Default=3. + mode (str): Determines how input is extended beyond boundaries. Default is 'constant'. + cval (scalar, optional): Value to fill past edges. Default is 0. + use_gpu (bool): Should use cpu or gpu. Uses cupyx which doesn't support order > 1 and modes + 'wrap' and 'reflect'. Defaults to cpu for these cases or if cupyx not found. + keep_size (bool): Should keep original size (pad if needed). + """ + + def __init__(self, keys, zoom, order=3, mode='constant', cval=0, + prefilter=True, use_gpu=False, keep_size=False): + MapTransform.__init__(self, keys) + self.zoomer = Zoom(zoom=zoom, order=order, mode=mode, cval=cval, + prefilter=prefilter, use_gpu=use_gpu, keep_size=keep_size) + + def __call__(self, data): + d = dict(data) + for key in self.keys: + d[key] = self.zoomer(d[key]) + return d + + +@export +@alias('RandZoomD', 'RandZoomDict') +class RandZoomd(Randomizable, MapTransform): + """Dict-based wrapper of RandZoom. + + Args: + keys (dict): Keys to pick data for transformation. + prob (float): Probability of zooming. + min_zoom (float or sequence): Min zoom factor. Can be float or sequence same size as image. + max_zoom (float or sequence): Max zoom factor. Can be float or sequence same size as image. + order (int): order of interpolation. Default=3. + mode ('reflect', 'constant', 'nearest', 'mirror', 'wrap'): Determines how input is + extended beyond boundaries. Default: 'constant'. + cval (scalar, optional): Value to fill past edges. Default is 0. + use_gpu (bool): Should use cpu or gpu. Uses cupyx which doesn't support order > 1 and modes + 'wrap' and 'reflect'. Defaults to cpu for these cases or if cupyx not found. + keep_size (bool): Should keep original size (pad if needed). + """ + + def __init__(self, keys, prob=0.1, min_zoom=0.9, + max_zoom=1.1, order=3, mode='constant', + cval=0, prefilter=True, use_gpu=False, keep_size=False): + MapTransform.__init__(self, keys) + if hasattr(min_zoom, '__iter__') and \ + hasattr(max_zoom, '__iter__'): + assert len(min_zoom) == len(max_zoom), "min_zoom and max_zoom must have same length." + self.min_zoom = min_zoom + self.max_zoom = max_zoom + self.prob = prob + self.order = order + self.mode = mode + self.cval = cval + self.prefilter = prefilter + self.use_gpu = use_gpu + self.keep_size = keep_size + + self._do_transform = False + self._zoom = None + + def randomize(self): + self._do_transform = self.R.random_sample() < self.prob + if hasattr(self.min_zoom, '__iter__'): + self._zoom = (self.R.uniform(l, h) for l, h in zip(self.min_zoom, self.max_zoom)) + else: + self._zoom = self.R.uniform(self.min_zoom, self.max_zoom) + + def __call__(self, data): + self.randomize() + d = dict(data) + if not self._do_transform: + return d + zoomer = Zoom(self._zoom, self.order, self.mode, self.cval, self.prefilter, self.use_gpu, self.keep_size) + for key in self.keys: + d[key] = zoomer(d[key]) + return d diff --git a/monai/transforms/transforms.py b/monai/transforms/transforms.py index 370a7fb305..8f140972f6 100644 --- a/monai/transforms/transforms.py +++ b/monai/transforms/transforms.py @@ -434,7 +434,7 @@ def __call__(self, img): pad_vec[idx] = [half, diff - half] elif diff < 0: # need slicing slice_vec[idx] = slice(half, half + od) - zoomed = np.pad(zoomed, pad_vec) + zoomed = np.pad(zoomed, pad_vec, mode='constant') return zoomed[tuple(slice_vec)] @@ -696,6 +696,7 @@ def __init__(self, prob=0.1, axis=None): self.flipper = Flip(axis=axis) self._do_transform = False + self.flipper = Flip(axis=axis) def randomize(self): self._do_transform = self.R.random_sample() < self.prob diff --git a/tests/test_flip.py b/tests/test_flip.py index 3b027ec2c8..a261c315e2 100644 --- a/tests/test_flip.py +++ b/tests/test_flip.py @@ -14,31 +14,44 @@ import numpy as np from parameterized import parameterized -from monai.transforms import Flip +from monai.transforms import Flip, Flipd from tests.utils import NumpyImageTestCase2D +INVALID_CASES = [("wrong_axis", ['s', 1], TypeError), + ("not_numbers", 's', TypeError)] + +VALID_CASES = [("no_axis", None), + ("one_axis", 1), + ("many_axis", [0, 1, 2])] + class FlipTest(NumpyImageTestCase2D): - @parameterized.expand([ - ("wrong_axis", ['s', 1], TypeError), - ("not_numbers", 's', TypeError) - ]) + @parameterized.expand(INVALID_CASES) def test_invalid_inputs(self, _, axis, raises): with self.assertRaises(raises): flip = Flip(axis) flip(self.imt) - @parameterized.expand([ - ("no_axis", None), - ("one_axis", 1), - ("many_axis", [0, 1, 2]) - ]) + @parameterized.expand(INVALID_CASES) + def test_invalid_cases_dict(self, _, axis, raises): + with self.assertRaises(raises): + flip = Flipd(keys='img', axis=axis) + flip({'img': self.imt}) + + @parameterized.expand(VALID_CASES) def test_correct_results(self, _, axis): flip = Flip(axis=axis) expected = np.flip(self.imt, axis) self.assertTrue(np.allclose(expected, flip(self.imt))) + @parameterized.expand(VALID_CASES) + def test_correct_results_dict(self, _, axis): + flip = Flipd(keys='img', axis=axis) + expected = np.flip(self.imt, axis) + res = flip({'img': self.imt}) + assert np.allclose(expected, res['img']) + if __name__ == '__main__': unittest.main() diff --git a/tests/test_random_flip.py b/tests/test_rand_flip.py similarity index 62% rename from tests/test_random_flip.py rename to tests/test_rand_flip.py index ee89a133d9..be03ff5a28 100644 --- a/tests/test_random_flip.py +++ b/tests/test_rand_flip.py @@ -14,31 +14,38 @@ import numpy as np from parameterized import parameterized -from monai.transforms import RandFlip +from monai.transforms import RandFlip, RandFlipd from tests.utils import NumpyImageTestCase2D +INVALID_CASES = [("wrong_axis", ['s', 1], TypeError), + ("not_numbers", 's', TypeError)] -class RandomFlipTest(NumpyImageTestCase2D): +VALID_CASES = [("no_axis", None), + ("one_axis", 1), + ("many_axis", [0, 1, 2])] - @parameterized.expand([ - ("wrong_axis", ['s', 1], TypeError), - ("not_numbers", 's', TypeError) - ]) +class RandFlipTest(NumpyImageTestCase2D): + + @parameterized.expand(INVALID_CASES) def test_invalid_inputs(self, _, axis, raises): with self.assertRaises(raises): flip = RandFlip(prob=1.0, axis=axis) flip(self.imt) - @parameterized.expand([ - ("no_axis", None), - ("one_axis", 1), - ("many_axis", [0, 1, 2]) - ]) + @parameterized.expand(VALID_CASES) def test_correct_results(self, _, axis): flip = RandFlip(prob=1.0, axis=axis) expected = np.flip(self.imt, axis) self.assertTrue(np.allclose(expected, flip(self.imt))) + @parameterized.expand(VALID_CASES) + def test_correct_results_dict(self, _, axis): + flip = RandFlipd(keys='img', prob=1.0, axis=axis) + res = flip({'img': self.imt}) + + expected = np.flip(self.imt, axis) + self.assertTrue(np.allclose(expected, res['img'])) + if __name__ == '__main__': unittest.main() diff --git a/tests/test_random_rotate.py b/tests/test_rand_rotate.py similarity index 100% rename from tests/test_random_rotate.py rename to tests/test_rand_rotate.py diff --git a/tests/test_random_zoom.py b/tests/test_rand_zoom.py similarity index 76% rename from tests/test_random_zoom.py rename to tests/test_rand_zoom.py index d193a16dd2..530504b887 100644 --- a/tests/test_random_zoom.py +++ b/tests/test_rand_zoom.py @@ -17,15 +17,14 @@ from scipy.ndimage import zoom as zoom_scipy from parameterized import parameterized -from monai.transforms import RandZoom +from monai.transforms import RandZoom, RandZoomd from tests.utils import NumpyImageTestCase2D +VALID_CASES = [(0.9, 1.1, 3, 'constant', 0, True, False, False)] class ZoomTest(NumpyImageTestCase2D): - @parameterized.expand([ - (0.9, 1.1, 3, 'constant', 0, True, False, False), - ]) + @parameterized.expand(VALID_CASES) def test_correct_results(self, min_zoom, max_zoom, order, mode, cval, prefilter, use_gpu, keep_size): random_zoom = RandZoom(prob=1.0, min_zoom=min_zoom, max_zoom=max_zoom, order=order, @@ -39,6 +38,21 @@ def test_correct_results(self, min_zoom, max_zoom, order, mode, self.assertTrue(np.allclose(expected, zoomed)) + @parameterized.expand(VALID_CASES) + def test_correct_results_dict(self, min_zoom, max_zoom, order, mode, + cval, prefilter, use_gpu, keep_size): + keys = 'img' + random_zoom = RandZoomd(keys, prob=1.0, min_zoom=min_zoom, max_zoom=max_zoom, order=order, + mode=mode, cval=cval, prefilter=prefilter, use_gpu=use_gpu, + keep_size=keep_size) + random_zoom.set_random_state(234) + + zoomed = random_zoom({keys: self.imt}) + expected = zoom_scipy(self.imt, zoom=random_zoom._zoom, mode=mode, + order=order, cval=cval, prefilter=prefilter) + + self.assertTrue(np.allclose(expected, zoomed[keys])) + @parameterized.expand([ (0.8, 1.2, 1, 'constant', 0, True) ]) diff --git a/tests/test_rotate.py b/tests/test_rotate.py index 98e25f587f..0c34f5809e 100644 --- a/tests/test_rotate.py +++ b/tests/test_rotate.py @@ -15,17 +15,16 @@ import scipy.ndimage from parameterized import parameterized -from monai.transforms import Rotate +from monai.transforms import Rotate, Rotated from tests.utils import NumpyImageTestCase2D +TEST_CASES = [(90, (1, 2), True, 1, 'reflect', 0, True), + (-90, (2, 1), True, 3, 'constant', 0, True), + (180, (2, 3), False, 2, 'constant', 4, False)] class RotateTest(NumpyImageTestCase2D): - @parameterized.expand([ - (90, (1, 2), True, 1, 'reflect', 0, True), - (-90, (2, 1), True, 3, 'constant', 0, True), - (180, (2, 3), False, 2, 'constant', 4, False), - ]) + @parameterized.expand(TEST_CASES) def test_correct_results(self, angle, axes, reshape, order, mode, cval, prefilter): rotate_fn = Rotate(angle, axes, reshape, @@ -36,6 +35,18 @@ def test_correct_results(self, angle, axes, reshape, mode=mode, cval=cval, prefilter=prefilter) self.assertTrue(np.allclose(expected, rotated)) + @parameterized.expand(TEST_CASES) + def test_correct_results_dict(self, angle, axes, reshape, + order, mode, cval, prefilter): + key = 'img' + rotate_fn = Rotated(key, angle, axes, reshape, order, + mode, cval, prefilter) + rotated = rotate_fn({key: self.imt}) + + expected = scipy.ndimage.rotate(self.imt, angle, axes, reshape, order=order, + mode=mode, cval=cval, prefilter=prefilter) + self.assertTrue(np.allclose(expected, rotated[key])) + if __name__ == '__main__': unittest.main() diff --git a/tests/test_zoom.py b/tests/test_zoom.py index 874e587a98..83795542bc 100644 --- a/tests/test_zoom.py +++ b/tests/test_zoom.py @@ -17,17 +17,22 @@ from scipy.ndimage import zoom as zoom_scipy from parameterized import parameterized -from monai.transforms import Zoom +from monai.transforms import Zoom, Zoomd from tests.utils import NumpyImageTestCase2D +VALID_CASES = [(1.1, 3, 'constant', 0, True, False, False), + (0.9, 3, 'constant', 0, True, False, False), + (0.8, 1, 'reflect', 0, False, False, False)] + +GPU_CASES = [("gpu_zoom", 0.6, 1, 'constant', 0, True)] + +INVALID_CASES = [("no_zoom", None, 1, TypeError), + ("invalid_order", 0.9, 's', AssertionError)] + class ZoomTest(NumpyImageTestCase2D): - @parameterized.expand([ - (1.1, 3, 'constant', 0, True, False, False), - (0.9, 3, 'constant', 0, True, False, False), - (0.8, 1, 'reflect', 0, False, False, False) - ]) + @parameterized.expand(VALID_CASES) def test_correct_results(self, zoom, order, mode, cval, prefilter, use_gpu, keep_size): zoom_fn = Zoom(zoom=zoom, order=order, mode=mode, cval=cval, prefilter=prefilter, use_gpu=use_gpu, keep_size=keep_size) @@ -36,9 +41,19 @@ def test_correct_results(self, zoom, order, mode, cval, prefilter, use_gpu, keep cval=cval, prefilter=prefilter) self.assertTrue(np.allclose(expected, zoomed)) - @parameterized.expand([ - ("gpu_zoom", 0.6, 1, 'constant', 0, True) - ]) + @parameterized.expand(VALID_CASES) + def test_correct_results_dict(self, zoom, order, mode, cval, prefilter, use_gpu, keep_size): + key = 'img' + zoom_fn = Zoomd(key, zoom=zoom, order=order, mode=mode, cval=cval, + prefilter=prefilter, use_gpu=use_gpu, keep_size=keep_size) + zoomed = zoom_fn({key: self.imt[0]}) + + expected = zoom_scipy(self.imt, zoom=zoom, mode=mode, order=order, + cval=cval, prefilter=prefilter) + self.assertTrue(np.allclose(expected, zoomed[key])) + + + @parameterized.expand(GPU_CASES) def test_gpu_zoom(self, _, zoom, order, mode, cval, prefilter): if importlib.util.find_spec('cupy'): zoom_fn = Zoom(zoom=zoom, order=order, mode=mode, cval=cval, @@ -57,10 +72,7 @@ def test_keep_size(self): zoomed = zoom_fn(self.imt[0]) self.assertTrue(np.array_equal(zoomed.shape, self.imt.shape[1:])) - @parameterized.expand([ - ("no_zoom", None, 1, TypeError), - ("invalid_order", 0.9, 's', AssertionError) - ]) + @parameterized.expand(INVALID_CASES) def test_invalid_inputs(self, _, zoom, order, raises): with self.assertRaises(raises): zoom_fn = Zoom(zoom=zoom, order=order) From f0d4ebfe3fd0e7b412e9fb7317f8bd1354bfff9a Mon Sep 17 00:00:00 2001 From: root Date: Thu, 12 Mar 2020 02:39:45 +0000 Subject: [PATCH 3/3] [DLMED] add 3D classification inference examples --- ...py => densenet_classification_3d_array.py} | 36 +++-- examples/densenet_classification_3d_dict.py | 147 ++++++++++++++++++ examples/densenet_inference_3d_array.py | 93 +++++++++++ examples/densenet_inference_3d_dict.py | 95 +++++++++++ examples/unet_inference_3d_array.py | 6 +- examples/unet_inference_3d_dict.py | 2 +- examples/unet_segmentation_3d_array.py | 16 +- examples/unet_segmentation_3d_dict.py | 12 +- monai/handlers/classification_saver.py | 91 +++++++++++ monai/transforms/composables.py | 56 ++++++- monai/transforms/transforms.py | 34 ++-- tests/test_dice_loss.py | 2 +- tests/test_generalized_dice_loss.py | 2 +- tests/test_resize.py | 21 +-- 14 files changed, 548 insertions(+), 65 deletions(-) rename examples/{densenet_classification_3d.py => densenet_classification_3d_array.py} (83%) create mode 100644 examples/densenet_classification_3d_dict.py create mode 100644 examples/densenet_inference_3d_array.py create mode 100644 examples/densenet_inference_3d_dict.py create mode 100644 monai/handlers/classification_saver.py diff --git a/examples/densenet_classification_3d.py b/examples/densenet_classification_3d_array.py similarity index 83% rename from examples/densenet_classification_3d.py rename to examples/densenet_classification_3d_array.py index 753097a2a9..001eb079e2 100644 --- a/examples/densenet_classification_3d.py +++ b/examples/densenet_classification_3d_array.py @@ -23,14 +23,14 @@ import monai.transforms.compose as transforms from monai.data.nifti_reader import NiftiDataset -from monai.transforms import (AddChannel, Rescale, ToTensor, UniformRandomPatch) +from monai.transforms import (AddChannel, Rescale, Resize, RandRotate90) from monai.handlers.stats_handler import StatsHandler from ignite.metrics import Accuracy from monai.handlers.utils import stopping_fn_from_metric monai.config.print_config() -# FIXME: temp test dataset, Wenqi will replace later +# demo dataset, user can easily change to own dataset images = [ "/workspace/data/medical/ixi/IXI-T1/IXI314-IOP-0889-T1.nii.gz", "/workspace/data/medical/ixi/IXI-T1/IXI249-Guys-1072-T1.nii.gz", @@ -57,18 +57,23 @@ 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0 ]) -# Define transforms for image and segmentation -imtrans = transforms.Compose([ +# Define transforms +train_transforms = transforms.Compose([ Rescale(), AddChannel(), - UniformRandomPatch((96, 96, 96)), - ToTensor() + Resize((96, 96, 96)), + RandRotate90() +]) +val_transforms = transforms.Compose([ + Rescale(), + AddChannel(), + Resize((96, 96, 96)) ]) # Define nifti dataset, dataloader. -ds = NiftiDataset(image_files=images, labels=labels, transform=imtrans) -loader = DataLoader(ds, batch_size=2, num_workers=2, pin_memory=torch.cuda.is_available()) -im, label = monai.utils.misc.first(loader) +check_ds = NiftiDataset(image_files=images, labels=labels, transform=train_transforms) +check_loader = DataLoader(check_ds, batch_size=2, num_workers=2, pin_memory=torch.cuda.is_available()) +im, label = monai.utils.misc.first(check_loader) print(type(im), im.shape, label) lr = 1e-5 @@ -84,7 +89,8 @@ # Create trainer device = torch.device("cuda:0") -trainer = create_supervised_trainer(net, opt, loss, device, False) +trainer = create_supervised_trainer(net, opt, loss, device, False, + output_transform=lambda x, y, y_pred, loss: [y_pred, loss.item(), y]) # adding checkpoint handler to save models (network params and optimizer stats) during training checkpoint_handler = ModelCheckpoint('./runs/', 'net', n_saved=10, require_empty=False) @@ -94,10 +100,6 @@ train_stats_handler = StatsHandler(output_transform=lambda x: x[3]) train_stats_handler.attach(trainer) -@trainer.on(Events.EPOCH_COMPLETED) -def log_training_loss(engine): - engine.logger.info("Epoch[%s] Loss: %s", engine.state.epoch, engine.state.output) - # Set parameters for validation validation_every_n_epochs = 1 metric_name = 'Accuracy' @@ -118,8 +120,8 @@ def log_training_loss(engine): evaluator.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=early_stopper) # create a validation data loader -val_ds = NiftiDataset(image_files=images[-5:], labels=labels[-5:], transform=imtrans) -val_loader = DataLoader(ds, batch_size=2, num_workers=2, pin_memory=torch.cuda.is_available()) +val_ds = NiftiDataset(image_files=images[-10:], labels=labels[-10:], transform=val_transforms) +val_loader = DataLoader(val_ds, batch_size=2, num_workers=2, pin_memory=torch.cuda.is_available()) @trainer.on(Events.EPOCH_COMPLETED(every=validation_every_n_epochs)) @@ -129,7 +131,7 @@ def run_validation(engine): # create a training data loader logging.basicConfig(stream=sys.stdout, level=logging.INFO) -train_ds = NiftiDataset(image_files=images[:15], labels=labels[:15], transform=imtrans) +train_ds = NiftiDataset(image_files=images[:10], labels=labels[:10], transform=train_transforms) train_loader = DataLoader(train_ds, batch_size=2, num_workers=2, pin_memory=torch.cuda.is_available()) train_epochs = 30 diff --git a/examples/densenet_classification_3d_dict.py b/examples/densenet_classification_3d_dict.py new file mode 100644 index 0000000000..1be49de05d --- /dev/null +++ b/examples/densenet_classification_3d_dict.py @@ -0,0 +1,147 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import logging +import numpy as np +import torch +from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator, _prepare_batch +from ignite.handlers import ModelCheckpoint, EarlyStopping +from torch.utils.data import DataLoader + +# assumes the framework is found here, change as necessary +sys.path.append("..") +import monai +import monai.transforms.compose as transforms +from monai.transforms.composables import \ + LoadNiftid, AddChanneld, Rescaled, Resized, RandRotate90d +from monai.handlers.stats_handler import StatsHandler +from ignite.metrics import Accuracy +from monai.handlers.utils import stopping_fn_from_metric + +monai.config.print_config() + +# demo dataset, user can easily change to own dataset +images = [ + "/workspace/data/medical/ixi/IXI-T1/IXI314-IOP-0889-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI249-Guys-1072-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI609-HH-2600-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI173-HH-1590-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI020-Guys-0700-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI342-Guys-0909-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI134-Guys-0780-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI577-HH-2661-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI066-Guys-0731-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI130-HH-1528-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI607-Guys-1097-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI175-HH-1570-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI385-HH-2078-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI344-Guys-0905-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI409-Guys-0960-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI584-Guys-1129-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI253-HH-1694-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI092-HH-1436-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI574-IOP-1156-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI585-Guys-1130-T1.nii.gz" +] +labels = np.array([ + 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0 +]) +train_files = [{'img': img, 'label': label} for img, label in zip(images[:10], labels[:10])] +val_files = [{'img': img, 'label': label} for img, label in zip(images[-10:], labels[-10:])] + +# Define transforms for image +train_transforms = transforms.Compose([ + LoadNiftid(keys=['img']), + AddChanneld(keys=['img']), + Rescaled(keys=['img']), + Resized(keys=['img'], output_shape=(96, 96, 96)), + RandRotate90d(keys=['img'], prob=0.8, axes=[1, 3]) +]) +val_transforms = transforms.Compose([ + LoadNiftid(keys=['img']), + AddChanneld(keys=['img']), + Rescaled(keys=['img']), + Resized(keys=['img'], output_shape=(96, 96, 96)) +]) + +# Define dataset, dataloader. +check_ds = monai.data.Dataset(data=train_files, transform=train_transforms) +check_loader = DataLoader(check_ds, batch_size=2, num_workers=4, pin_memory=torch.cuda.is_available()) +check_data = monai.utils.misc.first(check_loader) +print(check_data['img'].shape, check_data['label']) + +lr = 1e-5 + +# Create DenseNet121, CrossEntropyLoss and Adam optimizer. +net = monai.networks.nets.densenet3d.densenet121( + in_channels=1, + out_channels=2, +) + +loss = torch.nn.CrossEntropyLoss() +opt = torch.optim.Adam(net.parameters(), lr) + + +# Create trainer +def prepare_batch(batch, device=None, non_blocking=False): + return _prepare_batch((batch['img'], batch['label']), device, non_blocking) + + +# Create trainer +device = torch.device("cuda:0") +trainer = create_supervised_trainer(net, opt, loss, device, False, prepare_batch=prepare_batch, + output_transform=lambda x, y, y_pred, loss: [y_pred, loss.item(), y]) + +# adding checkpoint handler to save models (network params and optimizer stats) during training +checkpoint_handler = ModelCheckpoint('./runs/', 'net', n_saved=10, require_empty=False) +trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED, + handler=checkpoint_handler, + to_save={'net': net, 'opt': opt}) +train_stats_handler = StatsHandler() +train_stats_handler.attach(trainer) + +# Set parameters for validation +validation_every_n_epochs = 1 +metric_name = 'Accuracy' + +# add evaluation metric to the evaluator engine +val_metrics = {metric_name: Accuracy()} +evaluator = create_supervised_evaluator(net, val_metrics, device, True, prepare_batch=prepare_batch) + +# Add stats event handler to print validation stats via evaluator +logging.basicConfig(stream=sys.stdout, level=logging.INFO) +val_stats_handler = StatsHandler() +val_stats_handler.attach(evaluator) + +# Add early stopping handler to evaluator. +early_stopper = EarlyStopping(patience=4, + score_function=stopping_fn_from_metric(metric_name), + trainer=trainer) +evaluator.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=early_stopper) + +# create a validation data loader +val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) +val_loader = DataLoader(val_ds, batch_size=2, num_workers=4, pin_memory=torch.cuda.is_available()) + + +@trainer.on(Events.EPOCH_COMPLETED(every=validation_every_n_epochs)) +def run_validation(engine): + evaluator.run(val_loader) + +# create a training data loader +logging.basicConfig(stream=sys.stdout, level=logging.INFO) + +train_ds = monai.data.Dataset(data=train_files, transform=train_transforms) +train_loader = DataLoader(train_ds, batch_size=2, num_workers=4, pin_memory=torch.cuda.is_available()) + +train_epochs = 30 +state = trainer.run(train_loader, train_epochs) diff --git a/examples/densenet_inference_3d_array.py b/examples/densenet_inference_3d_array.py new file mode 100644 index 0000000000..25908bb821 --- /dev/null +++ b/examples/densenet_inference_3d_array.py @@ -0,0 +1,93 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import logging +import numpy as np +import torch +from ignite.engine import create_supervised_evaluator, _prepare_batch +from torch.utils.data import DataLoader + +# assumes the framework is found here, change as necessary +sys.path.append("..") +import monai +import monai.transforms.compose as transforms +from monai.data.nifti_reader import NiftiDataset +from monai.transforms import (AddChannel, Rescale, Resize) +from monai.handlers.stats_handler import StatsHandler +from monai.handlers.classification_saver import ClassificationSaver +from monai.handlers.checkpoint_loader import CheckpointLoader +from ignite.metrics import Accuracy + +monai.config.print_config() + +# demo dataset, user can easily change to own dataset +images = [ + "/workspace/data/medical/ixi/IXI-T1/IXI607-Guys-1097-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI175-HH-1570-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI385-HH-2078-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI344-Guys-0905-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI409-Guys-0960-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI584-Guys-1129-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI253-HH-1694-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI092-HH-1436-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI574-IOP-1156-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI585-Guys-1130-T1.nii.gz" +] +labels = np.array([ + 0, 0, 1, 0, 1, 0, 1, 0, 1, 0 +]) + +# Define transforms for image +val_transforms = transforms.Compose([ + Rescale(), + AddChannel(), + Resize((96, 96, 96)) +]) +# Define nifti dataset +val_ds = NiftiDataset(image_files=images, labels=labels, transform=val_transforms, image_only=False) +# Create DenseNet121 +net = monai.networks.nets.densenet3d.densenet121( + in_channels=1, + out_channels=2, +) + +device = torch.device("cuda:0") +metric_name = 'Accuracy' + +# add evaluation metric to the evaluator engine +val_metrics = {metric_name: Accuracy()} + + +def prepare_batch(batch, device=None, non_blocking=False): + return _prepare_batch((batch[0], batch[1]), device, non_blocking) + + +evaluator = create_supervised_evaluator(net, val_metrics, device, True, prepare_batch=prepare_batch) + +# Add stats event handler to print validation stats via evaluator +logging.basicConfig(stream=sys.stdout, level=logging.INFO) +val_stats_handler = StatsHandler() +val_stats_handler.attach(evaluator) + +# for the arrary data format, assume the 3rd item of batch data is the meta_data +prediction_saver = ClassificationSaver(output_dir='tempdir', batch_transform=lambda batch: batch[2], + output_transform=lambda output: output[0].argmax(1)) +prediction_saver.attach(evaluator) + +# the model was trained by "densenet_classification_3d_array" exmple +CheckpointLoader(load_path='./runs/net_checkpoint_40.pth', load_dict={'net': net}).attach(evaluator) + +# create a validation data loader +val_loader = DataLoader(val_ds, batch_size=2, num_workers=4, pin_memory=torch.cuda.is_available()) + +state = evaluator.run(val_loader) +prediction_saver.finalize() diff --git a/examples/densenet_inference_3d_dict.py b/examples/densenet_inference_3d_dict.py new file mode 100644 index 0000000000..8f7dc75c72 --- /dev/null +++ b/examples/densenet_inference_3d_dict.py @@ -0,0 +1,95 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ignite.metrics import Accuracy +import sys +import logging +import numpy as np +import torch +from ignite.engine import create_supervised_evaluator, _prepare_batch +from torch.utils.data import DataLoader + +# assumes the framework is found here, change as necessary +sys.path.append("..") +from monai.handlers.classification_saver import ClassificationSaver +from monai.handlers.checkpoint_loader import CheckpointLoader +from monai.handlers.stats_handler import StatsHandler +from monai.transforms.composables import LoadNiftid, AddChanneld, Rescaled, Resized +import monai.transforms.compose as transforms +import monai + +monai.config.print_config() + +# demo dataset, user can easily change to own dataset +images = [ + "/workspace/data/medical/ixi/IXI-T1/IXI607-Guys-1097-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI175-HH-1570-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI385-HH-2078-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI344-Guys-0905-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI409-Guys-0960-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI584-Guys-1129-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI253-HH-1694-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI092-HH-1436-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI574-IOP-1156-T1.nii.gz", + "/workspace/data/medical/ixi/IXI-T1/IXI585-Guys-1130-T1.nii.gz" +] +labels = np.array([ + 0, 0, 1, 0, 1, 0, 1, 0, 1, 0 +]) +val_files = [{'img': img, 'label': label} for img, label in zip(images, labels)] + +# Define transforms for image +val_transforms = transforms.Compose([ + LoadNiftid(keys=['img']), + AddChanneld(keys=['img']), + Rescaled(keys=['img']), + Resized(keys=['img'], output_shape=(96, 96, 96)) +]) + +# Create DenseNet121 +net = monai.networks.nets.densenet3d.densenet121( + in_channels=1, + out_channels=2, +) + + +def prepare_batch(batch, device=None, non_blocking=False): + return _prepare_batch((batch['img'], batch['label']), device, non_blocking) + + +# Create trainer +device = torch.device("cuda:0") +metric_name = 'Accuracy' + +# add evaluation metric to the evaluator engine +val_metrics = {metric_name: Accuracy()} +evaluator = create_supervised_evaluator(net, val_metrics, device, True, prepare_batch=prepare_batch) + +# Add stats event handler to print validation stats via evaluator +logging.basicConfig(stream=sys.stdout, level=logging.INFO) +val_stats_handler = StatsHandler() +val_stats_handler.attach(evaluator) + +# for the arrary data format, assume the 3rd item of batch data is the meta_data +prediction_saver = ClassificationSaver(output_dir='tempdir', batch_transform=lambda batch: { + 'filename_or_obj': batch['img.filename_or_obj']}, + output_transform=lambda output: output[0].argmax(1)) +prediction_saver.attach(evaluator) + +# the model was trained by "densenet_classification_3d_dict" exmple +CheckpointLoader(load_path='./runs/net_checkpoint_40.pth', load_dict={'net': net}).attach(evaluator) + +# create a validation data loader +val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) +val_loader = DataLoader(val_ds, batch_size=2, num_workers=4, pin_memory=torch.cuda.is_available()) + +state = evaluator.run(val_loader) +prediction_saver.finalize() diff --git a/examples/unet_inference_3d_array.py b/examples/unet_inference_3d_array.py index 8fe417c7dd..ea87590285 100644 --- a/examples/unet_inference_3d_array.py +++ b/examples/unet_inference_3d_array.py @@ -28,7 +28,7 @@ from monai.handlers.segmentation_saver import SegmentationSaver import monai.transforms.compose as transforms from monai.data.nifti_reader import NiftiDataset -from monai.transforms import AddChannel, Rescale, ToTensor +from monai.transforms import AddChannel, Rescale from monai.networks.nets.unet import UNet from monai.networks.utils import predict_segmentation from monai.data.synthetic import create_test_image_3d @@ -50,8 +50,8 @@ images = sorted(glob(os.path.join(tempdir, 'im*.nii.gz'))) segs = sorted(glob(os.path.join(tempdir, 'seg*.nii.gz'))) -imtrans = transforms.Compose([Rescale(), AddChannel(), ToTensor()]) -segtrans = transforms.Compose([AddChannel(), ToTensor()]) +imtrans = transforms.Compose([Rescale(), AddChannel()]) +segtrans = transforms.Compose([AddChannel()]) ds = NiftiDataset(images, segs, transform=imtrans, seg_transform=segtrans, image_only=False) device = torch.device("cuda:0") diff --git a/examples/unet_inference_3d_dict.py b/examples/unet_inference_3d_dict.py index 405b49aa8d..7274453b22 100644 --- a/examples/unet_inference_3d_dict.py +++ b/examples/unet_inference_3d_dict.py @@ -87,7 +87,7 @@ def _sliding_window_processor(engine, batch): 'original_affine': batch['img.original_affine'], 'affine': batch['img.affine'], }).attach(infer_engine) -# the model was trained by "unet_segmentation_3d_array" exmple +# the model was trained by "unet_segmentation_3d_dict" exmple CheckpointLoader(load_path='./runs/net_checkpoint_120.pth', load_dict={'net': net}).attach(infer_engine) val_loader = DataLoader(val_ds, batch_size=1, num_workers=4, collate_fn=list_data_collate, diff --git a/examples/unet_segmentation_3d_array.py b/examples/unet_segmentation_3d_array.py index 3b0d880f10..d43ead1d2b 100644 --- a/examples/unet_segmentation_3d_array.py +++ b/examples/unet_segmentation_3d_array.py @@ -29,7 +29,7 @@ import monai.transforms.compose as transforms from monai.data.nifti_reader import NiftiDataset -from monai.transforms import AddChannel, Rescale, ToTensor, UniformRandomPatch +from monai.transforms import AddChannel, Rescale, UniformRandomPatch from monai.handlers.stats_handler import StatsHandler from monai.handlers.tensorboard_handlers import TensorBoardStatsHandler, TensorBoardImageHandler from monai.handlers.mean_dice import MeanDice @@ -57,19 +57,17 @@ imtrans = transforms.Compose([ Rescale(), AddChannel(), - UniformRandomPatch((96, 96, 96)), - ToTensor() + UniformRandomPatch((96, 96, 96)) ]) segtrans = transforms.Compose([ AddChannel(), - UniformRandomPatch((96, 96, 96)), - ToTensor() + UniformRandomPatch((96, 96, 96)) ]) # Define nifti dataset, dataloader. -ds = NiftiDataset(images, segs, transform=imtrans, seg_transform=segtrans) -loader = DataLoader(ds, batch_size=10, num_workers=2, pin_memory=torch.cuda.is_available()) -im, seg = monai.utils.misc.first(loader) +check_ds = NiftiDataset(images, segs, transform=imtrans, seg_transform=segtrans) +check_loader = DataLoader(check_ds, batch_size=10, num_workers=2, pin_memory=torch.cuda.is_available()) +im, seg = monai.utils.misc.first(check_loader) print(im.shape, seg.shape) lr = 1e-5 @@ -159,7 +157,7 @@ def log_training_loss(engine): # create a validation data loader val_ds = NiftiDataset(images[-20:], segs[-20:], transform=imtrans, seg_transform=segtrans) -val_loader = DataLoader(ds, batch_size=5, num_workers=8, pin_memory=torch.cuda.is_available()) +val_loader = DataLoader(val_ds, batch_size=5, num_workers=8, pin_memory=torch.cuda.is_available()) @trainer.on(Events.EPOCH_COMPLETED(every=validation_every_n_epochs)) diff --git a/examples/unet_segmentation_3d_dict.py b/examples/unet_segmentation_3d_dict.py index 0e1e78811b..b8622bb7eb 100644 --- a/examples/unet_segmentation_3d_dict.py +++ b/examples/unet_segmentation_3d_dict.py @@ -68,11 +68,11 @@ AsChannelFirstd(keys=['img', 'seg'], channel_dim=-1) ]) -# Define nifti dataset, dataloader. -ds = monai.data.Dataset(data=train_files, transform=train_transforms) -loader = DataLoader(ds, batch_size=2, num_workers=4, collate_fn=list_data_collate, - pin_memory=torch.cuda.is_available()) -check_data = monai.utils.misc.first(loader) +# Define dataset, dataloader. +check_ds = monai.data.Dataset(data=train_files, transform=train_transforms) +check_loader = DataLoader(check_ds, batch_size=2, num_workers=4, collate_fn=list_data_collate, + pin_memory=torch.cuda.is_available()) +check_data = monai.utils.misc.first(check_loader) print(check_data['img'].shape, check_data['seg'].shape) lr = 1e-5 @@ -171,7 +171,7 @@ def log_training_loss(engine): # create a validation data loader val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) -val_loader = DataLoader(ds, batch_size=5, num_workers=8, collate_fn=list_data_collate, +val_loader = DataLoader(val_ds, batch_size=5, num_workers=8, collate_fn=list_data_collate, pin_memory=torch.cuda.is_available()) diff --git a/monai/handlers/classification_saver.py b/monai/handlers/classification_saver.py new file mode 100644 index 0000000000..48f0ac59bb --- /dev/null +++ b/monai/handlers/classification_saver.py @@ -0,0 +1,91 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import csv +import numpy as np +import torch +from ignite.engine import Events + + +class ClassificationSaver: + """ + Event handler triggered on completing every iteration to save the classification predictions as CSV file. + """ + + def __init__(self, output_dir='./', overwrite=True, + batch_transform=lambda x: x, output_transform=lambda x: x, name=None): + """ + Args: + output_dir (str): output CSV file directory. + overwrite (bool): whether to overwriting existing CSV file content. If we are not overwriting, + then we check if the results have been previously saved, and load them to the prediction_dict. + batch_transform (Callable): a callable that is used to transform the + ignite.engine.batch into expected format to extract the meta_data dictionary. + output_transform (Callable): a callable that is used to transform the + ignite.engine.output into the form expected model prediction data. + The first dimension of this transform's output will be treated as the + batch dimension. Each item in the batch will be saved individually. + name (str): identifier of logging.logger to use, defaulting to `engine.logger`. + + """ + self.output_dir = output_dir + self._prediction_dict = {} + self._preds_filepath = os.path.join(output_dir, 'predictions.csv') + if not overwrite: + if os.path.exists(self._preds_filepath): + with open(self._path, 'r') as f: + reader = csv.reader(f) + for row in reader: + self._prediction_dict[row[0]] = np.array(row[1:]).astype(np.float32) + + self.batch_transform = batch_transform + self.output_transform = output_transform + + self.logger = None if name is None else logging.getLogger(name) + + def attach(self, engine): + if self.logger is None: + self.logger = engine.logger + return engine.add_event_handler(Events.ITERATION_COMPLETED, self) + + def finalize(self): + """ + Writes the prediction dict to a csv + + """ + if not os.path.exists(self.output_dir): + os.makedirs(self.output_dir) + with open(self._preds_filepath, 'w') as f: + for k, v in sorted(self._prediction_dict.items()): + f.write(k) + for result in v.flatten(): + f.write("," + str(result)) + f.write("\n") + self.logger.info('saved classification predictions into: {}'.format(self._preds_filepath)) + + def __call__(self, engine): + """ + This method assumes self.batch_transform will extract Metadata from the input batch. + Metadata should have the following keys: + + - ``'filename_or_obj'`` -- save the prediction corresponding to file name. + + """ + meta_data = self.batch_transform(engine.state.batch) + filenames = meta_data['filename_or_obj'] + + engine_output = self.output_transform(engine.state.output) + for batch_id, filename in enumerate(filenames): # save a batch of files + output = engine_output[batch_id] + if isinstance(output, torch.Tensor): + output = output.detach().cpu().numpy() + self._prediction_dict[filename] = output diff --git a/monai/transforms/composables.py b/monai/transforms/composables.py index c19c3f7df9..b27b64780e 100644 --- a/monai/transforms/composables.py +++ b/monai/transforms/composables.py @@ -14,6 +14,7 @@ """ import torch +import numpy as np from collections.abc import Hashable import monai @@ -23,7 +24,7 @@ from monai.transforms.transforms import (LoadNifti, AsChannelFirst, Orientation, AddChannel, Spacing, Rotate90, SpatialCrop, RandAffine, Rand2DElastic, Rand3DElastic, - Flip, Rotate, Zoom) + Rescale, Resize, Flip, Rotate, Zoom) from monai.utils.misc import ensure_tuple from monai.transforms.utils import generate_pos_neg_label_crop_centers, create_grid from monai.utils.aliases import alias @@ -249,6 +250,59 @@ def __call__(self, data): return d +@export +@alias('RescaleD', 'RescaleDict') +class Rescaled(MapTransform): + """ + dictionary-based wrapper of Rescale. + """ + + def __init__(self, keys, minv=0.0, maxv=1.0, dtype=np.float32): + MapTransform.__init__(self, keys) + self.rescaler = Rescale(minv, maxv, dtype) + + def __call__(self, data): + d = dict(data) + for key in self.keys: + d[key] = self.rescaler(d[key]) + return d + + +@export +@alias('ResizeD', 'ResizeDict') +class Resized(MapTransform): + """ + dictionary-based wrapper of Resize. + Args: + keys (hashable items): keys of the corresponding items to be transformed. + See also: monai.transform.composables.MapTransform + output_shape (tuple or list): expected shape after resize operation. + order (int): Order of spline interpolation. Default=1. + mode (str): Points outside boundaries are filled according to given mode. + Options are 'constant', 'edge', 'symmetric', 'reflect', 'wrap'. + cval (float): Used with mode 'constant', the value outside image boundaries. + clip (bool): Wheter to clip range of output values after interpolation. Default: True. + preserve_range (bool): Whether to keep original range of values. Default is True. + If False, input is converted according to conventions of img_as_float. See + https://scikit-image.org/docs/dev/user_guide/data_types.html. + anti_aliasing (bool): Whether to apply a gaussian filter to image before down-scaling. + Default is True. + anti_aliasing_sigma (float, tuple of floats): Standard deviation for gaussian filtering. + """ + + def __init__(self, keys, output_shape, order=1, mode='reflect', cval=0, + clip=True, preserve_range=True, anti_aliasing=True, anti_aliasing_sigma=None): + MapTransform.__init__(self, keys) + self.resizer = Resize(output_shape, order, mode, cval, clip, preserve_range, + anti_aliasing, anti_aliasing_sigma) + + def __call__(self, data): + d = dict(data) + for key in self.keys: + d[key] = self.resizer(d[key]) + return d + + @export @alias('UniformRandomPatchD', 'UniformRandomPatchDict') class UniformRandomPatchd(Randomizable, MapTransform): diff --git a/monai/transforms/transforms.py b/monai/transforms/transforms.py index 8f140972f6..be6d38a4fa 100644 --- a/monai/transforms/transforms.py +++ b/monai/transforms/transforms.py @@ -289,6 +289,7 @@ class Resize: For additional details, see https://scikit-image.org/docs/dev/api/skimage.transform.html#skimage.transform.resize. Args: + output_shape (tuple or list): expected shape after resize operation. order (int): Order of spline interpolation. Default=1. mode (str): Points outside boundaries are filled according to given mode. Options are 'constant', 'edge', 'symmetric', 'reflect', 'wrap'. @@ -315,11 +316,16 @@ def __init__(self, output_shape, order=1, mode='reflect', cval=0, self.anti_aliasing_sigma = anti_aliasing_sigma def __call__(self, img): - return resize(img, self.output_shape, order=self.order, - mode=self.mode, cval=self.cval, - clip=self.clip, preserve_range=self.preserve_range, - anti_aliasing=self.anti_aliasing, - anti_aliasing_sigma=self.anti_aliasing_sigma) + resized = list() + for channel in img: + resized.append( + resize(channel, self.output_shape, order=self.order, + mode=self.mode, cval=self.cval, + clip=self.clip, preserve_range=self.preserve_range, + anti_aliasing=self.anti_aliasing, + anti_aliasing_sigma=self.anti_aliasing_sigma) + ) + return np.stack(resized).astype(np.float32) @export @@ -353,7 +359,7 @@ def __init__(self, angle, axes=(1, 2), reshape=True, order=1, mode='constant', c def __call__(self, img): return scipy.ndimage.rotate(img, self.angle, self.axes, reshape=self.reshape, order=self.order, mode=self.mode, cval=self.cval, - prefilter=self.prefilter) + prefilter=self.prefilter).astype(np.float32) @export @@ -420,7 +426,7 @@ def __call__(self, img): mode=self.mode, cval=self.cval, prefilter=self.prefilter)) - zoomed = np.stack(zoomed) + zoomed = np.stack(zoomed).astype(np.float32) if not self.keep_size or np.allclose(img.shape, zoomed.shape): return zoomed @@ -478,16 +484,14 @@ class IntensityNormalizer: Args: subtrahend (ndarray): the amount to subtract by (usually the mean) divisor (ndarray): the amount to divide by (usually the standard deviation) - dtype: output data format """ - def __init__(self, subtrahend=None, divisor=None, dtype=np.float32): + def __init__(self, subtrahend=None, divisor=None): if subtrahend is not None or divisor is not None: assert isinstance(subtrahend, np.ndarray) and isinstance(divisor, np.ndarray), \ 'subtrahend and divisor must be set in pair and in numpy array.' self.subtrahend = subtrahend self.divisor = divisor - self.dtype = dtype def __call__(self, img): if self.subtrahend is not None and self.divisor is not None: @@ -497,8 +501,6 @@ def __call__(self, img): img -= np.mean(img) img /= np.std(img) - if self.dtype != img.dtype: - img = img.astype(self.dtype) return img @@ -511,15 +513,13 @@ class ImageEndPadder: Args: out_size (list): the size of region of interest at the end of the operation. mode (string): a portion from numpy.lib.arraypad.pad is copied below. - dtype: output data format. """ - def __init__(self, out_size, mode, dtype=np.float32): - assert out_size is not None and isinstance(out_size, (list, tuple)), 'out_size must be list or tuple' + def __init__(self, out_size, mode): + assert out_size is not None and isinstance(out_size, (list, tuple)), 'out_size must be list or tuple.' self.out_size = out_size - assert isinstance(mode, str), 'mode must be str' + assert isinstance(mode, str), 'mode must be str.' self.mode = mode - self.dtype = dtype def _determine_data_pad_width(self, data_shape): return [(0, max(self.out_size[i] - data_shape[i], 0)) for i in range(len(self.out_size))] diff --git a/tests/test_dice_loss.py b/tests/test_dice_loss.py index c5640a5660..e937185f91 100644 --- a/tests/test_dice_loss.py +++ b/tests/test_dice_loss.py @@ -82,7 +82,7 @@ TEST_CASE_6 = [ # shape: (1, 1, 2, 2), (1, 1, 2, 2) { - 'include_background': False, + 'include_background': True, 'do_sigmoid': True, }, { diff --git a/tests/test_generalized_dice_loss.py b/tests/test_generalized_dice_loss.py index e08ff1d296..b2ce96169e 100644 --- a/tests/test_generalized_dice_loss.py +++ b/tests/test_generalized_dice_loss.py @@ -94,7 +94,7 @@ TEST_CASE_6 = [ # shape: (1, 1, 2, 2), (1, 1, 2, 2) { - 'include_background': False, + 'include_background': True, 'do_sigmoid': True, }, { diff --git a/tests/test_resize.py b/tests/test_resize.py index 7feaf9f634..7a3d327b28 100644 --- a/tests/test_resize.py +++ b/tests/test_resize.py @@ -30,9 +30,9 @@ def test_invalid_inputs(self, _, order, raises): resize(self.imt) @parameterized.expand([ - ((1, 1, 64, 64), 1, 'reflect', 0, True, True, True, None), - ((1, 1, 32, 32), 2, 'constant', 3, False, False, False, None), - ((1, 1, 256, 256), 3, 'constant', 3, False, False, False, None), + ((64, 64), 1, 'reflect', 0, True, True, True, None), + ((32, 32), 2, 'constant', 3, False, False, False, None), + ((256, 256), 3, 'constant', 3, False, False, False, None), ]) def test_correct_results(self, output_shape, order, mode, cval, clip, preserve_range, @@ -40,12 +40,15 @@ def test_correct_results(self, output_shape, order, mode, resize = Resize(output_shape, order, mode, cval, clip, preserve_range, anti_aliasing, anti_aliasing_sigma) - expected = skimage.transform.resize(self.imt, output_shape, - order=order, mode=mode, - cval=cval, clip=clip, - preserve_range=preserve_range, - anti_aliasing=anti_aliasing, - anti_aliasing_sigma=anti_aliasing_sigma) + expected = list() + for channel in self.imt: + expected.append(skimage.transform.resize(channel, output_shape, + order=order, mode=mode, + cval=cval, clip=clip, + preserve_range=preserve_range, + anti_aliasing=anti_aliasing, + anti_aliasing_sigma=anti_aliasing_sigma)) + expected = np.stack(expected).astype(np.float32) self.assertTrue(np.allclose(resize(self.imt), expected))