Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions monai/networks/layers/filtering.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,15 +47,17 @@ class BilateralFilter(torch.autograd.Function):

@staticmethod
def forward(ctx, input, spatial_sigma=5, color_sigma=0.5, fast_approx=True):
ctx.save_for_backward(spatial_sigma, color_sigma, fast_approx)
ctx.ss = spatial_sigma
ctx.cs = color_sigma
ctx.fa = fast_approx
output_data = _C.bilateral_filter(input, spatial_sigma, color_sigma, fast_approx)
return output_data

@staticmethod
def backward(ctx, grad_output):
spatial_sigma, color_sigma, fast_approx = ctx.saved_variables
spatial_sigma, color_sigma, fast_approx = ctx.ss, ctx.cs, ctx.fa
grad_input = _C.bilateral_filter(grad_output, spatial_sigma, color_sigma, fast_approx)
return grad_input
return grad_input, None, None, None


class PHLFilter(torch.autograd.Function):
Expand Down
18 changes: 18 additions & 0 deletions tests/test_bilateral_approx_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import numpy as np
import torch
from parameterized import parameterized
from torch.autograd import gradcheck

from monai.networks.layers.filtering import BilateralFilter
from tests.utils import skip_if_no_cpp_extension
Expand Down Expand Up @@ -376,6 +377,23 @@ def test_cpu_approx(self, test_case_description, sigmas, input, expected):
# Ensure result are as expected
np.testing.assert_allclose(output, expected, atol=1e-5)

@parameterized.expand(TEST_CASES)
def test_cpu_approx_backwards(self, test_case_description, sigmas, input, expected):

# Params to determine the implementation to test
device = torch.device("cpu")
fast_approx = True

# Prepare input tensor
input_tensor = torch.from_numpy(np.array(input)).to(dtype=torch.float, device=device)
input_tensor.requires_grad = True

# Prepare args
args = (input_tensor, *sigmas, fast_approx)

# Run grad check
gradcheck(BilateralFilter.apply, args, raise_exception=False)


if __name__ == "__main__":
unittest.main()
18 changes: 18 additions & 0 deletions tests/test_bilateral_approx_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import numpy as np
import torch
from parameterized import parameterized
from torch.autograd import gradcheck

from monai.networks.layers.filtering import BilateralFilter
from tests.utils import skip_if_no_cpp_extension, skip_if_no_cuda
Expand Down Expand Up @@ -381,6 +382,23 @@ def test_cuda_approx(self, test_case_description, sigmas, input, expected):
# Ensure result are as expected
np.testing.assert_allclose(output, expected, atol=1e-2)

@parameterized.expand(TEST_CASES)
def test_cpu_approx_backwards(self, test_case_description, sigmas, input, expected):

# Params to determine the implementation to test
device = torch.device("cuda")
fast_approx = True

# Prepare input tensor
input_tensor = torch.from_numpy(np.array(input)).to(dtype=torch.float, device=device)
input_tensor.requires_grad = True

# Prepare args
args = (input_tensor, *sigmas, fast_approx)

# Run grad check
gradcheck(BilateralFilter.apply, args, raise_exception=False)


if __name__ == "__main__":
unittest.main()
43 changes: 39 additions & 4 deletions tests/test_bilateral_precise.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import numpy as np
import torch
from parameterized import parameterized
from torch.autograd import gradcheck

from monai.networks.layers.filtering import BilateralFilter
from tests.utils import skip_if_no_cpp_extension, skip_if_no_cuda
Expand Down Expand Up @@ -361,9 +362,9 @@


@skip_if_no_cpp_extension
class BilateralFilterTestCaseCpuPrecised(unittest.TestCase):
class BilateralFilterTestCaseCpuPrecise(unittest.TestCase):
@parameterized.expand(TEST_CASES)
def test_cpu_precised(self, test_case_description, sigmas, input, expected):
def test_cpu_precise(self, test_case_description, sigmas, input, expected):

# Params to determine the implementation to test
device = torch.device("cpu")
Expand All @@ -376,12 +377,29 @@ def test_cpu_precised(self, test_case_description, sigmas, input, expected):
# Ensure result are as expected
np.testing.assert_allclose(output, expected, atol=1e-5)

@parameterized.expand(TEST_CASES)
def test_cpu_precise_backwards(self, test_case_description, sigmas, input, expected):

# Params to determine the implementation to test
device = torch.device("cpu")
fast_approx = False

# Prepare input tensor
input_tensor = torch.from_numpy(np.array(input)).to(dtype=torch.float, device=device)
input_tensor.requires_grad = True

# Prepare args
args = (input_tensor, *sigmas, fast_approx)

# Run grad check
gradcheck(BilateralFilter.apply, args, raise_exception=False)


@skip_if_no_cuda
@skip_if_no_cpp_extension
class BilateralFilterTestCaseCudaPrecised(unittest.TestCase):
class BilateralFilterTestCaseCudaPrecise(unittest.TestCase):
@parameterized.expand(TEST_CASES)
def test_cuda_precised(self, test_case_description, sigmas, input, expected):
def test_cuda_precise(self, test_case_description, sigmas, input, expected):

# Skip this test
if not torch.cuda.is_available():
Expand All @@ -398,6 +416,23 @@ def test_cuda_precised(self, test_case_description, sigmas, input, expected):
# Ensure result are as expected
np.testing.assert_allclose(output, expected, atol=1e-5)

@parameterized.expand(TEST_CASES)
def test_cuda_precise_backwards(self, test_case_description, sigmas, input, expected):

# Params to determine the implementation to test
device = torch.device("cuda")
fast_approx = False

# Prepare input tensor
input_tensor = torch.from_numpy(np.array(input)).to(dtype=torch.float, device=device)
input_tensor.requires_grad = True

# Prepare args
args = (input_tensor, *sigmas, fast_approx)

# Run grad check
gradcheck(BilateralFilter.apply, args, raise_exception=False)


if __name__ == "__main__":
unittest.main()