From 8c097d165eee703d6d590f8e394d7ca1cebb9062 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Tue, 10 Mar 2020 00:04:47 +0800 Subject: [PATCH 1/7] [DLMED] implement TensorBoardHandler --- examples/unet_segmentation_3d_array.py | 60 ++++------- monai/handlers/tensorboard_handler.py | 141 +++++++++++++++++++++++++ monai/networks/utils.py | 2 +- 3 files changed, 164 insertions(+), 39 deletions(-) create mode 100644 monai/handlers/tensorboard_handler.py diff --git a/examples/unet_segmentation_3d_array.py b/examples/unet_segmentation_3d_array.py index 0e68eb67c8..77ee7f688d 100644 --- a/examples/unet_segmentation_3d_array.py +++ b/examples/unet_segmentation_3d_array.py @@ -18,7 +18,6 @@ import nibabel as nib import numpy as np import torch -from torch.utils.tensorboard import SummaryWriter from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator from ignite.handlers import ModelCheckpoint, EarlyStopping from torch.utils.data import DataLoader @@ -32,8 +31,8 @@ from monai.data.nifti_reader import NiftiDataset from monai.transforms import (AddChannel, Rescale, ToTensor, UniformRandomPatch) from monai.handlers.stats_handler import StatsHandler +from monai.handlers.tensorboard_handler import TensorBoardHandler from monai.handlers.mean_dice import MeanDice -from monai.visualize import img2tensorboard from monai.data.synthetic import create_test_image_3d from monai.handlers.utils import stopping_fn_from_metric @@ -88,69 +87,58 @@ loss = monai.losses.DiceLoss(do_sigmoid=True) opt = torch.optim.Adam(net.parameters(), lr) + # Since network outputs logits and segmentation, we need a custom function. def _loss_fn(i, j): return loss(i[0], j) + # Create trainer device = torch.device("cuda:0") trainer = create_supervised_trainer(net, opt, _loss_fn, device, False, - output_transform=lambda x, y, y_pred, loss: [y_pred, loss.item(), y]) + output_transform=lambda x, y, y_pred, loss: [y_pred[1], loss.item(), y]) # adding checkpoint handler to save models (network params and optimizer stats) during training checkpoint_handler = ModelCheckpoint('./runs/', 'net', n_saved=10, require_empty=False) trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=checkpoint_handler, to_save={'net': net, 'opt': opt}) +logging.basicConfig(stream=sys.stdout, level=logging.INFO) train_stats_handler = StatsHandler() train_stats_handler.attach(trainer) + @trainer.on(Events.EPOCH_COMPLETED) def log_training_loss(engine): - # log loss to tensorboard with second item of engine.state.output, loss.item() from output_transform - writer.add_scalar('Loss/train', engine.state.output[1], engine.state.epoch) - - # tensor of ones to use where for converting labels to zero and ones - ones = torch.ones(engine.state.batch[1][0].shape, dtype=torch.int32) - first_output_tensor = engine.state.output[0][1][0].detach().cpu() - # log model output to tensorboard, as three dimensional tensor with no channels dimension - img2tensorboard.add_animated_gif_no_channels(writer, "first_output_final_batch", first_output_tensor, 64, - 255, engine.state.epoch) - # get label tensor and convert to single class - first_label_tensor = torch.where(engine.state.batch[1][0] > 0, ones, engine.state.batch[1][0]) - # log label tensor to tensorboard, there is a channel dimension when getting label from batch - img2tensorboard.add_animated_gif(writer, "first_label_final_batch", first_label_tensor, 64, - 255, engine.state.epoch) - second_output_tensor = engine.state.output[0][1][1].detach().cpu() - img2tensorboard.add_animated_gif_no_channels(writer, "second_output_final_batch", second_output_tensor, 64, - 255, engine.state.epoch) - second_label_tensor = torch.where(engine.state.batch[1][1] > 0, ones, engine.state.batch[1][1]) - img2tensorboard.add_animated_gif(writer, "second_label_final_batch", second_label_tensor, 64, - 255, engine.state.epoch) - third_output_tensor = engine.state.output[0][1][2].detach().cpu() - img2tensorboard.add_animated_gif_no_channels(writer, "third_output_final_batch", third_output_tensor, 64, - 255, engine.state.epoch) - third_label_tensor = torch.where(engine.state.batch[1][2] > 0, ones, engine.state.batch[1][2]) - img2tensorboard.add_animated_gif(writer, "third_label_final_batch", third_label_tensor, 64, - 255, engine.state.epoch) engine.logger.info("Epoch[%s] Loss: %s", engine.state.epoch, engine.state.output[1]) -writer = SummaryWriter() # Set parameters for validation validation_every_n_epochs = 1 metric_name = 'Mean_Dice' # add evaluation metric to the evaluator engine -val_metrics = {metric_name: MeanDice(add_sigmoid=True)} -evaluator = create_supervised_evaluator(net, val_metrics, device, True, - output_transform=lambda x, y, y_pred: (y_pred[0], y)) +val_metrics = {metric_name: MeanDice( + add_sigmoid=True, output_transform=lambda output: (output[0][0], output[1])) +} +evaluator = create_supervised_evaluator(net, val_metrics, device, True) # Add stats event handler to print validation stats via evaluator -logging.basicConfig(stream=sys.stdout, level=logging.INFO) val_stats_handler = StatsHandler() val_stats_handler.attach(evaluator) + +def _global_epoch_transform(): + return trainer.state.epoch + + +val_tensorboard_handler = TensorBoardHandler( + batch_transform=lambda batch: (batch[0][0:3, 0:1, ...], batch[1][0:3, 0:1, ...]), + output_transform=lambda output: (output[0][1][0:3, 0:1, ...], None), + global_epoch_transform=_global_epoch_transform +) +val_tensorboard_handler.attach(evaluator) + # Add early stopping handler to evaluator. early_stopper = EarlyStopping(patience=4, score_function=stopping_fn_from_metric(metric_name), @@ -166,10 +154,6 @@ def log_training_loss(engine): def run_validation(engine): evaluator.run(val_loader) -@evaluator.on(Events.EPOCH_COMPLETED) -def log_metrics_to_tensorboard(engine): - for name, value in engine.state.metrics.items(): - writer.add_scalar('Metrics/{name}', value, trainer.state.epoch) # create a training data loader logging.basicConfig(stream=sys.stdout, level=logging.INFO) diff --git a/monai/handlers/tensorboard_handler.py b/monai/handlers/tensorboard_handler.py new file mode 100644 index 0000000000..6cfe1e9857 --- /dev/null +++ b/monai/handlers/tensorboard_handler.py @@ -0,0 +1,141 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from torch.utils.tensorboard import SummaryWriter +from ignite.engine import Engine, Events +from monai.visualize import img2tensorboard + + +class TensorBoardHandler(object): + """TensorBoardHandler defines a set of Ignite Event-handlers for all the TensorBoard logics. + It's can be used for any Ignite Engine(trainer, validator and evaluator). + And it can support both epoch level and iteration level with pre-defined TensorBoard event writer. + The expected data source is ignite `engine.state.batch`, `engine.state.output` and `engine.state.metrics`. + Default behaviors: + (1) Write metrics to TensorBoard when EPOCH_COMPLETED. + (2) Write loss to TensorBoard when ITERATION_COMPLETED, expected output format is (y_pred, loss). + (3) Show y_pred as images(GIF for 3D) on TensorBoard when EPOCH_COMPLETED, + need to use `transform` to specify how many images to show and show which channel. + Expected batch format is (image[N, channel, ...], label[N, channel, ...]), + Expected output format is (y_pred[N, channel, ...], loss). + """ + + def __init__(self, + epoch_event_writer=None, + iteration_event_writer=None, + batch_transform=lambda x: x, + output_transform=lambda x: x, + global_epoch_transform=None): + """ + Args: + epoch_event_writer (Callable): customized callable TensorBoard writer for epoch level. + must accept parameter "engine" and "summary_writer", use default event writer if None. + iteration_event_writer (Callable): custimized callable TensorBoard writer for iteration level. + must accept parameter "engine" and "summary_writer", use default event writer if None. + batch_transform (Callable): a callable that is used to transform the + ignite.engine.batch into expected format to extract several label data. + output_transform (Callable): a callable that is used to transform the + ignite.engine.output into expected format to extract several output data. + global_epoch_transform (Callable): a callable that is used to customize global epoch number. + For example, in evaluation, the evaluator engine needs to know current epoch from trainer. + """ + + self.epoch_event_writer = epoch_event_writer + self.iteration_event_writer = iteration_event_writer + self.batch_transform = batch_transform + self.output_transform = output_transform + self.global_epoch_transform = global_epoch_transform + self._writer = SummaryWriter() + + + def attach(self, engine: Engine): + """Register a set of Ignite Event-Handlers to a specified Ignite engine. + + Args: + engine (ignite.engine): Ignite Engine, it can be a trainer, validator or evaluator. + + """ + if not engine.has_event_handler(self.iteration_completed, Events.ITERATION_COMPLETED): + engine.add_event_handler(Events.ITERATION_COMPLETED, self.iteration_completed) + if not engine.has_event_handler(self.epoch_completed, Events.EPOCH_COMPLETED): + engine.add_event_handler(Events.EPOCH_COMPLETED, self.epoch_completed) + + def epoch_completed(self, engine: Engine): + """handler for train or validation/evaluation epoch completed Event. + Write epoch level events, default values are from ignite state.metrics dict. + + Args: + engine (ignite.engine): Ignite Engine, it can be a trainer, validator or evaluator. + + """ + if self.epoch_event_writer is not None: + self.epoch_event_writer(engine, self._writer) + else: + self._default_epoch_writer(engine, self._writer) + + def iteration_completed(self, engine: Engine): + """handler for train or validation/evaluation iteration completed Event. + Write iteration level events, default values are from ignite state.logs dict. + + Args: + engine (ignite.engine): Ignite Engine, it can be a trainer, validator or evaluator. + + """ + if self.iteration_event_writer is not None: + self.iteration_event_writer(engine, self._writer) + else: + self._default_iteration_writer(engine, self._writer) + + def _default_epoch_writer(self, engine: Engine, writer: SummaryWriter): + """Execute epoch level event write operation based on Ignite engine.state data. + Default is to write the values from ignite state.metrics dict. + + Args: + engine (ignite.engine): Ignite Engine, it can be a trainer, validator or evaluator. + writer (SummaryWriter): TensorBoard writer, created in TensorBoardHandler. + + """ + current_epoch = self.global_epoch_transform() if self.global_epoch_transform is not None \ + else engine.state.epoch + summary_dict = engine.state.metrics + for name, value in summary_dict.items(): + writer.add_scalar(name, value, current_epoch) + + show_labels = self.batch_transform(engine.state.batch)[1] + show_outputs = self.output_transform(engine.state.output)[0].detach().cpu() + if show_labels is not None and show_outputs is not None: + ones = torch.ones(show_labels[0].shape, dtype=torch.int32) + if len(show_labels.shape) == 5: + assert show_labels.shape[1] == show_outputs.shape[1] == 1, \ + '3D images must select only 1 channel to show.' + for i in range(len(show_labels)): + img2tensorboard.add_animated_gif(writer, 'output' + str(i), show_outputs[i], 64, 255, current_epoch) + label_tensor = torch.where(show_labels[i] > 0, ones, show_labels[i]) + img2tensorboard.add_animated_gif(writer, 'label' + str(i), label_tensor, 64, 255, current_epoch) + elif len(show_labels.shape) == 4: + for i in range(len(show_labels)): + writer.add_image('output' + str(i), show_outputs[i], current_epoch) + label_tensor = torch.where(show_labels[i] > 0, ones, show_labels[i]) + writer.add_image('label' + str(i), label_tensor, current_epoch) + + def _default_iteration_writer(self, engine: Engine, writer: SummaryWriter): + """Execute iteration level event write operation based on Ignite engine.state data. + Default is to write the loss value of current iteration. + + Args: + engine (ignite.engine): Ignite Engine, it can be a trainer, validator or evaluator. + writer (SummaryWriter): TensorBoard writer, created in TensorBoardHandler. + + """ + loss = self.output_transform(engine.state.output)[1] + if loss is not None: + writer.add_scalar('Loss', loss.item(), engine.state.iteration) diff --git a/monai/networks/utils.py b/monai/networks/utils.py index 5a22884846..c8516120fa 100644 --- a/monai/networks/utils.py +++ b/monai/networks/utils.py @@ -52,6 +52,6 @@ def predict_segmentation(logits): """ # generate prediction outputs, logits has shape BCHW[D] if logits.shape[1] == 1: - return (logits[:, 0] >= 0).int() # for binary segmentation threshold on channel 0 + return (logits >= 0).int() # for binary segmentation threshold on channel 0 else: return logits.max(1)[1] # take the index of the max value along dimension 1 From a4f8ffd7794fdf698456b19add21dfae9ff5627a Mon Sep 17 00:00:00 2001 From: Kevin Lu Date: Mon, 9 Mar 2020 22:38:40 -0400 Subject: [PATCH 2/7] [DLMED] fix dimension for tensorboard logging --- monai/visualize/img2tensorboard.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/monai/visualize/img2tensorboard.py b/monai/visualize/img2tensorboard.py index 82211a8ecd..ab421e5c7b 100644 --- a/monai/visualize/img2tensorboard.py +++ b/monai/visualize/img2tensorboard.py @@ -49,8 +49,8 @@ def _image3_animated_gif(imp, scale_factor=1): def make_animated_gif_summary(tag, tensor, max_out=3, - animation_axes=(1,), - image_axes=(2, 3), + animation_axes=(3,), + image_axes=(1, 2), other_indices=None, scale_factor=1): """ @@ -58,7 +58,7 @@ def make_animated_gif_summary(tag, Args: tag: Data identifier - tensor: tensor for the image, expected to be in CDHW format + tensor: tensor for the image, expected to be in CHWD format max_out: maximum number of slices to animate through animation_axes: axis to animate on (not currently used) image_axes: axes of image (not currently used) From 3e7c03ff74d348c1de04525b60d9cea8f4232af2 Mon Sep 17 00:00:00 2001 From: Kevin Lu Date: Mon, 9 Mar 2020 22:51:10 -0400 Subject: [PATCH 3/7] [DLMED] should actually fix dimension to be CHWD --- monai/visualize/img2tensorboard.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/visualize/img2tensorboard.py b/monai/visualize/img2tensorboard.py index ab421e5c7b..8fce996685 100644 --- a/monai/visualize/img2tensorboard.py +++ b/monai/visualize/img2tensorboard.py @@ -27,8 +27,8 @@ def _image3_animated_gif(imp, scale_factor=1): # x=numpy.random.randint(0,256,[10,10,10],numpy.uint8) (tag, ims) = imp ims = [ - (np.asarray((ims[i, :, :])) * scale_factor).astype(np.uint8) - for i in range(ims.shape[0]) + (np.asarray((ims[:, :, i])) * scale_factor).astype(np.uint8) + for i in range(ims.shape[2]) ] ims = [GifImage.fromarray(im) for im in ims] img_str = b'' From 4d49cba3aeb615d2f48bc8878ca2ff4a69751d3f Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Tue, 10 Mar 2020 13:35:40 +0800 Subject: [PATCH 4/7] [DLMED] update according to comments --- monai/handlers/tensorboard_handler.py | 4 +++- monai/networks/utils.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/monai/handlers/tensorboard_handler.py b/monai/handlers/tensorboard_handler.py index 6cfe1e9857..dcd24eb870 100644 --- a/monai/handlers/tensorboard_handler.py +++ b/monai/handlers/tensorboard_handler.py @@ -126,6 +126,8 @@ def _default_epoch_writer(self, engine: Engine, writer: SummaryWriter): writer.add_image('output' + str(i), show_outputs[i], current_epoch) label_tensor = torch.where(show_labels[i] > 0, ones, show_labels[i]) writer.add_image('label' + str(i), label_tensor, current_epoch) + else: + raise ValueError('unsupported input data shape.') def _default_iteration_writer(self, engine: Engine, writer: SummaryWriter): """Execute iteration level event write operation based on Ignite engine.state data. @@ -137,5 +139,5 @@ def _default_iteration_writer(self, engine: Engine, writer: SummaryWriter): """ loss = self.output_transform(engine.state.output)[1] - if loss is not None: + if loss is None or (isinstance(loss, torch.Tensor) is True and len(loss.shape) > 0): writer.add_scalar('Loss', loss.item(), engine.state.iteration) diff --git a/monai/networks/utils.py b/monai/networks/utils.py index c8516120fa..bca9922374 100644 --- a/monai/networks/utils.py +++ b/monai/networks/utils.py @@ -54,4 +54,4 @@ def predict_segmentation(logits): if logits.shape[1] == 1: return (logits >= 0).int() # for binary segmentation threshold on channel 0 else: - return logits.max(1)[1] # take the index of the max value along dimension 1 + return logits.argmax(1).unsqueeze(1) # take the index of the max value along dimension 1 From c533063bdc00090de376494c14d92093bf116220 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Tue, 10 Mar 2020 20:38:07 +0800 Subject: [PATCH 5/7] [DLMED] change to channel_last when save nifit --- monai/handlers/segmentation_saver.py | 5 +++-- tests/integration_sliding_window.py | 10 ++++++---- tests/integration_unet2d.py | 9 +++++---- tests/test_unet.py | 6 +++--- 4 files changed, 17 insertions(+), 13 deletions(-) diff --git a/monai/handlers/segmentation_saver.py b/monai/handlers/segmentation_saver.py index a87e517f81..17573defb4 100644 --- a/monai/handlers/segmentation_saver.py +++ b/monai/handlers/segmentation_saver.py @@ -10,7 +10,7 @@ # limitations under the License. import os - +import numpy as np import torch from ignite.engine import Events @@ -107,5 +107,6 @@ def __call__(self, engine): seg_output = seg_output.detach().cpu().numpy() output_filename = self._create_file_basename(self.output_postfix, filename, self.output_path) output_filename = '{}{}'.format(output_filename, self.output_ext) - write_nifti(seg_output, _affine, output_filename, _original_affine, dtype=seg_output.dtype) + # change output to "channel last" format and write to nifti format file + write_nifti(np.moveaxis(seg_output, 0, -1), _affine, output_filename, _original_affine, dtype=seg_output.dtype) self.logger.info('saved: {}'.format(output_filename)) diff --git a/tests/integration_sliding_window.py b/tests/integration_sliding_window.py index 32abbfff5c..706ee6233f 100644 --- a/tests/integration_sliding_window.py +++ b/tests/integration_sliding_window.py @@ -64,17 +64,19 @@ def _sliding_window_processor(_engine, batch): basename = os.path.basename(img_name)[:-len('.nii.gz')] saved_name = os.path.join(temp_dir, basename, '{}_seg.nii.gz'.format(basename)) - testing_shape = nib.load(saved_name).get_fdata().shape + # get spatial dimensions shape, the saved nifti image format: HWDC + testing_shape = nib.load(saved_name).get_fdata().shape[:-1] if os.path.exists(img_name): os.remove(img_name) if os.path.exists(seg_name): os.remove(seg_name) - - return testing_shape == input_shape + if testing_shape != input_shape: + print('testing shape: {} does not match input shape: {}.'.format(testing_shape, input_shape)) + return False + return True if __name__ == "__main__": result = run_test() - sys.exit(0 if result else 1) diff --git a/tests/integration_unet2d.py b/tests/integration_unet2d.py index 1fd9074c66..7b0f116b77 100644 --- a/tests/integration_unet2d.py +++ b/tests/integration_unet2d.py @@ -51,12 +51,13 @@ def loss_fn(pred, grnd): trainer = create_supervised_trainer(net, opt, loss_fn, device, False) trainer.run(src, 1) - - return trainer.state.output + loss = trainer.state.output + print('Loss:', loss) + if loss >= 1: + print('Loss value is wrong, expect to be < 1.') + return loss if __name__ == "__main__": result = run_test() - print(result) - sys.exit(0 if result < 1 else 1) diff --git a/tests/test_unet.py b/tests/test_unet.py index 98102375a6..c1e838c284 100644 --- a/tests/test_unet.py +++ b/tests/test_unet.py @@ -26,7 +26,7 @@ 'num_res_units': 1, }, torch.randn(16, 1, 32, 32), - (16, 32, 32), + (16, 1, 32, 32), ] TEST_CASE_2 = [ # single channel 3D, batch 16 @@ -39,7 +39,7 @@ 'num_res_units': 1, }, torch.randn(16, 1, 32, 24, 48), - (16, 32, 24, 48), + (16, 1, 32, 24, 48), ] TEST_CASE_3 = [ # 4-channel 3D, batch 16 @@ -52,7 +52,7 @@ 'num_res_units': 1, }, torch.randn(16, 4, 32, 64, 48), - (16, 32, 64, 48), + (16, 1, 32, 64, 48), ] From 0752d91ea32932a238dd7b27717839de0a235fde Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Tue, 10 Mar 2020 23:46:28 +0800 Subject: [PATCH 6/7] [DLMED] split TensorBoard logits into Stats and Image --- examples/unet_segmentation_3d_array.py | 23 ++-- monai/handlers/stats_handler.py | 9 +- ...ard_handler.py => tensorboard_handlers.py} | 102 +++++++++++++----- 3 files changed, 94 insertions(+), 40 deletions(-) rename monai/handlers/{tensorboard_handler.py => tensorboard_handlers.py} (61%) diff --git a/examples/unet_segmentation_3d_array.py b/examples/unet_segmentation_3d_array.py index 008ca0e6e4..d1bfa7e06f 100644 --- a/examples/unet_segmentation_3d_array.py +++ b/examples/unet_segmentation_3d_array.py @@ -31,7 +31,7 @@ from monai.data.nifti_reader import NiftiDataset from monai.transforms import AddChannel, Rescale, ToTensor, UniformRandomPatch from monai.handlers.stats_handler import StatsHandler -from monai.handlers.tensorboard_handler import TensorBoardHandler +from monai.handlers.tensorboard_handlers import TensorBoardStatsHandler, TensorBoardImageHandler from monai.handlers.mean_dice import MeanDice from monai.data.synthetic import create_test_image_3d from monai.handlers.utils import stopping_fn_from_metric @@ -123,23 +123,28 @@ def log_training_loss(engine): } evaluator = create_supervised_evaluator(net, val_metrics, device, True) -# Add stats event handler to print validation stats via evaluator -val_stats_handler = StatsHandler() -val_stats_handler.attach(evaluator) - def _global_epoch_transform(): return trainer.state.epoch -val_tensorboard_handler = TensorBoardHandler( +# Add stats event handler to print validation stats via evaluator +val_stats_handler = StatsHandler(global_epoch_transform=_global_epoch_transform) +val_stats_handler.attach(evaluator) + +# add handler to record metrics to TensorBoard +val_tensorboard_stats_handler = TensorBoardStatsHandler(global_epoch_transform=_global_epoch_transform) +val_tensorboard_stats_handler.attach(evaluator) +# add handler to draw several images and the corresponding labels and model outputs +# here we draw the first 3 images(draw the first channel) as GIF format along Depth axis +val_tensorboard_image_handler = TensorBoardImageHandler( batch_transform=lambda batch: (batch[0][0:3, 0:1, ...], batch[1][0:3, 0:1, ...]), output_transform=lambda output: (output[0][1][0:3, 0:1, ...], None), - global_epoch_transform=_global_epoch_transform + global_step_transform=_global_epoch_transform ) -val_tensorboard_handler.attach(evaluator) +evaluator.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=val_tensorboard_image_handler) -# Add early stopping handler to evaluator. +# Add early stopping handler to evaluator early_stopper = EarlyStopping(patience=4, score_function=stopping_fn_from_metric(metric_name), trainer=trainer) diff --git a/monai/handlers/stats_handler.py b/monai/handlers/stats_handler.py index 47709543fe..500835df2a 100644 --- a/monai/handlers/stats_handler.py +++ b/monai/handlers/stats_handler.py @@ -30,7 +30,8 @@ def __init__(self, iteration_print_logger=None, batch_transform=lambda x: x, output_transform=lambda x: x, - name=None): + name=None, + global_epoch_transform=None): """ Args: epoch_print_logger (Callable): customized callable printer for epoch level logging. @@ -42,12 +43,15 @@ def __init__(self, output_transform (Callable): a callable that is used to transform the ignite.engine.output into expected format to extract several output data. name (str): identifier of logging.logger to use, defaulting to `engine.logger`. + global_epoch_transform (Callable): a callable that is used to customize global epoch number. + For example, in evaluation, the evaluator engine needs to know current epoch from trainer. """ self.epoch_print_logger = epoch_print_logger self.iteration_print_logger = iteration_print_logger self.batch_transform = batch_transform self.output_transform = output_transform + self.global_epoch_transform = global_epoch_transform self.logger = None if name is None else logging.getLogger(name) def attach(self, engine: Engine): @@ -116,7 +120,8 @@ def _default_epoch_print(self, engine: Engine): prints_dict = engine.state.metrics if not prints_dict: return - current_epoch = engine.state.epoch + current_epoch = self.global_epoch_transform() if self.global_epoch_transform is not None \ + else engine.state.epoch out_str = "Epoch[{}] Metrics -- ".format(current_epoch) for name in sorted(prints_dict): diff --git a/monai/handlers/tensorboard_handler.py b/monai/handlers/tensorboard_handlers.py similarity index 61% rename from monai/handlers/tensorboard_handler.py rename to monai/handlers/tensorboard_handlers.py index dcd24eb870..08414762cc 100644 --- a/monai/handlers/tensorboard_handler.py +++ b/monai/handlers/tensorboard_handlers.py @@ -15,21 +15,18 @@ from monai.visualize import img2tensorboard -class TensorBoardHandler(object): - """TensorBoardHandler defines a set of Ignite Event-handlers for all the TensorBoard logics. +class TensorBoardStatsHandler(object): + """TensorBoardStatsHandler defines a set of Ignite Event-handlers for all the TensorBoard logics. It's can be used for any Ignite Engine(trainer, validator and evaluator). And it can support both epoch level and iteration level with pre-defined TensorBoard event writer. - The expected data source is ignite `engine.state.batch`, `engine.state.output` and `engine.state.metrics`. + The expected data source is ignite `engine.state.output` and `engine.state.metrics`. Default behaviors: (1) Write metrics to TensorBoard when EPOCH_COMPLETED. (2) Write loss to TensorBoard when ITERATION_COMPLETED, expected output format is (y_pred, loss). - (3) Show y_pred as images(GIF for 3D) on TensorBoard when EPOCH_COMPLETED, - need to use `transform` to specify how many images to show and show which channel. - Expected batch format is (image[N, channel, ...], label[N, channel, ...]), - Expected output format is (y_pred[N, channel, ...], loss). - """ + """ def __init__(self, + summary_writer=None, epoch_event_writer=None, iteration_event_writer=None, batch_transform=lambda x: x, @@ -37,6 +34,8 @@ def __init__(self, global_epoch_transform=None): """ Args: + summary_writer (SummaryWriter): user can specify TensorBoard SummaryWriter, + default to create a new writer. epoch_event_writer (Callable): customized callable TensorBoard writer for epoch level. must accept parameter "engine" and "summary_writer", use default event writer if None. iteration_event_writer (Callable): custimized callable TensorBoard writer for iteration level. @@ -47,14 +46,14 @@ def __init__(self, ignite.engine.output into expected format to extract several output data. global_epoch_transform (Callable): a callable that is used to customize global epoch number. For example, in evaluation, the evaluator engine needs to know current epoch from trainer. - """ + """ + self._writer = SummaryWriter() if summary_writer is None else summary_writer self.epoch_event_writer = epoch_event_writer self.iteration_event_writer = iteration_event_writer self.batch_transform = batch_transform self.output_transform = output_transform self.global_epoch_transform = global_epoch_transform - self._writer = SummaryWriter() def attach(self, engine: Engine): @@ -109,35 +108,80 @@ def _default_epoch_writer(self, engine: Engine, writer: SummaryWriter): summary_dict = engine.state.metrics for name, value in summary_dict.items(): writer.add_scalar(name, value, current_epoch) + writer.flush() + + def _default_iteration_writer(self, engine: Engine, writer: SummaryWriter): + """Execute iteration level event write operation based on Ignite engine.state data. + Default is to write the loss value of current iteration. + + Args: + engine (ignite.engine): Ignite Engine, it can be a trainer, validator or evaluator. + writer (SummaryWriter): TensorBoard writer, created in TensorBoardHandler. + """ + loss = self.output_transform(engine.state.output)[1] + if loss is None or (torch.is_tensor(loss) and len(loss.shape) > 0): + return # not record multi dimensional output + writer.add_scalar('Loss', loss.item() if torch.is_tensor(loss) else loss, engine.state.iteration) + writer.flush() + + +class TensorBoardImageHandler(object): + """TensorBoardImageHandler is an ignite Event handler that can draw images, labels and outputs as 2D/3D images. + 2D output will be shown as simple image, 3D output will be shown as GIF image along the last axis(typically Depth) + It's can be used for any Ignite Engine(trainer, validator and evaluator). + User can easily added it to engine for any expected Event, for example: EPOCH_COMPLETED, ITERATION_COMPLETED. + The expected data source is ignite `engine.state.batch` and `engine.state.output`. + Default behavior: + Show y_pred as images(GIF for 3D) on TensorBoard when Event triggered, + need to use `batch_transform` and `output_transform` to specify how many images to show and show which channel. + Expected batch data format is (image[N, channel, ...], label[N, channel, ...]), + Expected output data format is (y_pred[N, channel, ...], loss). + """ + + def __init__(self, + summary_writer=None, + batch_transform=lambda x: x, + output_transform=lambda x: x, + global_step_transform=None): + """ + Args: + summary_writer (SummaryWriter): user can specify TensorBoard SummaryWriter, + default to create a new writer. + batch_transform (Callable): a callable that is used to transform the + ignite.engine.batch into expected format to extract several label data. + output_transform (Callable): a callable that is used to transform the + ignite.engine.output into expected format to extract several output data. + global_step_transform (Callable): a callable that is used to customize global step number for TensorBoard. + For example, in evaluation, the evaluator engine needs to know current epoch from trainer. + + """ + self._writer = SummaryWriter() if summary_writer is None else summary_writer + self.batch_transform = batch_transform + self.output_transform = output_transform + self.global_step_transform = global_step_transform + + def __call__(self, engine): + step = self.global_step_transform() if self.global_step_transform is not None else None + show_images = self.batch_transform(engine.state.batch)[0] show_labels = self.batch_transform(engine.state.batch)[1] show_outputs = self.output_transform(engine.state.output)[0].detach().cpu() - if show_labels is not None and show_outputs is not None: + if show_images is not None and show_labels is not None and show_outputs is not None: ones = torch.ones(show_labels[0].shape, dtype=torch.int32) if len(show_labels.shape) == 5: - assert show_labels.shape[1] == show_outputs.shape[1] == 1, \ + assert show_images.shape[1] == show_labels.shape[1] == show_outputs.shape[1] == 1, \ '3D images must select only 1 channel to show.' for i in range(len(show_labels)): - img2tensorboard.add_animated_gif(writer, 'output' + str(i), show_outputs[i], 64, 255, current_epoch) + img2tensorboard.add_animated_gif(self._writer, 'image' + str(i), show_images[i], 64, 255, step) label_tensor = torch.where(show_labels[i] > 0, ones, show_labels[i]) - img2tensorboard.add_animated_gif(writer, 'label' + str(i), label_tensor, 64, 255, current_epoch) + img2tensorboard.add_animated_gif(self._writer, 'label' + str(i), label_tensor, 64, 255, step) + img2tensorboard.add_animated_gif(self._writer, 'output' + str(i), show_outputs[i], 64, 255, step) elif len(show_labels.shape) == 4: for i in range(len(show_labels)): - writer.add_image('output' + str(i), show_outputs[i], current_epoch) + self._writer.add_image('image' + str(i), show_images[i], step) label_tensor = torch.where(show_labels[i] > 0, ones, show_labels[i]) - writer.add_image('label' + str(i), label_tensor, current_epoch) + self._writer.add_image('label' + str(i), label_tensor, step) + self._writer.add_image('output' + str(i), show_outputs[i], step) else: raise ValueError('unsupported input data shape.') - - def _default_iteration_writer(self, engine: Engine, writer: SummaryWriter): - """Execute iteration level event write operation based on Ignite engine.state data. - Default is to write the loss value of current iteration. - - Args: - engine (ignite.engine): Ignite Engine, it can be a trainer, validator or evaluator. - writer (SummaryWriter): TensorBoard writer, created in TensorBoardHandler. - - """ - loss = self.output_transform(engine.state.output)[1] - if loss is None or (isinstance(loss, torch.Tensor) is True and len(loss.shape) > 0): - writer.add_scalar('Loss', loss.item(), engine.state.iteration) + self._writer.flush() From 9a58eed2f3d40c69fdcab86a6d98e61f67682f9d Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Wed, 11 Mar 2020 17:57:33 +0000 Subject: [PATCH 7/7] revise stats handlers, tensorboard handlers --- examples/densenet_classification_3d.py | 4 +- examples/unet_segmentation_3d.ipynb | 4 +- examples/unet_segmentation_3d_array.py | 35 +++-- examples/unet_segmentation_3d_dict.py | 4 +- monai/handlers/stats_handler.py | 76 ++++++++--- monai/handlers/tensorboard_handlers.py | 176 +++++++++++++++++-------- monai/handlers/utils.py | 4 + monai/utils/misc.py | 15 +++ tests/test_handler_stats.py | 38 +++++- tests/test_handler_tb_image.py | 60 +++++++++ tests/test_handler_tb_stats.py | 81 ++++++++++++ 11 files changed, 393 insertions(+), 104 deletions(-) create mode 100644 tests/test_handler_tb_image.py create mode 100644 tests/test_handler_tb_stats.py diff --git a/examples/densenet_classification_3d.py b/examples/densenet_classification_3d.py index 07ac3ffe04..753097a2a9 100644 --- a/examples/densenet_classification_3d.py +++ b/examples/densenet_classification_3d.py @@ -91,7 +91,7 @@ trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=checkpoint_handler, to_save={'net': net, 'opt': opt}) -train_stats_handler = StatsHandler() +train_stats_handler = StatsHandler(output_transform=lambda x: x[3]) train_stats_handler.attach(trainer) @trainer.on(Events.EPOCH_COMPLETED) @@ -108,7 +108,7 @@ def log_training_loss(engine): # Add stats event handler to print validation stats via evaluator logging.basicConfig(stream=sys.stdout, level=logging.INFO) -val_stats_handler = StatsHandler() +val_stats_handler = StatsHandler(output_transform=lambda x: None) val_stats_handler.attach(evaluator) # Add early stopping handler to evaluator. diff --git a/examples/unet_segmentation_3d.ipynb b/examples/unet_segmentation_3d.ipynb index bfba0a70bf..0d49742f10 100644 --- a/examples/unet_segmentation_3d.ipynb +++ b/examples/unet_segmentation_3d.ipynb @@ -197,7 +197,7 @@ "trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED,\n", " handler=checkpoint_handler,\n", " to_save={'net': net, 'opt': opt})\n", - "train_stats_handler = StatsHandler()\n", + "train_stats_handler = StatsHandler(output_transform=lambda x: x[1])\n", "train_stats_handler.attach(trainer)\n", "\n", "writer = SummaryWriter()\n", @@ -260,7 +260,7 @@ "\n", "# Add stats event handler to print validation stats via evaluator\n", "logging.basicConfig(stream=sys.stdout, level=logging.INFO)\n", - "val_stats_handler = StatsHandler()\n", + "val_stats_handler = StatsHandler(lambda x: None)\n", "val_stats_handler.attach(evaluator)\n", "\n", "# Add early stopping handler to evaluator.\n", diff --git a/examples/unet_segmentation_3d_array.py b/examples/unet_segmentation_3d_array.py index d1bfa7e06f..3b0d880f10 100644 --- a/examples/unet_segmentation_3d_array.py +++ b/examples/unet_segmentation_3d_array.py @@ -94,7 +94,7 @@ def _loss_fn(i, j): # Create trainer -device = torch.device("cuda:0") +device = torch.device("cpu:0") trainer = create_supervised_trainer(net, opt, _loss_fn, device, False, output_transform=lambda x, y, y_pred, loss: [y_pred[1], loss.item(), y]) @@ -104,9 +104,17 @@ def _loss_fn(i, j): handler=checkpoint_handler, to_save={'net': net, 'opt': opt}) logging.basicConfig(stream=sys.stdout, level=logging.INFO) -train_stats_handler = StatsHandler() + +# print training loss to commandline +train_stats_handler = StatsHandler(output_transform=lambda x: x[1]) train_stats_handler.attach(trainer) +# record training loss to TensorBoard at every iteration +train_tensorboard_stats_handler = TensorBoardStatsHandler( + output_transform=lambda x: {'training_dice_loss': x[1]}, # plot under tag name taining_dice_loss + global_epoch_transform=lambda x: trainer.state.epoch) +train_tensorboard_stats_handler.attach(trainer) + @trainer.on(Events.EPOCH_COMPLETED) def log_training_loss(engine): @@ -119,28 +127,27 @@ def log_training_loss(engine): # add evaluation metric to the evaluator engine val_metrics = {metric_name: MeanDice( - add_sigmoid=True, output_transform=lambda output: (output[0][0], output[1])) + add_sigmoid=True, to_onehot_y=False, output_transform=lambda output: (output[0][0], output[1])) } evaluator = create_supervised_evaluator(net, val_metrics, device, True) - -def _global_epoch_transform(): - return trainer.state.epoch - - # Add stats event handler to print validation stats via evaluator -val_stats_handler = StatsHandler(global_epoch_transform=_global_epoch_transform) +val_stats_handler = StatsHandler( + output_transform=lambda x: None, # disable per iteration output + global_epoch_transform=lambda x: trainer.state.epoch) val_stats_handler.attach(evaluator) -# add handler to record metrics to TensorBoard -val_tensorboard_stats_handler = TensorBoardStatsHandler(global_epoch_transform=_global_epoch_transform) +# add handler to record metrics to TensorBoard at every epoch +val_tensorboard_stats_handler = TensorBoardStatsHandler( + output_transform=lambda x: None, # no iteration plot + global_epoch_transform=lambda x: trainer.state.epoch) # use epoch number from trainer val_tensorboard_stats_handler.attach(evaluator) # add handler to draw several images and the corresponding labels and model outputs # here we draw the first 3 images(draw the first channel) as GIF format along Depth axis val_tensorboard_image_handler = TensorBoardImageHandler( - batch_transform=lambda batch: (batch[0][0:3, 0:1, ...], batch[1][0:3, 0:1, ...]), - output_transform=lambda output: (output[0][1][0:3, 0:1, ...], None), - global_step_transform=_global_epoch_transform + batch_transform=lambda batch: (batch[0], batch[1]), + output_transform=lambda output: output[0][1], + global_iter_transform=lambda x: trainer.state.epoch ) evaluator.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=val_tensorboard_image_handler) diff --git a/examples/unet_segmentation_3d_dict.py b/examples/unet_segmentation_3d_dict.py index 640bed21c0..0e1e78811b 100644 --- a/examples/unet_segmentation_3d_dict.py +++ b/examples/unet_segmentation_3d_dict.py @@ -111,7 +111,7 @@ def prepare_batch(batch, device=None, non_blocking=False): trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=checkpoint_handler, to_save={'net': net, 'opt': opt}) -train_stats_handler = StatsHandler() +train_stats_handler = StatsHandler(output_transform=lambda x: x[1]) train_stats_handler.attach(trainer) @@ -160,7 +160,7 @@ def log_training_loss(engine): # Add stats event handler to print validation stats via evaluator logging.basicConfig(stream=sys.stdout, level=logging.INFO) -val_stats_handler = StatsHandler() +val_stats_handler = StatsHandler(output_transform=lambda x: None) val_stats_handler.attach(evaluator) # Add early stopping handler to evaluator. diff --git a/monai/handlers/stats_handler.py b/monai/handlers/stats_handler.py index 500835df2a..9d2e518919 100644 --- a/monai/handlers/stats_handler.py +++ b/monai/handlers/stats_handler.py @@ -9,51 +9,63 @@ # See the License for the specific language governing permissions and # limitations under the License. +import warnings import logging import torch from ignite.engine import Engine, Events +from monai.utils.misc import is_scalar -KEY_VAL_FORMAT = '{}: {:.4f} ' +DEFAULT_KEY_VAL_FORMAT = '{}: {:.4f} ' +DEFAULT_TAG = 'Loss' class StatsHandler(object): """StatsHandler defines a set of Ignite Event-handlers for all the log printing logics. It's can be used for any Ignite Engine(trainer, validator and evaluator). And it can support logging for epoch level and iteration level with pre-defined loggers. - By default: - (1) epoch_print_logger logs `engine.state.metrics`. - (2) iteration_print_logger logs loss value, expected output format is (y_pred, loss). + + Default behaviors: + - When EPOCH_COMPLETED, logs ``engine.state.metrics`` using ``self.logger``. + - When ITERATION_COMPELTED, logs + ``self.output_transform(engine.state.output)`` using ``self.logger``. """ def __init__(self, epoch_print_logger=None, iteration_print_logger=None, - batch_transform=lambda x: x, output_transform=lambda x: x, + global_epoch_transform=lambda x: x, name=None, - global_epoch_transform=None): + tag_name=DEFAULT_TAG, + key_var_format=DEFAULT_KEY_VAL_FORMAT): """ Args: epoch_print_logger (Callable): customized callable printer for epoch level logging. must accept parameter "engine", use default printer if None. iteration_print_logger (Callable): custimized callable printer for iteration level logging. must accept parameter "engine", use default printer if None. - batch_transform (Callable): a callable that is used to transform the - ignite.engine.batch into expected format to extract input data. output_transform (Callable): a callable that is used to transform the - ignite.engine.output into expected format to extract several output data. - name (str): identifier of logging.logger to use, defaulting to `engine.logger`. + ``ignite.engine.output`` into a scalar to print, or a dictionary of {key: scalar}. + in the latter case, the output string will be formated as key: value. + by default this value logging happens when every iteration completed. global_epoch_transform (Callable): a callable that is used to customize global epoch number. - For example, in evaluation, the evaluator engine needs to know current epoch from trainer. + For example, in evaluation, the evaluator engine might want to print synced epoch number + with the trainer engine. + name (str): identifier of logging.logger to use, defaulting to ``engine.logger``. + tag_name (string): when iteration output is a scalar, tag_name is used to print + tag_name: scalar_value to logger. Defaults to ``'Loss'``. + key_var_format (string): a formatting string to control the output string format of key: value. """ self.epoch_print_logger = epoch_print_logger self.iteration_print_logger = iteration_print_logger - self.batch_transform = batch_transform self.output_transform = output_transform self.global_epoch_transform = global_epoch_transform self.logger = None if name is None else logging.getLogger(name) + self.tag_name = tag_name + self.key_var_format = key_var_format + def attach(self, engine: Engine): """Register a set of Ignite Event-Handlers to a specified Ignite engine. @@ -120,13 +132,12 @@ def _default_epoch_print(self, engine: Engine): prints_dict = engine.state.metrics if not prints_dict: return - current_epoch = self.global_epoch_transform() if self.global_epoch_transform is not None \ - else engine.state.epoch + current_epoch = self.global_epoch_transform(engine.state.epoch) out_str = "Epoch[{}] Metrics -- ".format(current_epoch) for name in sorted(prints_dict): value = prints_dict[name] - out_str += KEY_VAL_FORMAT.format(name, value) + out_str += self.key_var_format.format(name, value) self.logger.info(out_str) @@ -139,19 +150,42 @@ def _default_iteration_print(self, engine: Engine): engine (ignite.engine): Ignite Engine, it can be a trainer, validator or evaluator. """ - loss = self.output_transform(engine.state.output)[1] - if loss is None or (torch.is_tensor(loss) and len(loss.shape) > 0): - return # not printing multi dimensional output + loss = self.output_transform(engine.state.output) + if loss is None: + return # no printing if the output is empty + + out_str = '' + if isinstance(loss, dict): # print dictionary items + for name in sorted(loss): + value = loss[name] + if not is_scalar(value): + warnings.warn('ignoring non-scalar output in StatsHandler,' + ' make sure `output_transform(engine.state.output)` returns' + ' a scalar or dictionary of key and scalar pairs to avoid this warning.' + ' {}:{}'.format(name, type(value))) + continue # not printing multi dimensional output + out_str += self.key_var_format.format(name, value.item() if torch.is_tensor(value) else value) + else: + if is_scalar(loss): # not printing multi dimensional output + out_str += self.key_var_format.format(self.tag_name, loss.item() if torch.is_tensor(loss) else loss) + else: + warnings.warn('ignoring non-scalar output in StatsHandler,' + ' make sure `output_transform(engine.state.output)` returns' + ' a scalar or a dictionary of key and scalar pairs to avoid this warning.' + ' {}'.format(type(loss))) + + if not out_str: + return # no value to print + num_iterations = engine.state.epoch_length current_iteration = (engine.state.iteration - 1) % num_iterations + 1 current_epoch = engine.state.epoch num_epochs = engine.state.max_epochs - out_str = "Epoch: {}/{}, Iter: {}/{} -- ".format( + base_str = "Epoch: {}/{}, Iter: {}/{} --".format( current_epoch, num_epochs, current_iteration, num_iterations) - out_str += KEY_VAL_FORMAT.format('Loss', loss.item() if torch.is_tensor(loss) else loss) - self.logger.info(out_str) + self.logger.info(' '.join([base_str, out_str])) diff --git a/monai/handlers/tensorboard_handlers.py b/monai/handlers/tensorboard_handlers.py index 08414762cc..2d5116bd59 100644 --- a/monai/handlers/tensorboard_handlers.py +++ b/monai/handlers/tensorboard_handlers.py @@ -9,29 +9,35 @@ # See the License for the specific language governing permissions and # limitations under the License. +import numpy as np +import warnings import torch from torch.utils.tensorboard import SummaryWriter from ignite.engine import Engine, Events from monai.visualize import img2tensorboard +from monai.utils.misc import is_scalar +from monai.transforms.utils import rescale_array class TensorBoardStatsHandler(object): """TensorBoardStatsHandler defines a set of Ignite Event-handlers for all the TensorBoard logics. It's can be used for any Ignite Engine(trainer, validator and evaluator). And it can support both epoch level and iteration level with pre-defined TensorBoard event writer. - The expected data source is ignite `engine.state.output` and `engine.state.metrics`. - Default behaviors: - (1) Write metrics to TensorBoard when EPOCH_COMPLETED. - (2) Write loss to TensorBoard when ITERATION_COMPLETED, expected output format is (y_pred, loss). + The expected data source is ignite ``engine.state.output`` and ``engine.state.metrics``. + Default behaviors: + - When EPOCH_COMPLETED, write each dictionary item in + ``engine.state.metrics`` to TensorBoard. + - When ITERATION_COMPELTED, write each dictionary item in + ``self.output_transform(engine.state.output)`` to TensorBoard. """ + def __init__(self, summary_writer=None, epoch_event_writer=None, iteration_event_writer=None, - batch_transform=lambda x: x, - output_transform=lambda x: x, - global_epoch_transform=None): + output_transform=lambda x: {'Loss': x}, + global_epoch_transform=lambda x: x): """ Args: summary_writer (SummaryWriter): user can specify TensorBoard SummaryWriter, @@ -40,22 +46,19 @@ def __init__(self, must accept parameter "engine" and "summary_writer", use default event writer if None. iteration_event_writer (Callable): custimized callable TensorBoard writer for iteration level. must accept parameter "engine" and "summary_writer", use default event writer if None. - batch_transform (Callable): a callable that is used to transform the - ignite.engine.batch into expected format to extract several label data. output_transform (Callable): a callable that is used to transform the - ignite.engine.output into expected format to extract several output data. + ``ignite.engine.output`` into a dictionary of (tag_name: scalar) pairs to be plotted onto tensorboard. + by default this scalar plotting happens when every iteration completed. global_epoch_transform (Callable): a callable that is used to customize global epoch number. - For example, in evaluation, the evaluator engine needs to know current epoch from trainer. - + For example, in evaluation, the evaluator engine might want to use trainer engines epoch number + when plotting epoch vs metric curves. """ self._writer = SummaryWriter() if summary_writer is None else summary_writer self.epoch_event_writer = epoch_event_writer self.iteration_event_writer = iteration_event_writer - self.batch_transform = batch_transform self.output_transform = output_transform self.global_epoch_transform = global_epoch_transform - def attach(self, engine: Engine): """Register a set of Ignite Event-Handlers to a specified Ignite engine. @@ -103,8 +106,7 @@ def _default_epoch_writer(self, engine: Engine, writer: SummaryWriter): writer (SummaryWriter): TensorBoard writer, created in TensorBoardHandler. """ - current_epoch = self.global_epoch_transform() if self.global_epoch_transform is not None \ - else engine.state.epoch + current_epoch = self.global_epoch_transform(engine.state.epoch) summary_dict = engine.state.metrics for name, value in summary_dict.items(): writer.add_scalar(name, value, current_epoch) @@ -119,69 +121,129 @@ def _default_iteration_writer(self, engine: Engine, writer: SummaryWriter): writer (SummaryWriter): TensorBoard writer, created in TensorBoardHandler. """ - loss = self.output_transform(engine.state.output)[1] - if loss is None or (torch.is_tensor(loss) and len(loss.shape) > 0): - return # not record multi dimensional output - writer.add_scalar('Loss', loss.item() if torch.is_tensor(loss) else loss, engine.state.iteration) - writer.flush() + loss_dict = self.output_transform(engine.state.output) + if loss_dict is None: + return # do nothing if output is empty + if not isinstance(loss_dict, dict): + raise ValueError('TensorBoardStatsHandler requires' + ' output_transform(engine.state.output) returning a dictionary' + ' of key and scalar pairs to plot' + ' got {}.'.format(type(loss_dict))) + for name, value in loss_dict.items(): + if not is_scalar(value): + warnings.warn('ignoring non-scalar output in tensorboard curve plotting,' + ' make sure `output_transform(engine.state.output)` returns' + ' a dictionary of key and scalar pairs to avoid this warning.' + ' Got {}:{}'.format(name, type(value))) + continue + plot_value = value.item() if torch.is_tensor(value) else value + writer.add_scalar(name, plot_value, engine.state.iteration) + writer.flush() class TensorBoardImageHandler(object): - """TensorBoardImageHandler is an ignite Event handler that can draw images, labels and outputs as 2D/3D images. - 2D output will be shown as simple image, 3D output will be shown as GIF image along the last axis(typically Depth) - It's can be used for any Ignite Engine(trainer, validator and evaluator). - User can easily added it to engine for any expected Event, for example: EPOCH_COMPLETED, ITERATION_COMPLETED. - The expected data source is ignite `engine.state.batch` and `engine.state.output`. + """TensorBoardImageHandler is an ignite Event handler that can visualise images, labels and outputs as 2D/3D images. + 2D output (shape in Batch, channel, H, W) will be shown as simple image using the first element in the batch, + for 3D to ND output (shape in Batch, channel, H, W, D) input, + the last three dimensions will be shown as GIF image along the last axis (typically Depth). + + It's can be used for any Ignite Engine (trainer, validator and evaluator). + User can easily added it to engine for any expected Event, for example: ``EPOCH_COMPLETED``, + ``ITERATION_COMPLETED``. The expected data source is ignite's ``engine.state.batch`` and ``engine.state.output``. + Default behavior: - Show y_pred as images(GIF for 3D) on TensorBoard when Event triggered, - need to use `batch_transform` and `output_transform` to specify how many images to show and show which channel. - Expected batch data format is (image[N, channel, ...], label[N, channel, ...]), - Expected output data format is (y_pred[N, channel, ...], loss). - """ + - Show y_pred as images (GIF for 3D) on TensorBoard when Event triggered, + - need to use ``batch_transform`` and ``output_transform`` to specify + how many images to show and show which channel. + - Expects ``batch_transform(engine.state.batch)`` to return data + format: (image[N, channel, ...], label[N, channel, ...]). + - Expects ``output_transform(engine.state.output)`` to return a torch + tensor in format (y_pred[N, channel, ...], loss). + """ def __init__(self, summary_writer=None, batch_transform=lambda x: x, output_transform=lambda x: x, - global_step_transform=None): + global_iter_transform=lambda x: x, + max_channels=1, + max_frames=64): """ Args: summary_writer (SummaryWriter): user can specify TensorBoard SummaryWriter, default to create a new writer. batch_transform (Callable): a callable that is used to transform the - ignite.engine.batch into expected format to extract several label data. + ``ignite.engine.batch`` into expected format to extract several label data. output_transform (Callable): a callable that is used to transform the - ignite.engine.output into expected format to extract several output data. - global_step_transform (Callable): a callable that is used to customize global step number for TensorBoard. + ``ignite.engine.output`` into expected format to extract several output data. + global_iter_transform (Callable): a callable that is used to customize global step number for TensorBoard. For example, in evaluation, the evaluator engine needs to know current epoch from trainer. + max_channels (int): number of channels to plot. + max_frames (int): number of frames for 2D-t plot. """ self._writer = SummaryWriter() if summary_writer is None else summary_writer self.batch_transform = batch_transform self.output_transform = output_transform - self.global_step_transform = global_step_transform + self.global_iter_transform = global_iter_transform + + self.max_frames = max_frames + self.max_channels = max_channels def __call__(self, engine): - step = self.global_step_transform() if self.global_step_transform is not None else None + step = self.global_iter_transform(engine.state.iteration) + show_images = self.batch_transform(engine.state.batch)[0] + if torch.is_tensor(show_images): + show_images = show_images.detach().cpu().numpy() + if show_images is not None: + if not isinstance(show_images, np.ndarray): + raise ValueError('output_transform(engine.state.output)[0] must be an ndarray or tensor.') + self._add_2_or_3_d(show_images, step, 'input_0') + show_labels = self.batch_transform(engine.state.batch)[1] - show_outputs = self.output_transform(engine.state.output)[0].detach().cpu() - if show_images is not None and show_labels is not None and show_outputs is not None: - ones = torch.ones(show_labels[0].shape, dtype=torch.int32) - if len(show_labels.shape) == 5: - assert show_images.shape[1] == show_labels.shape[1] == show_outputs.shape[1] == 1, \ - '3D images must select only 1 channel to show.' - for i in range(len(show_labels)): - img2tensorboard.add_animated_gif(self._writer, 'image' + str(i), show_images[i], 64, 255, step) - label_tensor = torch.where(show_labels[i] > 0, ones, show_labels[i]) - img2tensorboard.add_animated_gif(self._writer, 'label' + str(i), label_tensor, 64, 255, step) - img2tensorboard.add_animated_gif(self._writer, 'output' + str(i), show_outputs[i], 64, 255, step) - elif len(show_labels.shape) == 4: - for i in range(len(show_labels)): - self._writer.add_image('image' + str(i), show_images[i], step) - label_tensor = torch.where(show_labels[i] > 0, ones, show_labels[i]) - self._writer.add_image('label' + str(i), label_tensor, step) - self._writer.add_image('output' + str(i), show_outputs[i], step) - else: - raise ValueError('unsupported input data shape.') - self._writer.flush() + if torch.is_tensor(show_labels): + show_labels = show_labels.detach().cpu().numpy() + if show_labels is not None: + if not isinstance(show_labels, np.ndarray): + raise ValueError('batch_transform(engine.state.batch)[1] must be an ndarray or tensor.') + self._add_2_or_3_d(show_labels, step, 'input_1') + + show_outputs = self.output_transform(engine.state.output) + if torch.is_tensor(show_outputs): + show_outputs = show_outputs.detach().cpu().numpy() + if show_outputs is not None: + if not isinstance(show_outputs, np.ndarray): + raise ValueError('output_transform(engine.state.output) must be an ndarray or tensor.') + self._add_2_or_3_d(show_outputs, step, 'output') + + self._writer.flush() + + def _add_2_or_3_d(self, data, step, tag='output'): + # for i, d in enumerate(data): # go through a batch of images + d = data[0] # show the first element in a batch + + if d.ndim == 2: + d = rescale_array(d, 0, 1) + dataformats = 'HW' + self._writer.add_image('{}_{}'.format(tag, dataformats), d, step, dataformats=dataformats) + return + + if d.ndim == 3: + if d.shape[0] == 3 and self.max_channels == 3: # rgb? + dataformats = 'CHW' + self._writer.add_image('{}_{}'.format(tag, dataformats), d, step, dataformats='CHW') + return + for j, d2 in enumerate(d[:self.max_channels]): + d2 = rescale_array(d2, 0, 1) + dataformats = 'HW' + self._writer.add_image('{}_{}_{}'.format(tag, dataformats, j), d2, step, dataformats=dataformats) + return + + if d.ndim >= 4: + spatial = d.shape[-3:] + for j, d3 in enumerate(d.reshape([-1] + list(spatial))[:self.max_channels]): + d3 = rescale_array(d3, 0, 255) + img2tensorboard.add_animated_gif( + self._writer, '{}_HWD_{}'.format(tag, j), d3[None], self.max_frames, 1.0, step) + return diff --git a/monai/handlers/utils.py b/monai/handlers/utils.py index 377d4d0073..1cd849d18e 100644 --- a/monai/handlers/utils.py +++ b/monai/handlers/utils.py @@ -12,13 +12,17 @@ def stopping_fn_from_metric(metric_name): """Returns a stopping function for ignite.handlers.EarlyStopping using the given metric name.""" + def stopping_fn(engine): return engine.state.metrics[metric_name] + return stopping_fn def stopping_fn_from_loss(): """Returns a stopping function for ignite.handlers.EarlyStopping using the loss value.""" + def stopping_fn(engine): return -engine.state.output + return stopping_fn diff --git a/monai/utils/misc.py b/monai/utils/misc.py index 775b8f2ebd..261e521adb 100644 --- a/monai/utils/misc.py +++ b/monai/utils/misc.py @@ -11,6 +11,9 @@ import itertools +import numpy as np +import torch + def zip_with(op, *vals, mapfunc=map): """ @@ -40,3 +43,15 @@ def ensure_tuple(vals): vals = (vals,) return tuple(vals) + + +def is_scalar_tensor(val): + if torch.is_tensor(val) and val.ndim == 0: + return True + return False + + +def is_scalar(val): + if torch.is_tensor(val) and val.ndim == 0: + return True + return np.isscalar(val) diff --git a/tests/test_handler_stats.py b/tests/test_handler_stats.py index 58a62133d2..fdb0600e04 100644 --- a/tests/test_handler_stats.py +++ b/tests/test_handler_stats.py @@ -30,7 +30,7 @@ def test_metrics_print(self): # set up engine def _train_func(engine, batch): - return None, torch.tensor(0.0) + return torch.tensor(0.0) engine = Engine(_train_func) @@ -50,7 +50,6 @@ def _update_metric(engine): output_str = log_stream.getvalue() grep = re.compile('.*{}.*'.format(key_to_handler)) has_key_word = re.compile('.*{}.*'.format(key_to_print)) - matched = [] for idx, line in enumerate(output_str.split('\n')): if grep.match(line): if idx in [5, 10]: @@ -60,16 +59,44 @@ def test_loss_print(self): log_stream = StringIO() logging.basicConfig(stream=log_stream, level=logging.INFO) key_to_handler = 'test_logging' - key_to_print = 'Loss' + key_to_print = 'myLoss' # set up engine def _train_func(engine, batch): - return None, torch.tensor(0.0) + return torch.tensor(0.0) engine = Engine(_train_func) # set up testing handler - stats_handler = StatsHandler(name=key_to_handler) + stats_handler = StatsHandler(name=key_to_handler, tag_name=key_to_print) + stats_handler.attach(engine) + + engine.run(range(3), max_epochs=2) + + # check logging output + output_str = log_stream.getvalue() + grep = re.compile('.*{}.*'.format(key_to_handler)) + has_key_word = re.compile('.*{}.*'.format(key_to_print)) + for idx, line in enumerate(output_str.split('\n')): + if grep.match(line): + if idx in [1, 2, 3, 6, 7, 8]: + self.assertTrue(has_key_word.match(line)) + + def test_loss_dict(self): + log_stream = StringIO() + logging.basicConfig(stream=log_stream, level=logging.INFO) + key_to_handler = 'test_logging' + key_to_print = 'myLoss1' + + # set up engine + def _train_func(engine, batch): + return torch.tensor(0.0) + + engine = Engine(_train_func) + + # set up testing handler + stats_handler = StatsHandler(name=key_to_handler, + output_transform=lambda x: {key_to_print: x}) stats_handler.attach(engine) engine.run(range(3), max_epochs=2) @@ -78,7 +105,6 @@ def _train_func(engine, batch): output_str = log_stream.getvalue() grep = re.compile('.*{}.*'.format(key_to_handler)) has_key_word = re.compile('.*{}.*'.format(key_to_print)) - matched = [] for idx, line in enumerate(output_str.split('\n')): if grep.match(line): if idx in [1, 2, 3, 6, 7, 8]: diff --git a/tests/test_handler_tb_image.py b/tests/test_handler_tb_image.py new file mode 100644 index 0000000000..9bf55e162b --- /dev/null +++ b/tests/test_handler_tb_image.py @@ -0,0 +1,60 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import glob +import os +import shutil +import unittest + +import numpy as np +import torch +from ignite.engine import Engine, Events +from parameterized import parameterized + +from monai.handlers.tensorboard_handlers import TensorBoardImageHandler + +TEST_CASES = [ + [[20, 20]], + [[2, 20, 20]], + [[3, 20, 20]], + [[20, 20, 20]], + [[2, 20, 20, 20]], + [[2, 2, 20, 20, 20]], +] + + +class TestHandlerTBImage(unittest.TestCase): + + @parameterized.expand(TEST_CASES) + def test_tb_image_shape(self, shape): + default_dir = os.path.join('.', 'runs') + shutil.rmtree(default_dir, ignore_errors=True) + + # set up engine + def _train_func(engine, batch): + return torch.zeros((1, 1, 10, 10)) + + engine = Engine(_train_func) + + # set up testing handler + stats_handler = TensorBoardImageHandler() + engine.add_event_handler(Events.ITERATION_COMPLETED, stats_handler) + + data = zip(np.random.normal(size=(10, 4, *shape)), np.random.normal(size=(10, 4, *shape))) + engine.run(data, epoch_length=10, max_epochs=1) + + self.assertTrue(os.path.exists(default_dir)) + self.assertTrue(len(glob.glob(default_dir)) > 0) + shutil.rmtree(default_dir) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_handler_tb_stats.py b/tests/test_handler_tb_stats.py new file mode 100644 index 0000000000..53a691701f --- /dev/null +++ b/tests/test_handler_tb_stats.py @@ -0,0 +1,81 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import shutil +import tempfile +import unittest +import glob + +from ignite.engine import Engine, Events +from torch.utils.tensorboard import SummaryWriter + +from monai.handlers.tensorboard_handlers import TensorBoardStatsHandler + + +class TestHandlerTBStats(unittest.TestCase): + + def test_metrics_print(self): + default_dir = os.path.join('.', 'runs') + shutil.rmtree(default_dir, ignore_errors=True) + + # set up engine + def _train_func(engine, batch): + return batch + 1.0 + + engine = Engine(_train_func) + + # set up dummy metric + @engine.on(Events.EPOCH_COMPLETED) + def _update_metric(engine): + current_metric = engine.state.metrics.get('acc', 0.1) + engine.state.metrics['acc'] = current_metric + 0.1 + + # set up testing handler + stats_handler = TensorBoardStatsHandler() + stats_handler.attach(engine) + engine.run(range(3), max_epochs=2) + # check logging output + + self.assertTrue(os.path.exists(default_dir)) + shutil.rmtree(default_dir) + + def test_metrics_writer(self): + default_dir = os.path.join('.', 'runs') + shutil.rmtree(default_dir, ignore_errors=True) + with tempfile.TemporaryDirectory() as temp_dir: + + # set up engine + def _train_func(engine, batch): + return batch + 1.0 + + engine = Engine(_train_func) + + # set up dummy metric + @engine.on(Events.EPOCH_COMPLETED) + def _update_metric(engine): + current_metric = engine.state.metrics.get('acc', 0.1) + engine.state.metrics['acc'] = current_metric + 0.1 + + # set up testing handler + writer = SummaryWriter(log_dir=temp_dir) + stats_handler = TensorBoardStatsHandler( + writer, output_transform=lambda x: {'loss': x * 2.0}, + global_epoch_transform=lambda x: x * 3.0) + stats_handler.attach(engine) + engine.run(range(3), max_epochs=2) + # check logging output + self.assertTrue(len(glob.glob(temp_dir)) > 0) + self.assertTrue(not os.path.exists(default_dir)) + + +if __name__ == '__main__': + unittest.main()