Skip to content
41 changes: 32 additions & 9 deletions monai/handlers/tensorboard_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,7 @@ def __init__(
global_iter_transform: Callable = lambda x: x,
index: int = 0,
max_channels: int = 1,
frame_dim: int = -3,
max_frames: int = 64,
) -> None:
"""
Expand All @@ -301,6 +302,8 @@ def __init__(
For example, in evaluation, the evaluator engine needs to know current epoch from trainer.
index: plot which element in a data batch, default is the first element.
max_channels: number of channels to plot.
frame_dim: if plotting 3D image as GIF, specify the dimension used as frames,
expect input data shape as `NCHWD`, default to `-3` (the first spatial dim)
max_frames: if plot 3D RGB image as video in TensorBoardX, set the FPS to `max_frames`.
"""
super().__init__(summary_writer=summary_writer, log_dir=log_dir)
Expand All @@ -310,6 +313,7 @@ def __init__(
self.output_transform = output_transform
self.global_iter_transform = global_iter_transform
self.index = index
self.frame_dim = frame_dim
self.max_frames = max_frames
self.max_channels = max_channels

Expand Down Expand Up @@ -349,13 +353,14 @@ def __call__(self, engine: Engine) -> None:
)
plot_2d_or_3d_image(
# add batch dim and plot the first item
show_images[None],
step,
self._writer,
0,
self.max_channels,
self.max_frames,
"input_0",
data=show_images[None],
step=step,
writer=self._writer,
index=0,
max_channels=self.max_channels,
frame_dim=self.frame_dim,
max_frames=self.max_frames,
tag="input_0",
)

show_labels = self.batch_transform(engine.state.batch)[1][self.index]
Expand All @@ -367,7 +372,16 @@ def __call__(self, engine: Engine) -> None:
"batch_transform(engine.state.batch)[1] must be None or one of "
f"(numpy.ndarray, torch.Tensor) but is {type(show_labels).__name__}."
)
plot_2d_or_3d_image(show_labels[None], step, self._writer, 0, self.max_channels, self.max_frames, "input_1")
plot_2d_or_3d_image(
data=show_labels[None],
step=step,
writer=self._writer,
index=0,
max_channels=self.max_channels,
frame_dim=self.frame_dim,
max_frames=self.max_frames,
tag="input_1",
)

show_outputs = self.output_transform(engine.state.output)[self.index]
if isinstance(show_outputs, torch.Tensor):
Expand All @@ -378,6 +392,15 @@ def __call__(self, engine: Engine) -> None:
"output_transform(engine.state.output) must be None or one of "
f"(numpy.ndarray, torch.Tensor) but is {type(show_outputs).__name__}."
)
plot_2d_or_3d_image(show_outputs[None], step, self._writer, 0, self.max_channels, self.max_frames, "output")
plot_2d_or_3d_image(
data=show_outputs[None],
step=step,
writer=self._writer,
index=0,
max_channels=self.max_channels,
frame_dim=self.frame_dim,
max_frames=self.max_frames,
tag="output",
)

self._writer.flush()
7 changes: 1 addition & 6 deletions monai/visualize/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,7 @@
# limitations under the License.

from .class_activation_maps import CAM, GradCAM, GradCAMpp, ModelWithHooks, default_normalizer
from .img2tensorboard import (
add_animated_gif,
add_animated_gif_no_channels,
make_animated_gif_summary,
plot_2d_or_3d_image,
)
from .img2tensorboard import add_animated_gif, make_animated_gif_summary, plot_2d_or_3d_image
from .occlusion_sensitivity import OcclusionSensitivity
from .utils import blend_images, matshow3d
from .visualizer import default_upsampler
103 changes: 34 additions & 69 deletions monai/visualize/img2tensorboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Union
from typing import TYPE_CHECKING, List, Optional, Union

import numpy as np
import torch

from monai.config import NdarrayTensor
from monai.transforms import rescale_array
from monai.utils import optional_import
from monai.utils import convert_data_type, optional_import

PIL, _ = optional_import("PIL")
GifImage, _ = optional_import("PIL.GifImagePlugin", name="Image")
Expand All @@ -30,22 +30,28 @@
Summary, _ = optional_import("tensorboard.compat.proto.summary_pb2", name="Summary")
SummaryWriter, _ = optional_import("torch.utils.tensorboard", name="SummaryWriter")

__all__ = ["make_animated_gif_summary", "add_animated_gif", "add_animated_gif_no_channels", "plot_2d_or_3d_image"]
__all__ = ["make_animated_gif_summary", "add_animated_gif", "plot_2d_or_3d_image"]


