Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,8 @@ object CHBackendSettings extends BackendSettingsApi with Logging {
CHConfig.prefixOf("enable.coalesce.project.union")
val GLUTEN_JOIN_AGGREGATE_TO_AGGREGATE_UNION: String =
CHConfig.prefixOf("join.aggregate.to.aggregate.union")
val GLUTEN_ELIMINATE_DEDUPLICATE_AGGREGATE_WITH_ANY_JOIN: String =
CHConfig.prefixOf("eliminate_deduplicate_aggregate_with_any_join")

def affinityMode: String = {
SparkEnv.get.conf
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,8 @@ class CHSparkPlanExecApi extends SparkPlanExecApi with Logging {
condition,
left,
right,
isSkewJoin)
isSkewJoin,
false)
}

/** Generate BroadcastHashJoinExecTransformer. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,8 @@ case class CHShuffledHashJoinExecTransformer(
condition: Option[Expression],
left: SparkPlan,
right: SparkPlan,
isSkewJoin: Boolean)
isSkewJoin: Boolean,
isAnyJoin: Boolean)
extends ShuffledHashJoinExecTransformerBase(
leftKeys,
rightKeys,
Expand All @@ -100,8 +101,6 @@ case class CHShuffledHashJoinExecTransformer(
left,
right,
isSkewJoin) {
// `any join` is used to accelerate the case when the right table is the aggregate result.
var isAnyJoin = false
override protected def withNewChildrenInternal(
newLeft: SparkPlan,
newRight: SparkPlan): CHShuffledHashJoinExecTransformer =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import org.apache.gluten.execution._
import org.apache.spark.internal.Logging
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.optimizer._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.SparkPlan
Expand All @@ -30,21 +31,48 @@ case class EliminateDeduplicateAggregateWithAnyJoin(spark: SparkSession)
extends Rule[SparkPlan]
with Logging {
override def apply(plan: SparkPlan): SparkPlan = {
if (!CHBackendSettings.eliminateDeduplicateAggregateWithAnyJoin()) {

if (
!spark.conf
.get(CHBackendSettings.GLUTEN_ELIMINATE_DEDUPLICATE_AGGREGATE_WITH_ANY_JOIN, "true")
.toBoolean
) {
return plan
}

plan.transformUp {
case hashJoin: CHShuffledHashJoinExecTransformer =>
case hashJoin: CHShuffledHashJoinExecTransformer
if (hashJoin.buildSide == BuildRight && hashJoin.joinType == LeftOuter) =>
hashJoin.right match {
case aggregate: CHHashAggregateExecTransformer =>
if (
isDeduplicateAggregate(aggregate) && allGroupingKeysAreJoinKeys(hashJoin, aggregate)
) {
hashJoin.copy(right = aggregate.child, isAnyJoin = true)
} else {
hashJoin
}
case project @ ProjectExecTransformer(_, aggregate: CHHashAggregateExecTransformer) =>
if (
hashJoin.joinType == LeftOuter &&
isDeduplicateAggregate(aggregate) &&
allGroupingKeysAreJoinKeys(hashJoin, aggregate) && project.projectList.forall(
_.isInstanceOf[AttributeReference])
) {
hashJoin.copy(right = project.copy(child = aggregate.child), isAnyJoin = true)
} else {
hashJoin
}
case _ => hashJoin
}
case hashJoin: CHShuffledHashJoinExecTransformer
if (hashJoin.buildSide == BuildLeft && hashJoin.joinType == LeftOuter) =>
hashJoin.left match {
case aggregate: CHHashAggregateExecTransformer =>
if (
isDeduplicateAggregate(aggregate) && allGroupingKeysAreJoinKeys(hashJoin, aggregate)
) {
val newHashJoin = hashJoin.copy(right = aggregate.child)
newHashJoin.isAnyJoin = true
newHashJoin
hashJoin.copy(left = aggregate.child, isAnyJoin = true)
} else {
hashJoin
}
Expand All @@ -55,10 +83,7 @@ case class EliminateDeduplicateAggregateWithAnyJoin(spark: SparkSession)
allGroupingKeysAreJoinKeys(hashJoin, aggregate) && project.projectList.forall(
_.isInstanceOf[AttributeReference])
) {
val newHashJoin =
hashJoin.copy(right = project.copy(child = aggregate.child))
newHashJoin.isAnyJoin = true
newHashJoin
hashJoin.copy(left = project.copy(child = aggregate.child), isAnyJoin = true)
} else {
hashJoin
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@ case class JoinedAggregateAnalyzer(join: Join, subquery: LogicalPlan) extends Lo
joinKeys.length != aggregate.groupingExpressions.length ||
!joinKeys.forall(k => outputGroupingKeys.exists(_.semanticEquals(k)))
) {
logError(
logDebug(
s"xxx Join keys and grouping keys are not matched. joinKeys: $joinKeys" +
s" outputGroupingKeys: $outputGroupingKeys")
return false
Expand Down Expand Up @@ -955,7 +955,7 @@ case class JoinAggregateToAggregateUnion(spark: SparkSession)
analyzedAggregates.insert(0, rightAggregateAnalyzer.get)
collectSameKeysJoinedAggregates(join.left, analyzedAggregates)
} else {
logError(
logDebug(
s"xxx Not have same keys. join keys:" +
s"${analyzedAggregates.head.getPrimeJoinKeys()} vs. " +
s"${rightAggregateAnalyzer.get.getPrimeJoinKeys()}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,8 @@ case class RewriteSortMergeJoinToHashJoinRule(session: SparkSession)
smj.condition,
newLeft,
newRight,
smj.isSkewJoin)
smj.isSkewJoin,
false)
val validateResult = hashJoin.doValidate()
if (!validateResult.ok()) {
logError(s"Validation failed for ShuffledHashJoinExec: ${validateResult.reason()}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
*/
package org.apache.gluten.execution

