From e551bb1fe37f701943970616d216a33da087daef Mon Sep 17 00:00:00 2001 From: "Jorge C. Leitao" Date: Sat, 30 Jan 2021 20:23:22 +0100 Subject: [PATCH 1/8] Added support for scalar in Builtin functions. --- rust/datafusion/src/execution/context.rs | 33 ++- .../src/execution/dataframe_impl.rs | 6 +- .../src/physical_plan/crypto_expressions.rs | 185 +++++++++++---- .../datafusion/src/physical_plan/functions.rs | 134 ++++------- .../src/physical_plan/math_expressions.rs | 44 +++- .../src/physical_plan/string_expressions.rs | 224 +++++++++++++----- rust/datafusion/src/scalar.rs | 54 +++++ 7 files changed, 462 insertions(+), 218 deletions(-) diff --git a/rust/datafusion/src/execution/context.rs b/rust/datafusion/src/execution/context.rs index 2977d9816ca..976592ab6a6 100644 --- a/rust/datafusion/src/execution/context.rs +++ b/rust/datafusion/src/execution/context.rs @@ -619,8 +619,8 @@ impl FunctionRegistry for ExecutionContextState { mod tests { use super::*; - use crate::physical_plan::functions::ScalarFunctionImplementation; use crate::physical_plan::{collect, collect_partitioned}; + use crate::physical_plan::{functions::ScalarFunctionImplementation, ColumnarValue}; use crate::test; use crate::variable::VarType; use crate::{ @@ -631,7 +631,7 @@ mod tests { datasource::MemTable, logical_plan::create_udaf, physical_plan::expressions::AvgAccumulator, }; - use arrow::array::{ArrayRef, Float64Array, Int32Array}; + use arrow::array::{Float64Array, Int32Array}; use arrow::compute::add; use arrow::datatypes::*; use arrow::record_batch::RecordBatch; @@ -1618,17 +1618,24 @@ mod tests { let provider = MemTable::try_new(Arc::new(schema), vec![vec![batch]])?; ctx.register_table("t", Box::new(provider)); - let myfunc: ScalarFunctionImplementation = Arc::new(|args: &[ArrayRef]| { - let l = &args[0] - .as_any() - .downcast_ref::() - .expect("cast failed"); - let r = &args[1] - .as_any() - .downcast_ref::() - .expect("cast failed"); - Ok(Arc::new(add(l, r)?)) - }); + let myfunc: ScalarFunctionImplementation = + Arc::new(|args: &[ColumnarValue]| { + if let (ColumnarValue::Array(l), ColumnarValue::Array(r)) = + (&args[0], &args[1]) + { + let l = l + .as_any() + .downcast_ref::() + .expect("cast failed"); + let r = r + .as_any() + .downcast_ref::() + .expect("cast failed"); + Ok(ColumnarValue::Array(Arc::new(add(l, r)?))) + } else { + unimplemented!() + } + }); ctx.register_udf(create_udf( "my_add", diff --git a/rust/datafusion/src/execution/dataframe_impl.rs b/rust/datafusion/src/execution/dataframe_impl.rs index 5a4270efa69..c9a1ff9dd26 100644 --- a/rust/datafusion/src/execution/dataframe_impl.rs +++ b/rust/datafusion/src/execution/dataframe_impl.rs @@ -158,11 +158,11 @@ impl DataFrame for DataFrameImpl { #[cfg(test)] mod tests { use super::*; - use crate::datasource::csv::CsvReadOptions; use crate::execution::context::ExecutionContext; use crate::logical_plan::*; + use crate::{datasource::csv::CsvReadOptions, physical_plan::ColumnarValue}; use crate::{physical_plan::functions::ScalarFunctionImplementation, test}; - use arrow::{array::ArrayRef, datatypes::DataType}; + use arrow::datatypes::DataType; #[test] fn select_columns() -> Result<()> { @@ -287,7 +287,7 @@ mod tests { // declare the udf let my_fn: ScalarFunctionImplementation = - Arc::new(|_: &[ArrayRef]| unimplemented!("my_fn is not implemented")); + Arc::new(|_: &[ColumnarValue]| unimplemented!("my_fn is not implemented")); // create and register the udf ctx.register_udf(create_udf( diff --git a/rust/datafusion/src/physical_plan/crypto_expressions.rs b/rust/datafusion/src/physical_plan/crypto_expressions.rs index 6a0940d4503..134a098c7d9 100644 --- a/rust/datafusion/src/physical_plan/crypto_expressions.rs +++ b/rust/datafusion/src/physical_plan/crypto_expressions.rs @@ -17,17 +17,26 @@ //! Crypto expressions +use std::sync::Arc; + use md5::Md5; use sha2::{ digest::Output as SHA2DigestOutput, Digest as SHA2Digest, Sha224, Sha256, Sha384, Sha512, }; -use crate::error::{DataFusionError, Result}; -use arrow::array::{ - ArrayRef, GenericBinaryArray, GenericStringArray, StringOffsetSizeTrait, +use crate::{ + error::{DataFusionError, Result}, + scalar::ScalarValue, +}; +use arrow::{ + array::{Array, GenericBinaryArray, GenericStringArray, StringOffsetSizeTrait}, + datatypes::DataType, }; +use super::{string_expressions::unary_string_function, ColumnarValue}; + +/// Computes the md5 of a string. fn md5_process(input: &str) -> String { let mut digest = Md5::default(); digest.update(&input); @@ -49,58 +58,136 @@ fn sha_process(input: &str) -> SHA2DigestOutput { digest.finalize() } -macro_rules! crypto_unary_string_function { - ($NAME:ident, $FUNC:expr) => { - /// crypto function that accepts Utf8 or LargeUtf8 and returns Utf8 string - pub fn $NAME( - args: &[ArrayRef], - ) -> Result> { - if args.len() != 1 { - return Err(DataFusionError::Internal(format!( - "{:?} args were supplied but {} takes exactly one argument", - args.len(), - String::from(stringify!($NAME)), - ))); - } +fn unary_binary_function( + args: &[&dyn Array], + op: F, + name: &str, +) -> Result> +where + R: AsRef<[u8]>, + T: StringOffsetSizeTrait, + F: Fn(&str) -> R, +{ + if args.len() != 1 { + return Err(DataFusionError::Internal(format!( + "{:?} args were supplied but {} takes exactly one argument", + args.len(), + name, + ))); + } + + let array = args[0] + .as_any() + .downcast_ref::>() + .unwrap(); - let array = args[0] - .as_any() - .downcast_ref::>() - .unwrap(); + // first map is the iterator, second is for the `Option<_>` + Ok(array.iter().map(|x| x.map(|x| op(x))).collect()) +} + +fn handle(args: &[ColumnarValue], op: F, name: &str) -> Result +where + R: AsRef<[u8]>, + F: Fn(&str) -> R, +{ + match &args[0] { + ColumnarValue::Array(a) => match a.data_type() { + DataType::Utf8 => { + Ok(ColumnarValue::Array(Arc::new(unary_binary_function::< + i32, + _, + _, + >( + &[a.as_ref()], op, name + )?))) + } + DataType::LargeUtf8 => { + Ok(ColumnarValue::Array(Arc::new(unary_binary_function::< + i64, + _, + _, + >( + &[a.as_ref()], op, name + )?))) + } + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {:?} for function md5", + other, + ))), + }, + ColumnarValue::Scalar(scalar) => match scalar { + ScalarValue::Utf8(a) => { + let result = a.as_ref().map(|x| (op)(x).as_ref().to_vec()); + Ok(ColumnarValue::Scalar(ScalarValue::Binary(result))) + } + ScalarValue::LargeUtf8(a) => { + let result = a.as_ref().map(|x| (op)(x).as_ref().to_vec()); + Ok(ColumnarValue::Scalar(ScalarValue::Binary(result))) + } + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {:?} for function md5", + other, + ))), + }, + } +} - // first map is the iterator, second is for the `Option<_>` - Ok(array.iter().map(|x| x.map(|x| $FUNC(x))).collect()) - } - }; +fn md5_array( + args: &[&dyn Array], +) -> Result> { + unary_string_function::(args, md5_process, "md5") } -macro_rules! crypto_unary_binary_function { - ($NAME:ident, $FUNC:expr) => { - /// crypto function that accepts Utf8 or LargeUtf8 and returns Binary - pub fn $NAME( - args: &[ArrayRef], - ) -> Result> { - if args.len() != 1 { - return Err(DataFusionError::Internal(format!( - "{:?} args were supplied but {} takes exactly one argument", - args.len(), - String::from(stringify!($NAME)), - ))); +/// crypto function that accepts Utf8 or LargeUtf8 and returns a [`ColumnarValue`] +pub fn md5(args: &[ColumnarValue]) -> Result { + match &args[0] { + ColumnarValue::Array(a) => match a.data_type() { + DataType::Utf8 => Ok(ColumnarValue::Array(Arc::new(md5_array::(&[ + a.as_ref() + ])?))), + DataType::LargeUtf8 => { + Ok(ColumnarValue::Array(Arc::new(md5_array::(&[ + a.as_ref() + ])?))) } + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {:?} for function md5", + other, + ))), + }, + ColumnarValue::Scalar(scalar) => match scalar { + ScalarValue::Utf8(a) => { + let result = a.as_ref().map(|x| md5_process(x)); + Ok(ColumnarValue::Scalar(ScalarValue::Utf8(result))) + } + ScalarValue::LargeUtf8(a) => { + let result = a.as_ref().map(|x| md5_process(x)); + Ok(ColumnarValue::Scalar(ScalarValue::LargeUtf8(result))) + } + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {:?} for function md5", + other, + ))), + }, + } +} - let array = args[0] - .as_any() - .downcast_ref::>() - .unwrap(); +/// crypto function that accepts Utf8 or LargeUtf8 and returns a [`ColumnarValue`] +pub fn sha224(args: &[ColumnarValue]) -> Result { + handle(args, sha_process::, "ssh224") +} - // first map is the iterator, second is for the `Option<_>` - Ok(array.iter().map(|x| x.map(|x| $FUNC(x))).collect()) - } - }; +/// crypto function that accepts Utf8 or LargeUtf8 and returns a [`ColumnarValue`] +pub fn sha256(args: &[ColumnarValue]) -> Result { + handle(args, sha_process::, "sha256") } -crypto_unary_string_function!(md5, md5_process); -crypto_unary_binary_function!(sha224, sha_process::); -crypto_unary_binary_function!(sha256, sha_process::); -crypto_unary_binary_function!(sha384, sha_process::); -crypto_unary_binary_function!(sha512, sha_process::); +/// crypto function that accepts Utf8 or LargeUtf8 and returns a [`ColumnarValue`] +pub fn sha384(args: &[ColumnarValue]) -> Result { + handle(args, sha_process::, "sha384") +} + +/// crypto function that accepts Utf8 or LargeUtf8 and returns a [`ColumnarValue`] +pub fn sha512(args: &[ColumnarValue]) -> Result { + handle(args, sha_process::, "sha512") +} diff --git a/rust/datafusion/src/physical_plan/functions.rs b/rust/datafusion/src/physical_plan/functions.rs index ca597c9e6ae..0b5105502cf 100644 --- a/rust/datafusion/src/physical_plan/functions.rs +++ b/rust/datafusion/src/physical_plan/functions.rs @@ -33,15 +33,17 @@ use super::{ type_coercion::{coerce, data_types}, ColumnarValue, PhysicalExpr, }; -use crate::error::{DataFusionError, Result}; use crate::physical_plan::array_expressions; use crate::physical_plan::crypto_expressions; use crate::physical_plan::datetime_expressions; use crate::physical_plan::expressions::{nullif_func, SUPPORTED_NULLIF_TYPES}; use crate::physical_plan::math_expressions; use crate::physical_plan::string_expressions; +use crate::{ + error::{DataFusionError, Result}, + scalar::ScalarValue, +}; use arrow::{ - array::ArrayRef, compute::kernels::length::length, datatypes::TimeUnit, datatypes::{DataType, Field, Schema}, @@ -72,7 +74,7 @@ pub enum Signature { /// Scalar function pub type ScalarFunctionImplementation = - Arc Result + Send + Sync>; + Arc Result + Send + Sync>; /// A function's return type pub type ReturnTypeFunction = @@ -383,98 +385,52 @@ pub fn create_physical_expr( BuiltinScalarFunction::Trunc => math_expressions::trunc, BuiltinScalarFunction::Abs => math_expressions::abs, BuiltinScalarFunction::Signum => math_expressions::signum, - BuiltinScalarFunction::NullIf => nullif_func, - BuiltinScalarFunction::MD5 => |args| match args[0].data_type() { - DataType::Utf8 => Ok(Arc::new(crypto_expressions::md5::(args)?)), - DataType::LargeUtf8 => Ok(Arc::new(crypto_expressions::md5::(args)?)), - other => Err(DataFusionError::Internal(format!( - "Unsupported data type {:?} for function md5", - other, - ))), - }, - BuiltinScalarFunction::SHA224 => |args| match args[0].data_type() { - DataType::Utf8 => Ok(Arc::new(crypto_expressions::sha224::(args)?)), - DataType::LargeUtf8 => Ok(Arc::new(crypto_expressions::sha224::(args)?)), - other => Err(DataFusionError::Internal(format!( - "Unsupported data type {:?} for function sha224", - other, - ))), - }, - BuiltinScalarFunction::SHA256 => |args| match args[0].data_type() { - DataType::Utf8 => Ok(Arc::new(crypto_expressions::sha256::(args)?)), - DataType::LargeUtf8 => Ok(Arc::new(crypto_expressions::sha256::(args)?)), - other => Err(DataFusionError::Internal(format!( - "Unsupported data type {:?} for function sha256", - other, - ))), - }, - BuiltinScalarFunction::SHA384 => |args| match args[0].data_type() { - DataType::Utf8 => Ok(Arc::new(crypto_expressions::sha384::(args)?)), - DataType::LargeUtf8 => Ok(Arc::new(crypto_expressions::sha384::(args)?)), - other => Err(DataFusionError::Internal(format!( - "Unsupported data type {:?} for function sha384", - other, - ))), - }, - BuiltinScalarFunction::SHA512 => |args| match args[0].data_type() { - DataType::Utf8 => Ok(Arc::new(crypto_expressions::sha512::(args)?)), - DataType::LargeUtf8 => Ok(Arc::new(crypto_expressions::sha512::(args)?)), - other => Err(DataFusionError::Internal(format!( - "Unsupported data type {:?} for function sha512", - other, - ))), - }, - BuiltinScalarFunction::Length => |args| Ok(length(args[0].as_ref())?), - BuiltinScalarFunction::Concat => { - |args| Ok(Arc::new(string_expressions::concatenate(args)?)) - } - BuiltinScalarFunction::Lower => |args| match args[0].data_type() { - DataType::Utf8 => Ok(Arc::new(string_expressions::lower::(args)?)), - DataType::LargeUtf8 => Ok(Arc::new(string_expressions::lower::(args)?)), - other => Err(DataFusionError::Internal(format!( - "Unsupported data type {:?} for function lower", - other, - ))), - }, - BuiltinScalarFunction::Trim => |args| match args[0].data_type() { - DataType::Utf8 => Ok(Arc::new(string_expressions::trim::(args)?)), - DataType::LargeUtf8 => Ok(Arc::new(string_expressions::trim::(args)?)), - other => Err(DataFusionError::Internal(format!( - "Unsupported data type {:?} for function trim", - other, - ))), - }, - BuiltinScalarFunction::Ltrim => |args| match args[0].data_type() { - DataType::Utf8 => Ok(Arc::new(string_expressions::ltrim::(args)?)), - DataType::LargeUtf8 => Ok(Arc::new(string_expressions::ltrim::(args)?)), - other => Err(DataFusionError::Internal(format!( - "Unsupported data type {:?} for function ltrim", - other, - ))), - }, - BuiltinScalarFunction::Rtrim => |args| match args[0].data_type() { - DataType::Utf8 => Ok(Arc::new(string_expressions::rtrim::(args)?)), - DataType::LargeUtf8 => Ok(Arc::new(string_expressions::rtrim::(args)?)), - other => Err(DataFusionError::Internal(format!( - "Unsupported data type {:?} for function rtrim", - other, - ))), + /* + BuiltinScalarFunction::NullIf => |args| match &args[0] { + ColumnarValue::Scalar(v) => match v { + ScalarValue::Utf8(v) => Ok(ColumnarValue::Scalar(ScalarValue::Int32( + v.as_ref().map(|x| x.len() as i32), + ))), + ScalarValue::LargeUtf8(v) => Ok(ColumnarValue::Scalar( + ScalarValue::Int64(v.as_ref().map(|x| x.len() as i64)), + )), + _ => unreachable!(), + }, + ColumnarValue::Array(v) => Ok(ColumnarValue::Array(nullif_func(v.as_ref())?)), }, - BuiltinScalarFunction::Upper => |args| match args[0].data_type() { - DataType::Utf8 => Ok(Arc::new(string_expressions::upper::(args)?)), - DataType::LargeUtf8 => Ok(Arc::new(string_expressions::upper::(args)?)), - other => Err(DataFusionError::Internal(format!( - "Unsupported data type {:?} for function upper", - other, - ))), + */ + BuiltinScalarFunction::MD5 => crypto_expressions::md5, + BuiltinScalarFunction::SHA256 => crypto_expressions::sha256, + BuiltinScalarFunction::SHA384 => crypto_expressions::sha384, + BuiltinScalarFunction::SHA512 => crypto_expressions::sha512, + BuiltinScalarFunction::Length => |args| match &args[0] { + ColumnarValue::Scalar(v) => match v { + ScalarValue::Utf8(v) => Ok(ColumnarValue::Scalar(ScalarValue::Int32( + v.as_ref().map(|x| x.len() as i32), + ))), + ScalarValue::LargeUtf8(v) => Ok(ColumnarValue::Scalar( + ScalarValue::Int64(v.as_ref().map(|x| x.len() as i64)), + )), + _ => unreachable!(), + }, + ColumnarValue::Array(v) => Ok(ColumnarValue::Array(length(v.as_ref())?)), }, + BuiltinScalarFunction::Concat => string_expressions::concatenate, + BuiltinScalarFunction::Lower => string_expressions::lower, + BuiltinScalarFunction::Trim => string_expressions::trim, + BuiltinScalarFunction::Ltrim => string_expressions::ltrim, + BuiltinScalarFunction::Rtrim => string_expressions::rtrim, + BuiltinScalarFunction::Upper => string_expressions::upper, + /* BuiltinScalarFunction::ToTimestamp => { |args| Ok(Arc::new(datetime_expressions::to_timestamp(args)?)) } BuiltinScalarFunction::DateTrunc => { |args| Ok(Arc::new(datetime_expressions::date_trunc(args)?)) } - BuiltinScalarFunction::Array => |args| array_expressions::array(args), + BuiltinScalarFunction::Array => |args| Ok(array_expressions::array(args)?), + */ + _ => todo!(), }); // coerce let args = coerce(args, input_schema, &signature(fun))?; @@ -622,12 +578,12 @@ impl PhysicalExpr for ScalarFunctionExpr { let inputs = self .args .iter() - .map(|e| e.evaluate(batch).map(|v| v.into_array(batch.num_rows()))) + .map(|e| e.evaluate(batch)) .collect::>>()?; // evaluate the function let fun = self.fun.as_ref(); - (fun)(&inputs).map(|a| ColumnarValue::Array(a)) + (fun)(&inputs) } } diff --git a/rust/datafusion/src/physical_plan/math_expressions.rs b/rust/datafusion/src/physical_plan/math_expressions.rs index 9ad0e2540df..772b80d409a 100644 --- a/rust/datafusion/src/physical_plan/math_expressions.rs +++ b/rust/datafusion/src/physical_plan/math_expressions.rs @@ -19,10 +19,11 @@ use std::sync::Arc; -use arrow::array::{make_array, Array, ArrayData, ArrayRef, Float32Array, Float64Array}; +use arrow::array::{make_array, Array, ArrayData, Float32Array, Float64Array}; use arrow::buffer::Buffer; use arrow::datatypes::{DataType, ToByteSlice}; +use super::{ColumnarValue, ScalarValue}; use crate::error::{DataFusionError, Result}; macro_rules! compute_op { @@ -58,14 +59,35 @@ macro_rules! downcast_compute_op { } macro_rules! unary_primitive_array_op { - ($ARRAY:expr, $NAME:expr, $FUNC:ident) => {{ - match ($ARRAY).data_type() { - DataType::Float32 => downcast_compute_op!($ARRAY, $NAME, $FUNC, Float32Array), - DataType::Float64 => downcast_compute_op!($ARRAY, $NAME, $FUNC, Float64Array), - other => Err(DataFusionError::Internal(format!( - "Unsupported data type {:?} for function {}", - other, $NAME, - ))), + ($VALUE:expr, $NAME:expr, $FUNC:ident) => {{ + match ($VALUE) { + ColumnarValue::Array(array) => match array.data_type() { + DataType::Float32 => { + let result = downcast_compute_op!(array, $NAME, $FUNC, Float32Array); + Ok(ColumnarValue::Array(result?)) + } + DataType::Float64 => { + let result = downcast_compute_op!(array, $NAME, $FUNC, Float64Array); + Ok(ColumnarValue::Array(result?)) + } + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {:?} for function {}", + other, $NAME, + ))), + }, + ColumnarValue::Scalar(a) => match a { + ScalarValue::Float32(a) => Ok(ColumnarValue::Scalar( + ScalarValue::Float64(a.map(|x| x.$FUNC() as f64)), + )), + ScalarValue::Float64(a) => Ok(ColumnarValue::Scalar( + ScalarValue::Float64(a.map(|x| x.$FUNC())), + )), + _ => Err(DataFusionError::Internal(format!( + "Unsupported data type {:?} for function {}", + ($VALUE).data_type(), + $NAME, + ))), + }, } }}; } @@ -73,8 +95,8 @@ macro_rules! unary_primitive_array_op { macro_rules! math_unary_function { ($NAME:expr, $FUNC:ident) => { /// mathematical function that accepts f32 or f64 and returns f64 - pub fn $FUNC(args: &[ArrayRef]) -> Result { - unary_primitive_array_op!(args[0], $NAME, $FUNC) + pub fn $FUNC(args: &[ColumnarValue]) -> Result { + unary_primitive_array_op!(&args[0], $NAME, $FUNC) } }; } diff --git a/rust/datafusion/src/physical_plan/string_expressions.rs b/rust/datafusion/src/physical_plan/string_expressions.rs index c633aa874f4..c102a629703 100644 --- a/rust/datafusion/src/physical_plan/string_expressions.rs +++ b/rust/datafusion/src/physical_plan/string_expressions.rs @@ -17,27 +17,100 @@ //! String expressions -use crate::error::{DataFusionError, Result}; -use arrow::array::{ - Array, ArrayRef, GenericStringArray, StringArray, StringBuilder, - StringOffsetSizeTrait, +use std::sync::Arc; + +use crate::{ + error::{DataFusionError, Result}, + scalar::ScalarValue, +}; +use arrow::{ + array::{Array, GenericStringArray, StringArray, StringOffsetSizeTrait}, + datatypes::DataType, }; -macro_rules! downcast_vec { - ($ARGS:expr, $ARRAY_TYPE:ident) => {{ - $ARGS - .iter() - .map(|e| match e.as_any().downcast_ref::<$ARRAY_TYPE>() { - Some(array) => Ok(array), - _ => Err(DataFusionError::Internal("failed to downcast".to_string())), - }) - }}; +use super::ColumnarValue; + +pub(crate) fn unary_string_function<'a, T, O, F, R>( + args: &[&'a dyn Array], + op: F, + name: &str, +) -> Result> +where + R: AsRef, + O: StringOffsetSizeTrait, + T: StringOffsetSizeTrait, + F: Fn(&'a str) -> R, +{ + if args.len() != 1 { + return Err(DataFusionError::Internal(format!( + "{:?} args were supplied but {} takes exactly one argument", + args.len(), + name, + ))); + } + + let array = args[0] + .as_any() + .downcast_ref::>() + .unwrap(); + + // first map is the iterator, second is for the `Option<_>` + Ok(array.iter().map(|x| x.map(|x| op(x))).collect()) +} + +fn handle<'a, F, R>(args: &'a [ColumnarValue], op: F, name: &str) -> Result +where + R: AsRef, + F: Fn(&'a str) -> R, +{ + match &args[0] { + ColumnarValue::Array(a) => match a.data_type() { + DataType::Utf8 => { + Ok(ColumnarValue::Array(Arc::new(unary_string_function::< + i32, + i32, + _, + _, + >( + &[a.as_ref()], op, name + )?))) + } + DataType::LargeUtf8 => { + Ok(ColumnarValue::Array(Arc::new(unary_string_function::< + i64, + i64, + _, + _, + >( + &[a.as_ref()], op, name + )?))) + } + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {:?} for function md5", + other, + ))), + }, + ColumnarValue::Scalar(scalar) => match scalar { + ScalarValue::Utf8(a) => { + let result = a.as_ref().map(|x| (op)(x).as_ref().to_string()); + Ok(ColumnarValue::Scalar(ScalarValue::Utf8(result))) + } + ScalarValue::LargeUtf8(a) => { + let result = a.as_ref().map(|x| (op)(x).as_ref().to_string()); + Ok(ColumnarValue::Scalar(ScalarValue::LargeUtf8(result))) + } + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {:?} for function md5", + other, + ))), + }, + } } /// concatenate string columns together. -pub fn concatenate(args: &[ArrayRef]) -> Result { +pub fn concatenate(args: &[ColumnarValue]) -> Result { // downcast all arguments to strings - let args = downcast_vec!(args, StringArray).collect::>>()?; + //let args = downcast_vec!(args, StringArray).collect::>>()?; // do not accept 0 arguments. if args.is_empty() { return Err(DataFusionError::Internal( @@ -46,48 +119,93 @@ pub fn concatenate(args: &[ArrayRef]) -> Result { )); } - let mut builder = StringBuilder::new(args.len()); - // for each entry in the array - for index in 0..args[0].len() { - let mut owned_string: String = "".to_owned(); - - // if any is null, the result is null - let mut is_null = false; - for arg in &args { - if arg.is_null(index) { - is_null = true; - break; // short-circuit as we already know the result + // first, decide whether to return a scalar or a vector. + let mut return_array = args.iter().filter_map(|x| match x { + ColumnarValue::Array(array) => Some(array.len()), + _ => None, + }); + if let Some(size) = return_array.next() { + let iter = (0..size).map(|index| { + let mut owned_string: String = "".to_owned(); + + // if any is null, the result is null + let mut is_null = false; + for arg in args { + match arg { + ColumnarValue::Scalar(ScalarValue::Utf8(maybe_value)) => { + if let Some(value) = maybe_value { + owned_string.push_str(value); + } else { + is_null = true; + break; // short-circuit as we already know the result + } + } + ColumnarValue::Array(v) => { + if v.is_null(index) { + is_null = true; + break; // short-circuit as we already know the result + } else { + let v = v.as_any().downcast_ref::().unwrap(); + owned_string.push_str(&v.value(index)); + } + } + _ => unreachable!(), + } + } + if is_null { + None } else { - owned_string.push_str(&arg.value(index)); + Some(owned_string) } - } - if is_null { - builder.append_null()?; - } else { - builder.append_value(&owned_string)?; - } + }); + let array = iter.collect::(); + + Ok(ColumnarValue::Array(Arc::new(array))) + } else { + // short avenue with only scalars + let initial = Some("".to_string()); + let result = args.iter().fold(initial, |mut acc, rhs| { + match acc { + Some(ref mut inner) => { + match rhs { + ColumnarValue::Scalar(ScalarValue::Utf8(Some(v))) => { + inner.push_str(v); + } + ColumnarValue::Scalar(ScalarValue::Utf8(None)) => { + acc = None; + } + _ => unreachable!(""), + }; + } + None => {} + }; + acc + }); + Ok(ColumnarValue::Scalar(ScalarValue::Utf8(result))) } - Ok(builder.finish()) } -macro_rules! string_unary_function { - ($NAME:ident, $FUNC:ident) => { - /// string function that accepts Utf8 or LargeUtf8 and returns Utf8 or LargeUtf8 - pub fn $NAME( - args: &[ArrayRef], - ) -> Result> { - let array = args[0] - .as_any() - .downcast_ref::>() - .unwrap(); - // first map is the iterator, second is for the `Option<_>` - Ok(array.iter().map(|x| x.map(|x| x.$FUNC())).collect()) - } - }; +/// lower +pub fn lower(args: &[ColumnarValue]) -> Result { + handle(args, |x| x.to_ascii_lowercase(), "lower") +} + +/// upper +pub fn upper(args: &[ColumnarValue]) -> Result { + handle(args, |x| x.to_ascii_uppercase(), "upper") } -string_unary_function!(lower, to_ascii_lowercase); -string_unary_function!(upper, to_ascii_uppercase); -string_unary_function!(trim, trim); -string_unary_function!(ltrim, trim_start); -string_unary_function!(rtrim, trim_end); +/// trim +pub fn trim(args: &[ColumnarValue]) -> Result { + handle(args, |x: &str| x.trim(), "trim") +} + +/// ltrim +pub fn ltrim(args: &[ColumnarValue]) -> Result { + handle(args, |x| x.trim_start(), "ltrim") +} + +/// rtrim +pub fn rtrim(args: &[ColumnarValue]) -> Result { + handle(args, |x| x.trim_end(), "rtrim") +} diff --git a/rust/datafusion/src/scalar.rs b/rust/datafusion/src/scalar.rs index d358e69fe5d..88efa8897ee 100644 --- a/rust/datafusion/src/scalar.rs +++ b/rust/datafusion/src/scalar.rs @@ -54,6 +54,10 @@ pub enum ScalarValue { Utf8(Option), /// utf-8 encoded string representing a LargeString's arrow type. LargeUtf8(Option), + /// binary + Binary(Option>), + /// large binary + LargeBinary(Option>), /// list of nested ScalarValue List(Option>, DataType), /// Date stored as a signed 32bit int @@ -141,6 +145,8 @@ impl ScalarValue { ScalarValue::Float64(_) => DataType::Float64, ScalarValue::Utf8(_) => DataType::Utf8, ScalarValue::LargeUtf8(_) => DataType::LargeUtf8, + ScalarValue::Binary(_) => DataType::Binary, + ScalarValue::LargeBinary(_) => DataType::LargeBinary, ScalarValue::List(_, data_type) => { DataType::List(Box::new(Field::new("item", data_type.clone(), true))) } @@ -293,6 +299,28 @@ impl ScalarValue { } None => new_null_array(&DataType::LargeUtf8, size), }, + ScalarValue::Binary(e) => match e { + Some(value) => Arc::new( + repeat(Some(value.as_slice())) + .take(size) + .collect::(), + ), + None => { + Arc::new(repeat(None::<&str>).take(size).collect::()) + } + }, + ScalarValue::LargeBinary(e) => match e { + Some(value) => Arc::new( + repeat(Some(value.as_slice())) + .take(size) + .collect::(), + ), + None => Arc::new( + repeat(None::<&str>) + .take(size) + .collect::(), + ), + }, ScalarValue::List(values, data_type) => Arc::new(match data_type { DataType::Int8 => build_list!(Int8Builder, Int8, values, size), DataType::Int16 => build_list!(Int16Builder, Int16, values, size), @@ -556,6 +584,28 @@ impl fmt::Display for ScalarValue { ScalarValue::TimeNanosecond(e) => format_option!(f, e)?, ScalarValue::Utf8(e) => format_option!(f, e)?, ScalarValue::LargeUtf8(e) => format_option!(f, e)?, + ScalarValue::Binary(e) => match e { + Some(l) => write!( + f, + "{}", + l.iter() + .map(|v| format!("{}", v)) + .collect::>() + .join(",") + )?, + None => write!(f, "NULL")?, + }, + ScalarValue::LargeBinary(e) => match e { + Some(l) => write!( + f, + "{}", + l.iter() + .map(|v| format!("{}", v)) + .collect::>() + .join(",") + )?, + None => write!(f, "NULL")?, + }, ScalarValue::List(e, _) => match e { Some(l) => write!( f, @@ -596,6 +646,10 @@ impl fmt::Debug for ScalarValue { ScalarValue::Utf8(Some(_)) => write!(f, "Utf8(\"{}\")", self), ScalarValue::LargeUtf8(None) => write!(f, "LargeUtf8({})", self), ScalarValue::LargeUtf8(Some(_)) => write!(f, "LargeUtf8(\"{}\")", self), + ScalarValue::Binary(None) => write!(f, "Binary({})", self), + ScalarValue::Binary(Some(_)) => write!(f, "Binary(\"{}\")", self), + ScalarValue::LargeBinary(None) => write!(f, "LargeBinary({})", self), + ScalarValue::LargeBinary(Some(_)) => write!(f, "LargeBinary(\"{}\")", self), ScalarValue::List(_, _) => write!(f, "List([{}])", self), ScalarValue::Date32(_) => write!(f, "Date32(\"{}\")", self), ScalarValue::Date64(_) => write!(f, "Date64(\"{}\")", self), From d9f076dca2cb1a9106a871646fe310a5593b2677 Mon Sep 17 00:00:00 2001 From: "Jorge C. Leitao" Date: Sun, 31 Jan 2021 10:56:46 +0100 Subject: [PATCH 2/8] Migrated remaining functions. --- .../src/physical_plan/array_expressions.rs | 23 +- .../src/physical_plan/datetime_expressions.rs | 566 +++++++++++------- .../src/physical_plan/expressions/binary.rs | 1 + .../src/physical_plan/expressions/mod.rs | 169 +----- .../src/physical_plan/expressions/nullif.rs | 188 ++++++ .../datafusion/src/physical_plan/functions.rs | 67 +-- rust/datafusion/src/scalar.rs | 23 +- 7 files changed, 628 insertions(+), 409 deletions(-) create mode 100644 rust/datafusion/src/physical_plan/expressions/nullif.rs diff --git a/rust/datafusion/src/physical_plan/array_expressions.rs b/rust/datafusion/src/physical_plan/array_expressions.rs index 9af81ad5b8d..a7e03b70e5d 100644 --- a/rust/datafusion/src/physical_plan/array_expressions.rs +++ b/rust/datafusion/src/physical_plan/array_expressions.rs @@ -22,6 +22,8 @@ use arrow::array::*; use arrow::datatypes::DataType; use std::sync::Arc; +use super::ColumnarValue; + macro_rules! downcast_vec { ($ARGS:expr, $ARRAY_TYPE:ident) => {{ $ARGS @@ -58,8 +60,7 @@ macro_rules! array { }}; } -/// put values in an array. -pub fn array(args: &[ArrayRef]) -> Result { +fn array_array(args: &[&dyn Array]) -> Result { // do not accept 0 arguments. if args.is_empty() { return Err(DataFusionError::Internal( @@ -88,6 +89,24 @@ pub fn array(args: &[ArrayRef]) -> Result { } } +/// put values in an array. +pub fn array(values: &[ColumnarValue]) -> Result { + let arrays: Vec<&dyn Array> = values + .iter() + .map(|value| { + if let ColumnarValue::Array(value) = value { + Ok(value.as_ref()) + } else { + Err(DataFusionError::NotImplemented( + "Array is not implemented for scalar values.".to_string(), + )) + } + }) + .collect::>()?; + + Ok(ColumnarValue::Array(array_array(&arrays)?)) +} + /// Currently supported types by the array function. /// The order of these types correspond to the order on which coercion applies /// This should thus be from least informative to most informative diff --git a/rust/datafusion/src/physical_plan/datetime_expressions.rs b/rust/datafusion/src/physical_plan/datetime_expressions.rs index 34414586983..b91611f6a40 100644 --- a/rust/datafusion/src/physical_plan/datetime_expressions.rs +++ b/rust/datafusion/src/physical_plan/datetime_expressions.rs @@ -19,163 +19,323 @@ use std::sync::Arc; -use crate::error::{DataFusionError, Result}; +use crate::{ + error::{DataFusionError, Result}, + scalar::{ScalarType, ScalarValue}, +}; +use arrow::temporal_conversions::timestamp_ns_to_datetime; use arrow::{ - array::{Array, ArrayData, ArrayRef, StringArray, TimestampNanosecondArray}, - buffer::Buffer, - compute::kernels::cast_utils::string_to_timestamp_nanos, - datatypes::{DataType, TimeUnit, ToByteSlice}, + array::{ + Array, GenericStringArray, PrimitiveArray, StringOffsetSizeTrait, + TimestampNanosecondArray, + }, + datatypes::{ArrowPrimitiveType, DataType, TimestampNanosecondType}, }; use chrono::prelude::*; use chrono::Duration; +use chrono::LocalResult; + +use super::ColumnarValue; + +#[inline] +/// Accepts a string in RFC3339 / ISO8601 standard format and some +/// variants and converts it to a nanosecond precision timestamp. +/// +/// Implements the `to_timestamp` function to convert a string to a +/// timestamp, following the model of spark SQL’s to_`timestamp`. +/// +/// In addition to RFC3339 / ISO8601 standard timestamps, it also +/// accepts strings that use a space ` ` to separate the date and time +/// as well as strings that have no explicit timezone offset. +/// +/// Examples of accepted inputs: +/// * `1997-01-31T09:26:56.123Z` # RCF3339 +/// * `1997-01-31T09:26:56.123-05:00` # RCF3339 +/// * `1997-01-31 09:26:56.123-05:00` # close to RCF3339 but with a space rather than T +/// * `1997-01-31T09:26:56.123` # close to RCF3339 but no timezone offset specified +/// * `1997-01-31 09:26:56.123` # close to RCF3339 but uses a space and no timezone offset +/// * `1997-01-31 09:26:56` # close to RCF3339, no fractional seconds +// +/// Internally, this function uses the `chrono` library for the +/// datetime parsing +/// +/// We hope to extend this function in the future with a second +/// parameter to specifying the format string. +/// +/// ## Timestamp Precision +/// +/// DataFusion uses the maximum precision timestamps supported by +/// Arrow (nanoseconds stored as a 64-bit integer) timestamps. This +/// means the range of dates that timestamps can represent is ~1677 AD +/// to 2262 AM +/// +/// +/// ## Timezone / Offset Handling +/// +/// By using the Arrow format, DataFusion inherits Arrow’s handling of +/// timestamp values. Specifically, the stored numerical values of +/// timestamps are stored compared to offset UTC. +/// +/// This function intertprets strings without an explicit time zone as +/// timestamps with offsets of the local time on the machine that ran +/// the datafusion query +/// +/// For example, `1997-01-31 09:26:56.123Z` is interpreted as UTC, as +/// it has an explicit timezone specifier (“Z” for Zulu/UTC) +/// +/// `1997-01-31T09:26:56.123` is interpreted as a local timestamp in +/// the timezone of the machine that ran DataFusion. For example, if +/// the system timezone is set to Americas/New_York (UTC-5) the +/// timestamp will be interpreted as though it were +/// `1997-01-31T09:26:56.123-05:00` +fn string_to_timestamp_nanos(s: &str) -> Result { + // Fast path: RFC3339 timestamp (with a T) + // Example: 2020-09-08T13:42:29.190855Z + if let Ok(ts) = DateTime::parse_from_rfc3339(s) { + return Ok(ts.timestamp_nanos()); + } -/// convert an array of strings into `Timestamp(Nanosecond, None)` -pub fn to_timestamp(args: &[ArrayRef]) -> Result { - let num_rows = args[0].len(); - let string_args = - &args[0] - .as_any() - .downcast_ref::() - .ok_or_else(|| { - DataFusionError::Internal( - "could not cast to_timestamp input to StringArray".to_string(), - ) - })?; - - let result = (0..num_rows) - .map(|i| { - if string_args.is_null(i) { - // NB: Since we use the same null bitset as the input, - // the output for this value will be ignored, but we - // need some value in the array we are building. - Ok(0) - } else { - string_to_timestamp_nanos(string_args.value(i)) - .map_err(DataFusionError::ArrowError) + // Implement quasi-RFC3339 support by trying to parse the + // timestamp with various other format specifiers to to support + // separating the date and time with a space ' ' rather than 'T' to be + // (more) compatible with Apache Spark SQL + + // timezone offset, using ' ' as a separator + // Example: 2020-09-08 13:42:29.190855-05:00 + if let Ok(ts) = DateTime::parse_from_str(s, "%Y-%m-%d %H:%M:%S%.f%:z") { + return Ok(ts.timestamp_nanos()); + } + + // with an explicit Z, using ' ' as a separator + // Example: 2020-09-08 13:42:29Z + if let Ok(ts) = Utc.datetime_from_str(s, "%Y-%m-%d %H:%M:%S%.fZ") { + return Ok(ts.timestamp_nanos()); + } + + // Support timestamps without an explicit timezone offset, again + // to be compatible with what Apache Spark SQL does. + + // without a timezone specifier as a local time, using T as a separator + // Example: 2020-09-08T13:42:29.190855 + if let Ok(ts) = NaiveDateTime::parse_from_str(s, "%Y-%m-%dT%H:%M:%S.%f") { + return naive_datetime_to_timestamp(s, ts); + } + + // without a timezone specifier as a local time, using T as a + // separator, no fractional seconds + // Example: 2020-09-08T13:42:29 + if let Ok(ts) = NaiveDateTime::parse_from_str(s, "%Y-%m-%dT%H:%M:%S") { + return naive_datetime_to_timestamp(s, ts); + } + + // without a timezone specifier as a local time, using ' ' as a separator + // Example: 2020-09-08 13:42:29.190855 + if let Ok(ts) = NaiveDateTime::parse_from_str(s, "%Y-%m-%d %H:%M:%S.%f") { + return naive_datetime_to_timestamp(s, ts); + } + + // without a timezone specifier as a local time, using ' ' as a + // separator, no fractional seconds + // Example: 2020-09-08 13:42:29 + if let Ok(ts) = NaiveDateTime::parse_from_str(s, "%Y-%m-%d %H:%M:%S") { + return naive_datetime_to_timestamp(s, ts); + } + + // Note we don't pass along the error message from the underlying + // chrono parsing because we tried several different format + // strings and we don't know which the user was trying to + // match. Ths any of the specific error messages is likely to be + // be more confusing than helpful + Err(DataFusionError::Execution(format!( + "Error parsing '{}' as timestamp", + s + ))) +} + +/// Converts the naive datetime (which has no specific timezone) to a +/// nanosecond epoch timestamp relative to UTC. +fn naive_datetime_to_timestamp(s: &str, datetime: NaiveDateTime) -> Result { + let l = Local {}; + + match l.from_local_datetime(&datetime) { + LocalResult::None => Err(DataFusionError::Execution(format!( + "Error parsing '{}' as timestamp: local time representation is invalid", + s + ))), + LocalResult::Single(local_datetime) => { + Ok(local_datetime.with_timezone(&Utc).timestamp_nanos()) + } + // Ambiguous times can happen if the timestamp is exactly when + // a daylight savings time transition occurs, for example, and + // so the datetime could validly be said to be in two + // potential offsets. However, since we are about to convert + // to UTC anyways, we can pick one arbitrarily + LocalResult::Ambiguous(local_datetime, _) => { + Ok(local_datetime.with_timezone(&Utc).timestamp_nanos()) + } + } +} + +pub(crate) fn unary_string_to_primitive_function<'a, T, O, F>( + args: &[&'a dyn Array], + op: F, + name: &str, +) -> Result> +where + O: ArrowPrimitiveType, + T: StringOffsetSizeTrait, + F: Fn(&'a str) -> Result, +{ + if args.len() != 1 { + return Err(DataFusionError::Internal(format!( + "{:?} args were supplied but {} takes exactly one argument", + args.len(), + name, + ))); + } + + let array = args[0] + .as_any() + .downcast_ref::>() + .unwrap(); + + // first map is the iterator, second is for the `Option<_>` + array.iter().map(|x| x.map(|x| op(x)).transpose()).collect() +} + +fn handle<'a, O, F, S>( + args: &'a [ColumnarValue], + op: F, + name: &str, +) -> Result +where + O: ArrowPrimitiveType, + S: ScalarType, + F: Fn(&'a str) -> Result, +{ + match &args[0] { + ColumnarValue::Array(a) => match a.data_type() { + DataType::Utf8 => Ok(ColumnarValue::Array(Arc::new( + unary_string_to_primitive_function::(&[a.as_ref()], op, name)?, + ))), + DataType::LargeUtf8 => Ok(ColumnarValue::Array(Arc::new( + unary_string_to_primitive_function::(&[a.as_ref()], op, name)?, + ))), + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {:?} for function {}", + other, name, + ))), + }, + ColumnarValue::Scalar(scalar) => match scalar { + ScalarValue::Utf8(a) => { + let result = a.as_ref().map(|x| (op)(x)).transpose()?; + Ok(ColumnarValue::Scalar(S::into_scalar(result))) + } + ScalarValue::LargeUtf8(a) => { + let result = a.as_ref().map(|x| (op)(x)).transpose()?; + Ok(ColumnarValue::Scalar(S::into_scalar(result))) } - }) - .collect::>>()?; - - let data = ArrayData::new( - DataType::Timestamp(TimeUnit::Nanosecond, None), - num_rows, - Some(string_args.null_count()), - string_args.data().null_buffer().cloned(), - 0, - vec![Buffer::from(result.to_byte_slice())], - vec![], - ); - - Ok(TimestampNanosecondArray::from(Arc::new(data))) + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {:?} for function {}", + other, name + ))), + }, + } +} + +/// to_timestamp SQL function +pub fn to_timestamp(args: &[ColumnarValue]) -> Result { + handle::( + args, + string_to_timestamp_nanos, + "to_timestamp", + ) +} + +fn date_trunc_single(granularity: &str, value: i64) -> Result { + let value = timestamp_ns_to_datetime(value).with_nanosecond(0); + let value = match granularity { + "second" => value, + "minute" => value.and_then(|d| d.with_second(0)), + "hour" => value + .and_then(|d| d.with_second(0)) + .and_then(|d| d.with_minute(0)), + "day" => value + .and_then(|d| d.with_second(0)) + .and_then(|d| d.with_minute(0)) + .and_then(|d| d.with_hour(0)), + "week" => value + .and_then(|d| d.with_second(0)) + .and_then(|d| d.with_minute(0)) + .and_then(|d| d.with_hour(0)) + .map(|d| d - Duration::seconds(60 * 60 * 24 * d.weekday() as i64)), + "month" => value + .and_then(|d| d.with_second(0)) + .and_then(|d| d.with_minute(0)) + .and_then(|d| d.with_hour(0)) + .and_then(|d| d.with_day0(0)), + "year" => value + .and_then(|d| d.with_second(0)) + .and_then(|d| d.with_minute(0)) + .and_then(|d| d.with_hour(0)) + .and_then(|d| d.with_day0(0)) + .and_then(|d| d.with_month0(0)), + unsupported => { + return Err(DataFusionError::Execution(format!( + "Unsupported date_trunc granularity: {}", + unsupported + ))) + } + }; + // `with_x(0)` are infalible because `0` are always a valid + Ok(value.unwrap().timestamp_nanos()) } /// date_trunc SQL function -pub fn date_trunc(args: &[ArrayRef]) -> Result { - let granularity_array = - &args[0] - .as_any() - .downcast_ref::() - .ok_or_else(|| { - DataFusionError::Execution( - "Could not cast date_trunc granularity input to StringArray" - .to_string(), - ) - })?; - - let array = &args[1] - .as_any() - .downcast_ref::() - .ok_or_else(|| { - DataFusionError::Execution( - "Could not cast date_trunc array input to TimestampNanosecondArray" - .to_string(), - ) - })?; - - let range = 0..array.len(); - let result = range - .map(|i| { - if array.is_null(i) { - Ok(0_i64) +pub fn date_trunc(args: &[ColumnarValue]) -> Result { + let (granularity, array) = (&args[0], &args[1]); + + let granularity = + if let ColumnarValue::Scalar(ScalarValue::Utf8(Some(v))) = granularity { + v + } else { + return Err(DataFusionError::Execution( + "Granularity of `date_trunc` must be non-null scalar Utf8".to_string(), + )); + }; + + let f = |x: Option| x.map(|x| date_trunc_single(granularity, x)).transpose(); + + Ok(match array { + ColumnarValue::Scalar(scalar) => { + if let ScalarValue::TimeNanosecond(v) = scalar { + ColumnarValue::Scalar(ScalarValue::TimeNanosecond((f)(*v)?)) } else { - let date_time = match granularity_array.value(i) { - "second" => array - .value_as_datetime(i) - .and_then(|d| d.with_nanosecond(0)), - "minute" => array - .value_as_datetime(i) - .and_then(|d| d.with_nanosecond(0)) - .and_then(|d| d.with_second(0)), - "hour" => array - .value_as_datetime(i) - .and_then(|d| d.with_nanosecond(0)) - .and_then(|d| d.with_second(0)) - .and_then(|d| d.with_minute(0)), - "day" => array - .value_as_datetime(i) - .and_then(|d| d.with_nanosecond(0)) - .and_then(|d| d.with_second(0)) - .and_then(|d| d.with_minute(0)) - .and_then(|d| d.with_hour(0)), - "week" => array - .value_as_datetime(i) - .and_then(|d| d.with_nanosecond(0)) - .and_then(|d| d.with_second(0)) - .and_then(|d| d.with_minute(0)) - .and_then(|d| d.with_hour(0)) - .map(|d| { - d - Duration::seconds(60 * 60 * 24 * d.weekday() as i64) - }), - "month" => array - .value_as_datetime(i) - .and_then(|d| d.with_nanosecond(0)) - .and_then(|d| d.with_second(0)) - .and_then(|d| d.with_minute(0)) - .and_then(|d| d.with_hour(0)) - .and_then(|d| d.with_day0(0)), - "year" => array - .value_as_datetime(i) - .and_then(|d| d.with_nanosecond(0)) - .and_then(|d| d.with_second(0)) - .and_then(|d| d.with_minute(0)) - .and_then(|d| d.with_hour(0)) - .and_then(|d| d.with_day0(0)) - .and_then(|d| d.with_month0(0)), - unsupported => { - return Err(DataFusionError::Execution(format!( - "Unsupported date_trunc granularity: {}", - unsupported - ))) - } - }; - date_time.map(|d| d.timestamp_nanos()).ok_or_else(|| { - DataFusionError::Execution(format!( - "Can't truncate date time: {:?}", - array.value_as_datetime(i) - )) - }) + return Err(DataFusionError::Execution( + "array of `date_trunc` must be non-null scalar Utf8".to_string(), + )); } - }) - .collect::>>()?; - - let data = ArrayData::new( - DataType::Timestamp(TimeUnit::Nanosecond, None), - array.len(), - Some(array.null_count()), - array.data().null_buffer().cloned(), - 0, - vec![Buffer::from(result.to_byte_slice())], - vec![], - ); - - Ok(TimestampNanosecondArray::from(Arc::new(data))) + } + ColumnarValue::Array(array) => { + let array = array + .as_any() + .downcast_ref::() + .unwrap(); + let array = array + .iter() + .map(f) + .collect::>()?; + + ColumnarValue::Array(Arc::new(array)) + } + }) } #[cfg(test)] mod tests { use std::sync::Arc; - use arrow::array::{Int64Array, StringBuilder}; + use arrow::array::{ArrayRef, Int64Array, StringBuilder}; use super::*; @@ -191,73 +351,77 @@ mod tests { string_builder.append_null()?; ts_builder.append_null()?; + let expected_timestamps = &ts_builder.finish() as &dyn Array; - let string_array = Arc::new(string_builder.finish()); + let string_array = + ColumnarValue::Array(Arc::new(string_builder.finish()) as ArrayRef); let parsed_timestamps = to_timestamp(&[string_array]) .expect("that to_timestamp parsed values without error"); - - let expected_timestamps = ts_builder.finish(); - - assert_eq!(parsed_timestamps.len(), 2); - assert_eq!(expected_timestamps, parsed_timestamps); + if let ColumnarValue::Array(parsed_array) = parsed_timestamps { + assert_eq!(parsed_array.len(), 2); + assert_eq!(expected_timestamps, parsed_array.as_ref()); + } else { + panic!("Expected a columnar array") + } Ok(()) } #[test] fn date_trunc_test() -> Result<()> { - let mut ts_builder = StringBuilder::new(2); - let mut truncated_builder = StringBuilder::new(2); - let mut string_builder = StringBuilder::new(2); - - ts_builder.append_null()?; - truncated_builder.append_null()?; - string_builder.append_value("second")?; - - ts_builder.append_value("2020-09-08T13:42:29.190855Z")?; - truncated_builder.append_value("2020-09-08T13:42:29.000000Z")?; - string_builder.append_value("second")?; - - ts_builder.append_value("2020-09-08T13:42:29.190855Z")?; - truncated_builder.append_value("2020-09-08T13:42:00.000000Z")?; - string_builder.append_value("minute")?; - - ts_builder.append_value("2020-09-08T13:42:29.190855Z")?; - truncated_builder.append_value("2020-09-08T13:00:00.000000Z")?; - string_builder.append_value("hour")?; - - ts_builder.append_value("2020-09-08T13:42:29.190855Z")?; - truncated_builder.append_value("2020-09-08T00:00:00.000000Z")?; - string_builder.append_value("day")?; - - ts_builder.append_value("2020-09-08T13:42:29.190855Z")?; - truncated_builder.append_value("2020-09-07T00:00:00.000000Z")?; - string_builder.append_value("week")?; - - ts_builder.append_value("2020-09-08T13:42:29.190855Z")?; - truncated_builder.append_value("2020-09-01T00:00:00.000000Z")?; - string_builder.append_value("month")?; - - ts_builder.append_value("2020-09-08T13:42:29.190855Z")?; - truncated_builder.append_value("2020-01-01T00:00:00.000000Z")?; - string_builder.append_value("year")?; - - ts_builder.append_value("2021-01-01T13:42:29.190855Z")?; - truncated_builder.append_value("2020-12-28T00:00:00.000000Z")?; - string_builder.append_value("week")?; - - ts_builder.append_value("2020-01-01T13:42:29.190855Z")?; - truncated_builder.append_value("2019-12-30T00:00:00.000000Z")?; - string_builder.append_value("week")?; - - let string_array = Arc::new(string_builder.finish()); - let ts_array = Arc::new(to_timestamp(&[Arc::new(ts_builder.finish())]).unwrap()); - let date_trunc_array = date_trunc(&[string_array, ts_array]) - .expect("that to_timestamp parsed values without error"); - - let expected_timestamps = - to_timestamp(&[Arc::new(truncated_builder.finish())]).unwrap(); - - assert_eq!(date_trunc_array, expected_timestamps); + let cases = vec![ + ( + "2020-09-08T13:42:29.190855Z", + "second", + "2020-09-08T13:42:29.000000Z", + ), + ( + "2020-09-08T13:42:29.190855Z", + "minute", + "2020-09-08T13:42:00.000000Z", + ), + ( + "2020-09-08T13:42:29.190855Z", + "hour", + "2020-09-08T13:00:00.000000Z", + ), + ( + "2020-09-08T13:42:29.190855Z", + "day", + "2020-09-08T00:00:00.000000Z", + ), + ( + "2020-09-08T13:42:29.190855Z", + "week", + "2020-09-07T00:00:00.000000Z", + ), + ( + "2020-09-08T13:42:29.190855Z", + "month", + "2020-09-01T00:00:00.000000Z", + ), + ( + "2020-09-08T13:42:29.190855Z", + "year", + "2020-01-01T00:00:00.000000Z", + ), + ( + "2021-01-01T13:42:29.190855Z", + "week", + "2020-12-28T00:00:00.000000Z", + ), + ( + "2020-01-01T13:42:29.190855Z", + "week", + "2019-12-30T00:00:00.000000Z", + ), + ]; + + cases.iter().for_each(|(original, granularity, expected)| { + let original = string_to_timestamp_nanos(original).unwrap(); + let expected = string_to_timestamp_nanos(expected).unwrap(); + let result = date_trunc_single(granularity, original).unwrap(); + assert_eq!(result, expected); + }); Ok(()) } @@ -268,10 +432,10 @@ mod tests { let mut builder = Int64Array::builder(1); builder.append_value(1)?; - let int64array = Arc::new(builder.finish()); + let int64array = ColumnarValue::Array(Arc::new(builder.finish())); let expected_err = - "Internal error: could not cast to_timestamp input to StringArray"; + "Internal error: Unsupported data type Int64 for function to_timestamp"; match to_timestamp(&[int64array]) { Ok(_) => panic!("Expected error but got success"), Err(e) => { diff --git a/rust/datafusion/src/physical_plan/expressions/binary.rs b/rust/datafusion/src/physical_plan/expressions/binary.rs index fb9ccda475c..0d503508d63 100644 --- a/rust/datafusion/src/physical_plan/expressions/binary.rs +++ b/rust/datafusion/src/physical_plan/expressions/binary.rs @@ -211,6 +211,7 @@ macro_rules! binary_primitive_array_op { /// The binary_array_op_scalar macro includes types that extend beyond the primitive, /// such as Utf8 strings. +#[macro_export] macro_rules! binary_array_op_scalar { ($LEFT:expr, $RIGHT:expr, $OP:ident) => {{ let result = match $LEFT.data_type() { diff --git a/rust/datafusion/src/physical_plan/expressions/mod.rs b/rust/datafusion/src/physical_plan/expressions/mod.rs index 9f2964c45fd..bf47aa1cfe8 100644 --- a/rust/datafusion/src/physical_plan/expressions/mod.rs +++ b/rust/datafusion/src/physical_plan/expressions/mod.rs @@ -22,16 +22,7 @@ use std::sync::Arc; use super::ColumnarValue; use crate::error::{DataFusionError, Result}; use crate::physical_plan::PhysicalExpr; -use arrow::array::Array; -use arrow::array::{ - ArrayRef, BooleanArray, Date32Array, Date64Array, Float32Array, Float64Array, - Int16Array, Int32Array, Int64Array, Int8Array, StringArray, TimestampNanosecondArray, - UInt16Array, UInt32Array, UInt64Array, UInt8Array, -}; -use arrow::compute::kernels::boolean::nullif; -use arrow::compute::kernels::comparison::{eq, eq_utf8}; use arrow::compute::kernels::sort::{SortColumn, SortOptions}; -use arrow::datatypes::{DataType, TimeUnit}; use arrow::record_batch::RecordBatch; mod average; @@ -49,6 +40,7 @@ mod literal; mod min_max; mod negative; mod not; +mod nullif; mod sum; pub use average::{avg_return_type, Avg, AvgAccumulator}; @@ -64,6 +56,7 @@ pub use literal::{lit, Literal}; pub use min_max::{Max, Min}; pub use negative::{negative, NegativeExpr}; pub use not::{not, NotExpr}; +pub use nullif::{nullif_func, SUPPORTED_NULLIF_TYPES}; pub use sum::{sum_return_type, Sum}; /// returns the name of the state @@ -71,80 +64,6 @@ pub fn format_state_name(name: &str, state_name: &str) -> String { format!("{}[{}]", name, state_name) } -/// Invoke a compute kernel on a primitive array and a Boolean Array -macro_rules! compute_bool_array_op { - ($LEFT:expr, $RIGHT:expr, $OP:ident, $DT:ident) => {{ - let ll = $LEFT - .as_any() - .downcast_ref::<$DT>() - .expect("compute_op failed to downcast array"); - let rr = $RIGHT - .as_any() - .downcast_ref::() - .expect("compute_op failed to downcast array"); - Ok(Arc::new($OP(&ll, &rr)?)) - }}; -} - -/// Binary op between primitive and boolean arrays -macro_rules! primitive_bool_array_op { - ($LEFT:expr, $RIGHT:expr, $OP:ident) => {{ - match $LEFT.data_type() { - DataType::Int8 => compute_bool_array_op!($LEFT, $RIGHT, $OP, Int8Array), - DataType::Int16 => compute_bool_array_op!($LEFT, $RIGHT, $OP, Int16Array), - DataType::Int32 => compute_bool_array_op!($LEFT, $RIGHT, $OP, Int32Array), - DataType::Int64 => compute_bool_array_op!($LEFT, $RIGHT, $OP, Int64Array), - DataType::UInt8 => compute_bool_array_op!($LEFT, $RIGHT, $OP, UInt8Array), - DataType::UInt16 => compute_bool_array_op!($LEFT, $RIGHT, $OP, UInt16Array), - DataType::UInt32 => compute_bool_array_op!($LEFT, $RIGHT, $OP, UInt32Array), - DataType::UInt64 => compute_bool_array_op!($LEFT, $RIGHT, $OP, UInt64Array), - DataType::Float32 => compute_bool_array_op!($LEFT, $RIGHT, $OP, Float32Array), - DataType::Float64 => compute_bool_array_op!($LEFT, $RIGHT, $OP, Float64Array), - other => Err(DataFusionError::Internal(format!( - "Unsupported data type {:?} for NULLIF/primitive/boolean operator", - other - ))), - } - }}; -} - -/// -/// Implements NULLIF(expr1, expr2) -/// Args: 0 - left expr is any array -/// 1 - if the left is equal to this expr2, then the result is NULL, otherwise left value is passed. -/// -pub fn nullif_func(args: &[ArrayRef]) -> Result { - if args.len() != 2 { - return Err(DataFusionError::Internal(format!( - "{:?} args were supplied but NULLIF takes exactly two args", - args.len(), - ))); - } - - // Get args0 == args1 evaluated and produce a boolean array - let cond_array = binary_array_op!(args[0], args[1], eq)?; - - // Now, invoke nullif on the result - primitive_bool_array_op!(args[0], *cond_array, nullif) -} - -/// Currently supported types by the nullif function. -/// The order of these types correspond to the order on which coercion applies -/// This should thus be from least informative to most informative -pub static SUPPORTED_NULLIF_TYPES: &[DataType] = &[ - DataType::Boolean, - DataType::UInt8, - DataType::UInt16, - DataType::UInt32, - DataType::UInt64, - DataType::Int8, - DataType::Int16, - DataType::Int32, - DataType::Int64, - DataType::Float32, - DataType::Float64, -]; - /// Represents Sort operation for a column in a RecordBatch #[derive(Clone, Debug)] pub struct PhysicalSortExpr { @@ -178,8 +97,6 @@ impl PhysicalSortExpr { mod tests { use super::*; use crate::{error::Result, physical_plan::AggregateExpr, scalar::ScalarValue}; - use arrow::array::PrimitiveArray; - use arrow::datatypes::*; /// macro to perform an aggregation and verify the result. #[macro_export] @@ -200,70 +117,6 @@ mod tests { }}; } - #[test] - fn nullif_int32() -> Result<()> { - let a = Int32Array::from(vec![ - Some(1), - Some(2), - None, - None, - Some(3), - None, - None, - Some(4), - Some(5), - ]); - let a = Arc::new(a); - let a_len = a.len(); - - let lit_array = Arc::new(Int32Array::from(vec![2; a.len()])); - - let result = nullif_func(&[a, lit_array])?; - - assert_eq!(result.len(), a_len); - - let expected = Int32Array::from(vec![ - Some(1), - None, - None, - None, - Some(3), - None, - None, - Some(4), - Some(5), - ]); - assert_array_eq::(expected, result); - Ok(()) - } - - #[test] - // Ensure that arrays with no nulls can also invoke NULLIF() correctly - fn nullif_int32_nonulls() -> Result<()> { - let a = Int32Array::from(vec![1, 3, 10, 7, 8, 1, 2, 4, 5]); - let a = Arc::new(a); - let a_len = a.len(); - - let lit_array = Arc::new(Int32Array::from(vec![1; a.len()])); - - let result = nullif_func(&[a, lit_array])?; - assert_eq!(result.len(), a_len); - - let expected = Int32Array::from(vec![ - None, - Some(3), - Some(10), - Some(7), - Some(8), - None, - Some(2), - Some(4), - Some(5), - ]); - assert_array_eq::(expected, result); - Ok(()) - } - pub fn aggregate( batch: &RecordBatch, agg: Arc, @@ -278,22 +131,4 @@ mod tests { accum.update_batch(&values)?; accum.evaluate() } - - fn assert_array_eq( - expected: PrimitiveArray, - actual: ArrayRef, - ) { - let actual = actual - .as_any() - .downcast_ref::>() - .expect("Actual array should unwrap to type of expected array"); - - for i in 0..expected.len() { - if expected.is_null(i) { - assert!(actual.is_null(i)); - } else { - assert_eq!(expected.value(i), actual.value(i)); - } - } - } } diff --git a/rust/datafusion/src/physical_plan/expressions/nullif.rs b/rust/datafusion/src/physical_plan/expressions/nullif.rs new file mode 100644 index 00000000000..7cc58ed2318 --- /dev/null +++ b/rust/datafusion/src/physical_plan/expressions/nullif.rs @@ -0,0 +1,188 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use super::ColumnarValue; +use crate::error::{DataFusionError, Result}; +use crate::scalar::ScalarValue; +use arrow::array::Array; +use arrow::array::{ + ArrayRef, BooleanArray, Date32Array, Date64Array, Float32Array, Float64Array, + Int16Array, Int32Array, Int64Array, Int8Array, StringArray, TimestampNanosecondArray, + UInt16Array, UInt32Array, UInt64Array, UInt8Array, +}; +use arrow::compute::kernels::boolean::nullif; +use arrow::compute::kernels::comparison::{eq, eq_scalar, eq_utf8, eq_utf8_scalar}; +use arrow::datatypes::{DataType, TimeUnit}; + +/// Invoke a compute kernel on a primitive array and a Boolean Array +macro_rules! compute_bool_array_op { + ($LEFT:expr, $RIGHT:expr, $OP:ident, $DT:ident) => {{ + let ll = $LEFT + .as_any() + .downcast_ref::<$DT>() + .expect("compute_op failed to downcast array"); + let rr = $RIGHT + .as_any() + .downcast_ref::() + .expect("compute_op failed to downcast array"); + Ok(Arc::new($OP(&ll, &rr)?) as ArrayRef) + }}; +} + +/// Binary op between primitive and boolean arrays +macro_rules! primitive_bool_array_op { + ($LEFT:expr, $RIGHT:expr, $OP:ident) => {{ + match $LEFT.data_type() { + DataType::Int8 => compute_bool_array_op!($LEFT, $RIGHT, $OP, Int8Array), + DataType::Int16 => compute_bool_array_op!($LEFT, $RIGHT, $OP, Int16Array), + DataType::Int32 => compute_bool_array_op!($LEFT, $RIGHT, $OP, Int32Array), + DataType::Int64 => compute_bool_array_op!($LEFT, $RIGHT, $OP, Int64Array), + DataType::UInt8 => compute_bool_array_op!($LEFT, $RIGHT, $OP, UInt8Array), + DataType::UInt16 => compute_bool_array_op!($LEFT, $RIGHT, $OP, UInt16Array), + DataType::UInt32 => compute_bool_array_op!($LEFT, $RIGHT, $OP, UInt32Array), + DataType::UInt64 => compute_bool_array_op!($LEFT, $RIGHT, $OP, UInt64Array), + DataType::Float32 => compute_bool_array_op!($LEFT, $RIGHT, $OP, Float32Array), + DataType::Float64 => compute_bool_array_op!($LEFT, $RIGHT, $OP, Float64Array), + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {:?} for NULLIF/primitive/boolean operator", + other + ))), + } + }}; +} + +/// Implements NULLIF(expr1, expr2) +/// Args: 0 - left expr is any array +/// 1 - if the left is equal to this expr2, then the result is NULL, otherwise left value is passed. +/// +pub fn nullif_func(args: &[ColumnarValue]) -> Result { + if args.len() != 2 { + return Err(DataFusionError::Internal(format!( + "{:?} args were supplied but NULLIF takes exactly two args", + args.len(), + ))); + } + + let (lhs, rhs) = (&args[0], &args[1]); + + match (lhs, rhs) { + (ColumnarValue::Array(lhs), ColumnarValue::Scalar(rhs)) => { + let cond_array = binary_array_op_scalar!(lhs, rhs.clone(), eq).unwrap()?; + + let array = primitive_bool_array_op!(lhs, *cond_array, nullif)?; + + Ok(ColumnarValue::Array(array)) + } + (ColumnarValue::Array(lhs), ColumnarValue::Array(rhs)) => { + // Get args0 == args1 evaluated and produce a boolean array + let cond_array = binary_array_op!(lhs, rhs, eq)?; + + // Now, invoke nullif on the result + let array = primitive_bool_array_op!(lhs, *cond_array, nullif)?; + Ok(ColumnarValue::Array(array)) + } + _ => Err(DataFusionError::NotImplemented( + "nullif does not support a literal as first argument".to_string(), + )), + } +} + +/// Currently supported types by the nullif function. +/// The order of these types correspond to the order on which coercion applies +/// This should thus be from least informative to most informative +pub static SUPPORTED_NULLIF_TYPES: &[DataType] = &[ + DataType::Boolean, + DataType::UInt8, + DataType::UInt16, + DataType::UInt32, + DataType::UInt64, + DataType::Int8, + DataType::Int16, + DataType::Int32, + DataType::Int64, + DataType::Float32, + DataType::Float64, +]; + +#[cfg(test)] +mod tests { + use super::*; + use crate::error::Result; + + #[test] + fn nullif_int32() -> Result<()> { + let a = Int32Array::from(vec![ + Some(1), + Some(2), + None, + None, + Some(3), + None, + None, + Some(4), + Some(5), + ]); + let a = ColumnarValue::Array(Arc::new(a)); + + let lit_array = ColumnarValue::Scalar(ScalarValue::Int32(Some(2i32))); + + let result = nullif_func(&[a, lit_array])?; + let result = result.into_array(0); + + let expected = Arc::new(Int32Array::from(vec![ + Some(1), + None, + None, + None, + Some(3), + None, + None, + Some(4), + Some(5), + ])) as ArrayRef; + assert_eq!(expected.as_ref(), result.as_ref()); + Ok(()) + } + + #[test] + // Ensure that arrays with no nulls can also invoke NULLIF() correctly + fn nullif_int32_nonulls() -> Result<()> { + let a = Int32Array::from(vec![1, 3, 10, 7, 8, 1, 2, 4, 5]); + let a = ColumnarValue::Array(Arc::new(a)); + + let lit_array = ColumnarValue::Scalar(ScalarValue::Int32(Some(1i32))); + + let result = nullif_func(&[a, lit_array])?; + let result = result.into_array(0); + + let expected = Arc::new(Int32Array::from(vec![ + None, + Some(3), + Some(10), + Some(7), + Some(8), + None, + Some(2), + Some(4), + Some(5), + ])) as ArrayRef; + assert_eq!(expected.as_ref(), result.as_ref()); + Ok(()) + } +} diff --git a/rust/datafusion/src/physical_plan/functions.rs b/rust/datafusion/src/physical_plan/functions.rs index 0b5105502cf..61a1bf1b9bb 100644 --- a/rust/datafusion/src/physical_plan/functions.rs +++ b/rust/datafusion/src/physical_plan/functions.rs @@ -385,21 +385,9 @@ pub fn create_physical_expr( BuiltinScalarFunction::Trunc => math_expressions::trunc, BuiltinScalarFunction::Abs => math_expressions::abs, BuiltinScalarFunction::Signum => math_expressions::signum, - /* - BuiltinScalarFunction::NullIf => |args| match &args[0] { - ColumnarValue::Scalar(v) => match v { - ScalarValue::Utf8(v) => Ok(ColumnarValue::Scalar(ScalarValue::Int32( - v.as_ref().map(|x| x.len() as i32), - ))), - ScalarValue::LargeUtf8(v) => Ok(ColumnarValue::Scalar( - ScalarValue::Int64(v.as_ref().map(|x| x.len() as i64)), - )), - _ => unreachable!(), - }, - ColumnarValue::Array(v) => Ok(ColumnarValue::Array(nullif_func(v.as_ref())?)), - }, - */ + BuiltinScalarFunction::NullIf => nullif_func, BuiltinScalarFunction::MD5 => crypto_expressions::md5, + BuiltinScalarFunction::SHA224 => crypto_expressions::sha224, BuiltinScalarFunction::SHA256 => crypto_expressions::sha256, BuiltinScalarFunction::SHA384 => crypto_expressions::sha384, BuiltinScalarFunction::SHA512 => crypto_expressions::sha512, @@ -421,16 +409,9 @@ pub fn create_physical_expr( BuiltinScalarFunction::Ltrim => string_expressions::ltrim, BuiltinScalarFunction::Rtrim => string_expressions::rtrim, BuiltinScalarFunction::Upper => string_expressions::upper, - /* - BuiltinScalarFunction::ToTimestamp => { - |args| Ok(Arc::new(datetime_expressions::to_timestamp(args)?)) - } - BuiltinScalarFunction::DateTrunc => { - |args| Ok(Arc::new(datetime_expressions::date_trunc(args)?)) - } - BuiltinScalarFunction::Array => |args| Ok(array_expressions::array(args)?), - */ - _ => todo!(), + BuiltinScalarFunction::ToTimestamp => datetime_expressions::to_timestamp, + BuiltinScalarFunction::DateTrunc => datetime_expressions::date_trunc, + BuiltinScalarFunction::Array => array_expressions::array, }); // coerce let args = coerce(args, input_schema, &signature(fun))?; @@ -590,9 +571,16 @@ impl PhysicalExpr for ScalarFunctionExpr { #[cfg(test)] mod tests { use super::*; - use crate::{error::Result, physical_plan::expressions::lit, scalar::ScalarValue}; + use crate::{ + error::Result, + physical_plan::expressions::{col, lit}, + scalar::ScalarValue, + }; use arrow::{ - array::{ArrayRef, FixedSizeListArray, Float64Array, Int32Array, StringArray}, + array::{ + ArrayRef, FixedSizeListArray, Float64Array, Int32Array, StringArray, + UInt32Array, UInt64Array, + }, datatypes::Field, record_batch::RecordBatch, }; @@ -681,18 +669,21 @@ mod tests { } fn generic_test_array( - value1: ScalarValue, - value2: ScalarValue, + value1: ArrayRef, + value2: ArrayRef, expected_type: DataType, expected: &str, ) -> Result<()> { // any type works here: we evaluate against a literal of `value` - let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); - let columns: Vec = vec![Arc::new(Int32Array::from(vec![1]))]; + let schema = Schema::new(vec![ + Field::new("a", value1.data_type().clone(), false), + Field::new("b", value2.data_type().clone(), false), + ]); + let columns: Vec = vec![value1, value2]; let expr = create_physical_expr( &BuiltinScalarFunction::Array, - &[lit(value1), lit(value2)], + &vec![col("a"), col("b")], &schema, )?; @@ -722,24 +713,24 @@ mod tests { #[test] fn test_array() -> Result<()> { generic_test_array( - ScalarValue::Utf8(Some("aa".to_string())), - ScalarValue::Utf8(Some("aa".to_string())), + Arc::new(StringArray::from(vec!["aa"])), + Arc::new(StringArray::from(vec!["bb"])), DataType::Utf8, - "StringArray\n[\n \"aa\",\n \"aa\",\n]", + "StringArray\n[\n \"aa\",\n \"bb\",\n]", )?; // different types, to validate that casting happens generic_test_array( - ScalarValue::from(1u32), - ScalarValue::from(1u64), + Arc::new(UInt32Array::from(vec![1u32])), + Arc::new(UInt64Array::from(vec![1u64])), DataType::UInt64, "PrimitiveArray\n[\n 1,\n 1,\n]", )?; // different types (another order), to validate that casting happens generic_test_array( - ScalarValue::from(1u64), - ScalarValue::from(1u32), + Arc::new(UInt64Array::from(vec![1u64])), + Arc::new(UInt32Array::from(vec![1u32])), DataType::UInt64, "PrimitiveArray\n[\n 1,\n 1,\n]", ) diff --git a/rust/datafusion/src/scalar.rs b/rust/datafusion/src/scalar.rs index 88efa8897ee..64034ebdb1d 100644 --- a/rust/datafusion/src/scalar.rs +++ b/rust/datafusion/src/scalar.rs @@ -19,8 +19,11 @@ use std::{convert::TryFrom, fmt, iter::repeat, sync::Arc}; -use arrow::array::*; use arrow::datatypes::{DataType, Field, IntervalUnit, TimeUnit}; +use arrow::{ + array::*, + datatypes::{ArrowNativeType, Float32Type, TimestampNanosecondType}, +}; use crate::error::{DataFusionError, Result}; @@ -663,6 +666,24 @@ impl fmt::Debug for ScalarValue { } } +/// Trait used to map +pub trait ScalarType { + /// returns a scalar from an optional T + fn into_scalar(r: Option) -> ScalarValue; +} + +impl ScalarType for Float32Type { + fn into_scalar(r: Option) -> ScalarValue { + ScalarValue::Float32(r) + } +} + +impl ScalarType for TimestampNanosecondType { + fn into_scalar(r: Option) -> ScalarValue { + ScalarValue::TimeNanosecond(r) + } +} + #[cfg(test)] mod tests { use super::*; From 87d787f4191cfa0aed5d1329bdce1aaf7ac5ece8 Mon Sep 17 00:00:00 2001 From: "Jorge C. Leitao" Date: Sun, 31 Jan 2021 16:13:25 +0100 Subject: [PATCH 3/8] Updated example. --- rust/datafusion/examples/simple_udf.rs | 98 +++++++++++++++++--------- 1 file changed, 64 insertions(+), 34 deletions(-) diff --git a/rust/datafusion/examples/simple_udf.rs b/rust/datafusion/examples/simple_udf.rs index 0eef801e07e..c850ce6b3dd 100644 --- a/rust/datafusion/examples/simple_udf.rs +++ b/rust/datafusion/examples/simple_udf.rs @@ -16,13 +16,17 @@ // under the License. use arrow::{ - array::{ArrayRef, Float32Array, Float64Array}, + array::{Array, ArrayRef, Float32Array, Float64Array}, datatypes::DataType, record_batch::RecordBatch, util::pretty, }; -use datafusion::error::Result; +use datafusion::{ + error::{DataFusionError, Result}, + physical_plan::ColumnarValue, + scalar::ScalarValue, +}; use datafusion::{physical_plan::functions::ScalarFunctionImplementation, prelude::*}; use std::sync::Arc; @@ -54,50 +58,76 @@ fn create_context() -> Result { Ok(ctx) } +// a small utility function to compute pow(base, exponent) +fn maybe_pow(base: &Option, exponent: &Option) -> Option { + match (base, exponent) { + // in arrow, any value can be null. + // Here we decide to make our UDF to return null when either base or exponent is null. + (Some(base), Some(exponent)) => Some(base.powf(*exponent)), + _ => None, + } +} + +fn pow_array(base: &dyn Array, exponent: &dyn Array) -> Result { + // 1. cast both arguments to f64. These casts MUST be aligned with the signature or this function panics! + let base = base + .as_any() + .downcast_ref::() + .expect("cast failed"); + let exponent = exponent + .as_any() + .downcast_ref::() + .expect("cast failed"); + + // this is guaranteed by DataFusion. We place it just to make it obvious. + assert_eq!(exponent.len(), base.len()); + + // 2. perform the computation + let array = base + .iter() + .zip(exponent.iter()) + .map(|(base, exponent)| maybe_pow(&base, &exponent)) + .collect::(); + + // `Ok` because no error occurred during the calculation (we should add one if exponent was [0, 1[ and the base < 0 because that panics!) + // `Arc` because arrays are immutable, thread-safe, trait objects. + Ok(Arc::new(array)) +} + /// In this example we will declare a single-type, single return type UDF that exponentiates f64, a^b #[tokio::main] async fn main() -> Result<()> { let mut ctx = create_context()?; // First, declare the actual implementation of the calculation - let pow: ScalarFunctionImplementation = Arc::new(|args: &[ArrayRef]| { - // in DataFusion, all `args` and output are dynamically-typed arrays, which means that we need to: + let pow: ScalarFunctionImplementation = Arc::new(|args: &[ColumnarValue]| { + // in DataFusion, all `args` and output are `ColumnarValue`, an enum of either a scalar or a dynamically-typed array. + // we can cater for both, or document that the UDF only supports some variants. + // here we will assume that al // 1. cast the values to the type we want // 2. perform the computation for every element in the array (using a loop or SIMD) and construct the result // this is guaranteed by DataFusion based on the function's signature. assert_eq!(args.len(), 2); - // 1. cast both arguments to f64. These casts MUST be aligned with the signature or this function panics! - let base = &args[0] - .as_any() - .downcast_ref::() - .expect("cast failed"); - let exponent = &args[1] - .as_any() - .downcast_ref::() - .expect("cast failed"); - - // this is guaranteed by DataFusion. We place it just to make it obvious. - assert_eq!(exponent.len(), base.len()); - - // 2. perform the computation - let array = base - .iter() - .zip(exponent.iter()) - .map(|(base, exponent)| { - match (base, exponent) { - // in arrow, any value can be null. - // Here we decide to make our UDF to return null when either base or exponent is null. - (Some(base), Some(exponent)) => Some(base.powf(exponent)), - _ => None, - } - }) - .collect::(); - - // `Ok` because no error occurred during the calculation (we should add one if exponent was [0, 1[ and the base < 0 because that panics!) - // `Arc` because arrays are immutable, thread-safe, trait objects. - Ok(Arc::new(array)) + let (base, exponent) = (&args[0], &args[1]); + + let result = match (base, exponent) { + ( + ColumnarValue::Scalar(ScalarValue::Float64(base)), + ColumnarValue::Scalar(ScalarValue::Float64(exponent)), + ) => ColumnarValue::Scalar(ScalarValue::Float64(maybe_pow(base, exponent))), + (ColumnarValue::Array(base), ColumnarValue::Array(exponent)) => { + let array = pow_array(base.as_ref(), exponent.as_ref())?; + ColumnarValue::Array(array) + } + _ => { + return Err(DataFusionError::Execution( + "This UDF only supports f64".to_string(), + )) + } + }; + Ok(result) }); // Next: From 75eec09f2db0cb625fde8ffbd1a6a5ebc52ae651 Mon Sep 17 00:00:00 2001 From: "Jorge C. Leitao" Date: Sun, 31 Jan 2021 16:16:15 +0100 Subject: [PATCH 4/8] Fixed clippy --- .../src/physical_plan/string_expressions.rs | 23 ++++++++----------- 1 file changed, 10 insertions(+), 13 deletions(-) diff --git a/rust/datafusion/src/physical_plan/string_expressions.rs b/rust/datafusion/src/physical_plan/string_expressions.rs index c102a629703..018468460db 100644 --- a/rust/datafusion/src/physical_plan/string_expressions.rs +++ b/rust/datafusion/src/physical_plan/string_expressions.rs @@ -165,19 +165,16 @@ pub fn concatenate(args: &[ColumnarValue]) -> Result { // short avenue with only scalars let initial = Some("".to_string()); let result = args.iter().fold(initial, |mut acc, rhs| { - match acc { - Some(ref mut inner) => { - match rhs { - ColumnarValue::Scalar(ScalarValue::Utf8(Some(v))) => { - inner.push_str(v); - } - ColumnarValue::Scalar(ScalarValue::Utf8(None)) => { - acc = None; - } - _ => unreachable!(""), - }; - } - None => {} + if let Some(ref mut inner) = acc { + match rhs { + ColumnarValue::Scalar(ScalarValue::Utf8(Some(v))) => { + inner.push_str(v); + } + ColumnarValue::Scalar(ScalarValue::Utf8(None)) => { + acc = None; + } + _ => unreachable!(""), + }; }; acc }); From 545d2d4e125df24454defc30b6feab163d55ef0f Mon Sep 17 00:00:00 2001 From: "Jorge C. Leitao" Date: Sun, 31 Jan 2021 16:59:53 +0100 Subject: [PATCH 5/8] Fixed error. --- rust/datafusion/src/scalar.rs | 18 +++++++++++++++++- rust/datafusion/tests/sql.rs | 28 +++++++++++++--------------- 2 files changed, 30 insertions(+), 16 deletions(-) diff --git a/rust/datafusion/src/scalar.rs b/rust/datafusion/src/scalar.rs index 64034ebdb1d..64d035ec885 100644 --- a/rust/datafusion/src/scalar.rs +++ b/rust/datafusion/src/scalar.rs @@ -520,7 +520,23 @@ impl TryFrom for i32 { } } -impl_try_from!(Int64, i64); +// special implementation for i64 because of TimeNanosecond +impl TryFrom for i64 { + type Error = DataFusionError; + + fn try_from(value: ScalarValue) -> Result { + match value { + ScalarValue::Int64(Some(inner_value)) + | ScalarValue::TimeNanosecond(Some(inner_value)) => Ok(inner_value), + _ => Err(DataFusionError::Internal(format!( + "Cannot convert {:?} to {}", + value, + std::any::type_name::() + ))), + } + } +} + impl_try_from!(UInt8, u8); impl_try_from!(UInt16, u16); impl_try_from!(UInt32, u32); diff --git a/rust/datafusion/tests/sql.rs b/rust/datafusion/tests/sql.rs index 8c1850e1df0..be6a1235089 100644 --- a/rust/datafusion/tests/sql.rs +++ b/rust/datafusion/tests/sql.rs @@ -28,7 +28,6 @@ use arrow::{ util::display::array_value_to_string, }; -use datafusion::error::Result; use datafusion::execution::context::ExecutionContext; use datafusion::logical_plan::{LogicalPlan, ToDFSchema}; use datafusion::prelude::create_udf; @@ -36,6 +35,7 @@ use datafusion::{ datasource::{csv::CsvReadOptions, MemTable}, physical_plan::collect, }; +use datafusion::{error::Result, physical_plan::ColumnarValue}; #[tokio::test] async fn nyc() -> Result<()> { @@ -569,21 +569,19 @@ fn create_ctx() -> Result { Ok(ctx) } -fn custom_sqrt(args: &[ArrayRef]) -> Result { - let input = &args[0] - .as_any() - .downcast_ref::() - .expect("cast failed"); - - let mut builder = Float64Builder::new(input.len()); - for i in 0..input.len() { - if input.is_null(i) { - builder.append_null()?; - } else { - builder.append_value(input.value(i).sqrt())?; - } +fn custom_sqrt(args: &[ColumnarValue]) -> Result { + let arg = &args[0]; + if let ColumnarValue::Array(v) = arg { + let input = v + .as_any() + .downcast_ref::() + .expect("cast failed"); + + let array: Float64Array = input.iter().map(|v| v.map(|x| x.sqrt())).collect(); + Ok(ColumnarValue::Array(Arc::new(array))) + } else { + unimplemented!() } - Ok(Arc::new(builder.finish())) } #[tokio::test] From ae0d125f83ebf436b9b3cf8ace66b546e5ac9a4f Mon Sep 17 00:00:00 2001 From: "Jorge C. Leitao" Date: Tue, 2 Feb 2021 06:46:07 +0100 Subject: [PATCH 6/8] Addressed comment about UDFs. --- rust/datafusion/examples/simple_udf.rs | 105 +++++++----------- rust/datafusion/src/execution/context.rs | 34 +++--- .../datafusion/src/physical_plan/functions.rs | 39 +++++++ rust/datafusion/src/physical_plan/mod.rs | 1 + 4 files changed, 93 insertions(+), 86 deletions(-) diff --git a/rust/datafusion/examples/simple_udf.rs b/rust/datafusion/examples/simple_udf.rs index c850ce6b3dd..c37cc9cc331 100644 --- a/rust/datafusion/examples/simple_udf.rs +++ b/rust/datafusion/examples/simple_udf.rs @@ -16,18 +16,14 @@ // under the License. use arrow::{ - array::{Array, ArrayRef, Float32Array, Float64Array}, + array::{ArrayRef, Float32Array, Float64Array}, datatypes::DataType, record_batch::RecordBatch, util::pretty, }; -use datafusion::{ - error::{DataFusionError, Result}, - physical_plan::ColumnarValue, - scalar::ScalarValue, -}; -use datafusion::{physical_plan::functions::ScalarFunctionImplementation, prelude::*}; +use datafusion::prelude::*; +use datafusion::{error::Result, physical_plan::functions::make_scalar_function}; use std::sync::Arc; // create local execution context with an in-memory table @@ -58,77 +54,54 @@ fn create_context() -> Result { Ok(ctx) } -// a small utility function to compute pow(base, exponent) -fn maybe_pow(base: &Option, exponent: &Option) -> Option { - match (base, exponent) { - // in arrow, any value can be null. - // Here we decide to make our UDF to return null when either base or exponent is null. - (Some(base), Some(exponent)) => Some(base.powf(*exponent)), - _ => None, - } -} - -fn pow_array(base: &dyn Array, exponent: &dyn Array) -> Result { - // 1. cast both arguments to f64. These casts MUST be aligned with the signature or this function panics! - let base = base - .as_any() - .downcast_ref::() - .expect("cast failed"); - let exponent = exponent - .as_any() - .downcast_ref::() - .expect("cast failed"); - - // this is guaranteed by DataFusion. We place it just to make it obvious. - assert_eq!(exponent.len(), base.len()); - - // 2. perform the computation - let array = base - .iter() - .zip(exponent.iter()) - .map(|(base, exponent)| maybe_pow(&base, &exponent)) - .collect::(); - - // `Ok` because no error occurred during the calculation (we should add one if exponent was [0, 1[ and the base < 0 because that panics!) - // `Arc` because arrays are immutable, thread-safe, trait objects. - Ok(Arc::new(array)) -} - /// In this example we will declare a single-type, single return type UDF that exponentiates f64, a^b #[tokio::main] async fn main() -> Result<()> { let mut ctx = create_context()?; // First, declare the actual implementation of the calculation - let pow: ScalarFunctionImplementation = Arc::new(|args: &[ColumnarValue]| { - // in DataFusion, all `args` and output are `ColumnarValue`, an enum of either a scalar or a dynamically-typed array. - // we can cater for both, or document that the UDF only supports some variants. - // here we will assume that al + let pow = |args: &[ArrayRef]| { + // in DataFusion, all `args` and output are dynamically-typed arrays, which means that we need to: // 1. cast the values to the type we want // 2. perform the computation for every element in the array (using a loop or SIMD) and construct the result // this is guaranteed by DataFusion based on the function's signature. assert_eq!(args.len(), 2); - let (base, exponent) = (&args[0], &args[1]); - - let result = match (base, exponent) { - ( - ColumnarValue::Scalar(ScalarValue::Float64(base)), - ColumnarValue::Scalar(ScalarValue::Float64(exponent)), - ) => ColumnarValue::Scalar(ScalarValue::Float64(maybe_pow(base, exponent))), - (ColumnarValue::Array(base), ColumnarValue::Array(exponent)) => { - let array = pow_array(base.as_ref(), exponent.as_ref())?; - ColumnarValue::Array(array) - } - _ => { - return Err(DataFusionError::Execution( - "This UDF only supports f64".to_string(), - )) - } - }; - Ok(result) - }); + // 1. cast both arguments to f64. These casts MUST be aligned with the signature or this function panics! + let base = &args[0] + .as_any() + .downcast_ref::() + .expect("cast failed"); + let exponent = &args[1] + .as_any() + .downcast_ref::() + .expect("cast failed"); + + // this is guaranteed by DataFusion. We place it just to make it obvious. + assert_eq!(exponent.len(), base.len()); + + // 2. perform the computation + let array = base + .iter() + .zip(exponent.iter()) + .map(|(base, exponent)| { + match (base, exponent) { + // in arrow, any value can be null. + // Here we decide to make our UDF to return null when either base or exponent is null. + (Some(base), Some(exponent)) => Some(base.powf(exponent)), + _ => None, + } + }) + .collect::(); + + // `Ok` because no error occurred during the calculation (we should add one if exponent was [0, 1[ and the base < 0 because that panics!) + // `Arc` because arrays are immutable, thread-safe, trait objects. + Ok(Arc::new(array) as ArrayRef) + }; + // the function above expects an `ArrayRef`, but DataFusion may pass a scalar to a UDF. + // thus, we use `make_scalar_function` to decorare the closure so that it can handle both Arrays and Scalar values. + let pow = make_scalar_function(pow); // Next: // * give it a name so that it shows nicely when the plan is printed diff --git a/rust/datafusion/src/execution/context.rs b/rust/datafusion/src/execution/context.rs index 976592ab6a6..ea79acdbc66 100644 --- a/rust/datafusion/src/execution/context.rs +++ b/rust/datafusion/src/execution/context.rs @@ -619,8 +619,8 @@ impl FunctionRegistry for ExecutionContextState { mod tests { use super::*; + use crate::physical_plan::functions::make_scalar_function; use crate::physical_plan::{collect, collect_partitioned}; - use crate::physical_plan::{functions::ScalarFunctionImplementation, ColumnarValue}; use crate::test; use crate::variable::VarType; use crate::{ @@ -631,7 +631,7 @@ mod tests { datasource::MemTable, logical_plan::create_udaf, physical_plan::expressions::AvgAccumulator, }; - use arrow::array::{Float64Array, Int32Array}; + use arrow::array::{ArrayRef, Float64Array, Int32Array}; use arrow::compute::add; use arrow::datatypes::*; use arrow::record_batch::RecordBatch; @@ -1618,24 +1618,18 @@ mod tests { let provider = MemTable::try_new(Arc::new(schema), vec![vec![batch]])?; ctx.register_table("t", Box::new(provider)); - let myfunc: ScalarFunctionImplementation = - Arc::new(|args: &[ColumnarValue]| { - if let (ColumnarValue::Array(l), ColumnarValue::Array(r)) = - (&args[0], &args[1]) - { - let l = l - .as_any() - .downcast_ref::() - .expect("cast failed"); - let r = r - .as_any() - .downcast_ref::() - .expect("cast failed"); - Ok(ColumnarValue::Array(Arc::new(add(l, r)?))) - } else { - unimplemented!() - } - }); + let myfunc = |args: &[ArrayRef]| { + let l = &args[0] + .as_any() + .downcast_ref::() + .expect("cast failed"); + let r = &args[1] + .as_any() + .downcast_ref::() + .expect("cast failed"); + Ok(Arc::new(add(l, r)?) as ArrayRef) + }; + let myfunc = make_scalar_function(myfunc); ctx.register_udf(create_udf( "my_add", diff --git a/rust/datafusion/src/physical_plan/functions.rs b/rust/datafusion/src/physical_plan/functions.rs index 61a1bf1b9bb..13674723bb7 100644 --- a/rust/datafusion/src/physical_plan/functions.rs +++ b/rust/datafusion/src/physical_plan/functions.rs @@ -44,6 +44,7 @@ use crate::{ scalar::ScalarValue, }; use arrow::{ + array::ArrayRef, compute::kernels::length::length, datatypes::TimeUnit, datatypes::{DataType, Field, Schema}, @@ -568,6 +569,44 @@ impl PhysicalExpr for ScalarFunctionExpr { } } +/// decorates a function to handle [`ScalarValue`]s by coverting them to arrays before calling the function +/// and vice-versa after evaluation. +pub fn make_scalar_function(inner: F) -> ScalarFunctionImplementation +where + F: Fn(&[ArrayRef]) -> Result + Sync + Send + 'static, +{ + Arc::new(move |args: &[ColumnarValue]| { + // first, identify if any of the arguments is an Array. If yes, store its `len`, + // as any scalar will need to be converted to an array of len `len`. + let len = args + .iter() + .fold(Option::::None, |acc, arg| match arg { + ColumnarValue::Scalar(_) => acc, + ColumnarValue::Array(a) => Some(a.len()), + }); + + // to array + let args = if let Some(len) = len { + args.iter() + .map(|arg| arg.clone().into_array(len)) + .collect::>() + } else { + args.iter() + .map(|arg| arg.clone().into_array(1)) + .collect::>() + }; + + let result = (inner)(&args); + + // maybe back to scalar + if len.is_some() { + result.map(ColumnarValue::Array) + } else { + ScalarValue::try_from_array(&result?, 0).map(ColumnarValue::Scalar) + } + }) +} + #[cfg(test)] mod tests { use super::*; diff --git a/rust/datafusion/src/physical_plan/mod.rs b/rust/datafusion/src/physical_plan/mod.rs index 2dac406e2a6..e4d761a1ca3 100644 --- a/rust/datafusion/src/physical_plan/mod.rs +++ b/rust/datafusion/src/physical_plan/mod.rs @@ -159,6 +159,7 @@ pub enum Distribution { } /// Represents the result from an expression +#[derive(Clone)] pub enum ColumnarValue { /// Array of values Array(ArrayRef), From ca5bf895e9a7ca2304d73822bf855e875e3bd24f Mon Sep 17 00:00:00 2001 From: "Jorge C. Leitao" Date: Tue, 2 Feb 2021 08:05:52 +0100 Subject: [PATCH 7/8] Improved comments. --- .../src/physical_plan/crypto_expressions.rs | 20 ++++++++++++------- .../src/physical_plan/datetime_expressions.rs | 19 +++++++++++++++--- .../src/physical_plan/string_expressions.rs | 18 ++++++++++++----- rust/datafusion/src/scalar.rs | 8 ++++---- 4 files changed, 46 insertions(+), 19 deletions(-) diff --git a/rust/datafusion/src/physical_plan/crypto_expressions.rs b/rust/datafusion/src/physical_plan/crypto_expressions.rs index 134a098c7d9..4d787082691 100644 --- a/rust/datafusion/src/physical_plan/crypto_expressions.rs +++ b/rust/datafusion/src/physical_plan/crypto_expressions.rs @@ -30,7 +30,7 @@ use crate::{ scalar::ScalarValue, }; use arrow::{ - array::{Array, GenericBinaryArray, GenericStringArray, StringOffsetSizeTrait}, + array::{Array, BinaryArray, GenericStringArray, StringOffsetSizeTrait}, datatypes::DataType, }; @@ -58,11 +58,15 @@ fn sha_process(input: &str) -> SHA2DigestOutput { digest.finalize() } +/// # Errors +/// This function errors when: +/// * the number of arguments is not 1 +/// * the first argument is not castable to a `GenericStringArray` fn unary_binary_function( args: &[&dyn Array], op: F, name: &str, -) -> Result> +) -> Result where R: AsRef<[u8]>, T: StringOffsetSizeTrait, @@ -79,7 +83,9 @@ where let array = args[0] .as_any() .downcast_ref::>() - .unwrap(); + .ok_or_else(|| { + DataFusionError::Internal("failed to downcast to string".to_string()) + })?; // first map is the iterator, second is for the `Option<_>` Ok(array.iter().map(|x| x.map(|x| op(x))).collect()) @@ -111,8 +117,8 @@ where )?))) } other => Err(DataFusionError::Internal(format!( - "Unsupported data type {:?} for function md5", - other, + "Unsupported data type {:?} for function {}", + other, name, ))), }, ColumnarValue::Scalar(scalar) => match scalar { @@ -125,8 +131,8 @@ where Ok(ColumnarValue::Scalar(ScalarValue::Binary(result))) } other => Err(DataFusionError::Internal(format!( - "Unsupported data type {:?} for function md5", - other, + "Unsupported data type {:?} for function {}", + other, name, ))), }, } diff --git a/rust/datafusion/src/physical_plan/datetime_expressions.rs b/rust/datafusion/src/physical_plan/datetime_expressions.rs index b91611f6a40..60ca8e7ac82 100644 --- a/rust/datafusion/src/physical_plan/datetime_expressions.rs +++ b/rust/datafusion/src/physical_plan/datetime_expressions.rs @@ -176,6 +176,14 @@ fn naive_datetime_to_timestamp(s: &str, datetime: NaiveDateTime) -> Result } } +// given a function `op` that maps a `&str` to a Result of an arrow native type, +// returns a `PrimitiveArray` after the application +// of the function to `args[0]`. +/// # Errors +/// This function errors iff: +/// * the number of arguments is not 1 or +/// * the first argument is not castable to a `GenericStringArray` or +/// * the function `op` errors pub(crate) fn unary_string_to_primitive_function<'a, T, O, F>( args: &[&'a dyn Array], op: F, @@ -197,12 +205,17 @@ where let array = args[0] .as_any() .downcast_ref::>() - .unwrap(); + .ok_or_else(|| { + DataFusionError::Internal("failed to downcast to string".to_string()) + })?; // first map is the iterator, second is for the `Option<_>` array.iter().map(|x| x.map(|x| op(x)).transpose()).collect() } +// given an function that maps a `&str` to a arrow native type, +// returns a `ColumnarValue` where the function is applied to either a `ArrayRef` or `ScalarValue` +// depending on the `args`'s variant. fn handle<'a, O, F, S>( args: &'a [ColumnarValue], op: F, @@ -229,11 +242,11 @@ where ColumnarValue::Scalar(scalar) => match scalar { ScalarValue::Utf8(a) => { let result = a.as_ref().map(|x| (op)(x)).transpose()?; - Ok(ColumnarValue::Scalar(S::into_scalar(result))) + Ok(ColumnarValue::Scalar(S::scalar(result))) } ScalarValue::LargeUtf8(a) => { let result = a.as_ref().map(|x| (op)(x)).transpose()?; - Ok(ColumnarValue::Scalar(S::into_scalar(result))) + Ok(ColumnarValue::Scalar(S::scalar(result))) } other => Err(DataFusionError::Internal(format!( "Unsupported data type {:?} for function {}", diff --git a/rust/datafusion/src/physical_plan/string_expressions.rs b/rust/datafusion/src/physical_plan/string_expressions.rs index 018468460db..a4ccef08681 100644 --- a/rust/datafusion/src/physical_plan/string_expressions.rs +++ b/rust/datafusion/src/physical_plan/string_expressions.rs @@ -30,6 +30,12 @@ use arrow::{ use super::ColumnarValue; +/// applies a unary expression to `args[0]` that is expected to be downcastable to +/// a `GenericStringArray` and returns a `GenericStringArray` (which may have a different offset) +/// # Errors +/// This function errors when: +/// * the number of arguments is not 1 +/// * the first argument is not castable to a `GenericStringArray` pub(crate) fn unary_string_function<'a, T, O, F, R>( args: &[&'a dyn Array], op: F, @@ -52,7 +58,9 @@ where let array = args[0] .as_any() .downcast_ref::>() - .unwrap(); + .ok_or_else(|| { + DataFusionError::Internal("failed to downcast to string".to_string()) + })?; // first map is the iterator, second is for the `Option<_>` Ok(array.iter().map(|x| x.map(|x| op(x))).collect()) @@ -86,8 +94,8 @@ where )?))) } other => Err(DataFusionError::Internal(format!( - "Unsupported data type {:?} for function md5", - other, + "Unsupported data type {:?} for function {}", + other, name, ))), }, ColumnarValue::Scalar(scalar) => match scalar { @@ -100,8 +108,8 @@ where Ok(ColumnarValue::Scalar(ScalarValue::LargeUtf8(result))) } other => Err(DataFusionError::Internal(format!( - "Unsupported data type {:?} for function md5", - other, + "Unsupported data type {:?} for function {}", + other, name, ))), }, } diff --git a/rust/datafusion/src/scalar.rs b/rust/datafusion/src/scalar.rs index 64d035ec885..ca0e27dd687 100644 --- a/rust/datafusion/src/scalar.rs +++ b/rust/datafusion/src/scalar.rs @@ -682,20 +682,20 @@ impl fmt::Debug for ScalarValue { } } -/// Trait used to map +/// Trait used to map a NativeTime to a ScalarType. pub trait ScalarType { /// returns a scalar from an optional T - fn into_scalar(r: Option) -> ScalarValue; + fn scalar(r: Option) -> ScalarValue; } impl ScalarType for Float32Type { - fn into_scalar(r: Option) -> ScalarValue { + fn scalar(r: Option) -> ScalarValue { ScalarValue::Float32(r) } } impl ScalarType for TimestampNanosecondType { - fn into_scalar(r: Option) -> ScalarValue { + fn scalar(r: Option) -> ScalarValue { ScalarValue::TimeNanosecond(r) } } From 27b01cf8af51ef2640c0e65da0458ee13641f486 Mon Sep 17 00:00:00 2001 From: "Jorge C. Leitao" Date: Sun, 14 Feb 2021 05:41:35 +0100 Subject: [PATCH 8/8] Fixed. --- rust/datafusion/src/physical_plan/datetime_expressions.rs | 3 +-- rust/datafusion/src/physical_plan/functions.rs | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/rust/datafusion/src/physical_plan/datetime_expressions.rs b/rust/datafusion/src/physical_plan/datetime_expressions.rs index 60ca8e7ac82..8642e3b40e3 100644 --- a/rust/datafusion/src/physical_plan/datetime_expressions.rs +++ b/rust/datafusion/src/physical_plan/datetime_expressions.rs @@ -380,7 +380,7 @@ mod tests { } #[test] - fn date_trunc_test() -> Result<()> { + fn date_trunc_test() { let cases = vec![ ( "2020-09-08T13:42:29.190855Z", @@ -435,7 +435,6 @@ mod tests { let result = date_trunc_single(granularity, original).unwrap(); assert_eq!(result, expected); }); - Ok(()) } #[test] diff --git a/rust/datafusion/src/physical_plan/functions.rs b/rust/datafusion/src/physical_plan/functions.rs index 13674723bb7..c5cd01f93c5 100644 --- a/rust/datafusion/src/physical_plan/functions.rs +++ b/rust/datafusion/src/physical_plan/functions.rs @@ -722,7 +722,7 @@ mod tests { let expr = create_physical_expr( &BuiltinScalarFunction::Array, - &vec![col("a"), col("b")], + &[col("a"), col("b")], &schema, )?;