diff --git a/datafusion/optimizer/src/single_distinct_to_groupby.rs b/datafusion/optimizer/src/single_distinct_to_groupby.rs index 414217612d1e4..3565cf07c7157 100644 --- a/datafusion/optimizer/src/single_distinct_to_groupby.rs +++ b/datafusion/optimizer/src/single_distinct_to_groupby.rs @@ -24,6 +24,7 @@ use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::Result; use datafusion_expr::{ + aggregate_function::AggregateFunction::{Max, Min, Sum}, col, expr::AggregateFunction, logical_plan::{Aggregate, LogicalPlan}, @@ -34,17 +35,19 @@ use hashbrown::HashSet; /// single distinct to group by optimizer rule /// ```text -/// SELECT F1(DISTINCT s),F2(DISTINCT s) -/// ... -/// GROUP BY k +/// Before: +/// SELECT a, COUNT(DINSTINCT b), SUM(c) +/// FROM t +/// GROUP BY a /// -/// Into -/// -/// SELECT F1(alias1),F2(alias1) +/// After: +/// SELECT a, COUNT(alias1), SUM(alias2) /// FROM ( -/// SELECT s as alias1, k ... GROUP BY s, k +/// SELECT a, b as alias1, SUM(c) as alias2 +/// FROM t +/// GROUP BY a, b /// ) -/// GROUP BY k +/// GROUP BY a /// ``` #[derive(Default)] pub struct SingleDistinctToGroupBy {} @@ -58,27 +61,37 @@ impl SingleDistinctToGroupBy { } } -/// Check whether all aggregate exprs are distinct on a single field. +/// Check whether all distinct aggregate exprs are distinct on a single field. fn is_single_distinct_agg(plan: &LogicalPlan) -> Result { match plan { LogicalPlan::Aggregate(Aggregate { aggr_expr, .. }) => { let mut fields_set = HashSet::new(); - let mut distinct_count = 0; + let mut aggregate_count = 0; for expr in aggr_expr { if let Expr::AggregateFunction(AggregateFunction { - distinct, args, .. + fun, + distinct, + args, + filter, + .. }) = expr { - if *distinct { - distinct_count += 1; - } - for e in args { - fields_set.insert(e.canonical_name()); + match filter { + Some(_) => return Ok(false), + None => { + aggregate_count += 1; + if *distinct { + for e in args { + fields_set.insert(e.canonical_name()); + } + } else if !matches!(fun, Sum | Min | Max) { + return Ok(false); + } + } } } } - let res = distinct_count == aggr_expr.len() && fields_set.len() == 1; - Ok(res) + Ok(fields_set.len() == 1 && aggregate_count == aggr_expr.len()) } _ => Ok(false), } @@ -151,31 +164,60 @@ impl OptimizerRule for SingleDistinctToGroupBy { .collect::>(); // replace the distinct arg with alias + let mut index = 1; let mut group_fields_set = HashSet::new(); + let mut inner_aggr_exprs = vec![]; let outer_aggr_exprs = aggr_expr .iter() .map(|aggr_expr| match aggr_expr { Expr::AggregateFunction(AggregateFunction { fun, args, - filter, order_by, + distinct, .. }) => { // is_single_distinct_agg ensure args.len=1 - if group_fields_set.insert(args[0].display_name()?) { + if *distinct + && group_fields_set.insert(args[0].display_name()?) + { inner_group_exprs.push( args[0].clone().alias(SINGLE_DISTINCT_ALIAS), ); } - Ok(Expr::AggregateFunction(AggregateFunction::new( - fun.clone(), - vec![col(SINGLE_DISTINCT_ALIAS)], - false, // intentional to remove distinct here - filter.clone(), - order_by.clone(), - )) - .alias(aggr_expr.display_name()?)) + + // if the aggregate function is not distinct, we need to rewrite it like two phase aggregation + if !(*distinct) { + index += 1; + let alias_str = format!("alias{}", index); + let inner_expr = + Expr::AggregateFunction(AggregateFunction::new( + fun.clone(), + args.clone(), + false, + None, + order_by.clone(), + )) + .alias(&alias_str); + inner_aggr_exprs.push(inner_expr); + Ok(Expr::AggregateFunction(AggregateFunction::new( + fun.clone(), + vec![col(&alias_str)], + false, + None, + order_by.clone(), + )) + .alias(aggr_expr.display_name()?)) + } else { + Ok(Expr::AggregateFunction(AggregateFunction::new( + fun.clone(), + vec![col(SINGLE_DISTINCT_ALIAS)], + false, // intentional to remove distinct here + None, + order_by.clone(), + )) + .alias(aggr_expr.display_name()?)) + } } _ => Ok(aggr_expr.clone()), }) @@ -185,7 +227,7 @@ impl OptimizerRule for SingleDistinctToGroupBy { let inner_agg = LogicalPlan::Aggregate(Aggregate::try_new( input.clone(), inner_group_exprs, - Vec::new(), + inner_aggr_exprs, )?); Ok(Some(LogicalPlan::Aggregate(Aggregate::try_new( @@ -217,7 +259,7 @@ mod tests { use datafusion_expr::expr; use datafusion_expr::expr::GroupingSet; use datafusion_expr::{ - col, count, count_distinct, lit, logical_plan::builder::LogicalPlanBuilder, max, + col, count_distinct, lit, logical_plan::builder::LogicalPlanBuilder, max, sum, AggregateFunction, }; @@ -396,6 +438,52 @@ mod tests { assert_optimized_plan_equal(&plan, expected) } + #[test] + fn two_distinct_and_one_common() -> Result<()> { + let table_scan = test_table_scan()?; + + let plan = LogicalPlanBuilder::from(table_scan) + .aggregate( + vec![col("a")], + vec![ + sum(col("c")), + count_distinct(col("b")), + Expr::AggregateFunction(expr::AggregateFunction::new( + AggregateFunction::Max, + vec![col("b")], + true, + None, + None, + )), + ], + )? + .build()?; + // Should work + let expected = "Aggregate: groupBy=[[test.a]], aggr=[[SUM(alias2) AS SUM(test.c), COUNT(alias1) AS COUNT(DISTINCT test.b), MAX(alias1) AS MAX(DISTINCT test.b)]] [a:UInt32, SUM(test.c):UInt64;N, COUNT(DISTINCT test.b):Int64;N, MAX(DISTINCT test.b):UInt32;N]\ + \n Aggregate: groupBy=[[test.a, test.b AS alias1]], aggr=[[SUM(test.c) AS alias2]] [a:UInt32, alias1:UInt32, alias2:UInt64;N]\ + \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; + + assert_optimized_plan_equal(&plan, expected) + } + + #[test] + fn one_distinctand_two() -> Result<()> { + let table_scan = test_table_scan()?; + + let plan = LogicalPlanBuilder::from(table_scan) + .aggregate( + vec![col("a")], + vec![sum(col("c")), max(col("c")), count_distinct(col("b"))], + )? + .build()?; + // Should work + let expected = "Aggregate: groupBy=[[test.a]], aggr=[[SUM(alias2) AS SUM(test.c), MAX(alias3) AS MAX(test.c), COUNT(alias1) AS COUNT(DISTINCT test.b)]] [a:UInt32, SUM(test.c):UInt64;N, MAX(test.c):UInt32;N, COUNT(DISTINCT test.b):Int64;N]\ + \n Aggregate: groupBy=[[test.a, test.b AS alias1]], aggr=[[SUM(test.c) AS alias2, MAX(test.c) AS alias3]] [a:UInt32, alias1:UInt32, alias2:UInt64;N, alias3:UInt32;N]\ + \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; + + assert_optimized_plan_equal(&plan, expected) + } + #[test] fn distinct_and_common() -> Result<()> { let table_scan = test_table_scan()?; @@ -403,13 +491,14 @@ mod tests { let plan = LogicalPlanBuilder::from(table_scan) .aggregate( vec![col("a")], - vec![count_distinct(col("b")), count(col("c"))], + vec![count_distinct(col("b")), sum(col("c"))], )? .build()?; - // Do nothing - let expected = "Aggregate: groupBy=[[test.a]], aggr=[[COUNT(DISTINCT test.b), COUNT(test.c)]] [a:UInt32, COUNT(DISTINCT test.b):Int64;N, COUNT(test.c):Int64;N]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; + // Should work + let expected = "Aggregate: groupBy=[[test.a]], aggr=[[COUNT(alias1) AS COUNT(DISTINCT test.b), SUM(alias2) AS SUM(test.c)]] [a:UInt32, COUNT(DISTINCT test.b):Int64;N, SUM(test.c):UInt64;N]\ + \n Aggregate: groupBy=[[test.a, test.b AS alias1]], aggr=[[SUM(test.c) AS alias2]] [a:UInt32, alias1:UInt32, alias2:UInt64;N]\ + \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; assert_optimized_plan_equal(&plan, expected) } diff --git a/datafusion/sqllogictest/test_files/groupby.slt b/datafusion/sqllogictest/test_files/groupby.slt index 105f11f21628d..5bee97231ae06 100644 --- a/datafusion/sqllogictest/test_files/groupby.slt +++ b/datafusion/sqllogictest/test_files/groupby.slt @@ -3841,3 +3841,54 @@ ProjectionExec: expr=[SUM(DISTINCT t1.x)@1 as SUM(DISTINCT t1.x), MAX(DISTINCT t ------------------AggregateExec: mode=Partial, gby=[y@1 as y, CAST(t1.x AS Float64)t1.x@0 as alias1], aggr=[] --------------------ProjectionExec: expr=[CAST(x@0 AS Float64) as CAST(t1.x AS Float64)t1.x, y@1 as y] ----------------------MemoryExec: partitions=1, partition_sizes=[1] + +statement ok +CREATE EXTERNAL TABLE aggregate_test_100 ( + c1 VARCHAR NOT NULL, + c2 TINYINT NOT NULL, + c3 SMALLINT NOT NULL, + c4 SMALLINT, + c5 INT, + c6 BIGINT NOT NULL, + c7 SMALLINT NOT NULL, + c8 INT NOT NULL, + c9 INT UNSIGNED NOT NULL, + c10 BIGINT UNSIGNED NOT NULL, + c11 FLOAT NOT NULL, + c12 DOUBLE NOT NULL, + c13 VARCHAR NOT NULL +) +STORED AS CSV +WITH HEADER ROW +LOCATION '../../testing/data/csv/aggregate_test_100.csv' + +query TIIII +SELECT c1, count(distinct c2), min(distinct c2), min(c3), max(c4) FROM aggregate_test_100 GROUP BY c1 ORDER BY c1; +---- +a 5 1 -101 32064 +b 5 1 -117 25286 +c 5 1 -117 29106 +d 5 1 -99 31106 +e 5 1 -95 32514 + +query TT +EXPLAIN SELECT c1, count(distinct c2), min(distinct c2), sum(c3), max(c4) FROM aggregate_test_100 GROUP BY c1 ORDER BY c1; +---- +logical_plan +Sort: aggregate_test_100.c1 ASC NULLS LAST +--Aggregate: groupBy=[[aggregate_test_100.c1]], aggr=[[COUNT(alias1) AS COUNT(DISTINCT aggregate_test_100.c2), MIN(alias1) AS MIN(DISTINCT aggregate_test_100.c2), SUM(alias2) AS SUM(aggregate_test_100.c3), MAX(alias3) AS MAX(aggregate_test_100.c4)]] +----Aggregate: groupBy=[[aggregate_test_100.c1, aggregate_test_100.c2 AS alias1]], aggr=[[SUM(CAST(aggregate_test_100.c3 AS Int64)) AS alias2, MAX(aggregate_test_100.c4) AS alias3]] +------TableScan: aggregate_test_100 projection=[c1, c2, c3, c4] +physical_plan +SortPreservingMergeExec: [c1@0 ASC NULLS LAST] +--SortExec: expr=[c1@0 ASC NULLS LAST] +----AggregateExec: mode=FinalPartitioned, gby=[c1@0 as c1], aggr=[COUNT(DISTINCT aggregate_test_100.c2), MIN(DISTINCT aggregate_test_100.c2), SUM(aggregate_test_100.c3), MAX(aggregate_test_100.c4)] +------CoalesceBatchesExec: target_batch_size=2 +--------RepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8 +----------AggregateExec: mode=Partial, gby=[c1@0 as c1], aggr=[COUNT(DISTINCT aggregate_test_100.c2), MIN(DISTINCT aggregate_test_100.c2), SUM(aggregate_test_100.c3), MAX(aggregate_test_100.c4)] +------------AggregateExec: mode=FinalPartitioned, gby=[c1@0 as c1, alias1@1 as alias1], aggr=[alias2, alias3] +--------------CoalesceBatchesExec: target_batch_size=2 +----------------RepartitionExec: partitioning=Hash([c1@0, alias1@1], 8), input_partitions=8 +------------------AggregateExec: mode=Partial, gby=[c1@0 as c1, c2@1 as alias1], aggr=[alias2, alias3] +--------------------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 +----------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1, c2, c3, c4], has_header=true diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs index ca2b4d48c4602..a89e32fd721de 100644 --- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs @@ -208,7 +208,7 @@ async fn simple_aggregate() -> Result<()> { #[tokio::test] async fn aggregate_distinct_with_having() -> Result<()> { - roundtrip("SELECT a, count(distinct b) FROM data GROUP BY a, c HAVING count(b) > 100") + roundtrip("SELECT a, count(distinct b), sum(distinct e) FROM data GROUP BY a, c HAVING count(b) > 100") .await } @@ -267,6 +267,33 @@ async fn select_distinct_two_fields() -> Result<()> { .await } +#[tokio::test] +async fn simple_distinct_aggregate() -> Result<()> { + test_alias( + "SELECT a, COUNT(DISTINCT b) FROM data GROUP BY a", + "SELECT a, COUNT(b) FROM (SELECT a, b FROM data GROUP BY a, b) GROUP BY a", + ) + .await +} + +#[tokio::test] +async fn select_distinct_aggregate_two_fields() -> Result<()> { + test_alias( + "SELECT a, COUNT(DISTINCT b), MAX(DISTINCT b) FROM data GROUP BY a", + "SELECT a, COUNT(b), MAX(b) FROM (SELECT a, b FROM data GROUP BY a, b) GROUP BY a", + ) + .await +} + +#[tokio::test] +async fn select_distinct_aggregate_and_no_distinct_aggregate() -> Result<()> { + test_alias( + "SELECT a, COUNT(DISTINCT b), SUM(e) FROM data GROUP by a", + "SELECT a, COUNT(b), SUM(\"SUM(data.e)\") FROM (SELECT a, b, SUM(e) FROM data GROUP BY a, b) GROUP BY a", + ) + .await +} + #[tokio::test] async fn simple_alias() -> Result<()> { test_alias("SELECT d1.a, d1.b FROM data d1", "SELECT a, b FROM data").await