Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
* under the License.
*/

package org.apache.spark.sql.comet.execution.shuffle
package org.apache.spark.sql.comet.execution.arrow

import java.nio.channels.ReadableByteChannel

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

package org.apache.spark.sql.comet.util

import java.io.{DataOutputStream, File}
import java.io.{DataInputStream, DataOutputStream, File}
import java.nio.ByteBuffer
import java.nio.channels.Channels

Expand All @@ -35,6 +35,7 @@ import org.apache.arrow.vector.types._
import org.apache.arrow.vector.types.pojo.{ArrowType, Field, FieldType, Schema}
import org.apache.spark.{SparkEnv, SparkException}
import org.apache.spark.io.CompressionCodec
import org.apache.spark.sql.comet.execution.arrow.ArrowReaderIterator
import org.apache.spark.sql.types._
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.util.io.{ChunkedByteBuffer, ChunkedByteBufferOutputStream}
Expand Down Expand Up @@ -226,6 +227,28 @@ object Utils {
}
}

/**
* Decodes the byte arrays back to ColumnarBatchs and put them into buffer.
* @param bytes
* the serialized batches
* @param source
* the class that calls this method
* @return
* an iterator of ColumnarBatch
*/
def decodeBatches(bytes: ChunkedByteBuffer, source: String): Iterator[ColumnarBatch] = {
if (bytes.size == 0) {
return Iterator.empty
}

// use Spark's compression codec (LZ4 by default) and not Comet's compression
val codec = CompressionCodec.createCodec(SparkEnv.get.conf)
val cbbis = bytes.toInputStream()
val ins = new DataInputStream(codec.compressedInputStream(cbbis))
// batches are in Arrow IPC format
new ArrowReaderIterator(Channels.newChannel(ins), source)
}

def getBatchFieldVectors(
batch: ColumnarBatch): (Seq[FieldVector], Option[DictionaryProvider]) = {
var provider: Option[DictionaryProvider] = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ package org.apache.comet

import java.nio.ByteOrder

import scala.annotation.tailrec
import scala.collection.mutable.ListBuffer

import org.apache.spark.SparkConf
Expand All @@ -37,7 +38,7 @@ import org.apache.spark.sql.comet._
import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, CometNativeShuffle, CometShuffleExchangeExec, CometShuffleManager}
import org.apache.spark.sql.comet.util.Utils
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.adaptive.{AQEShuffleReadExec, BroadcastQueryStageExec, ShuffleQueryStageExec}
import org.apache.spark.sql.execution.adaptive.{AQEShuffleReadExec, BroadcastQueryStageExec, QueryStageExec, ShuffleQueryStageExec}
import org.apache.spark.sql.execution.aggregate.{BaseAggregateExec, HashAggregateExec, ObjectHashAggregateExec}
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat
Expand Down Expand Up @@ -728,6 +729,18 @@ class CometSparkSessionExtensions
s
}

case s @ BroadcastQueryStageExec(
_,
ReusedExchangeExec(_, _: CometBroadcastExchangeExec),
_) =>
val newOp = transform1(s)
newOp match {
case Some(nativeOp) =>
CometSinkPlaceHolder(nativeOp, s, s)
case None =>
s
}

// `CometBroadcastExchangeExec`'s broadcast output is not compatible with Spark's broadcast
// exchange. It is only used for Comet native execution. We only transform Spark broadcast
// exchange to Comet broadcast exchange if its downstream is a Comet native plan or if the
Expand All @@ -739,7 +752,7 @@ class CometSparkSessionExtensions
CometConf.COMET_EXEC_BROADCAST_EXCHANGE_ENABLED.get(conf) =>
QueryPlanSerde.operator2Proto(b) match {
case Some(nativeOp) =>
val cometOp = CometBroadcastExchangeExec(b, b.output, b.child)
val cometOp = CometBroadcastExchangeExec(b, b.output, b.mode, b.child)
CometSinkPlaceHolder(nativeOp, b, cometOp)
case None => b
}
Expand Down Expand Up @@ -1136,7 +1149,7 @@ class CometSparkSessionExtensions
// and CometSparkToColumnarExec
sparkToColumnar.child
}
case c @ ColumnarToRowExec(child) if child.exists(_.isInstanceOf[CometPlan]) =>
case c @ ColumnarToRowExec(child) if hasCometNativeChild(child) =>
val op = CometColumnarToRowExec(child)
if (c.logicalLink.isEmpty) {
op.unsetTagValue(SparkPlan.LOGICAL_PLAN_TAG)
Expand Down Expand Up @@ -1170,6 +1183,14 @@ class CometSparkSessionExtensions
}
}
}

