From cb128bab076869e17ccd0ea9f6203f781a194421 Mon Sep 17 00:00:00 2001 From: Vibhatha Lakmal Abeykoon Date: Wed, 26 Oct 2022 07:55:19 +0530 Subject: [PATCH 01/13] feat(initial): custom aggregate udf added --- cpp/examples/arrow/udf_example.cc | 145 +++++++++++++++++++++++++++++- 1 file changed, 141 insertions(+), 4 deletions(-) diff --git a/cpp/examples/arrow/udf_example.cc b/cpp/examples/arrow/udf_example.cc index 573b5ccc78a..a1482f85164 100644 --- a/cpp/examples/arrow/udf_example.cc +++ b/cpp/examples/arrow/udf_example.cc @@ -83,15 +83,152 @@ arrow::Status Execute() { ARROW_ASSIGN_OR_RAISE(auto res, cp::CallFunction(name, {x, y, z})); auto res_array = res.make_array(); - std::cout << "Result" << std::endl; + std::cout << "Scalar UDF Result" << std::endl; std::cout << res_array->ToString() << std::endl; return arrow::Status::OK(); } +// User-defined Scalar Aggregate Function Example +struct ScalarUdfAggregator : public cp::KernelState { + virtual arrow::Status Consume(cp::KernelContext* ctx, const cp::ExecSpan& batch) = 0; + virtual arrow::Status MergeFrom(cp::KernelContext* ctx, cp::KernelState&& src) = 0; + virtual arrow::Status Finalize(cp::KernelContext* ctx, arrow::Datum* out) = 0; +}; + +class SimpleCountFunctionOptionsType : public cp::FunctionOptionsType { + const char* type_name() const override { return "SimpleCountFunctionOptionsType"; } + std::string Stringify(const cp::FunctionOptions&) const override { + return "SimpleCountFunctionOptionsType"; + } + bool Compare(const cp::FunctionOptions&, const cp::FunctionOptions&) const override { + return true; + } + std::unique_ptr Copy(const cp::FunctionOptions&) const override; +}; + +cp::FunctionOptionsType* GetSimpleCountFunctionOptionsType() { + static SimpleCountFunctionOptionsType options_type; + return &options_type; +} + +class SimpleCountOptions : public cp::FunctionOptions { + public: + SimpleCountOptions() : cp::FunctionOptions(GetSimpleCountFunctionOptionsType()) {} + static constexpr char const kTypeName[] = "SimpleCountOptions"; + static SimpleCountOptions Defaults() { return SimpleCountOptions{}; } +}; + +std::unique_ptr SimpleCountFunctionOptionsType::Copy( + const cp::FunctionOptions&) const { + return std::make_unique(); +} + +const cp::FunctionDoc simple_count_doc{ + "SimpleCount the number of null / non-null values", + ("By default, only non-null values are counted.\n" + "This can be changed through SimpleCountOptions."), + {"array"}, + "SimpleCountOptions"}; + +struct SimpleCountImpl : public ScalarUdfAggregator { + explicit SimpleCountImpl(SimpleCountOptions options) : options(std::move(options)) {} + + arrow::Status Consume(cp::KernelContext*, const cp::ExecSpan& batch) override { + if (batch[0].is_array()) { + const arrow::ArraySpan& input = batch[0].array; + const int64_t nulls = input.GetNullCount(); + this->non_nulls += input.length - nulls; + } else { + const arrow::Scalar& input = *batch[0].scalar; + this->non_nulls += input.is_valid * batch.length; + } + return arrow::Status::OK(); + } + + arrow::Status MergeFrom(cp::KernelContext*, cp::KernelState&& src) override { + const auto& other_state = arrow::internal::checked_cast(src); + this->non_nulls += other_state.non_nulls; + return arrow::Status::OK(); + } + + arrow::Status Finalize(cp::KernelContext* ctx, arrow::Datum* out) override { + const auto& state = + arrow::internal::checked_cast(*ctx->state()); + *out = arrow::Datum(state.non_nulls); + return arrow::Status::OK(); + } + + SimpleCountOptions options; + int64_t non_nulls = 0; +}; + +arrow::Result> SimpleCountInit( + cp::KernelContext*, const cp::KernelInitArgs& args) { + return std::make_unique( + static_cast(*args.options)); +} + +arrow::Status AggregateUdfConsume(cp::KernelContext* ctx, const cp::ExecSpan& batch) { + return arrow::internal::checked_cast(ctx->state()) + ->Consume(ctx, batch); +} + +arrow::Status AggregateUdfMerge(cp::KernelContext* ctx, cp::KernelState&& src, + cp::KernelState* dst) { + return arrow::internal::checked_cast(dst)->MergeFrom( + ctx, std::move(src)); +} + +arrow::Status AggregateUdfFinalize(cp::KernelContext* ctx, arrow::Datum* out) { + return arrow::internal::checked_cast(ctx->state()) + ->Finalize(ctx, out); +} + +arrow::Status AddAggKernel(std::shared_ptr sig, cp::KernelInit init, + cp::ScalarAggregateFunction* func) { + cp::ScalarAggregateKernel kernel(std::move(sig), std::move(init), AggregateUdfConsume, + AggregateUdfMerge, AggregateUdfFinalize); + ARROW_RETURN_NOT_OK(func->AddKernel(std::move(kernel))); + return arrow::Status::OK(); +} + +arrow::Status ExecuteAggregate() { + auto registry = cp::GetFunctionRegistry(); + static auto default_scalar_aggregate_options = cp::ScalarAggregateOptions::Defaults(); + static auto default_count_options = SimpleCountOptions::Defaults(); + const std::string name = "simple_count"; + auto func = std::make_shared( + name, cp::Arity::Unary(), simple_count_doc, &default_count_options); + + // Takes any input, outputs int64 scalar + cp::InputType any_input; + ARROW_RETURN_NOT_OK( + AddAggKernel(cp::KernelSignature::Make({arrow::int64()}, arrow::int64()), + SimpleCountInit, func.get())); + ARROW_RETURN_NOT_OK(registry->AddFunction(std::move(func))); + + ARROW_ASSIGN_OR_RAISE(auto x, GetArrayDataSample({1, 2, 3, 4, 5, 6})); + + ARROW_ASSIGN_OR_RAISE(auto res, cp::CallFunction(name, {x})); + auto res_scalar = res.scalar(); + std::cout << "Aggregate UDF Result" << std::endl; + std::cout << res_scalar->ToString() << std::endl; + + return arrow::Status::OK(); +} + int main(int argc, char** argv) { - auto status = Execute(); - if (!status.ok()) { - std::cerr << "Error occurred : " << status.message() << std::endl; + std::cout << "Sample Scalar UDF Execution" << std::endl; + auto s1 = Execute(); + if (!s1.ok()) { + std::cerr << "Error occurred : " << s1.message() << std::endl; + return EXIT_FAILURE; + } + + std::cout << "Sample Aggregate UDF Execution" << std::endl; + auto s2 = ExecuteAggregate(); + if (!s2.ok()) { + std::cerr << "Error occurred : " << s2.message() << std::endl; return EXIT_FAILURE; } return EXIT_SUCCESS; From 65d40d1d5dc970905930ee532fac6e3fa86dd5e7 Mon Sep 17 00:00:00 2001 From: Vibhatha Lakmal Abeykoon Date: Wed, 26 Oct 2022 13:17:03 +0530 Subject: [PATCH 02/13] feat(temp-init) --- cpp/examples/arrow/udf_example.cc | 3 +- python/pyarrow/src/arrow/python/udf.cc | 182 +++++++++++++++++++++++++ python/pyarrow/src/arrow/python/udf.h | 26 ++++ 3 files changed, 210 insertions(+), 1 deletion(-) diff --git a/cpp/examples/arrow/udf_example.cc b/cpp/examples/arrow/udf_example.cc index a1482f85164..fce4c7a880d 100644 --- a/cpp/examples/arrow/udf_example.cc +++ b/cpp/examples/arrow/udf_example.cc @@ -130,6 +130,7 @@ const cp::FunctionDoc simple_count_doc{ {"array"}, "SimpleCountOptions"}; +// Need Python interface for this Class struct SimpleCountImpl : public ScalarUdfAggregator { explicit SimpleCountImpl(SimpleCountOptions options) : options(std::move(options)) {} @@ -162,6 +163,7 @@ struct SimpleCountImpl : public ScalarUdfAggregator { int64_t non_nulls = 0; }; +// TODO: need a Python interface for this function arrow::Result> SimpleCountInit( cp::KernelContext*, const cp::KernelInitArgs& args) { return std::make_unique( @@ -201,7 +203,6 @@ arrow::Status ExecuteAggregate() { name, cp::Arity::Unary(), simple_count_doc, &default_count_options); // Takes any input, outputs int64 scalar - cp::InputType any_input; ARROW_RETURN_NOT_OK( AddAggKernel(cp::KernelSignature::Make({arrow::int64()}, arrow::int64()), SimpleCountInit, func.get())); diff --git a/python/pyarrow/src/arrow/python/udf.cc b/python/pyarrow/src/arrow/python/udf.cc index 81bf47c0ade..55d058f0570 100644 --- a/python/pyarrow/src/arrow/python/udf.cc +++ b/python/pyarrow/src/arrow/python/udf.cc @@ -17,6 +17,7 @@ #include "arrow/python/udf.h" #include "arrow/compute/function.h" +#include "arrow/compute/api_aggregate.h" #include "arrow/python/common.h" namespace arrow { @@ -120,6 +121,187 @@ Status RegisterScalarFunction(PyObject* user_function, ScalarUdfWrapperCallback return Status::OK(); } +// Scalar Aggregate Functions + +struct ScalarUdfAggregator : public compute::KernelState { + virtual Status Consume(compute::KernelContext* ctx, const compute::ExecSpan& batch) = 0; + virtual Status MergeFrom(compute::KernelContext* ctx, compute::KernelState&& src) = 0; + virtual Status Finalize(compute::KernelContext* ctx, Datum* out) = 0; +}; + +arrow::Status AggregateUdfConsume(compute::KernelContext* ctx, const compute::ExecSpan& batch) { + return arrow::internal::checked_cast(ctx->state()) + ->Consume(ctx, batch); +} + +arrow::Status AggregateUdfMerge(compute::KernelContext* ctx, compute::KernelState&& src, + compute::KernelState* dst) { + return arrow::internal::checked_cast(dst)->MergeFrom( + ctx, std::move(src)); +} + +arrow::Status AggregateUdfFinalize(compute::KernelContext* ctx, arrow::Datum* out) { + return arrow::internal::checked_cast(ctx->state()) + ->Finalize(ctx, out); +} + +struct PythonScalarUdfAggregatorImpl : public ScalarUdfAggregator { + + ScalarAggregateConsumeUdfWrapperCallback consume_cb; + ScalarAggregateMergeUdfWrapperCallback merge_cb; + ScalarAggregateFinalizeUdfWrapperCallback finalize_cb; + std::shared_ptr consume_function; + std::shared_ptr merge_function; + std::shared_ptr finalize_function; + std::shared_ptr output_type; + + + PythonScalarUdfAggregatorImpl(ScalarAggregateConsumeUdfWrapperCallback consume_cb, + ScalarAggregateMergeUdfWrapperCallback merge_cb, + ScalarAggregateFinalizeUdfWrapperCallback finalize_cb, + std::shared_ptr consume_function, + std::shared_ptr merge_function, + std::shared_ptr finalize_function, + const std::shared_ptr& output_type) : consume_cb(consume_cb), + merge_cb(merge_cb), + finalize_cb(finalize_cb), + consume_function(consume_function), + merge_function(merge_function), + finalize_function(finalize_function), + output_type(output_type) {} + + ~PythonScalarUdfAggregatorImpl() { + if (_Py_IsFinalizing()) { + consume_function->detach(); + merge_function->detach(); + finalize_function->detach(); + } + } + + Status ConsumeBatch(compute::KernelContext* ctx, const compute::ExecSpan& batch) { + const int num_args = batch.num_values(); + this->batch_length = batch.length; + ScalarAggregateUdfContext udf_context{ctx->memory_pool(), batch.length}; + // TODO: think about guaranteeing DRY (following logic already used in ScalarUDFs) + OwnedRef arg_tuple(PyTuple_New(num_args)); + RETURN_NOT_OK(CheckPyError()); + for (int arg_id = 0; arg_id < num_args; arg_id++) { + if (batch[arg_id].is_scalar()) { + std::shared_ptr c_data = batch[arg_id].scalar->GetSharedPtr(); + PyObject* data = wrap_scalar(c_data); + PyTuple_SetItem(arg_tuple.obj(), arg_id, data); + } else { + std::shared_ptr c_data = batch[arg_id].array.ToArray(); + PyObject* data = wrap_array(c_data); + PyTuple_SetItem(arg_tuple.obj(), arg_id, data); + } + } + consume_cb(consume_function->obj(), udf_context, arg_tuple.obj()); + RETURN_NOT_OK(CheckPyError()); + return Status::OK(); + } + + Status Consume(compute::KernelContext* ctx, const compute::ExecSpan& batch) override { + RETURN_NOT_OK(ConsumeBatch(ctx, batch)); + return Status::OK(); + } + + Status MergeFrom(compute::KernelContext* ctx, compute::KernelState&& src) override { + ScalarAggregateUdfContext udf_context{ctx->memory_pool(), this->batch_length}; + merge_cb(merge_function->obj(), udf_context); + return Status::OK(); + }; + + Status Finalize(compute::KernelContext* ctx, arrow::Datum* out) override { + ScalarAggregateUdfContext udf_context{ctx->memory_pool(), this->batch_length}; + OwnedRef result(finalize_cb(finalize_function->obj(), udf_context)); + RETURN_NOT_OK(CheckPyError()); + // unwrapping the output for expected output type + if (is_array(result.obj())) { + ARROW_ASSIGN_OR_RAISE(std::shared_ptr val, unwrap_array(result.obj())); + if (!output_type->Equals(*val->type())) { + return Status::TypeError("Expected output datatype ", output_type->ToString(), + ", but function returned datatype ", + val->type()->ToString()); + } + out = std::move(new Datum(std::move(val))); + return Status::OK(); + } else { + return Status::TypeError("Unexpected output type: ", Py_TYPE(result.obj())->tp_name, + " (expected Array)"); + } + return Status::OK(); + }; + +private: + int batch_length = 1; +}; + +Status AddAggKernel(std::shared_ptr sig, compute::KernelInit init, + compute::ScalarAggregateFunction* func) { + compute::ScalarAggregateKernel kernel(std::move(sig), std::move(init), AggregateUdfConsume, + AggregateUdfMerge, AggregateUdfFinalize); + RETURN_NOT_OK(func->AddKernel(std::move(kernel))); + return Status::OK(); +} + +Status RegisterScalarAggregateFunction(PyObject* consume_function, + ScalarAggregateConsumeUdfWrapperCallback consume_wrapper, + PyObject* merge_function, + ScalarAggregateMergeUdfWrapperCallback merge_wrapper, + PyObject* finalize_function, + ScalarAggregateFinalizeUdfWrapperCallback finalize_wrapper, + const ScalarUdfOptions& options) { + if (!PyCallable_Check(consume_function) || !PyCallable_Check(merge_function) || !PyCallable_Check(finalize_function)) { + return Status::TypeError("Expected a callable Python object."); + } + static auto default_scalar_aggregate_options = compute::ScalarAggregateOptions::Defaults(); + auto aggregate_func = std::make_shared( + options.func_name, options.arity, options.func_doc, &default_scalar_aggregate_options); + + Py_INCREF(consume_function); + Py_INCREF(merge_function); + Py_INCREF(finalize_function); + + std::vector input_types; + for (const auto& in_dtype : options.input_types) { + input_types.emplace_back(in_dtype); + } + compute::OutputType output_type(options.output_type); + auto udf_data = std::make_shared( + consume_wrapper, + merge_wrapper, + finalize_wrapper, + std::make_shared(consume_function), + std::make_shared(merge_function), + std::make_shared(finalize_function), + options.output_type); + + auto init = [aggregate_func]( + compute::KernelContext* ctx, + const compute::KernelInitArgs& args) -> Result> { + ARROW_ASSIGN_OR_RAISE(auto kernel, aggregate_func->DispatchExact(args.inputs)); + compute::KernelInitArgs new_args{kernel, args.inputs, args.options}; + return kernel->init(ctx, new_args); + }; + + RETURN_NOT_OK( + AddAggKernel(compute::KernelSignature::Make(input_types, output_type), + init, aggregate_func.get())); + + compute::ScalarKernel kernel( + compute::KernelSignature::Make(std::move(input_types), std::move(output_type), + options.arity.is_varargs), + PythonUdfExec); + kernel.data = std::move(udf_data); + + kernel.mem_allocation = compute::MemAllocation::NO_PREALLOCATE; + kernel.null_handling = compute::NullHandling::COMPUTED_NO_PREALLOCATE; + auto registry = compute::GetFunctionRegistry(); + RETURN_NOT_OK(registry->AddFunction(std::move(aggregate_func))); + return Status::OK(); +} + } // namespace py } // namespace arrow diff --git a/python/pyarrow/src/arrow/python/udf.h b/python/pyarrow/src/arrow/python/udf.h index 9a3666459fd..584570996b4 100644 --- a/python/pyarrow/src/arrow/python/udf.h +++ b/python/pyarrow/src/arrow/python/udf.h @@ -47,6 +47,13 @@ struct ARROW_PYTHON_EXPORT ScalarUdfContext { int64_t batch_length; }; + +struct ARROW_PYTHON_EXPORT ScalarAggregateUdfContext { + MemoryPool* pool; + int64_t batch_length; +}; + + using ScalarUdfWrapperCallback = std::function; @@ -55,6 +62,25 @@ Status ARROW_PYTHON_EXPORT RegisterScalarFunction(PyObject* user_function, ScalarUdfWrapperCallback wrapper, const ScalarUdfOptions& options); +using ScalarAggregateConsumeUdfWrapperCallback = std::function; + +using ScalarAggregateMergeUdfWrapperCallback = std::function; + +using ScalarAggregateFinalizeUdfWrapperCallback = std::function; + +/// \brief register a Scalar Aggregate user-defined-function from Python +Status ARROW_PYTHON_EXPORT RegisterScalarAggregateFunction(PyObject* consume_function, + ScalarAggregateConsumeUdfWrapperCallback consume_wrapper, + PyObject* merge_function, + ScalarAggregateMergeUdfWrapperCallback merge_wrapper, + PyObject* finalize_function, + ScalarAggregateFinalizeUdfWrapperCallback finalize_wrapper, + const ScalarUdfOptions& options); + + } // namespace py } // namespace arrow From 47b85305741ac70d8968dae036aebbc989590283 Mon Sep 17 00:00:00 2001 From: Vibhatha Lakmal Abeykoon Date: Wed, 26 Oct 2022 13:31:18 +0530 Subject: [PATCH 03/13] fix(init-method) --- python/pyarrow/src/arrow/python/udf.cc | 22 ++++++---------------- 1 file changed, 6 insertions(+), 16 deletions(-) diff --git a/python/pyarrow/src/arrow/python/udf.cc b/python/pyarrow/src/arrow/python/udf.cc index 55d058f0570..26ce845b742 100644 --- a/python/pyarrow/src/arrow/python/udf.cc +++ b/python/pyarrow/src/arrow/python/udf.cc @@ -268,7 +268,12 @@ Status RegisterScalarAggregateFunction(PyObject* consume_function, input_types.emplace_back(in_dtype); } compute::OutputType output_type(options.output_type); - auto udf_data = std::make_shared( + + auto init = [consume_wrapper, merge_wrapper, finalize_wrapper, + consume_function, merge_function, finalize_function, options]( + compute::KernelContext* ctx, + const compute::KernelInitArgs& args) -> Result> { + return std::make_unique( consume_wrapper, merge_wrapper, finalize_wrapper, @@ -276,27 +281,12 @@ Status RegisterScalarAggregateFunction(PyObject* consume_function, std::make_shared(merge_function), std::make_shared(finalize_function), options.output_type); - - auto init = [aggregate_func]( - compute::KernelContext* ctx, - const compute::KernelInitArgs& args) -> Result> { - ARROW_ASSIGN_OR_RAISE(auto kernel, aggregate_func->DispatchExact(args.inputs)); - compute::KernelInitArgs new_args{kernel, args.inputs, args.options}; - return kernel->init(ctx, new_args); }; RETURN_NOT_OK( AddAggKernel(compute::KernelSignature::Make(input_types, output_type), init, aggregate_func.get())); - compute::ScalarKernel kernel( - compute::KernelSignature::Make(std::move(input_types), std::move(output_type), - options.arity.is_varargs), - PythonUdfExec); - kernel.data = std::move(udf_data); - - kernel.mem_allocation = compute::MemAllocation::NO_PREALLOCATE; - kernel.null_handling = compute::NullHandling::COMPUTED_NO_PREALLOCATE; auto registry = compute::GetFunctionRegistry(); RETURN_NOT_OK(registry->AddFunction(std::move(aggregate_func))); return Status::OK(); From 52f9fbaa19f7612481ea3f3d30076ad8f2372fe6 Mon Sep 17 00:00:00 2001 From: Vibhatha Lakmal Abeykoon Date: Wed, 26 Oct 2022 18:11:48 +0530 Subject: [PATCH 04/13] feat(initial-python): wip --- python/pyarrow/_compute.pxd | 6 + python/pyarrow/_compute.pyx | 145 +++++++++++++++++++++++++ python/pyarrow/compute.py | 2 + python/pyarrow/includes/libarrow.pxd | 18 +++ python/pyarrow/src/arrow/python/udf.cc | 59 ++++++---- python/pyarrow/tests/test_udf.py | 55 ++++++++++ 6 files changed, 263 insertions(+), 22 deletions(-) diff --git a/python/pyarrow/_compute.pxd b/python/pyarrow/_compute.pxd index 8b09cbd445e..b34feb5a3a9 100644 --- a/python/pyarrow/_compute.pxd +++ b/python/pyarrow/_compute.pxd @@ -27,6 +27,12 @@ cdef class ScalarUdfContext(_Weakrefable): cdef void init(self, const CScalarUdfContext& c_context) +cdef class ScalarAggregateUdfContext(_Weakrefable): + cdef: + CScalarAggregateUdfContext c_context + + cdef void init(self, const CScalarAggregateUdfContext& c_context) + cdef class FunctionOptions(_Weakrefable): cdef: shared_ptr[CFunctionOptions] wrapped diff --git a/python/pyarrow/_compute.pyx b/python/pyarrow/_compute.pyx index 43f40a86c77..c182cc676d1 100644 --- a/python/pyarrow/_compute.pyx +++ b/python/pyarrow/_compute.pyx @@ -2474,6 +2474,45 @@ cdef class ScalarUdfContext: return box_memory_pool(self.c_context.pool) +cdef class ScalarAggregateUdfContext: + """ + Per-invocation function context/state. + + This object will always be the first argument to a user-defined + function. It should not be used outside of a call to the function. + """ + + def __init__(self): + raise TypeError("Do not call {}'s constructor directly" + .format(self.__class__.__name__)) + + cdef void init(self, const CScalarAggregateUdfContext &c_context): + self.c_context = c_context + + @property + def batch_length(self): + """ + The common length of all input arguments (int). + + In the case that all arguments are scalars, this value + is used to pass the "actual length" of the arguments, + e.g. because the scalar values are encoding a column + with a constant value. + """ + return self.c_context.batch_length + + @property + def memory_pool(self): + """ + A memory pool for allocations (:class:`MemoryPool`). + + This is the memory pool supplied by the user when they invoked + the function and it should be used in any calls to arrow that the + UDF makes if that call accepts a memory_pool. + """ + return box_memory_pool(self.c_context.pool) + + cdef inline CFunctionDoc _make_function_doc(dict func_doc) except *: """ Helper function to generate the FunctionDoc @@ -2502,6 +2541,12 @@ cdef object box_scalar_udf_context(const CScalarUdfContext& c_context): return context +cdef object box_scalar_udf_agg_context(const CScalarAggregateUdfContext& c_context): + cdef ScalarAggregateUdfContext context = ScalarAggregateUdfContext.__new__(ScalarAggregateUdfContext) + context.init(c_context) + return context + + cdef _scalar_udf_callback(user_function, const CScalarUdfContext& c_context, inputs): """ Helper callback function used to wrap the ScalarUdfContext from Python to C++ @@ -2511,6 +2556,33 @@ cdef _scalar_udf_callback(user_function, const CScalarUdfContext& c_context, inp return user_function(context, *inputs) +cdef _scalar_agg_consume_udf_callback(consume_function, const CScalarAggregateUdfContext& c_context, inputs): + """ + Helper aggregate consume callback function used to wrap the ScalarAggregateUdfContext from Python to C++ + execution. + """ + context = box_scalar_udf_agg_context(c_context) + return consume_function(context, *inputs) + + +cdef _scalar_agg_merge_udf_callback(merge_function, const CScalarAggregateUdfContext& c_context): + """ + Helper aggregate merge callback function used to wrap the ScalarAggregateUdfContext from Python to C++ + execution. + """ + context = box_scalar_udf_agg_context(c_context) + return merge_function(context) + + +cdef _scalar_agg_finalize_udf_callback(finalize_function, const CScalarAggregateUdfContext& c_context): + """ + Helper aggregate finalize callback function used to wrap the ScalarAggregateUdfContext from Python to C++ + execution. + """ + context = box_scalar_udf_agg_context(c_context) + return finalize_function(context) + + def _get_scalar_udf_context(memory_pool, batch_length): cdef CScalarUdfContext c_context c_context.pool = maybe_unbox_memory_pool(memory_pool) @@ -2641,3 +2713,76 @@ def register_scalar_function(func, function_name, function_doc, in_types, check_status(RegisterScalarFunction(c_function, &_scalar_udf_callback, c_options)) + + +def register_scalar_aggregate_function(consume_func, merge_func, finalize_func, + function_name, function_doc, in_types, out_type): + + cdef: + c_string c_func_name + CArity c_arity + CFunctionDoc c_func_doc + vector[shared_ptr[CDataType]] c_in_types + PyObject* c_consume_function + PyObject* c_merge_function + PyObject* c_finalize_function + shared_ptr[CDataType] c_out_type + CScalarUdfOptions c_options + + if callable(consume_func): + c_consume_function = consume_func + else: + raise TypeError("consume_func must be a callable") + + if callable(merge_func): + c_merge_function = merge_func + else: + raise TypeError("consume_func must be a callable") + + if callable(finalize_func): + c_finalize_function = finalize_func + else: + raise TypeError("consume_func must be a callable") + + c_func_name = tobytes(function_name) + + func_spec = inspect.getfullargspec(consume_func) + num_args = -1 + if isinstance(in_types, dict): + for in_type in in_types.values(): + c_in_types.push_back( + pyarrow_unwrap_data_type(ensure_type(in_type))) + function_doc["arg_names"] = in_types.keys() + num_args = len(in_types) + else: + raise TypeError( + "in_types must be a dictionary of DataType") + + c_arity = CArity( num_args, func_spec.varargs) + + if "summary" not in function_doc: + raise ValueError("Function doc must contain a summary") + + if "description" not in function_doc: + raise ValueError("Function doc must contain a description") + + if "arg_names" not in function_doc: + raise ValueError("Function doc must contain arg_names") + + c_func_doc = _make_function_doc(function_doc) + + c_out_type = pyarrow_unwrap_data_type(ensure_type(out_type)) + + c_options.func_name = c_func_name + c_options.arity = c_arity + c_options.func_doc = c_func_doc + c_options.input_types = c_in_types + c_options.output_type = c_out_type + + check_status(RegisterScalarAggregateFunction(c_consume_function, + &_scalar_agg_consume_udf_callback, + c_merge_function, + &_scalar_agg_merge_udf_callback, + c_finalize_function, + &_scalar_agg_finalize_udf_callback, + c_options)) \ No newline at end of file diff --git a/python/pyarrow/compute.py b/python/pyarrow/compute.py index 5873571c5a0..d72a28df84e 100644 --- a/python/pyarrow/compute.py +++ b/python/pyarrow/compute.py @@ -80,7 +80,9 @@ _group_by, # Udf register_scalar_function, + register_scalar_aggregate_function, ScalarUdfContext, + ScalarAggregateUdfContext, # Expressions Expression, ) diff --git a/python/pyarrow/includes/libarrow.pxd b/python/pyarrow/includes/libarrow.pxd index fbedb0fce36..c4f05ab1611 100644 --- a/python/pyarrow/includes/libarrow.pxd +++ b/python/pyarrow/includes/libarrow.pxd @@ -2761,11 +2761,21 @@ cdef extern from "arrow/util/byte_size.h" namespace "arrow::util" nogil: ctypedef PyObject* CallbackUdf(object user_function, const CScalarUdfContext& context, object inputs) +ctypedef PyObject* CallbackAggConsumeUdf(object consume_function, const CScalarAggregateUdfContext& context, object inputs) + +ctypedef void CallbackAggMergeUdf(object merge_function, const CScalarAggregateUdfContext& context) + +ctypedef PyObject* CallbackAggFinalizeUdf(object finalize_function, const CScalarAggregateUdfContext& context) + cdef extern from "arrow/python/udf.h" namespace "arrow::py": cdef cppclass CScalarUdfContext" arrow::py::ScalarUdfContext": CMemoryPool *pool int64_t batch_length + cdef cppclass CScalarAggregateUdfContext" arrow::py::ScalarAggregateUdfContext": + CMemoryPool *pool + int64_t batch_length + cdef cppclass CScalarUdfOptions" arrow::py::ScalarUdfOptions": c_string func_name CArity arity @@ -2775,3 +2785,11 @@ cdef extern from "arrow/python/udf.h" namespace "arrow::py": CStatus RegisterScalarFunction(PyObject* function, function[CallbackUdf] wrapper, const CScalarUdfOptions& options) + + CStatus RegisterScalarAggregateFunction(PyObject* consume_function, + function[CallbackAggConsumeUdf] consume_wrapper, + PyObject* merge_function, + function[CallbackAggMergeUdf] merge_wrapper, + PyObject* finalize_function, + function[CallbackAggFinalizeUdf] finalize_wrapper, + const CScalarUdfOptions& options) \ No newline at end of file diff --git a/python/pyarrow/src/arrow/python/udf.cc b/python/pyarrow/src/arrow/python/udf.cc index 26ce845b742..e25f85713ec 100644 --- a/python/pyarrow/src/arrow/python/udf.cc +++ b/python/pyarrow/src/arrow/python/udf.cc @@ -20,6 +20,9 @@ #include "arrow/compute/api_aggregate.h" #include "arrow/python/common.h" +// TODO REMOVE +#include + namespace arrow { using compute::ExecResult; @@ -179,10 +182,13 @@ struct PythonScalarUdfAggregatorImpl : public ScalarUdfAggregator { } Status ConsumeBatch(compute::KernelContext* ctx, const compute::ExecSpan& batch) { + std::cout << "ConsumeBatch" << std::endl; const int num_args = batch.num_values(); this->batch_length = batch.length; ScalarAggregateUdfContext udf_context{ctx->memory_pool(), batch.length}; // TODO: think about guaranteeing DRY (following logic already used in ScalarUDFs) + std::cout << "Num Args : " << num_args << std::endl; + std::cout << "Batch length : " << this->batch_length << std::endl; OwnedRef arg_tuple(PyTuple_New(num_args)); RETURN_NOT_OK(CheckPyError()); for (int arg_id = 0; arg_id < num_args; arg_id++) { @@ -196,41 +202,49 @@ struct PythonScalarUdfAggregatorImpl : public ScalarUdfAggregator { PyTuple_SetItem(arg_tuple.obj(), arg_id, data); } } + std::cout << "Args set " << std::endl; consume_cb(consume_function->obj(), udf_context, arg_tuple.obj()); + std::cout << "Function executed" << std::endl; RETURN_NOT_OK(CheckPyError()); return Status::OK(); } Status Consume(compute::KernelContext* ctx, const compute::ExecSpan& batch) override { - RETURN_NOT_OK(ConsumeBatch(ctx, batch)); - return Status::OK(); + return SafeCallIntoPython([&]() -> Status { return ConsumeBatch(ctx, batch); }); } Status MergeFrom(compute::KernelContext* ctx, compute::KernelState&& src) override { ScalarAggregateUdfContext udf_context{ctx->memory_pool(), this->batch_length}; - merge_cb(merge_function->obj(), udf_context); - return Status::OK(); + return SafeCallIntoPython([&]() -> Status { + merge_cb(merge_function->obj(), udf_context); + return Status::OK(); + }); }; Status Finalize(compute::KernelContext* ctx, arrow::Datum* out) override { - ScalarAggregateUdfContext udf_context{ctx->memory_pool(), this->batch_length}; - OwnedRef result(finalize_cb(finalize_function->obj(), udf_context)); - RETURN_NOT_OK(CheckPyError()); - // unwrapping the output for expected output type - if (is_array(result.obj())) { - ARROW_ASSIGN_OR_RAISE(std::shared_ptr val, unwrap_array(result.obj())); - if (!output_type->Equals(*val->type())) { - return Status::TypeError("Expected output datatype ", output_type->ToString(), - ", but function returned datatype ", - val->type()->ToString()); - } - out = std::move(new Datum(std::move(val))); - return Status::OK(); - } else { - return Status::TypeError("Unexpected output type: ", Py_TYPE(result.obj())->tp_name, - " (expected Array)"); - } - return Status::OK(); + return SafeCallIntoPython([&]() -> Status { + ScalarAggregateUdfContext udf_context{ctx->memory_pool(), this->batch_length}; + OwnedRef result(finalize_cb(finalize_function->obj(), udf_context)); + std::cout << "Finalize Python Call finished in C++" << std::endl; + RETURN_NOT_OK(CheckPyError()); + std::cout << "CheckPyError done" << std::endl; + // unwrapping the output for expected output type + if (is_array(result.obj())) { + ARROW_ASSIGN_OR_RAISE(std::shared_ptr val, unwrap_array(result.obj())); + if (!output_type->Equals(*val->type())) { + return Status::TypeError("Expected output datatype ", output_type->ToString(), + ", but function returned datatype ", + val->type()->ToString()); + } + std::cout << "Finalize called to C++ : " << val->ToString() << std::endl; + *out = Datum(std::move(val)); + std::cout << "Final value set" << std::endl; + return Status::OK(); + } else { + return Status::TypeError("Unexpected output type: ", Py_TYPE(result.obj())->tp_name, + " (expected Array)"); + } + }); }; private: @@ -252,6 +266,7 @@ Status RegisterScalarAggregateFunction(PyObject* consume_function, PyObject* finalize_function, ScalarAggregateFinalizeUdfWrapperCallback finalize_wrapper, const ScalarUdfOptions& options) { + std::cout << "RegisterScalarAggregateFunction" << std::endl; if (!PyCallable_Check(consume_function) || !PyCallable_Check(merge_function) || !PyCallable_Check(finalize_function)) { return Status::TypeError("Expected a callable Python object."); } diff --git a/python/pyarrow/tests/test_udf.py b/python/pyarrow/tests/test_udf.py index e711619582d..176ca3bde4d 100644 --- a/python/pyarrow/tests/test_udf.py +++ b/python/pyarrow/tests/test_udf.py @@ -504,3 +504,58 @@ def test_input_lifetime(unary_func_fixture): # Calling a UDF should not have kept `v` alive longer than required v = None assert proxy_pool.bytes_allocated() == 0 + + +def test_aggregate_udf(): + + class SimpleCount: + + def __init__(self): + self._count = 0 + + def consume(self, ctx, x): + if isinstance(x, pa.Array): + self._count = self._count + len(x) + elif isinstance(x, pa.Scalar): + self._count = self._count + 1 + + def merge(self, ctx): + pass + + def finalize(self, ctx): + return pa.scalar(self._count) + + + def consume(ctx, x): + if isinstance(x, pa.Array): + print("consume: array: ", len(x) + 1) + elif isinstance(x, pa.Scalar): + print(1) + + def merge(ctx): + print("call merge") + pass + + def finalize(ctx): + print("call finalize") + return pa.array([10]) + + func_name = "simple_count" + unary_doc = {"summary": "count function", + "description": "test agg count function"} + simple_count = SimpleCount() + pc.register_scalar_aggregate_function(consume, + merge, + finalize, + func_name, + unary_doc, + {"array": pa.int64()}, + pa.int64()) + + print(pc.get_function(func_name)) + + pc.call_function(func_name, [pa.array([10, 20])]) + + + + \ No newline at end of file From f3c8df58c97a82632b15be297d32c604c8ed0fd5 Mon Sep 17 00:00:00 2001 From: Vibhatha Lakmal Abeykoon Date: Thu, 27 Oct 2022 13:40:21 +0530 Subject: [PATCH 05/13] feat(initial): functional vanilla aggregate udfs --- cpp/examples/arrow/udf_example.cc | 1 + python/pyarrow/_compute.pyx | 51 ++++++-- python/pyarrow/includes/libarrow.pxd | 21 ++-- python/pyarrow/src/arrow/python/udf.cc | 166 +++++++++++++++++++++---- python/pyarrow/src/arrow/python/udf.h | 12 +- python/pyarrow/tests/test_udf.py | 78 ++++++------ 6 files changed, 239 insertions(+), 90 deletions(-) diff --git a/cpp/examples/arrow/udf_example.cc b/cpp/examples/arrow/udf_example.cc index fce4c7a880d..f1d47610364 100644 --- a/cpp/examples/arrow/udf_example.cc +++ b/cpp/examples/arrow/udf_example.cc @@ -148,6 +148,7 @@ struct SimpleCountImpl : public ScalarUdfAggregator { arrow::Status MergeFrom(cp::KernelContext*, cp::KernelState&& src) override { const auto& other_state = arrow::internal::checked_cast(src); + std::cout << "This non_nulls: " << this->non_nulls << std::endl; this->non_nulls += other_state.non_nulls; return arrow::Status::OK(); } diff --git a/python/pyarrow/_compute.pyx b/python/pyarrow/_compute.pyx index c182cc676d1..a2a8b760d21 100644 --- a/python/pyarrow/_compute.pyx +++ b/python/pyarrow/_compute.pyx @@ -2512,6 +2512,20 @@ cdef class ScalarAggregateUdfContext: """ return box_memory_pool(self.c_context.pool) + @property + def state(self): + """ + An object which maintains the state of aggregation + """ + obj = self.c_context.state + if obj is None: + raise RuntimeError("Error occurred in extracting state") + return obj + + @state.setter + def state(self, value): + self.c_context.state = value + cdef inline CFunctionDoc _make_function_doc(dict func_doc) except *: """ @@ -2565,13 +2579,13 @@ cdef _scalar_agg_consume_udf_callback(consume_function, const CScalarAggregateUd return consume_function(context, *inputs) -cdef _scalar_agg_merge_udf_callback(merge_function, const CScalarAggregateUdfContext& c_context): +cdef _scalar_agg_merge_udf_callback(merge_function, const CScalarAggregateUdfContext& c_context, current_state, other_state): """ Helper aggregate merge callback function used to wrap the ScalarAggregateUdfContext from Python to C++ execution. """ context = box_scalar_udf_agg_context(c_context) - return merge_function(context) + return merge_function(context, *current_state, *other_state) cdef _scalar_agg_finalize_udf_callback(finalize_function, const CScalarAggregateUdfContext& c_context): @@ -2582,6 +2596,13 @@ cdef _scalar_agg_finalize_udf_callback(finalize_function, const CScalarAggregate context = box_scalar_udf_agg_context(c_context) return finalize_function(context) +cdef _scalar_agg_init_udf_callback(init_function): + """ + Helper aggregate initialize callback function used to wrap the ScalarAggregateUdfContext from Python to C++ + execution. + """ + return init_function() + def _get_scalar_udf_context(memory_pool, batch_length): cdef CScalarUdfContext c_context @@ -2715,25 +2736,31 @@ def register_scalar_function(func, function_name, function_doc, in_types, &_scalar_udf_callback, c_options)) -def register_scalar_aggregate_function(consume_func, merge_func, finalize_func, - function_name, function_doc, in_types, out_type): +def register_scalar_aggregate_function(init_func, consume_func, merge_func, finalize_func, + function_name, function_doc, in_types, out_type): cdef: c_string c_func_name CArity c_arity CFunctionDoc c_func_doc vector[shared_ptr[CDataType]] c_in_types + PyObject* c_init_function PyObject* c_consume_function PyObject* c_merge_function PyObject* c_finalize_function shared_ptr[CDataType] c_out_type CScalarUdfOptions c_options + if callable(init_func): + c_init_function = init_func + else: + raise TypeError("init_func must be a callable") + if callable(consume_func): c_consume_function = consume_func else: raise TypeError("consume_func must be a callable") - + if callable(merge_func): c_merge_function = merge_func else: @@ -2780,9 +2807,11 @@ def register_scalar_aggregate_function(consume_func, merge_func, finalize_func, c_options.output_type = c_out_type check_status(RegisterScalarAggregateFunction(c_consume_function, - &_scalar_agg_consume_udf_callback, - c_merge_function, - &_scalar_agg_merge_udf_callback, - c_finalize_function, - &_scalar_agg_finalize_udf_callback, - c_options)) \ No newline at end of file + &_scalar_agg_consume_udf_callback, + c_merge_function, + &_scalar_agg_merge_udf_callback, + c_finalize_function, + &_scalar_agg_finalize_udf_callback, + c_init_function, + &_scalar_agg_init_udf_callback, + c_options)) diff --git a/python/pyarrow/includes/libarrow.pxd b/python/pyarrow/includes/libarrow.pxd index c4f05ab1611..b707c44c437 100644 --- a/python/pyarrow/includes/libarrow.pxd +++ b/python/pyarrow/includes/libarrow.pxd @@ -2761,9 +2761,11 @@ cdef extern from "arrow/util/byte_size.h" namespace "arrow::util" nogil: ctypedef PyObject* CallbackUdf(object user_function, const CScalarUdfContext& context, object inputs) +ctypedef PyObject* CallbackAggInitUdf(object init_function) + ctypedef PyObject* CallbackAggConsumeUdf(object consume_function, const CScalarAggregateUdfContext& context, object inputs) -ctypedef void CallbackAggMergeUdf(object merge_function, const CScalarAggregateUdfContext& context) +ctypedef PyObject* CallbackAggMergeUdf(object merge_function, const CScalarAggregateUdfContext& context, object current_state, object other_state) ctypedef PyObject* CallbackAggFinalizeUdf(object finalize_function, const CScalarAggregateUdfContext& context) @@ -2775,7 +2777,8 @@ cdef extern from "arrow/python/udf.h" namespace "arrow::py": cdef cppclass CScalarAggregateUdfContext" arrow::py::ScalarAggregateUdfContext": CMemoryPool *pool int64_t batch_length - + PyObject* state + cdef cppclass CScalarUdfOptions" arrow::py::ScalarUdfOptions": c_string func_name CArity arity @@ -2787,9 +2790,11 @@ cdef extern from "arrow/python/udf.h" namespace "arrow::py": function[CallbackUdf] wrapper, const CScalarUdfOptions& options) CStatus RegisterScalarAggregateFunction(PyObject* consume_function, - function[CallbackAggConsumeUdf] consume_wrapper, - PyObject* merge_function, - function[CallbackAggMergeUdf] merge_wrapper, - PyObject* finalize_function, - function[CallbackAggFinalizeUdf] finalize_wrapper, - const CScalarUdfOptions& options) \ No newline at end of file + function[CallbackAggConsumeUdf] consume_wrapper, + PyObject* merge_function, + function[CallbackAggMergeUdf] merge_wrapper, + PyObject* finalize_function, + function[CallbackAggFinalizeUdf] finalize_wrapper, + PyObject* init_function, + function[CallbackAggInitUdf] init_wrapper, + const CScalarUdfOptions& options) diff --git a/python/pyarrow/src/arrow/python/udf.cc b/python/pyarrow/src/arrow/python/udf.cc index e25f85713ec..e47527e8760 100644 --- a/python/pyarrow/src/arrow/python/udf.cc +++ b/python/pyarrow/src/arrow/python/udf.cc @@ -148,47 +148,144 @@ arrow::Status AggregateUdfFinalize(compute::KernelContext* ctx, arrow::Datum* ou ->Finalize(ctx, out); } +// TODO remove functions + +// debug functions +void PrintPyObject(std::string&& msg, PyObject* obj) { + std::cout << std::string('*', 100) << std::endl; + std::cout << "PrintPython Object:: " << msg << std::endl; + if(obj) { + PyObject *object_repr = PyObject_Repr(obj); + const char *s = PyUnicode_AsUTF8(object_repr); + std::cout << s << std::endl; + } else { + std::cout << "null object" << std::endl; + } + + std::cout << std::string('*', 80) << std::endl; +} + +Status PrintArrayObject(std::string&& msg, const OwnedRefNoGIL& owned_state) { + std::cout << std::string('X', 100) << std::endl; + std::cout << "Print Array Object : " << msg << std::endl; + if (owned_state) { + if(is_array(owned_state.obj())) { + std::cout << "is array" << std::endl; + ARROW_ASSIGN_OR_RAISE(std::shared_ptr val, unwrap_array(owned_state.obj())); + std::cout << "Value : " << val->ToString() << std::endl; + } else { + std::cout << "Non array state" << std::endl; + } + } else { + std::cout << "no state found" << std::endl; + } + std::cout << std::string('X', 100) << std::endl; + return Status::OK(); +} + +Status PrintArrayJustObject(std::string&& msg, PyObject* obj) { + std::cout << std::string('k', 100) << std::endl; + std::cout << "Print Just Array Object : " << msg << std::endl; + if (obj) { + if(is_array(obj)) { + std::cout << "is array" << std::endl; + ARROW_ASSIGN_OR_RAISE(std::shared_ptr val, unwrap_array(obj)); + std::cout << "Value : " << val->ToString() << std::endl; + } else { + std::cout << "Non array object" << std::endl; + } + } else { + std::cout << "no object" << std::endl; + } + std::cout << std::string('k', 100) << std::endl; + return Status::OK(); +} + +Status CheckUdfContext(std::string&& msg, ScalarAggregateUdfContext udf_context) { + std::cout << std::string('*', 100) << std::endl; + std::cout << "Check UDF COntext: " << msg << std::endl; + if(udf_context.state) { + std::cout << "udf_context_.state is ok" << std::endl; + if(is_array(udf_context.state)) { + std::cout << "is array" << std::endl; + ARROW_ASSIGN_OR_RAISE(std::shared_ptr val, unwrap_array(udf_context.state)); + std::cout << val->ToString() << std::endl; + } + } else { + std::cout << "this->udf_context_.state is null" << std::endl; + } + std::cout << std::string('*', 100) << std::endl; + return Status::OK(); +} + + struct PythonScalarUdfAggregatorImpl : public ScalarUdfAggregator { + ScalarAggregateInitUdfWrapperCallback init_cb; ScalarAggregateConsumeUdfWrapperCallback consume_cb; ScalarAggregateMergeUdfWrapperCallback merge_cb; ScalarAggregateFinalizeUdfWrapperCallback finalize_cb; + std::shared_ptr init_function; std::shared_ptr consume_function; std::shared_ptr merge_function; std::shared_ptr finalize_function; std::shared_ptr output_type; - PythonScalarUdfAggregatorImpl(ScalarAggregateConsumeUdfWrapperCallback consume_cb, + PythonScalarUdfAggregatorImpl(ScalarAggregateInitUdfWrapperCallback init_cb, + ScalarAggregateConsumeUdfWrapperCallback consume_cb, ScalarAggregateMergeUdfWrapperCallback merge_cb, ScalarAggregateFinalizeUdfWrapperCallback finalize_cb, + std::shared_ptr init_function, std::shared_ptr consume_function, std::shared_ptr merge_function, std::shared_ptr finalize_function, - const std::shared_ptr& output_type) : consume_cb(consume_cb), + const std::shared_ptr& output_type) : init_cb(init_cb), + consume_cb(consume_cb), merge_cb(merge_cb), finalize_cb(finalize_cb), + init_function(init_function), consume_function(consume_function), merge_function(merge_function), finalize_function(finalize_function), - output_type(output_type) {} + output_type(output_type) { + Init(init_cb, init_function); + } ~PythonScalarUdfAggregatorImpl() { if (_Py_IsFinalizing()) { + init_function->detach(); consume_function->detach(); merge_function->detach(); finalize_function->detach(); } } + void Init(ScalarAggregateInitUdfWrapperCallback& init_cb , std::shared_ptr& init_function) { + auto st = SafeCallIntoPython([&]() -> Status { + OwnedRef result(init_cb(init_function->obj())); + PyObject* init_res = result.obj(); + Py_INCREF(init_res); + this->udf_context_ = ScalarAggregateUdfContext{default_memory_pool(), 0, std::move(init_res)}; + this->owned_state_.reset(result.obj()); + RETURN_NOT_OK(CheckPyError()); + return Status::OK(); + }); + if (!st.ok()) { + throw std::runtime_error(st.ToString()); + } + } + Status ConsumeBatch(compute::KernelContext* ctx, const compute::ExecSpan& batch) { - std::cout << "ConsumeBatch" << std::endl; + const auto& current_state = + arrow::internal::checked_cast(*ctx->state()); const int num_args = batch.num_values(); - this->batch_length = batch.length; - ScalarAggregateUdfContext udf_context{ctx->memory_pool(), batch.length}; + this->batch_length_ = batch.length; + this->udf_context_.batch_length = batch.length; + Py_INCREF(this->udf_context_.state); + this->udf_context_.state = this->owned_state_.obj(); // TODO: think about guaranteeing DRY (following logic already used in ScalarUDFs) - std::cout << "Num Args : " << num_args << std::endl; - std::cout << "Batch length : " << this->batch_length << std::endl; + OwnedRef arg_tuple(PyTuple_New(num_args)); RETURN_NOT_OK(CheckPyError()); for (int arg_id = 0; arg_id < num_args; arg_id++) { @@ -202,32 +299,44 @@ struct PythonScalarUdfAggregatorImpl : public ScalarUdfAggregator { PyTuple_SetItem(arg_tuple.obj(), arg_id, data); } } - std::cout << "Args set " << std::endl; - consume_cb(consume_function->obj(), udf_context, arg_tuple.obj()); - std::cout << "Function executed" << std::endl; + OwnedRef result(consume_cb(consume_function->obj(), this->udf_context_, arg_tuple.obj())); + RETURN_NOT_OK(CheckPyError()); + PyObject* consume_res = result.obj(); + Py_INCREF(consume_res); + this->owned_state_.reset(consume_res); + Py_INCREF(this->udf_context_.state); + this->udf_context_.state = this->owned_state_.obj(); RETURN_NOT_OK(CheckPyError()); return Status::OK(); } Status Consume(compute::KernelContext* ctx, const compute::ExecSpan& batch) override { - return SafeCallIntoPython([&]() -> Status { return ConsumeBatch(ctx, batch); }); + RETURN_NOT_OK(SafeCallIntoPython([&]() -> Status { + RETURN_NOT_OK(ConsumeBatch(ctx, batch)); + return Status::OK(); + })); + return Status::OK(); } Status MergeFrom(compute::KernelContext* ctx, compute::KernelState&& src) override { - ScalarAggregateUdfContext udf_context{ctx->memory_pool(), this->batch_length}; - return SafeCallIntoPython([&]() -> Status { - merge_cb(merge_function->obj(), udf_context); + const auto& other_state = arrow::internal::checked_cast(src); + return SafeCallIntoPython([&]() -> Status { + OwnedRef result(merge_cb(merge_function->obj(), other_state.udf_context_, this->owned_state_.obj(), other_state.owned_state_.obj())); + RETURN_NOT_OK(CheckPyError()); + PyObject* merge_res = result.obj(); + Py_INCREF(merge_res); + this->owned_state_.reset(merge_res); + Py_INCREF(this->udf_context_.state); + this->udf_context_.state = this->owned_state_.obj(); return Status::OK(); - }); - }; + }); + } Status Finalize(compute::KernelContext* ctx, arrow::Datum* out) override { + // TODO: consider the this_state and return the accurate value return SafeCallIntoPython([&]() -> Status { - ScalarAggregateUdfContext udf_context{ctx->memory_pool(), this->batch_length}; - OwnedRef result(finalize_cb(finalize_function->obj(), udf_context)); - std::cout << "Finalize Python Call finished in C++" << std::endl; + OwnedRef result(finalize_cb(finalize_function->obj(), this->udf_context_)); RETURN_NOT_OK(CheckPyError()); - std::cout << "CheckPyError done" << std::endl; // unwrapping the output for expected output type if (is_array(result.obj())) { ARROW_ASSIGN_OR_RAISE(std::shared_ptr val, unwrap_array(result.obj())); @@ -236,9 +345,7 @@ struct PythonScalarUdfAggregatorImpl : public ScalarUdfAggregator { ", but function returned datatype ", val->type()->ToString()); } - std::cout << "Finalize called to C++ : " << val->ToString() << std::endl; *out = Datum(std::move(val)); - std::cout << "Final value set" << std::endl; return Status::OK(); } else { return Status::TypeError("Unexpected output type: ", Py_TYPE(result.obj())->tp_name, @@ -248,7 +355,10 @@ struct PythonScalarUdfAggregatorImpl : public ScalarUdfAggregator { }; private: - int batch_length = 1; + int batch_length_ = 1; + // Think about how this is going to be standardized + OwnedRefNoGIL owned_state_; + ScalarAggregateUdfContext udf_context_; }; Status AddAggKernel(std::shared_ptr sig, compute::KernelInit init, @@ -265,6 +375,8 @@ Status RegisterScalarAggregateFunction(PyObject* consume_function, ScalarAggregateMergeUdfWrapperCallback merge_wrapper, PyObject* finalize_function, ScalarAggregateFinalizeUdfWrapperCallback finalize_wrapper, + PyObject* init_function, + ScalarAggregateInitUdfWrapperCallback init_wrapper, const ScalarUdfOptions& options) { std::cout << "RegisterScalarAggregateFunction" << std::endl; if (!PyCallable_Check(consume_function) || !PyCallable_Check(merge_function) || !PyCallable_Check(finalize_function)) { @@ -284,14 +396,16 @@ Status RegisterScalarAggregateFunction(PyObject* consume_function, } compute::OutputType output_type(options.output_type); - auto init = [consume_wrapper, merge_wrapper, finalize_wrapper, - consume_function, merge_function, finalize_function, options]( + auto init = [init_wrapper, consume_wrapper, merge_wrapper, finalize_wrapper, + init_function, consume_function, merge_function, finalize_function, options]( compute::KernelContext* ctx, const compute::KernelInitArgs& args) -> Result> { return std::make_unique( + init_wrapper, consume_wrapper, merge_wrapper, finalize_wrapper, + std::make_shared(init_function), std::make_shared(consume_function), std::make_shared(merge_function), std::make_shared(finalize_function), diff --git a/python/pyarrow/src/arrow/python/udf.h b/python/pyarrow/src/arrow/python/udf.h index 584570996b4..c7ee3507222 100644 --- a/python/pyarrow/src/arrow/python/udf.h +++ b/python/pyarrow/src/arrow/python/udf.h @@ -51,6 +51,9 @@ struct ARROW_PYTHON_EXPORT ScalarUdfContext { struct ARROW_PYTHON_EXPORT ScalarAggregateUdfContext { MemoryPool* pool; int64_t batch_length; + // TODO: do we need to standardize this + // Meaning: do we have to Create a PythonAggregateState object or something separately. + PyObject* state; }; @@ -62,11 +65,12 @@ Status ARROW_PYTHON_EXPORT RegisterScalarFunction(PyObject* user_function, ScalarUdfWrapperCallback wrapper, const ScalarUdfOptions& options); -using ScalarAggregateConsumeUdfWrapperCallback = std::function; +using ScalarAggregateConsumeUdfWrapperCallback = std::function; -using ScalarAggregateMergeUdfWrapperCallback = std::function; +using ScalarAggregateMergeUdfWrapperCallback = std::function; using ScalarAggregateFinalizeUdfWrapperCallback = std::function; @@ -78,6 +82,8 @@ Status ARROW_PYTHON_EXPORT RegisterScalarAggregateFunction(PyObject* consume_fun ScalarAggregateMergeUdfWrapperCallback merge_wrapper, PyObject* finalize_function, ScalarAggregateFinalizeUdfWrapperCallback finalize_wrapper, + PyObject* init_function, + ScalarAggregateInitUdfWrapperCallback init_wrapper, const ScalarUdfOptions& options); diff --git a/python/pyarrow/tests/test_udf.py b/python/pyarrow/tests/test_udf.py index 176ca3bde4d..7ba13efa981 100644 --- a/python/pyarrow/tests/test_udf.py +++ b/python/pyarrow/tests/test_udf.py @@ -507,55 +507,49 @@ def test_input_lifetime(unary_func_fixture): def test_aggregate_udf(): - - class SimpleCount: - - def __init__(self): - self._count = 0 - - def consume(self, ctx, x): - if isinstance(x, pa.Array): - self._count = self._count + len(x) - elif isinstance(x, pa.Scalar): - self._count = self._count + 1 - - def merge(self, ctx): - pass - - def finalize(self, ctx): - return pa.scalar(self._count) - - + + class State: + def __init__(self, count): + self._count = count + + @property + def count(self): + return self._count + + @count.setter + def count(self, value): + self._count = value + + def init(): + return pa.array([0]) + def consume(ctx, x): if isinstance(x, pa.Array): - print("consume: array: ", len(x) + 1) + count = len(x) elif isinstance(x, pa.Scalar): - print(1) - - def merge(ctx): - print("call merge") - pass - + count = 1 + return pc.add(pa.array([count]), ctx.state) + + def merge(ctx, current_state, other_state): + new_state = pc.add(current_state, other_state) + return pa.array([new_state.as_py()]) + def finalize(ctx): - print("call finalize") - return pa.array([10]) - + return ctx.state + func_name = "simple_count" unary_doc = {"summary": "count function", "description": "test agg count function"} - simple_count = SimpleCount() - pc.register_scalar_aggregate_function(consume, + + pc.register_scalar_aggregate_function(init, + consume, merge, finalize, - func_name, - unary_doc, - {"array": pa.int64()}, - pa.int64()) - + func_name, + unary_doc, + {"array": pa.int64()}, + pa.int64()) + print(pc.get_function(func_name)) - - pc.call_function(func_name, [pa.array([10, 20])]) - - - - \ No newline at end of file + + print(pc.call_function(func_name, [pa.array([10, 20])])) From f723ad1e09504d16dbafcf4b6e1e9852e94dd33c Mon Sep 17 00:00:00 2001 From: Vibhatha Lakmal Abeykoon Date: Mon, 31 Oct 2022 19:33:10 +0530 Subject: [PATCH 06/13] feat(status-custom): initial --- python/pyarrow/_compute.pyx | 9 ++- python/pyarrow/includes/libarrow.pxd | 2 +- python/pyarrow/src/arrow/python/udf.cc | 25 +++++- python/pyarrow/src/arrow/python/udf.h | 2 +- python/pyarrow/tests/test_udf.py | 102 +++++++++++++++++++++---- 5 files changed, 120 insertions(+), 20 deletions(-) diff --git a/python/pyarrow/_compute.pyx b/python/pyarrow/_compute.pyx index a2a8b760d21..34a9d8c2bd9 100644 --- a/python/pyarrow/_compute.pyx +++ b/python/pyarrow/_compute.pyx @@ -2579,13 +2579,18 @@ cdef _scalar_agg_consume_udf_callback(consume_function, const CScalarAggregateUd return consume_function(context, *inputs) -cdef _scalar_agg_merge_udf_callback(merge_function, const CScalarAggregateUdfContext& c_context, current_state, other_state): +cdef _scalar_agg_merge_udf_callback(merge_function, const CScalarAggregateUdfContext& c_context, other_state): """ Helper aggregate merge callback function used to wrap the ScalarAggregateUdfContext from Python to C++ execution. """ + print("_scalar_agg_merge_udf_callback") context = box_scalar_udf_agg_context(c_context) - return merge_function(context, *current_state, *other_state) + print("context: ", context) + print("context.state: ", context.state) + print("other_state: ", other_state) + print("*other_state: ", *other_state) + return merge_function(context, *other_state) cdef _scalar_agg_finalize_udf_callback(finalize_function, const CScalarAggregateUdfContext& c_context): diff --git a/python/pyarrow/includes/libarrow.pxd b/python/pyarrow/includes/libarrow.pxd index b707c44c437..9b7093a1bc3 100644 --- a/python/pyarrow/includes/libarrow.pxd +++ b/python/pyarrow/includes/libarrow.pxd @@ -2765,7 +2765,7 @@ ctypedef PyObject* CallbackAggInitUdf(object init_function) ctypedef PyObject* CallbackAggConsumeUdf(object consume_function, const CScalarAggregateUdfContext& context, object inputs) -ctypedef PyObject* CallbackAggMergeUdf(object merge_function, const CScalarAggregateUdfContext& context, object current_state, object other_state) +ctypedef PyObject* CallbackAggMergeUdf(object merge_function, const CScalarAggregateUdfContext& context, object other_state) ctypedef PyObject* CallbackAggFinalizeUdf(object finalize_function, const CScalarAggregateUdfContext& context) diff --git a/python/pyarrow/src/arrow/python/udf.cc b/python/pyarrow/src/arrow/python/udf.cc index e47527e8760..782ba57e73e 100644 --- a/python/pyarrow/src/arrow/python/udf.cc +++ b/python/pyarrow/src/arrow/python/udf.cc @@ -210,6 +210,8 @@ Status CheckUdfContext(std::string&& msg, ScalarAggregateUdfContext udf_context) std::cout << "is array" << std::endl; ARROW_ASSIGN_OR_RAISE(std::shared_ptr val, unwrap_array(udf_context.state)); std::cout << val->ToString() << std::endl; + } else { + PrintPyObject("non-arrow-object", udf_context.state); } } else { std::cout << "this->udf_context_.state is null" << std::endl; @@ -285,7 +287,7 @@ struct PythonScalarUdfAggregatorImpl : public ScalarUdfAggregator { Py_INCREF(this->udf_context_.state); this->udf_context_.state = this->owned_state_.obj(); // TODO: think about guaranteeing DRY (following logic already used in ScalarUDFs) - + CheckUdfContext("check udf context @ConsumeBatch Start", this->udf_context_); OwnedRef arg_tuple(PyTuple_New(num_args)); RETURN_NOT_OK(CheckPyError()); for (int arg_id = 0; arg_id < num_args; arg_id++) { @@ -306,6 +308,7 @@ struct PythonScalarUdfAggregatorImpl : public ScalarUdfAggregator { this->owned_state_.reset(consume_res); Py_INCREF(this->udf_context_.state); this->udf_context_.state = this->owned_state_.obj(); + CheckUdfContext("check udf context @ConsumeBatch End", this->udf_context_); RETURN_NOT_OK(CheckPyError()); return Status::OK(); } @@ -319,21 +322,39 @@ struct PythonScalarUdfAggregatorImpl : public ScalarUdfAggregator { } Status MergeFrom(compute::KernelContext* ctx, compute::KernelState&& src) override { + std::cout << "MergeFrom Start" << std::endl; const auto& other_state = arrow::internal::checked_cast(src); return SafeCallIntoPython([&]() -> Status { - OwnedRef result(merge_cb(merge_function->obj(), other_state.udf_context_, this->owned_state_.obj(), other_state.owned_state_.obj())); + CheckUdfContext("\tcheck this->udf_context @MergeFrom", this->udf_context_); + CheckUdfContext("\tcheck other_state->udf_context @MergeFrom", other_state.udf_context_); + std::cout << "\tJust before callback exec" << std::endl; + if (this->owned_state_.obj() == Py_None) { + std::cout << "\t this->owned_state_.obj() == Py_None" << std::endl; + } + if (other_state.owned_state_.obj() == Py_None) { + std::cout << "\t other_state.owned_state_.obj() == Py_None" << std::endl; + } + if(this->udf_context_.state == Py_None) { + std::cout << "\t this->udf_context_.state == PyNone" << std::endl; + } + OwnedRef result(merge_cb(merge_function->obj(), this->udf_context_, other_state.owned_state_.obj())); RETURN_NOT_OK(CheckPyError()); + std::cout << "\t Exec callback finished" << std::endl; PyObject* merge_res = result.obj(); Py_INCREF(merge_res); this->owned_state_.reset(merge_res); + std::cout << "Results stored in owned_state" << std::endl; Py_INCREF(this->udf_context_.state); this->udf_context_.state = this->owned_state_.obj(); + std::cout << "Results stored in udf_context._state" << std::endl; + std::cout << "MergeFrom End" << std::endl; return Status::OK(); }); } Status Finalize(compute::KernelContext* ctx, arrow::Datum* out) override { // TODO: consider the this_state and return the accurate value + std::cout << "Finalize" << std::endl; return SafeCallIntoPython([&]() -> Status { OwnedRef result(finalize_cb(finalize_function->obj(), this->udf_context_)); RETURN_NOT_OK(CheckPyError()); diff --git a/python/pyarrow/src/arrow/python/udf.h b/python/pyarrow/src/arrow/python/udf.h index c7ee3507222..e355f8ddba1 100644 --- a/python/pyarrow/src/arrow/python/udf.h +++ b/python/pyarrow/src/arrow/python/udf.h @@ -70,7 +70,7 @@ using ScalarAggregateConsumeUdfWrapperCallback = std::function; using ScalarAggregateMergeUdfWrapperCallback = std::function; + PyObject* user_merge_func, const ScalarAggregateUdfContext& context, PyObject* other_state)>; using ScalarAggregateFinalizeUdfWrapperCallback = std::function; diff --git a/python/pyarrow/tests/test_udf.py b/python/pyarrow/tests/test_udf.py index 7ba13efa981..842f47429d7 100644 --- a/python/pyarrow/tests/test_udf.py +++ b/python/pyarrow/tests/test_udf.py @@ -508,18 +508,6 @@ def test_input_lifetime(unary_func_fixture): def test_aggregate_udf(): - class State: - def __init__(self, count): - self._count = count - - @property - def count(self): - return self._count - - @count.setter - def count(self, value): - self._count = value - def init(): return pa.array([0]) @@ -530,8 +518,8 @@ def consume(ctx, x): count = 1 return pc.add(pa.array([count]), ctx.state) - def merge(ctx, current_state, other_state): - new_state = pc.add(current_state, other_state) + def merge(ctx, other_state): + new_state = pc.add(ctx.state, other_state) return pa.array([new_state.as_py()]) def finalize(ctx): @@ -553,3 +541,89 @@ def finalize(ctx): print(pc.get_function(func_name)) print(pc.call_function(func_name, [pa.array([10, 20])])) + +def test_aggregate_udf_with_custom_state(): + + class State: + def __init__(self, non_null): + self._non_null = non_null + + @property + def non_null(self): + return self._non_null + + @non_null.setter + def non_null(self, value): + self._non_null = value + + def __repr__(self): + if self._non_null is None: + return "no values stored" + else: + return "count: " + str(self._non_null) + + def next(self): + return self._non_null + + def __iter__(self): + return self + + + def init(): + print(">>> Init") + state = State(1) + return state + + def consume(ctx, x): + print(">>> consume") + if isinstance(x, pa.Array): + count = len(x) + elif isinstance(x, pa.Scalar): + count = 1 + print("state: ", ctx.state) + return pc.add(pa.array([count]), pa.array([ctx.state.non_null])) + + def merge(ctx, other_state): + print(">>> merge") + print("os: ", other_state) + return pa.array([1]) + + def finalize(ctx): + print(">>> finalize") + print(ctx.state) + return pa.array([2]) + + func_name = "simple_count_1" + unary_doc = {"summary": "count function", + "description": "test agg count function"} + + pc.register_scalar_aggregate_function(init, + consume, + merge, + finalize, + func_name, + unary_doc, + {"array": pa.int64()}, + pa.int64()) + + print(pc.get_function(func_name)) + + print(pc.call_function(func_name, [pa.array([10, 20])])) + + + +def test_segfault_error(): + func_name = "simple_count_x" + unary_doc = {"summary": "count function", + "description": "test agg count function"} + + def func(ctx, x): + return x + + pc.register_scalar_function(func, + func_name, + unary_doc, + {"array": pa.int64()}, + pa.int64()) + + print(pc.call_function(func_name, [pa.array([10])])) From 4f6c5379f616f53b8c3b2a0fc52c70a356caed53 Mon Sep 17 00:00:00 2001 From: Vibhatha Lakmal Abeykoon Date: Tue, 1 Nov 2022 08:36:58 +0530 Subject: [PATCH 07/13] fix(state): state-based computations testing --- python/pyarrow/_compute.pyx | 3 +- python/pyarrow/src/arrow/python/udf.cc | 18 ++--- python/pyarrow/src/arrow/python/udf.h | 2 + python/pyarrow/tests/test_udf.py | 107 +++++++------------------ 4 files changed, 40 insertions(+), 90 deletions(-) diff --git a/python/pyarrow/_compute.pyx b/python/pyarrow/_compute.pyx index 34a9d8c2bd9..ef9a2532e18 100644 --- a/python/pyarrow/_compute.pyx +++ b/python/pyarrow/_compute.pyx @@ -2589,8 +2589,7 @@ cdef _scalar_agg_merge_udf_callback(merge_function, const CScalarAggregateUdfCon print("context: ", context) print("context.state: ", context.state) print("other_state: ", other_state) - print("*other_state: ", *other_state) - return merge_function(context, *other_state) + return merge_function(context, other_state) cdef _scalar_agg_finalize_udf_callback(finalize_function, const CScalarAggregateUdfContext& c_context): diff --git a/python/pyarrow/src/arrow/python/udf.cc b/python/pyarrow/src/arrow/python/udf.cc index 782ba57e73e..8a097d97fee 100644 --- a/python/pyarrow/src/arrow/python/udf.cc +++ b/python/pyarrow/src/arrow/python/udf.cc @@ -220,6 +220,11 @@ Status CheckUdfContext(std::string&& msg, ScalarAggregateUdfContext udf_context) return Status::OK(); } +ScalarAggregateUdfContext::~ScalarAggregateUdfContext() { + if (_Py_IsFinalizing()) { + Py_DECREF(this->state); + } +} struct PythonScalarUdfAggregatorImpl : public ScalarUdfAggregator { @@ -328,16 +333,8 @@ struct PythonScalarUdfAggregatorImpl : public ScalarUdfAggregator { CheckUdfContext("\tcheck this->udf_context @MergeFrom", this->udf_context_); CheckUdfContext("\tcheck other_state->udf_context @MergeFrom", other_state.udf_context_); std::cout << "\tJust before callback exec" << std::endl; - if (this->owned_state_.obj() == Py_None) { - std::cout << "\t this->owned_state_.obj() == Py_None" << std::endl; - } - if (other_state.owned_state_.obj() == Py_None) { - std::cout << "\t other_state.owned_state_.obj() == Py_None" << std::endl; - } - if(this->udf_context_.state == Py_None) { - std::cout << "\t this->udf_context_.state == PyNone" << std::endl; - } - OwnedRef result(merge_cb(merge_function->obj(), this->udf_context_, other_state.owned_state_.obj())); + OwnedRef result(merge_cb(merge_function->obj(), + this->udf_context_, other_state.owned_state_.obj())); RETURN_NOT_OK(CheckPyError()); std::cout << "\t Exec callback finished" << std::endl; PyObject* merge_res = result.obj(); @@ -353,7 +350,6 @@ struct PythonScalarUdfAggregatorImpl : public ScalarUdfAggregator { } Status Finalize(compute::KernelContext* ctx, arrow::Datum* out) override { - // TODO: consider the this_state and return the accurate value std::cout << "Finalize" << std::endl; return SafeCallIntoPython([&]() -> Status { OwnedRef result(finalize_cb(finalize_function->obj(), this->udf_context_)); diff --git a/python/pyarrow/src/arrow/python/udf.h b/python/pyarrow/src/arrow/python/udf.h index e355f8ddba1..004662ff4f7 100644 --- a/python/pyarrow/src/arrow/python/udf.h +++ b/python/pyarrow/src/arrow/python/udf.h @@ -54,6 +54,8 @@ struct ARROW_PYTHON_EXPORT ScalarAggregateUdfContext { // TODO: do we need to standardize this // Meaning: do we have to Create a PythonAggregateState object or something separately. PyObject* state; + + ~ScalarAggregateUdfContext(); }; diff --git a/python/pyarrow/tests/test_udf.py b/python/pyarrow/tests/test_udf.py index 842f47429d7..2cc0dcb7d78 100644 --- a/python/pyarrow/tests/test_udf.py +++ b/python/pyarrow/tests/test_udf.py @@ -506,47 +506,11 @@ def test_input_lifetime(unary_func_fixture): assert proxy_pool.bytes_allocated() == 0 -def test_aggregate_udf(): - - def init(): - return pa.array([0]) - - def consume(ctx, x): - if isinstance(x, pa.Array): - count = len(x) - elif isinstance(x, pa.Scalar): - count = 1 - return pc.add(pa.array([count]), ctx.state) - - def merge(ctx, other_state): - new_state = pc.add(ctx.state, other_state) - return pa.array([new_state.as_py()]) - - def finalize(ctx): - return ctx.state - - func_name = "simple_count" - unary_doc = {"summary": "count function", - "description": "test agg count function"} - - pc.register_scalar_aggregate_function(init, - consume, - merge, - finalize, - func_name, - unary_doc, - {"array": pa.int64()}, - pa.int64()) - - print(pc.get_function(func_name)) - - print(pc.call_function(func_name, [pa.array([10, 20])])) - def test_aggregate_udf_with_custom_state(): - class State: - def __init__(self, non_null): + def __init__(self, non_null=0, msg=""): self._non_null = non_null + self._msg = msg @property def non_null(self): @@ -555,43 +519,49 @@ def non_null(self): @non_null.setter def non_null(self, value): self._non_null = value - + + @property + def msg(self): + return self._msg + def __repr__(self): if self._non_null is None: return "no values stored" else: - return "count: " + str(self._non_null) - - def next(self): - return self._non_null - - def __iter__(self): - return self + return "count: " + str(self.non_null) \ + + ", msg: " + str(self.msg) + + def next(self): + print("next: ", self.msg, self.non_null) + return self.non_null + def __iter__(self): + print("iter: ", self.msg, self.non_null) + yield self.non_null + + def __del__(self): + print("State.__del__, msg: " + str(self.msg)) def init(): - print(">>> Init") - state = State(1) + state = State(0, "@init") return state def consume(ctx, x): - print(">>> consume") if isinstance(x, pa.Array): - count = len(x) + count = pc.sum(pc.invert(pc.is_nan(x))).as_py() elif isinstance(x, pa.Scalar): - count = 1 - print("state: ", ctx.state) - return pc.add(pa.array([count]), pa.array([ctx.state.non_null])) + if x.as_py(): + count = 1 + state_val = pc.add(pa.array([count]), pa.array([ctx.state.non_null])) + return State(state_val[0].as_py(), "@consume") def merge(ctx, other_state): - print(">>> merge") - print("os: ", other_state) - return pa.array([1]) + merged_state_val = ctx.state.non_null + other_state.non_null + return State(merged_state_val, "@merge") def finalize(ctx): - print(">>> finalize") print(ctx.state) - return pa.array([2]) + return pa.array([ctx.state.non_null]) func_name = "simple_count_1" unary_doc = {"summary": "count function", @@ -608,22 +578,5 @@ def finalize(ctx): print(pc.get_function(func_name)) - print(pc.call_function(func_name, [pa.array([10, 20])])) - - - -def test_segfault_error(): - func_name = "simple_count_x" - unary_doc = {"summary": "count function", - "description": "test agg count function"} - - def func(ctx, x): - return x - - pc.register_scalar_function(func, - func_name, - unary_doc, - {"array": pa.int64()}, - pa.int64()) - - print(pc.call_function(func_name, [pa.array([10])])) + print(pc.call_function(func_name, [ + pa.array([10, 20, None, 30, None, 40])])) From 19b8551771eeff56ada4a30509983a1a7e52c081 Mon Sep 17 00:00:00 2001 From: Vibhatha Lakmal Abeykoon Date: Tue, 1 Nov 2022 15:43:17 +0530 Subject: [PATCH 08/13] fix(format) --- python/pyarrow/_compute.pyx | 4 -- python/pyarrow/tests/test_udf.py | 105 +++++++++++++++++++++++++++---- 2 files changed, 92 insertions(+), 17 deletions(-) diff --git a/python/pyarrow/_compute.pyx b/python/pyarrow/_compute.pyx index ef9a2532e18..f28026fd813 100644 --- a/python/pyarrow/_compute.pyx +++ b/python/pyarrow/_compute.pyx @@ -2584,11 +2584,7 @@ cdef _scalar_agg_merge_udf_callback(merge_function, const CScalarAggregateUdfCon Helper aggregate merge callback function used to wrap the ScalarAggregateUdfContext from Python to C++ execution. """ - print("_scalar_agg_merge_udf_callback") context = box_scalar_udf_agg_context(c_context) - print("context: ", context) - print("context.state: ", context.state) - print("other_state: ", other_state) return merge_function(context, other_state) diff --git a/python/pyarrow/tests/test_udf.py b/python/pyarrow/tests/test_udf.py index 2cc0dcb7d78..440fe003eb8 100644 --- a/python/pyarrow/tests/test_udf.py +++ b/python/pyarrow/tests/test_udf.py @@ -531,14 +531,6 @@ def __repr__(self): return "count: " + str(self.non_null) \ + ", msg: " + str(self.msg) - def next(self): - print("next: ", self.msg, self.non_null) - return self.non_null - - def __iter__(self): - print("iter: ", self.msg, self.non_null) - yield self.non_null - def __del__(self): print("State.__del__, msg: " + str(self.msg)) @@ -548,19 +540,18 @@ def init(): def consume(ctx, x): if isinstance(x, pa.Array): - count = pc.sum(pc.invert(pc.is_nan(x))).as_py() + non_null = pc.sum(pc.invert(pc.is_nan(x))).as_py() elif isinstance(x, pa.Scalar): if x.as_py(): - count = 1 - state_val = pc.add(pa.array([count]), pa.array([ctx.state.non_null])) - return State(state_val[0].as_py(), "@consume") + non_null = 1 + non_null = non_null + ctx.state.non_null + return State(non_null, "@consume") def merge(ctx, other_state): merged_state_val = ctx.state.non_null + other_state.non_null return State(merged_state_val, "@merge") def finalize(ctx): - print(ctx.state) return pa.array([ctx.state.non_null]) func_name = "simple_count_1" @@ -580,3 +571,91 @@ def finalize(ctx): print(pc.call_function(func_name, [ pa.array([10, 20, None, 30, None, 40])])) + + +def test_aggregate_udf_with_custom_state_multi_attr(): + class State: + def __init__(self, non_null=0, null=0, msg=""): + self._non_null = non_null + self._null = null + self._msg = msg + + @property + def non_null(self): + return self._non_null + + @non_null.setter + def non_null(self, value): + self._non_null = value + + @property + def null(self): + return self._null + + @null.setter + def null(self, value): + self._null = value + + @property + def msg(self): + return self._msg + + def __repr__(self): + if self._non_null is None: + return "no values stored" + else: + return "non_null: " + str(self.non_null) \ + + ", null: " + str(self.null) \ + + ", msg: " + str(self.msg) + + def __del__(self): + print("State.__del__, msg: " + str(self.msg)) + + def init(): + print(">>> init") + state = State(0, 0, "@init") + return state + + def consume(ctx, x): + print(">>> consume") + null = 0 + non_null = 0 + if isinstance(x, pa.Array): + non_null = pc.sum(pc.invert(pc.is_nan(x))).as_py() + null = len(x) - non_null + elif isinstance(x, pa.Scalar): + if x.as_py(): + non_null = 1 + else: + null = 1 + non_null = non_null + ctx.state.non_null + return State(non_null, null, "@consume") + + def merge(ctx, other_state): + print(">>> merge") + merged_st_non_null = ctx.state.non_null + other_state.non_null + merged_st_null = ctx.state.null + other_state.null + return State(merged_st_non_null, merged_st_null, "@merge") + + def finalize(ctx): + print(">>> finalize") + print(ctx.state) + return pa.array([ctx.state.non_null, ctx.state.null]) + + func_name = "basic_count" + unary_doc = {"summary": "count function for null and non-null", + "description": "test agg count function"} + + pc.register_scalar_aggregate_function(init, + consume, + merge, + finalize, + func_name, + unary_doc, + {"array": pa.int64()}, + pa.int64()) + + print(pc.get_function(func_name)) + + print(pc.call_function(func_name, [ + pa.array([10, 20, None, 30, None, 40])])) From f945b4081da9212869f2bbc4bc1a7c94fc5766c8 Mon Sep 17 00:00:00 2001 From: Vibhatha Lakmal Abeykoon Date: Wed, 2 Nov 2022 13:57:32 +0530 Subject: [PATCH 09/13] fix(cleanup) --- python/pyarrow/_compute.pyx | 12 ---- python/pyarrow/includes/libarrow.pxd | 1 - python/pyarrow/src/arrow/python/udf.cc | 92 +------------------------- python/pyarrow/src/arrow/python/udf.h | 6 +- python/pyarrow/tests/test_udf.py | 16 ++--- 5 files changed, 8 insertions(+), 119 deletions(-) diff --git a/python/pyarrow/_compute.pyx b/python/pyarrow/_compute.pyx index f28026fd813..8c9c13d7392 100644 --- a/python/pyarrow/_compute.pyx +++ b/python/pyarrow/_compute.pyx @@ -2489,18 +2489,6 @@ cdef class ScalarAggregateUdfContext: cdef void init(self, const CScalarAggregateUdfContext &c_context): self.c_context = c_context - @property - def batch_length(self): - """ - The common length of all input arguments (int). - - In the case that all arguments are scalars, this value - is used to pass the "actual length" of the arguments, - e.g. because the scalar values are encoding a column - with a constant value. - """ - return self.c_context.batch_length - @property def memory_pool(self): """ diff --git a/python/pyarrow/includes/libarrow.pxd b/python/pyarrow/includes/libarrow.pxd index 9b7093a1bc3..192f166c8c4 100644 --- a/python/pyarrow/includes/libarrow.pxd +++ b/python/pyarrow/includes/libarrow.pxd @@ -2776,7 +2776,6 @@ cdef extern from "arrow/python/udf.h" namespace "arrow::py": cdef cppclass CScalarAggregateUdfContext" arrow::py::ScalarAggregateUdfContext": CMemoryPool *pool - int64_t batch_length PyObject* state cdef cppclass CScalarUdfOptions" arrow::py::ScalarUdfOptions": diff --git a/python/pyarrow/src/arrow/python/udf.cc b/python/pyarrow/src/arrow/python/udf.cc index 8a097d97fee..676869e90f4 100644 --- a/python/pyarrow/src/arrow/python/udf.cc +++ b/python/pyarrow/src/arrow/python/udf.cc @@ -148,78 +148,6 @@ arrow::Status AggregateUdfFinalize(compute::KernelContext* ctx, arrow::Datum* ou ->Finalize(ctx, out); } -// TODO remove functions - -// debug functions -void PrintPyObject(std::string&& msg, PyObject* obj) { - std::cout << std::string('*', 100) << std::endl; - std::cout << "PrintPython Object:: " << msg << std::endl; - if(obj) { - PyObject *object_repr = PyObject_Repr(obj); - const char *s = PyUnicode_AsUTF8(object_repr); - std::cout << s << std::endl; - } else { - std::cout << "null object" << std::endl; - } - - std::cout << std::string('*', 80) << std::endl; -} - -Status PrintArrayObject(std::string&& msg, const OwnedRefNoGIL& owned_state) { - std::cout << std::string('X', 100) << std::endl; - std::cout << "Print Array Object : " << msg << std::endl; - if (owned_state) { - if(is_array(owned_state.obj())) { - std::cout << "is array" << std::endl; - ARROW_ASSIGN_OR_RAISE(std::shared_ptr val, unwrap_array(owned_state.obj())); - std::cout << "Value : " << val->ToString() << std::endl; - } else { - std::cout << "Non array state" << std::endl; - } - } else { - std::cout << "no state found" << std::endl; - } - std::cout << std::string('X', 100) << std::endl; - return Status::OK(); -} - -Status PrintArrayJustObject(std::string&& msg, PyObject* obj) { - std::cout << std::string('k', 100) << std::endl; - std::cout << "Print Just Array Object : " << msg << std::endl; - if (obj) { - if(is_array(obj)) { - std::cout << "is array" << std::endl; - ARROW_ASSIGN_OR_RAISE(std::shared_ptr val, unwrap_array(obj)); - std::cout << "Value : " << val->ToString() << std::endl; - } else { - std::cout << "Non array object" << std::endl; - } - } else { - std::cout << "no object" << std::endl; - } - std::cout << std::string('k', 100) << std::endl; - return Status::OK(); -} - -Status CheckUdfContext(std::string&& msg, ScalarAggregateUdfContext udf_context) { - std::cout << std::string('*', 100) << std::endl; - std::cout << "Check UDF COntext: " << msg << std::endl; - if(udf_context.state) { - std::cout << "udf_context_.state is ok" << std::endl; - if(is_array(udf_context.state)) { - std::cout << "is array" << std::endl; - ARROW_ASSIGN_OR_RAISE(std::shared_ptr val, unwrap_array(udf_context.state)); - std::cout << val->ToString() << std::endl; - } else { - PrintPyObject("non-arrow-object", udf_context.state); - } - } else { - std::cout << "this->udf_context_.state is null" << std::endl; - } - std::cout << std::string('*', 100) << std::endl; - return Status::OK(); -} - ScalarAggregateUdfContext::~ScalarAggregateUdfContext() { if (_Py_IsFinalizing()) { Py_DECREF(this->state); @@ -273,7 +201,7 @@ struct PythonScalarUdfAggregatorImpl : public ScalarUdfAggregator { OwnedRef result(init_cb(init_function->obj())); PyObject* init_res = result.obj(); Py_INCREF(init_res); - this->udf_context_ = ScalarAggregateUdfContext{default_memory_pool(), 0, std::move(init_res)}; + this->udf_context_ = ScalarAggregateUdfContext{default_memory_pool(), std::move(init_res)}; this->owned_state_.reset(result.obj()); RETURN_NOT_OK(CheckPyError()); return Status::OK(); @@ -284,15 +212,10 @@ struct PythonScalarUdfAggregatorImpl : public ScalarUdfAggregator { } Status ConsumeBatch(compute::KernelContext* ctx, const compute::ExecSpan& batch) { - const auto& current_state = - arrow::internal::checked_cast(*ctx->state()); const int num_args = batch.num_values(); - this->batch_length_ = batch.length; - this->udf_context_.batch_length = batch.length; Py_INCREF(this->udf_context_.state); this->udf_context_.state = this->owned_state_.obj(); // TODO: think about guaranteeing DRY (following logic already used in ScalarUDFs) - CheckUdfContext("check udf context @ConsumeBatch Start", this->udf_context_); OwnedRef arg_tuple(PyTuple_New(num_args)); RETURN_NOT_OK(CheckPyError()); for (int arg_id = 0; arg_id < num_args; arg_id++) { @@ -313,7 +236,6 @@ struct PythonScalarUdfAggregatorImpl : public ScalarUdfAggregator { this->owned_state_.reset(consume_res); Py_INCREF(this->udf_context_.state); this->udf_context_.state = this->owned_state_.obj(); - CheckUdfContext("check udf context @ConsumeBatch End", this->udf_context_); RETURN_NOT_OK(CheckPyError()); return Status::OK(); } @@ -327,30 +249,21 @@ struct PythonScalarUdfAggregatorImpl : public ScalarUdfAggregator { } Status MergeFrom(compute::KernelContext* ctx, compute::KernelState&& src) override { - std::cout << "MergeFrom Start" << std::endl; const auto& other_state = arrow::internal::checked_cast(src); return SafeCallIntoPython([&]() -> Status { - CheckUdfContext("\tcheck this->udf_context @MergeFrom", this->udf_context_); - CheckUdfContext("\tcheck other_state->udf_context @MergeFrom", other_state.udf_context_); - std::cout << "\tJust before callback exec" << std::endl; OwnedRef result(merge_cb(merge_function->obj(), this->udf_context_, other_state.owned_state_.obj())); RETURN_NOT_OK(CheckPyError()); - std::cout << "\t Exec callback finished" << std::endl; PyObject* merge_res = result.obj(); Py_INCREF(merge_res); this->owned_state_.reset(merge_res); - std::cout << "Results stored in owned_state" << std::endl; Py_INCREF(this->udf_context_.state); this->udf_context_.state = this->owned_state_.obj(); - std::cout << "Results stored in udf_context._state" << std::endl; - std::cout << "MergeFrom End" << std::endl; return Status::OK(); }); } Status Finalize(compute::KernelContext* ctx, arrow::Datum* out) override { - std::cout << "Finalize" << std::endl; return SafeCallIntoPython([&]() -> Status { OwnedRef result(finalize_cb(finalize_function->obj(), this->udf_context_)); RETURN_NOT_OK(CheckPyError()); @@ -372,8 +285,6 @@ struct PythonScalarUdfAggregatorImpl : public ScalarUdfAggregator { }; private: - int batch_length_ = 1; - // Think about how this is going to be standardized OwnedRefNoGIL owned_state_; ScalarAggregateUdfContext udf_context_; }; @@ -395,7 +306,6 @@ Status RegisterScalarAggregateFunction(PyObject* consume_function, PyObject* init_function, ScalarAggregateInitUdfWrapperCallback init_wrapper, const ScalarUdfOptions& options) { - std::cout << "RegisterScalarAggregateFunction" << std::endl; if (!PyCallable_Check(consume_function) || !PyCallable_Check(merge_function) || !PyCallable_Check(finalize_function)) { return Status::TypeError("Expected a callable Python object."); } diff --git a/python/pyarrow/src/arrow/python/udf.h b/python/pyarrow/src/arrow/python/udf.h index 004662ff4f7..064c8df8848 100644 --- a/python/pyarrow/src/arrow/python/udf.h +++ b/python/pyarrow/src/arrow/python/udf.h @@ -47,14 +47,10 @@ struct ARROW_PYTHON_EXPORT ScalarUdfContext { int64_t batch_length; }; - +/// \brief A context passed as the first argument of scalar aggregate UDF functions. struct ARROW_PYTHON_EXPORT ScalarAggregateUdfContext { MemoryPool* pool; - int64_t batch_length; - // TODO: do we need to standardize this - // Meaning: do we have to Create a PythonAggregateState object or something separately. PyObject* state; - ~ScalarAggregateUdfContext(); }; diff --git a/python/pyarrow/tests/test_udf.py b/python/pyarrow/tests/test_udf.py index 440fe003eb8..d6f8683a1dc 100644 --- a/python/pyarrow/tests/test_udf.py +++ b/python/pyarrow/tests/test_udf.py @@ -554,7 +554,7 @@ def merge(ctx, other_state): def finalize(ctx): return pa.array([ctx.state.non_null]) - func_name = "simple_count_1" + func_name = "simple_count" unary_doc = {"summary": "count function", "description": "test agg count function"} @@ -567,10 +567,8 @@ def finalize(ctx): {"array": pa.int64()}, pa.int64()) - print(pc.get_function(func_name)) - - print(pc.call_function(func_name, [ - pa.array([10, 20, None, 30, None, 40])])) + assert pc.call_function(func_name, [pa.array( + [10, 20, None, 30, None, 40])]) == pa.array([4]) def test_aggregate_udf_with_custom_state_multi_attr(): @@ -642,7 +640,7 @@ def finalize(ctx): print(ctx.state) return pa.array([ctx.state.non_null, ctx.state.null]) - func_name = "basic_count" + func_name = "advance_count" unary_doc = {"summary": "count function for null and non-null", "description": "test agg count function"} @@ -655,7 +653,5 @@ def finalize(ctx): {"array": pa.int64()}, pa.int64()) - print(pc.get_function(func_name)) - - print(pc.call_function(func_name, [ - pa.array([10, 20, None, 30, None, 40])])) + assert pc.call_function(func_name, [ + pa.array([10, 20, None, 30, None, 40])]) == pa.array([4, 2]) From 838adf4c39b682df8ba2d926a3ec1bc7e261b870 Mon Sep 17 00:00:00 2001 From: Vibhatha Lakmal Abeykoon Date: Wed, 2 Nov 2022 14:15:36 +0530 Subject: [PATCH 10/13] fix(cleanup) --- python/pyarrow/src/arrow/python/udf.cc | 41 +++++++++++--------------- 1 file changed, 17 insertions(+), 24 deletions(-) diff --git a/python/pyarrow/src/arrow/python/udf.cc b/python/pyarrow/src/arrow/python/udf.cc index 676869e90f4..3b9d8068f49 100644 --- a/python/pyarrow/src/arrow/python/udf.cc +++ b/python/pyarrow/src/arrow/python/udf.cc @@ -16,8 +16,8 @@ // under the License. #include "arrow/python/udf.h" -#include "arrow/compute/function.h" #include "arrow/compute/api_aggregate.h" +#include "arrow/compute/function.h" #include "arrow/python/common.h" // TODO REMOVE @@ -32,6 +32,20 @@ namespace py { namespace { +void SetUpPythonArgs(int num_args, const ExecSpan& batch, OwnedRef& arg_tuple) { + for (int arg_id = 0; arg_id < num_args; arg_id++) { + if (batch[arg_id].is_scalar()) { + std::shared_ptr c_data = batch[arg_id].scalar->GetSharedPtr(); + PyObject* data = wrap_scalar(c_data); + PyTuple_SetItem(arg_tuple.obj(), arg_id, data); + } else { + std::shared_ptr c_data = batch[arg_id].array.ToArray(); + PyObject* data = wrap_array(c_data); + PyTuple_SetItem(arg_tuple.obj(), arg_id, data); + } + } +} + struct PythonUdf : public compute::KernelState { ScalarUdfWrapperCallback cb; std::shared_ptr function; @@ -55,18 +69,7 @@ struct PythonUdf : public compute::KernelState { OwnedRef arg_tuple(PyTuple_New(num_args)); RETURN_NOT_OK(CheckPyError()); - for (int arg_id = 0; arg_id < num_args; arg_id++) { - if (batch[arg_id].is_scalar()) { - std::shared_ptr c_data = batch[arg_id].scalar->GetSharedPtr(); - PyObject* data = wrap_scalar(c_data); - PyTuple_SetItem(arg_tuple.obj(), arg_id, data); - } else { - std::shared_ptr c_data = batch[arg_id].array.ToArray(); - PyObject* data = wrap_array(c_data); - PyTuple_SetItem(arg_tuple.obj(), arg_id, data); - } - } - + SetUpPythonArgs(num_args, batch, arg_tuple); OwnedRef result(cb(function->obj(), udf_context, arg_tuple.obj())); RETURN_NOT_OK(CheckPyError()); // unwrapping the output for expected output type @@ -218,17 +221,7 @@ struct PythonScalarUdfAggregatorImpl : public ScalarUdfAggregator { // TODO: think about guaranteeing DRY (following logic already used in ScalarUDFs) OwnedRef arg_tuple(PyTuple_New(num_args)); RETURN_NOT_OK(CheckPyError()); - for (int arg_id = 0; arg_id < num_args; arg_id++) { - if (batch[arg_id].is_scalar()) { - std::shared_ptr c_data = batch[arg_id].scalar->GetSharedPtr(); - PyObject* data = wrap_scalar(c_data); - PyTuple_SetItem(arg_tuple.obj(), arg_id, data); - } else { - std::shared_ptr c_data = batch[arg_id].array.ToArray(); - PyObject* data = wrap_array(c_data); - PyTuple_SetItem(arg_tuple.obj(), arg_id, data); - } - } + SetUpPythonArgs(num_args, batch, arg_tuple); OwnedRef result(consume_cb(consume_function->obj(), this->udf_context_, arg_tuple.obj())); RETURN_NOT_OK(CheckPyError()); PyObject* consume_res = result.obj(); From 96f8ba1d1027461ee03717b594e1229dd9364297 Mon Sep 17 00:00:00 2001 From: Vibhatha Lakmal Abeykoon Date: Wed, 2 Nov 2022 14:20:46 +0530 Subject: [PATCH 11/13] fix(cleanup-python) --- python/pyarrow/tests/test_udf.py | 42 ++++++++------------------------ 1 file changed, 10 insertions(+), 32 deletions(-) diff --git a/python/pyarrow/tests/test_udf.py b/python/pyarrow/tests/test_udf.py index d6f8683a1dc..6071229d2ec 100644 --- a/python/pyarrow/tests/test_udf.py +++ b/python/pyarrow/tests/test_udf.py @@ -508,9 +508,8 @@ def test_input_lifetime(unary_func_fixture): def test_aggregate_udf_with_custom_state(): class State: - def __init__(self, non_null=0, msg=""): + def __init__(self, non_null=0): self._non_null = non_null - self._msg = msg @property def non_null(self): @@ -520,22 +519,14 @@ def non_null(self): def non_null(self, value): self._non_null = value - @property - def msg(self): - return self._msg - def __repr__(self): if self._non_null is None: return "no values stored" else: - return "count: " + str(self.non_null) \ - + ", msg: " + str(self.msg) - - def __del__(self): - print("State.__del__, msg: " + str(self.msg)) + return "count: " + str(self.non_null) def init(): - state = State(0, "@init") + state = State(0) return state def consume(ctx, x): @@ -545,11 +536,11 @@ def consume(ctx, x): if x.as_py(): non_null = 1 non_null = non_null + ctx.state.non_null - return State(non_null, "@consume") + return State(non_null) def merge(ctx, other_state): merged_state_val = ctx.state.non_null + other_state.non_null - return State(merged_state_val, "@merge") + return State(merged_state_val) def finalize(ctx): return pa.array([ctx.state.non_null]) @@ -573,10 +564,9 @@ def finalize(ctx): def test_aggregate_udf_with_custom_state_multi_attr(): class State: - def __init__(self, non_null=0, null=0, msg=""): + def __init__(self, non_null=0, null=0): self._non_null = non_null self._null = null - self._msg = msg @property def non_null(self): @@ -594,28 +584,18 @@ def null(self): def null(self, value): self._null = value - @property - def msg(self): - return self._msg - def __repr__(self): if self._non_null is None: return "no values stored" else: return "non_null: " + str(self.non_null) \ - + ", null: " + str(self.null) \ - + ", msg: " + str(self.msg) - - def __del__(self): - print("State.__del__, msg: " + str(self.msg)) + + ", null: " + str(self.null) def init(): - print(">>> init") - state = State(0, 0, "@init") + state = State(0, 0) return state def consume(ctx, x): - print(">>> consume") null = 0 non_null = 0 if isinstance(x, pa.Array): @@ -627,16 +607,14 @@ def consume(ctx, x): else: null = 1 non_null = non_null + ctx.state.non_null - return State(non_null, null, "@consume") + return State(non_null, null) def merge(ctx, other_state): - print(">>> merge") merged_st_non_null = ctx.state.non_null + other_state.non_null merged_st_null = ctx.state.null + other_state.null - return State(merged_st_non_null, merged_st_null, "@merge") + return State(merged_st_non_null, merged_st_null) def finalize(ctx): - print(">>> finalize") print(ctx.state) return pa.array([ctx.state.non_null, ctx.state.null]) From 8207f995fabe97a7ca2c6dbc187524accdfc64bd Mon Sep 17 00:00:00 2001 From: Vibhatha Lakmal Abeykoon Date: Wed, 2 Nov 2022 14:24:07 +0530 Subject: [PATCH 12/13] fix(minor-style) --- python/pyarrow/src/arrow/python/udf.h | 1 + 1 file changed, 1 insertion(+) diff --git a/python/pyarrow/src/arrow/python/udf.h b/python/pyarrow/src/arrow/python/udf.h index 064c8df8848..140ac642de9 100644 --- a/python/pyarrow/src/arrow/python/udf.h +++ b/python/pyarrow/src/arrow/python/udf.h @@ -64,6 +64,7 @@ Status ARROW_PYTHON_EXPORT RegisterScalarFunction(PyObject* user_function, const ScalarUdfOptions& options); using ScalarAggregateInitUdfWrapperCallback = std::function; + using ScalarAggregateConsumeUdfWrapperCallback = std::function; From a6679cfa49d2ffba9728fa91adf4ab5730c228a1 Mon Sep 17 00:00:00 2001 From: Vibhatha Lakmal Abeykoon Date: Wed, 2 Nov 2022 15:00:05 +0530 Subject: [PATCH 13/13] feat(updated-docs) --- python/pyarrow/_compute.pyx | 116 +++++++++++++++++++++++++ python/pyarrow/src/arrow/python/udf.cc | 15 ++-- 2 files changed, 123 insertions(+), 8 deletions(-) diff --git a/python/pyarrow/_compute.pyx b/python/pyarrow/_compute.pyx index 8c9c13d7392..7cf4bf42561 100644 --- a/python/pyarrow/_compute.pyx +++ b/python/pyarrow/_compute.pyx @@ -2726,6 +2726,122 @@ def register_scalar_function(func, function_name, function_doc, in_types, def register_scalar_aggregate_function(init_func, consume_func, merge_func, finalize_func, function_name, function_doc, in_types, out_type): + """ + Register a user-defined scalar aggregate function. + + A scalar aggregate function is a set of 4 functions which formulates + the operation pieces of an scalar aggregation. The base behavior in + terms of computation is very much similar to scalar functions. + + Parameters + ---------- + init_func : callable + A callable implementing the user-defined initialization function. + This function is used to set the state for the aggregate operation + and returns the state object. + consume_func : callable + A callable implementing the user-defined consume function. + The first argument is the context argument of type + ScalarAggregateUdfContext. + Then, it must take arguments equal to the number of + in_types defined. + To define a varargs function, pass a callable that takes + varargs. The last in_type will be the type of all varargs + arguments. + + This function returns the updated state after consuming the + received data. + merge_func: callable + A callable implementing the user-defined merge function. + The first argument is the context argument of type + ScalarAggregateUdfContext. + Then, the second argument it takes is an state object. + This object holds the state with which the current state + must be merged with. The current state can be retrieved from + the context object which can be acessed by `context.state`. + The state doesn't need to be set in the Python side and it is + autonomously handled in the C++ backend. The updated state must + be returned from this function. + finalize_func: callable + A callable implementing the user-defined finalize function. + The first argument is the context argument of type + ScalarUdfContext. + Using the context argument the state can be extracted and return + type must be an array matching the `out_type`. + function_name : str + Name of the function. This name must be globally unique. + function_doc : dict + A dictionary object with keys "summary" (str), + and "description" (str). + in_types : Dict[str, DataType] + A dictionary mapping function argument names to + their respective DataType. + The argument names will be used to generate + documentation for the function. The number of + arguments specified here determines the function + arity. + out_type : DataType + Output type of the function. + + Examples + -------- + >>> class State: + ... def __init__(self, non_null=0): + ... self._non_null = non_null + ... + ... @property + ... def non_null(self): + ... return self._non_null + ... + ... @non_null.setter + ... def non_null(self, value): + ... self._non_null = value + ... + ... def __repr__(self): + ... if self._non_null is None: + ... return "no values stored" + ... else: + ... return "count: " + str(self.non_null) + + >>> def init(): + ... state = State(0) + ... return state + + >>> def consume(ctx, x): + ... if isinstance(x, pa.Array): + ... non_null = pc.sum(pc.invert(pc.is_nan(x))).as_py() + ... elif isinstance(x, pa.Scalar): + ... if x.as_py(): + ... non_null = 1 + ... non_null = non_null + ctx.state.non_null + ... return State(non_null) + + >>> def merge(ctx, other_state): + ... merged_state_val = ctx.state.non_null + other_state.non_null + ... return State(merged_state_val) + + >>> def finalize(ctx): + ... return pa.array([ctx.state.non_null]) + + >>> func_doc = {} + >>> func_doc["summary"] = "simple aggregate udf" + >>> func_doc["description"] = "simple count operation" + + >>> pc.register_scalar_aggregate_function(init, + ... consume, + ... merge, + ... finalize, + ... func_name, + ... unary_doc, + ... {"array": pa.int64()}, + ... pa.int64()) + + >>> pc.call_function(func_name, [pa.array([10, 20, None, 30, None, 40])]) + + [ + 4 + ] + """ cdef: c_string c_func_name diff --git a/python/pyarrow/src/arrow/python/udf.cc b/python/pyarrow/src/arrow/python/udf.cc index 3b9d8068f49..04615639f0b 100644 --- a/python/pyarrow/src/arrow/python/udf.cc +++ b/python/pyarrow/src/arrow/python/udf.cc @@ -25,6 +25,7 @@ namespace arrow { +using internal::checked_cast; using compute::ExecResult; using compute::ExecSpan; @@ -136,19 +137,16 @@ struct ScalarUdfAggregator : public compute::KernelState { }; arrow::Status AggregateUdfConsume(compute::KernelContext* ctx, const compute::ExecSpan& batch) { - return arrow::internal::checked_cast(ctx->state()) - ->Consume(ctx, batch); + return checked_cast(ctx->state())->Consume(ctx, batch); } arrow::Status AggregateUdfMerge(compute::KernelContext* ctx, compute::KernelState&& src, compute::KernelState* dst) { - return arrow::internal::checked_cast(dst)->MergeFrom( - ctx, std::move(src)); + return checked_cast(dst)->MergeFrom(ctx, std::move(src)); } arrow::Status AggregateUdfFinalize(compute::KernelContext* ctx, arrow::Datum* out) { - return arrow::internal::checked_cast(ctx->state()) - ->Finalize(ctx, out); + return checked_cast(ctx->state())->Finalize(ctx, out); } ScalarAggregateUdfContext::~ScalarAggregateUdfContext() { @@ -242,7 +240,7 @@ struct PythonScalarUdfAggregatorImpl : public ScalarUdfAggregator { } Status MergeFrom(compute::KernelContext* ctx, compute::KernelState&& src) override { - const auto& other_state = arrow::internal::checked_cast(src); + const auto& other_state = checked_cast(src); return SafeCallIntoPython([&]() -> Status { OwnedRef result(merge_cb(merge_function->obj(), this->udf_context_, other_state.owned_state_.obj())); @@ -257,8 +255,9 @@ struct PythonScalarUdfAggregatorImpl : public ScalarUdfAggregator { } Status Finalize(compute::KernelContext* ctx, arrow::Datum* out) override { + const auto& state = checked_cast(*ctx->state()); return SafeCallIntoPython([&]() -> Status { - OwnedRef result(finalize_cb(finalize_function->obj(), this->udf_context_)); + OwnedRef result(finalize_cb(finalize_function->obj(), state.udf_context_)); RETURN_NOT_OK(CheckPyError()); // unwrapping the output for expected output type if (is_array(result.obj())) {