From abb7a3d906518828e35c0f90cf12042f59aaa8a0 Mon Sep 17 00:00:00 2001 From: Bruce Robbins Date: Wed, 21 Aug 2024 18:38:32 -0700 Subject: [PATCH 1/7] potential fix --- .../optimizer/RewriteDistinctAggregates.scala | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala index 801bd2693af42..8bd2fa0ba956f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala @@ -406,10 +406,14 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { // The same GROUP BY clauses can have different forms (different names for instance) in // the groupBy and aggregate expressions of an aggregate. This makes a map lookup // tricky. So we do a linear search for a semantically equal group by expression. - groupByMap - .find(ge => e.semanticEquals(ge._1)) - .map(_._2) - .getOrElse(transformations.getOrElse(e, e)) + if (e.foldable) { + e + } else { + groupByMap + .find(ge => e.semanticEquals(ge._1)) + .map(_._2) + .getOrElse(transformations.getOrElse(e, e)) + } }.asInstanceOf[NamedExpression] } Aggregate(groupByAttrs, patchedAggExpressions, firstAggregate) From 7e65416e26da5c8481c093f343cc714d7ad61699 Mon Sep 17 00:00:00 2001 From: Bruce Robbins Date: Thu, 22 Aug 2024 13:53:35 -0700 Subject: [PATCH 2/7] Test --- .../RewriteDistinctAggregatesSuite.scala | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala index ac136dfb898ef..22b000c23d612 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.expressions.Literal +import org.apache.spark.sql.catalyst.expressions.{Literal, Round} import org.apache.spark.sql.catalyst.expressions.aggregate.CollectSet import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Expand, LocalRelation, LogicalPlan} @@ -109,4 +109,20 @@ class RewriteDistinctAggregatesSuite extends PlanTest { case _ => fail(s"Plan is not rewritten:\n$rewrite") } } + + test("plunk") { + val relation = testRelation2 + .select(Literal(6).as("gb"), $"a", $"b", $"c", $"d") + val input = relation + .groupBy($"a", $"gb")( + countDistinct($"b").as("agg1"), + countDistinct($"d").as("agg2"), + Round(sum($"c").as("sum1"), 6)).analyze + val rewriteFold = FoldablePropagation(input) + // without the fix, the below produces an unresolved plan + val rewrite = RewriteDistinctAggregates(rewriteFold) + if (!rewrite.resolved) { + fail(s"Plan is not as expected:\n$rewrite") + } + } } From 9df15f6c403ecbeaa0f83240147a92f35fb0d34b Mon Sep 17 00:00:00 2001 From: Bruce Robbins Date: Sun, 25 Aug 2024 15:54:36 -0700 Subject: [PATCH 3/7] Update --- .../optimizer/RewriteDistinctAggregates.scala | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala index 8bd2fa0ba956f..5aef82b64ed32 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala @@ -400,20 +400,17 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { (distinctAggOperatorMap.flatMap(_._2) ++ regularAggOperatorMap.map(e => (e._1, e._3))).toMap + val groupByMapNonFoldable = groupByMap.filter(!_._1.foldable) val patchedAggExpressions = a.aggregateExpressions.map { e => e.transformDown { case e: Expression => // The same GROUP BY clauses can have different forms (different names for instance) in // the groupBy and aggregate expressions of an aggregate. This makes a map lookup // tricky. So we do a linear search for a semantically equal group by expression. - if (e.foldable) { - e - } else { - groupByMap - .find(ge => e.semanticEquals(ge._1)) - .map(_._2) - .getOrElse(transformations.getOrElse(e, e)) - } + groupByMapNonFoldable + .find(ge => e.semanticEquals(ge._1)) + .map(_._2) + .getOrElse(transformations.getOrElse(e, e)) }.asInstanceOf[NamedExpression] } Aggregate(groupByAttrs, patchedAggExpressions, firstAggregate) From 4be26aeac48e37e3a55c24c5463fbfb18a4d39cc Mon Sep 17 00:00:00 2001 From: Bruce Robbins Date: Sun, 25 Aug 2024 16:15:00 -0700 Subject: [PATCH 4/7] Add test --- .../spark/sql/DataFrameAggregateSuite.scala | 21 +++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 0e9d34c3bd96a..4f09a2edc4561 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -2490,6 +2490,27 @@ class DataFrameAggregateSuite extends QueryTest }) } } + + test("plunk") { + val data = Seq((1, 1.001d, 2), (2, 3.001d, 4), (2, 3.001, 4)).toDF("a", "b", "c") + withTempView("v1") { + data.createOrReplaceTempView("v1") + val df = + sql("""select + | round(sum(b), 6) as sum1, + | count(distinct a) as count1, + | count(distinct c) as count2 + | from ( + | select + | 6 as gb, + | * + | from v1 + | ) + |group by a, gb + |""".stripMargin) + checkAnswer(df, Row(1.001d, 1, 1) :: Row(6.002d, 1, 1) :: Nil) + } + } } case class B(c: Option[Double]) From 5e5788f9e433653037bafaec269fb4ba2ed17ad9 Mon Sep 17 00:00:00 2001 From: Bruce Robbins Date: Mon, 26 Aug 2024 12:20:08 -0700 Subject: [PATCH 5/7] Update test names --- .../optimizer/RewriteDistinctAggregatesSuite.scala | 2 +- .../apache/spark/sql/DataFrameAggregateSuite.scala | 14 +++++++------- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala index 22b000c23d612..807cfa4781be1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala @@ -110,7 +110,7 @@ class RewriteDistinctAggregatesSuite extends PlanTest { } } - test("plunk") { + test("SPARK-49261: Don't patch literals in aggregate expressions with group-by expressions") { val relation = testRelation2 .select(Literal(6).as("gb"), $"a", $"b", $"c", $"d") val input = relation diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 4f09a2edc4561..9ae0b1b8cbd62 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -2491,7 +2491,7 @@ class DataFrameAggregateSuite extends QueryTest } } - test("plunk") { + test("SPARK-49261: Don't patch literals in aggregate expressions with group-by expressions") { val data = Seq((1, 1.001d, 2), (2, 3.001d, 4), (2, 3.001, 4)).toDF("a", "b", "c") withTempView("v1") { data.createOrReplaceTempView("v1") @@ -2500,12 +2500,12 @@ class DataFrameAggregateSuite extends QueryTest | round(sum(b), 6) as sum1, | count(distinct a) as count1, | count(distinct c) as count2 - | from ( - | select - | 6 as gb, - | * - | from v1 - | ) + |from ( + | select + | 6 as gb, + | * + | from v1 + |) |group by a, gb |""".stripMargin) checkAnswer(df, Row(1.001d, 1, 1) :: Row(6.002d, 1, 1) :: Nil) From 4de1c4893e7bbea72cffa85bcf4b92c151ac3be9 Mon Sep 17 00:00:00 2001 From: Bruce Robbins Date: Mon, 26 Aug 2024 15:35:09 -0700 Subject: [PATCH 6/7] Fix test names --- .../sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala | 2 +- .../scala/org/apache/spark/sql/DataFrameAggregateSuite.scala | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala index 807cfa4781be1..4d31999ded655 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala @@ -110,7 +110,7 @@ class RewriteDistinctAggregatesSuite extends PlanTest { } } - test("SPARK-49261: Don't patch literals in aggregate expressions with group-by expressions") { + test("SPARK-49261: Literals in grouping expressions shouldn't result in unresolved aggregation") { val relation = testRelation2 .select(Literal(6).as("gb"), $"a", $"b", $"c", $"d") val input = relation diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 9ae0b1b8cbd62..411c0c389b595 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -2491,7 +2491,7 @@ class DataFrameAggregateSuite extends QueryTest } } - test("SPARK-49261: Don't patch literals in aggregate expressions with group-by expressions") { + test("SPARK-49261: Literals in grouping expressions shouldn't result in unresolved aggregation") { val data = Seq((1, 1.001d, 2), (2, 3.001d, 4), (2, 3.001, 4)).toDF("a", "b", "c") withTempView("v1") { data.createOrReplaceTempView("v1") From be0a975798292e04a0e15c21dd4e8b963f550060 Mon Sep 17 00:00:00 2001 From: Bruce Robbins Date: Wed, 11 Sep 2024 15:29:44 -0700 Subject: [PATCH 7/7] Review feedback --- .../spark/sql/DataFrameAggregateSuite.scala | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 411c0c389b595..e80c3b23a7db3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -2496,17 +2496,17 @@ class DataFrameAggregateSuite extends QueryTest withTempView("v1") { data.createOrReplaceTempView("v1") val df = - sql("""select - | round(sum(b), 6) as sum1, - | count(distinct a) as count1, - | count(distinct c) as count2 - |from ( - | select - | 6 as gb, + sql("""SELECT + | ROUND(SUM(b), 6) AS sum1, + | COUNT(DISTINCT a) AS count1, + | COUNT(DISTINCT c) AS count2 + |FROM ( + | SELECT + | 6 AS gb, | * - | from v1 + | FROM v1 |) - |group by a, gb + |GROUP BY a, gb |""".stripMargin) checkAnswer(df, Row(1.001d, 1, 1) :: Row(6.002d, 1, 1) :: Nil) }