From 42d99cac386f4e75c9c0c0126846bdfc1bab36bd Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Mon, 18 Oct 2021 16:07:10 -0400 Subject: [PATCH] Sketch of an alternate algorithm for partial evaluation --- datafusion/src/optimizer/utils.rs | 197 +++++++++++++++++++++++++++++- 1 file changed, 193 insertions(+), 4 deletions(-) diff --git a/datafusion/src/optimizer/utils.rs b/datafusion/src/optimizer/utils.rs index 6e64bf39b2e2d..731d662fe80cf 100644 --- a/datafusion/src/optimizer/utils.rs +++ b/datafusion/src/optimizer/utils.rs @@ -17,12 +17,18 @@ //! Collection of utility functions that are leveraged by the query optimizer rules +use arrow::array::new_null_array; +use arrow::datatypes::{DataType, Field, Schema}; +use arrow::record_batch::RecordBatch; + use super::optimizer::OptimizerRule; -use crate::execution::context::ExecutionProps; +use crate::execution::context::{ExecutionContextState, ExecutionProps}; use crate::logical_plan::{ - build_join_schema, Column, DFSchemaRef, Expr, LogicalPlan, LogicalPlanBuilder, - Operator, Partitioning, Recursion, + build_join_schema, Column, DFSchema, DFSchemaRef, Expr, ExprRewriter, LogicalPlan, + LogicalPlanBuilder, Operator, Partitioning, Recursion, }; +use crate::physical_plan::functions::Volatility; +use crate::physical_plan::planner::DefaultPhysicalPlanner; use crate::prelude::lit; use crate::scalar::ScalarValue; use crate::{ @@ -468,10 +474,144 @@ pub fn rewrite_expression(expr: &Expr, expressions: &[Expr]) -> Result { } } +/// Evaluates any sub expressions that are constants within `expr`. +/// +/// For example, will rewrite `'foo' != bar OR col1 = 'baz'` to `false +/// OR col1 = 'baz'` +pub fn partially_evaluate_expr(expr: Expr) -> Result { + let mut evaluator = ExprEvaluator::new(); + + expr.rewrite(&mut evaluator) +} + +struct ExprEvaluator { + /// can_evaluate[N] represents the state of traversal when we are + /// N levels deep in the tree. when mutate is called (after + /// visiting all siblings) if can_evauate.top() is true, means there were no non-constants for any siblings + /// no non-constant values found in either this Expr or any + can_evaluate: Vec, + + ctx_state: ExecutionContextState, + planner: DefaultPhysicalPlanner, + input_schema: DFSchema, + input_batch: RecordBatch, +} + +impl ExprRewriter for ExprEvaluator { + fn mutate(&mut self, expr: Expr) -> Result { + // check for reasons we can't evaluate this node + let self_ok_to_evaluate = match &expr { + Expr::Column(_) => false, + Expr::ScalarFunction { fun, .. } => Self::volatility_ok(fun.volatility()), + Expr::ScalarUDF { fun, .. } => Self::volatility_ok(fun.signature.volatility), + _ => true, + }; + + // if this expr is not ok to evaluate, mark entire parent stack as not ok + if !self_ok_to_evaluate { + // walk back up stack, marking first parent that is not mutable + let mut parent_iter = self.can_evaluate.iter_mut().rev(); + while let Some(p) = parent_iter.next() { + if !*p { + // optimization: if we find an element on the + // stack already marked, know all elements above are also marked + break; + } + *p = false; + } + } + + // pre_visit pushed, can pop here + let ok_to_evaluate = self.can_evaluate.pop().unwrap(); + + if ok_to_evaluate { + let scalar = self.evaluate_to_scalar(expr)?; + Ok(Expr::Literal(scalar)) + } else { + Ok(expr) + } + } + + fn pre_visit( + &mut self, + _expr: &Expr, + ) -> Result { + // Default to being able to evaluate this node + self.can_evaluate.push(true); + + Ok(crate::logical_plan::RewriteRecursion::Continue) + } +} + +impl ExprEvaluator { + pub fn new() -> Self { + let planner = DefaultPhysicalPlanner::default(); + let ctx_state = ExecutionContextState::new(); + let input_schema = DFSchema::empty(); + + //The dummy column name shouldn't really matter as only scalar expressions will be evaluated + static DUMMY_COL_NAME: &str = "."; + let schema = + Schema::new(vec![Field::new(DUMMY_COL_NAME, DataType::Float64, true)]); + + let col = new_null_array(&DataType::Float64, 1); + + let input_batch = + RecordBatch::try_new(std::sync::Arc::new(schema), vec![col]).unwrap(); + + Self { + can_evaluate: vec![], + ctx_state, + planner, + input_schema, + input_batch, + } + } + + /// Can a fuction of the specified volatility be evaluated? + fn volatility_ok(volatility: Volatility) -> bool { + match volatility { + Volatility::Immutable => true, + Volatility::Stable => true, + Volatility::Volatile => false, + } + } + + fn evaluate_to_scalar(&self, expr: Expr) -> Result { + if let Expr::Literal(s) = expr { + return Ok(s.clone()); + } + + let phys_expr = self.planner.create_physical_expr( + &expr, + &self.input_schema, + &self.input_batch.schema(), + &self.ctx_state, + )?; + let col_val = phys_expr.evaluate(&self.input_batch)?; + match col_val { + crate::physical_plan::ColumnarValue::Array(a) => { + if a.len() != 1 { + Err(DataFusionError::Execution(format!( + "Could not evaluate the expressison, found a result of length {}", + a.len() + ))) + } else { + Ok(ScalarValue::try_from_array(&a, 0)?) + } + } + crate::physical_plan::ColumnarValue::Scalar(s) => Ok(s), + } + } +} + #[cfg(test)] mod tests { use super::*; - use crate::logical_plan::col; + use crate::{ + logical_plan::{col, lit_timestamp_nano}, + physical_plan::functions::BuiltinScalarFunction, + }; use arrow::datatypes::DataType; use std::collections::HashSet; @@ -496,4 +636,53 @@ mod tests { assert!(accum.contains(&Column::from_name("a"))); Ok(()) } + + #[test] + fn test_expr_evaluator() { + test_evaluate(lit(true), lit(true)); + test_evaluate(lit(true).or(lit(true)), lit(true)); + test_evaluate(lit(true).or(lit(false)), lit(true)); + + // "foo" == "foo" + test_evaluate(lit("foo").eq(lit("foo")), lit(true)); + // "foo" != "foo" + test_evaluate(lit("foo").not_eq(lit("foo")), lit(false)); + + // c = 1 + test_evaluate(col("c").eq(lit(1)), col("c").eq(lit(1))); + // c = 1 + 2 --> c + 3 + test_evaluate(col("c").eq(lit(1) + lit(2)), col("c").eq(lit(3))); + test_evaluate( + (lit("foo").not_eq(lit("foo"))).or(col("c").eq(lit(1))), + lit(false).or(col("c").eq(lit(1))), + ); + + // test function evaluation + let to_timestamp = Expr::ScalarFunction { + args: vec![lit("foo"), lit("bar")], + fun: BuiltinScalarFunction::Concat, + }; + test_evaluate(to_timestamp, lit("foobar")); + + // test function evaluation + let to_timestamp = Expr::ScalarFunction { + args: vec![lit("2020-09-08T12:00:00+00:00")], + fun: BuiltinScalarFunction::ToTimestamp, + }; + test_evaluate(to_timestamp, lit_timestamp_nano(1599566400000000000i64)); + + // TODO write some more tests for: + // to timestamp with col arguments + // now() + // volatile functions, etc (rand) + } + + fn test_evaluate(input_expr: Expr, expected_expr: Expr) { + let evaluated_expr = partially_evaluate_expr(input_expr.clone()).unwrap(); + assert_eq!( + evaluated_expr, expected_expr, + "Mismatch evaluating {}\n Expected:{}\n Got:{}", + input_expr, expected_expr, evaluated_expr + ); + } }