diff --git a/python/pyarrow/_compute.pyx b/python/pyarrow/_compute.pyx index eaf9d1dfb65..d0b1ef35fc7 100644 --- a/python/pyarrow/_compute.pyx +++ b/python/pyarrow/_compute.pyx @@ -2767,6 +2767,9 @@ def register_aggregate_function(func, function_name, function_doc, in_types, out This is often used with ordered or segmented aggregation where groups can be emit before accumulating all of the input data. + Note that currently the size of any input column can not exceed 2 GB + for a single segment (all groups combined). + Parameters ---------- func : callable @@ -2823,6 +2826,15 @@ def register_aggregate_function(func, function_name, function_doc, in_types, out >>> answer = pc.call_function(func_name, [pa.array([20, 40])]) >>> answer + >>> table = pa.table([pa.array([1, 1, 2, 2]), pa.array([10, 20, 30, 40])], names=['k', 'v']) + >>> result = table.group_by('k').aggregate([('v', 'py_compute_median')]) + >>> result + pyarrow.Table + k: int64 + v_py_compute_median: double + ---- + k: [[1,2]] + v_py_compute_median: [[15,35]] """ return _register_user_defined_function(get_register_aggregate_function(), func, function_name, function_doc, in_types, diff --git a/python/pyarrow/conftest.py b/python/pyarrow/conftest.py index f32cbf01efc..6f6807e907d 100644 --- a/python/pyarrow/conftest.py +++ b/python/pyarrow/conftest.py @@ -20,6 +20,8 @@ from pyarrow import Codec from pyarrow import fs +import numpy as np + groups = [ 'acero', 'brotli', @@ -283,15 +285,14 @@ def unary_function(ctx, x): @pytest.fixture(scope="session") def unary_agg_func_fixture(): """ - Register a unary aggregate function + Register a unary aggregate function (mean) """ from pyarrow import compute as pc - import numpy as np def func(ctx, x): return pa.scalar(np.nanmean(x)) - func_name = "y=avg(x)" + func_name = "mean_udf" func_doc = {"summary": "y=avg(x)", "description": "find mean of x"} @@ -312,7 +313,6 @@ def varargs_agg_func_fixture(): Register a unary aggregate function """ from pyarrow import compute as pc - import numpy as np def func(ctx, *args): sum = 0.0 @@ -320,7 +320,7 @@ def func(ctx, *args): sum += np.nanmean(arg) return pa.scalar(sum) - func_name = "y=sum_mean(x...)" + func_name = "sum_mean" func_doc = {"summary": "Varargs aggregate", "description": "Varargs aggregate"} diff --git a/python/pyarrow/src/arrow/python/udf.cc b/python/pyarrow/src/arrow/python/udf.cc index 06c116af820..435c89f596d 100644 --- a/python/pyarrow/src/arrow/python/udf.cc +++ b/python/pyarrow/src/arrow/python/udf.cc @@ -16,15 +16,25 @@ // under the License. #include "arrow/python/udf.h" +#include "arrow/array/builder_base.h" +#include "arrow/buffer_builder.h" #include "arrow/compute/api_aggregate.h" +#include "arrow/compute/api_vector.h" #include "arrow/compute/function.h" #include "arrow/compute/kernel.h" +#include "arrow/compute/row/grouper.h" #include "arrow/python/common.h" #include "arrow/table.h" #include "arrow/util/checked_cast.h" +#include "arrow/util/logging.h" namespace arrow { +using compute::ExecSpan; +using compute::Grouper; +using compute::KernelContext; +using compute::KernelState; using internal::checked_cast; + namespace py { namespace { @@ -73,6 +83,13 @@ struct ScalarUdfAggregator : public compute::KernelState { virtual Status Finalize(compute::KernelContext* ctx, Datum* out) = 0; }; +struct HashUdfAggregator : public compute::KernelState { + virtual Status Resize(KernelContext* ctx, int64_t size) = 0; + virtual Status Consume(KernelContext* ctx, const ExecSpan& batch) = 0; + virtual Status Merge(KernelContext* ct, KernelState&& other, const ArrayData&) = 0; + virtual Status Finalize(KernelContext* ctx, Datum* out) = 0; +}; + arrow::Status AggregateUdfConsume(compute::KernelContext* ctx, const compute::ExecSpan& batch) { return checked_cast(ctx->state())->Consume(ctx, batch); @@ -87,6 +104,24 @@ arrow::Status AggregateUdfFinalize(compute::KernelContext* ctx, arrow::Datum* ou return checked_cast(ctx->state())->Finalize(ctx, out); } +arrow::Status HashAggregateUdfResize(KernelContext* ctx, int64_t size) { + return checked_cast(ctx->state())->Resize(ctx, size); +} + +arrow::Status HashAggregateUdfConsume(KernelContext* ctx, const ExecSpan& batch) { + return checked_cast(ctx->state())->Consume(ctx, batch); +} + +arrow::Status HashAggregateUdfMerge(KernelContext* ctx, KernelState&& src, + const ArrayData& group_id_mapping) { + return checked_cast(ctx->state()) + ->Merge(ctx, std::move(src), group_id_mapping); +} + +arrow::Status HashAggregateUdfFinalize(KernelContext* ctx, Datum* out) { + return checked_cast(ctx->state())->Finalize(ctx, out); +} + struct PythonTableUdfKernelInit { PythonTableUdfKernelInit(std::shared_ptr function_maker, UdfWrapperCallback cb) @@ -124,14 +159,12 @@ struct PythonTableUdfKernelInit { }; struct PythonUdfScalarAggregatorImpl : public ScalarUdfAggregator { - PythonUdfScalarAggregatorImpl(UdfWrapperCallback agg_cb, - std::shared_ptr agg_function, + PythonUdfScalarAggregatorImpl(std::shared_ptr function, + UdfWrapperCallback cb, std::vector> input_types, std::shared_ptr output_type) - : agg_cb(std::move(agg_cb)), - agg_function(agg_function), - output_type(std::move(output_type)) { - Py_INCREF(agg_function->obj()); + : function(function), cb(std::move(cb)), output_type(std::move(output_type)) { + Py_INCREF(function->obj()); std::vector> fields; for (size_t i = 0; i < input_types.size(); i++) { fields.push_back(field("", input_types[i])); @@ -141,7 +174,7 @@ struct PythonUdfScalarAggregatorImpl : public ScalarUdfAggregator { ~PythonUdfScalarAggregatorImpl() override { if (_Py_IsFinalizing()) { - agg_function->detach(); + function->detach(); } } @@ -164,7 +197,6 @@ struct PythonUdfScalarAggregatorImpl : public ScalarUdfAggregator { Status Finalize(compute::KernelContext* ctx, Datum* out) override { auto state = arrow::internal::checked_cast(ctx->state()); - std::shared_ptr& function = state->agg_function; const int num_args = input_schema->num_fields(); // Note: The way that batches are concatenated together @@ -195,8 +227,8 @@ struct PythonUdfScalarAggregatorImpl : public ScalarUdfAggregator { PyObject* data = wrap_array(c_data); PyTuple_SetItem(arg_tuple.obj(), arg_id, data); } - result = std::make_unique( - agg_cb(function->obj(), udf_context, arg_tuple.obj())); + result = + std::make_unique(cb(function->obj(), udf_context, arg_tuple.obj())); RETURN_NOT_OK(CheckPyError()); // unwrapping the output for expected output type if (is_scalar(result->obj())) { @@ -215,9 +247,164 @@ struct PythonUdfScalarAggregatorImpl : public ScalarUdfAggregator { return Status::OK(); } - UdfWrapperCallback agg_cb; + std::shared_ptr function; + UdfWrapperCallback cb; std::vector> values; - std::shared_ptr agg_function; + std::shared_ptr input_schema; + std::shared_ptr output_type; +}; + +struct PythonUdfHashAggregatorImpl : public HashUdfAggregator { + PythonUdfHashAggregatorImpl(std::shared_ptr function, + UdfWrapperCallback cb, + std::vector> input_types, + std::shared_ptr output_type) + : function(function), cb(std::move(cb)), output_type(std::move(output_type)) { + Py_INCREF(function->obj()); + std::vector> fields; + fields.reserve(input_types.size()); + for (size_t i = 0; i < input_types.size(); i++) { + fields.push_back(field("", input_types[i])); + } + input_schema = schema(std::move(fields)); + }; + + ~PythonUdfHashAggregatorImpl() override { + if (_Py_IsFinalizing()) { + function->detach(); + } + } + + // same as ApplyGrouping in parition.cc + // replicated the code here to avoid complicating the dependencies + static Result ApplyGroupings( + const ListArray& groupings, const std::shared_ptr& batch) { + ARROW_ASSIGN_OR_RAISE(Datum sorted, + compute::Take(batch, groupings.data()->child_data[0])); + + const auto& sorted_batch = *sorted.record_batch(); + + RecordBatchVector out(static_cast(groupings.length())); + for (size_t i = 0; i < out.size(); ++i) { + out[i] = sorted_batch.Slice(groupings.value_offset(i), groupings.value_length(i)); + } + + return out; + } + + Status Resize(KernelContext* ctx, int64_t new_num_groups) { + // We only need to change num_groups in resize + // similar to other hash aggregate kernels + num_groups = new_num_groups; + return Status::OK(); + } + + Status Consume(KernelContext* ctx, const ExecSpan& batch) { + ARROW_ASSIGN_OR_RAISE( + std::shared_ptr rb, + batch.ToExecBatch().ToRecordBatch(input_schema, ctx->memory_pool())); + + // This is similar to GroupedListImpl + // last array is the group id + const ArraySpan& groups_array_data = batch[batch.num_values() - 1].array; + DCHECK_EQ(groups_array_data.offset, 0); + int64_t batch_num_values = groups_array_data.length; + const auto* batch_groups = groups_array_data.GetValues(1); + RETURN_NOT_OK(groups.Append(batch_groups, batch_num_values)); + values.push_back(std::move(rb)); + num_values += batch_num_values; + return Status::OK(); + } + Status Merge(KernelContext* ctx, KernelState&& other_state, + const ArrayData& group_id_mapping) { + // This is similar to GroupedListImpl + auto& other = checked_cast(other_state); + auto& other_values = other.values; + const uint32_t* other_raw_groups = other.groups.data(); + values.insert(values.end(), std::make_move_iterator(other_values.begin()), + std::make_move_iterator(other_values.end())); + + auto g = group_id_mapping.GetValues(1); + for (uint32_t other_g = 0; static_cast(other_g) < other.num_values; + ++other_g) { + // Different state can have different group_id mappings, so we + // need to translate the ids + RETURN_NOT_OK(groups.Append(g[other_raw_groups[other_g]])); + } + + num_values += other.num_values; + return Status::OK(); + } + + Status Finalize(KernelContext* ctx, Datum* out) { + // Exclude the last column which is the group id + const int num_args = input_schema->num_fields() - 1; + + ARROW_ASSIGN_OR_RAISE(auto groups_buffer, groups.Finish()); + ARROW_ASSIGN_OR_RAISE(auto groupings, + Grouper::MakeGroupings(UInt32Array(num_values, groups_buffer), + static_cast(num_groups))); + + ARROW_ASSIGN_OR_RAISE(auto table, + arrow::Table::FromRecordBatches(input_schema, values)); + ARROW_ASSIGN_OR_RAISE(auto rb, table->CombineChunksToBatch(ctx->memory_pool())); + UdfContext udf_context{ctx->memory_pool(), table->num_rows()}; + + if (rb->num_rows() == 0) { + *out = Datum(); + return Status::OK(); + } + + ARROW_ASSIGN_OR_RAISE(RecordBatchVector rbs, ApplyGroupings(*groupings, rb)); + + return SafeCallIntoPython([&] { + ARROW_ASSIGN_OR_RAISE(std::unique_ptr builder, + MakeBuilder(output_type, ctx->memory_pool())); + for (auto& group_rb : rbs) { + std::unique_ptr result; + OwnedRef arg_tuple(PyTuple_New(num_args)); + RETURN_NOT_OK(CheckPyError()); + + for (int arg_id = 0; arg_id < num_args; arg_id++) { + // Since we combined chunks there is only one chunk + std::shared_ptr c_data = group_rb->column(arg_id); + PyObject* data = wrap_array(c_data); + PyTuple_SetItem(arg_tuple.obj(), arg_id, data); + } + + result = + std::make_unique(cb(function->obj(), udf_context, arg_tuple.obj())); + RETURN_NOT_OK(CheckPyError()); + + // unwrapping the output for expected output type + if (is_scalar(result->obj())) { + ARROW_ASSIGN_OR_RAISE(std::shared_ptr val, + unwrap_scalar(result->obj())); + if (*output_type != *val->type) { + return Status::TypeError("Expected output datatype ", output_type->ToString(), + ", but function returned datatype ", + val->type->ToString()); + } + ARROW_RETURN_NOT_OK(builder->AppendScalar(std::move(*val))); + } else { + return Status::TypeError("Unexpected output type: ", + Py_TYPE(result->obj())->tp_name, " (expected Scalar)"); + } + } + ARROW_ASSIGN_OR_RAISE(auto result, builder->Finish()); + out->value = std::move(result->data()); + return Status::OK(); + }); + } + + std::shared_ptr function; + UdfWrapperCallback cb; + // Accumulated input batches + std::vector> values; + // Group ids - extracted from the last column from the batch + TypedBufferBuilder groups; + int64_t num_groups = 0; + int64_t num_values = 0; std::shared_ptr input_schema; std::shared_ptr output_type; }; @@ -332,15 +519,15 @@ Status RegisterUdf(PyObject* user_function, compute::KernelInit kernel_init, } // namespace -Status RegisterScalarFunction(PyObject* user_function, UdfWrapperCallback wrapper, +Status RegisterScalarFunction(PyObject* function, UdfWrapperCallback cb, const UdfOptions& options, compute::FunctionRegistry* registry) { - return RegisterUdf(user_function, - PythonUdfKernelInit{std::make_shared(user_function)}, - wrapper, options, registry); + return RegisterUdf(function, + PythonUdfKernelInit{std::make_shared(function)}, cb, + options, registry); } -Status RegisterTabularFunction(PyObject* user_function, UdfWrapperCallback wrapper, +Status RegisterTabularFunction(PyObject* function, UdfWrapperCallback cb, const UdfOptions& options, compute::FunctionRegistry* registry) { if (options.arity.num_args != 0 || options.arity.is_varargs) { @@ -350,24 +537,14 @@ Status RegisterTabularFunction(PyObject* user_function, UdfWrapperCallback wrapp return Status::Invalid("tabular function with non-struct output"); } return RegisterUdf( - user_function, - PythonTableUdfKernelInit{std::make_shared(user_function), wrapper}, - wrapper, options, registry); -} - -Status AddAggKernel(std::shared_ptr sig, - compute::KernelInit init, compute::ScalarAggregateFunction* func) { - compute::ScalarAggregateKernel kernel(std::move(sig), std::move(init), - AggregateUdfConsume, AggregateUdfMerge, - AggregateUdfFinalize, /*ordered=*/false); - RETURN_NOT_OK(func->AddKernel(std::move(kernel))); - return Status::OK(); + function, PythonTableUdfKernelInit{std::make_shared(function), cb}, + cb, options, registry); } -Status RegisterAggregateFunction(PyObject* agg_function, UdfWrapperCallback agg_wrapper, - const UdfOptions& options, - compute::FunctionRegistry* registry) { - if (!PyCallable_Check(agg_function)) { +Status RegisterScalarAggregateFunction(PyObject* function, UdfWrapperCallback cb, + const UdfOptions& options, + compute::FunctionRegistry* registry) { + if (!PyCallable_Check(function)) { return Status::TypeError("Expected a callable Python object."); } @@ -378,7 +555,7 @@ Status RegisterAggregateFunction(PyObject* agg_function, UdfWrapperCallback agg_ // Py_INCREF here so that once a function is registered // its refcount gets increased by 1 and doesn't get gced // if all existing refs are gone - Py_INCREF(agg_function); + Py_INCREF(function); static auto default_scalar_aggregate_options = compute::ScalarAggregateOptions::Defaults(); @@ -392,24 +569,109 @@ Status RegisterAggregateFunction(PyObject* agg_function, UdfWrapperCallback agg_ } compute::OutputType output_type(options.output_type); - compute::KernelInit init = [agg_wrapper, agg_function, options]( - compute::KernelContext* ctx, - const compute::KernelInitArgs& args) + compute::KernelInit init = [cb, function, options](compute::KernelContext* ctx, + const compute::KernelInitArgs& args) -> Result> { return std::make_unique( - agg_wrapper, std::make_shared(agg_function), options.input_types, + std::make_shared(function), cb, options.input_types, options.output_type); }; - RETURN_NOT_OK(AddAggKernel( - compute::KernelSignature::Make(std::move(input_types), std::move(output_type), - options.arity.is_varargs), - init, aggregate_func.get())); - + auto sig = compute::KernelSignature::Make( + std::move(input_types), std::move(output_type), options.arity.is_varargs); + compute::ScalarAggregateKernel kernel(std::move(sig), std::move(init), + AggregateUdfConsume, AggregateUdfMerge, + AggregateUdfFinalize, /*ordered=*/false); + RETURN_NOT_OK(aggregate_func->AddKernel(std::move(kernel))); RETURN_NOT_OK(registry->AddFunction(std::move(aggregate_func))); return Status::OK(); } +/// \brief Create a new UdfOptions with adjustment for hash kernel +/// \param options User provided udf options +UdfOptions AdjustForHashAggregate(const UdfOptions& options) { + UdfOptions hash_options; + // Append hash_ before the function name to seperate from the scalar + // version + hash_options.func_name = "hash_" + options.func_name; + // Extend input types with group id. Group id is appended by the group + // aggregation node. Here we change both arity and input types + if (options.arity.is_varargs) { + hash_options.arity = options.arity; + } else { + hash_options.arity = compute::Arity(options.arity.num_args + 1, false); + } + // Changing the function doc shouldn't be necessarily because group id + // is not user visible, however, this is currently needed to pass the + // function validation. The name group_id_array is consistent with + // hash kernels in hash_aggregate.cc + hash_options.func_doc = options.func_doc; + hash_options.func_doc.arg_names.emplace_back("group_id_array"); + std::vector> input_dtypes = options.input_types; + input_dtypes.emplace_back(uint32()); + hash_options.input_types = std::move(input_dtypes); + hash_options.output_type = options.output_type; + return hash_options; +} + +Status RegisterHashAggregateFunction(PyObject* function, UdfWrapperCallback cb, + const UdfOptions& options, + compute::FunctionRegistry* registry) { + if (!PyCallable_Check(function)) { + return Status::TypeError("Expected a callable Python object."); + } + + if (registry == NULLPTR) { + registry = compute::GetFunctionRegistry(); + } + + // Py_INCREF here so that once a function is registered + // its refcount gets increased by 1 and doesn't get gced + // if all existing refs are gone + Py_INCREF(function); + UdfOptions hash_options = AdjustForHashAggregate(options); + + std::vector input_types; + for (const auto& in_dtype : hash_options.input_types) { + input_types.emplace_back(in_dtype); + } + compute::OutputType output_type(hash_options.output_type); + + static auto default_hash_aggregate_options = + compute::ScalarAggregateOptions::Defaults(); + auto hash_aggregate_func = std::make_shared( + hash_options.func_name, hash_options.arity, hash_options.func_doc, + &default_hash_aggregate_options); + + compute::KernelInit init = [function, cb, hash_options]( + compute::KernelContext* ctx, + const compute::KernelInitArgs& args) + -> Result> { + return std::make_unique( + std::make_shared(function), cb, hash_options.input_types, + hash_options.output_type); + }; + + auto sig = compute::KernelSignature::Make( + std::move(input_types), std::move(output_type), hash_options.arity.is_varargs); + + compute::HashAggregateKernel kernel( + std::move(sig), std::move(init), HashAggregateUdfResize, HashAggregateUdfConsume, + HashAggregateUdfMerge, HashAggregateUdfFinalize, /*ordered=*/false); + RETURN_NOT_OK(hash_aggregate_func->AddKernel(std::move(kernel))); + RETURN_NOT_OK(registry->AddFunction(std::move(hash_aggregate_func))); + return Status::OK(); +} + +Status RegisterAggregateFunction(PyObject* function, UdfWrapperCallback cb, + const UdfOptions& options, + compute::FunctionRegistry* registry) { + RETURN_NOT_OK(RegisterScalarAggregateFunction(function, cb, options, registry)); + RETURN_NOT_OK(RegisterHashAggregateFunction(function, cb, options, registry)); + + return Status::OK(); +} + Result> CallTabularFunction( const std::string& func_name, const std::vector& args, compute::FunctionRegistry* registry) { diff --git a/python/pyarrow/tests/test_substrait.py b/python/pyarrow/tests/test_substrait.py index 34faaa157af..93ecae7bfa1 100644 --- a/python/pyarrow/tests/test_substrait.py +++ b/python/pyarrow/tests/test_substrait.py @@ -607,7 +607,7 @@ def table_provider(names, schema): assert res_tb == expected -def test_aggregate_udf_basic(varargs_agg_func_fixture): +def test_scalar_aggregate_udf_basic(varargs_agg_func_fixture): test_table = pa.Table.from_pydict( {"k": [1, 1, 2, 2], "v1": [1, 2, 3, 4], @@ -630,7 +630,7 @@ def table_provider(names, _): "extensionFunction": { "extensionUriReference": 1, "functionAnchor": 1, - "name": "y=sum_mean(x...)" + "name": "sum_mean" } } ], @@ -753,3 +753,173 @@ def table_provider(names, _): }) assert res_tb == expected_tb + + +def test_hash_aggregate_udf_basic(varargs_agg_func_fixture): + + test_table = pa.Table.from_pydict( + {"t": [1, 1, 1, 1, 2, 2, 2, 2], + "k": [1, 0, 0, 1, 0, 1, 0, 1], + "v1": [1, 2, 3, 4, 5, 6, 7, 8], + "v2": [1.0, 1.0, 1.0, 1.0, 2.0, 3.0, 4.0, 5.0]} + ) + + def table_provider(names, _): + return test_table + + substrait_query = b""" +{ + "extensionUris": [ + { + "extensionUriAnchor": 1, + "uri": "urn:arrow:substrait_simple_extension_function" + }, + ], + "extensions": [ + { + "extensionFunction": { + "extensionUriReference": 1, + "functionAnchor": 1, + "name": "sum_mean" + } + } + ], + "relations": [ + { + "root": { + "input": { + "extensionSingle": { + "common": { + "emit": { + "outputMapping": [ + 0, + 1, + 2 + ] + } + }, + "input": { + "read": { + "baseSchema": { + "names": [ + "t", + "k", + "v1", + "v2", + ], + "struct": { + "types": [ + { + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + { + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + { + "i64": { + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "fp64": { + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": ["t1"] + } + } + }, + "detail": { + "@type": "/arrow.substrait_ext.SegmentedAggregateRel", + "groupingKeys": [ + { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": {} + } + ], + "segmentKeys": [ + { + "directReference": { + "structField": {} + }, + "rootReference": {} + } + ], + "measures": [ + { + "measure": { + "functionReference": 1, + "phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT", + "outputType": { + "fp64": { + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 2 + } + }, + "rootReference": {} + } + } + }, + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 3 + } + }, + "rootReference": {} + } + } + } + ] + } + } + ] + } + } + }, + "names": [ + "t", + "k", + "v_avg" + ] + } + } + ], +} +""" + buf = pa._substrait._parse_json_plan(substrait_query) + reader = pa.substrait.run_query( + buf, table_provider=table_provider, use_threads=False) + res_tb = reader.read_all() + + expected_tb = pa.Table.from_pydict({ + 't': [1, 1, 2, 2], + 'k': [1, 0, 0, 1], + 'v_avg': [3.5, 3.5, 9.0, 11.0] + }) + + # Ordering of k is deterministic because this is running with serial execution + assert res_tb == expected_tb diff --git a/python/pyarrow/tests/test_udf.py b/python/pyarrow/tests/test_udf.py index c0cfd3d26e8..5631e19455c 100644 --- a/python/pyarrow/tests/test_udf.py +++ b/python/pyarrow/tests/test_udf.py @@ -18,6 +18,8 @@ import pytest +import numpy as np + import pyarrow as pa from pyarrow import compute as pc @@ -42,6 +44,28 @@ class MyError(RuntimeError): pass +@pytest.fixture(scope="session") +def sum_agg_func_fixture(): + """ + Register a unary aggregate function (mean) + """ + def func(ctx, x, *args): + return pa.scalar(np.nansum(x)) + + func_name = "sum_udf" + func_doc = empty_udf_doc + + pc.register_aggregate_function(func, + func_name, + func_doc, + { + "x": pa.float64(), + }, + pa.float64() + ) + return func, func_name + + @pytest.fixture(scope="session") def exception_agg_func_fixture(): def func(ctx, x): @@ -656,45 +680,120 @@ def test_udt_datasource1_exception(): _test_datasource1_udt(datasource1_exception) -def test_agg_basic(unary_agg_func_fixture): +def test_scalar_agg_basic(unary_agg_func_fixture): arr = pa.array([10.0, 20.0, 30.0, 40.0, 50.0], pa.float64()) - result = pc.call_function("y=avg(x)", [arr]) + result = pc.call_function("mean_udf", [arr]) expected = pa.scalar(30.0) assert result == expected -def test_agg_empty(unary_agg_func_fixture): +def test_scalar_agg_empty(unary_agg_func_fixture): empty = pa.array([], pa.float64()) with pytest.raises(pa.ArrowInvalid, match='empty inputs'): - pc.call_function("y=avg(x)", [empty]) + pc.call_function("mean_udf", [empty]) -def test_agg_wrong_output_dtype(wrong_output_dtype_agg_func_fixture): +def test_scalar_agg_wrong_output_dtype(wrong_output_dtype_agg_func_fixture): arr = pa.array([10, 20, 30, 40, 50], pa.int64()) with pytest.raises(pa.ArrowTypeError, match="output datatype"): pc.call_function("y=wrong_output_dtype(x)", [arr]) -def test_agg_wrong_output_type(wrong_output_type_agg_func_fixture): +def test_scalar_agg_wrong_output_type(wrong_output_type_agg_func_fixture): arr = pa.array([10, 20, 30, 40, 50], pa.int64()) with pytest.raises(pa.ArrowTypeError, match="output type"): pc.call_function("y=wrong_output_type(x)", [arr]) -def test_agg_varargs(varargs_agg_func_fixture): +def test_scalar_agg_varargs(varargs_agg_func_fixture): arr1 = pa.array([10, 20, 30, 40, 50], pa.int64()) arr2 = pa.array([1.0, 2.0, 3.0, 4.0, 5.0], pa.float64()) result = pc.call_function( - "y=sum_mean(x...)", [arr1, arr2] + "sum_mean", [arr1, arr2] ) expected = pa.scalar(33.0) assert result == expected -def test_agg_exception(exception_agg_func_fixture): +def test_scalar_agg_exception(exception_agg_func_fixture): arr = pa.array([10, 20, 30, 40, 50, 60], pa.int64()) with pytest.raises(RuntimeError, match='Oops'): pc.call_function("y=exception_len(x)", [arr]) + + +def test_hash_agg_basic(unary_agg_func_fixture): + arr1 = pa.array([10.0, 20.0, 30.0, 40.0, 50.0], pa.float64()) + arr2 = pa.array([4, 2, 1, 2, 1], pa.int32()) + + arr3 = pa.array([60.0, 70.0, 80.0, 90.0, 100.0], pa.float64()) + arr4 = pa.array([5, 1, 1, 4, 1], pa.int32()) + + table1 = pa.table([arr2, arr1], names=["id", "value"]) + table2 = pa.table([arr4, arr3], names=["id", "value"]) + table = pa.concat_tables([table1, table2]) + + result = table.group_by("id").aggregate([("value", "mean_udf")]) + expected = table.group_by("id").aggregate( + [("value", "mean")]).rename_columns(['id', 'value_mean_udf']) + + assert result.sort_by('id') == expected.sort_by('id') + + +def test_hash_agg_empty(unary_agg_func_fixture): + arr1 = pa.array([], pa.float64()) + arr2 = pa.array([], pa.int32()) + table = pa.table([arr2, arr1], names=["id", "value"]) + + result = table.group_by("id").aggregate([("value", "mean_udf")]) + expected = pa.table([pa.array([], pa.int32()), pa.array( + [], pa.float64())], names=['id', 'value_mean_udf']) + + assert result == expected + + +def test_hash_agg_wrong_output_dtype(wrong_output_dtype_agg_func_fixture): + arr1 = pa.array([10, 20, 30, 40, 50], pa.int64()) + arr2 = pa.array([4, 2, 1, 2, 1], pa.int32()) + + table = pa.table([arr2, arr1], names=["id", "value"]) + with pytest.raises(pa.ArrowTypeError, match="output datatype"): + table.group_by("id").aggregate([("value", "y=wrong_output_dtype(x)")]) + + +def test_hash_agg_wrong_output_type(wrong_output_type_agg_func_fixture): + arr1 = pa.array([10, 20, 30, 40, 50], pa.int64()) + arr2 = pa.array([4, 2, 1, 2, 1], pa.int32()) + table = pa.table([arr2, arr1], names=["id", "value"]) + + with pytest.raises(pa.ArrowTypeError, match="output type"): + table.group_by("id").aggregate([("value", "y=wrong_output_type(x)")]) + + +def test_hash_agg_exception(exception_agg_func_fixture): + arr1 = pa.array([10, 20, 30, 40, 50], pa.int64()) + arr2 = pa.array([4, 2, 1, 2, 1], pa.int32()) + table = pa.table([arr2, arr1], names=["id", "value"]) + + with pytest.raises(RuntimeError, match='Oops'): + table.group_by("id").aggregate([("value", "y=exception_len(x)")]) + + +def test_hash_agg_random(sum_agg_func_fixture): + """Test hash aggregate udf with randomly sampled data""" + + value_num = 1000000 + group_num = 1000 + + arr1 = pa.array(np.repeat(1, value_num), pa.float64()) + arr2 = pa.array(np.random.choice(group_num, value_num), pa.int32()) + + table = pa.table([arr2, arr1], names=['id', 'value']) + + result = table.group_by("id").aggregate([("value", "sum_udf")]) + expected = table.group_by("id").aggregate( + [("value", "sum")]).rename_columns(['id', 'value_sum_udf']) + + assert result.sort_by('id') == expected.sort_by('id')