diff --git a/monai/handlers/tensorboard_handlers.py b/monai/handlers/tensorboard_handlers.py index 490ba5d2d1..11b487ec99 100644 --- a/monai/handlers/tensorboard_handlers.py +++ b/monai/handlers/tensorboard_handlers.py @@ -275,6 +275,7 @@ def __init__( global_iter_transform: Callable = lambda x: x, index: int = 0, max_channels: int = 1, + frame_dim: int = -3, max_frames: int = 64, ) -> None: """ @@ -301,6 +302,8 @@ def __init__( For example, in evaluation, the evaluator engine needs to know current epoch from trainer. index: plot which element in a data batch, default is the first element. max_channels: number of channels to plot. + frame_dim: if plotting 3D image as GIF, specify the dimension used as frames, + expect input data shape as `NCHWD`, default to `-3` (the first spatial dim) max_frames: if plot 3D RGB image as video in TensorBoardX, set the FPS to `max_frames`. """ super().__init__(summary_writer=summary_writer, log_dir=log_dir) @@ -310,6 +313,7 @@ def __init__( self.output_transform = output_transform self.global_iter_transform = global_iter_transform self.index = index + self.frame_dim = frame_dim self.max_frames = max_frames self.max_channels = max_channels @@ -349,13 +353,14 @@ def __call__(self, engine: Engine) -> None: ) plot_2d_or_3d_image( # add batch dim and plot the first item - show_images[None], - step, - self._writer, - 0, - self.max_channels, - self.max_frames, - "input_0", + data=show_images[None], + step=step, + writer=self._writer, + index=0, + max_channels=self.max_channels, + frame_dim=self.frame_dim, + max_frames=self.max_frames, + tag="input_0", ) show_labels = self.batch_transform(engine.state.batch)[1][self.index] @@ -367,7 +372,16 @@ def __call__(self, engine: Engine) -> None: "batch_transform(engine.state.batch)[1] must be None or one of " f"(numpy.ndarray, torch.Tensor) but is {type(show_labels).__name__}." ) - plot_2d_or_3d_image(show_labels[None], step, self._writer, 0, self.max_channels, self.max_frames, "input_1") + plot_2d_or_3d_image( + data=show_labels[None], + step=step, + writer=self._writer, + index=0, + max_channels=self.max_channels, + frame_dim=self.frame_dim, + max_frames=self.max_frames, + tag="input_1", + ) show_outputs = self.output_transform(engine.state.output)[self.index] if isinstance(show_outputs, torch.Tensor): @@ -378,6 +392,15 @@ def __call__(self, engine: Engine) -> None: "output_transform(engine.state.output) must be None or one of " f"(numpy.ndarray, torch.Tensor) but is {type(show_outputs).__name__}." ) - plot_2d_or_3d_image(show_outputs[None], step, self._writer, 0, self.max_channels, self.max_frames, "output") + plot_2d_or_3d_image( + data=show_outputs[None], + step=step, + writer=self._writer, + index=0, + max_channels=self.max_channels, + frame_dim=self.frame_dim, + max_frames=self.max_frames, + tag="output", + ) self._writer.flush() diff --git a/monai/visualize/__init__.py b/monai/visualize/__init__.py index 4356b5f8d9..1e3140d3a4 100644 --- a/monai/visualize/__init__.py +++ b/monai/visualize/__init__.py @@ -10,12 +10,7 @@ # limitations under the License. from .class_activation_maps import CAM, GradCAM, GradCAMpp, ModelWithHooks, default_normalizer -from .img2tensorboard import ( - add_animated_gif, - add_animated_gif_no_channels, - make_animated_gif_summary, - plot_2d_or_3d_image, -) +from .img2tensorboard import add_animated_gif, make_animated_gif_summary, plot_2d_or_3d_image from .occlusion_sensitivity import OcclusionSensitivity from .utils import blend_images, matshow3d from .visualizer import default_upsampler diff --git a/monai/visualize/img2tensorboard.py b/monai/visualize/img2tensorboard.py index 309e7f4b18..619c65f79d 100644 --- a/monai/visualize/img2tensorboard.py +++ b/monai/visualize/img2tensorboard.py @@ -9,14 +9,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Union +from typing import TYPE_CHECKING, List, Optional, Union import numpy as np import torch from monai.config import NdarrayTensor from monai.transforms import rescale_array -from monai.utils import optional_import +from monai.utils import convert_data_type, optional_import PIL, _ = optional_import("PIL") GifImage, _ = optional_import("PIL.GifImagePlugin", name="Image") @@ -30,22 +30,28 @@ Summary, _ = optional_import("tensorboard.compat.proto.summary_pb2", name="Summary") SummaryWriter, _ = optional_import("torch.utils.tensorboard", name="SummaryWriter") -__all__ = ["make_animated_gif_summary", "add_animated_gif", "add_animated_gif_no_channels", "plot_2d_or_3d_image"] +__all__ = ["make_animated_gif_summary", "add_animated_gif", "plot_2d_or_3d_image"] -def _image3_animated_gif(tag: str, image: Union[np.ndarray, torch.Tensor], writer, scale_factor: float = 1.0): +def _image3_animated_gif( + tag: str, image: Union[np.ndarray, torch.Tensor], writer, frame_dim: int = 0, scale_factor: float = 1.0 +): """Function to actually create the animated gif. Args: tag: Data identifier image: 3D image tensors expected to be in `HWD` format + writer: the tensorboard writer to plot image + frame_dim: the dimension used as frames for GIF image, expect data shape as `HWD`, default to `0`. scale_factor: amount to multiply values by. if the image data is between 0 and 1, using 255 for this value will scale it to displayable range """ if len(image.shape) != 3: raise AssertionError("3D image tensors expected to be in `HWD` format, len(image.shape) != 3") - ims = [(np.asarray(image[:, :, i]) * scale_factor).astype(np.uint8, copy=False) for i in range(image.shape[2])] + image_np: np.ndarray + image_np, *_ = convert_data_type(image, output_type=np.ndarray) # type: ignore + ims = [(i * scale_factor).astype(np.uint8, copy=False) for i in np.moveaxis(image_np, frame_dim, 0)] ims = [GifImage.fromarray(im) for im in ims] img_str = b"" for b_data in PIL.GifImagePlugin.getheader(ims[0])[0]: @@ -67,45 +73,34 @@ def make_animated_gif_summary( image: Union[np.ndarray, torch.Tensor], writer=None, max_out: int = 3, - animation_axes: Sequence[int] = (3,), - image_axes: Sequence[int] = (1, 2), - other_indices: Optional[Dict] = None, + frame_dim: int = -3, scale_factor: float = 1.0, ) -> Summary: """Creates an animated gif out of an image tensor in 'CHWD' format and returns Summary. Args: tag: Data identifier - image: The image, expected to be in CHWD format + image: The image, expected to be in `CHWD` format writer: the tensorboard writer to plot image max_out: maximum number of image channels to animate through - animation_axes: axis to animate on (not currently used) - image_axes: axes of image (not currently used) - other_indices: (not currently used) + frame_dim: the dimension used as frames for GIF image, expect input data shape as `CHWD`, + default to `-3` (the first spatial dim) scale_factor: amount to multiply values by. if the image data is between 0 and 1, using 255 for this value will scale it to displayable range """ suffix = "/image" if max_out == 1 else "/image/{}" - if other_indices is None: - other_indices = {} - axis_order = [0] + list(animation_axes) + list(image_axes) - - slicing = [] - for i in range(len(image.shape)): - if i in axis_order: - slicing.append(slice(None)) - else: - other_ind = other_indices.get(i, 0) - slicing.append(slice(other_ind, other_ind + 1)) - image = image[tuple(slicing)] + # GIF image has no channel dim, reduce the spatial dim index if positive + frame_dim = frame_dim - 1 if frame_dim > 0 else frame_dim summary_op = [] for it_i in range(min(max_out, list(image.shape)[0])): one_channel_img: Union[torch.Tensor, np.ndarray] = ( image[it_i, :, :, :].squeeze(dim=0) if isinstance(image, torch.Tensor) else image[it_i, :, :, :] ) - summary_op.append(_image3_animated_gif(tag + suffix.format(it_i), one_channel_img, writer, scale_factor)) + summary_op.append( + _image3_animated_gif(tag + suffix.format(it_i), one_channel_img, writer, frame_dim, scale_factor) + ) return summary_op @@ -114,6 +109,7 @@ def add_animated_gif( tag: str, image_tensor: Union[np.ndarray, torch.Tensor], max_out: int = 3, + frame_dim: int = -3, scale_factor: float = 1.0, global_step: Optional[int] = None, ) -> None: @@ -122,67 +118,29 @@ def add_animated_gif( Args: writer: Tensorboard SummaryWriter to write to tag: Data identifier - image_tensor: tensor for the image to add, expected to be in CHWD format + image_tensor: tensor for the image to add, expected to be in `CHWD` format max_out: maximum number of image channels to animate through + frame_dim: the dimension used as frames for GIF image, expect input data shape as `CHWD`, + default to `-3` (the first spatial dim) scale_factor: amount to multiply values by. If the image data is between 0 and 1, using 255 for this value will scale it to displayable range global_step: Global step value to record """ summary = make_animated_gif_summary( - tag=tag, - image=image_tensor, - writer=writer, - max_out=max_out, - animation_axes=[1], - image_axes=[2, 3], - scale_factor=scale_factor, + tag=tag, image=image_tensor, writer=writer, max_out=max_out, frame_dim=frame_dim, scale_factor=scale_factor ) for s in summary: # add GIF for every channel separately writer._get_file_writer().add_summary(s, global_step) -def add_animated_gif_no_channels( - writer: SummaryWriter, - tag: str, - image_tensor: Union[np.ndarray, torch.Tensor], - max_out: int = 3, - scale_factor: float = 1.0, - global_step: Optional[int] = None, -) -> None: - """Creates an animated gif out of an image tensor in 'HWD' format that does not have - a channel dimension and writes it with SummaryWriter. This is similar to the "add_animated_gif" - after inserting a channel dimension of 1. - - Args: - writer: Tensorboard SummaryWriter to write to - tag: Data identifier - image_tensor: tensor for the image to add, expected to be in HWD format - max_out: maximum number of image channels to animate through - scale_factor: amount to multiply values by. If the image data is between 0 and 1, - using 255 for this value will scale it to displayable range - global_step: Global step value to record - """ - writer._get_file_writer().add_summary( - make_animated_gif_summary( - tag=tag, - image=image_tensor, - writer=writer, - max_out=max_out, - animation_axes=[1], - image_axes=[1, 2], - scale_factor=scale_factor, - )[0], - global_step, - ) - - def plot_2d_or_3d_image( data: Union[NdarrayTensor, List[NdarrayTensor]], step: int, writer: SummaryWriter, index: int = 0, max_channels: int = 1, + frame_dim: int = -3, max_frames: int = 24, tag: str = "output", ) -> None: @@ -200,10 +158,15 @@ def plot_2d_or_3d_image( writer: specify TensorBoard or TensorBoardX SummaryWriter to plot the image. index: plot which element in the input data batch, default is the first element. max_channels: number of channels to plot. + frame_dim: if plotting 3D image as GIF, specify the dimension used as frames, + expect input data shape as `NCHWD`, default to `-3` (the first spatial dim) max_frames: if plot 3D RGB image as video in TensorBoardX, set the FPS to `max_frames`. tag: tag of the plotted image on TensorBoard. """ data_index = data[index] + # as the `d` data has no batch dim, reduce the spatial dim index if positive + frame_dim = frame_dim - 1 if frame_dim > 0 else frame_dim + d: np.ndarray = data_index.detach().cpu().numpy() if isinstance(data_index, torch.Tensor) else data_index if d.ndim == 2: @@ -227,11 +190,13 @@ def plot_2d_or_3d_image( spatial = d.shape[-3:] d = d.reshape([-1] + list(spatial)) if d.shape[0] == 3 and max_channels == 3 and has_tensorboardx and isinstance(writer, SummaryWriterX): # RGB + # move the expected frame dim to the end as `T` dim for video + d = np.moveaxis(d, frame_dim, -1) writer.add_video(tag, d[None], step, fps=max_frames, dataformats="NCHWT") return # scale data to 0 - 255 for visualization max_channels = min(max_channels, d.shape[0]) d = np.stack([rescale_array(i, 0, 255) for i in d[:max_channels]], axis=0) # will plot every channel as a separate GIF image - add_animated_gif(writer, f"{tag}_HWD", d, max_out=max_channels, global_step=step) + add_animated_gif(writer, f"{tag}_HWD", d, max_out=max_channels, frame_dim=frame_dim, global_step=step) return diff --git a/tests/test_img2tensorboard.py b/tests/test_img2tensorboard.py index bf6890bcad..5d76231356 100644 --- a/tests/test_img2tensorboard.py +++ b/tests/test_img2tensorboard.py @@ -22,12 +22,7 @@ class TestImg2Tensorboard(unittest.TestCase): def test_write_gray(self): nparr = np.ones(shape=(1, 32, 32, 32), dtype=np.float32) summary_object_np = make_animated_gif_summary( - tag="test_summary_nparr.png", - image=nparr, - max_out=1, - animation_axes=(3,), - image_axes=(1, 2), - scale_factor=253.0, + tag="test_summary_nparr.png", image=nparr, max_out=1, scale_factor=253.0 ) for s in summary_object_np: assert isinstance( @@ -36,12 +31,7 @@ def test_write_gray(self): tensorarr = torch.tensor(nparr) summary_object_tensor = make_animated_gif_summary( - tag="test_summary_tensorarr.png", - image=tensorarr, - max_out=1, - animation_axes=(3,), - image_axes=(1, 2), - scale_factor=253.0, + tag="test_summary_tensorarr.png", image=tensorarr, max_out=1, frame_dim=-1, scale_factor=253.0 ) for s in summary_object_tensor: assert isinstance( diff --git a/tests/test_plot_2d_or_3d_image.py b/tests/test_plot_2d_or_3d_image.py index c645c8ff86..cfcb145503 100644 --- a/tests/test_plot_2d_or_3d_image.py +++ b/tests/test_plot_2d_or_3d_image.py @@ -39,7 +39,7 @@ class TestPlot2dOr3dImage(unittest.TestCase): def test_tb_image(self, shape): with tempfile.TemporaryDirectory() as tempdir: writer = SummaryWriter(log_dir=tempdir) - plot_2d_or_3d_image(torch.zeros(shape), 0, writer, max_channels=20) + plot_2d_or_3d_image(torch.zeros(shape), 0, writer, max_channels=3, frame_dim=-1) writer.flush() writer.close() self.assertTrue(len(glob.glob(tempdir)) > 0)