Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.comet.shims

import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.plans.physical.Partitioning

trait ShimCometBroadcastHashJoinExec {

/**
* Returns the expressions that are used for hash partitioning including `HashPartitioning` and
* `CoalescedHashPartitioning`. They shares same trait `HashPartitioningLike` since Spark 3.4,
* but Spark 3.2/3.3 doesn't have `HashPartitioningLike` and `CoalescedHashPartitioning`.
*
* TODO: remove after dropping Spark 3.2 and 3.3 support.
*/
def getHashPartitioningLikeExpressions(partitioning: Partitioning): Seq[Expression] = {
partitioning.getClass.getDeclaredMethods
.filter(_.getName == "expressions")
.flatMap(_.invoke(partitioning).asInstanceOf[Seq[Expression]])
}
}
157 changes: 151 additions & 6 deletions spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ package org.apache.spark.sql.comet
import java.io.{ByteArrayOutputStream, DataInputStream}
import java.nio.channels.Channels

import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer

import org.apache.spark.{SparkEnv, TaskContext}
Expand All @@ -30,13 +31,15 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Expression, NamedExpression, SortOrder}
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateMode}
import org.apache.spark.sql.catalyst.optimizer.BuildSide
import org.apache.spark.sql.catalyst.plans.JoinType
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide}
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, PartitioningCollection, UnknownPartitioning}
import org.apache.spark.sql.comet.execution.shuffle.{ArrowReaderIterator, CometShuffleExchangeExec}
import org.apache.spark.sql.comet.plans.PartitioningPreservingUnaryExecNode
import org.apache.spark.sql.comet.util.Utils
import org.apache.spark.sql.execution.{BinaryExecNode, ColumnarToRowExec, ExecSubqueryExpression, ExplainUtils, LeafExecNode, ScalarSubquery, SparkPlan, UnaryExecNode}
import org.apache.spark.sql.execution.adaptive.{AQEShuffleReadExec, BroadcastQueryStageExec, ShuffleQueryStageExec}
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
import org.apache.spark.sql.execution.exchange.ReusedExchangeExec
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
import org.apache.spark.sql.internal.SQLConf
Expand All @@ -47,6 +50,7 @@ import com.google.common.base.Objects

import org.apache.comet.{CometConf, CometExecIterator, CometRuntimeException}
import org.apache.comet.serde.OperatorOuterClass.Operator
import org.apache.comet.shims.ShimCometBroadcastHashJoinExec

