diff --git a/benchmarks/expected-plans/q7.txt b/benchmarks/expected-plans/q7.txt index a1d1806f91893..73fe8574a6272 100644 --- a/benchmarks/expected-plans/q7.txt +++ b/benchmarks/expected-plans/q7.txt @@ -3,7 +3,7 @@ Sort: shipping.supp_nation ASC NULLS LAST, shipping.cust_nation ASC NULLS LAST, Aggregate: groupBy=[[shipping.supp_nation, shipping.cust_nation, shipping.l_year]], aggr=[[SUM(shipping.volume)]] Projection: shipping.supp_nation, shipping.cust_nation, shipping.l_year, shipping.volume, alias=shipping Projection: n1.n_name AS supp_nation, n2.n_name AS cust_nation, datepart(Utf8("YEAR"), lineitem.l_shipdate) AS l_year, CAST(lineitem.l_extendedprice AS Decimal128(38, 4)) * CAST(Decimal128(Some(100),23,2) - CAST(lineitem.l_discount AS Decimal128(23, 2)) AS Decimal128(38, 4)) AS volume, alias=shipping - Filter: n1.n_name = Utf8("FRANCE") AND n2.n_name = Utf8("GERMANY") OR n1.n_name = Utf8("GERMANY") AND n2.n_name = Utf8("FRANCE") + Filter: n1.n_name = Utf8("FRANCE") OR n2.n_name = Utf8("FRANCE") AND n2.n_name = Utf8("GERMANY") OR n1.n_name = Utf8("GERMANY") Inner Join: customer.c_nationkey = n2.n_nationkey Inner Join: supplier.s_nationkey = n1.n_nationkey Inner Join: orders.o_custkey = customer.c_custkey diff --git a/datafusion/core/src/physical_plan/file_format/row_filter.rs b/datafusion/core/src/physical_plan/file_format/row_filter.rs index dd9c8fb650fd1..2ac55d368bf99 100644 --- a/datafusion/core/src/physical_plan/file_format/row_filter.rs +++ b/datafusion/core/src/physical_plan/file_format/row_filter.rs @@ -22,7 +22,7 @@ use arrow::record_batch::RecordBatch; use datafusion_common::{Column, DataFusionError, Result, ScalarValue, ToDFSchema}; use datafusion_expr::expr_rewriter::{ExprRewritable, ExprRewriter, RewriteRecursion}; -use datafusion_expr::Expr; +use datafusion_expr::{Expr, Operator}; use datafusion_optimizer::utils::split_conjunction_owned; use datafusion_physical_expr::execution_props::ExecutionProps; use datafusion_physical_expr::{create_physical_expr, PhysicalExpr}; @@ -253,7 +253,7 @@ pub fn build_row_filter( metadata: &ParquetMetaData, reorder_predicates: bool, ) -> Result> { - let predicates = split_conjunction_owned(expr); + let predicates = split_conjunction_owned(expr, Operator::And); let mut candidates: Vec = predicates .into_iter() diff --git a/datafusion/core/tests/sql/joins.rs b/datafusion/core/tests/sql/joins.rs index 2ff4947b3214a..1ba8cf7ac42ef 100644 --- a/datafusion/core/tests/sql/joins.rs +++ b/datafusion/core/tests/sql/joins.rs @@ -1468,10 +1468,15 @@ async fn reduce_left_join_2() -> Result<()> { .expect(&msg); let state = ctx.state(); let plan = state.optimize(&plan)?; + + // filter expr: `t2.t2_int < 10 or (t1.t1_int > 2 and t2.t2_name != 'w')` + // could be write to: `(t1.t1_int > 2 or t2.t2_int < 10) and (t2.t2_name != 'w' or t2.t2_int < 10)` + // the right part `(t2.t2_name != 'w' or t2.t2_int < 10)` could be push down left join side and remove in filter. + let expected = vec![ "Explain [plan_type:Utf8, plan:Utf8]", " Projection: t1.t1_id, t1.t1_name, t1.t1_int, t2.t2_id, t2.t2_name, t2.t2_int [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", - " Filter: CAST(t2.t2_int AS Int64) < Int64(10) OR CAST(t1.t1_int AS Int64) > Int64(2) AND t2.t2_name != Utf8(\"w\") [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", + " Filter: CAST(t2.t2_int AS Int64) < Int64(10) OR CAST(t1.t1_int AS Int64) > Int64(2) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", " Inner Join: t1.t1_id = t2.t2_id [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", " TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", " Filter: CAST(t2.t2_int AS Int64) < Int64(10) OR t2.t2_name != Utf8(\"w\") [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", diff --git a/datafusion/optimizer/src/filter_push_down.rs b/datafusion/optimizer/src/filter_push_down.rs index 6396f1fbfd6c1..a768b0a7fed58 100644 --- a/datafusion/optimizer/src/filter_push_down.rs +++ b/datafusion/optimizer/src/filter_push_down.rs @@ -14,6 +14,7 @@ //! Filter Push Down optimizer rule ensures that filters are applied as early as possible in the plan +use crate::utils::{split_conjunction, CnfHelper}; use crate::{utils, OptimizerConfig, OptimizerRule}; use datafusion_common::{Column, DFSchema, DataFusionError, Result}; use datafusion_expr::{ @@ -28,6 +29,7 @@ use datafusion_expr::{ utils::{expr_to_columns, exprlist_to_columns, from_plan}, Expr, Operator, TableProviderFilterPushDown, }; +use log::error; use std::collections::{HashMap, HashSet}; use std::iter::once; @@ -530,7 +532,14 @@ fn optimize(plan: &LogicalPlan, mut state: State) -> Result { } LogicalPlan::Analyze { .. } => push_down(&state, plan), LogicalPlan::Filter(filter) => { - let predicates = utils::split_conjunction(filter.predicate()); + let filter_cnf = filter.predicate().clone().rewrite(&mut CnfHelper::new()); + let predicates = match filter_cnf { + Ok(ref expr) => split_conjunction(expr), + Err(e) => { + error!("Fail at CnfHelper rewrite: {}.", e); + split_conjunction(filter.predicate()) + } + }; predicates .into_iter() @@ -953,6 +962,30 @@ mod tests { Ok(()) } + #[test] + fn filter_keep_partial_agg() -> Result<()> { + let table_scan = test_table_scan()?; + let f1 = col("c").eq(lit(1i64)).and(col("b").gt(lit(2i64))); + let f2 = col("c").eq(lit(1i64)).and(col("b").gt(lit(3i64))); + let filter = f1.or(f2); + let plan = LogicalPlanBuilder::from(table_scan) + .aggregate(vec![col("a")], vec![sum(col("b")).alias("b")])? + .filter(filter)? + .build()?; + // filter of aggregate is after aggregation since they are non-commutative + // (c =1 AND b > 2) OR (c = 1 AND b > 3) + // rewrite to CNF + // (c = 1 OR c = 1) [can pushDown] AND (c = 1 OR b > 3) AND (b > 2 OR C = 1) AND (b > 2 OR b > 3) + + let expected = "\ + Filter: test.c = Int64(1) OR b > Int64(3) AND b > Int64(2) OR test.c = Int64(1) AND b > Int64(2) OR b > Int64(3)\ + \n Aggregate: groupBy=[[test.a]], aggr=[[SUM(test.b) AS b]]\ + \n Filter: test.c = Int64(1) OR test.c = Int64(1)\ + \n TableScan: test"; + assert_optimized_plan_eq(&plan, expected); + Ok(()) + } + /// verifies that a filter is pushed to before a projection, the filter expression is correctly re-written #[test] fn alias() -> Result<()> { @@ -2344,7 +2377,7 @@ mod tests { .filter(filter)? .build()?; - let expected = "Filter: test.a = d AND test.b > UInt32(1) OR test.b = e AND test.c < UInt32(10)\ + let expected = "Filter: test.a = d OR test.b = e AND test.a = d OR test.c < UInt32(10) AND test.b > UInt32(1) OR test.b = e\ \n CrossJoin:\ \n Projection: test.a, test.b, test.c\ \n Filter: test.b > UInt32(1) OR test.c < UInt32(10)\ diff --git a/datafusion/optimizer/src/utils.rs b/datafusion/optimizer/src/utils.rs index 130df3e0e6efc..f088085b8812b 100644 --- a/datafusion/optimizer/src/utils.rs +++ b/datafusion/optimizer/src/utils.rs @@ -21,7 +21,7 @@ use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::Result; use datafusion_common::{plan_err, Column, DFSchemaRef}; use datafusion_expr::expr::BinaryExpr; -use datafusion_expr::expr_rewriter::{ExprRewritable, ExprRewriter}; +use datafusion_expr::expr_rewriter::{ExprRewritable, ExprRewriter, RewriteRecursion}; use datafusion_expr::expr_visitor::{ExprVisitable, ExpressionVisitor, Recursion}; use datafusion_expr::{ and, col, @@ -84,7 +84,7 @@ fn split_conjunction_impl<'a>(expr: &'a Expr, mut exprs: Vec<&'a Expr>) -> Vec<& /// /// # Example /// ``` -/// # use datafusion_expr::{col, lit}; +/// # use datafusion_expr::{col, lit, Operator}; /// # use datafusion_optimizer::utils::split_conjunction_owned; /// // a=1 AND b=2 /// let expr = col("a").eq(lit(1)).and(col("b").eq(lit(2))); @@ -96,23 +96,23 @@ fn split_conjunction_impl<'a>(expr: &'a Expr, mut exprs: Vec<&'a Expr>) -> Vec<& /// ]; /// /// // use split_conjunction_owned to split them -/// assert_eq!(split_conjunction_owned(expr), split); +/// assert_eq!(split_conjunction_owned(expr, Operator::And), split); /// ``` -pub fn split_conjunction_owned(expr: Expr) -> Vec { - split_conjunction_owned_impl(expr, vec![]) +pub fn split_conjunction_owned(expr: Expr, op: Operator) -> Vec { + split_conjunction_owned_impl(expr, op, vec![]) } -fn split_conjunction_owned_impl(expr: Expr, mut exprs: Vec) -> Vec { +fn split_conjunction_owned_impl( + expr: Expr, + operator: Operator, + mut exprs: Vec, +) -> Vec { match expr { - Expr::BinaryExpr(BinaryExpr { - right, - op: Operator::And, - left, - }) => { - let exprs = split_conjunction_owned_impl(*left, exprs); - split_conjunction_owned_impl(*right, exprs) + Expr::BinaryExpr(BinaryExpr { right, op, left }) if op == operator => { + let exprs = split_conjunction_owned_impl(*left, Operator::And, exprs); + split_conjunction_owned_impl(*right, Operator::And, exprs) } - Expr::Alias(expr, _) => split_conjunction_owned_impl(*expr, exprs), + Expr::Alias(expr, _) => split_conjunction_owned_impl(*expr, Operator::And, exprs), other => { exprs.push(other); exprs @@ -120,6 +120,149 @@ fn split_conjunction_owned_impl(expr: Expr, mut exprs: Vec) -> Vec { } } +/// Converts an expression to conjunctive normal form (CNF). +/// +/// The following expression is in CNF: +/// `(a OR b) AND (c OR d)` +/// The following is not in CNF: +/// `(a AND b) OR c`. +/// But could be rewrite to a CNF expression: +/// `(a OR c) AND (b OR c)`. +/// +/// # Example +/// ``` +/// # use datafusion_expr::{col, lit}; +/// # use datafusion_expr::expr_rewriter::ExprRewritable; +/// # use datafusion_optimizer::utils::CnfHelper; +/// // (a=1 AND b=2)OR c = 3 +/// let expr1 = col("a").eq(lit(1)).and(col("b").eq(lit(2))); +/// let expr2 = col("c").eq(lit(3)); +/// let expr = expr1.or(expr2); +/// +/// //(a=1 or c=3)AND(b=2 or c=3) +/// let expr1 = col("a").eq(lit(1)).or(col("c").eq(lit(3))); +/// let expr2 = col("b").eq(lit(2)).or(col("c").eq(lit(3))); +/// let expect = expr1.and(expr2); +/// // use split_conjunction_owned to split them +/// assert_eq!(expr.rewrite(& mut CnfHelper::new()).unwrap(), expect); +/// ``` +/// +pub struct CnfHelper { + max_count: usize, + current_count: usize, + exprs: Vec, + original_expr: Option, +} + +impl CnfHelper { + pub fn new() -> Self { + CnfHelper { + max_count: 50, + current_count: 0, + exprs: vec![], + original_expr: None, + } + } + + pub fn new_with_max_count(max_count: usize) -> Self { + CnfHelper { + max_count, + current_count: 0, + exprs: vec![], + original_expr: None, + } + } + + fn increment_and_check_overload(&mut self) -> bool { + self.current_count += 1; + self.current_count >= self.max_count + } +} + +impl ExprRewriter for CnfHelper { + fn pre_visit(&mut self, expr: &Expr) -> Result { + let is_root = self.original_expr.is_none(); + if is_root { + self.original_expr = Some(expr.clone()); + } + match expr { + Expr::BinaryExpr(BinaryExpr { left, op, right }) => { + match op { + Operator::And => { + if self.increment_and_check_overload() { + return Ok(RewriteRecursion::Mutate); + } + } + // (a AND b) OR (c AND d) = (a OR b) AND (a OR c) AND (b OR c) AND (b OR d) + Operator::Or => { + let left_and_split = + split_conjunction_owned(*left.clone(), Operator::And); + let right_and_split = + split_conjunction_owned(*right.clone(), Operator::And); + // Avoid create to much Expr like in tpch q19. + let lc = split_conjunction_owned(*left.clone(), Operator::Or) + .into_iter() + .flat_map(|e| split_conjunction_owned(e, Operator::And)) + .count(); + let rc = split_conjunction_owned(*right.clone(), Operator::Or) + .into_iter() + .flat_map(|e| split_conjunction_owned(e, Operator::And)) + .count(); + self.current_count += lc * rc - 1; + if self.increment_and_check_overload() { + return Ok(RewriteRecursion::Mutate); + } + left_and_split.iter().for_each(|l| { + right_and_split.iter().for_each(|r| { + self.exprs.push(Expr::BinaryExpr(BinaryExpr { + left: Box::new(l.clone()), + op: Operator::Or, + right: Box::new(r.clone()), + })) + }) + }); + return Ok(RewriteRecursion::Mutate); + } + _ => { + if self.increment_and_check_overload() { + return Ok(RewriteRecursion::Mutate); + } + self.exprs.push(expr.clone()); + return Ok(RewriteRecursion::Stop); + } + } + } + other => { + if self.increment_and_check_overload() { + return Ok(RewriteRecursion::Mutate); + } + self.exprs.push(other.clone()); + return Ok(RewriteRecursion::Stop); + } + } + if is_root { + Ok(RewriteRecursion::Continue) + } else { + Ok(RewriteRecursion::Skip) + } + } + + fn mutate(&mut self, _expr: Expr) -> Result { + if self.current_count >= self.max_count { + Ok(self.original_expr.as_ref().unwrap().clone()) + } else { + Ok(conjunction(self.exprs.clone()) + .unwrap_or_else(|| self.original_expr.as_ref().unwrap().clone())) + } + } +} + +impl Default for CnfHelper { + fn default() -> Self { + Self::new() + } +} + /// Combines an array of filter expressions into a single filter /// expression consisting of the input filter expressions joined with /// logical AND. @@ -469,7 +612,7 @@ mod tests { use super::*; use arrow::datatypes::DataType; use datafusion_common::Column; - use datafusion_expr::{col, lit, utils::expr_to_columns}; + use datafusion_expr::{col, lit, or, utils::expr_to_columns}; use std::collections::HashSet; use std::ops::Add; @@ -510,13 +653,16 @@ mod tests { #[test] fn test_split_conjunction_owned() { let expr = col("a"); - assert_eq!(split_conjunction_owned(expr.clone()), vec![expr]); + assert_eq!( + split_conjunction_owned(expr.clone(), Operator::And), + vec![expr] + ); } #[test] fn test_split_conjunction_owned_two() { assert_eq!( - split_conjunction_owned(col("a").eq(lit(5)).and(col("b"))), + split_conjunction_owned(col("a").eq(lit(5)).and(col("b")), Operator::And), vec![col("a").eq(lit(5)), col("b")] ); } @@ -524,7 +670,10 @@ mod tests { #[test] fn test_split_conjunction_owned_alias() { assert_eq!( - split_conjunction_owned(col("a").eq(lit(5)).and(col("b").alias("the_alias"))), + split_conjunction_owned( + col("a").eq(lit(5)).and(col("b").alias("the_alias")), + Operator::And + ), vec![ col("a").eq(lit(5)), // no alias on b @@ -570,7 +719,10 @@ mod tests { #[test] fn test_split_conjunction_owned_or() { let expr = col("a").eq(lit(5)).or(col("b")); - assert_eq!(split_conjunction_owned(expr.clone()), vec![expr]); + assert_eq!( + split_conjunction_owned(expr.clone(), Operator::And), + vec![expr] + ); } #[test] @@ -663,4 +815,84 @@ mod tests { "mismatch rewriting expr_from: {expr_from} to {rewrite_to}" ) } + + #[test] + fn test_rewrite_cnf() { + let a_1 = col("a").eq(lit(1i64)); + let a_2 = col("a").eq(lit(2i64)); + + let b_1 = col("b").eq(lit(1i64)); + let b_2 = col("b").eq(lit(2i64)); + + // Test rewrite on a1_and_b2 and a2_and_b1 -> not change + let mut helper = CnfHelper::new(); + let expr1 = and(a_1.clone(), b_2.clone()); + let expect = expr1.clone(); + let res = expr1.rewrite(&mut helper).unwrap(); + assert_eq!(expect, res); + + // Test rewrite on a1_and_b2 and a2_and_b1 -> (((a1 and b2) and a2) and b1) + let mut helper = CnfHelper::new(); + let expr1 = and(and(a_1.clone(), b_2.clone()), and(a_2.clone(), b_1.clone())); + let expect = and(a_1.clone(), b_2.clone()) + .and(a_2.clone()) + .and(b_1.clone()); + let res = expr1.rewrite(&mut helper).unwrap(); + assert_eq!(expect, res); + + // Test rewrite on a1_or_b2 -> not change + let mut helper = CnfHelper::new(); + let expr1 = or(a_1.clone(), b_2.clone()); + let expect = expr1.clone(); + let res = expr1.rewrite(&mut helper).unwrap(); + assert_eq!(expect, res); + + // Test rewrite on a1_and_b2 or a2_and_b1 -> a1_or_a2 and a1_or_b1 and b2_or_a2 and b2_or_b1 + let mut helper = CnfHelper::new(); + let expr1 = or(and(a_1.clone(), b_2.clone()), and(a_2.clone(), b_1.clone())); + let a1_or_a2 = or(a_1.clone(), a_2.clone()); + let a1_or_b1 = or(a_1.clone(), b_1.clone()); + let b2_or_a2 = or(b_2.clone(), a_2.clone()); + let b2_or_b1 = or(b_2.clone(), b_1.clone()); + let expect = and(a1_or_a2, a1_or_b1).and(b2_or_a2).and(b2_or_b1); + let res = expr1.rewrite(&mut helper).unwrap(); + assert_eq!(expect, res); + + // Test rewrite on a1_or_b2 or a2_and_b1 -> ( a1_or_a2 or a2 ) and (a1_or_a2 or b1) + let mut helper = CnfHelper::new(); + let a1_or_b2 = or(a_1.clone(), b_2.clone()); + let expr1 = or(or(a_1.clone(), b_2.clone()), and(a_2.clone(), b_1.clone())); + let expect = or(a1_or_b2.clone(), a_2.clone()).and(or(a1_or_b2, b_1.clone())); + let res = expr1.rewrite(&mut helper).unwrap(); + assert_eq!(expect, res); + + // Test rewrite on a1_or_b2 or a2_or_b1 -> not change + let mut helper = CnfHelper::new(); + let expr1 = or(or(a_1, b_2), or(a_2, b_1)); + let expect = expr1.clone(); + let res = expr1.rewrite(&mut helper).unwrap(); + assert_eq!(expect, res); + } + + #[test] + fn test_rewrite_cnf_overflow() { + // in this situation: + // AND = (a=1 and b=2) + // rewrite (AND * 10) or (AND * 10), it will produce 10 * 10 = 100 (a=1 or b=2) + // which cause size expansion. + + let mut expr1 = col("test1").eq(lit(1i64)); + let expr2 = col("test2").eq(lit(2i64)); + + for _i in 0..9 { + expr1 = expr1.clone().and(expr2.clone()); + } + let expr3 = expr1.clone(); + let expr = or(expr1, expr3); + let mut helper = CnfHelper::new(); + let res = expr.clone().rewrite(&mut helper).unwrap(); + assert_eq!(100, helper.current_count); + assert_eq!(res, expr); + assert!(helper.current_count >= helper.max_count); + } }