Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 130 additions & 0 deletions rust/datafusion/src/logical_plan/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,136 @@ impl Expr {
nulls_first,
}
}

/// Performs a depth first walk of an expression and
/// its children, calling [`ExpressionVisitor::pre_visit`] and
/// `visitor.post_visit`.
///
/// Implements the [visitor pattern](https://en.wikipedia.org/wiki/Visitor_pattern) to
/// separate expression algorithms from the structure of the
/// `Expr` tree and make it easier to add new types of expressions
/// and algorithms that walk the tree.
///
/// For an expression tree such as
/// BinaryExpr (GT)
/// left: Column("foo")
/// right: Column("bar")
///
/// The nodes are visited using the following order
/// ```text
/// pre_visit(BinaryExpr(GT))
/// pre_visit(Column("foo"))
/// pre_visit(Column("bar"))
/// post_visit(Column("bar"))
/// post_visit(Column("bar"))
/// post_visit(BinaryExpr(GT))
/// ```
///
/// If an Err result is returned, recursion is stopped immediately
///
/// If `Recursion::Stop` is returned on a call to pre_visit, no
/// children of that expression are visited, nor is post_visit
/// called on that expression
///
pub fn accept<V: ExpressionVisitor>(&self, visitor: V) -> Result<V> {
let visitor = match visitor.pre_visit(self)? {
Recursion::Continue(visitor) => visitor,
// If the recursion should stop, do not visit children
Recursion::Stop(visitor) => return Ok(visitor),
};

// recurse (and cover all expression types)
let visitor = match self {
Expr::Alias(expr, _) => expr.accept(visitor),
Expr::Column(..) => Ok(visitor),
Expr::ScalarVariable(..) => Ok(visitor),
Expr::Literal(..) => Ok(visitor),
Expr::BinaryExpr { left, right, .. } => {
let visitor = left.accept(visitor)?;
right.accept(visitor)
}
Expr::Not(expr) => expr.accept(visitor),
Expr::IsNotNull(expr) => expr.accept(visitor),
Expr::IsNull(expr) => expr.accept(visitor),
Expr::Negative(expr) => expr.accept(visitor),
Expr::Between {
expr, low, high, ..
} => {
let visitor = expr.accept(visitor)?;
let visitor = low.accept(visitor)?;
high.accept(visitor)
}
Expr::Case {
expr,
when_then_expr,
else_expr,
} => {
let visitor = if let Some(expr) = expr.as_ref() {
expr.accept(visitor)
} else {
Ok(visitor)
}?;
let visitor = when_then_expr.iter().try_fold(
visitor,
|visitor, (when, then)| {
let visitor = when.accept(visitor)?;
then.accept(visitor)
},
)?;
if let Some(else_expr) = else_expr.as_ref() {
else_expr.accept(visitor)
} else {
Ok(visitor)
}
}
Expr::Cast { expr, .. } => expr.accept(visitor),
Expr::Sort { expr, .. } => expr.accept(visitor),
Expr::ScalarFunction { args, .. } => args
.iter()
.try_fold(visitor, |visitor, arg| arg.accept(visitor)),
Expr::ScalarUDF { args, .. } => args
.iter()
.try_fold(visitor, |visitor, arg| arg.accept(visitor)),
Expr::AggregateFunction { args, .. } => args
.iter()
.try_fold(visitor, |visitor, arg| arg.accept(visitor)),
Expr::AggregateUDF { args, .. } => args
.iter()
.try_fold(visitor, |visitor, arg| arg.accept(visitor)),
Expr::InList { expr, list, .. } => {
let visitor = expr.accept(visitor)?;
list.iter()
.try_fold(visitor, |visitor, arg| arg.accept(visitor))
}
Expr::Wildcard => Ok(visitor),
}?;

visitor.post_visit(self)
}
}

/// Controls how the visitor recursion should proceed.
pub enum Recursion<V: ExpressionVisitor> {
/// Attempt to visit all the children, recursively, of this expression.
Continue(V),
/// Do not visit the children of this expression, though the walk
/// of parents of this expression will not be affected
Stop(V),
}

/// Encode the traversal of an expression tree. When passed to
/// `Expr::accept`, `ExpressionVisitor::visit` is invoked
/// recursively on all nodes of an expression tree. See the comments
/// on `Expr::accept` for details on its use
pub trait ExpressionVisitor: Sized {
/// Invoked before any children of `expr` are visisted.
fn pre_visit(self, expr: &Expr) -> Result<Recursion<Self>>;

/// Invoked after all children of `expr` are visited. Default
/// implementation does nothing.
fn post_visit(self, _expr: &Expr) -> Result<Self> {
Ok(self)
}
}

