From 57a73a29362ccb22cbd5d527a920b00c00209249 Mon Sep 17 00:00:00 2001 From: Li Jin Date: Tue, 20 Jun 2023 17:40:19 -0400 Subject: [PATCH 01/11] WIP: Test passing --- cpp/src/arrow/compute/row/grouper.h | 2 +- python/pyarrow/conftest.py | 5 +- python/pyarrow/src/arrow/python/udf.cc | 281 +++++++++++++++++++++++-- python/pyarrow/tests/test_udf.py | 22 +- 4 files changed, 290 insertions(+), 20 deletions(-) diff --git a/cpp/src/arrow/compute/row/grouper.h b/cpp/src/arrow/compute/row/grouper.h index 15f00eaac21..e273367a507 100644 --- a/cpp/src/arrow/compute/row/grouper.h +++ b/cpp/src/arrow/compute/row/grouper.h @@ -149,7 +149,7 @@ class ARROW_EXPORT Grouper { /// [] /// ] static Result> MakeGroupings( - const UInt32Array& ids, uint32_t num_groups, + const UInt32Array& ids, uint32_t , ExecContext* ctx = default_exec_context()); /// \brief Produce a ListArray whose slots are selections of `array` which correspond to diff --git a/python/pyarrow/conftest.py b/python/pyarrow/conftest.py index f32cbf01efc..96ff173d6d1 100644 --- a/python/pyarrow/conftest.py +++ b/python/pyarrow/conftest.py @@ -288,13 +288,14 @@ def unary_agg_func_fixture(): from pyarrow import compute as pc import numpy as np - def func(ctx, x): + def func(ctx, x, *args): return pa.scalar(np.nanmean(x)) - func_name = "y=avg(x)" + func_name = "mean_udf" func_doc = {"summary": "y=avg(x)", "description": "find mean of x"} + breakpoint() pc.register_aggregate_function(func, func_name, func_doc, diff --git a/python/pyarrow/src/arrow/python/udf.cc b/python/pyarrow/src/arrow/python/udf.cc index 06c116af820..dec1d649164 100644 --- a/python/pyarrow/src/arrow/python/udf.cc +++ b/python/pyarrow/src/arrow/python/udf.cc @@ -15,16 +15,28 @@ // specific language governing permissions and limitations // under the License. +#include + #include "arrow/python/udf.h" #include "arrow/compute/api_aggregate.h" +#include "arrow/buffer_builder.h" +#include "arrow/array/builder_base.h" +#include "arrow/compute/api_vector.h" #include "arrow/compute/function.h" +#include "arrow/compute/row/grouper.h" #include "arrow/compute/kernel.h" #include "arrow/python/common.h" +#include "arrow/util/logging.h" #include "arrow/table.h" #include "arrow/util/checked_cast.h" namespace arrow { using internal::checked_cast; +using compute::KernelState; +using compute::KernelContext; +using compute::ExecSpan; +using compute::Grouper; + namespace py { namespace { @@ -73,6 +85,13 @@ struct ScalarUdfAggregator : public compute::KernelState { virtual Status Finalize(compute::KernelContext* ctx, Datum* out) = 0; }; +struct HashUdfAggregator : public compute::KernelState { + virtual Status Resize(KernelContext* ctx, int64_t size) = 0; + virtual Status Consume(KernelContext* ctx, const ExecSpan& batch) = 0; + virtual Status Merge(KernelContext* ct, KernelState&& other, const ArrayData&) = 0; + virtual Status Finalize(KernelContext* ctx, Datum* out) = 0; +}; + arrow::Status AggregateUdfConsume(compute::KernelContext* ctx, const compute::ExecSpan& batch) { return checked_cast(ctx->state())->Consume(ctx, batch); @@ -87,6 +106,22 @@ arrow::Status AggregateUdfFinalize(compute::KernelContext* ctx, arrow::Datum* ou return checked_cast(ctx->state())->Finalize(ctx, out); } +arrow::Status HashAggregateUdfResize(KernelContext* ctx, int64_t size) { + return checked_cast(ctx->state())->Resize(ctx, size); +} + +arrow::Status HashAggregateUdfConsume(KernelContext* ctx, const ExecSpan& batch) { + return checked_cast(ctx->state())->Consume(ctx, batch); +} + +arrow::Status HashAggregateUdfMerge(KernelContext* ctx, KernelState&& src, const ArrayData& group_id_mapping) { + return checked_cast(ctx->state())->Merge(ctx, std::move(src), group_id_mapping); +} + +arrow::Status HashAggregateUdfFinalize(KernelContext* ctx, Datum* out) { + return checked_cast(ctx->state())->Finalize(ctx, out); +} + struct PythonTableUdfKernelInit { PythonTableUdfKernelInit(std::shared_ptr function_maker, UdfWrapperCallback cb) @@ -222,6 +257,145 @@ struct PythonUdfScalarAggregatorImpl : public ScalarUdfAggregator { std::shared_ptr output_type; }; +struct PythonUdfHashAggregatorImpl : public HashUdfAggregator { + PythonUdfHashAggregatorImpl(UdfWrapperCallback agg_cb, + std::shared_ptr function, + std::vector> input_types, + std::shared_ptr output_type) + : agg_cb(std::move(agg_cb)), + function(function), + output_type(std::move(output_type)) { + Py_INCREF(function->obj()); + std::vector> fields; + for (size_t i = 0; i < input_types.size(); i++) { + fields.push_back(field("", input_types[i])); + } + input_schema = schema(std::move(fields)); + }; + + ~PythonUdfHashAggregatorImpl() override { + if (_Py_IsFinalizing()) { + function->detach(); + } + } + + static Result ApplyGroupings( + const ListArray& groupings, const std::shared_ptr& batch) { + ARROW_ASSIGN_OR_RAISE(Datum sorted, + compute::Take(batch, groupings.data()->child_data[0])); + + const auto& sorted_batch = *sorted.record_batch(); + + RecordBatchVector out(static_cast(groupings.length())); + for (size_t i = 0; i < out.size(); ++i) { + out[i] = sorted_batch.Slice(groupings.value_offset(i), groupings.value_length(i)); + } + + return out; +} + + Status Resize(KernelContext* ctx, int64_t new_num_groups) { + num_groups = new_num_groups; + return Status::OK(); + } + Status Consume(KernelContext* ctx, const ExecSpan& batch) { + // last array is the group id + ARROW_ASSIGN_OR_RAISE( + std::shared_ptr rb, batch.ToExecBatch().ToRecordBatch(input_schema, ctx->memory_pool())); + + const ArraySpan& groups_array_data = batch[batch.num_values() - 1].array; + int64_t batch_num_values = groups_array_data.length; + const auto* batch_groups = groups_array_data.GetValues(1, 0); + DCHECK_EQ(groups_array_data.offset, 0); + RETURN_NOT_OK(groups.Append(batch_groups, batch_num_values)); + values.push_back(std::move(rb)); + num_values += batch_num_values; + return Status::OK(); + } + Status Merge(KernelContext* ctx, KernelState&& other_state, const ArrayData& group_id_mapping) { + auto& other = checked_cast(other_state); + auto& other_values = other.values; + const uint32_t* other_raw_groups = other.groups.data(); + values.insert(values.end(), std::make_move_iterator(other_values.begin()), + std::make_move_iterator(other_values.end())); + + auto g = group_id_mapping.GetValues(1); + for (uint32_t other_g = 0; static_cast(other_g) < other.num_values; ++other_g) { + RETURN_NOT_OK(groups.Append(g[other_raw_groups[other_g]])); + } + + num_values += other.num_values; + return Status::OK(); + } + + Status Finalize(KernelContext* ctx, Datum* out) { + const int num_args = input_schema->num_fields(); + + ARROW_ASSIGN_OR_RAISE(auto groups_buffer, groups.Finish()); + ARROW_ASSIGN_OR_RAISE( + auto groupings, + Grouper::MakeGroupings(UInt32Array(num_values, groups_buffer), static_cast(num_groups)) + ); + + ARROW_ASSIGN_OR_RAISE(auto table, + arrow::Table::FromRecordBatches(input_schema, values)); + ARROW_ASSIGN_OR_RAISE(auto rb, table->CombineChunksToBatch(ctx->memory_pool())); + UdfContext udf_context{ctx->memory_pool(), table->num_rows()}; + + if (rb->num_rows() == 0) { + return Status::Invalid("Finalized is called with empty inputs"); + } + + ARROW_ASSIGN_OR_RAISE(RecordBatchVector rbs, ApplyGroupings(*groupings, rb)); + + return SafeCallIntoPython([&] { + ARROW_ASSIGN_OR_RAISE(std::unique_ptr builder, MakeBuilder(output_type, ctx->memory_pool())); + for (auto& group_rb : rbs) { + std::unique_ptr result; + OwnedRef arg_tuple(PyTuple_New(num_args)); + RETURN_NOT_OK(CheckPyError()); + + // Exclude the last column which is the group id + for (int arg_id = 0; arg_id < num_args; arg_id++) { + // Since we combined chunks there is only one chunk + std::shared_ptr c_data = group_rb->column(arg_id); + PyObject* data = wrap_array(c_data); + PyTuple_SetItem(arg_tuple.obj(), arg_id, data); + } + + result = std::make_unique(agg_cb(function->obj(), udf_context, arg_tuple.obj())); + RETURN_NOT_OK(CheckPyError()); + + // unwrapping the output for expected output type + if (is_scalar(result->obj())) { + ARROW_ASSIGN_OR_RAISE(std::shared_ptr val, unwrap_scalar(result->obj())); + if (*output_type != *val->type) { + return Status::TypeError("Expected output datatype ", output_type->ToString(), + ", but function returned datatype ", + val->type->ToString()); + } + ARROW_RETURN_NOT_OK(builder->AppendScalar(std::move(*val))); + } else { + return Status::TypeError("Unexpected output type: ", + Py_TYPE(result->obj())->tp_name, " (expected Scalar)"); + } + } + ARROW_ASSIGN_OR_RAISE(auto result, builder->Finish()); + out->value = std::move(result->data()); + return Status::OK(); + }); + } + + int64_t num_groups = 0; + int64_t num_values = 0; + UdfWrapperCallback agg_cb; + std::vector> values; + TypedBufferBuilder groups; + std::shared_ptr function; + std::shared_ptr input_schema; + std::shared_ptr output_type; +}; + struct PythonUdf : public PythonUdfKernelState { PythonUdf(std::shared_ptr function, UdfWrapperCallback cb, std::vector input_types, compute::OutputType output_type) @@ -355,16 +529,7 @@ Status RegisterTabularFunction(PyObject* user_function, UdfWrapperCallback wrapp wrapper, options, registry); } -Status AddAggKernel(std::shared_ptr sig, - compute::KernelInit init, compute::ScalarAggregateFunction* func) { - compute::ScalarAggregateKernel kernel(std::move(sig), std::move(init), - AggregateUdfConsume, AggregateUdfMerge, - AggregateUdfFinalize, /*ordered=*/false); - RETURN_NOT_OK(func->AddKernel(std::move(kernel))); - return Status::OK(); -} - -Status RegisterAggregateFunction(PyObject* agg_function, UdfWrapperCallback agg_wrapper, +Status RegisterScalarAggregateFunction(PyObject* agg_function, UdfWrapperCallback agg_wrapper, const UdfOptions& options, compute::FunctionRegistry* registry) { if (!PyCallable_Check(agg_function)) { @@ -401,15 +566,101 @@ Status RegisterAggregateFunction(PyObject* agg_function, UdfWrapperCallback agg_ options.output_type); }; - RETURN_NOT_OK(AddAggKernel( - compute::KernelSignature::Make(std::move(input_types), std::move(output_type), - options.arity.is_varargs), - init, aggregate_func.get())); - + auto sig = compute::KernelSignature::Make( + std::move(input_types), std::move(output_type), options.arity.is_varargs); + compute::ScalarAggregateKernel kernel(std::move(sig), std::move(init), + AggregateUdfConsume, AggregateUdfMerge, + AggregateUdfFinalize, /*ordered=*/false); + RETURN_NOT_OK(aggregate_func->AddKernel(std::move(kernel))); RETURN_NOT_OK(registry->AddFunction(std::move(aggregate_func))); return Status::OK(); } +/// @brief Create a new UdfOptions with adjustment for hash kernel +/// @param options User provided udf options +UdfOptions AdjustForHashAggregate(const UdfOptions& options) { + UdfOptions hash_options; + // Append hash_ before the function name to seperate from the scalar + // version + hash_options.func_name = "hash_" + options.func_name; + // Extend input types with group id. Group id is appended by the group + // aggregation node. Here we change both arity and input types + if (options.arity.is_varargs) { + hash_options.arity = options.arity; + } else { + hash_options.arity = compute::Arity(options.arity.num_args + 1, false); + } + // Changing the function doc shouldn't be necessarily because group id + // is not user visible, however, this is currently needed to pass the + // function validation. The name group_id_array is consistent with + // hash kernels in hash_aggregate.cc + hash_options.func_doc = options.func_doc; + hash_options.func_doc.arg_names.emplace_back("group_id_array"); + std::vector> input_dtypes = options.input_types; + input_dtypes.emplace_back(uint32()); + hash_options.input_types = std::move(input_dtypes); + hash_options.output_type = options.output_type; + return hash_options; +} + +Status RegisterHashAggregateFunction(PyObject* agg_function, UdfWrapperCallback agg_wrapper, + const UdfOptions& options, + compute::FunctionRegistry* registry) { + if (!PyCallable_Check(agg_function)) { + return Status::TypeError("Expected a callable Python object."); + } + + if (registry == NULLPTR) { + registry = compute::GetFunctionRegistry(); + } + + // Py_INCREF here so that once a function is registered + // its refcount gets increased by 1 and doesn't get gced + // if all existing refs are gone + Py_INCREF(agg_function); + UdfOptions hash_options = AdjustForHashAggregate(options); + + std::vector input_types; + for (const auto& in_dtype : hash_options.input_types) { + input_types.emplace_back(in_dtype); + } + compute::OutputType output_type(hash_options.output_type); + + static auto default_hash_aggregate_options = + compute::ScalarAggregateOptions::Defaults(); + auto hash_aggregate_func = std::make_shared( + hash_options.func_name, hash_options.arity, hash_options.func_doc, + &default_hash_aggregate_options); + + compute::KernelInit init = [agg_wrapper, agg_function, hash_options]( + compute::KernelContext* ctx, + const compute::KernelInitArgs& args) + -> Result> { + return std::make_unique( + agg_wrapper, std::make_shared(agg_function), hash_options.input_types, + hash_options.output_type); + }; + + auto sig = compute::KernelSignature::Make(std::move(input_types), std::move(output_type), + hash_options.arity.is_varargs); + + compute::HashAggregateKernel kernel(std::move(sig), std::move(init), + HashAggregateUdfResize, HashAggregateUdfConsume, + HashAggregateUdfMerge, HashAggregateUdfFinalize, /*ordered=*/false); + RETURN_NOT_OK(hash_aggregate_func->AddKernel(std::move(kernel))); + RETURN_NOT_OK(registry->AddFunction(std::move(hash_aggregate_func))); + return Status::OK(); +} + +Status RegisterAggregateFunction(PyObject* function, UdfWrapperCallback wrapper, + const UdfOptions& options, + compute::FunctionRegistry* registry) { + RETURN_NOT_OK(RegisterScalarAggregateFunction(function, wrapper, options, registry)); + RETURN_NOT_OK(RegisterHashAggregateFunction(function, wrapper, options, registry)); + + return Status::OK(); +} + Result> CallTabularFunction( const std::string& func_name, const std::vector& args, compute::FunctionRegistry* registry) { diff --git a/python/pyarrow/tests/test_udf.py b/python/pyarrow/tests/test_udf.py index c0cfd3d26e8..22f4f98f5d3 100644 --- a/python/pyarrow/tests/test_udf.py +++ b/python/pyarrow/tests/test_udf.py @@ -658,16 +658,33 @@ def test_udt_datasource1_exception(): def test_agg_basic(unary_agg_func_fixture): arr = pa.array([10.0, 20.0, 30.0, 40.0, 50.0], pa.float64()) - result = pc.call_function("y=avg(x)", [arr]) + result = pc.call_function("mean_udf", [arr]) expected = pa.scalar(30.0) assert result == expected +def test_hash_agg_basic(unary_agg_func_fixture): + arr1 = pa.array([10.0, 20.0, 30.0, 40.0, 50.0], pa.float64()) + arr2 = pa.array([4, 2, 1, 2, 1], pa.int32()) + + arr3 = pa.array([60.0, 70.0, 80.0, 90.0, 100.0], pa.float64()) + arr4 = pa.array([5, 1, 1, 4, 1], pa.int32()) + + table1 = pa.table([arr2, arr1], names=["id", "value"]) + table2 = pa.table([arr4, arr3], names=["id", "value"]) + table = pa.concat_tables([table1, table2]) + + breakpoint() + result = table.group_by("id").aggregate([("value", "mean_udf")]) + expected = table.group_by("id").aggregate([("value", "mean")]).rename_columns(['id', 'value_mean_udf']) + + assert result.sort_by('id') == expected.sort_by('id') + def test_agg_empty(unary_agg_func_fixture): empty = pa.array([], pa.float64()) with pytest.raises(pa.ArrowInvalid, match='empty inputs'): - pc.call_function("y=avg(x)", [empty]) + pc.call_function("mean_udf", [empty]) def test_agg_wrong_output_dtype(wrong_output_dtype_agg_func_fixture): @@ -698,3 +715,4 @@ def test_agg_exception(exception_agg_func_fixture): with pytest.raises(RuntimeError, match='Oops'): pc.call_function("y=exception_len(x)", [arr]) + From 76c2b620134040b15beb80546d37ca24f3f37cb0 Mon Sep 17 00:00:00 2001 From: Li Jin Date: Tue, 20 Jun 2023 17:49:02 -0400 Subject: [PATCH 02/11] Remove breakpoints(); Fix num_args --- python/pyarrow/conftest.py | 1 - python/pyarrow/src/arrow/python/udf.cc | 4 ++-- python/pyarrow/tests/test_udf.py | 1 - 3 files changed, 2 insertions(+), 4 deletions(-) diff --git a/python/pyarrow/conftest.py b/python/pyarrow/conftest.py index 96ff173d6d1..de001a22f2e 100644 --- a/python/pyarrow/conftest.py +++ b/python/pyarrow/conftest.py @@ -295,7 +295,6 @@ def func(ctx, x, *args): func_doc = {"summary": "y=avg(x)", "description": "find mean of x"} - breakpoint() pc.register_aggregate_function(func, func_name, func_doc, diff --git a/python/pyarrow/src/arrow/python/udf.cc b/python/pyarrow/src/arrow/python/udf.cc index dec1d649164..5ef38a36913 100644 --- a/python/pyarrow/src/arrow/python/udf.cc +++ b/python/pyarrow/src/arrow/python/udf.cc @@ -329,7 +329,8 @@ struct PythonUdfHashAggregatorImpl : public HashUdfAggregator { } Status Finalize(KernelContext* ctx, Datum* out) { - const int num_args = input_schema->num_fields(); + // Exclude the last column which is the group id + const int num_args = input_schema->num_fields() - 1; ARROW_ASSIGN_OR_RAISE(auto groups_buffer, groups.Finish()); ARROW_ASSIGN_OR_RAISE( @@ -355,7 +356,6 @@ struct PythonUdfHashAggregatorImpl : public HashUdfAggregator { OwnedRef arg_tuple(PyTuple_New(num_args)); RETURN_NOT_OK(CheckPyError()); - // Exclude the last column which is the group id for (int arg_id = 0; arg_id < num_args; arg_id++) { // Since we combined chunks there is only one chunk std::shared_ptr c_data = group_rb->column(arg_id); diff --git a/python/pyarrow/tests/test_udf.py b/python/pyarrow/tests/test_udf.py index 22f4f98f5d3..e9383c0fbab 100644 --- a/python/pyarrow/tests/test_udf.py +++ b/python/pyarrow/tests/test_udf.py @@ -674,7 +674,6 @@ def test_hash_agg_basic(unary_agg_func_fixture): table2 = pa.table([arr4, arr3], names=["id", "value"]) table = pa.concat_tables([table1, table2]) - breakpoint() result = table.group_by("id").aggregate([("value", "mean_udf")]) expected = table.group_by("id").aggregate([("value", "mean")]).rename_columns(['id', 'value_mean_udf']) From 1a804f142ae7f0b8bc996dd506e57e3e09e1b117 Mon Sep 17 00:00:00 2001 From: Li Jin Date: Thu, 22 Jun 2023 13:51:12 -0400 Subject: [PATCH 03/11] Add more tests --- python/pyarrow/conftest.py | 8 +- python/pyarrow/tests/test_substrait.py | 173 ++++++++++++++++++++++++- python/pyarrow/tests/test_udf.py | 45 ++++++- 3 files changed, 219 insertions(+), 7 deletions(-) diff --git a/python/pyarrow/conftest.py b/python/pyarrow/conftest.py index de001a22f2e..da7fdffc3cf 100644 --- a/python/pyarrow/conftest.py +++ b/python/pyarrow/conftest.py @@ -20,6 +20,8 @@ from pyarrow import Codec from pyarrow import fs +import numpy as np + groups = [ 'acero', 'brotli', @@ -283,10 +285,9 @@ def unary_function(ctx, x): @pytest.fixture(scope="session") def unary_agg_func_fixture(): """ - Register a unary aggregate function + Register a unary aggregate function (mean) """ from pyarrow import compute as pc - import numpy as np def func(ctx, x, *args): return pa.scalar(np.nanmean(x)) @@ -312,7 +313,6 @@ def varargs_agg_func_fixture(): Register a unary aggregate function """ from pyarrow import compute as pc - import numpy as np def func(ctx, *args): sum = 0.0 @@ -320,7 +320,7 @@ def func(ctx, *args): sum += np.nanmean(arg) return pa.scalar(sum) - func_name = "y=sum_mean(x...)" + func_name = "sum_mean" func_doc = {"summary": "Varargs aggregate", "description": "Varargs aggregate"} diff --git a/python/pyarrow/tests/test_substrait.py b/python/pyarrow/tests/test_substrait.py index 34faaa157af..0258170263e 100644 --- a/python/pyarrow/tests/test_substrait.py +++ b/python/pyarrow/tests/test_substrait.py @@ -607,7 +607,7 @@ def table_provider(names, schema): assert res_tb == expected -def test_aggregate_udf_basic(varargs_agg_func_fixture): +def test_scalar_aggregate_udf_basic(varargs_agg_func_fixture): test_table = pa.Table.from_pydict( {"k": [1, 1, 2, 2], "v1": [1, 2, 3, 4], @@ -630,7 +630,7 @@ def table_provider(names, _): "extensionFunction": { "extensionUriReference": 1, "functionAnchor": 1, - "name": "y=sum_mean(x...)" + "name": "sum_mean" } } ], @@ -753,3 +753,172 @@ def table_provider(names, _): }) assert res_tb == expected_tb + +def test_hash_aggregate_udf_basic(varargs_agg_func_fixture): + + test_table = pa.Table.from_pydict( + {"t": [1, 1, 1, 1, 2, 2, 2, 2], + "k": [1, 0, 0, 1, 0, 1, 0, 1], + "v1": [1, 2, 3, 4, 5, 6, 7, 8], + "v2": [1.0, 1.0, 1.0, 1.0, 2.0, 3.0, 4.0, 5.0]} + ) + + def table_provider(names, _): + return test_table + + substrait_query = b""" +{ + "extensionUris": [ + { + "extensionUriAnchor": 1, + "uri": "urn:arrow:substrait_simple_extension_function" + }, + ], + "extensions": [ + { + "extensionFunction": { + "extensionUriReference": 1, + "functionAnchor": 1, + "name": "sum_mean" + } + } + ], + "relations": [ + { + "root": { + "input": { + "extensionSingle": { + "common": { + "emit": { + "outputMapping": [ + 0, + 1, + 2 + ] + } + }, + "input": { + "read": { + "baseSchema": { + "names": [ + "t", + "k", + "v1", + "v2", + ], + "struct": { + "types": [ + { + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + { + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + { + "i64": { + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "fp64": { + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": ["t1"] + } + } + }, + "detail": { + "@type": "/arrow.substrait_ext.SegmentedAggregateRel", + "groupingKeys": [ + { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": {} + } + ], + "segmentKeys": [ + { + "directReference": { + "structField": {} + }, + "rootReference": {} + } + ], + "measures": [ + { + "measure": { + "functionReference": 1, + "phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT", + "outputType": { + "fp64": { + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 2 + } + }, + "rootReference": {} + } + } + }, + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 3 + } + }, + "rootReference": {} + } + } + } + ] + } + } + ] + } + } + }, + "names": [ + "t", + "k", + "v_avg" + ] + } + } + ], +} +""" + buf = pa._substrait._parse_json_plan(substrait_query) + reader = pa.substrait.run_query( + buf, table_provider=table_provider, use_threads=False) + res_tb = reader.read_all() + + expected_tb = pa.Table.from_pydict({ + 't': [1, 1, 2, 2], + 'k': [1, 0, 0, 1], + 'v_avg': [3.5, 3.5, 9.0, 11.0] + }) + + # Ordering of k is deterministic because this is running with serial execution + assert res_tb == expected_tb \ No newline at end of file diff --git a/python/pyarrow/tests/test_udf.py b/python/pyarrow/tests/test_udf.py index e9383c0fbab..286b9bc9758 100644 --- a/python/pyarrow/tests/test_udf.py +++ b/python/pyarrow/tests/test_udf.py @@ -18,6 +18,8 @@ import pytest +import numpy as np + import pyarrow as pa from pyarrow import compute as pc @@ -42,6 +44,27 @@ class MyError(RuntimeError): pass +@pytest.fixture(scope="session") +def sum_agg_func_fixture(): + """ + Register a unary aggregate function (mean) + """ + def func(ctx, x, *args): + return pa.scalar(np.nansum(x)) + + func_name = "sum_udf" + func_doc = empty_udf_doc + + pc.register_aggregate_function(func, + func_name, + func_doc, + { + "x": pa.float64(), + }, + pa.float64() + ) + return func, func_name + @pytest.fixture(scope="session") def exception_agg_func_fixture(): def func(ctx, x): @@ -679,6 +702,26 @@ def test_hash_agg_basic(unary_agg_func_fixture): assert result.sort_by('id') == expected.sort_by('id') +def test_hash_agg_random(sum_agg_func_fixture): + """Test hash aggregate udf with randomly sampled data""" + + value_num = 1000000 + group_num = 1000 + seed = 1 + + rng = np.random.default_rng(seed=seed) + + arr1 = pa.array(np.repeat(1, value_num), pa.float64()) + arr2 = pa.array(rng.choice(group_num, value_num), pa.int32()) + + table = pa.table([arr2, arr1], names=['id', 'value']) + + result = table.group_by("id").aggregate([("value", "sum_udf")]) + expected = table.group_by("id").aggregate([("value", "sum")]).rename_columns(['id', 'value_sum_udf']) + + assert result.sort_by('id') == expected.sort_by('id') + + def test_agg_empty(unary_agg_func_fixture): empty = pa.array([], pa.float64()) @@ -703,7 +746,7 @@ def test_agg_varargs(varargs_agg_func_fixture): arr2 = pa.array([1.0, 2.0, 3.0, 4.0, 5.0], pa.float64()) result = pc.call_function( - "y=sum_mean(x...)", [arr1, arr2] + "sum_mean", [arr1, arr2] ) expected = pa.scalar(33.0) assert result == expected From d1c60ec4d62141c99899ff8ea195e7fe23568645 Mon Sep 17 00:00:00 2001 From: Li Jin Date: Thu, 22 Jun 2023 13:58:01 -0400 Subject: [PATCH 04/11] Lint --- cpp/src/arrow/compute/row/grouper.h | 3 +- python/pyarrow/src/arrow/python/udf.cc | 191 +++++++++++++------------ python/pyarrow/tests/test_substrait.py | 3 +- python/pyarrow/tests/test_udf.py | 9 +- 4 files changed, 109 insertions(+), 97 deletions(-) diff --git a/cpp/src/arrow/compute/row/grouper.h b/cpp/src/arrow/compute/row/grouper.h index e273367a507..d8788f084ed 100644 --- a/cpp/src/arrow/compute/row/grouper.h +++ b/cpp/src/arrow/compute/row/grouper.h @@ -149,8 +149,7 @@ class ARROW_EXPORT Grouper { /// [] /// ] static Result> MakeGroupings( - const UInt32Array& ids, uint32_t , - ExecContext* ctx = default_exec_context()); + const UInt32Array& ids, uint32_t, ExecContext* ctx = default_exec_context()); /// \brief Produce a ListArray whose slots are selections of `array` which correspond to /// the provided groupings. diff --git a/python/pyarrow/src/arrow/python/udf.cc b/python/pyarrow/src/arrow/python/udf.cc index 5ef38a36913..2667d5c93b5 100644 --- a/python/pyarrow/src/arrow/python/udf.cc +++ b/python/pyarrow/src/arrow/python/udf.cc @@ -17,25 +17,25 @@ #include -#include "arrow/python/udf.h" -#include "arrow/compute/api_aggregate.h" -#include "arrow/buffer_builder.h" #include "arrow/array/builder_base.h" +#include "arrow/buffer_builder.h" +#include "arrow/compute/api_aggregate.h" #include "arrow/compute/api_vector.h" #include "arrow/compute/function.h" -#include "arrow/compute/row/grouper.h" #include "arrow/compute/kernel.h" +#include "arrow/compute/row/grouper.h" #include "arrow/python/common.h" -#include "arrow/util/logging.h" +#include "arrow/python/udf.h" #include "arrow/table.h" #include "arrow/util/checked_cast.h" +#include "arrow/util/logging.h" namespace arrow { -using internal::checked_cast; -using compute::KernelState; -using compute::KernelContext; using compute::ExecSpan; using compute::Grouper; +using compute::KernelContext; +using compute::KernelState; +using internal::checked_cast; namespace py { namespace { @@ -107,15 +107,17 @@ arrow::Status AggregateUdfFinalize(compute::KernelContext* ctx, arrow::Datum* ou } arrow::Status HashAggregateUdfResize(KernelContext* ctx, int64_t size) { - return checked_cast(ctx->state())->Resize(ctx, size); + return checked_cast(ctx->state())->Resize(ctx, size); } arrow::Status HashAggregateUdfConsume(KernelContext* ctx, const ExecSpan& batch) { - return checked_cast(ctx->state())->Consume(ctx, batch); + return checked_cast(ctx->state())->Consume(ctx, batch); } -arrow::Status HashAggregateUdfMerge(KernelContext* ctx, KernelState&& src, const ArrayData& group_id_mapping) { - return checked_cast(ctx->state())->Merge(ctx, std::move(src), group_id_mapping); +arrow::Status HashAggregateUdfMerge(KernelContext* ctx, KernelState&& src, + const ArrayData& group_id_mapping) { + return checked_cast(ctx->state()) + ->Merge(ctx, std::move(src), group_id_mapping); } arrow::Status HashAggregateUdfFinalize(KernelContext* ctx, Datum* out) { @@ -259,9 +261,9 @@ struct PythonUdfScalarAggregatorImpl : public ScalarUdfAggregator { struct PythonUdfHashAggregatorImpl : public HashUdfAggregator { PythonUdfHashAggregatorImpl(UdfWrapperCallback agg_cb, - std::shared_ptr function, - std::vector> input_types, - std::shared_ptr output_type) + std::shared_ptr function, + std::vector> input_types, + std::shared_ptr output_type) : agg_cb(std::move(agg_cb)), function(function), output_type(std::move(output_type)) { @@ -280,19 +282,19 @@ struct PythonUdfHashAggregatorImpl : public HashUdfAggregator { } static Result ApplyGroupings( - const ListArray& groupings, const std::shared_ptr& batch) { - ARROW_ASSIGN_OR_RAISE(Datum sorted, - compute::Take(batch, groupings.data()->child_data[0])); + const ListArray& groupings, const std::shared_ptr& batch) { + ARROW_ASSIGN_OR_RAISE(Datum sorted, + compute::Take(batch, groupings.data()->child_data[0])); - const auto& sorted_batch = *sorted.record_batch(); + const auto& sorted_batch = *sorted.record_batch(); - RecordBatchVector out(static_cast(groupings.length())); - for (size_t i = 0; i < out.size(); ++i) { - out[i] = sorted_batch.Slice(groupings.value_offset(i), groupings.value_length(i)); - } + RecordBatchVector out(static_cast(groupings.length())); + for (size_t i = 0; i < out.size(); ++i) { + out[i] = sorted_batch.Slice(groupings.value_offset(i), groupings.value_length(i)); + } - return out; -} + return out; + } Status Resize(KernelContext* ctx, int64_t new_num_groups) { num_groups = new_num_groups; @@ -301,7 +303,8 @@ struct PythonUdfHashAggregatorImpl : public HashUdfAggregator { Status Consume(KernelContext* ctx, const ExecSpan& batch) { // last array is the group id ARROW_ASSIGN_OR_RAISE( - std::shared_ptr rb, batch.ToExecBatch().ToRecordBatch(input_schema, ctx->memory_pool())); + std::shared_ptr rb, + batch.ToExecBatch().ToRecordBatch(input_schema, ctx->memory_pool())); const ArraySpan& groups_array_data = batch[batch.num_values() - 1].array; int64_t batch_num_values = groups_array_data.length; @@ -312,7 +315,8 @@ struct PythonUdfHashAggregatorImpl : public HashUdfAggregator { num_values += batch_num_values; return Status::OK(); } - Status Merge(KernelContext* ctx, KernelState&& other_state, const ArrayData& group_id_mapping) { + Status Merge(KernelContext* ctx, KernelState&& other_state, + const ArrayData& group_id_mapping) { auto& other = checked_cast(other_state); auto& other_values = other.values; const uint32_t* other_raw_groups = other.groups.data(); @@ -320,7 +324,8 @@ struct PythonUdfHashAggregatorImpl : public HashUdfAggregator { std::make_move_iterator(other_values.end())); auto g = group_id_mapping.GetValues(1); - for (uint32_t other_g = 0; static_cast(other_g) < other.num_values; ++other_g) { + for (uint32_t other_g = 0; static_cast(other_g) < other.num_values; + ++other_g) { RETURN_NOT_OK(groups.Append(g[other_raw_groups[other_g]])); } @@ -329,61 +334,63 @@ struct PythonUdfHashAggregatorImpl : public HashUdfAggregator { } Status Finalize(KernelContext* ctx, Datum* out) { - // Exclude the last column which is the group id - const int num_args = input_schema->num_fields() - 1; - - ARROW_ASSIGN_OR_RAISE(auto groups_buffer, groups.Finish()); - ARROW_ASSIGN_OR_RAISE( - auto groupings, - Grouper::MakeGroupings(UInt32Array(num_values, groups_buffer), static_cast(num_groups)) - ); - - ARROW_ASSIGN_OR_RAISE(auto table, - arrow::Table::FromRecordBatches(input_schema, values)); - ARROW_ASSIGN_OR_RAISE(auto rb, table->CombineChunksToBatch(ctx->memory_pool())); - UdfContext udf_context{ctx->memory_pool(), table->num_rows()}; - - if (rb->num_rows() == 0) { - return Status::Invalid("Finalized is called with empty inputs"); - } + // Exclude the last column which is the group id + const int num_args = input_schema->num_fields() - 1; - ARROW_ASSIGN_OR_RAISE(RecordBatchVector rbs, ApplyGroupings(*groupings, rb)); + ARROW_ASSIGN_OR_RAISE(auto groups_buffer, groups.Finish()); + ARROW_ASSIGN_OR_RAISE(auto groupings, + Grouper::MakeGroupings(UInt32Array(num_values, groups_buffer), + static_cast(num_groups))); - return SafeCallIntoPython([&] { - ARROW_ASSIGN_OR_RAISE(std::unique_ptr builder, MakeBuilder(output_type, ctx->memory_pool())); - for (auto& group_rb : rbs) { - std::unique_ptr result; - OwnedRef arg_tuple(PyTuple_New(num_args)); - RETURN_NOT_OK(CheckPyError()); + ARROW_ASSIGN_OR_RAISE(auto table, + arrow::Table::FromRecordBatches(input_schema, values)); + ARROW_ASSIGN_OR_RAISE(auto rb, table->CombineChunksToBatch(ctx->memory_pool())); + UdfContext udf_context{ctx->memory_pool(), table->num_rows()}; - for (int arg_id = 0; arg_id < num_args; arg_id++) { - // Since we combined chunks there is only one chunk - std::shared_ptr c_data = group_rb->column(arg_id); - PyObject* data = wrap_array(c_data); - PyTuple_SetItem(arg_tuple.obj(), arg_id, data); - } + if (rb->num_rows() == 0) { + return Status::Invalid("Finalized is called with empty inputs"); + } - result = std::make_unique(agg_cb(function->obj(), udf_context, arg_tuple.obj())); - RETURN_NOT_OK(CheckPyError()); - - // unwrapping the output for expected output type - if (is_scalar(result->obj())) { - ARROW_ASSIGN_OR_RAISE(std::shared_ptr val, unwrap_scalar(result->obj())); - if (*output_type != *val->type) { - return Status::TypeError("Expected output datatype ", output_type->ToString(), - ", but function returned datatype ", - val->type->ToString()); - } - ARROW_RETURN_NOT_OK(builder->AppendScalar(std::move(*val))); - } else { - return Status::TypeError("Unexpected output type: ", - Py_TYPE(result->obj())->tp_name, " (expected Scalar)"); + ARROW_ASSIGN_OR_RAISE(RecordBatchVector rbs, ApplyGroupings(*groupings, rb)); + + return SafeCallIntoPython([&] { + ARROW_ASSIGN_OR_RAISE(std::unique_ptr builder, + MakeBuilder(output_type, ctx->memory_pool())); + for (auto& group_rb : rbs) { + std::unique_ptr result; + OwnedRef arg_tuple(PyTuple_New(num_args)); + RETURN_NOT_OK(CheckPyError()); + + for (int arg_id = 0; arg_id < num_args; arg_id++) { + // Since we combined chunks there is only one chunk + std::shared_ptr c_data = group_rb->column(arg_id); + PyObject* data = wrap_array(c_data); + PyTuple_SetItem(arg_tuple.obj(), arg_id, data); + } + + result = std::make_unique( + agg_cb(function->obj(), udf_context, arg_tuple.obj())); + RETURN_NOT_OK(CheckPyError()); + + // unwrapping the output for expected output type + if (is_scalar(result->obj())) { + ARROW_ASSIGN_OR_RAISE(std::shared_ptr val, + unwrap_scalar(result->obj())); + if (*output_type != *val->type) { + return Status::TypeError("Expected output datatype ", output_type->ToString(), + ", but function returned datatype ", + val->type->ToString()); } + ARROW_RETURN_NOT_OK(builder->AppendScalar(std::move(*val))); + } else { + return Status::TypeError("Unexpected output type: ", + Py_TYPE(result->obj())->tp_name, " (expected Scalar)"); } - ARROW_ASSIGN_OR_RAISE(auto result, builder->Finish()); - out->value = std::move(result->data()); - return Status::OK(); - }); + } + ARROW_ASSIGN_OR_RAISE(auto result, builder->Finish()); + out->value = std::move(result->data()); + return Status::OK(); + }); } int64_t num_groups = 0; @@ -529,9 +536,10 @@ Status RegisterTabularFunction(PyObject* user_function, UdfWrapperCallback wrapp wrapper, options, registry); } -Status RegisterScalarAggregateFunction(PyObject* agg_function, UdfWrapperCallback agg_wrapper, - const UdfOptions& options, - compute::FunctionRegistry* registry) { +Status RegisterScalarAggregateFunction(PyObject* agg_function, + UdfWrapperCallback agg_wrapper, + const UdfOptions& options, + compute::FunctionRegistry* registry) { if (!PyCallable_Check(agg_function)) { return Status::TypeError("Expected a callable Python object."); } @@ -567,7 +575,7 @@ Status RegisterScalarAggregateFunction(PyObject* agg_function, UdfWrapperCallbac }; auto sig = compute::KernelSignature::Make( - std::move(input_types), std::move(output_type), options.arity.is_varargs); + std::move(input_types), std::move(output_type), options.arity.is_varargs); compute::ScalarAggregateKernel kernel(std::move(sig), std::move(init), AggregateUdfConsume, AggregateUdfMerge, AggregateUdfFinalize, /*ordered=*/false); @@ -603,9 +611,10 @@ UdfOptions AdjustForHashAggregate(const UdfOptions& options) { return hash_options; } -Status RegisterHashAggregateFunction(PyObject* agg_function, UdfWrapperCallback agg_wrapper, - const UdfOptions& options, - compute::FunctionRegistry* registry) { +Status RegisterHashAggregateFunction(PyObject* agg_function, + UdfWrapperCallback agg_wrapper, + const UdfOptions& options, + compute::FunctionRegistry* registry) { if (!PyCallable_Check(agg_function)) { return Status::TypeError("Expected a callable Python object."); } @@ -637,16 +646,16 @@ Status RegisterHashAggregateFunction(PyObject* agg_function, UdfWrapperCallback const compute::KernelInitArgs& args) -> Result> { return std::make_unique( - agg_wrapper, std::make_shared(agg_function), hash_options.input_types, - hash_options.output_type); + agg_wrapper, std::make_shared(agg_function), + hash_options.input_types, hash_options.output_type); }; - auto sig = compute::KernelSignature::Make(std::move(input_types), std::move(output_type), - hash_options.arity.is_varargs); + auto sig = compute::KernelSignature::Make( + std::move(input_types), std::move(output_type), hash_options.arity.is_varargs); - compute::HashAggregateKernel kernel(std::move(sig), std::move(init), - HashAggregateUdfResize, HashAggregateUdfConsume, - HashAggregateUdfMerge, HashAggregateUdfFinalize, /*ordered=*/false); + compute::HashAggregateKernel kernel( + std::move(sig), std::move(init), HashAggregateUdfResize, HashAggregateUdfConsume, + HashAggregateUdfMerge, HashAggregateUdfFinalize, /*ordered=*/false); RETURN_NOT_OK(hash_aggregate_func->AddKernel(std::move(kernel))); RETURN_NOT_OK(registry->AddFunction(std::move(hash_aggregate_func))); return Status::OK(); diff --git a/python/pyarrow/tests/test_substrait.py b/python/pyarrow/tests/test_substrait.py index 0258170263e..93ecae7bfa1 100644 --- a/python/pyarrow/tests/test_substrait.py +++ b/python/pyarrow/tests/test_substrait.py @@ -754,6 +754,7 @@ def table_provider(names, _): assert res_tb == expected_tb + def test_hash_aggregate_udf_basic(varargs_agg_func_fixture): test_table = pa.Table.from_pydict( @@ -921,4 +922,4 @@ def table_provider(names, _): }) # Ordering of k is deterministic because this is running with serial execution - assert res_tb == expected_tb \ No newline at end of file + assert res_tb == expected_tb diff --git a/python/pyarrow/tests/test_udf.py b/python/pyarrow/tests/test_udf.py index 286b9bc9758..06f5e7e3250 100644 --- a/python/pyarrow/tests/test_udf.py +++ b/python/pyarrow/tests/test_udf.py @@ -65,6 +65,7 @@ def func(ctx, x, *args): ) return func, func_name + @pytest.fixture(scope="session") def exception_agg_func_fixture(): def func(ctx, x): @@ -698,10 +699,12 @@ def test_hash_agg_basic(unary_agg_func_fixture): table = pa.concat_tables([table1, table2]) result = table.group_by("id").aggregate([("value", "mean_udf")]) - expected = table.group_by("id").aggregate([("value", "mean")]).rename_columns(['id', 'value_mean_udf']) + expected = table.group_by("id").aggregate( + [("value", "mean")]).rename_columns(['id', 'value_mean_udf']) assert result.sort_by('id') == expected.sort_by('id') + def test_hash_agg_random(sum_agg_func_fixture): """Test hash aggregate udf with randomly sampled data""" @@ -717,7 +720,8 @@ def test_hash_agg_random(sum_agg_func_fixture): table = pa.table([arr2, arr1], names=['id', 'value']) result = table.group_by("id").aggregate([("value", "sum_udf")]) - expected = table.group_by("id").aggregate([("value", "sum")]).rename_columns(['id', 'value_sum_udf']) + expected = table.group_by("id").aggregate( + [("value", "sum")]).rename_columns(['id', 'value_sum_udf']) assert result.sort_by('id') == expected.sort_by('id') @@ -757,4 +761,3 @@ def test_agg_exception(exception_agg_func_fixture): with pytest.raises(RuntimeError, match='Oops'): pc.call_function("y=exception_len(x)", [arr]) - From be8474b750094ebbc45067a8c75091c55c69e30e Mon Sep 17 00:00:00 2001 From: Li Jin Date: Thu, 22 Jun 2023 16:56:22 -0400 Subject: [PATCH 05/11] Lint, self review and documentation --- cpp/src/arrow/compute/row/grouper.h | 3 +- python/pyarrow/_compute.pyx | 9 +++ python/pyarrow/conftest.py | 2 +- python/pyarrow/src/arrow/python/udf.cc | 104 +++++++++++++------------ 4 files changed, 65 insertions(+), 53 deletions(-) diff --git a/cpp/src/arrow/compute/row/grouper.h b/cpp/src/arrow/compute/row/grouper.h index d8788f084ed..15f00eaac21 100644 --- a/cpp/src/arrow/compute/row/grouper.h +++ b/cpp/src/arrow/compute/row/grouper.h @@ -149,7 +149,8 @@ class ARROW_EXPORT Grouper { /// [] /// ] static Result> MakeGroupings( - const UInt32Array& ids, uint32_t, ExecContext* ctx = default_exec_context()); + const UInt32Array& ids, uint32_t num_groups, + ExecContext* ctx = default_exec_context()); /// \brief Produce a ListArray whose slots are selections of `array` which correspond to /// the provided groupings. diff --git a/python/pyarrow/_compute.pyx b/python/pyarrow/_compute.pyx index eaf9d1dfb65..380fd7e4227 100644 --- a/python/pyarrow/_compute.pyx +++ b/python/pyarrow/_compute.pyx @@ -2823,6 +2823,15 @@ def register_aggregate_function(func, function_name, function_doc, in_types, out >>> answer = pc.call_function(func_name, [pa.array([20, 40])]) >>> answer + >>> table = pa.table([pa.array([1, 1, 2, 2]), pa.array([10, 20, 30, 40])], names=['k', 'v']) + >>> result = table.group_by('k').aggregate(['v', 'py_compute_median']) + >>> result + pyarrow.Table + k: int64 + v_py_compute_median: double + ---- + k: [[1,2]] + v_py_compute_median: [[15,35]] """ return _register_user_defined_function(get_register_aggregate_function(), func, function_name, function_doc, in_types, diff --git a/python/pyarrow/conftest.py b/python/pyarrow/conftest.py index da7fdffc3cf..6f6807e907d 100644 --- a/python/pyarrow/conftest.py +++ b/python/pyarrow/conftest.py @@ -289,7 +289,7 @@ def unary_agg_func_fixture(): """ from pyarrow import compute as pc - def func(ctx, x, *args): + def func(ctx, x): return pa.scalar(np.nanmean(x)) func_name = "mean_udf" diff --git a/python/pyarrow/src/arrow/python/udf.cc b/python/pyarrow/src/arrow/python/udf.cc index 2667d5c93b5..d761520ac0c 100644 --- a/python/pyarrow/src/arrow/python/udf.cc +++ b/python/pyarrow/src/arrow/python/udf.cc @@ -161,14 +161,12 @@ struct PythonTableUdfKernelInit { }; struct PythonUdfScalarAggregatorImpl : public ScalarUdfAggregator { - PythonUdfScalarAggregatorImpl(UdfWrapperCallback agg_cb, - std::shared_ptr agg_function, + PythonUdfScalarAggregatorImpl(std::shared_ptr function, + UdfWrapperCallback cb, std::vector> input_types, std::shared_ptr output_type) - : agg_cb(std::move(agg_cb)), - agg_function(agg_function), - output_type(std::move(output_type)) { - Py_INCREF(agg_function->obj()); + : function(function), cb(std::move(cb)), output_type(std::move(output_type)) { + Py_INCREF(function->obj()); std::vector> fields; for (size_t i = 0; i < input_types.size(); i++) { fields.push_back(field("", input_types[i])); @@ -178,7 +176,7 @@ struct PythonUdfScalarAggregatorImpl : public ScalarUdfAggregator { ~PythonUdfScalarAggregatorImpl() override { if (_Py_IsFinalizing()) { - agg_function->detach(); + function->detach(); } } @@ -201,7 +199,6 @@ struct PythonUdfScalarAggregatorImpl : public ScalarUdfAggregator { Status Finalize(compute::KernelContext* ctx, Datum* out) override { auto state = arrow::internal::checked_cast(ctx->state()); - std::shared_ptr& function = state->agg_function; const int num_args = input_schema->num_fields(); // Note: The way that batches are concatenated together @@ -232,8 +229,8 @@ struct PythonUdfScalarAggregatorImpl : public ScalarUdfAggregator { PyObject* data = wrap_array(c_data); PyTuple_SetItem(arg_tuple.obj(), arg_id, data); } - result = std::make_unique( - agg_cb(function->obj(), udf_context, arg_tuple.obj())); + result = + std::make_unique(cb(function->obj(), udf_context, arg_tuple.obj())); RETURN_NOT_OK(CheckPyError()); // unwrapping the output for expected output type if (is_scalar(result->obj())) { @@ -252,21 +249,19 @@ struct PythonUdfScalarAggregatorImpl : public ScalarUdfAggregator { return Status::OK(); } - UdfWrapperCallback agg_cb; + std::shared_ptr function; + UdfWrapperCallback cb; std::vector> values; - std::shared_ptr agg_function; std::shared_ptr input_schema; std::shared_ptr output_type; }; struct PythonUdfHashAggregatorImpl : public HashUdfAggregator { - PythonUdfHashAggregatorImpl(UdfWrapperCallback agg_cb, - std::shared_ptr function, + PythonUdfHashAggregatorImpl(std::shared_ptr function, + UdfWrapperCallback cb, std::vector> input_types, std::shared_ptr output_type) - : agg_cb(std::move(agg_cb)), - function(function), - output_type(std::move(output_type)) { + : function(function), cb(std::move(cb)), output_type(std::move(output_type)) { Py_INCREF(function->obj()); std::vector> fields; for (size_t i = 0; i < input_types.size(); i++) { @@ -281,6 +276,8 @@ struct PythonUdfHashAggregatorImpl : public HashUdfAggregator { } } + /// @brief Same as ApplyGrouping in parition.cc + /// Replicated the code here to avoid complicating the dependencies static Result ApplyGroupings( const ListArray& groupings, const std::shared_ptr& batch) { ARROW_ASSIGN_OR_RAISE(Datum sorted, @@ -297,19 +294,23 @@ struct PythonUdfHashAggregatorImpl : public HashUdfAggregator { } Status Resize(KernelContext* ctx, int64_t new_num_groups) { + // We only need to change num_groups in resize + // similar to other hash aggregate kernels num_groups = new_num_groups; return Status::OK(); } + Status Consume(KernelContext* ctx, const ExecSpan& batch) { - // last array is the group id ARROW_ASSIGN_OR_RAISE( std::shared_ptr rb, batch.ToExecBatch().ToRecordBatch(input_schema, ctx->memory_pool())); + // This is similar to GroupedListImpl + // last array is the group id const ArraySpan& groups_array_data = batch[batch.num_values() - 1].array; + DCHECK_EQ(groups_array_data.offset, 0); int64_t batch_num_values = groups_array_data.length; const auto* batch_groups = groups_array_data.GetValues(1, 0); - DCHECK_EQ(groups_array_data.offset, 0); RETURN_NOT_OK(groups.Append(batch_groups, batch_num_values)); values.push_back(std::move(rb)); num_values += batch_num_values; @@ -317,6 +318,7 @@ struct PythonUdfHashAggregatorImpl : public HashUdfAggregator { } Status Merge(KernelContext* ctx, KernelState&& other_state, const ArrayData& group_id_mapping) { + // This is similar to GroupedListImpl auto& other = checked_cast(other_state); auto& other_values = other.values; const uint32_t* other_raw_groups = other.groups.data(); @@ -326,6 +328,8 @@ struct PythonUdfHashAggregatorImpl : public HashUdfAggregator { auto g = group_id_mapping.GetValues(1); for (uint32_t other_g = 0; static_cast(other_g) < other.num_values; ++other_g) { + // Different state can have different group_id mappings, so we + // need to translate the ids RETURN_NOT_OK(groups.Append(g[other_raw_groups[other_g]])); } @@ -368,8 +372,8 @@ struct PythonUdfHashAggregatorImpl : public HashUdfAggregator { PyTuple_SetItem(arg_tuple.obj(), arg_id, data); } - result = std::make_unique( - agg_cb(function->obj(), udf_context, arg_tuple.obj())); + result = + std::make_unique(cb(function->obj(), udf_context, arg_tuple.obj())); RETURN_NOT_OK(CheckPyError()); // unwrapping the output for expected output type @@ -393,12 +397,14 @@ struct PythonUdfHashAggregatorImpl : public HashUdfAggregator { }); } - int64_t num_groups = 0; - int64_t num_values = 0; - UdfWrapperCallback agg_cb; + std::shared_ptr function; + UdfWrapperCallback cb; + // Accumulated input batches std::vector> values; + // Group ids - extracted from the last column from the batch TypedBufferBuilder groups; - std::shared_ptr function; + int64_t num_groups = 0; + int64_t num_values = 0; std::shared_ptr input_schema; std::shared_ptr output_type; }; @@ -513,15 +519,15 @@ Status RegisterUdf(PyObject* user_function, compute::KernelInit kernel_init, } // namespace -Status RegisterScalarFunction(PyObject* user_function, UdfWrapperCallback wrapper, +Status RegisterScalarFunction(PyObject* function, UdfWrapperCallback cb, const UdfOptions& options, compute::FunctionRegistry* registry) { - return RegisterUdf(user_function, - PythonUdfKernelInit{std::make_shared(user_function)}, - wrapper, options, registry); + return RegisterUdf(function, + PythonUdfKernelInit{std::make_shared(function)}, cb, + options, registry); } -Status RegisterTabularFunction(PyObject* user_function, UdfWrapperCallback wrapper, +Status RegisterTabularFunction(PyObject* function, UdfWrapperCallback cb, const UdfOptions& options, compute::FunctionRegistry* registry) { if (options.arity.num_args != 0 || options.arity.is_varargs) { @@ -531,16 +537,14 @@ Status RegisterTabularFunction(PyObject* user_function, UdfWrapperCallback wrapp return Status::Invalid("tabular function with non-struct output"); } return RegisterUdf( - user_function, - PythonTableUdfKernelInit{std::make_shared(user_function), wrapper}, - wrapper, options, registry); + function, PythonTableUdfKernelInit{std::make_shared(function), cb}, + cb, options, registry); } -Status RegisterScalarAggregateFunction(PyObject* agg_function, - UdfWrapperCallback agg_wrapper, +Status RegisterScalarAggregateFunction(PyObject* function, UdfWrapperCallback cb, const UdfOptions& options, compute::FunctionRegistry* registry) { - if (!PyCallable_Check(agg_function)) { + if (!PyCallable_Check(function)) { return Status::TypeError("Expected a callable Python object."); } @@ -551,7 +555,7 @@ Status RegisterScalarAggregateFunction(PyObject* agg_function, // Py_INCREF here so that once a function is registered // its refcount gets increased by 1 and doesn't get gced // if all existing refs are gone - Py_INCREF(agg_function); + Py_INCREF(function); static auto default_scalar_aggregate_options = compute::ScalarAggregateOptions::Defaults(); @@ -565,12 +569,11 @@ Status RegisterScalarAggregateFunction(PyObject* agg_function, } compute::OutputType output_type(options.output_type); - compute::KernelInit init = [agg_wrapper, agg_function, options]( - compute::KernelContext* ctx, - const compute::KernelInitArgs& args) + compute::KernelInit init = [cb, function, options](compute::KernelContext* ctx, + const compute::KernelInitArgs& args) -> Result> { return std::make_unique( - agg_wrapper, std::make_shared(agg_function), options.input_types, + std::make_shared(function), cb, options.input_types, options.output_type); }; @@ -611,11 +614,10 @@ UdfOptions AdjustForHashAggregate(const UdfOptions& options) { return hash_options; } -Status RegisterHashAggregateFunction(PyObject* agg_function, - UdfWrapperCallback agg_wrapper, +Status RegisterHashAggregateFunction(PyObject* function, UdfWrapperCallback cb, const UdfOptions& options, compute::FunctionRegistry* registry) { - if (!PyCallable_Check(agg_function)) { + if (!PyCallable_Check(function)) { return Status::TypeError("Expected a callable Python object."); } @@ -626,7 +628,7 @@ Status RegisterHashAggregateFunction(PyObject* agg_function, // Py_INCREF here so that once a function is registered // its refcount gets increased by 1 and doesn't get gced // if all existing refs are gone - Py_INCREF(agg_function); + Py_INCREF(function); UdfOptions hash_options = AdjustForHashAggregate(options); std::vector input_types; @@ -641,13 +643,13 @@ Status RegisterHashAggregateFunction(PyObject* agg_function, hash_options.func_name, hash_options.arity, hash_options.func_doc, &default_hash_aggregate_options); - compute::KernelInit init = [agg_wrapper, agg_function, hash_options]( + compute::KernelInit init = [function, cb, hash_options]( compute::KernelContext* ctx, const compute::KernelInitArgs& args) -> Result> { return std::make_unique( - agg_wrapper, std::make_shared(agg_function), - hash_options.input_types, hash_options.output_type); + std::make_shared(function), cb, hash_options.input_types, + hash_options.output_type); }; auto sig = compute::KernelSignature::Make( @@ -661,11 +663,11 @@ Status RegisterHashAggregateFunction(PyObject* agg_function, return Status::OK(); } -Status RegisterAggregateFunction(PyObject* function, UdfWrapperCallback wrapper, +Status RegisterAggregateFunction(PyObject* function, UdfWrapperCallback cb, const UdfOptions& options, compute::FunctionRegistry* registry) { - RETURN_NOT_OK(RegisterScalarAggregateFunction(function, wrapper, options, registry)); - RETURN_NOT_OK(RegisterHashAggregateFunction(function, wrapper, options, registry)); + RETURN_NOT_OK(RegisterScalarAggregateFunction(function, cb, options, registry)); + RETURN_NOT_OK(RegisterHashAggregateFunction(function, cb, options, registry)); return Status::OK(); } From 1519e456e861d1917a9d46a5c48d6c6aff4a9037 Mon Sep 17 00:00:00 2001 From: Li Jin Date: Thu, 22 Jun 2023 17:28:26 -0400 Subject: [PATCH 06/11] More tests for error case --- python/pyarrow/tests/test_udf.py | 113 ++++++++++++++++++++----------- 1 file changed, 73 insertions(+), 40 deletions(-) diff --git a/python/pyarrow/tests/test_udf.py b/python/pyarrow/tests/test_udf.py index 06f5e7e3250..48667b35369 100644 --- a/python/pyarrow/tests/test_udf.py +++ b/python/pyarrow/tests/test_udf.py @@ -680,13 +680,50 @@ def test_udt_datasource1_exception(): _test_datasource1_udt(datasource1_exception) -def test_agg_basic(unary_agg_func_fixture): +def test_scalar_agg_basic(unary_agg_func_fixture): arr = pa.array([10.0, 20.0, 30.0, 40.0, 50.0], pa.float64()) result = pc.call_function("mean_udf", [arr]) expected = pa.scalar(30.0) assert result == expected +def test_scalar_agg_empty(unary_agg_func_fixture): + empty = pa.array([], pa.float64()) + + with pytest.raises(pa.ArrowInvalid, match='empty inputs'): + pc.call_function("mean_udf", [empty]) + + +def test_scalar_agg_wrong_output_dtype(wrong_output_dtype_agg_func_fixture): + arr = pa.array([10, 20, 30, 40, 50], pa.int64()) + with pytest.raises(pa.ArrowTypeError, match="output datatype"): + pc.call_function("y=wrong_output_dtype(x)", [arr]) + + +def test_scalar_agg_wrong_output_type(wrong_output_type_agg_func_fixture): + arr = pa.array([10, 20, 30, 40, 50], pa.int64()) + with pytest.raises(pa.ArrowTypeError, match="output type"): + pc.call_function("y=wrong_output_type(x)", [arr]) + + +def test_scalar_agg_varargs(varargs_agg_func_fixture): + arr1 = pa.array([10, 20, 30, 40, 50], pa.int64()) + arr2 = pa.array([1.0, 2.0, 3.0, 4.0, 5.0], pa.float64()) + + result = pc.call_function( + "sum_mean", [arr1, arr2] + ) + expected = pa.scalar(33.0) + assert result == expected + + +def test_scalar_agg_exception(exception_agg_func_fixture): + arr = pa.array([10, 20, 30, 40, 50, 60], pa.int64()) + + with pytest.raises(RuntimeError, match='Oops'): + pc.call_function("y=exception_len(x)", [arr]) + + def test_hash_agg_basic(unary_agg_func_fixture): arr1 = pa.array([10.0, 20.0, 30.0, 40.0, 50.0], pa.float64()) arr2 = pa.array([4, 2, 1, 2, 1], pa.int32()) @@ -705,59 +742,55 @@ def test_hash_agg_basic(unary_agg_func_fixture): assert result.sort_by('id') == expected.sort_by('id') -def test_hash_agg_random(sum_agg_func_fixture): - """Test hash aggregate udf with randomly sampled data""" - - value_num = 1000000 - group_num = 1000 - seed = 1 +def test_hash_agg_empty(unary_agg_func_fixture): + arr1 = pa.array([], pa.float64()) + arr2 = pa.array([], pa.int32()) + table = pa.table([arr2, arr1], names=["id", "value"]) - rng = np.random.default_rng(seed=seed) + with pytest.raises(pa.ArrowInvalid, match='empty inputs'): + result = table.group_by("id").aggregate([("value", "mean_udf")]) - arr1 = pa.array(np.repeat(1, value_num), pa.float64()) - arr2 = pa.array(rng.choice(group_num, value_num), pa.int32()) - table = pa.table([arr2, arr1], names=['id', 'value']) - - result = table.group_by("id").aggregate([("value", "sum_udf")]) - expected = table.group_by("id").aggregate( - [("value", "sum")]).rename_columns(['id', 'value_sum_udf']) +def test_hash_agg_wrong_output_dtype(wrong_output_dtype_agg_func_fixture): + arr1 = pa.array([10, 20, 30, 40, 50], pa.int64()) + arr2 = pa.array([4, 2, 1, 2, 1], pa.int32()) - assert result.sort_by('id') == expected.sort_by('id') + table = pa.table([arr2, arr1], names=["id", "value"]) + with pytest.raises(pa.ArrowTypeError, match="output datatype"): + result = table.group_by("id").aggregate([("value", "y=wrong_output_dtype(x)")]) -def test_agg_empty(unary_agg_func_fixture): - empty = pa.array([], pa.float64()) +def test_hash_agg_wrong_output_type(wrong_output_type_agg_func_fixture): + arr1 = pa.array([10, 20, 30, 40, 50], pa.int64()) + arr2 = pa.array([4, 2, 1, 2, 1], pa.int32()) + table = pa.table([arr2, arr1], names=["id", "value"]) - with pytest.raises(pa.ArrowInvalid, match='empty inputs'): - pc.call_function("mean_udf", [empty]) + with pytest.raises(pa.ArrowTypeError, match="output type"): + result = table.group_by("id").aggregate([("value", "y=wrong_output_type(x)")]) -def test_agg_wrong_output_dtype(wrong_output_dtype_agg_func_fixture): - arr = pa.array([10, 20, 30, 40, 50], pa.int64()) - with pytest.raises(pa.ArrowTypeError, match="output datatype"): - pc.call_function("y=wrong_output_dtype(x)", [arr]) +def test_hash_agg_exception(exception_agg_func_fixture): + arr1 = pa.array([10, 20, 30, 40, 50], pa.int64()) + arr2 = pa.array([4, 2, 1, 2, 1], pa.int32()) + table = pa.table([arr2, arr1], names=["id", "value"]) + with pytest.raises(RuntimeError, match='Oops'): + result = table.group_by("id").aggregate([("value", "y=exception_len(x)")]) -def test_agg_wrong_output_type(wrong_output_type_agg_func_fixture): - arr = pa.array([10, 20, 30, 40, 50], pa.int64()) - with pytest.raises(pa.ArrowTypeError, match="output type"): - pc.call_function("y=wrong_output_type(x)", [arr]) +def test_hash_agg_random(sum_agg_func_fixture): + """Test hash aggregate udf with randomly sampled data""" -def test_agg_varargs(varargs_agg_func_fixture): - arr1 = pa.array([10, 20, 30, 40, 50], pa.int64()) - arr2 = pa.array([1.0, 2.0, 3.0, 4.0, 5.0], pa.float64()) + value_num = 1000000 + group_num = 1000 - result = pc.call_function( - "sum_mean", [arr1, arr2] - ) - expected = pa.scalar(33.0) - assert result == expected + arr1 = pa.array(np.repeat(1, value_num), pa.float64()) + arr2 = pa.array(np.random.choice(group_num, value_num), pa.int32()) + table = pa.table([arr2, arr1], names=['id', 'value']) -def test_agg_exception(exception_agg_func_fixture): - arr = pa.array([10, 20, 30, 40, 50, 60], pa.int64()) + result = table.group_by("id").aggregate([("value", "sum_udf")]) + expected = table.group_by("id").aggregate( + [("value", "sum")]).rename_columns(['id', 'value_sum_udf']) - with pytest.raises(RuntimeError, match='Oops'): - pc.call_function("y=exception_len(x)", [arr]) + assert result.sort_by('id') == expected.sort_by('id') From 14b1207b629838ea5fa6b88840bdac042d6577fd Mon Sep 17 00:00:00 2001 From: Li Jin Date: Thu, 22 Jun 2023 17:41:12 -0400 Subject: [PATCH 07/11] Lint --- python/pyarrow/tests/test_udf.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/pyarrow/tests/test_udf.py b/python/pyarrow/tests/test_udf.py index 48667b35369..84f4985abe2 100644 --- a/python/pyarrow/tests/test_udf.py +++ b/python/pyarrow/tests/test_udf.py @@ -748,7 +748,7 @@ def test_hash_agg_empty(unary_agg_func_fixture): table = pa.table([arr2, arr1], names=["id", "value"]) with pytest.raises(pa.ArrowInvalid, match='empty inputs'): - result = table.group_by("id").aggregate([("value", "mean_udf")]) + table.group_by("id").aggregate([("value", "mean_udf")]) def test_hash_agg_wrong_output_dtype(wrong_output_dtype_agg_func_fixture): @@ -757,7 +757,7 @@ def test_hash_agg_wrong_output_dtype(wrong_output_dtype_agg_func_fixture): table = pa.table([arr2, arr1], names=["id", "value"]) with pytest.raises(pa.ArrowTypeError, match="output datatype"): - result = table.group_by("id").aggregate([("value", "y=wrong_output_dtype(x)")]) + table.group_by("id").aggregate([("value", "y=wrong_output_dtype(x)")]) def test_hash_agg_wrong_output_type(wrong_output_type_agg_func_fixture): @@ -766,7 +766,7 @@ def test_hash_agg_wrong_output_type(wrong_output_type_agg_func_fixture): table = pa.table([arr2, arr1], names=["id", "value"]) with pytest.raises(pa.ArrowTypeError, match="output type"): - result = table.group_by("id").aggregate([("value", "y=wrong_output_type(x)")]) + table.group_by("id").aggregate([("value", "y=wrong_output_type(x)")]) def test_hash_agg_exception(exception_agg_func_fixture): @@ -775,7 +775,7 @@ def test_hash_agg_exception(exception_agg_func_fixture): table = pa.table([arr2, arr1], names=["id", "value"]) with pytest.raises(RuntimeError, match='Oops'): - result = table.group_by("id").aggregate([("value", "y=exception_len(x)")]) + table.group_by("id").aggregate([("value", "y=exception_len(x)")]) def test_hash_agg_random(sum_agg_func_fixture): From d2d00766f9e03cc47c0bd173ba372d359480741b Mon Sep 17 00:00:00 2001 From: Li Jin Date: Fri, 23 Jun 2023 09:23:58 -0400 Subject: [PATCH 08/11] Fix doc example --- python/pyarrow/_compute.pyx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyarrow/_compute.pyx b/python/pyarrow/_compute.pyx index 380fd7e4227..3de9598992c 100644 --- a/python/pyarrow/_compute.pyx +++ b/python/pyarrow/_compute.pyx @@ -2824,7 +2824,7 @@ def register_aggregate_function(func, function_name, function_doc, in_types, out >>> answer >>> table = pa.table([pa.array([1, 1, 2, 2]), pa.array([10, 20, 30, 40])], names=['k', 'v']) - >>> result = table.group_by('k').aggregate(['v', 'py_compute_median']) + >>> result = table.group_by('k').aggregate([('v', 'py_compute_median')]) >>> result pyarrow.Table k: int64 From 007221a9588a24f0e1fdbdd609e0aff7f1f6b5dd Mon Sep 17 00:00:00 2001 From: Li Jin Date: Mon, 26 Jun 2023 09:59:32 -0400 Subject: [PATCH 09/11] Apply suggestions from code review Co-authored-by: Weston Pace --- python/pyarrow/src/arrow/python/udf.cc | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/python/pyarrow/src/arrow/python/udf.cc b/python/pyarrow/src/arrow/python/udf.cc index d761520ac0c..f0516e44190 100644 --- a/python/pyarrow/src/arrow/python/udf.cc +++ b/python/pyarrow/src/arrow/python/udf.cc @@ -15,7 +15,6 @@ // specific language governing permissions and limitations // under the License. -#include #include "arrow/array/builder_base.h" #include "arrow/buffer_builder.h" @@ -264,6 +263,7 @@ struct PythonUdfHashAggregatorImpl : public HashUdfAggregator { : function(function), cb(std::move(cb)), output_type(std::move(output_type)) { Py_INCREF(function->obj()); std::vector> fields; + fields.reserve(input_types.size()); for (size_t i = 0; i < input_types.size(); i++) { fields.push_back(field("", input_types[i])); } @@ -276,8 +276,8 @@ struct PythonUdfHashAggregatorImpl : public HashUdfAggregator { } } - /// @brief Same as ApplyGrouping in parition.cc - /// Replicated the code here to avoid complicating the dependencies + // same as ApplyGrouping in parition.cc + // replicated the code here to avoid complicating the dependencies static Result ApplyGroupings( const ListArray& groupings, const std::shared_ptr& batch) { ARROW_ASSIGN_OR_RAISE(Datum sorted, @@ -587,8 +587,8 @@ Status RegisterScalarAggregateFunction(PyObject* function, UdfWrapperCallback cb return Status::OK(); } -/// @brief Create a new UdfOptions with adjustment for hash kernel -/// @param options User provided udf options +/// \brief Create a new UdfOptions with adjustment for hash kernel +/// \param options User provided udf options UdfOptions AdjustForHashAggregate(const UdfOptions& options) { UdfOptions hash_options; // Append hash_ before the function name to seperate from the scalar From 378eb48c8fa35207f0afd61b8b020bb8e1f62d5d Mon Sep 17 00:00:00 2001 From: Li Jin Date: Mon, 26 Jun 2023 13:45:58 -0400 Subject: [PATCH 10/11] Address PR comments --- python/pyarrow/_compute.pyx | 3 +++ python/pyarrow/src/arrow/python/udf.cc | 8 ++++---- python/pyarrow/tests/test_udf.py | 7 +++++-- 3 files changed, 12 insertions(+), 6 deletions(-) diff --git a/python/pyarrow/_compute.pyx b/python/pyarrow/_compute.pyx index 3de9598992c..bec985ca034 100644 --- a/python/pyarrow/_compute.pyx +++ b/python/pyarrow/_compute.pyx @@ -2767,6 +2767,9 @@ def register_aggregate_function(func, function_name, function_doc, in_types, out This is often used with ordered or segmented aggregation where groups can be emit before accumulating all of the input data. + Note that currently size of any input column can not exceed 2 GB limit + (all groups combined). + Parameters ---------- func : callable diff --git a/python/pyarrow/src/arrow/python/udf.cc b/python/pyarrow/src/arrow/python/udf.cc index f0516e44190..435c89f596d 100644 --- a/python/pyarrow/src/arrow/python/udf.cc +++ b/python/pyarrow/src/arrow/python/udf.cc @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. - +#include "arrow/python/udf.h" #include "arrow/array/builder_base.h" #include "arrow/buffer_builder.h" #include "arrow/compute/api_aggregate.h" @@ -24,7 +24,6 @@ #include "arrow/compute/kernel.h" #include "arrow/compute/row/grouper.h" #include "arrow/python/common.h" -#include "arrow/python/udf.h" #include "arrow/table.h" #include "arrow/util/checked_cast.h" #include "arrow/util/logging.h" @@ -310,7 +309,7 @@ struct PythonUdfHashAggregatorImpl : public HashUdfAggregator { const ArraySpan& groups_array_data = batch[batch.num_values() - 1].array; DCHECK_EQ(groups_array_data.offset, 0); int64_t batch_num_values = groups_array_data.length; - const auto* batch_groups = groups_array_data.GetValues(1, 0); + const auto* batch_groups = groups_array_data.GetValues(1); RETURN_NOT_OK(groups.Append(batch_groups, batch_num_values)); values.push_back(std::move(rb)); num_values += batch_num_values; @@ -352,7 +351,8 @@ struct PythonUdfHashAggregatorImpl : public HashUdfAggregator { UdfContext udf_context{ctx->memory_pool(), table->num_rows()}; if (rb->num_rows() == 0) { - return Status::Invalid("Finalized is called with empty inputs"); + *out = Datum(); + return Status::OK(); } ARROW_ASSIGN_OR_RAISE(RecordBatchVector rbs, ApplyGroupings(*groupings, rb)); diff --git a/python/pyarrow/tests/test_udf.py b/python/pyarrow/tests/test_udf.py index 84f4985abe2..5631e19455c 100644 --- a/python/pyarrow/tests/test_udf.py +++ b/python/pyarrow/tests/test_udf.py @@ -747,8 +747,11 @@ def test_hash_agg_empty(unary_agg_func_fixture): arr2 = pa.array([], pa.int32()) table = pa.table([arr2, arr1], names=["id", "value"]) - with pytest.raises(pa.ArrowInvalid, match='empty inputs'): - table.group_by("id").aggregate([("value", "mean_udf")]) + result = table.group_by("id").aggregate([("value", "mean_udf")]) + expected = pa.table([pa.array([], pa.int32()), pa.array( + [], pa.float64())], names=['id', 'value_mean_udf']) + + assert result == expected def test_hash_agg_wrong_output_dtype(wrong_output_dtype_agg_func_fixture): From c29ae7260a21009bb1a667769a6bb9eac3597e0d Mon Sep 17 00:00:00 2001 From: Li Jin Date: Wed, 28 Jun 2023 18:57:12 -0400 Subject: [PATCH 11/11] Apply suggestions from code review Co-authored-by: Weston Pace --- python/pyarrow/_compute.pyx | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/pyarrow/_compute.pyx b/python/pyarrow/_compute.pyx index bec985ca034..d0b1ef35fc7 100644 --- a/python/pyarrow/_compute.pyx +++ b/python/pyarrow/_compute.pyx @@ -2767,8 +2767,8 @@ def register_aggregate_function(func, function_name, function_doc, in_types, out This is often used with ordered or segmented aggregation where groups can be emit before accumulating all of the input data. - Note that currently size of any input column can not exceed 2 GB limit - (all groups combined). + Note that currently the size of any input column can not exceed 2 GB + for a single segment (all groups combined). Parameters ----------