diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index 57513528be64c..6fc56424299e0 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -1059,6 +1059,8 @@ dependencies = [ "datafusion-common", "lazy_static", "sqlparser", + "strum", + "strum_macros", ] [[package]] @@ -2805,6 +2807,9 @@ name = "strum" version = "0.24.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "063e6045c0e62079840579a7e47a355ae92f60eb74daaf156fb1e84ba164e63f" +dependencies = [ + "strum_macros", +] [[package]] name = "strum_macros" diff --git a/datafusion/expr/Cargo.toml b/datafusion/expr/Cargo.toml index 160d76deb13c7..52a31db485e8c 100644 --- a/datafusion/expr/Cargo.toml +++ b/datafusion/expr/Cargo.toml @@ -40,6 +40,8 @@ arrow = { workspace = true } datafusion-common = { path = "../common", version = "25.0.0" } lazy_static = { version = "^1.4.0" } sqlparser = "0.34" +strum = { version = "0.24", features = ["derive"] } +strum_macros = "0.24" [dev-dependencies] ctor = "0.2.0" diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs index 26b88a71a706f..3911939b4ca6e 100644 --- a/datafusion/expr/src/built_in_function.rs +++ b/datafusion/expr/src/built_in_function.rs @@ -19,12 +19,16 @@ use crate::Volatility; use datafusion_common::{DataFusionError, Result}; -use lazy_static::lazy_static; +use std::collections::HashMap; use std::fmt; use std::str::FromStr; +use strum::IntoEnumIterator; +use strum_macros::EnumIter; + +use lazy_static::lazy_static; /// Enum of all built-in scalar functions -#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, PartialEq, Eq, Hash, EnumIter, Copy)] pub enum BuiltinScalarFunction { // math functions /// abs @@ -204,117 +208,25 @@ pub enum BuiltinScalarFunction { } lazy_static! { - /// Mapping between SQL function names to `BuiltinScalarFunction` types. - /// Note that multiple SQL function names can represent the same `BuiltinScalarFunction`. These are treated as aliases. - /// In case of such aliases, the first SQL function name in the vector is used when displaying the function. - static ref NAME_TO_FUNCTION: Vec<(&'static str, BuiltinScalarFunction)> = vec![ - // math functions - ("abs", BuiltinScalarFunction::Abs), - ("acos", BuiltinScalarFunction::Acos), - ("acosh", BuiltinScalarFunction::Acosh), - ("asin", BuiltinScalarFunction::Asin), - ("asinh", BuiltinScalarFunction::Asinh), - ("atan", BuiltinScalarFunction::Atan), - ("atanh", BuiltinScalarFunction::Atanh), - ("atan2", BuiltinScalarFunction::Atan2), - ("cbrt", BuiltinScalarFunction::Cbrt), - ("ceil", BuiltinScalarFunction::Ceil), - ("cos", BuiltinScalarFunction::Cos), - ("cosh", BuiltinScalarFunction::Cosh), - ("degrees", BuiltinScalarFunction::Degrees), - ("exp", BuiltinScalarFunction::Exp), - ("factorial", BuiltinScalarFunction::Factorial), - ("floor", BuiltinScalarFunction::Floor), - ("gcd", BuiltinScalarFunction::Gcd), - ("lcm", BuiltinScalarFunction::Lcm), - ("ln", BuiltinScalarFunction::Ln), - ("log", BuiltinScalarFunction::Log), - ("log10", BuiltinScalarFunction::Log10), - ("log2", BuiltinScalarFunction::Log2), - ("pi", BuiltinScalarFunction::Pi), - ("power", BuiltinScalarFunction::Power), - ("pow", BuiltinScalarFunction::Power), - ("radians", BuiltinScalarFunction::Radians), - ("random", BuiltinScalarFunction::Random), - ("round", BuiltinScalarFunction::Round), - ("signum", BuiltinScalarFunction::Signum), - ("sin", BuiltinScalarFunction::Sin), - ("sinh", BuiltinScalarFunction::Sinh), - ("sqrt", BuiltinScalarFunction::Sqrt), - ("tan", BuiltinScalarFunction::Tan), - ("tanh", BuiltinScalarFunction::Tanh), - ("trunc", BuiltinScalarFunction::Trunc), - - // conditional functions - ("coalesce", BuiltinScalarFunction::Coalesce), - ("nullif", BuiltinScalarFunction::NullIf), + /// Maps the sql function name to `BuiltinScalarFunction` + static ref NAME_TO_FUNCTION: HashMap<&'static str, BuiltinScalarFunction> = { + let mut map: HashMap<&'static str, BuiltinScalarFunction> = HashMap::new(); + BuiltinScalarFunction::iter().for_each(|func| { + let a = aliases(&func); + a.iter().for_each(|a| {map.insert(a, func);}); + }); + map + }; - // string functions - ("ascii", BuiltinScalarFunction::Ascii), - ("bit_length", BuiltinScalarFunction::BitLength), - ("btrim", BuiltinScalarFunction::Btrim), - ("character_length", BuiltinScalarFunction::CharacterLength), - ("char_length", BuiltinScalarFunction::CharacterLength), - ("concat", BuiltinScalarFunction::Concat), - ("concat_ws", BuiltinScalarFunction::ConcatWithSeparator), - ("chr", BuiltinScalarFunction::Chr), - ("initcap", BuiltinScalarFunction::InitCap), - ("left", BuiltinScalarFunction::Left), - ("length", BuiltinScalarFunction::CharacterLength), - ("lower", BuiltinScalarFunction::Lower), - ("lpad", BuiltinScalarFunction::Lpad), - ("ltrim", BuiltinScalarFunction::Ltrim), - ("octet_length", BuiltinScalarFunction::OctetLength), - ("repeat", BuiltinScalarFunction::Repeat), - ("replace", BuiltinScalarFunction::Replace), - ("reverse", BuiltinScalarFunction::Reverse), - ("right", BuiltinScalarFunction::Right), - ("rpad", BuiltinScalarFunction::Rpad), - ("rtrim", BuiltinScalarFunction::Rtrim), - ("split_part", BuiltinScalarFunction::SplitPart), - ("starts_with", BuiltinScalarFunction::StartsWith), - ("strpos", BuiltinScalarFunction::Strpos), - ("substr", BuiltinScalarFunction::Substr), - ("to_hex", BuiltinScalarFunction::ToHex), - ("translate", BuiltinScalarFunction::Translate), - ("trim", BuiltinScalarFunction::Trim), - ("upper", BuiltinScalarFunction::Upper), - ("uuid", BuiltinScalarFunction::Uuid), - - // regex functions - ("regexp_match", BuiltinScalarFunction::RegexpMatch), - ("regexp_replace", BuiltinScalarFunction::RegexpReplace), - - // time/date functions - ("now", BuiltinScalarFunction::Now), - ("current_date", BuiltinScalarFunction::CurrentDate), - ("current_time", BuiltinScalarFunction::CurrentTime), - ("date_bin", BuiltinScalarFunction::DateBin), - ("date_trunc", BuiltinScalarFunction::DateTrunc), - ("datetrunc", BuiltinScalarFunction::DateTrunc), - ("date_part", BuiltinScalarFunction::DatePart), - ("datepart", BuiltinScalarFunction::DatePart), - ("to_timestamp", BuiltinScalarFunction::ToTimestamp), - ("to_timestamp_millis", BuiltinScalarFunction::ToTimestampMillis), - ("to_timestamp_micros", BuiltinScalarFunction::ToTimestampMicros), - ("to_timestamp_seconds", BuiltinScalarFunction::ToTimestampSeconds), - ("from_unixtime", BuiltinScalarFunction::FromUnixtime), - - // hashing functions - ("digest", BuiltinScalarFunction::Digest), - ("md5", BuiltinScalarFunction::MD5), - ("sha224", BuiltinScalarFunction::SHA224), - ("sha256", BuiltinScalarFunction::SHA256), - ("sha384", BuiltinScalarFunction::SHA384), - ("sha512", BuiltinScalarFunction::SHA512), - - // other functions - ("struct", BuiltinScalarFunction::Struct), - ("arrow_typeof", BuiltinScalarFunction::ArrowTypeof), - - // array functions - ("make_array", BuiltinScalarFunction::MakeArray), - ]; + /// Maps `BuiltinScalarFunction` --> canonical sql function + /// First alias in the array is used to display function names + static ref FUNCTION_TO_NAME: HashMap = { + let mut map: HashMap = HashMap::new(); + BuiltinScalarFunction::iter().for_each(|func| { + map.insert(func, aliases(&func).first().unwrap_or(&"NO_ALIAS")); + }); + map + }; } impl BuiltinScalarFunction { @@ -429,31 +341,130 @@ impl BuiltinScalarFunction { } } -impl fmt::Display for BuiltinScalarFunction { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - for (func_name, func) in NAME_TO_FUNCTION.iter() { - if func == self { - return write!(f, "{}", func_name); - } +fn aliases(func: &BuiltinScalarFunction) -> &'static [&'static str] { + match func { + BuiltinScalarFunction::Abs => &["abs"], + BuiltinScalarFunction::Acos => &["acos"], + BuiltinScalarFunction::Acosh => &["acosh"], + BuiltinScalarFunction::Asin => &["asin"], + BuiltinScalarFunction::Asinh => &["asinh"], + BuiltinScalarFunction::Atan => &["atan"], + BuiltinScalarFunction::Atanh => &["atanh"], + BuiltinScalarFunction::Atan2 => &["atan2"], + BuiltinScalarFunction::Cbrt => &["cbrt"], + BuiltinScalarFunction::Ceil => &["ceil"], + BuiltinScalarFunction::Cos => &["cos"], + BuiltinScalarFunction::Cosh => &["cosh"], + BuiltinScalarFunction::Degrees => &["degrees"], + BuiltinScalarFunction::Exp => &["exp"], + BuiltinScalarFunction::Factorial => &["factorial"], + BuiltinScalarFunction::Floor => &["floor"], + BuiltinScalarFunction::Gcd => &["gcd"], + BuiltinScalarFunction::Lcm => &["lcm"], + BuiltinScalarFunction::Ln => &["ln"], + BuiltinScalarFunction::Log => &["log"], + BuiltinScalarFunction::Log10 => &["log10"], + BuiltinScalarFunction::Log2 => &["log2"], + BuiltinScalarFunction::Pi => &["pi"], + BuiltinScalarFunction::Power => &["power", "pow"], + BuiltinScalarFunction::Radians => &["radians"], + BuiltinScalarFunction::Random => &["random"], + BuiltinScalarFunction::Round => &["round"], + BuiltinScalarFunction::Signum => &["signum"], + BuiltinScalarFunction::Sin => &["sin"], + BuiltinScalarFunction::Sinh => &["sinh"], + BuiltinScalarFunction::Sqrt => &["sqrt"], + BuiltinScalarFunction::Tan => &["tan"], + BuiltinScalarFunction::Tanh => &["tanh"], + BuiltinScalarFunction::Trunc => &["trunc"], + + // conditional functions + BuiltinScalarFunction::Coalesce => &["coalesce"], + BuiltinScalarFunction::NullIf => &["nullif"], + + // string functions + BuiltinScalarFunction::Ascii => &["ascii"], + BuiltinScalarFunction::BitLength => &["bit_length"], + BuiltinScalarFunction::Btrim => &["btrim"], + BuiltinScalarFunction::CharacterLength => { + &["character_length", "char_length", "length"] } + BuiltinScalarFunction::Concat => &["concat"], + BuiltinScalarFunction::ConcatWithSeparator => &["concat_ws"], + BuiltinScalarFunction::Chr => &["chr"], + BuiltinScalarFunction::InitCap => &["initcap"], + BuiltinScalarFunction::Left => &["left"], + BuiltinScalarFunction::Lower => &["lower"], + BuiltinScalarFunction::Lpad => &["lpad"], + BuiltinScalarFunction::Ltrim => &["ltrim"], + BuiltinScalarFunction::OctetLength => &["octet_length"], + BuiltinScalarFunction::Repeat => &["repeat"], + BuiltinScalarFunction::Replace => &["replace"], + BuiltinScalarFunction::Reverse => &["reverse"], + BuiltinScalarFunction::Right => &["right"], + BuiltinScalarFunction::Rpad => &["rpad"], + BuiltinScalarFunction::Rtrim => &["rtrim"], + BuiltinScalarFunction::SplitPart => &["split_part"], + BuiltinScalarFunction::StartsWith => &["starts_with"], + BuiltinScalarFunction::Strpos => &["strpos"], + BuiltinScalarFunction::Substr => &["substr"], + BuiltinScalarFunction::ToHex => &["to_hex"], + BuiltinScalarFunction::Translate => &["translate"], + BuiltinScalarFunction::Trim => &["trim"], + BuiltinScalarFunction::Upper => &["upper"], + BuiltinScalarFunction::Uuid => &["uuid"], + + // regex functions + BuiltinScalarFunction::RegexpMatch => &["regexp_match"], + BuiltinScalarFunction::RegexpReplace => &["regexp_replace"], + + // time/date functions + BuiltinScalarFunction::Now => &["now"], + BuiltinScalarFunction::CurrentDate => &["current_date"], + BuiltinScalarFunction::CurrentTime => &["current_time"], + BuiltinScalarFunction::DateBin => &["date_bin"], + BuiltinScalarFunction::DateTrunc => &["date_trunc", "datetrunc"], + BuiltinScalarFunction::DatePart => &["date_part", "datepart"], + BuiltinScalarFunction::ToTimestamp => &["to_timestamp"], + BuiltinScalarFunction::ToTimestampMillis => &["to_timestamp_millis"], + BuiltinScalarFunction::ToTimestampMicros => &["to_timestamp_micros"], + BuiltinScalarFunction::ToTimestampSeconds => &["to_timestamp_seconds"], + BuiltinScalarFunction::FromUnixtime => &["from_unixtime"], + + // hashing functions + BuiltinScalarFunction::Digest => &["digest"], + BuiltinScalarFunction::MD5 => &["md5"], + BuiltinScalarFunction::SHA224 => &["sha224"], + BuiltinScalarFunction::SHA256 => &["sha256"], + BuiltinScalarFunction::SHA384 => &["sha384"], + BuiltinScalarFunction::SHA512 => &["sha512"], - // Should not be reached - write!(f, "{}", format!("{self:?}").to_lowercase()) + // other functions + BuiltinScalarFunction::Struct => &["struct"], + BuiltinScalarFunction::ArrowTypeof => &["arrow_typeof"], + + // array functions + BuiltinScalarFunction::MakeArray => &["make_array"], + } +} + +impl fmt::Display for BuiltinScalarFunction { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + // .unwrap is safe here because compiler makes sure the map will have matches for each BuiltinScalarFunction + write!(f, "{}", FUNCTION_TO_NAME.get(self).unwrap()) } } impl FromStr for BuiltinScalarFunction { type Err = DataFusionError; fn from_str(name: &str) -> Result { - for (func_name, func) in NAME_TO_FUNCTION.iter() { - if name == *func_name { - return Ok(func.clone()); - } + if let Some(func) = NAME_TO_FUNCTION.get(name) { + Ok(*func) + } else { + Err(DataFusionError::Plan(format!( + "There is no built-in function named {name}" + ))) } - - Err(DataFusionError::Plan(format!( - "There is no built-in function named {name}" - ))) } } diff --git a/datafusion/sql/src/utils.rs b/datafusion/sql/src/utils.rs index dd96f1b2bbeaa..0929aec6e5eb4 100644 --- a/datafusion/sql/src/utils.rs +++ b/datafusion/sql/src/utils.rs @@ -304,7 +304,7 @@ where ))), Expr::ScalarFunction(ScalarFunction { fun, args }) => { Ok(Expr::ScalarFunction(ScalarFunction::new( - fun.clone(), + *fun, args.iter() .map(|e| clone_with_replacement(e, replacement_fn)) .collect::>>()?,