-
Notifications
You must be signed in to change notification settings - Fork 4k
GH-35515: [C++][Python] Add non decomposable aggregation UDF #35514
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
0c1c1b2
94c9710
241e970
15194fa
e91b882
a2b89c6
c578057
ff04234
b1d51f7
3daefea
8fd8c96
8381f08
9d7fd9d
dc1d734
1203346
84c1e91
17ff274
7f65599
febf6cc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -954,7 +954,9 @@ ExtensionIdRegistry::SubstraitAggregateToArrow DecodeBasicAggregate( | |
| return Status::Invalid("Expected aggregate call ", call.id().uri, "#", | ||
| call.id().name, " to have at least one argument"); | ||
| } | ||
| case 1: { | ||
| default: { | ||
| // Handles all arity > 0 | ||
|
|
||
| std::shared_ptr<compute::FunctionOptions> options = nullptr; | ||
| if (arrow_function_name == "stddev" || arrow_function_name == "variance") { | ||
| // See the following URL for the spec of stddev and variance: | ||
|
|
@@ -981,21 +983,22 @@ ExtensionIdRegistry::SubstraitAggregateToArrow DecodeBasicAggregate( | |
| } | ||
| fixed_arrow_func += arrow_function_name; | ||
|
||
|
|
||
| ARROW_ASSIGN_OR_RAISE(compute::Expression arg, call.GetValueArg(0)); | ||
| const FieldRef* arg_ref = arg.field_ref(); | ||
| if (!arg_ref) { | ||
| return Status::Invalid("Expected an aggregate call ", call.id().uri, "#", | ||
| call.id().name, " to have a direct reference"); | ||
| std::vector<FieldRef> target; | ||
| for (int i = 0; i < call.size(); i++) { | ||
| ARROW_ASSIGN_OR_RAISE(compute::Expression arg, call.GetValueArg(i)); | ||
| const FieldRef* arg_ref = arg.field_ref(); | ||
|
||
| if (!arg_ref) { | ||
| return Status::Invalid("Expected an aggregate call ", call.id().uri, "#", | ||
| call.id().name, " to have a direct reference"); | ||
| } | ||
| // Copy arg_ref here because field_ref() return const FieldRef* | ||
| target.emplace_back(*arg_ref); | ||
icexelloss marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| } | ||
|
|
||
| return compute::Aggregate{std::move(fixed_arrow_func), | ||
| options ? std::move(options) : nullptr, *arg_ref, ""}; | ||
| options ? std::move(options) : nullptr, | ||
| std::move(target), ""}; | ||
| } | ||
| default: | ||
| break; | ||
| } | ||
| return Status::NotImplemented( | ||
| "Only nullary and unary aggregate functions are currently supported"); | ||
| }; | ||
| } | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -21,11 +21,11 @@ from pyarrow.lib cimport * | |
| from pyarrow.includes.common cimport * | ||
| from pyarrow.includes.libarrow cimport * | ||
|
|
||
| cdef class ScalarUdfContext(_Weakrefable): | ||
|
||
| cdef class UdfContext(_Weakrefable): | ||
| cdef: | ||
| CScalarUdfContext c_context | ||
| CUdfContext c_context | ||
|
|
||
| cdef void init(self, const CScalarUdfContext& c_context) | ||
| cdef void init(self, const CUdfContext& c_context) | ||
|
|
||
|
|
||
| cdef class FunctionOptions(_Weakrefable): | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -278,3 +278,59 @@ def unary_function(ctx, x): | |
| {"array": pa.int64()}, | ||
| pa.int64()) | ||
| return unary_function, func_name | ||
|
|
||
|
|
||
| @pytest.fixture(scope="session") | ||
| def unary_agg_func_fixture(): | ||
|
||
| """ | ||
| Register a unary aggregate function | ||
| """ | ||
| from pyarrow import compute as pc | ||
| import numpy as np | ||
|
|
||
| def func(ctx, x): | ||
| return pa.scalar(np.nanmean(x)) | ||
|
|
||
| func_name = "y=avg(x)" | ||
| func_doc = {"summary": "y=avg(x)", | ||
| "description": "find mean of x"} | ||
|
|
||
| pc.register_aggregate_function(func, | ||
| func_name, | ||
| func_doc, | ||
| { | ||
| "x": pa.float64(), | ||
| }, | ||
| pa.float64() | ||
| ) | ||
| return func, func_name | ||
|
|
||
|
|
||
| @pytest.fixture(scope="session") | ||
| def varargs_agg_func_fixture(): | ||
| """ | ||
| Register a unary aggregate function | ||
| """ | ||
| from pyarrow import compute as pc | ||
| import numpy as np | ||
|
|
||
| def func(ctx, *args): | ||
| sum = 0.0 | ||
| for arg in args: | ||
| sum += np.nanmean(arg) | ||
| return pa.scalar(sum) | ||
|
|
||
| func_name = "y=sum_mean(x...)" | ||
| func_doc = {"summary": "Varargs aggregate", | ||
| "description": "Varargs aggregate"} | ||
|
|
||
| pc.register_aggregate_function(func, | ||
| func_name, | ||
| func_doc, | ||
| { | ||
| "x": pa.int64(), | ||
| "y": pa.float64() | ||
|
||
| }, | ||
| pa.float64() | ||
| ) | ||
| return func, func_name | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is so that we can support UDFs which can have arity > 1
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes