diff --git a/datafusion/functions-aggregate/src/correlation.rs b/datafusion/functions-aggregate/src/correlation.rs index 20f23662cade..f2a464de4155 100644 --- a/datafusion/functions-aggregate/src/correlation.rs +++ b/datafusion/functions-aggregate/src/correlation.rs @@ -88,7 +88,9 @@ impl Correlation { signature: Signature::exact( vec![DataType::Float64, DataType::Float64], Volatility::Immutable, - ), + ) + .with_parameter_names(vec!["y".to_string(), "x".to_string()]) + .expect("valid parameter names for corr"), } } } diff --git a/datafusion/functions-aggregate/src/percentile_cont.rs b/datafusion/functions-aggregate/src/percentile_cont.rs index 8e9e9a3144d4..39017fea5464 100644 --- a/datafusion/functions-aggregate/src/percentile_cont.rs +++ b/datafusion/functions-aggregate/src/percentile_cont.rs @@ -146,7 +146,9 @@ impl PercentileCont { variants.push(TypeSignature::Exact(vec![num.clone(), DataType::Float64])); } Self { - signature: Signature::one_of(variants, Volatility::Immutable), + signature: Signature::one_of(variants, Volatility::Immutable) + .with_parameter_names(vec!["expr".to_string(), "percentile".to_string()]) + .expect("valid parameter names for percentile_cont"), aliases: vec![String::from("quantile_cont")], } } diff --git a/datafusion/functions-window/src/lead_lag.rs b/datafusion/functions-window/src/lead_lag.rs index 3910a0be574d..02d7fc290b32 100644 --- a/datafusion/functions-window/src/lead_lag.rs +++ b/datafusion/functions-window/src/lead_lag.rs @@ -137,7 +137,13 @@ impl WindowShift { TypeSignature::Any(3), ], Volatility::Immutable, - ), + ) + .with_parameter_names(vec![ + "expr".to_string(), + "offset".to_string(), + "default".to_string(), + ]) + .expect("valid parameter names for lead/lag"), kind, } } diff --git a/datafusion/sql/src/expr/function.rs b/datafusion/sql/src/expr/function.rs index cb34bb0f7eb7..ba6324d1f5d1 100644 --- a/datafusion/sql/src/expr/function.rs +++ b/datafusion/sql/src/expr/function.rs @@ -386,7 +386,30 @@ impl SqlToRel<'_, S> { }; if let Ok(fun) = self.find_window_func(&name) { - let args = self.function_args_to_expr(args, schema, planner_context)?; + let (args, arg_names) = + self.function_args_to_expr_with_names(args, schema, planner_context)?; + + let resolved_args = if arg_names.iter().any(|name| name.is_some()) { + let signature = match &fun { + WindowFunctionDefinition::AggregateUDF(udaf) => udaf.signature(), + WindowFunctionDefinition::WindowUDF(udwf) => udwf.signature(), + }; + + if let Some(param_names) = &signature.parameter_names { + datafusion_expr::arguments::resolve_function_arguments( + param_names, + args, + arg_names, + )? + } else { + return plan_err!( + "Window function '{}' does not support named arguments", + name + ); + } + } else { + args + }; // Plan FILTER clause if present let filter = filter @@ -396,7 +419,7 @@ impl SqlToRel<'_, S> { let mut window_expr = RawWindowExpr { func_def: fun, - args, + args: resolved_args, partition_by, order_by, window_frame, @@ -464,8 +487,8 @@ impl SqlToRel<'_, S> { ); } - let mut args = - self.function_args_to_expr(args, schema, planner_context)?; + let (mut args, mut arg_names) = + self.function_args_to_expr_with_names(args, schema, planner_context)?; let order_by = if fm.is_ordered_set_aggregate() { let within_group = self.order_by_to_sort_expr( @@ -479,6 +502,12 @@ impl SqlToRel<'_, S> { // Add the WITHIN GROUP ordering expressions to the front of the argument list // So function(arg) WITHIN GROUP (ORDER BY x) becomes function(x, arg) if !within_group.is_empty() { + // Prepend None arg names for each WITHIN GROUP expression + let within_group_count = within_group.len(); + arg_names = std::iter::repeat_n(None, within_group_count) + .chain(arg_names) + .collect(); + args = within_group .iter() .map(|sort| sort.expr.clone()) @@ -506,9 +535,26 @@ impl SqlToRel<'_, S> { .transpose()? .map(Box::new); + let resolved_args = if arg_names.iter().any(|name| name.is_some()) { + if let Some(param_names) = &fm.signature().parameter_names { + datafusion_expr::arguments::resolve_function_arguments( + param_names, + args, + arg_names, + )? + } else { + return plan_err!( + "Aggregate function '{}' does not support named arguments", + fm.name() + ); + } + } else { + args + }; + let mut aggregate_expr = RawAggregateExpr { func: fm, - args, + args: resolved_args, distinct, filter, order_by, diff --git a/datafusion/sqllogictest/test_files/named_arguments.slt b/datafusion/sqllogictest/test_files/named_arguments.slt index c93da7e7a8f9..4eab799fd261 100644 --- a/datafusion/sqllogictest/test_files/named_arguments.slt +++ b/datafusion/sqllogictest/test_files/named_arguments.slt @@ -137,3 +137,135 @@ SELECT substr(str => 'hello world', start_pos => 7, length => 5); # Reset to default dialect statement ok set datafusion.sql_parser.dialect = 'Generic'; + +############# +## Aggregate UDF Tests - using corr(y, x) function +############# + +# Setup test data +statement ok +CREATE TABLE correlation_test(col1 DOUBLE, col2 DOUBLE) AS VALUES + (1.0, 2.0), + (2.0, 4.0), + (3.0, 6.0), + (4.0, 8.0); + +# Test positional arguments (baseline) +query R +SELECT corr(col1, col2) FROM correlation_test; +---- +1 + +# Test named arguments out of order (proves named args work for aggregates) +query R +SELECT corr(x => col2, y => col1) FROM correlation_test; +---- +1 + +# Error: function doesn't support named arguments (count has no parameter names) +query error DataFusion error: Error during planning: Aggregate function 'count' does not support named arguments +SELECT count(value => col1) FROM correlation_test; + +# Cleanup +statement ok +DROP TABLE correlation_test; + +############# +## Aggregate UDF with WITHIN GROUP Tests - using percentile_cont(expression, percentile) +## This tests the special handling where WITHIN GROUP ORDER BY expressions are prepended to args +############# + +# Setup test data +statement ok +CREATE TABLE percentile_test(salary DOUBLE) AS VALUES + (50000.0), + (60000.0), + (70000.0), + (80000.0), + (90000.0); + +# Test positional arguments (baseline) - standard call without WITHIN GROUP +query R +SELECT percentile_cont(salary, 0.5) FROM percentile_test; +---- +70000 + +# Test WITHIN GROUP with positional argument +query R +SELECT percentile_cont(0.5) WITHIN GROUP (ORDER BY salary) FROM percentile_test; +---- +70000 + +# Test WITHIN GROUP with named argument for percentile +# The ORDER BY expression (salary) is prepended internally, becoming: percentile_cont(salary, 0.5) +# We use named argument for percentile, which should work correctly +query R +SELECT percentile_cont(percentile => 0.5) WITHIN GROUP (ORDER BY salary) FROM percentile_test; +---- +70000 + +# Verify the WITHIN GROUP prepending logic with different percentile value +query R +SELECT percentile_cont(percentile => 0.25) WITHIN GROUP (ORDER BY salary) FROM percentile_test; +---- +60000 + +# Cleanup +statement ok +DROP TABLE percentile_test; + +############# +## Window UDF Tests - using lead(expression, offset, default) function +############# + +# Setup test data +statement ok +CREATE TABLE window_test(id INT, value INT) AS VALUES + (1, 10), + (2, 20), + (3, 30), + (4, 40); + +# Test positional arguments (baseline) +query II +SELECT id, lead(value, 1, 0) OVER (ORDER BY id) FROM window_test ORDER BY id; +---- +1 20 +2 30 +3 40 +4 0 + +# Test named arguments out of order (proves named args work for window functions) +query II +SELECT id, lead(default => 0, offset => 1, expr => value) OVER (ORDER BY id) FROM window_test ORDER BY id; +---- +1 20 +2 30 +3 40 +4 0 + +# Test with 1 argument (offset and default use defaults) +query II +SELECT id, lead(expr => value) OVER (ORDER BY id) FROM window_test ORDER BY id; +---- +1 20 +2 30 +3 40 +4 NULL + +# Test with 2 arguments (default uses default) +query II +SELECT id, lead(expr => value, offset => 2) OVER (ORDER BY id) FROM window_test ORDER BY id; +---- +1 30 +2 40 +3 NULL +4 NULL + +# Error: function doesn't support named arguments (row_number has no parameter names) +query error DataFusion error: Error during planning: Window function 'row_number' does not support named arguments +SELECT row_number(value => 1) OVER (ORDER BY id) FROM window_test; + +# Cleanup +statement ok +DROP TABLE window_test; diff --git a/docs/source/library-user-guide/functions/adding-udfs.md b/docs/source/library-user-guide/functions/adding-udfs.md index 7581d8b6505e..e56790a4b7d8 100644 --- a/docs/source/library-user-guide/functions/adding-udfs.md +++ b/docs/source/library-user-guide/functions/adding-udfs.md @@ -588,10 +588,17 @@ For async UDF implementation details, see [`async_udf.rs`](https://github.com/ap ## Named Arguments -DataFusion supports PostgreSQL-style named arguments for scalar functions, allowing you to pass arguments by parameter name: +DataFusion supports named arguments for Scalar, Window, and Aggregate UDFs, allowing you to pass arguments by parameter name: ```sql +-- Scalar function SELECT substr(str => 'hello', start_pos => 2, length => 3); + +-- Window function +SELECT lead(expr => value, offset => 1) OVER (ORDER BY id) FROM table; + +-- Aggregate function +SELECT corr(y => col1, x => col2) FROM table; ``` Named arguments can be mixed with positional arguments, but positional arguments must come first: @@ -602,38 +609,7 @@ SELECT substr('hello', start_pos => 2, length => 3); -- Valid ### Implementing Functions with Named Arguments -To support named arguments in your UDF, add parameter names to your function's signature using `.with_parameter_names()`: - -```rust -# use arrow::datatypes::DataType; -# use datafusion_expr::{Signature, Volatility}; -# -# #[derive(Debug)] -# struct MyFunction { -# signature: Signature, -# } -# -impl MyFunction { - fn new() -> Self { - Self { - signature: Signature::uniform( - 2, - vec![DataType::Float64], - Volatility::Immutable - ) - .with_parameter_names(vec![ - "base".to_string(), - "exponent".to_string() - ]) - .expect("valid parameter names"), - } - } -} -``` - -The parameter names should match the order of arguments in your function's signature. DataFusion automatically resolves named arguments to the correct positional order before invoking your function. - -### Example +To support named arguments in your UDF, add parameter names to your function's signature using `.with_parameter_names()`. This works the same way for Scalar, Window, and Aggregate UDFs: ```rust # use std::sync::Arc; @@ -681,10 +657,14 @@ impl ScalarUDFImpl for PowerFunction { } ``` -Once registered, users can call your function with named arguments: +The parameter names should match the order of arguments in your function's signature. DataFusion automatically resolves named arguments to the correct positional order before invoking your function. + +Once registered, users can call your functions with named arguments in any order: ```sql +-- All equivalent SELECT power(base => 2.0, exponent => 3.0); +SELECT power(exponent => 3.0, base => 2.0); SELECT power(2.0, exponent => 3.0); ```