def _image3_animated_gif(tag: str, image: Union[np.ndarray, torch.Tensor], writer, scale_factor: float = 1.0):
def _image3_animated_gif(
tag: str, image: Union[np.ndarray, torch.Tensor], writer, frame_dim: int = 0, scale_factor: float = 1.0
):
"""Function to actually create the animated gif.

Args:
tag: Data identifier
image: 3D image tensors expected to be in `HWD` format
writer: the tensorboard writer to plot image
frame_dim: the dimension used as frames for GIF image, expect data shape as `HWD`, default to `0`.
scale_factor: amount to multiply values by. if the image data is between 0 and 1, using 255 for this value will
scale it to displayable range
"""
if len(image.shape) != 3:
raise AssertionError("3D image tensors expected to be in `HWD` format, len(image.shape) != 3")

ims = [(np.asarray(image[:, :, i]) * scale_factor).astype(np.uint8, copy=False) for i in range(image.shape[2])]
image_np: np.ndarray
image_np, *_ = convert_data_type(image, output_type=np.ndarray) # type: ignore
ims = [(i * scale_factor).astype(np.uint8, copy=False) for i in np.moveaxis(image_np, frame_dim, 0)]
ims = [GifImage.fromarray(im) for im in ims]
img_str = b""
for b_data in PIL.GifImagePlugin.getheader(ims[0])[0]:
Expand All @@ -67,45 +73,34 @@ def make_animated_gif_summary(
image: Union[np.ndarray, torch.Tensor],
writer=None,
max_out: int = 3,
animation_axes: Sequence[int] = (3,),
image_axes: Sequence[int] = (1, 2),
other_indices: Optional[Dict] = None,
frame_dim: int = -3,
scale_factor: float = 1.0,
) -> Summary:
"""Creates an animated gif out of an image tensor in 'CHWD' format and returns Summary.

Args:
tag: Data identifier
image: The image, expected to be in CHWD format
image: The image, expected to be in `CHWD` format
writer: the tensorboard writer to plot image
max_out: maximum number of image channels to animate through
animation_axes: axis to animate on (not currently used)
image_axes: axes of image (not currently used)
other_indices: (not currently used)
frame_dim: the dimension used as frames for GIF image, expect input data shape as `CHWD`,
default to `-3` (the first spatial dim)
scale_factor: amount to multiply values by.
if the image data is between 0 and 1, using 255 for this value will scale it to displayable range
"""

suffix = "/image" if max_out == 1 else "/image/{}"
if other_indices is None:
other_indices = {}
axis_order = [0] + list(animation_axes) + list(image_axes)

slicing = []
for i in range(len(image.shape)):
if i in axis_order:
slicing.append(slice(None))
else:
other_ind = other_indices.get(i, 0)
slicing.append(slice(other_ind, other_ind + 1))
image = image[tuple(slicing)]
# GIF image has no channel dim, reduce the spatial dim index if positive
frame_dim = frame_dim - 1 if frame_dim > 0 else frame_dim

summary_op = []
for it_i in range(min(max_out, list(image.shape)[0])):
one_channel_img: Union[torch.Tensor, np.ndarray] = (
image[it_i, :, :, :].squeeze(dim=0) if isinstance(image, torch.Tensor) else image[it_i, :, :, :]
)
summary_op.append(_image3_animated_gif(tag + suffix.format(it_i), one_channel_img, writer, scale_factor))
summary_op.append(
_image3_animated_gif(tag + suffix.format(it_i), one_channel_img, writer, frame_dim, scale_factor)
)
return summary_op


