From 6fca1dddf5d53d50b61a14f0693c23c68f4b4702 Mon Sep 17 00:00:00 2001 From: bubulalabu Date: Sat, 11 Oct 2025 19:01:54 +0200 Subject: [PATCH 1/9] Add PostgreSQL-style named arguments support for scalar functions Implement support for calling functions with named parameters using PostgreSQL-style syntax (param => value). Features: - Parse named arguments from SQL (param => value syntax) - Resolve named arguments to positional order before execution - Support mixed positional and named arguments - Store parameter names in function signatures - Show parameter names in error messages Implementation: - Added argument resolution logic with validation - Extended Signature with parameter_names field - Updated SQL parser to handle named argument syntax - Integrated into physical planning phase - Added comprehensive tests and documentation Example usage: SELECT substr(str => 'hello', start_pos => 2, length => 3); SELECT substr('hello', start_pos => 2, length => 3); Error messages now show: Candidate functions: substr(str, start_pos) substr(str, start_pos, length) Instead of generic types like substr(Any, Any). Related issue: #17379 --- datafusion/expr-common/src/signature.rs | 498 ++++++++++++++++++ datafusion/expr/src/lib.rs | 2 +- datafusion/expr/src/udf.rs | 274 ++++++++++ datafusion/expr/src/utils.rs | 62 ++- datafusion/functions-nested/src/replace.rs | 3 + datafusion/functions/src/unicode/substr.rs | 17 +- datafusion/sql/src/expr/function.rs | 75 ++- .../test_files/named_arguments.slt | 75 +++ .../functions/adding-udfs.md | 105 ++++ 9 files changed, 1096 insertions(+), 15 deletions(-) create mode 100644 datafusion/sqllogictest/test_files/named_arguments.slt diff --git a/datafusion/expr-common/src/signature.rs b/datafusion/expr-common/src/signature.rs index 5fd4518e2e57f..8d2738fa897e1 100644 --- a/datafusion/expr-common/src/signature.rs +++ b/datafusion/expr-common/src/signature.rs @@ -486,6 +486,109 @@ impl TypeSignature { } } + /// Return string representation of the function signature with parameter names. + /// + /// This method is similar to [`Self::to_string_repr`] but uses parameter names + /// instead of types when available. This is useful for generating more helpful + /// error messages. + /// + /// # Arguments + /// * `parameter_names` - Optional slice of parameter names. When provided, these + /// names will be used instead of type names in the output. + /// + /// # Examples + /// ``` + /// # use datafusion_expr_common::signature::TypeSignature; + /// # use arrow::datatypes::DataType; + /// let sig = TypeSignature::Exact(vec![DataType::Int32, DataType::Utf8]); + /// + /// // Without names: shows types + /// assert_eq!(sig.to_string_repr_with_names(None), vec!["Int32, Utf8"]); + /// + /// // With names: shows parameter names + /// assert_eq!( + /// sig.to_string_repr_with_names(Some(&["id".to_string(), "name".to_string()])), + /// vec!["id, name"] + /// ); + /// ``` + pub fn to_string_repr_with_names( + &self, + parameter_names: Option<&[String]>, + ) -> Vec { + match self { + TypeSignature::Exact(types) => { + if let Some(names) = parameter_names { + vec![names.iter().take(types.len()).cloned().collect::>().join(", ")] + } else { + vec![Self::join_types(types, ", ")] + } + } + TypeSignature::Any(count) => { + if let Some(names) = parameter_names { + vec![names.iter().take(*count).cloned().collect::>().join(", ")] + } else { + vec![std::iter::repeat_n("Any", *count) + .collect::>() + .join(", ")] + } + } + TypeSignature::Uniform(count, _types) => { + if let Some(names) = parameter_names { + vec![names.iter().take(*count).cloned().collect::>().join(", ")] + } else { + // Fallback to original representation + self.to_string_repr() + } + } + TypeSignature::Coercible(coercions) => { + if let Some(names) = parameter_names { + vec![names.iter().take(coercions.len()).cloned().collect::>().join(", ")] + } else { + vec![Self::join_types(coercions, ", ")] + } + } + TypeSignature::Comparable(count) + | TypeSignature::Numeric(count) + | TypeSignature::String(count) => { + if let Some(names) = parameter_names { + vec![names.iter().take(*count).cloned().collect::>().join(", ")] + } else { + // Fallback to original representation + self.to_string_repr() + } + } + TypeSignature::Nullary => { + // No parameters, so no names to show + self.to_string_repr() + } + TypeSignature::ArraySignature(array_sig) => { + // ArraySignature has fixed arity, so it can support parameter names + let arity = match array_sig { + ArrayFunctionSignature::Array { arguments, .. } => arguments.len(), + ArrayFunctionSignature::RecursiveArray => 1, + ArrayFunctionSignature::MapArray => 1, + }; + if let Some(names) = parameter_names { + vec![names.iter().take(arity).cloned().collect::>().join(", ")] + } else { + // Fallback to semantic names like "array, index, element" + self.to_string_repr() + } + } + TypeSignature::OneOf(sigs) => { + sigs.iter() + .flat_map(|s| s.to_string_repr_with_names(parameter_names)) + .collect() + } + // Variable arity signatures cannot use parameter names + TypeSignature::Variadic(_) + | TypeSignature::VariadicAny + | TypeSignature::UserDefined => { + self.to_string_repr() + } + } + } + /// Helper function to join types with specified delimiter. pub fn join_types(types: &[T], delimiter: &str) -> String { types @@ -804,6 +907,13 @@ pub struct Signature { pub type_signature: TypeSignature, /// The volatility of the function. See [Volatility] for more information. pub volatility: Volatility, + /// Optional parameter names for the function arguments. + /// + /// If provided, enables named argument notation for function calls (e.g., `func(a => 1, b => 2)`). + /// The length must match the number of arguments defined by `type_signature`. + /// + /// Defaults to `None`, meaning only positional arguments are supported. + pub parameter_names: Option>, } impl Signature { @@ -812,6 +922,7 @@ impl Signature { Signature { type_signature, volatility, + parameter_names: None, } } /// An arbitrary number of arguments with the same type, from those listed in `common_types`. @@ -819,6 +930,7 @@ impl Signature { Self { type_signature: TypeSignature::Variadic(common_types), volatility, + parameter_names: None, } } /// User-defined coercion rules for the function. @@ -826,6 +938,7 @@ impl Signature { Self { type_signature: TypeSignature::UserDefined, volatility, + parameter_names: None, } } @@ -834,6 +947,7 @@ impl Signature { Self { type_signature: TypeSignature::Numeric(arg_count), volatility, + parameter_names: None, } } @@ -842,6 +956,7 @@ impl Signature { Self { type_signature: TypeSignature::String(arg_count), volatility, + parameter_names: None, } } @@ -850,6 +965,7 @@ impl Signature { Self { type_signature: TypeSignature::VariadicAny, volatility, + parameter_names: None, } } /// A fixed number of arguments of the same type, from those listed in `valid_types`. @@ -861,6 +977,7 @@ impl Signature { Self { type_signature: TypeSignature::Uniform(arg_count, valid_types), volatility, + parameter_names: None, } } /// Exactly matches the types in `exact_types`, in order. @@ -868,6 +985,7 @@ impl Signature { Signature { type_signature: TypeSignature::Exact(exact_types), volatility, + parameter_names: None, } } @@ -876,6 +994,7 @@ impl Signature { Self { type_signature: TypeSignature::Coercible(target_types), volatility, + parameter_names: None, } } @@ -884,6 +1003,7 @@ impl Signature { Self { type_signature: TypeSignature::Comparable(arg_count), volatility, + parameter_names: None, } } @@ -891,6 +1011,7 @@ impl Signature { Signature { type_signature: TypeSignature::Nullary, volatility, + parameter_names: None, } } @@ -899,6 +1020,7 @@ impl Signature { Signature { type_signature: TypeSignature::Any(arg_count), volatility, + parameter_names: None, } } @@ -907,6 +1029,7 @@ impl Signature { Signature { type_signature: TypeSignature::OneOf(type_signatures), volatility, + parameter_names: None, } } @@ -923,6 +1046,7 @@ impl Signature { }, ), volatility, + parameter_names: None, } } @@ -939,6 +1063,7 @@ impl Signature { }, ), volatility, + parameter_names: None, } } @@ -956,6 +1081,7 @@ impl Signature { }, ), volatility, + parameter_names: None, } } @@ -980,6 +1106,7 @@ impl Signature { }), ]), volatility, + parameter_names: None, } } @@ -996,6 +1123,7 @@ impl Signature { }, ), volatility, + parameter_names: None, } } @@ -1003,6 +1131,104 @@ impl Signature { pub fn array(volatility: Volatility) -> Self { Signature::arrays(1, Some(ListCoercion::FixedSizedListToList), volatility) } + + /// Add parameter names to this signature, enabling named argument notation. + /// + /// # Example + /// ``` + /// # use datafusion_expr_common::signature::{Signature, Volatility}; + /// # use arrow::datatypes::DataType; + /// let sig = Signature::exact(vec![DataType::Int32, DataType::Utf8], Volatility::Immutable) + /// .with_parameter_names(vec!["count".to_string(), "name".to_string()]); + /// ``` + /// + /// # Errors + /// Returns an error if the number of parameter names doesn't match the signature's arity. + /// For signatures with variable arity (e.g., `Variadic`, `VariadicAny`), parameter names + /// cannot be specified. + pub fn with_parameter_names( + mut self, + names: Vec, + ) -> datafusion_common::Result { + // Validate that the number of names matches the signature + self.validate_parameter_names(&names)?; + self.parameter_names = Some(names); + Ok(self) + } + + /// Validate that parameter names are compatible with this signature + fn validate_parameter_names( + &self, + names: &[String], + ) -> datafusion_common::Result<()> { + // Get expected argument count from the type signature + let expected_count = match &self.type_signature { + TypeSignature::Exact(types) => Some(types.len()), + TypeSignature::Uniform(count, _) => Some(*count), + TypeSignature::Numeric(count) => Some(*count), + TypeSignature::String(count) => Some(*count), + TypeSignature::Comparable(count) => Some(*count), + TypeSignature::Any(count) => Some(*count), + TypeSignature::Coercible(types) => Some(types.len()), + TypeSignature::Nullary => Some(0), + TypeSignature::ArraySignature(ArrayFunctionSignature::Array { + arguments, + .. + }) => Some(arguments.len()), + TypeSignature::ArraySignature(ArrayFunctionSignature::RecursiveArray) => { + Some(1) + } + TypeSignature::ArraySignature(ArrayFunctionSignature::MapArray) => Some(1), + // For OneOf, get the maximum arity from all variants + TypeSignature::OneOf(variants) => { + // Get max arity from all variants + let max_arity = variants.iter().filter_map(|v| match v { + TypeSignature::Any(count) + | TypeSignature::Uniform(count, _) + | TypeSignature::Numeric(count) + | TypeSignature::String(count) + | TypeSignature::Comparable(count) => Some(*count), + TypeSignature::Exact(types) => Some(types.len()), + TypeSignature::Coercible(types) => Some(types.len()), + TypeSignature::Nullary => Some(0), + _ => None, + }).max(); + max_arity + } + // Variable arity signatures cannot have parameter names + TypeSignature::Variadic(_) + | TypeSignature::VariadicAny + | TypeSignature::UserDefined => None, + }; + + if let Some(expected) = expected_count { + if names.len() != expected { + return datafusion_common::plan_err!( + "Parameter names count ({}) does not match signature arity ({})", + names.len(), + expected + ); + } + } else { + return datafusion_common::plan_err!( + "Cannot specify parameter names for variable arity signature: {:?}", + self.type_signature + ); + } + + // Validate no duplicate names + let mut seen = std::collections::HashSet::new(); + for name in names { + if !seen.insert(name) { + return datafusion_common::plan_err!( + "Duplicate parameter name: '{}'", + name + ); + } + } + + Ok(()) + } } #[cfg(test)] @@ -1167,4 +1393,276 @@ mod tests { ] ); } + + #[test] + fn test_signature_with_parameter_names() { + // Test adding parameter names to exact signature + let sig = Signature::exact(vec![DataType::Int32, DataType::Utf8], Volatility::Immutable) + .with_parameter_names(vec!["count".to_string(), "name".to_string()]) + .unwrap(); + + assert_eq!(sig.parameter_names, Some(vec!["count".to_string(), "name".to_string()])); + assert_eq!(sig.type_signature, TypeSignature::Exact(vec![DataType::Int32, DataType::Utf8])); + } + + #[test] + fn test_signature_parameter_names_wrong_count() { + // Test that wrong number of names fails + let result = Signature::exact(vec![DataType::Int32, DataType::Utf8], Volatility::Immutable) + .with_parameter_names(vec!["count".to_string()]); // Only 1 name for 2 args + + assert!(result.is_err()); + assert!(result.unwrap_err().to_string().contains("does not match signature arity")); + } + + #[test] + fn test_signature_parameter_names_duplicate() { + // Test that duplicate names fail + let result = Signature::exact(vec![DataType::Int32, DataType::Int32], Volatility::Immutable) + .with_parameter_names(vec!["count".to_string(), "count".to_string()]); + + assert!(result.is_err()); + assert!(result.unwrap_err().to_string().contains("Duplicate parameter name")); + } + + #[test] + fn test_signature_parameter_names_variadic() { + // Test that variadic signatures reject parameter names + let result = Signature::variadic(vec![DataType::Int32], Volatility::Immutable) + .with_parameter_names(vec!["arg".to_string()]); + + assert!(result.is_err()); + assert!(result.unwrap_err().to_string().contains("variable arity signature")); + } + + #[test] + fn test_signature_without_parameter_names() { + // Test that signatures without parameter names still work + let sig = Signature::exact(vec![DataType::Int32, DataType::Utf8], Volatility::Immutable); + + assert_eq!(sig.parameter_names, None); + } + + #[test] + fn test_signature_uniform_with_parameter_names() { + // Test uniform signature with parameter names + let sig = Signature::uniform(3, vec![DataType::Float64], Volatility::Immutable) + .with_parameter_names(vec!["x".to_string(), "y".to_string(), "z".to_string()]) + .unwrap(); + + assert_eq!(sig.parameter_names, Some(vec!["x".to_string(), "y".to_string(), "z".to_string()])); + } + + #[test] + fn test_signature_numeric_with_parameter_names() { + // Test numeric signature with parameter names + let sig = Signature::numeric(2, Volatility::Immutable) + .with_parameter_names(vec!["a".to_string(), "b".to_string()]) + .unwrap(); + + assert_eq!(sig.parameter_names, Some(vec!["a".to_string(), "b".to_string()])); + } + + #[test] + fn test_signature_nullary_with_empty_names() { + // Test that nullary signature accepts empty parameter names + let sig = Signature::nullary(Volatility::Immutable) + .with_parameter_names(vec![]) + .unwrap(); + + assert_eq!(sig.parameter_names, Some(vec![])); + } + + #[test] + fn test_to_string_repr_with_names_exact() { + // Test Exact signature with parameter names + let sig = TypeSignature::Exact(vec![DataType::Int32, DataType::Utf8]); + + // Without names: should show types + assert_eq!(sig.to_string_repr_with_names(None), vec!["Int32, Utf8"]); + + // With names: should show parameter names + let names = vec!["id".to_string(), "name".to_string()]; + assert_eq!(sig.to_string_repr_with_names(Some(&names)), vec!["id, name"]); + } + + #[test] + fn test_to_string_repr_with_names_any() { + // Test Any signature with parameter names + let sig = TypeSignature::Any(3); + + // Without names: should show "Any" for each parameter + assert_eq!(sig.to_string_repr_with_names(None), vec!["Any, Any, Any"]); + + // With names: should show parameter names + let names = vec!["x".to_string(), "y".to_string(), "z".to_string()]; + assert_eq!(sig.to_string_repr_with_names(Some(&names)), vec!["x, y, z"]); + } + + #[test] + fn test_to_string_repr_with_names_one_of() { + // Test OneOf signature with parameter names (like substr) + let sig = TypeSignature::OneOf(vec![ + TypeSignature::Any(2), + TypeSignature::Any(3), + ]); + + // Without names: should show generic "Any" types + assert_eq!( + sig.to_string_repr_with_names(None), + vec!["Any, Any", "Any, Any, Any"] + ); + + // With names: should use names for each variant + let names = vec!["str".to_string(), "start_pos".to_string(), "length".to_string()]; + assert_eq!( + sig.to_string_repr_with_names(Some(&names)), + vec!["str, start_pos", "str, start_pos, length"] + ); + } + + #[test] + fn test_to_string_repr_with_names_partial() { + // Test with fewer names than needed + let sig = TypeSignature::Exact(vec![DataType::Int32, DataType::Utf8, DataType::Float64]); + + // Provide only 2 names for 3 parameters + let names = vec!["a".to_string(), "b".to_string()]; + // Should only use the available names (takes first 2) + assert_eq!(sig.to_string_repr_with_names(Some(&names)), vec!["a, b"]); + } + + #[test] + fn test_to_string_repr_with_names_uniform() { + // Test Uniform signature with parameter names + let sig = TypeSignature::Uniform(2, vec![DataType::Float64]); + + // Without names: should show type representation + assert_eq!( + sig.to_string_repr_with_names(None), + vec!["Float64, Float64"] + ); + + // With names: should show parameter names + let names = vec!["x".to_string(), "y".to_string()]; + assert_eq!(sig.to_string_repr_with_names(Some(&names)), vec!["x, y"]); + } + + #[test] + fn test_to_string_repr_with_names_coercible() { + use crate::signature::{Coercion, TypeSignatureClass}; + use datafusion_common::types::logical_int32; + + // Test Coercible signature with parameter names + let sig = TypeSignature::Coercible(vec![ + Coercion::new_exact(TypeSignatureClass::Native(logical_int32())), + Coercion::new_exact(TypeSignatureClass::Native(logical_int32())), + ]); + + // With names: should show parameter names + let names = vec!["a".to_string(), "b".to_string()]; + assert_eq!(sig.to_string_repr_with_names(Some(&names)), vec!["a, b"]); + } + + #[test] + fn test_to_string_repr_with_names_comparable_numeric_string() { + // Test Comparable, Numeric, and String signatures + let comparable = TypeSignature::Comparable(3); + let numeric = TypeSignature::Numeric(2); + let string_sig = TypeSignature::String(2); + + let names = vec!["a".to_string(), "b".to_string(), "c".to_string()]; + + // All should show parameter names when provided + assert_eq!( + comparable.to_string_repr_with_names(Some(&names)), + vec!["a, b, c"] + ); + assert_eq!( + numeric.to_string_repr_with_names(Some(&names)), + vec!["a, b"] + ); + assert_eq!( + string_sig.to_string_repr_with_names(Some(&names)), + vec!["a, b"] + ); + } + + #[test] + fn test_to_string_repr_with_names_variadic_fallback() { + // Test that variadic variants fall back to to_string_repr() + let variadic = TypeSignature::Variadic(vec![DataType::Utf8, DataType::LargeUtf8]); + let names = vec!["x".to_string()]; + assert_eq!( + variadic.to_string_repr_with_names(Some(&names)), + variadic.to_string_repr() + ); + + let variadic_any = TypeSignature::VariadicAny; + assert_eq!( + variadic_any.to_string_repr_with_names(Some(&names)), + variadic_any.to_string_repr() + ); + + let user_defined = TypeSignature::UserDefined; + assert_eq!( + user_defined.to_string_repr_with_names(Some(&names)), + user_defined.to_string_repr() + ); + } + + #[test] + fn test_to_string_repr_with_names_nullary() { + // Test Nullary signature (no arguments) + let sig = TypeSignature::Nullary; + let names = vec!["x".to_string()]; + + // Should return empty representation, names don't apply + assert_eq!(sig.to_string_repr_with_names(Some(&names)), vec!["NullAry()"]); + assert_eq!(sig.to_string_repr_with_names(None), vec!["NullAry()"]); + } + + #[test] + fn test_to_string_repr_with_names_array_signature() { + use crate::signature::{ArrayFunctionArgument, ArrayFunctionSignature}; + + // Test ArraySignature with parameter names + let sig = TypeSignature::ArraySignature(ArrayFunctionSignature::Array { + arguments: vec![ + ArrayFunctionArgument::Array, + ArrayFunctionArgument::Index, + ArrayFunctionArgument::Element, + ], + array_coercion: None, + }); + + // Without names: should show semantic types + assert_eq!( + sig.to_string_repr_with_names(None), + vec!["array, index, element"] + ); + + // With names: should show parameter names + let names = vec!["arr".to_string(), "size".to_string(), "value".to_string()]; + assert_eq!( + sig.to_string_repr_with_names(Some(&names)), + vec!["arr, size, value"] + ); + + // Test RecursiveArray (1 argument) + let recursive = TypeSignature::ArraySignature(ArrayFunctionSignature::RecursiveArray); + let names = vec!["array".to_string()]; + assert_eq!( + recursive.to_string_repr_with_names(Some(&names)), + vec!["array"] + ); + + // Test MapArray (1 argument) + let map_array = TypeSignature::ArraySignature(ArrayFunctionSignature::MapArray); + let names = vec!["map".to_string()]; + assert_eq!( + map_array.to_string_repr_with_names(Some(&names)), + vec!["map"] + ); + } } diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs index 1c9734a89bd37..cb2d84915c217 100644 --- a/datafusion/expr/src/lib.rs +++ b/datafusion/expr/src/lib.rs @@ -116,7 +116,7 @@ pub use udaf::{ udaf_default_window_function_schema_name, AggregateUDF, AggregateUDFImpl, ReversedUDAF, SetMonotonicity, StatisticsArgs, }; -pub use udf::{ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl}; +pub use udf::{arguments, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl}; pub use udwf::{ReversedUDWF, WindowUDF, WindowUDFImpl}; pub use window_frame::{WindowFrame, WindowFrameBound, WindowFrameUnits}; diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs index d522158f7b6b7..b26939a7cf5e8 100644 --- a/datafusion/expr/src/udf.rs +++ b/datafusion/expr/src/udf.rs @@ -957,3 +957,277 @@ mod tests { hasher.finish() } } + +/// Argument resolution logic for named function parameters +pub mod arguments { + use datafusion_common::{plan_err, Result}; + use crate::Expr; + + /// Resolves function arguments, handling named and positional notation. + /// + /// This function validates and reorders arguments to match the function's parameter names + /// when named arguments are used. + /// + /// # Rules + /// - All positional arguments must come before named arguments + /// - Named arguments can be in any order after positional arguments + /// - All parameter names must match the provided parameter_names + /// - No duplicate parameter names allowed + /// + /// # Arguments + /// * `param_names` - The function's parameter names in order + /// * `args` - The argument expressions + /// * `arg_names` - Optional parameter name for each argument + /// + /// # Returns + /// A vector of expressions in the correct order matching the parameter names + /// + /// # Examples + /// ```rust,ignore + /// // Given parameters ["a", "b", "c"] + /// // And call: func(10, c => 30, b => 20) + /// // Returns: [Expr(10), Expr(20), Expr(30)] + /// ``` + pub fn resolve_function_arguments( + param_names: &[String], + args: Vec, + arg_names: Vec>, + ) -> Result> { + // Validate that arg_names length matches args length + if args.len() != arg_names.len() { + return plan_err!( + "Internal error: args length ({}) != arg_names length ({})", + args.len(), + arg_names.len() + ); + } + + // Check if all arguments are positional (fast path) + if arg_names.iter().all(|name| name.is_none()) { + return Ok(args); + } + + // Validate mixed positional and named arguments + validate_argument_order(&arg_names)?; + + // Validate and reorder named arguments + reorder_named_arguments(param_names, args, arg_names) + } + + /// Validates that positional arguments come before named arguments + fn validate_argument_order(arg_names: &[Option]) -> Result<()> { + let mut seen_named = false; + for (i, arg_name) in arg_names.iter().enumerate() { + match arg_name { + Some(_) => seen_named = true, + None if seen_named => { + return plan_err!( + "Positional argument at position {} follows named argument. \ + All positional arguments must come before named arguments.", + i + ); + } + None => {} + } + } + Ok(()) + } + + /// Reorders arguments based on named parameters to match signature order + fn reorder_named_arguments( + param_names: &[String], + args: Vec, + arg_names: Vec>, + ) -> Result> { + // Count positional vs named arguments + let positional_count = arg_names.iter().filter(|n| n.is_none()).count(); + + // Capture args length before consuming the vector + let args_len = args.len(); + + // Create a result vector with the expected size + let expected_arg_count = param_names.len(); + let mut result: Vec> = vec![None; expected_arg_count]; + + // Track which parameters have been assigned + let mut assigned = vec![false; expected_arg_count]; + + // Process all arguments (both positional and named) + for (i, (arg, arg_name)) in args.into_iter().zip(arg_names).enumerate() { + if let Some(name) = arg_name { + // Named argument - find its position in param_names + let param_index = param_names + .iter() + .position(|p| p == &name) + .ok_or_else(|| { + datafusion_common::plan_datafusion_err!( + "Unknown parameter name '{}'. Valid parameters are: [{}]", + name, + param_names.join(", ") + ) + })?; + + // Check if this parameter was already assigned + if assigned[param_index] { + return plan_err!( + "Parameter '{}' specified multiple times", + name + ); + } + + result[param_index] = Some(arg); + assigned[param_index] = true; + } else { + // Positional argument - place at current position + if i >= expected_arg_count { + return plan_err!( + "Too many positional arguments: expected at most {}, got {}", + expected_arg_count, + positional_count + ); + } + result[i] = Some(arg); + assigned[i] = true; + } + } + + // Check if all required parameters were provided + // Only require parameters up to the number of arguments provided (supports optional parameters) + let required_count = args_len; + for i in 0..required_count { + if !assigned[i] { + return plan_err!( + "Missing required parameter '{}'", + param_names[i] + ); + } + } + + // Return only the assigned parameters (handles optional trailing parameters) + Ok(result.into_iter().take(required_count).map(|e| e.unwrap()).collect()) + } + + #[cfg(test)] + mod tests { + use super::*; + use crate::lit; + + #[test] + fn test_all_positional() { + let param_names = vec!["a".to_string(), "b".to_string()]; + + let args = vec![lit(1), lit("hello")]; + let arg_names = vec![None, None]; + + let result = resolve_function_arguments(¶m_names, args.clone(), arg_names).unwrap(); + assert_eq!(result.len(), 2); + } + + #[test] + fn test_all_named() { + let param_names = vec!["a".to_string(), "b".to_string()]; + + let args = vec![lit(1), lit("hello")]; + let arg_names = vec![Some("a".to_string()), Some("b".to_string())]; + + let result = resolve_function_arguments(¶m_names, args, arg_names).unwrap(); + assert_eq!(result.len(), 2); + } + + #[test] + fn test_named_reordering() { + let param_names = vec!["a".to_string(), "b".to_string(), "c".to_string()]; + + // Call with: func(c => 3.0, a => 1, b => "hello") + let args = vec![lit(3.0), lit(1), lit("hello")]; + let arg_names = vec![ + Some("c".to_string()), + Some("a".to_string()), + Some("b".to_string()), + ]; + + let result = resolve_function_arguments(¶m_names, args, arg_names).unwrap(); + + // Should be reordered to [a, b, c] = [1, "hello", 3.0] + assert_eq!(result.len(), 3); + assert_eq!(result[0], lit(1)); + assert_eq!(result[1], lit("hello")); + assert_eq!(result[2], lit(3.0)); + } + + #[test] + fn test_mixed_positional_and_named() { + let param_names = vec!["a".to_string(), "b".to_string(), "c".to_string()]; + + // Call with: func(1, c => 3.0, b => "hello") + let args = vec![lit(1), lit(3.0), lit("hello")]; + let arg_names = vec![None, Some("c".to_string()), Some("b".to_string())]; + + let result = resolve_function_arguments(¶m_names, args, arg_names).unwrap(); + + // Should be reordered to [a, b, c] = [1, "hello", 3.0] + assert_eq!(result.len(), 3); + assert_eq!(result[0], lit(1)); + assert_eq!(result[1], lit("hello")); + assert_eq!(result[2], lit(3.0)); + } + + #[test] + fn test_positional_after_named_error() { + let param_names = vec!["a".to_string(), "b".to_string()]; + + // Call with: func(a => 1, "hello") - ERROR + let args = vec![lit(1), lit("hello")]; + let arg_names = vec![Some("a".to_string()), None]; + + let result = resolve_function_arguments(¶m_names, args, arg_names); + assert!(result.is_err()); + assert!(result + .unwrap_err() + .to_string() + .contains("Positional argument")); + } + + #[test] + fn test_unknown_parameter_name() { + let param_names = vec!["a".to_string(), "b".to_string()]; + + // Call with: func(x => 1, b => "hello") - ERROR + let args = vec![lit(1), lit("hello")]; + let arg_names = vec![Some("x".to_string()), Some("b".to_string())]; + + let result = resolve_function_arguments(¶m_names, args, arg_names); + assert!(result.is_err()); + assert!(result.unwrap_err().to_string().contains("Unknown parameter")); + } + + #[test] + fn test_duplicate_parameter_name() { + let param_names = vec!["a".to_string(), "b".to_string()]; + + // Call with: func(a => 1, a => 2) - ERROR + let args = vec![lit(1), lit(2)]; + let arg_names = vec![Some("a".to_string()), Some("a".to_string())]; + + let result = resolve_function_arguments(¶m_names, args, arg_names); + assert!(result.is_err()); + assert!(result + .unwrap_err() + .to_string() + .contains("specified multiple times")); + } + + #[test] + fn test_missing_required_parameter() { + let param_names = vec!["a".to_string(), "b".to_string(), "c".to_string()]; + + // Call with: func(a => 1, c => 3.0) - missing 'b' + let args = vec![lit(1), lit(3.0)]; + let arg_names = vec![Some("a".to_string()), Some("c".to_string())]; + + let result = resolve_function_arguments(¶m_names, args, arg_names); + assert!(result.is_err()); + assert!(result.unwrap_err().to_string().contains("Missing required parameter")); + } + } +} diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index b91db4527b3aa..8bb94c569cb97 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -936,7 +936,7 @@ pub fn generate_signature_error_msg( ) -> String { let candidate_signatures = func_signature .type_signature - .to_string_repr() + .to_string_repr_with_names(func_signature.parameter_names.as_deref()) .iter() .map(|args_str| format!("\t{func_name}({args_str})")) .collect::>() @@ -1714,4 +1714,64 @@ mod tests { DataType::List(Arc::new(Field::new("my_union", union_type, true))); assert!(!can_hash(&list_union_type)); } + + #[test] + fn test_generate_signature_error_msg_with_parameter_names() { + use datafusion_expr_common::signature::{TypeSignature, Volatility}; + + // Create a signature like substr with parameter names + let sig = Signature::one_of( + vec![TypeSignature::Any(2), TypeSignature::Any(3)], + Volatility::Immutable, + ) + .with_parameter_names(vec![ + "str".to_string(), + "start_pos".to_string(), + "length".to_string(), + ]) + .expect("valid parameter names"); + + // Generate error message with only 1 argument provided + let error_msg = generate_signature_error_msg("substr", sig, &[DataType::Utf8]); + + // Error message should contain parameter names + assert!( + error_msg.contains("str, start_pos"), + "Expected 'str, start_pos' in error message, got: {}", + error_msg + ); + assert!( + error_msg.contains("str, start_pos, length"), + "Expected 'str, start_pos, length' in error message, got: {}", + error_msg + ); + + // Should NOT contain generic "Any" types + assert!( + !error_msg.contains("Any, Any"), + "Should not contain 'Any, Any', got: {}", + error_msg + ); + } + + #[test] + fn test_generate_signature_error_msg_without_parameter_names() { + use datafusion_expr_common::signature::{TypeSignature, Volatility}; + + // Create a signature without parameter names + let sig = Signature::one_of( + vec![TypeSignature::Any(2), TypeSignature::Any(3)], + Volatility::Immutable, + ); + + // Generate error message + let error_msg = generate_signature_error_msg("my_func", sig, &[DataType::Int32]); + + // Should contain generic "Any" types when no parameter names + assert!( + error_msg.contains("Any, Any"), + "Expected 'Any, Any' without parameter names, got: {}", + error_msg + ); + } } diff --git a/datafusion/functions-nested/src/replace.rs b/datafusion/functions-nested/src/replace.rs index 59f851a776a18..4314d41419bcc 100644 --- a/datafusion/functions-nested/src/replace.rs +++ b/datafusion/functions-nested/src/replace.rs @@ -105,6 +105,7 @@ impl ArrayReplace { }, ), volatility: Volatility::Immutable, + parameter_names: None, }, aliases: vec![String::from("list_replace")], } @@ -186,6 +187,7 @@ impl ArrayReplaceN { }, ), volatility: Volatility::Immutable, + parameter_names: None, }, aliases: vec![String::from("list_replace_n")], } @@ -265,6 +267,7 @@ impl ArrayReplaceAll { }, ), volatility: Volatility::Immutable, + parameter_names: None, }, aliases: vec![String::from("list_replace_all")], } diff --git a/datafusion/functions/src/unicode/substr.rs b/datafusion/functions/src/unicode/substr.rs index 0b35f664532d4..631fee1b67c37 100644 --- a/datafusion/functions/src/unicode/substr.rs +++ b/datafusion/functions/src/unicode/substr.rs @@ -70,8 +70,23 @@ impl Default for SubstrFunc { impl SubstrFunc { pub fn new() -> Self { + use datafusion_expr::TypeSignature; Self { - signature: Signature::user_defined(Volatility::Immutable), + signature: Signature::one_of( + vec![ + // substr(str, start_pos) + TypeSignature::Any(2), + // substr(str, start_pos, length) + TypeSignature::Any(3), + ], + Volatility::Immutable, + ) + .with_parameter_names(vec![ + "str".to_string(), + "start_pos".to_string(), + "length".to_string(), + ]) + .expect("valid parameter names"), aliases: vec![String::from("substring")], } } diff --git a/datafusion/sql/src/expr/function.rs b/datafusion/sql/src/expr/function.rs index eabf645a5eafd..9f31ddb2d8e85 100644 --- a/datafusion/sql/src/expr/function.rs +++ b/datafusion/sql/src/expr/function.rs @@ -274,8 +274,30 @@ impl SqlToRel<'_, S> { } // User-defined function (UDF) should have precedence if let Some(fm) = self.context_provider.get_function_meta(&name) { - let args = self.function_args_to_expr(args, schema, planner_context)?; - let inner = ScalarFunction::new_udf(fm, args); + let (args, arg_names) = self.function_args_to_expr_with_names(args, schema, planner_context)?; + + // Resolve named arguments if any are present + let resolved_args = if arg_names.iter().any(|name| name.is_some()) { + // Get parameter names from the signature if available + if let Some(param_names) = &fm.signature().parameter_names { + datafusion_expr::arguments::resolve_function_arguments( + param_names, + args, + arg_names, + )? + } else { + // UDF doesn't support named arguments + return plan_err!( + "Function '{}' does not support named arguments", + fm.name() + ); + } + } else { + args + }; + + // After resolution, all arguments are positional + let inner = ScalarFunction::new_udf(fm, resolved_args); if name.eq_ignore_ascii_case(inner.name()) { return Ok(Expr::ScalarFunction(inner)); @@ -624,14 +646,28 @@ impl SqlToRel<'_, S> { schema: &DFSchema, planner_context: &mut PlannerContext, ) -> Result { + let (expr, _) = self.sql_fn_arg_to_logical_expr_with_name(sql, schema, planner_context)?; + Ok(expr) + } + + fn sql_fn_arg_to_logical_expr_with_name( + &self, + sql: FunctionArg, + schema: &DFSchema, + planner_context: &mut PlannerContext, + ) -> Result<(Expr, Option)> { match sql { FunctionArg::Named { - name: _, + name, arg: FunctionArgExpr::Expr(arg), operator: _, - } => self.sql_expr_to_logical_expr(arg, schema, planner_context), + } => { + let expr = self.sql_expr_to_logical_expr(arg, schema, planner_context)?; + let arg_name = crate::utils::normalize_ident(name); + Ok((expr, Some(arg_name))) + } FunctionArg::Named { - name: _, + name, arg: FunctionArgExpr::Wildcard, operator: _, } => { @@ -640,11 +676,12 @@ impl SqlToRel<'_, S> { qualifier: None, options: Box::new(WildcardOptions::default()), }; - - Ok(expr) + let arg_name = crate::utils::normalize_ident(name); + Ok((expr, Some(arg_name))) } FunctionArg::Unnamed(FunctionArgExpr::Expr(arg)) => { - self.sql_expr_to_logical_expr(arg, schema, planner_context) + let expr = self.sql_expr_to_logical_expr(arg, schema, planner_context)?; + Ok((expr, None)) } FunctionArg::Unnamed(FunctionArgExpr::Wildcard) => { #[expect(deprecated)] @@ -652,8 +689,7 @@ impl SqlToRel<'_, S> { qualifier: None, options: Box::new(WildcardOptions::default()), }; - - Ok(expr) + Ok((expr, None)) } FunctionArg::Unnamed(FunctionArgExpr::QualifiedWildcard(object_name)) => { let qualifier = self.object_name_to_table_reference(object_name)?; @@ -668,8 +704,7 @@ impl SqlToRel<'_, S> { qualifier: qualifier.into(), options: Box::new(WildcardOptions::default()), }; - - Ok(expr) + Ok((expr, None)) } _ => not_impl_err!("Unsupported qualified wildcard argument: {sql:?}"), } @@ -686,6 +721,22 @@ impl SqlToRel<'_, S> { .collect::>>() } + pub(super) fn function_args_to_expr_with_names( + &self, + args: Vec, + schema: &DFSchema, + planner_context: &mut PlannerContext, + ) -> Result<(Vec, Vec>)> { + let results: Result)>> = args + .into_iter() + .map(|a| self.sql_fn_arg_to_logical_expr_with_name(a, schema, planner_context)) + .collect(); + + let pairs = results?; + let (exprs, names): (Vec, Vec>) = pairs.into_iter().unzip(); + Ok((exprs, names)) + } + pub(crate) fn check_unnest_arg(arg: &Expr, schema: &DFSchema) -> Result<()> { // Check argument type, array types are supported match arg.get_type(schema)? { diff --git a/datafusion/sqllogictest/test_files/named_arguments.slt b/datafusion/sqllogictest/test_files/named_arguments.slt new file mode 100644 index 0000000000000..679608806c5aa --- /dev/null +++ b/datafusion/sqllogictest/test_files/named_arguments.slt @@ -0,0 +1,75 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +############# +## Tests for Named Arguments (PostgreSQL-style param => value syntax) +############# + +# Test positional arguments still work (baseline) +query T +SELECT substr('hello world', 7, 5); +---- +world + +# Test named arguments in order +query T +SELECT substr(str => 'hello world', start_pos => 7, length => 5); +---- +world + +# Test named arguments out of order +query T +SELECT substr(length => 5, str => 'hello world', start_pos => 7); +---- +world + +# Test mixed positional and named arguments +query T +SELECT substr('hello world', start_pos => 7, length => 5); +---- +world + +# Test with only 2 parameters (length optional) +query T +SELECT substr(str => 'hello world', start_pos => 7); +---- +world + +# Test all parameters named with substring alias +query T +SELECT substring(str => 'hello', start_pos => 1, length => 3); +---- +hel + +# Error: positional argument after named argument +query error DataFusion error: Error during planning: Positional argument.*follows named argument +SELECT substr(str => 'hello', 1, 3); + +# Error: unknown parameter name +query error DataFusion error: Error during planning: Unknown parameter name 'invalid' +SELECT substr(invalid => 'hello', start_pos => 1, length => 3); + +# Error: duplicate parameter name +query error DataFusion error: Error during planning: Parameter 'str' specified multiple times +SELECT substr(str => 'hello', str => 'world', start_pos => 1); + +# Error: wrong number of arguments +# This query provides only 1 argument but substr requires 2 or 3 +# The error message should show parameter names like "substr(str, start_pos)" +# instead of generic types like "substr(Any, Any)" +query error DataFusion error: Error during planning: Internal error: Function 'substr' failed to match any signature +SELECT substr(str => 'hello world'); diff --git a/docs/source/library-user-guide/functions/adding-udfs.md b/docs/source/library-user-guide/functions/adding-udfs.md index 2335105882a10..16cecf316c6df 100644 --- a/docs/source/library-user-guide/functions/adding-udfs.md +++ b/docs/source/library-user-guide/functions/adding-udfs.md @@ -586,6 +586,111 @@ For async UDF implementation details, see [`async_udf.rs`](https://github.com/ap [`process_scalar_func_inputs`]: https://docs.rs/datafusion/latest/datafusion/physical_expr/functions/fn.process_scalar_func_inputs.html [`advanced_udf.rs`]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/advanced_udf.rs +## Named Arguments + +DataFusion supports PostgreSQL-style named arguments for scalar functions, allowing you to pass arguments by parameter name: + +```sql +SELECT substr(str => 'hello', start_pos => 2, length => 3); +``` + +Named arguments can be mixed with positional arguments, but positional arguments must come first: + +```sql +SELECT substr('hello', start_pos => 2, length => 3); -- Valid +``` + +### Implementing Functions with Named Arguments + +To support named arguments in your UDF, add parameter names to your function's signature using `.with_parameter_names()`: + +```rust +impl MyFunction { + fn new() -> Self { + Self { + signature: Signature::uniform( + 2, + vec![DataType::Float64], + Volatility::Immutable + ) + .with_parameter_names(vec![ + "base".to_string(), + "exponent".to_string() + ]) + .expect("valid parameter names"), + } + } +} +``` + +The parameter names should match the order of arguments in your function's signature. DataFusion automatically resolves named arguments to the correct positional order before invoking your function. + +### Example + +```rust +# use std::sync::Arc; +# use std::any::Any; +# use arrow::datatypes::DataType; +# use datafusion_common::Result; +# use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, Signature, Volatility}; +# use datafusion_expr::ScalarUDFImpl; + +#[derive(Debug, PartialEq, Eq, Hash)] +struct PowerFunction { + signature: Signature, +} + +impl PowerFunction { + fn new() -> Self { + Self { + signature: Signature::uniform( + 2, + vec![DataType::Float64], + Volatility::Immutable + ) + .with_parameter_names(vec![ + "base".to_string(), + "exponent".to_string() + ]) + .expect("valid parameter names"), + } + } +} + +impl ScalarUDFImpl for PowerFunction { + fn as_any(&self) -> &dyn Any { self } + fn name(&self) -> &str { "power" } + fn signature(&self) -> &Signature { &self.signature } + + fn return_type(&self, _args: &[DataType]) -> Result { + Ok(DataType::Float64) + } + + fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result { + // Your implementation - arguments are in correct positional order + unimplemented!() + } +} +``` + +Once registered, users can call your function with named arguments: + +```sql +SELECT power(base => 2.0, exponent => 3.0); +SELECT power(2.0, exponent => 3.0); +``` + +### Error Messages + +When a function call fails due to incorrect arguments, DataFusion will show the parameter names in error messages to help users: + +``` +No function matches the given name and argument types 'substr(Utf8)'. + Candidate functions: + substr(str, start_pos) + substr(str, start_pos, length) +``` + ## Adding a Window UDF Scalar UDFs are functions that take a row of data and return a single value. Window UDFs are similar, but they also have From f1bb77dc995993ca62208f9c35d446f0d1c07864 Mon Sep 17 00:00:00 2001 From: bubulalabu Date: Sat, 18 Oct 2025 09:30:27 +0200 Subject: [PATCH 2/9] added support for TypeSignature::UserDefined fixed CI issues --- datafusion/expr-common/src/signature.rs | 186 +++++++++++++----- datafusion/expr/src/udf.rs | 55 +++--- datafusion/expr/src/utils.rs | 12 +- datafusion/functions/src/unicode/substr.rs | 23 +-- datafusion/sql/src/expr/function.rs | 10 +- .../test_files/named_arguments.slt | 4 +- .../functions/adding-udfs.md | 12 +- 7 files changed, 193 insertions(+), 109 deletions(-) diff --git a/datafusion/expr-common/src/signature.rs b/datafusion/expr-common/src/signature.rs index 8d2738fa897e1..2097a805dd5c1 100644 --- a/datafusion/expr-common/src/signature.rs +++ b/datafusion/expr-common/src/signature.rs @@ -518,14 +518,24 @@ impl TypeSignature { match self { TypeSignature::Exact(types) => { if let Some(names) = parameter_names { - vec![names.iter().take(types.len()).cloned().collect::>().join(", ")] + vec![names + .iter() + .take(types.len()) + .cloned() + .collect::>() + .join(", ")] } else { vec![Self::join_types(types, ", ")] } } TypeSignature::Any(count) => { if let Some(names) = parameter_names { - vec![names.iter().take(*count).cloned().collect::>().join(", ")] + vec![names + .iter() + .take(*count) + .cloned() + .collect::>() + .join(", ")] } else { vec![std::iter::repeat_n("Any", *count) .collect::>() @@ -534,7 +544,12 @@ impl TypeSignature { } TypeSignature::Uniform(count, _types) => { if let Some(names) = parameter_names { - vec![names.iter().take(*count).cloned().collect::>().join(", ")] + vec![names + .iter() + .take(*count) + .cloned() + .collect::>() + .join(", ")] } else { // Fallback to original representation self.to_string_repr() @@ -542,7 +557,12 @@ impl TypeSignature { } TypeSignature::Coercible(coercions) => { if let Some(names) = parameter_names { - vec![names.iter().take(coercions.len()).cloned().collect::>().join(", ")] + vec![names + .iter() + .take(coercions.len()) + .cloned() + .collect::>() + .join(", ")] } else { vec![Self::join_types(coercions, ", ")] } @@ -551,7 +571,12 @@ impl TypeSignature { | TypeSignature::Numeric(count) | TypeSignature::String(count) => { if let Some(names) = parameter_names { - vec![names.iter().take(*count).cloned().collect::>().join(", ")] + vec![names + .iter() + .take(*count) + .cloned() + .collect::>() + .join(", ")] } else { // Fallback to original representation self.to_string_repr() @@ -569,23 +594,25 @@ impl TypeSignature { ArrayFunctionSignature::MapArray => 1, }; if let Some(names) = parameter_names { - vec![names.iter().take(arity).cloned().collect::>().join(", ")] + vec![names + .iter() + .take(arity) + .cloned() + .collect::>() + .join(", ")] } else { // Fallback to semantic names like "array, index, element" self.to_string_repr() } } - TypeSignature::OneOf(sigs) => { - sigs.iter() - .flat_map(|s| s.to_string_repr_with_names(parameter_names)) - .collect() - } + TypeSignature::OneOf(sigs) => sigs + .iter() + .flat_map(|s| s.to_string_repr_with_names(parameter_names)) + .collect(), // Variable arity signatures cannot use parameter names TypeSignature::Variadic(_) | TypeSignature::VariadicAny - | TypeSignature::UserDefined => { - self.to_string_repr() - } + | TypeSignature::UserDefined => self.to_string_repr(), } } @@ -1182,17 +1209,20 @@ impl Signature { // For OneOf, get the maximum arity from all variants TypeSignature::OneOf(variants) => { // Get max arity from all variants - let max_arity = variants.iter().filter_map(|v| match v { - TypeSignature::Any(count) - | TypeSignature::Uniform(count, _) - | TypeSignature::Numeric(count) - | TypeSignature::String(count) - | TypeSignature::Comparable(count) => Some(*count), - TypeSignature::Exact(types) => Some(types.len()), - TypeSignature::Coercible(types) => Some(types.len()), - TypeSignature::Nullary => Some(0), - _ => None, - }).max(); + let max_arity = variants + .iter() + .filter_map(|v| match v { + TypeSignature::Any(count) + | TypeSignature::Uniform(count, _) + | TypeSignature::Numeric(count) + | TypeSignature::String(count) + | TypeSignature::Comparable(count) => Some(*count), + TypeSignature::Exact(types) => Some(types.len()), + TypeSignature::Coercible(types) => Some(types.len()), + TypeSignature::Nullary => Some(0), + _ => None, + }) + .max(); max_arity } // Variable arity signatures cannot have parameter names @@ -1210,10 +1240,14 @@ impl Signature { ); } } else { - return datafusion_common::plan_err!( - "Cannot specify parameter names for variable arity signature: {:?}", - self.type_signature - ); + // For UserDefined signatures, allow parameter names + // The function implementer is responsible for validating the names match the actual arguments + if !matches!(self.type_signature, TypeSignature::UserDefined) { + return datafusion_common::plan_err!( + "Cannot specify parameter names for variable arity signature: {:?}", + self.type_signature + ); + } } // Validate no duplicate names @@ -1397,32 +1431,53 @@ mod tests { #[test] fn test_signature_with_parameter_names() { // Test adding parameter names to exact signature - let sig = Signature::exact(vec![DataType::Int32, DataType::Utf8], Volatility::Immutable) - .with_parameter_names(vec!["count".to_string(), "name".to_string()]) - .unwrap(); + let sig = Signature::exact( + vec![DataType::Int32, DataType::Utf8], + Volatility::Immutable, + ) + .with_parameter_names(vec!["count".to_string(), "name".to_string()]) + .unwrap(); - assert_eq!(sig.parameter_names, Some(vec!["count".to_string(), "name".to_string()])); - assert_eq!(sig.type_signature, TypeSignature::Exact(vec![DataType::Int32, DataType::Utf8])); + assert_eq!( + sig.parameter_names, + Some(vec!["count".to_string(), "name".to_string()]) + ); + assert_eq!( + sig.type_signature, + TypeSignature::Exact(vec![DataType::Int32, DataType::Utf8]) + ); } #[test] fn test_signature_parameter_names_wrong_count() { // Test that wrong number of names fails - let result = Signature::exact(vec![DataType::Int32, DataType::Utf8], Volatility::Immutable) - .with_parameter_names(vec!["count".to_string()]); // Only 1 name for 2 args + let result = Signature::exact( + vec![DataType::Int32, DataType::Utf8], + Volatility::Immutable, + ) + .with_parameter_names(vec!["count".to_string()]); // Only 1 name for 2 args assert!(result.is_err()); - assert!(result.unwrap_err().to_string().contains("does not match signature arity")); + assert!(result + .unwrap_err() + .to_string() + .contains("does not match signature arity")); } #[test] fn test_signature_parameter_names_duplicate() { // Test that duplicate names fail - let result = Signature::exact(vec![DataType::Int32, DataType::Int32], Volatility::Immutable) - .with_parameter_names(vec!["count".to_string(), "count".to_string()]); + let result = Signature::exact( + vec![DataType::Int32, DataType::Int32], + Volatility::Immutable, + ) + .with_parameter_names(vec!["count".to_string(), "count".to_string()]); assert!(result.is_err()); - assert!(result.unwrap_err().to_string().contains("Duplicate parameter name")); + assert!(result + .unwrap_err() + .to_string() + .contains("Duplicate parameter name")); } #[test] @@ -1432,13 +1487,19 @@ mod tests { .with_parameter_names(vec!["arg".to_string()]); assert!(result.is_err()); - assert!(result.unwrap_err().to_string().contains("variable arity signature")); + assert!(result + .unwrap_err() + .to_string() + .contains("variable arity signature")); } #[test] fn test_signature_without_parameter_names() { // Test that signatures without parameter names still work - let sig = Signature::exact(vec![DataType::Int32, DataType::Utf8], Volatility::Immutable); + let sig = Signature::exact( + vec![DataType::Int32, DataType::Utf8], + Volatility::Immutable, + ); assert_eq!(sig.parameter_names, None); } @@ -1450,7 +1511,10 @@ mod tests { .with_parameter_names(vec!["x".to_string(), "y".to_string(), "z".to_string()]) .unwrap(); - assert_eq!(sig.parameter_names, Some(vec!["x".to_string(), "y".to_string(), "z".to_string()])); + assert_eq!( + sig.parameter_names, + Some(vec!["x".to_string(), "y".to_string(), "z".to_string()]) + ); } #[test] @@ -1460,7 +1524,10 @@ mod tests { .with_parameter_names(vec!["a".to_string(), "b".to_string()]) .unwrap(); - assert_eq!(sig.parameter_names, Some(vec!["a".to_string(), "b".to_string()])); + assert_eq!( + sig.parameter_names, + Some(vec!["a".to_string(), "b".to_string()]) + ); } #[test] @@ -1483,7 +1550,10 @@ mod tests { // With names: should show parameter names let names = vec!["id".to_string(), "name".to_string()]; - assert_eq!(sig.to_string_repr_with_names(Some(&names)), vec!["id, name"]); + assert_eq!( + sig.to_string_repr_with_names(Some(&names)), + vec!["id, name"] + ); } #[test] @@ -1502,10 +1572,8 @@ mod tests { #[test] fn test_to_string_repr_with_names_one_of() { // Test OneOf signature with parameter names (like substr) - let sig = TypeSignature::OneOf(vec![ - TypeSignature::Any(2), - TypeSignature::Any(3), - ]); + let sig = + TypeSignature::OneOf(vec![TypeSignature::Any(2), TypeSignature::Any(3)]); // Without names: should show generic "Any" types assert_eq!( @@ -1514,7 +1582,11 @@ mod tests { ); // With names: should use names for each variant - let names = vec!["str".to_string(), "start_pos".to_string(), "length".to_string()]; + let names = vec![ + "str".to_string(), + "start_pos".to_string(), + "length".to_string(), + ]; assert_eq!( sig.to_string_repr_with_names(Some(&names)), vec!["str, start_pos", "str, start_pos, length"] @@ -1524,7 +1596,11 @@ mod tests { #[test] fn test_to_string_repr_with_names_partial() { // Test with fewer names than needed - let sig = TypeSignature::Exact(vec![DataType::Int32, DataType::Utf8, DataType::Float64]); + let sig = TypeSignature::Exact(vec![ + DataType::Int32, + DataType::Utf8, + DataType::Float64, + ]); // Provide only 2 names for 3 parameters let names = vec!["a".to_string(), "b".to_string()]; @@ -1618,7 +1694,10 @@ mod tests { let names = vec!["x".to_string()]; // Should return empty representation, names don't apply - assert_eq!(sig.to_string_repr_with_names(Some(&names)), vec!["NullAry()"]); + assert_eq!( + sig.to_string_repr_with_names(Some(&names)), + vec!["NullAry()"] + ); assert_eq!(sig.to_string_repr_with_names(None), vec!["NullAry()"]); } @@ -1650,7 +1729,8 @@ mod tests { ); // Test RecursiveArray (1 argument) - let recursive = TypeSignature::ArraySignature(ArrayFunctionSignature::RecursiveArray); + let recursive = + TypeSignature::ArraySignature(ArrayFunctionSignature::RecursiveArray); let names = vec!["array".to_string()]; assert_eq!( recursive.to_string_repr_with_names(Some(&names)), diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs index b26939a7cf5e8..516635383098e 100644 --- a/datafusion/expr/src/udf.rs +++ b/datafusion/expr/src/udf.rs @@ -960,8 +960,8 @@ mod tests { /// Argument resolution logic for named function parameters pub mod arguments { - use datafusion_common::{plan_err, Result}; use crate::Expr; + use datafusion_common::{plan_err, Result}; /// Resolves function arguments, handling named and positional notation. /// @@ -983,10 +983,10 @@ pub mod arguments { /// A vector of expressions in the correct order matching the parameter names /// /// # Examples - /// ```rust,ignore - /// // Given parameters ["a", "b", "c"] - /// // And call: func(10, c => 30, b => 20) - /// // Returns: [Expr(10), Expr(20), Expr(30)] + /// ```text + /// Given parameters ["a", "b", "c"] + /// And call: func(10, c => 30, b => 20) + /// Returns: [Expr(10), Expr(20), Expr(30)] /// ``` pub fn resolve_function_arguments( param_names: &[String], @@ -1056,10 +1056,8 @@ pub mod arguments { for (i, (arg, arg_name)) in args.into_iter().zip(arg_names).enumerate() { if let Some(name) = arg_name { // Named argument - find its position in param_names - let param_index = param_names - .iter() - .position(|p| p == &name) - .ok_or_else(|| { + let param_index = + param_names.iter().position(|p| p == &name).ok_or_else(|| { datafusion_common::plan_datafusion_err!( "Unknown parameter name '{}'. Valid parameters are: [{}]", name, @@ -1069,10 +1067,7 @@ pub mod arguments { // Check if this parameter was already assigned if assigned[param_index] { - return plan_err!( - "Parameter '{}' specified multiple times", - name - ); + return plan_err!("Parameter '{}' specified multiple times", name); } result[param_index] = Some(arg); @@ -1096,15 +1091,16 @@ pub mod arguments { let required_count = args_len; for i in 0..required_count { if !assigned[i] { - return plan_err!( - "Missing required parameter '{}'", - param_names[i] - ); + return plan_err!("Missing required parameter '{}'", param_names[i]); } } // Return only the assigned parameters (handles optional trailing parameters) - Ok(result.into_iter().take(required_count).map(|e| e.unwrap()).collect()) + Ok(result + .into_iter() + .take(required_count) + .map(|e| e.unwrap()) + .collect()) } #[cfg(test)] @@ -1119,7 +1115,9 @@ pub mod arguments { let args = vec![lit(1), lit("hello")]; let arg_names = vec![None, None]; - let result = resolve_function_arguments(¶m_names, args.clone(), arg_names).unwrap(); + let result = + resolve_function_arguments(¶m_names, args.clone(), arg_names) + .unwrap(); assert_eq!(result.len(), 2); } @@ -1130,7 +1128,8 @@ pub mod arguments { let args = vec![lit(1), lit("hello")]; let arg_names = vec![Some("a".to_string()), Some("b".to_string())]; - let result = resolve_function_arguments(¶m_names, args, arg_names).unwrap(); + let result = + resolve_function_arguments(¶m_names, args, arg_names).unwrap(); assert_eq!(result.len(), 2); } @@ -1146,7 +1145,8 @@ pub mod arguments { Some("b".to_string()), ]; - let result = resolve_function_arguments(¶m_names, args, arg_names).unwrap(); + let result = + resolve_function_arguments(¶m_names, args, arg_names).unwrap(); // Should be reordered to [a, b, c] = [1, "hello", 3.0] assert_eq!(result.len(), 3); @@ -1163,7 +1163,8 @@ pub mod arguments { let args = vec![lit(1), lit(3.0), lit("hello")]; let arg_names = vec![None, Some("c".to_string()), Some("b".to_string())]; - let result = resolve_function_arguments(¶m_names, args, arg_names).unwrap(); + let result = + resolve_function_arguments(¶m_names, args, arg_names).unwrap(); // Should be reordered to [a, b, c] = [1, "hello", 3.0] assert_eq!(result.len(), 3); @@ -1198,7 +1199,10 @@ pub mod arguments { let result = resolve_function_arguments(¶m_names, args, arg_names); assert!(result.is_err()); - assert!(result.unwrap_err().to_string().contains("Unknown parameter")); + assert!(result + .unwrap_err() + .to_string() + .contains("Unknown parameter")); } #[test] @@ -1227,7 +1231,10 @@ pub mod arguments { let result = resolve_function_arguments(¶m_names, args, arg_names); assert!(result.is_err()); - assert!(result.unwrap_err().to_string().contains("Missing required parameter")); + assert!(result + .unwrap_err() + .to_string() + .contains("Missing required parameter")); } } } diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index 8bb94c569cb97..286e643d3c2c8 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -1737,20 +1737,17 @@ mod tests { // Error message should contain parameter names assert!( error_msg.contains("str, start_pos"), - "Expected 'str, start_pos' in error message, got: {}", - error_msg + "Expected 'str, start_pos' in error message, got: {error_msg}" ); assert!( error_msg.contains("str, start_pos, length"), - "Expected 'str, start_pos, length' in error message, got: {}", - error_msg + "Expected 'str, start_pos, length' in error message, got: {error_msg}" ); // Should NOT contain generic "Any" types assert!( !error_msg.contains("Any, Any"), - "Should not contain 'Any, Any', got: {}", - error_msg + "Should not contain 'Any, Any', got: {error_msg}" ); } @@ -1770,8 +1767,7 @@ mod tests { // Should contain generic "Any" types when no parameter names assert!( error_msg.contains("Any, Any"), - "Expected 'Any, Any' without parameter names, got: {}", - error_msg + "Expected 'Any, Any' without parameter names, got: {error_msg}" ); } } diff --git a/datafusion/functions/src/unicode/substr.rs b/datafusion/functions/src/unicode/substr.rs index 631fee1b67c37..46b3cc63d0b6d 100644 --- a/datafusion/functions/src/unicode/substr.rs +++ b/datafusion/functions/src/unicode/substr.rs @@ -70,23 +70,14 @@ impl Default for SubstrFunc { impl SubstrFunc { pub fn new() -> Self { - use datafusion_expr::TypeSignature; Self { - signature: Signature::one_of( - vec![ - // substr(str, start_pos) - TypeSignature::Any(2), - // substr(str, start_pos, length) - TypeSignature::Any(3), - ], - Volatility::Immutable, - ) - .with_parameter_names(vec![ - "str".to_string(), - "start_pos".to_string(), - "length".to_string(), - ]) - .expect("valid parameter names"), + signature: Signature::user_defined(Volatility::Immutable) + .with_parameter_names(vec![ + "str".to_string(), + "start_pos".to_string(), + "length".to_string(), + ]) + .expect("valid parameter names"), aliases: vec![String::from("substring")], } } diff --git a/datafusion/sql/src/expr/function.rs b/datafusion/sql/src/expr/function.rs index 9f31ddb2d8e85..c72bd669cb5d9 100644 --- a/datafusion/sql/src/expr/function.rs +++ b/datafusion/sql/src/expr/function.rs @@ -274,7 +274,8 @@ impl SqlToRel<'_, S> { } // User-defined function (UDF) should have precedence if let Some(fm) = self.context_provider.get_function_meta(&name) { - let (args, arg_names) = self.function_args_to_expr_with_names(args, schema, planner_context)?; + let (args, arg_names) = + self.function_args_to_expr_with_names(args, schema, planner_context)?; // Resolve named arguments if any are present let resolved_args = if arg_names.iter().any(|name| name.is_some()) { @@ -646,7 +647,8 @@ impl SqlToRel<'_, S> { schema: &DFSchema, planner_context: &mut PlannerContext, ) -> Result { - let (expr, _) = self.sql_fn_arg_to_logical_expr_with_name(sql, schema, planner_context)?; + let (expr, _) = + self.sql_fn_arg_to_logical_expr_with_name(sql, schema, planner_context)?; Ok(expr) } @@ -729,7 +731,9 @@ impl SqlToRel<'_, S> { ) -> Result<(Vec, Vec>)> { let results: Result)>> = args .into_iter() - .map(|a| self.sql_fn_arg_to_logical_expr_with_name(a, schema, planner_context)) + .map(|a| { + self.sql_fn_arg_to_logical_expr_with_name(a, schema, planner_context) + }) .collect(); let pairs = results?; diff --git a/datafusion/sqllogictest/test_files/named_arguments.slt b/datafusion/sqllogictest/test_files/named_arguments.slt index 679608806c5aa..512d34e048eef 100644 --- a/datafusion/sqllogictest/test_files/named_arguments.slt +++ b/datafusion/sqllogictest/test_files/named_arguments.slt @@ -69,7 +69,5 @@ SELECT substr(str => 'hello', str => 'world', start_pos => 1); # Error: wrong number of arguments # This query provides only 1 argument but substr requires 2 or 3 -# The error message should show parameter names like "substr(str, start_pos)" -# instead of generic types like "substr(Any, Any)" -query error DataFusion error: Error during planning: Internal error: Function 'substr' failed to match any signature +query error DataFusion error: Error during planning: Execution error: Function 'substr' user-defined coercion failed with "Error during planning: The substr function requires 2 or 3 arguments, but got 1." SELECT substr(str => 'hello world'); diff --git a/docs/source/library-user-guide/functions/adding-udfs.md b/docs/source/library-user-guide/functions/adding-udfs.md index 16cecf316c6df..2fc6026ede6ef 100644 --- a/docs/source/library-user-guide/functions/adding-udfs.md +++ b/docs/source/library-user-guide/functions/adding-udfs.md @@ -605,6 +605,14 @@ SELECT substr('hello', start_pos => 2, length => 3); -- Valid To support named arguments in your UDF, add parameter names to your function's signature using `.with_parameter_names()`: ```rust +# use arrow::datatypes::DataType; +# use datafusion_expr::{Signature, Volatility}; +# +# #[derive(Debug)] +# struct MyFunction { +# signature: Signature, +# } +# impl MyFunction { fn new() -> Self { Self { @@ -684,8 +692,8 @@ SELECT power(2.0, exponent => 3.0); When a function call fails due to incorrect arguments, DataFusion will show the parameter names in error messages to help users: -``` -No function matches the given name and argument types 'substr(Utf8)'. +```text +No function matches the given name and argument types substr(Utf8). Candidate functions: substr(str, start_pos) substr(str, start_pos, length) From d67effb4ff7bd55890ded38a53fd03fcdf73404b Mon Sep 17 00:00:00 2001 From: bubulalabu Date: Sat, 18 Oct 2025 09:51:52 +0200 Subject: [PATCH 3/9] added support for TypeSignature::UserDefined to to_string_repr_with_names --- datafusion/expr-common/src/signature.rs | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/datafusion/expr-common/src/signature.rs b/datafusion/expr-common/src/signature.rs index 2097a805dd5c1..b9bd16295cc09 100644 --- a/datafusion/expr-common/src/signature.rs +++ b/datafusion/expr-common/src/signature.rs @@ -609,10 +609,19 @@ impl TypeSignature { .iter() .flat_map(|s| s.to_string_repr_with_names(parameter_names)) .collect(), + TypeSignature::UserDefined => { + // UserDefined signatures can have parameter names + if let Some(names) = parameter_names { + vec![names.join(", ")] + } else { + self.to_string_repr() + } + } // Variable arity signatures cannot use parameter names - TypeSignature::Variadic(_) - | TypeSignature::VariadicAny - | TypeSignature::UserDefined => self.to_string_repr(), + TypeSignature::Variadic(_) | TypeSignature::VariadicAny => { + self.to_string_repr() + } + } } @@ -1680,9 +1689,15 @@ mod tests { variadic_any.to_string_repr() ); + // UserDefined now shows parameter names when available let user_defined = TypeSignature::UserDefined; assert_eq!( user_defined.to_string_repr_with_names(Some(&names)), + vec!["x"] + ); + // Without names, falls back to to_string_repr() + assert_eq!( + user_defined.to_string_repr_with_names(None), user_defined.to_string_repr() ); } From 930d72b332abbd21443d9b0b480ce6e5ff8f11c9 Mon Sep 17 00:00:00 2001 From: bubulalabu Date: Sun, 19 Oct 2025 11:41:32 +0200 Subject: [PATCH 4/9] moved mod arguments to dedicated file added test to verify arguments can be used case in-sensitive added support for PostgresSQL dialect and verified other dialects --- datafusion/expr-common/src/signature.rs | 1 - datafusion/expr/src/arguments.rs | 293 ++++++++++++++++++ datafusion/expr/src/lib.rs | 3 +- datafusion/expr/src/udf.rs | 281 ----------------- datafusion/sql/src/expr/function.rs | 37 +++ datafusion/sql/tests/sql_integration.rs | 79 ++++- .../src/engines/postgres_engine/mod.rs | 4 +- .../test_files/named_arguments.slt | 62 ++++ 8 files changed, 474 insertions(+), 286 deletions(-) create mode 100644 datafusion/expr/src/arguments.rs diff --git a/datafusion/expr-common/src/signature.rs b/datafusion/expr-common/src/signature.rs index b9bd16295cc09..2980867c73476 100644 --- a/datafusion/expr-common/src/signature.rs +++ b/datafusion/expr-common/src/signature.rs @@ -621,7 +621,6 @@ impl TypeSignature { TypeSignature::Variadic(_) | TypeSignature::VariadicAny => { self.to_string_repr() } - } } diff --git a/datafusion/expr/src/arguments.rs b/datafusion/expr/src/arguments.rs new file mode 100644 index 0000000000000..96b9c818613a9 --- /dev/null +++ b/datafusion/expr/src/arguments.rs @@ -0,0 +1,293 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Argument resolution logic for named function parameters + +use crate::Expr; +use datafusion_common::{plan_err, Result}; + +/// Resolves function arguments, handling named and positional notation. +/// +/// This function validates and reorders arguments to match the function's parameter names +/// when named arguments are used. +/// +/// # Rules +/// - All positional arguments must come before named arguments +/// - Named arguments can be in any order after positional arguments +/// - Parameter names follow SQL identifier rules: unquoted names are case-insensitive +/// (normalized to lowercase), quoted names are case-sensitive +/// - No duplicate parameter names allowed +/// +/// # Arguments +/// * `param_names` - The function's parameter names in order +/// * `args` - The argument expressions +/// * `arg_names` - Optional parameter name for each argument +/// +/// # Returns +/// A vector of expressions in the correct order matching the parameter names +/// +/// # Examples +/// ```text +/// Given parameters ["a", "b", "c"] +/// And call: func(10, c => 30, b => 20) +/// Returns: [Expr(10), Expr(20), Expr(30)] +/// ``` +pub fn resolve_function_arguments( + param_names: &[String], + args: Vec, + arg_names: Vec>, +) -> Result> { + // Validate that arg_names length matches args length + if args.len() != arg_names.len() { + return plan_err!( + "Internal error: args length ({}) != arg_names length ({})", + args.len(), + arg_names.len() + ); + } + + // Check if all arguments are positional (fast path) + if arg_names.iter().all(|name| name.is_none()) { + return Ok(args); + } + + // Validate mixed positional and named arguments + validate_argument_order(&arg_names)?; + + // Validate and reorder named arguments + reorder_named_arguments(param_names, args, arg_names) +} + +/// Validates that positional arguments come before named arguments +fn validate_argument_order(arg_names: &[Option]) -> Result<()> { + let mut seen_named = false; + for (i, arg_name) in arg_names.iter().enumerate() { + match arg_name { + Some(_) => seen_named = true, + None if seen_named => { + return plan_err!( + "Positional argument at position {} follows named argument. \ + All positional arguments must come before named arguments.", + i + ); + } + None => {} + } + } + Ok(()) +} + +/// Reorders arguments based on named parameters to match signature order +fn reorder_named_arguments( + param_names: &[String], + args: Vec, + arg_names: Vec>, +) -> Result> { + // Count positional vs named arguments + let positional_count = arg_names.iter().filter(|n| n.is_none()).count(); + + // Capture args length before consuming the vector + let args_len = args.len(); + + // Create a result vector with the expected size + let expected_arg_count = param_names.len(); + let mut result: Vec> = vec![None; expected_arg_count]; + + // Track which parameters have been assigned + let mut assigned = vec![false; expected_arg_count]; + + // Process all arguments (both positional and named) + for (i, (arg, arg_name)) in args.into_iter().zip(arg_names).enumerate() { + if let Some(name) = arg_name { + // Named argument - find its position in param_names + let param_index = + param_names.iter().position(|p| p == &name).ok_or_else(|| { + datafusion_common::plan_datafusion_err!( + "Unknown parameter name '{}'. Valid parameters are: [{}]", + name, + param_names.join(", ") + ) + })?; + + // Check if this parameter was already assigned + if assigned[param_index] { + return plan_err!("Parameter '{}' specified multiple times", name); + } + + result[param_index] = Some(arg); + assigned[param_index] = true; + } else { + // Positional argument - place at current position + if i >= expected_arg_count { + return plan_err!( + "Too many positional arguments: expected at most {}, got {}", + expected_arg_count, + positional_count + ); + } + result[i] = Some(arg); + assigned[i] = true; + } + } + + // Check if all required parameters were provided + // Only require parameters up to the number of arguments provided (supports optional parameters) + let required_count = args_len; + for i in 0..required_count { + if !assigned[i] { + return plan_err!("Missing required parameter '{}'", param_names[i]); + } + } + + // Return only the assigned parameters (handles optional trailing parameters) + Ok(result + .into_iter() + .take(required_count) + .map(|e| e.unwrap()) + .collect()) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::lit; + + #[test] + fn test_all_positional() { + let param_names = vec!["a".to_string(), "b".to_string()]; + + let args = vec![lit(1), lit("hello")]; + let arg_names = vec![None, None]; + + let result = + resolve_function_arguments(¶m_names, args.clone(), arg_names).unwrap(); + assert_eq!(result.len(), 2); + } + + #[test] + fn test_all_named() { + let param_names = vec!["a".to_string(), "b".to_string()]; + + let args = vec![lit(1), lit("hello")]; + let arg_names = vec![Some("a".to_string()), Some("b".to_string())]; + + let result = resolve_function_arguments(¶m_names, args, arg_names).unwrap(); + assert_eq!(result.len(), 2); + } + + #[test] + fn test_named_reordering() { + let param_names = vec!["a".to_string(), "b".to_string(), "c".to_string()]; + + // Call with: func(c => 3.0, a => 1, b => "hello") + let args = vec![lit(3.0), lit(1), lit("hello")]; + let arg_names = vec![ + Some("c".to_string()), + Some("a".to_string()), + Some("b".to_string()), + ]; + + let result = resolve_function_arguments(¶m_names, args, arg_names).unwrap(); + + // Should be reordered to [a, b, c] = [1, "hello", 3.0] + assert_eq!(result.len(), 3); + assert_eq!(result[0], lit(1)); + assert_eq!(result[1], lit("hello")); + assert_eq!(result[2], lit(3.0)); + } + + #[test] + fn test_mixed_positional_and_named() { + let param_names = vec!["a".to_string(), "b".to_string(), "c".to_string()]; + + // Call with: func(1, c => 3.0, b => "hello") + let args = vec![lit(1), lit(3.0), lit("hello")]; + let arg_names = vec![None, Some("c".to_string()), Some("b".to_string())]; + + let result = resolve_function_arguments(¶m_names, args, arg_names).unwrap(); + + // Should be reordered to [a, b, c] = [1, "hello", 3.0] + assert_eq!(result.len(), 3); + assert_eq!(result[0], lit(1)); + assert_eq!(result[1], lit("hello")); + assert_eq!(result[2], lit(3.0)); + } + + #[test] + fn test_positional_after_named_error() { + let param_names = vec!["a".to_string(), "b".to_string()]; + + // Call with: func(a => 1, "hello") - ERROR + let args = vec![lit(1), lit("hello")]; + let arg_names = vec![Some("a".to_string()), None]; + + let result = resolve_function_arguments(¶m_names, args, arg_names); + assert!(result.is_err()); + assert!(result + .unwrap_err() + .to_string() + .contains("Positional argument")); + } + + #[test] + fn test_unknown_parameter_name() { + let param_names = vec!["a".to_string(), "b".to_string()]; + + // Call with: func(x => 1, b => "hello") - ERROR + let args = vec![lit(1), lit("hello")]; + let arg_names = vec![Some("x".to_string()), Some("b".to_string())]; + + let result = resolve_function_arguments(¶m_names, args, arg_names); + assert!(result.is_err()); + assert!(result + .unwrap_err() + .to_string() + .contains("Unknown parameter")); + } + + #[test] + fn test_duplicate_parameter_name() { + let param_names = vec!["a".to_string(), "b".to_string()]; + + // Call with: func(a => 1, a => 2) - ERROR + let args = vec![lit(1), lit(2)]; + let arg_names = vec![Some("a".to_string()), Some("a".to_string())]; + + let result = resolve_function_arguments(¶m_names, args, arg_names); + assert!(result.is_err()); + assert!(result + .unwrap_err() + .to_string() + .contains("specified multiple times")); + } + + #[test] + fn test_missing_required_parameter() { + let param_names = vec!["a".to_string(), "b".to_string(), "c".to_string()]; + + // Call with: func(a => 1, c => 3.0) - missing 'b' + let args = vec![lit(1), lit(3.0)]; + let arg_names = vec![Some("a".to_string()), Some("c".to_string())]; + + let result = resolve_function_arguments(¶m_names, args, arg_names); + assert!(result.is_err()); + assert!(result + .unwrap_err() + .to_string() + .contains("Missing required parameter")); + } +} diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs index cb2d84915c217..c0a9c05952248 100644 --- a/datafusion/expr/src/lib.rs +++ b/datafusion/expr/src/lib.rs @@ -44,6 +44,7 @@ mod udaf; mod udf; mod udwf; +pub mod arguments; pub mod conditional_expressions; pub mod execution_props; pub mod expr; @@ -116,7 +117,7 @@ pub use udaf::{ udaf_default_window_function_schema_name, AggregateUDF, AggregateUDFImpl, ReversedUDAF, SetMonotonicity, StatisticsArgs, }; -pub use udf::{arguments, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl}; +pub use udf::{ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl}; pub use udwf::{ReversedUDWF, WindowUDF, WindowUDFImpl}; pub use window_frame::{WindowFrame, WindowFrameBound, WindowFrameUnits}; diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs index 516635383098e..d522158f7b6b7 100644 --- a/datafusion/expr/src/udf.rs +++ b/datafusion/expr/src/udf.rs @@ -957,284 +957,3 @@ mod tests { hasher.finish() } } - -/// Argument resolution logic for named function parameters -pub mod arguments { - use crate::Expr; - use datafusion_common::{plan_err, Result}; - - /// Resolves function arguments, handling named and positional notation. - /// - /// This function validates and reorders arguments to match the function's parameter names - /// when named arguments are used. - /// - /// # Rules - /// - All positional arguments must come before named arguments - /// - Named arguments can be in any order after positional arguments - /// - All parameter names must match the provided parameter_names - /// - No duplicate parameter names allowed - /// - /// # Arguments - /// * `param_names` - The function's parameter names in order - /// * `args` - The argument expressions - /// * `arg_names` - Optional parameter name for each argument - /// - /// # Returns - /// A vector of expressions in the correct order matching the parameter names - /// - /// # Examples - /// ```text - /// Given parameters ["a", "b", "c"] - /// And call: func(10, c => 30, b => 20) - /// Returns: [Expr(10), Expr(20), Expr(30)] - /// ``` - pub fn resolve_function_arguments( - param_names: &[String], - args: Vec, - arg_names: Vec>, - ) -> Result> { - // Validate that arg_names length matches args length - if args.len() != arg_names.len() { - return plan_err!( - "Internal error: args length ({}) != arg_names length ({})", - args.len(), - arg_names.len() - ); - } - - // Check if all arguments are positional (fast path) - if arg_names.iter().all(|name| name.is_none()) { - return Ok(args); - } - - // Validate mixed positional and named arguments - validate_argument_order(&arg_names)?; - - // Validate and reorder named arguments - reorder_named_arguments(param_names, args, arg_names) - } - - /// Validates that positional arguments come before named arguments - fn validate_argument_order(arg_names: &[Option]) -> Result<()> { - let mut seen_named = false; - for (i, arg_name) in arg_names.iter().enumerate() { - match arg_name { - Some(_) => seen_named = true, - None if seen_named => { - return plan_err!( - "Positional argument at position {} follows named argument. \ - All positional arguments must come before named arguments.", - i - ); - } - None => {} - } - } - Ok(()) - } - - /// Reorders arguments based on named parameters to match signature order - fn reorder_named_arguments( - param_names: &[String], - args: Vec, - arg_names: Vec>, - ) -> Result> { - // Count positional vs named arguments - let positional_count = arg_names.iter().filter(|n| n.is_none()).count(); - - // Capture args length before consuming the vector - let args_len = args.len(); - - // Create a result vector with the expected size - let expected_arg_count = param_names.len(); - let mut result: Vec> = vec![None; expected_arg_count]; - - // Track which parameters have been assigned - let mut assigned = vec![false; expected_arg_count]; - - // Process all arguments (both positional and named) - for (i, (arg, arg_name)) in args.into_iter().zip(arg_names).enumerate() { - if let Some(name) = arg_name { - // Named argument - find its position in param_names - let param_index = - param_names.iter().position(|p| p == &name).ok_or_else(|| { - datafusion_common::plan_datafusion_err!( - "Unknown parameter name '{}'. Valid parameters are: [{}]", - name, - param_names.join(", ") - ) - })?; - - // Check if this parameter was already assigned - if assigned[param_index] { - return plan_err!("Parameter '{}' specified multiple times", name); - } - - result[param_index] = Some(arg); - assigned[param_index] = true; - } else { - // Positional argument - place at current position - if i >= expected_arg_count { - return plan_err!( - "Too many positional arguments: expected at most {}, got {}", - expected_arg_count, - positional_count - ); - } - result[i] = Some(arg); - assigned[i] = true; - } - } - - // Check if all required parameters were provided - // Only require parameters up to the number of arguments provided (supports optional parameters) - let required_count = args_len; - for i in 0..required_count { - if !assigned[i] { - return plan_err!("Missing required parameter '{}'", param_names[i]); - } - } - - // Return only the assigned parameters (handles optional trailing parameters) - Ok(result - .into_iter() - .take(required_count) - .map(|e| e.unwrap()) - .collect()) - } - - #[cfg(test)] - mod tests { - use super::*; - use crate::lit; - - #[test] - fn test_all_positional() { - let param_names = vec!["a".to_string(), "b".to_string()]; - - let args = vec![lit(1), lit("hello")]; - let arg_names = vec![None, None]; - - let result = - resolve_function_arguments(¶m_names, args.clone(), arg_names) - .unwrap(); - assert_eq!(result.len(), 2); - } - - #[test] - fn test_all_named() { - let param_names = vec!["a".to_string(), "b".to_string()]; - - let args = vec![lit(1), lit("hello")]; - let arg_names = vec![Some("a".to_string()), Some("b".to_string())]; - - let result = - resolve_function_arguments(¶m_names, args, arg_names).unwrap(); - assert_eq!(result.len(), 2); - } - - #[test] - fn test_named_reordering() { - let param_names = vec!["a".to_string(), "b".to_string(), "c".to_string()]; - - // Call with: func(c => 3.0, a => 1, b => "hello") - let args = vec![lit(3.0), lit(1), lit("hello")]; - let arg_names = vec![ - Some("c".to_string()), - Some("a".to_string()), - Some("b".to_string()), - ]; - - let result = - resolve_function_arguments(¶m_names, args, arg_names).unwrap(); - - // Should be reordered to [a, b, c] = [1, "hello", 3.0] - assert_eq!(result.len(), 3); - assert_eq!(result[0], lit(1)); - assert_eq!(result[1], lit("hello")); - assert_eq!(result[2], lit(3.0)); - } - - #[test] - fn test_mixed_positional_and_named() { - let param_names = vec!["a".to_string(), "b".to_string(), "c".to_string()]; - - // Call with: func(1, c => 3.0, b => "hello") - let args = vec![lit(1), lit(3.0), lit("hello")]; - let arg_names = vec![None, Some("c".to_string()), Some("b".to_string())]; - - let result = - resolve_function_arguments(¶m_names, args, arg_names).unwrap(); - - // Should be reordered to [a, b, c] = [1, "hello", 3.0] - assert_eq!(result.len(), 3); - assert_eq!(result[0], lit(1)); - assert_eq!(result[1], lit("hello")); - assert_eq!(result[2], lit(3.0)); - } - - #[test] - fn test_positional_after_named_error() { - let param_names = vec!["a".to_string(), "b".to_string()]; - - // Call with: func(a => 1, "hello") - ERROR - let args = vec![lit(1), lit("hello")]; - let arg_names = vec![Some("a".to_string()), None]; - - let result = resolve_function_arguments(¶m_names, args, arg_names); - assert!(result.is_err()); - assert!(result - .unwrap_err() - .to_string() - .contains("Positional argument")); - } - - #[test] - fn test_unknown_parameter_name() { - let param_names = vec!["a".to_string(), "b".to_string()]; - - // Call with: func(x => 1, b => "hello") - ERROR - let args = vec![lit(1), lit("hello")]; - let arg_names = vec![Some("x".to_string()), Some("b".to_string())]; - - let result = resolve_function_arguments(¶m_names, args, arg_names); - assert!(result.is_err()); - assert!(result - .unwrap_err() - .to_string() - .contains("Unknown parameter")); - } - - #[test] - fn test_duplicate_parameter_name() { - let param_names = vec!["a".to_string(), "b".to_string()]; - - // Call with: func(a => 1, a => 2) - ERROR - let args = vec![lit(1), lit(2)]; - let arg_names = vec![Some("a".to_string()), Some("a".to_string())]; - - let result = resolve_function_arguments(¶m_names, args, arg_names); - assert!(result.is_err()); - assert!(result - .unwrap_err() - .to_string() - .contains("specified multiple times")); - } - - #[test] - fn test_missing_required_parameter() { - let param_names = vec!["a".to_string(), "b".to_string(), "c".to_string()]; - - // Call with: func(a => 1, c => 3.0) - missing 'b' - let args = vec![lit(1), lit(3.0)]; - let arg_names = vec![Some("a".to_string()), Some("c".to_string())]; - - let result = resolve_function_arguments(¶m_names, args, arg_names); - assert!(result.is_err()); - assert!(result - .unwrap_err() - .to_string() - .contains("Missing required parameter")); - } - } -} diff --git a/datafusion/sql/src/expr/function.rs b/datafusion/sql/src/expr/function.rs index c72bd669cb5d9..8f8873aea76c2 100644 --- a/datafusion/sql/src/expr/function.rs +++ b/datafusion/sql/src/expr/function.rs @@ -708,6 +708,43 @@ impl SqlToRel<'_, S> { }; Ok((expr, None)) } + // PostgreSQL dialect uses ExprNamed variant with expression for name + FunctionArg::ExprNamed { + name, + arg: FunctionArgExpr::Expr(arg), + operator: _, + } => { + let expr = self.sql_expr_to_logical_expr(arg, schema, planner_context)?; + let arg_name = match name { + SQLExpr::Identifier(ident) => crate::utils::normalize_ident(ident), + _ => { + return plan_err!( + "Named argument must use a simple identifier, got: {name:?}" + ) + } + }; + Ok((expr, Some(arg_name))) + } + FunctionArg::ExprNamed { + name, + arg: FunctionArgExpr::Wildcard, + operator: _, + } => { + #[expect(deprecated)] + let expr = Expr::Wildcard { + qualifier: None, + options: Box::new(WildcardOptions::default()), + }; + let arg_name = match name { + SQLExpr::Identifier(ident) => crate::utils::normalize_ident(ident), + _ => { + return plan_err!( + "Named argument must use a simple identifier, got: {name:?}" + ) + } + }; + Ok((expr, Some(arg_name))) + } _ => not_impl_err!("Unsupported qualified wildcard argument: {sql:?}"), } } diff --git a/datafusion/sql/tests/sql_integration.rs b/datafusion/sql/tests/sql_integration.rs index f66af28f436e6..6933c45617ffb 100644 --- a/datafusion/sql/tests/sql_integration.rs +++ b/datafusion/sql/tests/sql_integration.rs @@ -46,7 +46,11 @@ use datafusion_functions_nested::make_array::make_array_udf; use datafusion_functions_window::{rank::rank_udwf, row_number::row_number_udwf}; use insta::{allow_duplicates, assert_snapshot}; use rstest::rstest; -use sqlparser::dialect::{Dialect, GenericDialect, HiveDialect, MySqlDialect}; +use sqlparser::dialect::{ + AnsiDialect, BigQueryDialect, ClickHouseDialect, DatabricksDialect, Dialect, + DuckDbDialect, GenericDialect, HiveDialect, MsSqlDialect, MySqlDialect, + PostgreSqlDialect, RedshiftSqlDialect, SQLiteDialect, SnowflakeDialect, +}; mod cases; mod common; @@ -3949,6 +3953,79 @@ fn test_double_quoted_literal_string() { assert!(logical_plan("SELECT \"1\"").is_err()); } +#[test] +fn test_named_arguments_with_dialects() { + // Test that named arguments syntax (param => value) with different dialects + use datafusion_sql::parser::Statement as DFStatement; + use sqlparser::ast::{FunctionArg, Statement, Expr as SQLExpr}; + use sqlparser::dialect::Dialect; + + let sql = "SELECT my_func(arg1 => 'value1')"; + + // Returns None if the dialect doesn't support the => operator + let extract_first_arg = |dialect: &dyn Dialect| -> Option { + let mut statements = DFParser::parse_sql_with_dialect(sql, dialect).ok()?; + + let statement = statements.pop_front().unwrap(); + if let DFStatement::Statement(stmt) = statement { + if let Statement::Query(query) = stmt.as_ref() { + if let sqlparser::ast::SetExpr::Select(select) = query.body.as_ref() { + let projection = &select.projection[0]; + if let sqlparser::ast::SelectItem::UnnamedExpr(SQLExpr::Function(func)) = projection { + if let sqlparser::ast::FunctionArguments::List(arg_list) = &func.args { + return Some(arg_list.args[0].clone()); + } + } + } + } + } + panic!("Failed to extract function argument"); + }; + + let arg = extract_first_arg(&AnsiDialect {}); + assert!(matches!(arg, Some(FunctionArg::Named { name, .. }) if name.value == "arg1")); + + let arg = extract_first_arg(&BigQueryDialect {}); + assert!(matches!(arg, Some(FunctionArg::Named { name, .. }) if name.value == "arg1")); + + let arg = extract_first_arg(&ClickHouseDialect {}); + assert!(matches!(arg, Some(FunctionArg::Named { name, .. }) if name.value == "arg1")); + + let arg = extract_first_arg(&DatabricksDialect {}); + assert!(matches!(arg, Some(FunctionArg::Named { name, .. }) if name.value == "arg1")); + + let arg = extract_first_arg(&DuckDbDialect {}); + assert!(matches!(arg, Some(FunctionArg::Named { name, .. }) if name.value == "arg1")); + + let arg = extract_first_arg(&GenericDialect {}); + assert!(matches!(arg, Some(FunctionArg::Named { name, .. }) if name.value == "arg1")); + + let arg = extract_first_arg(&HiveDialect {}); + assert!(matches!(arg, Some(FunctionArg::Named { name, .. }) if name.value == "arg1")); + + let arg = extract_first_arg(&MsSqlDialect {}); + assert!(matches!(arg, None)); + + let arg = extract_first_arg(&MySqlDialect {}); + assert!(matches!(arg, Some(FunctionArg::Named { name, .. }) if name.value == "arg1")); + + let arg = extract_first_arg(&PostgreSqlDialect {}); + assert!(matches!( + arg.as_ref(), + Some(FunctionArg::ExprNamed { name, .. }) + if matches!(name, SQLExpr::Identifier(ident) if ident.value == "arg1") + )); + + let arg = extract_first_arg(&RedshiftSqlDialect {}); + assert!(matches!(arg, Some(FunctionArg::Named { name, .. }) if name.value == "arg1")); + + let arg = extract_first_arg(&SQLiteDialect {}); + assert!(matches!(arg, Some(FunctionArg::Named { name, .. }) if name.value == "arg1")); + + let arg = extract_first_arg(&SnowflakeDialect {}); + assert!(matches!(arg, Some(FunctionArg::Named { name, .. }) if name.value == "arg1")); +} + #[test] fn test_constant_expr_eq_join() { let sql = "SELECT id, order_id \ diff --git a/datafusion/sqllogictest/src/engines/postgres_engine/mod.rs b/datafusion/sqllogictest/src/engines/postgres_engine/mod.rs index 375f06d34b44f..4d310711687f2 100644 --- a/datafusion/sqllogictest/src/engines/postgres_engine/mod.rs +++ b/datafusion/sqllogictest/src/engines/postgres_engine/mod.rs @@ -76,8 +76,8 @@ impl Postgres { /// /// See https://docs.rs/tokio-postgres/latest/tokio_postgres/config/struct.Config.html#url for format pub async fn connect(relative_path: PathBuf, pb: ProgressBar) -> Result { - let uri = - std::env::var("PG_URI").map_or(PG_URI.to_string(), std::convert::identity); + let uri = std::env::var("PG_URI") + .map_or_else(|_| PG_URI.to_string(), std::convert::identity); info!("Using postgres connection string: {uri}"); diff --git a/datafusion/sqllogictest/test_files/named_arguments.slt b/datafusion/sqllogictest/test_files/named_arguments.slt index 512d34e048eef..9737f5b70d6eb 100644 --- a/datafusion/sqllogictest/test_files/named_arguments.slt +++ b/datafusion/sqllogictest/test_files/named_arguments.slt @@ -67,7 +67,69 @@ SELECT substr(invalid => 'hello', start_pos => 1, length => 3); query error DataFusion error: Error during planning: Parameter 'str' specified multiple times SELECT substr(str => 'hello', str => 'world', start_pos => 1); +# Test case-insensitive parameter names (unquoted identifiers) +query T +SELECT substr(STR => 'hello world', START_POS => 7, LENGTH => 5); +---- +world + +# Test case-insensitive with mixed case +query T +SELECT substr(Str => 'hello world', Start_Pos => 7); +---- +world + +# Error: case-sensitive quoted parameter names don't match +query error DataFusion error: Error during planning: Unknown parameter name 'STR' +SELECT substr("STR" => 'hello world', "start_pos" => 7); + # Error: wrong number of arguments # This query provides only 1 argument but substr requires 2 or 3 query error DataFusion error: Error during planning: Execution error: Function 'substr' user-defined coercion failed with "Error during planning: The substr function requires 2 or 3 arguments, but got 1." SELECT substr(str => 'hello world'); + +############# +## PostgreSQL Dialect Tests (uses ExprNamed variant) +############# + +statement ok +set datafusion.sql_parser.dialect = 'PostgreSQL'; + +# Test named arguments in order +query T +SELECT substr(str => 'hello world', start_pos => 7, length => 5); +---- +world + +# Test named arguments out of order +query T +SELECT substr(length => 5, str => 'hello world', start_pos => 7); +---- +world + +# Test mixed positional and named arguments +query T +SELECT substr('hello world', start_pos => 7, length => 5); +---- +world + +# Test with only 2 parameters (length optional) +query T +SELECT substr(str => 'hello world', start_pos => 7); +---- +world + +# Reset to default dialect +statement ok +set datafusion.sql_parser.dialect = 'Generic'; + +############# +## MsSQL Dialect Tests (does NOT support => operator) +############# + +statement ok +set datafusion.sql_parser.dialect = 'MsSQL'; + +# Error: MsSQL dialect does not support => operator +query error DataFusion error: SQL error: ParserError\("Expected: \), found: => at Line: 1, Column: 19"\) +SELECT substr(str => 'hello world', start_pos => 7, length => 5); From 907bbabd5392ccb921684f78b61569ebe5e10e13 Mon Sep 17 00:00:00 2001 From: bubulalabu Date: Sun, 19 Oct 2025 14:29:04 +0200 Subject: [PATCH 5/9] display both names and types in error messages (e.g., "str: Utf8, start_pos: Int64") fix ArraySignature to pair each name with its corresponding type move positional argument validation upfront --- datafusion/expr-common/src/signature.rs | 468 +++++++++++++----- datafusion/expr/src/arguments.rs | 44 +- datafusion/expr/src/utils.rs | 26 +- datafusion/sql/tests/sql_integration.rs | 11 +- .../functions/adding-udfs.md | 4 +- 5 files changed, 386 insertions(+), 167 deletions(-) diff --git a/datafusion/expr-common/src/signature.rs b/datafusion/expr-common/src/signature.rs index 2980867c73476..b0af49ed3889a 100644 --- a/datafusion/expr-common/src/signature.rs +++ b/datafusion/expr-common/src/signature.rs @@ -22,9 +22,9 @@ use std::hash::Hash; use crate::type_coercion::aggregates::NUMERICS; use arrow::datatypes::{DataType, IntervalUnit, TimeUnit}; -use datafusion_common::internal_err; use datafusion_common::types::{LogicalType, LogicalTypeRef, NativeType}; use datafusion_common::utils::ListCoercion; +use datafusion_common::{internal_err, plan_err, Result}; use indexmap::IndexSet; use itertools::Itertools; @@ -142,6 +142,16 @@ pub enum Volatility { /// DataType::Timestamp(TimeUnit::Nanosecond, Some(TIMEZONE_WILDCARD.into())), /// ]); /// ``` +/// +/// Represents the arity (number of arguments) of a function signature +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum Arity { + /// Fixed number of arguments + Fixed(usize), + /// Variable number of arguments (e.g., Variadic, VariadicAny, UserDefined) + Variable, +} + #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)] pub enum TypeSignature { /// One or more arguments of a common type out of a list of valid types. @@ -245,6 +255,69 @@ impl TypeSignature { pub fn is_one_of(&self) -> bool { matches!(self, TypeSignature::OneOf(_)) } + + /// Returns the arity (expected number of arguments) for this type signature. + /// + /// Returns `Arity::Fixed(n)` for signatures with a specific argument count, + /// or `Arity::Variable` for variable-arity signatures like `Variadic`, `VariadicAny`, `UserDefined`. + /// + /// # Examples + /// + /// ``` + /// # use datafusion_expr_common::signature::{TypeSignature, Arity}; + /// # use arrow::datatypes::DataType; + /// // Exact signature has fixed arity + /// let sig = TypeSignature::Exact(vec![DataType::Int32, DataType::Utf8]); + /// assert_eq!(sig.arity(), Arity::Fixed(2)); + /// + /// // Variadic signature has variable arity + /// let sig = TypeSignature::VariadicAny; + /// assert_eq!(sig.arity(), Arity::Variable); + /// ``` + pub fn arity(&self) -> Arity { + match self { + TypeSignature::Exact(types) => Arity::Fixed(types.len()), + TypeSignature::Uniform(count, _) => Arity::Fixed(*count), + TypeSignature::Numeric(count) => Arity::Fixed(*count), + TypeSignature::String(count) => Arity::Fixed(*count), + TypeSignature::Comparable(count) => Arity::Fixed(*count), + TypeSignature::Any(count) => Arity::Fixed(*count), + TypeSignature::Coercible(types) => Arity::Fixed(types.len()), + TypeSignature::Nullary => Arity::Fixed(0), + TypeSignature::ArraySignature(ArrayFunctionSignature::Array { + arguments, + .. + }) => Arity::Fixed(arguments.len()), + TypeSignature::ArraySignature(ArrayFunctionSignature::RecursiveArray) => { + Arity::Fixed(1) + } + TypeSignature::ArraySignature(ArrayFunctionSignature::MapArray) => { + Arity::Fixed(1) + } + TypeSignature::OneOf(variants) => { + // If any variant is Variable, the whole OneOf is Variable + let has_variable = variants.iter().any(|v| v.arity() == Arity::Variable); + if has_variable { + return Arity::Variable; + } + // Otherwise, get max arity from all fixed arity variants + let max_arity = variants + .iter() + .filter_map(|v| match v.arity() { + Arity::Fixed(n) => Some(n), + Arity::Variable => None, + }) + .max(); + match max_arity { + Some(n) => Arity::Fixed(n), + None => Arity::Variable, + } + } + TypeSignature::Variadic(_) + | TypeSignature::VariadicAny + | TypeSignature::UserDefined => Arity::Variable, + } + } } /// Represents the class of types that can be used in a function signature. @@ -336,7 +409,7 @@ impl TypeSignatureClass { &self, native_type: &NativeType, origin_type: &DataType, - ) -> datafusion_common::Result { + ) -> Result { match self { TypeSignatureClass::Native(logical_type) => { logical_type.native().default_cast_for(origin_type) @@ -502,13 +575,13 @@ impl TypeSignature { /// # use arrow::datatypes::DataType; /// let sig = TypeSignature::Exact(vec![DataType::Int32, DataType::Utf8]); /// - /// // Without names: shows types + /// // Without names: shows types only /// assert_eq!(sig.to_string_repr_with_names(None), vec!["Int32, Utf8"]); /// - /// // With names: shows parameter names + /// // With names: shows parameter names with types /// assert_eq!( /// sig.to_string_repr_with_names(Some(&["id".to_string(), "name".to_string()])), - /// vec!["id, name"] + /// vec!["id: Int32, name: Utf8"] /// ); /// ``` pub fn to_string_repr_with_names( @@ -518,10 +591,11 @@ impl TypeSignature { match self { TypeSignature::Exact(types) => { if let Some(names) = parameter_names { + // Combine names with types: "name: Type" vec![names .iter() - .take(types.len()) - .cloned() + .zip(types.iter()) + .map(|(name, typ)| format!("{name}: {typ}")) .collect::>() .join(", ")] } else { @@ -530,10 +604,11 @@ impl TypeSignature { } TypeSignature::Any(count) => { if let Some(names) = parameter_names { + // Combine names with "Any": "name: Any" vec![names .iter() .take(*count) - .cloned() + .map(|name| format!("{name}: Any")) .collect::>() .join(", ")] } else { @@ -542,12 +617,14 @@ impl TypeSignature { .join(", ")] } } - TypeSignature::Uniform(count, _types) => { + TypeSignature::Uniform(count, types) => { if let Some(names) = parameter_names { + // Combine names with union of types: "name: Type1/Type2" + let type_str = Self::join_types(types, "/"); vec![names .iter() .take(*count) - .cloned() + .map(|name| format!("{name}: {type_str}")) .collect::>() .join(", ")] } else { @@ -557,28 +634,53 @@ impl TypeSignature { } TypeSignature::Coercible(coercions) => { if let Some(names) = parameter_names { + // Combine names with coercion types: "name: Type" vec![names .iter() - .take(coercions.len()) - .cloned() + .zip(coercions.iter()) + .map(|(name, coercion)| format!("{name}: {coercion}")) .collect::>() .join(", ")] } else { vec![Self::join_types(coercions, ", ")] } } - TypeSignature::Comparable(count) - | TypeSignature::Numeric(count) - | TypeSignature::String(count) => { + TypeSignature::Comparable(count) => { if let Some(names) = parameter_names { + // Combine names with "Comparable": "name: Comparable" vec![names .iter() .take(*count) - .cloned() + .map(|name| format!("{name}: Comparable")) + .collect::>() + .join(", ")] + } else { + self.to_string_repr() + } + } + TypeSignature::Numeric(count) => { + if let Some(names) = parameter_names { + // Combine names with "Numeric": "name: Numeric" + vec![names + .iter() + .take(*count) + .map(|name| format!("{name}: Numeric")) + .collect::>() + .join(", ")] + } else { + self.to_string_repr() + } + } + TypeSignature::String(count) => { + if let Some(names) = parameter_names { + // Combine names with "String": "name: String" + vec![names + .iter() + .take(*count) + .map(|name| format!("{name}: String")) .collect::>() .join(", ")] } else { - // Fallback to original representation self.to_string_repr() } } @@ -588,18 +690,34 @@ impl TypeSignature { } TypeSignature::ArraySignature(array_sig) => { // ArraySignature has fixed arity, so it can support parameter names - let arity = match array_sig { - ArrayFunctionSignature::Array { arguments, .. } => arguments.len(), - ArrayFunctionSignature::RecursiveArray => 1, - ArrayFunctionSignature::MapArray => 1, - }; if let Some(names) = parameter_names { - vec![names - .iter() - .take(arity) - .cloned() - .collect::>() - .join(", ")] + match array_sig { + ArrayFunctionSignature::Array { arguments, .. } => { + // Pair each name with its corresponding array argument type + vec![names + .iter() + .zip(arguments.iter()) + .map(|(name, arg_type)| format!("{name}: {arg_type}")) + .collect::>() + .join(", ")] + } + ArrayFunctionSignature::RecursiveArray => { + vec![names + .iter() + .take(1) + .map(|name| format!("{name}: recursive_array")) + .collect::>() + .join(", ")] + } + ArrayFunctionSignature::MapArray => { + vec![names + .iter() + .take(1) + .map(|name| format!("{name}: map_array")) + .collect::>() + .join(", ")] + } + } } else { // Fallback to semantic names like "array, index, element" self.to_string_repr() @@ -610,7 +728,7 @@ impl TypeSignature { .flat_map(|s| s.to_string_repr_with_names(parameter_names)) .collect(), TypeSignature::UserDefined => { - // UserDefined signatures can have parameter names + // UserDefined signatures can have parameter names but no type info if let Some(names) = parameter_names { vec![names.join(", ")] } else { @@ -1181,10 +1299,8 @@ impl Signature { /// Returns an error if the number of parameter names doesn't match the signature's arity. /// For signatures with variable arity (e.g., `Variadic`, `VariadicAny`), parameter names /// cannot be specified. - pub fn with_parameter_names( - mut self, - names: Vec, - ) -> datafusion_common::Result { + pub fn with_parameter_names(mut self, names: Vec>) -> Result { + let names = names.into_iter().map(Into::into).collect::>(); // Validate that the number of names matches the signature self.validate_parameter_names(&names)?; self.parameter_names = Some(names); @@ -1192,69 +1308,27 @@ impl Signature { } /// Validate that parameter names are compatible with this signature - fn validate_parameter_names( - &self, - names: &[String], - ) -> datafusion_common::Result<()> { - // Get expected argument count from the type signature - let expected_count = match &self.type_signature { - TypeSignature::Exact(types) => Some(types.len()), - TypeSignature::Uniform(count, _) => Some(*count), - TypeSignature::Numeric(count) => Some(*count), - TypeSignature::String(count) => Some(*count), - TypeSignature::Comparable(count) => Some(*count), - TypeSignature::Any(count) => Some(*count), - TypeSignature::Coercible(types) => Some(types.len()), - TypeSignature::Nullary => Some(0), - TypeSignature::ArraySignature(ArrayFunctionSignature::Array { - arguments, - .. - }) => Some(arguments.len()), - TypeSignature::ArraySignature(ArrayFunctionSignature::RecursiveArray) => { - Some(1) - } - TypeSignature::ArraySignature(ArrayFunctionSignature::MapArray) => Some(1), - // For OneOf, get the maximum arity from all variants - TypeSignature::OneOf(variants) => { - // Get max arity from all variants - let max_arity = variants - .iter() - .filter_map(|v| match v { - TypeSignature::Any(count) - | TypeSignature::Uniform(count, _) - | TypeSignature::Numeric(count) - | TypeSignature::String(count) - | TypeSignature::Comparable(count) => Some(*count), - TypeSignature::Exact(types) => Some(types.len()), - TypeSignature::Coercible(types) => Some(types.len()), - TypeSignature::Nullary => Some(0), - _ => None, - }) - .max(); - max_arity - } - // Variable arity signatures cannot have parameter names - TypeSignature::Variadic(_) - | TypeSignature::VariadicAny - | TypeSignature::UserDefined => None, - }; - - if let Some(expected) = expected_count { - if names.len() != expected { - return datafusion_common::plan_err!( - "Parameter names count ({}) does not match signature arity ({})", - names.len(), - expected - ); + fn validate_parameter_names(&self, names: &[String]) -> Result<()> { + // Check arity compatibility + match self.type_signature.arity() { + Arity::Fixed(expected) => { + if names.len() != expected { + return plan_err!( + "Parameter names count ({}) does not match signature arity ({})", + names.len(), + expected + ); + } } - } else { - // For UserDefined signatures, allow parameter names - // The function implementer is responsible for validating the names match the actual arguments - if !matches!(self.type_signature, TypeSignature::UserDefined) { - return datafusion_common::plan_err!( - "Cannot specify parameter names for variable arity signature: {:?}", - self.type_signature - ); + Arity::Variable => { + // For UserDefined signatures, allow parameter names + // The function implementer is responsible for validating the names match the actual arguments + if !matches!(self.type_signature, TypeSignature::UserDefined) { + return plan_err!( + "Cannot specify parameter names for variable arity signature: {:?}", + self.type_signature + ); + } } } @@ -1262,10 +1336,7 @@ impl Signature { let mut seen = std::collections::HashSet::new(); for name in names { if !seen.insert(name) { - return datafusion_common::plan_err!( - "Duplicate parameter name: '{}'", - name - ); + return plan_err!("Duplicate parameter name: '{}'", name); } } @@ -1542,7 +1613,7 @@ mod tests { fn test_signature_nullary_with_empty_names() { // Test that nullary signature accepts empty parameter names let sig = Signature::nullary(Volatility::Immutable) - .with_parameter_names(vec![]) + .with_parameter_names(Vec::::new()) .unwrap(); assert_eq!(sig.parameter_names, Some(vec![])); @@ -1556,11 +1627,11 @@ mod tests { // Without names: should show types assert_eq!(sig.to_string_repr_with_names(None), vec!["Int32, Utf8"]); - // With names: should show parameter names + // With names: should show parameter names with types let names = vec!["id".to_string(), "name".to_string()]; assert_eq!( sig.to_string_repr_with_names(Some(&names)), - vec!["id, name"] + vec!["id: Int32, name: Utf8"] ); } @@ -1572,9 +1643,12 @@ mod tests { // Without names: should show "Any" for each parameter assert_eq!(sig.to_string_repr_with_names(None), vec!["Any, Any, Any"]); - // With names: should show parameter names + // With names: should show parameter names with "Any" type let names = vec!["x".to_string(), "y".to_string(), "z".to_string()]; - assert_eq!(sig.to_string_repr_with_names(Some(&names)), vec!["x, y, z"]); + assert_eq!( + sig.to_string_repr_with_names(Some(&names)), + vec!["x: Any, y: Any, z: Any"] + ); } #[test] @@ -1589,7 +1663,7 @@ mod tests { vec!["Any, Any", "Any, Any, Any"] ); - // With names: should use names for each variant + // With names: should use names with types for each variant let names = vec![ "str".to_string(), "start_pos".to_string(), @@ -1597,23 +1671,25 @@ mod tests { ]; assert_eq!( sig.to_string_repr_with_names(Some(&names)), - vec!["str, start_pos", "str, start_pos, length"] + vec![ + "str: Any, start_pos: Any", + "str: Any, start_pos: Any, length: Any" + ] ); } #[test] fn test_to_string_repr_with_names_partial() { - // Test with fewer names than needed - let sig = TypeSignature::Exact(vec![ - DataType::Int32, - DataType::Utf8, - DataType::Float64, - ]); + // Test with more names than a single signature needs (valid for OneOf) + // This simulates providing max arity names for a OneOf signature + let sig = TypeSignature::Exact(vec![DataType::Int32, DataType::Utf8]); - // Provide only 2 names for 3 parameters - let names = vec!["a".to_string(), "b".to_string()]; - // Should only use the available names (takes first 2) - assert_eq!(sig.to_string_repr_with_names(Some(&names)), vec!["a, b"]); + // Provide 3 names for 2-parameter signature (extra name is ignored via zip) + let names = vec!["a".to_string(), "b".to_string(), "c".to_string()]; + assert_eq!( + sig.to_string_repr_with_names(Some(&names)), + vec!["a: Int32, b: Utf8"] + ); } #[test] @@ -1627,9 +1703,12 @@ mod tests { vec!["Float64, Float64"] ); - // With names: should show parameter names + // With names: should show parameter names with types let names = vec!["x".to_string(), "y".to_string()]; - assert_eq!(sig.to_string_repr_with_names(Some(&names)), vec!["x, y"]); + assert_eq!( + sig.to_string_repr_with_names(Some(&names)), + vec!["x: Float64, y: Float64"] + ); } #[test] @@ -1643,9 +1722,13 @@ mod tests { Coercion::new_exact(TypeSignatureClass::Native(logical_int32())), ]); - // With names: should show parameter names + // With names: should show parameter names with coercion types let names = vec!["a".to_string(), "b".to_string()]; - assert_eq!(sig.to_string_repr_with_names(Some(&names)), vec!["a, b"]); + let result = sig.to_string_repr_with_names(Some(&names)); + // Check that it contains the parameter names with type annotations + assert_eq!(result.len(), 1); + assert!(result[0].starts_with("a: ")); + assert!(result[0].contains(", b: ")); } #[test] @@ -1657,18 +1740,18 @@ mod tests { let names = vec!["a".to_string(), "b".to_string(), "c".to_string()]; - // All should show parameter names when provided + // All should show parameter names with type annotations assert_eq!( comparable.to_string_repr_with_names(Some(&names)), - vec!["a, b, c"] + vec!["a: Comparable, b: Comparable, c: Comparable"] ); assert_eq!( numeric.to_string_repr_with_names(Some(&names)), - vec!["a, b"] + vec!["a: Numeric, b: Numeric"] ); assert_eq!( string_sig.to_string_repr_with_names(Some(&names)), - vec!["a, b"] + vec!["a: String, b: String"] ); } @@ -1735,11 +1818,11 @@ mod tests { vec!["array, index, element"] ); - // With names: should show parameter names - let names = vec!["arr".to_string(), "size".to_string(), "value".to_string()]; + // With names: should pair each name with its array argument type + let names = vec!["arr".to_string(), "idx".to_string(), "val".to_string()]; assert_eq!( sig.to_string_repr_with_names(Some(&names)), - vec!["arr, size, value"] + vec!["arr: array, idx: index, val: element"] ); // Test RecursiveArray (1 argument) @@ -1748,7 +1831,7 @@ mod tests { let names = vec!["array".to_string()]; assert_eq!( recursive.to_string_repr_with_names(Some(&names)), - vec!["array"] + vec!["array: recursive_array"] ); // Test MapArray (1 argument) @@ -1756,7 +1839,134 @@ mod tests { let names = vec!["map".to_string()]; assert_eq!( map_array.to_string_repr_with_names(Some(&names)), - vec!["map"] + vec!["map: map_array"] ); } + + #[test] + fn test_type_signature_arity_exact() { + let sig = TypeSignature::Exact(vec![DataType::Int32, DataType::Utf8]); + assert_eq!(sig.arity(), Arity::Fixed(2)); + + let sig = TypeSignature::Exact(vec![]); + assert_eq!(sig.arity(), Arity::Fixed(0)); + } + + #[test] + fn test_type_signature_arity_uniform() { + let sig = TypeSignature::Uniform(3, vec![DataType::Float64]); + assert_eq!(sig.arity(), Arity::Fixed(3)); + + let sig = TypeSignature::Uniform(1, vec![DataType::Int32]); + assert_eq!(sig.arity(), Arity::Fixed(1)); + } + + #[test] + fn test_type_signature_arity_numeric() { + let sig = TypeSignature::Numeric(2); + assert_eq!(sig.arity(), Arity::Fixed(2)); + } + + #[test] + fn test_type_signature_arity_string() { + let sig = TypeSignature::String(3); + assert_eq!(sig.arity(), Arity::Fixed(3)); + } + + #[test] + fn test_type_signature_arity_comparable() { + let sig = TypeSignature::Comparable(2); + assert_eq!(sig.arity(), Arity::Fixed(2)); + } + + #[test] + fn test_type_signature_arity_any() { + let sig = TypeSignature::Any(4); + assert_eq!(sig.arity(), Arity::Fixed(4)); + } + + #[test] + fn test_type_signature_arity_coercible() { + use datafusion_common::types::{logical_int32, logical_string}; + let sig = TypeSignature::Coercible(vec![ + Coercion::new_exact(TypeSignatureClass::Native(logical_int32())), + Coercion::new_exact(TypeSignatureClass::Native(logical_string())), + ]); + assert_eq!(sig.arity(), Arity::Fixed(2)); + } + + #[test] + fn test_type_signature_arity_nullary() { + let sig = TypeSignature::Nullary; + assert_eq!(sig.arity(), Arity::Fixed(0)); + } + + #[test] + fn test_type_signature_arity_array_signature() { + // Test Array variant with 2 arguments + let sig = TypeSignature::ArraySignature(ArrayFunctionSignature::Array { + arguments: vec![ArrayFunctionArgument::Array, ArrayFunctionArgument::Index], + array_coercion: None, + }); + assert_eq!(sig.arity(), Arity::Fixed(2)); + + // Test Array variant with 3 arguments + let sig = TypeSignature::ArraySignature(ArrayFunctionSignature::Array { + arguments: vec![ + ArrayFunctionArgument::Array, + ArrayFunctionArgument::Element, + ArrayFunctionArgument::Index, + ], + array_coercion: None, + }); + assert_eq!(sig.arity(), Arity::Fixed(3)); + + // Test RecursiveArray variant + let sig = TypeSignature::ArraySignature(ArrayFunctionSignature::RecursiveArray); + assert_eq!(sig.arity(), Arity::Fixed(1)); + + // Test MapArray variant + let sig = TypeSignature::ArraySignature(ArrayFunctionSignature::MapArray); + assert_eq!(sig.arity(), Arity::Fixed(1)); + } + + #[test] + fn test_type_signature_arity_one_of_fixed() { + // OneOf with all fixed arity variants should return max arity + let sig = TypeSignature::OneOf(vec![ + TypeSignature::Exact(vec![DataType::Int32]), + TypeSignature::Exact(vec![DataType::Int32, DataType::Utf8]), + TypeSignature::Exact(vec![ + DataType::Int32, + DataType::Utf8, + DataType::Float64, + ]), + ]); + assert_eq!(sig.arity(), Arity::Fixed(3)); + } + + #[test] + fn test_type_signature_arity_one_of_variable() { + // OneOf with variable arity variant should return Variable + let sig = TypeSignature::OneOf(vec![ + TypeSignature::Exact(vec![DataType::Int32]), + TypeSignature::VariadicAny, + ]); + assert_eq!(sig.arity(), Arity::Variable); + } + + #[test] + fn test_type_signature_arity_variadic() { + let sig = TypeSignature::Variadic(vec![DataType::Int32]); + assert_eq!(sig.arity(), Arity::Variable); + + let sig = TypeSignature::VariadicAny; + assert_eq!(sig.arity(), Arity::Variable); + } + + #[test] + fn test_type_signature_arity_user_defined() { + let sig = TypeSignature::UserDefined; + assert_eq!(sig.arity(), Arity::Variable); + } } diff --git a/datafusion/expr/src/arguments.rs b/datafusion/expr/src/arguments.rs index 96b9c818613a9..722a5d6d31c84 100644 --- a/datafusion/expr/src/arguments.rs +++ b/datafusion/expr/src/arguments.rs @@ -19,6 +19,7 @@ use crate::Expr; use datafusion_common::{plan_err, Result}; +use std::collections::HashMap; /// Resolves function arguments, handling named and positional notation. /// @@ -97,6 +98,13 @@ fn reorder_named_arguments( args: Vec, arg_names: Vec>, ) -> Result> { + // Build HashMap for O(1) parameter name lookups + let param_index_map: HashMap<&str, usize> = param_names + .iter() + .enumerate() + .map(|(idx, name)| (name.as_str(), idx)) + .collect(); + // Count positional vs named arguments let positional_count = arg_names.iter().filter(|n| n.is_none()).count(); @@ -105,17 +113,24 @@ fn reorder_named_arguments( // Create a result vector with the expected size let expected_arg_count = param_names.len(); - let mut result: Vec> = vec![None; expected_arg_count]; - // Track which parameters have been assigned - let mut assigned = vec![false; expected_arg_count]; + // Validate positional argument count upfront + if positional_count > expected_arg_count { + return plan_err!( + "Too many positional arguments: expected at most {}, got {}", + expected_arg_count, + positional_count + ); + } + + let mut result: Vec> = vec![None; expected_arg_count]; // Process all arguments (both positional and named) for (i, (arg, arg_name)) in args.into_iter().zip(arg_names).enumerate() { if let Some(name) = arg_name { - // Named argument - find its position in param_names + // Named argument - O(1) lookup in HashMap let param_index = - param_names.iter().position(|p| p == &name).ok_or_else(|| { + param_index_map.get(name.as_str()).copied().ok_or_else(|| { datafusion_common::plan_datafusion_err!( "Unknown parameter name '{}'. Valid parameters are: [{}]", name, @@ -124,23 +139,14 @@ fn reorder_named_arguments( })?; // Check if this parameter was already assigned - if assigned[param_index] { + if result[param_index].is_some() { return plan_err!("Parameter '{}' specified multiple times", name); } result[param_index] = Some(arg); - assigned[param_index] = true; } else { // Positional argument - place at current position - if i >= expected_arg_count { - return plan_err!( - "Too many positional arguments: expected at most {}, got {}", - expected_arg_count, - positional_count - ); - } result[i] = Some(arg); - assigned[i] = true; } } @@ -148,17 +154,13 @@ fn reorder_named_arguments( // Only require parameters up to the number of arguments provided (supports optional parameters) let required_count = args_len; for i in 0..required_count { - if !assigned[i] { + if result[i].is_none() { return plan_err!("Missing required parameter '{}'", param_names[i]); } } // Return only the assigned parameters (handles optional trailing parameters) - Ok(result - .into_iter() - .take(required_count) - .map(|e| e.unwrap()) - .collect()) + Ok(result.into_iter().take(required_count).flatten().collect()) } #[cfg(test)] diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index 286e643d3c2c8..409b53b7b0b68 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -1720,8 +1720,16 @@ mod tests { use datafusion_expr_common::signature::{TypeSignature, Volatility}; // Create a signature like substr with parameter names + // substr(str, start_pos) or substr(str, start_pos, length) let sig = Signature::one_of( - vec![TypeSignature::Any(2), TypeSignature::Any(3)], + vec![ + TypeSignature::Exact(vec![DataType::Utf8, DataType::Int64]), + TypeSignature::Exact(vec![ + DataType::Utf8, + DataType::Int64, + DataType::Int64, + ]), + ], Volatility::Immutable, ) .with_parameter_names(vec![ @@ -1734,20 +1742,14 @@ mod tests { // Generate error message with only 1 argument provided let error_msg = generate_signature_error_msg("substr", sig, &[DataType::Utf8]); - // Error message should contain parameter names + // Error message should contain parameter names with types assert!( - error_msg.contains("str, start_pos"), - "Expected 'str, start_pos' in error message, got: {error_msg}" + error_msg.contains("str: Utf8, start_pos: Int64"), + "Expected 'str: Utf8, start_pos: Int64' in error message, got: {error_msg}" ); assert!( - error_msg.contains("str, start_pos, length"), - "Expected 'str, start_pos, length' in error message, got: {error_msg}" - ); - - // Should NOT contain generic "Any" types - assert!( - !error_msg.contains("Any, Any"), - "Should not contain 'Any, Any', got: {error_msg}" + error_msg.contains("str: Utf8, start_pos: Int64, length: Int64"), + "Expected 'str: Utf8, start_pos: Int64, length: Int64' in error message, got: {error_msg}" ); } diff --git a/datafusion/sql/tests/sql_integration.rs b/datafusion/sql/tests/sql_integration.rs index 6933c45617ffb..8c5a9b643a8ec 100644 --- a/datafusion/sql/tests/sql_integration.rs +++ b/datafusion/sql/tests/sql_integration.rs @@ -3957,7 +3957,7 @@ fn test_double_quoted_literal_string() { fn test_named_arguments_with_dialects() { // Test that named arguments syntax (param => value) with different dialects use datafusion_sql::parser::Statement as DFStatement; - use sqlparser::ast::{FunctionArg, Statement, Expr as SQLExpr}; + use sqlparser::ast::{Expr as SQLExpr, FunctionArg, Statement}; use sqlparser::dialect::Dialect; let sql = "SELECT my_func(arg1 => 'value1')"; @@ -3971,8 +3971,13 @@ fn test_named_arguments_with_dialects() { if let Statement::Query(query) = stmt.as_ref() { if let sqlparser::ast::SetExpr::Select(select) = query.body.as_ref() { let projection = &select.projection[0]; - if let sqlparser::ast::SelectItem::UnnamedExpr(SQLExpr::Function(func)) = projection { - if let sqlparser::ast::FunctionArguments::List(arg_list) = &func.args { + if let sqlparser::ast::SelectItem::UnnamedExpr(SQLExpr::Function( + func, + )) = projection + { + if let sqlparser::ast::FunctionArguments::List(arg_list) = + &func.args + { return Some(arg_list.args[0].clone()); } } diff --git a/docs/source/library-user-guide/functions/adding-udfs.md b/docs/source/library-user-guide/functions/adding-udfs.md index 2fc6026ede6ef..e143fdd7b12ea 100644 --- a/docs/source/library-user-guide/functions/adding-udfs.md +++ b/docs/source/library-user-guide/functions/adding-udfs.md @@ -695,8 +695,8 @@ When a function call fails due to incorrect arguments, DataFusion will show the ```text No function matches the given name and argument types substr(Utf8). Candidate functions: - substr(str, start_pos) - substr(str, start_pos, length) + substr(str: Any, start_pos: Any) + substr(str: Any, start_pos: Any, length: Any) ``` ## Adding a Window UDF From 6ac3a742068c658ce6115da1402ac0e6a8e54157 Mon Sep 17 00:00:00 2001 From: bubulalabu Date: Sun, 19 Oct 2025 15:12:33 +0200 Subject: [PATCH 6/9] moved all use statements out of functions --- datafusion/expr-common/src/signature.rs | 11 ++++------- datafusion/expr/src/utils.rs | 5 +---- datafusion/sql/tests/sql_integration.rs | 7 +++---- 3 files changed, 8 insertions(+), 15 deletions(-) diff --git a/datafusion/expr-common/src/signature.rs b/datafusion/expr-common/src/signature.rs index b0af49ed3889a..62227a8f491f0 100644 --- a/datafusion/expr-common/src/signature.rs +++ b/datafusion/expr-common/src/signature.rs @@ -1346,9 +1346,12 @@ impl Signature { #[cfg(test)] mod tests { - use datafusion_common::types::{logical_int64, logical_string}; + use datafusion_common::types::{logical_int32, logical_int64, logical_string}; use super::*; + use crate::signature::{ + ArrayFunctionArgument, ArrayFunctionSignature, Coercion, TypeSignatureClass, + }; #[test] fn supports_zero_argument_tests() { @@ -1713,9 +1716,6 @@ mod tests { #[test] fn test_to_string_repr_with_names_coercible() { - use crate::signature::{Coercion, TypeSignatureClass}; - use datafusion_common::types::logical_int32; - // Test Coercible signature with parameter names let sig = TypeSignature::Coercible(vec![ Coercion::new_exact(TypeSignatureClass::Native(logical_int32())), @@ -1800,8 +1800,6 @@ mod tests { #[test] fn test_to_string_repr_with_names_array_signature() { - use crate::signature::{ArrayFunctionArgument, ArrayFunctionSignature}; - // Test ArraySignature with parameter names let sig = TypeSignature::ArraySignature(ArrayFunctionSignature::Array { arguments: vec![ @@ -1887,7 +1885,6 @@ mod tests { #[test] fn test_type_signature_arity_coercible() { - use datafusion_common::types::{logical_int32, logical_string}; let sig = TypeSignature::Coercible(vec![ Coercion::new_exact(TypeSignatureClass::Native(logical_int32())), Coercion::new_exact(TypeSignatureClass::Native(logical_string())), diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index 409b53b7b0b68..ae7287ae6dd46 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -1295,6 +1295,7 @@ mod tests { Cast, ExprFunctionExt, WindowFunctionDefinition, }; use arrow::datatypes::{UnionFields, UnionMode}; + use datafusion_expr_common::signature::{TypeSignature, Volatility}; #[test] fn test_group_window_expr_by_sort_keys_empty_case() -> Result<()> { @@ -1717,8 +1718,6 @@ mod tests { #[test] fn test_generate_signature_error_msg_with_parameter_names() { - use datafusion_expr_common::signature::{TypeSignature, Volatility}; - // Create a signature like substr with parameter names // substr(str, start_pos) or substr(str, start_pos, length) let sig = Signature::one_of( @@ -1755,8 +1754,6 @@ mod tests { #[test] fn test_generate_signature_error_msg_without_parameter_names() { - use datafusion_expr_common::signature::{TypeSignature, Volatility}; - // Create a signature without parameter names let sig = Signature::one_of( vec![TypeSignature::Any(2), TypeSignature::Any(3)], diff --git a/datafusion/sql/tests/sql_integration.rs b/datafusion/sql/tests/sql_integration.rs index 8c5a9b643a8ec..0ca1a81243c7b 100644 --- a/datafusion/sql/tests/sql_integration.rs +++ b/datafusion/sql/tests/sql_integration.rs @@ -31,7 +31,7 @@ use datafusion_expr::{ }; use datafusion_functions::{string, unicode}; use datafusion_sql::{ - parser::DFParser, + parser::{DFParser, Statement as DFStatement}, planner::{NullOrdering, ParserOptions, SqlToRel}, }; @@ -46,6 +46,7 @@ use datafusion_functions_nested::make_array::make_array_udf; use datafusion_functions_window::{rank::rank_udwf, row_number::row_number_udwf}; use insta::{allow_duplicates, assert_snapshot}; use rstest::rstest; +use sqlparser::ast::{Expr as SQLExpr, FunctionArg, Statement}; use sqlparser::dialect::{ AnsiDialect, BigQueryDialect, ClickHouseDialect, DatabricksDialect, Dialect, DuckDbDialect, GenericDialect, HiveDialect, MsSqlDialect, MySqlDialect, @@ -3956,8 +3957,6 @@ fn test_double_quoted_literal_string() { #[test] fn test_named_arguments_with_dialects() { // Test that named arguments syntax (param => value) with different dialects - use datafusion_sql::parser::Statement as DFStatement; - use sqlparser::ast::{Expr as SQLExpr, FunctionArg, Statement}; use sqlparser::dialect::Dialect; let sql = "SELECT my_func(arg1 => 'value1')"; @@ -4009,7 +4008,7 @@ fn test_named_arguments_with_dialects() { assert!(matches!(arg, Some(FunctionArg::Named { name, .. }) if name.value == "arg1")); let arg = extract_first_arg(&MsSqlDialect {}); - assert!(matches!(arg, None)); + assert!(arg.is_none()); let arg = extract_first_arg(&MySqlDialect {}); assert!(matches!(arg, Some(FunctionArg::Named { name, .. }) if name.value == "arg1")); From bb04678e4cbfccddea23f37eb1ff60c654072d0e Mon Sep 17 00:00:00 2001 From: bubulalabu Date: Sun, 19 Oct 2025 19:15:41 +0200 Subject: [PATCH 7/9] removed noisy comments --- datafusion/expr-common/src/signature.rs | 50 +------------------------ datafusion/expr/src/arguments.rs | 10 ----- datafusion/expr/src/utils.rs | 6 --- datafusion/sql/src/expr/function.rs | 3 -- datafusion/sql/tests/sql_integration.rs | 3 -- 5 files changed, 1 insertion(+), 71 deletions(-) diff --git a/datafusion/expr-common/src/signature.rs b/datafusion/expr-common/src/signature.rs index 62227a8f491f0..99bcb170c4097 100644 --- a/datafusion/expr-common/src/signature.rs +++ b/datafusion/expr-common/src/signature.rs @@ -591,7 +591,6 @@ impl TypeSignature { match self { TypeSignature::Exact(types) => { if let Some(names) = parameter_names { - // Combine names with types: "name: Type" vec![names .iter() .zip(types.iter()) @@ -604,7 +603,6 @@ impl TypeSignature { } TypeSignature::Any(count) => { if let Some(names) = parameter_names { - // Combine names with "Any": "name: Any" vec![names .iter() .take(*count) @@ -619,7 +617,6 @@ impl TypeSignature { } TypeSignature::Uniform(count, types) => { if let Some(names) = parameter_names { - // Combine names with union of types: "name: Type1/Type2" let type_str = Self::join_types(types, "/"); vec![names .iter() @@ -628,13 +625,11 @@ impl TypeSignature { .collect::>() .join(", ")] } else { - // Fallback to original representation self.to_string_repr() } } TypeSignature::Coercible(coercions) => { if let Some(names) = parameter_names { - // Combine names with coercion types: "name: Type" vec![names .iter() .zip(coercions.iter()) @@ -647,7 +642,6 @@ impl TypeSignature { } TypeSignature::Comparable(count) => { if let Some(names) = parameter_names { - // Combine names with "Comparable": "name: Comparable" vec![names .iter() .take(*count) @@ -660,7 +654,6 @@ impl TypeSignature { } TypeSignature::Numeric(count) => { if let Some(names) = parameter_names { - // Combine names with "Numeric": "name: Numeric" vec![names .iter() .take(*count) @@ -673,7 +666,6 @@ impl TypeSignature { } TypeSignature::String(count) => { if let Some(names) = parameter_names { - // Combine names with "String": "name: String" vec![names .iter() .take(*count) @@ -684,16 +676,11 @@ impl TypeSignature { self.to_string_repr() } } - TypeSignature::Nullary => { - // No parameters, so no names to show - self.to_string_repr() - } + TypeSignature::Nullary => self.to_string_repr(), TypeSignature::ArraySignature(array_sig) => { - // ArraySignature has fixed arity, so it can support parameter names if let Some(names) = parameter_names { match array_sig { ArrayFunctionSignature::Array { arguments, .. } => { - // Pair each name with its corresponding array argument type vec![names .iter() .zip(arguments.iter()) @@ -719,7 +706,6 @@ impl TypeSignature { } } } else { - // Fallback to semantic names like "array, index, element" self.to_string_repr() } } @@ -728,7 +714,6 @@ impl TypeSignature { .flat_map(|s| s.to_string_repr_with_names(parameter_names)) .collect(), TypeSignature::UserDefined => { - // UserDefined signatures can have parameter names but no type info if let Some(names) = parameter_names { vec![names.join(", ")] } else { @@ -1309,7 +1294,6 @@ impl Signature { /// Validate that parameter names are compatible with this signature fn validate_parameter_names(&self, names: &[String]) -> Result<()> { - // Check arity compatibility match self.type_signature.arity() { Arity::Fixed(expected) => { if names.len() != expected { @@ -1332,7 +1316,6 @@ impl Signature { } } - // Validate no duplicate names let mut seen = std::collections::HashSet::new(); for name in names { if !seen.insert(name) { @@ -1512,7 +1495,6 @@ mod tests { #[test] fn test_signature_with_parameter_names() { - // Test adding parameter names to exact signature let sig = Signature::exact( vec![DataType::Int32, DataType::Utf8], Volatility::Immutable, @@ -1532,7 +1514,6 @@ mod tests { #[test] fn test_signature_parameter_names_wrong_count() { - // Test that wrong number of names fails let result = Signature::exact( vec![DataType::Int32, DataType::Utf8], Volatility::Immutable, @@ -1548,7 +1529,6 @@ mod tests { #[test] fn test_signature_parameter_names_duplicate() { - // Test that duplicate names fail let result = Signature::exact( vec![DataType::Int32, DataType::Int32], Volatility::Immutable, @@ -1564,7 +1544,6 @@ mod tests { #[test] fn test_signature_parameter_names_variadic() { - // Test that variadic signatures reject parameter names let result = Signature::variadic(vec![DataType::Int32], Volatility::Immutable) .with_parameter_names(vec!["arg".to_string()]); @@ -1577,7 +1556,6 @@ mod tests { #[test] fn test_signature_without_parameter_names() { - // Test that signatures without parameter names still work let sig = Signature::exact( vec![DataType::Int32, DataType::Utf8], Volatility::Immutable, @@ -1588,7 +1566,6 @@ mod tests { #[test] fn test_signature_uniform_with_parameter_names() { - // Test uniform signature with parameter names let sig = Signature::uniform(3, vec![DataType::Float64], Volatility::Immutable) .with_parameter_names(vec!["x".to_string(), "y".to_string(), "z".to_string()]) .unwrap(); @@ -1601,7 +1578,6 @@ mod tests { #[test] fn test_signature_numeric_with_parameter_names() { - // Test numeric signature with parameter names let sig = Signature::numeric(2, Volatility::Immutable) .with_parameter_names(vec!["a".to_string(), "b".to_string()]) .unwrap(); @@ -1614,7 +1590,6 @@ mod tests { #[test] fn test_signature_nullary_with_empty_names() { - // Test that nullary signature accepts empty parameter names let sig = Signature::nullary(Volatility::Immutable) .with_parameter_names(Vec::::new()) .unwrap(); @@ -1624,13 +1599,10 @@ mod tests { #[test] fn test_to_string_repr_with_names_exact() { - // Test Exact signature with parameter names let sig = TypeSignature::Exact(vec![DataType::Int32, DataType::Utf8]); - // Without names: should show types assert_eq!(sig.to_string_repr_with_names(None), vec!["Int32, Utf8"]); - // With names: should show parameter names with types let names = vec!["id".to_string(), "name".to_string()]; assert_eq!( sig.to_string_repr_with_names(Some(&names)), @@ -1640,13 +1612,10 @@ mod tests { #[test] fn test_to_string_repr_with_names_any() { - // Test Any signature with parameter names let sig = TypeSignature::Any(3); - // Without names: should show "Any" for each parameter assert_eq!(sig.to_string_repr_with_names(None), vec!["Any, Any, Any"]); - // With names: should show parameter names with "Any" type let names = vec!["x".to_string(), "y".to_string(), "z".to_string()]; assert_eq!( sig.to_string_repr_with_names(Some(&names)), @@ -1656,17 +1625,14 @@ mod tests { #[test] fn test_to_string_repr_with_names_one_of() { - // Test OneOf signature with parameter names (like substr) let sig = TypeSignature::OneOf(vec![TypeSignature::Any(2), TypeSignature::Any(3)]); - // Without names: should show generic "Any" types assert_eq!( sig.to_string_repr_with_names(None), vec!["Any, Any", "Any, Any, Any"] ); - // With names: should use names with types for each variant let names = vec![ "str".to_string(), "start_pos".to_string(), @@ -1683,7 +1649,6 @@ mod tests { #[test] fn test_to_string_repr_with_names_partial() { - // Test with more names than a single signature needs (valid for OneOf) // This simulates providing max arity names for a OneOf signature let sig = TypeSignature::Exact(vec![DataType::Int32, DataType::Utf8]); @@ -1697,16 +1662,13 @@ mod tests { #[test] fn test_to_string_repr_with_names_uniform() { - // Test Uniform signature with parameter names let sig = TypeSignature::Uniform(2, vec![DataType::Float64]); - // Without names: should show type representation assert_eq!( sig.to_string_repr_with_names(None), vec!["Float64, Float64"] ); - // With names: should show parameter names with types let names = vec!["x".to_string(), "y".to_string()]; assert_eq!( sig.to_string_repr_with_names(Some(&names)), @@ -1716,13 +1678,11 @@ mod tests { #[test] fn test_to_string_repr_with_names_coercible() { - // Test Coercible signature with parameter names let sig = TypeSignature::Coercible(vec![ Coercion::new_exact(TypeSignatureClass::Native(logical_int32())), Coercion::new_exact(TypeSignatureClass::Native(logical_int32())), ]); - // With names: should show parameter names with coercion types let names = vec!["a".to_string(), "b".to_string()]; let result = sig.to_string_repr_with_names(Some(&names)); // Check that it contains the parameter names with type annotations @@ -1733,7 +1693,6 @@ mod tests { #[test] fn test_to_string_repr_with_names_comparable_numeric_string() { - // Test Comparable, Numeric, and String signatures let comparable = TypeSignature::Comparable(3); let numeric = TypeSignature::Numeric(2); let string_sig = TypeSignature::String(2); @@ -1757,7 +1716,6 @@ mod tests { #[test] fn test_to_string_repr_with_names_variadic_fallback() { - // Test that variadic variants fall back to to_string_repr() let variadic = TypeSignature::Variadic(vec![DataType::Utf8, DataType::LargeUtf8]); let names = vec!["x".to_string()]; assert_eq!( @@ -1777,7 +1735,6 @@ mod tests { user_defined.to_string_repr_with_names(Some(&names)), vec!["x"] ); - // Without names, falls back to to_string_repr() assert_eq!( user_defined.to_string_repr_with_names(None), user_defined.to_string_repr() @@ -1786,7 +1743,6 @@ mod tests { #[test] fn test_to_string_repr_with_names_nullary() { - // Test Nullary signature (no arguments) let sig = TypeSignature::Nullary; let names = vec!["x".to_string()]; @@ -1800,7 +1756,6 @@ mod tests { #[test] fn test_to_string_repr_with_names_array_signature() { - // Test ArraySignature with parameter names let sig = TypeSignature::ArraySignature(ArrayFunctionSignature::Array { arguments: vec![ ArrayFunctionArgument::Array, @@ -1810,20 +1765,17 @@ mod tests { array_coercion: None, }); - // Without names: should show semantic types assert_eq!( sig.to_string_repr_with_names(None), vec!["array, index, element"] ); - // With names: should pair each name with its array argument type let names = vec!["arr".to_string(), "idx".to_string(), "val".to_string()]; assert_eq!( sig.to_string_repr_with_names(Some(&names)), vec!["arr: array, idx: index, val: element"] ); - // Test RecursiveArray (1 argument) let recursive = TypeSignature::ArraySignature(ArrayFunctionSignature::RecursiveArray); let names = vec!["array".to_string()]; diff --git a/datafusion/expr/src/arguments.rs b/datafusion/expr/src/arguments.rs index 722a5d6d31c84..5653993db98fe 100644 --- a/datafusion/expr/src/arguments.rs +++ b/datafusion/expr/src/arguments.rs @@ -52,7 +52,6 @@ pub fn resolve_function_arguments( args: Vec, arg_names: Vec>, ) -> Result> { - // Validate that arg_names length matches args length if args.len() != arg_names.len() { return plan_err!( "Internal error: args length ({}) != arg_names length ({})", @@ -66,10 +65,8 @@ pub fn resolve_function_arguments( return Ok(args); } - // Validate mixed positional and named arguments validate_argument_order(&arg_names)?; - // Validate and reorder named arguments reorder_named_arguments(param_names, args, arg_names) } @@ -105,16 +102,13 @@ fn reorder_named_arguments( .map(|(idx, name)| (name.as_str(), idx)) .collect(); - // Count positional vs named arguments let positional_count = arg_names.iter().filter(|n| n.is_none()).count(); // Capture args length before consuming the vector let args_len = args.len(); - // Create a result vector with the expected size let expected_arg_count = param_names.len(); - // Validate positional argument count upfront if positional_count > expected_arg_count { return plan_err!( "Too many positional arguments: expected at most {}, got {}", @@ -125,7 +119,6 @@ fn reorder_named_arguments( let mut result: Vec> = vec![None; expected_arg_count]; - // Process all arguments (both positional and named) for (i, (arg, arg_name)) in args.into_iter().zip(arg_names).enumerate() { if let Some(name) = arg_name { // Named argument - O(1) lookup in HashMap @@ -138,19 +131,16 @@ fn reorder_named_arguments( ) })?; - // Check if this parameter was already assigned if result[param_index].is_some() { return plan_err!("Parameter '{}' specified multiple times", name); } result[param_index] = Some(arg); } else { - // Positional argument - place at current position result[i] = Some(arg); } } - // Check if all required parameters were provided // Only require parameters up to the number of arguments provided (supports optional parameters) let required_count = args_len; for i in 0..required_count { diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index ae7287ae6dd46..74ba99847f709 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -1718,8 +1718,6 @@ mod tests { #[test] fn test_generate_signature_error_msg_with_parameter_names() { - // Create a signature like substr with parameter names - // substr(str, start_pos) or substr(str, start_pos, length) let sig = Signature::one_of( vec![ TypeSignature::Exact(vec![DataType::Utf8, DataType::Int64]), @@ -1741,7 +1739,6 @@ mod tests { // Generate error message with only 1 argument provided let error_msg = generate_signature_error_msg("substr", sig, &[DataType::Utf8]); - // Error message should contain parameter names with types assert!( error_msg.contains("str: Utf8, start_pos: Int64"), "Expected 'str: Utf8, start_pos: Int64' in error message, got: {error_msg}" @@ -1754,16 +1751,13 @@ mod tests { #[test] fn test_generate_signature_error_msg_without_parameter_names() { - // Create a signature without parameter names let sig = Signature::one_of( vec![TypeSignature::Any(2), TypeSignature::Any(3)], Volatility::Immutable, ); - // Generate error message let error_msg = generate_signature_error_msg("my_func", sig, &[DataType::Int32]); - // Should contain generic "Any" types when no parameter names assert!( error_msg.contains("Any, Any"), "Expected 'Any, Any' without parameter names, got: {error_msg}" diff --git a/datafusion/sql/src/expr/function.rs b/datafusion/sql/src/expr/function.rs index 8f8873aea76c2..3b7c717cb31e0 100644 --- a/datafusion/sql/src/expr/function.rs +++ b/datafusion/sql/src/expr/function.rs @@ -277,9 +277,7 @@ impl SqlToRel<'_, S> { let (args, arg_names) = self.function_args_to_expr_with_names(args, schema, planner_context)?; - // Resolve named arguments if any are present let resolved_args = if arg_names.iter().any(|name| name.is_some()) { - // Get parameter names from the signature if available if let Some(param_names) = &fm.signature().parameter_names { datafusion_expr::arguments::resolve_function_arguments( param_names, @@ -287,7 +285,6 @@ impl SqlToRel<'_, S> { arg_names, )? } else { - // UDF doesn't support named arguments return plan_err!( "Function '{}' does not support named arguments", fm.name() diff --git a/datafusion/sql/tests/sql_integration.rs b/datafusion/sql/tests/sql_integration.rs index 0ca1a81243c7b..f4c5bb39fe960 100644 --- a/datafusion/sql/tests/sql_integration.rs +++ b/datafusion/sql/tests/sql_integration.rs @@ -3956,9 +3956,6 @@ fn test_double_quoted_literal_string() { #[test] fn test_named_arguments_with_dialects() { - // Test that named arguments syntax (param => value) with different dialects - use sqlparser::dialect::Dialect; - let sql = "SELECT my_func(arg1 => 'value1')"; // Returns None if the dialect doesn't support the => operator From 61c57f33bf57e3dd2172369cd2e6301afcdd04ec Mon Sep 17 00:00:00 2001 From: bubulalabu Date: Sun, 19 Oct 2025 19:34:16 +0200 Subject: [PATCH 8/9] moved enum Arity so it doesn't break the docstring of TypeSignature --- datafusion/expr-common/src/signature.rs | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/datafusion/expr-common/src/signature.rs b/datafusion/expr-common/src/signature.rs index 99bcb170c4097..38eef077c5af9 100644 --- a/datafusion/expr-common/src/signature.rs +++ b/datafusion/expr-common/src/signature.rs @@ -84,6 +84,15 @@ pub enum Volatility { Volatile, } +/// Represents the arity (number of arguments) of a function signature +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum Arity { + /// Fixed number of arguments + Fixed(usize), + /// Variable number of arguments (e.g., Variadic, VariadicAny, UserDefined) + Variable, +} + /// The types of arguments for which a function has implementations. /// /// [`TypeSignature`] **DOES NOT** define the types that a user query could call the @@ -142,16 +151,6 @@ pub enum Volatility { /// DataType::Timestamp(TimeUnit::Nanosecond, Some(TIMEZONE_WILDCARD.into())), /// ]); /// ``` -/// -/// Represents the arity (number of arguments) of a function signature -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum Arity { - /// Fixed number of arguments - Fixed(usize), - /// Variable number of arguments (e.g., Variadic, VariadicAny, UserDefined) - Variable, -} - #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)] pub enum TypeSignature { /// One or more arguments of a common type out of a list of valid types. From ebc391eb03b936a0cc5a2a55f9ab7d85b8fdab7e Mon Sep 17 00:00:00 2001 From: bubulalabu Date: Sun, 26 Oct 2025 13:49:24 +0100 Subject: [PATCH 9/9] removed confusing test_named_arguments_with_dialects test reset to default dialect in named_arguments.slt after testing MsSQL dialect compacted pattern matching in sql_fn_arg_to_logical_expr_with_name --- datafusion/sql/src/expr/function.rs | 22 +---- datafusion/sql/tests/sql_integration.rs | 82 +------------------ .../test_files/named_arguments.slt | 4 + 3 files changed, 10 insertions(+), 98 deletions(-) diff --git a/datafusion/sql/src/expr/function.rs b/datafusion/sql/src/expr/function.rs index 3b7c717cb31e0..cb34bb0f7eb7b 100644 --- a/datafusion/sql/src/expr/function.rs +++ b/datafusion/sql/src/expr/function.rs @@ -707,23 +707,16 @@ impl SqlToRel<'_, S> { } // PostgreSQL dialect uses ExprNamed variant with expression for name FunctionArg::ExprNamed { - name, + name: SQLExpr::Identifier(name), arg: FunctionArgExpr::Expr(arg), operator: _, } => { let expr = self.sql_expr_to_logical_expr(arg, schema, planner_context)?; - let arg_name = match name { - SQLExpr::Identifier(ident) => crate::utils::normalize_ident(ident), - _ => { - return plan_err!( - "Named argument must use a simple identifier, got: {name:?}" - ) - } - }; + let arg_name = crate::utils::normalize_ident(name); Ok((expr, Some(arg_name))) } FunctionArg::ExprNamed { - name, + name: SQLExpr::Identifier(name), arg: FunctionArgExpr::Wildcard, operator: _, } => { @@ -732,14 +725,7 @@ impl SqlToRel<'_, S> { qualifier: None, options: Box::new(WildcardOptions::default()), }; - let arg_name = match name { - SQLExpr::Identifier(ident) => crate::utils::normalize_ident(ident), - _ => { - return plan_err!( - "Named argument must use a simple identifier, got: {name:?}" - ) - } - }; + let arg_name = crate::utils::normalize_ident(name); Ok((expr, Some(arg_name))) } _ => not_impl_err!("Unsupported qualified wildcard argument: {sql:?}"), diff --git a/datafusion/sql/tests/sql_integration.rs b/datafusion/sql/tests/sql_integration.rs index f4c5bb39fe960..f66af28f436e6 100644 --- a/datafusion/sql/tests/sql_integration.rs +++ b/datafusion/sql/tests/sql_integration.rs @@ -31,7 +31,7 @@ use datafusion_expr::{ }; use datafusion_functions::{string, unicode}; use datafusion_sql::{ - parser::{DFParser, Statement as DFStatement}, + parser::DFParser, planner::{NullOrdering, ParserOptions, SqlToRel}, }; @@ -46,12 +46,7 @@ use datafusion_functions_nested::make_array::make_array_udf; use datafusion_functions_window::{rank::rank_udwf, row_number::row_number_udwf}; use insta::{allow_duplicates, assert_snapshot}; use rstest::rstest; -use sqlparser::ast::{Expr as SQLExpr, FunctionArg, Statement}; -use sqlparser::dialect::{ - AnsiDialect, BigQueryDialect, ClickHouseDialect, DatabricksDialect, Dialect, - DuckDbDialect, GenericDialect, HiveDialect, MsSqlDialect, MySqlDialect, - PostgreSqlDialect, RedshiftSqlDialect, SQLiteDialect, SnowflakeDialect, -}; +use sqlparser::dialect::{Dialect, GenericDialect, HiveDialect, MySqlDialect}; mod cases; mod common; @@ -3954,79 +3949,6 @@ fn test_double_quoted_literal_string() { assert!(logical_plan("SELECT \"1\"").is_err()); } -#[test] -fn test_named_arguments_with_dialects() { - let sql = "SELECT my_func(arg1 => 'value1')"; - - // Returns None if the dialect doesn't support the => operator - let extract_first_arg = |dialect: &dyn Dialect| -> Option { - let mut statements = DFParser::parse_sql_with_dialect(sql, dialect).ok()?; - - let statement = statements.pop_front().unwrap(); - if let DFStatement::Statement(stmt) = statement { - if let Statement::Query(query) = stmt.as_ref() { - if let sqlparser::ast::SetExpr::Select(select) = query.body.as_ref() { - let projection = &select.projection[0]; - if let sqlparser::ast::SelectItem::UnnamedExpr(SQLExpr::Function( - func, - )) = projection - { - if let sqlparser::ast::FunctionArguments::List(arg_list) = - &func.args - { - return Some(arg_list.args[0].clone()); - } - } - } - } - } - panic!("Failed to extract function argument"); - }; - - let arg = extract_first_arg(&AnsiDialect {}); - assert!(matches!(arg, Some(FunctionArg::Named { name, .. }) if name.value == "arg1")); - - let arg = extract_first_arg(&BigQueryDialect {}); - assert!(matches!(arg, Some(FunctionArg::Named { name, .. }) if name.value == "arg1")); - - let arg = extract_first_arg(&ClickHouseDialect {}); - assert!(matches!(arg, Some(FunctionArg::Named { name, .. }) if name.value == "arg1")); - - let arg = extract_first_arg(&DatabricksDialect {}); - assert!(matches!(arg, Some(FunctionArg::Named { name, .. }) if name.value == "arg1")); - - let arg = extract_first_arg(&DuckDbDialect {}); - assert!(matches!(arg, Some(FunctionArg::Named { name, .. }) if name.value == "arg1")); - - let arg = extract_first_arg(&GenericDialect {}); - assert!(matches!(arg, Some(FunctionArg::Named { name, .. }) if name.value == "arg1")); - - let arg = extract_first_arg(&HiveDialect {}); - assert!(matches!(arg, Some(FunctionArg::Named { name, .. }) if name.value == "arg1")); - - let arg = extract_first_arg(&MsSqlDialect {}); - assert!(arg.is_none()); - - let arg = extract_first_arg(&MySqlDialect {}); - assert!(matches!(arg, Some(FunctionArg::Named { name, .. }) if name.value == "arg1")); - - let arg = extract_first_arg(&PostgreSqlDialect {}); - assert!(matches!( - arg.as_ref(), - Some(FunctionArg::ExprNamed { name, .. }) - if matches!(name, SQLExpr::Identifier(ident) if ident.value == "arg1") - )); - - let arg = extract_first_arg(&RedshiftSqlDialect {}); - assert!(matches!(arg, Some(FunctionArg::Named { name, .. }) if name.value == "arg1")); - - let arg = extract_first_arg(&SQLiteDialect {}); - assert!(matches!(arg, Some(FunctionArg::Named { name, .. }) if name.value == "arg1")); - - let arg = extract_first_arg(&SnowflakeDialect {}); - assert!(matches!(arg, Some(FunctionArg::Named { name, .. }) if name.value == "arg1")); -} - #[test] fn test_constant_expr_eq_join() { let sql = "SELECT id, order_id \ diff --git a/datafusion/sqllogictest/test_files/named_arguments.slt b/datafusion/sqllogictest/test_files/named_arguments.slt index 9737f5b70d6eb..c93da7e7a8f9e 100644 --- a/datafusion/sqllogictest/test_files/named_arguments.slt +++ b/datafusion/sqllogictest/test_files/named_arguments.slt @@ -133,3 +133,7 @@ set datafusion.sql_parser.dialect = 'MsSQL'; # Error: MsSQL dialect does not support => operator query error DataFusion error: SQL error: ParserError\("Expected: \), found: => at Line: 1, Column: 19"\) SELECT substr(str => 'hello world', start_pos => 7, length => 5); + +# Reset to default dialect +statement ok +set datafusion.sql_parser.dialect = 'Generic';