Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions datafusion-cli/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions datafusion/expr/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ arrow = { workspace = true }
datafusion-common = { path = "../common", version = "25.0.0" }
lazy_static = { version = "^1.4.0" }
sqlparser = "0.34"
strum = { version = "0.24", features = ["derive"] }
strum_macros = "0.24"

[dev-dependencies]
ctor = "0.2.0"
Expand Down
267 changes: 139 additions & 128 deletions datafusion/expr/src/built_in_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,16 @@

use crate::Volatility;
use datafusion_common::{DataFusionError, Result};
use lazy_static::lazy_static;
use std::collections::HashMap;
use std::fmt;
use std::str::FromStr;
use strum::IntoEnumIterator;
use strum_macros::EnumIter;

use lazy_static::lazy_static;

/// Enum of all built-in scalar functions
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
#[derive(Debug, Clone, PartialEq, Eq, Hash, EnumIter, Copy)]
pub enum BuiltinScalarFunction {
// math functions
/// abs
Expand Down Expand Up @@ -204,117 +208,25 @@ pub enum BuiltinScalarFunction {
}

lazy_static! {
/// Mapping between SQL function names to `BuiltinScalarFunction` types.
/// Note that multiple SQL function names can represent the same `BuiltinScalarFunction`. These are treated as aliases.
/// In case of such aliases, the first SQL function name in the vector is used when displaying the function.
static ref NAME_TO_FUNCTION: Vec<(&'static str, BuiltinScalarFunction)> = vec![
// math functions
("abs", BuiltinScalarFunction::Abs),
("acos", BuiltinScalarFunction::Acos),
("acosh", BuiltinScalarFunction::Acosh),
("asin", BuiltinScalarFunction::Asin),
("asinh", BuiltinScalarFunction::Asinh),
("atan", BuiltinScalarFunction::Atan),
("atanh", BuiltinScalarFunction::Atanh),
("atan2", BuiltinScalarFunction::Atan2),
("cbrt", BuiltinScalarFunction::Cbrt),
("ceil", BuiltinScalarFunction::Ceil),
("cos", BuiltinScalarFunction::Cos),
("cosh", BuiltinScalarFunction::Cosh),
("degrees", BuiltinScalarFunction::Degrees),
("exp", BuiltinScalarFunction::Exp),
("factorial", BuiltinScalarFunction::Factorial),
("floor", BuiltinScalarFunction::Floor),
("gcd", BuiltinScalarFunction::Gcd),
("lcm", BuiltinScalarFunction::Lcm),
("ln", BuiltinScalarFunction::Ln),
("log", BuiltinScalarFunction::Log),
("log10", BuiltinScalarFunction::Log10),
("log2", BuiltinScalarFunction::Log2),
("pi", BuiltinScalarFunction::Pi),
("power", BuiltinScalarFunction::Power),
("pow", BuiltinScalarFunction::Power),
("radians", BuiltinScalarFunction::Radians),
("random", BuiltinScalarFunction::Random),
("round", BuiltinScalarFunction::Round),
("signum", BuiltinScalarFunction::Signum),
("sin", BuiltinScalarFunction::Sin),
("sinh", BuiltinScalarFunction::Sinh),
("sqrt", BuiltinScalarFunction::Sqrt),
("tan", BuiltinScalarFunction::Tan),
("tanh", BuiltinScalarFunction::Tanh),
("trunc", BuiltinScalarFunction::Trunc),

// conditional functions
("coalesce", BuiltinScalarFunction::Coalesce),
("nullif", BuiltinScalarFunction::NullIf),
/// Maps the sql function name to `BuiltinScalarFunction`
static ref NAME_TO_FUNCTION: HashMap<&'static str, BuiltinScalarFunction> = {
let mut map: HashMap<&'static str, BuiltinScalarFunction> = HashMap::new();
BuiltinScalarFunction::iter().for_each(|func| {
let a = aliases(&func);
a.iter().for_each(|a| {map.insert(a, func);});
});
map
};

// string functions
("ascii", BuiltinScalarFunction::Ascii),
("bit_length", BuiltinScalarFunction::BitLength),
("btrim", BuiltinScalarFunction::Btrim),
("character_length", BuiltinScalarFunction::CharacterLength),
("char_length", BuiltinScalarFunction::CharacterLength),
("concat", BuiltinScalarFunction::Concat),
("concat_ws", BuiltinScalarFunction::ConcatWithSeparator),
("chr", BuiltinScalarFunction::Chr),
("initcap", BuiltinScalarFunction::InitCap),
("left", BuiltinScalarFunction::Left),
("length", BuiltinScalarFunction::CharacterLength),
("lower", BuiltinScalarFunction::Lower),
("lpad", BuiltinScalarFunction::Lpad),
("ltrim", BuiltinScalarFunction::Ltrim),
("octet_length", BuiltinScalarFunction::OctetLength),
("repeat", BuiltinScalarFunction::Repeat),
("replace", BuiltinScalarFunction::Replace),
("reverse", BuiltinScalarFunction::Reverse),
("right", BuiltinScalarFunction::Right),
("rpad", BuiltinScalarFunction::Rpad),
("rtrim", BuiltinScalarFunction::Rtrim),
("split_part", BuiltinScalarFunction::SplitPart),
("starts_with", BuiltinScalarFunction::StartsWith),
("strpos", BuiltinScalarFunction::Strpos),
("substr", BuiltinScalarFunction::Substr),
("to_hex", BuiltinScalarFunction::ToHex),
("translate", BuiltinScalarFunction::Translate),
("trim", BuiltinScalarFunction::Trim),
("upper", BuiltinScalarFunction::Upper),
("uuid", BuiltinScalarFunction::Uuid),

// regex functions
("regexp_match", BuiltinScalarFunction::RegexpMatch),
("regexp_replace", BuiltinScalarFunction::RegexpReplace),

// time/date functions
("now", BuiltinScalarFunction::Now),
("current_date", BuiltinScalarFunction::CurrentDate),
("current_time", BuiltinScalarFunction::CurrentTime),
("date_bin", BuiltinScalarFunction::DateBin),
("date_trunc", BuiltinScalarFunction::DateTrunc),
("datetrunc", BuiltinScalarFunction::DateTrunc),
("date_part", BuiltinScalarFunction::DatePart),
("datepart", BuiltinScalarFunction::DatePart),
("to_timestamp", BuiltinScalarFunction::ToTimestamp),
("to_timestamp_millis", BuiltinScalarFunction::ToTimestampMillis),
("to_timestamp_micros", BuiltinScalarFunction::ToTimestampMicros),
("to_timestamp_seconds", BuiltinScalarFunction::ToTimestampSeconds),
("from_unixtime", BuiltinScalarFunction::FromUnixtime),

// hashing functions
("digest", BuiltinScalarFunction::Digest),
("md5", BuiltinScalarFunction::MD5),
("sha224", BuiltinScalarFunction::SHA224),
("sha256", BuiltinScalarFunction::SHA256),
("sha384", BuiltinScalarFunction::SHA384),
("sha512", BuiltinScalarFunction::SHA512),

// other functions
("struct", BuiltinScalarFunction::Struct),
("arrow_typeof", BuiltinScalarFunction::ArrowTypeof),

// array functions
("make_array", BuiltinScalarFunction::MakeArray),
];
/// Maps `BuiltinScalarFunction` --> canonical sql function
/// First alias in the array is used to display function names
static ref FUNCTION_TO_NAME: HashMap<BuiltinScalarFunction, &'static str> = {
let mut map: HashMap<BuiltinScalarFunction, &'static str> = HashMap::new();
BuiltinScalarFunction::iter().for_each(|func| {
map.insert(func, aliases(&func).first().unwrap_or(&"NO_ALIAS"));
});
map
};
}

