diff --git a/rust/arrow/Cargo.toml b/rust/arrow/Cargo.toml index 0b14b5bfae8..5ab1f8cc02b 100644 --- a/rust/arrow/Cargo.toml +++ b/rust/arrow/Cargo.toml @@ -114,6 +114,10 @@ harness = false name = "length_kernel" harness = false +[[bench]] +name = "bit_length_kernel" +harness = false + [[bench]] name = "sort_kernel" harness = false diff --git a/rust/arrow/benches/bit_length_kernel.rs b/rust/arrow/benches/bit_length_kernel.rs new file mode 100644 index 00000000000..51d31345712 --- /dev/null +++ b/rust/arrow/benches/bit_length_kernel.rs @@ -0,0 +1,46 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#[macro_use] +extern crate criterion; +use criterion::Criterion; + +extern crate arrow; + +use arrow::{array::*, compute::kernels::length::bit_length}; + +fn bench_bit_length(array: &StringArray) { + criterion::black_box(bit_length(array).unwrap()); +} + +fn add_benchmark(c: &mut Criterion) { + fn double_vec(v: Vec) -> Vec { + [&v[..], &v[..]].concat() + } + + // double ["hello", " ", "world", "!"] 10 times + let mut values = vec!["one", "on", "o", ""]; + for _ in 0..10 { + values = double_vec(values); + } + let array = StringArray::from(values); + + c.bench_function("bit_length", |b| b.iter(|| bench_bit_length(&array))); +} + +criterion_group!(benches, add_benchmark); +criterion_main!(benches); diff --git a/rust/arrow/src/compute/kernels/length.rs b/rust/arrow/src/compute/kernels/length.rs index 740bb2b68c8..ed1fda4a062 100644 --- a/rust/arrow/src/compute/kernels/length.rs +++ b/rust/arrow/src/compute/kernels/length.rs @@ -17,26 +17,33 @@ //! Defines kernel for length of a string array -use crate::{array::*, buffer::Buffer}; use crate::{ - datatypes::DataType, + array::*, + buffer::Buffer, + datatypes::{ArrowNativeType, ArrowPrimitiveType}, +}; +use crate::{ + datatypes::{DataType, Int32Type, Int64Type}, error::{ArrowError, Result}, }; use std::sync::Arc; -#[allow(clippy::unnecessary_wraps)] -fn length_string(array: &Array, data_type: DataType) -> Result +fn unary_offsets_string( + array: &GenericStringArray, + data_type: DataType, + op: F, +) -> ArrayRef where - OffsetSize: OffsetSizeTrait, + O: StringOffsetSizeTrait + ArrowNativeType, + F: Fn(O) -> O, { // note: offsets are stored as u8, but they can be interpreted as OffsetSize let offsets = &array.data_ref().buffers()[0]; // this is a 30% improvement over iterating over u8s and building OffsetSize, which // justifies the usage of `unsafe`. - let slice: &[OffsetSize] = - &unsafe { offsets.typed_data::() }[array.offset()..]; + let slice: &[O] = &unsafe { offsets.typed_data::() }[array.offset()..]; - let lengths = slice.windows(2).map(|offset| offset[1] - offset[0]); + let lengths = slice.windows(2).map(|offset| op(offset[1] - offset[0])); // JUSTIFICATION // Benefit @@ -60,18 +67,45 @@ where vec![buffer], vec![], ); - Ok(make_array(Arc::new(data))) + make_array(Arc::new(data)) } -/// Returns an array of Int32/Int64 denoting the number of characters in each string in the array. +fn octet_length( + array: &dyn Array, +) -> ArrayRef +where + T::Native: StringOffsetSizeTrait, +{ + let array = array + .as_any() + .downcast_ref::>() + .unwrap(); + unary_offsets_string::(array, T::DATA_TYPE, |x| x) +} + +fn bit_length_impl( + array: &dyn Array, +) -> ArrayRef +where + T::Native: StringOffsetSizeTrait, +{ + let array = array + .as_any() + .downcast_ref::>() + .unwrap(); + let bits_in_bytes = O::from_usize(8).unwrap(); + unary_offsets_string::(array, T::DATA_TYPE, |x| x * bits_in_bytes) +} + +/// Returns an array of Int32/Int64 denoting the number of bytes in each string in the array. /// /// * this only accepts StringArray/Utf8 and LargeString/LargeUtf8 /// * length of null is null. /// * length is in number of bytes pub fn length(array: &Array) -> Result { match array.data_type() { - DataType::Utf8 => length_string::(array, DataType::Int32), - DataType::LargeUtf8 => length_string::(array, DataType::Int64), + DataType::Utf8 => Ok(octet_length::(array)), + DataType::LargeUtf8 => Ok(octet_length::(array)), _ => Err(ArrowError::ComputeError(format!( "length not supported for {:?}", array.data_type() @@ -79,11 +113,27 @@ pub fn length(array: &Array) -> Result { } } +/// Returns an array of Int32/Int64 denoting the number of bits in each string in the array. +/// +/// * this only accepts StringArray/Utf8 and LargeString/LargeUtf8 +/// * bit_length of null is null. +/// * bit_length is in number of bits +pub fn bit_length(array: &Array) -> Result { + match array.data_type() { + DataType::Utf8 => Ok(bit_length_impl::(array)), + DataType::LargeUtf8 => Ok(bit_length_impl::(array)), + _ => Err(ArrowError::ComputeError(format!( + "bit_length not supported for {:?}", + array.data_type() + ))), + } +} + #[cfg(test)] mod tests { use super::*; - fn cases() -> Vec<(Vec<&'static str>, usize, Vec)> { + fn length_cases() -> Vec<(Vec<&'static str>, usize, Vec)> { fn double_vec(v: Vec) -> Vec { [&v[..], &v[..]].concat() } @@ -105,34 +155,38 @@ mod tests { } #[test] - fn test_string() -> Result<()> { - cases().into_iter().try_for_each(|(input, len, expected)| { - let array = StringArray::from(input); - let result = length(&array)?; - assert_eq!(len, result.len()); - let result = result.as_any().downcast_ref::().unwrap(); - expected.iter().enumerate().for_each(|(i, value)| { - assert_eq!(*value, result.value(i)); - }); - Ok(()) - }) + fn length_test_string() -> Result<()> { + length_cases() + .into_iter() + .try_for_each(|(input, len, expected)| { + let array = StringArray::from(input); + let result = length(&array)?; + assert_eq!(len, result.len()); + let result = result.as_any().downcast_ref::().unwrap(); + expected.iter().enumerate().for_each(|(i, value)| { + assert_eq!(*value, result.value(i)); + }); + Ok(()) + }) } #[test] - fn test_large_string() -> Result<()> { - cases().into_iter().try_for_each(|(input, len, expected)| { - let array = LargeStringArray::from(input); - let result = length(&array)?; - assert_eq!(len, result.len()); - let result = result.as_any().downcast_ref::().unwrap(); - expected.iter().enumerate().for_each(|(i, value)| { - assert_eq!(*value as i64, result.value(i)); - }); - Ok(()) - }) - } - - fn null_cases() -> Vec<(Vec>, usize, Vec>)> { + fn length_test_large_string() -> Result<()> { + length_cases() + .into_iter() + .try_for_each(|(input, len, expected)| { + let array = LargeStringArray::from(input); + let result = length(&array)?; + assert_eq!(len, result.len()); + let result = result.as_any().downcast_ref::().unwrap(); + expected.iter().enumerate().for_each(|(i, value)| { + assert_eq!(*value as i64, result.value(i)); + }); + Ok(()) + }) + } + + fn length_null_cases() -> Vec<(Vec>, usize, Vec>)> { vec![( vec![Some("one"), None, Some("three"), Some("four")], 4, @@ -141,8 +195,8 @@ mod tests { } #[test] - fn null_string() -> Result<()> { - null_cases() + fn length_null_string() -> Result<()> { + length_null_cases() .into_iter() .try_for_each(|(input, len, expected)| { let array = StringArray::from(input); @@ -157,8 +211,8 @@ mod tests { } #[test] - fn null_large_string() -> Result<()> { - null_cases() + fn length_null_large_string() -> Result<()> { + length_null_cases() .into_iter() .try_for_each(|(input, len, expected)| { let array = LargeStringArray::from(input); @@ -179,7 +233,7 @@ mod tests { /// Tests that length is not valid for u64. #[test] - fn wrong_type() { + fn length_wrong_type() { let array: UInt64Array = vec![1u64].into(); assert!(length(&array).is_err()); @@ -187,7 +241,7 @@ mod tests { /// Tests with an offset #[test] - fn offsets() -> Result<()> { + fn length_offsets() -> Result<()> { let a = StringArray::from(vec!["hello", " ", "world"]); let b = make_array( ArrayData::builder(DataType::Utf8) @@ -203,4 +257,130 @@ mod tests { Ok(()) } + + fn bit_length_cases() -> Vec<(Vec<&'static str>, usize, Vec)> { + fn double_vec(v: Vec) -> Vec { + [&v[..], &v[..]].concat() + } + + // a large array + let mut values = vec!["one", "on", "o", ""]; + let mut expected = vec![24, 16, 8, 0]; + for _ in 0..10 { + values = double_vec(values); + expected = double_vec(expected); + } + + vec![ + (vec!["hello", " ", "world", "!"], 4, vec![40, 8, 40, 8]), + (vec!["💖"], 1, vec![32]), + (vec!["josé"], 1, vec![40]), + (values, 4096, expected), + ] + } + + #[test] + fn bit_length_test_string() -> Result<()> { + bit_length_cases() + .into_iter() + .try_for_each(|(input, len, expected)| { + let array = StringArray::from(input); + let result = bit_length(&array)?; + assert_eq!(len, result.len()); + let result = result.as_any().downcast_ref::().unwrap(); + expected.iter().enumerate().for_each(|(i, value)| { + assert_eq!(*value, result.value(i)); + }); + Ok(()) + }) + } + + #[test] + fn bit_length_test_large_string() -> Result<()> { + bit_length_cases() + .into_iter() + .try_for_each(|(input, len, expected)| { + let array = LargeStringArray::from(input); + let result = bit_length(&array)?; + assert_eq!(len, result.len()); + let result = result.as_any().downcast_ref::().unwrap(); + expected.iter().enumerate().for_each(|(i, value)| { + assert_eq!(*value as i64, result.value(i)); + }); + Ok(()) + }) + } + + fn bit_length_null_cases() -> Vec<(Vec>, usize, Vec>)> + { + vec![( + vec![Some("one"), None, Some("three"), Some("four")], + 4, + vec![Some(24), None, Some(40), Some(32)], + )] + } + + #[test] + fn bit_length_null_string() -> Result<()> { + bit_length_null_cases() + .into_iter() + .try_for_each(|(input, len, expected)| { + let array = StringArray::from(input); + let result = bit_length(&array)?; + assert_eq!(len, result.len()); + let result = result.as_any().downcast_ref::().unwrap(); + + let expected: Int32Array = expected.into(); + assert_eq!(expected.data(), result.data()); + Ok(()) + }) + } + + #[test] + fn bit_length_null_large_string() -> Result<()> { + bit_length_null_cases() + .into_iter() + .try_for_each(|(input, len, expected)| { + let array = LargeStringArray::from(input); + let result = bit_length(&array)?; + assert_eq!(len, result.len()); + let result = result.as_any().downcast_ref::().unwrap(); + + // convert to i64 + let expected: Int64Array = expected + .iter() + .map(|e| e.map(|e| e as i64)) + .collect::>() + .into(); + assert_eq!(expected.data(), result.data()); + Ok(()) + }) + } + + /// Tests that bit_length is not valid for u64. + #[test] + fn bit_length_wrong_type() { + let array: UInt64Array = vec![1u64].into(); + + assert!(bit_length(&array).is_err()); + } + + /// Tests with an offset + #[test] + fn bit_length_offsets() -> Result<()> { + let a = StringArray::from(vec!["hello", " ", "world"]); + let b = make_array( + ArrayData::builder(DataType::Utf8) + .len(2) + .offset(1) + .buffers(a.data_ref().buffers().to_vec()) + .build(), + ); + let result = bit_length(b.as_ref())?; + + let expected = Int32Array::from(vec![8, 40]); + assert_eq!(expected.data(), result.data()); + + Ok(()) + } } diff --git a/rust/datafusion/Cargo.toml b/rust/datafusion/Cargo.toml index ea556662e32..11cc63bbdc3 100644 --- a/rust/datafusion/Cargo.toml +++ b/rust/datafusion/Cargo.toml @@ -64,6 +64,7 @@ log = "^0.4" md-5 = "^0.9.1" sha2 = "^0.9.1" ordered-float = "2.0" +unicode-segmentation = "^1.7.1" [dev-dependencies] rand = "0.8" diff --git a/rust/datafusion/README.md b/rust/datafusion/README.md index 7a122506e67..b4cb04321e7 100644 --- a/rust/datafusion/README.md +++ b/rust/datafusion/README.md @@ -57,7 +57,11 @@ DataFusion includes a simple command-line interactive SQL utility. See the [CLI - [x] UDAFs (user-defined aggregate functions) - [x] Common math functions - String functions - - [x] Length + - [x] bit_Length + - [x] char_length + - [x] character_length + - [x] length + - [x] octet_length - [x] Concatenate - Miscellaneous/Boolean functions - [x] nullif diff --git a/rust/datafusion/src/logical_plan/expr.rs b/rust/datafusion/src/logical_plan/expr.rs index 7f358cb31b0..65a5f716b47 100644 --- a/rust/datafusion/src/logical_plan/expr.rs +++ b/rust/datafusion/src/logical_plan/expr.rs @@ -812,6 +812,8 @@ macro_rules! unary_scalar_expr { } // generate methods for creating the supported unary expressions + +// math functions unary_scalar_expr!(Sqrt, sqrt); unary_scalar_expr!(Sin, sin); unary_scalar_expr!(Cos, cos); @@ -829,24 +831,22 @@ unary_scalar_expr!(Exp, exp); unary_scalar_expr!(Log, ln); unary_scalar_expr!(Log2, log2); unary_scalar_expr!(Log10, log10); + +// string functions +unary_scalar_expr!(BitLength, bit_length); +unary_scalar_expr!(CharacterLength, character_length); +unary_scalar_expr!(CharacterLength, length); unary_scalar_expr!(Lower, lower); -unary_scalar_expr!(Trim, trim); unary_scalar_expr!(Ltrim, ltrim); -unary_scalar_expr!(Rtrim, rtrim); -unary_scalar_expr!(Upper, upper); unary_scalar_expr!(MD5, md5); +unary_scalar_expr!(OctetLength, octet_length); +unary_scalar_expr!(Rtrim, rtrim); unary_scalar_expr!(SHA224, sha224); unary_scalar_expr!(SHA256, sha256); unary_scalar_expr!(SHA384, sha384); unary_scalar_expr!(SHA512, sha512); - -/// returns the length of a string in bytes -pub fn length(e: Expr) -> Expr { - Expr::ScalarFunction { - fun: functions::BuiltinScalarFunction::Length, - args: vec![e], - } -} +unary_scalar_expr!(Trim, trim); +unary_scalar_expr!(Upper, upper); /// returns the concatenation of string expressions pub fn concat(args: Vec) -> Expr { diff --git a/rust/datafusion/src/logical_plan/mod.rs b/rust/datafusion/src/logical_plan/mod.rs index fbad5e26606..6244387e180 100644 --- a/rust/datafusion/src/logical_plan/mod.rs +++ b/rust/datafusion/src/logical_plan/mod.rs @@ -34,11 +34,12 @@ pub use builder::LogicalPlanBuilder; pub use dfschema::{DFField, DFSchema, DFSchemaRef, ToDFSchema}; pub use display::display_schema; pub use expr::{ - abs, acos, and, array, asin, atan, avg, binary_expr, case, ceil, col, - combine_filters, concat, cos, count, count_distinct, create_udaf, create_udf, exp, - exprlist_to_fields, floor, in_list, length, lit, ln, log10, log2, lower, ltrim, max, - md5, min, or, round, rtrim, sha224, sha256, sha384, sha512, signum, sin, sqrt, sum, - tan, trim, trunc, upper, when, Expr, ExpressionVisitor, Literal, Recursion, + abs, acos, and, array, asin, atan, avg, binary_expr, bit_length, case, ceil, + character_length, col, combine_filters, concat, cos, count, count_distinct, + create_udaf, create_udf, exp, exprlist_to_fields, floor, in_list, length, lit, ln, + log10, log2, lower, ltrim, max, md5, min, octet_length, or, round, rtrim, sha224, + sha256, sha384, sha512, signum, sin, sqrt, sum, tan, trim, trunc, upper, when, Expr, + ExpressionVisitor, Literal, Recursion, }; pub use extension::UserDefinedLogicalNode; pub use operators::Operator; diff --git a/rust/datafusion/src/physical_plan/functions.rs b/rust/datafusion/src/physical_plan/functions.rs index c5cd01f93c5..baacf949270 100644 --- a/rust/datafusion/src/physical_plan/functions.rs +++ b/rust/datafusion/src/physical_plan/functions.rs @@ -45,9 +45,9 @@ use crate::{ }; use arrow::{ array::ArrayRef, - compute::kernels::length::length, + compute::kernels::length::{bit_length, length}, datatypes::TimeUnit, - datatypes::{DataType, Field, Schema}, + datatypes::{DataType, Field, Int32Type, Int64Type, Schema}, record_batch::RecordBatch, }; use fmt::{Debug, Formatter}; @@ -118,8 +118,6 @@ pub enum BuiltinScalarFunction { Abs, /// signum Signum, - /// length - Length, /// concat Concat, /// lower @@ -150,6 +148,12 @@ pub enum BuiltinScalarFunction { SHA384, /// SHA512, SHA512, + /// bit_length + BitLength, + /// character_length + CharacterLength, + /// octet_length + OctetLength, } impl fmt::Display for BuiltinScalarFunction { @@ -180,9 +184,6 @@ impl FromStr for BuiltinScalarFunction { "truc" => BuiltinScalarFunction::Trunc, "abs" => BuiltinScalarFunction::Abs, "signum" => BuiltinScalarFunction::Signum, - "length" => BuiltinScalarFunction::Length, - "char_length" => BuiltinScalarFunction::Length, - "character_length" => BuiltinScalarFunction::Length, "concat" => BuiltinScalarFunction::Concat, "lower" => BuiltinScalarFunction::Lower, "trim" => BuiltinScalarFunction::Trim, @@ -198,6 +199,11 @@ impl FromStr for BuiltinScalarFunction { "sha256" => BuiltinScalarFunction::SHA256, "sha384" => BuiltinScalarFunction::SHA384, "sha512" => BuiltinScalarFunction::SHA512, + "bit_length" => BuiltinScalarFunction::BitLength, + "octet_length" => BuiltinScalarFunction::OctetLength, + "length" => BuiltinScalarFunction::CharacterLength, + "char_length" => BuiltinScalarFunction::CharacterLength, + "character_length" => BuiltinScalarFunction::CharacterLength, _ => { return Err(DataFusionError::Plan(format!( "There is no built-in function named {}", @@ -231,16 +237,6 @@ pub fn return_type( // the return type of the built in function. // Some built-in functions' return type depends on the incoming type. match fun { - BuiltinScalarFunction::Length => Ok(match arg_types[0] { - DataType::LargeUtf8 => DataType::Int64, - DataType::Utf8 => DataType::Int32, - _ => { - // this error is internal as `data_types` should have captured this. - return Err(DataFusionError::Internal( - "The length function can only accept strings.".to_string(), - )); - } - }), BuiltinScalarFunction::Concat => Ok(DataType::Utf8), BuiltinScalarFunction::Lower => Ok(match arg_types[0] { DataType::LargeUtf8 => DataType::LargeUtf8, @@ -357,6 +353,36 @@ pub fn return_type( )); } }), + BuiltinScalarFunction::BitLength => Ok(match arg_types[0] { + DataType::LargeUtf8 => DataType::Int64, + DataType::Utf8 => DataType::Int32, + _ => { + // this error is internal as `data_types` should have captured this. + return Err(DataFusionError::Internal( + "The bit_length function can only accept strings.".to_string(), + )); + } + }), + BuiltinScalarFunction::CharacterLength => Ok(match arg_types[0] { + DataType::LargeUtf8 => DataType::Int64, + DataType::Utf8 => DataType::Int32, + _ => { + // this error is internal as `data_types` should have captured this. + return Err(DataFusionError::Internal( + "The character_length function can only accept strings.".to_string(), + )); + } + }), + BuiltinScalarFunction::OctetLength => Ok(match arg_types[0] { + DataType::LargeUtf8 => DataType::Int64, + DataType::Utf8 => DataType::Int32, + _ => { + // this error is internal as `data_types` should have captured this. + return Err(DataFusionError::Internal( + "The octet_length function can only accept strings.".to_string(), + )); + } + }), _ => Ok(DataType::Float64), } } @@ -392,7 +418,41 @@ pub fn create_physical_expr( BuiltinScalarFunction::SHA256 => crypto_expressions::sha256, BuiltinScalarFunction::SHA384 => crypto_expressions::sha384, BuiltinScalarFunction::SHA512 => crypto_expressions::sha512, - BuiltinScalarFunction::Length => |args| match &args[0] { + BuiltinScalarFunction::Concat => string_expressions::concatenate, + BuiltinScalarFunction::Lower => string_expressions::lower, + BuiltinScalarFunction::Trim => string_expressions::trim, + BuiltinScalarFunction::Ltrim => string_expressions::ltrim, + BuiltinScalarFunction::Rtrim => string_expressions::rtrim, + BuiltinScalarFunction::Upper => string_expressions::upper, + BuiltinScalarFunction::ToTimestamp => datetime_expressions::to_timestamp, + BuiltinScalarFunction::DateTrunc => datetime_expressions::date_trunc, + BuiltinScalarFunction::Array => array_expressions::array, + BuiltinScalarFunction::BitLength => |args| match &args[0] { + ColumnarValue::Array(v) => Ok(ColumnarValue::Array(bit_length(v.as_ref())?)), + ColumnarValue::Scalar(v) => match v { + ScalarValue::Utf8(v) => Ok(ColumnarValue::Scalar(ScalarValue::Int32( + v.as_ref().map(|x| (x.len() * 8) as i32), + ))), + ScalarValue::LargeUtf8(v) => Ok(ColumnarValue::Scalar( + ScalarValue::Int64(v.as_ref().map(|x| (x.len() * 8) as i64)), + )), + _ => unreachable!(), + }, + }, + BuiltinScalarFunction::CharacterLength => |args| match args[0].data_type() { + DataType::Utf8 => make_scalar_function( + string_expressions::character_length::, + )(args), + DataType::LargeUtf8 => make_scalar_function( + string_expressions::character_length::, + )(args), + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {:?} for function character_length", + other, + ))), + }, + BuiltinScalarFunction::OctetLength => |args| match &args[0] { + ColumnarValue::Array(v) => Ok(ColumnarValue::Array(length(v.as_ref())?)), ColumnarValue::Scalar(v) => match v { ScalarValue::Utf8(v) => Ok(ColumnarValue::Scalar(ScalarValue::Int32( v.as_ref().map(|x| x.len() as i32), @@ -402,17 +462,7 @@ pub fn create_physical_expr( )), _ => unreachable!(), }, - ColumnarValue::Array(v) => Ok(ColumnarValue::Array(length(v.as_ref())?)), }, - BuiltinScalarFunction::Concat => string_expressions::concatenate, - BuiltinScalarFunction::Lower => string_expressions::lower, - BuiltinScalarFunction::Trim => string_expressions::trim, - BuiltinScalarFunction::Ltrim => string_expressions::ltrim, - BuiltinScalarFunction::Rtrim => string_expressions::rtrim, - BuiltinScalarFunction::Upper => string_expressions::upper, - BuiltinScalarFunction::ToTimestamp => datetime_expressions::to_timestamp, - BuiltinScalarFunction::DateTrunc => datetime_expressions::date_trunc, - BuiltinScalarFunction::Array => array_expressions::array, }); // coerce let args = coerce(args, input_schema, &signature(fun))?; @@ -439,7 +489,9 @@ fn signature(fun: &BuiltinScalarFunction) -> Signature { BuiltinScalarFunction::Concat => Signature::Variadic(vec![DataType::Utf8]), BuiltinScalarFunction::Upper | BuiltinScalarFunction::Lower - | BuiltinScalarFunction::Length + | BuiltinScalarFunction::BitLength + | BuiltinScalarFunction::CharacterLength + | BuiltinScalarFunction::OctetLength | BuiltinScalarFunction::Trim | BuiltinScalarFunction::Ltrim | BuiltinScalarFunction::Rtrim @@ -617,48 +669,135 @@ mod tests { }; use arrow::{ array::{ - ArrayRef, FixedSizeListArray, Float64Array, Int32Array, StringArray, + Array, ArrayRef, FixedSizeListArray, Float64Array, Int32Array, StringArray, UInt32Array, UInt64Array, }, datatypes::Field, record_batch::RecordBatch, }; - fn generic_test_math(value: ScalarValue, expected: &str) -> Result<()> { - // any type works here: we evaluate against a literal of `value` - let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); - let columns: Vec = vec![Arc::new(Int32Array::from(vec![1]))]; - - let arg = lit(value); - - let expr = create_physical_expr(&BuiltinScalarFunction::Exp, &[arg], &schema)?; - - // type is correct - assert_eq!(expr.data_type(&schema)?, DataType::Float64); - - // evaluate works - let batch = RecordBatch::try_new(Arc::new(schema.clone()), columns)?; - let result = expr.evaluate(&batch)?.into_array(batch.num_rows()); - - // downcast works - let result = result.as_any().downcast_ref::().unwrap(); - - // value is correct - assert_eq!(result.value(0).to_string(), expected); - - Ok(()) + /// $FUNC function to test + /// $ARGS arguments (vec) to pass to function + /// $EXPECTED a Result> where Result allows testing errors and Option allows testing Null + /// $EXPECTED_TYPE is the expected value type + /// $DATA_TYPE is the function to test result type + /// $ARRAY_TYPE is the column type after function applied + macro_rules! test_function { + ($FUNC:ident, $ARGS:expr, $EXPECTED:expr, $EXPECTED_TYPE:ty, $DATA_TYPE: ident, $ARRAY_TYPE:ident) => { + // used to provide type annotation + let expected: Result> = $EXPECTED; + + // any type works here: we evaluate against a literal of `value` + let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); + let columns: Vec = vec![Arc::new(Int32Array::from(vec![1]))]; + + let expr = + create_physical_expr(&BuiltinScalarFunction::$FUNC, $ARGS, &schema)?; + + // type is correct + assert_eq!(expr.data_type(&schema)?, DataType::$DATA_TYPE); + + let batch = RecordBatch::try_new(Arc::new(schema.clone()), columns)?; + + match expected { + Ok(expected) => { + let result = expr.evaluate(&batch)?; + let result = result.into_array(batch.num_rows()); + let result = result.as_any().downcast_ref::<$ARRAY_TYPE>().unwrap(); + + // value is correct + match expected { + Some(v) => assert_eq!(result.value(0), v), + None => assert!(result.is_null(0)), + }; + } + Err(expected_error) => { + // evaluate is expected error - cannot use .expect_err() due to Debug not being implemented + match expr.evaluate(&batch) { + Ok(_) => assert!(false, "expected error"), + Err(error) => { + assert_eq!(error.to_string(), expected_error.to_string()); + } + } + } + }; + }; } #[test] - fn test_math_function() -> Result<()> { - // 2.71828182845904523536... : https://oeis.org/A001113 - let exp_f64 = "2.718281828459045"; - let exp_f32 = "2.7182817459106445"; - generic_test_math(ScalarValue::from(1i32), exp_f64)?; - generic_test_math(ScalarValue::from(1u32), exp_f64)?; - generic_test_math(ScalarValue::from(1u64), exp_f64)?; - generic_test_math(ScalarValue::from(1f64), exp_f64)?; - generic_test_math(ScalarValue::from(1f32), exp_f32)?; + fn test_functions() -> Result<()> { + test_function!( + CharacterLength, + &[lit(ScalarValue::Utf8(Some("chars".to_string())))], + Ok(Some(5)), + i32, + Int32, + Int32Array + ); + test_function!( + CharacterLength, + &[lit(ScalarValue::Utf8(Some("josé".to_string())))], + Ok(Some(4)), + i32, + Int32, + Int32Array + ); + test_function!( + CharacterLength, + &[lit(ScalarValue::Utf8(Some("".to_string())))], + Ok(Some(0)), + i32, + Int32, + Int32Array + ); + test_function!( + CharacterLength, + &[lit(ScalarValue::Utf8(None))], + Ok(None), + i32, + Int32, + Int32Array + ); + test_function!( + Exp, + &[lit(ScalarValue::Int32(Some(1)))], + Ok(Some((1.0_f64).exp())), + f64, + Float64, + Float64Array + ); + test_function!( + Exp, + &[lit(ScalarValue::UInt32(Some(1)))], + Ok(Some((1.0_f64).exp())), + f64, + Float64, + Float64Array + ); + test_function!( + Exp, + &[lit(ScalarValue::UInt64(Some(1)))], + Ok(Some((1.0_f64).exp())), + f64, + Float64, + Float64Array + ); + test_function!( + Exp, + &[lit(ScalarValue::Float64(Some(1.0)))], + Ok(Some((1.0_f64).exp())), + f64, + Float64, + Float64Array + ); + test_function!( + Exp, + &[lit(ScalarValue::Float32(Some(1.0)))], + Ok(Some((1.0_f32).exp() as f64)), + f64, + Float64, + Float64Array + ); Ok(()) } diff --git a/rust/datafusion/src/physical_plan/string_expressions.rs b/rust/datafusion/src/physical_plan/string_expressions.rs index a4ccef08681..81d2c67eec6 100644 --- a/rust/datafusion/src/physical_plan/string_expressions.rs +++ b/rust/datafusion/src/physical_plan/string_expressions.rs @@ -24,9 +24,13 @@ use crate::{ scalar::ScalarValue, }; use arrow::{ - array::{Array, GenericStringArray, StringArray, StringOffsetSizeTrait}, - datatypes::DataType, + array::{ + Array, ArrayRef, GenericStringArray, PrimitiveArray, StringArray, + StringOffsetSizeTrait, + }, + datatypes::{ArrowNativeType, ArrowPrimitiveType, DataType}, }; +use unicode_segmentation::UnicodeSegmentation; use super::ColumnarValue; @@ -115,6 +119,27 @@ where } } +/// Returns number of characters in the string. +/// character_length('josé') = 4 +pub fn character_length(args: &[ArrayRef]) -> Result +where + T::Native: StringOffsetSizeTrait, +{ + let string_array: &GenericStringArray = args[0] + .as_any() + .downcast_ref::>() + .unwrap(); + + let result = string_array + .iter() + .map(|x| { + x.map(|x: &str| T::Native::from_usize(x.graphemes(true).count()).unwrap()) + }) + .collect::>(); + + Ok(Arc::new(result) as ArrayRef) +} + /// concatenate string columns together. pub fn concatenate(args: &[ColumnarValue]) -> Result { // downcast all arguments to strings diff --git a/rust/datafusion/src/prelude.rs b/rust/datafusion/src/prelude.rs index 4575de19c66..26e03c7453e 100644 --- a/rust/datafusion/src/prelude.rs +++ b/rust/datafusion/src/prelude.rs @@ -28,8 +28,8 @@ pub use crate::dataframe::DataFrame; pub use crate::execution::context::{ExecutionConfig, ExecutionContext}; pub use crate::logical_plan::{ - array, avg, col, concat, count, create_udf, in_list, length, lit, lower, ltrim, max, - md5, min, rtrim, sha224, sha256, sha384, sha512, sum, trim, upper, JoinType, - Partitioning, + array, avg, bit_length, character_length, col, concat, count, create_udf, in_list, + length, lit, lower, ltrim, max, md5, min, octet_length, rtrim, sha224, sha256, + sha384, sha512, sum, trim, upper, JoinType, Partitioning, }; pub use crate::physical_plan::csv::CsvReadOptions;