From c75e71cbd1b3195cbbf8a2ebb60d62f6c3229e7a Mon Sep 17 00:00:00 2001 From: jackwener Date: Wed, 30 Nov 2022 22:56:11 +0800 Subject: [PATCH 1/7] fix `push_down_filter` push column instead of Expr. --- datafusion/optimizer/src/push_down_filter.rs | 12 +++++++++++- datafusion/optimizer/tests/integration-test.rs | 12 ++++++++++++ 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index e59590df55157..427f55f0c268b 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -654,7 +654,17 @@ impl OptimizerRule for PushDownFilter { } } - let child = match conjunction(push_predicates) { + let mut replace_map = HashMap::new(); + for expr in &agg.group_expr { + replace_map.insert(expr.display_name()?, expr.clone()); + } + + let replaced_push_predicates = push_predicates + .iter() + .map(|expr| replace_cols_by_name(expr.clone(), &replace_map)) + .collect::>>()?; + + let child = match conjunction(replaced_push_predicates) { Some(predicate) => LogicalPlan::Filter(Filter::try_new( predicate, Arc::new((*agg.input).clone()), diff --git a/datafusion/optimizer/tests/integration-test.rs b/datafusion/optimizer/tests/integration-test.rs index 457ea833ef3a9..701d1a84c604b 100644 --- a/datafusion/optimizer/tests/integration-test.rs +++ b/datafusion/optimizer/tests/integration-test.rs @@ -304,6 +304,18 @@ fn join_keys_in_subquery_alias_1() { assert_eq!(expected, format!("{:?}", plan)); } +#[test] +fn push_down_filter_groupby_expr_contains_alias() { + let sql = "SELECT * FROM (SELECT (col_int32 + col_uint32) AS c, count(*) FROM test GROUP BY 1) where c > 3"; + let plan = test_sql(sql).unwrap(); + let expected = "Projection: c, COUNT(UInt8(1))\ + \n Projection: test.col_int32 + test.col_uint32 AS c, COUNT(UInt8(1))\ + \n Aggregate: groupBy=[[test.col_int32 + CAST(test.col_uint32 AS Int32)]], aggr=[[COUNT(UInt8(1))]]\ + \n Filter: test.col_int32 + CAST(test.col_uint32 AS Int32) > Int32(3)\ + \n TableScan: test projection=[col_int32, col_uint32]"; + assert_eq!(expected, format!("{:?}", plan)); +} + fn test_sql(sql: &str) -> Result { // parse the SQL let dialect = GenericDialect {}; // or AnsiDialect, or your own dialect ... From 018b9a583f3843ea34ef027df0b792d0047ede81 Mon Sep 17 00:00:00 2001 From: jackwener Date: Wed, 30 Nov 2022 23:04:28 +0800 Subject: [PATCH 2/7] remove collect to avoid performance loss --- datafusion/optimizer/src/push_down_filter.rs | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index 427f55f0c268b..c74079c2ab6ab 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -625,14 +625,14 @@ impl OptimizerRule for PushDownFilter { // * the aggregation columns themselves // construct set of columns that `aggr_expr` depends on - let mut used_columns = HashSet::new(); - exprlist_to_columns(&agg.aggr_expr, &mut used_columns)?; + let mut aggr_expr_columns = HashSet::new(); + exprlist_to_columns(&agg.aggr_expr, &mut aggr_expr_columns)?; let agg_columns = agg .aggr_expr .iter() .map(|x| Ok(Column::from_name(x.display_name()?))) .collect::>>()?; - used_columns.extend(agg_columns); + aggr_expr_columns.extend(agg_columns); let predicates = utils::split_conjunction_owned(utils::cnf_rewrite( filter.predicate().clone(), @@ -643,10 +643,7 @@ impl OptimizerRule for PushDownFilter { for expr in predicates { let columns = expr.to_columns()?; if columns.is_empty() - || !columns - .intersection(&used_columns) - .collect::>() - .is_empty() + || columns.intersection(&aggr_expr_columns).next().is_some() { keep_predicates.push(expr); } else { From 992f3afea4ea3fd43444dbcdff338f4e66832c29 Mon Sep 17 00:00:00 2001 From: jackwener Date: Thu, 1 Dec 2022 09:00:25 +0800 Subject: [PATCH 3/7] add UT --- datafusion/optimizer/src/push_down_filter.rs | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index c74079c2ab6ab..779e6dd6ea52e 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -887,6 +887,19 @@ mod tests { assert_optimized_plan_eq(&plan, expected) } + #[test] + fn push_agg_need_replace_expr() -> Result<()> { + let plan = LogicalPlanBuilder::from(test_table_scan()?) + .aggregate(vec![add(col("b"), col("a"))], vec![sum(col("a")), col("b")])? + .filter(col("test.b + test.a").gt(lit(10i64)))? + .build()?; + let expected = + "Aggregate: groupBy=[[test.b + test.a]], aggr=[[SUM(test.a), test.b]]\ + \n Filter: test.b + test.a > Int64(10)\ + \n TableScan: test"; + assert_optimized_plan_eq(&plan, expected) + } + #[test] fn filter_keep_agg() -> Result<()> { let table_scan = test_table_scan()?; From 80c025bd4f0803004e8258ae80969800f88202c8 Mon Sep 17 00:00:00 2001 From: jackwener Date: Thu, 1 Dec 2022 12:34:11 +0800 Subject: [PATCH 4/7] enhance filter push through agg --- datafusion/optimizer/src/push_down_filter.rs | 48 ++++++-------------- 1 file changed, 14 insertions(+), 34 deletions(-) diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index 779e6dd6ea52e..966fa92e40b4a 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -17,7 +17,6 @@ use crate::utils::conjunction; use crate::{utils, OptimizerConfig, OptimizerRule}; use datafusion_common::{Column, DFSchema, DataFusionError, Result}; -use datafusion_expr::utils::exprlist_to_columns; use datafusion_expr::{ and, expr_rewriter::{replace_col, ExprRewritable, ExprRewriter}, @@ -620,19 +619,12 @@ impl OptimizerRule for PushDownFilter { }) } LogicalPlan::Aggregate(agg) => { - // An aggregate's aggregate columns are _not_ filter-commutable => collect these: - // * columns whose aggregation expression depends on - // * the aggregation columns themselves - - // construct set of columns that `aggr_expr` depends on - let mut aggr_expr_columns = HashSet::new(); - exprlist_to_columns(&agg.aggr_expr, &mut aggr_expr_columns)?; - let agg_columns = agg - .aggr_expr + // We can push down Predicate which in groupby_expr. + let group_expr_columns = agg + .group_expr .iter() - .map(|x| Ok(Column::from_name(x.display_name()?))) + .map(|e| Ok(Column::from_qualified_name(&(e.display_name()?)))) .collect::>>()?; - aggr_expr_columns.extend(agg_columns); let predicates = utils::split_conjunction_owned(utils::cnf_rewrite( filter.predicate().clone(), @@ -641,13 +633,11 @@ impl OptimizerRule for PushDownFilter { let mut keep_predicates = vec![]; let mut push_predicates = vec![]; for expr in predicates { - let columns = expr.to_columns()?; - if columns.is_empty() - || columns.intersection(&aggr_expr_columns).next().is_some() - { - keep_predicates.push(expr); - } else { + let cols = expr.to_columns()?; + if cols.iter().all(|c| group_expr_columns.contains(c)) { push_predicates.push(expr); + } else { + keep_predicates.push(expr); } } @@ -864,7 +854,7 @@ mod tests { let table_scan = test_table_scan()?; let plan = LogicalPlanBuilder::from(table_scan) .aggregate(vec![col("a")], vec![sum(col("b")).alias("total_salary")])? - .filter(col("a").gt(lit(10i64)))? + .filter(col("test.a").gt(lit(10i64)))? .build()?; // filter of key aggregation is commutative let expected = "\ @@ -930,11 +920,9 @@ mod tests { // rewrite to CNF // (c = 1 OR c = 1) [can pushDown] AND (c = 1 OR b > 3) AND (b > 2 OR C = 1) AND (b > 2 OR b > 3) - let expected = "\ - Filter: (test.c = Int64(1) OR b > Int64(3)) AND (b > Int64(2) OR test.c = Int64(1)) AND (b > Int64(2) OR b > Int64(3))\ + let expected = "Filter: (test.c = Int64(1) OR test.c = Int64(1)) AND (test.c = Int64(1) OR b > Int64(3)) AND (b > Int64(2) OR test.c = Int64(1)) AND (b > Int64(2) OR b > Int64(3))\ \n Aggregate: groupBy=[[test.a]], aggr=[[SUM(test.b) AS b]]\ - \n Filter: test.c = Int64(1) OR test.c = Int64(1)\ - \n TableScan: test"; + \n TableScan: test"; assert_optimized_plan_eq(&plan, expected) } @@ -1890,17 +1878,9 @@ mod tests { #[async_trait] impl TableSource for PushDownProvider { fn schema(&self) -> SchemaRef { - Arc::new(arrow::datatypes::Schema::new(vec![ - arrow::datatypes::Field::new( - "a", - arrow::datatypes::DataType::Int32, - true, - ), - arrow::datatypes::Field::new( - "b", - arrow::datatypes::DataType::Int32, - true, - ), + Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), ])) } From e51daa200fabcbc2c825ba167c9461feba1aa606 Mon Sep 17 00:00:00 2001 From: jackwener Date: Thu, 1 Dec 2022 12:50:17 +0800 Subject: [PATCH 5/7] add comment --- datafusion/optimizer/src/push_down_filter.rs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index 966fa92e40b4a..f0073c203e4ef 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -641,11 +641,13 @@ impl OptimizerRule for PushDownFilter { } } + // As for plan Filter: Column(a+b) > 0 -- Agg: groupby:[Column(a)+Column(b)] + // After push, we need to replace `a+b` with Column(a)+Column(b) + // So we need create a replace_map, add {`a+b` --> Expr(Column(a)+Column(b))} let mut replace_map = HashMap::new(); for expr in &agg.group_expr { replace_map.insert(expr.display_name()?, expr.clone()); } - let replaced_push_predicates = push_predicates .iter() .map(|expr| replace_cols_by_name(expr.clone(), &replace_map)) From c9c89c57b81a966dca1a59d7457a94a0c4a76c0f Mon Sep 17 00:00:00 2001 From: jackwener Date: Thu, 1 Dec 2022 13:01:58 +0800 Subject: [PATCH 6/7] polish --- datafusion/optimizer/src/push_down_filter.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index f0073c203e4ef..c58e6d23589f3 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -856,7 +856,7 @@ mod tests { let table_scan = test_table_scan()?; let plan = LogicalPlanBuilder::from(table_scan) .aggregate(vec![col("a")], vec![sum(col("b")).alias("total_salary")])? - .filter(col("test.a").gt(lit(10i64)))? + .filter(col("a").gt(lit(10i64)))? .build()?; // filter of key aggregation is commutative let expected = "\ From f9a7072799504567b85e22d46d1fe2759faabafe Mon Sep 17 00:00:00 2001 From: jackwener Date: Thu, 1 Dec 2022 14:50:13 +0800 Subject: [PATCH 7/7] remove wrong UT. --- datafusion/optimizer/src/push_down_filter.rs | 21 -------------------- 1 file changed, 21 deletions(-) diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index c58e6d23589f3..b61ba0f2179b0 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -907,27 +907,6 @@ mod tests { assert_optimized_plan_eq(&plan, expected) } - #[test] - fn filter_keep_partial_agg() -> Result<()> { - let table_scan = test_table_scan()?; - let f1 = col("c").eq(lit(1i64)).and(col("b").gt(lit(2i64))); - let f2 = col("c").eq(lit(1i64)).and(col("b").gt(lit(3i64))); - let filter = f1.or(f2); - let plan = LogicalPlanBuilder::from(table_scan) - .aggregate(vec![col("a")], vec![sum(col("b")).alias("b")])? - .filter(filter)? - .build()?; - // filter of aggregate is after aggregation since they are non-commutative - // (c =1 AND b > 2) OR (c = 1 AND b > 3) - // rewrite to CNF - // (c = 1 OR c = 1) [can pushDown] AND (c = 1 OR b > 3) AND (b > 2 OR C = 1) AND (b > 2 OR b > 3) - - let expected = "Filter: (test.c = Int64(1) OR test.c = Int64(1)) AND (test.c = Int64(1) OR b > Int64(3)) AND (b > Int64(2) OR test.c = Int64(1)) AND (b > Int64(2) OR b > Int64(3))\ - \n Aggregate: groupBy=[[test.a]], aggr=[[SUM(test.b) AS b]]\ - \n TableScan: test"; - assert_optimized_plan_eq(&plan, expected) - } - /// verifies that a filter is pushed to before a projection, the filter expression is correctly re-written #[test] fn alias() -> Result<()> {