From 29939a5f2257b79d8a73903943ab6198b422a76d Mon Sep 17 00:00:00 2001 From: Mike Seddon Date: Wed, 17 Feb 2021 08:21:39 +1100 Subject: [PATCH 1/6] length functions --- rust/arrow/Cargo.toml | 4 + rust/arrow/benches/bit_length_kernel.rs | 46 +++ rust/arrow/src/compute/kernels/bit_length.rs | 208 ++++++++++++++ rust/arrow/src/compute/kernels/length.rs | 2 +- rust/arrow/src/compute/kernels/mod.rs | 1 + rust/datafusion/Cargo.toml | 1 + rust/datafusion/README.md | 6 +- rust/datafusion/src/logical_plan/expr.rs | 22 +- rust/datafusion/src/logical_plan/mod.rs | 11 +- .../datafusion/src/physical_plan/functions.rs | 262 ++++++++++++++---- .../src/physical_plan/string_expressions.rs | 40 ++- rust/datafusion/src/prelude.rs | 6 +- 12 files changed, 527 insertions(+), 82 deletions(-) create mode 100644 rust/arrow/benches/bit_length_kernel.rs create mode 100644 rust/arrow/src/compute/kernels/bit_length.rs diff --git a/rust/arrow/Cargo.toml b/rust/arrow/Cargo.toml index 0b14b5bfae8..5ab1f8cc02b 100644 --- a/rust/arrow/Cargo.toml +++ b/rust/arrow/Cargo.toml @@ -114,6 +114,10 @@ harness = false name = "length_kernel" harness = false +[[bench]] +name = "bit_length_kernel" +harness = false + [[bench]] name = "sort_kernel" harness = false diff --git a/rust/arrow/benches/bit_length_kernel.rs b/rust/arrow/benches/bit_length_kernel.rs new file mode 100644 index 00000000000..b2104d7f354 --- /dev/null +++ b/rust/arrow/benches/bit_length_kernel.rs @@ -0,0 +1,46 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#[macro_use] +extern crate criterion; +use criterion::Criterion; + +extern crate arrow; + +use arrow::{array::*, compute::kernels::bit_length::bit_length}; + +fn bench_bit_length(array: &StringArray) { + criterion::black_box(bit_length(array).unwrap()); +} + +fn add_benchmark(c: &mut Criterion) { + fn double_vec(v: Vec) -> Vec { + [&v[..], &v[..]].concat() + } + + // double ["hello", " ", "world", "!"] 10 times + let mut values = vec!["one", "on", "o", ""]; + for _ in 0..10 { + values = double_vec(values); + } + let array = StringArray::from(values); + + c.bench_function("bit_length", |b| b.iter(|| bench_bit_length(&array))); +} + +criterion_group!(benches, add_benchmark); +criterion_main!(benches); diff --git a/rust/arrow/src/compute/kernels/bit_length.rs b/rust/arrow/src/compute/kernels/bit_length.rs new file mode 100644 index 00000000000..fee771315df --- /dev/null +++ b/rust/arrow/src/compute/kernels/bit_length.rs @@ -0,0 +1,208 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Defines kernel for length of a string array + +use crate::{array::*, buffer::Buffer}; +use crate::{ + datatypes::DataType, + error::{ArrowError, Result}, +}; +use std::sync::Arc; + +fn bit_length_string(array: &Array, data_type: DataType) -> ArrayRef +where + OffsetSize: OffsetSizeTrait, +{ + // note: offsets are stored as u8, but they can be interpreted as OffsetSize + let offsets = &array.data_ref().buffers()[0]; + // this is a 30% improvement over iterating over u8s and building OffsetSize, which + // justifies the usage of `unsafe`. + let slice: &[OffsetSize] = + &unsafe { offsets.typed_data::() }[array.offset()..]; + + let bit_size = OffsetSize::from_usize(8).unwrap(); + let lengths = slice + .windows(2) + .map(|offset| (offset[1] - offset[0]) * bit_size); + + // JUSTIFICATION + // Benefit + // ~60% speedup + // Soundness + // `values` is an iterator with a known size. + let buffer = unsafe { Buffer::from_trusted_len_iter(lengths) }; + + let null_bit_buffer = array + .data_ref() + .null_bitmap() + .as_ref() + .map(|b| b.bits.clone()); + + let data = ArrayData::new( + data_type, + array.len(), + None, + null_bit_buffer, + 0, + vec![buffer], + vec![], + ); + make_array(Arc::new(data)) +} + +/// Returns an array of Int32/Int64 denoting the number of bits in each string in the array. +/// +/// * this only accepts StringArray/Utf8 and LargeString/LargeUtf8 +/// * bit_length of null is null. +/// * bit_length is in number of bits +pub fn bit_length(array: &Array) -> Result { + match array.data_type() { + DataType::Utf8 => Ok(bit_length_string::(array, DataType::Int32)), + DataType::LargeUtf8 => Ok(bit_length_string::(array, DataType::Int64)), + _ => Err(ArrowError::ComputeError(format!( + "bit_length not supported for {:?}", + array.data_type() + ))), + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn cases() -> Vec<(Vec<&'static str>, usize, Vec)> { + fn double_vec(v: Vec) -> Vec { + [&v[..], &v[..]].concat() + } + + // a large array + let mut values = vec!["one", "on", "o", ""]; + let mut expected = vec![24, 16, 8, 0]; + for _ in 0..10 { + values = double_vec(values); + expected = double_vec(expected); + } + + vec![ + (vec!["hello", " ", "world", "!"], 4, vec![40, 8, 40, 8]), + (vec!["💖"], 1, vec![32]), + (vec!["josé"], 1, vec![40]), + (values, 4096, expected), + ] + } + + #[test] + fn test_string() -> Result<()> { + cases().into_iter().try_for_each(|(input, len, expected)| { + let array = StringArray::from(input); + let result = bit_length(&array)?; + assert_eq!(len, result.len()); + let result = result.as_any().downcast_ref::().unwrap(); + expected.iter().enumerate().for_each(|(i, value)| { + assert_eq!(*value, result.value(i)); + }); + Ok(()) + }) + } + + #[test] + fn test_large_string() -> Result<()> { + cases().into_iter().try_for_each(|(input, len, expected)| { + let array = LargeStringArray::from(input); + let result = bit_length(&array)?; + assert_eq!(len, result.len()); + let result = result.as_any().downcast_ref::().unwrap(); + expected.iter().enumerate().for_each(|(i, value)| { + assert_eq!(*value as i64, result.value(i)); + }); + Ok(()) + }) + } + + fn null_cases() -> Vec<(Vec>, usize, Vec>)> { + vec![( + vec![Some("one"), None, Some("three"), Some("four")], + 4, + vec![Some(24), None, Some(40), Some(32)], + )] + } + + #[test] + fn null_string() -> Result<()> { + null_cases() + .into_iter() + .try_for_each(|(input, len, expected)| { + let array = StringArray::from(input); + let result = bit_length(&array)?; + assert_eq!(len, result.len()); + let result = result.as_any().downcast_ref::().unwrap(); + + let expected: Int32Array = expected.into(); + assert_eq!(expected.data(), result.data()); + Ok(()) + }) + } + + #[test] + fn null_large_string() -> Result<()> { + null_cases() + .into_iter() + .try_for_each(|(input, len, expected)| { + let array = LargeStringArray::from(input); + let result = bit_length(&array)?; + assert_eq!(len, result.len()); + let result = result.as_any().downcast_ref::().unwrap(); + + // convert to i64 + let expected: Int64Array = expected + .iter() + .map(|e| e.map(|e| e as i64)) + .collect::>() + .into(); + assert_eq!(expected.data(), result.data()); + Ok(()) + }) + } + + /// Tests that bit_length is not valid for u64. + #[test] + fn wrong_type() { + let array: UInt64Array = vec![1u64].into(); + + assert!(bit_length(&array).is_err()); + } + + /// Tests with an offset + #[test] + fn offsets() -> Result<()> { + let a = StringArray::from(vec!["hello", " ", "world"]); + let b = make_array( + ArrayData::builder(DataType::Utf8) + .len(2) + .offset(1) + .buffers(a.data_ref().buffers().to_vec()) + .build(), + ); + let result = bit_length(b.as_ref())?; + + let expected = Int32Array::from(vec![8, 40]); + assert_eq!(expected.data(), result.data()); + + Ok(()) + } +} diff --git a/rust/arrow/src/compute/kernels/length.rs b/rust/arrow/src/compute/kernels/length.rs index 740bb2b68c8..f285ef34c08 100644 --- a/rust/arrow/src/compute/kernels/length.rs +++ b/rust/arrow/src/compute/kernels/length.rs @@ -63,7 +63,7 @@ where Ok(make_array(Arc::new(data))) } -/// Returns an array of Int32/Int64 denoting the number of characters in each string in the array. +/// Returns an array of Int32/Int64 denoting the number of bytes in each string in the array. /// /// * this only accepts StringArray/Utf8 and LargeString/LargeUtf8 /// * length of null is null. diff --git a/rust/arrow/src/compute/kernels/mod.rs b/rust/arrow/src/compute/kernels/mod.rs index a8d24979e04..ef2f7736f4d 100644 --- a/rust/arrow/src/compute/kernels/mod.rs +++ b/rust/arrow/src/compute/kernels/mod.rs @@ -20,6 +20,7 @@ pub mod aggregate; pub mod arithmetic; pub mod arity; +pub mod bit_length; pub mod boolean; pub mod cast; pub mod cast_utils; diff --git a/rust/datafusion/Cargo.toml b/rust/datafusion/Cargo.toml index ea556662e32..11cc63bbdc3 100644 --- a/rust/datafusion/Cargo.toml +++ b/rust/datafusion/Cargo.toml @@ -64,6 +64,7 @@ log = "^0.4" md-5 = "^0.9.1" sha2 = "^0.9.1" ordered-float = "2.0" +unicode-segmentation = "^1.7.1" [dev-dependencies] rand = "0.8" diff --git a/rust/datafusion/README.md b/rust/datafusion/README.md index 7a122506e67..b4cb04321e7 100644 --- a/rust/datafusion/README.md +++ b/rust/datafusion/README.md @@ -57,7 +57,11 @@ DataFusion includes a simple command-line interactive SQL utility. See the [CLI - [x] UDAFs (user-defined aggregate functions) - [x] Common math functions - String functions - - [x] Length + - [x] bit_Length + - [x] char_length + - [x] character_length + - [x] length + - [x] octet_length - [x] Concatenate - Miscellaneous/Boolean functions - [x] nullif diff --git a/rust/datafusion/src/logical_plan/expr.rs b/rust/datafusion/src/logical_plan/expr.rs index 7f358cb31b0..65a5f716b47 100644 --- a/rust/datafusion/src/logical_plan/expr.rs +++ b/rust/datafusion/src/logical_plan/expr.rs @@ -812,6 +812,8 @@ macro_rules! unary_scalar_expr { } // generate methods for creating the supported unary expressions + +// math functions unary_scalar_expr!(Sqrt, sqrt); unary_scalar_expr!(Sin, sin); unary_scalar_expr!(Cos, cos); @@ -829,24 +831,22 @@ unary_scalar_expr!(Exp, exp); unary_scalar_expr!(Log, ln); unary_scalar_expr!(Log2, log2); unary_scalar_expr!(Log10, log10); + +// string functions +unary_scalar_expr!(BitLength, bit_length); +unary_scalar_expr!(CharacterLength, character_length); +unary_scalar_expr!(CharacterLength, length); unary_scalar_expr!(Lower, lower); -unary_scalar_expr!(Trim, trim); unary_scalar_expr!(Ltrim, ltrim); -unary_scalar_expr!(Rtrim, rtrim); -unary_scalar_expr!(Upper, upper); unary_scalar_expr!(MD5, md5); +unary_scalar_expr!(OctetLength, octet_length); +unary_scalar_expr!(Rtrim, rtrim); unary_scalar_expr!(SHA224, sha224); unary_scalar_expr!(SHA256, sha256); unary_scalar_expr!(SHA384, sha384); unary_scalar_expr!(SHA512, sha512); - -/// returns the length of a string in bytes -pub fn length(e: Expr) -> Expr { - Expr::ScalarFunction { - fun: functions::BuiltinScalarFunction::Length, - args: vec![e], - } -} +unary_scalar_expr!(Trim, trim); +unary_scalar_expr!(Upper, upper); /// returns the concatenation of string expressions pub fn concat(args: Vec) -> Expr { diff --git a/rust/datafusion/src/logical_plan/mod.rs b/rust/datafusion/src/logical_plan/mod.rs index fbad5e26606..6244387e180 100644 --- a/rust/datafusion/src/logical_plan/mod.rs +++ b/rust/datafusion/src/logical_plan/mod.rs @@ -34,11 +34,12 @@ pub use builder::LogicalPlanBuilder; pub use dfschema::{DFField, DFSchema, DFSchemaRef, ToDFSchema}; pub use display::display_schema; pub use expr::{ - abs, acos, and, array, asin, atan, avg, binary_expr, case, ceil, col, - combine_filters, concat, cos, count, count_distinct, create_udaf, create_udf, exp, - exprlist_to_fields, floor, in_list, length, lit, ln, log10, log2, lower, ltrim, max, - md5, min, or, round, rtrim, sha224, sha256, sha384, sha512, signum, sin, sqrt, sum, - tan, trim, trunc, upper, when, Expr, ExpressionVisitor, Literal, Recursion, + abs, acos, and, array, asin, atan, avg, binary_expr, bit_length, case, ceil, + character_length, col, combine_filters, concat, cos, count, count_distinct, + create_udaf, create_udf, exp, exprlist_to_fields, floor, in_list, length, lit, ln, + log10, log2, lower, ltrim, max, md5, min, octet_length, or, round, rtrim, sha224, + sha256, sha384, sha512, signum, sin, sqrt, sum, tan, trim, trunc, upper, when, Expr, + ExpressionVisitor, Literal, Recursion, }; pub use extension::UserDefinedLogicalNode; pub use operators::Operator; diff --git a/rust/datafusion/src/physical_plan/functions.rs b/rust/datafusion/src/physical_plan/functions.rs index c5cd01f93c5..2513889c6d2 100644 --- a/rust/datafusion/src/physical_plan/functions.rs +++ b/rust/datafusion/src/physical_plan/functions.rs @@ -45,6 +45,7 @@ use crate::{ }; use arrow::{ array::ArrayRef, + compute::kernels::bit_length::bit_length, compute::kernels::length::length, datatypes::TimeUnit, datatypes::{DataType, Field, Schema}, @@ -118,8 +119,6 @@ pub enum BuiltinScalarFunction { Abs, /// signum Signum, - /// length - Length, /// concat Concat, /// lower @@ -150,6 +149,12 @@ pub enum BuiltinScalarFunction { SHA384, /// SHA512, SHA512, + /// bit_length + BitLength, + /// character_length + CharacterLength, + /// octet_length + OctetLength, } impl fmt::Display for BuiltinScalarFunction { @@ -180,9 +185,6 @@ impl FromStr for BuiltinScalarFunction { "truc" => BuiltinScalarFunction::Trunc, "abs" => BuiltinScalarFunction::Abs, "signum" => BuiltinScalarFunction::Signum, - "length" => BuiltinScalarFunction::Length, - "char_length" => BuiltinScalarFunction::Length, - "character_length" => BuiltinScalarFunction::Length, "concat" => BuiltinScalarFunction::Concat, "lower" => BuiltinScalarFunction::Lower, "trim" => BuiltinScalarFunction::Trim, @@ -198,6 +200,11 @@ impl FromStr for BuiltinScalarFunction { "sha256" => BuiltinScalarFunction::SHA256, "sha384" => BuiltinScalarFunction::SHA384, "sha512" => BuiltinScalarFunction::SHA512, + "bit_length" => BuiltinScalarFunction::BitLength, + "octet_length" => BuiltinScalarFunction::OctetLength, + "length" => BuiltinScalarFunction::CharacterLength, + "char_length" => BuiltinScalarFunction::CharacterLength, + "character_length" => BuiltinScalarFunction::CharacterLength, _ => { return Err(DataFusionError::Plan(format!( "There is no built-in function named {}", @@ -231,16 +238,6 @@ pub fn return_type( // the return type of the built in function. // Some built-in functions' return type depends on the incoming type. match fun { - BuiltinScalarFunction::Length => Ok(match arg_types[0] { - DataType::LargeUtf8 => DataType::Int64, - DataType::Utf8 => DataType::Int32, - _ => { - // this error is internal as `data_types` should have captured this. - return Err(DataFusionError::Internal( - "The length function can only accept strings.".to_string(), - )); - } - }), BuiltinScalarFunction::Concat => Ok(DataType::Utf8), BuiltinScalarFunction::Lower => Ok(match arg_types[0] { DataType::LargeUtf8 => DataType::LargeUtf8, @@ -357,6 +354,36 @@ pub fn return_type( )); } }), + BuiltinScalarFunction::BitLength => Ok(match arg_types[0] { + DataType::LargeUtf8 => DataType::Int64, + DataType::Utf8 => DataType::Int32, + _ => { + // this error is internal as `data_types` should have captured this. + return Err(DataFusionError::Internal( + "The bit_length function can only accept strings.".to_string(), + )); + } + }), + BuiltinScalarFunction::CharacterLength => Ok(match arg_types[0] { + DataType::LargeUtf8 => DataType::Int64, + DataType::Utf8 => DataType::Int32, + _ => { + // this error is internal as `data_types` should have captured this. + return Err(DataFusionError::Internal( + "The character_length function can only accept strings.".to_string(), + )); + } + }), + BuiltinScalarFunction::OctetLength => Ok(match arg_types[0] { + DataType::LargeUtf8 => DataType::Int64, + DataType::Utf8 => DataType::Int32, + _ => { + // this error is internal as `data_types` should have captured this. + return Err(DataFusionError::Internal( + "The octet_length function can only accept strings.".to_string(), + )); + } + }), _ => Ok(DataType::Float64), } } @@ -392,7 +419,41 @@ pub fn create_physical_expr( BuiltinScalarFunction::SHA256 => crypto_expressions::sha256, BuiltinScalarFunction::SHA384 => crypto_expressions::sha384, BuiltinScalarFunction::SHA512 => crypto_expressions::sha512, - BuiltinScalarFunction::Length => |args| match &args[0] { + BuiltinScalarFunction::Concat => string_expressions::concatenate, + BuiltinScalarFunction::Lower => string_expressions::lower, + BuiltinScalarFunction::Trim => string_expressions::trim, + BuiltinScalarFunction::Ltrim => string_expressions::ltrim, + BuiltinScalarFunction::Rtrim => string_expressions::rtrim, + BuiltinScalarFunction::Upper => string_expressions::upper, + BuiltinScalarFunction::ToTimestamp => datetime_expressions::to_timestamp, + BuiltinScalarFunction::DateTrunc => datetime_expressions::date_trunc, + BuiltinScalarFunction::Array => array_expressions::array, + BuiltinScalarFunction::BitLength => |args| match &args[0] { + ColumnarValue::Array(v) => Ok(ColumnarValue::Array(bit_length(v.as_ref())?)), + ColumnarValue::Scalar(v) => match v { + ScalarValue::Utf8(v) => Ok(ColumnarValue::Scalar(ScalarValue::Int32( + v.as_ref().map(|x| (x.len() * 8) as i32), + ))), + ScalarValue::LargeUtf8(v) => Ok(ColumnarValue::Scalar( + ScalarValue::Int64(v.as_ref().map(|x| (x.len() * 8) as i64)), + )), + _ => unreachable!(), + }, + }, + BuiltinScalarFunction::CharacterLength => |args| match args[0].data_type() { + DataType::Utf8 => { + make_scalar_function(string_expressions::character_length_i32)(args) + } + DataType::LargeUtf8 => { + make_scalar_function(string_expressions::character_length_i64)(args) + } + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {:?} for function character_length", + other, + ))), + }, + BuiltinScalarFunction::OctetLength => |args| match &args[0] { + ColumnarValue::Array(v) => Ok(ColumnarValue::Array(length(v.as_ref())?)), ColumnarValue::Scalar(v) => match v { ScalarValue::Utf8(v) => Ok(ColumnarValue::Scalar(ScalarValue::Int32( v.as_ref().map(|x| x.len() as i32), @@ -402,17 +463,7 @@ pub fn create_physical_expr( )), _ => unreachable!(), }, - ColumnarValue::Array(v) => Ok(ColumnarValue::Array(length(v.as_ref())?)), }, - BuiltinScalarFunction::Concat => string_expressions::concatenate, - BuiltinScalarFunction::Lower => string_expressions::lower, - BuiltinScalarFunction::Trim => string_expressions::trim, - BuiltinScalarFunction::Ltrim => string_expressions::ltrim, - BuiltinScalarFunction::Rtrim => string_expressions::rtrim, - BuiltinScalarFunction::Upper => string_expressions::upper, - BuiltinScalarFunction::ToTimestamp => datetime_expressions::to_timestamp, - BuiltinScalarFunction::DateTrunc => datetime_expressions::date_trunc, - BuiltinScalarFunction::Array => array_expressions::array, }); // coerce let args = coerce(args, input_schema, &signature(fun))?; @@ -439,7 +490,9 @@ fn signature(fun: &BuiltinScalarFunction) -> Signature { BuiltinScalarFunction::Concat => Signature::Variadic(vec![DataType::Utf8]), BuiltinScalarFunction::Upper | BuiltinScalarFunction::Lower - | BuiltinScalarFunction::Length + | BuiltinScalarFunction::BitLength + | BuiltinScalarFunction::CharacterLength + | BuiltinScalarFunction::OctetLength | BuiltinScalarFunction::Trim | BuiltinScalarFunction::Ltrim | BuiltinScalarFunction::Rtrim @@ -617,48 +670,137 @@ mod tests { }; use arrow::{ array::{ - ArrayRef, FixedSizeListArray, Float64Array, Int32Array, StringArray, + Array, ArrayRef, FixedSizeListArray, Float64Array, Int32Array, StringArray, UInt32Array, UInt64Array, }, datatypes::Field, record_batch::RecordBatch, }; - fn generic_test_math(value: ScalarValue, expected: &str) -> Result<()> { - // any type works here: we evaluate against a literal of `value` - let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); - let columns: Vec = vec![Arc::new(Int32Array::from(vec![1]))]; - - let arg = lit(value); - - let expr = create_physical_expr(&BuiltinScalarFunction::Exp, &[arg], &schema)?; - - // type is correct - assert_eq!(expr.data_type(&schema)?, DataType::Float64); - - // evaluate works - let batch = RecordBatch::try_new(Arc::new(schema.clone()), columns)?; - let result = expr.evaluate(&batch)?.into_array(batch.num_rows()); - - // downcast works - let result = result.as_any().downcast_ref::().unwrap(); - - // value is correct - assert_eq!(result.value(0).to_string(), expected); - - Ok(()) + /// $FUNC function to test + /// $ARGS arguments (vec) to pass to function + /// $EXPECTED a Result> where Result allows testing errors and Option allows testing Null + /// $EXPECTED_TYPE is the expected value type + /// $DATA_TYPE is the function to test result type + /// $ARRAY_TYPE is the column type after function applied + macro_rules! test_function { + ($FUNC:ident, $ARGS:expr, $EXPECTED:expr, $EXPECTED_TYPE:ty, $DATA_TYPE: ident, $ARRAY_TYPE:ident) => { + println!("{:?}", BuiltinScalarFunction::$FUNC); + + // used to provide type annotation + let expected: Result> = $EXPECTED; + + // any type works here: we evaluate against a literal of `value` + let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); + let columns: Vec = vec![Arc::new(Int32Array::from(vec![1]))]; + + let expr = + create_physical_expr(&BuiltinScalarFunction::$FUNC, $ARGS, &schema)?; + + // type is correct + assert_eq!(expr.data_type(&schema)?, DataType::$DATA_TYPE); + + let batch = RecordBatch::try_new(Arc::new(schema.clone()), columns)?; + + match expected { + Ok(expected) => { + let result = expr.evaluate(&batch)?; + let result = result.into_array(batch.num_rows()); + let result = result.as_any().downcast_ref::<$ARRAY_TYPE>().unwrap(); + + // value is correct + match expected { + Some(v) => assert_eq!(result.value(0), v), + None => assert!(result.is_null(0)), + }; + } + Err(expected_error) => { + // evaluate is expected error - cannot use .expect_err() due to Debug not being implemented + match expr.evaluate(&batch) { + Ok(_) => assert!(false, "expected error"), + Err(error) => { + assert_eq!(error.to_string(), expected_error.to_string()); + } + } + } + }; + }; } #[test] - fn test_math_function() -> Result<()> { - // 2.71828182845904523536... : https://oeis.org/A001113 - let exp_f64 = "2.718281828459045"; - let exp_f32 = "2.7182817459106445"; - generic_test_math(ScalarValue::from(1i32), exp_f64)?; - generic_test_math(ScalarValue::from(1u32), exp_f64)?; - generic_test_math(ScalarValue::from(1u64), exp_f64)?; - generic_test_math(ScalarValue::from(1f64), exp_f64)?; - generic_test_math(ScalarValue::from(1f32), exp_f32)?; + fn test_functions() -> Result<()> { + test_function!( + CharacterLength, + &[lit(ScalarValue::Utf8(Some("chars".to_string())))], + Ok(Some(5)), + i32, + Int32, + Int32Array + ); + test_function!( + CharacterLength, + &[lit(ScalarValue::Utf8(Some("josé".to_string())))], + Ok(Some(4)), + i32, + Int32, + Int32Array + ); + test_function!( + CharacterLength, + &[lit(ScalarValue::Utf8(Some("".to_string())))], + Ok(Some(0)), + i32, + Int32, + Int32Array + ); + test_function!( + CharacterLength, + &[lit(ScalarValue::Utf8(None))], + Ok(None), + i32, + Int32, + Int32Array + ); + test_function!( + Exp, + &[lit(ScalarValue::Int32(Some(1)))], + Ok(Some((1.0_f64).exp())), + f64, + Float64, + Float64Array + ); + test_function!( + Exp, + &[lit(ScalarValue::UInt32(Some(1)))], + Ok(Some((1.0_f64).exp())), + f64, + Float64, + Float64Array + ); + test_function!( + Exp, + &[lit(ScalarValue::UInt64(Some(1)))], + Ok(Some((1.0_f64).exp())), + f64, + Float64, + Float64Array + ); + test_function!( + Exp, + &[lit(ScalarValue::Float64(Some(1.0)))], + Ok(Some((1.0_f64).exp())), + f64, + Float64, + Float64Array + ); + test_function!( + Exp, + &[lit(ScalarValue::Float32(Some(1.0)))], + Ok(Some((1.0_f32).exp() as f64)), + f64, + Float64, + Float64Array + ); Ok(()) } diff --git a/rust/datafusion/src/physical_plan/string_expressions.rs b/rust/datafusion/src/physical_plan/string_expressions.rs index a4ccef08681..0d761ac3073 100644 --- a/rust/datafusion/src/physical_plan/string_expressions.rs +++ b/rust/datafusion/src/physical_plan/string_expressions.rs @@ -24,9 +24,13 @@ use crate::{ scalar::ScalarValue, }; use arrow::{ - array::{Array, GenericStringArray, StringArray, StringOffsetSizeTrait}, + array::{ + Array, ArrayRef, GenericStringArray, Int32Array, Int64Array, StringArray, + StringOffsetSizeTrait, + }, datatypes::DataType, }; +use unicode_segmentation::UnicodeSegmentation; use super::ColumnarValue; @@ -115,6 +119,40 @@ where } } +/// Returns number of characters in the string. +/// character_length_i32('josé') = 4 +pub fn character_length_i32(args: &[ArrayRef]) -> Result { + let string_array: &GenericStringArray = args[0] + .as_any() + .downcast_ref::>() + .unwrap(); + + // first map is the iterator, second is for the `Option<_>` + let result = string_array + .iter() + .map(|x| x.map(|x: &str| x.graphemes(true).count() as i32)) + .collect::(); + + Ok(Arc::new(result) as ArrayRef) +} + +/// Returns number of characters in the string. +/// character_length_i64('josé') = 4 +pub fn character_length_i64(args: &[ArrayRef]) -> Result { + let string_array: &GenericStringArray = args[0] + .as_any() + .downcast_ref::>() + .unwrap(); + + // first map is the iterator, second is for the `Option<_>` + let result = string_array + .iter() + .map(|x| x.map(|x: &str| x.graphemes(true).count() as i64)) + .collect::(); + + Ok(Arc::new(result) as ArrayRef) +} + /// concatenate string columns together. pub fn concatenate(args: &[ColumnarValue]) -> Result { // downcast all arguments to strings diff --git a/rust/datafusion/src/prelude.rs b/rust/datafusion/src/prelude.rs index 4575de19c66..26e03c7453e 100644 --- a/rust/datafusion/src/prelude.rs +++ b/rust/datafusion/src/prelude.rs @@ -28,8 +28,8 @@ pub use crate::dataframe::DataFrame; pub use crate::execution::context::{ExecutionConfig, ExecutionContext}; pub use crate::logical_plan::{ - array, avg, col, concat, count, create_udf, in_list, length, lit, lower, ltrim, max, - md5, min, rtrim, sha224, sha256, sha384, sha512, sum, trim, upper, JoinType, - Partitioning, + array, avg, bit_length, character_length, col, concat, count, create_udf, in_list, + length, lit, lower, ltrim, max, md5, min, octet_length, rtrim, sha224, sha256, + sha384, sha512, sum, trim, upper, JoinType, Partitioning, }; pub use crate::physical_plan::csv::CsvReadOptions; From 04f3bba115ca056e228c46ec296d24bf0242969a Mon Sep 17 00:00:00 2001 From: Mike Seddon Date: Wed, 17 Feb 2021 18:59:49 +1100 Subject: [PATCH 2/6] address review comments --- rust/arrow/benches/bit_length_kernel.rs | 2 +- rust/arrow/src/compute/kernels/bit_length.rs | 208 ------------- rust/arrow/src/compute/kernels/length.rs | 294 +++++++++++++----- rust/arrow/src/compute/kernels/mod.rs | 1 - .../datafusion/src/physical_plan/functions.rs | 19 +- .../src/physical_plan/string_expressions.rs | 39 +-- 6 files changed, 246 insertions(+), 317 deletions(-) delete mode 100644 rust/arrow/src/compute/kernels/bit_length.rs diff --git a/rust/arrow/benches/bit_length_kernel.rs b/rust/arrow/benches/bit_length_kernel.rs index b2104d7f354..51d31345712 100644 --- a/rust/arrow/benches/bit_length_kernel.rs +++ b/rust/arrow/benches/bit_length_kernel.rs @@ -21,7 +21,7 @@ use criterion::Criterion; extern crate arrow; -use arrow::{array::*, compute::kernels::bit_length::bit_length}; +use arrow::{array::*, compute::kernels::length::bit_length}; fn bench_bit_length(array: &StringArray) { criterion::black_box(bit_length(array).unwrap()); diff --git a/rust/arrow/src/compute/kernels/bit_length.rs b/rust/arrow/src/compute/kernels/bit_length.rs deleted file mode 100644 index fee771315df..00000000000 --- a/rust/arrow/src/compute/kernels/bit_length.rs +++ /dev/null @@ -1,208 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Defines kernel for length of a string array - -use crate::{array::*, buffer::Buffer}; -use crate::{ - datatypes::DataType, - error::{ArrowError, Result}, -}; -use std::sync::Arc; - -fn bit_length_string(array: &Array, data_type: DataType) -> ArrayRef -where - OffsetSize: OffsetSizeTrait, -{ - // note: offsets are stored as u8, but they can be interpreted as OffsetSize - let offsets = &array.data_ref().buffers()[0]; - // this is a 30% improvement over iterating over u8s and building OffsetSize, which - // justifies the usage of `unsafe`. - let slice: &[OffsetSize] = - &unsafe { offsets.typed_data::() }[array.offset()..]; - - let bit_size = OffsetSize::from_usize(8).unwrap(); - let lengths = slice - .windows(2) - .map(|offset| (offset[1] - offset[0]) * bit_size); - - // JUSTIFICATION - // Benefit - // ~60% speedup - // Soundness - // `values` is an iterator with a known size. - let buffer = unsafe { Buffer::from_trusted_len_iter(lengths) }; - - let null_bit_buffer = array - .data_ref() - .null_bitmap() - .as_ref() - .map(|b| b.bits.clone()); - - let data = ArrayData::new( - data_type, - array.len(), - None, - null_bit_buffer, - 0, - vec![buffer], - vec![], - ); - make_array(Arc::new(data)) -} - -/// Returns an array of Int32/Int64 denoting the number of bits in each string in the array. -/// -/// * this only accepts StringArray/Utf8 and LargeString/LargeUtf8 -/// * bit_length of null is null. -/// * bit_length is in number of bits -pub fn bit_length(array: &Array) -> Result { - match array.data_type() { - DataType::Utf8 => Ok(bit_length_string::(array, DataType::Int32)), - DataType::LargeUtf8 => Ok(bit_length_string::(array, DataType::Int64)), - _ => Err(ArrowError::ComputeError(format!( - "bit_length not supported for {:?}", - array.data_type() - ))), - } -} - -#[cfg(test)] -mod tests { - use super::*; - - fn cases() -> Vec<(Vec<&'static str>, usize, Vec)> { - fn double_vec(v: Vec) -> Vec { - [&v[..], &v[..]].concat() - } - - // a large array - let mut values = vec!["one", "on", "o", ""]; - let mut expected = vec![24, 16, 8, 0]; - for _ in 0..10 { - values = double_vec(values); - expected = double_vec(expected); - } - - vec![ - (vec!["hello", " ", "world", "!"], 4, vec![40, 8, 40, 8]), - (vec!["💖"], 1, vec![32]), - (vec!["josé"], 1, vec![40]), - (values, 4096, expected), - ] - } - - #[test] - fn test_string() -> Result<()> { - cases().into_iter().try_for_each(|(input, len, expected)| { - let array = StringArray::from(input); - let result = bit_length(&array)?; - assert_eq!(len, result.len()); - let result = result.as_any().downcast_ref::().unwrap(); - expected.iter().enumerate().for_each(|(i, value)| { - assert_eq!(*value, result.value(i)); - }); - Ok(()) - }) - } - - #[test] - fn test_large_string() -> Result<()> { - cases().into_iter().try_for_each(|(input, len, expected)| { - let array = LargeStringArray::from(input); - let result = bit_length(&array)?; - assert_eq!(len, result.len()); - let result = result.as_any().downcast_ref::().unwrap(); - expected.iter().enumerate().for_each(|(i, value)| { - assert_eq!(*value as i64, result.value(i)); - }); - Ok(()) - }) - } - - fn null_cases() -> Vec<(Vec>, usize, Vec>)> { - vec![( - vec![Some("one"), None, Some("three"), Some("four")], - 4, - vec![Some(24), None, Some(40), Some(32)], - )] - } - - #[test] - fn null_string() -> Result<()> { - null_cases() - .into_iter() - .try_for_each(|(input, len, expected)| { - let array = StringArray::from(input); - let result = bit_length(&array)?; - assert_eq!(len, result.len()); - let result = result.as_any().downcast_ref::().unwrap(); - - let expected: Int32Array = expected.into(); - assert_eq!(expected.data(), result.data()); - Ok(()) - }) - } - - #[test] - fn null_large_string() -> Result<()> { - null_cases() - .into_iter() - .try_for_each(|(input, len, expected)| { - let array = LargeStringArray::from(input); - let result = bit_length(&array)?; - assert_eq!(len, result.len()); - let result = result.as_any().downcast_ref::().unwrap(); - - // convert to i64 - let expected: Int64Array = expected - .iter() - .map(|e| e.map(|e| e as i64)) - .collect::>() - .into(); - assert_eq!(expected.data(), result.data()); - Ok(()) - }) - } - - /// Tests that bit_length is not valid for u64. - #[test] - fn wrong_type() { - let array: UInt64Array = vec![1u64].into(); - - assert!(bit_length(&array).is_err()); - } - - /// Tests with an offset - #[test] - fn offsets() -> Result<()> { - let a = StringArray::from(vec!["hello", " ", "world"]); - let b = make_array( - ArrayData::builder(DataType::Utf8) - .len(2) - .offset(1) - .buffers(a.data_ref().buffers().to_vec()) - .build(), - ); - let result = bit_length(b.as_ref())?; - - let expected = Int32Array::from(vec![8, 40]); - assert_eq!(expected.data(), result.data()); - - Ok(()) - } -} diff --git a/rust/arrow/src/compute/kernels/length.rs b/rust/arrow/src/compute/kernels/length.rs index f285ef34c08..1ee6ba45923 100644 --- a/rust/arrow/src/compute/kernels/length.rs +++ b/rust/arrow/src/compute/kernels/length.rs @@ -25,44 +25,52 @@ use crate::{ use std::sync::Arc; #[allow(clippy::unnecessary_wraps)] -fn length_string(array: &Array, data_type: DataType) -> Result -where - OffsetSize: OffsetSizeTrait, -{ - // note: offsets are stored as u8, but they can be interpreted as OffsetSize - let offsets = &array.data_ref().buffers()[0]; - // this is a 30% improvement over iterating over u8s and building OffsetSize, which - // justifies the usage of `unsafe`. - let slice: &[OffsetSize] = - &unsafe { offsets.typed_data::() }[array.offset()..]; - - let lengths = slice.windows(2).map(|offset| offset[1] - offset[0]); - - // JUSTIFICATION - // Benefit - // ~60% speedup - // Soundness - // `values` is an iterator with a known size. - let buffer = unsafe { Buffer::from_trusted_len_iter(lengths) }; - - let null_bit_buffer = array - .data_ref() - .null_bitmap() - .as_ref() - .map(|b| b.bits.clone()); - - let data = ArrayData::new( - data_type, - array.len(), - None, - null_bit_buffer, - 0, - vec![buffer], - vec![], - ); - Ok(make_array(Arc::new(data))) +macro_rules! length_functions { + ($FUNC:ident, $EXPR:expr) => { + fn $FUNC(array: &Array, data_type: DataType) -> Result + where + OffsetSize: OffsetSizeTrait, + { + // note: offsets are stored as u8, but they can be interpreted as OffsetSize + let offsets = &array.data_ref().buffers()[0]; + // this is a 30% improvement over iterating over u8s and building OffsetSize, which + // justifies the usage of `unsafe`. + let slice: &[OffsetSize] = + &unsafe { offsets.typed_data::() }[array.offset()..]; + + let lengths = slice.windows(2).map($EXPR); + + // JUSTIFICATION + // Benefit + // ~60% speedup + // Soundness + // `values` is an iterator with a known size. + let buffer = unsafe { Buffer::from_trusted_len_iter(lengths) }; + + let null_bit_buffer = array + .data_ref() + .null_bitmap() + .as_ref() + .map(|b| b.bits.clone()); + + let data = ArrayData::new( + data_type, + array.len(), + None, + null_bit_buffer, + 0, + vec![buffer], + vec![], + ); + Ok(make_array(Arc::new(data))) + } + }; } +length_functions!(octet_length_impl, |offset| offset[1] - offset[0]); +length_functions!(bit_length_impl, |offset| (offset[1] - offset[0]) + * OffsetSize::from_usize(8).unwrap()); + /// Returns an array of Int32/Int64 denoting the number of bytes in each string in the array. /// /// * this only accepts StringArray/Utf8 and LargeString/LargeUtf8 @@ -70,8 +78,8 @@ where /// * length is in number of bytes pub fn length(array: &Array) -> Result { match array.data_type() { - DataType::Utf8 => length_string::(array, DataType::Int32), - DataType::LargeUtf8 => length_string::(array, DataType::Int64), + DataType::Utf8 => octet_length_impl::(array, DataType::Int32), + DataType::LargeUtf8 => octet_length_impl::(array, DataType::Int64), _ => Err(ArrowError::ComputeError(format!( "length not supported for {:?}", array.data_type() @@ -79,11 +87,27 @@ pub fn length(array: &Array) -> Result { } } +/// Returns an array of Int32/Int64 denoting the number of bits in each string in the array. +/// +/// * this only accepts StringArray/Utf8 and LargeString/LargeUtf8 +/// * bit_length of null is null. +/// * bit_length is in number of bits +pub fn bit_length(array: &Array) -> Result { + match array.data_type() { + DataType::Utf8 => bit_length_impl::(array, DataType::Int32), + DataType::LargeUtf8 => bit_length_impl::(array, DataType::Int64), + _ => Err(ArrowError::ComputeError(format!( + "bit_length not supported for {:?}", + array.data_type() + ))), + } +} + #[cfg(test)] mod tests { use super::*; - fn cases() -> Vec<(Vec<&'static str>, usize, Vec)> { + fn length_cases() -> Vec<(Vec<&'static str>, usize, Vec)> { fn double_vec(v: Vec) -> Vec { [&v[..], &v[..]].concat() } @@ -105,34 +129,38 @@ mod tests { } #[test] - fn test_string() -> Result<()> { - cases().into_iter().try_for_each(|(input, len, expected)| { - let array = StringArray::from(input); - let result = length(&array)?; - assert_eq!(len, result.len()); - let result = result.as_any().downcast_ref::().unwrap(); - expected.iter().enumerate().for_each(|(i, value)| { - assert_eq!(*value, result.value(i)); - }); - Ok(()) - }) + fn length_test_string() -> Result<()> { + length_cases() + .into_iter() + .try_for_each(|(input, len, expected)| { + let array = StringArray::from(input); + let result = length(&array)?; + assert_eq!(len, result.len()); + let result = result.as_any().downcast_ref::().unwrap(); + expected.iter().enumerate().for_each(|(i, value)| { + assert_eq!(*value, result.value(i)); + }); + Ok(()) + }) } #[test] - fn test_large_string() -> Result<()> { - cases().into_iter().try_for_each(|(input, len, expected)| { - let array = LargeStringArray::from(input); - let result = length(&array)?; - assert_eq!(len, result.len()); - let result = result.as_any().downcast_ref::().unwrap(); - expected.iter().enumerate().for_each(|(i, value)| { - assert_eq!(*value as i64, result.value(i)); - }); - Ok(()) - }) - } - - fn null_cases() -> Vec<(Vec>, usize, Vec>)> { + fn length_test_large_string() -> Result<()> { + length_cases() + .into_iter() + .try_for_each(|(input, len, expected)| { + let array = LargeStringArray::from(input); + let result = length(&array)?; + assert_eq!(len, result.len()); + let result = result.as_any().downcast_ref::().unwrap(); + expected.iter().enumerate().for_each(|(i, value)| { + assert_eq!(*value as i64, result.value(i)); + }); + Ok(()) + }) + } + + fn length_null_cases() -> Vec<(Vec>, usize, Vec>)> { vec![( vec![Some("one"), None, Some("three"), Some("four")], 4, @@ -141,8 +169,8 @@ mod tests { } #[test] - fn null_string() -> Result<()> { - null_cases() + fn length_null_string() -> Result<()> { + length_null_cases() .into_iter() .try_for_each(|(input, len, expected)| { let array = StringArray::from(input); @@ -157,8 +185,8 @@ mod tests { } #[test] - fn null_large_string() -> Result<()> { - null_cases() + fn length_null_large_string() -> Result<()> { + length_null_cases() .into_iter() .try_for_each(|(input, len, expected)| { let array = LargeStringArray::from(input); @@ -179,7 +207,7 @@ mod tests { /// Tests that length is not valid for u64. #[test] - fn wrong_type() { + fn length_wrong_type() { let array: UInt64Array = vec![1u64].into(); assert!(length(&array).is_err()); @@ -187,7 +215,7 @@ mod tests { /// Tests with an offset #[test] - fn offsets() -> Result<()> { + fn length_offsets() -> Result<()> { let a = StringArray::from(vec!["hello", " ", "world"]); let b = make_array( ArrayData::builder(DataType::Utf8) @@ -203,4 +231,130 @@ mod tests { Ok(()) } + + fn bit_length_cases() -> Vec<(Vec<&'static str>, usize, Vec)> { + fn double_vec(v: Vec) -> Vec { + [&v[..], &v[..]].concat() + } + + // a large array + let mut values = vec!["one", "on", "o", ""]; + let mut expected = vec![24, 16, 8, 0]; + for _ in 0..10 { + values = double_vec(values); + expected = double_vec(expected); + } + + vec![ + (vec!["hello", " ", "world", "!"], 4, vec![40, 8, 40, 8]), + (vec!["💖"], 1, vec![32]), + (vec!["josé"], 1, vec![40]), + (values, 4096, expected), + ] + } + + #[test] + fn bit_length_test_string() -> Result<()> { + bit_length_cases() + .into_iter() + .try_for_each(|(input, len, expected)| { + let array = StringArray::from(input); + let result = bit_length(&array)?; + assert_eq!(len, result.len()); + let result = result.as_any().downcast_ref::().unwrap(); + expected.iter().enumerate().for_each(|(i, value)| { + assert_eq!(*value, result.value(i)); + }); + Ok(()) + }) + } + + #[test] + fn bit_length_test_large_string() -> Result<()> { + bit_length_cases() + .into_iter() + .try_for_each(|(input, len, expected)| { + let array = LargeStringArray::from(input); + let result = bit_length(&array)?; + assert_eq!(len, result.len()); + let result = result.as_any().downcast_ref::().unwrap(); + expected.iter().enumerate().for_each(|(i, value)| { + assert_eq!(*value as i64, result.value(i)); + }); + Ok(()) + }) + } + + fn bit_length_null_cases() -> Vec<(Vec>, usize, Vec>)> + { + vec![( + vec![Some("one"), None, Some("three"), Some("four")], + 4, + vec![Some(24), None, Some(40), Some(32)], + )] + } + + #[test] + fn bit_length_null_string() -> Result<()> { + bit_length_null_cases() + .into_iter() + .try_for_each(|(input, len, expected)| { + let array = StringArray::from(input); + let result = bit_length(&array)?; + assert_eq!(len, result.len()); + let result = result.as_any().downcast_ref::().unwrap(); + + let expected: Int32Array = expected.into(); + assert_eq!(expected.data(), result.data()); + Ok(()) + }) + } + + #[test] + fn bit_length_null_large_string() -> Result<()> { + bit_length_null_cases() + .into_iter() + .try_for_each(|(input, len, expected)| { + let array = LargeStringArray::from(input); + let result = bit_length(&array)?; + assert_eq!(len, result.len()); + let result = result.as_any().downcast_ref::().unwrap(); + + // convert to i64 + let expected: Int64Array = expected + .iter() + .map(|e| e.map(|e| e as i64)) + .collect::>() + .into(); + assert_eq!(expected.data(), result.data()); + Ok(()) + }) + } + + /// Tests that bit_length is not valid for u64. + #[test] + fn bit_length_wrong_type() { + let array: UInt64Array = vec![1u64].into(); + + assert!(bit_length(&array).is_err()); + } + + /// Tests with an offset + #[test] + fn bit_length_offsets() -> Result<()> { + let a = StringArray::from(vec!["hello", " ", "world"]); + let b = make_array( + ArrayData::builder(DataType::Utf8) + .len(2) + .offset(1) + .buffers(a.data_ref().buffers().to_vec()) + .build(), + ); + let result = bit_length(b.as_ref())?; + + let expected = Int32Array::from(vec![8, 40]); + assert_eq!(expected.data(), result.data()); + + Ok(()) + } } diff --git a/rust/arrow/src/compute/kernels/mod.rs b/rust/arrow/src/compute/kernels/mod.rs index ef2f7736f4d..a8d24979e04 100644 --- a/rust/arrow/src/compute/kernels/mod.rs +++ b/rust/arrow/src/compute/kernels/mod.rs @@ -20,7 +20,6 @@ pub mod aggregate; pub mod arithmetic; pub mod arity; -pub mod bit_length; pub mod boolean; pub mod cast; pub mod cast_utils; diff --git a/rust/datafusion/src/physical_plan/functions.rs b/rust/datafusion/src/physical_plan/functions.rs index 2513889c6d2..baacf949270 100644 --- a/rust/datafusion/src/physical_plan/functions.rs +++ b/rust/datafusion/src/physical_plan/functions.rs @@ -45,10 +45,9 @@ use crate::{ }; use arrow::{ array::ArrayRef, - compute::kernels::bit_length::bit_length, - compute::kernels::length::length, + compute::kernels::length::{bit_length, length}, datatypes::TimeUnit, - datatypes::{DataType, Field, Schema}, + datatypes::{DataType, Field, Int32Type, Int64Type, Schema}, record_batch::RecordBatch, }; use fmt::{Debug, Formatter}; @@ -441,12 +440,12 @@ pub fn create_physical_expr( }, }, BuiltinScalarFunction::CharacterLength => |args| match args[0].data_type() { - DataType::Utf8 => { - make_scalar_function(string_expressions::character_length_i32)(args) - } - DataType::LargeUtf8 => { - make_scalar_function(string_expressions::character_length_i64)(args) - } + DataType::Utf8 => make_scalar_function( + string_expressions::character_length::, + )(args), + DataType::LargeUtf8 => make_scalar_function( + string_expressions::character_length::, + )(args), other => Err(DataFusionError::Internal(format!( "Unsupported data type {:?} for function character_length", other, @@ -685,8 +684,6 @@ mod tests { /// $ARRAY_TYPE is the column type after function applied macro_rules! test_function { ($FUNC:ident, $ARGS:expr, $EXPECTED:expr, $EXPECTED_TYPE:ty, $DATA_TYPE: ident, $ARRAY_TYPE:ident) => { - println!("{:?}", BuiltinScalarFunction::$FUNC); - // used to provide type annotation let expected: Result> = $EXPECTED; diff --git a/rust/datafusion/src/physical_plan/string_expressions.rs b/rust/datafusion/src/physical_plan/string_expressions.rs index 0d761ac3073..81d2c67eec6 100644 --- a/rust/datafusion/src/physical_plan/string_expressions.rs +++ b/rust/datafusion/src/physical_plan/string_expressions.rs @@ -25,10 +25,10 @@ use crate::{ }; use arrow::{ array::{ - Array, ArrayRef, GenericStringArray, Int32Array, Int64Array, StringArray, + Array, ArrayRef, GenericStringArray, PrimitiveArray, StringArray, StringOffsetSizeTrait, }, - datatypes::DataType, + datatypes::{ArrowNativeType, ArrowPrimitiveType, DataType}, }; use unicode_segmentation::UnicodeSegmentation; @@ -120,35 +120,22 @@ where } /// Returns number of characters in the string. -/// character_length_i32('josé') = 4 -pub fn character_length_i32(args: &[ArrayRef]) -> Result { - let string_array: &GenericStringArray = args[0] - .as_any() - .downcast_ref::>() - .unwrap(); - - // first map is the iterator, second is for the `Option<_>` - let result = string_array - .iter() - .map(|x| x.map(|x: &str| x.graphemes(true).count() as i32)) - .collect::(); - - Ok(Arc::new(result) as ArrayRef) -} - -/// Returns number of characters in the string. -/// character_length_i64('josé') = 4 -pub fn character_length_i64(args: &[ArrayRef]) -> Result { - let string_array: &GenericStringArray = args[0] +/// character_length('josé') = 4 +pub fn character_length(args: &[ArrayRef]) -> Result +where + T::Native: StringOffsetSizeTrait, +{ + let string_array: &GenericStringArray = args[0] .as_any() - .downcast_ref::>() + .downcast_ref::>() .unwrap(); - // first map is the iterator, second is for the `Option<_>` let result = string_array .iter() - .map(|x| x.map(|x: &str| x.graphemes(true).count() as i64)) - .collect::(); + .map(|x| { + x.map(|x: &str| T::Native::from_usize(x.graphemes(true).count()).unwrap()) + }) + .collect::>(); Ok(Arc::new(result) as ArrayRef) } From 23b1329ba749e49d5e98a45b7973512e53867f04 Mon Sep 17 00:00:00 2001 From: "Jorge C. Leitao" Date: Wed, 17 Feb 2021 17:56:22 +0100 Subject: [PATCH 3/6] Added some safety. --- rust/arrow/src/compute/kernels/length.rs | 128 ++++++++++++++--------- 1 file changed, 78 insertions(+), 50 deletions(-) diff --git a/rust/arrow/src/compute/kernels/length.rs b/rust/arrow/src/compute/kernels/length.rs index 1ee6ba45923..94e8a965ac8 100644 --- a/rust/arrow/src/compute/kernels/length.rs +++ b/rust/arrow/src/compute/kernels/length.rs @@ -17,59 +17,87 @@ //! Defines kernel for length of a string array -use crate::{array::*, buffer::Buffer}; use crate::{ - datatypes::DataType, + array::*, + buffer::Buffer, + datatypes::{ArrowNativeType, ArrowPrimitiveType}, +}; +use crate::{ + datatypes::{DataType, Int32Type, Int64Type}, error::{ArrowError, Result}, }; use std::sync::Arc; -#[allow(clippy::unnecessary_wraps)] -macro_rules! length_functions { - ($FUNC:ident, $EXPR:expr) => { - fn $FUNC(array: &Array, data_type: DataType) -> Result - where - OffsetSize: OffsetSizeTrait, - { - // note: offsets are stored as u8, but they can be interpreted as OffsetSize - let offsets = &array.data_ref().buffers()[0]; - // this is a 30% improvement over iterating over u8s and building OffsetSize, which - // justifies the usage of `unsafe`. - let slice: &[OffsetSize] = - &unsafe { offsets.typed_data::() }[array.offset()..]; - - let lengths = slice.windows(2).map($EXPR); - - // JUSTIFICATION - // Benefit - // ~60% speedup - // Soundness - // `values` is an iterator with a known size. - let buffer = unsafe { Buffer::from_trusted_len_iter(lengths) }; - - let null_bit_buffer = array - .data_ref() - .null_bitmap() - .as_ref() - .map(|b| b.bits.clone()); - - let data = ArrayData::new( - data_type, - array.len(), - None, - null_bit_buffer, - 0, - vec![buffer], - vec![], - ); - Ok(make_array(Arc::new(data))) - } - }; +fn unary_offsets_string( + array: &GenericStringArray, + data_type: DataType, + op: F, +) -> ArrayRef +where + O: StringOffsetSizeTrait + ArrowNativeType, + F: Fn(O) -> O, +{ + // note: offsets are stored as u8, but they can be interpreted as OffsetSize + let offsets = &array.data_ref().buffers()[0]; + // this is a 30% improvement over iterating over u8s and building OffsetSize, which + // justifies the usage of `unsafe`. + let slice: &[O] = &unsafe { offsets.typed_data::() }[array.offset()..]; + + let lengths = slice.windows(2).map(|offset| op(offset[1] - offset[0])); + + // JUSTIFICATION + // Benefit + // ~60% speedup + // Soundness + // `values` is an iterator with a known size. + let buffer = unsafe { Buffer::from_trusted_len_iter(lengths) }; + + let null_bit_buffer = array + .data_ref() + .null_bitmap() + .as_ref() + .map(|b| b.bits.clone()); + + let data = ArrayData::new( + data_type, + array.len(), + None, + null_bit_buffer, + 0, + vec![buffer], + vec![], + ); + make_array(Arc::new(data)) } -length_functions!(octet_length_impl, |offset| offset[1] - offset[0]); -length_functions!(bit_length_impl, |offset| (offset[1] - offset[0]) - * OffsetSize::from_usize(8).unwrap()); +fn octet_length( + array: &dyn Array, +) -> Result +where + T::Native: StringOffsetSizeTrait, +{ + let array = array + .as_any() + .downcast_ref::>() + .unwrap(); + Ok(unary_offsets_string::(array, T::DATA_TYPE, |x| x)) +} + +fn bit_length_impl( + array: &dyn Array, +) -> Result +where + T::Native: StringOffsetSizeTrait, +{ + let array = array + .as_any() + .downcast_ref::>() + .unwrap(); + let bits_in_bytes = O::from_usize(8).unwrap(); + Ok(unary_offsets_string::(array, T::DATA_TYPE, |x| { + x * bits_in_bytes + })) +} /// Returns an array of Int32/Int64 denoting the number of bytes in each string in the array. /// @@ -78,8 +106,8 @@ length_functions!(bit_length_impl, |offset| (offset[1] - offset[0]) /// * length is in number of bytes pub fn length(array: &Array) -> Result { match array.data_type() { - DataType::Utf8 => octet_length_impl::(array, DataType::Int32), - DataType::LargeUtf8 => octet_length_impl::(array, DataType::Int64), + DataType::Utf8 => octet_length::(array), + DataType::LargeUtf8 => octet_length::(array), _ => Err(ArrowError::ComputeError(format!( "length not supported for {:?}", array.data_type() @@ -94,8 +122,8 @@ pub fn length(array: &Array) -> Result { /// * bit_length is in number of bits pub fn bit_length(array: &Array) -> Result { match array.data_type() { - DataType::Utf8 => bit_length_impl::(array, DataType::Int32), - DataType::LargeUtf8 => bit_length_impl::(array, DataType::Int64), + DataType::Utf8 => bit_length_impl::(array), + DataType::LargeUtf8 => bit_length_impl::(array), _ => Err(ArrowError::ComputeError(format!( "bit_length not supported for {:?}", array.data_type() From 619f2f630857d4635dc986f00b48a9e8af03f2be Mon Sep 17 00:00:00 2001 From: Mike Seddon Date: Thu, 18 Feb 2021 07:45:15 +1100 Subject: [PATCH 4/6] clippy --- rust/arrow/src/compute/kernels/length.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/rust/arrow/src/compute/kernels/length.rs b/rust/arrow/src/compute/kernels/length.rs index 94e8a965ac8..a5c32646782 100644 --- a/rust/arrow/src/compute/kernels/length.rs +++ b/rust/arrow/src/compute/kernels/length.rs @@ -72,7 +72,7 @@ where fn octet_length( array: &dyn Array, -) -> Result +) -> ArrayRef where T::Native: StringOffsetSizeTrait, { @@ -80,7 +80,7 @@ where .as_any() .downcast_ref::>() .unwrap(); - Ok(unary_offsets_string::(array, T::DATA_TYPE, |x| x)) + unary_offsets_string::(array, T::DATA_TYPE, |x| x) } fn bit_length_impl( @@ -106,8 +106,8 @@ where /// * length is in number of bytes pub fn length(array: &Array) -> Result { match array.data_type() { - DataType::Utf8 => octet_length::(array), - DataType::LargeUtf8 => octet_length::(array), + DataType::Utf8 => Ok(octet_length::(array)), + DataType::LargeUtf8 => Ok(octet_length::(array)), _ => Err(ArrowError::ComputeError(format!( "length not supported for {:?}", array.data_type() From 76855f25ecc0f797a9a06b9a773d08e62ad983f7 Mon Sep 17 00:00:00 2001 From: Mike Seddon Date: Sun, 21 Feb 2021 09:43:34 +1100 Subject: [PATCH 5/6] clippy --- rust/arrow/src/compute/kernels/length.rs | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/rust/arrow/src/compute/kernels/length.rs b/rust/arrow/src/compute/kernels/length.rs index a5c32646782..350c8a8f900 100644 --- a/rust/arrow/src/compute/kernels/length.rs +++ b/rust/arrow/src/compute/kernels/length.rs @@ -85,7 +85,7 @@ where fn bit_length_impl( array: &dyn Array, -) -> Result +) -> ArrayRef where T::Native: StringOffsetSizeTrait, { @@ -94,9 +94,9 @@ where .downcast_ref::>() .unwrap(); let bits_in_bytes = O::from_usize(8).unwrap(); - Ok(unary_offsets_string::(array, T::DATA_TYPE, |x| { + unary_offsets_string::(array, T::DATA_TYPE, |x| { x * bits_in_bytes - })) + }) } /// Returns an array of Int32/Int64 denoting the number of bytes in each string in the array. @@ -122,8 +122,8 @@ pub fn length(array: &Array) -> Result { /// * bit_length is in number of bits pub fn bit_length(array: &Array) -> Result { match array.data_type() { - DataType::Utf8 => bit_length_impl::(array), - DataType::LargeUtf8 => bit_length_impl::(array), + DataType::Utf8 => Ok(bit_length_impl::(array)), + DataType::LargeUtf8 => Ok(bit_length_impl::(array)), _ => Err(ArrowError::ComputeError(format!( "bit_length not supported for {:?}", array.data_type() From 17740e9e46815e402a1b710502222ac8eb7d4aeb Mon Sep 17 00:00:00 2001 From: Mike Seddon Date: Sun, 21 Feb 2021 10:10:37 +1100 Subject: [PATCH 6/6] fmt --- rust/arrow/src/compute/kernels/length.rs | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/rust/arrow/src/compute/kernels/length.rs b/rust/arrow/src/compute/kernels/length.rs index 350c8a8f900..ed1fda4a062 100644 --- a/rust/arrow/src/compute/kernels/length.rs +++ b/rust/arrow/src/compute/kernels/length.rs @@ -94,9 +94,7 @@ where .downcast_ref::>() .unwrap(); let bits_in_bytes = O::from_usize(8).unwrap(); - unary_offsets_string::(array, T::DATA_TYPE, |x| { - x * bits_in_bytes - }) + unary_offsets_string::(array, T::DATA_TYPE, |x| x * bits_in_bytes) } /// Returns an array of Int32/Int64 denoting the number of bytes in each string in the array.