From b074280362525894a44cee862ef269f84487f4f7 Mon Sep 17 00:00:00 2001 From: Chongchen Chen Date: Sat, 26 Apr 2025 17:13:02 +0800 Subject: [PATCH 1/7] feat: simplify count distinct logical plan --- .../src/single_distinct_to_groupby.rs | 12 +++++++ datafusion/sqllogictest/test_files/joins.slt | 33 ++++++++----------- 2 files changed, 25 insertions(+), 20 deletions(-) diff --git a/datafusion/optimizer/src/single_distinct_to_groupby.rs b/datafusion/optimizer/src/single_distinct_to_groupby.rs index 7337d2ffce5c3..9ae1266f8e5ec 100644 --- a/datafusion/optimizer/src/single_distinct_to_groupby.rs +++ b/datafusion/optimizer/src/single_distinct_to_groupby.rs @@ -27,6 +27,7 @@ use datafusion_common::{ }; use datafusion_expr::builder::project; use datafusion_expr::expr::AggregateFunctionParams; +use datafusion_expr::AggregateUDF; use datafusion_expr::{ col, expr::AggregateFunction, @@ -66,6 +67,7 @@ impl SingleDistinctToGroupBy { fn is_single_distinct_agg(aggr_expr: &[Expr]) -> Result { let mut fields_set = HashSet::new(); let mut aggregate_count = 0; + let mut distinct_func: Option> = None; for expr in aggr_expr { if let Expr::AggregateFunction(AggregateFunction { func, @@ -87,6 +89,7 @@ fn is_single_distinct_agg(aggr_expr: &[Expr]) -> Result { for e in args { fields_set.insert(e); } + distinct_func = Some(func.clone()); } else if func.name() != "sum" && func.name().to_lowercase() != "min" && func.name().to_lowercase() != "max" @@ -97,6 +100,15 @@ fn is_single_distinct_agg(aggr_expr: &[Expr]) -> Result { return Ok(false); } } + + if aggregate_count == aggr_expr.len() && fields_set.len() == 1 { + if let Some(distinct_func) = distinct_func { + if distinct_func.name() == "count" { + return Ok(false); + } + } + } + Ok(aggregate_count == aggr_expr.len() && fields_set.len() == 1) } diff --git a/datafusion/sqllogictest/test_files/joins.slt b/datafusion/sqllogictest/test_files/joins.slt index ddf701ba04efe..ec79e94949978 100644 --- a/datafusion/sqllogictest/test_files/joins.slt +++ b/datafusion/sqllogictest/test_files/joins.slt @@ -1413,27 +1413,20 @@ from join_t1 inner join join_t2 on join_t1.t1_id = join_t2.t2_id ---- logical_plan -01)Projection: count(alias1) AS count(DISTINCT join_t1.t1_id) -02)--Aggregate: groupBy=[[]], aggr=[[count(alias1)]] -03)----Aggregate: groupBy=[[join_t1.t1_id AS alias1]], aggr=[[]] -04)------Projection: join_t1.t1_id -05)--------Inner Join: join_t1.t1_id = join_t2.t2_id -06)----------TableScan: join_t1 projection=[t1_id] -07)----------TableScan: join_t2 projection=[t2_id] +01)Aggregate: groupBy=[[]], aggr=[[count(DISTINCT join_t1.t1_id)]] +02)--Projection: join_t1.t1_id +03)----Inner Join: join_t1.t1_id = join_t2.t2_id +04)------TableScan: join_t1 projection=[t1_id] +05)------TableScan: join_t2 projection=[t2_id] physical_plan -01)ProjectionExec: expr=[count(alias1)@0 as count(DISTINCT join_t1.t1_id)] -02)--AggregateExec: mode=Final, gby=[], aggr=[count(alias1)] -03)----CoalescePartitionsExec -04)------AggregateExec: mode=Partial, gby=[], aggr=[count(alias1)] -05)--------AggregateExec: mode=FinalPartitioned, gby=[alias1@0 as alias1], aggr=[] -06)----------CoalesceBatchesExec: target_batch_size=2 -07)------------RepartitionExec: partitioning=Hash([alias1@0], 2), input_partitions=2 -08)--------------AggregateExec: mode=Partial, gby=[t1_id@0 as alias1], aggr=[] -09)----------------CoalesceBatchesExec: target_batch_size=2 -10)------------------HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(t1_id@0, t2_id@0)], projection=[t1_id@0] -11)--------------------DataSourceExec: partitions=1, partition_sizes=[1] -12)--------------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 -13)----------------------DataSourceExec: partitions=1, partition_sizes=[1] +01)AggregateExec: mode=Final, gby=[], aggr=[count(DISTINCT join_t1.t1_id)] +02)--CoalescePartitionsExec +03)----AggregateExec: mode=Partial, gby=[], aggr=[count(DISTINCT join_t1.t1_id)] +04)------CoalesceBatchesExec: target_batch_size=2 +05)--------HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(t1_id@0, t2_id@0)], projection=[t1_id@0] +06)----------DataSourceExec: partitions=1, partition_sizes=[1] +07)----------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +08)------------DataSourceExec: partitions=1, partition_sizes=[1] statement ok set datafusion.explain.logical_plan_only = true; From b5b99a942e14864936085592d56702037f953779 Mon Sep 17 00:00:00 2001 From: Chongchen Chen Date: Sat, 26 Apr 2025 17:18:53 +0800 Subject: [PATCH 2/7] fmt --- datafusion/optimizer/src/single_distinct_to_groupby.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/datafusion/optimizer/src/single_distinct_to_groupby.rs b/datafusion/optimizer/src/single_distinct_to_groupby.rs index 9ae1266f8e5ec..953abf31dbb8a 100644 --- a/datafusion/optimizer/src/single_distinct_to_groupby.rs +++ b/datafusion/optimizer/src/single_distinct_to_groupby.rs @@ -89,7 +89,7 @@ fn is_single_distinct_agg(aggr_expr: &[Expr]) -> Result { for e in args { fields_set.insert(e); } - distinct_func = Some(func.clone()); + distinct_func = Some(Arc::clone(func)); } else if func.name() != "sum" && func.name().to_lowercase() != "min" && func.name().to_lowercase() != "max" @@ -101,14 +101,14 @@ fn is_single_distinct_agg(aggr_expr: &[Expr]) -> Result { } } - if aggregate_count == aggr_expr.len() && fields_set.len() == 1 { + if aggregate_count == 1 && fields_set.len() == 1 { if let Some(distinct_func) = distinct_func { if distinct_func.name() == "count" { return Ok(false); } } } - + Ok(aggregate_count == aggr_expr.len() && fields_set.len() == 1) } From a6d9788087e98d5701af0a7eaacd808e4fdd19c0 Mon Sep 17 00:00:00 2001 From: Chongchen Chen Date: Sat, 26 Apr 2025 17:37:22 +0800 Subject: [PATCH 3/7] update ut --- .../src/single_distinct_to_groupby.rs | 24 +++++++------------ 1 file changed, 8 insertions(+), 16 deletions(-) diff --git a/datafusion/optimizer/src/single_distinct_to_groupby.rs b/datafusion/optimizer/src/single_distinct_to_groupby.rs index 953abf31dbb8a..b5f277a410a32 100644 --- a/datafusion/optimizer/src/single_distinct_to_groupby.rs +++ b/datafusion/optimizer/src/single_distinct_to_groupby.rs @@ -346,10 +346,8 @@ mod tests { .build()?; // Should work - let expected = "Projection: count(alias1) AS count(DISTINCT test.b) [count(DISTINCT test.b):Int64]\ - \n Aggregate: groupBy=[[]], aggr=[[count(alias1)]] [count(alias1):Int64]\ - \n Aggregate: groupBy=[[test.b AS alias1]], aggr=[[]] [alias1:UInt32]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; + let expected = "Aggregate: groupBy=[[]], aggr=[[count(DISTINCT test.b)]] [count(DISTINCT test.b):Int64]\ + \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; assert_optimized_plan_equal(plan, expected) } @@ -420,10 +418,8 @@ mod tests { .aggregate(Vec::::new(), vec![count_distinct(lit(2) * col("b"))])? .build()?; - let expected = "Projection: count(alias1) AS count(DISTINCT Int32(2) * test.b) [count(DISTINCT Int32(2) * test.b):Int64]\ - \n Aggregate: groupBy=[[]], aggr=[[count(alias1)]] [count(alias1):Int64]\ - \n Aggregate: groupBy=[[Int32(2) * test.b AS alias1]], aggr=[[]] [alias1:Int64]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; + let expected = "Aggregate: groupBy=[[]], aggr=[[count(DISTINCT Int32(2) * test.b)]] [count(DISTINCT Int32(2) * test.b):Int64]\ + \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; assert_optimized_plan_equal(plan, expected) } @@ -437,10 +433,8 @@ mod tests { .build()?; // Should work - let expected = "Projection: test.a, count(alias1) AS count(DISTINCT test.b) [a:UInt32, count(DISTINCT test.b):Int64]\ - \n Aggregate: groupBy=[[test.a]], aggr=[[count(alias1)]] [a:UInt32, count(alias1):Int64]\ - \n Aggregate: groupBy=[[test.a, test.b AS alias1]], aggr=[[]] [a:UInt32, alias1:UInt32]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; + let expected = "Aggregate: groupBy=[[test.a]], aggr=[[count(DISTINCT test.b)]] [a:UInt32, count(DISTINCT test.b):Int64]\ + \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; assert_optimized_plan_equal(plan, expected) } @@ -509,10 +503,8 @@ mod tests { .build()?; // Should work - let expected = "Projection: group_alias_0 AS test.a + Int32(1), count(alias1) AS count(DISTINCT test.c) [test.a + Int32(1):Int64, count(DISTINCT test.c):Int64]\ - \n Aggregate: groupBy=[[group_alias_0]], aggr=[[count(alias1)]] [group_alias_0:Int64, count(alias1):Int64]\ - \n Aggregate: groupBy=[[test.a + Int32(1) AS group_alias_0, test.c AS alias1]], aggr=[[]] [group_alias_0:Int64, alias1:UInt32]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; + let expected = "Aggregate: groupBy=[[test.a + Int32(1)]], aggr=[[count(DISTINCT test.c)]] [test.a + Int32(1):Int64, count(DISTINCT test.c):Int64]\ + \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; assert_optimized_plan_equal(plan, expected) } From 191d7cf5dd06cc6bd49a4839569a331a9f8092d7 Mon Sep 17 00:00:00 2001 From: Chongchen Chen Date: Sat, 26 Apr 2025 18:23:44 +0800 Subject: [PATCH 4/7] update --- .../test_files/tpch/plans/q16.slt.part | 73 +++++++++---------- 1 file changed, 34 insertions(+), 39 deletions(-) diff --git a/datafusion/sqllogictest/test_files/tpch/plans/q16.slt.part b/datafusion/sqllogictest/test_files/tpch/plans/q16.slt.part index c648f164c8094..663f32da7251c 100644 --- a/datafusion/sqllogictest/test_files/tpch/plans/q16.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/plans/q16.slt.part @@ -51,49 +51,44 @@ limit 10; ---- logical_plan 01)Sort: supplier_cnt DESC NULLS FIRST, part.p_brand ASC NULLS LAST, part.p_type ASC NULLS LAST, part.p_size ASC NULLS LAST, fetch=10 -02)--Projection: part.p_brand, part.p_type, part.p_size, count(alias1) AS supplier_cnt -03)----Aggregate: groupBy=[[part.p_brand, part.p_type, part.p_size]], aggr=[[count(alias1)]] -04)------Aggregate: groupBy=[[part.p_brand, part.p_type, part.p_size, partsupp.ps_suppkey AS alias1]], aggr=[[]] -05)--------LeftAnti Join: partsupp.ps_suppkey = __correlated_sq_1.s_suppkey -06)----------Projection: partsupp.ps_suppkey, part.p_brand, part.p_type, part.p_size -07)------------Inner Join: partsupp.ps_partkey = part.p_partkey -08)--------------TableScan: partsupp projection=[ps_partkey, ps_suppkey] -09)--------------Filter: part.p_brand != Utf8("Brand#45") AND part.p_type NOT LIKE Utf8("MEDIUM POLISHED%") AND part.p_size IN ([Int32(49), Int32(14), Int32(23), Int32(45), Int32(19), Int32(3), Int32(36), Int32(9)]) -10)----------------TableScan: part projection=[p_partkey, p_brand, p_type, p_size], partial_filters=[part.p_brand != Utf8("Brand#45"), part.p_type NOT LIKE Utf8("MEDIUM POLISHED%"), part.p_size IN ([Int32(49), Int32(14), Int32(23), Int32(45), Int32(19), Int32(3), Int32(36), Int32(9)])] -11)----------SubqueryAlias: __correlated_sq_1 -12)------------Projection: supplier.s_suppkey -13)--------------Filter: supplier.s_comment LIKE Utf8("%Customer%Complaints%") -14)----------------TableScan: supplier projection=[s_suppkey, s_comment], partial_filters=[supplier.s_comment LIKE Utf8("%Customer%Complaints%")] +02)--Projection: part.p_brand, part.p_type, part.p_size, count(DISTINCT partsupp.ps_suppkey) AS supplier_cnt +03)----Aggregate: groupBy=[[part.p_brand, part.p_type, part.p_size]], aggr=[[count(DISTINCT partsupp.ps_suppkey)]] +04)------LeftAnti Join: partsupp.ps_suppkey = __correlated_sq_1.s_suppkey +05)--------Projection: partsupp.ps_suppkey, part.p_brand, part.p_type, part.p_size +06)----------Inner Join: partsupp.ps_partkey = part.p_partkey +07)------------TableScan: partsupp projection=[ps_partkey, ps_suppkey] +08)------------Filter: part.p_brand != Utf8("Brand#45") AND part.p_type NOT LIKE Utf8("MEDIUM POLISHED%") AND part.p_size IN ([Int32(49), Int32(14), Int32(23), Int32(45), Int32(19), Int32(3), Int32(36), Int32(9)]) +09)--------------TableScan: part projection=[p_partkey, p_brand, p_type, p_size], partial_filters=[part.p_brand != Utf8("Brand#45"), part.p_type NOT LIKE Utf8("MEDIUM POLISHED%"), part.p_size IN ([Int32(49), Int32(14), Int32(23), Int32(45), Int32(19), Int32(3), Int32(36), Int32(9)])] +10)--------SubqueryAlias: __correlated_sq_1 +11)----------Projection: supplier.s_suppkey +12)------------Filter: supplier.s_comment LIKE Utf8("%Customer%Complaints%") +13)--------------TableScan: supplier projection=[s_suppkey, s_comment], partial_filters=[supplier.s_comment LIKE Utf8("%Customer%Complaints%")] physical_plan 01)SortPreservingMergeExec: [supplier_cnt@3 DESC, p_brand@0 ASC NULLS LAST, p_type@1 ASC NULLS LAST, p_size@2 ASC NULLS LAST], fetch=10 02)--SortExec: TopK(fetch=10), expr=[supplier_cnt@3 DESC, p_brand@0 ASC NULLS LAST, p_type@1 ASC NULLS LAST, p_size@2 ASC NULLS LAST], preserve_partitioning=[true] -03)----ProjectionExec: expr=[p_brand@0 as p_brand, p_type@1 as p_type, p_size@2 as p_size, count(alias1)@3 as supplier_cnt] -04)------AggregateExec: mode=FinalPartitioned, gby=[p_brand@0 as p_brand, p_type@1 as p_type, p_size@2 as p_size], aggr=[count(alias1)] +03)----ProjectionExec: expr=[p_brand@0 as p_brand, p_type@1 as p_type, p_size@2 as p_size, count(DISTINCT partsupp.ps_suppkey)@3 as supplier_cnt] +04)------AggregateExec: mode=FinalPartitioned, gby=[p_brand@0 as p_brand, p_type@1 as p_type, p_size@2 as p_size], aggr=[count(DISTINCT partsupp.ps_suppkey)] 05)--------CoalesceBatchesExec: target_batch_size=8192 06)----------RepartitionExec: partitioning=Hash([p_brand@0, p_type@1, p_size@2], 4), input_partitions=4 -07)------------AggregateExec: mode=Partial, gby=[p_brand@0 as p_brand, p_type@1 as p_type, p_size@2 as p_size], aggr=[count(alias1)] -08)--------------AggregateExec: mode=FinalPartitioned, gby=[p_brand@0 as p_brand, p_type@1 as p_type, p_size@2 as p_size, alias1@3 as alias1], aggr=[] -09)----------------CoalesceBatchesExec: target_batch_size=8192 -10)------------------RepartitionExec: partitioning=Hash([p_brand@0, p_type@1, p_size@2, alias1@3], 4), input_partitions=4 -11)--------------------AggregateExec: mode=Partial, gby=[p_brand@1 as p_brand, p_type@2 as p_type, p_size@3 as p_size, ps_suppkey@0 as alias1], aggr=[] +07)------------AggregateExec: mode=Partial, gby=[p_brand@1 as p_brand, p_type@2 as p_type, p_size@3 as p_size], aggr=[count(DISTINCT partsupp.ps_suppkey)] +08)--------------CoalesceBatchesExec: target_batch_size=8192 +09)----------------HashJoinExec: mode=Partitioned, join_type=LeftAnti, on=[(ps_suppkey@0, s_suppkey@0)] +10)------------------CoalesceBatchesExec: target_batch_size=8192 +11)--------------------RepartitionExec: partitioning=Hash([ps_suppkey@0], 4), input_partitions=4 12)----------------------CoalesceBatchesExec: target_batch_size=8192 -13)------------------------HashJoinExec: mode=Partitioned, join_type=LeftAnti, on=[(ps_suppkey@0, s_suppkey@0)] +13)------------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(ps_partkey@0, p_partkey@0)], projection=[ps_suppkey@1, p_brand@3, p_type@4, p_size@5] 14)--------------------------CoalesceBatchesExec: target_batch_size=8192 -15)----------------------------RepartitionExec: partitioning=Hash([ps_suppkey@0], 4), input_partitions=4 -16)------------------------------CoalesceBatchesExec: target_batch_size=8192 -17)--------------------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(ps_partkey@0, p_partkey@0)], projection=[ps_suppkey@1, p_brand@3, p_type@4, p_size@5] -18)----------------------------------CoalesceBatchesExec: target_batch_size=8192 -19)------------------------------------RepartitionExec: partitioning=Hash([ps_partkey@0], 4), input_partitions=4 -20)--------------------------------------DataSourceExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:0..2932049], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:2932049..5864098], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:5864098..8796147], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:8796147..11728193]]}, projection=[ps_partkey, ps_suppkey], file_type=csv, has_header=false -21)----------------------------------CoalesceBatchesExec: target_batch_size=8192 -22)------------------------------------RepartitionExec: partitioning=Hash([p_partkey@0], 4), input_partitions=4 -23)--------------------------------------CoalesceBatchesExec: target_batch_size=8192 -24)----------------------------------------FilterExec: p_brand@1 != Brand#45 AND p_type@2 NOT LIKE MEDIUM POLISHED% AND Use p_size@3 IN (SET) ([Literal { value: Int32(49) }, Literal { value: Int32(14) }, Literal { value: Int32(23) }, Literal { value: Int32(45) }, Literal { value: Int32(19) }, Literal { value: Int32(3) }, Literal { value: Int32(36) }, Literal { value: Int32(9) }]) -25)------------------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -26)--------------------------------------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/part.tbl]]}, projection=[p_partkey, p_brand, p_type, p_size], file_type=csv, has_header=false -27)--------------------------CoalesceBatchesExec: target_batch_size=8192 -28)----------------------------RepartitionExec: partitioning=Hash([s_suppkey@0], 4), input_partitions=4 -29)------------------------------CoalesceBatchesExec: target_batch_size=8192 -30)--------------------------------FilterExec: s_comment@1 LIKE %Customer%Complaints%, projection=[s_suppkey@0] -31)----------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -32)------------------------------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/supplier.tbl]]}, projection=[s_suppkey, s_comment], file_type=csv, has_header=false +15)----------------------------RepartitionExec: partitioning=Hash([ps_partkey@0], 4), input_partitions=4 +16)------------------------------DataSourceExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:0..2932049], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:2932049..5864098], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:5864098..8796147], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:8796147..11728193]]}, projection=[ps_partkey, ps_suppkey], file_type=csv, has_header=false +17)--------------------------CoalesceBatchesExec: target_batch_size=8192 +18)----------------------------RepartitionExec: partitioning=Hash([p_partkey@0], 4), input_partitions=4 +19)------------------------------CoalesceBatchesExec: target_batch_size=8192 +20)--------------------------------FilterExec: p_brand@1 != Brand#45 AND p_type@2 NOT LIKE MEDIUM POLISHED% AND Use p_size@3 IN (SET) ([Literal { value: Int32(49) }, Literal { value: Int32(14) }, Literal { value: Int32(23) }, Literal { value: Int32(45) }, Literal { value: Int32(19) }, Literal { value: Int32(3) }, Literal { value: Int32(36) }, Literal { value: Int32(9) }]) +21)----------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +22)------------------------------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/part.tbl]]}, projection=[p_partkey, p_brand, p_type, p_size], file_type=csv, has_header=false +23)------------------CoalesceBatchesExec: target_batch_size=8192 +24)--------------------RepartitionExec: partitioning=Hash([s_suppkey@0], 4), input_partitions=4 +25)----------------------CoalesceBatchesExec: target_batch_size=8192 +26)------------------------FilterExec: s_comment@1 LIKE %Customer%Complaints%, projection=[s_suppkey@0] +27)--------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +28)----------------------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/supplier.tbl]]}, projection=[s_suppkey, s_comment], file_type=csv, has_header=false From bf0bb3a6990cebc38b1b871984cc3666daf48c57 Mon Sep 17 00:00:00 2001 From: Chongchen Chen Date: Sun, 27 Apr 2025 08:09:04 +0800 Subject: [PATCH 5/7] update --- .../optimizer/src/single_distinct_to_groupby.rs | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/datafusion/optimizer/src/single_distinct_to_groupby.rs b/datafusion/optimizer/src/single_distinct_to_groupby.rs index b5f277a410a32..375d30dd7dc1d 100644 --- a/datafusion/optimizer/src/single_distinct_to_groupby.rs +++ b/datafusion/optimizer/src/single_distinct_to_groupby.rs @@ -67,6 +67,7 @@ impl SingleDistinctToGroupBy { fn is_single_distinct_agg(aggr_expr: &[Expr]) -> Result { let mut fields_set = HashSet::new(); let mut aggregate_count = 0; + let mut distinct_count = 0; let mut distinct_func: Option> = None; for expr in aggr_expr { if let Expr::AggregateFunction(AggregateFunction { @@ -86,6 +87,7 @@ fn is_single_distinct_agg(aggr_expr: &[Expr]) -> Result { } aggregate_count += 1; if *distinct { + distinct_count += 1; for e in args { fields_set.insert(e); } @@ -101,7 +103,7 @@ fn is_single_distinct_agg(aggr_expr: &[Expr]) -> Result { } } - if aggregate_count == 1 && fields_set.len() == 1 { + if distinct_count == 1 && fields_set.len() == 1 { if let Some(distinct_func) = distinct_func { if distinct_func.name() == "count" { return Ok(false); @@ -543,10 +545,8 @@ mod tests { )? .build()?; // Should work - let expected = "Projection: test.a, sum(alias2) AS sum(test.c), max(alias3) AS max(test.c), count(alias1) AS count(DISTINCT test.b) [a:UInt32, sum(test.c):UInt64;N, max(test.c):UInt32;N, count(DISTINCT test.b):Int64]\ - \n Aggregate: groupBy=[[test.a]], aggr=[[sum(alias2), max(alias3), count(alias1)]] [a:UInt32, sum(alias2):UInt64;N, max(alias3):UInt32;N, count(alias1):Int64]\ - \n Aggregate: groupBy=[[test.a, test.b AS alias1]], aggr=[[sum(test.c) AS alias2, max(test.c) AS alias3]] [a:UInt32, alias1:UInt32, alias2:UInt64;N, alias3:UInt32;N]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; + let expected = "Aggregate: groupBy=[[test.a]], aggr=[[sum(test.c), max(test.c), count(DISTINCT test.b)]] [a:UInt32, sum(test.c):UInt64;N, max(test.c):UInt32;N, count(DISTINCT test.b):Int64]\ + \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; assert_optimized_plan_equal(plan, expected) } @@ -562,10 +562,8 @@ mod tests { )? .build()?; // Should work - let expected = "Projection: test.c, min(alias2) AS min(test.a), count(alias1) AS count(DISTINCT test.b) [c:UInt32, min(test.a):UInt32;N, count(DISTINCT test.b):Int64]\ - \n Aggregate: groupBy=[[test.c]], aggr=[[min(alias2), count(alias1)]] [c:UInt32, min(alias2):UInt32;N, count(alias1):Int64]\ - \n Aggregate: groupBy=[[test.c, test.b AS alias1]], aggr=[[min(test.a) AS alias2]] [c:UInt32, alias1:UInt32, alias2:UInt32;N]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; + let expected = "Aggregate: groupBy=[[test.c]], aggr=[[min(test.a), count(DISTINCT test.b)]] [c:UInt32, min(test.a):UInt32;N, count(DISTINCT test.b):Int64]\ + \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; assert_optimized_plan_equal(plan, expected) } From 4e865e93a343bc6e7ed0ed33fd44e2a8827e5b4c Mon Sep 17 00:00:00 2001 From: Chongchen Chen Date: Sun, 27 Apr 2025 10:09:20 +0800 Subject: [PATCH 6/7] update --- datafusion/optimizer/src/single_distinct_to_groupby.rs | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/datafusion/optimizer/src/single_distinct_to_groupby.rs b/datafusion/optimizer/src/single_distinct_to_groupby.rs index 375d30dd7dc1d..c11fe27e8209a 100644 --- a/datafusion/optimizer/src/single_distinct_to_groupby.rs +++ b/datafusion/optimizer/src/single_distinct_to_groupby.rs @@ -27,7 +27,6 @@ use datafusion_common::{ }; use datafusion_expr::builder::project; use datafusion_expr::expr::AggregateFunctionParams; -use datafusion_expr::AggregateUDF; use datafusion_expr::{ col, expr::AggregateFunction, @@ -68,7 +67,7 @@ fn is_single_distinct_agg(aggr_expr: &[Expr]) -> Result { let mut fields_set = HashSet::new(); let mut aggregate_count = 0; let mut distinct_count = 0; - let mut distinct_func: Option> = None; + let mut distinct_func: Option<&str> = None; for expr in aggr_expr { if let Expr::AggregateFunction(AggregateFunction { func, @@ -91,7 +90,7 @@ fn is_single_distinct_agg(aggr_expr: &[Expr]) -> Result { for e in args { fields_set.insert(e); } - distinct_func = Some(Arc::clone(func)); + distinct_func = Some(func.name()); } else if func.name() != "sum" && func.name().to_lowercase() != "min" && func.name().to_lowercase() != "max" @@ -105,7 +104,7 @@ fn is_single_distinct_agg(aggr_expr: &[Expr]) -> Result { if distinct_count == 1 && fields_set.len() == 1 { if let Some(distinct_func) = distinct_func { - if distinct_func.name() == "count" { + if distinct_func == "count" { return Ok(false); } } From 6426a826e3837ab9074b4d789bdc3a8dd302d3eb Mon Sep 17 00:00:00 2001 From: Jay Zhan Date: Tue, 29 Apr 2025 10:27:41 +0800 Subject: [PATCH 7/7] distinct count --- datafusion/functions-aggregate/Cargo.toml | 1 + datafusion/functions-aggregate/src/count.rs | 202 +++++++++++++++++++- 2 files changed, 198 insertions(+), 5 deletions(-) diff --git a/datafusion/functions-aggregate/Cargo.toml b/datafusion/functions-aggregate/Cargo.toml index ec6e6b633bb81..c3462c69448db 100644 --- a/datafusion/functions-aggregate/Cargo.toml +++ b/datafusion/functions-aggregate/Cargo.toml @@ -49,6 +49,7 @@ datafusion-macros = { workspace = true } datafusion-physical-expr = { workspace = true } datafusion-physical-expr-common = { workspace = true } half = { workspace = true } +hashbrown = { workspace = true } log = { workspace = true } paste = "1.0.14" diff --git a/datafusion/functions-aggregate/src/count.rs b/datafusion/functions-aggregate/src/count.rs index 2d995b4a41793..9f8ee0a81593e 100644 --- a/datafusion/functions-aggregate/src/count.rs +++ b/datafusion/functions-aggregate/src/count.rs @@ -16,6 +16,8 @@ // under the License. use ahash::RandomState; +use arrow::array::{ArrowNativeTypeOp, ListArray, UInt64Array}; +use datafusion_common::hash_utils::combine_hashes; use datafusion_common::stats::Precision; use datafusion_expr::expr::WindowFunction; use datafusion_functions_aggregate_common::aggregate::count_distinct::BytesViewDistinctCountAccumulator; @@ -346,18 +348,22 @@ impl AggregateUDFImpl for Count { fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool { // groups accumulator only supports `COUNT(c1)`, not // `COUNT(c1, c2)`, etc - if args.is_distinct { - return false; - } + // if args.is_distinct { + // return false; + // } args.exprs.len() == 1 } fn create_groups_accumulator( &self, - _args: AccumulatorArgs, + args: AccumulatorArgs, ) -> Result> { // instantiate specialized accumulator - Ok(Box::new(CountGroupsAccumulator::new())) + if args.is_distinct { + Ok(Box::new(DistinctCountGroupsAccumulator::new())) + } else { + Ok(Box::new(CountGroupsAccumulator::new())) + } } fn reverse_expr(&self) -> ReversedUDAF { @@ -623,6 +629,192 @@ impl GroupsAccumulator for CountGroupsAccumulator { } } +/// An accumulator to compute the counts of [`PrimitiveArray`]. +/// Stores values as native types, and does overflow checking +/// +/// Unlike most other accumulators, COUNT never produces NULLs. If no +/// non-null values are seen in any group the output is 0. Thus, this +/// accumulator has no additional null or seen filter tracking. +#[derive(Debug)] +struct DistinctCountGroupsAccumulator { + /// Distinct count per group. + /// + /// Note this is an i64 and not a u64 (or usize) because the + /// output type of count is `DataType::Int64`. Thus by using `i64` + /// for the counts, the output [`Int64Array`] can be created + /// without copy. + counts: Vec>>>, + final_count: Vec, + + map: hashbrown::HashTable, + values: Vec, + group_indices: Vec, + random_state: RandomState, +} + +impl DistinctCountGroupsAccumulator { + pub fn new() -> Self { + Self { + counts: vec![], + final_count: vec![], + random_state: Default::default(), + map: hashbrown::HashTable::with_capacity(128), + values: Vec::with_capacity(128), + group_indices: Vec::with_capacity(128), + } + } +} + +impl GroupsAccumulator for DistinctCountGroupsAccumulator { + fn update_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + assert_eq!(values.len(), 1, "single argument to update_batch"); + let values = &values[0]; + + // Add one to each group's counter for each non null, non + // filtered value + self.counts.resize(total_num_groups, None); + + // println!("dt: {:?}", values.data_type()); + + // let mut rows: Vec = vec![]; + + let arr = values.as_primitive::(); + for (i, v) in arr.iter().enumerate() { + if let Some(key) = v { + let group_index = group_indices[i]; + let state = &self.random_state; + let hash = state.hash_one(key); + let hash = combine_hashes(hash, state.hash_one(group_index)); + + let insert = self.map.entry( + hash, + |g| unsafe { + self.group_indices.get_unchecked(*g) == &group_index + && self.values.get_unchecked(*g) == &key + }, + |g| unsafe { + let v = self.values.get_unchecked(*g); + let g = self.group_indices.get_unchecked(*g); + combine_hashes(state.hash_one(v), state.hash_one(g)) + }, + ); + + match insert { + hashbrown::hash_table::Entry::Occupied(o) => {}, + hashbrown::hash_table::Entry::Vacant(v) => { + let g = self.values.len(); + v.insert(g); + self.values.push(key); + self.group_indices.push(group_index); + // rows.push(i as u64); + + if let Some(existing_keys) = &mut self.counts[group_index] { + // If it's Some(Vec), just push the new key + existing_keys.push(Some(key)); + } else { + // If it's None, create a new Vec containing the key and assign it + self.counts[group_index] = Some(vec![Some(key)]) + } + } + } + } + } + + // let indices = UInt64Array::from(rows); + // let final_array = compute::take(arr, &indices, None)?; + + // combine group indices and value and insert into the hashset, + // iterate again with group indices only and count the value for the same group + + Ok(()) + } + + fn merge_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + debug_assert_eq!(values.len(), 1); + + self.final_count.resize(total_num_groups, 0); + + let list_arr = values[0].as_list::(); + for (i, counts) in list_arr.iter().enumerate() { + let group_index = group_indices[i]; + if let Some(counts) = counts { + + let counts_in_row = counts.as_primitive::(); + for key in counts_in_row.iter().flatten() { + + let state = &self.random_state; + let hash = state.hash_one(key); + let hash = combine_hashes(hash, state.hash_one(group_index)); + + let insert = self.map.entry( + hash, + |g| unsafe { + self.group_indices.get_unchecked(*g) == &group_index + && self.values.get_unchecked(*g) == &key + }, + |g| unsafe { + let v = self.values.get_unchecked(*g); + let g = self.group_indices.get_unchecked(*g); + combine_hashes(state.hash_one(v), state.hash_one(g)) + }, + ); + + match insert { + hashbrown::hash_table::Entry::Occupied(o) => {}, + hashbrown::hash_table::Entry::Vacant(v) => { + let g = self.values.len(); + v.insert(g); + self.values.push(key); + self.group_indices.push(group_index); + + self.final_count[group_index] += 1; + } + } + + } + } + } + + Ok(()) + } + + fn evaluate(&mut self, emit_to: EmitTo) -> Result { + let counts = emit_to.take_needed(&mut self.final_count); + + let nulls = None; + let array = PrimitiveArray::::new(counts.into(), nulls); + + Ok(Arc::new(array)) + + } + + fn state(&mut self, emit_to: EmitTo) -> Result> { + let counts = emit_to.take_needed(&mut self.counts); + let list_array = ListArray::from_iter_primitive::(counts); + Ok(vec![Arc::new(list_array)]) + } + + fn size(&self) -> usize { + self.counts.capacity() * size_of::() + + self.final_count.capacity() * size_of::() + + self.map.capacity() * size_of::() + + self.values.capacity() * size_of::() + + self.group_indices.capacity() * size_of::() + } +} + /// count null values for multiple columns /// for each row if one column value is null, then null_count + 1 fn null_count_for_multiple_cols(values: &[ArrayRef]) -> usize {