diff --git a/monai/utils/jupyter_utils.py b/monai/utils/jupyter_utils.py index 26487083b1..ed109393f9 100644 --- a/monai/utils/jupyter_utils.py +++ b/monai/utils/jupyter_utils.py @@ -161,6 +161,7 @@ def plot_engine_status( window_fraction: int = 20, image_fn: Optional[Callable] = tensor_to_images, fig=None, + selected_inst: int = 0, ) -> Tuple: """ Plot the status of the given Engine with its logger. The plot will consist of a graph of loss values and metrics @@ -192,19 +193,33 @@ def plot_engine_status( imagemap = {} if image_fn is not None and engine.state is not None and engine.state.batch is not None: for src in (engine.state.batch, engine.state.output): + label = "Batch" if src is engine.state.batch else "Output" + batch_selected_inst = selected_inst # selected batch index, set to 0 when src is decollated + + # if the src object is a list of elements, ie. a decollated batch, select an element and keep it as + # a dictionary of tensors with a batch dimension added if isinstance(src, list): - for i, s in enumerate(src): - if isinstance(s, dict): - for k, v in s.items(): - if isinstance(v, torch.Tensor): - image = image_fn(k, v) - if image is not None: - imagemap[f"{k}_{i}"] = image - elif isinstance(s, torch.Tensor): - label = "Batch" if src is engine.state.batch else "Output" - image = image_fn(label, s) + selected_dict = src[selected_inst] # select this element + batch_selected_inst = 0 # set the selection to be the single index in the batch dimension + # store each tensor that is interpretable as an image with an added batch dimension + src = {k: v[None] for k, v in selected_dict.items() if isinstance(v, torch.Tensor) and v.ndim >= 3} + + # images will be generated from the batch item selected above only, or from the single item given as `src` + + if isinstance(src, dict): + for k, v in src.items(): + if isinstance(v, torch.Tensor) and v.ndim >= 4: + image = image_fn(k, v[batch_selected_inst]) + + # if we have images add each one separately to the map if image is not None: - imagemap[f"{label}_{i}"] = image + for i, im in enumerate(image): + imagemap[f"{k}_{i}"] = im + + elif isinstance(src, torch.Tensor): + image = image_fn(label, src) + if image is not None: + imagemap[f"{label}_{i}"] = image axes = plot_metric_images(fig, title, graphmap, imagemap, yscale, avg_keys, window_fraction) diff --git a/tests/test_threadcontainer.py b/tests/test_threadcontainer.py index 543dab4d0c..5613b1babd 100644 --- a/tests/test_threadcontainer.py +++ b/tests/test_threadcontainer.py @@ -79,7 +79,7 @@ def test_plot(self): # a third non-image key is added to test that this is correctly ignored when plotting data = {CommonKeys.IMAGE: img, CommonKeys.LABEL: img, "Not Image Data": ["This isn't an image"]} - loader = DataLoader([data] * 10) + loader = DataLoader([data] * 20, batch_size=2) trainer = SupervisedTrainer( device=torch.device("cpu"),