diff --git a/docs/source/networks.rst b/docs/source/networks.rst index 420da311d2..6a05d72b66 100644 --- a/docs/source/networks.rst +++ b/docs/source/networks.rst @@ -203,6 +203,11 @@ Layers .. autoclass:: BilateralFilter :members: +`SavitzkyGolayFilter` +~~~~~~~~~~~~~~~~~~~~~ +.. autoclass:: SavitzkyGolayFilter + :members: + `HilbertTransform` ~~~~~~~~~~~~~~~~~~ .. autoclass:: HilbertTransform diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index c769771f4a..90d960a6b9 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -186,6 +186,12 @@ Intensity :members: :special-members: __call__ +`SavitzkyGolaySmooth` +""""""""""""""""""""" +.. autoclass:: SavitzkyGolaySmooth + :members: + :special-members: __call__ + `GaussianSmooth` """""""""""""""" .. autoclass:: GaussianSmooth diff --git a/monai/networks/layers/__init__.py b/monai/networks/layers/__init__.py index dabec727ac..49c18eb5bf 100644 --- a/monai/networks/layers/__init__.py +++ b/monai/networks/layers/__init__.py @@ -19,6 +19,7 @@ GaussianFilter, HilbertTransform, Reshape, + SavitzkyGolayFilter, SkipConnection, separable_filtering, ) diff --git a/monai/networks/layers/simplelayers.py b/monai/networks/layers/simplelayers.py index a6524669b1..285b0d629f 100644 --- a/monai/networks/layers/simplelayers.py +++ b/monai/networks/layers/simplelayers.py @@ -39,6 +39,7 @@ "LLTM", "Reshape", "separable_filtering", + "SavitzkyGolayFilter", "HilbertTransform", "ChannelPad", ] @@ -163,7 +164,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x.reshape(shape) -def separable_filtering(x: torch.Tensor, kernels: Union[Sequence[torch.Tensor], torch.Tensor]) -> torch.Tensor: +def separable_filtering( + x: torch.Tensor, kernels: Union[Sequence[torch.Tensor], torch.Tensor], mode: str = "zeros" +) -> torch.Tensor: """ Apply 1-D convolutions along each spatial dimension of `x`. @@ -171,10 +174,14 @@ def separable_filtering(x: torch.Tensor, kernels: Union[Sequence[torch.Tensor], x: the input image. must have shape (batch, channels, H[, W, ...]). kernels: kernel along each spatial dimension. could be a single kernel (duplicated for all dimension), or `spatial_dims` number of kernels. + mode (string, optional): padding mode passed to convolution class. ``'zeros'``, ``'reflect'``, ``'replicate'`` + or ``'circular'``. Default: ``'zeros'``. Modes other than ``'zeros'`` require PyTorch version >= 1.5.1. See + torch.nn.Conv1d() for more information. Raises: TypeError: When ``x`` is not a ``torch.Tensor``. """ + if not torch.is_tensor(x): raise TypeError(f"x must be a torch.Tensor but is {type(x).__name__}.") @@ -184,7 +191,7 @@ def separable_filtering(x: torch.Tensor, kernels: Union[Sequence[torch.Tensor], for s in ensure_tuple_rep(kernels, spatial_dims) ] _paddings = [cast(int, (same_padding(k.shape[0]))) for k in _kernels] - n_chns = x.shape[1] + n_chs = x.shape[1] def _conv(input_: torch.Tensor, d: int) -> torch.Tensor: if d < 0: @@ -192,15 +199,95 @@ def _conv(input_: torch.Tensor, d: int) -> torch.Tensor: s = [1] * len(input_.shape) s[d + 2] = -1 _kernel = kernels[d].reshape(s) - _kernel = _kernel.repeat([n_chns, 1] + [1] * spatial_dims) + # if filter kernel is unity, don't convolve + if _kernel.numel() == 1 and _kernel[0] == 1: + return _conv(input_, d - 1) + _kernel = _kernel.repeat([n_chs, 1] + [1] * spatial_dims) _padding = [0] * spatial_dims _padding[d] = _paddings[d] conv_type = [F.conv1d, F.conv2d, F.conv3d][spatial_dims - 1] - return conv_type(input=_conv(input_, d - 1), weight=_kernel, padding=_padding, groups=n_chns) + # translate padding for input to torch.nn.functional.pad + _reversed_padding_repeated_twice = [p for p in reversed(_padding) for _ in range(2)] + pad_mode = "constant" if mode == "zeros" else mode + return conv_type( + input=_conv(F.pad(input_, _reversed_padding_repeated_twice, mode=pad_mode), d - 1), + weight=_kernel, + groups=n_chs, + ) return _conv(x, spatial_dims - 1) +class SavitzkyGolayFilter(nn.Module): + """ + Convolve a Tensor along a particular axis with a Savitzky-Golay kernel. + + Args: + window_length: Length of the filter window, must be a positive odd integer. + order: Order of the polynomial to fit to each window, must be less than ``window_length``. + axis (optional): Axis along which to apply the filter kernel. Default 2 (first spatial dimension). + mode (string, optional): padding mode passed to convolution class. ``'zeros'``, ``'reflect'``, ``'replicate'`` or + ``'circular'``. Default: ``'zeros'``. See torch.nn.Conv1d() for more information. + """ + + def __init__(self, window_length: int, order: int, axis: int = 2, mode: str = "zeros"): + + super().__init__() + if order >= window_length: + raise ValueError("order must be less than window_length.") + + self.axis = axis + self.mode = mode + self.coeffs = self._make_coeffs(window_length, order) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x: Tensor or array-like to filter. Must be real, in shape ``[Batch, chns, spatial1, spatial2, ...]`` and + have a device type of ``'cpu'``. + Returns: + torch.Tensor: ``x`` filtered by Savitzky-Golay kernel with window length ``self.window_length`` using + polynomials of order ``self.order``, along axis specified in ``self.axis``. + """ + + # Make input a real tensor on the CPU + x = torch.as_tensor(x, device=x.device if torch.is_tensor(x) else None) + if torch.is_complex(x): + raise ValueError("x must be real.") + else: + x = x.to(dtype=torch.float) + + if (self.axis < 0) or (self.axis > len(x.shape) - 1): + raise ValueError("Invalid axis for shape of x.") + + # Create list of filter kernels (1 per spatial dimension). The kernel for self.axis will be the savgol coeffs, + # while the other kernels will be set to [1]. + n_spatial_dims = len(x.shape) - 2 + spatial_processing_axis = self.axis - 2 + new_dims_before = spatial_processing_axis + new_dims_after = n_spatial_dims - spatial_processing_axis - 1 + kernel_list = [self.coeffs.to(device=x.device, dtype=x.dtype)] + for _ in range(new_dims_before): + kernel_list.insert(0, torch.ones(1, device=x.device, dtype=x.dtype)) + for _ in range(new_dims_after): + kernel_list.append(torch.ones(1, device=x.device, dtype=x.dtype)) + + return separable_filtering(x, kernel_list, mode=self.mode) + + @staticmethod + def _make_coeffs(window_length, order): + + half_length, rem = divmod(window_length, 2) + if rem == 0: + raise ValueError("window_length must be odd.") + + idx = torch.arange(window_length - half_length - 1, -half_length - 1, -1, dtype=torch.float, device="cpu") + a = idx ** torch.arange(order + 1, dtype=torch.float, device="cpu").reshape(-1, 1) + y = torch.zeros(order + 1, dtype=torch.float, device="cpu") + y[0] = 1.0 + return torch.lstsq(y, a).solution.squeeze() + + class HilbertTransform(nn.Module): """ Determine the analytical signal of a Tensor along a particular axis. diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index 8f46abf522..9eaedd6b15 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -78,6 +78,7 @@ RandHistogramShift, RandScaleIntensity, RandShiftIntensity, + SavitzkyGolaySmooth, ScaleIntensity, ScaleIntensityRange, ScaleIntensityRangePercentiles, diff --git a/monai/transforms/intensity/array.py b/monai/transforms/intensity/array.py index 2d3cca64e6..205b719246 100644 --- a/monai/transforms/intensity/array.py +++ b/monai/transforms/intensity/array.py @@ -20,7 +20,7 @@ import numpy as np import torch -from monai.networks.layers import GaussianFilter, HilbertTransform +from monai.networks.layers import GaussianFilter, HilbertTransform, SavitzkyGolayFilter from monai.transforms.compose import Randomizable, Transform from monai.transforms.utils import rescale_array from monai.utils import PT_BEFORE_1_7, InvalidPyTorchVersionError, dtype_torch_to_numpy, ensure_tuple_size @@ -39,6 +39,7 @@ "ScaleIntensityRangePercentiles", "MaskIntensity", "DetectEnvelope", + "SavitzkyGolaySmooth", "GaussianSmooth", "RandGaussianSmooth", "GaussianSharpen", @@ -544,6 +545,44 @@ def __call__(self, img: np.ndarray, mask_data: Optional[np.ndarray] = None) -> n return img * mask_data_ +class SavitzkyGolaySmooth(Transform): + """ + Smooth the input data along the given axis using a Savitzky-Golay filter. + + Args: + window_length: Length of the filter window, must be a positive odd integer. + order: Order of the polynomial to fit to each window, must be less than ``window_length``. + axis: Optional axis along which to apply the filter kernel. Default 1 (first spatial dimension). + mode: Optional padding mode, passed to convolution class. ``'zeros'``, ``'reflect'``, ``'replicate'`` + or ``'circular'``. Default: ``'zeros'``. See ``torch.nn.Conv1d()`` for more information. + """ + + def __init__(self, window_length: int, order: int, axis: int = 1, mode: str = "zeros"): + + if axis < 0: + raise ValueError("axis must be zero or positive.") + + self.window_length = window_length + self.order = order + self.axis = axis + self.mode = mode + + def __call__(self, img: np.ndarray) -> np.ndarray: + """ + Args: + img: numpy.ndarray containing input data. Must be real and in shape [channels, spatial1, spatial2, ...]. + + Returns: + np.ndarray containing smoothed result. + + """ + # add one to transform axis because a batch axis will be added at dimension 0 + savgol_filter = SavitzkyGolayFilter(self.window_length, self.order, self.axis + 1, self.mode) + # convert to Tensor and add Batch axis expected by HilbertTransform + input_data = torch.as_tensor(np.ascontiguousarray(img)).unsqueeze(0) + return savgol_filter(input_data).squeeze(0).numpy() + + class DetectEnvelope(Transform): """ Find the envelope of the input data along the requested axis using a Hilbert transform. diff --git a/tests/test_savitzky_golay_filter.py b/tests/test_savitzky_golay_filter.py new file mode 100644 index 0000000000..d76c42c15f --- /dev/null +++ b/tests/test_savitzky_golay_filter.py @@ -0,0 +1,152 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import torch +from parameterized import parameterized + +from monai.networks.layers import SavitzkyGolayFilter +from tests.utils import skip_if_no_cuda + +# Zero-padding trivial tests + +TEST_CASE_SINGLE_VALUE = [ + {"window_length": 3, "order": 1}, + torch.Tensor([1.0]).unsqueeze(0).unsqueeze(0), # Input data: Single value + torch.Tensor([1 / 3]).unsqueeze(0).unsqueeze(0), # Expected output: With a window length of 3 and polyorder 1 + # output should be equal to mean of 0, 1 and 0 = 1/3 (because input will be zero-padded and a linear fit performed) + 1e-15, # absolute tolerance +] + +TEST_CASE_1D = [ + {"window_length": 3, "order": 1}, + torch.Tensor([1.0, 1.0, 1.0]).unsqueeze(0).unsqueeze(0), # Input data + torch.Tensor([2 / 3, 1.0, 2 / 3]) + .unsqueeze(0) + .unsqueeze(0), # Expected output: zero padded, so linear interpolation + # over length-3 windows will result in output of [2/3, 1, 2/3]. + 1e-15, # absolute tolerance +] + +TEST_CASE_2D_AXIS_2 = [ + {"window_length": 3, "order": 1}, # along default axis (2, first spatial dim) + torch.ones((3, 2)).unsqueeze(0).unsqueeze(0), + torch.Tensor([[2 / 3, 2 / 3], [1.0, 1.0], [2 / 3, 2 / 3]]).unsqueeze(0).unsqueeze(0), + 1e-15, # absolute tolerance +] + +TEST_CASE_2D_AXIS_3 = [ + {"window_length": 3, "order": 1, "axis": 3}, # along axis 3 (second spatial dim) + torch.ones((2, 3)).unsqueeze(0).unsqueeze(0), + torch.Tensor([[2 / 3, 1.0, 2 / 3], [2 / 3, 1.0, 2 / 3]]).unsqueeze(0).unsqueeze(0), + 1e-15, # absolute tolerance +] + +# Replicated-padding trivial tests + +TEST_CASE_SINGLE_VALUE_REP = [ + {"window_length": 3, "order": 1, "mode": "replicate"}, + torch.Tensor([1.0]).unsqueeze(0).unsqueeze(0), # Input data: Single value + torch.Tensor([1.0]).unsqueeze(0).unsqueeze(0), # Expected output: With a window length of 3 and polyorder 1 + # output will be equal to mean of [1, 1, 1] = 1 (input will be nearest-neighbour-padded and a linear fit performed) + 1e-15, # absolute tolerance +] + +TEST_CASE_1D_REP = [ + {"window_length": 3, "order": 1, "mode": "replicate"}, + torch.Tensor([1.0, 1.0, 1.0]).unsqueeze(0).unsqueeze(0), # Input data + torch.Tensor([1.0, 1.0, 1.0]).unsqueeze(0).unsqueeze(0), # Expected output: zero padded, so linear interpolation + # over length-3 windows will result in output of [2/3, 1, 2/3]. + 1e-15, # absolute tolerance +] + +TEST_CASE_2D_AXIS_2_REP = [ + {"window_length": 3, "order": 1, "mode": "replicate"}, # along default axis (2, first spatial dim) + torch.ones((3, 2)).unsqueeze(0).unsqueeze(0), + torch.Tensor([[1.0, 1.0], [1.0, 1.0], [1.0, 1.0]]).unsqueeze(0).unsqueeze(0), + 1e-15, # absolute tolerance +] + +TEST_CASE_2D_AXIS_3_REP = [ + {"window_length": 3, "order": 1, "axis": 3, "mode": "replicate"}, # along axis 3 (second spatial dim) + torch.ones((2, 3)).unsqueeze(0).unsqueeze(0), + torch.Tensor([[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]).unsqueeze(0).unsqueeze(0), + 1e-15, # absolute tolerance +] + +# Sine smoothing + +TEST_CASE_SINE_SMOOTH = [ + {"window_length": 3, "order": 1}, + # Sine wave with period equal to savgol window length (windowed to reduce edge effects). + torch.as_tensor(np.sin(2 * np.pi * 1 / 3 * np.arange(100)) * np.hanning(100)).unsqueeze(0).unsqueeze(0), + # Should be smoothed out to zeros + torch.zeros(100).unsqueeze(0).unsqueeze(0), + # tolerance chosen by examining output of SciPy.signal.savgol_filter when provided the above input + 2e-2, # absolute tolerance +] + + +class TestSavitzkyGolayCPU(unittest.TestCase): + @parameterized.expand( + [ + TEST_CASE_SINGLE_VALUE, + TEST_CASE_1D, + TEST_CASE_2D_AXIS_2, + TEST_CASE_2D_AXIS_3, + TEST_CASE_SINE_SMOOTH, + ] + ) + def test_value(self, arguments, image, expected_data, atol): + result = SavitzkyGolayFilter(**arguments)(image) + np.testing.assert_allclose(result, expected_data, atol=atol) + + +class TestSavitzkyGolayCPUREP(unittest.TestCase): + @parameterized.expand( + [TEST_CASE_SINGLE_VALUE_REP, TEST_CASE_1D_REP, TEST_CASE_2D_AXIS_2_REP, TEST_CASE_2D_AXIS_3_REP] + ) + def test_value(self, arguments, image, expected_data, atol): + result = SavitzkyGolayFilter(**arguments)(image) + np.testing.assert_allclose(result, expected_data, atol=atol) + + +@skip_if_no_cuda +class TestSavitzkyGolayGPU(unittest.TestCase): + @parameterized.expand( + [ + TEST_CASE_SINGLE_VALUE, + TEST_CASE_1D, + TEST_CASE_2D_AXIS_2, + TEST_CASE_2D_AXIS_3, + TEST_CASE_SINE_SMOOTH, + ] + ) + def test_value(self, arguments, image, expected_data, atol): + result = SavitzkyGolayFilter(**arguments)(image.to(device="cuda")) + np.testing.assert_allclose(result.cpu(), expected_data, atol=atol) + + +@skip_if_no_cuda +class TestSavitzkyGolayGPUREP(unittest.TestCase): + @parameterized.expand( + [ + TEST_CASE_SINGLE_VALUE_REP, + TEST_CASE_1D_REP, + TEST_CASE_2D_AXIS_2_REP, + TEST_CASE_2D_AXIS_3_REP, + ] + ) + def test_value(self, arguments, image, expected_data, atol): + result = SavitzkyGolayFilter(**arguments)(image.to(device="cuda")) + np.testing.assert_allclose(result.cpu(), expected_data, atol=atol) diff --git a/tests/test_savitzky_golay_smooth.py b/tests/test_savitzky_golay_smooth.py new file mode 100644 index 0000000000..2be0da1360 --- /dev/null +++ b/tests/test_savitzky_golay_smooth.py @@ -0,0 +1,70 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from parameterized import parameterized + +from monai.transforms import SavitzkyGolaySmooth + +# Zero-padding trivial tests + +TEST_CASE_SINGLE_VALUE = [ + {"window_length": 3, "order": 1}, + np.expand_dims(np.array([1.0]), 0), # Input data: Single value + np.expand_dims(np.array([1 / 3]), 0), # Expected output: With a window length of 3 and polyorder 1 + # output should be equal to mean of 0, 1 and 0 = 1/3 (because input will be zero-padded and a linear fit performed) + 1e-15, # absolute tolerance +] + +TEST_CASE_2D_AXIS_2 = [ + {"window_length": 3, "order": 1, "axis": 2}, # along axis 2 (second spatial dim) + np.expand_dims(np.ones((2, 3)), 0), + np.expand_dims(np.array([[2 / 3, 1.0, 2 / 3], [2 / 3, 1.0, 2 / 3]]), 0), + 1e-15, # absolute tolerance +] + +# Replicated-padding trivial tests + +TEST_CASE_SINGLE_VALUE_REP = [ + {"window_length": 3, "order": 1, "mode": "replicate"}, + np.expand_dims(np.array([1.0]), 0), # Input data: Single value + np.expand_dims(np.array([1.0]), 0), # Expected output: With a window length of 3 and polyorder 1 + # output will be equal to mean of [1, 1, 1] = 1 (input will be nearest-neighbour-padded and a linear fit performed) + 1e-15, # absolute tolerance +] + +# Sine smoothing + +TEST_CASE_SINE_SMOOTH = [ + {"window_length": 3, "order": 1}, + # Sine wave with period equal to savgol window length (windowed to reduce edge effects). + np.expand_dims(np.sin(2 * np.pi * 1 / 3 * np.arange(100)) * np.hanning(100), 0), + # Should be smoothed out to zeros + np.expand_dims(np.zeros(100), 0), + # tolerance chosen by examining output of SciPy.signal.savgol_filter() when provided the above input + 2e-2, # absolute tolerance +] + + +class TestSavitzkyGolaySmooth(unittest.TestCase): + @parameterized.expand([TEST_CASE_SINGLE_VALUE, TEST_CASE_2D_AXIS_2, TEST_CASE_SINE_SMOOTH]) + def test_value(self, arguments, image, expected_data, atol): + result = SavitzkyGolaySmooth(**arguments)(image) + np.testing.assert_allclose(result, expected_data, atol=atol) + + +class TestSavitzkyGolaySmoothREP(unittest.TestCase): + @parameterized.expand([TEST_CASE_SINGLE_VALUE_REP]) + def test_value(self, arguments, image, expected_data, atol): + result = SavitzkyGolaySmooth(**arguments)(image) + np.testing.assert_allclose(result, expected_data, atol=atol)