From a958190276498836ed504a7fd5e3528e8dc8f0a2 Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Mon, 10 Apr 2023 16:14:00 +0800 Subject: [PATCH 1/7] fix #6320 Signed-off-by: KumoLiu --- monai/inferers/utils.py | 8 +++++++- tests/test_sliding_window_hovernet_inference.py | 3 ++- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/monai/inferers/utils.py b/monai/inferers/utils.py index 59fb479904..bfd48b5eec 100644 --- a/monai/inferers/utils.py +++ b/monai/inferers/utils.py @@ -308,7 +308,13 @@ def sliding_window_inference( final_output = _pack_struct(output_image_list, dict_keys) final_output = convert_to_dst_type(final_output, inputs, device=device)[0] # type: ignore if temp_meta is not None: - final_output = MetaTensor(final_output).copy_meta_from(temp_meta) + if dict_keys is not None: + for _dict_key in dict_keys: + final_output[_dict_key] = MetaTensor(final_output[_dict_key]).copy_meta_from(temp_meta) + elif isinstance(final_output, tuple): + final_output = ensure_tuple(MetaTensor(i).copy_meta_from(temp_meta) for i in final_output) + else: + final_output = MetaTensor(final_output).copy_meta_from(temp_meta) return final_output # type: ignore diff --git a/tests/test_sliding_window_hovernet_inference.py b/tests/test_sliding_window_hovernet_inference.py index 8f7f8346cc..37db3a8e8d 100644 --- a/tests/test_sliding_window_hovernet_inference.py +++ b/tests/test_sliding_window_hovernet_inference.py @@ -17,6 +17,7 @@ import torch from parameterized import parameterized +from monai.data import MetaTensor from monai.apps.pathology.inferers import SlidingWindowHoVerNetInferer from monai.inferers import sliding_window_inference from monai.utils import optional_import @@ -247,7 +248,7 @@ def compute(data, test1, test2): def test_multioutput(self): device = "cuda" if torch.cuda.is_available() else "cpu:0" - inputs = torch.ones((1, 6, 20, 20)).to(device=device) + inputs = MetaTensor(torch.ones((1, 6, 20, 20)).to(device=device)) roi_shape = (8, 8) sw_batch_size = 10 From c8735079a8ab6f64881a9b0205d4ac7367662864 Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Mon, 10 Apr 2023 17:05:32 +0800 Subject: [PATCH 2/7] fix flake8 Signed-off-by: KumoLiu --- tests/test_sliding_window_hovernet_inference.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_sliding_window_hovernet_inference.py b/tests/test_sliding_window_hovernet_inference.py index 37db3a8e8d..8503c68517 100644 --- a/tests/test_sliding_window_hovernet_inference.py +++ b/tests/test_sliding_window_hovernet_inference.py @@ -17,8 +17,8 @@ import torch from parameterized import parameterized -from monai.data import MetaTensor from monai.apps.pathology.inferers import SlidingWindowHoVerNetInferer +from monai.data import MetaTensor from monai.inferers import sliding_window_inference from monai.utils import optional_import from tests.test_sliding_window_inference import TEST_CASES From d24652e3b1077a431f4d64b510dae70f8e2f2345 Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Tue, 11 Apr 2023 14:40:25 +0800 Subject: [PATCH 3/7] address comments Signed-off-by: KumoLiu --- tests/test_sliding_window_hovernet_inference.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/tests/test_sliding_window_hovernet_inference.py b/tests/test_sliding_window_hovernet_inference.py index 8503c68517..83b0000fe1 100644 --- a/tests/test_sliding_window_hovernet_inference.py +++ b/tests/test_sliding_window_hovernet_inference.py @@ -32,6 +32,11 @@ ["hover", (1, 3, 16, 8), (4, 4), 7, 0.5, "constant", torch.device("cpu:0"), (1,) * 4], ] +TEST_CASES_MULTIOUTPUT = [ + [torch.ones((1, 6, 20, 20))], + [MetaTensor(torch.ones((1, 6, 20, 20)))] +] + class TestSlidingWindowHoVerNetInference(unittest.TestCase): @parameterized.expand(TEST_CASES_PADDING) @@ -246,9 +251,10 @@ def compute(data, test1, test2): )(inputs, compute, t1, test2=t2) np.testing.assert_allclose(result.cpu().numpy(), expected, rtol=1e-4) - def test_multioutput(self): + @parameterized.expand(TEST_CASES_MULTIOUTPUT) + def test_multioutput(self, inputs): device = "cuda" if torch.cuda.is_available() else "cpu:0" - inputs = MetaTensor(torch.ones((1, 6, 20, 20)).to(device=device)) + inputs = inputs.to(device=device) roi_shape = (8, 8) sw_batch_size = 10 From 41b3a496d8444d06c9c9864e973dfe667041059a Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Tue, 11 Apr 2023 14:45:55 +0800 Subject: [PATCH 4/7] update type hints Signed-off-by: KumoLiu --- monai/inferers/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/inferers/utils.py b/monai/inferers/utils.py index 9284046c49..e1a0981539 100644 --- a/monai/inferers/utils.py +++ b/monai/inferers/utils.py @@ -41,7 +41,7 @@ def sliding_window_inference( - inputs: torch.Tensor, + inputs: torch.Tensor | MetaTensor, roi_size: Sequence[int] | int, sw_batch_size: int, predictor: Callable[..., torch.Tensor | Sequence[torch.Tensor] | dict[Any, torch.Tensor]], From bb4d3b2348f481580a325f095153218dbdebe6a7 Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Tue, 11 Apr 2023 14:47:14 +0800 Subject: [PATCH 5/7] fix flake8 Signed-off-by: KumoLiu --- tests/test_sliding_window_hovernet_inference.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/tests/test_sliding_window_hovernet_inference.py b/tests/test_sliding_window_hovernet_inference.py index 83b0000fe1..b17e8525ec 100644 --- a/tests/test_sliding_window_hovernet_inference.py +++ b/tests/test_sliding_window_hovernet_inference.py @@ -32,10 +32,7 @@ ["hover", (1, 3, 16, 8), (4, 4), 7, 0.5, "constant", torch.device("cpu:0"), (1,) * 4], ] -TEST_CASES_MULTIOUTPUT = [ - [torch.ones((1, 6, 20, 20))], - [MetaTensor(torch.ones((1, 6, 20, 20)))] -] +TEST_CASES_MULTIOUTPUT = [[torch.ones((1, 6, 20, 20))], [MetaTensor(torch.ones((1, 6, 20, 20)))]] class TestSlidingWindowHoVerNetInference(unittest.TestCase): From 16fcbbce352394921f8680cd61e6b67320eaadf8 Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Tue, 11 Apr 2023 15:20:13 +0800 Subject: [PATCH 6/7] use `convert_to_dst_type` instead Signed-off-by: KumoLiu --- monai/inferers/utils.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/monai/inferers/utils.py b/monai/inferers/utils.py index e1a0981539..ac352b7f35 100644 --- a/monai/inferers/utils.py +++ b/monai/inferers/utils.py @@ -307,15 +307,9 @@ def sliding_window_inference( output_image_list[ss] = output_i[(slice(None), slice(None), *final_slicing)] final_output = _pack_struct(output_image_list, dict_keys) - final_output = convert_to_dst_type(final_output, inputs, device=device)[0] # type: ignore if temp_meta is not None: - if dict_keys is not None: - for _dict_key in dict_keys: - final_output[_dict_key] = MetaTensor(final_output[_dict_key]).copy_meta_from(temp_meta) - elif isinstance(final_output, tuple): - final_output = ensure_tuple(MetaTensor(i).copy_meta_from(temp_meta) for i in final_output) - else: - final_output = MetaTensor(final_output).copy_meta_from(temp_meta) + final_output = convert_to_dst_type(final_output, temp_meta, device=device)[0] # type: ignore + return final_output # type: ignore From 04da4391e661db83f31e956623beaf6ae383e122 Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Tue, 11 Apr 2023 15:26:52 +0800 Subject: [PATCH 7/7] minor fix Signed-off-by: KumoLiu --- monai/inferers/utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/monai/inferers/utils.py b/monai/inferers/utils.py index ac352b7f35..31bf2a9db4 100644 --- a/monai/inferers/utils.py +++ b/monai/inferers/utils.py @@ -309,6 +309,8 @@ def sliding_window_inference( final_output = _pack_struct(output_image_list, dict_keys) if temp_meta is not None: final_output = convert_to_dst_type(final_output, temp_meta, device=device)[0] # type: ignore + else: + final_output = convert_to_dst_type(final_output, inputs, device=device)[0] return final_output # type: ignore