From 78761f2003a52451e93611f005e2fb4d2865fcc1 Mon Sep 17 00:00:00 2001 From: comphead Date: Sat, 20 Sep 2025 13:30:06 -0700 Subject: [PATCH 01/11] feat: do not fallback to Spark for distincts --- .../main/scala/org/apache/comet/serde/QueryPlanSerde.scala | 6 ------ 1 file changed, 6 deletions(-) diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 258d275e5b..d8bb97b867 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -558,12 +558,6 @@ object QueryPlanSerde extends Logging with CometExprShim { binding: Boolean, conf: SQLConf): Option[AggExpr] = { - if (aggExpr.isDistinct) { - // https://github.com/apache/datafusion-comet/issues/1260 - withInfo(aggExpr, s"distinct aggregate not supported: $aggExpr") - return None - } - val fn = aggExpr.aggregateFunction val cometExpr = aggrSerdeMap.get(fn.getClass) cometExpr match { From f81548d0dc58773f9acb59504dc036d9b2dd03cb Mon Sep 17 00:00:00 2001 From: comphead Date: Sun, 21 Sep 2025 09:41:04 -0700 Subject: [PATCH 02/11] fix: Adding more fuzz tests for `count(distinct)` --- .../comet/CometFuzzAggregateSuite.scala | 28 +++++++++++++++---- .../apache/comet/exec/CometExecSuite.scala | 3 -- 2 files changed, 22 insertions(+), 9 deletions(-) diff --git a/spark/src/test/scala/org/apache/comet/CometFuzzAggregateSuite.scala b/spark/src/test/scala/org/apache/comet/CometFuzzAggregateSuite.scala index 6466f8fc29..b021f3e1e9 100644 --- a/spark/src/test/scala/org/apache/comet/CometFuzzAggregateSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometFuzzAggregateSuite.scala @@ -26,12 +26,26 @@ class CometFuzzAggregateSuite extends CometFuzzTestBase { df.createOrReplaceTempView("t1") for (col <- df.columns) { val sql = s"SELECT count(distinct $col) FROM t1" - // Comet does not support count distinct yet - // https://github.com/apache/datafusion-comet/issues/2292 val (_, cometPlan) = checkSparkAnswer(sql) if (usingDataSourceExec) { assert(1 == collectNativeScans(cometPlan).length) } + + checkSparkAnswerAndOperator(sql) + } + } + + test("count distinct group by multpiple column") { + val df = spark.read.parquet(filename) + df.createOrReplaceTempView("t1") + for (col <- df.columns) { + val sql = s"SELECT c1, c2, c3, count(distinct $col) FROM t1 group by c1, c2, c3" + val (_, cometPlan) = checkSparkAnswer(sql) + if (usingDataSourceExec) { + assert(1 == collectNativeScans(cometPlan).length) + } + + checkSparkAnswerAndOperator(sql) } } @@ -39,12 +53,13 @@ class CometFuzzAggregateSuite extends CometFuzzTestBase { val df = spark.read.parquet(filename) df.createOrReplaceTempView("t1") for (col <- df.columns) { - // cannot run fully natively due to range partitioning and sort val sql = s"SELECT $col, count(*) FROM t1 GROUP BY $col ORDER BY $col" val (_, cometPlan) = checkSparkAnswer(sql) if (usingDataSourceExec) { assert(1 == collectNativeScans(cometPlan).length) } + + checkSparkAnswerAndOperator(sql) } } @@ -53,12 +68,13 @@ class CometFuzzAggregateSuite extends CometFuzzTestBase { df.createOrReplaceTempView("t1") val groupCol = df.columns.head for (col <- df.columns.drop(1)) { - // cannot run fully natively due to range partitioning and sort val sql = s"SELECT $groupCol, count($col) FROM t1 GROUP BY $groupCol ORDER BY $groupCol" val (_, cometPlan) = checkSparkAnswer(sql) if (usingDataSourceExec) { assert(1 == collectNativeScans(cometPlan).length) } + + checkSparkAnswerAndOperator(sql) } } @@ -67,13 +83,14 @@ class CometFuzzAggregateSuite extends CometFuzzTestBase { df.createOrReplaceTempView("t1") val groupCol = df.columns.head val otherCol = df.columns.drop(1) - // cannot run fully natively due to range partitioning and sort val sql = s"SELECT $groupCol, count(${otherCol.mkString(", ")}) FROM t1 " + s"GROUP BY $groupCol ORDER BY $groupCol" val (_, cometPlan) = checkSparkAnswer(sql) if (usingDataSourceExec) { assert(1 == collectNativeScans(cometPlan).length) } + + checkSparkAnswerAndOperator(sql) } test("min/max aggregate") { @@ -88,5 +105,4 @@ class CometFuzzAggregateSuite extends CometFuzzTestBase { } } } - } diff --git a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala index 47d2205a08..c9992f0503 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala @@ -1031,9 +1031,6 @@ class CometExecSuite extends CometTestBase { |GROUP BY key """.stripMargin) - // The above query uses COUNT(DISTINCT) which Comet doesn't support yet, so the plan will - // have a mix of `HashAggregate` and `CometHashAggregate`. In the following we check all - // operators starting from `CometHashAggregate` are native. checkSparkAnswer(df) val subPlan = stripAQEPlan(df.queryExecution.executedPlan).collectFirst { case s: CometHashAggregateExec => s From b795ec4cfe7df7884d1352fb1cd896b8c249b7ec Mon Sep 17 00:00:00 2001 From: comphead Date: Sun, 21 Sep 2025 09:48:15 -0700 Subject: [PATCH 03/11] fix: Adding more fuzz tests for `count(distinct)` --- .../test/scala/org/apache/comet/CometFuzzAggregateSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spark/src/test/scala/org/apache/comet/CometFuzzAggregateSuite.scala b/spark/src/test/scala/org/apache/comet/CometFuzzAggregateSuite.scala index b021f3e1e9..c1d1df2bf4 100644 --- a/spark/src/test/scala/org/apache/comet/CometFuzzAggregateSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometFuzzAggregateSuite.scala @@ -35,7 +35,7 @@ class CometFuzzAggregateSuite extends CometFuzzTestBase { } } - test("count distinct group by multpiple column") { + test("count distinct group by multiple column") { val df = spark.read.parquet(filename) df.createOrReplaceTempView("t1") for (col <- df.columns) { From 4177f63d48e24fc4f6852ac27fa11f1ddcf9240c Mon Sep 17 00:00:00 2001 From: comphead Date: Mon, 22 Sep 2025 20:33:55 -0700 Subject: [PATCH 04/11] feat: do not fallback to Spark for distincts --- .../org/apache/comet/CometFuzzAggregateSuite.scala | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/spark/src/test/scala/org/apache/comet/CometFuzzAggregateSuite.scala b/spark/src/test/scala/org/apache/comet/CometFuzzAggregateSuite.scala index c1d1df2bf4..9ad92ba030 100644 --- a/spark/src/test/scala/org/apache/comet/CometFuzzAggregateSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometFuzzAggregateSuite.scala @@ -30,8 +30,6 @@ class CometFuzzAggregateSuite extends CometFuzzTestBase { if (usingDataSourceExec) { assert(1 == collectNativeScans(cometPlan).length) } - - checkSparkAnswerAndOperator(sql) } } @@ -44,8 +42,6 @@ class CometFuzzAggregateSuite extends CometFuzzTestBase { if (usingDataSourceExec) { assert(1 == collectNativeScans(cometPlan).length) } - - checkSparkAnswerAndOperator(sql) } } @@ -58,8 +54,6 @@ class CometFuzzAggregateSuite extends CometFuzzTestBase { if (usingDataSourceExec) { assert(1 == collectNativeScans(cometPlan).length) } - - checkSparkAnswerAndOperator(sql) } } @@ -73,8 +67,6 @@ class CometFuzzAggregateSuite extends CometFuzzTestBase { if (usingDataSourceExec) { assert(1 == collectNativeScans(cometPlan).length) } - - checkSparkAnswerAndOperator(sql) } } @@ -89,8 +81,6 @@ class CometFuzzAggregateSuite extends CometFuzzTestBase { if (usingDataSourceExec) { assert(1 == collectNativeScans(cometPlan).length) } - - checkSparkAnswerAndOperator(sql) } test("min/max aggregate") { From 554abaf5443c19a523d011872d9056bc9cce407c Mon Sep 17 00:00:00 2001 From: comphead Date: Tue, 23 Sep 2025 13:46:23 -0700 Subject: [PATCH 05/11] feat: do not fallback to Spark for distincts --- .../org/apache/comet/CometFuzzAggregateSuite.scala | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/spark/src/test/scala/org/apache/comet/CometFuzzAggregateSuite.scala b/spark/src/test/scala/org/apache/comet/CometFuzzAggregateSuite.scala index 9ad92ba030..14d64b2a91 100644 --- a/spark/src/test/scala/org/apache/comet/CometFuzzAggregateSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometFuzzAggregateSuite.scala @@ -45,6 +45,18 @@ class CometFuzzAggregateSuite extends CometFuzzTestBase { } } + test("count distinct multiple values and group by multiple column") { + val df = spark.read.parquet(filename) + df.createOrReplaceTempView("t1") + for (col <- df.columns) { + val sql = s"SELECT c1, c2, c3, count(distinct $col, c4, c5) FROM t1 group by c1, c2, c3" + val (_, cometPlan) = checkSparkAnswer(sql) + if (usingDataSourceExec) { + assert(1 == collectNativeScans(cometPlan).length) + } + } + } + test("count(*) group by single column") { val df = spark.read.parquet(filename) df.createOrReplaceTempView("t1") From 1356d511b4fa7f40962b0eecfba9f59604fc87bf Mon Sep 17 00:00:00 2001 From: comphead Date: Tue, 23 Sep 2025 15:54:51 -0700 Subject: [PATCH 06/11] feat: do not fallback to Spark for distincts --- .../scala/org/apache/comet/serde/QueryPlanSerde.scala | 9 +++++++++ .../scala/org/apache/comet/CometFuzzAggregateSuite.scala | 2 ++ 2 files changed, 11 insertions(+) diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index d8bb97b867..81d3b2fda8 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -558,6 +558,15 @@ object QueryPlanSerde extends Logging with CometExprShim { binding: Boolean, conf: SQLConf): Option[AggExpr] = { + // Support Count(distinct single_value) + // COUNT(DISTINCT x) - supported + // COUNT(DISTINCT x, x) - supported through transition to COUNT(DISTINCT x) + // COUNT(DISTINCT x, y) - not supported + if (aggExpr.isDistinct && (aggExpr.aggregateFunction.prettyName.toLowerCase == "count" && aggExpr.aggregateFunction.children.length == 1)) { + withInfo(aggExpr, s"Distinct aggregate not supported for: $aggExpr") + return None + } + val fn = aggExpr.aggregateFunction val cometExpr = aggrSerdeMap.get(fn.getClass) cometExpr match { diff --git a/spark/src/test/scala/org/apache/comet/CometFuzzAggregateSuite.scala b/spark/src/test/scala/org/apache/comet/CometFuzzAggregateSuite.scala index 14d64b2a91..2748a8d7a5 100644 --- a/spark/src/test/scala/org/apache/comet/CometFuzzAggregateSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometFuzzAggregateSuite.scala @@ -30,6 +30,8 @@ class CometFuzzAggregateSuite extends CometFuzzTestBase { if (usingDataSourceExec) { assert(1 == collectNativeScans(cometPlan).length) } + + checkSparkAnswerAndOperator(sql) } } From 66103c9e02b30790cba26ef97cc2172aaaaabbcc Mon Sep 17 00:00:00 2001 From: comphead Date: Tue, 23 Sep 2025 16:31:15 -0700 Subject: [PATCH 07/11] feat: do not fallback to Spark for distincts --- .../main/scala/org/apache/comet/serde/QueryPlanSerde.scala | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 81d3b2fda8..7a19b75ff3 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -562,7 +562,10 @@ object QueryPlanSerde extends Logging with CometExprShim { // COUNT(DISTINCT x) - supported // COUNT(DISTINCT x, x) - supported through transition to COUNT(DISTINCT x) // COUNT(DISTINCT x, y) - not supported - if (aggExpr.isDistinct && (aggExpr.aggregateFunction.prettyName.toLowerCase == "count" && aggExpr.aggregateFunction.children.length == 1)) { + if (aggExpr.isDistinct + && + (aggExpr.aggregateFunction.prettyName == "count" && + aggExpr.aggregateFunction.children.length == 1)) { withInfo(aggExpr, s"Distinct aggregate not supported for: $aggExpr") return None } From f029bf720804fb660377088ce2da3b161bd3b1ec Mon Sep 17 00:00:00 2001 From: comphead Date: Wed, 24 Sep 2025 08:53:22 -0700 Subject: [PATCH 08/11] feat: do not fallback to Spark for distincts --- .../apache/comet/serde/QueryPlanSerde.scala | 2 +- .../comet/CometFuzzAggregateSuite.scala | 42 +++++++++++++++++-- 2 files changed, 39 insertions(+), 5 deletions(-) diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 7a19b75ff3..892d8bca63 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -564,7 +564,7 @@ object QueryPlanSerde extends Logging with CometExprShim { // COUNT(DISTINCT x, y) - not supported if (aggExpr.isDistinct && - (aggExpr.aggregateFunction.prettyName == "count" && + !(aggExpr.aggregateFunction.prettyName == "count" && aggExpr.aggregateFunction.children.length == 1)) { withInfo(aggExpr, s"Distinct aggregate not supported for: $aggExpr") return None diff --git a/spark/src/test/scala/org/apache/comet/CometFuzzAggregateSuite.scala b/spark/src/test/scala/org/apache/comet/CometFuzzAggregateSuite.scala index 2748a8d7a5..0c222cfcfd 100644 --- a/spark/src/test/scala/org/apache/comet/CometFuzzAggregateSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometFuzzAggregateSuite.scala @@ -19,12 +19,14 @@ package org.apache.comet +import org.apache.comet.DataTypeSupport.isComplexType + class CometFuzzAggregateSuite extends CometFuzzTestBase { - test("count distinct") { + test("count distinct - simple columns") { val df = spark.read.parquet(filename) df.createOrReplaceTempView("t1") - for (col <- df.columns) { + for (col <- df.schema.fields.filterNot(f => isComplexType(f.dataType)).map(_.name)) { val sql = s"SELECT count(distinct $col) FROM t1" val (_, cometPlan) = checkSparkAnswer(sql) if (usingDataSourceExec) { @@ -35,10 +37,40 @@ class CometFuzzAggregateSuite extends CometFuzzTestBase { } } - test("count distinct group by multiple column") { + // Aggregate by complex columns not yet supported + // https://github.com/apache/datafusion-comet/issues/2382 + test("count distinct - complex columns") { val df = spark.read.parquet(filename) df.createOrReplaceTempView("t1") - for (col <- df.columns) { + for (col <- df.schema.fields.filter(f => isComplexType(f.dataType)).map(_.name)) { + val sql = s"SELECT count(distinct $col) FROM t1" + val (_, cometPlan) = checkSparkAnswer(sql) + if (usingDataSourceExec) { + assert(1 == collectNativeScans(cometPlan).length) + } + } + } + + test("count distinct group by multiple column - simple columns ") { + val df = spark.read.parquet(filename) + df.createOrReplaceTempView("t1") + for (col <- df.schema.fields.filterNot(f => isComplexType(f.dataType)).map(_.name)) { + val sql = s"SELECT c1, c2, c3, count(distinct $col) FROM t1 group by c1, c2, c3" + val (_, cometPlan) = checkSparkAnswer(sql) + if (usingDataSourceExec) { + assert(1 == collectNativeScans(cometPlan).length) + } + + checkSparkAnswerAndOperator(sql) + } + } + + // Aggregate by complex columns not yet supported + // https://github.com/apache/datafusion-comet/issues/2382 + test("count distinct group by multiple column - complex columns ") { + val df = spark.read.parquet(filename) + df.createOrReplaceTempView("t1") + for (col <- df.schema.fields.filter(f => isComplexType(f.dataType)).map(_.name)) { val sql = s"SELECT c1, c2, c3, count(distinct $col) FROM t1 group by c1, c2, c3" val (_, cometPlan) = checkSparkAnswer(sql) if (usingDataSourceExec) { @@ -47,6 +79,8 @@ class CometFuzzAggregateSuite extends CometFuzzTestBase { } } + // Not yet supported + // https://github.com/apache/datafusion-comet/issues/2292 test("count distinct multiple values and group by multiple column") { val df = spark.read.parquet(filename) df.createOrReplaceTempView("t1") From 0c20769e9ed9c7e89e4c2a72a309461dc72c376b Mon Sep 17 00:00:00 2001 From: comphead Date: Wed, 24 Sep 2025 10:04:01 -0700 Subject: [PATCH 09/11] feat: do not fallback to Spark for distincts --- .../test/scala/org/apache/comet/CometFuzzAggregateSuite.scala | 2 +- spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/spark/src/test/scala/org/apache/comet/CometFuzzAggregateSuite.scala b/spark/src/test/scala/org/apache/comet/CometFuzzAggregateSuite.scala index 0c222cfcfd..19812f38ce 100644 --- a/spark/src/test/scala/org/apache/comet/CometFuzzAggregateSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometFuzzAggregateSuite.scala @@ -79,7 +79,7 @@ class CometFuzzAggregateSuite extends CometFuzzTestBase { } } - // Not yet supported + // COUNT(distinct x, y, z, ...) not yet supported // https://github.com/apache/datafusion-comet/issues/2292 test("count distinct multiple values and group by multiple column") { val df = spark.read.parquet(filename) diff --git a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala index c9992f0503..5dfc4cbac2 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala @@ -1031,6 +1031,8 @@ class CometExecSuite extends CometTestBase { |GROUP BY key """.stripMargin) + // The above query uses SUM(DISTINCT) and count(distinct value1, value2) + // which is not yet supported checkSparkAnswer(df) val subPlan = stripAQEPlan(df.queryExecution.executedPlan).collectFirst { case s: CometHashAggregateExec => s From 9bb890e5d158a2c9c5d04db8b9d0281b638f0da3 Mon Sep 17 00:00:00 2001 From: comphead Date: Wed, 24 Sep 2025 15:09:53 -0700 Subject: [PATCH 10/11] feat: do not fallback to Spark for distincts --- .../benchmark/CometAggregateBenchmark.scala | 65 +++++++++++++------ 1 file changed, 44 insertions(+), 21 deletions(-) diff --git a/spark/src/test/scala/org/apache/spark/sql/benchmark/CometAggregateBenchmark.scala b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometAggregateBenchmark.scala index 47fbe354f5..3c6314fa0d 100644 --- a/spark/src/test/scala/org/apache/spark/sql/benchmark/CometAggregateBenchmark.scala +++ b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometAggregateBenchmark.scala @@ -44,11 +44,11 @@ object CometAggregateBenchmark extends CometBenchmarkBase { def singleGroupAndAggregate( values: Int, groupingKeyCardinality: Int, - aggregateFunction: String): Unit = { + aggregateFunction: BenchAggregateFunction): Unit = { val benchmark = new Benchmark( s"Grouped HashAgg Exec: single group key (cardinality $groupingKeyCardinality), " + - s"single aggregate $aggregateFunction", + s"single aggregate ${aggregateFunction.name}", values, output = output) @@ -58,13 +58,14 @@ object CometAggregateBenchmark extends CometBenchmarkBase { dir, spark.sql(s"SELECT value, floor(rand() * $groupingKeyCardinality) as key FROM $tbl")) - val query = s"SELECT key, $aggregateFunction(value) FROM parquetV1Table GROUP BY key" + val functionSQL = aggFunctionSQL(aggregateFunction, "value") + val query = s"SELECT key, $functionSQL FROM parquetV1Table GROUP BY key" - benchmark.addCase(s"SQL Parquet - Spark ($aggregateFunction)") { _ => + benchmark.addCase(s"SQL Parquet - Spark (${aggregateFunction.name})") { _ => spark.sql(query).noop() } - benchmark.addCase(s"SQL Parquet - Comet ($aggregateFunction)") { _ => + benchmark.addCase(s"SQL Parquet - Comet (${aggregateFunction.name})") { _ => withSQLConf( CometConf.COMET_ENABLED.key -> "true", CometConf.COMET_EXEC_ENABLED.key -> "true") { @@ -77,11 +78,15 @@ object CometAggregateBenchmark extends CometBenchmarkBase { } } + def aggFunctionSQL(aggregateFunction: BenchAggregateFunction, input: String): String = { + s"${aggregateFunction.name}(${if (aggregateFunction.distinct) s"DISTINCT $input" else input})" + } + def singleGroupAndAggregateDecimal( values: Int, dataType: DecimalType, groupingKeyCardinality: Int, - aggregateFunction: String): Unit = { + aggregateFunction: BenchAggregateFunction): Unit = { val benchmark = new Benchmark( s"Grouped HashAgg Exec: single group key (cardinality $groupingKeyCardinality), " + @@ -99,13 +104,14 @@ object CometAggregateBenchmark extends CometBenchmarkBase { spark.sql( s"SELECT dec as value, floor(rand() * $groupingKeyCardinality) as key FROM $tbl")) - val query = s"SELECT key, $aggregateFunction(value) FROM parquetV1Table GROUP BY key" + val functionSQL = aggFunctionSQL(aggregateFunction, "value") + val query = s"SELECT key, $functionSQL FROM parquetV1Table GROUP BY key" - benchmark.addCase(s"SQL Parquet - Spark ($aggregateFunction)") { _ => + benchmark.addCase(s"SQL Parquet - Spark (${aggregateFunction.name})") { _ => spark.sql(query).noop() } - benchmark.addCase(s"SQL Parquet - Comet ($aggregateFunction)") { _ => + benchmark.addCase(s"SQL Parquet - Comet (${aggregateFunction.name})") { _ => withSQLConf( CometConf.COMET_ENABLED.key -> "true", CometConf.COMET_EXEC_ENABLED.key -> "true") { @@ -118,11 +124,14 @@ object CometAggregateBenchmark extends CometBenchmarkBase { } } - def multiGroupKeys(values: Int, groupingKeyCard: Int, aggregateFunction: String): Unit = { + def multiGroupKeys( + values: Int, + groupingKeyCard: Int, + aggregateFunction: BenchAggregateFunction): Unit = { val benchmark = new Benchmark( s"Grouped HashAgg Exec: multiple group keys (cardinality $groupingKeyCard), " + - s"single aggregate $aggregateFunction", + s"single aggregate ${aggregateFunction.name}", values, output = output) @@ -134,14 +143,15 @@ object CometAggregateBenchmark extends CometBenchmarkBase { s"SELECT value, floor(rand() * $groupingKeyCard) as key1, " + s"floor(rand() * $groupingKeyCard) as key2 FROM $tbl")) + val functionSQL = aggFunctionSQL(aggregateFunction, "value") val query = - s"SELECT key1, key2, $aggregateFunction(value) FROM parquetV1Table GROUP BY key1, key2" + s"SELECT key1, key2, $functionSQL FROM parquetV1Table GROUP BY key1, key2" - benchmark.addCase(s"SQL Parquet - Spark ($aggregateFunction)") { _ => + benchmark.addCase(s"SQL Parquet - Spark (${aggregateFunction.name})") { _ => spark.sql(query).noop() } - benchmark.addCase(s"SQL Parquet - Comet ($aggregateFunction)") { _ => + benchmark.addCase(s"SQL Parquet - Comet (${aggregateFunction.name})") { _ => withSQLConf( CometConf.COMET_ENABLED.key -> "true", CometConf.COMET_EXEC_ENABLED.key -> "true", @@ -155,11 +165,14 @@ object CometAggregateBenchmark extends CometBenchmarkBase { } } - def multiAggregates(values: Int, groupingKeyCard: Int, aggregateFunction: String): Unit = { + def multiAggregates( + values: Int, + groupingKeyCard: Int, + aggregateFunction: BenchAggregateFunction): Unit = { val benchmark = new Benchmark( s"Grouped HashAgg Exec: single group key (cardinality $groupingKeyCard), " + - s"multiple aggregates $aggregateFunction", + s"multiple aggregates ${aggregateFunction.name}", values, output = output) @@ -171,14 +184,17 @@ object CometAggregateBenchmark extends CometBenchmarkBase { s"SELECT value as value1, value as value2, floor(rand() * $groupingKeyCard) as key " + s"FROM $tbl")) - val query = s"SELECT key, $aggregateFunction(value1), $aggregateFunction(value2) " + + val functionSQL1 = aggFunctionSQL(aggregateFunction, "value1") + val functionSQL2 = aggFunctionSQL(aggregateFunction, "value2") + + val query = s"SELECT key, $functionSQL1, $functionSQL2 " + "FROM parquetV1Table GROUP BY key" - benchmark.addCase(s"SQL Parquet - Spark ($aggregateFunction)") { _ => + benchmark.addCase(s"SQL Parquet - Spark (${aggregateFunction.name})") { _ => spark.sql(query).noop() } - benchmark.addCase(s"SQL Parquet - Comet ($aggregateFunction)") { _ => + benchmark.addCase(s"SQL Parquet - Comet (${aggregateFunction.name})") { _ => withSQLConf( CometConf.COMET_ENABLED.key -> "true", CometConf.COMET_EXEC_ENABLED.key -> "true") { @@ -191,12 +207,19 @@ object CometAggregateBenchmark extends CometBenchmarkBase { } } + case class BenchAggregateFunction(name: String, distinct: Boolean = false) + private val benchmarkAggFuncs = Seq( + BenchAggregateFunction("SUM"), + BenchAggregateFunction("MIN"), + BenchAggregateFunction("MAX"), + BenchAggregateFunction("COUNT"), + BenchAggregateFunction("COUNT", distinct = true)) + override def runCometBenchmark(mainArgs: Array[String]): Unit = { val total = 1024 * 1024 * 10 val combinations = List(100, 1024, 1024 * 1024) // number of distinct groups - val aggregateFunctions = List("SUM", "MIN", "MAX", "COUNT") - aggregateFunctions.foreach { aggFunc => + benchmarkAggFuncs.foreach { aggFunc => runBenchmarkWithTable( s"Grouped Aggregate (single group key + single aggregate $aggFunc)", total) { v => From 9d3a40f9b13c10b72adc78d5e3c1ccc955d3ff65 Mon Sep 17 00:00:00 2001 From: comphead Date: Thu, 25 Sep 2025 07:49:08 -0700 Subject: [PATCH 11/11] feat: do not fallback to Spark for distincts --- .../benchmark/CometAggregateBenchmark.scala | 53 ++++++++++--------- 1 file changed, 29 insertions(+), 24 deletions(-) diff --git a/spark/src/test/scala/org/apache/spark/sql/benchmark/CometAggregateBenchmark.scala b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometAggregateBenchmark.scala index 3c6314fa0d..1efd3974ed 100644 --- a/spark/src/test/scala/org/apache/spark/sql/benchmark/CometAggregateBenchmark.scala +++ b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometAggregateBenchmark.scala @@ -41,6 +41,23 @@ object CometAggregateBenchmark extends CometBenchmarkBase { session } + // Wrapper on SQL aggregation function + case class BenchAggregateFunction(name: String, distinct: Boolean = false) { + override def toString: String = if (distinct) s"$name(DISTINCT)" else name + } + + // Aggregation functions to test + private val benchmarkAggFuncs = Seq( + BenchAggregateFunction("SUM"), + BenchAggregateFunction("MIN"), + BenchAggregateFunction("MAX"), + BenchAggregateFunction("COUNT"), + BenchAggregateFunction("COUNT", distinct = true)) + + def aggFunctionSQL(aggregateFunction: BenchAggregateFunction, input: String): String = { + s"${aggregateFunction.name}(${if (aggregateFunction.distinct) s"DISTINCT $input" else input})" + } + def singleGroupAndAggregate( values: Int, groupingKeyCardinality: Int, @@ -48,7 +65,7 @@ object CometAggregateBenchmark extends CometBenchmarkBase { val benchmark = new Benchmark( s"Grouped HashAgg Exec: single group key (cardinality $groupingKeyCardinality), " + - s"single aggregate ${aggregateFunction.name}", + s"single aggregate ${aggregateFunction.toString}", values, output = output) @@ -61,11 +78,11 @@ object CometAggregateBenchmark extends CometBenchmarkBase { val functionSQL = aggFunctionSQL(aggregateFunction, "value") val query = s"SELECT key, $functionSQL FROM parquetV1Table GROUP BY key" - benchmark.addCase(s"SQL Parquet - Spark (${aggregateFunction.name})") { _ => + benchmark.addCase(s"SQL Parquet - Spark (${aggregateFunction.toString})") { _ => spark.sql(query).noop() } - benchmark.addCase(s"SQL Parquet - Comet (${aggregateFunction.name})") { _ => + benchmark.addCase(s"SQL Parquet - Comet (${aggregateFunction.toString})") { _ => withSQLConf( CometConf.COMET_ENABLED.key -> "true", CometConf.COMET_EXEC_ENABLED.key -> "true") { @@ -78,10 +95,6 @@ object CometAggregateBenchmark extends CometBenchmarkBase { } } - def aggFunctionSQL(aggregateFunction: BenchAggregateFunction, input: String): String = { - s"${aggregateFunction.name}(${if (aggregateFunction.distinct) s"DISTINCT $input" else input})" - } - def singleGroupAndAggregateDecimal( values: Int, dataType: DecimalType, @@ -90,7 +103,7 @@ object CometAggregateBenchmark extends CometBenchmarkBase { val benchmark = new Benchmark( s"Grouped HashAgg Exec: single group key (cardinality $groupingKeyCardinality), " + - s"single aggregate $aggregateFunction on decimal", + s"single aggregate ${aggregateFunction.toString} on decimal", values, output = output) @@ -107,11 +120,11 @@ object CometAggregateBenchmark extends CometBenchmarkBase { val functionSQL = aggFunctionSQL(aggregateFunction, "value") val query = s"SELECT key, $functionSQL FROM parquetV1Table GROUP BY key" - benchmark.addCase(s"SQL Parquet - Spark (${aggregateFunction.name})") { _ => + benchmark.addCase(s"SQL Parquet - Spark (${aggregateFunction.toString})") { _ => spark.sql(query).noop() } - benchmark.addCase(s"SQL Parquet - Comet (${aggregateFunction.name})") { _ => + benchmark.addCase(s"SQL Parquet - Comet (${aggregateFunction.toString})") { _ => withSQLConf( CometConf.COMET_ENABLED.key -> "true", CometConf.COMET_EXEC_ENABLED.key -> "true") { @@ -131,7 +144,7 @@ object CometAggregateBenchmark extends CometBenchmarkBase { val benchmark = new Benchmark( s"Grouped HashAgg Exec: multiple group keys (cardinality $groupingKeyCard), " + - s"single aggregate ${aggregateFunction.name}", + s"single aggregate ${aggregateFunction.toString}", values, output = output) @@ -147,11 +160,11 @@ object CometAggregateBenchmark extends CometBenchmarkBase { val query = s"SELECT key1, key2, $functionSQL FROM parquetV1Table GROUP BY key1, key2" - benchmark.addCase(s"SQL Parquet - Spark (${aggregateFunction.name})") { _ => + benchmark.addCase(s"SQL Parquet - Spark (${aggregateFunction.toString})") { _ => spark.sql(query).noop() } - benchmark.addCase(s"SQL Parquet - Comet (${aggregateFunction.name})") { _ => + benchmark.addCase(s"SQL Parquet - Comet (${aggregateFunction.toString})") { _ => withSQLConf( CometConf.COMET_ENABLED.key -> "true", CometConf.COMET_EXEC_ENABLED.key -> "true", @@ -172,7 +185,7 @@ object CometAggregateBenchmark extends CometBenchmarkBase { val benchmark = new Benchmark( s"Grouped HashAgg Exec: single group key (cardinality $groupingKeyCard), " + - s"multiple aggregates ${aggregateFunction.name}", + s"multiple aggregates ${aggregateFunction.toString}", values, output = output) @@ -190,11 +203,11 @@ object CometAggregateBenchmark extends CometBenchmarkBase { val query = s"SELECT key, $functionSQL1, $functionSQL2 " + "FROM parquetV1Table GROUP BY key" - benchmark.addCase(s"SQL Parquet - Spark (${aggregateFunction.name})") { _ => + benchmark.addCase(s"SQL Parquet - Spark (${aggregateFunction.toString})") { _ => spark.sql(query).noop() } - benchmark.addCase(s"SQL Parquet - Comet (${aggregateFunction.name})") { _ => + benchmark.addCase(s"SQL Parquet - Comet (${aggregateFunction.toString})") { _ => withSQLConf( CometConf.COMET_ENABLED.key -> "true", CometConf.COMET_EXEC_ENABLED.key -> "true") { @@ -207,14 +220,6 @@ object CometAggregateBenchmark extends CometBenchmarkBase { } } - case class BenchAggregateFunction(name: String, distinct: Boolean = false) - private val benchmarkAggFuncs = Seq( - BenchAggregateFunction("SUM"), - BenchAggregateFunction("MIN"), - BenchAggregateFunction("MAX"), - BenchAggregateFunction("COUNT"), - BenchAggregateFunction("COUNT", distinct = true)) - override def runCometBenchmark(mainArgs: Array[String]): Unit = { val total = 1024 * 1024 * 10 val combinations = List(100, 1024, 1024 * 1024) // number of distinct groups