From 1299add777898f8909ff6f6130c2cabc161f9917 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Wed, 20 Jan 2021 10:59:24 -0500 Subject: [PATCH 1/2] ARROW-11330: [Rust][DataFusion] add ExpressionVisitor --- rust/datafusion/src/logical_plan/expr.rs | 130 ++++++++++++++++++++ rust/datafusion/src/logical_plan/mod.rs | 2 +- rust/datafusion/src/optimizer/utils.rs | 107 +++++++---------- rust/datafusion/src/sql/utils.rs | 145 +++++++---------------- 4 files changed, 219 insertions(+), 165 deletions(-) diff --git a/rust/datafusion/src/logical_plan/expr.rs b/rust/datafusion/src/logical_plan/expr.rs index 59d6add3d71..cf5d3089e07 100644 --- a/rust/datafusion/src/logical_plan/expr.rs +++ b/rust/datafusion/src/logical_plan/expr.rs @@ -422,6 +422,136 @@ impl Expr { nulls_first, } } + + /// Performs a depth first depth first walk of an expression and + /// its children, calling `visitor.pre_visit` and + /// `visitor.post_visit`. + /// + /// Implements the [visitor pattern](https://en.wikipedia.org/wiki/Visitor_pattern) to + /// separate expression algorithms from the structure of the + /// `Expr` tree and make it easier to add new types of expressions + /// and algorithms that walk the tree. + /// + /// For an expression tree such as + /// BinaryExpr (GT) + /// left: Column("foo") + /// right: Column("bar") + /// + /// The nodes are visited using the following order + /// ```text + /// pre_visit(BinaryExpr(GT)) + /// pre_visit(Column("foo")) + /// pre_visit(Column("bar")) + /// post_visit(Column("bar")) + /// post_visit(Column("bar")) + /// post_visit(BinaryExpr(GT)) + /// ``` + /// + /// If an Err result is returned, recursion is stopped immediately + /// + /// If `Recursion::Stop` is returned on a call to pre_visit, no + /// children of that expression are visited, nor is post_visit + /// called on that expression + /// + pub fn accept(&self, visitor: V) -> Result { + let visitor = match visitor.pre_visit(self)? { + Recursion::Continue(visitor) => visitor, + // If the recursion should stop, do not visit children + Recursion::Stop(visitor) => return Ok(visitor), + }; + + // recurse (and cover all expression types) + let visitor = match self { + Expr::Alias(expr, _) => expr.accept(visitor), + Expr::Column(..) => Ok(visitor), + Expr::ScalarVariable(..) => Ok(visitor), + Expr::Literal(..) => Ok(visitor), + Expr::BinaryExpr { left, right, .. } => { + let visitor = left.accept(visitor)?; + right.accept(visitor) + } + Expr::Not(expr) => expr.accept(visitor), + Expr::IsNotNull(expr) => expr.accept(visitor), + Expr::IsNull(expr) => expr.accept(visitor), + Expr::Negative(expr) => expr.accept(visitor), + Expr::Between { + expr, low, high, .. + } => { + let visitor = expr.accept(visitor)?; + let visitor = low.accept(visitor)?; + high.accept(visitor) + } + Expr::Case { + expr, + when_then_expr, + else_expr, + } => { + let visitor = if let Some(expr) = expr.as_ref() { + expr.accept(visitor) + } else { + Ok(visitor) + }?; + let visitor = when_then_expr.iter().try_fold( + visitor, + |visitor, (when, then)| { + let visitor = when.accept(visitor)?; + then.accept(visitor) + }, + )?; + if let Some(else_expr) = else_expr.as_ref() { + else_expr.accept(visitor) + } else { + Ok(visitor) + } + } + Expr::Cast { expr, .. } => expr.accept(visitor), + Expr::Sort { expr, .. } => expr.accept(visitor), + Expr::ScalarFunction { args, .. } => args + .iter() + .try_fold(visitor, |visitor, arg| arg.accept(visitor)), + Expr::ScalarUDF { args, .. } => args + .iter() + .try_fold(visitor, |visitor, arg| arg.accept(visitor)), + Expr::AggregateFunction { args, .. } => args + .iter() + .try_fold(visitor, |visitor, arg| arg.accept(visitor)), + Expr::AggregateUDF { args, .. } => args + .iter() + .try_fold(visitor, |visitor, arg| arg.accept(visitor)), + Expr::InList { expr, list, .. } => { + let visitor = expr.accept(visitor)?; + list.iter() + .try_fold(visitor, |visitor, arg| arg.accept(visitor)) + } + Expr::Wildcard => Ok(visitor), + }?; + + visitor.post_visit(self) + } +} + +/// Controls how the visitor recursion should proceed. +pub enum Recursion { + /// Attempt to visit all the children, recursively, of this expression. + Continue(V), + /// Do not visit the children of this expression, though the walk + /// of parents of this expression will not be affected + Stop(V), +} + +/// Encode the traversal of an expression tree. When passed to +/// `visit_expression`, `ExpressionVisitor::visit` is invoked +/// recursively on all nodes of an expression tree. See the comments +/// on `Expr::accept` for details on its use +pub trait ExpressionVisitor: Sized { + /// Invoked before any children of `expr` are visisted. + fn pre_visit(self, expr: &Expr) -> Result>; + + /// Invoked after all children of `expr` are visited. Default + /// implementation does nothing. + fn post_visit(self, _expr: &Expr) -> Result { + Ok(self) + } } pub struct CaseBuilder { diff --git a/rust/datafusion/src/logical_plan/mod.rs b/rust/datafusion/src/logical_plan/mod.rs index 3d6d5817d21..cceb94794b5 100644 --- a/rust/datafusion/src/logical_plan/mod.rs +++ b/rust/datafusion/src/logical_plan/mod.rs @@ -38,7 +38,7 @@ pub use expr::{ count, count_distinct, create_udaf, create_udf, exp, exprlist_to_fields, floor, in_list, length, lit, ln, log10, log2, lower, ltrim, max, md5, min, or, round, rtrim, sha224, sha256, sha384, sha512, signum, sin, sqrt, sum, tan, trim, trunc, upper, - when, Expr, Literal, + when, Expr, ExpressionVisitor, Literal, Recursion, }; pub use extension::UserDefinedLogicalNode; pub use operators::Operator; diff --git a/rust/datafusion/src/optimizer/utils.rs b/rust/datafusion/src/optimizer/utils.rs index b9e67a43d7e..68143b1867f 100644 --- a/rust/datafusion/src/optimizer/utils.rs +++ b/rust/datafusion/src/optimizer/utils.rs @@ -22,12 +22,16 @@ use std::{collections::HashSet, sync::Arc}; use arrow::datatypes::Schema; use super::optimizer::OptimizerRule; -use crate::error::{DataFusionError, Result}; use crate::logical_plan::{ - Expr, LogicalPlan, Operator, Partitioning, PlanType, StringifiedPlan, ToDFSchema, + Expr, LogicalPlan, Operator, Partitioning, PlanType, Recursion, StringifiedPlan, + ToDFSchema, }; use crate::prelude::{col, lit}; use crate::scalar::ScalarValue; +use crate::{ + error::{DataFusionError, Result}, + logical_plan::ExpressionVisitor, +}; const CASE_EXPR_MARKER: &str = "__DATAFUSION_CASE_EXPR__"; const CASE_ELSE_MARKER: &str = "__DATAFUSION_CASE_ELSE__"; @@ -46,75 +50,48 @@ pub fn exprlist_to_column_names( /// Recursively walk an expression tree, collecting the unique set of column names /// referenced in the expression -pub fn expr_to_column_names(expr: &Expr, accum: &mut HashSet) -> Result<()> { - match expr { - Expr::Alias(expr, _) => expr_to_column_names(expr, accum), - Expr::Column(name) => { - accum.insert(name.clone()); - Ok(()) - } - Expr::ScalarVariable(var_names) => { - accum.insert(var_names.join(".")); - Ok(()) - } - Expr::Literal(_) => { - // not needed - Ok(()) - } - Expr::Not(e) => expr_to_column_names(e, accum), - Expr::Negative(e) => expr_to_column_names(e, accum), - Expr::IsNull(e) => expr_to_column_names(e, accum), - Expr::IsNotNull(e) => expr_to_column_names(e, accum), - Expr::BinaryExpr { left, right, .. } => { - expr_to_column_names(left, accum)?; - expr_to_column_names(right, accum)?; - Ok(()) - } - Expr::Case { - expr, - when_then_expr, - else_expr, - .. - } => { - if let Some(e) = expr { - expr_to_column_names(e, accum)?; - } - for (w, t) in when_then_expr { - expr_to_column_names(w, accum)?; - expr_to_column_names(t, accum)?; - } - if let Some(e) = else_expr { - expr_to_column_names(e, accum)? +struct ColumnNameVisitor<'a> { + accum: &'a mut HashSet, +} + +impl ExpressionVisitor for ColumnNameVisitor<'_> { + fn pre_visit(self, expr: &Expr) -> Result> { + match expr { + Expr::Column(name) => { + self.accum.insert(name.clone()); } - Ok(()) - } - Expr::Cast { expr, .. } => expr_to_column_names(expr, accum), - Expr::Sort { expr, .. } => expr_to_column_names(expr, accum), - Expr::AggregateFunction { args, .. } => exprlist_to_column_names(args, accum), - Expr::AggregateUDF { args, .. } => exprlist_to_column_names(args, accum), - Expr::ScalarFunction { args, .. } => exprlist_to_column_names(args, accum), - Expr::ScalarUDF { args, .. } => exprlist_to_column_names(args, accum), - Expr::Between { - expr, low, high, .. - } => { - expr_to_column_names(expr, accum)?; - expr_to_column_names(low, accum)?; - expr_to_column_names(high, accum)?; - Ok(()) - } - Expr::InList { expr, list, .. } => { - expr_to_column_names(expr, accum)?; - for list_expr in list { - expr_to_column_names(list_expr, accum)?; + Expr::ScalarVariable(var_names) => { + self.accum.insert(var_names.join(".")); } - Ok(()) + Expr::Alias(_, _) => {} + Expr::Literal(_) => {} + Expr::BinaryExpr { .. } => {} + Expr::Not(_) => {} + Expr::IsNotNull(_) => {} + Expr::IsNull(_) => {} + Expr::Negative(_) => {} + Expr::Between { .. } => {} + Expr::Case { .. } => {} + Expr::Cast { .. } => {} + Expr::Sort { .. } => {} + Expr::ScalarFunction { .. } => {} + Expr::ScalarUDF { .. } => {} + Expr::AggregateFunction { .. } => {} + Expr::AggregateUDF { .. } => {} + Expr::InList { .. } => {} + Expr::Wildcard => {} } - Expr::Wildcard => Err(DataFusionError::Internal( - "Wildcard expressions are not valid in a logical query plan".to_owned(), - )), + Ok(Recursion::Continue(self)) } } +/// Recursively walk an expression tree, collecting the unique set of column names +/// referenced in the expression +pub fn expr_to_column_names(expr: &Expr, accum: &mut HashSet) -> Result<()> { + expr.accept(ColumnNameVisitor { accum })?; + Ok(()) +} + /// Create a `LogicalPlan::Explain` node by running `optimizer` on the /// input plan and capturing the resulting plan string pub fn optimize_explain( diff --git a/rust/datafusion/src/sql/utils.rs b/rust/datafusion/src/sql/utils.rs index 976e2c574d9..34bd55df49d 100644 --- a/rust/datafusion/src/sql/utils.rs +++ b/rust/datafusion/src/sql/utils.rs @@ -15,8 +15,11 @@ // specific language governing permissions and limitations // under the License. -use crate::error::{DataFusionError, Result}; use crate::logical_plan::{DFSchema, Expr, LogicalPlan}; +use crate::{ + error::{DataFusionError, Result}, + logical_plan::{ExpressionVisitor, Recursion}, +}; /// Resolves an `Expr::Wildcard` to a collection of `Expr::Column`'s. pub(crate) fn expand_wildcard(expr: &Expr, schema: &DFSchema) -> Vec { @@ -66,113 +69,57 @@ where }) } -/// Search an `Expr`, and all of its nested `Expr`'s, for any that pass the -/// provided test. The returned `Expr`'s are deduplicated and returned in order -/// of appearance (depth first). -fn find_exprs_in_expr(expr: &Expr, test_fn: &F) -> Vec +// Visitor that find expressions that match a particular predicate +struct Finder<'a, F> where F: Fn(&Expr) -> bool, { - let matched_exprs = if test_fn(expr) { - vec![expr.clone()] - } else { - match expr { - Expr::AggregateFunction { args, .. } => find_exprs_in_exprs(&args, test_fn), - Expr::AggregateUDF { args, .. } => find_exprs_in_exprs(&args, test_fn), - Expr::Alias(nested_expr, _) => { - find_exprs_in_expr(nested_expr.as_ref(), test_fn) - } - Expr::Between { - expr: nested_expr, - low, - high, - .. - } => { - let mut matches = vec![]; - matches.extend(find_exprs_in_expr(nested_expr.as_ref(), test_fn)); - matches.extend(find_exprs_in_expr(low.as_ref(), test_fn)); - matches.extend(find_exprs_in_expr(high.as_ref(), test_fn)); - matches - } - Expr::BinaryExpr { left, right, .. } => { - let mut matches = vec![]; - matches.extend(find_exprs_in_expr(left.as_ref(), test_fn)); - matches.extend(find_exprs_in_expr(right.as_ref(), test_fn)); - matches - } - Expr::InList { - expr: nested_expr, - list, - .. - } => { - let mut matches = vec![]; - matches.extend(find_exprs_in_expr(nested_expr.as_ref(), test_fn)); - matches.extend( - list.iter() - .flat_map(|expr| find_exprs_in_expr(expr, test_fn)) - .collect::>(), - ); - matches - } - Expr::Case { - expr: case_expr_opt, - when_then_expr, - else_expr: else_expr_opt, - } => { - let mut matches = vec![]; - - if let Some(case_expr) = case_expr_opt { - matches.extend(find_exprs_in_expr(case_expr.as_ref(), test_fn)); - } - - matches.extend( - when_then_expr - .iter() - .flat_map(|(a, b)| vec![a, b]) - .flat_map(|expr| find_exprs_in_expr(expr.as_ref(), test_fn)) - .collect::>(), - ); + test_fn: &'a F, + exprs: Vec, +} - if let Some(else_expr) = else_expr_opt { - matches.extend(find_exprs_in_expr(else_expr.as_ref(), test_fn)); - } +impl<'a, F> Finder<'a, F> +where + F: Fn(&Expr) -> bool, +{ + /// Create a new finder with the `test_fn` + fn new(test_fn: &'a F) -> Self { + Self { + test_fn, + exprs: Vec::new(), + } + } +} - matches - } - Expr::Cast { - expr: nested_expr, .. - } => find_exprs_in_expr(nested_expr.as_ref(), test_fn), - Expr::IsNotNull(nested_expr) => { - find_exprs_in_expr(nested_expr.as_ref(), test_fn) - } - Expr::IsNull(nested_expr) => { - find_exprs_in_expr(nested_expr.as_ref(), test_fn) - } - Expr::Negative(nested_expr) => { - find_exprs_in_expr(nested_expr.as_ref(), test_fn) +impl<'a, F> ExpressionVisitor for Finder<'a, F> +where + F: Fn(&Expr) -> bool, +{ + fn pre_visit(mut self, expr: &Expr) -> Result> { + if (self.test_fn)(expr) { + if !(self.exprs.contains(expr)) { + self.exprs.push(expr.clone()) } - Expr::Not(nested_expr) => find_exprs_in_expr(nested_expr.as_ref(), test_fn), - Expr::ScalarFunction { args, .. } => find_exprs_in_exprs(&args, test_fn), - Expr::ScalarUDF { args, .. } => find_exprs_in_exprs(&args, test_fn), - Expr::Sort { - expr: nested_expr, .. - } => find_exprs_in_expr(nested_expr.as_ref(), test_fn), - - // These expressions don't nest other expressions. - Expr::Column(_) - | Expr::Literal(_) - | Expr::ScalarVariable(_) - | Expr::Wildcard => vec![], + // stop recursing down this expr once we find a match + return Ok(Recursion::Stop(self)); } - }; - matched_exprs.into_iter().fold(vec![], |mut acc, expr| { - if !acc.contains(&expr) { - acc.push(expr) - } + Ok(Recursion::Continue(self)) + } +} - acc - }) +/// Search an `Expr`, and all of its nested `Expr`'s, for any that pass the +/// provided test. The returned `Expr`'s are deduplicated and returned in order +/// of appearance (depth first). +fn find_exprs_in_expr(expr: &Expr, test_fn: &F) -> Vec +where + F: Fn(&Expr) -> bool, +{ + let Finder { exprs, .. } = expr + .accept(Finder::new(test_fn)) + // pre_visit always returns OK, so this will always too + .expect("no way to return error during recursion"); + exprs } /// Convert any `Expr` to an `Expr::Column`. From ae9fae851417157ba49531cdd2431fd2e167e158 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Wed, 20 Jan 2021 18:06:17 -0500 Subject: [PATCH 2/2] Apply suggestions from code review Co-authored-by: Jorge Leitao --- rust/datafusion/src/logical_plan/expr.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/rust/datafusion/src/logical_plan/expr.rs b/rust/datafusion/src/logical_plan/expr.rs index cf5d3089e07..31898d3e695 100644 --- a/rust/datafusion/src/logical_plan/expr.rs +++ b/rust/datafusion/src/logical_plan/expr.rs @@ -423,8 +423,8 @@ impl Expr { } } - /// Performs a depth first depth first walk of an expression and - /// its children, calling `visitor.pre_visit` and + /// Performs a depth first walk of an expression and + /// its children, calling [`ExpressionVisitor::pre_visit`] and /// `visitor.post_visit`. /// /// Implements the [visitor pattern](https://en.wikipedia.org/wiki/Visitor_pattern) to @@ -540,7 +540,7 @@ pub enum Recursion { } /// Encode the traversal of an expression tree. When passed to -/// `visit_expression`, `ExpressionVisitor::visit` is invoked +/// `Expr::accept`, `ExpressionVisitor::visit` is invoked /// recursively on all nodes of an expression tree. See the comments /// on `Expr::accept` for details on its use pub trait ExpressionVisitor: Sized {