diff --git a/cpp/examples/arrow/udf_example.cc b/cpp/examples/arrow/udf_example.cc index 573b5ccc78a..f1d47610364 100644 --- a/cpp/examples/arrow/udf_example.cc +++ b/cpp/examples/arrow/udf_example.cc @@ -83,15 +83,154 @@ arrow::Status Execute() { ARROW_ASSIGN_OR_RAISE(auto res, cp::CallFunction(name, {x, y, z})); auto res_array = res.make_array(); - std::cout << "Result" << std::endl; + std::cout << "Scalar UDF Result" << std::endl; std::cout << res_array->ToString() << std::endl; return arrow::Status::OK(); } +// User-defined Scalar Aggregate Function Example +struct ScalarUdfAggregator : public cp::KernelState { + virtual arrow::Status Consume(cp::KernelContext* ctx, const cp::ExecSpan& batch) = 0; + virtual arrow::Status MergeFrom(cp::KernelContext* ctx, cp::KernelState&& src) = 0; + virtual arrow::Status Finalize(cp::KernelContext* ctx, arrow::Datum* out) = 0; +}; + +class SimpleCountFunctionOptionsType : public cp::FunctionOptionsType { + const char* type_name() const override { return "SimpleCountFunctionOptionsType"; } + std::string Stringify(const cp::FunctionOptions&) const override { + return "SimpleCountFunctionOptionsType"; + } + bool Compare(const cp::FunctionOptions&, const cp::FunctionOptions&) const override { + return true; + } + std::unique_ptr Copy(const cp::FunctionOptions&) const override; +}; + +cp::FunctionOptionsType* GetSimpleCountFunctionOptionsType() { + static SimpleCountFunctionOptionsType options_type; + return &options_type; +} + +class SimpleCountOptions : public cp::FunctionOptions { + public: + SimpleCountOptions() : cp::FunctionOptions(GetSimpleCountFunctionOptionsType()) {} + static constexpr char const kTypeName[] = "SimpleCountOptions"; + static SimpleCountOptions Defaults() { return SimpleCountOptions{}; } +}; + +std::unique_ptr SimpleCountFunctionOptionsType::Copy( + const cp::FunctionOptions&) const { + return std::make_unique(); +} + +const cp::FunctionDoc simple_count_doc{ + "SimpleCount the number of null / non-null values", + ("By default, only non-null values are counted.\n" + "This can be changed through SimpleCountOptions."), + {"array"}, + "SimpleCountOptions"}; + +// Need Python interface for this Class +struct SimpleCountImpl : public ScalarUdfAggregator { + explicit SimpleCountImpl(SimpleCountOptions options) : options(std::move(options)) {} + + arrow::Status Consume(cp::KernelContext*, const cp::ExecSpan& batch) override { + if (batch[0].is_array()) { + const arrow::ArraySpan& input = batch[0].array; + const int64_t nulls = input.GetNullCount(); + this->non_nulls += input.length - nulls; + } else { + const arrow::Scalar& input = *batch[0].scalar; + this->non_nulls += input.is_valid * batch.length; + } + return arrow::Status::OK(); + } + + arrow::Status MergeFrom(cp::KernelContext*, cp::KernelState&& src) override { + const auto& other_state = arrow::internal::checked_cast(src); + std::cout << "This non_nulls: " << this->non_nulls << std::endl; + this->non_nulls += other_state.non_nulls; + return arrow::Status::OK(); + } + + arrow::Status Finalize(cp::KernelContext* ctx, arrow::Datum* out) override { + const auto& state = + arrow::internal::checked_cast(*ctx->state()); + *out = arrow::Datum(state.non_nulls); + return arrow::Status::OK(); + } + + SimpleCountOptions options; + int64_t non_nulls = 0; +}; + +// TODO: need a Python interface for this function +arrow::Result> SimpleCountInit( + cp::KernelContext*, const cp::KernelInitArgs& args) { + return std::make_unique( + static_cast(*args.options)); +} + +arrow::Status AggregateUdfConsume(cp::KernelContext* ctx, const cp::ExecSpan& batch) { + return arrow::internal::checked_cast(ctx->state()) + ->Consume(ctx, batch); +} + +arrow::Status AggregateUdfMerge(cp::KernelContext* ctx, cp::KernelState&& src, + cp::KernelState* dst) { + return arrow::internal::checked_cast(dst)->MergeFrom( + ctx, std::move(src)); +} + +arrow::Status AggregateUdfFinalize(cp::KernelContext* ctx, arrow::Datum* out) { + return arrow::internal::checked_cast(ctx->state()) + ->Finalize(ctx, out); +} + +arrow::Status AddAggKernel(std::shared_ptr sig, cp::KernelInit init, + cp::ScalarAggregateFunction* func) { + cp::ScalarAggregateKernel kernel(std::move(sig), std::move(init), AggregateUdfConsume, + AggregateUdfMerge, AggregateUdfFinalize); + ARROW_RETURN_NOT_OK(func->AddKernel(std::move(kernel))); + return arrow::Status::OK(); +} + +arrow::Status ExecuteAggregate() { + auto registry = cp::GetFunctionRegistry(); + static auto default_scalar_aggregate_options = cp::ScalarAggregateOptions::Defaults(); + static auto default_count_options = SimpleCountOptions::Defaults(); + const std::string name = "simple_count"; + auto func = std::make_shared( + name, cp::Arity::Unary(), simple_count_doc, &default_count_options); + + // Takes any input, outputs int64 scalar + ARROW_RETURN_NOT_OK( + AddAggKernel(cp::KernelSignature::Make({arrow::int64()}, arrow::int64()), + SimpleCountInit, func.get())); + ARROW_RETURN_NOT_OK(registry->AddFunction(std::move(func))); + + ARROW_ASSIGN_OR_RAISE(auto x, GetArrayDataSample({1, 2, 3, 4, 5, 6})); + + ARROW_ASSIGN_OR_RAISE(auto res, cp::CallFunction(name, {x})); + auto res_scalar = res.scalar(); + std::cout << "Aggregate UDF Result" << std::endl; + std::cout << res_scalar->ToString() << std::endl; + + return arrow::Status::OK(); +} + int main(int argc, char** argv) { - auto status = Execute(); - if (!status.ok()) { - std::cerr << "Error occurred : " << status.message() << std::endl; + std::cout << "Sample Scalar UDF Execution" << std::endl; + auto s1 = Execute(); + if (!s1.ok()) { + std::cerr << "Error occurred : " << s1.message() << std::endl; + return EXIT_FAILURE; + } + + std::cout << "Sample Aggregate UDF Execution" << std::endl; + auto s2 = ExecuteAggregate(); + if (!s2.ok()) { + std::cerr << "Error occurred : " << s2.message() << std::endl; return EXIT_FAILURE; } return EXIT_SUCCESS; diff --git a/python/pyarrow/_compute.pxd b/python/pyarrow/_compute.pxd index 8b09cbd445e..b34feb5a3a9 100644 --- a/python/pyarrow/_compute.pxd +++ b/python/pyarrow/_compute.pxd @@ -27,6 +27,12 @@ cdef class ScalarUdfContext(_Weakrefable): cdef void init(self, const CScalarUdfContext& c_context) +cdef class ScalarAggregateUdfContext(_Weakrefable): + cdef: + CScalarAggregateUdfContext c_context + + cdef void init(self, const CScalarAggregateUdfContext& c_context) + cdef class FunctionOptions(_Weakrefable): cdef: shared_ptr[CFunctionOptions] wrapped diff --git a/python/pyarrow/_compute.pyx b/python/pyarrow/_compute.pyx index 43f40a86c77..7cf4bf42561 100644 --- a/python/pyarrow/_compute.pyx +++ b/python/pyarrow/_compute.pyx @@ -2474,6 +2474,47 @@ cdef class ScalarUdfContext: return box_memory_pool(self.c_context.pool) +cdef class ScalarAggregateUdfContext: + """ + Per-invocation function context/state. + + This object will always be the first argument to a user-defined + function. It should not be used outside of a call to the function. + """ + + def __init__(self): + raise TypeError("Do not call {}'s constructor directly" + .format(self.__class__.__name__)) + + cdef void init(self, const CScalarAggregateUdfContext &c_context): + self.c_context = c_context + + @property + def memory_pool(self): + """ + A memory pool for allocations (:class:`MemoryPool`). + + This is the memory pool supplied by the user when they invoked + the function and it should be used in any calls to arrow that the + UDF makes if that call accepts a memory_pool. + """ + return box_memory_pool(self.c_context.pool) + + @property + def state(self): + """ + An object which maintains the state of aggregation + """ + obj = self.c_context.state + if obj is None: + raise RuntimeError("Error occurred in extracting state") + return obj + + @state.setter + def state(self, value): + self.c_context.state = value + + cdef inline CFunctionDoc _make_function_doc(dict func_doc) except *: """ Helper function to generate the FunctionDoc @@ -2502,6 +2543,12 @@ cdef object box_scalar_udf_context(const CScalarUdfContext& c_context): return context +cdef object box_scalar_udf_agg_context(const CScalarAggregateUdfContext& c_context): + cdef ScalarAggregateUdfContext context = ScalarAggregateUdfContext.__new__(ScalarAggregateUdfContext) + context.init(c_context) + return context + + cdef _scalar_udf_callback(user_function, const CScalarUdfContext& c_context, inputs): """ Helper callback function used to wrap the ScalarUdfContext from Python to C++ @@ -2511,6 +2558,40 @@ cdef _scalar_udf_callback(user_function, const CScalarUdfContext& c_context, inp return user_function(context, *inputs) +cdef _scalar_agg_consume_udf_callback(consume_function, const CScalarAggregateUdfContext& c_context, inputs): + """ + Helper aggregate consume callback function used to wrap the ScalarAggregateUdfContext from Python to C++ + execution. + """ + context = box_scalar_udf_agg_context(c_context) + return consume_function(context, *inputs) + + +cdef _scalar_agg_merge_udf_callback(merge_function, const CScalarAggregateUdfContext& c_context, other_state): + """ + Helper aggregate merge callback function used to wrap the ScalarAggregateUdfContext from Python to C++ + execution. + """ + context = box_scalar_udf_agg_context(c_context) + return merge_function(context, other_state) + + +cdef _scalar_agg_finalize_udf_callback(finalize_function, const CScalarAggregateUdfContext& c_context): + """ + Helper aggregate finalize callback function used to wrap the ScalarAggregateUdfContext from Python to C++ + execution. + """ + context = box_scalar_udf_agg_context(c_context) + return finalize_function(context) + +cdef _scalar_agg_init_udf_callback(init_function): + """ + Helper aggregate initialize callback function used to wrap the ScalarAggregateUdfContext from Python to C++ + execution. + """ + return init_function() + + def _get_scalar_udf_context(memory_pool, batch_length): cdef CScalarUdfContext c_context c_context.pool = maybe_unbox_memory_pool(memory_pool) @@ -2641,3 +2722,200 @@ def register_scalar_function(func, function_name, function_doc, in_types, check_status(RegisterScalarFunction(c_function, &_scalar_udf_callback, c_options)) + + +def register_scalar_aggregate_function(init_func, consume_func, merge_func, finalize_func, + function_name, function_doc, in_types, out_type): + """ + Register a user-defined scalar aggregate function. + + A scalar aggregate function is a set of 4 functions which formulates + the operation pieces of an scalar aggregation. The base behavior in + terms of computation is very much similar to scalar functions. + + Parameters + ---------- + init_func : callable + A callable implementing the user-defined initialization function. + This function is used to set the state for the aggregate operation + and returns the state object. + consume_func : callable + A callable implementing the user-defined consume function. + The first argument is the context argument of type + ScalarAggregateUdfContext. + Then, it must take arguments equal to the number of + in_types defined. + To define a varargs function, pass a callable that takes + varargs. The last in_type will be the type of all varargs + arguments. + + This function returns the updated state after consuming the + received data. + merge_func: callable + A callable implementing the user-defined merge function. + The first argument is the context argument of type + ScalarAggregateUdfContext. + Then, the second argument it takes is an state object. + This object holds the state with which the current state + must be merged with. The current state can be retrieved from + the context object which can be acessed by `context.state`. + The state doesn't need to be set in the Python side and it is + autonomously handled in the C++ backend. The updated state must + be returned from this function. + finalize_func: callable + A callable implementing the user-defined finalize function. + The first argument is the context argument of type + ScalarUdfContext. + Using the context argument the state can be extracted and return + type must be an array matching the `out_type`. + function_name : str + Name of the function. This name must be globally unique. + function_doc : dict + A dictionary object with keys "summary" (str), + and "description" (str). + in_types : Dict[str, DataType] + A dictionary mapping function argument names to + their respective DataType. + The argument names will be used to generate + documentation for the function. The number of + arguments specified here determines the function + arity. + out_type : DataType + Output type of the function. + + Examples + -------- + >>> class State: + ... def __init__(self, non_null=0): + ... self._non_null = non_null + ... + ... @property + ... def non_null(self): + ... return self._non_null + ... + ... @non_null.setter + ... def non_null(self, value): + ... self._non_null = value + ... + ... def __repr__(self): + ... if self._non_null is None: + ... return "no values stored" + ... else: + ... return "count: " + str(self.non_null) + + >>> def init(): + ... state = State(0) + ... return state + + >>> def consume(ctx, x): + ... if isinstance(x, pa.Array): + ... non_null = pc.sum(pc.invert(pc.is_nan(x))).as_py() + ... elif isinstance(x, pa.Scalar): + ... if x.as_py(): + ... non_null = 1 + ... non_null = non_null + ctx.state.non_null + ... return State(non_null) + + >>> def merge(ctx, other_state): + ... merged_state_val = ctx.state.non_null + other_state.non_null + ... return State(merged_state_val) + + >>> def finalize(ctx): + ... return pa.array([ctx.state.non_null]) + + >>> func_doc = {} + >>> func_doc["summary"] = "simple aggregate udf" + >>> func_doc["description"] = "simple count operation" + + >>> pc.register_scalar_aggregate_function(init, + ... consume, + ... merge, + ... finalize, + ... func_name, + ... unary_doc, + ... {"array": pa.int64()}, + ... pa.int64()) + + >>> pc.call_function(func_name, [pa.array([10, 20, None, 30, None, 40])]) + + [ + 4 + ] + """ + + cdef: + c_string c_func_name + CArity c_arity + CFunctionDoc c_func_doc + vector[shared_ptr[CDataType]] c_in_types + PyObject* c_init_function + PyObject* c_consume_function + PyObject* c_merge_function + PyObject* c_finalize_function + shared_ptr[CDataType] c_out_type + CScalarUdfOptions c_options + + if callable(init_func): + c_init_function = init_func + else: + raise TypeError("init_func must be a callable") + + if callable(consume_func): + c_consume_function = consume_func + else: + raise TypeError("consume_func must be a callable") + + if callable(merge_func): + c_merge_function = merge_func + else: + raise TypeError("consume_func must be a callable") + + if callable(finalize_func): + c_finalize_function = finalize_func + else: + raise TypeError("consume_func must be a callable") + + c_func_name = tobytes(function_name) + + func_spec = inspect.getfullargspec(consume_func) + num_args = -1 + if isinstance(in_types, dict): + for in_type in in_types.values(): + c_in_types.push_back( + pyarrow_unwrap_data_type(ensure_type(in_type))) + function_doc["arg_names"] = in_types.keys() + num_args = len(in_types) + else: + raise TypeError( + "in_types must be a dictionary of DataType") + + c_arity = CArity( num_args, func_spec.varargs) + + if "summary" not in function_doc: + raise ValueError("Function doc must contain a summary") + + if "description" not in function_doc: + raise ValueError("Function doc must contain a description") + + if "arg_names" not in function_doc: + raise ValueError("Function doc must contain arg_names") + + c_func_doc = _make_function_doc(function_doc) + + c_out_type = pyarrow_unwrap_data_type(ensure_type(out_type)) + + c_options.func_name = c_func_name + c_options.arity = c_arity + c_options.func_doc = c_func_doc + c_options.input_types = c_in_types + c_options.output_type = c_out_type + + check_status(RegisterScalarAggregateFunction(c_consume_function, + &_scalar_agg_consume_udf_callback, + c_merge_function, + &_scalar_agg_merge_udf_callback, + c_finalize_function, + &_scalar_agg_finalize_udf_callback, + c_init_function, + &_scalar_agg_init_udf_callback, + c_options)) diff --git a/python/pyarrow/compute.py b/python/pyarrow/compute.py index 5873571c5a0..d72a28df84e 100644 --- a/python/pyarrow/compute.py +++ b/python/pyarrow/compute.py @@ -80,7 +80,9 @@ _group_by, # Udf register_scalar_function, + register_scalar_aggregate_function, ScalarUdfContext, + ScalarAggregateUdfContext, # Expressions Expression, ) diff --git a/python/pyarrow/includes/libarrow.pxd b/python/pyarrow/includes/libarrow.pxd index fbedb0fce36..192f166c8c4 100644 --- a/python/pyarrow/includes/libarrow.pxd +++ b/python/pyarrow/includes/libarrow.pxd @@ -2761,11 +2761,23 @@ cdef extern from "arrow/util/byte_size.h" namespace "arrow::util" nogil: ctypedef PyObject* CallbackUdf(object user_function, const CScalarUdfContext& context, object inputs) +ctypedef PyObject* CallbackAggInitUdf(object init_function) + +ctypedef PyObject* CallbackAggConsumeUdf(object consume_function, const CScalarAggregateUdfContext& context, object inputs) + +ctypedef PyObject* CallbackAggMergeUdf(object merge_function, const CScalarAggregateUdfContext& context, object other_state) + +ctypedef PyObject* CallbackAggFinalizeUdf(object finalize_function, const CScalarAggregateUdfContext& context) + cdef extern from "arrow/python/udf.h" namespace "arrow::py": cdef cppclass CScalarUdfContext" arrow::py::ScalarUdfContext": CMemoryPool *pool int64_t batch_length + cdef cppclass CScalarAggregateUdfContext" arrow::py::ScalarAggregateUdfContext": + CMemoryPool *pool + PyObject* state + cdef cppclass CScalarUdfOptions" arrow::py::ScalarUdfOptions": c_string func_name CArity arity @@ -2775,3 +2787,13 @@ cdef extern from "arrow/python/udf.h" namespace "arrow::py": CStatus RegisterScalarFunction(PyObject* function, function[CallbackUdf] wrapper, const CScalarUdfOptions& options) + + CStatus RegisterScalarAggregateFunction(PyObject* consume_function, + function[CallbackAggConsumeUdf] consume_wrapper, + PyObject* merge_function, + function[CallbackAggMergeUdf] merge_wrapper, + PyObject* finalize_function, + function[CallbackAggFinalizeUdf] finalize_wrapper, + PyObject* init_function, + function[CallbackAggInitUdf] init_wrapper, + const CScalarUdfOptions& options) diff --git a/python/pyarrow/src/arrow/python/udf.cc b/python/pyarrow/src/arrow/python/udf.cc index 81bf47c0ade..04615639f0b 100644 --- a/python/pyarrow/src/arrow/python/udf.cc +++ b/python/pyarrow/src/arrow/python/udf.cc @@ -16,11 +16,16 @@ // under the License. #include "arrow/python/udf.h" +#include "arrow/compute/api_aggregate.h" #include "arrow/compute/function.h" #include "arrow/python/common.h" +// TODO REMOVE +#include + namespace arrow { +using internal::checked_cast; using compute::ExecResult; using compute::ExecSpan; @@ -28,6 +33,20 @@ namespace py { namespace { +void SetUpPythonArgs(int num_args, const ExecSpan& batch, OwnedRef& arg_tuple) { + for (int arg_id = 0; arg_id < num_args; arg_id++) { + if (batch[arg_id].is_scalar()) { + std::shared_ptr c_data = batch[arg_id].scalar->GetSharedPtr(); + PyObject* data = wrap_scalar(c_data); + PyTuple_SetItem(arg_tuple.obj(), arg_id, data); + } else { + std::shared_ptr c_data = batch[arg_id].array.ToArray(); + PyObject* data = wrap_array(c_data); + PyTuple_SetItem(arg_tuple.obj(), arg_id, data); + } + } +} + struct PythonUdf : public compute::KernelState { ScalarUdfWrapperCallback cb; std::shared_ptr function; @@ -51,18 +70,7 @@ struct PythonUdf : public compute::KernelState { OwnedRef arg_tuple(PyTuple_New(num_args)); RETURN_NOT_OK(CheckPyError()); - for (int arg_id = 0; arg_id < num_args; arg_id++) { - if (batch[arg_id].is_scalar()) { - std::shared_ptr c_data = batch[arg_id].scalar->GetSharedPtr(); - PyObject* data = wrap_scalar(c_data); - PyTuple_SetItem(arg_tuple.obj(), arg_id, data); - } else { - std::shared_ptr c_data = batch[arg_id].array.ToArray(); - PyObject* data = wrap_array(c_data); - PyTuple_SetItem(arg_tuple.obj(), arg_id, data); - } - } - + SetUpPythonArgs(num_args, batch, arg_tuple); OwnedRef result(cb(function->obj(), udf_context, arg_tuple.obj())); RETURN_NOT_OK(CheckPyError()); // unwrapping the output for expected output type @@ -120,6 +128,218 @@ Status RegisterScalarFunction(PyObject* user_function, ScalarUdfWrapperCallback return Status::OK(); } +// Scalar Aggregate Functions + +struct ScalarUdfAggregator : public compute::KernelState { + virtual Status Consume(compute::KernelContext* ctx, const compute::ExecSpan& batch) = 0; + virtual Status MergeFrom(compute::KernelContext* ctx, compute::KernelState&& src) = 0; + virtual Status Finalize(compute::KernelContext* ctx, Datum* out) = 0; +}; + +arrow::Status AggregateUdfConsume(compute::KernelContext* ctx, const compute::ExecSpan& batch) { + return checked_cast(ctx->state())->Consume(ctx, batch); +} + +arrow::Status AggregateUdfMerge(compute::KernelContext* ctx, compute::KernelState&& src, + compute::KernelState* dst) { + return checked_cast(dst)->MergeFrom(ctx, std::move(src)); +} + +arrow::Status AggregateUdfFinalize(compute::KernelContext* ctx, arrow::Datum* out) { + return checked_cast(ctx->state())->Finalize(ctx, out); +} + +ScalarAggregateUdfContext::~ScalarAggregateUdfContext() { + if (_Py_IsFinalizing()) { + Py_DECREF(this->state); + } +} + +struct PythonScalarUdfAggregatorImpl : public ScalarUdfAggregator { + + ScalarAggregateInitUdfWrapperCallback init_cb; + ScalarAggregateConsumeUdfWrapperCallback consume_cb; + ScalarAggregateMergeUdfWrapperCallback merge_cb; + ScalarAggregateFinalizeUdfWrapperCallback finalize_cb; + std::shared_ptr init_function; + std::shared_ptr consume_function; + std::shared_ptr merge_function; + std::shared_ptr finalize_function; + std::shared_ptr output_type; + + + PythonScalarUdfAggregatorImpl(ScalarAggregateInitUdfWrapperCallback init_cb, + ScalarAggregateConsumeUdfWrapperCallback consume_cb, + ScalarAggregateMergeUdfWrapperCallback merge_cb, + ScalarAggregateFinalizeUdfWrapperCallback finalize_cb, + std::shared_ptr init_function, + std::shared_ptr consume_function, + std::shared_ptr merge_function, + std::shared_ptr finalize_function, + const std::shared_ptr& output_type) : init_cb(init_cb), + consume_cb(consume_cb), + merge_cb(merge_cb), + finalize_cb(finalize_cb), + init_function(init_function), + consume_function(consume_function), + merge_function(merge_function), + finalize_function(finalize_function), + output_type(output_type) { + Init(init_cb, init_function); + } + + ~PythonScalarUdfAggregatorImpl() { + if (_Py_IsFinalizing()) { + init_function->detach(); + consume_function->detach(); + merge_function->detach(); + finalize_function->detach(); + } + } + + void Init(ScalarAggregateInitUdfWrapperCallback& init_cb , std::shared_ptr& init_function) { + auto st = SafeCallIntoPython([&]() -> Status { + OwnedRef result(init_cb(init_function->obj())); + PyObject* init_res = result.obj(); + Py_INCREF(init_res); + this->udf_context_ = ScalarAggregateUdfContext{default_memory_pool(), std::move(init_res)}; + this->owned_state_.reset(result.obj()); + RETURN_NOT_OK(CheckPyError()); + return Status::OK(); + }); + if (!st.ok()) { + throw std::runtime_error(st.ToString()); + } + } + + Status ConsumeBatch(compute::KernelContext* ctx, const compute::ExecSpan& batch) { + const int num_args = batch.num_values(); + Py_INCREF(this->udf_context_.state); + this->udf_context_.state = this->owned_state_.obj(); + // TODO: think about guaranteeing DRY (following logic already used in ScalarUDFs) + OwnedRef arg_tuple(PyTuple_New(num_args)); + RETURN_NOT_OK(CheckPyError()); + SetUpPythonArgs(num_args, batch, arg_tuple); + OwnedRef result(consume_cb(consume_function->obj(), this->udf_context_, arg_tuple.obj())); + RETURN_NOT_OK(CheckPyError()); + PyObject* consume_res = result.obj(); + Py_INCREF(consume_res); + this->owned_state_.reset(consume_res); + Py_INCREF(this->udf_context_.state); + this->udf_context_.state = this->owned_state_.obj(); + RETURN_NOT_OK(CheckPyError()); + return Status::OK(); + } + + Status Consume(compute::KernelContext* ctx, const compute::ExecSpan& batch) override { + RETURN_NOT_OK(SafeCallIntoPython([&]() -> Status { + RETURN_NOT_OK(ConsumeBatch(ctx, batch)); + return Status::OK(); + })); + return Status::OK(); + } + + Status MergeFrom(compute::KernelContext* ctx, compute::KernelState&& src) override { + const auto& other_state = checked_cast(src); + return SafeCallIntoPython([&]() -> Status { + OwnedRef result(merge_cb(merge_function->obj(), + this->udf_context_, other_state.owned_state_.obj())); + RETURN_NOT_OK(CheckPyError()); + PyObject* merge_res = result.obj(); + Py_INCREF(merge_res); + this->owned_state_.reset(merge_res); + Py_INCREF(this->udf_context_.state); + this->udf_context_.state = this->owned_state_.obj(); + return Status::OK(); + }); + } + + Status Finalize(compute::KernelContext* ctx, arrow::Datum* out) override { + const auto& state = checked_cast(*ctx->state()); + return SafeCallIntoPython([&]() -> Status { + OwnedRef result(finalize_cb(finalize_function->obj(), state.udf_context_)); + RETURN_NOT_OK(CheckPyError()); + // unwrapping the output for expected output type + if (is_array(result.obj())) { + ARROW_ASSIGN_OR_RAISE(std::shared_ptr val, unwrap_array(result.obj())); + if (!output_type->Equals(*val->type())) { + return Status::TypeError("Expected output datatype ", output_type->ToString(), + ", but function returned datatype ", + val->type()->ToString()); + } + *out = Datum(std::move(val)); + return Status::OK(); + } else { + return Status::TypeError("Unexpected output type: ", Py_TYPE(result.obj())->tp_name, + " (expected Array)"); + } + }); + }; + +private: + OwnedRefNoGIL owned_state_; + ScalarAggregateUdfContext udf_context_; +}; + +Status AddAggKernel(std::shared_ptr sig, compute::KernelInit init, + compute::ScalarAggregateFunction* func) { + compute::ScalarAggregateKernel kernel(std::move(sig), std::move(init), AggregateUdfConsume, + AggregateUdfMerge, AggregateUdfFinalize); + RETURN_NOT_OK(func->AddKernel(std::move(kernel))); + return Status::OK(); +} + +Status RegisterScalarAggregateFunction(PyObject* consume_function, + ScalarAggregateConsumeUdfWrapperCallback consume_wrapper, + PyObject* merge_function, + ScalarAggregateMergeUdfWrapperCallback merge_wrapper, + PyObject* finalize_function, + ScalarAggregateFinalizeUdfWrapperCallback finalize_wrapper, + PyObject* init_function, + ScalarAggregateInitUdfWrapperCallback init_wrapper, + const ScalarUdfOptions& options) { + if (!PyCallable_Check(consume_function) || !PyCallable_Check(merge_function) || !PyCallable_Check(finalize_function)) { + return Status::TypeError("Expected a callable Python object."); + } + static auto default_scalar_aggregate_options = compute::ScalarAggregateOptions::Defaults(); + auto aggregate_func = std::make_shared( + options.func_name, options.arity, options.func_doc, &default_scalar_aggregate_options); + + Py_INCREF(consume_function); + Py_INCREF(merge_function); + Py_INCREF(finalize_function); + + std::vector input_types; + for (const auto& in_dtype : options.input_types) { + input_types.emplace_back(in_dtype); + } + compute::OutputType output_type(options.output_type); + + auto init = [init_wrapper, consume_wrapper, merge_wrapper, finalize_wrapper, + init_function, consume_function, merge_function, finalize_function, options]( + compute::KernelContext* ctx, + const compute::KernelInitArgs& args) -> Result> { + return std::make_unique( + init_wrapper, + consume_wrapper, + merge_wrapper, + finalize_wrapper, + std::make_shared(init_function), + std::make_shared(consume_function), + std::make_shared(merge_function), + std::make_shared(finalize_function), + options.output_type); + }; + + RETURN_NOT_OK( + AddAggKernel(compute::KernelSignature::Make(input_types, output_type), + init, aggregate_func.get())); + + auto registry = compute::GetFunctionRegistry(); + RETURN_NOT_OK(registry->AddFunction(std::move(aggregate_func))); + return Status::OK(); +} + } // namespace py } // namespace arrow diff --git a/python/pyarrow/src/arrow/python/udf.h b/python/pyarrow/src/arrow/python/udf.h index 9a3666459fd..140ac642de9 100644 --- a/python/pyarrow/src/arrow/python/udf.h +++ b/python/pyarrow/src/arrow/python/udf.h @@ -47,6 +47,14 @@ struct ARROW_PYTHON_EXPORT ScalarUdfContext { int64_t batch_length; }; +/// \brief A context passed as the first argument of scalar aggregate UDF functions. +struct ARROW_PYTHON_EXPORT ScalarAggregateUdfContext { + MemoryPool* pool; + PyObject* state; + ~ScalarAggregateUdfContext(); +}; + + using ScalarUdfWrapperCallback = std::function; @@ -55,6 +63,29 @@ Status ARROW_PYTHON_EXPORT RegisterScalarFunction(PyObject* user_function, ScalarUdfWrapperCallback wrapper, const ScalarUdfOptions& options); +using ScalarAggregateInitUdfWrapperCallback = std::function; + +using ScalarAggregateConsumeUdfWrapperCallback = std::function; + +using ScalarAggregateMergeUdfWrapperCallback = std::function; + +using ScalarAggregateFinalizeUdfWrapperCallback = std::function; + +/// \brief register a Scalar Aggregate user-defined-function from Python +Status ARROW_PYTHON_EXPORT RegisterScalarAggregateFunction(PyObject* consume_function, + ScalarAggregateConsumeUdfWrapperCallback consume_wrapper, + PyObject* merge_function, + ScalarAggregateMergeUdfWrapperCallback merge_wrapper, + PyObject* finalize_function, + ScalarAggregateFinalizeUdfWrapperCallback finalize_wrapper, + PyObject* init_function, + ScalarAggregateInitUdfWrapperCallback init_wrapper, + const ScalarUdfOptions& options); + + } // namespace py } // namespace arrow diff --git a/python/pyarrow/tests/test_udf.py b/python/pyarrow/tests/test_udf.py index e711619582d..6071229d2ec 100644 --- a/python/pyarrow/tests/test_udf.py +++ b/python/pyarrow/tests/test_udf.py @@ -504,3 +504,132 @@ def test_input_lifetime(unary_func_fixture): # Calling a UDF should not have kept `v` alive longer than required v = None assert proxy_pool.bytes_allocated() == 0 + + +def test_aggregate_udf_with_custom_state(): + class State: + def __init__(self, non_null=0): + self._non_null = non_null + + @property + def non_null(self): + return self._non_null + + @non_null.setter + def non_null(self, value): + self._non_null = value + + def __repr__(self): + if self._non_null is None: + return "no values stored" + else: + return "count: " + str(self.non_null) + + def init(): + state = State(0) + return state + + def consume(ctx, x): + if isinstance(x, pa.Array): + non_null = pc.sum(pc.invert(pc.is_nan(x))).as_py() + elif isinstance(x, pa.Scalar): + if x.as_py(): + non_null = 1 + non_null = non_null + ctx.state.non_null + return State(non_null) + + def merge(ctx, other_state): + merged_state_val = ctx.state.non_null + other_state.non_null + return State(merged_state_val) + + def finalize(ctx): + return pa.array([ctx.state.non_null]) + + func_name = "simple_count" + unary_doc = {"summary": "count function", + "description": "test agg count function"} + + pc.register_scalar_aggregate_function(init, + consume, + merge, + finalize, + func_name, + unary_doc, + {"array": pa.int64()}, + pa.int64()) + + assert pc.call_function(func_name, [pa.array( + [10, 20, None, 30, None, 40])]) == pa.array([4]) + + +def test_aggregate_udf_with_custom_state_multi_attr(): + class State: + def __init__(self, non_null=0, null=0): + self._non_null = non_null + self._null = null + + @property + def non_null(self): + return self._non_null + + @non_null.setter + def non_null(self, value): + self._non_null = value + + @property + def null(self): + return self._null + + @null.setter + def null(self, value): + self._null = value + + def __repr__(self): + if self._non_null is None: + return "no values stored" + else: + return "non_null: " + str(self.non_null) \ + + ", null: " + str(self.null) + + def init(): + state = State(0, 0) + return state + + def consume(ctx, x): + null = 0 + non_null = 0 + if isinstance(x, pa.Array): + non_null = pc.sum(pc.invert(pc.is_nan(x))).as_py() + null = len(x) - non_null + elif isinstance(x, pa.Scalar): + if x.as_py(): + non_null = 1 + else: + null = 1 + non_null = non_null + ctx.state.non_null + return State(non_null, null) + + def merge(ctx, other_state): + merged_st_non_null = ctx.state.non_null + other_state.non_null + merged_st_null = ctx.state.null + other_state.null + return State(merged_st_non_null, merged_st_null) + + def finalize(ctx): + print(ctx.state) + return pa.array([ctx.state.non_null, ctx.state.null]) + + func_name = "advance_count" + unary_doc = {"summary": "count function for null and non-null", + "description": "test agg count function"} + + pc.register_scalar_aggregate_function(init, + consume, + merge, + finalize, + func_name, + unary_doc, + {"array": pa.int64()}, + pa.int64()) + + assert pc.call_function(func_name, [ + pa.array([10, 20, None, 30, None, 40])]) == pa.array([4, 2])