This class provides static methods that can be called from native code via JNI, delegating to + * the Scala Native object singleton. + */ +public class NativeJNIBridge { + + /** + * Gets Iceberg partition tasks for the current thread (JNI-accessible static method). + * + *
This method is called by native Rust code via JNI to retrieve partition-specific tasks + * during Iceberg scan execution. + * + * @return Serialized protobuf bytes containing IcebergFileScanTask messages, or null if not set + */ + public static byte[] getIcebergPartitionTasksInternal() { + return Native$.MODULE$.getIcebergPartitionTasksInternal(); + } +} diff --git a/spark/src/main/scala/org/apache/comet/Native.scala b/spark/src/main/scala/org/apache/comet/Native.scala index 55e0c70e72..3f93b05b86 100644 --- a/spark/src/main/scala/org/apache/comet/Native.scala +++ b/spark/src/main/scala/org/apache/comet/Native.scala @@ -248,4 +248,62 @@ class Native extends NativeBase { */ @native def columnarToRowClose(c2rHandle: Long): Unit + /** + * Get Iceberg partition tasks for the current partition. + * + * Used by native Iceberg scan to retrieve partition-specific FileScanTask bytes that were set + * via setIcebergPartitionTasks. This enables on-demand task retrieval instead of broadcasting + * all partition tasks to all executors. + * + * @return + * Serialized protobuf bytes containing IcebergFileScanTask messages for the current + * partition, or null if no tasks are set + */ + @native def getIcebergPartitionTasks(): Array[Byte] + +} + +object Native { + + /** + * Thread-local storage for Iceberg partition tasks. + * + * Stores serialized FileScanTask bytes for the current partition. Set by + * CometIcebergNativeScanExec before executing native plan, retrieved by native code via JNI + * during execution. + * + * Using thread-local ensures task isolation between concurrent tasks on the same executor. + */ + private val icebergPartitionTasks = new ThreadLocal[Array[Byte]]() + + /** + * Sets Iceberg partition tasks for the current thread. + * + * @param taskBytes + * Serialized protobuf bytes containing IcebergFileScanTask messages for this partition + */ + def setIcebergPartitionTasks(taskBytes: Array[Byte]): Unit = { + icebergPartitionTasks.set(taskBytes) + } + + /** + * Gets Iceberg partition tasks for the current thread. + * + * Called from JNI by native code during Iceberg scan execution. + * + * @return + * Serialized protobuf bytes, or null if not set + */ + def getIcebergPartitionTasksInternal(): Array[Byte] = { + icebergPartitionTasks.get() + } + + /** + * Clears Iceberg partition tasks for the current thread. + * + * Should be called after native execution completes to prevent memory leaks. + */ + def clearIcebergPartitionTasks(): Unit = { + icebergPartitionTasks.remove() + } } diff --git a/spark/src/main/scala/org/apache/comet/iceberg/IcebergReflection.scala b/spark/src/main/scala/org/apache/comet/iceberg/IcebergReflection.scala index 2d772063e4..16c7681dd9 100644 --- a/spark/src/main/scala/org/apache/comet/iceberg/IcebergReflection.scala +++ b/spark/src/main/scala/org/apache/comet/iceberg/IcebergReflection.scala @@ -739,7 +739,8 @@ case class CometIcebergNativeScanMetadata( tableSchema: Any, globalFieldIdMapping: Map[String, Int], catalogProperties: Map[String, String], - fileFormat: String) + fileFormat: String, + partitionTasks: Map[Int, Array[Byte]] = Map.empty) object CometIcebergNativeScanMetadata extends Logging { diff --git a/spark/src/main/scala/org/apache/comet/serde/operator/CometIcebergNativeScan.scala b/spark/src/main/scala/org/apache/comet/serde/operator/CometIcebergNativeScan.scala index 7238f8ae8c..f34c0c54f6 100644 --- a/spark/src/main/scala/org/apache/comet/serde/operator/CometIcebergNativeScan.scala +++ b/spark/src/main/scala/org/apache/comet/serde/operator/CometIcebergNativeScan.scala @@ -41,6 +41,16 @@ object CometIcebergNativeScan extends CometOperatorSerde[CometBatchScanExec] wit override def enabledConfig: Option[ConfigEntry[Boolean]] = None + /** + * Temporary storage for partition tasks extracted during serialization. + * + * Maps plan ID to partition tasks map. Used to pass partition-specific task bytes from + * convert() to createExec() without embedding them in protobuf. Cleared after createExec() + * completes to prevent memory leaks. + */ + private val partitionTasksCache = + new java.util.concurrent.ConcurrentHashMap[Long, Map[Int, Array[Byte]]]() + /** * Constants specific to Iceberg expression conversion (not in shared IcebergReflection). */ @@ -696,6 +706,9 @@ object CometIcebergNativeScan extends CometOperatorSerde[CometBatchScanExec] wit var totalTasks = 0 + // Map to store serialized task bytes per partition (for optimized execution) + val partitionTasksMap = mutable.HashMap[Int, Array[Byte]]() + // Get pre-extracted metadata from planning phase // If metadata is None, this is a programming error - metadata should have been extracted // in CometScanRule before creating CometBatchScanExec @@ -728,7 +741,7 @@ object CometIcebergNativeScan extends CometOperatorSerde[CometBatchScanExec] wit scan.wrapped.inputRDD match { case rdd: org.apache.spark.sql.execution.datasources.v2.DataSourceRDD => val partitions = rdd.partitions - partitions.foreach { partition => + partitions.zipWithIndex.foreach { case (partition, partitionIndex) => val partitionBuilder = OperatorOuterClass.IcebergFilePartition.newBuilder() val inputPartitions = partition @@ -972,7 +985,12 @@ object CometIcebergNativeScan extends CometOperatorSerde[CometBatchScanExec] wit } } + // Serialize this partition's tasks to bytes and store in map val builtPartition = partitionBuilder.build() + val partitionBytes = builtPartition.toByteArray + partitionTasksMap.put(partitionIndex, partitionBytes) + + // Add to protobuf for standard execution path icebergScanBuilder.addFilePartitions(builtPartition) } case _ => @@ -1022,6 +1040,19 @@ object CometIcebergNativeScan extends CometOperatorSerde[CometBatchScanExec] wit s"$partitionDataPoolBytes bytes (protobuf)") } + // Store partition tasks in cache for retrieval in createExec() + // Using plan ID as key since it's unique per operator + // IMPORTANT: Normalize to Long to avoid Integer/Long mismatch in cache lookups + val planId = builder.getPlanId.toLong + + if (partitionTasksMap.nonEmpty) { + partitionTasksCache.put(planId, partitionTasksMap.toMap) + val avgBytes = partitionTasksMap.values.map(_.length).sum / partitionTasksMap.size + logInfo( + s"IcebergScan: Cached ${partitionTasksMap.size} partitions " + + s"(avg $avgBytes bytes/partition)") + } + builder.clearChildren() Some(builder.setIcebergScan(icebergScanBuilder).build()) } @@ -1039,7 +1070,20 @@ object CometIcebergNativeScan extends CometOperatorSerde[CometBatchScanExec] wit // Extract metadataLocation from the native operator val metadataLocation = nativeOp.getIcebergScan.getMetadataLocation + // Retrieve partition tasks from cache (if available) + // IMPORTANT: Normalize to Long to match storage side (avoid Integer/Long mismatch) + val planId = nativeOp.getPlanId.toLong + + val partitionTasks = + Option(partitionTasksCache.remove(planId)).getOrElse(Map.empty[Int, Array[Byte]]) + // Create the CometIcebergNativeScanExec using the companion object's apply method - CometIcebergNativeScanExec(nativeOp, op.wrapped, op.session, metadataLocation, metadata) + CometIcebergNativeScanExec( + nativeOp, + op.wrapped, + op.session, + metadataLocation, + metadata, + partitionTasks) } } diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometIcebergNativeScanExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometIcebergNativeScanExec.scala index 223ae4fbb7..0abfacf743 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometIcebergNativeScanExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometIcebergNativeScanExec.scala @@ -21,12 +21,15 @@ package org.apache.spark.sql.comet import scala.jdk.CollectionConverters._ +import org.apache.spark.TaskContext +import org.apache.spark.rdd.RDD import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.expressions.{Attribute, SortOrder} import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.physical.{KeyGroupedPartitioning, Partitioning, UnknownPartitioning} import org.apache.spark.sql.execution.datasources.v2.BatchScanExec import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} +import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.util.AccumulatorV2 import com.google.common.base.Objects @@ -49,7 +52,8 @@ case class CometIcebergNativeScanExec( override val serializedPlanOpt: SerializedPlan, metadataLocation: String, numPartitions: Int, - @transient nativeIcebergScanMetadata: CometIcebergNativeScanMetadata) + @transient nativeIcebergScanMetadata: CometIcebergNativeScanMetadata, + @transient partitionTasks: Map[Int, Array[Byte]] = Map.empty) extends CometLeafExec { override val supportsColumnar: Boolean = true @@ -146,6 +150,66 @@ case class CometIcebergNativeScanExec( baseMetrics ++ icebergMetrics + ("num_splits" -> numSplitsMetric) } + override def doExecuteColumnar(): RDD[ColumnarBatch] = { + // Check if we should use JNI-based task retrieval (flag set in protobuf during planning) + // Note: partitionTasks field is @transient so it's empty after serialization + val useJniTaskRetrieval = nativeOp.getIcebergScan.getUseJniTaskRetrieval + + // If JNI task retrieval is enabled, use optimized approach with custom RDD + if (useJniTaskRetrieval) { + import org.apache.comet.{CometExecIterator, Native} + + // Extract serialized plan bytes + val serializedPlan = serializedPlanOpt.plan.getOrElse { + throw new IllegalStateException( + "CometIcebergNativeScanExec must have serialized plan for execution") + } + + val serializedPlanCopy = serializedPlan + val nativeMetrics = CometMetricNode.fromCometPlan(this) + val outputLength = output.length + + def createIterator(partitionIndex: Int, taskBytes: Array[Byte]): CometExecIterator = { + // Set partition tasks in thread-local before creating native plan + Native.setIcebergPartitionTasks(taskBytes) + + try { + val it = new CometExecIterator( + CometExec.newIterId, + Seq.empty, // No input iterators for leaf scan + outputLength, + serializedPlanCopy, + nativeMetrics, + numPartitions, + partitionIndex, + None, // No encryption for Iceberg + Seq.empty + ) // No encrypted files + + Option(TaskContext.get()).foreach { context => + context.addTaskCompletionListener[Unit] { _ => + it.close() + // Clear thread-local to prevent memory leaks + Native.clearIcebergPartitionTasks() + } + } + + it + } catch { + case e: Throwable => + // Ensure cleanup on error + Native.clearIcebergPartitionTasks() + throw e + } + } + + new IcebergScanRDD(sparkContext, numPartitions, partitionTasks, createIterator) + } else { + // Fall back to default implementation (tasks embedded in protobuf) + super.doExecuteColumnar() + } + } + override protected def doCanonicalize(): CometIcebergNativeScanExec = { CometIcebergNativeScanExec( nativeOp, @@ -154,7 +218,9 @@ case class CometIcebergNativeScanExec( SerializedPlan(None), metadataLocation, numPartitions, - nativeIcebergScanMetadata) + nativeIcebergScanMetadata, + Map.empty + ) // partitionTasks is transient, cleared during canonicalization } override def stringArgs: Iterator[Any] = @@ -178,6 +244,32 @@ case class CometIcebergNativeScanExec( output.asJava, serializedPlanOpt, numPartitions: java.lang.Integer) + + /** + * Override convertBlock to preserve partitionTasks when creating the serialized copy. + * + * The parent CometNativeExec.convertBlock() uses makeCopy() which loses @transient fields. We + * need to explicitly pass partitionTasks to the copy constructor. + */ + override def convertBlock(): CometNativeExec = { + if (serializedPlanOpt.isDefined) { + // Already serialized, just return this + this + } else { + // Serialize the plan + val size = nativeOp.getSerializedSize + val bytes = new Array[Byte](size) + val codedOutput = com.google.protobuf.CodedOutputStream.newInstance(bytes) + nativeOp.writeTo(codedOutput) + codedOutput.checkNoSpaceLeft() + + // Create copy with serialized plan AND preserved partitionTasks + copy( + serializedPlanOpt = SerializedPlan(Some(bytes)), + partitionTasks = partitionTasks // Explicitly preserve transient field + ) + } + } } object CometIcebergNativeScanExec { @@ -199,6 +291,8 @@ object CometIcebergNativeScanExec { * Path to table metadata file * @param nativeIcebergScanMetadata * Pre-extracted Iceberg metadata from planning phase + * @param partitionTasks + * Map of partition index to serialized task bytes (for optimized execution) * @return * A new CometIcebergNativeScanExec */ @@ -207,7 +301,8 @@ object CometIcebergNativeScanExec { scanExec: BatchScanExec, session: SparkSession, metadataLocation: String, - nativeIcebergScanMetadata: CometIcebergNativeScanMetadata): CometIcebergNativeScanExec = { + nativeIcebergScanMetadata: CometIcebergNativeScanMetadata, + partitionTasks: Map[Int, Array[Byte]]): CometIcebergNativeScanExec = { // Determine number of partitions from Iceberg's output partitioning val numParts = scanExec.outputPartitioning match { @@ -224,7 +319,8 @@ object CometIcebergNativeScanExec { SerializedPlan(None), metadataLocation, numParts, - nativeIcebergScanMetadata) + nativeIcebergScanMetadata, + partitionTasks) scanExec.logicalLink.foreach(exec.setLogicalLink) exec diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/IcebergScanRDD.scala b/spark/src/main/scala/org/apache/spark/sql/comet/IcebergScanRDD.scala new file mode 100644 index 0000000000..c806a234f9 --- /dev/null +++ b/spark/src/main/scala/org/apache/spark/sql/comet/IcebergScanRDD.scala @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.comet + +import org.apache.spark._ +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.vectorized.ColumnarBatch + +import org.apache.comet.CometExecIterator + +/** + * Partition for IcebergScanRDD that carries partition-specific task bytes. + * + * Each partition knows its own Iceberg FileScanTasks (serialized as protobuf bytes), avoiding the + * need to broadcast all tasks to all executors. + * + * @param index + * Partition index + * @param taskBytes + * Serialized IcebergFileScanTask protobuf bytes for this partition only + */ +private[comet] class IcebergScanPartition(override val index: Int, val taskBytes: Array[Byte]) + extends Partition + +/** + * RDD for native Iceberg scans that avoids broadcasting all partition tasks to all executors. + * + * Traditional approach: Driver serializes ALL partition tasks into protobuf -> broadcast to ALL + * executors -> each executor extracts its own partition tasks. For N partitions, each executor + * receives N*task_size bytes but only uses 1*task_size bytes (99% waste for large N). + * + * Optimized approach: Each IcebergScanPartition carries only its own tasks. Spark's RDD + * serialization ensures each executor only receives the partitions it needs. For N partitions, + * each executor receives O(1) task data instead of O(N). + * + * @param sc + * SparkContext + * @param numPartitions + * Number of partitions + * @param partitionTasks + * Map from partition index to serialized task bytes for that partition + * @param createIterator + * Function to create CometExecIterator for each partition + */ +private[comet] class IcebergScanRDD( + @transient private val sc: SparkContext, + numPartitions: Int, + partitionTasks: Map[Int, Array[Byte]], + createIterator: (Int, Array[Byte]) => CometExecIterator) + extends RDD[ColumnarBatch](sc, Nil) { + + override def getPartitions: Array[Partition] = { + (0 until numPartitions).map { i => + val taskBytes = partitionTasks.getOrElse( + i, + throw new IllegalStateException( + s"No tasks found for partition $i. " + + s"Available partitions: ${partitionTasks.keys.mkString(", ")}")) + new IcebergScanPartition(i, taskBytes) + }.toArray + } + + override def compute(split: Partition, context: TaskContext): Iterator[ColumnarBatch] = { + val partition = split.asInstanceOf[IcebergScanPartition] + // Create iterator with partition-specific task bytes + createIterator(partition.index, partition.taskBytes) + } + + override def getPreferredLocations(split: Partition): Seq[String] = { + // Iceberg handles data locality through its own planning + Nil + } +} diff --git a/spark/src/test/scala/org/apache/comet/IcebergReadFromS3Suite.scala b/spark/src/test/scala/org/apache/comet/IcebergReadFromS3Suite.scala index 00955e6291..1d44ef8739 100644 --- a/spark/src/test/scala/org/apache/comet/IcebergReadFromS3Suite.scala +++ b/spark/src/test/scala/org/apache/comet/IcebergReadFromS3Suite.scala @@ -184,7 +184,7 @@ class IcebergReadFromS3Suite extends CometS3TestBase { id, CONCAT('data_', CAST(id AS STRING)) as data, (id % 100) as partition_id - FROM range(500000) + FROM range(5000000) """) checkIcebergNativeScan(