Skip to content
4 changes: 2 additions & 2 deletions examples/densenet_classification_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@
trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED,
handler=checkpoint_handler,
to_save={'net': net, 'opt': opt})
train_stats_handler = StatsHandler()
train_stats_handler = StatsHandler(output_transform=lambda x: x[3])
train_stats_handler.attach(trainer)

@trainer.on(Events.EPOCH_COMPLETED)
Expand All @@ -108,7 +108,7 @@ def log_training_loss(engine):

# Add stats event handler to print validation stats via evaluator
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
val_stats_handler = StatsHandler()
val_stats_handler = StatsHandler(output_transform=lambda x: None)
val_stats_handler.attach(evaluator)

# Add early stopping handler to evaluator.
Expand Down
4 changes: 2 additions & 2 deletions examples/unet_segmentation_3d.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@
"trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED,\n",
" handler=checkpoint_handler,\n",
" to_save={'net': net, 'opt': opt})\n",
"train_stats_handler = StatsHandler()\n",
"train_stats_handler = StatsHandler(output_transform=lambda x: x[1])\n",
"train_stats_handler.attach(trainer)\n",
"\n",
"writer = SummaryWriter()\n",
Expand Down Expand Up @@ -260,7 +260,7 @@
"\n",
"# Add stats event handler to print validation stats via evaluator\n",
"logging.basicConfig(stream=sys.stdout, level=logging.INFO)\n",
"val_stats_handler = StatsHandler()\n",
"val_stats_handler = StatsHandler(lambda x: None)\n",
"val_stats_handler.attach(evaluator)\n",
"\n",
"# Add early stopping handler to evaluator.\n",
Expand Down
80 changes: 38 additions & 42 deletions examples/unet_segmentation_3d_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import nibabel as nib
import numpy as np
import torch
from torch.utils.tensorboard import SummaryWriter
from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator
from ignite.handlers import ModelCheckpoint, EarlyStopping
from torch.utils.data import DataLoader
Expand All @@ -32,8 +31,8 @@
from monai.data.nifti_reader import NiftiDataset
from monai.transforms import AddChannel, Rescale, ToTensor, UniformRandomPatch
from monai.handlers.stats_handler import StatsHandler
from monai.handlers.tensorboard_handlers import TensorBoardStatsHandler, TensorBoardImageHandler
from monai.handlers.mean_dice import MeanDice
from monai.visualize import img2tensorboard
from monai.data.synthetic import create_test_image_3d
from monai.handlers.utils import stopping_fn_from_metric

Expand Down Expand Up @@ -88,70 +87,71 @@
loss = monai.losses.DiceLoss(do_sigmoid=True)
opt = torch.optim.Adam(net.parameters(), lr)


# Since network outputs logits and segmentation, we need a custom function.
def _loss_fn(i, j):
return loss(i[0], j)


# Create trainer
device = torch.device("cuda:0")
device = torch.device("cpu:0")
trainer = create_supervised_trainer(net, opt, _loss_fn, device, False,
output_transform=lambda x, y, y_pred, loss: [y_pred, loss.item(), y])
output_transform=lambda x, y, y_pred, loss: [y_pred[1], loss.item(), y])

# adding checkpoint handler to save models (network params and optimizer stats) during training
checkpoint_handler = ModelCheckpoint('./runs/', 'net', n_saved=10, require_empty=False)
trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED,
handler=checkpoint_handler,
to_save={'net': net, 'opt': opt})
train_stats_handler = StatsHandler()
logging.basicConfig(stream=sys.stdout, level=logging.INFO)

# print training loss to commandline
train_stats_handler = StatsHandler(output_transform=lambda x: x[1])
train_stats_handler.attach(trainer)

# record training loss to TensorBoard at every iteration
train_tensorboard_stats_handler = TensorBoardStatsHandler(
output_transform=lambda x: {'training_dice_loss': x[1]}, # plot under tag name taining_dice_loss
global_epoch_transform=lambda x: trainer.state.epoch)
train_tensorboard_stats_handler.attach(trainer)


