diff --git a/datafusion/expr-common/src/signature.rs b/datafusion/expr-common/src/signature.rs index 5fd4518e2e57..38eef077c5af 100644 --- a/datafusion/expr-common/src/signature.rs +++ b/datafusion/expr-common/src/signature.rs @@ -22,9 +22,9 @@ use std::hash::Hash; use crate::type_coercion::aggregates::NUMERICS; use arrow::datatypes::{DataType, IntervalUnit, TimeUnit}; -use datafusion_common::internal_err; use datafusion_common::types::{LogicalType, LogicalTypeRef, NativeType}; use datafusion_common::utils::ListCoercion; +use datafusion_common::{internal_err, plan_err, Result}; use indexmap::IndexSet; use itertools::Itertools; @@ -84,6 +84,15 @@ pub enum Volatility { Volatile, } +/// Represents the arity (number of arguments) of a function signature +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum Arity { + /// Fixed number of arguments + Fixed(usize), + /// Variable number of arguments (e.g., Variadic, VariadicAny, UserDefined) + Variable, +} + /// The types of arguments for which a function has implementations. /// /// [`TypeSignature`] **DOES NOT** define the types that a user query could call the @@ -245,6 +254,69 @@ impl TypeSignature { pub fn is_one_of(&self) -> bool { matches!(self, TypeSignature::OneOf(_)) } + + /// Returns the arity (expected number of arguments) for this type signature. + /// + /// Returns `Arity::Fixed(n)` for signatures with a specific argument count, + /// or `Arity::Variable` for variable-arity signatures like `Variadic`, `VariadicAny`, `UserDefined`. + /// + /// # Examples + /// + /// ``` + /// # use datafusion_expr_common::signature::{TypeSignature, Arity}; + /// # use arrow::datatypes::DataType; + /// // Exact signature has fixed arity + /// let sig = TypeSignature::Exact(vec![DataType::Int32, DataType::Utf8]); + /// assert_eq!(sig.arity(), Arity::Fixed(2)); + /// + /// // Variadic signature has variable arity + /// let sig = TypeSignature::VariadicAny; + /// assert_eq!(sig.arity(), Arity::Variable); + /// ``` + pub fn arity(&self) -> Arity { + match self { + TypeSignature::Exact(types) => Arity::Fixed(types.len()), + TypeSignature::Uniform(count, _) => Arity::Fixed(*count), + TypeSignature::Numeric(count) => Arity::Fixed(*count), + TypeSignature::String(count) => Arity::Fixed(*count), + TypeSignature::Comparable(count) => Arity::Fixed(*count), + TypeSignature::Any(count) => Arity::Fixed(*count), + TypeSignature::Coercible(types) => Arity::Fixed(types.len()), + TypeSignature::Nullary => Arity::Fixed(0), + TypeSignature::ArraySignature(ArrayFunctionSignature::Array { + arguments, + .. + }) => Arity::Fixed(arguments.len()), + TypeSignature::ArraySignature(ArrayFunctionSignature::RecursiveArray) => { + Arity::Fixed(1) + } + TypeSignature::ArraySignature(ArrayFunctionSignature::MapArray) => { + Arity::Fixed(1) + } + TypeSignature::OneOf(variants) => { + // If any variant is Variable, the whole OneOf is Variable + let has_variable = variants.iter().any(|v| v.arity() == Arity::Variable); + if has_variable { + return Arity::Variable; + } + // Otherwise, get max arity from all fixed arity variants + let max_arity = variants + .iter() + .filter_map(|v| match v.arity() { + Arity::Fixed(n) => Some(n), + Arity::Variable => None, + }) + .max(); + match max_arity { + Some(n) => Arity::Fixed(n), + None => Arity::Variable, + } + } + TypeSignature::Variadic(_) + | TypeSignature::VariadicAny + | TypeSignature::UserDefined => Arity::Variable, + } + } } /// Represents the class of types that can be used in a function signature. @@ -336,7 +408,7 @@ impl TypeSignatureClass { &self, native_type: &NativeType, origin_type: &DataType, - ) -> datafusion_common::Result { + ) -> Result { match self { TypeSignatureClass::Native(logical_type) => { logical_type.native().default_cast_for(origin_type) @@ -486,6 +558,174 @@ impl TypeSignature { } } + /// Return string representation of the function signature with parameter names. + /// + /// This method is similar to [`Self::to_string_repr`] but uses parameter names + /// instead of types when available. This is useful for generating more helpful + /// error messages. + /// + /// # Arguments + /// * `parameter_names` - Optional slice of parameter names. When provided, these + /// names will be used instead of type names in the output. + /// + /// # Examples + /// ``` + /// # use datafusion_expr_common::signature::TypeSignature; + /// # use arrow::datatypes::DataType; + /// let sig = TypeSignature::Exact(vec![DataType::Int32, DataType::Utf8]); + /// + /// // Without names: shows types only + /// assert_eq!(sig.to_string_repr_with_names(None), vec!["Int32, Utf8"]); + /// + /// // With names: shows parameter names with types + /// assert_eq!( + /// sig.to_string_repr_with_names(Some(&["id".to_string(), "name".to_string()])), + /// vec!["id: Int32, name: Utf8"] + /// ); + /// ``` + pub fn to_string_repr_with_names( + &self, + parameter_names: Option<&[String]>, + ) -> Vec { + match self { + TypeSignature::Exact(types) => { + if let Some(names) = parameter_names { + vec![names + .iter() + .zip(types.iter()) + .map(|(name, typ)| format!("{name}: {typ}")) + .collect::>() + .join(", ")] + } else { + vec![Self::join_types(types, ", ")] + } + } + TypeSignature::Any(count) => { + if let Some(names) = parameter_names { + vec![names + .iter() + .take(*count) + .map(|name| format!("{name}: Any")) + .collect::>() + .join(", ")] + } else { + vec![std::iter::repeat_n("Any", *count) + .collect::>() + .join(", ")] + } + } + TypeSignature::Uniform(count, types) => { + if let Some(names) = parameter_names { + let type_str = Self::join_types(types, "/"); + vec![names + .iter() + .take(*count) + .map(|name| format!("{name}: {type_str}")) + .collect::>() + .join(", ")] + } else { + self.to_string_repr() + } + } + TypeSignature::Coercible(coercions) => { + if let Some(names) = parameter_names { + vec![names + .iter() + .zip(coercions.iter()) + .map(|(name, coercion)| format!("{name}: {coercion}")) + .collect::>() + .join(", ")] + } else { + vec![Self::join_types(coercions, ", ")] + } + } + TypeSignature::Comparable(count) => { + if let Some(names) = parameter_names { + vec![names + .iter() + .take(*count) + .map(|name| format!("{name}: Comparable")) + .collect::>() + .join(", ")] + } else { + self.to_string_repr() + } + } + TypeSignature::Numeric(count) => { + if let Some(names) = parameter_names { + vec![names + .iter() + .take(*count) + .map(|name| format!("{name}: Numeric")) + .collect::>() + .join(", ")] + } else { + self.to_string_repr() + } + } + TypeSignature::String(count) => { + if let Some(names) = parameter_names { + vec![names + .iter() + .take(*count) + .map(|name| format!("{name}: String")) + .collect::>() + .join(", ")] + } else { + self.to_string_repr() + } + } + TypeSignature::Nullary => self.to_string_repr(), + TypeSignature::ArraySignature(array_sig) => { + if let Some(names) = parameter_names { + match array_sig { + ArrayFunctionSignature::Array { arguments, .. } => { + vec![names + .iter() + .zip(arguments.iter()) + .map(|(name, arg_type)| format!("{name}: {arg_type}")) + .collect::>() + .join(", ")] + } + ArrayFunctionSignature::RecursiveArray => { + vec![names + .iter() + .take(1) + .map(|name| format!("{name}: recursive_array")) + .collect::>() + .join(", ")] + } + ArrayFunctionSignature::MapArray => { + vec![names + .iter() + .take(1) + .map(|name| format!("{name}: map_array")) + .collect::>() + .join(", ")] + } + } + } else { + self.to_string_repr() + } + } + TypeSignature::OneOf(sigs) => sigs + .iter() + .flat_map(|s| s.to_string_repr_with_names(parameter_names)) + .collect(), + TypeSignature::UserDefined => { + if let Some(names) = parameter_names { + vec![names.join(", ")] + } else { + self.to_string_repr() + } + } + // Variable arity signatures cannot use parameter names + TypeSignature::Variadic(_) | TypeSignature::VariadicAny => { + self.to_string_repr() + } + } + } + /// Helper function to join types with specified delimiter. pub fn join_types(types: &[T], delimiter: &str) -> String { types @@ -804,6 +1044,13 @@ pub struct Signature { pub type_signature: TypeSignature, /// The volatility of the function. See [Volatility] for more information. pub volatility: Volatility, + /// Optional parameter names for the function arguments. + /// + /// If provided, enables named argument notation for function calls (e.g., `func(a => 1, b => 2)`). + /// The length must match the number of arguments defined by `type_signature`. + /// + /// Defaults to `None`, meaning only positional arguments are supported. + pub parameter_names: Option>, } impl Signature { @@ -812,6 +1059,7 @@ impl Signature { Signature { type_signature, volatility, + parameter_names: None, } } /// An arbitrary number of arguments with the same type, from those listed in `common_types`. @@ -819,6 +1067,7 @@ impl Signature { Self { type_signature: TypeSignature::Variadic(common_types), volatility, + parameter_names: None, } } /// User-defined coercion rules for the function. @@ -826,6 +1075,7 @@ impl Signature { Self { type_signature: TypeSignature::UserDefined, volatility, + parameter_names: None, } } @@ -834,6 +1084,7 @@ impl Signature { Self { type_signature: TypeSignature::Numeric(arg_count), volatility, + parameter_names: None, } } @@ -842,6 +1093,7 @@ impl Signature { Self { type_signature: TypeSignature::String(arg_count), volatility, + parameter_names: None, } } @@ -850,6 +1102,7 @@ impl Signature { Self { type_signature: TypeSignature::VariadicAny, volatility, + parameter_names: None, } } /// A fixed number of arguments of the same type, from those listed in `valid_types`. @@ -861,6 +1114,7 @@ impl Signature { Self { type_signature: TypeSignature::Uniform(arg_count, valid_types), volatility, + parameter_names: None, } } /// Exactly matches the types in `exact_types`, in order. @@ -868,6 +1122,7 @@ impl Signature { Signature { type_signature: TypeSignature::Exact(exact_types), volatility, + parameter_names: None, } } @@ -876,6 +1131,7 @@ impl Signature { Self { type_signature: TypeSignature::Coercible(target_types), volatility, + parameter_names: None, } } @@ -884,6 +1140,7 @@ impl Signature { Self { type_signature: TypeSignature::Comparable(arg_count), volatility, + parameter_names: None, } } @@ -891,6 +1148,7 @@ impl Signature { Signature { type_signature: TypeSignature::Nullary, volatility, + parameter_names: None, } } @@ -899,6 +1157,7 @@ impl Signature { Signature { type_signature: TypeSignature::Any(arg_count), volatility, + parameter_names: None, } } @@ -907,6 +1166,7 @@ impl Signature { Signature { type_signature: TypeSignature::OneOf(type_signatures), volatility, + parameter_names: None, } } @@ -923,6 +1183,7 @@ impl Signature { }, ), volatility, + parameter_names: None, } } @@ -939,6 +1200,7 @@ impl Signature { }, ), volatility, + parameter_names: None, } } @@ -956,6 +1218,7 @@ impl Signature { }, ), volatility, + parameter_names: None, } } @@ -980,6 +1243,7 @@ impl Signature { }), ]), volatility, + parameter_names: None, } } @@ -996,6 +1260,7 @@ impl Signature { }, ), volatility, + parameter_names: None, } } @@ -1003,13 +1268,72 @@ impl Signature { pub fn array(volatility: Volatility) -> Self { Signature::arrays(1, Some(ListCoercion::FixedSizedListToList), volatility) } + + /// Add parameter names to this signature, enabling named argument notation. + /// + /// # Example + /// ``` + /// # use datafusion_expr_common::signature::{Signature, Volatility}; + /// # use arrow::datatypes::DataType; + /// let sig = Signature::exact(vec![DataType::Int32, DataType::Utf8], Volatility::Immutable) + /// .with_parameter_names(vec!["count".to_string(), "name".to_string()]); + /// ``` + /// + /// # Errors + /// Returns an error if the number of parameter names doesn't match the signature's arity. + /// For signatures with variable arity (e.g., `Variadic`, `VariadicAny`), parameter names + /// cannot be specified. + pub fn with_parameter_names(mut self, names: Vec>) -> Result { + let names = names.into_iter().map(Into::into).collect::>(); + // Validate that the number of names matches the signature + self.validate_parameter_names(&names)?; + self.parameter_names = Some(names); + Ok(self) + } + + /// Validate that parameter names are compatible with this signature + fn validate_parameter_names(&self, names: &[String]) -> Result<()> { + match self.type_signature.arity() { + Arity::Fixed(expected) => { + if names.len() != expected { + return plan_err!( + "Parameter names count ({}) does not match signature arity ({})", + names.len(), + expected + ); + } + } + Arity::Variable => { + // For UserDefined signatures, allow parameter names + // The function implementer is responsible for validating the names match the actual arguments + if !matches!(self.type_signature, TypeSignature::UserDefined) { + return plan_err!( + "Cannot specify parameter names for variable arity signature: {:?}", + self.type_signature + ); + } + } + } + + let mut seen = std::collections::HashSet::new(); + for name in names { + if !seen.insert(name) { + return plan_err!("Duplicate parameter name: '{}'", name); + } + } + + Ok(()) + } } #[cfg(test)] mod tests { - use datafusion_common::types::{logical_int64, logical_string}; + use datafusion_common::types::{logical_int32, logical_int64, logical_string}; use super::*; + use crate::signature::{ + ArrayFunctionArgument, ArrayFunctionSignature, Coercion, TypeSignatureClass, + }; #[test] fn supports_zero_argument_tests() { @@ -1167,4 +1491,430 @@ mod tests { ] ); } + + #[test] + fn test_signature_with_parameter_names() { + let sig = Signature::exact( + vec![DataType::Int32, DataType::Utf8], + Volatility::Immutable, + ) + .with_parameter_names(vec!["count".to_string(), "name".to_string()]) + .unwrap(); + + assert_eq!( + sig.parameter_names, + Some(vec!["count".to_string(), "name".to_string()]) + ); + assert_eq!( + sig.type_signature, + TypeSignature::Exact(vec![DataType::Int32, DataType::Utf8]) + ); + } + + #[test] + fn test_signature_parameter_names_wrong_count() { + let result = Signature::exact( + vec![DataType::Int32, DataType::Utf8], + Volatility::Immutable, + ) + .with_parameter_names(vec!["count".to_string()]); // Only 1 name for 2 args + + assert!(result.is_err()); + assert!(result + .unwrap_err() + .to_string() + .contains("does not match signature arity")); + } + + #[test] + fn test_signature_parameter_names_duplicate() { + let result = Signature::exact( + vec![DataType::Int32, DataType::Int32], + Volatility::Immutable, + ) + .with_parameter_names(vec!["count".to_string(), "count".to_string()]); + + assert!(result.is_err()); + assert!(result + .unwrap_err() + .to_string() + .contains("Duplicate parameter name")); + } + + #[test] + fn test_signature_parameter_names_variadic() { + let result = Signature::variadic(vec![DataType::Int32], Volatility::Immutable) + .with_parameter_names(vec!["arg".to_string()]); + + assert!(result.is_err()); + assert!(result + .unwrap_err() + .to_string() + .contains("variable arity signature")); + } + + #[test] + fn test_signature_without_parameter_names() { + let sig = Signature::exact( + vec![DataType::Int32, DataType::Utf8], + Volatility::Immutable, + ); + + assert_eq!(sig.parameter_names, None); + } + + #[test] + fn test_signature_uniform_with_parameter_names() { + let sig = Signature::uniform(3, vec![DataType::Float64], Volatility::Immutable) + .with_parameter_names(vec!["x".to_string(), "y".to_string(), "z".to_string()]) + .unwrap(); + + assert_eq!( + sig.parameter_names, + Some(vec!["x".to_string(), "y".to_string(), "z".to_string()]) + ); + } + + #[test] + fn test_signature_numeric_with_parameter_names() { + let sig = Signature::numeric(2, Volatility::Immutable) + .with_parameter_names(vec!["a".to_string(), "b".to_string()]) + .unwrap(); + + assert_eq!( + sig.parameter_names, + Some(vec!["a".to_string(), "b".to_string()]) + ); + } + + #[test] + fn test_signature_nullary_with_empty_names() { + let sig = Signature::nullary(Volatility::Immutable) + .with_parameter_names(Vec::::new()) + .unwrap(); + + assert_eq!(sig.parameter_names, Some(vec![])); + } + + #[test] + fn test_to_string_repr_with_names_exact() { + let sig = TypeSignature::Exact(vec![DataType::Int32, DataType::Utf8]); + + assert_eq!(sig.to_string_repr_with_names(None), vec!["Int32, Utf8"]); + + let names = vec!["id".to_string(), "name".to_string()]; + assert_eq!( + sig.to_string_repr_with_names(Some(&names)), + vec!["id: Int32, name: Utf8"] + ); + } + + #[test] + fn test_to_string_repr_with_names_any() { + let sig = TypeSignature::Any(3); + + assert_eq!(sig.to_string_repr_with_names(None), vec!["Any, Any, Any"]); + + let names = vec!["x".to_string(), "y".to_string(), "z".to_string()]; + assert_eq!( + sig.to_string_repr_with_names(Some(&names)), + vec!["x: Any, y: Any, z: Any"] + ); + } + + #[test] + fn test_to_string_repr_with_names_one_of() { + let sig = + TypeSignature::OneOf(vec![TypeSignature::Any(2), TypeSignature::Any(3)]); + + assert_eq!( + sig.to_string_repr_with_names(None), + vec!["Any, Any", "Any, Any, Any"] + ); + + let names = vec![ + "str".to_string(), + "start_pos".to_string(), + "length".to_string(), + ]; + assert_eq!( + sig.to_string_repr_with_names(Some(&names)), + vec![ + "str: Any, start_pos: Any", + "str: Any, start_pos: Any, length: Any" + ] + ); + } + + #[test] + fn test_to_string_repr_with_names_partial() { + // This simulates providing max arity names for a OneOf signature + let sig = TypeSignature::Exact(vec![DataType::Int32, DataType::Utf8]); + + // Provide 3 names for 2-parameter signature (extra name is ignored via zip) + let names = vec!["a".to_string(), "b".to_string(), "c".to_string()]; + assert_eq!( + sig.to_string_repr_with_names(Some(&names)), + vec!["a: Int32, b: Utf8"] + ); + } + + #[test] + fn test_to_string_repr_with_names_uniform() { + let sig = TypeSignature::Uniform(2, vec![DataType::Float64]); + + assert_eq!( + sig.to_string_repr_with_names(None), + vec!["Float64, Float64"] + ); + + let names = vec!["x".to_string(), "y".to_string()]; + assert_eq!( + sig.to_string_repr_with_names(Some(&names)), + vec!["x: Float64, y: Float64"] + ); + } + + #[test] + fn test_to_string_repr_with_names_coercible() { + let sig = TypeSignature::Coercible(vec![ + Coercion::new_exact(TypeSignatureClass::Native(logical_int32())), + Coercion::new_exact(TypeSignatureClass::Native(logical_int32())), + ]); + + let names = vec!["a".to_string(), "b".to_string()]; + let result = sig.to_string_repr_with_names(Some(&names)); + // Check that it contains the parameter names with type annotations + assert_eq!(result.len(), 1); + assert!(result[0].starts_with("a: ")); + assert!(result[0].contains(", b: ")); + } + + #[test] + fn test_to_string_repr_with_names_comparable_numeric_string() { + let comparable = TypeSignature::Comparable(3); + let numeric = TypeSignature::Numeric(2); + let string_sig = TypeSignature::String(2); + + let names = vec!["a".to_string(), "b".to_string(), "c".to_string()]; + + // All should show parameter names with type annotations + assert_eq!( + comparable.to_string_repr_with_names(Some(&names)), + vec!["a: Comparable, b: Comparable, c: Comparable"] + ); + assert_eq!( + numeric.to_string_repr_with_names(Some(&names)), + vec!["a: Numeric, b: Numeric"] + ); + assert_eq!( + string_sig.to_string_repr_with_names(Some(&names)), + vec!["a: String, b: String"] + ); + } + + #[test] + fn test_to_string_repr_with_names_variadic_fallback() { + let variadic = TypeSignature::Variadic(vec![DataType::Utf8, DataType::LargeUtf8]); + let names = vec!["x".to_string()]; + assert_eq!( + variadic.to_string_repr_with_names(Some(&names)), + variadic.to_string_repr() + ); + + let variadic_any = TypeSignature::VariadicAny; + assert_eq!( + variadic_any.to_string_repr_with_names(Some(&names)), + variadic_any.to_string_repr() + ); + + // UserDefined now shows parameter names when available + let user_defined = TypeSignature::UserDefined; + assert_eq!( + user_defined.to_string_repr_with_names(Some(&names)), + vec!["x"] + ); + assert_eq!( + user_defined.to_string_repr_with_names(None), + user_defined.to_string_repr() + ); + } + + #[test] + fn test_to_string_repr_with_names_nullary() { + let sig = TypeSignature::Nullary; + let names = vec!["x".to_string()]; + + // Should return empty representation, names don't apply + assert_eq!( + sig.to_string_repr_with_names(Some(&names)), + vec!["NullAry()"] + ); + assert_eq!(sig.to_string_repr_with_names(None), vec!["NullAry()"]); + } + + #[test] + fn test_to_string_repr_with_names_array_signature() { + let sig = TypeSignature::ArraySignature(ArrayFunctionSignature::Array { + arguments: vec![ + ArrayFunctionArgument::Array, + ArrayFunctionArgument::Index, + ArrayFunctionArgument::Element, + ], + array_coercion: None, + }); + + assert_eq!( + sig.to_string_repr_with_names(None), + vec!["array, index, element"] + ); + + let names = vec!["arr".to_string(), "idx".to_string(), "val".to_string()]; + assert_eq!( + sig.to_string_repr_with_names(Some(&names)), + vec!["arr: array, idx: index, val: element"] + ); + + let recursive = + TypeSignature::ArraySignature(ArrayFunctionSignature::RecursiveArray); + let names = vec!["array".to_string()]; + assert_eq!( + recursive.to_string_repr_with_names(Some(&names)), + vec!["array: recursive_array"] + ); + + // Test MapArray (1 argument) + let map_array = TypeSignature::ArraySignature(ArrayFunctionSignature::MapArray); + let names = vec!["map".to_string()]; + assert_eq!( + map_array.to_string_repr_with_names(Some(&names)), + vec!["map: map_array"] + ); + } + + #[test] + fn test_type_signature_arity_exact() { + let sig = TypeSignature::Exact(vec![DataType::Int32, DataType::Utf8]); + assert_eq!(sig.arity(), Arity::Fixed(2)); + + let sig = TypeSignature::Exact(vec![]); + assert_eq!(sig.arity(), Arity::Fixed(0)); + } + + #[test] + fn test_type_signature_arity_uniform() { + let sig = TypeSignature::Uniform(3, vec![DataType::Float64]); + assert_eq!(sig.arity(), Arity::Fixed(3)); + + let sig = TypeSignature::Uniform(1, vec![DataType::Int32]); + assert_eq!(sig.arity(), Arity::Fixed(1)); + } + + #[test] + fn test_type_signature_arity_numeric() { + let sig = TypeSignature::Numeric(2); + assert_eq!(sig.arity(), Arity::Fixed(2)); + } + + #[test] + fn test_type_signature_arity_string() { + let sig = TypeSignature::String(3); + assert_eq!(sig.arity(), Arity::Fixed(3)); + } + + #[test] + fn test_type_signature_arity_comparable() { + let sig = TypeSignature::Comparable(2); + assert_eq!(sig.arity(), Arity::Fixed(2)); + } + + #[test] + fn test_type_signature_arity_any() { + let sig = TypeSignature::Any(4); + assert_eq!(sig.arity(), Arity::Fixed(4)); + } + + #[test] + fn test_type_signature_arity_coercible() { + let sig = TypeSignature::Coercible(vec![ + Coercion::new_exact(TypeSignatureClass::Native(logical_int32())), + Coercion::new_exact(TypeSignatureClass::Native(logical_string())), + ]); + assert_eq!(sig.arity(), Arity::Fixed(2)); + } + + #[test] + fn test_type_signature_arity_nullary() { + let sig = TypeSignature::Nullary; + assert_eq!(sig.arity(), Arity::Fixed(0)); + } + + #[test] + fn test_type_signature_arity_array_signature() { + // Test Array variant with 2 arguments + let sig = TypeSignature::ArraySignature(ArrayFunctionSignature::Array { + arguments: vec![ArrayFunctionArgument::Array, ArrayFunctionArgument::Index], + array_coercion: None, + }); + assert_eq!(sig.arity(), Arity::Fixed(2)); + + // Test Array variant with 3 arguments + let sig = TypeSignature::ArraySignature(ArrayFunctionSignature::Array { + arguments: vec![ + ArrayFunctionArgument::Array, + ArrayFunctionArgument::Element, + ArrayFunctionArgument::Index, + ], + array_coercion: None, + }); + assert_eq!(sig.arity(), Arity::Fixed(3)); + + // Test RecursiveArray variant + let sig = TypeSignature::ArraySignature(ArrayFunctionSignature::RecursiveArray); + assert_eq!(sig.arity(), Arity::Fixed(1)); + + // Test MapArray variant + let sig = TypeSignature::ArraySignature(ArrayFunctionSignature::MapArray); + assert_eq!(sig.arity(), Arity::Fixed(1)); + } + + #[test] + fn test_type_signature_arity_one_of_fixed() { + // OneOf with all fixed arity variants should return max arity + let sig = TypeSignature::OneOf(vec![ + TypeSignature::Exact(vec![DataType::Int32]), + TypeSignature::Exact(vec![DataType::Int32, DataType::Utf8]), + TypeSignature::Exact(vec![ + DataType::Int32, + DataType::Utf8, + DataType::Float64, + ]), + ]); + assert_eq!(sig.arity(), Arity::Fixed(3)); + } + + #[test] + fn test_type_signature_arity_one_of_variable() { + // OneOf with variable arity variant should return Variable + let sig = TypeSignature::OneOf(vec![ + TypeSignature::Exact(vec![DataType::Int32]), + TypeSignature::VariadicAny, + ]); + assert_eq!(sig.arity(), Arity::Variable); + } + + #[test] + fn test_type_signature_arity_variadic() { + let sig = TypeSignature::Variadic(vec![DataType::Int32]); + assert_eq!(sig.arity(), Arity::Variable); + + let sig = TypeSignature::VariadicAny; + assert_eq!(sig.arity(), Arity::Variable); + } + + #[test] + fn test_type_signature_arity_user_defined() { + let sig = TypeSignature::UserDefined; + assert_eq!(sig.arity(), Arity::Variable); + } } diff --git a/datafusion/expr/src/arguments.rs b/datafusion/expr/src/arguments.rs new file mode 100644 index 000000000000..5653993db98f --- /dev/null +++ b/datafusion/expr/src/arguments.rs @@ -0,0 +1,285 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Argument resolution logic for named function parameters + +use crate::Expr; +use datafusion_common::{plan_err, Result}; +use std::collections::HashMap; + +/// Resolves function arguments, handling named and positional notation. +/// +/// This function validates and reorders arguments to match the function's parameter names +/// when named arguments are used. +/// +/// # Rules +/// - All positional arguments must come before named arguments +/// - Named arguments can be in any order after positional arguments +/// - Parameter names follow SQL identifier rules: unquoted names are case-insensitive +/// (normalized to lowercase), quoted names are case-sensitive +/// - No duplicate parameter names allowed +/// +/// # Arguments +/// * `param_names` - The function's parameter names in order +/// * `args` - The argument expressions +/// * `arg_names` - Optional parameter name for each argument +/// +/// # Returns +/// A vector of expressions in the correct order matching the parameter names +/// +/// # Examples +/// ```text +/// Given parameters ["a", "b", "c"] +/// And call: func(10, c => 30, b => 20) +/// Returns: [Expr(10), Expr(20), Expr(30)] +/// ``` +pub fn resolve_function_arguments( + param_names: &[String], + args: Vec, + arg_names: Vec>, +) -> Result> { + if args.len() != arg_names.len() { + return plan_err!( + "Internal error: args length ({}) != arg_names length ({})", + args.len(), + arg_names.len() + ); + } + + // Check if all arguments are positional (fast path) + if arg_names.iter().all(|name| name.is_none()) { + return Ok(args); + } + + validate_argument_order(&arg_names)?; + + reorder_named_arguments(param_names, args, arg_names) +} + +/// Validates that positional arguments come before named arguments +fn validate_argument_order(arg_names: &[Option]) -> Result<()> { + let mut seen_named = false; + for (i, arg_name) in arg_names.iter().enumerate() { + match arg_name { + Some(_) => seen_named = true, + None if seen_named => { + return plan_err!( + "Positional argument at position {} follows named argument. \ + All positional arguments must come before named arguments.", + i + ); + } + None => {} + } + } + Ok(()) +} + +/// Reorders arguments based on named parameters to match signature order +fn reorder_named_arguments( + param_names: &[String], + args: Vec, + arg_names: Vec>, +) -> Result> { + // Build HashMap for O(1) parameter name lookups + let param_index_map: HashMap<&str, usize> = param_names + .iter() + .enumerate() + .map(|(idx, name)| (name.as_str(), idx)) + .collect(); + + let positional_count = arg_names.iter().filter(|n| n.is_none()).count(); + + // Capture args length before consuming the vector + let args_len = args.len(); + + let expected_arg_count = param_names.len(); + + if positional_count > expected_arg_count { + return plan_err!( + "Too many positional arguments: expected at most {}, got {}", + expected_arg_count, + positional_count + ); + } + + let mut result: Vec> = vec![None; expected_arg_count]; + + for (i, (arg, arg_name)) in args.into_iter().zip(arg_names).enumerate() { + if let Some(name) = arg_name { + // Named argument - O(1) lookup in HashMap + let param_index = + param_index_map.get(name.as_str()).copied().ok_or_else(|| { + datafusion_common::plan_datafusion_err!( + "Unknown parameter name '{}'. Valid parameters are: [{}]", + name, + param_names.join(", ") + ) + })?; + + if result[param_index].is_some() { + return plan_err!("Parameter '{}' specified multiple times", name); + } + + result[param_index] = Some(arg); + } else { + result[i] = Some(arg); + } + } + + // Only require parameters up to the number of arguments provided (supports optional parameters) + let required_count = args_len; + for i in 0..required_count { + if result[i].is_none() { + return plan_err!("Missing required parameter '{}'", param_names[i]); + } + } + + // Return only the assigned parameters (handles optional trailing parameters) + Ok(result.into_iter().take(required_count).flatten().collect()) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::lit; + + #[test] + fn test_all_positional() { + let param_names = vec!["a".to_string(), "b".to_string()]; + + let args = vec![lit(1), lit("hello")]; + let arg_names = vec![None, None]; + + let result = + resolve_function_arguments(¶m_names, args.clone(), arg_names).unwrap(); + assert_eq!(result.len(), 2); + } + + #[test] + fn test_all_named() { + let param_names = vec!["a".to_string(), "b".to_string()]; + + let args = vec![lit(1), lit("hello")]; + let arg_names = vec![Some("a".to_string()), Some("b".to_string())]; + + let result = resolve_function_arguments(¶m_names, args, arg_names).unwrap(); + assert_eq!(result.len(), 2); + } + + #[test] + fn test_named_reordering() { + let param_names = vec!["a".to_string(), "b".to_string(), "c".to_string()]; + + // Call with: func(c => 3.0, a => 1, b => "hello") + let args = vec![lit(3.0), lit(1), lit("hello")]; + let arg_names = vec![ + Some("c".to_string()), + Some("a".to_string()), + Some("b".to_string()), + ]; + + let result = resolve_function_arguments(¶m_names, args, arg_names).unwrap(); + + // Should be reordered to [a, b, c] = [1, "hello", 3.0] + assert_eq!(result.len(), 3); + assert_eq!(result[0], lit(1)); + assert_eq!(result[1], lit("hello")); + assert_eq!(result[2], lit(3.0)); + } + + #[test] + fn test_mixed_positional_and_named() { + let param_names = vec!["a".to_string(), "b".to_string(), "c".to_string()]; + + // Call with: func(1, c => 3.0, b => "hello") + let args = vec![lit(1), lit(3.0), lit("hello")]; + let arg_names = vec![None, Some("c".to_string()), Some("b".to_string())]; + + let result = resolve_function_arguments(¶m_names, args, arg_names).unwrap(); + + // Should be reordered to [a, b, c] = [1, "hello", 3.0] + assert_eq!(result.len(), 3); + assert_eq!(result[0], lit(1)); + assert_eq!(result[1], lit("hello")); + assert_eq!(result[2], lit(3.0)); + } + + #[test] + fn test_positional_after_named_error() { + let param_names = vec!["a".to_string(), "b".to_string()]; + + // Call with: func(a => 1, "hello") - ERROR + let args = vec![lit(1), lit("hello")]; + let arg_names = vec![Some("a".to_string()), None]; + + let result = resolve_function_arguments(¶m_names, args, arg_names); + assert!(result.is_err()); + assert!(result + .unwrap_err() + .to_string() + .contains("Positional argument")); + } + + #[test] + fn test_unknown_parameter_name() { + let param_names = vec!["a".to_string(), "b".to_string()]; + + // Call with: func(x => 1, b => "hello") - ERROR + let args = vec![lit(1), lit("hello")]; + let arg_names = vec![Some("x".to_string()), Some("b".to_string())]; + + let result = resolve_function_arguments(¶m_names, args, arg_names); + assert!(result.is_err()); + assert!(result + .unwrap_err() + .to_string() + .contains("Unknown parameter")); + } + + #[test] + fn test_duplicate_parameter_name() { + let param_names = vec!["a".to_string(), "b".to_string()]; + + // Call with: func(a => 1, a => 2) - ERROR + let args = vec![lit(1), lit(2)]; + let arg_names = vec![Some("a".to_string()), Some("a".to_string())]; + + let result = resolve_function_arguments(¶m_names, args, arg_names); + assert!(result.is_err()); + assert!(result + .unwrap_err() + .to_string() + .contains("specified multiple times")); + } + + #[test] + fn test_missing_required_parameter() { + let param_names = vec!["a".to_string(), "b".to_string(), "c".to_string()]; + + // Call with: func(a => 1, c => 3.0) - missing 'b' + let args = vec![lit(1), lit(3.0)]; + let arg_names = vec![Some("a".to_string()), Some("c".to_string())]; + + let result = resolve_function_arguments(¶m_names, args, arg_names); + assert!(result.is_err()); + assert!(result + .unwrap_err() + .to_string() + .contains("Missing required parameter")); + } +} diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs index 1c9734a89bd3..c0a9c0595224 100644 --- a/datafusion/expr/src/lib.rs +++ b/datafusion/expr/src/lib.rs @@ -44,6 +44,7 @@ mod udaf; mod udf; mod udwf; +pub mod arguments; pub mod conditional_expressions; pub mod execution_props; pub mod expr; diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index b91db4527b3a..74ba99847f70 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -936,7 +936,7 @@ pub fn generate_signature_error_msg( ) -> String { let candidate_signatures = func_signature .type_signature - .to_string_repr() + .to_string_repr_with_names(func_signature.parameter_names.as_deref()) .iter() .map(|args_str| format!("\t{func_name}({args_str})")) .collect::>() @@ -1295,6 +1295,7 @@ mod tests { Cast, ExprFunctionExt, WindowFunctionDefinition, }; use arrow::datatypes::{UnionFields, UnionMode}; + use datafusion_expr_common::signature::{TypeSignature, Volatility}; #[test] fn test_group_window_expr_by_sort_keys_empty_case() -> Result<()> { @@ -1714,4 +1715,52 @@ mod tests { DataType::List(Arc::new(Field::new("my_union", union_type, true))); assert!(!can_hash(&list_union_type)); } + + #[test] + fn test_generate_signature_error_msg_with_parameter_names() { + let sig = Signature::one_of( + vec![ + TypeSignature::Exact(vec![DataType::Utf8, DataType::Int64]), + TypeSignature::Exact(vec![ + DataType::Utf8, + DataType::Int64, + DataType::Int64, + ]), + ], + Volatility::Immutable, + ) + .with_parameter_names(vec![ + "str".to_string(), + "start_pos".to_string(), + "length".to_string(), + ]) + .expect("valid parameter names"); + + // Generate error message with only 1 argument provided + let error_msg = generate_signature_error_msg("substr", sig, &[DataType::Utf8]); + + assert!( + error_msg.contains("str: Utf8, start_pos: Int64"), + "Expected 'str: Utf8, start_pos: Int64' in error message, got: {error_msg}" + ); + assert!( + error_msg.contains("str: Utf8, start_pos: Int64, length: Int64"), + "Expected 'str: Utf8, start_pos: Int64, length: Int64' in error message, got: {error_msg}" + ); + } + + #[test] + fn test_generate_signature_error_msg_without_parameter_names() { + let sig = Signature::one_of( + vec![TypeSignature::Any(2), TypeSignature::Any(3)], + Volatility::Immutable, + ); + + let error_msg = generate_signature_error_msg("my_func", sig, &[DataType::Int32]); + + assert!( + error_msg.contains("Any, Any"), + "Expected 'Any, Any' without parameter names, got: {error_msg}" + ); + } } diff --git a/datafusion/functions-nested/src/replace.rs b/datafusion/functions-nested/src/replace.rs index 59f851a776a1..4314d41419bc 100644 --- a/datafusion/functions-nested/src/replace.rs +++ b/datafusion/functions-nested/src/replace.rs @@ -105,6 +105,7 @@ impl ArrayReplace { }, ), volatility: Volatility::Immutable, + parameter_names: None, }, aliases: vec![String::from("list_replace")], } @@ -186,6 +187,7 @@ impl ArrayReplaceN { }, ), volatility: Volatility::Immutable, + parameter_names: None, }, aliases: vec![String::from("list_replace_n")], } @@ -265,6 +267,7 @@ impl ArrayReplaceAll { }, ), volatility: Volatility::Immutable, + parameter_names: None, }, aliases: vec![String::from("list_replace_all")], } diff --git a/datafusion/functions/src/unicode/substr.rs b/datafusion/functions/src/unicode/substr.rs index 0b35f664532d..46b3cc63d0b6 100644 --- a/datafusion/functions/src/unicode/substr.rs +++ b/datafusion/functions/src/unicode/substr.rs @@ -71,7 +71,13 @@ impl Default for SubstrFunc { impl SubstrFunc { pub fn new() -> Self { Self { - signature: Signature::user_defined(Volatility::Immutable), + signature: Signature::user_defined(Volatility::Immutable) + .with_parameter_names(vec![ + "str".to_string(), + "start_pos".to_string(), + "length".to_string(), + ]) + .expect("valid parameter names"), aliases: vec![String::from("substring")], } } diff --git a/datafusion/sql/src/expr/function.rs b/datafusion/sql/src/expr/function.rs index eabf645a5eaf..cb34bb0f7eb7 100644 --- a/datafusion/sql/src/expr/function.rs +++ b/datafusion/sql/src/expr/function.rs @@ -274,8 +274,28 @@ impl SqlToRel<'_, S> { } // User-defined function (UDF) should have precedence if let Some(fm) = self.context_provider.get_function_meta(&name) { - let args = self.function_args_to_expr(args, schema, planner_context)?; - let inner = ScalarFunction::new_udf(fm, args); + let (args, arg_names) = + self.function_args_to_expr_with_names(args, schema, planner_context)?; + + let resolved_args = if arg_names.iter().any(|name| name.is_some()) { + if let Some(param_names) = &fm.signature().parameter_names { + datafusion_expr::arguments::resolve_function_arguments( + param_names, + args, + arg_names, + )? + } else { + return plan_err!( + "Function '{}' does not support named arguments", + fm.name() + ); + } + } else { + args + }; + + // After resolution, all arguments are positional + let inner = ScalarFunction::new_udf(fm, resolved_args); if name.eq_ignore_ascii_case(inner.name()) { return Ok(Expr::ScalarFunction(inner)); @@ -624,14 +644,29 @@ impl SqlToRel<'_, S> { schema: &DFSchema, planner_context: &mut PlannerContext, ) -> Result { + let (expr, _) = + self.sql_fn_arg_to_logical_expr_with_name(sql, schema, planner_context)?; + Ok(expr) + } + + fn sql_fn_arg_to_logical_expr_with_name( + &self, + sql: FunctionArg, + schema: &DFSchema, + planner_context: &mut PlannerContext, + ) -> Result<(Expr, Option)> { match sql { FunctionArg::Named { - name: _, + name, arg: FunctionArgExpr::Expr(arg), operator: _, - } => self.sql_expr_to_logical_expr(arg, schema, planner_context), + } => { + let expr = self.sql_expr_to_logical_expr(arg, schema, planner_context)?; + let arg_name = crate::utils::normalize_ident(name); + Ok((expr, Some(arg_name))) + } FunctionArg::Named { - name: _, + name, arg: FunctionArgExpr::Wildcard, operator: _, } => { @@ -640,11 +675,12 @@ impl SqlToRel<'_, S> { qualifier: None, options: Box::new(WildcardOptions::default()), }; - - Ok(expr) + let arg_name = crate::utils::normalize_ident(name); + Ok((expr, Some(arg_name))) } FunctionArg::Unnamed(FunctionArgExpr::Expr(arg)) => { - self.sql_expr_to_logical_expr(arg, schema, planner_context) + let expr = self.sql_expr_to_logical_expr(arg, schema, planner_context)?; + Ok((expr, None)) } FunctionArg::Unnamed(FunctionArgExpr::Wildcard) => { #[expect(deprecated)] @@ -652,8 +688,7 @@ impl SqlToRel<'_, S> { qualifier: None, options: Box::new(WildcardOptions::default()), }; - - Ok(expr) + Ok((expr, None)) } FunctionArg::Unnamed(FunctionArgExpr::QualifiedWildcard(object_name)) => { let qualifier = self.object_name_to_table_reference(object_name)?; @@ -668,8 +703,30 @@ impl SqlToRel<'_, S> { qualifier: qualifier.into(), options: Box::new(WildcardOptions::default()), }; - - Ok(expr) + Ok((expr, None)) + } + // PostgreSQL dialect uses ExprNamed variant with expression for name + FunctionArg::ExprNamed { + name: SQLExpr::Identifier(name), + arg: FunctionArgExpr::Expr(arg), + operator: _, + } => { + let expr = self.sql_expr_to_logical_expr(arg, schema, planner_context)?; + let arg_name = crate::utils::normalize_ident(name); + Ok((expr, Some(arg_name))) + } + FunctionArg::ExprNamed { + name: SQLExpr::Identifier(name), + arg: FunctionArgExpr::Wildcard, + operator: _, + } => { + #[expect(deprecated)] + let expr = Expr::Wildcard { + qualifier: None, + options: Box::new(WildcardOptions::default()), + }; + let arg_name = crate::utils::normalize_ident(name); + Ok((expr, Some(arg_name))) } _ => not_impl_err!("Unsupported qualified wildcard argument: {sql:?}"), } @@ -686,6 +743,24 @@ impl SqlToRel<'_, S> { .collect::>>() } + pub(super) fn function_args_to_expr_with_names( + &self, + args: Vec, + schema: &DFSchema, + planner_context: &mut PlannerContext, + ) -> Result<(Vec, Vec>)> { + let results: Result)>> = args + .into_iter() + .map(|a| { + self.sql_fn_arg_to_logical_expr_with_name(a, schema, planner_context) + }) + .collect(); + + let pairs = results?; + let (exprs, names): (Vec, Vec>) = pairs.into_iter().unzip(); + Ok((exprs, names)) + } + pub(crate) fn check_unnest_arg(arg: &Expr, schema: &DFSchema) -> Result<()> { // Check argument type, array types are supported match arg.get_type(schema)? { diff --git a/datafusion/sqllogictest/src/engines/postgres_engine/mod.rs b/datafusion/sqllogictest/src/engines/postgres_engine/mod.rs index 375f06d34b44..4d310711687f 100644 --- a/datafusion/sqllogictest/src/engines/postgres_engine/mod.rs +++ b/datafusion/sqllogictest/src/engines/postgres_engine/mod.rs @@ -76,8 +76,8 @@ impl Postgres { /// /// See https://docs.rs/tokio-postgres/latest/tokio_postgres/config/struct.Config.html#url for format pub async fn connect(relative_path: PathBuf, pb: ProgressBar) -> Result { - let uri = - std::env::var("PG_URI").map_or(PG_URI.to_string(), std::convert::identity); + let uri = std::env::var("PG_URI") + .map_or_else(|_| PG_URI.to_string(), std::convert::identity); info!("Using postgres connection string: {uri}"); diff --git a/datafusion/sqllogictest/test_files/named_arguments.slt b/datafusion/sqllogictest/test_files/named_arguments.slt new file mode 100644 index 000000000000..c93da7e7a8f9 --- /dev/null +++ b/datafusion/sqllogictest/test_files/named_arguments.slt @@ -0,0 +1,139 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +############# +## Tests for Named Arguments (PostgreSQL-style param => value syntax) +############# + +# Test positional arguments still work (baseline) +query T +SELECT substr('hello world', 7, 5); +---- +world + +# Test named arguments in order +query T +SELECT substr(str => 'hello world', start_pos => 7, length => 5); +---- +world + +# Test named arguments out of order +query T +SELECT substr(length => 5, str => 'hello world', start_pos => 7); +---- +world + +# Test mixed positional and named arguments +query T +SELECT substr('hello world', start_pos => 7, length => 5); +---- +world + +# Test with only 2 parameters (length optional) +query T +SELECT substr(str => 'hello world', start_pos => 7); +---- +world + +# Test all parameters named with substring alias +query T +SELECT substring(str => 'hello', start_pos => 1, length => 3); +---- +hel + +# Error: positional argument after named argument +query error DataFusion error: Error during planning: Positional argument.*follows named argument +SELECT substr(str => 'hello', 1, 3); + +# Error: unknown parameter name +query error DataFusion error: Error during planning: Unknown parameter name 'invalid' +SELECT substr(invalid => 'hello', start_pos => 1, length => 3); + +# Error: duplicate parameter name +query error DataFusion error: Error during planning: Parameter 'str' specified multiple times +SELECT substr(str => 'hello', str => 'world', start_pos => 1); + +# Test case-insensitive parameter names (unquoted identifiers) +query T +SELECT substr(STR => 'hello world', START_POS => 7, LENGTH => 5); +---- +world + +# Test case-insensitive with mixed case +query T +SELECT substr(Str => 'hello world', Start_Pos => 7); +---- +world + +# Error: case-sensitive quoted parameter names don't match +query error DataFusion error: Error during planning: Unknown parameter name 'STR' +SELECT substr("STR" => 'hello world', "start_pos" => 7); + +# Error: wrong number of arguments +# This query provides only 1 argument but substr requires 2 or 3 +query error DataFusion error: Error during planning: Execution error: Function 'substr' user-defined coercion failed with "Error during planning: The substr function requires 2 or 3 arguments, but got 1." +SELECT substr(str => 'hello world'); + +############# +## PostgreSQL Dialect Tests (uses ExprNamed variant) +############# + +statement ok +set datafusion.sql_parser.dialect = 'PostgreSQL'; + +# Test named arguments in order +query T +SELECT substr(str => 'hello world', start_pos => 7, length => 5); +---- +world + +# Test named arguments out of order +query T +SELECT substr(length => 5, str => 'hello world', start_pos => 7); +---- +world + +# Test mixed positional and named arguments +query T +SELECT substr('hello world', start_pos => 7, length => 5); +---- +world + +# Test with only 2 parameters (length optional) +query T +SELECT substr(str => 'hello world', start_pos => 7); +---- +world + +# Reset to default dialect +statement ok +set datafusion.sql_parser.dialect = 'Generic'; + +############# +## MsSQL Dialect Tests (does NOT support => operator) +############# + +statement ok +set datafusion.sql_parser.dialect = 'MsSQL'; + +# Error: MsSQL dialect does not support => operator +query error DataFusion error: SQL error: ParserError\("Expected: \), found: => at Line: 1, Column: 19"\) +SELECT substr(str => 'hello world', start_pos => 7, length => 5); + +# Reset to default dialect +statement ok +set datafusion.sql_parser.dialect = 'Generic'; diff --git a/docs/source/library-user-guide/functions/adding-udfs.md b/docs/source/library-user-guide/functions/adding-udfs.md index 2335105882a1..e143fdd7b12e 100644 --- a/docs/source/library-user-guide/functions/adding-udfs.md +++ b/docs/source/library-user-guide/functions/adding-udfs.md @@ -586,6 +586,119 @@ For async UDF implementation details, see [`async_udf.rs`](https://github.com/ap [`process_scalar_func_inputs`]: https://docs.rs/datafusion/latest/datafusion/physical_expr/functions/fn.process_scalar_func_inputs.html [`advanced_udf.rs`]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/advanced_udf.rs +## Named Arguments + +DataFusion supports PostgreSQL-style named arguments for scalar functions, allowing you to pass arguments by parameter name: + +```sql +SELECT substr(str => 'hello', start_pos => 2, length => 3); +``` + +Named arguments can be mixed with positional arguments, but positional arguments must come first: + +```sql +SELECT substr('hello', start_pos => 2, length => 3); -- Valid +``` + +### Implementing Functions with Named Arguments + +To support named arguments in your UDF, add parameter names to your function's signature using `.with_parameter_names()`: + +```rust +# use arrow::datatypes::DataType; +# use datafusion_expr::{Signature, Volatility}; +# +# #[derive(Debug)] +# struct MyFunction { +# signature: Signature, +# } +# +impl MyFunction { + fn new() -> Self { + Self { + signature: Signature::uniform( + 2, + vec![DataType::Float64], + Volatility::Immutable + ) + .with_parameter_names(vec![ + "base".to_string(), + "exponent".to_string() + ]) + .expect("valid parameter names"), + } + } +} +``` + +The parameter names should match the order of arguments in your function's signature. DataFusion automatically resolves named arguments to the correct positional order before invoking your function. + +### Example + +```rust +# use std::sync::Arc; +# use std::any::Any; +# use arrow::datatypes::DataType; +# use datafusion_common::Result; +# use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, Signature, Volatility}; +# use datafusion_expr::ScalarUDFImpl; + +#[derive(Debug, PartialEq, Eq, Hash)] +struct PowerFunction { + signature: Signature, +} + +impl PowerFunction { + fn new() -> Self { + Self { + signature: Signature::uniform( + 2, + vec![DataType::Float64], + Volatility::Immutable + ) + .with_parameter_names(vec![ + "base".to_string(), + "exponent".to_string() + ]) + .expect("valid parameter names"), + } + } +} + +impl ScalarUDFImpl for PowerFunction { + fn as_any(&self) -> &dyn Any { self } + fn name(&self) -> &str { "power" } + fn signature(&self) -> &Signature { &self.signature } + + fn return_type(&self, _args: &[DataType]) -> Result { + Ok(DataType::Float64) + } + + fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result { + // Your implementation - arguments are in correct positional order + unimplemented!() + } +} +``` + +Once registered, users can call your function with named arguments: + +```sql +SELECT power(base => 2.0, exponent => 3.0); +SELECT power(2.0, exponent => 3.0); +``` + +### Error Messages + +When a function call fails due to incorrect arguments, DataFusion will show the parameter names in error messages to help users: + +```text +No function matches the given name and argument types substr(Utf8). + Candidate functions: + substr(str: Any, start_pos: Any) + substr(str: Any, start_pos: Any, length: Any) +``` + ## Adding a Window UDF Scalar UDFs are functions that take a row of data and return a single value. Window UDFs are similar, but they also have