Skip to content

SegResNetVAE fails with sliding window inference #2927

@lyndonboone

Description

@lyndonboone

Describe the bug
The forward method of SegResNetVAE returns a tuple, which throws an error when using it in eval mode with sliding_window_inference (which expects predictor to output a tensor, not a tuple).

I was just wondering: is there any reason why when not in train mode, the forward method of SegResNetVAE can't just return the output tensor (i.e., return x) instead of returning a tuple? (i.e., return x, None). I understand that when in train mode, the second element of the tuple contains the vae_loss term.

To Reproduce
The code below should be a fairly minimal example that reproduces the error:

from monai.inferers import sliding_window_inference
from monai.networks.nets import SegResNetVAE
import monai.transforms as tf
import torch

device = torch.device("cpu")
model = SegResNetVAE(input_image_size=(32,)*3).to(device)
model.eval()

inputs = torch.rand((1, 1, 32, 32, 32))
outputs = sliding_window_inference(inputs, (32,)*3, 4, model)

Expected behavior
Before checking the source code, I was expecting the forward method of SegResNetVAE to return a tensor in eval mode, but instead it returns a tuple. This seems to be what's causing sliding_window_inference to throw the error.

Screenshots
Screenshot of the error thrown when executing the code block above:

image

Environment

Ensuring you use the relevant python executable, please paste the output of:

python -c 'import monai; monai.config.print_debug_info()'

Output:

================================
Printing MONAI config...
================================
MONAI version: 0.7.dev2133
Numpy version: 1.19.2
Pytorch version: 1.8.1+cu102
MONAI flags: HAS_EXT = False, USE_COMPILED = False
MONAI rev id: 20ffa3f987fad60a8428ec635fb0b4f6ccca9747

Optional dependencies:
Pytorch Ignite version: NOT INSTALLED or UNKNOWN VERSION.
Nibabel version: 3.2.1
scikit-image version: NOT INSTALLED or UNKNOWN VERSION.
Pillow version: 7.2.0
Tensorboard version: NOT INSTALLED or UNKNOWN VERSION.
gdown version: NOT INSTALLED or UNKNOWN VERSION.
TorchVision version: 0.9.1+cu102
tqdm version: 4.60.0
lmdb version: NOT INSTALLED or UNKNOWN VERSION.
psutil version: NOT INSTALLED or UNKNOWN VERSION.
pandas version: 1.2.4
einops version: 0.3.2

For details about installing the optional dependencies, please visit:
    https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies


================================
Printing system config...
================================
`psutil` required for `print_system_info`

================================
Printing GPU config...
================================
Num GPUs: 0
Has CUDA: False
cuDNN enabled: True
cuDNN version: 7605

Additional context
N/A

Metadata

Metadata

Assignees

No one assigned

    Labels

    questionFurther information is requested

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions