diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/CommonSubexpressionEliminateRule.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/CommonSubexpressionEliminateRule.scala index e10e2bba001c..7e674367f176 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/CommonSubexpressionEliminateRule.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/CommonSubexpressionEliminateRule.scala @@ -89,16 +89,16 @@ class CommonSubexpressionEliminateRule(spark: SparkSession) extends Rule[Logical private def replaceAggCommonExprWithAttribute( expr: Expression, - commonExprMap: mutable.HashMap[ExpressionEquals, AliasAndAttribute]): Expression = { + commonExprMap: mutable.HashMap[ExpressionEquals, AliasAndAttribute], + inAgg: Boolean = false): Expression = { val exprEquals = commonExprMap.get(ExpressionEquals(expr)) - if (expr.isInstanceOf[AggregateExpression]) { - if (exprEquals.isDefined) { + expr match { + case _ if exprEquals.isDefined && inAgg => exprEquals.get.attribute - } else { - expr - } - } else { - expr.mapChildren(replaceAggCommonExprWithAttribute(_, commonExprMap)) + case _: AggregateExpression => + expr.mapChildren(replaceAggCommonExprWithAttribute(_, commonExprMap, true)) + case _ => + expr.mapChildren(replaceAggCommonExprWithAttribute(_, commonExprMap, inAgg)) } } diff --git a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenFunctionValidateSuite.scala b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenFunctionValidateSuite.scala index 7bbcc3c363f9..02e5b34a372d 100644 --- a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenFunctionValidateSuite.scala +++ b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenFunctionValidateSuite.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.{DataFrame, GlutenTestUtils, Row} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.optimizer.{ConstantFolding, NullPropagation} -import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project} +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LogicalPlan, Project} import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.datasources.v2.clickhouse.ClickHouseConfig import org.apache.spark.sql.internal.SQLConf @@ -825,8 +825,28 @@ class GlutenFunctionValidateSuite extends GlutenClickHouseWholeStageTransformerS |FROM (select id, cast(id as string) name from range(10)) |GROUP BY name |""".stripMargin) { - df => checkOperatorCount[ProjectExecTransformer](3)(df) + df => checkOperatorCount[ProjectExecTransformer](4)(df) } + + runQueryAndCompare( + s""" + |select id % 2, max(hash(id)), min(hash(id)) from range(10) group by id % 2 + |""".stripMargin)( + df => { + df.queryExecution.optimizedPlan.collect { + case Aggregate(_, aggregateExpressions, _) => + val result = + aggregateExpressions + .map(a => a.asInstanceOf[Alias].child) + .filter(_.isInstanceOf[AggregateExpression]) + .map(expr => expr.asInstanceOf[AggregateExpression].aggregateFunction) + .filter(aggFunc => aggFunc.children.head.isInstanceOf[AttributeReference]) + .map(aggFunc => aggFunc.children.head.asInstanceOf[AttributeReference].name) + .distinct + assertResult(1)(result.size) + } + checkOperatorCount[ProjectExecTransformer](1)(df) + }) } } @@ -1339,7 +1359,7 @@ class GlutenFunctionValidateSuite extends GlutenClickHouseWholeStageTransformerS test("Test rewrite aggregate if to aggregate with filter") { val sql = "select sum(if(id % 2=0, id, null)), count(if(id % 2 = 0, 1, null)), " + - "avg(if(id % 2 = 0, id, null)), sum(if(id % 3 = 0, id, 0)) from range(10)" + "avg(if(id % 4 = 0, id, null)), sum(if(id % 3 = 0, id, 0)) from range(10)" def checkAggregateWithFilter(df: DataFrame): Unit = { val aggregates = collectWithSubqueries(df.queryExecution.executedPlan) {