diff --git a/datafusion-physical-expr/Cargo.toml b/datafusion-physical-expr/Cargo.toml index 5c437caf5257d..85cd66fca8310 100644 --- a/datafusion-physical-expr/Cargo.toml +++ b/datafusion-physical-expr/Cargo.toml @@ -33,6 +33,10 @@ name = "datafusion_physical_expr" path = "src/lib.rs" [features] +default = ["crypto_expressions", "regex_expressions", "unicode_expressions"] +crypto_expressions = ["md-5", "sha2", "blake2", "blake3"] +regex_expressions = ["regex"] +unicode_expressions = ["unicode-segmentation"] [dependencies] datafusion-common = { path = "../datafusion-common", version = "7.0.0" } @@ -41,3 +45,13 @@ arrow = { version = "9.0.0", features = ["prettyprint"] } paste = "^1.0" ahash = { version = "0.7", default-features = false } ordered-float = "2.10" +lazy_static = { version = "^1.4.0" } +md-5 = { version = "^0.10.0", optional = true } +sha2 = { version = "^0.10.1", optional = true } +blake2 = { version = "^0.10.2", optional = true } +blake3 = { version = "1.0", optional = true } +rand = "0.8" +hashbrown = { version = "0.12", features = ["raw"] } +chrono = { version = "0.4", default-features = false } +regex = { version = "^1.4.3", optional = true } +unicode-segmentation = { version = "^1.7.1", optional = true } diff --git a/datafusion-physical-expr/src/aggregate_expr.rs b/datafusion-physical-expr/src/aggregate_expr.rs index fc0f39e6977c4..a22472a496f70 100644 --- a/datafusion-physical-expr/src/aggregate_expr.rs +++ b/datafusion-physical-expr/src/aggregate_expr.rs @@ -16,13 +16,11 @@ // under the License. use crate::PhysicalExpr; - use arrow::datatypes::Field; use datafusion_common::Result; use datafusion_expr::Accumulator; -use std::fmt::Debug; - use std::any::Any; +use std::fmt::Debug; use std::sync::Arc; /// An aggregate expression that: diff --git a/datafusion/src/physical_plan/array_expressions.rs b/datafusion-physical-expr/src/array_expressions.rs similarity index 98% rename from datafusion/src/physical_plan/array_expressions.rs rename to datafusion-physical-expr/src/array_expressions.rs index a7e03b70e5d21..ca396d0b7b51d 100644 --- a/datafusion/src/physical_plan/array_expressions.rs +++ b/datafusion-physical-expr/src/array_expressions.rs @@ -17,13 +17,12 @@ //! Array expressions -use crate::error::{DataFusionError, Result}; use arrow::array::*; use arrow::datatypes::DataType; +use datafusion_common::{DataFusionError, Result}; +use datafusion_expr::ColumnarValue; use std::sync::Arc; -use super::ColumnarValue; - macro_rules! downcast_vec { ($ARGS:expr, $ARRAY_TYPE:ident) => {{ $ARGS diff --git a/datafusion-physical-expr/src/coercion_rule/aggregate_rule.rs b/datafusion-physical-expr/src/coercion_rule/aggregate_rule.rs new file mode 100644 index 0000000000000..279fe7d31b7cb --- /dev/null +++ b/datafusion-physical-expr/src/coercion_rule/aggregate_rule.rs @@ -0,0 +1,262 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Support the coercion rule for aggregate function. + +use crate::expressions::{ + is_approx_percentile_cont_supported_arg_type, is_avg_support_arg_type, + is_correlation_support_arg_type, is_covariance_support_arg_type, + is_stddev_support_arg_type, is_sum_support_arg_type, is_variance_support_arg_type, + try_cast, +}; +use crate::PhysicalExpr; +use arrow::datatypes::DataType; +use arrow::datatypes::Schema; +use datafusion_common::{DataFusionError, Result}; +use datafusion_expr::AggregateFunction; +use datafusion_expr::{Signature, TypeSignature}; +use std::ops::Deref; +use std::sync::Arc; + +/// Returns the coerced data type for each `input_types`. +/// Different aggregate function with different input data type will get corresponding coerced data type. +pub fn coerce_types( + agg_fun: &AggregateFunction, + input_types: &[DataType], + signature: &Signature, +) -> Result> { + // Validate input_types matches (at least one of) the func signature. + check_arg_count(agg_fun, input_types, &signature.type_signature)?; + + match agg_fun { + AggregateFunction::Count | AggregateFunction::ApproxDistinct => { + Ok(input_types.to_vec()) + } + AggregateFunction::ArrayAgg => Ok(input_types.to_vec()), + AggregateFunction::Min | AggregateFunction::Max => { + // min and max support the dictionary data type + // unpack the dictionary to get the value + get_min_max_result_type(input_types) + } + AggregateFunction::Sum => { + // Refer to https://www.postgresql.org/docs/8.2/functions-aggregate.html doc + // smallint, int, bigint, real, double precision, decimal, or interval. + if !is_sum_support_arg_type(&input_types[0]) { + return Err(DataFusionError::Plan(format!( + "The function {:?} does not support inputs of type {:?}.", + agg_fun, input_types[0] + ))); + } + Ok(input_types.to_vec()) + } + AggregateFunction::Avg => { + // Refer to https://www.postgresql.org/docs/8.2/functions-aggregate.html doc + // smallint, int, bigint, real, double precision, decimal, or interval + if !is_avg_support_arg_type(&input_types[0]) { + return Err(DataFusionError::Plan(format!( + "The function {:?} does not support inputs of type {:?}.", + agg_fun, input_types[0] + ))); + } + Ok(input_types.to_vec()) + } + AggregateFunction::Variance => { + if !is_variance_support_arg_type(&input_types[0]) { + return Err(DataFusionError::Plan(format!( + "The function {:?} does not support inputs of type {:?}.", + agg_fun, input_types[0] + ))); + } + Ok(input_types.to_vec()) + } + AggregateFunction::VariancePop => { + if !is_variance_support_arg_type(&input_types[0]) { + return Err(DataFusionError::Plan(format!( + "The function {:?} does not support inputs of type {:?}.", + agg_fun, input_types[0] + ))); + } + Ok(input_types.to_vec()) + } + AggregateFunction::Covariance => { + if !is_covariance_support_arg_type(&input_types[0]) { + return Err(DataFusionError::Plan(format!( + "The function {:?} does not support inputs of type {:?}.", + agg_fun, input_types[0] + ))); + } + Ok(input_types.to_vec()) + } + AggregateFunction::CovariancePop => { + if !is_covariance_support_arg_type(&input_types[0]) { + return Err(DataFusionError::Plan(format!( + "The function {:?} does not support inputs of type {:?}.", + agg_fun, input_types[0] + ))); + } + Ok(input_types.to_vec()) + } + AggregateFunction::Stddev => { + if !is_stddev_support_arg_type(&input_types[0]) { + return Err(DataFusionError::Plan(format!( + "The function {:?} does not support inputs of type {:?}.", + agg_fun, input_types[0] + ))); + } + Ok(input_types.to_vec()) + } + AggregateFunction::StddevPop => { + if !is_stddev_support_arg_type(&input_types[0]) { + return Err(DataFusionError::Plan(format!( + "The function {:?} does not support inputs of type {:?}.", + agg_fun, input_types[0] + ))); + } + Ok(input_types.to_vec()) + } + AggregateFunction::Correlation => { + if !is_correlation_support_arg_type(&input_types[0]) { + return Err(DataFusionError::Plan(format!( + "The function {:?} does not support inputs of type {:?}.", + agg_fun, input_types[0] + ))); + } + Ok(input_types.to_vec()) + } + AggregateFunction::ApproxPercentileCont => { + if !is_approx_percentile_cont_supported_arg_type(&input_types[0]) { + return Err(DataFusionError::Plan(format!( + "The function {:?} does not support inputs of type {:?}.", + agg_fun, input_types[0] + ))); + } + if !matches!(input_types[1], DataType::Float64) { + return Err(DataFusionError::Plan(format!( + "The percentile argument for {:?} must be Float64, not {:?}.", + agg_fun, input_types[1] + ))); + } + Ok(input_types.to_vec()) + } + AggregateFunction::ApproxMedian => { + if !is_approx_percentile_cont_supported_arg_type(&input_types[0]) { + return Err(DataFusionError::Plan(format!( + "The function {:?} does not support inputs of type {:?}.", + agg_fun, input_types[0] + ))); + } + Ok(input_types.to_vec()) + } + } +} + +/// Validate the length of `input_types` matches the `signature` for `agg_fun`. +/// +/// This method DOES NOT validate the argument types - only that (at least one, +/// in the case of [`TypeSignature::OneOf`]) signature matches the desired +/// number of input types. +fn check_arg_count( + agg_fun: &AggregateFunction, + input_types: &[DataType], + signature: &TypeSignature, +) -> Result<()> { + match signature { + TypeSignature::Uniform(agg_count, _) | TypeSignature::Any(agg_count) => { + if input_types.len() != *agg_count { + return Err(DataFusionError::Plan(format!( + "The function {:?} expects {:?} arguments, but {:?} were provided", + agg_fun, + agg_count, + input_types.len() + ))); + } + } + TypeSignature::Exact(types) => { + if types.len() != input_types.len() { + return Err(DataFusionError::Plan(format!( + "The function {:?} expects {:?} arguments, but {:?} were provided", + agg_fun, + types.len(), + input_types.len() + ))); + } + } + TypeSignature::OneOf(variants) => { + let ok = variants + .iter() + .any(|v| check_arg_count(agg_fun, input_types, v).is_ok()); + if !ok { + return Err(DataFusionError::Plan(format!( + "The function {:?} does not accept {:?} function arguments.", + agg_fun, + input_types.len() + ))); + } + } + _ => { + return Err(DataFusionError::Internal(format!( + "Aggregate functions do not support this {:?}", + signature + ))); + } + } + Ok(()) +} + +fn get_min_max_result_type(input_types: &[DataType]) -> Result> { + // make sure that the input types only has one element. + assert_eq!(input_types.len(), 1); + // min and max support the dictionary data type + // unpack the dictionary to get the value + match &input_types[0] { + DataType::Dictionary(_, dict_value_type) => { + // TODO add checker, if the value type is complex data type + Ok(vec![dict_value_type.deref().clone()]) + } + // TODO add checker for datatype which min and max supported + // For example, the `Struct` and `Map` type are not supported in the MIN and MAX function + _ => Ok(input_types.to_vec()), + } +} + +/// Returns the coerced exprs for each `input_exprs`. +/// Get the coerced data type from `aggregate_rule::coerce_types` and add `try_cast` if the +/// data type of `input_exprs` need to be coerced. +pub fn coerce_exprs( + agg_fun: &AggregateFunction, + input_exprs: &[Arc], + schema: &Schema, + signature: &Signature, +) -> Result>> { + if input_exprs.is_empty() { + return Ok(vec![]); + } + let input_types = input_exprs + .iter() + .map(|e| e.data_type(schema)) + .collect::>>()?; + + // get the coerced data types + let coerced_types = coerce_types(agg_fun, &input_types, signature)?; + + // try cast if need + input_exprs + .iter() + .zip(coerced_types.into_iter()) + .map(|(expr, coerced_type)| try_cast(expr.clone(), schema, coerced_type)) + .collect::>>() +} diff --git a/datafusion-physical-expr/src/coercion_rule/mod.rs b/datafusion-physical-expr/src/coercion_rule/mod.rs index a98154867f5a1..fa8d4da3c13a1 100644 --- a/datafusion-physical-expr/src/coercion_rule/mod.rs +++ b/datafusion-physical-expr/src/coercion_rule/mod.rs @@ -20,4 +20,5 @@ //! Aggregate function rule //! Binary operation rule +pub mod aggregate_rule; pub mod binary_rule; diff --git a/datafusion/src/physical_plan/crypto_expressions.rs b/datafusion-physical-expr/src/crypto_expressions.rs similarity index 98% rename from datafusion/src/physical_plan/crypto_expressions.rs rename to datafusion-physical-expr/src/crypto_expressions.rs index 2507a8d192489..95bedd4af41db 100644 --- a/datafusion/src/physical_plan/crypto_expressions.rs +++ b/datafusion-physical-expr/src/crypto_expressions.rs @@ -16,11 +16,7 @@ // under the License. //! Crypto expressions -use super::ColumnarValue; -use crate::{ - error::{DataFusionError, Result}, - scalar::ScalarValue, -}; + use arrow::{ array::{ Array, ArrayRef, BinaryArray, GenericStringArray, StringArray, @@ -30,6 +26,9 @@ use arrow::{ }; use blake2::{Blake2b512, Blake2s256, Digest}; use blake3::Hasher as Blake3; +use datafusion_common::ScalarValue; +use datafusion_common::{DataFusionError, Result}; +use datafusion_expr::ColumnarValue; use md5::Md5; use sha2::{Sha224, Sha256, Sha384, Sha512}; use std::any::type_name; diff --git a/datafusion/src/physical_plan/datetime_expressions.rs b/datafusion-physical-expr/src/datetime_expressions.rs similarity index 99% rename from datafusion/src/physical_plan/datetime_expressions.rs rename to datafusion-physical-expr/src/datetime_expressions.rs index d1533d04a24f2..0dfd260d498e9 100644 --- a/datafusion/src/physical_plan/datetime_expressions.rs +++ b/datafusion-physical-expr/src/datetime_expressions.rs @@ -16,13 +16,7 @@ // under the License. //! DateTime expressions -use std::sync::Arc; -use super::ColumnarValue; -use crate::{ - error::{DataFusionError, Result}, - scalar::{ScalarType, ScalarValue}, -}; use arrow::{ array::{Array, ArrayRef, GenericStringArray, PrimitiveArray, StringOffsetSizeTrait}, compute::kernels::cast_utils::string_to_timestamp_nanos, @@ -42,7 +36,11 @@ use arrow::{ }; use chrono::prelude::*; use chrono::Duration; +use datafusion_common::{DataFusionError, Result}; +use datafusion_common::{ScalarType, ScalarValue}; +use datafusion_expr::ColumnarValue; use std::borrow::Borrow; +use std::sync::Arc; /// given a function `op` that maps a `&str` to a Result of an arrow native type, /// returns a `PrimitiveArray` after the application diff --git a/datafusion-physical-expr/src/expressions/cume_dist.rs b/datafusion-physical-expr/src/expressions/cume_dist.rs index 7376f37b9b5fe..9cd28a3db3c6a 100644 --- a/datafusion-physical-expr/src/expressions/cume_dist.rs +++ b/datafusion-physical-expr/src/expressions/cume_dist.rs @@ -18,8 +18,8 @@ //! Defines physical expression for `cume_dist` that can evaluated //! at runtime during query execution +use crate::window::partition_evaluator::PartitionEvaluator; use crate::window::BuiltInWindowFunctionExpr; -use crate::window::PartitionEvaluator; use crate::PhysicalExpr; use arrow::array::ArrayRef; use arrow::array::Float64Array; diff --git a/datafusion-physical-expr/src/expressions/lead_lag.rs b/datafusion-physical-expr/src/expressions/lead_lag.rs index 333810711b7fc..4e286d59e768d 100644 --- a/datafusion-physical-expr/src/expressions/lead_lag.rs +++ b/datafusion-physical-expr/src/expressions/lead_lag.rs @@ -18,8 +18,8 @@ //! Defines physical expression for `lead` and `lag` that can evaluated //! at runtime during query execution +use crate::window::partition_evaluator::PartitionEvaluator; use crate::window::BuiltInWindowFunctionExpr; -use crate::window::PartitionEvaluator; use crate::PhysicalExpr; use arrow::array::ArrayRef; use arrow::compute::cast; diff --git a/datafusion-physical-expr/src/expressions/nth_value.rs b/datafusion-physical-expr/src/expressions/nth_value.rs index 21df07aa0fa74..e0a6b2bd7a7c1 100644 --- a/datafusion-physical-expr/src/expressions/nth_value.rs +++ b/datafusion-physical-expr/src/expressions/nth_value.rs @@ -18,8 +18,8 @@ //! Defines physical expressions for `first_value`, `last_value`, and `nth_value` //! that can evaluated at runtime during query execution +use crate::window::partition_evaluator::PartitionEvaluator; use crate::window::BuiltInWindowFunctionExpr; -use crate::window::PartitionEvaluator; use crate::PhysicalExpr; use arrow::array::{new_null_array, ArrayRef}; use arrow::compute::kernels::window::shift; diff --git a/datafusion-physical-expr/src/expressions/rank.rs b/datafusion-physical-expr/src/expressions/rank.rs index f3f1143c83790..18bcf266b6676 100644 --- a/datafusion-physical-expr/src/expressions/rank.rs +++ b/datafusion-physical-expr/src/expressions/rank.rs @@ -18,8 +18,8 @@ //! Defines physical expression for `rank`, `dense_rank`, and `percent_rank` that can evaluated //! at runtime during query execution +use crate::window::partition_evaluator::PartitionEvaluator; use crate::window::BuiltInWindowFunctionExpr; -use crate::window::PartitionEvaluator; use crate::PhysicalExpr; use arrow::array::ArrayRef; use arrow::array::{Float64Array, UInt64Array}; diff --git a/datafusion-physical-expr/src/expressions/row_number.rs b/datafusion-physical-expr/src/expressions/row_number.rs index f9dccee5023ab..8a720d28d6195 100644 --- a/datafusion-physical-expr/src/expressions/row_number.rs +++ b/datafusion-physical-expr/src/expressions/row_number.rs @@ -17,8 +17,8 @@ //! Defines physical expression for `row_number` that can evaluated at runtime during query execution +use crate::window::partition_evaluator::PartitionEvaluator; use crate::window::BuiltInWindowFunctionExpr; -use crate::window::PartitionEvaluator; use crate::PhysicalExpr; use arrow::array::{ArrayRef, UInt64Array}; use arrow::datatypes::{DataType, Field}; diff --git a/datafusion-physical-expr/src/functions.rs b/datafusion-physical-expr/src/functions.rs new file mode 100644 index 0000000000000..1350d49510d58 --- /dev/null +++ b/datafusion-physical-expr/src/functions.rs @@ -0,0 +1,148 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Declaration of built-in (scalar) functions. +//! This module contains built-in functions' enumeration and metadata. +//! +//! Generally, a function has: +//! * a signature +//! * a return type, that is a function of the incoming argument's types +//! * the computation, that must accept each valid signature +//! +//! * Signature: see `Signature` +//! * Return type: a function `(arg_types) -> return_type`. E.g. for sqrt, ([f32]) -> f32, ([f64]) -> f64. +//! +//! This module also has a set of coercion rules to improve user experience: if an argument i32 is passed +//! to a function that supports f64, it is coerced to f64. + +use crate::PhysicalExpr; +use arrow::datatypes::{DataType, Schema}; +use arrow::record_batch::RecordBatch; +use datafusion_common::Result; +use datafusion_expr::BuiltinScalarFunction; +use datafusion_expr::ColumnarValue; +pub use datafusion_expr::NullColumnarValue; +use datafusion_expr::ScalarFunctionImplementation; +use std::any::Any; +use std::fmt::Debug; +use std::fmt::{self, Formatter}; +use std::sync::Arc; + +/// Physical expression of a scalar function +pub struct ScalarFunctionExpr { + fun: ScalarFunctionImplementation, + name: String, + args: Vec>, + return_type: DataType, +} + +impl Debug for ScalarFunctionExpr { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + f.debug_struct("ScalarFunctionExpr") + .field("fun", &"") + .field("name", &self.name) + .field("args", &self.args) + .field("return_type", &self.return_type) + .finish() + } +} + +impl ScalarFunctionExpr { + /// Create a new Scalar function + pub fn new( + name: &str, + fun: ScalarFunctionImplementation, + args: Vec>, + return_type: &DataType, + ) -> Self { + Self { + fun, + name: name.to_owned(), + args, + return_type: return_type.clone(), + } + } + + /// Get the scalar function implementation + pub fn fun(&self) -> &ScalarFunctionImplementation { + &self.fun + } + + /// The name for this expression + pub fn name(&self) -> &str { + &self.name + } + + /// Input arguments + pub fn args(&self) -> &[Arc] { + &self.args + } + + /// Data type produced by this expression + pub fn return_type(&self) -> &DataType { + &self.return_type + } +} + +impl fmt::Display for ScalarFunctionExpr { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!( + f, + "{}({})", + self.name, + self.args + .iter() + .map(|e| format!("{}", e)) + .collect::>() + .join(", ") + ) + } +} + +impl PhysicalExpr for ScalarFunctionExpr { + /// Return a reference to Any that can be used for downcasting + fn as_any(&self) -> &dyn Any { + self + } + + fn data_type(&self, _input_schema: &Schema) -> Result { + Ok(self.return_type.clone()) + } + + fn nullable(&self, _input_schema: &Schema) -> Result { + Ok(true) + } + + fn evaluate(&self, batch: &RecordBatch) -> Result { + // evaluate the arguments, if there are no arguments we'll instead pass in a null array + // indicating the batch size (as a convention) + let inputs = match (self.args.len(), self.name.parse::()) { + (0, Ok(scalar_fun)) if scalar_fun.supports_zero_argument() => { + vec![NullColumnarValue::from(batch)] + } + _ => self + .args + .iter() + .map(|e| e.evaluate(batch)) + .collect::>>()?, + }; + + // evaluate the function + let fun = self.fun.as_ref(); + (fun)(&inputs) + } +} diff --git a/datafusion-physical-expr/src/lib.rs b/datafusion-physical-expr/src/lib.rs index ead338fc2ebd3..8a2fe25046418 100644 --- a/datafusion-physical-expr/src/lib.rs +++ b/datafusion-physical-expr/src/lib.rs @@ -16,15 +16,27 @@ // under the License. mod aggregate_expr; +pub mod array_expressions; pub mod coercion_rule; +#[cfg(feature = "crypto_expressions")] +pub mod crypto_expressions; +pub mod datetime_expressions; pub mod expressions; pub mod field_util; +mod functions; mod hyperloglog; +pub mod math_expressions; mod physical_expr; +#[cfg(feature = "regex_expressions")] +pub mod regex_expressions; mod sort_expr; +pub mod string_expressions; mod tdigest; +#[cfg(feature = "unicode_expressions")] +pub mod unicode_expressions; pub mod window; pub use aggregate_expr::AggregateExpr; +pub use functions::ScalarFunctionExpr; pub use physical_expr::PhysicalExpr; pub use sort_expr::PhysicalSortExpr; diff --git a/datafusion/src/physical_plan/math_expressions.rs b/datafusion-physical-expr/src/math_expressions.rs similarity index 97% rename from datafusion/src/physical_plan/math_expressions.rs rename to datafusion-physical-expr/src/math_expressions.rs index eabacfc6eb183..b16a59634f505 100644 --- a/datafusion/src/physical_plan/math_expressions.rs +++ b/datafusion-physical-expr/src/math_expressions.rs @@ -16,10 +16,12 @@ // under the License. //! Math expressions -use super::{ColumnarValue, ScalarValue}; -use crate::error::{DataFusionError, Result}; + use arrow::array::{Float32Array, Float64Array}; use arrow::datatypes::DataType; +use datafusion_common::ScalarValue; +use datafusion_common::{DataFusionError, Result}; +use datafusion_expr::ColumnarValue; use rand::{thread_rng, Rng}; use std::iter; use std::sync::Arc; diff --git a/datafusion/src/physical_plan/regex_expressions.rs b/datafusion-physical-expr/src/regex_expressions.rs similarity index 95% rename from datafusion/src/physical_plan/regex_expressions.rs rename to datafusion-physical-expr/src/regex_expressions.rs index cf997c0328407..69de68e166f65 100644 --- a/datafusion/src/physical_plan/regex_expressions.rs +++ b/datafusion-physical-expr/src/regex_expressions.rs @@ -21,15 +21,14 @@ //! Regex expressions -use std::any::type_name; -use std::sync::Arc; - -use crate::error::{DataFusionError, Result}; use arrow::array::{ArrayRef, GenericStringArray, StringOffsetSizeTrait}; use arrow::compute; +use datafusion_common::{DataFusionError, Result}; use hashbrown::HashMap; use lazy_static::lazy_static; use regex::Regex; +use std::any::type_name; +use std::sync::Arc; macro_rules! downcast_string_arg { ($ARG:expr, $NAME:expr, $T:ident) => {{ @@ -182,14 +181,13 @@ pub fn regexp_replace(args: &[ArrayRef]) -> Result = GenericStringBuilder::new(0); let mut expected_builder = ListBuilder::new(elem_builder); @@ -209,10 +207,10 @@ mod tests { #[test] fn test_case_insensitive_regexp_match() { - let values = StringArray::from_slice(&["abc"; 5]); + let values = StringArray::from(vec!["abc"; 5]); let patterns = - StringArray::from_slice(&["^(a)", "^(A)", "(b|d)", "(B|D)", "^(b|c)"]); - let flags = StringArray::from_slice(&["i"; 5]); + StringArray::from(vec!["^(a)", "^(A)", "(b|d)", "(B|D)", "^(b|c)"]); + let flags = StringArray::from(vec!["i"; 5]); let elem_builder: GenericStringBuilder = GenericStringBuilder::new(0); let mut expected_builder = ListBuilder::new(elem_builder); diff --git a/datafusion-physical-expr/src/sort_expr.rs b/datafusion-physical-expr/src/sort_expr.rs index e8172dd9979eb..79656725d4f44 100644 --- a/datafusion-physical-expr/src/sort_expr.rs +++ b/datafusion-physical-expr/src/sort_expr.rs @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +//! Sort expressions + use crate::PhysicalExpr; use arrow::compute::kernels::sort::{SortColumn, SortOptions}; use arrow::record_batch::RecordBatch; diff --git a/datafusion/src/physical_plan/string_expressions.rs b/datafusion-physical-expr/src/string_expressions.rs similarity index 99% rename from datafusion/src/physical_plan/string_expressions.rs rename to datafusion-physical-expr/src/string_expressions.rs index a9e4c2fc54b1d..b0b569d99eca6 100644 --- a/datafusion/src/physical_plan/string_expressions.rs +++ b/datafusion-physical-expr/src/string_expressions.rs @@ -21,13 +21,6 @@ //! String expressions -use std::any::type_name; -use std::sync::Arc; - -use crate::{ - error::{DataFusionError, Result}, - scalar::ScalarValue, -}; use arrow::{ array::{ Array, ArrayRef, BooleanArray, GenericStringArray, Int32Array, Int64Array, @@ -35,8 +28,11 @@ use arrow::{ }, datatypes::{ArrowNativeType, ArrowPrimitiveType, DataType}, }; - -use super::ColumnarValue; +use datafusion_common::ScalarValue; +use datafusion_common::{DataFusionError, Result}; +use datafusion_expr::ColumnarValue; +use std::any::type_name; +use std::sync::Arc; macro_rules! downcast_string_arg { ($ARG:expr, $NAME:expr, $T:ident) => {{ diff --git a/datafusion/src/physical_plan/unicode_expressions.rs b/datafusion-physical-expr/src/unicode_expressions.rs similarity index 99% rename from datafusion/src/physical_plan/unicode_expressions.rs rename to datafusion-physical-expr/src/unicode_expressions.rs index 1d2c4b765a2dc..86a2ef7ba9a06 100644 --- a/datafusion/src/physical_plan/unicode_expressions.rs +++ b/datafusion-physical-expr/src/unicode_expressions.rs @@ -21,18 +21,17 @@ //! Unicode expressions -use std::any::type_name; -use std::cmp::Ordering; -use std::sync::Arc; - -use crate::error::{DataFusionError, Result}; use arrow::{ array::{ ArrayRef, GenericStringArray, Int64Array, PrimitiveArray, StringOffsetSizeTrait, }, datatypes::{ArrowNativeType, ArrowPrimitiveType}, }; +use datafusion_common::{DataFusionError, Result}; use hashbrown::HashMap; +use std::any::type_name; +use std::cmp::Ordering; +use std::sync::Arc; use unicode_segmentation::UnicodeSegmentation; macro_rules! downcast_string_arg { diff --git a/datafusion/src/physical_plan/windows/aggregate.rs b/datafusion-physical-expr/src/window/aggregate.rs similarity index 95% rename from datafusion/src/physical_plan/windows/aggregate.rs rename to datafusion-physical-expr/src/window/aggregate.rs index 30e0b2994f50b..9caa847c02c5a 100644 --- a/datafusion/src/physical_plan/windows/aggregate.rs +++ b/datafusion-physical-expr/src/window/aggregate.rs @@ -17,15 +17,16 @@ //! Physical exec for aggregate window function expressions. -use crate::error::{DataFusionError, Result}; -use crate::physical_plan::{ - expressions::PhysicalSortExpr, Accumulator, AggregateExpr, PhysicalExpr, WindowExpr, -}; +use crate::window::partition_evaluator::find_ranges_in_range; +use crate::{expressions::PhysicalSortExpr, PhysicalExpr}; +use crate::{window::WindowExpr, AggregateExpr}; use arrow::compute::concat; use arrow::record_batch::RecordBatch; use arrow::{array::ArrayRef, datatypes::Field}; +use datafusion_common::DataFusionError; +use datafusion_common::Result; +use datafusion_expr::Accumulator; use datafusion_expr::{WindowFrame, WindowFrameUnits}; -use datafusion_physical_expr::window::find_ranges_in_range; use std::any::Any; use std::iter::IntoIterator; use std::ops::Range; @@ -42,7 +43,7 @@ pub struct AggregateWindowExpr { impl AggregateWindowExpr { /// create a new aggregate window function expression - pub(super) fn new( + pub fn new( aggregate: Arc, partition_by: &[Arc], order_by: &[PhysicalSortExpr], diff --git a/datafusion/src/physical_plan/windows/built_in.rs b/datafusion-physical-expr/src/window/built_in.rs similarity index 93% rename from datafusion/src/physical_plan/windows/built_in.rs rename to datafusion-physical-expr/src/window/built_in.rs index 3ded850432bf2..2fa1f808fda83 100644 --- a/datafusion/src/physical_plan/windows/built_in.rs +++ b/datafusion-physical-expr/src/window/built_in.rs @@ -17,12 +17,14 @@ //! Physical exec for built-in window function expressions. -use crate::error::{DataFusionError, Result}; -use crate::physical_plan::{expressions::PhysicalSortExpr, PhysicalExpr, WindowExpr}; +use super::BuiltInWindowFunctionExpr; +use super::WindowExpr; +use crate::{expressions::PhysicalSortExpr, PhysicalExpr}; use arrow::compute::concat; use arrow::record_batch::RecordBatch; use arrow::{array::ArrayRef, datatypes::Field}; -use datafusion_physical_expr::window::BuiltInWindowFunctionExpr; +use datafusion_common::DataFusionError; +use datafusion_common::Result; use std::any::Any; use std::sync::Arc; @@ -36,7 +38,7 @@ pub struct BuiltInWindowExpr { impl BuiltInWindowExpr { /// create a new built-in window function expression - pub(super) fn new( + pub fn new( expr: Arc, partition_by: &[Arc], order_by: &[PhysicalSortExpr], diff --git a/datafusion-physical-expr/src/window/mod.rs b/datafusion-physical-expr/src/window/mod.rs index 48a6e8b4f589d..044cd1491a9a3 100644 --- a/datafusion-physical-expr/src/window/mod.rs +++ b/datafusion-physical-expr/src/window/mod.rs @@ -15,11 +15,13 @@ // specific language governing permissions and limitations // under the License. +mod aggregate; +mod built_in; mod built_in_window_function_expr; -mod partition_evaluator; +pub(crate) mod partition_evaluator; mod window_expr; +pub use aggregate::AggregateWindowExpr; +pub use built_in::BuiltInWindowExpr; pub use built_in_window_function_expr::BuiltInWindowFunctionExpr; -pub use partition_evaluator::find_ranges_in_range; -pub use partition_evaluator::PartitionEvaluator; pub use window_expr::WindowExpr; diff --git a/datafusion-physical-expr/src/window/partition_evaluator.rs b/datafusion-physical-expr/src/window/partition_evaluator.rs index 9afdf3860d0ec..c3a88367a2c24 100644 --- a/datafusion-physical-expr/src/window/partition_evaluator.rs +++ b/datafusion-physical-expr/src/window/partition_evaluator.rs @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +//! partition evaluation module + use arrow::array::ArrayRef; use datafusion_common::DataFusionError; use datafusion_common::Result; @@ -25,7 +27,7 @@ use std::ops::Range; /// boundaries would align (what's sorted on [partition columns...] would definitely be sorted /// on finer columns), so this will use binary search to find ranges that are within the /// partition range and return the valid slice. -pub fn find_ranges_in_range<'a>( +pub(crate) fn find_ranges_in_range<'a>( partition_range: &Range, sort_partition_points: &'a [Range], ) -> &'a [Range] { diff --git a/datafusion/Cargo.toml b/datafusion/Cargo.toml index 81722fd91c5ba..9a270199dc906 100644 --- a/datafusion/Cargo.toml +++ b/datafusion/Cargo.toml @@ -40,9 +40,9 @@ path = "src/lib.rs" [features] default = ["crypto_expressions", "regex_expressions", "unicode_expressions"] simd = ["arrow/simd"] -crypto_expressions = ["md-5", "sha2", "blake2", "blake3"] -regex_expressions = ["regex"] -unicode_expressions = ["unicode-segmentation"] +crypto_expressions = [ "datafusion-physical-expr/crypto_expressions" ] +unicode_expressions = ["datafusion-physical-expr/regex_expressions"] +regex_expressions = ["datafusion-physical-expr/regex_expressions"] pyarrow = ["pyo3", "arrow/pyarrow", "datafusion-common/pyarrow"] # Used for testing ONLY: causes all values to hash to the same value (test for collisions) force_hash_collisions = [] @@ -72,13 +72,7 @@ pin-project-lite= "^0.2.7" tokio = { version = "1.0", features = ["macros", "rt", "rt-multi-thread", "sync", "fs", "parking_lot"] } tokio-stream = "0.1" log = "^0.4" -md-5 = { version = "^0.10.0", optional = true } -sha2 = { version = "^0.10.1", optional = true } -blake2 = { version = "^0.10.2", optional = true } -blake3 = { version = "1.0", optional = true } ordered-float = "2.10" -unicode-segmentation = { version = "^1.7.1", optional = true } -regex = { version = "^1.4.3", optional = true } lazy_static = { version = "^1.4.0" } smallvec = { version = "1.6", features = ["union"] } rand = "0.8" diff --git a/datafusion/src/physical_plan/aggregate_rule.rs b/datafusion/src/physical_plan/aggregate_rule.rs index 2e510129a1d23..41ff4a65c9cf8 100644 --- a/datafusion/src/physical_plan/aggregate_rule.rs +++ b/datafusion/src/physical_plan/aggregate_rule.rs @@ -17,249 +17,9 @@ //! Support the coercion rule for aggregate function. -use arrow::datatypes::DataType; -use arrow::datatypes::Schema; -use datafusion_common::{DataFusionError, Result}; -use datafusion_expr::AggregateFunction; -use datafusion_expr::{Signature, TypeSignature}; -use datafusion_physical_expr::expressions::is_approx_percentile_cont_supported_arg_type; -use datafusion_physical_expr::expressions::{ - is_avg_support_arg_type, is_correlation_support_arg_type, - is_covariance_support_arg_type, is_stddev_support_arg_type, is_sum_support_arg_type, - is_variance_support_arg_type, try_cast, +pub use datafusion_physical_expr::coercion_rule::aggregate_rule::{ + coerce_exprs, coerce_types, }; -use datafusion_physical_expr::PhysicalExpr; -use std::ops::Deref; -use std::sync::Arc; - -/// Returns the coerced data type for each `input_types`. -/// Different aggregate function with different input data type will get corresponding coerced data type. -pub fn coerce_types( - agg_fun: &AggregateFunction, - input_types: &[DataType], - signature: &Signature, -) -> Result> { - // Validate input_types matches (at least one of) the func signature. - check_arg_count(agg_fun, input_types, &signature.type_signature)?; - - match agg_fun { - AggregateFunction::Count | AggregateFunction::ApproxDistinct => { - Ok(input_types.to_vec()) - } - AggregateFunction::ArrayAgg => Ok(input_types.to_vec()), - AggregateFunction::Min | AggregateFunction::Max => { - // min and max support the dictionary data type - // unpack the dictionary to get the value - get_min_max_result_type(input_types) - } - AggregateFunction::Sum => { - // Refer to https://www.postgresql.org/docs/8.2/functions-aggregate.html doc - // smallint, int, bigint, real, double precision, decimal, or interval. - if !is_sum_support_arg_type(&input_types[0]) { - return Err(DataFusionError::Plan(format!( - "The function {:?} does not support inputs of type {:?}.", - agg_fun, input_types[0] - ))); - } - Ok(input_types.to_vec()) - } - AggregateFunction::Avg => { - // Refer to https://www.postgresql.org/docs/8.2/functions-aggregate.html doc - // smallint, int, bigint, real, double precision, decimal, or interval - if !is_avg_support_arg_type(&input_types[0]) { - return Err(DataFusionError::Plan(format!( - "The function {:?} does not support inputs of type {:?}.", - agg_fun, input_types[0] - ))); - } - Ok(input_types.to_vec()) - } - AggregateFunction::Variance => { - if !is_variance_support_arg_type(&input_types[0]) { - return Err(DataFusionError::Plan(format!( - "The function {:?} does not support inputs of type {:?}.", - agg_fun, input_types[0] - ))); - } - Ok(input_types.to_vec()) - } - AggregateFunction::VariancePop => { - if !is_variance_support_arg_type(&input_types[0]) { - return Err(DataFusionError::Plan(format!( - "The function {:?} does not support inputs of type {:?}.", - agg_fun, input_types[0] - ))); - } - Ok(input_types.to_vec()) - } - AggregateFunction::Covariance => { - if !is_covariance_support_arg_type(&input_types[0]) { - return Err(DataFusionError::Plan(format!( - "The function {:?} does not support inputs of type {:?}.", - agg_fun, input_types[0] - ))); - } - Ok(input_types.to_vec()) - } - AggregateFunction::CovariancePop => { - if !is_covariance_support_arg_type(&input_types[0]) { - return Err(DataFusionError::Plan(format!( - "The function {:?} does not support inputs of type {:?}.", - agg_fun, input_types[0] - ))); - } - Ok(input_types.to_vec()) - } - AggregateFunction::Stddev => { - if !is_stddev_support_arg_type(&input_types[0]) { - return Err(DataFusionError::Plan(format!( - "The function {:?} does not support inputs of type {:?}.", - agg_fun, input_types[0] - ))); - } - Ok(input_types.to_vec()) - } - AggregateFunction::StddevPop => { - if !is_stddev_support_arg_type(&input_types[0]) { - return Err(DataFusionError::Plan(format!( - "The function {:?} does not support inputs of type {:?}.", - agg_fun, input_types[0] - ))); - } - Ok(input_types.to_vec()) - } - AggregateFunction::Correlation => { - if !is_correlation_support_arg_type(&input_types[0]) { - return Err(DataFusionError::Plan(format!( - "The function {:?} does not support inputs of type {:?}.", - agg_fun, input_types[0] - ))); - } - Ok(input_types.to_vec()) - } - AggregateFunction::ApproxPercentileCont => { - if !is_approx_percentile_cont_supported_arg_type(&input_types[0]) { - return Err(DataFusionError::Plan(format!( - "The function {:?} does not support inputs of type {:?}.", - agg_fun, input_types[0] - ))); - } - if !matches!(input_types[1], DataType::Float64) { - return Err(DataFusionError::Plan(format!( - "The percentile argument for {:?} must be Float64, not {:?}.", - agg_fun, input_types[1] - ))); - } - Ok(input_types.to_vec()) - } - AggregateFunction::ApproxMedian => { - if !is_approx_percentile_cont_supported_arg_type(&input_types[0]) { - return Err(DataFusionError::Plan(format!( - "The function {:?} does not support inputs of type {:?}.", - agg_fun, input_types[0] - ))); - } - Ok(input_types.to_vec()) - } - } -} - -/// Validate the length of `input_types` matches the `signature` for `agg_fun`. -/// -/// This method DOES NOT validate the argument types - only that (at least one, -/// in the case of [`TypeSignature::OneOf`]) signature matches the desired -/// number of input types. -fn check_arg_count( - agg_fun: &AggregateFunction, - input_types: &[DataType], - signature: &TypeSignature, -) -> Result<()> { - match signature { - TypeSignature::Uniform(agg_count, _) | TypeSignature::Any(agg_count) => { - if input_types.len() != *agg_count { - return Err(DataFusionError::Plan(format!( - "The function {:?} expects {:?} arguments, but {:?} were provided", - agg_fun, - agg_count, - input_types.len() - ))); - } - } - TypeSignature::Exact(types) => { - if types.len() != input_types.len() { - return Err(DataFusionError::Plan(format!( - "The function {:?} expects {:?} arguments, but {:?} were provided", - agg_fun, - types.len(), - input_types.len() - ))); - } - } - TypeSignature::OneOf(variants) => { - let ok = variants - .iter() - .any(|v| check_arg_count(agg_fun, input_types, v).is_ok()); - if !ok { - return Err(DataFusionError::Plan(format!( - "The function {:?} does not accept {:?} function arguments.", - agg_fun, - input_types.len() - ))); - } - } - _ => { - return Err(DataFusionError::Internal(format!( - "Aggregate functions do not support this {:?}", - signature - ))); - } - } - Ok(()) -} - -fn get_min_max_result_type(input_types: &[DataType]) -> Result> { - // make sure that the input types only has one element. - assert_eq!(input_types.len(), 1); - // min and max support the dictionary data type - // unpack the dictionary to get the value - match &input_types[0] { - DataType::Dictionary(_, dict_value_type) => { - // TODO add checker, if the value type is complex data type - Ok(vec![dict_value_type.deref().clone()]) - } - // TODO add checker for datatype which min and max supported - // For example, the `Struct` and `Map` type are not supported in the MIN and MAX function - _ => Ok(input_types.to_vec()), - } -} - -/// Returns the coerced exprs for each `input_exprs`. -/// Get the coerced data type from `aggregate_rule::coerce_types` and add `try_cast` if the -/// data type of `input_exprs` need to be coerced. -pub fn coerce_exprs( - agg_fun: &AggregateFunction, - input_exprs: &[Arc], - schema: &Schema, - signature: &Signature, -) -> Result>> { - if input_exprs.is_empty() { - return Ok(vec![]); - } - let input_types = input_exprs - .iter() - .map(|e| e.data_type(schema)) - .collect::>>()?; - - // get the coerced data types - let coerced_types = coerce_types(agg_fun, &input_types, signature)?; - - // try cast if need - input_exprs - .iter() - .zip(coerced_types.into_iter()) - .map(|(expr, coerced_type)| try_cast(expr.clone(), schema, coerced_type)) - .collect::>>() -} #[cfg(test)] mod tests { diff --git a/datafusion/src/physical_plan/functions.rs b/datafusion/src/physical_plan/functions.rs index bf0aee9e6aa00..07151dd20f60a 100644 --- a/datafusion/src/physical_plan/functions.rs +++ b/datafusion/src/physical_plan/functions.rs @@ -34,13 +34,9 @@ use super::{ ColumnarValue, PhysicalExpr, }; use crate::execution::context::ExecutionProps; -use crate::physical_plan::array_expressions; -use crate::physical_plan::datetime_expressions; use crate::physical_plan::expressions::{ cast_column, nullif_func, DEFAULT_DATAFUSION_CAST_OPTIONS, SUPPORTED_NULLIF_TYPES, }; -use crate::physical_plan::math_expressions; -use crate::physical_plan::string_expressions; use crate::{ error::{DataFusionError, Result}, scalar::ScalarValue, @@ -50,27 +46,14 @@ use arrow::{ compute::kernels::length::{bit_length, length}, datatypes::TimeUnit, datatypes::{DataType, Field, Int32Type, Int64Type, Schema}, - record_batch::RecordBatch, }; -pub use datafusion_expr::NullColumnarValue; +use datafusion_expr::ScalarFunctionImplementation; pub use datafusion_expr::{BuiltinScalarFunction, Signature, TypeSignature, Volatility}; -use fmt::{Debug, Formatter}; -use std::{any::Any, fmt, sync::Arc}; - -/// Scalar function -/// -/// The Fn param is the wrapped function but be aware that the function will -/// be passed with the slice / vec of columnar values (either scalar or array) -/// with the exception of zero param function, where a singular element vec -/// will be passed. In that case the single element is a null array to indicate -/// the batch's row count (so that the generative zero-argument function can know -/// the result array size). -pub type ScalarFunctionImplementation = - Arc Result + Send + Sync>; - -/// A function's return type -pub type ReturnTypeFunction = - Arc Result> + Send + Sync>; +use datafusion_physical_expr::array_expressions; +use datafusion_physical_expr::datetime_expressions; +use datafusion_physical_expr::math_expressions; +use datafusion_physical_expr::string_expressions; +use std::sync::Arc; macro_rules! make_utf8_to_return_type { ($FUNC:ident, $largeUtf8Type:expr, $utf8Type:expr) => { @@ -256,597 +239,106 @@ pub fn return_type( } } -#[cfg(feature = "crypto_expressions")] -macro_rules! invoke_if_crypto_expressions_feature_flag { - ($FUNC:ident, $NAME:expr) => {{ - use crate::physical_plan::crypto_expressions; - crypto_expressions::$FUNC - }}; -} - -#[cfg(not(feature = "crypto_expressions"))] -macro_rules! invoke_if_crypto_expressions_feature_flag { - ($FUNC:ident, $NAME:expr) => { - |_: &[ColumnarValue]| -> Result { - Err(DataFusionError::Internal(format!( - "function {} requires compilation with feature flag: crypto_expressions.", - $NAME - ))) - } - }; -} - -#[cfg(feature = "regex_expressions")] -macro_rules! invoke_if_regex_expressions_feature_flag { - ($FUNC:ident, $T:tt, $NAME:expr) => {{ - use crate::physical_plan::regex_expressions; - regex_expressions::$FUNC::<$T> - }}; -} - -#[cfg(not(feature = "regex_expressions"))] -macro_rules! invoke_if_regex_expressions_feature_flag { - ($FUNC:ident, $T:tt, $NAME:expr) => { - |_: &[ArrayRef]| -> Result { - Err(DataFusionError::Internal(format!( - "function {} requires compilation with feature flag: regex_expressions.", - $NAME - ))) - } - }; -} +/// Create a physical (function) expression. +/// This function errors when `args`' can't be coerced to a valid argument type of the function. +pub fn create_physical_expr( + fun: &BuiltinScalarFunction, + input_phy_exprs: &[Arc], + input_schema: &Schema, + execution_props: &ExecutionProps, +) -> Result> { + let coerced_phy_exprs = coerce(input_phy_exprs, input_schema, &signature(fun))?; -#[cfg(feature = "unicode_expressions")] -macro_rules! invoke_if_unicode_expressions_feature_flag { - ($FUNC:ident, $T:tt, $NAME:expr) => {{ - use crate::physical_plan::unicode_expressions; - unicode_expressions::$FUNC::<$T> - }}; -} + let coerced_expr_types = coerced_phy_exprs + .iter() + .map(|e| e.data_type(input_schema)) + .collect::>>()?; -#[cfg(not(feature = "unicode_expressions"))] -macro_rules! invoke_if_unicode_expressions_feature_flag { - ($FUNC:ident, $T:tt, $NAME:expr) => { - |_: &[ArrayRef]| -> Result { - Err(DataFusionError::Internal(format!( - "function {} requires compilation with feature flag: unicode_expressions.", - $NAME - ))) - } - }; -} + let data_type = return_type(fun, &coerced_expr_types)?; -/// Create a physical scalar function. -pub fn create_physical_fun( - fun: &BuiltinScalarFunction, - execution_props: &ExecutionProps, -) -> Result { - Ok(match fun { - // math functions - BuiltinScalarFunction::Abs => Arc::new(math_expressions::abs), - BuiltinScalarFunction::Acos => Arc::new(math_expressions::acos), - BuiltinScalarFunction::Asin => Arc::new(math_expressions::asin), - BuiltinScalarFunction::Atan => Arc::new(math_expressions::atan), - BuiltinScalarFunction::Ceil => Arc::new(math_expressions::ceil), - BuiltinScalarFunction::Cos => Arc::new(math_expressions::cos), - BuiltinScalarFunction::Exp => Arc::new(math_expressions::exp), - BuiltinScalarFunction::Floor => Arc::new(math_expressions::floor), - BuiltinScalarFunction::Log => Arc::new(math_expressions::log10), - BuiltinScalarFunction::Ln => Arc::new(math_expressions::ln), - BuiltinScalarFunction::Log10 => Arc::new(math_expressions::log10), - BuiltinScalarFunction::Log2 => Arc::new(math_expressions::log2), - BuiltinScalarFunction::Random => Arc::new(math_expressions::random), - BuiltinScalarFunction::Round => Arc::new(math_expressions::round), - BuiltinScalarFunction::Signum => Arc::new(math_expressions::signum), - BuiltinScalarFunction::Sin => Arc::new(math_expressions::sin), - BuiltinScalarFunction::Sqrt => Arc::new(math_expressions::sqrt), - BuiltinScalarFunction::Tan => Arc::new(math_expressions::tan), - BuiltinScalarFunction::Trunc => Arc::new(math_expressions::trunc), - // string functions - BuiltinScalarFunction::Array => Arc::new(array_expressions::array), - BuiltinScalarFunction::Ascii => Arc::new(|args| match args[0].data_type() { - DataType::Utf8 => { - make_scalar_function(string_expressions::ascii::)(args) - } - DataType::LargeUtf8 => { - make_scalar_function(string_expressions::ascii::)(args) - } - other => Err(DataFusionError::Internal(format!( - "Unsupported data type {:?} for function ascii", - other, - ))), - }), - BuiltinScalarFunction::BitLength => Arc::new(|args| match &args[0] { - ColumnarValue::Array(v) => Ok(ColumnarValue::Array(bit_length(v.as_ref())?)), - ColumnarValue::Scalar(v) => match v { - ScalarValue::Utf8(v) => Ok(ColumnarValue::Scalar(ScalarValue::Int32( - v.as_ref().map(|x| (x.len() * 8) as i32), - ))), - ScalarValue::LargeUtf8(v) => Ok(ColumnarValue::Scalar( - ScalarValue::Int64(v.as_ref().map(|x| (x.len() * 8) as i64)), - )), - _ => unreachable!(), - }, - }), - BuiltinScalarFunction::Btrim => Arc::new(|args| match args[0].data_type() { - DataType::Utf8 => { - make_scalar_function(string_expressions::btrim::)(args) - } - DataType::LargeUtf8 => { - make_scalar_function(string_expressions::btrim::)(args) - } - other => Err(DataFusionError::Internal(format!( - "Unsupported data type {:?} for function btrim", - other, - ))), - }), - BuiltinScalarFunction::CharacterLength => { - Arc::new(|args| match args[0].data_type() { - DataType::Utf8 => { - let func = invoke_if_unicode_expressions_feature_flag!( - character_length, - Int32Type, - "character_length" - ); - make_scalar_function(func)(args) + let fun_expr: ScalarFunctionImplementation = match fun { + // These functions need args and input schema to pick an implementation + // Unlike the string functions, which actually figure out the function to use with each array, + // here we return either a cast fn or string timestamp translation based on the expression data type + // so we don't have to pay a per-array/batch cost. + BuiltinScalarFunction::ToTimestamp => { + Arc::new(match coerced_phy_exprs[0].data_type(input_schema) { + Ok(DataType::Int64) | Ok(DataType::Timestamp(_, None)) => { + |col_values: &[ColumnarValue]| { + cast_column( + &col_values[0], + &DataType::Timestamp(TimeUnit::Nanosecond, None), + &DEFAULT_DATAFUSION_CAST_OPTIONS, + ) + } } - DataType::LargeUtf8 => { - let func = invoke_if_unicode_expressions_feature_flag!( - character_length, - Int64Type, - "character_length" - ); - make_scalar_function(func)(args) + Ok(DataType::Utf8) => datetime_expressions::to_timestamp, + other => { + return Err(DataFusionError::Internal(format!( + "Unsupported data type {:?} for function to_timestamp", + other, + ))) } - other => Err(DataFusionError::Internal(format!( - "Unsupported data type {:?} for function character_length", - other, - ))), }) } - BuiltinScalarFunction::Chr => { - Arc::new(|args| make_scalar_function(string_expressions::chr)(args)) - } - BuiltinScalarFunction::Concat => Arc::new(string_expressions::concat), - BuiltinScalarFunction::ConcatWithSeparator => { - Arc::new(|args| make_scalar_function(string_expressions::concat_ws)(args)) + BuiltinScalarFunction::ToTimestampMillis => { + Arc::new(match coerced_phy_exprs[0].data_type(input_schema) { + Ok(DataType::Int64) | Ok(DataType::Timestamp(_, None)) => { + |col_values: &[ColumnarValue]| { + cast_column( + &col_values[0], + &DataType::Timestamp(TimeUnit::Millisecond, None), + &DEFAULT_DATAFUSION_CAST_OPTIONS, + ) + } + } + Ok(DataType::Utf8) => datetime_expressions::to_timestamp_millis, + other => { + return Err(DataFusionError::Internal(format!( + "Unsupported data type {:?} for function to_timestamp_millis", + other, + ))) + } + }) } - BuiltinScalarFunction::DatePart => Arc::new(datetime_expressions::date_part), - BuiltinScalarFunction::DateTrunc => Arc::new(datetime_expressions::date_trunc), - BuiltinScalarFunction::Now => { - // bind value for now at plan time - Arc::new(datetime_expressions::make_now( - execution_props.query_execution_start_time, - )) + BuiltinScalarFunction::ToTimestampMicros => { + Arc::new(match coerced_phy_exprs[0].data_type(input_schema) { + Ok(DataType::Int64) | Ok(DataType::Timestamp(_, None)) => { + |col_values: &[ColumnarValue]| { + cast_column( + &col_values[0], + &DataType::Timestamp(TimeUnit::Microsecond, None), + &DEFAULT_DATAFUSION_CAST_OPTIONS, + ) + } + } + Ok(DataType::Utf8) => datetime_expressions::to_timestamp_micros, + other => { + return Err(DataFusionError::Internal(format!( + "Unsupported data type {:?} for function to_timestamp_micros", + other, + ))) + } + }) } - BuiltinScalarFunction::InitCap => Arc::new(|args| match args[0].data_type() { - DataType::Utf8 => { - make_scalar_function(string_expressions::initcap::)(args) - } - DataType::LargeUtf8 => { - make_scalar_function(string_expressions::initcap::)(args) - } - other => Err(DataFusionError::Internal(format!( - "Unsupported data type {:?} for function initcap", - other, - ))), - }), - BuiltinScalarFunction::Left => Arc::new(|args| match args[0].data_type() { - DataType::Utf8 => { - let func = invoke_if_unicode_expressions_feature_flag!(left, i32, "left"); - make_scalar_function(func)(args) - } - DataType::LargeUtf8 => { - let func = invoke_if_unicode_expressions_feature_flag!(left, i64, "left"); - make_scalar_function(func)(args) - } - other => Err(DataFusionError::Internal(format!( - "Unsupported data type {:?} for function left", - other, - ))), - }), - BuiltinScalarFunction::Lower => Arc::new(string_expressions::lower), - BuiltinScalarFunction::Lpad => Arc::new(|args| match args[0].data_type() { - DataType::Utf8 => { - let func = invoke_if_unicode_expressions_feature_flag!(lpad, i32, "lpad"); - make_scalar_function(func)(args) - } - DataType::LargeUtf8 => { - let func = invoke_if_unicode_expressions_feature_flag!(lpad, i64, "lpad"); - make_scalar_function(func)(args) - } - other => Err(DataFusionError::Internal(format!( - "Unsupported data type {:?} for function lpad", - other, - ))), - }), - BuiltinScalarFunction::Ltrim => Arc::new(|args| match args[0].data_type() { - DataType::Utf8 => { - make_scalar_function(string_expressions::ltrim::)(args) - } - DataType::LargeUtf8 => { - make_scalar_function(string_expressions::ltrim::)(args) - } - other => Err(DataFusionError::Internal(format!( - "Unsupported data type {:?} for function ltrim", - other, - ))), - }), - BuiltinScalarFunction::MD5 => { - Arc::new(invoke_if_crypto_expressions_feature_flag!(md5, "md5")) - } - BuiltinScalarFunction::Digest => { - Arc::new(invoke_if_crypto_expressions_feature_flag!(digest, "digest")) - } - BuiltinScalarFunction::NullIf => Arc::new(nullif_func), - BuiltinScalarFunction::OctetLength => Arc::new(|args| match &args[0] { - ColumnarValue::Array(v) => Ok(ColumnarValue::Array(length(v.as_ref())?)), - ColumnarValue::Scalar(v) => match v { - ScalarValue::Utf8(v) => Ok(ColumnarValue::Scalar(ScalarValue::Int32( - v.as_ref().map(|x| x.len() as i32), - ))), - ScalarValue::LargeUtf8(v) => Ok(ColumnarValue::Scalar( - ScalarValue::Int64(v.as_ref().map(|x| x.len() as i64)), - )), - _ => unreachable!(), - }, - }), - BuiltinScalarFunction::RegexpMatch => { - Arc::new(|args| match args[0].data_type() { - DataType::Utf8 => { - let func = invoke_if_regex_expressions_feature_flag!( - regexp_match, - i32, - "regexp_match" - ); - make_scalar_function(func)(args) - } - DataType::LargeUtf8 => { - let func = invoke_if_regex_expressions_feature_flag!( - regexp_match, - i64, - "regexp_match" - ); - make_scalar_function(func)(args) - } - other => Err(DataFusionError::Internal(format!( - "Unsupported data type {:?} for function regexp_match", - other - ))), - }) - } - BuiltinScalarFunction::RegexpReplace => { - Arc::new(|args| match args[0].data_type() { - DataType::Utf8 => { - let func = invoke_if_regex_expressions_feature_flag!( - regexp_replace, - i32, - "regexp_replace" - ); - make_scalar_function(func)(args) - } - DataType::LargeUtf8 => { - let func = invoke_if_regex_expressions_feature_flag!( - regexp_replace, - i64, - "regexp_replace" - ); - make_scalar_function(func)(args) - } - other => Err(DataFusionError::Internal(format!( - "Unsupported data type {:?} for function regexp_replace", - other, - ))), - }) - } - BuiltinScalarFunction::Repeat => Arc::new(|args| match args[0].data_type() { - DataType::Utf8 => { - make_scalar_function(string_expressions::repeat::)(args) - } - DataType::LargeUtf8 => { - make_scalar_function(string_expressions::repeat::)(args) - } - other => Err(DataFusionError::Internal(format!( - "Unsupported data type {:?} for function repeat", - other, - ))), - }), - BuiltinScalarFunction::Replace => Arc::new(|args| match args[0].data_type() { - DataType::Utf8 => { - make_scalar_function(string_expressions::replace::)(args) - } - DataType::LargeUtf8 => { - make_scalar_function(string_expressions::replace::)(args) - } - other => Err(DataFusionError::Internal(format!( - "Unsupported data type {:?} for function replace", - other, - ))), - }), - BuiltinScalarFunction::Reverse => Arc::new(|args| match args[0].data_type() { - DataType::Utf8 => { - let func = - invoke_if_unicode_expressions_feature_flag!(reverse, i32, "reverse"); - make_scalar_function(func)(args) - } - DataType::LargeUtf8 => { - let func = - invoke_if_unicode_expressions_feature_flag!(reverse, i64, "reverse"); - make_scalar_function(func)(args) - } - other => Err(DataFusionError::Internal(format!( - "Unsupported data type {:?} for function reverse", - other, - ))), - }), - BuiltinScalarFunction::Right => Arc::new(|args| match args[0].data_type() { - DataType::Utf8 => { - let func = - invoke_if_unicode_expressions_feature_flag!(right, i32, "right"); - make_scalar_function(func)(args) - } - DataType::LargeUtf8 => { - let func = - invoke_if_unicode_expressions_feature_flag!(right, i64, "right"); - make_scalar_function(func)(args) - } - other => Err(DataFusionError::Internal(format!( - "Unsupported data type {:?} for function right", - other, - ))), - }), - BuiltinScalarFunction::Rpad => Arc::new(|args| match args[0].data_type() { - DataType::Utf8 => { - let func = invoke_if_unicode_expressions_feature_flag!(rpad, i32, "rpad"); - make_scalar_function(func)(args) - } - DataType::LargeUtf8 => { - let func = invoke_if_unicode_expressions_feature_flag!(rpad, i64, "rpad"); - make_scalar_function(func)(args) - } - other => Err(DataFusionError::Internal(format!( - "Unsupported data type {:?} for function rpad", - other, - ))), - }), - BuiltinScalarFunction::Rtrim => Arc::new(|args| match args[0].data_type() { - DataType::Utf8 => { - make_scalar_function(string_expressions::rtrim::)(args) - } - DataType::LargeUtf8 => { - make_scalar_function(string_expressions::rtrim::)(args) - } - other => Err(DataFusionError::Internal(format!( - "Unsupported data type {:?} for function rtrim", - other, - ))), - }), - BuiltinScalarFunction::SHA224 => { - Arc::new(invoke_if_crypto_expressions_feature_flag!(sha224, "sha224")) - } - BuiltinScalarFunction::SHA256 => { - Arc::new(invoke_if_crypto_expressions_feature_flag!(sha256, "sha256")) - } - BuiltinScalarFunction::SHA384 => { - Arc::new(invoke_if_crypto_expressions_feature_flag!(sha384, "sha384")) - } - BuiltinScalarFunction::SHA512 => { - Arc::new(invoke_if_crypto_expressions_feature_flag!(sha512, "sha512")) - } - BuiltinScalarFunction::SplitPart => Arc::new(|args| match args[0].data_type() { - DataType::Utf8 => { - make_scalar_function(string_expressions::split_part::)(args) - } - DataType::LargeUtf8 => { - make_scalar_function(string_expressions::split_part::)(args) - } - other => Err(DataFusionError::Internal(format!( - "Unsupported data type {:?} for function split_part", - other, - ))), - }), - BuiltinScalarFunction::StartsWith => Arc::new(|args| match args[0].data_type() { - DataType::Utf8 => { - make_scalar_function(string_expressions::starts_with::)(args) - } - DataType::LargeUtf8 => { - make_scalar_function(string_expressions::starts_with::)(args) - } - other => Err(DataFusionError::Internal(format!( - "Unsupported data type {:?} for function starts_with", - other, - ))), - }), - BuiltinScalarFunction::Strpos => Arc::new(|args| match args[0].data_type() { - DataType::Utf8 => { - let func = invoke_if_unicode_expressions_feature_flag!( - strpos, Int32Type, "strpos" - ); - make_scalar_function(func)(args) - } - DataType::LargeUtf8 => { - let func = invoke_if_unicode_expressions_feature_flag!( - strpos, Int64Type, "strpos" - ); - make_scalar_function(func)(args) - } - other => Err(DataFusionError::Internal(format!( - "Unsupported data type {:?} for function strpos", - other, - ))), - }), - BuiltinScalarFunction::Substr => Arc::new(|args| match args[0].data_type() { - DataType::Utf8 => { - let func = - invoke_if_unicode_expressions_feature_flag!(substr, i32, "substr"); - make_scalar_function(func)(args) - } - DataType::LargeUtf8 => { - let func = - invoke_if_unicode_expressions_feature_flag!(substr, i64, "substr"); - make_scalar_function(func)(args) - } - other => Err(DataFusionError::Internal(format!( - "Unsupported data type {:?} for function substr", - other, - ))), - }), - BuiltinScalarFunction::ToHex => Arc::new(|args| match args[0].data_type() { - DataType::Int32 => { - make_scalar_function(string_expressions::to_hex::)(args) - } - DataType::Int64 => { - make_scalar_function(string_expressions::to_hex::)(args) - } - other => Err(DataFusionError::Internal(format!( - "Unsupported data type {:?} for function to_hex", - other, - ))), - }), - BuiltinScalarFunction::Translate => Arc::new(|args| match args[0].data_type() { - DataType::Utf8 => { - let func = invoke_if_unicode_expressions_feature_flag!( - translate, - i32, - "translate" - ); - make_scalar_function(func)(args) - } - DataType::LargeUtf8 => { - let func = invoke_if_unicode_expressions_feature_flag!( - translate, - i64, - "translate" - ); - make_scalar_function(func)(args) - } - other => Err(DataFusionError::Internal(format!( - "Unsupported data type {:?} for function translate", - other, - ))), - }), - BuiltinScalarFunction::Trim => Arc::new(|args| match args[0].data_type() { - DataType::Utf8 => { - make_scalar_function(string_expressions::btrim::)(args) - } - DataType::LargeUtf8 => { - make_scalar_function(string_expressions::btrim::)(args) - } - other => Err(DataFusionError::Internal(format!( - "Unsupported data type {:?} for function trim", - other, - ))), - }), - BuiltinScalarFunction::Upper => Arc::new(string_expressions::upper), - _ => { - return Err(DataFusionError::Internal(format!( - "create_physical_fun: Unsupported scalar function {:?}", - fun - ))) - } - }) -} - -/// Create a physical (function) expression. -/// This function errors when `args`' can't be coerced to a valid argument type of the function. -pub fn create_physical_expr( - fun: &BuiltinScalarFunction, - input_phy_exprs: &[Arc], - input_schema: &Schema, - execution_props: &ExecutionProps, -) -> Result> { - let coerced_phy_exprs = coerce(input_phy_exprs, input_schema, &signature(fun))?; - - let coerced_expr_types = coerced_phy_exprs - .iter() - .map(|e| e.data_type(input_schema)) - .collect::>>()?; - - let data_type = return_type(fun, &coerced_expr_types)?; - - let fun_expr: ScalarFunctionImplementation = match fun { - // These functions need args and input schema to pick an implementation - // Unlike the string functions, which actually figure out the function to use with each array, - // here we return either a cast fn or string timestamp translation based on the expression data type - // so we don't have to pay a per-array/batch cost. - BuiltinScalarFunction::ToTimestamp => { - Arc::new(match coerced_phy_exprs[0].data_type(input_schema) { - Ok(DataType::Int64) | Ok(DataType::Timestamp(_, None)) => { - |col_values: &[ColumnarValue]| { - cast_column( - &col_values[0], - &DataType::Timestamp(TimeUnit::Nanosecond, None), - &DEFAULT_DATAFUSION_CAST_OPTIONS, - ) - } - } - Ok(DataType::Utf8) => datetime_expressions::to_timestamp, - other => { - return Err(DataFusionError::Internal(format!( - "Unsupported data type {:?} for function to_timestamp", - other, - ))) - } - }) - } - BuiltinScalarFunction::ToTimestampMillis => { - Arc::new(match coerced_phy_exprs[0].data_type(input_schema) { - Ok(DataType::Int64) | Ok(DataType::Timestamp(_, None)) => { - |col_values: &[ColumnarValue]| { - cast_column( - &col_values[0], - &DataType::Timestamp(TimeUnit::Millisecond, None), - &DEFAULT_DATAFUSION_CAST_OPTIONS, - ) - } - } - Ok(DataType::Utf8) => datetime_expressions::to_timestamp_millis, - other => { - return Err(DataFusionError::Internal(format!( - "Unsupported data type {:?} for function to_timestamp_millis", - other, - ))) - } - }) - } - BuiltinScalarFunction::ToTimestampMicros => { - Arc::new(match coerced_phy_exprs[0].data_type(input_schema) { - Ok(DataType::Int64) | Ok(DataType::Timestamp(_, None)) => { - |col_values: &[ColumnarValue]| { - cast_column( - &col_values[0], - &DataType::Timestamp(TimeUnit::Microsecond, None), - &DEFAULT_DATAFUSION_CAST_OPTIONS, - ) - } - } - Ok(DataType::Utf8) => datetime_expressions::to_timestamp_micros, - other => { - return Err(DataFusionError::Internal(format!( - "Unsupported data type {:?} for function to_timestamp_micros", - other, - ))) - } - }) - } - BuiltinScalarFunction::ToTimestampSeconds => Arc::new({ - match coerced_phy_exprs[0].data_type(input_schema) { - Ok(DataType::Int64) | Ok(DataType::Timestamp(_, None)) => { - |col_values: &[ColumnarValue]| { - cast_column( - &col_values[0], - &DataType::Timestamp(TimeUnit::Second, None), - &DEFAULT_DATAFUSION_CAST_OPTIONS, - ) - } - } - Ok(DataType::Utf8) => datetime_expressions::to_timestamp_seconds, - other => { - return Err(DataFusionError::Internal(format!( - "Unsupported data type {:?} for function to_timestamp_seconds", - other, - ))) - } + BuiltinScalarFunction::ToTimestampSeconds => Arc::new({ + match coerced_phy_exprs[0].data_type(input_schema) { + Ok(DataType::Int64) | Ok(DataType::Timestamp(_, None)) => { + |col_values: &[ColumnarValue]| { + cast_column( + &col_values[0], + &DataType::Timestamp(TimeUnit::Second, None), + &DEFAULT_DATAFUSION_CAST_OPTIONS, + ) + } + } + Ok(DataType::Utf8) => datetime_expressions::to_timestamp_seconds, + other => { + return Err(DataFusionError::Internal(format!( + "Unsupported data type {:?} for function to_timestamp_seconds", + other, + ))) + } } }), // These don't need args and input schema @@ -962,317 +454,705 @@ fn signature(fun: &BuiltinScalarFunction) -> Signature { ], fun.volatility(), ), - BuiltinScalarFunction::ToTimestampMicros => Signature::uniform( - 1, + BuiltinScalarFunction::ToTimestampMicros => Signature::uniform( + 1, + vec![ + DataType::Utf8, + DataType::Int64, + DataType::Timestamp(TimeUnit::Nanosecond, None), + DataType::Timestamp(TimeUnit::Millisecond, None), + DataType::Timestamp(TimeUnit::Second, None), + ], + fun.volatility(), + ), + BuiltinScalarFunction::ToTimestampSeconds => Signature::uniform( + 1, + vec![ + DataType::Utf8, + DataType::Int64, + DataType::Timestamp(TimeUnit::Nanosecond, None), + DataType::Timestamp(TimeUnit::Microsecond, None), + DataType::Timestamp(TimeUnit::Millisecond, None), + ], + fun.volatility(), + ), + BuiltinScalarFunction::Digest => { + Signature::exact(vec![DataType::Utf8, DataType::Utf8], fun.volatility()) + } + BuiltinScalarFunction::DateTrunc => Signature::exact( + vec![ + DataType::Utf8, + DataType::Timestamp(TimeUnit::Nanosecond, None), + ], + fun.volatility(), + ), + BuiltinScalarFunction::DatePart => Signature::one_of( + vec![ + TypeSignature::Exact(vec![DataType::Utf8, DataType::Date32]), + TypeSignature::Exact(vec![DataType::Utf8, DataType::Date64]), + TypeSignature::Exact(vec![ + DataType::Utf8, + DataType::Timestamp(TimeUnit::Second, None), + ]), + TypeSignature::Exact(vec![ + DataType::Utf8, + DataType::Timestamp(TimeUnit::Microsecond, None), + ]), + TypeSignature::Exact(vec![ + DataType::Utf8, + DataType::Timestamp(TimeUnit::Millisecond, None), + ]), + TypeSignature::Exact(vec![ + DataType::Utf8, + DataType::Timestamp(TimeUnit::Nanosecond, None), + ]), + ], + fun.volatility(), + ), + BuiltinScalarFunction::SplitPart => Signature::one_of( + vec![ + TypeSignature::Exact(vec![ + DataType::Utf8, + DataType::Utf8, + DataType::Int64, + ]), + TypeSignature::Exact(vec![ + DataType::LargeUtf8, + DataType::Utf8, + DataType::Int64, + ]), + TypeSignature::Exact(vec![ + DataType::Utf8, + DataType::LargeUtf8, + DataType::Int64, + ]), + TypeSignature::Exact(vec![ + DataType::LargeUtf8, + DataType::LargeUtf8, + DataType::Int64, + ]), + ], + fun.volatility(), + ), + + BuiltinScalarFunction::Strpos | BuiltinScalarFunction::StartsWith => { + Signature::one_of( + vec![ + TypeSignature::Exact(vec![DataType::Utf8, DataType::Utf8]), + TypeSignature::Exact(vec![DataType::Utf8, DataType::LargeUtf8]), + TypeSignature::Exact(vec![DataType::LargeUtf8, DataType::Utf8]), + TypeSignature::Exact(vec![DataType::LargeUtf8, DataType::LargeUtf8]), + ], + fun.volatility(), + ) + } + + BuiltinScalarFunction::Substr => Signature::one_of( + vec![ + TypeSignature::Exact(vec![DataType::Utf8, DataType::Int64]), + TypeSignature::Exact(vec![DataType::LargeUtf8, DataType::Int64]), + TypeSignature::Exact(vec![ + DataType::Utf8, + DataType::Int64, + DataType::Int64, + ]), + TypeSignature::Exact(vec![ + DataType::LargeUtf8, + DataType::Int64, + DataType::Int64, + ]), + ], + fun.volatility(), + ), + + BuiltinScalarFunction::Replace | BuiltinScalarFunction::Translate => { + Signature::one_of( + vec![TypeSignature::Exact(vec![ + DataType::Utf8, + DataType::Utf8, + DataType::Utf8, + ])], + fun.volatility(), + ) + } + BuiltinScalarFunction::RegexpReplace => Signature::one_of( + vec![ + TypeSignature::Exact(vec![ + DataType::Utf8, + DataType::Utf8, + DataType::Utf8, + ]), + TypeSignature::Exact(vec![ + DataType::Utf8, + DataType::Utf8, + DataType::Utf8, + DataType::Utf8, + ]), + ], + fun.volatility(), + ), + + BuiltinScalarFunction::NullIf => { + Signature::uniform(2, SUPPORTED_NULLIF_TYPES.to_vec(), fun.volatility()) + } + BuiltinScalarFunction::RegexpMatch => Signature::one_of( vec![ - DataType::Utf8, - DataType::Int64, - DataType::Timestamp(TimeUnit::Nanosecond, None), - DataType::Timestamp(TimeUnit::Millisecond, None), - DataType::Timestamp(TimeUnit::Second, None), + TypeSignature::Exact(vec![DataType::Utf8, DataType::Utf8]), + TypeSignature::Exact(vec![DataType::LargeUtf8, DataType::Utf8]), + TypeSignature::Exact(vec![ + DataType::Utf8, + DataType::Utf8, + DataType::Utf8, + ]), + TypeSignature::Exact(vec![ + DataType::LargeUtf8, + DataType::Utf8, + DataType::Utf8, + ]), ], fun.volatility(), ), - BuiltinScalarFunction::ToTimestampSeconds => Signature::uniform( + BuiltinScalarFunction::Random => Signature::exact(vec![], fun.volatility()), + // math expressions expect 1 argument of type f64 or f32 + // priority is given to f64 because e.g. `sqrt(1i32)` is in IR (real numbers) and thus we + // return the best approximation for it (in f64). + // We accept f32 because in this case it is clear that the best approximation + // will be as good as the number of digits in the number + _ => Signature::uniform( 1, - vec![ - DataType::Utf8, - DataType::Int64, - DataType::Timestamp(TimeUnit::Nanosecond, None), - DataType::Timestamp(TimeUnit::Microsecond, None), - DataType::Timestamp(TimeUnit::Millisecond, None), - ], + vec![DataType::Float64, DataType::Float32], fun.volatility(), ), + } +} + +pub use datafusion_physical_expr::ScalarFunctionExpr; + +#[cfg(feature = "crypto_expressions")] +macro_rules! invoke_if_crypto_expressions_feature_flag { + ($FUNC:ident, $NAME:expr) => {{ + use datafusion_physical_expr::crypto_expressions; + crypto_expressions::$FUNC + }}; +} + +#[cfg(not(feature = "crypto_expressions"))] +macro_rules! invoke_if_crypto_expressions_feature_flag { + ($FUNC:ident, $NAME:expr) => { + |_: &[ColumnarValue]| -> Result { + Err(DataFusionError::Internal(format!( + "function {} requires compilation with feature flag: crypto_expressions.", + $NAME + ))) + } + }; +} + +#[cfg(feature = "regex_expressions")] +macro_rules! invoke_if_regex_expressions_feature_flag { + ($FUNC:ident, $T:tt, $NAME:expr) => {{ + use datafusion_physical_expr::regex_expressions; + regex_expressions::$FUNC::<$T> + }}; +} + +#[cfg(not(feature = "regex_expressions"))] +macro_rules! invoke_if_regex_expressions_feature_flag { + ($FUNC:ident, $T:tt, $NAME:expr) => { + |_: &[ArrayRef]| -> Result { + Err(DataFusionError::Internal(format!( + "function {} requires compilation with feature flag: regex_expressions.", + $NAME + ))) + } + }; +} + +#[cfg(feature = "unicode_expressions")] +macro_rules! invoke_if_unicode_expressions_feature_flag { + ($FUNC:ident, $T:tt, $NAME:expr) => {{ + use datafusion_physical_expr::unicode_expressions; + unicode_expressions::$FUNC::<$T> + }}; +} + +#[cfg(not(feature = "unicode_expressions"))] +macro_rules! invoke_if_unicode_expressions_feature_flag { + ($FUNC:ident, $T:tt, $NAME:expr) => { + |_: &[ArrayRef]| -> Result { + Err(DataFusionError::Internal(format!( + "function {} requires compilation with feature flag: unicode_expressions.", + $NAME + ))) + } + }; +} + +/// decorates a function to handle [`ScalarValue`]s by converting them to arrays before calling the function +/// and vice-versa after evaluation. +pub fn make_scalar_function(inner: F) -> ScalarFunctionImplementation +where + F: Fn(&[ArrayRef]) -> Result + Sync + Send + 'static, +{ + Arc::new(move |args: &[ColumnarValue]| { + // first, identify if any of the arguments is an Array. If yes, store its `len`, + // as any scalar will need to be converted to an array of len `len`. + let len = args + .iter() + .fold(Option::::None, |acc, arg| match arg { + ColumnarValue::Scalar(_) => acc, + ColumnarValue::Array(a) => Some(a.len()), + }); + + // to array + let args = if let Some(len) = len { + args.iter() + .map(|arg| arg.clone().into_array(len)) + .collect::>() + } else { + args.iter() + .map(|arg| arg.clone().into_array(1)) + .collect::>() + }; + + let result = (inner)(&args); + + // maybe back to scalar + if len.is_some() { + result.map(ColumnarValue::Array) + } else { + ScalarValue::try_from_array(&result?, 0).map(ColumnarValue::Scalar) + } + }) +} + +/// Create a physical scalar function. +pub fn create_physical_fun( + fun: &BuiltinScalarFunction, + execution_props: &ExecutionProps, +) -> Result { + Ok(match fun { + // math functions + BuiltinScalarFunction::Abs => Arc::new(math_expressions::abs), + BuiltinScalarFunction::Acos => Arc::new(math_expressions::acos), + BuiltinScalarFunction::Asin => Arc::new(math_expressions::asin), + BuiltinScalarFunction::Atan => Arc::new(math_expressions::atan), + BuiltinScalarFunction::Ceil => Arc::new(math_expressions::ceil), + BuiltinScalarFunction::Cos => Arc::new(math_expressions::cos), + BuiltinScalarFunction::Exp => Arc::new(math_expressions::exp), + BuiltinScalarFunction::Floor => Arc::new(math_expressions::floor), + BuiltinScalarFunction::Log => Arc::new(math_expressions::log10), + BuiltinScalarFunction::Ln => Arc::new(math_expressions::ln), + BuiltinScalarFunction::Log10 => Arc::new(math_expressions::log10), + BuiltinScalarFunction::Log2 => Arc::new(math_expressions::log2), + BuiltinScalarFunction::Random => Arc::new(math_expressions::random), + BuiltinScalarFunction::Round => Arc::new(math_expressions::round), + BuiltinScalarFunction::Signum => Arc::new(math_expressions::signum), + BuiltinScalarFunction::Sin => Arc::new(math_expressions::sin), + BuiltinScalarFunction::Sqrt => Arc::new(math_expressions::sqrt), + BuiltinScalarFunction::Tan => Arc::new(math_expressions::tan), + BuiltinScalarFunction::Trunc => Arc::new(math_expressions::trunc), + // string functions + BuiltinScalarFunction::Array => Arc::new(array_expressions::array), + BuiltinScalarFunction::Ascii => Arc::new(|args| match args[0].data_type() { + DataType::Utf8 => { + make_scalar_function(string_expressions::ascii::)(args) + } + DataType::LargeUtf8 => { + make_scalar_function(string_expressions::ascii::)(args) + } + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {:?} for function ascii", + other, + ))), + }), + BuiltinScalarFunction::BitLength => Arc::new(|args| match &args[0] { + ColumnarValue::Array(v) => Ok(ColumnarValue::Array(bit_length(v.as_ref())?)), + ColumnarValue::Scalar(v) => match v { + ScalarValue::Utf8(v) => Ok(ColumnarValue::Scalar(ScalarValue::Int32( + v.as_ref().map(|x| (x.len() * 8) as i32), + ))), + ScalarValue::LargeUtf8(v) => Ok(ColumnarValue::Scalar( + ScalarValue::Int64(v.as_ref().map(|x| (x.len() * 8) as i64)), + )), + _ => unreachable!(), + }, + }), + BuiltinScalarFunction::Btrim => Arc::new(|args| match args[0].data_type() { + DataType::Utf8 => { + make_scalar_function(string_expressions::btrim::)(args) + } + DataType::LargeUtf8 => { + make_scalar_function(string_expressions::btrim::)(args) + } + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {:?} for function btrim", + other, + ))), + }), + BuiltinScalarFunction::CharacterLength => { + Arc::new(|args| match args[0].data_type() { + DataType::Utf8 => { + let func = invoke_if_unicode_expressions_feature_flag!( + character_length, + Int32Type, + "character_length" + ); + make_scalar_function(func)(args) + } + DataType::LargeUtf8 => { + let func = invoke_if_unicode_expressions_feature_flag!( + character_length, + Int64Type, + "character_length" + ); + make_scalar_function(func)(args) + } + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {:?} for function character_length", + other, + ))), + }) + } + BuiltinScalarFunction::Chr => { + Arc::new(|args| make_scalar_function(string_expressions::chr)(args)) + } + BuiltinScalarFunction::Concat => Arc::new(string_expressions::concat), + BuiltinScalarFunction::ConcatWithSeparator => { + Arc::new(|args| make_scalar_function(string_expressions::concat_ws)(args)) + } + BuiltinScalarFunction::DatePart => Arc::new(datetime_expressions::date_part), + BuiltinScalarFunction::DateTrunc => Arc::new(datetime_expressions::date_trunc), + BuiltinScalarFunction::Now => { + // bind value for now at plan time + Arc::new(datetime_expressions::make_now( + execution_props.query_execution_start_time, + )) + } + BuiltinScalarFunction::InitCap => Arc::new(|args| match args[0].data_type() { + DataType::Utf8 => { + make_scalar_function(string_expressions::initcap::)(args) + } + DataType::LargeUtf8 => { + make_scalar_function(string_expressions::initcap::)(args) + } + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {:?} for function initcap", + other, + ))), + }), + BuiltinScalarFunction::Left => Arc::new(|args| match args[0].data_type() { + DataType::Utf8 => { + let func = invoke_if_unicode_expressions_feature_flag!(left, i32, "left"); + make_scalar_function(func)(args) + } + DataType::LargeUtf8 => { + let func = invoke_if_unicode_expressions_feature_flag!(left, i64, "left"); + make_scalar_function(func)(args) + } + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {:?} for function left", + other, + ))), + }), + BuiltinScalarFunction::Lower => Arc::new(string_expressions::lower), + BuiltinScalarFunction::Lpad => Arc::new(|args| match args[0].data_type() { + DataType::Utf8 => { + let func = invoke_if_unicode_expressions_feature_flag!(lpad, i32, "lpad"); + make_scalar_function(func)(args) + } + DataType::LargeUtf8 => { + let func = invoke_if_unicode_expressions_feature_flag!(lpad, i64, "lpad"); + make_scalar_function(func)(args) + } + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {:?} for function lpad", + other, + ))), + }), + BuiltinScalarFunction::Ltrim => Arc::new(|args| match args[0].data_type() { + DataType::Utf8 => { + make_scalar_function(string_expressions::ltrim::)(args) + } + DataType::LargeUtf8 => { + make_scalar_function(string_expressions::ltrim::)(args) + } + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {:?} for function ltrim", + other, + ))), + }), + BuiltinScalarFunction::MD5 => { + Arc::new(invoke_if_crypto_expressions_feature_flag!(md5, "md5")) + } BuiltinScalarFunction::Digest => { - Signature::exact(vec![DataType::Utf8, DataType::Utf8], fun.volatility()) + Arc::new(invoke_if_crypto_expressions_feature_flag!(digest, "digest")) } - BuiltinScalarFunction::DateTrunc => Signature::exact( - vec![ - DataType::Utf8, - DataType::Timestamp(TimeUnit::Nanosecond, None), - ], - fun.volatility(), - ), - BuiltinScalarFunction::DatePart => Signature::one_of( - vec![ - TypeSignature::Exact(vec![DataType::Utf8, DataType::Date32]), - TypeSignature::Exact(vec![DataType::Utf8, DataType::Date64]), - TypeSignature::Exact(vec![ - DataType::Utf8, - DataType::Timestamp(TimeUnit::Second, None), - ]), - TypeSignature::Exact(vec![ - DataType::Utf8, - DataType::Timestamp(TimeUnit::Microsecond, None), - ]), - TypeSignature::Exact(vec![ - DataType::Utf8, - DataType::Timestamp(TimeUnit::Millisecond, None), - ]), - TypeSignature::Exact(vec![ - DataType::Utf8, - DataType::Timestamp(TimeUnit::Nanosecond, None), - ]), - ], - fun.volatility(), - ), - BuiltinScalarFunction::SplitPart => Signature::one_of( - vec![ - TypeSignature::Exact(vec![ - DataType::Utf8, - DataType::Utf8, - DataType::Int64, - ]), - TypeSignature::Exact(vec![ - DataType::LargeUtf8, - DataType::Utf8, - DataType::Int64, - ]), - TypeSignature::Exact(vec![ - DataType::Utf8, - DataType::LargeUtf8, - DataType::Int64, - ]), - TypeSignature::Exact(vec![ - DataType::LargeUtf8, - DataType::LargeUtf8, - DataType::Int64, - ]), - ], - fun.volatility(), - ), - - BuiltinScalarFunction::Strpos | BuiltinScalarFunction::StartsWith => { - Signature::one_of( - vec![ - TypeSignature::Exact(vec![DataType::Utf8, DataType::Utf8]), - TypeSignature::Exact(vec![DataType::Utf8, DataType::LargeUtf8]), - TypeSignature::Exact(vec![DataType::LargeUtf8, DataType::Utf8]), - TypeSignature::Exact(vec![DataType::LargeUtf8, DataType::LargeUtf8]), - ], - fun.volatility(), - ) + BuiltinScalarFunction::NullIf => Arc::new(nullif_func), + BuiltinScalarFunction::OctetLength => Arc::new(|args| match &args[0] { + ColumnarValue::Array(v) => Ok(ColumnarValue::Array(length(v.as_ref())?)), + ColumnarValue::Scalar(v) => match v { + ScalarValue::Utf8(v) => Ok(ColumnarValue::Scalar(ScalarValue::Int32( + v.as_ref().map(|x| x.len() as i32), + ))), + ScalarValue::LargeUtf8(v) => Ok(ColumnarValue::Scalar( + ScalarValue::Int64(v.as_ref().map(|x| x.len() as i64)), + )), + _ => unreachable!(), + }, + }), + BuiltinScalarFunction::RegexpMatch => { + Arc::new(|args| match args[0].data_type() { + DataType::Utf8 => { + let func = invoke_if_regex_expressions_feature_flag!( + regexp_match, + i32, + "regexp_match" + ); + make_scalar_function(func)(args) + } + DataType::LargeUtf8 => { + let func = invoke_if_regex_expressions_feature_flag!( + regexp_match, + i64, + "regexp_match" + ); + make_scalar_function(func)(args) + } + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {:?} for function regexp_match", + other + ))), + }) } - - BuiltinScalarFunction::Substr => Signature::one_of( - vec![ - TypeSignature::Exact(vec![DataType::Utf8, DataType::Int64]), - TypeSignature::Exact(vec![DataType::LargeUtf8, DataType::Int64]), - TypeSignature::Exact(vec![ - DataType::Utf8, - DataType::Int64, - DataType::Int64, - ]), - TypeSignature::Exact(vec![ - DataType::LargeUtf8, - DataType::Int64, - DataType::Int64, - ]), - ], - fun.volatility(), - ), - - BuiltinScalarFunction::Replace | BuiltinScalarFunction::Translate => { - Signature::one_of( - vec![TypeSignature::Exact(vec![ - DataType::Utf8, - DataType::Utf8, - DataType::Utf8, - ])], - fun.volatility(), - ) + BuiltinScalarFunction::RegexpReplace => { + Arc::new(|args| match args[0].data_type() { + DataType::Utf8 => { + let func = invoke_if_regex_expressions_feature_flag!( + regexp_replace, + i32, + "regexp_replace" + ); + make_scalar_function(func)(args) + } + DataType::LargeUtf8 => { + let func = invoke_if_regex_expressions_feature_flag!( + regexp_replace, + i64, + "regexp_replace" + ); + make_scalar_function(func)(args) + } + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {:?} for function regexp_replace", + other, + ))), + }) + } + BuiltinScalarFunction::Repeat => Arc::new(|args| match args[0].data_type() { + DataType::Utf8 => { + make_scalar_function(string_expressions::repeat::)(args) + } + DataType::LargeUtf8 => { + make_scalar_function(string_expressions::repeat::)(args) + } + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {:?} for function repeat", + other, + ))), + }), + BuiltinScalarFunction::Replace => Arc::new(|args| match args[0].data_type() { + DataType::Utf8 => { + make_scalar_function(string_expressions::replace::)(args) + } + DataType::LargeUtf8 => { + make_scalar_function(string_expressions::replace::)(args) + } + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {:?} for function replace", + other, + ))), + }), + BuiltinScalarFunction::Reverse => Arc::new(|args| match args[0].data_type() { + DataType::Utf8 => { + let func = + invoke_if_unicode_expressions_feature_flag!(reverse, i32, "reverse"); + make_scalar_function(func)(args) + } + DataType::LargeUtf8 => { + let func = + invoke_if_unicode_expressions_feature_flag!(reverse, i64, "reverse"); + make_scalar_function(func)(args) + } + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {:?} for function reverse", + other, + ))), + }), + BuiltinScalarFunction::Right => Arc::new(|args| match args[0].data_type() { + DataType::Utf8 => { + let func = + invoke_if_unicode_expressions_feature_flag!(right, i32, "right"); + make_scalar_function(func)(args) + } + DataType::LargeUtf8 => { + let func = + invoke_if_unicode_expressions_feature_flag!(right, i64, "right"); + make_scalar_function(func)(args) + } + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {:?} for function right", + other, + ))), + }), + BuiltinScalarFunction::Rpad => Arc::new(|args| match args[0].data_type() { + DataType::Utf8 => { + let func = invoke_if_unicode_expressions_feature_flag!(rpad, i32, "rpad"); + make_scalar_function(func)(args) + } + DataType::LargeUtf8 => { + let func = invoke_if_unicode_expressions_feature_flag!(rpad, i64, "rpad"); + make_scalar_function(func)(args) + } + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {:?} for function rpad", + other, + ))), + }), + BuiltinScalarFunction::Rtrim => Arc::new(|args| match args[0].data_type() { + DataType::Utf8 => { + make_scalar_function(string_expressions::rtrim::)(args) + } + DataType::LargeUtf8 => { + make_scalar_function(string_expressions::rtrim::)(args) + } + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {:?} for function rtrim", + other, + ))), + }), + BuiltinScalarFunction::SHA224 => { + Arc::new(invoke_if_crypto_expressions_feature_flag!(sha224, "sha224")) } - BuiltinScalarFunction::RegexpReplace => Signature::one_of( - vec![ - TypeSignature::Exact(vec![ - DataType::Utf8, - DataType::Utf8, - DataType::Utf8, - ]), - TypeSignature::Exact(vec![ - DataType::Utf8, - DataType::Utf8, - DataType::Utf8, - DataType::Utf8, - ]), - ], - fun.volatility(), - ), - - BuiltinScalarFunction::NullIf => { - Signature::uniform(2, SUPPORTED_NULLIF_TYPES.to_vec(), fun.volatility()) + BuiltinScalarFunction::SHA256 => { + Arc::new(invoke_if_crypto_expressions_feature_flag!(sha256, "sha256")) } - BuiltinScalarFunction::RegexpMatch => Signature::one_of( - vec![ - TypeSignature::Exact(vec![DataType::Utf8, DataType::Utf8]), - TypeSignature::Exact(vec![DataType::LargeUtf8, DataType::Utf8]), - TypeSignature::Exact(vec![ - DataType::Utf8, - DataType::Utf8, - DataType::Utf8, - ]), - TypeSignature::Exact(vec![ - DataType::LargeUtf8, - DataType::Utf8, - DataType::Utf8, - ]), - ], - fun.volatility(), - ), - BuiltinScalarFunction::Random => Signature::exact(vec![], fun.volatility()), - // math expressions expect 1 argument of type f64 or f32 - // priority is given to f64 because e.g. `sqrt(1i32)` is in IR (real numbers) and thus we - // return the best approximation for it (in f64). - // We accept f32 because in this case it is clear that the best approximation - // will be as good as the number of digits in the number - _ => Signature::uniform( - 1, - vec![DataType::Float64, DataType::Float32], - fun.volatility(), - ), - } -} - -/// Physical expression of a scalar function -pub struct ScalarFunctionExpr { - fun: ScalarFunctionImplementation, - name: String, - args: Vec>, - return_type: DataType, -} - -impl Debug for ScalarFunctionExpr { - fn fmt(&self, f: &mut Formatter) -> fmt::Result { - f.debug_struct("ScalarFunctionExpr") - .field("fun", &"") - .field("name", &self.name) - .field("args", &self.args) - .field("return_type", &self.return_type) - .finish() - } -} - -impl ScalarFunctionExpr { - /// Create a new Scalar function - pub fn new( - name: &str, - fun: ScalarFunctionImplementation, - args: Vec>, - return_type: &DataType, - ) -> Self { - Self { - fun, - name: name.to_owned(), - args, - return_type: return_type.clone(), + BuiltinScalarFunction::SHA384 => { + Arc::new(invoke_if_crypto_expressions_feature_flag!(sha384, "sha384")) } - } - - /// Get the scalar function implementation - pub fn fun(&self) -> &ScalarFunctionImplementation { - &self.fun - } - - /// The name for this expression - pub fn name(&self) -> &str { - &self.name - } - - /// Input arguments - pub fn args(&self) -> &[Arc] { - &self.args - } - - /// Data type produced by this expression - pub fn return_type(&self) -> &DataType { - &self.return_type - } -} - -impl fmt::Display for ScalarFunctionExpr { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!( - f, - "{}({})", - self.name, - self.args - .iter() - .map(|e| format!("{}", e)) - .collect::>() - .join(", ") - ) - } -} - -impl PhysicalExpr for ScalarFunctionExpr { - /// Return a reference to Any that can be used for downcasting - fn as_any(&self) -> &dyn Any { - self - } - - fn data_type(&self, _input_schema: &Schema) -> Result { - Ok(self.return_type.clone()) - } - - fn nullable(&self, _input_schema: &Schema) -> Result { - Ok(true) - } - - fn evaluate(&self, batch: &RecordBatch) -> Result { - // evaluate the arguments, if there are no arguments we'll instead pass in a null array - // indicating the batch size (as a convention) - let inputs = match (self.args.len(), self.name.parse::()) { - (0, Ok(scalar_fun)) if scalar_fun.supports_zero_argument() => { - vec![NullColumnarValue::from(batch)] + BuiltinScalarFunction::SHA512 => { + Arc::new(invoke_if_crypto_expressions_feature_flag!(sha512, "sha512")) + } + BuiltinScalarFunction::SplitPart => Arc::new(|args| match args[0].data_type() { + DataType::Utf8 => { + make_scalar_function(string_expressions::split_part::)(args) } - _ => self - .args - .iter() - .map(|e| e.evaluate(batch)) - .collect::>>()?, - }; - - // evaluate the function - let fun = self.fun.as_ref(); - (fun)(&inputs) - } -} - -/// decorates a function to handle [`ScalarValue`]s by converting them to arrays before calling the function -/// and vice-versa after evaluation. -pub fn make_scalar_function(inner: F) -> ScalarFunctionImplementation -where - F: Fn(&[ArrayRef]) -> Result + Sync + Send + 'static, -{ - Arc::new(move |args: &[ColumnarValue]| { - // first, identify if any of the arguments is an Array. If yes, store its `len`, - // as any scalar will need to be converted to an array of len `len`. - let len = args - .iter() - .fold(Option::::None, |acc, arg| match arg { - ColumnarValue::Scalar(_) => acc, - ColumnarValue::Array(a) => Some(a.len()), - }); - - // to array - let args = if let Some(len) = len { - args.iter() - .map(|arg| arg.clone().into_array(len)) - .collect::>() - } else { - args.iter() - .map(|arg| arg.clone().into_array(1)) - .collect::>() - }; - - let result = (inner)(&args); - - // maybe back to scalar - if len.is_some() { - result.map(ColumnarValue::Array) - } else { - ScalarValue::try_from_array(&result?, 0).map(ColumnarValue::Scalar) + DataType::LargeUtf8 => { + make_scalar_function(string_expressions::split_part::)(args) + } + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {:?} for function split_part", + other, + ))), + }), + BuiltinScalarFunction::StartsWith => Arc::new(|args| match args[0].data_type() { + DataType::Utf8 => { + make_scalar_function(string_expressions::starts_with::)(args) + } + DataType::LargeUtf8 => { + make_scalar_function(string_expressions::starts_with::)(args) + } + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {:?} for function starts_with", + other, + ))), + }), + BuiltinScalarFunction::Strpos => Arc::new(|args| match args[0].data_type() { + DataType::Utf8 => { + let func = invoke_if_unicode_expressions_feature_flag!( + strpos, Int32Type, "strpos" + ); + make_scalar_function(func)(args) + } + DataType::LargeUtf8 => { + let func = invoke_if_unicode_expressions_feature_flag!( + strpos, Int64Type, "strpos" + ); + make_scalar_function(func)(args) + } + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {:?} for function strpos", + other, + ))), + }), + BuiltinScalarFunction::Substr => Arc::new(|args| match args[0].data_type() { + DataType::Utf8 => { + let func = + invoke_if_unicode_expressions_feature_flag!(substr, i32, "substr"); + make_scalar_function(func)(args) + } + DataType::LargeUtf8 => { + let func = + invoke_if_unicode_expressions_feature_flag!(substr, i64, "substr"); + make_scalar_function(func)(args) + } + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {:?} for function substr", + other, + ))), + }), + BuiltinScalarFunction::ToHex => Arc::new(|args| match args[0].data_type() { + DataType::Int32 => { + make_scalar_function(string_expressions::to_hex::)(args) + } + DataType::Int64 => { + make_scalar_function(string_expressions::to_hex::)(args) + } + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {:?} for function to_hex", + other, + ))), + }), + BuiltinScalarFunction::Translate => Arc::new(|args| match args[0].data_type() { + DataType::Utf8 => { + let func = invoke_if_unicode_expressions_feature_flag!( + translate, + i32, + "translate" + ); + make_scalar_function(func)(args) + } + DataType::LargeUtf8 => { + let func = invoke_if_unicode_expressions_feature_flag!( + translate, + i64, + "translate" + ); + make_scalar_function(func)(args) + } + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {:?} for function translate", + other, + ))), + }), + BuiltinScalarFunction::Trim => Arc::new(|args| match args[0].data_type() { + DataType::Utf8 => { + make_scalar_function(string_expressions::btrim::)(args) + } + DataType::LargeUtf8 => { + make_scalar_function(string_expressions::btrim::)(args) + } + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {:?} for function trim", + other, + ))), + }), + BuiltinScalarFunction::Upper => Arc::new(string_expressions::upper), + _ => { + return Err(DataFusionError::Internal(format!( + "create_physical_fun: Unsupported scalar function {:?}", + fun + ))) } }) } diff --git a/datafusion/src/physical_plan/mod.rs b/datafusion/src/physical_plan/mod.rs index 7d2599958c85d..e511df1dee90f 100644 --- a/datafusion/src/physical_plan/mod.rs +++ b/datafusion/src/physical_plan/mod.rs @@ -514,14 +514,10 @@ pub fn project_schema( pub mod aggregates; pub mod analyze; -pub mod array_expressions; pub mod coalesce_batches; pub mod coalesce_partitions; pub mod common; pub mod cross_join; -#[cfg(feature = "crypto_expressions")] -pub mod crypto_expressions; -pub mod datetime_expressions; pub mod display; pub mod empty; pub mod explain; @@ -535,22 +531,16 @@ pub mod hash_join; pub mod hash_utils; pub mod join_utils; pub mod limit; -pub mod math_expressions; pub mod memory; pub mod metrics; pub mod planner; pub mod projection; -#[cfg(feature = "regex_expressions")] -pub mod regex_expressions; pub mod repartition; pub mod sorts; pub mod stream; -pub mod string_expressions; pub mod type_coercion; pub mod udaf; pub mod udf; -#[cfg(feature = "unicode_expressions")] -pub mod unicode_expressions; pub mod union; pub mod values; pub mod window_functions; diff --git a/datafusion/src/physical_plan/windows/mod.rs b/datafusion/src/physical_plan/windows/mod.rs index 34f9337b0036f..e833c57c5b5ee 100644 --- a/datafusion/src/physical_plan/windows/mod.rs +++ b/datafusion/src/physical_plan/windows/mod.rs @@ -26,7 +26,7 @@ use crate::physical_plan::{ }, type_coercion::coerce, window_functions::{signature_for_built_in, BuiltInWindowFunction, WindowFunction}, - PhysicalExpr, WindowExpr, + PhysicalExpr, }; use crate::scalar::ScalarValue; use arrow::datatypes::Schema; @@ -35,12 +35,11 @@ use datafusion_physical_expr::window::BuiltInWindowFunctionExpr; use std::convert::TryInto; use std::sync::Arc; -mod aggregate; -mod built_in; mod window_agg_exec; -pub use aggregate::AggregateWindowExpr; -pub use built_in::BuiltInWindowExpr; +pub use datafusion_physical_expr::window::{ + AggregateWindowExpr, BuiltInWindowExpr, WindowExpr, +}; pub use window_agg_exec::WindowAggExec; /// Create a physical expression for window function