diff --git a/monai/visualize/utils.py b/monai/visualize/utils.py index a4b85d6e58..8779ec1ee6 100644 --- a/monai/visualize/utils.py +++ b/monai/visualize/utils.py @@ -32,6 +32,7 @@ def matshow3d( title: Optional[str] = None, figsize=(10, 10), frames_per_row: Optional[int] = None, + frame_dim: int = -3, vmin=None, vmax=None, every_n: int = 1, @@ -53,6 +54,8 @@ def matshow3d( title: title of the figure. figsize: size of the figure. frames_per_row: number of frames to display in each row. If None, sqrt(firstdim) will be used. + frame_dim: for higher dimensional arrays, which dimension (`-1`, `-2`, `-3`) is moved to the `-3` + dim and reshape to (-1, H, W) shape to construct frames, default to `-3`. vmin: `vmin` for the matplotlib `imshow`. vmax: `vmax` for the matplotlib `imshow`. every_n: factor to subsample the frames so that only every n-th frame is displayed. @@ -95,6 +98,8 @@ def matshow3d( else: # ndarray while len(vol.shape) < 3: vol = np.expand_dims(vol, 0) # so that we display 1d and 2d as well + + vol = np.moveaxis(vol, frame_dim, -3) # move the expected dim to construct frames with `B` or `C` dims if len(vol.shape) > 3: vol = vol.reshape((-1, vol.shape[-2], vol.shape[-1])) vmin = np.nanmin(vol) if vmin is None else vmin diff --git a/tests/test_matshow3d.py b/tests/test_matshow3d.py index 6be0938e8b..fa834ca431 100644 --- a/tests/test_matshow3d.py +++ b/tests/test_matshow3d.py @@ -34,7 +34,7 @@ def test_3d(self): ims = xforms({keys: image_path}) fig = pyplot.figure() # external figure - fig, _ = matshow3d(ims[keys], fig=fig, figsize=(2, 2), frames_per_row=5, every_n=2, show=False) + fig, _ = matshow3d(ims[keys], fig=fig, figsize=(2, 2), frames_per_row=5, every_n=2, frame_dim=-1, show=False) with tempfile.TemporaryDirectory() as tempdir: tempimg = f"{tempdir}/matshow3d_test.png" diff --git a/tests/testing_data/matshow3d_test.png b/tests/testing_data/matshow3d_test.png index 789c1a5ac8..d720a0c407 100644 Binary files a/tests/testing_data/matshow3d_test.png and b/tests/testing_data/matshow3d_test.png differ