Expand All @@ -114,6 +109,7 @@ def add_animated_gif(
tag: str,
image_tensor: Union[np.ndarray, torch.Tensor],
max_out: int = 3,
frame_dim: int = -3,
scale_factor: float = 1.0,
global_step: Optional[int] = None,
) -> None:
Expand All @@ -122,67 +118,29 @@ def add_animated_gif(
Args:
writer: Tensorboard SummaryWriter to write to
tag: Data identifier
image_tensor: tensor for the image to add, expected to be in CHWD format
image_tensor: tensor for the image to add, expected to be in `CHWD` format
max_out: maximum number of image channels to animate through
frame_dim: the dimension used as frames for GIF image, expect input data shape as `CHWD`,
default to `-3` (the first spatial dim)
scale_factor: amount to multiply values by. If the image data is between 0 and 1, using 255 for this value will
scale it to displayable range
global_step: Global step value to record
"""
summary = make_animated_gif_summary(
tag=tag,
image=image_tensor,
writer=writer,
max_out=max_out,
animation_axes=[1],
image_axes=[2, 3],
scale_factor=scale_factor,
tag=tag, image=image_tensor, writer=writer, max_out=max_out, frame_dim=frame_dim, scale_factor=scale_factor
)
for s in summary:
# add GIF for every channel separately
writer._get_file_writer().add_summary(s, global_step)


def add_animated_gif_no_channels(
writer: SummaryWriter,
tag: str,
image_tensor: Union[np.ndarray, torch.Tensor],
max_out: int = 3,
scale_factor: float = 1.0,
global_step: Optional[int] = None,
) -> None:
"""Creates an animated gif out of an image tensor in 'HWD' format that does not have
a channel dimension and writes it with SummaryWriter. This is similar to the "add_animated_gif"
after inserting a channel dimension of 1.

Args:
writer: Tensorboard SummaryWriter to write to
tag: Data identifier
image_tensor: tensor for the image to add, expected to be in HWD format
max_out: maximum number of image channels to animate through
scale_factor: amount to multiply values by. If the image data is between 0 and 1,
using 255 for this value will scale it to displayable range
global_step: Global step value to record
"""
writer._get_file_writer().add_summary(
make_animated_gif_summary(
tag=tag,
image=image_tensor,
writer=writer,
max_out=max_out,
animation_axes=[1],
image_axes=[1, 2],
scale_factor=scale_factor,
)[0],
global_step,
)


def plot_2d_or_3d_image(
data: Union[NdarrayTensor, List[NdarrayTensor]],
step: int,
writer: SummaryWriter,
index: int = 0,
max_channels: int = 1,
frame_dim: int = -3,
max_frames: int = 24,
tag: str = "output",
) -> None:
Expand All @@ -200,10 +158,15 @@ def plot_2d_or_3d_image(
writer: specify TensorBoard or TensorBoardX SummaryWriter to plot the image.
index: plot which element in the input data batch, default is the first element.
max_channels: number of channels to plot.
frame_dim: if plotting 3D image as GIF, specify the dimension used as frames,
expect input data shape as `NCHWD`, default to `-3` (the first spatial dim)
max_frames: if plot 3D RGB image as video in TensorBoardX, set the FPS to `max_frames`.
tag: tag of the plotted image on TensorBoard.
"""
data_index = data[index]
# as the `d` data has no batch dim, reduce the spatial dim index if positive
frame_dim = frame_dim - 1 if frame_dim > 0 else frame_dim

d: np.ndarray = data_index.detach().cpu().numpy() if isinstance(data_index, torch.Tensor) else data_index

if d.ndim == 2:
Expand All @@ -227,11 +190,13 @@ def plot_2d_or_3d_image(
spatial = d.shape[-3:]
d = d.reshape([-1] + list(spatial))
if d.shape[0] == 3 and max_channels == 3 and has_tensorboardx and isinstance(writer, SummaryWriterX): # RGB
# move the expected frame dim to the end as `T` dim for video
d = np.moveaxis(d, frame_dim, -1)
writer.add_video(tag, d[None], step, fps=max_frames, dataformats="NCHWT")
return
# scale data to 0 - 255 for visualization
max_channels = min(max_channels, d.shape[0])
d = np.stack([rescale_array(i, 0, 255) for i in d[:max_channels]], axis=0)
# will plot every channel as a separate GIF image
add_animated_gif(writer, f"{tag}_HWD", d, max_out=max_channels, global_step=step)
add_animated_gif(writer, f"{tag}_HWD", d, max_out=max_channels, frame_dim=frame_dim, global_step=step)
return
14 changes: 2 additions & 12 deletions tests/test_img2tensorboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,7 @@ class TestImg2Tensorboard(unittest.TestCase):
def test_write_gray(self):
nparr = np.ones(shape=(1, 32, 32, 32), dtype=np.float32)
summary_object_np = make_animated_gif_summary(
tag="test_summary_nparr.png",
image=nparr,
max_out=1,
animation_axes=(3,),
image_axes=(1, 2),
scale_factor=253.0,
tag="test_summary_nparr.png", image=nparr, max_out=1, scale_factor=253.0
)
for s in summary_object_np:
assert isinstance(
Expand All @@ -36,12 +31,7 @@ def test_write_gray(self):

tensorarr = torch.tensor(nparr)
summary_object_tensor = make_animated_gif_summary(
tag="test_summary_tensorarr.png",
image=tensorarr,
max_out=1,
animation_axes=(3,),
image_axes=(1, 2),
scale_factor=253.0,
tag="test_summary_tensorarr.png", image=tensorarr, max_out=1, frame_dim=-1, scale_factor=253.0
)
for s in summary_object_tensor:
assert isinstance(
Expand Down
2 changes: 1 addition & 1 deletion tests/test_plot_2d_or_3d_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class TestPlot2dOr3dImage(unittest.TestCase):
def test_tb_image(self, shape):
with tempfile.TemporaryDirectory() as tempdir:
writer = SummaryWriter(log_dir=tempdir)
plot_2d_or_3d_image(torch.zeros(shape), 0, writer, max_channels=20)
plot_2d_or_3d_image(torch.zeros(shape), 0, writer, max_channels=3, frame_dim=-1)
writer.flush()
writer.close()
self.assertTrue(len(glob.glob(tempdir)) > 0)
Expand Down