Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/source/networks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,11 @@ Layers
.. autoclass:: BilateralFilter
:members:

`SavitzkyGolayFilter`
~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: SavitzkyGolayFilter
:members:

`HilbertTransform`
~~~~~~~~~~~~~~~~~~
.. autoclass:: HilbertTransform
Expand Down
6 changes: 6 additions & 0 deletions docs/source/transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,12 @@ Intensity
:members:
:special-members: __call__

`SavitzkyGolaySmooth`
"""""""""""""""""""""
.. autoclass:: SavitzkyGolaySmooth
:members:
:special-members: __call__

`GaussianSmooth`
""""""""""""""""
.. autoclass:: GaussianSmooth
Expand Down
1 change: 1 addition & 0 deletions monai/networks/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
GaussianFilter,
HilbertTransform,
Reshape,
SavitzkyGolayFilter,
SkipConnection,
separable_filtering,
)
Expand Down
95 changes: 91 additions & 4 deletions monai/networks/layers/simplelayers.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
"LLTM",
"Reshape",
"separable_filtering",
"SavitzkyGolayFilter",
"HilbertTransform",
"ChannelPad",
]
Expand Down Expand Up @@ -163,18 +164,24 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return x.reshape(shape)


def separable_filtering(x: torch.Tensor, kernels: Union[Sequence[torch.Tensor], torch.Tensor]) -> torch.Tensor:
def separable_filtering(
x: torch.Tensor, kernels: Union[Sequence[torch.Tensor], torch.Tensor], mode: str = "zeros"
) -> torch.Tensor:
"""
Apply 1-D convolutions along each spatial dimension of `x`.

Args:
x: the input image. must have shape (batch, channels, H[, W, ...]).
kernels: kernel along each spatial dimension.
could be a single kernel (duplicated for all dimension), or `spatial_dims` number of kernels.
mode (string, optional): padding mode passed to convolution class. ``'zeros'``, ``'reflect'``, ``'replicate'``
or ``'circular'``. Default: ``'zeros'``. Modes other than ``'zeros'`` require PyTorch version >= 1.5.1. See
torch.nn.Conv1d() for more information.

Raises:
TypeError: When ``x`` is not a ``torch.Tensor``.
"""

if not torch.is_tensor(x):
raise TypeError(f"x must be a torch.Tensor but is {type(x).__name__}.")

