diff --git a/rust/datafusion/Cargo.toml b/rust/datafusion/Cargo.toml index 9f677eb47e1..b713b773328 100644 --- a/rust/datafusion/Cargo.toml +++ b/rust/datafusion/Cargo.toml @@ -40,10 +40,12 @@ name = "datafusion-cli" path = "src/bin/main.rs" [features] -default = ["cli", "crypto_expressions"] +default = ["cli", "crypto_expressions", "regex_expressions", "unicode_expressions"] cli = ["rustyline"] simd = ["arrow/simd"] crypto_expressions = ["md-5", "sha2"] +regex_expressions = ["regex", "lazy_static"] +unicode_expressions = ["unicode-segmentation"] [dependencies] ahash = "0.7" @@ -65,7 +67,9 @@ log = "^0.4" md-5 = { version = "^0.9.1", optional = true } sha2 = { version = "^0.9.1", optional = true } ordered-float = "2.0" -unicode-segmentation = "^1.7.1" +unicode-segmentation = { version = "^1.7.1", optional = true } +regex = { version = "^1.4.3", optional = true } +lazy_static = { version = "^1.4.0", optional = true } [dev-dependencies] rand = "0.8" diff --git a/rust/datafusion/README.md b/rust/datafusion/README.md index 0213f74b9cf..f8d0d92b516 100644 --- a/rust/datafusion/README.md +++ b/rust/datafusion/README.md @@ -71,15 +71,20 @@ DataFusion includes a simple command-line interactive SQL utility. See the [CLI - [x] lpad - [x] ltrim - [x] octet_length + - [x] regexp_replace - [x] repeat + - [x] replace - [x] reverse - [x] right - [x] rpad - [x] rtrim + - [x] split_part + - [x] starts_with + - [x] strpos - [x] substr - [x] to_hex + - [x] translate - [x] trim - - Miscellaneous/Boolean functions - [x] nullif - Common date/time functions diff --git a/rust/datafusion/src/lib.rs b/rust/datafusion/src/lib.rs index 1e86f664c2d..73dca51a8cf 100644 --- a/rust/datafusion/src/lib.rs +++ b/rust/datafusion/src/lib.rs @@ -170,3 +170,7 @@ pub mod variable; #[cfg(test)] pub mod test; + +#[macro_use] +#[cfg(feature = "regex_expressions")] +extern crate lazy_static; diff --git a/rust/datafusion/src/logical_plan/expr.rs b/rust/datafusion/src/logical_plan/expr.rs index 5b0876a79e0..1eaa02b1e41 100644 --- a/rust/datafusion/src/logical_plan/expr.rs +++ b/rust/datafusion/src/logical_plan/expr.rs @@ -1090,6 +1090,8 @@ unary_scalar_expr!(Lpad, lpad); unary_scalar_expr!(Ltrim, ltrim); unary_scalar_expr!(MD5, md5); unary_scalar_expr!(OctetLength, octet_length); +unary_scalar_expr!(RegexpReplace, regexp_replace); +unary_scalar_expr!(Replace, replace); unary_scalar_expr!(Repeat, repeat); unary_scalar_expr!(Reverse, reverse); unary_scalar_expr!(Right, right); @@ -1099,8 +1101,12 @@ unary_scalar_expr!(SHA224, sha224); unary_scalar_expr!(SHA256, sha256); unary_scalar_expr!(SHA384, sha384); unary_scalar_expr!(SHA512, sha512); +unary_scalar_expr!(SplitPart, split_part); +unary_scalar_expr!(StartsWith, starts_with); +unary_scalar_expr!(Strpos, strpos); unary_scalar_expr!(Substr, substr); unary_scalar_expr!(ToHex, to_hex); +unary_scalar_expr!(Translate, translate); unary_scalar_expr!(Trim, trim); unary_scalar_expr!(Upper, upper); diff --git a/rust/datafusion/src/logical_plan/mod.rs b/rust/datafusion/src/logical_plan/mod.rs index ab787ef82f4..0e7e61981b1 100644 --- a/rust/datafusion/src/logical_plan/mod.rs +++ b/rust/datafusion/src/logical_plan/mod.rs @@ -37,8 +37,9 @@ pub use expr::{ ceil, character_length, chr, col, combine_filters, concat, concat_ws, cos, count, count_distinct, create_udaf, create_udf, exp, exprlist_to_fields, floor, in_list, initcap, left, length, lit, ln, log10, log2, lower, lpad, ltrim, max, md5, min, - octet_length, or, repeat, reverse, right, round, rpad, rtrim, sha224, sha256, sha384, - sha512, signum, sin, sqrt, substr, sum, tan, to_hex, trim, trunc, upper, when, Expr, + octet_length, or, regexp_replace, repeat, replace, reverse, right, round, rpad, + rtrim, sha224, sha256, sha384, sha512, signum, sin, split_part, sqrt, starts_with, + strpos, substr, sum, tan, to_hex, translate, trim, trunc, upper, when, Expr, ExprRewriter, ExpressionVisitor, Literal, Recursion, }; pub use extension::UserDefinedLogicalNode; diff --git a/rust/datafusion/src/physical_plan/functions.rs b/rust/datafusion/src/physical_plan/functions.rs index ae8d128fc30..9dc54a4113f 100644 --- a/rust/datafusion/src/physical_plan/functions.rs +++ b/rust/datafusion/src/physical_plan/functions.rs @@ -158,8 +158,12 @@ pub enum BuiltinScalarFunction { NullIf, /// octet_length OctetLength, + /// regexp_replace + RegexpReplace, /// repeat Repeat, + /// replace + Replace, /// reverse Reverse, /// right @@ -176,12 +180,20 @@ pub enum BuiltinScalarFunction { SHA384, /// Sha512 SHA512, + /// split_part + SplitPart, + /// starts_with + StartsWith, + /// strpos + Strpos, /// substr Substr, /// to_hex ToHex, /// to_timestamp ToTimestamp, + /// translate + Translate, /// trim Trim, /// upper @@ -239,7 +251,9 @@ impl FromStr for BuiltinScalarFunction { "md5" => BuiltinScalarFunction::MD5, "nullif" => BuiltinScalarFunction::NullIf, "octet_length" => BuiltinScalarFunction::OctetLength, + "regexp_replace" => BuiltinScalarFunction::RegexpReplace, "repeat" => BuiltinScalarFunction::Repeat, + "replace" => BuiltinScalarFunction::Replace, "reverse" => BuiltinScalarFunction::Reverse, "right" => BuiltinScalarFunction::Right, "rpad" => BuiltinScalarFunction::Rpad, @@ -248,9 +262,13 @@ impl FromStr for BuiltinScalarFunction { "sha256" => BuiltinScalarFunction::SHA256, "sha384" => BuiltinScalarFunction::SHA384, "sha512" => BuiltinScalarFunction::SHA512, + "split_part" => BuiltinScalarFunction::SplitPart, + "starts_with" => BuiltinScalarFunction::StartsWith, + "strpos" => BuiltinScalarFunction::Strpos, "substr" => BuiltinScalarFunction::Substr, "to_hex" => BuiltinScalarFunction::ToHex, "to_timestamp" => BuiltinScalarFunction::ToTimestamp, + "translate" => BuiltinScalarFunction::Translate, "trim" => BuiltinScalarFunction::Trim, "upper" => BuiltinScalarFunction::Upper, @@ -404,6 +422,16 @@ pub fn return_type( )); } }), + BuiltinScalarFunction::RegexpReplace => Ok(match arg_types[0] { + DataType::LargeUtf8 => DataType::LargeUtf8, + DataType::Utf8 => DataType::Utf8, + _ => { + // this error is internal as `data_types` should have captured this. + return Err(DataFusionError::Internal( + "The regexp_replace function can only accept strings.".to_string(), + )); + } + }), BuiltinScalarFunction::Repeat => Ok(match arg_types[0] { DataType::LargeUtf8 => DataType::LargeUtf8, DataType::Utf8 => DataType::Utf8, @@ -414,6 +442,16 @@ pub fn return_type( )); } }), + BuiltinScalarFunction::Replace => Ok(match arg_types[0] { + DataType::LargeUtf8 => DataType::LargeUtf8, + DataType::Utf8 => DataType::Utf8, + _ => { + // this error is internal as `data_types` should have captured this. + return Err(DataFusionError::Internal( + "The replace function can only accept strings.".to_string(), + )); + } + }), BuiltinScalarFunction::Reverse => Ok(match arg_types[0] { DataType::LargeUtf8 => DataType::LargeUtf8, DataType::Utf8 => DataType::Utf8, @@ -494,6 +532,27 @@ pub fn return_type( )); } }), + BuiltinScalarFunction::SplitPart => Ok(match arg_types[0] { + DataType::LargeUtf8 => DataType::LargeUtf8, + DataType::Utf8 => DataType::Utf8, + _ => { + // this error is internal as `data_types` should have captured this. + return Err(DataFusionError::Internal( + "The split_part function can only accept strings.".to_string(), + )); + } + }), + BuiltinScalarFunction::StartsWith => Ok(DataType::Boolean), + BuiltinScalarFunction::Strpos => Ok(match arg_types[0] { + DataType::LargeUtf8 => DataType::Int64, + DataType::Utf8 => DataType::Int32, + _ => { + // this error is internal as `data_types` should have captured this. + return Err(DataFusionError::Internal( + "The strpos function can only accept strings.".to_string(), + )); + } + }), BuiltinScalarFunction::Substr => Ok(match arg_types[0] { DataType::LargeUtf8 => DataType::LargeUtf8, DataType::Utf8 => DataType::Utf8, @@ -518,6 +577,16 @@ pub fn return_type( BuiltinScalarFunction::ToTimestamp => { Ok(DataType::Timestamp(TimeUnit::Nanosecond, None)) } + BuiltinScalarFunction::Translate => Ok(match arg_types[0] { + DataType::LargeUtf8 => DataType::LargeUtf8, + DataType::Utf8 => DataType::Utf8, + _ => { + // this error is internal as `data_types` should have captured this. + return Err(DataFusionError::Internal( + "The translate function can only accept strings.".to_string(), + )); + } + }), BuiltinScalarFunction::Trim => Ok(match arg_types[0] { DataType::LargeUtf8 => DataType::LargeUtf8, DataType::Utf8 => DataType::Utf8, @@ -579,6 +648,46 @@ macro_rules! invoke_if_crypto_expressions_feature_flag { }; } +#[cfg(feature = "regex_expressions")] +macro_rules! invoke_if_regex_expressions_feature_flag { + ($FUNC:ident, $T:tt, $NAME:expr) => {{ + use crate::physical_plan::regex_expressions; + regex_expressions::$FUNC::<$T> + }}; +} + +#[cfg(not(feature = "regex_expressions"))] +macro_rules! invoke_if_regex_expressions_feature_flag { + ($FUNC:ident, $T:tt, $NAME:expr) => { + |_: &[ArrayRef]| -> Result { + Err(DataFusionError::Internal(format!( + "function {} requires compilation with feature flag: regex_expressions.", + $NAME + ))) + } + }; +} + +#[cfg(feature = "unicode_expressions")] +macro_rules! invoke_if_unicode_expressions_feature_flag { + ($FUNC:ident, $T:tt, $NAME:expr) => {{ + use crate::physical_plan::unicode_expressions; + unicode_expressions::$FUNC::<$T> + }}; +} + +#[cfg(not(feature = "unicode_expressions"))] +macro_rules! invoke_if_unicode_expressions_feature_flag { + ($FUNC:ident, $T:tt, $NAME:expr) => { + |_: &[ArrayRef]| -> Result { + Err(DataFusionError::Internal(format!( + "function {} requires compilation with feature flag: unicode_expressions.", + $NAME + ))) + } + }; +} + /// Create a physical (function) expression. /// This function errors when `args`' can't be coerced to a valid argument type of the function. pub fn create_physical_expr( @@ -645,12 +754,22 @@ pub fn create_physical_expr( ))), }, BuiltinScalarFunction::CharacterLength => |args| match args[0].data_type() { - DataType::Utf8 => make_scalar_function( - string_expressions::character_length::, - )(args), - DataType::LargeUtf8 => make_scalar_function( - string_expressions::character_length::, - )(args), + DataType::Utf8 => { + let func = invoke_if_unicode_expressions_feature_flag!( + character_length, + Int32Type, + "character_length" + ); + make_scalar_function(func)(args) + } + DataType::LargeUtf8 => { + let func = invoke_if_unicode_expressions_feature_flag!( + character_length, + Int64Type, + "character_length" + ); + make_scalar_function(func)(args) + } other => Err(DataFusionError::Internal(format!( "Unsupported data type {:?} for function character_length", other, @@ -678,9 +797,13 @@ pub fn create_physical_expr( ))), }, BuiltinScalarFunction::Left => |args| match args[0].data_type() { - DataType::Utf8 => make_scalar_function(string_expressions::left::)(args), + DataType::Utf8 => { + let func = invoke_if_unicode_expressions_feature_flag!(left, i32, "left"); + make_scalar_function(func)(args) + } DataType::LargeUtf8 => { - make_scalar_function(string_expressions::left::)(args) + let func = invoke_if_unicode_expressions_feature_flag!(left, i64, "left"); + make_scalar_function(func)(args) } other => Err(DataFusionError::Internal(format!( "Unsupported data type {:?} for function left", @@ -689,9 +812,13 @@ pub fn create_physical_expr( }, BuiltinScalarFunction::Lower => string_expressions::lower, BuiltinScalarFunction::Lpad => |args| match args[0].data_type() { - DataType::Utf8 => make_scalar_function(string_expressions::lpad::)(args), + DataType::Utf8 => { + let func = invoke_if_unicode_expressions_feature_flag!(lpad, i32, "lpad"); + make_scalar_function(func)(args) + } DataType::LargeUtf8 => { - make_scalar_function(string_expressions::lpad::)(args) + let func = invoke_if_unicode_expressions_feature_flag!(lpad, i64, "lpad"); + make_scalar_function(func)(args) } other => Err(DataFusionError::Internal(format!( "Unsupported data type {:?} for function lpad", @@ -726,6 +853,28 @@ pub fn create_physical_expr( _ => unreachable!(), }, }, + BuiltinScalarFunction::RegexpReplace => |args| match args[0].data_type() { + DataType::Utf8 => { + let func = invoke_if_regex_expressions_feature_flag!( + regexp_replace, + i32, + "regexp_replace" + ); + make_scalar_function(func)(args) + } + DataType::LargeUtf8 => { + let func = invoke_if_regex_expressions_feature_flag!( + regexp_replace, + i64, + "regexp_replace" + ); + make_scalar_function(func)(args) + } + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {:?} for function regexp_replace", + other, + ))), + }, BuiltinScalarFunction::Repeat => |args| match args[0].data_type() { DataType::Utf8 => { make_scalar_function(string_expressions::repeat::)(args) @@ -738,12 +887,28 @@ pub fn create_physical_expr( other, ))), }, + BuiltinScalarFunction::Replace => |args| match args[0].data_type() { + DataType::Utf8 => { + make_scalar_function(string_expressions::replace::)(args) + } + DataType::LargeUtf8 => { + make_scalar_function(string_expressions::replace::)(args) + } + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {:?} for function replace", + other, + ))), + }, BuiltinScalarFunction::Reverse => |args| match args[0].data_type() { DataType::Utf8 => { - make_scalar_function(string_expressions::reverse::)(args) + let func = + invoke_if_unicode_expressions_feature_flag!(reverse, i32, "reverse"); + make_scalar_function(func)(args) } DataType::LargeUtf8 => { - make_scalar_function(string_expressions::reverse::)(args) + let func = + invoke_if_unicode_expressions_feature_flag!(reverse, i64, "reverse"); + make_scalar_function(func)(args) } other => Err(DataFusionError::Internal(format!( "Unsupported data type {:?} for function reverse", @@ -752,10 +917,14 @@ pub fn create_physical_expr( }, BuiltinScalarFunction::Right => |args| match args[0].data_type() { DataType::Utf8 => { - make_scalar_function(string_expressions::right::)(args) + let func = + invoke_if_unicode_expressions_feature_flag!(right, i32, "right"); + make_scalar_function(func)(args) } DataType::LargeUtf8 => { - make_scalar_function(string_expressions::right::)(args) + let func = + invoke_if_unicode_expressions_feature_flag!(right, i64, "right"); + make_scalar_function(func)(args) } other => Err(DataFusionError::Internal(format!( "Unsupported data type {:?} for function right", @@ -763,9 +932,13 @@ pub fn create_physical_expr( ))), }, BuiltinScalarFunction::Rpad => |args| match args[0].data_type() { - DataType::Utf8 => make_scalar_function(string_expressions::rpad::)(args), + DataType::Utf8 => { + let func = invoke_if_unicode_expressions_feature_flag!(rpad, i32, "rpad"); + make_scalar_function(func)(args) + } DataType::LargeUtf8 => { - make_scalar_function(string_expressions::rpad::)(args) + let func = invoke_if_unicode_expressions_feature_flag!(rpad, i64, "rpad"); + make_scalar_function(func)(args) } other => Err(DataFusionError::Internal(format!( "Unsupported data type {:?} for function rpad", @@ -796,12 +969,58 @@ pub fn create_physical_expr( BuiltinScalarFunction::SHA512 => { invoke_if_crypto_expressions_feature_flag!(sha512, "sha512") } + BuiltinScalarFunction::SplitPart => |args| match args[0].data_type() { + DataType::Utf8 => { + make_scalar_function(string_expressions::split_part::)(args) + } + DataType::LargeUtf8 => { + make_scalar_function(string_expressions::split_part::)(args) + } + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {:?} for function split_part", + other, + ))), + }, + BuiltinScalarFunction::StartsWith => |args| match args[0].data_type() { + DataType::Utf8 => { + make_scalar_function(string_expressions::starts_with::)(args) + } + DataType::LargeUtf8 => { + make_scalar_function(string_expressions::starts_with::)(args) + } + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {:?} for function starts_with", + other, + ))), + }, + BuiltinScalarFunction::Strpos => |args| match args[0].data_type() { + DataType::Utf8 => { + let func = invoke_if_unicode_expressions_feature_flag!( + strpos, Int32Type, "strpos" + ); + make_scalar_function(func)(args) + } + DataType::LargeUtf8 => { + let func = invoke_if_unicode_expressions_feature_flag!( + strpos, Int64Type, "strpos" + ); + make_scalar_function(func)(args) + } + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {:?} for function strpos", + other, + ))), + }, BuiltinScalarFunction::Substr => |args| match args[0].data_type() { DataType::Utf8 => { - make_scalar_function(string_expressions::substr::)(args) + let func = + invoke_if_unicode_expressions_feature_flag!(substr, i32, "substr"); + make_scalar_function(func)(args) } DataType::LargeUtf8 => { - make_scalar_function(string_expressions::substr::)(args) + let func = + invoke_if_unicode_expressions_feature_flag!(substr, i64, "substr"); + make_scalar_function(func)(args) } other => Err(DataFusionError::Internal(format!( "Unsupported data type {:?} for function substr", @@ -821,6 +1040,28 @@ pub fn create_physical_expr( ))), }, BuiltinScalarFunction::ToTimestamp => datetime_expressions::to_timestamp, + BuiltinScalarFunction::Translate => |args| match args[0].data_type() { + DataType::Utf8 => { + let func = invoke_if_unicode_expressions_feature_flag!( + translate, + i32, + "translate" + ); + make_scalar_function(func)(args) + } + DataType::LargeUtf8 => { + let func = invoke_if_unicode_expressions_feature_flag!( + translate, + i64, + "translate" + ); + make_scalar_function(func)(args) + } + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {:?} for function translate", + other, + ))), + }, BuiltinScalarFunction::Trim => |args| match args[0].data_type() { DataType::Utf8 => { make_scalar_function(string_expressions::btrim::)(args) @@ -941,12 +1182,50 @@ fn signature(fun: &BuiltinScalarFunction) -> Signature { DataType::Timestamp(TimeUnit::Nanosecond, None), ]), ]), + BuiltinScalarFunction::SplitPart => Signature::OneOf(vec![ + Signature::Exact(vec![DataType::Utf8, DataType::Utf8, DataType::Int64]), + Signature::Exact(vec![DataType::LargeUtf8, DataType::Utf8, DataType::Int64]), + Signature::Exact(vec![DataType::Utf8, DataType::LargeUtf8, DataType::Int64]), + Signature::Exact(vec![ + DataType::LargeUtf8, + DataType::LargeUtf8, + DataType::Int64, + ]), + ]), + + BuiltinScalarFunction::Strpos | BuiltinScalarFunction::StartsWith => { + Signature::OneOf(vec![ + Signature::Exact(vec![DataType::Utf8, DataType::Utf8]), + Signature::Exact(vec![DataType::Utf8, DataType::LargeUtf8]), + Signature::Exact(vec![DataType::LargeUtf8, DataType::Utf8]), + Signature::Exact(vec![DataType::LargeUtf8, DataType::LargeUtf8]), + ]) + } + BuiltinScalarFunction::Substr => Signature::OneOf(vec![ Signature::Exact(vec![DataType::Utf8, DataType::Int64]), Signature::Exact(vec![DataType::LargeUtf8, DataType::Int64]), Signature::Exact(vec![DataType::Utf8, DataType::Int64, DataType::Int64]), Signature::Exact(vec![DataType::LargeUtf8, DataType::Int64, DataType::Int64]), ]), + + BuiltinScalarFunction::Replace | BuiltinScalarFunction::Translate => { + Signature::OneOf(vec![Signature::Exact(vec![ + DataType::Utf8, + DataType::Utf8, + DataType::Utf8, + ])]) + } + BuiltinScalarFunction::RegexpReplace => Signature::OneOf(vec![ + Signature::Exact(vec![DataType::Utf8, DataType::Utf8, DataType::Utf8]), + Signature::Exact(vec![ + DataType::Utf8, + DataType::Utf8, + DataType::Utf8, + DataType::Utf8, + ]), + ]), + BuiltinScalarFunction::NullIf => { Signature::Uniform(2, SUPPORTED_NULLIF_TYPES.to_vec()) } @@ -1106,8 +1385,8 @@ mod tests { }; use arrow::{ array::{ - Array, ArrayRef, BinaryArray, FixedSizeListArray, Float64Array, Int32Array, - StringArray, UInt32Array, UInt64Array, + Array, ArrayRef, BinaryArray, BooleanArray, FixedSizeListArray, Float64Array, + Int32Array, StringArray, UInt32Array, UInt64Array, }, datatypes::Field, record_batch::RecordBatch, @@ -1311,6 +1590,7 @@ mod tests { Utf8, StringArray ); + #[cfg(feature = "unicode_expressions")] test_function!( CharacterLength, &[lit(ScalarValue::Utf8(Some("chars".to_string())))], @@ -1319,6 +1599,7 @@ mod tests { Int32, Int32Array ); + #[cfg(feature = "unicode_expressions")] test_function!( CharacterLength, &[lit(ScalarValue::Utf8(Some("josé".to_string())))], @@ -1327,6 +1608,7 @@ mod tests { Int32, Int32Array ); + #[cfg(feature = "unicode_expressions")] test_function!( CharacterLength, &[lit(ScalarValue::Utf8(Some("".to_string())))], @@ -1335,6 +1617,7 @@ mod tests { Int32, Int32Array ); + #[cfg(feature = "unicode_expressions")] test_function!( CharacterLength, &[lit(ScalarValue::Utf8(None))], @@ -1343,6 +1626,17 @@ mod tests { Int32, Int32Array ); + #[cfg(not(feature = "unicode_expressions"))] + test_function!( + CharacterLength, + &[lit(ScalarValue::Utf8(Some("josé".to_string())))], + Err(DataFusionError::Internal( + "function character_length requires compilation with feature flag: unicode_expressions.".to_string() + )), + i32, + Int32, + Int32Array + ); test_function!( Chr, &[lit(ScalarValue::Int64(Some(128175)))], @@ -1557,6 +1851,7 @@ mod tests { Utf8, StringArray ); + #[cfg(feature = "unicode_expressions")] test_function!( Left, &[ @@ -1568,6 +1863,7 @@ mod tests { Utf8, StringArray ); + #[cfg(feature = "unicode_expressions")] test_function!( Left, &[ @@ -1579,6 +1875,7 @@ mod tests { Utf8, StringArray ); + #[cfg(feature = "unicode_expressions")] test_function!( Left, &[ @@ -1590,6 +1887,7 @@ mod tests { Utf8, StringArray ); + #[cfg(feature = "unicode_expressions")] test_function!( Left, &[ @@ -1601,6 +1899,7 @@ mod tests { Utf8, StringArray ); + #[cfg(feature = "unicode_expressions")] test_function!( Left, &[ @@ -1612,6 +1911,7 @@ mod tests { Utf8, StringArray ); + #[cfg(feature = "unicode_expressions")] test_function!( Left, &[ @@ -1623,6 +1923,7 @@ mod tests { Utf8, StringArray ); + #[cfg(feature = "unicode_expressions")] test_function!( Left, &[ @@ -1634,6 +1935,7 @@ mod tests { Utf8, StringArray ); + #[cfg(feature = "unicode_expressions")] test_function!( Left, &[ @@ -1645,6 +1947,7 @@ mod tests { Utf8, StringArray ); + #[cfg(feature = "unicode_expressions")] test_function!( Left, &[ @@ -1656,6 +1959,21 @@ mod tests { Utf8, StringArray ); + #[cfg(not(feature = "unicode_expressions"))] + test_function!( + Left, + &[ + lit(ScalarValue::Utf8(Some("abcde".to_string()))), + lit(ScalarValue::Int8(Some(2))), + ], + Err(DataFusionError::Internal( + "function left requires compilation with feature flag: unicode_expressions.".to_string() + )), + &str, + Utf8, + StringArray + ); + #[cfg(feature = "unicode_expressions")] test_function!( Lpad, &[ @@ -1667,6 +1985,7 @@ mod tests { Utf8, StringArray ); + #[cfg(feature = "unicode_expressions")] test_function!( Lpad, &[ @@ -1678,6 +1997,7 @@ mod tests { Utf8, StringArray ); + #[cfg(feature = "unicode_expressions")] test_function!( Lpad, &[ @@ -1689,6 +2009,7 @@ mod tests { Utf8, StringArray ); + #[cfg(feature = "unicode_expressions")] test_function!( Lpad, &[ @@ -1700,6 +2021,7 @@ mod tests { Utf8, StringArray ); + #[cfg(feature = "unicode_expressions")] test_function!( Lpad, &[ @@ -1711,6 +2033,7 @@ mod tests { Utf8, StringArray ); + #[cfg(feature = "unicode_expressions")] test_function!( Lpad, &[ @@ -1723,6 +2046,7 @@ mod tests { Utf8, StringArray ); + #[cfg(feature = "unicode_expressions")] test_function!( Lpad, &[ @@ -1735,6 +2059,7 @@ mod tests { Utf8, StringArray ); + #[cfg(feature = "unicode_expressions")] test_function!( Lpad, &[ @@ -1747,6 +2072,7 @@ mod tests { Utf8, StringArray ); + #[cfg(feature = "unicode_expressions")] test_function!( Lpad, &[ @@ -1759,6 +2085,7 @@ mod tests { Utf8, StringArray ); + #[cfg(feature = "unicode_expressions")] test_function!( Lpad, &[ @@ -1771,6 +2098,7 @@ mod tests { Utf8, StringArray ); + #[cfg(feature = "unicode_expressions")] test_function!( Lpad, &[ @@ -1783,6 +2111,7 @@ mod tests { Utf8, StringArray ); + #[cfg(feature = "unicode_expressions")] test_function!( Lpad, &[ @@ -1795,6 +2124,7 @@ mod tests { Utf8, StringArray ); + #[cfg(feature = "unicode_expressions")] test_function!( Lpad, &[ @@ -1807,6 +2137,7 @@ mod tests { Utf8, StringArray ); + #[cfg(feature = "unicode_expressions")] test_function!( Lpad, &[ @@ -1819,6 +2150,20 @@ mod tests { Utf8, StringArray ); + #[cfg(not(feature = "unicode_expressions"))] + test_function!( + Lpad, + &[ + lit(ScalarValue::Utf8(Some("josé".to_string()))), + lit(ScalarValue::Int64(Some(5))), + ], + Err(DataFusionError::Internal( + "function lpad requires compilation with feature flag: unicode_expressions.".to_string() + )), + &str, + Utf8, + StringArray + ); test_function!( Ltrim, &[lit(ScalarValue::Utf8(Some(" trim".to_string())))], @@ -1937,6 +2282,159 @@ mod tests { Int32, Int32Array ); + #[cfg(feature = "regex_expressions")] + test_function!( + RegexpReplace, + &[ + lit(ScalarValue::Utf8(Some("Thomas".to_string()))), + lit(ScalarValue::Utf8(Some(".[mN]a.".to_string()))), + lit(ScalarValue::Utf8(Some("M".to_string()))), + ], + Ok(Some("ThM")), + &str, + Utf8, + StringArray + ); + #[cfg(feature = "regex_expressions")] + test_function!( + RegexpReplace, + &[ + lit(ScalarValue::Utf8(Some("foobarbaz".to_string()))), + lit(ScalarValue::Utf8(Some("b..".to_string()))), + lit(ScalarValue::Utf8(Some("X".to_string()))), + ], + Ok(Some("fooXbaz")), + &str, + Utf8, + StringArray + ); + #[cfg(feature = "regex_expressions")] + test_function!( + RegexpReplace, + &[ + lit(ScalarValue::Utf8(Some("foobarbaz".to_string()))), + lit(ScalarValue::Utf8(Some("b..".to_string()))), + lit(ScalarValue::Utf8(Some("X".to_string()))), + lit(ScalarValue::Utf8(Some("g".to_string()))), + ], + Ok(Some("fooXX")), + &str, + Utf8, + StringArray + ); + #[cfg(feature = "regex_expressions")] + test_function!( + RegexpReplace, + &[ + lit(ScalarValue::Utf8(Some("foobarbaz".to_string()))), + lit(ScalarValue::Utf8(Some("b(..)".to_string()))), + lit(ScalarValue::Utf8(Some("X\\1Y".to_string()))), + lit(ScalarValue::Utf8(Some("g".to_string()))), + ], + Ok(Some("fooXarYXazY")), + &str, + Utf8, + StringArray + ); + #[cfg(feature = "regex_expressions")] + test_function!( + RegexpReplace, + &[ + lit(ScalarValue::Utf8(None)), + lit(ScalarValue::Utf8(Some("b(..)".to_string()))), + lit(ScalarValue::Utf8(Some("X\\1Y".to_string()))), + lit(ScalarValue::Utf8(Some("g".to_string()))), + ], + Ok(None), + &str, + Utf8, + StringArray + ); + #[cfg(feature = "regex_expressions")] + test_function!( + RegexpReplace, + &[ + lit(ScalarValue::Utf8(Some("foobarbaz".to_string()))), + lit(ScalarValue::Utf8(None)), + lit(ScalarValue::Utf8(Some("X\\1Y".to_string()))), + lit(ScalarValue::Utf8(Some("g".to_string()))), + ], + Ok(None), + &str, + Utf8, + StringArray + ); + #[cfg(feature = "regex_expressions")] + test_function!( + RegexpReplace, + &[ + lit(ScalarValue::Utf8(Some("foobarbaz".to_string()))), + lit(ScalarValue::Utf8(Some("b(..)".to_string()))), + lit(ScalarValue::Utf8(None)), + lit(ScalarValue::Utf8(Some("g".to_string()))), + ], + Ok(None), + &str, + Utf8, + StringArray + ); + #[cfg(feature = "regex_expressions")] + test_function!( + RegexpReplace, + &[ + lit(ScalarValue::Utf8(Some("foobarbaz".to_string()))), + lit(ScalarValue::Utf8(Some("b(..)".to_string()))), + lit(ScalarValue::Utf8(Some("X\\1Y".to_string()))), + lit(ScalarValue::Utf8(None)), + ], + Ok(None), + &str, + Utf8, + StringArray + ); + #[cfg(feature = "regex_expressions")] + test_function!( + RegexpReplace, + &[ + lit(ScalarValue::Utf8(Some("ABCabcABC".to_string()))), + lit(ScalarValue::Utf8(Some("(abc)".to_string()))), + lit(ScalarValue::Utf8(Some("X".to_string()))), + lit(ScalarValue::Utf8(Some("gi".to_string()))), + ], + Ok(Some("XXX")), + &str, + Utf8, + StringArray + ); + #[cfg(feature = "regex_expressions")] + test_function!( + RegexpReplace, + &[ + lit(ScalarValue::Utf8(Some("ABCabcABC".to_string()))), + lit(ScalarValue::Utf8(Some("(abc)".to_string()))), + lit(ScalarValue::Utf8(Some("X".to_string()))), + lit(ScalarValue::Utf8(Some("i".to_string()))), + ], + Ok(Some("XabcABC")), + &str, + Utf8, + StringArray + ); + #[cfg(not(feature = "regex_expressions"))] + test_function!( + RegexpReplace, + &[ + lit(ScalarValue::Utf8(Some("foobarbaz".to_string()))), + lit(ScalarValue::Utf8(Some("b..".to_string()))), + lit(ScalarValue::Utf8(Some("X".to_string()))), + ], + Err(DataFusionError::Internal( + "function regexp_replace requires compilation with feature flag: regex_expressions.".to_string() + )), + &str, + Utf8, + StringArray + ); test_function!( Repeat, &[ @@ -1970,6 +2468,7 @@ mod tests { Utf8, StringArray ); + #[cfg(feature = "unicode_expressions")] test_function!( Reverse, &[lit(ScalarValue::Utf8(Some("abcde".to_string())))], @@ -1978,6 +2477,7 @@ mod tests { Utf8, StringArray ); + #[cfg(feature = "unicode_expressions")] test_function!( Reverse, &[lit(ScalarValue::Utf8(Some("loẅks".to_string())))], @@ -1986,6 +2486,7 @@ mod tests { Utf8, StringArray ); + #[cfg(feature = "unicode_expressions")] test_function!( Reverse, &[lit(ScalarValue::Utf8(Some("loẅks".to_string())))], @@ -1994,6 +2495,7 @@ mod tests { Utf8, StringArray ); + #[cfg(feature = "unicode_expressions")] test_function!( Reverse, &[lit(ScalarValue::Utf8(None))], @@ -2002,6 +2504,18 @@ mod tests { Utf8, StringArray ); + #[cfg(not(feature = "unicode_expressions"))] + test_function!( + Reverse, + &[lit(ScalarValue::Utf8(Some("abcde".to_string())))], + Err(DataFusionError::Internal( + "function reverse requires compilation with feature flag: unicode_expressions.".to_string() + )), + &str, + Utf8, + StringArray + ); + #[cfg(feature = "unicode_expressions")] test_function!( Right, &[ @@ -2013,6 +2527,7 @@ mod tests { Utf8, StringArray ); + #[cfg(feature = "unicode_expressions")] test_function!( Right, &[ @@ -2024,6 +2539,7 @@ mod tests { Utf8, StringArray ); + #[cfg(feature = "unicode_expressions")] test_function!( Right, &[ @@ -2035,6 +2551,7 @@ mod tests { Utf8, StringArray ); + #[cfg(feature = "unicode_expressions")] test_function!( Right, &[ @@ -2046,6 +2563,7 @@ mod tests { Utf8, StringArray ); + #[cfg(feature = "unicode_expressions")] test_function!( Right, &[ @@ -2057,6 +2575,7 @@ mod tests { Utf8, StringArray ); + #[cfg(feature = "unicode_expressions")] test_function!( Right, &[ @@ -2068,6 +2587,7 @@ mod tests { Utf8, StringArray ); + #[cfg(feature = "unicode_expressions")] test_function!( Right, &[ @@ -2079,6 +2599,7 @@ mod tests { Utf8, StringArray ); + #[cfg(feature = "unicode_expressions")] test_function!( Right, &[ @@ -2090,6 +2611,7 @@ mod tests { Utf8, StringArray ); + #[cfg(feature = "unicode_expressions")] test_function!( Right, &[ @@ -2101,6 +2623,21 @@ mod tests { Utf8, StringArray ); + #[cfg(not(feature = "unicode_expressions"))] + test_function!( + Right, + &[ + lit(ScalarValue::Utf8(Some("abcde".to_string()))), + lit(ScalarValue::Int8(Some(2))), + ], + Err(DataFusionError::Internal( + "function right requires compilation with feature flag: unicode_expressions.".to_string() + )), + &str, + Utf8, + StringArray + ); + #[cfg(feature = "unicode_expressions")] test_function!( Rpad, &[ @@ -2112,6 +2649,7 @@ mod tests { Utf8, StringArray ); + #[cfg(feature = "unicode_expressions")] test_function!( Rpad, &[ @@ -2123,6 +2661,7 @@ mod tests { Utf8, StringArray ); + #[cfg(feature = "unicode_expressions")] test_function!( Rpad, &[ @@ -2134,6 +2673,7 @@ mod tests { Utf8, StringArray ); + #[cfg(feature = "unicode_expressions")] test_function!( Rpad, &[ @@ -2145,6 +2685,7 @@ mod tests { Utf8, StringArray ); + #[cfg(feature = "unicode_expressions")] test_function!( Rpad, &[ @@ -2156,6 +2697,7 @@ mod tests { Utf8, StringArray ); + #[cfg(feature = "unicode_expressions")] test_function!( Rpad, &[ @@ -2168,6 +2710,7 @@ mod tests { Utf8, StringArray ); + #[cfg(feature = "unicode_expressions")] test_function!( Rpad, &[ @@ -2180,6 +2723,7 @@ mod tests { Utf8, StringArray ); + #[cfg(feature = "unicode_expressions")] test_function!( Rpad, &[ @@ -2192,6 +2736,7 @@ mod tests { Utf8, StringArray ); + #[cfg(feature = "unicode_expressions")] test_function!( Rpad, &[ @@ -2204,6 +2749,7 @@ mod tests { Utf8, StringArray ); + #[cfg(feature = "unicode_expressions")] test_function!( Rpad, &[ @@ -2216,6 +2762,7 @@ mod tests { Utf8, StringArray ); + #[cfg(feature = "unicode_expressions")] test_function!( Rpad, &[ @@ -2228,6 +2775,7 @@ mod tests { Utf8, StringArray ); + #[cfg(feature = "unicode_expressions")] test_function!( Rpad, &[ @@ -2240,6 +2788,7 @@ mod tests { Utf8, StringArray ); + #[cfg(feature = "unicode_expressions")] test_function!( Rpad, &[ @@ -2252,6 +2801,7 @@ mod tests { Utf8, StringArray ); + #[cfg(feature = "unicode_expressions")] test_function!( Rpad, &[ @@ -2264,6 +2814,20 @@ mod tests { Utf8, StringArray ); + #[cfg(not(feature = "unicode_expressions"))] + test_function!( + Rpad, + &[ + lit(ScalarValue::Utf8(Some("josé".to_string()))), + lit(ScalarValue::Int64(Some(5))), + ], + Err(DataFusionError::Internal( + "function rpad requires compilation with feature flag: unicode_expressions.".to_string() + )), + &str, + Utf8, + StringArray + ); test_function!( Rtrim, &[lit(ScalarValue::Utf8(Some("trim ".to_string())))], @@ -2506,6 +3070,175 @@ mod tests { Binary, BinaryArray ); + test_function!( + SplitPart, + &[ + lit(ScalarValue::Utf8(Some("abc~@~def~@~ghi".to_string()))), + lit(ScalarValue::Utf8(Some("~@~".to_string()))), + lit(ScalarValue::Int64(Some(2))), + ], + Ok(Some("def")), + &str, + Utf8, + StringArray + ); + test_function!( + SplitPart, + &[ + lit(ScalarValue::Utf8(Some("abc~@~def~@~ghi".to_string()))), + lit(ScalarValue::Utf8(Some("~@~".to_string()))), + lit(ScalarValue::Int64(Some(20))), + ], + Ok(Some("")), + &str, + Utf8, + StringArray + ); + test_function!( + SplitPart, + &[ + lit(ScalarValue::Utf8(Some("abc~@~def~@~ghi".to_string()))), + lit(ScalarValue::Utf8(Some("~@~".to_string()))), + lit(ScalarValue::Int64(Some(-1))), + ], + Err(DataFusionError::Execution( + "field position must be greater than zero".to_string(), + )), + &str, + Utf8, + StringArray + ); + test_function!( + StartsWith, + &[ + lit(ScalarValue::Utf8(Some("alphabet".to_string()))), + lit(ScalarValue::Utf8(Some("alph".to_string()))), + ], + Ok(Some(true)), + bool, + Boolean, + BooleanArray + ); + test_function!( + StartsWith, + &[ + lit(ScalarValue::Utf8(Some("alphabet".to_string()))), + lit(ScalarValue::Utf8(Some("blph".to_string()))), + ], + Ok(Some(false)), + bool, + Boolean, + BooleanArray + ); + test_function!( + StartsWith, + &[ + lit(ScalarValue::Utf8(None)), + lit(ScalarValue::Utf8(Some("alph".to_string()))), + ], + Ok(None), + bool, + Boolean, + BooleanArray + ); + test_function!( + StartsWith, + &[ + lit(ScalarValue::Utf8(Some("alphabet".to_string()))), + lit(ScalarValue::Utf8(None)), + ], + Ok(None), + bool, + Boolean, + BooleanArray + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + Strpos, + &[ + lit(ScalarValue::Utf8(Some("abc".to_string()))), + lit(ScalarValue::Utf8(Some("c".to_string()))), + ], + Ok(Some(3)), + i32, + Int32, + Int32Array + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + Strpos, + &[ + lit(ScalarValue::Utf8(Some("josé".to_string()))), + lit(ScalarValue::Utf8(Some("é".to_string()))), + ], + Ok(Some(4)), + i32, + Int32, + Int32Array + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + Strpos, + &[ + lit(ScalarValue::Utf8(Some("joséésoj".to_string()))), + lit(ScalarValue::Utf8(Some("so".to_string()))), + ], + Ok(Some(6)), + i32, + Int32, + Int32Array + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + Strpos, + &[ + lit(ScalarValue::Utf8(Some("joséésoj".to_string()))), + lit(ScalarValue::Utf8(Some("abc".to_string()))), + ], + Ok(Some(0)), + i32, + Int32, + Int32Array + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + Strpos, + &[ + lit(ScalarValue::Utf8(None)), + lit(ScalarValue::Utf8(Some("abc".to_string()))), + ], + Ok(None), + i32, + Int32, + Int32Array + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + Strpos, + &[ + lit(ScalarValue::Utf8(Some("joséésoj".to_string()))), + lit(ScalarValue::Utf8(None)), + ], + Ok(None), + i32, + Int32, + Int32Array + ); + #[cfg(not(feature = "unicode_expressions"))] + test_function!( + Strpos, + &[ + lit(ScalarValue::Utf8(Some("joséésoj".to_string()))), + lit(ScalarValue::Utf8(None)), + ], + Err(DataFusionError::Internal( + "function strpos requires compilation with feature flag: unicode_expressions.".to_string() + )), + i32, + Int32, + Int32Array + ); + #[cfg(feature = "unicode_expressions")] test_function!( Substr, &[ @@ -2517,6 +3250,7 @@ mod tests { Utf8, StringArray ); + #[cfg(feature = "unicode_expressions")] test_function!( Substr, &[ @@ -2528,6 +3262,7 @@ mod tests { Utf8, StringArray ); + #[cfg(feature = "unicode_expressions")] test_function!( Substr, &[ @@ -2539,6 +3274,7 @@ mod tests { Utf8, StringArray ); + #[cfg(feature = "unicode_expressions")] test_function!( Substr, &[ @@ -2550,6 +3286,7 @@ mod tests { Utf8, StringArray ); + #[cfg(feature = "unicode_expressions")] test_function!( Substr, &[ @@ -2561,6 +3298,7 @@ mod tests { Utf8, StringArray ); + #[cfg(feature = "unicode_expressions")] test_function!( Substr, &[ @@ -2572,6 +3310,7 @@ mod tests { Utf8, StringArray ); + #[cfg(feature = "unicode_expressions")] test_function!( Substr, &[ @@ -2583,6 +3322,7 @@ mod tests { Utf8, StringArray ); + #[cfg(feature = "unicode_expressions")] test_function!( Substr, &[ @@ -2594,6 +3334,7 @@ mod tests { Utf8, StringArray ); + #[cfg(feature = "unicode_expressions")] test_function!( Substr, &[ @@ -2606,6 +3347,7 @@ mod tests { Utf8, StringArray ); + #[cfg(feature = "unicode_expressions")] test_function!( Substr, &[ @@ -2618,6 +3360,7 @@ mod tests { Utf8, StringArray ); + #[cfg(feature = "unicode_expressions")] test_function!( Substr, &[ @@ -2630,6 +3373,7 @@ mod tests { Utf8, StringArray ); + #[cfg(feature = "unicode_expressions")] test_function!( Substr, &[ @@ -2642,6 +3386,7 @@ mod tests { Utf8, StringArray ); + #[cfg(feature = "unicode_expressions")] test_function!( Substr, &[ @@ -2656,6 +3401,7 @@ mod tests { Utf8, StringArray ); + #[cfg(feature = "unicode_expressions")] test_function!( Substr, &[ @@ -2668,6 +3414,100 @@ mod tests { Utf8, StringArray ); + #[cfg(not(feature = "unicode_expressions"))] + test_function!( + Substr, + &[ + lit(ScalarValue::Utf8(Some("alphabet".to_string()))), + lit(ScalarValue::Int64(Some(0))), + ], + Err(DataFusionError::Internal( + "function substr requires compilation with feature flag: unicode_expressions.".to_string() + )), + &str, + Utf8, + StringArray + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + Translate, + &[ + lit(ScalarValue::Utf8(Some("12345".to_string()))), + lit(ScalarValue::Utf8(Some("143".to_string()))), + lit(ScalarValue::Utf8(Some("ax".to_string()))), + ], + Ok(Some("a2x5")), + &str, + Utf8, + StringArray + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + Translate, + &[ + lit(ScalarValue::Utf8(None)), + lit(ScalarValue::Utf8(Some("143".to_string()))), + lit(ScalarValue::Utf8(Some("ax".to_string()))), + ], + Ok(None), + &str, + Utf8, + StringArray + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + Translate, + &[ + lit(ScalarValue::Utf8(Some("12345".to_string()))), + lit(ScalarValue::Utf8(None)), + lit(ScalarValue::Utf8(Some("ax".to_string()))), + ], + Ok(None), + &str, + Utf8, + StringArray + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + Translate, + &[ + lit(ScalarValue::Utf8(Some("12345".to_string()))), + lit(ScalarValue::Utf8(Some("143".to_string()))), + lit(ScalarValue::Utf8(None)), + ], + Ok(None), + &str, + Utf8, + StringArray + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + Translate, + &[ + lit(ScalarValue::Utf8(Some("é2íñ5".to_string()))), + lit(ScalarValue::Utf8(Some("éñí".to_string()))), + lit(ScalarValue::Utf8(Some("óü".to_string()))), + ], + Ok(Some("ó2ü5")), + &str, + Utf8, + StringArray + ); + #[cfg(not(feature = "unicode_expressions"))] + test_function!( + Translate, + &[ + lit(ScalarValue::Utf8(Some("12345".to_string()))), + lit(ScalarValue::Utf8(Some("143".to_string()))), + lit(ScalarValue::Utf8(Some("ax".to_string()))), + ], + Err(DataFusionError::Internal( + "function translate requires compilation with feature flag: unicode_expressions.".to_string() + )), + &str, + Utf8, + StringArray + ); test_function!( Trim, &[lit(ScalarValue::Utf8(Some(" trim ".to_string())))], @@ -2700,6 +3540,30 @@ mod tests { Utf8, StringArray ); + test_function!( + Upper, + &[lit(ScalarValue::Utf8(Some("upper".to_string())))], + Ok(Some("UPPER")), + &str, + Utf8, + StringArray + ); + test_function!( + Upper, + &[lit(ScalarValue::Utf8(Some("UPPER".to_string())))], + Ok(Some("UPPER")), + &str, + Utf8, + StringArray + ); + test_function!( + Upper, + &[lit(ScalarValue::Utf8(None))], + Ok(None), + &str, + Utf8, + StringArray + ); Ok(()) } diff --git a/rust/datafusion/src/physical_plan/mod.rs b/rust/datafusion/src/physical_plan/mod.rs index a3471c88ca5..7702ff2c4da 100644 --- a/rust/datafusion/src/physical_plan/mod.rs +++ b/rust/datafusion/src/physical_plan/mod.rs @@ -298,9 +298,13 @@ pub mod merge; pub mod parquet; pub mod planner; pub mod projection; +#[cfg(feature = "regex_expressions")] +pub mod regex_expressions; pub mod repartition; pub mod sort; pub mod string_expressions; pub mod type_coercion; pub mod udaf; pub mod udf; +#[cfg(feature = "unicode_expressions")] +pub mod unicode_expressions; diff --git a/rust/datafusion/src/physical_plan/regex_expressions.rs b/rust/datafusion/src/physical_plan/regex_expressions.rs new file mode 100644 index 00000000000..8df9a822f31 --- /dev/null +++ b/rust/datafusion/src/physical_plan/regex_expressions.rs @@ -0,0 +1,156 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// Some of these functions reference the Postgres documentation +// or implementation to ensure compatibility and are subject to +// the Postgres license. + +//! Regex expressions + +use std::any::type_name; +use std::sync::Arc; + +use crate::error::{DataFusionError, Result}; +use arrow::array::{ArrayRef, GenericStringArray, StringOffsetSizeTrait}; +use hashbrown::HashMap; +use regex::Regex; + +macro_rules! downcast_string_arg { + ($ARG:expr, $NAME:expr, $T:ident) => {{ + $ARG.as_any() + .downcast_ref::>() + .ok_or_else(|| { + DataFusionError::Internal(format!( + "could not cast {} to {}", + $NAME, + type_name::>() + )) + })? + }}; +} + +/// replace POSIX capture groups (like \1) with Rust Regex group (like ${1}) +/// used by regexp_replace +fn regex_replace_posix_groups(replacement: &str) -> String { + lazy_static! { + static ref CAPTURE_GROUPS_RE: Regex = Regex::new("(\\\\)(\\d*)").unwrap(); + } + CAPTURE_GROUPS_RE + .replace_all(replacement, "$${$2}") + .into_owned() +} + +/// Replaces substring(s) matching a POSIX regular expression +/// regexp_replace('Thomas', '.[mN]a.', 'M') = 'ThM' +pub fn regexp_replace(args: &[ArrayRef]) -> Result { + // creating Regex is expensive so create hashmap for memoization + let mut patterns: HashMap = HashMap::new(); + + match args.len() { + 3 => { + let string_array = downcast_string_arg!(args[0], "string", T); + let pattern_array = downcast_string_arg!(args[1], "pattern", T); + let replacement_array = downcast_string_arg!(args[2], "replacement", T); + + let result = string_array + .iter() + .zip(pattern_array.iter()) + .zip(replacement_array.iter()) + .map(|((string, pattern), replacement)| match (string, pattern, replacement) { + (Some(string), Some(pattern), Some(replacement)) => { + let replacement = regex_replace_posix_groups(replacement); + + // if patterns hashmap already has regexp then use else else create and return + let re = match patterns.get(pattern) { + Some(re) => Ok(re.clone()), + None => { + match Regex::new(pattern) { + Ok(re) => { + patterns.insert(pattern.to_string(), re.clone()); + Ok(re) + }, + Err(err) => Err(DataFusionError::Execution(err.to_string())), + } + } + }; + + Some(re.map(|re| re.replace(string, replacement.as_str()))).transpose() + } + _ => Ok(None) + }) + .collect::>>()?; + + Ok(Arc::new(result) as ArrayRef) + } + 4 => { + let string_array = downcast_string_arg!(args[0], "string", T); + let pattern_array = downcast_string_arg!(args[1], "pattern", T); + let replacement_array = downcast_string_arg!(args[2], "replacement", T); + let flags_array = downcast_string_arg!(args[3], "flags", T); + + let result = string_array + .iter() + .zip(pattern_array.iter()) + .zip(replacement_array.iter()) + .zip(flags_array.iter()) + .map(|(((string, pattern), replacement), flags)| match (string, pattern, replacement, flags) { + (Some(string), Some(pattern), Some(replacement), Some(flags)) => { + let replacement = regex_replace_posix_groups(replacement); + + // format flags into rust pattern + let (pattern, replace_all) = if flags == "g" { + (pattern.to_string(), true) + } else if flags.contains('g') { + (format!("(?{}){}", flags.to_string().replace("g", ""), pattern), true) + } else { + (format!("(?{}){}", flags, pattern), false) + }; + + // if patterns hashmap already has regexp then use else else create and return + let re = match patterns.get(&pattern) { + Some(re) => Ok(re.clone()), + None => { + match Regex::new(pattern.as_str()) { + Ok(re) => { + patterns.insert(pattern, re.clone()); + Ok(re) + }, + Err(err) => Err(DataFusionError::Execution(err.to_string())), + } + } + }; + + Some(re.map(|re| { + if replace_all { + re.replace_all(string, replacement.as_str()) + } else { + re.replace(string, replacement.as_str()) + } + })).transpose() + } + _ => Ok(None) + }) + .collect::>>()?; + + Ok(Arc::new(result) as ArrayRef) + } + other => Err(DataFusionError::Internal(format!( + "regexp_replace was called with {} arguments. It requires at least 3 and at most 4.", + other + ))), + } +} diff --git a/rust/datafusion/src/physical_plan/string_expressions.rs b/rust/datafusion/src/physical_plan/string_expressions.rs index bc0e7633379..882fe30502f 100644 --- a/rust/datafusion/src/physical_plan/string_expressions.rs +++ b/rust/datafusion/src/physical_plan/string_expressions.rs @@ -22,8 +22,6 @@ //! String expressions use std::any::type_name; -use std::cmp::Ordering; -use std::str::from_utf8; use std::sync::Arc; use crate::{ @@ -32,12 +30,11 @@ use crate::{ }; use arrow::{ array::{ - Array, ArrayRef, GenericStringArray, Int32Array, Int64Array, PrimitiveArray, - StringArray, StringOffsetSizeTrait, + Array, ArrayRef, BooleanArray, GenericStringArray, Int32Array, Int64Array, + PrimitiveArray, StringArray, StringOffsetSizeTrait, }, datatypes::{ArrowNativeType, ArrowPrimitiveType, DataType}, }; -use unicode_segmentation::UnicodeSegmentation; use super::ColumnarValue; @@ -241,31 +238,6 @@ pub fn btrim(args: &[ArrayRef]) -> Result { } } -/// Returns number of characters in the string. -/// character_length('josé') = 4 -pub fn character_length(args: &[ArrayRef]) -> Result -where - T::Native: StringOffsetSizeTrait, -{ - let string_array: &GenericStringArray = args[0] - .as_any() - .downcast_ref::>() - .ok_or_else(|| { - DataFusionError::Internal("could not cast string to StringArray".to_string()) - })?; - - let result = string_array - .iter() - .map(|string| { - string.map(|string: &str| { - T::Native::from_usize(string.graphemes(true).count()).unwrap() - }) - }) - .collect::>(); - - Ok(Arc::new(result) as ArrayRef) -} - /// Returns the character with the given code. chr(0) is disallowed because text data types cannot store that character. /// chr(65) = 'A' pub fn chr(args: &[ArrayRef]) -> Result { @@ -426,142 +398,12 @@ pub fn initcap(args: &[ArrayRef]) -> Result Ok(Arc::new(result) as ArrayRef) } -/// Returns first n characters in the string, or when n is negative, returns all but last |n| characters. -/// left('abcde', 2) = 'ab' -pub fn left(args: &[ArrayRef]) -> Result { - let string_array = downcast_string_arg!(args[0], "string", T); - let n_array = downcast_arg!(args[1], "n", Int64Array); - - let result = string_array - .iter() - .zip(n_array.iter()) - .map(|(string, n)| match (string, n) { - (None, _) => None, - (_, None) => None, - (Some(string), Some(n)) => match n.cmp(&0) { - Ordering::Equal => Some(""), - Ordering::Greater => Some( - string - .grapheme_indices(true) - .nth(n as usize) - .map_or(string, |(i, _)| { - &from_utf8(&string.as_bytes()[..i]).unwrap() - }), - ), - Ordering::Less => Some( - string - .grapheme_indices(true) - .rev() - .nth(n.abs() as usize - 1) - .map_or("", |(i, _)| { - &from_utf8(&string.as_bytes()[..i]).unwrap() - }), - ), - }, - }) - .collect::>(); - - Ok(Arc::new(result) as ArrayRef) -} - /// Converts the string to all lower case. /// lower('TOM') = 'tom' pub fn lower(args: &[ColumnarValue]) -> Result { handle(args, |string| string.to_ascii_lowercase(), "lower") } -/// Extends the string to length length by prepending the characters fill (a space by default). If the string is already longer than length then it is truncated (on the right). -/// lpad('hi', 5, 'xy') = 'xyxhi' -pub fn lpad(args: &[ArrayRef]) -> Result { - match args.len() { - 2 => { - let string_array = downcast_string_arg!(args[0], "string", T); - let length_array = downcast_arg!(args[1], "length", Int64Array); - - let result = string_array - .iter() - .zip(length_array.iter()) - .map(|(string, length)| match (string, length) { - (None, _) => None, - (_, None) => None, - (Some(string), Some(length)) => { - let length = length as usize; - if length == 0 { - Some("".to_string()) - } else { - let graphemes = string.graphemes(true).collect::>(); - if length < graphemes.len() { - Some(graphemes[..length].concat()) - } else { - let mut s = string.to_string(); - s.insert_str( - 0, - " ".repeat(length - graphemes.len()).as_str(), - ); - Some(s) - } - } - } - }) - .collect::>(); - - Ok(Arc::new(result) as ArrayRef) - } - 3 => { - let string_array = downcast_string_arg!(args[0], "string", T); - let length_array = downcast_arg!(args[1], "length", Int64Array); - let fill_array = downcast_string_arg!(args[2], "fill", T); - - let result = string_array - .iter() - .zip(length_array.iter()) - .zip(fill_array.iter()) - .map(|((string, length), fill)| match (string, length, fill) { - (None, _, _) => None, - (_, None, _) => None, - (_, _, None) => None, - (Some(string), Some(length), Some(fill)) => { - let length = length as usize; - - if length == 0 { - Some("".to_string()) - } else { - let graphemes = string.graphemes(true).collect::>(); - let fill_chars = fill.chars().collect::>(); - - if length < graphemes.len() { - Some(graphemes[..length].concat()) - } else if fill_chars.is_empty() { - Some(string.to_string()) - } else { - let mut s = string.to_string(); - let mut char_vector = - Vec::::with_capacity(length - graphemes.len()); - for l in 0..length - graphemes.len() { - char_vector.push( - *fill_chars.get(l % fill_chars.len()).unwrap(), - ); - } - s.insert_str( - 0, - char_vector.iter().collect::().as_str(), - ); - Some(s) - } - } - } - }) - .collect::>(); - - Ok(Arc::new(result) as ArrayRef) - } - other => Err(DataFusionError::Internal(format!( - "lpad was called with {} arguments. It requires at least 2 and at most 3.", - other - ))), - } -} - /// Removes the longest string containing only characters in characters (a space by default) from the start of string. /// ltrim('zzzytest', 'xyz') = 'test' pub fn ltrim(args: &[ArrayRef]) -> Result { @@ -584,12 +426,11 @@ pub fn ltrim(args: &[ArrayRef]) -> Result { .iter() .zip(characters_array.iter()) .map(|(string, characters)| match (string, characters) { - (None, _) => None, - (_, None) => None, (Some(string), Some(characters)) => { let chars: Vec = characters.chars().collect(); Some(string.trim_start_matches(&chars[..])) } + _ => None, }) .collect::>(); @@ -612,148 +453,34 @@ pub fn repeat(args: &[ArrayRef]) -> Result { .iter() .zip(number_array.iter()) .map(|(string, number)| match (string, number) { - (None, _) => None, - (_, None) => None, (Some(string), Some(number)) => Some(string.repeat(number as usize)), + _ => None, }) .collect::>(); Ok(Arc::new(result) as ArrayRef) } -/// Reverses the order of the characters in the string. -/// reverse('abcde') = 'edcba' -pub fn reverse(args: &[ArrayRef]) -> Result { - let string_array = downcast_string_arg!(args[0], "string", T); - - let result = string_array - .iter() - .map(|string| { - string.map(|string: &str| string.graphemes(true).rev().collect::()) - }) - .collect::>(); - - Ok(Arc::new(result) as ArrayRef) -} - -/// Returns last n characters in the string, or when n is negative, returns all but first |n| characters. -/// right('abcde', 2) = 'de' -pub fn right(args: &[ArrayRef]) -> Result { +/// Replaces all occurrences in string of substring from with substring to. +/// replace('abcdefabcdef', 'cd', 'XX') = 'abXXefabXXef' +pub fn replace(args: &[ArrayRef]) -> Result { let string_array = downcast_string_arg!(args[0], "string", T); - let n_array = downcast_arg!(args[1], "n", Int64Array); + let from_array = downcast_string_arg!(args[1], "from", T); + let to_array = downcast_string_arg!(args[2], "to", T); let result = string_array .iter() - .zip(n_array.iter()) - .map(|(string, n)| match (string, n) { - (None, _) => None, - (_, None) => None, - (Some(string), Some(n)) => match n.cmp(&0) { - Ordering::Equal => Some(""), - Ordering::Greater => Some( - string - .grapheme_indices(true) - .rev() - .nth(n as usize - 1) - .map_or(string, |(i, _)| { - &from_utf8(&string.as_bytes()[i..]).unwrap() - }), - ), - Ordering::Less => Some( - string - .grapheme_indices(true) - .nth(n.abs() as usize) - .map_or("", |(i, _)| { - &from_utf8(&string.as_bytes()[i..]).unwrap() - }), - ), - }, + .zip(from_array.iter()) + .zip(to_array.iter()) + .map(|((string, from), to)| match (string, from, to) { + (Some(string), Some(from), Some(to)) => Some(string.replace(from, to)), + _ => None, }) .collect::>(); Ok(Arc::new(result) as ArrayRef) } -/// Extends the string to length length by appending the characters fill (a space by default). If the string is already longer than length then it is truncated. -/// rpad('hi', 5, 'xy') = 'hixyx' -pub fn rpad(args: &[ArrayRef]) -> Result { - match args.len() { - 2 => { - let string_array = downcast_string_arg!(args[0], "string", T); - let length_array = downcast_arg!(args[1], "length", Int64Array); - - let result = string_array - .iter() - .zip(length_array.iter()) - .map(|(string, length)| match (string, length) { - (None, _) => None, - (_, None) => None, - (Some(string), Some(length)) => { - let length = length as usize; - if length == 0 { - Some("".to_string()) - } else { - let graphemes = string.graphemes(true).collect::>(); - if length < graphemes.len() { - Some(graphemes[..length].concat()) - } else { - let mut s = string.to_string(); - s.push_str(" ".repeat(length - graphemes.len()).as_str()); - Some(s) - } - } - } - }) - .collect::>(); - - Ok(Arc::new(result) as ArrayRef) - } - 3 => { - let string_array = downcast_string_arg!(args[0], "string", T); - let length_array = downcast_arg!(args[1], "length", Int64Array); - let fill_array = downcast_string_arg!(args[2], "fill", T); - - let result = string_array - .iter() - .zip(length_array.iter()) - .zip(fill_array.iter()) - .map(|((string, length), fill)| match (string, length, fill) { - (None, _, _) => None, - (_, None, _) => None, - (_, _, None) => None, - (Some(string), Some(length), Some(fill)) => { - let length = length as usize; - let graphemes = string.graphemes(true).collect::>(); - let fill_chars = fill.chars().collect::>(); - - if length < graphemes.len() { - Some(graphemes[..length].concat()) - } else if fill_chars.is_empty() { - Some(string.to_string()) - } else { - let mut s = string.to_string(); - let mut char_vector = - Vec::::with_capacity(length - graphemes.len()); - for l in 0..length - graphemes.len() { - char_vector - .push(*fill_chars.get(l % fill_chars.len()).unwrap()); - } - s.push_str(char_vector.iter().collect::().as_str()); - Some(s) - } - } - }) - .collect::>(); - - Ok(Arc::new(result) as ArrayRef) - } - other => Err(DataFusionError::Internal(format!( - "rpad was called with {} arguments. It requires at least 2 and at most 3.", - other - ))), - } -} - /// Removes the longest string containing only characters in characters (a space by default) from the end of string. /// rtrim('testxxzx', 'xyz') = 'test' pub fn rtrim(args: &[ArrayRef]) -> Result { @@ -776,12 +503,11 @@ pub fn rtrim(args: &[ArrayRef]) -> Result { .iter() .zip(characters_array.iter()) .map(|(string, characters)| match (string, characters) { - (None, _) => None, - (_, None) => None, (Some(string), Some(characters)) => { let chars: Vec = characters.chars().collect(); Some(string.trim_end_matches(&chars[..])) } + _ => None, }) .collect::>(); @@ -794,85 +520,54 @@ pub fn rtrim(args: &[ArrayRef]) -> Result { } } -/// Extracts the substring of string starting at the start'th character, and extending for count characters if that is specified. (Same as substring(string from start for count).) -/// substr('alphabet', 3) = 'phabet' -/// substr('alphabet', 3, 2) = 'ph' -pub fn substr(args: &[ArrayRef]) -> Result { - match args.len() { - 2 => { - let string_array = downcast_string_arg!(args[0], "string", T); - let start_array = downcast_arg!(args[1], "start", Int64Array); +/// Splits string at occurrences of delimiter and returns the n'th field (counting from one). +/// split_part('abc~@~def~@~ghi', '~@~', 2) = 'def' +pub fn split_part(args: &[ArrayRef]) -> Result { + let string_array = downcast_string_arg!(args[0], "string", T); + let delimiter_array = downcast_string_arg!(args[1], "delimiter", T); + let n_array = downcast_arg!(args[2], "n", Int64Array); - let result = string_array - .iter() - .zip(start_array.iter()) - .map(|(string, start)| match (string, start) { - (None, _) => None, - (_, None) => None, - (Some(string), Some(start)) => { - if start <= 0 { - Some(string.to_string()) - } else { - let graphemes = string.graphemes(true).collect::>(); - let start_pos = start as usize - 1; - if graphemes.len() < start_pos { - Some("".to_string()) - } else { - Some(graphemes[start_pos..].concat()) - } - } + let result = string_array + .iter() + .zip(delimiter_array.iter()) + .zip(n_array.iter()) + .map(|((string, delimiter), n)| match (string, delimiter, n) { + (Some(string), Some(delimiter), Some(n)) => { + if n <= 0 { + Err(DataFusionError::Execution( + "field position must be greater than zero".to_string(), + )) + } else { + let split_string: Vec<&str> = string.split(delimiter).collect(); + match split_string.get(n as usize - 1) { + Some(s) => Ok(Some(*s)), + None => Ok(Some("")), } - }) - .collect::>(); + } + } + _ => Ok(None), + }) + .collect::>>()?; - Ok(Arc::new(result) as ArrayRef) - } - 3 => { - let string_array = downcast_string_arg!(args[0], "string", T); - let start_array = downcast_arg!(args[1], "start", Int64Array); - let count_array = downcast_arg!(args[2], "count", Int64Array); + Ok(Arc::new(result) as ArrayRef) +} - let result = string_array - .iter() - .zip(start_array.iter()) - .zip(count_array.iter()) - .map(|((string, start), count)| match (string, start, count) { - (None, _, _) => Ok(None), - (_, None, _) => Ok(None), - (_, _, None) => Ok(None), - (Some(string), Some(start), Some(count)) => { - if count < 0 { - Err(DataFusionError::Execution( - "negative substring length not allowed".to_string(), - )) - } else if start <= 0 { - Ok(Some(string.to_string())) - } else { - let graphemes = string.graphemes(true).collect::>(); - let start_pos = start as usize - 1; - let count_usize = count as usize; - if graphemes.len() < start_pos { - Ok(Some("".to_string())) - } else if graphemes.len() < start_pos + count_usize { - Ok(Some(graphemes[start_pos..].concat())) - } else { - Ok(Some( - graphemes[start_pos..start_pos + count_usize] - .concat(), - )) - } - } - } - }) - .collect::>>()?; +/// Returns true if string starts with prefix. +/// starts_with('alphabet', 'alph') = 't' +pub fn starts_with(args: &[ArrayRef]) -> Result { + let string_array = downcast_string_arg!(args[0], "string", T); + let prefix_array = downcast_string_arg!(args[1], "prefix", T); - Ok(Arc::new(result) as ArrayRef) - } - other => Err(DataFusionError::Internal(format!( - "substr was called with {} arguments. It requires 2 or 3.", - other - ))), - } + let result = string_array + .iter() + .zip(prefix_array.iter()) + .map(|(string, prefix)| match (string, prefix) { + (Some(string), Some(prefix)) => Some(string.starts_with(prefix)), + _ => None, + }) + .collect::(); + + Ok(Arc::new(result) as ArrayRef) } /// Converts the number to its equivalent hexadecimal representation. diff --git a/rust/datafusion/src/physical_plan/unicode_expressions.rs b/rust/datafusion/src/physical_plan/unicode_expressions.rs new file mode 100644 index 00000000000..787ea7ea267 --- /dev/null +++ b/rust/datafusion/src/physical_plan/unicode_expressions.rs @@ -0,0 +1,532 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// Some of these functions reference the Postgres documentation +// or implementation to ensure compatibility and are subject to +// the Postgres license. + +//! Unicode expressions + +use std::any::type_name; +use std::cmp::Ordering; +use std::sync::Arc; + +use crate::error::{DataFusionError, Result}; +use arrow::{ + array::{ + ArrayRef, GenericStringArray, Int64Array, PrimitiveArray, StringOffsetSizeTrait, + }, + datatypes::{ArrowNativeType, ArrowPrimitiveType}, +}; +use hashbrown::HashMap; +use unicode_segmentation::UnicodeSegmentation; + +macro_rules! downcast_string_arg { + ($ARG:expr, $NAME:expr, $T:ident) => {{ + $ARG.as_any() + .downcast_ref::>() + .ok_or_else(|| { + DataFusionError::Internal(format!( + "could not cast {} to {}", + $NAME, + type_name::>() + )) + })? + }}; +} + +macro_rules! downcast_arg { + ($ARG:expr, $NAME:expr, $ARRAY_TYPE:ident) => {{ + $ARG.as_any().downcast_ref::<$ARRAY_TYPE>().ok_or_else(|| { + DataFusionError::Internal(format!( + "could not cast {} to {}", + $NAME, + type_name::<$ARRAY_TYPE>() + )) + })? + }}; +} + +/// Returns number of characters in the string. +/// character_length('josé') = 4 +pub fn character_length(args: &[ArrayRef]) -> Result +where + T::Native: StringOffsetSizeTrait, +{ + let string_array: &GenericStringArray = args[0] + .as_any() + .downcast_ref::>() + .ok_or_else(|| { + DataFusionError::Internal("could not cast string to StringArray".to_string()) + })?; + + let result = string_array + .iter() + .map(|string| { + string.map(|string: &str| { + T::Native::from_usize(string.graphemes(true).count()).expect( + "should not fail as graphemes.count will always return integer", + ) + }) + }) + .collect::>(); + + Ok(Arc::new(result) as ArrayRef) +} + +/// Returns first n characters in the string, or when n is negative, returns all but last |n| characters. +/// left('abcde', 2) = 'ab' +pub fn left(args: &[ArrayRef]) -> Result { + let string_array = downcast_string_arg!(args[0], "string", T); + let n_array = downcast_arg!(args[1], "n", Int64Array); + + let result = string_array + .iter() + .zip(n_array.iter()) + .map(|(string, n)| match (string, n) { + (Some(string), Some(n)) => match n.cmp(&0) { + Ordering::Less => { + let graphemes = string.graphemes(true); + let len = graphemes.clone().count() as i64; + match n.abs().cmp(&len) { + Ordering::Less => { + Some(graphemes.take((len + n) as usize).collect::()) + } + Ordering::Equal => Some("".to_string()), + Ordering::Greater => Some("".to_string()), + } + } + Ordering::Equal => Some("".to_string()), + Ordering::Greater => { + Some(string.graphemes(true).take(n as usize).collect::()) + } + }, + _ => None, + }) + .collect::>(); + + Ok(Arc::new(result) as ArrayRef) +} + +/// Extends the string to length 'length' by prepending the characters fill (a space by default). If the string is already longer than length then it is truncated (on the right). +/// lpad('hi', 5, 'xy') = 'xyxhi' +pub fn lpad(args: &[ArrayRef]) -> Result { + match args.len() { + 2 => { + let string_array = downcast_string_arg!(args[0], "string", T); + let length_array = downcast_arg!(args[1], "length", Int64Array); + + let result = string_array + .iter() + .zip(length_array.iter()) + .map(|(string, length)| match (string, length) { + (Some(string), Some(length)) => { + let length = length as usize; + if length == 0 { + Some("".to_string()) + } else { + let graphemes = string.graphemes(true).collect::>(); + if length < graphemes.len() { + Some(graphemes[..length].concat()) + } else { + let mut s = string.to_string(); + s.insert_str( + 0, + " ".repeat(length - graphemes.len()).as_str(), + ); + Some(s) + } + } + } + _ => None, + }) + .collect::>(); + + Ok(Arc::new(result) as ArrayRef) + } + 3 => { + let string_array = downcast_string_arg!(args[0], "string", T); + let length_array = downcast_arg!(args[1], "length", Int64Array); + let fill_array = downcast_string_arg!(args[2], "fill", T); + + let result = string_array + .iter() + .zip(length_array.iter()) + .zip(fill_array.iter()) + .map(|((string, length), fill)| match (string, length, fill) { + (Some(string), Some(length), Some(fill)) => { + let length = length as usize; + + if length == 0 { + Some("".to_string()) + } else { + let graphemes = string.graphemes(true).collect::>(); + let fill_chars = fill.chars().collect::>(); + + if length < graphemes.len() { + Some(graphemes[..length].concat()) + } else if fill_chars.is_empty() { + Some(string.to_string()) + } else { + let mut s = string.to_string(); + let mut char_vector = + Vec::::with_capacity(length - graphemes.len()); + for l in 0..length - graphemes.len() { + char_vector.push( + *fill_chars.get(l % fill_chars.len()).unwrap(), + ); + } + s.insert_str( + 0, + char_vector.iter().collect::().as_str(), + ); + Some(s) + } + } + } + _ => None, + }) + .collect::>(); + + Ok(Arc::new(result) as ArrayRef) + } + other => Err(DataFusionError::Internal(format!( + "lpad was called with {} arguments. It requires at least 2 and at most 3.", + other + ))), + } +} + +/// Reverses the order of the characters in the string. +/// reverse('abcde') = 'edcba' +pub fn reverse(args: &[ArrayRef]) -> Result { + let string_array = downcast_string_arg!(args[0], "string", T); + + let result = string_array + .iter() + .map(|string| { + string.map(|string: &str| string.graphemes(true).rev().collect::()) + }) + .collect::>(); + + Ok(Arc::new(result) as ArrayRef) +} + +/// Returns last n characters in the string, or when n is negative, returns all but first |n| characters. +/// right('abcde', 2) = 'de' +pub fn right(args: &[ArrayRef]) -> Result { + let string_array = downcast_string_arg!(args[0], "string", T); + let n_array = downcast_arg!(args[1], "n", Int64Array); + + let result = string_array + .iter() + .zip(n_array.iter()) + .map(|(string, n)| match (string, n) { + (Some(string), Some(n)) => match n.cmp(&0) { + Ordering::Less => { + let graphemes = string.graphemes(true).rev(); + let len = graphemes.clone().count() as i64; + match n.abs().cmp(&len) { + Ordering::Less => Some( + graphemes + .take((len + n) as usize) + .collect::>() + .iter() + .rev() + .copied() + .collect::(), + ), + Ordering::Equal => Some("".to_string()), + Ordering::Greater => Some("".to_string()), + } + } + Ordering::Equal => Some("".to_string()), + Ordering::Greater => Some( + string + .graphemes(true) + .rev() + .take(n as usize) + .collect::>() + .iter() + .rev() + .copied() + .collect::(), + ), + }, + _ => None, + }) + .collect::>(); + + Ok(Arc::new(result) as ArrayRef) +} + +/// Extends the string to length 'length' by appending the characters fill (a space by default). If the string is already longer than length then it is truncated. +/// rpad('hi', 5, 'xy') = 'hixyx' +pub fn rpad(args: &[ArrayRef]) -> Result { + match args.len() { + 2 => { + let string_array = downcast_string_arg!(args[0], "string", T); + let length_array = downcast_arg!(args[1], "length", Int64Array); + + let result = string_array + .iter() + .zip(length_array.iter()) + .map(|(string, length)| match (string, length) { + (Some(string), Some(length)) => { + let length = length as usize; + if length == 0 { + Some("".to_string()) + } else { + let graphemes = string.graphemes(true).collect::>(); + if length < graphemes.len() { + Some(graphemes[..length].concat()) + } else { + let mut s = string.to_string(); + s.push_str(" ".repeat(length - graphemes.len()).as_str()); + Some(s) + } + } + } + _ => None, + }) + .collect::>(); + + Ok(Arc::new(result) as ArrayRef) + } + 3 => { + let string_array = downcast_string_arg!(args[0], "string", T); + let length_array = downcast_arg!(args[1], "length", Int64Array); + let fill_array = downcast_string_arg!(args[2], "fill", T); + + let result = string_array + .iter() + .zip(length_array.iter()) + .zip(fill_array.iter()) + .map(|((string, length), fill)| match (string, length, fill) { + (Some(string), Some(length), Some(fill)) => { + let length = length as usize; + let graphemes = string.graphemes(true).collect::>(); + let fill_chars = fill.chars().collect::>(); + + if length < graphemes.len() { + Some(graphemes[..length].concat()) + } else if fill_chars.is_empty() { + Some(string.to_string()) + } else { + let mut s = string.to_string(); + let mut char_vector = + Vec::::with_capacity(length - graphemes.len()); + for l in 0..length - graphemes.len() { + char_vector + .push(*fill_chars.get(l % fill_chars.len()).unwrap()); + } + s.push_str(char_vector.iter().collect::().as_str()); + Some(s) + } + } + _ => None, + }) + .collect::>(); + + Ok(Arc::new(result) as ArrayRef) + } + other => Err(DataFusionError::Internal(format!( + "rpad was called with {} arguments. It requires at least 2 and at most 3.", + other + ))), + } +} + +/// Returns starting index of specified substring within string, or zero if it's not present. (Same as position(substring in string), but note the reversed argument order.) +/// strpos('high', 'ig') = 2 +pub fn strpos(args: &[ArrayRef]) -> Result +where + T::Native: StringOffsetSizeTrait, +{ + let string_array: &GenericStringArray = args[0] + .as_any() + .downcast_ref::>() + .ok_or_else(|| { + DataFusionError::Internal("could not cast string to StringArray".to_string()) + })?; + + let substring_array: &GenericStringArray = args[1] + .as_any() + .downcast_ref::>() + .ok_or_else(|| { + DataFusionError::Internal( + "could not cast substring to StringArray".to_string(), + ) + })?; + + let result = string_array + .iter() + .zip(substring_array.iter()) + .map(|(string, substring)| match (string, substring) { + (Some(string), Some(substring)) => { + // the rfind method returns the byte index of the substring which may or may not be the same as the character index due to UTF8 encoding + // this method first finds the matching byte using rfind + // then maps that to the character index by matching on the grapheme_index of the byte_index + Some( + T::Native::from_usize(string.to_string().rfind(substring).map_or( + 0, + |byte_offset| { + string + .grapheme_indices(true) + .collect::>() + .iter() + .enumerate() + .filter(|(_, (offset, _))| *offset == byte_offset) + .map(|(index, _)| index) + .collect::>() + .first() + .expect("should not fail as grapheme_indices and byte offsets are tightly coupled") + .to_owned() + + 1 + }, + )) + .expect("should not fail due to map_or default value") + ) + } + _ => None, + }) + .collect::>(); + + Ok(Arc::new(result) as ArrayRef) +} + +/// Extracts the substring of string starting at the start'th character, and extending for count characters if that is specified. (Same as substring(string from start for count).) +/// substr('alphabet', 3) = 'phabet' +/// substr('alphabet', 3, 2) = 'ph' +pub fn substr(args: &[ArrayRef]) -> Result { + match args.len() { + 2 => { + let string_array = downcast_string_arg!(args[0], "string", T); + let start_array = downcast_arg!(args[1], "start", Int64Array); + + let result = string_array + .iter() + .zip(start_array.iter()) + .map(|(string, start)| match (string, start) { + (Some(string), Some(start)) => { + if start <= 0 { + Some(string.to_string()) + } else { + let graphemes = string.graphemes(true).collect::>(); + let start_pos = start as usize - 1; + if graphemes.len() < start_pos { + Some("".to_string()) + } else { + Some(graphemes[start_pos..].concat()) + } + } + } + _ => None, + }) + .collect::>(); + + Ok(Arc::new(result) as ArrayRef) + } + 3 => { + let string_array = downcast_string_arg!(args[0], "string", T); + let start_array = downcast_arg!(args[1], "start", Int64Array); + let count_array = downcast_arg!(args[2], "count", Int64Array); + + let result = string_array + .iter() + .zip(start_array.iter()) + .zip(count_array.iter()) + .map(|((string, start), count)| match (string, start, count) { + (Some(string), Some(start), Some(count)) => { + if count < 0 { + Err(DataFusionError::Execution( + "negative substring length not allowed".to_string(), + )) + } else if start <= 0 { + Ok(Some(string.to_string())) + } else { + let graphemes = string.graphemes(true).collect::>(); + let start_pos = start as usize - 1; + let count_usize = count as usize; + if graphemes.len() < start_pos { + Ok(Some("".to_string())) + } else if graphemes.len() < start_pos + count_usize { + Ok(Some(graphemes[start_pos..].concat())) + } else { + Ok(Some( + graphemes[start_pos..start_pos + count_usize] + .concat(), + )) + } + } + } + _ => Ok(None), + }) + .collect::>>()?; + + Ok(Arc::new(result) as ArrayRef) + } + other => Err(DataFusionError::Internal(format!( + "substr was called with {} arguments. It requires 2 or 3.", + other + ))), + } +} + +/// Replaces each character in string that matches a character in the from set with the corresponding character in the to set. If from is longer than to, occurrences of the extra characters in from are deleted. +/// translate('12345', '143', 'ax') = 'a2x5' +pub fn translate(args: &[ArrayRef]) -> Result { + let string_array = downcast_string_arg!(args[0], "string", T); + let from_array = downcast_string_arg!(args[1], "from", T); + let to_array = downcast_string_arg!(args[2], "to", T); + + let result = string_array + .iter() + .zip(from_array.iter()) + .zip(to_array.iter()) + .map(|((string, from), to)| match (string, from, to) { + (Some(string), Some(from), Some(to)) => { + // create a hashmap of [char, index] to change from O(n) to O(1) for from list + let from_map: HashMap<&str, usize> = from + .graphemes(true) + .collect::>() + .iter() + .enumerate() + .map(|(index, c)| (c.to_owned(), index)) + .collect(); + + let to = to.graphemes(true).collect::>(); + + Some( + string + .graphemes(true) + .collect::>() + .iter() + .flat_map(|c| match from_map.get(*c) { + Some(n) => to.get(*n).copied(), + None => Some(*c), + }) + .collect::>() + .concat(), + ) + } + _ => None, + }) + .collect::>(); + + Ok(Arc::new(result) as ArrayRef) +} diff --git a/rust/datafusion/src/prelude.rs b/rust/datafusion/src/prelude.rs index 5ab5e760246..0edc82a98af 100644 --- a/rust/datafusion/src/prelude.rs +++ b/rust/datafusion/src/prelude.rs @@ -30,7 +30,8 @@ pub use crate::execution::context::{ExecutionConfig, ExecutionContext}; pub use crate::logical_plan::{ array, ascii, avg, bit_length, btrim, character_length, chr, col, concat, concat_ws, count, create_udf, in_list, initcap, left, length, lit, lower, lpad, ltrim, max, md5, - min, octet_length, repeat, reverse, right, rpad, rtrim, sha224, sha256, sha384, - sha512, substr, sum, to_hex, trim, upper, JoinType, Partitioning, + min, octet_length, regexp_replace, repeat, replace, reverse, right, rpad, rtrim, + sha224, sha256, sha384, sha512, split_part, starts_with, strpos, substr, sum, to_hex, + translate, trim, upper, JoinType, Partitioning, }; pub use crate::physical_plan::csv::CsvReadOptions; diff --git a/rust/datafusion/tests/sql.rs b/rust/datafusion/tests/sql.rs index c8e198cb13c..d7bea20abe6 100644 --- a/rust/datafusion/tests/sql.rs +++ b/rust/datafusion/tests/sql.rs @@ -1582,11 +1582,13 @@ async fn generic_query_length>>( } #[tokio::test] +#[cfg_attr(not(feature = "unicode_expressions"), ignore)] async fn query_length() -> Result<()> { generic_query_length::(DataType::Utf8).await } #[tokio::test] +#[cfg_attr(not(feature = "unicode_expressions"), ignore)] async fn query_large_length() -> Result<()> { generic_query_length::(DataType::LargeUtf8).await } @@ -2033,125 +2035,42 @@ macro_rules! test_expression { } #[tokio::test] -async fn test_string_expressions() -> Result<()> { - test_expression!("ascii('')", "0"); - test_expression!("ascii('x')", "120"); - test_expression!("ascii(NULL)", "NULL"); - test_expression!("bit_length('')", "0"); - test_expression!("bit_length('chars')", "40"); - test_expression!("bit_length('josé')", "40"); - test_expression!("bit_length(NULL)", "NULL"); - test_expression!("btrim(' xyxtrimyyx ', NULL)", "NULL"); - test_expression!("btrim(' xyxtrimyyx ')", "xyxtrimyyx"); - test_expression!("btrim('\n xyxtrimyyx \n')", "\n xyxtrimyyx \n"); - test_expression!("btrim('xyxtrimyyx', 'xyz')", "trim"); - test_expression!("btrim('\nxyxtrimyyx\n', 'xyz\n')", "trim"); - test_expression!("btrim(NULL, 'xyz')", "NULL"); - test_expression!("char_length('')", "0"); - test_expression!("char_length('chars')", "5"); - test_expression!("char_length(NULL)", "NULL"); - test_expression!("character_length('')", "0"); - test_expression!("character_length('chars')", "5"); - test_expression!("character_length('josé')", "4"); - test_expression!("character_length(NULL)", "NULL"); - test_expression!("chr(CAST(120 AS int))", "x"); - test_expression!("chr(CAST(128175 AS int))", "💯"); - test_expression!("chr(CAST(NULL AS int))", "NULL"); - test_expression!("concat('a','b','c')", "abc"); - test_expression!("concat('abcde', 2, NULL, 22)", "abcde222"); - test_expression!("concat(NULL)", ""); - test_expression!("concat_ws(',', 'abcde', 2, NULL, 22)", "abcde,2,22"); - test_expression!("concat_ws('|','a','b','c')", "a|b|c"); - test_expression!("concat_ws('|',NULL)", ""); - test_expression!("concat_ws(NULL,'a',NULL,'b','c')", "NULL"); - test_expression!("initcap('')", ""); - test_expression!("initcap('hi THOMAS')", "Hi Thomas"); - test_expression!("initcap(NULL)", "NULL"); - test_expression!("left('abcde', -2)", "abc"); - test_expression!("left('abcde', -200)", ""); - test_expression!("left('abcde', 0)", ""); - test_expression!("left('abcde', 2)", "ab"); - test_expression!("left('abcde', 200)", "abcde"); - test_expression!("left('abcde', CAST(NULL AS INT))", "NULL"); - test_expression!("left(NULL, 2)", "NULL"); - test_expression!("left(NULL, CAST(NULL AS INT))", "NULL"); - test_expression!("lower('')", ""); - test_expression!("lower('TOM')", "tom"); - test_expression!("lower(NULL)", "NULL"); - test_expression!("lpad('hi', 5, 'xy')", "xyxhi"); - test_expression!("lpad('hi', 0)", ""); - test_expression!("lpad('hi', 21, 'abcdef')", "abcdefabcdefabcdefahi"); - test_expression!("lpad('hi', 5, 'xy')", "xyxhi"); - test_expression!("lpad('hi', 5, NULL)", "NULL"); - test_expression!("lpad('hi', 5)", " hi"); - test_expression!("lpad('hi', CAST(NULL AS INT), 'xy')", "NULL"); - test_expression!("lpad('hi', CAST(NULL AS INT))", "NULL"); - test_expression!("lpad('xyxhi', 3)", "xyx"); - test_expression!("lpad(NULL, 0)", "NULL"); - test_expression!("lpad(NULL, 5, 'xy')", "NULL"); - test_expression!("ltrim(' zzzytest ', NULL)", "NULL"); - test_expression!("ltrim(' zzzytest ')", "zzzytest "); - test_expression!("ltrim('zzzytest', 'xyz')", "test"); - test_expression!("ltrim(NULL, 'xyz')", "NULL"); - test_expression!("octet_length('')", "0"); - test_expression!("octet_length('chars')", "5"); - test_expression!("octet_length('josé')", "5"); - test_expression!("octet_length(NULL)", "NULL"); - test_expression!("repeat('Pg', 4)", "PgPgPgPg"); - test_expression!("repeat('Pg', CAST(NULL AS INT))", "NULL"); - test_expression!("repeat(NULL, 4)", "NULL"); - test_expression!("reverse('abcde')", "edcba"); - test_expression!("reverse('loẅks')", "skẅol"); - test_expression!("reverse(NULL)", "NULL"); - test_expression!("right('abcde', -2)", "cde"); - test_expression!("right('abcde', -200)", ""); - test_expression!("right('abcde', 0)", ""); - test_expression!("right('abcde', 2)", "de"); - test_expression!("right('abcde', 200)", "abcde"); - test_expression!("right('abcde', CAST(NULL AS INT))", "NULL"); - test_expression!("right(NULL, 2)", "NULL"); - test_expression!("right(NULL, CAST(NULL AS INT))", "NULL"); - test_expression!("rpad('hi', 5, 'xy')", "hixyx"); - test_expression!("rpad('hi', 0)", ""); - test_expression!("rpad('hi', 21, 'abcdef')", "hiabcdefabcdefabcdefa"); - test_expression!("rpad('hi', 5, 'xy')", "hixyx"); - test_expression!("rpad('hi', 5, NULL)", "NULL"); - test_expression!("rpad('hi', 5)", "hi "); - test_expression!("rpad('hi', CAST(NULL AS INT), 'xy')", "NULL"); - test_expression!("rpad('hi', CAST(NULL AS INT))", "NULL"); - test_expression!("rpad('xyxhi', 3)", "xyx"); - test_expression!("rtrim(' testxxzx ')", " testxxzx"); - test_expression!("rtrim(' zzzytest ', NULL)", "NULL"); - test_expression!("rtrim('testxxzx', 'xyz')", "test"); - test_expression!("rtrim(NULL, 'xyz')", "NULL"); - test_expression!("substr('alphabet', -3)", "alphabet"); - test_expression!("substr('alphabet', 0)", "alphabet"); - test_expression!("substr('alphabet', 1)", "alphabet"); - test_expression!("substr('alphabet', 2)", "lphabet"); - test_expression!("substr('alphabet', 3)", "phabet"); - test_expression!("substr('alphabet', 30)", ""); - test_expression!("substr('alphabet', CAST(NULL AS int))", "NULL"); - test_expression!("substr('alphabet', 3, 2)", "ph"); - test_expression!("substr('alphabet', 3, 20)", "phabet"); - test_expression!("substr('alphabet', CAST(NULL AS int), 20)", "NULL"); - test_expression!("substr('alphabet', 3, CAST(NULL AS int))", "NULL"); - test_expression!("to_hex(2147483647)", "7fffffff"); - test_expression!("to_hex(9223372036854775807)", "7fffffffffffffff"); - test_expression!("to_hex(CAST(NULL AS int))", "NULL"); - test_expression!("trim(' tom ')", "tom"); - test_expression!("trim(' tom')", "tom"); - test_expression!("trim('')", ""); - test_expression!("trim('tom ')", "tom"); - test_expression!("upper('')", ""); - test_expression!("upper('tom')", "TOM"); - test_expression!("upper(NULL)", "NULL"); +async fn test_boolean_expressions() -> Result<()> { + test_expression!("true", "true"); + test_expression!("false", "false"); Ok(()) } #[tokio::test] -async fn test_boolean_expressions() -> Result<()> { - test_expression!("true", "true"); - test_expression!("false", "false"); +#[cfg_attr(not(feature = "crypto_expressions"), ignore)] +async fn test_crypto_expressions() -> Result<()> { + test_expression!("md5('tom')", "34b7da764b21d298ef307d04d8152dc5"); + test_expression!("md5('')", "d41d8cd98f00b204e9800998ecf8427e"); + test_expression!("md5(NULL)", "NULL"); + test_expression!( + "sha224('tom')", + "0bf6cb62649c42a9ae3876ab6f6d92ad36cb5414e495f8873292be4d" + ); + test_expression!( + "sha224('')", + "d14a028c2a3a2bc9476102bb288234c415a2b01f828ea62ac5b3e42f" + ); + test_expression!("sha224(NULL)", "NULL"); + test_expression!( + "sha256('tom')", + "e1608f75c5d7813f3d4031cb30bfb786507d98137538ff8e128a6ff74e84e643" + ); + test_expression!( + "sha256('')", + "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855" + ); + test_expression!("sha256(NULL)", "NULL"); + test_expression!("sha384('tom')", "096f5b68aa77848e4fdf5c1c0b350de2dbfad60ffd7c25d9ea07c6c19b8a4d55a9187eb117c557883f58c16dfac3e343"); + test_expression!("sha384('')", "38b060a751ac96384cd9327eb1b1e36a21fdb71114be07434c0cc7bf63f6e1da274edebfe76f65fbd51ad2f14898b95b"); + test_expression!("sha384(NULL)", "NULL"); + test_expression!("sha512('tom')", "6e1b9b3fe840680e37051f7ad5e959d6f39ad0f8885d855166f55c659469d3c8b78118c44a2a49c72ddb481cd6d8731034e11cc030070ba843a90b3495cb8d3e"); + test_expression!("sha512('')", "cf83e1357eefb8bdf1542850d66d8007d620e4050b5715dc83f4a921d36ce9ce47d0d13c5d85f2b0ff8318d2877eec2f63b931bd47417a81a538327af927da3e"); + test_expression!("sha512(NULL)", "NULL"); Ok(()) } @@ -2273,35 +2192,178 @@ async fn test_interval_expressions() -> Result<()> { } #[tokio::test] -#[cfg_attr(not(feature = "crypto_expressions"), ignore)] -async fn test_crypto_expressions() -> Result<()> { - test_expression!("md5('tom')", "34b7da764b21d298ef307d04d8152dc5"); - test_expression!("md5('')", "d41d8cd98f00b204e9800998ecf8427e"); - test_expression!("md5(NULL)", "NULL"); - test_expression!( - "sha224('tom')", - "0bf6cb62649c42a9ae3876ab6f6d92ad36cb5414e495f8873292be4d" - ); +async fn test_string_expressions() -> Result<()> { + test_expression!("ascii('')", "0"); + test_expression!("ascii('x')", "120"); + test_expression!("ascii(NULL)", "NULL"); + test_expression!("bit_length('')", "0"); + test_expression!("bit_length('chars')", "40"); + test_expression!("bit_length('josé')", "40"); + test_expression!("bit_length(NULL)", "NULL"); + test_expression!("btrim(' xyxtrimyyx ', NULL)", "NULL"); + test_expression!("btrim(' xyxtrimyyx ')", "xyxtrimyyx"); + test_expression!("btrim('\n xyxtrimyyx \n')", "\n xyxtrimyyx \n"); + test_expression!("btrim('xyxtrimyyx', 'xyz')", "trim"); + test_expression!("btrim('\nxyxtrimyyx\n', 'xyz\n')", "trim"); + test_expression!("btrim(NULL, 'xyz')", "NULL"); + test_expression!("chr(CAST(120 AS int))", "x"); + test_expression!("chr(CAST(128175 AS int))", "💯"); + test_expression!("chr(CAST(NULL AS int))", "NULL"); + test_expression!("concat('a','b','c')", "abc"); + test_expression!("concat('abcde', 2, NULL, 22)", "abcde222"); + test_expression!("concat(NULL)", ""); + test_expression!("concat_ws(',', 'abcde', 2, NULL, 22)", "abcde,2,22"); + test_expression!("concat_ws('|','a','b','c')", "a|b|c"); + test_expression!("concat_ws('|',NULL)", ""); + test_expression!("concat_ws(NULL,'a',NULL,'b','c')", "NULL"); + test_expression!("initcap('')", ""); + test_expression!("initcap('hi THOMAS')", "Hi Thomas"); + test_expression!("initcap(NULL)", "NULL"); + test_expression!("lower('')", ""); + test_expression!("lower('TOM')", "tom"); + test_expression!("lower(NULL)", "NULL"); + test_expression!("ltrim(' zzzytest ', NULL)", "NULL"); + test_expression!("ltrim(' zzzytest ')", "zzzytest "); + test_expression!("ltrim('zzzytest', 'xyz')", "test"); + test_expression!("ltrim(NULL, 'xyz')", "NULL"); + test_expression!("octet_length('')", "0"); + test_expression!("octet_length('chars')", "5"); + test_expression!("octet_length('josé')", "5"); + test_expression!("octet_length(NULL)", "NULL"); + test_expression!("repeat('Pg', 4)", "PgPgPgPg"); + test_expression!("repeat('Pg', CAST(NULL AS INT))", "NULL"); + test_expression!("repeat(NULL, 4)", "NULL"); + test_expression!("replace('abcdefabcdef', 'cd', 'XX')", "abXXefabXXef"); + test_expression!("replace('abcdefabcdef', 'cd', NULL)", "NULL"); + test_expression!("replace('abcdefabcdef', 'notmatch', 'XX')", "abcdefabcdef"); + test_expression!("replace('abcdefabcdef', NULL, 'XX')", "NULL"); + test_expression!("replace(NULL, 'cd', 'XX')", "NULL"); + test_expression!("rtrim(' testxxzx ')", " testxxzx"); + test_expression!("rtrim(' zzzytest ', NULL)", "NULL"); + test_expression!("rtrim('testxxzx', 'xyz')", "test"); + test_expression!("rtrim(NULL, 'xyz')", "NULL"); + test_expression!("split_part('abc~@~def~@~ghi', '~@~', 2)", "def"); + test_expression!("split_part('abc~@~def~@~ghi', '~@~', 20)", ""); + test_expression!("split_part(NULL, '~@~', 20)", "NULL"); + test_expression!("split_part('abc~@~def~@~ghi', NULL, 20)", "NULL"); test_expression!( - "sha224('')", - "d14a028c2a3a2bc9476102bb288234c415a2b01f828ea62ac5b3e42f" + "split_part('abc~@~def~@~ghi', '~@~', CAST(NULL AS INT))", + "NULL" ); - test_expression!("sha224(NULL)", "NULL"); + test_expression!("starts_with('alphabet', 'alph')", "true"); + test_expression!("starts_with('alphabet', 'blph')", "false"); + test_expression!("starts_with(NULL, 'blph')", "NULL"); + test_expression!("starts_with('alphabet', NULL)", "NULL"); + test_expression!("to_hex(2147483647)", "7fffffff"); + test_expression!("to_hex(9223372036854775807)", "7fffffffffffffff"); + test_expression!("to_hex(CAST(NULL AS int))", "NULL"); + test_expression!("trim(' tom ')", "tom"); + test_expression!("trim(' tom')", "tom"); + test_expression!("trim('')", ""); + test_expression!("trim('tom ')", "tom"); + test_expression!("upper('')", ""); + test_expression!("upper('tom')", "TOM"); + test_expression!("upper(NULL)", "NULL"); + Ok(()) +} + +#[tokio::test] +#[cfg_attr(not(feature = "unicode_expressions"), ignore)] +async fn test_unicode_expressions() -> Result<()> { + test_expression!("char_length('')", "0"); + test_expression!("char_length('chars')", "5"); + test_expression!("char_length('josé')", "4"); + test_expression!("char_length(NULL)", "NULL"); + test_expression!("character_length('')", "0"); + test_expression!("character_length('chars')", "5"); + test_expression!("character_length('josé')", "4"); + test_expression!("character_length(NULL)", "NULL"); + test_expression!("left('abcde', -2)", "abc"); + test_expression!("left('abcde', -200)", ""); + test_expression!("left('abcde', 0)", ""); + test_expression!("left('abcde', 2)", "ab"); + test_expression!("left('abcde', 200)", "abcde"); + test_expression!("left('abcde', CAST(NULL AS INT))", "NULL"); + test_expression!("left(NULL, 2)", "NULL"); + test_expression!("left(NULL, CAST(NULL AS INT))", "NULL"); + test_expression!("length('')", "0"); + test_expression!("length('chars')", "5"); + test_expression!("length('josé')", "4"); + test_expression!("length(NULL)", "NULL"); + test_expression!("lpad('hi', 5, 'xy')", "xyxhi"); + test_expression!("lpad('hi', 0)", ""); + test_expression!("lpad('hi', 21, 'abcdef')", "abcdefabcdefabcdefahi"); + test_expression!("lpad('hi', 5, 'xy')", "xyxhi"); + test_expression!("lpad('hi', 5, NULL)", "NULL"); + test_expression!("lpad('hi', 5)", " hi"); + test_expression!("lpad('hi', CAST(NULL AS INT), 'xy')", "NULL"); + test_expression!("lpad('hi', CAST(NULL AS INT))", "NULL"); + test_expression!("lpad('xyxhi', 3)", "xyx"); + test_expression!("lpad(NULL, 0)", "NULL"); + test_expression!("lpad(NULL, 5, 'xy')", "NULL"); + test_expression!("reverse('abcde')", "edcba"); + test_expression!("reverse('loẅks')", "skẅol"); + test_expression!("reverse(NULL)", "NULL"); + test_expression!("right('abcde', -2)", "cde"); + test_expression!("right('abcde', -200)", ""); + test_expression!("right('abcde', 0)", ""); + test_expression!("right('abcde', 2)", "de"); + test_expression!("right('abcde', 200)", "abcde"); + test_expression!("right('abcde', CAST(NULL AS INT))", "NULL"); + test_expression!("right(NULL, 2)", "NULL"); + test_expression!("right(NULL, CAST(NULL AS INT))", "NULL"); + test_expression!("rpad('hi', 5, 'xy')", "hixyx"); + test_expression!("rpad('hi', 0)", ""); + test_expression!("rpad('hi', 21, 'abcdef')", "hiabcdefabcdefabcdefa"); + test_expression!("rpad('hi', 5, 'xy')", "hixyx"); + test_expression!("rpad('hi', 5, NULL)", "NULL"); + test_expression!("rpad('hi', 5)", "hi "); + test_expression!("rpad('hi', CAST(NULL AS INT), 'xy')", "NULL"); + test_expression!("rpad('hi', CAST(NULL AS INT))", "NULL"); + test_expression!("rpad('xyxhi', 3)", "xyx"); + test_expression!("strpos('abc', 'c')", "3"); + test_expression!("strpos('josé', 'é')", "4"); + test_expression!("strpos('joséésoj', 'so')", "6"); + test_expression!("strpos('joséésoj', 'abc')", "0"); + test_expression!("strpos(NULL, 'abc')", "NULL"); + test_expression!("strpos('joséésoj', NULL)", "NULL"); + test_expression!("substr('alphabet', -3)", "alphabet"); + test_expression!("substr('alphabet', 0)", "alphabet"); + test_expression!("substr('alphabet', 1)", "alphabet"); + test_expression!("substr('alphabet', 2)", "lphabet"); + test_expression!("substr('alphabet', 3)", "phabet"); + test_expression!("substr('alphabet', 30)", ""); + test_expression!("substr('alphabet', CAST(NULL AS int))", "NULL"); + test_expression!("substr('alphabet', 3, 2)", "ph"); + test_expression!("substr('alphabet', 3, 20)", "phabet"); + test_expression!("substr('alphabet', CAST(NULL AS int), 20)", "NULL"); + test_expression!("substr('alphabet', 3, CAST(NULL AS int))", "NULL"); + test_expression!("translate('12345', '143', 'ax')", "a2x5"); + test_expression!("translate(NULL, '143', 'ax')", "NULL"); + test_expression!("translate('12345', NULL, 'ax')", "NULL"); + test_expression!("translate('12345', '143', NULL)", "NULL"); + Ok(()) +} + +#[tokio::test] +#[cfg_attr(not(feature = "regex_expressions"), ignore)] +async fn test_regex_expressions() -> Result<()> { + test_expression!("regexp_replace('ABCabcABC', '(abc)', 'X', 'gi')", "XXX"); + test_expression!("regexp_replace('ABCabcABC', '(abc)', 'X', 'i')", "XabcABC"); + test_expression!("regexp_replace('foobarbaz', 'b..', 'X', 'g')", "fooXX"); + test_expression!("regexp_replace('foobarbaz', 'b..', 'X')", "fooXbaz"); test_expression!( - "sha256('tom')", - "e1608f75c5d7813f3d4031cb30bfb786507d98137538ff8e128a6ff74e84e643" + "regexp_replace('foobarbaz', 'b(..)', 'X\\1Y', 'g')", + "fooXarYXazY" ); test_expression!( - "sha256('')", - "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855" + "regexp_replace('foobarbaz', 'b(..)', 'X\\1Y', NULL)", + "NULL" ); - test_expression!("sha256(NULL)", "NULL"); - test_expression!("sha384('tom')", "096f5b68aa77848e4fdf5c1c0b350de2dbfad60ffd7c25d9ea07c6c19b8a4d55a9187eb117c557883f58c16dfac3e343"); - test_expression!("sha384('')", "38b060a751ac96384cd9327eb1b1e36a21fdb71114be07434c0cc7bf63f6e1da274edebfe76f65fbd51ad2f14898b95b"); - test_expression!("sha384(NULL)", "NULL"); - test_expression!("sha512('tom')", "6e1b9b3fe840680e37051f7ad5e959d6f39ad0f8885d855166f55c659469d3c8b78118c44a2a49c72ddb481cd6d8731034e11cc030070ba843a90b3495cb8d3e"); - test_expression!("sha512('')", "cf83e1357eefb8bdf1542850d66d8007d620e4050b5715dc83f4a921d36ce9ce47d0d13c5d85f2b0ff8318d2877eec2f63b931bd47417a81a538327af927da3e"); - test_expression!("sha512(NULL)", "NULL"); + test_expression!("regexp_replace('foobarbaz', 'b(..)', NULL, 'g')", "NULL"); + test_expression!("regexp_replace('foobarbaz', NULL, 'X\\1Y', 'g')", "NULL"); + test_expression!("regexp_replace('Thomas', '.[mN]a.', 'M')", "ThM"); + test_expression!("regexp_replace(NULL, 'b(..)', 'X\\1Y', 'g')", "NULL"); Ok(()) }