diff --git a/docs/source/handlers.rst b/docs/source/handlers.rst index 81d28fb4ac..a629b28b27 100644 --- a/docs/source/handlers.rst +++ b/docs/source/handlers.rst @@ -85,6 +85,9 @@ Training stats handler Tensorboard handlers -------------------- +.. autoclass:: TensorBoardHandler + :members: + .. autoclass:: TensorBoardStatsHandler :members: diff --git a/monai/handlers/__init__.py b/monai/handlers/__init__.py index 81c65ed580..8f73f7f2fd 100644 --- a/monai/handlers/__init__.py +++ b/monai/handlers/__init__.py @@ -24,7 +24,7 @@ from .smartcache_handler import SmartCacheHandler from .stats_handler import StatsHandler from .surface_distance import SurfaceDistance -from .tensorboard_handlers import TensorBoardImageHandler, TensorBoardStatsHandler +from .tensorboard_handlers import TensorBoardHandler, TensorBoardImageHandler, TensorBoardStatsHandler from .utils import ( evenly_divisible_all_gather, stopping_fn_from_loss, diff --git a/monai/handlers/metrics_saver.py b/monai/handlers/metrics_saver.py index b9ea296821..87d7223c96 100644 --- a/monai/handlers/metrics_saver.py +++ b/monai/handlers/metrics_saver.py @@ -34,20 +34,22 @@ class MetricsSaver: "*" - save all the existing metrics in `engine.state.metrics` dict into separate files. list of strings - specify the expected metrics to save. default to "*" to save all the metrics into `metrics.csv`. - metric_details: expected metric details to save into files, for example: mean dice - of every channel of every image in the validation dataset. - the data in `engine.state.metric_details` must contain at least 2 dims: (batch, classes, ...), + metric_details: expected metric details to save into files, the data comes from + `engine.state.metric_details`, which should be provided by different `Metrics`, + typically, it's some intermediate values in metric computation. + for example: mean dice of every channel of every image in the validation dataset. + it must contain at least 2 dims: (batch, classes, ...), if not, will unsequeeze to 2 dims. this arg can be: None, "*" or list of strings. - None - don't save any metrics into files. - "*" - save all the existing metrics in `engine.state.metric_details` dict into separate files. - list of strings - specify the expected metrics to save. - if not None, every metric will save a separate `{metric name}_raw.csv` file. + None - don't save any metric_details into files. + "*" - save all the existing metric_details in `engine.state.metric_details` dict into separate files. + list of strings - specify the metric_details of expected metrics to save. + if not None, every metric_details array will save a separate `{metric name}_raw.csv` file. batch_transform: callable function to extract the meta_dict from input batch data if saving metric details. used to extract filenames from input dict data. - summary_ops: expected computation operations to generate the summary report. + summary_ops: expected computation operations to generate the summary report based on specified metric_details. it can be: None, "*" or list of strings. - None - don't generate summary report for every expected metric_details + None - don't generate summary report for every specified metric_details "*" - generate summary report for every metric_details with all the supported operations. list of strings - generate summary report for every metric_details with specified operations, they should be within this list: [`mean`, `median`, `max`, `min`, `90percent`, `std`]. diff --git a/monai/handlers/tensorboard_handlers.py b/monai/handlers/tensorboard_handlers.py index acdfb84c8c..4ee88bcfc9 100644 --- a/monai/handlers/tensorboard_handlers.py +++ b/monai/handlers/tensorboard_handlers.py @@ -29,7 +29,38 @@ DEFAULT_TAG = "Loss" -class TensorBoardStatsHandler: +class TensorBoardHandler: + """ + Base class for the handlers to write data into TensorBoard. + + Args: + summary_writer: user can specify TensorBoard SummaryWriter, + default to create a new writer. + log_dir: if using default SummaryWriter, write logs to this directory, default is `./runs`. + + """ + + def __init__(self, summary_writer: Optional[SummaryWriter] = None, log_dir: str = "./runs"): + if summary_writer is None: + self._writer = SummaryWriter(log_dir=log_dir) + self.internal_writer = True + else: + self._writer = summary_writer + self.internal_writer = False + + def attach(self, engine: Engine) -> None: + raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") + + def close(self): + """ + Close the summary writer if created in this TensorBoard handler. + + """ + if self.internal_writer: + self._writer.close() + + +class TensorBoardStatsHandler(TensorBoardHandler): """ TensorBoardStatsHandler defines a set of Ignite Event-handlers for all the TensorBoard logics. It's can be used for any Ignite Engine(trainer, validator and evaluator). @@ -71,7 +102,7 @@ def __init__( when plotting epoch vs metric curves. tag_name: when iteration output is a scalar, tag_name is used to plot, defaults to ``'Loss'``. """ - self._writer = SummaryWriter(log_dir=log_dir) if summary_writer is None else summary_writer + super().__init__(summary_writer=summary_writer, log_dir=log_dir) self.epoch_event_writer = epoch_event_writer self.iteration_event_writer = iteration_event_writer self.output_transform = output_transform @@ -176,7 +207,7 @@ def _default_iteration_writer(self, engine: Engine, writer: SummaryWriter) -> No writer.flush() -class TensorBoardImageHandler: +class TensorBoardImageHandler(TensorBoardHandler): """ TensorBoardImageHandler is an Ignite Event handler that can visualize images, labels and outputs as 2D/3D images. 2D output (shape in Batch, channel, H, W) will be shown as simple image using the first element in the batch, @@ -229,7 +260,7 @@ def __init__( max_channels: number of channels to plot. max_frames: number of frames for 2D-t plot. """ - self._writer = SummaryWriter(log_dir=log_dir) if summary_writer is None else summary_writer + super().__init__(summary_writer=summary_writer, log_dir=log_dir) self.interval = interval self.epoch_level = epoch_level self.batch_transform = batch_transform diff --git a/monai/handlers/utils.py b/monai/handlers/utils.py index f4f0250b49..3e36af0652 100644 --- a/monai/handlers/utils.py +++ b/monai/handlers/utils.py @@ -128,7 +128,7 @@ def write_metrics_reports( images: name or path of every input image corresponding to the metric_details data. if None, will use index number as the filename of every input image. metrics: a dictionary of (metric name, metric value) pairs. - metric_details: a dictionary of (metric name, metric raw values) pairs, + metric_details: a dictionary of (metric name, metric raw values) pairs, usually, it comes from metrics computation, for example, the raw value can be the mean_dice of every channel of every input image. summary_ops: expected computation operations to generate the summary report. it can be: None, "*" or list of strings. diff --git a/tests/test_handler_tb_image.py b/tests/test_handler_tb_image.py index ed3ba8a32d..f946fb6060 100644 --- a/tests/test_handler_tb_image.py +++ b/tests/test_handler_tb_image.py @@ -40,6 +40,7 @@ def _train_func(engine, batch): data = zip(np.random.normal(size=(10, 4, *shape)), np.random.normal(size=(10, 4, *shape))) engine.run(data, epoch_length=10, max_epochs=1) + stats_handler.close() self.assertTrue(len(glob.glob(tempdir)) > 0) diff --git a/tests/test_handler_tb_stats.py b/tests/test_handler_tb_stats.py index 2d7d18d1f6..0d8654cb09 100644 --- a/tests/test_handler_tb_stats.py +++ b/tests/test_handler_tb_stats.py @@ -39,6 +39,7 @@ def _update_metric(engine): stats_handler = TensorBoardStatsHandler(log_dir=tempdir) stats_handler.attach(engine) engine.run(range(3), max_epochs=2) + stats_handler.close() # check logging output self.assertTrue(len(glob.glob(tempdir)) > 0) @@ -64,6 +65,7 @@ def _update_metric(engine): ) stats_handler.attach(engine) engine.run(range(3), max_epochs=2) + writer.close() # check logging output self.assertTrue(len(glob.glob(tempdir)) > 0) diff --git a/tests/test_integration_workflows.py b/tests/test_integration_workflows.py index aa4ccbb76d..db7580bf86 100644 --- a/tests/test_integration_workflows.py +++ b/tests/test_integration_workflows.py @@ -22,6 +22,7 @@ import numpy as np import torch from ignite.metrics import Accuracy +from torch.utils.tensorboard import SummaryWriter import monai from monai.data import create_test_image_3d @@ -105,6 +106,7 @@ def run_training_test(root_dir, device="cuda:0", amp=False, num_workers=4): loss = monai.losses.DiceLoss(sigmoid=True) opt = torch.optim.Adam(net.parameters(), 1e-3) lr_scheduler = torch.optim.lr_scheduler.StepLR(opt, step_size=2, gamma=0.1) + summary_writer = SummaryWriter(log_dir=root_dir) val_post_transforms = Compose( [ @@ -123,7 +125,7 @@ def _forward_completed(self, engine): val_handlers = [ StatsHandler(output_transform=lambda x: None), - TensorBoardStatsHandler(log_dir=root_dir, output_transform=lambda x: None), + TensorBoardStatsHandler(summary_writer=summary_writer, output_transform=lambda x: None), TensorBoardImageHandler( log_dir=root_dir, batch_transform=lambda x: (x["image"], x["label"]), output_transform=lambda x: x["pred"] ), @@ -176,7 +178,9 @@ def _optimizer_completed(self, engine): LrScheduleHandler(lr_scheduler=lr_scheduler, print_lr=True), ValidationHandler(validator=evaluator, interval=2, epoch_level=True), StatsHandler(tag_name="train_loss", output_transform=lambda x: x["loss"]), - TensorBoardStatsHandler(log_dir=root_dir, tag_name="train_loss", output_transform=lambda x: x["loss"]), + TensorBoardStatsHandler( + summary_writer=summary_writer, tag_name="train_loss", output_transform=lambda x: x["loss"] + ), CheckpointSaver(save_dir=root_dir, save_dict={"net": net, "opt": opt}, save_interval=2, epoch_level=True), _TestTrainIterEvents(), ]