diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index e59590df55157..b61ba0f2179b0 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -17,7 +17,6 @@ use crate::utils::conjunction; use crate::{utils, OptimizerConfig, OptimizerRule}; use datafusion_common::{Column, DFSchema, DataFusionError, Result}; -use datafusion_expr::utils::exprlist_to_columns; use datafusion_expr::{ and, expr_rewriter::{replace_col, ExprRewritable, ExprRewriter}, @@ -620,19 +619,12 @@ impl OptimizerRule for PushDownFilter { }) } LogicalPlan::Aggregate(agg) => { - // An aggregate's aggregate columns are _not_ filter-commutable => collect these: - // * columns whose aggregation expression depends on - // * the aggregation columns themselves - - // construct set of columns that `aggr_expr` depends on - let mut used_columns = HashSet::new(); - exprlist_to_columns(&agg.aggr_expr, &mut used_columns)?; - let agg_columns = agg - .aggr_expr + // We can push down Predicate which in groupby_expr. + let group_expr_columns = agg + .group_expr .iter() - .map(|x| Ok(Column::from_name(x.display_name()?))) + .map(|e| Ok(Column::from_qualified_name(&(e.display_name()?)))) .collect::>>()?; - used_columns.extend(agg_columns); let predicates = utils::split_conjunction_owned(utils::cnf_rewrite( filter.predicate().clone(), @@ -641,20 +633,27 @@ impl OptimizerRule for PushDownFilter { let mut keep_predicates = vec![]; let mut push_predicates = vec![]; for expr in predicates { - let columns = expr.to_columns()?; - if columns.is_empty() - || !columns - .intersection(&used_columns) - .collect::>() - .is_empty() - { - keep_predicates.push(expr); - } else { + let cols = expr.to_columns()?; + if cols.iter().all(|c| group_expr_columns.contains(c)) { push_predicates.push(expr); + } else { + keep_predicates.push(expr); } } - let child = match conjunction(push_predicates) { + // As for plan Filter: Column(a+b) > 0 -- Agg: groupby:[Column(a)+Column(b)] + // After push, we need to replace `a+b` with Column(a)+Column(b) + // So we need create a replace_map, add {`a+b` --> Expr(Column(a)+Column(b))} + let mut replace_map = HashMap::new(); + for expr in &agg.group_expr { + replace_map.insert(expr.display_name()?, expr.clone()); + } + let replaced_push_predicates = push_predicates + .iter() + .map(|expr| replace_cols_by_name(expr.clone(), &replace_map)) + .collect::>>()?; + + let child = match conjunction(replaced_push_predicates) { Some(predicate) => LogicalPlan::Filter(Filter::try_new( predicate, Arc::new((*agg.input).clone()), @@ -881,40 +880,30 @@ mod tests { } #[test] - fn filter_keep_agg() -> Result<()> { - let table_scan = test_table_scan()?; - let plan = LogicalPlanBuilder::from(table_scan) - .aggregate(vec![col("a")], vec![sum(col("b")).alias("b")])? - .filter(col("b").gt(lit(10i64)))? + fn push_agg_need_replace_expr() -> Result<()> { + let plan = LogicalPlanBuilder::from(test_table_scan()?) + .aggregate(vec![add(col("b"), col("a"))], vec![sum(col("a")), col("b")])? + .filter(col("test.b + test.a").gt(lit(10i64)))? .build()?; - // filter of aggregate is after aggregation since they are non-commutative - let expected = "\ - Filter: b > Int64(10)\ - \n Aggregate: groupBy=[[test.a]], aggr=[[SUM(test.b) AS b]]\ - \n TableScan: test"; + let expected = + "Aggregate: groupBy=[[test.b + test.a]], aggr=[[SUM(test.a), test.b]]\ + \n Filter: test.b + test.a > Int64(10)\ + \n TableScan: test"; assert_optimized_plan_eq(&plan, expected) } #[test] - fn filter_keep_partial_agg() -> Result<()> { + fn filter_keep_agg() -> Result<()> { let table_scan = test_table_scan()?; - let f1 = col("c").eq(lit(1i64)).and(col("b").gt(lit(2i64))); - let f2 = col("c").eq(lit(1i64)).and(col("b").gt(lit(3i64))); - let filter = f1.or(f2); let plan = LogicalPlanBuilder::from(table_scan) .aggregate(vec![col("a")], vec![sum(col("b")).alias("b")])? - .filter(filter)? + .filter(col("b").gt(lit(10i64)))? .build()?; // filter of aggregate is after aggregation since they are non-commutative - // (c =1 AND b > 2) OR (c = 1 AND b > 3) - // rewrite to CNF - // (c = 1 OR c = 1) [can pushDown] AND (c = 1 OR b > 3) AND (b > 2 OR C = 1) AND (b > 2 OR b > 3) - let expected = "\ - Filter: (test.c = Int64(1) OR b > Int64(3)) AND (b > Int64(2) OR test.c = Int64(1)) AND (b > Int64(2) OR b > Int64(3))\ - \n Aggregate: groupBy=[[test.a]], aggr=[[SUM(test.b) AS b]]\ - \n Filter: test.c = Int64(1) OR test.c = Int64(1)\ - \n TableScan: test"; + Filter: b > Int64(10)\ + \n Aggregate: groupBy=[[test.a]], aggr=[[SUM(test.b) AS b]]\ + \n TableScan: test"; assert_optimized_plan_eq(&plan, expected) } @@ -1870,17 +1859,9 @@ mod tests { #[async_trait] impl TableSource for PushDownProvider { fn schema(&self) -> SchemaRef { - Arc::new(arrow::datatypes::Schema::new(vec![ - arrow::datatypes::Field::new( - "a", - arrow::datatypes::DataType::Int32, - true, - ), - arrow::datatypes::Field::new( - "b", - arrow::datatypes::DataType::Int32, - true, - ), + Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), ])) } diff --git a/datafusion/optimizer/tests/integration-test.rs b/datafusion/optimizer/tests/integration-test.rs index 457ea833ef3a9..701d1a84c604b 100644 --- a/datafusion/optimizer/tests/integration-test.rs +++ b/datafusion/optimizer/tests/integration-test.rs @@ -304,6 +304,18 @@ fn join_keys_in_subquery_alias_1() { assert_eq!(expected, format!("{:?}", plan)); } +#[test] +fn push_down_filter_groupby_expr_contains_alias() { + let sql = "SELECT * FROM (SELECT (col_int32 + col_uint32) AS c, count(*) FROM test GROUP BY 1) where c > 3"; + let plan = test_sql(sql).unwrap(); + let expected = "Projection: c, COUNT(UInt8(1))\ + \n Projection: test.col_int32 + test.col_uint32 AS c, COUNT(UInt8(1))\ + \n Aggregate: groupBy=[[test.col_int32 + CAST(test.col_uint32 AS Int32)]], aggr=[[COUNT(UInt8(1))]]\ + \n Filter: test.col_int32 + CAST(test.col_uint32 AS Int32) > Int32(3)\ + \n TableScan: test projection=[col_int32, col_uint32]"; + assert_eq!(expected, format!("{:?}", plan)); +} + fn test_sql(sql: &str) -> Result { // parse the SQL let dialect = GenericDialect {}; // or AnsiDialect, or your own dialect ...