diff --git a/rust/datafusion/README.md b/rust/datafusion/README.md index 78b0a3172b7..c8688257074 100644 --- a/rust/datafusion/README.md +++ b/rust/datafusion/README.md @@ -155,6 +155,10 @@ Below is a checklist of what you need to do to add a new scalar function to Data * a new line in `create_physical_expr` mapping the built-in to the implementation * tests to the function. * In [tests/sql.rs](tests/sql.rs), add a new test where the function is called through SQL against well known data and returns the expected result. +* In [src/logical_plan/expr](src/logical_plan/expr.rs), add: + * a new entry of the `unary_scalar_expr!` macro for the new function. +* In [src/logical_plan/mod](src/logical_plan/mod.rs), add: + * a new entry in the `pub use expr::{}` set. ## How to add a new aggregate function diff --git a/rust/datafusion/src/logical_plan/expr.rs b/rust/datafusion/src/logical_plan/expr.rs index ca72948f1af..3bcfd078b3c 100644 --- a/rust/datafusion/src/logical_plan/expr.rs +++ b/rust/datafusion/src/logical_plan/expr.rs @@ -617,7 +617,7 @@ pub fn lit(n: T) -> Expr { } /// Create an convenience function representing a unary scalar function -macro_rules! unary_math_expr { +macro_rules! unary_scalar_expr { ($ENUM:ident, $FUNC:ident) => { #[allow(missing_docs)] pub fn $FUNC(e: Expr) -> Expr { @@ -629,24 +629,27 @@ macro_rules! unary_math_expr { }; } -// generate methods for creating the supported unary math expressions -unary_math_expr!(Sqrt, sqrt); -unary_math_expr!(Sin, sin); -unary_math_expr!(Cos, cos); -unary_math_expr!(Tan, tan); -unary_math_expr!(Asin, asin); -unary_math_expr!(Acos, acos); -unary_math_expr!(Atan, atan); -unary_math_expr!(Floor, floor); -unary_math_expr!(Ceil, ceil); -unary_math_expr!(Round, round); -unary_math_expr!(Trunc, trunc); -unary_math_expr!(Abs, abs); -unary_math_expr!(Signum, signum); -unary_math_expr!(Exp, exp); -unary_math_expr!(Log, ln); -unary_math_expr!(Log2, log2); -unary_math_expr!(Log10, log10); +// generate methods for creating the supported unary expressions +unary_scalar_expr!(Sqrt, sqrt); +unary_scalar_expr!(Sin, sin); +unary_scalar_expr!(Cos, cos); +unary_scalar_expr!(Tan, tan); +unary_scalar_expr!(Asin, asin); +unary_scalar_expr!(Acos, acos); +unary_scalar_expr!(Atan, atan); +unary_scalar_expr!(Floor, floor); +unary_scalar_expr!(Ceil, ceil); +unary_scalar_expr!(Round, round); +unary_scalar_expr!(Trunc, trunc); +unary_scalar_expr!(Abs, abs); +unary_scalar_expr!(Signum, signum); +unary_scalar_expr!(Exp, exp); +unary_scalar_expr!(Log, ln); +unary_scalar_expr!(Log2, log2); +unary_scalar_expr!(Log10, log10); +unary_scalar_expr!(Lower, lower); +unary_scalar_expr!(Trim, trim); +unary_scalar_expr!(Upper, upper); /// returns the length of a string in bytes pub fn length(e: Expr) -> Expr { diff --git a/rust/datafusion/src/logical_plan/mod.rs b/rust/datafusion/src/logical_plan/mod.rs index fc38057dfe0..4cd4d99587b 100644 --- a/rust/datafusion/src/logical_plan/mod.rs +++ b/rust/datafusion/src/logical_plan/mod.rs @@ -36,8 +36,8 @@ pub use display::display_schema; pub use expr::{ abs, acos, and, array, asin, atan, avg, binary_expr, case, ceil, col, concat, cos, count, create_udaf, create_udf, exp, exprlist_to_fields, floor, length, lit, ln, - log10, log2, max, min, or, round, signum, sin, sqrt, sum, tan, trunc, when, Expr, - Literal, + log10, log2, lower, max, min, or, round, signum, sin, sqrt, sum, tan, trim, trunc, + upper, when, Expr, Literal, }; pub use extension::UserDefinedLogicalNode; pub use operators::Operator; diff --git a/rust/datafusion/src/physical_plan/functions.rs b/rust/datafusion/src/physical_plan/functions.rs index fb6c63b709f..94e47bd56a0 100644 --- a/rust/datafusion/src/physical_plan/functions.rs +++ b/rust/datafusion/src/physical_plan/functions.rs @@ -61,7 +61,7 @@ pub enum Signature { VariadicEqual, /// fixed number of arguments of an arbitrary but equal type out of a list of valid types // A function of one argument of f64 is `Uniform(1, vec![DataType::Float64])` - // A function of two arguments of f64 or f32 is `Uniform(1, vec![DataType::Float32, DataType::Float64])` + // A function of one argument of f64 or f32 is `Uniform(1, vec![DataType::Float32, DataType::Float64])` Uniform(usize, Vec), /// exact number of arguments of an exact type Exact(Vec), @@ -118,6 +118,12 @@ pub enum BuiltinScalarFunction { Length, /// concat Concat, + /// lower + Lower, + /// upper + Upper, + /// trim + Trim, /// to_timestamp ToTimestamp, /// construct an array from columns @@ -155,7 +161,12 @@ impl FromStr for BuiltinScalarFunction { "abs" => BuiltinScalarFunction::Abs, "signum" => BuiltinScalarFunction::Signum, "length" => BuiltinScalarFunction::Length, + "char_length" => BuiltinScalarFunction::Length, + "character_length" => BuiltinScalarFunction::Length, "concat" => BuiltinScalarFunction::Concat, + "lower" => BuiltinScalarFunction::Lower, + "trim" => BuiltinScalarFunction::Trim, + "upper" => BuiltinScalarFunction::Upper, "to_timestamp" => BuiltinScalarFunction::ToTimestamp, "array" => BuiltinScalarFunction::Array, "nullif" => BuiltinScalarFunction::NullIf, @@ -203,6 +214,36 @@ pub fn return_type( } }), BuiltinScalarFunction::Concat => Ok(DataType::Utf8), + BuiltinScalarFunction::Lower => Ok(match arg_types[0] { + DataType::LargeUtf8 => DataType::LargeUtf8, + DataType::Utf8 => DataType::Utf8, + _ => { + // this error is internal as `data_types` should have captured this. + return Err(DataFusionError::Internal( + "The upper function can only accept strings.".to_string(), + )); + } + }), + BuiltinScalarFunction::Trim => Ok(match arg_types[0] { + DataType::LargeUtf8 => DataType::LargeUtf8, + DataType::Utf8 => DataType::Utf8, + _ => { + // this error is internal as `data_types` should have captured this. + return Err(DataFusionError::Internal( + "The trim function can only accept strings.".to_string(), + )); + } + }), + BuiltinScalarFunction::Upper => Ok(match arg_types[0] { + DataType::LargeUtf8 => DataType::LargeUtf8, + DataType::Utf8 => DataType::Utf8, + _ => { + // this error is internal as `data_types` should have captured this. + return Err(DataFusionError::Internal( + "The upper function can only accept strings.".to_string(), + )); + } + }), BuiltinScalarFunction::ToTimestamp => { Ok(DataType::Timestamp(TimeUnit::Nanosecond, None)) } @@ -249,6 +290,30 @@ pub fn create_physical_expr( BuiltinScalarFunction::Concat => { |args| Ok(Arc::new(string_expressions::concatenate(args)?)) } + BuiltinScalarFunction::Lower => |args| match args[0].data_type() { + DataType::Utf8 => Ok(Arc::new(string_expressions::lower::(args)?)), + DataType::LargeUtf8 => Ok(Arc::new(string_expressions::lower::(args)?)), + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {:?} for function lower", + other, + ))), + }, + BuiltinScalarFunction::Trim => |args| match args[0].data_type() { + DataType::Utf8 => Ok(Arc::new(string_expressions::trim::(args)?)), + DataType::LargeUtf8 => Ok(Arc::new(string_expressions::trim::(args)?)), + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {:?} for function trim", + other, + ))), + }, + BuiltinScalarFunction::Upper => |args| match args[0].data_type() { + DataType::Utf8 => Ok(Arc::new(string_expressions::upper::(args)?)), + DataType::LargeUtf8 => Ok(Arc::new(string_expressions::upper::(args)?)), + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {:?} for function upper", + other, + ))), + }, BuiltinScalarFunction::ToTimestamp => { |args| Ok(Arc::new(datetime_expressions::to_timestamp(args)?)) } @@ -280,6 +345,15 @@ fn signature(fun: &BuiltinScalarFunction) -> Signature { Signature::Uniform(1, vec![DataType::Utf8, DataType::LargeUtf8]) } BuiltinScalarFunction::Concat => Signature::Variadic(vec![DataType::Utf8]), + BuiltinScalarFunction::Lower => { + Signature::Uniform(1, vec![DataType::Utf8, DataType::LargeUtf8]) + } + BuiltinScalarFunction::Upper => { + Signature::Uniform(1, vec![DataType::Utf8, DataType::LargeUtf8]) + } + BuiltinScalarFunction::Trim => { + Signature::Uniform(1, vec![DataType::Utf8, DataType::LargeUtf8]) + } BuiltinScalarFunction::ToTimestamp => Signature::Uniform(1, vec![DataType::Utf8]), BuiltinScalarFunction::Array => { Signature::Variadic(array_expressions::SUPPORTED_ARRAY_TYPES.to_vec()) diff --git a/rust/datafusion/src/physical_plan/string_expressions.rs b/rust/datafusion/src/physical_plan/string_expressions.rs index fb65f91ce3a..328e24b8b92 100644 --- a/rust/datafusion/src/physical_plan/string_expressions.rs +++ b/rust/datafusion/src/physical_plan/string_expressions.rs @@ -18,7 +18,10 @@ //! String expressions use crate::error::{DataFusionError, Result}; -use arrow::array::{Array, ArrayRef, StringArray, StringBuilder}; +use arrow::array::{ + Array, ArrayRef, GenericStringArray, StringArray, StringBuilder, + StringOffsetSizeTrait, +}; macro_rules! downcast_vec { ($ARGS:expr, $ARRAY_TYPE:ident) => {{ @@ -66,3 +69,23 @@ pub fn concatenate(args: &[ArrayRef]) -> Result { } Ok(builder.finish()) } + +macro_rules! string_unary_function { + ($NAME:ident, $FUNC:ident) => { + /// string function that accepts Utf8 or LargeUtf8 and returns Utf8 or LargeUtf8 + pub fn $NAME( + args: &[ArrayRef], + ) -> Result> { + let array = args[0] + .as_any() + .downcast_ref::>() + .unwrap(); + // first map is the iterator, second is for the `Option<_>` + Ok(array.iter().map(|x| x.map(|x| x.$FUNC())).collect()) + } + }; +} + +string_unary_function!(lower, to_ascii_lowercase); +string_unary_function!(upper, to_ascii_uppercase); +string_unary_function!(trim, trim); diff --git a/rust/datafusion/src/sql/planner.rs b/rust/datafusion/src/sql/planner.rs index 4d469c83059..0d5c819e1db 100644 --- a/rust/datafusion/src/sql/planner.rs +++ b/rust/datafusion/src/sql/planner.rs @@ -550,6 +550,8 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { }, SQLExpr::Value(Value::SingleQuotedString(ref s)) => Ok(lit(s.clone())), + SQLExpr::Value(Value::Null) => Ok(Expr::Literal(ScalarValue::Utf8(None))), + SQLExpr::Identifier(ref id) => { if &id.value[0..1] == "@" { let var_names = vec![id.value.clone()]; diff --git a/rust/datafusion/tests/sql.rs b/rust/datafusion/tests/sql.rs index 84c90a5d326..48a8d70dba3 100644 --- a/rust/datafusion/tests/sql.rs +++ b/rust/datafusion/tests/sql.rs @@ -1826,3 +1826,27 @@ async fn csv_between_expr_negated() -> Result<()> { assert_eq!(expected, actual); Ok(()) } + +#[tokio::test] +async fn string_expressions() -> Result<()> { + let mut ctx = ExecutionContext::new(); + let sql = "SELECT + char_length('tom') AS char_length + ,char_length(NULL) AS char_length_null + ,character_length('tom') AS character_length + ,character_length(NULL) AS character_length_null + ,lower('TOM') AS lower + ,lower(NULL) AS lower_null + ,upper('tom') AS upper + ,upper(NULL) AS upper_null + ,trim(' tom ') AS trim + ,trim(NULL) AS trim_null + "; + let actual = execute(&mut ctx, sql).await; + + let expected = vec![vec![ + "3", "NULL", "3", "NULL", "tom", "NULL", "TOM", "NULL", "tom", "NULL", + ]]; + assert_eq!(expected, actual); + Ok(()) +}