diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala index 06def8902bc0..d71d41bf2d4c 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala @@ -160,6 +160,8 @@ object CHBackendSettings extends BackendSettingsApi with Logging { CHConfig.prefixOf("enable.coalesce.aggregation.union") val GLUTEN_ENABLE_COALESCE_PROJECT_UNION: String = CHConfig.prefixOf("enable.coalesce.project.union") + val GLUTEN_JOIN_AGGREGATE_TO_AGGREGATE_UNION: String = + CHConfig.prefixOf("join.aggregate.to.aggregate.union") def affinityMode: String = { SparkEnv.get.conf diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala index 695deaddbf82..374e223ea9c9 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala @@ -60,6 +60,9 @@ object CHRuleApi { (spark, parserInterface) => new GlutenCacheFilesSqlParser(spark, parserInterface)) injector.injectParser( (spark, parserInterface) => new GlutenClickhouseSqlParser(spark, parserInterface)) + injector.injectResolutionRule(spark => new JoinAggregateToAggregateUnion(spark)) + // CoalesceAggregationUnion and CoalesceProjectionUnion should follows + // JoinAggregateToAggregateUnion injector.injectResolutionRule(spark => new CoalesceAggregationUnion(spark)) injector.injectResolutionRule(spark => new CoalesceProjectionUnion(spark)) injector.injectResolutionRule(spark => new RewriteToDateExpresstionRule(spark)) diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/JoinAggregateToAggregateUnion.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/JoinAggregateToAggregateUnion.scala new file mode 100644 index 000000000000..3fcc2d5369da --- /dev/null +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/JoinAggregateToAggregateUnion.scala @@ -0,0 +1,996 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.gluten.extension + +import org.apache.gluten.backendsapi.clickhouse.CHBackendSettings +import org.apache.gluten.exception._ + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.types._ + +import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.ListBuffer + +private case class JoinKeys( + leftKeys: Seq[AttributeReference], + rightKeys: Seq[AttributeReference]) {} + +private object RulePlanHelper { + def transformDistinctToAggregate(distinct: Distinct): Aggregate = { + Aggregate(distinct.child.output, distinct.child.output, distinct.child) + } + + def extractDirectAggregate(plan: LogicalPlan): Option[Aggregate] = { + plan match { + case _ @SubqueryAlias(_, aggregate: Aggregate) => Some(aggregate) + case project @ Project(projectList, aggregate: Aggregate) => + if (projectList.forall(_.isInstanceOf[AttributeReference])) { + // Just a column prune projection, ignore it + Some(aggregate) + } else { + None + } + case _ @SubqueryAlias(_, distinct: Distinct) => + Some(transformDistinctToAggregate(distinct)) + case distinct: Distinct => + Some(transformDistinctToAggregate(distinct)) + case aggregate: Aggregate => Some(aggregate) + case _ => None + } + } +} + +private object RuleExpressionHelper { + + def extractLiteral(e: Expression): Option[Literal] = { + e match { + case literal: Literal => Some(literal) + case _ @Alias(literal: Literal, _) => Some(literal) + case _ => None + } + } + + def makeFlagLiteral(): Literal = { + Literal.create(true, BooleanType) + } + + def makeNullLiteral(dataType: DataType): Literal = { + Literal.create(null, dataType) + } + + def makeNamedExpression(e: Expression, name: String): NamedExpression = { + e match { + case alias: Alias => + Alias(alias.child, name)() + case _ => Alias(e, name)() + } + } + + def makeFirstAggregateExpression(e: Expression): AggregateExpression = { + val aggregateFunction = First(e, true) + AggregateExpression(aggregateFunction, Complete, false) + } + + def extractJoinKeys( + joinCondition: Option[Expression], + leftAttributes: AttributeSet, + rightAttributes: AttributeSet): Option[JoinKeys] = { + val leftKeys = ArrayBuffer[AttributeReference]() + val rightKeys = ArrayBuffer[AttributeReference]() + def visitJoinExpression(e: Expression): Unit = { + e match { + case and: And => + visitJoinExpression(and.left) + visitJoinExpression(and.right) + case equalTo @ EqualTo(left: AttributeReference, right: AttributeReference) => + if (leftAttributes.contains(left)) { + leftKeys += left + } + if (leftAttributes.contains(right)) { + leftKeys += right + } + if (rightAttributes.contains(left)) { + rightKeys += left + } + if (rightAttributes.contains(right)) { + rightKeys += right + } + case _ => + throw new GlutenException(s"Unsupported join condition $e") + } + } + + joinCondition match { + case Some(condition) => + try { + visitJoinExpression(condition) + if (leftKeys.length != rightKeys.length) { + return None + } + if (!leftKeys.forall(k => k.qualifier.equals(leftKeys.head.qualifier))) { + return None + } + // They must be the same sets or total different sets + if ( + leftKeys.length != rightKeys.length || + !(leftKeys.forall(key => rightKeys.exists(_.equals(key))) || + leftKeys.forall(key => !rightKeys.exists(_.equals(key)))) + ) { + return None + } + Some(JoinKeys(leftKeys.toSeq, rightKeys.toSeq)) + } catch { + case e: GlutenException => + return None + } + case None => + return None + } + } +} + +trait AggregateFunctionAnalyzer { + def doValidate(): Boolean + def getArgumentExpressions(): Option[Seq[Expression]] + def ignoreNulls(): Boolean + def buildUnionAggregateExpression( + arguments: Seq[Expression], + flag: Expression): AggregateExpression +} + +case class DefaultAggregateFunctionAnalyzer() extends AggregateFunctionAnalyzer { + override def doValidate(): Boolean = false + override def getArgumentExpressions(): Option[Seq[Expression]] = None + override def ignoreNulls(): Boolean = false + override def buildUnionAggregateExpression( + arguments: Seq[Expression], + flag: Expression): AggregateExpression = { + throw new GlutenException("Unsupported aggregate function") + } +} + +case class SumAnalyzer(aggregateExpression: AggregateExpression) extends AggregateFunctionAnalyzer { + val sum = aggregateExpression.aggregateFunction.asInstanceOf[Sum] + override def doValidate(): Boolean = { + aggregateExpression.filter.isEmpty + } + + override def getArgumentExpressions(): Option[Seq[Expression]] = { + Some(Seq(sum.child)) + } + + override def ignoreNulls(): Boolean = true + + def buildUnionAggregateExpression( + arguments: Seq[Expression], + flag: Expression): AggregateExpression = { + val newSum = sum.copy(child = arguments.head) + aggregateExpression.copy(aggregateFunction = newSum) + } +} + +case class AverageAnalyzer(aggregateExpression: AggregateExpression) + extends AggregateFunctionAnalyzer { + val avg = aggregateExpression.aggregateFunction.asInstanceOf[Average] + override def doValidate(): Boolean = { + aggregateExpression.filter.isEmpty + } + + override def getArgumentExpressions(): Option[Seq[Expression]] = { + Some(Seq(avg.child)) + } + + override def ignoreNulls(): Boolean = true + def buildUnionAggregateExpression( + arguments: Seq[Expression], + flag: Expression): AggregateExpression = { + val newAvg = avg.copy(child = arguments.head) + aggregateExpression.copy(aggregateFunction = newAvg) + } +} + +case class CountAnalyzer(aggregateExpression: AggregateExpression) + extends AggregateFunctionAnalyzer { + val count = aggregateExpression.aggregateFunction.asInstanceOf[Count] + override def doValidate(): Boolean = { + count.children.length == 1 && aggregateExpression.filter.isEmpty + } + + override def getArgumentExpressions(): Option[Seq[Expression]] = { + Some(count.children) + } + + override def ignoreNulls(): Boolean = false + def buildUnionAggregateExpression( + arguments: Seq[Expression], + flag: Expression): AggregateExpression = { + val newCount = count.copy(children = arguments) + aggregateExpression.copy(aggregateFunction = newCount) + } +} + +case class MinAnalyzer(aggregateExpression: AggregateExpression) extends AggregateFunctionAnalyzer { + val min = aggregateExpression.aggregateFunction.asInstanceOf[Min] + override def doValidate(): Boolean = { + aggregateExpression.filter.isEmpty + } + + override def getArgumentExpressions(): Option[Seq[Expression]] = { + Some(Seq(min.child)) + } + + override def ignoreNulls(): Boolean = false + def buildUnionAggregateExpression( + arguments: Seq[Expression], + flag: Expression): AggregateExpression = { + val newMin = min.copy(child = arguments.head) + aggregateExpression.copy(aggregateFunction = newMin) + } +} + +case class FirstAnalyzer(aggregateExpression: AggregateExpression) + extends AggregateFunctionAnalyzer { + val first = aggregateExpression.aggregateFunction.asInstanceOf[First] + override def doValidate(): Boolean = { + aggregateExpression.filter.isEmpty && first.ignoreNulls + } + + override def getArgumentExpressions(): Option[Seq[Expression]] = { + Some(Seq(first.child)) + } + + override def ignoreNulls(): Boolean = true + def buildUnionAggregateExpression( + arguments: Seq[Expression], + flag: Expression): AggregateExpression = { + val newFirst = first.copy(child = arguments.head) + aggregateExpression.copy(aggregateFunction = newFirst) + } +} + +case class MaxAnalyzer(aggregateExpression: AggregateExpression) extends AggregateFunctionAnalyzer { + val max = aggregateExpression.aggregateFunction.asInstanceOf[Max] + override def doValidate(): Boolean = { + aggregateExpression.filter.isEmpty + } + + override def getArgumentExpressions(): Option[Seq[Expression]] = { + Some(Seq(max.child)) + } + + override def ignoreNulls(): Boolean = false + def buildUnionAggregateExpression( + arguments: Seq[Expression], + flag: Expression): AggregateExpression = { + val newMax = max.copy(child = arguments.head) + aggregateExpression.copy(aggregateFunction = newMax) + } +} + +object AggregateFunctionAnalyzer { + def apply(e: Expression): AggregateFunctionAnalyzer = { + val aggExpr = e match { + case alias @ Alias(aggregate: AggregateExpression, _) => + Some(aggregate) + case aggregate: AggregateExpression => + Some(aggregate) + case _ => None + } + + aggExpr match { + case Some(agg) => + agg.aggregateFunction match { + case sum: Sum => SumAnalyzer(agg) + case avg: Average => AverageAnalyzer(agg) + case count: Count => CountAnalyzer(agg) + case max: Max => MaxAnalyzer(agg) + case min: Min => MinAnalyzer(agg) + case first: First => FirstAnalyzer(agg) + case _ => DefaultAggregateFunctionAnalyzer() + } + case _ => DefaultAggregateFunctionAnalyzer() + } + } +} + +case class JoinedAggregateAnalyzer(join: Join, subquery: LogicalPlan) extends Logging { + def analysis(): Boolean = { + if (!extractAggregateQuery(subquery)) { + logDebug(s"xxx Not found aggregate query") + return false + } + + if (!extractGroupingKeys()) { + logDebug(s"xxx Not found grouping keys") + return false + } + + if (!extractJoinKeys()) { + logDebug(s"xxx Not found join keys") + return false + } + + if ( + joinKeys.length != aggregate.groupingExpressions.length || + !joinKeys.forall(k => outputGroupingKeys.exists(_.semanticEquals(k))) + ) { + logError( + s"xxx Join keys and grouping keys are not matched. joinKeys: $joinKeys" + + s" outputGroupingKeys: $outputGroupingKeys") + return false + } + + aggregateExpressions = + aggregate.aggregateExpressions.filter(_.find(_.isInstanceOf[AggregateExpression]).isDefined) + aggregateFunctionAnalyzer = aggregateExpressions.map(AggregateFunctionAnalyzer(_)) + + // If there is any const value in the aggregate expressions, return false + if (aggregateExpressions.length + joinKeys.length != aggregate.aggregateExpressions.length) { + logDebug(s"xxx Have const expression in aggregate expressions") + return false + } + if ( + aggregateExpressions.zipWithIndex.exists { + case (e, i) => !aggregateFunctionAnalyzer(i).doValidate + } + ) { + logDebug(s"xxx Have invalid aggregate function in aggregate expressions") + return false + } + + val arguments = + aggregateExpressions.zipWithIndex.map { + case (_, i) => aggregateFunctionAnalyzer(i).getArgumentExpressions + } + if (arguments.exists(_.isEmpty)) { + logDebug(s"xxx Get aggregate function arguments failed") + return false + } + aggregateExpressionArguments = arguments.map(_.get) + + true + } + + def extractAggregateQuery(plan: LogicalPlan): Boolean = { + plan match { + case _ @SubqueryAlias(_, agg: Aggregate) => + aggregate = agg + true + case agg: Aggregate => + aggregate = agg + true + case _ => false + } + } + + def getPrimeJoinKeys(): Seq[AttributeReference] = primeJoinKeys + def getJoinKeys(): Seq[AttributeReference] = joinKeys + def getAggregate(): Aggregate = aggregate + def getGroupingKeys(): Seq[Attribute] = outputGroupingKeys + def getGroupingExpressions(): Seq[NamedExpression] = groupingExpressions + def getAggregateExpressions(): Seq[NamedExpression] = aggregateExpressions + def getAggregateExpressionArguments(): Seq[Seq[Expression]] = aggregateExpressionArguments + def getAggregateFunctionAnalyzers(): Seq[AggregateFunctionAnalyzer] = aggregateFunctionAnalyzer + + private var primeJoinKeys: Seq[AttributeReference] = Seq.empty + private var joinKeys: Seq[AttributeReference] = Seq.empty + private var aggregate: Aggregate = null + private var groupingExpressions: Seq[NamedExpression] = null + private var outputGroupingKeys: Seq[Attribute] = null + private var aggregateExpressions: Seq[NamedExpression] = Seq.empty + private var aggregateExpressionArguments: Seq[Seq[Expression]] = Seq.empty + private var aggregateFunctionAnalyzer = Seq.empty[AggregateFunctionAnalyzer] + + private def extractJoinKeys(): Boolean = { + val leftJoinKeys = ArrayBuffer[AttributeReference]() + val subqueryKeys = ArrayBuffer[AttributeReference]() + val leftOutputSet = join.left.outputSet + val subqueryOutputSet = subquery.outputSet + val joinKeysPair = + RuleExpressionHelper.extractJoinKeys(join.condition, leftOutputSet, subqueryOutputSet) + if (joinKeysPair.isEmpty) { + false + } else { + primeJoinKeys = joinKeysPair.get.leftKeys + joinKeys = joinKeysPair.get.rightKeys + true + } + } + + def extractGroupingKeys(): Boolean = { + val outputGroupingKeysBuffer = ArrayBuffer[Attribute]() + val groupingExpressionsBuffer = ArrayBuffer[NamedExpression]() + val indexedAggregateExpressions = aggregate.aggregateExpressions.zipWithIndex + val aggregateOuput = aggregate.output + + def findMatchedAggregateExpression(e: Expression): Option[Tuple2[NamedExpression, Int]] = { + indexedAggregateExpressions.find { + case (aggExpr, i) => + aggExpr match { + case alias: Alias => + alias.child.semanticEquals(e) || alias.semanticEquals(e) + case _ => aggExpr.semanticEquals(e) + } + } + } + aggregate.groupingExpressions.map { + e => + e match { + case nameExpression: NamedExpression => + groupingExpressionsBuffer += nameExpression + findMatchedAggregateExpression(nameExpression) match { + case Some((aggExpr, i)) => + outputGroupingKeysBuffer += aggregateOuput(i) + case None => + return false + } + case other => + findMatchedAggregateExpression(other) match { + case Some((aggExpr, i)) => + outputGroupingKeysBuffer += aggregateOuput(i) + groupingExpressionsBuffer += aggregate.aggregateExpressions(i) + case None => + return false + } + } + } + outputGroupingKeys = outputGroupingKeysBuffer.toSeq + groupingExpressions = groupingExpressionsBuffer.toSeq + true + } +} + +object JoinedAggregateAnalyzer extends Logging { + def build(join: Join, subquery: LogicalPlan): Option[JoinedAggregateAnalyzer] = { + val analyzer = JoinedAggregateAnalyzer(join, subquery) + if (analyzer.analysis()) { + Some(analyzer) + } else { + None + } + } + + def haveSamePrimeJoinKeys(analzyers: Seq[JoinedAggregateAnalyzer]): Boolean = { + val primeKeys = analzyers.map(_.getPrimeJoinKeys()).map(AttributeSet(_)) + primeKeys + .slice(1, primeKeys.length) + .forall(keys => keys.equals(primeKeys.head)) + } +} + +/** + * ReorderJoinSubqueries is step before JoinAggregateToAggregateUnion. It will reorder the join + * subqueries to move the queries with the same prime keys, join keys and grouping keys together. + * For example + * ``` + * select t1.k1, t1.k2, s1, s2 from ( + * select k1, k2, count(v1) s1 from t1 group by k1, k2 + * ) t1 left join ( + * select * from t2 + * ) t2 on t1.k1 = t2.k1 and t1.k2 = t2.k2 + * left join ( + * select k1, k2, count(v2) s2 from t3 group by k1, k2 + * ) t3 on t1.k1 = t3.k1 and t1.k2 = t3.k2 + * ``` + * is rewritten into + * ``` + * select t1.k1, t1.k2, s1, s2 from ( + * select k1, k2, count(v1) s1 from t1 group by k1, k2 + * ) t1 left join ( + * select k1, k2, count(v2) s2 from t3 group by k1, k2 + * ) t3 on t1.k1 = t3.k1 and t1.k2 = t3.k2 + * ) left join ( + * select * from t2 + * ) t2 on t1.k1 = t2.k1 and t1.k2 = t2.k2 + * ``` + */ +case class ReorderJoinSubqueries() extends Logging { + case class SameJoinKeysPlans( + primeKeys: AttributeSet, + plans: ListBuffer[Tuple2[LogicalPlan, Join]]) {} + def apply(plan: LogicalPlan): LogicalPlan = { + visitPlan(plan) + } + + def visitPlan(plan: LogicalPlan): LogicalPlan = { + plan match { + case join: Join if join.joinType == LeftOuter && join.condition.isDefined => + val sameJoinKeysPlansList = ListBuffer[SameJoinKeysPlans]() + val plan = visitJoin(join, sameJoinKeysPlansList) + finishReorderJoinPlans(plan, sameJoinKeysPlansList) + case _ => + plan.withNewChildren(plan.children.map(visitPlan)) + } + } + + def findSameJoinKeysPlansIndex()( + sameJoinKeysPlans: ListBuffer[SameJoinKeysPlans], + primeKeys: Seq[AttributeReference]): Int = { + val keysSet = AttributeSet(primeKeys) + for (i <- 0 until sameJoinKeysPlans.length) { + if (sameJoinKeysPlans(i).primeKeys.equals(keysSet)) { + return i + } + } + -1 + } + + def finishReorderJoinPlans( + left: LogicalPlan, + sameJoinKeysPlansList: ListBuffer[SameJoinKeysPlans]): LogicalPlan = { + var finalPlan = left + for (j <- sameJoinKeysPlansList.length - 1 to 0 by -1) { + for (i <- sameJoinKeysPlansList(j).plans.length - 1 to 0 by -1) { + val plan = sameJoinKeysPlansList(j).plans(i)._1 + val originalJoin = sameJoinKeysPlansList(j).plans(i)._2 + finalPlan = originalJoin.copy(left = finalPlan, right = plan) + } + } + finalPlan + } + + def visitJoin( + plan: LogicalPlan, + sameJoinKeysPlansList: ListBuffer[SameJoinKeysPlans]): LogicalPlan = { + plan match { + case join: Join if join.joinType == LeftOuter && join.condition.isDefined => + val joinKeys = RuleExpressionHelper.extractJoinKeys( + join.condition, + join.left.outputSet, + join.right.outputSet) + joinKeys match { + case Some(keys) => + val index = findSameJoinKeysPlansIndex()(sameJoinKeysPlansList, keys.leftKeys) + val newRight = visitPlan(join.right) + if (index == -1) { + sameJoinKeysPlansList += SameJoinKeysPlans( + AttributeSet(keys.leftKeys), + ListBuffer(Tuple2(newRight, join))) + } else { + if (index != sameJoinKeysPlansList.length - 1) {} + val sameJoinKeysPlans = sameJoinKeysPlansList.remove(index) + if ( + RulePlanHelper.extractDirectAggregate(newRight).isDefined || + RulePlanHelper.extractDirectAggregate(sameJoinKeysPlans.plans.last._1).isEmpty + ) { + sameJoinKeysPlans.plans += Tuple2(newRight, join) + } else { + sameJoinKeysPlans.plans.insert(0, Tuple2(newRight, join)) + } + sameJoinKeysPlansList += sameJoinKeysPlans + } + visitJoin(join.left, sameJoinKeysPlansList) + case None => + val joinLeft = visitPlan(join.left) + val joinRight = visitPlan(join.right) + join.copy(left = joinLeft, right = joinRight) + } + case subquery: SubqueryAlias if RulePlanHelper.extractDirectAggregate(subquery).isDefined => + val newAggregate = visitPlan(subquery.child) + val groupingKeys = RulePlanHelper.extractDirectAggregate(subquery).get.groupingExpressions + if (groupingKeys.forall(_.isInstanceOf[AttributeReference])) { + val keys = groupingKeys.map(_.asInstanceOf[AttributeReference]) + val index = findSameJoinKeysPlansIndex()(sameJoinKeysPlansList, keys) + if (index != -1 && index != sameJoinKeysPlansList.length - 1) { + val sameJoinKeysPlans = sameJoinKeysPlansList.remove(index) + sameJoinKeysPlansList += sameJoinKeysPlans + } + } + subquery.copy(child = newAggregate) + case aggregate: Aggregate => + val newAggregate = aggregate.withNewChildren(aggregate.children.map(visitPlan)) + val groupingKeys = aggregate.groupingExpressions + if (groupingKeys.forall(_.isInstanceOf[AttributeReference])) { + val keys = groupingKeys.map(_.asInstanceOf[AttributeReference]) + val index = findSameJoinKeysPlansIndex()(sameJoinKeysPlansList, keys) + if (index != -1 && index != sameJoinKeysPlansList.length - 1) { + val sameJoinKeysPlans = sameJoinKeysPlansList.remove(index) + sameJoinKeysPlansList += sameJoinKeysPlans + } + } + newAggregate + case _ => + plan.withNewChildren(plan.children.map(visitPlan)) + } + } + +} + +case class JoinAggregateToAggregateUnion(spark: SparkSession) + extends Rule[LogicalPlan] + with Logging { + def isResolvedPlan(plan: LogicalPlan): Boolean = { + plan match { + case insert: InsertIntoStatement => insert.query.resolved + case _ => plan.resolved + } + } + + override def apply(plan: LogicalPlan): LogicalPlan = { + if ( + spark.conf + .get(CHBackendSettings.GLUTEN_JOIN_AGGREGATE_TO_AGGREGATE_UNION, "true") + .toBoolean && isResolvedPlan(plan) + ) { + val reorderedPlan = ReorderJoinSubqueries().apply(plan) + val newPlan = visitPlan(reorderedPlan) + logDebug(s"Rewrite plan from \n$plan to \n$newPlan") + newPlan + } else { + plan + } + } + + /** + * Use wide-table aggregation to eliminate multi-table joins. For example + * ``` + * select k1, k2, s1, s2 from (select k1, sum(a) as s1 from t1 group by k1) as t1 + * left join (select k2, sum(b) as s2 from t2 group by k2) as t2 + * on t1.k1 = t2.k2 + * ``` + * It's rewritten into + * ``` + * select k1, k2, s1, s2 from( + * select k1, k2, s1, if(isNull(flag2), null, s2) as s2 from ( + * select k1, first(k2) as k2, sum(a) as s1, sum(b) as s2, first(flag1) as flag1, + * first(flag2) as flag2 + * from ( + * select k1, null as k2, a as a, null as b, true as flag1, null as flag2 from t1 + * union all + * select k2, k2, null as a, b as b, null as flag1, true as flag2 from t2 + * ) group by k1 + * ) where flag1 is not null + * ) + * ``` + * + * The first query is easier to write, but not as efficient as the second one. + */ + def visitPlan(plan: LogicalPlan): LogicalPlan = { + plan match { + case join: Join => + if (join.joinType == LeftOuter && join.condition.isDefined) { + val analyzedAggregates = ArrayBuffer[JoinedAggregateAnalyzer]() + val remainedPlan = collectSameKeysJoinedAggregates(join, analyzedAggregates) + if (analyzedAggregates.length == 0) { + join.copy(left = visitPlan(join.left), right = visitPlan(join.right)) + } else if (analyzedAggregates.length == 1) { + join.copy(left = visitPlan(join.left)) + } else { + val unionedAggregates = unionAllJoinedAggregates(analyzedAggregates.toSeq) + if (remainedPlan.isDefined) { + val lastJoin = analyzedAggregates.head.join + lastJoin.copy(left = visitPlan(lastJoin.left), right = unionedAggregates) + } else { + buildPrimeJoinKeysFilterOnAggregateUnion(unionedAggregates, analyzedAggregates.toSeq) + } + } + } else { + plan.withNewChildren(plan.children.map(visitPlan)) + } + case _ => plan.withNewChildren(plan.children.map(visitPlan)) + } + } + + def buildAggregateExpressionWithNewChildren( + ne: NamedExpression, + inputs: Seq[Attribute]): NamedExpression = { + val aggregateExpression = + ne.asInstanceOf[Alias].child.asInstanceOf[AggregateExpression] + val newAggregateFunction = aggregateExpression.aggregateFunction + .withNewChildren(inputs) + .asInstanceOf[AggregateFunction] + val newAggregateExpression = aggregateExpression.copy(aggregateFunction = newAggregateFunction) + RuleExpressionHelper.makeNamedExpression(newAggregateExpression, ne.name) + } + + def unionAllJoinedAggregates(analyzedAggregates: Seq[JoinedAggregateAnalyzer]): LogicalPlan = { + val extendProjects = buildExtendProjects(analyzedAggregates) + val union = buildUnionOnExtendedProjects(extendProjects) + val aggregateUnion = buildAggregateOnUnion(union, analyzedAggregates) + logDebug(s"xxx aggregateUnion $aggregateUnion") + val setNullsProject = + buildMakeNotMatchedRowsNullProject(aggregateUnion, analyzedAggregates, Set()) + buildRenameProject(setNullsProject, analyzedAggregates) + } + + def buildAggregateOnUnion( + union: LogicalPlan, + analyzedAggregates: Seq[JoinedAggregateAnalyzer]): LogicalPlan = { + val unionOutput = union.output + val keysNumber = analyzedAggregates.head.getGroupingKeys().length + val aggregateExpressions = ArrayBuffer[NamedExpression]() + + val groupingKeys = + unionOutput.slice(0, keysNumber).zip(analyzedAggregates.head.getGroupingKeys()).map { + case (e, a) => + RuleExpressionHelper.makeNamedExpression(e, a.name) + } + aggregateExpressions ++= groupingKeys + + for (i <- 1 until analyzedAggregates.length) { + val keys = analyzedAggregates(i).getGroupingKeys() + val fieldIndex = i * keysNumber + for (j <- 0 until keys.length) { + val key = keys(j) + val valueExpr = unionOutput(fieldIndex + j) + val firstValue = RuleExpressionHelper.makeFirstAggregateExpression(valueExpr) + aggregateExpressions += RuleExpressionHelper.makeNamedExpression(firstValue, key.name) + } + } + + var fieldIndex = keysNumber * analyzedAggregates.length + for (i <- 0 until analyzedAggregates.length) { + val partialAggregateExpressions = analyzedAggregates(i).getAggregateExpressions() + val partialAggregateArguments = analyzedAggregates(i).getAggregateExpressionArguments() + val aggregateFunctionAnalyzers = analyzedAggregates(i).getAggregateFunctionAnalyzers() + for (j <- 0 until partialAggregateExpressions.length) { + val aggregateExpression = partialAggregateExpressions(j) + val arguments = partialAggregateArguments(j) + val aggregateFunctionAnalyzer = aggregateFunctionAnalyzers(j) + val newArguments = unionOutput.slice(fieldIndex, fieldIndex + arguments.length) + val flagExpr = unionOutput(unionOutput.length - analyzedAggregates.length + i) + val newAggregateExpression = aggregateFunctionAnalyzer + .buildUnionAggregateExpression(newArguments, flagExpr) + aggregateExpressions += RuleExpressionHelper.makeNamedExpression( + newAggregateExpression, + aggregateExpression.name) + fieldIndex += arguments.length + } + } + + for (i <- fieldIndex until unionOutput.length) { + val valueExpr = unionOutput(i) + val firstValue = RuleExpressionHelper.makeFirstAggregateExpression(valueExpr) + aggregateExpressions += RuleExpressionHelper.makeNamedExpression(firstValue, valueExpr.name) + } + + Aggregate(groupingKeys, aggregateExpressions.toSeq, union) + } + + /** + * Some rows may come from the right tables which grouping keys are not in the prime keys set. We + * should remove them. + */ + def buildPrimeJoinKeysFilterOnAggregateUnion( + plan: LogicalPlan, + analyzedAggregates: Seq[JoinedAggregateAnalyzer]): LogicalPlan = { + val flagExpressions = plan.output(plan.output.length - analyzedAggregates.length) + val notNullExpr = IsNotNull(flagExpressions) + Filter(notNullExpr, plan) + } + + /** + * When a row's grouping keys are not present in the related table, the related aggregate values + * are replaced with nulls. + */ + def buildMakeNotMatchedRowsNullProject( + plan: LogicalPlan, + analyzedAggregates: Seq[JoinedAggregateAnalyzer], + ignoreAggregates: Set[Int]): LogicalPlan = { + val input = plan.output + val flagExpressions = + input.slice(plan.output.length - analyzedAggregates.length, plan.output.length) + val aggregateExprsStart = + analyzedAggregates.length * analyzedAggregates.head.getGroupingKeys.length + var fieldIndex = aggregateExprsStart + val aggregatesIfNullExpressions = analyzedAggregates.zipWithIndex.map { + case (analyzedAggregate, i) => + val flagExpr = flagExpressions(i) + val aggregateExpressions = analyzedAggregate.getAggregateExpressions() + val aggregateFunctionAnalyzers = analyzedAggregate.getAggregateFunctionAnalyzers() + aggregateExpressions.zipWithIndex.map { + case (e, i) => + val valueExpr = input(fieldIndex) + fieldIndex += 1 + if (ignoreAggregates(i) || aggregateFunctionAnalyzers(i).ignoreNulls()) { + valueExpr.asInstanceOf[NamedExpression] + } else { + val clearExpr = If( + IsNull(flagExpr), + RuleExpressionHelper.makeNullLiteral(valueExpr.dataType), + valueExpr) + RuleExpressionHelper.makeNamedExpression(clearExpr, valueExpr.name) + } + } + } + val ifNullExpressions = aggregatesIfNullExpressions.flatten + val projectList = input.slice(0, aggregateExprsStart) ++ ifNullExpressions ++ flagExpressions + Project(projectList, plan) + } + + /** + * A final step, ensure the output attributes have the same name and exprId as the original join + */ + def buildRenameProject( + plan: LogicalPlan, + analyzedAggregates: Seq[JoinedAggregateAnalyzer]): LogicalPlan = { + val input = plan.output + val projectList = ArrayBuffer[NamedExpression]() + val keysNum = analyzedAggregates.head.getGroupingKeys.length + var fieldIndex = 0 + for (i <- 0 until analyzedAggregates.length) { + val keys = analyzedAggregates(i).getGroupingKeys() + for (j <- 0 until keys.length) { + val key = keys(j) + projectList += Alias(input(fieldIndex), key.name)( + key.exprId, + key.qualifier, + None, + Seq.empty) + fieldIndex += 1 + } + } + + for (i <- 0 until analyzedAggregates.length) { + val aggregateExpressions = analyzedAggregates(i).getAggregateExpressions() + for (j <- 0 until aggregateExpressions.length) { + val e = aggregateExpressions(j) + val valueExpr = input(fieldIndex) + projectList += Alias(valueExpr, e.name)(e.exprId, e.qualifier, None, Seq.empty) + fieldIndex += 1 + } + } + + // Keep the flag columns + projectList ++= input.slice(input.length - analyzedAggregates.length, input.length) + Project(projectList.toSeq, plan) + } + + /** + * Build a extended project list, which contains three parts. + * - The grouping keys of all tables. + * - All required columns as arguments for the aggregate functions in every table. + * - Flags for each table to indicate whether the row is in the table. + */ + def buildExtendProjectList( + analyzedAggregates: Seq[JoinedAggregateAnalyzer], + index: Int): Seq[NamedExpression] = { + val projectList = ArrayBuffer[NamedExpression]() + projectList ++= analyzedAggregates(index).getGroupingExpressions().zipWithIndex.map { + case (e, i) => + RuleExpressionHelper.makeNamedExpression(e, s"key_0_$i") + } + for (i <- 1 until analyzedAggregates.length) { + val groupingKeys = analyzedAggregates(i).getGroupingExpressions() + projectList ++= groupingKeys.zipWithIndex.map { + case (e, j) => + if (i == index) { + RuleExpressionHelper.makeNamedExpression(e, s"key_${i}_$j") + } else { + RuleExpressionHelper.makeNamedExpression( + RuleExpressionHelper.makeNullLiteral(e.dataType), + s"key_${i}_$j") + } + } + } + + for (i <- 0 until analyzedAggregates.length) { + val argsList = analyzedAggregates(i).getAggregateExpressionArguments() + argsList.zipWithIndex.foreach { + case (args, j) => + projectList ++= args.zipWithIndex.map { + case (arg, k) => + if (i == index) { + RuleExpressionHelper.makeNamedExpression(arg, s"arg_${i}_${j}_$k") + } else { + RuleExpressionHelper.makeNamedExpression( + RuleExpressionHelper.makeNullLiteral(arg.dataType), + s"arg_${i}_${j}_$k") + } + } + } + } + + for (i <- 0 until analyzedAggregates.length) { + if (i == index) { + projectList += RuleExpressionHelper.makeNamedExpression( + RuleExpressionHelper.makeFlagLiteral(), + s"flag_$i") + } else { + projectList += RuleExpressionHelper.makeNamedExpression( + RuleExpressionHelper.makeNullLiteral(BooleanType), + s"flag_$i") + } + } + + projectList.toSeq + } + + def buildExtendProjects(analyzedAggregates: Seq[JoinedAggregateAnalyzer]): Seq[LogicalPlan] = { + val projects = ArrayBuffer[LogicalPlan]() + for (i <- 0 until analyzedAggregates.length) { + val projectList = buildExtendProjectList(analyzedAggregates, i) + projects += Project(projectList, analyzedAggregates(i).getAggregate().child) + } + projects.toSeq + } + + def buildUnionOnExtendedProjects(plans: Seq[LogicalPlan]): LogicalPlan = { + val union = Union(plans) + logDebug(s"xxx build union: $union") + union + } + + def collectSameKeysJoinedAggregates( + plan: LogicalPlan, + analyzedAggregates: ArrayBuffer[JoinedAggregateAnalyzer]): Option[LogicalPlan] = { + plan match { + case join: Join if join.joinType == LeftOuter && join.condition.isDefined => + val optionAggregate = RulePlanHelper.extractDirectAggregate(join.right) + if (optionAggregate.isEmpty) { + return Some(plan) + } + val rightAggregateAnalyzer = JoinedAggregateAnalyzer.build(join, optionAggregate.get) + if (rightAggregateAnalyzer.isEmpty) { + logDebug(s"xxx Not a valid aggregate query") + return Some(plan) + } + + if ( + analyzedAggregates.isEmpty || + JoinedAggregateAnalyzer.haveSamePrimeJoinKeys( + Seq(analyzedAggregates.head, rightAggregateAnalyzer.get)) + ) { + // left plan is pushed in front + analyzedAggregates.insert(0, rightAggregateAnalyzer.get) + collectSameKeysJoinedAggregates(join.left, analyzedAggregates) + } else { + logError( + s"xxx Not have same keys. join keys:" + + s"${analyzedAggregates.head.getPrimeJoinKeys()} vs. " + + s"${rightAggregateAnalyzer.get.getPrimeJoinKeys()}") + Some(plan) + } + case _ if RulePlanHelper.extractDirectAggregate(plan).isDefined => + val aggregate = RulePlanHelper.extractDirectAggregate(plan).get + val lastJoin = analyzedAggregates.head.join + assert(lastJoin.left.equals(plan), "The node should be last join's left child") + val leftAggregateAnalyzer = JoinedAggregateAnalyzer.build(lastJoin, aggregate) + if (leftAggregateAnalyzer.isEmpty) { + return Some(plan) + } + if ( + JoinedAggregateAnalyzer.haveSamePrimeJoinKeys( + Seq(analyzedAggregates.head, leftAggregateAnalyzer.get)) + ) { + analyzedAggregates.insert(0, leftAggregateAnalyzer.get) + None + } else { + Some(plan) + } + case _ => Some(plan) + } + } + + def collectSameKeysJoinSubqueries( + plan: LogicalPlan, + groupedAggregates: ArrayBuffer[ArrayBuffer[JoinedAggregateAnalyzer]], + nonAggregates: ArrayBuffer[LogicalPlan]): Option[LogicalPlan] = { + plan match { + case join: Join if join.joinType == LeftOuter && join.condition.isDefined => + case _ => Some(plan) + } + None + } + +} diff --git a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenEliminateJoinSuite.scala b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenEliminateJoinSuite.scala new file mode 100644 index 000000000000..080fc9b5c9a2 --- /dev/null +++ b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenEliminateJoinSuite.scala @@ -0,0 +1,472 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.gluten.execution + +import org.apache.spark.SparkConf +import org.apache.spark.internal.Logging +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.optimizer.{ConstantFolding, NullPropagation} +import org.apache.spark.sql.execution.datasources.v2.clickhouse.ClickHouseConfig +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types._ + +import java.nio.file.Files + +class GlutenEliminateJoinSuite extends GlutenClickHouseWholeStageTransformerSuite with Logging { + + protected val tablesPath: String = basePath + "/tpch-data" + protected val tpchQueries: String = + rootPath + "../../../../tools/gluten-it/common/src/main/resources/tpch-queries" + protected val queriesResults: String = rootPath + "queries-output" + + private var parquetPath: String = _ + + override protected def sparkConf: SparkConf = { + super.sparkConf + .set("spark.sql.files.maxPartitionBytes", "1g") + .set("spark.serializer", "org.apache.spark.serializer.JavaSerializer") + .set("spark.sql.shuffle.partitions", "5") + .set("spark.sql.adaptive.enabled", "false") + .set("spark.sql.files.minPartitionNum", "1") + .set( + "spark.sql.catalog.spark_catalog", + "org.apache.spark.sql.execution.datasources.v2.clickhouse.ClickHouseSparkCatalog") + .set("spark.databricks.delta.maxSnapshotLineageLength", "20") + .set("spark.databricks.delta.snapshotPartitions", "1") + .set("spark.databricks.delta.properties.defaults.checkpointInterval", "5") + .set("spark.databricks.delta.stalenessLimit", "3600000") + .set(ClickHouseConfig.CLICKHOUSE_WORKER_ID, "1") + .set("spark.gluten.sql.columnar.iterator", "true") + .set("spark.gluten.sql.columnar.hashagg.enablefinal", "true") + .set("spark.gluten.sql.enable.native.validation", "false") + .set("spark.sql.warehouse.dir", warehouse) + .set("spark.shuffle.manager", "sort") + .set("spark.io.compression.codec", "snappy") + .set("spark.sql.shuffle.partitions", "5") + .set("spark.sql.autoBroadcastJoinThreshold", "-1") + .set("spark.gluten.supported.scala.udfs", "compare_substrings:compare_substrings") + .set( + SQLConf.OPTIMIZER_EXCLUDED_RULES.key, + ConstantFolding.ruleName + "," + NullPropagation.ruleName) + } + + def createTable(tableName: String, rows: Seq[Row], schema: StructType): Unit = { + val file = Files.createTempFile(tableName, ".parquet").toFile + file.deleteOnExit() + val data = sparkContext.parallelize(rows) + val dataframe = spark.createDataFrame(data, schema) + dataframe.coalesce(1).write.format("parquet").mode("overwrite").parquet(file.getAbsolutePath) + spark.catalog.createTable(tableName, file.getAbsolutePath, "parquet") + } + + override def beforeAll(): Unit = { + super.beforeAll() + + val schema1 = StructType( + Array( + StructField("k1", IntegerType, true), + StructField("k2", IntegerType, true), + StructField("v1", IntegerType, true), + StructField("v2", IntegerType, true), + StructField("v3", IntegerType, true) + ) + ) + + val data1 = Seq( + Row(1, 1, 1, 1, 1), + Row(1, 1, null, 2, 1), + Row(1, 2, 1, 1, null), + Row(2, 1, 1, null, 1), + Row(2, 2, 1, 1, 1), + Row(2, 2, 1, null, 1), + Row(2, 2, 2, null, 3), + Row(2, 3, 0, null, 1), + Row(2, 4, 1, 2, 3), + Row(3, 1, 4, 5, 6), + Row(4, 2, 7, 8, 9), + Row(5, 3, 10, 11, 12) + ) + createTable("t1", data1, schema1) + createTable("t2", data1, schema1) + createTable("t3", data1, schema1) + } + + test("Eliminate two aggregates join") { + val sql = """ + select t1.k1, t1.k2, t2.k1, t2.k2, s1, s2 from ( + select k1, k2, sum(v1) s1 from ( + select * from t1 where k1 != 1 + )group by k1, k2 + ) t1 left join ( + select k1, k2, count(v1) s2 from ( + select * from t2 where k1 != 3 + )group by k1, k2 + ) t2 on t1.k1 = t2.k1 and t1.k2 = t2.k2 + order by t1.k1, t1.k2, s1, s2 + """.stripMargin + compareResultsAgainstVanillaSpark( + sql, + true, + { + df => + val joins = df.queryExecution.executedPlan.collect { + case join: ShuffledHashJoinExecTransformerBase => join + } + assert(joins.length == 0) + }) + } + + test("Elimiate three aggreages join") { + val sql = """ + select t1.k1, t1.k2, t2.k1, t3.k2, s1, s2, s3 from ( + select k1, k2, sum(v1) s1 from ( + select * from t1 where k1 != 1 + )group by k1, k2 + ) t1 left join ( + select k1, k2, count(v1) s2 from ( + select * from t2 where k1 != 3 + )group by k1, k2 + ) t2 on t1.k1 = t2.k1 and t1.k2 = t2.k2 + left join ( + select k1, k2, count(v2) s3 from ( + select * from t3 where k1 != 3 + )group by k1, k2 + ) t3 on t1.k1 = t3.k1 and t1.k2 = t3.k2 + order by t1.k1, t1.k2, s1, s2, s3 + """.stripMargin + compareResultsAgainstVanillaSpark( + sql, + true, + { + df => + val joins = df.queryExecution.executedPlan.collect { + case join: ShuffledHashJoinExecTransformerBase => join + } + assert(joins.length == 0) + }) + } + + test("Left one join uneliminable") { + val sql = """ + select t1.k1, t1.k2, s1, s2 from ( + select * from t1 where k1 != 1 + ) t1 left join ( + select k1, k2, count(v1) s1 from ( + select * from t2 where k1 != 1 + )group by k1, k2 + ) t2 on t1.k1 = t2.k1 and t1.k2 = t2.k2 + left join ( + select k1, k2, count(v2) s2 from ( + select * from t3 where k1 != 3 + )group by k1, k2 + ) t3 on t1.k1 = t3.k1 and t1.k2 = t3.k2 + order by t1.k1, t1.k2, s1, s2 + """.stripMargin + compareResultsAgainstVanillaSpark( + sql, + true, + { + df => + val joins = df.queryExecution.executedPlan.collect { + case join: ShuffledHashJoinExecTransformerBase => join + } + assert(joins.length == 1) + }) + } + + test("reorder join orders 1") { + val sql = """ + select t1.k1, t1.k2, t2.k1, s1, s2 from ( + select k1, k2, count(v1) s1 from ( + select * from t2 where k1 != 1 + )group by k1, k2 + ) t1 left join ( + select * from t1 where k1 != 1 + ) t2 on t1.k1 = t2.k1 and t1.k2 = t2.k2 + left join ( + select k1, k2, count(v2) s2 from ( + select * from t3 where k1 != 3 + )group by k1, k2 + ) t3 on t1.k1 = t3.k1 and t1.k2 = t3.k2 + order by t1.k1, t1.k2, t2.k1, s1, s2 + """.stripMargin + compareResultsAgainstVanillaSpark( + sql, + true, + { + df => + val joins = df.queryExecution.executedPlan.collect { + case join: ShuffledHashJoinExecTransformerBase => join + } + assert(joins.length == 1) + }) + } + + test("reorder join orders 2") { + val sql = """ + select t1.k1, t2.k1, s1, s2 from ( + select k1, k2, count(v1) s1 from ( + select * from t2 where k1 != 1 + )group by k1, k2 + ) t1 left join ( + select k1, count(v2) as s3 from t1 where k1 != 1 + group by k1 + ) t2 on t1.k1 = t2.k1 + left join ( + select k1, k2, count(v2) s2 from ( + select * from t3 where k1 != 3 + )group by k1, k2 + ) t3 on t1.k1 = t3.k1 and t1.k2 = t3.k2 + order by t1.k1, t2.k1, s1, s2 + """.stripMargin + compareResultsAgainstVanillaSpark( + sql, + true, + { + df => + val joins = df.queryExecution.executedPlan.collect { + case join: ShuffledHashJoinExecTransformerBase => join + } + assert(joins.length == 1) + } + ) + } + + test("aggregate literal") { + val sql = """ + select t1.k1, t1.k2, t2.k1, t2.k2, s1, s2 from ( + select k1, k2, sum(2) s1 from ( + select * from t1 where k1 != 1 + )group by k1, k2 + ) t1 left join ( + select k1, k2, count(1) s2 from ( + select * from t2 where k1 != 3 + )group by k1, k2 + ) t2 on t1.k1 = t2.k1 and t1.k2 = t2.k2 + order by t1.k1, t1.k2, s1, s2 + """.stripMargin + compareResultsAgainstVanillaSpark( + sql, + true, + { + df => + val joins = df.queryExecution.executedPlan.collect { + case join: ShuffledHashJoinExecTransformerBase => join + } + assert(joins.length == 0) + }) + } + + test("aggregate avg") { + val sql = """ + select t1.k1, t1.k2, t2.k1, t2.k2, s1, s2 from ( + select k1, k2, avg(2) s1 from ( + select * from t1 where k1 != 1 + )group by k1, k2 + ) t1 left join ( + select k1, k2, avg(1) s2 from ( + select * from t2 where k1 != 3 + )group by k1, k2 + ) t2 on t1.k1 = t2.k1 and t1.k2 = t2.k2 + order by t1.k1, t1.k2, s1, s2 + """.stripMargin + compareResultsAgainstVanillaSpark( + sql, + true, + { + df => + val joins = df.queryExecution.executedPlan.collect { + case join: ShuffledHashJoinExecTransformerBase => join + } + assert(joins.length == 0) + }) + } + test("aggregate min/max") { + val sql = """ + select t1.k1, t1.k2, t2.k1, t2.k2, s1, s2 from ( + select k1, k2, min(2) s1 from ( + select * from t1 where k1 != 1 + )group by k1, k2 + ) t1 left join ( + select k1, k2, max(1) s2 from ( + select * from t2 where k1 != 3 + )group by k1, k2 + ) t2 on t1.k1 = t2.k1 and t1.k2 = t2.k2 + order by t1.k1, t1.k2, s1, s2 + """.stripMargin + compareResultsAgainstVanillaSpark( + sql, + true, + { + df => + val joins = df.queryExecution.executedPlan.collect { + case join: ShuffledHashJoinExecTransformerBase => join + } + assert(joins.length == 0) + }) + } + + test("aggregate count distinct") { + val sql = """ + select t1.k1, t2.k1, s1, s2, s3, s4 from ( + select k1, count(distinct v1) s1, count(distinct v2) as s2 from ( + select * from t1 where k1 != 1 + )group by k1 + ) t1 left join ( + select k1, count(distinct v1) s3, count(distinct v3) as s4 from ( + select * from t2 where k1 != 3 + )group by k1 + ) t2 on t1.k1 = t2.k1 + order by t1.k1, s1, s2, s3, s4 + """.stripMargin + compareResultsAgainstVanillaSpark( + sql, + true, + { + df => + val joins = df.queryExecution.executedPlan.collect { + case join: ShuffledHashJoinExecTransformerBase => join + } + assert(joins.length == 0) + }) + } + + test("distinct") { + val sql = """ + select t1.k1, t2.k1, s1 from ( + select k1, count(v1) as s1 from ( + select * from t1 where k1 != 1 + )group by k1 + ) t1 left join ( + select distinct k1 from ( + select * from t2 where k1 != 3 + ) t2 + ) t2 on t1.k1 = t2.k1 + order by t1.k1, s1 + """.stripMargin + compareResultsAgainstVanillaSpark( + sql, + true, + { + df => + val joins = df.queryExecution.executedPlan.collect { + case join: ShuffledHashJoinExecTransformerBase => join + } + assert(joins.length == 0) + }) + } + + test("different join keys and grouping keys 1") { + val sql = """ + select t1.k1, t2.k1, s1 from ( + select k1, count(v1) as s1 from ( + select * from t1 where k1 != 1 + )group by k1 + ) t1 left join ( + select distinct k1, k2 from ( + select * from t2 where k1 != 3 + ) t2 + ) t2 on t1.k1 = t2.k1 + order by t1.k1, s1 + """.stripMargin + compareResultsAgainstVanillaSpark( + sql, + true, + { + df => + val joins = df.queryExecution.executedPlan.collect { + case join: ShuffledHashJoinExecTransformerBase => join + } + assert(joins.length == 1) + }) + } + + test("not attribute grouping keys 1") { + val sql = """ + select t1.k1, t2.k2, s1 from ( + select k1 + 1 as k1, count(v1) as s1 from ( + select * from t1 where k1 != 1 + )group by k1 + 1 + ) t1 left join ( + select distinct k2 from ( + select * from t2 where k1 != 3 + ) t2 + ) t2 on t1.k1 = t2.k2 + order by t1.k1, t2.k2, s1 + """.stripMargin + compareResultsAgainstVanillaSpark( + sql, + true, + { + df => + val joins = df.queryExecution.executedPlan.collect { + case join: ShuffledHashJoinExecTransformerBase => join + } + assert(joins.length == 0) + }) + } + + test("not attribute grouping keys 2") { + val sql = """ + select t1.k1, t2.k2, s1 from ( + select k1 + 1 as k1, count(v1) as s1 from ( + select * from t1 where k1 != 1 + )group by 1 + ) t1 left join ( + select distinct k2 from ( + select * from t2 where k1 != 3 + ) t2 + ) t2 on t1.k1 = t2.k2 + order by t1.k1, t2.k2, s1 + """.stripMargin + compareResultsAgainstVanillaSpark( + sql, + true, + { + df => + val joins = df.queryExecution.executedPlan.collect { + case join: ShuffledHashJoinExecTransformerBase => join + } + assert(joins.length == 0) + }) + } + + test("const expression aggregate expression") { + val sql = """ + select t1.k1, t2.k1, s1 from ( + select k1 + 1 as k1, 1 as s1 from ( + select * from t1 where k1 != 1 + )group by 1 + ) t1 left join ( + select distinct k1 from ( + select * from t2 where k1 != 3 + ) t2 + ) t2 on t1.k1 = t2.k1 + order by t1.k1, t2.k1, s1 + """.stripMargin + compareResultsAgainstVanillaSpark( + sql, + true, + { + df => + val joins = df.queryExecution.executedPlan.collect { + case join: ShuffledHashJoinExecTransformerBase => join + } + assert(joins.length == 1) + }) + } +}