Skip to content

sliding_window_inference cannot handle output in formats other than tensor #6320

@KumoLiu

Description

@KumoLiu

Describe the bug
sliding_window_inference cannot handle output in formats other than tensor which was a bug introduced by #6254.

final_output = MetaTensor(final_output).copy_meta_from(temp_meta)

A clear and concise description of what the bug is.


RuntimeError                              Traceback (most recent call last)
Cell In [1], line 15
      6 model = HoVerNet(
      7     in_channels=3,
      8     out_classes=5,
      9     mode = "fast",
     10     adapt_standard_resnet=True
     11 )
     13 data = MetaTensor(torch.rand([2, 3, 512, 512]))
---> 15 sliding_window_inference(inputs=data, roi_size=(256, 256), sw_batch_size=2, predictor=model)

File /workspace/Code/MONAI/monai/inferers/utils.py:311, in sliding_window_inference(inputs, roi_size, sw_batch_size, predictor, overlap, mode, sigma_scale, padding_mode, cval, sw_device, device, progress, roi_weight_map, process_fn, buffer_steps, buffer_dim, *args, **kwargs)
    309 final_output = convert_to_dst_type(final_output, inputs, device=device)[0]  # type: ignore
    310 if temp_meta is not None:
--> 311     final_output = MetaTensor(final_output).copy_meta_from(temp_meta)
    312 return final_output

File /workspace/Code/MONAI/monai/data/meta_tensor.py:116, in MetaTensor.__new__(cls, x, affine, meta, applied_operations, *args, **kwargs)
    105 @staticmethod
    106 def __new__(
    107     cls,
   (...)
    113     **kwargs,
    114 ) -> MetaTensor:
    115     _kwargs = {"device": kwargs.pop("device", None), "dtype": kwargs.pop("dtype", None)} if kwargs else {}
--> 116     return torch.as_tensor(x, *args, **_kwargs).as_subclass(cls)

RuntimeError: Could not infer dtype of dict

To Reproduce

import torch
from monai.data import MetaTensor
from monai.networks.nets import HoVerNet
from monai.inferers.utils import sliding_window_inference

model = HoVerNet(
    in_channels=3,
    out_classes=5,
    mode = "fast",
    adapt_standard_resnet=True
)

data = MetaTensor(torch.rand([2, 3, 512, 512]))

sliding_window_inference(inputs=data, roi_size=(256, 256), sw_batch_size=2, predictor=model)

Expected behavior
sliding_window_inference shouldn't only handle tensor.

Environment

Ensuring you use the relevant python executable, please paste the output of:

================================
Printing MONAI config...
================================
MONAI version: 1.2.0rc2+34.gf98f0fda
Numpy version: 1.22.2
Pytorch version: 1.13.0a0+d0d6b1f
MONAI flags: HAS_EXT = False, USE_COMPILED = False, USE_META_DICT = False
MONAI rev id: f98f0fda3fc57293917ba9a5cd7838bf05477191
MONAI __file__: /workspace/Code/MONAI/monai/__init__.py

Optional dependencies:
Pytorch Ignite version: 0.4.10
ITK version: 5.3.0
Nibabel version: 4.0.2
scikit-image version: 0.19.3
Pillow version: 9.0.1
Tensorboard version: 2.10.1
gdown version: 4.6.0
TorchVision version: 0.14.0a0
tqdm version: 4.64.1
lmdb version: 1.3.0
psutil version: 5.9.2
pandas version: 1.4.4
einops version: 0.6.0
transformers version: 4.21.3
mlflow version: 2.0.1
pynrrd version: 1.0.0

For details about installing the optional dependencies, please visit:
    https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies


================================
Printing system config...
================================
System: Linux
Linux version: Ubuntu 20.04.5 LTS
Platform: Linux-5.15.0-67-generic-x86_64-with-glibc2.10
Processor: x86_64
Machine: x86_64
Python version: 3.8.13
Process name: python
Command: ['python', '-c', 'import monai; monai.config.print_debug_info()']
Open files: []
Num physical CPUs: 12
Num logical CPUs: 20
Num usable CPUs: 20
CPU usage (%): [5.2, 4.6, 5.7, 4.6, 5.2, 4.0, 5.2, 4.6, 6.9, 4.6, 98.8, 4.6, 6.4, 4.6, 6.8, 5.1, 4.6, 4.6, 4.0, 4.6]
CPU freq. (MHz): 3621
Load avg. in last 1, 5, 15 mins (%): [1.4, 0.9, 0.9]
Disk usage (%): 81.9
Avg. sensor temp. (Celsius): UNKNOWN for given OS
Total physical memory (GB): 62.6
Available memory (GB): 13.7
Used memory (GB): 48.2

================================
Printing GPU config...
================================
Num GPUs: 2
Has CUDA: True
CUDA version: 11.8
cuDNN enabled: True
cuDNN version: 8600
Current device: 0
Library compiled for CUDA architectures: ['sm_52', 'sm_60', 'sm_61', 'sm_70', 'sm_75', 'sm_80', 'sm_86', 'sm_90', 'compute_90']
GPU 0 Name: NVIDIA GeForce RTX 3090 Ti
GPU 0 Is integrated: False
GPU 0 Is multi GPU board: False
GPU 0 Multi processor count: 84
GPU 0 Total memory (GB): 23.7
GPU 0 CUDA capability (maj.min): 8.6
GPU 1 Name: NVIDIA GeForce RTX 3090 Ti
GPU 1 Is integrated: False
GPU 1 Is multi GPU board: False
GPU 1 Multi processor count: 84
GPU 1 Total memory (GB): 23.7
GPU 1 CUDA capability (maj.min): 8.6

Additional context
May need @wyli to confirm.
Maybe a blocking issue for pathology_nuclei_segmentation_classification bundle.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions