From 5b6b98aee7728fe30ff022301bbee8496c8e1b59 Mon Sep 17 00:00:00 2001 From: Yongting You <2010youy01@gmail.com> Date: Thu, 16 Nov 2023 23:19:52 -0800 Subject: [PATCH 1/6] Refactor Expr::ScalarFunction --- .../core/src/datasource/listing/helpers.rs | 49 +++++--- datafusion/core/src/physical_planner.rs | 14 ++- datafusion/expr/src/expr.rs | 50 +++++++-- datafusion/expr/src/expr_fn.rs | 106 ++++++++++-------- datafusion/expr/src/expr_schema.rs | 56 +++++---- datafusion/expr/src/lib.rs | 2 +- datafusion/expr/src/tree_node/expr.rs | 21 +++- .../optimizer/src/analyzer/type_coercion.rs | 48 +++++--- datafusion/optimizer/src/push_down_filter.rs | 27 ++++- .../simplify_expressions/expr_simplifier.rs | 25 +++-- .../src/simplify_expressions/utils.rs | 6 +- datafusion/physical-expr/src/planner.rs | 60 +++++++--- datafusion/proto/src/logical_plan/to_proto.rs | 45 +++++--- .../substrait/src/logical_plan/consumer.rs | 6 +- .../substrait/src/logical_plan/producer.rs | 15 ++- 15 files changed, 365 insertions(+), 165 deletions(-) diff --git a/datafusion/core/src/datasource/listing/helpers.rs b/datafusion/core/src/datasource/listing/helpers.rs index 3d2a3dc928b63..f120b5b6f4922 100644 --- a/datafusion/core/src/datasource/listing/helpers.rs +++ b/datafusion/core/src/datasource/listing/helpers.rs @@ -38,9 +38,9 @@ use super::PartitionedFile; use crate::datasource::listing::ListingTableUrl; use crate::execution::context::SessionState; use datafusion_common::tree_node::{TreeNode, VisitRecursion}; -use datafusion_common::{Column, DFField, DFSchema, DataFusionError}; +use datafusion_common::{internal_err, Column, DFField, DFSchema, DataFusionError}; use datafusion_expr::expr::ScalarUDF; -use datafusion_expr::{Expr, Volatility}; +use datafusion_expr::{Expr, ScalarFunctionDefinition, Volatility}; use datafusion_physical_expr::create_physical_expr; use datafusion_physical_expr::execution_props::ExecutionProps; use object_store::path::Path; @@ -54,13 +54,13 @@ use object_store::{ObjectMeta, ObjectStore}; pub fn expr_applicable_for_cols(col_names: &[String], expr: &Expr) -> bool { let mut is_applicable = true; expr.apply(&mut |expr| { - Ok(match expr { + match expr { Expr::Column(Column { ref name, .. }) => { is_applicable &= col_names.contains(name); if is_applicable { - VisitRecursion::Skip + Ok(VisitRecursion::Skip) } else { - VisitRecursion::Stop + Ok(VisitRecursion::Stop) } } Expr::Literal(_) @@ -89,25 +89,42 @@ pub fn expr_applicable_for_cols(col_names: &[String], expr: &Expr) -> bool { | Expr::ScalarSubquery(_) | Expr::GetIndexedField { .. } | Expr::GroupingSet(_) - | Expr::Case { .. } => VisitRecursion::Continue, + | Expr::Case { .. } => Ok(VisitRecursion::Continue), Expr::ScalarFunction(scalar_function) => { - match scalar_function.fun.volatility() { - Volatility::Immutable => VisitRecursion::Continue, - // TODO: Stable functions could be `applicable`, but that would require access to the context - Volatility::Stable | Volatility::Volatile => { - is_applicable = false; - VisitRecursion::Stop + match &scalar_function.func_def { + ScalarFunctionDefinition::BuiltIn(fun) => { + match fun.volatility() { + Volatility::Immutable => Ok(VisitRecursion::Continue), + // TODO: Stable functions could be `applicable`, but that would require access to the context + Volatility::Stable | Volatility::Volatile => { + is_applicable = false; + Ok(VisitRecursion::Stop) + } + } + } + ScalarFunctionDefinition::UDF(fun) => { + match fun.signature().volatility { + Volatility::Immutable => Ok(VisitRecursion::Continue), + // TODO: Stable functions could be `applicable`, but that would require access to the context + Volatility::Stable | Volatility::Volatile => { + is_applicable = false; + Ok(VisitRecursion::Stop) + } + } + } + ScalarFunctionDefinition::Name(_) => { + internal_err!("Function `Expr` with name should be resolved.") } } } Expr::ScalarUDF(ScalarUDF { fun, .. }) => { match fun.signature().volatility { - Volatility::Immutable => VisitRecursion::Continue, + Volatility::Immutable => Ok(VisitRecursion::Continue), // TODO: Stable functions could be `applicable`, but that would require access to the context Volatility::Stable | Volatility::Volatile => { is_applicable = false; - VisitRecursion::Stop + Ok(VisitRecursion::Stop) } } } @@ -123,9 +140,9 @@ pub fn expr_applicable_for_cols(col_names: &[String], expr: &Expr) -> bool { | Expr::Wildcard { .. } | Expr::Placeholder(_) => { is_applicable = false; - VisitRecursion::Stop + Ok(VisitRecursion::Stop) } - }) + } }) .unwrap(); is_applicable diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index 1f1ef73cae343..63a224a8cfb6d 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -90,7 +90,8 @@ use datafusion_expr::expr::{ use datafusion_expr::expr_rewriter::{unalias, unnormalize_cols}; use datafusion_expr::logical_plan::builder::wrap_projection_for_join_if_necessary; use datafusion_expr::{ - DescribeTable, DmlStatement, StringifiedPlan, WindowFrame, WindowFrameBound, WriteOp, + DescribeTable, DmlStatement, ScalarFunctionDefinition, StringifiedPlan, WindowFrame, + WindowFrameBound, WriteOp, }; use datafusion_physical_expr::expressions::Literal; use datafusion_sql::utils::window_expr_common_partition_keys; @@ -218,8 +219,15 @@ fn create_physical_name(e: &Expr, is_first_expr: bool) -> Result { Ok(name) } - Expr::ScalarFunction(func) => { - create_function_physical_name(&func.fun.to_string(), false, &func.args) + Expr::ScalarFunction(func_expr) => { + let func_name = match &func_expr.func_def { + ScalarFunctionDefinition::BuiltIn(fun) => Ok(fun.to_string()), + ScalarFunctionDefinition::UDF(fun) => Ok(fun.name().to_string()), + ScalarFunctionDefinition::Name(_) => { + internal_err!("Function `Expr` with name should be resolved.") + } + }; + create_function_physical_name(&func_name?, false, &func_expr.args) } Expr::ScalarUDF(ScalarUDF { fun, args }) => { create_function_physical_name(fun.name(), false, args) diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 2b2d30af3bc22..52f0be05113e3 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -338,11 +338,22 @@ impl Between { } } +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub enum ScalarFunctionDefinition { + /// Resolved to a built in scalar function + /// (will be removed long term) + BuiltIn(built_in_function::BuiltinScalarFunction), + /// Resolved to a user defined function + UDF(Arc), + /// A scalar function that will be called by name + Name(Arc), +} + /// ScalarFunction expression invokes a built-in scalar function #[derive(Clone, PartialEq, Eq, Hash, Debug)] pub struct ScalarFunction { /// The function - pub fun: built_in_function::BuiltinScalarFunction, + pub func_def: ScalarFunctionDefinition, /// List of expressions to feed to the functions as arguments pub args: Vec, } @@ -350,7 +361,10 @@ pub struct ScalarFunction { impl ScalarFunction { /// Create a new ScalarFunction expression pub fn new(fun: built_in_function::BuiltinScalarFunction, args: Vec) -> Self { - Self { fun, args } + Self { + func_def: ScalarFunctionDefinition::BuiltIn(fun), + args, + } } } @@ -1198,9 +1212,21 @@ impl fmt::Display for Expr { write!(f, " NULLS LAST") } } - Expr::ScalarFunction(func) => { - fmt_function(f, &func.fun.to_string(), false, &func.args, true) - } + Expr::ScalarFunction(func_expr) => match &func_expr.func_def { + ScalarFunctionDefinition::BuiltIn(builtin_func) => fmt_function( + f, + &builtin_func.to_string(), + false, + &func_expr.args, + true, + ), + ScalarFunctionDefinition::UDF(udf) => { + fmt_function(f, udf.name(), false, &func_expr.args, true) + } + ScalarFunctionDefinition::Name(func_name) => { + fmt_function(f, func_name, false, &func_expr.args, true) + } + }, Expr::ScalarUDF(ScalarUDF { fun, args }) => { fmt_function(f, fun.name(), false, args, true) } @@ -1534,9 +1560,17 @@ fn create_name(e: &Expr) -> Result { } } } - Expr::ScalarFunction(func) => { - create_function_name(&func.fun.to_string(), false, &func.args) - } + Expr::ScalarFunction(func_expr) => match &func_expr.func_def { + ScalarFunctionDefinition::BuiltIn(builtin_func) => { + create_function_name(&builtin_func.to_string(), false, &func_expr.args) + } + ScalarFunctionDefinition::UDF(udf) => { + create_function_name(udf.name(), false, &func_expr.args) + } + ScalarFunctionDefinition::Name(name) => { + create_function_name(name, false, &func_expr.args) + } + }, Expr::ScalarUDF(ScalarUDF { fun, args }) => { create_function_name(fun.name(), false, args) } diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index bcf1aa0ca7e55..3c05f430914b7 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -1007,7 +1007,7 @@ pub fn call_fn(name: impl AsRef, args: Vec) -> Result { #[cfg(test)] mod test { use super::*; - use crate::lit; + use crate::{lit, ScalarFunctionDefinition}; #[test] fn filter_is_null_and_is_not_null() { @@ -1022,8 +1022,10 @@ mod test { macro_rules! test_unary_scalar_expr { ($ENUM:ident, $FUNC:ident) => {{ - if let Expr::ScalarFunction(ScalarFunction { fun, args }) = - $FUNC(col("tableA.a")) + if let Expr::ScalarFunction(ScalarFunction { + func_def: ScalarFunctionDefinition::BuiltIn(fun), + args, + }) = $FUNC(col("tableA.a")) { let name = built_in_function::BuiltinScalarFunction::$ENUM; assert_eq!(name, fun); @@ -1035,42 +1037,42 @@ mod test { } macro_rules! test_scalar_expr { - ($ENUM:ident, $FUNC:ident, $($arg:ident),*) => { - let expected = [$(stringify!($arg)),*]; - let result = $FUNC( + ($ENUM:ident, $FUNC:ident, $($arg:ident),*) => { + let expected = [$(stringify!($arg)),*]; + let result = $FUNC( + $( + col(stringify!($arg.to_string())) + ),* + ); + if let Expr::ScalarFunction(ScalarFunction { func_def: ScalarFunctionDefinition::BuiltIn(fun), args }) = result { + let name = built_in_function::BuiltinScalarFunction::$ENUM; + assert_eq!(name, fun); + assert_eq!(expected.len(), args.len()); + } else { + assert!(false, "unexpected: {:?}", result); + } + }; +} + + macro_rules! test_nary_scalar_expr { + ($ENUM:ident, $FUNC:ident, $($arg:ident),*) => { + let expected = [$(stringify!($arg)),*]; + let result = $FUNC( + vec![ $( col(stringify!($arg.to_string())) ),* - ); - if let Expr::ScalarFunction(ScalarFunction { fun, args }) = result { - let name = built_in_function::BuiltinScalarFunction::$ENUM; - assert_eq!(name, fun); - assert_eq!(expected.len(), args.len()); - } else { - assert!(false, "unexpected: {:?}", result); - } - }; - } - - macro_rules! test_nary_scalar_expr { - ($ENUM:ident, $FUNC:ident, $($arg:ident),*) => { - let expected = [$(stringify!($arg)),*]; - let result = $FUNC( - vec![ - $( - col(stringify!($arg.to_string())) - ),* - ] - ); - if let Expr::ScalarFunction(ScalarFunction { fun, args }) = result { - let name = built_in_function::BuiltinScalarFunction::$ENUM; - assert_eq!(name, fun); - assert_eq!(expected.len(), args.len()); - } else { - assert!(false, "unexpected: {:?}", result); - } - }; - } + ] + ); + if let Expr::ScalarFunction(ScalarFunction { func_def: ScalarFunctionDefinition::BuiltIn(fun), args }) = result { + let name = built_in_function::BuiltinScalarFunction::$ENUM; + assert_eq!(name, fun); + assert_eq!(expected.len(), args.len()); + } else { + assert!(false, "unexpected: {:?}", result); + } + }; +} #[test] fn scalar_function_definitions() { @@ -1199,9 +1201,13 @@ mod test { #[test] fn uuid_function_definitions() { - if let Expr::ScalarFunction(ScalarFunction { fun, args }) = uuid() { + if let Expr::ScalarFunction(ScalarFunction { + func_def: ScalarFunctionDefinition::BuiltIn(func_def), + args, + }) = uuid() + { let name = BuiltinScalarFunction::Uuid; - assert_eq!(name, fun); + assert_eq!(name, func_def); assert_eq!(0, args.len()); } else { unreachable!(); @@ -1210,11 +1216,13 @@ mod test { #[test] fn digest_function_definitions() { - if let Expr::ScalarFunction(ScalarFunction { fun, args }) = - digest(col("tableA.a"), lit("md5")) + if let Expr::ScalarFunction(ScalarFunction { + func_def: ScalarFunctionDefinition::BuiltIn(func_def), + args, + }) = digest(col("tableA.a"), lit("md5")) { let name = BuiltinScalarFunction::Digest; - assert_eq!(name, fun); + assert_eq!(name, func_def); assert_eq!(2, args.len()); } else { unreachable!(); @@ -1223,11 +1231,13 @@ mod test { #[test] fn encode_function_definitions() { - if let Expr::ScalarFunction(ScalarFunction { fun, args }) = - encode(col("tableA.a"), lit("base64")) + if let Expr::ScalarFunction(ScalarFunction { + func_def: ScalarFunctionDefinition::BuiltIn(func_def), + args, + }) = encode(col("tableA.a"), lit("base64")) { let name = BuiltinScalarFunction::Encode; - assert_eq!(name, fun); + assert_eq!(name, func_def); assert_eq!(2, args.len()); } else { unreachable!(); @@ -1236,11 +1246,13 @@ mod test { #[test] fn decode_function_definitions() { - if let Expr::ScalarFunction(ScalarFunction { fun, args }) = - decode(col("tableA.a"), lit("hex")) + if let Expr::ScalarFunction(ScalarFunction { + func_def: ScalarFunctionDefinition::BuiltIn(func_def), + args, + }) = decode(col("tableA.a"), lit("hex")) { let name = BuiltinScalarFunction::Decode; - assert_eq!(name, fun); + assert_eq!(name, func_def); assert_eq!(2, args.len()); } else { unreachable!(); diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index 0d06a1295199b..d3a3528a91f8c 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -18,8 +18,8 @@ use super::{Between, Expr, Like}; use crate::expr::{ AggregateFunction, AggregateUDF, Alias, BinaryExpr, Cast, GetFieldAccess, - GetIndexedField, InList, InSubquery, Placeholder, ScalarFunction, ScalarUDF, Sort, - TryCast, WindowFunction, + GetIndexedField, InList, InSubquery, Placeholder, ScalarFunction, + ScalarFunctionDefinition, ScalarUDF, Sort, TryCast, WindowFunction, }; use crate::field_util::GetFieldAccessSchema; use crate::type_coercion::binary::get_result_type; @@ -89,25 +89,39 @@ impl ExprSchemable for Expr { .collect::>>()?; Ok(fun.return_type(&data_types)?) } - Expr::ScalarFunction(ScalarFunction { fun, args }) => { - let arg_data_types = args - .iter() - .map(|e| e.get_type(schema)) - .collect::>>()?; - - // verify that input data types is consistent with function's `TypeSignature` - data_types(&arg_data_types, &fun.signature()).map_err(|_| { - plan_datafusion_err!( - "{}", - utils::generate_signature_error_msg( - &format!("{fun}"), - fun.signature(), - &arg_data_types, - ) - ) - })?; - - fun.return_type(&arg_data_types) + Expr::ScalarFunction(ScalarFunction { func_def, args }) => { + match func_def { + ScalarFunctionDefinition::BuiltIn(fun) => { + let arg_data_types = args + .iter() + .map(|e| e.get_type(schema)) + .collect::>>()?; + + // verify that input data types is consistent with function's `TypeSignature` + data_types(&arg_data_types, &fun.signature()).map_err(|_| { + plan_datafusion_err!( + "{}", + utils::generate_signature_error_msg( + &format!("{fun}"), + fun.signature(), + &arg_data_types, + ) + ) + })?; + + fun.return_type(&arg_data_types) + } + ScalarFunctionDefinition::UDF(fun) => { + let data_types = args + .iter() + .map(|e| e.get_type(schema)) + .collect::>>()?; + Ok(fun.return_type(&data_types)?) + } + ScalarFunctionDefinition::Name(_) => { + internal_err!("Function `Expr` with name should be resolved.") + } + } } Expr::WindowFunction(WindowFunction { fun, args, .. }) => { let data_types = args diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs index 21c0d750a36d0..c290db2b6eaa2 100644 --- a/datafusion/expr/src/lib.rs +++ b/datafusion/expr/src/lib.rs @@ -61,7 +61,7 @@ pub use built_in_function::BuiltinScalarFunction; pub use columnar_value::ColumnarValue; pub use expr::{ Between, BinaryExpr, Case, Cast, Expr, GetFieldAccess, GetIndexedField, GroupingSet, - Like, TryCast, + Like, ScalarFunctionDefinition, TryCast, }; pub use expr_fn::*; pub use expr_schema::ExprSchemable; diff --git a/datafusion/expr/src/tree_node/expr.rs b/datafusion/expr/src/tree_node/expr.rs index 6b86de37ba44d..387b95b9c8be1 100644 --- a/datafusion/expr/src/tree_node/expr.rs +++ b/datafusion/expr/src/tree_node/expr.rs @@ -20,12 +20,12 @@ use crate::expr::{ AggregateFunction, AggregateUDF, Alias, Between, BinaryExpr, Case, Cast, GetIndexedField, GroupingSet, InList, InSubquery, Like, Placeholder, ScalarFunction, - ScalarUDF, Sort, TryCast, WindowFunction, + ScalarFunctionDefinition, ScalarUDF, Sort, TryCast, WindowFunction, }; use crate::{Expr, GetFieldAccess}; use datafusion_common::tree_node::{TreeNode, VisitRecursion}; -use datafusion_common::Result; +use datafusion_common::{internal_err, DataFusionError, Result}; impl TreeNode for Expr { fn apply_children(&self, op: &mut F) -> Result @@ -276,9 +276,20 @@ impl TreeNode for Expr { asc, nulls_first, )), - Expr::ScalarFunction(ScalarFunction { args, fun }) => Expr::ScalarFunction( - ScalarFunction::new(fun, transform_vec(args, &mut transform)?), - ), + Expr::ScalarFunction(ScalarFunction { func_def, args }) => match func_def { + ScalarFunctionDefinition::BuiltIn(fun) => Expr::ScalarFunction( + ScalarFunction::new(fun, transform_vec(args, &mut transform)?), + ), + ScalarFunctionDefinition::UDF(fun) => Expr::ScalarUDF(ScalarUDF::new( + fun, + transform_vec(args, &mut transform)?, + )), + ScalarFunctionDefinition::Name(_) => { + return internal_err!( + "Function `Expr` with name should be resolved." + ); + } + }, Expr::ScalarUDF(ScalarUDF { args, fun }) => { Expr::ScalarUDF(ScalarUDF::new(fun, transform_vec(args, &mut transform)?)) } diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index 2c5e8c8b1c457..60112e110b8ba 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -45,7 +45,8 @@ use datafusion_expr::type_coercion::{is_datetime, is_utf8_or_large_utf8}; use datafusion_expr::{ is_false, is_not_false, is_not_true, is_not_unknown, is_true, is_unknown, type_coercion, window_function, AggregateFunction, BuiltinScalarFunction, Expr, - LogicalPlan, Operator, Projection, WindowFrame, WindowFrameBound, WindowFrameUnits, + LogicalPlan, Operator, Projection, ScalarFunctionDefinition, WindowFrame, + WindowFrameBound, WindowFrameUnits, }; use datafusion_expr::{ExprSchemable, Signature}; @@ -327,16 +328,34 @@ impl TreeNodeRewriter for TypeCoercionRewriter { )?; Ok(Expr::ScalarUDF(ScalarUDF::new(fun, new_expr))) } - Expr::ScalarFunction(ScalarFunction { fun, args }) => { - let new_args = coerce_arguments_for_signature( - args.as_slice(), - &self.schema, - &fun.signature(), - )?; - let new_args = - coerce_arguments_for_fun(new_args.as_slice(), &self.schema, &fun)?; - Ok(Expr::ScalarFunction(ScalarFunction::new(fun, new_args))) - } + Expr::ScalarFunction(ScalarFunction { func_def, args }) => match func_def { + ScalarFunctionDefinition::BuiltIn(fun) => { + let new_args = coerce_arguments_for_signature( + args.as_slice(), + &self.schema, + &fun.signature(), + )?; + let new_args = coerce_arguments_for_fun( + new_args.as_slice(), + &self.schema, + &fun, + )?; + Ok(Expr::ScalarFunction(ScalarFunction::new(fun, new_args))) + } + ScalarFunctionDefinition::UDF(fun) => { + let new_expr = coerce_arguments_for_signature( + args.as_slice(), + &self.schema, + fun.signature(), + )?; + Ok(Expr::ScalarUDF(ScalarUDF::new(fun, new_expr))) + } + ScalarFunctionDefinition::Name(_) => { + return internal_err!( + "Function `Expr` with name should be resolved." + ); + } + }, Expr::AggregateFunction(expr::AggregateFunction { fun, args, @@ -773,7 +792,8 @@ mod test { use datafusion_expr::{ cast, col, concat, concat_ws, create_udaf, is_true, AccumulatorFactoryFunction, AggregateFunction, AggregateUDF, BinaryExpr, BuiltinScalarFunction, Case, - ColumnarValue, ExprSchemable, Filter, Operator, StateTypeFunction, Subquery, + ColumnarValue, ExprSchemable, Filter, Operator, ScalarFunctionDefinition, + StateTypeFunction, Subquery, }; use datafusion_expr::{ lit, @@ -1247,7 +1267,7 @@ mod test { ), ))); let expr = Expr::ScalarFunction(ScalarFunction { - fun: BuiltinScalarFunction::MakeArray, + func_def: ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::MakeArray), args: vec![val.clone()], }); let schema = Arc::new(DFSchema::new_with_metadata( @@ -1279,7 +1299,7 @@ mod test { )?; let expected = Expr::ScalarFunction(ScalarFunction { - fun: BuiltinScalarFunction::MakeArray, + func_def: ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::MakeArray), args: vec![expected_casted_expr], }); diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index 05f4072e38573..c26b3de0823ca 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -28,7 +28,8 @@ use datafusion_expr::{ and, expr_rewriter::replace_col, logical_plan::{CrossJoin, Join, JoinType, LogicalPlan, TableScan, Union}, - or, BinaryExpr, Expr, Filter, Operator, TableProviderFilterPushDown, + or, BinaryExpr, Expr, Filter, Operator, ScalarFunctionDefinition, + TableProviderFilterPushDown, }; use itertools::Itertools; use std::collections::{HashMap, HashSet}; @@ -977,10 +978,26 @@ fn is_volatile_expression(e: &Expr) -> bool { let mut is_volatile = false; e.apply(&mut |expr| { Ok(match expr { - Expr::ScalarFunction(f) if f.fun.volatility() == Volatility::Volatile => { - is_volatile = true; - VisitRecursion::Stop - } + Expr::ScalarFunction(f) => match &f.func_def { + ScalarFunctionDefinition::BuiltIn(fun) + if fun.volatility() == Volatility::Volatile => + { + is_volatile = true; + VisitRecursion::Stop + } + ScalarFunctionDefinition::UDF(fun) + if fun.signature().volatility == Volatility::Volatile => + { + is_volatile = true; + VisitRecursion::Stop + } + ScalarFunctionDefinition::Name(_) => { + return internal_err!( + "Function `Expr` with name should be resolved." + ); + } + _ => VisitRecursion::Continue, + }, _ => VisitRecursion::Continue, }) }) diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index 947a6f6070d2d..02835f7ca3bf2 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -40,7 +40,7 @@ use datafusion_common::{ use datafusion_expr::expr::{InList, InSubquery, ScalarFunction}; use datafusion_expr::{ and, expr, lit, or, BinaryExpr, BuiltinScalarFunction, Case, ColumnarValue, Expr, - Like, Volatility, + Like, ScalarFunctionDefinition, Volatility, }; use datafusion_physical_expr::{ create_physical_expr, execution_props::ExecutionProps, intervals::NullableInterval, @@ -345,9 +345,15 @@ impl<'a> ConstEvaluator<'a> { | Expr::GroupingSet(_) | Expr::Wildcard { .. } | Expr::Placeholder(_) => false, - Expr::ScalarFunction(ScalarFunction { fun, .. }) => { - Self::volatility_ok(fun.volatility()) - } + Expr::ScalarFunction(ScalarFunction { func_def, .. }) => match func_def { + ScalarFunctionDefinition::BuiltIn(fun) => { + Self::volatility_ok(fun.volatility()) + } + ScalarFunctionDefinition::UDF(fun) => { + Self::volatility_ok(fun.signature().volatility) + } + ScalarFunctionDefinition::Name(_) => true, + }, Expr::ScalarUDF(expr::ScalarUDF { fun, .. }) => { Self::volatility_ok(fun.signature().volatility) } @@ -1201,25 +1207,28 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { // log Expr::ScalarFunction(ScalarFunction { - fun: BuiltinScalarFunction::Log, + func_def: ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::Log), args, }) => simpl_log(args, <&S>::clone(&info))?, // power Expr::ScalarFunction(ScalarFunction { - fun: BuiltinScalarFunction::Power, + func_def: ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::Power), args, }) => simpl_power(args, <&S>::clone(&info))?, // concat Expr::ScalarFunction(ScalarFunction { - fun: BuiltinScalarFunction::Concat, + func_def: ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::Concat), args, }) => simpl_concat(args)?, // concat_ws Expr::ScalarFunction(ScalarFunction { - fun: BuiltinScalarFunction::ConcatWithSeparator, + func_def: + ScalarFunctionDefinition::BuiltIn( + BuiltinScalarFunction::ConcatWithSeparator, + ), args, }) => match &args[..] { [delimiter, vals @ ..] => simpl_concat_ws(delimiter, vals)?, diff --git a/datafusion/optimizer/src/simplify_expressions/utils.rs b/datafusion/optimizer/src/simplify_expressions/utils.rs index 17e5d97c30062..fa91a3ace2a25 100644 --- a/datafusion/optimizer/src/simplify_expressions/utils.rs +++ b/datafusion/optimizer/src/simplify_expressions/utils.rs @@ -23,7 +23,7 @@ use datafusion_expr::expr::ScalarFunction; use datafusion_expr::{ expr::{Between, BinaryExpr, InList}, expr_fn::{and, bitwise_and, bitwise_or, concat_ws, or}, - lit, BuiltinScalarFunction, Expr, Like, Operator, + lit, BuiltinScalarFunction, Expr, Like, Operator, ScalarFunctionDefinition, }; pub static POWS_OF_TEN: [i128; 38] = [ @@ -365,7 +365,7 @@ pub fn simpl_log(current_args: Vec, info: &dyn SimplifyInfo) -> Result Ok(args[1].clone()), _ => { @@ -405,7 +405,7 @@ pub fn simpl_power(current_args: Vec, info: &dyn SimplifyInfo) -> Result Ok(args[1].clone()), _ => Ok(Expr::ScalarFunction(ScalarFunction::new( diff --git a/datafusion/physical-expr/src/planner.rs b/datafusion/physical-expr/src/planner.rs index f318cd3b0f4d9..beabc75954f73 100644 --- a/datafusion/physical-expr/src/planner.rs +++ b/datafusion/physical-expr/src/planner.rs @@ -32,7 +32,7 @@ use datafusion_common::{ use datafusion_expr::expr::{Alias, Cast, InList, ScalarFunction, ScalarUDF}; use datafusion_expr::{ binary_expr, Between, BinaryExpr, Expr, GetFieldAccess, GetIndexedField, Like, - Operator, TryCast, + Operator, ScalarFunctionDefinition, TryCast, }; use std::sync::Arc; @@ -348,20 +348,50 @@ pub fn create_physical_expr( ))) } - Expr::ScalarFunction(ScalarFunction { fun, args }) => { - let physical_args = args - .iter() - .map(|e| { - create_physical_expr(e, input_dfschema, input_schema, execution_props) - }) - .collect::>>()?; - functions::create_physical_expr( - fun, - &physical_args, - input_schema, - execution_props, - ) - } + Expr::ScalarFunction(ScalarFunction { func_def, args }) => match func_def { + ScalarFunctionDefinition::BuiltIn(fun) => { + let physical_args = args + .iter() + .map(|e| { + create_physical_expr( + e, + input_dfschema, + input_schema, + execution_props, + ) + }) + .collect::>>()?; + functions::create_physical_expr( + fun, + &physical_args, + input_schema, + execution_props, + ) + } + ScalarFunctionDefinition::UDF(fun) => { + let mut physical_args = vec![]; + for e in args { + physical_args.push(create_physical_expr( + e, + input_dfschema, + input_schema, + execution_props, + )?); + } + // udfs with zero params expect null array as input + if args.is_empty() { + physical_args.push(Arc::new(Literal::new(ScalarValue::Null))); + } + udf::create_physical_expr( + fun.clone().as_ref(), + &physical_args, + input_schema, + ) + } + ScalarFunctionDefinition::Name(_) => { + internal_err!("Function `Expr` with name should be resolved.") + } + }, Expr::ScalarUDF(ScalarUDF { fun, args }) => { let mut physical_args = vec![]; for e in args { diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 649be05b88c37..a4fbe38901624 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -45,7 +45,7 @@ use datafusion_common::{ }; use datafusion_expr::expr::{ self, Alias, Between, BinaryExpr, Cast, GetFieldAccess, GetIndexedField, GroupingSet, - InList, Like, Placeholder, ScalarFunction, ScalarUDF, Sort, + InList, Like, Placeholder, ScalarFunction, ScalarFunctionDefinition, ScalarUDF, Sort, }; use datafusion_expr::{ logical_plan::PlanType, logical_plan::StringifiedPlan, AggregateFunction, @@ -752,21 +752,40 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode { .to_string(), )) } - Expr::ScalarFunction(ScalarFunction { fun, args }) => { - let fun: protobuf::ScalarFunction = fun.try_into()?; - let args: Vec = args - .iter() - .map(|e| e.try_into()) - .collect::, Error>>()?; - Self { - expr_type: Some(ExprType::ScalarFunction( - protobuf::ScalarFunctionNode { - fun: fun.into(), - args, + Expr::ScalarFunction(ScalarFunction { func_def, args }) => match func_def { + ScalarFunctionDefinition::BuiltIn(fun) => { + let fun: protobuf::ScalarFunction = fun.try_into()?; + let args: Vec = args + .iter() + .map(|e| e.try_into()) + .collect::, Error>>()?; + Self { + expr_type: Some(ExprType::ScalarFunction( + protobuf::ScalarFunctionNode { + fun: fun.into(), + args, + }, + )), + } + } + ScalarFunctionDefinition::UDF(fun) => Self { + expr_type: Some(ExprType::ScalarUdfExpr( + protobuf::ScalarUdfExprNode { + fun_name: fun.name().to_string(), + args: args + .iter() + .map(|expr| expr.try_into()) + .collect::, Error>>()?, }, )), + }, + ScalarFunctionDefinition::Name(_) => { + return Err(Error::NotImplemented( + "Proto serialization error: Trying to serialize a unresolved function" + .to_string(), + )); } - } + }, Expr::ScalarUDF(ScalarUDF { fun, args }) => Self { expr_type: Some(ExprType::ScalarUdfExpr(protobuf::ScalarUdfExprNode { fun_name: fun.name().to_string(), diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index f4c36557dac8b..9d670fddc36d6 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -25,8 +25,8 @@ use datafusion::logical_expr::{ BuiltinScalarFunction, Case, Expr, LogicalPlan, Operator, }; use datafusion::logical_expr::{ - expr, Cast, Extension, GroupingSet, Like, LogicalPlanBuilder, WindowFrameBound, - WindowFrameUnits, + expr, Cast, Extension, GroupingSet, Like, LogicalPlanBuilder, + ScalarFunctionDefinition, WindowFrameBound, WindowFrameUnits, }; use datafusion::prelude::JoinType; use datafusion::sql::TableReference; @@ -844,7 +844,7 @@ pub async fn from_substrait_rex( args.push(arg_expr?.as_ref().clone()); } Ok(Arc::new(Expr::ScalarFunction(expr::ScalarFunction { - fun, + func_def: ScalarFunctionDefinition::BuiltIn(fun), args, }))) } diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index 4b6aded78b49d..9c7ad0647e2f1 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -34,7 +34,7 @@ use datafusion::common::{exec_err, internal_err, not_impl_err}; use datafusion::logical_expr::aggregate_function; use datafusion::logical_expr::expr::{ Alias, BinaryExpr, Case, Cast, GroupingSet, InList, - ScalarFunction as DFScalarFunction, Sort, WindowFunction, + ScalarFunction as DFScalarFunction, ScalarFunctionDefinition, Sort, WindowFunction, }; use datafusion::logical_expr::{expr, Between, JoinConstraint, LogicalPlan, Operator}; use datafusion::prelude::Expr; @@ -822,7 +822,7 @@ pub fn to_substrait_rex( Ok(substrait_or_list) } } - Expr::ScalarFunction(DFScalarFunction { fun, args }) => { + Expr::ScalarFunction(DFScalarFunction { func_def, args }) => { let mut arguments: Vec = vec![]; for arg in args { arguments.push(FunctionArgument { @@ -834,7 +834,16 @@ pub fn to_substrait_rex( )?)), }); } - let function_anchor = _register_function(fun.to_string(), extension_info); + + let func_name = match &func_def { + ScalarFunctionDefinition::BuiltIn(fun) => Ok(fun.to_string()), + ScalarFunctionDefinition::UDF(fun) => Ok(fun.name().to_string()), + ScalarFunctionDefinition::Name(_) => { + internal_err!("Function `Expr` with name should be resolved.") + } + }; + + let function_anchor = _register_function(func_name?, extension_info); Ok(Expression { rex_type: Some(RexType::ScalarFunction(ScalarFunction { function_reference: function_anchor, From 6a7ee875e1f7963a76098f0ed1f06e56b4a0a468 Mon Sep 17 00:00:00 2001 From: Yongting You <2010youy01@gmail.com> Date: Fri, 17 Nov 2023 12:45:26 -0800 Subject: [PATCH 2/6] Remove Expr::ScalarUDF --- .../core/src/datasource/listing/helpers.rs | 11 ----- datafusion/core/src/physical_planner.rs | 6 +-- datafusion/expr/src/expr.rs | 41 ++++++------------- datafusion/expr/src/expr_schema.rs | 10 +---- datafusion/expr/src/tree_node/expr.rs | 14 +++---- datafusion/expr/src/udf.rs | 5 ++- datafusion/expr/src/utils.rs | 1 - .../optimizer/src/analyzer/type_coercion.rs | 26 ++++-------- datafusion/optimizer/src/push_down_filter.rs | 5 ++- .../simplify_expressions/expr_simplifier.rs | 23 ++++++----- datafusion/physical-expr/src/planner.rs | 18 +------- .../proto/src/logical_plan/from_proto.rs | 2 +- datafusion/proto/src/logical_plan/to_proto.rs | 11 +---- .../tests/cases/roundtrip_logical_plan.rs | 7 +++- datafusion/sql/src/expr/function.rs | 4 +- 15 files changed, 58 insertions(+), 126 deletions(-) diff --git a/datafusion/core/src/datasource/listing/helpers.rs b/datafusion/core/src/datasource/listing/helpers.rs index f120b5b6f4922..19129986e4d4b 100644 --- a/datafusion/core/src/datasource/listing/helpers.rs +++ b/datafusion/core/src/datasource/listing/helpers.rs @@ -39,7 +39,6 @@ use crate::datasource::listing::ListingTableUrl; use crate::execution::context::SessionState; use datafusion_common::tree_node::{TreeNode, VisitRecursion}; use datafusion_common::{internal_err, Column, DFField, DFSchema, DataFusionError}; -use datafusion_expr::expr::ScalarUDF; use datafusion_expr::{Expr, ScalarFunctionDefinition, Volatility}; use datafusion_physical_expr::create_physical_expr; use datafusion_physical_expr::execution_props::ExecutionProps; @@ -118,16 +117,6 @@ pub fn expr_applicable_for_cols(col_names: &[String], expr: &Expr) -> bool { } } } - Expr::ScalarUDF(ScalarUDF { fun, .. }) => { - match fun.signature().volatility { - Volatility::Immutable => Ok(VisitRecursion::Continue), - // TODO: Stable functions could be `applicable`, but that would require access to the context - Volatility::Stable | Volatility::Volatile => { - is_applicable = false; - Ok(VisitRecursion::Stop) - } - } - } // TODO other expressions are not handled yet: // - AGGREGATE, WINDOW and SORT should not end up in filter conditions, except maybe in some edge cases diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index 63a224a8cfb6d..9d6f742b2d26b 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -84,8 +84,7 @@ use datafusion_common::{ use datafusion_expr::dml::{CopyOptions, CopyTo}; use datafusion_expr::expr::{ self, AggregateFunction, AggregateUDF, Alias, Between, BinaryExpr, Cast, - GetFieldAccess, GetIndexedField, GroupingSet, InList, Like, ScalarUDF, TryCast, - WindowFunction, + GetFieldAccess, GetIndexedField, GroupingSet, InList, Like, TryCast, WindowFunction, }; use datafusion_expr::expr_rewriter::{unalias, unnormalize_cols}; use datafusion_expr::logical_plan::builder::wrap_projection_for_join_if_necessary; @@ -229,9 +228,6 @@ fn create_physical_name(e: &Expr, is_first_expr: bool) -> Result { }; create_function_physical_name(&func_name?, false, &func_expr.args) } - Expr::ScalarUDF(ScalarUDF { fun, args }) => { - create_function_physical_name(fun.name(), false, args) - } Expr::WindowFunction(WindowFunction { fun, args, .. }) => { create_function_physical_name(&fun.to_string(), false, args) } diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 52f0be05113e3..7adef69a7d40d 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -148,10 +148,8 @@ pub enum Expr { TryCast(TryCast), /// A sort expression, that can be used to sort values. Sort(Sort), - /// Represents the call of a built-in scalar function with a set of arguments. + /// Represents the call of a scalar function with a set of arguments. ScalarFunction(ScalarFunction), - /// Represents the call of a user-defined scalar function with arguments. - ScalarUDF(ScalarUDF), /// Represents the call of an aggregate built-in function with arguments. AggregateFunction(AggregateFunction), /// Represents the call of a window function with arguments. @@ -340,12 +338,14 @@ impl Between { #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub enum ScalarFunctionDefinition { - /// Resolved to a built in scalar function - /// (will be removed long term) + /// Resolved to a `BuiltinScalarFunction` + /// There is plan to migrate `BuiltinScalarFunction` to UDF-based implementation (issue#8045) + /// This variant is planned to be removed in long term BuiltIn(built_in_function::BuiltinScalarFunction), /// Resolved to a user defined function UDF(Arc), - /// A scalar function that will be called by name + /// A scalar function constructed with name, could be resolved with registered functions during + /// analyzing. Name(Arc), } @@ -366,23 +366,13 @@ impl ScalarFunction { args, } } -} -/// ScalarUDF expression invokes a user-defined scalar function [`ScalarUDF`] -/// -/// [`ScalarUDF`]: crate::ScalarUDF -#[derive(Clone, PartialEq, Eq, Hash, Debug)] -pub struct ScalarUDF { - /// The function - pub fun: Arc, - /// List of expressions to feed to the functions as arguments - pub args: Vec, -} - -impl ScalarUDF { - /// Create a new ScalarUDF expression - pub fn new(fun: Arc, args: Vec) -> Self { - Self { fun, args } + /// Create a new ScalarFunction expression with a user-defined function (UDF) + pub fn new_udf(udf: Arc, args: Vec) -> Self { + Self { + func_def: ScalarFunctionDefinition::UDF(udf), + args, + } } } @@ -750,7 +740,6 @@ impl Expr { Expr::Placeholder(_) => "Placeholder", Expr::ScalarFunction(..) => "ScalarFunction", Expr::ScalarSubquery { .. } => "ScalarSubquery", - Expr::ScalarUDF(..) => "ScalarUDF", Expr::ScalarVariable(..) => "ScalarVariable", Expr::Sort { .. } => "Sort", Expr::TryCast { .. } => "TryCast", @@ -1227,9 +1216,6 @@ impl fmt::Display for Expr { fmt_function(f, func_name, false, &func_expr.args, true) } }, - Expr::ScalarUDF(ScalarUDF { fun, args }) => { - fmt_function(f, fun.name(), false, args, true) - } Expr::WindowFunction(WindowFunction { fun, args, @@ -1571,9 +1557,6 @@ fn create_name(e: &Expr) -> Result { create_function_name(name, false, &func_expr.args) } }, - Expr::ScalarUDF(ScalarUDF { fun, args }) => { - create_function_name(fun.name(), false, args) - } Expr::WindowFunction(WindowFunction { fun, args, diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index d3a3528a91f8c..e13765353d587 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -19,7 +19,7 @@ use super::{Between, Expr, Like}; use crate::expr::{ AggregateFunction, AggregateUDF, Alias, BinaryExpr, Cast, GetFieldAccess, GetIndexedField, InList, InSubquery, Placeholder, ScalarFunction, - ScalarFunctionDefinition, ScalarUDF, Sort, TryCast, WindowFunction, + ScalarFunctionDefinition, Sort, TryCast, WindowFunction, }; use crate::field_util::GetFieldAccessSchema; use crate::type_coercion::binary::get_result_type; @@ -82,13 +82,6 @@ impl ExprSchemable for Expr { Expr::Case(case) => case.when_then_expr[0].1.get_type(schema), Expr::Cast(Cast { data_type, .. }) | Expr::TryCast(TryCast { data_type, .. }) => Ok(data_type.clone()), - Expr::ScalarUDF(ScalarUDF { fun, args }) => { - let data_types = args - .iter() - .map(|e| e.get_type(schema)) - .collect::>>()?; - Ok(fun.return_type(&data_types)?) - } Expr::ScalarFunction(ScalarFunction { func_def, args }) => { match func_def { ScalarFunctionDefinition::BuiltIn(fun) => { @@ -257,7 +250,6 @@ impl ExprSchemable for Expr { Expr::ScalarVariable(_, _) | Expr::TryCast { .. } | Expr::ScalarFunction(..) - | Expr::ScalarUDF(..) | Expr::WindowFunction { .. } | Expr::AggregateFunction { .. } | Expr::AggregateUDF { .. } diff --git a/datafusion/expr/src/tree_node/expr.rs b/datafusion/expr/src/tree_node/expr.rs index 387b95b9c8be1..f50939a8e66d4 100644 --- a/datafusion/expr/src/tree_node/expr.rs +++ b/datafusion/expr/src/tree_node/expr.rs @@ -20,7 +20,7 @@ use crate::expr::{ AggregateFunction, AggregateUDF, Alias, Between, BinaryExpr, Case, Cast, GetIndexedField, GroupingSet, InList, InSubquery, Like, Placeholder, ScalarFunction, - ScalarFunctionDefinition, ScalarUDF, Sort, TryCast, WindowFunction, + ScalarFunctionDefinition, Sort, TryCast, WindowFunction, }; use crate::{Expr, GetFieldAccess}; @@ -64,7 +64,7 @@ impl TreeNode for Expr { } Expr::GroupingSet(GroupingSet::Rollup(exprs)) | Expr::GroupingSet(GroupingSet::Cube(exprs)) => exprs.clone(), - Expr::ScalarFunction (ScalarFunction{ args, .. } )| Expr::ScalarUDF(ScalarUDF { args, .. }) => { + Expr::ScalarFunction (ScalarFunction{ args, .. } ) => { args.clone() } Expr::GroupingSet(GroupingSet::GroupingSets(lists_of_exprs)) => { @@ -280,19 +280,15 @@ impl TreeNode for Expr { ScalarFunctionDefinition::BuiltIn(fun) => Expr::ScalarFunction( ScalarFunction::new(fun, transform_vec(args, &mut transform)?), ), - ScalarFunctionDefinition::UDF(fun) => Expr::ScalarUDF(ScalarUDF::new( - fun, - transform_vec(args, &mut transform)?, - )), + ScalarFunctionDefinition::UDF(fun) => Expr::ScalarFunction( + ScalarFunction::new_udf(fun, transform_vec(args, &mut transform)?), + ), ScalarFunctionDefinition::Name(_) => { return internal_err!( "Function `Expr` with name should be resolved." ); } }, - Expr::ScalarUDF(ScalarUDF { args, fun }) => { - Expr::ScalarUDF(ScalarUDF::new(fun, transform_vec(args, &mut transform)?)) - } Expr::WindowFunction(WindowFunction { args, fun, diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs index 22e56caaaf5f7..bc910b928a5d7 100644 --- a/datafusion/expr/src/udf.rs +++ b/datafusion/expr/src/udf.rs @@ -95,7 +95,10 @@ impl ScalarUDF { /// creates a logical expression with a call of the UDF /// This utility allows using the UDF without requiring access to the registry. pub fn call(&self, args: Vec) -> Expr { - Expr::ScalarUDF(crate::expr::ScalarUDF::new(Arc::new(self.clone()), args)) + Expr::ScalarFunction(crate::expr::ScalarFunction::new_udf( + Arc::new(self.clone()), + args, + )) } /// Returns this function's name diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index 8f13bf5f61be6..5e2ae706d37d5 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -283,7 +283,6 @@ pub fn expr_to_columns(expr: &Expr, accum: &mut HashSet) -> Result<()> { | Expr::TryCast { .. } | Expr::Sort { .. } | Expr::ScalarFunction(..) - | Expr::ScalarUDF(..) | Expr::WindowFunction { .. } | Expr::AggregateFunction { .. } | Expr::GroupingSet(_) diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index 60112e110b8ba..05f438bd51274 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -29,7 +29,7 @@ use datafusion_common::{ }; use datafusion_expr::expr::{ self, Between, BinaryExpr, Case, Exists, InList, InSubquery, Like, ScalarFunction, - ScalarUDF, WindowFunction, + WindowFunction, }; use datafusion_expr::expr_rewriter::rewrite_preserving_name; use datafusion_expr::expr_schema::cast_subquery; @@ -320,14 +320,6 @@ impl TreeNodeRewriter for TypeCoercionRewriter { let case = coerce_case_expression(case, &self.schema)?; Ok(Expr::Case(case)) } - Expr::ScalarUDF(ScalarUDF { fun, args }) => { - let new_expr = coerce_arguments_for_signature( - args.as_slice(), - &self.schema, - fun.signature(), - )?; - Ok(Expr::ScalarUDF(ScalarUDF::new(fun, new_expr))) - } Expr::ScalarFunction(ScalarFunction { func_def, args }) => match func_def { ScalarFunctionDefinition::BuiltIn(fun) => { let new_args = coerce_arguments_for_signature( @@ -348,12 +340,10 @@ impl TreeNodeRewriter for TypeCoercionRewriter { &self.schema, fun.signature(), )?; - Ok(Expr::ScalarUDF(ScalarUDF::new(fun, new_expr))) + Ok(Expr::ScalarFunction(ScalarFunction::new_udf(fun, new_expr))) } ScalarFunctionDefinition::Name(_) => { - return internal_err!( - "Function `Expr` with name should be resolved." - ); + internal_err!("Function `Expr` with name should be resolved.") } }, Expr::AggregateFunction(expr::AggregateFunction { @@ -858,7 +848,7 @@ mod test { Arc::new(move |_| Ok(Arc::new(DataType::Utf8))); let fun: ScalarFunctionImplementation = Arc::new(move |_| Ok(ColumnarValue::Scalar(ScalarValue::new_utf8("a")))); - let udf = Expr::ScalarUDF(expr::ScalarUDF::new( + let udf = Expr::ScalarFunction(expr::ScalarFunction::new_udf( Arc::new(ScalarUDF::new( "TestScalarUDF", &Signature::uniform(1, vec![DataType::Float32], Volatility::Stable), @@ -879,7 +869,7 @@ mod test { let return_type: ReturnTypeFunction = Arc::new(move |_| Ok(Arc::new(DataType::Utf8))); let fun: ScalarFunctionImplementation = Arc::new(move |_| unimplemented!()); - let udf = Expr::ScalarUDF(expr::ScalarUDF::new( + let udf = Expr::ScalarFunction(expr::ScalarFunction::new_udf( Arc::new(ScalarUDF::new( "TestScalarUDF", &Signature::uniform(1, vec![DataType::Int32], Volatility::Stable), @@ -893,9 +883,9 @@ mod test { .err() .unwrap(); assert_eq!( - "type_coercion\ncaused by\nError during planning: Coercion from [Utf8] to the signature Uniform(1, [Int32]) failed.", - err.strip_backtrace() - ); + "type_coercion\ncaused by\nError during planning: Coercion from [Utf8] to the signature Uniform(1, [Int32]) failed.", + err.strip_backtrace() + ); Ok(()) } diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index c26b3de0823ca..2536b4d50381d 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -222,7 +222,10 @@ fn can_evaluate_as_join_condition(predicate: &Expr) -> Result { | Expr::InSubquery(_) | Expr::ScalarSubquery(_) | Expr::OuterReferenceColumn(_, _) - | Expr::ScalarUDF(..) => { + | Expr::ScalarFunction(datafusion_expr::expr::ScalarFunction { + func_def: ScalarFunctionDefinition::UDF(_), + .. + }) => { is_evaluate = false; Ok(VisitRecursion::Stop) } diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index 02835f7ca3bf2..c84c4b9e2fec7 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -39,8 +39,8 @@ use datafusion_common::{ }; use datafusion_expr::expr::{InList, InSubquery, ScalarFunction}; use datafusion_expr::{ - and, expr, lit, or, BinaryExpr, BuiltinScalarFunction, Case, ColumnarValue, Expr, - Like, ScalarFunctionDefinition, Volatility, + and, lit, or, BinaryExpr, BuiltinScalarFunction, Case, ColumnarValue, Expr, Like, + ScalarFunctionDefinition, Volatility, }; use datafusion_physical_expr::{ create_physical_expr, execution_props::ExecutionProps, intervals::NullableInterval, @@ -354,9 +354,6 @@ impl<'a> ConstEvaluator<'a> { } ScalarFunctionDefinition::Name(_) => true, }, - Expr::ScalarUDF(expr::ScalarUDF { fun, .. }) => { - Self::volatility_ok(fun.signature().volatility) - } Expr::Literal(_) | Expr::BinaryExpr { .. } | Expr::Not(_) @@ -1561,7 +1558,7 @@ mod tests { // immutable UDF should get folded // udf_add(1+2, 30+40) --> 73 - let expr = Expr::ScalarUDF(expr::ScalarUDF::new( + let expr = Expr::ScalarFunction(expr::ScalarFunction::new_udf( make_udf_add(Volatility::Immutable), args.clone(), )); @@ -1570,15 +1567,21 @@ mod tests { // stable UDF should be entirely folded // udf_add(1+2, 30+40) --> 73 let fun = make_udf_add(Volatility::Stable); - let expr = Expr::ScalarUDF(expr::ScalarUDF::new(Arc::clone(&fun), args.clone())); + let expr = Expr::ScalarFunction(expr::ScalarFunction::new_udf( + Arc::clone(&fun), + args.clone(), + )); test_evaluate(expr, lit(73)); // volatile UDF should have args folded // udf_add(1+2, 30+40) --> udf_add(3, 70) let fun = make_udf_add(Volatility::Volatile); - let expr = Expr::ScalarUDF(expr::ScalarUDF::new(Arc::clone(&fun), args)); - let expected_expr = - Expr::ScalarUDF(expr::ScalarUDF::new(Arc::clone(&fun), folded_args)); + let expr = + Expr::ScalarFunction(expr::ScalarFunction::new_udf(Arc::clone(&fun), args)); + let expected_expr = Expr::ScalarFunction(expr::ScalarFunction::new_udf( + Arc::clone(&fun), + folded_args, + )); test_evaluate(expr, expected_expr); } diff --git a/datafusion/physical-expr/src/planner.rs b/datafusion/physical-expr/src/planner.rs index beabc75954f73..5501647da2c36 100644 --- a/datafusion/physical-expr/src/planner.rs +++ b/datafusion/physical-expr/src/planner.rs @@ -29,7 +29,7 @@ use datafusion_common::{ exec_err, internal_err, not_impl_err, plan_err, DFSchema, DataFusionError, Result, ScalarValue, }; -use datafusion_expr::expr::{Alias, Cast, InList, ScalarFunction, ScalarUDF}; +use datafusion_expr::expr::{Alias, Cast, InList, ScalarFunction}; use datafusion_expr::{ binary_expr, Between, BinaryExpr, Expr, GetFieldAccess, GetIndexedField, Like, Operator, ScalarFunctionDefinition, TryCast, @@ -392,22 +392,6 @@ pub fn create_physical_expr( internal_err!("Function `Expr` with name should be resolved.") } }, - Expr::ScalarUDF(ScalarUDF { fun, args }) => { - let mut physical_args = vec![]; - for e in args { - physical_args.push(create_physical_expr( - e, - input_dfschema, - input_schema, - execution_props, - )?); - } - // udfs with zero params expect null array as input - if args.is_empty() { - physical_args.push(Arc::new(Literal::new(ScalarValue::Null))); - } - udf::create_physical_expr(fun.clone().as_ref(), &physical_args, input_schema) - } Expr::Between(Between { expr, negated, diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 94c9f98066217..8365ef2e1d7cf 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -1705,7 +1705,7 @@ pub fn parse_expr( } ExprType::ScalarUdfExpr(protobuf::ScalarUdfExprNode { fun_name, args }) => { let scalar_fn = registry.udf(fun_name.as_str())?; - Ok(Expr::ScalarUDF(expr::ScalarUDF::new( + Ok(Expr::ScalarFunction(expr::ScalarFunction::new_udf( scalar_fn, args.iter() .map(|expr| parse_expr(expr, registry)) diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index a4fbe38901624..0e59f8077c80c 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -45,7 +45,7 @@ use datafusion_common::{ }; use datafusion_expr::expr::{ self, Alias, Between, BinaryExpr, Cast, GetFieldAccess, GetIndexedField, GroupingSet, - InList, Like, Placeholder, ScalarFunction, ScalarFunctionDefinition, ScalarUDF, Sort, + InList, Like, Placeholder, ScalarFunction, ScalarFunctionDefinition, Sort, }; use datafusion_expr::{ logical_plan::PlanType, logical_plan::StringifiedPlan, AggregateFunction, @@ -786,15 +786,6 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode { )); } }, - Expr::ScalarUDF(ScalarUDF { fun, args }) => Self { - expr_type: Some(ExprType::ScalarUdfExpr(protobuf::ScalarUdfExprNode { - fun_name: fun.name().to_string(), - args: args - .iter() - .map(|expr| expr.try_into()) - .collect::, Error>>()?, - })), - }, Expr::AggregateUDF(expr::AggregateUDF { fun, args, diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index 75af9d2e0acb7..fc7848277f37d 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -39,7 +39,7 @@ use datafusion_common::{internal_err, not_impl_err, plan_err}; use datafusion_common::{DFField, DFSchema, DFSchemaRef, DataFusionError, ScalarValue}; use datafusion_expr::expr::{ self, Between, BinaryExpr, Case, Cast, GroupingSet, InList, Like, ScalarFunction, - ScalarUDF, Sort, + Sort, }; use datafusion_expr::logical_plan::{Extension, UserDefinedLogicalNodeCore}; use datafusion_expr::{ @@ -1364,7 +1364,10 @@ fn roundtrip_scalar_udf() { scalar_fn, ); - let test_expr = Expr::ScalarUDF(ScalarUDF::new(Arc::new(udf.clone()), vec![lit("")])); + let test_expr = Expr::ScalarFunction(ScalarFunction::new_udf( + Arc::new(udf.clone()), + vec![lit("")], + )); let ctx = SessionContext::new(); ctx.register_udf(udf); diff --git a/datafusion/sql/src/expr/function.rs b/datafusion/sql/src/expr/function.rs index c77ef64718bbe..24ba4d1b506ae 100644 --- a/datafusion/sql/src/expr/function.rs +++ b/datafusion/sql/src/expr/function.rs @@ -19,7 +19,7 @@ use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; use datafusion_common::{ not_impl_err, plan_datafusion_err, plan_err, DFSchema, DataFusionError, Result, }; -use datafusion_expr::expr::{ScalarFunction, ScalarUDF}; +use datafusion_expr::expr::ScalarFunction; use datafusion_expr::function::suggest_valid_function; use datafusion_expr::window_frame::regularize; use datafusion_expr::{ @@ -66,7 +66,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // user-defined function (UDF) should have precedence in case it has the same name as a scalar built-in function if let Some(fm) = self.context_provider.get_function_meta(&name) { let args = self.function_args_to_expr(args, schema, planner_context)?; - return Ok(Expr::ScalarUDF(ScalarUDF::new(fm, args))); + return Ok(Expr::ScalarFunction(ScalarFunction::new_udf(fm, args))); } // next, scalar built-in From f0ffeacf2c1a10a3e3942242831c0419104957bf Mon Sep 17 00:00:00 2001 From: Yongting You <2010youy01@gmail.com> Date: Sun, 19 Nov 2023 21:33:40 -0800 Subject: [PATCH 3/6] review comments --- datafusion/core/src/physical_planner.rs | 16 +++---- datafusion/expr/src/expr.rs | 48 ++++++++----------- .../simplify_expressions/expr_simplifier.rs | 2 +- .../substrait/src/logical_plan/producer.rs | 13 ++--- 4 files changed, 33 insertions(+), 46 deletions(-) diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index 9d6f742b2d26b..505d2452f285d 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -218,15 +218,13 @@ fn create_physical_name(e: &Expr, is_first_expr: bool) -> Result { Ok(name) } - Expr::ScalarFunction(func_expr) => { - let func_name = match &func_expr.func_def { - ScalarFunctionDefinition::BuiltIn(fun) => Ok(fun.to_string()), - ScalarFunctionDefinition::UDF(fun) => Ok(fun.name().to_string()), - ScalarFunctionDefinition::Name(_) => { - internal_err!("Function `Expr` with name should be resolved.") - } - }; - create_function_physical_name(&func_name?, false, &func_expr.args) + Expr::ScalarFunction(expr::ScalarFunction { func_def, args }) => { + // function should be resolved during `AnalyzerRule`s + if let ScalarFunctionDefinition::Name(_) = func_def { + return internal_err!("Function `Expr` with name should be resolved."); + } + + create_function_physical_name(&func_def.name(), false, args) } Expr::WindowFunction(WindowFunction { fun, args, .. }) => { create_function_physical_name(&fun.to_string(), false, args) diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 7adef69a7d40d..d1313fd5dfd92 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -337,6 +337,7 @@ impl Between { } #[derive(Debug, Clone, PartialEq, Eq, Hash)] +/// Defines which implementation of a function for DataFusion to call. pub enum ScalarFunctionDefinition { /// Resolved to a `BuiltinScalarFunction` /// There is plan to migrate `BuiltinScalarFunction` to UDF-based implementation (issue#8045) @@ -344,8 +345,8 @@ pub enum ScalarFunctionDefinition { BuiltIn(built_in_function::BuiltinScalarFunction), /// Resolved to a user defined function UDF(Arc), - /// A scalar function constructed with name, could be resolved with registered functions during - /// analyzing. + /// A scalar function constructed with name. This variant can not be executed directly + /// and instead must be resolved to one of the other variants prior to physical planning. Name(Arc), } @@ -358,6 +359,17 @@ pub struct ScalarFunction { pub args: Vec, } +impl ScalarFunctionDefinition { + /// Function's name for display + pub fn name(&self) -> String { + match self { + ScalarFunctionDefinition::BuiltIn(builtin_func) => builtin_func.to_string(), + ScalarFunctionDefinition::UDF(udf) => udf.name().to_string(), + ScalarFunctionDefinition::Name(func_name) => func_name.as_ref().to_string(), + } + } +} + impl ScalarFunction { /// Create a new ScalarFunction expression pub fn new(fun: built_in_function::BuiltinScalarFunction, args: Vec) -> Self { @@ -1201,21 +1213,9 @@ impl fmt::Display for Expr { write!(f, " NULLS LAST") } } - Expr::ScalarFunction(func_expr) => match &func_expr.func_def { - ScalarFunctionDefinition::BuiltIn(builtin_func) => fmt_function( - f, - &builtin_func.to_string(), - false, - &func_expr.args, - true, - ), - ScalarFunctionDefinition::UDF(udf) => { - fmt_function(f, udf.name(), false, &func_expr.args, true) - } - ScalarFunctionDefinition::Name(func_name) => { - fmt_function(f, func_name, false, &func_expr.args, true) - } - }, + Expr::ScalarFunction(ScalarFunction { func_def, args }) => { + fmt_function(f, &func_def.name(), false, args, true) + } Expr::WindowFunction(WindowFunction { fun, args, @@ -1546,17 +1546,9 @@ fn create_name(e: &Expr) -> Result { } } } - Expr::ScalarFunction(func_expr) => match &func_expr.func_def { - ScalarFunctionDefinition::BuiltIn(builtin_func) => { - create_function_name(&builtin_func.to_string(), false, &func_expr.args) - } - ScalarFunctionDefinition::UDF(udf) => { - create_function_name(udf.name(), false, &func_expr.args) - } - ScalarFunctionDefinition::Name(name) => { - create_function_name(name, false, &func_expr.args) - } - }, + Expr::ScalarFunction(ScalarFunction { func_def, args }) => { + create_function_name(&func_def.name(), false, args) + } Expr::WindowFunction(WindowFunction { fun, args, diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index c84c4b9e2fec7..b4273e5cdf4ce 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -352,7 +352,7 @@ impl<'a> ConstEvaluator<'a> { ScalarFunctionDefinition::UDF(fun) => { Self::volatility_ok(fun.signature().volatility) } - ScalarFunctionDefinition::Name(_) => true, + ScalarFunctionDefinition::Name(_) => false, }, Expr::Literal(_) | Expr::BinaryExpr { .. } diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index 9c7ad0647e2f1..de033c8f267c0 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -835,15 +835,12 @@ pub fn to_substrait_rex( }); } - let func_name = match &func_def { - ScalarFunctionDefinition::BuiltIn(fun) => Ok(fun.to_string()), - ScalarFunctionDefinition::UDF(fun) => Ok(fun.name().to_string()), - ScalarFunctionDefinition::Name(_) => { - internal_err!("Function `Expr` with name should be resolved.") - } - }; + // function should be resolved during `AnalyzerRule` + if let ScalarFunctionDefinition::Name(_) = func_def { + return internal_err!("Function `Expr` with name should be resolved."); + } - let function_anchor = _register_function(func_name?, extension_info); + let function_anchor = _register_function(func_def.name(), extension_info); Ok(Expression { rex_type: Some(RexType::ScalarFunction(ScalarFunction { function_reference: function_anchor, From f59edb9482c904454c2e5041c94614d17a2ed057 Mon Sep 17 00:00:00 2001 From: Yongting You <2010youy01@gmail.com> Date: Wed, 22 Nov 2023 13:14:31 -0800 Subject: [PATCH 4/6] make name() return &str --- .../core/src/datasource/listing/helpers.rs | 2 +- datafusion/core/src/physical_planner.rs | 2 +- datafusion/expr/src/expr.rs | 22 +++++++++------ datafusion/expr/src/expr_fn.rs | 22 +++++++-------- datafusion/expr/src/expr_schema.rs | 2 +- datafusion/expr/src/tree_node/expr.rs | 2 +- .../optimizer/src/analyzer/type_coercion.rs | 20 +++++++------- datafusion/optimizer/src/push_down_filter.rs | 2 +- .../simplify_expressions/expr_simplifier.rs | 27 ++++++++++++++----- .../src/simplify_expressions/utils.rs | 12 +++++++-- datafusion/physical-expr/src/planner.rs | 2 +- datafusion/proto/src/logical_plan/to_proto.rs | 2 +- .../substrait/src/logical_plan/consumer.rs | 9 +++---- .../substrait/src/logical_plan/producer.rs | 3 ++- 14 files changed, 78 insertions(+), 51 deletions(-) diff --git a/datafusion/core/src/datasource/listing/helpers.rs b/datafusion/core/src/datasource/listing/helpers.rs index 19129986e4d4b..5f0467e7182a9 100644 --- a/datafusion/core/src/datasource/listing/helpers.rs +++ b/datafusion/core/src/datasource/listing/helpers.rs @@ -92,7 +92,7 @@ pub fn expr_applicable_for_cols(col_names: &[String], expr: &Expr) -> bool { Expr::ScalarFunction(scalar_function) => { match &scalar_function.func_def { - ScalarFunctionDefinition::BuiltIn(fun) => { + ScalarFunctionDefinition::BuiltIn { fun, .. } => { match fun.volatility() { Volatility::Immutable => Ok(VisitRecursion::Continue), // TODO: Stable functions could be `applicable`, but that would require access to the context diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index 505d2452f285d..07e141987d0b0 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -224,7 +224,7 @@ fn create_physical_name(e: &Expr, is_first_expr: bool) -> Result { return internal_err!("Function `Expr` with name should be resolved."); } - create_function_physical_name(&func_def.name(), false, args) + create_function_physical_name(func_def.name(), false, args) } Expr::WindowFunction(WindowFunction { fun, args, .. }) => { create_function_physical_name(&fun.to_string(), false, args) diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index d1313fd5dfd92..13e488dac042e 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -342,7 +342,10 @@ pub enum ScalarFunctionDefinition { /// Resolved to a `BuiltinScalarFunction` /// There is plan to migrate `BuiltinScalarFunction` to UDF-based implementation (issue#8045) /// This variant is planned to be removed in long term - BuiltIn(built_in_function::BuiltinScalarFunction), + BuiltIn { + fun: built_in_function::BuiltinScalarFunction, + name: Arc, + }, /// Resolved to a user defined function UDF(Arc), /// A scalar function constructed with name. This variant can not be executed directly @@ -361,11 +364,11 @@ pub struct ScalarFunction { impl ScalarFunctionDefinition { /// Function's name for display - pub fn name(&self) -> String { + pub fn name(&self) -> &str { match self { - ScalarFunctionDefinition::BuiltIn(builtin_func) => builtin_func.to_string(), - ScalarFunctionDefinition::UDF(udf) => udf.name().to_string(), - ScalarFunctionDefinition::Name(func_name) => func_name.as_ref().to_string(), + ScalarFunctionDefinition::BuiltIn { name, .. } => name.as_ref(), + ScalarFunctionDefinition::UDF(udf) => udf.name(), + ScalarFunctionDefinition::Name(func_name) => func_name.as_ref(), } } } @@ -374,7 +377,10 @@ impl ScalarFunction { /// Create a new ScalarFunction expression pub fn new(fun: built_in_function::BuiltinScalarFunction, args: Vec) -> Self { Self { - func_def: ScalarFunctionDefinition::BuiltIn(fun), + func_def: ScalarFunctionDefinition::BuiltIn { + fun, + name: Arc::from(fun.to_string()), + }, args, } } @@ -1214,7 +1220,7 @@ impl fmt::Display for Expr { } } Expr::ScalarFunction(ScalarFunction { func_def, args }) => { - fmt_function(f, &func_def.name(), false, args, true) + fmt_function(f, func_def.name(), false, args, true) } Expr::WindowFunction(WindowFunction { fun, @@ -1547,7 +1553,7 @@ fn create_name(e: &Expr) -> Result { } } Expr::ScalarFunction(ScalarFunction { func_def, args }) => { - create_function_name(&func_def.name(), false, args) + create_function_name(func_def.name(), false, args) } Expr::WindowFunction(WindowFunction { fun, diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 3c05f430914b7..79fe9b2c0276f 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -1023,7 +1023,7 @@ mod test { macro_rules! test_unary_scalar_expr { ($ENUM:ident, $FUNC:ident) => {{ if let Expr::ScalarFunction(ScalarFunction { - func_def: ScalarFunctionDefinition::BuiltIn(fun), + func_def: ScalarFunctionDefinition::BuiltIn { fun, .. }, args, }) = $FUNC(col("tableA.a")) { @@ -1044,7 +1044,7 @@ mod test { col(stringify!($arg.to_string())) ),* ); - if let Expr::ScalarFunction(ScalarFunction { func_def: ScalarFunctionDefinition::BuiltIn(fun), args }) = result { + if let Expr::ScalarFunction(ScalarFunction { func_def: ScalarFunctionDefinition::BuiltIn{fun, ..}, args }) = result { let name = built_in_function::BuiltinScalarFunction::$ENUM; assert_eq!(name, fun); assert_eq!(expected.len(), args.len()); @@ -1064,7 +1064,7 @@ mod test { ),* ] ); - if let Expr::ScalarFunction(ScalarFunction { func_def: ScalarFunctionDefinition::BuiltIn(fun), args }) = result { + if let Expr::ScalarFunction(ScalarFunction { func_def: ScalarFunctionDefinition::BuiltIn{fun, ..}, args }) = result { let name = built_in_function::BuiltinScalarFunction::$ENUM; assert_eq!(name, fun); assert_eq!(expected.len(), args.len()); @@ -1202,12 +1202,12 @@ mod test { #[test] fn uuid_function_definitions() { if let Expr::ScalarFunction(ScalarFunction { - func_def: ScalarFunctionDefinition::BuiltIn(func_def), + func_def: ScalarFunctionDefinition::BuiltIn { fun, .. }, args, }) = uuid() { let name = BuiltinScalarFunction::Uuid; - assert_eq!(name, func_def); + assert_eq!(name, fun); assert_eq!(0, args.len()); } else { unreachable!(); @@ -1217,12 +1217,12 @@ mod test { #[test] fn digest_function_definitions() { if let Expr::ScalarFunction(ScalarFunction { - func_def: ScalarFunctionDefinition::BuiltIn(func_def), + func_def: ScalarFunctionDefinition::BuiltIn { fun, .. }, args, }) = digest(col("tableA.a"), lit("md5")) { let name = BuiltinScalarFunction::Digest; - assert_eq!(name, func_def); + assert_eq!(name, fun); assert_eq!(2, args.len()); } else { unreachable!(); @@ -1232,12 +1232,12 @@ mod test { #[test] fn encode_function_definitions() { if let Expr::ScalarFunction(ScalarFunction { - func_def: ScalarFunctionDefinition::BuiltIn(func_def), + func_def: ScalarFunctionDefinition::BuiltIn { fun, .. }, args, }) = encode(col("tableA.a"), lit("base64")) { let name = BuiltinScalarFunction::Encode; - assert_eq!(name, func_def); + assert_eq!(name, fun); assert_eq!(2, args.len()); } else { unreachable!(); @@ -1247,12 +1247,12 @@ mod test { #[test] fn decode_function_definitions() { if let Expr::ScalarFunction(ScalarFunction { - func_def: ScalarFunctionDefinition::BuiltIn(func_def), + func_def: ScalarFunctionDefinition::BuiltIn { fun, .. }, args, }) = decode(col("tableA.a"), lit("hex")) { let name = BuiltinScalarFunction::Decode; - assert_eq!(name, func_def); + assert_eq!(name, fun); assert_eq!(2, args.len()); } else { unreachable!(); diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index e13765353d587..d5d9c848b2e9a 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -84,7 +84,7 @@ impl ExprSchemable for Expr { | Expr::TryCast(TryCast { data_type, .. }) => Ok(data_type.clone()), Expr::ScalarFunction(ScalarFunction { func_def, args }) => { match func_def { - ScalarFunctionDefinition::BuiltIn(fun) => { + ScalarFunctionDefinition::BuiltIn { fun, .. } => { let arg_data_types = args .iter() .map(|e| e.get_type(schema)) diff --git a/datafusion/expr/src/tree_node/expr.rs b/datafusion/expr/src/tree_node/expr.rs index f50939a8e66d4..474b5f7689b95 100644 --- a/datafusion/expr/src/tree_node/expr.rs +++ b/datafusion/expr/src/tree_node/expr.rs @@ -277,7 +277,7 @@ impl TreeNode for Expr { nulls_first, )), Expr::ScalarFunction(ScalarFunction { func_def, args }) => match func_def { - ScalarFunctionDefinition::BuiltIn(fun) => Expr::ScalarFunction( + ScalarFunctionDefinition::BuiltIn { fun, .. } => Expr::ScalarFunction( ScalarFunction::new(fun, transform_vec(args, &mut transform)?), ), ScalarFunctionDefinition::UDF(fun) => Expr::ScalarFunction( diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index 05f438bd51274..99142e0c49eff 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -321,7 +321,7 @@ impl TreeNodeRewriter for TypeCoercionRewriter { Ok(Expr::Case(case)) } Expr::ScalarFunction(ScalarFunction { func_def, args }) => match func_def { - ScalarFunctionDefinition::BuiltIn(fun) => { + ScalarFunctionDefinition::BuiltIn { fun, .. } => { let new_args = coerce_arguments_for_signature( args.as_slice(), &self.schema, @@ -782,7 +782,7 @@ mod test { use datafusion_expr::{ cast, col, concat, concat_ws, create_udaf, is_true, AccumulatorFactoryFunction, AggregateFunction, AggregateUDF, BinaryExpr, BuiltinScalarFunction, Case, - ColumnarValue, ExprSchemable, Filter, Operator, ScalarFunctionDefinition, + ColumnarValue, ExprSchemable, Filter, Operator, StateTypeFunction, Subquery, }; use datafusion_expr::{ @@ -1256,10 +1256,10 @@ mod test { None, ), ))); - let expr = Expr::ScalarFunction(ScalarFunction { - func_def: ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::MakeArray), - args: vec![val.clone()], - }); + let expr = Expr::ScalarFunction(ScalarFunction::new( + BuiltinScalarFunction::MakeArray, + vec![val.clone()], + )); let schema = Arc::new(DFSchema::new_with_metadata( vec![DFField::new_unqualified( "item", @@ -1288,10 +1288,10 @@ mod test { &schema, )?; - let expected = Expr::ScalarFunction(ScalarFunction { - func_def: ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::MakeArray), - args: vec![expected_casted_expr], - }); + let expected = Expr::ScalarFunction(ScalarFunction::new( + BuiltinScalarFunction::MakeArray, + vec![expected_casted_expr], + )); assert_eq!(result, expected); Ok(()) diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index 2536b4d50381d..7a2c6a8d8ccdd 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -982,7 +982,7 @@ fn is_volatile_expression(e: &Expr) -> bool { e.apply(&mut |expr| { Ok(match expr { Expr::ScalarFunction(f) => match &f.func_def { - ScalarFunctionDefinition::BuiltIn(fun) + ScalarFunctionDefinition::BuiltIn { fun, .. } if fun.volatility() == Volatility::Volatile => { is_volatile = true; diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index b4273e5cdf4ce..53a712d14a36f 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -346,7 +346,7 @@ impl<'a> ConstEvaluator<'a> { | Expr::Wildcard { .. } | Expr::Placeholder(_) => false, Expr::ScalarFunction(ScalarFunction { func_def, .. }) => match func_def { - ScalarFunctionDefinition::BuiltIn(fun) => { + ScalarFunctionDefinition::BuiltIn { fun, .. } => { Self::volatility_ok(fun.volatility()) } ScalarFunctionDefinition::UDF(fun) => { @@ -1204,28 +1204,41 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { // log Expr::ScalarFunction(ScalarFunction { - func_def: ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::Log), + func_def: + ScalarFunctionDefinition::BuiltIn { + fun: BuiltinScalarFunction::Log, + .. + }, args, }) => simpl_log(args, <&S>::clone(&info))?, // power Expr::ScalarFunction(ScalarFunction { - func_def: ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::Power), + func_def: + ScalarFunctionDefinition::BuiltIn { + fun: BuiltinScalarFunction::Power, + .. + }, args, }) => simpl_power(args, <&S>::clone(&info))?, // concat Expr::ScalarFunction(ScalarFunction { - func_def: ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::Concat), + func_def: + ScalarFunctionDefinition::BuiltIn { + fun: BuiltinScalarFunction::Concat, + .. + }, args, }) => simpl_concat(args)?, // concat_ws Expr::ScalarFunction(ScalarFunction { func_def: - ScalarFunctionDefinition::BuiltIn( - BuiltinScalarFunction::ConcatWithSeparator, - ), + ScalarFunctionDefinition::BuiltIn { + fun: BuiltinScalarFunction::ConcatWithSeparator, + .. + }, args, }) => match &args[..] { [delimiter, vals @ ..] => simpl_concat_ws(delimiter, vals)?, diff --git a/datafusion/optimizer/src/simplify_expressions/utils.rs b/datafusion/optimizer/src/simplify_expressions/utils.rs index fa91a3ace2a25..e69207b6889a0 100644 --- a/datafusion/optimizer/src/simplify_expressions/utils.rs +++ b/datafusion/optimizer/src/simplify_expressions/utils.rs @@ -365,7 +365,11 @@ pub fn simpl_log(current_args: Vec, info: &dyn SimplifyInfo) -> Result Ok(args[1].clone()), _ => { @@ -405,7 +409,11 @@ pub fn simpl_power(current_args: Vec, info: &dyn SimplifyInfo) -> Result Ok(args[1].clone()), _ => Ok(Expr::ScalarFunction(ScalarFunction::new( diff --git a/datafusion/physical-expr/src/planner.rs b/datafusion/physical-expr/src/planner.rs index 5501647da2c36..5c5cc8e36fa7d 100644 --- a/datafusion/physical-expr/src/planner.rs +++ b/datafusion/physical-expr/src/planner.rs @@ -349,7 +349,7 @@ pub fn create_physical_expr( } Expr::ScalarFunction(ScalarFunction { func_def, args }) => match func_def { - ScalarFunctionDefinition::BuiltIn(fun) => { + ScalarFunctionDefinition::BuiltIn { fun, .. } => { let physical_args = args .iter() .map(|e| { diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 0e59f8077c80c..c72909039575c 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -753,7 +753,7 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode { )) } Expr::ScalarFunction(ScalarFunction { func_def, args }) => match func_def { - ScalarFunctionDefinition::BuiltIn(fun) => { + ScalarFunctionDefinition::BuiltIn { fun, .. } => { let fun: protobuf::ScalarFunction = fun.try_into()?; let args: Vec = args .iter() diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index 9d670fddc36d6..d80f84deb468e 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -26,7 +26,7 @@ use datafusion::logical_expr::{ }; use datafusion::logical_expr::{ expr, Cast, Extension, GroupingSet, Like, LogicalPlanBuilder, - ScalarFunctionDefinition, WindowFrameBound, WindowFrameUnits, + WindowFrameBound, WindowFrameUnits, }; use datafusion::prelude::JoinType; use datafusion::sql::TableReference; @@ -843,10 +843,9 @@ pub async fn from_substrait_rex( }; args.push(arg_expr?.as_ref().clone()); } - Ok(Arc::new(Expr::ScalarFunction(expr::ScalarFunction { - func_def: ScalarFunctionDefinition::BuiltIn(fun), - args, - }))) + Ok(Arc::new(Expr::ScalarFunction(expr::ScalarFunction::new( + fun, args, + )))) } ScalarFunctionType::Op(op) => { if f.arguments.len() != 2 { diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index de033c8f267c0..95604e6d2db97 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -840,7 +840,8 @@ pub fn to_substrait_rex( return internal_err!("Function `Expr` with name should be resolved."); } - let function_anchor = _register_function(func_def.name(), extension_info); + let function_anchor = + _register_function(func_def.name().to_string(), extension_info); Ok(Expression { rex_type: Some(RexType::ScalarFunction(ScalarFunction { function_reference: function_anchor, From a6997ec6efecee09afd38d59286420d3c30bc87f Mon Sep 17 00:00:00 2001 From: Yongting You <2010youy01@gmail.com> Date: Wed, 22 Nov 2023 13:21:48 -0800 Subject: [PATCH 5/6] fix fmt --- datafusion/optimizer/src/analyzer/type_coercion.rs | 3 +-- datafusion/substrait/src/logical_plan/consumer.rs | 4 ++-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index 99142e0c49eff..6628e8961e263 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -782,8 +782,7 @@ mod test { use datafusion_expr::{ cast, col, concat, concat_ws, create_udaf, is_true, AccumulatorFactoryFunction, AggregateFunction, AggregateUDF, BinaryExpr, BuiltinScalarFunction, Case, - ColumnarValue, ExprSchemable, Filter, Operator, - StateTypeFunction, Subquery, + ColumnarValue, ExprSchemable, Filter, Operator, StateTypeFunction, Subquery, }; use datafusion_expr::{ lit, diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index d80f84deb468e..5cb72adaca4df 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -25,8 +25,8 @@ use datafusion::logical_expr::{ BuiltinScalarFunction, Case, Expr, LogicalPlan, Operator, }; use datafusion::logical_expr::{ - expr, Cast, Extension, GroupingSet, Like, LogicalPlanBuilder, - WindowFrameBound, WindowFrameUnits, + expr, Cast, Extension, GroupingSet, Like, LogicalPlanBuilder, WindowFrameBound, + WindowFrameUnits, }; use datafusion::prelude::JoinType; use datafusion::sql::TableReference; From 71bb177249ef9a20cb5057e903056b3549c02591 Mon Sep 17 00:00:00 2001 From: Yongting You <2010youy01@gmail.com> Date: Wed, 22 Nov 2023 13:41:50 -0800 Subject: [PATCH 6/6] fix after merge --- datafusion/sql/src/expr/value.rs | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/datafusion/sql/src/expr/value.rs b/datafusion/sql/src/expr/value.rs index 0f086bca68191..f33e9e8ddf78d 100644 --- a/datafusion/sql/src/expr/value.rs +++ b/datafusion/sql/src/expr/value.rs @@ -24,8 +24,8 @@ use datafusion_common::{ }; use datafusion_expr::expr::ScalarFunction; use datafusion_expr::expr::{BinaryExpr, Placeholder}; -use datafusion_expr::BuiltinScalarFunction; use datafusion_expr::{lit, Expr, Operator}; +use datafusion_expr::{BuiltinScalarFunction, ScalarFunctionDefinition}; use log::debug; use sqlparser::ast::{BinaryOperator, Expr as SQLExpr, Interval, Value}; use sqlparser::parser::ParserError::ParserError; @@ -143,8 +143,11 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { Expr::Literal(_) => { values.push(value); } - Expr::ScalarFunction(ref scalar_function) => { - if scalar_function.fun == BuiltinScalarFunction::MakeArray { + Expr::ScalarFunction(ScalarFunction { + func_def: ScalarFunctionDefinition::BuiltIn { fun, .. }, + .. + }) => { + if fun == BuiltinScalarFunction::MakeArray { values.push(value); } else { return not_impl_err!(