From d6ad24a876cae9d64517b59a6ed91706719a8de1 Mon Sep 17 00:00:00 2001 From: Christian Baker Date: Fri, 8 Jan 2021 10:14:00 +0000 Subject: [PATCH 1/5] Added new simplelayer SavitskyGolayFilter() Signed-off-by: Christian Baker --- monai/networks/layers/__init__.py | 3 +- monai/networks/layers/simplelayers.py | 102 ++++++++++++++++++++++++-- 2 files changed, 98 insertions(+), 7 deletions(-) diff --git a/monai/networks/layers/__init__.py b/monai/networks/layers/__init__.py index dabec727ac..1121652170 100644 --- a/monai/networks/layers/__init__.py +++ b/monai/networks/layers/__init__.py @@ -17,9 +17,10 @@ ChannelPad, Flatten, GaussianFilter, + SavitskyGolayFilter, HilbertTransform, Reshape, SkipConnection, separable_filtering, ) -from .spatial_transforms import AffineTransform, grid_count, grid_grad, grid_pull, grid_push +from .spatial_transforms import AffineTransform, grid_count, grid_grad, grid_pull, grid_push \ No newline at end of file diff --git a/monai/networks/layers/simplelayers.py b/monai/networks/layers/simplelayers.py index ba60f4eca4..d263cfa5a7 100644 --- a/monai/networks/layers/simplelayers.py +++ b/monai/networks/layers/simplelayers.py @@ -39,8 +39,9 @@ "LLTM", "Reshape", "separable_filtering", + "SavitskyGolayFilter", "HilbertTransform", - "ChannelPad", + "ChannelPad" ] @@ -163,7 +164,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x.reshape(shape) -def separable_filtering(x: torch.Tensor, kernels: Union[Sequence[torch.Tensor], torch.Tensor]) -> torch.Tensor: +def separable_filtering( + x: torch.Tensor, kernels: Union[Sequence[torch.Tensor], torch.Tensor], mode: str = "zeros" +) -> torch.Tensor: """ Apply 1-D convolutions along each spatial dimension of `x`. @@ -171,6 +174,8 @@ def separable_filtering(x: torch.Tensor, kernels: Union[Sequence[torch.Tensor], x: the input image. must have shape (batch, channels, H[, W, ...]). kernels: kernel along each spatial dimension. could be a single kernel (duplicated for all dimension), or `spatial_dims` number of kernels. + mode (string, optional): padding mode passed to convolution class. ``'zeros'``, ``'reflect'``, ``'replicate'`` or + ``'circular'``. Default: ``'zeros'``. See torch.nn.Conv1d() for more information. Raises: TypeError: When ``x`` is not a ``torch.Tensor``. @@ -184,7 +189,7 @@ def separable_filtering(x: torch.Tensor, kernels: Union[Sequence[torch.Tensor], for s in ensure_tuple_rep(kernels, spatial_dims) ] _paddings = [cast(int, (same_padding(k.shape[0]))) for k in _kernels] - n_chns = x.shape[1] + n_chs = x.shape[1] def _conv(input_: torch.Tensor, d: int) -> torch.Tensor: if d < 0: @@ -192,15 +197,100 @@ def _conv(input_: torch.Tensor, d: int) -> torch.Tensor: s = [1] * len(input_.shape) s[d + 2] = -1 _kernel = kernels[d].reshape(s) - _kernel = _kernel.repeat([n_chns, 1] + [1] * spatial_dims) + # if filter kernel is unity, don't convolve + if torch.equal(_kernel.squeeze(), torch.ones(1, device=_kernel.device)): + return _conv(input_, d - 1) + _kernel = _kernel.repeat([n_chs, 1] + [1] * spatial_dims) _padding = [0] * spatial_dims _padding[d] = _paddings[d] - conv_type = [F.conv1d, F.conv2d, F.conv3d][spatial_dims - 1] - return conv_type(input=_conv(input_, d - 1), weight=_kernel, padding=_padding, groups=n_chns) + + if mode == "zeros": # if zero padding (default), can use functional convolution + conv_type = [F.conv1d, F.conv2d, F.conv3d][spatial_dims - 1] + return conv_type(input=_conv(input_, d - 1), weight=_kernel, padding=_padding, groups=n_chs) + else: + conv_type = [ + nn.Conv1d(n_chs, n_chs, _kernel.shape, padding=_padding, groups=n_chs, bias=False, padding_mode=mode), + nn.Conv2d(n_chs, n_chs, _kernel.shape, padding=_padding, groups=n_chs, bias=False, padding_mode=mode), + nn.Conv3d(n_chs, n_chs, _kernel.shape, padding=_padding, groups=n_chs, bias=False, padding_mode=mode), + ][spatial_dims - 1] + conv_type.weight = torch.nn.Parameter(_kernel, requires_grad=_kernel.requires_grad) + return conv_type(input=_conv(input_, d - 1)) return _conv(x, spatial_dims - 1) +class SavitskyGolayFilter(nn.Module): + """ + Convolve a Tensor along a particular axis with a Savitsky-Golay kernel. + + Args: + window_length: Length of the filter window, must be a positive odd integer. + order: Order of the polynomial to fit to each window, must be less than ``window_length``. + axis (optional): Axis along which to apply the filter kernel. Default 2 (first spatial dimension). + mode (string, optional): padding mode passed to convolution class. ``'zeros'``, ``'reflect'``, ``'replicate'`` or + ``'circular'``. Default: ``'zeros'``. See torch.nn.Conv1d() for more information. + """ + + def __init__(self, window_length: int, order: int, axis: int = 2, mode: str = "zeros"): + + super().__init__() + if order >= window_length: + raise ValueError("order must be less than window_length.") + + self.axis = axis + self.mode = mode + self.coeffs = self._make_coeffs(window_length, order) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x: Tensor or array-like to filter. Must be real, in shape ``[Batch, chns, spatial1, spatial2, ...]`` and + have a device type of ``'cpu'``. + Returns: + torch.Tensor: ``x`` filtered by Savitsky-Golay kernel with window length ``self.window_length`` using + polynomials of order ``self.order``, along axis specified in ``self.axis``. + """ + + # Make input a real tensor on the CPU + x = torch.as_tensor(x, device=x.device if torch.is_tensor(x) else None) + if torch.is_complex(x): + raise ValueError("x must be real.") + else: + x = x.to(dtype=torch.float) + + if (self.axis < 0) or (self.axis > len(x.shape) - 1): + raise ValueError("Invalid axis for shape of x.") + + # Create list of filter kernels (1 per spatial dimension). The kernel for self.axis will be the savgol coeffs, + # while the other kernels will be set to [1]. + n_spatial_dims = len(x.shape) - 2 + spatial_processing_axis = self.axis - 2 + new_dims_before = spatial_processing_axis + new_dims_after = n_spatial_dims - spatial_processing_axis - 1 + kernel_list = [self.coeffs.to(device=x.device, dtype=x.dtype)] + for _ in range(new_dims_before): + kernel_list.insert(0, torch.ones(1, device=x.device, dtype=x.dtype)) + for _ in range(new_dims_after): + kernel_list.append(torch.ones(1, device=x.device, dtype=x.dtype)) + + return separable_filtering(x, kernel_list, mode=self.mode) + + @staticmethod + def _make_coeffs(window_length, order): + + half_length, rem = divmod(window_length, 2) + if rem == 0: + raise ValueError("window_length must be odd.") + + idx = torch.arange( + window_length - half_length - 1, -half_length - 1, -1, dtype=torch.float, device="cpu" + ) + a = idx ** torch.arange(order + 1, dtype=torch.float, device="cpu").reshape(-1, 1) + y = torch.zeros(order + 1, dtype=torch.float, device="cpu") + y[0] = 1.0 + return torch.lstsq(y, a).solution.squeeze() + + class HilbertTransform(nn.Module): """ Determine the analytical signal of a Tensor along a particular axis. From 4e4416a2cddd86d77779138c23a0ffd9f2e46f45 Mon Sep 17 00:00:00 2001 From: Christian Baker Date: Fri, 8 Jan 2021 10:23:14 +0000 Subject: [PATCH 2/5] Unit tests written for SavitskyGolayFilter() Signed-off-by: Christian Baker --- tests/test_savitsky_golay_filter.py | 136 ++++++++++++++++++++++++++++ 1 file changed, 136 insertions(+) create mode 100644 tests/test_savitsky_golay_filter.py diff --git a/tests/test_savitsky_golay_filter.py b/tests/test_savitsky_golay_filter.py new file mode 100644 index 0000000000..f13ead2bfe --- /dev/null +++ b/tests/test_savitsky_golay_filter.py @@ -0,0 +1,136 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import torch +from parameterized import parameterized + +from monai.networks.layers import SavitskyGolayFilter +from tests.utils import skip_if_no_cuda + +# Zero-padding trivial tests + +TEST_CASE_SINGLE_VALUE = [ + {"window_length": 3, "order": 1}, + torch.Tensor([1.0]).unsqueeze(0).unsqueeze(0), # Input data: Single value + torch.Tensor([1 / 3]).unsqueeze(0).unsqueeze(0), # Expected output: With a window length of 3 and polyorder 1 + # output should be equal to mean of 0, 1 and 0 = 1/3 (because input will be zero-padded and a linear fit performed) + 1e-15, # absolute tolerance +] + +TEST_CASE_1D = [ + {"window_length": 3, "order": 1}, + torch.Tensor([1.0, 1.0, 1.0]).unsqueeze(0).unsqueeze(0), # Input data + torch.Tensor([2 / 3, 1.0, 2 / 3]) + .unsqueeze(0) + .unsqueeze(0), # Expected output: zero padded, so linear interpolation + # over length-3 windows will result in output of [2/3, 1, 2/3]. + 1e-15, # absolute tolerance +] + +TEST_CASE_2D_AXIS_2 = [ + {"window_length": 3, "order": 1}, # along default axis (2, first spatial dim) + torch.ones((3, 2)).unsqueeze(0).unsqueeze(0), + torch.Tensor([[2 / 3, 2 / 3], [1.0, 1.0], [2 / 3, 2 / 3]]).unsqueeze(0).unsqueeze(0), + 1e-15, # absolute tolerance +] + +TEST_CASE_2D_AXIS_3 = [ + {"window_length": 3, "order": 1, "axis": 3}, # along axis 3 (second spatial dim) + torch.ones((2, 3)).unsqueeze(0).unsqueeze(0), + torch.Tensor([[2 / 3, 1.0, 2 / 3], [2 / 3, 1.0, 2 / 3]]).unsqueeze(0).unsqueeze(0), + 1e-15, # absolute tolerance +] + +# Replicated-padding trivial tests + +TEST_CASE_SINGLE_VALUE_REP = [ + {"window_length": 3, "order": 1, "mode": "replicate"}, + torch.Tensor([1.0]).unsqueeze(0).unsqueeze(0), # Input data: Single value + torch.Tensor([1.0]).unsqueeze(0).unsqueeze(0), # Expected output: With a window length of 3 and polyorder 1 + # output will be equal to mean of [1, 1, 1] = 1 (input will be nearest-neighbour-padded and a linear fit performed) + 1e-15, # absolute tolerance +] + +TEST_CASE_1D_REP = [ + {"window_length": 3, "order": 1, "mode": "replicate"}, + torch.Tensor([1.0, 1.0, 1.0]).unsqueeze(0).unsqueeze(0), # Input data + torch.Tensor([1.0, 1.0, 1.0]).unsqueeze(0).unsqueeze(0), # Expected output: zero padded, so linear interpolation + # over length-3 windows will result in output of [2/3, 1, 2/3]. + 1e-15, # absolute tolerance +] + +TEST_CASE_2D_AXIS_2_REP = [ + {"window_length": 3, "order": 1, "mode": "replicate"}, # along default axis (2, first spatial dim) + torch.ones((3, 2)).unsqueeze(0).unsqueeze(0), + torch.Tensor([[1.0, 1.0], [1.0, 1.0], [1.0, 1.0]]).unsqueeze(0).unsqueeze(0), + 1e-15, # absolute tolerance +] + +TEST_CASE_2D_AXIS_3_REP = [ + {"window_length": 3, "order": 1, "axis": 3, "mode": "replicate"}, # along axis 3 (second spatial dim) + torch.ones((2, 3)).unsqueeze(0).unsqueeze(0), + torch.Tensor([[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]).unsqueeze(0).unsqueeze(0), + 1e-15, # absolute tolerance +] + +# Sine smoothing + +TEST_CASE_SINE_SMOOTH = [ + {"window_length": 3, "order": 1}, + # Sine wave with period equal to savgol window length (windowed to reduce edge effects). + torch.as_tensor(np.sin(2 * np.pi * 1 / 3 * np.arange(100)) * np.hanning(100)).unsqueeze(0).unsqueeze(0), + # Should be smoothed out to zeros + torch.zeros(100).unsqueeze(0).unsqueeze(0), + # tolerance chosen by examining output of SciPy.signal.savgol_filter when provided the above input + 2e-2, # absolute tolerance +] + + +class TestSavitskyGolayCPU(unittest.TestCase): + @parameterized.expand( + [ + TEST_CASE_SINGLE_VALUE, + TEST_CASE_1D, + TEST_CASE_2D_AXIS_2, + TEST_CASE_2D_AXIS_3, + TEST_CASE_SINGLE_VALUE_REP, + TEST_CASE_1D_REP, + TEST_CASE_2D_AXIS_2_REP, + TEST_CASE_2D_AXIS_3_REP, + TEST_CASE_SINE_SMOOTH, + ] + ) + def test_value(self, arguments, image, expected_data, atol): + result = SavitskyGolayFilter(**arguments)(image) + np.testing.assert_allclose(result, expected_data, atol=atol) + + +@skip_if_no_cuda +class TestSavitskyGolayGPU(unittest.TestCase): + @parameterized.expand( + [ + TEST_CASE_SINGLE_VALUE, + TEST_CASE_1D, + TEST_CASE_2D_AXIS_2, + TEST_CASE_2D_AXIS_3, + TEST_CASE_SINGLE_VALUE_REP, + TEST_CASE_1D_REP, + TEST_CASE_2D_AXIS_2_REP, + TEST_CASE_2D_AXIS_3_REP, + TEST_CASE_SINE_SMOOTH, + ] + ) + def test_value(self, arguments, image, expected_data, atol): + result = SavitskyGolayFilter(**arguments)(image.to(device="cuda")) + np.testing.assert_allclose(result.cpu(), expected_data, atol=atol) From 0bee6f466d9a06a6cd396739bcead159e604e840 Mon Sep 17 00:00:00 2001 From: Christian Baker Date: Fri, 8 Jan 2021 10:28:37 +0000 Subject: [PATCH 3/5] New array transform SavitskyGolaySmooth() written that wraps SavitskyGolayFilter() simplelayer Signed-off-by: Christian Baker --- monai/networks/layers/__init__.py | 2 +- monai/transforms/__init__.py | 1 + monai/transforms/intensity/array.py | 41 ++++++++++++++++++++++++++++- 3 files changed, 42 insertions(+), 2 deletions(-) diff --git a/monai/networks/layers/__init__.py b/monai/networks/layers/__init__.py index 1121652170..8a603d1123 100644 --- a/monai/networks/layers/__init__.py +++ b/monai/networks/layers/__init__.py @@ -17,9 +17,9 @@ ChannelPad, Flatten, GaussianFilter, - SavitskyGolayFilter, HilbertTransform, Reshape, + SavitskyGolayFilter, SkipConnection, separable_filtering, ) diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index 305c27607e..db6900a345 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -104,6 +104,7 @@ RandHistogramShift, RandScaleIntensity, RandShiftIntensity, + SavitskyGolaySmooth, ScaleIntensity, ScaleIntensityRange, ScaleIntensityRangePercentiles, diff --git a/monai/transforms/intensity/array.py b/monai/transforms/intensity/array.py index 84d25c663f..307537590f 100644 --- a/monai/transforms/intensity/array.py +++ b/monai/transforms/intensity/array.py @@ -20,7 +20,7 @@ import numpy as np import torch -from monai.networks.layers import GaussianFilter, HilbertTransform +from monai.networks.layers import GaussianFilter, HilbertTransform, SavitskyGolayFilter from monai.transforms.compose import Randomizable, Transform from monai.transforms.utils import rescale_array from monai.utils import PT_BEFORE_1_7, InvalidPyTorchVersionError, dtype_torch_to_numpy, ensure_tuple_size @@ -39,6 +39,7 @@ "ScaleIntensityRangePercentiles", "MaskIntensity", "DetectEnvelope", + "SavitskyGolaySmooth", "GaussianSmooth", "RandGaussianSmooth", "GaussianSharpen", @@ -533,6 +534,44 @@ def __call__(self, img: np.ndarray, mask_data: Optional[np.ndarray] = None) -> n return img * mask_data_ +class SavitskyGolaySmooth(Transform): + """ + Smooth the input data along the given axis using a Savitsky-Golay filter. + + Args: + window_length: Length of the filter window, must be a positive odd integer. + order: Order of the polynomial to fit to each window, must be less than ``window_length``. + axis: Optional axis along which to apply the filter kernel. Default 1 (first spatial dimension). + mode: Optional padding mode, passed to convolution class. ``'zeros'``, ``'reflect'``, ``'replicate'`` + or ``'circular'``. Default: ``'zeros'``. See ``torch.nn.Conv1d()`` for more information. + """ + + def __init__(self, window_length: int, order: int, axis: int = 1, mode: str = "zeros"): + + if axis < 0: + raise ValueError("axis must be zero or positive.") + + self.window_length = window_length + self.order = order + self.axis = axis + self.mode = mode + + def __call__(self, img: np.ndarray) -> np.ndarray: + """ + Args: + img: numpy.ndarray containing input data. Must be real and in shape [channels, spatial1, spatial2, ...]. + + Returns: + np.ndarray containing smoothed result. + + """ + # add one to transform axis because a batch axis will be added at dimension 0 + savgol_filter = SavitskyGolayFilter(self.window_length, self.order, self.axis + 1, self.mode) + # convert to Tensor and add Batch axis expected by HilbertTransform + input_data = torch.as_tensor(np.ascontiguousarray(img)).unsqueeze(0) + return savgol_filter(input_data).squeeze(0).numpy() + + class DetectEnvelope(Transform): """ Find the envelope of the input data along the requested axis using a Hilbert transform. From a41fa18328d5207d72d6bb31b0275ef984ce1119 Mon Sep 17 00:00:00 2001 From: Christian Baker Date: Fri, 8 Jan 2021 10:44:43 +0000 Subject: [PATCH 4/5] Tests added for SavitskyGolaySmooth() Signed-off-by: Christian Baker --- tests/test_savitsky_golay_smooth.py | 65 +++++++++++++++++++++++++++++ 1 file changed, 65 insertions(+) create mode 100644 tests/test_savitsky_golay_smooth.py diff --git a/tests/test_savitsky_golay_smooth.py b/tests/test_savitsky_golay_smooth.py new file mode 100644 index 0000000000..dcd19d7f28 --- /dev/null +++ b/tests/test_savitsky_golay_smooth.py @@ -0,0 +1,65 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from parameterized import parameterized + +from monai.transforms import SavitskyGolaySmooth + +# Zero-padding trivial tests + +TEST_CASE_SINGLE_VALUE = [ + {"window_length": 3, "order": 1}, + np.expand_dims(np.array([1.0]), 0), # Input data: Single value + np.expand_dims(np.array([1 / 3]), 0), # Expected output: With a window length of 3 and polyorder 1 + # output should be equal to mean of 0, 1 and 0 = 1/3 (because input will be zero-padded and a linear fit performed) + 1e-15, # absolute tolerance +] + +TEST_CASE_2D_AXIS_2 = [ + {"window_length": 3, "order": 1, "axis": 2}, # along axis 2 (second spatial dim) + np.expand_dims(np.ones((2, 3)), 0), + np.expand_dims(np.array([[2 / 3, 1.0, 2 / 3], [2 / 3, 1.0, 2 / 3]]), 0), + 1e-15, # absolute tolerance +] + +# Replicated-padding trivial tests + +TEST_CASE_SINGLE_VALUE_REP = [ + {"window_length": 3, "order": 1, "mode": "replicate"}, + np.expand_dims(np.array([1.0]), 0), # Input data: Single value + np.expand_dims(np.array([1.0]), 0), # Expected output: With a window length of 3 and polyorder 1 + # output will be equal to mean of [1, 1, 1] = 1 (input will be nearest-neighbour-padded and a linear fit performed) + 1e-15, # absolute tolerance +] + +# Sine smoothing + +TEST_CASE_SINE_SMOOTH = [ + {"window_length": 3, "order": 1}, + # Sine wave with period equal to savgol window length (windowed to reduce edge effects). + np.expand_dims(np.sin(2 * np.pi * 1 / 3 * np.arange(100)) * np.hanning(100), 0), + # Should be smoothed out to zeros + np.expand_dims(np.zeros(100), 0), + # tolerance chosen by examining output of SciPy.signal.savgol_filter() when provided the above input + 2e-2, # absolute tolerance +] + + +class TestSavitskyGolaySmooth(unittest.TestCase): + @parameterized.expand( + [TEST_CASE_SINGLE_VALUE, TEST_CASE_2D_AXIS_2, TEST_CASE_SINGLE_VALUE_REP, TEST_CASE_SINE_SMOOTH] + ) + def test_value(self, arguments, image, expected_data, atol): + result = SavitskyGolaySmooth(**arguments)(image) + np.testing.assert_allclose(result, expected_data, atol=atol) \ No newline at end of file From 4b3345b2a2e389ec756b6c6d33079015c6e72b69 Mon Sep 17 00:00:00 2001 From: Christian Baker Date: Fri, 8 Jan 2021 10:51:34 +0000 Subject: [PATCH 5/5] Added to Sphinx .rst files and rebuilt docs Signed-off-by: Christian Baker --- docs/source/networks.rst | 5 +++ docs/source/transforms.rst | 6 +++ monai/networks/layers/__init__.py | 4 +- monai/networks/layers/simplelayers.py | 43 +++++++++---------- monai/transforms/__init__.py | 2 +- monai/transforms/intensity/array.py | 10 ++--- ...ilter.py => test_savitzky_golay_filter.py} | 36 +++++++++++----- ...mooth.py => test_savitzky_golay_smooth.py} | 19 +++++--- 8 files changed, 77 insertions(+), 48 deletions(-) rename tests/{test_savitsky_golay_filter.py => test_savitzky_golay_filter.py} (81%) rename tests/{test_savitsky_golay_smooth.py => test_savitzky_golay_smooth.py} (79%) diff --git a/docs/source/networks.rst b/docs/source/networks.rst index ed17d815b4..b1d029c730 100644 --- a/docs/source/networks.rst +++ b/docs/source/networks.rst @@ -189,6 +189,11 @@ Layers .. autoclass:: BilateralFilter :members: +`SavitzkyGolayFilter` +~~~~~~~~~~~~~~~~~~~~~ +.. autoclass:: SavitzkyGolayFilter + :members: + `HilbertTransform` ~~~~~~~~~~~~~~~~~~ .. autoclass:: HilbertTransform diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index f7e075f376..3c9b3d62fc 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -186,6 +186,12 @@ Intensity :members: :special-members: __call__ +`SavitzkyGolaySmooth` +""""""""""""""""""""" +.. autoclass:: SavitzkyGolaySmooth + :members: + :special-members: __call__ + `GaussianSmooth` """""""""""""""" .. autoclass:: GaussianSmooth diff --git a/monai/networks/layers/__init__.py b/monai/networks/layers/__init__.py index 8a603d1123..49c18eb5bf 100644 --- a/monai/networks/layers/__init__.py +++ b/monai/networks/layers/__init__.py @@ -19,8 +19,8 @@ GaussianFilter, HilbertTransform, Reshape, - SavitskyGolayFilter, + SavitzkyGolayFilter, SkipConnection, separable_filtering, ) -from .spatial_transforms import AffineTransform, grid_count, grid_grad, grid_pull, grid_push \ No newline at end of file +from .spatial_transforms import AffineTransform, grid_count, grid_grad, grid_pull, grid_push diff --git a/monai/networks/layers/simplelayers.py b/monai/networks/layers/simplelayers.py index d263cfa5a7..17a1c69f60 100644 --- a/monai/networks/layers/simplelayers.py +++ b/monai/networks/layers/simplelayers.py @@ -39,9 +39,9 @@ "LLTM", "Reshape", "separable_filtering", - "SavitskyGolayFilter", + "SavitzkyGolayFilter", "HilbertTransform", - "ChannelPad" + "ChannelPad", ] @@ -174,12 +174,14 @@ def separable_filtering( x: the input image. must have shape (batch, channels, H[, W, ...]). kernels: kernel along each spatial dimension. could be a single kernel (duplicated for all dimension), or `spatial_dims` number of kernels. - mode (string, optional): padding mode passed to convolution class. ``'zeros'``, ``'reflect'``, ``'replicate'`` or - ``'circular'``. Default: ``'zeros'``. See torch.nn.Conv1d() for more information. + mode (string, optional): padding mode passed to convolution class. ``'zeros'``, ``'reflect'``, ``'replicate'`` + or ``'circular'``. Default: ``'zeros'``. Modes other than ``'zeros'`` require PyTorch version >= 1.5.1. See + torch.nn.Conv1d() for more information. Raises: TypeError: When ``x`` is not a ``torch.Tensor``. """ + if not torch.is_tensor(x): raise TypeError(f"x must be a torch.Tensor but is {type(x).__name__}.") @@ -198,30 +200,27 @@ def _conv(input_: torch.Tensor, d: int) -> torch.Tensor: s[d + 2] = -1 _kernel = kernels[d].reshape(s) # if filter kernel is unity, don't convolve - if torch.equal(_kernel.squeeze(), torch.ones(1, device=_kernel.device)): + if _kernel.numel() == 1 and _kernel[0] == 1: return _conv(input_, d - 1) _kernel = _kernel.repeat([n_chs, 1] + [1] * spatial_dims) _padding = [0] * spatial_dims _padding[d] = _paddings[d] - - if mode == "zeros": # if zero padding (default), can use functional convolution - conv_type = [F.conv1d, F.conv2d, F.conv3d][spatial_dims - 1] - return conv_type(input=_conv(input_, d - 1), weight=_kernel, padding=_padding, groups=n_chs) - else: - conv_type = [ - nn.Conv1d(n_chs, n_chs, _kernel.shape, padding=_padding, groups=n_chs, bias=False, padding_mode=mode), - nn.Conv2d(n_chs, n_chs, _kernel.shape, padding=_padding, groups=n_chs, bias=False, padding_mode=mode), - nn.Conv3d(n_chs, n_chs, _kernel.shape, padding=_padding, groups=n_chs, bias=False, padding_mode=mode), - ][spatial_dims - 1] - conv_type.weight = torch.nn.Parameter(_kernel, requires_grad=_kernel.requires_grad) - return conv_type(input=_conv(input_, d - 1)) + conv_type = [F.conv1d, F.conv2d, F.conv3d][spatial_dims - 1] + # translate padding for input to torch.nn.functional.pad + _reversed_padding_repeated_twice = [p for p in reversed(_padding) for _ in range(2)] + pad_mode = "constant" if mode == "zeros" else mode + return conv_type( + input=_conv(F.pad(input_, _reversed_padding_repeated_twice, mode=pad_mode), d - 1), + weight=_kernel, + groups=n_chs, + ) return _conv(x, spatial_dims - 1) -class SavitskyGolayFilter(nn.Module): +class SavitzkyGolayFilter(nn.Module): """ - Convolve a Tensor along a particular axis with a Savitsky-Golay kernel. + Convolve a Tensor along a particular axis with a Savitzky-Golay kernel. Args: window_length: Length of the filter window, must be a positive odd integer. @@ -247,7 +246,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x: Tensor or array-like to filter. Must be real, in shape ``[Batch, chns, spatial1, spatial2, ...]`` and have a device type of ``'cpu'``. Returns: - torch.Tensor: ``x`` filtered by Savitsky-Golay kernel with window length ``self.window_length`` using + torch.Tensor: ``x`` filtered by Savitzky-Golay kernel with window length ``self.window_length`` using polynomials of order ``self.order``, along axis specified in ``self.axis``. """ @@ -282,9 +281,7 @@ def _make_coeffs(window_length, order): if rem == 0: raise ValueError("window_length must be odd.") - idx = torch.arange( - window_length - half_length - 1, -half_length - 1, -1, dtype=torch.float, device="cpu" - ) + idx = torch.arange(window_length - half_length - 1, -half_length - 1, -1, dtype=torch.float, device="cpu") a = idx ** torch.arange(order + 1, dtype=torch.float, device="cpu").reshape(-1, 1) y = torch.zeros(order + 1, dtype=torch.float, device="cpu") y[0] = 1.0 diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index db6900a345..a1d9f12670 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -104,7 +104,7 @@ RandHistogramShift, RandScaleIntensity, RandShiftIntensity, - SavitskyGolaySmooth, + SavitzkyGolaySmooth, ScaleIntensity, ScaleIntensityRange, ScaleIntensityRangePercentiles, diff --git a/monai/transforms/intensity/array.py b/monai/transforms/intensity/array.py index 307537590f..bd52a50ba7 100644 --- a/monai/transforms/intensity/array.py +++ b/monai/transforms/intensity/array.py @@ -20,7 +20,7 @@ import numpy as np import torch -from monai.networks.layers import GaussianFilter, HilbertTransform, SavitskyGolayFilter +from monai.networks.layers import GaussianFilter, HilbertTransform, SavitzkyGolayFilter from monai.transforms.compose import Randomizable, Transform from monai.transforms.utils import rescale_array from monai.utils import PT_BEFORE_1_7, InvalidPyTorchVersionError, dtype_torch_to_numpy, ensure_tuple_size @@ -39,7 +39,7 @@ "ScaleIntensityRangePercentiles", "MaskIntensity", "DetectEnvelope", - "SavitskyGolaySmooth", + "SavitzkyGolaySmooth", "GaussianSmooth", "RandGaussianSmooth", "GaussianSharpen", @@ -534,9 +534,9 @@ def __call__(self, img: np.ndarray, mask_data: Optional[np.ndarray] = None) -> n return img * mask_data_ -class SavitskyGolaySmooth(Transform): +class SavitzkyGolaySmooth(Transform): """ - Smooth the input data along the given axis using a Savitsky-Golay filter. + Smooth the input data along the given axis using a Savitzky-Golay filter. Args: window_length: Length of the filter window, must be a positive odd integer. @@ -566,7 +566,7 @@ def __call__(self, img: np.ndarray) -> np.ndarray: """ # add one to transform axis because a batch axis will be added at dimension 0 - savgol_filter = SavitskyGolayFilter(self.window_length, self.order, self.axis + 1, self.mode) + savgol_filter = SavitzkyGolayFilter(self.window_length, self.order, self.axis + 1, self.mode) # convert to Tensor and add Batch axis expected by HilbertTransform input_data = torch.as_tensor(np.ascontiguousarray(img)).unsqueeze(0) return savgol_filter(input_data).squeeze(0).numpy() diff --git a/tests/test_savitsky_golay_filter.py b/tests/test_savitzky_golay_filter.py similarity index 81% rename from tests/test_savitsky_golay_filter.py rename to tests/test_savitzky_golay_filter.py index f13ead2bfe..d76c42c15f 100644 --- a/tests/test_savitsky_golay_filter.py +++ b/tests/test_savitzky_golay_filter.py @@ -15,7 +15,7 @@ import torch from parameterized import parameterized -from monai.networks.layers import SavitskyGolayFilter +from monai.networks.layers import SavitzkyGolayFilter from tests.utils import skip_if_no_cuda # Zero-padding trivial tests @@ -97,40 +97,56 @@ ] -class TestSavitskyGolayCPU(unittest.TestCase): +class TestSavitzkyGolayCPU(unittest.TestCase): @parameterized.expand( [ TEST_CASE_SINGLE_VALUE, TEST_CASE_1D, TEST_CASE_2D_AXIS_2, TEST_CASE_2D_AXIS_3, - TEST_CASE_SINGLE_VALUE_REP, - TEST_CASE_1D_REP, - TEST_CASE_2D_AXIS_2_REP, - TEST_CASE_2D_AXIS_3_REP, TEST_CASE_SINE_SMOOTH, ] ) def test_value(self, arguments, image, expected_data, atol): - result = SavitskyGolayFilter(**arguments)(image) + result = SavitzkyGolayFilter(**arguments)(image) + np.testing.assert_allclose(result, expected_data, atol=atol) + + +class TestSavitzkyGolayCPUREP(unittest.TestCase): + @parameterized.expand( + [TEST_CASE_SINGLE_VALUE_REP, TEST_CASE_1D_REP, TEST_CASE_2D_AXIS_2_REP, TEST_CASE_2D_AXIS_3_REP] + ) + def test_value(self, arguments, image, expected_data, atol): + result = SavitzkyGolayFilter(**arguments)(image) np.testing.assert_allclose(result, expected_data, atol=atol) @skip_if_no_cuda -class TestSavitskyGolayGPU(unittest.TestCase): +class TestSavitzkyGolayGPU(unittest.TestCase): @parameterized.expand( [ TEST_CASE_SINGLE_VALUE, TEST_CASE_1D, TEST_CASE_2D_AXIS_2, TEST_CASE_2D_AXIS_3, + TEST_CASE_SINE_SMOOTH, + ] + ) + def test_value(self, arguments, image, expected_data, atol): + result = SavitzkyGolayFilter(**arguments)(image.to(device="cuda")) + np.testing.assert_allclose(result.cpu(), expected_data, atol=atol) + + +@skip_if_no_cuda +class TestSavitzkyGolayGPUREP(unittest.TestCase): + @parameterized.expand( + [ TEST_CASE_SINGLE_VALUE_REP, TEST_CASE_1D_REP, TEST_CASE_2D_AXIS_2_REP, TEST_CASE_2D_AXIS_3_REP, - TEST_CASE_SINE_SMOOTH, ] ) def test_value(self, arguments, image, expected_data, atol): - result = SavitskyGolayFilter(**arguments)(image.to(device="cuda")) + result = SavitzkyGolayFilter(**arguments)(image.to(device="cuda")) np.testing.assert_allclose(result.cpu(), expected_data, atol=atol) diff --git a/tests/test_savitsky_golay_smooth.py b/tests/test_savitzky_golay_smooth.py similarity index 79% rename from tests/test_savitsky_golay_smooth.py rename to tests/test_savitzky_golay_smooth.py index dcd19d7f28..2be0da1360 100644 --- a/tests/test_savitsky_golay_smooth.py +++ b/tests/test_savitzky_golay_smooth.py @@ -14,7 +14,7 @@ import numpy as np from parameterized import parameterized -from monai.transforms import SavitskyGolaySmooth +from monai.transforms import SavitzkyGolaySmooth # Zero-padding trivial tests @@ -56,10 +56,15 @@ ] -class TestSavitskyGolaySmooth(unittest.TestCase): - @parameterized.expand( - [TEST_CASE_SINGLE_VALUE, TEST_CASE_2D_AXIS_2, TEST_CASE_SINGLE_VALUE_REP, TEST_CASE_SINE_SMOOTH] - ) +class TestSavitzkyGolaySmooth(unittest.TestCase): + @parameterized.expand([TEST_CASE_SINGLE_VALUE, TEST_CASE_2D_AXIS_2, TEST_CASE_SINE_SMOOTH]) def test_value(self, arguments, image, expected_data, atol): - result = SavitskyGolaySmooth(**arguments)(image) - np.testing.assert_allclose(result, expected_data, atol=atol) \ No newline at end of file + result = SavitzkyGolaySmooth(**arguments)(image) + np.testing.assert_allclose(result, expected_data, atol=atol) + + +class TestSavitzkyGolaySmoothREP(unittest.TestCase): + @parameterized.expand([TEST_CASE_SINGLE_VALUE_REP]) + def test_value(self, arguments, image, expected_data, atol): + result = SavitzkyGolaySmooth(**arguments)(image) + np.testing.assert_allclose(result, expected_data, atol=atol)