diff --git a/datafusion/optimizer/src/single_distinct_to_groupby.rs b/datafusion/optimizer/src/single_distinct_to_groupby.rs index dd220e16e6b2c..b61f0c25e5d1c 100644 --- a/datafusion/optimizer/src/single_distinct_to_groupby.rs +++ b/datafusion/optimizer/src/single_distinct_to_groupby.rs @@ -19,7 +19,6 @@ use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::{DFSchema, Result}; -use datafusion_expr::utils::grouping_set_to_exprlist; use datafusion_expr::{ col, logical_plan::{Aggregate, LogicalPlan, Projection}, @@ -64,80 +63,98 @@ fn optimize(plan: &LogicalPlan) -> Result { group_expr, }) => { if is_single_distinct_agg(plan)? && !contains_grouping_set(group_expr) { - let mut group_fields_set = HashSet::new(); - let base_group_expr = grouping_set_to_exprlist(group_expr)?; - let mut all_group_args: Vec = group_expr.clone(); + // alias all original group_by exprs + let mut group_expr_alias = Vec::with_capacity(group_expr.len()); + let mut inner_group_exprs = group_expr + .iter() + .enumerate() + .map(|(i, group_expr)| { + let alias_str = format!("group_alias_{}", i); + let alias_expr = group_expr.clone().alias(&alias_str); + group_expr_alias.push((alias_str, schema.fields()[i].clone())); + alias_expr + }) + .collect::>(); + + // and they can be referenced by the alias in the outer aggr plan + let outer_group_exprs = group_expr_alias + .iter() + .map(|(alias, _)| col(alias)) + .collect::>(); - // remove distinct and collection args - let mut new_aggr_expr = Vec::with_capacity(aggr_expr.len()); - for agg_expr in aggr_expr { - let x = match agg_expr { + // replace the distinct arg with alias + let mut group_fields_set = HashSet::new(); + let new_aggr_exprs = aggr_expr + .iter() + .map(|aggr_expr| match aggr_expr { Expr::AggregateFunction { fun, args, .. } => { // is_single_distinct_agg ensure args.len=1 if group_fields_set.insert(args[0].name()?) { - all_group_args + inner_group_exprs .push(args[0].clone().alias(SINGLE_DISTINCT_ALIAS)); } - Expr::AggregateFunction { + Ok(Expr::AggregateFunction { fun: fun.clone(), args: vec![col(SINGLE_DISTINCT_ALIAS)], distinct: false, // intentional to remove distinct here - } + }) } - _ => agg_expr.clone(), - }; - new_aggr_expr.push(x); - } - - let all_group_expr = grouping_set_to_exprlist(&all_group_args)?; + _ => Ok(aggr_expr.clone()), + }) + .collect::>>()?; - let all_field = all_group_expr + // construct the inner AggrPlan + let inner_fields = inner_group_exprs .iter() .map(|expr| expr.to_field(input.schema())) .collect::>>()?; - - let grouped_schema = DFSchema::new_with_metadata( - all_field, + let inner_schema = DFSchema::new_with_metadata( + inner_fields, input.schema().metadata().clone(), )?; - let grouped_agg = LogicalPlan::Aggregate(Aggregate::try_new( + let grouped_aggr = LogicalPlan::Aggregate(Aggregate::try_new( input.clone(), - all_group_args, + inner_group_exprs, Vec::new(), - Arc::new(grouped_schema.clone()), + Arc::new(inner_schema.clone()), )?); - let grouped_agg = optimize_children(&grouped_agg); - let final_agg_schema = Arc::new(DFSchema::new_with_metadata( - base_group_expr + let inner_agg = optimize_children(&grouped_aggr)?; + + let outer_aggr_schema = Arc::new(DFSchema::new_with_metadata( + outer_group_exprs .iter() - .chain(new_aggr_expr.iter()) - .map(|expr| expr.to_field(&grouped_schema)) + .chain(new_aggr_exprs.iter()) + .map(|expr| expr.to_field(&inner_schema)) .collect::>>()?, input.schema().metadata().clone(), )?); // so the aggregates are displayed in the same way even after the rewrite + // this optimizer has two kinds of alias: + // - group_by aggr + // - aggr expr let mut alias_expr: Vec = Vec::new(); - base_group_expr - .iter() - .chain(new_aggr_expr.iter()) - .enumerate() - .for_each(|(i, field)| { - alias_expr.push(columnize_expr( - field.clone().alias(schema.clone().fields()[i].name()), - &final_agg_schema, - )); - }); - - let final_agg = LogicalPlan::Aggregate(Aggregate::try_new( - Arc::new(grouped_agg?), - group_expr.clone(), - new_aggr_expr, - final_agg_schema, + for (alias, original_field) in group_expr_alias { + alias_expr.push(col(&alias).alias(original_field.name())); + } + for (i, expr) in new_aggr_exprs.iter().enumerate() { + alias_expr.push(columnize_expr( + expr.clone() + .alias(schema.clone().fields()[i + group_expr.len()].name()), + &outer_aggr_schema, + )); + } + + let outer_aggr = LogicalPlan::Aggregate(Aggregate::try_new( + Arc::new(inner_agg), + outer_group_exprs, + new_aggr_exprs, + outer_aggr_schema, )?); + Ok(LogicalPlan::Projection(Projection::try_new_with_schema( alias_expr, - Arc::new(final_agg), + Arc::new(outer_aggr), schema.clone(), None, )?)) @@ -159,27 +176,30 @@ fn optimize_children(plan: &LogicalPlan) -> Result { from_plan(plan, &expr, &new_inputs) } +/// Check whether all aggregate exprs are distinct on a single field. fn is_single_distinct_agg(plan: &LogicalPlan) -> Result { match plan { LogicalPlan::Aggregate(Aggregate { aggr_expr, .. }) => { let mut fields_set = HashSet::new(); - let mut count = 0; + let mut distinct_count = 0; for expr in aggr_expr { if let Expr::AggregateFunction { distinct, args, .. } = expr { if *distinct { - count += 1; + distinct_count += 1; } - for expr in args { - fields_set.insert(expr.name()?); + for e in args { + fields_set.insert(e.name()?); } } } - Ok(count == aggr_expr.len() && fields_set.len() == 1) + let res = distinct_count == aggr_expr.len() && fields_set.len() == 1; + Ok(res) } _ => Ok(false), } } +/// Check if the first expr is [Expr::GroupingSet]. fn contains_grouping_set(expr: &[Expr]) -> bool { matches!(expr.first(), Some(Expr::GroupingSet(_))) } @@ -341,9 +361,9 @@ mod tests { .build()?; // Should work - let expected = "Projection: #test.a AS a, #COUNT(alias1) AS COUNT(DISTINCT test.b) [a:UInt32, COUNT(DISTINCT test.b):Int64;N]\ - \n Aggregate: groupBy=[[#test.a]], aggr=[[COUNT(#alias1)]] [a:UInt32, COUNT(alias1):Int64;N]\ - \n Aggregate: groupBy=[[#test.a, #test.b AS alias1]], aggr=[[]] [a:UInt32, alias1:UInt32]\ + let expected = "Projection: #group_alias_0 AS a, #COUNT(alias1) AS COUNT(DISTINCT test.b) [a:UInt32, COUNT(DISTINCT test.b):Int64;N]\ + \n Aggregate: groupBy=[[#group_alias_0]], aggr=[[COUNT(#alias1)]] [group_alias_0:UInt32, COUNT(alias1):Int64;N]\ + \n Aggregate: groupBy=[[#test.a AS group_alias_0, #test.b AS alias1]], aggr=[[]] [group_alias_0:UInt32, alias1:UInt32]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; assert_optimized_plan_eq(&plan, expected); @@ -387,9 +407,9 @@ mod tests { )? .build()?; // Should work - let expected = "Projection: #test.a AS a, #COUNT(alias1) AS COUNT(DISTINCT test.b), #MAX(alias1) AS MAX(DISTINCT test.b) [a:UInt32, COUNT(DISTINCT test.b):Int64;N, MAX(DISTINCT test.b):UInt32;N]\ - \n Aggregate: groupBy=[[#test.a]], aggr=[[COUNT(#alias1), MAX(#alias1)]] [a:UInt32, COUNT(alias1):Int64;N, MAX(alias1):UInt32;N]\ - \n Aggregate: groupBy=[[#test.a, #test.b AS alias1]], aggr=[[]] [a:UInt32, alias1:UInt32]\ + let expected = "Projection: #group_alias_0 AS a, #COUNT(alias1) AS COUNT(DISTINCT test.b), #MAX(alias1) AS MAX(DISTINCT test.b) [a:UInt32, COUNT(DISTINCT test.b):Int64;N, MAX(DISTINCT test.b):UInt32;N]\ + \n Aggregate: groupBy=[[#group_alias_0]], aggr=[[COUNT(#alias1), MAX(#alias1)]] [group_alias_0:UInt32, COUNT(alias1):Int64;N, MAX(alias1):UInt32;N]\ + \n Aggregate: groupBy=[[#test.a AS group_alias_0, #test.b AS alias1]], aggr=[[]] [group_alias_0:UInt32, alias1:UInt32]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; assert_optimized_plan_eq(&plan, expected); @@ -414,4 +434,23 @@ mod tests { assert_optimized_plan_eq(&plan, expected); Ok(()) } + + #[test] + fn group_by_with_expr() { + let table_scan = test_table_scan().unwrap(); + + let plan = LogicalPlanBuilder::from(table_scan) + .aggregate(vec![col("a") + lit(1)], vec![count_distinct(col("c"))]) + .unwrap() + .build() + .unwrap(); + + // Should work + let expected = "Projection: #group_alias_0 AS test.a + Int32(1), #COUNT(alias1) AS COUNT(DISTINCT test.c) [test.a + Int32(1):Int32, COUNT(DISTINCT test.c):Int64;N]\ + \n Aggregate: groupBy=[[#group_alias_0]], aggr=[[COUNT(#alias1)]] [group_alias_0:Int32, COUNT(alias1):Int64;N]\ + \n Aggregate: groupBy=[[#test.a + Int32(1) AS group_alias_0, #test.c AS alias1]], aggr=[[]] [group_alias_0:Int32, alias1:UInt32]\ + \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; + + assert_optimized_plan_eq(&plan, expected); + } }