pub struct CaseBuilder {
Expand Down
2 changes: 1 addition & 1 deletion rust/datafusion/src/logical_plan/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ pub use expr::{
count, count_distinct, create_udaf, create_udf, exp, exprlist_to_fields, floor,
in_list, length, lit, ln, log10, log2, lower, ltrim, max, md5, min, or, round, rtrim,
sha224, sha256, sha384, sha512, signum, sin, sqrt, sum, tan, trim, trunc, upper,
when, Expr, Literal,
when, Expr, ExpressionVisitor, Literal, Recursion,
};
pub use extension::UserDefinedLogicalNode;
pub use operators::Operator;
Expand Down
107 changes: 42 additions & 65 deletions rust/datafusion/src/optimizer/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,16 @@ use std::{collections::HashSet, sync::Arc};
use arrow::datatypes::Schema;

use super::optimizer::OptimizerRule;
use crate::error::{DataFusionError, Result};
use crate::logical_plan::{
Expr, LogicalPlan, Operator, Partitioning, PlanType, StringifiedPlan, ToDFSchema,
Expr, LogicalPlan, Operator, Partitioning, PlanType, Recursion, StringifiedPlan,
ToDFSchema,
};
use crate::prelude::{col, lit};
use crate::scalar::ScalarValue;
use crate::{
error::{DataFusionError, Result},
logical_plan::ExpressionVisitor,
};

const CASE_EXPR_MARKER: &str = "__DATAFUSION_CASE_EXPR__";
const CASE_ELSE_MARKER: &str = "__DATAFUSION_CASE_ELSE__";
Expand All @@ -46,75 +50,48 @@ pub fn exprlist_to_column_names(

/// Recursively walk an expression tree, collecting the unique set of column names
/// referenced in the expression
pub fn expr_to_column_names(expr: &Expr, accum: &mut HashSet<String>) -> Result<()> {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a pretty good example of the kind of repetition that can be removed using this visitor pattern.
Note that I still left all expr types enumerated so that anyone who adds a new Expr type need to update this code, and (hopefully) think if they need to add special handling for that new expr types

match expr {
Expr::Alias(expr, _) => expr_to_column_names(expr, accum),
Expr::Column(name) => {
accum.insert(name.clone());
Ok(())
}
Expr::ScalarVariable(var_names) => {
accum.insert(var_names.join("."));
Ok(())
}
Expr::Literal(_) => {
// not needed
Ok(())
}
Expr::Not(e) => expr_to_column_names(e, accum),
Expr::Negative(e) => expr_to_column_names(e, accum),
Expr::IsNull(e) => expr_to_column_names(e, accum),
Expr::IsNotNull(e) => expr_to_column_names(e, accum),
Expr::BinaryExpr { left, right, .. } => {
expr_to_column_names(left, accum)?;
expr_to_column_names(right, accum)?;
Ok(())
}
Expr::Case {
expr,
when_then_expr,
else_expr,
..
} => {
if let Some(e) = expr {
expr_to_column_names(e, accum)?;
}
for (w, t) in when_then_expr {
expr_to_column_names(w, accum)?;
expr_to_column_names(t, accum)?;
}
if let Some(e) = else_expr {
expr_to_column_names(e, accum)?
struct ColumnNameVisitor<'a> {
accum: &'a mut HashSet<String>,
}

impl ExpressionVisitor for ColumnNameVisitor<'_> {
fn pre_visit(self, expr: &Expr) -> Result<Recursion<Self>> {
match expr {
Expr::Column(name) => {
self.accum.insert(name.clone());
}
Ok(())
}
Expr::Cast { expr, .. } => expr_to_column_names(expr, accum),
Expr::Sort { expr, .. } => expr_to_column_names(expr, accum),
Expr::AggregateFunction { args, .. } => exprlist_to_column_names(args, accum),
Expr::AggregateUDF { args, .. } => exprlist_to_column_names(args, accum),
Expr::ScalarFunction { args, .. } => exprlist_to_column_names(args, accum),
Expr::ScalarUDF { args, .. } => exprlist_to_column_names(args, accum),
Expr::Between {
expr, low, high, ..
} => {
expr_to_column_names(expr, accum)?;
expr_to_column_names(low, accum)?;
expr_to_column_names(high, accum)?;
Ok(())
}
Expr::InList { expr, list, .. } => {
expr_to_column_names(expr, accum)?;
for list_expr in list {
expr_to_column_names(list_expr, accum)?;
Expr::ScalarVariable(var_names) => {
self.accum.insert(var_names.join("."));
}
Ok(())
Expr::Alias(_, _) => {}
Expr::Literal(_) => {}
Expr::BinaryExpr { .. } => {}
Expr::Not(_) => {}
Expr::IsNotNull(_) => {}
Expr::IsNull(_) => {}
Expr::Negative(_) => {}
Expr::Between { .. } => {}
Expr::Case { .. } => {}
Expr::Cast { .. } => {}
Expr::Sort { .. } => {}
Expr::ScalarFunction { .. } => {}
Expr::ScalarUDF { .. } => {}
Expr::AggregateFunction { .. } => {}
Expr::AggregateUDF { .. } => {}
Expr::InList { .. } => {}
Expr::Wildcard => {}
}
Expr::Wildcard => Err(DataFusionError::Internal(
"Wildcard expressions are not valid in a logical query plan".to_owned(),
)),
Ok(Recursion::Continue(self))
}
}

/// Recursively walk an expression tree, collecting the unique set of column names
/// referenced in the expression
pub fn expr_to_column_names(expr: &Expr, accum: &mut HashSet<String>) -> Result<()> {
expr.accept(ColumnNameVisitor { accum })?;
Ok(())
}

/// Create a `LogicalPlan::Explain` node by running `optimizer` on the
/// input plan and capturing the resulting plan string
pub fn optimize_explain(
Expand Down
Loading