From 30a325a9c7531abab361200431a3f1529c334693 Mon Sep 17 00:00:00 2001 From: Kristin Cowalcijk Date: Thu, 3 Apr 2025 11:54:51 +0800 Subject: [PATCH 1/6] Override outputPartitioning in CometBroadcastExchangeExec to make AQE capable of converting comet shuffled joins to comet broadcast hash joins --- .../arrow}/ArrowReaderIterator.scala | 2 +- .../apache/spark/sql/comet/util/Utils.scala | 25 ++++- .../comet/CometSparkSessionExtensions.scala | 14 ++- .../apache/comet/serde/QueryPlanSerde.scala | 2 + .../comet/CometBroadcastExchangeExec.scala | 33 ++----- .../sql/comet/CometColumnarToRowExec.scala | 93 ++++++++++++++++++- 6 files changed, 140 insertions(+), 29 deletions(-) rename {spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle => common/src/main/scala/org/apache/spark/sql/comet/execution/arrow}/ArrowReaderIterator.scala (97%) diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/ArrowReaderIterator.scala b/common/src/main/scala/org/apache/spark/sql/comet/execution/arrow/ArrowReaderIterator.scala similarity index 97% rename from spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/ArrowReaderIterator.scala rename to common/src/main/scala/org/apache/spark/sql/comet/execution/arrow/ArrowReaderIterator.scala index 933e0b6618..0d0093a107 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/ArrowReaderIterator.scala +++ b/common/src/main/scala/org/apache/spark/sql/comet/execution/arrow/ArrowReaderIterator.scala @@ -17,7 +17,7 @@ * under the License. */ -package org.apache.spark.sql.comet.execution.shuffle +package org.apache.spark.sql.comet.execution.arrow import java.nio.channels.ReadableByteChannel diff --git a/common/src/main/scala/org/apache/spark/sql/comet/util/Utils.scala b/common/src/main/scala/org/apache/spark/sql/comet/util/Utils.scala index 207474286e..1b42751802 100644 --- a/common/src/main/scala/org/apache/spark/sql/comet/util/Utils.scala +++ b/common/src/main/scala/org/apache/spark/sql/comet/util/Utils.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.comet.util -import java.io.{DataOutputStream, File} +import java.io.{DataInputStream, DataOutputStream, File} import java.nio.ByteBuffer import java.nio.channels.Channels @@ -35,6 +35,7 @@ import org.apache.arrow.vector.types._ import org.apache.arrow.vector.types.pojo.{ArrowType, Field, FieldType, Schema} import org.apache.spark.{SparkEnv, SparkException} import org.apache.spark.io.CompressionCodec +import org.apache.spark.sql.comet.execution.arrow.ArrowReaderIterator import org.apache.spark.sql.types._ import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.util.io.{ChunkedByteBuffer, ChunkedByteBufferOutputStream} @@ -226,6 +227,28 @@ object Utils { } } + /** + * Decodes the byte arrays back to ColumnarBatchs and put them into buffer. + * @param bytes + * the serialized batches + * @param source + * the class that calls this method + * @return + * an iterator of ColumnarBatch + */ + def decodeBatches(bytes: ChunkedByteBuffer, source: String): Iterator[ColumnarBatch] = { + if (bytes.size == 0) { + return Iterator.empty + } + + // use Spark's compression codec (LZ4 by default) and not Comet's compression + val codec = CompressionCodec.createCodec(SparkEnv.get.conf) + val cbbis = bytes.toInputStream() + val ins = new DataInputStream(codec.compressedInputStream(cbbis)) + // batches are in Arrow IPC format + new ArrowReaderIterator(Channels.newChannel(ins), source) + } + def getBatchFieldVectors( batch: ColumnarBatch): (Seq[FieldVector], Option[DictionaryProvider]) = { var provider: Option[DictionaryProvider] = None diff --git a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala index f3d031795d..3153d14ab5 100644 --- a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala +++ b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala @@ -728,6 +728,18 @@ class CometSparkSessionExtensions s } + case s @ BroadcastQueryStageExec( + _, + ReusedExchangeExec(_, _: CometBroadcastExchangeExec), + _) => + val newOp = transform1(s) + newOp match { + case Some(nativeOp) => + CometSinkPlaceHolder(nativeOp, s, s) + case None => + s + } + // `CometBroadcastExchangeExec`'s broadcast output is not compatible with Spark's broadcast // exchange. It is only used for Comet native execution. We only transform Spark broadcast // exchange to Comet broadcast exchange if its downstream is a Comet native plan or if the @@ -739,7 +751,7 @@ class CometSparkSessionExtensions CometConf.COMET_EXEC_BROADCAST_EXCHANGE_ENABLED.get(conf) => QueryPlanSerde.operator2Proto(b) match { case Some(nativeOp) => - val cometOp = CometBroadcastExchangeExec(b, b.output, b.child) + val cometOp = CometBroadcastExchangeExec(b, b.output, b.mode, b.child) CometSinkPlaceHolder(nativeOp, b, cometOp) case None => b } diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 4f8f8ee863..50b4187371 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -2765,6 +2765,8 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim case ShuffleQueryStageExec(_, ReusedExchangeExec(_, _: CometShuffleExchangeExec), _) => true case _: TakeOrderedAndProjectExec => true case BroadcastQueryStageExec(_, _: CometBroadcastExchangeExec, _) => true + case BroadcastQueryStageExec(_, ReusedExchangeExec(_, _: CometBroadcastExchangeExec), _) => + true case _: BroadcastExchangeExec => true case _: WindowExec => true case _ => false diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala index 3285159be3..c957e17081 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala @@ -19,8 +19,6 @@ package org.apache.spark.sql.comet -import java.io.DataInputStream -import java.nio.channels.Channels import java.util.UUID import java.util.concurrent.{Future, TimeoutException, TimeUnit} @@ -28,13 +26,13 @@ import scala.concurrent.{ExecutionContext, Promise} import scala.concurrent.duration.NANOSECONDS import scala.util.control.NonFatal -import org.apache.spark.{broadcast, Partition, SparkContext, SparkEnv, TaskContext} -import org.apache.spark.io.CompressionCodec +import org.apache.spark.{broadcast, Partition, SparkContext, TaskContext} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.Statistics -import org.apache.spark.sql.comet.execution.shuffle.ArrowReaderIterator +import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, BroadcastPartitioning, Partitioning} +import org.apache.spark.sql.comet.util.Utils import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution.{ColumnarToRowExec, SparkPlan, SQLExecution} import org.apache.spark.sql.execution.adaptive.{AQEShuffleReadExec, ShuffleQueryStageExec} @@ -63,6 +61,7 @@ import org.apache.comet.CometRuntimeException case class CometBroadcastExchangeExec( originalPlan: SparkPlan, override val output: Seq[Attribute], + mode: BroadcastMode, override val child: SparkPlan) extends BroadcastExchangeLike { import CometBroadcastExchangeExec._ @@ -77,7 +76,7 @@ case class CometBroadcastExchangeExec( "broadcastTime" -> SQLMetrics.createTimingMetric(sparkContext, "time to broadcast")) override def doCanonicalize(): SparkPlan = { - CometBroadcastExchangeExec(null, null, child.canonicalized) + CometBroadcastExchangeExec(null, null, mode, child.canonicalized) } override def runtimeStatistics: Statistics = { @@ -86,6 +85,8 @@ case class CometBroadcastExchangeExec( Statistics(dataSize, Some(rowCount)) } + override def outputPartitioning: Partitioning = BroadcastPartitioning(mode) + @transient private lazy val promise = Promise[broadcast.Broadcast[Any]]() @@ -250,7 +251,7 @@ case class CometBroadcastExchangeExec( override def hashCode(): Int = Objects.hashCode(child) - override def stringArgs: Iterator[Any] = Iterator(output, child) + override def stringArgs: Iterator[Any] = Iterator(output, mode, child) override protected def withNewChildInternal(newChild: SparkPlan): CometBroadcastExchangeExec = copy(child = newChild) @@ -289,23 +290,7 @@ class CometBatchRDD( override def compute(split: Partition, context: TaskContext): Iterator[ColumnarBatch] = { val partition = split.asInstanceOf[CometBatchPartition] partition.value.value.toIterator - .flatMap(decodeBatches(_, this.getClass.getSimpleName)) - } - - /** - * Decodes the byte arrays back to ColumnarBatchs and put them into buffer. - */ - private def decodeBatches(bytes: ChunkedByteBuffer, source: String): Iterator[ColumnarBatch] = { - if (bytes.size == 0) { - return Iterator.empty - } - - // use Spark's compression codec (LZ4 by default) and not Comet's compression - val codec = CompressionCodec.createCodec(SparkEnv.get.conf) - val cbbis = bytes.toInputStream() - val ins = new DataInputStream(codec.compressedInputStream(cbbis)) - // batches are in Arrow IPC format - new ArrowReaderIterator(Channels.newChannel(ins), source) + .flatMap(Utils.decodeBatches(_, this.getClass.getSimpleName)) } } diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometColumnarToRowExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometColumnarToRowExec.scala index 18d95a473c..cc150ef535 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometColumnarToRowExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometColumnarToRowExec.scala @@ -19,20 +19,29 @@ package org.apache.spark.sql.comet +import java.util.UUID +import java.util.concurrent.{Future, TimeoutException, TimeUnit} + import scala.collection.JavaConverters._ +import scala.concurrent.Promise +import scala.util.control.NonFatal +import org.apache.spark.{broadcast, SparkException} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, SortOrder, UnsafeProjection} import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.plans.physical.Partitioning -import org.apache.spark.sql.execution.{CodegenSupport, ColumnarToRowTransition, SparkPlan} +import org.apache.spark.sql.comet.util.{Utils => CometUtils} +import org.apache.spark.sql.errors.QueryExecutionErrors +import org.apache.spark.sql.execution.{CodegenSupport, ColumnarToRowTransition, SparkPlan, SQLExecution} import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} import org.apache.spark.sql.execution.vectorized.{ConstantColumnVector, WritableColumnVector} import org.apache.spark.sql.types._ import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector} -import org.apache.spark.util.Utils +import org.apache.spark.util.{SparkFatalException, Utils} +import org.apache.spark.util.io.ChunkedByteBuffer import org.apache.comet.vector.CometPlainVector @@ -76,6 +85,86 @@ case class CometColumnarToRowExec(child: SparkPlan) } } + @transient + private lazy val promise = Promise[broadcast.Broadcast[Any]]() + + @transient + private val timeout: Long = conf.broadcastTimeout + + private val runId: UUID = UUID.randomUUID + + @transient + lazy val relationFuture: Future[broadcast.Broadcast[Any]] = { + SQLExecution.withThreadLocalCaptured[broadcast.Broadcast[Any]]( + session, + CometBroadcastExchangeExec.executionContext) { + try { + // Setup a job group here so later it may get cancelled by groupId if necessary. + sparkContext.setJobGroup( + runId.toString, + s"ColumnarToRow broadcast exchange (runId $runId)", + interruptOnCancel = true) + + val numOutputRows = longMetric("numOutputRows") + val numInputBatches = longMetric("numInputBatches") + val localOutput = this.output + val broadcastColumnar = child.executeBroadcast() + val serializedBatches = broadcastColumnar.value.asInstanceOf[Array[ChunkedByteBuffer]] + val toUnsafe = UnsafeProjection.create(localOutput, localOutput) + val rows = serializedBatches.iterator + .flatMap(CometUtils.decodeBatches(_, this.getClass.getSimpleName)) + .flatMap { batch => + numInputBatches += 1 + numOutputRows += batch.numRows() + batch.rowIterator().asScala.map(toUnsafe) + } + + val mode = child.asInstanceOf[CometBroadcastExchangeExec].mode + val relation = mode.transform(rows, Some(numOutputRows.value)) + val broadcasted = sparkContext.broadcastInternal(relation, serializedOnly = true) + val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) + SQLMetrics.postDriverMetricUpdates(sparkContext, executionId, metrics.values.toSeq) + promise.trySuccess(broadcasted) + broadcasted + } catch { + // SPARK-24294: To bypass scala bug: https://github.com/scala/bug/issues/9554, we throw + // SparkFatalException, which is a subclass of Exception. ThreadUtils.awaitResult + // will catch this exception and re-throw the wrapped fatal throwable. + case oe: OutOfMemoryError => + val ex = new SparkFatalException(oe) + promise.tryFailure(ex) + throw ex + case e if !NonFatal(e) => + val ex = new SparkFatalException(e) + promise.tryFailure(ex) + throw ex + case e: Throwable => + promise.tryFailure(e) + throw e + } + } + } + + override def doExecuteBroadcast[T](): broadcast.Broadcast[T] = { + if (!child.isInstanceOf[CometBroadcastExchangeExec]) { + throw new SparkException( + "ColumnarToRowExec only supports doExecuteBroadcast when child is " + + "CometBroadcastExchange, but got " + child.nodeName) + } + + try { + relationFuture.get(timeout, TimeUnit.SECONDS).asInstanceOf[broadcast.Broadcast[T]] + } catch { + case ex: TimeoutException => + logError(s"Could not execute broadcast in $timeout secs.", ex) + if (!relationFuture.isDone) { + sparkContext.cancelJobGroup(runId.toString) + relationFuture.cancel(true) + } + throw QueryExecutionErrors.executeBroadcastTimeoutError(timeout, Some(ex)) + } + } + /** * Generate [[ColumnVector]] expressions for our parent to consume as rows. This is called once * per [[ColumnVector]] in the batch. From d20468f8d4b2c4fd48f11be116f71ac23420dd30 Mon Sep 17 00:00:00 2001 From: Kristin Cowalcijk Date: Thu, 3 Apr 2025 13:53:27 +0800 Subject: [PATCH 2/6] Support executeBroadcast for CometColumnarToRow --- .../comet/CometSparkSessionExtensions.scala | 13 ++++++++++-- .../comet/CometBroadcastExchangeExec.scala | 3 ++- .../sql/comet/CometColumnarToRowExec.scala | 21 ++++++++++++++----- 3 files changed, 29 insertions(+), 8 deletions(-) diff --git a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala index 3153d14ab5..daa1f19ea3 100644 --- a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala +++ b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala @@ -21,6 +21,7 @@ package org.apache.comet import java.nio.ByteOrder +import scala.annotation.tailrec import scala.collection.mutable.ListBuffer import org.apache.spark.SparkConf @@ -37,7 +38,7 @@ import org.apache.spark.sql.comet._ import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, CometNativeShuffle, CometShuffleExchangeExec, CometShuffleManager} import org.apache.spark.sql.comet.util.Utils import org.apache.spark.sql.execution._ -import org.apache.spark.sql.execution.adaptive.{AQEShuffleReadExec, BroadcastQueryStageExec, ShuffleQueryStageExec} +import org.apache.spark.sql.execution.adaptive.{AQEShuffleReadExec, BroadcastQueryStageExec, QueryStageExec, ShuffleQueryStageExec} import org.apache.spark.sql.execution.aggregate.{BaseAggregateExec, HashAggregateExec, ObjectHashAggregateExec} import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat @@ -1148,7 +1149,7 @@ class CometSparkSessionExtensions // and CometSparkToColumnarExec sparkToColumnar.child } - case c @ ColumnarToRowExec(child) if child.exists(_.isInstanceOf[CometPlan]) => + case c @ ColumnarToRowExec(child) if hasCometNativeChild(child) => val op = CometColumnarToRowExec(child) if (c.logicalLink.isEmpty) { op.unsetTagValue(SparkPlan.LOGICAL_PLAN_TAG) @@ -1182,6 +1183,14 @@ class CometSparkSessionExtensions } } } + + @tailrec + private def hasCometNativeChild(op: SparkPlan): Boolean = { + op match { + case c: QueryStageExec => hasCometNativeChild(c.plan) + case _ => op.exists(_.isInstanceOf[CometPlan]) + } + } } object CometSparkSessionExtensions extends Logging { diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala index c957e17081..a991964b6c 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala @@ -63,7 +63,8 @@ case class CometBroadcastExchangeExec( override val output: Seq[Attribute], mode: BroadcastMode, override val child: SparkPlan) - extends BroadcastExchangeLike { + extends BroadcastExchangeLike + with CometPlan { import CometBroadcastExchangeExec._ override val runId: UUID = UUID.randomUUID diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometColumnarToRowExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometColumnarToRowExec.scala index cc150ef535..0391a1c3b3 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometColumnarToRowExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometColumnarToRowExec.scala @@ -36,6 +36,7 @@ import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.comet.util.{Utils => CometUtils} import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution.{CodegenSupport, ColumnarToRowTransition, SparkPlan, SQLExecution} +import org.apache.spark.sql.execution.adaptive.BroadcastQueryStageExec import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} import org.apache.spark.sql.execution.vectorized.{ConstantColumnVector, WritableColumnVector} import org.apache.spark.sql.types._ @@ -93,6 +94,8 @@ case class CometColumnarToRowExec(child: SparkPlan) private val runId: UUID = UUID.randomUUID + private lazy val cometBroadcastExchange = findCometBroadcastExchange(child) + @transient lazy val relationFuture: Future[broadcast.Broadcast[Any]] = { SQLExecution.withThreadLocalCaptured[broadcast.Broadcast[Any]]( @@ -102,7 +105,7 @@ case class CometColumnarToRowExec(child: SparkPlan) // Setup a job group here so later it may get cancelled by groupId if necessary. sparkContext.setJobGroup( runId.toString, - s"ColumnarToRow broadcast exchange (runId $runId)", + s"CometColumnarToRow broadcast exchange (runId $runId)", interruptOnCancel = true) val numOutputRows = longMetric("numOutputRows") @@ -119,7 +122,7 @@ case class CometColumnarToRowExec(child: SparkPlan) batch.rowIterator().asScala.map(toUnsafe) } - val mode = child.asInstanceOf[CometBroadcastExchangeExec].mode + val mode = cometBroadcastExchange.get.mode val relation = mode.transform(rows, Some(numOutputRows.value)) val broadcasted = sparkContext.broadcastInternal(relation, serializedOnly = true) val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) @@ -146,10 +149,10 @@ case class CometColumnarToRowExec(child: SparkPlan) } override def doExecuteBroadcast[T](): broadcast.Broadcast[T] = { - if (!child.isInstanceOf[CometBroadcastExchangeExec]) { + if (cometBroadcastExchange.isEmpty) { throw new SparkException( - "ColumnarToRowExec only supports doExecuteBroadcast when child is " + - "CometBroadcastExchange, but got " + child.nodeName) + "ColumnarToRowExec only supports doExecuteBroadcast when child contains a " + + "CometBroadcastExchange, but got " + child) } try { @@ -165,6 +168,14 @@ case class CometColumnarToRowExec(child: SparkPlan) } } + private def findCometBroadcastExchange(op: SparkPlan): Option[CometBroadcastExchangeExec] = { + op match { + case b: CometBroadcastExchangeExec => Some(b) + case b: BroadcastQueryStageExec => findCometBroadcastExchange(b.plan) + case _ => op.children.collectFirst(Function.unlift(findCometBroadcastExchange)) + } + } + /** * Generate [[ColumnVector]] expressions for our parent to consume as rows. This is called once * per [[ColumnVector]] in the batch. From 79919c46ff449f7315d1c27c910871622aff34ee Mon Sep 17 00:00:00 2001 From: Kristin Cowalcijk Date: Thu, 3 Apr 2025 17:46:08 +0800 Subject: [PATCH 3/6] Add tests, revert my changes to how CometBroadcastExchange is displayed --- .../comet/CometBroadcastExchangeExec.scala | 2 +- .../apache/comet/exec/CometExecSuite.scala | 87 ++++++++++++++++++- 2 files changed, 87 insertions(+), 2 deletions(-) diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala index a991964b6c..c17b2f7856 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala @@ -252,7 +252,7 @@ case class CometBroadcastExchangeExec( override def hashCode(): Int = Objects.hashCode(child) - override def stringArgs: Iterator[Any] = Iterator(output, mode, child) + override def stringArgs: Iterator[Any] = Iterator(output, child) override protected def withNewChildInternal(newChild: SparkPlan): CometBroadcastExchangeExec = copy(child = newChild) diff --git a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala index f8c1cf90f8..b755971a73 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala @@ -36,7 +36,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateMode, Bloom import org.apache.spark.sql.comet._ import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, CometShuffleExchangeExec} import org.apache.spark.sql.execution.{CollectLimitExec, ProjectExec, SQLExecution, UnionExec} -import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec +import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, BroadcastQueryStageExec} import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec} import org.apache.spark.sql.execution.joins.{BroadcastNestedLoopJoinExec, CartesianProductExec, SortMergeJoinExec} @@ -750,6 +750,91 @@ class CometExecSuite extends CometTestBase { } } + test("Comet Shuffled Join should be optimized to CometBroadcastHashJoin by AQE") { + withSQLConf( + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", + SQLConf.ADAPTIVE_AUTO_BROADCASTJOIN_THRESHOLD.key -> "10485760", + CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true", + CometConf.COMET_SHUFFLE_MODE.key -> "native") { + withParquetTable((0 until 100).map(i => (i, i + 1)), "tbl_a") { + withParquetTable((0 until 100).map(i => (i, i + 2)), "tbl_b") { + withParquetTable((0 until 100).map(i => (i, i + 3)), "tbl_c") { + val df = sql("""SELECT /*+ BROADCAST(c) */ a1, sum_b2, c._2 FROM ( + | SELECT a._1 a1, SUM(b._2) sum_b2 FROM tbl_a a + | JOIN tbl_b b ON a._1 = b._1 + | GROUP BY a._1) t + |JOIN tbl_c c ON t.a1 = c._1 + |""".stripMargin) + checkSparkAnswerAndOperator(df) + + // Before AQE: 1 broadcast join + var broadcastHashJoinExec = stripAQEPlan(df.queryExecution.executedPlan).collect { + case s: CometBroadcastHashJoinExec => s + } + assert(broadcastHashJoinExec.length == 1) + + // After AQE: shuffled join optimized to broadcast join + df.collect() + broadcastHashJoinExec = stripAQEPlan(df.queryExecution.executedPlan).collect { + case s: CometBroadcastHashJoinExec => s + } + assert(broadcastHashJoinExec.length == 2) + } + } + } + } + } + + test("CometBroadcastExchange could be converted to rows using CometColumnarToRow") { + withSQLConf( + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", + SQLConf.ADAPTIVE_AUTO_BROADCASTJOIN_THRESHOLD.key -> "10485760", + CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true", + CometConf.COMET_SHUFFLE_MODE.key -> "auto") { + withParquetTable((0 until 100).map(i => (i, i + 1)), "tbl_a") { + withParquetTable((0 until 100).map(i => (i, i + 2)), "tbl_b") { + withParquetTable((0 until 100).map(i => (i, i + 3)), "tbl_c") { + val df = sql("""SELECT /*+ BROADCAST(c) */ a1, sum_b2, c._2 FROM ( + | SELECT a._1 a1, SUM(b._2) sum_b2 FROM tbl_a a + | JOIN tbl_b b ON a._1 = b._1 + | GROUP BY a._1) t + |JOIN tbl_c c ON t.a1 = c._1 + |""".stripMargin) + checkSparkAnswer(df) + + // Before AQE: one CometBroadcastExchange, no CometColumnarToRow + var columnarToRowExec = stripAQEPlan(df.queryExecution.executedPlan).collect { + case s: CometColumnarToRowExec => s + } + assert(columnarToRowExec.isEmpty) + + // Disable CometExecRule after the initial plan is generated. The CometSortMergeJoin and + // CometBroadcastHashJoin nodes in the initial plan will be converted to Spark BroadcastHashJoin + // during AQE. This will make CometBroadcastExchangeExec being converted to rows to be used by + // Spark BroadcastHashJoin. + spark.conf.set(CometConf.COMET_EXEC_ENABLED.key, "false") + df.collect() + + // After AQE: CometBroadcastExchange has to be converted to rows to conform to Spark + // BroadcastHashJoin. + columnarToRowExec = stripAQEPlan(df.queryExecution.executedPlan).collect { + case s: CometColumnarToRowExec => s + } + assert(columnarToRowExec.length == 1) + val broadcastQueryStage = + columnarToRowExec.head.find(_.isInstanceOf[BroadcastQueryStageExec]) + assert(broadcastQueryStage.isDefined) + assert( + broadcastQueryStage.get + .asInstanceOf[BroadcastQueryStageExec] + .broadcast + .isInstanceOf[CometBroadcastExchangeExec]) + } + } + } + } + } + test("expand operator") { val data1 = (0 until 1000) .map(_ % 5) // reduce value space to trigger dictionary encoding From 1819c791b9f6702d48f69178935d74edb3eccb22 Mon Sep 17 00:00:00 2001 From: Kristin Cowalcijk Date: Thu, 3 Apr 2025 18:22:45 +0800 Subject: [PATCH 4/6] Making newly added test fail before applying this fix --- .../org/apache/comet/exec/CometExecSuite.scala | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala index b755971a73..3743a4bb05 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala @@ -39,7 +39,7 @@ import org.apache.spark.sql.execution.{CollectLimitExec, ProjectExec, SQLExecuti import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, BroadcastQueryStageExec} import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec} -import org.apache.spark.sql.execution.joins.{BroadcastNestedLoopJoinExec, CartesianProductExec, SortMergeJoinExec} +import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BroadcastNestedLoopJoinExec, CartesianProductExec, HashJoin, SortMergeJoinExec} import org.apache.spark.sql.execution.reuse.ReuseExchangeAndSubquery import org.apache.spark.sql.execution.window.WindowExec import org.apache.spark.sql.expressions.Window @@ -817,10 +817,17 @@ class CometExecSuite extends CometTestBase { // After AQE: CometBroadcastExchange has to be converted to rows to conform to Spark // BroadcastHashJoin. - columnarToRowExec = stripAQEPlan(df.queryExecution.executedPlan).collect { - case s: CometColumnarToRowExec => s + val plan = stripAQEPlan(df.queryExecution.executedPlan) + columnarToRowExec = plan.collect { case s: CometColumnarToRowExec => + s } assert(columnarToRowExec.length == 1) + + // This ColumnarToRowExec should be the immediate child of BroadcastHashJoinExec + val parent = plan.find(_.children.contains(columnarToRowExec.head)) + assert(parent.get.isInstanceOf[BroadcastHashJoinExec]) + + // There should be a CometBroadcastExchangeExec under CometColumnarToRowExec val broadcastQueryStage = columnarToRowExec.head.find(_.isInstanceOf[BroadcastQueryStageExec]) assert(broadcastQueryStage.isDefined) From afb44905b0ae0b32501c5b376920e1179e4cfc67 Mon Sep 17 00:00:00 2001 From: Kristin Cowalcijk Date: Thu, 3 Apr 2025 19:38:40 +0800 Subject: [PATCH 5/6] Remove unused imports --- spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala index 3743a4bb05..4fdd38223f 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala @@ -39,7 +39,7 @@ import org.apache.spark.sql.execution.{CollectLimitExec, ProjectExec, SQLExecuti import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, BroadcastQueryStageExec} import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec} -import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BroadcastNestedLoopJoinExec, CartesianProductExec, HashJoin, SortMergeJoinExec} +import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BroadcastNestedLoopJoinExec, CartesianProductExec, SortMergeJoinExec} import org.apache.spark.sql.execution.reuse.ReuseExchangeAndSubquery import org.apache.spark.sql.execution.window.WindowExec import org.apache.spark.sql.expressions.Window From f73c7503932876f945a7e3a11bd96d3e239bd7d9 Mon Sep 17 00:00:00 2001 From: Kristin Cowalcijk Date: Thu, 3 Apr 2025 21:57:26 +0800 Subject: [PATCH 6/6] Fix test failure caused by spark conf pollution --- .../test/scala/org/apache/comet/exec/CometExecSuite.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala index 4fdd38223f..3fe300e973 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala @@ -812,8 +812,9 @@ class CometExecSuite extends CometTestBase { // CometBroadcastHashJoin nodes in the initial plan will be converted to Spark BroadcastHashJoin // during AQE. This will make CometBroadcastExchangeExec being converted to rows to be used by // Spark BroadcastHashJoin. - spark.conf.set(CometConf.COMET_EXEC_ENABLED.key, "false") - df.collect() + withSQLConf(CometConf.COMET_EXEC_ENABLED.key -> "false") { + df.collect() + } // After AQE: CometBroadcastExchange has to be converted to rows to conform to Spark // BroadcastHashJoin.