From c853c6f127fa4de19d41df6f1c44ecf5a08e3fff Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Sat, 20 Feb 2021 05:34:07 -0500 Subject: [PATCH 1/2] ARROW-11710: [Rust][DataFusion] Implement ExpressionRewriter --- rust/datafusion/src/logical_plan/expr.rs | 259 +++++++++++ rust/datafusion/src/logical_plan/mod.rs | 1 + .../src/optimizer/constant_folding.rs | 401 +++++++----------- 3 files changed, 424 insertions(+), 237 deletions(-) diff --git a/rust/datafusion/src/logical_plan/expr.rs b/rust/datafusion/src/logical_plan/expr.rs index 6dadefea548..775ab64ac14 100644 --- a/rust/datafusion/src/logical_plan/expr.rs +++ b/rust/datafusion/src/logical_plan/expr.rs @@ -563,6 +563,179 @@ impl Expr { visitor.post_visit(self) } + + /// Performs a depth first walk of an expression and its children + /// to rewrite an expression, consuming `self` producing a new + /// [`Expr`]. + /// + /// Implements a modified version of the [visitor + /// pattern](https://en.wikipedia.org/wiki/Visitor_pattern) to + /// separate algorithms from the structure of the `Expr` tree and + /// make it easier to write new, efficient expression + /// transformation algorithms. + /// + /// For an expression tree such as + /// ```text + /// BinaryExpr (GT) + /// left: Column("foo") + /// right: Column("bar") + /// ``` + /// + /// The nodes are visited using the following order + /// ```text + /// pre_visit(BinaryExpr(GT)) + /// pre_visit(Column("foo")) + /// mutatate(Column("foo")) + /// pre_visit(Column("bar")) + /// mutate(Column("bar")) + /// mutate(BinaryExpr(GT)) + /// ``` + /// + /// If an Err result is returned, recursion is stopped immediately + /// + /// If [`false`] is returned on a call to pre_visit, no + /// children of that expression are visited, nor is mutate + /// called on that expression + /// + pub fn rewrite(self, rewriter: &mut R) -> Result + where + R: ExprRewriter, + { + if !rewriter.pre_visit(&self)? { + return Ok(self); + }; + + // recurse into all sub expressions(and cover all expression types) + let expr = match self { + Expr::Alias(expr, name) => Expr::Alias(rewrite_boxed(expr, rewriter)?, name), + Expr::Column(name) => Expr::Column(name), + Expr::ScalarVariable(names) => Expr::ScalarVariable(names), + Expr::Literal(value) => Expr::Literal(value), + Expr::BinaryExpr { left, op, right } => Expr::BinaryExpr { + left: rewrite_boxed(left, rewriter)?, + op, + right: rewrite_boxed(right, rewriter)?, + }, + Expr::Not(expr) => Expr::Not(rewrite_boxed(expr, rewriter)?), + Expr::IsNotNull(expr) => Expr::IsNotNull(rewrite_boxed(expr, rewriter)?), + Expr::IsNull(expr) => Expr::IsNull(rewrite_boxed(expr, rewriter)?), + Expr::Negative(expr) => Expr::Negative(rewrite_boxed(expr, rewriter)?), + Expr::Between { + expr, + low, + high, + negated, + } => Expr::Between { + expr: rewrite_boxed(expr, rewriter)?, + low: rewrite_boxed(low, rewriter)?, + high: rewrite_boxed(high, rewriter)?, + negated, + }, + Expr::Case { + expr, + when_then_expr, + else_expr, + } => { + let expr = rewrite_option_box(expr, rewriter)?; + let when_then_expr = when_then_expr + .into_iter() + .map(|(when, then)| { + Ok(( + rewrite_boxed(when, rewriter)?, + rewrite_boxed(then, rewriter)?, + )) + }) + .collect::>>()?; + + let else_expr = rewrite_option_box(else_expr, rewriter)?; + + Expr::Case { + expr, + when_then_expr, + else_expr, + } + } + Expr::Cast { expr, data_type } => Expr::Cast { + expr: rewrite_boxed(expr, rewriter)?, + data_type, + }, + Expr::Sort { + expr, + asc, + nulls_first, + } => Expr::Sort { + expr: rewrite_boxed(expr, rewriter)?, + asc, + nulls_first, + }, + Expr::ScalarFunction { args, fun } => Expr::ScalarFunction { + args: rewrite_vec(args, rewriter)?, + fun, + }, + Expr::ScalarUDF { args, fun } => Expr::ScalarUDF { + args: rewrite_vec(args, rewriter)?, + fun, + }, + Expr::AggregateFunction { + args, + fun, + distinct, + } => Expr::AggregateFunction { + args: rewrite_vec(args, rewriter)?, + fun, + distinct, + }, + Expr::AggregateUDF { args, fun } => Expr::AggregateUDF { + args: rewrite_vec(args, rewriter)?, + fun, + }, + Expr::InList { + expr, + list, + negated, + } => Expr::InList { + expr: rewrite_boxed(expr, rewriter)?, + list, + negated, + }, + Expr::Wildcard => Expr::Wildcard, + }; + + // now rewrite this expression itself + rewriter.mutate(expr) + } +} + +#[allow(clippy::boxed_local)] +fn rewrite_boxed(boxed_expr: Box, rewriter: &mut R) -> Result> +where + R: ExprRewriter, +{ + // TODO: It might be possible to avoid an allocation (the + // Box::new) below by reusing the box. + let expr: Expr = *boxed_expr; + let rewritten_expr = expr.rewrite(rewriter)?; + Ok(Box::new(rewritten_expr)) +} + +fn rewrite_option_box( + option_box: Option>, + rewriter: &mut R, +) -> Result>> +where + R: ExprRewriter, +{ + option_box + .map(|expr| rewrite_boxed(expr, rewriter)) + .transpose() +} + +/// rewrite a `Vec` of `Expr`s with the rewriter +fn rewrite_vec(v: Vec, rewriter: &mut R) -> Result> +where + R: ExprRewriter, +{ + v.into_iter().map(|expr| expr.rewrite(rewriter)).collect() } /// Controls how the visitor recursion should proceed. @@ -589,6 +762,22 @@ pub trait ExpressionVisitor: Sized { } } +/// Trait for potentially recursively rewriting an [`Expr`] expression +/// tree. When passed to `Expr::rewrite`, `ExpressionVisitor::mutate` is +/// invoked recursively on all nodes of an expression tree. See the +/// comments on `Expr::rewrite` for details on its use +pub trait ExprRewriter: Sized { + /// Invoked before any children of `expr` are rewritten / + /// visited. Default implementation returns `Ok(true)` + fn pre_visit(&mut self, _expr: &Expr) -> Result { + Ok(true) + } + + /// Invoked after all children of `expr` have been mutated and + /// returns a potentially modified expr. + fn mutate(&mut self, expr: Expr) -> Result; +} + pub struct CaseBuilder { expr: Option>, when_expr: Vec, @@ -1180,4 +1369,74 @@ mod tests { .end(); assert!(maybe_expr.is_err()); } + + #[test] + fn rewriter_visit() { + let mut rewriter = RecordingRewriter::default(); + col("state").eq(lit("CO")).rewrite(&mut rewriter).unwrap(); + + assert_eq!( + rewriter.v, + vec![ + "Previsited #state Eq Utf8(\"CO\")", + "Previsited #state", + "Mutated #state", + "Previsited Utf8(\"CO\")", + "Mutated Utf8(\"CO\")", + "Mutated #state Eq Utf8(\"CO\")" + ] + ) + } + + #[derive(Default)] + struct RecordingRewriter { + v: Vec, + } + impl ExprRewriter for RecordingRewriter { + fn mutate(&mut self, expr: Expr) -> Result { + self.v.push(format!("Mutated {:?}", expr)); + Ok(expr) + } + + fn pre_visit(&mut self, expr: &Expr) -> Result { + self.v.push(format!("Previsited {:?}", expr)); + Ok(true) + } + } + + #[test] + fn rewriter_rewrite() { + let mut rewriter = FooBarRewriter {}; + + // rewrites "foo" --> "bar" + let rewritten = col("state").eq(lit("foo")).rewrite(&mut rewriter).unwrap(); + assert_eq!(rewritten, col("state").eq(lit("bar"))); + + // doesn't wrewrite + let rewritten = col("state").eq(lit("baz")).rewrite(&mut rewriter).unwrap(); + assert_eq!(rewritten, col("state").eq(lit("baz"))); + } + + /// rewrites all "foo" string literals to "bar" + struct FooBarRewriter {} + impl ExprRewriter for FooBarRewriter { + fn mutate(&mut self, expr: Expr) -> Result { + match expr { + Expr::Literal(scalar) => { + if let ScalarValue::Utf8(Some(utf8_val)) = scalar { + let utf8_val = if utf8_val == "foo" { + "bar".to_string() + } else { + utf8_val + }; + Ok(lit(utf8_val)) + } else { + Ok(Expr::Literal(scalar)) + } + } + // otherwise, return the expression unchanged + expr => Ok(expr), + } + } + } } diff --git a/rust/datafusion/src/logical_plan/mod.rs b/rust/datafusion/src/logical_plan/mod.rs index 99c35fafd54..2c902a000b2 100644 --- a/rust/datafusion/src/logical_plan/mod.rs +++ b/rust/datafusion/src/logical_plan/mod.rs @@ -39,6 +39,7 @@ pub use expr::{ length, lit, ln, log10, log2, lower, ltrim, max, md5, min, octet_length, or, round, rtrim, sha224, sha256, sha384, sha512, signum, sin, sqrt, substr, sum, tan, trim, trunc, upper, when, Expr, ExpressionVisitor, Literal, Recursion, + ExprRewriter }; pub use extension::UserDefinedLogicalNode; pub use operators::Operator; diff --git a/rust/datafusion/src/optimizer/constant_folding.rs b/rust/datafusion/src/optimizer/constant_folding.rs index 86cadf6405e..409b2c7733a 100644 --- a/rust/datafusion/src/optimizer/constant_folding.rs +++ b/rust/datafusion/src/optimizer/constant_folding.rs @@ -23,7 +23,7 @@ use std::sync::Arc; use arrow::datatypes::DataType; use crate::error::Result; -use crate::logical_plan::{DFSchemaRef, Expr, LogicalPlan, Operator}; +use crate::logical_plan::{DFSchemaRef, Expr, ExprRewriter, LogicalPlan, Operator}; use crate::optimizer::optimizer::OptimizerRule; use crate::optimizer::utils; use crate::scalar::ScalarValue; @@ -53,10 +53,13 @@ impl OptimizerRule for ConstantFolding { // projected columns. With just the projected schema, it's not possible to infer types for // expressions that references non-projected columns within the same project plan or its // children plans. + let mut rewriter = ConstantRewriter { + schemas: plan.all_schemas(), + }; match plan { LogicalPlan::Filter { predicate, input } => Ok(LogicalPlan::Filter { - predicate: optimize_expr(predicate, &plan.all_schemas())?, + predicate: predicate.clone().rewrite(&mut rewriter)?, input: Arc::new(self.optimize(input)?), }), // Rest: recurse into plan, apply optimization where possible @@ -76,10 +79,9 @@ impl OptimizerRule for ConstantFolding { .map(|plan| self.optimize(plan)) .collect::>>()?; - let schemas = plan.all_schemas(); let expr = utils::expressions(plan) - .iter() - .map(|e| optimize_expr(e, &schemas)) + .into_iter() + .map(|e| e.rewrite(&mut rewriter)) .collect::>>()?; utils::from_plan(plan, &expr, &new_inputs) @@ -95,208 +97,122 @@ impl OptimizerRule for ConstantFolding { } } -fn is_boolean_type(expr: &Expr, schemas: &[&DFSchemaRef]) -> bool { - for schema in schemas { - if let Ok(DataType::Boolean) = expr.get_type(schema) { - return true; +struct ConstantRewriter<'a> { + /// input schemas + schemas: Vec<&'a DFSchemaRef>, +} + +impl<'a> ConstantRewriter<'a> { + fn is_boolean_type(&self, expr: &Expr) -> bool { + for schema in &self.schemas { + if let Ok(DataType::Boolean) = expr.get_type(schema) { + return true; + } } - } - false + false + } } -/// Recursively transverses the expression tree. -fn optimize_expr(e: &Expr, schemas: &[&DFSchemaRef]) -> Result { - Ok(match e { - Expr::BinaryExpr { left, op, right } => { - let left = optimize_expr(left, schemas)?; - let right = optimize_expr(right, schemas)?; - match op { - Operator::Eq => match (&left, &right) { - ( - Expr::Literal(ScalarValue::Boolean(l)), - Expr::Literal(ScalarValue::Boolean(r)), - ) => match (l, r) { - (Some(l), Some(r)) => { - Expr::Literal(ScalarValue::Boolean(Some(l == r))) - } - _ => Expr::Literal(ScalarValue::Boolean(None)), - }, - (Expr::Literal(ScalarValue::Boolean(b)), _) - if is_boolean_type(&right, schemas) => - { - match b { - Some(true) => right, - Some(false) => Expr::Not(Box::new(right)), - None => Expr::Literal(ScalarValue::Boolean(None)), +impl<'a> ExprRewriter for ConstantRewriter<'a> { + /// rewrite the expression simplifying any constant expressions + fn mutate(&mut self, expr: Expr) -> Result { + let new_expr = match expr { + Expr::BinaryExpr { left, op, right } => { + let left = left.rewrite(self)?; + let right = right.rewrite(self)?; + match op { + Operator::Eq => match (&left, &right) { + ( + Expr::Literal(ScalarValue::Boolean(l)), + Expr::Literal(ScalarValue::Boolean(r)), + ) => match (l, r) { + (Some(l), Some(r)) => { + Expr::Literal(ScalarValue::Boolean(Some(l == r))) + } + _ => Expr::Literal(ScalarValue::Boolean(None)), + }, + (Expr::Literal(ScalarValue::Boolean(b)), _) + if self.is_boolean_type(&right) => + { + match b { + Some(true) => right, + Some(false) => Expr::Not(Box::new(right)), + None => Expr::Literal(ScalarValue::Boolean(None)), + } } - } - (_, Expr::Literal(ScalarValue::Boolean(b))) - if is_boolean_type(&left, schemas) => - { - match b { - Some(true) => left, - Some(false) => Expr::Not(Box::new(left)), - None => Expr::Literal(ScalarValue::Boolean(None)), + (_, Expr::Literal(ScalarValue::Boolean(b))) + if self.is_boolean_type(&left) => + { + match b { + Some(true) => left, + Some(false) => Expr::Not(Box::new(left)), + None => Expr::Literal(ScalarValue::Boolean(None)), + } } - } - _ => Expr::BinaryExpr { - left: Box::new(left), - op: Operator::Eq, - right: Box::new(right), - }, - }, - Operator::NotEq => match (&left, &right) { - ( - Expr::Literal(ScalarValue::Boolean(l)), - Expr::Literal(ScalarValue::Boolean(r)), - ) => match (l, r) { - (Some(l), Some(r)) => { - Expr::Literal(ScalarValue::Boolean(Some(l != r))) - } - _ => Expr::Literal(ScalarValue::Boolean(None)), + _ => Expr::BinaryExpr { + left: Box::new(left), + op: Operator::Eq, + right: Box::new(right), + }, }, - (Expr::Literal(ScalarValue::Boolean(b)), _) - if is_boolean_type(&right, schemas) => - { - match b { - Some(true) => Expr::Not(Box::new(right)), - Some(false) => right, - None => Expr::Literal(ScalarValue::Boolean(None)), + Operator::NotEq => match (&left, &right) { + ( + Expr::Literal(ScalarValue::Boolean(l)), + Expr::Literal(ScalarValue::Boolean(r)), + ) => match (l, r) { + (Some(l), Some(r)) => { + Expr::Literal(ScalarValue::Boolean(Some(l != r))) + } + _ => Expr::Literal(ScalarValue::Boolean(None)), + }, + (Expr::Literal(ScalarValue::Boolean(b)), _) + if self.is_boolean_type(&right) => + { + match b { + Some(true) => Expr::Not(Box::new(right)), + Some(false) => right, + None => Expr::Literal(ScalarValue::Boolean(None)), + } } - } - (_, Expr::Literal(ScalarValue::Boolean(b))) - if is_boolean_type(&left, schemas) => - { - match b { - Some(true) => Expr::Not(Box::new(left)), - Some(false) => left, - None => Expr::Literal(ScalarValue::Boolean(None)), + (_, Expr::Literal(ScalarValue::Boolean(b))) + if self.is_boolean_type(&left) => + { + match b { + Some(true) => Expr::Not(Box::new(left)), + Some(false) => left, + None => Expr::Literal(ScalarValue::Boolean(None)), + } } - } + _ => Expr::BinaryExpr { + left: Box::new(left), + op: Operator::NotEq, + right: Box::new(right), + }, + }, _ => Expr::BinaryExpr { left: Box::new(left), - op: Operator::NotEq, + op, right: Box::new(right), }, - }, - _ => Expr::BinaryExpr { - left: Box::new(left), - op: *op, - right: Box::new(right), - }, + } } - } - Expr::Not(expr) => match &**expr { - Expr::Not(inner) => optimize_expr(&inner, schemas)?, - _ => Expr::Not(Box::new(optimize_expr(&expr, schemas)?)), - }, - Expr::Case { - expr, - when_then_expr, - else_expr, - } => { - // recurse into CASE WHEN condition expressions - Expr::Case { - expr: match expr { - Some(e) => Some(Box::new(optimize_expr(e, schemas)?)), - None => None, - }, - when_then_expr: when_then_expr - .iter() - .map(|(when, then)| { - Ok(( - Box::new(optimize_expr(when, schemas)?), - Box::new(optimize_expr(then, schemas)?), - )) - }) - .collect::>()?, - else_expr: match else_expr { - Some(e) => Some(Box::new(optimize_expr(e, schemas)?)), - None => None, - }, + Expr::Not(inner) => { + // Not(Not(expr)) --> expr + let inner = inner.rewrite(self)?; + if let Expr::Not(negated_inner) = inner { + *negated_inner + } else { + Expr::Not(Box::new(inner)) + } } - } - Expr::Alias(expr, name) => { - Expr::Alias(Box::new(optimize_expr(expr, schemas)?), name.clone()) - } - Expr::Negative(expr) => Expr::Negative(Box::new(optimize_expr(expr, schemas)?)), - Expr::InList { - expr, - list, - negated, - } => Expr::InList { - expr: Box::new(optimize_expr(expr, schemas)?), - list: list - .iter() - .map(|e| optimize_expr(e, schemas)) - .collect::>()?, - negated: *negated, - }, - Expr::IsNotNull(expr) => Expr::IsNotNull(Box::new(optimize_expr(expr, schemas)?)), - Expr::IsNull(expr) => Expr::IsNull(Box::new(optimize_expr(expr, schemas)?)), - Expr::Cast { expr, data_type } => Expr::Cast { - expr: Box::new(optimize_expr(expr, schemas)?), - data_type: data_type.clone(), - }, - Expr::Between { - expr, - negated, - low, - high, - } => Expr::Between { - expr: Box::new(optimize_expr(expr, schemas)?), - negated: *negated, - low: Box::new(optimize_expr(low, schemas)?), - high: Box::new(optimize_expr(high, schemas)?), - }, - Expr::ScalarFunction { fun, args } => Expr::ScalarFunction { - fun: fun.clone(), - args: args - .iter() - .map(|e| optimize_expr(e, schemas)) - .collect::>()?, - }, - Expr::ScalarUDF { fun, args } => Expr::ScalarUDF { - fun: fun.clone(), - args: args - .iter() - .map(|e| optimize_expr(e, schemas)) - .collect::>()?, - }, - Expr::AggregateFunction { - fun, - args, - distinct, - } => Expr::AggregateFunction { - fun: fun.clone(), - args: args - .iter() - .map(|e| optimize_expr(e, schemas)) - .collect::>()?, - distinct: *distinct, - }, - Expr::AggregateUDF { fun, args } => Expr::AggregateUDF { - fun: fun.clone(), - args: args - .iter() - .map(|e| optimize_expr(e, schemas)) - .collect::>()?, - }, - Expr::Sort { - expr, - asc, - nulls_first, - } => Expr::Sort { - expr: Box::new(optimize_expr(expr, schemas)?), - asc: *asc, - nulls_first: *nulls_first, - }, - Expr::Column { .. } - | Expr::ScalarVariable { .. } - | Expr::Literal { .. } - | Expr::Wildcard => e.clone(), - }) + expr => { + // no rewrite possible + expr + } + }; + Ok(new_expr) + } } #[cfg(test)] @@ -331,8 +247,12 @@ mod tests { #[test] fn optimize_expr_not_not() -> Result<()> { let schema = expr_test_schema(); + let mut rewriter = ConstantRewriter { + schemas: vec![&schema], + }; + assert_eq!( - optimize_expr(&col("c2").not().not().not(), &[&schema])?, + (col("c2").not().not().not()).rewrite(&mut rewriter)?, col("c2").not(), ); @@ -342,34 +262,32 @@ mod tests { #[test] fn optimize_expr_null_comparision() -> Result<()> { let schema = expr_test_schema(); + let mut rewriter = ConstantRewriter { + schemas: vec![&schema], + }; // x = null is always null assert_eq!( - optimize_expr(&lit(true).eq(lit(ScalarValue::Boolean(None))), &[&schema])?, + (lit(true).eq(lit(ScalarValue::Boolean(None)))).rewrite(&mut rewriter)?, lit(ScalarValue::Boolean(None)), ); // null != null is always null assert_eq!( - optimize_expr( - &lit(ScalarValue::Boolean(None)).not_eq(lit(ScalarValue::Boolean(None))), - &[&schema], - )?, + (lit(ScalarValue::Boolean(None)).not_eq(lit(ScalarValue::Boolean(None)))) + .rewrite(&mut rewriter)?, lit(ScalarValue::Boolean(None)), ); // x != null is always null assert_eq!( - optimize_expr( - &col("c2").not_eq(lit(ScalarValue::Boolean(None))), - &[&schema], - )?, + (col("c2").not_eq(lit(ScalarValue::Boolean(None)))).rewrite(&mut rewriter)?, lit(ScalarValue::Boolean(None)), ); // null = x is always null assert_eq!( - optimize_expr(&lit(ScalarValue::Boolean(None)).eq(col("c2")), &[&schema])?, + (lit(ScalarValue::Boolean(None)).eq(col("c2"))).rewrite(&mut rewriter)?, lit(ScalarValue::Boolean(None)), ); @@ -379,29 +297,27 @@ mod tests { #[test] fn optimize_expr_eq() -> Result<()> { let schema = expr_test_schema(); + let mut rewriter = ConstantRewriter { + schemas: vec![&schema], + }; + assert_eq!(col("c2").get_type(&schema)?, DataType::Boolean); // true = ture -> true - assert_eq!( - optimize_expr(&lit(true).eq(lit(true)), &[&schema])?, - lit(true), - ); + assert_eq!((lit(true).eq(lit(true))).rewrite(&mut rewriter)?, lit(true),); // true = false -> false assert_eq!( - optimize_expr(&lit(true).eq(lit(false)), &[&schema])?, + (lit(true).eq(lit(false))).rewrite(&mut rewriter)?, lit(false), ); // c2 = true -> c2 - assert_eq!( - optimize_expr(&col("c2").eq(lit(true)), &[&schema])?, - col("c2"), - ); + assert_eq!((col("c2").eq(lit(true))).rewrite(&mut rewriter)?, col("c2"),); // c2 = false => !c2 assert_eq!( - optimize_expr(&col("c2").eq(lit(false)), &[&schema])?, + (col("c2").eq(lit(false))).rewrite(&mut rewriter)?, col("c2").not(), ); @@ -411,6 +327,9 @@ mod tests { #[test] fn optimize_expr_eq_skip_nonboolean_type() -> Result<()> { let schema = expr_test_schema(); + let mut rewriter = ConstantRewriter { + schemas: vec![&schema], + }; // When one of the operand is not of boolean type, folding the other boolean constant will // change return type of expression to non-boolean. @@ -420,24 +339,24 @@ mod tests { // don't fold c1 = true assert_eq!( - optimize_expr(&col("c1").eq(lit(true)), &[&schema])?, + (col("c1").eq(lit(true))).rewrite(&mut rewriter)?, col("c1").eq(lit(true)), ); // don't fold c1 = false assert_eq!( - optimize_expr(&col("c1").eq(lit(false)), &[&schema],)?, + (col("c1").eq(lit(false))).rewrite(&mut rewriter)?, col("c1").eq(lit(false)), ); // test constant operands assert_eq!( - optimize_expr(&lit(1).eq(lit(true)), &[&schema],)?, + (lit(1).eq(lit(true))).rewrite(&mut rewriter)?, lit(1).eq(lit(true)), ); assert_eq!( - optimize_expr(&lit("a").eq(lit(false)), &[&schema],)?, + (lit("a").eq(lit(false))).rewrite(&mut rewriter)?, lit("a").eq(lit(false)), ); @@ -447,28 +366,32 @@ mod tests { #[test] fn optimize_expr_not_eq() -> Result<()> { let schema = expr_test_schema(); + let mut rewriter = ConstantRewriter { + schemas: vec![&schema], + }; + assert_eq!(col("c2").get_type(&schema)?, DataType::Boolean); // c2 != true -> !c2 assert_eq!( - optimize_expr(&col("c2").not_eq(lit(true)), &[&schema])?, + (col("c2").not_eq(lit(true))).rewrite(&mut rewriter)?, col("c2").not(), ); // c2 != false -> c2 assert_eq!( - optimize_expr(&col("c2").not_eq(lit(false)), &[&schema])?, + (col("c2").not_eq(lit(false))).rewrite(&mut rewriter)?, col("c2"), ); // test constant assert_eq!( - optimize_expr(&lit(true).not_eq(lit(true)), &[&schema])?, + (lit(true).not_eq(lit(true))).rewrite(&mut rewriter)?, lit(false), ); assert_eq!( - optimize_expr(&lit(true).not_eq(lit(false)), &[&schema])?, + (lit(true).not_eq(lit(false))).rewrite(&mut rewriter)?, lit(true), ); @@ -478,29 +401,32 @@ mod tests { #[test] fn optimize_expr_not_eq_skip_nonboolean_type() -> Result<()> { let schema = expr_test_schema(); + let mut rewriter = ConstantRewriter { + schemas: vec![&schema], + }; // when one of the operand is not of boolean type, folding the other boolean constant will // change return type of expression to non-boolean. assert_eq!(col("c1").get_type(&schema)?, DataType::Utf8); assert_eq!( - optimize_expr(&col("c1").not_eq(lit(true)), &[&schema])?, + (col("c1").not_eq(lit(true))).rewrite(&mut rewriter)?, col("c1").not_eq(lit(true)), ); assert_eq!( - optimize_expr(&col("c1").not_eq(lit(false)), &[&schema])?, + (col("c1").not_eq(lit(false))).rewrite(&mut rewriter)?, col("c1").not_eq(lit(false)), ); // test constants assert_eq!( - optimize_expr(&lit(1).not_eq(lit(true)), &[&schema])?, + (lit(1).not_eq(lit(true))).rewrite(&mut rewriter)?, lit(1).not_eq(lit(true)), ); assert_eq!( - optimize_expr(&lit("a").not_eq(lit(false)), &[&schema],)?, + (lit("a").not_eq(lit(false))).rewrite(&mut rewriter)?, lit("a").not_eq(lit(false)), ); @@ -510,19 +436,20 @@ mod tests { #[test] fn optimize_expr_case_when_then_else() -> Result<()> { let schema = expr_test_schema(); + let mut rewriter = ConstantRewriter { + schemas: vec![&schema], + }; assert_eq!( - optimize_expr( - &Box::new(Expr::Case { - expr: None, - when_then_expr: vec![( - Box::new(col("c2").not_eq(lit(false))), - Box::new(lit("ok").eq(lit(true))), - )], - else_expr: Some(Box::new(col("c2").eq(lit(true)))), - }), - &[&schema], - )?, + (Box::new(Expr::Case { + expr: None, + when_then_expr: vec![( + Box::new(col("c2").not_eq(lit(false))), + Box::new(lit("ok").eq(lit(true))), + )], + else_expr: Some(Box::new(col("c2").eq(lit(true)))), + })) + .rewrite(&mut rewriter)?, Expr::Case { expr: None, when_then_expr: vec![( @@ -627,7 +554,7 @@ mod tests { let expected = "\ Projection: #a\ - \n Filter: NOT NOT #b\ + \n Filter: #b\ \n TableScan: test projection=None"; assert_optimized_plan_eq(&plan, expected); From 67d35b79dde420fa689ce80a521b0e34f2b23ae1 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Wed, 24 Feb 2021 18:20:01 -0500 Subject: [PATCH 2/2] Implement code review suggestions; avoid redundant rewrites --- rust/datafusion/src/logical_plan/mod.rs | 3 +- .../src/optimizer/constant_folding.rs | 141 ++++++++---------- 2 files changed, 67 insertions(+), 77 deletions(-) diff --git a/rust/datafusion/src/logical_plan/mod.rs b/rust/datafusion/src/logical_plan/mod.rs index 2c902a000b2..90c35dc3a23 100644 --- a/rust/datafusion/src/logical_plan/mod.rs +++ b/rust/datafusion/src/logical_plan/mod.rs @@ -38,8 +38,7 @@ pub use expr::{ count_distinct, create_udaf, create_udf, exp, exprlist_to_fields, floor, in_list, length, lit, ln, log10, log2, lower, ltrim, max, md5, min, octet_length, or, round, rtrim, sha224, sha256, sha384, sha512, signum, sin, sqrt, substr, sum, tan, trim, - trunc, upper, when, Expr, ExpressionVisitor, Literal, Recursion, - ExprRewriter + trunc, upper, when, Expr, ExprRewriter, ExpressionVisitor, Literal, Recursion, }; pub use extension::UserDefinedLogicalNode; pub use operators::Operator; diff --git a/rust/datafusion/src/optimizer/constant_folding.rs b/rust/datafusion/src/optimizer/constant_folding.rs index 409b2c7733a..62f5ee30c62 100644 --- a/rust/datafusion/src/optimizer/constant_folding.rs +++ b/rust/datafusion/src/optimizer/constant_folding.rs @@ -118,92 +118,83 @@ impl<'a> ExprRewriter for ConstantRewriter<'a> { /// rewrite the expression simplifying any constant expressions fn mutate(&mut self, expr: Expr) -> Result { let new_expr = match expr { - Expr::BinaryExpr { left, op, right } => { - let left = left.rewrite(self)?; - let right = right.rewrite(self)?; - match op { - Operator::Eq => match (&left, &right) { - ( - Expr::Literal(ScalarValue::Boolean(l)), - Expr::Literal(ScalarValue::Boolean(r)), - ) => match (l, r) { - (Some(l), Some(r)) => { - Expr::Literal(ScalarValue::Boolean(Some(l == r))) - } - _ => Expr::Literal(ScalarValue::Boolean(None)), - }, - (Expr::Literal(ScalarValue::Boolean(b)), _) - if self.is_boolean_type(&right) => - { - match b { - Some(true) => right, - Some(false) => Expr::Not(Box::new(right)), - None => Expr::Literal(ScalarValue::Boolean(None)), - } + Expr::BinaryExpr { left, op, right } => match op { + Operator::Eq => match (left.as_ref(), right.as_ref()) { + ( + Expr::Literal(ScalarValue::Boolean(l)), + Expr::Literal(ScalarValue::Boolean(r)), + ) => match (l, r) { + (Some(l), Some(r)) => { + Expr::Literal(ScalarValue::Boolean(Some(l == r))) } - (_, Expr::Literal(ScalarValue::Boolean(b))) - if self.is_boolean_type(&left) => - { - match b { - Some(true) => left, - Some(false) => Expr::Not(Box::new(left)), - None => Expr::Literal(ScalarValue::Boolean(None)), - } - } - _ => Expr::BinaryExpr { - left: Box::new(left), - op: Operator::Eq, - right: Box::new(right), - }, + _ => Expr::Literal(ScalarValue::Boolean(None)), }, - Operator::NotEq => match (&left, &right) { - ( - Expr::Literal(ScalarValue::Boolean(l)), - Expr::Literal(ScalarValue::Boolean(r)), - ) => match (l, r) { - (Some(l), Some(r)) => { - Expr::Literal(ScalarValue::Boolean(Some(l != r))) - } - _ => Expr::Literal(ScalarValue::Boolean(None)), - }, - (Expr::Literal(ScalarValue::Boolean(b)), _) - if self.is_boolean_type(&right) => - { - match b { - Some(true) => Expr::Not(Box::new(right)), - Some(false) => right, - None => Expr::Literal(ScalarValue::Boolean(None)), - } + (Expr::Literal(ScalarValue::Boolean(b)), _) + if self.is_boolean_type(&right) => + { + match b { + Some(true) => *right, + Some(false) => Expr::Not(right), + None => Expr::Literal(ScalarValue::Boolean(None)), + } + } + (_, Expr::Literal(ScalarValue::Boolean(b))) + if self.is_boolean_type(&left) => + { + match b { + Some(true) => *left, + Some(false) => Expr::Not(left), + None => Expr::Literal(ScalarValue::Boolean(None)), } - (_, Expr::Literal(ScalarValue::Boolean(b))) - if self.is_boolean_type(&left) => - { - match b { - Some(true) => Expr::Not(Box::new(left)), - Some(false) => left, - None => Expr::Literal(ScalarValue::Boolean(None)), - } + } + _ => Expr::BinaryExpr { + left, + op: Operator::Eq, + right, + }, + }, + Operator::NotEq => match (left.as_ref(), right.as_ref()) { + ( + Expr::Literal(ScalarValue::Boolean(l)), + Expr::Literal(ScalarValue::Boolean(r)), + ) => match (l, r) { + (Some(l), Some(r)) => { + Expr::Literal(ScalarValue::Boolean(Some(l != r))) } - _ => Expr::BinaryExpr { - left: Box::new(left), - op: Operator::NotEq, - right: Box::new(right), - }, + _ => Expr::Literal(ScalarValue::Boolean(None)), }, + (Expr::Literal(ScalarValue::Boolean(b)), _) + if self.is_boolean_type(&right) => + { + match b { + Some(true) => Expr::Not(right), + Some(false) => *right, + None => Expr::Literal(ScalarValue::Boolean(None)), + } + } + (_, Expr::Literal(ScalarValue::Boolean(b))) + if self.is_boolean_type(&left) => + { + match b { + Some(true) => Expr::Not(left), + Some(false) => *left, + None => Expr::Literal(ScalarValue::Boolean(None)), + } + } _ => Expr::BinaryExpr { - left: Box::new(left), - op, - right: Box::new(right), + left, + op: Operator::NotEq, + right, }, - } - } + }, + _ => Expr::BinaryExpr { left, op, right }, + }, Expr::Not(inner) => { // Not(Not(expr)) --> expr - let inner = inner.rewrite(self)?; - if let Expr::Not(negated_inner) = inner { + if let Expr::Not(negated_inner) = *inner { *negated_inner } else { - Expr::Not(Box::new(inner)) + Expr::Not(inner) } } expr => {