Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions rust/datafusion/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,10 @@ Below is a checklist of what you need to do to add a new scalar function to Data
* a new line in `create_physical_expr` mapping the built-in to the implementation
* tests to the function.
* In [tests/sql.rs](tests/sql.rs), add a new test where the function is called through SQL against well known data and returns the expected result.
* In [src/logical_plan/expr](src/logical_plan/expr.rs), add:
* a new entry of the `unary_scalar_expr!` macro for the new function.
* In [src/logical_plan/mod](src/logical_plan/mod.rs), add:
* a new entry in the `pub use expr::{}` set.

## How to add a new aggregate function

Expand Down
41 changes: 22 additions & 19 deletions rust/datafusion/src/logical_plan/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -617,7 +617,7 @@ pub fn lit<T: Literal>(n: T) -> Expr {
}

/// Create an convenience function representing a unary scalar function
macro_rules! unary_math_expr {
macro_rules! unary_scalar_expr {
($ENUM:ident, $FUNC:ident) => {
#[allow(missing_docs)]
pub fn $FUNC(e: Expr) -> Expr {
Expand All @@ -629,24 +629,27 @@ macro_rules! unary_math_expr {
};
}

// generate methods for creating the supported unary math expressions
unary_math_expr!(Sqrt, sqrt);
unary_math_expr!(Sin, sin);
unary_math_expr!(Cos, cos);
unary_math_expr!(Tan, tan);
unary_math_expr!(Asin, asin);
unary_math_expr!(Acos, acos);
unary_math_expr!(Atan, atan);
unary_math_expr!(Floor, floor);
unary_math_expr!(Ceil, ceil);
unary_math_expr!(Round, round);
unary_math_expr!(Trunc, trunc);
unary_math_expr!(Abs, abs);
unary_math_expr!(Signum, signum);
unary_math_expr!(Exp, exp);
unary_math_expr!(Log, ln);
unary_math_expr!(Log2, log2);
unary_math_expr!(Log10, log10);
// generate methods for creating the supported unary expressions
unary_scalar_expr!(Sqrt, sqrt);
unary_scalar_expr!(Sin, sin);
unary_scalar_expr!(Cos, cos);
unary_scalar_expr!(Tan, tan);
unary_scalar_expr!(Asin, asin);
unary_scalar_expr!(Acos, acos);
unary_scalar_expr!(Atan, atan);
unary_scalar_expr!(Floor, floor);
unary_scalar_expr!(Ceil, ceil);
unary_scalar_expr!(Round, round);
unary_scalar_expr!(Trunc, trunc);
unary_scalar_expr!(Abs, abs);
unary_scalar_expr!(Signum, signum);
unary_scalar_expr!(Exp, exp);
unary_scalar_expr!(Log, ln);
unary_scalar_expr!(Log2, log2);
unary_scalar_expr!(Log10, log10);
unary_scalar_expr!(Lower, lower);
unary_scalar_expr!(Trim, trim);
unary_scalar_expr!(Upper, upper);

/// returns the length of a string in bytes
pub fn length(e: Expr) -> Expr {
Expand Down
4 changes: 2 additions & 2 deletions rust/datafusion/src/logical_plan/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ pub use display::display_schema;
pub use expr::{
abs, acos, and, array, asin, atan, avg, binary_expr, case, ceil, col, concat, cos,
count, create_udaf, create_udf, exp, exprlist_to_fields, floor, length, lit, ln,
log10, log2, max, min, or, round, signum, sin, sqrt, sum, tan, trunc, when, Expr,
Literal,
log10, log2, lower, max, min, or, round, signum, sin, sqrt, sum, tan, trim, trunc,
upper, when, Expr, Literal,
};
pub use extension::UserDefinedLogicalNode;
pub use operators::Operator;
Expand Down
76 changes: 75 additions & 1 deletion rust/datafusion/src/physical_plan/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ pub enum Signature {
VariadicEqual,
/// fixed number of arguments of an arbitrary but equal type out of a list of valid types
// A function of one argument of f64 is `Uniform(1, vec![DataType::Float64])`
// A function of two arguments of f64 or f32 is `Uniform(1, vec![DataType::Float32, DataType::Float64])`
// A function of one argument of f64 or f32 is `Uniform(1, vec![DataType::Float32, DataType::Float64])`
Uniform(usize, Vec<DataType>),
/// exact number of arguments of an exact type
Exact(Vec<DataType>),
Expand Down Expand Up @@ -118,6 +118,12 @@ pub enum BuiltinScalarFunction {
Length,
/// concat
Concat,
/// lower
Lower,
/// upper
Upper,
/// trim
Trim,
/// to_timestamp
ToTimestamp,
/// construct an array from columns
Expand Down Expand Up @@ -155,7 +161,12 @@ impl FromStr for BuiltinScalarFunction {
"abs" => BuiltinScalarFunction::Abs,
"signum" => BuiltinScalarFunction::Signum,
"length" => BuiltinScalarFunction::Length,
"char_length" => BuiltinScalarFunction::Length,
"character_length" => BuiltinScalarFunction::Length,
"concat" => BuiltinScalarFunction::Concat,
"lower" => BuiltinScalarFunction::Lower,
"trim" => BuiltinScalarFunction::Trim,
"upper" => BuiltinScalarFunction::Upper,
"to_timestamp" => BuiltinScalarFunction::ToTimestamp,
"array" => BuiltinScalarFunction::Array,
"nullif" => BuiltinScalarFunction::NullIf,
Expand Down Expand Up @@ -203,6 +214,36 @@ pub fn return_type(
}
}),
BuiltinScalarFunction::Concat => Ok(DataType::Utf8),
BuiltinScalarFunction::Lower => Ok(match arg_types[0] {
DataType::LargeUtf8 => DataType::LargeUtf8,
DataType::Utf8 => DataType::Utf8,
_ => {
// this error is internal as `data_types` should have captured this.
return Err(DataFusionError::Internal(
"The upper function can only accept strings.".to_string(),
));
}
}),
BuiltinScalarFunction::Trim => Ok(match arg_types[0] {
DataType::LargeUtf8 => DataType::LargeUtf8,
DataType::Utf8 => DataType::Utf8,
_ => {
// this error is internal as `data_types` should have captured this.
return Err(DataFusionError::Internal(
"The trim function can only accept strings.".to_string(),
));
}
}),
BuiltinScalarFunction::Upper => Ok(match arg_types[0] {
DataType::LargeUtf8 => DataType::LargeUtf8,
DataType::Utf8 => DataType::Utf8,
_ => {
// this error is internal as `data_types` should have captured this.
return Err(DataFusionError::Internal(
"The upper function can only accept strings.".to_string(),
));
}
}),
BuiltinScalarFunction::ToTimestamp => {
Ok(DataType::Timestamp(TimeUnit::Nanosecond, None))
}
Expand Down Expand Up @@ -249,6 +290,30 @@ pub fn create_physical_expr(
BuiltinScalarFunction::Concat => {
|args| Ok(Arc::new(string_expressions::concatenate(args)?))
}
BuiltinScalarFunction::Lower => |args| match args[0].data_type() {
DataType::Utf8 => Ok(Arc::new(string_expressions::lower::<i32>(args)?)),
DataType::LargeUtf8 => Ok(Arc::new(string_expressions::lower::<i64>(args)?)),
other => Err(DataFusionError::Internal(format!(
"Unsupported data type {:?} for function lower",
other,
))),
},
BuiltinScalarFunction::Trim => |args| match args[0].data_type() {
DataType::Utf8 => Ok(Arc::new(string_expressions::trim::<i32>(args)?)),
DataType::LargeUtf8 => Ok(Arc::new(string_expressions::trim::<i64>(args)?)),
other => Err(DataFusionError::Internal(format!(
"Unsupported data type {:?} for function trim",
other,
))),
},
BuiltinScalarFunction::Upper => |args| match args[0].data_type() {
DataType::Utf8 => Ok(Arc::new(string_expressions::upper::<i32>(args)?)),
DataType::LargeUtf8 => Ok(Arc::new(string_expressions::upper::<i64>(args)?)),
other => Err(DataFusionError::Internal(format!(
"Unsupported data type {:?} for function upper",
other,
))),
},
BuiltinScalarFunction::ToTimestamp => {
|args| Ok(Arc::new(datetime_expressions::to_timestamp(args)?))
}
Expand Down Expand Up @@ -280,6 +345,15 @@ fn signature(fun: &BuiltinScalarFunction) -> Signature {
Signature::Uniform(1, vec![DataType::Utf8, DataType::LargeUtf8])
}
BuiltinScalarFunction::Concat => Signature::Variadic(vec![DataType::Utf8]),
BuiltinScalarFunction::Lower => {
Signature::Uniform(1, vec![DataType::Utf8, DataType::LargeUtf8])
}
BuiltinScalarFunction::Upper => {
Signature::Uniform(1, vec![DataType::Utf8, DataType::LargeUtf8])
}
BuiltinScalarFunction::Trim => {
Signature::Uniform(1, vec![DataType::Utf8, DataType::LargeUtf8])
}
BuiltinScalarFunction::ToTimestamp => Signature::Uniform(1, vec![DataType::Utf8]),
BuiltinScalarFunction::Array => {
Signature::Variadic(array_expressions::SUPPORTED_ARRAY_TYPES.to_vec())
Expand Down
25 changes: 24 additions & 1 deletion rust/datafusion/src/physical_plan/string_expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@
//! String expressions

use crate::error::{DataFusionError, Result};
use arrow::array::{Array, ArrayRef, StringArray, StringBuilder};
use arrow::array::{
Array, ArrayRef, GenericStringArray, StringArray, StringBuilder,
StringOffsetSizeTrait,
};

macro_rules! downcast_vec {
($ARGS:expr, $ARRAY_TYPE:ident) => {{
Expand Down Expand Up @@ -66,3 +69,23 @@ pub fn concatenate(args: &[ArrayRef]) -> Result<StringArray> {
}
Ok(builder.finish())
}

macro_rules! string_unary_function {
($NAME:ident, $FUNC:ident) => {
/// string function that accepts Utf8 or LargeUtf8 and returns Utf8 or LargeUtf8
pub fn $NAME<T: StringOffsetSizeTrait>(
args: &[ArrayRef],
) -> Result<GenericStringArray<T>> {
let array = args[0]
.as_any()
.downcast_ref::<GenericStringArray<T>>()
.unwrap();
// first map is the iterator, second is for the `Option<_>`
Ok(array.iter().map(|x| x.map(|x| x.$FUNC())).collect())
}
};
}

string_unary_function!(lower, to_ascii_lowercase);
string_unary_function!(upper, to_ascii_uppercase);
string_unary_function!(trim, trim);
2 changes: 2 additions & 0 deletions rust/datafusion/src/sql/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -550,6 +550,8 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
},
SQLExpr::Value(Value::SingleQuotedString(ref s)) => Ok(lit(s.clone())),

SQLExpr::Value(Value::Null) => Ok(Expr::Literal(ScalarValue::Utf8(None))),

SQLExpr::Identifier(ref id) => {
if &id.value[0..1] == "@" {
let var_names = vec![id.value.clone()];
Expand Down
24 changes: 24 additions & 0 deletions rust/datafusion/tests/sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1826,3 +1826,27 @@ async fn csv_between_expr_negated() -> Result<()> {
assert_eq!(expected, actual);
Ok(())
}

#[tokio::test]
async fn string_expressions() -> Result<()> {
let mut ctx = ExecutionContext::new();
let sql = "SELECT
char_length('tom') AS char_length
,char_length(NULL) AS char_length_null
,character_length('tom') AS character_length
,character_length(NULL) AS character_length_null
,lower('TOM') AS lower
,lower(NULL) AS lower_null
,upper('tom') AS upper
,upper(NULL) AS upper_null
,trim(' tom ') AS trim
,trim(NULL) AS trim_null
";
let actual = execute(&mut ctx, sql).await;

let expected = vec![vec![
"3", "NULL", "3", "NULL", "tom", "NULL", "TOM", "NULL", "tom", "NULL",
]];
assert_eq!(expected, actual);
Ok(())
}