diff --git a/examples/densenet_classification_3d_array.py b/examples/densenet_classification_3d_array.py index 10cc0b2689..8fa61ca284 100644 --- a/examples/densenet_classification_3d_array.py +++ b/examples/densenet_classification_3d_array.py @@ -56,7 +56,7 @@ 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0 ]) -# Define transforms for image +# Define transforms train_transforms = transforms.Compose([ Rescale(), AddChannel(), @@ -96,7 +96,8 @@ trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=checkpoint_handler, to_save={'net': net, 'opt': opt}) -train_stats_handler = StatsHandler() + +train_stats_handler = StatsHandler(output_transform=lambda x: x[3]) train_stats_handler.attach(trainer) # Set parameters for validation @@ -109,7 +110,8 @@ # Add stats event handler to print validation stats via evaluator logging.basicConfig(stream=sys.stdout, level=logging.INFO) -val_stats_handler = StatsHandler() + +val_stats_handler = StatsHandler(output_transform=lambda x: None) val_stats_handler.attach(evaluator) # Add early stopping handler to evaluator. diff --git a/examples/unet_segmentation_3d.ipynb b/examples/unet_segmentation_3d.ipynb index bfba0a70bf..0d49742f10 100644 --- a/examples/unet_segmentation_3d.ipynb +++ b/examples/unet_segmentation_3d.ipynb @@ -197,7 +197,7 @@ "trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED,\n", " handler=checkpoint_handler,\n", " to_save={'net': net, 'opt': opt})\n", - "train_stats_handler = StatsHandler()\n", + "train_stats_handler = StatsHandler(output_transform=lambda x: x[1])\n", "train_stats_handler.attach(trainer)\n", "\n", "writer = SummaryWriter()\n", @@ -260,7 +260,7 @@ "\n", "# Add stats event handler to print validation stats via evaluator\n", "logging.basicConfig(stream=sys.stdout, level=logging.INFO)\n", - "val_stats_handler = StatsHandler()\n", + "val_stats_handler = StatsHandler(lambda x: None)\n", "val_stats_handler.attach(evaluator)\n", "\n", "# Add early stopping handler to evaluator.\n", diff --git a/examples/unet_segmentation_3d_array.py b/examples/unet_segmentation_3d_array.py index 3e1a819e79..d43ead1d2b 100644 --- a/examples/unet_segmentation_3d_array.py +++ b/examples/unet_segmentation_3d_array.py @@ -18,7 +18,6 @@ import nibabel as nib import numpy as np import torch -from torch.utils.tensorboard import SummaryWriter from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator from ignite.handlers import ModelCheckpoint, EarlyStopping from torch.utils.data import DataLoader @@ -32,8 +31,8 @@ from monai.data.nifti_reader import NiftiDataset from monai.transforms import AddChannel, Rescale, UniformRandomPatch from monai.handlers.stats_handler import StatsHandler +from monai.handlers.tensorboard_handlers import TensorBoardStatsHandler, TensorBoardImageHandler from monai.handlers.mean_dice import MeanDice -from monai.visualize import img2tensorboard from monai.data.synthetic import create_test_image_3d from monai.handlers.utils import stopping_fn_from_metric @@ -86,70 +85,71 @@ loss = monai.losses.DiceLoss(do_sigmoid=True) opt = torch.optim.Adam(net.parameters(), lr) + # Since network outputs logits and segmentation, we need a custom function. def _loss_fn(i, j): return loss(i[0], j) + # Create trainer -device = torch.device("cuda:0") +device = torch.device("cpu:0") trainer = create_supervised_trainer(net, opt, _loss_fn, device, False, - output_transform=lambda x, y, y_pred, loss: [y_pred, loss.item(), y]) + output_transform=lambda x, y, y_pred, loss: [y_pred[1], loss.item(), y]) # adding checkpoint handler to save models (network params and optimizer stats) during training checkpoint_handler = ModelCheckpoint('./runs/', 'net', n_saved=10, require_empty=False) trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=checkpoint_handler, to_save={'net': net, 'opt': opt}) -train_stats_handler = StatsHandler() +logging.basicConfig(stream=sys.stdout, level=logging.INFO) + +# print training loss to commandline +train_stats_handler = StatsHandler(output_transform=lambda x: x[1]) train_stats_handler.attach(trainer) +# record training loss to TensorBoard at every iteration +train_tensorboard_stats_handler = TensorBoardStatsHandler( + output_transform=lambda x: {'training_dice_loss': x[1]}, # plot under tag name taining_dice_loss + global_epoch_transform=lambda x: trainer.state.epoch) +train_tensorboard_stats_handler.attach(trainer) + + @trainer.on(Events.EPOCH_COMPLETED) def log_training_loss(engine): - # log loss to tensorboard with second item of engine.state.output, loss.item() from output_transform - writer.add_scalar('Loss/train', engine.state.output[1], engine.state.epoch) - - # tensor of ones to use where for converting labels to zero and ones - ones = torch.ones(engine.state.batch[1][0].shape, dtype=torch.int32) - first_output_tensor = engine.state.output[0][1][0].detach().cpu() - # log model output to tensorboard, as three dimensional tensor with no channels dimension - img2tensorboard.add_animated_gif_no_channels(writer, "first_output_final_batch", first_output_tensor, 64, - 255, engine.state.epoch) - # get label tensor and convert to single class - first_label_tensor = torch.where(engine.state.batch[1][0] > 0, ones, engine.state.batch[1][0]) - # log label tensor to tensorboard, there is a channel dimension when getting label from batch - img2tensorboard.add_animated_gif(writer, "first_label_final_batch", first_label_tensor, 64, - 255, engine.state.epoch) - second_output_tensor = engine.state.output[0][1][1].detach().cpu() - img2tensorboard.add_animated_gif_no_channels(writer, "second_output_final_batch", second_output_tensor, 64, - 255, engine.state.epoch) - second_label_tensor = torch.where(engine.state.batch[1][1] > 0, ones, engine.state.batch[1][1]) - img2tensorboard.add_animated_gif(writer, "second_label_final_batch", second_label_tensor, 64, - 255, engine.state.epoch) - third_output_tensor = engine.state.output[0][1][2].detach().cpu() - img2tensorboard.add_animated_gif_no_channels(writer, "third_output_final_batch", third_output_tensor, 64, - 255, engine.state.epoch) - third_label_tensor = torch.where(engine.state.batch[1][2] > 0, ones, engine.state.batch[1][2]) - img2tensorboard.add_animated_gif(writer, "third_label_final_batch", third_label_tensor, 64, - 255, engine.state.epoch) engine.logger.info("Epoch[%s] Loss: %s", engine.state.epoch, engine.state.output[1]) -writer = SummaryWriter() # Set parameters for validation validation_every_n_epochs = 1 metric_name = 'Mean_Dice' # add evaluation metric to the evaluator engine -val_metrics = {metric_name: MeanDice(add_sigmoid=True)} -evaluator = create_supervised_evaluator(net, val_metrics, device, True, - output_transform=lambda x, y, y_pred: (y_pred[0], y)) +val_metrics = {metric_name: MeanDice( + add_sigmoid=True, to_onehot_y=False, output_transform=lambda output: (output[0][0], output[1])) +} +evaluator = create_supervised_evaluator(net, val_metrics, device, True) # Add stats event handler to print validation stats via evaluator -logging.basicConfig(stream=sys.stdout, level=logging.INFO) -val_stats_handler = StatsHandler() +val_stats_handler = StatsHandler( + output_transform=lambda x: None, # disable per iteration output + global_epoch_transform=lambda x: trainer.state.epoch) val_stats_handler.attach(evaluator) -# Add early stopping handler to evaluator. +# add handler to record metrics to TensorBoard at every epoch +val_tensorboard_stats_handler = TensorBoardStatsHandler( + output_transform=lambda x: None, # no iteration plot + global_epoch_transform=lambda x: trainer.state.epoch) # use epoch number from trainer +val_tensorboard_stats_handler.attach(evaluator) +# add handler to draw several images and the corresponding labels and model outputs +# here we draw the first 3 images(draw the first channel) as GIF format along Depth axis +val_tensorboard_image_handler = TensorBoardImageHandler( + batch_transform=lambda batch: (batch[0], batch[1]), + output_transform=lambda output: output[0][1], + global_iter_transform=lambda x: trainer.state.epoch +) +evaluator.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=val_tensorboard_image_handler) + +# Add early stopping handler to evaluator early_stopper = EarlyStopping(patience=4, score_function=stopping_fn_from_metric(metric_name), trainer=trainer) @@ -164,10 +164,6 @@ def log_training_loss(engine): def run_validation(engine): evaluator.run(val_loader) -@evaluator.on(Events.EPOCH_COMPLETED) -def log_metrics_to_tensorboard(engine): - for name, value in engine.state.metrics.items(): - writer.add_scalar('Metrics/{name}', value, trainer.state.epoch) # create a training data loader logging.basicConfig(stream=sys.stdout, level=logging.INFO) diff --git a/examples/unet_segmentation_3d_dict.py b/examples/unet_segmentation_3d_dict.py index 8eb2a76c73..b8622bb7eb 100644 --- a/examples/unet_segmentation_3d_dict.py +++ b/examples/unet_segmentation_3d_dict.py @@ -111,7 +111,7 @@ def prepare_batch(batch, device=None, non_blocking=False): trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=checkpoint_handler, to_save={'net': net, 'opt': opt}) -train_stats_handler = StatsHandler() +train_stats_handler = StatsHandler(output_transform=lambda x: x[1]) train_stats_handler.attach(trainer) @@ -160,7 +160,7 @@ def log_training_loss(engine): # Add stats event handler to print validation stats via evaluator logging.basicConfig(stream=sys.stdout, level=logging.INFO) -val_stats_handler = StatsHandler() +val_stats_handler = StatsHandler(output_transform=lambda x: None) val_stats_handler.attach(evaluator) # Add early stopping handler to evaluator. diff --git a/monai/handlers/segmentation_saver.py b/monai/handlers/segmentation_saver.py index e0fa50310a..1f3fe2615d 100644 --- a/monai/handlers/segmentation_saver.py +++ b/monai/handlers/segmentation_saver.py @@ -10,7 +10,7 @@ # limitations under the License. import os - +import numpy as np import torch from ignite.engine import Events @@ -114,5 +114,6 @@ def __call__(self, engine): seg_output = seg_output.detach().cpu().numpy() output_filename = self._create_file_basename(self.output_postfix, filename, self.output_path) output_filename = '{}{}'.format(output_filename, self.output_ext) - write_nifti(seg_output, affine_, output_filename, original_affine_, dtype=seg_output.dtype) + # change output to "channel last" format and write to nifti format file + write_nifti(np.moveaxis(seg_output, 0, -1), affine_, output_filename, original_affine_, dtype=seg_output.dtype) self.logger.info('saved: {}'.format(output_filename)) diff --git a/monai/handlers/stats_handler.py b/monai/handlers/stats_handler.py index 47709543fe..9d2e518919 100644 --- a/monai/handlers/stats_handler.py +++ b/monai/handlers/stats_handler.py @@ -9,47 +9,63 @@ # See the License for the specific language governing permissions and # limitations under the License. +import warnings import logging import torch from ignite.engine import Engine, Events +from monai.utils.misc import is_scalar -KEY_VAL_FORMAT = '{}: {:.4f} ' +DEFAULT_KEY_VAL_FORMAT = '{}: {:.4f} ' +DEFAULT_TAG = 'Loss' class StatsHandler(object): """StatsHandler defines a set of Ignite Event-handlers for all the log printing logics. It's can be used for any Ignite Engine(trainer, validator and evaluator). And it can support logging for epoch level and iteration level with pre-defined loggers. - By default: - (1) epoch_print_logger logs `engine.state.metrics`. - (2) iteration_print_logger logs loss value, expected output format is (y_pred, loss). + + Default behaviors: + - When EPOCH_COMPLETED, logs ``engine.state.metrics`` using ``self.logger``. + - When ITERATION_COMPELTED, logs + ``self.output_transform(engine.state.output)`` using ``self.logger``. """ def __init__(self, epoch_print_logger=None, iteration_print_logger=None, - batch_transform=lambda x: x, output_transform=lambda x: x, - name=None): + global_epoch_transform=lambda x: x, + name=None, + tag_name=DEFAULT_TAG, + key_var_format=DEFAULT_KEY_VAL_FORMAT): """ Args: epoch_print_logger (Callable): customized callable printer for epoch level logging. must accept parameter "engine", use default printer if None. iteration_print_logger (Callable): custimized callable printer for iteration level logging. must accept parameter "engine", use default printer if None. - batch_transform (Callable): a callable that is used to transform the - ignite.engine.batch into expected format to extract input data. output_transform (Callable): a callable that is used to transform the - ignite.engine.output into expected format to extract several output data. - name (str): identifier of logging.logger to use, defaulting to `engine.logger`. + ``ignite.engine.output`` into a scalar to print, or a dictionary of {key: scalar}. + in the latter case, the output string will be formated as key: value. + by default this value logging happens when every iteration completed. + global_epoch_transform (Callable): a callable that is used to customize global epoch number. + For example, in evaluation, the evaluator engine might want to print synced epoch number + with the trainer engine. + name (str): identifier of logging.logger to use, defaulting to ``engine.logger``. + tag_name (string): when iteration output is a scalar, tag_name is used to print + tag_name: scalar_value to logger. Defaults to ``'Loss'``. + key_var_format (string): a formatting string to control the output string format of key: value. """ self.epoch_print_logger = epoch_print_logger self.iteration_print_logger = iteration_print_logger - self.batch_transform = batch_transform self.output_transform = output_transform + self.global_epoch_transform = global_epoch_transform self.logger = None if name is None else logging.getLogger(name) + self.tag_name = tag_name + self.key_var_format = key_var_format + def attach(self, engine: Engine): """Register a set of Ignite Event-Handlers to a specified Ignite engine. @@ -116,12 +132,12 @@ def _default_epoch_print(self, engine: Engine): prints_dict = engine.state.metrics if not prints_dict: return - current_epoch = engine.state.epoch + current_epoch = self.global_epoch_transform(engine.state.epoch) out_str = "Epoch[{}] Metrics -- ".format(current_epoch) for name in sorted(prints_dict): value = prints_dict[name] - out_str += KEY_VAL_FORMAT.format(name, value) + out_str += self.key_var_format.format(name, value) self.logger.info(out_str) @@ -134,19 +150,42 @@ def _default_iteration_print(self, engine: Engine): engine (ignite.engine): Ignite Engine, it can be a trainer, validator or evaluator. """ - loss = self.output_transform(engine.state.output)[1] - if loss is None or (torch.is_tensor(loss) and len(loss.shape) > 0): - return # not printing multi dimensional output + loss = self.output_transform(engine.state.output) + if loss is None: + return # no printing if the output is empty + + out_str = '' + if isinstance(loss, dict): # print dictionary items + for name in sorted(loss): + value = loss[name] + if not is_scalar(value): + warnings.warn('ignoring non-scalar output in StatsHandler,' + ' make sure `output_transform(engine.state.output)` returns' + ' a scalar or dictionary of key and scalar pairs to avoid this warning.' + ' {}:{}'.format(name, type(value))) + continue # not printing multi dimensional output + out_str += self.key_var_format.format(name, value.item() if torch.is_tensor(value) else value) + else: + if is_scalar(loss): # not printing multi dimensional output + out_str += self.key_var_format.format(self.tag_name, loss.item() if torch.is_tensor(loss) else loss) + else: + warnings.warn('ignoring non-scalar output in StatsHandler,' + ' make sure `output_transform(engine.state.output)` returns' + ' a scalar or a dictionary of key and scalar pairs to avoid this warning.' + ' {}'.format(type(loss))) + + if not out_str: + return # no value to print + num_iterations = engine.state.epoch_length current_iteration = (engine.state.iteration - 1) % num_iterations + 1 current_epoch = engine.state.epoch num_epochs = engine.state.max_epochs - out_str = "Epoch: {}/{}, Iter: {}/{} -- ".format( + base_str = "Epoch: {}/{}, Iter: {}/{} --".format( current_epoch, num_epochs, current_iteration, num_iterations) - out_str += KEY_VAL_FORMAT.format('Loss', loss.item() if torch.is_tensor(loss) else loss) - self.logger.info(out_str) + self.logger.info(' '.join([base_str, out_str])) diff --git a/monai/handlers/tensorboard_handlers.py b/monai/handlers/tensorboard_handlers.py new file mode 100644 index 0000000000..2d5116bd59 --- /dev/null +++ b/monai/handlers/tensorboard_handlers.py @@ -0,0 +1,249 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import warnings +import torch +from torch.utils.tensorboard import SummaryWriter +from ignite.engine import Engine, Events +from monai.visualize import img2tensorboard +from monai.utils.misc import is_scalar +from monai.transforms.utils import rescale_array + + +class TensorBoardStatsHandler(object): + """TensorBoardStatsHandler defines a set of Ignite Event-handlers for all the TensorBoard logics. + It's can be used for any Ignite Engine(trainer, validator and evaluator). + And it can support both epoch level and iteration level with pre-defined TensorBoard event writer. + The expected data source is ignite ``engine.state.output`` and ``engine.state.metrics``. + + Default behaviors: + - When EPOCH_COMPLETED, write each dictionary item in + ``engine.state.metrics`` to TensorBoard. + - When ITERATION_COMPELTED, write each dictionary item in + ``self.output_transform(engine.state.output)`` to TensorBoard. + """ + + def __init__(self, + summary_writer=None, + epoch_event_writer=None, + iteration_event_writer=None, + output_transform=lambda x: {'Loss': x}, + global_epoch_transform=lambda x: x): + """ + Args: + summary_writer (SummaryWriter): user can specify TensorBoard SummaryWriter, + default to create a new writer. + epoch_event_writer (Callable): customized callable TensorBoard writer for epoch level. + must accept parameter "engine" and "summary_writer", use default event writer if None. + iteration_event_writer (Callable): custimized callable TensorBoard writer for iteration level. + must accept parameter "engine" and "summary_writer", use default event writer if None. + output_transform (Callable): a callable that is used to transform the + ``ignite.engine.output`` into a dictionary of (tag_name: scalar) pairs to be plotted onto tensorboard. + by default this scalar plotting happens when every iteration completed. + global_epoch_transform (Callable): a callable that is used to customize global epoch number. + For example, in evaluation, the evaluator engine might want to use trainer engines epoch number + when plotting epoch vs metric curves. + """ + self._writer = SummaryWriter() if summary_writer is None else summary_writer + self.epoch_event_writer = epoch_event_writer + self.iteration_event_writer = iteration_event_writer + self.output_transform = output_transform + self.global_epoch_transform = global_epoch_transform + + def attach(self, engine: Engine): + """Register a set of Ignite Event-Handlers to a specified Ignite engine. + + Args: + engine (ignite.engine): Ignite Engine, it can be a trainer, validator or evaluator. + + """ + if not engine.has_event_handler(self.iteration_completed, Events.ITERATION_COMPLETED): + engine.add_event_handler(Events.ITERATION_COMPLETED, self.iteration_completed) + if not engine.has_event_handler(self.epoch_completed, Events.EPOCH_COMPLETED): + engine.add_event_handler(Events.EPOCH_COMPLETED, self.epoch_completed) + + def epoch_completed(self, engine: Engine): + """handler for train or validation/evaluation epoch completed Event. + Write epoch level events, default values are from ignite state.metrics dict. + + Args: + engine (ignite.engine): Ignite Engine, it can be a trainer, validator or evaluator. + + """ + if self.epoch_event_writer is not None: + self.epoch_event_writer(engine, self._writer) + else: + self._default_epoch_writer(engine, self._writer) + + def iteration_completed(self, engine: Engine): + """handler for train or validation/evaluation iteration completed Event. + Write iteration level events, default values are from ignite state.logs dict. + + Args: + engine (ignite.engine): Ignite Engine, it can be a trainer, validator or evaluator. + + """ + if self.iteration_event_writer is not None: + self.iteration_event_writer(engine, self._writer) + else: + self._default_iteration_writer(engine, self._writer) + + def _default_epoch_writer(self, engine: Engine, writer: SummaryWriter): + """Execute epoch level event write operation based on Ignite engine.state data. + Default is to write the values from ignite state.metrics dict. + + Args: + engine (ignite.engine): Ignite Engine, it can be a trainer, validator or evaluator. + writer (SummaryWriter): TensorBoard writer, created in TensorBoardHandler. + + """ + current_epoch = self.global_epoch_transform(engine.state.epoch) + summary_dict = engine.state.metrics + for name, value in summary_dict.items(): + writer.add_scalar(name, value, current_epoch) + writer.flush() + + def _default_iteration_writer(self, engine: Engine, writer: SummaryWriter): + """Execute iteration level event write operation based on Ignite engine.state data. + Default is to write the loss value of current iteration. + + Args: + engine (ignite.engine): Ignite Engine, it can be a trainer, validator or evaluator. + writer (SummaryWriter): TensorBoard writer, created in TensorBoardHandler. + + """ + loss_dict = self.output_transform(engine.state.output) + if loss_dict is None: + return # do nothing if output is empty + if not isinstance(loss_dict, dict): + raise ValueError('TensorBoardStatsHandler requires' + ' output_transform(engine.state.output) returning a dictionary' + ' of key and scalar pairs to plot' + ' got {}.'.format(type(loss_dict))) + for name, value in loss_dict.items(): + if not is_scalar(value): + warnings.warn('ignoring non-scalar output in tensorboard curve plotting,' + ' make sure `output_transform(engine.state.output)` returns' + ' a dictionary of key and scalar pairs to avoid this warning.' + ' Got {}:{}'.format(name, type(value))) + continue + plot_value = value.item() if torch.is_tensor(value) else value + writer.add_scalar(name, plot_value, engine.state.iteration) + writer.flush() + + +class TensorBoardImageHandler(object): + """TensorBoardImageHandler is an ignite Event handler that can visualise images, labels and outputs as 2D/3D images. + 2D output (shape in Batch, channel, H, W) will be shown as simple image using the first element in the batch, + for 3D to ND output (shape in Batch, channel, H, W, D) input, + the last three dimensions will be shown as GIF image along the last axis (typically Depth). + + It's can be used for any Ignite Engine (trainer, validator and evaluator). + User can easily added it to engine for any expected Event, for example: ``EPOCH_COMPLETED``, + ``ITERATION_COMPLETED``. The expected data source is ignite's ``engine.state.batch`` and ``engine.state.output``. + + Default behavior: + - Show y_pred as images (GIF for 3D) on TensorBoard when Event triggered, + - need to use ``batch_transform`` and ``output_transform`` to specify + how many images to show and show which channel. + - Expects ``batch_transform(engine.state.batch)`` to return data + format: (image[N, channel, ...], label[N, channel, ...]). + - Expects ``output_transform(engine.state.output)`` to return a torch + tensor in format (y_pred[N, channel, ...], loss). + """ + + def __init__(self, + summary_writer=None, + batch_transform=lambda x: x, + output_transform=lambda x: x, + global_iter_transform=lambda x: x, + max_channels=1, + max_frames=64): + """ + Args: + summary_writer (SummaryWriter): user can specify TensorBoard SummaryWriter, + default to create a new writer. + batch_transform (Callable): a callable that is used to transform the + ``ignite.engine.batch`` into expected format to extract several label data. + output_transform (Callable): a callable that is used to transform the + ``ignite.engine.output`` into expected format to extract several output data. + global_iter_transform (Callable): a callable that is used to customize global step number for TensorBoard. + For example, in evaluation, the evaluator engine needs to know current epoch from trainer. + max_channels (int): number of channels to plot. + max_frames (int): number of frames for 2D-t plot. + + """ + self._writer = SummaryWriter() if summary_writer is None else summary_writer + self.batch_transform = batch_transform + self.output_transform = output_transform + self.global_iter_transform = global_iter_transform + + self.max_frames = max_frames + self.max_channels = max_channels + + def __call__(self, engine): + step = self.global_iter_transform(engine.state.iteration) + + show_images = self.batch_transform(engine.state.batch)[0] + if torch.is_tensor(show_images): + show_images = show_images.detach().cpu().numpy() + if show_images is not None: + if not isinstance(show_images, np.ndarray): + raise ValueError('output_transform(engine.state.output)[0] must be an ndarray or tensor.') + self._add_2_or_3_d(show_images, step, 'input_0') + + show_labels = self.batch_transform(engine.state.batch)[1] + if torch.is_tensor(show_labels): + show_labels = show_labels.detach().cpu().numpy() + if show_labels is not None: + if not isinstance(show_labels, np.ndarray): + raise ValueError('batch_transform(engine.state.batch)[1] must be an ndarray or tensor.') + self._add_2_or_3_d(show_labels, step, 'input_1') + + show_outputs = self.output_transform(engine.state.output) + if torch.is_tensor(show_outputs): + show_outputs = show_outputs.detach().cpu().numpy() + if show_outputs is not None: + if not isinstance(show_outputs, np.ndarray): + raise ValueError('output_transform(engine.state.output) must be an ndarray or tensor.') + self._add_2_or_3_d(show_outputs, step, 'output') + + self._writer.flush() + + def _add_2_or_3_d(self, data, step, tag='output'): + # for i, d in enumerate(data): # go through a batch of images + d = data[0] # show the first element in a batch + + if d.ndim == 2: + d = rescale_array(d, 0, 1) + dataformats = 'HW' + self._writer.add_image('{}_{}'.format(tag, dataformats), d, step, dataformats=dataformats) + return + + if d.ndim == 3: + if d.shape[0] == 3 and self.max_channels == 3: # rgb? + dataformats = 'CHW' + self._writer.add_image('{}_{}'.format(tag, dataformats), d, step, dataformats='CHW') + return + for j, d2 in enumerate(d[:self.max_channels]): + d2 = rescale_array(d2, 0, 1) + dataformats = 'HW' + self._writer.add_image('{}_{}_{}'.format(tag, dataformats, j), d2, step, dataformats=dataformats) + return + + if d.ndim >= 4: + spatial = d.shape[-3:] + for j, d3 in enumerate(d.reshape([-1] + list(spatial))[:self.max_channels]): + d3 = rescale_array(d3, 0, 255) + img2tensorboard.add_animated_gif( + self._writer, '{}_HWD_{}'.format(tag, j), d3[None], self.max_frames, 1.0, step) + return diff --git a/monai/handlers/utils.py b/monai/handlers/utils.py index 377d4d0073..1cd849d18e 100644 --- a/monai/handlers/utils.py +++ b/monai/handlers/utils.py @@ -12,13 +12,17 @@ def stopping_fn_from_metric(metric_name): """Returns a stopping function for ignite.handlers.EarlyStopping using the given metric name.""" + def stopping_fn(engine): return engine.state.metrics[metric_name] + return stopping_fn def stopping_fn_from_loss(): """Returns a stopping function for ignite.handlers.EarlyStopping using the loss value.""" + def stopping_fn(engine): return -engine.state.output + return stopping_fn diff --git a/monai/networks/utils.py b/monai/networks/utils.py index 5a22884846..bca9922374 100644 --- a/monai/networks/utils.py +++ b/monai/networks/utils.py @@ -52,6 +52,6 @@ def predict_segmentation(logits): """ # generate prediction outputs, logits has shape BCHW[D] if logits.shape[1] == 1: - return (logits[:, 0] >= 0).int() # for binary segmentation threshold on channel 0 + return (logits >= 0).int() # for binary segmentation threshold on channel 0 else: - return logits.max(1)[1] # take the index of the max value along dimension 1 + return logits.argmax(1).unsqueeze(1) # take the index of the max value along dimension 1 diff --git a/monai/transforms/composables.py b/monai/transforms/composables.py index 21d1471f1a..4a666dd417 100644 --- a/monai/transforms/composables.py +++ b/monai/transforms/composables.py @@ -14,6 +14,7 @@ """ import torch +import numpy as np from collections.abc import Hashable import numpy as np import monai @@ -21,8 +22,9 @@ from monai.networks.layers.simplelayers import GaussianFilter from monai.transforms.compose import Randomizable, Transform from monai.transforms.transforms import (LoadNifti, AsChannelFirst, Orientation, - AddChannel, Spacing, Rotate90, Rescale, Resize, - SpatialCrop, RandAffine, Rand2DElastic, Rand3DElastic) + AddChannel, Spacing, Rotate90, SpatialCrop, + RandAffine, Rand2DElastic, Rand3DElastic, + Rescale, Resize, Flip, Rotate, Zoom) from monai.utils.misc import ensure_tuple from monai.transforms.utils import generate_pos_neg_label_crop_centers, create_grid from monai.utils.aliases import alias @@ -529,7 +531,6 @@ def __init__(self, keys, as_tensor_output (bool): the computation is implemented using pytorch tensors, this option specifies whether to convert it back to numpy arrays. device (torch.device): device on which the tensor will be allocated. - See also: - ``RandAffineGrid`` for the random affine paramters configurations. - ``Affine`` for the affine transformation parameters configurations. @@ -604,7 +605,6 @@ def __init__(self, keys, as_tensor_output (bool): the computation is implemented using pytorch tensors, this option specifies whether to convert it back to numpy arrays. device (torch.device): device on which the tensor will be allocated. - See also: - ``RandAffineGrid`` for the random affine paramters configurations. - ``Affine`` for the affine transformation parameters configurations. @@ -647,3 +647,228 @@ def __call__(self, data): for key in self.keys: # same interpolation mode d[key] = self.rand_3d_elastic.resampler(d[key], grid, mode=self.rand_3d_elastic.mode) return d + + +@export +@alias('FlipD', 'FlipDict') +class Flipd(MapTransform): + """Dictionary-based wrapper of Flip. + + Args: + keys (dict): Keys to pick data for transformation. + axis (None, int or tuple of ints): Axes along which to flip over. Default is None. + """ + + def __init__(self, keys, axis=None): + MapTransform.__init__(self, keys) + self.flipper = Flip(axis=axis) + + def __call__(self, data): + d = dict(data) + for key in self.keys: + d[key] = self.flipper(d[key]) + return d + + +@export +@alias('RandFlipD', 'RandFlipDict') +class RandFlipd(Randomizable, MapTransform): + """Dict-based wrapper of RandFlip. + + Args: + prob (float): Probability of flipping. + axis (None, int or tuple of ints): Axes along which to flip over. Default is None. + """ + + def __init__(self, keys, prob=0.1, axis=None): + MapTransform.__init__(self, keys) + self.axis = axis + self.prob = prob + + self._do_transform = False + self.flipper = Flip(axis=axis) + + def randomize(self): + self._do_transform = self.R.random_sample() < self.prob + + def __call__(self, data): + self.randomize() + d = dict(data) + if not self._do_transform: + return d + for key in self.keys: + d[key] = self.flipper(d[key]) + return d + + +@export +@alias('RotateD', 'RotateDict') +class Rotated(MapTransform): + """Dictionary-based wrapper of Rotate. + + Args: + keys (dict): Keys to pick data for transformation. + angle (float): Rotation angle in degrees. + axes (tuple of 2 ints): Axes of rotation. Default: (1, 2). This is the first two + axis in spatial dimensions according to MONAI channel first shape assumption. + reshape (bool): If true, output shape is made same as input. Default: True. + order (int): Order of spline interpolation. Range 0-5. Default: 1. This is + different from scipy where default interpolation is 3. + mode (str): Points outside boundary filled according to this mode. Options are + 'constant', 'nearest', 'reflect', 'wrap'. Default: 'constant'. + cval (scalar): Values to fill outside boundary. Default: 0. + prefiter (bool): Apply spline_filter before interpolation. Default: True. + """ + + def __init__(self, keys, angle, axes=(1, 2), reshape=True, order=1, + mode='constant', cval=0, prefilter=True): + MapTransform.__init__(self, keys) + self.rotator = Rotate(angle=angle, axes=axes, reshape=reshape, + order=order, mode=mode, cval=cval, prefilter=prefilter) + + def __call__(self, data): + d = dict(data) + for key in self.keys: + d[key] = self.rotator(d[key]) + return d + + +@export +@alias('RandRotateD', 'RandRotateDict') +class RandRotated(Randomizable, MapTransform): + """Randomly rotates the input arrays. + + Args: + prob (float): Probability of rotation. + degrees (tuple of float or float): Range of rotation in degrees. If single number, + angle is picked from (-degrees, degrees). + axes (tuple of 2 ints): Axes of rotation. Default: (1, 2). This is the first two + axis in spatial dimensions according to MONAI channel first shape assumption. + reshape (bool): If true, output shape is made same as input. Default: True. + order (int): Order of spline interpolation. Range 0-5. Default: 1. This is + different from scipy where default interpolation is 3. + mode (str): Points outside boundary filled according to this mode. Options are + 'constant', 'nearest', 'reflect', 'wrap'. Default: 'constant'. + cval (scalar): Value to fill outside boundary. Default: 0. + prefiter (bool): Apply spline_filter before interpolation. Default: True. + """ + def __init__(self, keys, degrees, prob=0.1, axes=(1, 2), reshape=True, order=1, + mode='constant', cval=0, prefilter=True): + MapTransform.__init__(self, keys) + self.prob = prob + self.degrees = degrees + self.reshape = reshape + self.order = order + self.mode = mode + self.cval = cval + self.prefilter = prefilter + self.axes = axes + + if not hasattr(self.degrees, '__iter__'): + self.degrees = (-self.degrees, self.degrees) + assert len(self.degrees) == 2, "degrees should be a number or pair of numbers." + + self._do_transform = False + self.angle = None + + def randomize(self): + self._do_transform = self.R.random_sample() < self.prob + self.angle = self.R.uniform(low=self.degrees[0], high=self.degrees[1]) + + def __call__(self, data): + self.randomize() + d = dict(data) + if not self._do_transform: + return d + rotator = Rotate(self.angle, self.axes, self.reshape, self.order, + self.mode, self.cval, self.prefilter) + for key in self.keys: + d[key] = self.flipper(d[key]) + return d + + +@export +@alias('ZoomD', 'ZoomDict') +class Zoomd(MapTransform): + """Dictionary-based wrapper of Zoom transform. + + Args: + zoom (float or sequence): The zoom factor along the spatial axes. + If a float, zoom is the same for each spatial axis. + If a sequence, zoom should contain one value for each spatial axis. + order (int): order of interpolation. Default=3. + mode (str): Determines how input is extended beyond boundaries. Default is 'constant'. + cval (scalar, optional): Value to fill past edges. Default is 0. + use_gpu (bool): Should use cpu or gpu. Uses cupyx which doesn't support order > 1 and modes + 'wrap' and 'reflect'. Defaults to cpu for these cases or if cupyx not found. + keep_size (bool): Should keep original size (pad if needed). + """ + + def __init__(self, keys, zoom, order=3, mode='constant', cval=0, + prefilter=True, use_gpu=False, keep_size=False): + MapTransform.__init__(self, keys) + self.zoomer = Zoom(zoom=zoom, order=order, mode=mode, cval=cval, + prefilter=prefilter, use_gpu=use_gpu, keep_size=keep_size) + + def __call__(self, data): + d = dict(data) + for key in self.keys: + d[key] = self.zoomer(d[key]) + return d + + +@export +@alias('RandZoomD', 'RandZoomDict') +class RandZoomd(Randomizable, MapTransform): + """Dict-based wrapper of RandZoom. + + Args: + keys (dict): Keys to pick data for transformation. + prob (float): Probability of zooming. + min_zoom (float or sequence): Min zoom factor. Can be float or sequence same size as image. + max_zoom (float or sequence): Max zoom factor. Can be float or sequence same size as image. + order (int): order of interpolation. Default=3. + mode ('reflect', 'constant', 'nearest', 'mirror', 'wrap'): Determines how input is + extended beyond boundaries. Default: 'constant'. + cval (scalar, optional): Value to fill past edges. Default is 0. + use_gpu (bool): Should use cpu or gpu. Uses cupyx which doesn't support order > 1 and modes + 'wrap' and 'reflect'. Defaults to cpu for these cases or if cupyx not found. + keep_size (bool): Should keep original size (pad if needed). + """ + + def __init__(self, keys, prob=0.1, min_zoom=0.9, + max_zoom=1.1, order=3, mode='constant', + cval=0, prefilter=True, use_gpu=False, keep_size=False): + MapTransform.__init__(self, keys) + if hasattr(min_zoom, '__iter__') and \ + hasattr(max_zoom, '__iter__'): + assert len(min_zoom) == len(max_zoom), "min_zoom and max_zoom must have same length." + self.min_zoom = min_zoom + self.max_zoom = max_zoom + self.prob = prob + self.order = order + self.mode = mode + self.cval = cval + self.prefilter = prefilter + self.use_gpu = use_gpu + self.keep_size = keep_size + + self._do_transform = False + self._zoom = None + + def randomize(self): + self._do_transform = self.R.random_sample() < self.prob + if hasattr(self.min_zoom, '__iter__'): + self._zoom = (self.R.uniform(l, h) for l, h in zip(self.min_zoom, self.max_zoom)) + else: + self._zoom = self.R.uniform(self.min_zoom, self.max_zoom) + + def __call__(self, data): + self.randomize() + d = dict(data) + if not self._do_transform: + return d + zoomer = Zoom(self._zoom, self.order, self.mode, self.cval, self.prefilter, self.use_gpu, self.keep_size) + for key in self.keys: + d[key] = zoomer(d[key]) + return d diff --git a/monai/transforms/transforms.py b/monai/transforms/transforms.py index c34d1c4c2e..be6d38a4fa 100644 --- a/monai/transforms/transforms.py +++ b/monai/transforms/transforms.py @@ -440,7 +440,7 @@ def __call__(self, img): pad_vec[idx] = [half, diff - half] elif diff < 0: # need slicing slice_vec[idx] = slice(half, half + od) - zoomed = np.pad(zoomed, pad_vec) + zoomed = np.pad(zoomed, pad_vec, mode='constant') return zoomed[tuple(slice_vec)] @@ -484,7 +484,6 @@ class IntensityNormalizer: Args: subtrahend (ndarray): the amount to subtract by (usually the mean) divisor (ndarray): the amount to divide by (usually the standard deviation) - dtype: output data format """ def __init__(self, subtrahend=None, divisor=None): @@ -514,13 +513,12 @@ class ImageEndPadder: Args: out_size (list): the size of region of interest at the end of the operation. mode (string): a portion from numpy.lib.arraypad.pad is copied below. - dtype: output data format. """ def __init__(self, out_size, mode): - assert out_size is not None and isinstance(out_size, (list, tuple)), 'out_size must be list or tuple' + assert out_size is not None and isinstance(out_size, (list, tuple)), 'out_size must be list or tuple.' self.out_size = out_size - assert isinstance(mode, str), 'mode must be str' + assert isinstance(mode, str), 'mode must be str.' self.mode = mode def _determine_data_pad_width(self, data_shape): @@ -698,6 +696,7 @@ def __init__(self, prob=0.1, axis=None): self.flipper = Flip(axis=axis) self._do_transform = False + self.flipper = Flip(axis=axis) def randomize(self): self._do_transform = self.R.random_sample() < self.prob diff --git a/monai/utils/misc.py b/monai/utils/misc.py index 775b8f2ebd..261e521adb 100644 --- a/monai/utils/misc.py +++ b/monai/utils/misc.py @@ -11,6 +11,9 @@ import itertools +import numpy as np +import torch + def zip_with(op, *vals, mapfunc=map): """ @@ -40,3 +43,15 @@ def ensure_tuple(vals): vals = (vals,) return tuple(vals) + + +def is_scalar_tensor(val): + if torch.is_tensor(val) and val.ndim == 0: + return True + return False + + +def is_scalar(val): + if torch.is_tensor(val) and val.ndim == 0: + return True + return np.isscalar(val) diff --git a/monai/visualize/img2tensorboard.py b/monai/visualize/img2tensorboard.py index 82211a8ecd..8fce996685 100644 --- a/monai/visualize/img2tensorboard.py +++ b/monai/visualize/img2tensorboard.py @@ -27,8 +27,8 @@ def _image3_animated_gif(imp, scale_factor=1): # x=numpy.random.randint(0,256,[10,10,10],numpy.uint8) (tag, ims) = imp ims = [ - (np.asarray((ims[i, :, :])) * scale_factor).astype(np.uint8) - for i in range(ims.shape[0]) + (np.asarray((ims[:, :, i])) * scale_factor).astype(np.uint8) + for i in range(ims.shape[2]) ] ims = [GifImage.fromarray(im) for im in ims] img_str = b'' @@ -49,8 +49,8 @@ def _image3_animated_gif(imp, scale_factor=1): def make_animated_gif_summary(tag, tensor, max_out=3, - animation_axes=(1,), - image_axes=(2, 3), + animation_axes=(3,), + image_axes=(1, 2), other_indices=None, scale_factor=1): """ @@ -58,7 +58,7 @@ def make_animated_gif_summary(tag, Args: tag: Data identifier - tensor: tensor for the image, expected to be in CDHW format + tensor: tensor for the image, expected to be in CHWD format max_out: maximum number of slices to animate through animation_axes: axis to animate on (not currently used) image_axes: axes of image (not currently used) diff --git a/tests/integration_sliding_window.py b/tests/integration_sliding_window.py index 8025a4c821..4677b33b3a 100644 --- a/tests/integration_sliding_window.py +++ b/tests/integration_sliding_window.py @@ -65,17 +65,19 @@ def _sliding_window_processor(_engine, batch): basename = os.path.basename(img_name)[:-len('.nii.gz')] saved_name = os.path.join(temp_dir, basename, '{}_seg.nii.gz'.format(basename)) - testing_shape = nib.load(saved_name).get_fdata().shape + # get spatial dimensions shape, the saved nifti image format: HWDC + testing_shape = nib.load(saved_name).get_fdata().shape[:-1] if os.path.exists(img_name): os.remove(img_name) if os.path.exists(seg_name): os.remove(seg_name) - - return testing_shape == input_shape + if testing_shape != input_shape: + print('testing shape: {} does not match input shape: {}.'.format(testing_shape, input_shape)) + return False + return True if __name__ == "__main__": result = run_test() - sys.exit(0 if result else 1) diff --git a/tests/integration_unet2d.py b/tests/integration_unet2d.py index ed1afa9872..3dac697397 100644 --- a/tests/integration_unet2d.py +++ b/tests/integration_unet2d.py @@ -51,12 +51,13 @@ def loss_fn(pred, grnd): trainer = create_supervised_trainer(net, opt, loss_fn, device, False) trainer.run(src, 1) - - return trainer.state.output + loss = trainer.state.output + print('Loss:', loss) + if loss >= 1: + print('Loss value is wrong, expect to be < 1.') + return loss if __name__ == "__main__": result = run_test() - print(result) - sys.exit(0 if result < 1 else 1) diff --git a/tests/test_dice_loss.py b/tests/test_dice_loss.py index c5640a5660..e937185f91 100644 --- a/tests/test_dice_loss.py +++ b/tests/test_dice_loss.py @@ -82,7 +82,7 @@ TEST_CASE_6 = [ # shape: (1, 1, 2, 2), (1, 1, 2, 2) { - 'include_background': False, + 'include_background': True, 'do_sigmoid': True, }, { diff --git a/tests/test_flip.py b/tests/test_flip.py index 3b027ec2c8..a261c315e2 100644 --- a/tests/test_flip.py +++ b/tests/test_flip.py @@ -14,31 +14,44 @@ import numpy as np from parameterized import parameterized -from monai.transforms import Flip +from monai.transforms import Flip, Flipd from tests.utils import NumpyImageTestCase2D +INVALID_CASES = [("wrong_axis", ['s', 1], TypeError), + ("not_numbers", 's', TypeError)] + +VALID_CASES = [("no_axis", None), + ("one_axis", 1), + ("many_axis", [0, 1, 2])] + class FlipTest(NumpyImageTestCase2D): - @parameterized.expand([ - ("wrong_axis", ['s', 1], TypeError), - ("not_numbers", 's', TypeError) - ]) + @parameterized.expand(INVALID_CASES) def test_invalid_inputs(self, _, axis, raises): with self.assertRaises(raises): flip = Flip(axis) flip(self.imt) - @parameterized.expand([ - ("no_axis", None), - ("one_axis", 1), - ("many_axis", [0, 1, 2]) - ]) + @parameterized.expand(INVALID_CASES) + def test_invalid_cases_dict(self, _, axis, raises): + with self.assertRaises(raises): + flip = Flipd(keys='img', axis=axis) + flip({'img': self.imt}) + + @parameterized.expand(VALID_CASES) def test_correct_results(self, _, axis): flip = Flip(axis=axis) expected = np.flip(self.imt, axis) self.assertTrue(np.allclose(expected, flip(self.imt))) + @parameterized.expand(VALID_CASES) + def test_correct_results_dict(self, _, axis): + flip = Flipd(keys='img', axis=axis) + expected = np.flip(self.imt, axis) + res = flip({'img': self.imt}) + assert np.allclose(expected, res['img']) + if __name__ == '__main__': unittest.main() diff --git a/tests/test_generalized_dice_loss.py b/tests/test_generalized_dice_loss.py index e08ff1d296..b2ce96169e 100644 --- a/tests/test_generalized_dice_loss.py +++ b/tests/test_generalized_dice_loss.py @@ -94,7 +94,7 @@ TEST_CASE_6 = [ # shape: (1, 1, 2, 2), (1, 1, 2, 2) { - 'include_background': False, + 'include_background': True, 'do_sigmoid': True, }, { diff --git a/tests/test_handler_stats.py b/tests/test_handler_stats.py index 58a62133d2..fdb0600e04 100644 --- a/tests/test_handler_stats.py +++ b/tests/test_handler_stats.py @@ -30,7 +30,7 @@ def test_metrics_print(self): # set up engine def _train_func(engine, batch): - return None, torch.tensor(0.0) + return torch.tensor(0.0) engine = Engine(_train_func) @@ -50,7 +50,6 @@ def _update_metric(engine): output_str = log_stream.getvalue() grep = re.compile('.*{}.*'.format(key_to_handler)) has_key_word = re.compile('.*{}.*'.format(key_to_print)) - matched = [] for idx, line in enumerate(output_str.split('\n')): if grep.match(line): if idx in [5, 10]: @@ -60,16 +59,44 @@ def test_loss_print(self): log_stream = StringIO() logging.basicConfig(stream=log_stream, level=logging.INFO) key_to_handler = 'test_logging' - key_to_print = 'Loss' + key_to_print = 'myLoss' # set up engine def _train_func(engine, batch): - return None, torch.tensor(0.0) + return torch.tensor(0.0) engine = Engine(_train_func) # set up testing handler - stats_handler = StatsHandler(name=key_to_handler) + stats_handler = StatsHandler(name=key_to_handler, tag_name=key_to_print) + stats_handler.attach(engine) + + engine.run(range(3), max_epochs=2) + + # check logging output + output_str = log_stream.getvalue() + grep = re.compile('.*{}.*'.format(key_to_handler)) + has_key_word = re.compile('.*{}.*'.format(key_to_print)) + for idx, line in enumerate(output_str.split('\n')): + if grep.match(line): + if idx in [1, 2, 3, 6, 7, 8]: + self.assertTrue(has_key_word.match(line)) + + def test_loss_dict(self): + log_stream = StringIO() + logging.basicConfig(stream=log_stream, level=logging.INFO) + key_to_handler = 'test_logging' + key_to_print = 'myLoss1' + + # set up engine + def _train_func(engine, batch): + return torch.tensor(0.0) + + engine = Engine(_train_func) + + # set up testing handler + stats_handler = StatsHandler(name=key_to_handler, + output_transform=lambda x: {key_to_print: x}) stats_handler.attach(engine) engine.run(range(3), max_epochs=2) @@ -78,7 +105,6 @@ def _train_func(engine, batch): output_str = log_stream.getvalue() grep = re.compile('.*{}.*'.format(key_to_handler)) has_key_word = re.compile('.*{}.*'.format(key_to_print)) - matched = [] for idx, line in enumerate(output_str.split('\n')): if grep.match(line): if idx in [1, 2, 3, 6, 7, 8]: diff --git a/tests/test_handler_tb_image.py b/tests/test_handler_tb_image.py new file mode 100644 index 0000000000..9bf55e162b --- /dev/null +++ b/tests/test_handler_tb_image.py @@ -0,0 +1,60 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import glob +import os +import shutil +import unittest + +import numpy as np +import torch +from ignite.engine import Engine, Events +from parameterized import parameterized + +from monai.handlers.tensorboard_handlers import TensorBoardImageHandler + +TEST_CASES = [ + [[20, 20]], + [[2, 20, 20]], + [[3, 20, 20]], + [[20, 20, 20]], + [[2, 20, 20, 20]], + [[2, 2, 20, 20, 20]], +] + + +class TestHandlerTBImage(unittest.TestCase): + + @parameterized.expand(TEST_CASES) + def test_tb_image_shape(self, shape): + default_dir = os.path.join('.', 'runs') + shutil.rmtree(default_dir, ignore_errors=True) + + # set up engine + def _train_func(engine, batch): + return torch.zeros((1, 1, 10, 10)) + + engine = Engine(_train_func) + + # set up testing handler + stats_handler = TensorBoardImageHandler() + engine.add_event_handler(Events.ITERATION_COMPLETED, stats_handler) + + data = zip(np.random.normal(size=(10, 4, *shape)), np.random.normal(size=(10, 4, *shape))) + engine.run(data, epoch_length=10, max_epochs=1) + + self.assertTrue(os.path.exists(default_dir)) + self.assertTrue(len(glob.glob(default_dir)) > 0) + shutil.rmtree(default_dir) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_handler_tb_stats.py b/tests/test_handler_tb_stats.py new file mode 100644 index 0000000000..53a691701f --- /dev/null +++ b/tests/test_handler_tb_stats.py @@ -0,0 +1,81 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import shutil +import tempfile +import unittest +import glob + +from ignite.engine import Engine, Events +from torch.utils.tensorboard import SummaryWriter + +from monai.handlers.tensorboard_handlers import TensorBoardStatsHandler + + +class TestHandlerTBStats(unittest.TestCase): + + def test_metrics_print(self): + default_dir = os.path.join('.', 'runs') + shutil.rmtree(default_dir, ignore_errors=True) + + # set up engine + def _train_func(engine, batch): + return batch + 1.0 + + engine = Engine(_train_func) + + # set up dummy metric + @engine.on(Events.EPOCH_COMPLETED) + def _update_metric(engine): + current_metric = engine.state.metrics.get('acc', 0.1) + engine.state.metrics['acc'] = current_metric + 0.1 + + # set up testing handler + stats_handler = TensorBoardStatsHandler() + stats_handler.attach(engine) + engine.run(range(3), max_epochs=2) + # check logging output + + self.assertTrue(os.path.exists(default_dir)) + shutil.rmtree(default_dir) + + def test_metrics_writer(self): + default_dir = os.path.join('.', 'runs') + shutil.rmtree(default_dir, ignore_errors=True) + with tempfile.TemporaryDirectory() as temp_dir: + + # set up engine + def _train_func(engine, batch): + return batch + 1.0 + + engine = Engine(_train_func) + + # set up dummy metric + @engine.on(Events.EPOCH_COMPLETED) + def _update_metric(engine): + current_metric = engine.state.metrics.get('acc', 0.1) + engine.state.metrics['acc'] = current_metric + 0.1 + + # set up testing handler + writer = SummaryWriter(log_dir=temp_dir) + stats_handler = TensorBoardStatsHandler( + writer, output_transform=lambda x: {'loss': x * 2.0}, + global_epoch_transform=lambda x: x * 3.0) + stats_handler.attach(engine) + engine.run(range(3), max_epochs=2) + # check logging output + self.assertTrue(len(glob.glob(temp_dir)) > 0) + self.assertTrue(not os.path.exists(default_dir)) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_random_flip.py b/tests/test_rand_flip.py similarity index 62% rename from tests/test_random_flip.py rename to tests/test_rand_flip.py index ee89a133d9..be03ff5a28 100644 --- a/tests/test_random_flip.py +++ b/tests/test_rand_flip.py @@ -14,31 +14,38 @@ import numpy as np from parameterized import parameterized -from monai.transforms import RandFlip +from monai.transforms import RandFlip, RandFlipd from tests.utils import NumpyImageTestCase2D +INVALID_CASES = [("wrong_axis", ['s', 1], TypeError), + ("not_numbers", 's', TypeError)] -class RandomFlipTest(NumpyImageTestCase2D): +VALID_CASES = [("no_axis", None), + ("one_axis", 1), + ("many_axis", [0, 1, 2])] - @parameterized.expand([ - ("wrong_axis", ['s', 1], TypeError), - ("not_numbers", 's', TypeError) - ]) +class RandFlipTest(NumpyImageTestCase2D): + + @parameterized.expand(INVALID_CASES) def test_invalid_inputs(self, _, axis, raises): with self.assertRaises(raises): flip = RandFlip(prob=1.0, axis=axis) flip(self.imt) - @parameterized.expand([ - ("no_axis", None), - ("one_axis", 1), - ("many_axis", [0, 1, 2]) - ]) + @parameterized.expand(VALID_CASES) def test_correct_results(self, _, axis): flip = RandFlip(prob=1.0, axis=axis) expected = np.flip(self.imt, axis) self.assertTrue(np.allclose(expected, flip(self.imt))) + @parameterized.expand(VALID_CASES) + def test_correct_results_dict(self, _, axis): + flip = RandFlipd(keys='img', prob=1.0, axis=axis) + res = flip({'img': self.imt}) + + expected = np.flip(self.imt, axis) + self.assertTrue(np.allclose(expected, res['img'])) + if __name__ == '__main__': unittest.main() diff --git a/tests/test_random_rotate.py b/tests/test_rand_rotate.py similarity index 100% rename from tests/test_random_rotate.py rename to tests/test_rand_rotate.py diff --git a/tests/test_random_zoom.py b/tests/test_rand_zoom.py similarity index 76% rename from tests/test_random_zoom.py rename to tests/test_rand_zoom.py index d193a16dd2..530504b887 100644 --- a/tests/test_random_zoom.py +++ b/tests/test_rand_zoom.py @@ -17,15 +17,14 @@ from scipy.ndimage import zoom as zoom_scipy from parameterized import parameterized -from monai.transforms import RandZoom +from monai.transforms import RandZoom, RandZoomd from tests.utils import NumpyImageTestCase2D +VALID_CASES = [(0.9, 1.1, 3, 'constant', 0, True, False, False)] class ZoomTest(NumpyImageTestCase2D): - @parameterized.expand([ - (0.9, 1.1, 3, 'constant', 0, True, False, False), - ]) + @parameterized.expand(VALID_CASES) def test_correct_results(self, min_zoom, max_zoom, order, mode, cval, prefilter, use_gpu, keep_size): random_zoom = RandZoom(prob=1.0, min_zoom=min_zoom, max_zoom=max_zoom, order=order, @@ -39,6 +38,21 @@ def test_correct_results(self, min_zoom, max_zoom, order, mode, self.assertTrue(np.allclose(expected, zoomed)) + @parameterized.expand(VALID_CASES) + def test_correct_results_dict(self, min_zoom, max_zoom, order, mode, + cval, prefilter, use_gpu, keep_size): + keys = 'img' + random_zoom = RandZoomd(keys, prob=1.0, min_zoom=min_zoom, max_zoom=max_zoom, order=order, + mode=mode, cval=cval, prefilter=prefilter, use_gpu=use_gpu, + keep_size=keep_size) + random_zoom.set_random_state(234) + + zoomed = random_zoom({keys: self.imt}) + expected = zoom_scipy(self.imt, zoom=random_zoom._zoom, mode=mode, + order=order, cval=cval, prefilter=prefilter) + + self.assertTrue(np.allclose(expected, zoomed[keys])) + @parameterized.expand([ (0.8, 1.2, 1, 'constant', 0, True) ]) diff --git a/tests/test_rotate.py b/tests/test_rotate.py index 98e25f587f..0c34f5809e 100644 --- a/tests/test_rotate.py +++ b/tests/test_rotate.py @@ -15,17 +15,16 @@ import scipy.ndimage from parameterized import parameterized -from monai.transforms import Rotate +from monai.transforms import Rotate, Rotated from tests.utils import NumpyImageTestCase2D +TEST_CASES = [(90, (1, 2), True, 1, 'reflect', 0, True), + (-90, (2, 1), True, 3, 'constant', 0, True), + (180, (2, 3), False, 2, 'constant', 4, False)] class RotateTest(NumpyImageTestCase2D): - @parameterized.expand([ - (90, (1, 2), True, 1, 'reflect', 0, True), - (-90, (2, 1), True, 3, 'constant', 0, True), - (180, (2, 3), False, 2, 'constant', 4, False), - ]) + @parameterized.expand(TEST_CASES) def test_correct_results(self, angle, axes, reshape, order, mode, cval, prefilter): rotate_fn = Rotate(angle, axes, reshape, @@ -36,6 +35,18 @@ def test_correct_results(self, angle, axes, reshape, mode=mode, cval=cval, prefilter=prefilter) self.assertTrue(np.allclose(expected, rotated)) + @parameterized.expand(TEST_CASES) + def test_correct_results_dict(self, angle, axes, reshape, + order, mode, cval, prefilter): + key = 'img' + rotate_fn = Rotated(key, angle, axes, reshape, order, + mode, cval, prefilter) + rotated = rotate_fn({key: self.imt}) + + expected = scipy.ndimage.rotate(self.imt, angle, axes, reshape, order=order, + mode=mode, cval=cval, prefilter=prefilter) + self.assertTrue(np.allclose(expected, rotated[key])) + if __name__ == '__main__': unittest.main() diff --git a/tests/test_zoom.py b/tests/test_zoom.py index 874e587a98..83795542bc 100644 --- a/tests/test_zoom.py +++ b/tests/test_zoom.py @@ -17,17 +17,22 @@ from scipy.ndimage import zoom as zoom_scipy from parameterized import parameterized -from monai.transforms import Zoom +from monai.transforms import Zoom, Zoomd from tests.utils import NumpyImageTestCase2D +VALID_CASES = [(1.1, 3, 'constant', 0, True, False, False), + (0.9, 3, 'constant', 0, True, False, False), + (0.8, 1, 'reflect', 0, False, False, False)] + +GPU_CASES = [("gpu_zoom", 0.6, 1, 'constant', 0, True)] + +INVALID_CASES = [("no_zoom", None, 1, TypeError), + ("invalid_order", 0.9, 's', AssertionError)] + class ZoomTest(NumpyImageTestCase2D): - @parameterized.expand([ - (1.1, 3, 'constant', 0, True, False, False), - (0.9, 3, 'constant', 0, True, False, False), - (0.8, 1, 'reflect', 0, False, False, False) - ]) + @parameterized.expand(VALID_CASES) def test_correct_results(self, zoom, order, mode, cval, prefilter, use_gpu, keep_size): zoom_fn = Zoom(zoom=zoom, order=order, mode=mode, cval=cval, prefilter=prefilter, use_gpu=use_gpu, keep_size=keep_size) @@ -36,9 +41,19 @@ def test_correct_results(self, zoom, order, mode, cval, prefilter, use_gpu, keep cval=cval, prefilter=prefilter) self.assertTrue(np.allclose(expected, zoomed)) - @parameterized.expand([ - ("gpu_zoom", 0.6, 1, 'constant', 0, True) - ]) + @parameterized.expand(VALID_CASES) + def test_correct_results_dict(self, zoom, order, mode, cval, prefilter, use_gpu, keep_size): + key = 'img' + zoom_fn = Zoomd(key, zoom=zoom, order=order, mode=mode, cval=cval, + prefilter=prefilter, use_gpu=use_gpu, keep_size=keep_size) + zoomed = zoom_fn({key: self.imt[0]}) + + expected = zoom_scipy(self.imt, zoom=zoom, mode=mode, order=order, + cval=cval, prefilter=prefilter) + self.assertTrue(np.allclose(expected, zoomed[key])) + + + @parameterized.expand(GPU_CASES) def test_gpu_zoom(self, _, zoom, order, mode, cval, prefilter): if importlib.util.find_spec('cupy'): zoom_fn = Zoom(zoom=zoom, order=order, mode=mode, cval=cval, @@ -57,10 +72,7 @@ def test_keep_size(self): zoomed = zoom_fn(self.imt[0]) self.assertTrue(np.array_equal(zoomed.shape, self.imt.shape[1:])) - @parameterized.expand([ - ("no_zoom", None, 1, TypeError), - ("invalid_order", 0.9, 's', AssertionError) - ]) + @parameterized.expand(INVALID_CASES) def test_invalid_inputs(self, _, zoom, order, raises): with self.assertRaises(raises): zoom_fn = Zoom(zoom=zoom, order=order)