diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala index 21b395982b..95770592fd 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala @@ -26,7 +26,7 @@ import scala.concurrent.{ExecutionContext, Promise} import scala.concurrent.duration.NANOSECONDS import scala.util.control.NonFatal -import org.apache.spark.{broadcast, Partition, SparkContext, TaskContext} +import org.apache.spark.{broadcast, Partition, SparkContext, SparkException, TaskContext} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute @@ -102,14 +102,8 @@ case class CometBroadcastExchangeExec( @transient private lazy val maxBroadcastRows = 512000000 - private var numPartitions: Option[Int] = None - - def setNumPartitions(numPartitions: Int): CometBroadcastExchangeExec = { - this.numPartitions = Some(numPartitions) - this - } def getNumPartitions(): Int = { - numPartitions.getOrElse(child.executeColumnar().getNumPartitions) + child.executeColumnar().getNumPartitions } @transient @@ -224,6 +218,18 @@ case class CometBroadcastExchangeExec( new CometBatchRDD(sparkContext, getNumPartitions(), broadcasted) } + // After https://issues.apache.org/jira/browse/SPARK-48195, Spark plan will cache created RDD. + // Since we may change the number of partitions in CometBatchRDD, + // we need a method that always creates a new CometBatchRDD. + def executeColumnar(numPartitions: Int): RDD[ColumnarBatch] = { + if (isCanonicalizedPlan) { + throw SparkException.internalError("A canonicalized plan is not supposed to be executed.") + } + + val broadcasted = executeBroadcast[Array[ChunkedByteBuffer]]() + new CometBatchRDD(sparkContext, numPartitions, broadcasted) + } + override protected[sql] def doExecuteBroadcast[T](): broadcast.Broadcast[T] = { try { relationFuture.get(timeout, TimeUnit.SECONDS).asInstanceOf[broadcast.Broadcast[T]] @@ -276,7 +282,7 @@ object CometBroadcastExchangeExec { */ class CometBatchRDD( sc: SparkContext, - numPartitions: Int, + val numPartitions: Int, value: broadcast.Broadcast[Array[ChunkedByteBuffer]]) extends RDD[ColumnarBatch](sc, Nil) { diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala index 593f4f3a45..40da23a18e 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala @@ -272,16 +272,16 @@ abstract class CometNativeExec extends CometExec { sparkPlans.zipWithIndex.foreach { case (plan, idx) => plan match { case c: CometBroadcastExchangeExec => - inputs += c.setNumPartitions(firstNonBroadcastPlanNumPartitions).executeColumnar() + inputs += c.executeColumnar(firstNonBroadcastPlanNumPartitions) case BroadcastQueryStageExec(_, c: CometBroadcastExchangeExec, _) => - inputs += c.setNumPartitions(firstNonBroadcastPlanNumPartitions).executeColumnar() + inputs += c.executeColumnar(firstNonBroadcastPlanNumPartitions) case ReusedExchangeExec(_, c: CometBroadcastExchangeExec) => - inputs += c.setNumPartitions(firstNonBroadcastPlanNumPartitions).executeColumnar() + inputs += c.executeColumnar(firstNonBroadcastPlanNumPartitions) case BroadcastQueryStageExec( _, ReusedExchangeExec(_, c: CometBroadcastExchangeExec), _) => - inputs += c.setNumPartitions(firstNonBroadcastPlanNumPartitions).executeColumnar() + inputs += c.executeColumnar(firstNonBroadcastPlanNumPartitions) case _: CometNativeExec => // no-op case _ if idx == firstNonBroadcastPlan.get._2 =>