From 1ac790264fd068d34330c07c83d58149402d1434 Mon Sep 17 00:00:00 2001 From: niranda perera Date: Mon, 7 Jun 2021 15:30:26 -0400 Subject: [PATCH 1/5] adding dispatch best --- .../arrow/compute/kernels/scalar_if_else.cc | 34 +++++++++++++++++-- .../compute/kernels/scalar_if_else_test.cc | 7 ++++ 2 files changed, 39 insertions(+), 2 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else.cc b/cpp/src/arrow/compute/kernels/scalar_if_else.cc index 63086172c97..c3c46aea550 100644 --- a/cpp/src/arrow/compute/kernels/scalar_if_else.cc +++ b/cpp/src/arrow/compute/kernels/scalar_if_else.cc @@ -543,7 +543,37 @@ struct ResolveIfElseExec { } }; -void AddPrimitiveIfElseKernels(const std::shared_ptr& scalar_function, +struct IfElseFunction : ScalarFunction { + using ScalarFunction::ScalarFunction; + + Result DispatchBest(std::vector* values) const override { + RETURN_NOT_OK(CheckArity(*values)); + + // if-else 0'th descriptor is bool + std::vector left_right{(*values)[1], (*values)[2]}; + + using arrow::compute::detail::DispatchExactImpl; + if (auto kernel = DispatchExactImpl(this, *values)) return kernel; + + // internal::EnsureDictionaryDecoded(values); + + if (values->size() == 3) { + internal::ReplaceNullWithOtherType(&left_right); + + if (auto type = internal::CommonNumeric(left_right)) { + internal::ReplaceTypes(type, &left_right); + } + } + + if (auto kernel = + DispatchExactImpl(this, {(*values)[0], left_right[0], left_right[1]})) { + return kernel; + } + return arrow::compute::detail::NoMatchingKernel(this, *values); + } +}; + +void AddPrimitiveIfElseKernels(const std::shared_ptr& scalar_function, const std::vector>& types) { for (auto&& type : types) { auto exec = internal::GenerateTypeAgnosticPrimitive(*type); @@ -572,7 +602,7 @@ void RegisterScalarIfElse(FunctionRegistry* registry) { scalar_kernel.null_handling = NullHandling::COMPUTED_NO_PREALLOCATE; scalar_kernel.mem_allocation = MemAllocation::NO_PREALLOCATE; - auto func = std::make_shared("if_else", Arity::Ternary(), &if_else_doc); + auto func = std::make_shared("if_else", Arity::Ternary(), &if_else_doc); AddPrimitiveIfElseKernels(func, NumericTypes()); AddPrimitiveIfElseKernels(func, TemporalTypes()); diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc index 5d3d22210d2..4a7493a35fc 100644 --- a/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc @@ -271,5 +271,12 @@ TEST_F(TestIfElseKernel, IfElseNull) { ArrayFromJSON(null(), "[null, null, null, null]")); } +TEST_F(TestIfElseKernel, IfElseMultiType) { + CheckWithDifferentShapes(ArrayFromJSON(boolean(), "[true, true, true, false]"), + ArrayFromJSON(int32(), "[1, 2, 3, 4]"), + ArrayFromJSON(float32(), "[5, 6, 7, 8]"), + ArrayFromJSON(float32(), "[1, 2, 3, 8]")); +} + } // namespace compute } // namespace arrow From 07432422486bf64191c37607ecfd8cd646793d25 Mon Sep 17 00:00:00 2001 From: niranda perera Date: Mon, 7 Jun 2021 16:25:53 -0400 Subject: [PATCH 2/5] fixing errors --- .../arrow/compute/kernels/scalar_if_else.cc | 23 +++++++-------- .../compute/kernels/scalar_if_else_test.cc | 28 +++++++++++++++++++ 2 files changed, 38 insertions(+), 13 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else.cc b/cpp/src/arrow/compute/kernels/scalar_if_else.cc index c3c46aea550..a0210be658a 100644 --- a/cpp/src/arrow/compute/kernels/scalar_if_else.cc +++ b/cpp/src/arrow/compute/kernels/scalar_if_else.cc @@ -549,26 +549,23 @@ struct IfElseFunction : ScalarFunction { Result DispatchBest(std::vector* values) const override { RETURN_NOT_OK(CheckArity(*values)); - // if-else 0'th descriptor is bool - std::vector left_right{(*values)[1], (*values)[2]}; + // if-else 0'th descriptor is bool, so skip it + std::vector values_copy(values->begin() + 1, values->end()); using arrow::compute::detail::DispatchExactImpl; if (auto kernel = DispatchExactImpl(this, *values)) return kernel; - // internal::EnsureDictionaryDecoded(values); + internal::EnsureDictionaryDecoded(&values_copy); + internal::ReplaceNullWithOtherType(&values_copy); - if (values->size() == 3) { - internal::ReplaceNullWithOtherType(&left_right); - - if (auto type = internal::CommonNumeric(left_right)) { - internal::ReplaceTypes(type, &left_right); - } + if (auto type = internal::CommonNumeric(values_copy)) { + internal::ReplaceTypes(type, &values_copy); } - if (auto kernel = - DispatchExactImpl(this, {(*values)[0], left_right[0], left_right[1]})) { - return kernel; - } + std::copy(values_copy.begin(), values_copy.end(), values->begin() + 1); + + if (auto kernel = DispatchExactImpl(this, *values)) return kernel; + return arrow::compute::detail::NoMatchingKernel(this, *values); } }; diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc index 4a7493a35fc..aa33e73677e 100644 --- a/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc @@ -278,5 +278,33 @@ TEST_F(TestIfElseKernel, IfElseMultiType) { ArrayFromJSON(float32(), "[1, 2, 3, 8]")); } +TEST_F(TestIfElseKernel, IfElseDispatchBest) { + std::string name = "if_else"; + CheckDispatchBest(name, {boolean(), int32(), int32()}, {boolean(), int32(), int32()}); + CheckDispatchBest(name, {boolean(), int32(), null()}, {boolean(), int32(), int32()}); + CheckDispatchBest(name, {boolean(), null(), int32()}, {boolean(), int32(), int32()}); + + CheckDispatchBest(name, {boolean(), int32(), int8()}, {boolean(), int32(), int32()}); + CheckDispatchBest(name, {boolean(), int32(), int16()}, {boolean(), int32(), int32()}); + CheckDispatchBest(name, {boolean(), int32(), int32()}, {boolean(), int32(), int32()}); + CheckDispatchBest(name, {boolean(), int32(), int64()}, {boolean(), int64(), int64()}); + + CheckDispatchBest(name, {boolean(), int32(), uint8()}, {boolean(), int32(), int32()}); + CheckDispatchBest(name, {boolean(), int32(), uint16()}, {boolean(), int32(), int32()}); + CheckDispatchBest(name, {boolean(), int32(), uint32()}, {boolean(), int64(), int64()}); + CheckDispatchBest(name, {boolean(), int32(), uint64()}, {boolean(), int64(), int64()}); + + CheckDispatchBest(name, {boolean(), uint8(), uint8()}, {boolean(), uint8(), uint8()}); + CheckDispatchBest(name, {boolean(), uint8(), uint16()}, + {boolean(), uint16(), uint16()}); + + CheckDispatchBest(name, {boolean(), int32(), float32()}, + {boolean(), float32(), float32()}); + CheckDispatchBest(name, {boolean(), float32(), int64()}, + {boolean(), float32(), float32()}); + CheckDispatchBest(name, {boolean(), float64(), int32()}, + {boolean(), float64(), float64()}); +} + } // namespace compute } // namespace arrow From bb9a0a93077456b8e50b97717aef13119b3cf1ba Mon Sep 17 00:00:00 2001 From: niranda perera Date: Mon, 7 Jun 2021 16:54:48 -0400 Subject: [PATCH 3/5] Update cpp/src/arrow/compute/kernels/scalar_if_else.cc Co-authored-by: Benjamin Kietzman --- cpp/src/arrow/compute/kernels/scalar_if_else.cc | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else.cc b/cpp/src/arrow/compute/kernels/scalar_if_else.cc index a0210be658a..04e3c5e9f58 100644 --- a/cpp/src/arrow/compute/kernels/scalar_if_else.cc +++ b/cpp/src/arrow/compute/kernels/scalar_if_else.cc @@ -549,12 +549,11 @@ struct IfElseFunction : ScalarFunction { Result DispatchBest(std::vector* values) const override { RETURN_NOT_OK(CheckArity(*values)); - // if-else 0'th descriptor is bool, so skip it - std::vector values_copy(values->begin() + 1, values->end()); - using arrow::compute::detail::DispatchExactImpl; if (auto kernel = DispatchExactImpl(this, *values)) return kernel; + // if-else 0'th descriptor is bool, so skip it + std::vector values_copy(values->begin() + 1, values->end()); internal::EnsureDictionaryDecoded(&values_copy); internal::ReplaceNullWithOtherType(&values_copy); From de2c6ed0cfa121effcaeeb7655d89b67c0cb67ff Mon Sep 17 00:00:00 2001 From: niranda perera Date: Mon, 7 Jun 2021 17:10:13 -0400 Subject: [PATCH 4/5] Update cpp/src/arrow/compute/kernels/scalar_if_else.cc Co-authored-by: Benjamin Kietzman --- cpp/src/arrow/compute/kernels/scalar_if_else.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else.cc b/cpp/src/arrow/compute/kernels/scalar_if_else.cc index 04e3c5e9f58..ddbf1a26779 100644 --- a/cpp/src/arrow/compute/kernels/scalar_if_else.cc +++ b/cpp/src/arrow/compute/kernels/scalar_if_else.cc @@ -561,7 +561,7 @@ struct IfElseFunction : ScalarFunction { internal::ReplaceTypes(type, &values_copy); } - std::copy(values_copy.begin(), values_copy.end(), values->begin() + 1); + std::move(values_copy.begin(), values_copy.end(), values->begin() + 1); if (auto kernel = DispatchExactImpl(this, *values)) return kernel; From cda32a7431afb64ff52d106a1bb49d2e612aecc6 Mon Sep 17 00:00:00 2001 From: niranda perera Date: Mon, 7 Jun 2021 17:14:24 -0400 Subject: [PATCH 5/5] allowing null type for 0th arg --- cpp/src/arrow/compute/kernels/scalar_if_else.cc | 5 +++++ cpp/src/arrow/compute/kernels/scalar_if_else_test.cc | 2 ++ 2 files changed, 7 insertions(+) diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else.cc b/cpp/src/arrow/compute/kernels/scalar_if_else.cc index ddbf1a26779..7a0defaccd6 100644 --- a/cpp/src/arrow/compute/kernels/scalar_if_else.cc +++ b/cpp/src/arrow/compute/kernels/scalar_if_else.cc @@ -552,6 +552,11 @@ struct IfElseFunction : ScalarFunction { using arrow::compute::detail::DispatchExactImpl; if (auto kernel = DispatchExactImpl(this, *values)) return kernel; + // if 0th descriptor is null, replace with bool + if (values->at(0).type->id() == Type::NA) { + values->at(0).type = boolean(); + } + // if-else 0'th descriptor is bool, so skip it std::vector values_copy(values->begin() + 1, values->end()); internal::EnsureDictionaryDecoded(&values_copy); diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc index aa33e73677e..0fb0a1fc2d8 100644 --- a/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc @@ -304,6 +304,8 @@ TEST_F(TestIfElseKernel, IfElseDispatchBest) { {boolean(), float32(), float32()}); CheckDispatchBest(name, {boolean(), float64(), int32()}, {boolean(), float64(), float64()}); + + CheckDispatchBest(name, {null(), uint8(), int8()}, {boolean(), int16(), int16()}); } } // namespace compute