From c1c192c126039fe3b0ab5677b54c267cc5f44d71 Mon Sep 17 00:00:00 2001 From: loneylee Date: Fri, 11 Apr 2025 20:26:48 +0800 Subject: [PATCH 1/2] [GLUTEN-9178][CH] Fix cse in aggregate operator not working --- .../CommonSubexpressionEliminateRule.scala | 16 ++++++------- .../GlutenFunctionValidateSuite.scala | 24 +++++++++++++++++-- 2 files changed, 30 insertions(+), 10 deletions(-) diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/CommonSubexpressionEliminateRule.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/CommonSubexpressionEliminateRule.scala index e10e2bba001c..7e674367f176 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/CommonSubexpressionEliminateRule.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/CommonSubexpressionEliminateRule.scala @@ -89,16 +89,16 @@ class CommonSubexpressionEliminateRule(spark: SparkSession) extends Rule[Logical private def replaceAggCommonExprWithAttribute( expr: Expression, - commonExprMap: mutable.HashMap[ExpressionEquals, AliasAndAttribute]): Expression = { + commonExprMap: mutable.HashMap[ExpressionEquals, AliasAndAttribute], + inAgg: Boolean = false): Expression = { val exprEquals = commonExprMap.get(ExpressionEquals(expr)) - if (expr.isInstanceOf[AggregateExpression]) { - if (exprEquals.isDefined) { + expr match { + case _ if exprEquals.isDefined && inAgg => exprEquals.get.attribute - } else { - expr - } - } else { - expr.mapChildren(replaceAggCommonExprWithAttribute(_, commonExprMap)) + case _: AggregateExpression => + expr.mapChildren(replaceAggCommonExprWithAttribute(_, commonExprMap, true)) + case _ => + expr.mapChildren(replaceAggCommonExprWithAttribute(_, commonExprMap, inAgg)) } } diff --git a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenFunctionValidateSuite.scala b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenFunctionValidateSuite.scala index 7bbcc3c363f9..3f93d47c6cc9 100644 --- a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenFunctionValidateSuite.scala +++ b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenFunctionValidateSuite.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.{DataFrame, GlutenTestUtils, Row} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.optimizer.{ConstantFolding, NullPropagation} -import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project} +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LogicalPlan, Project} import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.datasources.v2.clickhouse.ClickHouseConfig import org.apache.spark.sql.internal.SQLConf @@ -825,8 +825,28 @@ class GlutenFunctionValidateSuite extends GlutenClickHouseWholeStageTransformerS |FROM (select id, cast(id as string) name from range(10)) |GROUP BY name |""".stripMargin) { - df => checkOperatorCount[ProjectExecTransformer](3)(df) + df => checkOperatorCount[ProjectExecTransformer](4)(df) } + + runQueryAndCompare( + s""" + |select id % 2, max(hash(id)), min(hash(id)) from range(10) group by id % 2 + |""".stripMargin)( + df => { + df.queryExecution.optimizedPlan.collect { + case Aggregate(_, aggregateExpressions, _) => + val result = + aggregateExpressions + .map(a => a.asInstanceOf[Alias].child) + .filter(_.isInstanceOf[AggregateExpression]) + .map(expr => expr.asInstanceOf[AggregateExpression].aggregateFunction) + .filter(aggFunc => aggFunc.children.head.isInstanceOf[AttributeReference]) + .map(aggFunc => aggFunc.children.head.asInstanceOf[AttributeReference].name) + .distinct + assertResult(1)(result.size) + } + checkOperatorCount[ProjectExecTransformer](1)(df) + }) } } From 3f9799eebbc2c582f95404ae1697fdee0d03a3b3 Mon Sep 17 00:00:00 2001 From: loneylee Date: Sat, 12 Apr 2025 21:07:15 +0800 Subject: [PATCH 2/2] fix ci --- .../apache/gluten/execution/GlutenFunctionValidateSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenFunctionValidateSuite.scala b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenFunctionValidateSuite.scala index 3f93d47c6cc9..02e5b34a372d 100644 --- a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenFunctionValidateSuite.scala +++ b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenFunctionValidateSuite.scala @@ -1359,7 +1359,7 @@ class GlutenFunctionValidateSuite extends GlutenClickHouseWholeStageTransformerS test("Test rewrite aggregate if to aggregate with filter") { val sql = "select sum(if(id % 2=0, id, null)), count(if(id % 2 = 0, 1, null)), " + - "avg(if(id % 2 = 0, id, null)), sum(if(id % 3 = 0, id, 0)) from range(10)" + "avg(if(id % 4 = 0, id, null)), sum(if(id % 3 = 0, id, 0)) from range(10)" def checkAggregateWithFilter(df: DataFrame): Unit = { val aggregates = collectWithSubqueries(df.queryExecution.executedPlan) {