diff --git a/datafusion/src/optimizer/constant_folding.rs b/datafusion/src/optimizer/constant_folding.rs index 4d8f06fb2844d..e1f36cdeee1c3 100644 --- a/datafusion/src/optimizer/constant_folding.rs +++ b/datafusion/src/optimizer/constant_folding.rs @@ -15,24 +15,27 @@ // specific language governing permissions and limitations // under the License. -//! Boolean comparison rule rewrites redundant comparison expression involving boolean literal into -//! unary expression. +//! This module contains an optimizer which performs boolean simplification and constant folding use std::sync::Arc; -use arrow::compute::kernels::cast_utils::string_to_timestamp_nanos; use arrow::datatypes::DataType; use crate::error::Result; use crate::execution::context::ExecutionProps; -use crate::logical_plan::{DFSchemaRef, Expr, ExprRewriter, LogicalPlan, Operator}; +use crate::logical_plan::{DFSchemaRef, Expr, LogicalPlan, Operator}; use crate::optimizer::optimizer::OptimizerRule; use crate::optimizer::utils; -use crate::physical_plan::functions::BuiltinScalarFunction; +use crate::optimizer::utils::evaluate_const_expr_unchecked; +use crate::physical_plan::functions::{BuiltinScalarFunction, Volatility}; use crate::scalar::ScalarValue; -use arrow::compute::{kernels, DEFAULT_CAST_OPTIONS}; -/// Optimizer that simplifies comparison expressions involving boolean literals. +struct ConstantRewriter<'a> { + execution_props: &'a ExecutionProps, + schemas: Vec<&'a DFSchemaRef>, +} + +/// Optimizer that evaluates scalar expressions and simplifies comparison expressions involving boolean literals. /// /// Recursively go through all expressions and simplify the following cases: /// * `expr = true` and `expr != false` to `expr` when `expr` is of boolean type @@ -49,27 +52,24 @@ impl ConstantFolding { Self {} } } - impl OptimizerRule for ConstantFolding { fn optimize( &self, plan: &LogicalPlan, execution_props: &ExecutionProps, ) -> Result { - // We need to pass down the all schemas within the plan tree to `optimize_expr` in order to - // to evaluate expression types. For example, a projection plan's schema will only include - // projected columns. With just the projected schema, it's not possible to infer types for - // expressions that references non-projected columns within the same project plan or its - // children plans. let mut rewriter = ConstantRewriter { - schemas: plan.all_schemas(), execution_props, + schemas: plan.all_schemas(), }; match plan { LogicalPlan::Filter { predicate, input } => Ok(LogicalPlan::Filter { - predicate: predicate.clone().rewrite(&mut rewriter)?, - input: Arc::new(self.optimize(input, execution_props)?), + predicate: rewriter.rewrite(predicate.clone()), + input: match self.optimize(input, execution_props) { + Ok(plan) => Arc::new(plan), + _ => input.clone(), + }, }), // Rest: recurse into plan, apply optimization where possible LogicalPlan::Projection { .. } @@ -89,14 +89,17 @@ impl OptimizerRule for ConstantFolding { let inputs = plan.inputs(); let new_inputs = inputs .iter() - .map(|plan| self.optimize(plan, execution_props)) - .collect::>>()?; + .map(|plan| match self.optimize(plan, execution_props) { + Ok(opt_plan) => opt_plan, + _ => (*plan).clone(), + }) + .collect::>(); let expr = plan .expressions() .into_iter() - .map(|e| e.rewrite(&mut rewriter)) - .collect::>>()?; + .map(|e| rewriter.rewrite(e)) + .collect::>(); utils::from_plan(plan, &expr, &new_inputs) } @@ -107,16 +110,10 @@ impl OptimizerRule for ConstantFolding { } fn name(&self) -> &str { - "constant_folding" + "const_folder" } } -struct ConstantRewriter<'a> { - /// input schemas - schemas: Vec<&'a DFSchemaRef>, - execution_props: &'a ExecutionProps, -} - impl<'a> ConstantRewriter<'a> { fn is_boolean_type(&self, expr: &Expr) -> bool { for schema in &self.schemas { @@ -127,164 +124,365 @@ impl<'a> ConstantRewriter<'a> { false } -} -impl<'a> ExprRewriter for ConstantRewriter<'a> { - /// rewrite the expression simplifying any constant expressions - fn mutate(&mut self, expr: Expr) -> Result { - let new_expr = match expr { - Expr::BinaryExpr { left, op, right } => match op { - Operator::Eq => match (left.as_ref(), right.as_ref()) { - ( - Expr::Literal(ScalarValue::Boolean(l)), - Expr::Literal(ScalarValue::Boolean(r)), - ) => match (l, r) { - (Some(l), Some(r)) => { - Expr::Literal(ScalarValue::Boolean(Some(l == r))) - } - _ => Expr::Literal(ScalarValue::Boolean(None)), - }, - (Expr::Literal(ScalarValue::Boolean(b)), _) - if self.is_boolean_type(&right) => - { - match b { - Some(true) => *right, - Some(false) => Expr::Not(right), - None => Expr::Literal(ScalarValue::Boolean(None)), + pub fn rewrite(&mut self, mut expr: Expr) -> Expr { + let name = match &expr { + Expr::Alias(_, name) => Some(name.clone()), + _ => None, + }; + + let rewrite_root = self.rewrite_const_expr(&mut expr); + if rewrite_root { + match evaluate_const_expr_unchecked(&expr) { + Ok(s) => expr = Expr::Literal(s), + Err(_) => return expr, + } + } + match name { + Some(name) => { + let existing_alias = match &expr { + Expr::Alias(_, new_alias) => Some(new_alias.as_str()), + _ => None, + }; + let apply_new_alias = match existing_alias { + Some(new) => *new != name, + None => true, + }; + if apply_new_alias { + expr = Expr::Alias(Box::new(expr), name); + } + expr + } + None => expr, + } + } + + ///Evaluates all literal expressions in the list. + fn const_fold_list_eager(&mut self, args: &mut Vec) { + for arg in args.iter_mut() { + if self.rewrite_const_expr(arg) { + if let Ok(s) = evaluate_const_expr_unchecked(arg) { + *arg = Expr::Literal(s); + } + } + } + } + ///Tests to see if the list passed in is all literal expressions, if they are then it returns true. + ///If some expressions are not literal then the literal expressions are evaluate_const_expr_uncheckedd and it returns false. + fn const_fold_list(&mut self, args: &mut Vec) -> bool { + let can_rewrite = args + .iter_mut() + .map(|e| self.rewrite_const_expr(e)) + .collect::>(); + if can_rewrite.iter().all(|f| *f) { + return true; + } else { + for (rewrite_expr, expr) in can_rewrite.iter().zip(args) { + if *rewrite_expr { + if let Ok(s) = evaluate_const_expr_unchecked(expr) { + *expr = Expr::Literal(s); + } + } + } + } + false + } + ///This attempts to simplify expressions of the form col(Boolean) = Boolean and col(Boolean) != Boolean + /// e.g. col(Boolean) = Some(true) -> col(Boolean). It also handles == and != between two boolean literals as + /// the binary operator physical expression currently doesn't handle them. + + fn binary_column_const_fold( + &mut self, + left: &mut Box, + op: &Operator, + right: &mut Box, + ) -> Option { + let expr = match (left.as_ref(), op, right.as_ref()) { + ( + Expr::Literal(ScalarValue::Boolean(l)), + Operator::Eq, + Expr::Literal(ScalarValue::Boolean(r)), + ) => { + let literal_bool = Expr::Literal(ScalarValue::Boolean(match (l, r) { + (Some(l), Some(r)) => Some(*l == *r), + _ => None, + })); + Some(literal_bool) + } + ( + Expr::Literal(ScalarValue::Boolean(l)), + Operator::NotEq, + Expr::Literal(ScalarValue::Boolean(r)), + ) => { + let literal_bool = match (l, r) { + (Some(l), Some(r)) => { + Expr::Literal(ScalarValue::Boolean(Some(l != r))) + } + _ => Expr::Literal(ScalarValue::Boolean(None)), + }; + Some(literal_bool) + } + (Expr::Literal(ScalarValue::Boolean(b)), Operator::Eq, col) + | (col, Operator::Eq, Expr::Literal(ScalarValue::Boolean(b))) + if self.is_boolean_type(col) => + { + Some(match b { + Some(true) => col.clone(), + Some(false) => Expr::Not(Box::new(col.clone())), + None => Expr::Literal(ScalarValue::Boolean(None)), + }) + } + (Expr::Literal(ScalarValue::Boolean(b)), Operator::NotEq, col) + | (col, Operator::NotEq, Expr::Literal(ScalarValue::Boolean(b))) + if self.is_boolean_type(col) => + { + Some(match b { + Some(true) => Expr::Not(Box::new(col.clone())), + Some(false) => col.clone(), + None => Expr::Literal(ScalarValue::Boolean(None)), + }) + } + _ => None, + }; + expr + } + + fn rewrite_const_expr(&mut self, expr: &mut Expr) -> bool { + let can_rewrite = match expr { + Expr::Alias(e, _) => self.rewrite_const_expr(e), + Expr::Column(_) => false, + Expr::ScalarVariable(_) => false, + Expr::Literal(_) => true, + Expr::BinaryExpr { left, op, right } => { + //Check if left and right are const, much like the Not Not optimization this is done first to make sure any + //Non-scalar execution optimizations, such as col("test") = NULL->false are performed first + let left_const = self.rewrite_const_expr(left); + let right_const = self.rewrite_const_expr(right); + let mut can_rewrite = match (left_const, right_const) { + (true, true) => true, + (false, false) => false, + (true, false) => { + if let Ok(s) = evaluate_const_expr_unchecked(left) { + *left.as_mut() = Expr::Literal(s); } + false } - (_, Expr::Literal(ScalarValue::Boolean(b))) - if self.is_boolean_type(&left) => - { - match b { - Some(true) => *left, - Some(false) => Expr::Not(left), - None => Expr::Literal(ScalarValue::Boolean(None)), + (false, true) => { + if let Ok(s) = evaluate_const_expr_unchecked(right) { + *right.as_mut() = Expr::Literal(s); } + false } - _ => Expr::BinaryExpr { - left, - op: Operator::Eq, - right, - }, - }, - Operator::NotEq => match (left.as_ref(), right.as_ref()) { - ( - Expr::Literal(ScalarValue::Boolean(l)), - Expr::Literal(ScalarValue::Boolean(r)), - ) => match (l, r) { - (Some(l), Some(r)) => { - Expr::Literal(ScalarValue::Boolean(Some(l != r))) + }; + + can_rewrite = match self.binary_column_const_fold(left, op, right) { + Some(e) => { + let expr: &mut Expr = expr; + *expr = e; + self.rewrite_const_expr(expr) + } + None => can_rewrite, + }; + + can_rewrite + } + + Expr::Not(e) => { + //Check if the expression can be rewritten. This may trigger simplifications such as col("b") = false -> NOT col("b") + //Then check if inner expression is Not and if so replace expr with the inner + let can_rewrite = self.rewrite_const_expr(e); + match e.as_mut() { + Expr::Not(inner) => { + let inner = std::mem::replace(inner.as_mut(), Expr::Wildcard); + *expr = inner; + self.rewrite_const_expr(expr) + } + _ => can_rewrite, + } + } + Expr::IsNotNull(e) => self.rewrite_const_expr(e), + Expr::IsNull(e) => self.rewrite_const_expr(e), + Expr::Negative(e) => self.rewrite_const_expr(e), + Expr::Between { + expr, low, high, .. + } => match ( + self.rewrite_const_expr(expr), + self.rewrite_const_expr(low), + self.rewrite_const_expr(high), + ) { + (true, true, true) => true, + (expr_const, low_const, high_const) => { + if expr_const { + if let Ok(s) = evaluate_const_expr_unchecked(expr) { + let expr: &mut Expr = expr; + *expr = Expr::Literal(s); } - _ => Expr::Literal(ScalarValue::Boolean(None)), - }, - (Expr::Literal(ScalarValue::Boolean(b)), _) - if self.is_boolean_type(&right) => - { - match b { - Some(true) => Expr::Not(right), - Some(false) => *right, - None => Expr::Literal(ScalarValue::Boolean(None)), + } + if low_const { + if let Ok(s) = evaluate_const_expr_unchecked(expr) { + let expr: &mut Expr = expr; + *expr = Expr::Literal(s); } } - (_, Expr::Literal(ScalarValue::Boolean(b))) - if self.is_boolean_type(&left) => - { - match b { - Some(true) => Expr::Not(left), - Some(false) => *left, - None => Expr::Literal(ScalarValue::Boolean(None)), + if high_const { + if let Ok(s) = evaluate_const_expr_unchecked(expr) { + let expr: &mut Expr = expr; + *expr = Expr::Literal(s); } } - _ => Expr::BinaryExpr { - left, - op: Operator::NotEq, - right, - }, - }, - _ => Expr::BinaryExpr { left, op, right }, + false + } }, - Expr::Not(inner) => { - // Not(Not(expr)) --> expr - if let Expr::Not(negated_inner) = *inner { - *negated_inner - } else { - Expr::Not(inner) + Expr::Case { + expr, + when_then_expr, + else_expr, + } => { + if expr + .as_mut() + .map(|e| self.rewrite_const_expr(e)) + .unwrap_or(false) + { + let expr_inner = expr.as_mut().unwrap(); + if let Ok(s) = evaluate_const_expr_unchecked(expr_inner) { + *expr_inner.as_mut() = Expr::Literal(s); + } + } + + if else_expr + .as_mut() + .map(|e| self.rewrite_const_expr(e)) + .unwrap_or(false) + { + let expr_inner = else_expr.as_mut().unwrap(); + if let Ok(s) = evaluate_const_expr_unchecked(expr_inner) { + *expr_inner.as_mut() = Expr::Literal(s); + } } + + for (when, then) in when_then_expr { + let when: &mut Expr = when; + let then: &mut Expr = then; + if self.rewrite_const_expr(when) { + if let Ok(s) = evaluate_const_expr_unchecked(when) { + *when = Expr::Literal(s); + } + } + if self.rewrite_const_expr(then) { + if let Ok(s) = evaluate_const_expr_unchecked(then) { + *then = Expr::Literal(s); + } + } + } + false + } + Expr::Cast { expr, .. } => self.rewrite_const_expr(expr), + Expr::TryCast { expr, .. } => self.rewrite_const_expr(expr), + Expr::Sort { expr, .. } => { + if self.rewrite_const_expr(expr) { + if let Ok(s) = evaluate_const_expr_unchecked(expr) { + let expr: &mut Expr = expr; + *expr = Expr::Literal(s); + } + } + false } Expr::ScalarFunction { fun: BuiltinScalarFunction::Now, .. - } => Expr::Literal(ScalarValue::TimestampNanosecond(Some( - self.execution_props - .query_execution_start_time - .timestamp_nanos(), - ))), - Expr::ScalarFunction { - fun: BuiltinScalarFunction::ToTimestamp, - args, } => { - if !args.is_empty() { - match &args[0] { - Expr::Literal(ScalarValue::Utf8(Some(val))) => { - match string_to_timestamp_nanos(val) { - Ok(timestamp) => Expr::Literal( - ScalarValue::TimestampNanosecond(Some(timestamp)), - ), - _ => Expr::ScalarFunction { - fun: BuiltinScalarFunction::ToTimestamp, - args, - }, - } + *expr = Expr::Literal(ScalarValue::TimestampNanosecond(Some( + self.execution_props + .query_execution_start_time + .timestamp_nanos(), + ))); + true + } + Expr::ScalarFunction { fun, args } => { + if args.is_empty() { + false + } else { + let volatility = fun.volatility(); + match volatility { + Volatility::Immutable => self.const_fold_list(args), + _ => { + self.const_fold_list_eager(args); + false } - _ => Expr::ScalarFunction { - fun: BuiltinScalarFunction::ToTimestamp, - args, - }, } + } + } + Expr::ScalarUDF { fun, args } => { + if args.is_empty() { + false } else { - Expr::ScalarFunction { - fun: BuiltinScalarFunction::ToTimestamp, - args, + let volatility = fun.volatility(); + match volatility { + Volatility::Immutable => self.const_fold_list(args), + _ => { + self.const_fold_list_eager(args); + false + } } } } - Expr::Cast { - expr: inner, - data_type, - } => match inner.as_ref() { - Expr::Literal(val) => { - let scalar_array = val.to_array(); - let cast_array = kernels::cast::cast_with_options( - &scalar_array, - &data_type, - &DEFAULT_CAST_OPTIONS, - )?; - let cast_scalar = ScalarValue::try_from_array(&cast_array, 0)?; - Expr::Literal(cast_scalar) + Expr::AggregateFunction { args, .. } => { + self.const_fold_list_eager(args); + false + } + Expr::WindowFunction { + args, + partition_by, + order_by, + .. + } => { + self.const_fold_list_eager(args); + self.const_fold_list_eager(partition_by); + self.const_fold_list_eager(order_by); + false + } + Expr::AggregateUDF { args, .. } => { + self.const_fold_list_eager(args); + false + } + Expr::InList { expr, list, .. } => { + let expr_const = self.rewrite_const_expr(expr); + let list_literals = self.const_fold_list(list); + match (expr_const, list_literals) { + (true, true) => true, + + (false, false) => false, + (true, false) => { + if let Ok(s) = evaluate_const_expr_unchecked(expr) { + let expr: &mut Expr = expr; + *expr = Expr::Literal(s); + } + false + } + (false, true) => { + self.const_fold_list_eager(list); + false + } } - _ => Expr::Cast { - expr: inner, - data_type, - }, - }, - expr => { - // no rewrite possible - expr } + Expr::Wildcard => false, }; - Ok(new_expr) + can_rewrite } } #[cfg(test)] mod tests { use super::*; - use crate::logical_plan::{ - col, lit, max, min, DFField, DFSchema, LogicalPlanBuilder, + use crate::{ + logical_plan::{ + abs, col, create_udf, lit, max, min, DFField, DFSchema, LogicalPlanBuilder, + }, + physical_plan::{functions::make_scalar_function, udf::ScalarUDF}, }; + use arrow::array::{ArrayRef, Float64Array}; + use arrow::datatypes::{Field, Schema}; - use arrow::datatypes::*; use chrono::{DateTime, Utc}; fn test_table_scan() -> Result { @@ -293,6 +491,7 @@ mod tests { Field::new("b", DataType::Boolean, false), Field::new("c", DataType::Boolean, false), Field::new("d", DataType::UInt32, false), + Field::new("e", DataType::Float64, false), ]); LogicalPlanBuilder::scan_empty(Some("test"), &schema, None)?.build() } @@ -316,7 +515,7 @@ mod tests { }; assert_eq!( - (col("c2").not().not().not()).rewrite(&mut rewriter)?, + rewriter.rewrite(col("c2").not().not().not()), col("c2").not(), ); @@ -333,26 +532,27 @@ mod tests { // x = null is always null assert_eq!( - (lit(true).eq(lit(ScalarValue::Boolean(None)))).rewrite(&mut rewriter)?, + rewriter.rewrite(lit(true).eq(lit(ScalarValue::Boolean(None)))), lit(ScalarValue::Boolean(None)), ); // null != null is always null assert_eq!( - (lit(ScalarValue::Boolean(None)).not_eq(lit(ScalarValue::Boolean(None)))) - .rewrite(&mut rewriter)?, + rewriter.rewrite( + lit(ScalarValue::Boolean(None)).not_eq(lit(ScalarValue::Boolean(None))) + ), lit(ScalarValue::Boolean(None)), ); // x != null is always null assert_eq!( - (col("c2").not_eq(lit(ScalarValue::Boolean(None)))).rewrite(&mut rewriter)?, + rewriter.rewrite(col("c2").not_eq(lit(ScalarValue::Boolean(None)))), lit(ScalarValue::Boolean(None)), ); // null = x is always null assert_eq!( - (lit(ScalarValue::Boolean(None)).eq(col("c2"))).rewrite(&mut rewriter)?, + rewriter.rewrite(lit(ScalarValue::Boolean(None)).eq(col("c2"))), lit(ScalarValue::Boolean(None)), ); @@ -370,22 +570,16 @@ mod tests { assert_eq!(col("c2").get_type(&schema)?, DataType::Boolean); // true = ture -> true - assert_eq!((lit(true).eq(lit(true))).rewrite(&mut rewriter)?, lit(true),); + assert_eq!(rewriter.rewrite(lit(true).eq(lit(true))), lit(true),); // true = false -> false - assert_eq!( - (lit(true).eq(lit(false))).rewrite(&mut rewriter)?, - lit(false), - ); + assert_eq!(rewriter.rewrite(lit(true).eq(lit(false))), lit(false),); // c2 = true -> c2 - assert_eq!((col("c2").eq(lit(true))).rewrite(&mut rewriter)?, col("c2"),); + assert_eq!(rewriter.rewrite(col("c2").eq(lit(true))), col("c2"),); // c2 = false => !c2 - assert_eq!( - (col("c2").eq(lit(false))).rewrite(&mut rewriter)?, - col("c2").not(), - ); + assert_eq!(rewriter.rewrite(col("c2").eq(lit(false))), col("c2").not(),); Ok(()) } @@ -406,24 +600,21 @@ mod tests { // don't fold c1 = true assert_eq!( - (col("c1").eq(lit(true))).rewrite(&mut rewriter)?, + rewriter.rewrite(col("c1").eq(lit(true))), col("c1").eq(lit(true)), ); // don't fold c1 = false assert_eq!( - (col("c1").eq(lit(false))).rewrite(&mut rewriter)?, + rewriter.rewrite(col("c1").eq(lit(false))), col("c1").eq(lit(false)), ); // test constant operands - assert_eq!( - (lit(1).eq(lit(true))).rewrite(&mut rewriter)?, - lit(1).eq(lit(true)), - ); + assert_eq!(rewriter.rewrite(lit(1).eq(lit(true))), lit(1).eq(lit(true)),); assert_eq!( - (lit("a").eq(lit(false))).rewrite(&mut rewriter)?, + rewriter.rewrite(lit("a").eq(lit(false))), lit("a").eq(lit(false)), ); @@ -442,26 +633,17 @@ mod tests { // c2 != true -> !c2 assert_eq!( - (col("c2").not_eq(lit(true))).rewrite(&mut rewriter)?, + rewriter.rewrite(col("c2").not_eq(lit(true))), col("c2").not(), ); // c2 != false -> c2 - assert_eq!( - (col("c2").not_eq(lit(false))).rewrite(&mut rewriter)?, - col("c2"), - ); + assert_eq!(rewriter.rewrite(col("c2").not_eq(lit(false))), col("c2"),); // test constant - assert_eq!( - (lit(true).not_eq(lit(true))).rewrite(&mut rewriter)?, - lit(false), - ); + assert_eq!(rewriter.rewrite(lit(true).not_eq(lit(true))), lit(false),); - assert_eq!( - (lit(true).not_eq(lit(false))).rewrite(&mut rewriter)?, - lit(true), - ); + assert_eq!(rewriter.rewrite(lit(true).not_eq(lit(false))), lit(true),); Ok(()) } @@ -479,23 +661,23 @@ mod tests { assert_eq!(col("c1").get_type(&schema)?, DataType::Utf8); assert_eq!( - (col("c1").not_eq(lit(true))).rewrite(&mut rewriter)?, + rewriter.rewrite(col("c1").not_eq(lit(true))), col("c1").not_eq(lit(true)), ); assert_eq!( - (col("c1").not_eq(lit(false))).rewrite(&mut rewriter)?, + rewriter.rewrite(col("c1").not_eq(lit(false))), col("c1").not_eq(lit(false)), ); // test constants assert_eq!( - (lit(1).not_eq(lit(true))).rewrite(&mut rewriter)?, + rewriter.rewrite(lit(1).not_eq(lit(true))), lit(1).not_eq(lit(true)), ); assert_eq!( - (lit("a").not_eq(lit(false))).rewrite(&mut rewriter)?, + rewriter.rewrite(lit("a").not_eq(lit(false))), lit("a").not_eq(lit(false)), ); @@ -511,15 +693,14 @@ mod tests { }; assert_eq!( - (Box::new(Expr::Case { + rewriter.rewrite(Expr::Case { expr: None, when_then_expr: vec![( Box::new(col("c2").not_eq(lit(false))), Box::new(lit("ok").eq(lit(true))), )], else_expr: Some(Box::new(col("c2").eq(lit(true)))), - })) - .rewrite(&mut rewriter)?, + }), Expr::Case { expr: None, when_then_expr: vec![( @@ -623,7 +804,6 @@ mod tests { .filter(col("b").eq(lit(false)).not())? .project(vec![col("a")])? .build()?; - let expected = "\ Projection: #test.a\ \n Filter: #test.b\ @@ -767,7 +947,7 @@ mod tests { #[test] fn cast_expr_wrong_arg() { let table_scan = test_table_scan().unwrap(); - let proj = vec![Expr::Cast { + let proj = vec![Expr::TryCast { expr: Box::new(Expr::Literal(ScalarValue::Utf8(Some("".to_string())))), data_type: DataType::Int32, }]; @@ -840,4 +1020,114 @@ mod tests { assert_eq!(actual, expected); } + + fn create_pow_with_volatilty(volatilty: Volatility) -> ScalarUDF { + let pow = |args: &[ArrayRef]| { + assert_eq!(args.len(), 2); + + let base = &args[0] + .as_any() + .downcast_ref::() + .expect("cast failed"); + let exponent = &args[1] + .as_any() + .downcast_ref::() + .expect("cast failed"); + + assert_eq!(exponent.len(), base.len()); + + let array = base + .iter() + .zip(exponent.iter()) + .map(|(base, exponent)| match (base, exponent) { + (Some(base), Some(exponent)) => Some(base.powf(exponent)), + _ => None, + }) + .collect::(); + Ok(Arc::new(array) as ArrayRef) + }; + let pow = make_scalar_function(pow); + let name = match volatilty { + Volatility::Immutable => "pow", + Volatility::Stable => "pow_stable", + Volatility::Volatile => "pow_vol", + }; + create_udf( + name, + vec![DataType::Float64, DataType::Float64], + Arc::new(DataType::Float64), + volatilty, + pow, + ) + } + + #[test] + fn test_constant_evaluate_binop() -> Result<()> { + let scan = test_table_scan()?; + + //Trying to get non literal expression that has the value Boolean(NULL) so that the symbolic constant_folding can be tested + let proj = vec![Expr::TryCast { + expr: Box::new(lit("")), + data_type: DataType::Int32, + } + .eq(lit(0)) + .eq(col("a"))]; + let time = chrono::Utc::now(); + let plan = LogicalPlanBuilder::from(scan.clone()) + .project(proj) + .unwrap() + .build() + .unwrap(); + let actual = get_optimized_plan_formatted(&plan, &time); + let expected = "Projection: Boolean(NULL)\ + \n TableScan: test projection=None"; + assert_eq!(actual, expected); + + //Another test for boolean expression constant folding true = #test.a -> true + let proj = vec![Expr::TryCast { + expr: Box::new(lit("0")), + data_type: DataType::Int32, + } + .eq(lit(0)) + .eq(col("a"))]; + let time = chrono::Utc::now(); + let plan = LogicalPlanBuilder::from(scan) + .project(proj) + .unwrap() + .build() + .unwrap(); + let actual = get_optimized_plan_formatted(&plan, &time); + let expected = "Projection: #test.a\ + \n TableScan: test projection=None"; + assert_eq!(actual, expected); + + Ok(()) + } + + //Testing that immutable scalar UDFs are inlined, stable or volatile UDFs are not inlined, and the arguments to a stable or volatile UDF are still folded + #[test] + fn test_udf_inlining() -> Result<()> { + let scan = test_table_scan()?; + let pow_immut = create_pow_with_volatilty(Volatility::Immutable); + let pow_stab = create_pow_with_volatilty(Volatility::Stable); + let pow_vol = create_pow_with_volatilty(Volatility::Volatile); + let pow_res_2 = vec![abs(lit(1.0) - lit(3)), lit(2)]; + let proj = vec![ + pow_immut.call(pow_res_2.clone()) * col("e").alias("constant"), + pow_stab.call(pow_res_2.clone()) * col("e").alias("stable"), + pow_vol.call(pow_res_2) * col("e"), + ]; + let time = chrono::Utc::now(); + let plan = LogicalPlanBuilder::from(scan) + .project(proj) + .unwrap() + .build() + .unwrap(); + let actual = get_optimized_plan_formatted(&plan, &time); + let expected = "Projection: Float64(4) * #test.e AS constant, pow_stable(Float64(2), Int32(2)) * #test.e AS stable, pow_vol(Float64(2), Int32(2)) * #test.e\ + \n TableScan: test projection=None".to_string(); + + assert_eq!(actual, expected); + Ok(()) + } } diff --git a/datafusion/src/optimizer/utils.rs b/datafusion/src/optimizer/utils.rs index 6e64bf39b2e2d..2ccd903aa3375 100644 --- a/datafusion/src/optimizer/utils.rs +++ b/datafusion/src/optimizer/utils.rs @@ -23,6 +23,7 @@ use crate::logical_plan::{ build_join_schema, Column, DFSchemaRef, Expr, LogicalPlan, LogicalPlanBuilder, Operator, Partitioning, Recursion, }; +use crate::physical_plan::functions::Volatility; use crate::prelude::lit; use crate::scalar::ScalarValue; use crate::{ @@ -31,6 +32,10 @@ use crate::{ }; use std::{collections::HashSet, sync::Arc}; +use crate::logical_plan::DFSchema; +use arrow::datatypes::{DataType, Field}; +use arrow::record_batch::RecordBatch; + const CASE_EXPR_MARKER: &str = "__DATAFUSION_CASE_EXPR__"; const CASE_ELSE_MARKER: &str = "__DATAFUSION_CASE_ELSE__"; const WINDOW_PARTITION_MARKER: &str = "__DATAFUSION_WINDOW_PARTITION__"; @@ -468,6 +473,117 @@ pub fn rewrite_expression(expr: &Expr, expressions: &[Expr]) -> Result { } } +///Tests an expression to see if it contains only literal expressions such as 3+4.5 and immutable scalar builtins or UDFs. +pub fn expr_is_const(expr: &Expr) -> bool { + match expr { + Expr::Column(_) + | Expr::ScalarVariable(_) + | Expr::AggregateFunction { .. } + | Expr::WindowFunction { .. } + | Expr::AggregateUDF { .. } + | Expr::Sort { .. } + | Expr::Wildcard => false, + + Expr::Literal(_) => true, + + Expr::Alias(child, _) + | Expr::Not(child) + | Expr::IsNotNull(child) + | Expr::IsNull(child) + | Expr::Negative(child) + | Expr::Cast { expr: child, .. } + | Expr::TryCast { expr: child, .. } => expr_is_const(child), + + Expr::ScalarFunction { fun, args } => match fun.volatility() { + Volatility::Immutable => args.iter().all(|arg| expr_is_const(arg)), + Volatility::Stable | Volatility::Volatile => false, + }, + Expr::BinaryExpr { left, right, .. } => { + expr_is_const(left) && expr_is_const(right) + } + Expr::Between { + expr, low, high, .. + } => expr_is_const(expr) && expr_is_const(low) && expr_is_const(high), + Expr::Case { + expr, + when_then_expr, + else_expr, + } => { + let expr_constant = expr.as_ref().map(|e| expr_is_const(e)).unwrap_or(true); + let else_constant = + else_expr.as_ref().map(|e| expr_is_const(e)).unwrap_or(true); + let when_then_constant = when_then_expr + .iter() + .all(|(w, th)| expr_is_const(w) && expr_is_const(th)); + expr_constant && else_constant && when_then_constant + } + + Expr::ScalarUDF { fun, args } => match fun.volatility() { + Volatility::Immutable => args.iter().all(|arg| expr_is_const(arg)), + Volatility::Stable | Volatility::Volatile => false, + }, + + Expr::InList { expr, list, .. } => { + expr_is_const(expr) && list.iter().all(|e| expr_is_const(e)) + } + } +} + +///Evaluates an expression if it only contains literal expressions. If a non-literal expression or non-immutable function is found then it returns an error. +pub fn evalute_const_expr(expr: &Expr) -> Result { + if !expr_is_const(expr) { + return Err(DataFusionError::Execution("The expression was not contsant and could not be evaluated. This means it contained stable or volatile functions, aggregates, or column expressions.".to_owned())); + } + evaluate_const_expr_unchecked(expr) +} + +///Evaluates an expression. Note that this will return incorrect results if the expression contains function calls which change return values call to call, such as Now(). +pub fn evaluate_const_expr_unchecked(expr: &Expr) -> Result { + if let Expr::Literal(s) = expr { + return Ok(s.clone()); + } + //The dummy column name shouldn't really matter as only scalar expressions will be evaluated + static DUMMY_COL_NAME: &str = "."; + let dummy_df_schema = DFSchema::empty(); + let dummy_input_schema = arrow::datatypes::Schema::new(vec![Field::new( + DUMMY_COL_NAME, + DataType::Float64, + true, + )]); + let ctx_state = crate::execution::context::ExecutionContextState::new(); + + let planner = crate::physical_plan::planner::DefaultPhysicalPlanner::default(); + let phys_expr = planner.create_physical_expr( + expr, + &dummy_df_schema, + &dummy_input_schema, + &ctx_state, + )?; + let col = { + let mut builder = arrow::array::Float64Array::builder(1); + builder.append_null()?; + builder.finish() + }; + let record_batch = RecordBatch::try_new( + std::sync::Arc::new(dummy_input_schema), + vec![std::sync::Arc::new(col)], + )?; + let col_val = phys_expr.evaluate(&record_batch)?; + match col_val { + crate::physical_plan::ColumnarValue::Array(a) => { + if a.len() != 1 { + Err(DataFusionError::Execution(format!( + "Could not evaluate the expressison, found a result of length {}", + a.len() + ))) + } else { + Ok(ScalarValue::try_from_array(&a, 0)?) + } + } + crate::physical_plan::ColumnarValue::Scalar(s) => Ok(s), + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/datafusion/src/physical_plan/udf.rs b/datafusion/src/physical_plan/udf.rs index 39265c0eb5efa..824eb672d865b 100644 --- a/datafusion/src/physical_plan/udf.rs +++ b/datafusion/src/physical_plan/udf.rs @@ -25,6 +25,7 @@ use arrow::datatypes::Schema; use crate::error::Result; use crate::{logical_plan::Expr, physical_plan::PhysicalExpr}; +use super::functions::Volatility; use super::{ functions::{ ReturnTypeFunction, ScalarFunctionExpr, ScalarFunctionImplementation, Signature, @@ -95,6 +96,10 @@ impl ScalarUDF { fun: fun.clone(), } } + ///Returns the volatilaty of the UDF + pub fn volatility(&self) -> Volatility { + self.signature.volatility + } /// creates a logical expression with a call of the UDF /// This utility allows using the UDF without requiring access to the registry.