diff --git a/docs/source/networks.rst b/docs/source/networks.rst index 6a05d72b66..cf383d2908 100644 --- a/docs/source/networks.rst +++ b/docs/source/networks.rst @@ -203,6 +203,10 @@ Layers .. autoclass:: BilateralFilter :members: +`PHLFilter` +~~~~~~~~~~~~~~~~~ +.. autoclass:: PHLFilter + `SavitzkyGolayFilter` ~~~~~~~~~~~~~~~~~~~~~ .. autoclass:: SavitzkyGolayFilter diff --git a/monai/csrc/ext.cpp b/monai/csrc/ext.cpp index c96e081a95..2e0644bc78 100644 --- a/monai/csrc/ext.cpp +++ b/monai/csrc/ext.cpp @@ -21,6 +21,7 @@ limitations under the License. PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { // filtering m.def("bilateral_filter", &BilateralFilter, "Bilateral Filter"); + m.def("phl_filter", &PermutohedralFilter, "Permutohedral Filter"); // lltm m.def("lltm_forward", &lltm_forward, "LLTM forward"); diff --git a/monai/csrc/filtering/bilateral/bilateralfilter_cpu.cpp b/monai/csrc/filtering/bilateral/bilateralfilter_cpu.cpp index ea56ff7526..474d24b4fa 100644 --- a/monai/csrc/filtering/bilateral/bilateralfilter_cpu.cpp +++ b/monai/csrc/filtering/bilateral/bilateralfilter_cpu.cpp @@ -158,7 +158,7 @@ torch::Tensor BilateralFilterCpu(torch::Tensor inputTensor, float spatialSigma, // Preparing output tensor. torch::Tensor outputTensor = torch::zeros_like(inputTensor); - AT_DISPATCH_FLOATING_TYPES_AND_HALF(inputTensor.type(), "BilateralFilterCpu", ([&] { + AT_DISPATCH_FLOATING_TYPES_AND_HALF(inputTensor.scalar_type(), "BilateralFilterCpu", ([&] { BilateralFilterCpu( inputTensor, outputTensor, spatialSigma, colorSigma); })); diff --git a/monai/csrc/filtering/bilateral/bilateralfilter_cpu_phl.cpp b/monai/csrc/filtering/bilateral/bilateralfilter_cpu_phl.cpp index 26f4f1b54f..1fb48cb6c9 100644 --- a/monai/csrc/filtering/bilateral/bilateralfilter_cpu_phl.cpp +++ b/monai/csrc/filtering/bilateral/bilateralfilter_cpu_phl.cpp @@ -62,13 +62,12 @@ void BilateralFilterPHLCpu( } // Filtering data with respect to the features. - scalar_t* output = - PermutohedralCPU(data, features, desc.channelCount, featureChannels, desc.channelStride); + PermutohedralCPU(data, features, desc.channelCount, featureChannels, desc.channelStride); // Writing output tensor. for (int i = 0; i < desc.channelStride; i++) { for (int c = 0; c < desc.channelCount; c++) { - outputTensorData[batchOffset + i + c * desc.channelStride] = output[i * desc.channelCount + c]; + outputTensorData[batchOffset + i + c * desc.channelStride] = data[i * desc.channelCount + c]; } } } @@ -81,7 +80,7 @@ void BilateralFilterPHLCpu( torch::Tensor BilateralFilterPHLCpu(torch::Tensor inputTensor, float spatialSigma, float colorSigma) { torch::Tensor outputTensor = torch::zeros_like(inputTensor); - AT_DISPATCH_FLOATING_TYPES(inputTensor.type(), "BilateralFilterPhlCpu", ([&] { + AT_DISPATCH_FLOATING_TYPES(inputTensor.scalar_type(), "BilateralFilterPhlCpu", ([&] { BilateralFilterPHLCpu(inputTensor, outputTensor, spatialSigma, colorSigma); })); diff --git a/monai/csrc/filtering/bilateral/bilateralfilter_cuda.cu b/monai/csrc/filtering/bilateral/bilateralfilter_cuda.cu index 44c8172870..4477ce5845 100644 --- a/monai/csrc/filtering/bilateral/bilateralfilter_cuda.cu +++ b/monai/csrc/filtering/bilateral/bilateralfilter_cuda.cu @@ -36,6 +36,9 @@ __global__ void BilateralFilterCudaKernel1D(scalar_t* input, scalar_t* output) { int homeOffset = blockIdx.x * blockDim.x + threadIdx.x; int batchOffset = blockIdx.y * cBatchStride; + if (homeOffset >= cColorStride) + return; + scalar_t weightSum = 0; for (int kernelOffset = 0; kernelOffset < cKernelSize; kernelOffset++) { @@ -79,6 +82,9 @@ __global__ void BilateralFilterCudaKernel2D(scalar_t* input, scalar_t* output) { int homeOffset = blockIdx.x * blockDim.x + threadIdx.x; int batchOffset = blockIdx.y * cBatchStride; + if (homeOffset >= cColorStride) + return; + int homeX = homeOffset / cStrides[0]; int homeY = (homeOffset - homeX * cStrides[0]) / cStrides[1]; @@ -132,6 +138,9 @@ __global__ void BilateralFilterCudaKernel3D(scalar_t* input, scalar_t* output) { int homeOffset = blockIdx.x * blockDim.x + threadIdx.x; int batchOffset = blockIdx.y * cBatchStride; + if (homeOffset >= cColorStride) + return; + int homeX = homeOffset / cStrides[0]; int homeY = (homeOffset - homeX * cStrides[0]) / cStrides[1]; int homeZ = (homeOffset - homeX * cStrides[0] - homeY * cStrides[1]) / cStrides[2]; @@ -211,22 +220,27 @@ void BilateralFilterCuda(torch::Tensor inputTensor, torch::Tensor outputTensor, cudaMemcpyToSymbol(cKernel, kernel, sizeof(float) * kernelSize); cudaMemcpyToSymbol(cColorExponentFactor, &colorExponentFactor, sizeof(float)); +#define BLOCK_SIZE 32 + AT_DISPATCH_FLOATING_TYPES_AND_HALF( - inputTensor.type(), "BilateralFilterCudaKernel", ([&] { + inputTensor.scalar_type(), "BilateralFilterCudaKernel", ([&] { // Dispatch kernel. (Partial template function specialisation not supported at present so using this switch // instead) switch (D) { case (1): - BilateralFilterCudaKernel1D<<>>( - inputTensor.data_ptr(), outputTensor.data_ptr()); + BilateralFilterCudaKernel1D + <<>>( + inputTensor.data_ptr(), outputTensor.data_ptr()); break; case (2): - BilateralFilterCudaKernel2D<<>>( - inputTensor.data_ptr(), outputTensor.data_ptr()); + BilateralFilterCudaKernel2D + <<>>( + inputTensor.data_ptr(), outputTensor.data_ptr()); break; case (3): - BilateralFilterCudaKernel3D<<>>( - inputTensor.data_ptr(), outputTensor.data_ptr()); + BilateralFilterCudaKernel3D + <<>>( + inputTensor.data_ptr(), outputTensor.data_ptr()); break; } })); diff --git a/monai/csrc/filtering/bilateral/bilateralfilter_cuda_phl.cu b/monai/csrc/filtering/bilateral/bilateralfilter_cuda_phl.cu index 1c8a163ec5..603ab689cf 100644 --- a/monai/csrc/filtering/bilateral/bilateralfilter_cuda_phl.cu +++ b/monai/csrc/filtering/bilateral/bilateralfilter_cuda_phl.cu @@ -30,6 +30,9 @@ __global__ void FeatureCreation(const scalar_t* inputTensor, scalar_t* outputDat int elementIndex = blockIdx.x * blockDim.x + threadIdx.x; int batchIndex = blockIdx.y; + if (elementIndex >= cChannelStride) + return; + int dataBatchOffset = batchIndex * cBatchStride; int featureBatchOffset = batchIndex * (D + C) * cChannelStride; @@ -56,6 +59,10 @@ template __global__ void WriteOutput(const scalar_t* data, scalar_t* outputTensor) { int elementIndex = blockIdx.x * blockDim.x + threadIdx.x; int batchIndex = blockIdx.y; + + if (elementIndex >= cChannelStride) + return; + int batchOffset = batchIndex * cBatchStride; #pragma unroll @@ -95,9 +102,12 @@ void BilateralFilterPHLCuda( cudaMemcpyToSymbol(cInvSpatialSigma, &invSpatialSigma, sizeof(float)); cudaMemcpyToSymbol(cInvColorSigma, &invColorSigma, sizeof(float)); +#define BLOCK_SIZE 32 + // Creating features FeatureCreation - <<>>(inputTensorData, data, features); + <<>>( + inputTensorData, data, features); // Filtering data with respect to the features for each sample in batch for (int batchIndex = 0; batchIndex < desc.batchCount; batchIndex++) { @@ -108,7 +118,8 @@ void BilateralFilterPHLCuda( } // Writing output - WriteOutput<<>>(data, outputTensorData); + WriteOutput<<>>( + data, outputTensorData); cudaFree(data); cudaFree(features); @@ -119,7 +130,7 @@ torch::Tensor BilateralFilterPHLCuda(torch::Tensor inputTensor, float spatialSig torch::Tensor outputTensor = torch::zeros_like(inputTensor); #define CASE(c, d) \ - AT_DISPATCH_FLOATING_TYPES(inputTensor.type(), "BilateralFilterCudaPHL", ([&] { \ + AT_DISPATCH_FLOATING_TYPES(inputTensor.scalar_type(), "BilateralFilterCudaPHL", ([&] { \ BilateralFilterPHLCuda( \ inputTensor, outputTensor, spatialSigma, colorSigma); \ })); diff --git a/monai/csrc/filtering/filtering.h b/monai/csrc/filtering/filtering.h index be348e9183..25186b182a 100644 --- a/monai/csrc/filtering/filtering.h +++ b/monai/csrc/filtering/filtering.h @@ -13,4 +13,5 @@ limitations under the License. #pragma once -#include "bilateral/bilateral.h" \ No newline at end of file +#include "bilateral/bilateral.h" +#include "permutohedral/permutohedral.h" \ No newline at end of file diff --git a/monai/csrc/filtering/permutohedral/hash_table.cu b/monai/csrc/filtering/permutohedral/hash_table.cuh similarity index 96% rename from monai/csrc/filtering/permutohedral/hash_table.cu rename to monai/csrc/filtering/permutohedral/hash_table.cuh index bedf0c1efc..7d9d7eb163 100644 --- a/monai/csrc/filtering/permutohedral/hash_table.cu +++ b/monai/csrc/filtering/permutohedral/hash_table.cuh @@ -90,9 +90,14 @@ static scalar_t* createHashTable(int capacity) { template static void destroyHashTable() { #ifndef LINEAR_D_MEMORY - cudaFree(table_keys); + signed short* keys; + cudaMemcpyFromSymbol(&keys, table_keys, sizeof(unsigned int*)); + cudaFree(keys); #endif - cudaFree(table_entries); + + int* entries; + cudaMemcpyFromSymbol(&entries, table_entries, sizeof(int*)); + cudaFree(entries); } template diff --git a/monai/csrc/filtering/permutohedral/permutohedral.cpp b/monai/csrc/filtering/permutohedral/permutohedral.cpp new file mode 100644 index 0000000000..5d6916b8f4 --- /dev/null +++ b/monai/csrc/filtering/permutohedral/permutohedral.cpp @@ -0,0 +1,71 @@ +#include "utils/common_utils.h" +#include "utils/meta_macros.h" + +#include "permutohedral.h" + +torch::Tensor PermutohedralFilter(torch::Tensor input, torch::Tensor features) { + input = input.contiguous(); + + int batchCount = input.size(0); + int batchStride = input.stride(0); + int elementCount = input.stride(1); + int channelCount = input.size(1); + int featureCount = features.size(1); + +// movedim not support in torch < 1.7.1 +#if MONAI_TORCH_VERSION >= 10701 + torch::Tensor data = input.clone().movedim(1, -1).contiguous(); + features = features.movedim(1, -1).contiguous(); +#else + torch::Tensor data = input.clone(); + features = features; + + for (int i = 1; i < input.dim() - 1; i++) { + data = data.transpose(i, i + 1); + features = features.transpose(i, i + 1); + } + + data = data.contiguous(); + features = features.contiguous(); +#endif + +#ifdef WITH_CUDA + if (torch::cuda::is_available() && data.is_cuda()) { + CHECK_CONTIGUOUS_CUDA(data); + +#define CASE(dc, fc) \ + AT_DISPATCH_FLOATING_TYPES(data.scalar_type(), "PermutohedralCuda", ([&] { \ + for (int batchIndex = 0; batchIndex < batchCount; batchIndex++) { \ + scalar_t* offsetData = data.data_ptr() + batchIndex * batchStride; \ + scalar_t* offsetFeatures = \ + features.data_ptr() + batchIndex * fc * elementCount; \ + PermutohedralCuda(offsetData, offsetFeatures, elementCount, true); \ + } \ + })); + SWITCH_AB(CASE, 16, 19, channelCount, featureCount); + + } else { +#endif + AT_DISPATCH_FLOATING_TYPES( + data.scalar_type(), "PermutohedralCPU", ([&] { + for (int batchIndex = 0; batchIndex < batchCount; batchIndex++) { + scalar_t* offsetData = data.data_ptr() + batchIndex * batchStride; + scalar_t* offsetFeatures = features.data_ptr() + batchIndex * featureCount * elementCount; + PermutohedralCPU(offsetData, offsetFeatures, channelCount, featureCount, elementCount); + } + })); +#ifdef WITH_CUDA + } +#endif + +// movedim not support in torch < 1.7.1 +#if MONAI_TORCH_VERSION >= 10701 + data = data.movedim(-1, 1); +#else + for (int i = input.dim() - 1; i > 1; i--) { + data = data.transpose(i - 1, i); + } +#endif + + return data; +} diff --git a/monai/csrc/filtering/permutohedral/permutohedral.h b/monai/csrc/filtering/permutohedral/permutohedral.h index 4b80c2bfc9..27b0ff4859 100644 --- a/monai/csrc/filtering/permutohedral/permutohedral.h +++ b/monai/csrc/filtering/permutohedral/permutohedral.h @@ -11,10 +11,14 @@ See the License for the specific language governing permissions and limitations under the License. */ +#include + #pragma once template -scalar_t* PermutohedralCPU(scalar_t* data, scalar_t* features, int dataChannels, int featureChannels, int elementCount); +void PermutohedralCPU(scalar_t* data, scalar_t* features, int dataChannels, int featureChannels, int elementCount); #ifdef WITH_CUDA template void PermutohedralCuda(scalar_t* data, scalar_t* features, int elementCount, bool accurate); #endif + +torch::Tensor PermutohedralFilter(torch::Tensor input, torch::Tensor features); \ No newline at end of file diff --git a/monai/csrc/filtering/permutohedral/permutohedral_cpu.cpp b/monai/csrc/filtering/permutohedral/permutohedral_cpu.cpp index 19b195908c..0876997448 100644 --- a/monai/csrc/filtering/permutohedral/permutohedral_cpu.cpp +++ b/monai/csrc/filtering/permutohedral/permutohedral_cpu.cpp @@ -215,7 +215,7 @@ class PermutohedralLattice { * im : image to be bilateral-filtered. * ref : reference image whose edges are to be respected. */ - static scalar_t* filter(scalar_t* data, scalar_t* features, int dataChannels, int featureChannels, int elementCount) { + static void filter(scalar_t* data, scalar_t* features, int dataChannels, int featureChannels, int elementCount) { // Create lattice PermutohedralLattice lattice(featureChannels, dataChannels + 1, elementCount); @@ -236,8 +236,6 @@ class PermutohedralLattice { lattice.blur(); // Slice from the lattice - scalar_t* outputData = new scalar_t[elementCount * dataChannels]; - lattice.beginSlice(); for (int i = 0, e = 0; e < elementCount; e++) { @@ -245,11 +243,9 @@ class PermutohedralLattice { scalar_t scale = 1.0f / col[dataChannels]; for (int c = 0; c < dataChannels; c++, i++) { - outputData[i] = col[c] * scale; + data[i] = col[c] * scale; } } - - return outputData; } /* Constructor @@ -498,19 +494,9 @@ class PermutohedralLattice { }; template -scalar_t* PermutohedralCPU( - scalar_t* data, - scalar_t* features, - int dataChannels, - int featureChannels, - int elementCount) { - return PermutohedralLattice::filter(data, features, dataChannels, featureChannels, elementCount); +void PermutohedralCPU(scalar_t* data, scalar_t* features, int dataChannels, int featureChannels, int elementCount) { + PermutohedralLattice::filter(data, features, dataChannels, featureChannels, elementCount); } -template float* PermutohedralCPU(float* data, float* features, int dataChannels, int featureChannels, int elementCount); -template double* PermutohedralCPU( - double* data, - double* features, - int dataChannels, - int featureChannels, - int elementCount); \ No newline at end of file +template void PermutohedralCPU(float* data, float* features, int dataChannels, int featureChannels, int elementCount); +template void PermutohedralCPU(double* data, double* features, int dataChannels, int featureChannels, int elementCount); \ No newline at end of file diff --git a/monai/csrc/filtering/permutohedral/permutohedral_cuda.cu b/monai/csrc/filtering/permutohedral/permutohedral_cuda.cu index 94c5b90659..b87a88a84f 100644 --- a/monai/csrc/filtering/permutohedral/permutohedral_cuda.cu +++ b/monai/csrc/filtering/permutohedral/permutohedral_cuda.cu @@ -46,7 +46,7 @@ SOFTWARE. #include #include -#include "hash_table.cu" +#include "hash_table.cuh" #include "utils/meta_macros.h" template diff --git a/monai/networks/layers/__init__.py b/monai/networks/layers/__init__.py index 49c18eb5bf..ba61774a96 100644 --- a/monai/networks/layers/__init__.py +++ b/monai/networks/layers/__init__.py @@ -11,7 +11,7 @@ from .convutils import calculate_out_shape, gaussian_1d, polyval, same_padding, stride_minus_kernel_padding from .factories import Act, Conv, Dropout, LayerFactory, Norm, Pad, Pool, split_args -from .filtering import BilateralFilter +from .filtering import BilateralFilter, PHLFilter from .simplelayers import ( LLTM, ChannelPad, diff --git a/monai/networks/layers/filtering.py b/monai/networks/layers/filtering.py index 420851d755..83a33bc609 100644 --- a/monai/networks/layers/filtering.py +++ b/monai/networks/layers/filtering.py @@ -15,7 +15,7 @@ _C, _ = optional_import("monai._C") -__all__ = ["BilateralFilter"] +__all__ = ["BilateralFilter", "PHLFilter"] class BilateralFilter(torch.autograd.Function): @@ -56,3 +56,43 @@ def backward(ctx, grad_output): spatial_sigma, color_sigma, fast_approx = ctx.saved_variables grad_input = _C.bilateral_filter(grad_output, spatial_sigma, color_sigma, fast_approx) return grad_input + + +class PHLFilter(torch.autograd.Function): + """ + Filters input based on arbitrary feature vectors. Uses a permutohedral + lattice data structure to efficiently approximate n-dimensional gaussian + filtering. Complexity is broadly independant of kernel size. Most applicable + to higher filter dimensions and larger kernel sizes. + + See: + https://graphics.stanford.edu/papers/permutohedral/ + + Args: + input: input tensor to be filtered. + + features: feature tensor used to filter the input. + + sigmas: the standard deviations of each feature in the filter. + + Returns: + output (torch.Tensor): output tensor. + """ + + @staticmethod + def forward(ctx, input, features, sigmas=None): + + scaled_features = features + if sigmas is not None: + for i in range(features.size(1)): + scaled_features[:, i, ...] /= sigmas[i] + + ctx.save_for_backward(scaled_features) + output_data = _C.phl_filter(input, scaled_features) + return output_data + + @staticmethod + def backward(ctx, grad_output): + scaled_features = ctx.saved_variables + grad_input = PHLFilter.scale(grad_output, scaled_features) + return grad_input diff --git a/tests/test_phl_cpu.py b/tests/test_phl_cpu.py new file mode 100644 index 0000000000..f0e62cbddb --- /dev/null +++ b/tests/test_phl_cpu.py @@ -0,0 +1,258 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import torch +from parameterized import parameterized + +from monai.networks.layers.filtering import PHLFilter +from tests.utils import skip_if_no_cpp_extention + +TEST_CASES = [ + [ + # Case Descirption + "2 batches, 1 dimensions, 1 channels, 1 features", + # Sigmas + [1, 0.2], + # Input + [ + # Batch 0 + [ + # Channel 0 + [1, 0, 0, 0, 1] + ], + # Batch 1 + [ + # Channel 0 + [0, 0, 1, 0, 0] + ], + ], + # Features + [ + # Batch 0 + [ + # Channel 0 + [1, 0.2, 0.5, 0, 1], + ], + # Batch 1 + [ + # Channel 0 + [0.5, 0, 1, 1, 1] + ], + ], + # Expected + [ + # Batch 0 + [ + # Channel 0 + [0.468968, 0.364596, 0.4082, 0.332579, 0.468968] + ], + # Batch 1 + [ + # Channel 0 + [0.202473, 0.176527, 0.220995, 0.220995, 0.220995] + ], + ], + ], + [ + # Case Descirption + "1 batches, 1 dimensions, 3 channels, 1 features", + # Sigmas + [1], + # Input + [ + # Batch 0 + [ + # Channel 0 + [1, 0, 0, 0, 0], + # Channel 1 + [0, 0, 0, 0, 1], + # Channel 2 + [0, 0, 1, 0, 0], + ], + ], + # Features + [ + # Batch 0 + [ + # Channel 0 + [1, 0.2, 0.5, 0.2, 1], + ], + ], + # Expected + [ + # Batch 0 + [ + # Channel 0 + [0.229572, 0.182884, 0.202637, 0.182884, 0.229572], + # Channel 1 + [0.229572, 0.182884, 0.202637, 0.182884, 0.229572], + # Channel 2 + [0.201235, 0.208194, 0.205409, 0.208194, 0.201235], + ], + ], + ], + [ + # Case Descirption + "1 batches, 2 dimensions, 1 channels, 3 features", + # Sigmas + [5, 3, 3], + # Input + [ + # Batch 0 + [ + # Channel 0 + [[9, 9, 0, 0, 0], [9, 9, 0, 0, 0], [9, 9, 0, 0, 0], [9, 9, 6, 6, 6], [9, 9, 6, 6, 6]] + ], + ], + # Features + [ + # Batch 0 + [ + # Channel 0 + [[9, 9, 0, 0, 0], [9, 9, 0, 0, 0], [9, 9, 0, 0, 0], [9, 9, 6, 6, 6], [9, 9, 6, 6, 6]], + # Channel 1 + [[0, 1, 2, 3, 4], [0, 1, 2, 3, 4], [0, 1, 2, 3, 4], [0, 1, 2, 3, 4], [0, 1, 2, 3, 4]], + # Channel 2 + [[0, 0, 0, 0, 0], [1, 1, 1, 1, 1], [2, 2, 2, 2, 2], [3, 3, 3, 3, 3], [4, 4, 4, 4, 4]], + ], + ], + # Expected + [ + # Batch 0 + [ + # Channel 0 + [ + [7.696051, 7.427121, 1.191990, 1.156004, 1.157489], + [7.670297, 7.371155, 1.340232, 1.287871, 1.304018], + [7.639579, 7.365163, 1.473319, 1.397826, 1.416861], + [7.613517, 7.359183, 5.846500, 5.638952, 5.350098], + [7.598255, 7.458446, 5.912375, 5.583625, 5.233126], + ] + ], + ], + ], + [ + # Case Descirption + "1 batches, 3 dimensions, 1 channels, 1 features", + # Sigmas + [5, 3, 3], + # Input + [ + # Batch 0 + [ + # Channel 0 + [ + # Frame 0 + [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [9, 9, 9, 0, 0], [9, 9, 9, 0, 0], [9, 9, 9, 0, 0]], + # Frame 1 + [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [9, 9, 9, 0, 0], [9, 9, 9, 0, 0], [9, 9, 9, 0, 0]], + # Frame 2 + [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]], + # Frame 3 + [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]], + # Frame 4 + [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]], + ] + ], + ], + # Features + [ + # Batch 0 + [ + # Channel 0 + [ + # Frame 0 + [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [9, 9, 9, 0, 0], [9, 9, 9, 0, 0], [9, 9, 9, 0, 0]], + # Frame 1 + [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [9, 9, 9, 0, 0], [9, 9, 9, 0, 0], [9, 9, 9, 0, 0]], + # Frame 2 + [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]], + # Frame 3 + [[0, 0, 5, 5, 5], [0, 0, 5, 5, 5], [0, 0, 5, 5, 5], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]], + # Frame 4 + [[0, 0, 5, 5, 5], [0, 0, 5, 5, 5], [0, 0, 5, 5, 5], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]], + ] + ], + ], + # Expected + [ + # Batch 0 + [ + # Channel 0 + [ + # Frame 0 + [ + [0.284234, 0.284234, 0.284234, 0.284234, 0.284234], + [0.284234, 0.284234, 0.284234, 0.284234, 0.284234], + [3.578490, 3.578490, 3.578490, 0.284234, 0.284234], + [3.578490, 3.578490, 3.578490, 0.284234, 0.284234], + [3.578490, 3.578490, 3.578490, 0.284234, 0.284234], + ], + # Frame 1 + [ + [0.284234, 0.284234, 0.284234, 0.284234, 0.284234], + [0.284234, 0.284234, 0.284234, 0.284234, 0.284234], + [3.578490, 3.578490, 3.578490, 0.284234, 0.284234], + [3.578490, 3.578490, 3.578490, 0.284234, 0.284234], + [3.578490, 3.578490, 3.578490, 0.284234, 0.284234], + ], + # Frame 2 + [ + [0.284234, 0.284234, 0.284234, 0.284234, 0.284234], + [0.284234, 0.284234, 0.284234, 0.284234, 0.284234], + [0.284234, 0.284234, 0.284234, 0.284234, 0.284234], + [0.284234, 0.284234, 0.284234, 0.284234, 0.284234], + [0.284234, 0.284234, 0.284234, 0.284234, 0.284234], + ], + # Frame 3 + [ + [0.284234, 0.284234, 1.359728, 1.359728, 1.359728], + [0.284234, 0.284234, 1.359728, 1.359728, 1.359728], + [0.284234, 0.284234, 1.359728, 1.359728, 1.359728], + [0.284234, 0.284234, 0.284234, 0.284234, 0.284234], + [0.284234, 0.284234, 0.284234, 0.284234, 0.284234], + ], + # Frame 4 + [ + [0.284234, 0.284234, 1.359728, 1.359728, 1.359728], + [0.284234, 0.284234, 1.359728, 1.359728, 1.359728], + [0.284234, 0.284234, 1.359728, 1.359728, 1.359728], + [0.284234, 0.284234, 0.284234, 0.284234, 0.284234], + [0.284234, 0.284234, 0.284234, 0.284234, 0.284234], + ], + ] + ], + ], + ], +] + + +@skip_if_no_cpp_extention +class PHLFilterTestCaseCpu(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_cpu(self, test_case_description, sigmas, input, features, expected): + + # Create input tensors + input_tensor = torch.from_numpy(np.array(input)).to(dtype=torch.float, device=torch.device("cpu")) + feature_tensor = torch.from_numpy(np.array(features)).to(dtype=torch.float, device=torch.device("cpu")) + + # apply filter + output = PHLFilter.apply(input_tensor, feature_tensor, sigmas).cpu().numpy() + + # Ensure result are as expected + np.testing.assert_allclose(output, expected, atol=1e-4) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_phl_cuda.py b/tests/test_phl_cuda.py new file mode 100644 index 0000000000..8b89efce1a --- /dev/null +++ b/tests/test_phl_cuda.py @@ -0,0 +1,166 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import torch +from parameterized import parameterized + +from monai.networks.layers.filtering import PHLFilter +from tests.utils import skip_if_no_cpp_extention, skip_if_no_cuda + +TEST_CASES = [ + [ + # Case Descirption + "2 batches, 1 dimensions, 1 channels, 1 features", + # Sigmas + [1, 0.2], + # Input + [ + # Batch 0 + [ + # Channel 0 + [1, 0, 0, 0, 1] + ], + # Batch 1 + [ + # Channel 0 + [0, 0, 1, 0, 0] + ], + ], + # Features + [ + # Batch 0 + [ + # Channel 0 + [1, 0.2, 0.5, 0, 1], + ], + # Batch 1 + [ + # Channel 0 + [0.5, 0, 1, 1, 1] + ], + ], + # Expected + [ + # Batch 0 + [ + # Channel 0 + [0.468968, 0.364596, 0.408200, 0.332579, 0.468968] + ], + # Batch 1 + [ + # Channel 0 + [0.202473, 0.176527, 0.220995, 0.220995, 0.220995] + ], + ], + ], + [ + # Case Descirption + "1 batches, 1 dimensions, 3 channels, 1 features", + # Sigmas + [1], + # Input + [ + # Batch 0 + [ + # Channel 0 + [1, 0, 0, 0, 0], + # Channel 1 + [0, 0, 0, 0, 1], + # Channel 2 + [0, 0, 1, 0, 0], + ], + ], + # Features + [ + # Batch 0 + [ + # Channel 0 + [1, 0.2, 0.5, 0.2, 1], + ], + ], + # Expected + [ + # Batch 0 + [ + # Channel 0 + [0.229572, 0.182884, 0.202637, 0.182884, 0.229572], + # Channel 1 + [0.229572, 0.182884, 0.202637, 0.182884, 0.229572], + # Channel 2 + [0.201235, 0.208194, 0.205409, 0.208194, 0.201235], + ], + ], + ], + [ + # Case Descirption + "1 batches, 2 dimensions, 1 channels, 3 features", + # Sigmas + [5, 3, 3], + # Input + [ + # Batch 0 + [ + # Channel 0 + [[9, 9, 0, 0, 0], [9, 9, 0, 0, 0], [9, 9, 0, 0, 0], [9, 9, 6, 6, 6], [9, 9, 6, 6, 6]] + ], + ], + # Features + [ + # Batch 0 + [ + # Channel 0 + [[9, 9, 0, 0, 0], [9, 9, 0, 0, 0], [9, 9, 0, 0, 0], [9, 9, 6, 6, 6], [9, 9, 6, 6, 6]], + # Channel 1 + [[0, 1, 2, 3, 4], [0, 1, 2, 3, 4], [0, 1, 2, 3, 4], [0, 1, 2, 3, 4], [0, 1, 2, 3, 4]], + # Channel 2 + [[0, 0, 0, 0, 0], [1, 1, 1, 1, 1], [2, 2, 2, 2, 2], [3, 3, 3, 3, 3], [4, 4, 4, 4, 4]], + ], + ], + # Expected + [ + # Batch 0 + [ + # Channel 0 + [ + [7.792655, 7.511395, 0.953769, 0.860538, 0.912978], + [7.758870, 7.426762, 1.164386, 1.050956, 1.121830], + [7.733974, 7.429964, 1.405752, 1.244949, 1.320862], + [7.712976, 7.429060, 5.789552, 5.594258, 5.371737], + [7.701185, 7.492719, 5.860026, 5.538241, 5.281656], + ] + ], + ], + ], +] + + +@skip_if_no_cuda +@skip_if_no_cpp_extention +class PHLFilterTestCaseCuda(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_cuda(self, test_case_description, sigmas, input, features, expected): + + # Create input tensors + input_tensor = torch.from_numpy(np.array(input)).to(dtype=torch.float, device=torch.device("cuda")) + feature_tensor = torch.from_numpy(np.array(features)).to(dtype=torch.float, device=torch.device("cuda")) + + # apply filter + output = PHLFilter.apply(input_tensor, feature_tensor, sigmas).cpu().numpy() + + # Ensure result are as expected + np.testing.assert_allclose(output, expected, atol=1e-4) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_savitzky_golay_filter.py b/tests/test_savitzky_golay_filter.py index d76c42c15f..9163204810 100644 --- a/tests/test_savitzky_golay_filter.py +++ b/tests/test_savitzky_golay_filter.py @@ -1,152 +1,152 @@ -# Copyright 2020 MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import unittest - -import numpy as np -import torch -from parameterized import parameterized - -from monai.networks.layers import SavitzkyGolayFilter -from tests.utils import skip_if_no_cuda - -# Zero-padding trivial tests - -TEST_CASE_SINGLE_VALUE = [ - {"window_length": 3, "order": 1}, - torch.Tensor([1.0]).unsqueeze(0).unsqueeze(0), # Input data: Single value - torch.Tensor([1 / 3]).unsqueeze(0).unsqueeze(0), # Expected output: With a window length of 3 and polyorder 1 - # output should be equal to mean of 0, 1 and 0 = 1/3 (because input will be zero-padded and a linear fit performed) - 1e-15, # absolute tolerance -] - -TEST_CASE_1D = [ - {"window_length": 3, "order": 1}, - torch.Tensor([1.0, 1.0, 1.0]).unsqueeze(0).unsqueeze(0), # Input data - torch.Tensor([2 / 3, 1.0, 2 / 3]) - .unsqueeze(0) - .unsqueeze(0), # Expected output: zero padded, so linear interpolation - # over length-3 windows will result in output of [2/3, 1, 2/3]. - 1e-15, # absolute tolerance -] - -TEST_CASE_2D_AXIS_2 = [ - {"window_length": 3, "order": 1}, # along default axis (2, first spatial dim) - torch.ones((3, 2)).unsqueeze(0).unsqueeze(0), - torch.Tensor([[2 / 3, 2 / 3], [1.0, 1.0], [2 / 3, 2 / 3]]).unsqueeze(0).unsqueeze(0), - 1e-15, # absolute tolerance -] - -TEST_CASE_2D_AXIS_3 = [ - {"window_length": 3, "order": 1, "axis": 3}, # along axis 3 (second spatial dim) - torch.ones((2, 3)).unsqueeze(0).unsqueeze(0), - torch.Tensor([[2 / 3, 1.0, 2 / 3], [2 / 3, 1.0, 2 / 3]]).unsqueeze(0).unsqueeze(0), - 1e-15, # absolute tolerance -] - -# Replicated-padding trivial tests - -TEST_CASE_SINGLE_VALUE_REP = [ - {"window_length": 3, "order": 1, "mode": "replicate"}, - torch.Tensor([1.0]).unsqueeze(0).unsqueeze(0), # Input data: Single value - torch.Tensor([1.0]).unsqueeze(0).unsqueeze(0), # Expected output: With a window length of 3 and polyorder 1 - # output will be equal to mean of [1, 1, 1] = 1 (input will be nearest-neighbour-padded and a linear fit performed) - 1e-15, # absolute tolerance -] - -TEST_CASE_1D_REP = [ - {"window_length": 3, "order": 1, "mode": "replicate"}, - torch.Tensor([1.0, 1.0, 1.0]).unsqueeze(0).unsqueeze(0), # Input data - torch.Tensor([1.0, 1.0, 1.0]).unsqueeze(0).unsqueeze(0), # Expected output: zero padded, so linear interpolation - # over length-3 windows will result in output of [2/3, 1, 2/3]. - 1e-15, # absolute tolerance -] - -TEST_CASE_2D_AXIS_2_REP = [ - {"window_length": 3, "order": 1, "mode": "replicate"}, # along default axis (2, first spatial dim) - torch.ones((3, 2)).unsqueeze(0).unsqueeze(0), - torch.Tensor([[1.0, 1.0], [1.0, 1.0], [1.0, 1.0]]).unsqueeze(0).unsqueeze(0), - 1e-15, # absolute tolerance -] - -TEST_CASE_2D_AXIS_3_REP = [ - {"window_length": 3, "order": 1, "axis": 3, "mode": "replicate"}, # along axis 3 (second spatial dim) - torch.ones((2, 3)).unsqueeze(0).unsqueeze(0), - torch.Tensor([[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]).unsqueeze(0).unsqueeze(0), - 1e-15, # absolute tolerance -] - -# Sine smoothing - -TEST_CASE_SINE_SMOOTH = [ - {"window_length": 3, "order": 1}, - # Sine wave with period equal to savgol window length (windowed to reduce edge effects). - torch.as_tensor(np.sin(2 * np.pi * 1 / 3 * np.arange(100)) * np.hanning(100)).unsqueeze(0).unsqueeze(0), - # Should be smoothed out to zeros - torch.zeros(100).unsqueeze(0).unsqueeze(0), - # tolerance chosen by examining output of SciPy.signal.savgol_filter when provided the above input - 2e-2, # absolute tolerance -] - - -class TestSavitzkyGolayCPU(unittest.TestCase): - @parameterized.expand( - [ - TEST_CASE_SINGLE_VALUE, - TEST_CASE_1D, - TEST_CASE_2D_AXIS_2, - TEST_CASE_2D_AXIS_3, - TEST_CASE_SINE_SMOOTH, - ] - ) - def test_value(self, arguments, image, expected_data, atol): - result = SavitzkyGolayFilter(**arguments)(image) - np.testing.assert_allclose(result, expected_data, atol=atol) - - -class TestSavitzkyGolayCPUREP(unittest.TestCase): - @parameterized.expand( - [TEST_CASE_SINGLE_VALUE_REP, TEST_CASE_1D_REP, TEST_CASE_2D_AXIS_2_REP, TEST_CASE_2D_AXIS_3_REP] - ) - def test_value(self, arguments, image, expected_data, atol): - result = SavitzkyGolayFilter(**arguments)(image) - np.testing.assert_allclose(result, expected_data, atol=atol) - - -@skip_if_no_cuda -class TestSavitzkyGolayGPU(unittest.TestCase): - @parameterized.expand( - [ - TEST_CASE_SINGLE_VALUE, - TEST_CASE_1D, - TEST_CASE_2D_AXIS_2, - TEST_CASE_2D_AXIS_3, - TEST_CASE_SINE_SMOOTH, - ] - ) - def test_value(self, arguments, image, expected_data, atol): - result = SavitzkyGolayFilter(**arguments)(image.to(device="cuda")) - np.testing.assert_allclose(result.cpu(), expected_data, atol=atol) - - -@skip_if_no_cuda -class TestSavitzkyGolayGPUREP(unittest.TestCase): - @parameterized.expand( - [ - TEST_CASE_SINGLE_VALUE_REP, - TEST_CASE_1D_REP, - TEST_CASE_2D_AXIS_2_REP, - TEST_CASE_2D_AXIS_3_REP, - ] - ) - def test_value(self, arguments, image, expected_data, atol): - result = SavitzkyGolayFilter(**arguments)(image.to(device="cuda")) - np.testing.assert_allclose(result.cpu(), expected_data, atol=atol) +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import torch +from parameterized import parameterized + +from monai.networks.layers import SavitzkyGolayFilter +from tests.utils import skip_if_no_cuda + +# Zero-padding trivial tests + +TEST_CASE_SINGLE_VALUE = [ + {"window_length": 3, "order": 1}, + torch.Tensor([1.0]).unsqueeze(0).unsqueeze(0), # Input data: Single value + torch.Tensor([1 / 3]).unsqueeze(0).unsqueeze(0), # Expected output: With a window length of 3 and polyorder 1 + # output should be equal to mean of 0, 1 and 0 = 1/3 (because input will be zero-padded and a linear fit performed) + 1e-15, # absolute tolerance +] + +TEST_CASE_1D = [ + {"window_length": 3, "order": 1}, + torch.Tensor([1.0, 1.0, 1.0]).unsqueeze(0).unsqueeze(0), # Input data + torch.Tensor([2 / 3, 1.0, 2 / 3]) + .unsqueeze(0) + .unsqueeze(0), # Expected output: zero padded, so linear interpolation + # over length-3 windows will result in output of [2/3, 1, 2/3]. + 1e-15, # absolute tolerance +] + +TEST_CASE_2D_AXIS_2 = [ + {"window_length": 3, "order": 1}, # along default axis (2, first spatial dim) + torch.ones((3, 2)).unsqueeze(0).unsqueeze(0), + torch.Tensor([[2 / 3, 2 / 3], [1.0, 1.0], [2 / 3, 2 / 3]]).unsqueeze(0).unsqueeze(0), + 1e-15, # absolute tolerance +] + +TEST_CASE_2D_AXIS_3 = [ + {"window_length": 3, "order": 1, "axis": 3}, # along axis 3 (second spatial dim) + torch.ones((2, 3)).unsqueeze(0).unsqueeze(0), + torch.Tensor([[2 / 3, 1.0, 2 / 3], [2 / 3, 1.0, 2 / 3]]).unsqueeze(0).unsqueeze(0), + 1e-15, # absolute tolerance +] + +# Replicated-padding trivial tests + +TEST_CASE_SINGLE_VALUE_REP = [ + {"window_length": 3, "order": 1, "mode": "replicate"}, + torch.Tensor([1.0]).unsqueeze(0).unsqueeze(0), # Input data: Single value + torch.Tensor([1.0]).unsqueeze(0).unsqueeze(0), # Expected output: With a window length of 3 and polyorder 1 + # output will be equal to mean of [1, 1, 1] = 1 (input will be nearest-neighbour-padded and a linear fit performed) + 1e-15, # absolute tolerance +] + +TEST_CASE_1D_REP = [ + {"window_length": 3, "order": 1, "mode": "replicate"}, + torch.Tensor([1.0, 1.0, 1.0]).unsqueeze(0).unsqueeze(0), # Input data + torch.Tensor([1.0, 1.0, 1.0]).unsqueeze(0).unsqueeze(0), # Expected output: zero padded, so linear interpolation + # over length-3 windows will result in output of [2/3, 1, 2/3]. + 1e-15, # absolute tolerance +] + +TEST_CASE_2D_AXIS_2_REP = [ + {"window_length": 3, "order": 1, "mode": "replicate"}, # along default axis (2, first spatial dim) + torch.ones((3, 2)).unsqueeze(0).unsqueeze(0), + torch.Tensor([[1.0, 1.0], [1.0, 1.0], [1.0, 1.0]]).unsqueeze(0).unsqueeze(0), + 1e-15, # absolute tolerance +] + +TEST_CASE_2D_AXIS_3_REP = [ + {"window_length": 3, "order": 1, "axis": 3, "mode": "replicate"}, # along axis 3 (second spatial dim) + torch.ones((2, 3)).unsqueeze(0).unsqueeze(0), + torch.Tensor([[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]).unsqueeze(0).unsqueeze(0), + 1e-15, # absolute tolerance +] + +# Sine smoothing + +TEST_CASE_SINE_SMOOTH = [ + {"window_length": 3, "order": 1}, + # Sine wave with period equal to savgol window length (windowed to reduce edge effects). + torch.as_tensor(np.sin(2 * np.pi * 1 / 3 * np.arange(100)) * np.hanning(100)).unsqueeze(0).unsqueeze(0), + # Should be smoothed out to zeros + torch.zeros(100).unsqueeze(0).unsqueeze(0), + # tolerance chosen by examining output of SciPy.signal.savgol_filter when provided the above input + 2e-2, # absolute tolerance +] + + +class TestSavitzkyGolayCPU(unittest.TestCase): + @parameterized.expand( + [ + TEST_CASE_SINGLE_VALUE, + TEST_CASE_1D, + TEST_CASE_2D_AXIS_2, + TEST_CASE_2D_AXIS_3, + TEST_CASE_SINE_SMOOTH, + ] + ) + def test_value(self, arguments, image, expected_data, atol): + result = SavitzkyGolayFilter(**arguments)(image) + np.testing.assert_allclose(result, expected_data, atol=atol) + + +class TestSavitzkyGolayCPUREP(unittest.TestCase): + @parameterized.expand( + [TEST_CASE_SINGLE_VALUE_REP, TEST_CASE_1D_REP, TEST_CASE_2D_AXIS_2_REP, TEST_CASE_2D_AXIS_3_REP] + ) + def test_value(self, arguments, image, expected_data, atol): + result = SavitzkyGolayFilter(**arguments)(image) + np.testing.assert_allclose(result, expected_data, atol=atol) + + +@skip_if_no_cuda +class TestSavitzkyGolayGPU(unittest.TestCase): + @parameterized.expand( + [ + TEST_CASE_SINGLE_VALUE, + TEST_CASE_1D, + TEST_CASE_2D_AXIS_2, + TEST_CASE_2D_AXIS_3, + TEST_CASE_SINE_SMOOTH, + ] + ) + def test_value(self, arguments, image, expected_data, atol): + result = SavitzkyGolayFilter(**arguments)(image.to(device="cuda")) + np.testing.assert_allclose(result.cpu(), expected_data, atol=atol) + + +@skip_if_no_cuda +class TestSavitzkyGolayGPUREP(unittest.TestCase): + @parameterized.expand( + [ + TEST_CASE_SINGLE_VALUE_REP, + TEST_CASE_1D_REP, + TEST_CASE_2D_AXIS_2_REP, + TEST_CASE_2D_AXIS_3_REP, + ] + ) + def test_value(self, arguments, image, expected_data, atol): + result = SavitzkyGolayFilter(**arguments)(image.to(device="cuda")) + np.testing.assert_allclose(result.cpu(), expected_data, atol=atol) diff --git a/tests/test_savitzky_golay_smooth.py b/tests/test_savitzky_golay_smooth.py index 2be0da1360..63dcce1b05 100644 --- a/tests/test_savitzky_golay_smooth.py +++ b/tests/test_savitzky_golay_smooth.py @@ -1,70 +1,70 @@ -# Copyright 2020 MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import unittest - -import numpy as np -from parameterized import parameterized - -from monai.transforms import SavitzkyGolaySmooth - -# Zero-padding trivial tests - -TEST_CASE_SINGLE_VALUE = [ - {"window_length": 3, "order": 1}, - np.expand_dims(np.array([1.0]), 0), # Input data: Single value - np.expand_dims(np.array([1 / 3]), 0), # Expected output: With a window length of 3 and polyorder 1 - # output should be equal to mean of 0, 1 and 0 = 1/3 (because input will be zero-padded and a linear fit performed) - 1e-15, # absolute tolerance -] - -TEST_CASE_2D_AXIS_2 = [ - {"window_length": 3, "order": 1, "axis": 2}, # along axis 2 (second spatial dim) - np.expand_dims(np.ones((2, 3)), 0), - np.expand_dims(np.array([[2 / 3, 1.0, 2 / 3], [2 / 3, 1.0, 2 / 3]]), 0), - 1e-15, # absolute tolerance -] - -# Replicated-padding trivial tests - -TEST_CASE_SINGLE_VALUE_REP = [ - {"window_length": 3, "order": 1, "mode": "replicate"}, - np.expand_dims(np.array([1.0]), 0), # Input data: Single value - np.expand_dims(np.array([1.0]), 0), # Expected output: With a window length of 3 and polyorder 1 - # output will be equal to mean of [1, 1, 1] = 1 (input will be nearest-neighbour-padded and a linear fit performed) - 1e-15, # absolute tolerance -] - -# Sine smoothing - -TEST_CASE_SINE_SMOOTH = [ - {"window_length": 3, "order": 1}, - # Sine wave with period equal to savgol window length (windowed to reduce edge effects). - np.expand_dims(np.sin(2 * np.pi * 1 / 3 * np.arange(100)) * np.hanning(100), 0), - # Should be smoothed out to zeros - np.expand_dims(np.zeros(100), 0), - # tolerance chosen by examining output of SciPy.signal.savgol_filter() when provided the above input - 2e-2, # absolute tolerance -] - - -class TestSavitzkyGolaySmooth(unittest.TestCase): - @parameterized.expand([TEST_CASE_SINGLE_VALUE, TEST_CASE_2D_AXIS_2, TEST_CASE_SINE_SMOOTH]) - def test_value(self, arguments, image, expected_data, atol): - result = SavitzkyGolaySmooth(**arguments)(image) - np.testing.assert_allclose(result, expected_data, atol=atol) - - -class TestSavitzkyGolaySmoothREP(unittest.TestCase): - @parameterized.expand([TEST_CASE_SINGLE_VALUE_REP]) - def test_value(self, arguments, image, expected_data, atol): - result = SavitzkyGolaySmooth(**arguments)(image) - np.testing.assert_allclose(result, expected_data, atol=atol) +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from parameterized import parameterized + +from monai.transforms import SavitzkyGolaySmooth + +# Zero-padding trivial tests + +TEST_CASE_SINGLE_VALUE = [ + {"window_length": 3, "order": 1}, + np.expand_dims(np.array([1.0]), 0), # Input data: Single value + np.expand_dims(np.array([1 / 3]), 0), # Expected output: With a window length of 3 and polyorder 1 + # output should be equal to mean of 0, 1 and 0 = 1/3 (because input will be zero-padded and a linear fit performed) + 1e-15, # absolute tolerance +] + +TEST_CASE_2D_AXIS_2 = [ + {"window_length": 3, "order": 1, "axis": 2}, # along axis 2 (second spatial dim) + np.expand_dims(np.ones((2, 3)), 0), + np.expand_dims(np.array([[2 / 3, 1.0, 2 / 3], [2 / 3, 1.0, 2 / 3]]), 0), + 1e-15, # absolute tolerance +] + +# Replicated-padding trivial tests + +TEST_CASE_SINGLE_VALUE_REP = [ + {"window_length": 3, "order": 1, "mode": "replicate"}, + np.expand_dims(np.array([1.0]), 0), # Input data: Single value + np.expand_dims(np.array([1.0]), 0), # Expected output: With a window length of 3 and polyorder 1 + # output will be equal to mean of [1, 1, 1] = 1 (input will be nearest-neighbour-padded and a linear fit performed) + 1e-15, # absolute tolerance +] + +# Sine smoothing + +TEST_CASE_SINE_SMOOTH = [ + {"window_length": 3, "order": 1}, + # Sine wave with period equal to savgol window length (windowed to reduce edge effects). + np.expand_dims(np.sin(2 * np.pi * 1 / 3 * np.arange(100)) * np.hanning(100), 0), + # Should be smoothed out to zeros + np.expand_dims(np.zeros(100), 0), + # tolerance chosen by examining output of SciPy.signal.savgol_filter() when provided the above input + 2e-2, # absolute tolerance +] + + +class TestSavitzkyGolaySmooth(unittest.TestCase): + @parameterized.expand([TEST_CASE_SINGLE_VALUE, TEST_CASE_2D_AXIS_2, TEST_CASE_SINE_SMOOTH]) + def test_value(self, arguments, image, expected_data, atol): + result = SavitzkyGolaySmooth(**arguments)(image) + np.testing.assert_allclose(result, expected_data, atol=atol) + + +class TestSavitzkyGolaySmoothREP(unittest.TestCase): + @parameterized.expand([TEST_CASE_SINGLE_VALUE_REP]) + def test_value(self, arguments, image, expected_data, atol): + result = SavitzkyGolaySmooth(**arguments)(image) + np.testing.assert_allclose(result, expected_data, atol=atol)