diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else.cc b/cpp/src/arrow/compute/kernels/scalar_if_else.cc index 63086172c97..7a0defaccd6 100644 --- a/cpp/src/arrow/compute/kernels/scalar_if_else.cc +++ b/cpp/src/arrow/compute/kernels/scalar_if_else.cc @@ -543,7 +543,38 @@ struct ResolveIfElseExec { } }; -void AddPrimitiveIfElseKernels(const std::shared_ptr& scalar_function, +struct IfElseFunction : ScalarFunction { + using ScalarFunction::ScalarFunction; + + Result DispatchBest(std::vector* values) const override { + RETURN_NOT_OK(CheckArity(*values)); + + using arrow::compute::detail::DispatchExactImpl; + if (auto kernel = DispatchExactImpl(this, *values)) return kernel; + + // if 0th descriptor is null, replace with bool + if (values->at(0).type->id() == Type::NA) { + values->at(0).type = boolean(); + } + + // if-else 0'th descriptor is bool, so skip it + std::vector values_copy(values->begin() + 1, values->end()); + internal::EnsureDictionaryDecoded(&values_copy); + internal::ReplaceNullWithOtherType(&values_copy); + + if (auto type = internal::CommonNumeric(values_copy)) { + internal::ReplaceTypes(type, &values_copy); + } + + std::move(values_copy.begin(), values_copy.end(), values->begin() + 1); + + if (auto kernel = DispatchExactImpl(this, *values)) return kernel; + + return arrow::compute::detail::NoMatchingKernel(this, *values); + } +}; + +void AddPrimitiveIfElseKernels(const std::shared_ptr& scalar_function, const std::vector>& types) { for (auto&& type : types) { auto exec = internal::GenerateTypeAgnosticPrimitive(*type); @@ -572,7 +603,7 @@ void RegisterScalarIfElse(FunctionRegistry* registry) { scalar_kernel.null_handling = NullHandling::COMPUTED_NO_PREALLOCATE; scalar_kernel.mem_allocation = MemAllocation::NO_PREALLOCATE; - auto func = std::make_shared("if_else", Arity::Ternary(), &if_else_doc); + auto func = std::make_shared("if_else", Arity::Ternary(), &if_else_doc); AddPrimitiveIfElseKernels(func, NumericTypes()); AddPrimitiveIfElseKernels(func, TemporalTypes()); diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc index 5d3d22210d2..0fb0a1fc2d8 100644 --- a/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc @@ -271,5 +271,42 @@ TEST_F(TestIfElseKernel, IfElseNull) { ArrayFromJSON(null(), "[null, null, null, null]")); } +TEST_F(TestIfElseKernel, IfElseMultiType) { + CheckWithDifferentShapes(ArrayFromJSON(boolean(), "[true, true, true, false]"), + ArrayFromJSON(int32(), "[1, 2, 3, 4]"), + ArrayFromJSON(float32(), "[5, 6, 7, 8]"), + ArrayFromJSON(float32(), "[1, 2, 3, 8]")); +} + +TEST_F(TestIfElseKernel, IfElseDispatchBest) { + std::string name = "if_else"; + CheckDispatchBest(name, {boolean(), int32(), int32()}, {boolean(), int32(), int32()}); + CheckDispatchBest(name, {boolean(), int32(), null()}, {boolean(), int32(), int32()}); + CheckDispatchBest(name, {boolean(), null(), int32()}, {boolean(), int32(), int32()}); + + CheckDispatchBest(name, {boolean(), int32(), int8()}, {boolean(), int32(), int32()}); + CheckDispatchBest(name, {boolean(), int32(), int16()}, {boolean(), int32(), int32()}); + CheckDispatchBest(name, {boolean(), int32(), int32()}, {boolean(), int32(), int32()}); + CheckDispatchBest(name, {boolean(), int32(), int64()}, {boolean(), int64(), int64()}); + + CheckDispatchBest(name, {boolean(), int32(), uint8()}, {boolean(), int32(), int32()}); + CheckDispatchBest(name, {boolean(), int32(), uint16()}, {boolean(), int32(), int32()}); + CheckDispatchBest(name, {boolean(), int32(), uint32()}, {boolean(), int64(), int64()}); + CheckDispatchBest(name, {boolean(), int32(), uint64()}, {boolean(), int64(), int64()}); + + CheckDispatchBest(name, {boolean(), uint8(), uint8()}, {boolean(), uint8(), uint8()}); + CheckDispatchBest(name, {boolean(), uint8(), uint16()}, + {boolean(), uint16(), uint16()}); + + CheckDispatchBest(name, {boolean(), int32(), float32()}, + {boolean(), float32(), float32()}); + CheckDispatchBest(name, {boolean(), float32(), int64()}, + {boolean(), float32(), float32()}); + CheckDispatchBest(name, {boolean(), float64(), int32()}, + {boolean(), float64(), float64()}); + + CheckDispatchBest(name, {null(), uint8(), int8()}, {boolean(), int16(), int16()}); +} + } // namespace compute } // namespace arrow