-
Notifications
You must be signed in to change notification settings - Fork 1.4k
Description
Describe the bug
The forward method of SegResNetVAE returns a tuple, which throws an error when using it in eval mode with sliding_window_inference (which expects predictor to output a tensor, not a tuple).
I was just wondering: is there any reason why when not in train mode, the forward method of SegResNetVAE can't just return the output tensor (i.e., return x) instead of returning a tuple? (i.e., return x, None). I understand that when in train mode, the second element of the tuple contains the vae_loss term.
To Reproduce
The code below should be a fairly minimal example that reproduces the error:
from monai.inferers import sliding_window_inference
from monai.networks.nets import SegResNetVAE
import monai.transforms as tf
import torch
device = torch.device("cpu")
model = SegResNetVAE(input_image_size=(32,)*3).to(device)
model.eval()
inputs = torch.rand((1, 1, 32, 32, 32))
outputs = sliding_window_inference(inputs, (32,)*3, 4, model)
Expected behavior
Before checking the source code, I was expecting the forward method of SegResNetVAE to return a tensor in eval mode, but instead it returns a tuple. This seems to be what's causing sliding_window_inference to throw the error.
Screenshots
Screenshot of the error thrown when executing the code block above:
Environment
Ensuring you use the relevant python executable, please paste the output of:
python -c 'import monai; monai.config.print_debug_info()'
Output:
================================
Printing MONAI config...
================================
MONAI version: 0.7.dev2133
Numpy version: 1.19.2
Pytorch version: 1.8.1+cu102
MONAI flags: HAS_EXT = False, USE_COMPILED = False
MONAI rev id: 20ffa3f987fad60a8428ec635fb0b4f6ccca9747
Optional dependencies:
Pytorch Ignite version: NOT INSTALLED or UNKNOWN VERSION.
Nibabel version: 3.2.1
scikit-image version: NOT INSTALLED or UNKNOWN VERSION.
Pillow version: 7.2.0
Tensorboard version: NOT INSTALLED or UNKNOWN VERSION.
gdown version: NOT INSTALLED or UNKNOWN VERSION.
TorchVision version: 0.9.1+cu102
tqdm version: 4.60.0
lmdb version: NOT INSTALLED or UNKNOWN VERSION.
psutil version: NOT INSTALLED or UNKNOWN VERSION.
pandas version: 1.2.4
einops version: 0.3.2
For details about installing the optional dependencies, please visit:
https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies
================================
Printing system config...
================================
`psutil` required for `print_system_info`
================================
Printing GPU config...
================================
Num GPUs: 0
Has CUDA: False
cuDNN enabled: True
cuDNN version: 7605
Additional context
N/A
