From 0c1c1b2c6b4b8fdb8a6cbbc5195db986df102085 Mon Sep 17 00:00:00 2001 From: Li Jin Date: Tue, 9 May 2023 11:30:33 -0400 Subject: [PATCH 01/19] Add initial implementation and test --- python/pyarrow/_compute.pyx | 84 +++++++++++++ python/pyarrow/compute.py | 1 + python/pyarrow/includes/libarrow.pxd | 4 + python/pyarrow/src/arrow/python/udf.cc | 165 +++++++++++++++++++++++++ python/pyarrow/src/arrow/python/udf.h | 11 +- python/pyarrow/tests/test_udf.py | 31 +++++ 6 files changed, 295 insertions(+), 1 deletion(-) diff --git a/python/pyarrow/_compute.pyx b/python/pyarrow/_compute.pyx index a5db5be5514..2ecf489b495 100644 --- a/python/pyarrow/_compute.pyx +++ b/python/pyarrow/_compute.pyx @@ -2665,6 +2665,12 @@ cdef get_register_tabular_function(): return reg +cdef get_register_aggregate_function(): + cdef RegisterUdf reg = RegisterUdf.__new__(RegisterUdf) + reg.register_func = RegisterAggregateFunction + return reg + + def register_scalar_function(func, function_name, function_doc, in_types, out_type, func_registry=None): """ @@ -2743,6 +2749,84 @@ def register_scalar_function(func, function_name, function_doc, in_types, out_ty out_type, func_registry) +def register_aggregate_function(func, function_name, function_doc, in_types, out_type, + func_registry=None): + """ + Register a user-defined scalar function. + + A scalar function is a function that executes elementwise + operations on arrays or scalars, i.e. a scalar function must + be computed row-by-row with no state where each output row + is computed only from its corresponding input row. + In other words, all argument arrays have the same length, + and the output array is of the same length as the arguments. + Scalar functions are the only functions allowed in query engine + expressions. + + Parameters + ---------- + func : callable + A callable implementing the user-defined function. + The first argument is the context argument of type + ScalarUdfContext. + Then, it must take arguments equal to the number of + in_types defined. It must return an Array or Scalar + matching the out_type. It must return a Scalar if + all arguments are scalar, else it must return an Array. + + To define a varargs function, pass a callable that takes + varargs. The last in_type will be the type of all varargs + arguments. + function_name : str + Name of the function. This name must be globally unique. + function_doc : dict + A dictionary object with keys "summary" (str), + and "description" (str). + in_types : Dict[str, DataType] + A dictionary mapping function argument names to + their respective DataType. + The argument names will be used to generate + documentation for the function. The number of + arguments specified here determines the function + arity. + out_type : DataType + Output type of the function. + func_registry : FunctionRegistry + Optional function registry to use instead of the default global one. + + Examples + -------- + >>> import pyarrow as pa + >>> import pyarrow.compute as pc + >>> + >>> func_doc = {} + >>> func_doc["summary"] = "simple udf" + >>> func_doc["description"] = "add a constant to a scalar" + >>> + >>> def add_constant(ctx, array): + ... return pc.add(array, 1, memory_pool=ctx.memory_pool) + >>> + >>> func_name = "py_add_func" + >>> in_types = {"array": pa.int64()} + >>> out_type = pa.int64() + >>> pc.register_scalar_function(add_constant, func_name, func_doc, + ... in_types, out_type) + >>> + >>> func = pc.get_function(func_name) + >>> func.name + 'py_add_func' + >>> answer = pc.call_function(func_name, [pa.array([20])]) + >>> answer + + [ + 21 + ] + """ + return _register_scalar_like_function(get_register_aggregate_function(), + func, function_name, function_doc, in_types, + out_type, func_registry) + + def register_tabular_function(func, function_name, function_doc, in_types, out_type, func_registry=None): """ diff --git a/python/pyarrow/compute.py b/python/pyarrow/compute.py index e299d44c04e..4d622932df3 100644 --- a/python/pyarrow/compute.py +++ b/python/pyarrow/compute.py @@ -84,6 +84,7 @@ call_tabular_function, register_scalar_function, register_tabular_function, + register_aggregate_function, ScalarUdfContext, # Expressions Expression, diff --git a/python/pyarrow/includes/libarrow.pxd b/python/pyarrow/includes/libarrow.pxd index 3190877ea09..a6fb2c4e9e2 100644 --- a/python/pyarrow/includes/libarrow.pxd +++ b/python/pyarrow/includes/libarrow.pxd @@ -2805,5 +2805,9 @@ cdef extern from "arrow/python/udf.h" namespace "arrow::py" nogil: function[CallbackUdf] wrapper, const CUdfOptions& options, CFunctionRegistry* registry) + CStatus RegisterAggregateFunction(PyObject* function, + function[CallbackUdf] wrapper, const CUdfOptions& options, + CFunctionRegistry* registry) + CResult[shared_ptr[CRecordBatchReader]] CallTabularFunction( const c_string& func_name, const vector[CDatum]& args, CFunctionRegistry* registry) diff --git a/python/pyarrow/src/arrow/python/udf.cc b/python/pyarrow/src/arrow/python/udf.cc index 7d63adb8352..9bed9555c70 100644 --- a/python/pyarrow/src/arrow/python/udf.cc +++ b/python/pyarrow/src/arrow/python/udf.cc @@ -15,13 +15,18 @@ // specific language governing permissions and limitations // under the License. +#include + #include "arrow/python/udf.h" +#include "arrow/table.h" +#include "arrow/compute/api_aggregate.h" #include "arrow/compute/function.h" #include "arrow/compute/kernel.h" #include "arrow/python/common.h" #include "arrow/util/checked_cast.h" namespace arrow { + using internal::checked_cast; namespace py { namespace { @@ -65,6 +70,26 @@ struct PythonUdfKernelInit { std::shared_ptr function; }; +struct ScalarUdfAggregator : public compute::KernelState { + virtual Status Consume(compute::KernelContext* ctx, const compute::ExecSpan& batch) = 0; + virtual Status MergeFrom(compute::KernelContext* ctx, compute::KernelState&& src) = 0; + virtual Status Finalize(compute::KernelContext* ctx, Datum* out) = 0; +}; + +arrow::Status AggregateUdfConsume(compute::KernelContext* ctx, const compute::ExecSpan& batch) { + return checked_cast(ctx->state())->Consume(ctx, batch); +} + +arrow::Status AggregateUdfMerge(compute::KernelContext* ctx, compute::KernelState&& src, + compute::KernelState* dst) { + return checked_cast(dst)->MergeFrom(ctx, std::move(src)); +} + +arrow::Status AggregateUdfFinalize(compute::KernelContext* ctx, arrow::Datum* out) { + auto udf = checked_cast(ctx->state()); + return SafeCallIntoPython([&]() -> Status {return udf->Finalize(ctx, out);}); +} + struct PythonTableUdfKernelInit { PythonTableUdfKernelInit(std::shared_ptr function_maker, UdfWrapperCallback cb) @@ -101,6 +126,98 @@ struct PythonTableUdfKernelInit { UdfWrapperCallback cb; }; + struct PythonUdfScalarAggregatorImpl : public ScalarUdfAggregator { + + PythonUdfScalarAggregatorImpl(UdfWrapperCallback agg_cb, + std::shared_ptr agg_function, + std::vector> input_types, + std::shared_ptr output_type): + agg_cb(agg_cb), + agg_function(agg_function), + output_type(output_type) { + std::vector> fields; + for (size_t i = 0; i < input_types.size(); i++) { + fields.push_back(field("", input_types[i])); + } + input_schema = schema(fields); + }; + + ~PythonUdfScalarAggregatorImpl() { + if (_Py_IsFinalizing()) { + agg_function->detach(); + } + } + + Status Consume(compute::KernelContext* ctx, const compute::ExecSpan& batch) { + num_rows = batch.length; + ARROW_ASSIGN_OR_RAISE(auto rb, batch.ToExecBatch().ToRecordBatch(input_schema, ctx->memory_pool())); + values.push_back(rb); + return Status::OK(); + } + + Status MergeFrom(compute::KernelContext* ctx, compute::KernelState&& src) { + const auto& other_state = checked_cast(src); + num_rows += other_state.num_rows; + values.insert(values.end(), other_state.values.begin(), other_state.values.end()); + return Status::OK(); + } + + Status Finalize(compute::KernelContext* ctx, Datum* out) { + auto state = arrow::internal::checked_cast(ctx->state()); + std::shared_ptr& function = state->agg_function; + const int num_args = input_schema->num_fields(); + // Ignore batch length here + ScalarUdfContext udf_context{ctx->memory_pool(), 0}; + + + OwnedRef arg_tuple(PyTuple_New(num_args)); + RETURN_NOT_OK(CheckPyError()); + + ARROW_ASSIGN_OR_RAISE( + auto table, + arrow::Table::FromRecordBatches(input_schema, values) + ); + ARROW_ASSIGN_OR_RAISE( + table, table->CombineChunks(ctx->memory_pool()) + ); + + for (int arg_id = 0; arg_id < num_args; arg_id++) { + // Since we combined chunks there is only one chunk + std::shared_ptr c_data = table->column(arg_id)->chunk(0); + PyObject* data = wrap_array(c_data); + PyTuple_SetItem(arg_tuple.obj(), arg_id, data); + } + + OwnedRef result(agg_cb(function->obj(), udf_context, arg_tuple.obj())); + RETURN_NOT_OK(CheckPyError()); + + // unwrapping the output for expected output type + if (is_array(result.obj())) { + ARROW_ASSIGN_OR_RAISE(std::shared_ptr val, unwrap_array(result.obj())); + std::cout << val->type()->ToString() << std::endl; + if (*output_type != *val->type()) { + return Status::TypeError("Expected output datatype ", output_type->ToString(), + ", but function returned datatype ", + val->type()->ToString()); + } + out->value = std::move(val->data()); + return Status::OK(); + } else { + return Status::TypeError("Unexpected output type: ", Py_TYPE(result.obj())->tp_name, + " (expected Array)"); + } + // *out = Datum((int32_t)table->num_rows()); + return Status::OK(); + } + + int32_t num_rows = 0; + UdfWrapperCallback agg_cb; + std::vector> values; + std::shared_ptr agg_function; + std::shared_ptr input_schema; + std::shared_ptr output_type; + }; + struct PythonUdf : public PythonUdfKernelState { PythonUdf(std::shared_ptr function, UdfWrapperCallback cb, std::vector input_types, compute::OutputType output_type) @@ -234,6 +351,54 @@ Status RegisterTabularFunction(PyObject* user_function, UdfWrapperCallback wrapp wrapper, options, registry); } +Status AddAggKernel(std::shared_ptr sig, compute::KernelInit init, + compute::ScalarAggregateFunction* func) { + + compute::ScalarAggregateKernel kernel(std::move(sig), std::move(init), AggregateUdfConsume, AggregateUdfMerge, AggregateUdfFinalize, /*ordered=*/false); + RETURN_NOT_OK(func->AddKernel(std::move(kernel))); + return Status::OK(); +} + +Status RegisterAggregateFunction(PyObject* agg_function, UdfWrapperCallback agg_wrapper, + const UdfOptions& options, + compute::FunctionRegistry* registry) { + if (!PyCallable_Check(agg_function)) { + return Status::TypeError("Expected a callable Python object."); + } + + if (registry == NULLPTR) { + registry = compute::GetFunctionRegistry(); + } + + static auto default_scalar_aggregate_options = compute::ScalarAggregateOptions::Defaults(); + auto aggregate_func = std::make_shared( + options.func_name, options.arity, options.func_doc, &default_scalar_aggregate_options); + + Py_INCREF(agg_function); + std::vector input_types; + for (const auto& in_dtype : options.input_types) { + input_types.emplace_back(in_dtype); + } + compute::OutputType output_type(options.output_type); + + auto init = [agg_wrapper, agg_function, options]( + compute::KernelContext* ctx, + const compute::KernelInitArgs& args) -> Result> { + return std::make_unique( + agg_wrapper, + std::make_shared(agg_function), + options.input_types, + options.output_type); + }; + + RETURN_NOT_OK( + AddAggKernel(compute::KernelSignature::Make(input_types, output_type), + init, aggregate_func.get())); + + RETURN_NOT_OK(registry->AddFunction(std::move(aggregate_func))); + return Status::OK(); +} + Result> CallTabularFunction( const std::string& func_name, const std::vector& args, compute::FunctionRegistry* registry) { diff --git a/python/pyarrow/src/arrow/python/udf.h b/python/pyarrow/src/arrow/python/udf.h index b3dcc9ccf44..ebcc1d3281a 100644 --- a/python/pyarrow/src/arrow/python/udf.h +++ b/python/pyarrow/src/arrow/python/udf.h @@ -59,7 +59,16 @@ Status ARROW_PYTHON_EXPORT RegisterScalarFunction( /// \brief register a Table user-defined-function from Python Status ARROW_PYTHON_EXPORT RegisterTabularFunction( - PyObject* user_function, UdfWrapperCallback wrapper, const UdfOptions& options, + PyObject* user_function, UdfWrapperCallback wrapper, + const UdfOptions& options, compute::FunctionRegistry* registry = NULLPTR); + +/// \brief register a Aggregate user-defined-function from Python +Status ARROW_PYTHON_EXPORT RegisterAggregateFunction( + PyObject* user_function, UdfWrapperCallback wrapper, + const UdfOptions& options, compute::FunctionRegistry* registry = NULLPTR); + +Result> ARROW_PYTHON_EXPORT CallTabularFunction( + const std::string& func_name, const std::vector& args, compute::FunctionRegistry* registry = NULLPTR); Result> ARROW_PYTHON_EXPORT diff --git a/python/pyarrow/tests/test_udf.py b/python/pyarrow/tests/test_udf.py index 0f336555f76..bc3e2a5ef46 100644 --- a/python/pyarrow/tests/test_udf.py +++ b/python/pyarrow/tests/test_udf.py @@ -39,6 +39,30 @@ class MyError(RuntimeError): pass +@pytest.fixture(scope="session") +def unary_agg_func_fixture(): + """ + Register a unary aggregate function + """ + + def func(ctx, x): + return pa.array([len(x)]) + + func_name = "y=len(x)" + func_doc = {"summary": "y=len(x)", + "description": "find length of x"} + + pc.register_aggregate_function(func, + func_name, + func_doc, + { + "x": pa.int64(), + }, + pa.int64() + ) + return func, func_name + + @pytest.fixture(scope="session") def binary_func_fixture(): """ @@ -593,3 +617,10 @@ def test_udt_datasource1_generator(): def test_udt_datasource1_exception(): with pytest.raises(RuntimeError, match='datasource1_exception'): _test_datasource1_udt(datasource1_exception) + + +def test_aggregate_udf_basic(unary_agg_func_fixture): + arr = pa.array([10, 20, 30, 40, 50, 60], pa.int64()) + result = pc.call_function("y=len(x)", [arr]) + expected = pa.array([6]) + assert result == expected From 94c9710cb0eb1b2b0d0881e41f85dea69130c53a Mon Sep 17 00:00:00 2001 From: Li Jin Date: Tue, 9 May 2023 16:51:18 -0400 Subject: [PATCH 02/19] Add repro for segfault --- python/pyarrow/src/arrow/python/udf.cc | 1 - python/pyarrow/src/arrow/python/udf.h | 4 -- python/pyarrow/tests/test_udf.py | 51 +++++++++++++++++++++++++- 3 files changed, 49 insertions(+), 7 deletions(-) diff --git a/python/pyarrow/src/arrow/python/udf.cc b/python/pyarrow/src/arrow/python/udf.cc index 9bed9555c70..df52c6ed34b 100644 --- a/python/pyarrow/src/arrow/python/udf.cc +++ b/python/pyarrow/src/arrow/python/udf.cc @@ -169,7 +169,6 @@ struct PythonTableUdfKernelInit { // Ignore batch length here ScalarUdfContext udf_context{ctx->memory_pool(), 0}; - OwnedRef arg_tuple(PyTuple_New(num_args)); RETURN_NOT_OK(CheckPyError()); diff --git a/python/pyarrow/src/arrow/python/udf.h b/python/pyarrow/src/arrow/python/udf.h index ebcc1d3281a..f96bd725bdf 100644 --- a/python/pyarrow/src/arrow/python/udf.h +++ b/python/pyarrow/src/arrow/python/udf.h @@ -67,10 +67,6 @@ Status ARROW_PYTHON_EXPORT RegisterAggregateFunction( PyObject* user_function, UdfWrapperCallback wrapper, const UdfOptions& options, compute::FunctionRegistry* registry = NULLPTR); -Result> ARROW_PYTHON_EXPORT CallTabularFunction( - const std::string& func_name, const std::vector& args, - compute::FunctionRegistry* registry = NULLPTR); - Result> ARROW_PYTHON_EXPORT CallTabularFunction(const std::string& func_name, const std::vector& args, compute::FunctionRegistry* registry = NULLPTR); diff --git a/python/pyarrow/tests/test_udf.py b/python/pyarrow/tests/test_udf.py index bc3e2a5ef46..0995266801e 100644 --- a/python/pyarrow/tests/test_udf.py +++ b/python/pyarrow/tests/test_udf.py @@ -63,6 +63,31 @@ def func(ctx, x): return func, func_name +@pytest.fixture(scope="session") +def bad_unary_agg_func_fixture(): + """ + Register a unary aggregate function + """ + + def func(ctx, x): + raise RuntimeError("Oops") + return pa.array([len(x)]) + + func_name = "y=bad_len(x)" + func_doc = {"summary": "y=bad_len(x)", + "description": "find length of"} + + pc.register_aggregate_function(func, + func_name, + func_doc, + { + "x": pa.int64(), + }, + pa.int64() + ) + return func, func_name + + @pytest.fixture(scope="session") def binary_func_fixture(): """ @@ -252,11 +277,11 @@ def check_scalar_function(func_fixture, if all_scalar: batch_length = 1 - expected_output = function(mock_scalar_udf_context(batch_length), *inputs) func = pc.get_function(name) assert func.name == name result = pc.call_function(name, inputs, length=batch_length) + expected_output = function(mock_scalar_udf_context(batch_length), *inputs) assert result == expected_output # At the moment there is an issue when handling nullary functions. # See: ARROW-15286 and ARROW-16290. @@ -619,8 +644,30 @@ def test_udt_datasource1_exception(): _test_datasource1_udt(datasource1_exception) -def test_aggregate_udf_basic(unary_agg_func_fixture): +def test_aggregate_basic(unary_agg_func_fixture): arr = pa.array([10, 20, 30, 40, 50, 60], pa.int64()) result = pc.call_function("y=len(x)", [arr]) expected = pa.array([6]) assert result == expected + + +def test_aggregate_exception(bad_unary_agg_func_fixture): + arr = pa.array([10, 20, 30, 40, 50, 60], pa.int64()) + with pytest.raises(RuntimeError, match='Oops'): + try: + pc.call_function("y=bad_len(x)", [arr]) + except Exception as e: + raise e + + +def test_aggregate_segfault(bad_unary_agg_func_fixture): + # This test will segfault the python process. + arr = pa.array([10, 20, 30, 40, 50, 60], pa.int64()) + try: + result = pc.call_function("y=bad_len(x)", [arr]) + except Exception as e: + print("Before segfault") + raise e + + expected = pa.array([6]) + assert result == expected From 241e970551419c7ff03a753ee11e0b0ec74380d3 Mon Sep 17 00:00:00 2001 From: Li Jin Date: Mon, 15 May 2023 10:27:16 -0400 Subject: [PATCH 03/19] Fix python object reference counter --- python/pyarrow/src/arrow/python/udf.cc | 1 + python/pyarrow/tests/test_udf.py | 13 ------------- 2 files changed, 1 insertion(+), 13 deletions(-) diff --git a/python/pyarrow/src/arrow/python/udf.cc b/python/pyarrow/src/arrow/python/udf.cc index df52c6ed34b..d83da5d2c60 100644 --- a/python/pyarrow/src/arrow/python/udf.cc +++ b/python/pyarrow/src/arrow/python/udf.cc @@ -383,6 +383,7 @@ Status RegisterAggregateFunction(PyObject* agg_function, UdfWrapperCallback agg_ auto init = [agg_wrapper, agg_function, options]( compute::KernelContext* ctx, const compute::KernelInitArgs& args) -> Result> { + Py_INCREF(agg_function); return std::make_unique( agg_wrapper, std::make_shared(agg_function), diff --git a/python/pyarrow/tests/test_udf.py b/python/pyarrow/tests/test_udf.py index 0995266801e..682b9279964 100644 --- a/python/pyarrow/tests/test_udf.py +++ b/python/pyarrow/tests/test_udf.py @@ -658,16 +658,3 @@ def test_aggregate_exception(bad_unary_agg_func_fixture): pc.call_function("y=bad_len(x)", [arr]) except Exception as e: raise e - - -def test_aggregate_segfault(bad_unary_agg_func_fixture): - # This test will segfault the python process. - arr = pa.array([10, 20, 30, 40, 50, 60], pa.int64()) - try: - result = pc.call_function("y=bad_len(x)", [arr]) - except Exception as e: - print("Before segfault") - raise e - - expected = pa.array([6]) - assert result == expected From 15194faf308da9136e0bfc5b1d2f36f536cb74c6 Mon Sep 17 00:00:00 2001 From: Li Jin Date: Mon, 15 May 2023 18:07:47 -0400 Subject: [PATCH 04/19] Implement support for calling from Acero/Substrait --- .../arrow/engine/substrait/extension_set.cc | 124 ++++++++--------- python/pyarrow/conftest.py | 26 ++++ python/pyarrow/src/arrow/python/udf.cc | 6 +- python/pyarrow/tests/test_substrait.py | 130 ++++++++++++++++++ python/pyarrow/tests/test_udf.py | 30 +--- 5 files changed, 225 insertions(+), 91 deletions(-) diff --git a/cpp/src/arrow/engine/substrait/extension_set.cc b/cpp/src/arrow/engine/substrait/extension_set.cc index 5501889d7a2..b6b3aa10850 100644 --- a/cpp/src/arrow/engine/substrait/extension_set.cc +++ b/cpp/src/arrow/engine/substrait/extension_set.cc @@ -363,6 +363,68 @@ ExtensionIdRegistry::SubstraitAggregateToArrow kSimpleSubstraitAggregateToArrow return DecodeBasicAggregate(std::string(call.id().name))(call); }; +ExtensionIdRegistry::SubstraitAggregateToArrow DecodeBasicAggregate( + const std::string& arrow_function_name) { + return [arrow_function_name](const SubstraitCall& call) -> Result { + std::string fixed_arrow_func; + if (call.is_hash()) { + fixed_arrow_func = "hash_"; + } + + switch (call.size()) { + case 0: { + if (call.id().name == "count") { + fixed_arrow_func += "count_all"; + return compute::Aggregate{std::move(fixed_arrow_func), ""}; + } + return Status::Invalid("Expected aggregate call ", call.id().uri, "#", + call.id().name, " to have at least one argument"); + } + case 1: { + std::shared_ptr options = nullptr; + if (arrow_function_name == "stddev" || arrow_function_name == "variance") { + // See the following URL for the spec of stddev and variance: + // https://github.com/substrait-io/substrait/blob/ + // 73228b4112d79eb1011af0ebb41753ce23ca180c/ + // extensions/functions_arithmetic.yaml#L1240 + auto maybe_dist = call.GetOption("distribution"); + if (maybe_dist) { + auto& prefs = **maybe_dist; + if (prefs.size() != 1) { + return Status::Invalid("expected a single preference for ", + arrow_function_name, " but got ", prefs.size()); + } + int ddof; + if (prefs[0] == "POPULATION") { + ddof = 1; + } else if (prefs[0] == "SAMPLE") { + ddof = 0; + } else { + return Status::Invalid("unknown distribution preference ", prefs[0]); + } + options = std::make_shared(ddof); + } + } + fixed_arrow_func += arrow_function_name; + + ARROW_ASSIGN_OR_RAISE(compute::Expression arg, call.GetValueArg(0)); + const FieldRef* arg_ref = arg.field_ref(); + if (!arg_ref) { + return Status::Invalid("Expected an aggregate call ", call.id().uri, "#", + call.id().name, " to have a direct reference"); + } + + return compute::Aggregate{std::move(fixed_arrow_func), + options ? std::move(options) : nullptr, *arg_ref, ""}; + } + default: + break; + } + return Status::NotImplemented( + "Only nullary and unary aggregate functions are currently supported"); + }; +} + struct ExtensionIdRegistryImpl : ExtensionIdRegistry { ExtensionIdRegistryImpl() : parent_(nullptr) {} explicit ExtensionIdRegistryImpl(const ExtensionIdRegistry* parent) : parent_(parent) {} @@ -937,68 +999,6 @@ ExtensionIdRegistry::SubstraitCallToArrow DecodeConcatMapping() { }; } -ExtensionIdRegistry::SubstraitAggregateToArrow DecodeBasicAggregate( - const std::string& arrow_function_name) { - return [arrow_function_name](const SubstraitCall& call) -> Result { - std::string fixed_arrow_func; - if (call.is_hash()) { - fixed_arrow_func = "hash_"; - } - - switch (call.size()) { - case 0: { - if (call.id().name == "count") { - fixed_arrow_func += "count_all"; - return compute::Aggregate{std::move(fixed_arrow_func), ""}; - } - return Status::Invalid("Expected aggregate call ", call.id().uri, "#", - call.id().name, " to have at least one argument"); - } - case 1: { - std::shared_ptr options = nullptr; - if (arrow_function_name == "stddev" || arrow_function_name == "variance") { - // See the following URL for the spec of stddev and variance: - // https://github.com/substrait-io/substrait/blob/ - // 73228b4112d79eb1011af0ebb41753ce23ca180c/ - // extensions/functions_arithmetic.yaml#L1240 - auto maybe_dist = call.GetOption("distribution"); - if (maybe_dist) { - auto& prefs = **maybe_dist; - if (prefs.size() != 1) { - return Status::Invalid("expected a single preference for ", - arrow_function_name, " but got ", prefs.size()); - } - int ddof; - if (prefs[0] == "POPULATION") { - ddof = 1; - } else if (prefs[0] == "SAMPLE") { - ddof = 0; - } else { - return Status::Invalid("unknown distribution preference ", prefs[0]); - } - options = std::make_shared(ddof); - } - } - fixed_arrow_func += arrow_function_name; - - ARROW_ASSIGN_OR_RAISE(compute::Expression arg, call.GetValueArg(0)); - const FieldRef* arg_ref = arg.field_ref(); - if (!arg_ref) { - return Status::Invalid("Expected an aggregate call ", call.id().uri, "#", - call.id().name, " to have a direct reference"); - } - - return compute::Aggregate{std::move(fixed_arrow_func), - options ? std::move(options) : nullptr, *arg_ref, ""}; - } - default: - break; - } - return Status::NotImplemented( - "Only nullary and unary aggregate functions are currently supported"); - }; -} - struct DefaultExtensionIdRegistry : ExtensionIdRegistryImpl { DefaultExtensionIdRegistry() { // ----------- Extension Types ---------------------------- diff --git a/python/pyarrow/conftest.py b/python/pyarrow/conftest.py index ef09393cfbd..a107cb791f8 100644 --- a/python/pyarrow/conftest.py +++ b/python/pyarrow/conftest.py @@ -278,3 +278,29 @@ def unary_function(ctx, x): {"array": pa.int64()}, pa.int64()) return unary_function, func_name + + +@pytest.fixture(scope="session") +def unary_agg_func_fixture(): + """ + Register a unary aggregate function + """ + from pyarrow import compute as pc + import numpy as np + + def func(ctx, x): + return pa.array([np.nanmean(x)]) + + func_name = "y=avg(x)" + func_doc = {"summary": "y=avg(x)", + "description": "find mean of x"} + + pc.register_aggregate_function(func, + func_name, + func_doc, + { + "x": pa.float64(), + }, + pa.float64() + ) + return func, func_name diff --git a/python/pyarrow/src/arrow/python/udf.cc b/python/pyarrow/src/arrow/python/udf.cc index d83da5d2c60..43a7f5af51e 100644 --- a/python/pyarrow/src/arrow/python/udf.cc +++ b/python/pyarrow/src/arrow/python/udf.cc @@ -199,11 +199,13 @@ struct PythonTableUdfKernelInit { ", but function returned datatype ", val->type()->ToString()); } - out->value = std::move(val->data()); + ARROW_ASSIGN_OR_RAISE(auto scalar_val , val->GetScalar(0)); + *out = Datum(scalar_val); + // out->value = std::move(val->data()); return Status::OK(); } else { return Status::TypeError("Unexpected output type: ", Py_TYPE(result.obj())->tp_name, - " (expected Array)"); + " (expected Array)"); } // *out = Datum((int32_t)table->num_rows()); return Status::OK(); diff --git a/python/pyarrow/tests/test_substrait.py b/python/pyarrow/tests/test_substrait.py index d0da517ea7f..4e2d7b7655e 100644 --- a/python/pyarrow/tests/test_substrait.py +++ b/python/pyarrow/tests/test_substrait.py @@ -436,6 +436,7 @@ def table_provider(names, _): """ buf = pa._substrait._parse_json_plan(substrait_query) + breakpoint() reader = pa.substrait.run_query( buf, table_provider=table_provider, use_threads=use_threads) res_tb = reader.read_all() @@ -605,3 +606,132 @@ def table_provider(names, schema): expected = pa.Table.from_pydict({"out": [1, 2, 3]}) assert res_tb == expected + + +def test_agg_udf(unary_agg_func_fixture): + + test_table = pa.Table.from_pydict( + {"k": [1, 1, 2, 2], "v": [1.0, 2.0, 3.0, 4.0]} + ) + + def table_provider(names, _): + return test_table + + substrait_query = b""" +{ + "extensionUris": [ + { + "extensionUriAnchor": 1, + "uri": "urn:arrow:substrait_simple_extension_function" + }, + ], + "extensions": [ + { + "extensionFunction": { + "extensionUriReference": 1, + "functionAnchor": 1, + "name": "y=avg(x)" + } + } + ], + "relations": [ + { + "root": { + "input": { + "extensionSingle": { + "common": { + "emit": { + "outputMapping": [ + 0, + 1 + ] + } + }, + "input": { + "read": { + "baseSchema": { + "names": [ + "time", + "price" + ], + "struct": { + "types": [ + { + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + { + "fp64": { + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": ["t1"] + } + } + }, + "detail": { + "@type": "/arrow.substrait_ext.SegmentedAggregateRel", + "segmentKeys": [ + { + "directReference": { + "structField": {} + }, + "rootReference": {} + } + ], + "measures": [ + { + "measure": { + "functionReference": 1, + "phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT", + "outputType": { + "fp64": { + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": {} + } + } + } + ] + } + } + ] + } + } + }, + "names": [ + "k", + "v_avg" + ] + } + } + ], +} +""" + buf = pa._substrait._parse_json_plan(substrait_query) + reader = pa.substrait.run_query( + buf, table_provider=table_provider, use_threads=False) + res_tb = reader.read_all() + + expected_tb = pa.Table.from_pydict({ + 'k': [1, 2], + 'v_avg': [1.5, 3.5] + }) + + assert res_tb == expected_tb diff --git a/python/pyarrow/tests/test_udf.py b/python/pyarrow/tests/test_udf.py index 682b9279964..b2db724aa4b 100644 --- a/python/pyarrow/tests/test_udf.py +++ b/python/pyarrow/tests/test_udf.py @@ -39,30 +39,6 @@ class MyError(RuntimeError): pass -@pytest.fixture(scope="session") -def unary_agg_func_fixture(): - """ - Register a unary aggregate function - """ - - def func(ctx, x): - return pa.array([len(x)]) - - func_name = "y=len(x)" - func_doc = {"summary": "y=len(x)", - "description": "find length of x"} - - pc.register_aggregate_function(func, - func_name, - func_doc, - { - "x": pa.int64(), - }, - pa.int64() - ) - return func, func_name - - @pytest.fixture(scope="session") def bad_unary_agg_func_fixture(): """ @@ -645,9 +621,9 @@ def test_udt_datasource1_exception(): def test_aggregate_basic(unary_agg_func_fixture): - arr = pa.array([10, 20, 30, 40, 50, 60], pa.int64()) - result = pc.call_function("y=len(x)", [arr]) - expected = pa.array([6]) + arr = pa.array([10.0, 20.0, 30.0, 40.0, 50.0, 60.0], pa.float64()) + result = pc.call_function("y=avg(x)", [arr]) + expected = pa.scalar(35.0) assert result == expected From e91b8828978f1888e264934448dc7a2f1c806be7 Mon Sep 17 00:00:00 2001 From: Li Jin Date: Tue, 16 May 2023 16:39:28 -0400 Subject: [PATCH 05/19] First around of self review --- python/pyarrow/_compute.pxd | 6 +- python/pyarrow/_compute.pyx | 117 +++++++++++-------------- python/pyarrow/compute.py | 2 +- python/pyarrow/conftest.py | 2 +- python/pyarrow/includes/libarrow.pxd | 4 +- python/pyarrow/src/arrow/python/udf.cc | 43 +++++---- python/pyarrow/src/arrow/python/udf.h | 10 +-- python/pyarrow/tests/test_substrait.py | 11 ++- python/pyarrow/tests/test_udf.py | 10 +-- 9 files changed, 93 insertions(+), 112 deletions(-) diff --git a/python/pyarrow/_compute.pxd b/python/pyarrow/_compute.pxd index 2dc0de2d0bf..29b37da3ac4 100644 --- a/python/pyarrow/_compute.pxd +++ b/python/pyarrow/_compute.pxd @@ -21,11 +21,11 @@ from pyarrow.lib cimport * from pyarrow.includes.common cimport * from pyarrow.includes.libarrow cimport * -cdef class ScalarUdfContext(_Weakrefable): +cdef class UdfContext(_Weakrefable): cdef: - CScalarUdfContext c_context + CUdfContext c_context - cdef void init(self, const CScalarUdfContext& c_context) + cdef void init(self, const CUdfContext& c_context) cdef class FunctionOptions(_Weakrefable): diff --git a/python/pyarrow/_compute.pyx b/python/pyarrow/_compute.pyx index 2ecf489b495..6e62d8501ee 100644 --- a/python/pyarrow/_compute.pyx +++ b/python/pyarrow/_compute.pyx @@ -2559,7 +2559,7 @@ cdef CExpression _bind(Expression filter, Schema schema) except *: deref(pyarrow_unwrap_schema(schema).get()))) -cdef class ScalarUdfContext: +cdef class UdfContext: """ Per-invocation function context/state. @@ -2571,7 +2571,7 @@ cdef class ScalarUdfContext: raise TypeError("Do not call {}'s constructor directly" .format(self.__class__.__name__)) - cdef void init(self, const CScalarUdfContext &c_context): + cdef void init(self, const CUdfContext &c_context): self.c_context = c_context @property @@ -2620,26 +2620,26 @@ cdef inline CFunctionDoc _make_function_doc(dict func_doc) except *: return f_doc -cdef object box_scalar_udf_context(const CScalarUdfContext& c_context): - cdef ScalarUdfContext context = ScalarUdfContext.__new__(ScalarUdfContext) +cdef object box_udf_context(const CUdfContext& c_context): + cdef UdfContext context = UdfContext.__new__(UdfContext) context.init(c_context) return context -cdef _udf_callback(user_function, const CScalarUdfContext& c_context, inputs): +cdef _udf_callback(user_function, const CUdfContext& c_context, inputs): """ - Helper callback function used to wrap the ScalarUdfContext from Python to C++ + Helper callback function used to wrap the UdfContext from Python to C++ execution. """ - context = box_scalar_udf_context(c_context) + context = box_udf_context(c_context) return user_function(context, *inputs) -def _get_scalar_udf_context(memory_pool, batch_length): - cdef CScalarUdfContext c_context +def _get_udf_context(memory_pool, batch_length): + cdef CUdfContext c_context c_context.pool = maybe_unbox_memory_pool(memory_pool) c_context.batch_length = batch_length - context = box_scalar_udf_context(c_context) + context = box_udf_context(c_context) return context @@ -2690,7 +2690,7 @@ def register_scalar_function(func, function_name, function_doc, in_types, out_ty func : callable A callable implementing the user-defined function. The first argument is the context argument of type - ScalarUdfContext. + UdfContext. Then, it must take arguments equal to the number of in_types defined. It must return an Array or Scalar matching the out_type. It must return a Scalar if @@ -2744,35 +2744,33 @@ def register_scalar_function(func, function_name, function_doc, in_types, out_ty 21 ] """ - return _register_scalar_like_function(get_register_scalar_function(), - func, function_name, function_doc, in_types, - out_type, func_registry) + return _register_user_defined_function(get_register_scalar_function(), + func, function_name, function_doc, in_types, + out_type, func_registry) def register_aggregate_function(func, function_name, function_doc, in_types, out_type, func_registry=None): """ - Register a user-defined scalar function. + Register a user-defined non-decomposable aggregate function. - A scalar function is a function that executes elementwise - operations on arrays or scalars, i.e. a scalar function must - be computed row-by-row with no state where each output row - is computed only from its corresponding input row. - In other words, all argument arrays have the same length, - and the output array is of the same length as the arguments. - Scalar functions are the only functions allowed in query engine - expressions. + A non-decomposable aggregation function is a function that executes + aggregate operations on the whole data that it is aggregating. + In other words, non-decomposable aggregate function cannot be + split into consume/merge/finalize steps. + + This is mostly useful with segemented aggregation, where the data + to be aggregated is continuous. Parameters ---------- func : callable A callable implementing the user-defined function. The first argument is the context argument of type - ScalarUdfContext. + UdfContext. Then, it must take arguments equal to the number of - in_types defined. It must return an Array or Scalar - matching the out_type. It must return a Scalar if - all arguments are scalar, else it must return an Array. + in_types defined. It must return Scalar matching the + out_type. To define a varargs function, pass a callable that takes varargs. The last in_type will be the type of all varargs @@ -2796,35 +2794,33 @@ def register_aggregate_function(func, function_name, function_doc, in_types, out Examples -------- + >>> import numpy as np >>> import pyarrow as pa >>> import pyarrow.compute as pc >>> >>> func_doc = {} - >>> func_doc["summary"] = "simple udf" - >>> func_doc["description"] = "add a constant to a scalar" + >>> func_doc["summary"] = "simple mean udf" + >>> func_doc["description"] = "compute mean" >>> - >>> def add_constant(ctx, array): - ... return pc.add(array, 1, memory_pool=ctx.memory_pool) + >>> def compute_mean(ctx, array): + ... return pa.scalar(np.nanmean(array)) >>> - >>> func_name = "py_add_func" + >>> func_name = "py_compute_mean" >>> in_types = {"array": pa.int64()} - >>> out_type = pa.int64() - >>> pc.register_scalar_function(add_constant, func_name, func_doc, + >>> out_type = pa.float64() + >>> pc.register_aggregate_function(compute_mean, func_name, func_doc, ... in_types, out_type) >>> >>> func = pc.get_function(func_name) >>> func.name - 'py_add_func' - >>> answer = pc.call_function(func_name, [pa.array([20])]) + 'py_compute_mean' + >>> answer = pc.call_function(func_name, [pa.array([20, 40])]) >>> answer - - [ - 21 - ] + """ - return _register_scalar_like_function(get_register_aggregate_function(), - func, function_name, function_doc, in_types, - out_type, func_registry) + return _register_user_defined_function(get_register_aggregate_function(), + func, function_name, function_doc, in_types, + out_type, func_registry) def register_tabular_function(func, function_name, function_doc, in_types, out_type, @@ -2833,7 +2829,7 @@ def register_tabular_function(func, function_name, function_doc, in_types, out_t Register a user-defined tabular function. A tabular function is one accepting a context argument of type - ScalarUdfContext and returning a generator of struct arrays. + UdfContext and returning a generator of struct arrays. The in_types argument must be empty and the out_type argument specifies a schema. Each struct array must have field types correspoding to the schema. @@ -2843,7 +2839,7 @@ def register_tabular_function(func, function_name, function_doc, in_types, out_t func : callable A callable implementing the user-defined function. The only argument is the context argument of type - ScalarUdfContext. It must return a callable that + UdfContext. It must return a callable that returns on each invocation a StructArray matching the out_type, where an empty array indicates end. function_name : str @@ -2867,36 +2863,25 @@ def register_tabular_function(func, function_name, function_doc, in_types, out_t with nogil: c_type = make_shared[CStructType](deref(c_schema).fields()) out_type = pyarrow_wrap_data_type(c_type) - return _register_scalar_like_function(get_register_tabular_function(), - func, function_name, function_doc, in_types, - out_type, func_registry) + return _register_user_defined_function(get_register_tabular_function(), + func, function_name, function_doc, in_types, + out_type, func_registry) -def _register_scalar_like_function(register_func, func, function_name, function_doc, in_types, - out_type, func_registry=None): +def _register_user_defined_function(register_func, func, function_name, function_doc, in_types, + out_type, func_registry=None): """ - Register a user-defined scalar-like function. + Register a user-defined function. - A scalar-like function is a callable accepting a first - context argument of type ScalarUdfContext as well as - possibly additional Arrow arguments, and returning a - an Arrow result appropriate for the kind of function. - A scalar function and a tabular function are examples - for scalar-like functions. - This function is normally not called directly but via - register_scalar_function or register_tabular_function. + This method itself doesn't care what the type of the UDF + (i.e., scalar vs tabular vs aggregate) Parameters ---------- register_func: object - An object holding a CRegisterUdf in a "register_func" attribute. Use - get_register_scalar_function() for a scalar function and - get_register_tabular_function() for a tabular function. + An object holding a CRegisterUdf in a "register_func" attribute. func : callable A callable implementing the user-defined function. - See register_scalar_function and - register_tabular_function for details. - function_name : str Name of the function. This name must be globally unique. function_doc : dict @@ -2905,8 +2890,6 @@ def _register_scalar_like_function(register_func, func, function_name, function_ in_types : Dict[str, DataType] A dictionary mapping function argument names to their respective DataType. - See register_scalar_function and - register_tabular_function for details. out_type : DataType Output type of the function. func_registry : FunctionRegistry diff --git a/python/pyarrow/compute.py b/python/pyarrow/compute.py index 4d622932df3..e92f0935477 100644 --- a/python/pyarrow/compute.py +++ b/python/pyarrow/compute.py @@ -85,7 +85,7 @@ register_scalar_function, register_tabular_function, register_aggregate_function, - ScalarUdfContext, + UdfContext, # Expressions Expression, ) diff --git a/python/pyarrow/conftest.py b/python/pyarrow/conftest.py index a107cb791f8..1a03ecaf6ea 100644 --- a/python/pyarrow/conftest.py +++ b/python/pyarrow/conftest.py @@ -289,7 +289,7 @@ def unary_agg_func_fixture(): import numpy as np def func(ctx, x): - return pa.array([np.nanmean(x)]) + return pa.scalar(np.nanmean(x)) func_name = "y=avg(x)" func_doc = {"summary": "y=avg(x)", diff --git a/python/pyarrow/includes/libarrow.pxd b/python/pyarrow/includes/libarrow.pxd index a6fb2c4e9e2..86f21f4b528 100644 --- a/python/pyarrow/includes/libarrow.pxd +++ b/python/pyarrow/includes/libarrow.pxd @@ -2775,7 +2775,7 @@ cdef extern from "arrow/util/byte_size.h" namespace "arrow::util" nogil: int64_t TotalBufferSize(const CRecordBatch& record_batch) int64_t TotalBufferSize(const CTable& table) -ctypedef PyObject* CallbackUdf(object user_function, const CScalarUdfContext& context, object inputs) +ctypedef PyObject* CallbackUdf(object user_function, const CUdfContext& context, object inputs) cdef extern from "arrow/api.h" namespace "arrow" nogil: @@ -2786,7 +2786,7 @@ cdef extern from "arrow/api.h" namespace "arrow" nogil: cdef extern from "arrow/python/udf.h" namespace "arrow::py" nogil: - cdef cppclass CScalarUdfContext" arrow::py::ScalarUdfContext": + cdef cppclass CUdfContext" arrow::py::UdfContext": CMemoryPool *pool int64_t batch_length diff --git a/python/pyarrow/src/arrow/python/udf.cc b/python/pyarrow/src/arrow/python/udf.cc index 43a7f5af51e..0a81cdfceb9 100644 --- a/python/pyarrow/src/arrow/python/udf.cc +++ b/python/pyarrow/src/arrow/python/udf.cc @@ -28,7 +28,6 @@ namespace arrow { using internal::checked_cast; namespace py { - namespace { struct PythonUdfKernelState : public compute::KernelState { @@ -107,12 +106,12 @@ struct PythonTableUdfKernelInit { Result> operator()( compute::KernelContext* ctx, const compute::KernelInitArgs&) { - ScalarUdfContext scalar_udf_context{ctx->memory_pool(), /*batch_length=*/0}; + UdfContext udf_context{ctx->memory_pool(), /*batch_length=*/0}; std::unique_ptr function; - RETURN_NOT_OK(SafeCallIntoPython([this, &scalar_udf_context, &function] { + RETURN_NOT_OK(SafeCallIntoPython([this, &udf_context, &function] { OwnedRef empty_tuple(PyTuple_New(0)); function = std::make_unique( - cb(function_maker->obj(), scalar_udf_context, empty_tuple.obj())); + cb(function_maker->obj(), udf_context, empty_tuple.obj())); RETURN_NOT_OK(CheckPyError()); return Status::OK(); })); @@ -149,7 +148,6 @@ struct PythonTableUdfKernelInit { } Status Consume(compute::KernelContext* ctx, const compute::ExecSpan& batch) { - num_rows = batch.length; ARROW_ASSIGN_OR_RAISE(auto rb, batch.ToExecBatch().ToRecordBatch(input_schema, ctx->memory_pool())); values.push_back(rb); return Status::OK(); @@ -157,7 +155,6 @@ struct PythonTableUdfKernelInit { Status MergeFrom(compute::KernelContext* ctx, compute::KernelState&& src) { const auto& other_state = checked_cast(src); - num_rows += other_state.num_rows; values.insert(values.end(), other_state.values.begin(), other_state.values.end()); return Status::OK(); } @@ -166,12 +163,18 @@ struct PythonTableUdfKernelInit { auto state = arrow::internal::checked_cast(ctx->state()); std::shared_ptr& function = state->agg_function; const int num_args = input_schema->num_fields(); - // Ignore batch length here - ScalarUdfContext udf_context{ctx->memory_pool(), 0}; OwnedRef arg_tuple(PyTuple_New(num_args)); RETURN_NOT_OK(CheckPyError()); + // Note: The way that batches are concatenated together + // would result in using double amount of the memory. + // This is OK for now because non decomposable aggregate + // UDF is supposed to be used with segmented aggregation + // where the size of the segment is more or less constant + // so doubling that is not a big deal. This can be also + // improved in the future to use more efficient way to + // concatenate. ARROW_ASSIGN_OR_RAISE( auto table, arrow::Table::FromRecordBatches(input_schema, values) @@ -180,8 +183,9 @@ struct PythonTableUdfKernelInit { table, table->CombineChunks(ctx->memory_pool()) ); + UdfContext udf_context{ctx->memory_pool(), table->num_rows()}; for (int arg_id = 0; arg_id < num_args; arg_id++) { - // Since we combined chunks there is only one chunk + // Since we combined chunks thComere is only one chunk std::shared_ptr c_data = table->column(arg_id)->chunk(0); PyObject* data = wrap_array(c_data); PyTuple_SetItem(arg_tuple.obj(), arg_id, data); @@ -191,27 +195,22 @@ struct PythonTableUdfKernelInit { RETURN_NOT_OK(CheckPyError()); // unwrapping the output for expected output type - if (is_array(result.obj())) { - ARROW_ASSIGN_OR_RAISE(std::shared_ptr val, unwrap_array(result.obj())); - std::cout << val->type()->ToString() << std::endl; - if (*output_type != *val->type()) { + if (is_scalar(result.obj())) { + ARROW_ASSIGN_OR_RAISE(std::shared_ptr val, unwrap_scalar(result.obj())); + if (*output_type != *val->type) { return Status::TypeError("Expected output datatype ", output_type->ToString(), ", but function returned datatype ", - val->type()->ToString()); + val->type->ToString()); } - ARROW_ASSIGN_OR_RAISE(auto scalar_val , val->GetScalar(0)); - *out = Datum(scalar_val); - // out->value = std::move(val->data()); + out->value = std::move(val); return Status::OK(); } else { return Status::TypeError("Unexpected output type: ", Py_TYPE(result.obj())->tp_name, - " (expected Array)"); + " (expected Scalar)"); } - // *out = Datum((int32_t)table->num_rows()); return Status::OK(); } - int32_t num_rows = 0; UdfWrapperCallback agg_cb; std::vector> values; std::shared_ptr agg_function; @@ -248,7 +247,7 @@ struct PythonUdf : public PythonUdfKernelState { auto state = arrow::internal::checked_cast(ctx->state()); std::shared_ptr& function = state->function; const int num_args = batch.num_values(); - ScalarUdfContext scalar_udf_context{ctx->memory_pool(), batch.length}; + UdfContext udf_context{ctx->memory_pool(), batch.length}; OwnedRef arg_tuple(PyTuple_New(num_args)); RETURN_NOT_OK(CheckPyError()); @@ -264,7 +263,7 @@ struct PythonUdf : public PythonUdfKernelState { } } - OwnedRef result(cb(function->obj(), scalar_udf_context, arg_tuple.obj())); + OwnedRef result(cb(function->obj(), udf_context, arg_tuple.obj())); RETURN_NOT_OK(CheckPyError()); // unwrapping the output for expected output type if (is_array(result.obj())) { diff --git a/python/pyarrow/src/arrow/python/udf.h b/python/pyarrow/src/arrow/python/udf.h index f96bd725bdf..cc2f3ab62f5 100644 --- a/python/pyarrow/src/arrow/python/udf.h +++ b/python/pyarrow/src/arrow/python/udf.h @@ -43,19 +43,19 @@ struct ARROW_PYTHON_EXPORT UdfOptions { std::shared_ptr output_type; }; -/// \brief A context passed as the first argument of scalar UDF functions. -struct ARROW_PYTHON_EXPORT ScalarUdfContext { +/// \brief A context passed as the first argument of UDF functions. +struct ARROW_PYTHON_EXPORT UdfContext { MemoryPool* pool; int64_t batch_length; }; using UdfWrapperCallback = std::function; + PyObject* user_function, const UdfContext& context, PyObject* inputs)>; /// \brief register a Scalar user-defined-function from Python Status ARROW_PYTHON_EXPORT RegisterScalarFunction( - PyObject* user_function, UdfWrapperCallback wrapper, const UdfOptions& options, - compute::FunctionRegistry* registry = NULLPTR); + PyObject* user_function, UdfWrapperCallback wrapper, + const UdfOptions& options, compute::FunctionRegistry* registry = NULLPTR); /// \brief register a Table user-defined-function from Python Status ARROW_PYTHON_EXPORT RegisterTabularFunction( diff --git a/python/pyarrow/tests/test_substrait.py b/python/pyarrow/tests/test_substrait.py index 4e2d7b7655e..6ba38bfed18 100644 --- a/python/pyarrow/tests/test_substrait.py +++ b/python/pyarrow/tests/test_substrait.py @@ -34,9 +34,9 @@ pytestmark = [pytest.mark.dataset, pytest.mark.substrait] -def mock_scalar_udf_context(batch_length=10): - from pyarrow._compute import _get_scalar_udf_context - return _get_scalar_udf_context(pa.default_memory_pool(), batch_length) +def mock_udf_context(batch_length=10): + from pyarrow._compute import _get_udf_context + return _get_udf_context(pa.default_memory_pool(), batch_length) def _write_dummy_data_to_disk(tmpdir, file_name, table): @@ -436,14 +436,13 @@ def table_provider(names, _): """ buf = pa._substrait._parse_json_plan(substrait_query) - breakpoint() reader = pa.substrait.run_query( buf, table_provider=table_provider, use_threads=use_threads) res_tb = reader.read_all() function, name = unary_func_fixture expected_tb = test_table.add_column(1, 'y', function( - mock_scalar_udf_context(10), test_table['x'])) + mock_udf_context(10), test_table['x'])) assert res_tb == expected_tb @@ -608,7 +607,7 @@ def table_provider(names, schema): assert res_tb == expected -def test_agg_udf(unary_agg_func_fixture): +def test_aggregate_udf_basic(unary_agg_func_fixture): test_table = pa.Table.from_pydict( {"k": [1, 1, 2, 2], "v": [1.0, 2.0, 3.0, 4.0]} diff --git a/python/pyarrow/tests/test_udf.py b/python/pyarrow/tests/test_udf.py index b2db724aa4b..6b71a5b87c5 100644 --- a/python/pyarrow/tests/test_udf.py +++ b/python/pyarrow/tests/test_udf.py @@ -30,9 +30,9 @@ ds = None -def mock_scalar_udf_context(batch_length=10): - from pyarrow._compute import _get_scalar_udf_context - return _get_scalar_udf_context(pa.default_memory_pool(), batch_length) +def mock_udf_context(batch_length=10): + from pyarrow._compute import _get_udf_context + return _get_udf_context(pa.default_memory_pool(), batch_length) class MyError(RuntimeError): @@ -47,7 +47,7 @@ def bad_unary_agg_func_fixture(): def func(ctx, x): raise RuntimeError("Oops") - return pa.array([len(x)]) + return pa.scalar(len(x)) func_name = "y=bad_len(x)" func_doc = {"summary": "y=bad_len(x)", @@ -257,7 +257,7 @@ def check_scalar_function(func_fixture, assert func.name == name result = pc.call_function(name, inputs, length=batch_length) - expected_output = function(mock_scalar_udf_context(batch_length), *inputs) + expected_output = function(mock_udf_context(batch_length), *inputs) assert result == expected_output # At the moment there is an issue when handling nullary functions. # See: ARROW-15286 and ARROW-16290. From a2b89c67720973d464d07ee9233860b36a234eac Mon Sep 17 00:00:00 2001 From: Li Jin Date: Tue, 16 May 2023 17:38:27 -0400 Subject: [PATCH 06/19] Support varargs --- .../arrow/engine/substrait/extension_set.cc | 17 ++++++++-- python/pyarrow/conftest.py | 29 +++++++++++++++++ python/pyarrow/src/arrow/python/udf.cc | 3 +- python/pyarrow/tests/test_substrait.py | 31 +++++++++++++++---- python/pyarrow/tests/test_udf.py | 17 ++++++++-- 5 files changed, 85 insertions(+), 12 deletions(-) diff --git a/cpp/src/arrow/engine/substrait/extension_set.cc b/cpp/src/arrow/engine/substrait/extension_set.cc index b6b3aa10850..96343118037 100644 --- a/cpp/src/arrow/engine/substrait/extension_set.cc +++ b/cpp/src/arrow/engine/substrait/extension_set.cc @@ -417,8 +417,21 @@ ExtensionIdRegistry::SubstraitAggregateToArrow DecodeBasicAggregate( return compute::Aggregate{std::move(fixed_arrow_func), options ? std::move(options) : nullptr, *arg_ref, ""}; } - default: - break; + default: { + + fixed_arrow_func += arrow_function_name; + std::vector target; + for (int i = 0; i < call.size(); i++) { + ARROW_ASSIGN_OR_RAISE(compute::Expression arg, call.GetValueArg(i)); + const FieldRef* arg_ref = arg.field_ref(); + if (!arg_ref) { + return Status::Invalid("Expected an aggregate call ", call.id().uri, "#", + call.id().name, " to have a direct reference"); + } + target.emplace_back(*arg_ref); + } + return compute::Aggregate{std::move(fixed_arrow_func), nullptr, target, ""}; + } } return Status::NotImplemented( "Only nullary and unary aggregate functions are currently supported"); diff --git a/python/pyarrow/conftest.py b/python/pyarrow/conftest.py index 1a03ecaf6ea..fdd4fae2ca9 100644 --- a/python/pyarrow/conftest.py +++ b/python/pyarrow/conftest.py @@ -304,3 +304,32 @@ def func(ctx, x): pa.float64() ) return func, func_name + +@pytest.fixture(scope="session") +def varargs_agg_func_fixture(): + """ + Register a unary aggregate function + """ + from pyarrow import compute as pc + import numpy as np + + def func(ctx, *args): + sum = 0.0 + for arg in args: + sum += np.nanmean(arg) + return pa.scalar(sum) + + func_name = "y=sum_mean(x...)" + func_doc = {"summary": "Varargs aggregate", + "description": "Varargs aggregate"} + + pc.register_aggregate_function(func, + func_name, + func_doc, + { + "x": pa.float64(), + "y": pa.float64() + }, + pa.float64() + ) + return func, func_name \ No newline at end of file diff --git a/python/pyarrow/src/arrow/python/udf.cc b/python/pyarrow/src/arrow/python/udf.cc index 0a81cdfceb9..2971dfcc3ed 100644 --- a/python/pyarrow/src/arrow/python/udf.cc +++ b/python/pyarrow/src/arrow/python/udf.cc @@ -393,7 +393,8 @@ Status RegisterAggregateFunction(PyObject* agg_function, UdfWrapperCallback agg_ }; RETURN_NOT_OK( - AddAggKernel(compute::KernelSignature::Make(input_types, output_type), + AddAggKernel(compute::KernelSignature::Make( + std::move(input_types), std::move(output_type), options.arity.is_varargs), init, aggregate_func.get())); RETURN_NOT_OK(registry->AddFunction(std::move(aggregate_func))); diff --git a/python/pyarrow/tests/test_substrait.py b/python/pyarrow/tests/test_substrait.py index 6ba38bfed18..b42c038a29f 100644 --- a/python/pyarrow/tests/test_substrait.py +++ b/python/pyarrow/tests/test_substrait.py @@ -607,10 +607,11 @@ def table_provider(names, schema): assert res_tb == expected -def test_aggregate_udf_basic(unary_agg_func_fixture): +def test_aggregate_udf_basic(varargs_agg_func_fixture): test_table = pa.Table.from_pydict( - {"k": [1, 1, 2, 2], "v": [1.0, 2.0, 3.0, 4.0]} + {"k": [1, 1, 2, 2], "v1": [1.0, 2.0, 3.0, 4.0], + "v2": [1.0, 1.0, 1.0, 1.0]} ) def table_provider(names, _): @@ -629,7 +630,7 @@ def table_provider(names, _): "extensionFunction": { "extensionUriReference": 1, "functionAnchor": 1, - "name": "y=avg(x)" + "name": "y=sum_mean(x...)" } } ], @@ -650,8 +651,9 @@ def table_provider(names, _): "read": { "baseSchema": { "names": [ - "time", - "price" + "k", + "v1", + "v2", ], "struct": { "types": [ @@ -660,6 +662,11 @@ def table_provider(names, _): "nullability": "NULLABILITY_REQUIRED" } }, + { + "fp64": { + "nullability": "NULLABILITY_NULLABLE" + } + }, { "fp64": { "nullability": "NULLABILITY_NULLABLE" @@ -706,6 +713,18 @@ def table_provider(names, _): "rootReference": {} } } + }, + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 2 + } + }, + "rootReference": {} + } + } } ] } @@ -730,7 +749,7 @@ def table_provider(names, _): expected_tb = pa.Table.from_pydict({ 'k': [1, 2], - 'v_avg': [1.5, 3.5] + 'v_avg': [2.5, 4.5] }) assert res_tb == expected_tb diff --git a/python/pyarrow/tests/test_udf.py b/python/pyarrow/tests/test_udf.py index 6b71a5b87c5..1dd5ac5e5c8 100644 --- a/python/pyarrow/tests/test_udf.py +++ b/python/pyarrow/tests/test_udf.py @@ -63,7 +63,6 @@ def func(ctx, x): ) return func, func_name - @pytest.fixture(scope="session") def binary_func_fixture(): """ @@ -621,14 +620,26 @@ def test_udt_datasource1_exception(): def test_aggregate_basic(unary_agg_func_fixture): - arr = pa.array([10.0, 20.0, 30.0, 40.0, 50.0, 60.0], pa.float64()) + arr = pa.array([10.0, 20.0, 30.0, 40.0, 50.0], pa.float64()) result = pc.call_function("y=avg(x)", [arr]) - expected = pa.scalar(35.0) + expected = pa.scalar(30.0) + assert result == expected + + +def test_aggregate_varargs(varargs_agg_func_fixture): + arr1 = pa.array([10.0, 20.0, 30.0, 40.0, 50.0], pa.float64()) + arr2 = pa.array([1.0, 2.0, 3.0, 4.0, 5.0], pa.float64()) + + result = pc.call_function( + "y=sum_mean(x...)", [arr1, arr2] + ) + expected = pa.scalar(33.0) assert result == expected def test_aggregate_exception(bad_unary_agg_func_fixture): arr = pa.array([10, 20, 30, 40, 50, 60], pa.int64()) + with pytest.raises(RuntimeError, match='Oops'): try: pc.call_function("y=bad_len(x)", [arr]) From c5780574b0624e539a3c5ba8531866ee307318f4 Mon Sep 17 00:00:00 2001 From: Li Jin Date: Wed, 17 May 2023 10:58:38 -0400 Subject: [PATCH 07/19] Lint. Revert unneeded changes. --- .../arrow/engine/substrait/extension_set.cc | 137 ++++++++---------- python/pyarrow/conftest.py | 3 +- python/pyarrow/tests/test_udf.py | 1 + 3 files changed, 65 insertions(+), 76 deletions(-) diff --git a/cpp/src/arrow/engine/substrait/extension_set.cc b/cpp/src/arrow/engine/substrait/extension_set.cc index 96343118037..70721f2cb99 100644 --- a/cpp/src/arrow/engine/substrait/extension_set.cc +++ b/cpp/src/arrow/engine/substrait/extension_set.cc @@ -363,81 +363,6 @@ ExtensionIdRegistry::SubstraitAggregateToArrow kSimpleSubstraitAggregateToArrow return DecodeBasicAggregate(std::string(call.id().name))(call); }; -ExtensionIdRegistry::SubstraitAggregateToArrow DecodeBasicAggregate( - const std::string& arrow_function_name) { - return [arrow_function_name](const SubstraitCall& call) -> Result { - std::string fixed_arrow_func; - if (call.is_hash()) { - fixed_arrow_func = "hash_"; - } - - switch (call.size()) { - case 0: { - if (call.id().name == "count") { - fixed_arrow_func += "count_all"; - return compute::Aggregate{std::move(fixed_arrow_func), ""}; - } - return Status::Invalid("Expected aggregate call ", call.id().uri, "#", - call.id().name, " to have at least one argument"); - } - case 1: { - std::shared_ptr options = nullptr; - if (arrow_function_name == "stddev" || arrow_function_name == "variance") { - // See the following URL for the spec of stddev and variance: - // https://github.com/substrait-io/substrait/blob/ - // 73228b4112d79eb1011af0ebb41753ce23ca180c/ - // extensions/functions_arithmetic.yaml#L1240 - auto maybe_dist = call.GetOption("distribution"); - if (maybe_dist) { - auto& prefs = **maybe_dist; - if (prefs.size() != 1) { - return Status::Invalid("expected a single preference for ", - arrow_function_name, " but got ", prefs.size()); - } - int ddof; - if (prefs[0] == "POPULATION") { - ddof = 1; - } else if (prefs[0] == "SAMPLE") { - ddof = 0; - } else { - return Status::Invalid("unknown distribution preference ", prefs[0]); - } - options = std::make_shared(ddof); - } - } - fixed_arrow_func += arrow_function_name; - - ARROW_ASSIGN_OR_RAISE(compute::Expression arg, call.GetValueArg(0)); - const FieldRef* arg_ref = arg.field_ref(); - if (!arg_ref) { - return Status::Invalid("Expected an aggregate call ", call.id().uri, "#", - call.id().name, " to have a direct reference"); - } - - return compute::Aggregate{std::move(fixed_arrow_func), - options ? std::move(options) : nullptr, *arg_ref, ""}; - } - default: { - - fixed_arrow_func += arrow_function_name; - std::vector target; - for (int i = 0; i < call.size(); i++) { - ARROW_ASSIGN_OR_RAISE(compute::Expression arg, call.GetValueArg(i)); - const FieldRef* arg_ref = arg.field_ref(); - if (!arg_ref) { - return Status::Invalid("Expected an aggregate call ", call.id().uri, "#", - call.id().name, " to have a direct reference"); - } - target.emplace_back(*arg_ref); - } - return compute::Aggregate{std::move(fixed_arrow_func), nullptr, target, ""}; - } - } - return Status::NotImplemented( - "Only nullary and unary aggregate functions are currently supported"); - }; -} - struct ExtensionIdRegistryImpl : ExtensionIdRegistry { ExtensionIdRegistryImpl() : parent_(nullptr) {} explicit ExtensionIdRegistryImpl(const ExtensionIdRegistry* parent) : parent_(parent) {} @@ -1012,6 +937,68 @@ ExtensionIdRegistry::SubstraitCallToArrow DecodeConcatMapping() { }; } +ExtensionIdRegistry::SubstraitAggregateToArrow DecodeBasicAggregate( + const std::string& arrow_function_name) { + return [arrow_function_name](const SubstraitCall& call) -> Result { + std::string fixed_arrow_func; + if (call.is_hash()) { + fixed_arrow_func = "hash_"; + } + + switch (call.size()) { + case 0: { + if (call.id().name == "count") { + fixed_arrow_func += "count_all"; + return compute::Aggregate{std::move(fixed_arrow_func), ""}; + } + return Status::Invalid("Expected aggregate call ", call.id().uri, "#", + call.id().name, " to have at least one argument"); + } + default: { + // Handles all arity > 0 + + std::shared_ptr options = nullptr; + if (arrow_function_name == "stddev" || arrow_function_name == "variance") { + // See the following URL for the spec of stddev and variance: + // https://github.com/substrait-io/substrait/blob/ + // 73228b4112d79eb1011af0ebb41753ce23ca180c/ + // extensions/functions_arithmetic.yaml#L1240 + auto maybe_dist = call.GetOption("distribution"); + if (maybe_dist) { + auto& prefs = **maybe_dist; + if (prefs.size() != 1) { + return Status::Invalid("expected a single preference for ", + arrow_function_name, " but got ", prefs.size()); + } + int ddof; + if (prefs[0] == "POPULATION") { + ddof = 1; + } else if (prefs[0] == "SAMPLE") { + ddof = 0; + } else { + return Status::Invalid("unknown distribution preference ", prefs[0]); + } + options = std::make_shared(ddof); + } + } + + fixed_arrow_func += arrow_function_name; + std::vector target; + for (int i = 0; i < call.size(); i++) { + ARROW_ASSIGN_OR_RAISE(compute::Expression arg, call.GetValueArg(i)); + const FieldRef* arg_ref = arg.field_ref(); + if (!arg_ref) { + return Status::Invalid("Expected an aggregate call ", call.id().uri, "#", + call.id().name, " to have a direct reference"); + } + target.emplace_back(*arg_ref); + } + return compute::Aggregate{std::move(fixed_arrow_func), nullptr, target, ""}; + } + } + }; +} + struct DefaultExtensionIdRegistry : ExtensionIdRegistryImpl { DefaultExtensionIdRegistry() { // ----------- Extension Types ---------------------------- diff --git a/python/pyarrow/conftest.py b/python/pyarrow/conftest.py index fdd4fae2ca9..fb08797c660 100644 --- a/python/pyarrow/conftest.py +++ b/python/pyarrow/conftest.py @@ -305,6 +305,7 @@ def func(ctx, x): ) return func, func_name + @pytest.fixture(scope="session") def varargs_agg_func_fixture(): """ @@ -332,4 +333,4 @@ def func(ctx, *args): }, pa.float64() ) - return func, func_name \ No newline at end of file + return func, func_name diff --git a/python/pyarrow/tests/test_udf.py b/python/pyarrow/tests/test_udf.py index 1dd5ac5e5c8..4f32698fb96 100644 --- a/python/pyarrow/tests/test_udf.py +++ b/python/pyarrow/tests/test_udf.py @@ -63,6 +63,7 @@ def func(ctx, x): ) return func, func_name + @pytest.fixture(scope="session") def binary_func_fixture(): """ From ff04234400226dae0e3dbc6571def74766ac61a1 Mon Sep 17 00:00:00 2001 From: Li Jin Date: Wed, 17 May 2023 11:45:21 -0400 Subject: [PATCH 08/19] Second round of self review --- cpp/src/arrow/engine/substrait/extension_set.cc | 5 +++-- python/pyarrow/_compute.pyx | 6 +++--- python/pyarrow/conftest.py | 2 +- python/pyarrow/tests/test_substrait.py | 4 ++-- python/pyarrow/tests/test_udf.py | 2 +- 5 files changed, 10 insertions(+), 9 deletions(-) diff --git a/cpp/src/arrow/engine/substrait/extension_set.cc b/cpp/src/arrow/engine/substrait/extension_set.cc index 70721f2cb99..008818c9ce3 100644 --- a/cpp/src/arrow/engine/substrait/extension_set.cc +++ b/cpp/src/arrow/engine/substrait/extension_set.cc @@ -991,9 +991,10 @@ ExtensionIdRegistry::SubstraitAggregateToArrow DecodeBasicAggregate( return Status::Invalid("Expected an aggregate call ", call.id().uri, "#", call.id().name, " to have a direct reference"); } - target.emplace_back(*arg_ref); + target.emplace_back(std::move(*arg_ref)); } - return compute::Aggregate{std::move(fixed_arrow_func), nullptr, target, ""}; + return compute::Aggregate{std::move(fixed_arrow_func), + options ? std::move(options) : nullptr, target, ""}; } } }; diff --git a/python/pyarrow/_compute.pyx b/python/pyarrow/_compute.pyx index 6e62d8501ee..c303e333c16 100644 --- a/python/pyarrow/_compute.pyx +++ b/python/pyarrow/_compute.pyx @@ -2771,10 +2771,10 @@ def register_aggregate_function(func, function_name, function_doc, in_types, out Then, it must take arguments equal to the number of in_types defined. It must return Scalar matching the out_type. - To define a varargs function, pass a callable that takes - varargs. The last in_type will be the type of all varargs - arguments. + varargs. The in_type needs to match in type of inputs when + the function gets called. + function_name : str Name of the function. This name must be globally unique. function_doc : dict diff --git a/python/pyarrow/conftest.py b/python/pyarrow/conftest.py index fb08797c660..f32cbf01efc 100644 --- a/python/pyarrow/conftest.py +++ b/python/pyarrow/conftest.py @@ -328,7 +328,7 @@ def func(ctx, *args): func_name, func_doc, { - "x": pa.float64(), + "x": pa.int64(), "y": pa.float64() }, pa.float64() diff --git a/python/pyarrow/tests/test_substrait.py b/python/pyarrow/tests/test_substrait.py index b42c038a29f..34faaa157af 100644 --- a/python/pyarrow/tests/test_substrait.py +++ b/python/pyarrow/tests/test_substrait.py @@ -610,7 +610,7 @@ def table_provider(names, schema): def test_aggregate_udf_basic(varargs_agg_func_fixture): test_table = pa.Table.from_pydict( - {"k": [1, 1, 2, 2], "v1": [1.0, 2.0, 3.0, 4.0], + {"k": [1, 1, 2, 2], "v1": [1, 2, 3, 4], "v2": [1.0, 1.0, 1.0, 1.0]} ) @@ -663,7 +663,7 @@ def table_provider(names, _): } }, { - "fp64": { + "i64": { "nullability": "NULLABILITY_NULLABLE" } }, diff --git a/python/pyarrow/tests/test_udf.py b/python/pyarrow/tests/test_udf.py index 4f32698fb96..67af4dd6664 100644 --- a/python/pyarrow/tests/test_udf.py +++ b/python/pyarrow/tests/test_udf.py @@ -628,7 +628,7 @@ def test_aggregate_basic(unary_agg_func_fixture): def test_aggregate_varargs(varargs_agg_func_fixture): - arr1 = pa.array([10.0, 20.0, 30.0, 40.0, 50.0], pa.float64()) + arr1 = pa.array([10, 20, 30, 40, 50], pa.int64()) arr2 = pa.array([1.0, 2.0, 3.0, 4.0, 5.0], pa.float64()) result = pc.call_function( From b1d51f7a4b21c09e5eeb7fb6faa1903103401aee Mon Sep 17 00:00:00 2001 From: Li Jin Date: Wed, 31 May 2023 15:22:16 -0400 Subject: [PATCH 09/19] Apply suggestions from code review Co-authored-by: Weston Pace --- python/pyarrow/_compute.pyx | 2 +- python/pyarrow/src/arrow/python/udf.cc | 12 ++++-------- 2 files changed, 5 insertions(+), 9 deletions(-) diff --git a/python/pyarrow/_compute.pyx b/python/pyarrow/_compute.pyx index c303e333c16..77967c5a3da 100644 --- a/python/pyarrow/_compute.pyx +++ b/python/pyarrow/_compute.pyx @@ -2769,7 +2769,7 @@ def register_aggregate_function(func, function_name, function_doc, in_types, out The first argument is the context argument of type UdfContext. Then, it must take arguments equal to the number of - in_types defined. It must return Scalar matching the + in_types defined. It must return a Scalar matching the out_type. To define a varargs function, pass a callable that takes varargs. The in_type needs to match in type of inputs when diff --git a/python/pyarrow/src/arrow/python/udf.cc b/python/pyarrow/src/arrow/python/udf.cc index 2971dfcc3ed..58a0c3f5b6c 100644 --- a/python/pyarrow/src/arrow/python/udf.cc +++ b/python/pyarrow/src/arrow/python/udf.cc @@ -15,8 +15,6 @@ // specific language governing permissions and limitations // under the License. -#include - #include "arrow/python/udf.h" #include "arrow/table.h" #include "arrow/compute/api_aggregate.h" @@ -149,7 +147,7 @@ struct PythonTableUdfKernelInit { Status Consume(compute::KernelContext* ctx, const compute::ExecSpan& batch) { ARROW_ASSIGN_OR_RAISE(auto rb, batch.ToExecBatch().ToRecordBatch(input_schema, ctx->memory_pool())); - values.push_back(rb); + values.push_back(std::move(rb)); return Status::OK(); } @@ -185,7 +183,7 @@ struct PythonTableUdfKernelInit { UdfContext udf_context{ctx->memory_pool(), table->num_rows()}; for (int arg_id = 0; arg_id < num_args; arg_id++) { - // Since we combined chunks thComere is only one chunk + // Since we combined chunks there is only one chunk std::shared_ptr c_data = table->column(arg_id)->chunk(0); PyObject* data = wrap_array(c_data); PyTuple_SetItem(arg_tuple.obj(), arg_id, data); @@ -204,11 +202,9 @@ struct PythonTableUdfKernelInit { } out->value = std::move(val); return Status::OK(); - } else { - return Status::TypeError("Unexpected output type: ", Py_TYPE(result.obj())->tp_name, - " (expected Scalar)"); } - return Status::OK(); + return Status::TypeError("Unexpected output type: ", Py_TYPE(result.obj())->tp_name, + " (expected Scalar)"); } UdfWrapperCallback agg_cb; From 3daefeaca75af4a945adcd08ce9d6c77dac58cd6 Mon Sep 17 00:00:00 2001 From: Li Jin Date: Thu, 1 Jun 2023 14:26:12 -0400 Subject: [PATCH 10/19] wip --- .../arrow/engine/substrait/extension_set.cc | 6 +-- python/pyarrow/_compute.pyx | 37 ++++++++++-------- python/pyarrow/src/arrow/python/udf.cc | 39 +++++++++++-------- python/pyarrow/tests/test_udf.py | 10 +++-- 4 files changed, 53 insertions(+), 39 deletions(-) diff --git a/cpp/src/arrow/engine/substrait/extension_set.cc b/cpp/src/arrow/engine/substrait/extension_set.cc index 008818c9ce3..8f8df5ae113 100644 --- a/cpp/src/arrow/engine/substrait/extension_set.cc +++ b/cpp/src/arrow/engine/substrait/extension_set.cc @@ -981,12 +981,12 @@ ExtensionIdRegistry::SubstraitAggregateToArrow DecodeBasicAggregate( options = std::make_shared(ddof); } } - fixed_arrow_func += arrow_function_name; + std::vector target; for (int i = 0; i < call.size(); i++) { ARROW_ASSIGN_OR_RAISE(compute::Expression arg, call.GetValueArg(i)); - const FieldRef* arg_ref = arg.field_ref(); + FieldRef* arg_ref = arg.field_ref(); if (!arg_ref) { return Status::Invalid("Expected an aggregate call ", call.id().uri, "#", call.id().name, " to have a direct reference"); @@ -994,7 +994,7 @@ ExtensionIdRegistry::SubstraitAggregateToArrow DecodeBasicAggregate( target.emplace_back(std::move(*arg_ref)); } return compute::Aggregate{std::move(fixed_arrow_func), - options ? std::move(options) : nullptr, target, ""}; + options ? std::move(options) : nullptr, std::move(target), ""}; } } }; diff --git a/python/pyarrow/_compute.pyx b/python/pyarrow/_compute.pyx index 77967c5a3da..5e02ba33854 100644 --- a/python/pyarrow/_compute.pyx +++ b/python/pyarrow/_compute.pyx @@ -2697,10 +2697,11 @@ def register_scalar_function(func, function_name, function_doc, in_types, out_ty all arguments are scalar, else it must return an Array. To define a varargs function, pass a callable that takes - varargs. The last in_type will be the type of all varargs + *args. The last in_type will be the type of all varargs arguments. function_name : str - Name of the function. This name must be globally unique. + Name of the function. There should only be one function + registered with this name in the function registry. function_doc : dict A dictionary object with keys "summary" (str), and "description" (str). @@ -2759,8 +2760,8 @@ def register_aggregate_function(func, function_name, function_doc, in_types, out In other words, non-decomposable aggregate function cannot be split into consume/merge/finalize steps. - This is mostly useful with segemented aggregation, where the data - to be aggregated is continuous. + This is often used with ordered or segmented aggregation where groups + can be emit before accumulating all of the input data. Parameters ---------- @@ -2772,11 +2773,13 @@ def register_aggregate_function(func, function_name, function_doc, in_types, out in_types defined. It must return a Scalar matching the out_type. To define a varargs function, pass a callable that takes - varargs. The in_type needs to match in type of inputs when + *args. The in_type needs to match in type of inputs when the function gets called. function_name : str - Name of the function. This name must be globally unique. + Name of the function. This name must be unique, i.e., + there should only be one function registered with + this name in the function registry. function_doc : dict A dictionary object with keys "summary" (str), and "description" (str). @@ -2799,21 +2802,21 @@ def register_aggregate_function(func, function_name, function_doc, in_types, out >>> import pyarrow.compute as pc >>> >>> func_doc = {} - >>> func_doc["summary"] = "simple mean udf" - >>> func_doc["description"] = "compute mean" + >>> func_doc["summary"] = "simple median udf" + >>> func_doc["description"] = "compute median" >>> - >>> def compute_mean(ctx, array): - ... return pa.scalar(np.nanmean(array)) + >>> def compute_median(ctx, array): + ... return pa.scalar(np.median(array)) >>> - >>> func_name = "py_compute_mean" + >>> func_name = "py_compute_median" >>> in_types = {"array": pa.int64()} >>> out_type = pa.float64() - >>> pc.register_aggregate_function(compute_mean, func_name, func_doc, + >>> pc.register_aggregate_function(compute_median, func_name, func_doc, ... in_types, out_type) >>> >>> func = pc.get_function(func_name) >>> func.name - 'py_compute_mean' + 'py_compute_median' >>> answer = pc.call_function(func_name, [pa.array([20, 40])]) >>> answer @@ -2843,7 +2846,8 @@ def register_tabular_function(func, function_name, function_doc, in_types, out_t returns on each invocation a StructArray matching the out_type, where an empty array indicates end. function_name : str - Name of the function. This name must be globally unique. + Name of the function. There should only be one function + registered with this name in the function registry. function_doc : dict A dictionary object with keys "summary" (str), and "description" (str). @@ -2873,7 +2877,7 @@ def _register_user_defined_function(register_func, func, function_name, function """ Register a user-defined function. - This method itself doesn't care what the type of the UDF + This method itself doesn't care about the type of the UDF (i.e., scalar vs tabular vs aggregate) Parameters @@ -2883,7 +2887,8 @@ def _register_user_defined_function(register_func, func, function_name, function func : callable A callable implementing the user-defined function. function_name : str - Name of the function. This name must be globally unique. + Name of the function. There should only be one function + registered with this name in the function registry. function_doc : dict A dictionary object with keys "summary" (str), and "description" (str). diff --git a/python/pyarrow/src/arrow/python/udf.cc b/python/pyarrow/src/arrow/python/udf.cc index 58a0c3f5b6c..7aaeb1f16d5 100644 --- a/python/pyarrow/src/arrow/python/udf.cc +++ b/python/pyarrow/src/arrow/python/udf.cc @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +#include + #include "arrow/python/udf.h" #include "arrow/table.h" #include "arrow/compute/api_aggregate.h" @@ -83,8 +85,7 @@ arrow::Status AggregateUdfMerge(compute::KernelContext* ctx, compute::KernelStat } arrow::Status AggregateUdfFinalize(compute::KernelContext* ctx, arrow::Datum* out) { - auto udf = checked_cast(ctx->state()); - return SafeCallIntoPython([&]() -> Status {return udf->Finalize(ctx, out);}); + return checked_cast(ctx->state())->Finalize(ctx, out); } struct PythonTableUdfKernelInit { @@ -162,9 +163,6 @@ struct PythonTableUdfKernelInit { std::shared_ptr& function = state->agg_function; const int num_args = input_schema->num_fields(); - OwnedRef arg_tuple(PyTuple_New(num_args)); - RETURN_NOT_OK(CheckPyError()); - // Note: The way that batches are concatenated together // would result in using double amount of the memory. // This is OK for now because non decomposable aggregate @@ -180,21 +178,30 @@ struct PythonTableUdfKernelInit { ARROW_ASSIGN_OR_RAISE( table, table->CombineChunks(ctx->memory_pool()) ); - UdfContext udf_context{ctx->memory_pool(), table->num_rows()}; - for (int arg_id = 0; arg_id < num_args; arg_id++) { - // Since we combined chunks there is only one chunk - std::shared_ptr c_data = table->column(arg_id)->chunk(0); - PyObject* data = wrap_array(c_data); - PyTuple_SetItem(arg_tuple.obj(), arg_id, data); + + if (table->num_rows() == 0) { + return Status::Invalid("Finalized is called with empty inputs"); } - OwnedRef result(agg_cb(function->obj(), udf_context, arg_tuple.obj())); - RETURN_NOT_OK(CheckPyError()); + std::unique_ptr result; + RETURN_NOT_OK(SafeCallIntoPython([&] { + OwnedRef arg_tuple(PyTuple_New(num_args)); + RETURN_NOT_OK(CheckPyError()); + for (int arg_id = 0; arg_id < num_args; arg_id++) { + // Since we combined chunks there is only one chunk + std::shared_ptr c_data = table->column(arg_id)->chunk(0); + PyObject* data = wrap_array(c_data); + PyTuple_SetItem(arg_tuple.obj(), arg_id, data); + } + result = std::make_unique(agg_cb(function->obj(), udf_context, arg_tuple.obj())); + RETURN_NOT_OK(CheckPyError()); + return Status::OK(); + })); // unwrapping the output for expected output type - if (is_scalar(result.obj())) { - ARROW_ASSIGN_OR_RAISE(std::shared_ptr val, unwrap_scalar(result.obj())); + if (is_scalar(result->obj())) { + ARROW_ASSIGN_OR_RAISE(std::shared_ptr val, unwrap_scalar(result->obj())); if (*output_type != *val->type) { return Status::TypeError("Expected output datatype ", output_type->ToString(), ", but function returned datatype ", @@ -203,7 +210,7 @@ struct PythonTableUdfKernelInit { out->value = std::move(val); return Status::OK(); } - return Status::TypeError("Unexpected output type: ", Py_TYPE(result.obj())->tp_name, + return Status::TypeError("Unexpected output type: ", Py_TYPE(result->obj())->tp_name, " (expected Scalar)"); } diff --git a/python/pyarrow/tests/test_udf.py b/python/pyarrow/tests/test_udf.py index 67af4dd6664..8d46c24823e 100644 --- a/python/pyarrow/tests/test_udf.py +++ b/python/pyarrow/tests/test_udf.py @@ -626,6 +626,11 @@ def test_aggregate_basic(unary_agg_func_fixture): expected = pa.scalar(30.0) assert result == expected +def test_aggregate_empty(unary_agg_func_fixture): + arr = pa.array([], pa.float64()) + + with pytest.raises(RuntimeError, match='.*empty inputs.*'): + pc.call_function("y=avg(x)", [arr]) def test_aggregate_varargs(varargs_agg_func_fixture): arr1 = pa.array([10, 20, 30, 40, 50], pa.int64()) @@ -642,7 +647,4 @@ def test_aggregate_exception(bad_unary_agg_func_fixture): arr = pa.array([10, 20, 30, 40, 50, 60], pa.int64()) with pytest.raises(RuntimeError, match='Oops'): - try: - pc.call_function("y=bad_len(x)", [arr]) - except Exception as e: - raise e + pc.call_function("y=bad_len(x)", [arr]) From 8fd8c9601b20742c5c7a4c1b01316ded5f5e37a9 Mon Sep 17 00:00:00 2001 From: Li Jin Date: Mon, 5 Jun 2023 14:48:05 -0400 Subject: [PATCH 11/19] Address PR comments --- .../arrow/engine/substrait/extension_set.cc | 8 +- python/pyarrow/src/arrow/python/udf.cc | 12 ++- python/pyarrow/tests/test_udf.py | 82 +++++++++++++++---- 3 files changed, 79 insertions(+), 23 deletions(-) diff --git a/cpp/src/arrow/engine/substrait/extension_set.cc b/cpp/src/arrow/engine/substrait/extension_set.cc index 8f8df5ae113..d89248383b7 100644 --- a/cpp/src/arrow/engine/substrait/extension_set.cc +++ b/cpp/src/arrow/engine/substrait/extension_set.cc @@ -986,15 +986,17 @@ ExtensionIdRegistry::SubstraitAggregateToArrow DecodeBasicAggregate( std::vector target; for (int i = 0; i < call.size(); i++) { ARROW_ASSIGN_OR_RAISE(compute::Expression arg, call.GetValueArg(i)); - FieldRef* arg_ref = arg.field_ref(); + const FieldRef* arg_ref = arg.field_ref(); if (!arg_ref) { return Status::Invalid("Expected an aggregate call ", call.id().uri, "#", call.id().name, " to have a direct reference"); } - target.emplace_back(std::move(*arg_ref)); + // Copy arg_ref here because field_ref() return const FieldRef* + target.emplace_back(*arg_ref); } return compute::Aggregate{std::move(fixed_arrow_func), - options ? std::move(options) : nullptr, std::move(target), ""}; + options ? std::move(options) : nullptr, + std::move(target), ""}; } } }; diff --git a/python/pyarrow/src/arrow/python/udf.cc b/python/pyarrow/src/arrow/python/udf.cc index 7aaeb1f16d5..edc49c18051 100644 --- a/python/pyarrow/src/arrow/python/udf.cc +++ b/python/pyarrow/src/arrow/python/udf.cc @@ -135,9 +135,9 @@ struct PythonTableUdfKernelInit { output_type(output_type) { std::vector> fields; for (size_t i = 0; i < input_types.size(); i++) { - fields.push_back(field("", input_types[i])); + fields.push_back(std::move(field("", input_types[i]))); } - input_schema = schema(fields); + input_schema = schema(std::move(fields)); }; ~PythonUdfScalarAggregatorImpl() { @@ -153,8 +153,12 @@ struct PythonTableUdfKernelInit { } Status MergeFrom(compute::KernelContext* ctx, compute::KernelState&& src) { - const auto& other_state = checked_cast(src); - values.insert(values.end(), other_state.values.begin(), other_state.values.end()); + auto& other_values = checked_cast(src).values; + values.insert(values.end(), + std::make_move_iterator(other_values.begin()), + std::make_move_iterator(other_values.end())); + + other_values.erase(other_values.begin(), other_values.end()); return Status::OK(); } diff --git a/python/pyarrow/tests/test_udf.py b/python/pyarrow/tests/test_udf.py index 8d46c24823e..c0cfd3d26e8 100644 --- a/python/pyarrow/tests/test_udf.py +++ b/python/pyarrow/tests/test_udf.py @@ -24,6 +24,9 @@ # UDFs are all tested with a dataset scan pytestmark = pytest.mark.dataset +# For convience, most of the test here doesn't care about udf func docs +empty_udf_doc = {"summary": "", "description": ""} + try: import pyarrow.dataset as ds except ImportError: @@ -40,18 +43,51 @@ class MyError(RuntimeError): @pytest.fixture(scope="session") -def bad_unary_agg_func_fixture(): - """ - Register a unary aggregate function - """ - +def exception_agg_func_fixture(): def func(ctx, x): raise RuntimeError("Oops") return pa.scalar(len(x)) - func_name = "y=bad_len(x)" - func_doc = {"summary": "y=bad_len(x)", - "description": "find length of"} + func_name = "y=exception_len(x)" + func_doc = empty_udf_doc + + pc.register_aggregate_function(func, + func_name, + func_doc, + { + "x": pa.int64(), + }, + pa.int64() + ) + return func, func_name + + +@pytest.fixture(scope="session") +def wrong_output_dtype_agg_func_fixture(scope="session"): + def func(ctx, x): + return pa.scalar(len(x), pa.int32()) + + func_name = "y=wrong_output_dtype(x)" + func_doc = empty_udf_doc + + pc.register_aggregate_function(func, + func_name, + func_doc, + { + "x": pa.int64(), + }, + pa.int64() + ) + return func, func_name + + +@pytest.fixture(scope="session") +def wrong_output_type_agg_func_fixture(scope="session"): + def func(ctx, x): + return len(x) + + func_name = "y=wrong_output_type(x)" + func_doc = empty_udf_doc pc.register_aggregate_function(func, func_name, @@ -620,19 +656,33 @@ def test_udt_datasource1_exception(): _test_datasource1_udt(datasource1_exception) -def test_aggregate_basic(unary_agg_func_fixture): +def test_agg_basic(unary_agg_func_fixture): arr = pa.array([10.0, 20.0, 30.0, 40.0, 50.0], pa.float64()) result = pc.call_function("y=avg(x)", [arr]) expected = pa.scalar(30.0) assert result == expected -def test_aggregate_empty(unary_agg_func_fixture): - arr = pa.array([], pa.float64()) - with pytest.raises(RuntimeError, match='.*empty inputs.*'): - pc.call_function("y=avg(x)", [arr]) +def test_agg_empty(unary_agg_func_fixture): + empty = pa.array([], pa.float64()) + + with pytest.raises(pa.ArrowInvalid, match='empty inputs'): + pc.call_function("y=avg(x)", [empty]) + + +def test_agg_wrong_output_dtype(wrong_output_dtype_agg_func_fixture): + arr = pa.array([10, 20, 30, 40, 50], pa.int64()) + with pytest.raises(pa.ArrowTypeError, match="output datatype"): + pc.call_function("y=wrong_output_dtype(x)", [arr]) + + +def test_agg_wrong_output_type(wrong_output_type_agg_func_fixture): + arr = pa.array([10, 20, 30, 40, 50], pa.int64()) + with pytest.raises(pa.ArrowTypeError, match="output type"): + pc.call_function("y=wrong_output_type(x)", [arr]) + -def test_aggregate_varargs(varargs_agg_func_fixture): +def test_agg_varargs(varargs_agg_func_fixture): arr1 = pa.array([10, 20, 30, 40, 50], pa.int64()) arr2 = pa.array([1.0, 2.0, 3.0, 4.0, 5.0], pa.float64()) @@ -643,8 +693,8 @@ def test_aggregate_varargs(varargs_agg_func_fixture): assert result == expected -def test_aggregate_exception(bad_unary_agg_func_fixture): +def test_agg_exception(exception_agg_func_fixture): arr = pa.array([10, 20, 30, 40, 50, 60], pa.int64()) with pytest.raises(RuntimeError, match='Oops'): - pc.call_function("y=bad_len(x)", [arr]) + pc.call_function("y=exception_len(x)", [arr]) From 8381f082b73637d8dec69b984100930fc155bf61 Mon Sep 17 00:00:00 2001 From: Li Jin Date: Mon, 5 Jun 2023 15:48:18 -0400 Subject: [PATCH 12/19] Minor updates --- python/pyarrow/_compute.pyx | 12 ++++++++---- python/pyarrow/src/arrow/python/udf.cc | 4 ++-- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/python/pyarrow/_compute.pyx b/python/pyarrow/_compute.pyx index 5e02ba33854..34497af000f 100644 --- a/python/pyarrow/_compute.pyx +++ b/python/pyarrow/_compute.pyx @@ -2676,6 +2676,8 @@ def register_scalar_function(func, function_name, function_doc, in_types, out_ty """ Register a user-defined scalar function. + This API is EXPERIMENTAL. + A scalar function is a function that executes elementwise operations on arrays or scalars, i.e. a scalar function must be computed row-by-row with no state where each output row @@ -2752,8 +2754,9 @@ def register_scalar_function(func, function_name, function_doc, in_types, out_ty def register_aggregate_function(func, function_name, function_doc, in_types, out_type, func_registry=None): - """ - Register a user-defined non-decomposable aggregate function. + """Register a user-defined non-decomposable aggregate function. + + This API is EXPERIMENTAL. A non-decomposable aggregation function is a function that executes aggregate operations on the whole data that it is aggregating. @@ -2828,8 +2831,9 @@ def register_aggregate_function(func, function_name, function_doc, in_types, out def register_tabular_function(func, function_name, function_doc, in_types, out_type, func_registry=None): - """ - Register a user-defined tabular function. + """Register a user-defined tabular function. + + This API is EXPERIMENTAL. A tabular function is one accepting a context argument of type UdfContext and returning a generator of struct arrays. diff --git a/python/pyarrow/src/arrow/python/udf.cc b/python/pyarrow/src/arrow/python/udf.cc index edc49c18051..efa1cc00b36 100644 --- a/python/pyarrow/src/arrow/python/udf.cc +++ b/python/pyarrow/src/arrow/python/udf.cc @@ -381,16 +381,16 @@ Status RegisterAggregateFunction(PyObject* agg_function, UdfWrapperCallback agg_ auto aggregate_func = std::make_shared( options.func_name, options.arity, options.func_doc, &default_scalar_aggregate_options); - Py_INCREF(agg_function); std::vector input_types; for (const auto& in_dtype : options.input_types) { input_types.emplace_back(in_dtype); } compute::OutputType output_type(options.output_type); - auto init = [agg_wrapper, agg_function, options]( + compute::KernelInit init = [agg_wrapper, agg_function, options]( compute::KernelContext* ctx, const compute::KernelInitArgs& args) -> Result> { + // Py_INCREF because OwnedRefNoGIL will call Py_XDECREF in destructor Py_INCREF(agg_function); return std::make_unique( agg_wrapper, From 9d7fd9d49f75e7355aceaab5fdfc990aca49b6a1 Mon Sep 17 00:00:00 2001 From: Li Jin Date: Mon, 5 Jun 2023 16:16:50 -0400 Subject: [PATCH 13/19] Try reverting Py ref count change --- python/pyarrow/src/arrow/python/udf.cc | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/pyarrow/src/arrow/python/udf.cc b/python/pyarrow/src/arrow/python/udf.cc index efa1cc00b36..32f020587f6 100644 --- a/python/pyarrow/src/arrow/python/udf.cc +++ b/python/pyarrow/src/arrow/python/udf.cc @@ -377,6 +377,8 @@ Status RegisterAggregateFunction(PyObject* agg_function, UdfWrapperCallback agg_ registry = compute::GetFunctionRegistry(); } + Py_INCREF(agg_function); + static auto default_scalar_aggregate_options = compute::ScalarAggregateOptions::Defaults(); auto aggregate_func = std::make_shared( options.func_name, options.arity, options.func_doc, &default_scalar_aggregate_options); From dc1d7340187dd3beb9e630a866fb5f7ea853cd8e Mon Sep 17 00:00:00 2001 From: Li Jin Date: Mon, 5 Jun 2023 16:25:13 -0400 Subject: [PATCH 14/19] Fix lint (clang-format) --- python/pyarrow/src/arrow/python/udf.cc | 220 ++++++++++++------------- python/pyarrow/src/arrow/python/udf.h | 12 +- 2 files changed, 115 insertions(+), 117 deletions(-) diff --git a/python/pyarrow/src/arrow/python/udf.cc b/python/pyarrow/src/arrow/python/udf.cc index 32f020587f6..a856f0e398b 100644 --- a/python/pyarrow/src/arrow/python/udf.cc +++ b/python/pyarrow/src/arrow/python/udf.cc @@ -17,16 +17,16 @@ #include -#include "arrow/python/udf.h" -#include "arrow/table.h" #include "arrow/compute/api_aggregate.h" #include "arrow/compute/function.h" #include "arrow/compute/kernel.h" #include "arrow/python/common.h" +#include "arrow/python/udf.h" +#include "arrow/table.h" #include "arrow/util/checked_cast.h" namespace arrow { - using internal::checked_cast; +using internal::checked_cast; namespace py { namespace { @@ -75,7 +75,8 @@ struct ScalarUdfAggregator : public compute::KernelState { virtual Status Finalize(compute::KernelContext* ctx, Datum* out) = 0; }; -arrow::Status AggregateUdfConsume(compute::KernelContext* ctx, const compute::ExecSpan& batch) { +arrow::Status AggregateUdfConsume(compute::KernelContext* ctx, + const compute::ExecSpan& batch) { return checked_cast(ctx->state())->Consume(ctx, batch); } @@ -124,106 +125,101 @@ struct PythonTableUdfKernelInit { UdfWrapperCallback cb; }; - struct PythonUdfScalarAggregatorImpl : public ScalarUdfAggregator { - - PythonUdfScalarAggregatorImpl(UdfWrapperCallback agg_cb, - std::shared_ptr agg_function, - std::vector> input_types, - std::shared_ptr output_type): - agg_cb(agg_cb), - agg_function(agg_function), - output_type(output_type) { - std::vector> fields; - for (size_t i = 0; i < input_types.size(); i++) { - fields.push_back(std::move(field("", input_types[i]))); - } - input_schema = schema(std::move(fields)); - }; - - ~PythonUdfScalarAggregatorImpl() { - if (_Py_IsFinalizing()) { - agg_function->detach(); - } +struct PythonUdfScalarAggregatorImpl : public ScalarUdfAggregator { + PythonUdfScalarAggregatorImpl(UdfWrapperCallback agg_cb, + std::shared_ptr agg_function, + std::vector> input_types, + std::shared_ptr output_type) + : agg_cb(agg_cb), agg_function(agg_function), output_type(output_type) { + std::vector> fields; + for (size_t i = 0; i < input_types.size(); i++) { + fields.push_back(std::move(field("", input_types[i]))); } + input_schema = schema(std::move(fields)); + }; - Status Consume(compute::KernelContext* ctx, const compute::ExecSpan& batch) { - ARROW_ASSIGN_OR_RAISE(auto rb, batch.ToExecBatch().ToRecordBatch(input_schema, ctx->memory_pool())); - values.push_back(std::move(rb)); - return Status::OK(); + ~PythonUdfScalarAggregatorImpl() { + if (_Py_IsFinalizing()) { + agg_function->detach(); } + } - Status MergeFrom(compute::KernelContext* ctx, compute::KernelState&& src) { - auto& other_values = checked_cast(src).values; - values.insert(values.end(), - std::make_move_iterator(other_values.begin()), - std::make_move_iterator(other_values.end())); + Status Consume(compute::KernelContext* ctx, const compute::ExecSpan& batch) { + ARROW_ASSIGN_OR_RAISE( + auto rb, batch.ToExecBatch().ToRecordBatch(input_schema, ctx->memory_pool())); + values.push_back(std::move(rb)); + return Status::OK(); + } - other_values.erase(other_values.begin(), other_values.end()); - return Status::OK(); + Status MergeFrom(compute::KernelContext* ctx, compute::KernelState&& src) { + auto& other_values = checked_cast(src).values; + values.insert(values.end(), std::make_move_iterator(other_values.begin()), + std::make_move_iterator(other_values.end())); + + other_values.erase(other_values.begin(), other_values.end()); + return Status::OK(); + } + + Status Finalize(compute::KernelContext* ctx, Datum* out) { + auto state = + arrow::internal::checked_cast(ctx->state()); + std::shared_ptr& function = state->agg_function; + const int num_args = input_schema->num_fields(); + + // Note: The way that batches are concatenated together + // would result in using double amount of the memory. + // This is OK for now because non decomposable aggregate + // UDF is supposed to be used with segmented aggregation + // where the size of the segment is more or less constant + // so doubling that is not a big deal. This can be also + // improved in the future to use more efficient way to + // concatenate. + ARROW_ASSIGN_OR_RAISE(auto table, + arrow::Table::FromRecordBatches(input_schema, values)); + ARROW_ASSIGN_OR_RAISE(table, table->CombineChunks(ctx->memory_pool())); + UdfContext udf_context{ctx->memory_pool(), table->num_rows()}; + + if (table->num_rows() == 0) { + return Status::Invalid("Finalized is called with empty inputs"); } - Status Finalize(compute::KernelContext* ctx, Datum* out) { - auto state = arrow::internal::checked_cast(ctx->state()); - std::shared_ptr& function = state->agg_function; - const int num_args = input_schema->num_fields(); - - // Note: The way that batches are concatenated together - // would result in using double amount of the memory. - // This is OK for now because non decomposable aggregate - // UDF is supposed to be used with segmented aggregation - // where the size of the segment is more or less constant - // so doubling that is not a big deal. This can be also - // improved in the future to use more efficient way to - // concatenate. - ARROW_ASSIGN_OR_RAISE( - auto table, - arrow::Table::FromRecordBatches(input_schema, values) - ); - ARROW_ASSIGN_OR_RAISE( - table, table->CombineChunks(ctx->memory_pool()) - ); - UdfContext udf_context{ctx->memory_pool(), table->num_rows()}; - - if (table->num_rows() == 0) { - return Status::Invalid("Finalized is called with empty inputs"); - } + std::unique_ptr result; + RETURN_NOT_OK(SafeCallIntoPython([&] { + OwnedRef arg_tuple(PyTuple_New(num_args)); + RETURN_NOT_OK(CheckPyError()); - std::unique_ptr result; - RETURN_NOT_OK(SafeCallIntoPython([&] { - OwnedRef arg_tuple(PyTuple_New(num_args)); - RETURN_NOT_OK(CheckPyError()); - - for (int arg_id = 0; arg_id < num_args; arg_id++) { - // Since we combined chunks there is only one chunk - std::shared_ptr c_data = table->column(arg_id)->chunk(0); - PyObject* data = wrap_array(c_data); - PyTuple_SetItem(arg_tuple.obj(), arg_id, data); - } - result = std::make_unique(agg_cb(function->obj(), udf_context, arg_tuple.obj())); - RETURN_NOT_OK(CheckPyError()); - return Status::OK(); - })); - // unwrapping the output for expected output type - if (is_scalar(result->obj())) { - ARROW_ASSIGN_OR_RAISE(std::shared_ptr val, unwrap_scalar(result->obj())); - if (*output_type != *val->type) { - return Status::TypeError("Expected output datatype ", output_type->ToString(), - ", but function returned datatype ", - val->type->ToString()); - } - out->value = std::move(val); - return Status::OK(); + for (int arg_id = 0; arg_id < num_args; arg_id++) { + // Since we combined chunks there is only one chunk + std::shared_ptr c_data = table->column(arg_id)->chunk(0); + PyObject* data = wrap_array(c_data); + PyTuple_SetItem(arg_tuple.obj(), arg_id, data); + } + result = std::make_unique( + agg_cb(function->obj(), udf_context, arg_tuple.obj())); + RETURN_NOT_OK(CheckPyError()); + return Status::OK(); + })); + // unwrapping the output for expected output type + if (is_scalar(result->obj())) { + ARROW_ASSIGN_OR_RAISE(std::shared_ptr val, unwrap_scalar(result->obj())); + if (*output_type != *val->type) { + return Status::TypeError("Expected output datatype ", output_type->ToString(), + ", but function returned datatype ", + val->type->ToString()); } - return Status::TypeError("Unexpected output type: ", Py_TYPE(result->obj())->tp_name, - " (expected Scalar)"); + out->value = std::move(val); + return Status::OK(); } + return Status::TypeError("Unexpected output type: ", Py_TYPE(result->obj())->tp_name, + " (expected Scalar)"); + } - UdfWrapperCallback agg_cb; - std::vector> values; - std::shared_ptr agg_function; - std::shared_ptr input_schema; - std::shared_ptr output_type; - }; + UdfWrapperCallback agg_cb; + std::vector> values; + std::shared_ptr agg_function; + std::shared_ptr input_schema; + std::shared_ptr output_type; +}; struct PythonUdf : public PythonUdfKernelState { PythonUdf(std::shared_ptr function, UdfWrapperCallback cb, @@ -358,17 +354,18 @@ Status RegisterTabularFunction(PyObject* user_function, UdfWrapperCallback wrapp wrapper, options, registry); } -Status AddAggKernel(std::shared_ptr sig, compute::KernelInit init, - compute::ScalarAggregateFunction* func) { - - compute::ScalarAggregateKernel kernel(std::move(sig), std::move(init), AggregateUdfConsume, AggregateUdfMerge, AggregateUdfFinalize, /*ordered=*/false); +Status AddAggKernel(std::shared_ptr sig, + compute::KernelInit init, compute::ScalarAggregateFunction* func) { + compute::ScalarAggregateKernel kernel(std::move(sig), std::move(init), + AggregateUdfConsume, AggregateUdfMerge, + AggregateUdfFinalize, /*ordered=*/false); RETURN_NOT_OK(func->AddKernel(std::move(kernel))); return Status::OK(); } Status RegisterAggregateFunction(PyObject* agg_function, UdfWrapperCallback agg_wrapper, - const UdfOptions& options, - compute::FunctionRegistry* registry) { + const UdfOptions& options, + compute::FunctionRegistry* registry) { if (!PyCallable_Check(agg_function)) { return Status::TypeError("Expected a callable Python object."); } @@ -379,9 +376,11 @@ Status RegisterAggregateFunction(PyObject* agg_function, UdfWrapperCallback agg_ Py_INCREF(agg_function); - static auto default_scalar_aggregate_options = compute::ScalarAggregateOptions::Defaults(); + static auto default_scalar_aggregate_options = + compute::ScalarAggregateOptions::Defaults(); auto aggregate_func = std::make_shared( - options.func_name, options.arity, options.func_doc, &default_scalar_aggregate_options); + options.func_name, options.arity, options.func_doc, + &default_scalar_aggregate_options); std::vector input_types; for (const auto& in_dtype : options.input_types) { @@ -390,21 +389,20 @@ Status RegisterAggregateFunction(PyObject* agg_function, UdfWrapperCallback agg_ compute::OutputType output_type(options.output_type); compute::KernelInit init = [agg_wrapper, agg_function, options]( - compute::KernelContext* ctx, - const compute::KernelInitArgs& args) -> Result> { + compute::KernelContext* ctx, + const compute::KernelInitArgs& args) + -> Result> { // Py_INCREF because OwnedRefNoGIL will call Py_XDECREF in destructor Py_INCREF(agg_function); return std::make_unique( - agg_wrapper, - std::make_shared(agg_function), - options.input_types, - options.output_type); + agg_wrapper, std::make_shared(agg_function), options.input_types, + options.output_type); }; - RETURN_NOT_OK( - AddAggKernel(compute::KernelSignature::Make( - std::move(input_types), std::move(output_type), options.arity.is_varargs), - init, aggregate_func.get())); + RETURN_NOT_OK(AddAggKernel( + compute::KernelSignature::Make(std::move(input_types), std::move(output_type), + options.arity.is_varargs), + init, aggregate_func.get())); RETURN_NOT_OK(registry->AddFunction(std::move(aggregate_func))); return Status::OK(); diff --git a/python/pyarrow/src/arrow/python/udf.h b/python/pyarrow/src/arrow/python/udf.h index cc2f3ab62f5..682cbb2ffe8 100644 --- a/python/pyarrow/src/arrow/python/udf.h +++ b/python/pyarrow/src/arrow/python/udf.h @@ -54,18 +54,18 @@ using UdfWrapperCallback = std::function> ARROW_PYTHON_EXPORT CallTabularFunction(const std::string& func_name, const std::vector& args, From 1203346cb592aa01b22d1e4ca5351f26e6410fbf Mon Sep 17 00:00:00 2001 From: Li Jin Date: Tue, 6 Jun 2023 11:39:34 -0400 Subject: [PATCH 15/19] Fix core-dump when running with Python dev mode --- python/pyarrow/src/arrow/python/udf.cc | 36 ++++++++++++++------------ 1 file changed, 19 insertions(+), 17 deletions(-) diff --git a/python/pyarrow/src/arrow/python/udf.cc b/python/pyarrow/src/arrow/python/udf.cc index a856f0e398b..8cc5fb659ab 100644 --- a/python/pyarrow/src/arrow/python/udf.cc +++ b/python/pyarrow/src/arrow/python/udf.cc @@ -131,6 +131,7 @@ struct PythonUdfScalarAggregatorImpl : public ScalarUdfAggregator { std::vector> input_types, std::shared_ptr output_type) : agg_cb(agg_cb), agg_function(agg_function), output_type(output_type) { + Py_INCREF(agg_function->obj()); std::vector> fields; for (size_t i = 0; i < input_types.size(); i++) { fields.push_back(std::move(field("", input_types[i]))); @@ -183,8 +184,8 @@ struct PythonUdfScalarAggregatorImpl : public ScalarUdfAggregator { return Status::Invalid("Finalized is called with empty inputs"); } - std::unique_ptr result; RETURN_NOT_OK(SafeCallIntoPython([&] { + std::unique_ptr result; OwnedRef arg_tuple(PyTuple_New(num_args)); RETURN_NOT_OK(CheckPyError()); @@ -197,21 +198,21 @@ struct PythonUdfScalarAggregatorImpl : public ScalarUdfAggregator { result = std::make_unique( agg_cb(function->obj(), udf_context, arg_tuple.obj())); RETURN_NOT_OK(CheckPyError()); - return Status::OK(); - })); - // unwrapping the output for expected output type - if (is_scalar(result->obj())) { - ARROW_ASSIGN_OR_RAISE(std::shared_ptr val, unwrap_scalar(result->obj())); - if (*output_type != *val->type) { - return Status::TypeError("Expected output datatype ", output_type->ToString(), - ", but function returned datatype ", - val->type->ToString()); + // unwrapping the output for expected output type + if (is_scalar(result->obj())) { + ARROW_ASSIGN_OR_RAISE(std::shared_ptr val, unwrap_scalar(result->obj())); + if (*output_type != *val->type) { + return Status::TypeError("Expected output datatype ", output_type->ToString(), + ", but function returned datatype ", + val->type->ToString()); + } + out->value = std::move(val); + return Status::OK(); } - out->value = std::move(val); - return Status::OK(); - } - return Status::TypeError("Unexpected output type: ", Py_TYPE(result->obj())->tp_name, - " (expected Scalar)"); + return Status::TypeError("Unexpected output type: ", Py_TYPE(result->obj())->tp_name, + " (expected Scalar)"); + })); + return Status::OK(); } UdfWrapperCallback agg_cb; @@ -374,6 +375,9 @@ Status RegisterAggregateFunction(PyObject* agg_function, UdfWrapperCallback agg_ registry = compute::GetFunctionRegistry(); } + // Py_INCREF here so that once a function is registered + // its refcount gets increased by 1 and doesn't get gced + // if all existing refs are gone Py_INCREF(agg_function); static auto default_scalar_aggregate_options = @@ -392,8 +396,6 @@ Status RegisterAggregateFunction(PyObject* agg_function, UdfWrapperCallback agg_ compute::KernelContext* ctx, const compute::KernelInitArgs& args) -> Result> { - // Py_INCREF because OwnedRefNoGIL will call Py_XDECREF in destructor - Py_INCREF(agg_function); return std::make_unique( agg_wrapper, std::make_shared(agg_function), options.input_types, options.output_type); From 84c1e9190cc2a4d9a07555b32251ef21feb5e98e Mon Sep 17 00:00:00 2001 From: Li Jin Date: Tue, 6 Jun 2023 14:03:06 -0400 Subject: [PATCH 16/19] Lint --- python/pyarrow/src/arrow/python/udf.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/pyarrow/src/arrow/python/udf.cc b/python/pyarrow/src/arrow/python/udf.cc index 8cc5fb659ab..7ab8c4da4ea 100644 --- a/python/pyarrow/src/arrow/python/udf.cc +++ b/python/pyarrow/src/arrow/python/udf.cc @@ -209,8 +209,8 @@ struct PythonUdfScalarAggregatorImpl : public ScalarUdfAggregator { out->value = std::move(val); return Status::OK(); } - return Status::TypeError("Unexpected output type: ", Py_TYPE(result->obj())->tp_name, - " (expected Scalar)"); + return Status::TypeError("Unexpected output type: ", + Py_TYPE(result->obj())->tp_name, " (expected Scalar)"); })); return Status::OK(); } From 17ff274e8efca6e2f772232d9fd60c929a526d7c Mon Sep 17 00:00:00 2001 From: Li Jin Date: Tue, 6 Jun 2023 17:53:15 -0400 Subject: [PATCH 17/19] Try fixing numpydoc lint --- python/pyarrow/_compute.pyx | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/python/pyarrow/_compute.pyx b/python/pyarrow/_compute.pyx index 34497af000f..eaf9d1dfb65 100644 --- a/python/pyarrow/_compute.pyx +++ b/python/pyarrow/_compute.pyx @@ -2754,7 +2754,8 @@ def register_scalar_function(func, function_name, function_doc, in_types, out_ty def register_aggregate_function(func, function_name, function_doc, in_types, out_type, func_registry=None): - """Register a user-defined non-decomposable aggregate function. + """ + Register a user-defined non-decomposable aggregate function. This API is EXPERIMENTAL. @@ -2778,7 +2779,6 @@ def register_aggregate_function(func, function_name, function_doc, in_types, out To define a varargs function, pass a callable that takes *args. The in_type needs to match in type of inputs when the function gets called. - function_name : str Name of the function. This name must be unique, i.e., there should only be one function registered with @@ -2831,7 +2831,8 @@ def register_aggregate_function(func, function_name, function_doc, in_types, out def register_tabular_function(func, function_name, function_doc, in_types, out_type, func_registry=None): - """Register a user-defined tabular function. + """ + Register a user-defined tabular function. This API is EXPERIMENTAL. From 7f65599db6635b78deea8ec3e9a27d52ab3523b1 Mon Sep 17 00:00:00 2001 From: Li Jin Date: Thu, 8 Jun 2023 09:29:04 -0400 Subject: [PATCH 18/19] Apply suggestions from code review Co-authored-by: Weston Pace --- python/pyarrow/src/arrow/python/udf.cc | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/python/pyarrow/src/arrow/python/udf.cc b/python/pyarrow/src/arrow/python/udf.cc index 7ab8c4da4ea..432a39d3fba 100644 --- a/python/pyarrow/src/arrow/python/udf.cc +++ b/python/pyarrow/src/arrow/python/udf.cc @@ -15,7 +15,6 @@ // specific language governing permissions and limitations // under the License. -#include #include "arrow/compute/api_aggregate.h" #include "arrow/compute/function.h" @@ -130,29 +129,29 @@ struct PythonUdfScalarAggregatorImpl : public ScalarUdfAggregator { std::shared_ptr agg_function, std::vector> input_types, std::shared_ptr output_type) - : agg_cb(agg_cb), agg_function(agg_function), output_type(output_type) { + : agg_cb(std::move(agg_cb)), agg_function(agg_function), output_type(std::move(output_type)) { Py_INCREF(agg_function->obj()); std::vector> fields; for (size_t i = 0; i < input_types.size(); i++) { - fields.push_back(std::move(field("", input_types[i]))); + fields.push_back(field("", input_types[i])); } input_schema = schema(std::move(fields)); }; - ~PythonUdfScalarAggregatorImpl() { + ~PythonUdfScalarAggregatorImpl() override { if (_Py_IsFinalizing()) { agg_function->detach(); } } - Status Consume(compute::KernelContext* ctx, const compute::ExecSpan& batch) { + Status Consume(compute::KernelContext* ctx, const compute::ExecSpan& batch) override { ARROW_ASSIGN_OR_RAISE( auto rb, batch.ToExecBatch().ToRecordBatch(input_schema, ctx->memory_pool())); values.push_back(std::move(rb)); return Status::OK(); } - Status MergeFrom(compute::KernelContext* ctx, compute::KernelState&& src) { + Status MergeFrom(compute::KernelContext* ctx, compute::KernelState&& src) override { auto& other_values = checked_cast(src).values; values.insert(values.end(), std::make_move_iterator(other_values.begin()), std::make_move_iterator(other_values.end())); @@ -161,7 +160,7 @@ struct PythonUdfScalarAggregatorImpl : public ScalarUdfAggregator { return Status::OK(); } - Status Finalize(compute::KernelContext* ctx, Datum* out) { + Status Finalize(compute::KernelContext* ctx, Datum* out) override { auto state = arrow::internal::checked_cast(ctx->state()); std::shared_ptr& function = state->agg_function; From febf6cce524df3a1076dc988293f7424e3354f6e Mon Sep 17 00:00:00 2001 From: Li Jin Date: Thu, 8 Jun 2023 10:30:18 -0400 Subject: [PATCH 19/19] Lint fix --- python/pyarrow/src/arrow/python/udf.cc | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/python/pyarrow/src/arrow/python/udf.cc b/python/pyarrow/src/arrow/python/udf.cc index 432a39d3fba..06c116af820 100644 --- a/python/pyarrow/src/arrow/python/udf.cc +++ b/python/pyarrow/src/arrow/python/udf.cc @@ -15,12 +15,11 @@ // specific language governing permissions and limitations // under the License. - +#include "arrow/python/udf.h" #include "arrow/compute/api_aggregate.h" #include "arrow/compute/function.h" #include "arrow/compute/kernel.h" #include "arrow/python/common.h" -#include "arrow/python/udf.h" #include "arrow/table.h" #include "arrow/util/checked_cast.h" @@ -129,7 +128,9 @@ struct PythonUdfScalarAggregatorImpl : public ScalarUdfAggregator { std::shared_ptr agg_function, std::vector> input_types, std::shared_ptr output_type) - : agg_cb(std::move(agg_cb)), agg_function(agg_function), output_type(std::move(output_type)) { + : agg_cb(std::move(agg_cb)), + agg_function(agg_function), + output_type(std::move(output_type)) { Py_INCREF(agg_function->obj()); std::vector> fields; for (size_t i = 0; i < input_types.size(); i++) {