From 64fcc12d1f0519a20b2384dffae1e4f660c43380 Mon Sep 17 00:00:00 2001 From: David Vogelbacher Date: Tue, 21 May 2019 10:09:23 -0400 Subject: [PATCH 1/5] write failing test --- python/pyspark/sql/tests/test_arrow.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/python/pyspark/sql/tests/test_arrow.py b/python/pyspark/sql/tests/test_arrow.py index f5b5ad9cdf214..38df2015d914f 100644 --- a/python/pyspark/sql/tests/test_arrow.py +++ b/python/pyspark/sql/tests/test_arrow.py @@ -191,6 +191,18 @@ def test_no_partition_frame(self): self.assertEqual(pdf.columns[0], "field1") self.assertTrue(pdf.empty) + def test_propagates_spark_exception(self): + df = self.spark.range(3).toDF("i") + from pyspark.sql.functions import udf + + def raise_exception(): + raise Exception("My error") + exception_udf = udf(raise_exception, IntegerType()) + df = df.withColumn("error", exception_udf()) + with QuietTest(self.sc): + with self.assertRaisesRegexp(Exception, 'My error'): + df.toPandas() + def _createDataFrame_toggle(self, pdf, schema=None): with self.sql_conf({"spark.sql.execution.arrow.enabled": False}): df_no_arrow = self.spark.createDataFrame(pdf, schema=schema) From 3b4fcbe22f07df759fed5c55a5a6dfb7097982e6 Mon Sep 17 00:00:00 2001 From: David Vogelbacher Date: Tue, 21 May 2019 10:41:49 -0400 Subject: [PATCH 2/5] write success and optional error message --- python/pyspark/serializers.py | 20 +++++--- python/pyspark/sql/tests/test_arrow.py | 2 +- .../scala/org/apache/spark/sql/Dataset.scala | 46 ++++++++++++------- 3 files changed, 44 insertions(+), 24 deletions(-) diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 6058e94d471e9..eb8672e3e09b9 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -206,13 +206,19 @@ def load_stream(self, stream): for batch in self.serializer.load_stream(stream): yield batch - # load the batch order indices - num = read_int(stream) - batch_order = [] - for i in xrange(num): - index = read_int(stream) - batch_order.append(index) - yield batch_order + # check success + success = read_bool(stream) + if success: + # load the batch order indices + num = read_int(stream) + batch_order = [] + for i in xrange(num): + index = read_int(stream) + batch_order.append(index) + yield batch_order + else: + error_msg = UTF8Deserializer().loads(stream) + raise RuntimeError("An error occurred while collecting: {}".format(error_msg)) def __repr__(self): return "ArrowCollectSerializer(%s)" % self.serializer diff --git a/python/pyspark/sql/tests/test_arrow.py b/python/pyspark/sql/tests/test_arrow.py index 38df2015d914f..2c622e43bddca 100644 --- a/python/pyspark/sql/tests/test_arrow.py +++ b/python/pyspark/sql/tests/test_arrow.py @@ -200,7 +200,7 @@ def raise_exception(): exception_udf = udf(raise_exception, IntegerType()) df = df.withColumn("error", exception_udf()) with QuietTest(self.sc): - with self.assertRaisesRegexp(Exception, 'My error'): + with self.assertRaisesRegexp(RuntimeError, 'My error'): df.toPandas() def _createDataFrame_toggle(self, pdf, schema=None): diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index d5f1edbfbcb71..90faa9c6fec21 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -18,15 +18,14 @@ package org.apache.spark.sql import java.io.{ByteArrayOutputStream, CharArrayWriter, DataOutputStream} +import java.nio.charset.StandardCharsets import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer import scala.language.implicitConversions import scala.util.control.NonFatal - import org.apache.commons.lang3.StringUtils - -import org.apache.spark.TaskContext +import org.apache.spark.{SparkException, TaskContext} import org.apache.spark.annotation.{DeveloperApi, Evolving, Experimental, Stable, Unstable} import org.apache.spark.api.java.JavaRDD import org.apache.spark.api.java.function._ @@ -40,7 +39,7 @@ import org.apache.spark.sql.catalyst.catalog.HiveTableRelation import org.apache.spark.sql.catalyst.encoders._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateSafeProjection -import org.apache.spark.sql.catalyst.json.{JacksonGenerator, JSONOptions} +import org.apache.spark.sql.catalyst.json.{JSONOptions, JacksonGenerator} import org.apache.spark.sql.catalyst.optimizer.CombineUnions import org.apache.spark.sql.catalyst.parser.{ParseException, ParserUtils} import org.apache.spark.sql.catalyst.plans._ @@ -3313,20 +3312,35 @@ class Dataset[T] private[sql]( } } - val arrowBatchRdd = toArrowBatchRdd(plan) - sparkSession.sparkContext.runJob( - arrowBatchRdd, - (it: Iterator[Array[Byte]]) => it.toArray, - handlePartitionBatches) + var sparkException: Option[SparkException] = Option.empty + try { + val arrowBatchRdd = toArrowBatchRdd(plan) + sparkSession.sparkContext.runJob( + arrowBatchRdd, + (it: Iterator[Array[Byte]]) => it.toArray, + handlePartitionBatches) + } catch { + case e: SparkException => + sparkException = Option.apply(e) + } - // After processing all partitions, end the stream and write batch order indices + // After processing all partitions, end the batch stream batchWriter.end() - out.writeInt(batchOrder.length) - // Sort by (index of partition, batch index in that partition) tuple to get the - // overall_batch_index from 0 to N-1 batches, which can be used to put the - // transferred batches in the correct order - batchOrder.zipWithIndex.sortBy(_._1).foreach { case (_, overallBatchIndex) => - out.writeInt(overallBatchIndex) + sparkException match { + case Some(exception) => + // Signal failure and write error message + out.writeBoolean(false) + PythonRDD.writeUTF(exception.getMessage, out) + case None => + // Signal success and write batch order indices + out.writeBoolean(true) + out.writeInt(batchOrder.length) + // Sort by (index of partition, batch index in that partition) tuple to get the + // overall_batch_index from 0 to N-1 batches, which can be used to put the + // transferred batches in the correct order + batchOrder.zipWithIndex.sortBy(_._1).foreach { case (_, overallBatchIndex) => + out.writeInt(overallBatchIndex) + } } } } From 4f57b7d8e6950990566b6de867cdc2039644b574 Mon Sep 17 00:00:00 2001 From: David Vogelbacher Date: Tue, 21 May 2019 11:07:42 -0400 Subject: [PATCH 3/5] fix import --- sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 90faa9c6fec21..86953c32aff85 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -18,13 +18,14 @@ package org.apache.spark.sql import java.io.{ByteArrayOutputStream, CharArrayWriter, DataOutputStream} -import java.nio.charset.StandardCharsets import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer import scala.language.implicitConversions import scala.util.control.NonFatal + import org.apache.commons.lang3.StringUtils + import org.apache.spark.{SparkException, TaskContext} import org.apache.spark.annotation.{DeveloperApi, Evolving, Experimental, Stable, Unstable} import org.apache.spark.api.java.JavaRDD @@ -39,7 +40,7 @@ import org.apache.spark.sql.catalyst.catalog.HiveTableRelation import org.apache.spark.sql.catalyst.encoders._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateSafeProjection -import org.apache.spark.sql.catalyst.json.{JSONOptions, JacksonGenerator} +import org.apache.spark.sql.catalyst.json.{JacksonGenerator, JSONOptions} import org.apache.spark.sql.catalyst.optimizer.CombineUnions import org.apache.spark.sql.catalyst.parser.{ParseException, ParserUtils} import org.apache.spark.sql.catalyst.plans._ From d9936d5084b71f4a8736d035fbe1fa3afade6966 Mon Sep 17 00:00:00 2001 From: David Vogelbacher Date: Wed, 29 May 2019 09:51:52 -0400 Subject: [PATCH 4/5] cr --- python/pyspark/serializers.py | 22 +++++++++---------- python/pyspark/sql/tests/test_arrow.py | 2 +- .../scala/org/apache/spark/sql/Dataset.scala | 9 ++++---- 3 files changed, 15 insertions(+), 18 deletions(-) diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index eb8672e3e09b9..e3446febf74ac 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -206,19 +206,17 @@ def load_stream(self, stream): for batch in self.serializer.load_stream(stream): yield batch - # check success - success = read_bool(stream) - if success: - # load the batch order indices - num = read_int(stream) - batch_order = [] - for i in xrange(num): - index = read_int(stream) - batch_order.append(index) - yield batch_order - else: + # load the batch order indices + num = read_int(stream) + if num == -1: error_msg = UTF8Deserializer().loads(stream) - raise RuntimeError("An error occurred while collecting: {}".format(error_msg)) + raise RuntimeError("An error occurred while calling " + "ArrowCollectSerializer.load_stream: {}".format(error_msg)) + batch_order = [] + for i in xrange(num): + index = read_int(stream) + batch_order.append(index) + yield batch_order def __repr__(self): return "ArrowCollectSerializer(%s)" % self.serializer diff --git a/python/pyspark/sql/tests/test_arrow.py b/python/pyspark/sql/tests/test_arrow.py index 2c622e43bddca..5fa1acc346d31 100644 --- a/python/pyspark/sql/tests/test_arrow.py +++ b/python/pyspark/sql/tests/test_arrow.py @@ -23,6 +23,7 @@ import warnings from pyspark.sql import Row +from pyspark.sql.functions import udf from pyspark.sql.types import * from pyspark.testing.sqlutils import ReusedSQLTestCase, have_pandas, have_pyarrow, \ pandas_requirement_message, pyarrow_requirement_message @@ -193,7 +194,6 @@ def test_no_partition_frame(self): def test_propagates_spark_exception(self): df = self.spark.range(3).toDF("i") - from pyspark.sql.functions import udf def raise_exception(): raise Exception("My error") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 86953c32aff85..b27481631aa1f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -3313,7 +3313,7 @@ class Dataset[T] private[sql]( } } - var sparkException: Option[SparkException] = Option.empty + var sparkException: Option[SparkException] = None try { val arrowBatchRdd = toArrowBatchRdd(plan) sparkSession.sparkContext.runJob( @@ -3322,7 +3322,7 @@ class Dataset[T] private[sql]( handlePartitionBatches) } catch { case e: SparkException => - sparkException = Option.apply(e) + sparkException = Some(e) } // After processing all partitions, end the batch stream @@ -3330,11 +3330,10 @@ class Dataset[T] private[sql]( sparkException match { case Some(exception) => // Signal failure and write error message - out.writeBoolean(false) + out.writeInt(-1) PythonRDD.writeUTF(exception.getMessage, out) case None => - // Signal success and write batch order indices - out.writeBoolean(true) + // Write batch order indices out.writeInt(batchOrder.length) // Sort by (index of partition, batch index in that partition) tuple to get the // overall_batch_index from 0 to N-1 batches, which can be used to put the From ccfeb9e408b3fd804c0f68308da7ca0adf3094b5 Mon Sep 17 00:00:00 2001 From: David Vogelbacher Date: Wed, 29 May 2019 10:37:40 -0400 Subject: [PATCH 5/5] better comment --- python/pyspark/serializers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index e3446febf74ac..516ee7e7b3084 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -206,7 +206,7 @@ def load_stream(self, stream): for batch in self.serializer.load_stream(stream): yield batch - # load the batch order indices + # load the batch order indices or propagate any error that occurred in the JVM num = read_int(stream) if num == -1: error_msg = UTF8Deserializer().loads(stream)