From 79c1a4317f26612d4d26e319124186549260ab3c Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Tue, 19 Jan 2021 16:12:43 +0000 Subject: [PATCH 01/14] update openmp flag Signed-off-by: Wenqi Li --- monai/csrc/utils/resample_utils.h | 2 ++ 1 file changed, 2 insertions(+) diff --git a/monai/csrc/utils/resample_utils.h b/monai/csrc/utils/resample_utils.h index 4735d13ca1..bbdf258b4c 100644 --- a/monai/csrc/utils/resample_utils.h +++ b/monai/csrc/utils/resample_utils.h @@ -62,7 +62,9 @@ namespace monai { template static inline void cpuAtomicAdd(scalar_t* ptr, offset_t offset, scalar_t value) { #if AT_PARALLEL_OPENMP +#if _OPENMP #pragma omp atomic +#endif #endif ptr[offset] += value; } From 947d448d5056801ef0ebd0deca206642073541cd Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Tue, 19 Jan 2021 16:15:43 +0000 Subject: [PATCH 02/14] improves boundtype docs Signed-off-by: Wenqi Li --- monai/csrc/ext.cpp | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/monai/csrc/ext.cpp b/monai/csrc/ext.cpp index 2e0644bc78..b4bb0f2c04 100644 --- a/monai/csrc/ext.cpp +++ b/monai/csrc/ext.cpp @@ -29,14 +29,20 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { // resample bound mode py::enum_(m, "BoundType") - .value("replicate", monai::BoundType::Replicate) - .value("dct1", monai::BoundType::DCT1) - .value("dct2", monai::BoundType::DCT2) - .value("dst1", monai::BoundType::DST1) - .value("dst2", monai::BoundType::DST2) - .value("dft", monai::BoundType::DFT) - .value("sliding", monai::BoundType::Sliding) - .value("zero", monai::BoundType::Zero) + .value("replicate", monai::BoundType::Replicate, "a a a | a b c d | d d d") + .value("nearest", monai::BoundType::Replicate, "a a a | a b c d | d d d") + .value("dct1", monai::BoundType::DCT1, "d c b | a b c d | c b a") + .value("mirror", monai::BoundType::DCT1, "d c b | a b c d | c b a") + .value("dct2", monai::BoundType::DCT2, "c b a | a b c d | d c b") + .value("reflect", monai::BoundType::DCT2, "c b a | a b c d | d c b") + .value("dst1", monai::BoundType::DST1, "-b -a 0 | a b c d | 0 -d -c") + .value("antimirror", monai::BoundType::DST1, "-b -a 0 | a b c d | 0 -d -c") + .value("dst2", monai::BoundType::DST2, "-c -b -a | a b c d | -d -c -b") + .value("antireflect", monai::BoundType::DST2, "-c -b -a | a b c d | -d -c -b") + .value("dft", monai::BoundType::DFT, "b c d | a b c d | a b c") + .value("wrap", monai::BoundType::DFT, "b c d | a b c d | a b c") + // .value("sliding", monai::BoundType::Sliding) + .value("zero", monai::BoundType::Zero, "0 0 0 | a b c d | 0 0 0") .export_values(); // resample interpolation mode From ed01c6d19f5262c4b1bc47926ec7f752bd98d862 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Tue, 19 Jan 2021 16:22:41 +0000 Subject: [PATCH 03/14] update setup.py Signed-off-by: Wenqi Li --- setup.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/setup.py b/setup.py index 9b20df845a..426866428c 100644 --- a/setup.py +++ b/setup.py @@ -26,6 +26,7 @@ FORCE_CUDA = os.getenv("FORCE_CUDA", "0") == "1" # flag ignored if BUILD_MONAI is False BUILD_CPP = BUILD_CUDA = False +TORCH_VERSION = 0 try: import torch @@ -35,14 +36,13 @@ BUILD_CPP = True from torch.utils.cpp_extension import CUDA_HOME, CUDAExtension - BUILD_CUDA = (torch.cuda.is_available() and (CUDA_HOME is not None)) or FORCE_CUDA + BUILD_CUDA = (CUDA_HOME is not None) if torch.cuda.is_available() else FORCE_CUDA _pt_version = pkg_resources.parse_version(torch.__version__).release # type: ignore[attr-defined] if _pt_version is None or len(_pt_version) < 3: raise AssertionError("unknown torch version") TORCH_VERSION = int(_pt_version[0]) * 10000 + int(_pt_version[1]) * 100 + int(_pt_version[2]) except (ImportError, TypeError, AssertionError, AttributeError) as e: - TORCH_VERSION = 0 warnings.warn(f"extension build skipped: {e}") finally: if not RUN_BUILD: From 76284a369512d49328dd09ccbf6a138b06085ed8 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Tue, 19 Jan 2021 16:34:37 +0000 Subject: [PATCH 04/14] input validation 1d Signed-off-by: Wenqi Li --- monai/csrc/resample/pushpull.h | 22 +++++++++++++--------- monai/csrc/utils/common_utils.h | 25 +++++++++++++------------ 2 files changed, 26 insertions(+), 21 deletions(-) diff --git a/monai/csrc/resample/pushpull.h b/monai/csrc/resample/pushpull.h index 45fd5ce564..1c20cc0114 100644 --- a/monai/csrc/resample/pushpull.h +++ b/monai/csrc/resample/pushpull.h @@ -69,8 +69,8 @@ at::Tensor grid_pull( CHECK_STRIDED(grid_opt) CHECK_SAME_DEVICE(input_opt, grid_opt) CHECK_SAME_DTYPE(input_opt, grid_opt) - CHECK_SPATIAL_2D_OR_3D(input) - CHECK_SPATIAL_2D_OR_3D(grid) + CHECK_SPATIAL_1D_2D_OR_3D(input) + CHECK_SPATIAL_1D_2D_OR_3D(grid) CHECK_GRID_COMPONENT(grid, grid.dim()) CHECK_SPATIAL_NOT_EMPTY(input) CHECK_SPATIAL_NOT_EMPTY(grid) @@ -165,8 +165,8 @@ at::Tensor grid_push( CHECK_STRIDED(grid_opt) CHECK_SAME_DEVICE(input_opt, grid_opt) CHECK_SAME_DTYPE(input_opt, grid_opt) - CHECK_SPATIAL_2D_OR_3D(input) - CHECK_SPATIAL_2D_OR_3D(grid) + CHECK_SPATIAL_1D_2D_OR_3D(input) + CHECK_SPATIAL_1D_2D_OR_3D(grid) CHECK_GRID_COMPONENT(grid, grid.dim()) CHECK_SPATIAL_NOT_EMPTY(input) CHECK_SPATIAL_NOT_EMPTY(grid) @@ -175,7 +175,10 @@ at::Tensor grid_push( CHECK_VEC_NOT_EMPTY(interpolation_mode); if (source_size.empty()) { - auto size = c10::IntArrayRef({input.size(2), input.size(3), input.dim() == 5 ? input.size(4) : 1}); + auto size = c10::IntArrayRef( + {input.dim() >= 3 ? input.size(2) : 1, + input.dim() >= 4 ? input.size(3) : 1, + input.dim() >= 5 ? input.size(4) : 1}); if (input.is_cuda()) #ifdef WITH_CUDA return cuda::pushpull( @@ -295,14 +298,15 @@ at::Tensor grid_count( CHECK_DEFINED(grid) auto grid_opt = grid.options(); CHECK_STRIDED(grid_opt) - CHECK_SPATIAL_2D_OR_3D(grid) + CHECK_SPATIAL_1D_2D_OR_3D(grid) CHECK_GRID_COMPONENT(grid, grid.dim()) CHECK_SPATIAL_NOT_EMPTY(grid) CHECK_VEC_NOT_EMPTY(bound_mode); CHECK_VEC_NOT_EMPTY(interpolation_mode); if (source_size.empty()) { - auto size = c10::IntArrayRef({grid.size(1), grid.size(2), grid.dim() == 5 ? grid.size(3) : 1}); + auto size = c10::IntArrayRef( + {grid.dim() >= 3 ? grid.size(2) : 1, grid.dim() >= 4 ? grid.size(3) : 1, grid.dim() >= 5 ? grid.size(4) : 1}); if (grid.is_cuda()) #ifdef WITH_CUDA return cuda::pushpull( @@ -422,8 +426,8 @@ at::Tensor grid_grad( CHECK_STRIDED(grid_opt) CHECK_SAME_DEVICE(input_opt, grid_opt) CHECK_SAME_DTYPE(input_opt, grid_opt) - CHECK_SPATIAL_2D_OR_3D(input) - CHECK_SPATIAL_2D_OR_3D(grid) + CHECK_SPATIAL_1D_2D_OR_3D(input) + CHECK_SPATIAL_1D_2D_OR_3D(grid) CHECK_GRID_COMPONENT(grid, grid.dim()) CHECK_SPATIAL_NOT_EMPTY(input) CHECK_SPATIAL_NOT_EMPTY(grid) diff --git a/monai/csrc/utils/common_utils.h b/monai/csrc/utils/common_utils.h index 882312acb3..3b90221ac3 100644 --- a/monai/csrc/utils/common_utils.h +++ b/monai/csrc/utils/common_utils.h @@ -26,10 +26,10 @@ limitations under the License. value.layout() == at::kStrided, \ "(): expected " #value "to have torch.strided layout, but it has ", \ value.layout()); -#define CHECK_SPATIAL_2D_OR_3D(value) \ - TORCH_CHECK( \ - (value.dim() == 4 || value.dim() == 5), \ - "(): expected 4D or 5D " #value " but got input with sizes ", \ +#define CHECK_SPATIAL_1D_2D_OR_3D(value) \ + TORCH_CHECK( \ + (value.dim() == 3 || value.dim() == 4 || value.dim() == 5), \ + "(): expected 3D, 4D or 5D " #value " but got input with sizes ", \ value.sizes()); #define CHECK_GRID_COMPONENT(value, dim) \ TORCH_CHECK( \ @@ -67,14 +67,15 @@ limitations under the License. i, \ " being empty"); \ } -#define CHECK_GRID_TARGET_COMPAT(value1, value2) \ - TORCH_CHECK( \ - value2.size(0) == value1.size(0) && value2.size(2) == value1.size(1) && value2.size(3) == value1.size(2) && \ - (value2.dim() == 4 || value2.size(4) == value1.size(3)), \ - "(): expected " #value2 " and " #value1 \ - " to have same batch, width, height and (optionally) depth sizes, but got " #value2 " with sizes ", \ - value2.sizes(), \ - " and " #value1 " with sizes ", \ +#define CHECK_GRID_TARGET_COMPAT(value1, value2) \ + TORCH_CHECK( \ + value2.size(0) == value1.size(0) && (value2.dim() <= 2 || value2.size(2) == value1.size(1)) && \ + (value2.dim() <= 3 || value2.size(3) == value1.size(2)) && \ + (value2.dim() <= 4 || value2.size(4) == value1.size(3)), \ + "(): expected " #value2 " and " #value1 \ + " to have same batch, width, height and (optionally) depth sizes, but got " #value2 " with sizes ", \ + value2.sizes(), \ + " and " #value1 " with sizes ", \ value1.sizes()); #define CHECK_SPATIAL_LENGTH(value, dim) \ TORCH_CHECK(((int64_t)(value.size()) == dim - 2), "(): expected ", dim, #value " elements but got ", value.size()); From 61386fff5bee95bdac2f438e59c3801dd9d3b85f Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Tue, 19 Jan 2021 16:53:23 +0000 Subject: [PATCH 05/14] fixes typos Signed-off-by: Wenqi Li --- monai/csrc/resample/pushpull_cpu.cpp | 3 +-- monai/csrc/resample/pushpull_cuda.cu | 1 - 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/monai/csrc/resample/pushpull_cpu.cpp b/monai/csrc/resample/pushpull_cpu.cpp index 40743a6cf1..a51b86fd19 100644 --- a/monai/csrc/resample/pushpull_cpu.cpp +++ b/monai/csrc/resample/pushpull_cpu.cpp @@ -18,7 +18,7 @@ limitations under the License. // It handles boundary conditions and interpolation orders defined in // `utils/resample_utils.h` and `utils/resample_utils.h`. // These parameters can be specified per dimension. -// Isotorpic 0-th and 1-st order interpolation have their own (faster) +// Isotropic 0-th and 1-st order interpolation have their own (faster) // implementations. Sliding boundary conditions are also implemented // separately. @@ -1461,7 +1461,6 @@ MONAI_NAMESPACE_DEVICE { // cpu scalar_t w10 = dx1 * dy0; scalar_t w01 = dx0 * dy1; scalar_t w11 = dx1 * dy1; - ; // Sign (/!\ compute sign before warping indices) int8_t sx1 = bound::sign(bound0, ix0 + 1, src_X); diff --git a/monai/csrc/resample/pushpull_cuda.cu b/monai/csrc/resample/pushpull_cuda.cu index ecfeb562ab..2a85b70299 100644 --- a/monai/csrc/resample/pushpull_cuda.cu +++ b/monai/csrc/resample/pushpull_cuda.cu @@ -1428,7 +1428,6 @@ MONAI_NAMESPACE_DEVICE { // cuda scalar_t w10 = dx1 * dy0; scalar_t w01 = dx0 * dy1; scalar_t w11 = dx1 * dy1; - ; // Sign (/!\ compute sign before warping indices) int8_t sx1 = bound::sign(bound0, ix0 + 1, src_X); From a2738c67ae92a78675bee14da7c8ae3e4ad24906 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Tue, 26 Jan 2021 20:11:32 +0000 Subject: [PATCH 06/14] fixes typos Signed-off-by: Wenqi Li --- monai/csrc/resample/pushpull_cpu.cpp | 3 ++- monai/csrc/resample/pushpull_cuda.cu | 1 + 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/monai/csrc/resample/pushpull_cpu.cpp b/monai/csrc/resample/pushpull_cpu.cpp index a51b86fd19..9d228f0cfb 100644 --- a/monai/csrc/resample/pushpull_cpu.cpp +++ b/monai/csrc/resample/pushpull_cpu.cpp @@ -37,6 +37,7 @@ limitations under the License. // . input bound/inter are always vectors -> clean unused constructors #include +#include #include #include "bounds_common.h" #include "interpolation_common.h" @@ -44,7 +45,7 @@ limitations under the License. //#include // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -// CPU/GPU -specific parameters +// CPU-specific parameters #include namespace { // This parameter specifies the minimum number of voxels that should be diff --git a/monai/csrc/resample/pushpull_cuda.cu b/monai/csrc/resample/pushpull_cuda.cu index 2a85b70299..029fe9eee1 100644 --- a/monai/csrc/resample/pushpull_cuda.cu +++ b/monai/csrc/resample/pushpull_cuda.cu @@ -37,6 +37,7 @@ limitations under the License. // . input bound/inter are always vectors -> clean unused constructors #include +#include #include #include "bounds_common.h" #include "interpolation_common.h" From 19a3637191c1fd9900611ce60d605de3eae911a7 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Wed, 27 Jan 2021 18:49:06 +0000 Subject: [PATCH 07/14] merge upstream changes Signed-off-by: Wenqi Li --- monai/csrc/resample/pushpull_cpu.cpp | 1214 +++++++++++++++++-------- monai/csrc/resample/pushpull_cuda.cu | 1220 ++++++++++++++++++-------- runtests.sh | 1 + 3 files changed, 1718 insertions(+), 717 deletions(-) diff --git a/monai/csrc/resample/pushpull_cpu.cpp b/monai/csrc/resample/pushpull_cpu.cpp index 9d228f0cfb..dd10dd76ee 100644 --- a/monai/csrc/resample/pushpull_cpu.cpp +++ b/monai/csrc/resample/pushpull_cpu.cpp @@ -25,6 +25,7 @@ limitations under the License. // TODO: // . [DONE] generic 3d // . [DONE] generic 2d +// . [DONE] generic 1d // . sliding nearest 3d // . sliding nearest 2d // . sliding linear 3d @@ -75,18 +76,27 @@ MONAI_NAMESPACE_DEVICE { // cpu namespace { // anonymous namespace > everything inside has internal linkage // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - // GENERIC PUSHPULL CLASS + // INDEXING UTILS // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - // This class implements the bulk of the code. - // /!\ No type and shape checking is performed here. - template - class PushPullImpl { + // This class reads and sets all the parameters that will later be used + // by the algorithm in PushPullImpl. All of this is done outside of the + // implementation class so that we do not depend on generic types. The + // point is to pre-allocate all necessary tensors so that we can check + // if they're all compatible with 32 bit math. If it's the case, we can + // dispatch to a 32b cuda implementation, which might increase + // performance. Else, we use 64 bit math to compute offsets. + // (On CPU, we always use 64 bit offsets because it doesn't make a huge + // difference. It would be different if we had a vectorized + // implementation as in PyTorch). + class PushPullAllocator { public: + static constexpr int64_t max_int32 = std::numeric_limits::max(); + // ~~~ CONSTRUCTORS ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ MONAI_HOST - PushPullImpl( + PushPullAllocator( int dim, BoundVectorRef bound, InterpolationVectorRef interpolation, @@ -126,101 +136,418 @@ MONAI_NAMESPACE_DEVICE { // cpu iso = interpolation0 == interpolation1 && interpolation0 == interpolation2; } - MONAI_HOST - PushPullImpl( - int dim, - BoundType bound, - InterpolationVectorRef interpolation, - bool extrapolate, - bool do_pull, - bool do_push, - bool do_count, - bool do_grad, - bool do_sgrad) - : dim(dim), - bound0(bound), - bound1(bound), - bound2(bound), - interpolation0(interpolation.size() > 0 ? interpolation[0] : InterpolationType::Linear), - interpolation1( - interpolation.size() > 1 ? interpolation[1] - : interpolation.size() > 0 ? interpolation[0] - : InterpolationType::Linear), - interpolation2( - interpolation.size() > 2 ? interpolation[2] - : interpolation.size() > 1 ? interpolation[1] - : interpolation.size() > 0 ? interpolation[0] - : InterpolationType::Linear), - extrapolate(extrapolate), - do_pull(do_pull), - do_push(do_push), - do_count(do_count), - do_grad(do_grad), - do_sgrad(do_sgrad) { - iso = interpolation0 == interpolation1 && interpolation0 == interpolation2; + // ~~~ FUNCTORS ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + // Usually used for pull: + // - do_pull -> return source[grid] + // - do_push -> fails + // - do_grad -> return J(source)[grid] + // - do_sgrad -> return H(source)[grid] + MONAI_HOST void ioset(const Tensor& source, const Tensor& grid) { + init_all(); + init_source(source); + init_grid(grid); + init_output(); } - MONAI_HOST - PushPullImpl( - int dim, - BoundVectorRef bound, - InterpolationType interpolation, - bool extrapolate, - bool do_pull, - bool do_push, - bool do_count, - bool do_grad, - bool do_sgrad) - : dim(dim), - bound0(bound.size() > 0 ? bound[0] : BoundType::Replicate), - bound1( - bound.size() > 1 ? bound[1] - : bound.size() > 0 ? bound[0] - : BoundType::Replicate), - bound2( - bound.size() > 2 ? bound[2] - : bound.size() > 1 ? bound[1] - : bound.size() > 0 ? bound[0] - : BoundType::Replicate), - interpolation0(interpolation), - interpolation1(interpolation), - interpolation2(interpolation), - extrapolate(extrapolate), - do_pull(do_pull), - do_push(do_push), - do_count(do_count), - do_grad(do_grad), - do_sgrad(do_sgrad) { - iso = interpolation0 == interpolation1 && interpolation0 == interpolation2; + // Usually used for pull_backward: + // - do_pull -> return source[grid] + // - do_push -> return push(target, grid, source.shape) + // - do_grad -> return J(source)[grid] + // - do_sgrad -> return H(source)[grid] + MONAI_HOST void ioset(const Tensor& source, const Tensor& grid, const Tensor& target) { + init_all(); + init_source(source); + init_grid(grid); + init_target(target); + init_output(); } - MONAI_HOST - PushPullImpl( - int dim, - BoundType bound, - InterpolationType interpolation, - bool extrapolate, - bool do_pull, - bool do_push, - bool do_count, - bool do_grad, - bool do_sgrad) - : dim(dim), - bound0(bound), - bound1(bound), - bound2(bound), - interpolation0(interpolation), - interpolation1(interpolation), - interpolation2(interpolation), - extrapolate(extrapolate), - do_pull(do_pull), - do_push(do_push), - do_count(do_count), - do_grad(do_grad), - do_sgrad(do_sgrad) { - iso = interpolation0 == interpolation1 && interpolation0 == interpolation2; + // Usually used for push: + // - do_pull -> fails + // - do_push -> return push(target, grid, source_size) + // - do_grad -> fails + // - do_sgrad -> fails + MONAI_HOST void ioset(IntArrayRef source_size, const Tensor& grid, const Tensor& target) { + init_all(); + init_source(source_size); + init_grid(grid); + init_target(target); + init_output(); } + // Usually used for count: + // - do_pull -> fails + // - do_push -> return push(ones, grid, source_size) + // - do_grad -> fails + // - do_sgrad -> fails + MONAI_HOST void ioset(IntArrayRef source_size, const Tensor& grid) { + init_all(); + init_source(source_size); + init_grid(grid); + init_output(); + } + + // We just check that all tensors that we own are compatible with 32b math + bool canUse32BitIndexMath(int64_t max_elem = max_int32) const { + return src_32b_ok && trgt_32b_ok && grid_32b_ok && grad_32b_ok && out_32b_ok; + } + + private: + // Copied from aten/src/ATen/native/IndexingUtils.cpp in PyTorch 1.6. + // It is used to decide to which pointer type we should dispatch to. + // Basically, we need to make sure that the "furthest" element we need + // to reach is less than max_elem away. + static bool tensorCanUse32BitIndexMath(const Tensor& t, int64_t max_elem = max_int32) { + int64_t elements = t.numel(); + if (elements >= max_elem) { + return false; + } + if (elements == 0) { + return max_elem > 0; + } + + int64_t offset = 0; + int64_t linearId = elements - 1; + + // NOTE: Assumes all strides are positive, which is true for now + for (int i = t.dim() - 1; i >= 0; --i) { + int64_t curDimIndex = linearId % t.size(i); + int64_t curDimOffset = curDimIndex * t.stride(i); + offset += curDimOffset; + linearId /= t.size(i); + } + + if (offset >= max_elem) { + return false; + } + + return true; + } + + // ~~~ COMPONENTS ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + MONAI_HOST void init_all(); + MONAI_HOST void init_source(const Tensor& source); + MONAI_HOST void init_source(IntArrayRef source_size); + MONAI_HOST void init_grid(const Tensor& grid); + MONAI_HOST void init_target(const Tensor& target); + MONAI_HOST void init_output(); + + // ~~~ OPTIONS ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + int dim; // dimensionality (2 or 3) + BoundType bound0; // boundary condition // x|W + BoundType bound1; // boundary condition // y|H + BoundType bound2; // boundary condition // z|D + InterpolationType interpolation0; // interpolation order // x|W + InterpolationType interpolation1; // interpolation order // y|H + InterpolationType interpolation2; // interpolation order // z|D + bool iso; // isotropic interpolation? + bool extrapolate; // compute out-of-bound values + bool do_pull; // sample a volume + bool do_push; // splat a volume + bool do_count; // splatting weights (= jacobian determinant) + bool do_grad; // backprop: gradient of grid // pull + bool do_sgrad; // sample spatial gradients + + // ~~~ NAVIGATORS ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + std::deque output; + TensorOptions src_opt; + TensorOptions grid_opt; + TensorOptions trgt_opt; + int64_t N; + int64_t C; + int64_t src_X; + int64_t src_Y; + int64_t src_Z; + int64_t trgt_X; + int64_t trgt_Y; + int64_t trgt_Z; + int64_t trgt_K; + int64_t src_sN; + int64_t src_sC; + int64_t src_sX; + int64_t src_sY; + int64_t src_sZ; + bool src_32b_ok; + void* src_ptr; + int64_t trgt_sN; + int64_t trgt_sC; + int64_t trgt_sX; + int64_t trgt_sY; + int64_t trgt_sZ; + int64_t trgt_sK; + bool trgt_32b_ok; + void* trgt_ptr; + int64_t grid_sN; + int64_t grid_sC; + int64_t grid_sX; + int64_t grid_sY; + int64_t grid_sZ; + bool grid_32b_ok; + void* grid_ptr; + int64_t out_sN; + int64_t out_sC; + int64_t out_sX; + int64_t out_sY; + int64_t out_sZ; + int64_t out_sK; // gradient dimension + bool out_32b_ok; + void* out_ptr; + int64_t grad_sN; + int64_t grad_sC; + int64_t grad_sX; + int64_t grad_sY; + int64_t grad_sZ; + bool grad_32b_ok; + void* grad_ptr; + + // Allow PushPullImpl's constructor to access PushPullAllocator's + // private members. + template + friend class PushPullImpl; + }; + + // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + // INITIALISATION + // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + MONAI_HOST + void PushPullAllocator::init_all() { + src_opt = grid_opt = trgt_opt = TensorOptions(); + N = C = 1L; + src_X = src_Y = src_Z = 1L; + trgt_X = trgt_Y = trgt_Z = 1L; + trgt_K = 0L; + src_sN = src_sC = src_sX = src_sY = src_sZ = 0L; + grid_sN = grid_sC = grid_sX = grid_sY = grid_sZ = 0L; + grad_sN = grad_sC = grad_sX = grad_sY = grad_sZ = 0L; + trgt_sN = trgt_sC = trgt_sX = trgt_sY = trgt_sZ = trgt_sK = 0L; + out_sN = out_sC = out_sX = out_sY = out_sZ = out_sK = 0L; + src_ptr = trgt_ptr = grid_ptr = out_ptr = grad_ptr = static_cast(0); + src_32b_ok = trgt_32b_ok = grid_32b_ok = out_32b_ok = grad_32b_ok = true; + } + + MONAI_HOST + void PushPullAllocator::init_source(const Tensor& source) { + N = source.size(0); + C = source.size(1); + src_X = source.size(2); + src_Y = dim < 2 ? 1L : source.size(3); + src_Z = dim < 3 ? 1L : source.size(4); + src_sN = source.stride(0); + src_sC = source.stride(1); + src_sX = source.stride(2); + src_sY = dim < 2 ? 0L : source.stride(3); + src_sZ = dim < 3 ? 0L : source.stride(4); + src_ptr = source.data_ptr(); + src_opt = source.options(); + src_32b_ok = tensorCanUse32BitIndexMath(source); + } + + MONAI_HOST + void PushPullAllocator::init_source(IntArrayRef source_size) { + src_X = source_size[0]; + src_Y = dim < 2 ? 1L : source_size[1]; + src_Z = dim < 3 ? 1L : source_size[2]; + } + + MONAI_HOST + void PushPullAllocator::init_grid(const Tensor& grid) { + N = grid.size(0); + trgt_X = grid.size(1); + trgt_Y = dim < 2 ? 1L : grid.size(2); + trgt_Z = dim < 3 ? 1L : grid.size(3); + grid_sN = grid.stride(0); + grid_sX = grid.stride(1); + grid_sY = dim < 2 ? 0L : grid.stride(2); + grid_sZ = dim < 3 ? 0L : grid.stride(3); + grid_sC = grid.stride(dim == 1 ? 2 : dim == 2 ? 3 : 4); + grid_ptr = grid.data_ptr(); + grid_opt = grid.options(); + grid_32b_ok = tensorCanUse32BitIndexMath(grid); + } + + MONAI_HOST + void PushPullAllocator::init_target(const Tensor& target) { + N = target.size(0); + C = target.size(1); + trgt_X = target.size(2); + trgt_Y = dim < 2 ? 1L : target.size(3); + trgt_Z = dim < 3 ? 1L : target.size(4); + trgt_K = target.dim() == dim + 3 ? target.size(dim == 1 ? 3 : dim == 2 ? 4 : 5) : 0L; + trgt_sN = target.stride(0); + trgt_sC = target.stride(1); + trgt_sX = target.stride(2); + trgt_sY = dim < 2 ? 0L : target.stride(3); + trgt_sZ = dim < 3 ? 0L : target.stride(4); + trgt_sK = target.dim() == dim + 3 ? target.stride(dim == 1 ? 3 : dim == 2 ? 4 : 5) : 0L; + trgt_ptr = target.data_ptr(); + trgt_opt = target.options(); + trgt_32b_ok = tensorCanUse32BitIndexMath(target); + } + + MONAI_HOST + void PushPullAllocator::init_output() { + output.clear(); + if (do_pull) { + if (dim == 1) + output.push_back(at::empty({N, C, trgt_X}, src_opt)); + else if (dim == 2) + output.push_back(at::empty({N, C, trgt_X, trgt_Y}, src_opt)); + else + output.push_back(at::empty({N, C, trgt_X, trgt_Y, trgt_Z}, src_opt)); + auto pull = output.back(); + out_sN = pull.stride(0); + out_sC = pull.stride(1); + out_sX = pull.stride(2); + out_sY = dim < 2 ? 0L : pull.stride(3); + out_sZ = dim < 3 ? 0L : pull.stride(4); + out_sK = 0L; + out_ptr = pull.data_ptr(); + out_32b_ok = tensorCanUse32BitIndexMath(pull); + } else if (do_sgrad) { + if (dim == 1) + output.push_back(at::empty({N, C, trgt_X, 1}, src_opt)); + else if (dim == 2) + output.push_back(at::empty({N, C, trgt_X, trgt_Y, 2}, src_opt)); + else + output.push_back(at::empty({N, C, trgt_X, trgt_Y, trgt_Z, 3}, src_opt)); + auto sgrad = output.back(); + out_sN = sgrad.stride(0); + out_sC = sgrad.stride(1); + out_sX = sgrad.stride(2); + out_sY = dim < 2 ? 0L : sgrad.stride(3); + out_sZ = dim < 3 ? 0L : sgrad.stride(4); + out_sK = sgrad.stride(dim == 1 ? 3 : dim == 2 ? 4 : 5); + out_ptr = sgrad.data_ptr(); + out_32b_ok = tensorCanUse32BitIndexMath(sgrad); + + if (iso && interpolation0 == InterpolationType::Nearest) + sgrad.zero_(); + if (iso && interpolation0 == InterpolationType::Linear && dim == 1) + sgrad.zero_(); + } else if (do_push) { + if (dim == 1) + output.push_back(at::zeros({N, C, src_X}, trgt_opt)); + else if (dim == 2) + output.push_back(at::zeros({N, C, src_X, src_Y}, trgt_opt)); + else + output.push_back(at::zeros({N, C, src_X, src_Y, src_Z}, trgt_opt)); + auto push = output.back(); + out_sN = push.stride(0); + out_sC = push.stride(1); + out_sX = push.stride(2); + out_sY = dim < 2 ? 0L : push.stride(3); + out_sZ = dim < 3 ? 0L : push.stride(4); + out_sK = 0L; + out_ptr = push.data_ptr(); + out_32b_ok = tensorCanUse32BitIndexMath(push); + } else if (do_count) { + if (dim == 1) + output.push_back(at::zeros({N, 1, src_X}, grid_opt)); + else if (dim == 2) + output.push_back(at::zeros({N, 1, src_X, src_Y}, grid_opt)); + else + output.push_back(at::zeros({N, 1, src_X, src_Y, src_Z}, grid_opt)); + auto count = output.back(); + out_sN = count.stride(0); + out_sC = count.stride(1); + out_sX = count.stride(2); + out_sY = dim < 2 ? 0L : count.stride(3); + out_sZ = dim < 3 ? 0L : count.stride(4); + out_sK = 0L; + out_ptr = count.data_ptr(); + out_32b_ok = tensorCanUse32BitIndexMath(count); + } + if (do_grad) { + if (dim == 1) + output.push_back(at::zeros({N, trgt_X, 1}, grid_opt)); + else if (dim == 2) + output.push_back(at::zeros({N, trgt_X, trgt_Y, 2}, grid_opt)); + else + output.push_back(at::zeros({N, trgt_X, trgt_Y, trgt_Z, 3}, grid_opt)); + auto grad = output.back(); + grad_sN = grad.stride(0); + grad_sX = grad.stride(1); + grad_sY = dim < 2 ? 0L : grad.stride(2); + grad_sZ = dim < 3 ? 0L : grad.stride(3); + grad_sC = grad.stride(dim == 1 ? 2 : dim == 2 ? 3 : 4); + grad_ptr = grad.data_ptr(); + out_32b_ok = tensorCanUse32BitIndexMath(grad); + + if (iso && interpolation0 == InterpolationType::Nearest) + grad.zero_(); + } + } + + // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + // GENERIC PUSHPULL CLASS + // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + // This class implements the bulk of the code. + // /!\ No type and shape checking is performed here. + + template + class PushPullImpl { + public: + // ~~~ CONSTRUCTOR ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + PushPullImpl(const PushPullAllocator& info) + : output(info.output), + dim(info.dim), + bound0(info.bound0), + bound1(info.bound1), + bound2(info.bound2), + interpolation0(info.interpolation0), + interpolation1(info.interpolation1), + interpolation2(info.interpolation1), + iso(info.iso), + extrapolate(info.extrapolate), + do_pull(info.do_pull), + do_push(info.do_push), + do_count(info.do_count), + do_grad(info.do_grad), + do_sgrad(info.do_sgrad), + N(static_cast(info.N)), + C(static_cast(info.C)), + src_X(static_cast(info.src_X)), + src_Y(static_cast(info.src_Y)), + src_Z(static_cast(info.src_Z)), + trgt_X(static_cast(info.trgt_X)), + trgt_Y(static_cast(info.trgt_Y)), + trgt_Z(static_cast(info.trgt_Z)), + trgt_K(static_cast(info.trgt_K)), + src_sN(static_cast(info.src_sN)), + src_sC(static_cast(info.src_sC)), + src_sX(static_cast(info.src_sX)), + src_sY(static_cast(info.src_sY)), + src_sZ(static_cast(info.src_sZ)), + src_ptr(static_cast(info.src_ptr)), + trgt_sN(static_cast(info.trgt_sN)), + trgt_sC(static_cast(info.trgt_sC)), + trgt_sX(static_cast(info.trgt_sX)), + trgt_sY(static_cast(info.trgt_sY)), + trgt_sZ(static_cast(info.trgt_sZ)), + trgt_sK(static_cast(info.trgt_sK)), + trgt_ptr(static_cast(info.trgt_ptr)), + grid_sN(static_cast(info.grid_sN)), + grid_sC(static_cast(info.grid_sC)), + grid_sX(static_cast(info.grid_sX)), + grid_sY(static_cast(info.grid_sY)), + grid_sZ(static_cast(info.grid_sZ)), + grid_ptr(static_cast(info.grid_ptr)), + out_sN(static_cast(info.out_sN)), + out_sC(static_cast(info.out_sC)), + out_sX(static_cast(info.out_sX)), + out_sY(static_cast(info.out_sY)), + out_sZ(static_cast(info.out_sZ)), + out_sK(static_cast(info.out_sK)), + out_ptr(static_cast(info.out_ptr)), + grad_sN(static_cast(info.grad_sN)), + grad_sC(static_cast(info.grad_sC)), + grad_sX(static_cast(info.grad_sX)), + grad_sY(static_cast(info.grad_sY)), + grad_sZ(static_cast(info.grad_sZ)), + grad_ptr(static_cast(info.grad_ptr)) {} + // ~~~ PUBLIC VALUE ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ std::deque output; @@ -248,39 +575,8 @@ MONAI_NAMESPACE_DEVICE { // cpu // } // ~~~ FUNCTORS ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - MONAI_HOST void ioset // Pull - (const Tensor& source, const Tensor& grid) { - init_all(); - init_source(source); - init_grid(grid); - init_output(); - } - - MONAI_HOST void ioset(const Tensor& source, const Tensor& grid, const Tensor& target) { - init_all(); - init_source(source); - init_grid(grid); - init_target(target); - init_output(); - } - - MONAI_HOST void ioset // Push - (IntArrayRef source_size, const Tensor& grid, const Tensor& target) { - init_all(); - init_source(source_size); - init_grid(grid); - init_target(target); - init_output(); - } - - MONAI_HOST void ioset // Count - (IntArrayRef source_size, const Tensor& grid) { - init_all(); - init_source(source_size); - init_grid(grid); - init_output(); - } + // Loop over all voxels void loop() const; MONAI_HOST MONAI_DEVICE int64_t voxcount() const { @@ -289,14 +585,18 @@ MONAI_NAMESPACE_DEVICE { // cpu private: // ~~~ COMPONENTS ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - MONAI_HOST void init_all(); - MONAI_HOST void init_source(const Tensor& source); - MONAI_HOST void init_source(IntArrayRef source_size); - MONAI_HOST void init_grid(const Tensor& grid); - MONAI_HOST void init_target(const Tensor& target); - MONAI_HOST void init_output(); + MONAI_DEVICE void check1d(offset_t w, offset_t n) const; MONAI_DEVICE void check2d(offset_t w, offset_t h, offset_t n) const; MONAI_DEVICE void check3d(offset_t w, offset_t h, offset_t d, offset_t n) const; + MONAI_DEVICE void interpolate1d(scalar_t x, offset_t w, offset_t n) const; + MONAI_DEVICE void interpolate1d_nearest(scalar_t x, offset_t w, offset_t n) const; + MONAI_DEVICE void interpolate1d_linear(scalar_t x, offset_t w, offset_t n) const; + MONAI_DEVICE void interpolate1d_sliding(scalar_t x, offset_t w, offset_t n) const { /*TODO*/ + } + MONAI_DEVICE void interpolate1d_sliding_nearest(scalar_t x, offset_t w, offset_t n) const { /*TODO*/ + } + MONAI_DEVICE void interpolate1d_sliding_linear(scalar_t x, offset_t w, offset_t n) const { /*TODO*/ + } MONAI_DEVICE void interpolate2d(scalar_t x, scalar_t y, offset_t w, offset_t h, offset_t n) const; MONAI_DEVICE void interpolate2d_nearest(scalar_t x, scalar_t y, offset_t w, offset_t h, offset_t n) const; MONAI_DEVICE void interpolate2d_bilinear(scalar_t x, scalar_t y, offset_t w, offset_t h, offset_t n) const; @@ -371,9 +671,6 @@ MONAI_NAMESPACE_DEVICE { // cpu bool do_sgrad; // sample spatial gradients // ~~~ NAVIGATORS ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - TensorOptions src_opt; - TensorOptions grid_opt; - TensorOptions trgt_opt; offset_t N; offset_t C; offset_t src_X; @@ -403,174 +700,24 @@ MONAI_NAMESPACE_DEVICE { // cpu offset_t grid_sZ; scalar_t* grid_ptr; offset_t out_sN; - offset_t out_sC; - offset_t out_sX; - offset_t out_sY; - offset_t out_sZ; - offset_t out_sK; // gradient dimension - scalar_t* out_ptr; - offset_t grad_sN; - offset_t grad_sC; - offset_t grad_sX; - offset_t grad_sY; - offset_t grad_sZ; - scalar_t* grad_ptr; - }; - - // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - // INITIALISATION - // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - - template - void PushPullImpl::init_all() { - src_opt = grid_opt = trgt_opt = TensorOptions(); - N = C = static_cast(1); - src_X = src_Y = src_Z = static_cast(1); - trgt_X = trgt_Y = trgt_Z = trgt_K = static_cast(1); - src_sN = src_sC = src_sX = src_sY = src_sZ = static_cast(0); - grid_sN = grid_sC = grid_sX = grid_sY = grid_sZ = static_cast(0); - grad_sN = grad_sC = grad_sX = grad_sY = grad_sZ = static_cast(0); - trgt_sN = trgt_sC = trgt_sX = trgt_sY = trgt_sZ = trgt_sK = static_cast(0); - out_sN = out_sC = out_sX = out_sY = out_sZ = out_sK = static_cast(0); - src_ptr = trgt_ptr = grid_ptr = out_ptr = grad_ptr = static_cast(0); - } - - template - MONAI_HOST void PushPullImpl::init_source(const Tensor& source) { - N = source.size(0); - C = source.size(1); - src_X = source.size(2); - src_Y = source.size(3); - src_Z = dim == 2 ? static_cast(1) : source.size(4); - src_sN = source.stride(0); - src_sC = source.stride(1); - src_sX = source.stride(2); - src_sY = source.stride(3); - src_sZ = dim == 2 ? static_cast(0) : source.stride(4); - src_ptr = source.data_ptr(); - src_opt = source.options(); - } - - template - MONAI_HOST void PushPullImpl::init_source(IntArrayRef source_size) { - src_X = source_size[0]; - src_Y = source_size[1]; - src_Z = dim == 2 ? static_cast(1) : source_size[2]; - } - - template - MONAI_HOST void PushPullImpl::init_grid(const Tensor& grid) { - N = grid.size(0); - trgt_X = grid.size(1); - trgt_Y = grid.size(2); - trgt_Z = dim == 2 ? static_cast(1) : grid.size(3); - grid_sN = grid.stride(0); - grid_sX = grid.stride(1); - grid_sY = grid.stride(2); - grid_sZ = dim == 2 ? static_cast(0) : grid.stride(3); - grid_sC = grid.stride(dim == 2 ? 3 : 4); - grid_ptr = grid.data_ptr(); - grid_opt = grid.options(); - } - - template - MONAI_HOST void PushPullImpl::init_target(const Tensor& target) { - N = target.size(0); - C = target.size(1); - trgt_X = target.size(2); - trgt_Y = target.size(3); - trgt_Z = dim == 2 ? static_cast(1) : target.size(4); - trgt_K = target.dim() == dim + 3 ? target.size(dim == 2 ? 4 : 5) : static_cast(1); - trgt_sN = target.stride(0); - trgt_sC = target.stride(1); - trgt_sX = target.stride(2); - trgt_sY = target.stride(3); - trgt_sZ = dim == 2 ? static_cast(0) : target.stride(4); - trgt_sK = target.dim() == dim + 3 ? target.stride(dim == 2 ? 4 : 5) : static_cast(0); - trgt_ptr = target.data_ptr(); - trgt_opt = target.options(); - } - - template - MONAI_HOST void PushPullImpl::init_output() { - output.clear(); - if (do_pull) { - if (dim == 2) - output.push_back(at::empty({N, C, trgt_X, trgt_Y}, src_opt)); - else - output.push_back(at::empty({N, C, trgt_X, trgt_Y, trgt_Z}, src_opt)); - auto pull = output.back(); - out_sN = pull.stride(0); - out_sC = pull.stride(1); - out_sX = pull.stride(2); - out_sY = pull.stride(3); - out_sZ = dim == 2 ? static_cast(0) : pull.stride(4); - out_sK = static_cast(0); - out_ptr = pull.template data_ptr(); - } else if (do_sgrad) { - if (dim == 2) - output.push_back(at::empty({N, C, trgt_X, trgt_Y, 2}, src_opt)); - else - output.push_back(at::empty({N, C, trgt_X, trgt_Y, trgt_Z, 3}, src_opt)); - auto sgrad = output.back(); - out_sN = sgrad.stride(0); - out_sC = sgrad.stride(1); - out_sX = sgrad.stride(2); - out_sY = sgrad.stride(3); - out_sZ = dim == 2 ? static_cast(0) : sgrad.stride(4); - out_sK = sgrad.stride(dim == 2 ? 4 : 5); - out_ptr = sgrad.template data_ptr(); - - if (iso && interpolation0 == InterpolationType::Nearest) - sgrad.zero_(); - } else if (do_push) { - if (dim == 2) - output.push_back(at::zeros({N, C, src_X, src_Y}, trgt_opt)); - else - output.push_back(at::zeros({N, C, src_X, src_Y, src_Z}, trgt_opt)); - auto push = output.back(); - out_sN = push.stride(0); - out_sC = push.stride(1); - out_sX = push.stride(2); - out_sY = push.stride(3); - out_sZ = dim == 2 ? static_cast(0) : push.stride(4); - out_sK = static_cast(0); - out_ptr = push.template data_ptr(); - } else if (do_count) { - if (dim == 2) - output.push_back(at::zeros({N, 1, src_X, src_Y}, grid_opt)); - else - output.push_back(at::zeros({N, 1, src_X, src_Y, src_Z}, grid_opt)); - auto count = output.back(); - out_sN = count.stride(0); - out_sC = count.stride(1); - out_sX = count.stride(2); - out_sY = count.stride(3); - out_sZ = dim == 2 ? static_cast(0) : count.stride(4); - out_sK = static_cast(0); - out_ptr = count.template data_ptr(); - } - if (do_grad) { - if (dim == 2) - output.push_back(at::zeros({N, src_X, src_Y, 2}, grid_opt)); - else - output.push_back(at::zeros({N, src_X, src_Y, src_Z, 3}, grid_opt)); - auto grad = output.back(); - grad_sN = grad.stride(0); - grad_sX = grad.stride(1); - grad_sY = grad.stride(2); - grad_sZ = dim == 2 ? static_cast(0) : grad.stride(3); - grad_sC = grad.stride(dim == 2 ? 3 : 4); - grad_ptr = grad.template data_ptr(); - - if (iso && interpolation0 == InterpolationType::Nearest) - grad.zero_(); - } - } + offset_t out_sC; + offset_t out_sX; + offset_t out_sY; + offset_t out_sZ; + offset_t out_sK; // gradient dimension + scalar_t* out_ptr; + offset_t grad_sN; + offset_t grad_sC; + offset_t grad_sX; + offset_t grad_sY; + offset_t grad_sZ; + scalar_t* grad_ptr; + }; // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ // LOOP // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + // This bit loops over all target voxels. We therefore need to // convert linear indices to multivariate indices. The way I do it // might not be optimal. @@ -587,7 +734,10 @@ MONAI_NAMESPACE_DEVICE { // cpu // parallelize across voxels. at::parallel_for(0, N, 0, [&](offset_t start, offset_t end) { for (offset_t n = start; n < end; ++n) { - if (dim == 2) { + if (dim == 1) { + for (offset_t w = 0; w < trgt_X; ++w) + check1d(w, n); + } else if (dim == 2) { for (offset_t h = 0; h < trgt_Y; ++h) for (offset_t w = 0; w < trgt_X; ++w) check2d(w, h, n); @@ -601,8 +751,8 @@ MONAI_NAMESPACE_DEVICE { // cpu }); return; } -#endif +#endif // Parallelize across voxels offset_t trgt_NXYZ = trgt_Z * trgt_Y * trgt_X * N; offset_t trgt_XYZ = trgt_Z * trgt_Y * trgt_X; @@ -616,7 +766,9 @@ MONAI_NAMESPACE_DEVICE { // cpu h = (i / trgt_Z) % trgt_Y; d = i % trgt_Z; - if (dim == 2) + if (dim == 1) + check1d(w, n); + else if (dim == 2) check2d(w, h, n); else check3d(w, h, d, n); @@ -632,6 +784,59 @@ MONAI_NAMESPACE_DEVICE { // cpu // 1) read the [x,y,z] source coordinate for the current target voxel // 3) check if the source coordinate is in bounds + template + MONAI_DEVICE void PushPullImpl::check3d(offset_t w, offset_t h, offset_t d, offset_t n) const { + // get the corresponding input x, y, z co-ordinates from grid + scalar_t* grid_ptr_NXYZ = grid_ptr + n * grid_sN + w * grid_sX + h * grid_sY + d * grid_sZ; + scalar_t x = *grid_ptr_NXYZ; + scalar_t y = grid_ptr_NXYZ[grid_sC]; + scalar_t z = grid_ptr_NXYZ[grid_sC * 2]; + + // Check if out-of-bound + if (!(extrapolate || + (inbounds(x, src_X, static_cast(TINY)) && inbounds(y, src_Y, static_cast(TINY)) && + inbounds(z, src_Z, static_cast(TINY))))) { + if (do_pull || do_sgrad) { + scalar_t* out_ptr_NCXYZ = out_ptr + n * out_sN + w * out_sX + h * out_sY + d * out_sZ; + for (offset_t c = 0; c < C; ++c, out_ptr_NCXYZ += out_sC) { + *out_ptr_NCXYZ = static_cast(0); + if (do_sgrad) { + out_ptr_NCXYZ[out_sK] = static_cast(0); + out_ptr_NCXYZ[out_sK * 2] = static_cast(0); + } + } + } + if (do_grad) { + scalar_t* grad_ptr_NXYZ = grad_ptr + n * grad_sN + w * grad_sX + h * grad_sY + d * grad_sZ; + (*grad_ptr_NXYZ) = static_cast(0); + grad_ptr_NXYZ[grad_sC] = static_cast(0); + grad_ptr_NXYZ[grad_sC * 2] = static_cast(0); + } + return; + } + + // Next step + if (bound0 == BoundType::Sliding) { + if (iso) + switch (static_cast(interpolation0)) { + case 0: + return interpolate3d_sliding_nearest(x, y, z, w, h, d, n); + case 1: + return interpolate3d_sliding_trilinear(x, y, z, w, h, d, n); + } + return interpolate3d_sliding(x, y, z, w, h, d, n); + } else { + if (iso) + switch (static_cast(interpolation0)) { + case 0: + return interpolate3d_nearest(x, y, z, w, h, d, n); + case 1: + return interpolate3d_trilinear(x, y, z, w, h, d, n); + } + return interpolate3d(x, y, z, w, h, d, n); + } + } + template MONAI_DEVICE void PushPullImpl::check2d(offset_t w, offset_t h, offset_t n) const { // get the corresponding input x, y, z co-ordinates from grid @@ -643,7 +848,7 @@ MONAI_NAMESPACE_DEVICE { // cpu if (!(extrapolate || (inbounds(x, src_X, static_cast(TINY)) && inbounds(y, src_Y, static_cast(TINY))))) { if (do_pull || do_sgrad) { - scalar_t* out_ptr_NCXY = out_ptr + n * out_sN + w * out_sZ + h * out_sY; + scalar_t* out_ptr_NCXY = out_ptr + n * out_sN + w * out_sX + h * out_sY; for (offset_t c = 0; c < C; ++c, out_ptr_NCXY += out_sC) { *out_ptr_NCXY = static_cast(0); if (do_sgrad) @@ -681,32 +886,25 @@ MONAI_NAMESPACE_DEVICE { // cpu } template - MONAI_DEVICE void PushPullImpl::check3d(offset_t w, offset_t h, offset_t d, offset_t n) const { + MONAI_DEVICE void PushPullImpl::check1d(offset_t w, offset_t n) const { // get the corresponding input x, y, z co-ordinates from grid - scalar_t* grid_ptr_NXYZ = grid_ptr + n * grid_sN + w * grid_sX + h * grid_sY + d * grid_sZ; - scalar_t x = *grid_ptr_NXYZ; - scalar_t y = grid_ptr_NXYZ[grid_sC]; - scalar_t z = grid_ptr_NXYZ[grid_sC * 2]; + scalar_t* grid_ptr_NX = grid_ptr + n * grid_sN + w * grid_sX; + scalar_t x = *grid_ptr_NX; // Check if out-of-bound - if (!(extrapolate || - (inbounds(x, src_X, static_cast(TINY)) && inbounds(y, src_Y, static_cast(TINY)) && - inbounds(z, src_Z, static_cast(TINY))))) { + if (!(extrapolate || inbounds(x, src_X, static_cast(TINY)))) { if (do_pull || do_sgrad) { - scalar_t* out_ptr_NCXYZ = out_ptr + n * out_sN + w * out_sX + h * out_sY + d * out_sZ; - for (offset_t c = 0; c < C; ++c, out_ptr_NCXYZ += out_sC) { - *out_ptr_NCXYZ = static_cast(0); - if (do_sgrad) { - out_ptr_NCXYZ[out_sK] = static_cast(0); - out_ptr_NCXYZ[out_sK * 2] = static_cast(0); - } + scalar_t* out_ptr_NCX = out_ptr + n * out_sN + w * out_sX; + for (offset_t c = 0; c < C; ++c, out_ptr_NCX += out_sC) { + *out_ptr_NCX = static_cast(0); + if (do_sgrad) + out_ptr_NCX[out_sK] = static_cast(0); } } if (do_grad) { - scalar_t* grad_ptr_NXYZ = grad_ptr + n * grad_sN + w * grad_sX + h * grad_sY + d * grad_sZ; - (*grad_ptr_NXYZ) = static_cast(0); - grad_ptr_NXYZ[grad_sC] = static_cast(0); - grad_ptr_NXYZ[grad_sC * 2] = static_cast(0); + scalar_t* grad_ptr_NX = grad_ptr + n * grad_sN + w * grad_sX; + (*grad_ptr_NX) = static_cast(0); + grad_ptr_NX[grad_sC] = static_cast(0); } return; } @@ -716,20 +914,20 @@ MONAI_NAMESPACE_DEVICE { // cpu if (iso) switch (static_cast(interpolation0)) { case 0: - return interpolate3d_sliding_nearest(x, y, z, w, h, d, n); + return interpolate1d_sliding_nearest(x, w, n); case 1: - return interpolate3d_sliding_trilinear(x, y, z, w, h, d, n); + return interpolate1d_sliding_linear(x, w, n); } - return interpolate3d_sliding(x, y, z, w, h, d, n); + return interpolate1d_sliding(x, w, n); } else { if (iso) switch (static_cast(interpolation0)) { case 0: - return interpolate3d_nearest(x, y, z, w, h, d, n); + return interpolate1d_nearest(x, w, n); case 1: - return interpolate3d_trilinear(x, y, z, w, h, d, n); + return interpolate1d_linear(x, w, n); } - return interpolate3d(x, y, z, w, h, d, n); + return interpolate1d(x, w, n); } } @@ -764,7 +962,7 @@ MONAI_NAMESPACE_DEVICE { // cpu if (trgt_ptr && (do_push || do_grad)) for (offset_t c = 0; c < C; ++c, trgt_ptr_NCXYZ += trgt_sC) { target[c] = *trgt_ptr_NCXYZ; - if (trgt_K > 1) { + if (trgt_K > 0) { target[c + C] = trgt_ptr_NCXYZ[trgt_sK]; target[c + C * 2] = trgt_ptr_NCXYZ[trgt_sK * 2]; } @@ -882,7 +1080,7 @@ MONAI_NAMESPACE_DEVICE { // cpu // ~~~~~~~~~~~~~~~~~~~~~~~~~~~ Push ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ else if (do_push) { - if (trgt_K == 1) { + if (trgt_K == 0) { // Diff w.r.t. push/pull scalar_t* out_ptr_NC = out_ptr_NC0; for (offset_t c = 0; c < C; ++c, out_ptr_NC += out_sC) @@ -905,7 +1103,7 @@ MONAI_NAMESPACE_DEVICE { // cpu // ~~~~~~~~~~~~~~~~~~~~~~~~~~~ Grad ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ if (do_grad) { - if (trgt_K == 1) { + if (trgt_K == 0) { // Diff w.r.t. pull/push scalar_t* src_ptr_NC = src_ptr_NC0; scalar_t dot = static_cast(0); @@ -974,7 +1172,7 @@ MONAI_NAMESPACE_DEVICE { // cpu if (trgt_ptr && (do_push || do_grad)) for (offset_t c = 0; c < C; ++c, trgt_ptr_NCXY += trgt_sC) { target[c] = *trgt_ptr_NCXY; - if (trgt_K > 1) { + if (trgt_K > 0) { target[c + C] = trgt_ptr_NCXY[trgt_sK]; } } @@ -1067,7 +1265,7 @@ MONAI_NAMESPACE_DEVICE { // cpu // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Push ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ else if (do_push) { - if (trgt_K == 1) { + if (trgt_K == 0) { // Diff w.r.t. push/pull scalar_t* out_ptr_NC = out_ptr_NC0; for (offset_t c = 0; c < C; ++c, out_ptr_NC += out_sC) @@ -1089,7 +1287,7 @@ MONAI_NAMESPACE_DEVICE { // cpu // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Grad ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ if (do_grad) { - if (trgt_K == 1) { + if (trgt_K == 0) { // Diff w.r.t. pull/push scalar_t* src_ptr_NC = src_ptr_NC0; scalar_t dot = static_cast(0); @@ -1126,6 +1324,150 @@ MONAI_NAMESPACE_DEVICE { // cpu } } + // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + // GENERIC INTERPOLATION 1D + // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + template + MONAI_DEVICE void PushPullImpl::interpolate1d(scalar_t x, offset_t w, offset_t n) const { + // Get corner pixel values from (x, y) + offset_t bx0, bx1; + interpolation::bounds(interpolation0, x, bx0, bx1); + offset_t dbx = bx1 - bx0; + + // Pre-compute offsets and target value + scalar_t* src_ptr_NC0 = src_ptr + n * src_sN; + scalar_t* out_ptr_NC0 = out_ptr + n * out_sN; + scalar_t* out_ptr_NCX0 = out_ptr + n * out_sN + w * out_sX; + scalar_t* trgt_ptr_NCX = trgt_ptr + n * trgt_sN + w * trgt_sX; + scalar_t target[2 * MONAI_MAX_NUM_CHANNELS]; + if (trgt_ptr && (do_push || do_grad)) + for (offset_t c = 0; c < C; ++c, trgt_ptr_NCX += trgt_sC) { + target[c] = *trgt_ptr_NCX; + if (trgt_K > 0) { + target[c + C] = trgt_ptr_NCX[trgt_sK]; + } + } + + // Initialize output + scalar_t* out_ptr_NCX = out_ptr_NCX0; + if (do_pull || do_sgrad) { + for (offset_t c = 0; c < C; ++c, out_ptr_NCX += out_sC) { + *out_ptr_NCX = static_cast(0); + if (do_sgrad) { + out_ptr_NCX[out_sK] = static_cast(0); + } + } + } + + // Pre-compute indices/weights/grad + scalar_t wx[8]; // B-spline weights + scalar_t gx[8]; // B-spline derivatives + scalar_t hx[8]; // B-spline 2nd derivatives + offset_t ix[8]; // Warped indices + uint8_t sx[8]; // Warped indices + + { + scalar_t *owx = static_cast(wx), *ogx = static_cast(gx), *ohx = static_cast(hx); + offset_t* oix = static_cast(ix); + uint8_t* osx = static_cast(sx); + for (offset_t bx = bx0; bx <= bx1; ++bx) { + scalar_t dx = x - bx; + *(owx++) = interpolation::fastweight(interpolation0, dx); + if (do_grad || do_sgrad) + *(ogx++) = interpolation::fastgrad(interpolation0, dx); + if (do_grad && trgt_sK > 1) + *(ohx++) = interpolation::fasthess(interpolation0, dx); + *(osx++) = bound::sign(bound0, bx, src_X); + *(oix++) = bound::index(bound0, bx, src_X); + } + } + + // Convolve coefficients with basis functions + scalar_t ogx; + ogx = static_cast(0); + for (offset_t i = 0; i <= dbx; ++i) { + offset_t oox = ix[i] * out_sX; + offset_t osx = ix[i] * src_sX; + uint8_t sxx = sx[i]; + scalar_t wxx = wx[i]; + scalar_t gxx = gx[i]; + scalar_t hxx = hx[i]; + + // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Pull ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if (do_pull) { + scalar_t* src_ptr_NC = src_ptr_NC0; + scalar_t* out_ptr_NCX = out_ptr_NCX0; + for (offset_t c = 0; c < C; ++c, out_ptr_NCX += out_sC, src_ptr_NC += src_sC) + *out_ptr_NCX += bound::get(src_ptr_NC, osx, sxx) * wxx; + } + + // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ SGrad ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + else if (do_sgrad) { + scalar_t* src_ptr_NC = src_ptr_NC0; + scalar_t* out_ptr_NCX = out_ptr_NCX0; + for (offset_t c = 0; c < C; ++c, out_ptr_NCX += out_sC, src_ptr_NC += src_sC) { + scalar_t src = bound::get(src_ptr_NC, osx, sxx); + *out_ptr_NCX += src * gxx; + } + } + + // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Push ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + else if (do_push) { + if (trgt_K == 0) { + // Diff w.r.t. push/pull + scalar_t* out_ptr_NC = out_ptr_NC0; + for (offset_t c = 0; c < C; ++c, out_ptr_NC += out_sC) + bound::add(out_ptr_NC, oox, wxx * target[c], sxx); + } else { + // Diff w.r.t. sgrad + scalar_t* out_ptr_NC = out_ptr_NC0; + for (offset_t c = 0; c < C; ++c, out_ptr_NC += out_sC) { + scalar_t val = gxx * target[c]; + bound::add(out_ptr_NC, oox, val, sxx); + } + } + } + + // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Count ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + else if (do_count) { + bound::add(out_ptr_NC0, oox, wxx, sxx); + } + + // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Grad ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if (do_grad) { + if (trgt_K == 0) { + // Diff w.r.t. pull/push + scalar_t* src_ptr_NC = src_ptr_NC0; + scalar_t dot = static_cast(0); + for (offset_t c = 0; c < C; ++c, src_ptr_NC += src_sC) { + scalar_t src = bound::get(src_ptr_NC, osx, sxx); + dot += (trgt_ptr ? src * target[c] : src); + // trgt_ptr == 0 in the backward pass of 'count' + } + ogx += gxx * dot; + } else { + // Diff w.r.t. sgrad + scalar_t* src_ptr_NC = src_ptr_NC0; + scalar_t dot; + dot = static_cast(0); + for (offset_t c = 0; c < C; ++c, src_ptr_NC += src_sC) { + scalar_t src = bound::get(src_ptr_NC, osx, sxx); + dot += src * target[c]; + } + ogx += hxx * dot; + } + } + + } // x + + // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Grad ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if (do_grad) { + scalar_t* grad_ptr_NX = grad_ptr + n * grad_sN + w * grad_sX; + (*grad_ptr_NX) = ogx; + } + } + // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ // LINEAR INTERPOLATION 3D // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -1215,7 +1557,7 @@ MONAI_NAMESPACE_DEVICE { // cpu scalar_t* trgt_ptr_NCXYZ = trgt_ptr + n * trgt_sN + w * trgt_sX + h * trgt_sY + d * trgt_sZ; scalar_t* src_ptr_NC = src_ptr + n * src_sN; - if (trgt_K == 1) { + if (trgt_K == 0) { // backward w.r.t. push/pull for (offset_t c = 0; c < C; ++c, trgt_ptr_NCXYZ += trgt_sC, src_ptr_NC += src_sC) { scalar_t src; @@ -1377,7 +1719,7 @@ MONAI_NAMESPACE_DEVICE { // cpu o111 = ix1 * out_sX + iy1 * out_sY + iz1 * out_sZ; scalar_t* trgt_ptr_NCXYZ = trgt_ptr + n * trgt_sN + w * trgt_sX + h * trgt_sY + d * trgt_sZ; scalar_t* out_ptr_NC = out_ptr + n * out_sN; - if (trgt_K == 1) { + if (trgt_K == 0) { // Diff w.r.t. push/pull for (offset_t c = 0; c < C; ++c, trgt_ptr_NCXYZ += trgt_sC, out_ptr_NC += out_sC) { scalar_t trgt = *trgt_ptr_NCXYZ; @@ -1500,7 +1842,7 @@ MONAI_NAMESPACE_DEVICE { // cpu scalar_t* trgt_ptr_NCXY = trgt_ptr + n * trgt_sN + w * trgt_sX + h * trgt_sY; scalar_t* src_ptr_NC = src_ptr + n * src_sN; - if (trgt_K == 1) { + if (trgt_K == 0) { // backward w.r.t. push/pull for (offset_t c = 0; c < C; ++c, trgt_ptr_NCXY += trgt_sC, src_ptr_NC += src_sC) { scalar_t src; @@ -1547,9 +1889,9 @@ MONAI_NAMESPACE_DEVICE { // cpu } } - scalar_t* grad_ptr_NXYZ = grad_ptr + n * grad_sN + w * grad_sX + h * grad_sY; - (*grad_ptr_NXYZ) = gx; - grad_ptr_NXYZ[grad_sC] = gy; + scalar_t* grad_ptr_NXY = grad_ptr + n * grad_sN + w * grad_sX + h * grad_sY; + (*grad_ptr_NXY) = gx; + grad_ptr_NXY[grad_sC] = gy; } // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Pull ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ if (do_pull) { @@ -1591,7 +1933,7 @@ MONAI_NAMESPACE_DEVICE { // cpu o11 = ix1 * out_sX + iy1 * out_sY; scalar_t* trgt_ptr_NCXY = trgt_ptr + n * trgt_sN + w * trgt_sX + h * trgt_sY; scalar_t* out_ptr_NC = out_ptr + n * out_sN; - if (trgt_K == 1) { + if (trgt_K == 0) { // Diff w.r.t. push/pull for (offset_t c = 0; c < C; ++c, trgt_ptr_NCXY += trgt_sC, out_ptr_NC += out_sC) { scalar_t trgt = *trgt_ptr_NCXY; @@ -1632,6 +1974,123 @@ MONAI_NAMESPACE_DEVICE { // cpu } } + // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + // LINEAR INTERPOLATION 1D + // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + template + MONAI_DEVICE void PushPullImpl::interpolate1d_linear(scalar_t x, offset_t w, offset_t n) const { + // Get corner pixel values from (x) + offset_t ix0 = static_cast(std::floor(x)); + + // Interpolation weights (inversely proportional to distance) + scalar_t w1 = x - ix0; + scalar_t w0 = 1. - w1; + + // Sign (/!\ compute sign before warping indices) + int8_t s1 = bound::sign(bound0, ix0 + 1, src_X); + int8_t s0 = bound::sign(bound0, ix0, src_X); + + // Warp indices + offset_t ix1; + ix1 = bound::index(bound0, ix0 + 1, src_X); + ix0 = bound::index(bound0, ix0, src_X); + + // Offsets into source volume + offset_t o0, o1; + if (do_pull || do_grad || do_sgrad) { + o0 = ix0 * src_sX; + o1 = ix1 * src_sX; + } + + // ~~~~~~~~~~~~~~~~~~~~~~~~~~ Grid gradient ~~~~~~~~~~~~~~~~~~~~~~~~~~ + if (do_grad) { + if (trgt_K == 0) { + // backward w.r.t. push/pull + + o0 = ix0 * src_sX; + o1 = ix1 * src_sX; + scalar_t gx = static_cast(0); + scalar_t* trgt_ptr_NCX = trgt_ptr + n * trgt_sN + w * trgt_sX; + scalar_t* src_ptr_NC = src_ptr + n * src_sN; + + for (offset_t c = 0; c < C; ++c, trgt_ptr_NCX += trgt_sC, src_ptr_NC += src_sC) { + scalar_t src; + scalar_t trgt = trgt_ptr ? *trgt_ptr_NCX : static_cast(1); + // ^ trgt_ptr == 0 during the backward pass of count + src = bound::get(src_ptr_NC, o0, s0); + if (trgt_ptr) + src *= trgt; + gx -= src; + src = bound::get(src_ptr_NC, o1, s1); + if (trgt_ptr) + src *= trgt; + gx += src; + } + + scalar_t* grad_ptr_NX = grad_ptr + n * grad_sN + w * grad_sX; + (*grad_ptr_NX) = gx; + } else { + // backward w.r.t. sgrad + // -> zero (make sure this is done at initialization) + } + } + // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Pull ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if (do_pull) { + o0 = ix0 * src_sX; + o1 = ix1 * src_sX; + scalar_t* out_ptr_NCX = out_ptr + n * out_sN + w * out_sX; + scalar_t* src_ptr_NC = src_ptr + n * src_sN; + for (offset_t c = 0; c < C; ++c, out_ptr_NCX += out_sC, src_ptr_NC += src_sC) { + *out_ptr_NCX = bound::get(src_ptr_NC, o0, s0) * w0 + bound::get(src_ptr_NC, o1, s1) * w1; + } + } + // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ SGrad ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + else if (do_sgrad) { + o0 = ix0 * src_sX; + o1 = ix1 * src_sX; + scalar_t* out_ptr_NCX = out_ptr + n * out_sN + w * out_sX; + scalar_t* src_ptr_NC = src_ptr + n * src_sN; + + for (offset_t c = 0; c < C; ++c, out_ptr_NCX += out_sC, src_ptr_NC += src_sC) { + *out_ptr_NCX = bound::get(src_ptr_NC, o1, s1) - bound::get(src_ptr_NC, o0, s0); + } + } + // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Push ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + else if (do_push) { + // Offsets into 'push' volume + o0 = ix0 * out_sX; + o1 = ix1 * out_sX; + scalar_t* trgt_ptr_NCX = trgt_ptr + n * trgt_sN + w * trgt_sX; + scalar_t* out_ptr_NC = out_ptr + n * out_sN; + if (trgt_K == 0) { + // Diff w.r.t. push/pull + for (offset_t c = 0; c < C; ++c, trgt_ptr_NCX += trgt_sC, out_ptr_NC += out_sC) { + scalar_t trgt = *trgt_ptr_NCX; + bound::add(out_ptr_NC, o0, w0 * trgt, s0); + bound::add(out_ptr_NC, o1, w1 * trgt, s1); + } + } else { + // Diff w.r.t. sgrad + for (offset_t c = 0; c < C; ++c, trgt_ptr_NCX += trgt_sC, out_ptr_NC += out_sC) { + scalar_t trgt0 = *trgt_ptr_NCX; + bound::add(out_ptr_NC, o0, -trgt0, s0); + bound::add(out_ptr_NC, o1, trgt0, s1); + } + } + } + // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Push ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + else if (do_count) { + // Offsets into 'push' volume + o0 = ix0 * out_sX; + o1 = ix1 * out_sX; + + scalar_t* out_ptr_N = out_ptr + n * out_sN; + bound::add(out_ptr_N, o0, w0, s0); + bound::add(out_ptr_N, o1, w1, s1); + } + } + // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ // NEAREST NEIGHBOR INTERPOLATION 3D // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -1666,7 +2125,7 @@ MONAI_NAMESPACE_DEVICE { // cpu scalar_t* src_ptr_NC = src_ptr + n * src_sN; for (offset_t c = 0; c < C; ++c, out_ptr_NCXYZ += out_sC, src_ptr_NC += src_sC) *out_ptr_NCXYZ = bound::get(src_ptr_NC, o, s); - } else if (do_push && trgt_K == 1) { + } else if (do_push && trgt_K == 0) { offset_t o = iz * out_sZ + iy * out_sY + ix * out_sX; scalar_t* trgt_ptr_NCXYZ = trgt_ptr + n * trgt_sN + w * trgt_sX + h * trgt_sY + d * trgt_sZ; scalar_t* out_ptr_NC = out_ptr + n * out_sN; @@ -1709,7 +2168,7 @@ MONAI_NAMESPACE_DEVICE { // cpu scalar_t* src_ptr_NC = src_ptr + n * src_sN; for (offset_t c = 0; c < C; ++c, out_ptr_NCXY += out_sC, src_ptr_NC += src_sC) *out_ptr_NCXY = bound::get(src_ptr_NC, o, s); - } else if (do_push && trgt_K == 1) { + } else if (do_push && trgt_K == 0) { offset_t o = iy * out_sY + ix * out_sX; scalar_t* trgt_ptr_NCXY = trgt_ptr + n * trgt_sN + w * trgt_sX + h * trgt_sY; scalar_t* out_ptr_NC = out_ptr + n * out_sN; @@ -1722,10 +2181,48 @@ MONAI_NAMESPACE_DEVICE { // cpu bound::add(out_ptr_NC, o, static_cast(1), s); } } + + // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + // NEAREST NEIGHBOR INTERPOLATION 1D + // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + template + MONAI_DEVICE void PushPullImpl::interpolate1d_nearest(scalar_t x, offset_t w, offset_t n) const { + offset_t i = static_cast(std::round(x)); + + // Boundary condition (/!\ compute sign before warping indices) + int8_t s = bound::sign(bound0, i, src_X); + i = bound::index(bound0, i, src_X); + + if (do_pull) { + offset_t o = i * src_sX; + scalar_t* out_ptr_NCX = out_ptr + n * out_sN + w * out_sX; + scalar_t* src_ptr_NC = src_ptr + n * src_sN; + for (offset_t c = 0; c < C; ++c, out_ptr_NCX += out_sC, src_ptr_NC += src_sC) + *out_ptr_NCX = bound::get(src_ptr_NC, o, s); + } else if (do_push && trgt_K == 0) { + offset_t o = i * out_sX; + scalar_t* trgt_ptr_NCX = trgt_ptr + n * trgt_sN + w * trgt_sX; + scalar_t* out_ptr_NC = out_ptr + n * out_sN; + for (offset_t c = 0; c < C; ++c, trgt_ptr_NCX += trgt_sC, out_ptr_NC += out_sC) + bound::add(out_ptr_NC, o, *trgt_ptr_NCX, s); + } else if (do_count) { + offset_t o = i * out_sX; + scalar_t* out_ptr_NC = out_ptr + n * out_sN; + for (offset_t c = 0; c < C; ++c, out_ptr_NC += out_sC) + bound::add(out_ptr_NC, o, static_cast(1), s); + } + } + // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ // LINEAR INTERPOLATION 3D + SLIDING BOUNDARY // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ // TODO + + // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + // CUDA KERNEL (MUST BE OUT OF CLASS) + // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + } // namespace // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -1757,8 +2254,6 @@ MONAI_NAMESPACE_DEVICE { // cpu PUSHPULL_INSTANTIATE1(BoundType); \ PUSHPULL_INSTANTIATE1(BoundVectorRef) - // ~~~ CPU ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - // Two arguments (source, grid) // > `bound` and `interpolation` can be single arguments or vectors. template @@ -1773,12 +2268,14 @@ MONAI_NAMESPACE_DEVICE { // cpu bool do_count, bool do_grad, bool do_sgrad) { + PushPullAllocator info( + grid.dim() - 2, bound, interpolation, extrapolate, do_pull, do_push, do_count, do_grad, do_sgrad); + info.ioset(source, grid); + return AT_DISPATCH_FLOATING_TYPES(grid.scalar_type(), "pushpull", [&] { - PushPullImpl f( - grid.dim() - 2, bound, interpolation, extrapolate, do_pull, do_push, do_count, do_grad, do_sgrad); - f.ioset(source, grid); - f.loop(); - return f.output; + PushPullImpl algo(info); + algo.loop(); + return algo.output; }); } @@ -1798,17 +2295,18 @@ MONAI_NAMESPACE_DEVICE { // cpu bool do_count, bool do_grad, bool do_sgrad) { + PushPullAllocator info( + grid.dim() - 2, bound, interpolation, extrapolate, do_pull, do_push, do_count, do_grad, do_sgrad); + info.ioset(source, grid, target); + return AT_DISPATCH_FLOATING_TYPES(grid.scalar_type(), "pushpull", [&] { - PushPullImpl f( - grid.dim() - 2, bound, interpolation, extrapolate, do_pull, do_push, do_count, do_grad, do_sgrad); - f.ioset(source, grid, target); - f.loop(); - return f.output; + PushPullImpl algo(info); + algo.loop(); + return algo.output; }); } PUSHPULL_INSTANTIATE; -} // namespace - +} // namespace cpu } // namespace monai diff --git a/monai/csrc/resample/pushpull_cuda.cu b/monai/csrc/resample/pushpull_cuda.cu index 029fe9eee1..38d34ffe98 100644 --- a/monai/csrc/resample/pushpull_cuda.cu +++ b/monai/csrc/resample/pushpull_cuda.cu @@ -25,6 +25,7 @@ limitations under the License. // TODO: // . [DONE] generic 3d // . [DONE] generic 2d +// . [DONE] generic 1d // . sliding nearest 3d // . sliding nearest 2d // . sliding linear 3d @@ -72,18 +73,27 @@ MONAI_NAMESPACE_DEVICE { // cuda namespace { // anonymous namespace > everything inside has internal linkage // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - // GENERIC PUSHPULL CLASS + // INDEXING UTILS // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - // This class implements the bulk of the code. - // /!\ No type and shape checking is performed here. - template - class PushPullImpl { + // This class reads and sets all the parameters that will later be used + // by the algorithm in PushPullImpl. All of this is done outside of the + // implementation class so that we do not depend on generic types. The + // point is to pre-allocate all necessary tensors so that we can check + // if they're all compatible with 32 bit math. If it's the case, we can + // dispatch to a 32b cuda implementation, which might increase + // performance. Else, we use 64 bit math to compute offsets. + // (On CPU, we always use 64 bit offsets because it doesn't make a huge + // difference. It would be different if we had a vectorized + // implementation as in PyTorch). + class PushPullAllocator { public: + static constexpr int64_t max_int32 = std::numeric_limits::max(); + // ~~~ CONSTRUCTORS ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ MONAI_HOST - PushPullImpl( + PushPullAllocator( int dim, BoundVectorRef bound, InterpolationVectorRef interpolation, @@ -123,100 +133,417 @@ MONAI_NAMESPACE_DEVICE { // cuda iso = interpolation0 == interpolation1 && interpolation0 == interpolation2; } - MONAI_HOST - PushPullImpl( - int dim, - BoundType bound, - InterpolationVectorRef interpolation, - bool extrapolate, - bool do_pull, - bool do_push, - bool do_count, - bool do_grad, - bool do_sgrad) - : dim(dim), - bound0(bound), - bound1(bound), - bound2(bound), - interpolation0(interpolation.size() > 0 ? interpolation[0] : InterpolationType::Linear), - interpolation1( - interpolation.size() > 1 ? interpolation[1] - : interpolation.size() > 0 ? interpolation[0] - : InterpolationType::Linear), - interpolation2( - interpolation.size() > 2 ? interpolation[2] - : interpolation.size() > 1 ? interpolation[1] - : interpolation.size() > 0 ? interpolation[0] - : InterpolationType::Linear), - extrapolate(extrapolate), - do_pull(do_pull), - do_push(do_push), - do_count(do_count), - do_grad(do_grad), - do_sgrad(do_sgrad) { - iso = interpolation0 == interpolation1 && interpolation0 == interpolation2; + // ~~~ FUNCTORS ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + // Usually used for pull: + // - do_pull -> return source[grid] + // - do_push -> fails + // - do_grad -> return J(source)[grid] + // - do_sgrad -> return H(source)[grid] + MONAI_HOST void ioset(const Tensor& source, const Tensor& grid) { + init_all(); + init_source(source); + init_grid(grid); + init_output(); } - MONAI_HOST - PushPullImpl( - int dim, - BoundVectorRef bound, - InterpolationType interpolation, - bool extrapolate, - bool do_pull, - bool do_push, - bool do_count, - bool do_grad, - bool do_sgrad) - : dim(dim), - bound0(bound.size() > 0 ? bound[0] : BoundType::Replicate), - bound1( - bound.size() > 1 ? bound[1] - : bound.size() > 0 ? bound[0] - : BoundType::Replicate), - bound2( - bound.size() > 2 ? bound[2] - : bound.size() > 1 ? bound[1] - : bound.size() > 0 ? bound[0] - : BoundType::Replicate), - interpolation0(interpolation), - interpolation1(interpolation), - interpolation2(interpolation), - extrapolate(extrapolate), - do_pull(do_pull), - do_push(do_push), - do_count(do_count), - do_grad(do_grad), - do_sgrad(do_sgrad) { - iso = interpolation0 == interpolation1 && interpolation0 == interpolation2; + // Usually used for pull_backward: + // - do_pull -> return source[grid] + // - do_push -> return push(target, grid, source.shape) + // - do_grad -> return J(source)[grid] + // - do_sgrad -> return H(source)[grid] + MONAI_HOST void ioset(const Tensor& source, const Tensor& grid, const Tensor& target) { + init_all(); + init_source(source); + init_grid(grid); + init_target(target); + init_output(); } - MONAI_HOST - PushPullImpl( - int dim, - BoundType bound, - InterpolationType interpolation, - bool extrapolate, - bool do_pull, - bool do_push, - bool do_count, - bool do_grad, - bool do_sgrad) - : dim(dim), - bound0(bound), - bound1(bound), - bound2(bound), - interpolation0(interpolation), - interpolation1(interpolation), - interpolation2(interpolation), - extrapolate(extrapolate), - do_pull(do_pull), - do_push(do_push), - do_count(do_count), - do_grad(do_grad), - do_sgrad(do_sgrad) { - iso = interpolation0 == interpolation1 && interpolation0 == interpolation2; + // Usually used for push: + // - do_pull -> fails + // - do_push -> return push(target, grid, source_size) + // - do_grad -> fails + // - do_sgrad -> fails + MONAI_HOST void ioset(IntArrayRef source_size, const Tensor& grid, const Tensor& target) { + init_all(); + init_source(source_size); + init_grid(grid); + init_target(target); + init_output(); + } + + // Usually used for count: + // - do_pull -> fails + // - do_push -> return push(ones, grid, source_size) + // - do_grad -> fails + // - do_sgrad -> fails + MONAI_HOST void ioset(IntArrayRef source_size, const Tensor& grid) { + init_all(); + init_source(source_size); + init_grid(grid); + init_output(); + } + + // We just check that all tensors that we own are compatible with 32b math + bool canUse32BitIndexMath(int64_t max_elem = max_int32) const { + return src_32b_ok && trgt_32b_ok && grid_32b_ok && grad_32b_ok && out_32b_ok; + } + + private: + // Copied from aten/src/ATen/native/IndexingUtils.cpp in PyTorch 1.6. + // It is used to decide to which pointer type we should dispatch to. + // Basically, we need to make sure that the "furthest" element we need + // to reach is less than max_elem away. + static bool tensorCanUse32BitIndexMath(const Tensor& t, int64_t max_elem = max_int32) { + int64_t elements = t.numel(); + if (elements >= max_elem) { + return false; + } + if (elements == 0) { + return max_elem > 0; + } + + int64_t offset = 0; + int64_t linearId = elements - 1; + + // NOTE: Assumes all strides are positive, which is true for now + for (int i = t.dim() - 1; i >= 0; --i) { + int64_t curDimIndex = linearId % t.size(i); + int64_t curDimOffset = curDimIndex * t.stride(i); + offset += curDimOffset; + linearId /= t.size(i); + } + + if (offset >= max_elem) { + return false; + } + + return true; + } + + // ~~~ COMPONENTS ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + MONAI_HOST void init_all(); + MONAI_HOST void init_source(const Tensor& source); + MONAI_HOST void init_source(IntArrayRef source_size); + MONAI_HOST void init_grid(const Tensor& grid); + MONAI_HOST void init_target(const Tensor& target); + MONAI_HOST void init_output(); + + // ~~~ OPTIONS ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + int dim; // dimensionality (2 or 3) + BoundType bound0; // boundary condition // x|W + BoundType bound1; // boundary condition // y|H + BoundType bound2; // boundary condition // z|D + InterpolationType interpolation0; // interpolation order // x|W + InterpolationType interpolation1; // interpolation order // y|H + InterpolationType interpolation2; // interpolation order // z|D + bool iso; // isotropic interpolation? + bool extrapolate; // compute out-of-bound values + bool do_pull; // sample a volume + bool do_push; // splat a volume + bool do_count; // splatting weights (= jacobian determinant) + bool do_grad; // backprop: gradient of grid // pull + bool do_sgrad; // sample spatial gradients + + // ~~~ NAVIGATORS ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + std::deque output; + TensorOptions src_opt; + TensorOptions grid_opt; + TensorOptions trgt_opt; + int64_t N; + int64_t C; + int64_t src_X; + int64_t src_Y; + int64_t src_Z; + int64_t trgt_X; + int64_t trgt_Y; + int64_t trgt_Z; + int64_t trgt_K; + int64_t src_sN; + int64_t src_sC; + int64_t src_sX; + int64_t src_sY; + int64_t src_sZ; + bool src_32b_ok; + void* src_ptr; + int64_t trgt_sN; + int64_t trgt_sC; + int64_t trgt_sX; + int64_t trgt_sY; + int64_t trgt_sZ; + int64_t trgt_sK; + bool trgt_32b_ok; + void* trgt_ptr; + int64_t grid_sN; + int64_t grid_sC; + int64_t grid_sX; + int64_t grid_sY; + int64_t grid_sZ; + bool grid_32b_ok; + void* grid_ptr; + int64_t out_sN; + int64_t out_sC; + int64_t out_sX; + int64_t out_sY; + int64_t out_sZ; + int64_t out_sK; // gradient dimension + bool out_32b_ok; + void* out_ptr; + int64_t grad_sN; + int64_t grad_sC; + int64_t grad_sX; + int64_t grad_sY; + int64_t grad_sZ; + bool grad_32b_ok; + void* grad_ptr; + + // Allow PushPullImpl's constructor to access PushPullAllocator's + // private members. + template + friend class PushPullImpl; + }; + + // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + // INITIALISATION + // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + MONAI_HOST + void PushPullAllocator::init_all() { + src_opt = grid_opt = trgt_opt = TensorOptions(); + N = C = 1L; + src_X = src_Y = src_Z = 1L; + trgt_X = trgt_Y = trgt_Z = 1L; + trgt_K = 0L; + src_sN = src_sC = src_sX = src_sY = src_sZ = 0L; + grid_sN = grid_sC = grid_sX = grid_sY = grid_sZ = 0L; + grad_sN = grad_sC = grad_sX = grad_sY = grad_sZ = 0L; + trgt_sN = trgt_sC = trgt_sX = trgt_sY = trgt_sZ = trgt_sK = 0L; + out_sN = out_sC = out_sX = out_sY = out_sZ = out_sK = 0L; + src_ptr = trgt_ptr = grid_ptr = out_ptr = grad_ptr = static_cast(0); + src_32b_ok = trgt_32b_ok = grid_32b_ok = out_32b_ok = grad_32b_ok = true; + } + + MONAI_HOST + void PushPullAllocator::init_source(const Tensor& source) { + N = source.size(0); + C = source.size(1); + src_X = source.size(2); + src_Y = dim < 2 ? 1L : source.size(3); + src_Z = dim < 3 ? 1L : source.size(4); + src_sN = source.stride(0); + src_sC = source.stride(1); + src_sX = source.stride(2); + src_sY = dim < 2 ? 0L : source.stride(3); + src_sZ = dim < 3 ? 0L : source.stride(4); + src_ptr = source.data_ptr(); + src_opt = source.options(); + src_32b_ok = tensorCanUse32BitIndexMath(source); + } + + MONAI_HOST + void PushPullAllocator::init_source(IntArrayRef source_size) { + src_X = source_size[0]; + src_Y = dim < 2 ? 1L : source_size[1]; + src_Z = dim < 3 ? 1L : source_size[2]; + } + + MONAI_HOST + void PushPullAllocator::init_grid(const Tensor& grid) { + N = grid.size(0); + trgt_X = grid.size(1); + trgt_Y = dim < 2 ? 1L : grid.size(2); + trgt_Z = dim < 3 ? 1L : grid.size(3); + grid_sN = grid.stride(0); + grid_sX = grid.stride(1); + grid_sY = dim < 2 ? 0L : grid.stride(2); + grid_sZ = dim < 3 ? 0L : grid.stride(3); + grid_sC = grid.stride(dim == 1 ? 2 : dim == 2 ? 3 : 4); + grid_ptr = grid.data_ptr(); + grid_opt = grid.options(); + grid_32b_ok = tensorCanUse32BitIndexMath(grid); + } + + MONAI_HOST + void PushPullAllocator::init_target(const Tensor& target) { + N = target.size(0); + C = target.size(1); + trgt_X = target.size(2); + trgt_Y = dim < 2 ? 1L : target.size(3); + trgt_Z = dim < 3 ? 1L : target.size(4); + trgt_K = target.dim() == dim + 3 ? target.size(dim == 1 ? 3 : dim == 2 ? 4 : 5) : 0L; + trgt_sN = target.stride(0); + trgt_sC = target.stride(1); + trgt_sX = target.stride(2); + trgt_sY = dim < 2 ? 0L : target.stride(3); + trgt_sZ = dim < 3 ? 0L : target.stride(4); + trgt_sK = target.dim() == dim + 3 ? target.stride(dim == 1 ? 3 : dim == 2 ? 4 : 5) : 0L; + trgt_ptr = target.data_ptr(); + trgt_opt = target.options(); + trgt_32b_ok = tensorCanUse32BitIndexMath(target); + } + + MONAI_HOST + void PushPullAllocator::init_output() { + output.clear(); + if (do_pull) { + if (dim == 1) + output.push_back(at::empty({N, C, trgt_X}, src_opt)); + else if (dim == 2) + output.push_back(at::empty({N, C, trgt_X, trgt_Y}, src_opt)); + else + output.push_back(at::empty({N, C, trgt_X, trgt_Y, trgt_Z}, src_opt)); + auto pull = output.back(); + out_sN = pull.stride(0); + out_sC = pull.stride(1); + out_sX = pull.stride(2); + out_sY = dim < 2 ? 0L : pull.stride(3); + out_sZ = dim < 3 ? 0L : pull.stride(4); + out_sK = 0L; + out_ptr = pull.data_ptr(); + out_32b_ok = tensorCanUse32BitIndexMath(pull); + } else if (do_sgrad) { + if (dim == 1) + output.push_back(at::empty({N, C, trgt_X, 1}, src_opt)); + else if (dim == 2) + output.push_back(at::empty({N, C, trgt_X, trgt_Y, 2}, src_opt)); + else + output.push_back(at::empty({N, C, trgt_X, trgt_Y, trgt_Z, 3}, src_opt)); + auto sgrad = output.back(); + out_sN = sgrad.stride(0); + out_sC = sgrad.stride(1); + out_sX = sgrad.stride(2); + out_sY = dim < 2 ? 0L : sgrad.stride(3); + out_sZ = dim < 3 ? 0L : sgrad.stride(4); + out_sK = sgrad.stride(dim == 1 ? 3 : dim == 2 ? 4 : 5); + out_ptr = sgrad.data_ptr(); + out_32b_ok = tensorCanUse32BitIndexMath(sgrad); + + if (iso && interpolation0 == InterpolationType::Nearest) + sgrad.zero_(); + if (iso && interpolation0 == InterpolationType::Linear && dim == 1) + sgrad.zero_(); + } else if (do_push) { + if (dim == 1) + output.push_back(at::zeros({N, C, src_X}, trgt_opt)); + else if (dim == 2) + output.push_back(at::zeros({N, C, src_X, src_Y}, trgt_opt)); + else + output.push_back(at::zeros({N, C, src_X, src_Y, src_Z}, trgt_opt)); + auto push = output.back(); + out_sN = push.stride(0); + out_sC = push.stride(1); + out_sX = push.stride(2); + out_sY = dim < 2 ? 0L : push.stride(3); + out_sZ = dim < 3 ? 0L : push.stride(4); + out_sK = 0L; + out_ptr = push.data_ptr(); + out_32b_ok = tensorCanUse32BitIndexMath(push); + } else if (do_count) { + if (dim == 1) + output.push_back(at::zeros({N, 1, src_X}, grid_opt)); + else if (dim == 2) + output.push_back(at::zeros({N, 1, src_X, src_Y}, grid_opt)); + else + output.push_back(at::zeros({N, 1, src_X, src_Y, src_Z}, grid_opt)); + auto count = output.back(); + out_sN = count.stride(0); + out_sC = count.stride(1); + out_sX = count.stride(2); + out_sY = dim < 2 ? 0L : count.stride(3); + out_sZ = dim < 3 ? 0L : count.stride(4); + out_sK = 0L; + out_ptr = count.data_ptr(); + out_32b_ok = tensorCanUse32BitIndexMath(count); + } + if (do_grad) { + if (dim == 1) + output.push_back(at::zeros({N, trgt_X, 1}, grid_opt)); + else if (dim == 2) + output.push_back(at::zeros({N, trgt_X, trgt_Y, 2}, grid_opt)); + else + output.push_back(at::zeros({N, trgt_X, trgt_Y, trgt_Z, 3}, grid_opt)); + auto grad = output.back(); + grad_sN = grad.stride(0); + grad_sX = grad.stride(1); + grad_sY = dim < 2 ? 0L : grad.stride(2); + grad_sZ = dim < 3 ? 0L : grad.stride(3); + grad_sC = grad.stride(dim == 1 ? 2 : dim == 2 ? 3 : 4); + grad_ptr = grad.data_ptr(); + out_32b_ok = tensorCanUse32BitIndexMath(grad); + + if (iso && interpolation0 == InterpolationType::Nearest) + grad.zero_(); } + } + + // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + // GENERIC PUSHPULL CLASS + // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + // This class implements the bulk of the code. + // /!\ No type and shape checking is performed here. + + template + class PushPullImpl { + public: + // ~~~ CONSTRUCTOR ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + PushPullImpl(const PushPullAllocator& info) + : output(info.output), + dim(info.dim), + bound0(info.bound0), + bound1(info.bound1), + bound2(info.bound2), + interpolation0(info.interpolation0), + interpolation1(info.interpolation1), + interpolation2(info.interpolation1), + iso(info.iso), + extrapolate(info.extrapolate), + do_pull(info.do_pull), + do_push(info.do_push), + do_count(info.do_count), + do_grad(info.do_grad), + do_sgrad(info.do_sgrad), + N(static_cast(info.N)), + C(static_cast(info.C)), + src_X(static_cast(info.src_X)), + src_Y(static_cast(info.src_Y)), + src_Z(static_cast(info.src_Z)), + trgt_X(static_cast(info.trgt_X)), + trgt_Y(static_cast(info.trgt_Y)), + trgt_Z(static_cast(info.trgt_Z)), + trgt_K(static_cast(info.trgt_K)), + src_sN(static_cast(info.src_sN)), + src_sC(static_cast(info.src_sC)), + src_sX(static_cast(info.src_sX)), + src_sY(static_cast(info.src_sY)), + src_sZ(static_cast(info.src_sZ)), + src_ptr(static_cast(info.src_ptr)), + trgt_sN(static_cast(info.trgt_sN)), + trgt_sC(static_cast(info.trgt_sC)), + trgt_sX(static_cast(info.trgt_sX)), + trgt_sY(static_cast(info.trgt_sY)), + trgt_sZ(static_cast(info.trgt_sZ)), + trgt_sK(static_cast(info.trgt_sK)), + trgt_ptr(static_cast(info.trgt_ptr)), + grid_sN(static_cast(info.grid_sN)), + grid_sC(static_cast(info.grid_sC)), + grid_sX(static_cast(info.grid_sX)), + grid_sY(static_cast(info.grid_sY)), + grid_sZ(static_cast(info.grid_sZ)), + grid_ptr(static_cast(info.grid_ptr)), + out_sN(static_cast(info.out_sN)), + out_sC(static_cast(info.out_sC)), + out_sX(static_cast(info.out_sX)), + out_sY(static_cast(info.out_sY)), + out_sZ(static_cast(info.out_sZ)), + out_sK(static_cast(info.out_sK)), + out_ptr(static_cast(info.out_ptr)), + grad_sN(static_cast(info.grad_sN)), + grad_sC(static_cast(info.grad_sC)), + grad_sX(static_cast(info.grad_sX)), + grad_sY(static_cast(info.grad_sY)), + grad_sZ(static_cast(info.grad_sZ)), + grad_ptr(static_cast(info.grad_ptr)) {} // ~~~ PUBLIC VALUE ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -245,39 +572,9 @@ MONAI_NAMESPACE_DEVICE { // cuda // } // ~~~ FUNCTORS ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - MONAI_HOST void ioset // Pull - (const Tensor& source, const Tensor& grid) { - init_all(); - init_source(source); - init_grid(grid); - init_output(); - } - - MONAI_HOST void ioset(const Tensor& source, const Tensor& grid, const Tensor& target) { - init_all(); - init_source(source); - init_grid(grid); - init_target(target); - init_output(); - } - - MONAI_HOST void ioset // Push - (IntArrayRef source_size, const Tensor& grid, const Tensor& target) { - init_all(); - init_source(source_size); - init_grid(grid); - init_target(target); - init_output(); - } - - MONAI_HOST void ioset // Count - (IntArrayRef source_size, const Tensor& grid) { - init_all(); - init_source(source_size); - init_grid(grid); - init_output(); - } + // Loop over voxels that belong to one CUDA block + // This function is called by the CUDA kernel MONAI_DEVICE void loop(int threadIdx, int blockIdx, int blockDim, int gridDim) const; MONAI_HOST MONAI_DEVICE int64_t voxcount() const { @@ -286,14 +583,18 @@ MONAI_NAMESPACE_DEVICE { // cuda private: // ~~~ COMPONENTS ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - MONAI_HOST void init_all(); - MONAI_HOST void init_source(const Tensor& source); - MONAI_HOST void init_source(IntArrayRef source_size); - MONAI_HOST void init_grid(const Tensor& grid); - MONAI_HOST void init_target(const Tensor& target); - MONAI_HOST void init_output(); + MONAI_DEVICE void check1d(offset_t w, offset_t n) const; MONAI_DEVICE void check2d(offset_t w, offset_t h, offset_t n) const; MONAI_DEVICE void check3d(offset_t w, offset_t h, offset_t d, offset_t n) const; + MONAI_DEVICE void interpolate1d(scalar_t x, offset_t w, offset_t n) const; + MONAI_DEVICE void interpolate1d_nearest(scalar_t x, offset_t w, offset_t n) const; + MONAI_DEVICE void interpolate1d_linear(scalar_t x, offset_t w, offset_t n) const; + MONAI_DEVICE void interpolate1d_sliding(scalar_t x, offset_t w, offset_t n) const { /*TODO*/ + } + MONAI_DEVICE void interpolate1d_sliding_nearest(scalar_t x, offset_t w, offset_t n) const { /*TODO*/ + } + MONAI_DEVICE void interpolate1d_sliding_linear(scalar_t x, offset_t w, offset_t n) const { /*TODO*/ + } MONAI_DEVICE void interpolate2d(scalar_t x, scalar_t y, offset_t w, offset_t h, offset_t n) const; MONAI_DEVICE void interpolate2d_nearest(scalar_t x, scalar_t y, offset_t w, offset_t h, offset_t n) const; MONAI_DEVICE void interpolate2d_bilinear(scalar_t x, scalar_t y, offset_t w, offset_t h, offset_t n) const; @@ -368,9 +669,6 @@ MONAI_NAMESPACE_DEVICE { // cuda bool do_sgrad; // sample spatial gradients // ~~~ NAVIGATORS ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - TensorOptions src_opt; - TensorOptions grid_opt; - TensorOptions trgt_opt; offset_t N; offset_t C; offset_t src_X; @@ -397,173 +695,22 @@ MONAI_NAMESPACE_DEVICE { // cuda offset_t grid_sC; offset_t grid_sX; offset_t grid_sY; - offset_t grid_sZ; - scalar_t* grid_ptr; - offset_t out_sN; - offset_t out_sC; - offset_t out_sX; - offset_t out_sY; - offset_t out_sZ; - offset_t out_sK; // gradient dimension - scalar_t* out_ptr; - offset_t grad_sN; - offset_t grad_sC; - offset_t grad_sX; - offset_t grad_sY; - offset_t grad_sZ; - scalar_t* grad_ptr; - }; - - // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - // INITIALISATION - // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - - template - void PushPullImpl::init_all() { - src_opt = grid_opt = trgt_opt = TensorOptions(); - N = C = static_cast(1); - src_X = src_Y = src_Z = static_cast(1); - trgt_X = trgt_Y = trgt_Z = trgt_K = static_cast(1); - src_sN = src_sC = src_sX = src_sY = src_sZ = static_cast(0); - grid_sN = grid_sC = grid_sX = grid_sY = grid_sZ = static_cast(0); - grad_sN = grad_sC = grad_sX = grad_sY = grad_sZ = static_cast(0); - trgt_sN = trgt_sC = trgt_sX = trgt_sY = trgt_sZ = trgt_sK = static_cast(0); - out_sN = out_sC = out_sX = out_sY = out_sZ = out_sK = static_cast(0); - src_ptr = trgt_ptr = grid_ptr = out_ptr = grad_ptr = static_cast(0); - } - - template - MONAI_HOST void PushPullImpl::init_source(const Tensor& source) { - N = source.size(0); - C = source.size(1); - src_X = source.size(2); - src_Y = source.size(3); - src_Z = dim == 2 ? static_cast(1) : source.size(4); - src_sN = source.stride(0); - src_sC = source.stride(1); - src_sX = source.stride(2); - src_sY = source.stride(3); - src_sZ = dim == 2 ? static_cast(0) : source.stride(4); - src_ptr = source.data_ptr(); - src_opt = source.options(); - } - - template - MONAI_HOST void PushPullImpl::init_source(IntArrayRef source_size) { - src_X = source_size[0]; - src_Y = source_size[1]; - src_Z = dim == 2 ? static_cast(1) : source_size[2]; - } - - template - MONAI_HOST void PushPullImpl::init_grid(const Tensor& grid) { - N = grid.size(0); - trgt_X = grid.size(1); - trgt_Y = grid.size(2); - trgt_Z = dim == 2 ? static_cast(1) : grid.size(3); - grid_sN = grid.stride(0); - grid_sX = grid.stride(1); - grid_sY = grid.stride(2); - grid_sZ = dim == 2 ? static_cast(0) : grid.stride(3); - grid_sC = grid.stride(dim == 2 ? 3 : 4); - grid_ptr = grid.data_ptr(); - grid_opt = grid.options(); - } - - template - MONAI_HOST void PushPullImpl::init_target(const Tensor& target) { - N = target.size(0); - C = target.size(1); - trgt_X = target.size(2); - trgt_Y = target.size(3); - trgt_Z = dim == 2 ? static_cast(1) : target.size(4); - trgt_K = target.dim() == dim + 3 ? target.size(dim == 2 ? 4 : 5) : static_cast(1); - trgt_sN = target.stride(0); - trgt_sC = target.stride(1); - trgt_sX = target.stride(2); - trgt_sY = target.stride(3); - trgt_sZ = dim == 2 ? static_cast(0) : target.stride(4); - trgt_sK = target.dim() == dim + 3 ? target.stride(dim == 2 ? 4 : 5) : static_cast(0); - trgt_ptr = target.data_ptr(); - trgt_opt = target.options(); - } - - template - MONAI_HOST void PushPullImpl::init_output() { - output.clear(); - if (do_pull) { - if (dim == 2) - output.push_back(at::empty({N, C, trgt_X, trgt_Y}, src_opt)); - else - output.push_back(at::empty({N, C, trgt_X, trgt_Y, trgt_Z}, src_opt)); - auto pull = output.back(); - out_sN = pull.stride(0); - out_sC = pull.stride(1); - out_sX = pull.stride(2); - out_sY = pull.stride(3); - out_sZ = dim == 2 ? static_cast(0) : pull.stride(4); - out_sK = static_cast(0); - out_ptr = pull.template data_ptr(); - } else if (do_sgrad) { - if (dim == 2) - output.push_back(at::empty({N, C, trgt_X, trgt_Y, 2}, src_opt)); - else - output.push_back(at::empty({N, C, trgt_X, trgt_Y, trgt_Z, 3}, src_opt)); - auto sgrad = output.back(); - out_sN = sgrad.stride(0); - out_sC = sgrad.stride(1); - out_sX = sgrad.stride(2); - out_sY = sgrad.stride(3); - out_sZ = dim == 2 ? static_cast(0) : sgrad.stride(4); - out_sK = sgrad.stride(dim == 2 ? 4 : 5); - out_ptr = sgrad.template data_ptr(); - - if (iso && interpolation0 == InterpolationType::Nearest) - sgrad.zero_(); - } else if (do_push) { - if (dim == 2) - output.push_back(at::zeros({N, C, src_X, src_Y}, trgt_opt)); - else - output.push_back(at::zeros({N, C, src_X, src_Y, src_Z}, trgt_opt)); - auto push = output.back(); - out_sN = push.stride(0); - out_sC = push.stride(1); - out_sX = push.stride(2); - out_sY = push.stride(3); - out_sZ = dim == 2 ? static_cast(0) : push.stride(4); - out_sK = static_cast(0); - out_ptr = push.template data_ptr(); - } else if (do_count) { - if (dim == 2) - output.push_back(at::zeros({N, 1, src_X, src_Y}, grid_opt)); - else - output.push_back(at::zeros({N, 1, src_X, src_Y, src_Z}, grid_opt)); - auto count = output.back(); - out_sN = count.stride(0); - out_sC = count.stride(1); - out_sX = count.stride(2); - out_sY = count.stride(3); - out_sZ = dim == 2 ? static_cast(0) : count.stride(4); - out_sK = static_cast(0); - out_ptr = count.template data_ptr(); - } - if (do_grad) { - if (dim == 2) - output.push_back(at::zeros({N, src_X, src_Y, 2}, grid_opt)); - else - output.push_back(at::zeros({N, src_X, src_Y, src_Z, 3}, grid_opt)); - auto grad = output.back(); - grad_sN = grad.stride(0); - grad_sX = grad.stride(1); - grad_sY = grad.stride(2); - grad_sZ = dim == 2 ? static_cast(0) : grad.stride(3); - grad_sC = grad.stride(dim == 2 ? 3 : 4); - grad_ptr = grad.template data_ptr(); - - if (iso && interpolation0 == InterpolationType::Nearest) - grad.zero_(); - } - } + offset_t grid_sZ; + scalar_t* grid_ptr; + offset_t out_sN; + offset_t out_sC; + offset_t out_sX; + offset_t out_sY; + offset_t out_sZ; + offset_t out_sK; // gradient dimension + scalar_t* out_ptr; + offset_t grad_sN; + offset_t grad_sC; + offset_t grad_sX; + offset_t grad_sY; + offset_t grad_sZ; + scalar_t* grad_ptr; + }; // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ // LOOP @@ -584,7 +731,9 @@ MONAI_NAMESPACE_DEVICE { // cuda h = (i / trgt_Z) % trgt_Y; d = i % trgt_Z; - if (dim == 2) + if (dim == 1) + check1d(w, n); + else if (dim == 2) check2d(w, h, n); else check3d(w, h, d, n); @@ -599,6 +748,59 @@ MONAI_NAMESPACE_DEVICE { // cuda // 1) read the [x,y,z] source coordinate for the current target voxel // 3) check if the source coordinate is in bounds + template + MONAI_DEVICE void PushPullImpl::check3d(offset_t w, offset_t h, offset_t d, offset_t n) const { + // get the corresponding input x, y, z co-ordinates from grid + scalar_t* grid_ptr_NXYZ = grid_ptr + n * grid_sN + w * grid_sX + h * grid_sY + d * grid_sZ; + scalar_t x = *grid_ptr_NXYZ; + scalar_t y = grid_ptr_NXYZ[grid_sC]; + scalar_t z = grid_ptr_NXYZ[grid_sC * 2]; + + // Check if out-of-bound + if (!(extrapolate || + (inbounds(x, src_X, static_cast(TINY)) && inbounds(y, src_Y, static_cast(TINY)) && + inbounds(z, src_Z, static_cast(TINY))))) { + if (do_pull || do_sgrad) { + scalar_t* out_ptr_NCXYZ = out_ptr + n * out_sN + w * out_sX + h * out_sY + d * out_sZ; + for (offset_t c = 0; c < C; ++c, out_ptr_NCXYZ += out_sC) { + *out_ptr_NCXYZ = static_cast(0); + if (do_sgrad) { + out_ptr_NCXYZ[out_sK] = static_cast(0); + out_ptr_NCXYZ[out_sK * 2] = static_cast(0); + } + } + } + if (do_grad) { + scalar_t* grad_ptr_NXYZ = grad_ptr + n * grad_sN + w * grad_sX + h * grad_sY + d * grad_sZ; + (*grad_ptr_NXYZ) = static_cast(0); + grad_ptr_NXYZ[grad_sC] = static_cast(0); + grad_ptr_NXYZ[grad_sC * 2] = static_cast(0); + } + return; + } + + // Next step + if (bound0 == BoundType::Sliding) { + if (iso) + switch (static_cast(interpolation0)) { + case 0: + return interpolate3d_sliding_nearest(x, y, z, w, h, d, n); + case 1: + return interpolate3d_sliding_trilinear(x, y, z, w, h, d, n); + } + return interpolate3d_sliding(x, y, z, w, h, d, n); + } else { + if (iso) + switch (static_cast(interpolation0)) { + case 0: + return interpolate3d_nearest(x, y, z, w, h, d, n); + case 1: + return interpolate3d_trilinear(x, y, z, w, h, d, n); + } + return interpolate3d(x, y, z, w, h, d, n); + } + } + template MONAI_DEVICE void PushPullImpl::check2d(offset_t w, offset_t h, offset_t n) const { // get the corresponding input x, y, z co-ordinates from grid @@ -610,7 +812,7 @@ MONAI_NAMESPACE_DEVICE { // cuda if (!(extrapolate || (inbounds(x, src_X, static_cast(TINY)) && inbounds(y, src_Y, static_cast(TINY))))) { if (do_pull || do_sgrad) { - scalar_t* out_ptr_NCXY = out_ptr + n * out_sN + w * out_sZ + h * out_sY; + scalar_t* out_ptr_NCXY = out_ptr + n * out_sN + w * out_sX + h * out_sY; for (offset_t c = 0; c < C; ++c, out_ptr_NCXY += out_sC) { *out_ptr_NCXY = static_cast(0); if (do_sgrad) @@ -648,32 +850,25 @@ MONAI_NAMESPACE_DEVICE { // cuda } template - MONAI_DEVICE void PushPullImpl::check3d(offset_t w, offset_t h, offset_t d, offset_t n) const { + MONAI_DEVICE void PushPullImpl::check1d(offset_t w, offset_t n) const { // get the corresponding input x, y, z co-ordinates from grid - scalar_t* grid_ptr_NXYZ = grid_ptr + n * grid_sN + w * grid_sX + h * grid_sY + d * grid_sZ; - scalar_t x = *grid_ptr_NXYZ; - scalar_t y = grid_ptr_NXYZ[grid_sC]; - scalar_t z = grid_ptr_NXYZ[grid_sC * 2]; + scalar_t* grid_ptr_NX = grid_ptr + n * grid_sN + w * grid_sX; + scalar_t x = *grid_ptr_NX; // Check if out-of-bound - if (!(extrapolate || - (inbounds(x, src_X, static_cast(TINY)) && inbounds(y, src_Y, static_cast(TINY)) && - inbounds(z, src_Z, static_cast(TINY))))) { + if (!(extrapolate || inbounds(x, src_X, static_cast(TINY)))) { if (do_pull || do_sgrad) { - scalar_t* out_ptr_NCXYZ = out_ptr + n * out_sN + w * out_sX + h * out_sY + d * out_sZ; - for (offset_t c = 0; c < C; ++c, out_ptr_NCXYZ += out_sC) { - *out_ptr_NCXYZ = static_cast(0); - if (do_sgrad) { - out_ptr_NCXYZ[out_sK] = static_cast(0); - out_ptr_NCXYZ[out_sK * 2] = static_cast(0); - } + scalar_t* out_ptr_NCX = out_ptr + n * out_sN + w * out_sX; + for (offset_t c = 0; c < C; ++c, out_ptr_NCX += out_sC) { + *out_ptr_NCX = static_cast(0); + if (do_sgrad) + out_ptr_NCX[out_sK] = static_cast(0); } } if (do_grad) { - scalar_t* grad_ptr_NXYZ = grad_ptr + n * grad_sN + w * grad_sX + h * grad_sY + d * grad_sZ; - (*grad_ptr_NXYZ) = static_cast(0); - grad_ptr_NXYZ[grad_sC] = static_cast(0); - grad_ptr_NXYZ[grad_sC * 2] = static_cast(0); + scalar_t* grad_ptr_NX = grad_ptr + n * grad_sN + w * grad_sX; + (*grad_ptr_NX) = static_cast(0); + grad_ptr_NX[grad_sC] = static_cast(0); } return; } @@ -683,20 +878,20 @@ MONAI_NAMESPACE_DEVICE { // cuda if (iso) switch (static_cast(interpolation0)) { case 0: - return interpolate3d_sliding_nearest(x, y, z, w, h, d, n); + return interpolate1d_sliding_nearest(x, w, n); case 1: - return interpolate3d_sliding_trilinear(x, y, z, w, h, d, n); + return interpolate1d_sliding_linear(x, w, n); } - return interpolate3d_sliding(x, y, z, w, h, d, n); + return interpolate1d_sliding(x, w, n); } else { if (iso) switch (static_cast(interpolation0)) { case 0: - return interpolate3d_nearest(x, y, z, w, h, d, n); + return interpolate1d_nearest(x, w, n); case 1: - return interpolate3d_trilinear(x, y, z, w, h, d, n); + return interpolate1d_linear(x, w, n); } - return interpolate3d(x, y, z, w, h, d, n); + return interpolate1d(x, w, n); } } @@ -731,7 +926,7 @@ MONAI_NAMESPACE_DEVICE { // cuda if (trgt_ptr && (do_push || do_grad)) for (offset_t c = 0; c < C; ++c, trgt_ptr_NCXYZ += trgt_sC) { target[c] = *trgt_ptr_NCXYZ; - if (trgt_K > 1) { + if (trgt_K > 0) { target[c + C] = trgt_ptr_NCXYZ[trgt_sK]; target[c + C * 2] = trgt_ptr_NCXYZ[trgt_sK * 2]; } @@ -849,7 +1044,7 @@ MONAI_NAMESPACE_DEVICE { // cuda // ~~~~~~~~~~~~~~~~~~~~~~~~~~~ Push ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ else if (do_push) { - if (trgt_K == 1) { + if (trgt_K == 0) { // Diff w.r.t. push/pull scalar_t* out_ptr_NC = out_ptr_NC0; for (offset_t c = 0; c < C; ++c, out_ptr_NC += out_sC) @@ -872,7 +1067,7 @@ MONAI_NAMESPACE_DEVICE { // cuda // ~~~~~~~~~~~~~~~~~~~~~~~~~~~ Grad ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ if (do_grad) { - if (trgt_K == 1) { + if (trgt_K == 0) { // Diff w.r.t. pull/push scalar_t* src_ptr_NC = src_ptr_NC0; scalar_t dot = static_cast(0); @@ -941,7 +1136,7 @@ MONAI_NAMESPACE_DEVICE { // cuda if (trgt_ptr && (do_push || do_grad)) for (offset_t c = 0; c < C; ++c, trgt_ptr_NCXY += trgt_sC) { target[c] = *trgt_ptr_NCXY; - if (trgt_K > 1) { + if (trgt_K > 0) { target[c + C] = trgt_ptr_NCXY[trgt_sK]; } } @@ -1034,7 +1229,7 @@ MONAI_NAMESPACE_DEVICE { // cuda // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Push ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ else if (do_push) { - if (trgt_K == 1) { + if (trgt_K == 0) { // Diff w.r.t. push/pull scalar_t* out_ptr_NC = out_ptr_NC0; for (offset_t c = 0; c < C; ++c, out_ptr_NC += out_sC) @@ -1056,7 +1251,7 @@ MONAI_NAMESPACE_DEVICE { // cuda // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Grad ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ if (do_grad) { - if (trgt_K == 1) { + if (trgt_K == 0) { // Diff w.r.t. pull/push scalar_t* src_ptr_NC = src_ptr_NC0; scalar_t dot = static_cast(0); @@ -1093,6 +1288,150 @@ MONAI_NAMESPACE_DEVICE { // cuda } } + // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + // GENERIC INTERPOLATION 1D + // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + template + MONAI_DEVICE void PushPullImpl::interpolate1d(scalar_t x, offset_t w, offset_t n) const { + // Get corner pixel values from (x, y) + offset_t bx0, bx1; + interpolation::bounds(interpolation0, x, bx0, bx1); + offset_t dbx = bx1 - bx0; + + // Pre-compute offsets and target value + scalar_t* src_ptr_NC0 = src_ptr + n * src_sN; + scalar_t* out_ptr_NC0 = out_ptr + n * out_sN; + scalar_t* out_ptr_NCX0 = out_ptr + n * out_sN + w * out_sX; + scalar_t* trgt_ptr_NCX = trgt_ptr + n * trgt_sN + w * trgt_sX; + scalar_t target[2 * MONAI_MAX_NUM_CHANNELS]; + if (trgt_ptr && (do_push || do_grad)) + for (offset_t c = 0; c < C; ++c, trgt_ptr_NCX += trgt_sC) { + target[c] = *trgt_ptr_NCX; + if (trgt_K > 0) { + target[c + C] = trgt_ptr_NCX[trgt_sK]; + } + } + + // Initialize output + scalar_t* out_ptr_NCX = out_ptr_NCX0; + if (do_pull || do_sgrad) { + for (offset_t c = 0; c < C; ++c, out_ptr_NCX += out_sC) { + *out_ptr_NCX = static_cast(0); + if (do_sgrad) { + out_ptr_NCX[out_sK] = static_cast(0); + } + } + } + + // Pre-compute indices/weights/grad + scalar_t wx[8]; // B-spline weights + scalar_t gx[8]; // B-spline derivatives + scalar_t hx[8]; // B-spline 2nd derivatives + offset_t ix[8]; // Warped indices + uint8_t sx[8]; // Warped indices + + { + scalar_t *owx = static_cast(wx), *ogx = static_cast(gx), *ohx = static_cast(hx); + offset_t* oix = static_cast(ix); + uint8_t* osx = static_cast(sx); + for (offset_t bx = bx0; bx <= bx1; ++bx) { + scalar_t dx = x - bx; + *(owx++) = interpolation::fastweight(interpolation0, dx); + if (do_grad || do_sgrad) + *(ogx++) = interpolation::fastgrad(interpolation0, dx); + if (do_grad && trgt_sK > 1) + *(ohx++) = interpolation::fasthess(interpolation0, dx); + *(osx++) = bound::sign(bound0, bx, src_X); + *(oix++) = bound::index(bound0, bx, src_X); + } + } + + // Convolve coefficients with basis functions + scalar_t ogx; + ogx = static_cast(0); + for (offset_t i = 0; i <= dbx; ++i) { + offset_t oox = ix[i] * out_sX; + offset_t osx = ix[i] * src_sX; + uint8_t sxx = sx[i]; + scalar_t wxx = wx[i]; + scalar_t gxx = gx[i]; + scalar_t hxx = hx[i]; + + // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Pull ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if (do_pull) { + scalar_t* src_ptr_NC = src_ptr_NC0; + scalar_t* out_ptr_NCX = out_ptr_NCX0; + for (offset_t c = 0; c < C; ++c, out_ptr_NCX += out_sC, src_ptr_NC += src_sC) + *out_ptr_NCX += bound::get(src_ptr_NC, osx, sxx) * wxx; + } + + // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ SGrad ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + else if (do_sgrad) { + scalar_t* src_ptr_NC = src_ptr_NC0; + scalar_t* out_ptr_NCX = out_ptr_NCX0; + for (offset_t c = 0; c < C; ++c, out_ptr_NCX += out_sC, src_ptr_NC += src_sC) { + scalar_t src = bound::get(src_ptr_NC, osx, sxx); + *out_ptr_NCX += src * gxx; + } + } + + // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Push ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + else if (do_push) { + if (trgt_K == 0) { + // Diff w.r.t. push/pull + scalar_t* out_ptr_NC = out_ptr_NC0; + for (offset_t c = 0; c < C; ++c, out_ptr_NC += out_sC) + bound::add(out_ptr_NC, oox, wxx * target[c], sxx); + } else { + // Diff w.r.t. sgrad + scalar_t* out_ptr_NC = out_ptr_NC0; + for (offset_t c = 0; c < C; ++c, out_ptr_NC += out_sC) { + scalar_t val = gxx * target[c]; + bound::add(out_ptr_NC, oox, val, sxx); + } + } + } + + // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Count ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + else if (do_count) { + bound::add(out_ptr_NC0, oox, wxx, sxx); + } + + // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Grad ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if (do_grad) { + if (trgt_K == 0) { + // Diff w.r.t. pull/push + scalar_t* src_ptr_NC = src_ptr_NC0; + scalar_t dot = static_cast(0); + for (offset_t c = 0; c < C; ++c, src_ptr_NC += src_sC) { + scalar_t src = bound::get(src_ptr_NC, osx, sxx); + dot += (trgt_ptr ? src * target[c] : src); + // trgt_ptr == 0 in the backward pass of 'count' + } + ogx += gxx * dot; + } else { + // Diff w.r.t. sgrad + scalar_t* src_ptr_NC = src_ptr_NC0; + scalar_t dot; + dot = static_cast(0); + for (offset_t c = 0; c < C; ++c, src_ptr_NC += src_sC) { + scalar_t src = bound::get(src_ptr_NC, osx, sxx); + dot += src * target[c]; + } + ogx += hxx * dot; + } + } + + } // x + + // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Grad ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if (do_grad) { + scalar_t* grad_ptr_NX = grad_ptr + n * grad_sN + w * grad_sX; + (*grad_ptr_NX) = ogx; + } + } + // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ // LINEAR INTERPOLATION 3D // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -1182,7 +1521,7 @@ MONAI_NAMESPACE_DEVICE { // cuda scalar_t* trgt_ptr_NCXYZ = trgt_ptr + n * trgt_sN + w * trgt_sX + h * trgt_sY + d * trgt_sZ; scalar_t* src_ptr_NC = src_ptr + n * src_sN; - if (trgt_K == 1) { + if (trgt_K == 0) { // backward w.r.t. push/pull for (offset_t c = 0; c < C; ++c, trgt_ptr_NCXYZ += trgt_sC, src_ptr_NC += src_sC) { scalar_t src; @@ -1344,7 +1683,7 @@ MONAI_NAMESPACE_DEVICE { // cuda o111 = ix1 * out_sX + iy1 * out_sY + iz1 * out_sZ; scalar_t* trgt_ptr_NCXYZ = trgt_ptr + n * trgt_sN + w * trgt_sX + h * trgt_sY + d * trgt_sZ; scalar_t* out_ptr_NC = out_ptr + n * out_sN; - if (trgt_K == 1) { + if (trgt_K == 0) { // Diff w.r.t. push/pull for (offset_t c = 0; c < C; ++c, trgt_ptr_NCXYZ += trgt_sC, out_ptr_NC += out_sC) { scalar_t trgt = *trgt_ptr_NCXYZ; @@ -1467,7 +1806,7 @@ MONAI_NAMESPACE_DEVICE { // cuda scalar_t* trgt_ptr_NCXY = trgt_ptr + n * trgt_sN + w * trgt_sX + h * trgt_sY; scalar_t* src_ptr_NC = src_ptr + n * src_sN; - if (trgt_K == 1) { + if (trgt_K == 0) { // backward w.r.t. push/pull for (offset_t c = 0; c < C; ++c, trgt_ptr_NCXY += trgt_sC, src_ptr_NC += src_sC) { scalar_t src; @@ -1514,9 +1853,9 @@ MONAI_NAMESPACE_DEVICE { // cuda } } - scalar_t* grad_ptr_NXYZ = grad_ptr + n * grad_sN + w * grad_sX + h * grad_sY; - (*grad_ptr_NXYZ) = gx; - grad_ptr_NXYZ[grad_sC] = gy; + scalar_t* grad_ptr_NXY = grad_ptr + n * grad_sN + w * grad_sX + h * grad_sY; + (*grad_ptr_NXY) = gx; + grad_ptr_NXY[grad_sC] = gy; } // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Pull ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ if (do_pull) { @@ -1558,7 +1897,7 @@ MONAI_NAMESPACE_DEVICE { // cuda o11 = ix1 * out_sX + iy1 * out_sY; scalar_t* trgt_ptr_NCXY = trgt_ptr + n * trgt_sN + w * trgt_sX + h * trgt_sY; scalar_t* out_ptr_NC = out_ptr + n * out_sN; - if (trgt_K == 1) { + if (trgt_K == 0) { // Diff w.r.t. push/pull for (offset_t c = 0; c < C; ++c, trgt_ptr_NCXY += trgt_sC, out_ptr_NC += out_sC) { scalar_t trgt = *trgt_ptr_NCXY; @@ -1599,6 +1938,123 @@ MONAI_NAMESPACE_DEVICE { // cuda } } + // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + // LINEAR INTERPOLATION 1D + // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + template + MONAI_DEVICE void PushPullImpl::interpolate1d_linear(scalar_t x, offset_t w, offset_t n) const { + // Get corner pixel values from (x) + offset_t ix0 = static_cast(std::floor(x)); + + // Interpolation weights (inversely proportional to distance) + scalar_t w1 = x - ix0; + scalar_t w0 = 1. - w1; + + // Sign (/!\ compute sign before warping indices) + int8_t s1 = bound::sign(bound0, ix0 + 1, src_X); + int8_t s0 = bound::sign(bound0, ix0, src_X); + + // Warp indices + offset_t ix1; + ix1 = bound::index(bound0, ix0 + 1, src_X); + ix0 = bound::index(bound0, ix0, src_X); + + // Offsets into source volume + offset_t o0, o1; + if (do_pull || do_grad || do_sgrad) { + o0 = ix0 * src_sX; + o1 = ix1 * src_sX; + } + + // ~~~~~~~~~~~~~~~~~~~~~~~~~~ Grid gradient ~~~~~~~~~~~~~~~~~~~~~~~~~~ + if (do_grad) { + if (trgt_K == 0) { + // backward w.r.t. push/pull + + o0 = ix0 * src_sX; + o1 = ix1 * src_sX; + scalar_t gx = static_cast(0); + scalar_t* trgt_ptr_NCX = trgt_ptr + n * trgt_sN + w * trgt_sX; + scalar_t* src_ptr_NC = src_ptr + n * src_sN; + + for (offset_t c = 0; c < C; ++c, trgt_ptr_NCX += trgt_sC, src_ptr_NC += src_sC) { + scalar_t src; + scalar_t trgt = trgt_ptr ? *trgt_ptr_NCX : static_cast(1); + // ^ trgt_ptr == 0 during the backward pass of count + src = bound::get(src_ptr_NC, o0, s0); + if (trgt_ptr) + src *= trgt; + gx -= src; + src = bound::get(src_ptr_NC, o1, s1); + if (trgt_ptr) + src *= trgt; + gx += src; + } + + scalar_t* grad_ptr_NX = grad_ptr + n * grad_sN + w * grad_sX; + (*grad_ptr_NX) = gx; + } else { + // backward w.r.t. sgrad + // -> zero (make sure this is done at initialization) + } + } + // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Pull ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if (do_pull) { + o0 = ix0 * src_sX; + o1 = ix1 * src_sX; + scalar_t* out_ptr_NCX = out_ptr + n * out_sN + w * out_sX; + scalar_t* src_ptr_NC = src_ptr + n * src_sN; + for (offset_t c = 0; c < C; ++c, out_ptr_NCX += out_sC, src_ptr_NC += src_sC) { + *out_ptr_NCX = bound::get(src_ptr_NC, o0, s0) * w0 + bound::get(src_ptr_NC, o1, s1) * w1; + } + } + // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ SGrad ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + else if (do_sgrad) { + o0 = ix0 * src_sX; + o1 = ix1 * src_sX; + scalar_t* out_ptr_NCX = out_ptr + n * out_sN + w * out_sX; + scalar_t* src_ptr_NC = src_ptr + n * src_sN; + + for (offset_t c = 0; c < C; ++c, out_ptr_NCX += out_sC, src_ptr_NC += src_sC) { + *out_ptr_NCX = bound::get(src_ptr_NC, o1, s1) - bound::get(src_ptr_NC, o0, s0); + } + } + // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Push ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + else if (do_push) { + // Offsets into 'push' volume + o0 = ix0 * out_sX; + o1 = ix1 * out_sX; + scalar_t* trgt_ptr_NCX = trgt_ptr + n * trgt_sN + w * trgt_sX; + scalar_t* out_ptr_NC = out_ptr + n * out_sN; + if (trgt_K == 0) { + // Diff w.r.t. push/pull + for (offset_t c = 0; c < C; ++c, trgt_ptr_NCX += trgt_sC, out_ptr_NC += out_sC) { + scalar_t trgt = *trgt_ptr_NCX; + bound::add(out_ptr_NC, o0, w0 * trgt, s0); + bound::add(out_ptr_NC, o1, w1 * trgt, s1); + } + } else { + // Diff w.r.t. sgrad + for (offset_t c = 0; c < C; ++c, trgt_ptr_NCX += trgt_sC, out_ptr_NC += out_sC) { + scalar_t trgt0 = *trgt_ptr_NCX; + bound::add(out_ptr_NC, o0, -trgt0, s0); + bound::add(out_ptr_NC, o1, trgt0, s1); + } + } + } + // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Push ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + else if (do_count) { + // Offsets into 'push' volume + o0 = ix0 * out_sX; + o1 = ix1 * out_sX; + + scalar_t* out_ptr_N = out_ptr + n * out_sN; + bound::add(out_ptr_N, o0, w0, s0); + bound::add(out_ptr_N, o1, w1, s1); + } + } + // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ // NEAREST NEIGHBOR INTERPOLATION 3D // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -1633,7 +2089,7 @@ MONAI_NAMESPACE_DEVICE { // cuda scalar_t* src_ptr_NC = src_ptr + n * src_sN; for (offset_t c = 0; c < C; ++c, out_ptr_NCXYZ += out_sC, src_ptr_NC += src_sC) *out_ptr_NCXYZ = bound::get(src_ptr_NC, o, s); - } else if (do_push && trgt_K == 1) { + } else if (do_push && trgt_K == 0) { offset_t o = iz * out_sZ + iy * out_sY + ix * out_sX; scalar_t* trgt_ptr_NCXYZ = trgt_ptr + n * trgt_sN + w * trgt_sX + h * trgt_sY + d * trgt_sZ; scalar_t* out_ptr_NC = out_ptr + n * out_sN; @@ -1676,7 +2132,7 @@ MONAI_NAMESPACE_DEVICE { // cuda scalar_t* src_ptr_NC = src_ptr + n * src_sN; for (offset_t c = 0; c < C; ++c, out_ptr_NCXY += out_sC, src_ptr_NC += src_sC) *out_ptr_NCXY = bound::get(src_ptr_NC, o, s); - } else if (do_push && trgt_K == 1) { + } else if (do_push && trgt_K == 0) { offset_t o = iy * out_sY + ix * out_sX; scalar_t* trgt_ptr_NCXY = trgt_ptr + n * trgt_sN + w * trgt_sX + h * trgt_sY; scalar_t* out_ptr_NC = out_ptr + n * out_sN; @@ -1689,6 +2145,39 @@ MONAI_NAMESPACE_DEVICE { // cuda bound::add(out_ptr_NC, o, static_cast(1), s); } } + + // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + // NEAREST NEIGHBOR INTERPOLATION 1D + // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + template + MONAI_DEVICE void PushPullImpl::interpolate1d_nearest(scalar_t x, offset_t w, offset_t n) const { + offset_t i = static_cast(std::round(x)); + + // Boundary condition (/!\ compute sign before warping indices) + int8_t s = bound::sign(bound0, i, src_X); + i = bound::index(bound0, i, src_X); + + if (do_pull) { + offset_t o = i * src_sX; + scalar_t* out_ptr_NCX = out_ptr + n * out_sN + w * out_sX; + scalar_t* src_ptr_NC = src_ptr + n * src_sN; + for (offset_t c = 0; c < C; ++c, out_ptr_NCX += out_sC, src_ptr_NC += src_sC) + *out_ptr_NCX = bound::get(src_ptr_NC, o, s); + } else if (do_push && trgt_K == 0) { + offset_t o = i * out_sX; + scalar_t* trgt_ptr_NCX = trgt_ptr + n * trgt_sN + w * trgt_sX; + scalar_t* out_ptr_NC = out_ptr + n * out_sN; + for (offset_t c = 0; c < C; ++c, trgt_ptr_NCX += trgt_sC, out_ptr_NC += out_sC) + bound::add(out_ptr_NC, o, *trgt_ptr_NCX, s); + } else if (do_count) { + offset_t o = i * out_sX; + scalar_t* out_ptr_NC = out_ptr + n * out_sN; + for (offset_t c = 0; c < C; ++c, out_ptr_NC += out_sC) + bound::add(out_ptr_NC, o, static_cast(1), s); + } + } + // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ // LINEAR INTERPOLATION 3D + SLIDING BOUNDARY // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -1736,8 +2225,6 @@ MONAI_NAMESPACE_DEVICE { // cuda PUSHPULL_INSTANTIATE1(BoundType); \ PUSHPULL_INSTANTIATE1(BoundVectorRef) - // ~~~ CUDA ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - // Two arguments (source, grid) // > `bound` and `interpolation` can be single arguments or vectors. template @@ -1752,12 +2239,20 @@ MONAI_NAMESPACE_DEVICE { // cuda bool do_count, bool do_grad, bool do_sgrad) { + PushPullAllocator info( + grid.dim() - 2, bound, interpolation, extrapolate, do_pull, do_push, do_count, do_grad, do_sgrad); + info.ioset(source, grid); + return AT_DISPATCH_FLOATING_TYPES_AND_HALF(grid.scalar_type(), "pushpull", [&] { - PushPullImpl f( - grid.dim() - 2, bound, interpolation, extrapolate, do_pull, do_push, do_count, do_grad, do_sgrad); - f.ioset(source, grid); - pushpull_kernel<<>>(f); - return f.output; + if (info.canUse32BitIndexMath()) { + PushPullImpl algo(info); + pushpull_kernel<<>>(algo); + return algo.output; + } else { + PushPullImpl algo(info); + pushpull_kernel<<>>(algo); + return algo.output; + } }); } @@ -1777,17 +2272,24 @@ MONAI_NAMESPACE_DEVICE { // cuda bool do_count, bool do_grad, bool do_sgrad) { + PushPullAllocator info( + grid.dim() - 2, bound, interpolation, extrapolate, do_pull, do_push, do_count, do_grad, do_sgrad); + info.ioset(source, grid, target); + return AT_DISPATCH_FLOATING_TYPES_AND_HALF(grid.scalar_type(), "pushpull", [&] { - PushPullImpl f( - grid.dim() - 2, bound, interpolation, extrapolate, do_pull, do_push, do_count, do_grad, do_sgrad); - f.ioset(source, grid, target); - pushpull_kernel<<>>(f); - return f.output; + if (info.canUse32BitIndexMath()) { + PushPullImpl algo(info); + pushpull_kernel<<>>(algo); + return algo.output; + } else { + PushPullImpl algo(info); + pushpull_kernel<<>>(algo); + return algo.output; + } }); } PUSHPULL_INSTANTIATE; -} // namespace - +} // namespace gpu } // namespace monai diff --git a/runtests.sh b/runtests.sh index 76692e731b..1395ccdcfd 100755 --- a/runtests.sh +++ b/runtests.sh @@ -159,6 +159,7 @@ function clean_py { find ${TO_CLEAN} -depth -maxdepth 1 -type d -name ".mypy_cache" -exec rm -r "{}" + find ${TO_CLEAN} -depth -maxdepth 1 -type d -name ".pytype" -exec rm -r "{}" + find ${TO_CLEAN} -depth -maxdepth 1 -type d -name ".coverage" -exec rm -r "{}" + + find ${TO_CLEAN} -depth -maxdepth 1 -type d -name "__pycache__" -exec rm -r "{}" + } function torch_validate { From e974e46257ca3f66eb6b72b0d536ebecef6e665e Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Wed, 3 Feb 2021 23:29:20 +0000 Subject: [PATCH 08/14] tests enums Signed-off-by: Wenqi Li --- monai/networks/layers/spatial_transforms.py | 137 +++++++++++--------- tests/test_enum_bound_interp.py | 73 +++++++++++ 2 files changed, 146 insertions(+), 64 deletions(-) create mode 100644 tests/test_enum_bound_interp.py diff --git a/monai/networks/layers/spatial_transforms.py b/monai/networks/layers/spatial_transforms.py index 175fd05694..ac1112a68b 100644 --- a/monai/networks/layers/spatial_transforms.py +++ b/monai/networks/layers/spatial_transforms.py @@ -60,7 +60,9 @@ def grid_pull(input: torch.Tensor, grid: torch.Tensor, interpolation="linear", b - 2 or 'quadratic' or InterpolationType.quadratic - 3 or 'cubic' or InterpolationType.cubic - 4 or 'fourth' or InterpolationType.fourth - - etc. + - 5 or 'fifth' or InterpolationType.fifth + - 6 or 'sixth' or InterpolationType.sixth + - 7 or 'seventh' or InterpolationType.seventh A list of values can be provided, in the order [W, H, D], to specify dimension-specific interpolation orders. @@ -68,14 +70,13 @@ def grid_pull(input: torch.Tensor, grid: torch.Tensor, interpolation="linear", b `bound` can be an int, a string or a BoundType. Possible values are:: - - 0 or 'replicate' or BoundType.replicate - - 1 or 'dct1' or BoundType.dct1 - - 2 or 'dct2' or BoundType.dct2 - - 3 or 'dst1' or BoundType.dst1 - - 4 or 'dst2' or BoundType.dst2 - - 5 or 'dft' or BoundType.dft - - 6 or 'sliding' or BoundType.sliding [not implemented] - - 7 or 'zero' or BoundType.zero + - 0 or 'replicate' or 'nearest' or BoundType.replicate + - 1 or 'dct1' or 'mirror' or BoundType.dct1 + - 2 or 'dct2' or 'reflect' or BoundType.dct2 + - 3 or 'dst1' or 'antimirror' or BoundType.dst1 + - 4 or 'dst2' or 'antireflect' or BoundType.dst2 + - 5 or 'dft' or 'wrap' or BoundType.dft + - 7 or 'zero' or BoundType.zero A list of values can be provided, in the order [W, H, D], to specify dimension-specific boundary conditions. @@ -87,15 +88,17 @@ def grid_pull(input: torch.Tensor, grid: torch.Tensor, interpolation="linear", b - `dct2` corresponds to Neumann boundary conditions (symmetric) - `dst2` corresponds to Dirichlet boundary conditions (antisymmetric) - See: - https://en.wikipedia.org/wiki/Discrete_cosine_transform - https://en.wikipedia.org/wiki/Discrete_sine_transform + See Also: + - https://en.wikipedia.org/wiki/Discrete_cosine_transform + - https://en.wikipedia.org/wiki/Discrete_sine_transform + - ``help(monai._C.BoundType)`` + - ``help(monai._C.InterpolationType)`` Args: input: Input image. `(B, C, Wi, Hi, Di)`. - grid: Deformation field. `(B, Wo, Ho, Do, 2|3)`. + grid: Deformation field. `(B, Wo, Ho, Do, 1|2|3)`. interpolation (int or list[int] , optional): Interpolation order. - Defaults to `1`. + Defaults to `'linear'`. bound (BoundType, or list[BoundType], optional): Boundary conditions. Defaults to `'zero'`. extrapolate: Extrapolate out-of-bound data. @@ -106,11 +109,10 @@ def grid_pull(input: torch.Tensor, grid: torch.Tensor, interpolation="linear", b """ # Convert parameters - bound = ensure_tuple(bound) - interpolation = ensure_tuple(interpolation) - bound = [_C.BoundType.__members__[b] if isinstance(b, str) else _C.BoundType(b) for b in bound] + bound = [_C.BoundType.__members__[b] if isinstance(b, str) else _C.BoundType(b) for b in ensure_tuple(bound)] interpolation = [ - _C.InterpolationType.__members__[i] if isinstance(i, str) else _C.InterpolationType(i) for i in interpolation + _C.InterpolationType.__members__[i] if isinstance(i, str) else _C.InterpolationType(i) + for i in ensure_tuple(interpolation) ] return _GridPull.apply(input, grid, interpolation, bound, extrapolate) @@ -156,7 +158,9 @@ def grid_push( - 2 or 'quadratic' or InterpolationType.quadratic - 3 or 'cubic' or InterpolationType.cubic - 4 or 'fourth' or InterpolationType.fourth - - etc. + - 5 or 'fifth' or InterpolationType.fifth + - 6 or 'sixth' or InterpolationType.sixth + - 7 or 'seventh' or InterpolationType.seventh A list of values can be provided, in the order `[W, H, D]`, to specify dimension-specific interpolation orders. @@ -164,14 +168,13 @@ def grid_push( `bound` can be an int, a string or a BoundType. Possible values are:: - - 0 or 'replicate' or BoundType.replicate - - 1 or 'dct1' or BoundType.dct1 - - 2 or 'dct2' or BoundType.dct2 - - 3 or 'dst1' or BoundType.dst1 - - 4 or 'dst2' or BoundType.dst2 - - 5 or 'dft' or BoundType.dft - - 6 or 'sliding' or BoundType.sliding [not implemented] - - 7 or 'zero' or BoundType.zero + - 0 or 'replicate' or 'nearest' or BoundType.replicate + - 1 or 'dct1' or 'mirror' or BoundType.dct1 + - 2 or 'dct2' or 'reflect' or BoundType.dct2 + - 3 or 'dst1' or 'antimirror' or BoundType.dst1 + - 4 or 'dst2' or 'antireflect' or BoundType.dst2 + - 5 or 'dft' or 'wrap' or BoundType.dft + - 7 or 'zero' or BoundType.zero A list of values can be provided, in the order `[W, H, D]`, to specify dimension-specific boundary conditions. @@ -183,17 +186,19 @@ def grid_push( - `dct2` corresponds to Neumann boundary conditions (symmetric) - `dst2` corresponds to Dirichlet boundary conditions (antisymmetric) - See also: + See Also: - https://en.wikipedia.org/wiki/Discrete_cosine_transform - https://en.wikipedia.org/wiki/Discrete_sine_transform + - ``help(monai._C.BoundType)`` + - ``help(monai._C.InterpolationType)`` Args: input: Input image `(B, C, Wi, Hi, Di)`. - grid: Deformation field `(B, Wi, Hi, Di, 2|3)`. + grid: Deformation field `(B, Wi, Hi, Di, 1|2|3)`. shape: Shape of the source image. interpolation (int or list[int] , optional): Interpolation order. - Defaults to `1`. + Defaults to `'linear'`. bound (BoundType, or list[BoundType], optional): Boundary conditions. Defaults to `'zero'`. extrapolate: Extrapolate out-of-bound data. @@ -204,11 +209,10 @@ def grid_push( """ # Convert parameters - bound = ensure_tuple(bound) - interpolation = ensure_tuple(interpolation) - bound = [_C.BoundType.__members__[b] if isinstance(b, str) else _C.BoundType(b) for b in bound] + bound = [_C.BoundType.__members__[b] if isinstance(b, str) else _C.BoundType(b) for b in ensure_tuple(bound)] interpolation = [ - _C.InterpolationType.__members__[i] if isinstance(i, str) else _C.InterpolationType(i) for i in interpolation + _C.InterpolationType.__members__[i] if isinstance(i, str) else _C.InterpolationType(i) + for i in ensure_tuple(interpolation) ] if shape is None: @@ -252,7 +256,9 @@ def grid_count(grid: torch.Tensor, shape=None, interpolation="linear", bound="ze - 2 or 'quadratic' or InterpolationType.quadratic - 3 or 'cubic' or InterpolationType.cubic - 4 or 'fourth' or InterpolationType.fourth - - etc. + - 5 or 'fifth' or InterpolationType.fifth + - 6 or 'sixth' or InterpolationType.sixth + - 7 or 'seventh' or InterpolationType.seventh A list of values can be provided, in the order [W, H, D], to specify dimension-specific interpolation orders. @@ -260,14 +266,13 @@ def grid_count(grid: torch.Tensor, shape=None, interpolation="linear", bound="ze `bound` can be an int, a string or a BoundType. Possible values are:: - - 0 or 'replicate' or BoundType.replicate - - 1 or 'dct1' or BoundType.dct1 - - 2 or 'dct2' or BoundType.dct2 - - 3 or 'dst1' or BoundType.dst1 - - 4 or 'dst2' or BoundType.dst2 - - 5 or 'dft' or BoundType.dft - - 6 or 'sliding' or BoundType.sliding [not implemented] - - 7 or 'zero' or BoundType.zero + - 0 or 'replicate' or 'nearest' or BoundType.replicate + - 1 or 'dct1' or 'mirror' or BoundType.dct1 + - 2 or 'dct2' or 'reflect' or BoundType.dct2 + - 3 or 'dst1' or 'antimirror' or BoundType.dst1 + - 4 or 'dst2' or 'antireflect' or BoundType.dst2 + - 5 or 'dft' or 'wrap' or BoundType.dft + - 7 or 'zero' or BoundType.zero A list of values can be provided, in the order [W, H, D], to specify dimension-specific boundary conditions. @@ -283,12 +288,14 @@ def grid_count(grid: torch.Tensor, shape=None, interpolation="linear", bound="ze - https://en.wikipedia.org/wiki/Discrete_cosine_transform - https://en.wikipedia.org/wiki/Discrete_sine_transform + - ``help(monai._C.BoundType)`` + - ``help(monai._C.InterpolationType)`` Args: grid: Deformation field `(B, Wi, Hi, Di, 2|3)`. shape: shape of the source image. interpolation (int or list[int] , optional): Interpolation order. - Defaults to `1`. + Defaults to `'linear'`. bound (BoundType, or list[BoundType], optional): Boundary conditions. Defaults to `'zero'`. extrapolate (bool, optional): Extrapolate out-of-bound data. @@ -299,11 +306,10 @@ def grid_count(grid: torch.Tensor, shape=None, interpolation="linear", bound="ze """ # Convert parameters - bound = ensure_tuple(bound) - interpolation = ensure_tuple(interpolation) - bound = [_C.BoundType.__members__[b] if isinstance(b, str) else _C.BoundType(b) for b in bound] + bound = [_C.BoundType.__members__[b] if isinstance(b, str) else _C.BoundType(b) for b in ensure_tuple(bound)] interpolation = [ - _C.InterpolationType.__members__[i] if isinstance(i, str) else _C.InterpolationType(i) for i in interpolation + _C.InterpolationType.__members__[i] if isinstance(i, str) else _C.InterpolationType(i) + for i in ensure_tuple(interpolation) ] if shape is None: @@ -351,7 +357,9 @@ def grid_grad(input: torch.Tensor, grid: torch.Tensor, interpolation="linear", b - 2 or 'quadratic' or InterpolationType.quadratic - 3 or 'cubic' or InterpolationType.cubic - 4 or 'fourth' or InterpolationType.fourth - - etc. + - 5 or 'fifth' or InterpolationType.fifth + - 6 or 'sixth' or InterpolationType.sixth + - 7 or 'seventh' or InterpolationType.seventh A list of values can be provided, in the order [W, H, D], to specify dimension-specific interpolation orders. @@ -359,14 +367,13 @@ def grid_grad(input: torch.Tensor, grid: torch.Tensor, interpolation="linear", b `bound` can be an int, a string or a BoundType. Possible values are:: - - 0 or 'replicate' or BoundType.replicate - - 1 or 'dct1' or BoundType.dct1 - - 2 or 'dct2' or BoundType.dct2 - - 3 or 'dst1' or BoundType.dst1 - - 4 or 'dst2' or BoundType.dst2 - - 5 or 'dft' or BoundType.dft - - 6 or 'sliding' or BoundType.sliding [not implemented] - - 7 or 'zero' or BoundType.zero + - 0 or 'replicate' or 'nearest' or BoundType.replicate + - 1 or 'dct1' or 'mirror' or BoundType.dct1 + - 2 or 'dct2' or 'reflect' or BoundType.dct2 + - 3 or 'dst1' or 'antimirror' or BoundType.dst1 + - 4 or 'dst2' or 'antireflect' or BoundType.dst2 + - 5 or 'dft' or 'wrap' or BoundType.dft + - 7 or 'zero' or BoundType.zero A list of values can be provided, in the order [W, H, D], to specify dimension-specific boundary conditions. @@ -378,30 +385,32 @@ def grid_grad(input: torch.Tensor, grid: torch.Tensor, interpolation="linear", b - `dct2` corresponds to Neumann boundary conditions (symmetric) - `dst2` corresponds to Dirichlet boundary conditions (antisymmetric) - See also: + See Also: - https://en.wikipedia.org/wiki/Discrete_cosine_transform - https://en.wikipedia.org/wiki/Discrete_sine_transform + - ``help(monai._C.BoundType)`` + - ``help(monai._C.InterpolationType)`` + Args: input: Input image. `(B, C, Wi, Hi, Di)`. grid: Deformation field. `(B, Wo, Ho, Do, 2|3)`. interpolation (int or list[int] , optional): Interpolation order. - Defaults to `1`. + Defaults to `'linear'`. bound (BoundType, or list[BoundType], optional): Boundary conditions. Defaults to `'zero'`. extrapolate: Extrapolate out-of-bound data. Defaults to `True`. Returns: - output (torch.Tensor): Sampled gradients (B, C, Wo, Ho, Do, 2|3). + output (torch.Tensor): Sampled gradients (B, C, Wo, Ho, Do, 1|2|3). """ # Convert parameters - bound = ensure_tuple(bound) - interpolation = ensure_tuple(interpolation) - bound = [_C.BoundType.__members__[b] if isinstance(b, str) else _C.BoundType(b) for b in bound] + bound = [_C.BoundType.__members__[b] if isinstance(b, str) else _C.BoundType(b) for b in ensure_tuple(bound)] interpolation = [ - _C.InterpolationType.__members__[i] if isinstance(i, str) else _C.InterpolationType(i) for i in interpolation + _C.InterpolationType.__members__[i] if isinstance(i, str) else _C.InterpolationType(i) + for i in ensure_tuple(interpolation) ] return _GridGrad.apply(input, grid, interpolation, bound, extrapolate) diff --git a/tests/test_enum_bound_interp.py b/tests/test_enum_bound_interp.py new file mode 100644 index 0000000000..c20471f945 --- /dev/null +++ b/tests/test_enum_bound_interp.py @@ -0,0 +1,73 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from monai.utils import optional_import +from tests.utils import skip_if_no_cpp_extention + +b, _ = optional_import("monai._C", name="BoundType") +p, _ = optional_import("monai._C", name="InterpolationType") + + +@skip_if_no_cpp_extention +class TestEnumBoundInterp(unittest.TestCase): + def test_bound(self): + self.assertEqual(str(b.replicate), "BoundType.replicate") + self.assertEqual(str(b.nearest), "BoundType.replicate") + self.assertEqual(str(b.dct1), "BoundType.dct1") + self.assertEqual(str(b.mirror), "BoundType.dct1") + self.assertEqual(str(b.dct2), "BoundType.dct2") + self.assertEqual(str(b.reflect), "BoundType.dct2") + self.assertEqual(str(b.dst1), "BoundType.dst1") + self.assertEqual(str(b.antimirror), "BoundType.dst1") + self.assertEqual(str(b.dst2), "BoundType.dst2") + self.assertEqual(str(b.antireflect), "BoundType.dst2") + self.assertEqual(str(b.dft), "BoundType.dft") + self.assertEqual(str(b.wrap), "BoundType.dft") + self.assertEqual(str(b.zero), "BoundType.zero") + + self.assertEqual(int(b.replicate), 0) + self.assertEqual(int(b.nearest), 0) + self.assertEqual(int(b.dct1), 1) + self.assertEqual(int(b.mirror), 1) + self.assertEqual(int(b.dct2), 2) + self.assertEqual(int(b.reflect), 2) + self.assertEqual(int(b.dst1), 3) + self.assertEqual(int(b.antimirror), 3) + self.assertEqual(int(b.dst2), 4) + self.assertEqual(int(b.antireflect), 4) + self.assertEqual(int(b.dft), 5) + self.assertEqual(int(b.wrap), 5) + self.assertEqual(int(b.zero), 7) + + def test_interp(self): + self.assertEqual(str(p.nearest), "InterpolationType.nearest") + self.assertEqual(str(p.linear), "InterpolationType.linear") + self.assertEqual(str(p.quadratic), "InterpolationType.quadratic") + self.assertEqual(str(p.cubic), "InterpolationType.cubic") + self.assertEqual(str(p.fourth), "InterpolationType.fourth") + self.assertEqual(str(p.fifth), "InterpolationType.fifth") + self.assertEqual(str(p.sixth), "InterpolationType.sixth") + self.assertEqual(str(p.seventh), "InterpolationType.seventh") + + self.assertEqual(int(p.nearest), 0) + self.assertEqual(int(p.linear), 1) + self.assertEqual(int(p.quadratic), 2) + self.assertEqual(int(p.cubic), 3) + self.assertEqual(int(p.fourth), 4) + self.assertEqual(int(p.fifth), 5) + self.assertEqual(int(p.sixth), 6) + self.assertEqual(int(p.seventh), 7) + + +if __name__ == "__main__": + unittest.main() From b3e9a354ec76768f652f78191864c81010866641 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Thu, 4 Feb 2021 11:14:12 +0000 Subject: [PATCH 09/14] init. test grid pull Signed-off-by: Wenqi Li --- monai/csrc/utils/common_utils.h | 8 ++--- tests/test_grid_pull.py | 60 +++++++++++++++++++++++++++++++++ 2 files changed, 64 insertions(+), 4 deletions(-) create mode 100644 tests/test_grid_pull.py diff --git a/monai/csrc/utils/common_utils.h b/monai/csrc/utils/common_utils.h index 3b90221ac3..4d09377e65 100644 --- a/monai/csrc/utils/common_utils.h +++ b/monai/csrc/utils/common_utils.h @@ -42,18 +42,18 @@ limitations under the License. #define CHECK_SAME_DEVICE(value1, value2) \ TORCH_CHECK( \ value1.device() == value2.device(), \ - "(): expected " #value2 " and " #value2 \ + "(): expected " #value1 " and " #value2 \ " to be on same device, " \ - "but " #value2 " is on ", \ + "but " #value1 " is on ", \ value1.device(), \ " and " #value2 " is on ", \ value2.device()); #define CHECK_SAME_DTYPE(value1, value2) \ TORCH_CHECK( \ value1.dtype() == value2.dtype(), \ - "(): expected " #value2 " and " #value2 \ + "(): expected " #value1 " and " #value2 \ " to have the same dtype, " \ - "but " #value2 " has ", \ + "but " #value1 " has ", \ value1.dtype(), \ " and " #value2 " has ", \ value2.dtype()); diff --git a/tests/test_grid_pull.py b/tests/test_grid_pull.py new file mode 100644 index 0000000000..c9eed6dd25 --- /dev/null +++ b/tests/test_grid_pull.py @@ -0,0 +1,60 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +from parameterized import parameterized + +from monai.networks.layers import grid_pull +from monai.utils import optional_import +from tests.utils import skip_if_no_cpp_extention + +BType, has_b_type = optional_import("monai._C", name="BoundType") +PType, has_p_type = optional_import("monai._C", name="InterpolationType") + + +def make_grid(shape, dtype=None, device=None): + ranges = [torch.arange(float(s), dtype=dtype, device=device) for s in shape] + grid = torch.stack(torch.meshgrid(*ranges), dim=-1) + return grid[None] + + +# 1D combinations of bounds/interpolations +bounds = set(BType.__members__.values()) if has_b_type else [] +interps = set(PType.__members__.values()) if has_p_type else [] +Expected_1D_BP_fwd = [torch.tensor([[[0.0, 1.0], [2.0, 3.0]]])] * 56 +assert len(bounds) * len(interps) == len(Expected_1D_BP_fwd) # all combinations +TEST_1D_BP_fwd = [] +for bound in bounds: + for interp in interps: + test_case = [ + { + "input": torch.arange(10, dtype=torch.float).reshape((1, 1, 10)), + "grid": make_grid((20,), dtype=torch.float) + 0.5, + "interpolation": interp, + "bound": bound, + }, + Expected_1D_BP_fwd.pop(0), + ] + TEST_1D_BP_fwd.append(test_case) + +@skip_if_no_cpp_extention +class TestGridPull(unittest.TestCase): + @parameterized.expand(TEST_1D_BP_fwd) + def test_grid_pull(self, input_param, expected_val): + result = grid_pull(**input_param) + print(input_param["interpolation"], input_param["bound"], result) + # np.testing.assert_allclose(result.cpu().numpy(), expected_val.cpu().numpy(), rtol=1e-4, atol=1e-4) + + +if __name__ == "__main__": + unittest.main() From 45a1e9519675e8d7ad97c7f1560a269c0c72acd5 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Mon, 8 Feb 2021 13:50:14 +0000 Subject: [PATCH 10/14] update Signed-off-by: Wenqi Li --- tests/test_enum_bound_interp.py | 4 ++-- tests/test_grid_pull.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_enum_bound_interp.py b/tests/test_enum_bound_interp.py index c20471f945..f788f8ba17 100644 --- a/tests/test_enum_bound_interp.py +++ b/tests/test_enum_bound_interp.py @@ -12,13 +12,13 @@ import unittest from monai.utils import optional_import -from tests.utils import skip_if_no_cpp_extention +from tests.utils import skip_if_no_cpp_extension b, _ = optional_import("monai._C", name="BoundType") p, _ = optional_import("monai._C", name="InterpolationType") -@skip_if_no_cpp_extention +@skip_if_no_cpp_extension class TestEnumBoundInterp(unittest.TestCase): def test_bound(self): self.assertEqual(str(b.replicate), "BoundType.replicate") diff --git a/tests/test_grid_pull.py b/tests/test_grid_pull.py index c9eed6dd25..caf7ef74aa 100644 --- a/tests/test_grid_pull.py +++ b/tests/test_grid_pull.py @@ -16,7 +16,7 @@ from monai.networks.layers import grid_pull from monai.utils import optional_import -from tests.utils import skip_if_no_cpp_extention +from tests.utils import skip_if_no_cpp_extension BType, has_b_type = optional_import("monai._C", name="BoundType") PType, has_p_type = optional_import("monai._C", name="InterpolationType") @@ -47,7 +47,7 @@ def make_grid(shape, dtype=None, device=None): ] TEST_1D_BP_fwd.append(test_case) -@skip_if_no_cpp_extention +@skip_if_no_cpp_extension class TestGridPull(unittest.TestCase): @parameterized.expand(TEST_1D_BP_fwd) def test_grid_pull(self, input_param, expected_val): From 44dcf7ea33dbb8d63f5f4ce9c3810ff5fc4519ab Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Mon, 8 Feb 2021 14:28:46 +0000 Subject: [PATCH 11/14] test grid_pull Signed-off-by: Wenqi Li --- tests/test_grid_pull.py | 9 ++-- tests/testing_data/cpp_resample_answers.py | 25 +++++++++ tests/testing_data/cpp_resample_answers.txt | 56 +++++++++++++++++++++ 3 files changed, 86 insertions(+), 4 deletions(-) create mode 100644 tests/testing_data/cpp_resample_answers.py create mode 100644 tests/testing_data/cpp_resample_answers.txt diff --git a/tests/test_grid_pull.py b/tests/test_grid_pull.py index caf7ef74aa..bdd00f0bbd 100644 --- a/tests/test_grid_pull.py +++ b/tests/test_grid_pull.py @@ -11,11 +11,13 @@ import unittest +import numpy as np import torch from parameterized import parameterized from monai.networks.layers import grid_pull from monai.utils import optional_import +from tests.testing_data.cpp_resample_answers import Expected_1D_BP_fwd from tests.utils import skip_if_no_cpp_extension BType, has_b_type = optional_import("monai._C", name="BoundType") @@ -31,7 +33,6 @@ def make_grid(shape, dtype=None, device=None): # 1D combinations of bounds/interpolations bounds = set(BType.__members__.values()) if has_b_type else [] interps = set(PType.__members__.values()) if has_p_type else [] -Expected_1D_BP_fwd = [torch.tensor([[[0.0, 1.0], [2.0, 3.0]]])] * 56 assert len(bounds) * len(interps) == len(Expected_1D_BP_fwd) # all combinations TEST_1D_BP_fwd = [] for bound in bounds: @@ -43,17 +44,17 @@ def make_grid(shape, dtype=None, device=None): "interpolation": interp, "bound": bound, }, - Expected_1D_BP_fwd.pop(0), + torch.tensor([[Expected_1D_BP_fwd.pop(0)]]), ] TEST_1D_BP_fwd.append(test_case) + @skip_if_no_cpp_extension class TestGridPull(unittest.TestCase): @parameterized.expand(TEST_1D_BP_fwd) def test_grid_pull(self, input_param, expected_val): result = grid_pull(**input_param) - print(input_param["interpolation"], input_param["bound"], result) - # np.testing.assert_allclose(result.cpu().numpy(), expected_val.cpu().numpy(), rtol=1e-4, atol=1e-4) + np.testing.assert_allclose(result.cpu().numpy(), expected_val.cpu().numpy(), rtol=1e-4, atol=1e-4) if __name__ == "__main__": diff --git a/tests/testing_data/cpp_resample_answers.py b/tests/testing_data/cpp_resample_answers.py new file mode 100644 index 0000000000..19d50bd9a3 --- /dev/null +++ b/tests/testing_data/cpp_resample_answers.py @@ -0,0 +1,25 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import csv +import os + +Expected_1D_BP_fwd = [] +pwd = os.path.dirname(os.path.abspath(__file__)) # current file's location +with open(os.path.join(pwd, "cpp_resample_answers.txt")) as f: + res_reader = csv.reader(f, delimiter=",") + for r in res_reader: + res_row = [] + for item in r: + if item.strip().startswith("#"): + continue + res_row.append(float(item)) + Expected_1D_BP_fwd.append(res_row) diff --git a/tests/testing_data/cpp_resample_answers.txt b/tests/testing_data/cpp_resample_answers.txt new file mode 100644 index 0000000000..a620d59dff --- /dev/null +++ b/tests/testing_data/cpp_resample_answers.txt @@ -0,0 +1,56 @@ +1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, # InterpolationType.nearest BoundType.replicate +0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, # InterpolationType.linear BoundType.replicate +0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, # InterpolationType.quadratic BoundType.replicate +0.5208, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.4792, 8.9792, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, # InterpolationType.cubic BoundType.replicate +0.5417, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.4583, 8.9583, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, # InterpolationType.fourth BoundType.replicate +0.5622, 1.5003, 2.5, 3.5, 4.5, 5.5, 6.5, 7.4997, 8.4378, 8.9378, 8.9997, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, # InterpolationType.fifth BoundType.replicate +0.5819, 1.5014, 2.5, 3.5, 4.5, 5.5, 6.5, 7.4986, 8.4181, 8.9181, 8.9986, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, # InterpolationType.sixth BoundType.replicate +0.6008, 1.5034, 2.5, 3.5, 4.5, 5.5, 6.5, 7.4966, 8.3992, 8.8992, 8.9966, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, # InterpolationType.seventh BoundType.replicate +1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 8.0, 7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0, 0.0, 1.0, 2.0, # InterpolationType.nearest BoundType.dct1 +0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 8.5, 7.5, 6.5, 5.5, 4.5, 3.5, 2.5, 1.5, 0.5, 0.5, 1.5, # InterpolationType.linear BoundType.dct1 +0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 8.5, 7.5, 6.5, 5.5, 4.5, 3.5, 2.5, 1.5, 0.5, 0.5, 1.5, # InterpolationType.quadratic BoundType.dct1 +0.5417, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.4583, 8.4583, 7.5, 6.5, 5.5, 4.5, 3.5, 2.5, 1.5, 0.5417, 0.5417, 1.5, # InterpolationType.cubic BoundType.dct1 +0.5833, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.4167, 8.4167, 7.5, 6.5, 5.5, 4.5, 3.5, 2.5, 1.5, 0.5833, 0.5833, 1.5, # InterpolationType.fourth BoundType.dct1 +0.6245, 1.5005, 2.5, 3.5, 4.5, 5.5, 6.5, 7.4995, 8.3755, 8.3755, 7.4995, 6.5, 5.5, 4.5, 3.5, 2.5, 1.5005, 0.6245, 0.6245, 1.5005, # InterpolationType.fifth BoundType.dct1 +0.6639, 1.5028, 2.5, 3.5, 4.5, 5.5, 6.5, 7.4972, 8.3361, 8.3361, 7.4972, 6.5, 5.5, 4.5, 3.5, 2.5, 1.5028, 0.6639, 0.6639, 1.5028, # InterpolationType.sixth BoundType.dct1 +0.7016, 1.5068, 2.5, 3.5, 4.5, 5.5, 6.5, 7.4932, 8.2984, 8.2984, 7.4932, 6.5, 5.5, 4.5, 3.5, 2.5, 1.5068, 0.7016, 0.7016, 1.5068, # InterpolationType.seventh BoundType.dct1 +1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 9.0, 8.0, 7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0, 0.0, 0.0, # InterpolationType.nearest BoundType.dct2 +0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.0, 8.5, 7.5, 6.5, 5.5, 4.5, 3.5, 2.5, 1.5, 0.5, 0.0, # InterpolationType.linear BoundType.dct2 +0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.0, 8.5, 7.5, 6.5, 5.5, 4.5, 3.5, 2.5, 1.5, 0.5, 0.0, # InterpolationType.quadratic BoundType.dct2 +0.5208, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.4792, 8.9583, 8.4792, 7.5, 6.5, 5.5, 4.5, 3.5, 2.5, 1.5, 0.5208, 0.0417, # InterpolationType.cubic BoundType.dct2 +0.5417, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.4583, 8.9167, 8.4583, 7.5, 6.5, 5.5, 4.5, 3.5, 2.5, 1.5, 0.5417, 0.0833, # InterpolationType.fourth BoundType.dct2 +0.5625, 1.5003, 2.5, 3.5, 4.5, 5.5, 6.5, 7.4997, 8.4375, 8.8755, 8.4375, 7.4997, 6.5, 5.5, 4.5, 3.5, 2.5, 1.5003, 0.5625, 0.1245, # InterpolationType.fifth BoundType.dct2 +0.5833, 1.5014, 2.5, 3.5, 4.5, 5.5, 6.5, 7.4986, 8.4167, 8.8361, 8.4167, 7.4986, 6.5, 5.5, 4.5, 3.5, 2.5, 1.5014, 0.5833, 0.1639, # InterpolationType.sixth BoundType.dct2 +0.6042, 1.5034, 2.5, 3.5, 4.5, 5.5, 6.5, 7.4966, 8.3958, 8.7984, 8.3958, 7.4966, 6.5, 5.5, 4.5, 3.5, 2.5, 1.5034, 0.6042, 0.2016, # InterpolationType.seventh BoundType.dct2 +1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 0.0, -9.0, -8.0, -7.0, -6.0, -5.0, -4.0, -3.0, -2.0, -1.0, -0.0, # InterpolationType.nearest BoundType.dst1 +0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 4.5, -4.5, -8.5, -7.5, -6.5, -5.5, -4.5, -3.5, -2.5, -1.5, -0.5, # InterpolationType.linear BoundType.dst1 +0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 4.5, -4.5, -8.5, -7.5, -6.5, -5.5, -4.5, -3.5, -2.5, -1.5, -0.5, # InterpolationType.quadratic BoundType.dst1 +0.5208, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.2917, 4.2917, -4.2917, -8.2917, -7.5, -6.5, -5.5, -4.5, -3.5, -2.5, -1.5, -0.5208, # InterpolationType.cubic BoundType.dst1 +0.5417, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.0833, 4.0833, -4.0833, -8.0833, -7.5, -6.5, -5.5, -4.5, -3.5, -2.5, -1.5, -0.5417, # InterpolationType.fourth BoundType.dst1 +0.5622, 1.5003, 2.5, 3.5, 4.5, 5.5, 6.5, 7.4974, 7.8776, 3.8802, -3.8802, -7.8776, -7.4974, -6.5, -5.5, -4.5, -3.5, -2.5, -1.5003, -0.5622, # InterpolationType.fifth BoundType.dst1 +0.5819, 1.5014, 2.5, 3.5, 4.5, 5.5, 6.5, 7.4861, 7.6806, 3.6944, -3.6944, -7.6806, -7.4861, -6.5, -5.5, -4.5, -3.5, -2.5, -1.5014, -0.5819, # InterpolationType.sixth BoundType.dst1 +0.6008, 1.5034, 2.5, 3.5, 4.5, 5.5, 6.5, 7.4662, 7.4922, 3.5260, -3.5260, -7.4922, -7.4662, -6.5, -5.5, -4.5, -3.5, -2.5, -1.5034, -0.6008, # InterpolationType.seventh BoundType.dst1 +1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, -9.0, -8.0, -7.0, -6.0, -5.0, -4.0, -3.0, -2.0, -1.0, -0.0, 0.0, # InterpolationType.nearest BoundType.dst2 +0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 0.0, -8.5, -7.5, -6.5, -5.5, -4.5, -3.5, -2.5, -1.5, -0.5, 0.0, # InterpolationType.linear BoundType.dst2 +0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 0.0, -8.5, -7.5, -6.5, -5.5, -4.5, -3.5, -2.5, -1.5, -0.5, 0.0, # InterpolationType.quadratic BoundType.dst2 +5.2083e-01, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.1042, -1.6391e-07, -8.1042, -7.5, -6.5, -5.5, -4.5, -3.5, -2.5, -1.5, -5.2083e-01, 0.0, # InterpolationType.cubic BoundType.dst2 +5.4167e-01, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 7.7083, 1.4901e-07, -7.7083, -7.5, -6.5, -5.5, -4.5, -3.5, -2.5, -1.5, -5.4167e-01, 0.0, # InterpolationType.fourth BoundType.dst2 +5.6198e-01, 1.5003, 2.5, 3.5, 4.5, 5.5, 6.5, 7.4951, 7.3224, 1.2107e-07, -7.3224, -7.4951, -6.5, -5.5, -4.5, -3.5, -2.5, -1.5003, -5.6198e-01, 5.2387e-10, # InterpolationType.fifth BoundType.dst2 +5.8056e-01, 1.5014, 2.5, 3.5, 4.5, 5.5, 6.5, 7.4736, 6.9694, -1.0896e-07, -6.9694, -7.4736, -6.5, -5.5, -4.5, -3.5, -2.5, -1.5014, -5.8056e-01, 2.3283e-10, # InterpolationType.sixth BoundType.dst2 +0.59740, 1.5034, 2.5, 3.5, 4.5, 5.5, 6.5, 7.4358, 6.6493, 0.0, -6.6493, -7.4358, -6.5, -5.5, -4.5, -3.5, -2.5, -1.5034, -0.59740, 0.0, # InterpolationType.seventh BoundType.dst2 +1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 0.0, # InterpolationType.nearest BoundType.dft +0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 4.5, 0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 4.5, # InterpolationType.linear BoundType.dft +0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 4.5, 0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 4.5, # InterpolationType.quadratic BoundType.dft +0.7083, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.2917, 4.5, 0.7083, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.2917, 4.5, # InterpolationType.cubic BoundType.dft +0.9167, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.0833, 4.5, 0.9167, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.0833, 4.5, # InterpolationType.fourth BoundType.dft +1.1198, 1.5026, 2.5, 3.5, 4.5, 5.5, 6.5, 7.4974, 7.8802, 4.5, 1.1198, 1.5026, 2.5, 3.5, 4.5, 5.5, 6.5, 7.4974, 7.8802, 4.5, # InterpolationType.fifth BoundType.dft +1.3056, 1.5139, 2.5, 3.5, 4.5, 5.5, 6.5, 7.4861, 7.6944, 4.5, 1.3056, 1.5139, 2.5, 3.5, 4.5, 5.5, 6.5, 7.4861, 7.6944, 4.5, # InterpolationType.sixth BoundType.dft +1.4740, 1.5338, 2.5, 3.5, 4.5, 5.5, 6.5, 7.4662, 7.5260, 4.5, 1.4740, 1.5338, 2.5, 3.5, 4.5, 5.5, 6.5, 7.4662, 7.5260, 4.5, # InterpolationType.seventh BoundType.dft +1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, # InterpolationType.nearest BoundType.zero +0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 4.5, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, # InterpolationType.linear BoundType.zero +0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 4.5, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, # InterpolationType.quadratic BoundType.zero +0.5208, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.2917, 4.4792, 0.1875, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, # InterpolationType.cubic BoundType.zero +0.5417, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.0833, 4.4583, 0.3750, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, # InterpolationType.fourth BoundType.zero +5.6224e-01, 1.5003, 2.5, 3.5, 4.5, 5.5, 6.5, 7.4974, 7.8799, 4.4378, 5.5755e-01, 2.3438e-03, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, # InterpolationType.fifth BoundType.zero +0.5819, 1.5014, 2.5, 3.5, 4.5, 5.5, 6.5, 7.4861, 7.6931, 4.4181, 0.7236, 0.0125, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, # InterpolationType.sixth BoundType.zero +6.0078e-01, 1.5034, 2.5, 3.5, 4.5, 5.5, 6.5, 7.4662, 7.5226, 4.3992, 8.7325e-01, 3.0411e-02, 1.3951e-05, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, # InterpolationType.seventh BoundType.zero From c08450c52db97f979ef8b9f694e58f4503e80f3a Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Mon, 8 Feb 2021 19:24:06 +0000 Subject: [PATCH 12/14] fixes min test Signed-off-by: Wenqi Li --- monai/networks/layers/spatial_transforms.py | 8 ++-- tests/test_grid_pull.py | 30 ++++++++++----- ...cpp_resample_answers.txt => 1D_BP_fwd.txt} | 0 tests/testing_data/cpp_resample_answers.py | 37 +++++++++++++------ 4 files changed, 51 insertions(+), 24 deletions(-) rename tests/testing_data/{cpp_resample_answers.txt => 1D_BP_fwd.txt} (100%) diff --git a/monai/networks/layers/spatial_transforms.py b/monai/networks/layers/spatial_transforms.py index ac1112a68b..768836afff 100644 --- a/monai/networks/layers/spatial_transforms.py +++ b/monai/networks/layers/spatial_transforms.py @@ -35,7 +35,7 @@ def forward(ctx, input, grid, interpolation, bound, extrapolate): @staticmethod def backward(ctx, grad): - var = ctx.saved_variables + var = ctx.saved_tensors opt = ctx.opt grad_input = grad_grid = None grads = _C.grid_pull_backward(grad, *var, *opt) @@ -131,7 +131,7 @@ def forward(ctx, input, grid, shape, interpolation, bound, extrapolate): @staticmethod def backward(ctx, grad): - var = ctx.saved_variables + var = ctx.saved_tensors opt = ctx.opt grad_input = grad_grid = None grads = _C.grid_push_backward(grad, *var, *opt) @@ -234,7 +234,7 @@ def forward(ctx, grid, shape, interpolation, bound, extrapolate): @staticmethod def backward(ctx, grad): - var = ctx.saved_variables + var = ctx.saved_tensors opt = ctx.opt grad_grid = None if ctx.needs_input_grad[0]: @@ -331,7 +331,7 @@ def forward(ctx, input, grid, interpolation, bound, extrapolate): @staticmethod def backward(ctx, grad): - var = ctx.saved_variables + var = ctx.saved_tensors opt = ctx.opt grad_input = grad_grid = None if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]: diff --git a/tests/test_grid_pull.py b/tests/test_grid_pull.py index bdd00f0bbd..93db2c4cbd 100644 --- a/tests/test_grid_pull.py +++ b/tests/test_grid_pull.py @@ -17,7 +17,7 @@ from monai.networks.layers import grid_pull from monai.utils import optional_import -from tests.testing_data.cpp_resample_answers import Expected_1D_BP_fwd +from tests.testing_data.cpp_resample_answers import Expected_1D_GP_fwd from tests.utils import skip_if_no_cpp_extension BType, has_b_type = optional_import("monai._C", name="BoundType") @@ -25,7 +25,7 @@ def make_grid(shape, dtype=None, device=None): - ranges = [torch.arange(float(s), dtype=dtype, device=device) for s in shape] + ranges = [torch.arange(float(s), dtype=dtype, device=device, requires_grad=True) for s in shape] grid = torch.stack(torch.meshgrid(*ranges), dim=-1) return grid[None] @@ -33,28 +33,40 @@ def make_grid(shape, dtype=None, device=None): # 1D combinations of bounds/interpolations bounds = set(BType.__members__.values()) if has_b_type else [] interps = set(PType.__members__.values()) if has_p_type else [] -assert len(bounds) * len(interps) == len(Expected_1D_BP_fwd) # all combinations -TEST_1D_BP_fwd = [] +TEST_1D_GP_fwd = [] for bound in bounds: for interp in interps: + if not Expected_1D_GP_fwd: + break test_case = [ { - "input": torch.arange(10, dtype=torch.float).reshape((1, 1, 10)), + "input": torch.arange(10, dtype=torch.float, requires_grad=True).reshape((1, 1, 10)), "grid": make_grid((20,), dtype=torch.float) + 0.5, "interpolation": interp, "bound": bound, }, - torch.tensor([[Expected_1D_BP_fwd.pop(0)]]), + torch.tensor([[Expected_1D_GP_fwd.pop(0)]]), ] - TEST_1D_BP_fwd.append(test_case) + TEST_1D_GP_fwd.append(test_case) @skip_if_no_cpp_extension class TestGridPull(unittest.TestCase): - @parameterized.expand(TEST_1D_BP_fwd) + @parameterized.expand(TEST_1D_GP_fwd, skip_on_empty=True) def test_grid_pull(self, input_param, expected_val): result = grid_pull(**input_param) - np.testing.assert_allclose(result.cpu().numpy(), expected_val.cpu().numpy(), rtol=1e-4, atol=1e-4) + np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val.cpu().numpy(), rtol=1e-4, atol=1e-4) + + @parameterized.expand(TEST_1D_GP_fwd, skip_on_empty=True) + def test_grid_pull_grad(self, input_param, expected_val): + result = grid_pull(**input_param) + input_param["input"].retain_grad() + input_param["grid"].retain_grad() + result.sum().backward() + print("--" * 15) + print(input_param["interpolation"], input_param["bound"]) + print(input_param["input"].grad) + print(input_param["grid"].grad) if __name__ == "__main__": diff --git a/tests/testing_data/cpp_resample_answers.txt b/tests/testing_data/1D_BP_fwd.txt similarity index 100% rename from tests/testing_data/cpp_resample_answers.txt rename to tests/testing_data/1D_BP_fwd.txt diff --git a/tests/testing_data/cpp_resample_answers.py b/tests/testing_data/cpp_resample_answers.py index 19d50bd9a3..b50b1d7395 100644 --- a/tests/testing_data/cpp_resample_answers.py +++ b/tests/testing_data/cpp_resample_answers.py @@ -11,15 +11,30 @@ import csv import os +import warnings +from typing import List, Optional -Expected_1D_BP_fwd = [] -pwd = os.path.dirname(os.path.abspath(__file__)) # current file's location -with open(os.path.join(pwd, "cpp_resample_answers.txt")) as f: - res_reader = csv.reader(f, delimiter=",") - for r in res_reader: - res_row = [] - for item in r: - if item.strip().startswith("#"): - continue - res_row.append(float(item)) - Expected_1D_BP_fwd.append(res_row) + +def _read_testing_data_answers(fname: Optional[str] = None, delimiter=",") -> List: + answers: List = [] + if not fname: + return answers + # read answers from directory of the current file + pwd = os.path.dirname(os.path.abspath(__file__)) + filename = os.path.join(pwd, fname) + if not os.path.isfile(filename): + warnings.warn("test data {} not found.".format(filename)) + return answers + with open(filename) as f: + res_reader = csv.reader(f, delimiter=delimiter) + for r in res_reader: + res_row = [] + for item in r: + if item.strip().startswith("#"): + continue # allow for some simple comments in the file + res_row.append(float(item)) + answers.append(res_row) + return answers + + +Expected_1D_GP_fwd: List = _read_testing_data_answers(fname="1D_BP_fwd.txt") From 39f6d5d8c9977e1d1e997379fed943ed25889227 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Fri, 19 Feb 2021 05:20:46 -0500 Subject: [PATCH 13/14] adds device tests Signed-off-by: Wenqi Li --- monai/networks/layers/spatial_transforms.py | 24 +++++++++------------ tests/test_grid_pull.py | 19 ++++++++++------ 2 files changed, 22 insertions(+), 21 deletions(-) diff --git a/monai/networks/layers/spatial_transforms.py b/monai/networks/layers/spatial_transforms.py index 768836afff..2b8737f2fa 100644 --- a/monai/networks/layers/spatial_transforms.py +++ b/monai/networks/layers/spatial_transforms.py @@ -35,17 +35,15 @@ def forward(ctx, input, grid, interpolation, bound, extrapolate): @staticmethod def backward(ctx, grad): + if not (ctx.needs_input_grad[0] or ctx.needs_input_grad[1]): + return None, None, None, None, None var = ctx.saved_tensors opt = ctx.opt - grad_input = grad_grid = None grads = _C.grid_pull_backward(grad, *var, *opt) if ctx.needs_input_grad[0]: - grad_input = grads[0] - if ctx.needs_input_grad[1]: - grad_grid = grads[1] - elif ctx.needs_input_grad[1]: - grad_grid = grads[0] - return grad_input, grad_grid, None, None, None + return grads[0], grads[1] if ctx.needs_input_grad[1] else None, None, None, None + if ctx.needs_input_grad[1]: + return None, grads[0], None, None, None def grid_pull(input: torch.Tensor, grid: torch.Tensor, interpolation="linear", bound="zero", extrapolate: bool = True): @@ -131,17 +129,15 @@ def forward(ctx, input, grid, shape, interpolation, bound, extrapolate): @staticmethod def backward(ctx, grad): + if not (ctx.needs_input_grad[0] or ctx.needs_input_grad[1]): + return None, None, None, None, None, None var = ctx.saved_tensors opt = ctx.opt - grad_input = grad_grid = None grads = _C.grid_push_backward(grad, *var, *opt) if ctx.needs_input_grad[0]: - grad_input = grads[0] - if ctx.needs_input_grad[1]: - grad_grid = grads[1] - elif ctx.needs_input_grad[1]: - grad_grid = grads[0] - return grad_input, grad_grid, None, None, None, None + return grads[0], grads[1] if ctx.needs_input_grad[1] else None, None, None, None, None + if ctx.needs_input_grad[1]: + return None, grads[0], None, None, None, None def grid_push( diff --git a/tests/test_grid_pull.py b/tests/test_grid_pull.py index 93db2c4cbd..daecbd4709 100644 --- a/tests/test_grid_pull.py +++ b/tests/test_grid_pull.py @@ -33,19 +33,23 @@ def make_grid(shape, dtype=None, device=None): # 1D combinations of bounds/interpolations bounds = set(BType.__members__.values()) if has_b_type else [] interps = set(PType.__members__.values()) if has_p_type else [] +device = "cuda" if torch.cuda.is_available() else "cpu" TEST_1D_GP_fwd = [] for bound in bounds: for interp in interps: if not Expected_1D_GP_fwd: - break + break # skip if the testing data are unavailable test_case = [ { - "input": torch.arange(10, dtype=torch.float, requires_grad=True).reshape((1, 1, 10)), - "grid": make_grid((20,), dtype=torch.float) + 0.5, + "input": torch.arange(10, dtype=torch.float, requires_grad=True, device=device).reshape((1, 1, 10)), + "grid": make_grid((20,), dtype=torch.float, device=device) + 0.5, "interpolation": interp, "bound": bound, }, - torch.tensor([[Expected_1D_GP_fwd.pop(0)]]), + { + "val": torch.tensor([[Expected_1D_GP_fwd.pop(0)]]), + "device": device, + }, ] TEST_1D_GP_fwd.append(test_case) @@ -53,12 +57,13 @@ def make_grid(shape, dtype=None, device=None): @skip_if_no_cpp_extension class TestGridPull(unittest.TestCase): @parameterized.expand(TEST_1D_GP_fwd, skip_on_empty=True) - def test_grid_pull(self, input_param, expected_val): + def test_grid_pull(self, input_param, expected): result = grid_pull(**input_param) - np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val.cpu().numpy(), rtol=1e-4, atol=1e-4) + self.assertTrue("{}".format(result.device).startswith(expected["device"])) + np.testing.assert_allclose(result.detach().cpu().numpy(), expected["val"].cpu().numpy(), rtol=1e-4, atol=1e-4) @parameterized.expand(TEST_1D_GP_fwd, skip_on_empty=True) - def test_grid_pull_grad(self, input_param, expected_val): + def test_grid_pull_grad(self, input_param, expected): result = grid_pull(**input_param) input_param["input"].retain_grad() input_param["grid"].retain_grad() From 4eec57499ec4214a5d8a174487bbd0d6af610742 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Fri, 19 Feb 2021 14:27:29 +0000 Subject: [PATCH 14/14] bwd tests Signed-off-by: Wenqi Li --- monai/networks/layers/spatial_transforms.py | 26 +-- tests/test_grid_pull.py | 76 ++++--- tests/testing_data/1D_BP_bwd.txt | 224 ++++++++++++++++++++ tests/testing_data/cpp_resample_answers.py | 1 + 4 files changed, 282 insertions(+), 45 deletions(-) create mode 100644 tests/testing_data/1D_BP_bwd.txt diff --git a/monai/networks/layers/spatial_transforms.py b/monai/networks/layers/spatial_transforms.py index 2b8737f2fa..03031f3340 100644 --- a/monai/networks/layers/spatial_transforms.py +++ b/monai/networks/layers/spatial_transforms.py @@ -230,12 +230,11 @@ def forward(ctx, grid, shape, interpolation, bound, extrapolate): @staticmethod def backward(ctx, grad): - var = ctx.saved_tensors - opt = ctx.opt - grad_grid = None if ctx.needs_input_grad[0]: - grad_grid = _C.grid_count_backward(grad, *var, *opt) - return grad_grid, None, None, None, None + var = ctx.saved_tensors + opt = ctx.opt + return _C.grid_count_backward(grad, *var, *opt), None, None, None, None + return None, None, None, None, None def grid_count(grid: torch.Tensor, shape=None, interpolation="linear", bound="zero", extrapolate: bool = True): @@ -327,18 +326,15 @@ def forward(ctx, input, grid, interpolation, bound, extrapolate): @staticmethod def backward(ctx, grad): + if not (ctx.needs_input_grad[0] or ctx.needs_input_grad[1]): + return None, None, None, None, None var = ctx.saved_tensors opt = ctx.opt - grad_input = grad_grid = None - if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]: - grads = _C.grid_grad_backward(grad, *var, *opt) - if ctx.needs_input_grad[0]: - grad_input = grads[0] - if ctx.needs_input_grad[1]: - grad_grid = grads[1] - elif ctx.needs_input_grad[1]: - grad_grid = grads[0] - return grad_input, grad_grid, None, None, None + grads = _C.grid_grad_backward(grad, *var, *opt) + if ctx.needs_input_grad[0]: + return grads[0], grads[1] if ctx.needs_input_grad[1] else None, None, None, None + if ctx.needs_input_grad[1]: + return None, grads[0], None, None, None def grid_grad(input: torch.Tensor, grid: torch.Tensor, interpolation="linear", bound="zero", extrapolate: bool = True): diff --git a/tests/test_grid_pull.py b/tests/test_grid_pull.py index daecbd4709..9e4d2e8237 100644 --- a/tests/test_grid_pull.py +++ b/tests/test_grid_pull.py @@ -17,15 +17,15 @@ from monai.networks.layers import grid_pull from monai.utils import optional_import -from tests.testing_data.cpp_resample_answers import Expected_1D_GP_fwd +from tests.testing_data.cpp_resample_answers import Expected_1D_GP_bwd, Expected_1D_GP_fwd from tests.utils import skip_if_no_cpp_extension BType, has_b_type = optional_import("monai._C", name="BoundType") PType, has_p_type = optional_import("monai._C", name="InterpolationType") -def make_grid(shape, dtype=None, device=None): - ranges = [torch.arange(float(s), dtype=dtype, device=device, requires_grad=True) for s in shape] +def make_grid(shape, dtype=None, device=None, requires_grad=True): + ranges = [torch.arange(float(s), dtype=dtype, device=device, requires_grad=requires_grad) for s in shape] grid = torch.stack(torch.meshgrid(*ranges), dim=-1) return grid[None] @@ -34,44 +34,60 @@ def make_grid(shape, dtype=None, device=None): bounds = set(BType.__members__.values()) if has_b_type else [] interps = set(PType.__members__.values()) if has_p_type else [] device = "cuda" if torch.cuda.is_available() else "cpu" -TEST_1D_GP_fwd = [] +TEST_1D_GP = [] for bound in bounds: for interp in interps: - if not Expected_1D_GP_fwd: + if not Expected_1D_GP_fwd or not Expected_1D_GP_bwd: break # skip if the testing data are unavailable - test_case = [ - { - "input": torch.arange(10, dtype=torch.float, requires_grad=True, device=device).reshape((1, 1, 10)), - "grid": make_grid((20,), dtype=torch.float, device=device) + 0.5, - "interpolation": interp, - "bound": bound, - }, - { - "val": torch.tensor([[Expected_1D_GP_fwd.pop(0)]]), - "device": device, - }, - ] - TEST_1D_GP_fwd.append(test_case) + expected_val = Expected_1D_GP_fwd.pop(0) + + for input_g in (True, False): + for grid_g in (True, False): + expected_grad = Expected_1D_GP_bwd.pop(0) + test_case = [ + { + "input": torch.arange(10, dtype=torch.float, requires_grad=input_g, device=device).reshape( + (1, 1, 10) + ), + "grid": make_grid((20,), dtype=torch.float, device=device, requires_grad=grid_g) + 0.5, + "interpolation": interp, + "bound": bound, + }, + { + "val": torch.tensor([[expected_val]]), + "device": device, + "grad": torch.tensor(expected_grad), + }, + ] + TEST_1D_GP.append(test_case) @skip_if_no_cpp_extension class TestGridPull(unittest.TestCase): - @parameterized.expand(TEST_1D_GP_fwd, skip_on_empty=True) + @parameterized.expand(TEST_1D_GP, skip_on_empty=True) def test_grid_pull(self, input_param, expected): result = grid_pull(**input_param) + if input_param["input"].requires_grad: + input_param["input"].retain_grad() + if input_param["grid"].requires_grad: + input_param["grid"].retain_grad() + if input_param["input"].requires_grad or input_param["grid"].requires_grad: + result.sum().backward() + + grads = [] + if input_param["input"].requires_grad: + grads.append(input_param["input"].grad.view(-1)) + if input_param["grid"].requires_grad: + grads.append(input_param["grid"].grad.view(-1)) + if not grads: + grads = torch.tensor(0.0, device=result.device) + elif len(grads) == 1: + grads = grads[0] + else: + grads = torch.cat(grads, dim=0) self.assertTrue("{}".format(result.device).startswith(expected["device"])) np.testing.assert_allclose(result.detach().cpu().numpy(), expected["val"].cpu().numpy(), rtol=1e-4, atol=1e-4) - - @parameterized.expand(TEST_1D_GP_fwd, skip_on_empty=True) - def test_grid_pull_grad(self, input_param, expected): - result = grid_pull(**input_param) - input_param["input"].retain_grad() - input_param["grid"].retain_grad() - result.sum().backward() - print("--" * 15) - print(input_param["interpolation"], input_param["bound"]) - print(input_param["input"].grad) - print(input_param["grid"].grad) + np.testing.assert_allclose(grads.detach().cpu().numpy(), expected["grad"].cpu().numpy(), rtol=1e-4, atol=1e-4) if __name__ == "__main__": diff --git a/tests/testing_data/1D_BP_bwd.txt b/tests/testing_data/1D_BP_bwd.txt new file mode 100644 index 0000000000..de43270e94 --- /dev/null +++ b/tests/testing_data/1D_BP_bwd.txt @@ -0,0 +1,224 @@ +0., 1., 1., 1., 1., 1., 1., 1., 1.,12., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., # InterpolationType.nearest BoundType.replicate +0., 1., 1., 1., 1., 1., 1., 1., 1.,12., # InterpolationType.nearest BoundType.replicate +0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0., # InterpolationType.nearest BoundType.replicate +0., # InterpolationType.nearest BoundType.replicate +0.5, 1. , 1. , 1. , 1. , 1. , 1. , 1. , 1. ,11.5, 1. , 1. , 1. , 1. , 1. , 1. , 1. , 1. , 1. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , # InterpolationType.linear BoundType.replicate +0.5, 1. , 1. , 1. , 1. , 1. , 1. , 1. , 1. ,11.5, # InterpolationType.linear BoundType.replicate +1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0., # InterpolationType.linear BoundType.replicate +0., # InterpolationType.linear BoundType.replicate +0.5, 1. , 1. , 1. , 1. , 1. , 1. , 1. , 1. ,11.5, 1. , 1. , 1. , 1. , 1. , 1. , 1. , 1. , 1. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , # InterpolationType.quadratic BoundType.replicate +0.5, 1. , 1. , 1. , 1. , 1. , 1. , 1. , 1. ,11.5, # InterpolationType.quadratic BoundType.replicate +1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0., # InterpolationType.quadratic BoundType.replicate +0., # InterpolationType.quadratic BoundType.replicate +0.5208333 , 0.9791666 , 0.99999994, 0.99999994, 0.99999994, 0.99999994, 0.99999994, 0.99999994, 0.99999994,11.5 , 0.875 , 1. , 1. , 1. , 1. , 1. , 1. , 1. , 0.875 , 0.125 , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , # InterpolationType.cubic BoundType.replicate +0.5208333 , 0.9791666 , 0.99999994, 0.99999994, 0.99999994, 0.99999994, 0.99999994, 0.99999994, 0.99999994,11.5 , # InterpolationType.cubic BoundType.replicate +0.875,1. ,1. ,1. ,1. ,1. ,1. ,1. ,0.875,0.125,0. ,0. ,0. ,0. ,0. ,0. ,0. ,0. ,0. ,0. , # InterpolationType.cubic BoundType.replicate +0., # InterpolationType.cubic BoundType.replicate +0.5416667 , 0.9583334 , 1. , 1. , 1. , 1. , 1. , 1. , 1. ,11.5 , 0.8333334 , 1. , 1. , 1. , 1. , 1. , 0.9999999 , 1. , 0.833333 , 0.16666651, 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , # InterpolationType.fourth BoundType.replicate +0.5416667, 0.9583334, 1. , 1. , 1. , 1. , 1. , 1. , 1. ,11.5 , # InterpolationType.fourth BoundType.replicate +0.8333334 ,1. ,1. ,1. ,1. ,1. ,0.9999999 ,1. ,0.833333 ,0.16666651,0. ,0. ,0. ,0. ,0. ,0. ,0. ,0. ,0. ,0. , # InterpolationType.fourth BoundType.replicate +0., # InterpolationType.fourth BoundType.replicate +5.6223959e-01,9.3802083e-01,9.9973959e-01,1.0000000e+00,1.0000000e+00,1.0000000e+00,1.0000000e+00,1.0000000e+00,1.0000000e+00,1.1499999e+01,7.9947913e-01,9.9739581e-01,1.0000000e+00,1.0000000e+00,9.9999994e-01,1.0000001e+00,9.9999976e-01,9.9739575e-01,7.9947948e-01,2.0052099e-01,2.6040077e-03,0.,0.,0.,0.,0.,0.,0.,0.,0., # InterpolationType.fifth BoundType.replicate +0.5622396, 0.9380208, 0.9997396, 1. , 1. , 1. , 1. , 1. , 1. ,11.499999 , # InterpolationType.fifth BoundType.replicate +0.7994791 ,0.9973958 ,1. ,1. ,0.99999994,1.0000001 ,0.99999976,0.99739575,0.7994795 ,0.20052099,0.00260401,0. ,0. ,0. ,0. ,0. ,0. ,0. ,0. ,0. , # InterpolationType.fifth BoundType.replicate +0., # InterpolationType.fifth BoundType.replicate +5.8194447e-01,9.1944444e-01,9.9861109e-01,1.0000000e+00,1.0000000e+00,1.0000000e+00,1.0000000e+00,1.0000000e+00,1.0000000e+00,1.1499997e+01,7.7499998e-01,9.9166673e-01,1.0000000e+00,1.0000000e+00,1.0000000e+00,9.9999982e-01,1.0000004e+00,9.9166673e-01,7.7499980e-01,2.2500010e-01,8.3333999e-03,1.9371510e-07,1.9371510e-07,1.9371510e-07,1.9371510e-07,1.9371510e-07,1.9371510e-07,1.9371510e-07,1.9371510e-07,1.9371510e-07, # InterpolationType.sixth BoundType.replicate +0.58194447, 0.91944444, 0.9986111 , 1. , 1. , 1. , 1. , 1. , 1. ,11.499997 , # InterpolationType.sixth BoundType.replicate +7.7499998e-01,9.9166673e-01,1.0000000e+00,1.0000000e+00,1.0000000e+00,9.9999982e-01,1.0000004e+00,9.9166673e-01,7.7499980e-01,2.2500010e-01,8.3333999e-03,1.9371510e-07,1.9371510e-07,1.9371510e-07,1.9371510e-07,1.9371510e-07,1.9371510e-07,1.9371510e-07,1.9371510e-07,1.9371510e-07, # InterpolationType.sixth BoundType.replicate +0., # InterpolationType.sixth BoundType.replicate +6.0078436e-01,9.0259641e-01,9.9662077e-01,9.9999845e-01,1.0000000e+00,1.0000000e+00,1.0000000e+00,1.0000000e+00,1.0000000e+00,1.1500004e+01,7.5551212e-01,9.8430985e-01,9.9997836e-01,9.9999994e-01,1.0000000e+00,1.0000001e+00,9.9997842e-01,9.8431003e-01,7.5551212e-01,2.4448761e-01,1.5690181e-02,2.1788481e-05,3.3080869e-07,3.3080869e-07,3.3080869e-07,3.3080869e-07,3.3080869e-07,3.3080869e-07,3.3080869e-07,3.3080869e-07, # InterpolationType.seventh BoundType.replicate +0.60078436, 0.9025964 , 0.9966208 , 0.99999845, 1. , 1. , 1. , 1. , 1. ,11.500004 , # InterpolationType.seventh BoundType.replicate +7.5551212e-01,9.8430985e-01,9.9997836e-01,9.9999994e-01,1.0000000e+00,1.0000001e+00,9.9997842e-01,9.8431003e-01,7.5551212e-01,2.4448761e-01,1.5690181e-02,2.1788481e-05,3.3080869e-07,3.3080869e-07,3.3080869e-07,3.3080869e-07,3.3080869e-07,3.3080869e-07,3.3080869e-07,3.3080869e-07, # InterpolationType.seventh BoundType.replicate +0., # InterpolationType.seventh BoundType.replicate +1.,3.,3.,2.,2.,2.,2.,2.,2.,1.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0., # InterpolationType.nearest BoundType.dct1 +1.,3.,3.,2.,2.,2.,2.,2.,2.,1., # InterpolationType.nearest BoundType.dct1 +0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0., # InterpolationType.nearest BoundType.dct1 +0., # InterpolationType.nearest BoundType.dct1 +1.5, 3. , 2.5, 2. , 2. , 2. , 2. , 2. , 2. , 1. , 1. , 1. , 1. , 1. , 1. , 1. , 1. , 1. , 1. ,-1. ,-1. ,-1. ,-1. ,-1. ,-1. ,-1. ,-1. ,-1. , 1. , 1. , # InterpolationType.linear BoundType.dct1 +1.5,3. ,2.5,2. ,2. ,2. ,2. ,2. ,2. ,1. , # InterpolationType.linear BoundType.dct1 +1., 1., 1., 1., 1., 1., 1., 1., 1.,-1.,-1.,-1.,-1.,-1.,-1.,-1.,-1.,-1., 1., 1., # InterpolationType.linear BoundType.dct1 +0., # InterpolationType.linear BoundType.dct1 +1.5, 3. , 2.5, 2. , 2. , 2. , 2. , 2. , 2. , 1. , 1. , 1. , 1. , 1. , 1. , 1. , 1. , 1. , 1. ,-1. ,-1. ,-1. ,-1. ,-1. ,-1. ,-1. ,-1. ,-1. , 1. , 1. , # InterpolationType.quadratic BoundType.dct1 +1.5,3. ,2.5,2. ,2. ,2. ,2. ,2. ,2. ,1. , # InterpolationType.quadratic BoundType.dct1 +1., 1., 1., 1., 1., 1., 1., 1., 1.,-1.,-1.,-1.,-1.,-1.,-1.,-1.,-1.,-1., 1., 1., # InterpolationType.quadratic BoundType.dct1 +0., # InterpolationType.quadratic BoundType.dct1 +1.5 , 2.9791667 , 2.5 , 2.0208333 , 1.9999999 , 1.9999999 , 1.9999999 , 1.9999999 , 1.9999999 , 0.99999994, 0.75 , 1. , 1. , 1. , 1. , 1. , 1. , 1. , 0.75 ,-0.75 ,-1. ,-1. ,-1. ,-1. ,-1. ,-1. ,-1. ,-0.75 , 0.75 , 1. , # InterpolationType.cubic BoundType.dct1 +1.5 ,2.9791667 ,2.5 ,2.0208333 ,1.9999999 ,1.9999999 ,1.9999999 ,1.9999999 ,1.9999999 ,0.99999994, # InterpolationType.cubic BoundType.dct1 +0.75, 1. , 1. , 1. , 1. , 1. , 1. , 1. , 0.75,-0.75,-1. ,-1. ,-1. ,-1. ,-1. ,-1. ,-1. ,-0.75, 0.75, 1. , # InterpolationType.cubic BoundType.dct1 +0., # InterpolationType.cubic BoundType.dct1 +1.5 , 2.9583333 , 2.5 , 2.0416667 , 2. , 2. , 2. , 2. , 2. , 1. , 0.6666666 , 1. , 1. , 1. , 1. , 1. , 0.9999999 , 1. , 0.6666664 ,-0.66666675,-1. ,-1.0000001 ,-1.0000002 ,-1. ,-1.0000001 ,-1.0000001 ,-1. ,-0.6666667 , 0.6666666 , 1. , # InterpolationType.fourth BoundType.dct1 +1.5 ,2.9583333,2.5 ,2.0416667,2. ,2. ,2. ,2. ,2. ,1. , # InterpolationType.fourth BoundType.dct1 +0.6666666 , 1. , 1. , 1. , 1. , 1. , 0.9999999 , 1. , 0.6666664 ,-0.66666675,-1. ,-1.0000001 ,-1.0000002 ,-1. ,-1.0000001 ,-1.0000001 ,-1. ,-0.6666667 , 0.6666666 , 1. , # InterpolationType.fourth BoundType.dct1 +0., # InterpolationType.fourth BoundType.dct1 +1.4997395 , 2.9380207 , 2.5 , 2.061979 , 2.0002604 , 2. , 2. , 2. , 2. , 1. , 0.5989583 , 0.9947917 , 1. , 1. , 0.99999994, 1.0000001 , 0.99999976, 0.99479157, 0.5989587 ,-0.59895825,-0.9947917 ,-0.9999998 ,-1.0000002 ,-1. ,-0.9999998 ,-1. ,-0.9947917 ,-0.5989583 , 0.5989583 , 0.9947917 , # InterpolationType.fifth BoundType.dct1 +1.4997395,2.9380207,2.5 ,2.061979 ,2.0002604,2. ,2. ,2. ,2. ,1. , # InterpolationType.fifth BoundType.dct1 +0.5989583 , 0.9947917 , 1. , 1. , 0.99999994, 1.0000001 , 0.99999976, 0.99479157, 0.5989587 ,-0.59895825,-0.9947917 ,-0.9999998 ,-1.0000002 ,-1. ,-0.9999998 ,-1. ,-0.9947917 ,-0.5989583 , 0.5989583 , 0.9947917 , # InterpolationType.fifth BoundType.dct1 +0., # InterpolationType.fifth BoundType.dct1 +1.498611 , 2.919444 , 2.5 , 2.0805554 , 2.0013888 , 2. , 2. , 2. , 2. , 1. , 0.54999995, 0.9833334 , 1. , 1. , 1. , 0.9999998 , 1.0000004 , 0.9833334 , 0.5499998 ,-0.5499999 ,-0.9833334 ,-1.0000004 ,-1.0000001 ,-1.0000001 ,-1. ,-0.99999994,-0.98333335,-0.55 , 0.54999995, 0.9833334 , # InterpolationType.sixth BoundType.dct1 +1.498611 ,2.919444 ,2.5 ,2.0805554,2.0013888,2. ,2. ,2. ,2. ,1. , # InterpolationType.sixth BoundType.dct1 +0.54999995, 0.9833334 , 1. , 1. , 1. , 0.9999998 , 1.0000004 , 0.9833334 , 0.5499998 ,-0.5499999 ,-0.9833334 ,-1.0000004 ,-1.0000001 ,-1.0000001 ,-1. ,-0.99999994,-0.98333335,-0.55 , 0.54999995, 0.9833334 , # InterpolationType.sixth BoundType.dct1 +0., # InterpolationType.sixth BoundType.dct1 +1.4966209 , 2.9025953 , 2.5000002 , 2.097404 , 2.0033796 , 2.000002 , 2.0000002 , 2.0000002 , 2.0000002 , 1. , 0.5110243 , 0.9686197 , 0.99995667, 0.99999994, 1. , 1.0000001 , 0.9999567 , 0.96861994, 0.51102436,-0.5110245 ,-0.9686197 ,-0.99995685,-1. ,-1. ,-1.0000001 ,-0.99995655,-0.9686198 ,-0.5110243 , 0.5110243 , 0.9686197 , # InterpolationType.seventh BoundType.dct1 +1.4966209,2.9025953,2.5000002,2.097404 ,2.0033796,2.000002 ,2.0000002,2.0000002,2.0000002,1. , # InterpolationType.seventh BoundType.dct1 +0.5110243 , 0.9686197 , 0.99995667, 0.99999994, 1. , 1.0000001 , 0.9999567 , 0.96861994, 0.51102436,-0.5110245 ,-0.9686197 ,-0.99995685,-1. ,-1. ,-1.0000001 ,-0.99995655,-0.9686198 ,-0.5110243 , 0.5110243 , 0.9686197 , # InterpolationType.seventh BoundType.dct1 +0., # InterpolationType.seventh BoundType.dct1 +2.,2.,2.,2.,2.,2.,2.,2.,2.,2.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0., # InterpolationType.nearest BoundType.dct2 +2.,2.,2.,2.,2.,2.,2.,2.,2.,2., # InterpolationType.nearest BoundType.dct2 +0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0., # InterpolationType.nearest BoundType.dct2 +0., # InterpolationType.nearest BoundType.dct2 +2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0.,-1.,-1.,-1.,-1.,-1.,-1.,-1.,-1.,-1., 0., # InterpolationType.linear BoundType.dct2 +2.,2.,2.,2.,2.,2.,2.,2.,2.,2., # InterpolationType.linear BoundType.dct2 +1., 1., 1., 1., 1., 1., 1., 1., 1., 0.,-1.,-1.,-1.,-1.,-1.,-1.,-1.,-1.,-1., 0., # InterpolationType.linear BoundType.dct2 +0., # InterpolationType.linear BoundType.dct2 +2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0.,-1.,-1.,-1.,-1.,-1.,-1.,-1.,-1.,-1., 0., # InterpolationType.quadratic BoundType.dct2 +2.,2.,2.,2.,2.,2.,2.,2.,2.,2., # InterpolationType.quadratic BoundType.dct2 +1., 1., 1., 1., 1., 1., 1., 1., 1., 0.,-1.,-1.,-1.,-1.,-1.,-1.,-1.,-1.,-1., 0., # InterpolationType.quadratic BoundType.dct2 +0., # InterpolationType.quadratic BoundType.dct2 +1.9999999, 2. , 1.9999999, 1.9999999, 1.9999999, 1.9999999, 1.9999999, 1.9999999, 1.9999999, 2. , 0.875 , 1. , 1. , 1. , 1. , 1. , 1. , 1. , 0.875 , 0. ,-0.875 ,-1. ,-1. ,-1. ,-1. ,-1. ,-1. ,-1. ,-0.875 , 0. , # InterpolationType.cubic BoundType.dct2 +1.9999999,2. ,1.9999999,1.9999999,1.9999999,1.9999999,1.9999999,1.9999999,1.9999999,2. , # InterpolationType.cubic BoundType.dct2 +0.875, 1. , 1. , 1. , 1. , 1. , 1. , 1. , 0.875, 0. ,-0.875,-1. ,-1. ,-1. ,-1. ,-1. ,-1. ,-1. ,-0.875, 0. , # InterpolationType.cubic BoundType.dct2 +0., # InterpolationType.cubic BoundType.dct2 +2.0000000e+00, 2.0000000e+00, 2.0000000e+00, 2.0000000e+00, 2.0000000e+00, 2.0000000e+00, 2.0000000e+00, 2.0000000e+00, 2.0000000e+00, 2.0000000e+00, 8.3333337e-01, 1.0000000e+00, 1.0000000e+00, 1.0000000e+00, 1.0000000e+00, 1.0000000e+00, 9.9999988e-01, 1.0000000e+00, 8.3333302e-01,-1.1920929e-07,-8.3333325e-01,-1.0000000e+00,-1.0000001e+00,-1.0000002e+00,-1.0000000e+00,-1.0000001e+00,-1.0000001e+00,-1.0000000e+00,-8.3333337e-01, 0., # InterpolationType.fourth BoundType.dct2 +2.,2.,2.,2.,2.,2.,2.,2.,2.,2., # InterpolationType.fourth BoundType.dct2 +8.3333337e-01, 1.0000000e+00, 1.0000000e+00, 1.0000000e+00, 1.0000000e+00, 1.0000000e+00, 9.9999988e-01, 1.0000000e+00, 8.3333302e-01,-1.1920929e-07,-8.3333325e-01,-1.0000000e+00,-1.0000001e+00,-1.0000002e+00,-1.0000000e+00,-1.0000001e+00,-1.0000001e+00,-1.0000000e+00,-8.3333337e-01, 0., # InterpolationType.fourth BoundType.dct2 +0., # InterpolationType.fourth BoundType.dct2 +2.0000000e+00, 1.9999999e+00, 2.0000000e+00, 2.0000000e+00, 2.0000000e+00, 2.0000000e+00, 2.0000000e+00, 2.0000000e+00, 2.0000000e+00, 2.0000000e+00, 7.9687500e-01, 9.9739581e-01, 1.0000000e+00, 1.0000000e+00, 9.9999994e-01, 1.0000001e+00, 9.9999976e-01, 9.9739575e-01, 7.9687530e-01, 1.6018748e-07,-7.9687524e-01,-9.9739569e-01,-9.9999982e-01,-1.0000002e+00,-1.0000000e+00,-9.9999982e-01,-1.0000000e+00,-9.9739587e-01,-7.9687494e-01, 5.1222742e-09, # InterpolationType.fifth BoundType.dct2 +2. ,1.9999999,2. ,2. ,2. ,2. ,2. ,2. ,2. ,2. , # InterpolationType.fifth BoundType.dct2 +7.9687500e-01, 9.9739581e-01, 1.0000000e+00, 1.0000000e+00, 9.9999994e-01, 1.0000001e+00, 9.9999976e-01, 9.9739575e-01, 7.9687530e-01, 1.6018748e-07,-7.9687524e-01,-9.9739569e-01,-9.9999982e-01,-1.0000002e+00,-1.0000000e+00,-9.9999982e-01,-1.0000000e+00,-9.9739587e-01,-7.9687494e-01, 5.1222742e-09, # InterpolationType.fifth BoundType.dct2 +0., # InterpolationType.fifth BoundType.dct2 +2.0000000e+00, 2.0000000e+00, 2.0000000e+00, 2.0000000e+00, 2.0000000e+00, 2.0000000e+00, 2.0000000e+00, 2.0000000e+00, 2.0000000e+00, 2.0000000e+00, 7.6666665e-01, 9.9166673e-01, 1.0000000e+00, 1.0000000e+00, 1.0000000e+00, 9.9999982e-01, 1.0000004e+00, 9.9166673e-01, 7.6666647e-01, 5.9604645e-08,-7.6666659e-01,-9.9166662e-01,-1.0000004e+00,-1.0000001e+00,-1.0000001e+00,-1.0000000e+00,-9.9999994e-01,-9.9166667e-01,-7.6666665e-01, 1.8626451e-09, # InterpolationType.sixth BoundType.dct2 +2.,2.,2.,2.,2.,2.,2.,2.,2.,2., # InterpolationType.sixth BoundType.dct2 +7.6666665e-01, 9.9166673e-01, 1.0000000e+00, 1.0000000e+00, 1.0000000e+00, 9.9999982e-01, 1.0000004e+00, 9.9166673e-01, 7.6666647e-01, 5.9604645e-08,-7.6666659e-01,-9.9166662e-01,-1.0000004e+00,-1.0000001e+00,-1.0000001e+00,-1.0000000e+00,-9.9999994e-01,-9.9166667e-01,-7.6666665e-01, 1.8626451e-09, # InterpolationType.sixth BoundType.dct2 +0., # InterpolationType.sixth BoundType.dct2 +2.0000002e+00, 2.0000000e+00, 2.0000000e+00, 2.0000002e+00, 2.0000002e+00, 2.0000002e+00, 2.0000002e+00, 2.0000002e+00, 2.0000002e+00, 2.0000002e+00, 7.3982203e-01, 9.8428816e-01, 9.9997836e-01, 9.9999994e-01, 1.0000000e+00, 1.0000001e+00, 9.9997842e-01, 9.8428833e-01, 7.3982203e-01,-1.6936974e-07,-7.3982191e-01,-9.8428810e-01,-9.9997830e-01,-1.0000000e+00,-1.0000000e+00,-1.0000001e+00,-9.9997824e-01,-9.8428822e-01,-7.3982203e-01,-2.7284841e-09, # InterpolationType.seventh BoundType.dct2 +2.0000002,2. ,2. ,2.0000002,2.0000002,2.0000002,2.0000002,2.0000002,2.0000002,2.0000002, # InterpolationType.seventh BoundType.dct2 +7.3982203e-01, 9.8428816e-01, 9.9997836e-01, 9.9999994e-01, 1.0000000e+00, 1.0000001e+00, 9.9997842e-01, 9.8428833e-01, 7.3982203e-01,-1.6936974e-07,-7.3982191e-01,-9.8428810e-01,-9.9997830e-01,-1.0000000e+00,-1.0000000e+00,-1.0000001e+00,-9.9997824e-01,-9.8428822e-01,-7.3982203e-01,-2.7284841e-09, # InterpolationType.seventh BoundType.dct2 +0., # InterpolationType.seventh BoundType.dct2 +-1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., # InterpolationType.nearest BoundType.dst1 +-1., 0., 0., 0., 0., 0., 0., 0., 0., 0., # InterpolationType.nearest BoundType.dst1 +0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0., # InterpolationType.nearest BoundType.dst1 +0., # InterpolationType.nearest BoundType.dst1 +0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1.,-9.,-9., 1., 1., 1., 1., 1., 1., 1., 1., 1., # InterpolationType.linear BoundType.dst1 +0.,0.,0.,0.,0.,0.,0.,0.,0.,0., # InterpolationType.linear BoundType.dst1 +1., 1., 1., 1., 1., 1., 1., 1., 1.,-9.,-9., 1., 1., 1., 1., 1., 1., 1., 1., 1., # InterpolationType.linear BoundType.dst1 +0., # InterpolationType.linear BoundType.dst1 +0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1.,-9.,-9., 1., 1., 1., 1., 1., 1., 1., 1., 1., # InterpolationType.quadratic BoundType.dst1 +0.,0.,0.,0.,0.,0.,0.,0.,0.,0., # InterpolationType.quadratic BoundType.dst1 +1., 1., 1., 1., 1., 1., 1., 1., 1.,-9.,-9., 1., 1., 1., 1., 1., 1., 1., 1., 1., # InterpolationType.quadratic BoundType.dst1 +0., # InterpolationType.quadratic BoundType.dst1 +0., 0.,-2.0489097e-08,-2.0489097e-08,-2.0489097e-08,-2.0489097e-08,-2.0489097e-08,-2.0489097e-08,-2.0489097e-08,-2.0489097e-08, 8.7500000e-01, 1.0000000e+00, 1.0000000e+00, 1.0000000e+00, 1.0000000e+00, 1.0000000e+00, 1.0000000e+00, 1.0000000e+00,-2.5000000e-01,-7.7500000e+00,-7.7500000e+00,-2.5000000e-01, 1.0000000e+00, 1.0000000e+00, 1.0000000e+00, 1.0000000e+00, 1.0000000e+00, 1.0000000e+00, 1.0000000e+00, 8.7500000e-01, # InterpolationType.cubic BoundType.dst1 +0., 0.,-2.0489097e-08,-2.0489097e-08,-2.0489097e-08,-2.0489097e-08,-2.0489097e-08,-2.0489097e-08,-2.0489097e-08,-2.0489097e-08, # InterpolationType.cubic BoundType.dst1 +0.875, 1. , 1. , 1. , 1. , 1. , 1. , 1. ,-0.25 ,-7.75 ,-7.75 ,-0.25 , 1. , 1. , 1. , 1. , 1. , 1. , 1. , 0.875, # InterpolationType.cubic BoundType.dst1 +0., # InterpolationType.cubic BoundType.dst1 +0., 0.,-4.0978193e-08,-4.0978193e-08,-4.0978193e-08,-4.0978193e-08,-4.0978193e-08,-4.0978193e-08,-4.0978193e-08,-1.1175871e-08, 8.3333337e-01, 1.0000000e+00, 1.0000000e+00, 1.0000000e+00, 1.0000000e+00, 1.0000000e+00, 9.9999988e-01, 1.0000000e+00,-6.6666698e-01,-7.3333335e+00,-7.3333335e+00,-6.6666675e-01, 1.0000000e+00, 1.0000001e+00, 1.0000002e+00, 1.0000000e+00, 1.0000001e+00, 1.0000001e+00, 1.0000000e+00, 8.3333337e-01, # InterpolationType.fourth BoundType.dst1 +0., 0.,-4.0978193e-08,-4.0978193e-08,-4.0978193e-08,-4.0978193e-08,-4.0978193e-08,-4.0978193e-08,-4.0978193e-08,-1.1175871e-08, # InterpolationType.fourth BoundType.dst1 +0.8333334 , 1. , 1. , 1. , 1. , 1. , 0.9999999 , 1. ,-0.666667 ,-7.3333335 ,-7.3333335 ,-0.66666675, 1. , 1.0000001 , 1.0000002 , 1. , 1.0000001 , 1.0000001 , 1. , 0.8333334 , # InterpolationType.fourth BoundType.dst1 +0., # InterpolationType.fourth BoundType.dst1 +3.9872248e-09, 0., 1.1175871e-08, 7.1886461e-09, 7.1886461e-09, 7.1886461e-09, 7.1886461e-09, 7.1886461e-09, 7.1886461e-09, 7.1886461e-09, 7.9947913e-01, 9.9739581e-01, 1.0000000e+00, 1.0000000e+00, 9.9999994e-01, 1.0000001e+00, 9.9999976e-01, 9.7395825e-01,-1.0052080e+00,-6.9687500e+00,-6.9687500e+00,-1.0052083e+00, 9.7395819e-01, 9.9999982e-01, 1.0000002e+00, 1.0000000e+00, 9.9999982e-01, 1.0000000e+00, 9.9739587e-01, 7.9947913e-01, # InterpolationType.fifth BoundType.dst1 +3.9872248e-09,0.,1.1175871e-08,7.1886461e-09,7.1886461e-09,7.1886461e-09,7.1886461e-09,7.1886461e-09,7.1886461e-09,7.1886461e-09, # InterpolationType.fifth BoundType.dst1 +0.7994791 , 0.9973958 , 1. , 1. , 0.99999994, 1.0000001 , 0.99999976, 0.97395825,-1.005208 ,-6.96875 ,-6.96875 ,-1.0052083 , 0.9739582 , 0.9999998 , 1.0000002 , 1. , 0.9999998 , 1. , 0.9973959 , 0.7994791 , # InterpolationType.fifth BoundType.dst1 +0., # InterpolationType.fifth BoundType.dst1 +4.1094609e-08, 0.,-1.4901161e-08, 3.6088750e-09, 3.6088750e-09, 3.6088750e-09, 3.6088750e-09, 3.6088750e-09, 3.6088750e-09,-2.6193447e-08, 7.7499998e-01, 9.9166673e-01, 1.0000000e+00, 1.0000000e+00, 1.0000000e+00, 9.9999982e-01, 1.0000004e+00, 9.1666675e-01,-1.2500002e+00,-6.6666665e+00,-6.6666665e+00,-1.2500000e+00, 9.1666681e-01, 1.0000004e+00, 1.0000001e+00, 1.0000001e+00, 1.0000000e+00, 9.9999994e-01, 9.9166667e-01, 7.7499998e-01, # InterpolationType.sixth BoundType.dst1 +4.1094609e-08, 0.,-1.4901161e-08, 3.6088750e-09, 3.6088750e-09, 3.6088750e-09, 3.6088750e-09, 3.6088750e-09, 3.6088750e-09,-2.6193447e-08, # InterpolationType.sixth BoundType.dst1 +0.775 , 0.99166673, 1. , 1. , 1. , 0.9999998 , 1.0000004 , 0.91666675,-1.2500002 ,-6.6666665 ,-6.6666665 ,-1.25 , 0.9166668 , 1.0000004 , 1.0000001 , 1.0000001 , 1. , 0.99999994, 0.9916667 , 0.775 , # InterpolationType.sixth BoundType.dst1 +0., # InterpolationType.sixth BoundType.dst1 +-9.7788870e-09, 3.7846348e-10,-7.4505806e-09, 2.3283064e-09, 1.9498430e-09, 1.9498430e-09, 1.9498430e-09, 1.9498430e-09, 1.9498430e-09, 1.9498430e-09, 7.5553381e-01, 9.8430985e-01, 9.9997836e-01, 9.9999994e-01, 1.0000000e+00, 1.0000001e+00, 9.9978310e-01, 8.4309906e-01,-1.4446614e+00,-6.3982205e+00,-6.3982205e+00,-1.4446614e+00, 8.4309900e-01, 9.9978304e-01, 1.0000000e+00, 1.0000000e+00, 1.0000001e+00, 9.9997824e-01, 9.8430991e-01, 7.5553381e-01, # InterpolationType.seventh BoundType.dst1 +-9.7788870e-09, 3.7846348e-10,-7.4505806e-09, 2.3283064e-09, 1.9498430e-09, 1.9498430e-09, 1.9498430e-09, 1.9498430e-09, 1.9498430e-09, 1.9498430e-09, # InterpolationType.seventh BoundType.dst1 +0.7555338 , 0.98430985, 0.99997836, 0.99999994, 1. , 1.0000001 , 0.9997831 , 0.84309906,-1.4446614 ,-6.3982205 ,-6.3982205 ,-1.4446614 , 0.843099 , 0.99978304, 1. , 1. , 1.0000001 , 0.99997824, 0.9843099 , 0.7555338 , # InterpolationType.seventh BoundType.dst1 +0., # InterpolationType.seventh BoundType.dst1 +0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0., # InterpolationType.nearest BoundType.dst2 +0.,0.,0.,0.,0.,0.,0.,0.,0.,0., # InterpolationType.nearest BoundType.dst2 +0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0., # InterpolationType.nearest BoundType.dst2 +0., # InterpolationType.nearest BoundType.dst2 + 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1.,-18., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., # InterpolationType.linear BoundType.dst2 +0.,0.,0.,0.,0.,0.,0.,0.,0.,0., # InterpolationType.linear BoundType.dst2 + 1., 1., 1., 1., 1., 1., 1., 1., 1.,-18., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., # InterpolationType.linear BoundType.dst2 +0., # InterpolationType.linear BoundType.dst2 + 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1.,-18., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., # InterpolationType.quadratic BoundType.dst2 +0.,0.,0.,0.,0.,0.,0.,0.,0.,0., # InterpolationType.quadratic BoundType.dst2 + 1., 1., 1., 1., 1., 1., 1., 1., 1.,-18., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., # InterpolationType.quadratic BoundType.dst2 +0., # InterpolationType.quadratic BoundType.dst2 +0., 0.,-2.0489097e-08,-2.0489097e-08,-2.0489097e-08,-2.0489097e-08,-2.0489097e-08,-2.0489097e-08,-2.0489097e-08, 9.3132257e-09, 8.7500000e-01, 1.0000000e+00, 1.0000000e+00, 1.0000000e+00, 1.0000000e+00, 1.0000000e+00, 1.0000000e+00, 1.0000000e+00,-1.3750000e+00,-1.3250000e+01,-1.3750000e+00, 1.0000000e+00, 1.0000000e+00, 1.0000000e+00, 1.0000000e+00, 1.0000000e+00, 1.0000000e+00, 1.0000000e+00, 8.7500000e-01, 2.5000000e-01, # InterpolationType.cubic BoundType.dst2 +0., 0.,-2.0489097e-08,-2.0489097e-08,-2.0489097e-08,-2.0489097e-08,-2.0489097e-08,-2.0489097e-08,-2.0489097e-08, 9.3132257e-09, # InterpolationType.cubic BoundType.dst2 + 0.875, 1. , 1. , 1. , 1. , 1. , 1. , 1. , -1.375,-13.25 , -1.375, 1. , 1. , 1. , 1. , 1. , 1. , 1. , 0.875, 0.25 , # InterpolationType.cubic BoundType.dst2 +0., # InterpolationType.cubic BoundType.dst2 +0., 0.,-4.0978193e-08,-4.0978193e-08,-4.0978193e-08,-4.0978193e-08,-4.0978193e-08,-4.0978193e-08,-4.0978193e-08,-1.1175871e-08, 8.3333337e-01, 1.0000000e+00, 1.0000000e+00, 1.0000000e+00, 1.0000000e+00, 1.0000000e+00, 9.9999988e-01, 1.0000000e+00,-2.1666670e+00,-1.1666667e+01,-2.1666667e+00, 1.0000000e+00, 1.0000001e+00, 1.0000002e+00, 1.0000000e+00, 1.0000001e+00, 1.0000001e+00, 1.0000000e+00, 8.3333337e-01, 3.3333334e-01, # InterpolationType.fourth BoundType.dst2 +0., 0.,-4.0978193e-08,-4.0978193e-08,-4.0978193e-08,-4.0978193e-08,-4.0978193e-08,-4.0978193e-08,-4.0978193e-08,-1.1175871e-08, # InterpolationType.fourth BoundType.dst2 + 0.8333334 , 1. , 1. , 1. , 1. , 1. , 0.9999999 , 1. , -2.166667 ,-11.666667 , -2.1666667 , 1. , 1.0000001 , 1.0000002 , 1. , 1.0000001 , 1.0000001 , 1. , 0.8333334 , 0.33333334, # InterpolationType.fourth BoundType.dst2 +0., # InterpolationType.fourth BoundType.dst2 +0., 3.7252903e-09, 1.1175871e-08, 7.1886461e-09, 7.1886461e-09, 7.1886461e-09, 7.1886461e-09, 7.1886461e-09, 7.1886461e-09, 1.0913936e-08, 8.0208331e-01, 9.9739581e-01, 1.0000000e+00, 1.0000000e+00, 9.9999994e-01, 1.0000001e+00, 9.9999976e-01, 9.5052075e-01,-2.7604163e+00,-1.0380208e+01,-2.7604165e+00, 9.5052069e-01, 9.9999982e-01, 1.0000002e+00, 1.0000000e+00, 9.9999982e-01, 1.0000000e+00, 9.9739587e-01, 8.0208331e-01, 4.0104166e-01, # InterpolationType.fifth BoundType.dst2 +0.,3.7252903e-09,1.1175871e-08,7.1886461e-09,7.1886461e-09,7.1886461e-09,7.1886461e-09,7.1886461e-09,7.1886461e-09,1.0913936e-08, # InterpolationType.fifth BoundType.dst2 + 0.8020833 , 0.9973958 , 1. , 1. , 0.99999994, 1.0000001 , 0.99999976, 0.95052075, -2.7604163 ,-10.380208 , -2.7604165 , 0.9505207 , 0.9999998 , 1.0000002 , 1. , 0.9999998 , 1. , 0.9973959 , 0.8020833 , 0.40104166, # InterpolationType.fifth BoundType.dst2 +0., # InterpolationType.fifth BoundType.dst2 +5.9604645e-08,-1.4901161e-08,-1.4901161e-08, 3.6088750e-09, 3.6088750e-09, 3.6088750e-09, 3.6088750e-09, 3.6088750e-09, 3.6088750e-09,-1.1292286e-08, 7.8333330e-01, 9.9166673e-01, 1.0000000e+00, 1.0000000e+00, 1.0000000e+00, 9.9999982e-01, 1.0000004e+00, 8.4166676e-01,-3.1166668e+00,-9.4499998e+00,-3.1166666e+00, 8.4166652e-01, 1.0000004e+00, 1.0000001e+00, 1.0000001e+00, 1.0000000e+00, 9.9999994e-01, 9.9166667e-01, 7.8333330e-01, 4.5000002e-01, # InterpolationType.sixth BoundType.dst2 +5.9604645e-08,-1.4901161e-08,-1.4901161e-08, 3.6088750e-09, 3.6088750e-09, 3.6088750e-09, 3.6088750e-09, 3.6088750e-09, 3.6088750e-09,-1.1292286e-08, # InterpolationType.sixth BoundType.dst2 +0.7833333 , 0.99166673, 1. , 1. , 1. , 0.9999998 , 1.0000004 , 0.84166676,-3.1166668 ,-9.45 ,-3.1166666 , 0.8416665 , 1.0000004 , 1.0000001 , 1.0000001 , 1. , 0.99999994, 0.9916667 , 0.7833333 , 0.45000002, # InterpolationType.sixth BoundType.dst2 +0., # InterpolationType.sixth BoundType.dst2 +0.,-7.4505806e-09,-6.9849193e-09, 2.3283064e-09, 1.9498430e-09, 1.9498430e-09, 1.9498430e-09, 1.9498430e-09, 1.9498430e-09,-5.0350764e-09, 7.7120221e-01, 9.8433155e-01, 9.9997836e-01, 9.9999994e-01, 1.0000000e+00, 1.0000001e+00, 9.9958777e-01, 7.0230043e-01,-3.3471570e+00,-8.7094622e+00,-3.3471570e+00, 7.0230043e-01, 9.9958777e-01, 1.0000000e+00, 1.0000000e+00, 1.0000001e+00, 9.9997824e-01, 9.8433161e-01, 7.7120221e-01, 4.8897570e-01, # InterpolationType.seventh BoundType.dst2 +0.,-7.4505806e-09,-6.9849193e-09, 2.3283064e-09, 1.9498430e-09, 1.9498430e-09, 1.9498430e-09, 1.9498430e-09, 1.9498430e-09,-5.0350764e-09, # InterpolationType.seventh BoundType.dst2 +0.7712022 , 0.98433155, 0.99997836, 0.99999994, 1. , 1.0000001 , 0.9995878 , 0.7023004 ,-3.347157 ,-8.709462 ,-3.347157 , 0.7023004 , 0.9995878 , 1. , 1. , 1.0000001 , 0.99997824, 0.9843316 , 0.7712022 , 0.4889757 , # InterpolationType.seventh BoundType.dst2 +0., # InterpolationType.seventh BoundType.dst2 +2.,2.,2.,2.,2.,2.,2.,2.,2.,2.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0., # InterpolationType.nearest BoundType.dft +2.,2.,2.,2.,2.,2.,2.,2.,2.,2., # InterpolationType.nearest BoundType.dft +0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0., # InterpolationType.nearest BoundType.dft +0., # InterpolationType.nearest BoundType.dft +2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 1., 1., 1., 1., 1., 1., 1., 1., 1.,-9., 1., 1., 1., 1., 1., 1., 1., 1., 1.,-9., # InterpolationType.linear BoundType.dft +2.,2.,2.,2.,2.,2.,2.,2.,2.,2., # InterpolationType.linear BoundType.dft +1., 1., 1., 1., 1., 1., 1., 1., 1.,-9., 1., 1., 1., 1., 1., 1., 1., 1., 1.,-9., # InterpolationType.linear BoundType.dft +0., # InterpolationType.linear BoundType.dft +2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 1., 1., 1., 1., 1., 1., 1., 1., 1.,-9., 1., 1., 1., 1., 1., 1., 1., 1., 1.,-9., # InterpolationType.quadratic BoundType.dft +2.,2.,2.,2.,2.,2.,2.,2.,2.,2., # InterpolationType.quadratic BoundType.dft +1., 1., 1., 1., 1., 1., 1., 1., 1.,-9., 1., 1., 1., 1., 1., 1., 1., 1., 1.,-9., # InterpolationType.quadratic BoundType.dft +0., # InterpolationType.quadratic BoundType.dft +2. , 2. , 1.9999999, 1.9999999, 1.9999999, 1.9999999, 1.9999999, 1.9999999, 1.9999999, 2. ,-0.25 , 1. , 1. , 1. , 1. , 1. , 1. , 1. ,-0.25 ,-6.5 ,-0.25 , 1. , 1. , 1. , 1. , 1. , 1. , 1. ,-0.25 ,-6.5 , # InterpolationType.cubic BoundType.dft +2. ,2. ,1.9999999,1.9999999,1.9999999,1.9999999,1.9999999,1.9999999,1.9999999,2. , # InterpolationType.cubic BoundType.dft +-0.25, 1. , 1. , 1. , 1. , 1. , 1. , 1. ,-0.25,-6.5 ,-0.25, 1. , 1. , 1. , 1. , 1. , 1. , 1. ,-0.25,-6.5 , # InterpolationType.cubic BoundType.dft +0., # InterpolationType.cubic BoundType.dft +2. , 2. , 2. , 2. , 2. , 2. , 2. , 2. , 2. , 2. ,-0.6666666, 1. , 1. , 1. , 1. , 1. , 0.9999999, 1. ,-0.666667 ,-5.666667 ,-0.6666666, 1. , 1. , 1. , 1. , 1. , 0.9999999, 1. ,-0.666667 ,-5.666667 , # InterpolationType.fourth BoundType.dft +2.,2.,2.,2.,2.,2.,2.,2.,2.,2., # InterpolationType.fourth BoundType.dft +-0.6666666, 1. , 1. , 1. , 1. , 1. , 0.9999999, 1. ,-0.666667 ,-5.666667 ,-0.6666666, 1. , 1. , 1. , 1. , 1. , 0.9999999, 1. ,-0.666667 ,-5.666667 , # InterpolationType.fourth BoundType.dft +0., # InterpolationType.fourth BoundType.dft +2. , 2. , 2. , 2. , 2. , 2. , 2. , 2. , 2. , 2. ,-0.97916675, 0.9739583 , 1. , 1. , 0.99999994, 1.0000001 , 0.99999976, 0.97395825,-0.9791663 ,-4.989583 ,-0.97916675, 0.9739583 , 1. , 1. , 0.99999994, 1.0000001 , 0.99999976, 0.97395825,-0.9791663 ,-4.989583 , # InterpolationType.fifth BoundType.dft +2.,2.,2.,2.,2.,2.,2.,2.,2.,2., # InterpolationType.fifth BoundType.dft +-0.97916675, 0.9739583 , 1. , 1. , 0.99999994, 1.0000001 , 0.99999976, 0.97395825,-0.9791663 ,-4.989583 ,-0.97916675, 0.9739583 , 1. , 1. , 0.99999994, 1.0000001 , 0.99999976, 0.97395825,-0.9791663 ,-4.989583 , # InterpolationType.fifth BoundType.dft +0., # InterpolationType.fifth BoundType.dft +1.9999999 , 2. , 2. , 2. , 2. , 2. , 2. , 2. , 2. , 2. ,-1.1666667 , 0.9166667 , 1. , 1. , 1. , 0.9999998 , 1.0000004 , 0.91666675,-1.1666669 ,-4.4999995 ,-1.1666667 , 0.9166667 , 1. , 1. , 1. , 0.9999998 , 1.0000004 , 0.91666675,-1.1666669 ,-4.4999995 , # InterpolationType.sixth BoundType.dft +1.9999999,2. ,2. ,2. ,2. ,2. ,2. ,2. ,2. ,2. , # InterpolationType.sixth BoundType.dft +-1.1666667 , 0.9166667 , 1. , 1. , 1. , 0.9999998 , 1.0000004 , 0.91666675,-1.1666669 ,-4.4999995 ,-1.1666667 , 0.9166667 , 1. , 1. , 1. , 0.9999998 , 1.0000004 , 0.91666675,-1.1666669 ,-4.4999995 , # InterpolationType.sixth BoundType.dft +0., # InterpolationType.sixth BoundType.dft +2.0000002 , 2.0000002 , 2.0000002 , 2.0000002 , 2.0000002 , 2.0000002 , 2.0000002 , 2.0000002 , 2.0000002 , 2.0000002 ,-1.2879773 , 0.84331596, 0.999783 , 0.99999994, 1. , 1.0000001 , 0.9997831 , 0.8433161 ,-1.2879775 ,-4.110243 ,-1.2879773 , 0.84331596, 0.999783 , 0.99999994, 1. , 1.0000001 , 0.9997831 , 0.8433161 ,-1.2879775 ,-4.110243 , # InterpolationType.seventh BoundType.dft +2.0000002,2.0000002,2.0000002,2.0000002,2.0000002,2.0000002,2.0000002,2.0000002,2.0000002,2.0000002, # InterpolationType.seventh BoundType.dft +-1.2879773 , 0.84331596, 0.999783 , 0.99999994, 1. , 1.0000001 , 0.9997831 , 0.8433161 ,-1.2879775 ,-4.110243 ,-1.2879773 , 0.84331596, 0.999783 , 0.99999994, 1. , 1.0000001 , 0.9997831 , 0.8433161 ,-1.2879775 ,-4.110243 , # InterpolationType.seventh BoundType.dft +0., # InterpolationType.seventh BoundType.dft +0.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0., # InterpolationType.nearest BoundType.zero +0.,1.,1.,1.,1.,1.,1.,1.,1.,1., # InterpolationType.nearest BoundType.zero +0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0., # InterpolationType.nearest BoundType.zero +0., # InterpolationType.nearest BoundType.zero +0.5, 1. , 1. , 1. , 1. , 1. , 1. , 1. , 1. , 1. , 1. , 1. , 1. , 1. , 1. , 1. , 1. , 1. , 1. ,-9. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , # InterpolationType.linear BoundType.zero +0.5,1. ,1. ,1. ,1. ,1. ,1. ,1. ,1. ,1. , # InterpolationType.linear BoundType.zero +1., 1., 1., 1., 1., 1., 1., 1., 1.,-9., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., # InterpolationType.linear BoundType.zero +0., # InterpolationType.linear BoundType.zero +0.5, 1. , 1. , 1. , 1. , 1. , 1. , 1. , 1. , 1. , 1. , 1. , 1. , 1. , 1. , 1. , 1. , 1. , 1. ,-9. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , # InterpolationType.quadratic BoundType.zero +0.5,1. ,1. ,1. ,1. ,1. ,1. ,1. ,1. ,1. , # InterpolationType.quadratic BoundType.zero +1., 1., 1., 1., 1., 1., 1., 1., 1.,-9., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., # InterpolationType.quadratic BoundType.zero +0., # InterpolationType.quadratic BoundType.zero +0.5 , 0.9791666 , 0.99999994, 0.99999994, 0.99999994, 0.99999994, 0.99999994, 0.99999994, 0.99999994, 0.99999994, 0.875 , 1. , 1. , 1. , 1. , 1. , 1. , 1. ,-0.25 ,-6.625 ,-1.125 , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , # InterpolationType.cubic BoundType.zero +0.5 ,0.9791666 ,0.99999994,0.99999994,0.99999994,0.99999994,0.99999994,0.99999994,0.99999994,0.99999994, # InterpolationType.cubic BoundType.zero +0.875, 1. , 1. , 1. , 1. , 1. , 1. , 1. ,-0.25 ,-6.625,-1.125, 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , # InterpolationType.cubic BoundType.zero +0., # InterpolationType.cubic BoundType.zero +0.5 , 0.9583334, 1. , 1. , 1. , 1. , 1. , 1. , 1. , 1. , 0.8333334, 1. , 1. , 1. , 1. , 1. , 0.9999999, 1. ,-0.666667 ,-5.8333335,-1.5 , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , # InterpolationType.fourth BoundType.zero +0.5 ,0.9583334,1. ,1. ,1. ,1. ,1. ,1. ,1. ,1. , # InterpolationType.fourth BoundType.zero +0.8333334, 1. , 1. , 1. , 1. , 1. , 0.9999999, 1. ,-0.666667 ,-5.8333335,-1.5 , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , # InterpolationType.fourth BoundType.zero +0., # InterpolationType.fourth BoundType.zero +0.5 , 0.9380208 , 0.9997396 , 1. , 1. , 1. , 1. , 1. , 1. , 1. , 0.7994791 , 0.9973958 , 1. , 1. , 0.99999994, 1.0000001 , 0.99999976, 0.97395825,-0.9817705 ,-5.190104 ,-1.7786459 ,-0.0234375 , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , # InterpolationType.fifth BoundType.zero +0.5 ,0.9380208,0.9997396,1. ,1. ,1. ,1. ,1. ,1. ,1. , # InterpolationType.fifth BoundType.zero +0.7994791 , 0.9973958 , 1. , 1. , 0.99999994, 1.0000001 , 0.99999976, 0.97395825,-0.9817705 ,-5.190104 ,-1.7786459 ,-0.0234375 , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , # InterpolationType.fifth BoundType.zero +0., # InterpolationType.fifth BoundType.zero +0.49999997, 0.91944444, 0.9986111 , 1. , 1. , 1. , 1. , 1. , 1. , 1. , 0.775 , 0.99166673, 1. , 1. , 1. , 0.9999998 , 1.0000004 , 0.91666675,-1.1750002 ,-4.725 ,-1.9416667 ,-0.075 , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , # InterpolationType.sixth BoundType.zero +0.49999997,0.91944444,0.9986111 ,1. ,1. ,1. ,1. ,1. ,1. ,1. , # InterpolationType.sixth BoundType.zero +0.775 , 0.99166673, 1. , 1. , 1. , 0.9999998 , 1.0000004 , 0.91666675,-1.1750002 ,-4.725 ,-1.9416667 ,-0.075 , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , # InterpolationType.sixth BoundType.zero +0., # InterpolationType.sixth BoundType.zero +5.0000000e-01, 9.0259641e-01, 9.9662077e-01, 9.9999845e-01, 1.0000000e+00, 1.0000000e+00, 1.0000000e+00, 1.0000000e+00, 1.0000000e+00, 1.0000000e+00, 7.5551212e-01, 9.8430985e-01, 9.9997836e-01, 9.9999994e-01, 1.0000000e+00, 1.0000001e+00, 9.9978310e-01, 8.4329438e-01,-1.3036675e+00,-4.3547311e+00,-2.0434895e+00,-1.4099392e-01,-1.9531250e-04, 0., 0., 0., 0., 0., 0., 0., # InterpolationType.seventh BoundType.zero +0.5 ,0.9025964 ,0.9966208 ,0.99999845,1. ,1. ,1. ,1. ,1. ,1. , # InterpolationType.seventh BoundType.zero +7.5551212e-01, 9.8430985e-01, 9.9997836e-01, 9.9999994e-01, 1.0000000e+00, 1.0000001e+00, 9.9978310e-01, 8.4329438e-01,-1.3036675e+00,-4.3547311e+00,-2.0434895e+00,-1.4099392e-01,-1.9531250e-04, 0., 0., 0., 0., 0., 0., 0., # InterpolationType.seventh BoundType.zero +0., # InterpolationType.seventh BoundType.zero diff --git a/tests/testing_data/cpp_resample_answers.py b/tests/testing_data/cpp_resample_answers.py index b50b1d7395..51ac6ccda9 100644 --- a/tests/testing_data/cpp_resample_answers.py +++ b/tests/testing_data/cpp_resample_answers.py @@ -38,3 +38,4 @@ def _read_testing_data_answers(fname: Optional[str] = None, delimiter=",") -> Li Expected_1D_GP_fwd: List = _read_testing_data_answers(fname="1D_BP_fwd.txt") +Expected_1D_GP_bwd: List = _read_testing_data_answers(fname="1D_BP_bwd.txt")