impl BuiltinScalarFunction {
Expand Down Expand Up @@ -429,31 +341,130 @@ impl BuiltinScalarFunction {
}
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

/// First alias in the array is used to display function names

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is that?

impl fmt::Display for BuiltinScalarFunction {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
for (func_name, func) in NAME_TO_FUNCTION.iter() {
if func == self {
return write!(f, "{}", func_name);
}
fn aliases(func: &BuiltinScalarFunction) -> &'static [&'static str] {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I commented on the wrong line previously @comphead , should be here.

Adding a comment for aliases():
/// First alias in the array is used to display function names

match func {
BuiltinScalarFunction::Abs => &["abs"],
BuiltinScalarFunction::Acos => &["acos"],
BuiltinScalarFunction::Acosh => &["acosh"],
BuiltinScalarFunction::Asin => &["asin"],
BuiltinScalarFunction::Asinh => &["asinh"],
BuiltinScalarFunction::Atan => &["atan"],
BuiltinScalarFunction::Atanh => &["atanh"],
BuiltinScalarFunction::Atan2 => &["atan2"],
BuiltinScalarFunction::Cbrt => &["cbrt"],
BuiltinScalarFunction::Ceil => &["ceil"],
BuiltinScalarFunction::Cos => &["cos"],
BuiltinScalarFunction::Cosh => &["cosh"],
BuiltinScalarFunction::Degrees => &["degrees"],
BuiltinScalarFunction::Exp => &["exp"],
BuiltinScalarFunction::Factorial => &["factorial"],
BuiltinScalarFunction::Floor => &["floor"],
BuiltinScalarFunction::Gcd => &["gcd"],
BuiltinScalarFunction::Lcm => &["lcm"],
BuiltinScalarFunction::Ln => &["ln"],
BuiltinScalarFunction::Log => &["log"],
BuiltinScalarFunction::Log10 => &["log10"],
BuiltinScalarFunction::Log2 => &["log2"],
BuiltinScalarFunction::Pi => &["pi"],
BuiltinScalarFunction::Power => &["power", "pow"],
BuiltinScalarFunction::Radians => &["radians"],
BuiltinScalarFunction::Random => &["random"],
BuiltinScalarFunction::Round => &["round"],
BuiltinScalarFunction::Signum => &["signum"],
BuiltinScalarFunction::Sin => &["sin"],
BuiltinScalarFunction::Sinh => &["sinh"],
BuiltinScalarFunction::Sqrt => &["sqrt"],
BuiltinScalarFunction::Tan => &["tan"],
BuiltinScalarFunction::Tanh => &["tanh"],
BuiltinScalarFunction::Trunc => &["trunc"],

// conditional functions
BuiltinScalarFunction::Coalesce => &["coalesce"],
BuiltinScalarFunction::NullIf => &["nullif"],

// string functions
BuiltinScalarFunction::Ascii => &["ascii"],
BuiltinScalarFunction::BitLength => &["bit_length"],
BuiltinScalarFunction::Btrim => &["btrim"],
BuiltinScalarFunction::CharacterLength => {
&["character_length", "char_length", "length"]
}
BuiltinScalarFunction::Concat => &["concat"],
BuiltinScalarFunction::ConcatWithSeparator => &["concat_ws"],
BuiltinScalarFunction::Chr => &["chr"],
BuiltinScalarFunction::InitCap => &["initcap"],
BuiltinScalarFunction::Left => &["left"],
BuiltinScalarFunction::Lower => &["lower"],
BuiltinScalarFunction::Lpad => &["lpad"],
BuiltinScalarFunction::Ltrim => &["ltrim"],
BuiltinScalarFunction::OctetLength => &["octet_length"],
BuiltinScalarFunction::Repeat => &["repeat"],
BuiltinScalarFunction::Replace => &["replace"],
BuiltinScalarFunction::Reverse => &["reverse"],
BuiltinScalarFunction::Right => &["right"],
BuiltinScalarFunction::Rpad => &["rpad"],
BuiltinScalarFunction::Rtrim => &["rtrim"],
BuiltinScalarFunction::SplitPart => &["split_part"],
BuiltinScalarFunction::StartsWith => &["starts_with"],
BuiltinScalarFunction::Strpos => &["strpos"],
BuiltinScalarFunction::Substr => &["substr"],
BuiltinScalarFunction::ToHex => &["to_hex"],
BuiltinScalarFunction::Translate => &["translate"],
BuiltinScalarFunction::Trim => &["trim"],
BuiltinScalarFunction::Upper => &["upper"],
BuiltinScalarFunction::Uuid => &["uuid"],

// regex functions
BuiltinScalarFunction::RegexpMatch => &["regexp_match"],
BuiltinScalarFunction::RegexpReplace => &["regexp_replace"],

// time/date functions
BuiltinScalarFunction::Now => &["now"],
BuiltinScalarFunction::CurrentDate => &["current_date"],
BuiltinScalarFunction::CurrentTime => &["current_time"],
BuiltinScalarFunction::DateBin => &["date_bin"],
BuiltinScalarFunction::DateTrunc => &["date_trunc", "datetrunc"],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like how this formulation makes it easier to understand what functions have aliases and which do not 👍

BuiltinScalarFunction::DatePart => &["date_part", "datepart"],
BuiltinScalarFunction::ToTimestamp => &["to_timestamp"],
BuiltinScalarFunction::ToTimestampMillis => &["to_timestamp_millis"],
BuiltinScalarFunction::ToTimestampMicros => &["to_timestamp_micros"],
BuiltinScalarFunction::ToTimestampSeconds => &["to_timestamp_seconds"],
BuiltinScalarFunction::FromUnixtime => &["from_unixtime"],

// hashing functions
BuiltinScalarFunction::Digest => &["digest"],
BuiltinScalarFunction::MD5 => &["md5"],
BuiltinScalarFunction::SHA224 => &["sha224"],
BuiltinScalarFunction::SHA256 => &["sha256"],
BuiltinScalarFunction::SHA384 => &["sha384"],
BuiltinScalarFunction::SHA512 => &["sha512"],

// Should not be reached
write!(f, "{}", format!("{self:?}").to_lowercase())
// other functions
BuiltinScalarFunction::Struct => &["struct"],
BuiltinScalarFunction::ArrowTypeof => &["arrow_typeof"],

// array functions
BuiltinScalarFunction::MakeArray => &["make_array"],
}
}

impl fmt::Display for BuiltinScalarFunction {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
// .unwrap is safe here because compiler makes sure the map will have matches for each BuiltinScalarFunction
write!(f, "{}", FUNCTION_TO_NAME.get(self).unwrap())
}
}

impl FromStr for BuiltinScalarFunction {
type Err = DataFusionError;
fn from_str(name: &str) -> Result<BuiltinScalarFunction> {
for (func_name, func) in NAME_TO_FUNCTION.iter() {
if name == *func_name {
return Ok(func.clone());
}
if let Some(func) = NAME_TO_FUNCTION.get(name) {
Ok(*func)
} else {
Err(DataFusionError::Plan(format!(
"There is no built-in function named {name}"
)))
}

Err(DataFusionError::Plan(format!(
"There is no built-in function named {name}"
)))
}
}

Expand Down
2 changes: 1 addition & 1 deletion datafusion/sql/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ where
))),
Expr::ScalarFunction(ScalarFunction { fun, args }) => {
Ok(Expr::ScalarFunction(ScalarFunction::new(
fun.clone(),
*fun,
args.iter()
.map(|e| clone_with_replacement(e, replacement_fn))
.collect::<Result<Vec<Expr>>>()?,
Expand Down