From 1b4af9adf5977199e0f93dfd4474bc3e29b4dd14 Mon Sep 17 00:00:00 2001 From: patrick Date: Wed, 13 Oct 2021 22:33:00 -0600 Subject: [PATCH 1/6] Added an evaluate function to evaluate literal expressions and extended constant folding optimizer to allow for more general folding --- datafusion/src/optimizer/constant_folding.rs | 714 ++++++++++++++----- datafusion/src/physical_plan/udf.rs | 5 + 2 files changed, 534 insertions(+), 185 deletions(-) diff --git a/datafusion/src/optimizer/constant_folding.rs b/datafusion/src/optimizer/constant_folding.rs index 4d8f06fb2844d..e294c3bcaf489 100644 --- a/datafusion/src/optimizer/constant_folding.rs +++ b/datafusion/src/optimizer/constant_folding.rs @@ -15,24 +15,34 @@ // specific language governing permissions and limitations // under the License. -//! Boolean comparison rule rewrites redundant comparison expression involving boolean literal into -//! unary expression. +//! This module contains an optimizer which performs boolean simplification and constant folding use std::sync::Arc; -use arrow::compute::kernels::cast_utils::string_to_timestamp_nanos; -use arrow::datatypes::DataType; +use arrow::array::Float64Array; +use arrow::datatypes::{DataType, Field, Schema}; +use arrow::record_batch::RecordBatch; -use crate::error::Result; -use crate::execution::context::ExecutionProps; -use crate::logical_plan::{DFSchemaRef, Expr, ExprRewriter, LogicalPlan, Operator}; +use crate::error::{DataFusionError, Result}; +use crate::execution::context::{ExecutionContextState, ExecutionProps}; +use crate::logical_plan::{ + DFSchema, DFSchemaRef, Expr, LogicalPlan, Operator, +}; use crate::optimizer::optimizer::OptimizerRule; use crate::optimizer::utils; -use crate::physical_plan::functions::BuiltinScalarFunction; +use crate::physical_plan::functions::{BuiltinScalarFunction, Volatility}; +use crate::physical_plan::planner::DefaultPhysicalPlanner; use crate::scalar::ScalarValue; -use arrow::compute::{kernels, DEFAULT_CAST_OPTIONS}; -/// Optimizer that simplifies comparison expressions involving boolean literals. + + + +struct ConstantRewriter<'a> { + execution_props: &'a ExecutionProps, + schemas: Vec<&'a DFSchemaRef>, +} + +/// Optimizer that evaluates scalar expressions and simplifies comparison expressions involving boolean literals. /// /// Recursively go through all expressions and simplify the following cases: /// * `expr = true` and `expr != false` to `expr` when `expr` is of boolean type @@ -41,35 +51,35 @@ use arrow::compute::{kernels, DEFAULT_CAST_OPTIONS}; /// * `false = true` and `true = false` to `false` /// * `!!expr` to `expr` /// * `expr = null` and `expr != null` to `null` -pub struct ConstantFolding {} +pub struct ConstantFolding{} -impl ConstantFolding { +impl ConstantFolding{ #[allow(missing_docs)] pub fn new() -> Self { Self {} } } - -impl OptimizerRule for ConstantFolding { +impl OptimizerRule for ConstantFolding{ fn optimize( &self, plan: &LogicalPlan, execution_props: &ExecutionProps, ) -> Result { - // We need to pass down the all schemas within the plan tree to `optimize_expr` in order to - // to evaluate expression types. For example, a projection plan's schema will only include - // projected columns. With just the projected schema, it's not possible to infer types for - // expressions that references non-projected columns within the same project plan or its - // children plans. - let mut rewriter = ConstantRewriter { + let mut rewriter = ConstantRewriter{ + execution_props: &execution_props, schemas: plan.all_schemas(), - execution_props, }; - + match plan { LogicalPlan::Filter { predicate, input } => Ok(LogicalPlan::Filter { - predicate: predicate.clone().rewrite(&mut rewriter)?, - input: Arc::new(self.optimize(input, execution_props)?), + predicate: match rewriter.rewrite(predicate.clone()){ + Ok(e)=> e, + _ => predicate.clone() + }, + input: match self.optimize(input, execution_props){ + Ok(plan) => Arc::new(plan.clone()), + _ => input.clone() + }, }), // Rest: recurse into plan, apply optimization where possible LogicalPlan::Projection { .. } @@ -89,14 +99,20 @@ impl OptimizerRule for ConstantFolding { let inputs = plan.inputs(); let new_inputs = inputs .iter() - .map(|plan| self.optimize(plan, execution_props)) - .collect::>>()?; + .map(|plan| match self.optimize(plan, execution_props){ + Ok(opt_plan) => opt_plan, + _=> (*plan).clone(), + }) + .collect::>(); let expr = plan .expressions() .into_iter() - .map(|e| e.rewrite(&mut rewriter)) - .collect::>>()?; + .map(|e| match rewriter.rewrite(e.clone()){ + Ok(expr) => expr, + Err(_) => e, + }) + .collect::>(); utils::from_plan(plan, &expr, &new_inputs) } @@ -107,17 +123,46 @@ impl OptimizerRule for ConstantFolding { } fn name(&self) -> &str { - "constant_folding" + "const_folder" } } -struct ConstantRewriter<'a> { - /// input schemas - schemas: Vec<&'a DFSchemaRef>, - execution_props: &'a ExecutionProps, +///Evaluate calculates the value of scalar expressions. This function may panic if columns are present within the expression +pub fn evaluate(expr: &Expr, exec_props: &ExecutionProps) -> Result { + if let Expr::Literal(s) = expr{ + return Ok(s.clone()); + } + //The dummy column name was chosen as to not interfere with any possible columns names in a normal schema. Unsure if this is needed as the schema of the columns + //is never used + static DUMMY_COL_NAME : &'static str = "."; + let dummy_df_schema = DFSchema::empty(); + let dummy_input_schema = Schema::new(vec![Field::new(DUMMY_COL_NAME, DataType::Float64, true)]); + let mut ctx_state = ExecutionContextState::new(); + ctx_state.execution_props = exec_props.clone(); + let planner = DefaultPhysicalPlanner::default(); + let phys_expr = planner.create_physical_expr(expr, &dummy_df_schema, &dummy_input_schema, &ctx_state)?; + let col = { + let mut builder = Float64Array::builder(1); + builder.append_null()?; + builder.finish() + }; + let record_batch = RecordBatch::try_new(Arc::new(dummy_input_schema), vec![Arc::new(col)])?; + let col_val = phys_expr.evaluate(&record_batch)?; + match col_val{ + crate::physical_plan::ColumnarValue::Array(a) => { + if a.len() != 1{ + Err(DataFusionError::Execution(format!("Could not evaluate the expressison, found a result of length {}", a.len()))) + }else{ + Ok(ScalarValue::try_from_array(&a,0)?) + } + }, + crate::physical_plan::ColumnarValue::Scalar(s) => Ok(s), + } + } impl<'a> ConstantRewriter<'a> { + fn is_boolean_type(&self, expr: &Expr) -> bool { for schema in &self.schemas { if let Ok(DataType::Boolean) = expr.get_type(schema) { @@ -127,164 +172,356 @@ impl<'a> ConstantRewriter<'a> { false } -} -impl<'a> ExprRewriter for ConstantRewriter<'a> { - /// rewrite the expression simplifying any constant expressions - fn mutate(&mut self, expr: Expr) -> Result { - let new_expr = match expr { - Expr::BinaryExpr { left, op, right } => match op { - Operator::Eq => match (left.as_ref(), right.as_ref()) { - ( - Expr::Literal(ScalarValue::Boolean(l)), - Expr::Literal(ScalarValue::Boolean(r)), - ) => match (l, r) { - (Some(l), Some(r)) => { - Expr::Literal(ScalarValue::Boolean(Some(l == r))) + + + + pub fn rewrite(&mut self, mut expr: Expr) -> Result { + let name = match &expr{ + Expr::Alias(_, name )=>Some(name.clone()), + _ => None, + }; + + let rewrite_root = self.rewrite_const_expr(&mut expr); + let expr = if rewrite_root{ + match evaluate(&expr, self.execution_props){ + Ok(s) => Expr::Literal(s), + Err(e) => { + println!("Could not rewrite: {}", e); + expr + }, + } + }else{ + expr + }; + Ok(match name{ + Some(name) => expr.alias(&name), + None => expr, + }) + } + + fn replace_expr(&self, expr: &mut Box, mut replacement: Expr) { + std::mem::swap(&mut replacement, expr); + } + + fn const_fold_list_eager(&mut self, args: &mut Vec){ + for arg in args.iter_mut(){ + if self.rewrite_const_expr(arg){ + match evaluate(arg,self.execution_props){ + Ok(s) => *arg = Expr::Literal(s), + _ => () + } + } + } + } + + fn const_fold_list(&mut self, args: &mut Vec)->bool{ + let can_rewrite= args.iter_mut().map(|e| self.rewrite_const_expr(e)).collect::>(); + if can_rewrite.iter().all(|f| *f){ + return true; + }else{ + for (rewrite_expr, expr) in can_rewrite.iter().zip(args){ + if *rewrite_expr{ + match evaluate(expr,self.execution_props){ + Ok(s) =>*expr = Expr::Literal(s), + _ =>(), + } + } + } + } + false + } + ///This attempts to simplify expressions of the form col(Boolean) = Boolean and col(Boolean) != Boolean + /// e.g. col(Boolean) = Some(true) -> col(Boolean) + + fn binary_column_const_fold(&mut self, left: &mut Box, op: &Operator, right: &mut Box)->Option{ + let expr = match (left.as_ref(), op, right.as_ref()){ + + (Expr::Literal(ScalarValue::Boolean(l)), Operator::Eq, Expr::Literal(ScalarValue::Boolean(r)))=>{ + let literal_bool = Expr::Literal(ScalarValue::Boolean( + match (l, r){ + + (Some(l), Some(r)) => Some(*l == *r), + _ => None, + })); + Some(literal_bool) + } + (Expr::Literal(ScalarValue::Boolean(l)), Operator::NotEq, Expr::Literal(ScalarValue::Boolean(r)))=>{ + let literal_bool = match (l,r){ + (Some(l), Some(r)) => { + Expr::Literal(ScalarValue::Boolean(Some(l != r))) + } + _ => Expr::Literal(ScalarValue::Boolean(None)), + }; + Some(literal_bool) + } + (Expr::Literal(ScalarValue::Boolean(b)), Operator::Eq, col) | + (col, Operator::Eq, Expr::Literal(ScalarValue::Boolean(b))) if self.is_boolean_type(&col) =>{ + Some(match b{ + Some(true)=>col.clone(), + Some(false) =>Expr::Not(Box::new(col.clone())), + None => Expr::Literal(ScalarValue::Boolean(None)), + }) + }, + (Expr::Literal(ScalarValue::Boolean(b)), Operator::NotEq, col) | + (col, Operator::NotEq, Expr::Literal(ScalarValue::Boolean(b))) if self.is_boolean_type(&col) =>{ + Some(match b{ + Some(true)=>Expr::Not(Box::new(col.clone())), + Some(false) => col.clone(), + None => Expr::Literal(ScalarValue::Boolean(None)), + }) + } + _ => None, + }; + expr + } + + + fn rewrite_const_expr(&mut self, expr: &mut Expr) -> bool { + + let can_rewrite = match expr { + Expr::Alias(e, _) => self.rewrite_const_expr(e), + Expr::Column(_) => false, + Expr::ScalarVariable(_) => false, + Expr::Literal(_) => true, + Expr::BinaryExpr { left, op, right } => { + //Check if left and right are const, much like the Not Not optimization this is done first to make sure any + //Non-scalar execution optimizations, such as col("test") = NULL->false are performed first + let left_const = self.rewrite_const_expr(left); + let right_const = self.rewrite_const_expr(right); + let mut can_rewrite = match (left_const, right_const) { + (true, true) => true, + (false, false) => false, + (true, false) => { + match evaluate(&left,self.execution_props) { + Ok(s) => { + let left: &mut Expr = left; + *left = Expr::Literal(s); + } + Err(_) => (), } - _ => Expr::Literal(ScalarValue::Boolean(None)), - }, - (Expr::Literal(ScalarValue::Boolean(b)), _) - if self.is_boolean_type(&right) => - { - match b { - Some(true) => *right, - Some(false) => Expr::Not(right), - None => Expr::Literal(ScalarValue::Boolean(None)), + false + } + (false, true) => { + match evaluate(&right,self.execution_props) { + Ok(s) => { + let right: &mut Expr = right; + *right = Expr::Literal(s); + } + Err(_) => (), + } + false + } + }; + + + can_rewrite= match self.binary_column_const_fold(left, op, right){ + Some(e) =>{ + let expr: &mut Expr = expr; + *expr = e; + self.rewrite_const_expr(expr) + } + None => can_rewrite + }; + + can_rewrite + + } + + Expr::Not(e) => { + //Check if the expression can be rewritten. This may trigger simplifications such as col("b") = false -> NOT col("b") + //Then check if inner expression is Not and if so replace expr with the inner + let can_rewrite = self.rewrite_const_expr(e); + match e.as_mut(){ + Expr::Not(inner)=>{ + let inner = std::mem::replace(inner.as_mut(),Expr::Wildcard); + *expr = inner; + self.rewrite_const_expr(expr) + } + _=>can_rewrite + } + }, + Expr::IsNotNull(e) => self.rewrite_const_expr(e), + Expr::IsNull(e) => self.rewrite_const_expr(e), + Expr::Negative(e) => self.rewrite_const_expr(e), + Expr::Between { + expr, + low, + high, .. + } => match ( + self.rewrite_const_expr(expr), + self.rewrite_const_expr(low), + self.rewrite_const_expr(high), + ) { + (true, true, true) => true, + (expr_const, low_const, high_const) => { + if expr_const { + if let Ok(s) = evaluate(expr,self.execution_props){ + self.replace_expr(expr, Expr::Literal(s)); } } - (_, Expr::Literal(ScalarValue::Boolean(b))) - if self.is_boolean_type(&left) => - { - match b { - Some(true) => *left, - Some(false) => Expr::Not(left), - None => Expr::Literal(ScalarValue::Boolean(None)), + if low_const { + if let Ok(s) = evaluate(expr,self.execution_props){ + self.replace_expr(low, Expr::Literal(s)); } } - _ => Expr::BinaryExpr { - left, - op: Operator::Eq, - right, - }, - }, - Operator::NotEq => match (left.as_ref(), right.as_ref()) { - ( - Expr::Literal(ScalarValue::Boolean(l)), - Expr::Literal(ScalarValue::Boolean(r)), - ) => match (l, r) { - (Some(l), Some(r)) => { - Expr::Literal(ScalarValue::Boolean(Some(l != r))) + if high_const { + if let Ok(s) = evaluate(expr,self.execution_props){ + self.replace_expr(high, Expr::Literal(s)); } - _ => Expr::Literal(ScalarValue::Boolean(None)), - }, - (Expr::Literal(ScalarValue::Boolean(b)), _) - if self.is_boolean_type(&right) => - { - match b { - Some(true) => Expr::Not(right), - Some(false) => *right, - None => Expr::Literal(ScalarValue::Boolean(None)), + } + false + } + }, + Expr::Case { expr, when_then_expr, else_expr } => { + if expr.as_mut().map(|e| self.rewrite_const_expr(e)).unwrap_or(false){ + let expr_inner = expr.as_mut().unwrap(); + match evaluate(expr_inner, self.execution_props){ + Ok(s) => *expr_inner.as_mut() = Expr::Literal(s), + Err(_) => (), + } + } + + if else_expr.as_mut().map(|e| self.rewrite_const_expr(e)).unwrap_or(false){ + let expr_inner = else_expr.as_mut().unwrap(); + match evaluate(expr_inner, self.execution_props){ + Ok(s) => *expr_inner.as_mut() = Expr::Literal(s), + Err(_) => (), + } + } + + for (when, then) in when_then_expr{ + let when: &mut Expr = when; + let then : &mut Expr = then; + if self.rewrite_const_expr(when){ + match evaluate(when, self.execution_props){ + Ok(s) => *when = Expr::Literal(s), + _ =>(), } } - (_, Expr::Literal(ScalarValue::Boolean(b))) - if self.is_boolean_type(&left) => - { - match b { - Some(true) => Expr::Not(left), - Some(false) => *left, - None => Expr::Literal(ScalarValue::Boolean(None)), + if self.rewrite_const_expr(then){ + match evaluate(then, self.execution_props){ + Ok(s) => *then = Expr::Literal(s), + Err(_) => (), } } - _ => Expr::BinaryExpr { - left, - op: Operator::NotEq, - right, - }, - }, - _ => Expr::BinaryExpr { left, op, right }, + } + false }, - Expr::Not(inner) => { - // Not(Not(expr)) --> expr - if let Expr::Not(negated_inner) = *inner { - *negated_inner - } else { - Expr::Not(inner) + Expr::Cast { expr, .. } => self.rewrite_const_expr(expr), + Expr::TryCast { expr, .. } => self.rewrite_const_expr(expr), + Expr::Sort { expr, .. } => { + if self.rewrite_const_expr(expr) { + match evaluate(expr,self.execution_props) { + Ok(s) => { + let expr: &mut Expr = expr; + *expr = Expr::Literal(s); + }, + Err(_) => (), + } } + false } Expr::ScalarFunction { fun: BuiltinScalarFunction::Now, .. - } => Expr::Literal(ScalarValue::TimestampNanosecond(Some( + } => { + *expr= Expr::Literal(ScalarValue::TimestampNanosecond(Some( self.execution_props .query_execution_start_time .timestamp_nanos(), - ))), - Expr::ScalarFunction { - fun: BuiltinScalarFunction::ToTimestamp, + ))); + true + }, + Expr::ScalarFunction { fun, args } => { + if args.is_empty(){ + false + }else{ + let volatility = fun.volatility(); + match volatility{ + Volatility::Immutable => self.const_fold_list(args), + _ =>{ + self.const_fold_list_eager(args); + false + } + } + } + }, + Expr::ScalarUDF { fun, args } => { + if args.is_empty(){ + false + }else{ + let volatility = fun.volatility(); + match volatility{ + Volatility::Immutable => self.const_fold_list(args), + _ =>{ + self.const_fold_list_eager(args); + false + } + } + } + }, + Expr::AggregateFunction { args, + .. } => { - if !args.is_empty() { - match &args[0] { - Expr::Literal(ScalarValue::Utf8(Some(val))) => { - match string_to_timestamp_nanos(val) { - Ok(timestamp) => Expr::Literal( - ScalarValue::TimestampNanosecond(Some(timestamp)), - ), - _ => Expr::ScalarFunction { - fun: BuiltinScalarFunction::ToTimestamp, - args, - }, - } + self.const_fold_list_eager(args); + false + }, + Expr::WindowFunction { + args, + partition_by, + order_by, + .. + } => { + self.const_fold_list_eager(args); + self.const_fold_list_eager(partition_by); + self.const_fold_list_eager(order_by); + false + }, + Expr::AggregateUDF { args,.. } => { + self.const_fold_list_eager(args); + false + }, + Expr::InList { + expr, + list, + .. + } => { + let expr_const = self.rewrite_const_expr(expr); + let list_literals = self.const_fold_list(list); + match (expr_const, list_literals){ + (true, true)=> true, + + (false, false) => false, + (true, false)=> { + if let Ok(s) = evaluate(expr,self.execution_props) { + let expr: &mut Expr = expr; + *expr = Expr::Literal(s); } - _ => Expr::ScalarFunction { - fun: BuiltinScalarFunction::ToTimestamp, - args, - }, - } - } else { - Expr::ScalarFunction { - fun: BuiltinScalarFunction::ToTimestamp, - args, + false } + (false, true)=>{ + self.const_fold_list_eager(list); + false + }, } - } - Expr::Cast { - expr: inner, - data_type, - } => match inner.as_ref() { - Expr::Literal(val) => { - let scalar_array = val.to_array(); - let cast_array = kernels::cast::cast_with_options( - &scalar_array, - &data_type, - &DEFAULT_CAST_OPTIONS, - )?; - let cast_scalar = ScalarValue::try_from_array(&cast_array, 0)?; - Expr::Literal(cast_scalar) - } - _ => Expr::Cast { - expr: inner, - data_type, - }, }, - expr => { - // no rewrite possible - expr - } + Expr::Wildcard => false, }; - Ok(new_expr) + println!("Can rewrite the expr[{}]: {}", can_rewrite, expr); + can_rewrite } } + + #[cfg(test)] mod tests { use super::*; - use crate::logical_plan::{ - col, lit, max, min, DFField, DFSchema, LogicalPlanBuilder, - }; + use arrow::array::{Float64Array, ArrayRef}; + use crate::{logical_plan::{DFField, DFSchema, LogicalPlanBuilder, col, create_udf, abs, lit, max, min}, physical_plan::{functions::make_scalar_function, udf::ScalarUDF}}; - use arrow::datatypes::*; use chrono::{DateTime, Utc}; fn test_table_scan() -> Result { @@ -293,6 +530,7 @@ mod tests { Field::new("b", DataType::Boolean, false), Field::new("c", DataType::Boolean, false), Field::new("d", DataType::UInt32, false), + Field::new("e", DataType::Float64, false) ]); LogicalPlanBuilder::scan_empty(Some("test"), &schema, None)?.build() } @@ -316,7 +554,7 @@ mod tests { }; assert_eq!( - (col("c2").not().not().not()).rewrite(&mut rewriter)?, + rewriter.rewrite(col("c2").not().not().not())?, col("c2").not(), ); @@ -333,26 +571,25 @@ mod tests { // x = null is always null assert_eq!( - (lit(true).eq(lit(ScalarValue::Boolean(None)))).rewrite(&mut rewriter)?, + rewriter.rewrite(lit(true).eq(lit(ScalarValue::Boolean(None))))?, lit(ScalarValue::Boolean(None)), ); // null != null is always null assert_eq!( - (lit(ScalarValue::Boolean(None)).not_eq(lit(ScalarValue::Boolean(None)))) - .rewrite(&mut rewriter)?, + rewriter.rewrite(lit(ScalarValue::Boolean(None)).not_eq(lit(ScalarValue::Boolean(None))))?, lit(ScalarValue::Boolean(None)), ); // x != null is always null assert_eq!( - (col("c2").not_eq(lit(ScalarValue::Boolean(None)))).rewrite(&mut rewriter)?, + rewriter.rewrite(col("c2").not_eq(lit(ScalarValue::Boolean(None))))?, lit(ScalarValue::Boolean(None)), ); // null = x is always null assert_eq!( - (lit(ScalarValue::Boolean(None)).eq(col("c2"))).rewrite(&mut rewriter)?, + rewriter.rewrite(lit(ScalarValue::Boolean(None)).eq(col("c2")))?, lit(ScalarValue::Boolean(None)), ); @@ -370,20 +607,20 @@ mod tests { assert_eq!(col("c2").get_type(&schema)?, DataType::Boolean); // true = ture -> true - assert_eq!((lit(true).eq(lit(true))).rewrite(&mut rewriter)?, lit(true),); + assert_eq!(rewriter.rewrite(lit(true).eq(lit(true)))?, lit(true),); // true = false -> false assert_eq!( - (lit(true).eq(lit(false))).rewrite(&mut rewriter)?, + rewriter.rewrite(lit(true).eq(lit(false)))?, lit(false), ); // c2 = true -> c2 - assert_eq!((col("c2").eq(lit(true))).rewrite(&mut rewriter)?, col("c2"),); + assert_eq!(rewriter.rewrite(col("c2").eq(lit(true)))?, col("c2"),); // c2 = false => !c2 assert_eq!( - (col("c2").eq(lit(false))).rewrite(&mut rewriter)?, + rewriter.rewrite(col("c2").eq(lit(false)))?, col("c2").not(), ); @@ -406,24 +643,24 @@ mod tests { // don't fold c1 = true assert_eq!( - (col("c1").eq(lit(true))).rewrite(&mut rewriter)?, + rewriter.rewrite(col("c1").eq(lit(true)))?, col("c1").eq(lit(true)), ); // don't fold c1 = false assert_eq!( - (col("c1").eq(lit(false))).rewrite(&mut rewriter)?, + rewriter.rewrite(col("c1").eq(lit(false)))?, col("c1").eq(lit(false)), ); // test constant operands assert_eq!( - (lit(1).eq(lit(true))).rewrite(&mut rewriter)?, + rewriter.rewrite(lit(1).eq(lit(true)))?, lit(1).eq(lit(true)), ); assert_eq!( - (lit("a").eq(lit(false))).rewrite(&mut rewriter)?, + rewriter.rewrite(lit("a").eq(lit(false)))?, lit("a").eq(lit(false)), ); @@ -442,24 +679,24 @@ mod tests { // c2 != true -> !c2 assert_eq!( - (col("c2").not_eq(lit(true))).rewrite(&mut rewriter)?, + rewriter.rewrite(col("c2").not_eq(lit(true)))?, col("c2").not(), ); // c2 != false -> c2 assert_eq!( - (col("c2").not_eq(lit(false))).rewrite(&mut rewriter)?, + rewriter.rewrite(col("c2").not_eq(lit(false)))?, col("c2"), ); // test constant assert_eq!( - (lit(true).not_eq(lit(true))).rewrite(&mut rewriter)?, + rewriter.rewrite(lit(true).not_eq(lit(true)))?, lit(false), ); assert_eq!( - (lit(true).not_eq(lit(false))).rewrite(&mut rewriter)?, + rewriter.rewrite(lit(true).not_eq(lit(false)))?, lit(true), ); @@ -479,23 +716,23 @@ mod tests { assert_eq!(col("c1").get_type(&schema)?, DataType::Utf8); assert_eq!( - (col("c1").not_eq(lit(true))).rewrite(&mut rewriter)?, + rewriter.rewrite(col("c1").not_eq(lit(true)))?, col("c1").not_eq(lit(true)), ); assert_eq!( - (col("c1").not_eq(lit(false))).rewrite(&mut rewriter)?, + rewriter.rewrite(col("c1").not_eq(lit(false)))?, col("c1").not_eq(lit(false)), ); // test constants assert_eq!( - (lit(1).not_eq(lit(true))).rewrite(&mut rewriter)?, + rewriter.rewrite(lit(1).not_eq(lit(true)))?, lit(1).not_eq(lit(true)), ); assert_eq!( - (lit("a").not_eq(lit(false))).rewrite(&mut rewriter)?, + rewriter.rewrite(lit("a").not_eq(lit(false)))?, lit("a").not_eq(lit(false)), ); @@ -511,15 +748,14 @@ mod tests { }; assert_eq!( - (Box::new(Expr::Case { + rewriter.rewrite(Expr::Case { expr: None, when_then_expr: vec![( Box::new(col("c2").not_eq(lit(false))), Box::new(lit("ok").eq(lit(true))), )], else_expr: Some(Box::new(col("c2").eq(lit(true)))), - })) - .rewrite(&mut rewriter)?, + })?, Expr::Case { expr: None, when_then_expr: vec![( @@ -623,7 +859,6 @@ mod tests { .filter(col("b").eq(lit(false)).not())? .project(vec![col("a")])? .build()?; - let expected = "\ Projection: #test.a\ \n Filter: #test.b\ @@ -767,7 +1002,7 @@ mod tests { #[test] fn cast_expr_wrong_arg() { let table_scan = test_table_scan().unwrap(); - let proj = vec![Expr::Cast { + let proj = vec![Expr::TryCast { expr: Box::new(Expr::Literal(ScalarValue::Utf8(Some("".to_string())))), data_type: DataType::Int32, }]; @@ -840,4 +1075,113 @@ mod tests { assert_eq!(actual, expected); } -} + + fn create_pow_with_volatilty(volatilty: Volatility)->ScalarUDF{ + let pow = |args: &[ArrayRef]| { + + assert_eq!(args.len(), 2); + + let base = &args[0] + .as_any() + .downcast_ref::() + .expect("cast failed"); + let exponent = &args[1] + .as_any() + .downcast_ref::() + .expect("cast failed"); + + assert_eq!(exponent.len(), base.len()); + + let array = base + .iter() + .zip(exponent.iter()) + .map(|(base, exponent)| { + match (base, exponent) { + (Some(base), Some(exponent)) => Some(base.powf(exponent)), + _ => None, + } + }) + .collect::(); + Ok(Arc::new(array) as ArrayRef) + }; + let pow = make_scalar_function(pow); + let name = match volatilty{ + Volatility::Immutable => "pow", + Volatility::Stable => "pow_stable", + Volatility::Volatile => "pow_vol", + }; + create_udf( + name, + vec![DataType::Float64, DataType::Float64], + Arc::new(DataType::Float64), + volatilty, + pow, + ) + } + + #[test] + fn test_constant_evaluate_binop()->Result<()>{ + let scan = test_table_scan()?; + + //Trying to get non literal expression that has the value Boolean(NULL) so that the symbolic constant_folding can be tested + let proj = vec![Expr::TryCast{expr: Box::new(lit("")), data_type: DataType::Int32}.eq(lit(0)).eq(col("a"))]; + let time = chrono::Utc::now(); + let plan = LogicalPlanBuilder::from(scan.clone()) + .project(proj) + .unwrap() + .build() + .unwrap(); + let actual = get_optimized_plan_formatted(&plan, &time); + let expected = "Projection: Boolean(NULL)\ + \n TableScan: test projection=None"; + assert_eq!(actual, expected); + + + //Another test for boolean expression constant folding true = #test.a -> true + let proj = vec![Expr::TryCast{expr: Box::new(lit("0")), data_type: DataType::Int32}.eq(lit(0)).eq(col("a"))]; + let time = chrono::Utc::now(); + let plan = LogicalPlanBuilder::from(scan) + .project(proj) + .unwrap() + .build() + .unwrap(); + let actual = get_optimized_plan_formatted(&plan, &time); + let expected = "Projection: #test.a\ + \n TableScan: test projection=None"; + assert_eq!(actual, expected); + + Ok(()) + + } + + //Testing that immutable scalar UDFs are inlined, stable or volatile UDFs are not inlined, and the arguments to a stable or volatile UDF are still folded + #[test] + fn test_udf_inlining()->Result<()>{ + + let scan = test_table_scan()?; + let pow_immut = create_pow_with_volatilty(Volatility::Immutable); + let pow_stab = create_pow_with_volatilty(Volatility::Stable); + let pow_vol = create_pow_with_volatilty(Volatility::Volatile); + let pow_res_2 = vec![abs(lit(1.0) - lit(3)), lit(2)]; + let proj = vec![ + pow_immut.call(pow_res_2.clone())*col("e").alias("constant"), + pow_stab.call(pow_res_2.clone())*col("e").alias("stable"), + pow_vol.call(pow_res_2)*col("e") + ]; + let time = chrono::Utc::now(); + let plan = LogicalPlanBuilder::from(scan) + .project(proj) + .unwrap() + .build() + .unwrap(); + let actual = get_optimized_plan_formatted(&plan, &time); + let expected = format!( + "Projection: Float64(4) * #test.e AS constant, pow_stable(Float64(2), Int32(2)) * #test.e AS stable, pow_vol(Float64(2), Int32(2)) * #test.e\ + \n TableScan: test projection=None", + ); + + assert_eq!(actual, expected); + Ok(()) + } + +} \ No newline at end of file diff --git a/datafusion/src/physical_plan/udf.rs b/datafusion/src/physical_plan/udf.rs index 39265c0eb5efa..835c39f2d55c7 100644 --- a/datafusion/src/physical_plan/udf.rs +++ b/datafusion/src/physical_plan/udf.rs @@ -25,6 +25,7 @@ use arrow::datatypes::Schema; use crate::error::Result; use crate::{logical_plan::Expr, physical_plan::PhysicalExpr}; +use super::functions::Volatility; use super::{ functions::{ ReturnTypeFunction, ScalarFunctionExpr, ScalarFunctionImplementation, Signature, @@ -95,6 +96,10 @@ impl ScalarUDF { fun: fun.clone(), } } + ///Returns the volatilaty of the UDF + pub fn volatility(&self)->Volatility{ + self.signature.volatility + } /// creates a logical expression with a call of the UDF /// This utility allows using the UDF without requiring access to the registry. From 18d966efba052efa6d0dc23b04237578d9a7726b Mon Sep 17 00:00:00 2001 From: patrick Date: Wed, 13 Oct 2021 22:33:33 -0600 Subject: [PATCH 2/6] formatted code --- datafusion/src/optimizer/constant_folding.rs | 482 ++++++++++--------- datafusion/src/physical_plan/udf.rs | 2 +- 2 files changed, 246 insertions(+), 238 deletions(-) diff --git a/datafusion/src/optimizer/constant_folding.rs b/datafusion/src/optimizer/constant_folding.rs index e294c3bcaf489..8c19f6ad30c12 100644 --- a/datafusion/src/optimizer/constant_folding.rs +++ b/datafusion/src/optimizer/constant_folding.rs @@ -25,18 +25,13 @@ use arrow::record_batch::RecordBatch; use crate::error::{DataFusionError, Result}; use crate::execution::context::{ExecutionContextState, ExecutionProps}; -use crate::logical_plan::{ - DFSchema, DFSchemaRef, Expr, LogicalPlan, Operator, -}; +use crate::logical_plan::{DFSchema, DFSchemaRef, Expr, LogicalPlan, Operator}; use crate::optimizer::optimizer::OptimizerRule; use crate::optimizer::utils; use crate::physical_plan::functions::{BuiltinScalarFunction, Volatility}; use crate::physical_plan::planner::DefaultPhysicalPlanner; use crate::scalar::ScalarValue; - - - struct ConstantRewriter<'a> { execution_props: &'a ExecutionProps, schemas: Vec<&'a DFSchemaRef>, @@ -51,34 +46,34 @@ struct ConstantRewriter<'a> { /// * `false = true` and `true = false` to `false` /// * `!!expr` to `expr` /// * `expr = null` and `expr != null` to `null` -pub struct ConstantFolding{} +pub struct ConstantFolding {} -impl ConstantFolding{ +impl ConstantFolding { #[allow(missing_docs)] pub fn new() -> Self { Self {} } } -impl OptimizerRule for ConstantFolding{ +impl OptimizerRule for ConstantFolding { fn optimize( &self, plan: &LogicalPlan, execution_props: &ExecutionProps, ) -> Result { - let mut rewriter = ConstantRewriter{ + let mut rewriter = ConstantRewriter { execution_props: &execution_props, schemas: plan.all_schemas(), }; - + match plan { LogicalPlan::Filter { predicate, input } => Ok(LogicalPlan::Filter { - predicate: match rewriter.rewrite(predicate.clone()){ - Ok(e)=> e, - _ => predicate.clone() + predicate: match rewriter.rewrite(predicate.clone()) { + Ok(e) => e, + _ => predicate.clone(), }, - input: match self.optimize(input, execution_props){ + input: match self.optimize(input, execution_props) { Ok(plan) => Arc::new(plan.clone()), - _ => input.clone() + _ => input.clone(), }, }), // Rest: recurse into plan, apply optimization where possible @@ -99,16 +94,16 @@ impl OptimizerRule for ConstantFolding{ let inputs = plan.inputs(); let new_inputs = inputs .iter() - .map(|plan| match self.optimize(plan, execution_props){ + .map(|plan| match self.optimize(plan, execution_props) { Ok(opt_plan) => opt_plan, - _=> (*plan).clone(), + _ => (*plan).clone(), }) .collect::>(); let expr = plan .expressions() .into_iter() - .map(|e| match rewriter.rewrite(e.clone()){ + .map(|e| match rewriter.rewrite(e.clone()) { Ok(expr) => expr, Err(_) => e, }) @@ -129,40 +124,47 @@ impl OptimizerRule for ConstantFolding{ ///Evaluate calculates the value of scalar expressions. This function may panic if columns are present within the expression pub fn evaluate(expr: &Expr, exec_props: &ExecutionProps) -> Result { - if let Expr::Literal(s) = expr{ + if let Expr::Literal(s) = expr { return Ok(s.clone()); } - //The dummy column name was chosen as to not interfere with any possible columns names in a normal schema. Unsure if this is needed as the schema of the columns - //is never used - static DUMMY_COL_NAME : &'static str = "."; + //The dummy column name shouldn't really matter as only scalar expressions should be evaluated + static DUMMY_COL_NAME: &'static str = "."; let dummy_df_schema = DFSchema::empty(); - let dummy_input_schema = Schema::new(vec![Field::new(DUMMY_COL_NAME, DataType::Float64, true)]); + let dummy_input_schema = + Schema::new(vec![Field::new(DUMMY_COL_NAME, DataType::Float64, true)]); let mut ctx_state = ExecutionContextState::new(); ctx_state.execution_props = exec_props.clone(); let planner = DefaultPhysicalPlanner::default(); - let phys_expr = planner.create_physical_expr(expr, &dummy_df_schema, &dummy_input_schema, &ctx_state)?; - let col = { + let phys_expr = planner.create_physical_expr( + expr, + &dummy_df_schema, + &dummy_input_schema, + &ctx_state, + )?; + let col = { let mut builder = Float64Array::builder(1); builder.append_null()?; builder.finish() }; - let record_batch = RecordBatch::try_new(Arc::new(dummy_input_schema), vec![Arc::new(col)])?; + let record_batch = + RecordBatch::try_new(Arc::new(dummy_input_schema), vec![Arc::new(col)])?; let col_val = phys_expr.evaluate(&record_batch)?; - match col_val{ + match col_val { crate::physical_plan::ColumnarValue::Array(a) => { - if a.len() != 1{ - Err(DataFusionError::Execution(format!("Could not evaluate the expressison, found a result of length {}", a.len()))) - }else{ - Ok(ScalarValue::try_from_array(&a,0)?) + if a.len() != 1 { + Err(DataFusionError::Execution(format!( + "Could not evaluate the expressison, found a result of length {}", + a.len() + ))) + } else { + Ok(ScalarValue::try_from_array(&a, 0)?) } - }, + } crate::physical_plan::ColumnarValue::Scalar(s) => Ok(s), } - } impl<'a> ConstantRewriter<'a> { - fn is_boolean_type(&self, expr: &Expr) -> bool { for schema in &self.schemas { if let Ok(DataType::Boolean) = expr.get_type(schema) { @@ -173,28 +175,25 @@ impl<'a> ConstantRewriter<'a> { false } - - - pub fn rewrite(&mut self, mut expr: Expr) -> Result { - let name = match &expr{ - Expr::Alias(_, name )=>Some(name.clone()), + let name = match &expr { + Expr::Alias(_, name) => Some(name.clone()), _ => None, }; - + let rewrite_root = self.rewrite_const_expr(&mut expr); - let expr = if rewrite_root{ - match evaluate(&expr, self.execution_props){ + let expr = if rewrite_root { + match evaluate(&expr, self.execution_props) { Ok(s) => Expr::Literal(s), Err(e) => { println!("Could not rewrite: {}", e); expr - }, + } } - }else{ + } else { expr }; - Ok(match name{ + Ok(match name { Some(name) => expr.alias(&name), None => expr, }) @@ -204,27 +203,30 @@ impl<'a> ConstantRewriter<'a> { std::mem::swap(&mut replacement, expr); } - fn const_fold_list_eager(&mut self, args: &mut Vec){ - for arg in args.iter_mut(){ - if self.rewrite_const_expr(arg){ - match evaluate(arg,self.execution_props){ + fn const_fold_list_eager(&mut self, args: &mut Vec) { + for arg in args.iter_mut() { + if self.rewrite_const_expr(arg) { + match evaluate(arg, self.execution_props) { Ok(s) => *arg = Expr::Literal(s), - _ => () + _ => (), } } } } - fn const_fold_list(&mut self, args: &mut Vec)->bool{ - let can_rewrite= args.iter_mut().map(|e| self.rewrite_const_expr(e)).collect::>(); - if can_rewrite.iter().all(|f| *f){ + fn const_fold_list(&mut self, args: &mut Vec) -> bool { + let can_rewrite = args + .iter_mut() + .map(|e| self.rewrite_const_expr(e)) + .collect::>(); + if can_rewrite.iter().all(|f| *f) { return true; - }else{ - for (rewrite_expr, expr) in can_rewrite.iter().zip(args){ - if *rewrite_expr{ - match evaluate(expr,self.execution_props){ - Ok(s) =>*expr = Expr::Literal(s), - _ =>(), + } else { + for (rewrite_expr, expr) in can_rewrite.iter().zip(args) { + if *rewrite_expr { + match evaluate(expr, self.execution_props) { + Ok(s) => *expr = Expr::Literal(s), + _ => (), } } } @@ -233,21 +235,31 @@ impl<'a> ConstantRewriter<'a> { } ///This attempts to simplify expressions of the form col(Boolean) = Boolean and col(Boolean) != Boolean /// e.g. col(Boolean) = Some(true) -> col(Boolean) - - fn binary_column_const_fold(&mut self, left: &mut Box, op: &Operator, right: &mut Box)->Option{ - let expr = match (left.as_ref(), op, right.as_ref()){ - - (Expr::Literal(ScalarValue::Boolean(l)), Operator::Eq, Expr::Literal(ScalarValue::Boolean(r)))=>{ - let literal_bool = Expr::Literal(ScalarValue::Boolean( - match (l, r){ - + + fn binary_column_const_fold( + &mut self, + left: &mut Box, + op: &Operator, + right: &mut Box, + ) -> Option { + let expr = match (left.as_ref(), op, right.as_ref()) { + ( + Expr::Literal(ScalarValue::Boolean(l)), + Operator::Eq, + Expr::Literal(ScalarValue::Boolean(r)), + ) => { + let literal_bool = Expr::Literal(ScalarValue::Boolean(match (l, r) { (Some(l), Some(r)) => Some(*l == *r), _ => None, })); Some(literal_bool) } - (Expr::Literal(ScalarValue::Boolean(l)), Operator::NotEq, Expr::Literal(ScalarValue::Boolean(r)))=>{ - let literal_bool = match (l,r){ + ( + Expr::Literal(ScalarValue::Boolean(l)), + Operator::NotEq, + Expr::Literal(ScalarValue::Boolean(r)), + ) => { + let literal_bool = match (l, r) { (Some(l), Some(r)) => { Expr::Literal(ScalarValue::Boolean(Some(l != r))) } @@ -255,45 +267,47 @@ impl<'a> ConstantRewriter<'a> { }; Some(literal_bool) } - (Expr::Literal(ScalarValue::Boolean(b)), Operator::Eq, col) | - (col, Operator::Eq, Expr::Literal(ScalarValue::Boolean(b))) if self.is_boolean_type(&col) =>{ - Some(match b{ - Some(true)=>col.clone(), - Some(false) =>Expr::Not(Box::new(col.clone())), + (Expr::Literal(ScalarValue::Boolean(b)), Operator::Eq, col) + | (col, Operator::Eq, Expr::Literal(ScalarValue::Boolean(b))) + if self.is_boolean_type(&col) => + { + Some(match b { + Some(true) => col.clone(), + Some(false) => Expr::Not(Box::new(col.clone())), None => Expr::Literal(ScalarValue::Boolean(None)), }) - }, - (Expr::Literal(ScalarValue::Boolean(b)), Operator::NotEq, col) | - (col, Operator::NotEq, Expr::Literal(ScalarValue::Boolean(b))) if self.is_boolean_type(&col) =>{ - Some(match b{ - Some(true)=>Expr::Not(Box::new(col.clone())), + } + (Expr::Literal(ScalarValue::Boolean(b)), Operator::NotEq, col) + | (col, Operator::NotEq, Expr::Literal(ScalarValue::Boolean(b))) + if self.is_boolean_type(&col) => + { + Some(match b { + Some(true) => Expr::Not(Box::new(col.clone())), Some(false) => col.clone(), None => Expr::Literal(ScalarValue::Boolean(None)), }) } - _ => None, + _ => None, }; expr } - fn rewrite_const_expr(&mut self, expr: &mut Expr) -> bool { - let can_rewrite = match expr { Expr::Alias(e, _) => self.rewrite_const_expr(e), Expr::Column(_) => false, Expr::ScalarVariable(_) => false, Expr::Literal(_) => true, Expr::BinaryExpr { left, op, right } => { - //Check if left and right are const, much like the Not Not optimization this is done first to make sure any - //Non-scalar execution optimizations, such as col("test") = NULL->false are performed first + //Check if left and right are const, much like the Not Not optimization this is done first to make sure any + //Non-scalar execution optimizations, such as col("test") = NULL->false are performed first let left_const = self.rewrite_const_expr(left); let right_const = self.rewrite_const_expr(right); - let mut can_rewrite = match (left_const, right_const) { + let mut can_rewrite = match (left_const, right_const) { (true, true) => true, (false, false) => false, (true, false) => { - match evaluate(&left,self.execution_props) { + match evaluate(&left, self.execution_props) { Ok(s) => { let left: &mut Expr = left; *left = Expr::Literal(s); @@ -303,51 +317,47 @@ impl<'a> ConstantRewriter<'a> { false } (false, true) => { - match evaluate(&right,self.execution_props) { + match evaluate(&right, self.execution_props) { Ok(s) => { let right: &mut Expr = right; - *right = Expr::Literal(s); + *right = Expr::Literal(s); } Err(_) => (), } false } }; - - - can_rewrite= match self.binary_column_const_fold(left, op, right){ - Some(e) =>{ + + can_rewrite = match self.binary_column_const_fold(left, op, right) { + Some(e) => { let expr: &mut Expr = expr; *expr = e; self.rewrite_const_expr(expr) } - None => can_rewrite + None => can_rewrite, }; - - can_rewrite + can_rewrite } Expr::Not(e) => { //Check if the expression can be rewritten. This may trigger simplifications such as col("b") = false -> NOT col("b") //Then check if inner expression is Not and if so replace expr with the inner let can_rewrite = self.rewrite_const_expr(e); - match e.as_mut(){ - Expr::Not(inner)=>{ - let inner = std::mem::replace(inner.as_mut(),Expr::Wildcard); + match e.as_mut() { + Expr::Not(inner) => { + let inner = std::mem::replace(inner.as_mut(), Expr::Wildcard); *expr = inner; self.rewrite_const_expr(expr) } - _=>can_rewrite + _ => can_rewrite, } - }, + } Expr::IsNotNull(e) => self.rewrite_const_expr(e), Expr::IsNull(e) => self.rewrite_const_expr(e), Expr::Negative(e) => self.rewrite_const_expr(e), Expr::Between { - expr, - low, - high, .. + expr, low, high, .. } => match ( self.rewrite_const_expr(expr), self.rewrite_const_expr(low), @@ -356,67 +366,79 @@ impl<'a> ConstantRewriter<'a> { (true, true, true) => true, (expr_const, low_const, high_const) => { if expr_const { - if let Ok(s) = evaluate(expr,self.execution_props){ + if let Ok(s) = evaluate(expr, self.execution_props) { self.replace_expr(expr, Expr::Literal(s)); } } if low_const { - if let Ok(s) = evaluate(expr,self.execution_props){ + if let Ok(s) = evaluate(expr, self.execution_props) { self.replace_expr(low, Expr::Literal(s)); } } if high_const { - if let Ok(s) = evaluate(expr,self.execution_props){ + if let Ok(s) = evaluate(expr, self.execution_props) { self.replace_expr(high, Expr::Literal(s)); } } false } }, - Expr::Case { expr, when_then_expr, else_expr } => { - if expr.as_mut().map(|e| self.rewrite_const_expr(e)).unwrap_or(false){ - let expr_inner = expr.as_mut().unwrap(); - match evaluate(expr_inner, self.execution_props){ + Expr::Case { + expr, + when_then_expr, + else_expr, + } => { + if expr + .as_mut() + .map(|e| self.rewrite_const_expr(e)) + .unwrap_or(false) + { + let expr_inner = expr.as_mut().unwrap(); + match evaluate(expr_inner, self.execution_props) { Ok(s) => *expr_inner.as_mut() = Expr::Literal(s), - Err(_) => (), + Err(_) => (), } } - if else_expr.as_mut().map(|e| self.rewrite_const_expr(e)).unwrap_or(false){ - let expr_inner = else_expr.as_mut().unwrap(); - match evaluate(expr_inner, self.execution_props){ + if else_expr + .as_mut() + .map(|e| self.rewrite_const_expr(e)) + .unwrap_or(false) + { + let expr_inner = else_expr.as_mut().unwrap(); + match evaluate(expr_inner, self.execution_props) { Ok(s) => *expr_inner.as_mut() = Expr::Literal(s), - Err(_) => (), + Err(_) => (), } } - for (when, then) in when_then_expr{ + for (when, then) in when_then_expr { let when: &mut Expr = when; - let then : &mut Expr = then; - if self.rewrite_const_expr(when){ - match evaluate(when, self.execution_props){ + let then: &mut Expr = then; + if self.rewrite_const_expr(when) { + match evaluate(when, self.execution_props) { Ok(s) => *when = Expr::Literal(s), - _ =>(), + _ => (), } } - if self.rewrite_const_expr(then){ - match evaluate(then, self.execution_props){ + if self.rewrite_const_expr(then) { + match evaluate(then, self.execution_props) { Ok(s) => *then = Expr::Literal(s), Err(_) => (), } } } false - }, + } Expr::Cast { expr, .. } => self.rewrite_const_expr(expr), Expr::TryCast { expr, .. } => self.rewrite_const_expr(expr), Expr::Sort { expr, .. } => { if self.rewrite_const_expr(expr) { - match evaluate(expr,self.execution_props) { + match evaluate(expr, self.execution_props) { Ok(s) => { let expr: &mut Expr = expr; - *expr = Expr::Literal(s); - }, + *expr = Expr::Literal(s); + } Err(_) => (), } } @@ -426,48 +448,45 @@ impl<'a> ConstantRewriter<'a> { fun: BuiltinScalarFunction::Now, .. } => { - *expr= Expr::Literal(ScalarValue::TimestampNanosecond(Some( - self.execution_props - .query_execution_start_time - .timestamp_nanos(), + *expr = Expr::Literal(ScalarValue::TimestampNanosecond(Some( + self.execution_props + .query_execution_start_time + .timestamp_nanos(), ))); true - }, + } Expr::ScalarFunction { fun, args } => { - if args.is_empty(){ + if args.is_empty() { false - }else{ + } else { let volatility = fun.volatility(); - match volatility{ + match volatility { Volatility::Immutable => self.const_fold_list(args), - _ =>{ + _ => { self.const_fold_list_eager(args); false } } - } - }, + } + } Expr::ScalarUDF { fun, args } => { - if args.is_empty(){ + if args.is_empty() { false - }else{ + } else { let volatility = fun.volatility(); - match volatility{ + match volatility { Volatility::Immutable => self.const_fold_list(args), - _ =>{ + _ => { self.const_fold_list_eager(args); false } - } + } } - }, - Expr::AggregateFunction { - args, - .. - } => { + } + Expr::AggregateFunction { args, .. } => { self.const_fold_list_eager(args); false - }, + } Expr::WindowFunction { args, partition_by, @@ -478,35 +497,31 @@ impl<'a> ConstantRewriter<'a> { self.const_fold_list_eager(partition_by); self.const_fold_list_eager(order_by); false - }, - Expr::AggregateUDF { args,.. } => { + } + Expr::AggregateUDF { args, .. } => { self.const_fold_list_eager(args); false - }, - Expr::InList { - expr, - list, - .. - } => { + } + Expr::InList { expr, list, .. } => { let expr_const = self.rewrite_const_expr(expr); let list_literals = self.const_fold_list(list); - match (expr_const, list_literals){ - (true, true)=> true, + match (expr_const, list_literals) { + (true, true) => true, (false, false) => false, - (true, false)=> { - if let Ok(s) = evaluate(expr,self.execution_props) { + (true, false) => { + if let Ok(s) = evaluate(expr, self.execution_props) { let expr: &mut Expr = expr; - *expr = Expr::Literal(s); + *expr = Expr::Literal(s); } false } - (false, true)=>{ + (false, true) => { self.const_fold_list_eager(list); false - }, + } } - }, + } Expr::Wildcard => false, }; println!("Can rewrite the expr[{}]: {}", can_rewrite, expr); @@ -514,13 +529,16 @@ impl<'a> ConstantRewriter<'a> { } } - - #[cfg(test)] mod tests { use super::*; - use arrow::array::{Float64Array, ArrayRef}; - use crate::{logical_plan::{DFField, DFSchema, LogicalPlanBuilder, col, create_udf, abs, lit, max, min}, physical_plan::{functions::make_scalar_function, udf::ScalarUDF}}; + use crate::{ + logical_plan::{ + abs, col, create_udf, lit, max, min, DFField, DFSchema, LogicalPlanBuilder, + }, + physical_plan::{functions::make_scalar_function, udf::ScalarUDF}, + }; + use arrow::array::{ArrayRef, Float64Array}; use chrono::{DateTime, Utc}; @@ -530,7 +548,7 @@ mod tests { Field::new("b", DataType::Boolean, false), Field::new("c", DataType::Boolean, false), Field::new("d", DataType::UInt32, false), - Field::new("e", DataType::Float64, false) + Field::new("e", DataType::Float64, false), ]); LogicalPlanBuilder::scan_empty(Some("test"), &schema, None)?.build() } @@ -577,7 +595,9 @@ mod tests { // null != null is always null assert_eq!( - rewriter.rewrite(lit(ScalarValue::Boolean(None)).not_eq(lit(ScalarValue::Boolean(None))))?, + rewriter.rewrite( + lit(ScalarValue::Boolean(None)).not_eq(lit(ScalarValue::Boolean(None))) + )?, lit(ScalarValue::Boolean(None)), ); @@ -610,19 +630,13 @@ mod tests { assert_eq!(rewriter.rewrite(lit(true).eq(lit(true)))?, lit(true),); // true = false -> false - assert_eq!( - rewriter.rewrite(lit(true).eq(lit(false)))?, - lit(false), - ); + assert_eq!(rewriter.rewrite(lit(true).eq(lit(false)))?, lit(false),); // c2 = true -> c2 assert_eq!(rewriter.rewrite(col("c2").eq(lit(true)))?, col("c2"),); // c2 = false => !c2 - assert_eq!( - rewriter.rewrite(col("c2").eq(lit(false)))?, - col("c2").not(), - ); + assert_eq!(rewriter.rewrite(col("c2").eq(lit(false)))?, col("c2").not(),); Ok(()) } @@ -684,21 +698,12 @@ mod tests { ); // c2 != false -> c2 - assert_eq!( - rewriter.rewrite(col("c2").not_eq(lit(false)))?, - col("c2"), - ); + assert_eq!(rewriter.rewrite(col("c2").not_eq(lit(false)))?, col("c2"),); // test constant - assert_eq!( - rewriter.rewrite(lit(true).not_eq(lit(true)))?, - lit(false), - ); + assert_eq!(rewriter.rewrite(lit(true).not_eq(lit(true)))?, lit(false),); - assert_eq!( - rewriter.rewrite(lit(true).not_eq(lit(false)))?, - lit(true), - ); + assert_eq!(rewriter.rewrite(lit(true).not_eq(lit(false)))?, lit(true),); Ok(()) } @@ -1076,40 +1081,37 @@ mod tests { assert_eq!(actual, expected); } - fn create_pow_with_volatilty(volatilty: Volatility)->ScalarUDF{ - let pow = |args: &[ArrayRef]| { - - assert_eq!(args.len(), 2); - - let base = &args[0] - .as_any() - .downcast_ref::() - .expect("cast failed"); - let exponent = &args[1] - .as_any() - .downcast_ref::() - .expect("cast failed"); - - assert_eq!(exponent.len(), base.len()); - - let array = base - .iter() - .zip(exponent.iter()) - .map(|(base, exponent)| { - match (base, exponent) { + fn create_pow_with_volatilty(volatilty: Volatility) -> ScalarUDF { + let pow = |args: &[ArrayRef]| { + assert_eq!(args.len(), 2); + + let base = &args[0] + .as_any() + .downcast_ref::() + .expect("cast failed"); + let exponent = &args[1] + .as_any() + .downcast_ref::() + .expect("cast failed"); + + assert_eq!(exponent.len(), base.len()); + + let array = base + .iter() + .zip(exponent.iter()) + .map(|(base, exponent)| match (base, exponent) { (Some(base), Some(exponent)) => Some(base.powf(exponent)), _ => None, - } - }) - .collect::(); - Ok(Arc::new(array) as ArrayRef) - }; - let pow = make_scalar_function(pow); - let name = match volatilty{ - Volatility::Immutable => "pow", - Volatility::Stable => "pow_stable", - Volatility::Volatile => "pow_vol", - }; + }) + .collect::(); + Ok(Arc::new(array) as ArrayRef) + }; + let pow = make_scalar_function(pow); + let name = match volatilty { + Volatility::Immutable => "pow", + Volatility::Stable => "pow_stable", + Volatility::Volatile => "pow_vol", + }; create_udf( name, vec![DataType::Float64, DataType::Float64], @@ -1120,11 +1122,16 @@ mod tests { } #[test] - fn test_constant_evaluate_binop()->Result<()>{ + fn test_constant_evaluate_binop() -> Result<()> { let scan = test_table_scan()?; //Trying to get non literal expression that has the value Boolean(NULL) so that the symbolic constant_folding can be tested - let proj = vec![Expr::TryCast{expr: Box::new(lit("")), data_type: DataType::Int32}.eq(lit(0)).eq(col("a"))]; + let proj = vec![Expr::TryCast { + expr: Box::new(lit("")), + data_type: DataType::Int32, + } + .eq(lit(0)) + .eq(col("a"))]; let time = chrono::Utc::now(); let plan = LogicalPlanBuilder::from(scan.clone()) .project(proj) @@ -1133,12 +1140,16 @@ mod tests { .unwrap(); let actual = get_optimized_plan_formatted(&plan, &time); let expected = "Projection: Boolean(NULL)\ - \n TableScan: test projection=None"; + \n TableScan: test projection=None"; assert_eq!(actual, expected); - //Another test for boolean expression constant folding true = #test.a -> true - let proj = vec![Expr::TryCast{expr: Box::new(lit("0")), data_type: DataType::Int32}.eq(lit(0)).eq(col("a"))]; + let proj = vec![Expr::TryCast { + expr: Box::new(lit("0")), + data_type: DataType::Int32, + } + .eq(lit(0)) + .eq(col("a"))]; let time = chrono::Utc::now(); let plan = LogicalPlanBuilder::from(scan) .project(proj) @@ -1147,26 +1158,24 @@ mod tests { .unwrap(); let actual = get_optimized_plan_formatted(&plan, &time); let expected = "Projection: #test.a\ - \n TableScan: test projection=None"; + \n TableScan: test projection=None"; assert_eq!(actual, expected); Ok(()) - } //Testing that immutable scalar UDFs are inlined, stable or volatile UDFs are not inlined, and the arguments to a stable or volatile UDF are still folded #[test] - fn test_udf_inlining()->Result<()>{ - + fn test_udf_inlining() -> Result<()> { let scan = test_table_scan()?; let pow_immut = create_pow_with_volatilty(Volatility::Immutable); let pow_stab = create_pow_with_volatilty(Volatility::Stable); let pow_vol = create_pow_with_volatilty(Volatility::Volatile); let pow_res_2 = vec![abs(lit(1.0) - lit(3)), lit(2)]; let proj = vec![ - pow_immut.call(pow_res_2.clone())*col("e").alias("constant"), - pow_stab.call(pow_res_2.clone())*col("e").alias("stable"), - pow_vol.call(pow_res_2)*col("e") + pow_immut.call(pow_res_2.clone()) * col("e").alias("constant"), + pow_stab.call(pow_res_2.clone()) * col("e").alias("stable"), + pow_vol.call(pow_res_2) * col("e"), ]; let time = chrono::Utc::now(); let plan = LogicalPlanBuilder::from(scan) @@ -1183,5 +1192,4 @@ mod tests { assert_eq!(actual, expected); Ok(()) } - -} \ No newline at end of file +} diff --git a/datafusion/src/physical_plan/udf.rs b/datafusion/src/physical_plan/udf.rs index 835c39f2d55c7..824eb672d865b 100644 --- a/datafusion/src/physical_plan/udf.rs +++ b/datafusion/src/physical_plan/udf.rs @@ -97,7 +97,7 @@ impl ScalarUDF { } } ///Returns the volatilaty of the UDF - pub fn volatility(&self)->Volatility{ + pub fn volatility(&self) -> Volatility { self.signature.volatility } From b943a18932716c73b03e0e96a75d9bdf85486a74 Mon Sep 17 00:00:00 2001 From: patrick Date: Wed, 13 Oct 2021 23:01:15 -0600 Subject: [PATCH 3/6] Removed ExecutionProps argument for evaluate --- datafusion/src/optimizer/constant_folding.rs | 141 ++++++++++--------- 1 file changed, 73 insertions(+), 68 deletions(-) diff --git a/datafusion/src/optimizer/constant_folding.rs b/datafusion/src/optimizer/constant_folding.rs index 8c19f6ad30c12..a3726617c7c0c 100644 --- a/datafusion/src/optimizer/constant_folding.rs +++ b/datafusion/src/optimizer/constant_folding.rs @@ -67,10 +67,7 @@ impl OptimizerRule for ConstantFolding { match plan { LogicalPlan::Filter { predicate, input } => Ok(LogicalPlan::Filter { - predicate: match rewriter.rewrite(predicate.clone()) { - Ok(e) => e, - _ => predicate.clone(), - }, + predicate: rewriter.rewrite(predicate.clone()) , input: match self.optimize(input, execution_props) { Ok(plan) => Arc::new(plan.clone()), _ => input.clone(), @@ -103,10 +100,7 @@ impl OptimizerRule for ConstantFolding { let expr = plan .expressions() .into_iter() - .map(|e| match rewriter.rewrite(e.clone()) { - Ok(expr) => expr, - Err(_) => e, - }) + .map(|e| rewriter.rewrite(e)) .collect::>(); utils::from_plan(plan, &expr, &new_inputs) @@ -123,7 +117,7 @@ impl OptimizerRule for ConstantFolding { } ///Evaluate calculates the value of scalar expressions. This function may panic if columns are present within the expression -pub fn evaluate(expr: &Expr, exec_props: &ExecutionProps) -> Result { +pub fn evaluate(expr: &Expr) -> Result { if let Expr::Literal(s) = expr { return Ok(s.clone()); } @@ -132,8 +126,8 @@ pub fn evaluate(expr: &Expr, exec_props: &ExecutionProps) -> Result let dummy_df_schema = DFSchema::empty(); let dummy_input_schema = Schema::new(vec![Field::new(DUMMY_COL_NAME, DataType::Float64, true)]); - let mut ctx_state = ExecutionContextState::new(); - ctx_state.execution_props = exec_props.clone(); + let ctx_state = ExecutionContextState::new(); + let planner = DefaultPhysicalPlanner::default(); let phys_expr = planner.create_physical_expr( expr, @@ -175,45 +169,52 @@ impl<'a> ConstantRewriter<'a> { false } - pub fn rewrite(&mut self, mut expr: Expr) -> Result { + pub fn rewrite(&mut self, mut expr: Expr) -> Expr { let name = match &expr { Expr::Alias(_, name) => Some(name.clone()), _ => None, }; let rewrite_root = self.rewrite_const_expr(&mut expr); - let expr = if rewrite_root { - match evaluate(&expr, self.execution_props) { - Ok(s) => Expr::Literal(s), - Err(e) => { - println!("Could not rewrite: {}", e); - expr - } + if rewrite_root{ + match evaluate(&expr) { + Ok(s) => expr= Expr::Literal(s), + Err(_) => return expr, } - } else { - expr - }; - Ok(match name { - Some(name) => expr.alias(&name), + } + match name { + Some(name) => { + let existing_alias = match &expr{ + Expr::Alias(_, new_alias) => Some(new_alias), + _ => None + }; + let apply_new_alias = match existing_alias{ + Some(new) => *new != name, + None => false, + }; + if apply_new_alias{ + expr = Expr::Alias(Box::new(expr), name); + } + expr + }, None => expr, - }) + } } - fn replace_expr(&self, expr: &mut Box, mut replacement: Expr) { - std::mem::swap(&mut replacement, expr); - } + ///Evaluates all literal expressions in the list. fn const_fold_list_eager(&mut self, args: &mut Vec) { for arg in args.iter_mut() { if self.rewrite_const_expr(arg) { - match evaluate(arg, self.execution_props) { + match evaluate(arg) { Ok(s) => *arg = Expr::Literal(s), _ => (), } } } } - + ///Tests to see if the list passed in is all literal expressions, if they are then it returns true. + ///If some expressions are not literal then the literal expressions are evaluated and it returns false. fn const_fold_list(&mut self, args: &mut Vec) -> bool { let can_rewrite = args .iter_mut() @@ -224,7 +225,7 @@ impl<'a> ConstantRewriter<'a> { } else { for (rewrite_expr, expr) in can_rewrite.iter().zip(args) { if *rewrite_expr { - match evaluate(expr, self.execution_props) { + match evaluate(expr) { Ok(s) => *expr = Expr::Literal(s), _ => (), } @@ -234,7 +235,8 @@ impl<'a> ConstantRewriter<'a> { false } ///This attempts to simplify expressions of the form col(Boolean) = Boolean and col(Boolean) != Boolean - /// e.g. col(Boolean) = Some(true) -> col(Boolean) + /// e.g. col(Boolean) = Some(true) -> col(Boolean). It also handles == and != between two boolean literals as + /// the binary operator physical expression currently doesn't handle them. fn binary_column_const_fold( &mut self, @@ -307,7 +309,7 @@ impl<'a> ConstantRewriter<'a> { (true, true) => true, (false, false) => false, (true, false) => { - match evaluate(&left, self.execution_props) { + match evaluate(&left) { Ok(s) => { let left: &mut Expr = left; *left = Expr::Literal(s); @@ -317,7 +319,7 @@ impl<'a> ConstantRewriter<'a> { false } (false, true) => { - match evaluate(&right, self.execution_props) { + match evaluate(&right) { Ok(s) => { let right: &mut Expr = right; *right = Expr::Literal(s); @@ -366,18 +368,21 @@ impl<'a> ConstantRewriter<'a> { (true, true, true) => true, (expr_const, low_const, high_const) => { if expr_const { - if let Ok(s) = evaluate(expr, self.execution_props) { - self.replace_expr(expr, Expr::Literal(s)); + if let Ok(s) = evaluate(expr) { + let expr: &mut Expr = expr; + *expr = Expr::Literal(s); } } if low_const { - if let Ok(s) = evaluate(expr, self.execution_props) { - self.replace_expr(low, Expr::Literal(s)); + if let Ok(s) = evaluate(expr) { + let expr: &mut Expr = expr; + *expr = Expr::Literal(s); } } if high_const { - if let Ok(s) = evaluate(expr, self.execution_props) { - self.replace_expr(high, Expr::Literal(s)); + if let Ok(s) = evaluate(expr) { + let expr: &mut Expr = expr; + *expr = Expr::Literal(s); } } false @@ -394,7 +399,7 @@ impl<'a> ConstantRewriter<'a> { .unwrap_or(false) { let expr_inner = expr.as_mut().unwrap(); - match evaluate(expr_inner, self.execution_props) { + match evaluate(expr_inner) { Ok(s) => *expr_inner.as_mut() = Expr::Literal(s), Err(_) => (), } @@ -406,7 +411,7 @@ impl<'a> ConstantRewriter<'a> { .unwrap_or(false) { let expr_inner = else_expr.as_mut().unwrap(); - match evaluate(expr_inner, self.execution_props) { + match evaluate(expr_inner) { Ok(s) => *expr_inner.as_mut() = Expr::Literal(s), Err(_) => (), } @@ -416,13 +421,13 @@ impl<'a> ConstantRewriter<'a> { let when: &mut Expr = when; let then: &mut Expr = then; if self.rewrite_const_expr(when) { - match evaluate(when, self.execution_props) { + match evaluate(when) { Ok(s) => *when = Expr::Literal(s), _ => (), } } if self.rewrite_const_expr(then) { - match evaluate(then, self.execution_props) { + match evaluate(then) { Ok(s) => *then = Expr::Literal(s), Err(_) => (), } @@ -434,7 +439,7 @@ impl<'a> ConstantRewriter<'a> { Expr::TryCast { expr, .. } => self.rewrite_const_expr(expr), Expr::Sort { expr, .. } => { if self.rewrite_const_expr(expr) { - match evaluate(expr, self.execution_props) { + match evaluate(expr) { Ok(s) => { let expr: &mut Expr = expr; *expr = Expr::Literal(s); @@ -510,7 +515,7 @@ impl<'a> ConstantRewriter<'a> { (false, false) => false, (true, false) => { - if let Ok(s) = evaluate(expr, self.execution_props) { + if let Ok(s) = evaluate(expr) { let expr: &mut Expr = expr; *expr = Expr::Literal(s); } @@ -572,7 +577,7 @@ mod tests { }; assert_eq!( - rewriter.rewrite(col("c2").not().not().not())?, + rewriter.rewrite(col("c2").not().not().not()), col("c2").not(), ); @@ -589,7 +594,7 @@ mod tests { // x = null is always null assert_eq!( - rewriter.rewrite(lit(true).eq(lit(ScalarValue::Boolean(None))))?, + rewriter.rewrite(lit(true).eq(lit(ScalarValue::Boolean(None)))), lit(ScalarValue::Boolean(None)), ); @@ -597,19 +602,19 @@ mod tests { assert_eq!( rewriter.rewrite( lit(ScalarValue::Boolean(None)).not_eq(lit(ScalarValue::Boolean(None))) - )?, + ), lit(ScalarValue::Boolean(None)), ); // x != null is always null assert_eq!( - rewriter.rewrite(col("c2").not_eq(lit(ScalarValue::Boolean(None))))?, + rewriter.rewrite(col("c2").not_eq(lit(ScalarValue::Boolean(None)))), lit(ScalarValue::Boolean(None)), ); // null = x is always null assert_eq!( - rewriter.rewrite(lit(ScalarValue::Boolean(None)).eq(col("c2")))?, + rewriter.rewrite(lit(ScalarValue::Boolean(None)).eq(col("c2"))), lit(ScalarValue::Boolean(None)), ); @@ -627,16 +632,16 @@ mod tests { assert_eq!(col("c2").get_type(&schema)?, DataType::Boolean); // true = ture -> true - assert_eq!(rewriter.rewrite(lit(true).eq(lit(true)))?, lit(true),); + assert_eq!(rewriter.rewrite(lit(true).eq(lit(true))), lit(true),); // true = false -> false - assert_eq!(rewriter.rewrite(lit(true).eq(lit(false)))?, lit(false),); + assert_eq!(rewriter.rewrite(lit(true).eq(lit(false))), lit(false),); // c2 = true -> c2 - assert_eq!(rewriter.rewrite(col("c2").eq(lit(true)))?, col("c2"),); + assert_eq!(rewriter.rewrite(col("c2").eq(lit(true))), col("c2"),); // c2 = false => !c2 - assert_eq!(rewriter.rewrite(col("c2").eq(lit(false)))?, col("c2").not(),); + assert_eq!(rewriter.rewrite(col("c2").eq(lit(false))), col("c2").not(),); Ok(()) } @@ -657,24 +662,24 @@ mod tests { // don't fold c1 = true assert_eq!( - rewriter.rewrite(col("c1").eq(lit(true)))?, + rewriter.rewrite(col("c1").eq(lit(true))), col("c1").eq(lit(true)), ); // don't fold c1 = false assert_eq!( - rewriter.rewrite(col("c1").eq(lit(false)))?, + rewriter.rewrite(col("c1").eq(lit(false))), col("c1").eq(lit(false)), ); // test constant operands assert_eq!( - rewriter.rewrite(lit(1).eq(lit(true)))?, + rewriter.rewrite(lit(1).eq(lit(true))), lit(1).eq(lit(true)), ); assert_eq!( - rewriter.rewrite(lit("a").eq(lit(false)))?, + rewriter.rewrite(lit("a").eq(lit(false))), lit("a").eq(lit(false)), ); @@ -693,17 +698,17 @@ mod tests { // c2 != true -> !c2 assert_eq!( - rewriter.rewrite(col("c2").not_eq(lit(true)))?, + rewriter.rewrite(col("c2").not_eq(lit(true))), col("c2").not(), ); // c2 != false -> c2 - assert_eq!(rewriter.rewrite(col("c2").not_eq(lit(false)))?, col("c2"),); + assert_eq!(rewriter.rewrite(col("c2").not_eq(lit(false))), col("c2"),); // test constant - assert_eq!(rewriter.rewrite(lit(true).not_eq(lit(true)))?, lit(false),); + assert_eq!(rewriter.rewrite(lit(true).not_eq(lit(true))), lit(false),); - assert_eq!(rewriter.rewrite(lit(true).not_eq(lit(false)))?, lit(true),); + assert_eq!(rewriter.rewrite(lit(true).not_eq(lit(false))), lit(true),); Ok(()) } @@ -721,23 +726,23 @@ mod tests { assert_eq!(col("c1").get_type(&schema)?, DataType::Utf8); assert_eq!( - rewriter.rewrite(col("c1").not_eq(lit(true)))?, + rewriter.rewrite(col("c1").not_eq(lit(true))), col("c1").not_eq(lit(true)), ); assert_eq!( - rewriter.rewrite(col("c1").not_eq(lit(false)))?, + rewriter.rewrite(col("c1").not_eq(lit(false))), col("c1").not_eq(lit(false)), ); // test constants assert_eq!( - rewriter.rewrite(lit(1).not_eq(lit(true)))?, + rewriter.rewrite(lit(1).not_eq(lit(true))), lit(1).not_eq(lit(true)), ); assert_eq!( - rewriter.rewrite(lit("a").not_eq(lit(false)))?, + rewriter.rewrite(lit("a").not_eq(lit(false))), lit("a").not_eq(lit(false)), ); @@ -760,7 +765,7 @@ mod tests { Box::new(lit("ok").eq(lit(true))), )], else_expr: Some(Box::new(col("c2").eq(lit(true)))), - })?, + }), Expr::Case { expr: None, when_then_expr: vec![( From 2ad8480e93a1f6efd498aa89edef633a968dc642 Mon Sep 17 00:00:00 2001 From: patrick Date: Wed, 13 Oct 2021 23:20:12 -0600 Subject: [PATCH 4/6] Fixed alias not being preserved --- datafusion/src/optimizer/constant_folding.rs | 102 +++++++------------ 1 file changed, 39 insertions(+), 63 deletions(-) diff --git a/datafusion/src/optimizer/constant_folding.rs b/datafusion/src/optimizer/constant_folding.rs index a3726617c7c0c..7bd828e060870 100644 --- a/datafusion/src/optimizer/constant_folding.rs +++ b/datafusion/src/optimizer/constant_folding.rs @@ -61,15 +61,15 @@ impl OptimizerRule for ConstantFolding { execution_props: &ExecutionProps, ) -> Result { let mut rewriter = ConstantRewriter { - execution_props: &execution_props, + execution_props, schemas: plan.all_schemas(), }; match plan { LogicalPlan::Filter { predicate, input } => Ok(LogicalPlan::Filter { - predicate: rewriter.rewrite(predicate.clone()) , + predicate: rewriter.rewrite(predicate.clone()), input: match self.optimize(input, execution_props) { - Ok(plan) => Arc::new(plan.clone()), + Ok(plan) => Arc::new(plan), _ => input.clone(), }, }), @@ -122,12 +122,12 @@ pub fn evaluate(expr: &Expr) -> Result { return Ok(s.clone()); } //The dummy column name shouldn't really matter as only scalar expressions should be evaluated - static DUMMY_COL_NAME: &'static str = "."; + static DUMMY_COL_NAME: &str = "."; let dummy_df_schema = DFSchema::empty(); let dummy_input_schema = Schema::new(vec![Field::new(DUMMY_COL_NAME, DataType::Float64, true)]); let ctx_state = ExecutionContextState::new(); - + let planner = DefaultPhysicalPlanner::default(); let phys_expr = planner.create_physical_expr( expr, @@ -176,39 +176,37 @@ impl<'a> ConstantRewriter<'a> { }; let rewrite_root = self.rewrite_const_expr(&mut expr); - if rewrite_root{ + if rewrite_root { match evaluate(&expr) { - Ok(s) => expr= Expr::Literal(s), + Ok(s) => expr = Expr::Literal(s), Err(_) => return expr, } } match name { Some(name) => { - let existing_alias = match &expr{ - Expr::Alias(_, new_alias) => Some(new_alias), - _ => None + let existing_alias = match &expr { + Expr::Alias(_, new_alias) => Some(new_alias.as_str()), + _ => None, }; - let apply_new_alias = match existing_alias{ + let apply_new_alias = match existing_alias { Some(new) => *new != name, - None => false, + None => true, }; - if apply_new_alias{ + if apply_new_alias { expr = Expr::Alias(Box::new(expr), name); } expr - }, + } None => expr, } } - ///Evaluates all literal expressions in the list. fn const_fold_list_eager(&mut self, args: &mut Vec) { for arg in args.iter_mut() { if self.rewrite_const_expr(arg) { - match evaluate(arg) { - Ok(s) => *arg = Expr::Literal(s), - _ => (), + if let Ok(s) = evaluate(arg) { + *arg = Expr::Literal(s); } } } @@ -225,9 +223,8 @@ impl<'a> ConstantRewriter<'a> { } else { for (rewrite_expr, expr) in can_rewrite.iter().zip(args) { if *rewrite_expr { - match evaluate(expr) { - Ok(s) => *expr = Expr::Literal(s), - _ => (), + if let Ok(s) = evaluate(expr) { + *expr = Expr::Literal(s); } } } @@ -235,7 +232,7 @@ impl<'a> ConstantRewriter<'a> { false } ///This attempts to simplify expressions of the form col(Boolean) = Boolean and col(Boolean) != Boolean - /// e.g. col(Boolean) = Some(true) -> col(Boolean). It also handles == and != between two boolean literals as + /// e.g. col(Boolean) = Some(true) -> col(Boolean). It also handles == and != between two boolean literals as /// the binary operator physical expression currently doesn't handle them. fn binary_column_const_fold( @@ -271,7 +268,7 @@ impl<'a> ConstantRewriter<'a> { } (Expr::Literal(ScalarValue::Boolean(b)), Operator::Eq, col) | (col, Operator::Eq, Expr::Literal(ScalarValue::Boolean(b))) - if self.is_boolean_type(&col) => + if self.is_boolean_type(col) => { Some(match b { Some(true) => col.clone(), @@ -281,7 +278,7 @@ impl<'a> ConstantRewriter<'a> { } (Expr::Literal(ScalarValue::Boolean(b)), Operator::NotEq, col) | (col, Operator::NotEq, Expr::Literal(ScalarValue::Boolean(b))) - if self.is_boolean_type(&col) => + if self.is_boolean_type(col) => { Some(match b { Some(true) => Expr::Not(Box::new(col.clone())), @@ -309,22 +306,14 @@ impl<'a> ConstantRewriter<'a> { (true, true) => true, (false, false) => false, (true, false) => { - match evaluate(&left) { - Ok(s) => { - let left: &mut Expr = left; - *left = Expr::Literal(s); - } - Err(_) => (), + if let Ok(s) = evaluate(left) { + *left.as_mut() = Expr::Literal(s); } false } (false, true) => { - match evaluate(&right) { - Ok(s) => { - let right: &mut Expr = right; - *right = Expr::Literal(s); - } - Err(_) => (), + if let Ok(s) = evaluate(right) { + *right.as_mut() = Expr::Literal(s); } false } @@ -399,9 +388,8 @@ impl<'a> ConstantRewriter<'a> { .unwrap_or(false) { let expr_inner = expr.as_mut().unwrap(); - match evaluate(expr_inner) { - Ok(s) => *expr_inner.as_mut() = Expr::Literal(s), - Err(_) => (), + if let Ok(s) = evaluate(expr_inner) { + *expr_inner.as_mut() = Expr::Literal(s); } } @@ -411,9 +399,8 @@ impl<'a> ConstantRewriter<'a> { .unwrap_or(false) { let expr_inner = else_expr.as_mut().unwrap(); - match evaluate(expr_inner) { - Ok(s) => *expr_inner.as_mut() = Expr::Literal(s), - Err(_) => (), + if let Ok(s) = evaluate(expr_inner) { + *expr_inner.as_mut() = Expr::Literal(s); } } @@ -421,15 +408,13 @@ impl<'a> ConstantRewriter<'a> { let when: &mut Expr = when; let then: &mut Expr = then; if self.rewrite_const_expr(when) { - match evaluate(when) { - Ok(s) => *when = Expr::Literal(s), - _ => (), + if let Ok(s) = evaluate(when) { + *when = Expr::Literal(s); } } if self.rewrite_const_expr(then) { - match evaluate(then) { - Ok(s) => *then = Expr::Literal(s), - Err(_) => (), + if let Ok(s) = evaluate(then) { + *then = Expr::Literal(s); } } } @@ -439,12 +424,9 @@ impl<'a> ConstantRewriter<'a> { Expr::TryCast { expr, .. } => self.rewrite_const_expr(expr), Expr::Sort { expr, .. } => { if self.rewrite_const_expr(expr) { - match evaluate(expr) { - Ok(s) => { - let expr: &mut Expr = expr; - *expr = Expr::Literal(s); - } - Err(_) => (), + if let Ok(s) = evaluate(expr) { + let expr: &mut Expr = expr; + *expr = Expr::Literal(s); } } false @@ -529,7 +511,6 @@ impl<'a> ConstantRewriter<'a> { } Expr::Wildcard => false, }; - println!("Can rewrite the expr[{}]: {}", can_rewrite, expr); can_rewrite } } @@ -673,10 +654,7 @@ mod tests { ); // test constant operands - assert_eq!( - rewriter.rewrite(lit(1).eq(lit(true))), - lit(1).eq(lit(true)), - ); + assert_eq!(rewriter.rewrite(lit(1).eq(lit(true))), lit(1).eq(lit(true)),); assert_eq!( rewriter.rewrite(lit("a").eq(lit(false))), @@ -1189,10 +1167,8 @@ mod tests { .build() .unwrap(); let actual = get_optimized_plan_formatted(&plan, &time); - let expected = format!( - "Projection: Float64(4) * #test.e AS constant, pow_stable(Float64(2), Int32(2)) * #test.e AS stable, pow_vol(Float64(2), Int32(2)) * #test.e\ - \n TableScan: test projection=None", - ); + let expected = "Projection: Float64(4) * #test.e AS constant, pow_stable(Float64(2), Int32(2)) * #test.e AS stable, pow_vol(Float64(2), Int32(2)) * #test.e\ + \n TableScan: test projection=None".to_string(); assert_eq!(actual, expected); Ok(()) From a87e8c7fa20448719eb15233e02aa0b465bc2b32 Mon Sep 17 00:00:00 2001 From: patrick Date: Fri, 15 Oct 2021 17:58:30 -0600 Subject: [PATCH 5/6] Moved evaluate function --- datafusion/src/optimizer/constant_folding.rs | 55 ++----------------- .../src/physical_plan/expressions/mod.rs | 53 ++++++++++++++++++ 2 files changed, 59 insertions(+), 49 deletions(-) diff --git a/datafusion/src/optimizer/constant_folding.rs b/datafusion/src/optimizer/constant_folding.rs index 7bd828e060870..a9166532ff0ff 100644 --- a/datafusion/src/optimizer/constant_folding.rs +++ b/datafusion/src/optimizer/constant_folding.rs @@ -19,17 +19,15 @@ use std::sync::Arc; -use arrow::array::Float64Array; -use arrow::datatypes::{DataType, Field, Schema}; -use arrow::record_batch::RecordBatch; +use arrow::datatypes::DataType; -use crate::error::{DataFusionError, Result}; -use crate::execution::context::{ExecutionContextState, ExecutionProps}; -use crate::logical_plan::{DFSchema, DFSchemaRef, Expr, LogicalPlan, Operator}; +use crate::error::Result; +use crate::execution::context::ExecutionProps; +use crate::logical_plan::{DFSchemaRef, Expr, LogicalPlan, Operator}; use crate::optimizer::optimizer::OptimizerRule; use crate::optimizer::utils; +use crate::physical_plan::expressions::helpers::evaluate; use crate::physical_plan::functions::{BuiltinScalarFunction, Volatility}; -use crate::physical_plan::planner::DefaultPhysicalPlanner; use crate::scalar::ScalarValue; struct ConstantRewriter<'a> { @@ -116,48 +114,6 @@ impl OptimizerRule for ConstantFolding { } } -///Evaluate calculates the value of scalar expressions. This function may panic if columns are present within the expression -pub fn evaluate(expr: &Expr) -> Result { - if let Expr::Literal(s) = expr { - return Ok(s.clone()); - } - //The dummy column name shouldn't really matter as only scalar expressions should be evaluated - static DUMMY_COL_NAME: &str = "."; - let dummy_df_schema = DFSchema::empty(); - let dummy_input_schema = - Schema::new(vec![Field::new(DUMMY_COL_NAME, DataType::Float64, true)]); - let ctx_state = ExecutionContextState::new(); - - let planner = DefaultPhysicalPlanner::default(); - let phys_expr = planner.create_physical_expr( - expr, - &dummy_df_schema, - &dummy_input_schema, - &ctx_state, - )?; - let col = { - let mut builder = Float64Array::builder(1); - builder.append_null()?; - builder.finish() - }; - let record_batch = - RecordBatch::try_new(Arc::new(dummy_input_schema), vec![Arc::new(col)])?; - let col_val = phys_expr.evaluate(&record_batch)?; - match col_val { - crate::physical_plan::ColumnarValue::Array(a) => { - if a.len() != 1 { - Err(DataFusionError::Execution(format!( - "Could not evaluate the expressison, found a result of length {}", - a.len() - ))) - } else { - Ok(ScalarValue::try_from_array(&a, 0)?) - } - } - crate::physical_plan::ColumnarValue::Scalar(s) => Ok(s), - } -} - impl<'a> ConstantRewriter<'a> { fn is_boolean_type(&self, expr: &Expr) -> bool { for schema in &self.schemas { @@ -525,6 +481,7 @@ mod tests { physical_plan::{functions::make_scalar_function, udf::ScalarUDF}, }; use arrow::array::{ArrayRef, Float64Array}; + use arrow::datatypes::{Field, Schema}; use chrono::{DateTime, Utc}; diff --git a/datafusion/src/physical_plan/expressions/mod.rs b/datafusion/src/physical_plan/expressions/mod.rs index 9f7a6cc6b5fb3..3580ab531eb61 100644 --- a/datafusion/src/physical_plan/expressions/mod.rs +++ b/datafusion/src/physical_plan/expressions/mod.rs @@ -52,7 +52,60 @@ mod try_cast; /// Module with some convenient methods used in expression building pub mod helpers { + pub use super::min_max::{max, min}; + + use crate::error::{DataFusionError, Result}; + use crate::logical_plan::{DFSchema, Expr}; + use crate::scalar::ScalarValue; + use arrow::datatypes::{DataType, Field}; + use arrow::record_batch::RecordBatch; + ///Evaluate calculates the value of scalar expressions. This function may panic if non-constant expressions are present within the expression + pub fn evaluate(expr: &Expr) -> Result { + if let Expr::Literal(s) = expr { + return Ok(s.clone()); + } + //The dummy column name shouldn't really matter as only scalar expressions will be evaluated + static DUMMY_COL_NAME: &str = "."; + let dummy_df_schema = DFSchema::empty(); + let dummy_input_schema = arrow::datatypes::Schema::new(vec![Field::new( + DUMMY_COL_NAME, + DataType::Float64, + true, + )]); + let ctx_state = crate::execution::context::ExecutionContextState::new(); + + let planner = crate::physical_plan::planner::DefaultPhysicalPlanner::default(); + let phys_expr = planner.create_physical_expr( + expr, + &dummy_df_schema, + &dummy_input_schema, + &ctx_state, + )?; + let col = { + let mut builder = arrow::array::Float64Array::builder(1); + builder.append_null()?; + builder.finish() + }; + let record_batch = RecordBatch::try_new( + std::sync::Arc::new(dummy_input_schema), + vec![std::sync::Arc::new(col)], + )?; + let col_val = phys_expr.evaluate(&record_batch)?; + match col_val { + crate::physical_plan::ColumnarValue::Array(a) => { + if a.len() != 1 { + Err(DataFusionError::Execution(format!( + "Could not evaluate the expressison, found a result of length {}", + a.len() + ))) + } else { + Ok(ScalarValue::try_from_array(&a, 0)?) + } + } + crate::physical_plan::ColumnarValue::Scalar(s) => Ok(s), + } + } } pub use average::{avg_return_type, Avg, AvgAccumulator}; From 81972d53fece1c502029728c35717b0163226957 Mon Sep 17 00:00:00 2001 From: patrick Date: Sun, 17 Oct 2021 13:41:50 -0600 Subject: [PATCH 6/6] Moved and renamed evalute function. Added function to test if expr is literal --- datafusion/src/optimizer/constant_folding.rs | 32 ++--- datafusion/src/optimizer/utils.rs | 116 ++++++++++++++++++ .../src/physical_plan/expressions/mod.rs | 53 -------- 3 files changed, 132 insertions(+), 69 deletions(-) diff --git a/datafusion/src/optimizer/constant_folding.rs b/datafusion/src/optimizer/constant_folding.rs index a9166532ff0ff..e1f36cdeee1c3 100644 --- a/datafusion/src/optimizer/constant_folding.rs +++ b/datafusion/src/optimizer/constant_folding.rs @@ -26,7 +26,7 @@ use crate::execution::context::ExecutionProps; use crate::logical_plan::{DFSchemaRef, Expr, LogicalPlan, Operator}; use crate::optimizer::optimizer::OptimizerRule; use crate::optimizer::utils; -use crate::physical_plan::expressions::helpers::evaluate; +use crate::optimizer::utils::evaluate_const_expr_unchecked; use crate::physical_plan::functions::{BuiltinScalarFunction, Volatility}; use crate::scalar::ScalarValue; @@ -133,7 +133,7 @@ impl<'a> ConstantRewriter<'a> { let rewrite_root = self.rewrite_const_expr(&mut expr); if rewrite_root { - match evaluate(&expr) { + match evaluate_const_expr_unchecked(&expr) { Ok(s) => expr = Expr::Literal(s), Err(_) => return expr, } @@ -161,14 +161,14 @@ impl<'a> ConstantRewriter<'a> { fn const_fold_list_eager(&mut self, args: &mut Vec) { for arg in args.iter_mut() { if self.rewrite_const_expr(arg) { - if let Ok(s) = evaluate(arg) { + if let Ok(s) = evaluate_const_expr_unchecked(arg) { *arg = Expr::Literal(s); } } } } ///Tests to see if the list passed in is all literal expressions, if they are then it returns true. - ///If some expressions are not literal then the literal expressions are evaluated and it returns false. + ///If some expressions are not literal then the literal expressions are evaluate_const_expr_uncheckedd and it returns false. fn const_fold_list(&mut self, args: &mut Vec) -> bool { let can_rewrite = args .iter_mut() @@ -179,7 +179,7 @@ impl<'a> ConstantRewriter<'a> { } else { for (rewrite_expr, expr) in can_rewrite.iter().zip(args) { if *rewrite_expr { - if let Ok(s) = evaluate(expr) { + if let Ok(s) = evaluate_const_expr_unchecked(expr) { *expr = Expr::Literal(s); } } @@ -262,13 +262,13 @@ impl<'a> ConstantRewriter<'a> { (true, true) => true, (false, false) => false, (true, false) => { - if let Ok(s) = evaluate(left) { + if let Ok(s) = evaluate_const_expr_unchecked(left) { *left.as_mut() = Expr::Literal(s); } false } (false, true) => { - if let Ok(s) = evaluate(right) { + if let Ok(s) = evaluate_const_expr_unchecked(right) { *right.as_mut() = Expr::Literal(s); } false @@ -313,19 +313,19 @@ impl<'a> ConstantRewriter<'a> { (true, true, true) => true, (expr_const, low_const, high_const) => { if expr_const { - if let Ok(s) = evaluate(expr) { + if let Ok(s) = evaluate_const_expr_unchecked(expr) { let expr: &mut Expr = expr; *expr = Expr::Literal(s); } } if low_const { - if let Ok(s) = evaluate(expr) { + if let Ok(s) = evaluate_const_expr_unchecked(expr) { let expr: &mut Expr = expr; *expr = Expr::Literal(s); } } if high_const { - if let Ok(s) = evaluate(expr) { + if let Ok(s) = evaluate_const_expr_unchecked(expr) { let expr: &mut Expr = expr; *expr = Expr::Literal(s); } @@ -344,7 +344,7 @@ impl<'a> ConstantRewriter<'a> { .unwrap_or(false) { let expr_inner = expr.as_mut().unwrap(); - if let Ok(s) = evaluate(expr_inner) { + if let Ok(s) = evaluate_const_expr_unchecked(expr_inner) { *expr_inner.as_mut() = Expr::Literal(s); } } @@ -355,7 +355,7 @@ impl<'a> ConstantRewriter<'a> { .unwrap_or(false) { let expr_inner = else_expr.as_mut().unwrap(); - if let Ok(s) = evaluate(expr_inner) { + if let Ok(s) = evaluate_const_expr_unchecked(expr_inner) { *expr_inner.as_mut() = Expr::Literal(s); } } @@ -364,12 +364,12 @@ impl<'a> ConstantRewriter<'a> { let when: &mut Expr = when; let then: &mut Expr = then; if self.rewrite_const_expr(when) { - if let Ok(s) = evaluate(when) { + if let Ok(s) = evaluate_const_expr_unchecked(when) { *when = Expr::Literal(s); } } if self.rewrite_const_expr(then) { - if let Ok(s) = evaluate(then) { + if let Ok(s) = evaluate_const_expr_unchecked(then) { *then = Expr::Literal(s); } } @@ -380,7 +380,7 @@ impl<'a> ConstantRewriter<'a> { Expr::TryCast { expr, .. } => self.rewrite_const_expr(expr), Expr::Sort { expr, .. } => { if self.rewrite_const_expr(expr) { - if let Ok(s) = evaluate(expr) { + if let Ok(s) = evaluate_const_expr_unchecked(expr) { let expr: &mut Expr = expr; *expr = Expr::Literal(s); } @@ -453,7 +453,7 @@ impl<'a> ConstantRewriter<'a> { (false, false) => false, (true, false) => { - if let Ok(s) = evaluate(expr) { + if let Ok(s) = evaluate_const_expr_unchecked(expr) { let expr: &mut Expr = expr; *expr = Expr::Literal(s); } diff --git a/datafusion/src/optimizer/utils.rs b/datafusion/src/optimizer/utils.rs index 6e64bf39b2e2d..2ccd903aa3375 100644 --- a/datafusion/src/optimizer/utils.rs +++ b/datafusion/src/optimizer/utils.rs @@ -23,6 +23,7 @@ use crate::logical_plan::{ build_join_schema, Column, DFSchemaRef, Expr, LogicalPlan, LogicalPlanBuilder, Operator, Partitioning, Recursion, }; +use crate::physical_plan::functions::Volatility; use crate::prelude::lit; use crate::scalar::ScalarValue; use crate::{ @@ -31,6 +32,10 @@ use crate::{ }; use std::{collections::HashSet, sync::Arc}; +use crate::logical_plan::DFSchema; +use arrow::datatypes::{DataType, Field}; +use arrow::record_batch::RecordBatch; + const CASE_EXPR_MARKER: &str = "__DATAFUSION_CASE_EXPR__"; const CASE_ELSE_MARKER: &str = "__DATAFUSION_CASE_ELSE__"; const WINDOW_PARTITION_MARKER: &str = "__DATAFUSION_WINDOW_PARTITION__"; @@ -468,6 +473,117 @@ pub fn rewrite_expression(expr: &Expr, expressions: &[Expr]) -> Result { } } +///Tests an expression to see if it contains only literal expressions such as 3+4.5 and immutable scalar builtins or UDFs. +pub fn expr_is_const(expr: &Expr) -> bool { + match expr { + Expr::Column(_) + | Expr::ScalarVariable(_) + | Expr::AggregateFunction { .. } + | Expr::WindowFunction { .. } + | Expr::AggregateUDF { .. } + | Expr::Sort { .. } + | Expr::Wildcard => false, + + Expr::Literal(_) => true, + + Expr::Alias(child, _) + | Expr::Not(child) + | Expr::IsNotNull(child) + | Expr::IsNull(child) + | Expr::Negative(child) + | Expr::Cast { expr: child, .. } + | Expr::TryCast { expr: child, .. } => expr_is_const(child), + + Expr::ScalarFunction { fun, args } => match fun.volatility() { + Volatility::Immutable => args.iter().all(|arg| expr_is_const(arg)), + Volatility::Stable | Volatility::Volatile => false, + }, + Expr::BinaryExpr { left, right, .. } => { + expr_is_const(left) && expr_is_const(right) + } + Expr::Between { + expr, low, high, .. + } => expr_is_const(expr) && expr_is_const(low) && expr_is_const(high), + Expr::Case { + expr, + when_then_expr, + else_expr, + } => { + let expr_constant = expr.as_ref().map(|e| expr_is_const(e)).unwrap_or(true); + let else_constant = + else_expr.as_ref().map(|e| expr_is_const(e)).unwrap_or(true); + let when_then_constant = when_then_expr + .iter() + .all(|(w, th)| expr_is_const(w) && expr_is_const(th)); + expr_constant && else_constant && when_then_constant + } + + Expr::ScalarUDF { fun, args } => match fun.volatility() { + Volatility::Immutable => args.iter().all(|arg| expr_is_const(arg)), + Volatility::Stable | Volatility::Volatile => false, + }, + + Expr::InList { expr, list, .. } => { + expr_is_const(expr) && list.iter().all(|e| expr_is_const(e)) + } + } +} + +///Evaluates an expression if it only contains literal expressions. If a non-literal expression or non-immutable function is found then it returns an error. +pub fn evalute_const_expr(expr: &Expr) -> Result { + if !expr_is_const(expr) { + return Err(DataFusionError::Execution("The expression was not contsant and could not be evaluated. This means it contained stable or volatile functions, aggregates, or column expressions.".to_owned())); + } + evaluate_const_expr_unchecked(expr) +} + +///Evaluates an expression. Note that this will return incorrect results if the expression contains function calls which change return values call to call, such as Now(). +pub fn evaluate_const_expr_unchecked(expr: &Expr) -> Result { + if let Expr::Literal(s) = expr { + return Ok(s.clone()); + } + //The dummy column name shouldn't really matter as only scalar expressions will be evaluated + static DUMMY_COL_NAME: &str = "."; + let dummy_df_schema = DFSchema::empty(); + let dummy_input_schema = arrow::datatypes::Schema::new(vec![Field::new( + DUMMY_COL_NAME, + DataType::Float64, + true, + )]); + let ctx_state = crate::execution::context::ExecutionContextState::new(); + + let planner = crate::physical_plan::planner::DefaultPhysicalPlanner::default(); + let phys_expr = planner.create_physical_expr( + expr, + &dummy_df_schema, + &dummy_input_schema, + &ctx_state, + )?; + let col = { + let mut builder = arrow::array::Float64Array::builder(1); + builder.append_null()?; + builder.finish() + }; + let record_batch = RecordBatch::try_new( + std::sync::Arc::new(dummy_input_schema), + vec![std::sync::Arc::new(col)], + )?; + let col_val = phys_expr.evaluate(&record_batch)?; + match col_val { + crate::physical_plan::ColumnarValue::Array(a) => { + if a.len() != 1 { + Err(DataFusionError::Execution(format!( + "Could not evaluate the expressison, found a result of length {}", + a.len() + ))) + } else { + Ok(ScalarValue::try_from_array(&a, 0)?) + } + } + crate::physical_plan::ColumnarValue::Scalar(s) => Ok(s), + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/datafusion/src/physical_plan/expressions/mod.rs b/datafusion/src/physical_plan/expressions/mod.rs index 3580ab531eb61..9f7a6cc6b5fb3 100644 --- a/datafusion/src/physical_plan/expressions/mod.rs +++ b/datafusion/src/physical_plan/expressions/mod.rs @@ -52,60 +52,7 @@ mod try_cast; /// Module with some convenient methods used in expression building pub mod helpers { - pub use super::min_max::{max, min}; - - use crate::error::{DataFusionError, Result}; - use crate::logical_plan::{DFSchema, Expr}; - use crate::scalar::ScalarValue; - use arrow::datatypes::{DataType, Field}; - use arrow::record_batch::RecordBatch; - ///Evaluate calculates the value of scalar expressions. This function may panic if non-constant expressions are present within the expression - pub fn evaluate(expr: &Expr) -> Result { - if let Expr::Literal(s) = expr { - return Ok(s.clone()); - } - //The dummy column name shouldn't really matter as only scalar expressions will be evaluated - static DUMMY_COL_NAME: &str = "."; - let dummy_df_schema = DFSchema::empty(); - let dummy_input_schema = arrow::datatypes::Schema::new(vec![Field::new( - DUMMY_COL_NAME, - DataType::Float64, - true, - )]); - let ctx_state = crate::execution::context::ExecutionContextState::new(); - - let planner = crate::physical_plan::planner::DefaultPhysicalPlanner::default(); - let phys_expr = planner.create_physical_expr( - expr, - &dummy_df_schema, - &dummy_input_schema, - &ctx_state, - )?; - let col = { - let mut builder = arrow::array::Float64Array::builder(1); - builder.append_null()?; - builder.finish() - }; - let record_batch = RecordBatch::try_new( - std::sync::Arc::new(dummy_input_schema), - vec![std::sync::Arc::new(col)], - )?; - let col_val = phys_expr.evaluate(&record_batch)?; - match col_val { - crate::physical_plan::ColumnarValue::Array(a) => { - if a.len() != 1 { - Err(DataFusionError::Execution(format!( - "Could not evaluate the expressison, found a result of length {}", - a.len() - ))) - } else { - Ok(ScalarValue::try_from_array(&a, 0)?) - } - } - crate::physical_plan::ColumnarValue::Scalar(s) => Ok(s), - } - } } pub use average::{avg_return_type, Avg, AvgAccumulator};