diff --git a/examples/densenet_classification_3d.py b/examples/densenet_classification_3d.py index 07ac3ffe04..753097a2a9 100644 --- a/examples/densenet_classification_3d.py +++ b/examples/densenet_classification_3d.py @@ -91,7 +91,7 @@ trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=checkpoint_handler, to_save={'net': net, 'opt': opt}) -train_stats_handler = StatsHandler() +train_stats_handler = StatsHandler(output_transform=lambda x: x[3]) train_stats_handler.attach(trainer) @trainer.on(Events.EPOCH_COMPLETED) @@ -108,7 +108,7 @@ def log_training_loss(engine): # Add stats event handler to print validation stats via evaluator logging.basicConfig(stream=sys.stdout, level=logging.INFO) -val_stats_handler = StatsHandler() +val_stats_handler = StatsHandler(output_transform=lambda x: None) val_stats_handler.attach(evaluator) # Add early stopping handler to evaluator. diff --git a/examples/unet_segmentation_3d.ipynb b/examples/unet_segmentation_3d.ipynb index bfba0a70bf..0d49742f10 100644 --- a/examples/unet_segmentation_3d.ipynb +++ b/examples/unet_segmentation_3d.ipynb @@ -197,7 +197,7 @@ "trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED,\n", " handler=checkpoint_handler,\n", " to_save={'net': net, 'opt': opt})\n", - "train_stats_handler = StatsHandler()\n", + "train_stats_handler = StatsHandler(output_transform=lambda x: x[1])\n", "train_stats_handler.attach(trainer)\n", "\n", "writer = SummaryWriter()\n", @@ -260,7 +260,7 @@ "\n", "# Add stats event handler to print validation stats via evaluator\n", "logging.basicConfig(stream=sys.stdout, level=logging.INFO)\n", - "val_stats_handler = StatsHandler()\n", + "val_stats_handler = StatsHandler(lambda x: None)\n", "val_stats_handler.attach(evaluator)\n", "\n", "# Add early stopping handler to evaluator.\n", diff --git a/examples/unet_segmentation_3d_array.py b/examples/unet_segmentation_3d_array.py index de5996470d..3b0d880f10 100644 --- a/examples/unet_segmentation_3d_array.py +++ b/examples/unet_segmentation_3d_array.py @@ -18,7 +18,6 @@ import nibabel as nib import numpy as np import torch -from torch.utils.tensorboard import SummaryWriter from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator from ignite.handlers import ModelCheckpoint, EarlyStopping from torch.utils.data import DataLoader @@ -32,8 +31,8 @@ from monai.data.nifti_reader import NiftiDataset from monai.transforms import AddChannel, Rescale, ToTensor, UniformRandomPatch from monai.handlers.stats_handler import StatsHandler +from monai.handlers.tensorboard_handlers import TensorBoardStatsHandler, TensorBoardImageHandler from monai.handlers.mean_dice import MeanDice -from monai.visualize import img2tensorboard from monai.data.synthetic import create_test_image_3d from monai.handlers.utils import stopping_fn_from_metric @@ -88,70 +87,71 @@ loss = monai.losses.DiceLoss(do_sigmoid=True) opt = torch.optim.Adam(net.parameters(), lr) + # Since network outputs logits and segmentation, we need a custom function. def _loss_fn(i, j): return loss(i[0], j) + # Create trainer -device = torch.device("cuda:0") +device = torch.device("cpu:0") trainer = create_supervised_trainer(net, opt, _loss_fn, device, False, - output_transform=lambda x, y, y_pred, loss: [y_pred, loss.item(), y]) + output_transform=lambda x, y, y_pred, loss: [y_pred[1], loss.item(), y]) # adding checkpoint handler to save models (network params and optimizer stats) during training checkpoint_handler = ModelCheckpoint('./runs/', 'net', n_saved=10, require_empty=False) trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=checkpoint_handler, to_save={'net': net, 'opt': opt}) -train_stats_handler = StatsHandler() +logging.basicConfig(stream=sys.stdout, level=logging.INFO) + +# print training loss to commandline +train_stats_handler = StatsHandler(output_transform=lambda x: x[1]) train_stats_handler.attach(trainer) +# record training loss to TensorBoard at every iteration +train_tensorboard_stats_handler = TensorBoardStatsHandler( + output_transform=lambda x: {'training_dice_loss': x[1]}, # plot under tag name taining_dice_loss + global_epoch_transform=lambda x: trainer.state.epoch) +train_tensorboard_stats_handler.attach(trainer) + + @trainer.on(Events.EPOCH_COMPLETED) def log_training_loss(engine): - # log loss to tensorboard with second item of engine.state.output, loss.item() from output_transform - writer.add_scalar('Loss/train', engine.state.output[1], engine.state.epoch) - - # tensor of ones to use where for converting labels to zero and ones - ones = torch.ones(engine.state.batch[1][0].shape, dtype=torch.int32) - first_output_tensor = engine.state.output[0][1][0].detach().cpu() - # log model output to tensorboard, as three dimensional tensor with no channels dimension - img2tensorboard.add_animated_gif_no_channels(writer, "first_output_final_batch", first_output_tensor, 64, - 255, engine.state.epoch) - # get label tensor and convert to single class - first_label_tensor = torch.where(engine.state.batch[1][0] > 0, ones, engine.state.batch[1][0]) - # log label tensor to tensorboard, there is a channel dimension when getting label from batch - img2tensorboard.add_animated_gif(writer, "first_label_final_batch", first_label_tensor, 64, - 255, engine.state.epoch) - second_output_tensor = engine.state.output[0][1][1].detach().cpu() - img2tensorboard.add_animated_gif_no_channels(writer, "second_output_final_batch", second_output_tensor, 64, - 255, engine.state.epoch) - second_label_tensor = torch.where(engine.state.batch[1][1] > 0, ones, engine.state.batch[1][1]) - img2tensorboard.add_animated_gif(writer, "second_label_final_batch", second_label_tensor, 64, - 255, engine.state.epoch) - third_output_tensor = engine.state.output[0][1][2].detach().cpu() - img2tensorboard.add_animated_gif_no_channels(writer, "third_output_final_batch", third_output_tensor, 64, - 255, engine.state.epoch) - third_label_tensor = torch.where(engine.state.batch[1][2] > 0, ones, engine.state.batch[1][2]) - img2tensorboard.add_animated_gif(writer, "third_label_final_batch", third_label_tensor, 64, - 255, engine.state.epoch) engine.logger.info("Epoch[%s] Loss: %s", engine.state.epoch, engine.state.output[1]) -writer = SummaryWriter() # Set parameters for validation validation_every_n_epochs = 1 metric_name = 'Mean_Dice' # add evaluation metric to the evaluator engine -val_metrics = {metric_name: MeanDice(add_sigmoid=True)} -evaluator = create_supervised_evaluator(net, val_metrics, device, True, - output_transform=lambda x, y, y_pred: (y_pred[0], y)) +val_metrics = {metric_name: MeanDice( + add_sigmoid=True, to_onehot_y=False, output_transform=lambda output: (output[0][0], output[1])) +} +evaluator = create_supervised_evaluator(net, val_metrics, device, True) # Add stats event handler to print validation stats via evaluator -logging.basicConfig(stream=sys.stdout, level=logging.INFO) -val_stats_handler = StatsHandler() +val_stats_handler = StatsHandler( + output_transform=lambda x: None, # disable per iteration output + global_epoch_transform=lambda x: trainer.state.epoch) val_stats_handler.attach(evaluator) -# Add early stopping handler to evaluator. +# add handler to record metrics to TensorBoard at every epoch +val_tensorboard_stats_handler = TensorBoardStatsHandler( + output_transform=lambda x: None, # no iteration plot + global_epoch_transform=lambda x: trainer.state.epoch) # use epoch number from trainer +val_tensorboard_stats_handler.attach(evaluator) +# add handler to draw several images and the corresponding labels and model outputs +# here we draw the first 3 images(draw the first channel) as GIF format along Depth axis +val_tensorboard_image_handler = TensorBoardImageHandler( + batch_transform=lambda batch: (batch[0], batch[1]), + output_transform=lambda output: output[0][1], + global_iter_transform=lambda x: trainer.state.epoch +) +evaluator.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=val_tensorboard_image_handler) + +# Add early stopping handler to evaluator early_stopper = EarlyStopping(patience=4, score_function=stopping_fn_from_metric(metric_name), trainer=trainer) @@ -166,10 +166,6 @@ def log_training_loss(engine): def run_validation(engine): evaluator.run(val_loader) -@evaluator.on(Events.EPOCH_COMPLETED) -def log_metrics_to_tensorboard(engine): - for name, value in engine.state.metrics.items(): - writer.add_scalar('Metrics/{name}', value, trainer.state.epoch) # create a training data loader logging.basicConfig(stream=sys.stdout, level=logging.INFO) diff --git a/examples/unet_segmentation_3d_dict.py b/examples/unet_segmentation_3d_dict.py index 640bed21c0..0e1e78811b 100644 --- a/examples/unet_segmentation_3d_dict.py +++ b/examples/unet_segmentation_3d_dict.py @@ -111,7 +111,7 @@ def prepare_batch(batch, device=None, non_blocking=False): trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=checkpoint_handler, to_save={'net': net, 'opt': opt}) -train_stats_handler = StatsHandler() +train_stats_handler = StatsHandler(output_transform=lambda x: x[1]) train_stats_handler.attach(trainer) @@ -160,7 +160,7 @@ def log_training_loss(engine): # Add stats event handler to print validation stats via evaluator logging.basicConfig(stream=sys.stdout, level=logging.INFO) -val_stats_handler = StatsHandler() +val_stats_handler = StatsHandler(output_transform=lambda x: None) val_stats_handler.attach(evaluator) # Add early stopping handler to evaluator. diff --git a/monai/handlers/segmentation_saver.py b/monai/handlers/segmentation_saver.py index e0fa50310a..1f3fe2615d 100644 --- a/monai/handlers/segmentation_saver.py +++ b/monai/handlers/segmentation_saver.py @@ -10,7 +10,7 @@ # limitations under the License. import os - +import numpy as np import torch from ignite.engine import Events @@ -114,5 +114,6 @@ def __call__(self, engine): seg_output = seg_output.detach().cpu().numpy() output_filename = self._create_file_basename(self.output_postfix, filename, self.output_path) output_filename = '{}{}'.format(output_filename, self.output_ext) - write_nifti(seg_output, affine_, output_filename, original_affine_, dtype=seg_output.dtype) + # change output to "channel last" format and write to nifti format file + write_nifti(np.moveaxis(seg_output, 0, -1), affine_, output_filename, original_affine_, dtype=seg_output.dtype) self.logger.info('saved: {}'.format(output_filename)) diff --git a/monai/handlers/stats_handler.py b/monai/handlers/stats_handler.py index 47709543fe..9d2e518919 100644 --- a/monai/handlers/stats_handler.py +++ b/monai/handlers/stats_handler.py @@ -9,47 +9,63 @@ # See the License for the specific language governing permissions and # limitations under the License. +import warnings import logging import torch from ignite.engine import Engine, Events +from monai.utils.misc import is_scalar -KEY_VAL_FORMAT = '{}: {:.4f} ' +DEFAULT_KEY_VAL_FORMAT = '{}: {:.4f} ' +DEFAULT_TAG = 'Loss' class StatsHandler(object): """StatsHandler defines a set of Ignite Event-handlers for all the log printing logics. It's can be used for any Ignite Engine(trainer, validator and evaluator). And it can support logging for epoch level and iteration level with pre-defined loggers. - By default: - (1) epoch_print_logger logs `engine.state.metrics`. - (2) iteration_print_logger logs loss value, expected output format is (y_pred, loss). + + Default behaviors: + - When EPOCH_COMPLETED, logs ``engine.state.metrics`` using ``self.logger``. + - When ITERATION_COMPELTED, logs + ``self.output_transform(engine.state.output)`` using ``self.logger``. """ def __init__(self, epoch_print_logger=None, iteration_print_logger=None, - batch_transform=lambda x: x, output_transform=lambda x: x, - name=None): + global_epoch_transform=lambda x: x, + name=None, + tag_name=DEFAULT_TAG, + key_var_format=DEFAULT_KEY_VAL_FORMAT): """ Args: epoch_print_logger (Callable): customized callable printer for epoch level logging. must accept parameter "engine", use default printer if None. iteration_print_logger (Callable): custimized callable printer for iteration level logging. must accept parameter "engine", use default printer if None. - batch_transform (Callable): a callable that is used to transform the - ignite.engine.batch into expected format to extract input data. output_transform (Callable): a callable that is used to transform the - ignite.engine.output into expected format to extract several output data. - name (str): identifier of logging.logger to use, defaulting to `engine.logger`. + ``ignite.engine.output`` into a scalar to print, or a dictionary of {key: scalar}. + in the latter case, the output string will be formated as key: value. + by default this value logging happens when every iteration completed. + global_epoch_transform (Callable): a callable that is used to customize global epoch number. + For example, in evaluation, the evaluator engine might want to print synced epoch number + with the trainer engine. + name (str): identifier of logging.logger to use, defaulting to ``engine.logger``. + tag_name (string): when iteration output is a scalar, tag_name is used to print + tag_name: scalar_value to logger. Defaults to ``'Loss'``. + key_var_format (string): a formatting string to control the output string format of key: value. """ self.epoch_print_logger = epoch_print_logger self.iteration_print_logger = iteration_print_logger - self.batch_transform = batch_transform self.output_transform = output_transform + self.global_epoch_transform = global_epoch_transform self.logger = None if name is None else logging.getLogger(name) + self.tag_name = tag_name + self.key_var_format = key_var_format + def attach(self, engine: Engine): """Register a set of Ignite Event-Handlers to a specified Ignite engine. @@ -116,12 +132,12 @@ def _default_epoch_print(self, engine: Engine): prints_dict = engine.state.metrics if not prints_dict: return - current_epoch = engine.state.epoch + current_epoch = self.global_epoch_transform(engine.state.epoch) out_str = "Epoch[{}] Metrics -- ".format(current_epoch) for name in sorted(prints_dict): value = prints_dict[name] - out_str += KEY_VAL_FORMAT.format(name, value) + out_str += self.key_var_format.format(name, value) self.logger.info(out_str) @@ -134,19 +150,42 @@ def _default_iteration_print(self, engine: Engine): engine (ignite.engine): Ignite Engine, it can be a trainer, validator or evaluator. """ - loss = self.output_transform(engine.state.output)[1] - if loss is None or (torch.is_tensor(loss) and len(loss.shape) > 0): - return # not printing multi dimensional output + loss = self.output_transform(engine.state.output) + if loss is None: + return # no printing if the output is empty + + out_str = '' + if isinstance(loss, dict): # print dictionary items + for name in sorted(loss): + value = loss[name] + if not is_scalar(value): + warnings.warn('ignoring non-scalar output in StatsHandler,' + ' make sure `output_transform(engine.state.output)` returns' + ' a scalar or dictionary of key and scalar pairs to avoid this warning.' + ' {}:{}'.format(name, type(value))) + continue # not printing multi dimensional output + out_str += self.key_var_format.format(name, value.item() if torch.is_tensor(value) else value) + else: + if is_scalar(loss): # not printing multi dimensional output + out_str += self.key_var_format.format(self.tag_name, loss.item() if torch.is_tensor(loss) else loss) + else: + warnings.warn('ignoring non-scalar output in StatsHandler,' + ' make sure `output_transform(engine.state.output)` returns' + ' a scalar or a dictionary of key and scalar pairs to avoid this warning.' + ' {}'.format(type(loss))) + + if not out_str: + return # no value to print + num_iterations = engine.state.epoch_length current_iteration = (engine.state.iteration - 1) % num_iterations + 1 current_epoch = engine.state.epoch num_epochs = engine.state.max_epochs - out_str = "Epoch: {}/{}, Iter: {}/{} -- ".format( + base_str = "Epoch: {}/{}, Iter: {}/{} --".format( current_epoch, num_epochs, current_iteration, num_iterations) - out_str += KEY_VAL_FORMAT.format('Loss', loss.item() if torch.is_tensor(loss) else loss) - self.logger.info(out_str) + self.logger.info(' '.join([base_str, out_str])) diff --git a/monai/handlers/tensorboard_handlers.py b/monai/handlers/tensorboard_handlers.py new file mode 100644 index 0000000000..2d5116bd59 --- /dev/null +++ b/monai/handlers/tensorboard_handlers.py @@ -0,0 +1,249 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import warnings +import torch +from torch.utils.tensorboard import SummaryWriter +from ignite.engine import Engine, Events +from monai.visualize import img2tensorboard +from monai.utils.misc import is_scalar +from monai.transforms.utils import rescale_array + + +class TensorBoardStatsHandler(object): + """TensorBoardStatsHandler defines a set of Ignite Event-handlers for all the TensorBoard logics. + It's can be used for any Ignite Engine(trainer, validator and evaluator). + And it can support both epoch level and iteration level with pre-defined TensorBoard event writer. + The expected data source is ignite ``engine.state.output`` and ``engine.state.metrics``. + + Default behaviors: + - When EPOCH_COMPLETED, write each dictionary item in + ``engine.state.metrics`` to TensorBoard. + - When ITERATION_COMPELTED, write each dictionary item in + ``self.output_transform(engine.state.output)`` to TensorBoard. + """ + + def __init__(self, + summary_writer=None, + epoch_event_writer=None, + iteration_event_writer=None, + output_transform=lambda x: {'Loss': x}, + global_epoch_transform=lambda x: x): + """ + Args: + summary_writer (SummaryWriter): user can specify TensorBoard SummaryWriter, + default to create a new writer. + epoch_event_writer (Callable): customized callable TensorBoard writer for epoch level. + must accept parameter "engine" and "summary_writer", use default event writer if None. + iteration_event_writer (Callable): custimized callable TensorBoard writer for iteration level. + must accept parameter "engine" and "summary_writer", use default event writer if None. + output_transform (Callable): a callable that is used to transform the + ``ignite.engine.output`` into a dictionary of (tag_name: scalar) pairs to be plotted onto tensorboard. + by default this scalar plotting happens when every iteration completed. + global_epoch_transform (Callable): a callable that is used to customize global epoch number. + For example, in evaluation, the evaluator engine might want to use trainer engines epoch number + when plotting epoch vs metric curves. + """ + self._writer = SummaryWriter() if summary_writer is None else summary_writer + self.epoch_event_writer = epoch_event_writer + self.iteration_event_writer = iteration_event_writer + self.output_transform = output_transform + self.global_epoch_transform = global_epoch_transform + + def attach(self, engine: Engine): + """Register a set of Ignite Event-Handlers to a specified Ignite engine. + + Args: + engine (ignite.engine): Ignite Engine, it can be a trainer, validator or evaluator. + + """ + if not engine.has_event_handler(self.iteration_completed, Events.ITERATION_COMPLETED): + engine.add_event_handler(Events.ITERATION_COMPLETED, self.iteration_completed) + if not engine.has_event_handler(self.epoch_completed, Events.EPOCH_COMPLETED): + engine.add_event_handler(Events.EPOCH_COMPLETED, self.epoch_completed) + + def epoch_completed(self, engine: Engine): + """handler for train or validation/evaluation epoch completed Event. + Write epoch level events, default values are from ignite state.metrics dict. + + Args: + engine (ignite.engine): Ignite Engine, it can be a trainer, validator or evaluator. + + """ + if self.epoch_event_writer is not None: + self.epoch_event_writer(engine, self._writer) + else: + self._default_epoch_writer(engine, self._writer) + + def iteration_completed(self, engine: Engine): + """handler for train or validation/evaluation iteration completed Event. + Write iteration level events, default values are from ignite state.logs dict. + + Args: + engine (ignite.engine): Ignite Engine, it can be a trainer, validator or evaluator. + + """ + if self.iteration_event_writer is not None: + self.iteration_event_writer(engine, self._writer) + else: + self._default_iteration_writer(engine, self._writer) + + def _default_epoch_writer(self, engine: Engine, writer: SummaryWriter): + """Execute epoch level event write operation based on Ignite engine.state data. + Default is to write the values from ignite state.metrics dict. + + Args: + engine (ignite.engine): Ignite Engine, it can be a trainer, validator or evaluator. + writer (SummaryWriter): TensorBoard writer, created in TensorBoardHandler. + + """ + current_epoch = self.global_epoch_transform(engine.state.epoch) + summary_dict = engine.state.metrics + for name, value in summary_dict.items(): + writer.add_scalar(name, value, current_epoch) + writer.flush() + + def _default_iteration_writer(self, engine: Engine, writer: SummaryWriter): + """Execute iteration level event write operation based on Ignite engine.state data. + Default is to write the loss value of current iteration. + + Args: + engine (ignite.engine): Ignite Engine, it can be a trainer, validator or evaluator. + writer (SummaryWriter): TensorBoard writer, created in TensorBoardHandler. + + """ + loss_dict = self.output_transform(engine.state.output) + if loss_dict is None: + return # do nothing if output is empty + if not isinstance(loss_dict, dict): + raise ValueError('TensorBoardStatsHandler requires' + ' output_transform(engine.state.output) returning a dictionary' + ' of key and scalar pairs to plot' + ' got {}.'.format(type(loss_dict))) + for name, value in loss_dict.items(): + if not is_scalar(value): + warnings.warn('ignoring non-scalar output in tensorboard curve plotting,' + ' make sure `output_transform(engine.state.output)` returns' + ' a dictionary of key and scalar pairs to avoid this warning.' + ' Got {}:{}'.format(name, type(value))) + continue + plot_value = value.item() if torch.is_tensor(value) else value + writer.add_scalar(name, plot_value, engine.state.iteration) + writer.flush() + + +class TensorBoardImageHandler(object): + """TensorBoardImageHandler is an ignite Event handler that can visualise images, labels and outputs as 2D/3D images. + 2D output (shape in Batch, channel, H, W) will be shown as simple image using the first element in the batch, + for 3D to ND output (shape in Batch, channel, H, W, D) input, + the last three dimensions will be shown as GIF image along the last axis (typically Depth). + + It's can be used for any Ignite Engine (trainer, validator and evaluator). + User can easily added it to engine for any expected Event, for example: ``EPOCH_COMPLETED``, + ``ITERATION_COMPLETED``. The expected data source is ignite's ``engine.state.batch`` and ``engine.state.output``. + + Default behavior: + - Show y_pred as images (GIF for 3D) on TensorBoard when Event triggered, + - need to use ``batch_transform`` and ``output_transform`` to specify + how many images to show and show which channel. + - Expects ``batch_transform(engine.state.batch)`` to return data + format: (image[N, channel, ...], label[N, channel, ...]). + - Expects ``output_transform(engine.state.output)`` to return a torch + tensor in format (y_pred[N, channel, ...], loss). + """ + + def __init__(self, + summary_writer=None, + batch_transform=lambda x: x, + output_transform=lambda x: x, + global_iter_transform=lambda x: x, + max_channels=1, + max_frames=64): + """ + Args: + summary_writer (SummaryWriter): user can specify TensorBoard SummaryWriter, + default to create a new writer. + batch_transform (Callable): a callable that is used to transform the + ``ignite.engine.batch`` into expected format to extract several label data. + output_transform (Callable): a callable that is used to transform the + ``ignite.engine.output`` into expected format to extract several output data. + global_iter_transform (Callable): a callable that is used to customize global step number for TensorBoard. + For example, in evaluation, the evaluator engine needs to know current epoch from trainer. + max_channels (int): number of channels to plot. + max_frames (int): number of frames for 2D-t plot. + + """ + self._writer = SummaryWriter() if summary_writer is None else summary_writer + self.batch_transform = batch_transform + self.output_transform = output_transform + self.global_iter_transform = global_iter_transform + + self.max_frames = max_frames + self.max_channels = max_channels + + def __call__(self, engine): + step = self.global_iter_transform(engine.state.iteration) + + show_images = self.batch_transform(engine.state.batch)[0] + if torch.is_tensor(show_images): + show_images = show_images.detach().cpu().numpy() + if show_images is not None: + if not isinstance(show_images, np.ndarray): + raise ValueError('output_transform(engine.state.output)[0] must be an ndarray or tensor.') + self._add_2_or_3_d(show_images, step, 'input_0') + + show_labels = self.batch_transform(engine.state.batch)[1] + if torch.is_tensor(show_labels): + show_labels = show_labels.detach().cpu().numpy() + if show_labels is not None: + if not isinstance(show_labels, np.ndarray): + raise ValueError('batch_transform(engine.state.batch)[1] must be an ndarray or tensor.') + self._add_2_or_3_d(show_labels, step, 'input_1') + + show_outputs = self.output_transform(engine.state.output) + if torch.is_tensor(show_outputs): + show_outputs = show_outputs.detach().cpu().numpy() + if show_outputs is not None: + if not isinstance(show_outputs, np.ndarray): + raise ValueError('output_transform(engine.state.output) must be an ndarray or tensor.') + self._add_2_or_3_d(show_outputs, step, 'output') + + self._writer.flush() + + def _add_2_or_3_d(self, data, step, tag='output'): + # for i, d in enumerate(data): # go through a batch of images + d = data[0] # show the first element in a batch + + if d.ndim == 2: + d = rescale_array(d, 0, 1) + dataformats = 'HW' + self._writer.add_image('{}_{}'.format(tag, dataformats), d, step, dataformats=dataformats) + return + + if d.ndim == 3: + if d.shape[0] == 3 and self.max_channels == 3: # rgb? + dataformats = 'CHW' + self._writer.add_image('{}_{}'.format(tag, dataformats), d, step, dataformats='CHW') + return + for j, d2 in enumerate(d[:self.max_channels]): + d2 = rescale_array(d2, 0, 1) + dataformats = 'HW' + self._writer.add_image('{}_{}_{}'.format(tag, dataformats, j), d2, step, dataformats=dataformats) + return + + if d.ndim >= 4: + spatial = d.shape[-3:] + for j, d3 in enumerate(d.reshape([-1] + list(spatial))[:self.max_channels]): + d3 = rescale_array(d3, 0, 255) + img2tensorboard.add_animated_gif( + self._writer, '{}_HWD_{}'.format(tag, j), d3[None], self.max_frames, 1.0, step) + return diff --git a/monai/handlers/utils.py b/monai/handlers/utils.py index 377d4d0073..1cd849d18e 100644 --- a/monai/handlers/utils.py +++ b/monai/handlers/utils.py @@ -12,13 +12,17 @@ def stopping_fn_from_metric(metric_name): """Returns a stopping function for ignite.handlers.EarlyStopping using the given metric name.""" + def stopping_fn(engine): return engine.state.metrics[metric_name] + return stopping_fn def stopping_fn_from_loss(): """Returns a stopping function for ignite.handlers.EarlyStopping using the loss value.""" + def stopping_fn(engine): return -engine.state.output + return stopping_fn diff --git a/monai/networks/utils.py b/monai/networks/utils.py index 5a22884846..bca9922374 100644 --- a/monai/networks/utils.py +++ b/monai/networks/utils.py @@ -52,6 +52,6 @@ def predict_segmentation(logits): """ # generate prediction outputs, logits has shape BCHW[D] if logits.shape[1] == 1: - return (logits[:, 0] >= 0).int() # for binary segmentation threshold on channel 0 + return (logits >= 0).int() # for binary segmentation threshold on channel 0 else: - return logits.max(1)[1] # take the index of the max value along dimension 1 + return logits.argmax(1).unsqueeze(1) # take the index of the max value along dimension 1 diff --git a/monai/utils/misc.py b/monai/utils/misc.py index 775b8f2ebd..261e521adb 100644 --- a/monai/utils/misc.py +++ b/monai/utils/misc.py @@ -11,6 +11,9 @@ import itertools +import numpy as np +import torch + def zip_with(op, *vals, mapfunc=map): """ @@ -40,3 +43,15 @@ def ensure_tuple(vals): vals = (vals,) return tuple(vals) + + +def is_scalar_tensor(val): + if torch.is_tensor(val) and val.ndim == 0: + return True + return False + + +def is_scalar(val): + if torch.is_tensor(val) and val.ndim == 0: + return True + return np.isscalar(val) diff --git a/monai/visualize/img2tensorboard.py b/monai/visualize/img2tensorboard.py index 82211a8ecd..8fce996685 100644 --- a/monai/visualize/img2tensorboard.py +++ b/monai/visualize/img2tensorboard.py @@ -27,8 +27,8 @@ def _image3_animated_gif(imp, scale_factor=1): # x=numpy.random.randint(0,256,[10,10,10],numpy.uint8) (tag, ims) = imp ims = [ - (np.asarray((ims[i, :, :])) * scale_factor).astype(np.uint8) - for i in range(ims.shape[0]) + (np.asarray((ims[:, :, i])) * scale_factor).astype(np.uint8) + for i in range(ims.shape[2]) ] ims = [GifImage.fromarray(im) for im in ims] img_str = b'' @@ -49,8 +49,8 @@ def _image3_animated_gif(imp, scale_factor=1): def make_animated_gif_summary(tag, tensor, max_out=3, - animation_axes=(1,), - image_axes=(2, 3), + animation_axes=(3,), + image_axes=(1, 2), other_indices=None, scale_factor=1): """ @@ -58,7 +58,7 @@ def make_animated_gif_summary(tag, Args: tag: Data identifier - tensor: tensor for the image, expected to be in CDHW format + tensor: tensor for the image, expected to be in CHWD format max_out: maximum number of slices to animate through animation_axes: axis to animate on (not currently used) image_axes: axes of image (not currently used) diff --git a/tests/integration_sliding_window.py b/tests/integration_sliding_window.py index db10d7cc49..31c7a1248a 100644 --- a/tests/integration_sliding_window.py +++ b/tests/integration_sliding_window.py @@ -65,17 +65,19 @@ def _sliding_window_processor(_engine, batch): basename = os.path.basename(img_name)[:-len('.nii.gz')] saved_name = os.path.join(temp_dir, basename, '{}_seg.nii.gz'.format(basename)) - testing_shape = nib.load(saved_name).get_fdata().shape + # get spatial dimensions shape, the saved nifti image format: HWDC + testing_shape = nib.load(saved_name).get_fdata().shape[:-1] if os.path.exists(img_name): os.remove(img_name) if os.path.exists(seg_name): os.remove(seg_name) - - return testing_shape == input_shape + if testing_shape != input_shape: + print('testing shape: {} does not match input shape: {}.'.format(testing_shape, input_shape)) + return False + return True if __name__ == "__main__": result = run_test() - sys.exit(0 if result else 1) diff --git a/tests/integration_unet2d.py b/tests/integration_unet2d.py index 1fd9074c66..7b0f116b77 100644 --- a/tests/integration_unet2d.py +++ b/tests/integration_unet2d.py @@ -51,12 +51,13 @@ def loss_fn(pred, grnd): trainer = create_supervised_trainer(net, opt, loss_fn, device, False) trainer.run(src, 1) - - return trainer.state.output + loss = trainer.state.output + print('Loss:', loss) + if loss >= 1: + print('Loss value is wrong, expect to be < 1.') + return loss if __name__ == "__main__": result = run_test() - print(result) - sys.exit(0 if result < 1 else 1) diff --git a/tests/test_handler_stats.py b/tests/test_handler_stats.py index 58a62133d2..fdb0600e04 100644 --- a/tests/test_handler_stats.py +++ b/tests/test_handler_stats.py @@ -30,7 +30,7 @@ def test_metrics_print(self): # set up engine def _train_func(engine, batch): - return None, torch.tensor(0.0) + return torch.tensor(0.0) engine = Engine(_train_func) @@ -50,7 +50,6 @@ def _update_metric(engine): output_str = log_stream.getvalue() grep = re.compile('.*{}.*'.format(key_to_handler)) has_key_word = re.compile('.*{}.*'.format(key_to_print)) - matched = [] for idx, line in enumerate(output_str.split('\n')): if grep.match(line): if idx in [5, 10]: @@ -60,16 +59,44 @@ def test_loss_print(self): log_stream = StringIO() logging.basicConfig(stream=log_stream, level=logging.INFO) key_to_handler = 'test_logging' - key_to_print = 'Loss' + key_to_print = 'myLoss' # set up engine def _train_func(engine, batch): - return None, torch.tensor(0.0) + return torch.tensor(0.0) engine = Engine(_train_func) # set up testing handler - stats_handler = StatsHandler(name=key_to_handler) + stats_handler = StatsHandler(name=key_to_handler, tag_name=key_to_print) + stats_handler.attach(engine) + + engine.run(range(3), max_epochs=2) + + # check logging output + output_str = log_stream.getvalue() + grep = re.compile('.*{}.*'.format(key_to_handler)) + has_key_word = re.compile('.*{}.*'.format(key_to_print)) + for idx, line in enumerate(output_str.split('\n')): + if grep.match(line): + if idx in [1, 2, 3, 6, 7, 8]: + self.assertTrue(has_key_word.match(line)) + + def test_loss_dict(self): + log_stream = StringIO() + logging.basicConfig(stream=log_stream, level=logging.INFO) + key_to_handler = 'test_logging' + key_to_print = 'myLoss1' + + # set up engine + def _train_func(engine, batch): + return torch.tensor(0.0) + + engine = Engine(_train_func) + + # set up testing handler + stats_handler = StatsHandler(name=key_to_handler, + output_transform=lambda x: {key_to_print: x}) stats_handler.attach(engine) engine.run(range(3), max_epochs=2) @@ -78,7 +105,6 @@ def _train_func(engine, batch): output_str = log_stream.getvalue() grep = re.compile('.*{}.*'.format(key_to_handler)) has_key_word = re.compile('.*{}.*'.format(key_to_print)) - matched = [] for idx, line in enumerate(output_str.split('\n')): if grep.match(line): if idx in [1, 2, 3, 6, 7, 8]: diff --git a/tests/test_handler_tb_image.py b/tests/test_handler_tb_image.py new file mode 100644 index 0000000000..9bf55e162b --- /dev/null +++ b/tests/test_handler_tb_image.py @@ -0,0 +1,60 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import glob +import os +import shutil +import unittest + +import numpy as np +import torch +from ignite.engine import Engine, Events +from parameterized import parameterized + +from monai.handlers.tensorboard_handlers import TensorBoardImageHandler + +TEST_CASES = [ + [[20, 20]], + [[2, 20, 20]], + [[3, 20, 20]], + [[20, 20, 20]], + [[2, 20, 20, 20]], + [[2, 2, 20, 20, 20]], +] + + +class TestHandlerTBImage(unittest.TestCase): + + @parameterized.expand(TEST_CASES) + def test_tb_image_shape(self, shape): + default_dir = os.path.join('.', 'runs') + shutil.rmtree(default_dir, ignore_errors=True) + + # set up engine + def _train_func(engine, batch): + return torch.zeros((1, 1, 10, 10)) + + engine = Engine(_train_func) + + # set up testing handler + stats_handler = TensorBoardImageHandler() + engine.add_event_handler(Events.ITERATION_COMPLETED, stats_handler) + + data = zip(np.random.normal(size=(10, 4, *shape)), np.random.normal(size=(10, 4, *shape))) + engine.run(data, epoch_length=10, max_epochs=1) + + self.assertTrue(os.path.exists(default_dir)) + self.assertTrue(len(glob.glob(default_dir)) > 0) + shutil.rmtree(default_dir) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_handler_tb_stats.py b/tests/test_handler_tb_stats.py new file mode 100644 index 0000000000..53a691701f --- /dev/null +++ b/tests/test_handler_tb_stats.py @@ -0,0 +1,81 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import shutil +import tempfile +import unittest +import glob + +from ignite.engine import Engine, Events +from torch.utils.tensorboard import SummaryWriter + +from monai.handlers.tensorboard_handlers import TensorBoardStatsHandler + + +class TestHandlerTBStats(unittest.TestCase): + + def test_metrics_print(self): + default_dir = os.path.join('.', 'runs') + shutil.rmtree(default_dir, ignore_errors=True) + + # set up engine + def _train_func(engine, batch): + return batch + 1.0 + + engine = Engine(_train_func) + + # set up dummy metric + @engine.on(Events.EPOCH_COMPLETED) + def _update_metric(engine): + current_metric = engine.state.metrics.get('acc', 0.1) + engine.state.metrics['acc'] = current_metric + 0.1 + + # set up testing handler + stats_handler = TensorBoardStatsHandler() + stats_handler.attach(engine) + engine.run(range(3), max_epochs=2) + # check logging output + + self.assertTrue(os.path.exists(default_dir)) + shutil.rmtree(default_dir) + + def test_metrics_writer(self): + default_dir = os.path.join('.', 'runs') + shutil.rmtree(default_dir, ignore_errors=True) + with tempfile.TemporaryDirectory() as temp_dir: + + # set up engine + def _train_func(engine, batch): + return batch + 1.0 + + engine = Engine(_train_func) + + # set up dummy metric + @engine.on(Events.EPOCH_COMPLETED) + def _update_metric(engine): + current_metric = engine.state.metrics.get('acc', 0.1) + engine.state.metrics['acc'] = current_metric + 0.1 + + # set up testing handler + writer = SummaryWriter(log_dir=temp_dir) + stats_handler = TensorBoardStatsHandler( + writer, output_transform=lambda x: {'loss': x * 2.0}, + global_epoch_transform=lambda x: x * 3.0) + stats_handler.attach(engine) + engine.run(range(3), max_epochs=2) + # check logging output + self.assertTrue(len(glob.glob(temp_dir)) > 0) + self.assertTrue(not os.path.exists(default_dir)) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_unet.py b/tests/test_unet.py index 98102375a6..c1e838c284 100644 --- a/tests/test_unet.py +++ b/tests/test_unet.py @@ -26,7 +26,7 @@ 'num_res_units': 1, }, torch.randn(16, 1, 32, 32), - (16, 32, 32), + (16, 1, 32, 32), ] TEST_CASE_2 = [ # single channel 3D, batch 16 @@ -39,7 +39,7 @@ 'num_res_units': 1, }, torch.randn(16, 1, 32, 24, 48), - (16, 32, 24, 48), + (16, 1, 32, 24, 48), ] TEST_CASE_3 = [ # 4-channel 3D, batch 16 @@ -52,7 +52,7 @@ 'num_res_units': 1, }, torch.randn(16, 4, 32, 64, 48), - (16, 32, 64, 48), + (16, 1, 32, 64, 48), ]