From 6a9f8aaff31e41ef398a11a600b6eb06e9bae9d6 Mon Sep 17 00:00:00 2001 From: Haejoon Lee Date: Fri, 29 Nov 2024 16:53:19 +0900 Subject: [PATCH 1/2] [SPARK-50357][PYTHON] Support Interrupt(Tag|All) APIs for PySpark --- python/pyspark/sql/session.py | 34 +++++++++++++------ .../connect/test_parity_job_cancellation.py | 22 ------------ .../sql/tests/test_connect_compatibility.py | 2 -- .../sql/tests/test_job_cancellation.py | 22 ++++++++++++ python/pyspark/sql/tests/test_session.py | 1 - 5 files changed, 46 insertions(+), 35 deletions(-) diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index f3a1639fddafa..fc434cd16bfbd 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -2197,13 +2197,15 @@ def copyFromLocalToFs(self, local_path: str, dest_path: str) -> None: messageParameters={"feature": "SparkSession.copyFromLocalToFs"}, ) - @remote_only def interruptAll(self) -> List[str]: """ Interrupt all operations of this session currently running on the connected server. .. versionadded:: 3.5.0 + .. versionchanged:: 4.0.0 + Supports Spark Classic. + Returns ------- list of str @@ -2213,18 +2215,25 @@ def interruptAll(self) -> List[str]: ----- There is still a possibility of operation finishing just as it is interrupted. """ - raise PySparkRuntimeError( - errorClass="ONLY_SUPPORTED_WITH_SPARK_CONNECT", - messageParameters={"feature": "SparkSession.interruptAll"}, - ) + java_list = self._jsparkSession.interruptAll() + python_list = list() + + # Use iterator to manually iterate through Java list + java_iterator = java_list.iterator() + while java_iterator.hasNext(): + python_list.append(str(java_iterator.next())) + + return python_list - @remote_only def interruptTag(self, tag: str) -> List[str]: """ Interrupt all operations of this session with the given operation tag. .. versionadded:: 3.5.0 + .. versionchanged:: 4.0.0 + Supports Spark Classic. + Returns ------- list of str @@ -2234,10 +2243,15 @@ def interruptTag(self, tag: str) -> List[str]: ----- There is still a possibility of operation finishing just as it is interrupted. """ - raise PySparkRuntimeError( - errorClass="ONLY_SUPPORTED_WITH_SPARK_CONNECT", - messageParameters={"feature": "SparkSession.interruptTag"}, - ) + java_list = self._jsparkSession.interruptTag(tag) + python_list = list() + + # Use iterator to manually iterate through Java list + java_iterator = java_list.iterator() + while java_iterator.hasNext(): + python_list.append(str(java_iterator.next())) + + return python_list @remote_only def interruptOperation(self, op_id: str) -> List[str]: diff --git a/python/pyspark/sql/tests/connect/test_parity_job_cancellation.py b/python/pyspark/sql/tests/connect/test_parity_job_cancellation.py index c5184b04d6aa5..ddb4554afa55a 100644 --- a/python/pyspark/sql/tests/connect/test_parity_job_cancellation.py +++ b/python/pyspark/sql/tests/connect/test_parity_job_cancellation.py @@ -32,28 +32,6 @@ def func(target): create_thread=lambda target, session: threading.Thread(target=func, args=(target,)) ) - def test_interrupt_tag(self): - thread_ids = range(4) - self.check_job_cancellation( - lambda job_group: self.spark.addTag(job_group), - lambda job_group: self.spark.interruptTag(job_group), - thread_ids, - [i for i in thread_ids if i % 2 == 0], - [i for i in thread_ids if i % 2 != 0], - ) - self.spark.clearTags() - - def test_interrupt_all(self): - thread_ids = range(4) - self.check_job_cancellation( - lambda job_group: None, - lambda job_group: self.spark.interruptAll(), - thread_ids, - thread_ids, - [], - ) - self.spark.clearTags() - if __name__ == "__main__": import unittest diff --git a/python/pyspark/sql/tests/test_connect_compatibility.py b/python/pyspark/sql/tests/test_connect_compatibility.py index ef83dc3834d0c..25b8be1f9ac7a 100644 --- a/python/pyspark/sql/tests/test_connect_compatibility.py +++ b/python/pyspark/sql/tests/test_connect_compatibility.py @@ -266,9 +266,7 @@ def test_spark_session_compatibility(self): "addArtifacts", "clearProgressHandlers", "copyFromLocalToFs", - "interruptAll", "interruptOperation", - "interruptTag", "newSession", "registerProgressHandler", "removeProgressHandler", diff --git a/python/pyspark/sql/tests/test_job_cancellation.py b/python/pyspark/sql/tests/test_job_cancellation.py index a046c9c01811b..3f30f78808892 100644 --- a/python/pyspark/sql/tests/test_job_cancellation.py +++ b/python/pyspark/sql/tests/test_job_cancellation.py @@ -166,6 +166,28 @@ def get_outer_local_prop(): self.assertEqual(first, {"a", "b"}) self.assertEqual(second, {"a", "b", "c"}) + def test_interrupt_tag(self): + thread_ids = range(4) + self.check_job_cancellation( + lambda job_group: self.spark.addTag(job_group), + lambda job_group: self.spark.interruptTag(job_group), + thread_ids, + [i for i in thread_ids if i % 2 == 0], + [i for i in thread_ids if i % 2 != 0], + ) + self.spark.clearTags() + + def test_interrupt_all(self): + thread_ids = range(4) + self.check_job_cancellation( + lambda job_group: None, + lambda job_group: self.spark.interruptAll(), + thread_ids, + thread_ids, + [], + ) + self.spark.clearTags() + class JobCancellationTests(JobCancellationTestsMixin, ReusedSQLTestCase): pass diff --git a/python/pyspark/sql/tests/test_session.py b/python/pyspark/sql/tests/test_session.py index 3fbc0be943e45..a22fe777e3c9a 100644 --- a/python/pyspark/sql/tests/test_session.py +++ b/python/pyspark/sql/tests/test_session.py @@ -227,7 +227,6 @@ def test_unsupported_api(self): (lambda: session.client, "client"), (session.addArtifacts, "addArtifact(s)"), (lambda: session.copyFromLocalToFs("", ""), "copyFromLocalToFs"), - (lambda: session.interruptTag(""), "interruptTag"), (lambda: session.interruptOperation(""), "interruptOperation"), ] From ae0f059f69951a32f8e4fa988f2320f045120b60 Mon Sep 17 00:00:00 2001 From: Haejoon Lee Date: Mon, 6 Jan 2025 14:41:29 +0900 Subject: [PATCH 2/2] Update docs --- python/docs/source/reference/pyspark.sql/spark_session.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/docs/source/reference/pyspark.sql/spark_session.rst b/python/docs/source/reference/pyspark.sql/spark_session.rst index 1677d3e8e0209..a35fccbcffe99 100644 --- a/python/docs/source/reference/pyspark.sql/spark_session.rst +++ b/python/docs/source/reference/pyspark.sql/spark_session.rst @@ -52,6 +52,8 @@ See also :class:`SparkSession`. SparkSession.dataSource SparkSession.getActiveSession SparkSession.getTags + SparkSession.interruptAll + SparkSession.interruptTag SparkSession.newSession SparkSession.profile SparkSession.removeTag @@ -86,8 +88,6 @@ Spark Connect Only SparkSession.clearProgressHandlers SparkSession.client SparkSession.copyFromLocalToFs - SparkSession.interruptAll SparkSession.interruptOperation - SparkSession.interruptTag SparkSession.registerProgressHandler SparkSession.removeProgressHandler