Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion datafusion/functions-aggregate/src/correlation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,9 @@ impl Correlation {
signature: Signature::exact(
vec![DataType::Float64, DataType::Float64],
Volatility::Immutable,
),
)
.with_parameter_names(vec!["y".to_string(), "x".to_string()])
.expect("valid parameter names for corr"),
}
}
}
Expand Down
4 changes: 3 additions & 1 deletion datafusion/functions-aggregate/src/percentile_cont.rs
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,9 @@ impl PercentileCont {
variants.push(TypeSignature::Exact(vec![num.clone(), DataType::Float64]));
}
Self {
signature: Signature::one_of(variants, Volatility::Immutable),
signature: Signature::one_of(variants, Volatility::Immutable)
.with_parameter_names(vec!["expr".to_string(), "percentile".to_string()])
.expect("valid parameter names for percentile_cont"),
aliases: vec![String::from("quantile_cont")],
}
}
Expand Down
8 changes: 7 additions & 1 deletion datafusion/functions-window/src/lead_lag.rs
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,13 @@ impl WindowShift {
TypeSignature::Any(3),
],
Volatility::Immutable,
),
)
.with_parameter_names(vec![
"expr".to_string(),
"offset".to_string(),
"default".to_string(),
])
.expect("valid parameter names for lead/lag"),
kind,
}
}
Expand Down
56 changes: 51 additions & 5 deletions datafusion/sql/src/expr/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,30 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
};

if let Ok(fun) = self.find_window_func(&name) {
let args = self.function_args_to_expr(args, schema, planner_context)?;
let (args, arg_names) =
self.function_args_to_expr_with_names(args, schema, planner_context)?;

let resolved_args = if arg_names.iter().any(|name| name.is_some()) {
let signature = match &fun {
WindowFunctionDefinition::AggregateUDF(udaf) => udaf.signature(),
WindowFunctionDefinition::WindowUDF(udwf) => udwf.signature(),
};

if let Some(param_names) = &signature.parameter_names {
datafusion_expr::arguments::resolve_function_arguments(
param_names,
args,
arg_names,
)?
} else {
return plan_err!(
"Window function '{}' does not support named arguments",
name
);
}
} else {
args
};

// Plan FILTER clause if present
let filter = filter
Expand All @@ -396,7 +419,7 @@ impl<S: ContextProvider> SqlToRel<'_, S> {

let mut window_expr = RawWindowExpr {
func_def: fun,
args,
args: resolved_args,
partition_by,
order_by,
window_frame,
Expand Down Expand Up @@ -464,8 +487,8 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
);
}

let mut args =
self.function_args_to_expr(args, schema, planner_context)?;
let (mut args, mut arg_names) =
self.function_args_to_expr_with_names(args, schema, planner_context)?;

