From 46959b900e4a26f7dd164bacf0d42bfabbb9014d Mon Sep 17 00:00:00 2001 From: charliebudd Date: Wed, 13 Jan 2021 11:28:48 +0000 Subject: [PATCH 01/12] exposing permutohedral lattice filter in python api Signed-off-by: charliebudd --- monai/csrc/ext.cpp | 1 + monai/csrc/filtering/filtering.h | 3 +- .../filtering/permutohedral/permutohedral.cpp | 41 ++++++++++++++++++ .../filtering/permutohedral/permutohedral.h | 6 ++- .../permutohedral/permutohedral_cpu.cpp | 16 +++---- monai/networks/layers/filtering.py | 42 ++++++++++++++++++- 6 files changed, 96 insertions(+), 13 deletions(-) create mode 100644 monai/csrc/filtering/permutohedral/permutohedral.cpp diff --git a/monai/csrc/ext.cpp b/monai/csrc/ext.cpp index 6740d1b5b4..8de1c5e7a1 100644 --- a/monai/csrc/ext.cpp +++ b/monai/csrc/ext.cpp @@ -21,6 +21,7 @@ limitations under the License. PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { // filtering m.def("bilateral_filter", &BilateralFilter, "Bilateral Filter"); + m.def("phl_filter", &PermutohedralFilter, "Permutohedral Filter"); // lltm m.def("lltm_forward", &lltm_forward, "LLTM forward"); diff --git a/monai/csrc/filtering/filtering.h b/monai/csrc/filtering/filtering.h index 18cf2ae6f4..9340d60805 100644 --- a/monai/csrc/filtering/filtering.h +++ b/monai/csrc/filtering/filtering.h @@ -13,4 +13,5 @@ limitations under the License. #pragma once -#include "bilateral/bilateral.h" \ No newline at end of file +#include "bilateral/bilateral.h" +#include "permutohedral/permutohedral.h" \ No newline at end of file diff --git a/monai/csrc/filtering/permutohedral/permutohedral.cpp b/monai/csrc/filtering/permutohedral/permutohedral.cpp new file mode 100644 index 0000000000..2776142b56 --- /dev/null +++ b/monai/csrc/filtering/permutohedral/permutohedral.cpp @@ -0,0 +1,41 @@ +#include "utils/common_utils.h" +#include "utils/meta_macros.h" + +#include "permutohedral.h" +#include + +torch::Tensor PermutohedralFilter(torch::Tensor input, torch::Tensor features) { + + input = input.contiguous(); + + int elementCount = input.stride(1); + int channelCount = input.size(1); + int featureCount = features.size(1); + + torch::Tensor data = input.clone().movedim(1, -1).contiguous(); + features = features.movedim(1, -1).contiguous(); + + #ifdef WITH_CUDA + if (torch::cuda::is_available() && data.is_cuda()) { + CHECK_CONTIGUOUS_CUDA(data); + + #define CASE(dc, fc) AT_DISPATCH_FLOATING_TYPES(data.scalar_type(), "PermutohedralCuda", ([&] { \ + PermutohedralCuda(data.data_ptr(), features.data_ptr(), elementCount, true); \ + })); + SWITCH_AB(CASE, 16, 19, channelCount, featureCount); + + } else { + AT_DISPATCH_FLOATING_TYPES(data.scalar_type(), "PermutohedralCPU", ([&] { \ + PermutohedralCPU(data.data_ptr(), features.data_ptr(), channelCount, featureCount, elementCount); \ + })); + } +#else + AT_DISPATCH_FLOATING_TYPES(data.scalar_type(), "PermutohedralCPU", ([&] { \ + PermutohedralCPU(data.data_ptr(), features.data_ptr(), channelCount, featureCount, elementCount); \ + })); +#endif + + data = data.movedim(-1, 1); + + return data; +} diff --git a/monai/csrc/filtering/permutohedral/permutohedral.h b/monai/csrc/filtering/permutohedral/permutohedral.h index 7f57c91a78..c97f867282 100644 --- a/monai/csrc/filtering/permutohedral/permutohedral.h +++ b/monai/csrc/filtering/permutohedral/permutohedral.h @@ -11,10 +11,14 @@ See the License for the specific language governing permissions and limitations under the License. */ +#include + #pragma once template -scalar_t* PermutohedralCPU(scalar_t* data, scalar_t* features, int dataChannels, int featureChannels, int elementCount); +void PermutohedralCPU(scalar_t* data, scalar_t* features, int dataChannels, int featureChannels, int elementCount); #ifdef WITH_CUDA template void PermutohedralCuda(scalar_t* data, scalar_t* features, int elementCount, bool accurate); #endif + +torch::Tensor PermutohedralFilter(torch::Tensor input, torch::Tensor features); \ No newline at end of file diff --git a/monai/csrc/filtering/permutohedral/permutohedral_cpu.cpp b/monai/csrc/filtering/permutohedral/permutohedral_cpu.cpp index 597bf263c1..d248ad47f6 100644 --- a/monai/csrc/filtering/permutohedral/permutohedral_cpu.cpp +++ b/monai/csrc/filtering/permutohedral/permutohedral_cpu.cpp @@ -215,7 +215,7 @@ class PermutohedralLattice { * im : image to be bilateral-filtered. * ref : reference image whose edges are to be respected. */ - static scalar_t* filter(scalar_t* data, scalar_t* features, int dataChannels, int featureChannels, int elementCount) { + static void filter(scalar_t* data, scalar_t* features, int dataChannels, int featureChannels, int elementCount) { // Create lattice PermutohedralLattice lattice(featureChannels, dataChannels + 1, elementCount); @@ -236,8 +236,6 @@ class PermutohedralLattice { lattice.blur(); // Slice from the lattice - scalar_t* outputData = new scalar_t[elementCount * dataChannels]; - lattice.beginSlice(); for (int i = 0, e = 0; e < elementCount; e++) { @@ -245,11 +243,9 @@ class PermutohedralLattice { scalar_t scale = 1.0f / col[dataChannels]; for (int c = 0; c < dataChannels; c++, i++) { - outputData[i] = col[c] * scale; + data[i] = col[c] * scale; } } - - return outputData; } /* Constructor @@ -498,17 +494,17 @@ class PermutohedralLattice { }; template -scalar_t* PermutohedralCPU( +void PermutohedralCPU( scalar_t* data, scalar_t* features, int dataChannels, int featureChannels, int elementCount) { - return PermutohedralLattice::filter(data, features, dataChannels, featureChannels, elementCount); + PermutohedralLattice::filter(data, features, dataChannels, featureChannels, elementCount); } -template float* PermutohedralCPU(float* data, float* features, int dataChannels, int featureChannels, int elementCount); -template double* PermutohedralCPU( +template void PermutohedralCPU(float* data, float* features, int dataChannels, int featureChannels, int elementCount); +template void PermutohedralCPU( double* data, double* features, int dataChannels, diff --git a/monai/networks/layers/filtering.py b/monai/networks/layers/filtering.py index dcb172d892..5b5cc1da6a 100644 --- a/monai/networks/layers/filtering.py +++ b/monai/networks/layers/filtering.py @@ -15,7 +15,7 @@ _C, _ = optional_import("monai._C") -__all__ = ["BilateralFilter"] +__all__ = ["BilateralFilter", "PHLFilter"] class BilateralFilter(torch.autograd.Function): @@ -56,3 +56,43 @@ def backward(ctx, grad_output): spatial_sigma, color_sigma, fast_approx = ctx.saved_variables grad_input = _C.bilateral_filter(grad_output, spatial_sigma, color_sigma, fast_approx) return grad_input + + +class PHLFilter(torch.autograd.Function): + """ + Filters input based on arbitrary feature vectors. Uses a permutohedral + lattice data structure to efficiently approximate n-dimensional gaussian + filtering. Complexity is broadly independant of kernel size. Most applicable + to higher filter dimensions and larger kernel sizes. + + See: + https://graphics.stanford.edu/papers/permutohedral/ + + Args: + input: input tensor to be filtered. + + features: feature tensor used to filter the input. + + sigmas: the standard deviations of each feature in the filter. + + Returns: + output (torch.Tensor): output tensor. + """ + + @staticmethod + def forward(ctx, input, features, sigmas=None): + + scaled_features = features + if sigmas is not None: + for i in range(features.size(1)): + scaled_features[:, i, ...] /= sigmas[i] + + ctx.save_for_backward(scaled_features) + output_data = _C.phl_filter(input, scaled_features) + return output_data + + @staticmethod + def backward(ctx, grad_output): + scaled_features = ctx.saved_variables + grad_input = PHLFilter.scale(grad_output, scaled_features) + return grad_input From a4cc4aa09b4e64be8f71b4d89076313f8901f546 Mon Sep 17 00:00:00 2001 From: charliebudd Date: Wed, 13 Jan 2021 16:50:55 +0000 Subject: [PATCH 02/12] unit tests for phl filter Signed-off-by: charliebudd --- .../filtering/permutohedral/permutohedral.cpp | 33 +- tests/test_phl_cpu.py | 341 ++++++++++++++++++ tests/test_phl_cuda.py | 190 ++++++++++ 3 files changed, 551 insertions(+), 13 deletions(-) create mode 100644 tests/test_phl_cpu.py create mode 100644 tests/test_phl_cuda.py diff --git a/monai/csrc/filtering/permutohedral/permutohedral.cpp b/monai/csrc/filtering/permutohedral/permutohedral.cpp index 2776142b56..605ae01526 100644 --- a/monai/csrc/filtering/permutohedral/permutohedral.cpp +++ b/monai/csrc/filtering/permutohedral/permutohedral.cpp @@ -8,6 +8,8 @@ torch::Tensor PermutohedralFilter(torch::Tensor input, torch::Tensor features) { input = input.contiguous(); + int batchCount = input.size(0); + int batchStride = input.stride(0); int elementCount = input.stride(1); int channelCount = input.size(1); int featureCount = features.size(1); @@ -19,21 +21,26 @@ torch::Tensor PermutohedralFilter(torch::Tensor input, torch::Tensor features) { if (torch::cuda::is_available() && data.is_cuda()) { CHECK_CONTIGUOUS_CUDA(data); - #define CASE(dc, fc) AT_DISPATCH_FLOATING_TYPES(data.scalar_type(), "PermutohedralCuda", ([&] { \ - PermutohedralCuda(data.data_ptr(), features.data_ptr(), elementCount, true); \ - })); + #define CASE(dc, fc) AT_DISPATCH_FLOATING_TYPES(data.scalar_type(), "PermutohedralCuda", ([&] { \ + for (int batchIndex = 0; batchIndex < batchCount; batchIndex++) { \ + scalar_t* offsetData = data.data_ptr() + batchIndex * batchStride; \ + scalar_t* offsetFeatures = features.data_ptr() + batchIndex * fc * elementCount; \ + PermutohedralCuda(offsetData, offsetFeatures, elementCount, true); \ + }})); SWITCH_AB(CASE, 16, 19, channelCount, featureCount); - } else { - AT_DISPATCH_FLOATING_TYPES(data.scalar_type(), "PermutohedralCPU", ([&] { \ - PermutohedralCPU(data.data_ptr(), features.data_ptr(), channelCount, featureCount, elementCount); \ - })); - } -#else - AT_DISPATCH_FLOATING_TYPES(data.scalar_type(), "PermutohedralCPU", ([&] { \ - PermutohedralCPU(data.data_ptr(), features.data_ptr(), channelCount, featureCount, elementCount); \ - })); -#endif + } + else { + #endif + AT_DISPATCH_FLOATING_TYPES(data.scalar_type(), "PermutohedralCPU", ([&] { \ + for (int batchIndex = 0; batchIndex < batchCount; batchIndex++) { \ + scalar_t* offsetData = data.data_ptr() + batchIndex * batchStride; \ + scalar_t* offsetFeatures = features.data_ptr() + batchIndex * featureCount * elementCount;\ + PermutohedralCPU(offsetData, offsetFeatures, channelCount, featureCount, elementCount); \ + }})); + #ifdef WITH_CUDA + } + #endif data = data.movedim(-1, 1); diff --git a/tests/test_phl_cpu.py b/tests/test_phl_cpu.py new file mode 100644 index 0000000000..88859cb02a --- /dev/null +++ b/tests/test_phl_cpu.py @@ -0,0 +1,341 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import torch +from parameterized import parameterized + +from monai.networks.layers.filtering import PHLFilter +from tests.utils import skip_if_no_cpp_extention + +TEST_CASES = [ + [ + # Case Descirption + "2 batches, 1 dimensions, 1 channels, 1 features", + # Sigmas + [1, 0.2], + # Input + [ + # Batch 0 + [ + # Channel 0 + [1, 0, 0, 0, 1] + ], + # Batch 1 + [ + # Channel 0 + [0, 0, 1, 0, 0] + ], + ], + # Features + [ + # Batch 0 + [ + # Channel 0 + [1, 0.2, 0.5, 0, 1], + ], + # Batch 1 + [ + # Channel 0 + [0.5, 0, 1, 1, 1] + ], + ], + # Expected + [ + # Batch 0 + [ + # Channel 0 + [0.468968, 0.364596, 0.4082 , 0.332579, 0.468968] + ], + # Batch 1 + [ + # Channel 0 + [0.202473, 0.176527, 0.220995, 0.220995, 0.220995] + ], + ], + ], + [ + # Case Descirption + "1 batches, 1 dimensions, 3 channels, 1 features", + # Sigmas + [1], + # Input + [ + # Batch 0 + [ + # Channel 0 + [1, 0, 0, 0, 0], + # Channel 1 + [0, 0, 0, 0, 1], + # Channel 2 + [0, 0, 1, 0, 0] + ], + ], + # Features + [ + # Batch 0 + [ + # Channel 0 + [1, 0.2, 0.5, 0.2, 1], + ], + ], + # Expected + [ + # Batch 0 + [ + # Channel 0 + [0.229572, 0.182884, 0.202637, 0.182884, 0.229572], + # Channel 1 + [0.229572, 0.182884, 0.202637, 0.182884, 0.229572], + # Channel 2 + [0.201235, 0.208194, 0.205409, 0.208194, 0.201235] + ], + ], + ], + [ + # Case Descirption + "1 batches, 2 dimensions, 1 channels, 3 features", + # Sigmas + [5, 3, 3], + # Input + [ + # Batch 0 + [ + # Channel 0 + [ + [9, 9, 0, 0, 0], + [9, 9, 0, 0, 0], + [9, 9, 0, 0, 0], + [9, 9, 6, 6, 6], + [9, 9, 6, 6, 6] + ] + ], + ], + # Features + [ + # Batch 0 + [ + # Channel 0 + [ + [9, 9, 0, 0, 0], + [9, 9, 0, 0, 0], + [9, 9, 0, 0, 0], + [9, 9, 6, 6, 6], + [9, 9, 6, 6, 6] + ], + # Channel 1 + [ + [0, 1, 2, 3, 4], + [0, 1, 2, 3, 4], + [0, 1, 2, 3, 4], + [0, 1, 2, 3, 4], + [0, 1, 2, 3, 4] + ], + # Channel 2 + [ + [0, 0, 0, 0, 0], + [1, 1, 1, 1, 1], + [2, 2, 2, 2, 2], + [3, 3, 3, 3, 3], + [4, 4, 4, 4, 4] + ] + ], + ], + # Expected + [ + # Batch 0 + [ + # Channel 0 + [ + [7.696051, 7.427121, 1.191990, 1.156004, 1.157489], + [7.670297, 7.371155, 1.340232, 1.287871, 1.304018], + [7.639579, 7.365163, 1.473319, 1.397826, 1.416861], + [7.613517, 7.359183, 5.846500, 5.638952, 5.350098], + [7.598255, 7.458446, 5.912375, 5.583625, 5.233126] + ] + ], + ], + ], + [ + # Case Descirption + "1 batches, 3 dimensions, 1 channels, 1 features", + # Sigmas + [5, 3, 3], + # Input + [ + # Batch 0 + [ + # Channel 0 + [ + # Frame 0 + [ + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [9, 9, 9, 0, 0], + [9, 9, 9, 0, 0], + [9, 9, 9, 0, 0] + ], + # Frame 1 + [ + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [9, 9, 9, 0, 0], + [9, 9, 9, 0, 0], + [9, 9, 9, 0, 0] + ], + # Frame 2 + [ + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0] + ], + # Frame 3 + [ + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0] + ], + # Frame 4 + [ + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0] + ] + ] + ], + ], + # Features + [ + # Batch 0 + [ + # Channel 0 + [ + # Frame 0 + [ + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [9, 9, 9, 0, 0], + [9, 9, 9, 0, 0], + [9, 9, 9, 0, 0] + ], + # Frame 1 + [ + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [9, 9, 9, 0, 0], + [9, 9, 9, 0, 0], + [9, 9, 9, 0, 0] + ], + # Frame 2 + [ + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0] + ], + # Frame 3 + [ + [0, 0, 5, 5, 5], + [0, 0, 5, 5, 5], + [0, 0, 5, 5, 5], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0] + ], + # Frame 4 + [ + [0, 0, 5, 5, 5], + [0, 0, 5, 5, 5], + [0, 0, 5, 5, 5], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0] + ] + ] + ], + ], + # Expected + [ + # Batch 0 + [ + # Channel 0 + [ + # Frame 0 + [ + [0.284234, 0.284234, 0.284234, 0.284234, 0.284234], + [0.284234, 0.284234, 0.284234, 0.284234, 0.284234], + [3.578490, 3.578490, 3.578490, 0.284234, 0.284234], + [3.578490, 3.578490, 3.578490, 0.284234, 0.284234], + [3.578490, 3.578490, 3.578490, 0.284234, 0.284234] + ], + # Frame 1 + [ + [0.284234, 0.284234, 0.284234, 0.284234, 0.284234], + [0.284234, 0.284234, 0.284234, 0.284234, 0.284234], + [3.578490, 3.578490, 3.578490, 0.284234, 0.284234], + [3.578490, 3.578490, 3.578490, 0.284234, 0.284234], + [3.578490, 3.578490, 3.578490, 0.284234, 0.284234] + ], + # Frame 2 + [ + [0.284234, 0.284234, 0.284234, 0.284234, 0.284234], + [0.284234, 0.284234, 0.284234, 0.284234, 0.284234], + [0.284234, 0.284234, 0.284234, 0.284234, 0.284234], + [0.284234, 0.284234, 0.284234, 0.284234, 0.284234], + [0.284234, 0.284234, 0.284234, 0.284234, 0.284234] + ], + # Frame 3 + [ + [0.284234, 0.284234, 1.359728, 1.359728, 1.359728], + [0.284234, 0.284234, 1.359728, 1.359728, 1.359728], + [0.284234, 0.284234, 1.359728, 1.359728, 1.359728], + [0.284234, 0.284234, 0.284234, 0.284234, 0.284234], + [0.284234, 0.284234, 0.284234, 0.284234, 0.284234] + ], + # Frame 4 + [ + [0.284234, 0.284234, 1.359728, 1.359728, 1.359728], + [0.284234, 0.284234, 1.359728, 1.359728, 1.359728], + [0.284234, 0.284234, 1.359728, 1.359728, 1.359728], + [0.284234, 0.284234, 0.284234, 0.284234, 0.284234], + [0.284234, 0.284234, 0.284234, 0.284234, 0.284234] + ] + ] + ], + ], + ], +] + +#@skip_if_no_cpp_extention +class PHLFilterTestCaseCpu(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_cpu(self, test_case_description, sigmas, input, features, expected): + + # Create input tensors + input_tensor = torch.from_numpy(np.array(input)).to(dtype=torch.float, device=torch.device("cpu")) + feature_tensor = torch.from_numpy(np.array(features)).to(dtype=torch.float, device=torch.device("cpu")) + + # apply filter + output = PHLFilter.apply(input_tensor, feature_tensor, sigmas).cpu().numpy() + + # Ensure result are as expected + np.testing.assert_allclose(output, expected, atol=1e-4) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_phl_cuda.py b/tests/test_phl_cuda.py new file mode 100644 index 0000000000..2fa6380a29 --- /dev/null +++ b/tests/test_phl_cuda.py @@ -0,0 +1,190 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import torch +from parameterized import parameterized + +from monai.networks.layers.filtering import PHLFilter +from tests.utils import skip_if_no_cpp_extention, skip_if_no_cuda + +TEST_CASES = [ + [ + # Case Descirption + "2 batches, 1 dimensions, 1 channels, 1 features", + # Sigmas + [1, 0.2], + # Input + [ + # Batch 0 + [ + # Channel 0 + [1, 0, 0, 0, 1] + ], + # Batch 1 + [ + # Channel 0 + [0, 0, 1, 0, 0] + ], + ], + # Features + [ + # Batch 0 + [ + # Channel 0 + [1, 0.2, 0.5, 0, 1], + ], + # Batch 1 + [ + # Channel 0 + [0.5, 0, 1, 1, 1] + ], + ], + # Expected + [ + # Batch 0 + [ + # Channel 0 + [0.468968, 0.364596, 0.408200, 0.332579, 0.468968] + ], + # Batch 1 + [ + # Channel 0 + [0.202473, 0.176527, 0.220995, 0.220995, 0.220995] + ], + ], + ], + [ + # Case Descirption + "1 batches, 1 dimensions, 3 channels, 1 features", + # Sigmas + [1], + # Input + [ + # Batch 0 + [ + # Channel 0 + [1, 0, 0, 0, 0], + # Channel 1 + [0, 0, 0, 0, 1], + # Channel 2 + [0, 0, 1, 0, 0] + ], + ], + # Features + [ + # Batch 0 + [ + # Channel 0 + [1, 0.2, 0.5, 0.2, 1], + ], + ], + # Expected + [ + # Batch 0 + [ + # Channel 0 + [0.229572, 0.182884, 0.202637, 0.182884, 0.229572], + # Channel 1 + [0.229572, 0.182884, 0.202637, 0.182884, 0.229572], + # Channel 2 + [0.201235, 0.208194, 0.205409, 0.208194, 0.201235] + ], + ], + ], + [ + # Case Descirption + "1 batches, 2 dimensions, 1 channels, 3 features", + # Sigmas + [5, 3, 3], + # Input + [ + # Batch 0 + [ + # Channel 0 + [ + [9, 9, 0, 0, 0], + [9, 9, 0, 0, 0], + [9, 9, 0, 0, 0], + [9, 9, 6, 6, 6], + [9, 9, 6, 6, 6] + ] + ], + ], + # Features + [ + # Batch 0 + [ + # Channel 0 + [ + [9, 9, 0, 0, 0], + [9, 9, 0, 0, 0], + [9, 9, 0, 0, 0], + [9, 9, 6, 6, 6], + [9, 9, 6, 6, 6] + ], + # Channel 1 + [ + [0, 1, 2, 3, 4], + [0, 1, 2, 3, 4], + [0, 1, 2, 3, 4], + [0, 1, 2, 3, 4], + [0, 1, 2, 3, 4] + ], + # Channel 2 + [ + [0, 0, 0, 0, 0], + [1, 1, 1, 1, 1], + [2, 2, 2, 2, 2], + [3, 3, 3, 3, 3], + [4, 4, 4, 4, 4] + ] + ], + ], + # Expected + [ + # Batch 0 + [ + # Channel 0 + [ + [7.792655, 7.511395, 0.953769, 0.860538, 0.912978], + [7.758870, 7.426762, 1.164386, 1.050956, 1.121830], + [7.733974, 7.429964, 1.405752, 1.244949, 1.320862], + [7.712976, 7.429060, 5.789552, 5.594258, 5.371737], + [7.701185, 7.492719, 5.860026, 5.538241, 5.281656] + ] + ], + ], + ], +] + + +#@skip_if_no_cuda +#@skip_if_no_cpp_extention +class PHLFilterTestCaseCuda(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_cuda(self, test_case_description, sigmas, input, features, expected): + + # Create input tensors + input_tensor = torch.from_numpy(np.array(input)).to(dtype=torch.float, device=torch.device("cuda")) + feature_tensor = torch.from_numpy(np.array(features)).to(dtype=torch.float, device=torch.device("cuda")) + + # apply filter + output = PHLFilter.apply(input_tensor, feature_tensor, sigmas).cpu().numpy() + + # Ensure result are as expected + np.testing.assert_allclose(output, expected, atol=1e-4) + + +if __name__ == "__main__": + unittest.main() From 29a74adb95200e61633c7d4e959aa5d02821986c Mon Sep 17 00:00:00 2001 From: charliebudd Date: Wed, 13 Jan 2021 17:32:47 +0000 Subject: [PATCH 03/12] auto formatting Signed-off-by: charliebudd --- .../bilateral/bilateralfilter_cpu_phl.cpp | 5 +- .../bilateral/bilateralfilter_cuda_phl.cu | 10 +- monai/networks/layers/filtering.py | 6 +- tests/test_phl_cpu.py | 138 ++++-------------- tests/test_phl_cuda.py | 44 ++---- 5 files changed, 50 insertions(+), 153 deletions(-) diff --git a/monai/csrc/filtering/bilateral/bilateralfilter_cpu_phl.cpp b/monai/csrc/filtering/bilateral/bilateralfilter_cpu_phl.cpp index eb94749ea5..75ad368687 100644 --- a/monai/csrc/filtering/bilateral/bilateralfilter_cpu_phl.cpp +++ b/monai/csrc/filtering/bilateral/bilateralfilter_cpu_phl.cpp @@ -62,13 +62,12 @@ void BilateralFilterPHLCpu( } // Filtering data with respect to the features. - scalar_t* output = - PermutohedralCPU(data, features, desc.channelCount, featureChannels, desc.channelStride); + PermutohedralCPU(data, features, desc.channelCount, featureChannels, desc.channelStride); // Writing output tensor. for (int i = 0; i < desc.channelStride; i++) { for (int c = 0; c < desc.channelCount; c++) { - outputTensorData[batchOffset + i + c * desc.channelStride] = output[i * desc.channelCount + c]; + outputTensorData[batchOffset + i + c * desc.channelStride] = data[i * desc.channelCount + c]; } } } diff --git a/monai/csrc/filtering/bilateral/bilateralfilter_cuda_phl.cu b/monai/csrc/filtering/bilateral/bilateralfilter_cuda_phl.cu index df4ed8771b..f1d5682475 100644 --- a/monai/csrc/filtering/bilateral/bilateralfilter_cuda_phl.cu +++ b/monai/csrc/filtering/bilateral/bilateralfilter_cuda_phl.cu @@ -30,6 +30,8 @@ __global__ void FeatureCreation(const scalar_t* inputTensor, scalar_t* outputDat int elementIndex = blockIdx.x * blockDim.x + threadIdx.x; int batchIndex = blockIdx.y; + if (elementIndex >= cChannelStride) return; + int dataBatchOffset = batchIndex * cBatchStride; int featureBatchOffset = batchIndex * (D + C) * cChannelStride; @@ -56,6 +58,10 @@ template __global__ void WriteOutput(const scalar_t* data, scalar_t* outputTensor) { int elementIndex = blockIdx.x * blockDim.x + threadIdx.x; int batchIndex = blockIdx.y; + + + if (elementIndex >= cChannelStride) return; + int batchOffset = batchIndex * cBatchStride; #pragma unroll @@ -97,7 +103,7 @@ void BilateralFilterPHLCuda( // Creating features FeatureCreation - <<>>(inputTensorData, data, features); + <<>>(inputTensorData, data, features); // Filtering data with respect to the features for each sample in batch for (int batchIndex = 0; batchIndex < desc.batchCount; batchIndex++) { @@ -108,7 +114,7 @@ void BilateralFilterPHLCuda( } // Writing output - WriteOutput<<>>(data, outputTensorData); + WriteOutput<<>>(data, outputTensorData); cudaFree(data); cudaFree(features); diff --git a/monai/networks/layers/filtering.py b/monai/networks/layers/filtering.py index 5b5cc1da6a..034161ab9c 100644 --- a/monai/networks/layers/filtering.py +++ b/monai/networks/layers/filtering.py @@ -60,9 +60,9 @@ def backward(ctx, grad_output): class PHLFilter(torch.autograd.Function): """ - Filters input based on arbitrary feature vectors. Uses a permutohedral + Filters input based on arbitrary feature vectors. Uses a permutohedral lattice data structure to efficiently approximate n-dimensional gaussian - filtering. Complexity is broadly independant of kernel size. Most applicable + filtering. Complexity is broadly independant of kernel size. Most applicable to higher filter dimensions and larger kernel sizes. See: @@ -81,7 +81,7 @@ class PHLFilter(torch.autograd.Function): @staticmethod def forward(ctx, input, features, sigmas=None): - + scaled_features = features if sigmas is not None: for i in range(features.size(1)): diff --git a/tests/test_phl_cpu.py b/tests/test_phl_cpu.py index 88859cb02a..a3cb7f12cf 100644 --- a/tests/test_phl_cpu.py +++ b/tests/test_phl_cpu.py @@ -55,7 +55,7 @@ # Batch 0 [ # Channel 0 - [0.468968, 0.364596, 0.4082 , 0.332579, 0.468968] + [0.468968, 0.364596, 0.4082, 0.332579, 0.468968] ], # Batch 1 [ @@ -78,7 +78,7 @@ # Channel 1 [0, 0, 0, 0, 1], # Channel 2 - [0, 0, 1, 0, 0] + [0, 0, 1, 0, 0], ], ], # Features @@ -98,7 +98,7 @@ # Channel 1 [0.229572, 0.182884, 0.202637, 0.182884, 0.229572], # Channel 2 - [0.201235, 0.208194, 0.205409, 0.208194, 0.201235] + [0.201235, 0.208194, 0.205409, 0.208194, 0.201235], ], ], ], @@ -112,13 +112,7 @@ # Batch 0 [ # Channel 0 - [ - [9, 9, 0, 0, 0], - [9, 9, 0, 0, 0], - [9, 9, 0, 0, 0], - [9, 9, 6, 6, 6], - [9, 9, 6, 6, 6] - ] + [[9, 9, 0, 0, 0], [9, 9, 0, 0, 0], [9, 9, 0, 0, 0], [9, 9, 6, 6, 6], [9, 9, 6, 6, 6]] ], ], # Features @@ -126,29 +120,11 @@ # Batch 0 [ # Channel 0 - [ - [9, 9, 0, 0, 0], - [9, 9, 0, 0, 0], - [9, 9, 0, 0, 0], - [9, 9, 6, 6, 6], - [9, 9, 6, 6, 6] - ], + [[9, 9, 0, 0, 0], [9, 9, 0, 0, 0], [9, 9, 0, 0, 0], [9, 9, 6, 6, 6], [9, 9, 6, 6, 6]], # Channel 1 - [ - [0, 1, 2, 3, 4], - [0, 1, 2, 3, 4], - [0, 1, 2, 3, 4], - [0, 1, 2, 3, 4], - [0, 1, 2, 3, 4] - ], + [[0, 1, 2, 3, 4], [0, 1, 2, 3, 4], [0, 1, 2, 3, 4], [0, 1, 2, 3, 4], [0, 1, 2, 3, 4]], # Channel 2 - [ - [0, 0, 0, 0, 0], - [1, 1, 1, 1, 1], - [2, 2, 2, 2, 2], - [3, 3, 3, 3, 3], - [4, 4, 4, 4, 4] - ] + [[0, 0, 0, 0, 0], [1, 1, 1, 1, 1], [2, 2, 2, 2, 2], [3, 3, 3, 3, 3], [4, 4, 4, 4, 4]], ], ], # Expected @@ -161,7 +137,7 @@ [7.670297, 7.371155, 1.340232, 1.287871, 1.304018], [7.639579, 7.365163, 1.473319, 1.397826, 1.416861], [7.613517, 7.359183, 5.846500, 5.638952, 5.350098], - [7.598255, 7.458446, 5.912375, 5.583625, 5.233126] + [7.598255, 7.458446, 5.912375, 5.583625, 5.233126], ] ], ], @@ -178,45 +154,15 @@ # Channel 0 [ # Frame 0 - [ - [0, 0, 0, 0, 0], - [0, 0, 0, 0, 0], - [9, 9, 9, 0, 0], - [9, 9, 9, 0, 0], - [9, 9, 9, 0, 0] - ], + [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [9, 9, 9, 0, 0], [9, 9, 9, 0, 0], [9, 9, 9, 0, 0]], # Frame 1 - [ - [0, 0, 0, 0, 0], - [0, 0, 0, 0, 0], - [9, 9, 9, 0, 0], - [9, 9, 9, 0, 0], - [9, 9, 9, 0, 0] - ], + [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [9, 9, 9, 0, 0], [9, 9, 9, 0, 0], [9, 9, 9, 0, 0]], # Frame 2 - [ - [0, 0, 0, 0, 0], - [0, 0, 0, 0, 0], - [0, 0, 0, 0, 0], - [0, 0, 0, 0, 0], - [0, 0, 0, 0, 0] - ], + [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]], # Frame 3 - [ - [0, 0, 0, 0, 0], - [0, 0, 0, 0, 0], - [0, 0, 0, 0, 0], - [0, 0, 0, 0, 0], - [0, 0, 0, 0, 0] - ], + [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]], # Frame 4 - [ - [0, 0, 0, 0, 0], - [0, 0, 0, 0, 0], - [0, 0, 0, 0, 0], - [0, 0, 0, 0, 0], - [0, 0, 0, 0, 0] - ] + [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]], ] ], ], @@ -227,45 +173,15 @@ # Channel 0 [ # Frame 0 - [ - [0, 0, 0, 0, 0], - [0, 0, 0, 0, 0], - [9, 9, 9, 0, 0], - [9, 9, 9, 0, 0], - [9, 9, 9, 0, 0] - ], + [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [9, 9, 9, 0, 0], [9, 9, 9, 0, 0], [9, 9, 9, 0, 0]], # Frame 1 - [ - [0, 0, 0, 0, 0], - [0, 0, 0, 0, 0], - [9, 9, 9, 0, 0], - [9, 9, 9, 0, 0], - [9, 9, 9, 0, 0] - ], + [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [9, 9, 9, 0, 0], [9, 9, 9, 0, 0], [9, 9, 9, 0, 0]], # Frame 2 - [ - [0, 0, 0, 0, 0], - [0, 0, 0, 0, 0], - [0, 0, 0, 0, 0], - [0, 0, 0, 0, 0], - [0, 0, 0, 0, 0] - ], + [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]], # Frame 3 - [ - [0, 0, 5, 5, 5], - [0, 0, 5, 5, 5], - [0, 0, 5, 5, 5], - [0, 0, 0, 0, 0], - [0, 0, 0, 0, 0] - ], + [[0, 0, 5, 5, 5], [0, 0, 5, 5, 5], [0, 0, 5, 5, 5], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]], # Frame 4 - [ - [0, 0, 5, 5, 5], - [0, 0, 5, 5, 5], - [0, 0, 5, 5, 5], - [0, 0, 0, 0, 0], - [0, 0, 0, 0, 0] - ] + [[0, 0, 5, 5, 5], [0, 0, 5, 5, 5], [0, 0, 5, 5, 5], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]], ] ], ], @@ -281,7 +197,7 @@ [0.284234, 0.284234, 0.284234, 0.284234, 0.284234], [3.578490, 3.578490, 3.578490, 0.284234, 0.284234], [3.578490, 3.578490, 3.578490, 0.284234, 0.284234], - [3.578490, 3.578490, 3.578490, 0.284234, 0.284234] + [3.578490, 3.578490, 3.578490, 0.284234, 0.284234], ], # Frame 1 [ @@ -289,7 +205,7 @@ [0.284234, 0.284234, 0.284234, 0.284234, 0.284234], [3.578490, 3.578490, 3.578490, 0.284234, 0.284234], [3.578490, 3.578490, 3.578490, 0.284234, 0.284234], - [3.578490, 3.578490, 3.578490, 0.284234, 0.284234] + [3.578490, 3.578490, 3.578490, 0.284234, 0.284234], ], # Frame 2 [ @@ -297,7 +213,7 @@ [0.284234, 0.284234, 0.284234, 0.284234, 0.284234], [0.284234, 0.284234, 0.284234, 0.284234, 0.284234], [0.284234, 0.284234, 0.284234, 0.284234, 0.284234], - [0.284234, 0.284234, 0.284234, 0.284234, 0.284234] + [0.284234, 0.284234, 0.284234, 0.284234, 0.284234], ], # Frame 3 [ @@ -305,7 +221,7 @@ [0.284234, 0.284234, 1.359728, 1.359728, 1.359728], [0.284234, 0.284234, 1.359728, 1.359728, 1.359728], [0.284234, 0.284234, 0.284234, 0.284234, 0.284234], - [0.284234, 0.284234, 0.284234, 0.284234, 0.284234] + [0.284234, 0.284234, 0.284234, 0.284234, 0.284234], ], # Frame 4 [ @@ -313,19 +229,19 @@ [0.284234, 0.284234, 1.359728, 1.359728, 1.359728], [0.284234, 0.284234, 1.359728, 1.359728, 1.359728], [0.284234, 0.284234, 0.284234, 0.284234, 0.284234], - [0.284234, 0.284234, 0.284234, 0.284234, 0.284234] - ] + [0.284234, 0.284234, 0.284234, 0.284234, 0.284234], + ], ] ], ], ], ] - -#@skip_if_no_cpp_extention + +# @skip_if_no_cpp_extention class PHLFilterTestCaseCpu(unittest.TestCase): @parameterized.expand(TEST_CASES) def test_cpu(self, test_case_description, sigmas, input, features, expected): - + # Create input tensors input_tensor = torch.from_numpy(np.array(input)).to(dtype=torch.float, device=torch.device("cpu")) feature_tensor = torch.from_numpy(np.array(features)).to(dtype=torch.float, device=torch.device("cpu")) diff --git a/tests/test_phl_cuda.py b/tests/test_phl_cuda.py index 2fa6380a29..7fa3c2e909 100644 --- a/tests/test_phl_cuda.py +++ b/tests/test_phl_cuda.py @@ -78,7 +78,7 @@ # Channel 1 [0, 0, 0, 0, 1], # Channel 2 - [0, 0, 1, 0, 0] + [0, 0, 1, 0, 0], ], ], # Features @@ -98,7 +98,7 @@ # Channel 1 [0.229572, 0.182884, 0.202637, 0.182884, 0.229572], # Channel 2 - [0.201235, 0.208194, 0.205409, 0.208194, 0.201235] + [0.201235, 0.208194, 0.205409, 0.208194, 0.201235], ], ], ], @@ -112,13 +112,7 @@ # Batch 0 [ # Channel 0 - [ - [9, 9, 0, 0, 0], - [9, 9, 0, 0, 0], - [9, 9, 0, 0, 0], - [9, 9, 6, 6, 6], - [9, 9, 6, 6, 6] - ] + [[9, 9, 0, 0, 0], [9, 9, 0, 0, 0], [9, 9, 0, 0, 0], [9, 9, 6, 6, 6], [9, 9, 6, 6, 6]] ], ], # Features @@ -126,29 +120,11 @@ # Batch 0 [ # Channel 0 - [ - [9, 9, 0, 0, 0], - [9, 9, 0, 0, 0], - [9, 9, 0, 0, 0], - [9, 9, 6, 6, 6], - [9, 9, 6, 6, 6] - ], + [[9, 9, 0, 0, 0], [9, 9, 0, 0, 0], [9, 9, 0, 0, 0], [9, 9, 6, 6, 6], [9, 9, 6, 6, 6]], # Channel 1 - [ - [0, 1, 2, 3, 4], - [0, 1, 2, 3, 4], - [0, 1, 2, 3, 4], - [0, 1, 2, 3, 4], - [0, 1, 2, 3, 4] - ], + [[0, 1, 2, 3, 4], [0, 1, 2, 3, 4], [0, 1, 2, 3, 4], [0, 1, 2, 3, 4], [0, 1, 2, 3, 4]], # Channel 2 - [ - [0, 0, 0, 0, 0], - [1, 1, 1, 1, 1], - [2, 2, 2, 2, 2], - [3, 3, 3, 3, 3], - [4, 4, 4, 4, 4] - ] + [[0, 0, 0, 0, 0], [1, 1, 1, 1, 1], [2, 2, 2, 2, 2], [3, 3, 3, 3, 3], [4, 4, 4, 4, 4]], ], ], # Expected @@ -161,7 +137,7 @@ [7.758870, 7.426762, 1.164386, 1.050956, 1.121830], [7.733974, 7.429964, 1.405752, 1.244949, 1.320862], [7.712976, 7.429060, 5.789552, 5.594258, 5.371737], - [7.701185, 7.492719, 5.860026, 5.538241, 5.281656] + [7.701185, 7.492719, 5.860026, 5.538241, 5.281656], ] ], ], @@ -169,12 +145,12 @@ ] -#@skip_if_no_cuda -#@skip_if_no_cpp_extention +# @skip_if_no_cuda +# @skip_if_no_cpp_extention class PHLFilterTestCaseCuda(unittest.TestCase): @parameterized.expand(TEST_CASES) def test_cuda(self, test_case_description, sigmas, input, features, expected): - + # Create input tensors input_tensor = torch.from_numpy(np.array(input)).to(dtype=torch.float, device=torch.device("cuda")) feature_tensor = torch.from_numpy(np.array(features)).to(dtype=torch.float, device=torch.device("cuda")) From 5b93a18e91e86d2720ac70079cccd3244d7153fa Mon Sep 17 00:00:00 2001 From: charliebudd Date: Fri, 15 Jan 2021 10:56:05 +0000 Subject: [PATCH 04/12] expanding block sizes for better gpu utilisation Signed-off-by: charliebudd --- .../filtering/bilateral/bilateralfilter_cuda.cu | 14 +++++++++++--- .../bilateral/bilateralfilter_cuda_phl.cu | 6 ++++-- 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/monai/csrc/filtering/bilateral/bilateralfilter_cuda.cu b/monai/csrc/filtering/bilateral/bilateralfilter_cuda.cu index 872ff652cb..9fc25dc84f 100644 --- a/monai/csrc/filtering/bilateral/bilateralfilter_cuda.cu +++ b/monai/csrc/filtering/bilateral/bilateralfilter_cuda.cu @@ -36,6 +36,8 @@ __global__ void BilateralFilterCudaKernel1D(scalar_t* input, scalar_t* output) { int homeOffset = blockIdx.x * blockDim.x + threadIdx.x; int batchOffset = blockIdx.y * cBatchStride; + if (homeOffset >= cColorStride) return; + scalar_t weightSum = 0; for (int kernelOffset = 0; kernelOffset < cKernelSize; kernelOffset++) { @@ -78,6 +80,8 @@ __global__ void BilateralFilterCudaKernel2D(scalar_t* input, scalar_t* output) { int homeOffset = blockIdx.x * blockDim.x + threadIdx.x; int batchOffset = blockIdx.y * cBatchStride; + + if (homeOffset >= cColorStride) return; int homeX = homeOffset / cStrides[0]; int homeY = (homeOffset - homeX * cStrides[0]) / cStrides[1]; @@ -132,6 +136,8 @@ __global__ void BilateralFilterCudaKernel3D(scalar_t* input, scalar_t* output) { int homeOffset = blockIdx.x * blockDim.x + threadIdx.x; int batchOffset = blockIdx.y * cBatchStride; + if (homeOffset >= cColorStride) return; + int homeX = homeOffset / cStrides[0]; int homeY = (homeOffset - homeX * cStrides[0]) / cStrides[1]; int homeZ = (homeOffset - homeX * cStrides[0] - homeY * cStrides[1]) / cStrides[2]; @@ -211,21 +217,23 @@ void BilateralFilterCuda(torch::Tensor inputTensor, torch::Tensor outputTensor, cudaMemcpyToSymbol(cKernel, kernel, sizeof(float) * kernelSize); cudaMemcpyToSymbol(cColorExponentFactor, &colorExponentFactor, sizeof(float)); + #define BLOCK_SIZE 32 + AT_DISPATCH_FLOATING_TYPES_AND_HALF( inputTensor.type(), "BilateralFilterCudaKernel", ([&] { // Dispatch kernel. (Partial template function specialisation not supported at present so using this switch // instead) switch (D) { case (1): - BilateralFilterCudaKernel1D<<>>( + BilateralFilterCudaKernel1D<<>>( inputTensor.data_ptr(), outputTensor.data_ptr()); break; case (2): - BilateralFilterCudaKernel2D<<>>( + BilateralFilterCudaKernel2D<<>>( inputTensor.data_ptr(), outputTensor.data_ptr()); break; case (3): - BilateralFilterCudaKernel3D<<>>( + BilateralFilterCudaKernel3D<<>>( inputTensor.data_ptr(), outputTensor.data_ptr()); break; } diff --git a/monai/csrc/filtering/bilateral/bilateralfilter_cuda_phl.cu b/monai/csrc/filtering/bilateral/bilateralfilter_cuda_phl.cu index f1d5682475..3acced8e99 100644 --- a/monai/csrc/filtering/bilateral/bilateralfilter_cuda_phl.cu +++ b/monai/csrc/filtering/bilateral/bilateralfilter_cuda_phl.cu @@ -101,9 +101,11 @@ void BilateralFilterPHLCuda( cudaMemcpyToSymbol(cInvSpatialSigma, &invSpatialSigma, sizeof(float)); cudaMemcpyToSymbol(cInvColorSigma, &invColorSigma, sizeof(float)); + #define BLOCK_SIZE 32 + // Creating features FeatureCreation - <<>>(inputTensorData, data, features); + <<>>(inputTensorData, data, features); // Filtering data with respect to the features for each sample in batch for (int batchIndex = 0; batchIndex < desc.batchCount; batchIndex++) { @@ -114,7 +116,7 @@ void BilateralFilterPHLCuda( } // Writing output - WriteOutput<<>>(data, outputTensorData); + WriteOutput<<>>(data, outputTensorData); cudaFree(data); cudaFree(features); From 98c7ec5a7a85a466e98916f4dc0d9f56445132c2 Mon Sep 17 00:00:00 2001 From: charliebudd Date: Fri, 15 Jan 2021 11:41:34 +0000 Subject: [PATCH 05/12] un-commenting skip test decorators Signed-off-by: charliebudd --- tests/test_phl_cpu.py | 3 ++- tests/test_phl_cuda.py | 4 ++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/test_phl_cpu.py b/tests/test_phl_cpu.py index a3cb7f12cf..f0e62cbddb 100644 --- a/tests/test_phl_cpu.py +++ b/tests/test_phl_cpu.py @@ -237,7 +237,8 @@ ], ] -# @skip_if_no_cpp_extention + +@skip_if_no_cpp_extention class PHLFilterTestCaseCpu(unittest.TestCase): @parameterized.expand(TEST_CASES) def test_cpu(self, test_case_description, sigmas, input, features, expected): diff --git a/tests/test_phl_cuda.py b/tests/test_phl_cuda.py index 7fa3c2e909..8b89efce1a 100644 --- a/tests/test_phl_cuda.py +++ b/tests/test_phl_cuda.py @@ -145,8 +145,8 @@ ] -# @skip_if_no_cuda -# @skip_if_no_cpp_extention +@skip_if_no_cuda +@skip_if_no_cpp_extention class PHLFilterTestCaseCuda(unittest.TestCase): @parameterized.expand(TEST_CASES) def test_cuda(self, test_case_description, sigmas, input, features, expected): From 7fe830711e7553b880412491b99cb0d14cc3a802 Mon Sep 17 00:00:00 2001 From: charliebudd Date: Fri, 15 Jan 2021 11:41:59 +0000 Subject: [PATCH 06/12] adding docs entry for PHLFilter Signed-off-by: charliebudd --- docs/source/networks.rst | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/docs/source/networks.rst b/docs/source/networks.rst index ed17d815b4..ae096b5f89 100644 --- a/docs/source/networks.rst +++ b/docs/source/networks.rst @@ -189,6 +189,11 @@ Layers .. autoclass:: BilateralFilter :members: +`PHLFilter` +~~~~~~~~~~~~~~~~~ +.. autoclass:: PHLFilter + :members: + `HilbertTransform` ~~~~~~~~~~~~~~~~~~ .. autoclass:: HilbertTransform From b153e9f27f7906aa9b71b2674c18de81139bb974 Mon Sep 17 00:00:00 2001 From: charliebudd Date: Fri, 15 Jan 2021 14:26:28 +0000 Subject: [PATCH 07/12] adding PHLFilter to init file Signed-off-by: charliebudd --- monai/networks/layers/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/networks/layers/__init__.py b/monai/networks/layers/__init__.py index dabec727ac..4007d88b51 100644 --- a/monai/networks/layers/__init__.py +++ b/monai/networks/layers/__init__.py @@ -11,7 +11,7 @@ from .convutils import calculate_out_shape, gaussian_1d, polyval, same_padding, stride_minus_kernel_padding from .factories import Act, Conv, Dropout, LayerFactory, Norm, Pad, Pool, split_args -from .filtering import BilateralFilter +from .filtering import BilateralFilter, PHLFilter from .simplelayers import ( LLTM, ChannelPad, From 452e4231f71e157aa14e9d4c4f83803a4b374358 Mon Sep 17 00:00:00 2001 From: charliebudd Date: Fri, 15 Jan 2021 17:00:03 +0000 Subject: [PATCH 08/12] fixing some compilation warnings Signed-off-by: charliebudd --- monai/csrc/filtering/bilateral/bilateralfilter_cpu.cpp | 2 +- .../csrc/filtering/bilateral/bilateralfilter_cpu_phl.cpp | 2 +- monai/csrc/filtering/bilateral/bilateralfilter_cuda.cu | 2 +- .../csrc/filtering/bilateral/bilateralfilter_cuda_phl.cu | 2 +- .../permutohedral/{hash_table.cu => hash_table.cuh} | 9 +++++++-- monai/csrc/filtering/permutohedral/permutohedral_cuda.cu | 2 +- 6 files changed, 12 insertions(+), 7 deletions(-) rename monai/csrc/filtering/permutohedral/{hash_table.cu => hash_table.cuh} (96%) diff --git a/monai/csrc/filtering/bilateral/bilateralfilter_cpu.cpp b/monai/csrc/filtering/bilateral/bilateralfilter_cpu.cpp index ea56ff7526..474d24b4fa 100644 --- a/monai/csrc/filtering/bilateral/bilateralfilter_cpu.cpp +++ b/monai/csrc/filtering/bilateral/bilateralfilter_cpu.cpp @@ -158,7 +158,7 @@ torch::Tensor BilateralFilterCpu(torch::Tensor inputTensor, float spatialSigma, // Preparing output tensor. torch::Tensor outputTensor = torch::zeros_like(inputTensor); - AT_DISPATCH_FLOATING_TYPES_AND_HALF(inputTensor.type(), "BilateralFilterCpu", ([&] { + AT_DISPATCH_FLOATING_TYPES_AND_HALF(inputTensor.scalar_type(), "BilateralFilterCpu", ([&] { BilateralFilterCpu( inputTensor, outputTensor, spatialSigma, colorSigma); })); diff --git a/monai/csrc/filtering/bilateral/bilateralfilter_cpu_phl.cpp b/monai/csrc/filtering/bilateral/bilateralfilter_cpu_phl.cpp index 8a7671c527..1fb48cb6c9 100644 --- a/monai/csrc/filtering/bilateral/bilateralfilter_cpu_phl.cpp +++ b/monai/csrc/filtering/bilateral/bilateralfilter_cpu_phl.cpp @@ -80,7 +80,7 @@ void BilateralFilterPHLCpu( torch::Tensor BilateralFilterPHLCpu(torch::Tensor inputTensor, float spatialSigma, float colorSigma) { torch::Tensor outputTensor = torch::zeros_like(inputTensor); - AT_DISPATCH_FLOATING_TYPES(inputTensor.type(), "BilateralFilterPhlCpu", ([&] { + AT_DISPATCH_FLOATING_TYPES(inputTensor.scalar_type(), "BilateralFilterPhlCpu", ([&] { BilateralFilterPHLCpu(inputTensor, outputTensor, spatialSigma, colorSigma); })); diff --git a/monai/csrc/filtering/bilateral/bilateralfilter_cuda.cu b/monai/csrc/filtering/bilateral/bilateralfilter_cuda.cu index 0215c2e124..2ea618576f 100644 --- a/monai/csrc/filtering/bilateral/bilateralfilter_cuda.cu +++ b/monai/csrc/filtering/bilateral/bilateralfilter_cuda.cu @@ -220,7 +220,7 @@ void BilateralFilterCuda(torch::Tensor inputTensor, torch::Tensor outputTensor, #define BLOCK_SIZE 32 AT_DISPATCH_FLOATING_TYPES_AND_HALF( - inputTensor.type(), "BilateralFilterCudaKernel", ([&] { + inputTensor.scalar_type(), "BilateralFilterCudaKernel", ([&] { // Dispatch kernel. (Partial template function specialisation not supported at present so using this switch // instead) switch (D) { diff --git a/monai/csrc/filtering/bilateral/bilateralfilter_cuda_phl.cu b/monai/csrc/filtering/bilateral/bilateralfilter_cuda_phl.cu index 2d48837b91..35e30a2d33 100644 --- a/monai/csrc/filtering/bilateral/bilateralfilter_cuda_phl.cu +++ b/monai/csrc/filtering/bilateral/bilateralfilter_cuda_phl.cu @@ -127,7 +127,7 @@ torch::Tensor BilateralFilterPHLCuda(torch::Tensor inputTensor, float spatialSig torch::Tensor outputTensor = torch::zeros_like(inputTensor); #define CASE(c, d) \ - AT_DISPATCH_FLOATING_TYPES(inputTensor.type(), "BilateralFilterCudaPHL", ([&] { \ + AT_DISPATCH_FLOATING_TYPES(inputTensor.scalar_type(), "BilateralFilterCudaPHL", ([&] { \ BilateralFilterPHLCuda( \ inputTensor, outputTensor, spatialSigma, colorSigma); \ })); diff --git a/monai/csrc/filtering/permutohedral/hash_table.cu b/monai/csrc/filtering/permutohedral/hash_table.cuh similarity index 96% rename from monai/csrc/filtering/permutohedral/hash_table.cu rename to monai/csrc/filtering/permutohedral/hash_table.cuh index bedf0c1efc..7d9d7eb163 100644 --- a/monai/csrc/filtering/permutohedral/hash_table.cu +++ b/monai/csrc/filtering/permutohedral/hash_table.cuh @@ -90,9 +90,14 @@ static scalar_t* createHashTable(int capacity) { template static void destroyHashTable() { #ifndef LINEAR_D_MEMORY - cudaFree(table_keys); + signed short* keys; + cudaMemcpyFromSymbol(&keys, table_keys, sizeof(unsigned int*)); + cudaFree(keys); #endif - cudaFree(table_entries); + + int* entries; + cudaMemcpyFromSymbol(&entries, table_entries, sizeof(int*)); + cudaFree(entries); } template diff --git a/monai/csrc/filtering/permutohedral/permutohedral_cuda.cu b/monai/csrc/filtering/permutohedral/permutohedral_cuda.cu index 94c5b90659..b87a88a84f 100644 --- a/monai/csrc/filtering/permutohedral/permutohedral_cuda.cu +++ b/monai/csrc/filtering/permutohedral/permutohedral_cuda.cu @@ -46,7 +46,7 @@ SOFTWARE. #include #include -#include "hash_table.cu" +#include "hash_table.cuh" #include "utils/meta_macros.h" template From eda1e4e7d2bb08acb570decae1cb36545c0380de Mon Sep 17 00:00:00 2001 From: charliebudd Date: Mon, 18 Jan 2021 12:48:47 +0000 Subject: [PATCH 09/12] implementing manual movedim for torch < 1.7 Signed-off-by: charliebudd --- .../filtering/permutohedral/permutohedral.cpp | 21 +++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/monai/csrc/filtering/permutohedral/permutohedral.cpp b/monai/csrc/filtering/permutohedral/permutohedral.cpp index 605ae01526..bfad0d7faf 100644 --- a/monai/csrc/filtering/permutohedral/permutohedral.cpp +++ b/monai/csrc/filtering/permutohedral/permutohedral.cpp @@ -14,8 +14,22 @@ torch::Tensor PermutohedralFilter(torch::Tensor input, torch::Tensor features) { int channelCount = input.size(1); int featureCount = features.size(1); + // movedim not support in torch < 1.7 + #if MONAI_TORCH_VERSION >= 10700 torch::Tensor data = input.clone().movedim(1, -1).contiguous(); features = features.movedim(1, -1).contiguous(); + #else + torch::Tensor data = input.clone(); + features = features; + + for (int i=1; i < input.dim()-1; i++){ + data = data.transpose(i, i+1); + features = features.transpose(i, i+1); + } + + data = data.contiguous(); + features = features.contiguous(); + #endif #ifdef WITH_CUDA if (torch::cuda::is_available() && data.is_cuda()) { @@ -42,7 +56,14 @@ torch::Tensor PermutohedralFilter(torch::Tensor input, torch::Tensor features) { } #endif + // movedim not support in torch < 1.7 + #if MONAI_TORCH_VERSION >= 10700 data = data.movedim(-1, 1); + #else + for (int i=input.dim()-1; i > 1; i--){ + data = data.transpose(i-1, i); + } + #endif return data; } From 80bf2bc17019a3c5938d6062a161e2e6539d7bcb Mon Sep 17 00:00:00 2001 From: charliebudd Date: Mon, 18 Jan 2021 14:47:14 +0000 Subject: [PATCH 10/12] fixing dispatch width Signed-off-by: charliebudd --- monai/csrc/filtering/bilateral/bilateralfilter_cuda.cu | 6 +++--- monai/csrc/filtering/permutohedral/permutohedral.cpp | 1 - 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/monai/csrc/filtering/bilateral/bilateralfilter_cuda.cu b/monai/csrc/filtering/bilateral/bilateralfilter_cuda.cu index 2ea618576f..c392ca9a9a 100644 --- a/monai/csrc/filtering/bilateral/bilateralfilter_cuda.cu +++ b/monai/csrc/filtering/bilateral/bilateralfilter_cuda.cu @@ -225,15 +225,15 @@ void BilateralFilterCuda(torch::Tensor inputTensor, torch::Tensor outputTensor, // instead) switch (D) { case (1): - BilateralFilterCudaKernel1D<<>>( + BilateralFilterCudaKernel1D<<>>( inputTensor.data_ptr(), outputTensor.data_ptr()); break; case (2): - BilateralFilterCudaKernel2D<<>>( + BilateralFilterCudaKernel2D<<>>( inputTensor.data_ptr(), outputTensor.data_ptr()); break; case (3): - BilateralFilterCudaKernel3D<<>>( + BilateralFilterCudaKernel3D<<>>( inputTensor.data_ptr(), outputTensor.data_ptr()); break; } diff --git a/monai/csrc/filtering/permutohedral/permutohedral.cpp b/monai/csrc/filtering/permutohedral/permutohedral.cpp index bfad0d7faf..89d6d10c99 100644 --- a/monai/csrc/filtering/permutohedral/permutohedral.cpp +++ b/monai/csrc/filtering/permutohedral/permutohedral.cpp @@ -2,7 +2,6 @@ #include "utils/meta_macros.h" #include "permutohedral.h" -#include torch::Tensor PermutohedralFilter(torch::Tensor input, torch::Tensor features) { From 24d34354922340a6bb84d065509586853332066e Mon Sep 17 00:00:00 2001 From: monai-bot Date: Mon, 18 Jan 2021 16:13:07 +0000 Subject: [PATCH 11/12] [MONAI] python code formatting Signed-off-by: monai-bot --- .../bilateral/bilateralfilter_cuda.cu | 28 +- .../bilateral/bilateralfilter_cuda_phl.cu | 17 +- .../filtering/permutohedral/permutohedral.cpp | 107 +++--- .../permutohedral/permutohedral_cpu.cpp | 14 +- tests/test_savitzky_golay_filter.py | 304 +++++++++--------- tests/test_savitzky_golay_smooth.py | 140 ++++---- 6 files changed, 306 insertions(+), 304 deletions(-) diff --git a/monai/csrc/filtering/bilateral/bilateralfilter_cuda.cu b/monai/csrc/filtering/bilateral/bilateralfilter_cuda.cu index c392ca9a9a..4477ce5845 100644 --- a/monai/csrc/filtering/bilateral/bilateralfilter_cuda.cu +++ b/monai/csrc/filtering/bilateral/bilateralfilter_cuda.cu @@ -36,7 +36,8 @@ __global__ void BilateralFilterCudaKernel1D(scalar_t* input, scalar_t* output) { int homeOffset = blockIdx.x * blockDim.x + threadIdx.x; int batchOffset = blockIdx.y * cBatchStride; - if (homeOffset >= cColorStride) return; + if (homeOffset >= cColorStride) + return; scalar_t weightSum = 0; @@ -80,8 +81,9 @@ __global__ void BilateralFilterCudaKernel2D(scalar_t* input, scalar_t* output) { int homeOffset = blockIdx.x * blockDim.x + threadIdx.x; int batchOffset = blockIdx.y * cBatchStride; - - if (homeOffset >= cColorStride) return; + + if (homeOffset >= cColorStride) + return; int homeX = homeOffset / cStrides[0]; int homeY = (homeOffset - homeX * cStrides[0]) / cStrides[1]; @@ -136,7 +138,8 @@ __global__ void BilateralFilterCudaKernel3D(scalar_t* input, scalar_t* output) { int homeOffset = blockIdx.x * blockDim.x + threadIdx.x; int batchOffset = blockIdx.y * cBatchStride; - if (homeOffset >= cColorStride) return; + if (homeOffset >= cColorStride) + return; int homeX = homeOffset / cStrides[0]; int homeY = (homeOffset - homeX * cStrides[0]) / cStrides[1]; @@ -217,7 +220,7 @@ void BilateralFilterCuda(torch::Tensor inputTensor, torch::Tensor outputTensor, cudaMemcpyToSymbol(cKernel, kernel, sizeof(float) * kernelSize); cudaMemcpyToSymbol(cColorExponentFactor, &colorExponentFactor, sizeof(float)); - #define BLOCK_SIZE 32 +#define BLOCK_SIZE 32 AT_DISPATCH_FLOATING_TYPES_AND_HALF( inputTensor.scalar_type(), "BilateralFilterCudaKernel", ([&] { @@ -225,16 +228,19 @@ void BilateralFilterCuda(torch::Tensor inputTensor, torch::Tensor outputTensor, // instead) switch (D) { case (1): - BilateralFilterCudaKernel1D<<>>( - inputTensor.data_ptr(), outputTensor.data_ptr()); + BilateralFilterCudaKernel1D + <<>>( + inputTensor.data_ptr(), outputTensor.data_ptr()); break; case (2): - BilateralFilterCudaKernel2D<<>>( - inputTensor.data_ptr(), outputTensor.data_ptr()); + BilateralFilterCudaKernel2D + <<>>( + inputTensor.data_ptr(), outputTensor.data_ptr()); break; case (3): - BilateralFilterCudaKernel3D<<>>( - inputTensor.data_ptr(), outputTensor.data_ptr()); + BilateralFilterCudaKernel3D + <<>>( + inputTensor.data_ptr(), outputTensor.data_ptr()); break; } })); diff --git a/monai/csrc/filtering/bilateral/bilateralfilter_cuda_phl.cu b/monai/csrc/filtering/bilateral/bilateralfilter_cuda_phl.cu index 35e30a2d33..603ab689cf 100644 --- a/monai/csrc/filtering/bilateral/bilateralfilter_cuda_phl.cu +++ b/monai/csrc/filtering/bilateral/bilateralfilter_cuda_phl.cu @@ -30,7 +30,8 @@ __global__ void FeatureCreation(const scalar_t* inputTensor, scalar_t* outputDat int elementIndex = blockIdx.x * blockDim.x + threadIdx.x; int batchIndex = blockIdx.y; - if (elementIndex >= cChannelStride) return; + if (elementIndex >= cChannelStride) + return; int dataBatchOffset = batchIndex * cBatchStride; int featureBatchOffset = batchIndex * (D + C) * cChannelStride; @@ -59,8 +60,8 @@ __global__ void WriteOutput(const scalar_t* data, scalar_t* outputTensor) { int elementIndex = blockIdx.x * blockDim.x + threadIdx.x; int batchIndex = blockIdx.y; - - if (elementIndex >= cChannelStride) return; + if (elementIndex >= cChannelStride) + return; int batchOffset = batchIndex * cBatchStride; @@ -101,11 +102,12 @@ void BilateralFilterPHLCuda( cudaMemcpyToSymbol(cInvSpatialSigma, &invSpatialSigma, sizeof(float)); cudaMemcpyToSymbol(cInvColorSigma, &invColorSigma, sizeof(float)); - #define BLOCK_SIZE 32 +#define BLOCK_SIZE 32 // Creating features FeatureCreation - <<>>(inputTensorData, data, features); + <<>>( + inputTensorData, data, features); // Filtering data with respect to the features for each sample in batch for (int batchIndex = 0; batchIndex < desc.batchCount; batchIndex++) { @@ -116,7 +118,8 @@ void BilateralFilterPHLCuda( } // Writing output - WriteOutput<<>>(data, outputTensorData); + WriteOutput<<>>( + data, outputTensorData); cudaFree(data); cudaFree(features); @@ -127,7 +130,7 @@ torch::Tensor BilateralFilterPHLCuda(torch::Tensor inputTensor, float spatialSig torch::Tensor outputTensor = torch::zeros_like(inputTensor); #define CASE(c, d) \ - AT_DISPATCH_FLOATING_TYPES(inputTensor.scalar_type(), "BilateralFilterCudaPHL", ([&] { \ + AT_DISPATCH_FLOATING_TYPES(inputTensor.scalar_type(), "BilateralFilterCudaPHL", ([&] { \ BilateralFilterPHLCuda( \ inputTensor, outputTensor, spatialSigma, colorSigma); \ })); diff --git a/monai/csrc/filtering/permutohedral/permutohedral.cpp b/monai/csrc/filtering/permutohedral/permutohedral.cpp index 89d6d10c99..c8a7bd881a 100644 --- a/monai/csrc/filtering/permutohedral/permutohedral.cpp +++ b/monai/csrc/filtering/permutohedral/permutohedral.cpp @@ -4,65 +4,68 @@ #include "permutohedral.h" torch::Tensor PermutohedralFilter(torch::Tensor input, torch::Tensor features) { + input = input.contiguous(); - input = input.contiguous(); + int batchCount = input.size(0); + int batchStride = input.stride(0); + int elementCount = input.stride(1); + int channelCount = input.size(1); + int featureCount = features.size(1); - int batchCount = input.size(0); - int batchStride = input.stride(0); - int elementCount = input.stride(1); - int channelCount = input.size(1); - int featureCount = features.size(1); +// movedim not support in torch < 1.7 +#if MONAI_TORCH_VERSION >= 10700 + torch::Tensor data = input.clone().movedim(1, -1).contiguous(); + features = features.movedim(1, -1).contiguous(); +#else + torch::Tensor data = input.clone(); + features = features; - // movedim not support in torch < 1.7 - #if MONAI_TORCH_VERSION >= 10700 - torch::Tensor data = input.clone().movedim(1, -1).contiguous(); - features = features.movedim(1, -1).contiguous(); - #else - torch::Tensor data = input.clone(); - features = features; + for (int i = 1; i < input.dim() - 1; i++) { + data = data.transpose(i, i + 1); + features = features.transpose(i, i + 1); + } - for (int i=1; i < input.dim()-1; i++){ - data = data.transpose(i, i+1); - features = features.transpose(i, i+1); - } + data = data.contiguous(); + features = features.contiguous(); +#endif - data = data.contiguous(); - features = features.contiguous(); - #endif +#ifdef WITH_CUDA + if (torch::cuda::is_available() && data.is_cuda()) { + CHECK_CONTIGUOUS_CUDA(data); - #ifdef WITH_CUDA - if (torch::cuda::is_available() && data.is_cuda()) { - CHECK_CONTIGUOUS_CUDA(data); - - #define CASE(dc, fc) AT_DISPATCH_FLOATING_TYPES(data.scalar_type(), "PermutohedralCuda", ([&] { \ - for (int batchIndex = 0; batchIndex < batchCount; batchIndex++) { \ - scalar_t* offsetData = data.data_ptr() + batchIndex * batchStride; \ - scalar_t* offsetFeatures = features.data_ptr() + batchIndex * fc * elementCount; \ - PermutohedralCuda(offsetData, offsetFeatures, elementCount, true); \ - }})); - SWITCH_AB(CASE, 16, 19, channelCount, featureCount); +#define CASE(dc, fc) \ + AT_DISPATCH_FLOATING_TYPES(data.scalar_type(), "PermutohedralCuda", ([&] { \ + for (int batchIndex = 0; batchIndex < batchCount; batchIndex++) { \ + scalar_t* offsetData = data.data_ptr() + batchIndex * batchStride; \ + scalar_t* offsetFeatures = \ + features.data_ptr() + batchIndex * fc * elementCount; \ + PermutohedralCuda(offsetData, offsetFeatures, elementCount, true); \ + } \ + })); + SWITCH_AB(CASE, 16, 19, channelCount, featureCount); - } - else { - #endif - AT_DISPATCH_FLOATING_TYPES(data.scalar_type(), "PermutohedralCPU", ([&] { \ - for (int batchIndex = 0; batchIndex < batchCount; batchIndex++) { \ - scalar_t* offsetData = data.data_ptr() + batchIndex * batchStride; \ - scalar_t* offsetFeatures = features.data_ptr() + batchIndex * featureCount * elementCount;\ - PermutohedralCPU(offsetData, offsetFeatures, channelCount, featureCount, elementCount); \ - }})); - #ifdef WITH_CUDA - } - #endif + } else { +#endif + AT_DISPATCH_FLOATING_TYPES( + data.scalar_type(), "PermutohedralCPU", ([&] { + for (int batchIndex = 0; batchIndex < batchCount; batchIndex++) { + scalar_t* offsetData = data.data_ptr() + batchIndex * batchStride; + scalar_t* offsetFeatures = features.data_ptr() + batchIndex * featureCount * elementCount; + PermutohedralCPU(offsetData, offsetFeatures, channelCount, featureCount, elementCount); + } + })); +#ifdef WITH_CUDA + } +#endif - // movedim not support in torch < 1.7 - #if MONAI_TORCH_VERSION >= 10700 - data = data.movedim(-1, 1); - #else - for (int i=input.dim()-1; i > 1; i--){ - data = data.transpose(i-1, i); - } - #endif +// movedim not support in torch < 1.7 +#if MONAI_TORCH_VERSION >= 10700 + data = data.movedim(-1, 1); +#else + for (int i = input.dim() - 1; i > 1; i--) { + data = data.transpose(i - 1, i); + } +#endif - return data; + return data; } diff --git a/monai/csrc/filtering/permutohedral/permutohedral_cpu.cpp b/monai/csrc/filtering/permutohedral/permutohedral_cpu.cpp index bc9ab46973..0876997448 100644 --- a/monai/csrc/filtering/permutohedral/permutohedral_cpu.cpp +++ b/monai/csrc/filtering/permutohedral/permutohedral_cpu.cpp @@ -494,19 +494,9 @@ class PermutohedralLattice { }; template -void PermutohedralCPU( - scalar_t* data, - scalar_t* features, - int dataChannels, - int featureChannels, - int elementCount) { +void PermutohedralCPU(scalar_t* data, scalar_t* features, int dataChannels, int featureChannels, int elementCount) { PermutohedralLattice::filter(data, features, dataChannels, featureChannels, elementCount); } template void PermutohedralCPU(float* data, float* features, int dataChannels, int featureChannels, int elementCount); -template void PermutohedralCPU( - double* data, - double* features, - int dataChannels, - int featureChannels, - int elementCount); \ No newline at end of file +template void PermutohedralCPU(double* data, double* features, int dataChannels, int featureChannels, int elementCount); \ No newline at end of file diff --git a/tests/test_savitzky_golay_filter.py b/tests/test_savitzky_golay_filter.py index d76c42c15f..9163204810 100644 --- a/tests/test_savitzky_golay_filter.py +++ b/tests/test_savitzky_golay_filter.py @@ -1,152 +1,152 @@ -# Copyright 2020 MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import unittest - -import numpy as np -import torch -from parameterized import parameterized - -from monai.networks.layers import SavitzkyGolayFilter -from tests.utils import skip_if_no_cuda - -# Zero-padding trivial tests - -TEST_CASE_SINGLE_VALUE = [ - {"window_length": 3, "order": 1}, - torch.Tensor([1.0]).unsqueeze(0).unsqueeze(0), # Input data: Single value - torch.Tensor([1 / 3]).unsqueeze(0).unsqueeze(0), # Expected output: With a window length of 3 and polyorder 1 - # output should be equal to mean of 0, 1 and 0 = 1/3 (because input will be zero-padded and a linear fit performed) - 1e-15, # absolute tolerance -] - -TEST_CASE_1D = [ - {"window_length": 3, "order": 1}, - torch.Tensor([1.0, 1.0, 1.0]).unsqueeze(0).unsqueeze(0), # Input data - torch.Tensor([2 / 3, 1.0, 2 / 3]) - .unsqueeze(0) - .unsqueeze(0), # Expected output: zero padded, so linear interpolation - # over length-3 windows will result in output of [2/3, 1, 2/3]. - 1e-15, # absolute tolerance -] - -TEST_CASE_2D_AXIS_2 = [ - {"window_length": 3, "order": 1}, # along default axis (2, first spatial dim) - torch.ones((3, 2)).unsqueeze(0).unsqueeze(0), - torch.Tensor([[2 / 3, 2 / 3], [1.0, 1.0], [2 / 3, 2 / 3]]).unsqueeze(0).unsqueeze(0), - 1e-15, # absolute tolerance -] - -TEST_CASE_2D_AXIS_3 = [ - {"window_length": 3, "order": 1, "axis": 3}, # along axis 3 (second spatial dim) - torch.ones((2, 3)).unsqueeze(0).unsqueeze(0), - torch.Tensor([[2 / 3, 1.0, 2 / 3], [2 / 3, 1.0, 2 / 3]]).unsqueeze(0).unsqueeze(0), - 1e-15, # absolute tolerance -] - -# Replicated-padding trivial tests - -TEST_CASE_SINGLE_VALUE_REP = [ - {"window_length": 3, "order": 1, "mode": "replicate"}, - torch.Tensor([1.0]).unsqueeze(0).unsqueeze(0), # Input data: Single value - torch.Tensor([1.0]).unsqueeze(0).unsqueeze(0), # Expected output: With a window length of 3 and polyorder 1 - # output will be equal to mean of [1, 1, 1] = 1 (input will be nearest-neighbour-padded and a linear fit performed) - 1e-15, # absolute tolerance -] - -TEST_CASE_1D_REP = [ - {"window_length": 3, "order": 1, "mode": "replicate"}, - torch.Tensor([1.0, 1.0, 1.0]).unsqueeze(0).unsqueeze(0), # Input data - torch.Tensor([1.0, 1.0, 1.0]).unsqueeze(0).unsqueeze(0), # Expected output: zero padded, so linear interpolation - # over length-3 windows will result in output of [2/3, 1, 2/3]. - 1e-15, # absolute tolerance -] - -TEST_CASE_2D_AXIS_2_REP = [ - {"window_length": 3, "order": 1, "mode": "replicate"}, # along default axis (2, first spatial dim) - torch.ones((3, 2)).unsqueeze(0).unsqueeze(0), - torch.Tensor([[1.0, 1.0], [1.0, 1.0], [1.0, 1.0]]).unsqueeze(0).unsqueeze(0), - 1e-15, # absolute tolerance -] - -TEST_CASE_2D_AXIS_3_REP = [ - {"window_length": 3, "order": 1, "axis": 3, "mode": "replicate"}, # along axis 3 (second spatial dim) - torch.ones((2, 3)).unsqueeze(0).unsqueeze(0), - torch.Tensor([[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]).unsqueeze(0).unsqueeze(0), - 1e-15, # absolute tolerance -] - -# Sine smoothing - -TEST_CASE_SINE_SMOOTH = [ - {"window_length": 3, "order": 1}, - # Sine wave with period equal to savgol window length (windowed to reduce edge effects). - torch.as_tensor(np.sin(2 * np.pi * 1 / 3 * np.arange(100)) * np.hanning(100)).unsqueeze(0).unsqueeze(0), - # Should be smoothed out to zeros - torch.zeros(100).unsqueeze(0).unsqueeze(0), - # tolerance chosen by examining output of SciPy.signal.savgol_filter when provided the above input - 2e-2, # absolute tolerance -] - - -class TestSavitzkyGolayCPU(unittest.TestCase): - @parameterized.expand( - [ - TEST_CASE_SINGLE_VALUE, - TEST_CASE_1D, - TEST_CASE_2D_AXIS_2, - TEST_CASE_2D_AXIS_3, - TEST_CASE_SINE_SMOOTH, - ] - ) - def test_value(self, arguments, image, expected_data, atol): - result = SavitzkyGolayFilter(**arguments)(image) - np.testing.assert_allclose(result, expected_data, atol=atol) - - -class TestSavitzkyGolayCPUREP(unittest.TestCase): - @parameterized.expand( - [TEST_CASE_SINGLE_VALUE_REP, TEST_CASE_1D_REP, TEST_CASE_2D_AXIS_2_REP, TEST_CASE_2D_AXIS_3_REP] - ) - def test_value(self, arguments, image, expected_data, atol): - result = SavitzkyGolayFilter(**arguments)(image) - np.testing.assert_allclose(result, expected_data, atol=atol) - - -@skip_if_no_cuda -class TestSavitzkyGolayGPU(unittest.TestCase): - @parameterized.expand( - [ - TEST_CASE_SINGLE_VALUE, - TEST_CASE_1D, - TEST_CASE_2D_AXIS_2, - TEST_CASE_2D_AXIS_3, - TEST_CASE_SINE_SMOOTH, - ] - ) - def test_value(self, arguments, image, expected_data, atol): - result = SavitzkyGolayFilter(**arguments)(image.to(device="cuda")) - np.testing.assert_allclose(result.cpu(), expected_data, atol=atol) - - -@skip_if_no_cuda -class TestSavitzkyGolayGPUREP(unittest.TestCase): - @parameterized.expand( - [ - TEST_CASE_SINGLE_VALUE_REP, - TEST_CASE_1D_REP, - TEST_CASE_2D_AXIS_2_REP, - TEST_CASE_2D_AXIS_3_REP, - ] - ) - def test_value(self, arguments, image, expected_data, atol): - result = SavitzkyGolayFilter(**arguments)(image.to(device="cuda")) - np.testing.assert_allclose(result.cpu(), expected_data, atol=atol) +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import torch +from parameterized import parameterized + +from monai.networks.layers import SavitzkyGolayFilter +from tests.utils import skip_if_no_cuda + +# Zero-padding trivial tests + +TEST_CASE_SINGLE_VALUE = [ + {"window_length": 3, "order": 1}, + torch.Tensor([1.0]).unsqueeze(0).unsqueeze(0), # Input data: Single value + torch.Tensor([1 / 3]).unsqueeze(0).unsqueeze(0), # Expected output: With a window length of 3 and polyorder 1 + # output should be equal to mean of 0, 1 and 0 = 1/3 (because input will be zero-padded and a linear fit performed) + 1e-15, # absolute tolerance +] + +TEST_CASE_1D = [ + {"window_length": 3, "order": 1}, + torch.Tensor([1.0, 1.0, 1.0]).unsqueeze(0).unsqueeze(0), # Input data + torch.Tensor([2 / 3, 1.0, 2 / 3]) + .unsqueeze(0) + .unsqueeze(0), # Expected output: zero padded, so linear interpolation + # over length-3 windows will result in output of [2/3, 1, 2/3]. + 1e-15, # absolute tolerance +] + +TEST_CASE_2D_AXIS_2 = [ + {"window_length": 3, "order": 1}, # along default axis (2, first spatial dim) + torch.ones((3, 2)).unsqueeze(0).unsqueeze(0), + torch.Tensor([[2 / 3, 2 / 3], [1.0, 1.0], [2 / 3, 2 / 3]]).unsqueeze(0).unsqueeze(0), + 1e-15, # absolute tolerance +] + +TEST_CASE_2D_AXIS_3 = [ + {"window_length": 3, "order": 1, "axis": 3}, # along axis 3 (second spatial dim) + torch.ones((2, 3)).unsqueeze(0).unsqueeze(0), + torch.Tensor([[2 / 3, 1.0, 2 / 3], [2 / 3, 1.0, 2 / 3]]).unsqueeze(0).unsqueeze(0), + 1e-15, # absolute tolerance +] + +# Replicated-padding trivial tests + +TEST_CASE_SINGLE_VALUE_REP = [ + {"window_length": 3, "order": 1, "mode": "replicate"}, + torch.Tensor([1.0]).unsqueeze(0).unsqueeze(0), # Input data: Single value + torch.Tensor([1.0]).unsqueeze(0).unsqueeze(0), # Expected output: With a window length of 3 and polyorder 1 + # output will be equal to mean of [1, 1, 1] = 1 (input will be nearest-neighbour-padded and a linear fit performed) + 1e-15, # absolute tolerance +] + +TEST_CASE_1D_REP = [ + {"window_length": 3, "order": 1, "mode": "replicate"}, + torch.Tensor([1.0, 1.0, 1.0]).unsqueeze(0).unsqueeze(0), # Input data + torch.Tensor([1.0, 1.0, 1.0]).unsqueeze(0).unsqueeze(0), # Expected output: zero padded, so linear interpolation + # over length-3 windows will result in output of [2/3, 1, 2/3]. + 1e-15, # absolute tolerance +] + +TEST_CASE_2D_AXIS_2_REP = [ + {"window_length": 3, "order": 1, "mode": "replicate"}, # along default axis (2, first spatial dim) + torch.ones((3, 2)).unsqueeze(0).unsqueeze(0), + torch.Tensor([[1.0, 1.0], [1.0, 1.0], [1.0, 1.0]]).unsqueeze(0).unsqueeze(0), + 1e-15, # absolute tolerance +] + +TEST_CASE_2D_AXIS_3_REP = [ + {"window_length": 3, "order": 1, "axis": 3, "mode": "replicate"}, # along axis 3 (second spatial dim) + torch.ones((2, 3)).unsqueeze(0).unsqueeze(0), + torch.Tensor([[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]).unsqueeze(0).unsqueeze(0), + 1e-15, # absolute tolerance +] + +# Sine smoothing + +TEST_CASE_SINE_SMOOTH = [ + {"window_length": 3, "order": 1}, + # Sine wave with period equal to savgol window length (windowed to reduce edge effects). + torch.as_tensor(np.sin(2 * np.pi * 1 / 3 * np.arange(100)) * np.hanning(100)).unsqueeze(0).unsqueeze(0), + # Should be smoothed out to zeros + torch.zeros(100).unsqueeze(0).unsqueeze(0), + # tolerance chosen by examining output of SciPy.signal.savgol_filter when provided the above input + 2e-2, # absolute tolerance +] + + +class TestSavitzkyGolayCPU(unittest.TestCase): + @parameterized.expand( + [ + TEST_CASE_SINGLE_VALUE, + TEST_CASE_1D, + TEST_CASE_2D_AXIS_2, + TEST_CASE_2D_AXIS_3, + TEST_CASE_SINE_SMOOTH, + ] + ) + def test_value(self, arguments, image, expected_data, atol): + result = SavitzkyGolayFilter(**arguments)(image) + np.testing.assert_allclose(result, expected_data, atol=atol) + + +class TestSavitzkyGolayCPUREP(unittest.TestCase): + @parameterized.expand( + [TEST_CASE_SINGLE_VALUE_REP, TEST_CASE_1D_REP, TEST_CASE_2D_AXIS_2_REP, TEST_CASE_2D_AXIS_3_REP] + ) + def test_value(self, arguments, image, expected_data, atol): + result = SavitzkyGolayFilter(**arguments)(image) + np.testing.assert_allclose(result, expected_data, atol=atol) + + +@skip_if_no_cuda +class TestSavitzkyGolayGPU(unittest.TestCase): + @parameterized.expand( + [ + TEST_CASE_SINGLE_VALUE, + TEST_CASE_1D, + TEST_CASE_2D_AXIS_2, + TEST_CASE_2D_AXIS_3, + TEST_CASE_SINE_SMOOTH, + ] + ) + def test_value(self, arguments, image, expected_data, atol): + result = SavitzkyGolayFilter(**arguments)(image.to(device="cuda")) + np.testing.assert_allclose(result.cpu(), expected_data, atol=atol) + + +@skip_if_no_cuda +class TestSavitzkyGolayGPUREP(unittest.TestCase): + @parameterized.expand( + [ + TEST_CASE_SINGLE_VALUE_REP, + TEST_CASE_1D_REP, + TEST_CASE_2D_AXIS_2_REP, + TEST_CASE_2D_AXIS_3_REP, + ] + ) + def test_value(self, arguments, image, expected_data, atol): + result = SavitzkyGolayFilter(**arguments)(image.to(device="cuda")) + np.testing.assert_allclose(result.cpu(), expected_data, atol=atol) diff --git a/tests/test_savitzky_golay_smooth.py b/tests/test_savitzky_golay_smooth.py index 2be0da1360..63dcce1b05 100644 --- a/tests/test_savitzky_golay_smooth.py +++ b/tests/test_savitzky_golay_smooth.py @@ -1,70 +1,70 @@ -# Copyright 2020 MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import unittest - -import numpy as np -from parameterized import parameterized - -from monai.transforms import SavitzkyGolaySmooth - -# Zero-padding trivial tests - -TEST_CASE_SINGLE_VALUE = [ - {"window_length": 3, "order": 1}, - np.expand_dims(np.array([1.0]), 0), # Input data: Single value - np.expand_dims(np.array([1 / 3]), 0), # Expected output: With a window length of 3 and polyorder 1 - # output should be equal to mean of 0, 1 and 0 = 1/3 (because input will be zero-padded and a linear fit performed) - 1e-15, # absolute tolerance -] - -TEST_CASE_2D_AXIS_2 = [ - {"window_length": 3, "order": 1, "axis": 2}, # along axis 2 (second spatial dim) - np.expand_dims(np.ones((2, 3)), 0), - np.expand_dims(np.array([[2 / 3, 1.0, 2 / 3], [2 / 3, 1.0, 2 / 3]]), 0), - 1e-15, # absolute tolerance -] - -# Replicated-padding trivial tests - -TEST_CASE_SINGLE_VALUE_REP = [ - {"window_length": 3, "order": 1, "mode": "replicate"}, - np.expand_dims(np.array([1.0]), 0), # Input data: Single value - np.expand_dims(np.array([1.0]), 0), # Expected output: With a window length of 3 and polyorder 1 - # output will be equal to mean of [1, 1, 1] = 1 (input will be nearest-neighbour-padded and a linear fit performed) - 1e-15, # absolute tolerance -] - -# Sine smoothing - -TEST_CASE_SINE_SMOOTH = [ - {"window_length": 3, "order": 1}, - # Sine wave with period equal to savgol window length (windowed to reduce edge effects). - np.expand_dims(np.sin(2 * np.pi * 1 / 3 * np.arange(100)) * np.hanning(100), 0), - # Should be smoothed out to zeros - np.expand_dims(np.zeros(100), 0), - # tolerance chosen by examining output of SciPy.signal.savgol_filter() when provided the above input - 2e-2, # absolute tolerance -] - - -class TestSavitzkyGolaySmooth(unittest.TestCase): - @parameterized.expand([TEST_CASE_SINGLE_VALUE, TEST_CASE_2D_AXIS_2, TEST_CASE_SINE_SMOOTH]) - def test_value(self, arguments, image, expected_data, atol): - result = SavitzkyGolaySmooth(**arguments)(image) - np.testing.assert_allclose(result, expected_data, atol=atol) - - -class TestSavitzkyGolaySmoothREP(unittest.TestCase): - @parameterized.expand([TEST_CASE_SINGLE_VALUE_REP]) - def test_value(self, arguments, image, expected_data, atol): - result = SavitzkyGolaySmooth(**arguments)(image) - np.testing.assert_allclose(result, expected_data, atol=atol) +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from parameterized import parameterized + +from monai.transforms import SavitzkyGolaySmooth + +# Zero-padding trivial tests + +TEST_CASE_SINGLE_VALUE = [ + {"window_length": 3, "order": 1}, + np.expand_dims(np.array([1.0]), 0), # Input data: Single value + np.expand_dims(np.array([1 / 3]), 0), # Expected output: With a window length of 3 and polyorder 1 + # output should be equal to mean of 0, 1 and 0 = 1/3 (because input will be zero-padded and a linear fit performed) + 1e-15, # absolute tolerance +] + +TEST_CASE_2D_AXIS_2 = [ + {"window_length": 3, "order": 1, "axis": 2}, # along axis 2 (second spatial dim) + np.expand_dims(np.ones((2, 3)), 0), + np.expand_dims(np.array([[2 / 3, 1.0, 2 / 3], [2 / 3, 1.0, 2 / 3]]), 0), + 1e-15, # absolute tolerance +] + +# Replicated-padding trivial tests + +TEST_CASE_SINGLE_VALUE_REP = [ + {"window_length": 3, "order": 1, "mode": "replicate"}, + np.expand_dims(np.array([1.0]), 0), # Input data: Single value + np.expand_dims(np.array([1.0]), 0), # Expected output: With a window length of 3 and polyorder 1 + # output will be equal to mean of [1, 1, 1] = 1 (input will be nearest-neighbour-padded and a linear fit performed) + 1e-15, # absolute tolerance +] + +# Sine smoothing + +TEST_CASE_SINE_SMOOTH = [ + {"window_length": 3, "order": 1}, + # Sine wave with period equal to savgol window length (windowed to reduce edge effects). + np.expand_dims(np.sin(2 * np.pi * 1 / 3 * np.arange(100)) * np.hanning(100), 0), + # Should be smoothed out to zeros + np.expand_dims(np.zeros(100), 0), + # tolerance chosen by examining output of SciPy.signal.savgol_filter() when provided the above input + 2e-2, # absolute tolerance +] + + +class TestSavitzkyGolaySmooth(unittest.TestCase): + @parameterized.expand([TEST_CASE_SINGLE_VALUE, TEST_CASE_2D_AXIS_2, TEST_CASE_SINE_SMOOTH]) + def test_value(self, arguments, image, expected_data, atol): + result = SavitzkyGolaySmooth(**arguments)(image) + np.testing.assert_allclose(result, expected_data, atol=atol) + + +class TestSavitzkyGolaySmoothREP(unittest.TestCase): + @parameterized.expand([TEST_CASE_SINGLE_VALUE_REP]) + def test_value(self, arguments, image, expected_data, atol): + result = SavitzkyGolaySmooth(**arguments)(image) + np.testing.assert_allclose(result, expected_data, atol=atol) From 2e9b0a7457455adb8c18e27c64bd5b2dc1ca61e6 Mon Sep 17 00:00:00 2001 From: charliebudd Date: Mon, 18 Jan 2021 16:21:31 +0000 Subject: [PATCH 12/12] changing torch version from which to use movedim to >= 1.7.1 Signed-off-by: charliebudd --- monai/csrc/filtering/permutohedral/permutohedral.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/monai/csrc/filtering/permutohedral/permutohedral.cpp b/monai/csrc/filtering/permutohedral/permutohedral.cpp index c8a7bd881a..5d6916b8f4 100644 --- a/monai/csrc/filtering/permutohedral/permutohedral.cpp +++ b/monai/csrc/filtering/permutohedral/permutohedral.cpp @@ -12,8 +12,8 @@ torch::Tensor PermutohedralFilter(torch::Tensor input, torch::Tensor features) { int channelCount = input.size(1); int featureCount = features.size(1); -// movedim not support in torch < 1.7 -#if MONAI_TORCH_VERSION >= 10700 +// movedim not support in torch < 1.7.1 +#if MONAI_TORCH_VERSION >= 10701 torch::Tensor data = input.clone().movedim(1, -1).contiguous(); features = features.movedim(1, -1).contiguous(); #else @@ -58,8 +58,8 @@ torch::Tensor PermutohedralFilter(torch::Tensor input, torch::Tensor features) { } #endif -// movedim not support in torch < 1.7 -#if MONAI_TORCH_VERSION >= 10700 +// movedim not support in torch < 1.7.1 +#if MONAI_TORCH_VERSION >= 10701 data = data.movedim(-1, 1); #else for (int i = input.dim() - 1; i > 1; i--) {