@tailrec
private def hasCometNativeChild(op: SparkPlan): Boolean = {
op match {
case c: QueryStageExec => hasCometNativeChild(c.plan)
case _ => op.exists(_.isInstanceOf[CometPlan])
}
}
}

object CometSparkSessionExtensions extends Logging {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2765,6 +2765,8 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
case ShuffleQueryStageExec(_, ReusedExchangeExec(_, _: CometShuffleExchangeExec), _) => true
case _: TakeOrderedAndProjectExec => true
case BroadcastQueryStageExec(_, _: CometBroadcastExchangeExec, _) => true
case BroadcastQueryStageExec(_, ReusedExchangeExec(_, _: CometBroadcastExchangeExec), _) =>
true
case _: BroadcastExchangeExec => true
case _: WindowExec => true
case _ => false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,22 +19,20 @@

package org.apache.spark.sql.comet

import java.io.DataInputStream
import java.nio.channels.Channels
import java.util.UUID
import java.util.concurrent.{Future, TimeoutException, TimeUnit}

import scala.concurrent.{ExecutionContext, Promise}
import scala.concurrent.duration.NANOSECONDS
import scala.util.control.NonFatal

import org.apache.spark.{broadcast, Partition, SparkContext, SparkEnv, TaskContext}
import org.apache.spark.io.CompressionCodec
import org.apache.spark.{broadcast, Partition, SparkContext, TaskContext}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.Statistics
import org.apache.spark.sql.comet.execution.shuffle.ArrowReaderIterator
import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, BroadcastPartitioning, Partitioning}
import org.apache.spark.sql.comet.util.Utils
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.execution.{ColumnarToRowExec, SparkPlan, SQLExecution}
import org.apache.spark.sql.execution.adaptive.{AQEShuffleReadExec, ShuffleQueryStageExec}
Expand Down Expand Up @@ -63,8 +61,10 @@ import org.apache.comet.CometRuntimeException
case class CometBroadcastExchangeExec(
originalPlan: SparkPlan,
override val output: Seq[Attribute],
mode: BroadcastMode,
override val child: SparkPlan)
extends BroadcastExchangeLike {
extends BroadcastExchangeLike
with CometPlan {
import CometBroadcastExchangeExec._

override val runId: UUID = UUID.randomUUID
Expand All @@ -77,7 +77,7 @@ case class CometBroadcastExchangeExec(
"broadcastTime" -> SQLMetrics.createTimingMetric(sparkContext, "time to broadcast"))

override def doCanonicalize(): SparkPlan = {
CometBroadcastExchangeExec(null, null, child.canonicalized)
CometBroadcastExchangeExec(null, null, mode, child.canonicalized)
}

override def runtimeStatistics: Statistics = {
Expand All @@ -86,6 +86,8 @@ case class CometBroadcastExchangeExec(
Statistics(dataSize, Some(rowCount))
}

override def outputPartitioning: Partitioning = BroadcastPartitioning(mode)

@transient
private lazy val promise = Promise[broadcast.Broadcast[Any]]()

Expand Down Expand Up @@ -289,23 +291,7 @@ class CometBatchRDD(
override def compute(split: Partition, context: TaskContext): Iterator[ColumnarBatch] = {
val partition = split.asInstanceOf[CometBatchPartition]
partition.value.value.toIterator
.flatMap(decodeBatches(_, this.getClass.getSimpleName))
}

/**
* Decodes the byte arrays back to ColumnarBatchs and put them into buffer.
*/
private def decodeBatches(bytes: ChunkedByteBuffer, source: String): Iterator[ColumnarBatch] = {
if (bytes.size == 0) {
return Iterator.empty
}

// use Spark's compression codec (LZ4 by default) and not Comet's compression
val codec = CompressionCodec.createCodec(SparkEnv.get.conf)
val cbbis = bytes.toInputStream()
val ins = new DataInputStream(codec.compressedInputStream(cbbis))
// batches are in Arrow IPC format
new ArrowReaderIterator(Channels.newChannel(ins), source)
.flatMap(Utils.decodeBatches(_, this.getClass.getSimpleName))
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,30 @@

package org.apache.spark.sql.comet

import java.util.UUID
import java.util.concurrent.{Future, TimeoutException, TimeUnit}

import scala.collection.JavaConverters._
import scala.concurrent.Promise
import scala.util.control.NonFatal

import org.apache.spark.{broadcast, SparkException}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, SortOrder, UnsafeProjection}
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
import org.apache.spark.sql.execution.{CodegenSupport, ColumnarToRowTransition, SparkPlan}
import org.apache.spark.sql.comet.util.{Utils => CometUtils}
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.execution.{CodegenSupport, ColumnarToRowTransition, SparkPlan, SQLExecution}
import org.apache.spark.sql.execution.adaptive.BroadcastQueryStageExec
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
import org.apache.spark.sql.execution.vectorized.{ConstantColumnVector, WritableColumnVector}
import org.apache.spark.sql.types._
import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector}
import org.apache.spark.util.Utils
import org.apache.spark.util.{SparkFatalException, Utils}
import org.apache.spark.util.io.ChunkedByteBuffer

import org.apache.comet.vector.CometPlainVector

Expand Down Expand Up @@ -76,6 +86,96 @@ case class CometColumnarToRowExec(child: SparkPlan)
}
}

@transient
private lazy val promise = Promise[broadcast.Broadcast[Any]]()

@transient
private val timeout: Long = conf.broadcastTimeout

private val runId: UUID = UUID.randomUUID

private lazy val cometBroadcastExchange = findCometBroadcastExchange(child)

@transient
lazy val relationFuture: Future[broadcast.Broadcast[Any]] = {
SQLExecution.withThreadLocalCaptured[broadcast.Broadcast[Any]](
session,
CometBroadcastExchangeExec.executionContext) {
try {
// Setup a job group here so later it may get cancelled by groupId if necessary.
sparkContext.setJobGroup(
runId.toString,
s"CometColumnarToRow broadcast exchange (runId $runId)",
interruptOnCancel = true)

val numOutputRows = longMetric("numOutputRows")
val numInputBatches = longMetric("numInputBatches")
val localOutput = this.output
val broadcastColumnar = child.executeBroadcast()
val serializedBatches = broadcastColumnar.value.asInstanceOf[Array[ChunkedByteBuffer]]
val toUnsafe = UnsafeProjection.create(localOutput, localOutput)
val rows = serializedBatches.iterator
.flatMap(CometUtils.decodeBatches(_, this.getClass.getSimpleName))
.flatMap { batch =>
numInputBatches += 1
numOutputRows += batch.numRows()
batch.rowIterator().asScala.map(toUnsafe)
}

val mode = cometBroadcastExchange.get.mode
val relation = mode.transform(rows, Some(numOutputRows.value))
val broadcasted = sparkContext.broadcastInternal(relation, serializedOnly = true)
val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY)
SQLMetrics.postDriverMetricUpdates(sparkContext, executionId, metrics.values.toSeq)
promise.trySuccess(broadcasted)
broadcasted
} catch {
// SPARK-24294: To bypass scala bug: https://github.com/scala/bug/issues/9554, we throw
// SparkFatalException, which is a subclass of Exception. ThreadUtils.awaitResult
// will catch this exception and re-throw the wrapped fatal throwable.
case oe: OutOfMemoryError =>
val ex = new SparkFatalException(oe)
promise.tryFailure(ex)
throw ex
case e if !NonFatal(e) =>
val ex = new SparkFatalException(e)
promise.tryFailure(ex)
throw ex
case e: Throwable =>
promise.tryFailure(e)
throw e
}
}
}

