From 8df27ec2a16c3ff820562d9f5f72bb128502a6eb Mon Sep 17 00:00:00 2001 From: Mike Seddon Date: Sat, 19 Dec 2020 11:18:10 +1100 Subject: [PATCH 1/5] add basic string functions --- .../datafusion/src/physical_plan/functions.rs | 43 ++++++++++- .../src/physical_plan/string_expressions.rs | 77 ++++++++++++++++++- rust/datafusion/tests/sql.rs | 18 +++++ 3 files changed, 136 insertions(+), 2 deletions(-) diff --git a/rust/datafusion/src/physical_plan/functions.rs b/rust/datafusion/src/physical_plan/functions.rs index fb6c63b709f..687013a3459 100644 --- a/rust/datafusion/src/physical_plan/functions.rs +++ b/rust/datafusion/src/physical_plan/functions.rs @@ -61,7 +61,7 @@ pub enum Signature { VariadicEqual, /// fixed number of arguments of an arbitrary but equal type out of a list of valid types // A function of one argument of f64 is `Uniform(1, vec![DataType::Float64])` - // A function of two arguments of f64 or f32 is `Uniform(1, vec![DataType::Float32, DataType::Float64])` + // A function of one argument of f64 or f32 is `Uniform(1, vec![DataType::Float32, DataType::Float64])` Uniform(usize, Vec), /// exact number of arguments of an exact type Exact(Vec), @@ -118,6 +118,14 @@ pub enum BuiltinScalarFunction { Length, /// concat Concat, + /// character_length + CharacterLength, + /// lower + Lower, + /// upper + Upper, + /// ltrim + Trim, /// to_timestamp ToTimestamp, /// construct an array from columns @@ -156,6 +164,11 @@ impl FromStr for BuiltinScalarFunction { "signum" => BuiltinScalarFunction::Signum, "length" => BuiltinScalarFunction::Length, "concat" => BuiltinScalarFunction::Concat, + "char_length" => BuiltinScalarFunction::CharacterLength, + "character_length" => BuiltinScalarFunction::CharacterLength, + "lower" => BuiltinScalarFunction::Lower, + "upper" => BuiltinScalarFunction::Upper, + "trim" => BuiltinScalarFunction::Trim, "to_timestamp" => BuiltinScalarFunction::ToTimestamp, "array" => BuiltinScalarFunction::Array, "nullif" => BuiltinScalarFunction::NullIf, @@ -203,6 +216,10 @@ pub fn return_type( } }), BuiltinScalarFunction::Concat => Ok(DataType::Utf8), + BuiltinScalarFunction::CharacterLength => Ok(DataType::Int32), + BuiltinScalarFunction::Lower => Ok(DataType::Utf8), + BuiltinScalarFunction::Upper => Ok(DataType::Utf8), + BuiltinScalarFunction::Trim => Ok(DataType::Utf8), BuiltinScalarFunction::ToTimestamp => { Ok(DataType::Timestamp(TimeUnit::Nanosecond, None)) } @@ -249,6 +266,18 @@ pub fn create_physical_expr( BuiltinScalarFunction::Concat => { |args| Ok(Arc::new(string_expressions::concatenate(args)?)) } + BuiltinScalarFunction::CharacterLength => { + |args| Ok(Arc::new(string_expressions::character_length(args)?)) + } + BuiltinScalarFunction::Lower => { + |args| Ok(Arc::new(string_expressions::lower(args)?)) + } + BuiltinScalarFunction::Upper => { + |args| Ok(Arc::new(string_expressions::upper(args)?)) + } + BuiltinScalarFunction::Trim => { + |args| Ok(Arc::new(string_expressions::trim(args)?)) + } BuiltinScalarFunction::ToTimestamp => { |args| Ok(Arc::new(datetime_expressions::to_timestamp(args)?)) } @@ -280,6 +309,18 @@ fn signature(fun: &BuiltinScalarFunction) -> Signature { Signature::Uniform(1, vec![DataType::Utf8, DataType::LargeUtf8]) } BuiltinScalarFunction::Concat => Signature::Variadic(vec![DataType::Utf8]), + BuiltinScalarFunction::CharacterLength => { + Signature::Uniform(1, vec![DataType::Utf8, DataType::LargeUtf8]) + } + BuiltinScalarFunction::Lower => { + Signature::Uniform(1, vec![DataType::Utf8, DataType::LargeUtf8]) + } + BuiltinScalarFunction::Upper => { + Signature::Uniform(1, vec![DataType::Utf8, DataType::LargeUtf8]) + } + BuiltinScalarFunction::Trim => { + Signature::Uniform(1, vec![DataType::Utf8, DataType::LargeUtf8]) + } BuiltinScalarFunction::ToTimestamp => Signature::Uniform(1, vec![DataType::Utf8]), BuiltinScalarFunction::Array => { Signature::Variadic(array_expressions::SUPPORTED_ARRAY_TYPES.to_vec()) diff --git a/rust/datafusion/src/physical_plan/string_expressions.rs b/rust/datafusion/src/physical_plan/string_expressions.rs index fb65f91ce3a..f6d5ba760d4 100644 --- a/rust/datafusion/src/physical_plan/string_expressions.rs +++ b/rust/datafusion/src/physical_plan/string_expressions.rs @@ -18,7 +18,12 @@ //! String expressions use crate::error::{DataFusionError, Result}; -use arrow::array::{Array, ArrayRef, StringArray, StringBuilder}; +use arrow::{ + array::{Array, ArrayData, ArrayRef, Int32Array, StringArray, StringBuilder}, + buffer::Buffer, + datatypes::{DataType, ToByteSlice}, +}; +use std::sync::Arc; macro_rules! downcast_vec { ($ARGS:expr, $ARRAY_TYPE:ident) => {{ @@ -66,3 +71,73 @@ pub fn concatenate(args: &[ArrayRef]) -> Result { } Ok(builder.finish()) } + +/// character_length returns number of characters in the string +/// character_length('josé') = 4 +pub fn character_length(args: &[ArrayRef]) -> Result { + let num_rows = args[0].len(); + let string_args = + &args[0] + .as_any() + .downcast_ref::() + .ok_or_else(|| { + DataFusionError::Internal( + "could not cast input to StringArray".to_string(), + ) + })?; + + let result = (0..num_rows) + .map(|i| { + if string_args.is_null(i) { + // NB: Since we use the same null bitset as the input, + // the output for this value will be ignored, but we + // need some value in the array we are building. + Ok(0) + } else { + Ok(string_args.value(i).chars().count() as i32) + } + }) + .collect::>>()?; + + let data = ArrayData::new( + DataType::Int32, + num_rows, + Some(string_args.null_count()), + string_args.data().null_buffer().cloned(), + 0, + vec![Buffer::from(result.to_byte_slice())], + vec![], + ); + + Ok(Int32Array::from(Arc::new(data))) +} + +macro_rules! string_unary_function { + ($NAME:ident, $FUNC:ident) => { + /// string function that accepts utf8 and returns utf8 + pub fn $NAME(args: &[ArrayRef]) -> Result { + let string_args = &args[0] + .as_any() + .downcast_ref::() + .ok_or_else(|| { + DataFusionError::Internal( + "could not cast input to StringArray".to_string(), + ) + })?; + + let mut builder = StringBuilder::new(args.len()); + for index in 0..args[0].len() { + if string_args.is_null(index) { + builder.append_null()?; + } else { + builder.append_value(&string_args.value(index).$FUNC())?; + } + } + Ok(builder.finish()) + } + }; +} + +string_unary_function!(lower, to_ascii_lowercase); +string_unary_function!(upper, to_ascii_uppercase); +string_unary_function!(trim, trim); diff --git a/rust/datafusion/tests/sql.rs b/rust/datafusion/tests/sql.rs index 84c90a5d326..310ed71615e 100644 --- a/rust/datafusion/tests/sql.rs +++ b/rust/datafusion/tests/sql.rs @@ -1826,3 +1826,21 @@ async fn csv_between_expr_negated() -> Result<()> { assert_eq!(expected, actual); Ok(()) } + +#[tokio::test] +async fn string_expressions() -> Result<()> { + let mut ctx = ExecutionContext::new(); + register_aggregate_csv(&mut ctx)?; + let sql = "SELECT + char_length('josé') AS char_length + ,character_length('josé') AS character_length + ,lower('TOM') AS lower + ,upper('tom') AS upper + ,trim(' tom ') AS trim + "; + let actual = execute(&mut ctx, sql).await; + + let expected = vec![vec!["4", "4", "tom", "TOM", "tom"]]; + assert_eq!(expected, actual); + Ok(()) +} From a2ccbbadbaafc176073d396822f4be6717978a9b Mon Sep 17 00:00:00 2001 From: Mike Seddon Date: Sun, 20 Dec 2020 11:31:01 +1100 Subject: [PATCH 2/5] rewrite to support Utf8 and LargeUtf8 --- rust/datafusion/src/logical_plan/expr.rs | 42 +++++---- rust/datafusion/src/logical_plan/mod.rs | 8 +- .../datafusion/src/physical_plan/functions.rs | 51 +++++++--- .../src/physical_plan/string_expressions.rs | 92 +++++++++++++------ rust/datafusion/tests/sql.rs | 1 - 5 files changed, 127 insertions(+), 67 deletions(-) diff --git a/rust/datafusion/src/logical_plan/expr.rs b/rust/datafusion/src/logical_plan/expr.rs index ca72948f1af..7d149f04c77 100644 --- a/rust/datafusion/src/logical_plan/expr.rs +++ b/rust/datafusion/src/logical_plan/expr.rs @@ -617,7 +617,7 @@ pub fn lit(n: T) -> Expr { } /// Create an convenience function representing a unary scalar function -macro_rules! unary_math_expr { +macro_rules! unary_scalar_expr { ($ENUM:ident, $FUNC:ident) => { #[allow(missing_docs)] pub fn $FUNC(e: Expr) -> Expr { @@ -629,24 +629,28 @@ macro_rules! unary_math_expr { }; } -// generate methods for creating the supported unary math expressions -unary_math_expr!(Sqrt, sqrt); -unary_math_expr!(Sin, sin); -unary_math_expr!(Cos, cos); -unary_math_expr!(Tan, tan); -unary_math_expr!(Asin, asin); -unary_math_expr!(Acos, acos); -unary_math_expr!(Atan, atan); -unary_math_expr!(Floor, floor); -unary_math_expr!(Ceil, ceil); -unary_math_expr!(Round, round); -unary_math_expr!(Trunc, trunc); -unary_math_expr!(Abs, abs); -unary_math_expr!(Signum, signum); -unary_math_expr!(Exp, exp); -unary_math_expr!(Log, ln); -unary_math_expr!(Log2, log2); -unary_math_expr!(Log10, log10); +// generate methods for creating the supported unary expressions +unary_scalar_expr!(Sqrt, sqrt); +unary_scalar_expr!(Sin, sin); +unary_scalar_expr!(Cos, cos); +unary_scalar_expr!(Tan, tan); +unary_scalar_expr!(Asin, asin); +unary_scalar_expr!(Acos, acos); +unary_scalar_expr!(Atan, atan); +unary_scalar_expr!(Floor, floor); +unary_scalar_expr!(Ceil, ceil); +unary_scalar_expr!(Round, round); +unary_scalar_expr!(Trunc, trunc); +unary_scalar_expr!(Abs, abs); +unary_scalar_expr!(Signum, signum); +unary_scalar_expr!(Exp, exp); +unary_scalar_expr!(Log, ln); +unary_scalar_expr!(Log2, log2); +unary_scalar_expr!(Log10, log10); +unary_scalar_expr!(CharacterLength, character_length); +unary_scalar_expr!(Lower, lower); +unary_scalar_expr!(Trim, trim); +unary_scalar_expr!(Upper, upper); /// returns the length of a string in bytes pub fn length(e: Expr) -> Expr { diff --git a/rust/datafusion/src/logical_plan/mod.rs b/rust/datafusion/src/logical_plan/mod.rs index fc38057dfe0..b5f6926a7ad 100644 --- a/rust/datafusion/src/logical_plan/mod.rs +++ b/rust/datafusion/src/logical_plan/mod.rs @@ -34,10 +34,10 @@ pub use builder::LogicalPlanBuilder; pub use dfschema::{DFField, DFSchema, DFSchemaRef, ToDFSchema}; pub use display::display_schema; pub use expr::{ - abs, acos, and, array, asin, atan, avg, binary_expr, case, ceil, col, concat, cos, - count, create_udaf, create_udf, exp, exprlist_to_fields, floor, length, lit, ln, - log10, log2, max, min, or, round, signum, sin, sqrt, sum, tan, trunc, when, Expr, - Literal, + abs, acos, and, array, asin, atan, avg, binary_expr, case, ceil, character_length, + col, concat, cos, count, create_udaf, create_udf, exp, exprlist_to_fields, floor, + length, lit, ln, log10, log2, lower, max, min, or, round, signum, sin, sqrt, sum, + tan, trim, trunc, upper, when, Expr, Literal, }; pub use extension::UserDefinedLogicalNode; pub use operators::Operator; diff --git a/rust/datafusion/src/physical_plan/functions.rs b/rust/datafusion/src/physical_plan/functions.rs index 687013a3459..d7ae269649e 100644 --- a/rust/datafusion/src/physical_plan/functions.rs +++ b/rust/datafusion/src/physical_plan/functions.rs @@ -124,7 +124,7 @@ pub enum BuiltinScalarFunction { Lower, /// upper Upper, - /// ltrim + /// trim Trim, /// to_timestamp ToTimestamp, @@ -167,8 +167,8 @@ impl FromStr for BuiltinScalarFunction { "char_length" => BuiltinScalarFunction::CharacterLength, "character_length" => BuiltinScalarFunction::CharacterLength, "lower" => BuiltinScalarFunction::Lower, - "upper" => BuiltinScalarFunction::Upper, "trim" => BuiltinScalarFunction::Trim, + "upper" => BuiltinScalarFunction::Upper, "to_timestamp" => BuiltinScalarFunction::ToTimestamp, "array" => BuiltinScalarFunction::Array, "nullif" => BuiltinScalarFunction::NullIf, @@ -216,10 +216,37 @@ pub fn return_type( } }), BuiltinScalarFunction::Concat => Ok(DataType::Utf8), - BuiltinScalarFunction::CharacterLength => Ok(DataType::Int32), - BuiltinScalarFunction::Lower => Ok(DataType::Utf8), - BuiltinScalarFunction::Upper => Ok(DataType::Utf8), - BuiltinScalarFunction::Trim => Ok(DataType::Utf8), + BuiltinScalarFunction::CharacterLength => Ok(DataType::UInt32), + BuiltinScalarFunction::Lower => Ok(match arg_types[0] { + DataType::LargeUtf8 => DataType::LargeUtf8, + DataType::Utf8 => DataType::Utf8, + _ => { + // this error is internal as `data_types` should have captured this. + return Err(DataFusionError::Internal( + "The upper function can only accept strings.".to_string(), + )); + } + }), + BuiltinScalarFunction::Trim => Ok(match arg_types[0] { + DataType::LargeUtf8 => DataType::LargeUtf8, + DataType::Utf8 => DataType::Utf8, + _ => { + // this error is internal as `data_types` should have captured this. + return Err(DataFusionError::Internal( + "The trim function can only accept strings.".to_string(), + )); + } + }), + BuiltinScalarFunction::Upper => Ok(match arg_types[0] { + DataType::LargeUtf8 => DataType::LargeUtf8, + DataType::Utf8 => DataType::Utf8, + _ => { + // this error is internal as `data_types` should have captured this. + return Err(DataFusionError::Internal( + "The upper function can only accept strings.".to_string(), + )); + } + }), BuiltinScalarFunction::ToTimestamp => { Ok(DataType::Timestamp(TimeUnit::Nanosecond, None)) } @@ -269,15 +296,9 @@ pub fn create_physical_expr( BuiltinScalarFunction::CharacterLength => { |args| Ok(Arc::new(string_expressions::character_length(args)?)) } - BuiltinScalarFunction::Lower => { - |args| Ok(Arc::new(string_expressions::lower(args)?)) - } - BuiltinScalarFunction::Upper => { - |args| Ok(Arc::new(string_expressions::upper(args)?)) - } - BuiltinScalarFunction::Trim => { - |args| Ok(Arc::new(string_expressions::trim(args)?)) - } + BuiltinScalarFunction::Lower => string_expressions::lower, + BuiltinScalarFunction::Trim => string_expressions::trim, + BuiltinScalarFunction::Upper => string_expressions::upper, BuiltinScalarFunction::ToTimestamp => { |args| Ok(Arc::new(datetime_expressions::to_timestamp(args)?)) } diff --git a/rust/datafusion/src/physical_plan/string_expressions.rs b/rust/datafusion/src/physical_plan/string_expressions.rs index f6d5ba760d4..e2877b5e79b 100644 --- a/rust/datafusion/src/physical_plan/string_expressions.rs +++ b/rust/datafusion/src/physical_plan/string_expressions.rs @@ -19,7 +19,10 @@ use crate::error::{DataFusionError, Result}; use arrow::{ - array::{Array, ArrayData, ArrayRef, Int32Array, StringArray, StringBuilder}, + array::{ + Array, ArrayData, ArrayRef, LargeStringArray, LargeStringBuilder, StringArray, + StringBuilder, UInt32Array, + }, buffer::Buffer, datatypes::{DataType, ToByteSlice}, }; @@ -74,7 +77,7 @@ pub fn concatenate(args: &[ArrayRef]) -> Result { /// character_length returns number of characters in the string /// character_length('josé') = 4 -pub fn character_length(args: &[ArrayRef]) -> Result { +pub fn character_length(args: &[ArrayRef]) -> Result { let num_rows = args[0].len(); let string_args = &args[0] @@ -94,13 +97,13 @@ pub fn character_length(args: &[ArrayRef]) -> Result { // need some value in the array we are building. Ok(0) } else { - Ok(string_args.value(i).chars().count() as i32) + Ok(string_args.value(i).chars().count() as u32) } }) .collect::>>()?; let data = ArrayData::new( - DataType::Int32, + DataType::UInt32, num_rows, Some(string_args.null_count()), string_args.data().null_buffer().cloned(), @@ -109,35 +112,68 @@ pub fn character_length(args: &[ArrayRef]) -> Result { vec![], ); - Ok(Int32Array::from(Arc::new(data))) + Ok(UInt32Array::from(Arc::new(data))) } -macro_rules! string_unary_function { - ($NAME:ident, $FUNC:ident) => { - /// string function that accepts utf8 and returns utf8 - pub fn $NAME(args: &[ArrayRef]) -> Result { - let string_args = &args[0] - .as_any() - .downcast_ref::() - .ok_or_else(|| { - DataFusionError::Internal( - "could not cast input to StringArray".to_string(), - ) - })?; +macro_rules! compute_op { + ($ARRAY:expr, $FUNC:ident, $TYPE:ident, $BUILDER:ident) => {{ + let mut builder = $BUILDER::new($ARRAY.len()); + for index in 0..$ARRAY.len() { + if $ARRAY.is_null(index) { + builder.append_null()?; + } else { + builder.append_value(&$ARRAY.value(index).$FUNC())?; + } + } + Ok(Arc::new(builder.finish())) + }}; +} - let mut builder = StringBuilder::new(args.len()); - for index in 0..args[0].len() { - if string_args.is_null(index) { - builder.append_null()?; - } else { - builder.append_value(&string_args.value(index).$FUNC())?; - } +macro_rules! downcast_compute_op { + ($ARRAY:expr, $NAME:expr, $FUNC:ident, $TYPE:ident, $BUILDER:ident) => {{ + let n = $ARRAY.as_any().downcast_ref::<$TYPE>(); + match n { + Some(array) => compute_op!(array, $FUNC, $TYPE, $BUILDER), + _ => Err(DataFusionError::Internal(format!( + "Invalid data type for {}", + $NAME + ))), + } + }}; +} + +macro_rules! unary_primitive_array_op { + ($ARRAY:expr, $NAME:expr, $FUNC:ident) => {{ + match ($ARRAY).data_type() { + DataType::Utf8 => { + downcast_compute_op!($ARRAY, $NAME, $FUNC, StringArray, StringBuilder) + } + DataType::LargeUtf8 => { + downcast_compute_op!( + $ARRAY, + $NAME, + $FUNC, + LargeStringArray, + LargeStringBuilder + ) } - Ok(builder.finish()) + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {:?} for function {}", + other, $NAME, + ))), + } + }}; +} + +macro_rules! string_unary_function { + ($NAME:expr, $FUNC:ident, $STRINGFUNC:ident) => { + /// string function that accepts Utf8 or LargeUtf8 and returns StringArray or LargeStringArray + pub fn $FUNC(args: &[ArrayRef]) -> Result { + unary_primitive_array_op!(args[0], $NAME, $STRINGFUNC) } }; } -string_unary_function!(lower, to_ascii_lowercase); -string_unary_function!(upper, to_ascii_uppercase); -string_unary_function!(trim, trim); +string_unary_function!("lower", lower, to_ascii_lowercase); +string_unary_function!("upper", upper, to_ascii_uppercase); +string_unary_function!("trim", trim, trim); diff --git a/rust/datafusion/tests/sql.rs b/rust/datafusion/tests/sql.rs index 310ed71615e..9d96a66952c 100644 --- a/rust/datafusion/tests/sql.rs +++ b/rust/datafusion/tests/sql.rs @@ -1830,7 +1830,6 @@ async fn csv_between_expr_negated() -> Result<()> { #[tokio::test] async fn string_expressions() -> Result<()> { let mut ctx = ExecutionContext::new(); - register_aggregate_csv(&mut ctx)?; let sql = "SELECT char_length('josé') AS char_length ,character_length('josé') AS character_length From 011d6228495e5a17e08a08f71e04dca8705027ee Mon Sep 17 00:00:00 2001 From: Mike Seddon Date: Sun, 20 Dec 2020 11:36:03 +1100 Subject: [PATCH 3/5] update README with unary function addition steps --- rust/datafusion/README.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/rust/datafusion/README.md b/rust/datafusion/README.md index 78b0a3172b7..c8688257074 100644 --- a/rust/datafusion/README.md +++ b/rust/datafusion/README.md @@ -155,6 +155,10 @@ Below is a checklist of what you need to do to add a new scalar function to Data * a new line in `create_physical_expr` mapping the built-in to the implementation * tests to the function. * In [tests/sql.rs](tests/sql.rs), add a new test where the function is called through SQL against well known data and returns the expected result. +* In [src/logical_plan/expr](src/logical_plan/expr.rs), add: + * a new entry of the `unary_scalar_expr!` macro for the new function. +* In [src/logical_plan/mod](src/logical_plan/mod.rs), add: + * a new entry in the `pub use expr::{}` set. ## How to add a new aggregate function From 826750159d0edb30ce6b055a2631b688aa134719 Mon Sep 17 00:00:00 2001 From: Mike Seddon Date: Sun, 20 Dec 2020 11:40:24 +1100 Subject: [PATCH 4/5] add support for NULL --- rust/datafusion/src/sql/planner.rs | 2 ++ rust/datafusion/tests/sql.rs | 9 ++++++++- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/rust/datafusion/src/sql/planner.rs b/rust/datafusion/src/sql/planner.rs index 4d469c83059..0d5c819e1db 100644 --- a/rust/datafusion/src/sql/planner.rs +++ b/rust/datafusion/src/sql/planner.rs @@ -550,6 +550,8 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { }, SQLExpr::Value(Value::SingleQuotedString(ref s)) => Ok(lit(s.clone())), + SQLExpr::Value(Value::Null) => Ok(Expr::Literal(ScalarValue::Utf8(None))), + SQLExpr::Identifier(ref id) => { if &id.value[0..1] == "@" { let var_names = vec![id.value.clone()]; diff --git a/rust/datafusion/tests/sql.rs b/rust/datafusion/tests/sql.rs index 9d96a66952c..961ccfb650c 100644 --- a/rust/datafusion/tests/sql.rs +++ b/rust/datafusion/tests/sql.rs @@ -1832,14 +1832,21 @@ async fn string_expressions() -> Result<()> { let mut ctx = ExecutionContext::new(); let sql = "SELECT char_length('josé') AS char_length + ,char_length(NULL) AS char_length_null ,character_length('josé') AS character_length + ,character_length(NULL) AS character_length_null ,lower('TOM') AS lower + ,lower(NULL) AS lower_null ,upper('tom') AS upper + ,upper(NULL) AS upper_null ,trim(' tom ') AS trim + ,trim(NULL) AS trim_null "; let actual = execute(&mut ctx, sql).await; - let expected = vec![vec!["4", "4", "tom", "TOM", "tom"]]; + let expected = vec![vec![ + "4", "NULL", "4", "NULL", "tom", "NULL", "TOM", "NULL", "tom", "NULL", + ]]; assert_eq!(expected, actual); Ok(()) } From 2f26b23d324944d04d64548c51f2d51de1cce027 Mon Sep 17 00:00:00 2001 From: Mike Seddon Date: Mon, 21 Dec 2020 12:34:34 +1100 Subject: [PATCH 5/5] major simplification based on feedback --- rust/datafusion/src/logical_plan/expr.rs | 1 - rust/datafusion/src/logical_plan/mod.rs | 8 +- .../datafusion/src/physical_plan/functions.rs | 40 ++++-- .../src/physical_plan/string_expressions.rs | 122 +++--------------- rust/datafusion/tests/sql.rs | 6 +- 5 files changed, 50 insertions(+), 127 deletions(-) diff --git a/rust/datafusion/src/logical_plan/expr.rs b/rust/datafusion/src/logical_plan/expr.rs index 7d149f04c77..3bcfd078b3c 100644 --- a/rust/datafusion/src/logical_plan/expr.rs +++ b/rust/datafusion/src/logical_plan/expr.rs @@ -647,7 +647,6 @@ unary_scalar_expr!(Exp, exp); unary_scalar_expr!(Log, ln); unary_scalar_expr!(Log2, log2); unary_scalar_expr!(Log10, log10); -unary_scalar_expr!(CharacterLength, character_length); unary_scalar_expr!(Lower, lower); unary_scalar_expr!(Trim, trim); unary_scalar_expr!(Upper, upper); diff --git a/rust/datafusion/src/logical_plan/mod.rs b/rust/datafusion/src/logical_plan/mod.rs index b5f6926a7ad..4cd4d99587b 100644 --- a/rust/datafusion/src/logical_plan/mod.rs +++ b/rust/datafusion/src/logical_plan/mod.rs @@ -34,10 +34,10 @@ pub use builder::LogicalPlanBuilder; pub use dfschema::{DFField, DFSchema, DFSchemaRef, ToDFSchema}; pub use display::display_schema; pub use expr::{ - abs, acos, and, array, asin, atan, avg, binary_expr, case, ceil, character_length, - col, concat, cos, count, create_udaf, create_udf, exp, exprlist_to_fields, floor, - length, lit, ln, log10, log2, lower, max, min, or, round, signum, sin, sqrt, sum, - tan, trim, trunc, upper, when, Expr, Literal, + abs, acos, and, array, asin, atan, avg, binary_expr, case, ceil, col, concat, cos, + count, create_udaf, create_udf, exp, exprlist_to_fields, floor, length, lit, ln, + log10, log2, lower, max, min, or, round, signum, sin, sqrt, sum, tan, trim, trunc, + upper, when, Expr, Literal, }; pub use extension::UserDefinedLogicalNode; pub use operators::Operator; diff --git a/rust/datafusion/src/physical_plan/functions.rs b/rust/datafusion/src/physical_plan/functions.rs index d7ae269649e..94e47bd56a0 100644 --- a/rust/datafusion/src/physical_plan/functions.rs +++ b/rust/datafusion/src/physical_plan/functions.rs @@ -118,8 +118,6 @@ pub enum BuiltinScalarFunction { Length, /// concat Concat, - /// character_length - CharacterLength, /// lower Lower, /// upper @@ -163,9 +161,9 @@ impl FromStr for BuiltinScalarFunction { "abs" => BuiltinScalarFunction::Abs, "signum" => BuiltinScalarFunction::Signum, "length" => BuiltinScalarFunction::Length, + "char_length" => BuiltinScalarFunction::Length, + "character_length" => BuiltinScalarFunction::Length, "concat" => BuiltinScalarFunction::Concat, - "char_length" => BuiltinScalarFunction::CharacterLength, - "character_length" => BuiltinScalarFunction::CharacterLength, "lower" => BuiltinScalarFunction::Lower, "trim" => BuiltinScalarFunction::Trim, "upper" => BuiltinScalarFunction::Upper, @@ -216,7 +214,6 @@ pub fn return_type( } }), BuiltinScalarFunction::Concat => Ok(DataType::Utf8), - BuiltinScalarFunction::CharacterLength => Ok(DataType::UInt32), BuiltinScalarFunction::Lower => Ok(match arg_types[0] { DataType::LargeUtf8 => DataType::LargeUtf8, DataType::Utf8 => DataType::Utf8, @@ -293,12 +290,30 @@ pub fn create_physical_expr( BuiltinScalarFunction::Concat => { |args| Ok(Arc::new(string_expressions::concatenate(args)?)) } - BuiltinScalarFunction::CharacterLength => { - |args| Ok(Arc::new(string_expressions::character_length(args)?)) - } - BuiltinScalarFunction::Lower => string_expressions::lower, - BuiltinScalarFunction::Trim => string_expressions::trim, - BuiltinScalarFunction::Upper => string_expressions::upper, + BuiltinScalarFunction::Lower => |args| match args[0].data_type() { + DataType::Utf8 => Ok(Arc::new(string_expressions::lower::(args)?)), + DataType::LargeUtf8 => Ok(Arc::new(string_expressions::lower::(args)?)), + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {:?} for function lower", + other, + ))), + }, + BuiltinScalarFunction::Trim => |args| match args[0].data_type() { + DataType::Utf8 => Ok(Arc::new(string_expressions::trim::(args)?)), + DataType::LargeUtf8 => Ok(Arc::new(string_expressions::trim::(args)?)), + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {:?} for function trim", + other, + ))), + }, + BuiltinScalarFunction::Upper => |args| match args[0].data_type() { + DataType::Utf8 => Ok(Arc::new(string_expressions::upper::(args)?)), + DataType::LargeUtf8 => Ok(Arc::new(string_expressions::upper::(args)?)), + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {:?} for function upper", + other, + ))), + }, BuiltinScalarFunction::ToTimestamp => { |args| Ok(Arc::new(datetime_expressions::to_timestamp(args)?)) } @@ -330,9 +345,6 @@ fn signature(fun: &BuiltinScalarFunction) -> Signature { Signature::Uniform(1, vec![DataType::Utf8, DataType::LargeUtf8]) } BuiltinScalarFunction::Concat => Signature::Variadic(vec![DataType::Utf8]), - BuiltinScalarFunction::CharacterLength => { - Signature::Uniform(1, vec![DataType::Utf8, DataType::LargeUtf8]) - } BuiltinScalarFunction::Lower => { Signature::Uniform(1, vec![DataType::Utf8, DataType::LargeUtf8]) } diff --git a/rust/datafusion/src/physical_plan/string_expressions.rs b/rust/datafusion/src/physical_plan/string_expressions.rs index e2877b5e79b..328e24b8b92 100644 --- a/rust/datafusion/src/physical_plan/string_expressions.rs +++ b/rust/datafusion/src/physical_plan/string_expressions.rs @@ -18,15 +18,10 @@ //! String expressions use crate::error::{DataFusionError, Result}; -use arrow::{ - array::{ - Array, ArrayData, ArrayRef, LargeStringArray, LargeStringBuilder, StringArray, - StringBuilder, UInt32Array, - }, - buffer::Buffer, - datatypes::{DataType, ToByteSlice}, +use arrow::array::{ + Array, ArrayRef, GenericStringArray, StringArray, StringBuilder, + StringOffsetSizeTrait, }; -use std::sync::Arc; macro_rules! downcast_vec { ($ARGS:expr, $ARRAY_TYPE:ident) => {{ @@ -75,105 +70,22 @@ pub fn concatenate(args: &[ArrayRef]) -> Result { Ok(builder.finish()) } -/// character_length returns number of characters in the string -/// character_length('josé') = 4 -pub fn character_length(args: &[ArrayRef]) -> Result { - let num_rows = args[0].len(); - let string_args = - &args[0] - .as_any() - .downcast_ref::() - .ok_or_else(|| { - DataFusionError::Internal( - "could not cast input to StringArray".to_string(), - ) - })?; - - let result = (0..num_rows) - .map(|i| { - if string_args.is_null(i) { - // NB: Since we use the same null bitset as the input, - // the output for this value will be ignored, but we - // need some value in the array we are building. - Ok(0) - } else { - Ok(string_args.value(i).chars().count() as u32) - } - }) - .collect::>>()?; - - let data = ArrayData::new( - DataType::UInt32, - num_rows, - Some(string_args.null_count()), - string_args.data().null_buffer().cloned(), - 0, - vec![Buffer::from(result.to_byte_slice())], - vec![], - ); - - Ok(UInt32Array::from(Arc::new(data))) -} - -macro_rules! compute_op { - ($ARRAY:expr, $FUNC:ident, $TYPE:ident, $BUILDER:ident) => {{ - let mut builder = $BUILDER::new($ARRAY.len()); - for index in 0..$ARRAY.len() { - if $ARRAY.is_null(index) { - builder.append_null()?; - } else { - builder.append_value(&$ARRAY.value(index).$FUNC())?; - } - } - Ok(Arc::new(builder.finish())) - }}; -} - -macro_rules! downcast_compute_op { - ($ARRAY:expr, $NAME:expr, $FUNC:ident, $TYPE:ident, $BUILDER:ident) => {{ - let n = $ARRAY.as_any().downcast_ref::<$TYPE>(); - match n { - Some(array) => compute_op!(array, $FUNC, $TYPE, $BUILDER), - _ => Err(DataFusionError::Internal(format!( - "Invalid data type for {}", - $NAME - ))), - } - }}; -} - -macro_rules! unary_primitive_array_op { - ($ARRAY:expr, $NAME:expr, $FUNC:ident) => {{ - match ($ARRAY).data_type() { - DataType::Utf8 => { - downcast_compute_op!($ARRAY, $NAME, $FUNC, StringArray, StringBuilder) - } - DataType::LargeUtf8 => { - downcast_compute_op!( - $ARRAY, - $NAME, - $FUNC, - LargeStringArray, - LargeStringBuilder - ) - } - other => Err(DataFusionError::Internal(format!( - "Unsupported data type {:?} for function {}", - other, $NAME, - ))), - } - }}; -} - macro_rules! string_unary_function { - ($NAME:expr, $FUNC:ident, $STRINGFUNC:ident) => { - /// string function that accepts Utf8 or LargeUtf8 and returns StringArray or LargeStringArray - pub fn $FUNC(args: &[ArrayRef]) -> Result { - unary_primitive_array_op!(args[0], $NAME, $STRINGFUNC) + ($NAME:ident, $FUNC:ident) => { + /// string function that accepts Utf8 or LargeUtf8 and returns Utf8 or LargeUtf8 + pub fn $NAME( + args: &[ArrayRef], + ) -> Result> { + let array = args[0] + .as_any() + .downcast_ref::>() + .unwrap(); + // first map is the iterator, second is for the `Option<_>` + Ok(array.iter().map(|x| x.map(|x| x.$FUNC())).collect()) } }; } -string_unary_function!("lower", lower, to_ascii_lowercase); -string_unary_function!("upper", upper, to_ascii_uppercase); -string_unary_function!("trim", trim, trim); +string_unary_function!(lower, to_ascii_lowercase); +string_unary_function!(upper, to_ascii_uppercase); +string_unary_function!(trim, trim); diff --git a/rust/datafusion/tests/sql.rs b/rust/datafusion/tests/sql.rs index 961ccfb650c..48a8d70dba3 100644 --- a/rust/datafusion/tests/sql.rs +++ b/rust/datafusion/tests/sql.rs @@ -1831,9 +1831,9 @@ async fn csv_between_expr_negated() -> Result<()> { async fn string_expressions() -> Result<()> { let mut ctx = ExecutionContext::new(); let sql = "SELECT - char_length('josé') AS char_length + char_length('tom') AS char_length ,char_length(NULL) AS char_length_null - ,character_length('josé') AS character_length + ,character_length('tom') AS character_length ,character_length(NULL) AS character_length_null ,lower('TOM') AS lower ,lower(NULL) AS lower_null @@ -1845,7 +1845,7 @@ async fn string_expressions() -> Result<()> { let actual = execute(&mut ctx, sql).await; let expected = vec![vec![ - "4", "NULL", "4", "NULL", "tom", "NULL", "TOM", "NULL", "tom", "NULL", + "3", "NULL", "3", "NULL", "tom", "NULL", "TOM", "NULL", "tom", "NULL", ]]; assert_eq!(expected, actual); Ok(())