From 1de047849b4f3cd34e392be72070e5729b71081b Mon Sep 17 00:00:00 2001 From: HyukjinKwon Date: Fri, 3 May 2019 10:23:39 +0900 Subject: [PATCH 1/2] Use Python's default protocol instead of highest protocol --- python/pyspark/serializers.py | 5 ++++- python/pyspark/sql/tests/test_serde.py | 9 ++++++++- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 6058e94d471e9..69e99a8709b73 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -435,7 +435,10 @@ def _batched(self, iterator): yield items def dump_stream(self, iterator, stream): - self.serializer.dump_stream(self._batched(iterator), stream) + a = list(self._batched(iterator)) + for i in a: + print(i) + self.serializer.dump_stream(a, stream) def load_stream(self, stream): return chain.from_iterable(self._load_stream_without_unbatching(stream)) diff --git a/python/pyspark/sql/tests/test_serde.py b/python/pyspark/sql/tests/test_serde.py index 8707f46b6a25a..673042e955db8 100644 --- a/python/pyspark/sql/tests/test_serde.py +++ b/python/pyspark/sql/tests/test_serde.py @@ -19,8 +19,9 @@ import shutil import tempfile import time +import unittest -from pyspark.sql import Row +from pyspark.sql import Row, SparkSession from pyspark.sql.functions import lit from pyspark.sql.types import * from pyspark.testing.sqlutils import ReusedSQLTestCase, UTCOffsetTimezone @@ -126,6 +127,12 @@ def test_BinaryType_serialization(self): df = self.spark.createDataFrame(data, schema=schema) df.collect() + def test_int_array_serialization(self): + # Note that this test seems dependent on parallelism. + data = self.spark.sparkContext.parallelize([[1, 2, 3, 4]] * 100, numSlices=12) + df = self.spark.createDataFrame(data, "array") + self.assertEqual(len(list(filter(lambda r: None in r.value, df.collect()))), 0) + if __name__ == "__main__": import unittest From 3ab6eb27614a888295b82b1c3e3b0b914f9f99a3 Mon Sep 17 00:00:00 2001 From: HyukjinKwon Date: Fri, 3 May 2019 12:08:02 +0900 Subject: [PATCH 2/2] Revert test only codes --- python/pyspark/serializers.py | 8 +++----- python/pyspark/sql/tests/test_serde.py | 3 +-- 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 69e99a8709b73..531108738f6c9 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -62,11 +62,12 @@ if sys.version < '3': import cPickle as pickle from itertools import izip as zip, imap as map + pickle_protocol = 2 else: import pickle basestring = unicode = str xrange = range -pickle_protocol = pickle.HIGHEST_PROTOCOL + pickle_protocol = 3 from pyspark import cloudpickle from pyspark.util import _exception_message @@ -435,10 +436,7 @@ def _batched(self, iterator): yield items def dump_stream(self, iterator, stream): - a = list(self._batched(iterator)) - for i in a: - print(i) - self.serializer.dump_stream(a, stream) + self.serializer.dump_stream(self._batched(iterator), stream) def load_stream(self, stream): return chain.from_iterable(self._load_stream_without_unbatching(stream)) diff --git a/python/pyspark/sql/tests/test_serde.py b/python/pyspark/sql/tests/test_serde.py index 673042e955db8..1c18e930eb91d 100644 --- a/python/pyspark/sql/tests/test_serde.py +++ b/python/pyspark/sql/tests/test_serde.py @@ -19,9 +19,8 @@ import shutil import tempfile import time -import unittest -from pyspark.sql import Row, SparkSession +from pyspark.sql import Row from pyspark.sql.functions import lit from pyspark.sql.types import * from pyspark.testing.sqlutils import ReusedSQLTestCase, UTCOffsetTimezone