@trainer.on(Events.EPOCH_COMPLETED)
def log_training_loss(engine):
# log loss to tensorboard with second item of engine.state.output, loss.item() from output_transform
writer.add_scalar('Loss/train', engine.state.output[1], engine.state.epoch)

# tensor of ones to use where for converting labels to zero and ones
ones = torch.ones(engine.state.batch[1][0].shape, dtype=torch.int32)
first_output_tensor = engine.state.output[0][1][0].detach().cpu()
# log model output to tensorboard, as three dimensional tensor with no channels dimension
img2tensorboard.add_animated_gif_no_channels(writer, "first_output_final_batch", first_output_tensor, 64,
255, engine.state.epoch)
# get label tensor and convert to single class
first_label_tensor = torch.where(engine.state.batch[1][0] > 0, ones, engine.state.batch[1][0])
# log label tensor to tensorboard, there is a channel dimension when getting label from batch
img2tensorboard.add_animated_gif(writer, "first_label_final_batch", first_label_tensor, 64,
255, engine.state.epoch)
second_output_tensor = engine.state.output[0][1][1].detach().cpu()
img2tensorboard.add_animated_gif_no_channels(writer, "second_output_final_batch", second_output_tensor, 64,
255, engine.state.epoch)
second_label_tensor = torch.where(engine.state.batch[1][1] > 0, ones, engine.state.batch[1][1])
img2tensorboard.add_animated_gif(writer, "second_label_final_batch", second_label_tensor, 64,
255, engine.state.epoch)
third_output_tensor = engine.state.output[0][1][2].detach().cpu()
img2tensorboard.add_animated_gif_no_channels(writer, "third_output_final_batch", third_output_tensor, 64,
255, engine.state.epoch)
third_label_tensor = torch.where(engine.state.batch[1][2] > 0, ones, engine.state.batch[1][2])
img2tensorboard.add_animated_gif(writer, "third_label_final_batch", third_label_tensor, 64,
255, engine.state.epoch)
engine.logger.info("Epoch[%s] Loss: %s", engine.state.epoch, engine.state.output[1])

writer = SummaryWriter()

# Set parameters for validation
validation_every_n_epochs = 1
metric_name = 'Mean_Dice'

# add evaluation metric to the evaluator engine
val_metrics = {metric_name: MeanDice(add_sigmoid=True)}
evaluator = create_supervised_evaluator(net, val_metrics, device, True,
output_transform=lambda x, y, y_pred: (y_pred[0], y))
val_metrics = {metric_name: MeanDice(
add_sigmoid=True, to_onehot_y=False, output_transform=lambda output: (output[0][0], output[1]))
}
evaluator = create_supervised_evaluator(net, val_metrics, device, True)

# Add stats event handler to print validation stats via evaluator
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
val_stats_handler = StatsHandler()
val_stats_handler = StatsHandler(
output_transform=lambda x: None, # disable per iteration output
global_epoch_transform=lambda x: trainer.state.epoch)
val_stats_handler.attach(evaluator)

# Add early stopping handler to evaluator.
# add handler to record metrics to TensorBoard at every epoch
val_tensorboard_stats_handler = TensorBoardStatsHandler(
output_transform=lambda x: None, # no iteration plot
global_epoch_transform=lambda x: trainer.state.epoch) # use epoch number from trainer
val_tensorboard_stats_handler.attach(evaluator)
# add handler to draw several images and the corresponding labels and model outputs
# here we draw the first 3 images(draw the first channel) as GIF format along Depth axis
val_tensorboard_image_handler = TensorBoardImageHandler(
batch_transform=lambda batch: (batch[0], batch[1]),
output_transform=lambda output: output[0][1],
global_iter_transform=lambda x: trainer.state.epoch
)
evaluator.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=val_tensorboard_image_handler)

