From 3e5ea5297fe655070571e9dbfae5ea9dae9e6338 Mon Sep 17 00:00:00 2001 From: yangjiang Date: Mon, 17 Oct 2022 22:15:15 +0800 Subject: [PATCH 1/5] Factorize common AND factors out of OR predicates to support filterPushDown as possible Signed-off-by: yangjiang --- datafusion/optimizer/src/filter_push_down.rs | 27 ++++- datafusion/optimizer/src/utils.rs | 115 +++++++++++++++++++ 2 files changed, 141 insertions(+), 1 deletion(-) diff --git a/datafusion/optimizer/src/filter_push_down.rs b/datafusion/optimizer/src/filter_push_down.rs index 08ba71cda0cd..d41e620ccdc5 100644 --- a/datafusion/optimizer/src/filter_push_down.rs +++ b/datafusion/optimizer/src/filter_push_down.rs @@ -529,7 +529,8 @@ fn optimize(plan: &LogicalPlan, mut state: State) -> Result { } LogicalPlan::Analyze { .. } => push_down(&state, plan), LogicalPlan::Filter(filter) => { - let predicates = utils::split_conjunction(filter.predicate()); + let filter_cnf = utils::CnfHelper::new().rewrite_to_cnf_impl(filter.predicate()); + let predicates = utils::split_conjunction(&filter_cnf); predicates .into_iter() @@ -952,6 +953,30 @@ mod tests { Ok(()) } + #[test] + fn filter_keep_partial_agg() -> Result<()> { + let table_scan = test_table_scan()?; + let f1 = col("c").eq(lit(1i64)).and(col("b").gt(lit(2i64))); + let f2 = col("c").eq(lit(1i64)).and(col("b").gt(lit(3i64))); + let filter = f1.or(f2); + let plan = LogicalPlanBuilder::from(table_scan) + .aggregate(vec![col("a")], vec![sum(col("b")).alias("b")])? + .filter(filter)? + .build()?; + // filter of aggregate is after aggregation since they are non-commutative + // (c =1 AND b > 2) OR (c = 1 AND b > 3) + // rewrite to CNF + // (c = 1 OR c = 1) [can pushDown] AND (c = 1 OR b > 3) AND (b > 2 OR C = 1) AND (b > 2 OR b > 3) + + let expected = "\ + Filter: test.c = Int64(1) OR b > Int64(3) AND b > Int64(2) OR test.c = Int64(1) AND b > Int64(2) OR b > Int64(3)\ + \n Aggregate: groupBy=[[test.a]], aggr=[[SUM(test.b) AS b]]\ + \n Filter: test.c = Int64(1) OR test.c = Int64(1)\ + \n TableScan: test"; + assert_optimized_plan_eq(&plan, expected); + Ok(()) + } + /// verifies that a filter is pushed to before a projection, the filter expression is correctly re-written #[test] fn alias() -> Result<()> { diff --git a/datafusion/optimizer/src/utils.rs b/datafusion/optimizer/src/utils.rs index 57702a71f8d8..276c8c0cde44 100644 --- a/datafusion/optimizer/src/utils.rs +++ b/datafusion/optimizer/src/utils.rs @@ -119,6 +119,121 @@ fn split_conjunction_owned_impl(expr: Expr, mut exprs: Vec) -> Vec { } } +/// Converts an expression to conjunctive normal form (CNF). +/// +/// The following expression is in CNF: +/// `(a OR b) AND (c OR d)` +/// The following is not in CNF: +/// `(a AND b) OR c`. +/// But could be rewrite to a CNF expression: +/// `(a OR c) AND (b OR c)`. +/// +/// # Example +/// ``` +/// # use datafusion_expr::{col, lit}; +/// # use datafusion_optimizer::utils::CnfHelper; +/// // (a=1 AND b=2)OR c = 3 +/// let expr1 = col("a").eq(lit(1)).and(col("b").eq(lit(2))); +/// let expr2 = col("c").eq(lit(3)); +/// let expr = expr1.or(expr2); +/// +/// //(a=1 or c=3)AND(b=2 or c=3) +/// let expr1 = col("a").eq(lit(1)).or(col("c").eq(lit(3))); +/// let expr2 = col("b").eq(lit(2)).or(col("c").eq(lit(3))); +/// let expect = expr1.and(expr2); +/// // use split_conjunction_owned to split them +/// assert_eq!(CnfHelper::new().rewrite_to_cnf_impl(&expr), expect); +/// ``` +/// +pub struct CnfHelper { + max_count: usize, + current_count: usize, + exprs: Vec, + original_expr: Option, +} + +impl CnfHelper { + pub fn new() -> Self { + CnfHelper { + max_count: 100, + current_count: 0, + exprs: vec![], + original_expr: None, + } + } + + pub fn new_with_max_count(max_count: usize) -> Self { + CnfHelper { + max_count, + current_count: 0, + exprs: vec![], + original_expr: None, + } + } + + fn increment_and_check_overload(&mut self) -> bool { + self.current_count += 1; + self.current_count >= self.max_count + } + + pub fn rewrite_to_cnf_impl(&mut self, expr: &Expr) -> Expr { + if self.original_expr.is_none() { + self.original_expr = Some(expr.clone()); + } + match expr { + Expr::BinaryExpr { left, op, right } => { + match op { + Operator::And => { + if self.increment_and_check_overload() { + return expr.clone(); + } + self.rewrite_to_cnf_impl(left); + self.rewrite_to_cnf_impl(right); + } + // (a AND b) OR (c AND d) = (a OR b) AND (a OR c) AND (b OR c) AND (b OR d) + Operator::Or => { + if self.increment_and_check_overload() { + return expr.clone(); + } + split_conjunction_owned(*left.clone()).iter().for_each(|l| { + split_conjunction_owned(*right.clone()) + .iter() + .for_each(|r| { + self.exprs.push(Expr::BinaryExpr { + left: Box::new(l.clone()), + op: Operator::Or, + right: Box::new(r.clone()), + }) + }) + }) + } + _ => { + if self.increment_and_check_overload() { + return expr.clone(); + } + self.exprs.push(expr.clone()); + } + } + } + other => { + self.exprs.push(other.clone()); + } + } + if self.current_count >= self.max_count { + self.original_expr.as_ref().unwrap().clone() + } else { + conjunction(self.exprs.clone()) + .unwrap_or_else(|| self.original_expr.as_ref().unwrap().clone()) + } + } +} + +impl Default for CnfHelper { + fn default() -> Self { + Self::new() + } +} + /// Combines an array of filter expressions into a single filter /// expression consisting of the input filter expressions joined with /// logical AND. From 9b893980e81d9b98c80c750bbb262fd241e73c28 Mon Sep 17 00:00:00 2001 From: yangjiang Date: Tue, 18 Oct 2022 21:28:19 +0800 Subject: [PATCH 2/5] add test and use `ExprRewriter` framework Signed-off-by: yangjiang --- datafusion/optimizer/src/utils.rs | 144 +++++++++++++++++++++++++----- 1 file changed, 121 insertions(+), 23 deletions(-) diff --git a/datafusion/optimizer/src/utils.rs b/datafusion/optimizer/src/utils.rs index 276c8c0cde44..d463c831b907 100644 --- a/datafusion/optimizer/src/utils.rs +++ b/datafusion/optimizer/src/utils.rs @@ -20,7 +20,7 @@ use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::Result; use datafusion_common::{plan_err, Column, DFSchemaRef}; -use datafusion_expr::expr_rewriter::{ExprRewritable, ExprRewriter}; +use datafusion_expr::expr_rewriter::{ExprRewritable, ExprRewriter, RewriteRecursion}; use datafusion_expr::expr_visitor::{ExprVisitable, ExpressionVisitor, Recursion}; use datafusion_expr::{ and, col, @@ -131,6 +131,7 @@ fn split_conjunction_owned_impl(expr: Expr, mut exprs: Vec) -> Vec { /// # Example /// ``` /// # use datafusion_expr::{col, lit}; +/// # use datafusion_expr::expr_rewriter::ExprRewritable; /// # use datafusion_optimizer::utils::CnfHelper; /// // (a=1 AND b=2)OR c = 3 /// let expr1 = col("a").eq(lit(1)).and(col("b").eq(lit(2))); @@ -142,7 +143,7 @@ fn split_conjunction_owned_impl(expr: Expr, mut exprs: Vec) -> Vec { /// let expr2 = col("b").eq(lit(2)).or(col("c").eq(lit(3))); /// let expect = expr1.and(expr2); /// // use split_conjunction_owned to split them -/// assert_eq!(CnfHelper::new().rewrite_to_cnf_impl(&expr), expect); +/// assert_eq!(expr.rewrite(& mut CnfHelper::new()).unwrap(), expect); /// ``` /// pub struct CnfHelper { @@ -175,9 +176,12 @@ impl CnfHelper { self.current_count += 1; self.current_count >= self.max_count } +} - pub fn rewrite_to_cnf_impl(&mut self, expr: &Expr) -> Expr { - if self.original_expr.is_none() { +impl ExprRewriter for CnfHelper { + fn pre_visit(&mut self, expr: &Expr) -> Result { + let is_root = self.original_expr.is_none(); + if is_root { self.original_expr = Some(expr.clone()); } match expr { @@ -185,45 +189,59 @@ impl CnfHelper { match op { Operator::And => { if self.increment_and_check_overload() { - return expr.clone(); + return Ok(RewriteRecursion::Mutate); } - self.rewrite_to_cnf_impl(left); - self.rewrite_to_cnf_impl(right); } // (a AND b) OR (c AND d) = (a OR b) AND (a OR c) AND (b OR c) AND (b OR d) Operator::Or => { + let left = split_conjunction_owned(*left.clone()); + let right = split_conjunction_owned(*right.clone()); + let count = left.len() * right.len() - 1; + self.current_count += count; if self.increment_and_check_overload() { - return expr.clone(); + return Ok(RewriteRecursion::Mutate); } - split_conjunction_owned(*left.clone()).iter().for_each(|l| { - split_conjunction_owned(*right.clone()) - .iter() - .for_each(|r| { - self.exprs.push(Expr::BinaryExpr { - left: Box::new(l.clone()), - op: Operator::Or, - right: Box::new(r.clone()), - }) + left.iter().for_each(|l| { + right.iter().for_each(|r| { + self.exprs.push(Expr::BinaryExpr { + left: Box::new(l.clone()), + op: Operator::Or, + right: Box::new(r.clone()), }) - }) + }) + }); + return Ok(RewriteRecursion::Mutate); } _ => { if self.increment_and_check_overload() { - return expr.clone(); + return Ok(RewriteRecursion::Mutate); } self.exprs.push(expr.clone()); + return Ok(RewriteRecursion::Stop); } } } other => { + if self.increment_and_check_overload() { + return Ok(RewriteRecursion::Mutate); + } self.exprs.push(other.clone()); + return Ok(RewriteRecursion::Stop); } } + if is_root { + Ok(RewriteRecursion::Continue) + } else { + Ok(RewriteRecursion::Skip) + } + } + + fn mutate(&mut self, _expr: Expr) -> Result { if self.current_count >= self.max_count { - self.original_expr.as_ref().unwrap().clone() + Ok(self.original_expr.as_ref().unwrap().clone()) } else { - conjunction(self.exprs.clone()) - .unwrap_or_else(|| self.original_expr.as_ref().unwrap().clone()) + Ok(conjunction(self.exprs.clone()) + .unwrap_or_else(|| self.original_expr.as_ref().unwrap().clone())) } } } @@ -579,7 +597,7 @@ mod tests { use super::*; use arrow::datatypes::DataType; use datafusion_common::Column; - use datafusion_expr::{col, lit, utils::expr_to_columns}; + use datafusion_expr::{col, lit, or, utils::expr_to_columns}; use std::collections::HashSet; use std::ops::Add; @@ -773,4 +791,84 @@ mod tests { "mismatch rewriting expr_from: {expr_from} to {rewrite_to}" ) } + + #[test] + fn test_rewrite_cnf() { + let a_1 = col("a").eq(lit(1i64)); + let a_2 = col("a").eq(lit(2i64)); + + let b_1 = col("b").eq(lit(1i64)); + let b_2 = col("b").eq(lit(2i64)); + + // Test rewrite on a1_and_b2 and a2_and_b1 -> not change + let mut helper = CnfHelper::new(); + let expr1 = and(a_1.clone(), b_2.clone()); + let expect = expr1.clone(); + let res = expr1.rewrite(&mut helper).unwrap(); + assert_eq!(expect, res); + + // Test rewrite on a1_and_b2 and a2_and_b1 -> (((a1 and b2) and a2) and b1) + let mut helper = CnfHelper::new(); + let expr1 = and(and(a_1.clone(), b_2.clone()), and(a_2.clone(), b_1.clone())); + let expect = and(a_1.clone(), b_2.clone()) + .and(a_2.clone()) + .and(b_1.clone()); + let res = expr1.rewrite(&mut helper).unwrap(); + assert_eq!(expect, res); + + // Test rewrite on a1_or_b2 -> not change + let mut helper = CnfHelper::new(); + let expr1 = or(a_1.clone(), b_2.clone()); + let expect = expr1.clone(); + let res = expr1.rewrite(&mut helper).unwrap(); + assert_eq!(expect, res); + + // Test rewrite on a1_and_b2 or a2_and_b1 -> a1_or_a2 and a1_or_b1 and b2_or_a2 and b2_or_b1 + let mut helper = CnfHelper::new(); + let expr1 = or(and(a_1.clone(), b_2.clone()), and(a_2.clone(), b_1.clone())); + let a1_or_a2 = or(a_1.clone(), a_2.clone()); + let a1_or_b1 = or(a_1.clone(), b_1.clone()); + let b2_or_a2 = or(b_2.clone(), a_2.clone()); + let b2_or_b1 = or(b_2.clone(), b_1.clone()); + let expect = and(a1_or_a2, a1_or_b1).and(b2_or_a2).and(b2_or_b1); + let res = expr1.rewrite(&mut helper).unwrap(); + assert_eq!(expect, res); + + // Test rewrite on a1_or_b2 or a2_and_b1 -> ( a1_or_a2 or a2 ) and (a1_or_a2 or b1) + let mut helper = CnfHelper::new(); + let a1_or_b2 = or(a_1.clone(), b_2.clone()); + let expr1 = or(or(a_1.clone(), b_2.clone()), and(a_2.clone(), b_1.clone())); + let expect = or(a1_or_b2.clone(), a_2.clone()).and(or(a1_or_b2, b_1.clone())); + let res = expr1.rewrite(&mut helper).unwrap(); + assert_eq!(expect, res); + + // Test rewrite on a1_or_b2 or a2_or_b1 -> not change + let mut helper = CnfHelper::new(); + let expr1 = or(or(a_1, b_2), or(a_2, b_1)); + let expect = expr1.clone(); + let res = expr1.rewrite(&mut helper).unwrap(); + assert_eq!(expect, res); + } + + #[test] + fn test_rewrite_cnf_overflow() { + // in this situation: + // AND = (a=1 and b=2) + // rewrite (AND * 10) or (AND * 10), it will produce 10 * 10 = 100 (a=1 or b=2) + // which cause size expansion. + + let mut expr1 = col("test1").eq(lit(1i64)); + let expr2 = col("test2").eq(lit(2i64)); + + for _i in 0..9 { + expr1 = expr1.clone().and(expr2.clone()); + } + let expr3 = expr1.clone(); + let expr = or(expr1, expr3); + let mut helper = CnfHelper::new(); + let res = expr.clone().rewrite(&mut helper).unwrap(); + assert_eq!(100, helper.current_count); + assert_eq!(res, expr); + assert!(helper.current_count >= helper.max_count); + } } From 3e076a0c68ba1334a56a5eac2b8ab9018f73b6a8 Mon Sep 17 00:00:00 2001 From: yangjiang Date: Tue, 18 Oct 2022 21:43:34 +0800 Subject: [PATCH 3/5] add test and use farmework 2 Signed-off-by: yangjiang --- datafusion/optimizer/src/filter_push_down.rs | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/datafusion/optimizer/src/filter_push_down.rs b/datafusion/optimizer/src/filter_push_down.rs index 3be39013c8d0..a768b0a7fed5 100644 --- a/datafusion/optimizer/src/filter_push_down.rs +++ b/datafusion/optimizer/src/filter_push_down.rs @@ -14,6 +14,7 @@ //! Filter Push Down optimizer rule ensures that filters are applied as early as possible in the plan +use crate::utils::{split_conjunction, CnfHelper}; use crate::{utils, OptimizerConfig, OptimizerRule}; use datafusion_common::{Column, DFSchema, DataFusionError, Result}; use datafusion_expr::{ @@ -28,6 +29,7 @@ use datafusion_expr::{ utils::{expr_to_columns, exprlist_to_columns, from_plan}, Expr, Operator, TableProviderFilterPushDown, }; +use log::error; use std::collections::{HashMap, HashSet}; use std::iter::once; @@ -530,8 +532,14 @@ fn optimize(plan: &LogicalPlan, mut state: State) -> Result { } LogicalPlan::Analyze { .. } => push_down(&state, plan), LogicalPlan::Filter(filter) => { - let filter_cnf = utils::CnfHelper::new().rewrite_to_cnf_impl(filter.predicate()); - let predicates = utils::split_conjunction(&filter_cnf); + let filter_cnf = filter.predicate().clone().rewrite(&mut CnfHelper::new()); + let predicates = match filter_cnf { + Ok(ref expr) => split_conjunction(expr), + Err(e) => { + error!("Fail at CnfHelper rewrite: {}.", e); + split_conjunction(filter.predicate()) + } + }; predicates .into_iter() @@ -2369,7 +2377,7 @@ mod tests { .filter(filter)? .build()?; - let expected = "Filter: test.a = d AND test.b > UInt32(1) OR test.b = e AND test.c < UInt32(10)\ + let expected = "Filter: test.a = d OR test.b = e AND test.a = d OR test.c < UInt32(10) AND test.b > UInt32(1) OR test.b = e\ \n CrossJoin:\ \n Projection: test.a, test.b, test.c\ \n Filter: test.b > UInt32(1) OR test.c < UInt32(10)\ From fce97fecb24b57910350d1b5dd8bd84186882b6e Mon Sep 17 00:00:00 2001 From: yangjiang Date: Tue, 18 Oct 2022 21:59:37 +0800 Subject: [PATCH 4/5] rebase master for binaryExpr change Signed-off-by: yangjiang --- datafusion/optimizer/src/utils.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/datafusion/optimizer/src/utils.rs b/datafusion/optimizer/src/utils.rs index 3d3740ceed2f..856195432622 100644 --- a/datafusion/optimizer/src/utils.rs +++ b/datafusion/optimizer/src/utils.rs @@ -186,7 +186,7 @@ impl ExprRewriter for CnfHelper { self.original_expr = Some(expr.clone()); } match expr { - Expr::BinaryExpr { left, op, right } => { + Expr::BinaryExpr(BinaryExpr { left, op, right }) => { match op { Operator::And => { if self.increment_and_check_overload() { @@ -204,11 +204,11 @@ impl ExprRewriter for CnfHelper { } left.iter().for_each(|l| { right.iter().for_each(|r| { - self.exprs.push(Expr::BinaryExpr { + self.exprs.push(Expr::BinaryExpr(BinaryExpr { left: Box::new(l.clone()), op: Operator::Or, right: Box::new(r.clone()), - }) + })) }) }); return Ok(RewriteRecursion::Mutate); From e8cda828f4e1dafb7ddc91031b7a00a0507f040e Mon Sep 17 00:00:00 2001 From: yangjiang Date: Wed, 19 Oct 2022 15:07:20 +0800 Subject: [PATCH 5/5] fix tests and support split_conjunction on other type Signed-off-by: yangjiang --- benchmarks/expected-plans/q7.txt | 2 +- .../physical_plan/file_format/row_filter.rs | 4 +- datafusion/core/tests/sql/joins.rs | 7 +- datafusion/optimizer/src/utils.rs | 67 ++++++++++++------- 4 files changed, 52 insertions(+), 28 deletions(-) diff --git a/benchmarks/expected-plans/q7.txt b/benchmarks/expected-plans/q7.txt index a1d1806f9189..73fe8574a627 100644 --- a/benchmarks/expected-plans/q7.txt +++ b/benchmarks/expected-plans/q7.txt @@ -3,7 +3,7 @@ Sort: shipping.supp_nation ASC NULLS LAST, shipping.cust_nation ASC NULLS LAST, Aggregate: groupBy=[[shipping.supp_nation, shipping.cust_nation, shipping.l_year]], aggr=[[SUM(shipping.volume)]] Projection: shipping.supp_nation, shipping.cust_nation, shipping.l_year, shipping.volume, alias=shipping Projection: n1.n_name AS supp_nation, n2.n_name AS cust_nation, datepart(Utf8("YEAR"), lineitem.l_shipdate) AS l_year, CAST(lineitem.l_extendedprice AS Decimal128(38, 4)) * CAST(Decimal128(Some(100),23,2) - CAST(lineitem.l_discount AS Decimal128(23, 2)) AS Decimal128(38, 4)) AS volume, alias=shipping - Filter: n1.n_name = Utf8("FRANCE") AND n2.n_name = Utf8("GERMANY") OR n1.n_name = Utf8("GERMANY") AND n2.n_name = Utf8("FRANCE") + Filter: n1.n_name = Utf8("FRANCE") OR n2.n_name = Utf8("FRANCE") AND n2.n_name = Utf8("GERMANY") OR n1.n_name = Utf8("GERMANY") Inner Join: customer.c_nationkey = n2.n_nationkey Inner Join: supplier.s_nationkey = n1.n_nationkey Inner Join: orders.o_custkey = customer.c_custkey diff --git a/datafusion/core/src/physical_plan/file_format/row_filter.rs b/datafusion/core/src/physical_plan/file_format/row_filter.rs index dd9c8fb650fd..2ac55d368bf9 100644 --- a/datafusion/core/src/physical_plan/file_format/row_filter.rs +++ b/datafusion/core/src/physical_plan/file_format/row_filter.rs @@ -22,7 +22,7 @@ use arrow::record_batch::RecordBatch; use datafusion_common::{Column, DataFusionError, Result, ScalarValue, ToDFSchema}; use datafusion_expr::expr_rewriter::{ExprRewritable, ExprRewriter, RewriteRecursion}; -use datafusion_expr::Expr; +use datafusion_expr::{Expr, Operator}; use datafusion_optimizer::utils::split_conjunction_owned; use datafusion_physical_expr::execution_props::ExecutionProps; use datafusion_physical_expr::{create_physical_expr, PhysicalExpr}; @@ -253,7 +253,7 @@ pub fn build_row_filter( metadata: &ParquetMetaData, reorder_predicates: bool, ) -> Result> { - let predicates = split_conjunction_owned(expr); + let predicates = split_conjunction_owned(expr, Operator::And); let mut candidates: Vec = predicates .into_iter() diff --git a/datafusion/core/tests/sql/joins.rs b/datafusion/core/tests/sql/joins.rs index 2ff4947b3214..1ba8cf7ac42e 100644 --- a/datafusion/core/tests/sql/joins.rs +++ b/datafusion/core/tests/sql/joins.rs @@ -1468,10 +1468,15 @@ async fn reduce_left_join_2() -> Result<()> { .expect(&msg); let state = ctx.state(); let plan = state.optimize(&plan)?; + + // filter expr: `t2.t2_int < 10 or (t1.t1_int > 2 and t2.t2_name != 'w')` + // could be write to: `(t1.t1_int > 2 or t2.t2_int < 10) and (t2.t2_name != 'w' or t2.t2_int < 10)` + // the right part `(t2.t2_name != 'w' or t2.t2_int < 10)` could be push down left join side and remove in filter. + let expected = vec![ "Explain [plan_type:Utf8, plan:Utf8]", " Projection: t1.t1_id, t1.t1_name, t1.t1_int, t2.t2_id, t2.t2_name, t2.t2_int [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", - " Filter: CAST(t2.t2_int AS Int64) < Int64(10) OR CAST(t1.t1_int AS Int64) > Int64(2) AND t2.t2_name != Utf8(\"w\") [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", + " Filter: CAST(t2.t2_int AS Int64) < Int64(10) OR CAST(t1.t1_int AS Int64) > Int64(2) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", " Inner Join: t1.t1_id = t2.t2_id [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", " TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", " Filter: CAST(t2.t2_int AS Int64) < Int64(10) OR t2.t2_name != Utf8(\"w\") [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", diff --git a/datafusion/optimizer/src/utils.rs b/datafusion/optimizer/src/utils.rs index 856195432622..f088085b8812 100644 --- a/datafusion/optimizer/src/utils.rs +++ b/datafusion/optimizer/src/utils.rs @@ -84,7 +84,7 @@ fn split_conjunction_impl<'a>(expr: &'a Expr, mut exprs: Vec<&'a Expr>) -> Vec<& /// /// # Example /// ``` -/// # use datafusion_expr::{col, lit}; +/// # use datafusion_expr::{col, lit, Operator}; /// # use datafusion_optimizer::utils::split_conjunction_owned; /// // a=1 AND b=2 /// let expr = col("a").eq(lit(1)).and(col("b").eq(lit(2))); @@ -96,23 +96,23 @@ fn split_conjunction_impl<'a>(expr: &'a Expr, mut exprs: Vec<&'a Expr>) -> Vec<& /// ]; /// /// // use split_conjunction_owned to split them -/// assert_eq!(split_conjunction_owned(expr), split); +/// assert_eq!(split_conjunction_owned(expr, Operator::And), split); /// ``` -pub fn split_conjunction_owned(expr: Expr) -> Vec { - split_conjunction_owned_impl(expr, vec![]) +pub fn split_conjunction_owned(expr: Expr, op: Operator) -> Vec { + split_conjunction_owned_impl(expr, op, vec![]) } -fn split_conjunction_owned_impl(expr: Expr, mut exprs: Vec) -> Vec { +fn split_conjunction_owned_impl( + expr: Expr, + operator: Operator, + mut exprs: Vec, +) -> Vec { match expr { - Expr::BinaryExpr(BinaryExpr { - right, - op: Operator::And, - left, - }) => { - let exprs = split_conjunction_owned_impl(*left, exprs); - split_conjunction_owned_impl(*right, exprs) + Expr::BinaryExpr(BinaryExpr { right, op, left }) if op == operator => { + let exprs = split_conjunction_owned_impl(*left, Operator::And, exprs); + split_conjunction_owned_impl(*right, Operator::And, exprs) } - Expr::Alias(expr, _) => split_conjunction_owned_impl(*expr, exprs), + Expr::Alias(expr, _) => split_conjunction_owned_impl(*expr, Operator::And, exprs), other => { exprs.push(other); exprs @@ -157,7 +157,7 @@ pub struct CnfHelper { impl CnfHelper { pub fn new() -> Self { CnfHelper { - max_count: 100, + max_count: 50, current_count: 0, exprs: vec![], original_expr: None, @@ -195,15 +195,25 @@ impl ExprRewriter for CnfHelper { } // (a AND b) OR (c AND d) = (a OR b) AND (a OR c) AND (b OR c) AND (b OR d) Operator::Or => { - let left = split_conjunction_owned(*left.clone()); - let right = split_conjunction_owned(*right.clone()); - let count = left.len() * right.len() - 1; - self.current_count += count; + let left_and_split = + split_conjunction_owned(*left.clone(), Operator::And); + let right_and_split = + split_conjunction_owned(*right.clone(), Operator::And); + // Avoid create to much Expr like in tpch q19. + let lc = split_conjunction_owned(*left.clone(), Operator::Or) + .into_iter() + .flat_map(|e| split_conjunction_owned(e, Operator::And)) + .count(); + let rc = split_conjunction_owned(*right.clone(), Operator::Or) + .into_iter() + .flat_map(|e| split_conjunction_owned(e, Operator::And)) + .count(); + self.current_count += lc * rc - 1; if self.increment_and_check_overload() { return Ok(RewriteRecursion::Mutate); } - left.iter().for_each(|l| { - right.iter().for_each(|r| { + left_and_split.iter().for_each(|l| { + right_and_split.iter().for_each(|r| { self.exprs.push(Expr::BinaryExpr(BinaryExpr { left: Box::new(l.clone()), op: Operator::Or, @@ -643,13 +653,16 @@ mod tests { #[test] fn test_split_conjunction_owned() { let expr = col("a"); - assert_eq!(split_conjunction_owned(expr.clone()), vec![expr]); + assert_eq!( + split_conjunction_owned(expr.clone(), Operator::And), + vec![expr] + ); } #[test] fn test_split_conjunction_owned_two() { assert_eq!( - split_conjunction_owned(col("a").eq(lit(5)).and(col("b"))), + split_conjunction_owned(col("a").eq(lit(5)).and(col("b")), Operator::And), vec![col("a").eq(lit(5)), col("b")] ); } @@ -657,7 +670,10 @@ mod tests { #[test] fn test_split_conjunction_owned_alias() { assert_eq!( - split_conjunction_owned(col("a").eq(lit(5)).and(col("b").alias("the_alias"))), + split_conjunction_owned( + col("a").eq(lit(5)).and(col("b").alias("the_alias")), + Operator::And + ), vec![ col("a").eq(lit(5)), // no alias on b @@ -703,7 +719,10 @@ mod tests { #[test] fn test_split_conjunction_owned_or() { let expr = col("a").eq(lit(5)).or(col("b")); - assert_eq!(split_conjunction_owned(expr.clone()), vec![expr]); + assert_eq!( + split_conjunction_owned(expr.clone(), Operator::And), + vec![expr] + ); } #[test]