From f7706c5152461f469ca97fb8450ad992de17001a Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Thu, 15 Sep 2022 11:13:12 +0900 Subject: [PATCH 1/8] [SPARK-40432][SS][PYTHON] Introduce GroupStateImpl and GroupStateTimeout in PySpark --- .../scala/org/apache/spark/util/Utils.scala | 2 +- .../sql/streaming/GroupStateTimeout.java | 5 ++ .../execution/streaming/GroupStateImpl.scala | 54 +++++++++++++++++++ 3 files changed, 60 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index aef79c7882ca1..484a07c18ed0e 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -3051,7 +3051,7 @@ private[spark] object Utils extends Logging { * and return the trailing part after the last dollar sign in the middle */ @scala.annotation.tailrec - private def stripDollars(s: String): String = { + def stripDollars(s: String): String = { val lastDollarIndex = s.lastIndexOf('$') if (lastDollarIndex < s.length - 1) { // The last char is not a dollar sign diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/GroupStateTimeout.java b/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/GroupStateTimeout.java index a814525f870c9..be435a892c974 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/GroupStateTimeout.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/GroupStateTimeout.java @@ -32,6 +32,11 @@ @Experimental @Evolving public class GroupStateTimeout { + // scalastyle:off line.size.limit + // NOTE: if you're adding new type of timeout, you should also fix the places below: + // - Scala: org.apache.spark.sql.execution.streaming.GroupStateImpl.getGroupStateTimeoutFromString + // - Python: pyspark.sql.streaming.state.GroupStateTimeout + // scalastyle:on line.size.limit /** * Timeout based on processing time. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/GroupStateImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/GroupStateImpl.scala index b4f37125f4fa9..3af9c9aebf33d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/GroupStateImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/GroupStateImpl.scala @@ -20,6 +20,9 @@ package org.apache.spark.sql.execution.streaming import java.sql.Date import java.util.concurrent.TimeUnit +import org.json4s._ +import org.json4s.jackson.JsonMethods._ + import org.apache.spark.api.java.Optional import org.apache.spark.sql.catalyst.plans.logical.{EventTimeTimeout, NoTimeout, ProcessingTimeTimeout} import org.apache.spark.sql.catalyst.util.IntervalUtils @@ -27,6 +30,7 @@ import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution.streaming.GroupStateImpl._ import org.apache.spark.sql.streaming.{GroupStateTimeout, TestGroupState} import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.util.Utils /** * Internal implementation of the [[TestGroupState]] interface. Methods are not thread-safe. @@ -46,6 +50,9 @@ private[sql] class GroupStateImpl[S] private( timeoutConf: GroupStateTimeout, override val hasTimedOut: Boolean, watermarkPresent: Boolean) extends TestGroupState[S] { + // NOTE: if you're adding new properties here, fix: + // - `json` and `fromJson` methods of this class in Scala + // - pyspark.sql.streaming.state.GroupStateImpl in Python private var value: S = optionalValue.getOrElse(null.asInstanceOf[S]) private var defined: Boolean = optionalValue.isDefined @@ -173,6 +180,22 @@ private[sql] class GroupStateImpl[S] private( throw QueryExecutionErrors.cannotSetTimeoutTimestampError() } } + + private[sql] def json(): String = compact(render(new JObject( + // Constructor + "optionalValue" -> JNull :: // Note that optionalValue will be manually serialized. + "batchProcessingTimeMs" -> JLong(batchProcessingTimeMs) :: + "eventTimeWatermarkMs" -> JLong(eventTimeWatermarkMs) :: + "timeoutConf" -> JString(Utils.stripDollars(Utils.getSimpleName(timeoutConf.getClass))) :: + "hasTimedOut" -> JBool(hasTimedOut) :: + "watermarkPresent" -> JBool(watermarkPresent) :: + + // Internal state + "defined" -> JBool(defined) :: + "updated" -> JBool(updated) :: + "removed" -> JBool(removed) :: + "timeoutTimestamp" -> JLong(timeoutTimestamp) :: Nil + ))) } @@ -214,4 +237,35 @@ private[sql] object GroupStateImpl { hasTimedOut = false, watermarkPresent) } + + def groupStateTimeoutFromString(clazz: String): GroupStateTimeout = clazz match { + case "ProcessingTimeTimeout" => GroupStateTimeout.ProcessingTimeTimeout + case "EventTimeTimeout" => GroupStateTimeout.EventTimeTimeout + case "NoTimeout" => GroupStateTimeout.NoTimeout + case _ => throw new IllegalStateException("Invalid string for GroupStateTimeout: " + clazz) + } + + def fromJson[S](value: Option[S], json: JValue): GroupStateImpl[S] = { + implicit val formats = org.json4s.DefaultFormats + + val hmap = json.extract[Map[String, Any]] + + // Constructor + val newGroupState = new GroupStateImpl[S]( + value, + hmap("batchProcessingTimeMs").asInstanceOf[Number].longValue(), + hmap("eventTimeWatermarkMs").asInstanceOf[Number].longValue(), + groupStateTimeoutFromString(hmap("timeoutConf").asInstanceOf[String]), + hmap("hasTimedOut").asInstanceOf[Boolean], + hmap("watermarkPresent").asInstanceOf[Boolean]) + + // Internal state + newGroupState.defined = hmap("defined").asInstanceOf[Boolean] + newGroupState.updated = hmap("updated").asInstanceOf[Boolean] + newGroupState.removed = hmap("removed").asInstanceOf[Boolean] + newGroupState.timeoutTimestamp = + hmap("timeoutTimestamp").asInstanceOf[Number].longValue() + + newGroupState + } } From a28cc538383aa3340d01016795db99afc65d49fd Mon Sep 17 00:00:00 2001 From: Hyukjin Kwon Date: Thu, 15 Sep 2022 11:26:20 +0900 Subject: [PATCH 2/8] meta-commit to credit properly on co-authorship From 64ebd207c5aea0ffc856d1610bdb4969f421ca10 Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Thu, 15 Sep 2022 11:40:19 +0900 Subject: [PATCH 3/8] Update sql/catalyst/src/main/java/org/apache/spark/sql/streaming/GroupStateTimeout.java Co-authored-by: Hyukjin Kwon --- .../org/apache/spark/sql/streaming/GroupStateTimeout.java | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/GroupStateTimeout.java b/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/GroupStateTimeout.java index be435a892c974..ee51ddb0e1ef5 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/GroupStateTimeout.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/GroupStateTimeout.java @@ -32,11 +32,10 @@ @Experimental @Evolving public class GroupStateTimeout { - // scalastyle:off line.size.limit // NOTE: if you're adding new type of timeout, you should also fix the places below: - // - Scala: org.apache.spark.sql.execution.streaming.GroupStateImpl.getGroupStateTimeoutFromString + // - Scala: + // org.apache.spark.sql.execution.streaming.GroupStateImpl.getGroupStateTimeoutFromString // - Python: pyspark.sql.streaming.state.GroupStateTimeout - // scalastyle:on line.size.limit /** * Timeout based on processing time. From 4b85557ea7648e8ba768e4cf55e3a3b9f6ea215b Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Thu, 15 Sep 2022 12:13:48 +0900 Subject: [PATCH 4/8] add missed file --- python/pyspark/sql/streaming/state.py | 192 ++++++++++++++++++++++++++ 1 file changed, 192 insertions(+) create mode 100644 python/pyspark/sql/streaming/state.py diff --git a/python/pyspark/sql/streaming/state.py b/python/pyspark/sql/streaming/state.py new file mode 100644 index 0000000000000..a776c521c9355 --- /dev/null +++ b/python/pyspark/sql/streaming/state.py @@ -0,0 +1,192 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import datetime +import json +from typing import Tuple, Optional + +from pyspark.sql.types import DateType, Row, StructType + +__all__ = ["GroupStateImpl", "GroupStateTimeout"] + + +class GroupStateTimeout: + NoTimeout: str = "NoTimeout" + ProcessingTimeTimeout: str = "ProcessingTimeTimeout" + EventTimeTimeout: str = "EventTimeTimeout" + + +class GroupStateImpl: + NO_TIMESTAMP: int = -1 + + def __init__( + self, + # JVM Constructor + optionalValue: Row, + batchProcessingTimeMs: int, + eventTimeWatermarkMs: int, + timeoutConf: str, + hasTimedOut: bool, + watermarkPresent: bool, + # JVM internal state. + defined: bool, + updated: bool, + removed: bool, + timeoutTimestamp: int, + # Python internal state. + keyAsUnsafe: bytes, + valueSchema: StructType, + ) -> None: + self._keyAsUnsafe = keyAsUnsafe + self._value = optionalValue + self._batch_processing_time_ms = batchProcessingTimeMs + self._event_time_watermark_ms = eventTimeWatermarkMs + + assert timeoutConf in [ + GroupStateTimeout.NoTimeout, + GroupStateTimeout.ProcessingTimeTimeout, + GroupStateTimeout.EventTimeTimeout, + ] + self._timeout_conf = timeoutConf + + self._has_timed_out = hasTimedOut + self._watermark_present = watermarkPresent + + self._defined = defined + self._updated = updated + self._removed = removed + self._timeout_timestamp = timeoutTimestamp + # Python internal state. + self._old_timeout_timestamp = timeoutTimestamp + + self._value_schema = valueSchema + + @property + def exists(self) -> bool: + return self._defined + + @property + def get(self) -> Tuple: + if self.exists: + return tuple(self._value) + else: + raise ValueError("State is either not defined or has already been removed") + + @property + def getOption(self) -> Optional[Tuple]: + if self.exists: + return tuple(self._value) + else: + return None + + @property + def hasTimedOut(self) -> bool: + return self._has_timed_out + + # NOTE: this function is only available to PySpark implementation due to underlying + # implementation, do not port to Scala implementation! + @property + def oldTimeoutTimestamp(self) -> int: + return self._old_timeout_timestamp + + def update(self, newValue: Tuple) -> None: + if newValue is None: + raise ValueError("'None' is not a valid state value") + + self._value = Row(*newValue) + self._defined = True + self._updated = True + self._removed = False + + def remove(self) -> None: + self._defined = False + self._updated = False + self._removed = True + + def setTimeoutDuration(self, durationMs: int) -> None: + if isinstance(durationMs, str): + # TODO(SPARK-XXXXX): Support string representation of durationMs. + raise ValueError("durationMs should be int but get :%s" % type(durationMs)) + + if self._timeout_conf != GroupStateTimeout.ProcessingTimeTimeout: + raise RuntimeError( + "Cannot set timeout duration without enabling processing time timeout in " + "applyInPandasWithState" + ) + + if durationMs <= 0: + raise ValueError("Timeout duration must be positive") + self._timeout_timestamp = durationMs + self._batch_processing_time_ms + + # TODO(SPARK-XXXXX): Implement additionalDuration parameter. + def setTimeoutTimestamp(self, timestampMs: int) -> None: + if self._timeout_conf != GroupStateTimeout.EventTimeTimeout: + raise RuntimeError( + "Cannot set timeout duration without enabling processing time timeout in " + "applyInPandasWithState" + ) + + if isinstance(timestampMs, datetime.datetime): + timestampMs = DateType().toInternal(timestampMs) + + if timestampMs <= 0: + raise ValueError("Timeout timestamp must be positive") + + if ( + self._event_time_watermark_ms != GroupStateImpl.NO_TIMESTAMP + and timestampMs < self._event_time_watermark_ms + ): + raise ValueError( + "Timeout timestamp (%s) cannot be earlier than the " + "current watermark (%s)" % (timestampMs, self._event_time_watermark_ms) + ) + + self._timeout_timestamp = timestampMs + + def getCurrentWatermarkMs(self) -> int: + if not self._watermark_present: + raise RuntimeError( + "Cannot get event time watermark timestamp without setting watermark before " + "applyInPandasWithState" + ) + return self._event_time_watermark_ms + + def getCurrentProcessingTimeMs(self) -> int: + return self._batch_processing_time_ms + + def __str__(self) -> str: + if self.exists: + return "GroupState(%s)" % (self.get, ) + else: + return "GroupState()" + + def json(self) -> str: + return json.dumps( + { + # Constructor + "optionalValue": None, # Note that optionalValue will be manually serialized. + "batchProcessingTimeMs": self._batch_processing_time_ms, + "eventTimeWatermarkMs": self._event_time_watermark_ms, + "timeoutConf": self._timeout_conf, + "hasTimedOut": self._has_timed_out, + "watermarkPresent": self._watermark_present, + # JVM internal state. + "defined": self._defined, + "updated": self._updated, + "removed": self._removed, + "timeoutTimestamp": self._timeout_timestamp, + } + ) From 0c63198c8b75bd61230f5df08b279331d5d28d6e Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Thu, 15 Sep 2022 12:23:00 +0900 Subject: [PATCH 5/8] update SPARK-XXXXX to SPARK-40437 --- python/pyspark/sql/streaming/state.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/streaming/state.py b/python/pyspark/sql/streaming/state.py index a776c521c9355..b4f2be2256eee 100644 --- a/python/pyspark/sql/streaming/state.py +++ b/python/pyspark/sql/streaming/state.py @@ -118,7 +118,7 @@ def remove(self) -> None: def setTimeoutDuration(self, durationMs: int) -> None: if isinstance(durationMs, str): - # TODO(SPARK-XXXXX): Support string representation of durationMs. + # TODO(SPARK-40437): Support string representation of durationMs. raise ValueError("durationMs should be int but get :%s" % type(durationMs)) if self._timeout_conf != GroupStateTimeout.ProcessingTimeTimeout: @@ -131,7 +131,7 @@ def setTimeoutDuration(self, durationMs: int) -> None: raise ValueError("Timeout duration must be positive") self._timeout_timestamp = durationMs + self._batch_processing_time_ms - # TODO(SPARK-XXXXX): Implement additionalDuration parameter. + # TODO(SPARK-40437): Implement additionalDuration parameter. def setTimeoutTimestamp(self, timestampMs: int) -> None: if self._timeout_conf != GroupStateTimeout.EventTimeTimeout: raise RuntimeError( From 59862ca93e3e4370324ea1f620067f1237a22786 Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Thu, 15 Sep 2022 12:24:57 +0900 Subject: [PATCH 6/8] fix --- python/pyspark/sql/streaming/state.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/sql/streaming/state.py b/python/pyspark/sql/streaming/state.py index b4f2be2256eee..c548182e004fa 100644 --- a/python/pyspark/sql/streaming/state.py +++ b/python/pyspark/sql/streaming/state.py @@ -131,7 +131,7 @@ def setTimeoutDuration(self, durationMs: int) -> None: raise ValueError("Timeout duration must be positive") self._timeout_timestamp = durationMs + self._batch_processing_time_ms - # TODO(SPARK-40437): Implement additionalDuration parameter. + # TODO(SPARK-40438): Implement additionalDuration parameter. def setTimeoutTimestamp(self, timestampMs: int) -> None: if self._timeout_conf != GroupStateTimeout.EventTimeTimeout: raise RuntimeError( From 9c97c6bf1c89420db1d60ae4bae56521bc315c50 Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Thu, 15 Sep 2022 19:14:05 +0900 Subject: [PATCH 7/8] trigger From dd967831ff36142f4013ed910a068c8ad47b330e Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Thu, 15 Sep 2022 22:15:11 +0900 Subject: [PATCH 8/8] fix lint --- python/pyspark/sql/streaming/state.py | 38 +++++++++++++-------------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/python/pyspark/sql/streaming/state.py b/python/pyspark/sql/streaming/state.py index c548182e004fa..842eff3223308 100644 --- a/python/pyspark/sql/streaming/state.py +++ b/python/pyspark/sql/streaming/state.py @@ -33,22 +33,22 @@ class GroupStateImpl: NO_TIMESTAMP: int = -1 def __init__( - self, - # JVM Constructor - optionalValue: Row, - batchProcessingTimeMs: int, - eventTimeWatermarkMs: int, - timeoutConf: str, - hasTimedOut: bool, - watermarkPresent: bool, - # JVM internal state. - defined: bool, - updated: bool, - removed: bool, - timeoutTimestamp: int, - # Python internal state. - keyAsUnsafe: bytes, - valueSchema: StructType, + self, + # JVM Constructor + optionalValue: Row, + batchProcessingTimeMs: int, + eventTimeWatermarkMs: int, + timeoutConf: str, + hasTimedOut: bool, + watermarkPresent: bool, + # JVM internal state. + defined: bool, + updated: bool, + removed: bool, + timeoutTimestamp: int, + # Python internal state. + keyAsUnsafe: bytes, + valueSchema: StructType, ) -> None: self._keyAsUnsafe = keyAsUnsafe self._value = optionalValue @@ -146,8 +146,8 @@ def setTimeoutTimestamp(self, timestampMs: int) -> None: raise ValueError("Timeout timestamp must be positive") if ( - self._event_time_watermark_ms != GroupStateImpl.NO_TIMESTAMP - and timestampMs < self._event_time_watermark_ms + self._event_time_watermark_ms != GroupStateImpl.NO_TIMESTAMP + and timestampMs < self._event_time_watermark_ms ): raise ValueError( "Timeout timestamp (%s) cannot be earlier than the " @@ -169,7 +169,7 @@ def getCurrentProcessingTimeMs(self) -> int: def __str__(self) -> str: if self.exists: - return "GroupState(%s)" % (self.get, ) + return "GroupState(%s)" % (self.get,) else: return "GroupState()"