From d02c2555d67e987f3420485ed8e5915b05d32e93 Mon Sep 17 00:00:00 2001 From: lgbo-ustc Date: Wed, 26 Mar 2025 08:52:22 +0800 Subject: [PATCH 01/10] wip --- .../backendsapi/clickhouse/CHBackend.scala | 2 + .../backendsapi/clickhouse/CHRuleApi.scala | 1 + .../JoinAggregateToAggregateUnion.scala | 519 ++++++++++++++++++ 3 files changed, 522 insertions(+) create mode 100644 backends-clickhouse/src/main/scala/org/apache/gluten/extension/JoinAggregateToAggregateUnion.scala diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala index 06def8902bc0..d71d41bf2d4c 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala @@ -160,6 +160,8 @@ object CHBackendSettings extends BackendSettingsApi with Logging { CHConfig.prefixOf("enable.coalesce.aggregation.union") val GLUTEN_ENABLE_COALESCE_PROJECT_UNION: String = CHConfig.prefixOf("enable.coalesce.project.union") + val GLUTEN_JOIN_AGGREGATE_TO_AGGREGATE_UNION: String = + CHConfig.prefixOf("join.aggregate.to.aggregate.union") def affinityMode: String = { SparkEnv.get.conf diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala index 695deaddbf82..f78c4f483aff 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala @@ -62,6 +62,7 @@ object CHRuleApi { (spark, parserInterface) => new GlutenClickhouseSqlParser(spark, parserInterface)) injector.injectResolutionRule(spark => new CoalesceAggregationUnion(spark)) injector.injectResolutionRule(spark => new CoalesceProjectionUnion(spark)) + injector.injectResolutionRule(spark => new JoinAggregateToAggregateUnion(spark)) injector.injectResolutionRule(spark => new RewriteToDateExpresstionRule(spark)) injector.injectResolutionRule(spark => new RewriteDateTimestampComparisonRule(spark)) injector.injectResolutionRule(spark => new CollapseGetJsonObjectExpressionRule(spark)) diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/JoinAggregateToAggregateUnion.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/JoinAggregateToAggregateUnion.scala new file mode 100644 index 000000000000..48815bc46645 --- /dev/null +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/JoinAggregateToAggregateUnion.scala @@ -0,0 +1,519 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.gluten.extension + +import org.apache.gluten.backendsapi.clickhouse.CHBackendSettings +import org.apache.gluten.exception._ + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.types._ + +import scala.collection.mutable.ArrayBuffer + +trait AggregateExpressionValidator { + def doValidate(): Boolean + def getArgumentExpressions(): Option[Seq[Expression]] +} + +case class DefaultValidator() extends AggregateExpressionValidator { + override def doValidate(): Boolean = false + override def getArgumentExpressions(): Option[Seq[Expression]] = None +} + +case class SumValidator(aggExpr: AggregateExpression) extends AggregateExpressionValidator { + val sum = aggExpr.aggregateFunction.asInstanceOf[Sum] + override def doValidate(): Boolean = { + !sum.child.isInstanceOf[Literal] + } + + override def getArgumentExpressions(): Option[Seq[Expression]] = { + Some(Seq(sum.child)) + } +} + +case class CountValidator(aggExpr: AggregateExpression) extends AggregateExpressionValidator { + val count = aggExpr.aggregateFunction.asInstanceOf[Count] + override def doValidate(): Boolean = { + count.children.length == 1 && !count.children.head.isInstanceOf[Literal] + } + + override def getArgumentExpressions(): Option[Seq[Expression]] = { + Some(count.children) + } +} + +object AggregateExpressionValidator { + def apply(e: Expression): AggregateExpressionValidator = { + val aggExpr = e match { + case alias @ Alias(aggregate: AggregateExpression, _) => + Some(aggregate) + case aggregate: AggregateExpression => + Some(aggregate) + case _ => None + } + + aggExpr match { + case Some(agg) => + agg.aggregateFunction match { + case sum: Sum => SumValidator(agg) + case count: Count => CountValidator(agg) + case _ => DefaultValidator() + } + case _ => DefaultValidator() + } + } +} + +case class JoinAggregateToAggregateUnion(spark: SparkSession) + extends Rule[LogicalPlan] + with Logging { + val JOIN_FILTER_FLAG_NAME = "_left_flag_" + def isResolvedPlan(plan: LogicalPlan): Boolean = { + plan match { + case insert: InsertIntoStatement => insert.query.resolved + case _ => plan.resolved + } + } + + override def apply(plan: LogicalPlan): LogicalPlan = { + if ( + spark.conf + .get(CHBackendSettings.GLUTEN_JOIN_AGGREGATE_TO_AGGREGATE_UNION, "true") + .toBoolean && isResolvedPlan(plan) + ) { + val newPlan = visitPlan(plan) + logDebug(s"old plan\n$plan\nnew plan\n$newPlan") + newPlan + } else { + plan + } + } + + def extractJoinKeys( + join: Join): (Option[Seq[AttributeReference]], Option[Seq[AttributeReference]]) = { + val leftKeys = ArrayBuffer[AttributeReference]() + val rightKeys = ArrayBuffer[AttributeReference]() + val leftOutputSet = join.left.outputSet + val rightOutputSet = join.right.outputSet + def visitJoinExpression(e: Expression): Unit = { + e match { + case and: And => + visitJoinExpression(and.left) + visitJoinExpression(and.right) + case equalTo @ EqualTo(left: AttributeReference, right: AttributeReference) => + if (leftOutputSet.contains(left) && rightOutputSet.contains(right)) { + leftKeys += left + rightKeys += right + } else if (leftOutputSet.contains(right) && rightOutputSet.contains(left)) { + leftKeys += right + rightKeys += left + } else { + throw new GlutenException(s"Invalid join condition $equalTo") + } + case _ => + throw new GlutenException(s"Unsupported join condition $e") + } + } + join.condition match { + case Some(condition) => + try { + visitJoinExpression(condition) + } catch { + case e: GlutenException => + logDebug(s"xxx invalid join condition $condition") + return (None, None) + } + (Some(leftKeys), Some(rightKeys)) + case _ => (None, None) + } + } + + def extracAggregateExpressions(aggregate: Aggregate): Seq[NamedExpression] = { + aggregate.aggregateExpressions.filter(_.find(_.isInstanceOf[AggregateExpression]).isDefined) + } + + def validateAggregateExpressions(aggregateExpressions: Seq[NamedExpression]): Boolean = { + aggregateExpressions.forall(AggregateExpressionValidator(_).doValidate) + } + + def collectAggregateExpressionsArguments( + expressions: Seq[NamedExpression]): Option[Seq[Expression]] = { + val arguments = expressions.map(AggregateExpressionValidator(_).getArgumentExpressions) + if (arguments.exists(_.isEmpty)) { + None + } else { + Some(arguments.map(_.get).flatten) + } + } + + def areSameKeys(leftKeys: Seq[Expression], rightKeys: Seq[Expression]): Boolean = { + var isValid = leftKeys.length == rightKeys.length + isValid = isValid && leftKeys.forall(_.isInstanceOf[AttributeReference]) && + rightKeys.forall(_.isInstanceOf[AttributeReference]) + isValid && leftKeys.forall(e => rightKeys.exists(_.semanticEquals(e))) + } + + // Only two tables join, both are aggregate + def rewriteOnlyTwoAggregatesJoin( + leftAggregate: Aggregate, + rightAggregate: Aggregate, + join: Join): Option[LogicalPlan] = { + val (leftKeys, rightKeys) = extractJoinKeys(join) + if (!leftKeys.isDefined || !rightKeys.isDefined) { + logDebug(s"xxx not valid join keys") + return None + } + + val leftAggregateExpressions = extracAggregateExpressions(leftAggregate) + val rightAggregateExpressions = extracAggregateExpressions(rightAggregate) + val firstAggExpr = + leftAggregateExpressions.head.asInstanceOf[Alias].child.asInstanceOf[AggregateExpression] + logDebug(s"xxx aggregate mode ${firstAggExpr.mode}") + if ( + !validateAggregateExpressions(leftAggregateExpressions) || + !validateAggregateExpressions(rightAggregateExpressions) + ) { + logDebug(s"xxx not valid aggregate expressions") + return None + } + + if ( + !areSameKeys( + leftKeys.get.map(_.asInstanceOf[Expression]), + leftAggregate.groupingExpressions.map(_.asInstanceOf[Expression])) || + !areSameKeys( + rightKeys.get.map(_.asInstanceOf[Expression]), + rightAggregate.groupingExpressions.map(_.asInstanceOf[Expression])) + ) { + logDebug(s"xxx not same keys") + return None + } + + val leftAggregateInputs = collectAggregateExpressionsArguments(leftAggregateExpressions) + val rightAggregateInputs = collectAggregateExpressionsArguments(rightAggregateExpressions) + if (!leftAggregateInputs.isDefined || !rightAggregateInputs.isDefined) { + logDebug(s"xxx not valid aggregate inputs") + return None + } + logDebug(s"xxx left inputs\n$leftAggregateInputs\nright inputs\n$rightAggregateInputs") + logDebug(s"xxx all check passed. $join") + + val aggregates = Seq(leftAggregate, rightAggregate) + + val leftProject = buildExtendedProject( + leftAggregate.groupingExpressions, + leftAggregate.groupingExpressions.map(_.dataType).map(makeNullLiteral), + leftAggregateInputs.get, + rightAggregateInputs.get.map(_.dataType).map(makeNullLiteral), + Seq(makeFlagLiteral(), makeNullLiteral(BooleanType)), + leftAggregate.child + ) + val rightProject = buildExtendedProject( + rightAggregate.groupingExpressions, + rightAggregate.groupingExpressions, + leftAggregateInputs.get.map(_.dataType).map(makeNullLiteral), + rightAggregateInputs.get, + Seq(makeNullLiteral(BooleanType), makeFlagLiteral()), + rightAggregate.child + ) + val union = buildUnionOnExtendedProjects(leftProject, rightProject) + val unionOutput = union.output + logDebug(s"xxx join outputs ${join.output}\nunion outputs ${union.output}") + val aggregateUnion = + buildAggregateOnUnion(union, leftAggregate, leftKeys.get, rightAggregate, rightKeys.get) + logDebug(s"xxx aggregate union\n$aggregateUnion") + logDebug(s"xxx aggregate outputs ${aggregateUnion.output}") + val filtAggregateUnion = buildPrimeKeysFilterOnAggregateUnion(aggregateUnion, aggregates) + logDebug(s"xxx filt aggregate union\n$filtAggregateUnion") + logDebug(s"xxx filt outputs ${filtAggregateUnion.output}") + val fieldStartOffset = + leftAggregate.groupingExpressions.length * aggregates.length + leftAggregateExpressions.length + val setNullsProject = buildMakeNotMatchedRowsNullProject( + filtAggregateUnion, + fieldStartOffset, + aggregates.slice(1, aggregates.length)) + logDebug(s"xxx set nulls project\n$setNullsProject") + val renameProject = + buildRenameProject(setNullsProject, aggregates, Seq(leftKeys.get, rightKeys.get)) + logDebug(s"xxx rename project\n$renameProject") + logDebug(s"xxx rename outputs ${renameProject.output}") + Some(renameProject) + } + + def rebuildAggregateExpression( + ne: NamedExpression, + input: Seq[Attribute], + inputOffset: Integer): (Integer, NamedExpression) = { + val alias = ne.asInstanceOf[Alias] + val aggregateExpression = alias.child.asInstanceOf[AggregateExpression] + val validator = AggregateExpressionValidator(aggregateExpression) + val arguments = validator.getArgumentExpressions().get + val newArguments = input.slice(inputOffset, inputOffset + arguments.length) + val newAggregateFunction = aggregateExpression.aggregateFunction + .withNewChildren(newArguments) + .asInstanceOf[AggregateFunction] + val newAggregateExpression = aggregateExpression.copy(aggregateFunction = newAggregateFunction) + // val newAlias = + // Alias(newAggregateExpression, alias.name)(alias.exprId, alias.qualifier, None, Seq.empty) + val newAlias = makeNamedExpression(newAggregateExpression, alias.name) + (inputOffset + arguments.length, newAlias) + } + + def buildAggregateOnUnion( + union: LogicalPlan, + leftAggregate: Aggregate, + leftKeys: Seq[AttributeReference], + rightAggregate: Aggregate, + rightKeys: Seq[AttributeReference]): LogicalPlan = { + val unionOutput = union.output + val groupingKeysNumber = leftKeys.length + val aggregateExpressions = ArrayBuffer[NamedExpression]() + val groupingKeys = unionOutput.slice(0, groupingKeysNumber).zip(leftKeys).map { + case (e, a) => + // Alias(e, a.name)(a.exprId, a.qualifier, Some(a.metadata), Seq.empty) + makeNamedExpression(e, a.name) + } + aggregateExpressions ++= groupingKeys + + var fieldIndex = groupingKeysNumber + aggregateExpressions ++= unionOutput + .slice(groupingKeysNumber, groupingKeysNumber * 2) + .map(makeFirstAggregateExpression) + .zip(rightKeys) + .map { + case (e, a) => + fieldIndex += 1 + // Alias(e, a.name)(a.exprId, a.qualifier, Some(a.metadata), Seq.empty) + makeNamedExpression(e, a.name) + } + + aggregateExpressions ++= + extracAggregateExpressions(leftAggregate).map { + e => + val (nextFieldIndex, newAggregateExpression) = + rebuildAggregateExpression(e, unionOutput, fieldIndex) + fieldIndex = nextFieldIndex + newAggregateExpression + } + aggregateExpressions ++= + extracAggregateExpressions(rightAggregate).map { + e => + val (nextFieldIndex, newAggregateExpression) = + rebuildAggregateExpression(e, unionOutput, fieldIndex) + fieldIndex = nextFieldIndex + newAggregateExpression + } + + aggregateExpressions ++= unionOutput.slice(fieldIndex, unionOutput.length).map { + attr => + val firstValue = makeFirstAggregateExpression(attr) + makeNamedExpression(firstValue, attr.name) + } + + Aggregate(groupingKeys, aggregateExpressions.toSeq, union) + } + + def buildPrimeKeysFilterOnAggregateUnion( + plan: LogicalPlan, + aggregates: Seq[Aggregate]): LogicalPlan = { + val flagExpressions = plan.output(plan.output.length - aggregates.length) + val notNullExpr = IsNotNull(flagExpressions) + Filter(notNullExpr, plan) + } + + def buildMakeNotMatchedRowsNullProject( + plan: LogicalPlan, + fieldStartOffset: Integer, + aggregates: Seq[Aggregate]): LogicalPlan = { + val input = plan.output + val flagExpressions = input.slice(plan.output.length - aggregates.length, plan.output.length) + var fieldIndex = fieldStartOffset + val aggregatesIfNullExpressions = aggregates.zipWithIndex.map { + case (aggregate, i) => + val flagExpr = flagExpressions(i) + val aggregateExpressions = extracAggregateExpressions(aggregate) + aggregateExpressions.map { + e => + val valueExpr = input(fieldIndex) + fieldIndex += 1 + val clearExpr = If(IsNull(flagExpr), makeNullLiteral(valueExpr.dataType), valueExpr) + makeNamedExpression(clearExpr, valueExpr.name) + } + } + val ifNullExpressions = aggregatesIfNullExpressions.flatten + val projectList = input.slice(0, fieldStartOffset) ++ ifNullExpressions + Project(projectList, plan) + } + + def buildRenameProject( + plan: LogicalPlan, + aggregates: Seq[Aggregate], + joinKeys: Seq[Seq[AttributeReference]]): LogicalPlan = { + val input = plan.output + var fieldIndex = 0 + val projectList = ArrayBuffer[NamedExpression]() + joinKeys.foreach { + keys => + keys.foreach { + key => + projectList += Alias(input(fieldIndex), key.name)( + key.exprId, + key.qualifier, + None, + Seq.empty) + fieldIndex += 1 + } + } + aggregates.foreach { + aggregate => + val aggregateExpressions = extracAggregateExpressions(aggregate) + aggregateExpressions.foreach { + e => + val valueExpr = input(fieldIndex) + projectList += Alias(valueExpr, e.name)(e.exprId, e.qualifier, None, Seq.empty) + fieldIndex += 1 + } + } + Project(projectList.toSeq, plan) + } + + def makeNamedExpression(e: Expression, name: String): NamedExpression = { + Alias(e, name)() + } + + def makeNullLiteral(dataType: DataType): Literal = { + Literal.create(null, dataType) + } + + def makeFlagLiteral(): Literal = { + Literal.create(true, BooleanType) + } + + def makeFirstAggregateExpression(e: Expression): AggregateExpression = { + val aggregateFunction = First(e, true) + AggregateExpression(aggregateFunction, Complete, false) + } + + def buildExtendProjectList( + leftGroupingkeys: Seq[Expression], + rightGroupingkeys: Seq[Expression], + leftInputs: Seq[Expression], + rightInputs: Seq[Expression], + flags: Seq[Expression]): Seq[NamedExpression] = { + var fieldIndex = 0 + val part1 = leftGroupingkeys.map { + e => + fieldIndex += 1 + makeNamedExpression(e, s"field$fieldIndex") + } + val part2 = rightGroupingkeys.map { + e => + fieldIndex += 1 + makeNamedExpression(e, s"field$fieldIndex") + } + val part3 = leftInputs.map { + e => + fieldIndex += 1 + makeNamedExpression(e, s"field$fieldIndex") + } + val part4 = rightInputs.map { + e => + fieldIndex += 1 + makeNamedExpression(e, s"field$fieldIndex") + } + val part5 = flags.map { + e => + fieldIndex += 1 + makeNamedExpression(e, s"field$fieldIndex") + } + part1 ++ part2 ++ part3 ++ part4 ++ part5 + } + + def buildExtendedProject( + leftGroupingkeys: Seq[Expression], + rightGroupingkeys: Seq[Expression], + leftInputs: Seq[Expression], + rightInputs: Seq[Expression], + flags: Seq[Expression], + child: LogicalPlan): LogicalPlan = { + val projectList = + buildExtendProjectList(leftGroupingkeys, rightGroupingkeys, leftInputs, rightInputs, flags) + Project(projectList, child) + } + + + def buildUnionOnExtendedProjects1(plans: Seq[LogicalPlan]): LogicalPlan = { + val union = Union(plans) + logDebug(s"xxx union\n$union") + union + } + def buildUnionOnExtendedProjects( + leftProject: LogicalPlan, + rightProject: LogicalPlan): LogicalPlan = { + val union = Union(Seq(leftProject, rightProject)) + logDebug(s"xxx union\n$union") + union + } + + def isDirectAggregate(plan: LogicalPlan): Boolean = { + plan match { + case _ @SubqueryAlias(_, aggregate: Aggregate) => true + case _: Aggregate => true + case _ => false + } + } + + def extractDirectAggregate(plan: LogicalPlan): Option[Aggregate] = { + plan match { + case _ @SubqueryAlias(_, aggregate: Aggregate) => Some(aggregate) + case aggregate: Aggregate => Some(aggregate) + case _ => None + } + } + + def visitPlan(plan: LogicalPlan): LogicalPlan = { + plan match { + case join: Join => + if (join.joinType == LeftOuter && join.condition.isDefined) { + (join.left, join.right) match { + case (left, right) if isDirectAggregate(left) && isDirectAggregate(right) => + val leftAggregate = extractDirectAggregate(left).get + val rightAggregate = extractDirectAggregate(right).get + logDebug(s"xxx case 1. left agg:\n$leftAggregate,\nright agg:\n$rightAggregate") + rewriteOnlyTwoAggregatesJoin(leftAggregate, rightAggregate, join) match { + case Some(newPlan) => + newPlan.withNewChildren(newPlan.children.map(visitPlan)) + case _ => + plan.withNewChildren(plan.children.map(visitPlan)) + } + case _ => + logDebug(s"xxx case 2.left\n{join.left}\nright\n{join.right}") + plan.withNewChildren(plan.children.map(visitPlan)) + } + } else { + plan.withNewChildren(plan.children.map(visitPlan)) + } + case _ => plan.withNewChildren(plan.children.map(visitPlan)) + } + } +} From aecad64e0a08cac758316d30ab1eebdfe1f0b2fb Mon Sep 17 00:00:00 2001 From: lgbo-ustc Date: Wed, 26 Mar 2025 18:33:09 +0800 Subject: [PATCH 02/10] wip --- .../JoinAggregateToAggregateUnion.scala | 249 ++++++++++-------- 1 file changed, 144 insertions(+), 105 deletions(-) diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/JoinAggregateToAggregateUnion.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/JoinAggregateToAggregateUnion.scala index 48815bc46645..6cf7c7453f71 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/JoinAggregateToAggregateUnion.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/JoinAggregateToAggregateUnion.scala @@ -173,86 +173,77 @@ case class JoinAggregateToAggregateUnion(spark: SparkSession) isValid && leftKeys.forall(e => rightKeys.exists(_.semanticEquals(e))) } - // Only two tables join, both are aggregate + /** + * Rewrite the plan if it is a join between two aggregate tables. For example + * ``` + * select k1, k2, s1, s2 from (select k1, sum(a) as s1 from t1 group by k1) as t1 + * left join (select k2, sum(b) as s2 from t2 group by k2) as t2 + * on t1.k1 = t2.k2 + * ``` + * It's rewritten into + * ``` + * select k1, k2, s1, s2 from( + * select k1, k2, s1, if(isNull(flag2), null, s2) as s2 from ( + * select k1, first(k2) as k2, sum(a) as s1, sum(b) as s2, first(flag1) as flag1, + * first(flag2) as flag2 + * from ( + * select k1, null as k2, a as a, null as b, true as flag1, null as flag2 from t1 + * union all + * select k2, k2, null as a, b as b, null as flag1, true as flag2 from t2 + * ) group by k1 + * ) where flag1 is not null + * ) + * ``` + * + * The first query is easier to write, but not as efficient as the second one. + */ def rewriteOnlyTwoAggregatesJoin( leftAggregate: Aggregate, rightAggregate: Aggregate, join: Join): Option[LogicalPlan] = { val (leftKeys, rightKeys) = extractJoinKeys(join) if (!leftKeys.isDefined || !rightKeys.isDefined) { - logDebug(s"xxx not valid join keys") return None } - - val leftAggregateExpressions = extracAggregateExpressions(leftAggregate) - val rightAggregateExpressions = extracAggregateExpressions(rightAggregate) - val firstAggExpr = - leftAggregateExpressions.head.asInstanceOf[Alias].child.asInstanceOf[AggregateExpression] - logDebug(s"xxx aggregate mode ${firstAggExpr.mode}") - if ( - !validateAggregateExpressions(leftAggregateExpressions) || - !validateAggregateExpressions(rightAggregateExpressions) - ) { - logDebug(s"xxx not valid aggregate expressions") + val aggregates = Seq(leftAggregate, rightAggregate) + val joinKeys = Seq(leftKeys.get, rightKeys.get) + val aggregateExpressions = aggregates.map(extracAggregateExpressions) + if (aggregateExpressions.exists(!validateAggregateExpressions(_))) { return None } - if ( - !areSameKeys( - leftKeys.get.map(_.asInstanceOf[Expression]), - leftAggregate.groupingExpressions.map(_.asInstanceOf[Expression])) || - !areSameKeys( - rightKeys.get.map(_.asInstanceOf[Expression]), - rightAggregate.groupingExpressions.map(_.asInstanceOf[Expression])) - ) { - logDebug(s"xxx not same keys") + val keyPairs = joinKeys.zip(aggregates).map { + case (keys, aggregate) => + ( + keys.map(_.asInstanceOf[Expression]), + aggregate.groupingExpressions.map(_.asInstanceOf[Expression])) + } + if (keyPairs.exists(pair => !areSameKeys(pair._1, pair._2))) { return None } - val leftAggregateInputs = collectAggregateExpressionsArguments(leftAggregateExpressions) - val rightAggregateInputs = collectAggregateExpressionsArguments(rightAggregateExpressions) - if (!leftAggregateInputs.isDefined || !rightAggregateInputs.isDefined) { - logDebug(s"xxx not valid aggregate inputs") + val optionAggregateInputs = aggregateExpressions.map(collectAggregateExpressionsArguments) + if (optionAggregateInputs.exists(_.isEmpty)) { return None } - logDebug(s"xxx left inputs\n$leftAggregateInputs\nright inputs\n$rightAggregateInputs") - logDebug(s"xxx all check passed. $join") + val aggregateInputs = optionAggregateInputs.map(_.get) - val aggregates = Seq(leftAggregate, rightAggregate) - - val leftProject = buildExtendedProject( - leftAggregate.groupingExpressions, - leftAggregate.groupingExpressions.map(_.dataType).map(makeNullLiteral), - leftAggregateInputs.get, - rightAggregateInputs.get.map(_.dataType).map(makeNullLiteral), - Seq(makeFlagLiteral(), makeNullLiteral(BooleanType)), - leftAggregate.child - ) - val rightProject = buildExtendedProject( - rightAggregate.groupingExpressions, - rightAggregate.groupingExpressions, - leftAggregateInputs.get.map(_.dataType).map(makeNullLiteral), - rightAggregateInputs.get, - Seq(makeNullLiteral(BooleanType), makeFlagLiteral()), - rightAggregate.child + val extendProjects = buildExtendProjects( + aggregates.map(_.groupingExpressions), + aggregateInputs, + aggregates.map(_.child) ) - val union = buildUnionOnExtendedProjects(leftProject, rightProject) + + val union = buildUnionOnExtendedProjects(extendProjects) val unionOutput = union.output - logDebug(s"xxx join outputs ${join.output}\nunion outputs ${union.output}") - val aggregateUnion = - buildAggregateOnUnion(union, leftAggregate, leftKeys.get, rightAggregate, rightKeys.get) - logDebug(s"xxx aggregate union\n$aggregateUnion") - logDebug(s"xxx aggregate outputs ${aggregateUnion.output}") + val aggregateUnion = buildAggregateOnUnion(union, aggregates, joinKeys) val filtAggregateUnion = buildPrimeKeysFilterOnAggregateUnion(aggregateUnion, aggregates) - logDebug(s"xxx filt aggregate union\n$filtAggregateUnion") - logDebug(s"xxx filt outputs ${filtAggregateUnion.output}") val fieldStartOffset = - leftAggregate.groupingExpressions.length * aggregates.length + leftAggregateExpressions.length + aggregates.map(_.groupingExpressions.length).sum + aggregateExpressions.head.length val setNullsProject = buildMakeNotMatchedRowsNullProject( filtAggregateUnion, fieldStartOffset, aggregates.slice(1, aggregates.length)) - logDebug(s"xxx set nulls project\n$setNullsProject") val renameProject = buildRenameProject(setNullsProject, aggregates, Seq(leftKeys.get, rightKeys.get)) logDebug(s"xxx rename project\n$renameProject") @@ -273,12 +264,56 @@ case class JoinAggregateToAggregateUnion(spark: SparkSession) .withNewChildren(newArguments) .asInstanceOf[AggregateFunction] val newAggregateExpression = aggregateExpression.copy(aggregateFunction = newAggregateFunction) - // val newAlias = - // Alias(newAggregateExpression, alias.name)(alias.exprId, alias.qualifier, None, Seq.empty) val newAlias = makeNamedExpression(newAggregateExpression, alias.name) (inputOffset + arguments.length, newAlias) } + def buildAggregateOnUnion( + union: LogicalPlan, + aggregates: Seq[Aggregate], + joinKeys: Seq[Seq[AttributeReference]]): LogicalPlan = { + val unionOutput = union.output + val keysNumber = joinKeys.head.length + val aggregateExpressions = ArrayBuffer[NamedExpression]() + val groupingKeys = unionOutput.slice(0, keysNumber).zip(joinKeys.head).map { + case (e, a) => + makeNamedExpression(e, a.name) + } + aggregateExpressions ++= groupingKeys + + var fieldIndex = keysNumber + + for (i <- 1 until joinKeys.length) { + val keys = joinKeys(i) + keys.foreach { + key => + val valueExpr = unionOutput(fieldIndex) + val firstValue = makeFirstAggregateExpression(valueExpr) + aggregateExpressions += makeNamedExpression(firstValue, key.name) + fieldIndex += 1 + } + } + + aggregates.foreach { + aggregate => + val partialAggregateExpressions = extracAggregateExpressions(aggregate) + partialAggregateExpressions.foreach { + e => + val (nextFieldIndex, newAggregateExpression) = + rebuildAggregateExpression(e, unionOutput, fieldIndex) + fieldIndex = nextFieldIndex + aggregateExpressions += newAggregateExpression + } + } + for (i <- fieldIndex until unionOutput.length) { + val valueExpr = unionOutput(i) + val firstValue = makeFirstAggregateExpression(valueExpr) + aggregateExpressions += makeNamedExpression(firstValue, valueExpr.name) + } + + Aggregate(groupingKeys, aggregateExpressions.toSeq, union) + } + def buildAggregateOnUnion( union: LogicalPlan, leftAggregate: Aggregate, @@ -290,7 +325,6 @@ case class JoinAggregateToAggregateUnion(spark: SparkSession) val aggregateExpressions = ArrayBuffer[NamedExpression]() val groupingKeys = unionOutput.slice(0, groupingKeysNumber).zip(leftKeys).map { case (e, a) => - // Alias(e, a.name)(a.exprId, a.qualifier, Some(a.metadata), Seq.empty) makeNamedExpression(e, a.name) } aggregateExpressions ++= groupingKeys @@ -303,7 +337,6 @@ case class JoinAggregateToAggregateUnion(spark: SparkSession) .map { case (e, a) => fieldIndex += 1 - // Alias(e, a.name)(a.exprId, a.qualifier, Some(a.metadata), Seq.empty) makeNamedExpression(e, a.name) } @@ -415,65 +448,71 @@ case class JoinAggregateToAggregateUnion(spark: SparkSession) } def buildExtendProjectList( - leftGroupingkeys: Seq[Expression], - rightGroupingkeys: Seq[Expression], - leftInputs: Seq[Expression], - rightInputs: Seq[Expression], - flags: Seq[Expression]): Seq[NamedExpression] = { + groupingKeys: Seq[Seq[Expression]], + aggregateInputs: Seq[Seq[Expression]], + index: Int): Seq[NamedExpression] = { var fieldIndex = 0 - val part1 = leftGroupingkeys.map { - e => - fieldIndex += 1 - makeNamedExpression(e, s"field$fieldIndex") - } - val part2 = rightGroupingkeys.map { - e => - fieldIndex += 1 - makeNamedExpression(e, s"field$fieldIndex") + def nextFieldIndex(): Int = { + val res = fieldIndex + fieldIndex += 1 + res } - val part3 = leftInputs.map { - e => - fieldIndex += 1 - makeNamedExpression(e, s"field$fieldIndex") + + val projectList = ArrayBuffer[NamedExpression]() + val keys = groupingKeys(index) + projectList ++= keys.map(makeNamedExpression(_, s"key$nextFieldIndex")) + for (i <- 1 until groupingKeys.length) { + if (i == index) { + projectList ++= keys.map(makeNamedExpression(_, s"key$nextFieldIndex")) + } else { + projectList ++= keys + .map(_.dataType) + .map(makeNullLiteral) + .map(makeNamedExpression(_, s"key$nextFieldIndex")) + } } - val part4 = rightInputs.map { - e => - fieldIndex += 1 - makeNamedExpression(e, s"field$fieldIndex") + + aggregateInputs.zipWithIndex.foreach { + case (inputs, i) => + if (i == index) { + projectList ++= inputs.map(makeNamedExpression(_, s"agg$nextFieldIndex")) + } else { + projectList ++= inputs + .map(_.dataType) + .map(makeNullLiteral) + .map(makeNamedExpression(_, s"agg$nextFieldIndex")) + } } - val part5 = flags.map { - e => - fieldIndex += 1 - makeNamedExpression(e, s"field$fieldIndex") + + for (i <- 0 until groupingKeys.length) { + if (i == index) { + projectList += makeNamedExpression(makeFlagLiteral(), s"flag$nextFieldIndex") + } else { + projectList += makeNamedExpression(makeNullLiteral(BooleanType), s"flag$nextFieldIndex") + } } - part1 ++ part2 ++ part3 ++ part4 ++ part5 - } - def buildExtendedProject( - leftGroupingkeys: Seq[Expression], - rightGroupingkeys: Seq[Expression], - leftInputs: Seq[Expression], - rightInputs: Seq[Expression], - flags: Seq[Expression], - child: LogicalPlan): LogicalPlan = { - val projectList = - buildExtendProjectList(leftGroupingkeys, rightGroupingkeys, leftInputs, rightInputs, flags) - Project(projectList, child) + projectList.toSeq } + def buildExtendProjects( + groupingKeys: Seq[Seq[Expression]], + aggregateInputs: Seq[Seq[Expression]], + children: Seq[LogicalPlan] + ): Seq[LogicalPlan] = { + val projects = ArrayBuffer[LogicalPlan]() + for (i <- 0 until children.length) { + val projectList = buildExtendProjectList(groupingKeys, aggregateInputs, i) + projects += Project(projectList, children(i)) + } + projects.toSeq + } - def buildUnionOnExtendedProjects1(plans: Seq[LogicalPlan]): LogicalPlan = { + def buildUnionOnExtendedProjects(plans: Seq[LogicalPlan]): LogicalPlan = { val union = Union(plans) logDebug(s"xxx union\n$union") union } - def buildUnionOnExtendedProjects( - leftProject: LogicalPlan, - rightProject: LogicalPlan): LogicalPlan = { - val union = Union(Seq(leftProject, rightProject)) - logDebug(s"xxx union\n$union") - union - } def isDirectAggregate(plan: LogicalPlan): Boolean = { plan match { From fc8c4071fe4799958a2e69cfb7ff7eef0a970d5a Mon Sep 17 00:00:00 2001 From: lgbo-ustc Date: Thu, 27 Mar 2025 10:28:11 +0800 Subject: [PATCH 03/10] wip --- .../JoinAggregateToAggregateUnion.scala | 518 +++++++++--------- 1 file changed, 256 insertions(+), 262 deletions(-) diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/JoinAggregateToAggregateUnion.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/JoinAggregateToAggregateUnion.scala index 6cf7c7453f71..d83b6ee8c166 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/JoinAggregateToAggregateUnion.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/JoinAggregateToAggregateUnion.scala @@ -30,17 +30,17 @@ import org.apache.spark.sql.types._ import scala.collection.mutable.ArrayBuffer -trait AggregateExpressionValidator { +trait AggregateFunctionAnalyzer { def doValidate(): Boolean def getArgumentExpressions(): Option[Seq[Expression]] } -case class DefaultValidator() extends AggregateExpressionValidator { +case class DefaultAggregateFunctionAnalyzer() extends AggregateFunctionAnalyzer { override def doValidate(): Boolean = false override def getArgumentExpressions(): Option[Seq[Expression]] = None } -case class SumValidator(aggExpr: AggregateExpression) extends AggregateExpressionValidator { +case class SumAnalyzer(aggExpr: AggregateExpression) extends AggregateFunctionAnalyzer { val sum = aggExpr.aggregateFunction.asInstanceOf[Sum] override def doValidate(): Boolean = { !sum.child.isInstanceOf[Literal] @@ -51,7 +51,7 @@ case class SumValidator(aggExpr: AggregateExpression) extends AggregateExpressio } } -case class CountValidator(aggExpr: AggregateExpression) extends AggregateExpressionValidator { +case class CountAnalyzer(aggExpr: AggregateExpression) extends AggregateFunctionAnalyzer { val count = aggExpr.aggregateFunction.asInstanceOf[Count] override def doValidate(): Boolean = { count.children.length == 1 && !count.children.head.isInstanceOf[Literal] @@ -62,8 +62,8 @@ case class CountValidator(aggExpr: AggregateExpression) extends AggregateExpress } } -object AggregateExpressionValidator { - def apply(e: Expression): AggregateExpressionValidator = { +object AggregateFunctionAnalyzer { + def apply(e: Expression): AggregateFunctionAnalyzer = { val aggExpr = e match { case alias @ Alias(aggregate: AggregateExpression, _) => Some(aggregate) @@ -75,102 +75,158 @@ object AggregateExpressionValidator { aggExpr match { case Some(agg) => agg.aggregateFunction match { - case sum: Sum => SumValidator(agg) - case count: Count => CountValidator(agg) - case _ => DefaultValidator() + case sum: Sum => SumAnalyzer(agg) + case count: Count => CountAnalyzer(agg) + case _ => DefaultAggregateFunctionAnalyzer() } - case _ => DefaultValidator() + case _ => DefaultAggregateFunctionAnalyzer() } } } -case class JoinAggregateToAggregateUnion(spark: SparkSession) - extends Rule[LogicalPlan] - with Logging { - val JOIN_FILTER_FLAG_NAME = "_left_flag_" - def isResolvedPlan(plan: LogicalPlan): Boolean = { - plan match { - case insert: InsertIntoStatement => insert.query.resolved - case _ => plan.resolved +case class JoinedAggregateAnalyzer(join: Join, subquery: LogicalPlan) extends Logging { + def analysis(): Boolean = { + if (!extractAggregateQuery(subquery)) { + return false } + + if (!extractJoinKeys()) { + return false + } + + aggregateExpressions = + aggregate.aggregateExpressions.filter(_.find(_.isInstanceOf[AggregateExpression]).isDefined) + if (aggregateExpressions.exists(!AggregateFunctionAnalyzer(_).doValidate)) { + return false + } + + val arguments = + aggregateExpressions.map(AggregateFunctionAnalyzer(_).getArgumentExpressions) + if (arguments.exists(_.isEmpty)) { + return false + } + aggregateExpressionArguments = arguments.map(_.get) + + true } - override def apply(plan: LogicalPlan): LogicalPlan = { - if ( - spark.conf - .get(CHBackendSettings.GLUTEN_JOIN_AGGREGATE_TO_AGGREGATE_UNION, "true") - .toBoolean && isResolvedPlan(plan) - ) { - val newPlan = visitPlan(plan) - logDebug(s"old plan\n$plan\nnew plan\n$newPlan") - newPlan - } else { - plan + def extractAggregateQuery(plan: LogicalPlan): Boolean = { + plan match { + case _ @SubqueryAlias(_, agg: Aggregate) => + aggregate = agg + true + case agg: Aggregate => + aggregate = agg + true + case _ => false } } - def extractJoinKeys( - join: Join): (Option[Seq[AttributeReference]], Option[Seq[AttributeReference]]) = { - val leftKeys = ArrayBuffer[AttributeReference]() - val rightKeys = ArrayBuffer[AttributeReference]() + def getPrimeKeys(): Seq[AttributeReference] = primeKeys + def getKeys(): Seq[AttributeReference] = keys + def getAggregate(): Aggregate = aggregate + def getAggregateExpressions(): Seq[NamedExpression] = aggregateExpressions + def getAggregateExpressionArguments(): Seq[Seq[Expression]] = aggregateExpressionArguments + + private var primeKeys: Seq[AttributeReference] = Seq.empty + private var keys: Seq[AttributeReference] = Seq.empty + private var aggregate: Aggregate = null + private var aggregateExpressions: Seq[NamedExpression] = Seq.empty + private var aggregateExpressionArguments: Seq[Seq[Expression]] = Seq.empty + + private def extractJoinKeys(): Boolean = { + val leftJoinKeys = ArrayBuffer[AttributeReference]() + val subqueryKeys = ArrayBuffer[AttributeReference]() val leftOutputSet = join.left.outputSet - val rightOutputSet = join.right.outputSet + val subqueryOutputSet = subquery.outputSet def visitJoinExpression(e: Expression): Unit = { e match { case and: And => visitJoinExpression(and.left) visitJoinExpression(and.right) case equalTo @ EqualTo(left: AttributeReference, right: AttributeReference) => - if (leftOutputSet.contains(left) && rightOutputSet.contains(right)) { - leftKeys += left - rightKeys += right - } else if (leftOutputSet.contains(right) && rightOutputSet.contains(left)) { - leftKeys += right - rightKeys += left - } else { - throw new GlutenException(s"Invalid join condition $equalTo") + if (leftOutputSet.contains(left)) { + leftJoinKeys += left + } + if (subqueryOutputSet.contains(left)) { + subqueryKeys += left + } + if (leftOutputSet.contains(right)) { + leftJoinKeys += right + } + if (subqueryOutputSet.contains(right)) { + subqueryKeys += right } case _ => throw new GlutenException(s"Unsupported join condition $e") } } + + // They must be the same sets or total different sets + if ( + leftJoinKeys.length != subqueryKeys.length || + !(leftJoinKeys.forall(key => subqueryKeys.exists(_.equals(key))) || + leftJoinKeys.forall(key => !subqueryKeys.exists(_.equals(key)))) + ) { + return false + } join.condition match { case Some(condition) => try { visitJoinExpression(condition) } catch { case e: GlutenException => - logDebug(s"xxx invalid join condition $condition") - return (None, None) + return false } - (Some(leftKeys), Some(rightKeys)) - case _ => (None, None) + keys = subqueryKeys.toSeq + primeKeys = leftJoinKeys.toSeq + true + case _ => false } } +} - def extracAggregateExpressions(aggregate: Aggregate): Seq[NamedExpression] = { - aggregate.aggregateExpressions.filter(_.find(_.isInstanceOf[AggregateExpression]).isDefined) +object JoinedAggregateAnalyzer extends Logging { + def build(join: Join, subquery: LogicalPlan): Option[JoinedAggregateAnalyzer] = { + val analyzer = JoinedAggregateAnalyzer(join, subquery) + if (analyzer.analysis()) { + Some(analyzer) + } else { + None + } } - def validateAggregateExpressions(aggregateExpressions: Seq[NamedExpression]): Boolean = { - aggregateExpressions.forall(AggregateExpressionValidator(_).doValidate) + def haveSamePrimeKeys(analzyers: Seq[JoinedAggregateAnalyzer]): Boolean = { + val primeKeys = analzyers.map(_.getPrimeKeys()).map(AttributeSet(_)) + primeKeys + .slice(1, primeKeys.length) + .forall(keys => keys.equals(primeKeys.head)) } +} - def collectAggregateExpressionsArguments( - expressions: Seq[NamedExpression]): Option[Seq[Expression]] = { - val arguments = expressions.map(AggregateExpressionValidator(_).getArgumentExpressions) - if (arguments.exists(_.isEmpty)) { - None - } else { - Some(arguments.map(_.get).flatten) +case class JoinAggregateToAggregateUnion(spark: SparkSession) + extends Rule[LogicalPlan] + with Logging { + val JOIN_FILTER_FLAG_NAME = "_left_flag_" + def isResolvedPlan(plan: LogicalPlan): Boolean = { + plan match { + case insert: InsertIntoStatement => insert.query.resolved + case _ => plan.resolved } } - def areSameKeys(leftKeys: Seq[Expression], rightKeys: Seq[Expression]): Boolean = { - var isValid = leftKeys.length == rightKeys.length - isValid = isValid && leftKeys.forall(_.isInstanceOf[AttributeReference]) && - rightKeys.forall(_.isInstanceOf[AttributeReference]) - isValid && leftKeys.forall(e => rightKeys.exists(_.semanticEquals(e))) + override def apply(plan: LogicalPlan): LogicalPlan = { + if ( + spark.conf + .get(CHBackendSettings.GLUTEN_JOIN_AGGREGATE_TO_AGGREGATE_UNION, "true") + .toBoolean && isResolvedPlan(plan) + ) { + val newPlan = visitPlan(plan) + logDebug(s"old plan\n$plan\nnew plan\n$newPlan") + newPlan + } else { + plan + } } /** @@ -201,110 +257,84 @@ case class JoinAggregateToAggregateUnion(spark: SparkSession) leftAggregate: Aggregate, rightAggregate: Aggregate, join: Join): Option[LogicalPlan] = { - val (leftKeys, rightKeys) = extractJoinKeys(join) - if (!leftKeys.isDefined || !rightKeys.isDefined) { + val optionAnalyzers = Seq(leftAggregate, rightAggregate).map( + agg => JoinedAggregateAnalyzer.build(join, agg.asInstanceOf[LogicalPlan])) + if (optionAnalyzers.exists(_.isEmpty)) { return None } - val aggregates = Seq(leftAggregate, rightAggregate) - val joinKeys = Seq(leftKeys.get, rightKeys.get) - val aggregateExpressions = aggregates.map(extracAggregateExpressions) - if (aggregateExpressions.exists(!validateAggregateExpressions(_))) { + val analyzers = optionAnalyzers.map(_.get) + if (!JoinedAggregateAnalyzer.haveSamePrimeKeys(analyzers)) { return None } - val keyPairs = joinKeys.zip(aggregates).map { - case (keys, aggregate) => - ( - keys.map(_.asInstanceOf[Expression]), - aggregate.groupingExpressions.map(_.asInstanceOf[Expression])) - } - if (keyPairs.exists(pair => !areSameKeys(pair._1, pair._2))) { - return None - } + val aggregates = analyzers.map(_.getAggregate) + val joinKeys = analyzers.map(_.getKeys()) + val aggregateExpressions = analyzers.map(_.getAggregateExpressions()) + val aggregateExpressionArguments = analyzers.map(_.getAggregateExpressionArguments()).flatten - val optionAggregateInputs = aggregateExpressions.map(collectAggregateExpressionsArguments) - if (optionAggregateInputs.exists(_.isEmpty)) { - return None - } - val aggregateInputs = optionAggregateInputs.map(_.get) - - val extendProjects = buildExtendProjects( - aggregates.map(_.groupingExpressions), - aggregateInputs, - aggregates.map(_.child) - ) + val extendProjects = buildExtendProjects(analyzers) val union = buildUnionOnExtendedProjects(extendProjects) - val unionOutput = union.output - val aggregateUnion = buildAggregateOnUnion(union, aggregates, joinKeys) - val filtAggregateUnion = buildPrimeKeysFilterOnAggregateUnion(aggregateUnion, aggregates) - val fieldStartOffset = - aggregates.map(_.groupingExpressions.length).sum + aggregateExpressions.head.length - val setNullsProject = buildMakeNotMatchedRowsNullProject( - filtAggregateUnion, - fieldStartOffset, - aggregates.slice(1, aggregates.length)) - val renameProject = - buildRenameProject(setNullsProject, aggregates, Seq(leftKeys.get, rightKeys.get)) + val aggregateUnion = buildAggregateOnUnion(union, analyzers) + val filtAggregateUnion = buildPrimeKeysFilterOnAggregateUnion(aggregateUnion, analyzers) + val setNullsProject = buildMakeNotMatchedRowsNullProject(filtAggregateUnion, analyzers, Set(0)) + val renameProject = buildRenameProject(setNullsProject, analyzers) logDebug(s"xxx rename project\n$renameProject") logDebug(s"xxx rename outputs ${renameProject.output}") Some(renameProject) } - def rebuildAggregateExpression( + def buildAggregateExpressionWithNewChildren( ne: NamedExpression, - input: Seq[Attribute], - inputOffset: Integer): (Integer, NamedExpression) = { - val alias = ne.asInstanceOf[Alias] - val aggregateExpression = alias.child.asInstanceOf[AggregateExpression] - val validator = AggregateExpressionValidator(aggregateExpression) - val arguments = validator.getArgumentExpressions().get - val newArguments = input.slice(inputOffset, inputOffset + arguments.length) + inputs: Seq[Attribute]): NamedExpression = { + val aggregateExpression = + ne.asInstanceOf[Alias].child.asInstanceOf[AggregateExpression] val newAggregateFunction = aggregateExpression.aggregateFunction - .withNewChildren(newArguments) + .withNewChildren(inputs) .asInstanceOf[AggregateFunction] val newAggregateExpression = aggregateExpression.copy(aggregateFunction = newAggregateFunction) - val newAlias = makeNamedExpression(newAggregateExpression, alias.name) - (inputOffset + arguments.length, newAlias) + makeNamedExpression(newAggregateExpression, ne.name) } def buildAggregateOnUnion( union: LogicalPlan, - aggregates: Seq[Aggregate], - joinKeys: Seq[Seq[AttributeReference]]): LogicalPlan = { + analyzedAggregates: Seq[JoinedAggregateAnalyzer]): LogicalPlan = { val unionOutput = union.output - val keysNumber = joinKeys.head.length + val keysNumber = analyzedAggregates.head.getKeys().length val aggregateExpressions = ArrayBuffer[NamedExpression]() - val groupingKeys = unionOutput.slice(0, keysNumber).zip(joinKeys.head).map { + + val groupingKeys = unionOutput.slice(0, keysNumber).zip(analyzedAggregates.head.getKeys()).map { case (e, a) => makeNamedExpression(e, a.name) } aggregateExpressions ++= groupingKeys - var fieldIndex = keysNumber - - for (i <- 1 until joinKeys.length) { - val keys = joinKeys(i) - keys.foreach { - key => - val valueExpr = unionOutput(fieldIndex) - val firstValue = makeFirstAggregateExpression(valueExpr) - aggregateExpressions += makeNamedExpression(firstValue, key.name) - fieldIndex += 1 + for (i <- 1 until analyzedAggregates.length) { + val keys = analyzedAggregates(i).getKeys() + val fieldIndex = i * keysNumber + for (j <- 0 until keys.length) { + val key = keys(j) + val valueExpr = unionOutput(fieldIndex + j) + val firstValue = makeFirstAggregateExpression(valueExpr) + aggregateExpressions += makeNamedExpression(firstValue, key.name) } } - aggregates.foreach { - aggregate => - val partialAggregateExpressions = extracAggregateExpressions(aggregate) - partialAggregateExpressions.foreach { - e => - val (nextFieldIndex, newAggregateExpression) = - rebuildAggregateExpression(e, unionOutput, fieldIndex) - fieldIndex = nextFieldIndex - aggregateExpressions += newAggregateExpression - } + var fieldIndex = keysNumber * analyzedAggregates.length + for (i <- 0 until analyzedAggregates.length) { + val partialAggregateExpressions = analyzedAggregates(i).getAggregateExpressions() + val partialAggregateArguments = analyzedAggregates(i).getAggregateExpressionArguments() + for (j <- 0 until partialAggregateExpressions.length) { + val aggregateExpression = partialAggregateExpressions(j) + val arguments = partialAggregateArguments(j) + val newArguments = unionOutput.slice(fieldIndex, fieldIndex + arguments.length) + aggregateExpressions += buildAggregateExpressionWithNewChildren( + aggregateExpression, + newArguments) + fieldIndex += arguments.length + } } + for (i <- fieldIndex until unionOutput.length) { val valueExpr = unionOutput(i) val firstValue = makeFirstAggregateExpression(valueExpr) @@ -314,119 +344,84 @@ case class JoinAggregateToAggregateUnion(spark: SparkSession) Aggregate(groupingKeys, aggregateExpressions.toSeq, union) } - def buildAggregateOnUnion( - union: LogicalPlan, - leftAggregate: Aggregate, - leftKeys: Seq[AttributeReference], - rightAggregate: Aggregate, - rightKeys: Seq[AttributeReference]): LogicalPlan = { - val unionOutput = union.output - val groupingKeysNumber = leftKeys.length - val aggregateExpressions = ArrayBuffer[NamedExpression]() - val groupingKeys = unionOutput.slice(0, groupingKeysNumber).zip(leftKeys).map { - case (e, a) => - makeNamedExpression(e, a.name) - } - aggregateExpressions ++= groupingKeys - - var fieldIndex = groupingKeysNumber - aggregateExpressions ++= unionOutput - .slice(groupingKeysNumber, groupingKeysNumber * 2) - .map(makeFirstAggregateExpression) - .zip(rightKeys) - .map { - case (e, a) => - fieldIndex += 1 - makeNamedExpression(e, a.name) - } - - aggregateExpressions ++= - extracAggregateExpressions(leftAggregate).map { - e => - val (nextFieldIndex, newAggregateExpression) = - rebuildAggregateExpression(e, unionOutput, fieldIndex) - fieldIndex = nextFieldIndex - newAggregateExpression - } - aggregateExpressions ++= - extracAggregateExpressions(rightAggregate).map { - e => - val (nextFieldIndex, newAggregateExpression) = - rebuildAggregateExpression(e, unionOutput, fieldIndex) - fieldIndex = nextFieldIndex - newAggregateExpression - } - - aggregateExpressions ++= unionOutput.slice(fieldIndex, unionOutput.length).map { - attr => - val firstValue = makeFirstAggregateExpression(attr) - makeNamedExpression(firstValue, attr.name) - } - - Aggregate(groupingKeys, aggregateExpressions.toSeq, union) - } - + /** + * If the grouping keys is in the right table but not in the left table, remove the row from the + * result. + */ def buildPrimeKeysFilterOnAggregateUnion( plan: LogicalPlan, - aggregates: Seq[Aggregate]): LogicalPlan = { - val flagExpressions = plan.output(plan.output.length - aggregates.length) + analyzedAggregates: Seq[JoinedAggregateAnalyzer]): LogicalPlan = { + val flagExpressions = plan.output(plan.output.length - analyzedAggregates.length) val notNullExpr = IsNotNull(flagExpressions) Filter(notNullExpr, plan) } + /** Make the aggregate result be null if the grouping keys is not in the most left table. */ def buildMakeNotMatchedRowsNullProject( plan: LogicalPlan, - fieldStartOffset: Integer, - aggregates: Seq[Aggregate]): LogicalPlan = { + analyzedAggregates: Seq[JoinedAggregateAnalyzer], + ignoreAggregates: Set[Int]): LogicalPlan = { val input = plan.output - val flagExpressions = input.slice(plan.output.length - aggregates.length, plan.output.length) - var fieldIndex = fieldStartOffset - val aggregatesIfNullExpressions = aggregates.zipWithIndex.map { - case (aggregate, i) => + val flagExpressions = + input.slice(plan.output.length - analyzedAggregates.length, plan.output.length) + val aggregateExprsStart = analyzedAggregates.length * analyzedAggregates.head.getKeys().length + var fieldIndex = aggregateExprsStart + val aggregatesIfNullExpressions = analyzedAggregates.zipWithIndex.map { + case (analyzedAggregate, i) => val flagExpr = flagExpressions(i) - val aggregateExpressions = extracAggregateExpressions(aggregate) + val aggregateExpressions = analyzedAggregate.getAggregateExpressions() aggregateExpressions.map { e => val valueExpr = input(fieldIndex) fieldIndex += 1 - val clearExpr = If(IsNull(flagExpr), makeNullLiteral(valueExpr.dataType), valueExpr) - makeNamedExpression(clearExpr, valueExpr.name) + if (ignoreAggregates(i)) { + valueExpr.asInstanceOf[NamedExpression] + } else { + val clearExpr = If(IsNull(flagExpr), makeNullLiteral(valueExpr.dataType), valueExpr) + makeNamedExpression(clearExpr, valueExpr.name) + } } } val ifNullExpressions = aggregatesIfNullExpressions.flatten - val projectList = input.slice(0, fieldStartOffset) ++ ifNullExpressions + val projectList = input.slice(0, aggregateExprsStart) ++ ifNullExpressions Project(projectList, plan) } + /** + * Build a project to make the output attributes have the same name and exprId as the original + * join output attributes. + */ def buildRenameProject( plan: LogicalPlan, - aggregates: Seq[Aggregate], - joinKeys: Seq[Seq[AttributeReference]]): LogicalPlan = { + analyzedAggregates: Seq[JoinedAggregateAnalyzer]): LogicalPlan = { val input = plan.output - var fieldIndex = 0 + val joinKeys = analyzedAggregates.map(_.getKeys()) val projectList = ArrayBuffer[NamedExpression]() - joinKeys.foreach { - keys => - keys.foreach { - key => - projectList += Alias(input(fieldIndex), key.name)( - key.exprId, - key.qualifier, - None, - Seq.empty) - fieldIndex += 1 - } + val keysNum = analyzedAggregates.head.getKeys().length + var fieldIndex = 0 + for (i <- 0 until analyzedAggregates.length) { + val keys = analyzedAggregates(i).getKeys() + for (j <- 0 until keys.length) { + val key = keys(j) + projectList += Alias(input(fieldIndex), key.name)( + key.exprId, + key.qualifier, + None, + Seq.empty) + fieldIndex += 1 + } } - aggregates.foreach { - aggregate => - val aggregateExpressions = extracAggregateExpressions(aggregate) - aggregateExpressions.foreach { - e => - val valueExpr = input(fieldIndex) - projectList += Alias(valueExpr, e.name)(e.exprId, e.qualifier, None, Seq.empty) - fieldIndex += 1 - } + + for (i <- 0 until analyzedAggregates.length) { + val aggregateExpressions = analyzedAggregates(i).getAggregateExpressions() + for (j <- 0 until aggregateExpressions.length) { + val e = aggregateExpressions(j) + val valueExpr = input(fieldIndex) + projectList += Alias(valueExpr, e.name)(e.exprId, e.qualifier, None, Seq.empty) + fieldIndex += 1 + } } + Project(projectList.toSeq, plan) } @@ -447,70 +442,69 @@ case class JoinAggregateToAggregateUnion(spark: SparkSession) AggregateExpression(aggregateFunction, Complete, false) } + /** + * Build a extended project list, which contains three parts. + * - The grouping keys of all tables. + * - All required columns for aggregate functions in every table. + * - Flags for each table to indicate whether the row is in the table. + */ def buildExtendProjectList( - groupingKeys: Seq[Seq[Expression]], - aggregateInputs: Seq[Seq[Expression]], + analyzedAggregates: Seq[JoinedAggregateAnalyzer], index: Int): Seq[NamedExpression] = { - var fieldIndex = 0 - def nextFieldIndex(): Int = { - val res = fieldIndex - fieldIndex += 1 - res - } - val projectList = ArrayBuffer[NamedExpression]() - val keys = groupingKeys(index) - projectList ++= keys.map(makeNamedExpression(_, s"key$nextFieldIndex")) - for (i <- 1 until groupingKeys.length) { - if (i == index) { - projectList ++= keys.map(makeNamedExpression(_, s"key$nextFieldIndex")) - } else { - projectList ++= keys - .map(_.dataType) - .map(makeNullLiteral) - .map(makeNamedExpression(_, s"key$nextFieldIndex")) + projectList ++= analyzedAggregates(index).getKeys().zipWithIndex.map { + case (e, i) => + makeNamedExpression(e, s"key_0_$i") + } + for (i <- 1 until analyzedAggregates.length) { + val groupingKeys = analyzedAggregates(i).getKeys() + projectList ++= groupingKeys.zipWithIndex.map { + case (e, j) => + if (i == index) { + makeNamedExpression(e, s"key_${i}_$j") + } else { + makeNamedExpression(makeNullLiteral(e.dataType), s"key_${i}_$j") + } } } - aggregateInputs.zipWithIndex.foreach { - case (inputs, i) => - if (i == index) { - projectList ++= inputs.map(makeNamedExpression(_, s"agg$nextFieldIndex")) - } else { - projectList ++= inputs - .map(_.dataType) - .map(makeNullLiteral) - .map(makeNamedExpression(_, s"agg$nextFieldIndex")) - } + for (i <- 0 until analyzedAggregates.length) { + val argsList = analyzedAggregates(i).getAggregateExpressionArguments() + argsList.zipWithIndex.foreach { + case (args, j) => + projectList ++= args.zipWithIndex.map { + case (arg, k) => + if (i == index) { + makeNamedExpression(arg, s"arg_${i}_${j}_$k") + } else { + makeNamedExpression(makeNullLiteral(arg.dataType), s"arg_${i}_${j}_$k") + } + } + } } - for (i <- 0 until groupingKeys.length) { + for (i <- 0 until analyzedAggregates.length) { if (i == index) { - projectList += makeNamedExpression(makeFlagLiteral(), s"flag$nextFieldIndex") + projectList += makeNamedExpression(makeFlagLiteral(), s"flag_$i") } else { - projectList += makeNamedExpression(makeNullLiteral(BooleanType), s"flag$nextFieldIndex") + projectList += makeNamedExpression(makeNullLiteral(BooleanType), s"flag_$i") } } projectList.toSeq } - def buildExtendProjects( - groupingKeys: Seq[Seq[Expression]], - aggregateInputs: Seq[Seq[Expression]], - children: Seq[LogicalPlan] - ): Seq[LogicalPlan] = { + def buildExtendProjects(analyzedAggregates: Seq[JoinedAggregateAnalyzer]): Seq[LogicalPlan] = { val projects = ArrayBuffer[LogicalPlan]() - for (i <- 0 until children.length) { - val projectList = buildExtendProjectList(groupingKeys, aggregateInputs, i) - projects += Project(projectList, children(i)) + for (i <- 0 until analyzedAggregates.length) { + val projectList = buildExtendProjectList(analyzedAggregates, i) + projects += Project(projectList, analyzedAggregates(i).getAggregate().child) } projects.toSeq } def buildUnionOnExtendedProjects(plans: Seq[LogicalPlan]): LogicalPlan = { val union = Union(plans) - logDebug(s"xxx union\n$union") union } From 82ffc166a66097f6089b522ed0712b928c8a2012 Mon Sep 17 00:00:00 2001 From: lgbo-ustc Date: Thu, 27 Mar 2025 16:09:55 +0800 Subject: [PATCH 04/10] wip --- .../JoinAggregateToAggregateUnion.scala | 138 ++++++++++++++---- 1 file changed, 112 insertions(+), 26 deletions(-) diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/JoinAggregateToAggregateUnion.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/JoinAggregateToAggregateUnion.scala index d83b6ee8c166..9a6339e15dd5 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/JoinAggregateToAggregateUnion.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/JoinAggregateToAggregateUnion.scala @@ -229,6 +229,52 @@ case class JoinAggregateToAggregateUnion(spark: SparkSession) } } + def visitPlan(plan: LogicalPlan): LogicalPlan = { + plan match { + case join: Join => + if (join.joinType == LeftOuter && join.condition.isDefined) { + (join.left, join.right) match { + case (left, right) if isDirectAggregate(left) && isDirectAggregate(right) => + val leftAggregate = extractDirectAggregate(left).get + val rightAggregate = extractDirectAggregate(right).get + logDebug(s"xxx case 1. left agg:\n$leftAggregate,\nright agg:\n$rightAggregate") + rewriteOnlyTwoAggregatesJoin(leftAggregate, rightAggregate, join) match { + case Some(newPlan) => + newPlan.withNewChildren(newPlan.children.map(visitPlan)) + case _ => + plan.withNewChildren(plan.children.map(visitPlan)) + } + case (left, right) if isDirectJoin(left) && isDirectAggregate(right) => + logDebug(s"xxx case 2") + val analyzedAggregates = ArrayBuffer[JoinedAggregateAnalyzer]() + val remainedPlan = collectSameKeysJoinedAggregates(join, analyzedAggregates) + logDebug(s"xxx join left\n$remainedPlan") + logDebug(s"xxx analyzed aggregates number ${analyzedAggregates.length}") + val unionedAggregates = unionAllJoinedAggregates(analyzedAggregates.toSeq) + val finalPlan = + if (analyzedAggregates.length == 0) { + join.copy(left = visitPlan(join.left), right = visitPlan(join.right)) + } else if (analyzedAggregates.length == 1) { + join.copy(left = visitPlan(join.left)) + } else if (remainedPlan.isDefined) { + val lastJoin = analyzedAggregates.last.join + lastJoin.copy(left = visitPlan(lastJoin.left), right = unionedAggregates) + } else { + buildPrimeKeysFilterOnAggregateUnion(unionedAggregates, analyzedAggregates) + } + logDebug(s"xxx final rewritten plan\n$finalPlan") + finalPlan + case _ => + logDebug(s"xxx case 3.left\n${join.left}\nright\n${join.right}") + plan.withNewChildren(plan.children.map(visitPlan)) + } + } else { + plan.withNewChildren(plan.children.map(visitPlan)) + } + case _ => plan.withNewChildren(plan.children.map(visitPlan)) + } + } + /** * Rewrite the plan if it is a join between two aggregate tables. For example * ``` @@ -267,11 +313,6 @@ case class JoinAggregateToAggregateUnion(spark: SparkSession) return None } - val aggregates = analyzers.map(_.getAggregate) - val joinKeys = analyzers.map(_.getKeys()) - val aggregateExpressions = analyzers.map(_.getAggregateExpressions()) - val aggregateExpressionArguments = analyzers.map(_.getAggregateExpressionArguments()).flatten - val extendProjects = buildExtendProjects(analyzers) val union = buildUnionOnExtendedProjects(extendProjects) @@ -296,6 +337,15 @@ case class JoinAggregateToAggregateUnion(spark: SparkSession) makeNamedExpression(newAggregateExpression, ne.name) } + def unionAllJoinedAggregates(analyzedAggregates: Seq[JoinedAggregateAnalyzer]): LogicalPlan = { + val extendProjects = buildExtendProjects(analyzedAggregates) + val union = buildUnionOnExtendedProjects(extendProjects) + val aggregateUnion = buildAggregateOnUnion(union, analyzedAggregates) + val setNullsProject = + buildMakeNotMatchedRowsNullProject(aggregateUnion, analyzedAggregates, Set()) + buildRenameProject(setNullsProject, analyzedAggregates) + } + def buildAggregateOnUnion( union: LogicalPlan, analyzedAggregates: Seq[JoinedAggregateAnalyzer]): LogicalPlan = { @@ -383,7 +433,7 @@ case class JoinAggregateToAggregateUnion(spark: SparkSession) } } val ifNullExpressions = aggregatesIfNullExpressions.flatten - val projectList = input.slice(0, aggregateExprsStart) ++ ifNullExpressions + val projectList = input.slice(0, aggregateExprsStart) ++ ifNullExpressions ++ flagExpressions Project(projectList, plan) } @@ -422,6 +472,8 @@ case class JoinAggregateToAggregateUnion(spark: SparkSession) } } + // Keep the flag columns + projectList ++= input.slice(input.length - analyzedAggregates.length, input.length) Project(projectList.toSeq, plan) } @@ -524,29 +576,63 @@ case class JoinAggregateToAggregateUnion(spark: SparkSession) } } - def visitPlan(plan: LogicalPlan): LogicalPlan = { + def isDirectJoin(plan: LogicalPlan): Boolean = { plan match { - case join: Join => - if (join.joinType == LeftOuter && join.condition.isDefined) { - (join.left, join.right) match { - case (left, right) if isDirectAggregate(left) && isDirectAggregate(right) => - val leftAggregate = extractDirectAggregate(left).get - val rightAggregate = extractDirectAggregate(right).get - logDebug(s"xxx case 1. left agg:\n$leftAggregate,\nright agg:\n$rightAggregate") - rewriteOnlyTwoAggregatesJoin(leftAggregate, rightAggregate, join) match { - case Some(newPlan) => - newPlan.withNewChildren(newPlan.children.map(visitPlan)) - case _ => - plan.withNewChildren(plan.children.map(visitPlan)) - } - case _ => - logDebug(s"xxx case 2.left\n{join.left}\nright\n{join.right}") - plan.withNewChildren(plan.children.map(visitPlan)) - } + case _: Join => true + case _ => false + } + } + + def extractDirectJoin(plan: LogicalPlan): Option[Join] = { + plan match { + case join: Join => Some(join) + case _ => None + } + } + + def collectSameKeysJoinedAggregates( + plan: LogicalPlan, + analyzedAggregates: ArrayBuffer[JoinedAggregateAnalyzer]): Option[LogicalPlan] = { + plan match { + case join: Join if join.joinType == LeftOuter && join.condition.isDefined => + val optionAggregate = extractDirectAggregate(join.right) + if (optionAggregate.isEmpty) { + return Some(plan) + } + val rightAggregateAnalyzer = JoinedAggregateAnalyzer.build(join, optionAggregate.get) + if (rightAggregateAnalyzer.isEmpty) { + return Some(plan) + } + + if ( + analyzedAggregates.isEmpty || + JoinedAggregateAnalyzer.haveSamePrimeKeys( + Seq(analyzedAggregates.head, rightAggregateAnalyzer.get)) + ) { + analyzedAggregates += rightAggregateAnalyzer.get + collectSameKeysJoinedAggregates(join.left, analyzedAggregates) } else { - plan.withNewChildren(plan.children.map(visitPlan)) + Some(plan) } - case _ => plan.withNewChildren(plan.children.map(visitPlan)) + case _ if isDirectAggregate(plan) => + val aggregate = extractDirectAggregate(plan).get + val lastJoin = analyzedAggregates.last.join + assert(lastJoin.left.equals(plan), "The node should be last join's left child") + val leftAggregateAnalyzer = JoinedAggregateAnalyzer.build(lastJoin, aggregate) + if (leftAggregateAnalyzer.isEmpty) { + return Some(plan) + } + if ( + JoinedAggregateAnalyzer.haveSamePrimeKeys( + Seq(analyzedAggregates.head, leftAggregateAnalyzer.get)) + ) { + analyzedAggregates += leftAggregateAnalyzer.get + None + } else { + Some(plan) + } + case _ => Some(plan) } } + } From 2ab4e47005ed62756ed2183deb56e7aee933c6f7 Mon Sep 17 00:00:00 2001 From: lgbo-ustc Date: Thu, 27 Mar 2025 19:49:16 +0800 Subject: [PATCH 05/10] add test --- .../JoinAggregateToAggregateUnion.scala | 144 ++++++------ .../execution/GlutenEliminateJoinSuite.scala | 217 ++++++++++++++++++ 2 files changed, 283 insertions(+), 78 deletions(-) create mode 100644 backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenEliminateJoinSuite.scala diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/JoinAggregateToAggregateUnion.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/JoinAggregateToAggregateUnion.scala index 9a6339e15dd5..95f10a2cfaaa 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/JoinAggregateToAggregateUnion.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/JoinAggregateToAggregateUnion.scala @@ -33,11 +33,13 @@ import scala.collection.mutable.ArrayBuffer trait AggregateFunctionAnalyzer { def doValidate(): Boolean def getArgumentExpressions(): Option[Seq[Expression]] + def ignoreNulls(): Boolean } case class DefaultAggregateFunctionAnalyzer() extends AggregateFunctionAnalyzer { override def doValidate(): Boolean = false override def getArgumentExpressions(): Option[Seq[Expression]] = None + override def ignoreNulls(): Boolean = false } case class SumAnalyzer(aggExpr: AggregateExpression) extends AggregateFunctionAnalyzer { @@ -49,6 +51,21 @@ case class SumAnalyzer(aggExpr: AggregateExpression) extends AggregateFunctionAn override def getArgumentExpressions(): Option[Seq[Expression]] = { Some(Seq(sum.child)) } + + override def ignoreNulls(): Boolean = true +} + +case class AverageAnalyzer(aggExpr: AggregateExpression) extends AggregateFunctionAnalyzer { + val sum = aggExpr.aggregateFunction.asInstanceOf[Sum] + override def doValidate(): Boolean = { + !sum.child.isInstanceOf[Literal] + } + + override def getArgumentExpressions(): Option[Seq[Expression]] = { + Some(Seq(sum.child)) + } + + override def ignoreNulls(): Boolean = true } case class CountAnalyzer(aggExpr: AggregateExpression) extends AggregateFunctionAnalyzer { @@ -60,6 +77,9 @@ case class CountAnalyzer(aggExpr: AggregateExpression) extends AggregateFunction override def getArgumentExpressions(): Option[Seq[Expression]] = { Some(count.children) } + + override def ignoreNulls(): Boolean = false + } object AggregateFunctionAnalyzer { @@ -76,6 +96,7 @@ object AggregateFunctionAnalyzer { case Some(agg) => agg.aggregateFunction match { case sum: Sum => SumAnalyzer(agg) + case avg: Average => AverageAnalyzer(agg) case count: Count => CountAnalyzer(agg) case _ => DefaultAggregateFunctionAnalyzer() } @@ -96,12 +117,19 @@ case class JoinedAggregateAnalyzer(join: Join, subquery: LogicalPlan) extends Lo aggregateExpressions = aggregate.aggregateExpressions.filter(_.find(_.isInstanceOf[AggregateExpression]).isDefined) - if (aggregateExpressions.exists(!AggregateFunctionAnalyzer(_).doValidate)) { + aggregateFunctionAnalyzer = aggregateExpressions.map(AggregateFunctionAnalyzer(_)) + if ( + aggregateExpressions.zipWithIndex.exists { + case (_, i) => !aggregateFunctionAnalyzer(i).doValidate + } + ) { return false } val arguments = - aggregateExpressions.map(AggregateFunctionAnalyzer(_).getArgumentExpressions) + aggregateExpressions.zipWithIndex.map { + case (_, i) => aggregateFunctionAnalyzer(i).getArgumentExpressions + } if (arguments.exists(_.isEmpty)) { return false } @@ -127,12 +155,14 @@ case class JoinedAggregateAnalyzer(join: Join, subquery: LogicalPlan) extends Lo def getAggregate(): Aggregate = aggregate def getAggregateExpressions(): Seq[NamedExpression] = aggregateExpressions def getAggregateExpressionArguments(): Seq[Seq[Expression]] = aggregateExpressionArguments + def getAggregateFunctionAnalyzers(): Seq[AggregateFunctionAnalyzer] = aggregateFunctionAnalyzer private var primeKeys: Seq[AttributeReference] = Seq.empty private var keys: Seq[AttributeReference] = Seq.empty private var aggregate: Aggregate = null private var aggregateExpressions: Seq[NamedExpression] = Seq.empty private var aggregateExpressionArguments: Seq[Seq[Expression]] = Seq.empty + private var aggregateFunctionAnalyzer = Seq.empty[AggregateFunctionAnalyzer] private def extractJoinKeys(): Boolean = { val leftJoinKeys = ArrayBuffer[AttributeReference]() @@ -229,54 +259,8 @@ case class JoinAggregateToAggregateUnion(spark: SparkSession) } } - def visitPlan(plan: LogicalPlan): LogicalPlan = { - plan match { - case join: Join => - if (join.joinType == LeftOuter && join.condition.isDefined) { - (join.left, join.right) match { - case (left, right) if isDirectAggregate(left) && isDirectAggregate(right) => - val leftAggregate = extractDirectAggregate(left).get - val rightAggregate = extractDirectAggregate(right).get - logDebug(s"xxx case 1. left agg:\n$leftAggregate,\nright agg:\n$rightAggregate") - rewriteOnlyTwoAggregatesJoin(leftAggregate, rightAggregate, join) match { - case Some(newPlan) => - newPlan.withNewChildren(newPlan.children.map(visitPlan)) - case _ => - plan.withNewChildren(plan.children.map(visitPlan)) - } - case (left, right) if isDirectJoin(left) && isDirectAggregate(right) => - logDebug(s"xxx case 2") - val analyzedAggregates = ArrayBuffer[JoinedAggregateAnalyzer]() - val remainedPlan = collectSameKeysJoinedAggregates(join, analyzedAggregates) - logDebug(s"xxx join left\n$remainedPlan") - logDebug(s"xxx analyzed aggregates number ${analyzedAggregates.length}") - val unionedAggregates = unionAllJoinedAggregates(analyzedAggregates.toSeq) - val finalPlan = - if (analyzedAggregates.length == 0) { - join.copy(left = visitPlan(join.left), right = visitPlan(join.right)) - } else if (analyzedAggregates.length == 1) { - join.copy(left = visitPlan(join.left)) - } else if (remainedPlan.isDefined) { - val lastJoin = analyzedAggregates.last.join - lastJoin.copy(left = visitPlan(lastJoin.left), right = unionedAggregates) - } else { - buildPrimeKeysFilterOnAggregateUnion(unionedAggregates, analyzedAggregates) - } - logDebug(s"xxx final rewritten plan\n$finalPlan") - finalPlan - case _ => - logDebug(s"xxx case 3.left\n${join.left}\nright\n${join.right}") - plan.withNewChildren(plan.children.map(visitPlan)) - } - } else { - plan.withNewChildren(plan.children.map(visitPlan)) - } - case _ => plan.withNewChildren(plan.children.map(visitPlan)) - } - } - /** - * Rewrite the plan if it is a join between two aggregate tables. For example + * Use wide-table aggregation to eliminate multi-table joins. For example * ``` * select k1, k2, s1, s2 from (select k1, sum(a) as s1 from t1 group by k1) as t1 * left join (select k2, sum(b) as s2 from t2 group by k2) as t2 @@ -299,30 +283,32 @@ case class JoinAggregateToAggregateUnion(spark: SparkSession) * * The first query is easier to write, but not as efficient as the second one. */ - def rewriteOnlyTwoAggregatesJoin( - leftAggregate: Aggregate, - rightAggregate: Aggregate, - join: Join): Option[LogicalPlan] = { - val optionAnalyzers = Seq(leftAggregate, rightAggregate).map( - agg => JoinedAggregateAnalyzer.build(join, agg.asInstanceOf[LogicalPlan])) - if (optionAnalyzers.exists(_.isEmpty)) { - return None - } - val analyzers = optionAnalyzers.map(_.get) - if (!JoinedAggregateAnalyzer.haveSamePrimeKeys(analyzers)) { - return None + def visitPlan(plan: LogicalPlan): LogicalPlan = { + plan match { + case join: Join => + if (join.joinType == LeftOuter && join.condition.isDefined) { + val analyzedAggregates = ArrayBuffer[JoinedAggregateAnalyzer]() + val remainedPlan = collectSameKeysJoinedAggregates(join, analyzedAggregates) + logDebug(s"xxx join left\n$remainedPlan") + logDebug(s"xxx analyzed aggregates number ${analyzedAggregates.length}") + if (analyzedAggregates.length == 0) { + join.copy(left = visitPlan(join.left), right = visitPlan(join.right)) + } else if (analyzedAggregates.length == 1) { + join.copy(left = visitPlan(join.left)) + } else { + val unionedAggregates = unionAllJoinedAggregates(analyzedAggregates.toSeq) + if (remainedPlan.isDefined) { + val lastJoin = analyzedAggregates.head.join + lastJoin.copy(left = visitPlan(lastJoin.left), right = unionedAggregates) + } else { + buildPrimeKeysFilterOnAggregateUnion(unionedAggregates, analyzedAggregates) + } + } + } else { + plan.withNewChildren(plan.children.map(visitPlan)) + } + case _ => plan.withNewChildren(plan.children.map(visitPlan)) } - - val extendProjects = buildExtendProjects(analyzers) - - val union = buildUnionOnExtendedProjects(extendProjects) - val aggregateUnion = buildAggregateOnUnion(union, analyzers) - val filtAggregateUnion = buildPrimeKeysFilterOnAggregateUnion(aggregateUnion, analyzers) - val setNullsProject = buildMakeNotMatchedRowsNullProject(filtAggregateUnion, analyzers, Set(0)) - val renameProject = buildRenameProject(setNullsProject, analyzers) - logDebug(s"xxx rename project\n$renameProject") - logDebug(s"xxx rename outputs ${renameProject.output}") - Some(renameProject) } def buildAggregateExpressionWithNewChildren( @@ -420,11 +406,12 @@ case class JoinAggregateToAggregateUnion(spark: SparkSession) case (analyzedAggregate, i) => val flagExpr = flagExpressions(i) val aggregateExpressions = analyzedAggregate.getAggregateExpressions() - aggregateExpressions.map { - e => + val aggregateFunctionAnalyzers = analyzedAggregate.getAggregateFunctionAnalyzers() + aggregateExpressions.zipWithIndex.map { + case (e, i) => val valueExpr = input(fieldIndex) fieldIndex += 1 - if (ignoreAggregates(i)) { + if (ignoreAggregates(i) || aggregateFunctionAnalyzers(i).ignoreNulls()) { valueExpr.asInstanceOf[NamedExpression] } else { val clearExpr = If(IsNull(flagExpr), makeNullLiteral(valueExpr.dataType), valueExpr) @@ -609,14 +596,15 @@ case class JoinAggregateToAggregateUnion(spark: SparkSession) JoinedAggregateAnalyzer.haveSamePrimeKeys( Seq(analyzedAggregates.head, rightAggregateAnalyzer.get)) ) { - analyzedAggregates += rightAggregateAnalyzer.get + // left plan is pushed in front + analyzedAggregates.insert(0, rightAggregateAnalyzer.get) collectSameKeysJoinedAggregates(join.left, analyzedAggregates) } else { Some(plan) } case _ if isDirectAggregate(plan) => val aggregate = extractDirectAggregate(plan).get - val lastJoin = analyzedAggregates.last.join + val lastJoin = analyzedAggregates.head.join assert(lastJoin.left.equals(plan), "The node should be last join's left child") val leftAggregateAnalyzer = JoinedAggregateAnalyzer.build(lastJoin, aggregate) if (leftAggregateAnalyzer.isEmpty) { @@ -626,7 +614,7 @@ case class JoinAggregateToAggregateUnion(spark: SparkSession) JoinedAggregateAnalyzer.haveSamePrimeKeys( Seq(analyzedAggregates.head, leftAggregateAnalyzer.get)) ) { - analyzedAggregates += leftAggregateAnalyzer.get + analyzedAggregates.insert(0, leftAggregateAnalyzer.get) None } else { Some(plan) diff --git a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenEliminateJoinSuite.scala b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenEliminateJoinSuite.scala new file mode 100644 index 000000000000..e878247a5b33 --- /dev/null +++ b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenEliminateJoinSuite.scala @@ -0,0 +1,217 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.gluten.execution + +import org.apache.spark.SparkConf +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.optimizer.{ConstantFolding, NullPropagation} +import org.apache.spark.sql.execution.datasources.v2.clickhouse.ClickHouseConfig +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types._ + +import java.nio.file.Files + +class GlutenEliminateJoinSuite extends GlutenClickHouseWholeStageTransformerSuite { + + protected val tablesPath: String = basePath + "/tpch-data" + protected val tpchQueries: String = + rootPath + "../../../../tools/gluten-it/common/src/main/resources/tpch-queries" + protected val queriesResults: String = rootPath + "queries-output" + + private var parquetPath: String = _ + + override protected def sparkConf: SparkConf = { + super.sparkConf + .set("spark.sql.files.maxPartitionBytes", "1g") + .set("spark.serializer", "org.apache.spark.serializer.JavaSerializer") + .set("spark.sql.shuffle.partitions", "5") + .set("spark.sql.adaptive.enabled", "false") + .set("spark.sql.files.minPartitionNum", "1") + .set( + "spark.sql.catalog.spark_catalog", + "org.apache.spark.sql.execution.datasources.v2.clickhouse.ClickHouseSparkCatalog") + .set("spark.databricks.delta.maxSnapshotLineageLength", "20") + .set("spark.databricks.delta.snapshotPartitions", "1") + .set("spark.databricks.delta.properties.defaults.checkpointInterval", "5") + .set("spark.databricks.delta.stalenessLimit", "3600000") + .set(ClickHouseConfig.CLICKHOUSE_WORKER_ID, "1") + .set("spark.gluten.sql.columnar.iterator", "true") + .set("spark.gluten.sql.columnar.hashagg.enablefinal", "true") + .set("spark.gluten.sql.enable.native.validation", "false") + .set("spark.sql.warehouse.dir", warehouse) + .set("spark.shuffle.manager", "sort") + .set("spark.io.compression.codec", "snappy") + .set("spark.sql.shuffle.partitions", "5") + .set("spark.sql.autoBroadcastJoinThreshold", "-1") + .set("spark.gluten.supported.scala.udfs", "compare_substrings:compare_substrings") + .set( + SQLConf.OPTIMIZER_EXCLUDED_RULES.key, + ConstantFolding.ruleName + "," + NullPropagation.ruleName) + } + + def createTable(tableName: String, rows: Seq[Row], schema: StructType): Unit = { + val file = Files.createTempFile(tableName, ".parquet").toFile + file.deleteOnExit() + val data = sparkContext.parallelize(rows) + val dataframe = spark.createDataFrame(data, schema) + dataframe.coalesce(1).write.format("parquet").mode("overwrite").parquet(file.getAbsolutePath) + spark.catalog.createTable(tableName, file.getAbsolutePath, "parquet") + } + + override def beforeAll(): Unit = { + super.beforeAll() + + val schema1 = StructType( + Array( + StructField("k1", IntegerType, true), + StructField("k2", IntegerType, true), + StructField("v1", IntegerType, true), + StructField("v2", IntegerType, true), + StructField("v3", IntegerType, true) + ) + ) + + val data1 = Seq( + Row(1, 1, 1, 1, 1), + Row(1, 1, null, 2, 1), + Row(1, 2, 1, 1, null), + Row(2, 1, 1, null, 1), + Row(2, 2, 1, 1, 1), + Row(2, 2, 1, null, 1), + Row(2, 2, 2, null, 3), + Row(2, 3, 0, null, 1), + Row(2, 4, 1, 2, 3), + Row(3, 1, 4, 5, 6), + Row(4, 2, 7, 8, 9), + Row(5, 3, 10, 11, 12) + ) + createTable("t1", data1, schema1) + createTable("t2", data1, schema1) + createTable("t3", data1, schema1) + } + + test("Eliminate two aggregates join") { + val sql = """ + select t1.k1, t1.k2, t2.k1, t2.k2, s1, s2 from ( + select k1, k2, sum(v1) s1 from ( + select * from t1 where k1 != 1 + )group by k1, k2 + ) t1 left join ( + select k1, k2, count(v1) s2 from ( + select * from t2 where k1 != 3 + )group by k1, k2 + ) t2 on t1.k1 = t2.k1 and t1.k2 = t2.k2 + order by t1.k1, t1.k2, s1, s2 + """.stripMargin + compareResultsAgainstVanillaSpark( + sql, + true, + { + df => + val joins = df.queryExecution.executedPlan.collect { + case join: ShuffledHashJoinExecTransformerBase => join + } + assert(joins.length == 0) + }) + } + + test("Elimiate three aggreages join") { + val sql = """ + select t1.k1, t1.k2, t2.k1, t3.k2, s1, s2, s3 from ( + select k1, k2, sum(v1) s1 from ( + select * from t1 where k1 != 1 + )group by k1, k2 + ) t1 left join ( + select k1, k2, count(v1) s2 from ( + select * from t2 where k1 != 3 + )group by k1, k2 + ) t2 on t1.k1 = t2.k1 and t1.k2 = t2.k2 + left join ( + select k1, k2, count(v2) s3 from ( + select * from t3 where k1 != 3 + )group by k1, k2 + ) t3 on t1.k1 = t3.k1 and t1.k2 = t3.k2 + order by t1.k1, t1.k2, s1, s2, s3 + """.stripMargin + compareResultsAgainstVanillaSpark( + sql, + true, + { + df => + val joins = df.queryExecution.executedPlan.collect { + case join: ShuffledHashJoinExecTransformerBase => join + } + assert(joins.length == 0) + }) + } + + test("Left one join uneliminable") { + val sql = """ + select t1.k1, t1.k2, s1, s2 from ( + select * from t1 where k1 != 1 + ) t1 left join ( + select k1, k2, count(v1) s1 from ( + select * from t2 where k1 != 1 + )group by k1, k2 + ) t2 on t1.k1 = t2.k1 and t1.k2 = t2.k2 + left join ( + select k1, k2, count(v2) s2 from ( + select * from t3 where k1 != 3 + )group by k1, k2 + ) t3 on t1.k1 = t3.k1 and t1.k2 = t3.k2 + order by t1.k1, t1.k2, s1, s2 + """.stripMargin + compareResultsAgainstVanillaSpark( + sql, + true, + { + df => + val joins = df.queryExecution.executedPlan.collect { + case join: ShuffledHashJoinExecTransformerBase => join + } + assert(joins.length == 1) + }) + } + + test("Left two joins uneliminable") { + val sql = """ + select t1.k1, t1.k2, s1, s2 from ( + select k1, k2, count(v1) s1 from ( + select * from t2 where k1 != 1 + )group by k1, k2 + ) t1 left join ( + select * from t1 where k1 != 1 + ) t2 on t1.k1 = t2.k1 and t1.k2 = t2.k2 + left join ( + select k1, k2, count(v2) s2 from ( + select * from t3 where k1 != 3 + )group by k1, k2 + ) t3 on t1.k1 = t3.k1 and t1.k2 = t3.k2 + order by t1.k1, t1.k2, s1, s2 + """.stripMargin + compareResultsAgainstVanillaSpark( + sql, + true, + { + df => + val joins = df.queryExecution.executedPlan.collect { + case join: ShuffledHashJoinExecTransformerBase => join + } + assert(joins.length == 2) + }) + } +} From c2161f4f8a71fc1b655e1a4441841bb752c1bef9 Mon Sep 17 00:00:00 2001 From: lgbo-ustc Date: Fri, 28 Mar 2025 10:47:25 +0800 Subject: [PATCH 06/10] support more aggregate functions --- .../backendsapi/clickhouse/CHRuleApi.scala | 9 +- .../JoinAggregateToAggregateUnion.scala | 129 ++++++++++++++---- .../execution/GlutenEliminateJoinSuite.scala | 98 +++++++++++++ 3 files changed, 210 insertions(+), 26 deletions(-) diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala index f78c4f483aff..46ee43b289fa 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala @@ -60,9 +60,6 @@ object CHRuleApi { (spark, parserInterface) => new GlutenCacheFilesSqlParser(spark, parserInterface)) injector.injectParser( (spark, parserInterface) => new GlutenClickhouseSqlParser(spark, parserInterface)) - injector.injectResolutionRule(spark => new CoalesceAggregationUnion(spark)) - injector.injectResolutionRule(spark => new CoalesceProjectionUnion(spark)) - injector.injectResolutionRule(spark => new JoinAggregateToAggregateUnion(spark)) injector.injectResolutionRule(spark => new RewriteToDateExpresstionRule(spark)) injector.injectResolutionRule(spark => new RewriteDateTimestampComparisonRule(spark)) injector.injectResolutionRule(spark => new CollapseGetJsonObjectExpressionRule(spark)) @@ -72,6 +69,12 @@ object CHRuleApi { injector.injectOptimizerRule(spark => new SimplifySumRule(spark)) injector.injectOptimizerRule(spark => new ExtendedColumnPruning(spark)) injector.injectOptimizerRule(spark => CHAggregateFunctionRewriteRule(spark)) + // JoinAggregateToAggregateUnion need to be applied on optimized plans + injector.injectOptimizerRule(spark => new JoinAggregateToAggregateUnion(spark)) + // CoalesceAggregationUnion and CoalesceProjectionUnion should follows + // JoinAggregateToAggregateUnion + injector.injectOptimizerRule(spark => new CoalesceAggregationUnion(spark)) + injector.injectOptimizerRule(spark => new CoalesceProjectionUnion(spark)) injector.injectOptimizerRule(_ => CountDistinctWithoutExpand) injector.injectOptimizerRule(_ => EqualToRewrite) injector.injectPreCBORule(spark => new CHOptimizeMetadataOnlyDeltaQuery(spark)) diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/JoinAggregateToAggregateUnion.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/JoinAggregateToAggregateUnion.scala index 95f10a2cfaaa..e2831da79165 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/JoinAggregateToAggregateUnion.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/JoinAggregateToAggregateUnion.scala @@ -30,22 +30,41 @@ import org.apache.spark.sql.types._ import scala.collection.mutable.ArrayBuffer +private object RuleExpressionHelper { + + def extractLiteral(e: Expression): Option[Literal] = { + e match { + case literal: Literal => Some(literal) + case _ @Alias(literal: Literal, _) => Some(literal) + case _ => None + } + } +} + trait AggregateFunctionAnalyzer { def doValidate(): Boolean def getArgumentExpressions(): Option[Seq[Expression]] def ignoreNulls(): Boolean + def buildUnionAggregateExpression( + arguments: Seq[Expression], + flag: Expression): AggregateExpression } case class DefaultAggregateFunctionAnalyzer() extends AggregateFunctionAnalyzer { override def doValidate(): Boolean = false override def getArgumentExpressions(): Option[Seq[Expression]] = None override def ignoreNulls(): Boolean = false + override def buildUnionAggregateExpression( + arguments: Seq[Expression], + flag: Expression): AggregateExpression = { + throw new GlutenException("Unsupported aggregate function") + } } -case class SumAnalyzer(aggExpr: AggregateExpression) extends AggregateFunctionAnalyzer { - val sum = aggExpr.aggregateFunction.asInstanceOf[Sum] +case class SumAnalyzer(aggregateExpression: AggregateExpression) extends AggregateFunctionAnalyzer { + val sum = aggregateExpression.aggregateFunction.asInstanceOf[Sum] override def doValidate(): Boolean = { - !sum.child.isInstanceOf[Literal] + aggregateExpression.filter.isEmpty } override def getArgumentExpressions(): Option[Seq[Expression]] = { @@ -53,25 +72,40 @@ case class SumAnalyzer(aggExpr: AggregateExpression) extends AggregateFunctionAn } override def ignoreNulls(): Boolean = true + + def buildUnionAggregateExpression( + arguments: Seq[Expression], + flag: Expression): AggregateExpression = { + val newSum = sum.copy(child = arguments.head) + aggregateExpression.copy(aggregateFunction = newSum) + } } -case class AverageAnalyzer(aggExpr: AggregateExpression) extends AggregateFunctionAnalyzer { - val sum = aggExpr.aggregateFunction.asInstanceOf[Sum] +case class AverageAnalyzer(aggregateExpression: AggregateExpression) + extends AggregateFunctionAnalyzer { + val avg = aggregateExpression.aggregateFunction.asInstanceOf[Average] override def doValidate(): Boolean = { - !sum.child.isInstanceOf[Literal] + aggregateExpression.filter.isEmpty } override def getArgumentExpressions(): Option[Seq[Expression]] = { - Some(Seq(sum.child)) + Some(Seq(avg.child)) } override def ignoreNulls(): Boolean = true + def buildUnionAggregateExpression( + arguments: Seq[Expression], + flag: Expression): AggregateExpression = { + val newAvg = avg.copy(child = arguments.head) + aggregateExpression.copy(aggregateFunction = newAvg) + } } -case class CountAnalyzer(aggExpr: AggregateExpression) extends AggregateFunctionAnalyzer { - val count = aggExpr.aggregateFunction.asInstanceOf[Count] +case class CountAnalyzer(aggregateExpression: AggregateExpression) + extends AggregateFunctionAnalyzer { + val count = aggregateExpression.aggregateFunction.asInstanceOf[Count] override def doValidate(): Boolean = { - count.children.length == 1 && !count.children.head.isInstanceOf[Literal] + count.children.length == 1 && aggregateExpression.filter.isEmpty } override def getArgumentExpressions(): Option[Seq[Expression]] = { @@ -79,7 +113,50 @@ case class CountAnalyzer(aggExpr: AggregateExpression) extends AggregateFunction } override def ignoreNulls(): Boolean = false + def buildUnionAggregateExpression( + arguments: Seq[Expression], + flag: Expression): AggregateExpression = { + val newCount = count.copy(children = arguments) + aggregateExpression.copy(aggregateFunction = newCount) + } +} +case class MinAnalyzer(aggregateExpression: AggregateExpression) extends AggregateFunctionAnalyzer { + val min = aggregateExpression.aggregateFunction.asInstanceOf[Min] + override def doValidate(): Boolean = { + aggregateExpression.filter.isEmpty + } + + override def getArgumentExpressions(): Option[Seq[Expression]] = { + Some(Seq(min.child)) + } + + override def ignoreNulls(): Boolean = false + def buildUnionAggregateExpression( + arguments: Seq[Expression], + flag: Expression): AggregateExpression = { + val newMin = min.copy(child = arguments.head) + aggregateExpression.copy(aggregateFunction = newMin) + } +} + +case class MaxAnalyzer(aggregateExpression: AggregateExpression) extends AggregateFunctionAnalyzer { + val max = aggregateExpression.aggregateFunction.asInstanceOf[Max] + override def doValidate(): Boolean = { + aggregateExpression.filter.isEmpty + } + + override def getArgumentExpressions(): Option[Seq[Expression]] = { + Some(Seq(max.child)) + } + + override def ignoreNulls(): Boolean = false + def buildUnionAggregateExpression( + arguments: Seq[Expression], + flag: Expression): AggregateExpression = { + val newMax = max.copy(child = arguments.head) + aggregateExpression.copy(aggregateFunction = newMax) + } } object AggregateFunctionAnalyzer { @@ -98,6 +175,8 @@ object AggregateFunctionAnalyzer { case sum: Sum => SumAnalyzer(agg) case avg: Average => AverageAnalyzer(agg) case count: Count => CountAnalyzer(agg) + case max: Max => MaxAnalyzer(agg) + case min: Min => MinAnalyzer(agg) case _ => DefaultAggregateFunctionAnalyzer() } case _ => DefaultAggregateFunctionAnalyzer() @@ -237,7 +316,6 @@ object JoinedAggregateAnalyzer extends Logging { case class JoinAggregateToAggregateUnion(spark: SparkSession) extends Rule[LogicalPlan] with Logging { - val JOIN_FILTER_FLAG_NAME = "_left_flag_" def isResolvedPlan(plan: LogicalPlan): Boolean = { plan match { case insert: InsertIntoStatement => insert.query.resolved @@ -289,8 +367,6 @@ case class JoinAggregateToAggregateUnion(spark: SparkSession) if (join.joinType == LeftOuter && join.condition.isDefined) { val analyzedAggregates = ArrayBuffer[JoinedAggregateAnalyzer]() val remainedPlan = collectSameKeysJoinedAggregates(join, analyzedAggregates) - logDebug(s"xxx join left\n$remainedPlan") - logDebug(s"xxx analyzed aggregates number ${analyzedAggregates.length}") if (analyzedAggregates.length == 0) { join.copy(left = visitPlan(join.left), right = visitPlan(join.right)) } else if (analyzedAggregates.length == 1) { @@ -301,7 +377,7 @@ case class JoinAggregateToAggregateUnion(spark: SparkSession) val lastJoin = analyzedAggregates.head.join lastJoin.copy(left = visitPlan(lastJoin.left), right = unionedAggregates) } else { - buildPrimeKeysFilterOnAggregateUnion(unionedAggregates, analyzedAggregates) + buildPrimeKeysFilterOnAggregateUnion(unionedAggregates, analyzedAggregates.toSeq) } } } else { @@ -360,13 +436,18 @@ case class JoinAggregateToAggregateUnion(spark: SparkSession) for (i <- 0 until analyzedAggregates.length) { val partialAggregateExpressions = analyzedAggregates(i).getAggregateExpressions() val partialAggregateArguments = analyzedAggregates(i).getAggregateExpressionArguments() + val aggregateFunctionAnalyzers = analyzedAggregates(i).getAggregateFunctionAnalyzers() for (j <- 0 until partialAggregateExpressions.length) { val aggregateExpression = partialAggregateExpressions(j) val arguments = partialAggregateArguments(j) + val aggregateFunctionAnalyzer = aggregateFunctionAnalyzers(j) val newArguments = unionOutput.slice(fieldIndex, fieldIndex + arguments.length) - aggregateExpressions += buildAggregateExpressionWithNewChildren( - aggregateExpression, - newArguments) + val flagExpr = unionOutput(unionOutput.length - analyzedAggregates.length + i) + val newAggregateExpression = aggregateFunctionAnalyzer + .buildUnionAggregateExpression(newArguments, flagExpr) + aggregateExpressions += makeNamedExpression( + newAggregateExpression, + aggregateExpression.name) fieldIndex += arguments.length } } @@ -381,8 +462,8 @@ case class JoinAggregateToAggregateUnion(spark: SparkSession) } /** - * If the grouping keys is in the right table but not in the left table, remove the row from the - * result. + * Some rows may come from the right tables which grouping keys are not in the prime keys set. We + * should remove them. */ def buildPrimeKeysFilterOnAggregateUnion( plan: LogicalPlan, @@ -392,7 +473,10 @@ case class JoinAggregateToAggregateUnion(spark: SparkSession) Filter(notNullExpr, plan) } - /** Make the aggregate result be null if the grouping keys is not in the most left table. */ + /** + * When a row's grouping keys are not present in the related table, the related aggregate values + * are replaced with nulls. + */ def buildMakeNotMatchedRowsNullProject( plan: LogicalPlan, analyzedAggregates: Seq[JoinedAggregateAnalyzer], @@ -425,8 +509,7 @@ case class JoinAggregateToAggregateUnion(spark: SparkSession) } /** - * Build a project to make the output attributes have the same name and exprId as the original - * join output attributes. + * A final step, ensure the output attributes have the same name and exprId as the original join */ def buildRenameProject( plan: LogicalPlan, @@ -484,7 +567,7 @@ case class JoinAggregateToAggregateUnion(spark: SparkSession) /** * Build a extended project list, which contains three parts. * - The grouping keys of all tables. - * - All required columns for aggregate functions in every table. + * - All required columns as arguments for the aggregate functions in every table. * - Flags for each table to indicate whether the row is in the table. */ def buildExtendProjectList( diff --git a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenEliminateJoinSuite.scala b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenEliminateJoinSuite.scala index e878247a5b33..ab06d94e9194 100644 --- a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenEliminateJoinSuite.scala +++ b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenEliminateJoinSuite.scala @@ -214,4 +214,102 @@ class GlutenEliminateJoinSuite extends GlutenClickHouseWholeStageTransformerSuit assert(joins.length == 2) }) } + + test("aggregate literal") { + val sql = """ + select t1.k1, t1.k2, t2.k1, t2.k2, s1, s2 from ( + select k1, k2, sum(2) s1 from ( + select * from t1 where k1 != 1 + )group by k1, k2 + ) t1 left join ( + select k1, k2, count(1) s2 from ( + select * from t2 where k1 != 3 + )group by k1, k2 + ) t2 on t1.k1 = t2.k1 and t1.k2 = t2.k2 + order by t1.k1, t1.k2, s1, s2 + """.stripMargin + compareResultsAgainstVanillaSpark( + sql, + true, + { + df => + val joins = df.queryExecution.executedPlan.collect { + case join: ShuffledHashJoinExecTransformerBase => join + } + assert(joins.length == 0) + }) + } + + test("aggregate avg") { + val sql = """ + select t1.k1, t1.k2, t2.k1, t2.k2, s1, s2 from ( + select k1, k2, avg(2) s1 from ( + select * from t1 where k1 != 1 + )group by k1, k2 + ) t1 left join ( + select k1, k2, avg(1) s2 from ( + select * from t2 where k1 != 3 + )group by k1, k2 + ) t2 on t1.k1 = t2.k1 and t1.k2 = t2.k2 + order by t1.k1, t1.k2, s1, s2 + """.stripMargin + compareResultsAgainstVanillaSpark( + sql, + true, + { + df => + val joins = df.queryExecution.executedPlan.collect { + case join: ShuffledHashJoinExecTransformerBase => join + } + assert(joins.length == 0) + }) + } + test("aggregate min/max") { + val sql = """ + select t1.k1, t1.k2, t2.k1, t2.k2, s1, s2 from ( + select k1, k2, min(2) s1 from ( + select * from t1 where k1 != 1 + )group by k1, k2 + ) t1 left join ( + select k1, k2, max(1) s2 from ( + select * from t2 where k1 != 3 + )group by k1, k2 + ) t2 on t1.k1 = t2.k1 and t1.k2 = t2.k2 + order by t1.k1, t1.k2, s1, s2 + """.stripMargin + compareResultsAgainstVanillaSpark( + sql, + true, + { + df => + val joins = df.queryExecution.executedPlan.collect { + case join: ShuffledHashJoinExecTransformerBase => join + } + assert(joins.length == 0) + }) + } + test("aggregate count distinct") { + val sql = """ + select t1.k1, t2.k1, s1, s2, s3, s4 from ( + select k1, k2, count(distinct v1) s1, count(distinct v2) as s2 from ( + select * from t1 where k1 != 1 + )group by k1 + ) t1 left join ( + select k1, count(distinct v1) s3, count(distinct v3) as s4 from ( + select * from t2 where k1 != 3 + )group by k1 + ) t2 on t1.k1 = t2.k1 + order by t1.k1, s1, s2, s3, s4 + """.stripMargin + compareResultsAgainstVanillaSpark( + sql, + true, + { + df => + val joins = df.queryExecution.executedPlan.collect { + case join: ShuffledHashJoinExecTransformerBase => join + } + assert(joins.length == 0) + }) + } } From 58364491d1e69c26ce5bef8a61d84b50486b4c42 Mon Sep 17 00:00:00 2001 From: lgbo-ustc Date: Fri, 28 Mar 2025 15:10:39 +0800 Subject: [PATCH 07/10] fix bugs --- .../backendsapi/clickhouse/CHRuleApi.scala | 11 +- .../JoinAggregateToAggregateUnion.scala | 195 +++++++++++++++--- .../execution/GlutenEliminateJoinSuite.scala | 54 ++++- 3 files changed, 220 insertions(+), 40 deletions(-) diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala index 46ee43b289fa..374e223ea9c9 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala @@ -60,6 +60,11 @@ object CHRuleApi { (spark, parserInterface) => new GlutenCacheFilesSqlParser(spark, parserInterface)) injector.injectParser( (spark, parserInterface) => new GlutenClickhouseSqlParser(spark, parserInterface)) + injector.injectResolutionRule(spark => new JoinAggregateToAggregateUnion(spark)) + // CoalesceAggregationUnion and CoalesceProjectionUnion should follows + // JoinAggregateToAggregateUnion + injector.injectResolutionRule(spark => new CoalesceAggregationUnion(spark)) + injector.injectResolutionRule(spark => new CoalesceProjectionUnion(spark)) injector.injectResolutionRule(spark => new RewriteToDateExpresstionRule(spark)) injector.injectResolutionRule(spark => new RewriteDateTimestampComparisonRule(spark)) injector.injectResolutionRule(spark => new CollapseGetJsonObjectExpressionRule(spark)) @@ -69,12 +74,6 @@ object CHRuleApi { injector.injectOptimizerRule(spark => new SimplifySumRule(spark)) injector.injectOptimizerRule(spark => new ExtendedColumnPruning(spark)) injector.injectOptimizerRule(spark => CHAggregateFunctionRewriteRule(spark)) - // JoinAggregateToAggregateUnion need to be applied on optimized plans - injector.injectOptimizerRule(spark => new JoinAggregateToAggregateUnion(spark)) - // CoalesceAggregationUnion and CoalesceProjectionUnion should follows - // JoinAggregateToAggregateUnion - injector.injectOptimizerRule(spark => new CoalesceAggregationUnion(spark)) - injector.injectOptimizerRule(spark => new CoalesceProjectionUnion(spark)) injector.injectOptimizerRule(_ => CountDistinctWithoutExpand) injector.injectOptimizerRule(_ => EqualToRewrite) injector.injectPreCBORule(spark => new CHOptimizeMetadataOnlyDeltaQuery(spark)) diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/JoinAggregateToAggregateUnion.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/JoinAggregateToAggregateUnion.scala index e2831da79165..f25bd7eb6b4c 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/JoinAggregateToAggregateUnion.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/JoinAggregateToAggregateUnion.scala @@ -30,6 +30,9 @@ import org.apache.spark.sql.types._ import scala.collection.mutable.ArrayBuffer +private case class JoinKeys( + leftKeys: Seq[AttributeReference], + rightKeys: Seq[AttributeReference]) {} private object RuleExpressionHelper { def extractLiteral(e: Expression): Option[Literal] = { @@ -39,6 +42,72 @@ private object RuleExpressionHelper { case _ => None } } + + def makeFlagLiteral(): Literal = { + Literal.create(true, BooleanType) + } + + def makeNullLiteral(dataType: DataType): Literal = { + Literal.create(null, dataType) + } + + def makeNamedExpression(e: Expression, name: String): NamedExpression = { + Alias(e, name)() + } + + def makeFirstAggregateExpression(e: Expression): AggregateExpression = { + val aggregateFunction = First(e, true) + AggregateExpression(aggregateFunction, Complete, false) + } + + def extractJoinKeys( + joinCondition: Option[Expression], + leftAttributes: AttributeSet, + rightAttributes: AttributeSet): Option[JoinKeys] = { + val leftKeys = ArrayBuffer[AttributeReference]() + val rightKeys = ArrayBuffer[AttributeReference]() + def visitJoinExpression(e: Expression): Unit = { + e match { + case and: And => + visitJoinExpression(and.left) + visitJoinExpression(and.right) + case equalTo @ EqualTo(left: AttributeReference, right: AttributeReference) => + if (leftAttributes.contains(left)) { + leftKeys += left + } + if (leftAttributes.contains(right)) { + leftKeys += right + } + if (rightAttributes.contains(left)) { + rightKeys += left + } + if (rightAttributes.contains(right)) { + rightKeys += right + } + case _ => + throw new GlutenException(s"Unsupported join condition $e") + } + } + + joinCondition match { + case Some(condition) => + try { + visitJoinExpression(condition) + if (leftKeys.length != rightKeys.length) { + return None + } + if (!leftKeys.forall(k => k.qualifier.equals(leftKeys.head.qualifier))) { + return None + } + Some(JoinKeys(leftKeys.toSeq, rightKeys.toSeq)) + } catch { + case e: GlutenException => + return None + } + case None => + return None + } + } } trait AggregateFunctionAnalyzer { @@ -190,10 +259,21 @@ case class JoinedAggregateAnalyzer(join: Join, subquery: LogicalPlan) extends Lo return false } + if (!aggregate.groupingExpressions.forall(_.isInstanceOf[AttributeReference])) { + return false + } + if (!extractJoinKeys()) { return false } + if ( + keys.length != aggregate.groupingExpressions.length || + !keys.forall(k => aggregate.groupingExpressions.exists(_.semanticEquals(k))) + ) { + return false + } + aggregateExpressions = aggregate.aggregateExpressions.filter(_.find(_.isInstanceOf[AggregateExpression]).isDefined) aggregateFunctionAnalyzer = aggregateExpressions.map(AggregateFunctionAnalyzer(_)) @@ -254,6 +334,7 @@ case class JoinedAggregateAnalyzer(join: Join, subquery: LogicalPlan) extends Lo visitJoinExpression(and.left) visitJoinExpression(and.right) case equalTo @ EqualTo(left: AttributeReference, right: AttributeReference) => + logDebug(s"xxx qualifier. left: ${left.qualifier}, right: ${right.qualifier}") if (leftOutputSet.contains(left)) { leftJoinKeys += left } @@ -313,6 +394,34 @@ object JoinedAggregateAnalyzer extends Logging { } } +/** + * ReorderJoinSubqueries is step before JoinAggregateToAggregateUnion. It will reorder the join + * subqueries to move the queries with the same prime keys, join keys and grouping keys together. + * For example + * ``` + * select t1.k1, t1.k2, s1, s2 from ( + * select k1, k2, count(v1) s1 from t1 group by k1, k2 + * ) t1 left join ( + * select * from t2 + * ) t2 on t1.k1 = t2.k1 and t1.k2 = t2.k2 + * left join ( + * select k1, k2, count(v2) s2 from t3 group by k1, k2 + * ) t3 on t1.k1 = t3.k1 and t1.k2 = t3.k2 + * ``` + * is rewritten into + * ``` + * select t1.k1, t1.k2, s1, s2 from ( + * select k1, k2, count(v1) s1 from t1 group by k1, k2 + * ) t1 left join ( + * select k1, k2, count(v2) s2 from t3 group by k1, k2 + * ) t3 on t1.k1 = t3.k1 and t1.k2 = t3.k2 + * ) left join ( + * select * from t2 + * ) t2 on t1.k1 = t2.k1 and t1.k2 = t2.k2 + * ``` + */ +case class ReorderJoinSubqueries() extends Logging {} + case class JoinAggregateToAggregateUnion(spark: SparkSession) extends Rule[LogicalPlan] with Logging { @@ -396,7 +505,7 @@ case class JoinAggregateToAggregateUnion(spark: SparkSession) .withNewChildren(inputs) .asInstanceOf[AggregateFunction] val newAggregateExpression = aggregateExpression.copy(aggregateFunction = newAggregateFunction) - makeNamedExpression(newAggregateExpression, ne.name) + RuleExpressionHelper.makeNamedExpression(newAggregateExpression, ne.name) } def unionAllJoinedAggregates(analyzedAggregates: Seq[JoinedAggregateAnalyzer]): LogicalPlan = { @@ -417,7 +526,7 @@ case class JoinAggregateToAggregateUnion(spark: SparkSession) val groupingKeys = unionOutput.slice(0, keysNumber).zip(analyzedAggregates.head.getKeys()).map { case (e, a) => - makeNamedExpression(e, a.name) + RuleExpressionHelper.makeNamedExpression(e, a.name) } aggregateExpressions ++= groupingKeys @@ -427,8 +536,8 @@ case class JoinAggregateToAggregateUnion(spark: SparkSession) for (j <- 0 until keys.length) { val key = keys(j) val valueExpr = unionOutput(fieldIndex + j) - val firstValue = makeFirstAggregateExpression(valueExpr) - aggregateExpressions += makeNamedExpression(firstValue, key.name) + val firstValue = RuleExpressionHelper.makeFirstAggregateExpression(valueExpr) + aggregateExpressions += RuleExpressionHelper.makeNamedExpression(firstValue, key.name) } } @@ -445,7 +554,7 @@ case class JoinAggregateToAggregateUnion(spark: SparkSession) val flagExpr = unionOutput(unionOutput.length - analyzedAggregates.length + i) val newAggregateExpression = aggregateFunctionAnalyzer .buildUnionAggregateExpression(newArguments, flagExpr) - aggregateExpressions += makeNamedExpression( + aggregateExpressions += RuleExpressionHelper.makeNamedExpression( newAggregateExpression, aggregateExpression.name) fieldIndex += arguments.length @@ -454,8 +563,8 @@ case class JoinAggregateToAggregateUnion(spark: SparkSession) for (i <- fieldIndex until unionOutput.length) { val valueExpr = unionOutput(i) - val firstValue = makeFirstAggregateExpression(valueExpr) - aggregateExpressions += makeNamedExpression(firstValue, valueExpr.name) + val firstValue = RuleExpressionHelper.makeFirstAggregateExpression(valueExpr) + aggregateExpressions += RuleExpressionHelper.makeNamedExpression(firstValue, valueExpr.name) } Aggregate(groupingKeys, aggregateExpressions.toSeq, union) @@ -498,8 +607,11 @@ case class JoinAggregateToAggregateUnion(spark: SparkSession) if (ignoreAggregates(i) || aggregateFunctionAnalyzers(i).ignoreNulls()) { valueExpr.asInstanceOf[NamedExpression] } else { - val clearExpr = If(IsNull(flagExpr), makeNullLiteral(valueExpr.dataType), valueExpr) - makeNamedExpression(clearExpr, valueExpr.name) + val clearExpr = If( + IsNull(flagExpr), + RuleExpressionHelper.makeNullLiteral(valueExpr.dataType), + valueExpr) + RuleExpressionHelper.makeNamedExpression(clearExpr, valueExpr.name) } } } @@ -547,23 +659,6 @@ case class JoinAggregateToAggregateUnion(spark: SparkSession) Project(projectList.toSeq, plan) } - def makeNamedExpression(e: Expression, name: String): NamedExpression = { - Alias(e, name)() - } - - def makeNullLiteral(dataType: DataType): Literal = { - Literal.create(null, dataType) - } - - def makeFlagLiteral(): Literal = { - Literal.create(true, BooleanType) - } - - def makeFirstAggregateExpression(e: Expression): AggregateExpression = { - val aggregateFunction = First(e, true) - AggregateExpression(aggregateFunction, Complete, false) - } - /** * Build a extended project list, which contains three parts. * - The grouping keys of all tables. @@ -576,16 +671,18 @@ case class JoinAggregateToAggregateUnion(spark: SparkSession) val projectList = ArrayBuffer[NamedExpression]() projectList ++= analyzedAggregates(index).getKeys().zipWithIndex.map { case (e, i) => - makeNamedExpression(e, s"key_0_$i") + RuleExpressionHelper.makeNamedExpression(e, s"key_0_$i") } for (i <- 1 until analyzedAggregates.length) { val groupingKeys = analyzedAggregates(i).getKeys() projectList ++= groupingKeys.zipWithIndex.map { case (e, j) => if (i == index) { - makeNamedExpression(e, s"key_${i}_$j") + RuleExpressionHelper.makeNamedExpression(e, s"key_${i}_$j") } else { - makeNamedExpression(makeNullLiteral(e.dataType), s"key_${i}_$j") + RuleExpressionHelper.makeNamedExpression( + RuleExpressionHelper.makeNullLiteral(e.dataType), + s"key_${i}_$j") } } } @@ -597,9 +694,11 @@ case class JoinAggregateToAggregateUnion(spark: SparkSession) projectList ++= args.zipWithIndex.map { case (arg, k) => if (i == index) { - makeNamedExpression(arg, s"arg_${i}_${j}_$k") + RuleExpressionHelper.makeNamedExpression(arg, s"arg_${i}_${j}_$k") } else { - makeNamedExpression(makeNullLiteral(arg.dataType), s"arg_${i}_${j}_$k") + RuleExpressionHelper.makeNamedExpression( + RuleExpressionHelper.makeNullLiteral(arg.dataType), + s"arg_${i}_${j}_$k") } } } @@ -607,9 +706,13 @@ case class JoinAggregateToAggregateUnion(spark: SparkSession) for (i <- 0 until analyzedAggregates.length) { if (i == index) { - projectList += makeNamedExpression(makeFlagLiteral(), s"flag_$i") + projectList += RuleExpressionHelper.makeNamedExpression( + RuleExpressionHelper.makeFlagLiteral(), + s"flag_$i") } else { - projectList += makeNamedExpression(makeNullLiteral(BooleanType), s"flag_$i") + projectList += RuleExpressionHelper.makeNamedExpression( + RuleExpressionHelper.makeNullLiteral(BooleanType), + s"flag_$i") } } @@ -638,9 +741,24 @@ case class JoinAggregateToAggregateUnion(spark: SparkSession) } } + def transformDistinctToAggregate(distinct: Distinct): Aggregate = { + Aggregate(distinct.child.output, distinct.child.output, distinct.child) + } + def extractDirectAggregate(plan: LogicalPlan): Option[Aggregate] = { plan match { case _ @SubqueryAlias(_, aggregate: Aggregate) => Some(aggregate) + case project @ Project(projectList, aggregate: Aggregate) => + if (projectList.forall(_.isInstanceOf[AttributeReference])) { + // Just a column prune projection, ignore it + Some(aggregate) + } else { + None + } + case _ @SubqueryAlias(_, distinct: Distinct) => + Some(transformDistinctToAggregate(distinct)) + case distinct: Distinct => + Some(transformDistinctToAggregate(distinct)) case aggregate: Aggregate => Some(aggregate) case _ => None } @@ -706,4 +824,15 @@ case class JoinAggregateToAggregateUnion(spark: SparkSession) } } + def collectSameKeysJoinSubqueries( + plan: LogicalPlan, + groupedAggregates: ArrayBuffer[ArrayBuffer[JoinedAggregateAnalyzer]], + nonAggregates: ArrayBuffer[LogicalPlan]): Option[LogicalPlan] = { + plan match { + case join: Join if join.joinType == LeftOuter && join.condition.isDefined => + case _ => Some(plan) + } + None + } + } diff --git a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenEliminateJoinSuite.scala b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenEliminateJoinSuite.scala index ab06d94e9194..28007a64baf4 100644 --- a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenEliminateJoinSuite.scala +++ b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenEliminateJoinSuite.scala @@ -288,10 +288,11 @@ class GlutenEliminateJoinSuite extends GlutenClickHouseWholeStageTransformerSuit assert(joins.length == 0) }) } + test("aggregate count distinct") { val sql = """ select t1.k1, t2.k1, s1, s2, s3, s4 from ( - select k1, k2, count(distinct v1) s1, count(distinct v2) as s2 from ( + select k1, count(distinct v1) s1, count(distinct v2) as s2 from ( select * from t1 where k1 != 1 )group by k1 ) t1 left join ( @@ -312,4 +313,55 @@ class GlutenEliminateJoinSuite extends GlutenClickHouseWholeStageTransformerSuit assert(joins.length == 0) }) } + + test("distinct") { + val sql = """ + select t1.k1, t2.k1, s1 from ( + select k1, count(v1) as s1 from ( + select * from t1 where k1 != 1 + )group by k1 + ) t1 left join ( + select distinct k1 from ( + select * from t2 where k1 != 3 + ) t2 + ) t2 on t1.k1 = t2.k1 + order by t1.k1, s1 + """.stripMargin + compareResultsAgainstVanillaSpark( + sql, + true, + { + df => + val joins = df.queryExecution.executedPlan.collect { + case join: ShuffledHashJoinExecTransformerBase => join + } + assert(joins.length == 0) + }) + } + + test("different join keys and grouping keys 1") { + val sql = """ + select t1.k1, t2.k1, s1 from ( + select k1, count(v1) as s1 from ( + select * from t1 where k1 != 1 + )group by k1 + ) t1 left join ( + select distinct k1, k2 from ( + select * from t2 where k1 != 3 + ) t2 + ) t2 on t1.k1 = t2.k1 + order by t1.k1, s1 + """.stripMargin + compareResultsAgainstVanillaSpark( + sql, + true, + { + df => + val joins = df.queryExecution.executedPlan.collect { + case join: ShuffledHashJoinExecTransformerBase => join + } + assert(joins.length == 1) + }) + + } } From a020029de0e68334ef37b0bcb430b5846b413ae0 Mon Sep 17 00:00:00 2001 From: lgbo-ustc Date: Mon, 31 Mar 2025 15:18:49 +0800 Subject: [PATCH 08/10] reorder joins --- .../JoinAggregateToAggregateUnion.scala | 359 ++++++++++++------ .../execution/GlutenEliminateJoinSuite.scala | 115 +++++- 2 files changed, 362 insertions(+), 112 deletions(-) diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/JoinAggregateToAggregateUnion.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/JoinAggregateToAggregateUnion.scala index f25bd7eb6b4c..d6675327d8bd 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/JoinAggregateToAggregateUnion.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/JoinAggregateToAggregateUnion.scala @@ -28,11 +28,39 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.types._ +// import scala.collection.mutable import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.ListBuffer private case class JoinKeys( leftKeys: Seq[AttributeReference], rightKeys: Seq[AttributeReference]) {} + +private object RulePlanHeler { + def transformDistinctToAggregate(distinct: Distinct): Aggregate = { + Aggregate(distinct.child.output, distinct.child.output, distinct.child) + } + + def extractDirectAggregate(plan: LogicalPlan): Option[Aggregate] = { + plan match { + case _ @SubqueryAlias(_, aggregate: Aggregate) => Some(aggregate) + case project @ Project(projectList, aggregate: Aggregate) => + if (projectList.forall(_.isInstanceOf[AttributeReference])) { + // Just a column prune projection, ignore it + Some(aggregate) + } else { + None + } + case _ @SubqueryAlias(_, distinct: Distinct) => + Some(transformDistinctToAggregate(distinct)) + case distinct: Distinct => + Some(transformDistinctToAggregate(distinct)) + case aggregate: Aggregate => Some(aggregate) + case _ => None + } + } +} + private object RuleExpressionHelper { def extractLiteral(e: Expression): Option[Literal] = { @@ -52,7 +80,11 @@ private object RuleExpressionHelper { } def makeNamedExpression(e: Expression, name: String): NamedExpression = { - Alias(e, name)() + e match { + case alias: Alias => + Alias(alias.child, name)() + case _ => Alias(e, name)() + } } def makeFirstAggregateExpression(e: Expression): AggregateExpression = { @@ -99,6 +131,14 @@ private object RuleExpressionHelper { if (!leftKeys.forall(k => k.qualifier.equals(leftKeys.head.qualifier))) { return None } + // They must be the same sets or total different sets + if ( + leftKeys.length != rightKeys.length || + !(leftKeys.forall(key => rightKeys.exists(_.equals(key))) || + leftKeys.forall(key => !rightKeys.exists(_.equals(key)))) + ) { + return None + } Some(JoinKeys(leftKeys.toSeq, rightKeys.toSeq)) } catch { case e: GlutenException => @@ -209,6 +249,26 @@ case class MinAnalyzer(aggregateExpression: AggregateExpression) extends Aggrega } } +case class FirstAnalyzer(aggregateExpression: AggregateExpression) + extends AggregateFunctionAnalyzer { + val first = aggregateExpression.aggregateFunction.asInstanceOf[First] + override def doValidate(): Boolean = { + aggregateExpression.filter.isEmpty && first.ignoreNulls + } + + override def getArgumentExpressions(): Option[Seq[Expression]] = { + Some(Seq(first.child)) + } + + override def ignoreNulls(): Boolean = true + def buildUnionAggregateExpression( + arguments: Seq[Expression], + flag: Expression): AggregateExpression = { + val newFirst = first.copy(child = arguments.head) + aggregateExpression.copy(aggregateFunction = newFirst) + } +} + case class MaxAnalyzer(aggregateExpression: AggregateExpression) extends AggregateFunctionAnalyzer { val max = aggregateExpression.aggregateFunction.asInstanceOf[Max] override def doValidate(): Boolean = { @@ -246,6 +306,7 @@ object AggregateFunctionAnalyzer { case count: Count => CountAnalyzer(agg) case max: Max => MaxAnalyzer(agg) case min: Min => MinAnalyzer(agg) + case first: First => FirstAnalyzer(agg) case _ => DefaultAggregateFunctionAnalyzer() } case _ => DefaultAggregateFunctionAnalyzer() @@ -259,7 +320,7 @@ case class JoinedAggregateAnalyzer(join: Join, subquery: LogicalPlan) extends Lo return false } - if (!aggregate.groupingExpressions.forall(_.isInstanceOf[AttributeReference])) { + if (!extractGroupingKeys()) { return false } @@ -269,7 +330,7 @@ case class JoinedAggregateAnalyzer(join: Join, subquery: LogicalPlan) extends Lo if ( keys.length != aggregate.groupingExpressions.length || - !keys.forall(k => aggregate.groupingExpressions.exists(_.semanticEquals(k))) + !keys.forall(k => outputGroupingKeys.exists(_.semanticEquals(k))) ) { return false } @@ -277,9 +338,14 @@ case class JoinedAggregateAnalyzer(join: Join, subquery: LogicalPlan) extends Lo aggregateExpressions = aggregate.aggregateExpressions.filter(_.find(_.isInstanceOf[AggregateExpression]).isDefined) aggregateFunctionAnalyzer = aggregateExpressions.map(AggregateFunctionAnalyzer(_)) + + // If there is any const value in the aggregate expressions, return false + if (aggregateExpressions.length + keys.length != aggregate.aggregateExpressions.length) { + return false + } if ( aggregateExpressions.zipWithIndex.exists { - case (_, i) => !aggregateFunctionAnalyzer(i).doValidate + case (e, i) => !aggregateFunctionAnalyzer(i).doValidate } ) { return false @@ -312,6 +378,8 @@ case class JoinedAggregateAnalyzer(join: Join, subquery: LogicalPlan) extends Lo def getPrimeKeys(): Seq[AttributeReference] = primeKeys def getKeys(): Seq[AttributeReference] = keys def getAggregate(): Aggregate = aggregate + def getGroupingKeys(): Seq[Attribute] = outputGroupingKeys + def getGroupingExpressions(): Seq[NamedExpression] = groupingExpressions def getAggregateExpressions(): Seq[NamedExpression] = aggregateExpressions def getAggregateExpressionArguments(): Seq[Seq[Expression]] = aggregateExpressionArguments def getAggregateFunctionAnalyzers(): Seq[AggregateFunctionAnalyzer] = aggregateFunctionAnalyzer @@ -319,6 +387,8 @@ case class JoinedAggregateAnalyzer(join: Join, subquery: LogicalPlan) extends Lo private var primeKeys: Seq[AttributeReference] = Seq.empty private var keys: Seq[AttributeReference] = Seq.empty private var aggregate: Aggregate = null + private var groupingExpressions: Seq[NamedExpression] = null + private var outputGroupingKeys: Seq[Attribute] = null private var aggregateExpressions: Seq[NamedExpression] = Seq.empty private var aggregateExpressionArguments: Seq[Seq[Expression]] = Seq.empty private var aggregateFunctionAnalyzer = Seq.empty[AggregateFunctionAnalyzer] @@ -328,51 +398,57 @@ case class JoinedAggregateAnalyzer(join: Join, subquery: LogicalPlan) extends Lo val subqueryKeys = ArrayBuffer[AttributeReference]() val leftOutputSet = join.left.outputSet val subqueryOutputSet = subquery.outputSet - def visitJoinExpression(e: Expression): Unit = { - e match { - case and: And => - visitJoinExpression(and.left) - visitJoinExpression(and.right) - case equalTo @ EqualTo(left: AttributeReference, right: AttributeReference) => - logDebug(s"xxx qualifier. left: ${left.qualifier}, right: ${right.qualifier}") - if (leftOutputSet.contains(left)) { - leftJoinKeys += left - } - if (subqueryOutputSet.contains(left)) { - subqueryKeys += left - } - if (leftOutputSet.contains(right)) { - leftJoinKeys += right - } - if (subqueryOutputSet.contains(right)) { - subqueryKeys += right - } - case _ => - throw new GlutenException(s"Unsupported join condition $e") - } + val joinKeys = + RuleExpressionHelper.extractJoinKeys(join.condition, leftOutputSet, subqueryOutputSet) + if (joinKeys.isEmpty) { + false + } else { + primeKeys = joinKeys.get.leftKeys + keys = joinKeys.get.rightKeys + true } + } - // They must be the same sets or total different sets - if ( - leftJoinKeys.length != subqueryKeys.length || - !(leftJoinKeys.forall(key => subqueryKeys.exists(_.equals(key))) || - leftJoinKeys.forall(key => !subqueryKeys.exists(_.equals(key)))) - ) { - return false + def extractGroupingKeys(): Boolean = { + val outputGroupingKeysBuffer = ArrayBuffer[Attribute]() + val groupingExpressionsBuffer = ArrayBuffer[NamedExpression]() + val indexedAggregateExpressions = aggregate.aggregateExpressions.zipWithIndex + val aggregateOuput = aggregate.output + + def findMatchedAggregateExpression(e: Expression): Option[Tuple2[NamedExpression, Int]] = { + indexedAggregateExpressions.find { + case (aggExpr, i) => + aggExpr match { + case alias: Alias => + alias.child.semanticEquals(e) || alias.semanticEquals(e) + case _ => aggExpr.semanticEquals(e) + } + } } - join.condition match { - case Some(condition) => - try { - visitJoinExpression(condition) - } catch { - case e: GlutenException => - return false + aggregate.groupingExpressions.map { + e => + e match { + case nameExpression: NamedExpression => + groupingExpressionsBuffer += nameExpression + findMatchedAggregateExpression(nameExpression) match { + case Some((aggExpr, i)) => + outputGroupingKeysBuffer += aggregateOuput(i) + case None => + return false + } + case other => + findMatchedAggregateExpression(other) match { + case Some((aggExpr, i)) => + outputGroupingKeysBuffer += aggregateOuput(i) + groupingExpressionsBuffer += aggregate.aggregateExpressions(i) + case None => + return false + } } - keys = subqueryKeys.toSeq - primeKeys = leftJoinKeys.toSeq - true - case _ => false } + outputGroupingKeys = outputGroupingKeysBuffer.toSeq + groupingExpressions = groupingExpressionsBuffer.toSeq + true } } @@ -420,7 +496,117 @@ object JoinedAggregateAnalyzer extends Logging { * ) t2 on t1.k1 = t2.k1 and t1.k2 = t2.k2 * ``` */ -case class ReorderJoinSubqueries() extends Logging {} +case class ReorderJoinSubqueries() extends Logging { + case class SameJoinKeysPlans( + primeKeys: AttributeSet, + plans: ListBuffer[Tuple2[LogicalPlan, Join]]) {} + def apply(plan: LogicalPlan): LogicalPlan = { + visitPlan(plan) + } + + def visitPlan(plan: LogicalPlan): LogicalPlan = { + plan match { + case join: Join if join.joinType == LeftOuter && join.condition.isDefined => + val sameJoinKeysPlansList = ListBuffer[SameJoinKeysPlans]() + val plan = visitJoin(join, sameJoinKeysPlansList) + finishReorderJoinPlans(plan, sameJoinKeysPlansList) + case _ => + plan.withNewChildren(plan.children.map(visitPlan)) + } + } + + def findSameJoinKeysPlansIndex()( + sameJoinKeysPlans: ListBuffer[SameJoinKeysPlans], + primeKeys: Seq[AttributeReference]): Int = { + val keysSet = AttributeSet(primeKeys) + for (i <- 0 until sameJoinKeysPlans.length) { + if (sameJoinKeysPlans(i).primeKeys.equals(keysSet)) { + return i + } + } + -1 + } + + def finishReorderJoinPlans( + left: LogicalPlan, + sameJoinKeysPlansList: ListBuffer[SameJoinKeysPlans]): LogicalPlan = { + var finalPlan = left + for (j <- sameJoinKeysPlansList.length - 1 to 0 by -1) { + for (i <- sameJoinKeysPlansList(j).plans.length - 1 to 0 by -1) { + val plan = sameJoinKeysPlansList(j).plans(i)._1 + val originalJoin = sameJoinKeysPlansList(j).plans(i)._2 + finalPlan = originalJoin.copy(left = finalPlan, right = plan) + } + } + finalPlan + } + + def visitJoin( + plan: LogicalPlan, + sameJoinKeysPlansList: ListBuffer[SameJoinKeysPlans]): LogicalPlan = { + plan match { + case join: Join if join.joinType == LeftOuter && join.condition.isDefined => + val joinKeys = RuleExpressionHelper.extractJoinKeys( + join.condition, + join.left.outputSet, + join.right.outputSet) + joinKeys match { + case Some(keys) => + val index = findSameJoinKeysPlansIndex()(sameJoinKeysPlansList, keys.leftKeys) + val newRight = visitPlan(join.right) + if (index == -1) { + sameJoinKeysPlansList += SameJoinKeysPlans( + AttributeSet(keys.leftKeys), + ListBuffer(Tuple2(newRight, join))) + } else { + if (index != sameJoinKeysPlansList.length - 1) {} + val sameJoinKeysPlans = sameJoinKeysPlansList.remove(index) + if ( + RulePlanHeler.extractDirectAggregate(newRight).isDefined || + RulePlanHeler.extractDirectAggregate(sameJoinKeysPlans.plans.last._1).isEmpty + ) { + sameJoinKeysPlans.plans += Tuple2(newRight, join) + } else { + sameJoinKeysPlans.plans.insert(0, Tuple2(newRight, join)) + } + sameJoinKeysPlansList += sameJoinKeysPlans + } + visitJoin(join.left, sameJoinKeysPlansList) + case None => + val joinLeft = visitPlan(join.left) + val joinRight = visitPlan(join.right) + join.copy(left = joinLeft, right = joinRight) + } + case subquery: SubqueryAlias if RulePlanHeler.extractDirectAggregate(subquery).isDefined => + val newAggregate = visitPlan(subquery.child) + val groupingKeys = RulePlanHeler.extractDirectAggregate(subquery).get.groupingExpressions + if (groupingKeys.forall(_.isInstanceOf[AttributeReference])) { + val keys = groupingKeys.map(_.asInstanceOf[AttributeReference]) + val index = findSameJoinKeysPlansIndex()(sameJoinKeysPlansList, keys) + if (index != -1 && index != sameJoinKeysPlansList.length - 1) { + val sameJoinKeysPlans = sameJoinKeysPlansList.remove(index) + sameJoinKeysPlansList += sameJoinKeysPlans + } + } + subquery.copy(child = newAggregate) + case aggregate: Aggregate => + val newAggregate = aggregate.withNewChildren(aggregate.children.map(visitPlan)) + val groupingKeys = aggregate.groupingExpressions + if (groupingKeys.forall(_.isInstanceOf[AttributeReference])) { + val keys = groupingKeys.map(_.asInstanceOf[AttributeReference]) + val index = findSameJoinKeysPlansIndex()(sameJoinKeysPlansList, keys) + if (index != -1 && index != sameJoinKeysPlansList.length - 1) { + val sameJoinKeysPlans = sameJoinKeysPlansList.remove(index) + sameJoinKeysPlansList += sameJoinKeysPlans + } + } + newAggregate + case _ => + plan.withNewChildren(plan.children.map(visitPlan)) + } + } + +} case class JoinAggregateToAggregateUnion(spark: SparkSession) extends Rule[LogicalPlan] @@ -438,8 +624,9 @@ case class JoinAggregateToAggregateUnion(spark: SparkSession) .get(CHBackendSettings.GLUTEN_JOIN_AGGREGATE_TO_AGGREGATE_UNION, "true") .toBoolean && isResolvedPlan(plan) ) { - val newPlan = visitPlan(plan) - logDebug(s"old plan\n$plan\nnew plan\n$newPlan") + val reorderedPlan = ReorderJoinSubqueries().apply(plan) + val newPlan = visitPlan(reorderedPlan) + logDebug(s"Rewrite plan from \n$plan to \n$newPlan") newPlan } else { plan @@ -512,6 +699,7 @@ case class JoinAggregateToAggregateUnion(spark: SparkSession) val extendProjects = buildExtendProjects(analyzedAggregates) val union = buildUnionOnExtendedProjects(extendProjects) val aggregateUnion = buildAggregateOnUnion(union, analyzedAggregates) + logDebug(s"xxx aggregateUnion $aggregateUnion") val setNullsProject = buildMakeNotMatchedRowsNullProject(aggregateUnion, analyzedAggregates, Set()) buildRenameProject(setNullsProject, analyzedAggregates) @@ -521,17 +709,18 @@ case class JoinAggregateToAggregateUnion(spark: SparkSession) union: LogicalPlan, analyzedAggregates: Seq[JoinedAggregateAnalyzer]): LogicalPlan = { val unionOutput = union.output - val keysNumber = analyzedAggregates.head.getKeys().length + val keysNumber = analyzedAggregates.head.getGroupingKeys().length val aggregateExpressions = ArrayBuffer[NamedExpression]() - val groupingKeys = unionOutput.slice(0, keysNumber).zip(analyzedAggregates.head.getKeys()).map { - case (e, a) => - RuleExpressionHelper.makeNamedExpression(e, a.name) - } + val groupingKeys = + unionOutput.slice(0, keysNumber).zip(analyzedAggregates.head.getGroupingKeys()).map { + case (e, a) => + RuleExpressionHelper.makeNamedExpression(e, a.name) + } aggregateExpressions ++= groupingKeys for (i <- 1 until analyzedAggregates.length) { - val keys = analyzedAggregates(i).getKeys() + val keys = analyzedAggregates(i).getGroupingKeys() val fieldIndex = i * keysNumber for (j <- 0 until keys.length) { val key = keys(j) @@ -593,7 +782,8 @@ case class JoinAggregateToAggregateUnion(spark: SparkSession) val input = plan.output val flagExpressions = input.slice(plan.output.length - analyzedAggregates.length, plan.output.length) - val aggregateExprsStart = analyzedAggregates.length * analyzedAggregates.head.getKeys().length + val aggregateExprsStart = + analyzedAggregates.length * analyzedAggregates.head.getGroupingKeys.length var fieldIndex = aggregateExprsStart val aggregatesIfNullExpressions = analyzedAggregates.zipWithIndex.map { case (analyzedAggregate, i) => @@ -627,12 +817,11 @@ case class JoinAggregateToAggregateUnion(spark: SparkSession) plan: LogicalPlan, analyzedAggregates: Seq[JoinedAggregateAnalyzer]): LogicalPlan = { val input = plan.output - val joinKeys = analyzedAggregates.map(_.getKeys()) val projectList = ArrayBuffer[NamedExpression]() - val keysNum = analyzedAggregates.head.getKeys().length + val keysNum = analyzedAggregates.head.getGroupingKeys.length var fieldIndex = 0 for (i <- 0 until analyzedAggregates.length) { - val keys = analyzedAggregates(i).getKeys() + val keys = analyzedAggregates(i).getGroupingKeys() for (j <- 0 until keys.length) { val key = keys(j) projectList += Alias(input(fieldIndex), key.name)( @@ -669,12 +858,12 @@ case class JoinAggregateToAggregateUnion(spark: SparkSession) analyzedAggregates: Seq[JoinedAggregateAnalyzer], index: Int): Seq[NamedExpression] = { val projectList = ArrayBuffer[NamedExpression]() - projectList ++= analyzedAggregates(index).getKeys().zipWithIndex.map { + projectList ++= analyzedAggregates(index).getGroupingExpressions().zipWithIndex.map { case (e, i) => RuleExpressionHelper.makeNamedExpression(e, s"key_0_$i") } for (i <- 1 until analyzedAggregates.length) { - val groupingKeys = analyzedAggregates(i).getKeys() + val groupingKeys = analyzedAggregates(i).getGroupingExpressions() projectList ++= groupingKeys.zipWithIndex.map { case (e, j) => if (i == index) { @@ -730,60 +919,16 @@ case class JoinAggregateToAggregateUnion(spark: SparkSession) def buildUnionOnExtendedProjects(plans: Seq[LogicalPlan]): LogicalPlan = { val union = Union(plans) + logDebug(s"xxx build union: $union") union } - def isDirectAggregate(plan: LogicalPlan): Boolean = { - plan match { - case _ @SubqueryAlias(_, aggregate: Aggregate) => true - case _: Aggregate => true - case _ => false - } - } - - def transformDistinctToAggregate(distinct: Distinct): Aggregate = { - Aggregate(distinct.child.output, distinct.child.output, distinct.child) - } - - def extractDirectAggregate(plan: LogicalPlan): Option[Aggregate] = { - plan match { - case _ @SubqueryAlias(_, aggregate: Aggregate) => Some(aggregate) - case project @ Project(projectList, aggregate: Aggregate) => - if (projectList.forall(_.isInstanceOf[AttributeReference])) { - // Just a column prune projection, ignore it - Some(aggregate) - } else { - None - } - case _ @SubqueryAlias(_, distinct: Distinct) => - Some(transformDistinctToAggregate(distinct)) - case distinct: Distinct => - Some(transformDistinctToAggregate(distinct)) - case aggregate: Aggregate => Some(aggregate) - case _ => None - } - } - - def isDirectJoin(plan: LogicalPlan): Boolean = { - plan match { - case _: Join => true - case _ => false - } - } - - def extractDirectJoin(plan: LogicalPlan): Option[Join] = { - plan match { - case join: Join => Some(join) - case _ => None - } - } - def collectSameKeysJoinedAggregates( plan: LogicalPlan, analyzedAggregates: ArrayBuffer[JoinedAggregateAnalyzer]): Option[LogicalPlan] = { plan match { case join: Join if join.joinType == LeftOuter && join.condition.isDefined => - val optionAggregate = extractDirectAggregate(join.right) + val optionAggregate = RulePlanHeler.extractDirectAggregate(join.right) if (optionAggregate.isEmpty) { return Some(plan) } @@ -803,8 +948,8 @@ case class JoinAggregateToAggregateUnion(spark: SparkSession) } else { Some(plan) } - case _ if isDirectAggregate(plan) => - val aggregate = extractDirectAggregate(plan).get + case _ if RulePlanHeler.extractDirectAggregate(plan).isDefined => + val aggregate = RulePlanHeler.extractDirectAggregate(plan).get val lastJoin = analyzedAggregates.head.join assert(lastJoin.left.equals(plan), "The node should be last join's left child") val leftAggregateAnalyzer = JoinedAggregateAnalyzer.build(lastJoin, aggregate) diff --git a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenEliminateJoinSuite.scala b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenEliminateJoinSuite.scala index 28007a64baf4..080fc9b5c9a2 100644 --- a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenEliminateJoinSuite.scala +++ b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenEliminateJoinSuite.scala @@ -17,6 +17,7 @@ package org.apache.gluten.execution import org.apache.spark.SparkConf +import org.apache.spark.internal.Logging import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.optimizer.{ConstantFolding, NullPropagation} import org.apache.spark.sql.execution.datasources.v2.clickhouse.ClickHouseConfig @@ -25,7 +26,7 @@ import org.apache.spark.sql.types._ import java.nio.file.Files -class GlutenEliminateJoinSuite extends GlutenClickHouseWholeStageTransformerSuite { +class GlutenEliminateJoinSuite extends GlutenClickHouseWholeStageTransformerSuite with Logging { protected val tablesPath: String = basePath + "/tpch-data" protected val tpchQueries: String = @@ -187,9 +188,9 @@ class GlutenEliminateJoinSuite extends GlutenClickHouseWholeStageTransformerSuit }) } - test("Left two joins uneliminable") { + test("reorder join orders 1") { val sql = """ - select t1.k1, t1.k2, s1, s2 from ( + select t1.k1, t1.k2, t2.k1, s1, s2 from ( select k1, k2, count(v1) s1 from ( select * from t2 where k1 != 1 )group by k1, k2 @@ -201,7 +202,7 @@ class GlutenEliminateJoinSuite extends GlutenClickHouseWholeStageTransformerSuit select * from t3 where k1 != 3 )group by k1, k2 ) t3 on t1.k1 = t3.k1 and t1.k2 = t3.k2 - order by t1.k1, t1.k2, s1, s2 + order by t1.k1, t1.k2, t2.k1, s1, s2 """.stripMargin compareResultsAgainstVanillaSpark( sql, @@ -211,10 +212,40 @@ class GlutenEliminateJoinSuite extends GlutenClickHouseWholeStageTransformerSuit val joins = df.queryExecution.executedPlan.collect { case join: ShuffledHashJoinExecTransformerBase => join } - assert(joins.length == 2) + assert(joins.length == 1) }) } + test("reorder join orders 2") { + val sql = """ + select t1.k1, t2.k1, s1, s2 from ( + select k1, k2, count(v1) s1 from ( + select * from t2 where k1 != 1 + )group by k1, k2 + ) t1 left join ( + select k1, count(v2) as s3 from t1 where k1 != 1 + group by k1 + ) t2 on t1.k1 = t2.k1 + left join ( + select k1, k2, count(v2) s2 from ( + select * from t3 where k1 != 3 + )group by k1, k2 + ) t3 on t1.k1 = t3.k1 and t1.k2 = t3.k2 + order by t1.k1, t2.k1, s1, s2 + """.stripMargin + compareResultsAgainstVanillaSpark( + sql, + true, + { + df => + val joins = df.queryExecution.executedPlan.collect { + case join: ShuffledHashJoinExecTransformerBase => join + } + assert(joins.length == 1) + } + ) + } + test("aggregate literal") { val sql = """ select t1.k1, t1.k2, t2.k1, t2.k2, s1, s2 from ( @@ -362,6 +393,80 @@ class GlutenEliminateJoinSuite extends GlutenClickHouseWholeStageTransformerSuit } assert(joins.length == 1) }) + } + test("not attribute grouping keys 1") { + val sql = """ + select t1.k1, t2.k2, s1 from ( + select k1 + 1 as k1, count(v1) as s1 from ( + select * from t1 where k1 != 1 + )group by k1 + 1 + ) t1 left join ( + select distinct k2 from ( + select * from t2 where k1 != 3 + ) t2 + ) t2 on t1.k1 = t2.k2 + order by t1.k1, t2.k2, s1 + """.stripMargin + compareResultsAgainstVanillaSpark( + sql, + true, + { + df => + val joins = df.queryExecution.executedPlan.collect { + case join: ShuffledHashJoinExecTransformerBase => join + } + assert(joins.length == 0) + }) + } + + test("not attribute grouping keys 2") { + val sql = """ + select t1.k1, t2.k2, s1 from ( + select k1 + 1 as k1, count(v1) as s1 from ( + select * from t1 where k1 != 1 + )group by 1 + ) t1 left join ( + select distinct k2 from ( + select * from t2 where k1 != 3 + ) t2 + ) t2 on t1.k1 = t2.k2 + order by t1.k1, t2.k2, s1 + """.stripMargin + compareResultsAgainstVanillaSpark( + sql, + true, + { + df => + val joins = df.queryExecution.executedPlan.collect { + case join: ShuffledHashJoinExecTransformerBase => join + } + assert(joins.length == 0) + }) + } + + test("const expression aggregate expression") { + val sql = """ + select t1.k1, t2.k1, s1 from ( + select k1 + 1 as k1, 1 as s1 from ( + select * from t1 where k1 != 1 + )group by 1 + ) t1 left join ( + select distinct k1 from ( + select * from t2 where k1 != 3 + ) t2 + ) t2 on t1.k1 = t2.k1 + order by t1.k1, t2.k1, s1 + """.stripMargin + compareResultsAgainstVanillaSpark( + sql, + true, + { + df => + val joins = df.queryExecution.executedPlan.collect { + case join: ShuffledHashJoinExecTransformerBase => join + } + assert(joins.length == 1) + }) } } From 90c12015a9495ac1443fe30526ac0be3425d9456 Mon Sep 17 00:00:00 2001 From: lgbo-ustc Date: Thu, 3 Apr 2025 15:32:14 +0800 Subject: [PATCH 09/10] fix typo --- .../JoinAggregateToAggregateUnion.scala | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/JoinAggregateToAggregateUnion.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/JoinAggregateToAggregateUnion.scala index d6675327d8bd..7958e4a0dfc3 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/JoinAggregateToAggregateUnion.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/JoinAggregateToAggregateUnion.scala @@ -36,7 +36,7 @@ private case class JoinKeys( leftKeys: Seq[AttributeReference], rightKeys: Seq[AttributeReference]) {} -private object RulePlanHeler { +private object RulePlanHelper { def transformDistinctToAggregate(distinct: Distinct): Aggregate = { Aggregate(distinct.child.output, distinct.child.output, distinct.child) } @@ -562,8 +562,8 @@ case class ReorderJoinSubqueries() extends Logging { if (index != sameJoinKeysPlansList.length - 1) {} val sameJoinKeysPlans = sameJoinKeysPlansList.remove(index) if ( - RulePlanHeler.extractDirectAggregate(newRight).isDefined || - RulePlanHeler.extractDirectAggregate(sameJoinKeysPlans.plans.last._1).isEmpty + RulePlanHelper.extractDirectAggregate(newRight).isDefined || + RulePlanHelper.extractDirectAggregate(sameJoinKeysPlans.plans.last._1).isEmpty ) { sameJoinKeysPlans.plans += Tuple2(newRight, join) } else { @@ -577,9 +577,9 @@ case class ReorderJoinSubqueries() extends Logging { val joinRight = visitPlan(join.right) join.copy(left = joinLeft, right = joinRight) } - case subquery: SubqueryAlias if RulePlanHeler.extractDirectAggregate(subquery).isDefined => + case subquery: SubqueryAlias if RulePlanHelper.extractDirectAggregate(subquery).isDefined => val newAggregate = visitPlan(subquery.child) - val groupingKeys = RulePlanHeler.extractDirectAggregate(subquery).get.groupingExpressions + val groupingKeys = RulePlanHelper.extractDirectAggregate(subquery).get.groupingExpressions if (groupingKeys.forall(_.isInstanceOf[AttributeReference])) { val keys = groupingKeys.map(_.asInstanceOf[AttributeReference]) val index = findSameJoinKeysPlansIndex()(sameJoinKeysPlansList, keys) @@ -928,7 +928,7 @@ case class JoinAggregateToAggregateUnion(spark: SparkSession) analyzedAggregates: ArrayBuffer[JoinedAggregateAnalyzer]): Option[LogicalPlan] = { plan match { case join: Join if join.joinType == LeftOuter && join.condition.isDefined => - val optionAggregate = RulePlanHeler.extractDirectAggregate(join.right) + val optionAggregate = RulePlanHelper.extractDirectAggregate(join.right) if (optionAggregate.isEmpty) { return Some(plan) } @@ -948,8 +948,8 @@ case class JoinAggregateToAggregateUnion(spark: SparkSession) } else { Some(plan) } - case _ if RulePlanHeler.extractDirectAggregate(plan).isDefined => - val aggregate = RulePlanHeler.extractDirectAggregate(plan).get + case _ if RulePlanHelper.extractDirectAggregate(plan).isDefined => + val aggregate = RulePlanHelper.extractDirectAggregate(plan).get val lastJoin = analyzedAggregates.head.join assert(lastJoin.left.equals(plan), "The node should be last join's left child") val leftAggregateAnalyzer = JoinedAggregateAnalyzer.build(lastJoin, aggregate) From 22cde293de5850324e04ac37f8cda0c0b94c6d45 Mon Sep 17 00:00:00 2001 From: lgbo-ustc Date: Thu, 3 Apr 2025 15:57:15 +0800 Subject: [PATCH 10/10] rename --- .../JoinAggregateToAggregateUnion.scala | 49 ++++++++++++------- 1 file changed, 31 insertions(+), 18 deletions(-) diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/JoinAggregateToAggregateUnion.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/JoinAggregateToAggregateUnion.scala index 7958e4a0dfc3..3fcc2d5369da 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/JoinAggregateToAggregateUnion.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/JoinAggregateToAggregateUnion.scala @@ -28,7 +28,6 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.types._ -// import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.ListBuffer @@ -317,21 +316,27 @@ object AggregateFunctionAnalyzer { case class JoinedAggregateAnalyzer(join: Join, subquery: LogicalPlan) extends Logging { def analysis(): Boolean = { if (!extractAggregateQuery(subquery)) { + logDebug(s"xxx Not found aggregate query") return false } if (!extractGroupingKeys()) { + logDebug(s"xxx Not found grouping keys") return false } if (!extractJoinKeys()) { + logDebug(s"xxx Not found join keys") return false } if ( - keys.length != aggregate.groupingExpressions.length || - !keys.forall(k => outputGroupingKeys.exists(_.semanticEquals(k))) + joinKeys.length != aggregate.groupingExpressions.length || + !joinKeys.forall(k => outputGroupingKeys.exists(_.semanticEquals(k))) ) { + logError( + s"xxx Join keys and grouping keys are not matched. joinKeys: $joinKeys" + + s" outputGroupingKeys: $outputGroupingKeys") return false } @@ -340,7 +345,8 @@ case class JoinedAggregateAnalyzer(join: Join, subquery: LogicalPlan) extends Lo aggregateFunctionAnalyzer = aggregateExpressions.map(AggregateFunctionAnalyzer(_)) // If there is any const value in the aggregate expressions, return false - if (aggregateExpressions.length + keys.length != aggregate.aggregateExpressions.length) { + if (aggregateExpressions.length + joinKeys.length != aggregate.aggregateExpressions.length) { + logDebug(s"xxx Have const expression in aggregate expressions") return false } if ( @@ -348,6 +354,7 @@ case class JoinedAggregateAnalyzer(join: Join, subquery: LogicalPlan) extends Lo case (e, i) => !aggregateFunctionAnalyzer(i).doValidate } ) { + logDebug(s"xxx Have invalid aggregate function in aggregate expressions") return false } @@ -356,6 +363,7 @@ case class JoinedAggregateAnalyzer(join: Join, subquery: LogicalPlan) extends Lo case (_, i) => aggregateFunctionAnalyzer(i).getArgumentExpressions } if (arguments.exists(_.isEmpty)) { + logDebug(s"xxx Get aggregate function arguments failed") return false } aggregateExpressionArguments = arguments.map(_.get) @@ -375,8 +383,8 @@ case class JoinedAggregateAnalyzer(join: Join, subquery: LogicalPlan) extends Lo } } - def getPrimeKeys(): Seq[AttributeReference] = primeKeys - def getKeys(): Seq[AttributeReference] = keys + def getPrimeJoinKeys(): Seq[AttributeReference] = primeJoinKeys + def getJoinKeys(): Seq[AttributeReference] = joinKeys def getAggregate(): Aggregate = aggregate def getGroupingKeys(): Seq[Attribute] = outputGroupingKeys def getGroupingExpressions(): Seq[NamedExpression] = groupingExpressions @@ -384,8 +392,8 @@ case class JoinedAggregateAnalyzer(join: Join, subquery: LogicalPlan) extends Lo def getAggregateExpressionArguments(): Seq[Seq[Expression]] = aggregateExpressionArguments def getAggregateFunctionAnalyzers(): Seq[AggregateFunctionAnalyzer] = aggregateFunctionAnalyzer - private var primeKeys: Seq[AttributeReference] = Seq.empty - private var keys: Seq[AttributeReference] = Seq.empty + private var primeJoinKeys: Seq[AttributeReference] = Seq.empty + private var joinKeys: Seq[AttributeReference] = Seq.empty private var aggregate: Aggregate = null private var groupingExpressions: Seq[NamedExpression] = null private var outputGroupingKeys: Seq[Attribute] = null @@ -398,13 +406,13 @@ case class JoinedAggregateAnalyzer(join: Join, subquery: LogicalPlan) extends Lo val subqueryKeys = ArrayBuffer[AttributeReference]() val leftOutputSet = join.left.outputSet val subqueryOutputSet = subquery.outputSet - val joinKeys = + val joinKeysPair = RuleExpressionHelper.extractJoinKeys(join.condition, leftOutputSet, subqueryOutputSet) - if (joinKeys.isEmpty) { + if (joinKeysPair.isEmpty) { false } else { - primeKeys = joinKeys.get.leftKeys - keys = joinKeys.get.rightKeys + primeJoinKeys = joinKeysPair.get.leftKeys + joinKeys = joinKeysPair.get.rightKeys true } } @@ -462,8 +470,8 @@ object JoinedAggregateAnalyzer extends Logging { } } - def haveSamePrimeKeys(analzyers: Seq[JoinedAggregateAnalyzer]): Boolean = { - val primeKeys = analzyers.map(_.getPrimeKeys()).map(AttributeSet(_)) + def haveSamePrimeJoinKeys(analzyers: Seq[JoinedAggregateAnalyzer]): Boolean = { + val primeKeys = analzyers.map(_.getPrimeJoinKeys()).map(AttributeSet(_)) primeKeys .slice(1, primeKeys.length) .forall(keys => keys.equals(primeKeys.head)) @@ -673,7 +681,7 @@ case class JoinAggregateToAggregateUnion(spark: SparkSession) val lastJoin = analyzedAggregates.head.join lastJoin.copy(left = visitPlan(lastJoin.left), right = unionedAggregates) } else { - buildPrimeKeysFilterOnAggregateUnion(unionedAggregates, analyzedAggregates.toSeq) + buildPrimeJoinKeysFilterOnAggregateUnion(unionedAggregates, analyzedAggregates.toSeq) } } } else { @@ -763,7 +771,7 @@ case class JoinAggregateToAggregateUnion(spark: SparkSession) * Some rows may come from the right tables which grouping keys are not in the prime keys set. We * should remove them. */ - def buildPrimeKeysFilterOnAggregateUnion( + def buildPrimeJoinKeysFilterOnAggregateUnion( plan: LogicalPlan, analyzedAggregates: Seq[JoinedAggregateAnalyzer]): LogicalPlan = { val flagExpressions = plan.output(plan.output.length - analyzedAggregates.length) @@ -934,18 +942,23 @@ case class JoinAggregateToAggregateUnion(spark: SparkSession) } val rightAggregateAnalyzer = JoinedAggregateAnalyzer.build(join, optionAggregate.get) if (rightAggregateAnalyzer.isEmpty) { + logDebug(s"xxx Not a valid aggregate query") return Some(plan) } if ( analyzedAggregates.isEmpty || - JoinedAggregateAnalyzer.haveSamePrimeKeys( + JoinedAggregateAnalyzer.haveSamePrimeJoinKeys( Seq(analyzedAggregates.head, rightAggregateAnalyzer.get)) ) { // left plan is pushed in front analyzedAggregates.insert(0, rightAggregateAnalyzer.get) collectSameKeysJoinedAggregates(join.left, analyzedAggregates) } else { + logError( + s"xxx Not have same keys. join keys:" + + s"${analyzedAggregates.head.getPrimeJoinKeys()} vs. " + + s"${rightAggregateAnalyzer.get.getPrimeJoinKeys()}") Some(plan) } case _ if RulePlanHelper.extractDirectAggregate(plan).isDefined => @@ -957,7 +970,7 @@ case class JoinAggregateToAggregateUnion(spark: SparkSession) return Some(plan) } if ( - JoinedAggregateAnalyzer.haveSamePrimeKeys( + JoinedAggregateAnalyzer.haveSamePrimeJoinKeys( Seq(analyzedAggregates.head, leftAggregateAnalyzer.get)) ) { analyzedAggregates.insert(0, leftAggregateAnalyzer.get)