diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 4e99c8e3c6b10..6e821321caede 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -2482,18 +2482,20 @@ def test_toDF_with_schema_string(self): self.assertEqual(df.schema.simpleString(), "struct") self.assertEqual(df.collect(), [Row(key=str(i), value=str(i)) for i in range(100)]) - # field names can differ. - df = rdd.toDF(" a: int, b: string ") - self.assertEqual(df.schema.simpleString(), "struct") - self.assertEqual(df.collect(), data) + # field order can differ since Rows created with named arguments. + df = rdd.toDF(" value: string, key: int ") + self.assertEqual(df.schema.simpleString(), "struct") + self.assertEqual(df.select("key", "value").collect(), data) - # number of fields must match. - self.assertRaisesRegexp(Exception, "Length of object", - lambda: rdd.toDF("key: int").collect()) + # schema field must be also be a field in the row. + with QuietTest(self.sc): + self.assertRaisesRegexp(Exception, "ValueError: foo", + lambda: rdd.toDF("foo: int").collect()) # field types mismatch will cause exception at runtime. - self.assertRaisesRegexp(Exception, "FloatType can not accept", - lambda: rdd.toDF("key: float, value: string").collect()) + with QuietTest(self.sc): + self.assertRaisesRegexp(Exception, "FloatType can not accept", + lambda: rdd.toDF("key: float, value: string").collect()) # flat schema values will be wrapped into row. df = rdd.map(lambda row: row.key).toDF("int") @@ -2505,6 +2507,21 @@ def test_toDF_with_schema_string(self): self.assertEqual(df.schema.simpleString(), "struct") self.assertEqual(df.collect(), [Row(key=i) for i in range(100)]) + def test_toDF_with_positional_Row_class(self): + TestRow = Row("b", "a") + data = [TestRow(i, str(i)) for i in range(10)] + rdd = self.sc.parallelize(data, 2) + + # field names can differ as long as types are in expected position. + df = rdd.toDF(" key: int, value: string ") + self.assertEqual(df.schema.simpleString(), "struct") + self.assertEqual(df.collect(), data) + + # number of fields must match. + with QuietTest(self.sc): + self.assertRaisesRegexp(Exception, "Length of object", + lambda: rdd.toDF("key: int").collect()) + def test_join_without_on(self): df1 = self.spark.range(1).toDF("a") df2 = self.spark.range(1).toDF("b") diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 1f6534836d64a..d4b8005ea4e59 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -1392,9 +1392,11 @@ def _create_row_inbound_converter(dataType): return lambda *a: dataType.fromInternal(a) -def _create_row(fields, values): +def _create_row(fields, values, from_dict=False): row = Row(*values) row.__fields__ = fields + if from_dict: + row.__from_dict__ = True return row @@ -1445,15 +1447,15 @@ def __new__(self, *args, **kwargs): raise ValueError("Can not use both args " "and kwargs to create Row") if kwargs: - # create row objects + # create row object from named arguments, order not guaranteed so will be sorted names = sorted(kwargs.keys()) row = tuple.__new__(self, [kwargs[n] for n in names]) row.__fields__ = names - row.__from_dict__ = True + row.__from_dict__ = True # Row elements will be accessed by field name, not position return row else: - # create row class or objects + # create a row class for generating objects or a tuple-like object return tuple.__new__(self, args) def asDict(self, recursive=False): @@ -1532,7 +1534,10 @@ def __setattr__(self, key, value): def __reduce__(self): """Returns a tuple so Python knows how to pickle Row.""" if hasattr(self, "__fields__"): - return (_create_row, (self.__fields__, tuple(self))) + if hasattr(self, "__from_dict__"): + return (_create_row, (self.__fields__, tuple(self), True)) + else: + return (_create_row, (self.__fields__, tuple(self))) else: return tuple.__reduce__(self)