From 90df6ee3e75223334657fa67604d5570b5d4c2a2 Mon Sep 17 00:00:00 2001 From: wuziyi Date: Thu, 10 Apr 2025 00:28:42 +0800 Subject: [PATCH 1/7] [fix] not pull out from PartialMerge aggregate function to avoid invalid reference binding in ProjectExec --- .../gluten/utils/PullOutProjectHelper.scala | 24 +++++++++++-------- .../GlutenExtensionRewriteRuleSuite.scala | 13 +++++++++- 2 files changed, 26 insertions(+), 11 deletions(-) diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/utils/PullOutProjectHelper.scala b/gluten-substrait/src/main/scala/org/apache/gluten/utils/PullOutProjectHelper.scala index e4fc11441031..89d479500ab2 100644 --- a/gluten-substrait/src/main/scala/org/apache/gluten/utils/PullOutProjectHelper.scala +++ b/gluten-substrait/src/main/scala/org/apache/gluten/utils/PullOutProjectHelper.scala @@ -20,7 +20,7 @@ import org.apache.gluten.backendsapi.BackendsApiManager import org.apache.gluten.exception.{GlutenException, GlutenNotSupportException} import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction} +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction,Complete, Partial} import org.apache.spark.sql.execution.aggregate._ import org.apache.spark.sql.execution.window.WindowExec import org.apache.spark.sql.types.{ByteType, DateType, IntegerType, LongType, ShortType} @@ -137,16 +137,20 @@ trait PullOutProjectHelper { protected def rewriteAggregateExpression( ae: AggregateExpression, expressionMap: mutable.HashMap[Expression, NamedExpression]): AggregateExpression = { - val newAggFuncChildren = ae.aggregateFunction.children.map { - case literal: Literal => literal - case other => replaceExpressionWithAttribute(other, expressionMap) + ae.mode match { + case Partial | Complete => + val newAggFuncChildren = ae.aggregateFunction.children.map { + case literal: Literal => literal + case other => replaceExpressionWithAttribute(other, expressionMap) + } + val newAggFunc = ae.aggregateFunction + .withNewChildren(newAggFuncChildren) + .asInstanceOf[AggregateFunction] + val newFilter = + ae.filter.map(replaceExpressionWithAttribute(_, expressionMap)) + ae.copy(aggregateFunction = newAggFunc, filter = newFilter) + case _ => ae } - val newAggFunc = ae.aggregateFunction - .withNewChildren(newAggFuncChildren) - .asInstanceOf[AggregateFunction] - val newFilter = - ae.filter.map(replaceExpressionWithAttribute(_, expressionMap)) - ae.copy(aggregateFunction = newAggFunc, filter = newFilter) } private def needPreComputeRangeFrameBoundary(bound: Expression): Boolean = { diff --git a/gluten-ut/test/src/test/scala/org/apache/gluten/extension/GlutenExtensionRewriteRuleSuite.scala b/gluten-ut/test/src/test/scala/org/apache/gluten/extension/GlutenExtensionRewriteRuleSuite.scala index a29550815041..e4e230663f6a 100644 --- a/gluten-ut/test/src/test/scala/org/apache/gluten/extension/GlutenExtensionRewriteRuleSuite.scala +++ b/gluten-ut/test/src/test/scala/org/apache/gluten/extension/GlutenExtensionRewriteRuleSuite.scala @@ -18,7 +18,6 @@ package org.apache.gluten.extension import org.apache.gluten.execution.{ProjectExecTransformer, WholeStageTransformerSuite} import org.apache.gluten.utils.BackendTestUtils - import org.apache.spark.SparkConf class GlutenExtensionRewriteRuleSuite extends WholeStageTransformerSuite { @@ -62,4 +61,16 @@ class GlutenExtensionRewriteRuleSuite extends WholeStageTransformerSuite { } ) } + + test("GLUTEN-XXXX - Pull out project avoid invalid reference binding") { + withTable("t") { + sql("CREATE TABLE t(f1 String, f2 String, f3 String, f4 String) USING CSV") + sql("INSERT INTO t values ('1', '2', '3', '4'), ('11' ,'22', '33', '4')") + val df = sql( + """ + |SELECT SUM(f1) / COUNT(DISTINCT f2, f3) FROM t GROUP BY f4; + |""".stripMargin) + assert(df.collect()(0).getDouble(0) == 6) + } + } } From 447bfa077443167aadc634f84d5690b75a45d942 Mon Sep 17 00:00:00 2001 From: wuziyi Date: Thu, 10 Apr 2025 00:51:15 +0800 Subject: [PATCH 2/7] [fix] format --- .../gluten/extension/GlutenExtensionRewriteRuleSuite.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/gluten-ut/test/src/test/scala/org/apache/gluten/extension/GlutenExtensionRewriteRuleSuite.scala b/gluten-ut/test/src/test/scala/org/apache/gluten/extension/GlutenExtensionRewriteRuleSuite.scala index e4e230663f6a..48b469cd5acb 100644 --- a/gluten-ut/test/src/test/scala/org/apache/gluten/extension/GlutenExtensionRewriteRuleSuite.scala +++ b/gluten-ut/test/src/test/scala/org/apache/gluten/extension/GlutenExtensionRewriteRuleSuite.scala @@ -18,6 +18,7 @@ package org.apache.gluten.extension import org.apache.gluten.execution.{ProjectExecTransformer, WholeStageTransformerSuite} import org.apache.gluten.utils.BackendTestUtils + import org.apache.spark.SparkConf class GlutenExtensionRewriteRuleSuite extends WholeStageTransformerSuite { @@ -62,7 +63,7 @@ class GlutenExtensionRewriteRuleSuite extends WholeStageTransformerSuite { ) } - test("GLUTEN-XXXX - Pull out project avoid invalid reference binding") { + test("GLUTEN-9279 - Not Pull out expression to avoid invalid reference binding") { withTable("t") { sql("CREATE TABLE t(f1 String, f2 String, f3 String, f4 String) USING CSV") sql("INSERT INTO t values ('1', '2', '3', '4'), ('11' ,'22', '33', '4')") From c338583c0f4b4cf39a51cd6721e4583919910435 Mon Sep 17 00:00:00 2001 From: wuziyi Date: Thu, 10 Apr 2025 01:08:33 +0800 Subject: [PATCH 3/7] [fix] format --- .../org/apache/gluten/utils/PullOutProjectHelper.scala | 2 +- .../gluten/extension/GlutenExtensionRewriteRuleSuite.scala | 7 +++---- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/utils/PullOutProjectHelper.scala b/gluten-substrait/src/main/scala/org/apache/gluten/utils/PullOutProjectHelper.scala index 89d479500ab2..4bb09bfcab6a 100644 --- a/gluten-substrait/src/main/scala/org/apache/gluten/utils/PullOutProjectHelper.scala +++ b/gluten-substrait/src/main/scala/org/apache/gluten/utils/PullOutProjectHelper.scala @@ -20,7 +20,7 @@ import org.apache.gluten.backendsapi.BackendsApiManager import org.apache.gluten.exception.{GlutenException, GlutenNotSupportException} import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction,Complete, Partial} +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction, Complete, Partial} import org.apache.spark.sql.execution.aggregate._ import org.apache.spark.sql.execution.window.WindowExec import org.apache.spark.sql.types.{ByteType, DateType, IntegerType, LongType, ShortType} diff --git a/gluten-ut/test/src/test/scala/org/apache/gluten/extension/GlutenExtensionRewriteRuleSuite.scala b/gluten-ut/test/src/test/scala/org/apache/gluten/extension/GlutenExtensionRewriteRuleSuite.scala index 48b469cd5acb..367e495b8d77 100644 --- a/gluten-ut/test/src/test/scala/org/apache/gluten/extension/GlutenExtensionRewriteRuleSuite.scala +++ b/gluten-ut/test/src/test/scala/org/apache/gluten/extension/GlutenExtensionRewriteRuleSuite.scala @@ -67,10 +67,9 @@ class GlutenExtensionRewriteRuleSuite extends WholeStageTransformerSuite { withTable("t") { sql("CREATE TABLE t(f1 String, f2 String, f3 String, f4 String) USING CSV") sql("INSERT INTO t values ('1', '2', '3', '4'), ('11' ,'22', '33', '4')") - val df = sql( - """ - |SELECT SUM(f1) / COUNT(DISTINCT f2, f3) FROM t GROUP BY f4; - |""".stripMargin) + val df = sql(""" + |SELECT SUM(f1) / COUNT(DISTINCT f2, f3) FROM t GROUP BY f4; + |""".stripMargin) assert(df.collect()(0).getDouble(0) == 6) } } From 58831574cea1404d21b8a10156a4438226ef39f4 Mon Sep 17 00:00:00 2001 From: wuziyi Date: Fri, 11 Apr 2025 22:54:39 +0800 Subject: [PATCH 4/7] [fix] ut --- .../extension/GlutenExtensionRewriteRuleSuite.scala | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/gluten-ut/test/src/test/scala/org/apache/gluten/extension/GlutenExtensionRewriteRuleSuite.scala b/gluten-ut/test/src/test/scala/org/apache/gluten/extension/GlutenExtensionRewriteRuleSuite.scala index 367e495b8d77..eb831bb996b2 100644 --- a/gluten-ut/test/src/test/scala/org/apache/gluten/extension/GlutenExtensionRewriteRuleSuite.scala +++ b/gluten-ut/test/src/test/scala/org/apache/gluten/extension/GlutenExtensionRewriteRuleSuite.scala @@ -67,10 +67,12 @@ class GlutenExtensionRewriteRuleSuite extends WholeStageTransformerSuite { withTable("t") { sql("CREATE TABLE t(f1 String, f2 String, f3 String, f4 String) USING CSV") sql("INSERT INTO t values ('1', '2', '3', '4'), ('11' ,'22', '33', '4')") - val df = sql(""" - |SELECT SUM(f1) / COUNT(DISTINCT f2, f3) FROM t GROUP BY f4; - |""".stripMargin) - assert(df.collect()(0).getDouble(0) == 6) + runQueryAndCompare( + """ + |SELECT SUM(f1) / COUNT(DISTINCT f2, f3) FROM t GROUP BY f4; + |""".stripMargin, + noFallBack = false + )(_ => {}) } } } From 8dbdb112e8470e6b147e726a86bd8de39478a516 Mon Sep 17 00:00:00 2001 From: wuziyi Date: Fri, 11 Apr 2025 23:04:16 +0800 Subject: [PATCH 5/7] [fix] ut --- .../extension/GlutenExtensionRewriteRuleSuite.scala | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/gluten-ut/test/src/test/scala/org/apache/gluten/extension/GlutenExtensionRewriteRuleSuite.scala b/gluten-ut/test/src/test/scala/org/apache/gluten/extension/GlutenExtensionRewriteRuleSuite.scala index eb831bb996b2..758d41060780 100644 --- a/gluten-ut/test/src/test/scala/org/apache/gluten/extension/GlutenExtensionRewriteRuleSuite.scala +++ b/gluten-ut/test/src/test/scala/org/apache/gluten/extension/GlutenExtensionRewriteRuleSuite.scala @@ -16,7 +16,7 @@ */ package org.apache.gluten.extension -import org.apache.gluten.execution.{ProjectExecTransformer, WholeStageTransformerSuite} +import org.apache.gluten.execution.{HashAggregateExecBaseTransformer, ProjectExecTransformer, WholeStageTransformerSuite} import org.apache.gluten.utils.BackendTestUtils import org.apache.spark.SparkConf @@ -72,7 +72,11 @@ class GlutenExtensionRewriteRuleSuite extends WholeStageTransformerSuite { |SELECT SUM(f1) / COUNT(DISTINCT f2, f3) FROM t GROUP BY f4; |""".stripMargin, noFallBack = false - )(_ => {}) + )( + df => { + checkGlutenOperatorMatch[ProjectExecTransformer](df) + checkGlutenOperatorMatch[HashAggregateExecBaseTransformer](df) + }) } } } From 0457794243afc6ae0fb2932bbaa3f863e64b69f1 Mon Sep 17 00:00:00 2001 From: wuziyi Date: Sat, 12 Apr 2025 00:58:27 +0800 Subject: [PATCH 6/7] [fix] ut --- .../apache/spark/sql/GlutenQueryTest.scala | 22 +++++++++++++++++++ .../GlutenExtensionRewriteRuleSuite.scala | 4 ++-- 2 files changed, 24 insertions(+), 2 deletions(-) diff --git a/gluten-substrait/src/test/scala/org/apache/spark/sql/GlutenQueryTest.scala b/gluten-substrait/src/test/scala/org/apache/spark/sql/GlutenQueryTest.scala index 1d4a63bdbc28..8852fdb137b7 100644 --- a/gluten-substrait/src/test/scala/org/apache/spark/sql/GlutenQueryTest.scala +++ b/gluten-substrait/src/test/scala/org/apache/spark/sql/GlutenQueryTest.scala @@ -397,6 +397,28 @@ abstract class GlutenQueryTest extends PlanTest { val executedPlan = getExecutedPlan(df) assert(executedPlan.exists(plan => tag.runtimeClass.isInstance(plan))) } + + /** + * Check whether the executed plan of a dataframe contains expected number of expected plans. + * + * @param df: + * the input dataframe. + * @param count: + * expected number of expected plan. + * @param tag: + * class of the expected plan. + * @tparam T: + * type of the expected plan. + */ + def checkGlutenOperatorCount[T <: GlutenPlan](df: DataFrame, count: Int)(implicit + tag: ClassTag[T]): Unit = { + val executedPlan = getExecutedPlan(df) + assert( + executedPlan.count(plan => tag.runtimeClass.isInstance(plan)) == count, + s"Expect $count ${tag.runtimeClass.getSimpleName} " + + s"in executedPlan:\n ${executedPlan.last}" + ) + } } object GlutenQueryTest extends Assertions { diff --git a/gluten-ut/test/src/test/scala/org/apache/gluten/extension/GlutenExtensionRewriteRuleSuite.scala b/gluten-ut/test/src/test/scala/org/apache/gluten/extension/GlutenExtensionRewriteRuleSuite.scala index 758d41060780..79b010bdd321 100644 --- a/gluten-ut/test/src/test/scala/org/apache/gluten/extension/GlutenExtensionRewriteRuleSuite.scala +++ b/gluten-ut/test/src/test/scala/org/apache/gluten/extension/GlutenExtensionRewriteRuleSuite.scala @@ -74,8 +74,8 @@ class GlutenExtensionRewriteRuleSuite extends WholeStageTransformerSuite { noFallBack = false )( df => { - checkGlutenOperatorMatch[ProjectExecTransformer](df) - checkGlutenOperatorMatch[HashAggregateExecBaseTransformer](df) + checkGlutenOperatorCount[ProjectExecTransformer](df, 3) + checkGlutenOperatorCount[HashAggregateExecBaseTransformer](df, 4) }) } } From 41b01be0f191bc89b5ef6bd67d2485eb8d6e2a6f Mon Sep 17 00:00:00 2001 From: wuziyi Date: Fri, 18 Apr 2025 04:59:28 +0800 Subject: [PATCH 7/7] fix failed in ck-backend ut --- .../GlutenExtensionRewriteRuleSuite.scala | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/gluten-ut/test/src/test/scala/org/apache/gluten/extension/GlutenExtensionRewriteRuleSuite.scala b/gluten-ut/test/src/test/scala/org/apache/gluten/extension/GlutenExtensionRewriteRuleSuite.scala index 79b010bdd321..837d37236b65 100644 --- a/gluten-ut/test/src/test/scala/org/apache/gluten/extension/GlutenExtensionRewriteRuleSuite.scala +++ b/gluten-ut/test/src/test/scala/org/apache/gluten/extension/GlutenExtensionRewriteRuleSuite.scala @@ -65,16 +65,26 @@ class GlutenExtensionRewriteRuleSuite extends WholeStageTransformerSuite { test("GLUTEN-9279 - Not Pull out expression to avoid invalid reference binding") { withTable("t") { - sql("CREATE TABLE t(f1 String, f2 String, f3 String, f4 String) USING CSV") + sql("CREATE TABLE t(f1 String, f2 String, f3 String, f4 String) USING PARQUET") sql("INSERT INTO t values ('1', '2', '3', '4'), ('11' ,'22', '33', '4')") + var expectedProjectCount = 3 + var noFallback = false + if (BackendTestUtils.isCHBackendLoaded()) { + // The `RewriteMultiChildrenCount` rule in the Velox-backend is the root cause of the + // additional ProjectExecTransformer, which leads to the invalid reference binding issue. + // We still conduct tests on the CH-backend here to ensure that the introduced modification + // in `PullOutPreProject` has no side effect on the CH-backend. + expectedProjectCount = 2 + noFallback = true + } runQueryAndCompare( """ |SELECT SUM(f1) / COUNT(DISTINCT f2, f3) FROM t GROUP BY f4; |""".stripMargin, - noFallBack = false + noFallBack = noFallback )( df => { - checkGlutenOperatorCount[ProjectExecTransformer](df, 3) + checkGlutenOperatorCount[ProjectExecTransformer](df, expectedProjectCount) checkGlutenOperatorCount[HashAggregateExecBaseTransformer](df, 4) }) }