From c55c64da3604b5b968d454f73a415b9021577df2 Mon Sep 17 00:00:00 2001 From: jackwener Date: Thu, 24 Nov 2022 22:12:29 +0800 Subject: [PATCH 1/2] reimplement `push_down_filter` --- datafusion/optimizer/src/lib.rs | 2 +- datafusion/optimizer/src/optimizer.rs | 5 +- ...ilter_push_down.rs => push_down_filter.rs} | 1119 ++++++++--------- 3 files changed, 499 insertions(+), 627 deletions(-) rename datafusion/optimizer/src/{filter_push_down.rs => push_down_filter.rs} (73%) diff --git a/datafusion/optimizer/src/lib.rs b/datafusion/optimizer/src/lib.rs index cdfe7fc9b31d1..aba53a3d8a2f3 100644 --- a/datafusion/optimizer/src/lib.rs +++ b/datafusion/optimizer/src/lib.rs @@ -23,12 +23,12 @@ pub mod eliminate_filter; pub mod eliminate_limit; pub mod eliminate_outer_join; pub mod filter_null_join_keys; -pub mod filter_push_down; pub mod inline_table_scan; pub mod limit_push_down; pub mod optimizer; pub mod projection_push_down; pub mod propagate_empty_relation; +pub mod push_down_filter; pub mod scalar_subquery_to_join; pub mod simplify_expressions; pub mod single_distinct_to_groupby; diff --git a/datafusion/optimizer/src/optimizer.rs b/datafusion/optimizer/src/optimizer.rs index afa391d416969..e1b42f9e6b00d 100644 --- a/datafusion/optimizer/src/optimizer.rs +++ b/datafusion/optimizer/src/optimizer.rs @@ -25,11 +25,11 @@ use crate::eliminate_filter::EliminateFilter; use crate::eliminate_limit::EliminateLimit; use crate::eliminate_outer_join::EliminateOuterJoin; use crate::filter_null_join_keys::FilterNullJoinKeys; -use crate::filter_push_down::FilterPushDown; use crate::inline_table_scan::InlineTableScan; use crate::limit_push_down::LimitPushDown; use crate::projection_push_down::ProjectionPushDown; use crate::propagate_empty_relation::PropagateEmptyRelation; +use crate::push_down_filter::PushDownFilter; use crate::rewrite_disjunctive_predicate::RewriteDisjunctivePredicate; use crate::scalar_subquery_to_join::ScalarSubqueryToJoin; use crate::simplify_expressions::SimplifyExpressions; @@ -184,8 +184,9 @@ impl Optimizer { rules.push(Arc::new(FilterNullJoinKeys::default())); } rules.push(Arc::new(EliminateOuterJoin::new())); - rules.push(Arc::new(FilterPushDown::new())); + // Filter can't pushdown Limit, we should do PushDownFilter after LimitPushDown rules.push(Arc::new(LimitPushDown::new())); + rules.push(Arc::new(PushDownFilter::new())); rules.push(Arc::new(SingleDistinctToGroupBy::new())); // The previous optimizations added expressions and projections, diff --git a/datafusion/optimizer/src/filter_push_down.rs b/datafusion/optimizer/src/push_down_filter.rs similarity index 73% rename from datafusion/optimizer/src/filter_push_down.rs rename to datafusion/optimizer/src/push_down_filter.rs index 2f8a8a8b4d881..edde57fa3d14a 100644 --- a/datafusion/optimizer/src/filter_push_down.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -12,26 +12,25 @@ // specific language governing permissions and limitations // under the License. -//! Filter Push Down optimizer rule ensures that filters are applied as early as possible in the plan +//! Push Down Filter optimizer rule ensures that filters are applied as early as possible in the plan +use crate::utils::conjunction; use crate::{utils, OptimizerConfig, OptimizerRule}; use datafusion_common::{Column, DFSchema, DataFusionError, Result}; +use datafusion_expr::utils::exprlist_to_columns; use datafusion_expr::{ - and, col, - expr::BinaryExpr, + and, expr_rewriter::{replace_col, ExprRewritable, ExprRewriter}, - logical_plan::{ - Aggregate, CrossJoin, Join, JoinType, Limit, LogicalPlan, Projection, TableScan, - Union, - }, + logical_plan::{CrossJoin, Join, JoinType, LogicalPlan, TableScan, Union}, or, - utils::{expr_to_columns, exprlist_to_columns, from_plan}, - Expr, Operator, TableProviderFilterPushDown, + utils::from_plan, + BinaryExpr, Expr, Filter, Operator, TableProviderFilterPushDown, }; use std::collections::{HashMap, HashSet}; use std::iter::once; +use std::sync::Arc; -/// Filter Push Down optimizer rule pushes filter clauses down the plan +/// Push Down Filter optimizer rule pushes filter clauses down the plan /// # Introduction /// A filter-commutative operation is an operation whose result of filter(op(data)) = op(filter(data)). /// An example of a filter-commutative operation is a projection; a counter-example is `limit`. @@ -57,96 +56,7 @@ use std::iter::once; /// When it passes through a projection, it re-writes the filter's expression taking into account that projection. /// When multiple filters would have been written, it `AND` their expressions into a single expression. #[derive(Default)] -pub struct FilterPushDown {} - -/// Filter predicate represented by tuple of expression and its columns -type Predicate = (Expr, HashSet); - -/// Multiple filter predicates represented by tuple of expressions vector -/// and corresponding expression columns vector -type Predicates<'a> = (Vec<&'a Expr>, Vec<&'a HashSet>); - -#[derive(Debug, Clone, Default)] -struct State { - // (predicate, columns on the predicate) - filters: Vec, -} - -impl State { - fn append_predicates(&mut self, predicates: Predicates) { - predicates - .0 - .into_iter() - .zip(predicates.1) - .for_each(|(expr, cols)| self.filters.push((expr.clone(), cols.clone()))) - } -} - -/// returns all predicates in `state` that depend on any of `used_columns` -/// or the ones that does not reference any columns (e.g. WHERE 1=1) -fn get_predicates<'a>( - state: &'a State, - used_columns: &HashSet, -) -> Predicates<'a> { - state - .filters - .iter() - .filter(|(_, columns)| { - columns.is_empty() - || !columns - .intersection(used_columns) - .collect::>() - .is_empty() - }) - .map(|&(ref a, ref b)| (a, b)) - .unzip() -} - -/// Optimizes the plan -fn push_down(state: &State, plan: &LogicalPlan) -> Result { - let new_inputs = plan - .inputs() - .iter() - .map(|input| optimize(input, state.clone())) - .collect::>>()?; - - let expr = plan.expressions(); - from_plan(plan, &expr, &new_inputs) -} - -// remove all filters from `filters` that are in `predicate_columns` -fn remove_filters( - filters: &[Predicate], - predicate_columns: &[&HashSet], -) -> Vec { - filters - .iter() - .filter(|(_, columns)| !predicate_columns.contains(&columns)) - .cloned() - .collect::>() -} - -/// builds a new [LogicalPlan] from `plan` by issuing new [LogicalPlan::Filter] if any of the filters -/// in `state` depend on the columns `used_columns`. -fn issue_filters( - mut state: State, - used_columns: HashSet, - plan: &LogicalPlan, -) -> Result { - let (predicates, predicate_columns) = get_predicates(&state, &used_columns); - - if predicates.is_empty() { - // all filters can be pushed down => optimize inputs and return new plan - return push_down(&state, plan); - } - - let plan = utils::add_filter(plan.clone(), &predicates)?; - - state.filters = remove_filters(&state.filters, &predicate_columns); - - // continue optimization over all input nodes by cloning the current state (i.e. each node is independent) - push_down(&state, &plan) -} +pub struct PushDownFilter {} // For a given JOIN logical plan, determine whether each side of the join is preserved. // We say a join side is preserved if the join returns all or a subset of the rows from @@ -220,15 +130,7 @@ fn on_lr_is_preserved(plan: &LogicalPlan) -> Result<(bool, bool)> { // or not the side's rows are preserved when joining. If the side is not preserved, we // do not push down anything. Otherwise we can push down predicates where all of the // relevant columns are contained on the relevant join side's schema. -fn get_pushable_join_predicates<'a>( - filters: &'a [Predicate], - schema: &DFSchema, - preserved: bool, -) -> Predicates<'a> { - if !preserved { - return (vec![], vec![]); - } - +fn can_pushdown_join_predicate(predicate: &Expr, schema: &DFSchema) -> Result { let schema_columns = schema .fields() .iter() @@ -240,19 +142,13 @@ fn get_pushable_join_predicates<'a>( ] }) .collect::>(); + let columns = predicate.to_columns()?; - filters - .iter() - .filter(|(_, columns)| { - let all_columns_in_schema = schema_columns - .intersection(columns) - .collect::>() - .len() - == columns.len(); - all_columns_in_schema - }) - .map(|(a, b)| (a, b)) - .unzip() + Ok(schema_columns + .intersection(&columns) + .collect::>() + .len() + == columns.len()) } // examine OR clause to see if any useful clauses can be extracted and push down. @@ -292,9 +188,9 @@ fn extract_or_clauses_for_join( filters: &[&Expr], schema: &DFSchema, preserved: bool, -) -> (Vec, Vec>) { +) -> Vec { if !preserved { - return (vec![], vec![]); + return vec![]; } let schema_columns = schema @@ -310,7 +206,6 @@ fn extract_or_clauses_for_join( .collect::>(); let mut exprs = vec![]; - let mut expr_columns = vec![]; for expr in filters.iter() { if let Expr::BinaryExpr(BinaryExpr { left, @@ -323,17 +218,13 @@ fn extract_or_clauses_for_join( // If nothing can be extracted from any sub clauses, do nothing for this OR clause. if let (Some(left_expr), Some(right_expr)) = (left_expr, right_expr) { - let predicate = or(left_expr, right_expr); - let columns = predicate.to_columns().ok().unwrap(); - - exprs.push(predicate); - expr_columns.push(columns); + exprs.push(or(left_expr, right_expr)); } } } // new formed OR clauses and their column references - (exprs, expr_columns) + exprs } // extract qual from OR sub-clause. @@ -403,94 +294,90 @@ fn extract_or_clause(expr: &Expr, schema_columns: &HashSet) -> Option, plan: &LogicalPlan, left: &LogicalPlan, right: &LogicalPlan, - on_filter: Vec, + on_filter: Vec, ) -> Result { + let on_filter_empty = on_filter.is_empty(); // Get pushable predicates from current optimizer state let (left_preserved, right_preserved) = lr_is_preserved(plan)?; - let to_left = - get_pushable_join_predicates(&state.filters, left.schema(), left_preserved); - let to_right = - get_pushable_join_predicates(&state.filters, right.schema(), right_preserved); - let to_keep: Predicates = state - .filters - .iter() - .filter(|(e, _)| !to_left.0.contains(&e) && !to_right.0.contains(&e)) - .map(|(a, b)| (a, b)) - .unzip(); + let mut left_push = vec![]; + let mut right_push = vec![]; + + let mut keep_predicates = vec![]; + for predicate in predicates { + if left_preserved && can_pushdown_join_predicate(&predicate, left.schema())? { + left_push.push(predicate); + } else if right_preserved + && can_pushdown_join_predicate(&predicate, right.schema())? + { + right_push.push(predicate); + } else { + keep_predicates.push(predicate); + } + } - // Get pushable predicates from join filter - let (on_to_left, on_to_right, on_to_keep) = if on_filter.is_empty() { - ((vec![], vec![]), (vec![], vec![]), vec![]) - } else { + let mut keep_condition = vec![]; + if !on_filter.is_empty() { let (on_left_preserved, on_right_preserved) = on_lr_is_preserved(plan)?; - let on_to_left = - get_pushable_join_predicates(&on_filter, left.schema(), on_left_preserved); - let on_to_right = - get_pushable_join_predicates(&on_filter, right.schema(), on_right_preserved); - let on_to_keep = on_filter - .iter() - .filter(|(e, _)| !on_to_left.0.contains(&e) && !on_to_right.0.contains(&e)) - .map(|(a, _)| a.clone()) - .collect::>(); - - (on_to_left, on_to_right, on_to_keep) - }; + for on in on_filter { + if on_left_preserved && can_pushdown_join_predicate(&on, left.schema())? { + left_push.push(on) + } else if on_right_preserved + && can_pushdown_join_predicate(&on, right.schema())? + { + right_push.push(on) + } else { + keep_condition.push(on) + } + } + } // Extract from OR clause, generate new predicates for both side of join if possible. // We only track the unpushable predicates above. - let or_to_left = - extract_or_clauses_for_join(&to_keep.0, left.schema(), left_preserved); - let or_to_right = - extract_or_clauses_for_join(&to_keep.0, right.schema(), right_preserved); + // TODO: we just get, but don't remove them from origin expr. + let or_to_left = extract_or_clauses_for_join( + &keep_predicates.iter().collect::>(), + left.schema(), + left_preserved, + ); + let or_to_right = extract_or_clauses_for_join( + &keep_predicates.iter().collect::>(), + right.schema(), + right_preserved, + ); let on_or_to_left = extract_or_clauses_for_join( - &on_to_keep.iter().collect::>(), + &keep_condition.iter().collect::>(), left.schema(), left_preserved, ); let on_or_to_right = extract_or_clauses_for_join( - &on_to_keep.iter().collect::>(), + &keep_condition.iter().collect::>(), right.schema(), right_preserved, ); - // Build new filter states using pushable predicates - // from current optimizer states and from ON clause. - // Then recursively call optimization for both join inputs - let mut left_state = State::default(); - left_state.append_predicates(to_left); - left_state.append_predicates(on_to_left); - or_to_left - .0 - .into_iter() - .zip(or_to_left.1) - .for_each(|(expr, cols)| left_state.filters.push((expr, cols))); - on_or_to_left - .0 - .into_iter() - .zip(on_or_to_left.1) - .for_each(|(expr, cols)| left_state.filters.push((expr, cols))); - let left = optimize(left, left_state)?; - - let mut right_state = State::default(); - right_state.append_predicates(to_right); - right_state.append_predicates(on_to_right); - or_to_right - .0 - .into_iter() - .zip(or_to_right.1) - .for_each(|(expr, cols)| right_state.filters.push((expr, cols))); - on_or_to_right - .0 - .into_iter() - .zip(on_or_to_right.1) - .for_each(|(expr, cols)| right_state.filters.push((expr, cols))); - let right = optimize(right, right_state)?; + left_push.extend(or_to_left); + left_push.extend(on_or_to_left); + right_push.extend(or_to_right); + right_push.extend(on_or_to_right); + let left = match conjunction(left_push) { + Some(predicate) => { + LogicalPlan::Filter(Filter::try_new(predicate, Arc::new(left.clone()))?) + } + None => left.clone(), + }; + let right = match conjunction(right_push) { + Some(predicate) => { + LogicalPlan::Filter(Filter::try_new(predicate, Arc::new(right.clone()))?) + } + None => right.clone(), + }; // Create a new Join with the new `left` and `right` // // expressions() output for Join is a vector consisting of @@ -500,302 +387,336 @@ fn optimize_join( // vector will contain only join keys (without additional // element representing filter). let expr = plan.expressions(); - let expr = if !on_filter.is_empty() && on_to_keep.is_empty() { + let expr = if !on_filter_empty && keep_condition.is_empty() { // New filter expression is None - should remove last element expr[..expr.len() - 1].to_vec() - } else if !on_to_keep.is_empty() { + } else if !keep_condition.is_empty() { // Replace last element with new filter expression expr[..expr.len() - 1] .iter() .cloned() - .chain(once(on_to_keep.into_iter().reduce(Expr::and).unwrap())) + .chain(once(keep_condition.into_iter().reduce(Expr::and).unwrap())) .collect() } else { plan.expressions() }; let plan = from_plan(plan, &expr, &[left, right])?; - if to_keep.0.is_empty() { + if keep_predicates.is_empty() { Ok(plan) } else { // wrap the join on the filter whose predicates must be kept - let plan = utils::add_filter(plan, &to_keep.0)?; - state.filters = remove_filters(&state.filters, &to_keep.1); - - Ok(plan) + match conjunction(keep_predicates) { + Some(predicate) => Ok(LogicalPlan::Filter(Filter::try_new( + predicate, + Arc::new(plan), + )?)), + None => Ok(plan), + } } } -fn optimize(plan: &LogicalPlan, mut state: State) -> Result { - match plan { - LogicalPlan::Explain { .. } => { - // push the optimization to the plan of this explain - push_down(&state, plan) - } - LogicalPlan::Analyze { .. } => push_down(&state, plan), - LogicalPlan::Filter(filter) => { - let predicate = utils::cnf_rewrite(filter.predicate().clone()); - - utils::split_conjunction_owned(predicate) - .into_iter() - .try_for_each::<_, Result<()>>(|predicate| { - let columns = predicate.to_columns()?; - state.filters.push((predicate, columns)); - Ok(()) - })?; - - optimize(filter.input(), state) +fn push_down_join( + plan: &LogicalPlan, + join: &Join, + parent_predicate: Option<&Expr>, +) -> Result> { + let mut predicates = match parent_predicate { + Some(parent_predicate) => { + utils::split_conjunction_owned(utils::cnf_rewrite(parent_predicate.clone())) } - LogicalPlan::Projection(Projection { - input, - expr, - schema, - }) => { - // A projection is filter-commutable, but re-writes all predicate expressions - // collect projection. - let projection = schema - .fields() - .iter() - .enumerate() - .flat_map(|(i, field)| { - // strip alias, as they should not be part of filters - let expr = match &expr[i] { - Expr::Alias(expr, _) => expr.as_ref().clone(), - expr => expr.clone(), + None => vec![], + }; + + // Convert JOIN ON predicate to Predicates + let on_filters = join + .filter + .as_ref() + .map(|e| utils::split_conjunction_owned(e.clone())) + .unwrap_or_else(Vec::new); + + if join.join_type == JoinType::Inner { + // For inner joins, duplicate filters for joined columns so filters can be pushed down + // to both sides. Take the following query as an example: + // + // ```sql + // SELECT * FROM t1 JOIN t2 on t1.id = t2.uid WHERE t1.id > 1 + // ``` + // + // `t1.id > 1` predicate needs to be pushed down to t1 table scan, while + // `t2.uid > 1` predicate needs to be pushed down to t2 table scan. + // + // Join clauses with `Using` constraints also take advantage of this logic to make sure + // predicates reference the shared join columns are pushed to both sides. + // This logic should also been applied to conditions in JOIN ON clause + let join_side_filters = predicates + .iter() + .chain(on_filters.iter()) + .filter_map(|predicate| { + let mut join_cols_to_replace = HashMap::new(); + let columns = match predicate.to_columns() { + Ok(columns) => columns, + Err(e) => return Some(Err(e)), + }; + + for col in columns.iter() { + for (l, r) in join.on.iter() { + if col == l { + join_cols_to_replace.insert(col, r); + break; + } else if col == r { + join_cols_to_replace.insert(col, l); + break; + } + } + } + + if join_cols_to_replace.is_empty() { + return None; + } + + let join_side_predicate = + match replace_col(predicate.clone(), &join_cols_to_replace) { + Ok(p) => p, + Err(e) => { + return Some(Err(e)); + } }; - // Convert both qualified and unqualified fields - [ - (field.name().clone(), expr.clone()), - (field.qualified_name(), expr), - ] - }) - .collect::>(); + Some(Ok(join_side_predicate)) + }) + .collect::>>()?; + predicates.extend(join_side_filters); + } + if on_filters.is_empty() && predicates.is_empty() { + return Ok(None); + } + Ok(Some(push_down_all_join( + predicates, + plan, + &join.left, + &join.right, + on_filters, + )?)) +} - // re-write all filters based on this projection - // E.g. in `Filter: b\n Projection: a > 1 as b`, we can swap them, but the filter must be "a > 1" - for (predicate, columns) in state.filters.iter_mut() { - *predicate = replace_cols_by_name(predicate.clone(), &projection)?; +impl OptimizerRule for PushDownFilter { + fn name(&self) -> &str { + "push_down_filter" + } - columns.clear(); - expr_to_columns(predicate, columns)?; + fn optimize( + &self, + plan: &LogicalPlan, + optimizer_config: &mut OptimizerConfig, + ) -> Result { + let filter = match plan { + LogicalPlan::Filter(filter) => filter, + // we also need to pushdown filter in Join. + LogicalPlan::Join(join) => { + let optimized_plan = push_down_join(plan, join, None)?; + return match optimized_plan { + Some(optimized_plan) => { + utils::optimize_children(self, &optimized_plan, optimizer_config) + } + None => utils::optimize_children(self, plan, optimizer_config), + }; } + _ => return utils::optimize_children(self, plan, optimizer_config), + }; - // optimize inner - let new_input = optimize(input, state)?; - Ok(from_plan(plan, expr, &[new_input])?) - } - LogicalPlan::Aggregate(Aggregate { aggr_expr, .. }) => { - // An aggregate's aggreagate columns are _not_ filter-commutable => collect these: - // * columns whose aggregation expression depends on - // * the aggregation columns themselves - - // construct set of columns that `aggr_expr` depends on - let mut used_columns = HashSet::new(); - exprlist_to_columns(aggr_expr, &mut used_columns)?; - - let agg_columns = aggr_expr - .iter() - .map(|x| Ok(Column::from_name(x.display_name()?))) - .collect::>>()?; - used_columns.extend(agg_columns); - - issue_filters(state, used_columns, plan) - } - LogicalPlan::Sort { .. } => { - // sort is filter-commutable - push_down(&state, plan) - } - LogicalPlan::Union(Union { inputs: _, schema }) => { - // union changing all qualifiers while building logical plan so we need - // to rewrite filters to push unqualified columns to inputs - let projection = schema - .fields() - .iter() - .map(|field| (field.qualified_name(), col(field.name()))) - .collect::>(); - - // rewriting predicate expressions using unqualified names as replacements - if !projection.is_empty() { - for (predicate, columns) in state.filters.iter_mut() { - *predicate = replace_cols_by_name(predicate.clone(), &projection)?; - - columns.clear(); - expr_to_columns(predicate, columns)?; - } + let child_plan = &**filter.input(); + let new_plan = match child_plan { + LogicalPlan::Filter(child_filter) => { + let new_predicate = + and(filter.predicate().clone(), child_filter.predicate().clone()); + let new_plan = LogicalPlan::Filter(Filter::try_new( + new_predicate, + child_filter.input().clone(), + )?); + return self.optimize(&new_plan, optimizer_config); } - - push_down(&state, plan) - } - LogicalPlan::Limit(Limit { input, .. }) => { - // limit is _not_ filter-commutable => collect all columns from its input - let used_columns = input - .schema() - .fields() - .iter() - .map(|f| f.qualified_column()) - .collect::>(); - issue_filters(state, used_columns, plan) - } - LogicalPlan::CrossJoin(CrossJoin { left, right, .. }) => { - optimize_join(state, plan, left, right, vec![]) - } - LogicalPlan::Join(Join { - left, - right, - on, - filter, - join_type, - .. - }) => { - // Convert JOIN ON predicate to Predicates - let on_filters = filter - .as_ref() - .map(|e| { - let predicates = utils::split_conjunction(e); - - predicates - .into_iter() - .map(|e| Ok((e.clone(), e.to_columns()?))) - .collect::>>() - }) - .unwrap_or_else(|| Ok(vec![]))?; - - if *join_type == JoinType::Inner { - // For inner joins, duplicate filters for joined columns so filters can be pushed down - // to both sides. Take the following query as an example: - // - // ```sql - // SELECT * FROM t1 JOIN t2 on t1.id = t2.uid WHERE t1.id > 1 - // ``` - // - // `t1.id > 1` predicate needs to be pushed down to t1 table scan, while - // `t2.uid > 1` predicate needs to be pushed down to t2 table scan. - // - // Join clauses with `Using` constraints also take advantage of this logic to make sure - // predicates reference the shared join columns are pushed to both sides. - // This logic should also been applied to conditions in JOIN ON clause - let join_side_filters = state - .filters + LogicalPlan::Repartition(_) + | LogicalPlan::Distinct(_) + | LogicalPlan::Sort(_) => { + // commutable + let new_filter = + plan.with_new_inputs(&[ + (**(child_plan.inputs().get(0).unwrap())).clone() + ])?; + child_plan.with_new_inputs(&[new_filter])? + } + LogicalPlan::Projection(projection) => { + // A projection is filter-commutable, but re-writes all predicate expressions + // collect projection. + let replace_map = projection + .schema + .fields() .iter() - .chain(on_filters.iter()) - .filter_map(|(predicate, columns)| { - let mut join_cols_to_replace = HashMap::new(); - for col in columns.iter() { - for (l, r) in on { - if col == l { - join_cols_to_replace.insert(col, r); - break; - } else if col == r { - join_cols_to_replace.insert(col, l); - break; - } - } - } + .enumerate() + .map(|(i, field)| { + // strip alias, as they should not be part of filters + let expr = match &projection.expr[i] { + Expr::Alias(expr, _) => expr.as_ref().clone(), + expr => expr.clone(), + }; + + (field.qualified_name(), expr) + }) + .collect::>(); - if join_cols_to_replace.is_empty() { - return None; - } + // re-write all filters based on this projection + // E.g. in `Filter: b\n Projection: a > 1 as b`, we can swap them, but the filter must be "a > 1" + let new_filter = LogicalPlan::Filter(Filter::try_new( + replace_cols_by_name(filter.predicate().clone(), &replace_map)?, + projection.input.clone(), + )?); - let join_side_predicate = - match replace_col(predicate.clone(), &join_cols_to_replace) { - Ok(p) => p, - Err(e) => { - return Some(Err(e)); - } - }; - - let join_side_columns = columns - .clone() - .into_iter() - // replace keys in join_cols_to_replace with values in resulting column - // set - .filter(|c| !join_cols_to_replace.contains_key(c)) - .chain(join_cols_to_replace.values().map(|v| (*v).clone())) - .collect(); - - Some(Ok((join_side_predicate, join_side_columns))) - }) - .collect::>>()?; - state.filters.extend(join_side_filters); + child_plan.with_new_inputs(&[new_filter])? } + LogicalPlan::Union(union) => { + let mut inputs = Vec::with_capacity(union.inputs.len()); + for input in &union.inputs { + let mut replace_map = HashMap::new(); + for (i, field) in input.schema().fields().iter().enumerate() { + replace_map.insert( + union.schema.fields().get(i).unwrap().qualified_name(), + Expr::Column(field.qualified_column()), + ); + } - optimize_join(state, plan, left, right, on_filters) - } - LogicalPlan::TableScan(TableScan { - source, - projected_schema, - filters, - projection, - table_name, - fetch, - }) => { - let mut used_columns = HashSet::new(); - let mut new_filters = filters.clone(); - - for (filter_expr, cols) in &state.filters { - let (preserve_filter_node, add_to_provider) = - match source.supports_filter_pushdown(filter_expr)? { - TableProviderFilterPushDown::Unsupported => (true, false), - TableProviderFilterPushDown::Inexact => (true, true), - TableProviderFilterPushDown::Exact => (false, true), - }; - - if preserve_filter_node { - used_columns.extend(cols.clone()); + let push_predicate = + replace_cols_by_name(filter.predicate().clone(), &replace_map)?; + inputs.push(Arc::new(LogicalPlan::Filter(Filter::try_new( + push_predicate, + input.clone(), + )?))) } - - if add_to_provider { - // Don't add expression again if it's already present in - // pushed down filters. - if new_filters.contains(filter_expr) { - continue; + LogicalPlan::Union(Union { + inputs, + schema: plan.schema().clone(), + }) + } + LogicalPlan::Aggregate(agg) => { + // An aggregate's aggregate columns are _not_ filter-commutable => collect these: + // * columns whose aggregation expression depends on + // * the aggregation columns themselves + + // construct set of columns that `aggr_expr` depends on + let mut used_columns = HashSet::new(); + exprlist_to_columns(&agg.aggr_expr, &mut used_columns)?; + let agg_columns = agg + .aggr_expr + .iter() + .map(|x| Ok(Column::from_name(x.display_name()?))) + .collect::>>()?; + used_columns.extend(agg_columns); + + let predicates = utils::split_conjunction_owned(utils::cnf_rewrite( + filter.predicate().clone(), + )); + + let mut keep_predicates = vec![]; + let mut push_predicates = vec![]; + for expr in predicates { + let columns = expr.to_columns()?; + if columns.is_empty() + || !columns + .intersection(&used_columns) + .collect::>() + .is_empty() + { + keep_predicates.push(expr); + } else { + push_predicates.push(expr); } - new_filters.push(filter_expr.clone()); + } + + let child = match conjunction(push_predicates) { + Some(predicate) => LogicalPlan::Filter(Filter::try_new( + predicate, + Arc::new((*agg.input).clone()), + )?), + None => (*agg.input).clone(), + }; + let new_agg = from_plan( + filter.input(), + &filter.input().expressions(), + &vec![child], + )?; + match conjunction(keep_predicates) { + Some(predicate) => LogicalPlan::Filter(Filter::try_new( + predicate, + Arc::new(new_agg), + )?), + None => new_agg, } } + LogicalPlan::Join(join) => { + match push_down_join(filter.input(), join, Some(filter.predicate()))? { + Some(optimized_plan) => optimized_plan, + None => plan.clone(), + } + } + LogicalPlan::CrossJoin(CrossJoin { left, right, .. }) => { + let predicates = utils::split_conjunction_owned(utils::cnf_rewrite( + filter.predicate().clone(), + )); - issue_filters( - state, - used_columns, - &LogicalPlan::TableScan(TableScan { - source: source.clone(), - projection: projection.clone(), - projected_schema: projected_schema.clone(), - table_name: table_name.clone(), - filters: new_filters, - fetch: *fetch, - }), - ) - } - _ => { - // all other plans are _not_ filter-commutable - let used_columns = plan - .schema() - .fields() - .iter() - .map(|f| f.qualified_column()) - .collect::>(); - issue_filters(state, used_columns, plan) - } - } -} + push_down_all_join(predicates, filter.input(), left, right, vec![])? + } + LogicalPlan::TableScan(scan) => { + let mut new_scan_filters = scan.filters.clone(); + let mut new_predicate = vec![]; + + let filter_predicates = utils::split_conjunction_owned( + utils::cnf_rewrite(filter.predicate().clone()), + ); + + for filter_expr in &filter_predicates { + let (preserve_filter_node, add_to_provider) = + match scan.source.supports_filter_pushdown(filter_expr)? { + TableProviderFilterPushDown::Unsupported => (true, false), + TableProviderFilterPushDown::Inexact => (true, true), + TableProviderFilterPushDown::Exact => (false, true), + }; + if preserve_filter_node { + new_predicate.push(filter_expr.clone()); + } + if add_to_provider { + // avoid reduplicated filter expr. + if new_scan_filters.contains(filter_expr) { + continue; + } + new_scan_filters.push(filter_expr.clone()); + } + } -impl OptimizerRule for FilterPushDown { - fn name(&self) -> &str { - "filter_push_down" - } + let new_scan = LogicalPlan::TableScan(TableScan { + source: scan.source.clone(), + projection: scan.projection.clone(), + projected_schema: scan.projected_schema.clone(), + table_name: scan.table_name.clone(), + filters: new_scan_filters, + fetch: scan.fetch, + }); + + match conjunction(new_predicate) { + Some(predicate) => LogicalPlan::Filter(Filter::try_new( + predicate, + Arc::new(new_scan), + )?), + None => new_scan, + } + } + _ => plan.clone(), + }; - fn optimize( - &self, - plan: &LogicalPlan, - _: &mut OptimizerConfig, - ) -> Result { - optimize(plan, State::default()) + utils::optimize_children(self, &new_plan, optimizer_config) } } -impl FilterPushDown { +impl PushDownFilter { #[allow(missing_docs)] pub fn new() -> Self { Self {} @@ -832,21 +753,19 @@ mod tests { use async_trait::async_trait; use datafusion_common::DFSchema; use datafusion_expr::{ - and, col, in_list, in_subquery, lit, logical_plan::JoinType, sum, Expr, - LogicalPlanBuilder, Operator, TableSource, TableType, + and, col, in_list, in_subquery, lit, logical_plan::JoinType, or, sum, BinaryExpr, + Expr, LogicalPlanBuilder, Operator, TableSource, TableType, }; use std::sync::Arc; - fn optimize_plan(plan: &LogicalPlan) -> LogicalPlan { - let rule = FilterPushDown::new(); - rule.optimize(plan, &mut OptimizerConfig::new()) - .expect("failed to optimize plan") - } - - fn assert_optimized_plan_eq(plan: &LogicalPlan, expected: &str) { - let optimized_plan = optimize_plan(plan); + fn assert_optimized_plan_eq(plan: &LogicalPlan, expected: &str) -> Result<()> { + let optimized_plan = PushDownFilter::new() + .optimize(plan, &mut OptimizerConfig::new()) + .expect("failed to optimize plan"); let formatted_plan = format!("{:?}", optimized_plan); - assert_eq!(formatted_plan, expected); + assert_eq!(plan.schema(), optimized_plan.schema()); + assert_eq!(expected, formatted_plan); + Ok(()) } #[test] @@ -861,8 +780,7 @@ mod tests { Projection: test.a, test.b\ \n Filter: test.a = Int64(1)\ \n TableScan: test"; - assert_optimized_plan_eq(&plan, expected); - Ok(()) + assert_optimized_plan_eq(&plan, expected) } #[test] @@ -879,8 +797,7 @@ mod tests { \n Limit: skip=0, fetch=10\ \n Projection: test.a, test.b\ \n TableScan: test"; - assert_optimized_plan_eq(&plan, expected); - Ok(()) + assert_optimized_plan_eq(&plan, expected) } #[test] @@ -892,8 +809,7 @@ mod tests { let expected = "\ Filter: Int64(0) = Int64(1)\ \n TableScan: test"; - assert_optimized_plan_eq(&plan, expected); - Ok(()) + assert_optimized_plan_eq(&plan, expected) } #[test] @@ -910,8 +826,7 @@ mod tests { \n Projection: test.a, test.b, test.c\ \n Filter: test.a = Int64(1)\ \n TableScan: test"; - assert_optimized_plan_eq(&plan, expected); - Ok(()) + assert_optimized_plan_eq(&plan, expected) } #[test] @@ -926,8 +841,7 @@ mod tests { Aggregate: groupBy=[[test.a]], aggr=[[SUM(test.b) AS total_salary]]\ \n Filter: test.a > Int64(10)\ \n TableScan: test"; - assert_optimized_plan_eq(&plan, expected); - Ok(()) + assert_optimized_plan_eq(&plan, expected) } #[test] @@ -942,8 +856,7 @@ mod tests { Filter: b > Int64(10)\ \n Aggregate: groupBy=[[test.a]], aggr=[[SUM(test.b) AS b]]\ \n TableScan: test"; - assert_optimized_plan_eq(&plan, expected); - Ok(()) + assert_optimized_plan_eq(&plan, expected) } #[test] @@ -966,8 +879,7 @@ mod tests { \n Aggregate: groupBy=[[test.a]], aggr=[[SUM(test.b) AS b]]\ \n Filter: test.c = Int64(1) OR test.c = Int64(1)\ \n TableScan: test"; - assert_optimized_plan_eq(&plan, expected); - Ok(()) + assert_optimized_plan_eq(&plan, expected) } /// verifies that a filter is pushed to before a projection, the filter expression is correctly re-written @@ -983,8 +895,7 @@ mod tests { Projection: test.a AS b, test.c\ \n Filter: test.a = Int64(1)\ \n TableScan: test"; - assert_optimized_plan_eq(&plan, expected); - Ok(()) + assert_optimized_plan_eq(&plan, expected) } fn add(left: Expr, right: Expr) -> Expr { @@ -1029,8 +940,7 @@ mod tests { Projection: test.a * Int32(2) + test.c AS b, test.c\ \n Filter: test.a * Int32(2) + test.c = Int64(1)\ \n TableScan: test"; - assert_optimized_plan_eq(&plan, expected); - Ok(()) + assert_optimized_plan_eq(&plan, expected) } /// verifies that when a filter is pushed to after 2 projections, the filter expression is correctly re-written @@ -1063,8 +973,7 @@ mod tests { \n Projection: test.a * Int32(2) + test.c AS b, test.c\ \n Filter: (test.a * Int32(2) + test.c) * Int32(3) = Int64(1)\ \n TableScan: test"; - assert_optimized_plan_eq(&plan, expected); - Ok(()) + assert_optimized_plan_eq(&plan, expected) } /// verifies that when two filters apply after an aggregation that only allows one to be pushed, one is pushed @@ -1098,9 +1007,7 @@ mod tests { \n Projection: test.a AS b, test.c\ \n Filter: test.a > Int64(10)\ \n TableScan: test"; - assert_optimized_plan_eq(&plan, expected); - - Ok(()) + assert_optimized_plan_eq(&plan, expected) } /// verifies that when a filter with two predicates is applied after an aggregation that only allows one to be pushed, one is pushed @@ -1135,9 +1042,7 @@ mod tests { \n Projection: test.a AS b, test.c\ \n Filter: test.a > Int64(10)\ \n TableScan: test"; - assert_optimized_plan_eq(&plan, expected); - - Ok(()) + assert_optimized_plan_eq(&plan, expected) } /// verifies that when two limits are in place, we jump neither @@ -1159,26 +1064,24 @@ mod tests { \n Limit: skip=0, fetch=20\ \n Projection: test.a, test.b\ \n TableScan: test"; - assert_optimized_plan_eq(&plan, expected); - Ok(()) + assert_optimized_plan_eq(&plan, expected) } #[test] fn union_all() -> Result<()> { let table_scan = test_table_scan()?; - let plan = LogicalPlanBuilder::from(table_scan.clone()) - .union(LogicalPlanBuilder::from(table_scan).build()?)? + let table_scan2 = test_table_scan_with_name("test2")?; + let plan = LogicalPlanBuilder::from(table_scan) + .union(LogicalPlanBuilder::from(table_scan2).build()?)? .filter(col("a").eq(lit(1i64)))? .build()?; // filter appears below Union - let expected = "\ - Union\ - \n Filter: a = Int64(1)\ - \n TableScan: test\ - \n Filter: a = Int64(1)\ - \n TableScan: test"; - assert_optimized_plan_eq(&plan, expected); - Ok(()) + let expected = "Union\ + \n Filter: test.a = Int64(1)\ + \n TableScan: test\ + \n Filter: test2.a = Int64(1)\ + \n TableScan: test2"; + assert_optimized_plan_eq(&plan, expected) } #[test] @@ -1194,16 +1097,15 @@ mod tests { // filter appears below Union let expected = "Union\ - \n SubqueryAlias: test2\ - \n Projection: test.a AS b\ - \n Filter: test.a = Int64(1)\ + \n Filter: test2.b = Int64(1)\ + \n SubqueryAlias: test2\ + \n Projection: test.a AS b\ \n TableScan: test\ - \n SubqueryAlias: test2\ - \n Projection: test.a AS b\ - \n Filter: test.a = Int64(1)\ + \n Filter: test2.b = Int64(1)\ + \n SubqueryAlias: test2\ + \n Projection: test.a AS b\ \n TableScan: test"; - assert_optimized_plan_eq(&plan, expected); - Ok(()) + assert_optimized_plan_eq(&plan, expected) } /// verifies that filters with the same columns are correctly placed @@ -1238,8 +1140,7 @@ mod tests { \n Filter: test.a <= Int64(1)\ \n TableScan: test"; - assert_optimized_plan_eq(&plan, expected); - Ok(()) + assert_optimized_plan_eq(&plan, expected) } /// verifies that filters to be placed on the same depth are ANDed @@ -1269,8 +1170,7 @@ mod tests { \n Limit: skip=0, fetch=1\ \n TableScan: test"; - assert_optimized_plan_eq(&plan, expected); - Ok(()) + assert_optimized_plan_eq(&plan, expected) } /// verifies that filters on a plan with user nodes are not lost @@ -1292,8 +1192,7 @@ mod tests { // not part of the test assert_eq!(format!("{:?}", plan), expected); - assert_optimized_plan_eq(&plan, expected); - Ok(()) + assert_optimized_plan_eq(&plan, expected) } /// post-on-join predicates on a column common to both sides is pushed to both sides @@ -1318,8 +1217,7 @@ mod tests { // not part of the test, just good to know: assert_eq!( format!("{:?}", plan), - "\ - Filter: test.a <= Int64(1)\ + "Filter: test.a <= Int64(1)\ \n Inner Join: test.a = test2.a\ \n TableScan: test\ \n Projection: test2.a\ @@ -1334,8 +1232,7 @@ mod tests { \n Projection: test2.a\ \n Filter: test2.a <= Int64(1)\ \n TableScan: test2"; - assert_optimized_plan_eq(&plan, expected); - Ok(()) + assert_optimized_plan_eq(&plan, expected) } /// post-using-join predicates on a column common to both sides is pushed to both sides @@ -1359,8 +1256,7 @@ mod tests { // not part of the test, just good to know: assert_eq!( format!("{:?}", plan), - "\ - Filter: test.a <= Int64(1)\ + "Filter: test.a <= Int64(1)\ \n Inner Join: Using test.a = test2.a\ \n TableScan: test\ \n Projection: test2.a\ @@ -1375,8 +1271,7 @@ mod tests { \n Projection: test2.a\ \n Filter: test2.a <= Int64(1)\ \n TableScan: test2"; - assert_optimized_plan_eq(&plan, expected); - Ok(()) + assert_optimized_plan_eq(&plan, expected) } /// post-join predicates with columns from both sides are not pushed @@ -1404,8 +1299,7 @@ mod tests { // not part of the test, just good to know: assert_eq!( format!("{:?}", plan), - "\ - Filter: test.c <= test2.b\ + "Filter: test.c <= test2.b\ \n Inner Join: test.a = test2.a\ \n Projection: test.a, test.c\ \n TableScan: test\ @@ -1415,8 +1309,7 @@ mod tests { // expected is equal: no push-down let expected = &format!("{:?}", plan); - assert_optimized_plan_eq(&plan, expected); - Ok(()) + assert_optimized_plan_eq(&plan, expected) } /// post-join predicates with columns from one side of a join are pushed only to that side @@ -1444,8 +1337,7 @@ mod tests { // not part of the test, just good to know: assert_eq!( format!("{:?}", plan), - "\ - Filter: test.b <= Int64(1)\ + "Filter: test.b <= Int64(1)\ \n Inner Join: test.a = test2.a\ \n Projection: test.a, test.b\ \n TableScan: test\ @@ -1460,8 +1352,7 @@ mod tests { \n TableScan: test\ \n Projection: test2.a, test2.c\ \n TableScan: test2"; - assert_optimized_plan_eq(&plan, expected); - Ok(()) + assert_optimized_plan_eq(&plan, expected) } /// post-join predicates on the right side of a left join are not duplicated @@ -1486,8 +1377,7 @@ mod tests { // not part of the test, just good to know: assert_eq!( format!("{:?}", plan), - "\ - Filter: test2.a <= Int64(1)\ + "Filter: test2.a <= Int64(1)\ \n Left Join: Using test.a = test2.a\ \n TableScan: test\ \n Projection: test2.a\ @@ -1501,12 +1391,10 @@ mod tests { \n TableScan: test\ \n Projection: test2.a\ \n TableScan: test2"; - assert_optimized_plan_eq(&plan, expected); - Ok(()) + assert_optimized_plan_eq(&plan, expected) } /// post-join predicates on the left side of a right join are not duplicated - /// TODO: In this case we can sometimes convert the join to an INNER join #[test] fn filter_using_right_join() -> Result<()> { let table_scan = test_table_scan()?; @@ -1527,8 +1415,7 @@ mod tests { // not part of the test, just good to know: assert_eq!( format!("{:?}", plan), - "\ - Filter: test.a <= Int64(1)\ + "Filter: test.a <= Int64(1)\ \n Right Join: Using test.a = test2.a\ \n TableScan: test\ \n Projection: test2.a\ @@ -1542,8 +1429,7 @@ mod tests { \n TableScan: test\ \n Projection: test2.a\ \n TableScan: test2"; - assert_optimized_plan_eq(&plan, expected); - Ok(()) + assert_optimized_plan_eq(&plan, expected) } /// post-left-join predicate on a column common to both sides is only pushed to the left side @@ -1568,8 +1454,7 @@ mod tests { // not part of the test, just good to know: assert_eq!( format!("{:?}", plan), - "\ - Filter: test.a <= Int64(1)\ + "Filter: test.a <= Int64(1)\ \n Left Join: Using test.a = test2.a\ \n TableScan: test\ \n Projection: test2.a\ @@ -1583,8 +1468,7 @@ mod tests { \n TableScan: test\ \n Projection: test2.a\ \n TableScan: test2"; - assert_optimized_plan_eq(&plan, expected); - Ok(()) + assert_optimized_plan_eq(&plan, expected) } /// post-right-join predicate on a column common to both sides is only pushed to the right side @@ -1609,8 +1493,7 @@ mod tests { // not part of the test, just good to know: assert_eq!( format!("{:?}", plan), - "\ - Filter: test2.a <= Int64(1)\ + "Filter: test2.a <= Int64(1)\ \n Right Join: Using test.a = test2.a\ \n TableScan: test\ \n Projection: test2.a\ @@ -1624,8 +1507,7 @@ mod tests { \n Projection: test2.a\ \n Filter: test2.a <= Int64(1)\ \n TableScan: test2"; - assert_optimized_plan_eq(&plan, expected); - Ok(()) + assert_optimized_plan_eq(&plan, expected) } /// single table predicate parts of ON condition should be pushed to both inputs @@ -1655,8 +1537,7 @@ mod tests { // not part of the test, just good to know: assert_eq!( format!("{:?}", plan), - "\ - Inner Join: test.a = test2.a Filter: test.c > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4)\ + "Inner Join: test.a = test2.a Filter: test.c > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4)\ \n Projection: test.a, test.b, test.c\ \n TableScan: test\ \n Projection: test2.a, test2.b, test2.c\ @@ -1671,8 +1552,7 @@ mod tests { \n Projection: test2.a, test2.b, test2.c\ \n Filter: test2.c > UInt32(4)\ \n TableScan: test2"; - assert_optimized_plan_eq(&plan, expected); - Ok(()) + assert_optimized_plan_eq(&plan, expected) } /// join filter should be completely removed after pushdown @@ -1701,8 +1581,7 @@ mod tests { // not part of the test, just good to know: assert_eq!( format!("{:?}", plan), - "\ - Inner Join: test.a = test2.a Filter: test.b > UInt32(1) AND test2.c > UInt32(4)\ + "Inner Join: test.a = test2.a Filter: test.b > UInt32(1) AND test2.c > UInt32(4)\ \n Projection: test.a, test.b, test.c\ \n TableScan: test\ \n Projection: test2.a, test2.b, test2.c\ @@ -1717,8 +1596,7 @@ mod tests { \n Projection: test2.a, test2.b, test2.c\ \n Filter: test2.c > UInt32(4)\ \n TableScan: test2"; - assert_optimized_plan_eq(&plan, expected); - Ok(()) + assert_optimized_plan_eq(&plan, expected) } /// predicate on join key in filter expression should be pushed down to both inputs @@ -1745,8 +1623,7 @@ mod tests { // not part of the test, just good to know: assert_eq!( format!("{:?}", plan), - "\ - Inner Join: test.a = test2.b Filter: test.a > UInt32(1)\ + "Inner Join: test.a = test2.b Filter: test.a > UInt32(1)\ \n Projection: test.a\ \n TableScan: test\ \n Projection: test2.b\ @@ -1761,8 +1638,7 @@ mod tests { \n Projection: test2.b\ \n Filter: test2.b > UInt32(1)\ \n TableScan: test2"; - assert_optimized_plan_eq(&plan, expected); - Ok(()) + assert_optimized_plan_eq(&plan, expected) } /// single table predicate parts of ON condition should be pushed to right input @@ -1792,8 +1668,7 @@ mod tests { // not part of the test, just good to know: assert_eq!( format!("{:?}", plan), - "\ - Left Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4)\ + "Left Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4)\ \n Projection: test.a, test.b, test.c\ \n TableScan: test\ \n Projection: test2.a, test2.b, test2.c\ @@ -1807,8 +1682,7 @@ mod tests { \n Projection: test2.a, test2.b, test2.c\ \n Filter: test2.c > UInt32(4)\ \n TableScan: test2"; - assert_optimized_plan_eq(&plan, expected); - Ok(()) + assert_optimized_plan_eq(&plan, expected) } /// single table predicate parts of ON condition should be pushed to left input @@ -1838,8 +1712,7 @@ mod tests { // not part of the test, just good to know: assert_eq!( format!("{:?}", plan), - "\ - Right Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4)\ + "Right Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4)\ \n Projection: test.a, test.b, test.c\ \n TableScan: test\ \n Projection: test2.a, test2.b, test2.c\ @@ -1853,8 +1726,7 @@ mod tests { \n TableScan: test\ \n Projection: test2.a, test2.b, test2.c\ \n TableScan: test2"; - assert_optimized_plan_eq(&plan, expected); - Ok(()) + assert_optimized_plan_eq(&plan, expected) } /// single table predicate parts of ON condition should not be pushed @@ -1884,8 +1756,7 @@ mod tests { // not part of the test, just good to know: assert_eq!( format!("{:?}", plan), - "\ - Full Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4)\ + "Full Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4)\ \n Projection: test.a, test.b, test.c\ \n TableScan: test\ \n Projection: test2.a, test2.b, test2.c\ @@ -1893,8 +1764,7 @@ mod tests { ); let expected = &format!("{:?}", plan); - assert_optimized_plan_eq(&plan, expected); - Ok(()) + assert_optimized_plan_eq(&plan, expected) } struct PushDownProvider { @@ -1961,8 +1831,7 @@ mod tests { let expected = "\ TableScan: test, full_filters=[a = Int64(1)]"; - assert_optimized_plan_eq(&plan, expected); - Ok(()) + assert_optimized_plan_eq(&plan, expected) } #[test] @@ -1973,8 +1842,7 @@ mod tests { let expected = "\ Filter: a = Int64(1)\ \n TableScan: test, partial_filters=[a = Int64(1)]"; - assert_optimized_plan_eq(&plan, expected); - Ok(()) + assert_optimized_plan_eq(&plan, expected) } #[test] @@ -1982,7 +1850,9 @@ mod tests { let plan = table_scan_with_pushdown_provider(TableProviderFilterPushDown::Inexact)?; - let optimised_plan = optimize_plan(&plan); + let optimised_plan = PushDownFilter::new() + .optimize(&plan, &mut OptimizerConfig::new()) + .expect("failed to optimize plan"); let expected = "\ Filter: a = Int64(1)\ @@ -1990,8 +1860,7 @@ mod tests { // Optimizing the same plan multiple times should produce the same plan // each time. - assert_optimized_plan_eq(&optimised_plan, expected); - Ok(()) + assert_optimized_plan_eq(&optimised_plan, expected) } #[test] @@ -2002,8 +1871,7 @@ mod tests { let expected = "\ Filter: a = Int64(1)\ \n TableScan: test"; - assert_optimized_plan_eq(&plan, expected); - Ok(()) + assert_optimized_plan_eq(&plan, expected) } #[test] @@ -2028,13 +1896,11 @@ mod tests { .project(vec![col("a"), col("b")])? .build()?; - let expected ="Projection: a, b\ + let expected = "Projection: a, b\ \n Filter: a = Int64(10) AND b > Int64(11)\ \n TableScan: test projection=[a], partial_filters=[a = Int64(10), b > Int64(11)]"; - assert_optimized_plan_eq(&plan, expected); - - Ok(()) + assert_optimized_plan_eq(&plan, expected) } #[test] @@ -2051,11 +1917,9 @@ mod tests { // filter on col b assert_eq!( format!("{:?}", plan), - "\ - Filter: b > Int64(10) AND test.c > Int64(10)\ + "Filter: b > Int64(10) AND test.c > Int64(10)\ \n Projection: test.a AS b, test.c\ - \n TableScan: test\ - " + \n TableScan: test" ); // rewrite filter col b to test.a @@ -2065,9 +1929,7 @@ mod tests { \n TableScan: test\ "; - assert_optimized_plan_eq(&plan, expected); - - Ok(()) + assert_optimized_plan_eq(&plan, expected) } #[test] @@ -2085,8 +1947,7 @@ mod tests { // filter on col b assert_eq!( format!("{:?}", plan), - "\ - Filter: b > Int64(10) AND test.c > Int64(10)\ + "Filter: b > Int64(10) AND test.c > Int64(10)\ \n Projection: b, test.c\ \n Projection: test.a AS b, test.c\ \n TableScan: test\ @@ -2101,9 +1962,7 @@ mod tests { \n TableScan: test\ "; - assert_optimized_plan_eq(&plan, expected); - - Ok(()) + assert_optimized_plan_eq(&plan, expected) } #[test] @@ -2117,8 +1976,7 @@ mod tests { // filter on col b and d assert_eq!( format!("{:?}", plan), - "\ - Filter: b > Int64(10) AND d > Int64(10)\ + "Filter: b > Int64(10) AND d > Int64(10)\ \n Projection: test.a AS b, test.c AS d\ \n TableScan: test\ " @@ -2131,9 +1989,7 @@ mod tests { \n TableScan: test\ "; - assert_optimized_plan_eq(&plan, expected); - - Ok(()) + assert_optimized_plan_eq(&plan, expected) } /// predicate on join key in filter expression should be pushed down to both inputs @@ -2159,8 +2015,7 @@ mod tests { assert_eq!( format!("{:?}", plan), - "\ - Inner Join: c = d Filter: c > UInt32(1)\ + "Inner Join: c = d Filter: c > UInt32(1)\ \n Projection: test.a AS c\ \n TableScan: test\ \n Projection: test2.b AS d\ @@ -2176,8 +2031,7 @@ mod tests { \n Projection: test2.b AS d\ \n Filter: test2.b > UInt32(1)\ \n TableScan: test2"; - assert_optimized_plan_eq(&plan, expected); - Ok(()) + assert_optimized_plan_eq(&plan, expected) } #[test] @@ -2195,8 +2049,7 @@ mod tests { // filter on col b assert_eq!( format!("{:?}", plan), - "\ - Filter: b IN ([UInt32(1), UInt32(2), UInt32(3), UInt32(4)])\ + "Filter: b IN ([UInt32(1), UInt32(2), UInt32(3), UInt32(4)])\ \n Projection: test.a AS b, test.c\ \n TableScan: test\ " @@ -2209,9 +2062,7 @@ mod tests { \n TableScan: test\ "; - assert_optimized_plan_eq(&plan, expected); - - Ok(()) + assert_optimized_plan_eq(&plan, expected) } #[test] @@ -2230,8 +2081,7 @@ mod tests { // filter on col b assert_eq!( format!("{:?}", plan), - "\ - Filter: b IN ([UInt32(1), UInt32(2), UInt32(3), UInt32(4)])\ + "Filter: b IN ([UInt32(1), UInt32(2), UInt32(3), UInt32(4)])\ \n Projection: b, test.c\ \n Projection: test.a AS b, test.c\ \n TableScan: test\ @@ -2246,9 +2096,7 @@ mod tests { \n TableScan: test\ "; - assert_optimized_plan_eq(&plan, expected); - - Ok(()) + assert_optimized_plan_eq(&plan, expected) } #[test] @@ -2285,9 +2133,7 @@ mod tests { \n Projection: sq.c\ \n TableScan: sq\ \n TableScan: test"; - assert_optimized_plan_eq(&plan, expected_after); - - Ok(()) + assert_optimized_plan_eq(&plan, expected_after) } #[test] @@ -2318,9 +2164,7 @@ mod tests { \n SubqueryAlias: b\ \n Projection: Int64(0) AS a\ \n EmptyRelation"; - assert_optimized_plan_eq(&plan, expected_after); - - Ok(()) + assert_optimized_plan_eq(&plan, expected_after) } #[test] @@ -2351,7 +2195,34 @@ mod tests { \n TableScan: test\ \n Projection: test1.a AS d, test1.a AS e\ \n TableScan: test1"; - assert_optimized_plan_eq(&plan, expected); - Ok(()) + assert_optimized_plan_eq(&plan, expected) + } + + #[test] + fn test_project_same_name_different_qualifier() -> Result<()> { + let table_scan = test_table_scan()?; + let left = LogicalPlanBuilder::from(table_scan) + .project(vec![col("a"), col("b"), col("c")])? + .build()?; + let right_table_scan = test_table_scan_with_name("test1")?; + let right = LogicalPlanBuilder::from(right_table_scan) + .project(vec![col("a"), col("b"), col("c")])? + .build()?; + let filter = and(col("test.a").eq(lit(1)), col("test1.a").gt(lit(2))); + let plan = LogicalPlanBuilder::from(left) + .cross_join(&right)? + .project(vec![col("test.a"), col("test1.a")])? + .filter(filter)? + .build()?; + + let expected = "Projection: test.a, test1.a\ + \n CrossJoin:\ + \n Projection: test.a, test.b, test.c\ + \n Filter: test.a = Int32(1)\ + \n TableScan: test\ + \n Projection: test1.a, test1.b, test1.c\ + \n Filter: test1.a > Int32(2)\ + \n TableScan: test1"; + assert_optimized_plan_eq(&plan, expected) } } From ac9b9b40cad7e3e9c7861d03e8ffc517f1fb4a36 Mon Sep 17 00:00:00 2001 From: jackwener Date: Tue, 29 Nov 2022 20:30:03 +0800 Subject: [PATCH 2/2] fix regression for push_down_filter meet subquery-alias --- benchmarks/expected-plans/q21.txt | 8 ++-- benchmarks/expected-plans/q7.txt | 8 ++-- datafusion/core/tests/sql/joins.rs | 13 +++--- datafusion/optimizer/src/push_down_filter.rs | 46 ++++++++++++++----- .../optimizer/tests/integration-test.rs | 31 ++++++------- 5 files changed, 63 insertions(+), 43 deletions(-) diff --git a/benchmarks/expected-plans/q21.txt b/benchmarks/expected-plans/q21.txt index 397e0a8d8cf61..3ef6269dee48a 100644 --- a/benchmarks/expected-plans/q21.txt +++ b/benchmarks/expected-plans/q21.txt @@ -7,8 +7,8 @@ Sort: numwait DESC NULLS FIRST, supplier.s_name ASC NULLS LAST Inner Join: l1.l_orderkey = orders.o_orderkey Inner Join: supplier.s_suppkey = l1.l_suppkey TableScan: supplier projection=[s_suppkey, s_name, s_nationkey] - Filter: l1.l_receiptdate > l1.l_commitdate - SubqueryAlias: l1 + SubqueryAlias: l1 + Filter: lineitem.l_receiptdate > lineitem.l_commitdate TableScan: lineitem projection=[l_orderkey, l_suppkey, l_commitdate, l_receiptdate] Filter: orders.o_orderstatus = Utf8("F") TableScan: orders projection=[o_orderkey, o_orderstatus] @@ -16,6 +16,6 @@ Sort: numwait DESC NULLS FIRST, supplier.s_name ASC NULLS LAST TableScan: nation projection=[n_nationkey, n_name] SubqueryAlias: l2 TableScan: lineitem projection=[l_orderkey, l_suppkey] - Filter: l3.l_receiptdate > l3.l_commitdate - SubqueryAlias: l3 + SubqueryAlias: l3 + Filter: lineitem.l_receiptdate > lineitem.l_commitdate TableScan: lineitem projection=[l_orderkey, l_suppkey, l_commitdate, l_receiptdate] \ No newline at end of file diff --git a/benchmarks/expected-plans/q7.txt b/benchmarks/expected-plans/q7.txt index 74857c6f94ace..53deda1b87c04 100644 --- a/benchmarks/expected-plans/q7.txt +++ b/benchmarks/expected-plans/q7.txt @@ -14,9 +14,9 @@ Sort: shipping.supp_nation ASC NULLS LAST, shipping.cust_nation ASC NULLS LAST, TableScan: lineitem projection=[l_orderkey, l_suppkey, l_extendedprice, l_discount, l_shipdate] TableScan: orders projection=[o_orderkey, o_custkey] TableScan: customer projection=[c_custkey, c_nationkey] - Filter: n1.n_name = Utf8("FRANCE") OR n1.n_name = Utf8("GERMANY") - SubqueryAlias: n1 + SubqueryAlias: n1 + Filter: nation.n_name = Utf8("FRANCE") OR nation.n_name = Utf8("GERMANY") TableScan: nation projection=[n_nationkey, n_name] - Filter: n2.n_name = Utf8("GERMANY") OR n2.n_name = Utf8("FRANCE") - SubqueryAlias: n2 + SubqueryAlias: n2 + Filter: nation.n_name = Utf8("GERMANY") OR nation.n_name = Utf8("FRANCE") TableScan: nation projection=[n_nationkey, n_name] \ No newline at end of file diff --git a/datafusion/core/tests/sql/joins.rs b/datafusion/core/tests/sql/joins.rs index 87fb594c79b3b..7129fc7ed644d 100644 --- a/datafusion/core/tests/sql/joins.rs +++ b/datafusion/core/tests/sql/joins.rs @@ -1636,15 +1636,14 @@ async fn reduce_left_join_3() -> Result<()> { "Explain [plan_type:Utf8, plan:Utf8]", " Projection: t3.t1_id, t3.t1_name, t3.t1_int, t2.t2_id, t2.t2_name, t2.t2_int [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", " Left Join: t3.t1_int = t2.t2_int [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", - " Filter: t3.t1_id < UInt32(100) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", - " SubqueryAlias: t3 [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", - " Inner Join: t1.t1_id = t2.t2_id [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", + " SubqueryAlias: t3 [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", + " Inner Join: t1.t1_id = t2.t2_id [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", + " Filter: t1.t1_id < UInt32(100) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", " TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", - " Filter: t2.t2_int < UInt32(3) [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", - " TableScan: t2 projection=[t2_id, t2_name, t2_int] [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", + " Filter: t2.t2_int < UInt32(3) AND t2.t2_id < UInt32(100) [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", + " TableScan: t2 projection=[t2_id, t2_name, t2_int] [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", " TableScan: t2 projection=[t2_id, t2_name, t2_int] [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", - ] - ; + ]; let formatted = plan.display_indent_schema().to_string(); let actual: Vec<&str> = formatted.trim().lines().collect(); assert_eq!( diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index edde57fa3d14a..ebdac394bea04 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -546,6 +546,29 @@ impl OptimizerRule for PushDownFilter { ])?; child_plan.with_new_inputs(&[new_filter])? } + LogicalPlan::SubqueryAlias(subquery_alias) => { + let mut replace_map = HashMap::new(); + for (i, field) in + subquery_alias.input.schema().fields().iter().enumerate() + { + replace_map.insert( + subquery_alias + .schema + .fields() + .get(i) + .unwrap() + .qualified_name(), + Expr::Column(field.qualified_column()), + ); + } + let new_predicate = + replace_cols_by_name(filter.predicate().clone(), &replace_map)?; + let new_filter = LogicalPlan::Filter(Filter::try_new( + new_predicate, + subquery_alias.input.clone(), + )?); + child_plan.with_new_inputs(&[new_filter])? + } LogicalPlan::Projection(projection) => { // A projection is filter-commutable, but re-writes all predicate expressions // collect projection. @@ -1096,14 +1119,13 @@ mod tests { .build()?; // filter appears below Union - let expected = "Union\ - \n Filter: test2.b = Int64(1)\ - \n SubqueryAlias: test2\ - \n Projection: test.a AS b\ + let expected = "Union\n SubqueryAlias: test2\ + \n Projection: test.a AS b\ + \n Filter: test.a = Int64(1)\ \n TableScan: test\ - \n Filter: test2.b = Int64(1)\ - \n SubqueryAlias: test2\ - \n Projection: test.a AS b\ + \n SubqueryAlias: test2\ + \n Projection: test.a AS b\ + \n Filter: test.a = Int64(1)\ \n TableScan: test"; assert_optimized_plan_eq(&plan, expected) } @@ -2158,11 +2180,11 @@ mod tests { // Ensure that the predicate without any columns (0 = 1) is // still there. let expected_after = "Projection: b.a\ - \n Filter: b.a = Int64(1)\ - \n SubqueryAlias: b\ - \n Projection: b.a\ - \n SubqueryAlias: b\ - \n Projection: Int64(0) AS a\ + \n SubqueryAlias: b\ + \n Projection: b.a\ + \n SubqueryAlias: b\ + \n Projection: Int64(0) AS a\ + \n Filter: Int64(0) = Int64(1)\ \n EmptyRelation"; assert_optimized_plan_eq(&plan, expected_after) } diff --git a/datafusion/optimizer/tests/integration-test.rs b/datafusion/optimizer/tests/integration-test.rs index c4911439cd3c3..457ea833ef3a9 100644 --- a/datafusion/optimizer/tests/integration-test.rs +++ b/datafusion/optimizer/tests/integration-test.rs @@ -274,12 +274,12 @@ fn join_keys_in_subquery_alias() { let plan = test_sql(sql).unwrap(); let expected = "Projection: a.col_int32, a.col_uint32, a.col_utf8, a.col_date32, a.col_date64, a.col_ts_nano_none, a.col_ts_nano_utc, b.key\ \n Inner Join: a.col_int32 = b.key\ - \n Filter: a.col_int32 IS NOT NULL\ - \n SubqueryAlias: a\ + \n SubqueryAlias: a\ + \n Filter: test.col_int32 IS NOT NULL\ \n TableScan: test projection=[col_int32, col_uint32, col_utf8, col_date32, col_date64, col_ts_nano_none, col_ts_nano_utc]\ - \n Filter: b.key IS NOT NULL\ - \n SubqueryAlias: b\ - \n Projection: test.col_int32 AS key\ + \n SubqueryAlias: b\ + \n Projection: test.col_int32 AS key\ + \n Filter: test.col_int32 IS NOT NULL\ \n TableScan: test projection=[col_int32]"; assert_eq!(expected, format!("{:?}", plan)); } @@ -288,20 +288,19 @@ fn join_keys_in_subquery_alias() { fn join_keys_in_subquery_alias_1() { let sql = "SELECT * FROM test AS A, ( SELECT test.col_int32 AS key FROM test JOIN test AS C on test.col_int32 = C.col_int32 ) AS B where A.col_int32 = B.key;"; let plan = test_sql(sql).unwrap(); - let expected = "Projection: a.col_int32, a.col_uint32, a.col_utf8, a.col_date32, a.col_date64, a.col_ts_nano_none, a.col_ts_nano_utc, b.key\ + let expected = "Projection: a.col_int32, a.col_uint32, a.col_utf8, a.col_date32, a.col_date64, a.col_ts_nano_none, a.col_ts_nano_utc, b.key\ \n Inner Join: a.col_int32 = b.key\ - \n Filter: a.col_int32 IS NOT NULL\ - \n SubqueryAlias: a\ + \n SubqueryAlias: a\ + \n Filter: test.col_int32 IS NOT NULL\ \n TableScan: test projection=[col_int32, col_uint32, col_utf8, col_date32, col_date64, col_ts_nano_none, col_ts_nano_utc]\ - \n Filter: b.key IS NOT NULL\ - \n SubqueryAlias: b\ - \n Projection: test.col_int32 AS key\ - \n Inner Join: test.col_int32 = c.col_int32\ + \n SubqueryAlias: b\ + \n Projection: test.col_int32 AS key\ + \n Inner Join: test.col_int32 = c.col_int32\ + \n Filter: test.col_int32 IS NOT NULL\ + \n TableScan: test projection=[col_int32]\ + \n SubqueryAlias: c\ \n Filter: test.col_int32 IS NOT NULL\ - \n TableScan: test projection=[col_int32]\ - \n Filter: c.col_int32 IS NOT NULL\ - \n SubqueryAlias: c\ - \n TableScan: test projection=[col_int32]"; + \n TableScan: test projection=[col_int32]"; assert_eq!(expected, format!("{:?}", plan)); }