diff --git a/monai/inferers/utils.py b/monai/inferers/utils.py index 45742eff79..b171d20ebb 100644 --- a/monai/inferers/utils.py +++ b/monai/inferers/utils.py @@ -41,7 +41,7 @@ def sliding_window_inference( - inputs: torch.Tensor, + inputs: torch.Tensor | MetaTensor, roi_size: Sequence[int] | int, sw_batch_size: int, predictor: Callable[..., torch.Tensor | Sequence[torch.Tensor] | dict[Any, torch.Tensor]], @@ -307,9 +307,11 @@ def sliding_window_inference( output_image_list[ss] = output_i[(slice(None), slice(None), *final_slicing)] final_output = _pack_struct(output_image_list, dict_keys) - final_output = convert_to_dst_type(final_output, inputs, device=device)[0] if temp_meta is not None: - final_output = MetaTensor(final_output).copy_meta_from(temp_meta) + final_output = convert_to_dst_type(final_output, temp_meta, device=device)[0] # type: ignore + else: + final_output = convert_to_dst_type(final_output, inputs, device=device)[0] + return final_output # type: ignore diff --git a/tests/test_sliding_window_hovernet_inference.py b/tests/test_sliding_window_hovernet_inference.py index 8f7f8346cc..b17e8525ec 100644 --- a/tests/test_sliding_window_hovernet_inference.py +++ b/tests/test_sliding_window_hovernet_inference.py @@ -18,6 +18,7 @@ from parameterized import parameterized from monai.apps.pathology.inferers import SlidingWindowHoVerNetInferer +from monai.data import MetaTensor from monai.inferers import sliding_window_inference from monai.utils import optional_import from tests.test_sliding_window_inference import TEST_CASES @@ -31,6 +32,8 @@ ["hover", (1, 3, 16, 8), (4, 4), 7, 0.5, "constant", torch.device("cpu:0"), (1,) * 4], ] +TEST_CASES_MULTIOUTPUT = [[torch.ones((1, 6, 20, 20))], [MetaTensor(torch.ones((1, 6, 20, 20)))]] + class TestSlidingWindowHoVerNetInference(unittest.TestCase): @parameterized.expand(TEST_CASES_PADDING) @@ -245,9 +248,10 @@ def compute(data, test1, test2): )(inputs, compute, t1, test2=t2) np.testing.assert_allclose(result.cpu().numpy(), expected, rtol=1e-4) - def test_multioutput(self): + @parameterized.expand(TEST_CASES_MULTIOUTPUT) + def test_multioutput(self, inputs): device = "cuda" if torch.cuda.is_available() else "cpu:0" - inputs = torch.ones((1, 6, 20, 20)).to(device=device) + inputs = inputs.to(device=device) roi_shape = (8, 8) sw_batch_size = 10