diff --git a/cpp/src/arrow/compute/exec.cc b/cpp/src/arrow/compute/exec.cc index 090a901cb5e..8d3dcf0f2cd 100644 --- a/cpp/src/arrow/compute/exec.cc +++ b/cpp/src/arrow/compute/exec.cc @@ -33,6 +33,7 @@ #include "arrow/chunked_array.h" #include "arrow/compute/exec_internal.h" #include "arrow/compute/function.h" +#include "arrow/compute/function_internal.h" #include "arrow/compute/kernel.h" #include "arrow/compute/registry.h" #include "arrow/datum.h" @@ -883,6 +884,7 @@ class ScalarExecutor : public KernelExecutorImpl { } } if (kernel_->mem_allocation == MemAllocation::PREALLOCATE) { + data_preallocated_.clear(); ComputeDataPreallocate(*output_type_.type, &data_preallocated_); } @@ -966,6 +968,7 @@ class VectorExecutor : public KernelExecutorImpl { (kernel_->null_handling != NullHandling::COMPUTED_NO_PREALLOCATE && kernel_->null_handling != NullHandling::OUTPUT_NOT_NULL); if (kernel_->mem_allocation == MemAllocation::PREALLOCATE) { + data_preallocated_.clear(); ComputeDataPreallocate(*output_type_.type, &data_preallocated_); } @@ -1316,5 +1319,25 @@ Result CallFunction(const std::string& func_name, const ExecBatch& batch, return CallFunction(func_name, batch, /*options=*/nullptr, ctx); } +Result> GetFunctionExecutor( + const std::string& func_name, std::vector in_types, + const FunctionOptions* options, FunctionRegistry* func_registry) { + if (func_registry == NULLPTR) { + func_registry = GetFunctionRegistry(); + } + ARROW_ASSIGN_OR_RAISE(std::shared_ptr func, + func_registry->GetFunction(func_name)); + ARROW_ASSIGN_OR_RAISE(auto func_exec, func->GetBestExecutor(std::move(in_types))); + ARROW_RETURN_NOT_OK(func_exec->Init(options)); + return func_exec; +} + +Result> GetFunctionExecutor( + const std::string& func_name, const std::vector& args, + const FunctionOptions* options, FunctionRegistry* func_registry) { + ARROW_ASSIGN_OR_RAISE(auto in_types, internal::GetFunctionArgumentTypes(args)); + return GetFunctionExecutor(func_name, std::move(in_types), options, func_registry); +} + } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/exec.h b/cpp/src/arrow/compute/exec.h index b7598593886..a1f72b3e501 100644 --- a/cpp/src/arrow/compute/exec.h +++ b/cpp/src/arrow/compute/exec.h @@ -47,9 +47,6 @@ class CpuInfo; namespace compute { -class FunctionOptions; -class FunctionRegistry; - // It seems like 64K might be a good default chunksize to use for execution // based on the experience of other query processing systems. The current // default is not to chunk contiguous arrays, though, but this may change in @@ -440,5 +437,30 @@ Result CallFunction(const std::string& func_name, const ExecBatch& batch, /// @} +/// \defgroup compute-function-executor One-shot calls to obtain function executors +/// +/// @{ + +/// \brief One-shot executor provider for all types of functions. +/// +/// This function creates and initializes a `FunctionExecutor` appropriate +/// for the given function name, input types and function options. +ARROW_EXPORT +Result> GetFunctionExecutor( + const std::string& func_name, std::vector in_types, + const FunctionOptions* options = NULLPTR, FunctionRegistry* func_registry = NULLPTR); + +/// \brief One-shot executor provider for all types of functions. +/// +/// This function creates and initializes a `FunctionExecutor` appropriate +/// for the given function name, input types (taken from the Datum arguments) +/// and function options. +ARROW_EXPORT +Result> GetFunctionExecutor( + const std::string& func_name, const std::vector& args, + const FunctionOptions* options = NULLPTR, FunctionRegistry* func_registry = NULLPTR); + +/// @} + } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/exec_test.cc b/cpp/src/arrow/compute/exec_test.cc index f18af71dba1..cab9bd6a1d6 100644 --- a/cpp/src/arrow/compute/exec_test.cc +++ b/cpp/src/arrow/compute/exec_test.cc @@ -959,7 +959,7 @@ struct ExampleState : public KernelState { Result> InitStateful(KernelContext*, const KernelInitArgs& args) { auto func_options = static_cast(args.options); - return std::make_unique(func_options->value); + return std::make_unique(func_options ? func_options->value : nullptr); } Status ExecStateful(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) { @@ -1073,36 +1073,134 @@ class TestCallScalarFunction : public TestComputeInternals { bool TestCallScalarFunction::initialized_ = false; -TEST_F(TestCallScalarFunction, ArgumentValidation) { +class FunctionCaller { + public: + virtual ~FunctionCaller() = default; + + virtual Result Call(const std::vector& args, + const FunctionOptions* options, + ExecContext* ctx = NULLPTR) = 0; + virtual Result Call(const std::vector& args, + ExecContext* ctx = NULLPTR) = 0; +}; + +using FunctionCallerMaker = std::function>( + const std::string& func_name, std::vector in_types)>; + +class SimpleFunctionCaller : public FunctionCaller { + public: + explicit SimpleFunctionCaller(const std::string& func_name) : func_name(func_name) {} + + static Result> Make(const std::string& func_name) { + return std::make_shared(func_name); + } + + static Result> Maker(const std::string& func_name, + std::vector in_types) { + return Make(func_name); + } + + Result Call(const std::vector& args, const FunctionOptions* options, + ExecContext* ctx) override { + return CallFunction(func_name, args, options, ctx); + } + Result Call(const std::vector& args, ExecContext* ctx) override { + return CallFunction(func_name, args, ctx); + } + + std::string func_name; +}; + +class ExecFunctionCaller : public FunctionCaller { + public: + explicit ExecFunctionCaller(std::shared_ptr func_exec) + : func_exec(std::move(func_exec)) {} + + static Result> Make( + const std::string& func_name, const std::vector& args, + const FunctionOptions* options = nullptr, + FunctionRegistry* func_registry = nullptr) { + ARROW_ASSIGN_OR_RAISE(auto func_exec, + GetFunctionExecutor(func_name, args, options, func_registry)); + return std::make_shared(std::move(func_exec)); + } + + static Result> Make( + const std::string& func_name, std::vector in_types, + const FunctionOptions* options = nullptr, + FunctionRegistry* func_registry = nullptr) { + ARROW_ASSIGN_OR_RAISE( + auto func_exec, GetFunctionExecutor(func_name, in_types, options, func_registry)); + return std::make_shared(std::move(func_exec)); + } + + static Result> Maker(const std::string& func_name, + std::vector in_types) { + return Make(func_name, std::move(in_types)); + } + + Result Call(const std::vector& args, const FunctionOptions* options, + ExecContext* ctx) override { + ARROW_RETURN_NOT_OK(func_exec->Init(options, ctx)); + return func_exec->Execute(args); + } + Result Call(const std::vector& args, ExecContext* ctx) override { + return Call(args, nullptr, ctx); + } + + std::shared_ptr func_exec; +}; + +class TestCallScalarFunctionArgumentValidation : public TestCallScalarFunction { + protected: + void DoTest(FunctionCallerMaker caller_maker); +}; + +void TestCallScalarFunctionArgumentValidation::DoTest(FunctionCallerMaker caller_maker) { + ASSERT_OK_AND_ASSIGN(auto test_copy, caller_maker("test_copy", {int32()})); + // Copy accepts only a single array argument Datum d1(GetInt32Array(10)); // Too many args std::vector args = {d1, d1}; - ASSERT_RAISES(Invalid, CallFunction("test_copy", args)); + ASSERT_RAISES(Invalid, test_copy->Call(args)); // Too few args = {}; - ASSERT_RAISES(Invalid, CallFunction("test_copy", args)); + ASSERT_RAISES(Invalid, test_copy->Call(args)); // Cannot do scalar Datum d1_scalar(std::make_shared(5)); - ASSERT_OK_AND_ASSIGN(auto result, CallFunction("test_copy", {d1})); - ASSERT_OK_AND_ASSIGN(result, CallFunction("test_copy", {d1_scalar})); + ASSERT_OK_AND_ASSIGN(auto result, test_copy->Call({d1})); + ASSERT_OK_AND_ASSIGN(result, test_copy->Call({d1_scalar})); +} + +TEST_F(TestCallScalarFunctionArgumentValidation, SimpleCall) { + TestCallScalarFunctionArgumentValidation::DoTest(SimpleFunctionCaller::Maker); +} + +TEST_F(TestCallScalarFunctionArgumentValidation, ExecCall) { + TestCallScalarFunctionArgumentValidation::DoTest(ExecFunctionCaller::Maker); } -TEST_F(TestCallScalarFunction, PreallocationCases) { +class TestCallScalarFunctionPreallocationCases : public TestCallScalarFunction { + protected: + void DoTest(FunctionCallerMaker caller_maker); +}; + +void TestCallScalarFunctionPreallocationCases::DoTest(FunctionCallerMaker caller_maker) { double null_prob = 0.2; auto arr = GetUInt8Array(100, null_prob); - auto CheckFunction = [&](std::string func_name) { + auto CheckFunction = [&](std::shared_ptr test_copy) { ResetContexts(); // The default should be a single array output { std::vector args = {Datum(arr)}; - ASSERT_OK_AND_ASSIGN(Datum result, CallFunction(func_name, args)); + ASSERT_OK_AND_ASSIGN(Datum result, test_copy->Call(args)); ASSERT_EQ(Datum::ARRAY, result.kind()); AssertArraysEqual(*arr, *result.make_array()); } @@ -1112,7 +1210,7 @@ TEST_F(TestCallScalarFunction, PreallocationCases) { { std::vector args = {Datum(arr)}; exec_ctx_->set_exec_chunksize(80); - ASSERT_OK_AND_ASSIGN(Datum result, CallFunction(func_name, args, exec_ctx_.get())); + ASSERT_OK_AND_ASSIGN(Datum result, test_copy->Call(args, exec_ctx_.get())); AssertArraysEqual(*arr, *result.make_array()); } @@ -1120,7 +1218,7 @@ TEST_F(TestCallScalarFunction, PreallocationCases) { // Chunksize not multiple of 8 std::vector args = {Datum(arr)}; exec_ctx_->set_exec_chunksize(11); - ASSERT_OK_AND_ASSIGN(Datum result, CallFunction(func_name, args, exec_ctx_.get())); + ASSERT_OK_AND_ASSIGN(Datum result, test_copy->Call(args, exec_ctx_.get())); AssertArraysEqual(*arr, *result.make_array()); } @@ -1129,7 +1227,7 @@ TEST_F(TestCallScalarFunction, PreallocationCases) { auto carr = std::make_shared(ArrayVector{arr->Slice(0, 10), arr->Slice(10)}); std::vector args = {Datum(carr)}; - ASSERT_OK_AND_ASSIGN(Datum result, CallFunction(func_name, args, exec_ctx_.get())); + ASSERT_OK_AND_ASSIGN(Datum result, test_copy->Call(args, exec_ctx_.get())); std::shared_ptr actual = result.chunked_array(); ASSERT_EQ(1, actual->num_chunks()); AssertChunkedEquivalent(*carr, *actual); @@ -1140,7 +1238,7 @@ TEST_F(TestCallScalarFunction, PreallocationCases) { std::vector args = {Datum(arr)}; exec_ctx_->set_preallocate_contiguous(false); exec_ctx_->set_exec_chunksize(40); - ASSERT_OK_AND_ASSIGN(Datum result, CallFunction(func_name, args, exec_ctx_.get())); + ASSERT_OK_AND_ASSIGN(Datum result, test_copy->Call(args, exec_ctx_.get())); ASSERT_EQ(Datum::CHUNKED_ARRAY, result.kind()); const ChunkedArray& carr = *result.chunked_array(); ASSERT_EQ(3, carr.num_chunks()); @@ -1150,11 +1248,28 @@ TEST_F(TestCallScalarFunction, PreallocationCases) { } }; - CheckFunction("test_copy"); - CheckFunction("test_copy_computed_bitmap"); + ASSERT_OK_AND_ASSIGN(auto test_copy, caller_maker("test_copy", {uint8()})); + CheckFunction(test_copy); + ASSERT_OK_AND_ASSIGN(auto test_copy_computed_bitmap, + caller_maker("test_copy_computed_bitmap", {uint8()})); + CheckFunction(test_copy_computed_bitmap); +} + +TEST_F(TestCallScalarFunctionPreallocationCases, SimpleCaller) { + TestCallScalarFunctionPreallocationCases::DoTest(SimpleFunctionCaller::Maker); } -TEST_F(TestCallScalarFunction, BasicNonStandardCases) { +TEST_F(TestCallScalarFunctionPreallocationCases, ExecCaller) { + TestCallScalarFunctionPreallocationCases::DoTest(ExecFunctionCaller::Maker); +} + +class TestCallScalarFunctionBasicNonStandardCases : public TestCallScalarFunction { + protected: + void DoTest(FunctionCallerMaker caller_maker); +}; + +void TestCallScalarFunctionBasicNonStandardCases::DoTest( + FunctionCallerMaker caller_maker) { // Test a handful of cases // // * Validity bitmap computed by kernel rather than using PropagateNulls @@ -1166,19 +1281,19 @@ TEST_F(TestCallScalarFunction, BasicNonStandardCases) { auto arr = GetUInt8Array(1000, null_prob); std::vector args = {Datum(arr)}; - auto CheckFunction = [&](std::string func_name) { + auto CheckFunction = [&](std::shared_ptr test_nopre) { ResetContexts(); // The default should be a single array output { - ASSERT_OK_AND_ASSIGN(Datum result, CallFunction(func_name, args)); + ASSERT_OK_AND_ASSIGN(Datum result, test_nopre->Call(args)); AssertArraysEqual(*arr, *result.make_array(), true); } // Split execution into 3 chunks { exec_ctx_->set_exec_chunksize(400); - ASSERT_OK_AND_ASSIGN(Datum result, CallFunction(func_name, args, exec_ctx_.get())); + ASSERT_OK_AND_ASSIGN(Datum result, test_nopre->Call(args, exec_ctx_.get())); ASSERT_EQ(Datum::CHUNKED_ARRAY, result.kind()); const ChunkedArray& carr = *result.chunked_array(); ASSERT_EQ(3, carr.num_chunks()); @@ -1188,31 +1303,73 @@ TEST_F(TestCallScalarFunction, BasicNonStandardCases) { } }; - CheckFunction("test_nopre_data"); - CheckFunction("test_nopre_validity_or_data"); + ASSERT_OK_AND_ASSIGN(auto test_nopre_data, caller_maker("test_nopre_data", {uint8()})); + CheckFunction(test_nopre_data); + ASSERT_OK_AND_ASSIGN(auto test_nopre_validity_or_data, + caller_maker("test_nopre_validity_or_data", {uint8()})); + CheckFunction(test_nopre_validity_or_data); +} + +TEST_F(TestCallScalarFunctionBasicNonStandardCases, SimpleCall) { + TestCallScalarFunctionBasicNonStandardCases::DoTest(SimpleFunctionCaller::Maker); +} + +TEST_F(TestCallScalarFunctionBasicNonStandardCases, ExecCall) { + TestCallScalarFunctionBasicNonStandardCases::DoTest(ExecFunctionCaller::Maker); } -TEST_F(TestCallScalarFunction, StatefulKernel) { +class TestCallScalarFunctionStatefulKernel : public TestCallScalarFunction { + protected: + void DoTest(FunctionCallerMaker caller_maker); +}; + +void TestCallScalarFunctionStatefulKernel::DoTest(FunctionCallerMaker caller_maker) { + ASSERT_OK_AND_ASSIGN(auto test_stateful, caller_maker("test_stateful", {int32()})); + auto input = ArrayFromJSON(int32(), "[1, 2, 3, null, 5]"); auto multiplier = std::make_shared(2); auto expected = ArrayFromJSON(int32(), "[2, 4, 6, null, 10]"); ExampleOptions options(multiplier); std::vector args = {Datum(input)}; - ASSERT_OK_AND_ASSIGN(Datum result, CallFunction("test_stateful", args, &options)); + ASSERT_OK_AND_ASSIGN(Datum result, test_stateful->Call(args, &options)); AssertArraysEqual(*expected, *result.make_array()); } -TEST_F(TestCallScalarFunction, ScalarFunction) { +TEST_F(TestCallScalarFunctionStatefulKernel, Simplecall) { + TestCallScalarFunctionStatefulKernel::DoTest(SimpleFunctionCaller::Maker); +} + +TEST_F(TestCallScalarFunctionStatefulKernel, ExecCall) { + TestCallScalarFunctionStatefulKernel::DoTest(ExecFunctionCaller::Maker); +} + +class TestCallScalarFunctionScalarFunction : public TestCallScalarFunction { + protected: + void DoTest(FunctionCallerMaker caller_maker); +}; + +void TestCallScalarFunctionScalarFunction::DoTest(FunctionCallerMaker caller_maker) { + ASSERT_OK_AND_ASSIGN(auto test_scalar_add_int32, + caller_maker("test_scalar_add_int32", {int32(), int32()})); + std::vector args = {Datum(std::make_shared(5)), Datum(std::make_shared(7))}; - ASSERT_OK_AND_ASSIGN(Datum result, CallFunction("test_scalar_add_int32", args)); + ASSERT_OK_AND_ASSIGN(Datum result, test_scalar_add_int32->Call(args)); ASSERT_EQ(Datum::SCALAR, result.kind()); auto expected = std::make_shared(12); ASSERT_TRUE(expected->Equals(*result.scalar())); } +TEST_F(TestCallScalarFunctionScalarFunction, SimpleCall) { + TestCallScalarFunctionScalarFunction::DoTest(SimpleFunctionCaller::Maker); +} + +TEST_F(TestCallScalarFunctionScalarFunction, ExecCall) { + TestCallScalarFunctionScalarFunction::DoTest(ExecFunctionCaller::Maker); +} + } // namespace detail } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/function.cc b/cpp/src/arrow/compute/function.cc index 12d80a8c9ae..90e754f6150 100644 --- a/cpp/src/arrow/compute/function.cc +++ b/cpp/src/arrow/compute/function.cc @@ -97,6 +97,18 @@ Status Function::CheckArity(size_t num_args) const { return CheckArityImpl(*this, static_cast(num_args)); } +namespace { + +Status CheckOptions(const Function& function, const FunctionOptions* options) { + if (options == nullptr && function.doc().options_required) { + return Status::Invalid("Function '", function.name(), + "' cannot be called without options"); + } + return Status::OK(); +} + +} // namespace + namespace detail { Status NoMatchingKernel(const Function* func, const std::vector& types) { @@ -167,6 +179,118 @@ const Kernel* DispatchExactImpl(const Function* func, return nullptr; } +struct FunctionExecutorImpl : public FunctionExecutor { + FunctionExecutorImpl(std::vector in_types, const Kernel* kernel, + std::unique_ptr executor, + const Function& func) + : in_types(std::move(in_types)), + kernel(kernel), + kernel_ctx(default_exec_context(), kernel), + executor(std::move(executor)), + func(func), + state(), + options(NULLPTR), + inited(false) {} + virtual ~FunctionExecutorImpl() {} + + Status KernelInit(const FunctionOptions* options) { + RETURN_NOT_OK(CheckOptions(func, options)); + if (options == NULLPTR) { + options = func.default_options(); + } + if (kernel->init) { + ARROW_ASSIGN_OR_RAISE(state, + kernel->init(&kernel_ctx, {kernel, in_types, options})); + kernel_ctx.SetState(state.get()); + } + + RETURN_NOT_OK(executor->Init(&kernel_ctx, {kernel, in_types, options})); + this->options = options; + inited = true; + return Status::OK(); + } + + Status Init(const FunctionOptions* options, ExecContext* exec_ctx) override { + if (exec_ctx == NULLPTR) { + exec_ctx = default_exec_context(); + } + kernel_ctx = KernelContext{exec_ctx, kernel}; + return KernelInit(options); + } + + Result Execute(const std::vector& args, int64_t passed_length) override { + util::tracing::Span span; + + auto func_kind = func.kind(); + const auto& func_name = func.name(); + START_COMPUTE_SPAN(span, func_name, + {{"function.name", func_name}, + {"function.options", options ? options->ToString() : ""}, + {"function.kind", func_kind}}); + + if (in_types.size() != args.size()) { + return Status::Invalid("Execution of '", func_name, "' expected ", in_types.size(), + " arguments but got ", args.size()); + } + if (!inited) { + ARROW_RETURN_NOT_OK(Init(NULLPTR, default_exec_context())); + } + ExecContext* ctx = kernel_ctx.exec_context(); + // Cast arguments if necessary + std::vector args_with_cast(args.size()); + for (size_t i = 0; i != args.size(); ++i) { + const auto& in_type = in_types[i]; + auto arg = args[i]; + if (in_type != args[i].type()) { + ARROW_ASSIGN_OR_RAISE(arg, Cast(args[i], CastOptions::Safe(in_type), ctx)); + } + args_with_cast[i] = std::move(arg); + } + + detail::DatumAccumulator listener; + + ExecBatch input(std::move(args_with_cast), /*length=*/0); + if (input.num_values() == 0) { + if (passed_length != -1) { + input.length = passed_length; + } + } else { + bool all_same_length = false; + int64_t inferred_length = detail::InferBatchLength(input.values, &all_same_length); + input.length = inferred_length; + if (func_kind == Function::SCALAR) { + if (passed_length != -1 && passed_length != inferred_length) { + return Status::Invalid( + "Passed batch length for execution did not match actual" + " length of values for execution of scalar function '", + func_name, "'"); + } + } else if (func_kind == Function::VECTOR) { + auto vkernel = static_cast(kernel); + if (!all_same_length && vkernel->can_execute_chunkwise) { + return Status::Invalid("Arguments for execution of vector kernel function '", + func_name, "' must all be the same length"); + } + } + } + RETURN_NOT_OK(executor->Execute(input, &listener)); + const auto out = executor->WrapResults(input.values, listener.values()); +#ifndef NDEBUG + DCHECK_OK(executor->CheckResultType(out, func_name.c_str())); +#endif + return out; + } + + std::vector in_types; + const Kernel* kernel; + KernelContext kernel_ctx; + std::unique_ptr executor; + const Function& func; + std::unique_ptr state; + const FunctionOptions* options; + bool inited; +}; + } // namespace detail Result Function::DispatchExact( @@ -187,114 +311,34 @@ Result Function::DispatchBest(std::vector* values) co return DispatchExact(*values); } -namespace { - -Status CheckAllArrayOrScalar(const std::vector& values) { - for (const auto& value : values) { - if (!value.is_value()) { - return Status::Invalid("Tried executing function with non-value type: ", - value.ToString()); - } - } - return Status::OK(); -} - -Status CheckOptions(const Function& function, const FunctionOptions* options) { - if (options == nullptr && function.doc().options_required) { - return Status::Invalid("Function '", function.name(), - "' cannot be called without options"); - } - return Status::OK(); -} - -Result ExecuteInternal(const Function& func, std::vector args, - int64_t passed_length, const FunctionOptions* options, - ExecContext* ctx) { - std::unique_ptr default_ctx; - if (options == nullptr) { - RETURN_NOT_OK(CheckOptions(func, options)); - options = func.default_options(); - } - if (ctx == nullptr) { - default_ctx.reset(new ExecContext()); - ctx = default_ctx.get(); - } - - util::tracing::Span span; - - START_COMPUTE_SPAN(span, func.name(), - {{"function.name", func.name()}, - {"function.options", options ? options->ToString() : ""}, - {"function.kind", func.kind()}}); - - // type-check Datum arguments here. Really we'd like to avoid this as much as - // possible - RETURN_NOT_OK(CheckAllArrayOrScalar(args)); - std::vector in_types(args.size()); - for (size_t i = 0; i != args.size(); ++i) { - in_types[i] = args[i].type().get(); - } - +Result> Function::GetBestExecutor( + std::vector inputs) const { std::unique_ptr executor; - if (func.kind() == Function::SCALAR) { + if (kind() == Function::SCALAR) { executor = detail::KernelExecutor::MakeScalar(); - } else if (func.kind() == Function::VECTOR) { + } else if (kind() == Function::VECTOR) { executor = detail::KernelExecutor::MakeVector(); - } else if (func.kind() == Function::SCALAR_AGGREGATE) { + } else if (kind() == Function::SCALAR_AGGREGATE) { executor = detail::KernelExecutor::MakeScalarAggregate(); } else { return Status::NotImplemented("Direct execution of HASH_AGGREGATE functions"); } - ARROW_ASSIGN_OR_RAISE(const Kernel* kernel, func.DispatchBest(&in_types)); - - // Cast arguments if necessary - for (size_t i = 0; i != args.size(); ++i) { - if (in_types[i] != args[i].type()) { - ARROW_ASSIGN_OR_RAISE(args[i], Cast(args[i], CastOptions::Safe(in_types[i]), ctx)); - } - } - - KernelContext kernel_ctx{ctx, kernel}; + ARROW_ASSIGN_OR_RAISE(const Kernel* kernel, DispatchBest(&inputs)); - std::unique_ptr state; - if (kernel->init) { - ARROW_ASSIGN_OR_RAISE(state, kernel->init(&kernel_ctx, {kernel, in_types, options})); - kernel_ctx.SetState(state.get()); - } - - RETURN_NOT_OK(executor->Init(&kernel_ctx, {kernel, in_types, options})); + return std::make_shared(std::move(inputs), kernel, + std::move(executor), *this); +} - detail::DatumAccumulator listener; +namespace { - ExecBatch input(std::move(args), /*length=*/0); - if (input.num_values() == 0) { - if (passed_length != -1) { - input.length = passed_length; - } - } else { - bool all_same_length = false; - int64_t inferred_length = detail::InferBatchLength(input.values, &all_same_length); - input.length = inferred_length; - if (func.kind() == Function::SCALAR) { - if (passed_length != -1 && passed_length != inferred_length) { - return Status::Invalid( - "Passed batch length for execution did not match actual" - " length of values for scalar function execution"); - } - } else if (func.kind() == Function::VECTOR) { - auto vkernel = static_cast(kernel); - if (!(all_same_length || !vkernel->can_execute_chunkwise)) { - return Status::Invalid("Vector kernel arguments must all be the same length"); - } - } - } - RETURN_NOT_OK(executor->Execute(input, &listener)); - const auto out = executor->WrapResults(input.values, listener.values()); -#ifndef NDEBUG - DCHECK_OK(executor->CheckResultType(out, func.name().c_str())); -#endif - return out; +Result ExecuteInternal(const Function& func, std::vector args, + int64_t passed_length, const FunctionOptions* options, + ExecContext* ctx) { + ARROW_ASSIGN_OR_RAISE(auto inputs, internal::GetFunctionArgumentTypes(args)); + ARROW_ASSIGN_OR_RAISE(auto func_exec, func.GetBestExecutor(inputs)); + ARROW_RETURN_NOT_OK(func_exec->Init(options, ctx)); + return func_exec->Execute(args, passed_length); } } // namespace diff --git a/cpp/src/arrow/compute/function.h b/cpp/src/arrow/compute/function.h index 7f2fba68caf..8a1b0da424a 100644 --- a/cpp/src/arrow/compute/function.h +++ b/cpp/src/arrow/compute/function.h @@ -159,6 +159,29 @@ struct ARROW_EXPORT FunctionDoc { static const FunctionDoc& Empty(); }; +/// \brief An executor of a function with a preconfigured kernel +class ARROW_EXPORT FunctionExecutor { + public: + virtual ~FunctionExecutor() = default; + /// \brief Initialize or re-initialize the preconfigured kernel + /// + /// This method may be called zero or more times. Depending on how + /// the FunctionExecutor was obtained, it may already have been initialized. + virtual Status Init(const FunctionOptions* options = NULLPTR, + ExecContext* exec_ctx = NULLPTR) = 0; + /// \brief Execute the preconfigured kernel with arguments that must fit it + /// + /// The method requires the arguments be castable to the preconfigured types. + /// + /// \param[in] args Arguments to execute the function on + /// \param[in] length Length of arguments batch or -1 to default it. If the + /// function has no parameters, this determines the batch length, defaulting + /// to 0. Otherwise, if the function is scalar, this must equal the argument + /// batch's inferred length or be -1 to default to it. This is ignored for + /// vector functions. + virtual Result Execute(const std::vector& args, int64_t length = -1) = 0; +}; + /// \brief Base class for compute functions. Function implementations contain a /// collection of "kernels" which are implementations of the function for /// specific argument types. Selecting a viable kernel for executing a function @@ -225,6 +248,13 @@ class ARROW_EXPORT Function { /// required by the kernel. virtual Result DispatchBest(std::vector* values) const; + /// \brief Get a function executor with a best-matching kernel + /// + /// The returned executor will by default work with the default FunctionOptions + /// and KernelContext. If you want to change that, call `FunctionExecutor::Init`. + virtual Result> GetBestExecutor( + std::vector inputs) const; + /// \brief Execute the function eagerly with the passed input arguments with /// kernel dispatch, batch iteration, and memory allocation details taken /// care of. diff --git a/cpp/src/arrow/compute/function_internal.cc b/cpp/src/arrow/compute/function_internal.cc index 0a926e0a39c..cd73462e953 100644 --- a/cpp/src/arrow/compute/function_internal.cc +++ b/cpp/src/arrow/compute/function_internal.cc @@ -108,6 +108,27 @@ Result> DeserializeFunctionOptions( return FunctionOptionsFromStructScalar(scalar); } +Status CheckAllArrayOrScalar(const std::vector& values) { + for (const auto& value : values) { + if (!value.is_value()) { + return Status::TypeError( + "Tried executing function with non-array, non-scalar type: ", value.ToString()); + } + } + return Status::OK(); +} + +Result> GetFunctionArgumentTypes(const std::vector& args) { + // type-check Datum arguments here. Really we'd like to avoid this as much as + // possible + RETURN_NOT_OK(CheckAllArrayOrScalar(args)); + std::vector inputs(args.size()); + for (size_t i = 0; i != args.size(); ++i) { + inputs[i] = TypeHolder(args[i].type()); + } + return inputs; +} + } // namespace internal } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/function_internal.h b/cpp/src/arrow/compute/function_internal.h index 427d1a97a10..8ed4493e742 100644 --- a/cpp/src/arrow/compute/function_internal.h +++ b/cpp/src/arrow/compute/function_internal.h @@ -74,7 +74,7 @@ Result ValidateEnumValue(CType raw) { return Status::Invalid("Invalid value for ", EnumTraits::name(), ": ", raw); } -class GenericOptionsType : public FunctionOptionsType { +class ARROW_EXPORT GenericOptionsType : public FunctionOptionsType { public: Result> Serialize(const FunctionOptions&) const override; Result> Deserialize( @@ -664,6 +664,11 @@ const FunctionOptionsType* GetFunctionOptionsType(const Properties&... propertie return &instance; } +Status CheckAllArrayOrScalar(const std::vector& values); + +ARROW_EXPORT +Result> GetFunctionArgumentTypes(const std::vector& args); + } // namespace internal } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/function_test.cc b/cpp/src/arrow/compute/function_test.cc index ea151e81f0b..b71e5a12b50 100644 --- a/cpp/src/arrow/compute/function_test.cc +++ b/cpp/src/arrow/compute/function_test.cc @@ -23,16 +23,20 @@ #include #include +#include "arrow/array/builder_primitive.h" #include "arrow/compute/api_aggregate.h" #include "arrow/compute/api_scalar.h" #include "arrow/compute/api_vector.h" #include "arrow/compute/cast.h" +#include "arrow/compute/function_internal.h" #include "arrow/compute/kernel.h" #include "arrow/datum.h" #include "arrow/status.h" #include "arrow/testing/gtest_util.h" +#include "arrow/testing/matchers.h" #include "arrow/type.h" #include "arrow/util/key_value_metadata.h" +#include "arrow/util/logging.h" namespace arrow { namespace compute { @@ -351,5 +355,106 @@ TEST(ScalarAggregateFunction, DispatchExact) { ASSERT_TRUE(selected_kernel->signature->MatchesInputs(dispatch_args)); } +namespace { + +struct TestFunctionOptions : public FunctionOptions { + TestFunctionOptions(); + + static const char* kTypeName; + + int value; +}; + +static auto kTestFunctionOptionsType = + internal::GetFunctionOptionsType(); + +TestFunctionOptions::TestFunctionOptions() : FunctionOptions(kTestFunctionOptionsType) {} + +const char* TestFunctionOptions::kTypeName = "test_options"; + +} // namespace + +TEST(FunctionExecutor, Basics) { + VectorFunction func("vector_test", Arity::Binary(), /*doc=*/FunctionDoc::Empty()); + int init_calls = 0; + int expected_optval = 0; + ExecContext exec_ctx; + TestFunctionOptions options; + options.value = 1; + auto init = + [&](KernelContext* kernel_ctx, + const KernelInitArgs& init_args) -> Result> { + if (&exec_ctx != kernel_ctx->exec_context()) { + return Status::Invalid("expected exec context not found in kernel context"); + } + if (init_args.options != nullptr) { + const auto* test_opts = checked_cast(init_args.options); + if (test_opts->value != expected_optval) { + return Status::Invalid("bad options value"); + } + } + if (&options != init_args.options) { + return Status::Invalid("expected options not found in kernel init args"); + } + ++init_calls; + return nullptr; + }; + auto exec = [](KernelContext* ctx, const ExecSpan& args, ExecResult* out) -> Status { + [&]() { // gtest ASSERT macros require a void function + ASSERT_EQ(2, args.values.size()); + const int32_t* vals[2]; + for (size_t i = 0; i < 2; i++) { + ASSERT_TRUE(args.values[i].is_array()); + const ArraySpan& array = args.values[i].array; + ASSERT_EQ(array.type->id(), Type::INT32); + vals[i] = array.GetValues(1); + } + ASSERT_TRUE(out->is_array_data()); + auto out_data = out->array_data(); + Int32Builder builder; + for (int64_t i = 0; i < args.length; i++) { + ASSERT_OK(builder.Append(vals[0][i] + vals[1][i])); + } + ASSERT_OK_AND_ASSIGN(auto array, builder.Finish()); + *out_data.get() = *array->data(); + }(); + return Status::OK(); + }; + std::vector in_types = {int32(), int32()}; + OutputType out_type = int32(); + ASSERT_OK(func.AddKernel(in_types, out_type, exec, init)); + + ASSERT_OK_AND_ASSIGN(const Kernel* dispatched, func.DispatchExact({int32(), int32()})); + ASSERT_EQ(exec, static_cast(dispatched)->exec); + std::vector inputs = {int32(), int32()}; + + ASSERT_OK_AND_ASSIGN(auto func_exec, func.GetBestExecutor(inputs)); + ASSERT_EQ(0, init_calls); + EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, ::testing::HasSubstr("options not found"), + func_exec->Init(nullptr, &exec_ctx)); + EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, ::testing::HasSubstr("bad options value"), + func_exec->Init(&options, &exec_ctx)); + ExecContext other_exec_ctx; + EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, ::testing::HasSubstr("exec context not found"), + func_exec->Init(&options, &other_exec_ctx)); + + ArrayVector arrays = {ArrayFromJSON(int32(), "[1]"), ArrayFromJSON(int32(), "[2]"), + ArrayFromJSON(int32(), "[3]"), ArrayFromJSON(int32(), "[4]")}; + ArrayVector expected = {ArrayFromJSON(int32(), "[3]"), ArrayFromJSON(int32(), "[5]"), + ArrayFromJSON(int32(), "[7]")}; + for (int n = 1; n <= 3; n++) { + expected_optval = options.value = n; + ASSERT_OK(func_exec->Init(&options, &exec_ctx)); + ASSERT_EQ(n, init_calls); + for (int32_t i = 1; i <= 3; i++) { + std::vector values = {arrays[i - 1], arrays[i]}; + ASSERT_OK_AND_ASSIGN(auto result, func_exec->Execute(values, 1)); + ASSERT_TRUE(result.is_array()); + auto actual = result.make_array(); + AssertArraysEqual(*expected[i - 1], *actual); + } + } +} + } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/type_fwd.h b/cpp/src/arrow/compute/type_fwd.h index 11c45fde091..67dc5a278b4 100644 --- a/cpp/src/arrow/compute/type_fwd.h +++ b/cpp/src/arrow/compute/type_fwd.h @@ -27,7 +27,9 @@ struct TypeHolder; namespace compute { class Function; +class FunctionExecutor; class FunctionOptions; +class FunctionRegistry; class CastOptions; diff --git a/docs/source/cpp/api/compute.rst b/docs/source/cpp/api/compute.rst index 288e280d0cf..5e490fc4089 100644 --- a/docs/source/cpp/api/compute.rst +++ b/docs/source/cpp/api/compute.rst @@ -31,6 +31,13 @@ Abstract Function classes :content-only: :members: +Function execution +------------------ + +.. doxygengroup:: compute-functions-executor + :content-only: + :members: + Function registry ----------------- diff --git a/python/pyarrow/tests/test_compute.py b/python/pyarrow/tests/test_compute.py index 9e5d0e554d6..1c0dba9acd8 100644 --- a/python/pyarrow/tests/test_compute.py +++ b/python/pyarrow/tests/test_compute.py @@ -2927,5 +2927,5 @@ def test_expression_call_function(): def test_cast_table_raises(): table = pa.table({'a': [1, 2]}) - with pytest.raises(pa.lib.ArrowInvalid): + with pytest.raises(pa.lib.ArrowTypeError): pc.cast(table, pa.int64())