import org.apache.gluten.backendsapi.clickhouse._

import org.apache.spark.SparkConf
import org.apache.spark.internal.Logging
import org.apache.spark.sql.Row
Expand Down Expand Up @@ -59,6 +61,8 @@ class GlutenEliminateJoinSuite extends GlutenClickHouseWholeStageTransformerSuit
.set("spark.sql.shuffle.partitions", "5")
.set("spark.sql.autoBroadcastJoinThreshold", "-1")
.set("spark.gluten.supported.scala.udfs", "compare_substrings:compare_substrings")
.set(CHConfig.runtimeSettings("max_memory_usage_ratio_for_streaming_aggregating"), "0.01")
.set(CHConfig.runtimeSettings("high_cardinality_threshold_for_streaming_aggregating"), "0.2")
.set(
SQLConf.OPTIMIZER_EXCLUDED_RULES.key,
ConstantFolding.ruleName + "," + NullPropagation.ruleName)
Expand Down Expand Up @@ -469,4 +473,72 @@ class GlutenEliminateJoinSuite extends GlutenClickHouseWholeStageTransformerSuit
assert(joins.length == 1)
})
}

// Ensure the isAnyJoin will never lost after apply other rules
test("lost any join setting") {
spark.sql("drop table if exists t_9267_1")
spark.sql("drop table if exists t_9267_2")
spark.sql("create table t_9267_1 (a bigint, b bigint) using parquet")
spark.sql("create table t_9267_2 (a bigint, b bigint) using parquet")
spark.sql("insert into t_9267_1 select id as a, id as b from range(20000000)")
spark.sql("insert into t_9267_2 select id as a, id as b from range(5000000)")
spark.sql("insert into t_9267_2 select id as a, id as b from range(5000000)")

val sql =
"""
|select count(1) as n1, count(a1, b1, a2) as n2 from(
| select t1.a as a1, t1.b as b1, t2.a as a2 from (
| select * from t_9267_1 where a >= 0 and b < 100000000 and b >= 0
| ) t1 left join (
| select a, b from t_9267_2 group by a, b
| ) t2 on t1.a = t2.a and t1.b = t2.b
|)""".stripMargin
compareResultsAgainstVanillaSpark(
sql,
true,
{
df =>
val joins = df.queryExecution.executedPlan.collect {
case join: ShuffledHashJoinExecTransformerBase => join
}
assert(joins.length == 1)
})

spark.sql("drop table t_9267_1")
spark.sql("drop table t_9267_2")
}

test("build left side") {
spark.sql("drop table if exists t_9267_1")
spark.sql("drop table if exists t_9267_2")
spark.sql("create table t_9267_1 (a bigint, b bigint) using parquet")
spark.sql("create table t_9267_2 (a bigint, b bigint) using parquet")
spark.sql("insert into t_9267_1 select id as a, id as b from range(2000000)")
spark.sql("insert into t_9267_2 select id as a, id as b from range(500000)")
spark.sql("insert into t_9267_2 select id as a, id as b from range(500000)")

// left table is smaller, it will be used as the build side.
val sql =
"""
|select count(1) as n1, count(a1, b1, a2) as n2 from(
| select t1.a as a1, t1.b as b1, t2.a as a2 from (
| select a, b from t_9267_2 group by a, b
| ) t1 left join (
| select * from t_9267_1 where a >= 0 and b != 100000000 and b >= 0
| ) t2 on t1.a = t2.a and t1.b = t2.b
|)""".stripMargin
compareResultsAgainstVanillaSpark(
sql,
true,
{
df =>
val joins = df.queryExecution.executedPlan.collect {
case join: ShuffledHashJoinExecTransformerBase => join
}
assert(joins.length == 1)
})

spark.sql("drop table t_9267_1")
spark.sql("drop table t_9267_2")
}
}