diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/CustomMetrics.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/CustomMetrics.java new file mode 100644 index 0000000000000..7011a70e515e2 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/CustomMetrics.java @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2; + +import org.apache.spark.annotation.InterfaceStability; + +/** + * An interface for reporting custom metrics from streaming sources and sinks + */ +@InterfaceStability.Evolving +public interface CustomMetrics { + /** + * Returns a JSON serialized representation of custom metrics + * + * @return JSON serialized representation of custom metrics + */ + String json(); +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/SupportsCustomReaderMetrics.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/SupportsCustomReaderMetrics.java new file mode 100644 index 0000000000000..3b293d925c91d --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/SupportsCustomReaderMetrics.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.sources.v2.reader.streaming; + +import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.sources.v2.CustomMetrics; +import org.apache.spark.sql.sources.v2.reader.DataSourceReader; + +/** + * A mix in interface for {@link DataSourceReader}. Data source readers can implement this + * interface to report custom metrics that gets reported under the + * {@link org.apache.spark.sql.streaming.SourceProgress} + * + */ +@InterfaceStability.Evolving +public interface SupportsCustomReaderMetrics extends DataSourceReader { + /** + * Returns custom metrics specific to this data source. + */ + CustomMetrics getCustomMetrics(); + + /** + * Invoked if the custom metrics returned by {@link #getCustomMetrics()} is invalid + * (e.g. Invalid data that cannot be parsed). Throwing an error here would ensure that + * your custom metrics work right and correct values are reported always. The default action + * on invalid metrics is to ignore it. + * + * @param ex the exception + */ + default void onInvalidMetrics(Exception ex) { + // default is to ignore invalid custom metrics + } +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/SupportsCustomWriterMetrics.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/SupportsCustomWriterMetrics.java new file mode 100644 index 0000000000000..0cd36501320fd --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/SupportsCustomWriterMetrics.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.sources.v2.writer.streaming; + +import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.sources.v2.CustomMetrics; +import org.apache.spark.sql.sources.v2.writer.DataSourceWriter; + +/** + * A mix in interface for {@link DataSourceWriter}. Data source writers can implement this + * interface to report custom metrics that gets reported under the + * {@link org.apache.spark.sql.streaming.SinkProgress} + * + */ +@InterfaceStability.Evolving +public interface SupportsCustomWriterMetrics extends DataSourceWriter { + /** + * Returns custom metrics specific to this data source. + */ + CustomMetrics getCustomMetrics(); + + /** + * Invoked if the custom metrics returned by {@link #getCustomMetrics()} is invalid + * (e.g. Invalid data that cannot be parsed). Throwing an error here would ensure that + * your custom metrics work right and correct values are reported always. The default action + * on invalid metrics is to ignore it. + * + * @param ex the exception + */ + default void onInvalidMetrics(Exception ex) { + // default is to ignore invalid custom metrics + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala index 47f4b52e6e34c..1e158323d2020 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala @@ -22,14 +22,22 @@ import java.util.{Date, UUID} import scala.collection.JavaConverters._ import scala.collection.mutable +import scala.util.control.NonFatal + +import org.json4s.JsonAST.JValue +import org.json4s.jackson.JsonMethods.parse import org.apache.spark.internal.Logging -import org.apache.spark.sql.{DataFrame, SparkSession} +import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, LogicalPlan} import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.QueryExecution -import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanExec -import org.apache.spark.sql.sources.v2.reader.streaming.MicroBatchReader +import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2ScanExec, WriteToDataSourceV2Exec} +import org.apache.spark.sql.execution.streaming.sources.MicroBatchWriter +import org.apache.spark.sql.sources.v2.CustomMetrics +import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, SupportsCustomReaderMetrics} +import org.apache.spark.sql.sources.v2.writer.DataSourceWriter +import org.apache.spark.sql.sources.v2.writer.streaming.SupportsCustomWriterMetrics import org.apache.spark.sql.streaming._ import org.apache.spark.sql.streaming.StreamingQueryListener.QueryProgressEvent import org.apache.spark.util.Clock @@ -156,7 +164,31 @@ trait ProgressReporter extends Logging { } logDebug(s"Execution stats: $executionStats") + // extracts and validates custom metrics from readers and writers + def extractMetrics( + getMetrics: () => Option[CustomMetrics], + onInvalidMetrics: (Exception) => Unit): Option[String] = { + try { + getMetrics().map(m => { + val json = m.json() + parse(json) + json + }) + } catch { + case ex: Exception if NonFatal(ex) => + onInvalidMetrics(ex) + None + } + } + val sourceProgress = sources.distinct.map { source => + val customReaderMetrics = source match { + case s: SupportsCustomReaderMetrics => + extractMetrics(() => Option(s.getCustomMetrics), s.onInvalidMetrics) + + case _ => None + } + val numRecords = executionStats.inputRows.getOrElse(source, 0L) new SourceProgress( description = source.toString, @@ -164,10 +196,19 @@ trait ProgressReporter extends Logging { endOffset = currentTriggerEndOffsets.get(source).orNull, numInputRows = numRecords, inputRowsPerSecond = numRecords / inputTimeSec, - processedRowsPerSecond = numRecords / processingTimeSec + processedRowsPerSecond = numRecords / processingTimeSec, + customReaderMetrics.orNull ) } - val sinkProgress = new SinkProgress(sink.toString) + + val customWriterMetrics = dataSourceWriter match { + case Some(s: SupportsCustomWriterMetrics) => + extractMetrics(() => Option(s.getCustomMetrics), s.onInvalidMetrics) + + case _ => None + } + + val sinkProgress = new SinkProgress(sink.toString, customWriterMetrics.orNull) val newProgress = new StreamingQueryProgress( id = id, @@ -196,6 +237,18 @@ trait ProgressReporter extends Logging { currentStatus = currentStatus.copy(isTriggerActive = false) } + /** Extract writer from the executed query plan. */ + private def dataSourceWriter: Option[DataSourceWriter] = { + if (lastExecution == null) return None + lastExecution.executedPlan.collect { + case p if p.isInstanceOf[WriteToDataSourceV2Exec] => + p.asInstanceOf[WriteToDataSourceV2Exec].writer + }.headOption match { + case Some(w: MicroBatchWriter) => Some(w.writer) + case _ => None + } + } + /** Extract statistics about stateful operators from the executed query plan. */ private def extractStateOperatorMetrics(hasNewData: Boolean): Seq[StateOperatorProgress] = { if (lastExecution == null) return Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWriter.scala index d023a35ea20b6..2d43a7bb77872 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWriter.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter * the non-streaming interface, forwarding the batch ID determined at construction to a wrapped * streaming writer. */ -class MicroBatchWriter(batchId: Long, writer: StreamWriter) extends DataSourceWriter { +class MicroBatchWriter(batchId: Long, val writer: StreamWriter) extends DataSourceWriter { override def commit(messages: Array[WriterCommitMessage]): Unit = { writer.commit(batchId, messages) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala index afacb2f72c926..2a5d21f330541 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala @@ -23,6 +23,9 @@ import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import scala.util.control.NonFatal +import org.json4s.NoTypeHints +import org.json4s.jackson.Serialization + import org.apache.spark.internal.Logging import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow @@ -32,9 +35,9 @@ import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics} import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils import org.apache.spark.sql.catalyst.streaming.InternalOutputModes.{Append, Complete, Update} import org.apache.spark.sql.execution.streaming.{MemorySinkBase, Sink} -import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, StreamWriteSupport} +import org.apache.spark.sql.sources.v2.{CustomMetrics, DataSourceOptions, DataSourceV2, StreamWriteSupport} import org.apache.spark.sql.sources.v2.writer._ -import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter +import org.apache.spark.sql.sources.v2.writer.streaming.{StreamWriter, SupportsCustomWriterMetrics} import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType @@ -114,14 +117,25 @@ class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with MemorySinkB batches.clear() } + def numRows: Int = synchronized { + batches.foldLeft(0)(_ + _.data.length) + } + override def toString(): String = "MemorySinkV2" } case class MemoryWriterCommitMessage(partition: Int, data: Seq[Row]) extends WriterCommitMessage {} +class MemoryV2CustomMetrics(sink: MemorySinkV2) extends CustomMetrics { + private implicit val formats = Serialization.formats(NoTypeHints) + override def json(): String = Serialization.write(Map("numRows" -> sink.numRows)) +} + class MemoryWriter(sink: MemorySinkV2, batchId: Long, outputMode: OutputMode, schema: StructType) - extends DataSourceWriter with Logging { + extends DataSourceWriter with SupportsCustomWriterMetrics with Logging { + + private val memoryV2CustomMetrics = new MemoryV2CustomMetrics(sink) override def createWriterFactory: MemoryWriterFactory = MemoryWriterFactory(outputMode, schema) @@ -135,10 +149,16 @@ class MemoryWriter(sink: MemorySinkV2, batchId: Long, outputMode: OutputMode, sc override def abort(messages: Array[WriterCommitMessage]): Unit = { // Don't accept any of the new input. } + + override def getCustomMetrics: CustomMetrics = { + memoryV2CustomMetrics + } } class MemoryStreamWriter(val sink: MemorySinkV2, outputMode: OutputMode, schema: StructType) - extends StreamWriter { + extends StreamWriter with SupportsCustomWriterMetrics { + + private val customMemoryV2Metrics = new MemoryV2CustomMetrics(sink) override def createWriterFactory: MemoryWriterFactory = MemoryWriterFactory(outputMode, schema) @@ -152,6 +172,10 @@ class MemoryStreamWriter(val sink: MemorySinkV2, outputMode: OutputMode, schema: override def abort(epochId: Long, messages: Array[WriterCommitMessage]): Unit = { // Don't accept any of the new input. } + + override def getCustomMetrics: CustomMetrics = { + customMemoryV2Metrics + } } case class MemoryWriterFactory(outputMode: OutputMode, schema: StructType) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/progress.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/progress.scala index 0dcb666e2c3e4..2fb87960ccb04 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/progress.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/progress.scala @@ -163,7 +163,27 @@ class SourceProgress protected[sql]( val endOffset: String, val numInputRows: Long, val inputRowsPerSecond: Double, - val processedRowsPerSecond: Double) extends Serializable { + val processedRowsPerSecond: Double, + val customMetrics: String) extends Serializable { + + /** SourceProgress without custom metrics. */ + protected[sql] def this( + description: String, + startOffset: String, + endOffset: String, + numInputRows: Long, + inputRowsPerSecond: Double, + processedRowsPerSecond: Double) { + + this( + description, + startOffset, + endOffset, + numInputRows, + inputRowsPerSecond, + processedRowsPerSecond, + null) + } /** The compact JSON representation of this progress. */ def json: String = compact(render(jsonValue)) @@ -178,12 +198,18 @@ class SourceProgress protected[sql]( if (value.isNaN || value.isInfinity) JNothing else JDouble(value) } - ("description" -> JString(description)) ~ + val jsonVal = ("description" -> JString(description)) ~ ("startOffset" -> tryParse(startOffset)) ~ ("endOffset" -> tryParse(endOffset)) ~ ("numInputRows" -> JInt(numInputRows)) ~ ("inputRowsPerSecond" -> safeDoubleToJValue(inputRowsPerSecond)) ~ ("processedRowsPerSecond" -> safeDoubleToJValue(processedRowsPerSecond)) + + if (customMetrics != null) { + jsonVal ~ ("customMetrics" -> parse(customMetrics)) + } else { + jsonVal + } } private def tryParse(json: String) = try { @@ -202,7 +228,13 @@ class SourceProgress protected[sql]( */ @InterfaceStability.Evolving class SinkProgress protected[sql]( - val description: String) extends Serializable { + val description: String, + val customMetrics: String) extends Serializable { + + /** SinkProgress without custom metrics. */ + protected[sql] def this(description: String) { + this(description, null) + } /** The compact JSON representation of this progress. */ def json: String = compact(render(jsonValue)) @@ -213,6 +245,12 @@ class SinkProgress protected[sql]( override def toString: String = prettyJson private[sql] def jsonValue: JValue = { - ("description" -> JString(description)) + val jsonVal = ("description" -> JString(description)) + + if (customMetrics != null) { + jsonVal ~ ("customMetrics" -> parse(customMetrics)) + } else { + jsonVal + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala index b4d9b68c78152..1efaead0845db 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala @@ -84,4 +84,26 @@ class MemorySinkV2Suite extends StreamTest with BeforeAndAfter { assert(sink.allData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4, 6, 7, 11, 22, 33)) } + + test("writer metrics") { + val sink = new MemorySinkV2 + val schema = new StructType().add("i", "int") + // batch 0 + var writer = new MemoryWriter(sink, 0, OutputMode.Append(), schema) + writer.commit( + Array( + MemoryWriterCommitMessage(0, Seq(Row(1), Row(2))), + MemoryWriterCommitMessage(1, Seq(Row(3), Row(4))), + MemoryWriterCommitMessage(2, Seq(Row(5), Row(6))) + )) + assert(writer.getCustomMetrics.json() == "{\"numRows\":6}") + // batch 1 + writer = new MemoryWriter(sink, 1, OutputMode.Append(), schema + ) + writer.commit( + Array( + MemoryWriterCommitMessage(0, Seq(Row(7), Row(8))) + )) + assert(writer.getCustomMetrics.json() == "{\"numRows\":8}") + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala index f37f3682b03b9..646c904786e68 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala @@ -24,6 +24,9 @@ import java.util.concurrent.CountDownLatch import scala.collection.mutable import org.apache.commons.lang3.RandomStringUtils +import org.json4s.NoTypeHints +import org.json4s.jackson.JsonMethods._ +import org.json4s.jackson.Serialization import org.scalactic.TolerantNumerics import org.scalatest.BeforeAndAfter import org.scalatest.concurrent.PatienceConfiguration.Timeout @@ -472,6 +475,31 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi } } + test("Check if custom metrics are reported") { + val streamInput = MemoryStream[Int] + implicit val formats = Serialization.formats(NoTypeHints) + testStream(streamInput.toDF(), useV2Sink = true)( + AddData(streamInput, 1, 2, 3), + CheckAnswer(1, 2, 3), + AssertOnQuery { q => + val lastProgress = getLastProgressWithData(q) + assert(lastProgress.nonEmpty) + assert(lastProgress.get.numInputRows == 3) + assert(lastProgress.get.sink.customMetrics == "{\"numRows\":3}") + true + }, + AddData(streamInput, 4, 5, 6, 7), + CheckAnswer(1, 2, 3, 4, 5, 6, 7), + AssertOnQuery { q => + val lastProgress = getLastProgressWithData(q) + assert(lastProgress.nonEmpty) + assert(lastProgress.get.numInputRows == 4) + assert(lastProgress.get.sink.customMetrics == "{\"numRows\":7}") + true + } + ) + } + test("input row calculation with same V1 source used twice in self-join") { val streamingTriggerDF = spark.createDataset(1 to 10).toDF val streamingInputDF = createSingleTriggerStreamingDF(streamingTriggerDF).toDF("value")