Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/source/handlers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,9 @@ Training stats handler

Tensorboard handlers
--------------------
.. autoclass:: TensorBoardHandler
:members:

.. autoclass:: TensorBoardStatsHandler
:members:

Expand Down
2 changes: 1 addition & 1 deletion monai/handlers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from .smartcache_handler import SmartCacheHandler
from .stats_handler import StatsHandler
from .surface_distance import SurfaceDistance
from .tensorboard_handlers import TensorBoardImageHandler, TensorBoardStatsHandler
from .tensorboard_handlers import TensorBoardHandler, TensorBoardImageHandler, TensorBoardStatsHandler
from .utils import (
evenly_divisible_all_gather,
stopping_fn_from_loss,
Expand Down
20 changes: 11 additions & 9 deletions monai/handlers/metrics_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,20 +34,22 @@ class MetricsSaver:
"*" - save all the existing metrics in `engine.state.metrics` dict into separate files.
list of strings - specify the expected metrics to save.
default to "*" to save all the metrics into `metrics.csv`.
metric_details: expected metric details to save into files, for example: mean dice
of every channel of every image in the validation dataset.
the data in `engine.state.metric_details` must contain at least 2 dims: (batch, classes, ...),
metric_details: expected metric details to save into files, the data comes from
`engine.state.metric_details`, which should be provided by different `Metrics`,
typically, it's some intermediate values in metric computation.
for example: mean dice of every channel of every image in the validation dataset.
it must contain at least 2 dims: (batch, classes, ...),
if not, will unsequeeze to 2 dims.
this arg can be: None, "*" or list of strings.
None - don't save any metrics into files.
"*" - save all the existing metrics in `engine.state.metric_details` dict into separate files.
list of strings - specify the expected metrics to save.
if not None, every metric will save a separate `{metric name}_raw.csv` file.
None - don't save any metric_details into files.
"*" - save all the existing metric_details in `engine.state.metric_details` dict into separate files.
list of strings - specify the metric_details of expected metrics to save.
if not None, every metric_details array will save a separate `{metric name}_raw.csv` file.
batch_transform: callable function to extract the meta_dict from input batch data if saving metric details.
used to extract filenames from input dict data.
summary_ops: expected computation operations to generate the summary report.
summary_ops: expected computation operations to generate the summary report based on specified metric_details.
it can be: None, "*" or list of strings.
None - don't generate summary report for every expected metric_details
None - don't generate summary report for every specified metric_details
"*" - generate summary report for every metric_details with all the supported operations.
list of strings - generate summary report for every metric_details with specified operations, they
should be within this list: [`mean`, `median`, `max`, `min`, `90percent`, `std`].
Expand Down
39 changes: 35 additions & 4 deletions monai/handlers/tensorboard_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,38 @@
DEFAULT_TAG = "Loss"


class TensorBoardStatsHandler:
class TensorBoardHandler:
"""
Base class for the handlers to write data into TensorBoard.

Args:
summary_writer: user can specify TensorBoard SummaryWriter,
default to create a new writer.
log_dir: if using default SummaryWriter, write logs to this directory, default is `./runs`.

"""

def __init__(self, summary_writer: Optional[SummaryWriter] = None, log_dir: str = "./runs"):
if summary_writer is None:
self._writer = SummaryWriter(log_dir=log_dir)
self.internal_writer = True
else:
self._writer = summary_writer
self.internal_writer = False

def attach(self, engine: Engine) -> None:
raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.")

def close(self):
"""
Close the summary writer if created in this TensorBoard handler.

"""
if self.internal_writer:
self._writer.close()


class TensorBoardStatsHandler(TensorBoardHandler):
"""
TensorBoardStatsHandler defines a set of Ignite Event-handlers for all the TensorBoard logics.
It's can be used for any Ignite Engine(trainer, validator and evaluator).
Expand Down Expand Up @@ -71,7 +102,7 @@ def __init__(
when plotting epoch vs metric curves.
tag_name: when iteration output is a scalar, tag_name is used to plot, defaults to ``'Loss'``.
"""
self._writer = SummaryWriter(log_dir=log_dir) if summary_writer is None else summary_writer
super().__init__(summary_writer=summary_writer, log_dir=log_dir)
self.epoch_event_writer = epoch_event_writer
self.iteration_event_writer = iteration_event_writer
self.output_transform = output_transform
Expand Down Expand Up @@ -176,7 +207,7 @@ def _default_iteration_writer(self, engine: Engine, writer: SummaryWriter) -> No
writer.flush()


class TensorBoardImageHandler:
class TensorBoardImageHandler(TensorBoardHandler):
"""
TensorBoardImageHandler is an Ignite Event handler that can visualize images, labels and outputs as 2D/3D images.
2D output (shape in Batch, channel, H, W) will be shown as simple image using the first element in the batch,
Expand Down Expand Up @@ -229,7 +260,7 @@ def __init__(
max_channels: number of channels to plot.
max_frames: number of frames for 2D-t plot.
"""
self._writer = SummaryWriter(log_dir=log_dir) if summary_writer is None else summary_writer
super().__init__(summary_writer=summary_writer, log_dir=log_dir)
self.interval = interval
self.epoch_level = epoch_level
self.batch_transform = batch_transform
Expand Down
2 changes: 1 addition & 1 deletion monai/handlers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def write_metrics_reports(
images: name or path of every input image corresponding to the metric_details data.
if None, will use index number as the filename of every input image.
metrics: a dictionary of (metric name, metric value) pairs.
metric_details: a dictionary of (metric name, metric raw values) pairs,
metric_details: a dictionary of (metric name, metric raw values) pairs, usually, it comes from metrics computation,
for example, the raw value can be the mean_dice of every channel of every input image.
summary_ops: expected computation operations to generate the summary report.
it can be: None, "*" or list of strings.
Expand Down
1 change: 1 addition & 0 deletions tests/test_handler_tb_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def _train_func(engine, batch):

data = zip(np.random.normal(size=(10, 4, *shape)), np.random.normal(size=(10, 4, *shape)))
engine.run(data, epoch_length=10, max_epochs=1)
stats_handler.close()

self.assertTrue(len(glob.glob(tempdir)) > 0)

Expand Down
2 changes: 2 additions & 0 deletions tests/test_handler_tb_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def _update_metric(engine):
stats_handler = TensorBoardStatsHandler(log_dir=tempdir)
stats_handler.attach(engine)
engine.run(range(3), max_epochs=2)
stats_handler.close()
# check logging output
self.assertTrue(len(glob.glob(tempdir)) > 0)

Expand All @@ -64,6 +65,7 @@ def _update_metric(engine):
)
stats_handler.attach(engine)
engine.run(range(3), max_epochs=2)
writer.close()
# check logging output
self.assertTrue(len(glob.glob(tempdir)) > 0)

Expand Down
8 changes: 6 additions & 2 deletions tests/test_integration_workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import numpy as np
import torch
from ignite.metrics import Accuracy
from torch.utils.tensorboard import SummaryWriter

import monai
from monai.data import create_test_image_3d
Expand Down Expand Up @@ -105,6 +106,7 @@ def run_training_test(root_dir, device="cuda:0", amp=False, num_workers=4):
loss = monai.losses.DiceLoss(sigmoid=True)
opt = torch.optim.Adam(net.parameters(), 1e-3)
lr_scheduler = torch.optim.lr_scheduler.StepLR(opt, step_size=2, gamma=0.1)
summary_writer = SummaryWriter(log_dir=root_dir)

val_post_transforms = Compose(
[
Expand All @@ -123,7 +125,7 @@ def _forward_completed(self, engine):

val_handlers = [
StatsHandler(output_transform=lambda x: None),
TensorBoardStatsHandler(log_dir=root_dir, output_transform=lambda x: None),
TensorBoardStatsHandler(summary_writer=summary_writer, output_transform=lambda x: None),
TensorBoardImageHandler(
log_dir=root_dir, batch_transform=lambda x: (x["image"], x["label"]), output_transform=lambda x: x["pred"]
),
Expand Down Expand Up @@ -176,7 +178,9 @@ def _optimizer_completed(self, engine):
LrScheduleHandler(lr_scheduler=lr_scheduler, print_lr=True),
ValidationHandler(validator=evaluator, interval=2, epoch_level=True),
StatsHandler(tag_name="train_loss", output_transform=lambda x: x["loss"]),
TensorBoardStatsHandler(log_dir=root_dir, tag_name="train_loss", output_transform=lambda x: x["loss"]),
TensorBoardStatsHandler(
summary_writer=summary_writer, tag_name="train_loss", output_transform=lambda x: x["loss"]
),
CheckpointSaver(save_dir=root_dir, save_dict={"net": net, "opt": opt}, save_interval=2, epoch_level=True),
_TestTrainIterEvents(),
]
Expand Down