From 315b8de0fb3e7277b895b98769e52da7aaae32d6 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Tue, 16 Jan 2018 11:19:12 -0800 Subject: [PATCH 1/3] added __from_dict__ to Row pickling, fixed existing tests and added new test --- python/pyspark/sql/tests.py | 35 ++++++++++++++++++++++++++--------- python/pyspark/sql/types.py | 7 +++++-- 2 files changed, 31 insertions(+), 11 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 80a94a91a87b3..8de708336a60c 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -2306,18 +2306,20 @@ def test_toDF_with_schema_string(self): self.assertEqual(df.schema.simpleString(), "struct") self.assertEqual(df.collect(), [Row(key=str(i), value=str(i)) for i in range(100)]) - # field names can differ. - df = rdd.toDF(" a: int, b: string ") - self.assertEqual(df.schema.simpleString(), "struct") - self.assertEqual(df.collect(), data) + # field order can differ since Rows have kwargs. + df = rdd.toDF(" value: string, key: int ") + self.assertEqual(df.schema.simpleString(), "struct") + self.assertEqual(df.select("key", "value").collect(), data) - # number of fields must match. - self.assertRaisesRegexp(Exception, "Length of object", - lambda: rdd.toDF("key: int").collect()) + # schema field must be a kwarg of the row. + with QuietTest(self.sc): + self.assertRaisesRegexp(Exception, "ValueError: foo", + lambda: rdd.toDF("foo: int").collect()) # field types mismatch will cause exception at runtime. - self.assertRaisesRegexp(Exception, "FloatType can not accept", - lambda: rdd.toDF("key: float, value: string").collect()) + with QuietTest(self.sc): + self.assertRaisesRegexp(Exception, "FloatType can not accept", + lambda: rdd.toDF("key: float, value: string").collect()) # flat schema values will be wrapped into row. df = rdd.map(lambda row: row.key).toDF("int") @@ -2329,6 +2331,21 @@ def test_toDF_with_schema_string(self): self.assertEqual(df.schema.simpleString(), "struct") self.assertEqual(df.collect(), [Row(key=i) for i in range(100)]) + def test_toDF_with_positional_Row_class(self): + TestRow = Row("b", "a") + data = [TestRow(i, str(i)) for i in range(100)] + rdd = self.sc.parallelize(data, 5) + + # field names can differ as long as types are in expected position. + df = rdd.toDF(" key: int, value: string ") + self.assertEqual(df.schema.simpleString(), "struct") + self.assertEqual(df.collect(), data) + + # number of fields must match. + with QuietTest(self.sc): + self.assertRaisesRegexp(Exception, "Length of object", + lambda: rdd.toDF("key: int").collect()) + def test_join_without_on(self): df1 = self.spark.range(1).toDF("a") df2 = self.spark.range(1).toDF("b") diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 0dc5823f72a3c..935c86f12f25c 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -1431,9 +1431,11 @@ def _create_row_inbound_converter(dataType): return lambda *a: dataType.fromInternal(a) -def _create_row(fields, values): +def _create_row(fields, values, from_dict=False): row = Row(*values) row.__fields__ = fields + if from_dict: + row.__from_dict__ = True return row @@ -1571,7 +1573,8 @@ def __setattr__(self, key, value): def __reduce__(self): """Returns a tuple so Python knows how to pickle Row.""" if hasattr(self, "__fields__"): - return (_create_row, (self.__fields__, tuple(self))) + from_dict = getattr(self, "__from_dict__", False) + return (_create_row, (self.__fields__, tuple(self), from_dict)) else: return tuple.__reduce__(self) From a7d339624d3ddf80af63fd3710fdc1e0742ecc6c Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Tue, 16 Jan 2018 11:35:13 -0800 Subject: [PATCH 2/3] shortend test size --- python/pyspark/sql/tests.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 8de708336a60c..5ea6e5a797b92 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -2333,8 +2333,8 @@ def test_toDF_with_schema_string(self): def test_toDF_with_positional_Row_class(self): TestRow = Row("b", "a") - data = [TestRow(i, str(i)) for i in range(100)] - rdd = self.sc.parallelize(data, 5) + data = [TestRow(i, str(i)) for i in range(10)] + rdd = self.sc.parallelize(data, 2) # field names can differ as long as types are in expected position. df = rdd.toDF(" key: int, value: string ") From 10bf2d094b29b4e8ef7a38693f3956f96c0e9f7e Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Wed, 18 Apr 2018 13:42:19 -0700 Subject: [PATCH 3/3] avoid serializing __from_dict__ if not set, improved comments --- python/pyspark/sql/tests.py | 4 ++-- python/pyspark/sql/types.py | 12 +++++++----- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 0623743964239..6e821321caede 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -2482,12 +2482,12 @@ def test_toDF_with_schema_string(self): self.assertEqual(df.schema.simpleString(), "struct") self.assertEqual(df.collect(), [Row(key=str(i), value=str(i)) for i in range(100)]) - # field order can differ since Rows have kwargs. + # field order can differ since Rows created with named arguments. df = rdd.toDF(" value: string, key: int ") self.assertEqual(df.schema.simpleString(), "struct") self.assertEqual(df.select("key", "value").collect(), data) - # schema field must be a kwarg of the row. + # schema field must be also be a field in the row. with QuietTest(self.sc): self.assertRaisesRegexp(Exception, "ValueError: foo", lambda: rdd.toDF("foo: int").collect()) diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 1a0eb260288ad..d4b8005ea4e59 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -1447,15 +1447,15 @@ def __new__(self, *args, **kwargs): raise ValueError("Can not use both args " "and kwargs to create Row") if kwargs: - # create row objects + # create row object from named arguments, order not guaranteed so will be sorted names = sorted(kwargs.keys()) row = tuple.__new__(self, [kwargs[n] for n in names]) row.__fields__ = names - row.__from_dict__ = True + row.__from_dict__ = True # Row elements will be accessed by field name, not position return row else: - # create row class or objects + # create a row class for generating objects or a tuple-like object return tuple.__new__(self, args) def asDict(self, recursive=False): @@ -1534,8 +1534,10 @@ def __setattr__(self, key, value): def __reduce__(self): """Returns a tuple so Python knows how to pickle Row.""" if hasattr(self, "__fields__"): - from_dict = getattr(self, "__from_dict__", False) - return (_create_row, (self.__fields__, tuple(self), from_dict)) + if hasattr(self, "__from_dict__"): + return (_create_row, (self.__fields__, tuple(self), True)) + else: + return (_create_row, (self.__fields__, tuple(self))) else: return tuple.__reduce__(self)