diff --git a/rust/arrow/benches/length_kernel.rs b/rust/arrow/benches/length_kernel.rs index b70f6374f8f..e10c8a4431a 100644 --- a/rust/arrow/benches/length_kernel.rs +++ b/rust/arrow/benches/length_kernel.rs @@ -21,8 +21,7 @@ use criterion::Criterion; extern crate arrow; -use arrow::array::*; -use arrow::compute::kernels::length::length; +use arrow::{array::*, compute::kernels::length::length}; fn bench_length(array: &StringArray) { criterion::black_box(length(array).unwrap()); diff --git a/rust/arrow/src/compute/kernels/bit_length.rs b/rust/arrow/src/compute/kernels/bit_length.rs new file mode 100644 index 00000000000..fee771315df --- /dev/null +++ b/rust/arrow/src/compute/kernels/bit_length.rs @@ -0,0 +1,208 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Defines kernel for length of a string array + +use crate::{array::*, buffer::Buffer}; +use crate::{ + datatypes::DataType, + error::{ArrowError, Result}, +}; +use std::sync::Arc; + +fn bit_length_string(array: &Array, data_type: DataType) -> ArrayRef +where + OffsetSize: OffsetSizeTrait, +{ + // note: offsets are stored as u8, but they can be interpreted as OffsetSize + let offsets = &array.data_ref().buffers()[0]; + // this is a 30% improvement over iterating over u8s and building OffsetSize, which + // justifies the usage of `unsafe`. + let slice: &[OffsetSize] = + &unsafe { offsets.typed_data::() }[array.offset()..]; + + let bit_size = OffsetSize::from_usize(8).unwrap(); + let lengths = slice + .windows(2) + .map(|offset| (offset[1] - offset[0]) * bit_size); + + // JUSTIFICATION + // Benefit + // ~60% speedup + // Soundness + // `values` is an iterator with a known size. + let buffer = unsafe { Buffer::from_trusted_len_iter(lengths) }; + + let null_bit_buffer = array + .data_ref() + .null_bitmap() + .as_ref() + .map(|b| b.bits.clone()); + + let data = ArrayData::new( + data_type, + array.len(), + None, + null_bit_buffer, + 0, + vec![buffer], + vec![], + ); + make_array(Arc::new(data)) +} + +/// Returns an array of Int32/Int64 denoting the number of bits in each string in the array. +/// +/// * this only accepts StringArray/Utf8 and LargeString/LargeUtf8 +/// * bit_length of null is null. +/// * bit_length is in number of bits +pub fn bit_length(array: &Array) -> Result { + match array.data_type() { + DataType::Utf8 => Ok(bit_length_string::(array, DataType::Int32)), + DataType::LargeUtf8 => Ok(bit_length_string::(array, DataType::Int64)), + _ => Err(ArrowError::ComputeError(format!( + "bit_length not supported for {:?}", + array.data_type() + ))), + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn cases() -> Vec<(Vec<&'static str>, usize, Vec)> { + fn double_vec(v: Vec) -> Vec { + [&v[..], &v[..]].concat() + } + + // a large array + let mut values = vec!["one", "on", "o", ""]; + let mut expected = vec![24, 16, 8, 0]; + for _ in 0..10 { + values = double_vec(values); + expected = double_vec(expected); + } + + vec![ + (vec!["hello", " ", "world", "!"], 4, vec![40, 8, 40, 8]), + (vec!["💖"], 1, vec![32]), + (vec!["josé"], 1, vec![40]), + (values, 4096, expected), + ] + } + + #[test] + fn test_string() -> Result<()> { + cases().into_iter().try_for_each(|(input, len, expected)| { + let array = StringArray::from(input); + let result = bit_length(&array)?; + assert_eq!(len, result.len()); + let result = result.as_any().downcast_ref::().unwrap(); + expected.iter().enumerate().for_each(|(i, value)| { + assert_eq!(*value, result.value(i)); + }); + Ok(()) + }) + } + + #[test] + fn test_large_string() -> Result<()> { + cases().into_iter().try_for_each(|(input, len, expected)| { + let array = LargeStringArray::from(input); + let result = bit_length(&array)?; + assert_eq!(len, result.len()); + let result = result.as_any().downcast_ref::().unwrap(); + expected.iter().enumerate().for_each(|(i, value)| { + assert_eq!(*value as i64, result.value(i)); + }); + Ok(()) + }) + } + + fn null_cases() -> Vec<(Vec>, usize, Vec>)> { + vec![( + vec![Some("one"), None, Some("three"), Some("four")], + 4, + vec![Some(24), None, Some(40), Some(32)], + )] + } + + #[test] + fn null_string() -> Result<()> { + null_cases() + .into_iter() + .try_for_each(|(input, len, expected)| { + let array = StringArray::from(input); + let result = bit_length(&array)?; + assert_eq!(len, result.len()); + let result = result.as_any().downcast_ref::().unwrap(); + + let expected: Int32Array = expected.into(); + assert_eq!(expected.data(), result.data()); + Ok(()) + }) + } + + #[test] + fn null_large_string() -> Result<()> { + null_cases() + .into_iter() + .try_for_each(|(input, len, expected)| { + let array = LargeStringArray::from(input); + let result = bit_length(&array)?; + assert_eq!(len, result.len()); + let result = result.as_any().downcast_ref::().unwrap(); + + // convert to i64 + let expected: Int64Array = expected + .iter() + .map(|e| e.map(|e| e as i64)) + .collect::>() + .into(); + assert_eq!(expected.data(), result.data()); + Ok(()) + }) + } + + /// Tests that bit_length is not valid for u64. + #[test] + fn wrong_type() { + let array: UInt64Array = vec![1u64].into(); + + assert!(bit_length(&array).is_err()); + } + + /// Tests with an offset + #[test] + fn offsets() -> Result<()> { + let a = StringArray::from(vec!["hello", " ", "world"]); + let b = make_array( + ArrayData::builder(DataType::Utf8) + .len(2) + .offset(1) + .buffers(a.data_ref().buffers().to_vec()) + .build(), + ); + let result = bit_length(b.as_ref())?; + + let expected = Int32Array::from(vec![8, 40]); + assert_eq!(expected.data(), result.data()); + + Ok(()) + } +} diff --git a/rust/arrow/src/compute/kernels/mod.rs b/rust/arrow/src/compute/kernels/mod.rs index a8d24979e04..ef2f7736f4d 100644 --- a/rust/arrow/src/compute/kernels/mod.rs +++ b/rust/arrow/src/compute/kernels/mod.rs @@ -20,6 +20,7 @@ pub mod aggregate; pub mod arithmetic; pub mod arity; +pub mod bit_length; pub mod boolean; pub mod cast; pub mod cast_utils; diff --git a/rust/datafusion/Cargo.toml b/rust/datafusion/Cargo.toml index 11cc63bbdc3..f38bb487bdc 100644 --- a/rust/datafusion/Cargo.toml +++ b/rust/datafusion/Cargo.toml @@ -65,6 +65,8 @@ md-5 = "^0.9.1" sha2 = "^0.9.1" ordered-float = "2.0" unicode-segmentation = "^1.7.1" +regex = "1" +lazy_static = "^1.4.0" [dev-dependencies] rand = "0.8" diff --git a/rust/datafusion/src/lib.rs b/rust/datafusion/src/lib.rs index 1e86f664c2d..4c2ea9497af 100644 --- a/rust/datafusion/src/lib.rs +++ b/rust/datafusion/src/lib.rs @@ -156,6 +156,9 @@ extern crate arrow; extern crate sqlparser; +#[macro_use] +extern crate lazy_static; + pub mod dataframe; pub mod datasource; pub mod error; diff --git a/rust/datafusion/src/logical_plan/expr.rs b/rust/datafusion/src/logical_plan/expr.rs index 245ca3aaaa8..f224b519aa2 100644 --- a/rust/datafusion/src/logical_plan/expr.rs +++ b/rust/datafusion/src/logical_plan/expr.rs @@ -874,29 +874,39 @@ unary_scalar_expr!(Log2, log2); unary_scalar_expr!(Log10, log10); // string functions +unary_scalar_expr!(Ascii, ascii); unary_scalar_expr!(BitLength, bit_length); +unary_scalar_expr!(Btrim, btrim); unary_scalar_expr!(CharacterLength, character_length); -unary_scalar_expr!(CharacterLength, length); +unary_scalar_expr!(Chr, chr); +unary_scalar_expr!(Concat, concat); +unary_scalar_expr!(ConcatWithSeparator, concat_ws); +unary_scalar_expr!(InitCap, initcap); +unary_scalar_expr!(Left, left); unary_scalar_expr!(Lower, lower); +unary_scalar_expr!(Lpad, lpad); unary_scalar_expr!(Ltrim, ltrim); unary_scalar_expr!(MD5, md5); unary_scalar_expr!(OctetLength, octet_length); +unary_scalar_expr!(RegexpReplace, regexp_replace); +unary_scalar_expr!(Repeat, repeat); +unary_scalar_expr!(Replace, replace); +unary_scalar_expr!(Reverse, reverse); +unary_scalar_expr!(Right, right); +unary_scalar_expr!(Rpad, rpad); unary_scalar_expr!(Rtrim, rtrim); unary_scalar_expr!(SHA224, sha224); unary_scalar_expr!(SHA256, sha256); unary_scalar_expr!(SHA384, sha384); unary_scalar_expr!(SHA512, sha512); +unary_scalar_expr!(SplitPart, split_part); +unary_scalar_expr!(StartsWith, starts_with); +unary_scalar_expr!(Strpos, strpos); +unary_scalar_expr!(Substr, substr); +unary_scalar_expr!(Translate, translate); unary_scalar_expr!(Trim, trim); unary_scalar_expr!(Upper, upper); -/// returns the concatenation of string expressions -pub fn concat(args: Vec) -> Expr { - Expr::ScalarFunction { - fun: functions::BuiltinScalarFunction::Concat, - args, - } -} - /// returns an array of fixed size with each argument on it. pub fn array(args: Vec) -> Expr { Expr::ScalarFunction { diff --git a/rust/datafusion/src/logical_plan/mod.rs b/rust/datafusion/src/logical_plan/mod.rs index 0de0a032520..81fc6550542 100644 --- a/rust/datafusion/src/logical_plan/mod.rs +++ b/rust/datafusion/src/logical_plan/mod.rs @@ -33,12 +33,14 @@ pub use builder::LogicalPlanBuilder; pub use dfschema::{DFField, DFSchema, DFSchemaRef, ToDFSchema}; pub use display::display_schema; pub use expr::{ - abs, acos, and, array, asin, atan, avg, binary_expr, bit_length, case, ceil, - character_length, col, combine_filters, concat, cos, count, count_distinct, - create_udaf, create_udf, exp, exprlist_to_fields, floor, in_list, length, lit, ln, - log10, log2, lower, ltrim, max, md5, min, octet_length, or, round, rtrim, sha224, - sha256, sha384, sha512, signum, sin, sqrt, sum, tan, trim, trunc, upper, when, Expr, - ExpressionVisitor, Literal, Recursion, + abs, acos, and, array, ascii, asin, atan, avg, binary_expr, bit_length, btrim, case, + ceil, character_length, chr, col, combine_filters, concat, concat_ws, cos, count, + count_distinct, create_udaf, create_udf, exp, exprlist_to_fields, floor, in_list, + initcap, left, lit, ln, log10, log2, lower, lpad, ltrim, max, md5, min, octet_length, + or, regexp_replace, repeat, replace, reverse, right, round, rpad, rtrim, sha224, + sha256, sha384, sha512, signum, sin, split_part, sqrt, starts_with, strpos, substr, + sum, tan, translate, trim, trunc, upper, when, Expr, ExpressionVisitor, Literal, + Recursion, }; pub use extension::UserDefinedLogicalNode; pub use operators::Operator; diff --git a/rust/datafusion/src/physical_plan/functions.rs b/rust/datafusion/src/physical_plan/functions.rs index 51941188bb4..d9a91628eb1 100644 --- a/rust/datafusion/src/physical_plan/functions.rs +++ b/rust/datafusion/src/physical_plan/functions.rs @@ -56,6 +56,17 @@ use std::{any::Any, fmt, str::FromStr, sync::Arc}; /// A function's signature, which defines the function's supported argument types. #[derive(Debug, Clone, PartialEq)] pub enum Signature { + /// fixed number of arguments of arbitrary types + Any(usize), + /// exact number of arguments of an exact type + Exact(Vec), + /// One of a list of signatures + OneOf(Vec), + /// fixed number of arguments of an arbitrary but equal type out of a list of valid types + // A function of one argument of f64 is `Uniform(1, vec![DataType::Float64])` + // A function of one argument of f64 or f32 is `Uniform(1, vec![DataType::Float32, DataType::Float64])` + // A function of two arguments where both arguments must be the same type of f64 or f32 is `Uniform(2, vec![DataType::Float32, DataType::Float64])` + Uniform(usize, Vec), /// arbitrary number of arguments of an common type out of a list of valid types // A function such as `concat` is `Variadic(vec![DataType::Utf8, DataType::LargeUtf8])` Variadic(Vec), @@ -63,16 +74,6 @@ pub enum Signature { // A function such as `array` is `VariadicEqual` // The first argument decides the type used for coercion VariadicEqual, - /// fixed number of arguments of an arbitrary but equal type out of a list of valid types - // A function of one argument of f64 is `Uniform(1, vec![DataType::Float64])` - // A function of one argument of f64 or f32 is `Uniform(1, vec![DataType::Float32, DataType::Float64])` - Uniform(usize, Vec), - /// exact number of arguments of an exact type - Exact(Vec), - /// fixed number of arguments of arbitrary types - Any(usize), - /// One of a list of signatures - OneOf(Vec), } /// Scalar function @@ -120,20 +121,64 @@ pub enum BuiltinScalarFunction { Abs, /// signum Signum, + /// ascii + Ascii, + /// bit_length + BitLength, + /// btrim + Btrim, + /// character_length + CharacterLength, + /// chr + Chr, /// concat Concat, + /// concat_ws + ConcatWithSeparator, + /// initcap + InitCap, + /// left + Left, + /// lpad + Lpad, /// lower Lower, - /// upper - Upper, - /// trim - Trim, /// trim left Ltrim, - /// trim right + /// length + OctetLength, + /// regexp_replace + RegexpReplace, + /// repeat + Repeat, + /// replace + Replace, + /// reverse + Reverse, + /// right + Right, + /// rpad + Rpad, + /// rtrim Rtrim, + /// split_part + SplitPart, + /// starts_with + StartsWith, + /// strpos + Strpos, + /// substr + Substr, + /// to_hex + ToHex, + /// translate + Translate, + /// trim + Trim, /// to_timestamp ToTimestamp, + /// upper + Upper, /// construct an array from columns Array, /// SQL NULLIF() @@ -152,12 +197,6 @@ pub enum BuiltinScalarFunction { SHA384, /// SHA512, SHA512, - /// bit_length - BitLength, - /// character_length - CharacterLength, - /// octet_length - OctetLength, } impl fmt::Display for BuiltinScalarFunction { @@ -188,11 +227,35 @@ impl FromStr for BuiltinScalarFunction { "truc" => BuiltinScalarFunction::Trunc, "abs" => BuiltinScalarFunction::Abs, "signum" => BuiltinScalarFunction::Signum, + "ascii" => BuiltinScalarFunction::Ascii, + "bit_length" => BuiltinScalarFunction::BitLength, + "btrim" => BuiltinScalarFunction::Btrim, + "chr" => BuiltinScalarFunction::Chr, + "char_length" => BuiltinScalarFunction::CharacterLength, + "character_length" => BuiltinScalarFunction::CharacterLength, "concat" => BuiltinScalarFunction::Concat, + "concat_ws" => BuiltinScalarFunction::ConcatWithSeparator, + "initcap" => BuiltinScalarFunction::InitCap, + "left" => BuiltinScalarFunction::Left, + "length" => BuiltinScalarFunction::CharacterLength, + "lpad" => BuiltinScalarFunction::Lpad, "lower" => BuiltinScalarFunction::Lower, - "trim" => BuiltinScalarFunction::Trim, "ltrim" => BuiltinScalarFunction::Ltrim, + "octet_length" => BuiltinScalarFunction::OctetLength, + "regexp_replace" => BuiltinScalarFunction::RegexpReplace, + "repeat" => BuiltinScalarFunction::Repeat, + "replace" => BuiltinScalarFunction::Replace, + "reverse" => BuiltinScalarFunction::Reverse, + "right" => BuiltinScalarFunction::Right, + "rpad" => BuiltinScalarFunction::Rpad, "rtrim" => BuiltinScalarFunction::Rtrim, + "split_part" => BuiltinScalarFunction::SplitPart, + "starts_with" => BuiltinScalarFunction::StartsWith, + "strpos" => BuiltinScalarFunction::Strpos, + "substr" => BuiltinScalarFunction::Substr, + "to_hex" => BuiltinScalarFunction::ToHex, + "translate" => BuiltinScalarFunction::Translate, + "trim" => BuiltinScalarFunction::Trim, "upper" => BuiltinScalarFunction::Upper, "to_timestamp" => BuiltinScalarFunction::ToTimestamp, "date_trunc" => BuiltinScalarFunction::DateTrunc, @@ -204,11 +267,6 @@ impl FromStr for BuiltinScalarFunction { "sha256" => BuiltinScalarFunction::SHA256, "sha384" => BuiltinScalarFunction::SHA384, "sha512" => BuiltinScalarFunction::SHA512, - "bit_length" => BuiltinScalarFunction::BitLength, - "octet_length" => BuiltinScalarFunction::OctetLength, - "length" => BuiltinScalarFunction::CharacterLength, - "char_length" => BuiltinScalarFunction::CharacterLength, - "character_length" => BuiltinScalarFunction::CharacterLength, _ => { return Err(DataFusionError::Plan(format!( "There is no built-in function named {}", @@ -242,14 +300,77 @@ pub fn return_type( // the return type of the built in function. // Some built-in functions' return type depends on the incoming type. match fun { + BuiltinScalarFunction::Ascii => Ok(DataType::Int32), + BuiltinScalarFunction::BitLength => Ok(match arg_types[0] { + DataType::LargeUtf8 => DataType::Int64, + DataType::Utf8 => DataType::Int32, + _ => { + // this error is internal as `data_types` should have captured this. + return Err(DataFusionError::Internal( + "The bit_length function can only accept strings.".to_string(), + )); + } + }), + BuiltinScalarFunction::Btrim => Ok(match arg_types[0] { + DataType::LargeUtf8 => DataType::LargeUtf8, + DataType::Utf8 => DataType::Utf8, + _ => { + // this error is internal as `data_types` should have captured this. + return Err(DataFusionError::Internal( + "The btrim function can only accept strings.".to_string(), + )); + } + }), + BuiltinScalarFunction::CharacterLength => Ok(match arg_types[0] { + DataType::LargeUtf8 => DataType::Int64, + DataType::Utf8 => DataType::Int32, + _ => { + // this error is internal as `data_types` should have captured this. + return Err(DataFusionError::Internal( + "The character_length function can only accept strings.".to_string(), + )); + } + }), + BuiltinScalarFunction::Chr => Ok(DataType::Utf8), BuiltinScalarFunction::Concat => Ok(DataType::Utf8), + BuiltinScalarFunction::ConcatWithSeparator => Ok(DataType::Utf8), + BuiltinScalarFunction::InitCap => Ok(match arg_types[0] { + DataType::LargeUtf8 => DataType::LargeUtf8, + DataType::Utf8 => DataType::Utf8, + _ => { + // this error is internal as `data_types` should have captured this. + return Err(DataFusionError::Internal( + "The initcap function can only accept strings.".to_string(), + )); + } + }), + BuiltinScalarFunction::Left => Ok(match arg_types[0] { + DataType::LargeUtf8 => DataType::LargeUtf8, + DataType::Utf8 => DataType::Utf8, + _ => { + // this error is internal as `data_types` should have captured this. + return Err(DataFusionError::Internal( + "The left function can only accept strings.".to_string(), + )); + } + }), BuiltinScalarFunction::Lower => Ok(match arg_types[0] { DataType::LargeUtf8 => DataType::LargeUtf8, DataType::Utf8 => DataType::Utf8, _ => { // this error is internal as `data_types` should have captured this. return Err(DataFusionError::Internal( - "The upper function can only accept strings.".to_string(), + "The left function can only accept strings.".to_string(), + )); + } + }), + BuiltinScalarFunction::Lpad => Ok(match arg_types[0] { + DataType::LargeUtf8 => DataType::LargeUtf8, + DataType::Utf8 => DataType::Utf8, + _ => { + // this error is internal as `data_types` should have captured this. + return Err(DataFusionError::Internal( + "The lpad function can only accept strings.".to_string(), )); } }), @@ -263,6 +384,76 @@ pub fn return_type( )); } }), + BuiltinScalarFunction::OctetLength => Ok(match arg_types[0] { + DataType::LargeUtf8 => DataType::Int64, + DataType::Utf8 => DataType::Int32, + _ => { + // this error is internal as `data_types` should have captured this. + return Err(DataFusionError::Internal( + "The octet_length function can only accept strings.".to_string(), + )); + } + }), + BuiltinScalarFunction::RegexpReplace => Ok(match arg_types[0] { + DataType::LargeUtf8 => DataType::LargeUtf8, + DataType::Utf8 => DataType::Utf8, + _ => { + // this error is internal as `data_types` should have captured this. + return Err(DataFusionError::Internal( + "The regexp_replace function can only accept strings.".to_string(), + )); + } + }), + BuiltinScalarFunction::Repeat => Ok(match arg_types[0] { + DataType::LargeUtf8 => DataType::LargeUtf8, + DataType::Utf8 => DataType::Utf8, + _ => { + // this error is internal as `data_types` should have captured this. + return Err(DataFusionError::Internal( + "The repeat function can only accept strings.".to_string(), + )); + } + }), + BuiltinScalarFunction::Replace => Ok(match arg_types[0] { + DataType::LargeUtf8 => DataType::LargeUtf8, + DataType::Utf8 => DataType::Utf8, + _ => { + // this error is internal as `data_types` should have captured this. + return Err(DataFusionError::Internal( + "The replace function can only accept strings.".to_string(), + )); + } + }), + BuiltinScalarFunction::Reverse => Ok(match arg_types[0] { + DataType::LargeUtf8 => DataType::LargeUtf8, + DataType::Utf8 => DataType::Utf8, + _ => { + // this error is internal as `data_types` should have captured this. + return Err(DataFusionError::Internal( + "The reverse function can only accept strings.".to_string(), + )); + } + }), + BuiltinScalarFunction::Right => Ok(match arg_types[0] { + DataType::LargeUtf8 => DataType::LargeUtf8, + DataType::Utf8 => DataType::Utf8, + _ => { + // this error is internal as `data_types` should have captured this. + return Err(DataFusionError::Internal( + "The right function can only accept strings.".to_string(), + )); + } + }), + BuiltinScalarFunction::Rpad => Ok(match arg_types[0] { + DataType::LargeUtf8 => DataType::LargeUtf8, + DataType::Utf8 => DataType::Utf8, + _ => { + // this error is internal as `data_types` should have captured this. + return Err(DataFusionError::Internal( + "The rpad function can only accept strings.".to_string(), + )); + } + }), BuiltinScalarFunction::Rtrim => Ok(match arg_types[0] { DataType::LargeUtf8 => DataType::LargeUtf8, DataType::Utf8 => DataType::Utf8, @@ -273,6 +464,58 @@ pub fn return_type( )); } }), + BuiltinScalarFunction::SplitPart => Ok(match arg_types[0] { + DataType::LargeUtf8 => DataType::LargeUtf8, + DataType::Utf8 => DataType::Utf8, + _ => { + // this error is internal as `data_types` should have captured this. + return Err(DataFusionError::Internal( + "The split_part function can only accept strings.".to_string(), + )); + } + }), + BuiltinScalarFunction::StartsWith => Ok(DataType::Boolean), + BuiltinScalarFunction::Strpos => Ok(match arg_types[0] { + DataType::LargeUtf8 => DataType::Int64, + DataType::Utf8 => DataType::Int32, + _ => { + // this error is internal as `data_types` should have captured this. + return Err(DataFusionError::Internal( + "The strpos function can only accept strings.".to_string(), + )); + } + }), + BuiltinScalarFunction::Substr => Ok(match arg_types[0] { + DataType::LargeUtf8 => DataType::LargeUtf8, + DataType::Utf8 => DataType::Utf8, + _ => { + // this error is internal as `data_types` should have captured this. + return Err(DataFusionError::Internal( + "The substr function can only accept strings.".to_string(), + )); + } + }), + BuiltinScalarFunction::ToHex => Ok(match arg_types[0] { + DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => { + DataType::Utf8 + } + _ => { + // this error is internal as `data_types` should have captured this. + return Err(DataFusionError::Internal( + "The to_hex function can only accept integers.".to_string(), + )); + } + }), + BuiltinScalarFunction::Translate => Ok(match arg_types[0] { + DataType::LargeUtf8 => DataType::LargeUtf8, + DataType::Utf8 => DataType::Utf8, + _ => { + // this error is internal as `data_types` should have captured this. + return Err(DataFusionError::Internal( + "The translate function can only accept strings.".to_string(), + )); + } + }), BuiltinScalarFunction::Trim => Ok(match arg_types[0] { DataType::LargeUtf8 => DataType::LargeUtf8, DataType::Utf8 => DataType::Utf8, @@ -359,36 +602,6 @@ pub fn return_type( )); } }), - BuiltinScalarFunction::BitLength => Ok(match arg_types[0] { - DataType::LargeUtf8 => DataType::Int64, - DataType::Utf8 => DataType::Int32, - _ => { - // this error is internal as `data_types` should have captured this. - return Err(DataFusionError::Internal( - "The bit_length function can only accept strings.".to_string(), - )); - } - }), - BuiltinScalarFunction::CharacterLength => Ok(match arg_types[0] { - DataType::LargeUtf8 => DataType::Int64, - DataType::Utf8 => DataType::Int32, - _ => { - // this error is internal as `data_types` should have captured this. - return Err(DataFusionError::Internal( - "The character_length function can only accept strings.".to_string(), - )); - } - }), - BuiltinScalarFunction::OctetLength => Ok(match arg_types[0] { - DataType::LargeUtf8 => DataType::Int64, - DataType::Utf8 => DataType::Int32, - _ => { - // this error is internal as `data_types` should have captured this. - return Err(DataFusionError::Internal( - "The octet_length function can only accept strings.".to_string(), - )); - } - }), _ => Ok(DataType::Float64), } } @@ -424,15 +637,21 @@ pub fn create_physical_expr( BuiltinScalarFunction::SHA256 => crypto_expressions::sha256, BuiltinScalarFunction::SHA384 => crypto_expressions::sha384, BuiltinScalarFunction::SHA512 => crypto_expressions::sha512, - BuiltinScalarFunction::Concat => string_expressions::concatenate, - BuiltinScalarFunction::Lower => string_expressions::lower, - BuiltinScalarFunction::Trim => string_expressions::trim, - BuiltinScalarFunction::Ltrim => string_expressions::ltrim, - BuiltinScalarFunction::Rtrim => string_expressions::rtrim, - BuiltinScalarFunction::Upper => string_expressions::upper, BuiltinScalarFunction::ToTimestamp => datetime_expressions::to_timestamp, BuiltinScalarFunction::DateTrunc => datetime_expressions::date_trunc, BuiltinScalarFunction::Array => array_expressions::array, + BuiltinScalarFunction::Ascii => |args| match args[0].data_type() { + DataType::Utf8 => { + make_scalar_function(string_expressions::ascii::)(args) + } + DataType::LargeUtf8 => { + make_scalar_function(string_expressions::ascii::)(args) + } + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {:?} for function ascii", + other, + ))), + }, BuiltinScalarFunction::BitLength => |args| match &args[0] { ColumnarValue::Array(v) => Ok(ColumnarValue::Array(bit_length(v.as_ref())?)), ColumnarValue::Scalar(v) => match v { @@ -445,6 +664,18 @@ pub fn create_physical_expr( _ => unreachable!(), }, }, + BuiltinScalarFunction::Btrim => |args| match args[0].data_type() { + DataType::Utf8 => { + make_scalar_function(string_expressions::btrim::)(args) + } + DataType::LargeUtf8 => { + make_scalar_function(string_expressions::btrim::)(args) + } + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {:?} for function btrim", + other, + ))), + }, BuiltinScalarFunction::CharacterLength => |args| match args[0].data_type() { DataType::Utf8 => make_scalar_function( string_expressions::character_length::, @@ -457,19 +688,238 @@ pub fn create_physical_expr( other, ))), }, - BuiltinScalarFunction::OctetLength => |args| match &args[0] { - ColumnarValue::Array(v) => Ok(ColumnarValue::Array(length(v.as_ref())?)), - ColumnarValue::Scalar(v) => match v { - ScalarValue::Utf8(v) => Ok(ColumnarValue::Scalar(ScalarValue::Int32( - v.as_ref().map(|x| x.len() as i32), - ))), - ScalarValue::LargeUtf8(v) => Ok(ColumnarValue::Scalar( - ScalarValue::Int64(v.as_ref().map(|x| x.len() as i64)), - )), - _ => unreachable!(), - }, - }, + BuiltinScalarFunction::Chr => { + |args| make_scalar_function(string_expressions::chr)(args) + } + BuiltinScalarFunction::Concat => string_expressions::concat, + BuiltinScalarFunction::ConcatWithSeparator => { + |args| make_scalar_function(string_expressions::concat_ws)(args) + } + BuiltinScalarFunction::InitCap => |args| match args[0].data_type() { + DataType::Utf8 => { + make_scalar_function(string_expressions::initcap::)(args) + } + DataType::LargeUtf8 => { + make_scalar_function(string_expressions::initcap::)(args) + } + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {:?} for function initcap", + other, + ))), + }, + BuiltinScalarFunction::Left => |args| match args[0].data_type() { + DataType::Utf8 => make_scalar_function(string_expressions::left::)(args), + DataType::LargeUtf8 => { + make_scalar_function(string_expressions::left::)(args) + } + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {:?} for function left", + other, + ))), + }, + BuiltinScalarFunction::Lower => string_expressions::lower, + BuiltinScalarFunction::Lpad => |args| match args[0].data_type() { + DataType::Utf8 => make_scalar_function(string_expressions::lpad::)(args), + DataType::LargeUtf8 => { + make_scalar_function(string_expressions::lpad::)(args) + } + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {:?} for function lpad", + other, + ))), + }, + BuiltinScalarFunction::Ltrim => |args| match args[0].data_type() { + DataType::Utf8 => { + make_scalar_function(string_expressions::ltrim::)(args) + } + DataType::LargeUtf8 => { + make_scalar_function(string_expressions::ltrim::)(args) + } + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {:?} for function ltrim", + other, + ))), + }, + BuiltinScalarFunction::OctetLength => |args| match &args[0] { + ColumnarValue::Array(v) => Ok(ColumnarValue::Array(length(v.as_ref())?)), + ColumnarValue::Scalar(v) => match v { + ScalarValue::Utf8(v) => Ok(ColumnarValue::Scalar(ScalarValue::Int32( + v.as_ref().map(|x| x.len() as i32), + ))), + ScalarValue::LargeUtf8(v) => Ok(ColumnarValue::Scalar( + ScalarValue::Int64(v.as_ref().map(|x| x.len() as i64)), + )), + _ => unreachable!(), + }, + }, BuiltinScalarFunction::DatePart => datetime_expressions::date_part, + BuiltinScalarFunction::RegexpReplace => |args| match args[0].data_type() { + DataType::Utf8 => { + make_scalar_function(string_expressions::regexp_replace::)(args) + } + DataType::LargeUtf8 => { + make_scalar_function(string_expressions::regexp_replace::)(args) + } + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {:?} for function regexp_replace", + other, + ))), + }, + BuiltinScalarFunction::Repeat => |args| match args[0].data_type() { + DataType::Utf8 => { + make_scalar_function(string_expressions::repeat::)(args) + } + DataType::LargeUtf8 => { + make_scalar_function(string_expressions::repeat::)(args) + } + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {:?} for function repeat", + other, + ))), + }, + BuiltinScalarFunction::Replace => |args| match args[0].data_type() { + DataType::Utf8 => { + make_scalar_function(string_expressions::replace::)(args) + } + DataType::LargeUtf8 => { + make_scalar_function(string_expressions::replace::)(args) + } + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {:?} for function replace", + other, + ))), + }, + BuiltinScalarFunction::Reverse => |args| match args[0].data_type() { + DataType::Utf8 => { + make_scalar_function(string_expressions::reverse::)(args) + } + DataType::LargeUtf8 => { + make_scalar_function(string_expressions::reverse::)(args) + } + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {:?} for function reverse", + other, + ))), + }, + BuiltinScalarFunction::Right => |args| match args[0].data_type() { + DataType::Utf8 => { + make_scalar_function(string_expressions::right::)(args) + } + DataType::LargeUtf8 => { + make_scalar_function(string_expressions::right::)(args) + } + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {:?} for function right", + other, + ))), + }, + BuiltinScalarFunction::Rpad => |args| match args[0].data_type() { + DataType::Utf8 => make_scalar_function(string_expressions::rpad::)(args), + DataType::LargeUtf8 => { + make_scalar_function(string_expressions::rpad::)(args) + } + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {:?} for function rpad", + other, + ))), + }, + BuiltinScalarFunction::Rtrim => |args| match args[0].data_type() { + DataType::Utf8 => { + make_scalar_function(string_expressions::rtrim::)(args) + } + DataType::LargeUtf8 => { + make_scalar_function(string_expressions::rtrim::)(args) + } + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {:?} for function rtrim", + other, + ))), + }, + BuiltinScalarFunction::SplitPart => |args| match args[0].data_type() { + DataType::Utf8 => { + make_scalar_function(string_expressions::split_part::)(args) + } + DataType::LargeUtf8 => { + make_scalar_function(string_expressions::split_part::)(args) + } + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {:?} for function split_part", + other, + ))), + }, + BuiltinScalarFunction::StartsWith => |args| match args[0].data_type() { + DataType::Utf8 => { + make_scalar_function(string_expressions::starts_with::)(args) + } + DataType::LargeUtf8 => { + make_scalar_function(string_expressions::starts_with::)(args) + } + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {:?} for function starts_with", + other, + ))), + }, + BuiltinScalarFunction::Strpos => |args| match args[0].data_type() { + DataType::Utf8 => { + make_scalar_function(string_expressions::strpos::)(args) + } + DataType::LargeUtf8 => { + make_scalar_function(string_expressions::strpos::)(args) + } + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {:?} for function strpos", + other, + ))), + }, + BuiltinScalarFunction::Substr => |args| match args[0].data_type() { + DataType::Utf8 => { + make_scalar_function(string_expressions::substr::)(args) + } + DataType::LargeUtf8 => { + make_scalar_function(string_expressions::substr::)(args) + } + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {:?} for function substr", + other, + ))), + }, + BuiltinScalarFunction::ToHex => |args| match args[0].data_type() { + DataType::Int32 => { + make_scalar_function(string_expressions::to_hex::)(args) + } + DataType::Int64 => { + make_scalar_function(string_expressions::to_hex::)(args) + } + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {:?} for function to_hex", + other, + ))), + }, + BuiltinScalarFunction::Translate => |args| match args[0].data_type() { + DataType::Utf8 => { + make_scalar_function(string_expressions::translate::)(args) + } + DataType::LargeUtf8 => { + make_scalar_function(string_expressions::translate::)(args) + } + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {:?} for function translate", + other, + ))), + }, + BuiltinScalarFunction::Trim => |args| match args[0].data_type() { + DataType::Utf8 => { + make_scalar_function(string_expressions::btrim::)(args) + } + DataType::LargeUtf8 => { + make_scalar_function(string_expressions::btrim::)(args) + } + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {:?} for function trim", + other, + ))), + }, + BuiltinScalarFunction::Upper => string_expressions::upper, }); // coerce let args = coerce(args, input_schema, &signature(fun))?; @@ -491,24 +941,82 @@ pub fn create_physical_expr( fn signature(fun: &BuiltinScalarFunction) -> Signature { // note: the physical expression must accept the type returned by this function or the execution panics. - // for now, the list is small, as we do not have many built-in functions. match fun { BuiltinScalarFunction::Concat => Signature::Variadic(vec![DataType::Utf8]), - BuiltinScalarFunction::Upper - | BuiltinScalarFunction::Lower + BuiltinScalarFunction::ConcatWithSeparator => { + Signature::Variadic(vec![DataType::Utf8]) + } + BuiltinScalarFunction::Ascii | BuiltinScalarFunction::BitLength | BuiltinScalarFunction::CharacterLength - | BuiltinScalarFunction::OctetLength - | BuiltinScalarFunction::Trim - | BuiltinScalarFunction::Ltrim - | BuiltinScalarFunction::Rtrim + | BuiltinScalarFunction::InitCap + | BuiltinScalarFunction::Lower | BuiltinScalarFunction::MD5 + | BuiltinScalarFunction::OctetLength + | BuiltinScalarFunction::Reverse | BuiltinScalarFunction::SHA224 | BuiltinScalarFunction::SHA256 | BuiltinScalarFunction::SHA384 - | BuiltinScalarFunction::SHA512 => { + | BuiltinScalarFunction::SHA512 + | BuiltinScalarFunction::Trim + | BuiltinScalarFunction::Upper => { Signature::Uniform(1, vec![DataType::Utf8, DataType::LargeUtf8]) } + BuiltinScalarFunction::Btrim + | BuiltinScalarFunction::Ltrim + | BuiltinScalarFunction::Rtrim => { + Signature::Variadic(vec![DataType::Utf8, DataType::LargeUtf8]) + } + BuiltinScalarFunction::Lpad | BuiltinScalarFunction::Rpad => { + Signature::OneOf(vec![ + Signature::Exact(vec![DataType::Utf8, DataType::Int64]), + Signature::Exact(vec![DataType::LargeUtf8, DataType::Int64]), + Signature::Exact(vec![DataType::Utf8, DataType::Int64, DataType::Utf8]), + Signature::Exact(vec![ + DataType::LargeUtf8, + DataType::Int64, + DataType::Utf8, + ]), + Signature::Exact(vec![ + DataType::Utf8, + DataType::Int64, + DataType::LargeUtf8, + ]), + Signature::Exact(vec![ + DataType::LargeUtf8, + DataType::Int64, + DataType::LargeUtf8, + ]), + ]) + } + BuiltinScalarFunction::SplitPart => Signature::OneOf(vec![ + Signature::Exact(vec![DataType::Utf8, DataType::Utf8, DataType::Int64]), + Signature::Exact(vec![DataType::LargeUtf8, DataType::Utf8, DataType::Int64]), + Signature::Exact(vec![DataType::Utf8, DataType::LargeUtf8, DataType::Int64]), + Signature::Exact(vec![ + DataType::LargeUtf8, + DataType::LargeUtf8, + DataType::Int64, + ]), + ]), + BuiltinScalarFunction::Strpos | BuiltinScalarFunction::StartsWith => { + Signature::Uniform(2, vec![DataType::Utf8, DataType::LargeUtf8]) + } + BuiltinScalarFunction::Substr => Signature::OneOf(vec![ + Signature::Exact(vec![DataType::Utf8, DataType::Int64]), + Signature::Exact(vec![DataType::LargeUtf8, DataType::Int64]), + Signature::Exact(vec![DataType::Utf8, DataType::Int64, DataType::Int64]), + Signature::Exact(vec![DataType::LargeUtf8, DataType::Int64, DataType::Int64]), + ]), + BuiltinScalarFunction::RegexpReplace => Signature::OneOf(vec![ + Signature::Exact(vec![DataType::Utf8, DataType::Utf8, DataType::Utf8]), + Signature::Exact(vec![ + DataType::Utf8, + DataType::Utf8, + DataType::Utf8, + DataType::Utf8, + ]), + ]), BuiltinScalarFunction::ToTimestamp => Signature::Uniform(1, vec![DataType::Utf8]), BuiltinScalarFunction::DateTrunc => Signature::Exact(vec![ DataType::Utf8, @@ -540,6 +1048,21 @@ fn signature(fun: &BuiltinScalarFunction) -> Signature { BuiltinScalarFunction::NullIf => { Signature::Uniform(2, SUPPORTED_NULLIF_TYPES.to_vec()) } + BuiltinScalarFunction::Left + | BuiltinScalarFunction::Repeat + | BuiltinScalarFunction::Right => Signature::OneOf(vec![Signature::Exact(vec![ + DataType::Utf8, + DataType::Int64, + ])]), + BuiltinScalarFunction::Chr => Signature::Uniform(1, vec![DataType::Int64]), + BuiltinScalarFunction::Replace | BuiltinScalarFunction::Translate => { + Signature::Uniform(3, vec![DataType::Utf8, DataType::LargeUtf8]) + } + BuiltinScalarFunction::ToHex => Signature::OneOf(vec![ + Signature::Exact(vec![DataType::Int32]), + Signature::Exact(vec![DataType::Int64]), + ]), + // math expressions expect 1 argument of type f64 or f32 // priority is given to f64 because e.g. `sqrt(1i32)` is in IR (real numbers) and thus we // return the best approximation for it (in f64). @@ -696,8 +1219,8 @@ mod tests { }; use arrow::{ array::{ - Array, ArrayRef, FixedSizeListArray, Float64Array, Int32Array, StringArray, - UInt32Array, UInt64Array, + Array, ArrayRef, BooleanArray, FixedSizeListArray, Float64Array, Int32Array, + StringArray, UInt32Array, UInt64Array, }, datatypes::Field, record_batch::RecordBatch, @@ -753,6 +1276,154 @@ mod tests { #[test] fn test_functions() -> Result<()> { + test_function!( + Ascii, + &[lit(ScalarValue::Utf8(Some("x".to_string())))], + Ok(Some(120)), + i32, + Int32, + Int32Array + ); + test_function!( + Ascii, + &[lit(ScalarValue::Utf8(Some("ésoj".to_string())))], + Ok(Some(233)), + i32, + Int32, + Int32Array + ); + test_function!( + Ascii, + &[lit(ScalarValue::Utf8(Some("💯".to_string())))], + Ok(Some(128175)), + i32, + Int32, + Int32Array + ); + test_function!( + Ascii, + &[lit(ScalarValue::Utf8(Some("💯a".to_string())))], + Ok(Some(128175)), + i32, + Int32, + Int32Array + ); + test_function!( + Ascii, + &[lit(ScalarValue::Utf8(Some("".to_string())))], + Ok(Some(0)), + i32, + Int32, + Int32Array + ); + test_function!( + Ascii, + &[lit(ScalarValue::Utf8(None))], + Ok(None), + i32, + Int32, + Int32Array + ); + test_function!( + BitLength, + &[lit(ScalarValue::Utf8(Some("chars".to_string())))], + Ok(Some(40)), + i32, + Int32, + Int32Array + ); + test_function!( + BitLength, + &[lit(ScalarValue::Utf8(Some("josé".to_string())))], + Ok(Some(40)), + i32, + Int32, + Int32Array + ); + test_function!( + BitLength, + &[lit(ScalarValue::Utf8(Some("".to_string())))], + Ok(Some(0)), + i32, + Int32, + Int32Array + ); + test_function!( + Btrim, + &[lit(ScalarValue::Utf8(Some(" trim ".to_string())))], + Ok(Some("trim")), + &str, + Utf8, + StringArray + ); + test_function!( + Btrim, + &[lit(ScalarValue::Utf8(Some(" trim".to_string())))], + Ok(Some("trim")), + &str, + Utf8, + StringArray + ); + test_function!( + Btrim, + &[lit(ScalarValue::Utf8(Some("trim ".to_string())))], + Ok(Some("trim")), + &str, + Utf8, + StringArray + ); + test_function!( + Btrim, + &[lit(ScalarValue::Utf8(Some("\n trim \n".to_string())))], + Ok(Some("\n trim \n")), + &str, + Utf8, + StringArray + ); + test_function!( + Btrim, + &[ + lit(ScalarValue::Utf8(Some("xyxtrimyyx".to_string()))), + lit(ScalarValue::Utf8(Some("xyz".to_string()))), + ], + Ok(Some("trim")), + &str, + Utf8, + StringArray + ); + test_function!( + Btrim, + &[ + lit(ScalarValue::Utf8(Some("\nxyxtrimyyx\n".to_string()))), + lit(ScalarValue::Utf8(Some("xyz\n".to_string()))), + ], + Ok(Some("trim")), + &str, + Utf8, + StringArray + ); + test_function!( + Btrim, + &[ + lit(ScalarValue::Utf8(None)), + lit(ScalarValue::Utf8(Some("xyz".to_string()))), + ], + Ok(None), + &str, + Utf8, + StringArray + ); + test_function!( + Btrim, + &[ + lit(ScalarValue::Utf8(Some("xyxtrimyyx".to_string()))), + lit(ScalarValue::Utf8(None)), + ], + Ok(None), + &str, + Utf8, + StringArray + ); test_function!( CharacterLength, &[lit(ScalarValue::Utf8(Some("chars".to_string())))], @@ -785,6 +1456,148 @@ mod tests { Int32, Int32Array ); + test_function!( + Chr, + &[lit(ScalarValue::Int64(Some(128175)))], + Ok(Some("💯")), + &str, + Utf8, + StringArray + ); + test_function!( + Chr, + &[lit(ScalarValue::Int64(None))], + Ok(None), + &str, + Utf8, + StringArray + ); + test_function!( + Chr, + &[lit(ScalarValue::Int64(Some(120)))], + Ok(Some("x")), + &str, + Utf8, + StringArray + ); + test_function!( + Chr, + &[lit(ScalarValue::Int64(Some(128175)))], + Ok(Some("💯")), + &str, + Utf8, + StringArray + ); + test_function!( + Chr, + &[lit(ScalarValue::Int64(None))], + Ok(None), + &str, + Utf8, + StringArray + ); + test_function!( + Chr, + &[lit(ScalarValue::Int64(Some(0)))], + Err(DataFusionError::Execution( + "null character not permitted.".to_string(), + )), + &str, + Utf8, + StringArray + ); + test_function!( + Chr, + &[lit(ScalarValue::Int64(Some(i64::MAX)))], + Err(DataFusionError::Execution( + "requested character too large for encoding.".to_string(), + )), + &str, + Utf8, + StringArray + ); + test_function!( + Concat, + &[ + lit(ScalarValue::Utf8(Some("aa".to_string()))), + lit(ScalarValue::Utf8(Some("bb".to_string()))), + lit(ScalarValue::Utf8(Some("cc".to_string()))), + ], + Ok(Some("aabbcc")), + &str, + Utf8, + StringArray + ); + test_function!( + Concat, + &[ + lit(ScalarValue::Utf8(Some("aa".to_string()))), + lit(ScalarValue::Utf8(None)), + lit(ScalarValue::Utf8(Some("cc".to_string()))), + ], + Ok(Some("aacc")), + &str, + Utf8, + StringArray + ); + test_function!( + Concat, + &[lit(ScalarValue::Utf8(None))], + Ok(Some("")), + &str, + Utf8, + StringArray + ); + test_function!( + ConcatWithSeparator, + &[ + lit(ScalarValue::Utf8(Some("|".to_string()))), + lit(ScalarValue::Utf8(Some("aa".to_string()))), + lit(ScalarValue::Utf8(Some("bb".to_string()))), + lit(ScalarValue::Utf8(Some("cc".to_string()))), + ], + Ok(Some("aa|bb|cc")), + &str, + Utf8, + StringArray + ); + test_function!( + ConcatWithSeparator, + &[ + lit(ScalarValue::Utf8(Some("|".to_string()))), + lit(ScalarValue::Utf8(None)), + ], + Ok(Some("")), + &str, + Utf8, + StringArray + ); + test_function!( + ConcatWithSeparator, + &[ + lit(ScalarValue::Utf8(None)), + lit(ScalarValue::Utf8(Some("aa".to_string()))), + lit(ScalarValue::Utf8(Some("bb".to_string()))), + lit(ScalarValue::Utf8(Some("cc".to_string()))), + ], + Ok(None), + &str, + Utf8, + StringArray + ); + test_function!( + ConcatWithSeparator, + &[ + lit(ScalarValue::Utf8(Some("|".to_string()))), + lit(ScalarValue::Utf8(Some("aa".to_string()))), + lit(ScalarValue::Utf8(None)), + lit(ScalarValue::Utf8(Some("cc".to_string()))), + ], + Ok(Some("aa|cc")), + &str, + Utf8, + StringArray + ); test_function!( Exp, &[lit(ScalarValue::Int32(Some(1)))], @@ -825,42 +1638,1432 @@ mod tests { Float64, Float64Array ); + test_function!( + InitCap, + &[lit(ScalarValue::Utf8(Some("hi THOMAS".to_string())))], + Ok(Some("Hi Thomas")), + &str, + Utf8, + StringArray + ); + test_function!( + InitCap, + &[lit(ScalarValue::Utf8(Some("".to_string())))], + Ok(Some("")), + &str, + Utf8, + StringArray + ); + test_function!( + InitCap, + &[lit(ScalarValue::Utf8(Some("".to_string())))], + Ok(Some("")), + &str, + Utf8, + StringArray + ); + test_function!( + InitCap, + &[lit(ScalarValue::Utf8(None))], + Ok(None), + &str, + Utf8, + StringArray + ); + test_function!( + Left, + &[ + lit(ScalarValue::Utf8(Some("abcde".to_string()))), + lit(ScalarValue::Int8(Some(2))), + ], + Ok(Some("ab")), + &str, + Utf8, + StringArray + ); + test_function!( + Left, + &[ + lit(ScalarValue::Utf8(Some("abcde".to_string()))), + lit(ScalarValue::Int64(Some(200))), + ], + Ok(Some("abcde")), + &str, + Utf8, + StringArray + ); + test_function!( + Left, + &[ + lit(ScalarValue::Utf8(Some("abcde".to_string()))), + lit(ScalarValue::Int64(Some(-2))), + ], + Ok(Some("abc")), + &str, + Utf8, + StringArray + ); + test_function!( + Left, + &[ + lit(ScalarValue::Utf8(Some("abcde".to_string()))), + lit(ScalarValue::Int64(Some(-200))), + ], + Ok(Some("")), + &str, + Utf8, + StringArray + ); + test_function!( + Left, + &[ + lit(ScalarValue::Utf8(Some("abcde".to_string()))), + lit(ScalarValue::Int64(Some(0))), + ], + Ok(Some("")), + &str, + Utf8, + StringArray + ); + test_function!( + Left, + &[ + lit(ScalarValue::Utf8(None)), + lit(ScalarValue::Int64(Some(2))), + ], + Ok(None), + &str, + Utf8, + StringArray + ); + test_function!( + Left, + &[ + lit(ScalarValue::Utf8(Some("abcde".to_string()))), + lit(ScalarValue::Int64(None)), + ], + Ok(None), + &str, + Utf8, + StringArray + ); + test_function!( + Left, + &[ + lit(ScalarValue::Utf8(Some("joséésoj".to_string()))), + lit(ScalarValue::Int64(Some(5))), + ], + Ok(Some("joséé")), + &str, + Utf8, + StringArray + ); + test_function!( + Left, + &[ + lit(ScalarValue::Utf8(Some("joséésoj".to_string()))), + lit(ScalarValue::Int64(Some(-3))), + ], + Ok(Some("joséé")), + &str, + Utf8, + StringArray + ); + test_function!( + Lpad, + &[ + lit(ScalarValue::Utf8(Some("josé".to_string()))), + lit(ScalarValue::Int64(Some(5))), + ], + Ok(Some(" josé")), + &str, + Utf8, + StringArray + ); + test_function!( + Lpad, + &[ + lit(ScalarValue::Utf8(Some("hi".to_string()))), + lit(ScalarValue::Int64(Some(5))), + ], + Ok(Some(" hi")), + &str, + Utf8, + StringArray + ); + test_function!( + Lpad, + &[ + lit(ScalarValue::Utf8(Some("hi".to_string()))), + lit(ScalarValue::Int64(Some(0))), + ], + Ok(Some("")), + &str, + Utf8, + StringArray + ); + test_function!( + Lpad, + &[ + lit(ScalarValue::Utf8(Some("hi".to_string()))), + lit(ScalarValue::Int64(None)), + ], + Ok(None), + &str, + Utf8, + StringArray + ); + test_function!( + Lpad, + &[ + lit(ScalarValue::Utf8(None)), + lit(ScalarValue::Int64(Some(5))), + ], + Ok(None), + &str, + Utf8, + StringArray + ); + test_function!( + Lpad, + &[ + lit(ScalarValue::Utf8(Some("hi".to_string()))), + lit(ScalarValue::Int64(Some(5))), + lit(ScalarValue::Utf8(Some("xy".to_string()))), + ], + Ok(Some("xyxhi")), + &str, + Utf8, + StringArray + ); + test_function!( + Lpad, + &[ + lit(ScalarValue::Utf8(Some("hi".to_string()))), + lit(ScalarValue::Int64(Some(21))), + lit(ScalarValue::Utf8(Some("abcdef".to_string()))), + ], + Ok(Some("abcdefabcdefabcdefahi")), + &str, + Utf8, + StringArray + ); + test_function!( + Lpad, + &[ + lit(ScalarValue::Utf8(Some("hi".to_string()))), + lit(ScalarValue::Int64(Some(5))), + lit(ScalarValue::Utf8(Some(" ".to_string()))), + ], + Ok(Some(" hi")), + &str, + Utf8, + StringArray + ); + test_function!( + Lpad, + &[ + lit(ScalarValue::Utf8(Some("hi".to_string()))), + lit(ScalarValue::Int64(Some(5))), + lit(ScalarValue::Utf8(Some("".to_string()))), + ], + Ok(Some("hi")), + &str, + Utf8, + StringArray + ); + test_function!( + Lpad, + &[ + lit(ScalarValue::Utf8(None)), + lit(ScalarValue::Int64(Some(5))), + lit(ScalarValue::Utf8(Some("xy".to_string()))), + ], + Ok(None), + &str, + Utf8, + StringArray + ); + test_function!( + Lpad, + &[ + lit(ScalarValue::Utf8(Some("hi".to_string()))), + lit(ScalarValue::Int64(None)), + lit(ScalarValue::Utf8(Some("xy".to_string()))), + ], + Ok(None), + &str, + Utf8, + StringArray + ); + test_function!( + Lpad, + &[ + lit(ScalarValue::Utf8(Some("hi".to_string()))), + lit(ScalarValue::Int64(Some(5))), + lit(ScalarValue::Utf8(None)), + ], + Ok(None), + &str, + Utf8, + StringArray + ); + test_function!( + Lpad, + &[ + lit(ScalarValue::Utf8(Some("hi".to_string()))), + lit(ScalarValue::Utf8(Some("5".to_string()))), + lit(ScalarValue::Utf8(Some("xy".to_string()))), + ], + Ok(Some("xyxhi")), + &str, + Utf8, + StringArray + ); + test_function!( + Lpad, + &[ + lit(ScalarValue::Utf8(Some("josé".to_string()))), + lit(ScalarValue::Int64(Some(10))), + lit(ScalarValue::Utf8(Some("xy".to_string()))), + ], + Ok(Some("xyxyxyjosé")), + &str, + Utf8, + StringArray + ); + test_function!( + Lpad, + &[ + lit(ScalarValue::Utf8(Some("josé".to_string()))), + lit(ScalarValue::Int64(Some(10))), + lit(ScalarValue::Utf8(Some("éñ".to_string()))), + ], + Ok(Some("éñéñéñjosé")), + &str, + Utf8, + StringArray + ); + test_function!( + Lower, + &[lit(ScalarValue::Utf8(Some("LOWER".to_string())))], + Ok(Some("lower")), + &str, + Utf8, + StringArray + ); + test_function!( + Lower, + &[lit(ScalarValue::Utf8(Some("lower".to_string())))], + Ok(Some("lower")), + &str, + Utf8, + StringArray + ); + test_function!( + Lower, + &[lit(ScalarValue::Utf8(None))], + Ok(None), + &str, + Utf8, + StringArray + ); + test_function!( + Ltrim, + &[lit(ScalarValue::Utf8(Some(" trim".to_string())))], + Ok(Some("trim")), + &str, + Utf8, + StringArray + ); + test_function!( + Ltrim, + &[lit(ScalarValue::Utf8(Some(" trim ".to_string())))], + Ok(Some("trim ")), + &str, + Utf8, + StringArray + ); + test_function!( + Ltrim, + &[lit(ScalarValue::Utf8(Some("trim ".to_string())))], + Ok(Some("trim ")), + &str, + Utf8, + StringArray + ); + test_function!( + Ltrim, + &[lit(ScalarValue::Utf8(Some("trim".to_string())))], + Ok(Some("trim")), + &str, + Utf8, + StringArray + ); + test_function!( + Ltrim, + &[lit(ScalarValue::Utf8(Some("\n trim ".to_string())))], + Ok(Some("\n trim ")), + &str, + Utf8, + StringArray + ); + test_function!( + Ltrim, + &[lit(ScalarValue::Utf8(None))], + Ok(None), + &str, + Utf8, + StringArray + ); + test_function!( + OctetLength, + &[lit(ScalarValue::Utf8(Some("chars".to_string())))], + Ok(Some(5)), + i32, + Int32, + Int32Array + ); + test_function!( + OctetLength, + &[lit(ScalarValue::Utf8(Some("josé".to_string())))], + Ok(Some(5)), + i32, + Int32, + Int32Array + ); + test_function!( + OctetLength, + &[lit(ScalarValue::Utf8(Some("".to_string())))], + Ok(Some(0)), + i32, + Int32, + Int32Array + ); + test_function!( + OctetLength, + &[lit(ScalarValue::Utf8(None))], + Ok(None), + i32, + Int32, + Int32Array + ); + test_function!( + RegexpReplace, + &[ + lit(ScalarValue::Utf8(Some("Thomas".to_string()))), + lit(ScalarValue::Utf8(Some(".[mN]a.".to_string()))), + lit(ScalarValue::Utf8(Some("M".to_string()))), + ], + Ok(Some("ThM")), + &str, + Utf8, + StringArray + ); + test_function!( + RegexpReplace, + &[ + lit(ScalarValue::Utf8(Some("foobarbaz".to_string()))), + lit(ScalarValue::Utf8(Some("b..".to_string()))), + lit(ScalarValue::Utf8(Some("X".to_string()))), + ], + Ok(Some("fooXbaz")), + &str, + Utf8, + StringArray + ); + test_function!( + RegexpReplace, + &[ + lit(ScalarValue::Utf8(Some("foobarbaz".to_string()))), + lit(ScalarValue::Utf8(Some("b..".to_string()))), + lit(ScalarValue::Utf8(Some("X".to_string()))), + lit(ScalarValue::Utf8(Some("g".to_string()))), + ], + Ok(Some("fooXX")), + &str, + Utf8, + StringArray + ); + test_function!( + RegexpReplace, + &[ + lit(ScalarValue::Utf8(Some("foobarbaz".to_string()))), + lit(ScalarValue::Utf8(Some("b(..)".to_string()))), + lit(ScalarValue::Utf8(Some("X\\1Y".to_string()))), + lit(ScalarValue::Utf8(Some("g".to_string()))), + ], + Ok(Some("fooXarYXazY")), + &str, + Utf8, + StringArray + ); + test_function!( + RegexpReplace, + &[ + lit(ScalarValue::Utf8(None)), + lit(ScalarValue::Utf8(Some("b(..)".to_string()))), + lit(ScalarValue::Utf8(Some("X\\1Y".to_string()))), + lit(ScalarValue::Utf8(Some("g".to_string()))), + ], + Ok(None), + &str, + Utf8, + StringArray + ); + test_function!( + RegexpReplace, + &[ + lit(ScalarValue::Utf8(Some("foobarbaz".to_string()))), + lit(ScalarValue::Utf8(None)), + lit(ScalarValue::Utf8(Some("X\\1Y".to_string()))), + lit(ScalarValue::Utf8(Some("g".to_string()))), + ], + Ok(None), + &str, + Utf8, + StringArray + ); + test_function!( + RegexpReplace, + &[ + lit(ScalarValue::Utf8(Some("foobarbaz".to_string()))), + lit(ScalarValue::Utf8(Some("b(..)".to_string()))), + lit(ScalarValue::Utf8(None)), + lit(ScalarValue::Utf8(Some("g".to_string()))), + ], + Ok(None), + &str, + Utf8, + StringArray + ); + test_function!( + RegexpReplace, + &[ + lit(ScalarValue::Utf8(Some("foobarbaz".to_string()))), + lit(ScalarValue::Utf8(Some("b(..)".to_string()))), + lit(ScalarValue::Utf8(Some("X\\1Y".to_string()))), + lit(ScalarValue::Utf8(None)), + ], + Ok(None), + &str, + Utf8, + StringArray + ); + test_function!( + RegexpReplace, + &[ + lit(ScalarValue::Utf8(Some("ABCabcABC".to_string()))), + lit(ScalarValue::Utf8(Some("(abc)".to_string()))), + lit(ScalarValue::Utf8(Some("X".to_string()))), + lit(ScalarValue::Utf8(Some("gi".to_string()))), + ], + Ok(Some("XXX")), + &str, + Utf8, + StringArray + ); + test_function!( + RegexpReplace, + &[ + lit(ScalarValue::Utf8(Some("ABCabcABC".to_string()))), + lit(ScalarValue::Utf8(Some("(abc)".to_string()))), + lit(ScalarValue::Utf8(Some("X".to_string()))), + lit(ScalarValue::Utf8(Some("i".to_string()))), + ], + Ok(Some("XabcABC")), + &str, + Utf8, + StringArray + ); + test_function!( + Repeat, + &[ + lit(ScalarValue::Utf8(Some("Pg".to_string()))), + lit(ScalarValue::Int64(Some(4))), + ], + Ok(Some("PgPgPgPg")), + &str, + Utf8, + StringArray + ); + test_function!( + Repeat, + &[ + lit(ScalarValue::Utf8(None)), + lit(ScalarValue::Int64(Some(4))), + ], + Ok(None), + &str, + Utf8, + StringArray + ); + test_function!( + Repeat, + &[ + lit(ScalarValue::Utf8(Some("Pg".to_string()))), + lit(ScalarValue::Int64(None)), + ], + Ok(None), + &str, + Utf8, + StringArray + ); + test_function!( + Replace, + &[ + lit(ScalarValue::Utf8(Some("abcdefabcdef".to_string()))), + lit(ScalarValue::Utf8(Some("cd".to_string()))), + lit(ScalarValue::Utf8(Some("XX".to_string()))), + ], + Ok(Some("abXXefabXXef")), + &str, + Utf8, + StringArray + ); + test_function!( + Replace, + &[ + lit(ScalarValue::Utf8(Some("abcdefabcdef".to_string()))), + lit(ScalarValue::Utf8(Some("notmatch".to_string()))), + lit(ScalarValue::Utf8(Some("XX".to_string()))), + ], + Ok(Some("abcdefabcdef")), + &str, + Utf8, + StringArray + ); + test_function!( + Replace, + &[ + lit(ScalarValue::Utf8(None)), + lit(ScalarValue::Utf8(Some("cd".to_string()))), + lit(ScalarValue::Utf8(Some("XX".to_string()))), + ], + Ok(None), + &str, + Utf8, + StringArray + ); + test_function!( + Replace, + &[ + lit(ScalarValue::Utf8(Some("abcdefabcdef".to_string()))), + lit(ScalarValue::Utf8(None)), + lit(ScalarValue::Utf8(Some("XX".to_string()))), + ], + Ok(None), + &str, + Utf8, + StringArray + ); + test_function!( + Replace, + &[ + lit(ScalarValue::Utf8(Some("abcdefabcdef".to_string()))), + lit(ScalarValue::Utf8(Some("cd".to_string()))), + lit(ScalarValue::Utf8(None)), + ], + Ok(None), + &str, + Utf8, + StringArray + ); + test_function!( + Reverse, + &[lit(ScalarValue::Utf8(Some("abcde".to_string())))], + Ok(Some("edcba")), + &str, + Utf8, + StringArray + ); + test_function!( + Reverse, + &[lit(ScalarValue::Utf8(Some("loẅks".to_string())))], + Ok(Some("skẅol")), + &str, + Utf8, + StringArray + ); + test_function!( + Reverse, + &[lit(ScalarValue::Utf8(Some("loẅks".to_string())))], + Ok(Some("skẅol")), + &str, + Utf8, + StringArray + ); + test_function!( + Reverse, + &[lit(ScalarValue::Utf8(None))], + Ok(None), + &str, + Utf8, + StringArray + ); + test_function!( + Right, + &[ + lit(ScalarValue::Utf8(Some("abcde".to_string()))), + lit(ScalarValue::Int8(Some(2))), + ], + Ok(Some("de")), + &str, + Utf8, + StringArray + ); + test_function!( + Right, + &[ + lit(ScalarValue::Utf8(Some("abcde".to_string()))), + lit(ScalarValue::Int64(Some(200))), + ], + Ok(Some("abcde")), + &str, + Utf8, + StringArray + ); + test_function!( + Right, + &[ + lit(ScalarValue::Utf8(Some("abcde".to_string()))), + lit(ScalarValue::Int64(Some(-2))), + ], + Ok(Some("cde")), + &str, + Utf8, + StringArray + ); + test_function!( + Right, + &[ + lit(ScalarValue::Utf8(Some("abcde".to_string()))), + lit(ScalarValue::Int64(Some(-200))), + ], + Ok(Some("")), + &str, + Utf8, + StringArray + ); + test_function!( + Right, + &[ + lit(ScalarValue::Utf8(Some("abcde".to_string()))), + lit(ScalarValue::Int64(Some(0))), + ], + Ok(Some("")), + &str, + Utf8, + StringArray + ); + test_function!( + Right, + &[ + lit(ScalarValue::Utf8(None)), + lit(ScalarValue::Int64(Some(2))), + ], + Ok(None), + &str, + Utf8, + StringArray + ); + test_function!( + Right, + &[ + lit(ScalarValue::Utf8(Some("abcde".to_string()))), + lit(ScalarValue::Int64(None)), + ], + Ok(None), + &str, + Utf8, + StringArray + ); + test_function!( + Right, + &[ + lit(ScalarValue::Utf8(Some("joséésoj".to_string()))), + lit(ScalarValue::Int64(Some(5))), + ], + Ok(Some("éésoj")), + &str, + Utf8, + StringArray + ); + test_function!( + Right, + &[ + lit(ScalarValue::Utf8(Some("joséésoj".to_string()))), + lit(ScalarValue::Int64(Some(-3))), + ], + Ok(Some("éésoj")), + &str, + Utf8, + StringArray + ); + test_function!( + Rpad, + &[ + lit(ScalarValue::Utf8(Some("josé".to_string()))), + lit(ScalarValue::Int64(Some(5))), + ], + Ok(Some("josé ")), + &str, + Utf8, + StringArray + ); + test_function!( + Rpad, + &[ + lit(ScalarValue::Utf8(Some("hi".to_string()))), + lit(ScalarValue::Int64(Some(5))), + ], + Ok(Some("hi ")), + &str, + Utf8, + StringArray + ); + test_function!( + Rpad, + &[ + lit(ScalarValue::Utf8(Some("hi".to_string()))), + lit(ScalarValue::Int64(Some(0))), + ], + Ok(Some("")), + &str, + Utf8, + StringArray + ); + test_function!( + Rpad, + &[ + lit(ScalarValue::Utf8(Some("hi".to_string()))), + lit(ScalarValue::Int64(None)), + ], + Ok(None), + &str, + Utf8, + StringArray + ); + test_function!( + Rpad, + &[ + lit(ScalarValue::Utf8(None)), + lit(ScalarValue::Int64(Some(5))), + ], + Ok(None), + &str, + Utf8, + StringArray + ); + test_function!( + Rpad, + &[ + lit(ScalarValue::Utf8(Some("hi".to_string()))), + lit(ScalarValue::Int64(Some(5))), + lit(ScalarValue::Utf8(Some("xy".to_string()))), + ], + Ok(Some("hixyx")), + &str, + Utf8, + StringArray + ); + test_function!( + Rpad, + &[ + lit(ScalarValue::Utf8(Some("hi".to_string()))), + lit(ScalarValue::Int64(Some(21))), + lit(ScalarValue::Utf8(Some("abcdef".to_string()))), + ], + Ok(Some("hiabcdefabcdefabcdefa")), + &str, + Utf8, + StringArray + ); + test_function!( + Rpad, + &[ + lit(ScalarValue::Utf8(Some("hi".to_string()))), + lit(ScalarValue::Int64(Some(5))), + lit(ScalarValue::Utf8(Some(" ".to_string()))), + ], + Ok(Some("hi ")), + &str, + Utf8, + StringArray + ); + test_function!( + Rpad, + &[ + lit(ScalarValue::Utf8(Some("hi".to_string()))), + lit(ScalarValue::Int64(Some(5))), + lit(ScalarValue::Utf8(Some("".to_string()))), + ], + Ok(Some("hi")), + &str, + Utf8, + StringArray + ); + test_function!( + Rpad, + &[ + lit(ScalarValue::Utf8(None)), + lit(ScalarValue::Int64(Some(5))), + lit(ScalarValue::Utf8(Some("xy".to_string()))), + ], + Ok(None), + &str, + Utf8, + StringArray + ); + test_function!( + Rpad, + &[ + lit(ScalarValue::Utf8(Some("hi".to_string()))), + lit(ScalarValue::Int64(None)), + lit(ScalarValue::Utf8(Some("xy".to_string()))), + ], + Ok(None), + &str, + Utf8, + StringArray + ); + test_function!( + Rpad, + &[ + lit(ScalarValue::Utf8(Some("hi".to_string()))), + lit(ScalarValue::Int64(Some(5))), + lit(ScalarValue::Utf8(None)), + ], + Ok(None), + &str, + Utf8, + StringArray + ); + test_function!( + Rpad, + &[ + lit(ScalarValue::Utf8(Some("josé".to_string()))), + lit(ScalarValue::Int64(Some(10))), + lit(ScalarValue::Utf8(Some("xy".to_string()))), + ], + Ok(Some("joséxyxyxy")), + &str, + Utf8, + StringArray + ); + test_function!( + Rpad, + &[ + lit(ScalarValue::Utf8(Some("josé".to_string()))), + lit(ScalarValue::Int64(Some(10))), + lit(ScalarValue::Utf8(Some("éñ".to_string()))), + ], + Ok(Some("josééñéñéñ")), + &str, + Utf8, + StringArray + ); + test_function!( + Rtrim, + &[lit(ScalarValue::Utf8(Some("trim ".to_string())))], + Ok(Some("trim")), + &str, + Utf8, + StringArray + ); + test_function!( + Rtrim, + &[lit(ScalarValue::Utf8(Some(" trim ".to_string())))], + Ok(Some(" trim")), + &str, + Utf8, + StringArray + ); + test_function!( + Rtrim, + &[lit(ScalarValue::Utf8(Some(" trim \n".to_string())))], + Ok(Some(" trim \n")), + &str, + Utf8, + StringArray + ); + test_function!( + Rtrim, + &[lit(ScalarValue::Utf8(Some(" trim".to_string())))], + Ok(Some(" trim")), + &str, + Utf8, + StringArray + ); + test_function!( + Rtrim, + &[lit(ScalarValue::Utf8(Some("trim".to_string())))], + Ok(Some("trim")), + &str, + Utf8, + StringArray + ); + test_function!( + Rtrim, + &[lit(ScalarValue::Utf8(None))], + Ok(None), + &str, + Utf8, + StringArray + ); + test_function!( + Strpos, + &[ + lit(ScalarValue::Utf8(Some("abc".to_string()))), + lit(ScalarValue::Utf8(Some("c".to_string()))), + ], + Ok(Some(3)), + i32, + Int32, + Int32Array + ); + test_function!( + Strpos, + &[ + lit(ScalarValue::Utf8(Some("josé".to_string()))), + lit(ScalarValue::Utf8(Some("é".to_string()))), + ], + Ok(Some(4)), + i32, + Int32, + Int32Array + ); + test_function!( + Strpos, + &[ + lit(ScalarValue::Utf8(Some("joséésoj".to_string()))), + lit(ScalarValue::Utf8(Some("so".to_string()))), + ], + Ok(Some(6)), + i32, + Int32, + Int32Array + ); + test_function!( + Strpos, + &[ + lit(ScalarValue::Utf8(Some("joséésoj".to_string()))), + lit(ScalarValue::Utf8(Some("abc".to_string()))), + ], + Ok(Some(0)), + i32, + Int32, + Int32Array + ); + test_function!( + Strpos, + &[ + lit(ScalarValue::Utf8(None)), + lit(ScalarValue::Utf8(Some("abc".to_string()))), + ], + Ok(None), + i32, + Int32, + Int32Array + ); + test_function!( + Strpos, + &[ + lit(ScalarValue::Utf8(Some("joséésoj".to_string()))), + lit(ScalarValue::Utf8(None)), + ], + Ok(None), + i32, + Int32, + Int32Array + ); + test_function!( + Trim, + &[lit(ScalarValue::Utf8(Some(" trim ".to_string())))], + Ok(Some("trim")), + &str, + Utf8, + StringArray + ); + test_function!( + Trim, + &[lit(ScalarValue::Utf8(Some("trim ".to_string())))], + Ok(Some("trim")), + &str, + Utf8, + StringArray + ); + test_function!( + Trim, + &[lit(ScalarValue::Utf8(Some(" trim".to_string())))], + Ok(Some("trim")), + &str, + Utf8, + StringArray + ); + test_function!( + Trim, + &[lit(ScalarValue::Utf8(None))], + Ok(None), + &str, + Utf8, + StringArray + ); + test_function!( + SplitPart, + &[ + lit(ScalarValue::Utf8(Some("abc~@~def~@~ghi".to_string()))), + lit(ScalarValue::Utf8(Some("~@~".to_string()))), + lit(ScalarValue::Int64(Some(2))), + ], + Ok(Some("def")), + &str, + Utf8, + StringArray + ); + test_function!( + SplitPart, + &[ + lit(ScalarValue::Utf8(Some("abc~@~def~@~ghi".to_string()))), + lit(ScalarValue::Utf8(Some("~@~".to_string()))), + lit(ScalarValue::Int64(Some(20))), + ], + Ok(Some("")), + &str, + Utf8, + StringArray + ); + test_function!( + SplitPart, + &[ + lit(ScalarValue::Utf8(Some("abc~@~def~@~ghi".to_string()))), + lit(ScalarValue::Utf8(Some("~@~".to_string()))), + lit(ScalarValue::Int64(Some(-1))), + ], + Err(DataFusionError::Execution( + "field position must be greater than zero".to_string(), + )), + &str, + Utf8, + StringArray + ); + test_function!( + StartsWith, + &[ + lit(ScalarValue::Utf8(Some("alphabet".to_string()))), + lit(ScalarValue::Utf8(Some("alph".to_string()))), + ], + Ok(Some(true)), + bool, + Boolean, + BooleanArray + ); + test_function!( + StartsWith, + &[ + lit(ScalarValue::Utf8(Some("alphabet".to_string()))), + lit(ScalarValue::Utf8(Some("blph".to_string()))), + ], + Ok(Some(false)), + bool, + Boolean, + BooleanArray + ); + test_function!( + StartsWith, + &[ + lit(ScalarValue::Utf8(None)), + lit(ScalarValue::Utf8(Some("alph".to_string()))), + ], + Ok(None), + bool, + Boolean, + BooleanArray + ); + test_function!( + StartsWith, + &[ + lit(ScalarValue::Utf8(Some("alphabet".to_string()))), + lit(ScalarValue::Utf8(None)), + ], + Ok(None), + bool, + Boolean, + BooleanArray + ); + test_function!( + Substr, + &[ + lit(ScalarValue::Utf8(Some("alphabet".to_string()))), + lit(ScalarValue::Int64(Some(0))), + ], + Ok(Some("alphabet")), + &str, + Utf8, + StringArray + ); + test_function!( + Substr, + &[ + lit(ScalarValue::Utf8(Some("joséésoj".to_string()))), + lit(ScalarValue::Int64(Some(5))), + ], + Ok(Some("ésoj")), + &str, + Utf8, + StringArray + ); + test_function!( + Substr, + &[ + lit(ScalarValue::Utf8(Some("alphabet".to_string()))), + lit(ScalarValue::Int64(Some(1))), + ], + Ok(Some("alphabet")), + &str, + Utf8, + StringArray + ); + test_function!( + Substr, + &[ + lit(ScalarValue::Utf8(Some("alphabet".to_string()))), + lit(ScalarValue::Int64(Some(2))), + ], + Ok(Some("lphabet")), + &str, + Utf8, + StringArray + ); + test_function!( + Substr, + &[ + lit(ScalarValue::Utf8(Some("alphabet".to_string()))), + lit(ScalarValue::Int64(Some(3))), + ], + Ok(Some("phabet")), + &str, + Utf8, + StringArray + ); + test_function!( + Substr, + &[ + lit(ScalarValue::Utf8(Some("alphabet".to_string()))), + lit(ScalarValue::Int64(Some(-3))), + ], + Ok(Some("alphabet")), + &str, + Utf8, + StringArray + ); + test_function!( + Substr, + &[ + lit(ScalarValue::Utf8(Some("alphabet".to_string()))), + lit(ScalarValue::Int64(Some(30))), + ], + Ok(Some("")), + &str, + Utf8, + StringArray + ); + test_function!( + Substr, + &[ + lit(ScalarValue::Utf8(Some("alphabet".to_string()))), + lit(ScalarValue::Int64(None)), + ], + Ok(None), + &str, + Utf8, + StringArray + ); + test_function!( + Substr, + &[ + lit(ScalarValue::Utf8(Some("alphabet".to_string()))), + lit(ScalarValue::Int64(Some(3))), + lit(ScalarValue::Int64(Some(2))), + ], + Ok(Some("ph")), + &str, + Utf8, + StringArray + ); + test_function!( + Substr, + &[ + lit(ScalarValue::Utf8(Some("alphabet".to_string()))), + lit(ScalarValue::Int64(Some(3))), + lit(ScalarValue::Int64(Some(20))), + ], + Ok(Some("phabet")), + &str, + Utf8, + StringArray + ); + test_function!( + Substr, + &[ + lit(ScalarValue::Utf8(Some("alphabet".to_string()))), + lit(ScalarValue::Int64(None)), + lit(ScalarValue::Int64(Some(20))), + ], + Ok(None), + &str, + Utf8, + StringArray + ); + test_function!( + Substr, + &[ + lit(ScalarValue::Utf8(Some("alphabet".to_string()))), + lit(ScalarValue::Int64(Some(3))), + lit(ScalarValue::Int64(None)), + ], + Ok(None), + &str, + Utf8, + StringArray + ); + test_function!( + Substr, + &[ + lit(ScalarValue::Utf8(Some("alphabet".to_string()))), + lit(ScalarValue::Int64(Some(1))), + lit(ScalarValue::Int64(Some(-1))), + ], + Err(DataFusionError::Execution( + "negative substring length not allowed".to_string(), + )), + &str, + Utf8, + StringArray + ); + test_function!( + Substr, + &[ + lit(ScalarValue::Utf8(Some("joséésoj".to_string()))), + lit(ScalarValue::Int64(Some(5))), + lit(ScalarValue::Int64(Some(2))), + ], + Ok(Some("és")), + &str, + Utf8, + StringArray + ); + test_function!( + ToHex, + &[lit(ScalarValue::Int32(Some(2147483647)))], + Ok(Some("7fffffff")), + &str, + Utf8, + StringArray + ); + test_function!( + ToHex, + &[lit(ScalarValue::Int64(Some(9223372036854775807)))], + Ok(Some("7fffffffffffffff")), + &str, + Utf8, + StringArray + ); + test_function!( + ToHex, + &[lit(ScalarValue::Int32(None))], + Ok(None), + &str, + Utf8, + StringArray + ); + test_function!( + Translate, + &[ + lit(ScalarValue::Utf8(Some("12345".to_string()))), + lit(ScalarValue::Utf8(Some("143".to_string()))), + lit(ScalarValue::Utf8(Some("ax".to_string()))), + ], + Ok(Some("a2x5")), + &str, + Utf8, + StringArray + ); + test_function!( + Translate, + &[ + lit(ScalarValue::Utf8(None)), + lit(ScalarValue::Utf8(Some("143".to_string()))), + lit(ScalarValue::Utf8(Some("ax".to_string()))), + ], + Ok(None), + &str, + Utf8, + StringArray + ); + test_function!( + Translate, + &[ + lit(ScalarValue::Utf8(Some("12345".to_string()))), + lit(ScalarValue::Utf8(None)), + lit(ScalarValue::Utf8(Some("ax".to_string()))), + ], + Ok(None), + &str, + Utf8, + StringArray + ); + test_function!( + Translate, + &[ + lit(ScalarValue::Utf8(Some("12345".to_string()))), + lit(ScalarValue::Utf8(Some("143".to_string()))), + lit(ScalarValue::Utf8(None)), + ], + Ok(None), + &str, + Utf8, + StringArray + ); + test_function!( + Translate, + &[ + lit(ScalarValue::Utf8(Some("é2íñ5".to_string()))), + lit(ScalarValue::Utf8(Some("éñí".to_string()))), + lit(ScalarValue::Utf8(Some("óü".to_string()))), + ], + Ok(Some("ó2ü5")), + &str, + Utf8, + StringArray + ); + test_function!( + Upper, + &[lit(ScalarValue::Utf8(Some("upper".to_string())))], + Ok(Some("UPPER")), + &str, + Utf8, + StringArray + ); + test_function!( + Upper, + &[lit(ScalarValue::Utf8(Some("UPPER".to_string())))], + Ok(Some("UPPER")), + &str, + Utf8, + StringArray + ); + test_function!( + Upper, + &[lit(ScalarValue::Utf8(None))], + Ok(None), + &str, + Utf8, + StringArray + ); Ok(()) } - fn test_concat(value: ScalarValue, expected: &str) -> Result<()> { - // any type works here: we evaluate against a literal of `value` - let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); - let columns: Vec = vec![Arc::new(Int32Array::from(vec![1]))]; - - // concat(value, value) - let expr = create_physical_expr( - &BuiltinScalarFunction::Concat, - &[lit(value.clone()), lit(value)], - &schema, - )?; - - // type is correct - assert_eq!(expr.data_type(&schema)?, DataType::Utf8); - - // evaluate works - let batch = RecordBatch::try_new(Arc::new(schema.clone()), columns)?; - let result = expr.evaluate(&batch)?.into_array(batch.num_rows()); - - // downcast works - let result = result.as_any().downcast_ref::().unwrap(); - - // value is correct - assert_eq!(result.value(0).to_string(), expected); - - Ok(()) - } - - #[test] - fn test_concat_utf8() -> Result<()> { - test_concat(ScalarValue::Utf8(Some("aa".to_string())), "aaaa") - } - #[test] fn test_concat_error() -> Result<()> { let result = return_type(&BuiltinScalarFunction::Concat, &[]); diff --git a/rust/datafusion/src/physical_plan/string_expressions.rs b/rust/datafusion/src/physical_plan/string_expressions.rs index 81d2c67eec6..2de019747d8 100644 --- a/rust/datafusion/src/physical_plan/string_expressions.rs +++ b/rust/datafusion/src/physical_plan/string_expressions.rs @@ -14,9 +14,14 @@ // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. +// +// Some of these functions reference the Postgres documentation +// or implementation to ensure compatibility and are subject to +// the Postgres license. //! String expressions - +use std::cmp::Ordering; +use std::str::from_utf8; use std::sync::Arc; use crate::{ @@ -25,11 +30,13 @@ use crate::{ }; use arrow::{ array::{ - Array, ArrayRef, GenericStringArray, PrimitiveArray, StringArray, - StringOffsetSizeTrait, + Array, ArrayRef, BooleanArray, GenericStringArray, Int32Array, Int64Array, + PrimitiveArray, StringArray, StringOffsetSizeTrait, }, datatypes::{ArrowNativeType, ArrowPrimitiveType, DataType}, }; +use hashbrown::HashMap; +use regex::Regex; use unicode_segmentation::UnicodeSegmentation; use super::ColumnarValue; @@ -46,10 +53,10 @@ pub(crate) fn unary_string_function<'a, T, O, F, R>( name: &str, ) -> Result> where - R: AsRef, - O: StringOffsetSizeTrait, T: StringOffsetSizeTrait, + O: StringOffsetSizeTrait, F: Fn(&'a str) -> R, + R: AsRef, { if args.len() != 1 { return Err(DataFusionError::Internal(format!( @@ -72,8 +79,8 @@ where fn handle<'a, F, R>(args: &'a [ColumnarValue], op: F, name: &str) -> Result where - R: AsRef, F: Fn(&'a str) -> R, + R: AsRef, { match &args[0] { ColumnarValue::Array(a) => match a.data_type() { @@ -119,6 +126,93 @@ where } } +macro_rules! downcast_vec { + ($ARGS:expr, $ARRAY_TYPE:ident) => {{ + $ARGS + .iter() + .map(|e| match e.as_any().downcast_ref::<$ARRAY_TYPE>() { + Some(array) => Ok(array), + _ => Err(DataFusionError::Internal("failed to downcast".to_string())), + }) + }}; +} + +/// Returns the numeric code of the first character of the argument. +/// ascii('x') = 120 +pub fn ascii(args: &[ArrayRef]) -> Result { + let string_array: &GenericStringArray = args[0] + .as_any() + .downcast_ref::>() + .unwrap(); + + // first map is the iterator, second is for the `Option<_>` + let result = string_array + .iter() + .map(|x| { + x.map(|x: &str| { + let mut chars = x.chars(); + chars.next().map_or(0, |v| v as i32) + }) + }) + .collect::(); + + Ok(Arc::new(result) as ArrayRef) +} + +/// Removes the longest string containing only characters in characters (a space by default) from the start and end of string. +/// btrim('xyxtrimyyx', 'xyz') = 'trim' +pub fn btrim(args: &[ArrayRef]) -> Result { + match args.len() { + 1 => { + let string_array: &GenericStringArray = args[0] + .as_any() + .downcast_ref::>() + .unwrap(); + + let result = string_array + .iter() + .map(|x| x.map(|x: &str| x.trim_start_matches(' ').trim_end_matches(' '))) + .collect::>(); + + Ok(Arc::new(result) as ArrayRef) + } + 2 => { + let string_array: &GenericStringArray = args[0] + .as_any() + .downcast_ref::>() + .unwrap(); + + let characters_array: &GenericStringArray = args[1] + .as_any() + .downcast_ref::>() + .unwrap(); + + let result = string_array + .iter() + .enumerate() + .map(|(i, x)| { + if characters_array.is_null(i) { + None + } else { + x.map(|x: &str| { + let chars: Vec = + characters_array.value(i).chars().collect(); + x.trim_start_matches(&chars[..]) + .trim_end_matches(&chars[..]) + }) + } + }) + .collect::>(); + + Ok(Arc::new(result) as ArrayRef) + } + other => Err(DataFusionError::Internal(format!( + "btrim was called with {} arguments. It requires at most 2.", + other + ))), + } +} + /// Returns number of characters in the string. /// character_length('josé') = 4 pub fn character_length(args: &[ArrayRef]) -> Result @@ -140,16 +234,46 @@ where Ok(Arc::new(result) as ArrayRef) } -/// concatenate string columns together. -pub fn concatenate(args: &[ColumnarValue]) -> Result { - // downcast all arguments to strings - //let args = downcast_vec!(args, StringArray).collect::>>()?; +/// Returns the character with the given code. chr(0) is disallowed because text data types cannot store that character. +/// chr(65) = 'A' +pub fn chr(args: &[ArrayRef]) -> Result { + let integer_array: &Int64Array = + args[0].as_any().downcast_ref::().unwrap(); + + // first map is the iterator, second is for the `Option<_>` + let result = integer_array + .iter() + .map(|x: Option| { + x.map(|x| { + if x == 0 { + Err(DataFusionError::Execution( + "null character not permitted.".to_string(), + )) + } else { + match core::char::from_u32(x as u32) { + Some(x) => Ok(x.to_string()), + None => Err(DataFusionError::Execution( + "requested character too large for encoding.".to_string(), + )), + } + } + }) + .transpose() + }) + .collect::>()?; + + Ok(Arc::new(result) as ArrayRef) +} + +/// Concatenates the text representations of all the arguments. NULL arguments are ignored. +/// concat('abcde', 2, NULL, 22) = 'abcde222' +pub fn concat(args: &[ColumnarValue]) -> Result { // do not accept 0 arguments. if args.is_empty() { - return Err(DataFusionError::Internal( - "Concatenate was called with 0 arguments. It requires at least one." - .to_string(), - )); + return Err(DataFusionError::Internal(format!( + "concat was called with {} arguments. It requires at least 1.", + args.len() + ))); } // first, decide whether to return a scalar or a vector. @@ -158,42 +282,30 @@ pub fn concatenate(args: &[ColumnarValue]) -> Result { _ => None, }); if let Some(size) = return_array.next() { - let iter = (0..size).map(|index| { - let mut owned_string: String = "".to_owned(); - - // if any is null, the result is null - let mut is_null = false; - for arg in args { - match arg { - ColumnarValue::Scalar(ScalarValue::Utf8(maybe_value)) => { - if let Some(value) = maybe_value { - owned_string.push_str(value); - } else { - is_null = true; - break; // short-circuit as we already know the result + let result = (0..size) + .map(|index| { + let mut owned_string: String = "".to_owned(); + for arg in args { + match arg { + ColumnarValue::Scalar(ScalarValue::Utf8(maybe_value)) => { + if let Some(value) = maybe_value { + owned_string.push_str(value); + } } - } - ColumnarValue::Array(v) => { - if v.is_null(index) { - is_null = true; - break; // short-circuit as we already know the result - } else { - let v = v.as_any().downcast_ref::().unwrap(); - owned_string.push_str(&v.value(index)); + ColumnarValue::Array(v) => { + if v.is_valid(index) { + let v = v.as_any().downcast_ref::().unwrap(); + owned_string.push_str(&v.value(index)); + } } + _ => unreachable!(), } - _ => unreachable!(), } - } - if is_null { - None - } else { Some(owned_string) - } - }); - let array = iter.collect::(); + }) + .collect::(); - Ok(ColumnarValue::Array(Arc::new(array))) + Ok(ColumnarValue::Array(Arc::new(result))) } else { // short avenue with only scalars let initial = Some("".to_string()); @@ -203,9 +315,7 @@ pub fn concatenate(args: &[ColumnarValue]) -> Result { ColumnarValue::Scalar(ScalarValue::Utf8(Some(v))) => { inner.push_str(v); } - ColumnarValue::Scalar(ScalarValue::Utf8(None)) => { - acc = None; - } + ColumnarValue::Scalar(ScalarValue::Utf8(None)) => {} _ => unreachable!(""), }; }; @@ -215,27 +325,1064 @@ pub fn concatenate(args: &[ColumnarValue]) -> Result { } } -/// lower +/// Concatenates all but the first argument, with separators. The first argument is used as the separator string, and should not be NULL. Other NULL arguments are ignored. +/// concat_ws(',', 'abcde', 2, NULL, 22) = 'abcde,2,22' +pub fn concat_ws(args: &[ArrayRef]) -> Result { + // downcast all arguments to strings + let args = downcast_vec!(args, StringArray).collect::>>()?; + + // do not accept 0 or 1 arguments. + if args.len() < 2 { + return Err(DataFusionError::Internal(format!( + "concat_ws was called with {} arguments. It requires at least 2.", + args.len() + ))); + } + + // first map is the iterator, second is for the `Option<_>` + let result = args[0] + .iter() + .enumerate() + .map(|(index, x)| { + x.map(|sep: &str| { + let mut owned_string: String = "".to_owned(); + for arg_index in 1..args.len() { + let arg = &args[arg_index]; + if !arg.is_null(index) { + owned_string.push_str(&arg.value(index)); + // if not last push separator + if arg_index != args.len() - 1 { + owned_string.push_str(&sep); + } + } + } + owned_string + }) + }) + .collect::(); + + Ok(Arc::new(result) as ArrayRef) +} + +/// Converts the first letter of each word to upper case and the rest to lower case. Words are sequences of alphanumeric characters separated by non-alphanumeric characters. +/// initcap('hi THOMAS') = 'Hi Thomas' +pub fn initcap(args: &[ArrayRef]) -> Result { + let string_array: &GenericStringArray = args[0] + .as_any() + .downcast_ref::>() + .unwrap(); + + // first map is the iterator, second is for the `Option<_>` + let result = string_array + .iter() + .map(|x| { + x.map(|x: &str| { + let mut char_vector = Vec::::new(); + let mut wasalnum = false; + for c in x.chars() { + if wasalnum { + char_vector.push(c.to_ascii_lowercase()); + } else { + char_vector.push(c.to_ascii_uppercase()); + } + wasalnum = ('A'..='Z').contains(&c) + || ('a'..='z').contains(&c) + || ('0'..='9').contains(&c); + } + char_vector.iter().collect::() + }) + }) + .collect::>(); + + Ok(Arc::new(result) as ArrayRef) +} + +/// Returns first n characters in the string, or when n is negative, returns all but last |n| characters. +/// left('abcde', 2) = 'ab' +pub fn left(args: &[ArrayRef]) -> Result { + let string_array: &GenericStringArray = args[0] + .as_any() + .downcast_ref::>() + .ok_or_else(|| { + DataFusionError::Internal("could not cast string to StringArray".to_string()) + })?; + + let n_array: &Int64Array = + args[1] + .as_any() + .downcast_ref::() + .ok_or_else(|| { + DataFusionError::Internal("could not cast n to Int64Array".to_string()) + })?; + + let result = string_array + .iter() + .enumerate() + .map(|(i, x)| { + if n_array.is_null(i) { + None + } else { + x.map(|x: &str| { + let n: i64 = n_array.value(i); + match n.cmp(&0) { + Ordering::Equal => "", + Ordering::Greater => x + .grapheme_indices(true) + .nth(n as usize) + .map_or(x, |(i, _)| &from_utf8(&x.as_bytes()[..i]).unwrap()), + Ordering::Less => x + .grapheme_indices(true) + .rev() + .nth(n.abs() as usize - 1) + .map_or("", |(i, _)| &from_utf8(&x.as_bytes()[..i]).unwrap()), + } + }) + } + }) + .collect::>(); + + Ok(Arc::new(result) as ArrayRef) +} + +/// Converts the string to all lower case. +/// length('jose') = 4 pub fn lower(args: &[ColumnarValue]) -> Result { handle(args, |x| x.to_ascii_lowercase(), "lower") } -/// upper -pub fn upper(args: &[ColumnarValue]) -> Result { - handle(args, |x| x.to_ascii_uppercase(), "upper") +/// Extends the string to length length by prepending the characters fill (a space by default). If the string is already longer than length then it is truncated (on the right). +/// lpad('hi', 5, 'xy') = 'xyxhi' +pub fn lpad(args: &[ArrayRef]) -> Result { + match args.len() { + 2 => { + let string_array: &GenericStringArray = args[0] + .as_any() + .downcast_ref::>() + .unwrap(); + + let length_array: &Int64Array = args[1] + .as_any() + .downcast_ref::() + .ok_or_else(|| { + DataFusionError::Internal( + "could not cast length to Int64Array".to_string(), + ) + })?; + + let result = string_array + .iter() + .enumerate() + .map(|(i, x)| { + if length_array.is_null(i) { + None + } else { + x.map(|x: &str| { + let length = length_array.value(i) as usize; + if length == 0 { + "".to_string() + } else { + let graphemes = x.graphemes(true).collect::>(); + if length < graphemes.len() { + graphemes[..length].concat() + } else { + let mut s = x.to_string(); + s.insert_str( + 0, + " ".repeat(length - graphemes.len()).as_str(), + ); + s + } + } + }) + } + }) + .collect::>(); + + Ok(Arc::new(result) as ArrayRef) + } + 3 => { + let string_array: &GenericStringArray = args[0] + .as_any() + .downcast_ref::>() + .unwrap(); + + let length_array: &Int64Array = + args[1].as_any().downcast_ref::().unwrap(); + + let fill_array: &GenericStringArray = args[2] + .as_any() + .downcast_ref::>() + .unwrap(); + + let result = string_array + .iter() + .enumerate() + .map(|(i, x)| { + if length_array.is_null(i) || fill_array.is_null(i) { + None + } else { + x.map(|x: &str| { + let length = length_array.value(i) as usize; + + if length == 0 { + "".to_string() + } else { + let graphemes = x.graphemes(true).collect::>(); + let fill_chars = + fill_array.value(i).chars().collect::>(); + + if length < graphemes.len() { + graphemes[..length].concat() + } else if fill_chars.is_empty() { + x.to_string() + } else { + let mut s = x.to_string(); + let mut char_vector = Vec::::with_capacity( + length - graphemes.len(), + ); + for l in 0..length - graphemes.len() { + char_vector.push( + *fill_chars + .get(l % fill_chars.len()) + .unwrap(), + ); + } + s.insert_str( + 0, + char_vector.iter().collect::().as_str(), + ); + s + } + } + }) + } + }) + .collect::>(); + + Ok(Arc::new(result) as ArrayRef) + } + other => Err(DataFusionError::Internal(format!( + "lpad was called with {} arguments. It requires at least 2 and at most 3.", + other + ))), + } } -/// trim -pub fn trim(args: &[ColumnarValue]) -> Result { - handle(args, |x: &str| x.trim(), "trim") +/// Removes the longest string containing only characters in characters (a space by default) from the start of string. +/// ltrim('zzzytest', 'xyz') = 'test' +pub fn ltrim(args: &[ArrayRef]) -> Result { + match args.len() { + 1 => { + let string_array: &GenericStringArray = args[0] + .as_any() + .downcast_ref::>() + .unwrap(); + + let result = string_array + .iter() + .map(|x| x.map(|x: &str| x.trim_start_matches(' '))) + .collect::>(); + + Ok(Arc::new(result) as ArrayRef) + } + 2 => { + let string_array: &GenericStringArray = args[0] + .as_any() + .downcast_ref::>() + .unwrap(); + + let characters_array: &GenericStringArray = args[1] + .as_any() + .downcast_ref::>() + .unwrap(); + + let result = string_array + .iter() + .enumerate() + .map(|(i, x)| { + if characters_array.is_null(i) { + None + } else { + x.map(|x: &str| { + let chars: Vec = + characters_array.value(i).chars().collect(); + x.trim_start_matches(&chars[..]) + }) + } + }) + .collect::>(); + + Ok(Arc::new(result) as ArrayRef) + } + other => Err(DataFusionError::Internal(format!( + "ltrim was called with {} arguments. It requires at most 2.", + other + ))), + } } -/// ltrim -pub fn ltrim(args: &[ColumnarValue]) -> Result { - handle(args, |x| x.trim_start(), "ltrim") +/// replace POSIX capture groups (like \1) with Rust Regex group (like ${1}) +/// used by regexp_replace +fn regex_replace_posix_groups(replacement: &str) -> String { + lazy_static! { + static ref CAPTURE_GROUPS_RE: Regex = Regex::new("(\\\\)(\\d*)").unwrap(); + } + CAPTURE_GROUPS_RE + .replace_all(replacement, "$${$2}") + .into_owned() } -/// rtrim -pub fn rtrim(args: &[ColumnarValue]) -> Result { - handle(args, |x| x.trim_end(), "rtrim") +/// Replaces substring(s) matching a POSIX regular expression +/// regexp_replace('Thomas', '.[mN]a.', 'M') = 'ThM' +pub fn regexp_replace(args: &[ArrayRef]) -> Result { + // creating Regex is expensive so create hashmap for memoization + let mut patterns: HashMap = HashMap::new(); + + match args.len() { + 3 => { + let string_array: &GenericStringArray = args[0] + .as_any() + .downcast_ref::>() + .unwrap(); + + let pattern_array: &StringArray = args[1] + .as_any() + .downcast_ref::() + .unwrap(); + + let replacement_array: &StringArray = args[2] + .as_any() + .downcast_ref::() + .unwrap(); + + let result = string_array + .iter() + .enumerate() + .map(|(i, x)| { + if pattern_array.is_null(i) || replacement_array.is_null(i) { + None + } else { + x.map(|x: &str| { + let pattern = pattern_array.value(i).to_string(); + let replacement = regex_replace_posix_groups(replacement_array.value(i)); + let re = match patterns.get(pattern_array.value(i)) { + Some(re) => Ok(re.clone()), + None => { + match Regex::new(pattern.as_str()) { + Ok(re) => { + patterns.insert(pattern, re.clone()); + Ok(re) + }, + Err(err) => Err(DataFusionError::Execution(err.to_string())), + } + } + }; + re.map(|re| re.replace(x, replacement.as_str())) + }) + }.transpose() + }) + .collect::>>()?; + + Ok(Arc::new(result) as ArrayRef) + } + 4 => { + let string_array: &GenericStringArray = args[0] + .as_any() + .downcast_ref::>() + .unwrap(); + + let pattern_array: &StringArray = args[1] + .as_any() + .downcast_ref::() + .unwrap(); + + let replacement_array: &StringArray = args[2] + .as_any() + .downcast_ref::() + .unwrap(); + + let flags_array: &StringArray = args[3] + .as_any() + .downcast_ref::() + .unwrap(); + + let result = string_array + .iter() + .enumerate() + .map(|(i, x)| { + if pattern_array.is_null(i) || replacement_array.is_null(i) || flags_array.is_null(i) { + None + } else { + x.map(|x: &str| { + let replacement = regex_replace_posix_groups(replacement_array.value(i)); + + let flags = flags_array.value(i); + let (pattern, replace_all) = if flags == "g" { + (pattern_array.value(i).to_string(), true) + } else if flags.contains('g') { + (format!("(?{}){}", flags.to_string().replace("g", ""), pattern_array.value(i)), true) + } else { + (format!("(?{}){}", flags, pattern_array.value(i)), false) + }; + + let re = match patterns.get(pattern_array.value(i)) { + Some(re) => Ok(re.clone()), + None => { + match Regex::new(pattern.as_str()) { + Ok(re) => { + patterns.insert(pattern, re.clone()); + Ok(re) + }, + Err(err) => Err(DataFusionError::Execution(err.to_string())), + } + } + }; + + re.map(|re| { + if replace_all { + re.replace_all(x, replacement.as_str()) + } else { + re.replace(x, replacement.as_str()) + } + }) + }) + }.transpose() + }) + .collect::>>()?; + + Ok(Arc::new(result) as ArrayRef) + } + other => Err(DataFusionError::Internal(format!( + "regexp_replace was called with {} arguments. It requires at least 3 and at most 4.", + other + ))), + } +} + +/// Repeats string the specified number of times. +/// repeat('Pg', 4) = 'PgPgPgPg' +pub fn repeat(args: &[ArrayRef]) -> Result { + let string_array: &GenericStringArray = args[0] + .as_any() + .downcast_ref::>() + .unwrap(); + + let number_array: &Int64Array = + args[1].as_any().downcast_ref::().unwrap(); + + let result = string_array + .iter() + .enumerate() + .map(|(i, x)| { + if number_array.is_null(i) { + None + } else { + x.map(|x: &str| x.repeat(number_array.value(i) as usize)) + } + }) + .collect::>(); + + Ok(Arc::new(result) as ArrayRef) +} + +/// Replaces all occurrences in string of substring from with substring to. +/// replace('abcdefabcdef', 'cd', 'XX') = 'abXXefabXXef' +pub fn replace(args: &[ArrayRef]) -> Result { + let string_array: &GenericStringArray = args[0] + .as_any() + .downcast_ref::>() + .unwrap(); + + let from_array: &GenericStringArray = args[1] + .as_any() + .downcast_ref::>() + .unwrap(); + + let to_array: &GenericStringArray = args[2] + .as_any() + .downcast_ref::>() + .unwrap(); + + let result = string_array + .iter() + .enumerate() + .map(|(i, x)| { + if from_array.is_null(i) || to_array.is_null(i) { + None + } else { + x.map(|x: &str| x.replace(from_array.value(i), to_array.value(i))) + } + }) + .collect::>(); + + Ok(Arc::new(result) as ArrayRef) +} + +/// Reverses the order of the characters in the string. +/// reverse('abcde') = 'edcba' +pub fn reverse(args: &[ArrayRef]) -> Result { + let string_array: &GenericStringArray = args[0] + .as_any() + .downcast_ref::>() + .unwrap(); + + // first map is the iterator, second is for the `Option<_>` + let result = string_array + .iter() + .map(|x| x.map(|x: &str| x.graphemes(true).rev().collect::())) + .collect::>(); + + Ok(Arc::new(result) as ArrayRef) +} + +/// Returns last n characters in the string, or when n is negative, returns all but first |n| characters. +/// right('abcde', 2) = 'de' +pub fn right(args: &[ArrayRef]) -> Result { + let string_array: &GenericStringArray = args[0] + .as_any() + .downcast_ref::>() + .ok_or_else(|| { + DataFusionError::Internal("could not cast string to StringArray".to_string()) + })?; + + let n_array: &Int64Array = + args[1] + .as_any() + .downcast_ref::() + .ok_or_else(|| { + DataFusionError::Internal("could not cast n to Int64Array".to_string()) + })?; + + let result = string_array + .iter() + .enumerate() + .map(|(i, x)| { + if n_array.is_null(i) { + None + } else { + x.map(|x: &str| { + let n: i64 = n_array.value(i); + match n.cmp(&0) { + Ordering::Equal => "", + Ordering::Greater => x + .grapheme_indices(true) + .rev() + .nth(n as usize - 1) + .map_or(x, |(i, _)| &from_utf8(&x.as_bytes()[i..]).unwrap()), + Ordering::Less => x + .grapheme_indices(true) + .nth(n.abs() as usize) + .map_or("", |(i, _)| &from_utf8(&x.as_bytes()[i..]).unwrap()), + } + }) + } + }) + .collect::>(); + + Ok(Arc::new(result) as ArrayRef) +} + +/// Extends the string to length length by appending the characters fill (a space by default). If the string is already longer than length then it is truncated. +/// rpad('hi', 5, 'xy') = 'hixyx' +pub fn rpad(args: &[ArrayRef]) -> Result { + match args.len() { + 2 => { + let string_array: &GenericStringArray = args[0] + .as_any() + .downcast_ref::>() + .unwrap(); + + let length_array: &Int64Array = args[1] + .as_any() + .downcast_ref::() + .ok_or_else(|| { + DataFusionError::Internal( + "could not cast length to Int64Array".to_string(), + ) + })?; + + let result = string_array + .iter() + .enumerate() + .map(|(i, x)| { + if length_array.is_null(i) { + None + } else { + x.map(|x: &str| { + let length = length_array.value(i) as usize; + if length == 0 { + "".to_string() + } else { + let graphemes = x.graphemes(true).collect::>(); + if length < graphemes.len() { + graphemes[..length].concat() + } else { + let mut s = x.to_string(); + s.push_str( + " ".repeat(length - graphemes.len()).as_str(), + ); + s + } + } + }) + } + }) + .collect::>(); + + Ok(Arc::new(result) as ArrayRef) + } + 3 => { + let string_array: &GenericStringArray = args[0] + .as_any() + .downcast_ref::>() + .unwrap(); + + let length_array: &Int64Array = + args[1].as_any().downcast_ref::().unwrap(); + + let fill_array: &GenericStringArray = args[2] + .as_any() + .downcast_ref::>() + .unwrap(); + + let result = string_array + .iter() + .enumerate() + .map(|(i, x)| { + if length_array.is_null(i) || fill_array.is_null(i) { + None + } else { + x.map(|x: &str| { + let length = length_array.value(i) as usize; + + if length == 0 { + "".to_string() + } else { + let graphemes = x.graphemes(true).collect::>(); + let fill_chars = + fill_array.value(i).chars().collect::>(); + + if length < graphemes.len() { + graphemes[..length].concat() + } else if fill_chars.is_empty() { + x.to_string() + } else { + let mut s = x.to_string(); + let mut char_vector = Vec::::with_capacity( + length - graphemes.len(), + ); + for l in 0..length - graphemes.len() { + char_vector.push( + *fill_chars + .get(l % fill_chars.len()) + .unwrap(), + ); + } + s.push_str( + char_vector.iter().collect::().as_str(), + ); + s + } + } + }) + } + }) + .collect::>(); + + Ok(Arc::new(result) as ArrayRef) + } + other => Err(DataFusionError::Internal(format!( + "rpad was called with {} arguments. It requires at least 2 and at most 3.", + other + ))), + } +} + +/// Removes the longest string containing only characters in characters (a space by default) from the end of string. +/// rtrim('testxxzx', 'xyz') = 'test' +pub fn rtrim(args: &[ArrayRef]) -> Result { + match args.len() { + 1 => { + let string_array: &GenericStringArray = args[0] + .as_any() + .downcast_ref::>() + .unwrap(); + + let result = string_array + .iter() + .map(|x| x.map(|x: &str| x.trim_end_matches(' '))) + .collect::>(); + + Ok(Arc::new(result) as ArrayRef) + } + 2 => { + let string_array: &GenericStringArray = args[0] + .as_any() + .downcast_ref::>() + .unwrap(); + + let characters_array: &GenericStringArray = args[1] + .as_any() + .downcast_ref::>() + .unwrap(); + + let result = string_array + .iter() + .enumerate() + .map(|(i, x)| { + if characters_array.is_null(i) { + None + } else { + x.map(|x: &str| { + let chars: Vec = + characters_array.value(i).chars().collect(); + x.trim_end_matches(&chars[..]) + }) + } + }) + .collect::>(); + + Ok(Arc::new(result) as ArrayRef) + } + other => Err(DataFusionError::Internal(format!( + "rtrim was called with {} arguments. It requires at most two.", + other + ))), + } +} + +/// Splits string at occurrences of delimiter and returns the n'th field (counting from one). +/// split_part('abc~@~def~@~ghi', '~@~', 2) = 'def' +pub fn split_part(args: &[ArrayRef]) -> Result { + let string_array: &GenericStringArray = args[0] + .as_any() + .downcast_ref::>() + .unwrap(); + + let delimiter_array: &GenericStringArray = args[1] + .as_any() + .downcast_ref::>() + .unwrap(); + + let n_array: &Int64Array = args[2].as_any().downcast_ref::().unwrap(); + + let result = string_array + .iter() + .enumerate() + .map(|(i, x)| { + if delimiter_array.is_null(i) || n_array.is_null(i) { + Ok(None) + } else { + x.map(|x: &str| { + let delimiter = delimiter_array.value(i); + let n = n_array.value(i); + if n <= 0 { + Err(DataFusionError::Execution( + "field position must be greater than zero".to_string(), + )) + } else { + let v: Vec<&str> = x.split(delimiter).collect(); + match v.get(n as usize - 1) { + Some(s) => Ok(*s), + None => Ok(""), + } + } + }) + .transpose() + } + }) + .collect::>>()?; + + Ok(Arc::new(result) as ArrayRef) +} + +/// Returns true if string starts with prefix. +/// starts_with('alphabet', 'alph') = 't' +pub fn starts_with(args: &[ArrayRef]) -> Result { + let string_array: &GenericStringArray = args[0] + .as_any() + .downcast_ref::>() + .unwrap(); + + let prefix_array: &GenericStringArray = args[1] + .as_any() + .downcast_ref::>() + .unwrap(); + + let result = string_array + .iter() + .enumerate() + .map(|(i, x)| { + if prefix_array.is_null(i) { + None + } else { + x.map(|x: &str| x.starts_with(prefix_array.value(i))) + } + }) + .collect::(); + + Ok(Arc::new(result) as ArrayRef) +} + +/// Returns starting index of specified substring within string, or zero if it's not present. (Same as position(substring in string), but note the reversed argument order.) +/// strpos('high', 'ig') = 2 +pub fn strpos(args: &[ArrayRef]) -> Result +where + T::Native: StringOffsetSizeTrait, +{ + let string_array: &GenericStringArray = args[0] + .as_any() + .downcast_ref::>() + .ok_or_else(|| { + DataFusionError::Internal("could not cast string to StringArray".to_string()) + })?; + + let substring_array: &GenericStringArray = args[1] + .as_any() + .downcast_ref::>() + .ok_or_else(|| { + DataFusionError::Internal( + "could not cast substring to StringArray".to_string(), + ) + })?; + + let result = string_array + .iter() + .enumerate() + .map(|(i, x)| { + if substring_array.is_null(i) { + None + } else { + x.map(|x: &str| { + let substring: &str = substring_array.value(i); + // the rfind method returns the byte index which may or may not be the same as the character index due to UTF8 encoding + // this method first finds the matching byte using rfind + // then maps that to the character index by matching on the grapheme_index of the byte_index + T::Native::from_usize(x.to_string().rfind(substring).map_or( + 0, + |byte_offset| { + x.grapheme_indices(true) + .collect::>() + .iter() + .enumerate() + .filter(|(_, (offset, _))| *offset == byte_offset) + .map(|(index, _)| index) + .collect::>() + .first() + .unwrap() + .to_owned() + + 1 + }, + )) + .unwrap() + }) + } + }) + .collect::>(); + + Ok(Arc::new(result) as ArrayRef) +} + +/// Extracts the substring of string starting at the start'th character, and extending for count characters if that is specified. (Same as substring(string from start for count).) +/// substr('alphabet', 3) = 'phabet' +/// substr('alphabet', 3, 2) = 'ph' +pub fn substr(args: &[ArrayRef]) -> Result { + match args.len() { + 2 => { + let string_array: &GenericStringArray = args[0] + .as_any() + .downcast_ref::>() + .ok_or_else(|| { + DataFusionError::Internal( + "could not cast string to StringArray".to_string(), + ) + })?; + + let start_array: &Int64Array = args[1] + .as_any() + .downcast_ref::() + .ok_or_else(|| { + DataFusionError::Internal( + "could not cast start to Int64Array".to_string(), + ) + })?; + + let result = string_array + .iter() + .enumerate() + .map(|(i, x)| { + if start_array.is_null(i) { + None + } else { + x.map(|x: &str| { + let start: i64 = start_array.value(i); + + if start <= 0 { + x.to_string() + } else { + let graphemes = x.graphemes(true).collect::>(); + let start_pos = start as usize - 1; + if graphemes.len() < start_pos { + "".to_string() + } else { + graphemes[start_pos..].concat() + } + } + }) + } + }) + .collect::>(); + + Ok(Arc::new(result) as ArrayRef) + } + 3 => { + let string_array: &GenericStringArray = args[0] + .as_any() + .downcast_ref::>() + .ok_or_else(|| { + DataFusionError::Internal( + "could not cast string to StringArray".to_string(), + ) + })?; + + let start_array: &Int64Array = args[1] + .as_any() + .downcast_ref::() + .ok_or_else(|| { + DataFusionError::Internal( + "could not cast start to Int64Array".to_string(), + ) + })?; + + let count_array: &Int64Array = args[2] + .as_any() + .downcast_ref::() + .ok_or_else(|| { + DataFusionError::Internal( + "could not cast count to Int64Array".to_string(), + ) + })?; + + let result = string_array + .iter() + .enumerate() + .map(|(i, x)| { + if start_array.is_null(i) || count_array.is_null(i) { + Ok(None) + } else { + x.map(|x: &str| { + let start: i64 = start_array.value(i); + let count = count_array.value(i); + + if count < 0 { + Err(DataFusionError::Execution( + "negative substring length not allowed".to_string(), + )) + } else if start <= 0 { + Ok(x.to_string()) + } else { + let graphemes = x.graphemes(true).collect::>(); + let start_pos = start as usize - 1; + let count_usize = count as usize; + if graphemes.len() < start_pos { + Ok("".to_string()) + } else if graphemes.len() < start_pos + count_usize { + Ok(graphemes[start_pos..].concat()) + } else { + Ok(graphemes[start_pos..start_pos + count_usize] + .concat()) + } + } + }) + .transpose() + } + }) + .collect::>>()?; + + Ok(Arc::new(result) as ArrayRef) + } + other => Err(DataFusionError::Internal(format!( + "substr was called with {} arguments. It requires 2 or 3.", + other + ))), + } +} + +/// Converts the number to its equivalent hexadecimal representation. +/// to_hex(2147483647) = '7fffffff' +pub fn to_hex(args: &[ArrayRef]) -> Result +where + T::Native: StringOffsetSizeTrait, +{ + let integer_array: &PrimitiveArray = args[0] + .as_any() + .downcast_ref::>() + .unwrap(); + + // first map is the iterator, second is for the `Option<_>` + let result = integer_array + .iter() + .map(|x| x.map(|x| format!("{:x}", x.to_usize().unwrap()))) + .collect::>(); + + Ok(Arc::new(result) as ArrayRef) +} + +/// Replaces each character in string that matches a character in the from set with the corresponding character in the to set. If from is longer than to, occurrences of the extra characters in from are deleted. +/// translate('12345', '143', 'ax') = 'a2x5' +pub fn translate(args: &[ArrayRef]) -> Result { + let string_array: &GenericStringArray = args[0] + .as_any() + .downcast_ref::>() + .unwrap(); + + let from_array: &GenericStringArray = args[1] + .as_any() + .downcast_ref::>() + .unwrap(); + + let to_array: &GenericStringArray = args[2] + .as_any() + .downcast_ref::>() + .unwrap(); + + let result = string_array + .iter() + .enumerate() + .map(|(i, x)| { + if from_array.is_null(i) || to_array.is_null(i) { + None + } else { + x.map(|x: &str| { + let from = from_array.value(i).graphemes(true).collect::>(); + // create a hashmap to change from O(n) to O(1) from lookup + let from_map: HashMap<&str, usize> = from + .iter() + .enumerate() + .map(|(index, c)| (c.to_owned(), index)) + .collect(); + + let to = to_array.value(i).graphemes(true).collect::>(); + + x.graphemes(true) + .collect::>() + .iter() + .flat_map(|c| match from_map.get(*c) { + Some(n) => to.get(*n).copied(), + None => Some(*c), + }) + .collect::>() + .concat() + }) + } + }) + .collect::>(); + + Ok(Arc::new(result) as ArrayRef) +} + +/// Converts the string to all upper case. +/// upper('tom') = 'TOM' +pub fn upper(args: &[ColumnarValue]) -> Result { + handle(args, |x| x.to_ascii_uppercase(), "upper") } diff --git a/rust/datafusion/src/physical_plan/type_coercion.rs b/rust/datafusion/src/physical_plan/type_coercion.rs index ae920cb870f..091133b3e35 100644 --- a/rust/datafusion/src/physical_plan/type_coercion.rs +++ b/rust/datafusion/src/physical_plan/type_coercion.rs @@ -95,22 +95,6 @@ fn get_valid_types( current_types: &[DataType], ) -> Result>> { let valid_types = match signature { - Signature::Variadic(valid_types) => valid_types - .iter() - .map(|valid_type| current_types.iter().map(|_| valid_type.clone()).collect()) - .collect(), - Signature::Uniform(number, valid_types) => valid_types - .iter() - .map(|valid_type| (0..*number).map(|_| valid_type.clone()).collect()) - .collect(), - Signature::VariadicEqual => { - // one entry with the same len as current_types, whose type is `current_types[0]`. - vec![current_types - .iter() - .map(|_| current_types[0].clone()) - .collect()] - } - Signature::Exact(valid_types) => vec![valid_types.clone()], Signature::Any(number) => { if current_types.len() != *number { return Err(DataFusionError::Plan(format!( @@ -121,6 +105,7 @@ fn get_valid_types( } vec![(0..*number).map(|i| current_types[i].clone()).collect()] } + Signature::Exact(valid_types) => vec![valid_types.clone()], Signature::OneOf(types) => { let mut r = vec![]; for s in types { @@ -128,6 +113,21 @@ fn get_valid_types( } r } + Signature::Uniform(number, valid_types) => valid_types + .iter() + .map(|valid_type| (0..*number).map(|_| valid_type.clone()).collect()) + .collect(), + Signature::Variadic(valid_types) => valid_types + .iter() + .map(|valid_type| current_types.iter().map(|_| valid_type.clone()).collect()) + .collect(), + Signature::VariadicEqual => { + // one entry with the same len as current_types, whose type is `current_types[0]`. + vec![current_types + .iter() + .map(|_| current_types[0].clone()) + .collect()] + } }; Ok(valid_types) @@ -168,20 +168,35 @@ fn maybe_data_types( pub fn can_coerce_from(type_into: &DataType, type_from: &DataType) -> bool { use self::DataType::*; match type_into { - Int8 => matches!(type_from, Int8), - Int16 => matches!(type_from, Int8 | Int16 | UInt8), - Int32 => matches!(type_from, Int8 | Int16 | Int32 | UInt8 | UInt16), + Int8 => matches!(type_from, Int8 | Utf8 | LargeUtf8), + Int16 => matches!(type_from, Int8 | Int16 | UInt8 | Utf8 | LargeUtf8), + Int32 => matches!( + type_from, + Int8 | Int16 | Int32 | UInt8 | UInt16 | Utf8 | LargeUtf8 + ), Int64 => matches!( type_from, - Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 + Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | Utf8 | LargeUtf8 + ), + UInt8 => matches!(type_from, UInt8 | Utf8 | LargeUtf8), + UInt16 => matches!(type_from, UInt8 | UInt16 | Utf8 | LargeUtf8), + UInt32 => matches!(type_from, UInt8 | UInt16 | UInt32 | Utf8 | LargeUtf8), + UInt64 => matches!( + type_from, + UInt8 | UInt16 | UInt32 | UInt64 | Utf8 | LargeUtf8 ), - UInt8 => matches!(type_from, UInt8), - UInt16 => matches!(type_from, UInt8 | UInt16), - UInt32 => matches!(type_from, UInt8 | UInt16 | UInt32), - UInt64 => matches!(type_from, UInt8 | UInt16 | UInt32 | UInt64), Float32 => matches!( type_from, - Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64 | Float32 + Int8 | Int16 + | Int32 + | Int64 + | UInt8 + | UInt16 + | UInt32 + | UInt64 + | Float32 + | Utf8 + | LargeUtf8 ), Float64 => matches!( type_from, @@ -194,9 +209,12 @@ pub fn can_coerce_from(type_into: &DataType, type_from: &DataType) -> bool { | UInt64 | Float32 | Float64 + | Utf8 + | LargeUtf8 ), Timestamp(TimeUnit::Nanosecond, None) => matches!(type_from, Timestamp(_, None)), Utf8 => true, + LargeUtf8 => true, _ => false, } } @@ -280,17 +298,35 @@ mod tests { Signature::Uniform(1, vec![DataType::UInt32]), vec![DataType::UInt32], )?, + case( + vec![DataType::UInt16], + Signature::OneOf(vec![Signature::Exact(vec![DataType::UInt32])]), + vec![DataType::UInt32], + )?, // same type case( vec![DataType::UInt32, DataType::UInt32], Signature::Uniform(2, vec![DataType::UInt32]), vec![DataType::UInt32, DataType::UInt32], )?, + case( + vec![DataType::UInt32, DataType::UInt32], + Signature::OneOf(vec![Signature::Exact(vec![ + DataType::UInt32, + DataType::UInt32, + ])]), + vec![DataType::UInt32, DataType::UInt32], + )?, case( vec![DataType::UInt32], Signature::Uniform(1, vec![DataType::Float32, DataType::Float64]), vec![DataType::Float32], )?, + case( + vec![DataType::UInt32], + Signature::OneOf(vec![Signature::Exact(vec![DataType::Float32])]), + vec![DataType::Float32], + )?, // u32 -> f32 case( vec![DataType::UInt32, DataType::UInt32], @@ -328,7 +364,7 @@ mod tests { // we do not know how to cast bool to UInt16 => fail case( vec![DataType::Boolean], - Signature::Uniform(1, vec![DataType::UInt16]), + Signature::OneOf(vec![Signature::Exact(vec![DataType::UInt16])]), vec![], )?, // u32 and bool are not uniform diff --git a/rust/datafusion/src/prelude.rs b/rust/datafusion/src/prelude.rs index 26e03c7453e..6071c1d82b6 100644 --- a/rust/datafusion/src/prelude.rs +++ b/rust/datafusion/src/prelude.rs @@ -28,8 +28,12 @@ pub use crate::dataframe::DataFrame; pub use crate::execution::context::{ExecutionConfig, ExecutionContext}; pub use crate::logical_plan::{ - array, avg, bit_length, character_length, col, concat, count, create_udf, in_list, - length, lit, lower, ltrim, max, md5, min, octet_length, rtrim, sha224, sha256, - sha384, sha512, sum, trim, upper, JoinType, Partitioning, + abs, acos, and, array, ascii, asin, atan, avg, binary_expr, bit_length, btrim, case, + ceil, character_length, chr, col, concat, concat_ws, cos, count, count_distinct, + create_udaf, create_udf, exp, exprlist_to_fields, floor, in_list, initcap, left, lit, + ln, log10, log2, lower, lpad, ltrim, max, md5, min, octet_length, or, regexp_replace, + repeat, replace, reverse, right, round, rpad, rtrim, sha224, sha256, sha384, sha512, + signum, sin, split_part, sqrt, starts_with, strpos, substr, sum, tan, translate, + trim, trunc, upper, when, Expr, JoinType, Literal, Partitioning, }; pub use crate::physical_plan::csv::CsvReadOptions; diff --git a/rust/datafusion/tests/sql.rs b/rust/datafusion/tests/sql.rs index 2f780b662b8..f16252a7080 100644 --- a/rust/datafusion/tests/sql.rs +++ b/rust/datafusion/tests/sql.rs @@ -530,17 +530,6 @@ async fn sqrt_f32_vs_f64() -> Result<()> { Ok(()) } -#[tokio::test] -async fn csv_query_error() -> Result<()> { - // sin(utf8) should error - let mut ctx = create_ctx()?; - register_aggregate_csv(&mut ctx)?; - let sql = "SELECT sin(c1) FROM aggregate_test_100"; - let plan = ctx.create_logical_plan(&sql); - assert!(plan.is_err()); - Ok(()) -} - // this query used to deadlock due to the call udf(udf()) #[tokio::test] async fn csv_query_sqrt_sqrt() -> Result<()> { @@ -1601,7 +1590,7 @@ async fn query_concat() -> Result<()> { let expected = vec![ vec!["-hi-0"], vec!["a-hi-1"], - vec!["NULL"], + vec!["aa-hi-"], vec!["aaa-hi-3"], ]; assert_eq!(expected, actual); @@ -1848,7 +1837,7 @@ async fn query_on_string_dictionary() -> Result<()> { // Expression evaluation let sql = "SELECT concat(d1, '-foo') FROM test"; let actual = execute(&mut ctx, sql).await; - let expected = vec![vec!["one-foo"], vec!["NULL"], vec!["three-foo"]]; + let expected = vec![vec!["one-foo"], vec!["-foo"], vec!["three-foo"]]; assert_eq!(expected, actual); // aggregation @@ -1985,175 +1974,395 @@ async fn csv_group_by_date() -> Result<()> { Ok(()) } -#[tokio::test] -async fn string_expressions() -> Result<()> { - let mut ctx = ExecutionContext::new(); - let sql = "SELECT - char_length('tom') AS char_length - ,char_length(NULL) AS char_length_null - ,character_length('tom') AS character_length - ,character_length(NULL) AS character_length_null - ,lower('TOM') AS lower - ,lower(NULL) AS lower_null - ,upper('tom') AS upper - ,upper(NULL) AS upper_null - ,trim(' tom ') AS trim - ,trim(NULL) AS trim_null - ,ltrim(' tom ') AS trim_left - ,rtrim(' tom ') AS trim_right - "; - let actual = execute(&mut ctx, sql).await; - - let expected = vec![vec![ - "3", "NULL", "3", "NULL", "tom", "NULL", "TOM", "NULL", "tom", "NULL", "tom ", - " tom", - ]]; - assert_eq!(expected, actual); - Ok(()) -} - -#[tokio::test] -async fn boolean_expressions() -> Result<()> { - let mut ctx = ExecutionContext::new(); - let sql = "SELECT - true AS val_1, - false AS val_2 - "; - let actual = execute(&mut ctx, sql).await; - - let expected = vec![vec!["true", "false"]]; - assert_eq!(expected, actual); - Ok(()) -} - -#[tokio::test] -async fn interval_expressions() -> Result<()> { - let mut ctx = ExecutionContext::new(); - let sql = "SELECT - (interval '1') as interval_1, - (interval '1 second') as interval_2, - (interval '500 milliseconds') as interval_3, - (interval '5 second') as interval_4, - (interval '1 minute') as interval_5, - (interval '0.5 minute') as interval_6, - (interval '.5 minute') as interval_7, - (interval '5 minute') as interval_8, - (interval '5 minute 1 second') as interval_9, - (interval '1 hour') as interval_10, - (interval '5 hour') as interval_11, - (interval '1 day') as interval_12, - (interval '1 day 1') as interval_13, - (interval '0.5') as interval_14, - (interval '0.5 day 1') as interval_15, - (interval '0.49 day') as interval_16, - (interval '0.499 day') as interval_17, - (interval '0.4999 day') as interval_18, - (interval '0.49999 day') as interval_19, - (interval '0.49999999999 day') as interval_20, - (interval '5 day') as interval_21, - (interval '5 day 4 hours 3 minutes 2 seconds 100 milliseconds') as interval_22, - (interval '0.5 month') as interval_23, - (interval '1 month') as interval_24, - (interval '5 month') as interval_25, - (interval '13 month') as interval_26, - (interval '0.5 year') as interval_27, - (interval '1 year') as interval_28, - (interval '2 year') as interval_29 - "; - let actual = execute(&mut ctx, sql).await; - - let expected = vec![vec![ - "0 years 0 mons 0 days 0 hours 0 mins 1.00 secs", - "0 years 0 mons 0 days 0 hours 0 mins 1.00 secs", - "0 years 0 mons 0 days 0 hours 0 mins 0.500 secs", - "0 years 0 mons 0 days 0 hours 0 mins 5.00 secs", - "0 years 0 mons 0 days 0 hours 1 mins 0.00 secs", - "0 years 0 mons 0 days 0 hours 0 mins 30.00 secs", - "0 years 0 mons 0 days 0 hours 0 mins 30.00 secs", - "0 years 0 mons 0 days 0 hours 5 mins 0.00 secs", - "0 years 0 mons 0 days 0 hours 5 mins 1.00 secs", - "0 years 0 mons 0 days 1 hours 0 mins 0.00 secs", - "0 years 0 mons 0 days 5 hours 0 mins 0.00 secs", - "0 years 0 mons 1 days 0 hours 0 mins 0.00 secs", - "0 years 0 mons 1 days 0 hours 0 mins 1.00 secs", - "0 years 0 mons 0 days 0 hours 0 mins 0.500 secs", - "0 years 0 mons 0 days 12 hours 0 mins 1.00 secs", - "0 years 0 mons 0 days 11 hours 45 mins 36.00 secs", - "0 years 0 mons 0 days 11 hours 58 mins 33.596 secs", - "0 years 0 mons 0 days 11 hours 59 mins 51.364 secs", - "0 years 0 mons 0 days 11 hours 59 mins 59.136 secs", - "0 years 0 mons 0 days 12 hours 0 mins 0.00 secs", - "0 years 0 mons 5 days 0 hours 0 mins 0.00 secs", - "0 years 0 mons 5 days 4 hours 3 mins 2.100 secs", - "0 years 0 mons 15 days 0 hours 0 mins 0.00 secs", - "0 years 1 mons 0 days 0 hours 0 mins 0.00 secs", - "0 years 5 mons 0 days 0 hours 0 mins 0.00 secs", - "1 years 1 mons 0 days 0 hours 0 mins 0.00 secs", - "0 years 6 mons 0 days 0 hours 0 mins 0.00 secs", - "1 years 0 mons 0 days 0 hours 0 mins 0.00 secs", - "2 years 0 mons 0 days 0 hours 0 mins 0.00 secs", - ]]; - assert_eq!(expected, actual); - Ok(()) -} - -#[tokio::test] -async fn crypto_expressions() -> Result<()> { - let mut ctx = ExecutionContext::new(); - let sql = "SELECT - md5('tom') AS md5_tom, - md5('') AS md5_empty_str, - md5(null) AS md5_null, - sha224('tom') AS sha224_tom, - sha224('') AS sha224_empty_str, - sha224(null) AS sha224_null, - sha256('tom') AS sha256_tom, - sha256('') AS sha256_empty_str, - sha384('tom') AS sha348_tom, - sha384('') AS sha384_empty_str, - sha512('tom') AS sha512_tom, - sha512('') AS sha512_empty_str - "; - let actual = execute(&mut ctx, sql).await; - - let expected = vec![vec![ - "34b7da764b21d298ef307d04d8152dc5", - "d41d8cd98f00b204e9800998ecf8427e", - "NULL", - "0bf6cb62649c42a9ae3876ab6f6d92ad36cb5414e495f8873292be4d", - "d14a028c2a3a2bc9476102bb288234c415a2b01f828ea62ac5b3e42f", - "NULL", - "e1608f75c5d7813f3d4031cb30bfb786507d98137538ff8e128a6ff74e84e643", - "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855", - "096f5b68aa77848e4fdf5c1c0b350de2dbfad60ffd7c25d9ea07c6c19b8a4d55a9187eb117c557883f58c16dfac3e343", - "38b060a751ac96384cd9327eb1b1e36a21fdb71114be07434c0cc7bf63f6e1da274edebfe76f65fbd51ad2f14898b95b", - "6e1b9b3fe840680e37051f7ad5e959d6f39ad0f8885d855166f55c659469d3c8b78118c44a2a49c72ddb481cd6d8731034e11cc030070ba843a90b3495cb8d3e", - "cf83e1357eefb8bdf1542850d66d8007d620e4050b5715dc83f4a921d36ce9ce47d0d13c5d85f2b0ff8318d2877eec2f63b931bd47417a81a538327af927da3e" - ]]; - assert_eq!(expected, actual); +macro_rules! test_expression { + ($SQL:expr, $EXPECTED:expr) => { + let mut ctx = ExecutionContext::new(); + println!("test_expression: {}", $SQL); + let sql = format!("SELECT {}", $SQL); + let actual = execute(&mut ctx, sql.as_str()).await; + assert_eq!($EXPECTED, actual[0][0]); + }; +} + +#[tokio::test] +async fn test_string_expressions() -> Result<()> { + test_expression!("ascii('')", "0"); + test_expression!("ascii('x')", "120"); + test_expression!("ascii(NULL)", "NULL"); + test_expression!("bit_length('')", "0"); + test_expression!("bit_length('chars')", "40"); + test_expression!("bit_length('josé')", "40"); + test_expression!("bit_length(NULL)", "NULL"); + test_expression!("btrim(' xyxtrimyyx ', NULL)", "NULL"); + test_expression!("btrim(' xyxtrimyyx ')", "xyxtrimyyx"); + test_expression!("btrim('\n xyxtrimyyx \n')", "\n xyxtrimyyx \n"); + test_expression!("btrim('xyxtrimyyx', 'xyz')", "trim"); + test_expression!("btrim('\nxyxtrimyyx\n', 'xyz\n')", "trim"); + test_expression!("btrim(NULL, 'xyz')", "NULL"); + test_expression!("char_length('')", "0"); + test_expression!("char_length('chars')", "5"); + test_expression!("char_length(NULL)", "NULL"); + test_expression!("character_length('')", "0"); + test_expression!("character_length('chars')", "5"); + test_expression!("character_length('josé')", "4"); + test_expression!("character_length(NULL)", "NULL"); + test_expression!("chr(CAST(120 AS int))", "x"); + test_expression!("chr(CAST(128175 AS int))", "💯"); + test_expression!("chr(CAST(NULL AS int))", "NULL"); + test_expression!("concat_ws(',', 'abcde', 2, NULL, 22)", "abcde,2,22"); + test_expression!("concat_ws('|','a','b','c')", "a|b|c"); + test_expression!("concat_ws('|',NULL)", ""); + test_expression!("concat_ws(NULL,'a',NULL,'b','c')", "NULL"); + test_expression!("concat('a','b','c')", "abc"); + test_expression!("concat('abcde', 2, NULL, 22)", "abcde222"); + test_expression!("concat(NULL)", ""); + test_expression!("initcap('')", ""); + test_expression!("initcap('hi THOMAS')", "Hi Thomas"); + test_expression!("initcap(NULL)", "NULL"); + test_expression!("left('abcde', -2)", "abc"); + test_expression!("left('abcde', -200)", ""); + test_expression!("left('abcde', 0)", ""); + test_expression!("left('abcde', 2)", "ab"); + test_expression!("left('abcde', 200)", "abcde"); + test_expression!("left('abcde', CAST(NULL AS INT))", "NULL"); + test_expression!("left(NULL, 2)", "NULL"); + test_expression!("left(NULL, CAST(NULL AS INT))", "NULL"); + test_expression!("length('')", "0"); + test_expression!("length('chars')", "5"); + test_expression!("length(NULL)", "NULL"); + test_expression!("lower('')", ""); + test_expression!("lower('TOM')", "tom"); + test_expression!("lower(NULL)", "NULL"); + test_expression!("lpad('hi', '5', 'xy')", "xyxhi"); + test_expression!("lpad('hi', 0)", ""); + test_expression!("lpad('hi', 21, 'abcdef')", "abcdefabcdefabcdefahi"); + test_expression!("lpad('hi', 5, 'xy')", "xyxhi"); + test_expression!("lpad('hi', 5, NULL)", "NULL"); + test_expression!("lpad('hi', 5)", " hi"); + test_expression!("lpad('hi', NULL, 'xy')", "NULL"); + test_expression!("lpad('hi', NULL)", "NULL"); + test_expression!("lpad('xyxhi', 3)", "xyx"); + test_expression!("lpad(NULL, 0)", "NULL"); + test_expression!("lpad(NULL, 5, 'xy')", "NULL"); + test_expression!("ltrim(' zzzytest ', NULL)", "NULL"); + test_expression!("ltrim(' zzzytest ')", "zzzytest "); + test_expression!("ltrim('zzzytest', 'xyz')", "test"); + test_expression!("ltrim(NULL, 'xyz')", "NULL"); + test_expression!("octet_length('')", "0"); + test_expression!("octet_length('chars')", "5"); + test_expression!("octet_length('josé')", "5"); + test_expression!("octet_length(NULL)", "NULL"); + test_expression!("regexp_replace('ABCabcABC', '(abc)', 'X', 'gi')", "XXX"); + test_expression!("regexp_replace('ABCabcABC', '(abc)', 'X', 'i')", "XabcABC"); + test_expression!("regexp_replace('foobarbaz', 'b..', 'X', 'g')", "fooXX"); + test_expression!("regexp_replace('foobarbaz', 'b..', 'X')", "fooXbaz"); + test_expression!( + "regexp_replace('foobarbaz', 'b(..)', 'X\\1Y', 'g')", + "fooXarYXazY" + ); + test_expression!( + "regexp_replace('foobarbaz', 'b(..)', 'X\\1Y', NULL)", + "NULL" + ); + test_expression!("regexp_replace('foobarbaz', 'b(..)', NULL, 'g')", "NULL"); + test_expression!("regexp_replace('foobarbaz', NULL, 'X\\1Y', 'g')", "NULL"); + test_expression!("regexp_replace('Thomas', '.[mN]a.', 'M')", "ThM"); + test_expression!("regexp_replace(NULL, 'b(..)', 'X\\1Y', 'g')", "NULL"); + test_expression!("repeat('Pg', 4)", "PgPgPgPg"); + test_expression!("repeat('Pg', NULL)", "NULL"); + test_expression!("repeat(NULL, 4)", "NULL"); + test_expression!("replace('abcdefabcdef', 'cd', 'XX')", "abXXefabXXef"); + test_expression!("replace('abcdefabcdef', 'cd', NULL)", "NULL"); + test_expression!("replace('abcdefabcdef', 'notmatch', 'XX')", "abcdefabcdef"); + test_expression!("replace('abcdefabcdef', NULL, 'XX')", "NULL"); + test_expression!("replace(NULL, 'cd', 'XX')", "NULL"); + test_expression!("reverse('abcde')", "edcba"); + test_expression!("reverse('loẅks')", "skẅol"); + test_expression!("reverse(NULL)", "NULL"); + test_expression!("right('abcde', -2)", "cde"); + test_expression!("right('abcde', -200)", ""); + test_expression!("right('abcde', 0)", ""); + test_expression!("right('abcde', 2)", "de"); + test_expression!("right('abcde', 200)", "abcde"); + test_expression!("right('abcde', CAST(NULL AS INT))", "NULL"); + test_expression!("right(NULL, 2)", "NULL"); + test_expression!("right(NULL, CAST(NULL AS INT))", "NULL"); + test_expression!("rpad('hi', '5', 'xy')", "hixyx"); + test_expression!("rpad('hi', 0)", ""); + test_expression!("rpad('hi', 21, 'abcdef')", "hiabcdefabcdefabcdefa"); + test_expression!("rpad('hi', 5, 'xy')", "hixyx"); + test_expression!("rpad('hi', 5, NULL)", "NULL"); + test_expression!("rpad('hi', 5)", "hi "); + test_expression!("rpad('hi', NULL, 'xy')", "NULL"); + test_expression!("rpad('hi', NULL)", "NULL"); + test_expression!("rpad('xyxhi', 3)", "xyx"); + test_expression!("rtrim(' testxxzx ')", " testxxzx"); + test_expression!("rtrim(' zzzytest ', NULL)", "NULL"); + test_expression!("rtrim('testxxzx', 'xyz')", "test"); + test_expression!("rtrim(NULL, 'xyz')", "NULL"); + test_expression!("trim(' tom ')", "tom"); + test_expression!("trim(' tom')", "tom"); + test_expression!("trim('')", ""); + test_expression!("trim('tom ')", "tom"); + test_expression!("trim(NULL)", "NULL"); + test_expression!("split_part('abc~@~def~@~ghi', '~@~', 2)", "def"); + test_expression!("split_part('abc~@~def~@~ghi', '~@~', 20)", ""); + test_expression!("split_part(NULL, '~@~', 20)", "NULL"); + test_expression!("split_part('abc~@~def~@~ghi', NULL, 20)", "NULL"); + test_expression!("split_part('abc~@~def~@~ghi', '~@~', NULL)", "NULL"); + test_expression!("starts_with('alphabet', 'alph')", "true"); + test_expression!("starts_with('alphabet', 'blph')", "false"); + test_expression!("starts_with(NULL, 'blph')", "NULL"); + test_expression!("starts_with('alphabet', NULL)", "NULL"); + test_expression!("strpos('abc', 'c')", "3"); + test_expression!("strpos('josé', 'é')", "4"); + test_expression!("strpos('joséésoj', 'so')", "6"); + test_expression!("strpos('joséésoj', 'abc')", "0"); + test_expression!("strpos(NULL, 'abc')", "NULL"); + test_expression!("strpos('joséésoj', NULL)", "NULL"); + test_expression!("substr('alphabet', -3)", "alphabet"); + test_expression!("substr('alphabet', 0)", "alphabet"); + test_expression!("substr('alphabet', 1)", "alphabet"); + test_expression!("substr('alphabet', 2)", "lphabet"); + test_expression!("substr('alphabet', 3)", "phabet"); + test_expression!("substr('alphabet', 30)", ""); + test_expression!("substr('alphabet', NULL)", "NULL"); + test_expression!("substr('alphabet', 3, 2)", "ph"); + test_expression!("substr('alphabet', 3, 20)", "phabet"); + test_expression!("substr('alphabet', NULL, 20)", "NULL"); + test_expression!("substr('alphabet', 3, NULL)", "NULL"); + test_expression!("to_hex(2147483647)", "7fffffff"); + test_expression!("to_hex(9223372036854775807)", "7fffffffffffffff"); + test_expression!("to_hex(CAST(NULL AS int))", "NULL"); + test_expression!("translate('12345', '143', 'ax')", "a2x5"); + test_expression!("translate(NULL, '143', 'ax')", "NULL"); + test_expression!("translate('12345', NULL, 'ax')", "NULL"); + test_expression!("translate('12345', '143', NULL)", "NULL"); + test_expression!("upper('')", ""); + test_expression!("upper('tom')", "TOM"); + test_expression!("upper(NULL)", "NULL"); + Ok(()) +} + +#[tokio::test] +async fn test_boolean_expressions() -> Result<()> { + test_expression!("true", "true"); + test_expression!("false", "false"); + Ok(()) +} + +#[tokio::test] +async fn test_interval_expressions() -> Result<()> { + test_expression!( + "interval '1'", + "0 years 0 mons 0 days 0 hours 0 mins 1.00 secs" + ); + test_expression!( + "interval '1 second'", + "0 years 0 mons 0 days 0 hours 0 mins 1.00 secs" + ); + test_expression!( + "interval '500 milliseconds'", + "0 years 0 mons 0 days 0 hours 0 mins 0.500 secs" + ); + test_expression!( + "interval '5 second'", + "0 years 0 mons 0 days 0 hours 0 mins 5.00 secs" + ); + test_expression!( + "interval '0.5 minute'", + "0 years 0 mons 0 days 0 hours 0 mins 30.00 secs" + ); + test_expression!( + "interval '.5 minute'", + "0 years 0 mons 0 days 0 hours 0 mins 30.00 secs" + ); + test_expression!( + "interval '5 minute'", + "0 years 0 mons 0 days 0 hours 5 mins 0.00 secs" + ); + test_expression!( + "interval '5 minute 1 second'", + "0 years 0 mons 0 days 0 hours 5 mins 1.00 secs" + ); + test_expression!( + "interval '1 hour'", + "0 years 0 mons 0 days 1 hours 0 mins 0.00 secs" + ); + test_expression!( + "interval '5 hour'", + "0 years 0 mons 0 days 5 hours 0 mins 0.00 secs" + ); + test_expression!( + "interval '1 day'", + "0 years 0 mons 1 days 0 hours 0 mins 0.00 secs" + ); + test_expression!( + "interval '1 day 1'", + "0 years 0 mons 1 days 0 hours 0 mins 1.00 secs" + ); + test_expression!( + "interval '0.5'", + "0 years 0 mons 0 days 0 hours 0 mins 0.500 secs" + ); + test_expression!( + "interval '0.5 day 1'", + "0 years 0 mons 0 days 12 hours 0 mins 1.00 secs" + ); + test_expression!( + "interval '0.49 day'", + "0 years 0 mons 0 days 11 hours 45 mins 36.00 secs" + ); + test_expression!( + "interval '0.499 day'", + "0 years 0 mons 0 days 11 hours 58 mins 33.596 secs" + ); + test_expression!( + "interval '0.4999 day'", + "0 years 0 mons 0 days 11 hours 59 mins 51.364 secs" + ); + test_expression!( + "interval '0.49999 day'", + "0 years 0 mons 0 days 11 hours 59 mins 59.136 secs" + ); + test_expression!( + "interval '0.49999999999 day'", + "0 years 0 mons 0 days 12 hours 0 mins 0.00 secs" + ); + test_expression!( + "interval '5 day'", + "0 years 0 mons 5 days 0 hours 0 mins 0.00 secs" + ); + test_expression!( + "interval '5 day 4 hours 3 minutes 2 seconds 100 milliseconds'", + "0 years 0 mons 5 days 4 hours 3 mins 2.100 secs" + ); + test_expression!( + "interval '0.5 month'", + "0 years 0 mons 15 days 0 hours 0 mins 0.00 secs" + ); + test_expression!( + "interval '1 month'", + "0 years 1 mons 0 days 0 hours 0 mins 0.00 secs" + ); + test_expression!( + "interval '5 month'", + "0 years 5 mons 0 days 0 hours 0 mins 0.00 secs" + ); + test_expression!( + "interval '13 month'", + "1 years 1 mons 0 days 0 hours 0 mins 0.00 secs" + ); + test_expression!( + "interval '0.5 year'", + "0 years 6 mons 0 days 0 hours 0 mins 0.00 secs" + ); + test_expression!( + "interval '1 year'", + "1 years 0 mons 0 days 0 hours 0 mins 0.00 secs" + ); + test_expression!( + "interval '2 year'", + "2 years 0 mons 0 days 0 hours 0 mins 0.00 secs" + ); Ok(()) } #[tokio::test] -async fn extract_date_part() -> Result<()> { - let mut ctx = ExecutionContext::new(); - let sql = "SELECT - date_part('hour', CAST('2020-01-01' AS DATE)) AS hr1, - EXTRACT(HOUR FROM CAST('2020-01-01' AS DATE)) AS hr2, - EXTRACT(HOUR FROM to_timestamp('2020-09-08T12:00:00+00:00')) AS hr3, - date_part('YEAR', CAST('2000-01-01' AS DATE)) AS year1, - EXTRACT(year FROM to_timestamp('2020-09-08T12:00:00+00:00')) AS year2 - "; - - let actual = execute(&mut ctx, sql).await; - - let expected = vec![vec!["0", "0", "12", "2000", "2020"]]; - assert_eq!(expected, actual); +async fn test_extract_date_part() -> Result<()> { + test_expression!("date_part('hour', CAST('2020-01-01' AS DATE))", "0"); + test_expression!("EXTRACT(HOUR FROM CAST('2020-01-01' AS DATE))", "0"); + test_expression!( + "EXTRACT(HOUR FROM to_timestamp('2020-09-08T12:00:00+00:00'))", + "12" + ); + test_expression!("date_part('YEAR', CAST('2000-01-01' AS DATE))", "2000"); + test_expression!( + "EXTRACT(year FROM to_timestamp('2020-09-08T12:00:00+00:00'))", + "2020" + ); Ok(()) } #[tokio::test] -async fn in_list_array() -> Result<()> { +async fn test_crypto_expressions() -> Result<()> { + test_expression!("md5('tom')", "34b7da764b21d298ef307d04d8152dc5"); + test_expression!("md5('')", "d41d8cd98f00b204e9800998ecf8427e"); + test_expression!("md5(NULL)", "NULL"); + test_expression!( + "sha224('tom')", + "0bf6cb62649c42a9ae3876ab6f6d92ad36cb5414e495f8873292be4d" + ); + test_expression!( + "sha224('')", + "d14a028c2a3a2bc9476102bb288234c415a2b01f828ea62ac5b3e42f" + ); + test_expression!("sha224(NULL)", "NULL"); + test_expression!( + "sha256('tom')", + "e1608f75c5d7813f3d4031cb30bfb786507d98137538ff8e128a6ff74e84e643" + ); + test_expression!( + "sha256('')", + "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855" + ); + test_expression!("sha256(NULL)", "NULL"); + test_expression!("sha384('tom')", "096f5b68aa77848e4fdf5c1c0b350de2dbfad60ffd7c25d9ea07c6c19b8a4d55a9187eb117c557883f58c16dfac3e343"); + test_expression!("sha384('')", "38b060a751ac96384cd9327eb1b1e36a21fdb71114be07434c0cc7bf63f6e1da274edebfe76f65fbd51ad2f14898b95b"); + test_expression!("sha384(NULL)", "NULL"); + test_expression!("sha512('tom')", "6e1b9b3fe840680e37051f7ad5e959d6f39ad0f8885d855166f55c659469d3c8b78118c44a2a49c72ddb481cd6d8731034e11cc030070ba843a90b3495cb8d3e"); + test_expression!("sha512('')", "cf83e1357eefb8bdf1542850d66d8007d620e4050b5715dc83f4a921d36ce9ce47d0d13c5d85f2b0ff8318d2877eec2f63b931bd47417a81a538327af927da3e"); + test_expression!("sha512(NULL)", "NULL"); + Ok(()) +} + +#[tokio::test] +async fn test_in_list_scalar() -> Result<()> { + test_expression!("'a' IN ('a','b')", "true"); + test_expression!("'c' IN ('a','b')", "false"); + test_expression!("'c' NOT IN ('a','b')", "true"); + test_expression!("'a' NOT IN ('a','b')", "false"); + test_expression!("NULL IN ('a','b')", "NULL"); + test_expression!("NULL NOT IN ('a','b')", "NULL"); + test_expression!("'a' IN ('a','b',NULL)", "true"); + test_expression!("'c' IN ('a','b',NULL)", "NULL"); + test_expression!("'a' NOT IN ('a','b',NULL)", "false"); + test_expression!("'c' NOT IN ('a','b',NULL)", "NULL"); + test_expression!("0 IN (0,1,2)", "true"); + test_expression!("3 IN (0,1,2)", "false"); + test_expression!("3 NOT IN (0,1,2)", "true"); + test_expression!("0 NOT IN (0,1,2)", "false"); + test_expression!("NULL IN (0,1,2)", "NULL"); + test_expression!("NULL NOT IN (0,1,2)", "NULL"); + test_expression!("0 IN (0,1,2,NULL)", "true"); + test_expression!("3 IN (0,1,2,NULL)", "NULL"); + test_expression!("0 NOT IN (0,1,2,NULL)", "false"); + test_expression!("3 NOT IN (0,1,2,NULL)", "NULL"); + test_expression!("0.0 IN (0.0,0.1,0.2)", "true"); + test_expression!("0.3 IN (0.0,0.1,0.2)", "false"); + test_expression!("0.3 NOT IN (0.0,0.1,0.2)", "true"); + test_expression!("0.0 NOT IN (0.0,0.1,0.2)", "false"); + test_expression!("NULL IN (0.0,0.1,0.2)", "NULL"); + test_expression!("NULL NOT IN (0.0,0.1,0.2)", "NULL"); + test_expression!("0.0 IN (0.0,0.1,0.2,NULL)", "true"); + test_expression!("0.3 IN (0.0,0.1,0.2,NULL)", "NULL"); + test_expression!("0.0 NOT IN (0.0,0.1,0.2,NULL)", "false"); + test_expression!("0.3 NOT IN (0.0,0.1,0.2,NULL)", "NULL"); + test_expression!("'1' IN ('a','b',1)", "true"); + test_expression!("'2' IN ('a','b',1)", "false"); + test_expression!("'2' NOT IN ('a','b',1)", "true"); + test_expression!("'1' NOT IN ('a','b',1)", "false"); + test_expression!("NULL IN ('a','b',1)", "NULL"); + test_expression!("NULL NOT IN ('a','b',1)", "NULL"); + test_expression!("'1' IN ('a','b',NULL,1)", "true"); + test_expression!("'2' IN ('a','b',NULL,1)", "NULL"); + test_expression!("'1' NOT IN ('a','b',NULL,1)", "false"); + test_expression!("'2' NOT IN ('a','b',NULL,1)", "NULL"); + Ok(()) +} + +#[tokio::test] +async fn test_in_list_array() -> Result<()> { let mut ctx = ExecutionContext::new(); register_aggregate_csv_by_sql(&mut ctx).await; let sql = "SELECT @@ -2176,64 +2385,3 @@ async fn in_list_array() -> Result<()> { assert_eq!(expected, actual); Ok(()) } - -#[tokio::test] -async fn in_list_scalar() -> Result<()> { - let mut ctx = ExecutionContext::new(); - let sql = "SELECT - 'a' IN ('a','b') AS utf8_in_true - ,'c' IN ('a','b') AS utf8_in_false - ,'c' NOT IN ('a','b') AS utf8_not_in_true - ,'a' NOT IN ('a','b') AS utf8_not_in_false - ,NULL IN ('a','b') AS utf8_in_null - ,NULL NOT IN ('a','b') AS utf8_not_in_null - ,'a' IN ('a','b',NULL) AS utf8_in_null_true - ,'c' IN ('a','b',NULL) AS utf8_in_null_null - ,'a' NOT IN ('a','b',NULL) AS utf8_not_in_null_false - ,'c' NOT IN ('a','b',NULL) AS utf8_not_in_null_null - - ,0 IN (0,1,2) AS int64_in_true - ,3 IN (0,1,2) AS int64_in_false - ,3 NOT IN (0,1,2) AS int64_not_in_true - ,0 NOT IN (0,1,2) AS int64_not_in_false - ,NULL IN (0,1,2) AS int64_in_null - ,NULL NOT IN (0,1,2) AS int64_not_in_null - ,0 IN (0,1,2,NULL) AS int64_in_null_true - ,3 IN (0,1,2,NULL) AS int64_in_null_null - ,0 NOT IN (0,1,2,NULL) AS int64_not_in_null_false - ,3 NOT IN (0,1,2,NULL) AS int64_not_in_null_null - - ,0.0 IN (0.0,0.1,0.2) AS float64_in_true - ,0.3 IN (0.0,0.1,0.2) AS float64_in_false - ,0.3 NOT IN (0.0,0.1,0.2) AS float64_not_in_true - ,0.0 NOT IN (0.0,0.1,0.2) AS float64_not_in_false - ,NULL IN (0.0,0.1,0.2) AS float64_in_null - ,NULL NOT IN (0.0,0.1,0.2) AS float64_not_in_null - ,0.0 IN (0.0,0.1,0.2,NULL) AS float64_in_null_true - ,0.3 IN (0.0,0.1,0.2,NULL) AS float64_in_null_null - ,0.0 NOT IN (0.0,0.1,0.2,NULL) AS float64_not_in_null_false - ,0.3 NOT IN (0.0,0.1,0.2,NULL) AS float64_not_in_null_null - - ,'1' IN ('a','b',1) AS utf8_cast_in_true - ,'2' IN ('a','b',1) AS utf8_cast_in_false - ,'2' NOT IN ('a','b',1) AS utf8_cast_not_in_true - ,'1' NOT IN ('a','b',1) AS utf8_cast_not_in_false - ,NULL IN ('a','b',1) AS utf8_cast_in_null - ,NULL NOT IN ('a','b',1) AS utf8_cast_not_in_null - ,'1' IN ('a','b',NULL,1) AS utf8_cast_in_null_true - ,'2' IN ('a','b',NULL,1) AS utf8_cast_in_null_null - ,'1' NOT IN ('a','b',NULL,1) AS utf8_cast_not_in_null_false - ,'2' NOT IN ('a','b',NULL,1) AS utf8_cast_not_in_null_null - "; - let actual = execute(&mut ctx, sql).await; - - let expected = vec![vec![ - "true", "false", "true", "false", "NULL", "NULL", "true", "NULL", "false", - "NULL", "true", "false", "true", "false", "NULL", "NULL", "true", "NULL", - "false", "NULL", "true", "false", "true", "false", "NULL", "NULL", "true", - "NULL", "false", "NULL", "true", "false", "true", "false", "NULL", "NULL", - "true", "NULL", "false", "NULL", - ]]; - assert_eq!(expected, actual); - Ok(()) -}