From b074280362525894a44cee862ef269f84487f4f7 Mon Sep 17 00:00:00 2001 From: Chongchen Chen Date: Sat, 26 Apr 2025 17:13:02 +0800 Subject: [PATCH 1/6] feat: simplify count distinct logical plan --- .../src/single_distinct_to_groupby.rs | 12 +++++++ datafusion/sqllogictest/test_files/joins.slt | 33 ++++++++----------- 2 files changed, 25 insertions(+), 20 deletions(-) diff --git a/datafusion/optimizer/src/single_distinct_to_groupby.rs b/datafusion/optimizer/src/single_distinct_to_groupby.rs index 7337d2ffce5c3..9ae1266f8e5ec 100644 --- a/datafusion/optimizer/src/single_distinct_to_groupby.rs +++ b/datafusion/optimizer/src/single_distinct_to_groupby.rs @@ -27,6 +27,7 @@ use datafusion_common::{ }; use datafusion_expr::builder::project; use datafusion_expr::expr::AggregateFunctionParams; +use datafusion_expr::AggregateUDF; use datafusion_expr::{ col, expr::AggregateFunction, @@ -66,6 +67,7 @@ impl SingleDistinctToGroupBy { fn is_single_distinct_agg(aggr_expr: &[Expr]) -> Result { let mut fields_set = HashSet::new(); let mut aggregate_count = 0; + let mut distinct_func: Option> = None; for expr in aggr_expr { if let Expr::AggregateFunction(AggregateFunction { func, @@ -87,6 +89,7 @@ fn is_single_distinct_agg(aggr_expr: &[Expr]) -> Result { for e in args { fields_set.insert(e); } + distinct_func = Some(func.clone()); } else if func.name() != "sum" && func.name().to_lowercase() != "min" && func.name().to_lowercase() != "max" @@ -97,6 +100,15 @@ fn is_single_distinct_agg(aggr_expr: &[Expr]) -> Result { return Ok(false); } } + + if aggregate_count == aggr_expr.len() && fields_set.len() == 1 { + if let Some(distinct_func) = distinct_func { + if distinct_func.name() == "count" { + return Ok(false); + } + } + } + Ok(aggregate_count == aggr_expr.len() && fields_set.len() == 1) } diff --git a/datafusion/sqllogictest/test_files/joins.slt b/datafusion/sqllogictest/test_files/joins.slt index ddf701ba04efe..ec79e94949978 100644 --- a/datafusion/sqllogictest/test_files/joins.slt +++ b/datafusion/sqllogictest/test_files/joins.slt @@ -1413,27 +1413,20 @@ from join_t1 inner join join_t2 on join_t1.t1_id = join_t2.t2_id ---- logical_plan -01)Projection: count(alias1) AS count(DISTINCT join_t1.t1_id) -02)--Aggregate: groupBy=[[]], aggr=[[count(alias1)]] -03)----Aggregate: groupBy=[[join_t1.t1_id AS alias1]], aggr=[[]] -04)------Projection: join_t1.t1_id -05)--------Inner Join: join_t1.t1_id = join_t2.t2_id -06)----------TableScan: join_t1 projection=[t1_id] -07)----------TableScan: join_t2 projection=[t2_id] +01)Aggregate: groupBy=[[]], aggr=[[count(DISTINCT join_t1.t1_id)]] +02)--Projection: join_t1.t1_id +03)----Inner Join: join_t1.t1_id = join_t2.t2_id +04)------TableScan: join_t1 projection=[t1_id] +05)------TableScan: join_t2 projection=[t2_id] physical_plan -01)ProjectionExec: expr=[count(alias1)@0 as count(DISTINCT join_t1.t1_id)] -02)--AggregateExec: mode=Final, gby=[], aggr=[count(alias1)] -03)----CoalescePartitionsExec -04)------AggregateExec: mode=Partial, gby=[], aggr=[count(alias1)] -05)--------AggregateExec: mode=FinalPartitioned, gby=[alias1@0 as alias1], aggr=[] -06)----------CoalesceBatchesExec: target_batch_size=2 -07)------------RepartitionExec: partitioning=Hash([alias1@0], 2), input_partitions=2 -08)--------------AggregateExec: mode=Partial, gby=[t1_id@0 as alias1], aggr=[] -09)----------------CoalesceBatchesExec: target_batch_size=2 -10)------------------HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(t1_id@0, t2_id@0)], projection=[t1_id@0] -11)--------------------DataSourceExec: partitions=1, partition_sizes=[1] -12)--------------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 -13)----------------------DataSourceExec: partitions=1, partition_sizes=[1] +01)AggregateExec: mode=Final, gby=[], aggr=[count(DISTINCT join_t1.t1_id)] +02)--CoalescePartitionsExec +03)----AggregateExec: mode=Partial, gby=[], aggr=[count(DISTINCT join_t1.t1_id)] +04)------CoalesceBatchesExec: target_batch_size=2 +05)--------HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(t1_id@0, t2_id@0)], projection=[t1_id@0] +06)----------DataSourceExec: partitions=1, partition_sizes=[1] +07)----------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +08)------------DataSourceExec: partitions=1, partition_sizes=[1] statement ok set datafusion.explain.logical_plan_only = true; From b5b99a942e14864936085592d56702037f953779 Mon Sep 17 00:00:00 2001 From: Chongchen Chen Date: Sat, 26 Apr 2025 17:18:53 +0800 Subject: [PATCH 2/6] fmt --- datafusion/optimizer/src/single_distinct_to_groupby.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/datafusion/optimizer/src/single_distinct_to_groupby.rs b/datafusion/optimizer/src/single_distinct_to_groupby.rs index 9ae1266f8e5ec..953abf31dbb8a 100644 --- a/datafusion/optimizer/src/single_distinct_to_groupby.rs +++ b/datafusion/optimizer/src/single_distinct_to_groupby.rs @@ -89,7 +89,7 @@ fn is_single_distinct_agg(aggr_expr: &[Expr]) -> Result { for e in args { fields_set.insert(e); } - distinct_func = Some(func.clone()); + distinct_func = Some(Arc::clone(func)); } else if func.name() != "sum" && func.name().to_lowercase() != "min" && func.name().to_lowercase() != "max" @@ -101,14 +101,14 @@ fn is_single_distinct_agg(aggr_expr: &[Expr]) -> Result { } } - if aggregate_count == aggr_expr.len() && fields_set.len() == 1 { + if aggregate_count == 1 && fields_set.len() == 1 { if let Some(distinct_func) = distinct_func { if distinct_func.name() == "count" { return Ok(false); } } } - + Ok(aggregate_count == aggr_expr.len() && fields_set.len() == 1) } From a6d9788087e98d5701af0a7eaacd808e4fdd19c0 Mon Sep 17 00:00:00 2001 From: Chongchen Chen Date: Sat, 26 Apr 2025 17:37:22 +0800 Subject: [PATCH 3/6] update ut --- .../src/single_distinct_to_groupby.rs | 24 +++++++------------ 1 file changed, 8 insertions(+), 16 deletions(-) diff --git a/datafusion/optimizer/src/single_distinct_to_groupby.rs b/datafusion/optimizer/src/single_distinct_to_groupby.rs index 953abf31dbb8a..b5f277a410a32 100644 --- a/datafusion/optimizer/src/single_distinct_to_groupby.rs +++ b/datafusion/optimizer/src/single_distinct_to_groupby.rs @@ -346,10 +346,8 @@ mod tests { .build()?; // Should work - let expected = "Projection: count(alias1) AS count(DISTINCT test.b) [count(DISTINCT test.b):Int64]\ - \n Aggregate: groupBy=[[]], aggr=[[count(alias1)]] [count(alias1):Int64]\ - \n Aggregate: groupBy=[[test.b AS alias1]], aggr=[[]] [alias1:UInt32]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; + let expected = "Aggregate: groupBy=[[]], aggr=[[count(DISTINCT test.b)]] [count(DISTINCT test.b):Int64]\ + \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; assert_optimized_plan_equal(plan, expected) } @@ -420,10 +418,8 @@ mod tests { .aggregate(Vec::::new(), vec![count_distinct(lit(2) * col("b"))])? .build()?; - let expected = "Projection: count(alias1) AS count(DISTINCT Int32(2) * test.b) [count(DISTINCT Int32(2) * test.b):Int64]\ - \n Aggregate: groupBy=[[]], aggr=[[count(alias1)]] [count(alias1):Int64]\ - \n Aggregate: groupBy=[[Int32(2) * test.b AS alias1]], aggr=[[]] [alias1:Int64]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; + let expected = "Aggregate: groupBy=[[]], aggr=[[count(DISTINCT Int32(2) * test.b)]] [count(DISTINCT Int32(2) * test.b):Int64]\ + \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; assert_optimized_plan_equal(plan, expected) } @@ -437,10 +433,8 @@ mod tests { .build()?; // Should work - let expected = "Projection: test.a, count(alias1) AS count(DISTINCT test.b) [a:UInt32, count(DISTINCT test.b):Int64]\ - \n Aggregate: groupBy=[[test.a]], aggr=[[count(alias1)]] [a:UInt32, count(alias1):Int64]\ - \n Aggregate: groupBy=[[test.a, test.b AS alias1]], aggr=[[]] [a:UInt32, alias1:UInt32]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; + let expected = "Aggregate: groupBy=[[test.a]], aggr=[[count(DISTINCT test.b)]] [a:UInt32, count(DISTINCT test.b):Int64]\ + \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; assert_optimized_plan_equal(plan, expected) } @@ -509,10 +503,8 @@ mod tests { .build()?; // Should work - let expected = "Projection: group_alias_0 AS test.a + Int32(1), count(alias1) AS count(DISTINCT test.c) [test.a + Int32(1):Int64, count(DISTINCT test.c):Int64]\ - \n Aggregate: groupBy=[[group_alias_0]], aggr=[[count(alias1)]] [group_alias_0:Int64, count(alias1):Int64]\ - \n Aggregate: groupBy=[[test.a + Int32(1) AS group_alias_0, test.c AS alias1]], aggr=[[]] [group_alias_0:Int64, alias1:UInt32]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; + let expected = "Aggregate: groupBy=[[test.a + Int32(1)]], aggr=[[count(DISTINCT test.c)]] [test.a + Int32(1):Int64, count(DISTINCT test.c):Int64]\ + \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; assert_optimized_plan_equal(plan, expected) } From 191d7cf5dd06cc6bd49a4839569a331a9f8092d7 Mon Sep 17 00:00:00 2001 From: Chongchen Chen Date: Sat, 26 Apr 2025 18:23:44 +0800 Subject: [PATCH 4/6] update --- .../test_files/tpch/plans/q16.slt.part | 73 +++++++++---------- 1 file changed, 34 insertions(+), 39 deletions(-) diff --git a/datafusion/sqllogictest/test_files/tpch/plans/q16.slt.part b/datafusion/sqllogictest/test_files/tpch/plans/q16.slt.part index c648f164c8094..663f32da7251c 100644 --- a/datafusion/sqllogictest/test_files/tpch/plans/q16.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/plans/q16.slt.part @@ -51,49 +51,44 @@ limit 10; ---- logical_plan 01)Sort: supplier_cnt DESC NULLS FIRST, part.p_brand ASC NULLS LAST, part.p_type ASC NULLS LAST, part.p_size ASC NULLS LAST, fetch=10 -02)--Projection: part.p_brand, part.p_type, part.p_size, count(alias1) AS supplier_cnt -03)----Aggregate: groupBy=[[part.p_brand, part.p_type, part.p_size]], aggr=[[count(alias1)]] -04)------Aggregate: groupBy=[[part.p_brand, part.p_type, part.p_size, partsupp.ps_suppkey AS alias1]], aggr=[[]] -05)--------LeftAnti Join: partsupp.ps_suppkey = __correlated_sq_1.s_suppkey -06)----------Projection: partsupp.ps_suppkey, part.p_brand, part.p_type, part.p_size -07)------------Inner Join: partsupp.ps_partkey = part.p_partkey -08)--------------TableScan: partsupp projection=[ps_partkey, ps_suppkey] -09)--------------Filter: part.p_brand != Utf8("Brand#45") AND part.p_type NOT LIKE Utf8("MEDIUM POLISHED%") AND part.p_size IN ([Int32(49), Int32(14), Int32(23), Int32(45), Int32(19), Int32(3), Int32(36), Int32(9)]) -10)----------------TableScan: part projection=[p_partkey, p_brand, p_type, p_size], partial_filters=[part.p_brand != Utf8("Brand#45"), part.p_type NOT LIKE Utf8("MEDIUM POLISHED%"), part.p_size IN ([Int32(49), Int32(14), Int32(23), Int32(45), Int32(19), Int32(3), Int32(36), Int32(9)])] -11)----------SubqueryAlias: __correlated_sq_1 -12)------------Projection: supplier.s_suppkey -13)--------------Filter: supplier.s_comment LIKE Utf8("%Customer%Complaints%") -14)----------------TableScan: supplier projection=[s_suppkey, s_comment], partial_filters=[supplier.s_comment LIKE Utf8("%Customer%Complaints%")] +02)--Projection: part.p_brand, part.p_type, part.p_size, count(DISTINCT partsupp.ps_suppkey) AS supplier_cnt +03)----Aggregate: groupBy=[[part.p_brand, part.p_type, part.p_size]], aggr=[[count(DISTINCT partsupp.ps_suppkey)]] +04)------LeftAnti Join: partsupp.ps_suppkey = __correlated_sq_1.s_suppkey +05)--------Projection: partsupp.ps_suppkey, part.p_brand, part.p_type, part.p_size +06)----------Inner Join: partsupp.ps_partkey = part.p_partkey +07)------------TableScan: partsupp projection=[ps_partkey, ps_suppkey] +08)------------Filter: part.p_brand != Utf8("Brand#45") AND part.p_type NOT LIKE Utf8("MEDIUM POLISHED%") AND part.p_size IN ([Int32(49), Int32(14), Int32(23), Int32(45), Int32(19), Int32(3), Int32(36), Int32(9)]) +09)--------------TableScan: part projection=[p_partkey, p_brand, p_type, p_size], partial_filters=[part.p_brand != Utf8("Brand#45"), part.p_type NOT LIKE Utf8("MEDIUM POLISHED%"), part.p_size IN ([Int32(49), Int32(14), Int32(23), Int32(45), Int32(19), Int32(3), Int32(36), Int32(9)])] +10)--------SubqueryAlias: __correlated_sq_1 +11)----------Projection: supplier.s_suppkey +12)------------Filter: supplier.s_comment LIKE Utf8("%Customer%Complaints%") +13)--------------TableScan: supplier projection=[s_suppkey, s_comment], partial_filters=[supplier.s_comment LIKE Utf8("%Customer%Complaints%")] physical_plan 01)SortPreservingMergeExec: [supplier_cnt@3 DESC, p_brand@0 ASC NULLS LAST, p_type@1 ASC NULLS LAST, p_size@2 ASC NULLS LAST], fetch=10 02)--SortExec: TopK(fetch=10), expr=[supplier_cnt@3 DESC, p_brand@0 ASC NULLS LAST, p_type@1 ASC NULLS LAST, p_size@2 ASC NULLS LAST], preserve_partitioning=[true] -03)----ProjectionExec: expr=[p_brand@0 as p_brand, p_type@1 as p_type, p_size@2 as p_size, count(alias1)@3 as supplier_cnt] -04)------AggregateExec: mode=FinalPartitioned, gby=[p_brand@0 as p_brand, p_type@1 as p_type, p_size@2 as p_size], aggr=[count(alias1)] +03)----ProjectionExec: expr=[p_brand@0 as p_brand, p_type@1 as p_type, p_size@2 as p_size, count(DISTINCT partsupp.ps_suppkey)@3 as supplier_cnt] +04)------AggregateExec: mode=FinalPartitioned, gby=[p_brand@0 as p_brand, p_type@1 as p_type, p_size@2 as p_size], aggr=[count(DISTINCT partsupp.ps_suppkey)] 05)--------CoalesceBatchesExec: target_batch_size=8192 06)----------RepartitionExec: partitioning=Hash([p_brand@0, p_type@1, p_size@2], 4), input_partitions=4 -07)------------AggregateExec: mode=Partial, gby=[p_brand@0 as p_brand, p_type@1 as p_type, p_size@2 as p_size], aggr=[count(alias1)] -08)--------------AggregateExec: mode=FinalPartitioned, gby=[p_brand@0 as p_brand, p_type@1 as p_type, p_size@2 as p_size, alias1@3 as alias1], aggr=[] -09)----------------CoalesceBatchesExec: target_batch_size=8192 -10)------------------RepartitionExec: partitioning=Hash([p_brand@0, p_type@1, p_size@2, alias1@3], 4), input_partitions=4 -11)--------------------AggregateExec: mode=Partial, gby=[p_brand@1 as p_brand, p_type@2 as p_type, p_size@3 as p_size, ps_suppkey@0 as alias1], aggr=[] +07)------------AggregateExec: mode=Partial, gby=[p_brand@1 as p_brand, p_type@2 as p_type, p_size@3 as p_size], aggr=[count(DISTINCT partsupp.ps_suppkey)] +08)--------------CoalesceBatchesExec: target_batch_size=8192 +09)----------------HashJoinExec: mode=Partitioned, join_type=LeftAnti, on=[(ps_suppkey@0, s_suppkey@0)] +10)------------------CoalesceBatchesExec: target_batch_size=8192 +11)--------------------RepartitionExec: partitioning=Hash([ps_suppkey@0], 4), input_partitions=4 12)----------------------CoalesceBatchesExec: target_batch_size=8192 -13)------------------------HashJoinExec: mode=Partitioned, join_type=LeftAnti, on=[(ps_suppkey@0, s_suppkey@0)] +13)------------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(ps_partkey@0, p_partkey@0)], projection=[ps_suppkey@1, p_brand@3, p_type@4, p_size@5] 14)--------------------------CoalesceBatchesExec: target_batch_size=8192 -15)----------------------------RepartitionExec: partitioning=Hash([ps_suppkey@0], 4), input_partitions=4 -16)------------------------------CoalesceBatchesExec: target_batch_size=8192 -17)--------------------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(ps_partkey@0, p_partkey@0)], projection=[ps_suppkey@1, p_brand@3, p_type@4, p_size@5] -18)----------------------------------CoalesceBatchesExec: target_batch_size=8192 -19)------------------------------------RepartitionExec: partitioning=Hash([ps_partkey@0], 4), input_partitions=4 -20)--------------------------------------DataSourceExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:0..2932049], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:2932049..5864098], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:5864098..8796147], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:8796147..11728193]]}, projection=[ps_partkey, ps_suppkey], file_type=csv, has_header=false -21)----------------------------------CoalesceBatchesExec: target_batch_size=8192 -22)------------------------------------RepartitionExec: partitioning=Hash([p_partkey@0], 4), input_partitions=4 -23)--------------------------------------CoalesceBatchesExec: target_batch_size=8192 -24)----------------------------------------FilterExec: p_brand@1 != Brand#45 AND p_type@2 NOT LIKE MEDIUM POLISHED% AND Use p_size@3 IN (SET) ([Literal { value: Int32(49) }, Literal { value: Int32(14) }, Literal { value: Int32(23) }, Literal { value: Int32(45) }, Literal { value: Int32(19) }, Literal { value: Int32(3) }, Literal { value: Int32(36) }, Literal { value: Int32(9) }]) -25)------------------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -26)--------------------------------------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/part.tbl]]}, projection=[p_partkey, p_brand, p_type, p_size], file_type=csv, has_header=false -27)--------------------------CoalesceBatchesExec: target_batch_size=8192 -28)----------------------------RepartitionExec: partitioning=Hash([s_suppkey@0], 4), input_partitions=4 -29)------------------------------CoalesceBatchesExec: target_batch_size=8192 -30)--------------------------------FilterExec: s_comment@1 LIKE %Customer%Complaints%, projection=[s_suppkey@0] -31)----------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -32)------------------------------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/supplier.tbl]]}, projection=[s_suppkey, s_comment], file_type=csv, has_header=false +15)----------------------------RepartitionExec: partitioning=Hash([ps_partkey@0], 4), input_partitions=4 +16)------------------------------DataSourceExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:0..2932049], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:2932049..5864098], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:5864098..8796147], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:8796147..11728193]]}, projection=[ps_partkey, ps_suppkey], file_type=csv, has_header=false +17)--------------------------CoalesceBatchesExec: target_batch_size=8192 +18)----------------------------RepartitionExec: partitioning=Hash([p_partkey@0], 4), input_partitions=4 +19)------------------------------CoalesceBatchesExec: target_batch_size=8192 +20)--------------------------------FilterExec: p_brand@1 != Brand#45 AND p_type@2 NOT LIKE MEDIUM POLISHED% AND Use p_size@3 IN (SET) ([Literal { value: Int32(49) }, Literal { value: Int32(14) }, Literal { value: Int32(23) }, Literal { value: Int32(45) }, Literal { value: Int32(19) }, Literal { value: Int32(3) }, Literal { value: Int32(36) }, Literal { value: Int32(9) }]) +21)----------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +22)------------------------------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/part.tbl]]}, projection=[p_partkey, p_brand, p_type, p_size], file_type=csv, has_header=false +23)------------------CoalesceBatchesExec: target_batch_size=8192 +24)--------------------RepartitionExec: partitioning=Hash([s_suppkey@0], 4), input_partitions=4 +25)----------------------CoalesceBatchesExec: target_batch_size=8192 +26)------------------------FilterExec: s_comment@1 LIKE %Customer%Complaints%, projection=[s_suppkey@0] +27)--------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +28)----------------------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/supplier.tbl]]}, projection=[s_suppkey, s_comment], file_type=csv, has_header=false From bf0bb3a6990cebc38b1b871984cc3666daf48c57 Mon Sep 17 00:00:00 2001 From: Chongchen Chen Date: Sun, 27 Apr 2025 08:09:04 +0800 Subject: [PATCH 5/6] update --- .../optimizer/src/single_distinct_to_groupby.rs | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/datafusion/optimizer/src/single_distinct_to_groupby.rs b/datafusion/optimizer/src/single_distinct_to_groupby.rs index b5f277a410a32..375d30dd7dc1d 100644 --- a/datafusion/optimizer/src/single_distinct_to_groupby.rs +++ b/datafusion/optimizer/src/single_distinct_to_groupby.rs @@ -67,6 +67,7 @@ impl SingleDistinctToGroupBy { fn is_single_distinct_agg(aggr_expr: &[Expr]) -> Result { let mut fields_set = HashSet::new(); let mut aggregate_count = 0; + let mut distinct_count = 0; let mut distinct_func: Option> = None; for expr in aggr_expr { if let Expr::AggregateFunction(AggregateFunction { @@ -86,6 +87,7 @@ fn is_single_distinct_agg(aggr_expr: &[Expr]) -> Result { } aggregate_count += 1; if *distinct { + distinct_count += 1; for e in args { fields_set.insert(e); } @@ -101,7 +103,7 @@ fn is_single_distinct_agg(aggr_expr: &[Expr]) -> Result { } } - if aggregate_count == 1 && fields_set.len() == 1 { + if distinct_count == 1 && fields_set.len() == 1 { if let Some(distinct_func) = distinct_func { if distinct_func.name() == "count" { return Ok(false); @@ -543,10 +545,8 @@ mod tests { )? .build()?; // Should work - let expected = "Projection: test.a, sum(alias2) AS sum(test.c), max(alias3) AS max(test.c), count(alias1) AS count(DISTINCT test.b) [a:UInt32, sum(test.c):UInt64;N, max(test.c):UInt32;N, count(DISTINCT test.b):Int64]\ - \n Aggregate: groupBy=[[test.a]], aggr=[[sum(alias2), max(alias3), count(alias1)]] [a:UInt32, sum(alias2):UInt64;N, max(alias3):UInt32;N, count(alias1):Int64]\ - \n Aggregate: groupBy=[[test.a, test.b AS alias1]], aggr=[[sum(test.c) AS alias2, max(test.c) AS alias3]] [a:UInt32, alias1:UInt32, alias2:UInt64;N, alias3:UInt32;N]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; + let expected = "Aggregate: groupBy=[[test.a]], aggr=[[sum(test.c), max(test.c), count(DISTINCT test.b)]] [a:UInt32, sum(test.c):UInt64;N, max(test.c):UInt32;N, count(DISTINCT test.b):Int64]\ + \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; assert_optimized_plan_equal(plan, expected) } @@ -562,10 +562,8 @@ mod tests { )? .build()?; // Should work - let expected = "Projection: test.c, min(alias2) AS min(test.a), count(alias1) AS count(DISTINCT test.b) [c:UInt32, min(test.a):UInt32;N, count(DISTINCT test.b):Int64]\ - \n Aggregate: groupBy=[[test.c]], aggr=[[min(alias2), count(alias1)]] [c:UInt32, min(alias2):UInt32;N, count(alias1):Int64]\ - \n Aggregate: groupBy=[[test.c, test.b AS alias1]], aggr=[[min(test.a) AS alias2]] [c:UInt32, alias1:UInt32, alias2:UInt32;N]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; + let expected = "Aggregate: groupBy=[[test.c]], aggr=[[min(test.a), count(DISTINCT test.b)]] [c:UInt32, min(test.a):UInt32;N, count(DISTINCT test.b):Int64]\ + \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; assert_optimized_plan_equal(plan, expected) } From 4e865e93a343bc6e7ed0ed33fd44e2a8827e5b4c Mon Sep 17 00:00:00 2001 From: Chongchen Chen Date: Sun, 27 Apr 2025 10:09:20 +0800 Subject: [PATCH 6/6] update --- datafusion/optimizer/src/single_distinct_to_groupby.rs | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/datafusion/optimizer/src/single_distinct_to_groupby.rs b/datafusion/optimizer/src/single_distinct_to_groupby.rs index 375d30dd7dc1d..c11fe27e8209a 100644 --- a/datafusion/optimizer/src/single_distinct_to_groupby.rs +++ b/datafusion/optimizer/src/single_distinct_to_groupby.rs @@ -27,7 +27,6 @@ use datafusion_common::{ }; use datafusion_expr::builder::project; use datafusion_expr::expr::AggregateFunctionParams; -use datafusion_expr::AggregateUDF; use datafusion_expr::{ col, expr::AggregateFunction, @@ -68,7 +67,7 @@ fn is_single_distinct_agg(aggr_expr: &[Expr]) -> Result { let mut fields_set = HashSet::new(); let mut aggregate_count = 0; let mut distinct_count = 0; - let mut distinct_func: Option> = None; + let mut distinct_func: Option<&str> = None; for expr in aggr_expr { if let Expr::AggregateFunction(AggregateFunction { func, @@ -91,7 +90,7 @@ fn is_single_distinct_agg(aggr_expr: &[Expr]) -> Result { for e in args { fields_set.insert(e); } - distinct_func = Some(Arc::clone(func)); + distinct_func = Some(func.name()); } else if func.name() != "sum" && func.name().to_lowercase() != "min" && func.name().to_lowercase() != "max" @@ -105,7 +104,7 @@ fn is_single_distinct_agg(aggr_expr: &[Expr]) -> Result { if distinct_count == 1 && fields_set.len() == 1 { if let Some(distinct_func) = distinct_func { - if distinct_func.name() == "count" { + if distinct_func == "count" { return Ok(false); } }