override def doExecuteBroadcast[T](): broadcast.Broadcast[T] = {
if (cometBroadcastExchange.isEmpty) {
throw new SparkException(
"ColumnarToRowExec only supports doExecuteBroadcast when child contains a " +
"CometBroadcastExchange, but got " + child)
}

try {
relationFuture.get(timeout, TimeUnit.SECONDS).asInstanceOf[broadcast.Broadcast[T]]
} catch {
case ex: TimeoutException =>
logError(s"Could not execute broadcast in $timeout secs.", ex)
if (!relationFuture.isDone) {
sparkContext.cancelJobGroup(runId.toString)
relationFuture.cancel(true)
}
throw QueryExecutionErrors.executeBroadcastTimeoutError(timeout, Some(ex))
}
}

private def findCometBroadcastExchange(op: SparkPlan): Option[CometBroadcastExchangeExec] = {
op match {
case b: CometBroadcastExchangeExec => Some(b)
case b: BroadcastQueryStageExec => findCometBroadcastExchange(b.plan)
case _ => op.children.collectFirst(Function.unlift(findCometBroadcastExchange))
}
}

/**
* Generate [[ColumnVector]] expressions for our parent to consume as rows. This is called once
* per [[ColumnVector]] in the batch.
Expand Down
Loading
Loading