From 00c631e107ce36724f8d8b102858124f7fe5d49e Mon Sep 17 00:00:00 2001 From: wforget <643348094@qq.com> Date: Fri, 19 Sep 2025 14:32:58 +0800 Subject: [PATCH 1/2] fix: Avoid spark plan execution cache preventing CometBatchRDD numPartitions change --- .../spark/sql/comet/CometBroadcastExchangeExec.scala | 12 +++++++++++- .../scala/org/apache/spark/sql/comet/operators.scala | 8 ++++---- 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala index 9114caf6e3..4209591911 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala @@ -26,7 +26,7 @@ import scala.concurrent.{ExecutionContext, Promise} import scala.concurrent.duration.NANOSECONDS import scala.util.control.NonFatal -import org.apache.spark.{broadcast, Partition, SparkContext, TaskContext} +import org.apache.spark.{broadcast, Partition, SparkContext, SparkException, TaskContext} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute @@ -224,6 +224,16 @@ case class CometBroadcastExchangeExec( new CometBatchRDD(sparkContext, getNumPartitions(), broadcasted) } + // After https://issues.apache.org/jira/browse/SPARK-48195, Spark plan will cache created RDD. + // Since we may change the number of partitions in CometBatchRDD, + // we need a method that always creates a new CometBatchRDD. + def executeColumnarWithoutCache(): RDD[ColumnarBatch] = { + if (isCanonicalizedPlan) { + throw SparkException.internalError("A canonicalized plan is not supposed to be executed.") + } + doExecuteColumnar() + } + override protected[sql] def doExecuteBroadcast[T](): broadcast.Broadcast[T] = { try { relationFuture.get(timeout, TimeUnit.SECONDS).asInstanceOf[broadcast.Broadcast[T]] diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala index aa0ecdcb61..655319b6e4 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala @@ -275,17 +275,17 @@ abstract class CometNativeExec extends CometExec { plan match { case c: CometBroadcastExchangeExec => inputs += c - .executeColumnar() + .executeColumnarWithoutCache() .asInstanceOf[CometBatchRDD] .withNumPartitions(firstNonBroadcastPlanNumPartitions) case BroadcastQueryStageExec(_, c: CometBroadcastExchangeExec, _) => inputs += c - .executeColumnar() + .executeColumnarWithoutCache() .asInstanceOf[CometBatchRDD] .withNumPartitions(firstNonBroadcastPlanNumPartitions) case ReusedExchangeExec(_, c: CometBroadcastExchangeExec) => inputs += c - .executeColumnar() + .executeColumnarWithoutCache() .asInstanceOf[CometBatchRDD] .withNumPartitions(firstNonBroadcastPlanNumPartitions) case BroadcastQueryStageExec( @@ -293,7 +293,7 @@ abstract class CometNativeExec extends CometExec { ReusedExchangeExec(_, c: CometBroadcastExchangeExec), _) => inputs += c - .executeColumnar() + .executeColumnarWithoutCache() .asInstanceOf[CometBatchRDD] .withNumPartitions(firstNonBroadcastPlanNumPartitions) case _: CometNativeExec => From bd05a0be4a5d20e2f595f75dd3918ddeebde47da Mon Sep 17 00:00:00 2001 From: wforget <643348094@qq.com> Date: Fri, 19 Sep 2025 20:04:31 +0800 Subject: [PATCH 2/2] refactor --- .../comet/CometBroadcastExchangeExec.scala | 22 +++++-------------- .../apache/spark/sql/comet/operators.scala | 20 ++++------------- 2 files changed, 10 insertions(+), 32 deletions(-) diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala index 4209591911..95770592fd 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala @@ -102,14 +102,8 @@ case class CometBroadcastExchangeExec( @transient private lazy val maxBroadcastRows = 512000000 - private var numPartitions: Option[Int] = None - - def setNumPartitions(numPartitions: Int): CometBroadcastExchangeExec = { - this.numPartitions = Some(numPartitions) - this - } def getNumPartitions(): Int = { - numPartitions.getOrElse(child.executeColumnar().getNumPartitions) + child.executeColumnar().getNumPartitions } @transient @@ -227,11 +221,13 @@ case class CometBroadcastExchangeExec( // After https://issues.apache.org/jira/browse/SPARK-48195, Spark plan will cache created RDD. // Since we may change the number of partitions in CometBatchRDD, // we need a method that always creates a new CometBatchRDD. - def executeColumnarWithoutCache(): RDD[ColumnarBatch] = { + def executeColumnar(numPartitions: Int): RDD[ColumnarBatch] = { if (isCanonicalizedPlan) { throw SparkException.internalError("A canonicalized plan is not supposed to be executed.") } - doExecuteColumnar() + + val broadcasted = executeBroadcast[Array[ChunkedByteBuffer]]() + new CometBatchRDD(sparkContext, numPartitions, broadcasted) } override protected[sql] def doExecuteBroadcast[T](): broadcast.Broadcast[T] = { @@ -286,7 +282,7 @@ object CometBroadcastExchangeExec { */ class CometBatchRDD( sc: SparkContext, - @volatile var numPartitions: Int, + val numPartitions: Int, value: broadcast.Broadcast[Array[ChunkedByteBuffer]]) extends RDD[ColumnarBatch](sc, Nil) { @@ -299,12 +295,6 @@ class CometBatchRDD( partition.value.value.toIterator .flatMap(Utils.decodeBatches(_, this.getClass.getSimpleName)) } - - def withNumPartitions(numPartitions: Int): CometBatchRDD = { - this.numPartitions = numPartitions - this - } - } class CometBatchPartition( diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala index 655319b6e4..a7cfacc475 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala @@ -274,28 +274,16 @@ abstract class CometNativeExec extends CometExec { sparkPlans.zipWithIndex.foreach { case (plan, idx) => plan match { case c: CometBroadcastExchangeExec => - inputs += c - .executeColumnarWithoutCache() - .asInstanceOf[CometBatchRDD] - .withNumPartitions(firstNonBroadcastPlanNumPartitions) + inputs += c.executeColumnar(firstNonBroadcastPlanNumPartitions) case BroadcastQueryStageExec(_, c: CometBroadcastExchangeExec, _) => - inputs += c - .executeColumnarWithoutCache() - .asInstanceOf[CometBatchRDD] - .withNumPartitions(firstNonBroadcastPlanNumPartitions) + inputs += c.executeColumnar(firstNonBroadcastPlanNumPartitions) case ReusedExchangeExec(_, c: CometBroadcastExchangeExec) => - inputs += c - .executeColumnarWithoutCache() - .asInstanceOf[CometBatchRDD] - .withNumPartitions(firstNonBroadcastPlanNumPartitions) + inputs += c.executeColumnar(firstNonBroadcastPlanNumPartitions) case BroadcastQueryStageExec( _, ReusedExchangeExec(_, c: CometBroadcastExchangeExec), _) => - inputs += c - .executeColumnarWithoutCache() - .asInstanceOf[CometBatchRDD] - .withNumPartitions(firstNonBroadcastPlanNumPartitions) + inputs += c.executeColumnar(firstNonBroadcastPlanNumPartitions) case _: CometNativeExec => // no-op case _ if idx == firstNonBroadcastPlan.get._2 =>