diff --git a/monai/networks/layers/filtering.py b/monai/networks/layers/filtering.py index 7eca03a280..fc6c0a38b5 100644 --- a/monai/networks/layers/filtering.py +++ b/monai/networks/layers/filtering.py @@ -47,15 +47,17 @@ class BilateralFilter(torch.autograd.Function): @staticmethod def forward(ctx, input, spatial_sigma=5, color_sigma=0.5, fast_approx=True): - ctx.save_for_backward(spatial_sigma, color_sigma, fast_approx) + ctx.ss = spatial_sigma + ctx.cs = color_sigma + ctx.fa = fast_approx output_data = _C.bilateral_filter(input, spatial_sigma, color_sigma, fast_approx) return output_data @staticmethod def backward(ctx, grad_output): - spatial_sigma, color_sigma, fast_approx = ctx.saved_variables + spatial_sigma, color_sigma, fast_approx = ctx.ss, ctx.cs, ctx.fa grad_input = _C.bilateral_filter(grad_output, spatial_sigma, color_sigma, fast_approx) - return grad_input + return grad_input, None, None, None class PHLFilter(torch.autograd.Function): diff --git a/tests/test_bilateral_approx_cpu.py b/tests/test_bilateral_approx_cpu.py index 2b6088a56f..96d60fb22c 100644 --- a/tests/test_bilateral_approx_cpu.py +++ b/tests/test_bilateral_approx_cpu.py @@ -14,6 +14,7 @@ import numpy as np import torch from parameterized import parameterized +from torch.autograd import gradcheck from monai.networks.layers.filtering import BilateralFilter from tests.utils import skip_if_no_cpp_extension @@ -376,6 +377,23 @@ def test_cpu_approx(self, test_case_description, sigmas, input, expected): # Ensure result are as expected np.testing.assert_allclose(output, expected, atol=1e-5) + @parameterized.expand(TEST_CASES) + def test_cpu_approx_backwards(self, test_case_description, sigmas, input, expected): + + # Params to determine the implementation to test + device = torch.device("cpu") + fast_approx = True + + # Prepare input tensor + input_tensor = torch.from_numpy(np.array(input)).to(dtype=torch.float, device=device) + input_tensor.requires_grad = True + + # Prepare args + args = (input_tensor, *sigmas, fast_approx) + + # Run grad check + gradcheck(BilateralFilter.apply, args, raise_exception=False) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_bilateral_approx_cuda.py b/tests/test_bilateral_approx_cuda.py index fdaba26f72..c95fbed1f5 100644 --- a/tests/test_bilateral_approx_cuda.py +++ b/tests/test_bilateral_approx_cuda.py @@ -14,6 +14,7 @@ import numpy as np import torch from parameterized import parameterized +from torch.autograd import gradcheck from monai.networks.layers.filtering import BilateralFilter from tests.utils import skip_if_no_cpp_extension, skip_if_no_cuda @@ -381,6 +382,23 @@ def test_cuda_approx(self, test_case_description, sigmas, input, expected): # Ensure result are as expected np.testing.assert_allclose(output, expected, atol=1e-2) + @parameterized.expand(TEST_CASES) + def test_cpu_approx_backwards(self, test_case_description, sigmas, input, expected): + + # Params to determine the implementation to test + device = torch.device("cuda") + fast_approx = True + + # Prepare input tensor + input_tensor = torch.from_numpy(np.array(input)).to(dtype=torch.float, device=device) + input_tensor.requires_grad = True + + # Prepare args + args = (input_tensor, *sigmas, fast_approx) + + # Run grad check + gradcheck(BilateralFilter.apply, args, raise_exception=False) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_bilateral_precise.py b/tests/test_bilateral_precise.py index db2ee88239..df4d0c2500 100644 --- a/tests/test_bilateral_precise.py +++ b/tests/test_bilateral_precise.py @@ -14,6 +14,7 @@ import numpy as np import torch from parameterized import parameterized +from torch.autograd import gradcheck from monai.networks.layers.filtering import BilateralFilter from tests.utils import skip_if_no_cpp_extension, skip_if_no_cuda @@ -361,9 +362,9 @@ @skip_if_no_cpp_extension -class BilateralFilterTestCaseCpuPrecised(unittest.TestCase): +class BilateralFilterTestCaseCpuPrecise(unittest.TestCase): @parameterized.expand(TEST_CASES) - def test_cpu_precised(self, test_case_description, sigmas, input, expected): + def test_cpu_precise(self, test_case_description, sigmas, input, expected): # Params to determine the implementation to test device = torch.device("cpu") @@ -376,12 +377,29 @@ def test_cpu_precised(self, test_case_description, sigmas, input, expected): # Ensure result are as expected np.testing.assert_allclose(output, expected, atol=1e-5) + @parameterized.expand(TEST_CASES) + def test_cpu_precise_backwards(self, test_case_description, sigmas, input, expected): + + # Params to determine the implementation to test + device = torch.device("cpu") + fast_approx = False + + # Prepare input tensor + input_tensor = torch.from_numpy(np.array(input)).to(dtype=torch.float, device=device) + input_tensor.requires_grad = True + + # Prepare args + args = (input_tensor, *sigmas, fast_approx) + + # Run grad check + gradcheck(BilateralFilter.apply, args, raise_exception=False) + @skip_if_no_cuda @skip_if_no_cpp_extension -class BilateralFilterTestCaseCudaPrecised(unittest.TestCase): +class BilateralFilterTestCaseCudaPrecise(unittest.TestCase): @parameterized.expand(TEST_CASES) - def test_cuda_precised(self, test_case_description, sigmas, input, expected): + def test_cuda_precise(self, test_case_description, sigmas, input, expected): # Skip this test if not torch.cuda.is_available(): @@ -398,6 +416,23 @@ def test_cuda_precised(self, test_case_description, sigmas, input, expected): # Ensure result are as expected np.testing.assert_allclose(output, expected, atol=1e-5) + @parameterized.expand(TEST_CASES) + def test_cuda_precise_backwards(self, test_case_description, sigmas, input, expected): + + # Params to determine the implementation to test + device = torch.device("cuda") + fast_approx = False + + # Prepare input tensor + input_tensor = torch.from_numpy(np.array(input)).to(dtype=torch.float, device=device) + input_tensor.requires_grad = True + + # Prepare args + args = (input_tensor, *sigmas, fast_approx) + + # Run grad check + gradcheck(BilateralFilter.apply, args, raise_exception=False) + if __name__ == "__main__": unittest.main()