From c5b4ac7c81c64287ef3481f0d80391fd17e76218 Mon Sep 17 00:00:00 2001 From: Mike Seddon Date: Wed, 24 Feb 2021 10:42:57 +1100 Subject: [PATCH] btrim, concat, concat_ws, ltrim, rtrim, substr, trim --- rust/datafusion/README.md | 8 +- rust/datafusion/src/logical_plan/expr.rs | 12 +- rust/datafusion/src/logical_plan/mod.rs | 12 +- .../datafusion/src/physical_plan/functions.rs | 952 ++++++++++++++---- .../src/physical_plan/string_expressions.rs | 423 ++++++-- rust/datafusion/src/prelude.rs | 6 +- rust/datafusion/tests/sql.rs | 497 +++++---- 7 files changed, 1441 insertions(+), 469 deletions(-) diff --git a/rust/datafusion/README.md b/rust/datafusion/README.md index b4cb04321e7..5dcab04399e 100644 --- a/rust/datafusion/README.md +++ b/rust/datafusion/README.md @@ -58,11 +58,17 @@ DataFusion includes a simple command-line interactive SQL utility. See the [CLI - [x] Common math functions - String functions - [x] bit_Length + - [x] btrim - [x] char_length - [x] character_length + - [x] concat + - [x] concat_ws - [x] length + - [x] ltrim - [x] octet_length - - [x] Concatenate + - [x] rtrim + - [x] substr + - [x] trim - Miscellaneous/Boolean functions - [x] nullif - Common date/time functions diff --git a/rust/datafusion/src/logical_plan/expr.rs b/rust/datafusion/src/logical_plan/expr.rs index 245ca3aaaa8..6dadefea548 100644 --- a/rust/datafusion/src/logical_plan/expr.rs +++ b/rust/datafusion/src/logical_plan/expr.rs @@ -875,8 +875,11 @@ unary_scalar_expr!(Log10, log10); // string functions unary_scalar_expr!(BitLength, bit_length); +unary_scalar_expr!(Btrim, btrim); unary_scalar_expr!(CharacterLength, character_length); unary_scalar_expr!(CharacterLength, length); +unary_scalar_expr!(Concat, concat); +unary_scalar_expr!(ConcatWithSeparator, concat_ws); unary_scalar_expr!(Lower, lower); unary_scalar_expr!(Ltrim, ltrim); unary_scalar_expr!(MD5, md5); @@ -886,17 +889,10 @@ unary_scalar_expr!(SHA224, sha224); unary_scalar_expr!(SHA256, sha256); unary_scalar_expr!(SHA384, sha384); unary_scalar_expr!(SHA512, sha512); +unary_scalar_expr!(Substr, substr); unary_scalar_expr!(Trim, trim); unary_scalar_expr!(Upper, upper); -/// returns the concatenation of string expressions -pub fn concat(args: Vec) -> Expr { - Expr::ScalarFunction { - fun: functions::BuiltinScalarFunction::Concat, - args, - } -} - /// returns an array of fixed size with each argument on it. pub fn array(args: Vec) -> Expr { Expr::ScalarFunction { diff --git a/rust/datafusion/src/logical_plan/mod.rs b/rust/datafusion/src/logical_plan/mod.rs index 0de0a032520..99c35fafd54 100644 --- a/rust/datafusion/src/logical_plan/mod.rs +++ b/rust/datafusion/src/logical_plan/mod.rs @@ -33,12 +33,12 @@ pub use builder::LogicalPlanBuilder; pub use dfschema::{DFField, DFSchema, DFSchemaRef, ToDFSchema}; pub use display::display_schema; pub use expr::{ - abs, acos, and, array, asin, atan, avg, binary_expr, bit_length, case, ceil, - character_length, col, combine_filters, concat, cos, count, count_distinct, - create_udaf, create_udf, exp, exprlist_to_fields, floor, in_list, length, lit, ln, - log10, log2, lower, ltrim, max, md5, min, octet_length, or, round, rtrim, sha224, - sha256, sha384, sha512, signum, sin, sqrt, sum, tan, trim, trunc, upper, when, Expr, - ExpressionVisitor, Literal, Recursion, + abs, acos, and, array, asin, atan, avg, binary_expr, bit_length, btrim, case, ceil, + character_length, col, combine_filters, concat, concat_ws, cos, count, + count_distinct, create_udaf, create_udf, exp, exprlist_to_fields, floor, in_list, + length, lit, ln, log10, log2, lower, ltrim, max, md5, min, octet_length, or, round, + rtrim, sha224, sha256, sha384, sha512, signum, sin, sqrt, substr, sum, tan, trim, + trunc, upper, when, Expr, ExpressionVisitor, Literal, Recursion, }; pub use extension::UserDefinedLogicalNode; pub use operators::Operator; diff --git a/rust/datafusion/src/physical_plan/functions.rs b/rust/datafusion/src/physical_plan/functions.rs index 51941188bb4..1c82d0fea45 100644 --- a/rust/datafusion/src/physical_plan/functions.rs +++ b/rust/datafusion/src/physical_plan/functions.rs @@ -86,78 +86,87 @@ pub type ReturnTypeFunction = /// Enum of all built-in scalar functions #[derive(Debug, Clone, PartialEq, Eq)] pub enum BuiltinScalarFunction { - /// sqrt - Sqrt, - /// sin - Sin, - /// cos - Cos, - /// tan - Tan, - /// asin - Asin, + // math functions + /// abs + Abs, /// acos Acos, + /// asin + Asin, /// atan Atan, + /// ceil + Ceil, + /// cos + Cos, /// exp Exp, + /// floor + Floor, /// log, also known as ln Log, - /// log2 - Log2, /// log10 Log10, - /// floor - Floor, - /// ceil - Ceil, + /// log2 + Log2, /// round Round, - /// trunc - Trunc, - /// abs - Abs, /// signum Signum, + /// sin + Sin, + /// sqrt + Sqrt, + /// tan + Tan, + /// trunc + Trunc, + + // string functions + /// construct an array from columns + Array, + /// bit_length + BitLength, + /// btrim + Btrim, + /// character_length + CharacterLength, /// concat Concat, + /// concat_ws + ConcatWithSeparator, + /// Date part + DatePart, + /// Date truncate + DateTrunc, /// lower Lower, - /// upper - Upper, - /// trim - Trim, /// trim left Ltrim, - /// trim right - Rtrim, - /// to_timestamp - ToTimestamp, - /// construct an array from columns - Array, - /// SQL NULLIF() - NullIf, - /// Date truncate - DateTrunc, - /// Date part - DatePart, /// MD5 MD5, + /// SQL NULLIF() + NullIf, + /// octet_length + OctetLength, + /// trim right + Rtrim, /// SHA224 SHA224, - /// SHA256, + /// SHA256 SHA256, /// SHA384 SHA384, - /// SHA512, + /// SHA512 SHA512, - /// bit_length - BitLength, - /// character_length - CharacterLength, - /// octet_length - OctetLength, + /// substr + Substr, + /// to_timestamp + ToTimestamp, + /// trim + Trim, + /// upper + Upper, } impl fmt::Display for BuiltinScalarFunction { @@ -171,44 +180,51 @@ impl FromStr for BuiltinScalarFunction { type Err = DataFusionError; fn from_str(name: &str) -> Result { Ok(match name { - "sqrt" => BuiltinScalarFunction::Sqrt, - "sin" => BuiltinScalarFunction::Sin, - "cos" => BuiltinScalarFunction::Cos, - "tan" => BuiltinScalarFunction::Tan, - "asin" => BuiltinScalarFunction::Asin, + // math functions + "abs" => BuiltinScalarFunction::Abs, "acos" => BuiltinScalarFunction::Acos, + "asin" => BuiltinScalarFunction::Asin, "atan" => BuiltinScalarFunction::Atan, + "ceil" => BuiltinScalarFunction::Ceil, + "cos" => BuiltinScalarFunction::Cos, "exp" => BuiltinScalarFunction::Exp, + "floor" => BuiltinScalarFunction::Floor, "log" => BuiltinScalarFunction::Log, - "log2" => BuiltinScalarFunction::Log2, "log10" => BuiltinScalarFunction::Log10, - "floor" => BuiltinScalarFunction::Floor, - "ceil" => BuiltinScalarFunction::Ceil, + "log2" => BuiltinScalarFunction::Log2, "round" => BuiltinScalarFunction::Round, - "truc" => BuiltinScalarFunction::Trunc, - "abs" => BuiltinScalarFunction::Abs, "signum" => BuiltinScalarFunction::Signum, + "sin" => BuiltinScalarFunction::Sin, + "sqrt" => BuiltinScalarFunction::Sqrt, + "tan" => BuiltinScalarFunction::Tan, + "trunc" => BuiltinScalarFunction::Trunc, + + // string functions + "array" => BuiltinScalarFunction::Array, + "bit_length" => BuiltinScalarFunction::BitLength, + "btrim" => BuiltinScalarFunction::Btrim, + "char_length" => BuiltinScalarFunction::CharacterLength, + "character_length" => BuiltinScalarFunction::CharacterLength, "concat" => BuiltinScalarFunction::Concat, + "concat_ws" => BuiltinScalarFunction::ConcatWithSeparator, + "date_part" => BuiltinScalarFunction::DatePart, + "date_trunc" => BuiltinScalarFunction::DateTrunc, + "length" => BuiltinScalarFunction::CharacterLength, "lower" => BuiltinScalarFunction::Lower, - "trim" => BuiltinScalarFunction::Trim, "ltrim" => BuiltinScalarFunction::Ltrim, - "rtrim" => BuiltinScalarFunction::Rtrim, - "upper" => BuiltinScalarFunction::Upper, - "to_timestamp" => BuiltinScalarFunction::ToTimestamp, - "date_trunc" => BuiltinScalarFunction::DateTrunc, - "date_part" => BuiltinScalarFunction::DatePart, - "array" => BuiltinScalarFunction::Array, - "nullif" => BuiltinScalarFunction::NullIf, "md5" => BuiltinScalarFunction::MD5, + "nullif" => BuiltinScalarFunction::NullIf, + "octet_length" => BuiltinScalarFunction::OctetLength, + "rtrim" => BuiltinScalarFunction::Rtrim, "sha224" => BuiltinScalarFunction::SHA224, "sha256" => BuiltinScalarFunction::SHA256, "sha384" => BuiltinScalarFunction::SHA384, "sha512" => BuiltinScalarFunction::SHA512, - "bit_length" => BuiltinScalarFunction::BitLength, - "octet_length" => BuiltinScalarFunction::OctetLength, - "length" => BuiltinScalarFunction::CharacterLength, - "char_length" => BuiltinScalarFunction::CharacterLength, - "character_length" => BuiltinScalarFunction::CharacterLength, + "substr" => BuiltinScalarFunction::Substr, + "to_timestamp" => BuiltinScalarFunction::ToTimestamp, + "trim" => BuiltinScalarFunction::Trim, + "upper" => BuiltinScalarFunction::Upper, + _ => { return Err(DataFusionError::Plan(format!( "There is no built-in function named {}", @@ -242,80 +258,98 @@ pub fn return_type( // the return type of the built in function. // Some built-in functions' return type depends on the incoming type. match fun { - BuiltinScalarFunction::Concat => Ok(DataType::Utf8), - BuiltinScalarFunction::Lower => Ok(match arg_types[0] { - DataType::LargeUtf8 => DataType::LargeUtf8, - DataType::Utf8 => DataType::Utf8, + BuiltinScalarFunction::Array => Ok(DataType::FixedSizeList( + Box::new(Field::new("item", arg_types[0].clone(), true)), + arg_types.len() as i32, + )), + BuiltinScalarFunction::BitLength => Ok(match arg_types[0] { + DataType::LargeUtf8 => DataType::Int64, + DataType::Utf8 => DataType::Int32, _ => { // this error is internal as `data_types` should have captured this. return Err(DataFusionError::Internal( - "The upper function can only accept strings.".to_string(), + "The bit_length function can only accept strings.".to_string(), )); } }), - BuiltinScalarFunction::Ltrim => Ok(match arg_types[0] { + BuiltinScalarFunction::Btrim => Ok(match arg_types[0] { DataType::LargeUtf8 => DataType::LargeUtf8, DataType::Utf8 => DataType::Utf8, _ => { // this error is internal as `data_types` should have captured this. return Err(DataFusionError::Internal( - "The ltrim function can only accept strings.".to_string(), + "The btrim function can only accept strings.".to_string(), )); } }), - BuiltinScalarFunction::Rtrim => Ok(match arg_types[0] { + BuiltinScalarFunction::CharacterLength => Ok(match arg_types[0] { + DataType::LargeUtf8 => DataType::Int64, + DataType::Utf8 => DataType::Int32, + _ => { + // this error is internal as `data_types` should have captured this. + return Err(DataFusionError::Internal( + "The character_length function can only accept strings.".to_string(), + )); + } + }), + BuiltinScalarFunction::Concat => Ok(DataType::Utf8), + BuiltinScalarFunction::ConcatWithSeparator => Ok(DataType::Utf8), + BuiltinScalarFunction::DatePart => Ok(DataType::Int32), + BuiltinScalarFunction::DateTrunc => { + Ok(DataType::Timestamp(TimeUnit::Nanosecond, None)) + } + BuiltinScalarFunction::Lower => Ok(match arg_types[0] { DataType::LargeUtf8 => DataType::LargeUtf8, DataType::Utf8 => DataType::Utf8, _ => { // this error is internal as `data_types` should have captured this. return Err(DataFusionError::Internal( - "The rtrim function can only accept strings.".to_string(), + "The upper function can only accept strings.".to_string(), )); } }), - BuiltinScalarFunction::Trim => Ok(match arg_types[0] { + BuiltinScalarFunction::Ltrim => Ok(match arg_types[0] { DataType::LargeUtf8 => DataType::LargeUtf8, DataType::Utf8 => DataType::Utf8, _ => { // this error is internal as `data_types` should have captured this. return Err(DataFusionError::Internal( - "The trim function can only accept strings.".to_string(), + "The ltrim function can only accept strings.".to_string(), )); } }), - BuiltinScalarFunction::Upper => Ok(match arg_types[0] { + BuiltinScalarFunction::MD5 => Ok(match arg_types[0] { DataType::LargeUtf8 => DataType::LargeUtf8, DataType::Utf8 => DataType::Utf8, _ => { // this error is internal as `data_types` should have captured this. return Err(DataFusionError::Internal( - "The upper function can only accept strings.".to_string(), + "The md5 function can only accept strings.".to_string(), )); } }), - BuiltinScalarFunction::ToTimestamp => { - Ok(DataType::Timestamp(TimeUnit::Nanosecond, None)) - } - BuiltinScalarFunction::DateTrunc => { - Ok(DataType::Timestamp(TimeUnit::Nanosecond, None)) - } - BuiltinScalarFunction::DatePart => Ok(DataType::Int32), - BuiltinScalarFunction::Array => Ok(DataType::FixedSizeList( - Box::new(Field::new("item", arg_types[0].clone(), true)), - arg_types.len() as i32, - )), BuiltinScalarFunction::NullIf => { // NULLIF has two args and they might get coerced, get a preview of this let coerced_types = data_types(arg_types, &signature(fun)); coerced_types.map(|typs| typs[0].clone()) } - BuiltinScalarFunction::MD5 => Ok(match arg_types[0] { + BuiltinScalarFunction::OctetLength => Ok(match arg_types[0] { + DataType::LargeUtf8 => DataType::Int64, + DataType::Utf8 => DataType::Int32, + _ => { + // this error is internal as `data_types` should have captured this. + return Err(DataFusionError::Internal( + "The octet_length function can only accept strings.".to_string(), + )); + } + }), + BuiltinScalarFunction::Rtrim => Ok(match arg_types[0] { DataType::LargeUtf8 => DataType::LargeUtf8, DataType::Utf8 => DataType::Utf8, _ => { // this error is internal as `data_types` should have captured this. return Err(DataFusionError::Internal( - "The md5 function can only accept strings.".to_string(), + "The rtrim function can only accept strings.".to_string(), )); } }), @@ -359,37 +393,57 @@ pub fn return_type( )); } }), - BuiltinScalarFunction::BitLength => Ok(match arg_types[0] { - DataType::LargeUtf8 => DataType::Int64, - DataType::Utf8 => DataType::Int32, + BuiltinScalarFunction::Substr => Ok(match arg_types[0] { + DataType::LargeUtf8 => DataType::LargeUtf8, + DataType::Utf8 => DataType::Utf8, _ => { // this error is internal as `data_types` should have captured this. return Err(DataFusionError::Internal( - "The bit_length function can only accept strings.".to_string(), + "The substr function can only accept strings.".to_string(), )); } }), - BuiltinScalarFunction::CharacterLength => Ok(match arg_types[0] { - DataType::LargeUtf8 => DataType::Int64, - DataType::Utf8 => DataType::Int32, + BuiltinScalarFunction::ToTimestamp => { + Ok(DataType::Timestamp(TimeUnit::Nanosecond, None)) + } + BuiltinScalarFunction::Trim => Ok(match arg_types[0] { + DataType::LargeUtf8 => DataType::LargeUtf8, + DataType::Utf8 => DataType::Utf8, _ => { // this error is internal as `data_types` should have captured this. return Err(DataFusionError::Internal( - "The character_length function can only accept strings.".to_string(), + "The trim function can only accept strings.".to_string(), )); } }), - BuiltinScalarFunction::OctetLength => Ok(match arg_types[0] { - DataType::LargeUtf8 => DataType::Int64, - DataType::Utf8 => DataType::Int32, + BuiltinScalarFunction::Upper => Ok(match arg_types[0] { + DataType::LargeUtf8 => DataType::LargeUtf8, + DataType::Utf8 => DataType::Utf8, _ => { // this error is internal as `data_types` should have captured this. return Err(DataFusionError::Internal( - "The octet_length function can only accept strings.".to_string(), + "The upper function can only accept strings.".to_string(), )); } }), - _ => Ok(DataType::Float64), + + BuiltinScalarFunction::Abs + | BuiltinScalarFunction::Acos + | BuiltinScalarFunction::Asin + | BuiltinScalarFunction::Atan + | BuiltinScalarFunction::Ceil + | BuiltinScalarFunction::Cos + | BuiltinScalarFunction::Exp + | BuiltinScalarFunction::Floor + | BuiltinScalarFunction::Log + | BuiltinScalarFunction::Log10 + | BuiltinScalarFunction::Log2 + | BuiltinScalarFunction::Round + | BuiltinScalarFunction::Signum + | BuiltinScalarFunction::Sin + | BuiltinScalarFunction::Sqrt + | BuiltinScalarFunction::Tan + | BuiltinScalarFunction::Trunc => Ok(DataType::Float64), } } @@ -401,37 +455,26 @@ pub fn create_physical_expr( input_schema: &Schema, ) -> Result> { let fun_expr: ScalarFunctionImplementation = Arc::new(match fun { - BuiltinScalarFunction::Sqrt => math_expressions::sqrt, - BuiltinScalarFunction::Sin => math_expressions::sin, - BuiltinScalarFunction::Cos => math_expressions::cos, - BuiltinScalarFunction::Tan => math_expressions::tan, - BuiltinScalarFunction::Asin => math_expressions::asin, + // math functions + BuiltinScalarFunction::Abs => math_expressions::abs, BuiltinScalarFunction::Acos => math_expressions::acos, + BuiltinScalarFunction::Asin => math_expressions::asin, BuiltinScalarFunction::Atan => math_expressions::atan, + BuiltinScalarFunction::Ceil => math_expressions::ceil, + BuiltinScalarFunction::Cos => math_expressions::cos, BuiltinScalarFunction::Exp => math_expressions::exp, + BuiltinScalarFunction::Floor => math_expressions::floor, BuiltinScalarFunction::Log => math_expressions::ln, - BuiltinScalarFunction::Log2 => math_expressions::log2, BuiltinScalarFunction::Log10 => math_expressions::log10, - BuiltinScalarFunction::Floor => math_expressions::floor, - BuiltinScalarFunction::Ceil => math_expressions::ceil, + BuiltinScalarFunction::Log2 => math_expressions::log2, BuiltinScalarFunction::Round => math_expressions::round, - BuiltinScalarFunction::Trunc => math_expressions::trunc, - BuiltinScalarFunction::Abs => math_expressions::abs, BuiltinScalarFunction::Signum => math_expressions::signum, - BuiltinScalarFunction::NullIf => nullif_func, - BuiltinScalarFunction::MD5 => crypto_expressions::md5, - BuiltinScalarFunction::SHA224 => crypto_expressions::sha224, - BuiltinScalarFunction::SHA256 => crypto_expressions::sha256, - BuiltinScalarFunction::SHA384 => crypto_expressions::sha384, - BuiltinScalarFunction::SHA512 => crypto_expressions::sha512, - BuiltinScalarFunction::Concat => string_expressions::concatenate, - BuiltinScalarFunction::Lower => string_expressions::lower, - BuiltinScalarFunction::Trim => string_expressions::trim, - BuiltinScalarFunction::Ltrim => string_expressions::ltrim, - BuiltinScalarFunction::Rtrim => string_expressions::rtrim, - BuiltinScalarFunction::Upper => string_expressions::upper, - BuiltinScalarFunction::ToTimestamp => datetime_expressions::to_timestamp, - BuiltinScalarFunction::DateTrunc => datetime_expressions::date_trunc, + BuiltinScalarFunction::Sin => math_expressions::sin, + BuiltinScalarFunction::Sqrt => math_expressions::sqrt, + BuiltinScalarFunction::Tan => math_expressions::tan, + BuiltinScalarFunction::Trunc => math_expressions::trunc, + + // string functions BuiltinScalarFunction::Array => array_expressions::array, BuiltinScalarFunction::BitLength => |args| match &args[0] { ColumnarValue::Array(v) => Ok(ColumnarValue::Array(bit_length(v.as_ref())?)), @@ -445,6 +488,18 @@ pub fn create_physical_expr( _ => unreachable!(), }, }, + BuiltinScalarFunction::Btrim => |args| match args[0].data_type() { + DataType::Utf8 => { + make_scalar_function(string_expressions::btrim::)(args) + } + DataType::LargeUtf8 => { + make_scalar_function(string_expressions::btrim::)(args) + } + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {:?} for function btrim", + other, + ))), + }, BuiltinScalarFunction::CharacterLength => |args| match args[0].data_type() { DataType::Utf8 => make_scalar_function( string_expressions::character_length::, @@ -457,6 +512,27 @@ pub fn create_physical_expr( other, ))), }, + BuiltinScalarFunction::Concat => string_expressions::concat, + BuiltinScalarFunction::ConcatWithSeparator => { + |args| make_scalar_function(string_expressions::concat_ws)(args) + } + BuiltinScalarFunction::DatePart => datetime_expressions::date_part, + BuiltinScalarFunction::DateTrunc => datetime_expressions::date_trunc, + BuiltinScalarFunction::Lower => string_expressions::lower, + BuiltinScalarFunction::Ltrim => |args| match args[0].data_type() { + DataType::Utf8 => { + make_scalar_function(string_expressions::ltrim::)(args) + } + DataType::LargeUtf8 => { + make_scalar_function(string_expressions::ltrim::)(args) + } + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {:?} for function ltrim", + other, + ))), + }, + BuiltinScalarFunction::MD5 => crypto_expressions::md5, + BuiltinScalarFunction::NullIf => nullif_func, BuiltinScalarFunction::OctetLength => |args| match &args[0] { ColumnarValue::Array(v) => Ok(ColumnarValue::Array(length(v.as_ref())?)), ColumnarValue::Scalar(v) => match v { @@ -469,7 +545,48 @@ pub fn create_physical_expr( _ => unreachable!(), }, }, - BuiltinScalarFunction::DatePart => datetime_expressions::date_part, + BuiltinScalarFunction::Rtrim => |args| match args[0].data_type() { + DataType::Utf8 => { + make_scalar_function(string_expressions::rtrim::)(args) + } + DataType::LargeUtf8 => { + make_scalar_function(string_expressions::rtrim::)(args) + } + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {:?} for function rtrim", + other, + ))), + }, + BuiltinScalarFunction::SHA224 => crypto_expressions::sha224, + BuiltinScalarFunction::SHA256 => crypto_expressions::sha256, + BuiltinScalarFunction::SHA384 => crypto_expressions::sha384, + BuiltinScalarFunction::SHA512 => crypto_expressions::sha512, + BuiltinScalarFunction::Substr => |args| match args[0].data_type() { + DataType::Utf8 => { + make_scalar_function(string_expressions::substr::)(args) + } + DataType::LargeUtf8 => { + make_scalar_function(string_expressions::substr::)(args) + } + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {:?} for function substr", + other, + ))), + }, + BuiltinScalarFunction::ToTimestamp => datetime_expressions::to_timestamp, + BuiltinScalarFunction::Trim => |args| match args[0].data_type() { + DataType::Utf8 => { + make_scalar_function(string_expressions::btrim::)(args) + } + DataType::LargeUtf8 => { + make_scalar_function(string_expressions::btrim::)(args) + } + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {:?} for function trim", + other, + ))), + }, + BuiltinScalarFunction::Upper => string_expressions::upper, }); // coerce let args = coerce(args, input_schema, &signature(fun))?; @@ -493,22 +610,31 @@ fn signature(fun: &BuiltinScalarFunction) -> Signature { // for now, the list is small, as we do not have many built-in functions. match fun { - BuiltinScalarFunction::Concat => Signature::Variadic(vec![DataType::Utf8]), - BuiltinScalarFunction::Upper - | BuiltinScalarFunction::Lower - | BuiltinScalarFunction::BitLength + BuiltinScalarFunction::Array => { + Signature::Variadic(array_expressions::SUPPORTED_ARRAY_TYPES.to_vec()) + } + BuiltinScalarFunction::Concat | BuiltinScalarFunction::ConcatWithSeparator => { + Signature::Variadic(vec![DataType::Utf8]) + } + BuiltinScalarFunction::BitLength | BuiltinScalarFunction::CharacterLength - | BuiltinScalarFunction::OctetLength - | BuiltinScalarFunction::Trim - | BuiltinScalarFunction::Ltrim - | BuiltinScalarFunction::Rtrim + | BuiltinScalarFunction::Lower | BuiltinScalarFunction::MD5 + | BuiltinScalarFunction::OctetLength | BuiltinScalarFunction::SHA224 | BuiltinScalarFunction::SHA256 | BuiltinScalarFunction::SHA384 - | BuiltinScalarFunction::SHA512 => { + | BuiltinScalarFunction::SHA512 + | BuiltinScalarFunction::Trim + | BuiltinScalarFunction::Upper => { Signature::Uniform(1, vec![DataType::Utf8, DataType::LargeUtf8]) } + BuiltinScalarFunction::Btrim + | BuiltinScalarFunction::Ltrim + | BuiltinScalarFunction::Rtrim => Signature::OneOf(vec![ + Signature::Exact(vec![DataType::Utf8]), + Signature::Exact(vec![DataType::Utf8, DataType::Utf8]), + ]), BuiltinScalarFunction::ToTimestamp => Signature::Uniform(1, vec![DataType::Utf8]), BuiltinScalarFunction::DateTrunc => Signature::Exact(vec![ DataType::Utf8, @@ -534,9 +660,12 @@ fn signature(fun: &BuiltinScalarFunction) -> Signature { DataType::Timestamp(TimeUnit::Nanosecond, None), ]), ]), - BuiltinScalarFunction::Array => { - Signature::Variadic(array_expressions::SUPPORTED_ARRAY_TYPES.to_vec()) - } + BuiltinScalarFunction::Substr => Signature::OneOf(vec![ + Signature::Exact(vec![DataType::Utf8, DataType::Int64]), + Signature::Exact(vec![DataType::LargeUtf8, DataType::Int64]), + Signature::Exact(vec![DataType::Utf8, DataType::Int64, DataType::Int64]), + Signature::Exact(vec![DataType::LargeUtf8, DataType::Int64, DataType::Int64]), + ]), BuiltinScalarFunction::NullIf => { Signature::Uniform(2, SUPPORTED_NULLIF_TYPES.to_vec()) } @@ -753,6 +882,106 @@ mod tests { #[test] fn test_functions() -> Result<()> { + test_function!( + BitLength, + &[lit(ScalarValue::Utf8(Some("chars".to_string())))], + Ok(Some(40)), + i32, + Int32, + Int32Array + ); + test_function!( + BitLength, + &[lit(ScalarValue::Utf8(Some("josé".to_string())))], + Ok(Some(40)), + i32, + Int32, + Int32Array + ); + test_function!( + BitLength, + &[lit(ScalarValue::Utf8(Some("".to_string())))], + Ok(Some(0)), + i32, + Int32, + Int32Array + ); + test_function!( + Btrim, + &[lit(ScalarValue::Utf8(Some(" trim ".to_string())))], + Ok(Some("trim")), + &str, + Utf8, + StringArray + ); + test_function!( + Btrim, + &[lit(ScalarValue::Utf8(Some(" trim".to_string())))], + Ok(Some("trim")), + &str, + Utf8, + StringArray + ); + test_function!( + Btrim, + &[lit(ScalarValue::Utf8(Some("trim ".to_string())))], + Ok(Some("trim")), + &str, + Utf8, + StringArray + ); + test_function!( + Btrim, + &[lit(ScalarValue::Utf8(Some("\n trim \n".to_string())))], + Ok(Some("\n trim \n")), + &str, + Utf8, + StringArray + ); + test_function!( + Btrim, + &[ + lit(ScalarValue::Utf8(Some("xyxtrimyyx".to_string()))), + lit(ScalarValue::Utf8(Some("xyz".to_string()))), + ], + Ok(Some("trim")), + &str, + Utf8, + StringArray + ); + test_function!( + Btrim, + &[ + lit(ScalarValue::Utf8(Some("\nxyxtrimyyx\n".to_string()))), + lit(ScalarValue::Utf8(Some("xyz\n".to_string()))), + ], + Ok(Some("trim")), + &str, + Utf8, + StringArray + ); + test_function!( + Btrim, + &[ + lit(ScalarValue::Utf8(None)), + lit(ScalarValue::Utf8(Some("xyz".to_string()))), + ], + Ok(None), + &str, + Utf8, + StringArray + ); + test_function!( + Btrim, + &[ + lit(ScalarValue::Utf8(Some("xyxtrimyyx".to_string()))), + lit(ScalarValue::Utf8(None)), + ], + Ok(None), + &str, + Utf8, + StringArray + ); test_function!( CharacterLength, &[lit(ScalarValue::Utf8(Some("chars".to_string())))], @@ -785,6 +1014,88 @@ mod tests { Int32, Int32Array ); + test_function!( + Concat, + &[ + lit(ScalarValue::Utf8(Some("aa".to_string()))), + lit(ScalarValue::Utf8(Some("bb".to_string()))), + lit(ScalarValue::Utf8(Some("cc".to_string()))), + ], + Ok(Some("aabbcc")), + &str, + Utf8, + StringArray + ); + test_function!( + Concat, + &[ + lit(ScalarValue::Utf8(Some("aa".to_string()))), + lit(ScalarValue::Utf8(None)), + lit(ScalarValue::Utf8(Some("cc".to_string()))), + ], + Ok(Some("aacc")), + &str, + Utf8, + StringArray + ); + test_function!( + Concat, + &[lit(ScalarValue::Utf8(None))], + Ok(Some("")), + &str, + Utf8, + StringArray + ); + test_function!( + ConcatWithSeparator, + &[ + lit(ScalarValue::Utf8(Some("|".to_string()))), + lit(ScalarValue::Utf8(Some("aa".to_string()))), + lit(ScalarValue::Utf8(Some("bb".to_string()))), + lit(ScalarValue::Utf8(Some("cc".to_string()))), + ], + Ok(Some("aa|bb|cc")), + &str, + Utf8, + StringArray + ); + test_function!( + ConcatWithSeparator, + &[ + lit(ScalarValue::Utf8(Some("|".to_string()))), + lit(ScalarValue::Utf8(None)), + ], + Ok(Some("")), + &str, + Utf8, + StringArray + ); + test_function!( + ConcatWithSeparator, + &[ + lit(ScalarValue::Utf8(None)), + lit(ScalarValue::Utf8(Some("aa".to_string()))), + lit(ScalarValue::Utf8(Some("bb".to_string()))), + lit(ScalarValue::Utf8(Some("cc".to_string()))), + ], + Ok(None), + &str, + Utf8, + StringArray + ); + test_function!( + ConcatWithSeparator, + &[ + lit(ScalarValue::Utf8(Some("|".to_string()))), + lit(ScalarValue::Utf8(Some("aa".to_string()))), + lit(ScalarValue::Utf8(None)), + lit(ScalarValue::Utf8(Some("cc".to_string()))), + ], + Ok(Some("aa|cc")), + &str, + Utf8, + StringArray + ); test_function!( Exp, &[lit(ScalarValue::Int32(Some(1)))], @@ -825,42 +1136,331 @@ mod tests { Float64, Float64Array ); + test_function!( + Ltrim, + &[lit(ScalarValue::Utf8(Some(" trim".to_string())))], + Ok(Some("trim")), + &str, + Utf8, + StringArray + ); + test_function!( + Ltrim, + &[lit(ScalarValue::Utf8(Some(" trim ".to_string())))], + Ok(Some("trim ")), + &str, + Utf8, + StringArray + ); + test_function!( + Ltrim, + &[lit(ScalarValue::Utf8(Some("trim ".to_string())))], + Ok(Some("trim ")), + &str, + Utf8, + StringArray + ); + test_function!( + Ltrim, + &[lit(ScalarValue::Utf8(Some("trim".to_string())))], + Ok(Some("trim")), + &str, + Utf8, + StringArray + ); + test_function!( + Ltrim, + &[lit(ScalarValue::Utf8(Some("\n trim ".to_string())))], + Ok(Some("\n trim ")), + &str, + Utf8, + StringArray + ); + test_function!( + Ltrim, + &[lit(ScalarValue::Utf8(None))], + Ok(None), + &str, + Utf8, + StringArray + ); + test_function!( + OctetLength, + &[lit(ScalarValue::Utf8(Some("chars".to_string())))], + Ok(Some(5)), + i32, + Int32, + Int32Array + ); + test_function!( + OctetLength, + &[lit(ScalarValue::Utf8(Some("josé".to_string())))], + Ok(Some(5)), + i32, + Int32, + Int32Array + ); + test_function!( + OctetLength, + &[lit(ScalarValue::Utf8(Some("".to_string())))], + Ok(Some(0)), + i32, + Int32, + Int32Array + ); + test_function!( + OctetLength, + &[lit(ScalarValue::Utf8(None))], + Ok(None), + i32, + Int32, + Int32Array + ); + test_function!( + Substr, + &[ + lit(ScalarValue::Utf8(Some("alphabet".to_string()))), + lit(ScalarValue::Int64(Some(0))), + ], + Ok(Some("alphabet")), + &str, + Utf8, + StringArray + ); + test_function!( + Substr, + &[ + lit(ScalarValue::Utf8(Some("joséésoj".to_string()))), + lit(ScalarValue::Int64(Some(5))), + ], + Ok(Some("ésoj")), + &str, + Utf8, + StringArray + ); + test_function!( + Substr, + &[ + lit(ScalarValue::Utf8(Some("alphabet".to_string()))), + lit(ScalarValue::Int64(Some(1))), + ], + Ok(Some("alphabet")), + &str, + Utf8, + StringArray + ); + test_function!( + Substr, + &[ + lit(ScalarValue::Utf8(Some("alphabet".to_string()))), + lit(ScalarValue::Int64(Some(2))), + ], + Ok(Some("lphabet")), + &str, + Utf8, + StringArray + ); + test_function!( + Substr, + &[ + lit(ScalarValue::Utf8(Some("alphabet".to_string()))), + lit(ScalarValue::Int64(Some(3))), + ], + Ok(Some("phabet")), + &str, + Utf8, + StringArray + ); + test_function!( + Substr, + &[ + lit(ScalarValue::Utf8(Some("alphabet".to_string()))), + lit(ScalarValue::Int64(Some(-3))), + ], + Ok(Some("alphabet")), + &str, + Utf8, + StringArray + ); + test_function!( + Substr, + &[ + lit(ScalarValue::Utf8(Some("alphabet".to_string()))), + lit(ScalarValue::Int64(Some(30))), + ], + Ok(Some("")), + &str, + Utf8, + StringArray + ); + test_function!( + Substr, + &[ + lit(ScalarValue::Utf8(Some("alphabet".to_string()))), + lit(ScalarValue::Int64(None)), + ], + Ok(None), + &str, + Utf8, + StringArray + ); + test_function!( + Substr, + &[ + lit(ScalarValue::Utf8(Some("alphabet".to_string()))), + lit(ScalarValue::Int64(Some(3))), + lit(ScalarValue::Int64(Some(2))), + ], + Ok(Some("ph")), + &str, + Utf8, + StringArray + ); + test_function!( + Substr, + &[ + lit(ScalarValue::Utf8(Some("alphabet".to_string()))), + lit(ScalarValue::Int64(Some(3))), + lit(ScalarValue::Int64(Some(20))), + ], + Ok(Some("phabet")), + &str, + Utf8, + StringArray + ); + test_function!( + Substr, + &[ + lit(ScalarValue::Utf8(Some("alphabet".to_string()))), + lit(ScalarValue::Int64(None)), + lit(ScalarValue::Int64(Some(20))), + ], + Ok(None), + &str, + Utf8, + StringArray + ); + test_function!( + Substr, + &[ + lit(ScalarValue::Utf8(Some("alphabet".to_string()))), + lit(ScalarValue::Int64(Some(3))), + lit(ScalarValue::Int64(None)), + ], + Ok(None), + &str, + Utf8, + StringArray + ); + test_function!( + Substr, + &[ + lit(ScalarValue::Utf8(Some("alphabet".to_string()))), + lit(ScalarValue::Int64(Some(1))), + lit(ScalarValue::Int64(Some(-1))), + ], + Err(DataFusionError::Execution( + "negative substring length not allowed".to_string(), + )), + &str, + Utf8, + StringArray + ); + test_function!( + Substr, + &[ + lit(ScalarValue::Utf8(Some("joséésoj".to_string()))), + lit(ScalarValue::Int64(Some(5))), + lit(ScalarValue::Int64(Some(2))), + ], + Ok(Some("és")), + &str, + Utf8, + StringArray + ); + test_function!( + Rtrim, + &[lit(ScalarValue::Utf8(Some("trim ".to_string())))], + Ok(Some("trim")), + &str, + Utf8, + StringArray + ); + test_function!( + Rtrim, + &[lit(ScalarValue::Utf8(Some(" trim ".to_string())))], + Ok(Some(" trim")), + &str, + Utf8, + StringArray + ); + test_function!( + Rtrim, + &[lit(ScalarValue::Utf8(Some(" trim \n".to_string())))], + Ok(Some(" trim \n")), + &str, + Utf8, + StringArray + ); + test_function!( + Rtrim, + &[lit(ScalarValue::Utf8(Some(" trim".to_string())))], + Ok(Some(" trim")), + &str, + Utf8, + StringArray + ); + test_function!( + Rtrim, + &[lit(ScalarValue::Utf8(Some("trim".to_string())))], + Ok(Some("trim")), + &str, + Utf8, + StringArray + ); + test_function!( + Rtrim, + &[lit(ScalarValue::Utf8(None))], + Ok(None), + &str, + Utf8, + StringArray + ); + test_function!( + Trim, + &[lit(ScalarValue::Utf8(Some(" trim ".to_string())))], + Ok(Some("trim")), + &str, + Utf8, + StringArray + ); + test_function!( + Trim, + &[lit(ScalarValue::Utf8(Some("trim ".to_string())))], + Ok(Some("trim")), + &str, + Utf8, + StringArray + ); + test_function!( + Trim, + &[lit(ScalarValue::Utf8(Some(" trim".to_string())))], + Ok(Some("trim")), + &str, + Utf8, + StringArray + ); + test_function!( + Trim, + &[lit(ScalarValue::Utf8(None))], + Ok(None), + &str, + Utf8, + StringArray + ); Ok(()) } - fn test_concat(value: ScalarValue, expected: &str) -> Result<()> { - // any type works here: we evaluate against a literal of `value` - let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); - let columns: Vec = vec![Arc::new(Int32Array::from(vec![1]))]; - - // concat(value, value) - let expr = create_physical_expr( - &BuiltinScalarFunction::Concat, - &[lit(value.clone()), lit(value)], - &schema, - )?; - - // type is correct - assert_eq!(expr.data_type(&schema)?, DataType::Utf8); - - // evaluate works - let batch = RecordBatch::try_new(Arc::new(schema.clone()), columns)?; - let result = expr.evaluate(&batch)?.into_array(batch.num_rows()); - - // downcast works - let result = result.as_any().downcast_ref::().unwrap(); - - // value is correct - assert_eq!(result.value(0).to_string(), expected); - - Ok(()) - } - - #[test] - fn test_concat_utf8() -> Result<()> { - test_concat(ScalarValue::Utf8(Some("aa".to_string())), "aaaa") - } - #[test] fn test_concat_error() -> Result<()> { let result = return_type(&BuiltinScalarFunction::Concat, &[]); diff --git a/rust/datafusion/src/physical_plan/string_expressions.rs b/rust/datafusion/src/physical_plan/string_expressions.rs index 81d2c67eec6..7ab0f9f215b 100644 --- a/rust/datafusion/src/physical_plan/string_expressions.rs +++ b/rust/datafusion/src/physical_plan/string_expressions.rs @@ -15,6 +15,10 @@ // specific language governing permissions and limitations // under the License. +// Some of these functions reference the Postgres documentation +// or implementation to ensure compatibility and are subject to +// the Postgres license. + //! String expressions use std::sync::Arc; @@ -25,7 +29,7 @@ use crate::{ }; use arrow::{ array::{ - Array, ArrayRef, GenericStringArray, PrimitiveArray, StringArray, + Array, ArrayRef, GenericStringArray, Int64Array, PrimitiveArray, StringArray, StringOffsetSizeTrait, }, datatypes::{ArrowNativeType, ArrowPrimitiveType, DataType}, @@ -119,6 +123,71 @@ where } } +macro_rules! downcast_vec { + ($ARGS:expr, $ARRAY_TYPE:ident) => {{ + $ARGS + .iter() + .map(|e| match e.as_any().downcast_ref::<$ARRAY_TYPE>() { + Some(array) => Ok(array), + _ => Err(DataFusionError::Internal("failed to downcast".to_string())), + }) + }}; +} + +/// Removes the longest string containing only characters in characters (a space by default) from the start and end of string. +/// btrim('xyxtrimyyx', 'xyz') = 'trim' +pub fn btrim(args: &[ArrayRef]) -> Result { + match args.len() { + 1 => { + let string_array: &GenericStringArray = args[0] + .as_any() + .downcast_ref::>() + .unwrap(); + + let result = string_array + .iter() + .map(|x| x.map(|x: &str| x.trim_start_matches(' ').trim_end_matches(' '))) + .collect::>(); + + Ok(Arc::new(result) as ArrayRef) + } + 2 => { + let string_array: &GenericStringArray = args[0] + .as_any() + .downcast_ref::>() + .unwrap(); + + let characters_array: &GenericStringArray = args[1] + .as_any() + .downcast_ref::>() + .unwrap(); + + let result = string_array + .iter() + .enumerate() + .map(|(i, x)| { + if characters_array.is_null(i) { + None + } else { + x.map(|x: &str| { + let chars: Vec = + characters_array.value(i).chars().collect(); + x.trim_start_matches(&chars[..]) + .trim_end_matches(&chars[..]) + }) + } + }) + .collect::>(); + + Ok(Arc::new(result) as ArrayRef) + } + other => Err(DataFusionError::Internal(format!( + "btrim was called with {} arguments. It requires at most 2.", + other + ))), + } +} + /// Returns number of characters in the string. /// character_length('josé') = 4 pub fn character_length(args: &[ArrayRef]) -> Result @@ -140,16 +209,15 @@ where Ok(Arc::new(result) as ArrayRef) } -/// concatenate string columns together. -pub fn concatenate(args: &[ColumnarValue]) -> Result { - // downcast all arguments to strings - //let args = downcast_vec!(args, StringArray).collect::>>()?; +/// Concatenates the text representations of all the arguments. NULL arguments are ignored. +/// concat('abcde', 2, NULL, 22) = 'abcde222' +pub fn concat(args: &[ColumnarValue]) -> Result { // do not accept 0 arguments. if args.is_empty() { - return Err(DataFusionError::Internal( - "Concatenate was called with 0 arguments. It requires at least one." - .to_string(), - )); + return Err(DataFusionError::Internal(format!( + "concat was called with {} arguments. It requires at least 1.", + args.len() + ))); } // first, decide whether to return a scalar or a vector. @@ -158,42 +226,30 @@ pub fn concatenate(args: &[ColumnarValue]) -> Result { _ => None, }); if let Some(size) = return_array.next() { - let iter = (0..size).map(|index| { - let mut owned_string: String = "".to_owned(); - - // if any is null, the result is null - let mut is_null = false; - for arg in args { - match arg { - ColumnarValue::Scalar(ScalarValue::Utf8(maybe_value)) => { - if let Some(value) = maybe_value { - owned_string.push_str(value); - } else { - is_null = true; - break; // short-circuit as we already know the result + let result = (0..size) + .map(|index| { + let mut owned_string: String = "".to_owned(); + for arg in args { + match arg { + ColumnarValue::Scalar(ScalarValue::Utf8(maybe_value)) => { + if let Some(value) = maybe_value { + owned_string.push_str(value); + } } - } - ColumnarValue::Array(v) => { - if v.is_null(index) { - is_null = true; - break; // short-circuit as we already know the result - } else { - let v = v.as_any().downcast_ref::().unwrap(); - owned_string.push_str(&v.value(index)); + ColumnarValue::Array(v) => { + if v.is_valid(index) { + let v = v.as_any().downcast_ref::().unwrap(); + owned_string.push_str(&v.value(index)); + } } + _ => unreachable!(), } - _ => unreachable!(), } - } - if is_null { - None - } else { Some(owned_string) - } - }); - let array = iter.collect::(); + }) + .collect::(); - Ok(ColumnarValue::Array(Arc::new(array))) + Ok(ColumnarValue::Array(Arc::new(result))) } else { // short avenue with only scalars let initial = Some("".to_string()); @@ -203,9 +259,7 @@ pub fn concatenate(args: &[ColumnarValue]) -> Result { ColumnarValue::Scalar(ScalarValue::Utf8(Some(v))) => { inner.push_str(v); } - ColumnarValue::Scalar(ScalarValue::Utf8(None)) => { - acc = None; - } + ColumnarValue::Scalar(ScalarValue::Utf8(None)) => {} _ => unreachable!(""), }; }; @@ -215,27 +269,284 @@ pub fn concatenate(args: &[ColumnarValue]) -> Result { } } -/// lower +/// Concatenates all but the first argument, with separators. The first argument is used as the separator string, and should not be NULL. Other NULL arguments are ignored. +/// concat_ws(',', 'abcde', 2, NULL, 22) = 'abcde,2,22' +pub fn concat_ws(args: &[ArrayRef]) -> Result { + // downcast all arguments to strings + let args = downcast_vec!(args, StringArray).collect::>>()?; + + // do not accept 0 or 1 arguments. + if args.len() < 2 { + return Err(DataFusionError::Internal(format!( + "concat_ws was called with {} arguments. It requires at least 2.", + args.len() + ))); + } + + // first map is the iterator, second is for the `Option<_>` + let result = args[0] + .iter() + .enumerate() + .map(|(index, x)| { + x.map(|sep: &str| { + let mut owned_string: String = "".to_owned(); + for arg_index in 1..args.len() { + let arg = &args[arg_index]; + if !arg.is_null(index) { + owned_string.push_str(&arg.value(index)); + // if not last push separator + if arg_index != args.len() - 1 { + owned_string.push_str(&sep); + } + } + } + owned_string + }) + }) + .collect::(); + + Ok(Arc::new(result) as ArrayRef) +} + +/// Removes the longest string containing only characters in characters (a space by default) from the start of string. +/// ltrim('zzzytest', 'xyz') = 'test' +pub fn ltrim(args: &[ArrayRef]) -> Result { + match args.len() { + 1 => { + let string_array: &GenericStringArray = args[0] + .as_any() + .downcast_ref::>() + .unwrap(); + + let result = string_array + .iter() + .map(|x| x.map(|x: &str| x.trim_start_matches(' '))) + .collect::>(); + + Ok(Arc::new(result) as ArrayRef) + } + 2 => { + let string_array: &GenericStringArray = args[0] + .as_any() + .downcast_ref::>() + .unwrap(); + + let characters_array: &GenericStringArray = args[1] + .as_any() + .downcast_ref::>() + .unwrap(); + + let result = string_array + .iter() + .enumerate() + .map(|(i, x)| { + if characters_array.is_null(i) { + None + } else { + x.map(|x: &str| { + let chars: Vec = + characters_array.value(i).chars().collect(); + x.trim_start_matches(&chars[..]) + }) + } + }) + .collect::>(); + + Ok(Arc::new(result) as ArrayRef) + } + other => Err(DataFusionError::Internal(format!( + "ltrim was called with {} arguments. It requires at most 2.", + other + ))), + } +} + +/// Converts the string to all lower case. +/// lower('TOM') = 'tom' pub fn lower(args: &[ColumnarValue]) -> Result { handle(args, |x| x.to_ascii_lowercase(), "lower") } -/// upper -pub fn upper(args: &[ColumnarValue]) -> Result { - handle(args, |x| x.to_ascii_uppercase(), "upper") -} +/// Removes the longest string containing only characters in characters (a space by default) from the end of string. +/// rtrim('testxxzx', 'xyz') = 'test' +pub fn rtrim(args: &[ArrayRef]) -> Result { + match args.len() { + 1 => { + let string_array: &GenericStringArray = args[0] + .as_any() + .downcast_ref::>() + .unwrap(); + + let result = string_array + .iter() + .map(|x| x.map(|x: &str| x.trim_end_matches(' '))) + .collect::>(); + + Ok(Arc::new(result) as ArrayRef) + } + 2 => { + let string_array: &GenericStringArray = args[0] + .as_any() + .downcast_ref::>() + .unwrap(); + + let characters_array: &GenericStringArray = args[1] + .as_any() + .downcast_ref::>() + .unwrap(); -/// trim -pub fn trim(args: &[ColumnarValue]) -> Result { - handle(args, |x: &str| x.trim(), "trim") + let result = string_array + .iter() + .enumerate() + .map(|(i, x)| { + if characters_array.is_null(i) { + None + } else { + x.map(|x: &str| { + let chars: Vec = + characters_array.value(i).chars().collect(); + x.trim_end_matches(&chars[..]) + }) + } + }) + .collect::>(); + + Ok(Arc::new(result) as ArrayRef) + } + other => Err(DataFusionError::Internal(format!( + "rtrim was called with {} arguments. It requires at most two.", + other + ))), + } } -/// ltrim -pub fn ltrim(args: &[ColumnarValue]) -> Result { - handle(args, |x| x.trim_start(), "ltrim") +/// Extracts the substring of string starting at the start'th character, and extending for count characters if that is specified. (Same as substring(string from start for count).) +/// substr('alphabet', 3) = 'phabet' +/// substr('alphabet', 3, 2) = 'ph' +pub fn substr(args: &[ArrayRef]) -> Result { + match args.len() { + 2 => { + let string_array: &GenericStringArray = args[0] + .as_any() + .downcast_ref::>() + .ok_or_else(|| { + DataFusionError::Internal( + "could not cast string to StringArray".to_string(), + ) + })?; + + let start_array: &Int64Array = args[1] + .as_any() + .downcast_ref::() + .ok_or_else(|| { + DataFusionError::Internal( + "could not cast start to Int64Array".to_string(), + ) + })?; + + let result = string_array + .iter() + .enumerate() + .map(|(i, x)| { + if start_array.is_null(i) { + None + } else { + x.map(|x: &str| { + let start: i64 = start_array.value(i); + + if start <= 0 { + x.to_string() + } else { + let graphemes = x.graphemes(true).collect::>(); + let start_pos = start as usize - 1; + if graphemes.len() < start_pos { + "".to_string() + } else { + graphemes[start_pos..].concat() + } + } + }) + } + }) + .collect::>(); + + Ok(Arc::new(result) as ArrayRef) + } + 3 => { + let string_array: &GenericStringArray = args[0] + .as_any() + .downcast_ref::>() + .ok_or_else(|| { + DataFusionError::Internal( + "could not cast string to StringArray".to_string(), + ) + })?; + + let start_array: &Int64Array = args[1] + .as_any() + .downcast_ref::() + .ok_or_else(|| { + DataFusionError::Internal( + "could not cast start to Int64Array".to_string(), + ) + })?; + + let count_array: &Int64Array = args[2] + .as_any() + .downcast_ref::() + .ok_or_else(|| { + DataFusionError::Internal( + "could not cast count to Int64Array".to_string(), + ) + })?; + + let result = string_array + .iter() + .enumerate() + .map(|(i, x)| { + if start_array.is_null(i) || count_array.is_null(i) { + Ok(None) + } else { + x.map(|x: &str| { + let start: i64 = start_array.value(i); + let count = count_array.value(i); + + if count < 0 { + Err(DataFusionError::Execution( + "negative substring length not allowed".to_string(), + )) + } else if start <= 0 { + Ok(x.to_string()) + } else { + let graphemes = x.graphemes(true).collect::>(); + let start_pos = start as usize - 1; + let count_usize = count as usize; + if graphemes.len() < start_pos { + Ok("".to_string()) + } else if graphemes.len() < start_pos + count_usize { + Ok(graphemes[start_pos..].concat()) + } else { + Ok(graphemes[start_pos..start_pos + count_usize] + .concat()) + } + } + }) + .transpose() + } + }) + .collect::>>()?; + + Ok(Arc::new(result) as ArrayRef) + } + other => Err(DataFusionError::Internal(format!( + "substr was called with {} arguments. It requires 2 or 3.", + other + ))), + } } -/// rtrim -pub fn rtrim(args: &[ColumnarValue]) -> Result { - handle(args, |x| x.trim_end(), "rtrim") +/// Converts the string to all upper case. +/// upper('tom') = 'TOM' +pub fn upper(args: &[ColumnarValue]) -> Result { + handle(args, |x| x.to_ascii_uppercase(), "upper") } diff --git a/rust/datafusion/src/prelude.rs b/rust/datafusion/src/prelude.rs index 26e03c7453e..d60f0c32a4d 100644 --- a/rust/datafusion/src/prelude.rs +++ b/rust/datafusion/src/prelude.rs @@ -28,8 +28,8 @@ pub use crate::dataframe::DataFrame; pub use crate::execution::context::{ExecutionConfig, ExecutionContext}; pub use crate::logical_plan::{ - array, avg, bit_length, character_length, col, concat, count, create_udf, in_list, - length, lit, lower, ltrim, max, md5, min, octet_length, rtrim, sha224, sha256, - sha384, sha512, sum, trim, upper, JoinType, Partitioning, + array, avg, bit_length, btrim, character_length, col, concat, concat_ws, count, + create_udf, in_list, length, lit, lower, ltrim, max, md5, min, octet_length, rtrim, + sha224, sha256, sha384, sha512, substr, sum, trim, upper, JoinType, Partitioning, }; pub use crate::physical_plan::csv::CsvReadOptions; diff --git a/rust/datafusion/tests/sql.rs b/rust/datafusion/tests/sql.rs index 7a0666635a2..587fe299bd8 100644 --- a/rust/datafusion/tests/sql.rs +++ b/rust/datafusion/tests/sql.rs @@ -1639,7 +1639,7 @@ async fn query_concat() -> Result<()> { let expected = vec![ vec!["-hi-0"], vec!["a-hi-1"], - vec!["NULL"], + vec!["aa-hi-"], vec!["aaa-hi-3"], ]; assert_eq!(expected, actual); @@ -1886,7 +1886,7 @@ async fn query_on_string_dictionary() -> Result<()> { // Expression evaluation let sql = "SELECT concat(d1, '-foo') FROM test"; let actual = execute(&mut ctx, sql).await; - let expected = vec![vec!["one-foo"], vec!["NULL"], vec!["three-foo"]]; + let expected = vec![vec!["one-foo"], vec!["-foo"], vec!["three-foo"]]; assert_eq!(expected, actual); // aggregation @@ -2023,170 +2023,290 @@ async fn csv_group_by_date() -> Result<()> { Ok(()) } -#[tokio::test] -async fn string_expressions() -> Result<()> { - let mut ctx = ExecutionContext::new(); - let sql = "SELECT - char_length('tom') AS char_length - ,char_length(NULL) AS char_length_null - ,character_length('tom') AS character_length - ,character_length(NULL) AS character_length_null - ,lower('TOM') AS lower - ,lower(NULL) AS lower_null - ,upper('tom') AS upper - ,upper(NULL) AS upper_null - ,trim(' tom ') AS trim - ,trim(NULL) AS trim_null - ,ltrim(' tom ') AS trim_left - ,rtrim(' tom ') AS trim_right - "; - let actual = execute(&mut ctx, sql).await; - - let expected = vec![vec![ - "3", "NULL", "3", "NULL", "tom", "NULL", "TOM", "NULL", "tom", "NULL", "tom ", - " tom", - ]]; - assert_eq!(expected, actual); - Ok(()) -} - -#[tokio::test] -async fn boolean_expressions() -> Result<()> { - let mut ctx = ExecutionContext::new(); - let sql = "SELECT - true AS val_1, - false AS val_2 - "; - let actual = execute(&mut ctx, sql).await; - - let expected = vec![vec!["true", "false"]]; - assert_eq!(expected, actual); - Ok(()) -} - -#[tokio::test] -async fn interval_expressions() -> Result<()> { - let mut ctx = ExecutionContext::new(); - let sql = "SELECT - (interval '1') as interval_1, - (interval '1 second') as interval_2, - (interval '500 milliseconds') as interval_3, - (interval '5 second') as interval_4, - (interval '1 minute') as interval_5, - (interval '0.5 minute') as interval_6, - (interval '.5 minute') as interval_7, - (interval '5 minute') as interval_8, - (interval '5 minute 1 second') as interval_9, - (interval '1 hour') as interval_10, - (interval '5 hour') as interval_11, - (interval '1 day') as interval_12, - (interval '1 day 1') as interval_13, - (interval '0.5') as interval_14, - (interval '0.5 day 1') as interval_15, - (interval '0.49 day') as interval_16, - (interval '0.499 day') as interval_17, - (interval '0.4999 day') as interval_18, - (interval '0.49999 day') as interval_19, - (interval '0.49999999999 day') as interval_20, - (interval '5 day') as interval_21, - (interval '5 day 4 hours 3 minutes 2 seconds 100 milliseconds') as interval_22, - (interval '0.5 month') as interval_23, - (interval '1 month') as interval_24, - (interval '5 month') as interval_25, - (interval '13 month') as interval_26, - (interval '0.5 year') as interval_27, - (interval '1 year') as interval_28, - (interval '2 year') as interval_29 - "; - let actual = execute(&mut ctx, sql).await; - - let expected = vec![vec![ - "0 years 0 mons 0 days 0 hours 0 mins 1.00 secs", - "0 years 0 mons 0 days 0 hours 0 mins 1.00 secs", - "0 years 0 mons 0 days 0 hours 0 mins 0.500 secs", - "0 years 0 mons 0 days 0 hours 0 mins 5.00 secs", - "0 years 0 mons 0 days 0 hours 1 mins 0.00 secs", - "0 years 0 mons 0 days 0 hours 0 mins 30.00 secs", - "0 years 0 mons 0 days 0 hours 0 mins 30.00 secs", - "0 years 0 mons 0 days 0 hours 5 mins 0.00 secs", - "0 years 0 mons 0 days 0 hours 5 mins 1.00 secs", - "0 years 0 mons 0 days 1 hours 0 mins 0.00 secs", - "0 years 0 mons 0 days 5 hours 0 mins 0.00 secs", - "0 years 0 mons 1 days 0 hours 0 mins 0.00 secs", - "0 years 0 mons 1 days 0 hours 0 mins 1.00 secs", - "0 years 0 mons 0 days 0 hours 0 mins 0.500 secs", - "0 years 0 mons 0 days 12 hours 0 mins 1.00 secs", - "0 years 0 mons 0 days 11 hours 45 mins 36.00 secs", - "0 years 0 mons 0 days 11 hours 58 mins 33.596 secs", - "0 years 0 mons 0 days 11 hours 59 mins 51.364 secs", - "0 years 0 mons 0 days 11 hours 59 mins 59.136 secs", - "0 years 0 mons 0 days 12 hours 0 mins 0.00 secs", - "0 years 0 mons 5 days 0 hours 0 mins 0.00 secs", - "0 years 0 mons 5 days 4 hours 3 mins 2.100 secs", - "0 years 0 mons 15 days 0 hours 0 mins 0.00 secs", - "0 years 1 mons 0 days 0 hours 0 mins 0.00 secs", - "0 years 5 mons 0 days 0 hours 0 mins 0.00 secs", - "1 years 1 mons 0 days 0 hours 0 mins 0.00 secs", - "0 years 6 mons 0 days 0 hours 0 mins 0.00 secs", - "1 years 0 mons 0 days 0 hours 0 mins 0.00 secs", - "2 years 0 mons 0 days 0 hours 0 mins 0.00 secs", - ]]; - assert_eq!(expected, actual); +macro_rules! test_expression { + ($SQL:expr, $EXPECTED:expr) => { + let mut ctx = ExecutionContext::new(); + let sql = format!("SELECT {}", $SQL); + let actual = execute(&mut ctx, sql.as_str()).await; + assert_eq!($EXPECTED, actual[0][0]); + }; +} + +#[tokio::test] +async fn test_string_expressions() -> Result<()> { + test_expression!("bit_length('')", "0"); + test_expression!("bit_length('chars')", "40"); + test_expression!("bit_length('josé')", "40"); + test_expression!("bit_length(NULL)", "NULL"); + test_expression!("btrim(' xyxtrimyyx ', NULL)", "NULL"); + test_expression!("btrim(' xyxtrimyyx ')", "xyxtrimyyx"); + test_expression!("btrim('\n xyxtrimyyx \n')", "\n xyxtrimyyx \n"); + test_expression!("btrim('xyxtrimyyx', 'xyz')", "trim"); + test_expression!("btrim('\nxyxtrimyyx\n', 'xyz\n')", "trim"); + test_expression!("btrim(NULL, 'xyz')", "NULL"); + test_expression!("char_length('')", "0"); + test_expression!("char_length('chars')", "5"); + test_expression!("char_length(NULL)", "NULL"); + test_expression!("character_length('')", "0"); + test_expression!("character_length('chars')", "5"); + test_expression!("character_length('josé')", "4"); + test_expression!("character_length(NULL)", "NULL"); + test_expression!("concat('a','b','c')", "abc"); + test_expression!("concat('abcde', 2, NULL, 22)", "abcde222"); + test_expression!("concat(NULL)", ""); + test_expression!("concat_ws(',', 'abcde', 2, NULL, 22)", "abcde,2,22"); + test_expression!("concat_ws('|','a','b','c')", "a|b|c"); + test_expression!("concat_ws('|',NULL)", ""); + test_expression!("concat_ws(NULL,'a',NULL,'b','c')", "NULL"); + test_expression!("ltrim(' zzzytest ', NULL)", "NULL"); + test_expression!("ltrim(' zzzytest ')", "zzzytest "); + test_expression!("ltrim('zzzytest', 'xyz')", "test"); + test_expression!("ltrim(NULL, 'xyz')", "NULL"); + test_expression!("lower('')", ""); + test_expression!("lower('TOM')", "tom"); + test_expression!("lower(NULL)", "NULL"); + test_expression!("octet_length('')", "0"); + test_expression!("octet_length('chars')", "5"); + test_expression!("octet_length('josé')", "5"); + test_expression!("octet_length(NULL)", "NULL"); + test_expression!("rtrim(' testxxzx ')", " testxxzx"); + test_expression!("rtrim(' zzzytest ', NULL)", "NULL"); + test_expression!("rtrim('testxxzx', 'xyz')", "test"); + test_expression!("rtrim(NULL, 'xyz')", "NULL"); + test_expression!("substr('alphabet', -3)", "alphabet"); + test_expression!("substr('alphabet', 0)", "alphabet"); + test_expression!("substr('alphabet', 1)", "alphabet"); + test_expression!("substr('alphabet', 2)", "lphabet"); + test_expression!("substr('alphabet', 3)", "phabet"); + test_expression!("substr('alphabet', 30)", ""); + test_expression!("substr('alphabet', CAST(NULL AS int))", "NULL"); + test_expression!("substr('alphabet', 3, 2)", "ph"); + test_expression!("substr('alphabet', 3, 20)", "phabet"); + test_expression!("substr('alphabet', CAST(NULL AS int), 20)", "NULL"); + test_expression!("substr('alphabet', 3, CAST(NULL AS int))", "NULL"); + test_expression!("trim(' tom ')", "tom"); + test_expression!("trim(' tom')", "tom"); + test_expression!("trim('')", ""); + test_expression!("trim('tom ')", "tom"); + test_expression!("upper('')", ""); + test_expression!("upper('tom')", "TOM"); + test_expression!("upper(NULL)", "NULL"); + Ok(()) +} + +#[tokio::test] +async fn test_boolean_expressions() -> Result<()> { + test_expression!("true", "true"); + test_expression!("false", "false"); + Ok(()) +} + +#[tokio::test] +async fn test_interval_expressions() -> Result<()> { + test_expression!( + "interval '1'", + "0 years 0 mons 0 days 0 hours 0 mins 1.00 secs" + ); + test_expression!( + "interval '1 second'", + "0 years 0 mons 0 days 0 hours 0 mins 1.00 secs" + ); + test_expression!( + "interval '500 milliseconds'", + "0 years 0 mons 0 days 0 hours 0 mins 0.500 secs" + ); + test_expression!( + "interval '5 second'", + "0 years 0 mons 0 days 0 hours 0 mins 5.00 secs" + ); + test_expression!( + "interval '0.5 minute'", + "0 years 0 mons 0 days 0 hours 0 mins 30.00 secs" + ); + test_expression!( + "interval '.5 minute'", + "0 years 0 mons 0 days 0 hours 0 mins 30.00 secs" + ); + test_expression!( + "interval '5 minute'", + "0 years 0 mons 0 days 0 hours 5 mins 0.00 secs" + ); + test_expression!( + "interval '5 minute 1 second'", + "0 years 0 mons 0 days 0 hours 5 mins 1.00 secs" + ); + test_expression!( + "interval '1 hour'", + "0 years 0 mons 0 days 1 hours 0 mins 0.00 secs" + ); + test_expression!( + "interval '5 hour'", + "0 years 0 mons 0 days 5 hours 0 mins 0.00 secs" + ); + test_expression!( + "interval '1 day'", + "0 years 0 mons 1 days 0 hours 0 mins 0.00 secs" + ); + test_expression!( + "interval '1 day 1'", + "0 years 0 mons 1 days 0 hours 0 mins 1.00 secs" + ); + test_expression!( + "interval '0.5'", + "0 years 0 mons 0 days 0 hours 0 mins 0.500 secs" + ); + test_expression!( + "interval '0.5 day 1'", + "0 years 0 mons 0 days 12 hours 0 mins 1.00 secs" + ); + test_expression!( + "interval '0.49 day'", + "0 years 0 mons 0 days 11 hours 45 mins 36.00 secs" + ); + test_expression!( + "interval '0.499 day'", + "0 years 0 mons 0 days 11 hours 58 mins 33.596 secs" + ); + test_expression!( + "interval '0.4999 day'", + "0 years 0 mons 0 days 11 hours 59 mins 51.364 secs" + ); + test_expression!( + "interval '0.49999 day'", + "0 years 0 mons 0 days 11 hours 59 mins 59.136 secs" + ); + test_expression!( + "interval '0.49999999999 day'", + "0 years 0 mons 0 days 12 hours 0 mins 0.00 secs" + ); + test_expression!( + "interval '5 day'", + "0 years 0 mons 5 days 0 hours 0 mins 0.00 secs" + ); + test_expression!( + "interval '5 day 4 hours 3 minutes 2 seconds 100 milliseconds'", + "0 years 0 mons 5 days 4 hours 3 mins 2.100 secs" + ); + test_expression!( + "interval '0.5 month'", + "0 years 0 mons 15 days 0 hours 0 mins 0.00 secs" + ); + test_expression!( + "interval '1 month'", + "0 years 1 mons 0 days 0 hours 0 mins 0.00 secs" + ); + test_expression!( + "interval '5 month'", + "0 years 5 mons 0 days 0 hours 0 mins 0.00 secs" + ); + test_expression!( + "interval '13 month'", + "1 years 1 mons 0 days 0 hours 0 mins 0.00 secs" + ); + test_expression!( + "interval '0.5 year'", + "0 years 6 mons 0 days 0 hours 0 mins 0.00 secs" + ); + test_expression!( + "interval '1 year'", + "1 years 0 mons 0 days 0 hours 0 mins 0.00 secs" + ); + test_expression!( + "interval '2 year'", + "2 years 0 mons 0 days 0 hours 0 mins 0.00 secs" + ); Ok(()) } #[tokio::test] -async fn crypto_expressions() -> Result<()> { - let mut ctx = ExecutionContext::new(); - let sql = "SELECT - md5('tom') AS md5_tom, - md5('') AS md5_empty_str, - md5(null) AS md5_null, - sha224('tom') AS sha224_tom, - sha224('') AS sha224_empty_str, - sha224(null) AS sha224_null, - sha256('tom') AS sha256_tom, - sha256('') AS sha256_empty_str, - sha384('tom') AS sha348_tom, - sha384('') AS sha384_empty_str, - sha512('tom') AS sha512_tom, - sha512('') AS sha512_empty_str - "; - let actual = execute(&mut ctx, sql).await; - - let expected = vec![vec![ - "34b7da764b21d298ef307d04d8152dc5", - "d41d8cd98f00b204e9800998ecf8427e", - "NULL", - "0bf6cb62649c42a9ae3876ab6f6d92ad36cb5414e495f8873292be4d", - "d14a028c2a3a2bc9476102bb288234c415a2b01f828ea62ac5b3e42f", - "NULL", - "e1608f75c5d7813f3d4031cb30bfb786507d98137538ff8e128a6ff74e84e643", - "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855", - "096f5b68aa77848e4fdf5c1c0b350de2dbfad60ffd7c25d9ea07c6c19b8a4d55a9187eb117c557883f58c16dfac3e343", - "38b060a751ac96384cd9327eb1b1e36a21fdb71114be07434c0cc7bf63f6e1da274edebfe76f65fbd51ad2f14898b95b", - "6e1b9b3fe840680e37051f7ad5e959d6f39ad0f8885d855166f55c659469d3c8b78118c44a2a49c72ddb481cd6d8731034e11cc030070ba843a90b3495cb8d3e", - "cf83e1357eefb8bdf1542850d66d8007d620e4050b5715dc83f4a921d36ce9ce47d0d13c5d85f2b0ff8318d2877eec2f63b931bd47417a81a538327af927da3e" - ]]; - assert_eq!(expected, actual); +async fn test_crypto_expressions() -> Result<()> { + test_expression!("md5('tom')", "34b7da764b21d298ef307d04d8152dc5"); + test_expression!("md5('')", "d41d8cd98f00b204e9800998ecf8427e"); + test_expression!("md5(NULL)", "NULL"); + test_expression!( + "sha224('tom')", + "0bf6cb62649c42a9ae3876ab6f6d92ad36cb5414e495f8873292be4d" + ); + test_expression!( + "sha224('')", + "d14a028c2a3a2bc9476102bb288234c415a2b01f828ea62ac5b3e42f" + ); + test_expression!("sha224(NULL)", "NULL"); + test_expression!( + "sha256('tom')", + "e1608f75c5d7813f3d4031cb30bfb786507d98137538ff8e128a6ff74e84e643" + ); + test_expression!( + "sha256('')", + "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855" + ); + test_expression!("sha256(NULL)", "NULL"); + test_expression!("sha384('tom')", "096f5b68aa77848e4fdf5c1c0b350de2dbfad60ffd7c25d9ea07c6c19b8a4d55a9187eb117c557883f58c16dfac3e343"); + test_expression!("sha384('')", "38b060a751ac96384cd9327eb1b1e36a21fdb71114be07434c0cc7bf63f6e1da274edebfe76f65fbd51ad2f14898b95b"); + test_expression!("sha384(NULL)", "NULL"); + test_expression!("sha512('tom')", "6e1b9b3fe840680e37051f7ad5e959d6f39ad0f8885d855166f55c659469d3c8b78118c44a2a49c72ddb481cd6d8731034e11cc030070ba843a90b3495cb8d3e"); + test_expression!("sha512('')", "cf83e1357eefb8bdf1542850d66d8007d620e4050b5715dc83f4a921d36ce9ce47d0d13c5d85f2b0ff8318d2877eec2f63b931bd47417a81a538327af927da3e"); + test_expression!("sha512(NULL)", "NULL"); + Ok(()) +} +#[tokio::test] +async fn test_extract_date_part() -> Result<()> { + test_expression!("date_part('hour', CAST('2020-01-01' AS DATE))", "0"); + test_expression!("EXTRACT(HOUR FROM CAST('2020-01-01' AS DATE))", "0"); + test_expression!( + "EXTRACT(HOUR FROM to_timestamp('2020-09-08T12:00:00+00:00'))", + "12" + ); + test_expression!("date_part('YEAR', CAST('2000-01-01' AS DATE))", "2000"); + test_expression!( + "EXTRACT(year FROM to_timestamp('2020-09-08T12:00:00+00:00'))", + "2020" + ); Ok(()) } #[tokio::test] -async fn extract_date_part() -> Result<()> { - let mut ctx = ExecutionContext::new(); - let sql = "SELECT - date_part('hour', CAST('2020-01-01' AS DATE)) AS hr1, - EXTRACT(HOUR FROM CAST('2020-01-01' AS DATE)) AS hr2, - EXTRACT(HOUR FROM to_timestamp('2020-09-08T12:00:00+00:00')) AS hr3, - date_part('YEAR', CAST('2000-01-01' AS DATE)) AS year1, - EXTRACT(year FROM to_timestamp('2020-09-08T12:00:00+00:00')) AS year2 - "; - - let actual = execute(&mut ctx, sql).await; - - let expected = vec![vec!["0", "0", "12", "2000", "2020"]]; - assert_eq!(expected, actual); +async fn test_in_list_scalar() -> Result<()> { + test_expression!("'a' IN ('a','b')", "true"); + test_expression!("'c' IN ('a','b')", "false"); + test_expression!("'c' NOT IN ('a','b')", "true"); + test_expression!("'a' NOT IN ('a','b')", "false"); + test_expression!("NULL IN ('a','b')", "NULL"); + test_expression!("NULL NOT IN ('a','b')", "NULL"); + test_expression!("'a' IN ('a','b',NULL)", "true"); + test_expression!("'c' IN ('a','b',NULL)", "NULL"); + test_expression!("'a' NOT IN ('a','b',NULL)", "false"); + test_expression!("'c' NOT IN ('a','b',NULL)", "NULL"); + test_expression!("0 IN (0,1,2)", "true"); + test_expression!("3 IN (0,1,2)", "false"); + test_expression!("3 NOT IN (0,1,2)", "true"); + test_expression!("0 NOT IN (0,1,2)", "false"); + test_expression!("NULL IN (0,1,2)", "NULL"); + test_expression!("NULL NOT IN (0,1,2)", "NULL"); + test_expression!("0 IN (0,1,2,NULL)", "true"); + test_expression!("3 IN (0,1,2,NULL)", "NULL"); + test_expression!("0 NOT IN (0,1,2,NULL)", "false"); + test_expression!("3 NOT IN (0,1,2,NULL)", "NULL"); + test_expression!("0.0 IN (0.0,0.1,0.2)", "true"); + test_expression!("0.3 IN (0.0,0.1,0.2)", "false"); + test_expression!("0.3 NOT IN (0.0,0.1,0.2)", "true"); + test_expression!("0.0 NOT IN (0.0,0.1,0.2)", "false"); + test_expression!("NULL IN (0.0,0.1,0.2)", "NULL"); + test_expression!("NULL NOT IN (0.0,0.1,0.2)", "NULL"); + test_expression!("0.0 IN (0.0,0.1,0.2,NULL)", "true"); + test_expression!("0.3 IN (0.0,0.1,0.2,NULL)", "NULL"); + test_expression!("0.0 NOT IN (0.0,0.1,0.2,NULL)", "false"); + test_expression!("0.3 NOT IN (0.0,0.1,0.2,NULL)", "NULL"); + test_expression!("'1' IN ('a','b',1)", "true"); + test_expression!("'2' IN ('a','b',1)", "false"); + test_expression!("'2' NOT IN ('a','b',1)", "true"); + test_expression!("'1' NOT IN ('a','b',1)", "false"); + test_expression!("NULL IN ('a','b',1)", "NULL"); + test_expression!("NULL NOT IN ('a','b',1)", "NULL"); + test_expression!("'1' IN ('a','b',NULL,1)", "true"); + test_expression!("'2' IN ('a','b',NULL,1)", "NULL"); + test_expression!("'1' NOT IN ('a','b',NULL,1)", "false"); + test_expression!("'2' NOT IN ('a','b',NULL,1)", "NULL"); Ok(()) } @@ -2215,67 +2335,6 @@ async fn in_list_array() -> Result<()> { Ok(()) } -#[tokio::test] -async fn in_list_scalar() -> Result<()> { - let mut ctx = ExecutionContext::new(); - let sql = "SELECT - 'a' IN ('a','b') AS utf8_in_true - ,'c' IN ('a','b') AS utf8_in_false - ,'c' NOT IN ('a','b') AS utf8_not_in_true - ,'a' NOT IN ('a','b') AS utf8_not_in_false - ,NULL IN ('a','b') AS utf8_in_null - ,NULL NOT IN ('a','b') AS utf8_not_in_null - ,'a' IN ('a','b',NULL) AS utf8_in_null_true - ,'c' IN ('a','b',NULL) AS utf8_in_null_null - ,'a' NOT IN ('a','b',NULL) AS utf8_not_in_null_false - ,'c' NOT IN ('a','b',NULL) AS utf8_not_in_null_null - - ,0 IN (0,1,2) AS int64_in_true - ,3 IN (0,1,2) AS int64_in_false - ,3 NOT IN (0,1,2) AS int64_not_in_true - ,0 NOT IN (0,1,2) AS int64_not_in_false - ,NULL IN (0,1,2) AS int64_in_null - ,NULL NOT IN (0,1,2) AS int64_not_in_null - ,0 IN (0,1,2,NULL) AS int64_in_null_true - ,3 IN (0,1,2,NULL) AS int64_in_null_null - ,0 NOT IN (0,1,2,NULL) AS int64_not_in_null_false - ,3 NOT IN (0,1,2,NULL) AS int64_not_in_null_null - - ,0.0 IN (0.0,0.1,0.2) AS float64_in_true - ,0.3 IN (0.0,0.1,0.2) AS float64_in_false - ,0.3 NOT IN (0.0,0.1,0.2) AS float64_not_in_true - ,0.0 NOT IN (0.0,0.1,0.2) AS float64_not_in_false - ,NULL IN (0.0,0.1,0.2) AS float64_in_null - ,NULL NOT IN (0.0,0.1,0.2) AS float64_not_in_null - ,0.0 IN (0.0,0.1,0.2,NULL) AS float64_in_null_true - ,0.3 IN (0.0,0.1,0.2,NULL) AS float64_in_null_null - ,0.0 NOT IN (0.0,0.1,0.2,NULL) AS float64_not_in_null_false - ,0.3 NOT IN (0.0,0.1,0.2,NULL) AS float64_not_in_null_null - - ,'1' IN ('a','b',1) AS utf8_cast_in_true - ,'2' IN ('a','b',1) AS utf8_cast_in_false - ,'2' NOT IN ('a','b',1) AS utf8_cast_not_in_true - ,'1' NOT IN ('a','b',1) AS utf8_cast_not_in_false - ,NULL IN ('a','b',1) AS utf8_cast_in_null - ,NULL NOT IN ('a','b',1) AS utf8_cast_not_in_null - ,'1' IN ('a','b',NULL,1) AS utf8_cast_in_null_true - ,'2' IN ('a','b',NULL,1) AS utf8_cast_in_null_null - ,'1' NOT IN ('a','b',NULL,1) AS utf8_cast_not_in_null_false - ,'2' NOT IN ('a','b',NULL,1) AS utf8_cast_not_in_null_null - "; - let actual = execute(&mut ctx, sql).await; - - let expected = vec![vec![ - "true", "false", "true", "false", "NULL", "NULL", "true", "NULL", "false", - "NULL", "true", "false", "true", "false", "NULL", "NULL", "true", "NULL", - "false", "NULL", "true", "false", "true", "false", "NULL", "NULL", "true", - "NULL", "false", "NULL", "true", "false", "true", "false", "NULL", "NULL", - "true", "NULL", "false", "NULL", - ]]; - assert_eq!(expected, actual); - Ok(()) -} - // TODO Tests to prove correct implementation of INNER JOIN's with qualified names. // https://issues.apache.org/jira/projects/ARROW/issues/ARROW-11432. #[tokio::test]