diff --git a/native/core/src/execution/jni_api.rs b/native/core/src/execution/jni_api.rs index e9f2d6523d..1f913f65b6 100644 --- a/native/core/src/execution/jni_api.rs +++ b/native/core/src/execution/jni_api.rs @@ -153,6 +153,51 @@ struct ExecutionContext { pub tracing_enabled: bool, } +/// Deserializes per-partition IcebergScan data from protobuf. +/// Returns a map of plan_id to IcebergFilePartitionData containing only the +/// per-partition FileScanTasks and deduplication pools. +fn deserialize_iceberg_replacements( + bytes: &[u8], +) -> CometResult> { + use datafusion_comet_proto::spark_operator::IcebergScanReplacements; + use prost::Message; + + let replacements = IcebergScanReplacements::decode(bytes).map_err(|e| { + CometError::Internal(format!("Failed to decode IcebergScanReplacements: {}", e)) + })?; + Ok(replacements.replacements) +} + +/// Merges per-partition IcebergScan data into placeholder IcebergScans by plan_id. +/// The placeholder IcebergScan already contains common metadata (required_schema, +/// catalog_properties, metadata_location). We just merge in the per-partition data. +fn merge_iceberg_partition_data( + op: &mut Operator, + replacements: &HashMap, +) { + use datafusion_comet_proto::spark_operator::operator::OpStruct; + + if let Some(OpStruct::IcebergScan(scan)) = &mut op.op_struct { + if let Some(partition_data) = replacements.get(&op.plan_id) { + // Merge per-partition data into the existing IcebergScan + scan.file_partition = partition_data.file_partition.clone(); + scan.schema_pool = partition_data.schema_pool.clone(); + scan.partition_type_pool = partition_data.partition_type_pool.clone(); + scan.partition_spec_pool = partition_data.partition_spec_pool.clone(); + scan.name_mapping_pool = partition_data.name_mapping_pool.clone(); + scan.project_field_ids_pool = partition_data.project_field_ids_pool.clone(); + scan.partition_data_pool = partition_data.partition_data_pool.clone(); + scan.delete_files_pool = partition_data.delete_files_pool.clone(); + scan.residual_pool = partition_data.residual_pool.clone(); + } + } + + // Continue traversing to handle multiple IcebergScans + for child in op.children.iter_mut() { + merge_iceberg_partition_data(child, replacements); + } +} + /// Accept serialized query plan and return the address of the native query plan. /// # Safety /// This function is inherently unsafe since it deals with raw pointers passed from JNI. @@ -177,6 +222,7 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan( task_attempt_id: jlong, task_cpus: jlong, key_unwrapper_obj: JObject, + iceberg_task_bytes: JByteArray, ) -> jlong { try_unwrap_or_throw(&e, |mut env| { // Deserialize Spark configs @@ -199,7 +245,15 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan( // Deserialize query plan let bytes = env.convert_byte_array(serialized_query)?; - let spark_plan = serde::deserialize_op(bytes.as_slice())?; + let mut spark_plan = serde::deserialize_op(bytes.as_slice())?; + + // Merge per-partition IcebergScan data into placeholders to avoid sending + // all tasks to all executors. Each executor receives only its partition's tasks. + if !iceberg_task_bytes.is_null() { + let iceberg_bytes = env.convert_byte_array(iceberg_task_bytes)?; + let replacements = deserialize_iceberg_replacements(&iceberg_bytes)?; + merge_iceberg_partition_data(&mut spark_plan, &replacements); + } let metrics = Arc::new(jni_new_global_ref!(env, metrics_node)?); diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index 44ff20a44f..1dae4075bc 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -1144,13 +1144,13 @@ impl PhysicalPlanner { let metadata_location = scan.metadata_location.clone(); debug_assert!( - !scan.file_partitions.is_empty(), - "IcebergScan must have at least one file partition. This indicates a bug in Scala serialization." + scan.file_partition.is_some(), + "IcebergScan must have file_partition populated. This indicates a bug in per-partition merge logic." ); let tasks = parse_file_scan_tasks( scan, - &scan.file_partitions[self.partition as usize].file_scan_tasks, + &scan.file_partition.as_ref().unwrap().file_scan_tasks, )?; let file_task_groups = vec![tasks]; diff --git a/native/proto/src/proto/operator.proto b/native/proto/src/proto/operator.proto index 73c087cf36..5f92fb9de2 100644 --- a/native/proto/src/proto/operator.proto +++ b/native/proto/src/proto/operator.proto @@ -163,8 +163,8 @@ message IcebergScan { // Catalog-specific configuration for FileIO (credentials, S3/GCS config, etc.) map catalog_properties = 2; - // Pre-planned file scan tasks grouped by Spark partition - repeated IcebergFilePartition file_partitions = 3; + // File scan tasks for this partition (exactly one partition per IcebergScan after merging) + IcebergFilePartition file_partition = 3; // Table metadata file path for FileIO initialization string metadata_location = 4; @@ -406,3 +406,24 @@ message Window { repeated spark.spark_expression.Expr partition_by_list = 3; Operator child = 4; } + +// Per-partition file scan tasks and deduplication pools for IcebergScan. +// Contains only the data that varies per partition, while common metadata +// (required_schema, catalog_properties, metadata_location) stays in the placeholder. +message IcebergFilePartitionData { + IcebergFilePartition file_partition = 1; + repeated string schema_pool = 2; + repeated string partition_type_pool = 3; + repeated string partition_spec_pool = 4; + repeated string name_mapping_pool = 5; + repeated ProjectFieldIdList project_field_ids_pool = 6; + repeated PartitionData partition_data_pool = 7; + repeated DeleteFileList delete_files_pool = 8; + repeated spark.spark_expression.Expr residual_pool = 9; +} + +// Per-partition IcebergScan replacements for multiple tables in joins/unions. +// Maps plan_id to just the per-partition data, avoiding duplicate common metadata. +message IcebergScanReplacements { + map replacements = 1; +} diff --git a/spark/src/main/scala/org/apache/comet/CometExecIterator.scala b/spark/src/main/scala/org/apache/comet/CometExecIterator.scala index 3156eb3873..9ad26e14d5 100644 --- a/spark/src/main/scala/org/apache/comet/CometExecIterator.scala +++ b/spark/src/main/scala/org/apache/comet/CometExecIterator.scala @@ -69,7 +69,8 @@ class CometExecIterator( numParts: Int, partitionIndex: Int, broadcastedHadoopConfForEncryption: Option[Broadcast[SerializableConfiguration]] = None, - encryptedFilePaths: Seq[String] = Seq.empty) + encryptedFilePaths: Seq[String] = Seq.empty, + icebergTaskBytes: Option[Array[Byte]] = None) extends Iterator[ColumnarBatch] with Logging { @@ -123,7 +124,8 @@ class CometExecIterator( memoryConfig.memoryLimitPerTask, taskAttemptId, taskCPUs, - keyUnwrapper) + keyUnwrapper, + icebergTaskBytes = icebergTaskBytes.orNull) } private var nextBatch: Option[ColumnarBatch] = None diff --git a/spark/src/main/scala/org/apache/comet/Native.scala b/spark/src/main/scala/org/apache/comet/Native.scala index 55e0c70e72..c25fa426ce 100644 --- a/spark/src/main/scala/org/apache/comet/Native.scala +++ b/spark/src/main/scala/org/apache/comet/Native.scala @@ -69,7 +69,8 @@ class Native extends NativeBase { memoryLimitPerTask: Long, taskAttemptId: Long, taskCPUs: Long, - keyUnwrapper: CometFileKeyUnwrapper): Long + keyUnwrapper: CometFileKeyUnwrapper, + icebergTaskBytes: Array[Byte]): Long // scalastyle:on /** diff --git a/spark/src/main/scala/org/apache/comet/serde/operator/CometIcebergNativeScan.scala b/spark/src/main/scala/org/apache/comet/serde/operator/CometIcebergNativeScan.scala index 7238f8ae8c..3015400f3d 100644 --- a/spark/src/main/scala/org/apache/comet/serde/operator/CometIcebergNativeScan.scala +++ b/spark/src/main/scala/org/apache/comet/serde/operator/CometIcebergNativeScan.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.comet.{CometBatchScanExec, CometNativeExec} import org.apache.spark.sql.types._ import org.apache.comet.ConfigEntry -import org.apache.comet.iceberg.IcebergReflection +import org.apache.comet.iceberg.{CometIcebergNativeScanMetadata, IcebergReflection} import org.apache.comet.serde.{CometOperatorSerde, OperatorOuterClass} import org.apache.comet.serde.ExprOuterClass.Expr import org.apache.comet.serde.OperatorOuterClass.{Operator, SparkStructField} @@ -70,9 +70,14 @@ object CometIcebergNativeScan extends CometOperatorSerde[CometBatchScanExec] wit } /** - * Converts an Iceberg partition value to protobuf format. Protobuf is less verbose than JSON. - * The following types are also serialized as integer values instead of as strings - Timestamps, - * Dates, Decimals, FieldIDs + * Converts an Iceberg partition value to JSON format expected by iceberg-rust. + * + * iceberg-rust's Literal::try_from_json() expects specific formats for certain types: + * - Timestamps: ISO string format "yyyy-MM-dd'T'HH:mm:ss.SSSSSS" + * - Dates: ISO string format "YYYY-MM-DD" + * - Decimals: String representation + * + * See: iceberg-rust/crates/iceberg/src/spec/values/literal.rs */ private def partitionValueToProto( fieldId: Int, @@ -217,6 +222,216 @@ object CometIcebergNativeScan extends CometOperatorSerde[CometBatchScanExec] wit } } + /** + * Serializes a single FileScanTask to protobuf. Reusable helper for both convert() and + * serializePartitionInputs(). + */ + // scalastyle:off parameter.number argcount + private def serializeFileScanTask( + task: Any, + output: Seq[Attribute], + metadata: CometIcebergNativeScanMetadata, + icebergScanBuilder: OperatorOuterClass.IcebergScan.Builder, + schemaToPoolIndex: mutable.HashMap[AnyRef, Int], + partitionTypeToPoolIndex: mutable.HashMap[String, Int], + partitionSpecToPoolIndex: mutable.HashMap[String, Int], + nameMappingToPoolIndex: mutable.HashMap[String, Int], + projectFieldIdsToPoolIndex: mutable.HashMap[Seq[Int], Int], + partitionDataToPoolIndex: mutable.HashMap[String, Int], + deleteFilesToPoolIndex: mutable.HashMap[Seq[OperatorOuterClass.IcebergDeleteFile], Int], + residualToPoolIndex: mutable.HashMap[Option[Expr], Int]) + : OperatorOuterClass.IcebergFileScanTask = { + // scalastyle:on parameter.number argcount + + val taskBuilder = OperatorOuterClass.IcebergFileScanTask.newBuilder() + + // scalastyle:off classforname + val contentScanTaskClass = + Class.forName(IcebergReflection.ClassNames.CONTENT_SCAN_TASK) + val fileScanTaskClass = + Class.forName(IcebergReflection.ClassNames.FILE_SCAN_TASK) + val contentFileClass = + Class.forName(IcebergReflection.ClassNames.CONTENT_FILE) + // scalastyle:on classforname + + val fileMethod = contentScanTaskClass.getMethod("file") + val dataFile = fileMethod.invoke(task) + + val filePathOpt = + IcebergReflection.extractFileLocation(contentFileClass, dataFile) + + filePathOpt match { + case Some(filePath) => + taskBuilder.setDataFilePath(filePath) + case None => + val msg = + "Iceberg reflection failure: Cannot extract file path from data file" + logError(msg) + throw new RuntimeException(msg) + } + + val startMethod = contentScanTaskClass.getMethod("start") + val start = startMethod.invoke(task).asInstanceOf[Long] + taskBuilder.setStart(start) + + val lengthMethod = contentScanTaskClass.getMethod("length") + val length = lengthMethod.invoke(task).asInstanceOf[Long] + taskBuilder.setLength(length) + + try { + // scalastyle:off classforname + val schemaParserClass = + Class.forName(IcebergReflection.ClassNames.SCHEMA_PARSER) + val schemaClass = Class.forName(IcebergReflection.ClassNames.SCHEMA) + // scalastyle:on classforname + + val taskSchemaMethod = fileScanTaskClass.getMethod("schema") + val taskSchema = taskSchemaMethod.invoke(task) + + val deletes = + IcebergReflection.getDeleteFilesFromTask(task, fileScanTaskClass) + val hasDeletes = !deletes.isEmpty + + // Use taskSchema for deletes (MOR requires exact schema matching), + // scanSchema for historical queries (VERSION AS OF with dropped columns), + // or tableSchema for regular queries (provides partition column lookups). + val schema: AnyRef = + if (hasDeletes) { + taskSchema + } else { + val scanSchemaFieldIds = IcebergReflection + .buildFieldIdMapping(metadata.scanSchema) + .values + .toSet + val tableSchemaFieldIds = IcebergReflection + .buildFieldIdMapping(metadata.tableSchema) + .values + .toSet + val hasHistoricalColumns = + scanSchemaFieldIds.exists(id => !tableSchemaFieldIds.contains(id)) + + if (hasHistoricalColumns) { + metadata.scanSchema.asInstanceOf[AnyRef] + } else { + metadata.tableSchema.asInstanceOf[AnyRef] + } + } + + val toJsonMethod = schemaParserClass.getMethod("toJson", schemaClass) + toJsonMethod.setAccessible(true) + + val schemaIdx = schemaToPoolIndex.getOrElseUpdate( + schema, { + val idx = schemaToPoolIndex.size + val schemaJson = toJsonMethod.invoke(null, schema).asInstanceOf[String] + icebergScanBuilder.addSchemaPool(schemaJson) + idx + }) + taskBuilder.setSchemaIdx(schemaIdx) + + // Build field ID mapping + val nameToFieldId = IcebergReflection.buildFieldIdMapping(schema) + + val projectFieldIds = output.flatMap { attr => + nameToFieldId + .get(attr.name) + .orElse(metadata.globalFieldIdMapping.get(attr.name)) + .orElse { + logWarning( + s"Column '${attr.name}' not found in task or scan schema," + + "skipping projection") + None + } + } + + val projectFieldIdsIdx = projectFieldIdsToPoolIndex.getOrElseUpdate( + projectFieldIds, { + val idx = projectFieldIdsToPoolIndex.size + val listBuilder = OperatorOuterClass.ProjectFieldIdList.newBuilder() + projectFieldIds.foreach(id => listBuilder.addFieldIds(id)) + icebergScanBuilder.addProjectFieldIdsPool(listBuilder.build()) + idx + }) + taskBuilder.setProjectFieldIdsIdx(projectFieldIdsIdx) + } catch { + case e: Exception => + val msg = + "Iceberg reflection failure: " + + "Failed to extract schema from FileScanTask: " + + s"${e.getMessage}" + logError(msg) + throw new RuntimeException(msg, e) + } + + // Deduplicate delete files + val deleteFilesList = + extractDeleteFilesList(task, contentFileClass, fileScanTaskClass) + if (deleteFilesList.nonEmpty) { + val deleteFilesIdx = deleteFilesToPoolIndex.getOrElseUpdate( + deleteFilesList, { + val idx = deleteFilesToPoolIndex.size + val listBuilder = OperatorOuterClass.DeleteFileList.newBuilder() + deleteFilesList.foreach(df => listBuilder.addDeleteFiles(df)) + icebergScanBuilder.addDeleteFilesPool(listBuilder.build()) + idx + }) + taskBuilder.setDeleteFilesIdx(deleteFilesIdx) + } + + // Extract and deduplicate residual expression + val residualExprOpt = + try { + val residualMethod = contentScanTaskClass.getMethod("residual") + val residualExpr = residualMethod.invoke(task) + + val catalystExpr = convertIcebergExpression(residualExpr, output) + + catalystExpr.flatMap { expr => + exprToProto(expr, output, binding = false) + } + } catch { + case e: Exception => + logWarning( + "Failed to extract residual expression from FileScanTask: " + + s"${e.getMessage}") + None + } + + residualExprOpt.foreach { residualExpr => + val residualIdx = residualToPoolIndex.getOrElseUpdate( + Some(residualExpr), { + val idx = residualToPoolIndex.size + icebergScanBuilder.addResidualPool(residualExpr) + idx + }) + taskBuilder.setResidualIdx(residualIdx) + } + + // Serialize partition spec and data + serializePartitionData( + task, + contentScanTaskClass, + fileScanTaskClass, + taskBuilder, + icebergScanBuilder, + partitionTypeToPoolIndex, + partitionSpecToPoolIndex, + partitionDataToPoolIndex) + + // Deduplicate name mapping + metadata.nameMapping.foreach { nm => + val nmIdx = nameMappingToPoolIndex.getOrElseUpdate( + nm, { + val idx = nameMappingToPoolIndex.size + icebergScanBuilder.addNameMappingPool(nm) + idx + }) + taskBuilder.setNameMappingIdx(nmIdx) + } + + taskBuilder.build() + } + /** * Extracts delete files from an Iceberg FileScanTask as a list (for deduplication). */ @@ -426,14 +641,12 @@ object CometIcebergNativeScan extends CometOperatorSerde[CometBatchScanExec] wit } } - // Serialize partition data to protobuf for native execution. - // The native execution engine uses partition_data protobuf messages to - // build a constants_map, which provides partition values to identity- - // transformed partition columns. Non-identity transforms (bucket, truncate, - // days, etc.) read values from data files. - // - // IMPORTANT: Use partition field IDs (not source field IDs) to match - // the schema. + // Serialize partition data as protobuf for iceberg-rust's constants_map. + // The native execution engine uses partition_data + + // partition_type to build a constants_map, which is the primary + // mechanism for providing partition values to identity-transformed + // partition columns. Non-identity transforms (bucket, truncate, days, + // etc.) read values from data files. // Filter out fields with unknown type (same as partition type filtering) val partitionValues: Seq[OperatorOuterClass.PartitionValue] = @@ -444,7 +657,6 @@ object CometIcebergNativeScan extends CometOperatorSerde[CometBatchScanExec] wit if (fieldTypeStr == IcebergReflection.TypeNames.UNKNOWN) { None } else { - // Use the partition type's field ID (same as in partition_type_json) val fieldIdMethod = field.getClass.getMethod("fieldId") val fieldId = fieldIdMethod.invoke(field).asInstanceOf[Int] @@ -683,19 +895,6 @@ object CometIcebergNativeScan extends CometOperatorSerde[CometBatchScanExec] wit childOp: Operator*): Option[OperatorOuterClass.Operator] = { val icebergScanBuilder = OperatorOuterClass.IcebergScan.newBuilder() - // Deduplication structures - map unique values to pool indices - val schemaToPoolIndex = mutable.HashMap[AnyRef, Int]() - val partitionTypeToPoolIndex = mutable.HashMap[String, Int]() - val partitionSpecToPoolIndex = mutable.HashMap[String, Int]() - val nameMappingToPoolIndex = mutable.HashMap[String, Int]() - val projectFieldIdsToPoolIndex = mutable.HashMap[Seq[Int], Int]() - val partitionDataToPoolIndex = mutable.HashMap[String, Int]() // Base64 bytes -> pool index - val deleteFilesToPoolIndex = - mutable.HashMap[Seq[OperatorOuterClass.IcebergDeleteFile], Int]() - val residualToPoolIndex = mutable.HashMap[Option[Expr], Int]() - - var totalTasks = 0 - // Get pre-extracted metadata from planning phase // If metadata is None, this is a programming error - metadata should have been extracted // in CometScanRule before creating CometBatchScanExec @@ -723,304 +922,12 @@ object CometIcebergNativeScan extends CometOperatorSerde[CometBatchScanExec] wit icebergScanBuilder.addRequiredSchema(field.build()) } - // Extract FileScanTasks from the InputPartitions in the RDD - try { - scan.wrapped.inputRDD match { - case rdd: org.apache.spark.sql.execution.datasources.v2.DataSourceRDD => - val partitions = rdd.partitions - partitions.foreach { partition => - val partitionBuilder = OperatorOuterClass.IcebergFilePartition.newBuilder() - - val inputPartitions = partition - .asInstanceOf[org.apache.spark.sql.execution.datasources.v2.DataSourceRDDPartition] - .inputPartitions - - inputPartitions.foreach { inputPartition => - val inputPartClass = inputPartition.getClass - - try { - val taskGroupMethod = inputPartClass.getDeclaredMethod("taskGroup") - taskGroupMethod.setAccessible(true) - val taskGroup = taskGroupMethod.invoke(inputPartition) - - val taskGroupClass = taskGroup.getClass - val tasksMethod = taskGroupClass.getMethod("tasks") - val tasksCollection = - tasksMethod.invoke(taskGroup).asInstanceOf[java.util.Collection[_]] - - tasksCollection.asScala.foreach { task => - totalTasks += 1 - - try { - val taskBuilder = OperatorOuterClass.IcebergFileScanTask.newBuilder() - - // scalastyle:off classforname - val contentScanTaskClass = - Class.forName(IcebergReflection.ClassNames.CONTENT_SCAN_TASK) - val fileScanTaskClass = - Class.forName(IcebergReflection.ClassNames.FILE_SCAN_TASK) - val contentFileClass = - Class.forName(IcebergReflection.ClassNames.CONTENT_FILE) - // scalastyle:on classforname - - val fileMethod = contentScanTaskClass.getMethod("file") - val dataFile = fileMethod.invoke(task) - - val filePathOpt = - IcebergReflection.extractFileLocation(contentFileClass, dataFile) - - filePathOpt match { - case Some(filePath) => - taskBuilder.setDataFilePath(filePath) - case None => - val msg = - "Iceberg reflection failure: Cannot extract file path from data file" - logError(msg) - throw new RuntimeException(msg) - } - - val startMethod = contentScanTaskClass.getMethod("start") - val start = startMethod.invoke(task).asInstanceOf[Long] - taskBuilder.setStart(start) - - val lengthMethod = contentScanTaskClass.getMethod("length") - val length = lengthMethod.invoke(task).asInstanceOf[Long] - taskBuilder.setLength(length) - - try { - // Equality deletes require the full table schema to resolve field IDs, - // even for columns not in the projection. Schema evolution requires - // using the snapshot's schema to correctly read old data files. - // These requirements conflict, so we choose based on delete presence. - - val taskSchemaMethod = fileScanTaskClass.getMethod("schema") - val taskSchema = taskSchemaMethod.invoke(task) - - val deletes = - IcebergReflection.getDeleteFilesFromTask(task, fileScanTaskClass) - val hasDeletes = !deletes.isEmpty - - // Schema to pass to iceberg-rust's FileScanTask. - // This is used by RecordBatchTransformer for field type lookups (e.g., in - // constants_map) and default value generation. The actual projection is - // controlled by project_field_ids. - // - // Schema selection logic: - // 1. If hasDeletes=true: Use taskSchema (file-specific schema) because - // delete files reference specific schema versions and we need exact schema - // matching for MOR. - // 2. Else if scanSchema contains columns not in tableSchema: Use scanSchema - // because this is a VERSION AS OF query reading a historical snapshot with - // different schema (e.g., after column drop, scanSchema has old columns - // that tableSchema doesn't) - // 3. Else: Use tableSchema because scanSchema is the query OUTPUT schema - // (e.g., for aggregates like "SELECT count(*)", scanSchema only has - // aggregate fields and doesn't contain partition columns needed by - // constants_map) - val schema: AnyRef = - if (hasDeletes) { - taskSchema - } else { - // Check if scanSchema has columns that tableSchema doesn't have - // (VERSION AS OF case) - val scanSchemaFieldIds = IcebergReflection - .buildFieldIdMapping(metadata.scanSchema) - .values - .toSet - val tableSchemaFieldIds = IcebergReflection - .buildFieldIdMapping(metadata.tableSchema) - .values - .toSet - val hasHistoricalColumns = - scanSchemaFieldIds.exists(id => !tableSchemaFieldIds.contains(id)) - - if (hasHistoricalColumns) { - // VERSION AS OF: scanSchema has columns that current table doesn't have - metadata.scanSchema.asInstanceOf[AnyRef] - } else { - // Regular query: use tableSchema for partition field lookups - metadata.tableSchema.asInstanceOf[AnyRef] - } - } - - // scalastyle:off classforname - val schemaParserClass = - Class.forName(IcebergReflection.ClassNames.SCHEMA_PARSER) - val schemaClass = Class.forName(IcebergReflection.ClassNames.SCHEMA) - // scalastyle:on classforname - val toJsonMethod = schemaParserClass.getMethod("toJson", schemaClass) - toJsonMethod.setAccessible(true) - - // Use object identity for deduplication: Iceberg Schema objects are immutable - // and reused across tasks, making identity-based deduplication safe - val schemaIdx = schemaToPoolIndex.getOrElseUpdate( - schema, { - val idx = schemaToPoolIndex.size - val schemaJson = toJsonMethod.invoke(null, schema).asInstanceOf[String] - icebergScanBuilder.addSchemaPool(schemaJson) - idx - }) - taskBuilder.setSchemaIdx(schemaIdx) - - // Build field ID mapping from the schema we're using - val nameToFieldId = IcebergReflection.buildFieldIdMapping(schema) - - // Extract project_field_ids for scan.output columns. - // For schema evolution: try task schema first, then fall back to - // global scan schema (pre-extracted in metadata). - val projectFieldIds = scan.output.flatMap { attr => - nameToFieldId - .get(attr.name) - .orElse(metadata.globalFieldIdMapping.get(attr.name)) - .orElse { - logWarning( - s"Column '${attr.name}' not found in task or scan schema," + - "skipping projection") - None - } - } - - // Deduplicate project field IDs - val projectFieldIdsIdx = projectFieldIdsToPoolIndex.getOrElseUpdate( - projectFieldIds, { - val idx = projectFieldIdsToPoolIndex.size - val listBuilder = OperatorOuterClass.ProjectFieldIdList.newBuilder() - projectFieldIds.foreach(id => listBuilder.addFieldIds(id)) - icebergScanBuilder.addProjectFieldIdsPool(listBuilder.build()) - idx - }) - taskBuilder.setProjectFieldIdsIdx(projectFieldIdsIdx) - } catch { - case e: Exception => - val msg = - "Iceberg reflection failure: " + - "Failed to extract schema from FileScanTask: " + - s"${e.getMessage}" - logError(msg) - throw new RuntimeException(msg, e) - } - - // Deduplicate delete files - val deleteFilesList = - extractDeleteFilesList(task, contentFileClass, fileScanTaskClass) - if (deleteFilesList.nonEmpty) { - val deleteFilesIdx = deleteFilesToPoolIndex.getOrElseUpdate( - deleteFilesList, { - val idx = deleteFilesToPoolIndex.size - val listBuilder = OperatorOuterClass.DeleteFileList.newBuilder() - deleteFilesList.foreach(df => listBuilder.addDeleteFiles(df)) - icebergScanBuilder.addDeleteFilesPool(listBuilder.build()) - idx - }) - taskBuilder.setDeleteFilesIdx(deleteFilesIdx) - } - - // Extract and deduplicate residual expression - val residualExprOpt = - try { - val residualMethod = contentScanTaskClass.getMethod("residual") - val residualExpr = residualMethod.invoke(task) - - val catalystExpr = convertIcebergExpression(residualExpr, scan.output) - - catalystExpr.flatMap { expr => - exprToProto(expr, scan.output, binding = false) - } - } catch { - case e: Exception => - logWarning( - "Failed to extract residual expression from FileScanTask: " + - s"${e.getMessage}") - None - } - - residualExprOpt.foreach { residualExpr => - val residualIdx = residualToPoolIndex.getOrElseUpdate( - Some(residualExpr), { - val idx = residualToPoolIndex.size - icebergScanBuilder.addResidualPool(residualExpr) - idx - }) - taskBuilder.setResidualIdx(residualIdx) - } - - // Serialize partition spec and data (field definitions, transforms, values) - serializePartitionData( - task, - contentScanTaskClass, - fileScanTaskClass, - taskBuilder, - icebergScanBuilder, - partitionTypeToPoolIndex, - partitionSpecToPoolIndex, - partitionDataToPoolIndex) - - // Deduplicate name mapping - metadata.nameMapping.foreach { nm => - val nmIdx = nameMappingToPoolIndex.getOrElseUpdate( - nm, { - val idx = nameMappingToPoolIndex.size - icebergScanBuilder.addNameMappingPool(nm) - idx - }) - taskBuilder.setNameMappingIdx(nmIdx) - } - - partitionBuilder.addFileScanTasks(taskBuilder.build()) - } - } - } - } - - val builtPartition = partitionBuilder.build() - icebergScanBuilder.addFilePartitions(builtPartition) - } - case _ => - } - } catch { - case e: Exception => - val msg = - "Iceberg reflection failure: Failed to extract FileScanTasks from Iceberg scan RDD: " + - s"${e.getMessage}" - logError(msg, e) - return None - } - - // Log deduplication summary - val allPoolSizes = Seq( - schemaToPoolIndex.size, - partitionTypeToPoolIndex.size, - partitionSpecToPoolIndex.size, - nameMappingToPoolIndex.size, - projectFieldIdsToPoolIndex.size, - partitionDataToPoolIndex.size, - deleteFilesToPoolIndex.size, - residualToPoolIndex.size) - - val avgDedup = if (totalTasks == 0) { - "0.0" - } else { - // Filter out empty pools - they shouldn't count as 100% dedup - val nonEmptyPools = allPoolSizes.filter(_ > 0) - if (nonEmptyPools.isEmpty) { - "0.0" - } else { - val avgUnique = nonEmptyPools.sum.toDouble / nonEmptyPools.length - f"${(1.0 - avgUnique / totalTasks) * 100}%.1f" - } - } - - // Calculate partition data pool size in bytes (protobuf format) - val partitionDataPoolBytes = icebergScanBuilder.getPartitionDataPoolList.asScala - .map(_.getSerializedSize) - .sum - - logInfo(s"IcebergScan: $totalTasks tasks, ${allPoolSizes.size} pools ($avgDedup% avg dedup)") - if (partitionDataToPoolIndex.nonEmpty) { - logInfo( - s" Partition data pool: ${partitionDataToPoolIndex.size} unique values, " + - s"$partitionDataPoolBytes bytes (protobuf)") - } + // Create a minimal placeholder IcebergScan. + // The actual FileScanTasks will be serialized per-partition at execution time + // in buildPerPartitionIcebergBytes() and passed via icebergTaskBytes. + // This placeholder is only used for Rust to identify the IcebergScan node to replace. + logInfo( + "IcebergScan: Creating placeholder (tasks will be provided per-partition at execution)") builder.clearChildren() Some(builder.setIcebergScan(icebergScanBuilder).build()) @@ -1042,4 +949,163 @@ object CometIcebergNativeScan extends CometOperatorSerde[CometBatchScanExec] wit // Create the CometIcebergNativeScanExec using the companion object's apply method CometIcebergNativeScanExec(nativeOp, op.wrapped, op.session, metadataLocation, metadata) } + + /** + * Extracts InputPartitions for each partition separately to avoid serializing all tasks to all + * executors. Each executor receives only its partition's tasks instead of all N x M tasks (N + * partitions x M tasks/partition). + */ + def buildPerPartitionBytes( + icebergScan: org.apache.spark.sql.comet.CometIcebergNativeScanExec, + numPartitions: Int): Array[Array[Byte]] = { + + // Get the original BatchScanExec + val batchScan = icebergScan.originalPlan + val inputRDD = batchScan.inputRDD + .asInstanceOf[org.apache.spark.sql.execution.datasources.v2.DataSourceRDD] + val allPartitions = inputRDD.partitions + + if (allPartitions.length != numPartitions) { + throw new org.apache.comet.CometRuntimeException( + s"Partition count mismatch: expected $numPartitions, got ${allPartitions.length}") + } + + allPartitions.zipWithIndex.map { case (partition, idx) => + try { + // Extract just this partition's InputPartitions + val inputPartitions = partition + .asInstanceOf[org.apache.spark.sql.execution.datasources.v2.DataSourceRDDPartition] + .inputPartitions + + // Serialize using the same logic as convert(), but only for this partition + val bytes = serializePartitionInputs( + icebergScan.output, + icebergScan.nativeIcebergScanMetadata, + inputPartitions, + idx) + + bytes + } catch { + case e: Exception => + throw new org.apache.comet.CometRuntimeException( + s"Failed to serialize partition $idx: ${e.getMessage}", + e) + } + }.toArray + } + + /** + * Serializes per-partition IcebergScan data for multiple scans using protobuf. Each entry maps + * plan_id to just the per-partition data (FileScanTasks and pools), avoiding duplication of + * common metadata across partitions. + */ + def buildMultiScanBytesForPartition( + perScanBytes: Map[Int, Array[Array[Byte]]], + partitionIdx: Int): Array[Byte] = { + + val replacementsBuilder = OperatorOuterClass.IcebergScanReplacements.newBuilder() + + perScanBytes.foreach { case (planId, allPartitionBytes) => + val partitionBytes = allPartitionBytes(partitionIdx) + val partitionData = OperatorOuterClass.IcebergFilePartitionData.parseFrom(partitionBytes) + replacementsBuilder.putReplacements(planId, partitionData) + } + + val replacements = replacementsBuilder.build() + replacements.toByteArray + } + + /** + * Serializes a single partition's InputPartitions to IcebergFilePartitionData protobuf. + * Contains only per-partition data (file_partitions and deduplication pools), not common + * metadata which stays in the placeholder IcebergScan. + */ + def serializePartitionInputs( + output: Seq[Attribute], + metadata: CometIcebergNativeScanMetadata, + inputPartitions: Seq[org.apache.spark.sql.connector.read.InputPartition], + partitionIndex: Int): Array[Byte] = { + + val icebergScanBuilder = OperatorOuterClass.IcebergScan.newBuilder() + + // Per-partition deduplication structures + val schemaToPoolIndex = mutable.HashMap[AnyRef, Int]() + val partitionTypeToPoolIndex = mutable.HashMap[String, Int]() + val partitionSpecToPoolIndex = mutable.HashMap[String, Int]() + val nameMappingToPoolIndex = mutable.HashMap[String, Int]() + val projectFieldIdsToPoolIndex = mutable.HashMap[Seq[Int], Int]() + val partitionDataToPoolIndex = mutable.HashMap[String, Int]() + val deleteFilesToPoolIndex = + mutable.HashMap[Seq[OperatorOuterClass.IcebergDeleteFile], Int]() + val residualToPoolIndex = mutable.HashMap[Option[Expr], Int]() + + // Process this partition's InputPartitions + val partitionBuilder = OperatorOuterClass.IcebergFilePartition.newBuilder() + + inputPartitions.foreach { inputPartition => + val inputPartClass = inputPartition.getClass + + try { + val taskGroupMethod = inputPartClass.getDeclaredMethod("taskGroup") + taskGroupMethod.setAccessible(true) + val taskGroup = taskGroupMethod.invoke(inputPartition) + + val taskGroupClass = taskGroup.getClass + val tasksMethod = taskGroupClass.getMethod("tasks") + val tasksCollection = + tasksMethod.invoke(taskGroup).asInstanceOf[java.util.Collection[_]] + + tasksCollection.asScala.foreach { task => + try { + val fileScanTask = serializeFileScanTask( + task, + output, + metadata, + icebergScanBuilder, + schemaToPoolIndex, + partitionTypeToPoolIndex, + partitionSpecToPoolIndex, + nameMappingToPoolIndex, + projectFieldIdsToPoolIndex, + partitionDataToPoolIndex, + deleteFilesToPoolIndex, + residualToPoolIndex) + + partitionBuilder.addFileScanTasks(fileScanTask) + } catch { + case e: Exception => + logWarning( + s"Failed to serialize task in partition $partitionIndex: ${e.getMessage}") + } + } + } catch { + case e: Exception => + logWarning( + "Failed to extract tasks from InputPartition in partition " + + s"$partitionIndex: ${e.getMessage}") + } + } + + // Build IcebergFilePartitionData with only per-partition data + val partitionData = OperatorOuterClass.IcebergFilePartitionData + .newBuilder() + .setFilePartition(partitionBuilder.build()) + .addAllSchemaPool(icebergScanBuilder.getSchemaPoolList) + .addAllPartitionTypePool(icebergScanBuilder.getPartitionTypePoolList) + .addAllPartitionSpecPool(icebergScanBuilder.getPartitionSpecPoolList) + .addAllNameMappingPool(icebergScanBuilder.getNameMappingPoolList) + .addAllProjectFieldIdsPool(icebergScanBuilder.getProjectFieldIdsPoolList) + .addAllPartitionDataPool(icebergScanBuilder.getPartitionDataPoolList) + .addAllDeleteFilesPool(icebergScanBuilder.getDeleteFilesPoolList) + .addAllResidualPool(icebergScanBuilder.getResidualPoolList) + .build() + + // Serialize IcebergFilePartitionData to bytes + val size = partitionData.getSerializedSize + val bytes = new Array[Byte](size) + val codedOutput = com.google.protobuf.CodedOutputStream.newInstance(bytes) + partitionData.writeTo(codedOutput) + codedOutput.checkNoSpaceLeft() + bytes + } } diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometExecRDD.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometExecRDD.scala index 2fd7f12c24..cd451a1bff 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometExecRDD.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometExecRDD.scala @@ -23,35 +23,155 @@ import org.apache.spark.{Partition, SparkContext, TaskContext} import org.apache.spark.rdd.{RDD, RDDOperationScope} import org.apache.spark.sql.vectorized.ColumnarBatch +/** + * Partition that carries per-partition Iceberg FileScanTask bytes. This ensures only relevant + * bytes are sent to each executor (not all partition bytes). + */ +private[spark] class CometIcebergPartition(override val index: Int, val taskBytes: Array[Byte]) + extends Partition { + override def hashCode(): Int = index + override def equals(obj: Any): Boolean = obj match { + case other: CometIcebergPartition => other.index == index + case _ => false + } +} + /** * A RDD that executes Spark SQL query in Comet native execution to generate ColumnarBatch. */ private[spark] class CometExecRDD( sc: SparkContext, - partitionNum: Int, - var f: (Seq[Iterator[ColumnarBatch]], Int, Int) => Iterator[ColumnarBatch]) + customPartitions: Array[Partition], + var f: ( + Seq[Iterator[ColumnarBatch]], + Int, + Int, + Option[Array[Byte]]) => Iterator[ColumnarBatch]) extends RDD[ColumnarBatch](sc, Nil) { override def compute(s: Partition, context: TaskContext): Iterator[ColumnarBatch] = { - f(Seq.empty, partitionNum, s.index) + val taskBytes = s match { + case p: CometIcebergPartition => Some(p.taskBytes) + case _ => None + } + f(Seq.empty, customPartitions.length, s.index, taskBytes) } - override protected def getPartitions: Array[Partition] = { - Array.tabulate(partitionNum)(i => - new Partition { - override def index: Int = i - }) + override protected def getPartitions: Array[Partition] = customPartitions +} + +/** + * Partition class for ZippedPartitionsWithIcebergRDD that combines multiple input RDD partitions + * with per-partition Iceberg task data. + */ +private class ZippedPartitionsWithIcebergPartition( + idx: Int, + @transient private val rdds: Seq[RDD[_]], + @transient private val preferredLocationsFunc: Seq[RDD[_]] => Seq[Partition]) + extends Partition { + + override val index: Int = idx + var partitions: Seq[Partition] = preferredLocationsFunc(rdds) +} + +/** + * RDD that zips multiple input RDDs while also passing per-partition IcebergScan bytes. Used for + * joins where one side is an IcebergScan and the other is broadcast/shuffle. + * + * This RDD combines the functionality of: + * - ZippedPartitionsRDD (zipping multiple input RDDs) + * - CometIcebergPartition (passing per-partition taskBytes) + * + * Each partition receives: + * - Iterators from all input RDDs (for broadcast/shuffle sides) + * - Per-partition taskBytes for IcebergScan(s) + */ +private[spark] class ZippedPartitionsWithIcebergRDD( + sc: SparkContext, + var rdds: Seq[RDD[ColumnarBatch]], + perPartitionBytes: Array[Array[Byte]])( + var f: (Seq[Iterator[ColumnarBatch]], Int, Int, Option[Array[Byte]]) => Iterator[ + ColumnarBatch]) + extends RDD[ColumnarBatch](sc, rdds.flatMap(_.dependencies)) { + + require( + rdds.forall(_.getNumPartitions == perPartitionBytes.length), + "All inputs and perPartitionBytes must have same partition count. " + + s"Input partitions: ${rdds.map(_.getNumPartitions).mkString(", ")}, " + + s"perPartitionBytes length: ${perPartitionBytes.length}") + + override def getPartitions: Array[Partition] = { + val numParts = perPartitionBytes.length + Array.tabulate(numParts) { idx => + new ZippedPartitionsWithIcebergPartition(idx, rdds, rdds => rdds.map(_.partitions(idx))) + } + } + + override def compute(split: Partition, context: TaskContext): Iterator[ColumnarBatch] = { + val partition = split.asInstanceOf[ZippedPartitionsWithIcebergPartition] + val numPartitions = getNumPartitions + + // Get iterators from all input RDDs (like ZippedPartitionsRDD) + val inputIterators = rdds.zipWithIndex.map { case (rdd, rddsIndex) => + rdd.iterator(partition.partitions(rddsIndex), context) + } + + // Get per-partition taskBytes for this partition + val taskBytes = perPartitionBytes(partition.index) + + // Create iterator with both inputs AND taskBytes + f(inputIterators, numPartitions, partition.index, Some(taskBytes)) + } + + override def clearDependencies(): Unit = { + super.clearDependencies() + rdds = null + f = null + } + + override def getPreferredLocations(split: Partition): Seq[String] = { + val partition = split.asInstanceOf[ZippedPartitionsWithIcebergPartition] + // Prefer locations from first input (usually the larger side) + rdds.head.preferredLocations(partition.partitions.head) } } object CometExecRDD { + // For regular execution without per-partition bytes def apply(sc: SparkContext, partitionNum: Int)( f: (Seq[Iterator[ColumnarBatch]], Int, Int) => Iterator[ColumnarBatch]) : RDD[ColumnarBatch] = withScope(sc) { - new CometExecRDD(sc, partitionNum, f) + val partitions = Array.tabulate(partitionNum)(i => + new Partition { + override def index: Int = i + }) + new CometExecRDD(sc, partitions, (inputs, numParts, idx, _) => f(inputs, numParts, idx)) + } + + // For execution with per-partition bytes (Iceberg scans) + def apply(sc: SparkContext, partitionNum: Int, perPartitionBytes: Array[Array[Byte]])( + f: (Seq[Iterator[ColumnarBatch]], Int, Int, Option[Array[Byte]]) => Iterator[ColumnarBatch]) + : RDD[ColumnarBatch] = + withScope(sc) { + val partitions: Array[Partition] = perPartitionBytes.zipWithIndex.map { case (bytes, idx) => + new CometIcebergPartition(idx, bytes): Partition + } + new CometExecRDD(sc, partitions, f) } private[spark] def withScope[U](sc: SparkContext)(body: => U): U = RDDOperationScope.withScope[U](sc)(body) } + +object ZippedPartitionsWithIcebergRDD { + def apply( + sc: SparkContext, + rdds: Seq[RDD[ColumnarBatch]], + perPartitionBytes: Array[Array[Byte]])( + f: (Seq[Iterator[ColumnarBatch]], Int, Int, Option[Array[Byte]]) => Iterator[ColumnarBatch]) + : RDD[ColumnarBatch] = + CometExecRDD.withScope(sc) { + new ZippedPartitionsWithIcebergRDD(sc, rdds, perPartitionBytes)(f) + } +} diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala index 6f33467efe..8fca67e4cc 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala @@ -293,7 +293,8 @@ abstract class CometNativeExec extends CometExec { def createCometExecIter( inputs: Seq[Iterator[ColumnarBatch]], numParts: Int, - partitionIndex: Int): CometExecIterator = { + partitionIndex: Int, + taskBytes: Option[Array[Byte]] = None): CometExecIterator = { val it = new CometExecIterator( CometExec.newIterId, inputs, @@ -303,7 +304,8 @@ abstract class CometNativeExec extends CometExec { numParts, partitionIndex, broadcastedHadoopConfForEncryption, - encryptedFilePaths) + encryptedFilePaths, + taskBytes) setSubqueries(it.id, this) @@ -395,11 +397,60 @@ abstract class CometNativeExec extends CometExec { throw new CometRuntimeException(s"No input for CometNativeExec:\n $this") } - if (inputs.nonEmpty) { - ZippedPartitionsRDD(sparkContext, inputs.toSeq)(createCometExecIter) - } else { + // Check if this plan has IcebergScan children - if so, use per-partition serialization + // Only collect IcebergScans at this boundary level, not nested inside other CometNativeExec + // operators with their own serializedPlanOpt (those handle their own extraction) + def collectIcebergScansAtBoundary(plan: SparkPlan): Seq[CometIcebergNativeScanExec] = { + plan match { + case icebergScan: CometIcebergNativeScanExec => + Seq(icebergScan) + case nativeExec: CometNativeExec + if nativeExec != this && nativeExec.serializedPlanOpt.isDefined => + // Stop at nested boundary nodes - they handle their own IcebergScans + Seq.empty + case other => + other.children.flatMap(collectIcebergScansAtBoundary) + } + } + + val icebergScans = collectIcebergScansAtBoundary(this) + + if (icebergScans.nonEmpty) { val partitionNum = firstNonBroadcastPlanNumPartitions - CometExecRDD(sparkContext, partitionNum)(createCometExecIter) + + // Build per-partition bytes for ALL IcebergScans + val perScanBytes: Map[Int, Array[Array[Byte]]] = icebergScans.map { scan => + scan.nativeOp.getPlanId -> org.apache.comet.serde.operator.CometIcebergNativeScan + .buildPerPartitionBytes(scan, partitionNum) + }.toMap + + // For each partition, build map of all scans' bytes + val perPartitionMultiScanBytes = Array.tabulate(partitionNum) { partitionIdx => + org.apache.comet.serde.operator.CometIcebergNativeScan + .buildMultiScanBytesForPartition(perScanBytes, partitionIdx) + } + + if (inputs.nonEmpty) { + // Mixed-input: IcebergScan(s) + broadcast/shuffle + ZippedPartitionsWithIcebergRDD( + sparkContext, + inputs.toSeq, + perPartitionMultiScanBytes)(createCometExecIter) + } else { + // Pure IcebergScan(s) - single or multiple + CometExecRDD(sparkContext, partitionNum, perPartitionMultiScanBytes)( + createCometExecIter) + } + } else { + // No IcebergScans - regular execution + if (inputs.nonEmpty) { + ZippedPartitionsRDD(sparkContext, inputs.toSeq)((inputs, numParts, idx) => + createCometExecIter(inputs, numParts, idx)) + } else { + val partitionNum = firstNonBroadcastPlanNumPartitions + CometExecRDD(sparkContext, partitionNum)((inputs, numParts, idx) => + createCometExecIter(inputs, numParts, idx)) + } } } }