diff --git a/datafusion/core/src/optimizer/filter_push_down.rs b/datafusion/core/src/optimizer/filter_push_down.rs index 0fd107b40dead..cd29aeadb5f57 100644 --- a/datafusion/core/src/optimizer/filter_push_down.rs +++ b/datafusion/core/src/optimizer/filter_push_down.rs @@ -130,7 +130,7 @@ fn issue_filters( return push_down(&state, plan); } - let plan = utils::add_filter(plan.clone(), &predicates); + let plan = utils::filter_by_all(plan.clone(), &predicates); state.filters = remove_filters(&state.filters, &predicate_columns); @@ -251,7 +251,7 @@ fn optimize_join( Ok(plan) } else { // wrap the join on the filter whose predicates must be kept - let plan = utils::add_filter(plan, &to_keep.0); + let plan = utils::filter_by_all(plan, &to_keep.0); state.filters = remove_filters(&state.filters, &to_keep.1); Ok(plan) @@ -290,7 +290,7 @@ fn optimize(plan: &LogicalPlan, mut state: State) -> Result { // As those contain only literals, they could be optimized using constant folding // and removal of WHERE TRUE / WHERE FALSE if !no_col_predicates.is_empty() { - Ok(utils::add_filter( + Ok(utils::filter_by_all( optimize(input, state)?, &no_col_predicates, )) diff --git a/datafusion/core/src/optimizer/subquery_filter_to_join.rs b/datafusion/core/src/optimizer/subquery_filter_to_join.rs index 5f4583c28f75d..fc4946b2b7704 100644 --- a/datafusion/core/src/optimizer/subquery_filter_to_join.rs +++ b/datafusion/core/src/optimizer/subquery_filter_to_join.rs @@ -25,14 +25,14 @@ //! ```text //! WHERE t1.f IN (SELECT f FROM t2) OR t2.f = 'x' //! ``` -//! won't +//! use std::sync::Arc; use crate::error::{DataFusionError, Result}; use crate::execution::context::ExecutionProps; -use crate::logical_plan::plan::{Filter, Join}; +use crate::logical_plan::plan::{Filter, Join, Projection}; use crate::logical_plan::{ - build_join_schema, Expr, JoinConstraint, JoinType, LogicalPlan, + build_join_schema, Expr, JoinConstraint, JoinType, LogicalPlan, Operator, }; use crate::optimizer::optimizer::OptimizerRule; use crate::optimizer::utils; @@ -46,6 +46,173 @@ impl SubqueryFilterToJoin { pub fn new() -> Self { Self {} } + + fn rewrite_correlated_subquery_as_join( + &self, + outer_plan: LogicalPlan, + subquery_expr: &Expr, + execution_props: &ExecutionProps, + ) -> Result { + match subquery_expr { + Expr::InSubquery { + expr, + subquery, + negated, + } => { + let mut correlated_join_columns = vec![]; + let subquery_ref = &*subquery.subquery; + let right_decorrelated_plan = match subquery_ref { + // NOTE: We only pattern match against Projection(Filter(..)). We will have another optimization rule + // which tries to pull up all correlated predicates in an InSubquery into a Projection(Filter(..)) + // at the root node of the InSubquery's subquery. The Projection at the root must have as its expression + // a single Column. + LogicalPlan::Projection(Projection { input, expr, .. }) => { + if expr.len() != 1 { + return Err(DataFusionError::Plan( + "Only single column allowed in InSubquery".to_string(), + )); + }; + match (&expr[0], &**input) { + ( + Expr::Column(right_key), + LogicalPlan::Filter(Filter { predicate, input }), + ) => { + // Extract correlated columns as join columns from the filter predicate + let non_correlated_predicate = + utils::extract_correlated_as_join_columns( + predicate, + outer_plan.schema(), + &mut correlated_join_columns, + ); + + // Strip the projection away and use its input for the semi/anti-join + // Note that this rule is quite quirky. But a removing a projection below a semi + // or anti join is inconsequential if it is a Column projection. + let plan = + if let Some(predicate) = non_correlated_predicate { + LogicalPlan::Filter(Filter { + input: input.clone(), + predicate, + }) + } else { + (**input).clone() + }; + Some((plan, right_key.clone())) + } + _ => None, + } + } + _ => None, + }; + + // optimize the subquery and obtain the appropriate IN join key + let (right_input, right_key) = + if let Some((plan, key)) = right_decorrelated_plan { + let right_input = self.optimize(&plan, execution_props)?; + (right_input, key) + } else { + // If we were unable to decorrelate the subquery by matching against + // the pattern, we assume the subquery itself is not correlated + // and we run the semi/anti join on its output column + let right_input = self.optimize(subquery_ref, execution_props)?; + let right_schema = right_input.schema(); + if right_schema.fields().len() != 1 { + return Err(DataFusionError::Plan( + "Only single column allowed in InSubquery".to_string(), + )); + } + let right_key = right_schema.field(0).qualified_column(); + + (right_input, right_key) + }; + + let left_key = match *expr.clone() { + Expr::Column(col) => col, + _ => { + return Err(DataFusionError::NotImplemented( + "Filtering by expression not implemented for InSubquery" + .to_string(), + )) + } + }; + correlated_join_columns.push((left_key, right_key)); + + let join_type = if *negated { + JoinType::Anti + } else { + JoinType::Semi + }; + + let schema = build_join_schema( + outer_plan.schema(), + right_input.schema(), + &join_type, + )?; + + Ok(LogicalPlan::Join(Join { + left: Arc::new(outer_plan), + right: Arc::new(right_input), + on: correlated_join_columns, + join_type, + join_constraint: JoinConstraint::On, + schema: Arc::new(schema), + null_equals_null: false, + })) + } + Expr::Exists { subquery, negated } => { + // NOTE: We only pattern match against Filter(..). We will have another optimization rule + // which tries to pull up all correlated predicates in an Exists into a Filter(..) + // at the root node of the Exists's subquery + let mut correlated_join_columns = vec![]; + let right_input = match &*subquery.subquery { + LogicalPlan::Filter(Filter { predicate, input }) => { + let non_correlated_predicate = + utils::extract_correlated_as_join_columns( + predicate, + outer_plan.schema(), + &mut correlated_join_columns, + ); + if let Some(predicate) = non_correlated_predicate { + Arc::new(LogicalPlan::Filter(Filter { + input: input.clone(), + predicate, + })) + } else { + input.clone() + } + } + _ => subquery.subquery.clone(), + }; + + let right_input = self.optimize(&right_input, execution_props)?; + + let join_type = if *negated { + JoinType::Anti + } else { + JoinType::Semi + }; + + let schema = build_join_schema( + outer_plan.schema(), + right_input.schema(), + &join_type, + )?; + + Ok(LogicalPlan::Join(Join { + left: Arc::new(outer_plan), + right: Arc::new(right_input), + on: correlated_join_columns, + join_type, + join_constraint: JoinConstraint::On, + schema: Arc::new(schema), + null_equals_null: false, + })) + } + _ => Err(DataFusionError::Plan( + "Unknown expression while rewriting subquery to joins".to_string(), + )), + } + } } impl OptimizerRule for SubqueryFilterToJoin { @@ -55,6 +222,8 @@ impl OptimizerRule for SubqueryFilterToJoin { execution_props: &ExecutionProps, ) -> Result { match plan { + // Pattern match on all plans of the form + // Filter: Exists(Filter(..)) AND InSubquery(Project(Filter(..))) AND ... LogicalPlan::Filter(Filter { predicate, input }) => { // Apply optimizer rule to current input let optimized_input = self.optimize(input, execution_props)?; @@ -64,105 +233,41 @@ impl OptimizerRule for SubqueryFilterToJoin { utils::split_conjunction(predicate, &mut filters); // Searching for subquery-based filters - let (subquery_filters, regular_filters): (Vec<&Expr>, Vec<&Expr>) = - filters - .into_iter() - .partition(|&e| matches!(e, Expr::InSubquery { .. })); - - // Check all subquery filters could be rewritten - // - // In case of expressions which could not be rewritten - // return original filter with optimized input - let mut subqueries_in_regular = vec![]; - regular_filters.iter().try_for_each(|&e| { - extract_subquery_filters(e, &mut subqueries_in_regular) - })?; - - if !subqueries_in_regular.is_empty() { - return Ok(LogicalPlan::Filter(Filter { - predicate: predicate.clone(), - input: Arc::new(optimized_input), - })); - }; - - // Add subquery joins to new_input - // optimized_input value should retain for possible optimization rollback - let opt_result = subquery_filters.iter().try_fold( - optimized_input.clone(), - |input, &e| match e { - Expr::InSubquery { - expr, - subquery, - negated, - } => { - let right_input = self.optimize( - &*subquery.subquery, - execution_props - )?; - let right_schema = right_input.schema(); - if right_schema.fields().len() != 1 { - return Err(DataFusionError::Plan( - "Only single column allowed in InSubquery" - .to_string(), - )); - }; - - let right_key = right_schema.field(0).qualified_column(); - let left_key = match *expr.clone() { - Expr::Column(col) => col, - _ => return Err(DataFusionError::NotImplemented( - "Filtering by expression not implemented for InSubquery" - .to_string(), - )), - }; - - let join_type = if *negated { - JoinType::Anti - } else { - JoinType::Semi - }; - - let schema = build_join_schema( - optimized_input.schema(), - right_schema, - &join_type, - )?; - - Ok(LogicalPlan::Join(Join { - left: Arc::new(input), - right: Arc::new(right_input), - on: vec![(left_key, right_key)], - join_type, - join_constraint: JoinConstraint::On, - schema: Arc::new(schema), - null_equals_null: false, - })) - } - _ => Err(DataFusionError::Plan( - "Unknown expression while rewriting subquery to joins" - .to_string(), - )), - } - ); - - // In case of expressions which could not be rewritten - // return original filter with optimized input - let new_input = match opt_result { - Ok(plan) => plan, - Err(_) => { + let (subquery_filters, remainder): (Vec<&Expr>, Vec<&Expr>) = + filters.into_iter().partition(|&e| { + matches!(e, Expr::InSubquery { .. } | Expr::Exists { .. }) + }); + + let remaining_predicate = utils::combine_conjunctive(&remainder); + + if let Some(predicate) = remaining_predicate { + // Since we are unable to simplify the correlated subquery, + // we must do a row scan against the outer plan anyway, so we abort + // + // TODO: complex expressions which are disjunctive with our subquery expressions + // can be rewritten as unions (without deduplication...)? + if utils::contains_joinable_subquery(&predicate)? { return Ok(LogicalPlan::Filter(Filter { - predicate: predicate.clone(), + predicate, input: Arc::new(optimized_input), - })) + })); } - }; - - // Apply regular filters to join output if some or just return join - if regular_filters.is_empty() { - Ok(new_input) - } else { - Ok(utils::add_filter(new_input, ®ular_filters)) } + + // Add subquery joins to optimized_input + let new_input = subquery_filters.iter().try_fold( + optimized_input, + |outer_plan, &subquery_expr| { + self.rewrite_correlated_subquery_as_join( + outer_plan, + subquery_expr, + execution_props, + ) + }, + )?; + + // Apply filters to join output if any + Ok(utils::filter_by_all(new_input, &remainder)) } _ => { // Apply the optimization to all inputs of the plan @@ -176,24 +281,12 @@ impl OptimizerRule for SubqueryFilterToJoin { } } -fn extract_subquery_filters(expression: &Expr, extracted: &mut Vec) -> Result<()> { - utils::expr_sub_expressions(expression)? - .into_iter() - .try_for_each(|se| match se { - Expr::InSubquery { .. } => { - extracted.push(se); - Ok(()) - } - _ => extract_subquery_filters(&se, extracted), - }) -} - #[cfg(test)] mod tests { use super::*; use crate::logical_plan::{ - and, binary_expr, col, in_subquery, lit, not_in_subquery, or, LogicalPlanBuilder, - Operator, + and, binary_expr, col, exists, in_subquery, lit, not_exists, not_in_subquery, or, + LogicalPlanBuilder, Operator, }; use crate::test::*; @@ -348,11 +441,36 @@ mod tests { let expected = "Projection: #test.b [b:UInt32]\ \n Semi Join: #test.b = #sq.a [a:UInt32, b:UInt32, c:UInt32]\ \n TableScan: test projection=None [a:UInt32, b:UInt32, c:UInt32]\ - \n Projection: #sq.a [a:UInt32]\ - \n Semi Join: #sq.a = #sq_nested.c [a:UInt32, b:UInt32, c:UInt32]\ - \n TableScan: sq projection=None [a:UInt32, b:UInt32, c:UInt32]\ - \n Projection: #sq_nested.c [c:UInt32]\ - \n TableScan: sq_nested projection=None [a:UInt32, b:UInt32, c:UInt32]"; + \n Semi Join: #sq.a = #sq_nested.c [a:UInt32, b:UInt32, c:UInt32]\ + \n TableScan: sq projection=None [a:UInt32, b:UInt32, c:UInt32]\ + \n Projection: #sq_nested.c [c:UInt32]\ + \n TableScan: sq_nested projection=None [a:UInt32, b:UInt32, c:UInt32]"; + + assert_optimized_plan_eq(&plan, expected); + Ok(()) + } + + /// Test for IN subquery with additional correlated (dependent) predicate + #[test] + fn in_subquery_with_correlated_filters() -> Result<()> { + let table_a = test_table_scan_with_name("table_a")?; + let table_b = test_table_scan_with_name("table_b")?; + + let subquery = LogicalPlanBuilder::from(table_b) + .filter(col("table_a.a").eq(col("table_b.a")))? + .project(vec![col("c")])? + .build()?; + + let plan = LogicalPlanBuilder::from(table_a) + .filter(in_subquery(col("c"), Arc::new(subquery)))? + .project(vec![col("table_a.b")])? + .build()?; + + let expected = "\ + Projection: #table_a.b [b:UInt32]\ + \n Semi Join: #table_a.a = #table_b.a, #table_a.c = #table_b.c [a:UInt32, b:UInt32, c:UInt32]\ + \n TableScan: table_a projection=None [a:UInt32, b:UInt32, c:UInt32]\ + \n TableScan: table_b projection=None [a:UInt32, b:UInt32, c:UInt32]"; assert_optimized_plan_eq(&plan, expected); Ok(()) @@ -386,4 +504,112 @@ mod tests { assert_optimized_plan_eq(&plan, expected); Ok(()) } + + #[test] + fn test_exists_simple() -> Result<()> { + let table_a = test_table_scan_with_name("table_a")?; + let table_b = test_table_scan_with_name("table_b")?; + let subquery = LogicalPlanBuilder::from(table_b) + .filter(col("table_a.a").eq(col("table_b.a")))? + .build()?; + + let plan = LogicalPlanBuilder::from(table_a) + .filter(exists(Arc::new(subquery)))? + .project(vec![col("a"), col("b")])? + .build()?; + + let expected = "\ + Projection: #table_a.a, #table_a.b [a:UInt32, b:UInt32]\ + \n Semi Join: #table_a.a = #table_b.a [a:UInt32, b:UInt32, c:UInt32]\ + \n TableScan: table_a projection=None [a:UInt32, b:UInt32, c:UInt32]\ + \n TableScan: table_b projection=None [a:UInt32, b:UInt32, c:UInt32]"; + + assert_optimized_plan_eq(&plan, expected); + + Ok(()) + } + + #[test] + fn test_exists_multiple_correlated_filters() -> Result<()> { + let table_a = test_table_scan_with_name("table_a")?; + let table_b = test_table_scan_with_name("table_b")?; + + // Test AND and nested filters will be extracted as join columns + let subquery = LogicalPlanBuilder::from(table_b) + .filter( + (col("table_a.c").eq(col("table_b.c"))).and( + (col("table_a.a").eq(col("table_b.a"))) + .and(col("table_a.b").eq(col("table_b.b"))), + ), + )? + .build()?; + + let plan = LogicalPlanBuilder::from(table_a) + .filter(exists(Arc::new(subquery)))? + .project(vec![col("a"), col("b")])? + .build()?; + + let expected = "\ + Projection: #table_a.a, #table_a.b [a:UInt32, b:UInt32]\ + \n Semi Join: #table_a.c = #table_b.c, #table_a.a = #table_b.a, #table_a.b = #table_b.b [a:UInt32, b:UInt32, c:UInt32]\ + \n TableScan: table_a projection=None [a:UInt32, b:UInt32, c:UInt32]\ + \n TableScan: table_b projection=None [a:UInt32, b:UInt32, c:UInt32]"; + + assert_optimized_plan_eq(&plan, expected); + + Ok(()) + } + + #[test] + fn test_exists_with_non_correlated_filter() -> Result<()> { + let table_a = test_table_scan_with_name("table_a")?; + let table_b = test_table_scan_with_name("table_b")?; + let subquery = LogicalPlanBuilder::from(table_b) + .filter( + (col("table_a.a").eq(col("table_b.a"))) + .and(col("table_b.b").gt(lit("5"))), + )? + .build()?; + + let plan = LogicalPlanBuilder::from(table_a) + .project(vec![col("a"), col("b")])? + .filter(exists(Arc::new(subquery)))? + .build()?; + let expected = "\ + Semi Join: #table_a.a = #table_b.a [a:UInt32, b:UInt32]\ + \n Projection: #table_a.a, #table_a.b [a:UInt32, b:UInt32]\ + \n TableScan: table_a projection=None [a:UInt32, b:UInt32, c:UInt32]\ + \n Filter: #table_b.b > Utf8(\"5\") [a:UInt32, b:UInt32, c:UInt32]\ + \n TableScan: table_b projection=None [a:UInt32, b:UInt32, c:UInt32]"; + + assert_optimized_plan_eq(&plan, expected); + + Ok(()) + } + + // We only test not exists for the simplest case since all the other code paths + // are covered by exists + #[test] + fn test_not_exists_simple() -> Result<()> { + let table_a = test_table_scan_with_name("table_a")?; + let table_b = test_table_scan_with_name("table_b")?; + let subquery = LogicalPlanBuilder::from(table_b) + .filter(col("table_a.a").eq(col("table_b.a")))? + .build()?; + + let plan = LogicalPlanBuilder::from(table_a) + .filter(not_exists(Arc::new(subquery)))? + .project(vec![col("a"), col("b")])? + .build()?; + + let expected = "\ + Projection: #table_a.a, #table_a.b [a:UInt32, b:UInt32]\ + \n Anti Join: #table_a.a = #table_b.a [a:UInt32, b:UInt32, c:UInt32]\ + \n TableScan: table_a projection=None [a:UInt32, b:UInt32, c:UInt32]\ + \n TableScan: table_b projection=None [a:UInt32, b:UInt32, c:UInt32]"; + + assert_optimized_plan_eq(&plan, expected); + + Ok(()) + } } diff --git a/datafusion/core/src/optimizer/utils.rs b/datafusion/core/src/optimizer/utils.rs index 48855df9f8e8a..d6112d9cfb523 100644 --- a/datafusion/core/src/optimizer/utils.rs +++ b/datafusion/core/src/optimizer/utils.rs @@ -25,9 +25,9 @@ use datafusion_expr::logical_plan::{ }; use crate::logical_plan::{ - and, build_join_schema, Column, CreateMemoryTable, DFSchemaRef, Expr, ExprVisitable, - Limit, LogicalPlan, LogicalPlanBuilder, Operator, Partitioning, Recursion, - Repartition, Union, Values, + and, build_join_schema, or, Column, CreateMemoryTable, DFSchemaRef, Expr, + ExprVisitable, Limit, LogicalPlan, LogicalPlanBuilder, Operator, Partitioning, + Recursion, Repartition, Union, Values, }; use crate::prelude::lit; use crate::scalar::ScalarValue; @@ -556,6 +556,32 @@ pub fn rewrite_expression(expr: &Expr, expressions: &[Expr]) -> Result { } } +// returns a new [LogicalPlan] that wraps `plan` in a [LogicalPlan::Filter] with +/// its predicate with all `predicates` ANDed. +pub fn filter_by_all(plan: LogicalPlan, predicates: &[&Expr]) -> LogicalPlan { + if let Some(predicate) = combine_conjunctive(predicates) { + LogicalPlan::Filter(Filter { + predicate, + input: Arc::new(plan), + }) + } else { + plan + } +} + +// returns a new [LogicalPlan] that wraps `plan` in a [LogicalPlan::Filter] with +/// its predicate with all `predicates` ORed. +pub fn filter_by_any(plan: LogicalPlan, predicates: &[&Expr]) -> LogicalPlan { + if let Some(predicate) = combine_disjunctive(predicates) { + LogicalPlan::Filter(Filter { + predicate, + input: Arc::new(plan), + }) + } else { + plan + } +} + /// converts "A AND B AND C" => [A, B, C] pub fn split_conjunction<'a>(predicate: &'a Expr, predicates: &mut Vec<&'a Expr>) { match predicate { @@ -574,23 +600,180 @@ pub fn split_conjunction<'a>(predicate: &'a Expr, predicates: &mut Vec<&'a Expr> } } -/// returns a new [LogicalPlan] that wraps `plan` in a [LogicalPlan::Filter] with -/// its predicate be all `predicates` ANDed. -pub fn add_filter(plan: LogicalPlan, predicates: &[&Expr]) -> LogicalPlan { - // reduce filters to a single filter with an AND - let predicate = predicates - .iter() - .skip(1) - .fold(predicates[0].clone(), |acc, predicate| { - and(acc, (*predicate).to_owned()) - }); - - LogicalPlan::Filter(Filter { - predicate, - input: Arc::new(plan), - }) +/// converts "A OR B OR C" => [A, B, C] +pub fn split_disjunction<'a>(predicate: &'a Expr, predicates: &mut Vec<&'a Expr>) { + match predicate { + Expr::BinaryExpr { + right, + op: Operator::Or, + left, + } => { + split_disjunction(left, predicates); + split_disjunction(right, predicates); + } + Expr::Alias(expr, _) => { + split_disjunction(expr, predicates); + } + other => predicates.push(other), + } +} + +/// Converts [A, B, C] -> A AND B AND C +pub fn combine_conjunctive(predicates: &[&Expr]) -> Option { + if predicates.is_empty() { + None + } else { + // reduce filters to a single filter with an AND + Some( + predicates + .iter() + .skip(1) + .fold(predicates[0].clone(), |acc, predicate| { + and(acc, (*predicate).to_owned()) + }), + ) + } +} + +/// Converts [A, B, C] -> A OR B OR C +pub fn combine_disjunctive(predicates: &[&Expr]) -> Option { + if predicates.is_empty() { + None + } else { + // reduce filters to a single filter with an OR + Some( + predicates + .iter() + .skip(1) + .fold(predicates[0].clone(), |acc, predicate| { + or(acc, (*predicate).to_owned()) + }), + ) + } +} + +/// Recursively walk an expression tree, returning true if it encounters a joinable subquery +struct SubqueryVisitor<'a> { + contains_joinable_subquery: &'a mut bool, +} + +impl ExpressionVisitor for SubqueryVisitor<'_> { + fn pre_visit(self, expr: &Expr) -> Result> { + match expr { + Expr::InSubquery { .. } | Expr::Exists { .. } => { + *self.contains_joinable_subquery = true; + return Ok(Recursion::Stop(self)); + } + _ => {} + } + Ok(Recursion::Continue(self)) + } +} + +/// Recursively walk an expression tree, returning true if it encounters a joinable subquery +pub fn contains_joinable_subquery(expr: &Expr) -> Result { + let mut contains_joinable_subquery = false; + expr.accept(SubqueryVisitor { + contains_joinable_subquery: &mut contains_joinable_subquery, + })?; + Ok(contains_joinable_subquery) +} + +/// Checks if the column belongs to the outer schema +pub(crate) fn column_is_correlated(outer: &Arc, column: &Column) -> bool { + for field in outer.fields() { + if *column == field.qualified_column() || *column == field.unqualified_column() { + return true; + } + } + false +} + +/// Recursively walk an expression tree, returning true if it encounters a joinable subquery +struct CorrelatedColumnsVisitor<'a> { + outer_schema: &'a Arc, + contains_correlated_columns: &'a mut bool, +} + +impl ExpressionVisitor for CorrelatedColumnsVisitor<'_> { + fn pre_visit(self, expr: &Expr) -> Result> { + if let Expr::Column(c) = expr { + if column_is_correlated(self.outer_schema, c) { + *self.contains_correlated_columns = true; + return Ok(Recursion::Stop(self)); + } + } + Ok(Recursion::Continue(self)) + } +} + +/// Recursively walk an expression tree, returning true if it encounters a correlated column +pub fn contains_correlated_columns( + outer_schema: &Arc, + expr: &Expr, +) -> Result { + let mut contains_correlated_columns = false; + expr.accept(CorrelatedColumnsVisitor { + outer_schema, + contains_correlated_columns: &mut contains_correlated_columns, + })?; + Ok(contains_correlated_columns) +} + +/// Check if one of the columns belongs to the outer schema, then return (outer, inner) +fn maybe_correlated_columns( + outer: &Arc, + column_a: &Column, + column_b: &Column, +) -> Option<(Column, Column)> { + if column_is_correlated(outer, column_a) { + return Some((column_a.clone(), column_b.clone())); + } else if column_is_correlated(outer, column_b) { + return Some((column_b.clone(), column_a.clone())); + } + None } +/// Extract filters of the form Column = Column where one of the columns +/// from the given expression belongs to the outer schema +pub(crate) fn extract_correlated_as_join_columns( + expr: &Expr, + outer: &Arc, + correlated_columns: &mut Vec<(Column, Column)>, +) -> Option { + let mut filters = vec![]; + // This will also strip aliases + split_conjunction(expr, &mut filters); + + let mut non_correlated_predicates = vec![]; + for filter in filters { + match filter { + Expr::BinaryExpr { left, op, right } => { + let mut extracted_column = false; + if let (Expr::Column(column_a), Expr::Column(column_b)) = + (left.as_ref(), right.as_ref()) + { + if let Some(columns) = + maybe_correlated_columns(outer, column_a, column_b) + { + if *op == Operator::Eq { + correlated_columns.push(columns); + extracted_column = true; + } + } + } + if !extracted_column { + non_correlated_predicates.push(filter); + } + } + _ => non_correlated_predicates.push(filter), + } + } + + combine_conjunctive(&non_correlated_predicates) +} + + #[cfg(test)] mod tests { use super::*;