let order_by = if fm.is_ordered_set_aggregate() {
let within_group = self.order_by_to_sort_expr(
Expand All @@ -479,6 +502,12 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
// Add the WITHIN GROUP ordering expressions to the front of the argument list
// So function(arg) WITHIN GROUP (ORDER BY x) becomes function(x, arg)
if !within_group.is_empty() {
// Prepend None arg names for each WITHIN GROUP expression
let within_group_count = within_group.len();
arg_names = std::iter::repeat_n(None, within_group_count)
.chain(arg_names)
.collect();

args = within_group
.iter()
.map(|sort| sort.expr.clone())
Expand Down Expand Up @@ -506,9 +535,26 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
.transpose()?
.map(Box::new);

let resolved_args = if arg_names.iter().any(|name| name.is_some()) {
if let Some(param_names) = &fm.signature().parameter_names {
datafusion_expr::arguments::resolve_function_arguments(
param_names,
args,
arg_names,
)?
} else {
return plan_err!(
"Aggregate function '{}' does not support named arguments",
fm.name()
);
}
} else {
args
};

let mut aggregate_expr = RawAggregateExpr {
func: fm,
args,
args: resolved_args,
distinct,
filter,
order_by,
Expand Down
132 changes: 132 additions & 0 deletions datafusion/sqllogictest/test_files/named_arguments.slt
Original file line number Diff line number Diff line change
Expand Up @@ -137,3 +137,135 @@ SELECT substr(str => 'hello world', start_pos => 7, length => 5);
# Reset to default dialect
statement ok
set datafusion.sql_parser.dialect = 'Generic';

#############
## Aggregate UDF Tests - using corr(y, x) function
#############

# Setup test data
statement ok
CREATE TABLE correlation_test(col1 DOUBLE, col2 DOUBLE) AS VALUES
(1.0, 2.0),
(2.0, 4.0),
(3.0, 6.0),
(4.0, 8.0);

# Test positional arguments (baseline)
query R
SELECT corr(col1, col2) FROM correlation_test;
----
1

# Test named arguments out of order (proves named args work for aggregates)
query R
SELECT corr(x => col2, y => col1) FROM correlation_test;
----
1

# Error: function doesn't support named arguments (count has no parameter names)
query error DataFusion error: Error during planning: Aggregate function 'count' does not support named arguments
SELECT count(value => col1) FROM correlation_test;

# Cleanup
statement ok
DROP TABLE correlation_test;

#############
## Aggregate UDF with WITHIN GROUP Tests - using percentile_cont(expression, percentile)
## This tests the special handling where WITHIN GROUP ORDER BY expressions are prepended to args
#############

# Setup test data
statement ok
CREATE TABLE percentile_test(salary DOUBLE) AS VALUES
(50000.0),
(60000.0),
(70000.0),
(80000.0),
(90000.0);

# Test positional arguments (baseline) - standard call without WITHIN GROUP
query R
SELECT percentile_cont(salary, 0.5) FROM percentile_test;
----
70000

# Test WITHIN GROUP with positional argument
query R
SELECT percentile_cont(0.5) WITHIN GROUP (ORDER BY salary) FROM percentile_test;
----
70000

# Test WITHIN GROUP with named argument for percentile
# The ORDER BY expression (salary) is prepended internally, becoming: percentile_cont(salary, 0.5)
# We use named argument for percentile, which should work correctly
query R
SELECT percentile_cont(percentile => 0.5) WITHIN GROUP (ORDER BY salary) FROM percentile_test;
----
70000

# Verify the WITHIN GROUP prepending logic with different percentile value
query R
SELECT percentile_cont(percentile => 0.25) WITHIN GROUP (ORDER BY salary) FROM percentile_test;
----
60000

# Cleanup
statement ok
DROP TABLE percentile_test;

#############
## Window UDF Tests - using lead(expression, offset, default) function
#############

# Setup test data
statement ok
CREATE TABLE window_test(id INT, value INT) AS VALUES
(1, 10),
(2, 20),
(3, 30),
(4, 40);

# Test positional arguments (baseline)
query II
SELECT id, lead(value, 1, 0) OVER (ORDER BY id) FROM window_test ORDER BY id;
----
1 20
2 30
3 40
4 0

# Test named arguments out of order (proves named args work for window functions)
query II
SELECT id, lead(default => 0, offset => 1, expr => value) OVER (ORDER BY id) FROM window_test ORDER BY id;
----
1 20
2 30
3 40
4 0

# Test with 1 argument (offset and default use defaults)
query II
SELECT id, lead(expr => value) OVER (ORDER BY id) FROM window_test ORDER BY id;
----
1 20
2 30
3 40
4 NULL

# Test with 2 arguments (default uses default)
query II
SELECT id, lead(expr => value, offset => 2) OVER (ORDER BY id) FROM window_test ORDER BY id;
----
1 30
2 40
3 NULL
4 NULL

# Error: function doesn't support named arguments (row_number has no parameter names)
query error DataFusion error: Error during planning: Window function 'row_number' does not support named arguments
SELECT row_number(value => 1) OVER (ORDER BY id) FROM window_test;

# Cleanup
statement ok
DROP TABLE window_test;
48 changes: 14 additions & 34 deletions docs/source/library-user-guide/functions/adding-udfs.md
Original file line number Diff line number Diff line change
Expand Up @@ -588,10 +588,17 @@ For async UDF implementation details, see [`async_udf.rs`](https://github.com/ap

## Named Arguments

DataFusion supports PostgreSQL-style named arguments for scalar functions, allowing you to pass arguments by parameter name:
DataFusion supports named arguments for Scalar, Window, and Aggregate UDFs, allowing you to pass arguments by parameter name:

```sql
-- Scalar function
SELECT substr(str => 'hello', start_pos => 2, length => 3);

-- Window function
SELECT lead(expr => value, offset => 1) OVER (ORDER BY id) FROM table;

-- Aggregate function
SELECT corr(y => col1, x => col2) FROM table;
```

Named arguments can be mixed with positional arguments, but positional arguments must come first:
Expand All @@ -602,38 +609,7 @@ SELECT substr('hello', start_pos => 2, length => 3); -- Valid

### Implementing Functions with Named Arguments

To support named arguments in your UDF, add parameter names to your function's signature using `.with_parameter_names()`:

```rust
# use arrow::datatypes::DataType;
# use datafusion_expr::{Signature, Volatility};
#
# #[derive(Debug)]
# struct MyFunction {
# signature: Signature,
# }
#
impl MyFunction {
fn new() -> Self {
Self {
signature: Signature::uniform(
2,
vec![DataType::Float64],
Volatility::Immutable
)
.with_parameter_names(vec![
"base".to_string(),
"exponent".to_string()
])
.expect("valid parameter names"),
}
}
}
```

The parameter names should match the order of arguments in your function's signature. DataFusion automatically resolves named arguments to the correct positional order before invoking your function.

### Example
To support named arguments in your UDF, add parameter names to your function's signature using `.with_parameter_names()`. This works the same way for Scalar, Window, and Aggregate UDFs:

```rust
# use std::sync::Arc;
Expand Down Expand Up @@ -681,10 +657,14 @@ impl ScalarUDFImpl for PowerFunction {
}
```

Once registered, users can call your function with named arguments:
The parameter names should match the order of arguments in your function's signature. DataFusion automatically resolves named arguments to the correct positional order before invoking your function.

Once registered, users can call your functions with named arguments in any order:

```sql
-- All equivalent
SELECT power(base => 2.0, exponent => 3.0);
SELECT power(exponent => 3.0, base => 2.0);
SELECT power(2.0, exponent => 3.0);
```

Expand Down