Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import scala.concurrent.{ExecutionContext, Promise}
import scala.concurrent.duration.NANOSECONDS
import scala.util.control.NonFatal

import org.apache.spark.{broadcast, Partition, SparkContext, TaskContext}
import org.apache.spark.{broadcast, Partition, SparkContext, SparkException, TaskContext}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Attribute
Expand Down Expand Up @@ -102,14 +102,8 @@ case class CometBroadcastExchangeExec(
@transient
private lazy val maxBroadcastRows = 512000000

private var numPartitions: Option[Int] = None

def setNumPartitions(numPartitions: Int): CometBroadcastExchangeExec = {
this.numPartitions = Some(numPartitions)
this
}
def getNumPartitions(): Int = {
numPartitions.getOrElse(child.executeColumnar().getNumPartitions)
child.executeColumnar().getNumPartitions
}

@transient
Expand Down Expand Up @@ -224,6 +218,18 @@ case class CometBroadcastExchangeExec(
new CometBatchRDD(sparkContext, getNumPartitions(), broadcasted)
}

// After https://issues.apache.org/jira/browse/SPARK-48195, Spark plan will cache created RDD.
// Since we may change the number of partitions in CometBatchRDD,
// we need a method that always creates a new CometBatchRDD.
def executeColumnar(numPartitions: Int): RDD[ColumnarBatch] = {
if (isCanonicalizedPlan) {
throw SparkException.internalError("A canonicalized plan is not supposed to be executed.")
}

val broadcasted = executeBroadcast[Array[ChunkedByteBuffer]]()
new CometBatchRDD(sparkContext, numPartitions, broadcasted)
}

override protected[sql] def doExecuteBroadcast[T](): broadcast.Broadcast[T] = {
try {
relationFuture.get(timeout, TimeUnit.SECONDS).asInstanceOf[broadcast.Broadcast[T]]
Expand Down Expand Up @@ -276,7 +282,7 @@ object CometBroadcastExchangeExec {
*/
class CometBatchRDD(
sc: SparkContext,
numPartitions: Int,
val numPartitions: Int,
value: broadcast.Broadcast[Array[ChunkedByteBuffer]])
extends RDD[ColumnarBatch](sc, Nil) {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -272,16 +272,16 @@ abstract class CometNativeExec extends CometExec {
sparkPlans.zipWithIndex.foreach { case (plan, idx) =>
plan match {
case c: CometBroadcastExchangeExec =>
inputs += c.setNumPartitions(firstNonBroadcastPlanNumPartitions).executeColumnar()
inputs += c.executeColumnar(firstNonBroadcastPlanNumPartitions)
case BroadcastQueryStageExec(_, c: CometBroadcastExchangeExec, _) =>
inputs += c.setNumPartitions(firstNonBroadcastPlanNumPartitions).executeColumnar()
inputs += c.executeColumnar(firstNonBroadcastPlanNumPartitions)
case ReusedExchangeExec(_, c: CometBroadcastExchangeExec) =>
inputs += c.setNumPartitions(firstNonBroadcastPlanNumPartitions).executeColumnar()
inputs += c.executeColumnar(firstNonBroadcastPlanNumPartitions)
case BroadcastQueryStageExec(
_,
ReusedExchangeExec(_, c: CometBroadcastExchangeExec),
_) =>
inputs += c.setNumPartitions(firstNonBroadcastPlanNumPartitions).executeColumnar()
inputs += c.executeColumnar(firstNonBroadcastPlanNumPartitions)
case _: CometNativeExec =>
// no-op
case _ if idx == firstNonBroadcastPlan.get._2 =>
Expand Down
Loading