Expand All @@ -184,23 +191,103 @@ def separable_filtering(x: torch.Tensor, kernels: Union[Sequence[torch.Tensor],
for s in ensure_tuple_rep(kernels, spatial_dims)
]
_paddings = [cast(int, (same_padding(k.shape[0]))) for k in _kernels]
n_chns = x.shape[1]
n_chs = x.shape[1]

def _conv(input_: torch.Tensor, d: int) -> torch.Tensor:
if d < 0:
return input_
s = [1] * len(input_.shape)
s[d + 2] = -1
_kernel = kernels[d].reshape(s)
_kernel = _kernel.repeat([n_chns, 1] + [1] * spatial_dims)
# if filter kernel is unity, don't convolve
if _kernel.numel() == 1 and _kernel[0] == 1:
return _conv(input_, d - 1)
_kernel = _kernel.repeat([n_chs, 1] + [1] * spatial_dims)
_padding = [0] * spatial_dims
_padding[d] = _paddings[d]
conv_type = [F.conv1d, F.conv2d, F.conv3d][spatial_dims - 1]
return conv_type(input=_conv(input_, d - 1), weight=_kernel, padding=_padding, groups=n_chns)
# translate padding for input to torch.nn.functional.pad
_reversed_padding_repeated_twice = [p for p in reversed(_padding) for _ in range(2)]
pad_mode = "constant" if mode == "zeros" else mode
return conv_type(
input=_conv(F.pad(input_, _reversed_padding_repeated_twice, mode=pad_mode), d - 1),
weight=_kernel,
groups=n_chs,
)

return _conv(x, spatial_dims - 1)


class SavitzkyGolayFilter(nn.Module):
"""
Convolve a Tensor along a particular axis with a Savitzky-Golay kernel.

Args:
window_length: Length of the filter window, must be a positive odd integer.
order: Order of the polynomial to fit to each window, must be less than ``window_length``.
axis (optional): Axis along which to apply the filter kernel. Default 2 (first spatial dimension).
mode (string, optional): padding mode passed to convolution class. ``'zeros'``, ``'reflect'``, ``'replicate'`` or
``'circular'``. Default: ``'zeros'``. See torch.nn.Conv1d() for more information.
"""

def __init__(self, window_length: int, order: int, axis: int = 2, mode: str = "zeros"):

super().__init__()
if order >= window_length:
raise ValueError("order must be less than window_length.")

self.axis = axis
self.mode = mode
self.coeffs = self._make_coeffs(window_length, order)

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: Tensor or array-like to filter. Must be real, in shape ``[Batch, chns, spatial1, spatial2, ...]`` and
have a device type of ``'cpu'``.
Returns:
torch.Tensor: ``x`` filtered by Savitzky-Golay kernel with window length ``self.window_length`` using
polynomials of order ``self.order``, along axis specified in ``self.axis``.
"""

# Make input a real tensor on the CPU
x = torch.as_tensor(x, device=x.device if torch.is_tensor(x) else None)
if torch.is_complex(x):
raise ValueError("x must be real.")
else:
x = x.to(dtype=torch.float)

if (self.axis < 0) or (self.axis > len(x.shape) - 1):
raise ValueError("Invalid axis for shape of x.")

# Create list of filter kernels (1 per spatial dimension). The kernel for self.axis will be the savgol coeffs,
# while the other kernels will be set to [1].
n_spatial_dims = len(x.shape) - 2
spatial_processing_axis = self.axis - 2
new_dims_before = spatial_processing_axis
new_dims_after = n_spatial_dims - spatial_processing_axis - 1
kernel_list = [self.coeffs.to(device=x.device, dtype=x.dtype)]
for _ in range(new_dims_before):
kernel_list.insert(0, torch.ones(1, device=x.device, dtype=x.dtype))
for _ in range(new_dims_after):
kernel_list.append(torch.ones(1, device=x.device, dtype=x.dtype))

return separable_filtering(x, kernel_list, mode=self.mode)

@staticmethod
def _make_coeffs(window_length, order):

half_length, rem = divmod(window_length, 2)
if rem == 0:
raise ValueError("window_length must be odd.")

idx = torch.arange(window_length - half_length - 1, -half_length - 1, -1, dtype=torch.float, device="cpu")
a = idx ** torch.arange(order + 1, dtype=torch.float, device="cpu").reshape(-1, 1)
y = torch.zeros(order + 1, dtype=torch.float, device="cpu")
y[0] = 1.0
return torch.lstsq(y, a).solution.squeeze()


class HilbertTransform(nn.Module):
"""
Determine the analytical signal of a Tensor along a particular axis.
Expand Down
1 change: 1 addition & 0 deletions monai/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@
RandHistogramShift,
RandScaleIntensity,
RandShiftIntensity,
SavitzkyGolaySmooth,
ScaleIntensity,
ScaleIntensityRange,
ScaleIntensityRangePercentiles,
Expand Down
41 changes: 40 additions & 1 deletion monai/transforms/intensity/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import numpy as np
import torch

from monai.networks.layers import GaussianFilter, HilbertTransform
from monai.networks.layers import GaussianFilter, HilbertTransform, SavitzkyGolayFilter
from monai.transforms.compose import Randomizable, Transform
from monai.transforms.utils import rescale_array
from monai.utils import PT_BEFORE_1_7, InvalidPyTorchVersionError, dtype_torch_to_numpy, ensure_tuple_size
Expand All @@ -39,6 +39,7 @@
"ScaleIntensityRangePercentiles",
"MaskIntensity",
"DetectEnvelope",
"SavitzkyGolaySmooth",
"GaussianSmooth",
"RandGaussianSmooth",
"GaussianSharpen",
Expand Down Expand Up @@ -544,6 +545,44 @@ def __call__(self, img: np.ndarray, mask_data: Optional[np.ndarray] = None) -> n
return img * mask_data_


class SavitzkyGolaySmooth(Transform):
"""
Smooth the input data along the given axis using a Savitzky-Golay filter.

Args:
window_length: Length of the filter window, must be a positive odd integer.
order: Order of the polynomial to fit to each window, must be less than ``window_length``.
axis: Optional axis along which to apply the filter kernel. Default 1 (first spatial dimension).
mode: Optional padding mode, passed to convolution class. ``'zeros'``, ``'reflect'``, ``'replicate'``
or ``'circular'``. Default: ``'zeros'``. See ``torch.nn.Conv1d()`` for more information.
"""

def __init__(self, window_length: int, order: int, axis: int = 1, mode: str = "zeros"):

if axis < 0:
raise ValueError("axis must be zero or positive.")

self.window_length = window_length
self.order = order
self.axis = axis
self.mode = mode

def __call__(self, img: np.ndarray) -> np.ndarray:
"""
Args:
img: numpy.ndarray containing input data. Must be real and in shape [channels, spatial1, spatial2, ...].

Returns:
np.ndarray containing smoothed result.

"""
# add one to transform axis because a batch axis will be added at dimension 0
savgol_filter = SavitzkyGolayFilter(self.window_length, self.order, self.axis + 1, self.mode)
# convert to Tensor and add Batch axis expected by HilbertTransform
input_data = torch.as_tensor(np.ascontiguousarray(img)).unsqueeze(0)
return savgol_filter(input_data).squeeze(0).numpy()


class DetectEnvelope(Transform):
"""
Find the envelope of the input data along the requested axis using a Hilbert transform.
Expand Down
152 changes: 152 additions & 0 deletions tests/test_savitzky_golay_filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
# Copyright 2020 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import numpy as np
import torch
from parameterized import parameterized

from monai.networks.layers import SavitzkyGolayFilter
from tests.utils import skip_if_no_cuda

# Zero-padding trivial tests

TEST_CASE_SINGLE_VALUE = [
{"window_length": 3, "order": 1},
torch.Tensor([1.0]).unsqueeze(0).unsqueeze(0), # Input data: Single value
torch.Tensor([1 / 3]).unsqueeze(0).unsqueeze(0), # Expected output: With a window length of 3 and polyorder 1
# output should be equal to mean of 0, 1 and 0 = 1/3 (because input will be zero-padded and a linear fit performed)
1e-15, # absolute tolerance
]

TEST_CASE_1D = [
{"window_length": 3, "order": 1},
torch.Tensor([1.0, 1.0, 1.0]).unsqueeze(0).unsqueeze(0), # Input data
torch.Tensor([2 / 3, 1.0, 2 / 3])
.unsqueeze(0)
.unsqueeze(0), # Expected output: zero padded, so linear interpolation
# over length-3 windows will result in output of [2/3, 1, 2/3].
1e-15, # absolute tolerance
]

TEST_CASE_2D_AXIS_2 = [
{"window_length": 3, "order": 1}, # along default axis (2, first spatial dim)
torch.ones((3, 2)).unsqueeze(0).unsqueeze(0),
torch.Tensor([[2 / 3, 2 / 3], [1.0, 1.0], [2 / 3, 2 / 3]]).unsqueeze(0).unsqueeze(0),
1e-15, # absolute tolerance
]

TEST_CASE_2D_AXIS_3 = [
{"window_length": 3, "order": 1, "axis": 3}, # along axis 3 (second spatial dim)
torch.ones((2, 3)).unsqueeze(0).unsqueeze(0),
torch.Tensor([[2 / 3, 1.0, 2 / 3], [2 / 3, 1.0, 2 / 3]]).unsqueeze(0).unsqueeze(0),
1e-15, # absolute tolerance
]

# Replicated-padding trivial tests

TEST_CASE_SINGLE_VALUE_REP = [
{"window_length": 3, "order": 1, "mode": "replicate"},
torch.Tensor([1.0]).unsqueeze(0).unsqueeze(0), # Input data: Single value
torch.Tensor([1.0]).unsqueeze(0).unsqueeze(0), # Expected output: With a window length of 3 and polyorder 1
# output will be equal to mean of [1, 1, 1] = 1 (input will be nearest-neighbour-padded and a linear fit performed)
1e-15, # absolute tolerance
]

TEST_CASE_1D_REP = [
{"window_length": 3, "order": 1, "mode": "replicate"},
torch.Tensor([1.0, 1.0, 1.0]).unsqueeze(0).unsqueeze(0), # Input data
torch.Tensor([1.0, 1.0, 1.0]).unsqueeze(0).unsqueeze(0), # Expected output: zero padded, so linear interpolation
# over length-3 windows will result in output of [2/3, 1, 2/3].
1e-15, # absolute tolerance
]

TEST_CASE_2D_AXIS_2_REP = [
{"window_length": 3, "order": 1, "mode": "replicate"}, # along default axis (2, first spatial dim)
torch.ones((3, 2)).unsqueeze(0).unsqueeze(0),
torch.Tensor([[1.0, 1.0], [1.0, 1.0], [1.0, 1.0]]).unsqueeze(0).unsqueeze(0),
1e-15, # absolute tolerance
]

TEST_CASE_2D_AXIS_3_REP = [
{"window_length": 3, "order": 1, "axis": 3, "mode": "replicate"}, # along axis 3 (second spatial dim)
torch.ones((2, 3)).unsqueeze(0).unsqueeze(0),
torch.Tensor([[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]).unsqueeze(0).unsqueeze(0),
1e-15, # absolute tolerance
]

# Sine smoothing

TEST_CASE_SINE_SMOOTH = [
{"window_length": 3, "order": 1},
# Sine wave with period equal to savgol window length (windowed to reduce edge effects).
torch.as_tensor(np.sin(2 * np.pi * 1 / 3 * np.arange(100)) * np.hanning(100)).unsqueeze(0).unsqueeze(0),
# Should be smoothed out to zeros
torch.zeros(100).unsqueeze(0).unsqueeze(0),
# tolerance chosen by examining output of SciPy.signal.savgol_filter when provided the above input
2e-2, # absolute tolerance
]


class TestSavitzkyGolayCPU(unittest.TestCase):
@parameterized.expand(
[
TEST_CASE_SINGLE_VALUE,
TEST_CASE_1D,
TEST_CASE_2D_AXIS_2,
TEST_CASE_2D_AXIS_3,
TEST_CASE_SINE_SMOOTH,
]
)
def test_value(self, arguments, image, expected_data, atol):
result = SavitzkyGolayFilter(**arguments)(image)
np.testing.assert_allclose(result, expected_data, atol=atol)


class TestSavitzkyGolayCPUREP(unittest.TestCase):
@parameterized.expand(
[TEST_CASE_SINGLE_VALUE_REP, TEST_CASE_1D_REP, TEST_CASE_2D_AXIS_2_REP, TEST_CASE_2D_AXIS_3_REP]
)
def test_value(self, arguments, image, expected_data, atol):
result = SavitzkyGolayFilter(**arguments)(image)
np.testing.assert_allclose(result, expected_data, atol=atol)


@skip_if_no_cuda
class TestSavitzkyGolayGPU(unittest.TestCase):
@parameterized.expand(
[
TEST_CASE_SINGLE_VALUE,
TEST_CASE_1D,
TEST_CASE_2D_AXIS_2,
TEST_CASE_2D_AXIS_3,
TEST_CASE_SINE_SMOOTH,
]
)
def test_value(self, arguments, image, expected_data, atol):
result = SavitzkyGolayFilter(**arguments)(image.to(device="cuda"))
np.testing.assert_allclose(result.cpu(), expected_data, atol=atol)


@skip_if_no_cuda
class TestSavitzkyGolayGPUREP(unittest.TestCase):
@parameterized.expand(
[
TEST_CASE_SINGLE_VALUE_REP,
TEST_CASE_1D_REP,
TEST_CASE_2D_AXIS_2_REP,
TEST_CASE_2D_AXIS_3_REP,
]
)
def test_value(self, arguments, image, expected_data, atol):
result = SavitzkyGolayFilter(**arguments)(image.to(device="cuda"))
np.testing.assert_allclose(result.cpu(), expected_data, atol=atol)
Loading