From 4fd28886829c7c11bcd4498e28f4ad308a675f1a Mon Sep 17 00:00:00 2001 From: Mike Seddon Date: Thu, 4 Mar 2021 08:44:56 +1100 Subject: [PATCH 1/2] implement misc string functions --- rust/datafusion/README.md | 9 +- rust/datafusion/src/logical_plan/expr.rs | 6 + rust/datafusion/src/logical_plan/mod.rs | 12 +- .../datafusion/src/physical_plan/functions.rs | 905 ++++++++++++------ .../src/physical_plan/string_expressions.rs | 237 ++++- rust/datafusion/src/prelude.rs | 8 +- rust/datafusion/tests/sql.rs | 18 + 7 files changed, 875 insertions(+), 320 deletions(-) diff --git a/rust/datafusion/README.md b/rust/datafusion/README.md index 6f9cd85deaa..0213f74b9cf 100644 --- a/rust/datafusion/README.md +++ b/rust/datafusion/README.md @@ -57,22 +57,29 @@ DataFusion includes a simple command-line interactive SQL utility. See the [CLI - [x] UDAFs (user-defined aggregate functions) - [x] Common math functions - String functions + - [x] ascii - [x] bit_Length - [x] btrim - [x] char_length - [x] character_length + - [x] chr - [x] concat - [x] concat_ws - - [x] length + - [x] initcap - [x] left + - [x] length - [x] lpad - [x] ltrim - [x] octet_length + - [x] repeat + - [x] reverse - [x] right - [x] rpad - [x] rtrim - [x] substr + - [x] to_hex - [x] trim + - Miscellaneous/Boolean functions - [x] nullif - Common date/time functions diff --git a/rust/datafusion/src/logical_plan/expr.rs b/rust/datafusion/src/logical_plan/expr.rs index d98cc80c24c..5b0876a79e0 100644 --- a/rust/datafusion/src/logical_plan/expr.rs +++ b/rust/datafusion/src/logical_plan/expr.rs @@ -1075,18 +1075,23 @@ unary_scalar_expr!(Log2, log2); unary_scalar_expr!(Log10, log10); // string functions +unary_scalar_expr!(Ascii, ascii); unary_scalar_expr!(BitLength, bit_length); unary_scalar_expr!(Btrim, btrim); unary_scalar_expr!(CharacterLength, character_length); unary_scalar_expr!(CharacterLength, length); +unary_scalar_expr!(Chr, chr); unary_scalar_expr!(Concat, concat); unary_scalar_expr!(ConcatWithSeparator, concat_ws); +unary_scalar_expr!(InitCap, initcap); unary_scalar_expr!(Left, left); unary_scalar_expr!(Lower, lower); unary_scalar_expr!(Lpad, lpad); unary_scalar_expr!(Ltrim, ltrim); unary_scalar_expr!(MD5, md5); unary_scalar_expr!(OctetLength, octet_length); +unary_scalar_expr!(Repeat, repeat); +unary_scalar_expr!(Reverse, reverse); unary_scalar_expr!(Right, right); unary_scalar_expr!(Rpad, rpad); unary_scalar_expr!(Rtrim, rtrim); @@ -1095,6 +1100,7 @@ unary_scalar_expr!(SHA256, sha256); unary_scalar_expr!(SHA384, sha384); unary_scalar_expr!(SHA512, sha512); unary_scalar_expr!(Substr, substr); +unary_scalar_expr!(ToHex, to_hex); unary_scalar_expr!(Trim, trim); unary_scalar_expr!(Upper, upper); diff --git a/rust/datafusion/src/logical_plan/mod.rs b/rust/datafusion/src/logical_plan/mod.rs index 08ba81c4271..ab787ef82f4 100644 --- a/rust/datafusion/src/logical_plan/mod.rs +++ b/rust/datafusion/src/logical_plan/mod.rs @@ -33,13 +33,13 @@ pub use builder::LogicalPlanBuilder; pub use dfschema::{DFField, DFSchema, DFSchemaRef, ToDFSchema}; pub use display::display_schema; pub use expr::{ - abs, acos, and, array, asin, atan, avg, binary_expr, bit_length, btrim, case, ceil, - character_length, col, combine_filters, concat, concat_ws, cos, count, + abs, acos, and, array, ascii, asin, atan, avg, binary_expr, bit_length, btrim, case, + ceil, character_length, chr, col, combine_filters, concat, concat_ws, cos, count, count_distinct, create_udaf, create_udf, exp, exprlist_to_fields, floor, in_list, - left, length, lit, ln, log10, log2, lower, lpad, ltrim, max, md5, min, octet_length, - or, right, round, rpad, rtrim, sha224, sha256, sha384, sha512, signum, sin, sqrt, - substr, sum, tan, trim, trunc, upper, when, Expr, ExprRewriter, ExpressionVisitor, - Literal, Recursion, + initcap, left, length, lit, ln, log10, log2, lower, lpad, ltrim, max, md5, min, + octet_length, or, repeat, reverse, right, round, rpad, rtrim, sha224, sha256, sha384, + sha512, signum, sin, sqrt, substr, sum, tan, to_hex, trim, trunc, upper, when, Expr, + ExprRewriter, ExpressionVisitor, Literal, Recursion, }; pub use extension::UserDefinedLogicalNode; pub use operators::Operator; diff --git a/rust/datafusion/src/physical_plan/functions.rs b/rust/datafusion/src/physical_plan/functions.rs index 7c443f83297..ae8d128fc30 100644 --- a/rust/datafusion/src/physical_plan/functions.rs +++ b/rust/datafusion/src/physical_plan/functions.rs @@ -124,50 +124,62 @@ pub enum BuiltinScalarFunction { // string functions /// construct an array from columns Array, + /// ascii + Ascii, /// bit_length BitLength, /// btrim Btrim, /// character_length CharacterLength, + /// chr + Chr, /// concat Concat, /// concat_ws ConcatWithSeparator, - /// Date part + /// date_part DatePart, - /// Date truncate + /// date_trunc DateTrunc, + /// initcap + InitCap, /// left Left, /// lpad Lpad, /// lower Lower, - /// trim left + /// ltrim Ltrim, - /// MD5 + /// md5 MD5, - /// SQL NULLIF() + /// nullif NullIf, /// octet_length OctetLength, + /// repeat + Repeat, + /// reverse + Reverse, /// right Right, /// rpad Rpad, - /// trim right + /// rtrim Rtrim, - /// SHA224 + /// sha224 SHA224, - /// SHA256 + /// sha256 SHA256, - /// SHA384 + /// sha384 SHA384, - /// SHA512 + /// Sha512 SHA512, /// substr Substr, + /// to_hex + ToHex, /// to_timestamp ToTimestamp, /// trim @@ -208,14 +220,17 @@ impl FromStr for BuiltinScalarFunction { // string functions "array" => BuiltinScalarFunction::Array, + "ascii" => BuiltinScalarFunction::Ascii, "bit_length" => BuiltinScalarFunction::BitLength, "btrim" => BuiltinScalarFunction::Btrim, "char_length" => BuiltinScalarFunction::CharacterLength, "character_length" => BuiltinScalarFunction::CharacterLength, "concat" => BuiltinScalarFunction::Concat, "concat_ws" => BuiltinScalarFunction::ConcatWithSeparator, + "chr" => BuiltinScalarFunction::Chr, "date_part" => BuiltinScalarFunction::DatePart, "date_trunc" => BuiltinScalarFunction::DateTrunc, + "initcap" => BuiltinScalarFunction::InitCap, "left" => BuiltinScalarFunction::Left, "length" => BuiltinScalarFunction::CharacterLength, "lower" => BuiltinScalarFunction::Lower, @@ -224,6 +239,8 @@ impl FromStr for BuiltinScalarFunction { "md5" => BuiltinScalarFunction::MD5, "nullif" => BuiltinScalarFunction::NullIf, "octet_length" => BuiltinScalarFunction::OctetLength, + "repeat" => BuiltinScalarFunction::Repeat, + "reverse" => BuiltinScalarFunction::Reverse, "right" => BuiltinScalarFunction::Right, "rpad" => BuiltinScalarFunction::Rpad, "rtrim" => BuiltinScalarFunction::Rtrim, @@ -232,6 +249,7 @@ impl FromStr for BuiltinScalarFunction { "sha384" => BuiltinScalarFunction::SHA384, "sha512" => BuiltinScalarFunction::SHA512, "substr" => BuiltinScalarFunction::Substr, + "to_hex" => BuiltinScalarFunction::ToHex, "to_timestamp" => BuiltinScalarFunction::ToTimestamp, "trim" => BuiltinScalarFunction::Trim, "upper" => BuiltinScalarFunction::Upper, @@ -273,6 +291,7 @@ pub fn return_type( Box::new(Field::new("item", arg_types[0].clone(), true)), arg_types.len() as i32, )), + BuiltinScalarFunction::Ascii => Ok(DataType::Int32), BuiltinScalarFunction::BitLength => Ok(match arg_types[0] { DataType::LargeUtf8 => DataType::Int64, DataType::Utf8 => DataType::Int32, @@ -303,12 +322,23 @@ pub fn return_type( )); } }), + BuiltinScalarFunction::Chr => Ok(DataType::Utf8), BuiltinScalarFunction::Concat => Ok(DataType::Utf8), BuiltinScalarFunction::ConcatWithSeparator => Ok(DataType::Utf8), BuiltinScalarFunction::DatePart => Ok(DataType::Int32), BuiltinScalarFunction::DateTrunc => { Ok(DataType::Timestamp(TimeUnit::Nanosecond, None)) } + BuiltinScalarFunction::InitCap => Ok(match arg_types[0] { + DataType::LargeUtf8 => DataType::LargeUtf8, + DataType::Utf8 => DataType::Utf8, + _ => { + // this error is internal as `data_types` should have captured this. + return Err(DataFusionError::Internal( + "The initcap function can only accept strings.".to_string(), + )); + } + }), BuiltinScalarFunction::Left => Ok(match arg_types[0] { DataType::LargeUtf8 => DataType::LargeUtf8, DataType::Utf8 => DataType::Utf8, @@ -374,6 +404,26 @@ pub fn return_type( )); } }), + BuiltinScalarFunction::Repeat => Ok(match arg_types[0] { + DataType::LargeUtf8 => DataType::LargeUtf8, + DataType::Utf8 => DataType::Utf8, + _ => { + // this error is internal as `data_types` should have captured this. + return Err(DataFusionError::Internal( + "The repeat function can only accept strings.".to_string(), + )); + } + }), + BuiltinScalarFunction::Reverse => Ok(match arg_types[0] { + DataType::LargeUtf8 => DataType::LargeUtf8, + DataType::Utf8 => DataType::Utf8, + _ => { + // this error is internal as `data_types` should have captured this. + return Err(DataFusionError::Internal( + "The reverse function can only accept strings.".to_string(), + )); + } + }), BuiltinScalarFunction::Right => Ok(match arg_types[0] { DataType::LargeUtf8 => DataType::LargeUtf8, DataType::Utf8 => DataType::Utf8, @@ -454,6 +504,17 @@ pub fn return_type( )); } }), + BuiltinScalarFunction::ToHex => Ok(match arg_types[0] { + DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => { + DataType::Utf8 + } + _ => { + // this error is internal as `data_types` should have captured this. + return Err(DataFusionError::Internal( + "The to_hex function can only accept integers.".to_string(), + )); + } + }), BuiltinScalarFunction::ToTimestamp => { Ok(DataType::Timestamp(TimeUnit::Nanosecond, None)) } @@ -547,6 +608,18 @@ pub fn create_physical_expr( // string functions BuiltinScalarFunction::Array => array_expressions::array, + BuiltinScalarFunction::Ascii => |args| match args[0].data_type() { + DataType::Utf8 => { + make_scalar_function(string_expressions::ascii::)(args) + } + DataType::LargeUtf8 => { + make_scalar_function(string_expressions::ascii::)(args) + } + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {:?} for function ascii", + other, + ))), + }, BuiltinScalarFunction::BitLength => |args| match &args[0] { ColumnarValue::Array(v) => Ok(ColumnarValue::Array(bit_length(v.as_ref())?)), ColumnarValue::Scalar(v) => match v { @@ -583,12 +656,27 @@ pub fn create_physical_expr( other, ))), }, + BuiltinScalarFunction::Chr => { + |args| make_scalar_function(string_expressions::chr)(args) + } BuiltinScalarFunction::Concat => string_expressions::concat, BuiltinScalarFunction::ConcatWithSeparator => { |args| make_scalar_function(string_expressions::concat_ws)(args) } BuiltinScalarFunction::DatePart => datetime_expressions::date_part, BuiltinScalarFunction::DateTrunc => datetime_expressions::date_trunc, + BuiltinScalarFunction::InitCap => |args| match args[0].data_type() { + DataType::Utf8 => { + make_scalar_function(string_expressions::initcap::)(args) + } + DataType::LargeUtf8 => { + make_scalar_function(string_expressions::initcap::)(args) + } + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {:?} for function initcap", + other, + ))), + }, BuiltinScalarFunction::Left => |args| match args[0].data_type() { DataType::Utf8 => make_scalar_function(string_expressions::left::)(args), DataType::LargeUtf8 => { @@ -638,6 +726,30 @@ pub fn create_physical_expr( _ => unreachable!(), }, }, + BuiltinScalarFunction::Repeat => |args| match args[0].data_type() { + DataType::Utf8 => { + make_scalar_function(string_expressions::repeat::)(args) + } + DataType::LargeUtf8 => { + make_scalar_function(string_expressions::repeat::)(args) + } + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {:?} for function repeat", + other, + ))), + }, + BuiltinScalarFunction::Reverse => |args| match args[0].data_type() { + DataType::Utf8 => { + make_scalar_function(string_expressions::reverse::)(args) + } + DataType::LargeUtf8 => { + make_scalar_function(string_expressions::reverse::)(args) + } + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {:?} for function reverse", + other, + ))), + }, BuiltinScalarFunction::Right => |args| match args[0].data_type() { DataType::Utf8 => { make_scalar_function(string_expressions::right::)(args) @@ -696,6 +808,18 @@ pub fn create_physical_expr( other, ))), }, + BuiltinScalarFunction::ToHex => |args| match args[0].data_type() { + DataType::Int32 => { + make_scalar_function(string_expressions::to_hex::)(args) + } + DataType::Int64 => { + make_scalar_function(string_expressions::to_hex::)(args) + } + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {:?} for function to_hex", + other, + ))), + }, BuiltinScalarFunction::ToTimestamp => datetime_expressions::to_timestamp, BuiltinScalarFunction::Trim => |args| match args[0].data_type() { DataType::Utf8 => { @@ -739,11 +863,14 @@ fn signature(fun: &BuiltinScalarFunction) -> Signature { BuiltinScalarFunction::Concat | BuiltinScalarFunction::ConcatWithSeparator => { Signature::Variadic(vec![DataType::Utf8]) } - BuiltinScalarFunction::BitLength + BuiltinScalarFunction::Ascii + | BuiltinScalarFunction::BitLength | BuiltinScalarFunction::CharacterLength + | BuiltinScalarFunction::InitCap | BuiltinScalarFunction::Lower | BuiltinScalarFunction::MD5 | BuiltinScalarFunction::OctetLength + | BuiltinScalarFunction::Reverse | BuiltinScalarFunction::SHA224 | BuiltinScalarFunction::SHA256 | BuiltinScalarFunction::SHA384 @@ -758,6 +885,9 @@ fn signature(fun: &BuiltinScalarFunction) -> Signature { Signature::Exact(vec![DataType::Utf8]), Signature::Exact(vec![DataType::Utf8, DataType::Utf8]), ]), + BuiltinScalarFunction::Chr | BuiltinScalarFunction::ToHex => { + Signature::Uniform(1, vec![DataType::Int64]) + } BuiltinScalarFunction::Lpad | BuiltinScalarFunction::Rpad => { Signature::OneOf(vec![ Signature::Exact(vec![DataType::Utf8, DataType::Int64]), @@ -780,12 +910,12 @@ fn signature(fun: &BuiltinScalarFunction) -> Signature { ]), ]) } - BuiltinScalarFunction::Left | BuiltinScalarFunction::Right => { - Signature::OneOf(vec![ - Signature::Exact(vec![DataType::Utf8, DataType::Int64]), - Signature::Exact(vec![DataType::LargeUtf8, DataType::Int64]), - ]) - } + BuiltinScalarFunction::Left + | BuiltinScalarFunction::Repeat + | BuiltinScalarFunction::Right => Signature::OneOf(vec![ + Signature::Exact(vec![DataType::Utf8, DataType::Int64]), + Signature::Exact(vec![DataType::LargeUtf8, DataType::Int64]), + ]), BuiltinScalarFunction::ToTimestamp => Signature::Uniform(1, vec![DataType::Utf8]), BuiltinScalarFunction::DateTrunc => Signature::Exact(vec![ DataType::Utf8, @@ -1033,6 +1163,54 @@ mod tests { #[test] fn test_functions() -> Result<()> { + test_function!( + Ascii, + &[lit(ScalarValue::Utf8(Some("x".to_string())))], + Ok(Some(120)), + i32, + Int32, + Int32Array + ); + test_function!( + Ascii, + &[lit(ScalarValue::Utf8(Some("ésoj".to_string())))], + Ok(Some(233)), + i32, + Int32, + Int32Array + ); + test_function!( + Ascii, + &[lit(ScalarValue::Utf8(Some("💯".to_string())))], + Ok(Some(128175)), + i32, + Int32, + Int32Array + ); + test_function!( + Ascii, + &[lit(ScalarValue::Utf8(Some("💯a".to_string())))], + Ok(Some(128175)), + i32, + Int32, + Int32Array + ); + test_function!( + Ascii, + &[lit(ScalarValue::Utf8(Some("".to_string())))], + Ok(Some(0)), + i32, + Int32, + Int32Array + ); + test_function!( + Ascii, + &[lit(ScalarValue::Utf8(None))], + Ok(None), + i32, + Int32, + Int32Array + ); test_function!( BitLength, &[lit(ScalarValue::Utf8(Some("chars".to_string())))], @@ -1165,6 +1343,66 @@ mod tests { Int32, Int32Array ); + test_function!( + Chr, + &[lit(ScalarValue::Int64(Some(128175)))], + Ok(Some("💯")), + &str, + Utf8, + StringArray + ); + test_function!( + Chr, + &[lit(ScalarValue::Int64(None))], + Ok(None), + &str, + Utf8, + StringArray + ); + test_function!( + Chr, + &[lit(ScalarValue::Int64(Some(120)))], + Ok(Some("x")), + &str, + Utf8, + StringArray + ); + test_function!( + Chr, + &[lit(ScalarValue::Int64(Some(128175)))], + Ok(Some("💯")), + &str, + Utf8, + StringArray + ); + test_function!( + Chr, + &[lit(ScalarValue::Int64(None))], + Ok(None), + &str, + Utf8, + StringArray + ); + test_function!( + Chr, + &[lit(ScalarValue::Int64(Some(0)))], + Err(DataFusionError::Execution( + "null character not permitted.".to_string(), + )), + &str, + Utf8, + StringArray + ); + test_function!( + Chr, + &[lit(ScalarValue::Int64(Some(i64::MAX)))], + Err(DataFusionError::Execution( + "requested character too large for encoding.".to_string(), + )), + &str, + Utf8, + StringArray + ); test_function!( Concat, &[ @@ -1287,6 +1525,38 @@ mod tests { Float64, Float64Array ); + test_function!( + InitCap, + &[lit(ScalarValue::Utf8(Some("hi THOMAS".to_string())))], + Ok(Some("Hi Thomas")), + &str, + Utf8, + StringArray + ); + test_function!( + InitCap, + &[lit(ScalarValue::Utf8(Some("".to_string())))], + Ok(Some("")), + &str, + Utf8, + StringArray + ); + test_function!( + InitCap, + &[lit(ScalarValue::Utf8(Some("".to_string())))], + Ok(Some("")), + &str, + Utf8, + StringArray + ); + test_function!( + InitCap, + &[lit(ScalarValue::Utf8(None))], + Ok(None), + &str, + Utf8, + StringArray + ); test_function!( Left, &[ @@ -1353,311 +1623,79 @@ mod tests { Utf8, StringArray ); - #[cfg(feature = "crypto_expressions")] test_function!( - MD5, - &[lit(ScalarValue::Utf8(Some("tom".to_string())))], - Ok(Some("34b7da764b21d298ef307d04d8152dc5")), + Left, + &[ + lit(ScalarValue::Utf8(Some("abcde".to_string()))), + lit(ScalarValue::Int64(None)), + ], + Ok(None), &str, Utf8, StringArray ); - #[cfg(feature = "crypto_expressions")] test_function!( - MD5, - &[lit(ScalarValue::Utf8(Some("".to_string())))], - Ok(Some("d41d8cd98f00b204e9800998ecf8427e")), + Left, + &[ + lit(ScalarValue::Utf8(Some("joséésoj".to_string()))), + lit(ScalarValue::Int64(Some(5))), + ], + Ok(Some("joséé")), &str, Utf8, StringArray ); - #[cfg(feature = "crypto_expressions")] test_function!( - MD5, - &[lit(ScalarValue::Utf8(None))], - Ok(None), + Left, + &[ + lit(ScalarValue::Utf8(Some("joséésoj".to_string()))), + lit(ScalarValue::Int64(Some(-3))), + ], + Ok(Some("joséé")), &str, Utf8, StringArray ); - #[cfg(not(feature = "crypto_expressions"))] test_function!( - MD5, - &[lit(ScalarValue::Utf8(Some("tom".to_string())))], - Err(DataFusionError::Internal( - "function md5 requires compilation with feature flag: crypto_expressions.".to_string() - )), + Lpad, + &[ + lit(ScalarValue::Utf8(Some("josé".to_string()))), + lit(ScalarValue::Int64(Some(5))), + ], + Ok(Some(" josé")), &str, Utf8, StringArray ); test_function!( - Left, + Lpad, &[ - lit(ScalarValue::Utf8(Some("abcde".to_string()))), - lit(ScalarValue::Int64(None)), + lit(ScalarValue::Utf8(Some("hi".to_string()))), + lit(ScalarValue::Int64(Some(5))), ], - Ok(None), + Ok(Some(" hi")), &str, Utf8, StringArray ); test_function!( - Left, + Lpad, &[ - lit(ScalarValue::Utf8(Some("joséésoj".to_string()))), - lit(ScalarValue::Int64(Some(5))), + lit(ScalarValue::Utf8(Some("hi".to_string()))), + lit(ScalarValue::Int64(Some(0))), ], - Ok(Some("joséé")), + Ok(Some("")), &str, Utf8, StringArray ); test_function!( - Left, + Lpad, &[ - lit(ScalarValue::Utf8(Some("joséésoj".to_string()))), - lit(ScalarValue::Int64(Some(-3))), + lit(ScalarValue::Utf8(Some("hi".to_string()))), + lit(ScalarValue::Int64(None)), ], - Ok(Some("joséé")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "crypto_expressions")] - test_function!( - SHA224, - &[lit(ScalarValue::Utf8(Some("tom".to_string())))], - Ok(Some(&[ - 11u8, 246u8, 203u8, 98u8, 100u8, 156u8, 66u8, 169u8, 174u8, 56u8, 118u8, - 171u8, 111u8, 109u8, 146u8, 173u8, 54u8, 203u8, 84u8, 20u8, 228u8, 149u8, - 248u8, 135u8, 50u8, 146u8, 190u8, 77u8 - ])), - &[u8], - Binary, - BinaryArray - ); - #[cfg(feature = "crypto_expressions")] - test_function!( - SHA224, - &[lit(ScalarValue::Utf8(Some("".to_string())))], - Ok(Some(&[ - 209u8, 74u8, 2u8, 140u8, 42u8, 58u8, 43u8, 201u8, 71u8, 97u8, 2u8, 187u8, - 40u8, 130u8, 52u8, 196u8, 21u8, 162u8, 176u8, 31u8, 130u8, 142u8, 166u8, - 42u8, 197u8, 179u8, 228u8, 47u8 - ])), - &[u8], - Binary, - BinaryArray - ); - #[cfg(feature = "crypto_expressions")] - test_function!( - SHA224, - &[lit(ScalarValue::Utf8(None))], - Ok(None), - &[u8], - Binary, - BinaryArray - ); - #[cfg(not(feature = "crypto_expressions"))] - test_function!( - SHA224, - &[lit(ScalarValue::Utf8(Some("tom".to_string())))], - Err(DataFusionError::Internal( - "function sha224 requires compilation with feature flag: crypto_expressions.".to_string() - )), - &[u8], - Binary, - BinaryArray - ); - #[cfg(feature = "crypto_expressions")] - test_function!( - SHA256, - &[lit(ScalarValue::Utf8(Some("tom".to_string())))], - Ok(Some(&[ - 225u8, 96u8, 143u8, 117u8, 197u8, 215u8, 129u8, 63u8, 61u8, 64u8, 49u8, - 203u8, 48u8, 191u8, 183u8, 134u8, 80u8, 125u8, 152u8, 19u8, 117u8, 56u8, - 255u8, 142u8, 18u8, 138u8, 111u8, 247u8, 78u8, 132u8, 230u8, 67u8 - ])), - &[u8], - Binary, - BinaryArray - ); - #[cfg(feature = "crypto_expressions")] - test_function!( - SHA256, - &[lit(ScalarValue::Utf8(Some("".to_string())))], - Ok(Some(&[ - 227u8, 176u8, 196u8, 66u8, 152u8, 252u8, 28u8, 20u8, 154u8, 251u8, 244u8, - 200u8, 153u8, 111u8, 185u8, 36u8, 39u8, 174u8, 65u8, 228u8, 100u8, 155u8, - 147u8, 76u8, 164u8, 149u8, 153u8, 27u8, 120u8, 82u8, 184u8, 85u8 - ])), - &[u8], - Binary, - BinaryArray - ); - #[cfg(feature = "crypto_expressions")] - test_function!( - SHA256, - &[lit(ScalarValue::Utf8(None))], - Ok(None), - &[u8], - Binary, - BinaryArray - ); - #[cfg(not(feature = "crypto_expressions"))] - test_function!( - SHA256, - &[lit(ScalarValue::Utf8(Some("tom".to_string())))], - Err(DataFusionError::Internal( - "function sha256 requires compilation with feature flag: crypto_expressions.".to_string() - )), - &[u8], - Binary, - BinaryArray - ); - #[cfg(feature = "crypto_expressions")] - test_function!( - SHA384, - &[lit(ScalarValue::Utf8(Some("tom".to_string())))], - Ok(Some(&[ - 9u8, 111u8, 91u8, 104u8, 170u8, 119u8, 132u8, 142u8, 79u8, 223u8, 92u8, - 28u8, 11u8, 53u8, 13u8, 226u8, 219u8, 250u8, 214u8, 15u8, 253u8, 124u8, - 37u8, 217u8, 234u8, 7u8, 198u8, 193u8, 155u8, 138u8, 77u8, 85u8, 169u8, - 24u8, 126u8, 177u8, 23u8, 197u8, 87u8, 136u8, 63u8, 88u8, 193u8, 109u8, - 250u8, 195u8, 227u8, 67u8 - ])), - &[u8], - Binary, - BinaryArray - ); - #[cfg(feature = "crypto_expressions")] - test_function!( - SHA384, - &[lit(ScalarValue::Utf8(Some("".to_string())))], - Ok(Some(&[ - 56u8, 176u8, 96u8, 167u8, 81u8, 172u8, 150u8, 56u8, 76u8, 217u8, 50u8, - 126u8, 177u8, 177u8, 227u8, 106u8, 33u8, 253u8, 183u8, 17u8, 20u8, 190u8, - 7u8, 67u8, 76u8, 12u8, 199u8, 191u8, 99u8, 246u8, 225u8, 218u8, 39u8, - 78u8, 222u8, 191u8, 231u8, 111u8, 101u8, 251u8, 213u8, 26u8, 210u8, - 241u8, 72u8, 152u8, 185u8, 91u8 - ])), - &[u8], - Binary, - BinaryArray - ); - #[cfg(feature = "crypto_expressions")] - test_function!( - SHA384, - &[lit(ScalarValue::Utf8(None))], - Ok(None), - &[u8], - Binary, - BinaryArray - ); - #[cfg(not(feature = "crypto_expressions"))] - test_function!( - SHA384, - &[lit(ScalarValue::Utf8(Some("tom".to_string())))], - Err(DataFusionError::Internal( - "function sha384 requires compilation with feature flag: crypto_expressions.".to_string() - )), - &[u8], - Binary, - BinaryArray - ); - #[cfg(feature = "crypto_expressions")] - test_function!( - SHA512, - &[lit(ScalarValue::Utf8(Some("tom".to_string())))], - Ok(Some(&[ - 110u8, 27u8, 155u8, 63u8, 232u8, 64u8, 104u8, 14u8, 55u8, 5u8, 31u8, - 122u8, 213u8, 233u8, 89u8, 214u8, 243u8, 154u8, 208u8, 248u8, 136u8, - 93u8, 133u8, 81u8, 102u8, 245u8, 92u8, 101u8, 148u8, 105u8, 211u8, 200u8, - 183u8, 129u8, 24u8, 196u8, 74u8, 42u8, 73u8, 199u8, 45u8, 219u8, 72u8, - 28u8, 214u8, 216u8, 115u8, 16u8, 52u8, 225u8, 28u8, 192u8, 48u8, 7u8, - 11u8, 168u8, 67u8, 169u8, 11u8, 52u8, 149u8, 203u8, 141u8, 62u8 - ])), - &[u8], - Binary, - BinaryArray - ); - #[cfg(feature = "crypto_expressions")] - test_function!( - SHA512, - &[lit(ScalarValue::Utf8(Some("".to_string())))], - Ok(Some(&[ - 207u8, 131u8, 225u8, 53u8, 126u8, 239u8, 184u8, 189u8, 241u8, 84u8, 40u8, - 80u8, 214u8, 109u8, 128u8, 7u8, 214u8, 32u8, 228u8, 5u8, 11u8, 87u8, - 21u8, 220u8, 131u8, 244u8, 169u8, 33u8, 211u8, 108u8, 233u8, 206u8, 71u8, - 208u8, 209u8, 60u8, 93u8, 133u8, 242u8, 176u8, 255u8, 131u8, 24u8, 210u8, - 135u8, 126u8, 236u8, 47u8, 99u8, 185u8, 49u8, 189u8, 71u8, 65u8, 122u8, - 129u8, 165u8, 56u8, 50u8, 122u8, 249u8, 39u8, 218u8, 62u8 - ])), - &[u8], - Binary, - BinaryArray - ); - #[cfg(feature = "crypto_expressions")] - test_function!( - SHA512, - &[lit(ScalarValue::Utf8(None))], - Ok(None), - &[u8], - Binary, - BinaryArray - ); - #[cfg(not(feature = "crypto_expressions"))] - test_function!( - SHA512, - &[lit(ScalarValue::Utf8(Some("tom".to_string())))], - Err(DataFusionError::Internal( - "function sha512 requires compilation with feature flag: crypto_expressions.".to_string() - )), - &[u8], - Binary, - BinaryArray - ); - test_function!( - Lpad, - &[ - lit(ScalarValue::Utf8(Some("josé".to_string()))), - lit(ScalarValue::Int64(Some(5))), - ], - Ok(Some(" josé")), - &str, - Utf8, - StringArray - ); - test_function!( - Lpad, - &[ - lit(ScalarValue::Utf8(Some("hi".to_string()))), - lit(ScalarValue::Int64(Some(5))), - ], - Ok(Some(" hi")), - &str, - Utf8, - StringArray - ); - test_function!( - Lpad, - &[ - lit(ScalarValue::Utf8(Some("hi".to_string()))), - lit(ScalarValue::Int64(Some(0))), - ], - Ok(Some("")), - &str, - Utf8, - StringArray - ); - test_function!( - Lpad, - &[ - lit(ScalarValue::Utf8(Some("hi".to_string()))), - lit(ScalarValue::Int64(None)), - ], - Ok(None), + Ok(None), &str, Utf8, StringArray @@ -1829,6 +1867,44 @@ mod tests { Utf8, StringArray ); + #[cfg(feature = "crypto_expressions")] + test_function!( + MD5, + &[lit(ScalarValue::Utf8(Some("tom".to_string())))], + Ok(Some("34b7da764b21d298ef307d04d8152dc5")), + &str, + Utf8, + StringArray + ); + #[cfg(feature = "crypto_expressions")] + test_function!( + MD5, + &[lit(ScalarValue::Utf8(Some("".to_string())))], + Ok(Some("d41d8cd98f00b204e9800998ecf8427e")), + &str, + Utf8, + StringArray + ); + #[cfg(feature = "crypto_expressions")] + test_function!( + MD5, + &[lit(ScalarValue::Utf8(None))], + Ok(None), + &str, + Utf8, + StringArray + ); + #[cfg(not(feature = "crypto_expressions"))] + test_function!( + MD5, + &[lit(ScalarValue::Utf8(Some("tom".to_string())))], + Err(DataFusionError::Internal( + "function md5 requires compilation with feature flag: crypto_expressions.".to_string() + )), + &str, + Utf8, + StringArray + ); test_function!( OctetLength, &[lit(ScalarValue::Utf8(Some("chars".to_string())))], @@ -1861,6 +1937,71 @@ mod tests { Int32, Int32Array ); + test_function!( + Repeat, + &[ + lit(ScalarValue::Utf8(Some("Pg".to_string()))), + lit(ScalarValue::Int64(Some(4))), + ], + Ok(Some("PgPgPgPg")), + &str, + Utf8, + StringArray + ); + test_function!( + Repeat, + &[ + lit(ScalarValue::Utf8(None)), + lit(ScalarValue::Int64(Some(4))), + ], + Ok(None), + &str, + Utf8, + StringArray + ); + test_function!( + Repeat, + &[ + lit(ScalarValue::Utf8(Some("Pg".to_string()))), + lit(ScalarValue::Int64(None)), + ], + Ok(None), + &str, + Utf8, + StringArray + ); + test_function!( + Reverse, + &[lit(ScalarValue::Utf8(Some("abcde".to_string())))], + Ok(Some("edcba")), + &str, + Utf8, + StringArray + ); + test_function!( + Reverse, + &[lit(ScalarValue::Utf8(Some("loẅks".to_string())))], + Ok(Some("skẅol")), + &str, + Utf8, + StringArray + ); + test_function!( + Reverse, + &[lit(ScalarValue::Utf8(Some("loẅks".to_string())))], + Ok(Some("skẅol")), + &str, + Utf8, + StringArray + ); + test_function!( + Reverse, + &[lit(ScalarValue::Utf8(None))], + Ok(None), + &str, + Utf8, + StringArray + ); test_function!( Right, &[ @@ -2171,6 +2312,200 @@ mod tests { Utf8, StringArray ); + #[cfg(feature = "crypto_expressions")] + test_function!( + SHA224, + &[lit(ScalarValue::Utf8(Some("tom".to_string())))], + Ok(Some(&[ + 11u8, 246u8, 203u8, 98u8, 100u8, 156u8, 66u8, 169u8, 174u8, 56u8, 118u8, + 171u8, 111u8, 109u8, 146u8, 173u8, 54u8, 203u8, 84u8, 20u8, 228u8, 149u8, + 248u8, 135u8, 50u8, 146u8, 190u8, 77u8 + ])), + &[u8], + Binary, + BinaryArray + ); + #[cfg(feature = "crypto_expressions")] + test_function!( + SHA224, + &[lit(ScalarValue::Utf8(Some("".to_string())))], + Ok(Some(&[ + 209u8, 74u8, 2u8, 140u8, 42u8, 58u8, 43u8, 201u8, 71u8, 97u8, 2u8, 187u8, + 40u8, 130u8, 52u8, 196u8, 21u8, 162u8, 176u8, 31u8, 130u8, 142u8, 166u8, + 42u8, 197u8, 179u8, 228u8, 47u8 + ])), + &[u8], + Binary, + BinaryArray + ); + #[cfg(feature = "crypto_expressions")] + test_function!( + SHA224, + &[lit(ScalarValue::Utf8(None))], + Ok(None), + &[u8], + Binary, + BinaryArray + ); + #[cfg(not(feature = "crypto_expressions"))] + test_function!( + SHA224, + &[lit(ScalarValue::Utf8(Some("tom".to_string())))], + Err(DataFusionError::Internal( + "function sha224 requires compilation with feature flag: crypto_expressions.".to_string() + )), + &[u8], + Binary, + BinaryArray + ); + #[cfg(feature = "crypto_expressions")] + test_function!( + SHA256, + &[lit(ScalarValue::Utf8(Some("tom".to_string())))], + Ok(Some(&[ + 225u8, 96u8, 143u8, 117u8, 197u8, 215u8, 129u8, 63u8, 61u8, 64u8, 49u8, + 203u8, 48u8, 191u8, 183u8, 134u8, 80u8, 125u8, 152u8, 19u8, 117u8, 56u8, + 255u8, 142u8, 18u8, 138u8, 111u8, 247u8, 78u8, 132u8, 230u8, 67u8 + ])), + &[u8], + Binary, + BinaryArray + ); + #[cfg(feature = "crypto_expressions")] + test_function!( + SHA256, + &[lit(ScalarValue::Utf8(Some("".to_string())))], + Ok(Some(&[ + 227u8, 176u8, 196u8, 66u8, 152u8, 252u8, 28u8, 20u8, 154u8, 251u8, 244u8, + 200u8, 153u8, 111u8, 185u8, 36u8, 39u8, 174u8, 65u8, 228u8, 100u8, 155u8, + 147u8, 76u8, 164u8, 149u8, 153u8, 27u8, 120u8, 82u8, 184u8, 85u8 + ])), + &[u8], + Binary, + BinaryArray + ); + #[cfg(feature = "crypto_expressions")] + test_function!( + SHA256, + &[lit(ScalarValue::Utf8(None))], + Ok(None), + &[u8], + Binary, + BinaryArray + ); + #[cfg(not(feature = "crypto_expressions"))] + test_function!( + SHA256, + &[lit(ScalarValue::Utf8(Some("tom".to_string())))], + Err(DataFusionError::Internal( + "function sha256 requires compilation with feature flag: crypto_expressions.".to_string() + )), + &[u8], + Binary, + BinaryArray + ); + #[cfg(feature = "crypto_expressions")] + test_function!( + SHA384, + &[lit(ScalarValue::Utf8(Some("tom".to_string())))], + Ok(Some(&[ + 9u8, 111u8, 91u8, 104u8, 170u8, 119u8, 132u8, 142u8, 79u8, 223u8, 92u8, + 28u8, 11u8, 53u8, 13u8, 226u8, 219u8, 250u8, 214u8, 15u8, 253u8, 124u8, + 37u8, 217u8, 234u8, 7u8, 198u8, 193u8, 155u8, 138u8, 77u8, 85u8, 169u8, + 24u8, 126u8, 177u8, 23u8, 197u8, 87u8, 136u8, 63u8, 88u8, 193u8, 109u8, + 250u8, 195u8, 227u8, 67u8 + ])), + &[u8], + Binary, + BinaryArray + ); + #[cfg(feature = "crypto_expressions")] + test_function!( + SHA384, + &[lit(ScalarValue::Utf8(Some("".to_string())))], + Ok(Some(&[ + 56u8, 176u8, 96u8, 167u8, 81u8, 172u8, 150u8, 56u8, 76u8, 217u8, 50u8, + 126u8, 177u8, 177u8, 227u8, 106u8, 33u8, 253u8, 183u8, 17u8, 20u8, 190u8, + 7u8, 67u8, 76u8, 12u8, 199u8, 191u8, 99u8, 246u8, 225u8, 218u8, 39u8, + 78u8, 222u8, 191u8, 231u8, 111u8, 101u8, 251u8, 213u8, 26u8, 210u8, + 241u8, 72u8, 152u8, 185u8, 91u8 + ])), + &[u8], + Binary, + BinaryArray + ); + #[cfg(feature = "crypto_expressions")] + test_function!( + SHA384, + &[lit(ScalarValue::Utf8(None))], + Ok(None), + &[u8], + Binary, + BinaryArray + ); + #[cfg(not(feature = "crypto_expressions"))] + test_function!( + SHA384, + &[lit(ScalarValue::Utf8(Some("tom".to_string())))], + Err(DataFusionError::Internal( + "function sha384 requires compilation with feature flag: crypto_expressions.".to_string() + )), + &[u8], + Binary, + BinaryArray + ); + #[cfg(feature = "crypto_expressions")] + test_function!( + SHA512, + &[lit(ScalarValue::Utf8(Some("tom".to_string())))], + Ok(Some(&[ + 110u8, 27u8, 155u8, 63u8, 232u8, 64u8, 104u8, 14u8, 55u8, 5u8, 31u8, + 122u8, 213u8, 233u8, 89u8, 214u8, 243u8, 154u8, 208u8, 248u8, 136u8, + 93u8, 133u8, 81u8, 102u8, 245u8, 92u8, 101u8, 148u8, 105u8, 211u8, 200u8, + 183u8, 129u8, 24u8, 196u8, 74u8, 42u8, 73u8, 199u8, 45u8, 219u8, 72u8, + 28u8, 214u8, 216u8, 115u8, 16u8, 52u8, 225u8, 28u8, 192u8, 48u8, 7u8, + 11u8, 168u8, 67u8, 169u8, 11u8, 52u8, 149u8, 203u8, 141u8, 62u8 + ])), + &[u8], + Binary, + BinaryArray + ); + #[cfg(feature = "crypto_expressions")] + test_function!( + SHA512, + &[lit(ScalarValue::Utf8(Some("".to_string())))], + Ok(Some(&[ + 207u8, 131u8, 225u8, 53u8, 126u8, 239u8, 184u8, 189u8, 241u8, 84u8, 40u8, + 80u8, 214u8, 109u8, 128u8, 7u8, 214u8, 32u8, 228u8, 5u8, 11u8, 87u8, + 21u8, 220u8, 131u8, 244u8, 169u8, 33u8, 211u8, 108u8, 233u8, 206u8, 71u8, + 208u8, 209u8, 60u8, 93u8, 133u8, 242u8, 176u8, 255u8, 131u8, 24u8, 210u8, + 135u8, 126u8, 236u8, 47u8, 99u8, 185u8, 49u8, 189u8, 71u8, 65u8, 122u8, + 129u8, 165u8, 56u8, 50u8, 122u8, 249u8, 39u8, 218u8, 62u8 + ])), + &[u8], + Binary, + BinaryArray + ); + #[cfg(feature = "crypto_expressions")] + test_function!( + SHA512, + &[lit(ScalarValue::Utf8(None))], + Ok(None), + &[u8], + Binary, + BinaryArray + ); + #[cfg(not(feature = "crypto_expressions"))] + test_function!( + SHA512, + &[lit(ScalarValue::Utf8(Some("tom".to_string())))], + Err(DataFusionError::Internal( + "function sha512 requires compilation with feature flag: crypto_expressions.".to_string() + )), + &[u8], + Binary, + BinaryArray + ); test_function!( Substr, &[ diff --git a/rust/datafusion/src/physical_plan/string_expressions.rs b/rust/datafusion/src/physical_plan/string_expressions.rs index 5d3c4d83a24..6210b0faeae 100644 --- a/rust/datafusion/src/physical_plan/string_expressions.rs +++ b/rust/datafusion/src/physical_plan/string_expressions.rs @@ -31,8 +31,8 @@ use crate::{ }; use arrow::{ array::{ - Array, ArrayRef, GenericStringArray, Int64Array, PrimitiveArray, StringArray, - StringOffsetSizeTrait, + Array, ArrayRef, GenericStringArray, Int32Array, Int64Array, PrimitiveArray, + StringArray, StringOffsetSizeTrait, }, datatypes::{ArrowNativeType, ArrowPrimitiveType, DataType}, }; @@ -136,6 +136,29 @@ macro_rules! downcast_vec { }}; } +/// Returns the numeric code of the first character of the argument. +/// ascii('x') = 120 +pub fn ascii(args: &[ArrayRef]) -> Result { + let string_array: &GenericStringArray = args[0] + .as_any() + .downcast_ref::>() + .ok_or_else(|| { + DataFusionError::Internal("could not cast string to StringArray".to_string()) + })?; + + let result = string_array + .iter() + .map(|string| { + string.map(|string: &str| { + let mut chars = string.chars(); + chars.next().map_or(0, |v| v as i32) + }) + }) + .collect::(); + + Ok(Arc::new(result) as ArrayRef) +} + /// Removes the longest string containing only characters in characters (a space by default) from the start and end of string. /// btrim('xyxtrimyyx', 'xyz') = 'trim' pub fn btrim(args: &[ArrayRef]) -> Result { @@ -144,11 +167,19 @@ pub fn btrim(args: &[ArrayRef]) -> Result { let string_array: &GenericStringArray = args[0] .as_any() .downcast_ref::>() - .unwrap(); + .ok_or_else(|| { + DataFusionError::Internal( + "could not cast string to StringArray".to_string(), + ) + })?; let result = string_array .iter() - .map(|x| x.map(|x: &str| x.trim_start_matches(' ').trim_end_matches(' '))) + .map(|string| { + string.map(|string: &str| { + string.trim_start_matches(' ').trim_end_matches(' ') + }) + }) .collect::>(); Ok(Arc::new(result) as ArrayRef) @@ -157,26 +188,34 @@ pub fn btrim(args: &[ArrayRef]) -> Result { let string_array: &GenericStringArray = args[0] .as_any() .downcast_ref::>() - .unwrap(); + .ok_or_else(|| { + DataFusionError::Internal( + "could not cast string to StringArray".to_string(), + ) + })?; let characters_array: &GenericStringArray = args[1] .as_any() .downcast_ref::>() - .unwrap(); + .ok_or_else(|| { + DataFusionError::Internal( + "could not cast characters to StringArray".to_string(), + ) + })?; let result = string_array .iter() - .enumerate() - .map(|(i, x)| { - if characters_array.is_null(i) { - None - } else { - x.map(|x: &str| { - let chars: Vec = - characters_array.value(i).chars().collect(); - x.trim_start_matches(&chars[..]) - .trim_end_matches(&chars[..]) - }) + .zip(characters_array.iter()) + .map(|(string, characters)| match (string, characters) { + (None, _) => None, + (_, None) => None, + (Some(string), Some(characters)) => { + let chars: Vec = characters.chars().collect(); + Some( + string + .trim_start_matches(&chars[..]) + .trim_end_matches(&chars[..]), + ) } }) .collect::>(); @@ -184,7 +223,7 @@ pub fn btrim(args: &[ArrayRef]) -> Result { Ok(Arc::new(result) as ArrayRef) } other => Err(DataFusionError::Internal(format!( - "btrim was called with {} arguments. It requires at most 2.", + "btrim was called with {} arguments. It requires at least 1 and at most 2.", other ))), } @@ -199,18 +238,58 @@ where let string_array: &GenericStringArray = args[0] .as_any() .downcast_ref::>() - .unwrap(); + .ok_or_else(|| { + DataFusionError::Internal("could not cast string to StringArray".to_string()) + })?; let result = string_array .iter() - .map(|x| { - x.map(|x: &str| T::Native::from_usize(x.graphemes(true).count()).unwrap()) + .map(|string| { + string.map(|string: &str| { + T::Native::from_usize(string.graphemes(true).count()).unwrap() + }) }) .collect::>(); Ok(Arc::new(result) as ArrayRef) } +/// Returns the character with the given code. chr(0) is disallowed because text data types cannot store that character. +/// chr(65) = 'A' +pub fn chr(args: &[ArrayRef]) -> Result { + let integer_array: &Int64Array = args[0] + .as_any() + .downcast_ref::() + .ok_or_else(|| { + DataFusionError::Internal("could not cast integer to Int64Array".to_string()) + })?; + + // first map is the iterator, second is for the `Option<_>` + let result = integer_array + .iter() + .map(|integer: Option| { + integer + .map(|integer| { + if integer == 0 { + Err(DataFusionError::Execution( + "null character not permitted.".to_string(), + )) + } else { + match core::char::from_u32(integer as u32) { + Some(integer) => Ok(integer.to_string()), + None => Err(DataFusionError::Execution( + "requested character too large for encoding.".to_string(), + )), + } + } + }) + .transpose() + }) + .collect::>()?; + + Ok(Arc::new(result) as ArrayRef) +} + /// Concatenates the text representations of all the arguments. NULL arguments are ignored. /// concat('abcde', 2, NULL, 22) = 'abcde222' pub fn concat(args: &[ColumnarValue]) -> Result { @@ -310,6 +389,41 @@ pub fn concat_ws(args: &[ArrayRef]) -> Result { Ok(Arc::new(result) as ArrayRef) } +/// Converts the first letter of each word to upper case and the rest to lower case. Words are sequences of alphanumeric characters separated by non-alphanumeric characters. +/// initcap('hi THOMAS') = 'Hi Thomas' +pub fn initcap(args: &[ArrayRef]) -> Result { + let string_array: &GenericStringArray = args[0] + .as_any() + .downcast_ref::>() + .ok_or_else(|| { + DataFusionError::Internal("could not cast string to StringArray".to_string()) + })?; + + // first map is the iterator, second is for the `Option<_>` + let result = string_array + .iter() + .map(|string| { + string.map(|string: &str| { + let mut char_vector = Vec::::new(); + let mut previous_character_letter_or_number = false; + for c in string.chars() { + if previous_character_letter_or_number { + char_vector.push(c.to_ascii_lowercase()); + } else { + char_vector.push(c.to_ascii_uppercase()); + } + previous_character_letter_or_number = ('A'..='Z').contains(&c) + || ('a'..='z').contains(&c) + || ('0'..='9').contains(&c); + } + char_vector.iter().collect::() + }) + }) + .collect::>(); + + Ok(Arc::new(result) as ArrayRef) +} + /// Returns first n characters in the string, or when n is negative, returns all but last |n| characters. /// left('abcde', 2) = 'ab' pub fn left(args: &[ArrayRef]) -> Result { @@ -363,7 +477,7 @@ pub fn left(args: &[ArrayRef]) -> Result { /// Converts the string to all lower case. /// lower('TOM') = 'tom' pub fn lower(args: &[ColumnarValue]) -> Result { - handle(args, |x| x.to_ascii_lowercase(), "lower") + handle(args, |string| string.to_ascii_lowercase(), "lower") } /// Extends the string to length length by prepending the characters fill (a space by default). If the string is already longer than length then it is truncated (on the right). @@ -512,7 +626,7 @@ pub fn ltrim(args: &[ArrayRef]) -> Result { let result = string_array .iter() - .map(|x| x.map(|x: &str| x.trim_start_matches(' '))) + .map(|string| string.map(|string: &str| string.trim_start_matches(' '))) .collect::>(); Ok(Arc::new(result) as ArrayRef) @@ -558,6 +672,56 @@ pub fn ltrim(args: &[ArrayRef]) -> Result { } } +/// Repeats string the specified number of times. +/// repeat('Pg', 4) = 'PgPgPgPg' +pub fn repeat(args: &[ArrayRef]) -> Result { + let string_array: &GenericStringArray = args[0] + .as_any() + .downcast_ref::>() + .ok_or_else(|| { + DataFusionError::Internal("could not cast string to StringArray".to_string()) + })?; + + let number_array: &Int64Array = args[1] + .as_any() + .downcast_ref::() + .ok_or_else(|| { + DataFusionError::Internal("could not cast number to Int64Array".to_string()) + })?; + + let result = string_array + .iter() + .zip(number_array.iter()) + .map(|(string, number)| match (string, number) { + (None, _) => None, + (_, None) => None, + (Some(string), Some(number)) => Some(string.repeat(number as usize)), + }) + .collect::>(); + + Ok(Arc::new(result) as ArrayRef) +} + +/// Reverses the order of the characters in the string. +/// reverse('abcde') = 'edcba' +pub fn reverse(args: &[ArrayRef]) -> Result { + let string_array: &GenericStringArray = args[0] + .as_any() + .downcast_ref::>() + .ok_or_else(|| { + DataFusionError::Internal("could not cast string to StringArray".to_string()) + })?; + + let result = string_array + .iter() + .map(|string| { + string.map(|string: &str| string.graphemes(true).rev().collect::()) + }) + .collect::>(); + + Ok(Arc::new(result) as ArrayRef) +} + /// Returns last n characters in the string, or when n is negative, returns all but first |n| characters. /// right('abcde', 2) = 'de' pub fn right(args: &[ArrayRef]) -> Result { @@ -907,8 +1071,33 @@ pub fn substr(args: &[ArrayRef]) -> Result { } } +/// Converts the number to its equivalent hexadecimal representation. +/// to_hex(2147483647) = '7fffffff' +pub fn to_hex(args: &[ArrayRef]) -> Result +where + T::Native: StringOffsetSizeTrait, +{ + let integer_array: &PrimitiveArray = args[0] + .as_any() + .downcast_ref::>() + .ok_or_else(|| { + DataFusionError::Internal( + "could not cast integer to PrimitiveArray".to_string(), + ) + })?; + + let result = integer_array + .iter() + .map(|integer| { + integer.map(|integer| format!("{:x}", integer.to_usize().unwrap())) + }) + .collect::>(); + + Ok(Arc::new(result) as ArrayRef) +} + /// Converts the string to all upper case. /// upper('tom') = 'TOM' pub fn upper(args: &[ColumnarValue]) -> Result { - handle(args, |x| x.to_ascii_uppercase(), "upper") + handle(args, |string| string.to_ascii_uppercase(), "upper") } diff --git a/rust/datafusion/src/prelude.rs b/rust/datafusion/src/prelude.rs index 1f8588c9c52..5ab5e760246 100644 --- a/rust/datafusion/src/prelude.rs +++ b/rust/datafusion/src/prelude.rs @@ -28,9 +28,9 @@ pub use crate::dataframe::DataFrame; pub use crate::execution::context::{ExecutionConfig, ExecutionContext}; pub use crate::logical_plan::{ - array, avg, bit_length, btrim, character_length, col, concat, concat_ws, count, - create_udf, in_list, left, length, lit, lower, lpad, ltrim, max, md5, min, - octet_length, right, rpad, rtrim, sha224, sha256, sha384, sha512, substr, sum, trim, - upper, JoinType, Partitioning, + array, ascii, avg, bit_length, btrim, character_length, chr, col, concat, concat_ws, + count, create_udf, in_list, initcap, left, length, lit, lower, lpad, ltrim, max, md5, + min, octet_length, repeat, reverse, right, rpad, rtrim, sha224, sha256, sha384, + sha512, substr, sum, to_hex, trim, upper, JoinType, Partitioning, }; pub use crate::physical_plan::csv::CsvReadOptions; diff --git a/rust/datafusion/tests/sql.rs b/rust/datafusion/tests/sql.rs index 5bddcbe1c07..c8e198cb13c 100644 --- a/rust/datafusion/tests/sql.rs +++ b/rust/datafusion/tests/sql.rs @@ -2034,6 +2034,9 @@ macro_rules! test_expression { #[tokio::test] async fn test_string_expressions() -> Result<()> { + test_expression!("ascii('')", "0"); + test_expression!("ascii('x')", "120"); + test_expression!("ascii(NULL)", "NULL"); test_expression!("bit_length('')", "0"); test_expression!("bit_length('chars')", "40"); test_expression!("bit_length('josé')", "40"); @@ -2051,6 +2054,9 @@ async fn test_string_expressions() -> Result<()> { test_expression!("character_length('chars')", "5"); test_expression!("character_length('josé')", "4"); test_expression!("character_length(NULL)", "NULL"); + test_expression!("chr(CAST(120 AS int))", "x"); + test_expression!("chr(CAST(128175 AS int))", "💯"); + test_expression!("chr(CAST(NULL AS int))", "NULL"); test_expression!("concat('a','b','c')", "abc"); test_expression!("concat('abcde', 2, NULL, 22)", "abcde222"); test_expression!("concat(NULL)", ""); @@ -2058,6 +2064,9 @@ async fn test_string_expressions() -> Result<()> { test_expression!("concat_ws('|','a','b','c')", "a|b|c"); test_expression!("concat_ws('|',NULL)", ""); test_expression!("concat_ws(NULL,'a',NULL,'b','c')", "NULL"); + test_expression!("initcap('')", ""); + test_expression!("initcap('hi THOMAS')", "Hi Thomas"); + test_expression!("initcap(NULL)", "NULL"); test_expression!("left('abcde', -2)", "abc"); test_expression!("left('abcde', -200)", ""); test_expression!("left('abcde', 0)", ""); @@ -2088,6 +2097,12 @@ async fn test_string_expressions() -> Result<()> { test_expression!("octet_length('chars')", "5"); test_expression!("octet_length('josé')", "5"); test_expression!("octet_length(NULL)", "NULL"); + test_expression!("repeat('Pg', 4)", "PgPgPgPg"); + test_expression!("repeat('Pg', CAST(NULL AS INT))", "NULL"); + test_expression!("repeat(NULL, 4)", "NULL"); + test_expression!("reverse('abcde')", "edcba"); + test_expression!("reverse('loẅks')", "skẅol"); + test_expression!("reverse(NULL)", "NULL"); test_expression!("right('abcde', -2)", "cde"); test_expression!("right('abcde', -200)", ""); test_expression!("right('abcde', 0)", ""); @@ -2120,6 +2135,9 @@ async fn test_string_expressions() -> Result<()> { test_expression!("substr('alphabet', 3, 20)", "phabet"); test_expression!("substr('alphabet', CAST(NULL AS int), 20)", "NULL"); test_expression!("substr('alphabet', 3, CAST(NULL AS int))", "NULL"); + test_expression!("to_hex(2147483647)", "7fffffff"); + test_expression!("to_hex(9223372036854775807)", "7fffffffffffffff"); + test_expression!("to_hex(CAST(NULL AS int))", "NULL"); test_expression!("trim(' tom ')", "tom"); test_expression!("trim(' tom')", "tom"); test_expression!("trim('')", ""); From 190547a0b5e91fff3dc00eb9f2e38a5c663448a1 Mon Sep 17 00:00:00 2001 From: Mike Seddon Date: Thu, 4 Mar 2021 11:05:45 +1100 Subject: [PATCH 2/2] create downcast macros --- .../src/physical_plan/string_expressions.rs | 387 +++++------------- 1 file changed, 92 insertions(+), 295 deletions(-) diff --git a/rust/datafusion/src/physical_plan/string_expressions.rs b/rust/datafusion/src/physical_plan/string_expressions.rs index 6210b0faeae..bc0e7633379 100644 --- a/rust/datafusion/src/physical_plan/string_expressions.rs +++ b/rust/datafusion/src/physical_plan/string_expressions.rs @@ -21,6 +21,7 @@ //! String expressions +use std::any::type_name; use std::cmp::Ordering; use std::str::from_utf8; use std::sync::Arc; @@ -40,6 +41,57 @@ use unicode_segmentation::UnicodeSegmentation; use super::ColumnarValue; +macro_rules! downcast_string_arg { + ($ARG:expr, $NAME:expr, $T:ident) => {{ + $ARG.as_any() + .downcast_ref::>() + .ok_or_else(|| { + DataFusionError::Internal(format!( + "could not cast {} to {}", + $NAME, + type_name::>() + )) + })? + }}; +} + +macro_rules! downcast_primitive_array_arg { + ($ARG:expr, $NAME:expr, $T:ident) => {{ + $ARG.as_any() + .downcast_ref::>() + .ok_or_else(|| { + DataFusionError::Internal(format!( + "could not cast {} to {}", + $NAME, + type_name::>() + )) + })? + }}; +} + +macro_rules! downcast_arg { + ($ARG:expr, $NAME:expr, $ARRAY_TYPE:ident) => {{ + $ARG.as_any().downcast_ref::<$ARRAY_TYPE>().ok_or_else(|| { + DataFusionError::Internal(format!( + "could not cast {} to {}", + $NAME, + type_name::<$ARRAY_TYPE>() + )) + })? + }}; +} + +macro_rules! downcast_vec { + ($ARGS:expr, $ARRAY_TYPE:ident) => {{ + $ARGS + .iter() + .map(|e| match e.as_any().downcast_ref::<$ARRAY_TYPE>() { + Some(array) => Ok(array), + _ => Err(DataFusionError::Internal("failed to downcast".to_string())), + }) + }}; +} + /// applies a unary expression to `args[0]` that is expected to be downcastable to /// a `GenericStringArray` and returns a `GenericStringArray` (which may have a different offset) /// # Errors @@ -65,15 +117,13 @@ where ))); } - let array = args[0] - .as_any() - .downcast_ref::>() - .ok_or_else(|| { - DataFusionError::Internal("failed to downcast to string".to_string()) - })?; + let string_array = downcast_string_arg!(args[0], "string", T); // first map is the iterator, second is for the `Option<_>` - Ok(array.iter().map(|x| x.map(|x| op(x))).collect()) + Ok(string_array + .iter() + .map(|string| string.map(|string| op(string))) + .collect()) } fn handle<'a, F, R>(args: &'a [ColumnarValue], op: F, name: &str) -> Result @@ -125,26 +175,10 @@ where } } -macro_rules! downcast_vec { - ($ARGS:expr, $ARRAY_TYPE:ident) => {{ - $ARGS - .iter() - .map(|e| match e.as_any().downcast_ref::<$ARRAY_TYPE>() { - Some(array) => Ok(array), - _ => Err(DataFusionError::Internal("failed to downcast".to_string())), - }) - }}; -} - /// Returns the numeric code of the first character of the argument. /// ascii('x') = 120 pub fn ascii(args: &[ArrayRef]) -> Result { - let string_array: &GenericStringArray = args[0] - .as_any() - .downcast_ref::>() - .ok_or_else(|| { - DataFusionError::Internal("could not cast string to StringArray".to_string()) - })?; + let string_array = downcast_string_arg!(args[0], "string", T); let result = string_array .iter() @@ -164,14 +198,7 @@ pub fn ascii(args: &[ArrayRef]) -> Result { pub fn btrim(args: &[ArrayRef]) -> Result { match args.len() { 1 => { - let string_array: &GenericStringArray = args[0] - .as_any() - .downcast_ref::>() - .ok_or_else(|| { - DataFusionError::Internal( - "could not cast string to StringArray".to_string(), - ) - })?; + let string_array = downcast_string_arg!(args[0], "string", T); let result = string_array .iter() @@ -185,23 +212,8 @@ pub fn btrim(args: &[ArrayRef]) -> Result { Ok(Arc::new(result) as ArrayRef) } 2 => { - let string_array: &GenericStringArray = args[0] - .as_any() - .downcast_ref::>() - .ok_or_else(|| { - DataFusionError::Internal( - "could not cast string to StringArray".to_string(), - ) - })?; - - let characters_array: &GenericStringArray = args[1] - .as_any() - .downcast_ref::>() - .ok_or_else(|| { - DataFusionError::Internal( - "could not cast characters to StringArray".to_string(), - ) - })?; + let string_array = downcast_string_arg!(args[0], "string", T); + let characters_array = downcast_string_arg!(args[1], "characters", T); let result = string_array .iter() @@ -257,12 +269,7 @@ where /// Returns the character with the given code. chr(0) is disallowed because text data types cannot store that character. /// chr(65) = 'A' pub fn chr(args: &[ArrayRef]) -> Result { - let integer_array: &Int64Array = args[0] - .as_any() - .downcast_ref::() - .ok_or_else(|| { - DataFusionError::Internal("could not cast integer to Int64Array".to_string()) - })?; + let integer_array = downcast_arg!(args[0], "integer", Int64Array); // first map is the iterator, second is for the `Option<_>` let result = integer_array @@ -392,12 +399,7 @@ pub fn concat_ws(args: &[ArrayRef]) -> Result { /// Converts the first letter of each word to upper case and the rest to lower case. Words are sequences of alphanumeric characters separated by non-alphanumeric characters. /// initcap('hi THOMAS') = 'Hi Thomas' pub fn initcap(args: &[ArrayRef]) -> Result { - let string_array: &GenericStringArray = args[0] - .as_any() - .downcast_ref::>() - .ok_or_else(|| { - DataFusionError::Internal("could not cast string to StringArray".to_string()) - })?; + let string_array = downcast_string_arg!(args[0], "string", T); // first map is the iterator, second is for the `Option<_>` let result = string_array @@ -427,20 +429,8 @@ pub fn initcap(args: &[ArrayRef]) -> Result /// Returns first n characters in the string, or when n is negative, returns all but last |n| characters. /// left('abcde', 2) = 'ab' pub fn left(args: &[ArrayRef]) -> Result { - let string_array: &GenericStringArray = args[0] - .as_any() - .downcast_ref::>() - .ok_or_else(|| { - DataFusionError::Internal("could not cast string to StringArray".to_string()) - })?; - - let n_array: &Int64Array = - args[1] - .as_any() - .downcast_ref::() - .ok_or_else(|| { - DataFusionError::Internal("could not cast n to Int64Array".to_string()) - })?; + let string_array = downcast_string_arg!(args[0], "string", T); + let n_array = downcast_arg!(args[1], "n", Int64Array); let result = string_array .iter() @@ -485,23 +475,8 @@ pub fn lower(args: &[ColumnarValue]) -> Result { pub fn lpad(args: &[ArrayRef]) -> Result { match args.len() { 2 => { - let string_array: &GenericStringArray = args[0] - .as_any() - .downcast_ref::>() - .ok_or_else(|| { - DataFusionError::Internal( - "could not cast string to StringArray".to_string(), - ) - })?; - - let length_array: &Int64Array = args[1] - .as_any() - .downcast_ref::() - .ok_or_else(|| { - DataFusionError::Internal( - "could not cast length to Int64Array".to_string(), - ) - })?; + let string_array = downcast_string_arg!(args[0], "string", T); + let length_array = downcast_arg!(args[1], "length", Int64Array); let result = string_array .iter() @@ -533,32 +508,9 @@ pub fn lpad(args: &[ArrayRef]) -> Result { Ok(Arc::new(result) as ArrayRef) } 3 => { - let string_array: &GenericStringArray = args[0] - .as_any() - .downcast_ref::>() - .ok_or_else(|| { - DataFusionError::Internal( - "could not cast string to StringArray".to_string(), - ) - })?; - - let length_array: &Int64Array = args[1] - .as_any() - .downcast_ref::() - .ok_or_else(|| { - DataFusionError::Internal( - "could not cast length to Int64Array".to_string(), - ) - })?; - - let fill_array: &GenericStringArray = args[2] - .as_any() - .downcast_ref::>() - .ok_or_else(|| { - DataFusionError::Internal( - "could not cast fill to StringArray".to_string(), - ) - })?; + let string_array = downcast_string_arg!(args[0], "string", T); + let length_array = downcast_arg!(args[1], "length", Int64Array); + let fill_array = downcast_string_arg!(args[2], "fill", T); let result = string_array .iter() @@ -615,14 +567,7 @@ pub fn lpad(args: &[ArrayRef]) -> Result { pub fn ltrim(args: &[ArrayRef]) -> Result { match args.len() { 1 => { - let string_array: &GenericStringArray = args[0] - .as_any() - .downcast_ref::>() - .ok_or_else(|| { - DataFusionError::Internal( - "could not cast string to StringArray".to_string(), - ) - })?; + let string_array = downcast_string_arg!(args[0], "string", T); let result = string_array .iter() @@ -632,23 +577,8 @@ pub fn ltrim(args: &[ArrayRef]) -> Result { Ok(Arc::new(result) as ArrayRef) } 2 => { - let string_array: &GenericStringArray = args[0] - .as_any() - .downcast_ref::>() - .ok_or_else(|| { - DataFusionError::Internal( - "could not cast string to StringArray".to_string(), - ) - })?; - - let characters_array: &GenericStringArray = args[1] - .as_any() - .downcast_ref::>() - .ok_or_else(|| { - DataFusionError::Internal( - "could not cast characters to StringArray".to_string(), - ) - })?; + let string_array = downcast_string_arg!(args[0], "string", T); + let characters_array = downcast_string_arg!(args[1], "characters", T); let result = string_array .iter() @@ -675,19 +605,8 @@ pub fn ltrim(args: &[ArrayRef]) -> Result { /// Repeats string the specified number of times. /// repeat('Pg', 4) = 'PgPgPgPg' pub fn repeat(args: &[ArrayRef]) -> Result { - let string_array: &GenericStringArray = args[0] - .as_any() - .downcast_ref::>() - .ok_or_else(|| { - DataFusionError::Internal("could not cast string to StringArray".to_string()) - })?; - - let number_array: &Int64Array = args[1] - .as_any() - .downcast_ref::() - .ok_or_else(|| { - DataFusionError::Internal("could not cast number to Int64Array".to_string()) - })?; + let string_array = downcast_string_arg!(args[0], "string", T); + let number_array = downcast_arg!(args[1], "number", Int64Array); let result = string_array .iter() @@ -705,12 +624,7 @@ pub fn repeat(args: &[ArrayRef]) -> Result { /// Reverses the order of the characters in the string. /// reverse('abcde') = 'edcba' pub fn reverse(args: &[ArrayRef]) -> Result { - let string_array: &GenericStringArray = args[0] - .as_any() - .downcast_ref::>() - .ok_or_else(|| { - DataFusionError::Internal("could not cast string to StringArray".to_string()) - })?; + let string_array = downcast_string_arg!(args[0], "string", T); let result = string_array .iter() @@ -725,20 +639,8 @@ pub fn reverse(args: &[ArrayRef]) -> Result /// Returns last n characters in the string, or when n is negative, returns all but first |n| characters. /// right('abcde', 2) = 'de' pub fn right(args: &[ArrayRef]) -> Result { - let string_array: &GenericStringArray = args[0] - .as_any() - .downcast_ref::>() - .ok_or_else(|| { - DataFusionError::Internal("could not cast string to StringArray".to_string()) - })?; - - let n_array: &Int64Array = - args[1] - .as_any() - .downcast_ref::() - .ok_or_else(|| { - DataFusionError::Internal("could not cast n to Int64Array".to_string()) - })?; + let string_array = downcast_string_arg!(args[0], "string", T); + let n_array = downcast_arg!(args[1], "n", Int64Array); let result = string_array .iter() @@ -777,23 +679,8 @@ pub fn right(args: &[ArrayRef]) -> Result { pub fn rpad(args: &[ArrayRef]) -> Result { match args.len() { 2 => { - let string_array: &GenericStringArray = args[0] - .as_any() - .downcast_ref::>() - .ok_or_else(|| { - DataFusionError::Internal( - "could not cast string to StringArray".to_string(), - ) - })?; - - let length_array: &Int64Array = args[1] - .as_any() - .downcast_ref::() - .ok_or_else(|| { - DataFusionError::Internal( - "could not cast length to Int64Array".to_string(), - ) - })?; + let string_array = downcast_string_arg!(args[0], "string", T); + let length_array = downcast_arg!(args[1], "length", Int64Array); let result = string_array .iter() @@ -822,32 +709,9 @@ pub fn rpad(args: &[ArrayRef]) -> Result { Ok(Arc::new(result) as ArrayRef) } 3 => { - let string_array: &GenericStringArray = args[0] - .as_any() - .downcast_ref::>() - .ok_or_else(|| { - DataFusionError::Internal( - "could not cast string to StringArray".to_string(), - ) - })?; - - let length_array: &Int64Array = args[1] - .as_any() - .downcast_ref::() - .ok_or_else(|| { - DataFusionError::Internal( - "could not cast length to Int64Array".to_string(), - ) - })?; - - let fill_array: &GenericStringArray = args[2] - .as_any() - .downcast_ref::>() - .ok_or_else(|| { - DataFusionError::Internal( - "could not cast fill to StringArray".to_string(), - ) - })?; + let string_array = downcast_string_arg!(args[0], "string", T); + let length_array = downcast_arg!(args[1], "length", Int64Array); + let fill_array = downcast_string_arg!(args[2], "fill", T); let result = string_array .iter() @@ -895,14 +759,7 @@ pub fn rpad(args: &[ArrayRef]) -> Result { pub fn rtrim(args: &[ArrayRef]) -> Result { match args.len() { 1 => { - let string_array: &GenericStringArray = args[0] - .as_any() - .downcast_ref::>() - .ok_or_else(|| { - DataFusionError::Internal( - "could not cast string to StringArray".to_string(), - ) - })?; + let string_array = downcast_string_arg!(args[0], "string", T); let result = string_array .iter() @@ -912,23 +769,8 @@ pub fn rtrim(args: &[ArrayRef]) -> Result { Ok(Arc::new(result) as ArrayRef) } 2 => { - let string_array: &GenericStringArray = args[0] - .as_any() - .downcast_ref::>() - .ok_or_else(|| { - DataFusionError::Internal( - "could not cast string to StringArray".to_string(), - ) - })?; - - let characters_array: &GenericStringArray = args[1] - .as_any() - .downcast_ref::>() - .ok_or_else(|| { - DataFusionError::Internal( - "could not cast characters to StringArray".to_string(), - ) - })?; + let string_array = downcast_string_arg!(args[0], "string", T); + let characters_array = downcast_string_arg!(args[1], "characters", T); let result = string_array .iter() @@ -958,23 +800,8 @@ pub fn rtrim(args: &[ArrayRef]) -> Result { pub fn substr(args: &[ArrayRef]) -> Result { match args.len() { 2 => { - let string_array: &GenericStringArray = args[0] - .as_any() - .downcast_ref::>() - .ok_or_else(|| { - DataFusionError::Internal( - "could not cast string to StringArray".to_string(), - ) - })?; - - let start_array: &Int64Array = args[1] - .as_any() - .downcast_ref::() - .ok_or_else(|| { - DataFusionError::Internal( - "could not cast start to Int64Array".to_string(), - ) - })?; + let string_array = downcast_string_arg!(args[0], "string", T); + let start_array = downcast_arg!(args[1], "start", Int64Array); let result = string_array .iter() @@ -1001,32 +828,9 @@ pub fn substr(args: &[ArrayRef]) -> Result { Ok(Arc::new(result) as ArrayRef) } 3 => { - let string_array: &GenericStringArray = args[0] - .as_any() - .downcast_ref::>() - .ok_or_else(|| { - DataFusionError::Internal( - "could not cast string to StringArray".to_string(), - ) - })?; - - let start_array: &Int64Array = args[1] - .as_any() - .downcast_ref::() - .ok_or_else(|| { - DataFusionError::Internal( - "could not cast start to Int64Array".to_string(), - ) - })?; - - let count_array: &Int64Array = args[2] - .as_any() - .downcast_ref::() - .ok_or_else(|| { - DataFusionError::Internal( - "could not cast count to Int64Array".to_string(), - ) - })?; + let string_array = downcast_string_arg!(args[0], "string", T); + let start_array = downcast_arg!(args[1], "start", Int64Array); + let count_array = downcast_arg!(args[2], "count", Int64Array); let result = string_array .iter() @@ -1077,14 +881,7 @@ pub fn to_hex(args: &[ArrayRef]) -> Result where T::Native: StringOffsetSizeTrait, { - let integer_array: &PrimitiveArray = args[0] - .as_any() - .downcast_ref::>() - .ok_or_else(|| { - DataFusionError::Internal( - "could not cast integer to PrimitiveArray".to_string(), - ) - })?; + let integer_array = downcast_primitive_array_arg!(args[0], "integer", T); let result = integer_array .iter()