diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/utils/PullOutProjectHelper.scala b/gluten-substrait/src/main/scala/org/apache/gluten/utils/PullOutProjectHelper.scala index e4fc11441031..4bb09bfcab6a 100644 --- a/gluten-substrait/src/main/scala/org/apache/gluten/utils/PullOutProjectHelper.scala +++ b/gluten-substrait/src/main/scala/org/apache/gluten/utils/PullOutProjectHelper.scala @@ -20,7 +20,7 @@ import org.apache.gluten.backendsapi.BackendsApiManager import org.apache.gluten.exception.{GlutenException, GlutenNotSupportException} import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction} +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction, Complete, Partial} import org.apache.spark.sql.execution.aggregate._ import org.apache.spark.sql.execution.window.WindowExec import org.apache.spark.sql.types.{ByteType, DateType, IntegerType, LongType, ShortType} @@ -137,16 +137,20 @@ trait PullOutProjectHelper { protected def rewriteAggregateExpression( ae: AggregateExpression, expressionMap: mutable.HashMap[Expression, NamedExpression]): AggregateExpression = { - val newAggFuncChildren = ae.aggregateFunction.children.map { - case literal: Literal => literal - case other => replaceExpressionWithAttribute(other, expressionMap) + ae.mode match { + case Partial | Complete => + val newAggFuncChildren = ae.aggregateFunction.children.map { + case literal: Literal => literal + case other => replaceExpressionWithAttribute(other, expressionMap) + } + val newAggFunc = ae.aggregateFunction + .withNewChildren(newAggFuncChildren) + .asInstanceOf[AggregateFunction] + val newFilter = + ae.filter.map(replaceExpressionWithAttribute(_, expressionMap)) + ae.copy(aggregateFunction = newAggFunc, filter = newFilter) + case _ => ae } - val newAggFunc = ae.aggregateFunction - .withNewChildren(newAggFuncChildren) - .asInstanceOf[AggregateFunction] - val newFilter = - ae.filter.map(replaceExpressionWithAttribute(_, expressionMap)) - ae.copy(aggregateFunction = newAggFunc, filter = newFilter) } private def needPreComputeRangeFrameBoundary(bound: Expression): Boolean = { diff --git a/gluten-substrait/src/test/scala/org/apache/spark/sql/GlutenQueryTest.scala b/gluten-substrait/src/test/scala/org/apache/spark/sql/GlutenQueryTest.scala index 1d4a63bdbc28..8852fdb137b7 100644 --- a/gluten-substrait/src/test/scala/org/apache/spark/sql/GlutenQueryTest.scala +++ b/gluten-substrait/src/test/scala/org/apache/spark/sql/GlutenQueryTest.scala @@ -397,6 +397,28 @@ abstract class GlutenQueryTest extends PlanTest { val executedPlan = getExecutedPlan(df) assert(executedPlan.exists(plan => tag.runtimeClass.isInstance(plan))) } + + /** + * Check whether the executed plan of a dataframe contains expected number of expected plans. + * + * @param df: + * the input dataframe. + * @param count: + * expected number of expected plan. + * @param tag: + * class of the expected plan. + * @tparam T: + * type of the expected plan. + */ + def checkGlutenOperatorCount[T <: GlutenPlan](df: DataFrame, count: Int)(implicit + tag: ClassTag[T]): Unit = { + val executedPlan = getExecutedPlan(df) + assert( + executedPlan.count(plan => tag.runtimeClass.isInstance(plan)) == count, + s"Expect $count ${tag.runtimeClass.getSimpleName} " + + s"in executedPlan:\n ${executedPlan.last}" + ) + } } object GlutenQueryTest extends Assertions { diff --git a/gluten-ut/test/src/test/scala/org/apache/gluten/extension/GlutenExtensionRewriteRuleSuite.scala b/gluten-ut/test/src/test/scala/org/apache/gluten/extension/GlutenExtensionRewriteRuleSuite.scala index a29550815041..837d37236b65 100644 --- a/gluten-ut/test/src/test/scala/org/apache/gluten/extension/GlutenExtensionRewriteRuleSuite.scala +++ b/gluten-ut/test/src/test/scala/org/apache/gluten/extension/GlutenExtensionRewriteRuleSuite.scala @@ -16,7 +16,7 @@ */ package org.apache.gluten.extension -import org.apache.gluten.execution.{ProjectExecTransformer, WholeStageTransformerSuite} +import org.apache.gluten.execution.{HashAggregateExecBaseTransformer, ProjectExecTransformer, WholeStageTransformerSuite} import org.apache.gluten.utils.BackendTestUtils import org.apache.spark.SparkConf @@ -62,4 +62,31 @@ class GlutenExtensionRewriteRuleSuite extends WholeStageTransformerSuite { } ) } + + test("GLUTEN-9279 - Not Pull out expression to avoid invalid reference binding") { + withTable("t") { + sql("CREATE TABLE t(f1 String, f2 String, f3 String, f4 String) USING PARQUET") + sql("INSERT INTO t values ('1', '2', '3', '4'), ('11' ,'22', '33', '4')") + var expectedProjectCount = 3 + var noFallback = false + if (BackendTestUtils.isCHBackendLoaded()) { + // The `RewriteMultiChildrenCount` rule in the Velox-backend is the root cause of the + // additional ProjectExecTransformer, which leads to the invalid reference binding issue. + // We still conduct tests on the CH-backend here to ensure that the introduced modification + // in `PullOutPreProject` has no side effect on the CH-backend. + expectedProjectCount = 2 + noFallback = true + } + runQueryAndCompare( + """ + |SELECT SUM(f1) / COUNT(DISTINCT f2, f3) FROM t GROUP BY f4; + |""".stripMargin, + noFallBack = noFallback + )( + df => { + checkGlutenOperatorCount[ProjectExecTransformer](df, expectedProjectCount) + checkGlutenOperatorCount[HashAggregateExecBaseTransformer](df, 4) + }) + } + } }