From 43e760ca5d3758da8cd0b2b8b5e847352761d313 Mon Sep 17 00:00:00 2001 From: Wei Liu Date: Fri, 9 Jun 2023 12:02:33 -0700 Subject: [PATCH 01/14] done --- python/pyspark/sql/streaming/listener.py | 438 ++++++++++++++---- .../streaming/test_streaming_listener.py | 166 ++++++- .../streaming/StreamingQueryListener.scala | 37 +- 3 files changed, 550 insertions(+), 91 deletions(-) diff --git a/python/pyspark/sql/streaming/listener.py b/python/pyspark/sql/streaming/listener.py index 33482664a7b02..8bc749c9be1a1 100644 --- a/python/pyspark/sql/streaming/listener.py +++ b/python/pyspark/sql/streaming/listener.py @@ -15,7 +15,8 @@ # limitations under the License. # import uuid -from typing import Optional, Dict, List +import json +from typing import Any, Dict, List, Optional from abc import ABC, abstractmethod from py4j.java_gateway import JavaObject @@ -129,16 +130,16 @@ def __init__(self, pylistener: StreamingQueryListener) -> None: self.pylistener = pylistener def onQueryStarted(self, jevent: JavaObject) -> None: - self.pylistener.onQueryStarted(QueryStartedEvent(jevent)) + self.pylistener.onQueryStarted(QueryStartedEvent.fromJObject(jevent)) def onQueryProgress(self, jevent: JavaObject) -> None: - self.pylistener.onQueryProgress(QueryProgressEvent(jevent)) + self.pylistener.onQueryProgress(QueryProgressEvent.fromJObject(jevent)) def onQueryIdle(self, jevent: JavaObject) -> None: - self.pylistener.onQueryIdle(QueryIdleEvent(jevent)) + self.pylistener.onQueryIdle(QueryIdleEvent.fromJObject(jevent)) def onQueryTerminated(self, jevent: JavaObject) -> None: - self.pylistener.onQueryTerminated(QueryTerminatedEvent(jevent)) + self.pylistener.onQueryTerminated(QueryTerminatedEvent.fromJObject(jevent)) class Java: implements = ["org.apache.spark.sql.streaming.PythonStreamingQueryListener"] @@ -149,17 +150,39 @@ class QueryStartedEvent: Event representing the start of a query. .. versionadded:: 3.4.0 + .. versionchanged:: 3.5.0 + Add fromJson constructor to support Spark Connect. Notes ----- This API is evolving. """ - def __init__(self, jevent: JavaObject) -> None: - self._id: uuid.UUID = uuid.UUID(jevent.id().toString()) - self._runId: uuid.UUID = uuid.UUID(jevent.runId().toString()) - self._name: Optional[str] = jevent.name() - self._timestamp: str = jevent.timestamp() + def __init__( + self, id: uuid.UUID, runId: uuid.UUID, name: Optional[str], timestamp: str + ) -> None: + self._id: uuid.UUID = id + self._runId: uuid.UUID = runId + self._name: Optional[str] = name + self._timestamp: str = timestamp + + @classmethod + def fromJObject(cls, jevent: JavaObject) -> "QueryStartedEvent": + return cls( + id=uuid.UUID(jevent.id().toString()), + runId=uuid.UUID(jevent.runId().toString()), + name=jevent.name(), + timestamp=jevent.timestamp(), + ) + + @classmethod + def fromJson(cls, j: Dict[str, Any]) -> "QueryStartedEvent": + return cls( + id=uuid.UUID(j["id"]), + runId=uuid.UUID(j["runId"]), + name=j["name"], + timestamp=j["timestamp"], + ) @property def id(self) -> uuid.UUID: @@ -197,14 +220,24 @@ class QueryProgressEvent: Event representing any progress updates in a query. .. versionadded:: 3.4.0 + .. versionchanged:: 3.5.0 + Add fromJson constructor to support Spark Connect. Notes ----- This API is evolving. """ - def __init__(self, jevent: JavaObject) -> None: - self._progress: StreamingQueryProgress = StreamingQueryProgress(jevent.progress()) + def __init__(self, progress: "StreamingQueryProgress") -> None: + self._progress: StreamingQueryProgress = progress + + @classmethod + def fromJObject(cls, jevent: JavaObject) -> "QueryProgressEvent": + return cls(progress=StreamingQueryProgress.fromJObject(jevent.progress())) + + @classmethod + def fromJson(cls, j: Dict[str, Any]) -> "QueryProgressEvent": + return cls(progress=StreamingQueryProgress.fromJson(j["progress"])) @property def progress(self) -> "StreamingQueryProgress": @@ -225,10 +258,22 @@ class QueryIdleEvent: This API is evolving. """ - def __init__(self, jevent: JavaObject) -> None: - self._id: uuid.UUID = uuid.UUID(jevent.id().toString()) - self._runId: uuid.UUID = uuid.UUID(jevent.runId().toString()) - self._timestamp: str = jevent.timestamp() + def __init__(self, id: uuid.UUID, runId: uuid.UUID, timestamp: str) -> None: + self._id: uuid.UUID = id + self._runId: uuid.UUID = runId + self._timestamp: str = timestamp + + @classmethod + def fromJObject(cls, jevent: JavaObject) -> "QueryIdleEvent": + return cls( + id=uuid.UUID(jevent.id().toString()), + runId=uuid.UUID(jevent.runId().toString()), + timestamp=jevent.timestamp(), + ) + + @classmethod + def fromJson(cls, j: Dict[str, Any]) -> "QueryIdleEvent": + return cls(id=uuid.UUID(j["id"]), runId=uuid.UUID(j["runId"]), timestamp=j["timestamp"]) @property def id(self) -> uuid.UUID: @@ -259,20 +304,44 @@ class QueryTerminatedEvent: Event representing that termination of a query. .. versionadded:: 3.4.0 + .. versionchanged:: 3.5.0 + Add fromJson constructor to support Spark Connect. Notes ----- This API is evolving. """ - def __init__(self, jevent: JavaObject) -> None: - self._id: uuid.UUID = uuid.UUID(jevent.id().toString()) - self._runId: uuid.UUID = uuid.UUID(jevent.runId().toString()) + def __init__( + self, + id: uuid.UUID, + runId: uuid.UUID, + exception: Optional[str], + errorClassOnException: Optional[str], + ) -> None: + self._id: uuid.UUID = id + self._runId: uuid.UUID = runId + self._exception: Optional[str] = exception + self._errorClassOnException: Optional[str] = errorClassOnException + + @classmethod + def fromJObject(cls, jevent: JavaObject) -> "QueryTerminatedEvent": jexception = jevent.exception() - self._exception: Optional[str] = jexception.get() if jexception.isDefined() else None jerrorclass = jevent.errorClassOnException() - self._errorClassOnException: Optional[str] = ( - jerrorclass.get() if jerrorclass.isDefined() else None + return cls( + id=uuid.UUID(jevent.id().toString()), + runId=uuid.UUID(jevent.runId().toString()), + exception=jexception.get() if jexception.isDefined() else None, + errorClassOnException=jerrorclass.get() if jerrorclass.isDefined() else None, + ) + + @classmethod + def fromJson(cls, j: Dict[str, Any]) -> "QueryTerminatedEvent": + return cls( + id=uuid.UUID(j["id"]), + runId=uuid.UUID(j["runId"]), + exception=j["exception"], + errorClassOnException=j["errorClassOnException"], ) @property @@ -316,38 +385,105 @@ def errorClassOnException(self) -> Optional[str]: class StreamingQueryProgress: """ .. versionadded:: 3.4.0 + .. versionchanged:: 3.5.0 + Add fromJson constructor to support Spark Connect. Notes ----- This API is evolving. """ - def __init__(self, jprogress: JavaObject) -> None: + def __init__( + self, + json: str, + prettyJson: str, + id: uuid.UUID, + runId: uuid.UUID, + name: Optional[str], + timestamp: str, + batchId: int, + batchDuration: int, + durationMs: Dict[str, int], + eventTime: Dict[str, str], + stateOperators: List["StateOperatorProgress"], + sources: List["SourceProgress"], + sink: "SinkProgress", + numInputRows: Optional[str], + inputRowsPerSecond: float, + processedRowsPerSecond: float, + observedMetrics: Dict[str, Row] + ): + self._json: str = json + self._prettyJson: str = prettyJson + self._id: uuid.UUID = id + self._runId: uuid.UUID = runId + self._name: Optional[str] = name + self._timestamp: str = timestamp + self._batchId: int = batchId + self._batchDuration: int = batchDuration + self._durationMs: Dict[str, int] = durationMs + self._eventTime: Dict[str, str] = eventTime + self._stateOperators: List[StateOperatorProgress] = stateOperators + self._sources: List[SourceProgress] = sources + self._sink: SinkProgress = sink + self._numInputRows: Optional[str] = numInputRows + self._inputRowsPerSecond: float = inputRowsPerSecond + self._processedRowsPerSecond: float = processedRowsPerSecond + self._observedMetrics: Dict[str, Row] = observedMetrics + + @classmethod + def fromJObject(cls, jprogress: JavaObject) -> "StreamingQueryProgress": from pyspark import SparkContext - self._jprogress: JavaObject = jprogress - self._id: uuid.UUID = uuid.UUID(jprogress.id().toString()) - self._runId: uuid.UUID = uuid.UUID(jprogress.runId().toString()) - self._name: Optional[str] = jprogress.name() - self._timestamp: str = jprogress.timestamp() - self._batchId: int = jprogress.batchId() - self._inputRowsPerSecond: float = jprogress.inputRowsPerSecond() - self._processedRowsPerSecond: float = jprogress.processedRowsPerSecond() - self._batchDuration: int = jprogress.batchDuration() - self._durationMs: Dict[str, int] = dict(jprogress.durationMs()) - self._eventTime: Dict[str, str] = dict(jprogress.eventTime()) - self._stateOperators: List[StateOperatorProgress] = [ - StateOperatorProgress(js) for js in jprogress.stateOperators() - ] - self._sources: List[SourceProgress] = [SourceProgress(js) for js in jprogress.sources()] - self._sink: SinkProgress = SinkProgress(jprogress.sink()) - - self._observedMetrics: Dict[str, Row] = { - k: cloudpickle.loads( - SparkContext._jvm.PythonSQLUtils.toPyRow(jr) # type: ignore[union-attr] - ) - for k, jr in dict(jprogress.observedMetrics()).items() - } + return cls( + json=jprogress.json(), + prettyJson=jprogress.prettyJson(), + id=uuid.UUID(jprogress.id().toString()), + runId=uuid.UUID(jprogress.runId().toString()), + name=jprogress.name(), + timestamp=jprogress.timestamp(), + batchId=jprogress.batchId(), + batchDuration=jprogress.batchDuration(), + durationMs=dict(jprogress.durationMs()), + eventTime=dict(jprogress.eventTime()), + stateOperators=[StateOperatorProgress.fromJObject(js) for js in jprogress.stateOperators()], + sources=[SourceProgress.fromJObject(js) for js in jprogress.sources()], + sink=SinkProgress.fromJObject(jprogress.sink()), + numInputRows=jprogress.numInputRows(), + inputRowsPerSecond=jprogress.inputRowsPerSecond(), + processedRowsPerSecond=jprogress.processedRowsPerSecond(), + observedMetrics={ + k: cloudpickle.loads( + SparkContext._jvm.PythonSQLUtils.toPyRow(jr) # type: ignore[union-attr] + ) + for k, jr in dict(jprogress.observedMetrics()).items() + }, + ) + + @classmethod + def fromJson(cls, j: Dict[str, Any]) -> "StreamingQueryProgress": + return cls( + json=json.dumps(j), + prettyJson=json.dumps(j, indent=4), + id=uuid.UUID(j["id"]), + runId=uuid.UUID(j["runId"]), + name=j["name"], + timestamp=j["timestamp"], + batchId=j["batchId"], + batchDuration=j["batchDuration"], + durationMs=dict(j["durationMs"]), + eventTime=dict(j["eventTime"]), + stateOperators=[StateOperatorProgress.fromJson(s) for s in j["stateOperators"]], + sources=[SourceProgress.fromJson(s) for s in j["sources"]], + sink=SinkProgress.fromJson(j["sink"]), + numInputRows=j["numInputRows"], + inputRowsPerSecond=j["inputRowsPerSecond"], + processedRowsPerSecond=j["processedRowsPerSecond"], + observedMetrics={ + k: Row(*row_dict.keys())(*row_dict.values()) # Assume no nested rows + for k, row_dict in j["observedMetrics"].items() + }, + ) @property def id(self) -> uuid.UUID: @@ -452,7 +588,7 @@ def numInputRows(self) -> Optional[str]: """ The aggregate (across all sources) number of records processed in a trigger. """ - return self._jprogress.numInputRows() + return self._numInputRows @property def inputRowsPerSecond(self) -> float: @@ -464,7 +600,7 @@ def inputRowsPerSecond(self) -> float: @property def processedRowsPerSecond(self) -> float: """ - The aggregate (across all sources) rate at which Spark is processing data.. + The aggregate (across all sources) rate at which Spark is processing data. """ return self._processedRowsPerSecond @@ -473,14 +609,14 @@ def json(self) -> str: """ The compact JSON representation of this progress. """ - return self._jprogress.json() + return self._json @property def prettyJson(self) -> str: """ The pretty (i.e. indented) JSON representation of this progress. """ - return self._jprogress.prettyJson() + return self._prettyJson def __str__(self) -> str: return self.prettyJson @@ -489,26 +625,83 @@ def __str__(self) -> str: class StateOperatorProgress: """ .. versionadded:: 3.4.0 + .. versionchanged:: 3.5.0 + Add fromJson constructor to support Spark Connect. Notes ----- This API is evolving. """ - def __init__(self, jprogress: JavaObject) -> None: - self._jprogress: JavaObject = jprogress - self._operatorName: str = jprogress.operatorName() - self._numRowsTotal: int = jprogress.numRowsTotal() - self._numRowsUpdated: int = jprogress.numRowsUpdated() - self._allUpdatesTimeMs: int = jprogress.allUpdatesTimeMs() - self._numRowsRemoved: int = jprogress.numRowsRemoved() - self._allRemovalsTimeMs: int = jprogress.allRemovalsTimeMs() - self._commitTimeMs: int = jprogress.commitTimeMs() - self._memoryUsedBytes: int = jprogress.memoryUsedBytes() - self._numRowsDroppedByWatermark: int = jprogress.numRowsDroppedByWatermark() - self._numShufflePartitions: int = jprogress.numShufflePartitions() - self._numStateStoreInstances: int = jprogress.numStateStoreInstances() - self._customMetrics: Dict[str, int] = dict(jprogress.customMetrics()) + def __init__( + self, + json: str, + prettyJson: str, + operatorName: str, + numRowsTotal: int, + numRowsUpdated: int, + numRowsRemoved: int, + allUpdatesTimeMs: int, + allRemovalsTimeMs: int, + commitTimeMs: int, + memoryUsedBytes: int, + numRowsDroppedByWatermark: int, + numShufflePartitions: int, + numStateStoreInstances: int, + customMetrics: Dict[str, int], + ): + self._json: str = json + self._prettyJson: str = prettyJson + self._operatorName: str = operatorName + self._numRowsTotal: int = numRowsTotal + self._numRowsUpdated: int = numRowsUpdated + self._numRowsRemoved: int = numRowsRemoved + self._allUpdatesTimeMs: int = allUpdatesTimeMs + self._allRemovalsTimeMs: int = allRemovalsTimeMs + self._commitTimeMs: int = commitTimeMs + self._memoryUsedBytes: int = memoryUsedBytes + self._numRowsDroppedByWatermark: int = numRowsDroppedByWatermark + self._numShufflePartitions: int = numShufflePartitions + self._numStateStoreInstances: int = numStateStoreInstances + self._customMetrics: Dict[str, int] = customMetrics + + @classmethod + def fromJObject(cls, jprogress: JavaObject) -> "StateOperatorProgress": + return cls( + json=jprogress.json(), + prettyJson=jprogress.prettyJson(), + operatorName=jprogress.operatorName(), + numRowsTotal=jprogress.numRowsTotal(), + numRowsUpdated=jprogress.numRowsUpdated(), + allUpdatesTimeMs=jprogress.allUpdatesTimeMs(), + numRowsRemoved=jprogress.numRowsRemoved(), + allRemovalsTimeMs=jprogress.allRemovalsTimeMs(), + commitTimeMs=jprogress.commitTimeMs(), + memoryUsedBytes=jprogress.memoryUsedBytes(), + numRowsDroppedByWatermark=jprogress.numRowsDroppedByWatermark(), + numShufflePartitions=jprogress.numShufflePartitions(), + numStateStoreInstances=jprogress.numStateStoreInstances(), + customMetrics=dict(jprogress.customMetrics()), + ) + + @classmethod + def fromJson(cls, j: Dict[str, Any]) -> "StateOperatorProgress": + return cls( + json=json.dumps(j), + prettyJson=json.dumps(j, indent=4), + operatorName=j["operatorName"], + numRowsTotal=j["numRowsTotal"], + numRowsUpdated=j["numRowsUpdated"], + numRowsRemoved=j["numRowsRemoved"], + allUpdatesTimeMs=j["allUpdatesTimeMs"], + allRemovalsTimeMs=j["allRemovalsTimeMs"], + commitTimeMs=j["commitTimeMs"], + memoryUsedBytes=j["memoryUsedBytes"], + numRowsDroppedByWatermark=j["numRowsDroppedByWatermark"], + numShufflePartitions=j["numShufflePartitions"], + numStateStoreInstances=j["numStateStoreInstances"], + customMetrics=dict(j["customMetrics"]), + ) @property def operatorName(self) -> str: @@ -563,14 +756,14 @@ def json(self) -> str: """ The compact JSON representation of this progress. """ - return self._jprogress.json() + return self._json @property def prettyJson(self) -> str: """ The pretty (i.e. indented) JSON representation of this progress. """ - return self._jprogress.prettyJson() + return self._prettyJson def __str__(self) -> str: return self.prettyJson @@ -579,22 +772,67 @@ def __str__(self) -> str: class SourceProgress: """ .. versionadded:: 3.4.0 + .. versionchanged:: 3.5.0 + Add fromJson constructor to support Spark Connect. Notes ----- This API is evolving. """ - def __init__(self, jprogress: JavaObject) -> None: - self._jprogress: JavaObject = jprogress - self._description: str = jprogress.description() - self._startOffset: str = jprogress.startOffset() - self._endOffset: str = jprogress.endOffset() - self._latestOffset: str = jprogress.latestOffset() - self._numInputRows: int = jprogress.numInputRows() - self._inputRowsPerSecond: float = jprogress.inputRowsPerSecond() - self._processedRowsPerSecond: float = jprogress.processedRowsPerSecond() - self._metrics: Dict[str, str] = dict(jprogress.metrics()) + def __init__( + self, + json: str, + prettyJson: str, + description: str, + startOffset: str, + endOffset: str, + latestOffset: str, + numInputRows: int, + inputRowsPerSecond: float, + processedRowsPerSecond: float, + metrics: Dict[str, str], + ) -> None: + self._json: str = json + self._prettyJson: str = prettyJson + self._description: str = description + self._startOffset: str = startOffset + self._endOffset: str = endOffset + self._latestOffset: str = latestOffset + self._numInputRows: int = numInputRows + self._inputRowsPerSecond: float = inputRowsPerSecond + self._processedRowsPerSecond: float = processedRowsPerSecond + self._metrics: Dict[str, str] = metrics + + @classmethod + def fromJObject(cls, jprogress: JavaObject) -> "SourceProgress": + return cls( + json=jprogress.json(), + prettyJson=jprogress.prettyJson(), + description=jprogress.description(), + startOffset=str(jprogress.startOffset()), + endOffset=str(jprogress.endOffset()), + latestOffset=str(jprogress.latestOffset()), + numInputRows=jprogress.numInputRows(), + inputRowsPerSecond=jprogress.inputRowsPerSecond(), + processedRowsPerSecond=jprogress.processedRowsPerSecond(), + metrics=dict(jprogress.metrics()) + ) + + @classmethod + def fromJson(cls, j: Dict[str, Any]) -> "SourceProgress": + return cls( + json=json.dumps(j), + prettyJson=json.dumps(j, indent=4), + description=j["description"], + startOffset=str(j["startOffset"]), + endOffset=str(j["endOffset"]), + latestOffset=str(j["latestOffset"]), + numInputRows=j["numInputRows"], + inputRowsPerSecond=j["inputRowsPerSecond"], + processedRowsPerSecond=j["processedRowsPerSecond"], + metrics=dict(j["metrics"]), + ) @property def description(self) -> str: @@ -654,14 +892,14 @@ def json(self) -> str: """ The compact JSON representation of this progress. """ - return self._jprogress.json() + return self._json @property def prettyJson(self) -> str: """ The pretty (i.e. indented) JSON representation of this progress. """ - return self._jprogress.prettyJson() + return self._prettyJson def __str__(self) -> str: return self.prettyJson @@ -670,17 +908,47 @@ def __str__(self) -> str: class SinkProgress: """ .. versionadded:: 3.4.0 + .. versionchanged:: 3.5.0 + Add fromJson constructor to support Spark Connect. Notes ----- This API is evolving. """ - def __init__(self, jprogress: JavaObject) -> None: - self._jprogress: JavaObject = jprogress - self._description: str = jprogress.description() - self._numOutputRows: int = jprogress.numOutputRows() - self._metrics: Dict[str, str] = dict(jprogress.metrics()) + def __init__( + self, + json: str, + prettyJson: str, + description: str, + numOutputRows: int, + metrics: Dict[str, str], + ) -> None: + self._json: str = json + self._prettyJson: str = prettyJson + self._description: str = description + self._numOutputRows: int = numOutputRows + self._metrics: Dict[str, str] = metrics + + @classmethod + def fromJObject(cls, jprogress: JavaObject) -> "SinkProgress": + return cls( + json=jprogress.json(), + prettyJson=jprogress.prettyJson(), + description=jprogress.description(), + numOutputRows=jprogress.numOutputRows(), + metrics=dict(jprogress.metrics()), + ) + + @classmethod + def fromJson(cls, j: Dict[str, Any]) -> "SinkProgress": + return cls( + json=json.dumps(j), + prettyJson=json.dumps(j, indent=4), + description=j["description"], + numOutputRows=j["numOutputRows"], + metrics=j["metrics"], + ) @property def description(self) -> str: @@ -706,14 +974,14 @@ def json(self) -> str: """ The compact JSON representation of this progress. """ - return self._jprogress.json() + return self._json @property def prettyJson(self) -> str: """ The pretty (i.e. indented) JSON representation of this progress. """ - return self._jprogress.prettyJson() + return self._prettyJson def __str__(self) -> str: return self.prettyJson diff --git a/python/pyspark/sql/tests/streaming/test_streaming_listener.py b/python/pyspark/sql/tests/streaming/test_streaming_listener.py index 71d76bc4e8d52..494e5319c3ab9 100644 --- a/python/pyspark/sql/tests/streaming/test_streaming_listener.py +++ b/python/pyspark/sql/tests/streaming/test_streaming_listener.py @@ -19,6 +19,7 @@ import uuid from datetime import datetime +from pyspark import Row from pyspark.sql.streaming import StreamingQueryListener from pyspark.sql.streaming.listener import ( QueryStartedEvent, @@ -51,7 +52,7 @@ def get_number_of_public_methods(clz): get_number_of_public_methods( "org.apache.spark.sql.streaming.StreamingQueryListener$QueryStartedEvent" ), - 14, + 15, msg, ) self.assertEquals( @@ -65,7 +66,7 @@ def get_number_of_public_methods(clz): get_number_of_public_methods( "org.apache.spark.sql.streaming.StreamingQueryListener$QueryTerminatedEvent" ), - 14, + 17, msg, ) self.assertEquals( @@ -112,7 +113,8 @@ def onQueryTerminated(self, event): self.spark.streams.addListener(test_listener) df = self.spark.readStream.format("rate").option("rowsPerSecond", 10).load() - q = df.writeStream.format("noop").queryName("test").start() + df = df.groupBy().count() # make query stateful + q = df.writeStream.format("noop").queryName("test").outputMode("complete").start() self.assertTrue(q.isActive) time.sleep(10) q.stop() @@ -131,7 +133,7 @@ def check_start_event(self, event): self.assertTrue(isinstance(event, QueryStartedEvent)) self.assertTrue(isinstance(event.id, uuid.UUID)) self.assertTrue(isinstance(event.runId, uuid.UUID)) - self.assertEquals(event.name, "test") + self.assertTrue(event.name is None or event.name == "test") try: datetime.strptime(event.timestamp, "%Y-%m-%dT%H:%M:%S.%fZ") except ValueError: @@ -194,10 +196,12 @@ def check_streaming_query_progress(self, progress): self.assertEquals(progress.eventTime, {}) self.assertTrue(isinstance(progress.stateOperators, list)) + self.assertTrue(len(progress.stateOperators) >= 1) for so in progress.stateOperators: self.check_state_operator_progress(so) self.assertTrue(isinstance(progress.sources, list)) + self.assertTrue(len(progress.sources) >= 1) for so in progress.sources: self.check_source_progress(so) @@ -299,6 +303,160 @@ def onQueryTerminated(self, event): self.spark.streams.removeListener(test_listener) self.assertEqual(num_listeners, len(self.spark.streams._jsqm.listListeners())) + def test_query_started_event_fromJson(self): + start_event = """ + { + "id" : "78923ec2-8f4d-4266-876e-1f50cf3c283b", + "runId" : "55a95d45-e932-4e08-9caa-0a8ecd9391e8", + "name" : null, + "timestamp" : "2023-06-09T18:13:29.741Z" + } + """ + start_event = QueryStartedEvent.fromJson(json.loads(start_event)) + self.check_start_event(start_event) + self.assertTrue(start_event.id == uuid.UUID("78923ec2-8f4d-4266-876e-1f50cf3c283b")) + self.assertTrue(start_event.runId == uuid.UUID("55a95d45-e932-4e08-9caa-0a8ecd9391e8")) + self.assertTrue(start_event.name is None) + self.assertTrue(start_event.timestamp == "2023-06-09T18:13:29.741Z") + + def test_query_terminated_event_fromJson(self): + terminated_json = """ + { + "id" : "78923ec2-8f4d-4266-876e-1f50cf3c283b", + "runId" : "55a95d45-e932-4e08-9caa-0a8ecd9391e8", + "exception" : null, + "errorClassOnException" : null} + """ + terminated_event = QueryTerminatedEvent.fromJson(json.loads(terminated_json)) + self.check_terminated_event(terminated_event) + self.assertTrue(terminated_event.id == uuid.UUID("78923ec2-8f4d-4266-876e-1f50cf3c283b")) + self.assertTrue(terminated_event.runId == uuid.UUID("55a95d45-e932-4e08-9caa-0a8ecd9391e8")) + self.assertTrue(terminated_event.exception is None) + self.assertTrue(terminated_event.errorClassOnException is None) + + def test_streaming_query_progress_fromJson(self): + progress_json = """ + { + "id" : "00000000-0000-0001-0000-000000000001", + "runId" : "00000000-0000-0001-0000-000000000002", + "name" : "test", + "timestamp" : "2016-12-05T20:54:20.827Z", + "batchId" : 2, + "numInputRows" : 678, + "inputRowsPerSecond" : 10.0, + "processedRowsPerSecond" : 5.4, + "batchDuration": 5, + "durationMs" : { + "getBatch" : 0 + }, + "eventTime" : {}, + "stateOperators" : [ { + "operatorName" : "op1", + "numRowsTotal" : 0, + "numRowsUpdated" : 1, + "allUpdatesTimeMs" : 1, + "numRowsRemoved" : 2, + "allRemovalsTimeMs" : 34, + "commitTimeMs" : 23, + "memoryUsedBytes" : 3, + "numRowsDroppedByWatermark" : 0, + "numShufflePartitions" : 2, + "numStateStoreInstances" : 2, + "customMetrics" : { + "loadedMapCacheHitCount" : 1, + "loadedMapCacheMissCount" : 0, + "stateOnCurrentVersionSizeBytes" : 2 + } + } ], + "sources" : [ { + "description" : "source", + "startOffset" : 123, + "endOffset" : 456, + "latestOffset" : 789, + "numInputRows" : 678, + "inputRowsPerSecond" : 10.0, + "processedRowsPerSecond" : 5.4, + "metrics": {} + } ], + "sink" : { + "description" : "sink", + "numOutputRows" : -1, + "metrics": {} + }, + "observedMetrics" : { + "event1" : { + "c1" : 1, + "c2" : 3.0 + }, + "event2" : { + "rc" : 1, + "min_q" : "hello", + "max_q" : "world" + } + } + } + """ + progress = StreamingQueryProgress.fromJson(json.loads(progress_json)) + + self.check_streaming_query_progress(progress) + + # checks for progress + self.assertTrue(progress.id == uuid.UUID("00000000-0000-0001-0000-000000000001")) + self.assertTrue(progress.runId == uuid.UUID("00000000-0000-0001-0000-000000000002")) + self.assertTrue(progress.name == "test") + self.assertTrue(progress.timestamp == "2016-12-05T20:54:20.827Z") + self.assertTrue(progress.batchId == 2) + self.assertTrue(progress.numInputRows == 678) + self.assertTrue(progress.inputRowsPerSecond == 10.0) + self.assertTrue(progress.batchDuration == 5) + self.assertTrue(progress.durationMs == {"getBatch": 0}) + self.assertTrue(progress.eventTime == {}) + self.assertTrue(progress.observedMetrics == { + "event1": Row("c1", "c2")(1, 3.0), + "event2": Row("rc", "min_q", "max_q")(1, "hello", "world") + }) + + # Check stateOperators list + self.assertTrue(len(progress.stateOperators) == 1) + state_operator = progress.stateOperators[0] + self.assertTrue(isinstance(state_operator, StateOperatorProgress)) + self.assertTrue(state_operator.operatorName == "op1") + self.assertTrue(state_operator.numRowsTotal == 0) + self.assertTrue(state_operator.numRowsUpdated == 1) + self.assertTrue(state_operator.allUpdatesTimeMs == 1) + self.assertTrue(state_operator.numRowsRemoved == 2) + self.assertTrue(state_operator.allRemovalsTimeMs == 34) + self.assertTrue(state_operator.commitTimeMs == 23) + self.assertTrue(state_operator.memoryUsedBytes == 3) + self.assertTrue(state_operator.numRowsDroppedByWatermark == 0) + self.assertTrue(state_operator.numShufflePartitions == 2) + self.assertTrue(state_operator.numStateStoreInstances == 2) + self.assertTrue(state_operator.customMetrics == { + "loadedMapCacheHitCount": 1, + "loadedMapCacheMissCount": 0, + "stateOnCurrentVersionSizeBytes": 2 + }) + + # Check sources list + self.assertTrue(len(progress.sources) == 1) + source = progress.sources[0] + self.assertTrue(isinstance(source, SourceProgress)) + self.assertTrue(source.description == "source") + self.assertTrue(source.startOffset == "123") + self.assertTrue(source.endOffset == "456") + self.assertTrue(source.latestOffset == "789") + self.assertTrue(source.numInputRows == 678) + self.assertTrue(source.inputRowsPerSecond == 10.0) + self.assertTrue(source.processedRowsPerSecond == 5.4) + self.assertTrue(source.metrics == {}) + + # Check sink + sink = progress.sink + self.assertTrue(isinstance(sink, SinkProgress)) + self.assertTrue(sink.description == "sink") + self.assertTrue(sink.numOutputRows == -1) + self.assertTrue(sink.metrics == {}) + if __name__ == "__main__": import unittest diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryListener.scala index 61a0ef1b98e54..8f99fa2f9d391 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryListener.scala @@ -19,6 +19,11 @@ package org.apache.spark.sql.streaming import java.util.UUID +import org.json4s.JsonAST.JValue +import org.json4s.JsonDSL.{jobject2assoc, pair2Assoc} +import org.json4s.JString +import org.json4s.jackson.JsonMethods.{compact, render} + import org.apache.spark.annotation.Evolving import org.apache.spark.scheduler.SparkListenerEvent @@ -123,7 +128,17 @@ object StreamingQueryListener { val id: UUID, val runId: UUID, val name: String, - val timestamp: String) extends Event + val timestamp: String) extends Event { + + def json: String = compact(render(jsonValue)) + + private def jsonValue: JValue = { + ("id" -> JString(id.toString)) ~ + ("runId" -> JString(runId.toString)) ~ + ("name" -> JString(name)) ~ + ("timestamp" -> JString(timestamp)) + } + } /** * Event representing any progress updates in a query. @@ -145,7 +160,16 @@ object StreamingQueryListener { class QueryIdleEvent private[sql]( val id: UUID, val runId: UUID, - val timestamp: String) extends Event + val timestamp: String) extends Event { + + def json: String = compact(render(jsonValue)) + + private def jsonValue: JValue = { + ("id" -> JString(id.toString)) ~ + ("runId" -> JString(runId.toString)) ~ + ("timestamp" -> JString(timestamp)) + } + } /** * Event representing that termination of a query. @@ -171,5 +195,14 @@ object StreamingQueryListener { def this(id: UUID, runId: UUID, exception: Option[String]) = { this(id, runId, exception, None) } + + def json: String = compact(render(jsonValue)) + + private def jsonValue: JValue = { + ("id" -> JString(id.toString)) ~ + ("runId" -> JString(runId.toString)) ~ + ("exception" -> JString(exception.orNull)) ~ + ("errorClassOnException" -> JString(errorClassOnException.orNull)) + } } } From 4d52eecd31c75f68e62a71ea51e608d90722ee42 Mon Sep 17 00:00:00 2001 From: Wei Liu Date: Fri, 9 Jun 2023 12:04:21 -0700 Subject: [PATCH 02/14] fmt --- python/pyspark/sql/streaming/listener.py | 8 ++++--- .../streaming/test_streaming_listener.py | 24 ++++++++++++------- 2 files changed, 20 insertions(+), 12 deletions(-) diff --git a/python/pyspark/sql/streaming/listener.py b/python/pyspark/sql/streaming/listener.py index 8bc749c9be1a1..68bfeeda55c47 100644 --- a/python/pyspark/sql/streaming/listener.py +++ b/python/pyspark/sql/streaming/listener.py @@ -411,7 +411,7 @@ def __init__( numInputRows: Optional[str], inputRowsPerSecond: float, processedRowsPerSecond: float, - observedMetrics: Dict[str, Row] + observedMetrics: Dict[str, Row], ): self._json: str = json self._prettyJson: str = prettyJson @@ -446,7 +446,9 @@ def fromJObject(cls, jprogress: JavaObject) -> "StreamingQueryProgress": batchDuration=jprogress.batchDuration(), durationMs=dict(jprogress.durationMs()), eventTime=dict(jprogress.eventTime()), - stateOperators=[StateOperatorProgress.fromJObject(js) for js in jprogress.stateOperators()], + stateOperators=[ + StateOperatorProgress.fromJObject(js) for js in jprogress.stateOperators() + ], sources=[SourceProgress.fromJObject(js) for js in jprogress.sources()], sink=SinkProgress.fromJObject(jprogress.sink()), numInputRows=jprogress.numInputRows(), @@ -816,7 +818,7 @@ def fromJObject(cls, jprogress: JavaObject) -> "SourceProgress": numInputRows=jprogress.numInputRows(), inputRowsPerSecond=jprogress.inputRowsPerSecond(), processedRowsPerSecond=jprogress.processedRowsPerSecond(), - metrics=dict(jprogress.metrics()) + metrics=dict(jprogress.metrics()), ) @classmethod diff --git a/python/pyspark/sql/tests/streaming/test_streaming_listener.py b/python/pyspark/sql/tests/streaming/test_streaming_listener.py index 494e5319c3ab9..2afd74b57de07 100644 --- a/python/pyspark/sql/tests/streaming/test_streaming_listener.py +++ b/python/pyspark/sql/tests/streaming/test_streaming_listener.py @@ -411,10 +411,13 @@ def test_streaming_query_progress_fromJson(self): self.assertTrue(progress.batchDuration == 5) self.assertTrue(progress.durationMs == {"getBatch": 0}) self.assertTrue(progress.eventTime == {}) - self.assertTrue(progress.observedMetrics == { - "event1": Row("c1", "c2")(1, 3.0), - "event2": Row("rc", "min_q", "max_q")(1, "hello", "world") - }) + self.assertTrue( + progress.observedMetrics + == { + "event1": Row("c1", "c2")(1, 3.0), + "event2": Row("rc", "min_q", "max_q")(1, "hello", "world"), + } + ) # Check stateOperators list self.assertTrue(len(progress.stateOperators) == 1) @@ -431,11 +434,14 @@ def test_streaming_query_progress_fromJson(self): self.assertTrue(state_operator.numRowsDroppedByWatermark == 0) self.assertTrue(state_operator.numShufflePartitions == 2) self.assertTrue(state_operator.numStateStoreInstances == 2) - self.assertTrue(state_operator.customMetrics == { - "loadedMapCacheHitCount": 1, - "loadedMapCacheMissCount": 0, - "stateOnCurrentVersionSizeBytes": 2 - }) + self.assertTrue( + state_operator.customMetrics + == { + "loadedMapCacheHitCount": 1, + "loadedMapCacheMissCount": 0, + "stateOnCurrentVersionSizeBytes": 2, + } + ) # Check sources list self.assertTrue(len(progress.sources) == 1) From 80bb4c48e95937368e01a589bd7ff60f843aa4da Mon Sep 17 00:00:00 2001 From: Wei Liu Date: Fri, 9 Jun 2023 13:15:30 -0700 Subject: [PATCH 03/14] also add test to exception in QueryTerminatedEvent --- .../streaming/test_streaming_listener.py | 44 ++++++++++++++----- 1 file changed, 34 insertions(+), 10 deletions(-) diff --git a/python/pyspark/sql/tests/streaming/test_streaming_listener.py b/python/pyspark/sql/tests/streaming/test_streaming_listener.py index 2afd74b57de07..e97ddadf9b898 100644 --- a/python/pyspark/sql/tests/streaming/test_streaming_listener.py +++ b/python/pyspark/sql/tests/streaming/test_streaming_listener.py @@ -66,7 +66,7 @@ def get_number_of_public_methods(clz): get_number_of_public_methods( "org.apache.spark.sql.streaming.StreamingQueryListener$QueryTerminatedEvent" ), - 17, + 15, msg, ) self.assertEquals( @@ -113,8 +113,15 @@ def onQueryTerminated(self, event): self.spark.streams.addListener(test_listener) df = self.spark.readStream.format("rate").option("rowsPerSecond", 10).load() - df = df.groupBy().count() # make query stateful - q = df.writeStream.format("noop").queryName("test").outputMode("complete").start() + + # check successful stateful query + df_stateful = df.groupBy().count() # make query stateful + q = ( + df_stateful.writeStream.format("noop") + .queryName("test") + .outputMode("complete") + .start() + ) self.assertTrue(q.isActive) time.sleep(10) q.stop() @@ -125,6 +132,17 @@ def onQueryTerminated(self, event): self.check_start_event(start_event) self.check_progress_event(progress_event) self.check_terminated_event(terminated_event) + + # Check query terminated with exception + from pyspark.sql.functions import col, udf + + bad_udf = udf(lambda x: 1 / 0) + q = df.select(bad_udf(col("value"))).writeStream.format("noop").start() + time.sleep(5) + q.stop() + self.spark.sparkContext._jsc.sc().listenerBus().waitUntilEmpty() + self.check_terminated_event(terminated_event, "ZeroDivisionError") + finally: self.spark.streams.removeListener(test_listener) @@ -144,14 +162,20 @@ def check_progress_event(self, event): self.assertTrue(isinstance(event, QueryProgressEvent)) self.check_streaming_query_progress(event.progress) - def check_terminated_event(self, event): + def check_terminated_event(self, event, exception=None, error_class=None): """Check QueryTerminatedEvent""" self.assertTrue(isinstance(event, QueryTerminatedEvent)) self.assertTrue(isinstance(event.id, uuid.UUID)) self.assertTrue(isinstance(event.runId, uuid.UUID)) - # TODO: Needs a test for exception. - self.assertEquals(event.exception, None) - self.assertEquals(event.errorClassOnException, None) + if exception: + self.assertTrue(exception in event.exception) + else: + self.assertEquals(event.exception, None) + + if error_class: + self.assertTrue(error_class in event.errorClassOnException) + else: + self.assertEquals(event.errorClassOnException, None) def check_streaming_query_progress(self, progress): """Check StreamingQueryProgress""" @@ -324,14 +348,14 @@ def test_query_terminated_event_fromJson(self): { "id" : "78923ec2-8f4d-4266-876e-1f50cf3c283b", "runId" : "55a95d45-e932-4e08-9caa-0a8ecd9391e8", - "exception" : null, + "exception" : "org.apache.spark.SparkException: Job aborted due to stage failure...", "errorClassOnException" : null} """ terminated_event = QueryTerminatedEvent.fromJson(json.loads(terminated_json)) - self.check_terminated_event(terminated_event) + self.check_terminated_event(terminated_event, "SparkException") self.assertTrue(terminated_event.id == uuid.UUID("78923ec2-8f4d-4266-876e-1f50cf3c283b")) self.assertTrue(terminated_event.runId == uuid.UUID("55a95d45-e932-4e08-9caa-0a8ecd9391e8")) - self.assertTrue(terminated_event.exception is None) + self.assertTrue("SparkException" in terminated_event.exception) self.assertTrue(terminated_event.errorClassOnException is None) def test_streaming_query_progress_fromJson(self): From ce574378db146413ff284e967648ab79bb510bc5 Mon Sep 17 00:00:00 2001 From: Wei Liu Date: Fri, 9 Jun 2023 13:22:04 -0700 Subject: [PATCH 04/14] type fix for StreamingQueryProgress.numInputRows --- python/pyspark/sql/streaming/listener.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/streaming/listener.py b/python/pyspark/sql/streaming/listener.py index 68bfeeda55c47..b0594380286b7 100644 --- a/python/pyspark/sql/streaming/listener.py +++ b/python/pyspark/sql/streaming/listener.py @@ -408,7 +408,7 @@ def __init__( stateOperators: List["StateOperatorProgress"], sources: List["SourceProgress"], sink: "SinkProgress", - numInputRows: Optional[str], + numInputRows: int, inputRowsPerSecond: float, processedRowsPerSecond: float, observedMetrics: Dict[str, Row], @@ -586,7 +586,7 @@ def observedMetrics(self) -> Dict[str, Row]: return self._observedMetrics @property - def numInputRows(self) -> Optional[str]: + def numInputRows(self) -> int: """ The aggregate (across all sources) number of records processed in a trigger. """ From 021446436517894829e1c6f0c81c8e4f6feb66c3 Mon Sep 17 00:00:00 2001 From: Wei Liu Date: Mon, 12 Jun 2023 10:14:53 -0700 Subject: [PATCH 05/14] lint, line too long --- python/pyspark/sql/tests/streaming/test_streaming_listener.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/sql/tests/streaming/test_streaming_listener.py b/python/pyspark/sql/tests/streaming/test_streaming_listener.py index e97ddadf9b898..14b9945be8110 100644 --- a/python/pyspark/sql/tests/streaming/test_streaming_listener.py +++ b/python/pyspark/sql/tests/streaming/test_streaming_listener.py @@ -348,7 +348,7 @@ def test_query_terminated_event_fromJson(self): { "id" : "78923ec2-8f4d-4266-876e-1f50cf3c283b", "runId" : "55a95d45-e932-4e08-9caa-0a8ecd9391e8", - "exception" : "org.apache.spark.SparkException: Job aborted due to stage failure...", + "exception" : "org.apache.spark.SparkException: Job aborted due to stage failure", "errorClassOnException" : null} """ terminated_event = QueryTerminatedEvent.fromJson(json.loads(terminated_json)) From 28869401eff8fcd94aa68d662fdd4d3dbe2f58af Mon Sep 17 00:00:00 2001 From: Wei Liu Date: Mon, 12 Jun 2023 12:08:04 -0700 Subject: [PATCH 06/14] further type fix on numInputRows --- python/pyspark/sql/streaming/listener.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/sql/streaming/listener.py b/python/pyspark/sql/streaming/listener.py index b0594380286b7..5855a5cf32d33 100644 --- a/python/pyspark/sql/streaming/listener.py +++ b/python/pyspark/sql/streaming/listener.py @@ -426,7 +426,7 @@ def __init__( self._stateOperators: List[StateOperatorProgress] = stateOperators self._sources: List[SourceProgress] = sources self._sink: SinkProgress = sink - self._numInputRows: Optional[str] = numInputRows + self._numInputRows: int = numInputRows self._inputRowsPerSecond: float = inputRowsPerSecond self._processedRowsPerSecond: float = processedRowsPerSecond self._observedMetrics: Dict[str, Row] = observedMetrics From 97c08c4460f55d51a53395efff3e695db190006f Mon Sep 17 00:00:00 2001 From: Wei Liu Date: Tue, 13 Jun 2023 10:47:07 -0700 Subject: [PATCH 07/14] add watermark to test json --- .../tests/streaming/test_streaming_listener.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/tests/streaming/test_streaming_listener.py b/python/pyspark/sql/tests/streaming/test_streaming_listener.py index 14b9945be8110..b0f5a7cfd2078 100644 --- a/python/pyspark/sql/tests/streaming/test_streaming_listener.py +++ b/python/pyspark/sql/tests/streaming/test_streaming_listener.py @@ -373,7 +373,12 @@ def test_streaming_query_progress_fromJson(self): "durationMs" : { "getBatch" : 0 }, - "eventTime" : {}, + "eventTime" : { + "min" : "2016-12-05T20:54:20.827Z", + "avg" : "2016-12-05T20:54:20.827Z", + "watermark" : "2016-12-05T20:54:20.827Z", + "max" : "2016-12-05T20:54:20.827Z" + }, "stateOperators" : [ { "operatorName" : "op1", "numRowsTotal" : 0, @@ -434,7 +439,15 @@ def test_streaming_query_progress_fromJson(self): self.assertTrue(progress.inputRowsPerSecond == 10.0) self.assertTrue(progress.batchDuration == 5) self.assertTrue(progress.durationMs == {"getBatch": 0}) - self.assertTrue(progress.eventTime == {}) + self.assertTrue( + progress.eventTime + == { + "min": "2016-12-05T20:54:20.827Z", + "avg": "2016-12-05T20:54:20.827Z", + "watermark": "2016-12-05T20:54:20.827Z", + "max": "2016-12-05T20:54:20.827Z", + } + ) self.assertTrue( progress.observedMetrics == { From d3041ae2ed2e3d1e3849f87668e36d6aac130777 Mon Sep 17 00:00:00 2001 From: Wei Liu Date: Tue, 13 Jun 2023 13:57:28 -0700 Subject: [PATCH 08/14] fix test failure --- python/pyspark/sql/tests/streaming/test_streaming_listener.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/sql/tests/streaming/test_streaming_listener.py b/python/pyspark/sql/tests/streaming/test_streaming_listener.py index b0f5a7cfd2078..e9ba73b42a0bc 100644 --- a/python/pyspark/sql/tests/streaming/test_streaming_listener.py +++ b/python/pyspark/sql/tests/streaming/test_streaming_listener.py @@ -217,7 +217,7 @@ def check_streaming_query_progress(self, progress): ) self.assertTrue(all(map(lambda v: isinstance(v, int), progress.durationMs.values()))) - self.assertEquals(progress.eventTime, {}) + self.assertTrue(all(map(lambda v: isinstance(v, str), progress.eventTime.values()))) self.assertTrue(isinstance(progress.stateOperators, list)) self.assertTrue(len(progress.stateOperators) >= 1) From c5fe43734db89b371f6dcaa3204472aa978c8e24 Mon Sep 17 00:00:00 2001 From: Wei Liu Date: Tue, 13 Jun 2023 16:27:49 -0700 Subject: [PATCH 09/14] use different assert function --- .../streaming/test_streaming_listener.py | 121 ++++++++---------- 1 file changed, 56 insertions(+), 65 deletions(-) diff --git a/python/pyspark/sql/tests/streaming/test_streaming_listener.py b/python/pyspark/sql/tests/streaming/test_streaming_listener.py index e9ba73b42a0bc..21ca34d179f36 100644 --- a/python/pyspark/sql/tests/streaming/test_streaming_listener.py +++ b/python/pyspark/sql/tests/streaming/test_streaming_listener.py @@ -338,10 +338,10 @@ def test_query_started_event_fromJson(self): """ start_event = QueryStartedEvent.fromJson(json.loads(start_event)) self.check_start_event(start_event) - self.assertTrue(start_event.id == uuid.UUID("78923ec2-8f4d-4266-876e-1f50cf3c283b")) - self.assertTrue(start_event.runId == uuid.UUID("55a95d45-e932-4e08-9caa-0a8ecd9391e8")) - self.assertTrue(start_event.name is None) - self.assertTrue(start_event.timestamp == "2023-06-09T18:13:29.741Z") + self.assertEqual(start_event.id, uuid.UUID("78923ec2-8f4d-4266-876e-1f50cf3c283b")) + self.assertEqual(start_event.runId, uuid.UUID("55a95d45-e932-4e08-9caa-0a8ecd9391e8")) + self.assertIsNone(start_event.name) + self.assertEqual(start_event.timestamp, "2023-06-09T18:13:29.741Z") def test_query_terminated_event_fromJson(self): terminated_json = """ @@ -353,10 +353,10 @@ def test_query_terminated_event_fromJson(self): """ terminated_event = QueryTerminatedEvent.fromJson(json.loads(terminated_json)) self.check_terminated_event(terminated_event, "SparkException") - self.assertTrue(terminated_event.id == uuid.UUID("78923ec2-8f4d-4266-876e-1f50cf3c283b")) - self.assertTrue(terminated_event.runId == uuid.UUID("55a95d45-e932-4e08-9caa-0a8ecd9391e8")) - self.assertTrue("SparkException" in terminated_event.exception) - self.assertTrue(terminated_event.errorClassOnException is None) + self.assertEqual(terminated_event.id, uuid.UUID("78923ec2-8f4d-4266-876e-1f50cf3c283b")) + self.assertEqual(terminated_event.runId, uuid.UUID("55a95d45-e932-4e08-9caa-0a8ecd9391e8")) + self.assertIn("SparkException", terminated_event.exception) + self.assertIsNone(terminated_event.errorClassOnException) def test_streaming_query_progress_fromJson(self): progress_json = """ @@ -430,75 +430,66 @@ def test_streaming_query_progress_fromJson(self): self.check_streaming_query_progress(progress) # checks for progress - self.assertTrue(progress.id == uuid.UUID("00000000-0000-0001-0000-000000000001")) - self.assertTrue(progress.runId == uuid.UUID("00000000-0000-0001-0000-000000000002")) - self.assertTrue(progress.name == "test") - self.assertTrue(progress.timestamp == "2016-12-05T20:54:20.827Z") - self.assertTrue(progress.batchId == 2) - self.assertTrue(progress.numInputRows == 678) - self.assertTrue(progress.inputRowsPerSecond == 10.0) - self.assertTrue(progress.batchDuration == 5) - self.assertTrue(progress.durationMs == {"getBatch": 0}) - self.assertTrue( - progress.eventTime - == { - "min": "2016-12-05T20:54:20.827Z", - "avg": "2016-12-05T20:54:20.827Z", - "watermark": "2016-12-05T20:54:20.827Z", - "max": "2016-12-05T20:54:20.827Z", - } - ) - self.assertTrue( - progress.observedMetrics - == { - "event1": Row("c1", "c2")(1, 3.0), - "event2": Row("rc", "min_q", "max_q")(1, "hello", "world"), - } - ) + self.assertEqual(progress.id, uuid.UUID("00000000-0000-0001-0000-000000000001")) + self.assertEqual(progress.runId, uuid.UUID("00000000-0000-0001-0000-000000000002")) + self.assertEqual(progress.name, "test") + self.assertEqual(progress.timestamp, "2016-12-05T20:54:20.827Z") + self.assertEqual(progress.batchId, 2) + self.assertEqual(progress.numInputRows, 678) + self.assertEqual(progress.inputRowsPerSecond, 10.0) + self.assertEqual(progress.batchDuration, 5) + self.assertEqual(progress.durationMs, {"getBatch": 0}) + self.assertEqual(progress.eventTime, { + "min": "2016-12-05T20:54:20.827Z", + "avg": "2016-12-05T20:54:20.827Z", + "watermark": "2016-12-05T20:54:20.827Z", + "max": "2016-12-05T20:54:20.827Z", + }) + self.assertEqual(progress.observedMetrics, { + "event1": Row("c1", "c2")(1, 3.0), + "event2": Row("rc", "min_q", "max_q")(1, "hello", "world"), + }) # Check stateOperators list - self.assertTrue(len(progress.stateOperators) == 1) + self.assertEqual(len(progress.stateOperators), 1) state_operator = progress.stateOperators[0] self.assertTrue(isinstance(state_operator, StateOperatorProgress)) - self.assertTrue(state_operator.operatorName == "op1") - self.assertTrue(state_operator.numRowsTotal == 0) - self.assertTrue(state_operator.numRowsUpdated == 1) - self.assertTrue(state_operator.allUpdatesTimeMs == 1) - self.assertTrue(state_operator.numRowsRemoved == 2) - self.assertTrue(state_operator.allRemovalsTimeMs == 34) - self.assertTrue(state_operator.commitTimeMs == 23) - self.assertTrue(state_operator.memoryUsedBytes == 3) - self.assertTrue(state_operator.numRowsDroppedByWatermark == 0) - self.assertTrue(state_operator.numShufflePartitions == 2) - self.assertTrue(state_operator.numStateStoreInstances == 2) - self.assertTrue( - state_operator.customMetrics - == { - "loadedMapCacheHitCount": 1, - "loadedMapCacheMissCount": 0, - "stateOnCurrentVersionSizeBytes": 2, - } - ) + self.assertEqual(state_operator.operatorName, "op1") + self.assertEqual(state_operator.numRowsTotal, 0) + self.assertEqual(state_operator.numRowsUpdated, 1) + self.assertEqual(state_operator.allUpdatesTimeMs, 1) + self.assertEqual(state_operator.numRowsRemoved, 2) + self.assertEqual(state_operator.allRemovalsTimeMs, 34) + self.assertEqual(state_operator.commitTimeMs, 23) + self.assertEqual(state_operator.memoryUsedBytes, 3) + self.assertEqual(state_operator.numRowsDroppedByWatermark, 0) + self.assertEqual(state_operator.numShufflePartitions, 2) + self.assertEqual(state_operator.numStateStoreInstances, 2) + self.assertEqual(state_operator.customMetrics, { + "loadedMapCacheHitCount": 1, + "loadedMapCacheMissCount": 0, + "stateOnCurrentVersionSizeBytes": 2, + }) # Check sources list - self.assertTrue(len(progress.sources) == 1) + self.assertEqual(len(progress.sources), 1) source = progress.sources[0] self.assertTrue(isinstance(source, SourceProgress)) - self.assertTrue(source.description == "source") - self.assertTrue(source.startOffset == "123") - self.assertTrue(source.endOffset == "456") - self.assertTrue(source.latestOffset == "789") - self.assertTrue(source.numInputRows == 678) - self.assertTrue(source.inputRowsPerSecond == 10.0) - self.assertTrue(source.processedRowsPerSecond == 5.4) - self.assertTrue(source.metrics == {}) + self.assertEqual(source.description, "source") + self.assertEqual(source.startOffset, "123") + self.assertEqual(source.endOffset, "456") + self.assertEqual(source.latestOffset, "789") + self.assertEqual(source.numInputRows, 678) + self.assertEqual(source.inputRowsPerSecond, 10.0) + self.assertEqual(source.processedRowsPerSecond, 5.4) + self.assertEqual(source.metrics, {}) # Check sink sink = progress.sink self.assertTrue(isinstance(sink, SinkProgress)) - self.assertTrue(sink.description == "sink") - self.assertTrue(sink.numOutputRows == -1) - self.assertTrue(sink.metrics == {}) + self.assertEqual(sink.description, "sink") + self.assertEqual(sink.numOutputRows, -1) + self.assertEqual(sink.metrics, {}) if __name__ == "__main__": From f20d7f8bf270ab12c81569456b4b00304672ecd6 Mon Sep 17 00:00:00 2001 From: Wei Liu Date: Tue, 13 Jun 2023 22:05:21 -0700 Subject: [PATCH 10/14] fmt --- .../streaming/test_streaming_listener.py | 39 ++++++++++++------- 1 file changed, 24 insertions(+), 15 deletions(-) diff --git a/python/pyspark/sql/tests/streaming/test_streaming_listener.py b/python/pyspark/sql/tests/streaming/test_streaming_listener.py index 21ca34d179f36..4df4d9a00ec53 100644 --- a/python/pyspark/sql/tests/streaming/test_streaming_listener.py +++ b/python/pyspark/sql/tests/streaming/test_streaming_listener.py @@ -439,16 +439,22 @@ def test_streaming_query_progress_fromJson(self): self.assertEqual(progress.inputRowsPerSecond, 10.0) self.assertEqual(progress.batchDuration, 5) self.assertEqual(progress.durationMs, {"getBatch": 0}) - self.assertEqual(progress.eventTime, { - "min": "2016-12-05T20:54:20.827Z", - "avg": "2016-12-05T20:54:20.827Z", - "watermark": "2016-12-05T20:54:20.827Z", - "max": "2016-12-05T20:54:20.827Z", - }) - self.assertEqual(progress.observedMetrics, { - "event1": Row("c1", "c2")(1, 3.0), - "event2": Row("rc", "min_q", "max_q")(1, "hello", "world"), - }) + self.assertEqual( + progress.eventTime, + { + "min": "2016-12-05T20:54:20.827Z", + "avg": "2016-12-05T20:54:20.827Z", + "watermark": "2016-12-05T20:54:20.827Z", + "max": "2016-12-05T20:54:20.827Z", + }, + ) + self.assertEqual( + progress.observedMetrics, + { + "event1": Row("c1", "c2")(1, 3.0), + "event2": Row("rc", "min_q", "max_q")(1, "hello", "world"), + }, + ) # Check stateOperators list self.assertEqual(len(progress.stateOperators), 1) @@ -465,11 +471,14 @@ def test_streaming_query_progress_fromJson(self): self.assertEqual(state_operator.numRowsDroppedByWatermark, 0) self.assertEqual(state_operator.numShufflePartitions, 2) self.assertEqual(state_operator.numStateStoreInstances, 2) - self.assertEqual(state_operator.customMetrics, { - "loadedMapCacheHitCount": 1, - "loadedMapCacheMissCount": 0, - "stateOnCurrentVersionSizeBytes": 2, - }) + self.assertEqual( + state_operator.customMetrics, + { + "loadedMapCacheHitCount": 1, + "loadedMapCacheMissCount": 0, + "stateOnCurrentVersionSizeBytes": 2, + }, + ) # Check sources list self.assertEqual(len(progress.sources), 1) From d725d6d2b2f73243b57f151e998319e4511ac225 Mon Sep 17 00:00:00 2001 From: Wei Liu Date: Tue, 13 Jun 2023 22:54:29 -0700 Subject: [PATCH 11/14] remove versionchanged --- python/pyspark/sql/streaming/listener.py | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/python/pyspark/sql/streaming/listener.py b/python/pyspark/sql/streaming/listener.py index 5855a5cf32d33..d657a19157850 100644 --- a/python/pyspark/sql/streaming/listener.py +++ b/python/pyspark/sql/streaming/listener.py @@ -150,8 +150,6 @@ class QueryStartedEvent: Event representing the start of a query. .. versionadded:: 3.4.0 - .. versionchanged:: 3.5.0 - Add fromJson constructor to support Spark Connect. Notes ----- @@ -220,8 +218,6 @@ class QueryProgressEvent: Event representing any progress updates in a query. .. versionadded:: 3.4.0 - .. versionchanged:: 3.5.0 - Add fromJson constructor to support Spark Connect. Notes ----- @@ -304,8 +300,6 @@ class QueryTerminatedEvent: Event representing that termination of a query. .. versionadded:: 3.4.0 - .. versionchanged:: 3.5.0 - Add fromJson constructor to support Spark Connect. Notes ----- @@ -385,8 +379,6 @@ def errorClassOnException(self) -> Optional[str]: class StreamingQueryProgress: """ .. versionadded:: 3.4.0 - .. versionchanged:: 3.5.0 - Add fromJson constructor to support Spark Connect. Notes ----- @@ -627,8 +619,6 @@ def __str__(self) -> str: class StateOperatorProgress: """ .. versionadded:: 3.4.0 - .. versionchanged:: 3.5.0 - Add fromJson constructor to support Spark Connect. Notes ----- @@ -774,8 +764,6 @@ def __str__(self) -> str: class SourceProgress: """ .. versionadded:: 3.4.0 - .. versionchanged:: 3.5.0 - Add fromJson constructor to support Spark Connect. Notes ----- @@ -910,8 +898,6 @@ def __str__(self) -> str: class SinkProgress: """ .. versionadded:: 3.4.0 - .. versionchanged:: 3.5.0 - Add fromJson constructor to support Spark Connect. Notes ----- From d378973efbaa579f67e280a56f8a6534b11dfa1b Mon Sep 17 00:00:00 2001 From: Wei Liu Date: Wed, 14 Jun 2023 11:42:34 -0700 Subject: [PATCH 12/14] oops forget ser method for QueryProgressEvent --- .../spark/sql/streaming/StreamingQueryListener.scala | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryListener.scala index 8f99fa2f9d391..5c0027895cda6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryListener.scala @@ -19,9 +19,9 @@ package org.apache.spark.sql.streaming import java.util.UUID +import org.json4s.{JObject, JString} import org.json4s.JsonAST.JValue import org.json4s.JsonDSL.{jobject2assoc, pair2Assoc} -import org.json4s.JString import org.json4s.jackson.JsonMethods.{compact, render} import org.apache.spark.annotation.Evolving @@ -146,7 +146,12 @@ object StreamingQueryListener { * @since 2.1.0 */ @Evolving - class QueryProgressEvent private[sql](val progress: StreamingQueryProgress) extends Event + class QueryProgressEvent private[sql](val progress: StreamingQueryProgress) extends Event { + + def json: String = compact(render(jsonValue)) + + private def jsonValue: JValue = JObject("progress" -> progress.jsonValue) + } /** * Event representing that query is idle and waiting for new data to process. From a1da6de0a185ba833e6fd420a3853e0793d52821 Mon Sep 17 00:00:00 2001 From: Wei Liu Date: Wed, 14 Jun 2023 12:09:01 -0700 Subject: [PATCH 13/14] address raghu's comments --- python/pyspark/sql/streaming/listener.py | 104 +++++++++++------- .../streaming/test_streaming_listener.py | 2 +- 2 files changed, 65 insertions(+), 41 deletions(-) diff --git a/python/pyspark/sql/streaming/listener.py b/python/pyspark/sql/streaming/listener.py index d657a19157850..9ead70abb3f62 100644 --- a/python/pyspark/sql/streaming/listener.py +++ b/python/pyspark/sql/streaming/listener.py @@ -387,8 +387,6 @@ class StreamingQueryProgress: def __init__( self, - json: str, - prettyJson: str, id: uuid.UUID, runId: uuid.UUID, name: Optional[str], @@ -404,9 +402,11 @@ def __init__( inputRowsPerSecond: float, processedRowsPerSecond: float, observedMetrics: Dict[str, Row], + jprogress: JavaObject = None, + jdict: Dict[str, Any] = None, ): - self._json: str = json - self._prettyJson: str = prettyJson + self._jprogress: JavaObject = jprogress + self._jdict: Dict[str, Any] = jdict self._id: uuid.UUID = id self._runId: uuid.UUID = runId self._name: Optional[str] = name @@ -428,8 +428,7 @@ def fromJObject(cls, jprogress: JavaObject) -> "StreamingQueryProgress": from pyspark import SparkContext return cls( - json=jprogress.json(), - prettyJson=jprogress.prettyJson(), + jprogress=jprogress, id=uuid.UUID(jprogress.id().toString()), runId=uuid.UUID(jprogress.runId().toString()), name=jprogress.name(), @@ -457,8 +456,7 @@ def fromJObject(cls, jprogress: JavaObject) -> "StreamingQueryProgress": @classmethod def fromJson(cls, j: Dict[str, Any]) -> "StreamingQueryProgress": return cls( - json=json.dumps(j), - prettyJson=json.dumps(j, indent=4), + jdict=j, id=uuid.UUID(j["id"]), runId=uuid.UUID(j["runId"]), name=j["name"], @@ -603,14 +601,22 @@ def json(self) -> str: """ The compact JSON representation of this progress. """ - return self._json + assert self._jdict is not None or self._jprogress is not None + if self._jprogress: + return self._jprogress.json() + else: + return json.dumps(self._jdict) @property def prettyJson(self) -> str: """ The pretty (i.e. indented) JSON representation of this progress. """ - return self._prettyJson + assert self._jdict is not None or self._jprogress is not None + if self._jprogress: + return self._jprogress.prettyJson() + else: + return json.dumps(self._jdict, indent=4) def __str__(self) -> str: return self.prettyJson @@ -627,8 +633,6 @@ class StateOperatorProgress: def __init__( self, - json: str, - prettyJson: str, operatorName: str, numRowsTotal: int, numRowsUpdated: int, @@ -641,9 +645,11 @@ def __init__( numShufflePartitions: int, numStateStoreInstances: int, customMetrics: Dict[str, int], + jprogress: JavaObject = None, + jdict: Dict[str, Any] = None, ): - self._json: str = json - self._prettyJson: str = prettyJson + self._jprogress: JavaObject = jprogress + self._jdict: Dict[str, Any] = jdict self._operatorName: str = operatorName self._numRowsTotal: int = numRowsTotal self._numRowsUpdated: int = numRowsUpdated @@ -660,8 +666,7 @@ def __init__( @classmethod def fromJObject(cls, jprogress: JavaObject) -> "StateOperatorProgress": return cls( - json=jprogress.json(), - prettyJson=jprogress.prettyJson(), + jprogress=jprogress, operatorName=jprogress.operatorName(), numRowsTotal=jprogress.numRowsTotal(), numRowsUpdated=jprogress.numRowsUpdated(), @@ -679,8 +684,7 @@ def fromJObject(cls, jprogress: JavaObject) -> "StateOperatorProgress": @classmethod def fromJson(cls, j: Dict[str, Any]) -> "StateOperatorProgress": return cls( - json=json.dumps(j), - prettyJson=json.dumps(j, indent=4), + jdict=j, operatorName=j["operatorName"], numRowsTotal=j["numRowsTotal"], numRowsUpdated=j["numRowsUpdated"], @@ -748,14 +752,22 @@ def json(self) -> str: """ The compact JSON representation of this progress. """ - return self._json + assert self._jdict is not None or self._jprogress is not None + if self._jprogress: + return self._jprogress.json() + else: + return json.dumps(self._jdict) @property def prettyJson(self) -> str: """ The pretty (i.e. indented) JSON representation of this progress. """ - return self._prettyJson + assert self._jdict is not None or self._jprogress is not None + if self._jprogress: + return self._jprogress.prettyJson() + else: + return json.dumps(self._jdict, indent=4) def __str__(self) -> str: return self.prettyJson @@ -772,8 +784,6 @@ class SourceProgress: def __init__( self, - json: str, - prettyJson: str, description: str, startOffset: str, endOffset: str, @@ -782,9 +792,11 @@ def __init__( inputRowsPerSecond: float, processedRowsPerSecond: float, metrics: Dict[str, str], + jprogress: JavaObject = None, + jdict: Dict[str, Any] = None, ) -> None: - self._json: str = json - self._prettyJson: str = prettyJson + self._jprogress: JavaObject = jprogress + self._jdict: Dict[str, Any] = jdict self._description: str = description self._startOffset: str = startOffset self._endOffset: str = endOffset @@ -797,8 +809,7 @@ def __init__( @classmethod def fromJObject(cls, jprogress: JavaObject) -> "SourceProgress": return cls( - json=jprogress.json(), - prettyJson=jprogress.prettyJson(), + jprogress=jprogress, description=jprogress.description(), startOffset=str(jprogress.startOffset()), endOffset=str(jprogress.endOffset()), @@ -812,8 +823,7 @@ def fromJObject(cls, jprogress: JavaObject) -> "SourceProgress": @classmethod def fromJson(cls, j: Dict[str, Any]) -> "SourceProgress": return cls( - json=json.dumps(j), - prettyJson=json.dumps(j, indent=4), + jdict=j, description=j["description"], startOffset=str(j["startOffset"]), endOffset=str(j["endOffset"]), @@ -882,14 +892,22 @@ def json(self) -> str: """ The compact JSON representation of this progress. """ - return self._json + assert self._jdict is not None or self._jprogress is not None + if self._jprogress: + return self._jprogress.json() + else: + return json.dumps(self._jdict) @property def prettyJson(self) -> str: """ The pretty (i.e. indented) JSON representation of this progress. """ - return self._prettyJson + assert self._jdict is not None or self._jprogress is not None + if self._jprogress: + return self._jprogress.prettyJson() + else: + return json.dumps(self._jdict, indent=4) def __str__(self) -> str: return self.prettyJson @@ -906,14 +924,14 @@ class SinkProgress: def __init__( self, - json: str, - prettyJson: str, description: str, numOutputRows: int, metrics: Dict[str, str], + jprogress: JavaObject = None, + jdict: Dict[str, Any] = None, ) -> None: - self._json: str = json - self._prettyJson: str = prettyJson + self._jprogress: JavaObject = jprogress + self._jdict: Dict[str, Any] = jdict self._description: str = description self._numOutputRows: int = numOutputRows self._metrics: Dict[str, str] = metrics @@ -921,8 +939,7 @@ def __init__( @classmethod def fromJObject(cls, jprogress: JavaObject) -> "SinkProgress": return cls( - json=jprogress.json(), - prettyJson=jprogress.prettyJson(), + jprogress=jprogress, description=jprogress.description(), numOutputRows=jprogress.numOutputRows(), metrics=dict(jprogress.metrics()), @@ -931,8 +948,7 @@ def fromJObject(cls, jprogress: JavaObject) -> "SinkProgress": @classmethod def fromJson(cls, j: Dict[str, Any]) -> "SinkProgress": return cls( - json=json.dumps(j), - prettyJson=json.dumps(j, indent=4), + jdict=j, description=j["description"], numOutputRows=j["numOutputRows"], metrics=j["metrics"], @@ -962,14 +978,22 @@ def json(self) -> str: """ The compact JSON representation of this progress. """ - return self._json + assert self._jdict is not None or self._jprogress is not None + if self._jprogress: + return self._jprogress.json() + else: + return json.dumps(self._jdict) @property def prettyJson(self) -> str: """ The pretty (i.e. indented) JSON representation of this progress. """ - return self._prettyJson + assert self._jdict is not None or self._jprogress is not None + if self._jprogress: + return self._jprogress.prettyJson() + else: + return json.dumps(self._jdict, indent=4) def __str__(self) -> str: return self.prettyJson diff --git a/python/pyspark/sql/tests/streaming/test_streaming_listener.py b/python/pyspark/sql/tests/streaming/test_streaming_listener.py index 4df4d9a00ec53..2bd6d2c666837 100644 --- a/python/pyspark/sql/tests/streaming/test_streaming_listener.py +++ b/python/pyspark/sql/tests/streaming/test_streaming_listener.py @@ -59,7 +59,7 @@ def get_number_of_public_methods(clz): get_number_of_public_methods( "org.apache.spark.sql.streaming.StreamingQueryListener$QueryProgressEvent" ), - 11, + 12, msg, ) self.assertEquals( From f08c6748f10be85762768718f2f10679956327a9 Mon Sep 17 00:00:00 2001 From: Wei Liu Date: Thu, 15 Jun 2023 12:05:34 -0700 Subject: [PATCH 14/14] lint --- python/pyspark/sql/streaming/listener.py | 32 ++++++++++++------------ 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/python/pyspark/sql/streaming/listener.py b/python/pyspark/sql/streaming/listener.py index 9ead70abb3f62..198af0c9cbeb5 100644 --- a/python/pyspark/sql/streaming/listener.py +++ b/python/pyspark/sql/streaming/listener.py @@ -402,11 +402,11 @@ def __init__( inputRowsPerSecond: float, processedRowsPerSecond: float, observedMetrics: Dict[str, Row], - jprogress: JavaObject = None, - jdict: Dict[str, Any] = None, + jprogress: Optional[JavaObject] = None, + jdict: Optional[Dict[str, Any]] = None, ): - self._jprogress: JavaObject = jprogress - self._jdict: Dict[str, Any] = jdict + self._jprogress: Optional[JavaObject] = jprogress + self._jdict: Optional[Dict[str, Any]] = jdict self._id: uuid.UUID = id self._runId: uuid.UUID = runId self._name: Optional[str] = name @@ -645,11 +645,11 @@ def __init__( numShufflePartitions: int, numStateStoreInstances: int, customMetrics: Dict[str, int], - jprogress: JavaObject = None, - jdict: Dict[str, Any] = None, + jprogress: Optional[JavaObject] = None, + jdict: Optional[Dict[str, Any]] = None, ): - self._jprogress: JavaObject = jprogress - self._jdict: Dict[str, Any] = jdict + self._jprogress: Optional[JavaObject] = jprogress + self._jdict: Optional[Dict[str, Any]] = jdict self._operatorName: str = operatorName self._numRowsTotal: int = numRowsTotal self._numRowsUpdated: int = numRowsUpdated @@ -792,11 +792,11 @@ def __init__( inputRowsPerSecond: float, processedRowsPerSecond: float, metrics: Dict[str, str], - jprogress: JavaObject = None, - jdict: Dict[str, Any] = None, + jprogress: Optional[JavaObject] = None, + jdict: Optional[Dict[str, Any]] = None, ) -> None: - self._jprogress: JavaObject = jprogress - self._jdict: Dict[str, Any] = jdict + self._jprogress: Optional[JavaObject] = jprogress + self._jdict: Optional[Dict[str, Any]] = jdict self._description: str = description self._startOffset: str = startOffset self._endOffset: str = endOffset @@ -927,11 +927,11 @@ def __init__( description: str, numOutputRows: int, metrics: Dict[str, str], - jprogress: JavaObject = None, - jdict: Dict[str, Any] = None, + jprogress: Optional[JavaObject] = None, + jdict: Optional[Dict[str, Any]] = None, ) -> None: - self._jprogress: JavaObject = jprogress - self._jdict: Dict[str, Any] = jdict + self._jprogress: Optional[JavaObject] = jprogress + self._jdict: Optional[Dict[str, Any]] = jdict self._description: str = description self._numOutputRows: int = numOutputRows self._metrics: Dict[str, str] = metrics