/**
* A Comet physical operator
Expand All @@ -69,6 +73,10 @@ abstract class CometExec extends CometPlan {

override def outputOrdering: Seq[SortOrder] = originalPlan.outputOrdering

// `CometExec` reuses the outputPartitioning of the original SparkPlan.
// Note that if the outputPartitioning of the original SparkPlan depends on its children,
// we should override this method in the specific CometExec, because Spark AQE may change the
// outputPartitioning of SparkPlan, e.g., AQEShuffleReadExec.
override def outputPartitioning: Partitioning = originalPlan.outputPartitioning

/**
Expand Down Expand Up @@ -377,7 +385,8 @@ case class CometProjectExec(
override val output: Seq[Attribute],
child: SparkPlan,
override val serializedPlanOpt: SerializedPlan)
extends CometUnaryExec {
extends CometUnaryExec
with PartitioningPreservingUnaryExecNode {
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PartitioningPreservingUnaryExecNode implements outputPartitioning for ProjectExec.

override def producedAttributes: AttributeSet = outputSet
override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =
this.copy(child = newChild)
Expand All @@ -396,6 +405,8 @@ case class CometProjectExec(
}

override def hashCode(): Int = Objects.hashCode(projectList, output, child)

override protected def outputExpressions: Seq[NamedExpression] = projectList
}

case class CometFilterExec(
Expand All @@ -405,6 +416,9 @@ case class CometFilterExec(
child: SparkPlan,
override val serializedPlanOpt: SerializedPlan)
extends CometUnaryExec {

override def outputPartitioning: Partitioning = child.outputPartitioning

override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =
this.copy(child = newChild)

Expand Down Expand Up @@ -439,6 +453,9 @@ case class CometSortExec(
child: SparkPlan,
override val serializedPlanOpt: SerializedPlan)
extends CometUnaryExec {

override def outputPartitioning: Partitioning = child.outputPartitioning

override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =
this.copy(child = newChild)

Expand Down Expand Up @@ -471,6 +488,9 @@ case class CometLocalLimitExec(
child: SparkPlan,
override val serializedPlanOpt: SerializedPlan)
extends CometUnaryExec {

override def outputPartitioning: Partitioning = child.outputPartitioning
Comment on lines 490 to +492
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it make sense to add this outputPartitioning to CometUnaryExec as the default?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because Spark's default outputPartitioning is UnknownPartitioning, if we add child.outputPartitioning to CometUnaryExec as default, it will possibly change outputPartitioning if we don't notice it.

Currently CometExec uses original Spark plan's outputPartitioning as default which is safer option, I think. Except for the case that Spark dynamically changes output partitioning during execution like AQE, it should be correct because Comet doesn't change output partitioning from Spark.


override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =
this.copy(child = newChild)

Expand Down Expand Up @@ -498,6 +518,9 @@ case class CometGlobalLimitExec(
child: SparkPlan,
override val serializedPlanOpt: SerializedPlan)
extends CometUnaryExec {

override def outputPartitioning: Partitioning = child.outputPartitioning

override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =
this.copy(child = newChild)

Expand Down Expand Up @@ -586,7 +609,8 @@ case class CometHashAggregateExec(
mode: Option[AggregateMode],
child: SparkPlan,
override val serializedPlanOpt: SerializedPlan)
extends CometUnaryExec {
extends CometUnaryExec
with PartitioningPreservingUnaryExecNode {
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PartitioningPreservingUnaryExecNode implements outputPartitioning for HashAggregateExec.

override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =
this.copy(child = newChild)

Expand Down Expand Up @@ -618,6 +642,9 @@ case class CometHashAggregateExec(

override def hashCode(): Int =
Objects.hashCode(groupingExpressions, aggregateExpressions, input, mode, child)

override protected def outputExpressions: Seq[NamedExpression] =
originalPlan.asInstanceOf[HashAggregateExec].resultExpressions
}

case class CometHashJoinExec(
Expand All @@ -632,6 +659,18 @@ case class CometHashJoinExec(
override val right: SparkPlan,
override val serializedPlanOpt: SerializedPlan)
extends CometBinaryExec {

override def outputPartitioning: Partitioning = joinType match {
case _: InnerLike =>
PartitioningCollection(Seq(left.outputPartitioning, right.outputPartitioning))
case LeftOuter => left.outputPartitioning
case RightOuter => right.outputPartitioning
case FullOuter => UnknownPartitioning(left.outputPartitioning.numPartitions)
case LeftExistence(_) => left.outputPartitioning
case x =>
throw new IllegalArgumentException(s"ShuffledJoin should not take $x as the JoinType")
}

override def withNewChildrenInternal(newLeft: SparkPlan, newRight: SparkPlan): SparkPlan =
this.copy(left = newLeft, right = newRight)

Expand Down Expand Up @@ -668,7 +707,101 @@ case class CometBroadcastHashJoinExec(
override val left: SparkPlan,
override val right: SparkPlan,
override val serializedPlanOpt: SerializedPlan)
extends CometBinaryExec {
extends CometBinaryExec
with ShimCometBroadcastHashJoinExec {

// The following logic of `outputPartitioning` is copied from Spark `BroadcastHashJoinExec`.
protected lazy val streamedPlan: SparkPlan = buildSide match {
case BuildLeft => right
case BuildRight => left
}

override lazy val outputPartitioning: Partitioning = {
joinType match {
case _: InnerLike if conf.broadcastHashJoinOutputPartitioningExpandLimit > 0 =>
streamedPlan.outputPartitioning match {
case h: HashPartitioning => expandOutputPartitioning(h)
case h: Expression if h.getClass.getName.contains("CoalescedHashPartitioning") =>
expandOutputPartitioning(h)
case c: PartitioningCollection => expandOutputPartitioning(c)
case other => other
}
case _ => streamedPlan.outputPartitioning
}
}

protected lazy val (buildKeys, streamedKeys) = {
require(
leftKeys.length == rightKeys.length &&
leftKeys
.map(_.dataType)
.zip(rightKeys.map(_.dataType))
.forall(types => types._1.sameType(types._2)),
"Join keys from two sides should have same length and types")
buildSide match {
case BuildLeft => (leftKeys, rightKeys)
case BuildRight => (rightKeys, leftKeys)
}
}

// An one-to-many mapping from a streamed key to build keys.
private lazy val streamedKeyToBuildKeyMapping = {
val mapping = mutable.Map.empty[Expression, Seq[Expression]]
streamedKeys.zip(buildKeys).foreach { case (streamedKey, buildKey) =>
val key = streamedKey.canonicalized
mapping.get(key) match {
case Some(v) => mapping.put(key, v :+ buildKey)
case None => mapping.put(key, Seq(buildKey))
}
}
mapping.toMap
}

// Expands the given partitioning collection recursively.
private def expandOutputPartitioning(
partitioning: PartitioningCollection): PartitioningCollection = {
PartitioningCollection(partitioning.partitionings.flatMap {
case h: HashPartitioning => expandOutputPartitioning(h).partitionings
case h: Expression if h.getClass.getName.contains("CoalescedHashPartitioning") =>
expandOutputPartitioning(h).partitionings
case c: PartitioningCollection => Seq(expandOutputPartitioning(c))
case other => Seq(other)
})
}

// Expands the given hash partitioning by substituting streamed keys with build keys.
// For example, if the expressions for the given partitioning are Seq("a", "b", "c")
// where the streamed keys are Seq("b", "c") and the build keys are Seq("x", "y"),
// the expanded partitioning will have the following expressions:
// Seq("a", "b", "c"), Seq("a", "b", "y"), Seq("a", "x", "c"), Seq("a", "x", "y").
// The expanded expressions are returned as PartitioningCollection.
private def expandOutputPartitioning(
partitioning: Partitioning with Expression): PartitioningCollection = {
val maxNumCombinations = conf.broadcastHashJoinOutputPartitioningExpandLimit
var currentNumCombinations = 0

def generateExprCombinations(
current: Seq[Expression],
accumulated: Seq[Expression]): Seq[Seq[Expression]] = {
if (currentNumCombinations >= maxNumCombinations) {
Nil
} else if (current.isEmpty) {
currentNumCombinations += 1
Seq(accumulated)
} else {
val buildKeysOpt = streamedKeyToBuildKeyMapping.get(current.head.canonicalized)
generateExprCombinations(current.tail, accumulated :+ current.head) ++
buildKeysOpt
.map(_.flatMap(b => generateExprCombinations(current.tail, accumulated :+ b)))
.getOrElse(Nil)
}
}

PartitioningCollection(
generateExprCombinations(getHashPartitioningLikeExpressions(partitioning), Nil)
.map(exprs => partitioning.withNewChildren(exprs).asInstanceOf[Partitioning]))
}

override def withNewChildrenInternal(newLeft: SparkPlan, newRight: SparkPlan): SparkPlan =
this.copy(left = newLeft, right = newRight)

Expand Down Expand Up @@ -705,6 +838,18 @@ case class CometSortMergeJoinExec(
override val right: SparkPlan,
override val serializedPlanOpt: SerializedPlan)
extends CometBinaryExec {

override def outputPartitioning: Partitioning = joinType match {
case _: InnerLike =>
PartitioningCollection(Seq(left.outputPartitioning, right.outputPartitioning))
case LeftOuter => left.outputPartitioning
case RightOuter => right.outputPartitioning
case FullOuter => UnknownPartitioning(left.outputPartitioning.numPartitions)
case LeftExistence(_) => left.outputPartitioning
case x =>
throw new IllegalArgumentException(s"ShuffledJoin should not take $x as the JoinType")
}

override def withNewChildrenInternal(newLeft: SparkPlan, newRight: SparkPlan): SparkPlan =
this.copy(left = newLeft, right = newRight)

Expand Down
Loading