diff --git a/rust/datafusion/README.md b/rust/datafusion/README.md index 5dcab04399e..6f9cd85deaa 100644 --- a/rust/datafusion/README.md +++ b/rust/datafusion/README.md @@ -64,8 +64,12 @@ DataFusion includes a simple command-line interactive SQL utility. See the [CLI - [x] concat - [x] concat_ws - [x] length + - [x] left + - [x] lpad - [x] ltrim - [x] octet_length + - [x] right + - [x] rpad - [x] rtrim - [x] substr - [x] trim diff --git a/rust/datafusion/src/logical_plan/expr.rs b/rust/datafusion/src/logical_plan/expr.rs index 775ab64ac14..38039b41f1e 100644 --- a/rust/datafusion/src/logical_plan/expr.rs +++ b/rust/datafusion/src/logical_plan/expr.rs @@ -1069,10 +1069,14 @@ unary_scalar_expr!(CharacterLength, character_length); unary_scalar_expr!(CharacterLength, length); unary_scalar_expr!(Concat, concat); unary_scalar_expr!(ConcatWithSeparator, concat_ws); +unary_scalar_expr!(Left, left); unary_scalar_expr!(Lower, lower); +unary_scalar_expr!(Lpad, lpad); unary_scalar_expr!(Ltrim, ltrim); unary_scalar_expr!(MD5, md5); unary_scalar_expr!(OctetLength, octet_length); +unary_scalar_expr!(Right, right); +unary_scalar_expr!(Rpad, rpad); unary_scalar_expr!(Rtrim, rtrim); unary_scalar_expr!(SHA224, sha224); unary_scalar_expr!(SHA256, sha256); diff --git a/rust/datafusion/src/logical_plan/mod.rs b/rust/datafusion/src/logical_plan/mod.rs index 90c35dc3a23..08ba81c4271 100644 --- a/rust/datafusion/src/logical_plan/mod.rs +++ b/rust/datafusion/src/logical_plan/mod.rs @@ -36,9 +36,10 @@ pub use expr::{ abs, acos, and, array, asin, atan, avg, binary_expr, bit_length, btrim, case, ceil, character_length, col, combine_filters, concat, concat_ws, cos, count, count_distinct, create_udaf, create_udf, exp, exprlist_to_fields, floor, in_list, - length, lit, ln, log10, log2, lower, ltrim, max, md5, min, octet_length, or, round, - rtrim, sha224, sha256, sha384, sha512, signum, sin, sqrt, substr, sum, tan, trim, - trunc, upper, when, Expr, ExprRewriter, ExpressionVisitor, Literal, Recursion, + left, length, lit, ln, log10, log2, lower, lpad, ltrim, max, md5, min, octet_length, + or, right, round, rpad, rtrim, sha224, sha256, sha384, sha512, signum, sin, sqrt, + substr, sum, tan, trim, trunc, upper, when, Expr, ExprRewriter, ExpressionVisitor, + Literal, Recursion, }; pub use extension::UserDefinedLogicalNode; pub use operators::Operator; diff --git a/rust/datafusion/src/physical_plan/functions.rs b/rust/datafusion/src/physical_plan/functions.rs index 1c82d0fea45..86dc67b2074 100644 --- a/rust/datafusion/src/physical_plan/functions.rs +++ b/rust/datafusion/src/physical_plan/functions.rs @@ -139,6 +139,10 @@ pub enum BuiltinScalarFunction { DatePart, /// Date truncate DateTrunc, + /// left + Left, + /// lpad + Lpad, /// lower Lower, /// trim left @@ -149,6 +153,10 @@ pub enum BuiltinScalarFunction { NullIf, /// octet_length OctetLength, + /// right + Right, + /// rpad + Rpad, /// trim right Rtrim, /// SHA224 @@ -209,12 +217,16 @@ impl FromStr for BuiltinScalarFunction { "concat_ws" => BuiltinScalarFunction::ConcatWithSeparator, "date_part" => BuiltinScalarFunction::DatePart, "date_trunc" => BuiltinScalarFunction::DateTrunc, + "left" => BuiltinScalarFunction::Left, "length" => BuiltinScalarFunction::CharacterLength, "lower" => BuiltinScalarFunction::Lower, + "lpad" => BuiltinScalarFunction::Lpad, "ltrim" => BuiltinScalarFunction::Ltrim, "md5" => BuiltinScalarFunction::MD5, "nullif" => BuiltinScalarFunction::NullIf, "octet_length" => BuiltinScalarFunction::OctetLength, + "right" => BuiltinScalarFunction::Right, + "rpad" => BuiltinScalarFunction::Rpad, "rtrim" => BuiltinScalarFunction::Rtrim, "sha224" => BuiltinScalarFunction::SHA224, "sha256" => BuiltinScalarFunction::SHA256, @@ -298,6 +310,16 @@ pub fn return_type( BuiltinScalarFunction::DateTrunc => { Ok(DataType::Timestamp(TimeUnit::Nanosecond, None)) } + BuiltinScalarFunction::Left => Ok(match arg_types[0] { + DataType::LargeUtf8 => DataType::LargeUtf8, + DataType::Utf8 => DataType::Utf8, + _ => { + // this error is internal as `data_types` should have captured this. + return Err(DataFusionError::Internal( + "The left function can only accept strings.".to_string(), + )); + } + }), BuiltinScalarFunction::Lower => Ok(match arg_types[0] { DataType::LargeUtf8 => DataType::LargeUtf8, DataType::Utf8 => DataType::Utf8, @@ -308,6 +330,16 @@ pub fn return_type( )); } }), + BuiltinScalarFunction::Lpad => Ok(match arg_types[0] { + DataType::LargeUtf8 => DataType::LargeUtf8, + DataType::Utf8 => DataType::Utf8, + _ => { + // this error is internal as `data_types` should have captured this. + return Err(DataFusionError::Internal( + "The lpad function can only accept strings.".to_string(), + )); + } + }), BuiltinScalarFunction::Ltrim => Ok(match arg_types[0] { DataType::LargeUtf8 => DataType::LargeUtf8, DataType::Utf8 => DataType::Utf8, @@ -343,6 +375,26 @@ pub fn return_type( )); } }), + BuiltinScalarFunction::Right => Ok(match arg_types[0] { + DataType::LargeUtf8 => DataType::LargeUtf8, + DataType::Utf8 => DataType::Utf8, + _ => { + // this error is internal as `data_types` should have captured this. + return Err(DataFusionError::Internal( + "The right function can only accept strings.".to_string(), + )); + } + }), + BuiltinScalarFunction::Rpad => Ok(match arg_types[0] { + DataType::LargeUtf8 => DataType::LargeUtf8, + DataType::Utf8 => DataType::Utf8, + _ => { + // this error is internal as `data_types` should have captured this. + return Err(DataFusionError::Internal( + "The rpad function can only accept strings.".to_string(), + )); + } + }), BuiltinScalarFunction::Rtrim => Ok(match arg_types[0] { DataType::LargeUtf8 => DataType::LargeUtf8, DataType::Utf8 => DataType::Utf8, @@ -518,7 +570,27 @@ pub fn create_physical_expr( } BuiltinScalarFunction::DatePart => datetime_expressions::date_part, BuiltinScalarFunction::DateTrunc => datetime_expressions::date_trunc, + BuiltinScalarFunction::Left => |args| match args[0].data_type() { + DataType::Utf8 => make_scalar_function(string_expressions::left::)(args), + DataType::LargeUtf8 => { + make_scalar_function(string_expressions::left::)(args) + } + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {:?} for function left", + other, + ))), + }, BuiltinScalarFunction::Lower => string_expressions::lower, + BuiltinScalarFunction::Lpad => |args| match args[0].data_type() { + DataType::Utf8 => make_scalar_function(string_expressions::lpad::)(args), + DataType::LargeUtf8 => { + make_scalar_function(string_expressions::lpad::)(args) + } + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {:?} for function lpad", + other, + ))), + }, BuiltinScalarFunction::Ltrim => |args| match args[0].data_type() { DataType::Utf8 => { make_scalar_function(string_expressions::ltrim::)(args) @@ -545,6 +617,28 @@ pub fn create_physical_expr( _ => unreachable!(), }, }, + BuiltinScalarFunction::Right => |args| match args[0].data_type() { + DataType::Utf8 => { + make_scalar_function(string_expressions::right::)(args) + } + DataType::LargeUtf8 => { + make_scalar_function(string_expressions::right::)(args) + } + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {:?} for function right", + other, + ))), + }, + BuiltinScalarFunction::Rpad => |args| match args[0].data_type() { + DataType::Utf8 => make_scalar_function(string_expressions::rpad::)(args), + DataType::LargeUtf8 => { + make_scalar_function(string_expressions::rpad::)(args) + } + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {:?} for function rpad", + other, + ))), + }, BuiltinScalarFunction::Rtrim => |args| match args[0].data_type() { DataType::Utf8 => { make_scalar_function(string_expressions::rtrim::)(args) @@ -635,6 +729,34 @@ fn signature(fun: &BuiltinScalarFunction) -> Signature { Signature::Exact(vec![DataType::Utf8]), Signature::Exact(vec![DataType::Utf8, DataType::Utf8]), ]), + BuiltinScalarFunction::Lpad | BuiltinScalarFunction::Rpad => { + Signature::OneOf(vec![ + Signature::Exact(vec![DataType::Utf8, DataType::Int64]), + Signature::Exact(vec![DataType::LargeUtf8, DataType::Int64]), + Signature::Exact(vec![DataType::Utf8, DataType::Int64, DataType::Utf8]), + Signature::Exact(vec![ + DataType::LargeUtf8, + DataType::Int64, + DataType::Utf8, + ]), + Signature::Exact(vec![ + DataType::Utf8, + DataType::Int64, + DataType::LargeUtf8, + ]), + Signature::Exact(vec![ + DataType::LargeUtf8, + DataType::Int64, + DataType::LargeUtf8, + ]), + ]) + } + BuiltinScalarFunction::Left | BuiltinScalarFunction::Right => { + Signature::OneOf(vec![ + Signature::Exact(vec![DataType::Utf8, DataType::Int64]), + Signature::Exact(vec![DataType::LargeUtf8, DataType::Int64]), + ]) + } BuiltinScalarFunction::ToTimestamp => Signature::Uniform(1, vec![DataType::Utf8]), BuiltinScalarFunction::DateTrunc => Signature::Exact(vec![ DataType::Utf8, @@ -1137,203 +1259,225 @@ mod tests { Float64Array ); test_function!( - Ltrim, - &[lit(ScalarValue::Utf8(Some(" trim".to_string())))], - Ok(Some("trim")), + Left, + &[ + lit(ScalarValue::Utf8(Some("abcde".to_string()))), + lit(ScalarValue::Int8(Some(2))), + ], + Ok(Some("ab")), &str, Utf8, StringArray ); test_function!( - Ltrim, - &[lit(ScalarValue::Utf8(Some(" trim ".to_string())))], - Ok(Some("trim ")), + Left, + &[ + lit(ScalarValue::Utf8(Some("abcde".to_string()))), + lit(ScalarValue::Int64(Some(200))), + ], + Ok(Some("abcde")), &str, Utf8, StringArray ); test_function!( - Ltrim, - &[lit(ScalarValue::Utf8(Some("trim ".to_string())))], - Ok(Some("trim ")), + Left, + &[ + lit(ScalarValue::Utf8(Some("abcde".to_string()))), + lit(ScalarValue::Int64(Some(-2))), + ], + Ok(Some("abc")), &str, Utf8, StringArray ); test_function!( - Ltrim, - &[lit(ScalarValue::Utf8(Some("trim".to_string())))], - Ok(Some("trim")), + Left, + &[ + lit(ScalarValue::Utf8(Some("abcde".to_string()))), + lit(ScalarValue::Int64(Some(-200))), + ], + Ok(Some("")), &str, Utf8, StringArray ); test_function!( - Ltrim, - &[lit(ScalarValue::Utf8(Some("\n trim ".to_string())))], - Ok(Some("\n trim ")), + Left, + &[ + lit(ScalarValue::Utf8(Some("abcde".to_string()))), + lit(ScalarValue::Int64(Some(0))), + ], + Ok(Some("")), &str, Utf8, StringArray ); test_function!( - Ltrim, - &[lit(ScalarValue::Utf8(None))], + Left, + &[ + lit(ScalarValue::Utf8(None)), + lit(ScalarValue::Int64(Some(2))), + ], Ok(None), &str, Utf8, StringArray ); test_function!( - OctetLength, - &[lit(ScalarValue::Utf8(Some("chars".to_string())))], - Ok(Some(5)), - i32, - Int32, - Int32Array - ); - test_function!( - OctetLength, - &[lit(ScalarValue::Utf8(Some("josé".to_string())))], - Ok(Some(5)), - i32, - Int32, - Int32Array + Left, + &[ + lit(ScalarValue::Utf8(Some("abcde".to_string()))), + lit(ScalarValue::Int64(None)), + ], + Ok(None), + &str, + Utf8, + StringArray ); test_function!( - OctetLength, - &[lit(ScalarValue::Utf8(Some("".to_string())))], - Ok(Some(0)), - i32, - Int32, - Int32Array + Left, + &[ + lit(ScalarValue::Utf8(Some("joséésoj".to_string()))), + lit(ScalarValue::Int64(Some(5))), + ], + Ok(Some("joséé")), + &str, + Utf8, + StringArray ); test_function!( - OctetLength, - &[lit(ScalarValue::Utf8(None))], - Ok(None), - i32, - Int32, - Int32Array + Left, + &[ + lit(ScalarValue::Utf8(Some("joséésoj".to_string()))), + lit(ScalarValue::Int64(Some(-3))), + ], + Ok(Some("joséé")), + &str, + Utf8, + StringArray ); test_function!( - Substr, + Lpad, &[ - lit(ScalarValue::Utf8(Some("alphabet".to_string()))), - lit(ScalarValue::Int64(Some(0))), + lit(ScalarValue::Utf8(Some("josé".to_string()))), + lit(ScalarValue::Int64(Some(5))), ], - Ok(Some("alphabet")), + Ok(Some(" josé")), &str, Utf8, StringArray ); test_function!( - Substr, + Lpad, &[ - lit(ScalarValue::Utf8(Some("joséésoj".to_string()))), + lit(ScalarValue::Utf8(Some("hi".to_string()))), lit(ScalarValue::Int64(Some(5))), ], - Ok(Some("ésoj")), + Ok(Some(" hi")), &str, Utf8, StringArray ); test_function!( - Substr, + Lpad, &[ - lit(ScalarValue::Utf8(Some("alphabet".to_string()))), - lit(ScalarValue::Int64(Some(1))), + lit(ScalarValue::Utf8(Some("hi".to_string()))), + lit(ScalarValue::Int64(Some(0))), ], - Ok(Some("alphabet")), + Ok(Some("")), &str, Utf8, StringArray ); test_function!( - Substr, + Lpad, &[ - lit(ScalarValue::Utf8(Some("alphabet".to_string()))), - lit(ScalarValue::Int64(Some(2))), + lit(ScalarValue::Utf8(Some("hi".to_string()))), + lit(ScalarValue::Int64(None)), ], - Ok(Some("lphabet")), + Ok(None), &str, Utf8, StringArray ); test_function!( - Substr, + Lpad, &[ - lit(ScalarValue::Utf8(Some("alphabet".to_string()))), - lit(ScalarValue::Int64(Some(3))), + lit(ScalarValue::Utf8(None)), + lit(ScalarValue::Int64(Some(5))), ], - Ok(Some("phabet")), + Ok(None), &str, Utf8, StringArray ); test_function!( - Substr, + Lpad, &[ - lit(ScalarValue::Utf8(Some("alphabet".to_string()))), - lit(ScalarValue::Int64(Some(-3))), + lit(ScalarValue::Utf8(Some("hi".to_string()))), + lit(ScalarValue::Int64(Some(5))), + lit(ScalarValue::Utf8(Some("xy".to_string()))), ], - Ok(Some("alphabet")), + Ok(Some("xyxhi")), &str, Utf8, StringArray ); test_function!( - Substr, + Lpad, &[ - lit(ScalarValue::Utf8(Some("alphabet".to_string()))), - lit(ScalarValue::Int64(Some(30))), + lit(ScalarValue::Utf8(Some("hi".to_string()))), + lit(ScalarValue::Int64(Some(21))), + lit(ScalarValue::Utf8(Some("abcdef".to_string()))), ], - Ok(Some("")), + Ok(Some("abcdefabcdefabcdefahi")), &str, Utf8, StringArray ); test_function!( - Substr, + Lpad, &[ - lit(ScalarValue::Utf8(Some("alphabet".to_string()))), - lit(ScalarValue::Int64(None)), + lit(ScalarValue::Utf8(Some("hi".to_string()))), + lit(ScalarValue::Int64(Some(5))), + lit(ScalarValue::Utf8(Some(" ".to_string()))), ], - Ok(None), + Ok(Some(" hi")), &str, Utf8, StringArray ); test_function!( - Substr, + Lpad, &[ - lit(ScalarValue::Utf8(Some("alphabet".to_string()))), - lit(ScalarValue::Int64(Some(3))), - lit(ScalarValue::Int64(Some(2))), + lit(ScalarValue::Utf8(Some("hi".to_string()))), + lit(ScalarValue::Int64(Some(5))), + lit(ScalarValue::Utf8(Some("".to_string()))), ], - Ok(Some("ph")), + Ok(Some("hi")), &str, Utf8, StringArray ); test_function!( - Substr, + Lpad, &[ - lit(ScalarValue::Utf8(Some("alphabet".to_string()))), - lit(ScalarValue::Int64(Some(3))), - lit(ScalarValue::Int64(Some(20))), + lit(ScalarValue::Utf8(None)), + lit(ScalarValue::Int64(Some(5))), + lit(ScalarValue::Utf8(Some("xy".to_string()))), ], - Ok(Some("phabet")), + Ok(None), &str, Utf8, StringArray ); test_function!( - Substr, + Lpad, &[ - lit(ScalarValue::Utf8(Some("alphabet".to_string()))), + lit(ScalarValue::Utf8(Some("hi".to_string()))), lit(ScalarValue::Int64(None)), - lit(ScalarValue::Int64(Some(20))), + lit(ScalarValue::Utf8(Some("xy".to_string()))), ], Ok(None), &str, @@ -1341,11 +1485,11 @@ mod tests { StringArray ); test_function!( - Substr, + Lpad, &[ - lit(ScalarValue::Utf8(Some("alphabet".to_string()))), - lit(ScalarValue::Int64(Some(3))), - lit(ScalarValue::Int64(None)), + lit(ScalarValue::Utf8(Some("hi".to_string()))), + lit(ScalarValue::Int64(Some(5))), + lit(ScalarValue::Utf8(None)), ], Ok(None), &str, @@ -1353,79 +1497,581 @@ mod tests { StringArray ); test_function!( - Substr, + Lpad, &[ - lit(ScalarValue::Utf8(Some("alphabet".to_string()))), - lit(ScalarValue::Int64(Some(1))), - lit(ScalarValue::Int64(Some(-1))), + lit(ScalarValue::Utf8(Some("josé".to_string()))), + lit(ScalarValue::Int64(Some(10))), + lit(ScalarValue::Utf8(Some("xy".to_string()))), ], - Err(DataFusionError::Execution( - "negative substring length not allowed".to_string(), - )), + Ok(Some("xyxyxyjosé")), &str, Utf8, StringArray ); test_function!( - Substr, + Lpad, &[ - lit(ScalarValue::Utf8(Some("joséésoj".to_string()))), - lit(ScalarValue::Int64(Some(5))), - lit(ScalarValue::Int64(Some(2))), + lit(ScalarValue::Utf8(Some("josé".to_string()))), + lit(ScalarValue::Int64(Some(10))), + lit(ScalarValue::Utf8(Some("éñ".to_string()))), ], - Ok(Some("és")), + Ok(Some("éñéñéñjosé")), &str, Utf8, StringArray ); test_function!( - Rtrim, - &[lit(ScalarValue::Utf8(Some("trim ".to_string())))], + Ltrim, + &[lit(ScalarValue::Utf8(Some(" trim".to_string())))], Ok(Some("trim")), &str, Utf8, StringArray ); test_function!( - Rtrim, + Ltrim, &[lit(ScalarValue::Utf8(Some(" trim ".to_string())))], - Ok(Some(" trim")), + Ok(Some("trim ")), &str, Utf8, StringArray ); test_function!( - Rtrim, - &[lit(ScalarValue::Utf8(Some(" trim \n".to_string())))], - Ok(Some(" trim \n")), + Ltrim, + &[lit(ScalarValue::Utf8(Some("trim ".to_string())))], + Ok(Some("trim ")), &str, Utf8, StringArray ); test_function!( - Rtrim, - &[lit(ScalarValue::Utf8(Some(" trim".to_string())))], - Ok(Some(" trim")), + Ltrim, + &[lit(ScalarValue::Utf8(Some("trim".to_string())))], + Ok(Some("trim")), &str, Utf8, StringArray ); test_function!( - Rtrim, - &[lit(ScalarValue::Utf8(Some("trim".to_string())))], - Ok(Some("trim")), + Ltrim, + &[lit(ScalarValue::Utf8(Some("\n trim ".to_string())))], + Ok(Some("\n trim ")), &str, Utf8, StringArray ); test_function!( - Rtrim, + Ltrim, &[lit(ScalarValue::Utf8(None))], Ok(None), &str, Utf8, StringArray ); + test_function!( + OctetLength, + &[lit(ScalarValue::Utf8(Some("chars".to_string())))], + Ok(Some(5)), + i32, + Int32, + Int32Array + ); + test_function!( + OctetLength, + &[lit(ScalarValue::Utf8(Some("josé".to_string())))], + Ok(Some(5)), + i32, + Int32, + Int32Array + ); + test_function!( + OctetLength, + &[lit(ScalarValue::Utf8(Some("".to_string())))], + Ok(Some(0)), + i32, + Int32, + Int32Array + ); + test_function!( + OctetLength, + &[lit(ScalarValue::Utf8(None))], + Ok(None), + i32, + Int32, + Int32Array + ); + test_function!( + Right, + &[ + lit(ScalarValue::Utf8(Some("abcde".to_string()))), + lit(ScalarValue::Int8(Some(2))), + ], + Ok(Some("de")), + &str, + Utf8, + StringArray + ); + test_function!( + Right, + &[ + lit(ScalarValue::Utf8(Some("abcde".to_string()))), + lit(ScalarValue::Int64(Some(200))), + ], + Ok(Some("abcde")), + &str, + Utf8, + StringArray + ); + test_function!( + Right, + &[ + lit(ScalarValue::Utf8(Some("abcde".to_string()))), + lit(ScalarValue::Int64(Some(-2))), + ], + Ok(Some("cde")), + &str, + Utf8, + StringArray + ); + test_function!( + Right, + &[ + lit(ScalarValue::Utf8(Some("abcde".to_string()))), + lit(ScalarValue::Int64(Some(-200))), + ], + Ok(Some("")), + &str, + Utf8, + StringArray + ); + test_function!( + Right, + &[ + lit(ScalarValue::Utf8(Some("abcde".to_string()))), + lit(ScalarValue::Int64(Some(0))), + ], + Ok(Some("")), + &str, + Utf8, + StringArray + ); + test_function!( + Right, + &[ + lit(ScalarValue::Utf8(None)), + lit(ScalarValue::Int64(Some(2))), + ], + Ok(None), + &str, + Utf8, + StringArray + ); + test_function!( + Right, + &[ + lit(ScalarValue::Utf8(Some("abcde".to_string()))), + lit(ScalarValue::Int64(None)), + ], + Ok(None), + &str, + Utf8, + StringArray + ); + test_function!( + Right, + &[ + lit(ScalarValue::Utf8(Some("joséésoj".to_string()))), + lit(ScalarValue::Int64(Some(5))), + ], + Ok(Some("éésoj")), + &str, + Utf8, + StringArray + ); + test_function!( + Right, + &[ + lit(ScalarValue::Utf8(Some("joséésoj".to_string()))), + lit(ScalarValue::Int64(Some(-3))), + ], + Ok(Some("éésoj")), + &str, + Utf8, + StringArray + ); + test_function!( + Rpad, + &[ + lit(ScalarValue::Utf8(Some("josé".to_string()))), + lit(ScalarValue::Int64(Some(5))), + ], + Ok(Some("josé ")), + &str, + Utf8, + StringArray + ); + test_function!( + Rpad, + &[ + lit(ScalarValue::Utf8(Some("hi".to_string()))), + lit(ScalarValue::Int64(Some(5))), + ], + Ok(Some("hi ")), + &str, + Utf8, + StringArray + ); + test_function!( + Rpad, + &[ + lit(ScalarValue::Utf8(Some("hi".to_string()))), + lit(ScalarValue::Int64(Some(0))), + ], + Ok(Some("")), + &str, + Utf8, + StringArray + ); + test_function!( + Rpad, + &[ + lit(ScalarValue::Utf8(Some("hi".to_string()))), + lit(ScalarValue::Int64(None)), + ], + Ok(None), + &str, + Utf8, + StringArray + ); + test_function!( + Rpad, + &[ + lit(ScalarValue::Utf8(None)), + lit(ScalarValue::Int64(Some(5))), + ], + Ok(None), + &str, + Utf8, + StringArray + ); + test_function!( + Rpad, + &[ + lit(ScalarValue::Utf8(Some("hi".to_string()))), + lit(ScalarValue::Int64(Some(5))), + lit(ScalarValue::Utf8(Some("xy".to_string()))), + ], + Ok(Some("hixyx")), + &str, + Utf8, + StringArray + ); + test_function!( + Rpad, + &[ + lit(ScalarValue::Utf8(Some("hi".to_string()))), + lit(ScalarValue::Int64(Some(21))), + lit(ScalarValue::Utf8(Some("abcdef".to_string()))), + ], + Ok(Some("hiabcdefabcdefabcdefa")), + &str, + Utf8, + StringArray + ); + test_function!( + Rpad, + &[ + lit(ScalarValue::Utf8(Some("hi".to_string()))), + lit(ScalarValue::Int64(Some(5))), + lit(ScalarValue::Utf8(Some(" ".to_string()))), + ], + Ok(Some("hi ")), + &str, + Utf8, + StringArray + ); + test_function!( + Rpad, + &[ + lit(ScalarValue::Utf8(Some("hi".to_string()))), + lit(ScalarValue::Int64(Some(5))), + lit(ScalarValue::Utf8(Some("".to_string()))), + ], + Ok(Some("hi")), + &str, + Utf8, + StringArray + ); + test_function!( + Rpad, + &[ + lit(ScalarValue::Utf8(None)), + lit(ScalarValue::Int64(Some(5))), + lit(ScalarValue::Utf8(Some("xy".to_string()))), + ], + Ok(None), + &str, + Utf8, + StringArray + ); + test_function!( + Rpad, + &[ + lit(ScalarValue::Utf8(Some("hi".to_string()))), + lit(ScalarValue::Int64(None)), + lit(ScalarValue::Utf8(Some("xy".to_string()))), + ], + Ok(None), + &str, + Utf8, + StringArray + ); + test_function!( + Rpad, + &[ + lit(ScalarValue::Utf8(Some("hi".to_string()))), + lit(ScalarValue::Int64(Some(5))), + lit(ScalarValue::Utf8(None)), + ], + Ok(None), + &str, + Utf8, + StringArray + ); + test_function!( + Rpad, + &[ + lit(ScalarValue::Utf8(Some("josé".to_string()))), + lit(ScalarValue::Int64(Some(10))), + lit(ScalarValue::Utf8(Some("xy".to_string()))), + ], + Ok(Some("joséxyxyxy")), + &str, + Utf8, + StringArray + ); + test_function!( + Rpad, + &[ + lit(ScalarValue::Utf8(Some("josé".to_string()))), + lit(ScalarValue::Int64(Some(10))), + lit(ScalarValue::Utf8(Some("éñ".to_string()))), + ], + Ok(Some("josééñéñéñ")), + &str, + Utf8, + StringArray + ); + test_function!( + Rtrim, + &[lit(ScalarValue::Utf8(Some("trim ".to_string())))], + Ok(Some("trim")), + &str, + Utf8, + StringArray + ); + test_function!( + Rtrim, + &[lit(ScalarValue::Utf8(Some(" trim ".to_string())))], + Ok(Some(" trim")), + &str, + Utf8, + StringArray + ); + test_function!( + Rtrim, + &[lit(ScalarValue::Utf8(Some(" trim \n".to_string())))], + Ok(Some(" trim \n")), + &str, + Utf8, + StringArray + ); + test_function!( + Rtrim, + &[lit(ScalarValue::Utf8(Some(" trim".to_string())))], + Ok(Some(" trim")), + &str, + Utf8, + StringArray + ); + test_function!( + Rtrim, + &[lit(ScalarValue::Utf8(Some("trim".to_string())))], + Ok(Some("trim")), + &str, + Utf8, + StringArray + ); + test_function!( + Rtrim, + &[lit(ScalarValue::Utf8(None))], + Ok(None), + &str, + Utf8, + StringArray + ); + test_function!( + Substr, + &[ + lit(ScalarValue::Utf8(Some("alphabet".to_string()))), + lit(ScalarValue::Int64(Some(0))), + ], + Ok(Some("alphabet")), + &str, + Utf8, + StringArray + ); + test_function!( + Substr, + &[ + lit(ScalarValue::Utf8(Some("joséésoj".to_string()))), + lit(ScalarValue::Int64(Some(5))), + ], + Ok(Some("ésoj")), + &str, + Utf8, + StringArray + ); + test_function!( + Substr, + &[ + lit(ScalarValue::Utf8(Some("alphabet".to_string()))), + lit(ScalarValue::Int64(Some(1))), + ], + Ok(Some("alphabet")), + &str, + Utf8, + StringArray + ); + test_function!( + Substr, + &[ + lit(ScalarValue::Utf8(Some("alphabet".to_string()))), + lit(ScalarValue::Int64(Some(2))), + ], + Ok(Some("lphabet")), + &str, + Utf8, + StringArray + ); + test_function!( + Substr, + &[ + lit(ScalarValue::Utf8(Some("alphabet".to_string()))), + lit(ScalarValue::Int64(Some(3))), + ], + Ok(Some("phabet")), + &str, + Utf8, + StringArray + ); + test_function!( + Substr, + &[ + lit(ScalarValue::Utf8(Some("alphabet".to_string()))), + lit(ScalarValue::Int64(Some(-3))), + ], + Ok(Some("alphabet")), + &str, + Utf8, + StringArray + ); + test_function!( + Substr, + &[ + lit(ScalarValue::Utf8(Some("alphabet".to_string()))), + lit(ScalarValue::Int64(Some(30))), + ], + Ok(Some("")), + &str, + Utf8, + StringArray + ); + test_function!( + Substr, + &[ + lit(ScalarValue::Utf8(Some("alphabet".to_string()))), + lit(ScalarValue::Int64(None)), + ], + Ok(None), + &str, + Utf8, + StringArray + ); + test_function!( + Substr, + &[ + lit(ScalarValue::Utf8(Some("alphabet".to_string()))), + lit(ScalarValue::Int64(Some(3))), + lit(ScalarValue::Int64(Some(2))), + ], + Ok(Some("ph")), + &str, + Utf8, + StringArray + ); + test_function!( + Substr, + &[ + lit(ScalarValue::Utf8(Some("alphabet".to_string()))), + lit(ScalarValue::Int64(Some(3))), + lit(ScalarValue::Int64(Some(20))), + ], + Ok(Some("phabet")), + &str, + Utf8, + StringArray + ); + test_function!( + Substr, + &[ + lit(ScalarValue::Utf8(Some("alphabet".to_string()))), + lit(ScalarValue::Int64(None)), + lit(ScalarValue::Int64(Some(20))), + ], + Ok(None), + &str, + Utf8, + StringArray + ); + test_function!( + Substr, + &[ + lit(ScalarValue::Utf8(Some("alphabet".to_string()))), + lit(ScalarValue::Int64(Some(3))), + lit(ScalarValue::Int64(None)), + ], + Ok(None), + &str, + Utf8, + StringArray + ); + test_function!( + Substr, + &[ + lit(ScalarValue::Utf8(Some("alphabet".to_string()))), + lit(ScalarValue::Int64(Some(1))), + lit(ScalarValue::Int64(Some(-1))), + ], + Err(DataFusionError::Execution( + "negative substring length not allowed".to_string(), + )), + &str, + Utf8, + StringArray + ); + test_function!( + Substr, + &[ + lit(ScalarValue::Utf8(Some("joséésoj".to_string()))), + lit(ScalarValue::Int64(Some(5))), + lit(ScalarValue::Int64(Some(2))), + ], + Ok(Some("és")), + &str, + Utf8, + StringArray + ); test_function!( Trim, &[lit(ScalarValue::Utf8(Some(" trim ".to_string())))], diff --git a/rust/datafusion/src/physical_plan/string_expressions.rs b/rust/datafusion/src/physical_plan/string_expressions.rs index 7ab0f9f215b..5d3c4d83a24 100644 --- a/rust/datafusion/src/physical_plan/string_expressions.rs +++ b/rust/datafusion/src/physical_plan/string_expressions.rs @@ -21,6 +21,8 @@ //! String expressions +use std::cmp::Ordering; +use std::str::from_utf8; use std::sync::Arc; use crate::{ @@ -308,6 +310,192 @@ pub fn concat_ws(args: &[ArrayRef]) -> Result { Ok(Arc::new(result) as ArrayRef) } +/// Returns first n characters in the string, or when n is negative, returns all but last |n| characters. +/// left('abcde', 2) = 'ab' +pub fn left(args: &[ArrayRef]) -> Result { + let string_array: &GenericStringArray = args[0] + .as_any() + .downcast_ref::>() + .ok_or_else(|| { + DataFusionError::Internal("could not cast string to StringArray".to_string()) + })?; + + let n_array: &Int64Array = + args[1] + .as_any() + .downcast_ref::() + .ok_or_else(|| { + DataFusionError::Internal("could not cast n to Int64Array".to_string()) + })?; + + let result = string_array + .iter() + .zip(n_array.iter()) + .map(|(string, n)| match (string, n) { + (None, _) => None, + (_, None) => None, + (Some(string), Some(n)) => match n.cmp(&0) { + Ordering::Equal => Some(""), + Ordering::Greater => Some( + string + .grapheme_indices(true) + .nth(n as usize) + .map_or(string, |(i, _)| { + &from_utf8(&string.as_bytes()[..i]).unwrap() + }), + ), + Ordering::Less => Some( + string + .grapheme_indices(true) + .rev() + .nth(n.abs() as usize - 1) + .map_or("", |(i, _)| { + &from_utf8(&string.as_bytes()[..i]).unwrap() + }), + ), + }, + }) + .collect::>(); + + Ok(Arc::new(result) as ArrayRef) +} + +/// Converts the string to all lower case. +/// lower('TOM') = 'tom' +pub fn lower(args: &[ColumnarValue]) -> Result { + handle(args, |x| x.to_ascii_lowercase(), "lower") +} + +/// Extends the string to length length by prepending the characters fill (a space by default). If the string is already longer than length then it is truncated (on the right). +/// lpad('hi', 5, 'xy') = 'xyxhi' +pub fn lpad(args: &[ArrayRef]) -> Result { + match args.len() { + 2 => { + let string_array: &GenericStringArray = args[0] + .as_any() + .downcast_ref::>() + .ok_or_else(|| { + DataFusionError::Internal( + "could not cast string to StringArray".to_string(), + ) + })?; + + let length_array: &Int64Array = args[1] + .as_any() + .downcast_ref::() + .ok_or_else(|| { + DataFusionError::Internal( + "could not cast length to Int64Array".to_string(), + ) + })?; + + let result = string_array + .iter() + .zip(length_array.iter()) + .map(|(string, length)| match (string, length) { + (None, _) => None, + (_, None) => None, + (Some(string), Some(length)) => { + let length = length as usize; + if length == 0 { + Some("".to_string()) + } else { + let graphemes = string.graphemes(true).collect::>(); + if length < graphemes.len() { + Some(graphemes[..length].concat()) + } else { + let mut s = string.to_string(); + s.insert_str( + 0, + " ".repeat(length - graphemes.len()).as_str(), + ); + Some(s) + } + } + } + }) + .collect::>(); + + Ok(Arc::new(result) as ArrayRef) + } + 3 => { + let string_array: &GenericStringArray = args[0] + .as_any() + .downcast_ref::>() + .ok_or_else(|| { + DataFusionError::Internal( + "could not cast string to StringArray".to_string(), + ) + })?; + + let length_array: &Int64Array = args[1] + .as_any() + .downcast_ref::() + .ok_or_else(|| { + DataFusionError::Internal( + "could not cast length to Int64Array".to_string(), + ) + })?; + + let fill_array: &GenericStringArray = args[2] + .as_any() + .downcast_ref::>() + .ok_or_else(|| { + DataFusionError::Internal( + "could not cast fill to StringArray".to_string(), + ) + })?; + + let result = string_array + .iter() + .zip(length_array.iter()) + .zip(fill_array.iter()) + .map(|((string, length), fill)| match (string, length, fill) { + (None, _, _) => None, + (_, None, _) => None, + (_, _, None) => None, + (Some(string), Some(length), Some(fill)) => { + let length = length as usize; + + if length == 0 { + Some("".to_string()) + } else { + let graphemes = string.graphemes(true).collect::>(); + let fill_chars = fill.chars().collect::>(); + + if length < graphemes.len() { + Some(graphemes[..length].concat()) + } else if fill_chars.is_empty() { + Some(string.to_string()) + } else { + let mut s = string.to_string(); + let mut char_vector = + Vec::::with_capacity(length - graphemes.len()); + for l in 0..length - graphemes.len() { + char_vector.push( + *fill_chars.get(l % fill_chars.len()).unwrap(), + ); + } + s.insert_str( + 0, + char_vector.iter().collect::().as_str(), + ); + Some(s) + } + } + } + }) + .collect::>(); + + Ok(Arc::new(result) as ArrayRef) + } + other => Err(DataFusionError::Internal(format!( + "lpad was called with {} arguments. It requires at least 2 and at most 3.", + other + ))), + } +} + /// Removes the longest string containing only characters in characters (a space by default) from the start of string. /// ltrim('zzzytest', 'xyz') = 'test' pub fn ltrim(args: &[ArrayRef]) -> Result { @@ -316,7 +504,11 @@ pub fn ltrim(args: &[ArrayRef]) -> Result { let string_array: &GenericStringArray = args[0] .as_any() .downcast_ref::>() - .unwrap(); + .ok_or_else(|| { + DataFusionError::Internal( + "could not cast string to StringArray".to_string(), + ) + })?; let result = string_array .iter() @@ -329,25 +521,30 @@ pub fn ltrim(args: &[ArrayRef]) -> Result { let string_array: &GenericStringArray = args[0] .as_any() .downcast_ref::>() - .unwrap(); + .ok_or_else(|| { + DataFusionError::Internal( + "could not cast string to StringArray".to_string(), + ) + })?; let characters_array: &GenericStringArray = args[1] .as_any() .downcast_ref::>() - .unwrap(); + .ok_or_else(|| { + DataFusionError::Internal( + "could not cast characters to StringArray".to_string(), + ) + })?; let result = string_array .iter() - .enumerate() - .map(|(i, x)| { - if characters_array.is_null(i) { - None - } else { - x.map(|x: &str| { - let chars: Vec = - characters_array.value(i).chars().collect(); - x.trim_start_matches(&chars[..]) - }) + .zip(characters_array.iter()) + .map(|(string, characters)| match (string, characters) { + (None, _) => None, + (_, None) => None, + (Some(string), Some(characters)) => { + let chars: Vec = characters.chars().collect(); + Some(string.trim_start_matches(&chars[..])) } }) .collect::>(); @@ -355,16 +552,178 @@ pub fn ltrim(args: &[ArrayRef]) -> Result { Ok(Arc::new(result) as ArrayRef) } other => Err(DataFusionError::Internal(format!( - "ltrim was called with {} arguments. It requires at most 2.", + "ltrim was called with {} arguments. It requires at least 1 and at most 2.", other ))), } } -/// Converts the string to all lower case. -/// lower('TOM') = 'tom' -pub fn lower(args: &[ColumnarValue]) -> Result { - handle(args, |x| x.to_ascii_lowercase(), "lower") +/// Returns last n characters in the string, or when n is negative, returns all but first |n| characters. +/// right('abcde', 2) = 'de' +pub fn right(args: &[ArrayRef]) -> Result { + let string_array: &GenericStringArray = args[0] + .as_any() + .downcast_ref::>() + .ok_or_else(|| { + DataFusionError::Internal("could not cast string to StringArray".to_string()) + })?; + + let n_array: &Int64Array = + args[1] + .as_any() + .downcast_ref::() + .ok_or_else(|| { + DataFusionError::Internal("could not cast n to Int64Array".to_string()) + })?; + + let result = string_array + .iter() + .zip(n_array.iter()) + .map(|(string, n)| match (string, n) { + (None, _) => None, + (_, None) => None, + (Some(string), Some(n)) => match n.cmp(&0) { + Ordering::Equal => Some(""), + Ordering::Greater => Some( + string + .grapheme_indices(true) + .rev() + .nth(n as usize - 1) + .map_or(string, |(i, _)| { + &from_utf8(&string.as_bytes()[i..]).unwrap() + }), + ), + Ordering::Less => Some( + string + .grapheme_indices(true) + .nth(n.abs() as usize) + .map_or("", |(i, _)| { + &from_utf8(&string.as_bytes()[i..]).unwrap() + }), + ), + }, + }) + .collect::>(); + + Ok(Arc::new(result) as ArrayRef) +} + +/// Extends the string to length length by appending the characters fill (a space by default). If the string is already longer than length then it is truncated. +/// rpad('hi', 5, 'xy') = 'hixyx' +pub fn rpad(args: &[ArrayRef]) -> Result { + match args.len() { + 2 => { + let string_array: &GenericStringArray = args[0] + .as_any() + .downcast_ref::>() + .ok_or_else(|| { + DataFusionError::Internal( + "could not cast string to StringArray".to_string(), + ) + })?; + + let length_array: &Int64Array = args[1] + .as_any() + .downcast_ref::() + .ok_or_else(|| { + DataFusionError::Internal( + "could not cast length to Int64Array".to_string(), + ) + })?; + + let result = string_array + .iter() + .zip(length_array.iter()) + .map(|(string, length)| match (string, length) { + (None, _) => None, + (_, None) => None, + (Some(string), Some(length)) => { + let length = length as usize; + if length == 0 { + Some("".to_string()) + } else { + let graphemes = string.graphemes(true).collect::>(); + if length < graphemes.len() { + Some(graphemes[..length].concat()) + } else { + let mut s = string.to_string(); + s.push_str(" ".repeat(length - graphemes.len()).as_str()); + Some(s) + } + } + } + }) + .collect::>(); + + Ok(Arc::new(result) as ArrayRef) + } + 3 => { + let string_array: &GenericStringArray = args[0] + .as_any() + .downcast_ref::>() + .ok_or_else(|| { + DataFusionError::Internal( + "could not cast string to StringArray".to_string(), + ) + })?; + + let length_array: &Int64Array = args[1] + .as_any() + .downcast_ref::() + .ok_or_else(|| { + DataFusionError::Internal( + "could not cast length to Int64Array".to_string(), + ) + })?; + + let fill_array: &GenericStringArray = args[2] + .as_any() + .downcast_ref::>() + .ok_or_else(|| { + DataFusionError::Internal( + "could not cast fill to StringArray".to_string(), + ) + })?; + + let result = string_array + .iter() + .zip(length_array.iter()) + .zip(fill_array.iter()) + .map(|((string, length), fill)| match (string, length, fill) { + (None, _, _) => None, + (_, None, _) => None, + (_, _, None) => None, + (Some(string), Some(length), Some(fill)) => { + let length = length as usize; + let graphemes = string.graphemes(true).collect::>(); + let fill_chars = fill.chars().collect::>(); + + if length < graphemes.len() { + Some(graphemes[..length].concat()) + } else if fill_chars.is_empty() { + Some(string.to_string()) + } else { + let mut s = string.to_string(); + let mut char_vector = + Vec::::with_capacity(length - graphemes.len()); + for l in 0..length - graphemes.len() { + char_vector + .push(*fill_chars.get(l % fill_chars.len()).unwrap()); + } + s.push_str(char_vector.iter().collect::().as_str()); + Some(s) + } + } + }) + .collect::>(); + + Ok(Arc::new(result) as ArrayRef) + } + other => Err(DataFusionError::Internal(format!( + "rpad was called with {} arguments. It requires at least 2 and at most 3.", + other + ))), + } } /// Removes the longest string containing only characters in characters (a space by default) from the end of string. @@ -375,11 +734,15 @@ pub fn rtrim(args: &[ArrayRef]) -> Result { let string_array: &GenericStringArray = args[0] .as_any() .downcast_ref::>() - .unwrap(); + .ok_or_else(|| { + DataFusionError::Internal( + "could not cast string to StringArray".to_string(), + ) + })?; let result = string_array .iter() - .map(|x| x.map(|x: &str| x.trim_end_matches(' '))) + .map(|string| string.map(|string: &str| string.trim_end_matches(' '))) .collect::>(); Ok(Arc::new(result) as ArrayRef) @@ -388,25 +751,30 @@ pub fn rtrim(args: &[ArrayRef]) -> Result { let string_array: &GenericStringArray = args[0] .as_any() .downcast_ref::>() - .unwrap(); + .ok_or_else(|| { + DataFusionError::Internal( + "could not cast string to StringArray".to_string(), + ) + })?; let characters_array: &GenericStringArray = args[1] .as_any() .downcast_ref::>() - .unwrap(); + .ok_or_else(|| { + DataFusionError::Internal( + "could not cast characters to StringArray".to_string(), + ) + })?; let result = string_array .iter() - .enumerate() - .map(|(i, x)| { - if characters_array.is_null(i) { - None - } else { - x.map(|x: &str| { - let chars: Vec = - characters_array.value(i).chars().collect(); - x.trim_end_matches(&chars[..]) - }) + .zip(characters_array.iter()) + .map(|(string, characters)| match (string, characters) { + (None, _) => None, + (_, None) => None, + (Some(string), Some(characters)) => { + let chars: Vec = characters.chars().collect(); + Some(string.trim_end_matches(&chars[..])) } }) .collect::>(); @@ -414,7 +782,7 @@ pub fn rtrim(args: &[ArrayRef]) -> Result { Ok(Arc::new(result) as ArrayRef) } other => Err(DataFusionError::Internal(format!( - "rtrim was called with {} arguments. It requires at most two.", + "rtrim was called with {} arguments. It requires at least 1 and at most 2.", other ))), } @@ -446,26 +814,22 @@ pub fn substr(args: &[ArrayRef]) -> Result { let result = string_array .iter() - .enumerate() - .map(|(i, x)| { - if start_array.is_null(i) { - None - } else { - x.map(|x: &str| { - let start: i64 = start_array.value(i); - - if start <= 0 { - x.to_string() + .zip(start_array.iter()) + .map(|(string, start)| match (string, start) { + (None, _) => None, + (_, None) => None, + (Some(string), Some(start)) => { + if start <= 0 { + Some(string.to_string()) + } else { + let graphemes = string.graphemes(true).collect::>(); + let start_pos = start as usize - 1; + if graphemes.len() < start_pos { + Some("".to_string()) } else { - let graphemes = x.graphemes(true).collect::>(); - let start_pos = start as usize - 1; - if graphemes.len() < start_pos { - "".to_string() - } else { - graphemes[start_pos..].concat() - } + Some(graphemes[start_pos..].concat()) } - }) + } } }) .collect::>(); @@ -502,36 +866,34 @@ pub fn substr(args: &[ArrayRef]) -> Result { let result = string_array .iter() - .enumerate() - .map(|(i, x)| { - if start_array.is_null(i) || count_array.is_null(i) { - Ok(None) - } else { - x.map(|x: &str| { - let start: i64 = start_array.value(i); - let count = count_array.value(i); - - if count < 0 { - Err(DataFusionError::Execution( - "negative substring length not allowed".to_string(), - )) - } else if start <= 0 { - Ok(x.to_string()) + .zip(start_array.iter()) + .zip(count_array.iter()) + .map(|((string, start), count)| match (string, start, count) { + (None, _, _) => Ok(None), + (_, None, _) => Ok(None), + (_, _, None) => Ok(None), + (Some(string), Some(start), Some(count)) => { + if count < 0 { + Err(DataFusionError::Execution( + "negative substring length not allowed".to_string(), + )) + } else if start <= 0 { + Ok(Some(string.to_string())) + } else { + let graphemes = string.graphemes(true).collect::>(); + let start_pos = start as usize - 1; + let count_usize = count as usize; + if graphemes.len() < start_pos { + Ok(Some("".to_string())) + } else if graphemes.len() < start_pos + count_usize { + Ok(Some(graphemes[start_pos..].concat())) } else { - let graphemes = x.graphemes(true).collect::>(); - let start_pos = start as usize - 1; - let count_usize = count as usize; - if graphemes.len() < start_pos { - Ok("".to_string()) - } else if graphemes.len() < start_pos + count_usize { - Ok(graphemes[start_pos..].concat()) - } else { - Ok(graphemes[start_pos..start_pos + count_usize] - .concat()) - } + Ok(Some( + graphemes[start_pos..start_pos + count_usize] + .concat(), + )) } - }) - .transpose() + } } }) .collect::>>()?; diff --git a/rust/datafusion/src/prelude.rs b/rust/datafusion/src/prelude.rs index d60f0c32a4d..1f8588c9c52 100644 --- a/rust/datafusion/src/prelude.rs +++ b/rust/datafusion/src/prelude.rs @@ -29,7 +29,8 @@ pub use crate::dataframe::DataFrame; pub use crate::execution::context::{ExecutionConfig, ExecutionContext}; pub use crate::logical_plan::{ array, avg, bit_length, btrim, character_length, col, concat, concat_ws, count, - create_udf, in_list, length, lit, lower, ltrim, max, md5, min, octet_length, rtrim, - sha224, sha256, sha384, sha512, substr, sum, trim, upper, JoinType, Partitioning, + create_udf, in_list, left, length, lit, lower, lpad, ltrim, max, md5, min, + octet_length, right, rpad, rtrim, sha224, sha256, sha384, sha512, substr, sum, trim, + upper, JoinType, Partitioning, }; pub use crate::physical_plan::csv::CsvReadOptions; diff --git a/rust/datafusion/tests/sql.rs b/rust/datafusion/tests/sql.rs index 587fe299bd8..96032c7fc7e 100644 --- a/rust/datafusion/tests/sql.rs +++ b/rust/datafusion/tests/sql.rs @@ -2058,17 +2058,53 @@ async fn test_string_expressions() -> Result<()> { test_expression!("concat_ws('|','a','b','c')", "a|b|c"); test_expression!("concat_ws('|',NULL)", ""); test_expression!("concat_ws(NULL,'a',NULL,'b','c')", "NULL"); + test_expression!("left('abcde', -2)", "abc"); + test_expression!("left('abcde', -200)", ""); + test_expression!("left('abcde', 0)", ""); + test_expression!("left('abcde', 2)", "ab"); + test_expression!("left('abcde', 200)", "abcde"); + test_expression!("left('abcde', CAST(NULL AS INT))", "NULL"); + test_expression!("left(NULL, 2)", "NULL"); + test_expression!("left(NULL, CAST(NULL AS INT))", "NULL"); + test_expression!("lower('')", ""); + test_expression!("lower('TOM')", "tom"); + test_expression!("lower(NULL)", "NULL"); + test_expression!("lpad('hi', 5, 'xy')", "xyxhi"); + test_expression!("lpad('hi', 0)", ""); + test_expression!("lpad('hi', 21, 'abcdef')", "abcdefabcdefabcdefahi"); + test_expression!("lpad('hi', 5, 'xy')", "xyxhi"); + test_expression!("lpad('hi', 5, NULL)", "NULL"); + test_expression!("lpad('hi', 5)", " hi"); + test_expression!("lpad('hi', CAST(NULL AS INT), 'xy')", "NULL"); + test_expression!("lpad('hi', CAST(NULL AS INT))", "NULL"); + test_expression!("lpad('xyxhi', 3)", "xyx"); + test_expression!("lpad(NULL, 0)", "NULL"); + test_expression!("lpad(NULL, 5, 'xy')", "NULL"); test_expression!("ltrim(' zzzytest ', NULL)", "NULL"); test_expression!("ltrim(' zzzytest ')", "zzzytest "); test_expression!("ltrim('zzzytest', 'xyz')", "test"); test_expression!("ltrim(NULL, 'xyz')", "NULL"); - test_expression!("lower('')", ""); - test_expression!("lower('TOM')", "tom"); - test_expression!("lower(NULL)", "NULL"); test_expression!("octet_length('')", "0"); test_expression!("octet_length('chars')", "5"); test_expression!("octet_length('josé')", "5"); test_expression!("octet_length(NULL)", "NULL"); + test_expression!("right('abcde', -2)", "cde"); + test_expression!("right('abcde', -200)", ""); + test_expression!("right('abcde', 0)", ""); + test_expression!("right('abcde', 2)", "de"); + test_expression!("right('abcde', 200)", "abcde"); + test_expression!("right('abcde', CAST(NULL AS INT))", "NULL"); + test_expression!("right(NULL, 2)", "NULL"); + test_expression!("right(NULL, CAST(NULL AS INT))", "NULL"); + test_expression!("rpad('hi', 5, 'xy')", "hixyx"); + test_expression!("rpad('hi', 0)", ""); + test_expression!("rpad('hi', 21, 'abcdef')", "hiabcdefabcdefabcdefa"); + test_expression!("rpad('hi', 5, 'xy')", "hixyx"); + test_expression!("rpad('hi', 5, NULL)", "NULL"); + test_expression!("rpad('hi', 5)", "hi "); + test_expression!("rpad('hi', CAST(NULL AS INT), 'xy')", "NULL"); + test_expression!("rpad('hi', CAST(NULL AS INT))", "NULL"); + test_expression!("rpad('xyxhi', 3)", "xyx"); test_expression!("rtrim(' testxxzx ')", " testxxzx"); test_expression!("rtrim(' zzzytest ', NULL)", "NULL"); test_expression!("rtrim('testxxzx', 'xyz')", "test");