# Add early stopping handler to evaluator
early_stopper = EarlyStopping(patience=4,
score_function=stopping_fn_from_metric(metric_name),
trainer=trainer)
Expand All @@ -166,10 +166,6 @@ def log_training_loss(engine):
def run_validation(engine):
evaluator.run(val_loader)

@evaluator.on(Events.EPOCH_COMPLETED)
def log_metrics_to_tensorboard(engine):
for name, value in engine.state.metrics.items():
writer.add_scalar('Metrics/{name}', value, trainer.state.epoch)

# create a training data loader
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
Expand Down
4 changes: 2 additions & 2 deletions examples/unet_segmentation_3d_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def prepare_batch(batch, device=None, non_blocking=False):
trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED,
handler=checkpoint_handler,
to_save={'net': net, 'opt': opt})
train_stats_handler = StatsHandler()
train_stats_handler = StatsHandler(output_transform=lambda x: x[1])
train_stats_handler.attach(trainer)


Expand Down Expand Up @@ -160,7 +160,7 @@ def log_training_loss(engine):

# Add stats event handler to print validation stats via evaluator
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
val_stats_handler = StatsHandler()
val_stats_handler = StatsHandler(output_transform=lambda x: None)
val_stats_handler.attach(evaluator)

# Add early stopping handler to evaluator.
Expand Down
5 changes: 3 additions & 2 deletions monai/handlers/segmentation_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
# limitations under the License.

import os

import numpy as np
import torch
from ignite.engine import Events

Expand Down Expand Up @@ -114,5 +114,6 @@ def __call__(self, engine):
seg_output = seg_output.detach().cpu().numpy()
output_filename = self._create_file_basename(self.output_postfix, filename, self.output_path)
output_filename = '{}{}'.format(output_filename, self.output_ext)
write_nifti(seg_output, affine_, output_filename, original_affine_, dtype=seg_output.dtype)
# change output to "channel last" format and write to nifti format file
write_nifti(np.moveaxis(seg_output, 0, -1), affine_, output_filename, original_affine_, dtype=seg_output.dtype)
self.logger.info('saved: {}'.format(output_filename))
77 changes: 58 additions & 19 deletions monai/handlers/stats_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,47 +9,63 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import warnings
import logging
import torch
from ignite.engine import Engine, Events
from monai.utils.misc import is_scalar

KEY_VAL_FORMAT = '{}: {:.4f} '
DEFAULT_KEY_VAL_FORMAT = '{}: {:.4f} '
DEFAULT_TAG = 'Loss'


class StatsHandler(object):
"""StatsHandler defines a set of Ignite Event-handlers for all the log printing logics.
It's can be used for any Ignite Engine(trainer, validator and evaluator).
And it can support logging for epoch level and iteration level with pre-defined loggers.
By default:
(1) epoch_print_logger logs `engine.state.metrics`.
(2) iteration_print_logger logs loss value, expected output format is (y_pred, loss).

Default behaviors:
- When EPOCH_COMPLETED, logs ``engine.state.metrics`` using ``self.logger``.
- When ITERATION_COMPELTED, logs
``self.output_transform(engine.state.output)`` using ``self.logger``.
"""

