From 2ad3b6f72fde7ab220dd0bcdb37d89233205247c Mon Sep 17 00:00:00 2001 From: comphead Date: Mon, 30 Mar 2026 09:31:52 -0700 Subject: [PATCH 01/15] chore: `native_datafusion` to report scan task input metrics --- .../apache/spark/sql/comet/CometExecRDD.scala | 12 ++++ .../spark/sql/comet/CometMetricNode.scala | 11 ++++ .../sql/comet/CometTaskMetricsSuite.scala | 65 +++++++++++++++++++ 3 files changed, 88 insertions(+) diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometExecRDD.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometExecRDD.scala index c5014818c4..5310b7c84a 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometExecRDD.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometExecRDD.scala @@ -139,6 +139,18 @@ private[spark] class CometExecRDD( ctx.addTaskCompletionListener[Unit] { _ => it.close() subqueries.foreach(sub => CometScalarSubquery.removeSubquery(it.id, sub)) + + // Propagate native scan metrics (bytes_scanned, output_rows) to Spark's task-level + // inputMetrics so they appear in the Spark UI "Input" column and are reported via + // the listener infrastructure. The native reader bypasses Hadoop's Java FileSystem, + // so thread-local FS statistics are never updated -- we bridge the gap here. + val bytesScannedMetric = nativeMetrics.findMetric("bytes_scanned") + val outputRowsMetric = nativeMetrics.findMetric("output_rows") + if (bytesScannedMetric.isDefined || outputRowsMetric.isDefined) { + val inputMetrics = ctx.taskMetrics().inputMetrics + bytesScannedMetric.foreach(m => inputMetrics.setBytesRead(m.value)) + outputRowsMetric.foreach(m => inputMetrics.setRecordsRead(m.value)) + } } } diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometMetricNode.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometMetricNode.scala index 8c75df1d45..2867c54a45 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometMetricNode.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometMetricNode.scala @@ -79,10 +79,21 @@ case class CometMetricNode(metrics: Map[String, SQLMetric], children: Seq[CometM } } + // Called via JNI from `comet_metric_node.rs` def set_all_from_bytes(bytes: Array[Byte]): Unit = { val metricNode = Metric.NativeMetricNode.parseFrom(bytes) set_all(metricNode) } + + /** + * Finds a metric by name in this node or any descendant node. Returns the first match found via + * depth-first search. + */ + def findMetric(name: String): Option[SQLMetric] = { + metrics.get(name).orElse { + children.iterator.map(_.findMetric(name)).collectFirst { case Some(m) => m } + } + } } object CometMetricNode { diff --git a/spark/src/test/scala/org/apache/spark/sql/comet/CometTaskMetricsSuite.scala b/spark/src/test/scala/org/apache/spark/sql/comet/CometTaskMetricsSuite.scala index 3946aab184..ec77dbc1fc 100644 --- a/spark/src/test/scala/org/apache/spark/sql/comet/CometTaskMetricsSuite.scala +++ b/spark/src/test/scala/org/apache/spark/sql/comet/CometTaskMetricsSuite.scala @@ -21,6 +21,7 @@ package org.apache.spark.sql.comet import scala.collection.mutable +import org.apache.spark.executor.InputMetrics import org.apache.spark.executor.ShuffleReadMetrics import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.scheduler.SparkListener @@ -30,6 +31,8 @@ import org.apache.spark.sql.comet.execution.shuffle.CometNativeShuffle import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper +import org.apache.comet.CometConf + class CometTaskMetricsSuite extends CometTestBase with AdaptiveSparkPlanHelper { import testImplicits._ @@ -91,4 +94,66 @@ class CometTaskMetricsSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } } + + test("native_datafusion scan reports task-level input metrics matching Spark") { + withParquetTable((0 until 10000).map(i => (i, (i + 1).toLong)), "tbl") { + // Collect baseline input metrics from vanilla Spark (Comet disabled) + val (sparkBytes, sparkRecords) = collectInputMetrics(CometConf.COMET_ENABLED.key -> "false") + + // Collect input metrics from Comet native_datafusion scan + val (cometBytes, cometRecords) = collectInputMetrics( + CometConf.COMET_NATIVE_SCAN_IMPL.key -> CometConf.SCAN_NATIVE_DATAFUSION) + + // Records must match exactly + assert( + cometRecords == sparkRecords, + s"recordsRead mismatch: comet=$cometRecords, spark=$sparkRecords") + + // Bytes should be in the same ballpark -- both read the same Parquet file(s), + // but the exact byte count can differ due to reader implementation details + // (e.g. footer reads, page headers, buffering granularity). + assert(sparkBytes > 0, s"Spark bytesRead should be > 0, got $sparkBytes") + assert(cometBytes > 0, s"Comet bytesRead should be > 0, got $cometBytes") + val ratio = cometBytes.toDouble / sparkBytes.toDouble + assert( + ratio >= 0.85 && ratio <= 1.15, + s"bytesRead ratio out of range: comet=$cometBytes, spark=$sparkBytes, ratio=$ratio") + } + } + + /** + * Runs `SELECT * FROM tbl` with the given SQL config overrides and returns the aggregated + * (bytesRead, recordsRead) across all tasks. + */ + private def collectInputMetrics(confs: (String, String)*): (Long, Long) = { + val inputMetricsList = mutable.ArrayBuffer.empty[InputMetrics] + + val listener = new SparkListener { + override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = { + val im = taskEnd.taskMetrics.inputMetrics + inputMetricsList.synchronized { + inputMetricsList += im + } + } + } + + spark.sparkContext.addSparkListener(listener) + try { + // Drain any earlier events + spark.sparkContext.listenerBus.waitUntilEmpty() + + withSQLConf(confs: _*) { + sql("SELECT * FROM tbl").collect() + } + + spark.sparkContext.listenerBus.waitUntilEmpty() + + assert(inputMetricsList.nonEmpty, s"No input metrics found for confs=$confs") + val totalBytes = inputMetricsList.map(_.bytesRead).sum + val totalRecords = inputMetricsList.map(_.recordsRead).sum + (totalBytes, totalRecords) + } finally { + spark.sparkContext.removeSparkListener(listener) + } + } } From e2c6093a3abca023a05df6c5aba850f806da634b Mon Sep 17 00:00:00 2001 From: comphead Date: Mon, 30 Mar 2026 09:54:18 -0700 Subject: [PATCH 02/15] chore: `native_datafusion` to report scan task input metrics --- .../org/apache/spark/sql/comet/CometTaskMetricsSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spark/src/test/scala/org/apache/spark/sql/comet/CometTaskMetricsSuite.scala b/spark/src/test/scala/org/apache/spark/sql/comet/CometTaskMetricsSuite.scala index ec77dbc1fc..59d02512a0 100644 --- a/spark/src/test/scala/org/apache/spark/sql/comet/CometTaskMetricsSuite.scala +++ b/spark/src/test/scala/org/apache/spark/sql/comet/CometTaskMetricsSuite.scala @@ -116,7 +116,7 @@ class CometTaskMetricsSuite extends CometTestBase with AdaptiveSparkPlanHelper { assert(cometBytes > 0, s"Comet bytesRead should be > 0, got $cometBytes") val ratio = cometBytes.toDouble / sparkBytes.toDouble assert( - ratio >= 0.85 && ratio <= 1.15, + ratio >= 0.8 && ratio <= 1.2, s"bytesRead ratio out of range: comet=$cometBytes, spark=$sparkBytes, ratio=$ratio") } } From ac6b869680ed29f5212a4abef4e1bed0581320cb Mon Sep 17 00:00:00 2001 From: comphead Date: Mon, 30 Mar 2026 16:53:48 -0700 Subject: [PATCH 03/15] chore: `native_datafusion` to report scan task input metrics --- .../apache/spark/sql/comet/CometExecRDD.scala | 17 ++++++----------- .../spark/sql/comet/CometMetricNode.scala | 10 ---------- 2 files changed, 6 insertions(+), 21 deletions(-) diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometExecRDD.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometExecRDD.scala index 5310b7c84a..c547b43d48 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometExecRDD.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometExecRDD.scala @@ -140,17 +140,12 @@ private[spark] class CometExecRDD( it.close() subqueries.foreach(sub => CometScalarSubquery.removeSubquery(it.id, sub)) - // Propagate native scan metrics (bytes_scanned, output_rows) to Spark's task-level - // inputMetrics so they appear in the Spark UI "Input" column and are reported via - // the listener infrastructure. The native reader bypasses Hadoop's Java FileSystem, - // so thread-local FS statistics are never updated -- we bridge the gap here. - val bytesScannedMetric = nativeMetrics.findMetric("bytes_scanned") - val outputRowsMetric = nativeMetrics.findMetric("output_rows") - if (bytesScannedMetric.isDefined || outputRowsMetric.isDefined) { - val inputMetrics = ctx.taskMetrics().inputMetrics - bytesScannedMetric.foreach(m => inputMetrics.setBytesRead(m.value)) - outputRowsMetric.foreach(m => inputMetrics.setRecordsRead(m.value)) - } + nativeMetrics.metrics + .get("bytes_scanned") + .foreach(m => ctx.taskMetrics().inputMetrics.setBytesRead(m.value)) + nativeMetrics.metrics + .get("output_rows") + .foreach(m => ctx.taskMetrics().inputMetrics.setRecordsRead(m.value)) } } diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometMetricNode.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometMetricNode.scala index 2867c54a45..7883775c80 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometMetricNode.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometMetricNode.scala @@ -84,16 +84,6 @@ case class CometMetricNode(metrics: Map[String, SQLMetric], children: Seq[CometM val metricNode = Metric.NativeMetricNode.parseFrom(bytes) set_all(metricNode) } - - /** - * Finds a metric by name in this node or any descendant node. Returns the first match found via - * depth-first search. - */ - def findMetric(name: String): Option[SQLMetric] = { - metrics.get(name).orElse { - children.iterator.map(_.findMetric(name)).collectFirst { case Some(m) => m } - } - } } object CometMetricNode { From 7c8377d78ae68a52a98a9f52f7a953c6d5384ec1 Mon Sep 17 00:00:00 2001 From: comphead Date: Tue, 31 Mar 2026 08:34:14 -0700 Subject: [PATCH 04/15] chore: `native_datafusion` to report scan task input metrics --- .../spark/sql/comet/CometTaskMetricsSuite.scala | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/spark/src/test/scala/org/apache/spark/sql/comet/CometTaskMetricsSuite.scala b/spark/src/test/scala/org/apache/spark/sql/comet/CometTaskMetricsSuite.scala index 59d02512a0..57a28e97fb 100644 --- a/spark/src/test/scala/org/apache/spark/sql/comet/CometTaskMetricsSuite.scala +++ b/spark/src/test/scala/org/apache/spark/sql/comet/CometTaskMetricsSuite.scala @@ -100,8 +100,21 @@ class CometTaskMetricsSuite extends CometTestBase with AdaptiveSparkPlanHelper { // Collect baseline input metrics from vanilla Spark (Comet disabled) val (sparkBytes, sparkRecords) = collectInputMetrics(CometConf.COMET_ENABLED.key -> "false") + // Verify the plan actually uses CometNativeScanExec before collecting metrics + withSQLConf( + CometConf.COMET_ENABLED.key -> "true", + CometConf.COMET_NATIVE_SCAN_IMPL.key -> CometConf.SCAN_NATIVE_DATAFUSION) { + val df = sql("SELECT * FROM tbl") + df.collect() // force plan materialization for AQE + val plan = stripAQEPlan(df.queryExecution.executedPlan) + assert( + find(plan)(_.isInstanceOf[CometNativeScanExec]).isDefined, + s"Expected CometNativeScanExec in plan:\n${plan.treeString}") + } + // Collect input metrics from Comet native_datafusion scan val (cometBytes, cometRecords) = collectInputMetrics( + CometConf.COMET_ENABLED.key -> "true", CometConf.COMET_NATIVE_SCAN_IMPL.key -> CometConf.SCAN_NATIVE_DATAFUSION) // Records must match exactly From 0ec7f55531a10299d7fe7c7cc2862555fe7d6f24 Mon Sep 17 00:00:00 2001 From: comphead Date: Tue, 31 Mar 2026 14:07:05 -0700 Subject: [PATCH 05/15] chore: `native_datafusion` to report scan task input metrics --- .../sql/comet/CometTaskMetricsSuite.scala | 34 +++++++++---------- 1 file changed, 16 insertions(+), 18 deletions(-) diff --git a/spark/src/test/scala/org/apache/spark/sql/comet/CometTaskMetricsSuite.scala b/spark/src/test/scala/org/apache/spark/sql/comet/CometTaskMetricsSuite.scala index 57a28e97fb..f436a43d0a 100644 --- a/spark/src/test/scala/org/apache/spark/sql/comet/CometTaskMetricsSuite.scala +++ b/spark/src/test/scala/org/apache/spark/sql/comet/CometTaskMetricsSuite.scala @@ -29,6 +29,7 @@ import org.apache.spark.scheduler.SparkListenerTaskEnd import org.apache.spark.sql.CometTestBase import org.apache.spark.sql.comet.execution.shuffle.CometNativeShuffle import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec +import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.comet.CometConf @@ -98,25 +99,19 @@ class CometTaskMetricsSuite extends CometTestBase with AdaptiveSparkPlanHelper { test("native_datafusion scan reports task-level input metrics matching Spark") { withParquetTable((0 until 10000).map(i => (i, (i + 1).toLong)), "tbl") { // Collect baseline input metrics from vanilla Spark (Comet disabled) - val (sparkBytes, sparkRecords) = collectInputMetrics(CometConf.COMET_ENABLED.key -> "false") - - // Verify the plan actually uses CometNativeScanExec before collecting metrics - withSQLConf( - CometConf.COMET_ENABLED.key -> "true", - CometConf.COMET_NATIVE_SCAN_IMPL.key -> CometConf.SCAN_NATIVE_DATAFUSION) { - val df = sql("SELECT * FROM tbl") - df.collect() // force plan materialization for AQE - val plan = stripAQEPlan(df.queryExecution.executedPlan) - assert( - find(plan)(_.isInstanceOf[CometNativeScanExec]).isDefined, - s"Expected CometNativeScanExec in plan:\n${plan.treeString}") - } + val (sparkBytes, sparkRecords, _) = + collectInputMetrics(CometConf.COMET_ENABLED.key -> "false") // Collect input metrics from Comet native_datafusion scan - val (cometBytes, cometRecords) = collectInputMetrics( + val (cometBytes, cometRecords, cometPlan) = collectInputMetrics( CometConf.COMET_ENABLED.key -> "true", CometConf.COMET_NATIVE_SCAN_IMPL.key -> CometConf.SCAN_NATIVE_DATAFUSION) + // Verify the plan actually used CometNativeScanExec + assert( + find(cometPlan)(_.isInstanceOf[CometNativeScanExec]).isDefined, + s"Expected CometNativeScanExec in plan:\n${cometPlan.treeString}") + // Records must match exactly assert( cometRecords == sparkRecords, @@ -136,9 +131,9 @@ class CometTaskMetricsSuite extends CometTestBase with AdaptiveSparkPlanHelper { /** * Runs `SELECT * FROM tbl` with the given SQL config overrides and returns the aggregated - * (bytesRead, recordsRead) across all tasks. + * (bytesRead, recordsRead) across all tasks, along with the executed plan. */ - private def collectInputMetrics(confs: (String, String)*): (Long, Long) = { + private def collectInputMetrics(confs: (String, String)*): (Long, Long, SparkPlan) = { val inputMetricsList = mutable.ArrayBuffer.empty[InputMetrics] val listener = new SparkListener { @@ -155,8 +150,11 @@ class CometTaskMetricsSuite extends CometTestBase with AdaptiveSparkPlanHelper { // Drain any earlier events spark.sparkContext.listenerBus.waitUntilEmpty() + var plan: SparkPlan = null withSQLConf(confs: _*) { - sql("SELECT * FROM tbl").collect() + val df = sql("SELECT * FROM tbl where _1 > 2000") + df.collect() + plan = stripAQEPlan(df.queryExecution.executedPlan) } spark.sparkContext.listenerBus.waitUntilEmpty() @@ -164,7 +162,7 @@ class CometTaskMetricsSuite extends CometTestBase with AdaptiveSparkPlanHelper { assert(inputMetricsList.nonEmpty, s"No input metrics found for confs=$confs") val totalBytes = inputMetricsList.map(_.bytesRead).sum val totalRecords = inputMetricsList.map(_.recordsRead).sum - (totalBytes, totalRecords) + (totalBytes, totalRecords, plan) } finally { spark.sparkContext.removeSparkListener(listener) } From b1a0d67501084e226b8220b194034f13fb0b3609 Mon Sep 17 00:00:00 2001 From: comphead Date: Wed, 1 Apr 2026 10:14:34 -0700 Subject: [PATCH 06/15] chore: `native_datafusion` to report scan task input metrics --- .../apache/spark/sql/comet/CometExecRDD.scala | 24 ++++++++++++------- .../comet/CometIcebergNativeScanExec.scala | 3 ++- .../spark/sql/comet/CometMetricNode.scala | 10 ++++++++ .../spark/sql/comet/CometNativeScanExec.scala | 3 ++- 4 files changed, 29 insertions(+), 11 deletions(-) diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometExecRDD.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometExecRDD.scala index c547b43d48..fee8a6fcae 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometExecRDD.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometExecRDD.scala @@ -66,7 +66,8 @@ private[spark] class CometExecRDD( subqueries: Seq[ScalarSubquery], broadcastedHadoopConfForEncryption: Option[Broadcast[SerializableConfiguration]] = None, encryptedFilePaths: Seq[String] = Seq.empty, - shuffleScanIndices: Set[Int] = Set.empty) + shuffleScanIndices: Set[Int] = Set.empty, + hasNativeScan: Boolean = false) extends RDD[ColumnarBatch](sc, inputRDDs.map(rdd => new OneToOneDependency(rdd))) { // Determine partition count: from inputs if available, otherwise from parameter @@ -140,12 +141,15 @@ private[spark] class CometExecRDD( it.close() subqueries.foreach(sub => CometScalarSubquery.removeSubquery(it.id, sub)) - nativeMetrics.metrics - .get("bytes_scanned") - .foreach(m => ctx.taskMetrics().inputMetrics.setBytesRead(m.value)) - nativeMetrics.metrics - .get("output_rows") - .foreach(m => ctx.taskMetrics().inputMetrics.setRecordsRead(m.value)) + if (hasNativeScan) { + val leaf = nativeMetrics.leafNode + leaf.metrics.get("bytes_scanned").foreach { bs => + ctx.taskMetrics().inputMetrics.setBytesRead(bs.value) + leaf.metrics + .get("output_rows") + .foreach(m => ctx.taskMetrics().inputMetrics.setRecordsRead(m.value)) + } + } } } @@ -187,7 +191,8 @@ object CometExecRDD { subqueries: Seq[ScalarSubquery], broadcastedHadoopConfForEncryption: Option[Broadcast[SerializableConfiguration]] = None, encryptedFilePaths: Seq[String] = Seq.empty, - shuffleScanIndices: Set[Int] = Set.empty): CometExecRDD = { + shuffleScanIndices: Set[Int] = Set.empty, + hasNativeScan: Boolean = false): CometExecRDD = { // scalastyle:on new CometExecRDD( @@ -202,6 +207,7 @@ object CometExecRDD { subqueries, broadcastedHadoopConfForEncryption, encryptedFilePaths, - shuffleScanIndices) + shuffleScanIndices, + hasNativeScan) } } diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometIcebergNativeScanExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometIcebergNativeScanExec.scala index 36085b6329..b4a408eeda 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometIcebergNativeScanExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometIcebergNativeScanExec.scala @@ -305,7 +305,8 @@ case class CometIcebergNativeScanExec( numPartitions = perPartitionData.length, numOutputCols = output.length, nativeMetrics = nativeMetrics, - subqueries = Seq.empty) + subqueries = Seq.empty, + hasNativeScan = true) } /** diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometMetricNode.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometMetricNode.scala index 7883775c80..c60202a96f 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometMetricNode.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometMetricNode.scala @@ -41,6 +41,16 @@ import org.apache.comet.serde.Metric case class CometMetricNode(metrics: Map[String, SQLMetric], children: Seq[CometMetricNode]) extends Logging { + /** + * Returns the leaf node (deepest single-child descendant). For a native scan plan like + * FilterExec -> DataSourceExec, this returns the DataSourceExec node which has the + * bytes_scanned and output_rows metrics from the Parquet reader. + */ + def leafNode: CometMetricNode = { + if (children.isEmpty) this + else children.head.leafNode + } + /** * Gets a child node. Called from native. */ diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometNativeScanExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometNativeScanExec.scala index dcb975ac7a..f5e6495724 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometNativeScanExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometNativeScanExec.scala @@ -191,7 +191,8 @@ case class CometNativeScanExec( nativeMetrics = nativeMetrics, subqueries = Seq.empty, broadcastedHadoopConfForEncryption = broadcastedHadoopConfForEncryption, - encryptedFilePaths = encryptedFilePaths) + encryptedFilePaths = encryptedFilePaths, + hasNativeScan = true) } override def doCanonicalize(): CometNativeScanExec = { From 70252c2e699ab16c2d4141677722fc1c02c288ba Mon Sep 17 00:00:00 2001 From: comphead Date: Wed, 1 Apr 2026 14:55:36 -0700 Subject: [PATCH 07/15] chore: `native_datafusion` to report scan task input metrics --- .../apache/spark/sql/comet/CometExecRDD.scala | 11 ++- .../sql/comet/CometTaskMetricsSuite.scala | 72 ++++++++++--------- 2 files changed, 47 insertions(+), 36 deletions(-) diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometExecRDD.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometExecRDD.scala index fee8a6fcae..1ad7526d70 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometExecRDD.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometExecRDD.scala @@ -145,9 +145,14 @@ private[spark] class CometExecRDD( val leaf = nativeMetrics.leafNode leaf.metrics.get("bytes_scanned").foreach { bs => ctx.taskMetrics().inputMetrics.setBytesRead(bs.value) - leaf.metrics - .get("output_rows") - .foreach(m => ctx.taskMetrics().inputMetrics.setRecordsRead(m.value)) + // Total rows read from Parquet = rows output after pushdown (output_rows) + // + rows pruned by pushdown filters (pushdown_rows_pruned). + // This matches Spark's recordsRead which counts rows before filtering. + val outputRows = + leaf.metrics.get("output_rows").map(_.value).getOrElse(0L) + val prunedRows = + leaf.metrics.get("pushdown_rows_pruned").map(_.value).getOrElse(0L) + ctx.taskMetrics().inputMetrics.setRecordsRead(outputRows + prunedRows) } } } diff --git a/spark/src/test/scala/org/apache/spark/sql/comet/CometTaskMetricsSuite.scala b/spark/src/test/scala/org/apache/spark/sql/comet/CometTaskMetricsSuite.scala index f436a43d0a..3107440ea3 100644 --- a/spark/src/test/scala/org/apache/spark/sql/comet/CometTaskMetricsSuite.scala +++ b/spark/src/test/scala/org/apache/spark/sql/comet/CometTaskMetricsSuite.scala @@ -21,7 +21,7 @@ package org.apache.spark.sql.comet import scala.collection.mutable -import org.apache.spark.executor.InputMetrics +import org.apache.spark.SparkConf import org.apache.spark.executor.ShuffleReadMetrics import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.scheduler.SparkListener @@ -36,6 +36,10 @@ import org.apache.comet.CometConf class CometTaskMetricsSuite extends CometTestBase with AdaptiveSparkPlanHelper { + override protected def sparkConf: SparkConf = { + super.sparkConf.set("spark.ui.enabled", "true") + } + import testImplicits._ test("per-task native shuffle metrics") { @@ -97,12 +101,20 @@ class CometTaskMetricsSuite extends CometTestBase with AdaptiveSparkPlanHelper { } test("native_datafusion scan reports task-level input metrics matching Spark") { - withParquetTable((0 until 10000).map(i => (i, (i + 1).toLong)), "tbl") { + val totalRows = 10000 + withTempPath { dir => + val rng = new scala.util.Random(42) + spark + .createDataFrame((0 until totalRows).map(_ => (rng.nextInt(), rng.nextLong()))) + .repartition(5) + .write + .parquet(dir.getAbsolutePath) + spark.read.parquet(dir.getAbsolutePath).createOrReplaceTempView("tbl") // Collect baseline input metrics from vanilla Spark (Comet disabled) val (sparkBytes, sparkRecords, _) = collectInputMetrics(CometConf.COMET_ENABLED.key -> "false") - // Collect input metrics from Comet native_datafusion scan + // Collect input metrics from Comet native_datafusion scan. val (cometBytes, cometRecords, cometPlan) = collectInputMetrics( CometConf.COMET_ENABLED.key -> "true", CometConf.COMET_NATIVE_SCAN_IMPL.key -> CometConf.SCAN_NATIVE_DATAFUSION) @@ -112,10 +124,9 @@ class CometTaskMetricsSuite extends CometTestBase with AdaptiveSparkPlanHelper { find(cometPlan)(_.isInstanceOf[CometNativeScanExec]).isDefined, s"Expected CometNativeScanExec in plan:\n${cometPlan.treeString}") - // Records must match exactly assert( cometRecords == sparkRecords, - s"recordsRead mismatch: comet=$cometRecords, spark=$sparkRecords") + s"recordsRead mismatch: comet=$cometRecords, sparkRecords=$sparkRecords") // Bytes should be in the same ballpark -- both read the same Parquet file(s), // but the exact byte count can differ due to reader implementation details @@ -130,41 +141,36 @@ class CometTaskMetricsSuite extends CometTestBase with AdaptiveSparkPlanHelper { } /** - * Runs `SELECT * FROM tbl` with the given SQL config overrides and returns the aggregated - * (bytesRead, recordsRead) across all tasks, along with the executed plan. + * Runs `SELECT * FROM tbl WHERE _1 > 2000` with the given SQL config overrides and returns the + * aggregated (bytesRead, recordsRead) across all tasks, along with the executed plan. + * + * Uses AppStatusStore (same source as Spark UI) to read task-level input metrics. + * AppStatusStore stores immutable snapshots of metric values, unlike SparkListener's + * InputMetrics which are backed by mutable accumulators that can be reset. */ private def collectInputMetrics(confs: (String, String)*): (Long, Long, SparkPlan) = { - val inputMetricsList = mutable.ArrayBuffer.empty[InputMetrics] + val store = spark.sparkContext.statusStore - val listener = new SparkListener { - override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = { - val im = taskEnd.taskMetrics.inputMetrics - inputMetricsList.synchronized { - inputMetricsList += im - } - } + // Record existing stage IDs so we only look at stages from our query + val stagesBefore = store.stageList(null).map(_.stageId).toSet + + var plan: SparkPlan = null + withSQLConf(confs: _*) { + val df = sql("SELECT * FROM tbl where _1 > 2000") + df.collect() + plan = stripAQEPlan(df.queryExecution.executedPlan) } - spark.sparkContext.addSparkListener(listener) - try { - // Drain any earlier events - spark.sparkContext.listenerBus.waitUntilEmpty() + // Wait for listener bus to flush all events into the status store + spark.sparkContext.listenerBus.waitUntilEmpty() - var plan: SparkPlan = null - withSQLConf(confs: _*) { - val df = sql("SELECT * FROM tbl where _1 > 2000") - df.collect() - plan = stripAQEPlan(df.queryExecution.executedPlan) - } + // Sum input metrics from stages created by our query + val newStages = store.stageList(null).filterNot(s => stagesBefore.contains(s.stageId)) + assert(newStages.nonEmpty, s"No new stages found for confs=$confs") - spark.sparkContext.listenerBus.waitUntilEmpty() + val totalBytes = newStages.map(_.inputBytes).sum + val totalRecords = newStages.map(_.inputRecords).sum - assert(inputMetricsList.nonEmpty, s"No input metrics found for confs=$confs") - val totalBytes = inputMetricsList.map(_.bytesRead).sum - val totalRecords = inputMetricsList.map(_.recordsRead).sum - (totalBytes, totalRecords, plan) - } finally { - spark.sparkContext.removeSparkListener(listener) - } + (totalBytes, totalRecords, plan) } } From 4137783e05364acfa6c17123737c0701d0eea26b Mon Sep 17 00:00:00 2001 From: comphead Date: Wed, 1 Apr 2026 16:22:58 -0700 Subject: [PATCH 08/15] chore: `native_datafusion` to report scan task input metrics --- .../apache/spark/sql/comet/CometExecRDD.scala | 27 ++++++++++--------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometExecRDD.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometExecRDD.scala index 1ad7526d70..b680b114e6 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometExecRDD.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometExecRDD.scala @@ -141,19 +141,20 @@ private[spark] class CometExecRDD( it.close() subqueries.foreach(sub => CometScalarSubquery.removeSubquery(it.id, sub)) - if (hasNativeScan) { - val leaf = nativeMetrics.leafNode - leaf.metrics.get("bytes_scanned").foreach { bs => - ctx.taskMetrics().inputMetrics.setBytesRead(bs.value) - // Total rows read from Parquet = rows output after pushdown (output_rows) - // + rows pruned by pushdown filters (pushdown_rows_pruned). - // This matches Spark's recordsRead which counts rows before filtering. - val outputRows = - leaf.metrics.get("output_rows").map(_.value).getOrElse(0L) - val prunedRows = - leaf.metrics.get("pushdown_rows_pruned").map(_.value).getOrElse(0L) - ctx.taskMetrics().inputMetrics.setRecordsRead(outputRows + prunedRows) - } + // Report scan input metrics when the leaf node has scan metrics. + // The bytes_scanned key only exists in nativeScanMetrics, so this + // naturally skips non-scan CometExecRDD instances. + val leaf = nativeMetrics.leafNode + leaf.metrics.get("bytes_scanned").foreach { bs => + ctx.taskMetrics().inputMetrics.setBytesRead(bs.value) + // Total rows read from Parquet = rows output after pushdown (output_rows) + // + rows pruned by pushdown filters (pushdown_rows_pruned). + // This matches Spark's recordsRead which counts rows before filtering. + val outputRows = + leaf.metrics.get("output_rows").map(_.value).getOrElse(0L) + val prunedRows = + leaf.metrics.get("pushdown_rows_pruned").map(_.value).getOrElse(0L) + ctx.taskMetrics().inputMetrics.setRecordsRead(outputRows + prunedRows) } } } From aa24794acde49266d8d3a48e70156ac12b164eac Mon Sep 17 00:00:00 2001 From: comphead Date: Thu, 2 Apr 2026 09:23:43 -0700 Subject: [PATCH 09/15] chore: `native_datafusion` fails on repartition + count --- .../scala/org/apache/spark/sql/comet/CometExecRDD.scala | 9 +++------ .../spark/sql/comet/CometIcebergNativeScanExec.scala | 3 +-- .../org/apache/spark/sql/comet/CometNativeScanExec.scala | 3 +-- 3 files changed, 5 insertions(+), 10 deletions(-) diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometExecRDD.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometExecRDD.scala index b680b114e6..4e4dbc4b7d 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometExecRDD.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometExecRDD.scala @@ -66,8 +66,7 @@ private[spark] class CometExecRDD( subqueries: Seq[ScalarSubquery], broadcastedHadoopConfForEncryption: Option[Broadcast[SerializableConfiguration]] = None, encryptedFilePaths: Seq[String] = Seq.empty, - shuffleScanIndices: Set[Int] = Set.empty, - hasNativeScan: Boolean = false) + shuffleScanIndices: Set[Int] = Set.empty) extends RDD[ColumnarBatch](sc, inputRDDs.map(rdd => new OneToOneDependency(rdd))) { // Determine partition count: from inputs if available, otherwise from parameter @@ -197,8 +196,7 @@ object CometExecRDD { subqueries: Seq[ScalarSubquery], broadcastedHadoopConfForEncryption: Option[Broadcast[SerializableConfiguration]] = None, encryptedFilePaths: Seq[String] = Seq.empty, - shuffleScanIndices: Set[Int] = Set.empty, - hasNativeScan: Boolean = false): CometExecRDD = { + shuffleScanIndices: Set[Int] = Set.empty): CometExecRDD = { // scalastyle:on new CometExecRDD( @@ -213,7 +211,6 @@ object CometExecRDD { subqueries, broadcastedHadoopConfForEncryption, encryptedFilePaths, - shuffleScanIndices, - hasNativeScan) + shuffleScanIndices) } } diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometIcebergNativeScanExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometIcebergNativeScanExec.scala index b4a408eeda..36085b6329 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometIcebergNativeScanExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometIcebergNativeScanExec.scala @@ -305,8 +305,7 @@ case class CometIcebergNativeScanExec( numPartitions = perPartitionData.length, numOutputCols = output.length, nativeMetrics = nativeMetrics, - subqueries = Seq.empty, - hasNativeScan = true) + subqueries = Seq.empty) } /** diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometNativeScanExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometNativeScanExec.scala index f5e6495724..dcb975ac7a 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometNativeScanExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometNativeScanExec.scala @@ -191,8 +191,7 @@ case class CometNativeScanExec( nativeMetrics = nativeMetrics, subqueries = Seq.empty, broadcastedHadoopConfForEncryption = broadcastedHadoopConfForEncryption, - encryptedFilePaths = encryptedFilePaths, - hasNativeScan = true) + encryptedFilePaths = encryptedFilePaths) } override def doCanonicalize(): CometNativeScanExec = { From 23b0e3c946a00c327828f3ff30bfd6f42a9ee413 Mon Sep 17 00:00:00 2001 From: comphead Date: Fri, 3 Apr 2026 15:08:26 -0700 Subject: [PATCH 10/15] chore: `native_datafusion` to report scan task input metrics --- .../apache/spark/sql/comet/CometExecRDD.scala | 16 ---------- .../apache/spark/sql/comet/operators.scala | 30 +++++++++++++++++-- .../sql/comet/CometTaskMetricsSuite.scala | 2 +- 3 files changed, 29 insertions(+), 19 deletions(-) diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometExecRDD.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometExecRDD.scala index 4e4dbc4b7d..c5014818c4 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometExecRDD.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometExecRDD.scala @@ -139,22 +139,6 @@ private[spark] class CometExecRDD( ctx.addTaskCompletionListener[Unit] { _ => it.close() subqueries.foreach(sub => CometScalarSubquery.removeSubquery(it.id, sub)) - - // Report scan input metrics when the leaf node has scan metrics. - // The bytes_scanned key only exists in nativeScanMetrics, so this - // naturally skips non-scan CometExecRDD instances. - val leaf = nativeMetrics.leafNode - leaf.metrics.get("bytes_scanned").foreach { bs => - ctx.taskMetrics().inputMetrics.setBytesRead(bs.value) - // Total rows read from Parquet = rows output after pushdown (output_rows) - // + rows pruned by pushdown filters (pushdown_rows_pruned). - // This matches Spark's recordsRead which counts rows before filtering. - val outputRows = - leaf.metrics.get("output_rows").map(_.value).getOrElse(0L) - val prunedRows = - leaf.metrics.get("pushdown_rows_pruned").map(_.value).getOrElse(0L) - ctx.taskMetrics().inputMetrics.setRecordsRead(outputRows + prunedRows) - } } } diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala index 2965e46988..4ab101fea8 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala @@ -25,6 +25,7 @@ import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import scala.jdk.CollectionConverters._ +import org.apache.spark.{Partition, TaskContext} import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow @@ -558,7 +559,7 @@ abstract class CometNativeExec extends CometExec { // Unified RDD creation - CometExecRDD handles all cases val subqueries = collectSubqueries(this) - CometExecRDD( + new CometExecRDD( sparkContext, inputs.toSeq, commonByKey, @@ -570,7 +571,32 @@ abstract class CometNativeExec extends CometExec { subqueries, broadcastedHadoopConfForEncryption, encryptedFilePaths, - shuffleScanIndices) + shuffleScanIndices) { + override def compute( + split: Partition, + context: TaskContext): Iterator[ColumnarBatch] = { + val res = super.compute(split, context) + + // Report scan input metrics only when the native plan contains a scan. + if (sparkPlans.exists(_.isInstanceOf[CometNativeScanExec])) { + Option(context).foreach { ctx => + ctx.addTaskCompletionListener[Unit] { _ => + val leaf = nativeMetrics.leafNode + leaf.metrics.get("bytes_scanned").foreach { bs => + ctx.taskMetrics().inputMetrics.setBytesRead(bs.value) + val outputRows = + leaf.metrics.get("output_rows").map(_.value).getOrElse(0L) + val prunedRows = + leaf.metrics.get("pushdown_rows_pruned").map(_.value).getOrElse(0L) + ctx.taskMetrics().inputMetrics.setRecordsRead(outputRows + prunedRows) + } + } + } + } + + res + } + } } } diff --git a/spark/src/test/scala/org/apache/spark/sql/comet/CometTaskMetricsSuite.scala b/spark/src/test/scala/org/apache/spark/sql/comet/CometTaskMetricsSuite.scala index 3107440ea3..a4f12dc8f7 100644 --- a/spark/src/test/scala/org/apache/spark/sql/comet/CometTaskMetricsSuite.scala +++ b/spark/src/test/scala/org/apache/spark/sql/comet/CometTaskMetricsSuite.scala @@ -105,7 +105,7 @@ class CometTaskMetricsSuite extends CometTestBase with AdaptiveSparkPlanHelper { withTempPath { dir => val rng = new scala.util.Random(42) spark - .createDataFrame((0 until totalRows).map(_ => (rng.nextInt(), rng.nextLong()))) + .createDataFrame((0 until totalRows).map(i => (i, s"elem_$i"))) .repartition(5) .write .parquet(dir.getAbsolutePath) From 44d012d5d6346ff76c0ed51c1602d2c8b51375d4 Mon Sep 17 00:00:00 2001 From: comphead Date: Fri, 3 Apr 2026 15:27:23 -0700 Subject: [PATCH 11/15] chore: `native_datafusion` to report scan task input metrics --- .../scala/org/apache/spark/sql/comet/CometTaskMetricsSuite.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/spark/src/test/scala/org/apache/spark/sql/comet/CometTaskMetricsSuite.scala b/spark/src/test/scala/org/apache/spark/sql/comet/CometTaskMetricsSuite.scala index a4f12dc8f7..a053c57293 100644 --- a/spark/src/test/scala/org/apache/spark/sql/comet/CometTaskMetricsSuite.scala +++ b/spark/src/test/scala/org/apache/spark/sql/comet/CometTaskMetricsSuite.scala @@ -103,7 +103,6 @@ class CometTaskMetricsSuite extends CometTestBase with AdaptiveSparkPlanHelper { test("native_datafusion scan reports task-level input metrics matching Spark") { val totalRows = 10000 withTempPath { dir => - val rng = new scala.util.Random(42) spark .createDataFrame((0 until totalRows).map(i => (i, s"elem_$i"))) .repartition(5) From 13d93fefd882ecadca0c2cc0abc17acd95fb5087 Mon Sep 17 00:00:00 2001 From: comphead Date: Fri, 3 Apr 2026 22:57:02 -0700 Subject: [PATCH 12/15] chore: `native_datafusion` to report scan task input metrics --- .../org/apache/spark/sql/comet/CometTaskMetricsSuite.scala | 3 +++ 1 file changed, 3 insertions(+) diff --git a/spark/src/test/scala/org/apache/spark/sql/comet/CometTaskMetricsSuite.scala b/spark/src/test/scala/org/apache/spark/sql/comet/CometTaskMetricsSuite.scala index a053c57293..5b6225b720 100644 --- a/spark/src/test/scala/org/apache/spark/sql/comet/CometTaskMetricsSuite.scala +++ b/spark/src/test/scala/org/apache/spark/sql/comet/CometTaskMetricsSuite.scala @@ -123,6 +123,9 @@ class CometTaskMetricsSuite extends CometTestBase with AdaptiveSparkPlanHelper { find(cometPlan)(_.isInstanceOf[CometNativeScanExec]).isDefined, s"Expected CometNativeScanExec in plan:\n${cometPlan.treeString}") + assert(sparkRecords > 0, s"Spark outputRecords should be > 0, got $sparkRecords") + assert(cometRecords > 0, s"Comet outputRecords should be > 0, got $cometRecords") + assert( cometRecords == sparkRecords, s"recordsRead mismatch: comet=$cometRecords, sparkRecords=$sparkRecords") From e2299fe77ccafd45eb685b8a420948e556983d9a Mon Sep 17 00:00:00 2001 From: comphead Date: Tue, 7 Apr 2026 09:39:34 -0700 Subject: [PATCH 13/15] chore: `native_datafusion` to report scan task input metrics --- .../spark/sql/comet/CometMetricNode.scala | 10 ++ .../apache/spark/sql/comet/operators.scala | 23 ++-- .../sql/comet/CometTaskMetricsSuite.scala | 119 +++++++++++++++++- 3 files changed, 139 insertions(+), 13 deletions(-) diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometMetricNode.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometMetricNode.scala index c60202a96f..ea476252d4 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometMetricNode.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometMetricNode.scala @@ -51,6 +51,16 @@ case class CometMetricNode(metrics: Map[String, SQLMetric], children: Seq[CometM else children.head.leafNode } + /** + * Returns all leaf nodes (nodes with no children) in the metric tree. Unlike [[leafNode]] which + * only follows the first child, this finds all leaves, which is needed for plans with multiple + * scans (e.g., joins, unions). + */ + def leafNodes: Seq[CometMetricNode] = { + if (children.isEmpty) Seq(this) + else children.flatMap(_.leafNodes) + } + /** * Gets a child node. Called from native. */ diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala index 3624044eca..77c87fcc74 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala @@ -578,17 +578,24 @@ abstract class CometNativeExec extends CometExec { val res = super.compute(split, context) // Report scan input metrics only when the native plan contains a scan. + // Aggregate across all scan leaf nodes to handle plans with multiple + // scans (e.g., joins over different tables, unions). if (sparkPlans.exists(_.isInstanceOf[CometNativeScanExec])) { Option(context).foreach { ctx => ctx.addTaskCompletionListener[Unit] { _ => - val leaf = nativeMetrics.leafNode - leaf.metrics.get("bytes_scanned").foreach { bs => - ctx.taskMetrics().inputMetrics.setBytesRead(bs.value) - val outputRows = - leaf.metrics.get("output_rows").map(_.value).getOrElse(0L) - val prunedRows = - leaf.metrics.get("pushdown_rows_pruned").map(_.value).getOrElse(0L) - ctx.taskMetrics().inputMetrics.setRecordsRead(outputRows + prunedRows) + val scanLeaves = nativeMetrics.leafNodes + .filter(_.metrics.contains("bytes_scanned")) + if (scanLeaves.nonEmpty) { + val totalBytes = scanLeaves.map(_.metrics("bytes_scanned").value).sum + val totalRows = scanLeaves.map { leaf => + val outputRows = + leaf.metrics.get("output_rows").map(_.value).getOrElse(0L) + val prunedRows = + leaf.metrics.get("pushdown_rows_pruned").map(_.value).getOrElse(0L) + outputRows + prunedRows + }.sum + ctx.taskMetrics().inputMetrics.setBytesRead(totalBytes) + ctx.taskMetrics().inputMetrics.setRecordsRead(totalRows) } } } diff --git a/spark/src/test/scala/org/apache/spark/sql/comet/CometTaskMetricsSuite.scala b/spark/src/test/scala/org/apache/spark/sql/comet/CometTaskMetricsSuite.scala index 5b6225b720..2c34307cc0 100644 --- a/spark/src/test/scala/org/apache/spark/sql/comet/CometTaskMetricsSuite.scala +++ b/spark/src/test/scala/org/apache/spark/sql/comet/CometTaskMetricsSuite.scala @@ -111,10 +111,13 @@ class CometTaskMetricsSuite extends CometTestBase with AdaptiveSparkPlanHelper { spark.read.parquet(dir.getAbsolutePath).createOrReplaceTempView("tbl") // Collect baseline input metrics from vanilla Spark (Comet disabled) val (sparkBytes, sparkRecords, _) = - collectInputMetrics(CometConf.COMET_ENABLED.key -> "false") + collectInputMetrics( + "SELECT * FROM tbl where _1 > 2000", + CometConf.COMET_ENABLED.key -> "false") // Collect input metrics from Comet native_datafusion scan. val (cometBytes, cometRecords, cometPlan) = collectInputMetrics( + "SELECT * FROM tbl where _1 > 2000", CometConf.COMET_ENABLED.key -> "true", CometConf.COMET_NATIVE_SCAN_IMPL.key -> CometConf.SCAN_NATIVE_DATAFUSION) @@ -142,15 +145,121 @@ class CometTaskMetricsSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } + test("input metrics aggregate across multiple native scans in a join") { + withTempPath { dir1 => + withTempPath { dir2 => + // Create two separate parquet tables + spark + .createDataFrame((0 until 5000).map(i => (i, s"left_$i"))) + .repartition(3) + .write + .parquet(dir1.getAbsolutePath) + spark + .createDataFrame((0 until 5000).map(i => (i, s"right_$i"))) + .repartition(3) + .write + .parquet(dir2.getAbsolutePath) + + spark.read.parquet(dir1.getAbsolutePath).createOrReplaceTempView("left_tbl") + spark.read.parquet(dir2.getAbsolutePath).createOrReplaceTempView("right_tbl") + + val joinQuery = "SELECT * FROM left_tbl JOIN right_tbl ON left_tbl._1 = right_tbl._1" + + // Collect baseline from vanilla Spark + val (sparkBytes, sparkRecords, _) = + collectInputMetrics(joinQuery, CometConf.COMET_ENABLED.key -> "false") + + // Collect from Comet native scan + val (cometBytes, cometRecords, cometPlan) = collectInputMetrics( + joinQuery, + CometConf.COMET_ENABLED.key -> "true", + CometConf.COMET_NATIVE_SCAN_IMPL.key -> CometConf.SCAN_NATIVE_DATAFUSION) + + // Verify the plan has multiple CometNativeScanExec nodes + val scanCount = collect(cometPlan) { case s: CometNativeScanExec => + s + }.size + assert( + scanCount >= 2, + s"Expected at least 2 CometNativeScanExec in plan, found $scanCount:\n" + + cometPlan.treeString) + + assert(sparkBytes > 0, s"Spark bytesRead should be > 0, got $sparkBytes") + assert(cometBytes > 0, s"Comet bytesRead should be > 0, got $cometBytes") + assert(sparkRecords > 0, s"Spark recordsRead should be > 0, got $sparkRecords") + assert(cometRecords > 0, s"Comet recordsRead should be > 0, got $cometRecords") + + // Both sides should contribute to the total bytes + val ratio = cometBytes.toDouble / sparkBytes.toDouble + assert( + ratio >= 0.8 && ratio <= 1.2, + s"bytesRead ratio out of range: comet=$cometBytes, spark=$sparkBytes, ratio=$ratio") + } + } + } + + test("input metrics aggregate across multiple native scans in a union") { + withTempPath { dir1 => + withTempPath { dir2 => + spark + .createDataFrame((0 until 5000).map(i => (i, s"left_$i"))) + .repartition(3) + .write + .parquet(dir1.getAbsolutePath) + spark + .createDataFrame((5000 until 10000).map(i => (i, s"right_$i"))) + .repartition(3) + .write + .parquet(dir2.getAbsolutePath) + + spark.read.parquet(dir1.getAbsolutePath).createOrReplaceTempView("union_left") + spark.read.parquet(dir2.getAbsolutePath).createOrReplaceTempView("union_right") + + val unionQuery = "SELECT * FROM union_left UNION ALL SELECT * FROM union_right" + + // Collect baseline from vanilla Spark + val (sparkBytes, sparkRecords, _) = + collectInputMetrics(unionQuery, CometConf.COMET_ENABLED.key -> "false") + + // Collect from Comet native scan + val (cometBytes, cometRecords, cometPlan) = collectInputMetrics( + unionQuery, + CometConf.COMET_ENABLED.key -> "true", + CometConf.COMET_NATIVE_SCAN_IMPL.key -> CometConf.SCAN_NATIVE_DATAFUSION) + + // Verify the plan has multiple CometNativeScanExec nodes + val scanCount = collect(cometPlan) { case s: CometNativeScanExec => + s + }.size + assert( + scanCount >= 2, + s"Expected at least 2 CometNativeScanExec in plan, found $scanCount:\n" + + cometPlan.treeString) + + assert(sparkBytes > 0, s"Spark bytesRead should be > 0, got $sparkBytes") + assert(cometBytes > 0, s"Comet bytesRead should be > 0, got $cometBytes") + assert(sparkRecords > 0, s"Spark recordsRead should be > 0, got $sparkRecords") + assert(cometRecords > 0, s"Comet recordsRead should be > 0, got $cometRecords") + + val ratio = cometBytes.toDouble / sparkBytes.toDouble + assert( + ratio >= 0.8 && ratio <= 1.2, + s"bytesRead ratio out of range: comet=$cometBytes, spark=$sparkBytes, ratio=$ratio") + } + } + } + /** - * Runs `SELECT * FROM tbl WHERE _1 > 2000` with the given SQL config overrides and returns the - * aggregated (bytesRead, recordsRead) across all tasks, along with the executed plan. + * Runs the given query with the given SQL config overrides and returns the aggregated + * (bytesRead, recordsRead) across all tasks, along with the executed plan. * * Uses AppStatusStore (same source as Spark UI) to read task-level input metrics. * AppStatusStore stores immutable snapshots of metric values, unlike SparkListener's * InputMetrics which are backed by mutable accumulators that can be reset. */ - private def collectInputMetrics(confs: (String, String)*): (Long, Long, SparkPlan) = { + private def collectInputMetrics( + query: String, + confs: (String, String)*): (Long, Long, SparkPlan) = { val store = spark.sparkContext.statusStore // Record existing stage IDs so we only look at stages from our query @@ -158,7 +267,7 @@ class CometTaskMetricsSuite extends CometTestBase with AdaptiveSparkPlanHelper { var plan: SparkPlan = null withSQLConf(confs: _*) { - val df = sql("SELECT * FROM tbl where _1 > 2000") + val df = sql(query) df.collect() plan = stripAQEPlan(df.queryExecution.executedPlan) } From 9b8557c08b132636574d49bd9293a1cddb7434fe Mon Sep 17 00:00:00 2001 From: comphead Date: Tue, 7 Apr 2026 11:12:41 -0700 Subject: [PATCH 14/15] chore: `native_datafusion` to report scan task input metrics --- .../spark/sql/comet/CometMetricNode.scala | 25 ++++++++++++++- .../spark/sql/comet/CometNativeScanExec.scala | 32 ++++++++++++------- .../apache/spark/sql/comet/operators.scala | 24 ++------------ .../sql/comet/CometTaskMetricsSuite.scala | 2 +- 4 files changed, 49 insertions(+), 34 deletions(-) diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometMetricNode.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometMetricNode.scala index ea476252d4..b0ded47580 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometMetricNode.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometMetricNode.scala @@ -21,7 +21,7 @@ package org.apache.spark.sql.comet import scala.jdk.CollectionConverters._ -import org.apache.spark.SparkContext +import org.apache.spark.{SparkContext, TaskContext} import org.apache.spark.internal.Logging import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} @@ -61,6 +61,29 @@ case class CometMetricNode(metrics: Map[String, SQLMetric], children: Seq[CometM else children.flatMap(_.leafNodes) } + /** + * Reports aggregated scan input metrics (bytesRead, recordsRead) to Spark's task metrics. + * Aggregates across all scan leaf nodes to handle plans with multiple scans (e.g., joins). Must + * be called in a TaskCompletionListener after the iterator is fully consumed. + */ + def reportScanInputMetrics(ctx: TaskContext): Unit = { + ctx.addTaskCompletionListener[Unit] { _ => + val scanLeaves = leafNodes.filter(_.metrics.contains("bytes_scanned")) + if (scanLeaves.nonEmpty) { + val totalBytes = scanLeaves.map(_.metrics("bytes_scanned").value).sum + val totalRows = scanLeaves.map { leaf => + val outputRows = + leaf.metrics.get("output_rows").map(_.value).getOrElse(0L) + val prunedRows = + leaf.metrics.get("pushdown_rows_pruned").map(_.value).getOrElse(0L) + outputRows + prunedRows + }.sum + ctx.taskMetrics().inputMetrics.setBytesRead(totalBytes) + ctx.taskMetrics().inputMetrics.setRecordsRead(totalRows) + } + } + } + /** * Gets a child node. Called from native. */ diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometNativeScanExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometNativeScanExec.scala index dcb975ac7a..ae2d873ef7 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometNativeScanExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometNativeScanExec.scala @@ -21,6 +21,7 @@ package org.apache.spark.sql.comet import scala.reflect.ClassTag +import org.apache.spark.{Partition, TaskContext} import org.apache.spark.rdd.RDD import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst._ @@ -180,18 +181,27 @@ case class CometNativeScanExec( (None, Seq.empty) } - CometExecRDD( + new CometExecRDD( sparkContext, - inputRDDs = Seq.empty, - commonByKey = Map(sourceKey -> commonData), - perPartitionByKey = Map(sourceKey -> perPartitionData), - serializedPlan = serializedPlan, - numPartitions = perPartitionData.length, - numOutputCols = output.length, - nativeMetrics = nativeMetrics, - subqueries = Seq.empty, - broadcastedHadoopConfForEncryption = broadcastedHadoopConfForEncryption, - encryptedFilePaths = encryptedFilePaths) + Seq.empty, + Map(sourceKey -> commonData), + Map(sourceKey -> perPartitionData), + serializedPlan, + perPartitionData.length, + output.length, + nativeMetrics, + Seq.empty, + broadcastedHadoopConfForEncryption, + encryptedFilePaths) { + override def compute(split: Partition, context: TaskContext): Iterator[ColumnarBatch] = { + val res = super.compute(split, context) + + // Report scan input metrics after the iterator is fully consumed. + Option(context).foreach(nativeMetrics.reportScanInputMetrics) + + res + } + } } override def doCanonicalize(): CometNativeScanExec = { diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala index 77c87fcc74..21cbdab974 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala @@ -559,6 +559,7 @@ abstract class CometNativeExec extends CometExec { // Unified RDD creation - CometExecRDD handles all cases val subqueries = collectSubqueries(this) + val hasScanInput = sparkPlans.exists(_.isInstanceOf[CometNativeScanExec]) new CometExecRDD( sparkContext, inputs.toSeq, @@ -578,27 +579,8 @@ abstract class CometNativeExec extends CometExec { val res = super.compute(split, context) // Report scan input metrics only when the native plan contains a scan. - // Aggregate across all scan leaf nodes to handle plans with multiple - // scans (e.g., joins over different tables, unions). - if (sparkPlans.exists(_.isInstanceOf[CometNativeScanExec])) { - Option(context).foreach { ctx => - ctx.addTaskCompletionListener[Unit] { _ => - val scanLeaves = nativeMetrics.leafNodes - .filter(_.metrics.contains("bytes_scanned")) - if (scanLeaves.nonEmpty) { - val totalBytes = scanLeaves.map(_.metrics("bytes_scanned").value).sum - val totalRows = scanLeaves.map { leaf => - val outputRows = - leaf.metrics.get("output_rows").map(_.value).getOrElse(0L) - val prunedRows = - leaf.metrics.get("pushdown_rows_pruned").map(_.value).getOrElse(0L) - outputRows + prunedRows - }.sum - ctx.taskMetrics().inputMetrics.setBytesRead(totalBytes) - ctx.taskMetrics().inputMetrics.setRecordsRead(totalRows) - } - } - } + if (hasScanInput) { + Option(context).foreach(nativeMetrics.reportScanInputMetrics) } res diff --git a/spark/src/test/scala/org/apache/spark/sql/comet/CometTaskMetricsSuite.scala b/spark/src/test/scala/org/apache/spark/sql/comet/CometTaskMetricsSuite.scala index 2c34307cc0..219464c561 100644 --- a/spark/src/test/scala/org/apache/spark/sql/comet/CometTaskMetricsSuite.scala +++ b/spark/src/test/scala/org/apache/spark/sql/comet/CometTaskMetricsSuite.scala @@ -140,7 +140,7 @@ class CometTaskMetricsSuite extends CometTestBase with AdaptiveSparkPlanHelper { assert(cometBytes > 0, s"Comet bytesRead should be > 0, got $cometBytes") val ratio = cometBytes.toDouble / sparkBytes.toDouble assert( - ratio >= 0.8 && ratio <= 1.2, + ratio >= 0.7 && ratio <= 1.3, s"bytesRead ratio out of range: comet=$cometBytes, spark=$sparkBytes, ratio=$ratio") } } From b9925a8bdaa0262b5e15f96f812923643334fd8f Mon Sep 17 00:00:00 2001 From: comphead Date: Tue, 7 Apr 2026 11:31:38 -0700 Subject: [PATCH 15/15] chore: `native_datafusion` to report scan task input metrics --- .../org/apache/spark/sql/comet/CometTaskMetricsSuite.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/spark/src/test/scala/org/apache/spark/sql/comet/CometTaskMetricsSuite.scala b/spark/src/test/scala/org/apache/spark/sql/comet/CometTaskMetricsSuite.scala index 219464c561..cc02551bf8 100644 --- a/spark/src/test/scala/org/apache/spark/sql/comet/CometTaskMetricsSuite.scala +++ b/spark/src/test/scala/org/apache/spark/sql/comet/CometTaskMetricsSuite.scala @@ -192,7 +192,7 @@ class CometTaskMetricsSuite extends CometTestBase with AdaptiveSparkPlanHelper { // Both sides should contribute to the total bytes val ratio = cometBytes.toDouble / sparkBytes.toDouble assert( - ratio >= 0.8 && ratio <= 1.2, + ratio >= 0.7 && ratio <= 1.3, s"bytesRead ratio out of range: comet=$cometBytes, spark=$sparkBytes, ratio=$ratio") } } @@ -243,7 +243,7 @@ class CometTaskMetricsSuite extends CometTestBase with AdaptiveSparkPlanHelper { val ratio = cometBytes.toDouble / sparkBytes.toDouble assert( - ratio >= 0.8 && ratio <= 1.2, + ratio >= 0.7 && ratio <= 1.3, s"bytesRead ratio out of range: comet=$cometBytes, spark=$sparkBytes, ratio=$ratio") } }