Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import org.apache.gluten.backendsapi.BackendsApiManager
import org.apache.gluten.exception.{GlutenException, GlutenNotSupportException}

import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction}
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction, Complete, Partial}
import org.apache.spark.sql.execution.aggregate._
import org.apache.spark.sql.execution.window.WindowExec
import org.apache.spark.sql.types.{ByteType, DateType, IntegerType, LongType, ShortType}
Expand Down Expand Up @@ -137,16 +137,20 @@ trait PullOutProjectHelper {
protected def rewriteAggregateExpression(
ae: AggregateExpression,
expressionMap: mutable.HashMap[Expression, NamedExpression]): AggregateExpression = {
val newAggFuncChildren = ae.aggregateFunction.children.map {
case literal: Literal => literal
case other => replaceExpressionWithAttribute(other, expressionMap)
ae.mode match {
case Partial | Complete =>
val newAggFuncChildren = ae.aggregateFunction.children.map {
case literal: Literal => literal
case other => replaceExpressionWithAttribute(other, expressionMap)
}
val newAggFunc = ae.aggregateFunction
.withNewChildren(newAggFuncChildren)
.asInstanceOf[AggregateFunction]
val newFilter =
ae.filter.map(replaceExpressionWithAttribute(_, expressionMap))
ae.copy(aggregateFunction = newAggFunc, filter = newFilter)
case _ => ae
}
val newAggFunc = ae.aggregateFunction
.withNewChildren(newAggFuncChildren)
.asInstanceOf[AggregateFunction]
val newFilter =
ae.filter.map(replaceExpressionWithAttribute(_, expressionMap))
ae.copy(aggregateFunction = newAggFunc, filter = newFilter)
}

private def needPreComputeRangeFrameBoundary(bound: Expression): Boolean = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,28 @@ abstract class GlutenQueryTest extends PlanTest {
val executedPlan = getExecutedPlan(df)
assert(executedPlan.exists(plan => tag.runtimeClass.isInstance(plan)))
}

/**
* Check whether the executed plan of a dataframe contains expected number of expected plans.
*
* @param df:
* the input dataframe.
* @param count:
* expected number of expected plan.
* @param tag:
* class of the expected plan.
* @tparam T:
* type of the expected plan.
*/
def checkGlutenOperatorCount[T <: GlutenPlan](df: DataFrame, count: Int)(implicit
tag: ClassTag[T]): Unit = {
val executedPlan = getExecutedPlan(df)
assert(
executedPlan.count(plan => tag.runtimeClass.isInstance(plan)) == count,
s"Expect $count ${tag.runtimeClass.getSimpleName} " +
s"in executedPlan:\n ${executedPlan.last}"
)
}
}

object GlutenQueryTest extends Assertions {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
*/
package org.apache.gluten.extension

import org.apache.gluten.execution.{ProjectExecTransformer, WholeStageTransformerSuite}
import org.apache.gluten.execution.{HashAggregateExecBaseTransformer, ProjectExecTransformer, WholeStageTransformerSuite}
import org.apache.gluten.utils.BackendTestUtils

import org.apache.spark.SparkConf
Expand Down Expand Up @@ -62,4 +62,31 @@ class GlutenExtensionRewriteRuleSuite extends WholeStageTransformerSuite {
}
)
}

test("GLUTEN-9279 - Not Pull out expression to avoid invalid reference binding") {
withTable("t") {
sql("CREATE TABLE t(f1 String, f2 String, f3 String, f4 String) USING PARQUET")
sql("INSERT INTO t values ('1', '2', '3', '4'), ('11' ,'22', '33', '4')")
var expectedProjectCount = 3
var noFallback = false
if (BackendTestUtils.isCHBackendLoaded()) {
// The `RewriteMultiChildrenCount` rule in the Velox-backend is the root cause of the
// additional ProjectExecTransformer, which leads to the invalid reference binding issue.
// We still conduct tests on the CH-backend here to ensure that the introduced modification
// in `PullOutPreProject` has no side effect on the CH-backend.
expectedProjectCount = 2
noFallback = true
}
runQueryAndCompare(
"""
|SELECT SUM(f1) / COUNT(DISTINCT f2, f3) FROM t GROUP BY f4;
|""".stripMargin,
noFallBack = noFallback
)(
df => {
checkGlutenOperatorCount[ProjectExecTransformer](df, expectedProjectCount)
checkGlutenOperatorCount[HashAggregateExecBaseTransformer](df, 4)
})
}
}
}