def __init__(self,
epoch_print_logger=None,
iteration_print_logger=None,
batch_transform=lambda x: x,
output_transform=lambda x: x,
name=None):
global_epoch_transform=lambda x: x,
name=None,
tag_name=DEFAULT_TAG,
key_var_format=DEFAULT_KEY_VAL_FORMAT):
"""
Args:
epoch_print_logger (Callable): customized callable printer for epoch level logging.
must accept parameter "engine", use default printer if None.
iteration_print_logger (Callable): custimized callable printer for iteration level logging.
must accept parameter "engine", use default printer if None.
batch_transform (Callable): a callable that is used to transform the
ignite.engine.batch into expected format to extract input data.
output_transform (Callable): a callable that is used to transform the
ignite.engine.output into expected format to extract several output data.
name (str): identifier of logging.logger to use, defaulting to `engine.logger`.
``ignite.engine.output`` into a scalar to print, or a dictionary of {key: scalar}.
in the latter case, the output string will be formated as key: value.
by default this value logging happens when every iteration completed.
global_epoch_transform (Callable): a callable that is used to customize global epoch number.
For example, in evaluation, the evaluator engine might want to print synced epoch number
with the trainer engine.
name (str): identifier of logging.logger to use, defaulting to ``engine.logger``.
tag_name (string): when iteration output is a scalar, tag_name is used to print
tag_name: scalar_value to logger. Defaults to ``'Loss'``.
key_var_format (string): a formatting string to control the output string format of key: value.
"""

self.epoch_print_logger = epoch_print_logger
self.iteration_print_logger = iteration_print_logger
self.batch_transform = batch_transform
self.output_transform = output_transform
self.global_epoch_transform = global_epoch_transform
self.logger = None if name is None else logging.getLogger(name)

self.tag_name = tag_name
self.key_var_format = key_var_format

def attach(self, engine: Engine):
"""Register a set of Ignite Event-Handlers to a specified Ignite engine.

Expand Down Expand Up @@ -116,12 +132,12 @@ def _default_epoch_print(self, engine: Engine):
prints_dict = engine.state.metrics
if not prints_dict:
return
current_epoch = engine.state.epoch
current_epoch = self.global_epoch_transform(engine.state.epoch)

out_str = "Epoch[{}] Metrics -- ".format(current_epoch)
for name in sorted(prints_dict):
value = prints_dict[name]
out_str += KEY_VAL_FORMAT.format(name, value)
out_str += self.key_var_format.format(name, value)

self.logger.info(out_str)

Expand All @@ -134,19 +150,42 @@ def _default_iteration_print(self, engine: Engine):
engine (ignite.engine): Ignite Engine, it can be a trainer, validator or evaluator.

"""
loss = self.output_transform(engine.state.output)[1]
if loss is None or (torch.is_tensor(loss) and len(loss.shape) > 0):
return # not printing multi dimensional output
loss = self.output_transform(engine.state.output)
if loss is None:
return # no printing if the output is empty

out_str = ''
if isinstance(loss, dict): # print dictionary items
for name in sorted(loss):
value = loss[name]
if not is_scalar(value):
warnings.warn('ignoring non-scalar output in StatsHandler,'
' make sure `output_transform(engine.state.output)` returns'
' a scalar or dictionary of key and scalar pairs to avoid this warning.'
' {}:{}'.format(name, type(value)))
continue # not printing multi dimensional output
out_str += self.key_var_format.format(name, value.item() if torch.is_tensor(value) else value)
else:
if is_scalar(loss): # not printing multi dimensional output
out_str += self.key_var_format.format(self.tag_name, loss.item() if torch.is_tensor(loss) else loss)
else:
warnings.warn('ignoring non-scalar output in StatsHandler,'
' make sure `output_transform(engine.state.output)` returns'
' a scalar or a dictionary of key and scalar pairs to avoid this warning.'
' {}'.format(type(loss)))

if not out_str:
return # no value to print

num_iterations = engine.state.epoch_length
current_iteration = (engine.state.iteration - 1) % num_iterations + 1
current_epoch = engine.state.epoch
num_epochs = engine.state.max_epochs

out_str = "Epoch: {}/{}, Iter: {}/{} -- ".format(
base_str = "Epoch: {}/{}, Iter: {}/{} --".format(
current_epoch,
num_epochs,
current_iteration,
num_iterations)
out_str += KEY_VAL_FORMAT.format('Loss', loss.item() if torch.is_tensor(loss) else loss)

self.logger.info(out_str)
self.logger.info(' '.join([base_str, out_str]))
Loading