diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometMetricNode.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometMetricNode.scala index 8c75df1d45..b0ded47580 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometMetricNode.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometMetricNode.scala @@ -21,7 +21,7 @@ package org.apache.spark.sql.comet import scala.jdk.CollectionConverters._ -import org.apache.spark.SparkContext +import org.apache.spark.{SparkContext, TaskContext} import org.apache.spark.internal.Logging import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} @@ -41,6 +41,49 @@ import org.apache.comet.serde.Metric case class CometMetricNode(metrics: Map[String, SQLMetric], children: Seq[CometMetricNode]) extends Logging { + /** + * Returns the leaf node (deepest single-child descendant). For a native scan plan like + * FilterExec -> DataSourceExec, this returns the DataSourceExec node which has the + * bytes_scanned and output_rows metrics from the Parquet reader. + */ + def leafNode: CometMetricNode = { + if (children.isEmpty) this + else children.head.leafNode + } + + /** + * Returns all leaf nodes (nodes with no children) in the metric tree. Unlike [[leafNode]] which + * only follows the first child, this finds all leaves, which is needed for plans with multiple + * scans (e.g., joins, unions). + */ + def leafNodes: Seq[CometMetricNode] = { + if (children.isEmpty) Seq(this) + else children.flatMap(_.leafNodes) + } + + /** + * Reports aggregated scan input metrics (bytesRead, recordsRead) to Spark's task metrics. + * Aggregates across all scan leaf nodes to handle plans with multiple scans (e.g., joins). Must + * be called in a TaskCompletionListener after the iterator is fully consumed. + */ + def reportScanInputMetrics(ctx: TaskContext): Unit = { + ctx.addTaskCompletionListener[Unit] { _ => + val scanLeaves = leafNodes.filter(_.metrics.contains("bytes_scanned")) + if (scanLeaves.nonEmpty) { + val totalBytes = scanLeaves.map(_.metrics("bytes_scanned").value).sum + val totalRows = scanLeaves.map { leaf => + val outputRows = + leaf.metrics.get("output_rows").map(_.value).getOrElse(0L) + val prunedRows = + leaf.metrics.get("pushdown_rows_pruned").map(_.value).getOrElse(0L) + outputRows + prunedRows + }.sum + ctx.taskMetrics().inputMetrics.setBytesRead(totalBytes) + ctx.taskMetrics().inputMetrics.setRecordsRead(totalRows) + } + } + } + /** * Gets a child node. Called from native. */ @@ -79,6 +122,7 @@ case class CometMetricNode(metrics: Map[String, SQLMetric], children: Seq[CometM } } + // Called via JNI from `comet_metric_node.rs` def set_all_from_bytes(bytes: Array[Byte]): Unit = { val metricNode = Metric.NativeMetricNode.parseFrom(bytes) set_all(metricNode) diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometNativeScanExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometNativeScanExec.scala index dcb975ac7a..ae2d873ef7 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometNativeScanExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometNativeScanExec.scala @@ -21,6 +21,7 @@ package org.apache.spark.sql.comet import scala.reflect.ClassTag +import org.apache.spark.{Partition, TaskContext} import org.apache.spark.rdd.RDD import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst._ @@ -180,18 +181,27 @@ case class CometNativeScanExec( (None, Seq.empty) } - CometExecRDD( + new CometExecRDD( sparkContext, - inputRDDs = Seq.empty, - commonByKey = Map(sourceKey -> commonData), - perPartitionByKey = Map(sourceKey -> perPartitionData), - serializedPlan = serializedPlan, - numPartitions = perPartitionData.length, - numOutputCols = output.length, - nativeMetrics = nativeMetrics, - subqueries = Seq.empty, - broadcastedHadoopConfForEncryption = broadcastedHadoopConfForEncryption, - encryptedFilePaths = encryptedFilePaths) + Seq.empty, + Map(sourceKey -> commonData), + Map(sourceKey -> perPartitionData), + serializedPlan, + perPartitionData.length, + output.length, + nativeMetrics, + Seq.empty, + broadcastedHadoopConfForEncryption, + encryptedFilePaths) { + override def compute(split: Partition, context: TaskContext): Iterator[ColumnarBatch] = { + val res = super.compute(split, context) + + // Report scan input metrics after the iterator is fully consumed. + Option(context).foreach(nativeMetrics.reportScanInputMetrics) + + res + } + } } override def doCanonicalize(): CometNativeScanExec = { diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala index 1f5e7b6677..21cbdab974 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala @@ -25,6 +25,7 @@ import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import scala.jdk.CollectionConverters._ +import org.apache.spark.{Partition, TaskContext} import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow @@ -558,7 +559,8 @@ abstract class CometNativeExec extends CometExec { // Unified RDD creation - CometExecRDD handles all cases val subqueries = collectSubqueries(this) - CometExecRDD( + val hasScanInput = sparkPlans.exists(_.isInstanceOf[CometNativeScanExec]) + new CometExecRDD( sparkContext, inputs.toSeq, commonByKey, @@ -570,7 +572,20 @@ abstract class CometNativeExec extends CometExec { subqueries, broadcastedHadoopConfForEncryption, encryptedFilePaths, - shuffleScanIndices) + shuffleScanIndices) { + override def compute( + split: Partition, + context: TaskContext): Iterator[ColumnarBatch] = { + val res = super.compute(split, context) + + // Report scan input metrics only when the native plan contains a scan. + if (hasScanInput) { + Option(context).foreach(nativeMetrics.reportScanInputMetrics) + } + + res + } + } } } diff --git a/spark/src/test/scala/org/apache/spark/sql/comet/CometTaskMetricsSuite.scala b/spark/src/test/scala/org/apache/spark/sql/comet/CometTaskMetricsSuite.scala index 3946aab184..cc02551bf8 100644 --- a/spark/src/test/scala/org/apache/spark/sql/comet/CometTaskMetricsSuite.scala +++ b/spark/src/test/scala/org/apache/spark/sql/comet/CometTaskMetricsSuite.scala @@ -21,6 +21,7 @@ package org.apache.spark.sql.comet import scala.collection.mutable +import org.apache.spark.SparkConf import org.apache.spark.executor.ShuffleReadMetrics import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.scheduler.SparkListener @@ -28,10 +29,17 @@ import org.apache.spark.scheduler.SparkListenerTaskEnd import org.apache.spark.sql.CometTestBase import org.apache.spark.sql.comet.execution.shuffle.CometNativeShuffle import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec +import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper +import org.apache.comet.CometConf + class CometTaskMetricsSuite extends CometTestBase with AdaptiveSparkPlanHelper { + override protected def sparkConf: SparkConf = { + super.sparkConf.set("spark.ui.enabled", "true") + } + import testImplicits._ test("per-task native shuffle metrics") { @@ -91,4 +99,189 @@ class CometTaskMetricsSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } } + + test("native_datafusion scan reports task-level input metrics matching Spark") { + val totalRows = 10000 + withTempPath { dir => + spark + .createDataFrame((0 until totalRows).map(i => (i, s"elem_$i"))) + .repartition(5) + .write + .parquet(dir.getAbsolutePath) + spark.read.parquet(dir.getAbsolutePath).createOrReplaceTempView("tbl") + // Collect baseline input metrics from vanilla Spark (Comet disabled) + val (sparkBytes, sparkRecords, _) = + collectInputMetrics( + "SELECT * FROM tbl where _1 > 2000", + CometConf.COMET_ENABLED.key -> "false") + + // Collect input metrics from Comet native_datafusion scan. + val (cometBytes, cometRecords, cometPlan) = collectInputMetrics( + "SELECT * FROM tbl where _1 > 2000", + CometConf.COMET_ENABLED.key -> "true", + CometConf.COMET_NATIVE_SCAN_IMPL.key -> CometConf.SCAN_NATIVE_DATAFUSION) + + // Verify the plan actually used CometNativeScanExec + assert( + find(cometPlan)(_.isInstanceOf[CometNativeScanExec]).isDefined, + s"Expected CometNativeScanExec in plan:\n${cometPlan.treeString}") + + assert(sparkRecords > 0, s"Spark outputRecords should be > 0, got $sparkRecords") + assert(cometRecords > 0, s"Comet outputRecords should be > 0, got $cometRecords") + + assert( + cometRecords == sparkRecords, + s"recordsRead mismatch: comet=$cometRecords, sparkRecords=$sparkRecords") + + // Bytes should be in the same ballpark -- both read the same Parquet file(s), + // but the exact byte count can differ due to reader implementation details + // (e.g. footer reads, page headers, buffering granularity). + assert(sparkBytes > 0, s"Spark bytesRead should be > 0, got $sparkBytes") + assert(cometBytes > 0, s"Comet bytesRead should be > 0, got $cometBytes") + val ratio = cometBytes.toDouble / sparkBytes.toDouble + assert( + ratio >= 0.7 && ratio <= 1.3, + s"bytesRead ratio out of range: comet=$cometBytes, spark=$sparkBytes, ratio=$ratio") + } + } + + test("input metrics aggregate across multiple native scans in a join") { + withTempPath { dir1 => + withTempPath { dir2 => + // Create two separate parquet tables + spark + .createDataFrame((0 until 5000).map(i => (i, s"left_$i"))) + .repartition(3) + .write + .parquet(dir1.getAbsolutePath) + spark + .createDataFrame((0 until 5000).map(i => (i, s"right_$i"))) + .repartition(3) + .write + .parquet(dir2.getAbsolutePath) + + spark.read.parquet(dir1.getAbsolutePath).createOrReplaceTempView("left_tbl") + spark.read.parquet(dir2.getAbsolutePath).createOrReplaceTempView("right_tbl") + + val joinQuery = "SELECT * FROM left_tbl JOIN right_tbl ON left_tbl._1 = right_tbl._1" + + // Collect baseline from vanilla Spark + val (sparkBytes, sparkRecords, _) = + collectInputMetrics(joinQuery, CometConf.COMET_ENABLED.key -> "false") + + // Collect from Comet native scan + val (cometBytes, cometRecords, cometPlan) = collectInputMetrics( + joinQuery, + CometConf.COMET_ENABLED.key -> "true", + CometConf.COMET_NATIVE_SCAN_IMPL.key -> CometConf.SCAN_NATIVE_DATAFUSION) + + // Verify the plan has multiple CometNativeScanExec nodes + val scanCount = collect(cometPlan) { case s: CometNativeScanExec => + s + }.size + assert( + scanCount >= 2, + s"Expected at least 2 CometNativeScanExec in plan, found $scanCount:\n" + + cometPlan.treeString) + + assert(sparkBytes > 0, s"Spark bytesRead should be > 0, got $sparkBytes") + assert(cometBytes > 0, s"Comet bytesRead should be > 0, got $cometBytes") + assert(sparkRecords > 0, s"Spark recordsRead should be > 0, got $sparkRecords") + assert(cometRecords > 0, s"Comet recordsRead should be > 0, got $cometRecords") + + // Both sides should contribute to the total bytes + val ratio = cometBytes.toDouble / sparkBytes.toDouble + assert( + ratio >= 0.7 && ratio <= 1.3, + s"bytesRead ratio out of range: comet=$cometBytes, spark=$sparkBytes, ratio=$ratio") + } + } + } + + test("input metrics aggregate across multiple native scans in a union") { + withTempPath { dir1 => + withTempPath { dir2 => + spark + .createDataFrame((0 until 5000).map(i => (i, s"left_$i"))) + .repartition(3) + .write + .parquet(dir1.getAbsolutePath) + spark + .createDataFrame((5000 until 10000).map(i => (i, s"right_$i"))) + .repartition(3) + .write + .parquet(dir2.getAbsolutePath) + + spark.read.parquet(dir1.getAbsolutePath).createOrReplaceTempView("union_left") + spark.read.parquet(dir2.getAbsolutePath).createOrReplaceTempView("union_right") + + val unionQuery = "SELECT * FROM union_left UNION ALL SELECT * FROM union_right" + + // Collect baseline from vanilla Spark + val (sparkBytes, sparkRecords, _) = + collectInputMetrics(unionQuery, CometConf.COMET_ENABLED.key -> "false") + + // Collect from Comet native scan + val (cometBytes, cometRecords, cometPlan) = collectInputMetrics( + unionQuery, + CometConf.COMET_ENABLED.key -> "true", + CometConf.COMET_NATIVE_SCAN_IMPL.key -> CometConf.SCAN_NATIVE_DATAFUSION) + + // Verify the plan has multiple CometNativeScanExec nodes + val scanCount = collect(cometPlan) { case s: CometNativeScanExec => + s + }.size + assert( + scanCount >= 2, + s"Expected at least 2 CometNativeScanExec in plan, found $scanCount:\n" + + cometPlan.treeString) + + assert(sparkBytes > 0, s"Spark bytesRead should be > 0, got $sparkBytes") + assert(cometBytes > 0, s"Comet bytesRead should be > 0, got $cometBytes") + assert(sparkRecords > 0, s"Spark recordsRead should be > 0, got $sparkRecords") + assert(cometRecords > 0, s"Comet recordsRead should be > 0, got $cometRecords") + + val ratio = cometBytes.toDouble / sparkBytes.toDouble + assert( + ratio >= 0.7 && ratio <= 1.3, + s"bytesRead ratio out of range: comet=$cometBytes, spark=$sparkBytes, ratio=$ratio") + } + } + } + + /** + * Runs the given query with the given SQL config overrides and returns the aggregated + * (bytesRead, recordsRead) across all tasks, along with the executed plan. + * + * Uses AppStatusStore (same source as Spark UI) to read task-level input metrics. + * AppStatusStore stores immutable snapshots of metric values, unlike SparkListener's + * InputMetrics which are backed by mutable accumulators that can be reset. + */ + private def collectInputMetrics( + query: String, + confs: (String, String)*): (Long, Long, SparkPlan) = { + val store = spark.sparkContext.statusStore + + // Record existing stage IDs so we only look at stages from our query + val stagesBefore = store.stageList(null).map(_.stageId).toSet + + var plan: SparkPlan = null + withSQLConf(confs: _*) { + val df = sql(query) + df.collect() + plan = stripAQEPlan(df.queryExecution.executedPlan) + } + + // Wait for listener bus to flush all events into the status store + spark.sparkContext.listenerBus.waitUntilEmpty() + + // Sum input metrics from stages created by our query + val newStages = store.stageList(null).filterNot(s => stagesBefore.contains(s.stageId)) + assert(newStages.nonEmpty, s"No new stages found for confs=$confs") + + val totalBytes = newStages.map(_.inputBytes).sum + val totalRecords = newStages.map(_.inputRecords).sum + + (totalBytes, totalRecords, plan) + } }