Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 26 additions & 11 deletions monai/utils/jupyter_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ def plot_engine_status(
window_fraction: int = 20,
image_fn: Optional[Callable] = tensor_to_images,
fig=None,
selected_inst: int = 0,
) -> Tuple:
"""
Plot the status of the given Engine with its logger. The plot will consist of a graph of loss values and metrics
Expand Down Expand Up @@ -192,19 +193,33 @@ def plot_engine_status(
imagemap = {}
if image_fn is not None and engine.state is not None and engine.state.batch is not None:
for src in (engine.state.batch, engine.state.output):
label = "Batch" if src is engine.state.batch else "Output"
batch_selected_inst = selected_inst # selected batch index, set to 0 when src is decollated

# if the src object is a list of elements, ie. a decollated batch, select an element and keep it as
# a dictionary of tensors with a batch dimension added
if isinstance(src, list):
for i, s in enumerate(src):
if isinstance(s, dict):
for k, v in s.items():
if isinstance(v, torch.Tensor):
image = image_fn(k, v)
if image is not None:
imagemap[f"{k}_{i}"] = image
elif isinstance(s, torch.Tensor):
label = "Batch" if src is engine.state.batch else "Output"
image = image_fn(label, s)
selected_dict = src[selected_inst] # select this element
batch_selected_inst = 0 # set the selection to be the single index in the batch dimension
# store each tensor that is interpretable as an image with an added batch dimension
src = {k: v[None] for k, v in selected_dict.items() if isinstance(v, torch.Tensor) and v.ndim >= 3}

# images will be generated from the batch item selected above only, or from the single item given as `src`

if isinstance(src, dict):
for k, v in src.items():
if isinstance(v, torch.Tensor) and v.ndim >= 4:
image = image_fn(k, v[batch_selected_inst])

# if we have images add each one separately to the map
if image is not None:
imagemap[f"{label}_{i}"] = image
for i, im in enumerate(image):
imagemap[f"{k}_{i}"] = im

elif isinstance(src, torch.Tensor):
image = image_fn(label, src)
if image is not None:
imagemap[f"{label}_{i}"] = image

axes = plot_metric_images(fig, title, graphmap, imagemap, yscale, avg_keys, window_fraction)

Expand Down
2 changes: 1 addition & 1 deletion tests/test_threadcontainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def test_plot(self):
# a third non-image key is added to test that this is correctly ignored when plotting
data = {CommonKeys.IMAGE: img, CommonKeys.LABEL: img, "Not Image Data": ["This isn't an image"]}

loader = DataLoader([data] * 10)
loader = DataLoader([data] * 20, batch_size=2)

trainer = SupervisedTrainer(
device=torch.device("cpu"),
Expand Down