From c35ad20480bfd0370cf549d30a19700d7ca1477f Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Tue, 21 Feb 2023 12:43:38 +0000 Subject: [PATCH 1/3] differentiable sliding window utility Signed-off-by: Wenqi Li --- monai/inferers/utils.py | 2 +- tests/test_sliding_window_inference.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/monai/inferers/utils.py b/monai/inferers/utils.py index 13c6d9d9ca..882f0f9101 100644 --- a/monai/inferers/utils.py +++ b/monai/inferers/utils.py @@ -262,7 +262,7 @@ def sliding_window_inference( # account for any overlapping sections for ss in range(len(output_image_list)): - output_image_list[ss] = output_image_list[ss].detach() + output_image_list[ss] = output_image_list[ss] _map = count_map_list.pop(0) for _i in range(output_image_list[ss].shape[1]): output_image_list[ss][:, _i : _i + 1, ...] /= _map diff --git a/tests/test_sliding_window_inference.py b/tests/test_sliding_window_inference.py index 2696d4456c..5f07084927 100644 --- a/tests/test_sliding_window_inference.py +++ b/tests/test_sliding_window_inference.py @@ -83,10 +83,12 @@ def test_default_device(self, data_type): def compute(data): return data + 1 + inputs.requires_grad = True result = sliding_window_inference(inputs, roi_shape, sw_batch_size, compute) + self.assertTrue(result.requires_grad) np.testing.assert_string_equal(inputs.device.type, result.device.type) expected_val = np.ones((1, 3, 16, 15, 7), dtype=np.float32) + 1 - np.testing.assert_allclose(result.cpu().numpy(), expected_val) + np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val) @parameterized.expand(list(itertools.product(TEST_TORCH_AND_META_TENSORS, ("cpu", "cuda"), ("cpu", "cuda", None)))) @skip_if_no_cuda From 59aef40764fe05a5a47ac98542dc2de4255b6a18 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Tue, 21 Feb 2023 16:28:13 +0000 Subject: [PATCH 2/3] skip quick tests Signed-off-by: Wenqi Li --- tests/test_vit.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_vit.py b/tests/test_vit.py index 62f8ad4f23..504c1ccebd 100644 --- a/tests/test_vit.py +++ b/tests/test_vit.py @@ -18,7 +18,7 @@ from monai.networks import eval_mode from monai.networks.nets.vit import ViT -from tests.utils import SkipIfBeforePyTorchVersion, test_script_save +from tests.utils import SkipIfBeforePyTorchVersion, skip_if_quick, test_script_save TEST_CASE_Vit = [] for dropout_rate in [0.6]: @@ -59,6 +59,7 @@ TEST_CASE_Vit.append(test_case) +@skip_if_quick class TestViT(unittest.TestCase): @parameterized.expand(TEST_CASE_Vit) def test_shape(self, input_param, input_shape, expected_shape): From 2bbf3a21a5664bfaf3129ff58d44cf0709eed28b Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Tue, 21 Feb 2023 17:32:40 +0000 Subject: [PATCH 3/3] skip warp if quick tests Signed-off-by: Wenqi Li --- tests/test_warp.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/tests/test_warp.py b/tests/test_warp.py index 0f1ef3101f..e614973f90 100644 --- a/tests/test_warp.py +++ b/tests/test_warp.py @@ -22,7 +22,13 @@ from monai.networks.blocks.warp import Warp from monai.transforms import LoadImaged from monai.utils import GridSampleMode, GridSamplePadMode -from tests.utils import SkipIfBeforePyTorchVersion, SkipIfNoModule, download_url_or_skip_test, testing_data_config +from tests.utils import ( + SkipIfBeforePyTorchVersion, + SkipIfNoModule, + download_url_or_skip_test, + skip_if_quick, + testing_data_config, +) LOW_POWER_TEST_CASES = [ # run with BUILD_MONAI=1 to test csrc/resample, BUILD_MONAI=0 to test native grid_sample [ @@ -98,6 +104,7 @@ TEST_CASES += CPP_TEST_CASES +@skip_if_quick class TestWarp(unittest.TestCase): def setUp(self): config = testing_data_config("images", "Prostate_T2W_AX_1")