diff --git a/docs/source/networks.rst b/docs/source/networks.rst index fc16e8c86e..ed17d815b4 100644 --- a/docs/source/networks.rst +++ b/docs/source/networks.rst @@ -183,6 +183,11 @@ Layers ~~~~~~~~~~~~~~~~ .. autoclass:: GaussianFilter :members: + +`BilateralFilter` +~~~~~~~~~~~~~~~~~ +.. autoclass:: BilateralFilter + :members: `HilbertTransform` ~~~~~~~~~~~~~~~~~~ diff --git a/monai/csrc/ext.cpp b/monai/csrc/ext.cpp index 5aaa2e70c9..6740d1b5b4 100644 --- a/monai/csrc/ext.cpp +++ b/monai/csrc/ext.cpp @@ -12,11 +12,16 @@ limitations under the License. */ #include + +#include "filtering/filtering.h" #include "lltm/lltm.h" #include "resample/pushpull.h" #include "utils/resample_utils.h" PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + // filtering + m.def("bilateral_filter", &BilateralFilter, "Bilateral Filter"); + // lltm m.def("lltm_forward", &lltm_forward, "LLTM forward"); m.def("lltm_backward", &lltm_backward, "LLTM backward"); diff --git a/monai/csrc/filtering/bilateral/bilateral.h b/monai/csrc/filtering/bilateral/bilateral.h new file mode 100644 index 0000000000..68f8a3093c --- /dev/null +++ b/monai/csrc/filtering/bilateral/bilateral.h @@ -0,0 +1,42 @@ +/* +Copyright 2020 MONAI Consortium +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#pragma once + +#include +#include "utils/common_utils.h" + +torch::Tensor BilateralFilterCpu(torch::Tensor input, float spatial_sigma, float color_sigma); +torch::Tensor BilateralFilterPHLCpu(torch::Tensor input, float spatial_sigma, float color_sigma); + +#ifdef WITH_CUDA +torch::Tensor BilateralFilterCuda(torch::Tensor input, float spatial_sigma, float color_sigma); +torch::Tensor BilateralFilterPHLCuda(torch::Tensor input, float spatial_sigma, float color_sigma); +#endif + +torch::Tensor BilateralFilter(torch::Tensor input, float spatial_sigma, float color_sigma, bool usePHL) { + torch::Tensor (*filterFunction)(torch::Tensor, float, float); + +#ifdef WITH_CUDA + if (torch::cuda::is_available() && input.is_cuda()) { + CHECK_CONTIGUOUS_CUDA(input); + filterFunction = usePHL ? &BilateralFilterPHLCuda : &BilateralFilterCuda; + } else { + filterFunction = usePHL ? &BilateralFilterPHLCpu : &BilateralFilterCpu; + } +#else + filterFunction = usePHL ? &BilateralFilterPHLCpu : &BilateralFilterCpu; +#endif + + return filterFunction(input, spatial_sigma, color_sigma); +} diff --git a/monai/csrc/filtering/bilateral/bilateralfilter_cpu.cpp b/monai/csrc/filtering/bilateral/bilateralfilter_cpu.cpp new file mode 100644 index 0000000000..cdce729f17 --- /dev/null +++ b/monai/csrc/filtering/bilateral/bilateralfilter_cpu.cpp @@ -0,0 +1,167 @@ +/* +Copyright 2020 MONAI Consortium +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#include +#include + +#include "utils/tensor_description.h" + +struct Indexer { + public: + Indexer(int dimensions, int* sizes) { + m_dimensions = dimensions; + m_sizes = sizes; + m_index = new int[dimensions]{0}; + } + + bool operator++(int) { + for (int i = 0; i < m_dimensions; i++) { + m_index[i] += 1; + + if (m_index[i] < m_sizes[i]) { + return true; + } else { + m_index[i] = 0; + } + } + + return false; + } + + int& operator[](int dimensionIndex) { + return m_index[dimensionIndex]; + } + + private: + int m_dimensions; + int* m_sizes; + int* m_index; +}; + +template +void BilateralFilterCpu(torch::Tensor inputTensor, torch::Tensor outputTensor, float spatialSigma, float colorSigma) { + // Getting tensor description. + TensorDescription desc = TensorDescription(inputTensor); + + // Raw tensor data pointers. + scalar_t* inputTensorData = inputTensor.data_ptr(); + scalar_t* outputTensorData = outputTensor.data_ptr(); + + // Pre-calculate common values + int windowSize = (int)ceil(5.0f * spatialSigma) | 1; // ORing last bit to ensure odd window size + int halfWindowSize = floor(0.5f * windowSize); + scalar_t spatialExpConstant = -1.0f / (2 * spatialSigma * spatialSigma); + scalar_t colorExpConstant = -1.0f / (2 * colorSigma * colorSigma); + + // Kernel sizes. + int* kernelSizes = new int[desc.dimensions]; + + for (int i = 0; i < desc.dimensions; i++) { + kernelSizes[i] = windowSize; + } + + // Pre-calculate gaussian kernel in 1D. + scalar_t* gaussianKernel = new scalar_t[windowSize]; + + for (int i = 0; i < windowSize; i++) { + int distance = i - halfWindowSize; + gaussianKernel[i] = exp(distance * distance * spatialExpConstant); + } + + // Kernel aggregates used to calculate + // the output value. + scalar_t* valueSum = new scalar_t[desc.channelCount]; + scalar_t weightSum = 0; + + // Looping over the batches + for (int b = 0; b < desc.batchCount; b++) { + int batchOffset = b * desc.batchStride; + + // Looping over all dimensions for the home element + Indexer homeIndex = Indexer(desc.dimensions, desc.sizes); + do // while(homeIndex++) + { + // Calculating indexing offset for the home element + int homeOffset = batchOffset; + + for (int i = 0; i < desc.dimensions; i++) { + homeOffset += homeIndex[i] * desc.strides[i]; + } + + // Zero kernel aggregates. + for (int i = 0; i < desc.channelCount; i++) { + valueSum[i] = 0; + } + + weightSum = 0.0f; + + // Looping over all dimensions for the neighbour element + Indexer kernelIndex = Indexer(desc.dimensions, kernelSizes); + do // while(kernelIndex++) + { + // Calculating buffer offset for the neighbour element + // Index is clamped to the border in each dimension. + int neighbourOffset = batchOffset; + + for (int i = 0; i < desc.dimensions; i++) { + int neighbourIndex = homeIndex[i] + kernelIndex[i] - halfWindowSize; + int neighbourIndexClamped = std::min(desc.sizes[i] - 1, std::max(0, neighbourIndex)); + neighbourOffset += neighbourIndexClamped * desc.strides[i]; + } + + // Euclidean color distance. + scalar_t colorDistanceSquared = 0; + + for (int i = 0; i < desc.channelCount; i++) { + scalar_t diff = inputTensorData[homeOffset + i * desc.channelStride] - + inputTensorData[neighbourOffset + i * desc.channelStride]; + colorDistanceSquared += diff * diff; + } + + // Calculating and combining the spatial + // and color weights. + scalar_t spatialWeight = 1; + + for (int i = 0; i < desc.dimensions; i++) { + spatialWeight *= gaussianKernel[kernelIndex[i]]; + } + + scalar_t colorWeight = exp(colorDistanceSquared * colorExpConstant); + scalar_t totalWeight = spatialWeight * colorWeight; + + // Aggregating values. + for (int i = 0; i < desc.channelCount; i++) { + valueSum[i] += inputTensorData[neighbourOffset + i * desc.channelStride] * totalWeight; + } + + weightSum += totalWeight; + } while (kernelIndex++); + + for (int i = 0; i < desc.channelCount; i++) { + outputTensorData[homeOffset + i * desc.channelStride] = valueSum[i] / weightSum; + } + } while (homeIndex++); + } +} + +torch::Tensor BilateralFilterCpu(torch::Tensor inputTensor, float spatialSigma, float colorSigma) { + // Preparing output tensor. + torch::Tensor outputTensor = torch::zeros_like(inputTensor); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(inputTensor.type(), "BilateralFilterCpu", ([&] { + BilateralFilterCpu( + inputTensor, outputTensor, spatialSigma, colorSigma); + })); + + return outputTensor; +} \ No newline at end of file diff --git a/monai/csrc/filtering/bilateral/bilateralfilter_cpu_phl.cpp b/monai/csrc/filtering/bilateral/bilateralfilter_cpu_phl.cpp new file mode 100644 index 0000000000..eb94749ea5 --- /dev/null +++ b/monai/csrc/filtering/bilateral/bilateralfilter_cpu_phl.cpp @@ -0,0 +1,89 @@ +/* +Copyright 2020 MONAI Consortium +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#include + +#include "filtering/permutohedral/permutohedral.h" +#include "utils/tensor_description.h" + +template +void BilateralFilterPHLCpu( + torch::Tensor inputTensor, + torch::Tensor outputTensor, + float spatialSigma, + float colorSigma) { + // Getting tensor description. + TensorDescription desc = TensorDescription(inputTensor); + + int featureChannels = desc.channelCount + desc.dimensions; + + // Preparing memory + scalar_t* inputTensorData = inputTensor.data_ptr(); + scalar_t* outputTensorData = outputTensor.data_ptr(); + scalar_t* data = new scalar_t[desc.channelStride * desc.channelCount]; + scalar_t* features = new scalar_t[desc.channelStride * featureChannels]; + + // Precalculating inverse sigmas + float invSpatialSigma = 1.0f / spatialSigma; + float invColorSigma = 1.0f / colorSigma; + + // Looping over batches + for (int b = 0; b < desc.batchCount; b++) { + int batchOffset = b * desc.batchStride; + + // Creating features (also permuting input data to be channel last. Permutohedral + // implementation should be changed to channel first to avoid this) + for (int i = 0; i < desc.channelStride; i++) { + // Color features (and permutation) + for (int c = 0; c < desc.channelCount; c++) { + features[i * featureChannels + c] = invColorSigma * inputTensorData[batchOffset + i + c * desc.channelStride]; + data[i * desc.channelCount + c] = inputTensorData[batchOffset + i + c * desc.channelStride]; + } + + // Spatial features + int offsetRemanider = i; + + for (int d = 0; d < desc.dimensions; d++) { + int coord = offsetRemanider / desc.strides[d]; + offsetRemanider -= coord * desc.strides[d]; + + features[i * featureChannels + desc.channelCount + d] = invSpatialSigma * coord; + } + } + + // Filtering data with respect to the features. + scalar_t* output = + PermutohedralCPU(data, features, desc.channelCount, featureChannels, desc.channelStride); + + // Writing output tensor. + for (int i = 0; i < desc.channelStride; i++) { + for (int c = 0; c < desc.channelCount; c++) { + outputTensorData[batchOffset + i + c * desc.channelStride] = output[i * desc.channelCount + c]; + } + } + } + + delete[] data; + delete[] features; +} + +// Function to choose template implementation based on dynamic, channels and dimensions +torch::Tensor BilateralFilterPHLCpu(torch::Tensor inputTensor, float spatialSigma, float colorSigma) { + torch::Tensor outputTensor = torch::zeros_like(inputTensor); + + AT_DISPATCH_FLOATING_TYPES(inputTensor.type(), "BilateralFilterPhlCpu", ([&] { + BilateralFilterPHLCpu(inputTensor, outputTensor, spatialSigma, colorSigma); + })); + + return outputTensor; +} diff --git a/monai/csrc/filtering/bilateral/bilateralfilter_cuda.cu b/monai/csrc/filtering/bilateral/bilateralfilter_cuda.cu new file mode 100644 index 0000000000..872ff652cb --- /dev/null +++ b/monai/csrc/filtering/bilateral/bilateralfilter_cuda.cu @@ -0,0 +1,245 @@ +/* +Copyright 2020 MONAI Consortium +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#include +#include +#include + +#include "utils/meta_macros.h" +#include "utils/tensor_description.h" + +__constant__ int cBatchStride; +__constant__ int cColorStride; + +__constant__ int cSizes[3]; +__constant__ int cStrides[3]; + +__constant__ int cKernelSize; +__constant__ float cKernel[256]; + +__constant__ float cColorExponentFactor; + +template +__global__ void BilateralFilterCudaKernel1D(scalar_t* input, scalar_t* output) { + int kernelHalfSize = cKernelSize / 2; + + int homeOffset = blockIdx.x * blockDim.x + threadIdx.x; + int batchOffset = blockIdx.y * cBatchStride; + + scalar_t weightSum = 0; + + for (int kernelOffset = 0; kernelOffset < cKernelSize; kernelOffset++) { + int neighbourOffset = max(0, min(homeOffset + (kernelOffset - kernelHalfSize), cSizes[0] - 1)); + scalar_t gaussian = cKernel[kernelOffset]; + + scalar_t distanceSquared = 0; + +#pragma unroll + for (int c = 0; c < C; c++) { + scalar_t a = input[batchOffset + homeOffset + c * cColorStride]; + scalar_t b = input[batchOffset + neighbourOffset + c * cColorStride]; + scalar_t diff = a - b; + distanceSquared += diff * diff; + } + + scalar_t spatialWeight = gaussian; + scalar_t colorWeight = exp(cColorExponentFactor * distanceSquared); + scalar_t totalWeight = spatialWeight * colorWeight; + +#pragma unroll + for (int c = 0; c < C; c++) { + scalar_t a = input[batchOffset + neighbourOffset + c * cColorStride]; + + output[batchOffset + homeOffset + c * cColorStride] += a * totalWeight; + } + + weightSum += totalWeight; + } + +#pragma unroll + for (int c = 0; c < C; c++) { + output[batchOffset + homeOffset + c * cColorStride] /= weightSum; + } +} + +template +__global__ void BilateralFilterCudaKernel2D(scalar_t* input, scalar_t* output) { + int kernelHalfSize = cKernelSize / 2; + + int homeOffset = blockIdx.x * blockDim.x + threadIdx.x; + int batchOffset = blockIdx.y * cBatchStride; + + int homeX = homeOffset / cStrides[0]; + int homeY = (homeOffset - homeX * cStrides[0]) / cStrides[1]; + + scalar_t weightSum = 0; + + for (int kernelX = 0; kernelX < cKernelSize; kernelX++) { + int neighbourX = max(0, min(homeX + (kernelX - kernelHalfSize), cSizes[0] - 1)); + scalar_t gaussianX = cKernel[kernelX]; + + for (int kernelY = 0; kernelY < cKernelSize; kernelY++) { + int neighbourY = max(0, min(homeY + (kernelY - kernelHalfSize), cSizes[1] - 1)); + scalar_t gaussianY = cKernel[kernelY]; + + int neighbourOffset = neighbourX * cStrides[0] + neighbourY; + + scalar_t distanceSquared = 0; + +#pragma unroll + for (int c = 0; c < C; c++) { + scalar_t a = input[batchOffset + homeOffset + c * cColorStride]; + scalar_t b = input[batchOffset + neighbourOffset + c * cColorStride]; + scalar_t diff = a - b; + distanceSquared += diff * diff; + } + + scalar_t spatialWeight = gaussianX * gaussianY; + scalar_t colorWeight = exp(cColorExponentFactor * distanceSquared); + scalar_t totalWeight = spatialWeight * colorWeight; + +#pragma unroll + for (int c = 0; c < C; c++) { + scalar_t a = input[batchOffset + neighbourOffset + c * cColorStride]; + + output[batchOffset + homeOffset + c * cColorStride] += a * totalWeight; + } + + weightSum += totalWeight; + } + } + +#pragma unroll + for (int c = 0; c < C; c++) { + output[batchOffset + homeOffset + c * cColorStride] /= weightSum; + } +} + +template +__global__ void BilateralFilterCudaKernel3D(scalar_t* input, scalar_t* output) { + int kernelHalfSize = cKernelSize / 2; + + int homeOffset = blockIdx.x * blockDim.x + threadIdx.x; + int batchOffset = blockIdx.y * cBatchStride; + + int homeX = homeOffset / cStrides[0]; + int homeY = (homeOffset - homeX * cStrides[0]) / cStrides[1]; + int homeZ = (homeOffset - homeX * cStrides[0] - homeY * cStrides[1]) / cStrides[2]; + + scalar_t weightSum = 0; + + for (int kernelX = 0; kernelX < cKernelSize; kernelX++) { + int neighbourX = max(0, min(homeX + (kernelX - kernelHalfSize), cSizes[0] - 1)); + scalar_t gaussianX = cKernel[kernelX]; + + for (int kernelY = 0; kernelY < cKernelSize; kernelY++) { + int neighbourY = max(0, min(homeY + (kernelY - kernelHalfSize), cSizes[1] - 1)); + scalar_t gaussianY = cKernel[kernelY]; + + for (int kernelZ = 0; kernelZ < cKernelSize; kernelZ++) { + int neighbourZ = max(0, min(homeZ + (kernelZ - kernelHalfSize), cSizes[2] - 1)); + scalar_t gaussianZ = cKernel[kernelZ]; + + int neighbourOffset = neighbourX * cStrides[0] + neighbourY * cStrides[1] + neighbourZ; + + scalar_t distanceSquared = 0; + +#pragma unroll + for (int c = 0; c < C; c++) { + scalar_t a = input[batchOffset + homeOffset + c * cColorStride]; + scalar_t b = input[batchOffset + neighbourOffset + c * cColorStride]; + scalar_t diff = a - b; + distanceSquared += diff * diff; + } + + scalar_t spatialWeight = gaussianX * gaussianY * gaussianZ; + scalar_t colorWeight = exp(cColorExponentFactor * distanceSquared); + scalar_t totalWeight = spatialWeight * colorWeight; + +#pragma unroll + for (int c = 0; c < C; c++) { + scalar_t a = input[batchOffset + neighbourOffset + c * cColorStride]; + output[batchOffset + homeOffset + c * cColorStride] += a * totalWeight; + } + + weightSum += totalWeight; + } + } + } + +#pragma unroll + for (int c = 0; c < C; c++) { + output[batchOffset + homeOffset + c * cColorStride] /= weightSum; + } +} + +template +void BilateralFilterCuda(torch::Tensor inputTensor, torch::Tensor outputTensor, float spatialSigma, float colorSigma) { + // Getting tensor description. + TensorDescription desc = TensorDescription(inputTensor); + + // Pre-calculating exponent factors. + float spatialExponentFactor = -1.0f / (2 * spatialSigma * spatialSigma); + float colorExponentFactor = -1.0f / (2 * colorSigma * colorSigma); + + // Pre-calculating gaussian kernel. + int kernelSize = (int)ceil(5.0f * spatialSigma) | 1; // ORing last bit to ensure odd window size + int kernelHalfSize = floor(0.5f * kernelSize); + float* kernel = new float[kernelSize]; + + for (int i = 0; i < kernelSize; i++) { + int distance = i - kernelHalfSize; + kernel[i] = exp(distance * distance * spatialExponentFactor); + } + + // Writing constant memory. + cudaMemcpyToSymbol(cBatchStride, &desc.batchStride, sizeof(int)); + cudaMemcpyToSymbol(cColorStride, &desc.channelStride, sizeof(int)); + cudaMemcpyToSymbol(cSizes, desc.sizes, sizeof(int) * D); + cudaMemcpyToSymbol(cStrides, desc.strides, sizeof(int) * D); + cudaMemcpyToSymbol(cKernelSize, &kernelSize, sizeof(int)); + cudaMemcpyToSymbol(cKernel, kernel, sizeof(float) * kernelSize); + cudaMemcpyToSymbol(cColorExponentFactor, &colorExponentFactor, sizeof(float)); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + inputTensor.type(), "BilateralFilterCudaKernel", ([&] { + // Dispatch kernel. (Partial template function specialisation not supported at present so using this switch + // instead) + switch (D) { + case (1): + BilateralFilterCudaKernel1D<<>>( + inputTensor.data_ptr(), outputTensor.data_ptr()); + break; + case (2): + BilateralFilterCudaKernel2D<<>>( + inputTensor.data_ptr(), outputTensor.data_ptr()); + break; + case (3): + BilateralFilterCudaKernel3D<<>>( + inputTensor.data_ptr(), outputTensor.data_ptr()); + break; + } + })); + + delete[] kernel; +} + +// Function to choose template implementation based on dynamic, channels and dimensions +torch::Tensor BilateralFilterCuda(torch::Tensor inputTensor, float spatialSigma, float colorSigma) { + torch::Tensor outputTensor = torch::zeros_like(inputTensor); + +#define CASE(c, d) BilateralFilterCuda(inputTensor, outputTensor, spatialSigma, colorSigma); + SWITCH_AB(CASE, 16, 3, inputTensor.size(1), inputTensor.dim() - 2); + + return outputTensor; +} diff --git a/monai/csrc/filtering/bilateral/bilateralfilter_cuda_phl.cu b/monai/csrc/filtering/bilateral/bilateralfilter_cuda_phl.cu new file mode 100644 index 0000000000..df4ed8771b --- /dev/null +++ b/monai/csrc/filtering/bilateral/bilateralfilter_cuda_phl.cu @@ -0,0 +1,130 @@ +/* +Copyright 2020 MONAI Consortium +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#include +#include +#include + +#include "filtering/permutohedral/permutohedral.h" +#include "utils/meta_macros.h" +#include "utils/tensor_description.h" + +__constant__ int cBatchStride; +__constant__ int cChannelStride; +__constant__ int cSpatialStrides[3]; +__constant__ float cInvSpatialSigma; +__constant__ float cInvColorSigma; + +template +__global__ void FeatureCreation(const scalar_t* inputTensor, scalar_t* outputData, scalar_t* outputFeatures) { + int elementIndex = blockIdx.x * blockDim.x + threadIdx.x; + int batchIndex = blockIdx.y; + + int dataBatchOffset = batchIndex * cBatchStride; + int featureBatchOffset = batchIndex * (D + C) * cChannelStride; + +#pragma unroll + for (int i = 0; i < C; i++) { + outputData[dataBatchOffset + elementIndex * C + i] = + inputTensor[dataBatchOffset + elementIndex + i * cChannelStride]; + outputFeatures[featureBatchOffset + elementIndex * (C + D) + i] = + inputTensor[dataBatchOffset + elementIndex + i * cChannelStride] * cInvColorSigma; + } + + int remainder = elementIndex; + +#pragma unroll + for (int i = 0; i < D; i++) { + int coord = remainder / cSpatialStrides[i]; + remainder -= coord * cSpatialStrides[i]; + + outputFeatures[featureBatchOffset + elementIndex * (C + D) + C + i] = coord * cInvSpatialSigma; + } +} + +template +__global__ void WriteOutput(const scalar_t* data, scalar_t* outputTensor) { + int elementIndex = blockIdx.x * blockDim.x + threadIdx.x; + int batchIndex = blockIdx.y; + int batchOffset = batchIndex * cBatchStride; + +#pragma unroll + for (int i = 0; i < C; i++) { + outputTensor[batchOffset + elementIndex + i * cChannelStride] = data[batchOffset + elementIndex * C + i]; + } +} + +template +void BilateralFilterPHLCuda( + torch::Tensor inputTensor, + torch::Tensor outputTensor, + float spatialSigma, + float colorSigma) { + // Getting tensor description. + TensorDescription desc = TensorDescription(inputTensor); + + int featureChannelCount = desc.channelCount + desc.dimensions; + + // Pre calculating inverse sigmas. + float invSpatialSigma = 1.0f / spatialSigma; + float invColorSigma = 1.0f / colorSigma; + + // Preparing global memory + scalar_t* inputTensorData = inputTensor.data_ptr(); + scalar_t* outputTensorData = outputTensor.data_ptr(); + + scalar_t* data; + scalar_t* features; + cudaMalloc(&data, desc.batchCount * desc.channelStride * desc.channelCount * sizeof(scalar_t)); + cudaMalloc(&features, desc.batchCount * desc.channelStride * featureChannelCount * sizeof(scalar_t)); + + // Prparing constant memory + cudaMemcpyToSymbol(cBatchStride, &desc.batchStride, sizeof(int)); + cudaMemcpyToSymbol(cChannelStride, &desc.channelStride, sizeof(int)); + cudaMemcpyToSymbol(cSpatialStrides, desc.strides, sizeof(int) * desc.dimensions); + cudaMemcpyToSymbol(cInvSpatialSigma, &invSpatialSigma, sizeof(float)); + cudaMemcpyToSymbol(cInvColorSigma, &invColorSigma, sizeof(float)); + + // Creating features + FeatureCreation + <<>>(inputTensorData, data, features); + + // Filtering data with respect to the features for each sample in batch + for (int batchIndex = 0; batchIndex < desc.batchCount; batchIndex++) { + scalar_t* offsetData = data + batchIndex * desc.batchStride; + scalar_t* offsetFeatures = features + batchIndex * featureChannelCount * desc.channelStride; + + PermutohedralCuda(offsetData, offsetFeatures, desc.channelStride, true); + } + + // Writing output + WriteOutput<<>>(data, outputTensorData); + + cudaFree(data); + cudaFree(features); +} + +// Function to choose template implementation based on dynamic, channels and dimensions +torch::Tensor BilateralFilterPHLCuda(torch::Tensor inputTensor, float spatialSigma, float colorSigma) { + torch::Tensor outputTensor = torch::zeros_like(inputTensor); + +#define CASE(c, d) \ + AT_DISPATCH_FLOATING_TYPES(inputTensor.type(), "BilateralFilterCudaPHL", ([&] { \ + BilateralFilterPHLCuda( \ + inputTensor, outputTensor, spatialSigma, colorSigma); \ + })); + + SWITCH_AB(CASE, 16, 3, inputTensor.size(1), inputTensor.dim() - 2); + + return outputTensor; +} diff --git a/monai/csrc/filtering/filtering.h b/monai/csrc/filtering/filtering.h new file mode 100644 index 0000000000..18cf2ae6f4 --- /dev/null +++ b/monai/csrc/filtering/filtering.h @@ -0,0 +1,16 @@ +/* +Copyright 2020 MONAI Consortium +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#pragma once + +#include "bilateral/bilateral.h" \ No newline at end of file diff --git a/monai/csrc/filtering/permutohedral/hash_table.cu b/monai/csrc/filtering/permutohedral/hash_table.cu new file mode 100644 index 0000000000..cdda0b4fed --- /dev/null +++ b/monai/csrc/filtering/permutohedral/hash_table.cu @@ -0,0 +1,255 @@ +/* +Copyright 2020 MONAI Consortium +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#include + +//#define USE_ADDITIVE_HASH + +// turn this on if you want to get slighly less memory consumption and slightly longer run times. +//#define LINEAR_D_MEMORY + +#define USE_CUSTOM_MODULO + +__device__ __constant__ signed short* table_keys; +__device__ __constant__ int* table_entries; +__device__ __constant__ unsigned int table_capacity; +__device__ __constant__ signed short* table_zeros; +__device__ __constant__ char* table_rank; + +/*************************************************************/ +/* Fast computation of modulo operator with constant divisor */ +/*************************************************************/ +__device__ __constant__ unsigned int __div_m; +__device__ __constant__ unsigned int __div_l; +__device__ __constant__ unsigned int __div_c; + +#ifdef USE_CUSTOM_MODULO +__device__ inline unsigned int modHash(unsigned int n) { + unsigned int t1 = __umulhi(__div_m, n); + return n - ((t1 + ((n - t1) >> 1)) >> (__div_l - 1)) * __div_c; +} + +#else +#define modHash(n) ((n) % (2 * table_capacity)); +#endif + +/*************************************************************/ +/* End modulo */ +/*************************************************************/ + +__device__ __constant__ static unsigned int hOffset[64]; + +template +static scalar_t* createHashTable(int capacity) { + scalar_t* values; + cudaMalloc(&values, capacity * vd * sizeof(scalar_t)); + cudaMemset(values, 0, capacity * vd * sizeof(scalar_t)); + + int* entries; + cudaMalloc(&entries, capacity * 2 * sizeof(int)); + cudaMemset(entries, -1, capacity * 2 * sizeof(int)); + + cudaMemcpyToSymbol(table_capacity, &capacity, sizeof(int)); + + cudaMemcpyToSymbol(table_entries, &entries, sizeof(int*)); + +#ifdef LINEAR_D_MEMORY + + char* ranks; + cudaMalloc(&ranks, capacity * sizeof(char)); + + signed short* zeros; + cudaMalloc(&zeros, capacity * sizeof(signed short)); + + cudaMemcpyToSymbol(table_rank, &ranks, sizeof(char*)); + cudaMemcpyToSymbol(table_zeros, &zeros, sizeof(char*)); + +#else + + signed short* keys; + cudaMalloc(&keys, capacity * kd * sizeof(signed short)); + cudaMemset(keys, 0, capacity * kd * sizeof(signed short)); + + cudaMemcpyToSymbol(table_keys, &keys, sizeof(unsigned int*)); + +#endif + + return values; +} + +template +static void destroyHashTable() { +#ifndef LINEAR_D_MEMORY + cudaFree(table_keys); +#endif + cudaFree(table_entries); +} + +template +__device__ __host__ static unsigned int hash(signed short* key) { + unsigned int k = 0; + for (int i = 0; i < kd; i++) { + k += key[i]; + k = k * 2531011; + } + return k; +} + +template +__device__ __host__ static unsigned int hash(int* key) { + unsigned int k = 0; + for (int i = 0; i < kd; i++) { + k += key[i]; + k = k * 2531011; + } + return k; +} + +template +__device__ static bool matchKey(int idx, signed short* key) { + bool match = true; + int slot = idx / (d + 1), color = idx - slot * (d + 1); + char* rank = table_rank + slot * (d + 1); + signed short* zero = table_zeros + slot * (d + 1); + + for (int i = 0; i < d && match; i++) { + match = (key[i] == zero[i] + color - (rank[i] > d - color ? (d + 1) : 0)); + } + + return match; +} + +template +__device__ static void generateKey(int idx, signed short* key) { + int slot = idx / (d + 1), color = idx - slot * (d + 1); + char* rank = table_rank + slot * (d + 1); + signed short* zero = table_zeros + slot * (d + 1); + + for (int i = 0; i < d; i++) { + key[i] = zero[i] + color - (rank[i] > d - color ? (d + 1) : 0); + } +} + +template +__device__ static int hashTableInsert(unsigned int fh, signed short* key, unsigned int slot) { + int h = modHash(fh); + while (1) { + int* e = &table_entries[h]; + + // If the cell is empty (-1), lock it (-2) + int contents = atomicCAS(e, -1, -2); + + if (contents == -2) { + // If it was locked already, move on to the next cell + } else if (contents == -1) { + // If it was empty, we successfully locked it. Write our key. + +#ifndef LINEAR_D_MEMORY + for (int i = 0; i < kd; i++) { + table_keys[slot * kd + i] = key[i]; + } +#endif + + // Unlock + atomicExch(e, slot); + + return h; + } else { +// The cell is unlocked and has a key in it, check if it matches +#ifdef LINEAR_D_MEMORY + if (matchKey(contents, key)) + return h; +#else + bool match = true; + + for (int i = 0; i < kd && match; i++) { + match = (table_keys[contents * kd + i] == key[i]); + } + + if (match) + return h; +#endif + } + // increment the bucket with wraparound + h++; + + if (h == table_capacity * 2) + h = 0; + } +} + +template +__device__ static int hashTableInsert(signed short* key, unsigned int slot) { + unsigned int myHash = hash(key); + return hashTableInsert(myHash, key, slot); +} + +template +__device__ static int hashTableRetrieveWithHash(unsigned int fh, signed short* key) { + int h = modHash(fh); + while (1) { + int* e = table_entries + h; + + if (*e == -1) + return -1; + +#ifdef LINEAR_D_MEMORY + if (matchKey((*e), key)) + return *e; +#else + bool match = true; + + for (int i = 0; i < kd && match; i++) { + match = (table_keys[(*e) * kd + i] == key[i]); + } + + if (match) + return *e; +#endif + + h++; + + if (h == table_capacity * 2) + h = 0; + } +} + +template +__device__ static int hashTableRetrieve(signed short* key) { + int h = modHash(hash(key)); + while (1) { + int* e = table_entries + h; + + if (*e == -1) + return -1; + +#ifdef LINEAR_D_MEMORY + if (matchKey((*e), key)) + return *e; +#else + bool match = true; + + for (int i = 0; i < kd && match; i++) { + match = (table_keys[(*e) * kd + i] == key[i]); + } + + if (match) + return *e; +#endif + + h++; + + if (h == table_capacity * 2) + h = 0; + } +} \ No newline at end of file diff --git a/monai/csrc/filtering/permutohedral/permutohedral.h b/monai/csrc/filtering/permutohedral/permutohedral.h new file mode 100644 index 0000000000..7f57c91a78 --- /dev/null +++ b/monai/csrc/filtering/permutohedral/permutohedral.h @@ -0,0 +1,20 @@ +/* +Copyright 2020 MONAI Consortium +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#pragma once +template +scalar_t* PermutohedralCPU(scalar_t* data, scalar_t* features, int dataChannels, int featureChannels, int elementCount); +#ifdef WITH_CUDA +template +void PermutohedralCuda(scalar_t* data, scalar_t* features, int elementCount, bool accurate); +#endif diff --git a/monai/csrc/filtering/permutohedral/permutohedral_cpu.cpp b/monai/csrc/filtering/permutohedral/permutohedral_cpu.cpp new file mode 100644 index 0000000000..597bf263c1 --- /dev/null +++ b/monai/csrc/filtering/permutohedral/permutohedral_cpu.cpp @@ -0,0 +1,516 @@ +/* +Copyright 2020 MONAI Consortium +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +/* +Adapted from https://github.com/abadams/permutohedral +which has the following license... + +MIT License + +Copyright (c) 2020 Andrew Adams + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +*/ + +#include +#include + +#include + +using namespace std; + +/***************************************************************/ +/* Hash table implementation for permutohedral lattice + * + * The lattice points are stored sparsely using a hash table. + * The key for each point is its spatial location in the (d+1)- + * dimensional space. + */ +/***************************************************************/ +template +class HashTablePermutohedral { + public: + /* Constructor + * kd_: the dimensionality of the position vectors on the hyperplane. + * vd_: the dimensionality of the value vectors + */ + HashTablePermutohedral(int kd_, int vd_) : kd(kd_), vd(vd_) { + capacity = 1 << 15; + filled = 0; + entries = new Entry[capacity]; + keys = new short[kd * capacity / 2]; + values = new scalar_t[vd * capacity / 2]; + memset(values, 0, sizeof(scalar_t) * vd * capacity / 2); + } + + // Returns the number of vectors stored. + int size() { + return filled; + } + + // Returns a pointer to the keys array. + short* getKeys() { + return keys; + } + + // Returns a pointer to the values array. + scalar_t* getValues() { + return values; + } + + /* Returns the index into the hash table for a given key. + * key: a pointer to the position vector. + * h: hash of the position vector. + * create: a flag specifying whether an entry should be created, + * should an entry with the given key not found. + */ + int lookupOffset(short* key, size_t h, bool create = true) { + // Double hash table size if necessary + if (filled >= (capacity / 2) - 1) { + grow(); + } + + // Find the entry with the given key + while (1) { + Entry e = entries[h]; + // check if the cell is empty + if (e.keyIdx == -1) { + if (!create) + return -1; // Return not found. + // need to create an entry. Store the given key. + for (int i = 0; i < kd; i++) + keys[filled * kd + i] = key[i]; + e.keyIdx = filled * kd; + e.valueIdx = filled * vd; + entries[h] = e; + filled++; + return e.valueIdx; + } + + // check if the cell has a matching key + bool match = true; + for (int i = 0; i < kd && match; i++) + match = keys[e.keyIdx + i] == key[i]; + if (match) + return e.valueIdx; + + // increment the bucket with wraparound + h++; + if (h == capacity) + h = 0; + } + } + + /* Looks up the value vector associated with a given key vector. + * k : pointer to the key vector to be looked up. + * create : true if a non-existing key should be created. + */ + scalar_t* lookup(short* k, bool create = true) { + size_t h = hash(k) % capacity; + int offset = lookupOffset(k, h, create); + if (offset < 0) + return NULL; + else + return values + offset; + }; + + /* Hash function used in this implementation. A simple base conversion. */ + size_t hash(const short* key) { + size_t k = 0; + for (int i = 0; i < kd; i++) { + k += key[i]; + k *= 2531011; + } + return k; + } + + private: + /* Grows the size of the hash table */ + void grow() { + size_t oldCapacity = capacity; + capacity *= 2; + + // Migrate the value vectors. + scalar_t* newValues = new scalar_t[vd * capacity / 2]; + memset(newValues, 0, sizeof(scalar_t) * vd * capacity / 2); + memcpy(newValues, values, sizeof(scalar_t) * vd * filled); + delete[] values; + values = newValues; + + // Migrate the key vectors. + short* newKeys = new short[kd * capacity / 2]; + memcpy(newKeys, keys, sizeof(short) * kd * filled); + delete[] keys; + keys = newKeys; + + Entry* newEntries = new Entry[capacity]; + + // Migrate the table of indices. + for (size_t i = 0; i < oldCapacity; i++) { + if (entries[i].keyIdx == -1) + continue; + size_t h = hash(keys + entries[i].keyIdx) % capacity; + while (newEntries[h].keyIdx != -1) { + h++; + if (h == capacity) + h = 0; + } + newEntries[h] = entries[i]; + } + delete[] entries; + entries = newEntries; + } + + // Private struct for the hash table entries. + struct Entry { + Entry() : keyIdx(-1), valueIdx(-1) {} + int keyIdx; + int valueIdx; + }; + + short* keys; + scalar_t* values; + Entry* entries; + size_t capacity, filled; + int kd, vd; +}; + +/***************************************************************/ +/* The algorithm class that performs the filter + * + * PermutohedralLattice::filter(...) does all the work. + * + */ +/***************************************************************/ +template +class PermutohedralLattice { + public: + /* Filters given image against a reference image. + * im : image to be bilateral-filtered. + * ref : reference image whose edges are to be respected. + */ + static scalar_t* filter(scalar_t* data, scalar_t* features, int dataChannels, int featureChannels, int elementCount) { + // Create lattice + PermutohedralLattice lattice(featureChannels, dataChannels + 1, elementCount); + + // Splat into the lattice + scalar_t* col = new scalar_t[dataChannels + 1]; + col[dataChannels] = 1; // homogeneous coordinate + + for (int i = 0, e = 0; e < elementCount; e++) { + for (int c = 0; c < dataChannels; c++, i++) { + col[c] = data[i]; + } + + scalar_t* featureVec = features + e * featureChannels; + lattice.splat(featureVec, col); + } + + // Blur the lattice + lattice.blur(); + + // Slice from the lattice + scalar_t* outputData = new scalar_t[elementCount * dataChannels]; + + lattice.beginSlice(); + + for (int i = 0, e = 0; e < elementCount; e++) { + lattice.slice(col); + + scalar_t scale = 1.0f / col[dataChannels]; + for (int c = 0; c < dataChannels; c++, i++) { + outputData[i] = col[c] * scale; + } + } + + return outputData; + } + + /* Constructor + * d_ : dimensionality of key vectors + * vd_ : dimensionality of value vectors + * nData_ : number of points in the input + */ + PermutohedralLattice(int d_, int vd_, int nData_) : d(d_), vd(vd_), nData(nData_), hashTable(d_, vd_) { + // Allocate storage for various arrays + elevated = new scalar_t[d + 1]; + scaleFactor = new scalar_t[d]; + + greedy = new short[d + 1]; + rank = new char[d + 1]; + barycentric = new scalar_t[d + 2]; + replay = new ReplayEntry[nData * (d + 1)]; + nReplay = 0; + canonical = new short[(d + 1) * (d + 1)]; + key = new short[d + 1]; + + // compute the coordinates of the canonical simplex, in which + // the difference between a contained point and the zero + // remainder vertex is always in ascending order. (See pg.4 of paper.) + for (int i = 0; i <= d; i++) { + for (int j = 0; j <= d - i; j++) + canonical[i * (d + 1) + j] = i; + for (int j = d - i + 1; j <= d; j++) + canonical[i * (d + 1) + j] = i - (d + 1); + } + + // Compute parts of the rotation matrix E. (See pg.4-5 of paper.) + for (int i = 0; i < d; i++) { + // the diagonal entries for normalization + scaleFactor[i] = 1.0f / (sqrtf((scalar_t)(i + 1) * (i + 2))); + + /* We presume that the user would like to do a Gaussian blur of standard deviation + * 1 in each dimension (or a total variance of d, summed over dimensions.) + * Because the total variance of the blur performed by this algorithm is not d, + * we must scale the space to offset this. + * + * The total variance of the algorithm is (See pg.6 and 10 of paper): + * [variance of splatting] + [variance of blurring] + [variance of splatting] + * = d(d+1)(d+1)/12 + d(d+1)(d+1)/2 + d(d+1)(d+1)/12 + * = 2d(d+1)(d+1)/3. + * + * So we need to scale the space by (d+1)sqrt(2/3). + */ + scaleFactor[i] *= (d + 1) * sqrtf(2.0 / 3); + } + } + + /* Performs splatting with given position and value vectors */ + void splat(scalar_t* position, scalar_t* value) { + // first rotate position into the (d+1)-dimensional hyperplane + elevated[d] = -d * position[d - 1] * scaleFactor[d - 1]; + for (int i = d - 1; i > 0; i--) + elevated[i] = + (elevated[i + 1] - i * position[i - 1] * scaleFactor[i - 1] + (i + 2) * position[i] * scaleFactor[i]); + elevated[0] = elevated[1] + 2 * position[0] * scaleFactor[0]; + + // prepare to find the closest lattice points + scalar_t scale = 1.0f / (d + 1); + char* myrank = rank; + short* mygreedy = greedy; + + // greedily search for the closest zero-colored lattice point + int sum = 0; + for (int i = 0; i <= d; i++) { + scalar_t v = elevated[i] * scale; + scalar_t up = ceilf(v) * (d + 1); + scalar_t down = floorf(v) * (d + 1); + + if (up - elevated[i] < elevated[i] - down) + mygreedy[i] = (short)up; + else + mygreedy[i] = (short)down; + + sum += mygreedy[i]; + } + sum /= d + 1; + + // rank differential to find the permutation between this simplex and the canonical one. + // (See pg. 3-4 in paper.) + memset(myrank, 0, sizeof(char) * (d + 1)); + for (int i = 0; i < d; i++) + for (int j = i + 1; j <= d; j++) + if (elevated[i] - mygreedy[i] < elevated[j] - mygreedy[j]) + myrank[i]++; + else + myrank[j]++; + + if (sum > 0) { + // sum too large - the point is off the hyperplane. + // need to bring down the ones with the smallest differential + for (int i = 0; i <= d; i++) { + if (myrank[i] >= d + 1 - sum) { + mygreedy[i] -= d + 1; + myrank[i] += sum - (d + 1); + } else + myrank[i] += sum; + } + } else if (sum < 0) { + // sum too small - the point is off the hyperplane + // need to bring up the ones with largest differential + for (int i = 0; i <= d; i++) { + if (myrank[i] < -sum) { + mygreedy[i] += d + 1; + myrank[i] += (d + 1) + sum; + } else + myrank[i] += sum; + } + } + + // Compute barycentric coordinates (See pg.10 of paper.) + memset(barycentric, 0, sizeof(scalar_t) * (d + 2)); + for (int i = 0; i <= d; i++) { + barycentric[d - myrank[i]] += (elevated[i] - mygreedy[i]) * scale; + barycentric[d + 1 - myrank[i]] -= (elevated[i] - mygreedy[i]) * scale; + } + barycentric[0] += 1.0f + barycentric[d + 1]; + + // Splat the value into each vertex of the simplex, with barycentric weights. + for (int remainder = 0; remainder <= d; remainder++) { + // Compute the location of the lattice point explicitly (all but the last coordinate - it's redundant because they + // sum to zero) + for (int i = 0; i < d; i++) + key[i] = mygreedy[i] + canonical[remainder * (d + 1) + myrank[i]]; + + // Retrieve pointer to the value at this vertex. + scalar_t* val = hashTable.lookup(key, true); + + // Accumulate values with barycentric weight. + for (int i = 0; i < vd; i++) + val[i] += barycentric[remainder] * value[i]; + + // Record this interaction to use later when slicing + replay[nReplay].offset = val - hashTable.getValues(); + replay[nReplay].weight = barycentric[remainder]; + nReplay++; + } + } + + // Prepare for slicing + void beginSlice() { + nReplay = 0; + } + + /* Performs slicing out of position vectors. Note that the barycentric weights and the simplex + * containing each position vector were calculated and stored in the splatting step. + * We may reuse this to accelerate the algorithm. (See pg. 6 in paper.) + */ + void slice(scalar_t* col) { + scalar_t* base = hashTable.getValues(); + for (int j = 0; j < vd; j++) + col[j] = 0; + for (int i = 0; i <= d; i++) { + ReplayEntry r = replay[nReplay++]; + for (int j = 0; j < vd; j++) { + col[j] += r.weight * base[r.offset + j]; + } + } + } + + /* Performs a Gaussian blur along each projected axis in the hyperplane. */ + void blur() { + // Prepare arrays + short* neighbor1 = new short[d + 1]; + short* neighbor2 = new short[d + 1]; + scalar_t* newValue = new scalar_t[vd * hashTable.size()]; + scalar_t* oldValue = hashTable.getValues(); + scalar_t* hashTableBase = oldValue; + + scalar_t* zero = new scalar_t[vd]; + for (int k = 0; k < vd; k++) + zero[k] = 0; + + // For each of d+1 axes, + for (int j = 0; j <= d; j++) { + // For each vertex in the lattice, + for (int i = 0; i < hashTable.size(); i++) { // blur point i in dimension j + short* key = hashTable.getKeys() + i * (d); // keys to current vertex + for (int k = 0; k < d; k++) { + neighbor1[k] = key[k] + 1; + neighbor2[k] = key[k] - 1; + } + neighbor1[j] = key[j] - d; + neighbor2[j] = key[j] + d; // keys to the neighbors along the given axis. + + scalar_t* oldVal = oldValue + i * vd; + scalar_t* newVal = newValue + i * vd; + + scalar_t *vm1, *vp1; + + vm1 = hashTable.lookup(neighbor1, false); // look up first neighbor + if (vm1) + vm1 = vm1 - hashTableBase + oldValue; + else + vm1 = zero; + + vp1 = hashTable.lookup(neighbor2, false); // look up second neighbor + if (vp1) + vp1 = vp1 - hashTableBase + oldValue; + else + vp1 = zero; + + // Mix values of the three vertices + for (int k = 0; k < vd; k++) + newVal[k] = (0.25f * vm1[k] + 0.5f * oldVal[k] + 0.25f * vp1[k]); + } + scalar_t* tmp = newValue; + newValue = oldValue; + oldValue = tmp; + // the freshest data is now in oldValue, and newValue is ready to be written over + } + + // depending where we ended up, we may have to copy data + if (oldValue != hashTableBase) { + memcpy(hashTableBase, oldValue, hashTable.size() * vd * sizeof(scalar_t)); + delete oldValue; + } else { + delete newValue; + } + + delete zero; + delete neighbor1; + delete neighbor2; + } + + private: + int d, vd, nData; + scalar_t *elevated, *scaleFactor, *barycentric; + short* canonical; + short* key; + + // slicing is done by replaying splatting (ie storing the sparse matrix) + struct ReplayEntry { + int offset; + scalar_t weight; + } * replay; + int nReplay, nReplaySub; + + public: + char* rank; + short* greedy; + HashTablePermutohedral hashTable; +}; + +template +scalar_t* PermutohedralCPU( + scalar_t* data, + scalar_t* features, + int dataChannels, + int featureChannels, + int elementCount) { + return PermutohedralLattice::filter(data, features, dataChannels, featureChannels, elementCount); +} + +template float* PermutohedralCPU(float* data, float* features, int dataChannels, int featureChannels, int elementCount); +template double* PermutohedralCPU( + double* data, + double* features, + int dataChannels, + int featureChannels, + int elementCount); \ No newline at end of file diff --git a/monai/csrc/filtering/permutohedral/permutohedral_cuda.cu b/monai/csrc/filtering/permutohedral/permutohedral_cuda.cu new file mode 100644 index 0000000000..c60d0d8c31 --- /dev/null +++ b/monai/csrc/filtering/permutohedral/permutohedral_cuda.cu @@ -0,0 +1,537 @@ +/* +Copyright 2020 MONAI Consortium +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +/* +Adapted from https://github.com/abadams/permutohedral +which has the following license... + +MIT License + +Copyright (c) 2020 Andrew Adams + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +*/ + +#define BLOCK_SIZE 64 + +#include +#include +#include +#include +#include + +#include "hash_table.cu" +#include "utils/meta_macros.h" + +template +struct MatrixEntry { + int index; + scalar_t weight; +}; + +template +__global__ static void createMatrix( + const int elementCount, + const scalar_t* positions, + const scalar_t* values, + const scalar_t* scaleFactor, + MatrixEntry* matrix) { + const int threadId = threadIdx.x; + const int idx = threadIdx.x + blockIdx.x * BLOCK_SIZE; + const bool outOfBounds = idx >= elementCount; + + scalar_t myElevated[pd + 1]; + const scalar_t* myPosition = positions + idx * pd; + + int myGreedy[pd + 1]; + int myRank[pd + 1]; + + scalar_t myBarycentric[pd + 2]; + __shared__ short keys[pd * BLOCK_SIZE]; + short* myKey = keys + threadId * pd; + + if (!outOfBounds) { + myElevated[pd] = -pd * myPosition[pd - 1] * scaleFactor[pd - 1]; + + for (int i = pd - 1; i > 0; i--) { + myElevated[i] = + myElevated[i + 1] - i * (myPosition[i - 1]) * scaleFactor[i - 1] + (i + 2) * myPosition[i] * scaleFactor[i]; + } + + myElevated[0] = myElevated[1] + 2 * myPosition[0] * scaleFactor[0]; + + // find the closest zero-colored lattice point + + // greedily search for the closest zero-colored lattice point + signed short sum = 0; + + for (int i = 0; i <= pd; i++) { + scalar_t v = myElevated[i] * (1.0f / (pd + 1)); + scalar_t up = ceilf(v) * (pd + 1); + scalar_t down = floorf(v) * (pd + 1); + + myGreedy[i] = (signed short)(up - myElevated[i] < myElevated[i] - down ? up : down); + sum += myGreedy[i]; + } + + sum /= pd + 1; + + // sort differential to find the permutation between this simplex and the canonical one + for (int i = 0; i <= pd; i++) { + myRank[i] = 0; + + for (int j = 0; j <= pd; j++) { + scalar_t iDiff = myElevated[i] - myGreedy[i]; + scalar_t jDiff = myElevated[j] - myGreedy[j]; + + if (iDiff < jDiff || (iDiff == jDiff && i > j)) { + myRank[i]++; + } + } + } + + if (sum > 0) // sum too large, need to bring down the ones with the smallest differential + { + for (int i = 0; i <= pd; i++) { + if (myRank[i] >= pd + 1 - sum) { + myGreedy[i] -= (pd + 1); + myRank[i] += sum - (pd + 1); + } else { + myRank[i] += sum; + } + } + } else if (sum < 0) // sum too small, need to bring up the ones with largest differential + { + for (int i = 0; i <= pd; i++) { + if (myRank[i] < -sum) { + myGreedy[i] += (pd + 1); + myRank[i] += sum + (pd + 1); + } else { + myRank[i] += sum; + } + } + } + +#ifdef LINEAR_D_MEMORY + for (int i = 0; i <= pd; i++) { + table_zeros[idx * (pd + 1) + i] = myGreedy[i]; + table_rank[idx * (pd + 1) + i] = myRank[i]; + } +#endif + + // turn delta into barycentric coords + for (int i = 0; i <= pd + 1; i++) { + myBarycentric[i] = 0; + } + + for (int i = 0; i <= pd; i++) { + scalar_t delta = (myElevated[i] - myGreedy[i]) * (1.0f / (pd + 1)); + myBarycentric[pd - myRank[i]] += delta; + myBarycentric[pd + 1 - myRank[i]] -= delta; + } + + myBarycentric[0] += 1.0f + myBarycentric[pd + 1]; + } + +#ifdef USE_ADDITIVE_HASH + unsigned int cumulative_hash = hash(myGreedy); +#endif + + for (int color = 0; color <= pd; color++) { + // Compute the location of the lattice point explicitly (all but + // the last coordinate - it's redundant because they sum to zero) + if (!outOfBounds) { + for (int i = 0; i < pd; i++) { + myKey[i] = myGreedy[i] + color; + + if (myRank[i] > pd - color) { + myKey[i] -= (pd + 1); + } + } + } + +#ifdef USE_ADDITIVE_HASH + for (int i = 0; i < pd; i++) { + if (myRank[i] == pd - color) { + cumulative_hash += hOffset[i]; + } + } +#endif + + if (!outOfBounds) { + MatrixEntry r; + +#ifdef USE_ADDITIVE_HASH + r.index = hashTableInsert(cumulative_hash, myKey, idx * (pd + 1) + color); +#else + r.index = hashTableInsert(myKey, idx * (pd + 1) + color); +#endif + + r.weight = myBarycentric[color]; + matrix[idx * (pd + 1) + color] = r; + } + } +} + +template +__global__ static void cleanHashTable(const int elementCount, MatrixEntry* matrix) { + const int idx = threadIdx.x + blockIdx.x * blockDim.x; + + if (idx >= elementCount) + return; + + // find my hash table entry + int* e = table_entries + idx; + + // Check if I created my own key in the previous phase + if (*e >= 0) { + // Rehash my key and reset the pointer in order to merge with + // any other pixel that created a different entry under the + // same key. If the computation was serial this would never + // happen, but sometimes race conditions can make the same key + // be inserted twice. hashTableRetrieve always returns the + // earlier, so it's no problem as long as we rehash now. + +#ifdef LINEAR_D_MEMORY + // Get my key + short myKey[kd]; + generateKey(*e, myKey); + *e = hashTableRetrieve(myKey); +#else + *e = hashTableRetrieve(table_keys + *e * kd); +#endif + } +} + +template +__global__ static void splat( + const int elementCount, + scalar_t* values, + MatrixEntry* matrix, + scalar_t* table_values) { + const int color = threadIdx.y; + const int idx = threadIdx.x + blockIdx.x * blockDim.x; + + const bool outOfBounds = idx >= elementCount; + + if (outOfBounds) { + return; + } + + scalar_t* myValue = values + idx * vd; + + MatrixEntry r = matrix[idx * (pd + 1) + color]; + + matrix[idx * (pd + 1) + color].index = r.index = table_entries[r.index]; + scalar_t* val = table_values + r.index * (vd + 1); + + for (int j = 0; j < vd; j++) { + gpuAtomicAdd(val + j, myValue[j] * r.weight); + } + + gpuAtomicAdd(val + vd, r.weight); +} + +// splat splits by color, so extend the y coordinate to our blocks to represent that +// dim3 oldblocks((w-1)/8+1, (h-1)/8+1, 1); +// dim3 oldblockSize(8, 8, 1); +// oldblocks.y *= pd+1; +// splatCache<<>>(w, h, values, matrix); + +// int blockCount = (elementCount + 1) / BLOCK_SIZE + 1; +// int blockSize = BLOCK_SIZE; + +// splatCache<<>>(elementCount, values, matrix); + +template +__global__ static void splatCache( + const int elementCount, + scalar_t* values, + MatrixEntry* matrix, + scalar_t* table_values) { + // const int x = threadIdx.x + blockIdx.x * blockDim.x; + // const int y = threadIdx.y + (blockIdx.y/(pd+1)) * blockDim.y; + + // const int threadId = threadIdx.y*blockDim.x + threadIdx.x; + // const int color = blockIdx.y % (pd+1); + // const int idx = y*w + x; + + const int threadId = threadIdx.x; + const int color = threadIdx.y; + const int idx = threadIdx.x + blockIdx.x * BLOCK_SIZE; + + const bool outOfBounds = idx >= elementCount; + + __shared__ int sharedOffsets[BLOCK_SIZE]; + __shared__ scalar_t sharedValues[BLOCK_SIZE * (vd + 1)]; + + int myOffset = -1; + scalar_t* myValue = sharedValues + threadId * (vd + 1); + + if (!outOfBounds) { + scalar_t* value = values + idx * vd; + + MatrixEntry r = matrix[idx * (pd + 1) + color]; + + // convert the matrix entry from a pointer into the entries array to a pointer into the keys/values array + matrix[idx * (pd + 1) + color].index = r.index = table_entries[r.index]; + // record the offset into the keys/values array in shared space + myOffset = sharedOffsets[threadId] = r.index * (vd + 1); + + for (int j = 0; j < vd; j++) { + myValue[j] = value[j] * r.weight; + } + myValue[vd] = r.weight; + + } else { + sharedOffsets[threadId] = -1; + } + + __syncthreads(); + + // am I the first thread in this block to care about this key? + + if (outOfBounds) + return; + + for (int i = 0; i < BLOCK_SIZE; i++) { + if (i < threadId) { + if (myOffset == sharedOffsets[i]) { + // somebody else with higher priority cares about this key + return; + } + } else if (i > threadId) { + if (myOffset == sharedOffsets[i]) { + // someone else with lower priority cares about this key, accumulate it into mine + for (int j = 0; j <= vd; j++) { + sharedValues[threadId * (vd + 1) + j] += sharedValues[i * (vd + 1) + j]; + } + } + } + } + + // only the threads with something to write to main memory are still going + scalar_t* val = table_values + myOffset; + for (int j = 0; j <= vd; j++) { + gpuAtomicAdd(val + j, myValue[j]); + } +} + +template +__global__ static void blur( + int n, + scalar_t* newValues, + MatrixEntry* matrix, + int color, + scalar_t* table_values) { + const int idx = (blockIdx.y * gridDim.x + blockIdx.x) * blockDim.x * blockDim.y + threadIdx.x; + + if (idx >= n) + return; + + // Check if I'm valid + if (matrix[idx].index != idx) + return; + + // find my key and the keys of my neighbours + short myKey[pd + 1]; + short np[pd + 1]; + short nm[pd + 1]; + +#ifdef LINEAR_D_MEMORY + generateKey(idx, myKey); + for (int i = 0; i < pd; i++) { + np[i] = myKey[i] + 1; + nm[i] = myKey[i] - 1; + } +#else + for (int i = 0; i < pd; i++) { + myKey[i] = table_keys[idx * pd + i]; + np[i] = myKey[i] + 1; + nm[i] = myKey[i] - 1; + } +#endif + + np[color] -= pd + 1; + nm[color] += pd + 1; + +#ifdef USE_ADDITIVE_HASH + unsigned int hCurrent = hash(myKey); + int offNp = hashTableRetrieveWithHash(hCurrent + hOffset[color], np); + int offNm = hashTableRetrieveWithHash(hCurrent - hOffset[color], nm); +#else + int offNp = hashTableRetrieve(np); + int offNm = hashTableRetrieve(nm); +#endif + + scalar_t* valMe = table_values + (vd + 1) * idx; + scalar_t* valNp = table_values + (vd + 1) * offNp; + scalar_t* valNm = table_values + (vd + 1) * offNm; + scalar_t* valOut = newValues + (vd + 1) * idx; + + if (offNp >= 0 && offNm >= 0) { + for (int i = 0; i <= vd; i++) { + valOut[i] = (valNp[i] + (valMe[i] * 2) + valNm[i]) / 4; + } + } else if (offNp >= 0) { + for (int i = 0; i <= vd; i++) { + valOut[i] = (valNp[i] + (valMe[i] * 2)) / 4; + } + } else if (offNm >= 0) { + for (int i = 0; i <= vd; i++) { + valOut[i] = (valNm[i] + (valMe[i] * 2)) / 4; + } + } else { + for (int i = 0; i <= vd; i++) { + valOut[i] = valMe[i] * 2; + } + } +} + +template +__global__ static void slice( + const int elementCount, + scalar_t* values, + MatrixEntry* matrix, + scalar_t* table_values) { + const int threadId = threadIdx.x; + const int idx = threadIdx.x + blockIdx.x * BLOCK_SIZE; + const bool outOfBounds = idx >= elementCount; + + if (outOfBounds) + return; + + __shared__ scalar_t localValue[BLOCK_SIZE * vd]; + + scalar_t* myValue = localValue + threadId * vd; + scalar_t myWeight = 0; + + for (int i = 0; i < vd; i++) { + myValue[i] = 0; + } + + for (int i = 0; i <= pd; i++) { + MatrixEntry r = matrix[idx * (pd + 1) + i]; + scalar_t* val = table_values + r.index * (vd + 1); + + for (int j = 0; j < vd; j++) { + myValue[j] += r.weight * val[j]; + } + + myWeight += r.weight * val[vd]; + } + + myWeight = 1.0f / myWeight; + + for (int j = 0; j < vd; j++) { + values[idx * vd + j] = myValue[j] * myWeight; + } +} + +template +void PermutohedralCuda(scalar_t* values, scalar_t* positions, int elementCount, bool accurate) { + scalar_t blurVariance = accurate ? 0.5 : 0; + + scalar_t* scaleFactor; + cudaMalloc(&scaleFactor, pd * sizeof(scalar_t)); + + scalar_t scaleFactorHost[pd]; + for (int i = 0; i < pd; i++) { + scaleFactorHost[i] = (pd + 1) * sqrtf((1.0 / 6 + blurVariance) / ((i + 1) * (i + 2))); + } + + cudaMemcpy(scaleFactor, scaleFactorHost, pd * sizeof(scalar_t), cudaMemcpyHostToDevice); + + MatrixEntry* matrix; + cudaMalloc(&matrix, elementCount * (pd + 1) * sizeof(MatrixEntry)); + + scalar_t* table_values = createHashTable(elementCount * (pd + 1)); + + // Populate constant memory for hash helpers + unsigned long long int __host_two32 = ((unsigned long long int)1) << 32; + unsigned int __host_div_c = 2 * (elementCount * (pd + 1)); + unsigned int __host_div_l = ceilf(logf((float)__host_div_c) / logf(2.0f)); + unsigned int __host_div_m = (__host_two32 << __host_div_l) / __host_div_c - __host_two32 + 1; + cudaMemcpyToSymbol(__div_c, &__host_div_c, sizeof(unsigned int)); + cudaMemcpyToSymbol(__div_l, &__host_div_l, sizeof(unsigned int)); + cudaMemcpyToSymbol(__div_m, &__host_div_m, sizeof(unsigned int)); + + // Populate constant memory with hash of offset vectors + unsigned int hOffset_host[pd + 1]; + signed short offset[pd + 1]; + for (int i = 0; i < pd; offset[i] = 1, i++) + ; + for (int i = 0; i <= pd; i++) { + offset[i] -= pd + 1; + hOffset_host[i] = hash(offset); + offset[i] += pd + 1; + } + cudaMemcpyToSymbol(hOffset, &hOffset_host, sizeof(unsigned int) * (pd + 1)); + + int blockCount = (elementCount + 1) / BLOCK_SIZE + 1; + int blockSize = BLOCK_SIZE; + + createMatrix<<>>(elementCount, positions, values, scaleFactor, matrix); + + // fix duplicate hash table entries + int tableSize = elementCount * 2 * (pd + 1); + int cleanBlockSize = 32; + int cleanBlocks = (tableSize - 1) / cleanBlockSize + 1; + + cleanHashTable<<>>(tableSize, matrix); + + splat<<>>(elementCount, values, matrix, table_values); + + if (accurate) { + scalar_t* newValues; + cudaMalloc(&newValues, elementCount * (pd + 1) * (vd + 1) * sizeof(scalar_t)); + cudaMemset(newValues, 0, elementCount * (pd + 1) * (vd + 1) * sizeof(scalar_t)); + + for (int color = 0; color <= pd; color++) { + blur + <<>>(elementCount * (pd + 1), newValues, matrix, color, table_values); + + scalar_t* swap = newValues; + newValues = table_values; + table_values = swap; + } + + cudaFree(newValues); + } + + slice<<>>(elementCount, values, matrix, table_values); + + destroyHashTable(); + cudaFree(table_values); +} + +#define DECLARATION(dc, fc) \ + template void PermutohedralCuda(float* values, float* positions, int elementCount, bool accurate); \ + template void PermutohedralCuda(double* values, double* positions, int elementCount, bool accurate); +DO_FOR_AB(DECLARATION, 16, 19) diff --git a/monai/csrc/utils/meta_macros.h b/monai/csrc/utils/meta_macros.h new file mode 100644 index 0000000000..73d1851198 --- /dev/null +++ b/monai/csrc/utils/meta_macros.h @@ -0,0 +1,131 @@ +/* +Copyright 2020 MONAI Consortium +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#pragma once + +// Helper Macros: for internal use (see below) +#define _DO_1(TARGET) TARGET(1) +#define _DO_2(TARGET) TARGET(2) _DO_1(TARGET) +#define _DO_3(TARGET) TARGET(3) _DO_2(TARGET) +#define _DO_4(TARGET) TARGET(4) _DO_3(TARGET) +#define _DO_5(TARGET) TARGET(5) _DO_4(TARGET) +#define _DO_6(TARGET) TARGET(6) _DO_5(TARGET) +#define _DO_7(TARGET) TARGET(7) _DO_6(TARGET) +#define _DO_8(TARGET) TARGET(8) _DO_7(TARGET) +#define _DO_9(TARGET) TARGET(9) _DO_8(TARGET) +#define _DO_10(TARGET) TARGET(10) _DO_9(TARGET) +#define _DO_11(TARGET) TARGET(11) _DO_10(TARGET) +#define _DO_12(TARGET) TARGET(12) _DO_11(TARGET) +#define _DO_13(TARGET) TARGET(13) _DO_12(TARGET) +#define _DO_14(TARGET) TARGET(14) _DO_13(TARGET) +#define _DO_15(TARGET) TARGET(15) _DO_14(TARGET) +#define _DO_16(TARGET) TARGET(16) _DO_15(TARGET) +#define _DO_17(TARGET) TARGET(17) _DO_16(TARGET) +#define _DO_18(TARGET) TARGET(18) _DO_17(TARGET) +#define _DO_19(TARGET) TARGET(19) _DO_18(TARGET) +#define _DO_20(TARGET) TARGET(20) _DO_19(TARGET) +#define _DO_21(TARGET) TARGET(21) _DO_20(TARGET) +#define _DO_22(TARGET) TARGET(22) _DO_21(TARGET) +#define _DO_23(TARGET) TARGET(23) _DO_22(TARGET) +#define _DO_24(TARGET) TARGET(24) _DO_23(TARGET) +#define _DO_25(TARGET) TARGET(25) _DO_24(TARGET) +#define _DO_26(TARGET) TARGET(26) _DO_25(TARGET) +#define _DO_27(TARGET) TARGET(27) _DO_26(TARGET) +#define _DO_28(TARGET) TARGET(28) _DO_27(TARGET) +#define _DO_29(TARGET) TARGET(29) _DO_28(TARGET) +#define _DO_30(TARGET) TARGET(30) _DO_29(TARGET) +#define _DO_31(TARGET) TARGET(31) _DO_30(TARGET) +#define _DO_32(TARGET) TARGET(32) _DO_31(TARGET) + +#define _DO_A_1(TARGET, A) TARGET(A, 1) +#define _DO_A_2(TARGET, A) TARGET(A, 2) _DO_A_1(TARGET, A) +#define _DO_A_3(TARGET, A) TARGET(A, 3) _DO_A_2(TARGET, A) +#define _DO_A_4(TARGET, A) TARGET(A, 4) _DO_A_3(TARGET, A) +#define _DO_A_5(TARGET, A) TARGET(A, 5) _DO_A_4(TARGET, A) +#define _DO_A_6(TARGET, A) TARGET(A, 6) _DO_A_5(TARGET, A) +#define _DO_A_7(TARGET, A) TARGET(A, 7) _DO_A_6(TARGET, A) +#define _DO_A_8(TARGET, A) TARGET(A, 8) _DO_A_7(TARGET, A) +#define _DO_A_9(TARGET, A) TARGET(A, 9) _DO_A_8(TARGET, A) +#define _DO_A_10(TARGET, A) TARGET(A, 10) _DO_A_9(TARGET, A) +#define _DO_A_11(TARGET, A) TARGET(A, 11) _DO_A_10(TARGET, A) +#define _DO_A_12(TARGET, A) TARGET(A, 12) _DO_A_11(TARGET, A) +#define _DO_A_13(TARGET, A) TARGET(A, 13) _DO_A_12(TARGET, A) +#define _DO_A_14(TARGET, A) TARGET(A, 14) _DO_A_13(TARGET, A) +#define _DO_A_15(TARGET, A) TARGET(A, 15) _DO_A_14(TARGET, A) +#define _DO_A_16(TARGET, A) TARGET(A, 16) _DO_A_15(TARGET, A) +#define _DO_A_17(TARGET, A) TARGET(A, 17) _DO_A_16(TARGET, A) +#define _DO_A_18(TARGET, A) TARGET(A, 18) _DO_A_17(TARGET, A) +#define _DO_A_19(TARGET, A) TARGET(A, 19) _DO_A_18(TARGET, A) +#define _DO_A_20(TARGET, A) TARGET(A, 20) _DO_A_19(TARGET, A) +#define _DO_A_21(TARGET, A) TARGET(A, 21) _DO_A_20(TARGET, A) +#define _DO_A_22(TARGET, A) TARGET(A, 22) _DO_A_21(TARGET, A) +#define _DO_A_23(TARGET, A) TARGET(A, 23) _DO_A_22(TARGET, A) +#define _DO_A_24(TARGET, A) TARGET(A, 24) _DO_A_23(TARGET, A) +#define _DO_A_25(TARGET, A) TARGET(A, 25) _DO_A_24(TARGET, A) +#define _DO_A_26(TARGET, A) TARGET(A, 26) _DO_A_25(TARGET, A) +#define _DO_A_27(TARGET, A) TARGET(A, 27) _DO_A_26(TARGET, A) +#define _DO_A_28(TARGET, A) TARGET(A, 28) _DO_A_27(TARGET, A) +#define _DO_A_29(TARGET, A) TARGET(A, 29) _DO_A_28(TARGET, A) +#define _DO_A_30(TARGET, A) TARGET(A, 30) _DO_A_29(TARGET, A) +#define _DO_A_31(TARGET, A) TARGET(A, 31) _DO_A_30(TARGET, A) +#define _DO_A_32(TARGET, A) TARGET(A, 32) _DO_A_31(TARGET, A) + +#define _DO_1_B(TARGET, B_RANGE) _DO_A_##B_RANGE(TARGET, 1) +#define _DO_2_B(TARGET, B_RANGE) _DO_A_##B_RANGE(TARGET, 2) _DO_1_B(TARGET, B_RANGE) +#define _DO_3_B(TARGET, B_RANGE) _DO_A_##B_RANGE(TARGET, 3) _DO_2_B(TARGET, B_RANGE) +#define _DO_4_B(TARGET, B_RANGE) _DO_A_##B_RANGE(TARGET, 4) _DO_3_B(TARGET, B_RANGE) +#define _DO_5_B(TARGET, B_RANGE) _DO_A_##B_RANGE(TARGET, 5) _DO_4_B(TARGET, B_RANGE) +#define _DO_6_B(TARGET, B_RANGE) _DO_A_##B_RANGE(TARGET, 6) _DO_5_B(TARGET, B_RANGE) +#define _DO_7_B(TARGET, B_RANGE) _DO_A_##B_RANGE(TARGET, 7) _DO_6_B(TARGET, B_RANGE) +#define _DO_8_B(TARGET, B_RANGE) _DO_A_##B_RANGE(TARGET, 8) _DO_7_B(TARGET, B_RANGE) +#define _DO_9_B(TARGET, B_RANGE) _DO_A_##B_RANGE(TARGET, 9) _DO_8_B(TARGET, B_RANGE) +#define _DO_10_B(TARGET, B_RANGE) _DO_A_##B_RANGE(TARGET, 10) _DO_9_B(TARGET, B_RANGE) +#define _DO_11_B(TARGET, B_RANGE) _DO_A_##B_RANGE(TARGET, 11) _DO_10_B(TARGET, B_RANGE) +#define _DO_12_B(TARGET, B_RANGE) _DO_A_##B_RANGE(TARGET, 12) _DO_11_B(TARGET, B_RANGE) +#define _DO_13_B(TARGET, B_RANGE) _DO_A_##B_RANGE(TARGET, 13) _DO_12_B(TARGET, B_RANGE) +#define _DO_14_B(TARGET, B_RANGE) _DO_A_##B_RANGE(TARGET, 14) _DO_13_B(TARGET, B_RANGE) +#define _DO_15_B(TARGET, B_RANGE) _DO_A_##B_RANGE(TARGET, 15) _DO_14_B(TARGET, B_RANGE) +#define _DO_16_B(TARGET, B_RANGE) _DO_A_##B_RANGE(TARGET, 16) _DO_15_B(TARGET, B_RANGE) +#define _DO_17_B(TARGET, B_RANGE) _DO_A_##B_RANGE(TARGET, 17) _DO_16_B(TARGET, B_RANGE) +#define _DO_18_B(TARGET, B_RANGE) _DO_A_##B_RANGE(TARGET, 18) _DO_17_B(TARGET, B_RANGE) +#define _DO_19_B(TARGET, B_RANGE) _DO_A_##B_RANGE(TARGET, 19) _DO_18_B(TARGET, B_RANGE) +#define _DO_20_B(TARGET, B_RANGE) _DO_A_##B_RANGE(TARGET, 20) _DO_19_B(TARGET, B_RANGE) +#define _DO_21_B(TARGET, B_RANGE) _DO_A_##B_RANGE(TARGET, 21) _DO_20_B(TARGET, B_RANGE) +#define _DO_22_B(TARGET, B_RANGE) _DO_A_##B_RANGE(TARGET, 22) _DO_21_B(TARGET, B_RANGE) +#define _DO_23_B(TARGET, B_RANGE) _DO_A_##B_RANGE(TARGET, 23) _DO_22_B(TARGET, B_RANGE) +#define _DO_24_B(TARGET, B_RANGE) _DO_A_##B_RANGE(TARGET, 24) _DO_23_B(TARGET, B_RANGE) +#define _DO_25_B(TARGET, B_RANGE) _DO_A_##B_RANGE(TARGET, 25) _DO_24_B(TARGET, B_RANGE) +#define _DO_26_B(TARGET, B_RANGE) _DO_A_##B_RANGE(TARGET, 26) _DO_25_B(TARGET, B_RANGE) +#define _DO_27_B(TARGET, B_RANGE) _DO_A_##B_RANGE(TARGET, 27) _DO_26_B(TARGET, B_RANGE) +#define _DO_28_B(TARGET, B_RANGE) _DO_A_##B_RANGE(TARGET, 28) _DO_27_B(TARGET, B_RANGE) +#define _DO_29_B(TARGET, B_RANGE) _DO_A_##B_RANGE(TARGET, 29) _DO_28_B(TARGET, B_RANGE) +#define _DO_30_B(TARGET, B_RANGE) _DO_A_##B_RANGE(TARGET, 30) _DO_29_B(TARGET, B_RANGE) +#define _DO_31_B(TARGET, B_RANGE) _DO_A_##B_RANGE(TARGET, 31) _DO_30_B(TARGET, B_RANGE) +#define _DO_32_B(TARGET, B_RANGE) _DO_A_##B_RANGE(TARGET, 32) _DO_31_B(TARGET, B_RANGE) + +#define _CASE_A(A) \ + case (A): \ + CASE(A) break; +#define _CASE_AB(A, B) \ + case (A * 100 + B): \ + CASE(A, B) break; + +// Preproccessor For Loops +#define DO_FOR_A(TARGET, A_RANGE) _DO_##A_RANGE(TARGET) +#define DO_FOR_AB(TARGET, A_RANGE, B_RANGE) _DO_##A_RANGE##_B(TARGET, B_RANGE) + +// Preproccessor Switch Statement Generators +#define SWITCH_A(CASE, A_RANGE, A) \ + switch (A) { DO_FOR_A(_CASE_A, A_RANGE) } +#define SWITCH_AB(CALL, A_RANGE, B_RANGE, A, B) \ + switch (A * 100 + B) { DO_FOR_AB(_CASE_AB, A_RANGE, B_RANGE) } diff --git a/monai/csrc/utils/tensor_description.h b/monai/csrc/utils/tensor_description.h new file mode 100644 index 0000000000..6072037f72 --- /dev/null +++ b/monai/csrc/utils/tensor_description.h @@ -0,0 +1,40 @@ + +#include + +// Struct to easily cache descriptive information about a tensor. +// This is helpful as regular calls to the size and stride member +// functions of tensors appear to cause memory issues. +struct TensorDescription { + public: + TensorDescription(torch::Tensor tensor) { + batchCount = tensor.size(0); + batchStride = tensor.stride(0); + + channelCount = tensor.size(1); + channelStride = tensor.stride(1); + + dimensions = tensor.dim() - 2; + sizes = new int[dimensions]; + strides = new int[dimensions]; + + for (int i = 0; i < dimensions; i++) { + sizes[i] = tensor.size(i + 2); + strides[i] = tensor.stride(i + 2); + } + } + + ~TensorDescription() { + delete[] sizes; + delete[] strides; + } + + int batchCount; + int batchStride; + + int channelCount; + int channelStride; + + int dimensions; + int* sizes; + int* strides; +}; diff --git a/monai/networks/layers/__init__.py b/monai/networks/layers/__init__.py index 9125dc38cf..f400eaf3a3 100644 --- a/monai/networks/layers/__init__.py +++ b/monai/networks/layers/__init__.py @@ -11,5 +11,6 @@ from .convutils import * from .factories import * +from .filtering import * from .simplelayers import * from .spatial_transforms import * diff --git a/monai/networks/layers/filtering.py b/monai/networks/layers/filtering.py new file mode 100644 index 0000000000..dcb172d892 --- /dev/null +++ b/monai/networks/layers/filtering.py @@ -0,0 +1,58 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from monai.utils.module import optional_import + +_C, _ = optional_import("monai._C") + +__all__ = ["BilateralFilter"] + + +class BilateralFilter(torch.autograd.Function): + """ + Blurs the input tensor spatially whilst preserving edges. Can run on 1D, 2D, or 3D, + tensors (on top of Batch and Channel dimensions). Two implementations are provided, + an exact solution and a much faster approximation which uses a permutohedral lattice. + + See: + https://en.wikipedia.org/wiki/Bilateral_filter + https://graphics.stanford.edu/papers/permutohedral/ + + Args: + input: input tensor. + + spatial sigma: the standard deviation of the spatial blur. Higher values can + hurt performace when not using the approximate method (see fast approx). + + color sigma: the standard deviation of the color blur. Lower values preserve + edges better whilst higher values tend to a simple gaussian spatial blur. + + fast approx: This flag chooses between two implementations. The approximate method may + produce artifacts in some scenarios whereas the exact solution may be intolerably + slow for high spatial standard deviations. + + Returns: + output (torch.Tensor): output tensor. + """ + + @staticmethod + def forward(ctx, input, spatial_sigma=5, color_sigma=0.5, fast_approx=True): + ctx.save_for_backward(spatial_sigma, color_sigma, fast_approx) + output_data = _C.bilateral_filter(input, spatial_sigma, color_sigma, fast_approx) + return output_data + + @staticmethod + def backward(ctx, grad_output): + spatial_sigma, color_sigma, fast_approx = ctx.saved_variables + grad_input = _C.bilateral_filter(grad_output, spatial_sigma, color_sigma, fast_approx) + return grad_input diff --git a/tests/test_bilateral_approx_cpu.py b/tests/test_bilateral_approx_cpu.py new file mode 100644 index 0000000000..13aaaeb34e --- /dev/null +++ b/tests/test_bilateral_approx_cpu.py @@ -0,0 +1,381 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import torch +from parameterized import parameterized + +from monai.networks.layers.filtering import BilateralFilter +from tests.utils import skip_if_no_cpp_extention + +TEST_CASES = [ + [ + # Case Descirption + "1 dimension, 1 channel, low spatial sigma, low color sigma", + # Spatial and Color Sigmas + (1, 0.2), + # Input + [ + # Batch 0 + [ + # Channel 0 + [1, 0, 0, 0, 1] + ], + # Batch 1 + [ + # Channel 0 + [0, 0, 1, 0, 0] + ], + ], + # Expected + [ + # Batch 0 + [ + # Channel 0 + [1.000000, 0.000000, 0.000000, 0.000000, 1.000000] + ], + # Batch 1 + [ + # Channel 0 + [0.000000, 0.000000, 1.000000, 0.000000, 0.000000] + ], + ], + ], + [ + # Case Descirption + "1 dimension, 1 channel, low spatial sigma, high color sigma", + # Spatial and Color Sigmas + (1, 0.9), + # Input + [ + # Batch 0 + [ + # Channel 0 + [1, 0, 0, 0, 1] + ], + # Batch 1 + [ + # Channel 0 + [0, 0, 1, 0, 0] + ], + ], + # Expected + [ + # Batch 0 + [ + # Channel 0 + [0.631360, 0.099349, 0.070177, 0.164534, 0.649869] + ], + # Batch 1 + [ + # Channel 0 + [0.052271, 0.173599, 0.481337, 0.183721, 0.045619] + ], + ], + ], + [ + # Case Descirption + "1 dimension, 1 channel, high spatial sigma, low color sigma", + # Spatial and Color Sigmas + (4, 0.2), + # Input + [ + # Batch 0 + [ + # Channel 0 + [1, 0, 0, 0, 1] + ], + # Batch 1 + [ + # Channel 0 + [0, 0, 1, 0, 0] + ], + ], + # Expected + [ + # Batch 0 + [ + # Channel 0 + [1.000000, 0.000000, 0.000000, 0.000000, 1.000000] + ], + # Batch 1 + [ + # Channel 0 + [0.000000, 0.000000, 1.000000, 0.000000, 0.000000] + ], + ], + ], + [ + # Case Descirption + "1 dimension, 1 channel, high spatial sigma, high color sigma", + # Sigmas + (4, 0.9), + # Input + [ + # Batch 0 + [ + # Channel 0 + [1, 0, 0, 0, 1] + ], + # Batch 1 + [ + # Channel 0 + [0, 0, 1, 0, 0] + ], + ], + # Expected + [ + # Batch 0 + [ + # Channel 0 + [0.497667, 0.268683, 0.265026, 0.261467, 0.495981] + ], + # Batch 1 + [ + # Channel 0 + [0.145959, 0.142282, 0.315710, 0.135609, 0.132572] + ], + ], + ], + [ + # Case Descirption + "1 dimension, 4 channel, low spatial sigma, high color sigma", + # Spatial and Color Sigmas + (1, 0.9), + # Input + [ + # Batch 0 + [ + # Channel 0 + [1, 0, 0, 0, 0], + # Channel 1 + [1, 0, 1, 0, 0], + # Channel 2 + [0, 0, 1, 0, 1], + # Channel 3 + [0, 0, 0, 0, 1], + ] + ], + # Expected + [ + # Batch 0 + [ + # Channel 0 + [0.960843, 0.073540, 0.027689, 0.002676, 0.000000], + # Channel 1 + [0.960843, 0.073540, 0.951248, 0.003033, 0.000750], + # Channel 2 + [0.000000, 0.000000, 0.923559, 0.000357, 0.981324], + # Channel 3 + [0.000000, 0.000000, 0.000000, 0.000000, 0.980574], + ] + ], + ], + [ + # Case Descirption + "2 dimension, 1 channel, high spatial sigma, high color sigma", + # Sigmas + (4, 0.9), + # Input + [ + # Batch 0 + [ + # Channel 0 + [[1, 0, 0, 0, 1], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, 0, 0, 0, 1]] + ], + # Batch 1 + [ + # Channel 0 + [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 1, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]] + ], + ], + # Expected + [ + # Batch 0 + [ + # Channel 0 + [ + [0.213684, 0.094356, 0.092973, 0.091650, 0.216281], + [0.094085, 0.092654, 0.091395, 0.090186, 0.089302], + [0.092436, 0.091150, 0.090008, 0.088896, 0.088897], + [0.090849, 0.089717, 0.088759, 0.087751, 0.088501], + [0.211458, 0.088334, 0.087495, 0.087049, 0.212173], + ] + ], + # Batch 1 + [ + # Channel 0 + [ + [0.033341, 0.031314, 0.029367, 0.027494, 0.025692], + [0.031869, 0.030632, 0.028820, 0.027074, 0.025454], + [0.030455, 0.029628, 0.084257, 0.026704, 0.025372], + [0.029095, 0.028391, 0.027790, 0.026375, 0.025292], + [0.027786, 0.027197, 0.026692, 0.026181, 0.025213], + ] + ], + ], + ], + [ + # Case Descirption + "2 dimension, 4 channel, high spatial sigma, high color sigma", + # Spatial and Color Sigmas + (4, 0.9), + # Input + [ + # Batch 0 + [ + # Channel 0 + [[1, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 1]], + # Channel 1 + [[1, 0, 1, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 1, 0, 1]], + # Channel 2 + [[0, 0, 1, 0, 1], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, 0, 1, 0, 0]], + # Channel 3 + [[0, 0, 0, 0, 1], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, 0, 0, 0, 0]], + ] + ], + # Expected + [ + # Batch 0 + [ + # Channel 0 + [ + [0.244373, 0.014488, 0.036589, 0.014226, 0.024329], + [0.014108, 0.014228, 0.014096, 0.013961, 0.013823], + [0.013574, 0.013757, 0.013836, 0.013699, 0.013558], + [0.013008, 0.013211, 0.013404, 0.013438, 0.013295], + [0.025179, 0.012634, 0.034555, 0.013050, 0.237582], + ], + # Channel 1 + [ + [0.271496, 0.015547, 0.439432, 0.015700, 0.089579], + [0.015252, 0.015702, 0.015779, 0.015859, 0.015940], + [0.015020, 0.015556, 0.015935, 0.016015, 0.016098], + [0.014774, 0.015331, 0.015860, 0.016171, 0.016255], + [0.107384, 0.015094, 0.462471, 0.016166, 0.263480], + ], + # Channel 2 + [ + [0.027123, 0.003527, 0.467273, 0.004912, 0.645776], + [0.003810, 0.004908, 0.005605, 0.006319, 0.007050], + [0.004816, 0.005991, 0.006989, 0.007716, 0.008459], + [0.005880, 0.007060, 0.008179, 0.009101, 0.009858], + [0.633398, 0.008191, 0.496893, 0.010376, 0.025898], + ], + # Channel 3 + [ + [0.000000, 0.002468, 0.064430, 0.003437, 0.580526], + [0.002666, 0.003434, 0.003922, 0.004422, 0.004933], + [0.003370, 0.004192, 0.004890, 0.005399, 0.005919], + [0.004115, 0.004940, 0.005723, 0.006368, 0.006898], + [0.551194, 0.005731, 0.068977, 0.007260, 0.000000], + ], + ] + ], + ], + [ + # Case Descirption + "3 dimension, 1 channel, high spatial sigma, high color sigma", + # Sigmas + (4, 0.9), + # Input + [ + # Batch 0 + [ + # Channel 0 + [ + # Frame 0 + [[1, 0, 0, 0, 1], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, 0, 0, 0, 1]], + # Frame 1 + [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]], + # Frame 2 + [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]], + # Frame 3 + [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]], + # Frame 4 + [[1, 0, 0, 0, 1], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, 0, 0, 0, 1]], + ] + ], + ], + # Expected + [ + # Batch 0 + [ + # Channel 0 + [ + # Frame 0 + [ + [0.086801, 0.036670, 0.035971, 0.035304, 0.088456], + [0.036639, 0.035652, 0.035009, 0.034394, 0.033803], + [0.035899, 0.034897, 0.034136, 0.033566, 0.033129], + [0.035180, 0.034238, 0.033413, 0.032811, 0.032577], + [0.088290, 0.033597, 0.032821, 0.032134, 0.088786], + ], + # Frame 1 + [ + [0.036286, 0.035269, 0.034632, 0.034021, 0.033435], + [0.035398, 0.034485, 0.033922, 0.033381, 0.033177], + [0.034688, 0.033822, 0.033169, 0.032664, 0.032780], + [0.034024, 0.033234, 0.032533, 0.032005, 0.032388], + [0.033564, 0.032797, 0.032118, 0.031525, 0.032105], + ], + # Frame 2 + [ + [0.035225, 0.034169, 0.033404, 0.032843, 0.032766], + [0.034383, 0.033487, 0.032908, 0.032415, 0.032650], + [0.033691, 0.032921, 0.032353, 0.031900, 0.032384], + [0.033080, 0.032390, 0.031786, 0.031432, 0.032008], + [0.033099, 0.032373, 0.031737, 0.031479, 0.032054], + ], + # Frame 3 + [ + [0.034216, 0.033231, 0.032337, 0.031758, 0.032101], + [0.033456, 0.032669, 0.031913, 0.031455, 0.032034], + [0.032788, 0.032140, 0.031618, 0.031413, 0.031977], + [0.032221, 0.031650, 0.031145, 0.031130, 0.031652], + [0.032642, 0.031968, 0.031378, 0.031433, 0.032003], + ], + # Frame 4 + [ + [0.086207, 0.032335, 0.031499, 0.030832, 0.087498], + [0.032570, 0.031884, 0.031155, 0.030858, 0.031401], + [0.031967, 0.031417, 0.030876, 0.030881, 0.031388], + [0.031602, 0.031103, 0.030696, 0.030960, 0.031455], + [0.090599, 0.031546, 0.031127, 0.031386, 0.083483], + ], + ] + ] + ], + ], +] + + +@skip_if_no_cpp_extention +class BilateralFilterTestCaseCpuApprox(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_cpu_approx(self, test_case_description, sigmas, input, expected): + + # Params to determine the implementation to test + device = torch.device("cpu") + fast_approx = True + + # Create input tensor and apply filter + input_tensor = torch.from_numpy(np.array(input)).to(dtype=torch.float, device=device) + output = BilateralFilter.apply(input_tensor, *sigmas, fast_approx).cpu().numpy() + + # Ensure result are as expected + np.testing.assert_allclose(output, expected, atol=1e-5) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_bilateral_approx_cuda.py b/tests/test_bilateral_approx_cuda.py new file mode 100644 index 0000000000..5ea0d997d1 --- /dev/null +++ b/tests/test_bilateral_approx_cuda.py @@ -0,0 +1,386 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import torch +from parameterized import parameterized + +from monai.networks.layers.filtering import BilateralFilter +from tests.utils import skip_if_no_cpp_extention, skip_if_no_cuda + +TEST_CASES = [ + [ + # Case Descirption + "1 dimension, 1 channel, low spatial sigma, low color sigma", + # Spatial and Color Sigmas + (1, 0.2), + # Input + [ + # Batch 0 + [ + # Channel 0 + [1, 0, 0, 0, 1] + ], + # Batch 1 + [ + # Channel 0 + [0, 0, 1, 0, 0] + ], + ], + # Expected + [ + # Batch 0 + [ + # Channel 0 + [1.000000, 0.000000, 0.000000, 0.000000, 1.000000] + ], + # Batch 1 + [ + # Channel 0 + [0.000000, 0.000000, 1.000000, 0.000000, 0.000000] + ], + ], + ], + [ + # Case Descirption + "1 dimension, 1 channel, low spatial sigma, high color sigma", + # Spatial and Color Sigmas + (1, 0.9), + # Input + [ + # Batch 0 + [ + # Channel 0 + [1, 0, 0, 0, 1] + ], + # Batch 1 + [ + # Channel 0 + [0, 0, 1, 0, 0] + ], + ], + # Expected + [ + # Batch 0 + [ + # Channel 0 + [0.880626, 0.306148, 0.158734, 0.164534, 0.754386] + ], + # Batch 1 + [ + # Channel 0 + [0.019010, 0.104507, 0.605634, 0.183721, 0.045619] + ], + ], + ], + [ + # Case Descirption + "1 dimension, 1 channel, high spatial sigma, low color sigma", + # Spatial and Color Sigmas + (4, 0.2), + # Input + [ + # Batch 0 + [ + # Channel 0 + [1, 0, 0, 0, 1] + ], + # Batch 1 + [ + # Channel 0 + [0, 0, 1, 0, 0] + ], + ], + # Expected + [ + # Batch 0 + [ + # Channel 0 + [1.000000, 0.000000, 0.000000, 0.000000, 1.000000] + ], + # Batch 1 + [ + # Channel 0 + [0.000000, 0.000000, 1.000000, 0.000000, 0.000000] + ], + ], + ], + [ + # Case Descirption + "1 dimension, 1 channel, high spatial sigma, high color sigma", + # Sigmas + (4, 0.9), + # Input + [ + # Batch 0 + [ + # Channel 0 + [1, 0, 0, 0, 1] + ], + # Batch 1 + [ + # Channel 0 + [0, 0, 1, 0, 0] + ], + ], + # Expected + [ + # Batch 0 + [ + # Channel 0 + [0.497667, 0.268683, 0.265026, 0.261467, 0.495981] + ], + # Batch 1 + [ + # Channel 0 + [0.149889, 0.148226, 0.367978, 0.144023, 0.141317] + ], + ], + ], + [ + # Case Descirption + "1 dimension, 4 channel, low spatial sigma, high color sigma", + # Spatial and Color Sigmas + (1, 0.9), + # Input + [ + # Batch 0 + [ + # Channel 0 + [1, 0, 0, 0, 0], + # Channel 1 + [1, 0, 1, 0, 0], + # Channel 2 + [0, 0, 1, 0, 1], + # Channel 3 + [0, 0, 0, 0, 1], + ] + ], + # Expected + [ + # Batch 0 + [ + # Channel 0 + [0.988107, 0.061340, 0.001565, 0.000011, 0.000000], + # Channel 1 + [0.988107, 0.061340, 0.998000, 0.000016, 0.000123], + # Channel 2 + [0.000000, 0.000000, 0.996435, 0.000006, 0.999236], + # Channel 3 + [0.000000, 0.000000, 0.000000, 0.000000, 0.999113], + ] + ], + ], + [ + # Case Descirption + "2 dimension, 1 channel, high spatial sigma, high color sigma", + # Sigmas + (4, 0.9), + # Input + [ + # Batch 0 + [ + # Channel 0 + [[1, 0, 0, 0, 1], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, 0, 0, 0, 1]] + ], + # Batch 1 + [ + # Channel 0 + [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 1, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]] + ], + ], + # Expected + [ + # Batch 0 + [ + # Channel 0 + [ + [0.211469, 0.094356, 0.092973, 0.091650, 0.211894], + [0.093755, 0.091753, 0.090524, 0.089343, 0.088384], + [0.091803, 0.089783, 0.088409, 0.087346, 0.086927], + [0.089938, 0.088126, 0.086613, 0.085601, 0.085535], + [0.208359, 0.086535, 0.085179, 0.084210, 0.205858], + ] + ], + # Batch 1 + [ + # Channel 0 + [ + [0.032760, 0.030146, 0.027442, 0.024643, 0.021744], + [0.030955, 0.029416, 0.026574, 0.023629, 0.020841], + [0.028915, 0.027834, 0.115442, 0.022515, 0.020442], + [0.026589, 0.025447, 0.024319, 0.021286, 0.019964], + [0.023913, 0.022704, 0.021510, 0.020388, 0.019379], + ] + ], + ], + ], + [ + # Case Descirption + "2 dimension, 4 channel, high spatial sigma, high color sigma", + # Spatial and Color Sigmas + (4, 0.9), + # Input + [ + # Batch 0 + [ + # Channel 0 + [[1, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 1]], + # Channel 1 + [[1, 0, 1, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 1, 0, 1]], + # Channel 2 + [[0, 0, 1, 0, 1], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, 0, 1, 0, 0]], + # Channel 3 + [[0, 0, 0, 0, 1], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, 0, 0, 0, 0]], + ] + ], + # Expected + [ + # Batch 0 + [ + # Channel 0 + [ + [0.557349, 0.011031, 0.001800, 0.011265, 0.000631], + [0.009824, 0.010361, 0.010429, 0.010506, 0.010595], + [0.008709, 0.009252, 0.009688, 0.009714, 0.009744], + [0.007589, 0.008042, 0.008576, 0.008887, 0.008852], + [0.000420, 0.006827, 0.001048, 0.007763, 0.190722], + ], + # Channel 1 + [ + [0.614072, 0.011045, 0.925766, 0.011287, 0.007548], + [0.009838, 0.010382, 0.010454, 0.010536, 0.010630], + [0.008727, 0.009277, 0.009720, 0.009751, 0.009787], + [0.007611, 0.008071, 0.008613, 0.008932, 0.008904], + [0.027088, 0.006859, 0.950749, 0.007815, 0.230270], + ], + # Channel 2 + [ + [0.056723, 0.000150, 0.973790, 0.000233, 0.990814], + [0.000151, 0.000214, 0.000257, 0.000307, 0.000364], + [0.000186, 0.000257, 0.000328, 0.000384, 0.000449], + [0.000221, 0.000295, 0.000382, 0.000465, 0.000538], + [0.993884, 0.000333, 0.984743, 0.000532, 0.039548], + ], + # Channel 3 + [ + [0.000000, 0.000136, 0.049824, 0.000210, 0.983897], + [0.000136, 0.000193, 0.000232, 0.000277, 0.000329], + [0.000168, 0.000232, 0.000297, 0.000347, 0.000405], + [0.000200, 0.000266, 0.000345, 0.000420, 0.000485], + [0.967217, 0.000301, 0.035041, 0.000481, 0.000000], + ], + ] + ], + ], + [ + # Case Descirption + "3 dimension, 1 channel, high spatial sigma, high color sigma", + # Sigmas + (4, 0.9), + # Input + [ + # Batch 0 + [ + # Channel 0 + [ + # Frame 0 + [[1, 0, 0, 0, 1], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, 0, 0, 0, 1]], + # Frame 1 + [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]], + # Frame 2 + [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]], + # Frame 3 + [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]], + # Frame 4 + [[1, 0, 0, 0, 1], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, 0, 0, 0, 1]], + ] + ], + ], + # Expected + [ + # Batch 0 + [ + # Channel 0 + [ + # Frame 0 + [ + [0.085451, 0.037820, 0.036880, 0.035978, 0.084296], + [0.037939, 0.036953, 0.036155, 0.035385, 0.034640], + [0.037167, 0.036302, 0.035603, 0.034931, 0.034465], + [0.036469, 0.035724, 0.035137, 0.034572, 0.034480], + [0.088942, 0.035193, 0.034682, 0.034266, 0.090568], + ], + # Frame 1 + [ + [0.037125, 0.035944, 0.035103, 0.033429, 0.033498], + [0.033380, 0.032653, 0.033748, 0.033073, 0.032549], + [0.034834, 0.034001, 0.033500, 0.032902, 0.032560], + [0.033972, 0.033554, 0.033220, 0.032765, 0.032570], + [0.033590, 0.033222, 0.032927, 0.032689, 0.032629], + ], + # Frame 2 + [ + [0.035635, 0.034468, 0.033551, 0.032818, 0.032302], + [0.034523, 0.032830, 0.032146, 0.031536, 0.031149], + [0.033612, 0.032011, 0.031664, 0.031128, 0.030839], + [0.032801, 0.031668, 0.031529, 0.031198, 0.030978], + [0.032337, 0.031550, 0.031419, 0.031383, 0.031211], + ], + # Frame 3 + [ + [0.034300, 0.033236, 0.032239, 0.031517, 0.031133], + [0.033357, 0.031842, 0.031035, 0.030471, 0.030126], + [0.032563, 0.031094, 0.030156, 0.029703, 0.029324], + [0.031850, 0.030505, 0.030027, 0.029802, 0.029461], + [0.031555, 0.030121, 0.029943, 0.030000, 0.029700], + ], + # Frame 4 + [ + [0.083156, 0.032122, 0.031204, 0.030380, 0.080582], + [0.032296, 0.030936, 0.030170, 0.029557, 0.029124], + [0.031617, 0.030293, 0.029377, 0.028886, 0.028431], + [0.031084, 0.029859, 0.028839, 0.028439, 0.027973], + [0.164616, 0.029457, 0.028484, 0.028532, 0.211082], + ], + ] + ] + ], + ], +] + + +@skip_if_no_cuda +@skip_if_no_cpp_extention +class BilateralFilterTestCaseCudaApprox(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_cuda_approx(self, test_case_description, sigmas, input, expected): + + # Skip this test + if not torch.cuda.is_available(): + return + + # Params to determine the implementation to test + device = torch.device("cuda") + fast_approx = True + + # Create input tensor and apply filter + input_tensor = torch.from_numpy(np.array(input)).to(dtype=torch.float, device=device) + output = BilateralFilter.apply(input_tensor, *sigmas, fast_approx).cpu().numpy() + + # Ensure result are as expected + np.testing.assert_allclose(output, expected, atol=1e-2) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_bilateral_precise.py b/tests/test_bilateral_precise.py new file mode 100644 index 0000000000..f2a265b106 --- /dev/null +++ b/tests/test_bilateral_precise.py @@ -0,0 +1,403 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import torch +from parameterized import parameterized + +from monai.networks.layers.filtering import BilateralFilter +from tests.utils import skip_if_no_cpp_extention, skip_if_no_cuda + +TEST_CASES = [ + [ + # Case Descirption + "1 dimension, 1 channel, low spatial sigma, low color sigma", + # Spatial and Color Sigmas + (1, 0.2), + # Input + [ + # Batch 0 + [ + # Channel 0 + [1, 0, 0, 0, 1] + ], + # Batch 1 + [ + # Channel 0 + [0, 0, 1, 0, 0] + ], + ], + # Expected + [ + # Batch 0 + [ + # Channel 0 + [0.999998, 0.000002, 0.000000, 0.000002, 0.999998] + ], + # Batch 1 + [ + # Channel 0 + [0.000000, 0.000001, 0.999995, 0.000001, 0.000000] + ], + ], + ], + [ + # Case Descirption + "1 dimension, 1 channel, low spatial sigma, high color sigma", + # Spatial and Color Sigmas + (1, 0.9), + # Input + [ + # Batch 0 + [ + # Channel 0 + [1, 0, 0, 0, 1] + ], + # Batch 1 + [ + # Channel 0 + [0, 0, 1, 0, 0] + ], + ], + # Expected + [ + # Batch 0 + [ + # Channel 0 + [0.813183, 0.186817, 0.061890, 0.186817, 0.813183] + ], + # Batch 1 + [ + # Channel 0 + [0.030148, 0.148418, 0.555452, 0.148418, 0.030148] + ], + ], + ], + [ + # Case Descirption + "1 dimension, 1 channel, high spatial sigma, low color sigma", + # Spatial and Color Sigmas + (4, 0.2), + # Input + [ + # Batch 0 + [ + # Channel 0 + [1, 0, 0, 0, 1] + ], + # Batch 1 + [ + # Channel 0 + [0, 0, 1, 0, 0] + ], + ], + # Expected + [ + # Batch 0 + [ + # Channel 0 + [0.999999, 0.000009, 0.000009, 0.000009, 0.999999] + ], + # Batch 1 + [ + # Channel 0 + [0.000000, 0.000000, 0.999967, 0.000000, 0.000000] + ], + ], + ], + [ + # Case Descirption + "1 dimension, 1 channel, high spatial sigma, high color sigma", + # Sigmas + (4, 0.9), + # Input + [ + # Batch 0 + [ + # Channel 0 + [1, 0, 0, 0, 1] + ], + # Batch 1 + [ + # Channel 0 + [0, 0, 1, 0, 0] + ], + ], + # Expected + [ + # Batch 0 + [ + # Channel 0 + [0.839145, 0.572834, 0.562460, 0.572834, 0.839145] + ], + # Batch 1 + [ + # Channel 0 + [0.049925, 0.055062, 0.171732, 0.055062, 0.049925] + ], + ], + ], + [ + # Case Descirption + "1 dimension, 4 channel, low spatial sigma, high color sigma", + # Spatial and Color Sigmas + (1, 0.9), + # Input + [ + # Batch 0 + [ + # Channel 0 + [1, 0, 0, 0, 0], + # Channel 1 + [1, 0, 1, 0, 0], + # Channel 2 + [0, 0, 1, 0, 1], + # Channel 3 + [0, 0, 0, 0, 1], + ] + ], + # Expected + [ + # Batch 0 + [ + # Channel 0 + [0.889742, 0.141296, 0.027504, 0.000000, 0.000000], + # Channel 1 + [0.909856, 0.256817, 0.725970, 0.115520, 0.020114], + # Channel 2 + [0.020114, 0.115520, 0.725970, 0.256817, 0.909856], + # Channel 3 + [0.000000, 0.000000, 0.027504, 0.141296, 0.889742], + ] + ], + ], + [ + # Case Descirption + "2 dimension, 1 channel, high spatial sigma, high color sigma", + # Sigmas + (4, 0.9), + # Input + [ + # Batch 0 + [ + # Channel 0 + [[1, 0, 0, 0, 1], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, 0, 0, 0, 1]] + ], + # Batch 1 + [ + # Channel 0 + [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 1, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]] + ], + ], + # Expected + [ + # Batch 0 + [ + # Channel 0 + [ + [0.688943, 0.374599, 0.368574, 0.374599, 0.688943], + [0.374599, 0.358248, 0.352546, 0.358248, 0.374599], + [0.368574, 0.352546, 0.346955, 0.352546, 0.368574], + [0.374599, 0.358248, 0.352546, 0.358248, 0.374599], + [0.688943, 0.374599, 0.368574, 0.374599, 0.688943], + ] + ], + # Batch 1 + [ + # Channel 0 + [ + [0.004266, 0.004687, 0.004836, 0.004687, 0.004266], + [0.004687, 0.005150, 0.005314, 0.005150, 0.004687], + [0.004836, 0.005314, 0.018598, 0.005314, 0.004836], + [0.004687, 0.005150, 0.005314, 0.005150, 0.004687], + [0.004266, 0.004687, 0.004836, 0.004687, 0.004266], + ] + ], + ], + ], + [ + # Case Descirption + "2 dimension, 4 channel, high spatial sigma, high color sigma", + # Spatial and Color Sigmas + (4, 0.9), + # Input + [ + # Batch 0 + [ + # Channel 0 + [[1, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 1]], + # Channel 1 + [[1, 0, 1, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 1, 0, 1]], + # Channel 2 + [[0, 0, 1, 0, 1], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, 0, 1, 0, 0]], + # Channel 3 + [[0, 0, 0, 0, 1], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, 0, 0, 0, 0]], + ] + ], + # Expected + [ + # Batch 0 + [ + # Channel 0 + [ + [0.692549, 0.149979, 0.220063, 0.115840, 0.035799], + [0.148403, 0.133935, 0.123253, 0.116828, 0.114623], + [0.128773, 0.122804, 0.120731, 0.122804, 0.128773], + [0.114623, 0.116828, 0.123253, 0.133935, 0.148403], + [0.035799, 0.115840, 0.220063, 0.149979, 0.692549], + ], + # Channel 1 + [ + [0.731597, 0.186319, 0.436069, 0.152181, 0.074847], + [0.180049, 0.168217, 0.158453, 0.151110, 0.146269], + [0.159760, 0.156381, 0.155211, 0.156381, 0.159760], + [0.146269, 0.151110, 0.158453, 0.168217, 0.180049], + [0.074847, 0.152181, 0.436068, 0.186319, 0.731597], + ], + # Channel 2 + [ + [0.074847, 0.152181, 0.436068, 0.186319, 0.731597], + [0.146269, 0.151110, 0.158453, 0.168217, 0.180049], + [0.159760, 0.156381, 0.155211, 0.156381, 0.159760], + [0.180049, 0.168217, 0.158453, 0.151110, 0.146269], + [0.731597, 0.186319, 0.436069, 0.152181, 0.074847], + ], + # Channel 3 + [ + [0.035799, 0.115840, 0.220063, 0.149979, 0.692549], + [0.114623, 0.116828, 0.123253, 0.133935, 0.148403], + [0.128773, 0.122804, 0.120731, 0.122804, 0.128773], + [0.148403, 0.133935, 0.123253, 0.116828, 0.114623], + [0.692549, 0.149979, 0.220063, 0.115840, 0.035799], + ], + ] + ], + ], + [ + # Case Descirption + "3 dimension, 1 channel, high spatial sigma, high color sigma", + # Sigmas + (4, 0.9), + # Input + [ + # Batch 0 + [ + # Channel 0 + [ + # Frame 0 + [[1, 0, 0, 0, 1], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, 0, 0, 0, 1]], + # Frame 1 + [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]], + # Frame 2 + [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]], + # Frame 3 + [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]], + # Frame 4 + [[1, 0, 0, 0, 1], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, 0, 0, 0, 1]], + ] + ], + ], + # Expected + [ + # Batch 0 + [ + # Channel 0 + [ + # Frame 0 + [ + [0.554430, 0.254995, 0.251207, 0.254996, 0.554430], + [0.254996, 0.244691, 0.241082, 0.244692, 0.254996], + [0.251207, 0.241082, 0.237534, 0.241082, 0.251207], + [0.254996, 0.244691, 0.241082, 0.244692, 0.254996], + [0.554430, 0.254995, 0.251207, 0.254996, 0.554430], + ], + # Frame 1 + [ + [0.254996, 0.244691, 0.241082, 0.244692, 0.254996], + [0.244692, 0.234873, 0.231432, 0.234873, 0.244692], + [0.241082, 0.231431, 0.228049, 0.231432, 0.241082], + [0.244692, 0.234873, 0.231432, 0.234873, 0.244692], + [0.254996, 0.244691, 0.241082, 0.244692, 0.254996], + ], + # Frame 2 + [ + [0.251207, 0.241081, 0.237534, 0.241082, 0.251207], + [0.241082, 0.231431, 0.228049, 0.231432, 0.241082], + [0.237534, 0.228048, 0.224724, 0.228049, 0.237534], + [0.241082, 0.231431, 0.228049, 0.231432, 0.241082], + [0.251207, 0.241081, 0.237534, 0.241082, 0.251207], + ], + # Frame 3 + [ + [0.254996, 0.244691, 0.241082, 0.244692, 0.254996], + [0.244692, 0.234873, 0.231432, 0.234873, 0.244692], + [0.241082, 0.231431, 0.228049, 0.231432, 0.241082], + [0.244692, 0.234873, 0.231432, 0.234873, 0.244692], + [0.254996, 0.244691, 0.241082, 0.244692, 0.254996], + ], + # Frame 4 + [ + [0.554430, 0.254995, 0.251207, 0.254996, 0.554430], + [0.254996, 0.244691, 0.241082, 0.244692, 0.254996], + [0.251207, 0.241082, 0.237534, 0.241082, 0.251207], + [0.254996, 0.244691, 0.241082, 0.244692, 0.254996], + [0.554430, 0.254995, 0.251207, 0.254996, 0.554430], + ], + ] + ] + ], + ], +] + + +@skip_if_no_cpp_extention +class BilateralFilterTestCaseCpuPrecised(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_cpu_precised(self, test_case_description, sigmas, input, expected): + + # Params to determine the implementation to test + device = torch.device("cpu") + fast_approx = False + + # Create input tensor and apply filter + input_tensor = torch.from_numpy(np.array(input)).to(dtype=torch.float, device=device) + output = BilateralFilter.apply(input_tensor, *sigmas, fast_approx).cpu().numpy() + + # Ensure result are as expected + np.testing.assert_allclose(output, expected, atol=1e-5) + + +@skip_if_no_cuda +@skip_if_no_cpp_extention +class BilateralFilterTestCaseCudaPrecised(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_cuda_precised(self, test_case_description, sigmas, input, expected): + + # Skip this test + if not torch.cuda.is_available(): + return + + # Params to determine the implementation to test + device = torch.device("cuda") + fast_approx = False + + # Create input tensor and apply filter + input_tensor = torch.from_numpy(np.array(input)).to(dtype=torch.float, device=device) + output = BilateralFilter.apply(input_tensor, *sigmas, fast_approx).cpu().numpy() + + # Ensure result are as expected + np.testing.assert_allclose(output, expected, atol=1e-5) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/utils.py b/tests/utils.py index 50c159053e..0b6c4e7318 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -28,6 +28,7 @@ import torch import torch.distributed as dist +from monai.config.deviceconfig import USE_COMPILED from monai.data import create_test_image_2d, create_test_image_3d from monai.utils import ensure_tuple, optional_import, set_determinism from monai.utils.module import get_torch_version_tuple @@ -80,6 +81,13 @@ def __call__(self, obj): return unittest.skipIf(self.module_avail, f"Skipping because optional module present: {self.module_name}")(obj) +def skip_if_no_cpp_extention(obj): + """ + Skip the unit tests if the cpp extention isnt available + """ + return unittest.skipIf(not USE_COMPILED, "Skipping cpp extention tests")(obj) + + def skip_if_no_cuda(obj): """ Skip the unit tests if torch.cuda.is_available is False