-
Notifications
You must be signed in to change notification settings - Fork 306
fix: CometExec's outputPartitioning might not be same as Spark expects after AQE interferes #299
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,39 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one | ||
| * or more contributor license agreements. See the NOTICE file | ||
| * distributed with this work for additional information | ||
| * regarding copyright ownership. The ASF licenses this file | ||
| * to you under the Apache License, Version 2.0 (the | ||
| * "License"); you may not use this file except in compliance | ||
| * with the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, | ||
| * software distributed under the License is distributed on an | ||
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
| * KIND, either express or implied. See the License for the | ||
| * specific language governing permissions and limitations | ||
| * under the License. | ||
| */ | ||
|
|
||
| package org.apache.comet.shims | ||
|
|
||
| import org.apache.spark.sql.catalyst.expressions.Expression | ||
| import org.apache.spark.sql.catalyst.plans.physical.Partitioning | ||
|
|
||
| trait ShimCometBroadcastHashJoinExec { | ||
|
|
||
| /** | ||
| * Returns the expressions that are used for hash partitioning including `HashPartitioning` and | ||
| * `CoalescedHashPartitioning`. They shares same trait `HashPartitioningLike` since Spark 3.4, | ||
| * but Spark 3.2/3.3 doesn't have `HashPartitioningLike` and `CoalescedHashPartitioning`. | ||
| * | ||
| * TODO: remove after dropping Spark 3.2 and 3.3 support. | ||
| */ | ||
| def getHashPartitioningLikeExpressions(partitioning: Partitioning): Seq[Expression] = { | ||
| partitioning.getClass.getDeclaredMethods | ||
| .filter(_.getName == "expressions") | ||
| .flatMap(_.invoke(partitioning).asInstanceOf[Seq[Expression]]) | ||
| } | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -22,6 +22,7 @@ package org.apache.spark.sql.comet | |
| import java.io.{ByteArrayOutputStream, DataInputStream} | ||
| import java.nio.channels.Channels | ||
|
|
||
| import scala.collection.mutable | ||
| import scala.collection.mutable.ArrayBuffer | ||
|
|
||
| import org.apache.spark.{SparkEnv, TaskContext} | ||
|
|
@@ -30,13 +31,15 @@ import org.apache.spark.rdd.RDD | |
| import org.apache.spark.sql.catalyst.InternalRow | ||
| import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Expression, NamedExpression, SortOrder} | ||
| import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateMode} | ||
| import org.apache.spark.sql.catalyst.optimizer.BuildSide | ||
| import org.apache.spark.sql.catalyst.plans.JoinType | ||
| import org.apache.spark.sql.catalyst.plans.physical.Partitioning | ||
| import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide} | ||
| import org.apache.spark.sql.catalyst.plans._ | ||
| import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, PartitioningCollection, UnknownPartitioning} | ||
| import org.apache.spark.sql.comet.execution.shuffle.{ArrowReaderIterator, CometShuffleExchangeExec} | ||
| import org.apache.spark.sql.comet.plans.PartitioningPreservingUnaryExecNode | ||
| import org.apache.spark.sql.comet.util.Utils | ||
| import org.apache.spark.sql.execution.{BinaryExecNode, ColumnarToRowExec, ExecSubqueryExpression, ExplainUtils, LeafExecNode, ScalarSubquery, SparkPlan, UnaryExecNode} | ||
| import org.apache.spark.sql.execution.adaptive.{AQEShuffleReadExec, BroadcastQueryStageExec, ShuffleQueryStageExec} | ||
| import org.apache.spark.sql.execution.aggregate.HashAggregateExec | ||
| import org.apache.spark.sql.execution.exchange.ReusedExchangeExec | ||
| import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} | ||
| import org.apache.spark.sql.internal.SQLConf | ||
|
|
@@ -47,6 +50,7 @@ import com.google.common.base.Objects | |
|
|
||
| import org.apache.comet.{CometConf, CometExecIterator, CometRuntimeException} | ||
| import org.apache.comet.serde.OperatorOuterClass.Operator | ||
| import org.apache.comet.shims.ShimCometBroadcastHashJoinExec | ||
|
|
||
| /** | ||
| * A Comet physical operator | ||
|
|
@@ -69,6 +73,10 @@ abstract class CometExec extends CometPlan { | |
|
|
||
| override def outputOrdering: Seq[SortOrder] = originalPlan.outputOrdering | ||
|
|
||
| // `CometExec` reuses the outputPartitioning of the original SparkPlan. | ||
| // Note that if the outputPartitioning of the original SparkPlan depends on its children, | ||
| // we should override this method in the specific CometExec, because Spark AQE may change the | ||
| // outputPartitioning of SparkPlan, e.g., AQEShuffleReadExec. | ||
| override def outputPartitioning: Partitioning = originalPlan.outputPartitioning | ||
|
|
||
| /** | ||
|
|
@@ -377,7 +385,8 @@ case class CometProjectExec( | |
| override val output: Seq[Attribute], | ||
| child: SparkPlan, | ||
| override val serializedPlanOpt: SerializedPlan) | ||
| extends CometUnaryExec { | ||
| extends CometUnaryExec | ||
| with PartitioningPreservingUnaryExecNode { | ||
| override def producedAttributes: AttributeSet = outputSet | ||
| override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = | ||
| this.copy(child = newChild) | ||
|
|
@@ -396,6 +405,8 @@ case class CometProjectExec( | |
| } | ||
|
|
||
| override def hashCode(): Int = Objects.hashCode(projectList, output, child) | ||
|
|
||
| override protected def outputExpressions: Seq[NamedExpression] = projectList | ||
| } | ||
|
|
||
| case class CometFilterExec( | ||
|
|
@@ -405,6 +416,9 @@ case class CometFilterExec( | |
| child: SparkPlan, | ||
| override val serializedPlanOpt: SerializedPlan) | ||
| extends CometUnaryExec { | ||
|
|
||
| override def outputPartitioning: Partitioning = child.outputPartitioning | ||
|
|
||
| override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = | ||
| this.copy(child = newChild) | ||
|
|
||
|
|
@@ -439,6 +453,9 @@ case class CometSortExec( | |
| child: SparkPlan, | ||
| override val serializedPlanOpt: SerializedPlan) | ||
| extends CometUnaryExec { | ||
|
|
||
| override def outputPartitioning: Partitioning = child.outputPartitioning | ||
|
|
||
| override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = | ||
| this.copy(child = newChild) | ||
|
|
||
|
|
@@ -471,6 +488,9 @@ case class CometLocalLimitExec( | |
| child: SparkPlan, | ||
| override val serializedPlanOpt: SerializedPlan) | ||
| extends CometUnaryExec { | ||
|
|
||
| override def outputPartitioning: Partitioning = child.outputPartitioning | ||
|
Comment on lines
490
to
+492
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would it make sense to add this
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Because Spark's default Currently |
||
|
|
||
| override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = | ||
| this.copy(child = newChild) | ||
|
|
||
|
|
@@ -498,6 +518,9 @@ case class CometGlobalLimitExec( | |
| child: SparkPlan, | ||
| override val serializedPlanOpt: SerializedPlan) | ||
| extends CometUnaryExec { | ||
|
|
||
| override def outputPartitioning: Partitioning = child.outputPartitioning | ||
|
|
||
| override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = | ||
| this.copy(child = newChild) | ||
|
|
||
|
|
@@ -586,7 +609,8 @@ case class CometHashAggregateExec( | |
| mode: Option[AggregateMode], | ||
| child: SparkPlan, | ||
| override val serializedPlanOpt: SerializedPlan) | ||
| extends CometUnaryExec { | ||
| extends CometUnaryExec | ||
| with PartitioningPreservingUnaryExecNode { | ||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. PartitioningPreservingUnaryExecNode implements outputPartitioning for |
||
| override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = | ||
| this.copy(child = newChild) | ||
|
|
||
|
|
@@ -618,6 +642,9 @@ case class CometHashAggregateExec( | |
|
|
||
| override def hashCode(): Int = | ||
| Objects.hashCode(groupingExpressions, aggregateExpressions, input, mode, child) | ||
|
|
||
| override protected def outputExpressions: Seq[NamedExpression] = | ||
| originalPlan.asInstanceOf[HashAggregateExec].resultExpressions | ||
| } | ||
|
|
||
| case class CometHashJoinExec( | ||
|
|
@@ -632,6 +659,18 @@ case class CometHashJoinExec( | |
| override val right: SparkPlan, | ||
| override val serializedPlanOpt: SerializedPlan) | ||
| extends CometBinaryExec { | ||
|
|
||
| override def outputPartitioning: Partitioning = joinType match { | ||
| case _: InnerLike => | ||
| PartitioningCollection(Seq(left.outputPartitioning, right.outputPartitioning)) | ||
| case LeftOuter => left.outputPartitioning | ||
| case RightOuter => right.outputPartitioning | ||
| case FullOuter => UnknownPartitioning(left.outputPartitioning.numPartitions) | ||
| case LeftExistence(_) => left.outputPartitioning | ||
| case x => | ||
| throw new IllegalArgumentException(s"ShuffledJoin should not take $x as the JoinType") | ||
| } | ||
|
|
||
| override def withNewChildrenInternal(newLeft: SparkPlan, newRight: SparkPlan): SparkPlan = | ||
| this.copy(left = newLeft, right = newRight) | ||
|
|
||
|
|
@@ -668,7 +707,101 @@ case class CometBroadcastHashJoinExec( | |
| override val left: SparkPlan, | ||
| override val right: SparkPlan, | ||
| override val serializedPlanOpt: SerializedPlan) | ||
| extends CometBinaryExec { | ||
| extends CometBinaryExec | ||
| with ShimCometBroadcastHashJoinExec { | ||
|
|
||
| // The following logic of `outputPartitioning` is copied from Spark `BroadcastHashJoinExec`. | ||
| protected lazy val streamedPlan: SparkPlan = buildSide match { | ||
| case BuildLeft => right | ||
| case BuildRight => left | ||
| } | ||
|
|
||
| override lazy val outputPartitioning: Partitioning = { | ||
| joinType match { | ||
| case _: InnerLike if conf.broadcastHashJoinOutputPartitioningExpandLimit > 0 => | ||
| streamedPlan.outputPartitioning match { | ||
| case h: HashPartitioning => expandOutputPartitioning(h) | ||
| case h: Expression if h.getClass.getName.contains("CoalescedHashPartitioning") => | ||
| expandOutputPartitioning(h) | ||
| case c: PartitioningCollection => expandOutputPartitioning(c) | ||
| case other => other | ||
| } | ||
| case _ => streamedPlan.outputPartitioning | ||
| } | ||
| } | ||
|
|
||
| protected lazy val (buildKeys, streamedKeys) = { | ||
| require( | ||
| leftKeys.length == rightKeys.length && | ||
| leftKeys | ||
| .map(_.dataType) | ||
| .zip(rightKeys.map(_.dataType)) | ||
| .forall(types => types._1.sameType(types._2)), | ||
| "Join keys from two sides should have same length and types") | ||
| buildSide match { | ||
| case BuildLeft => (leftKeys, rightKeys) | ||
| case BuildRight => (rightKeys, leftKeys) | ||
| } | ||
| } | ||
|
|
||
| // An one-to-many mapping from a streamed key to build keys. | ||
| private lazy val streamedKeyToBuildKeyMapping = { | ||
| val mapping = mutable.Map.empty[Expression, Seq[Expression]] | ||
| streamedKeys.zip(buildKeys).foreach { case (streamedKey, buildKey) => | ||
| val key = streamedKey.canonicalized | ||
| mapping.get(key) match { | ||
| case Some(v) => mapping.put(key, v :+ buildKey) | ||
| case None => mapping.put(key, Seq(buildKey)) | ||
| } | ||
| } | ||
| mapping.toMap | ||
| } | ||
|
|
||
| // Expands the given partitioning collection recursively. | ||
| private def expandOutputPartitioning( | ||
| partitioning: PartitioningCollection): PartitioningCollection = { | ||
| PartitioningCollection(partitioning.partitionings.flatMap { | ||
| case h: HashPartitioning => expandOutputPartitioning(h).partitionings | ||
| case h: Expression if h.getClass.getName.contains("CoalescedHashPartitioning") => | ||
| expandOutputPartitioning(h).partitionings | ||
| case c: PartitioningCollection => Seq(expandOutputPartitioning(c)) | ||
| case other => Seq(other) | ||
| }) | ||
| } | ||
|
|
||
| // Expands the given hash partitioning by substituting streamed keys with build keys. | ||
| // For example, if the expressions for the given partitioning are Seq("a", "b", "c") | ||
| // where the streamed keys are Seq("b", "c") and the build keys are Seq("x", "y"), | ||
| // the expanded partitioning will have the following expressions: | ||
| // Seq("a", "b", "c"), Seq("a", "b", "y"), Seq("a", "x", "c"), Seq("a", "x", "y"). | ||
| // The expanded expressions are returned as PartitioningCollection. | ||
| private def expandOutputPartitioning( | ||
| partitioning: Partitioning with Expression): PartitioningCollection = { | ||
| val maxNumCombinations = conf.broadcastHashJoinOutputPartitioningExpandLimit | ||
| var currentNumCombinations = 0 | ||
|
|
||
| def generateExprCombinations( | ||
| current: Seq[Expression], | ||
| accumulated: Seq[Expression]): Seq[Seq[Expression]] = { | ||
| if (currentNumCombinations >= maxNumCombinations) { | ||
| Nil | ||
| } else if (current.isEmpty) { | ||
| currentNumCombinations += 1 | ||
| Seq(accumulated) | ||
| } else { | ||
| val buildKeysOpt = streamedKeyToBuildKeyMapping.get(current.head.canonicalized) | ||
| generateExprCombinations(current.tail, accumulated :+ current.head) ++ | ||
| buildKeysOpt | ||
| .map(_.flatMap(b => generateExprCombinations(current.tail, accumulated :+ b))) | ||
| .getOrElse(Nil) | ||
| } | ||
| } | ||
|
|
||
| PartitioningCollection( | ||
| generateExprCombinations(getHashPartitioningLikeExpressions(partitioning), Nil) | ||
| .map(exprs => partitioning.withNewChildren(exprs).asInstanceOf[Partitioning])) | ||
| } | ||
|
|
||
| override def withNewChildrenInternal(newLeft: SparkPlan, newRight: SparkPlan): SparkPlan = | ||
| this.copy(left = newLeft, right = newRight) | ||
|
|
||
|
|
@@ -705,6 +838,18 @@ case class CometSortMergeJoinExec( | |
| override val right: SparkPlan, | ||
| override val serializedPlanOpt: SerializedPlan) | ||
| extends CometBinaryExec { | ||
|
|
||
| override def outputPartitioning: Partitioning = joinType match { | ||
| case _: InnerLike => | ||
| PartitioningCollection(Seq(left.outputPartitioning, right.outputPartitioning)) | ||
| case LeftOuter => left.outputPartitioning | ||
| case RightOuter => right.outputPartitioning | ||
| case FullOuter => UnknownPartitioning(left.outputPartitioning.numPartitions) | ||
| case LeftExistence(_) => left.outputPartitioning | ||
| case x => | ||
| throw new IllegalArgumentException(s"ShuffledJoin should not take $x as the JoinType") | ||
| } | ||
|
|
||
| override def withNewChildrenInternal(newLeft: SparkPlan, newRight: SparkPlan): SparkPlan = | ||
| this.copy(left = newLeft, right = newRight) | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PartitioningPreservingUnaryExecNodeimplementsoutputPartitioningfor ProjectExec.