From 96db384f9eba821cad803ee80e3e00e1dea50085 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Fri, 25 Jul 2014 23:59:39 -0700 Subject: [PATCH 1/4] support datetime type for SchemaRDD --- .../apache/spark/api/python/PythonRDD.scala | 4 +- python/pyspark/sql.py | 9 +++-- .../org/apache/spark/sql/SQLContext.scala | 40 ++++++++++++++++++- .../org/apache/spark/sql/SchemaRDD.scala | 5 +++ 4 files changed, 50 insertions(+), 8 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index d6b0988641a97..1ec9051f7a494 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -550,11 +550,11 @@ private[spark] object PythonRDD extends Logging { def pythonToJavaMap(pyRDD: JavaRDD[Array[Byte]]): JavaRDD[Map[String, _]] = { pyRDD.rdd.mapPartitions { iter => val unpickle = new Unpickler - // TODO: Figure out why flatMap is necessay for pyspark iter.flatMap { row => unpickle.loads(row) match { + // in case of objects are pickled in batch mode case objs: java.util.ArrayList[JMap[String, _] @unchecked] => objs.map(_.toMap) - // Incase the partition doesn't have a collection + // not in batch mode case obj: JMap[String @unchecked, _] => Seq(obj.toMap) } } diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index cb83e89176823..63c53c6f607d6 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -47,12 +47,13 @@ def __init__(self, sparkContext, sqlContext=None): ... ValueError:... - >>> allTypes = sc.parallelize([{"int" : 1, "string" : "string", "double" : 1.0, "long": 1L, - ... "boolean" : True}]) + >>> from datetime import datetime + >>> allTypes = sc.parallelize([{"int": 1, "string": "string", "double": 1.0, "long": 1L, + ... "boolean": True, "time": datetime(2010, 1, 1, 1, 1, 1)}]) >>> srdd = sqlCtx.inferSchema(allTypes).map(lambda x: (x.int, x.string, x.double, x.long, - ... x.boolean)) + ... x.boolean, x.time)) >>> srdd.collect()[0] - (1, u'string', 1.0, 1, True) + (1, u'string', 1.0, 1, True, datetime.datetime(2010, 1, 1, 1, 1, 1)) """ self._sc = sparkContext self._jsc = self._sc._jsc diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 4abd89955bd27..aae757da5bf6b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -357,16 +357,52 @@ class SQLContext(@transient val sparkContext: SparkContext) case c: java.util.Map[_, _] => val (key, value) = c.head MapType(typeFor(key), typeFor(value)) + case c: java.util.Calendar => TimestampType case c if c.getClass.isArray => val elem = c.asInstanceOf[Array[_]].head ArrayType(typeFor(elem)) case c => throw new Exception(s"Object of type $c cannot be used") } - val schema = rdd.first().map { case (fieldName, obj) => + val firstRow = rdd.first() + val schema = firstRow.map { case (fieldName, obj) => AttributeReference(fieldName, typeFor(obj), true)() }.toSeq - val rowRdd = rdd.mapPartitions { iter => + def needTransform(obj: Any): Boolean = obj match { + case c: java.util.List[_] => c.exists(needTransform) + case c: java.util.Set[_] => c.exists(needTransform) + case c: java.util.Map[_, _] => c.exists { + case (key, value) => needTransform(key) || needTransform(value) + } + case c if c.getClass.isArray => + c.asInstanceOf[Array[_]].exists(needTransform) + case c: java.util.Calendar => true + case c => false + } + + def transform(obj: Any): Any = obj match { + case c: java.util.List[_] => c.map(transform) + case c: java.util.Set[_] => c.map(transform) + case c: java.util.Map[_, _] => c.map { + case (key, value) => (transform(key), transform(value)) + } + case c if c.getClass.isArray => + c.asInstanceOf[Array[_]].map(transform) + case c: java.util.Calendar => + new java.sql.Timestamp(c.getTime().getTime()) + case c => c + } + + val need = firstRow.exists {case (key, value) => needTransform(value)} + val transformed = if (need) { + rdd.mapPartitions { iter => + iter.map { + m => m.map {case (key, value) => (key, transform(value))} + } + } + } else rdd + + val rowRdd = transformed.mapPartitions { iter => iter.map { map => new GenericRow(map.values.toArray.asInstanceOf[Array[Any]]): Row } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala index 31d27bb4f0571..23e026d2c99c2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala @@ -395,6 +395,11 @@ class SchemaRDD( arr.asInstanceOf[Array[Any]].map { element => rowToMap(element.asInstanceOf[Row], struct) } + case t: java.sql.Timestamp => { + val c = java.util.Calendar.getInstance() + c.setTimeInMillis(t.getTime()) + c + } case other => other } map.put(attrName, arrayValues) From 709d40d583b0dddcf5d0f471a268cf50bf655e59 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Sat, 26 Jul 2014 18:58:11 -0700 Subject: [PATCH 2/4] remove brackets --- sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala index 23e026d2c99c2..12f67622cd7e3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala @@ -395,11 +395,10 @@ class SchemaRDD( arr.asInstanceOf[Array[Any]].map { element => rowToMap(element.asInstanceOf[Row], struct) } - case t: java.sql.Timestamp => { + case t: java.sql.Timestamp => val c = java.util.Calendar.getInstance() c.setTimeInMillis(t.getTime()) c - } case other => other } map.put(attrName, arrayValues) From c9d607a2d6979f95a3a67f3c60e249243c791e5f Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Mon, 28 Jul 2014 00:26:29 -0700 Subject: [PATCH 3/4] convert datetype for runtime java.util.{List,Set} => Seq java.util.Map => Map but it can not convert Seq into java.util.Set, so set() and tuple() and array() can not been handled gracefully (back with the original type). We can not access items in ArrayType by position, but this is not defined for set(). Do we still want to support set()/tuple()/array() ? --- python/pyspark/sql.py | 19 +++++++-------- .../org/apache/spark/sql/SQLContext.scala | 23 +++++++++---------- .../org/apache/spark/sql/SchemaRDD.scala | 17 ++++++++------ 3 files changed, 31 insertions(+), 28 deletions(-) diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index 63c53c6f607d6..5662f500e74fa 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -49,11 +49,12 @@ def __init__(self, sparkContext, sqlContext=None): >>> from datetime import datetime >>> allTypes = sc.parallelize([{"int": 1, "string": "string", "double": 1.0, "long": 1L, - ... "boolean": True, "time": datetime(2010, 1, 1, 1, 1, 1)}]) + ... "boolean": True, "time": datetime(2010, 1, 1, 1, 1, 1), "dict": {"a": 1}, + ... "list": [1, 2, 3]}]) >>> srdd = sqlCtx.inferSchema(allTypes).map(lambda x: (x.int, x.string, x.double, x.long, - ... x.boolean, x.time)) + ... x.boolean, x.time, x.dict["a"], x.list)) >>> srdd.collect()[0] - (1, u'string', 1.0, 1, True, datetime.datetime(2010, 1, 1, 1, 1, 1)) + (1, u'string', 1.0, 1, True, datetime.datetime(2010, 1, 1, 1, 1, 1), 1, [1, 2, 3]) """ self._sc = sparkContext self._jsc = self._sc._jsc @@ -89,13 +90,13 @@ def inferSchema(self, rdd): >>> from array import array >>> srdd = sqlCtx.inferSchema(nestedRdd1) - >>> srdd.collect() == [{"f1" : array('i', [1, 2]), "f2" : {"row1" : 1.0}}, - ... {"f1" : array('i', [2, 3]), "f2" : {"row2" : 2.0}}] + >>> srdd.collect() == [{"f1" : [1, 2], "f2" : {"row1" : 1.0}}, + ... {"f1" : [2, 3], "f2" : {"row2" : 2.0}}] True >>> srdd = sqlCtx.inferSchema(nestedRdd2) - >>> srdd.collect() == [{"f1" : [[1, 2], [2, 3]], "f2" : set([1, 2]), "f3" : (1, 2)}, - ... {"f1" : [[2, 3], [3, 4]], "f2" : set([2, 3]), "f3" : (2, 3)}] + >>> srdd.collect() == [{"f1" : [[1, 2], [2, 3]], "f2" : [1, 2], "f3" : [1, 2]}, + ... {"f1" : [[2, 3], [3, 4]], "f2" : [2, 3], "f3" : [2, 3]}] True """ if (rdd.__class__ is SchemaRDD): @@ -510,8 +511,8 @@ def _test(): {"f1": array('i', [1, 2]), "f2": {"row1": 1.0}}, {"f1": array('i', [2, 3]), "f2": {"row2": 2.0}}]) globs['nestedRdd2'] = sc.parallelize([ - {"f1": [[1, 2], [2, 3]], "f2": set([1, 2]), "f3": (1, 2)}, - {"f1": [[2, 3], [3, 4]], "f2": set([2, 3]), "f3": (2, 3)}]) + {"f1": [[1, 2], [2, 3]], "f2": set([1, 2]), "f3": [1, 2]}, + {"f1": [[2, 3], [3, 4]], "f2": set([2, 3]), "f3": [2, 3]}]) (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) globs['sc'].stop() if failure_count: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index aae757da5bf6b..5b95e8bea6442 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -369,25 +369,24 @@ class SQLContext(@transient val sparkContext: SparkContext) }.toSeq def needTransform(obj: Any): Boolean = obj match { - case c: java.util.List[_] => c.exists(needTransform) - case c: java.util.Set[_] => c.exists(needTransform) - case c: java.util.Map[_, _] => c.exists { - case (key, value) => needTransform(key) || needTransform(value) - } - case c if c.getClass.isArray => - c.asInstanceOf[Array[_]].exists(needTransform) + case c: java.util.List[_] => true + case c: java.util.Set[_] => true + case c: java.util.Map[_, _] => true + case c if c.getClass.isArray => true case c: java.util.Calendar => true case c => false } + // convert JList, JSet into Seq, convert JMap into Map + // convert Calendar into Timestamp def transform(obj: Any): Any = obj match { - case c: java.util.List[_] => c.map(transform) - case c: java.util.Set[_] => c.map(transform) + case c: java.util.List[_] => c.map(transform).toSeq + case c: java.util.Set[_] => c.map(transform).toSet.toSeq case c: java.util.Map[_, _] => c.map { - case (key, value) => (transform(key), transform(value)) - } + case (key, value) => (key, transform(value)) + }.toMap case c if c.getClass.isArray => - c.asInstanceOf[Array[_]].map(transform) + c.asInstanceOf[Array[_]].map(transform).toSeq case c: java.util.Calendar => new java.sql.Timestamp(c.getTime().getTime()) case c => c diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala index 12f67622cd7e3..e5e9514cf6b52 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.{Inner, JoinType} -import org.apache.spark.sql.catalyst.types.{ArrayType, BooleanType, StructType} +import org.apache.spark.sql.catalyst.types.{ArrayType, BooleanType, StructType, MapType} import org.apache.spark.sql.execution.{ExistingRdd, SparkLogicalPlan} import org.apache.spark.api.java.JavaRDD @@ -388,20 +388,21 @@ class SchemaRDD( case seq: Seq[Any] => seq.map(element => rowToMap(element.asInstanceOf[Row], struct)).asJava case list: JList[_] => - list.map(element => rowToMap(element.asInstanceOf[Row], struct)) + list.map(element => rowToMap(element.asInstanceOf[Row], struct)).asJava case set: JSet[_] => - set.map(element => rowToMap(element.asInstanceOf[Row], struct)) + set.map(element => rowToMap(element.asInstanceOf[Row], struct)).asJava case arr if arr != null && arr.getClass.isArray => arr.asInstanceOf[Array[Any]].map { element => rowToMap(element.asInstanceOf[Row], struct) } - case t: java.sql.Timestamp => - val c = java.util.Calendar.getInstance() - c.setTimeInMillis(t.getTime()) - c case other => other } map.put(attrName, arrayValues) + case m @ MapType(_, struct: StructType) => + val nm = obj.asInstanceOf[Map[_,_]].map { + case (k, v) => (k, rowToMap(v.asInstanceOf[Row], struct)) + }.asJava + map.put(attrName, nm) case array: ArrayType => { val arrayValues = obj match { case seq: Seq[Any] => seq.asJava @@ -409,6 +410,8 @@ class SchemaRDD( } map.put(attrName, arrayValues) } + case m: MapType => map.put(attrName, obj.asInstanceOf[Map[_,_]].asJava) + // Pyrolite can handle Timestamp case other => map.put(attrName, obj) } } From f0599b0f58a0e26064a62747158dac621ed7ede7 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Mon, 28 Jul 2014 12:00:51 -0700 Subject: [PATCH 4/4] remove tests for sets and tuple in sql, fix list of list --- python/pyspark/sql.py | 8 +-- .../org/apache/spark/sql/SQLContext.scala | 9 ++-- .../org/apache/spark/sql/SchemaRDD.scala | 53 ++++++------------- 3 files changed, 25 insertions(+), 45 deletions(-) diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index 5662f500e74fa..a6b3277db3266 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -95,8 +95,8 @@ def inferSchema(self, rdd): True >>> srdd = sqlCtx.inferSchema(nestedRdd2) - >>> srdd.collect() == [{"f1" : [[1, 2], [2, 3]], "f2" : [1, 2], "f3" : [1, 2]}, - ... {"f1" : [[2, 3], [3, 4]], "f2" : [2, 3], "f3" : [2, 3]}] + >>> srdd.collect() == [{"f1" : [[1, 2], [2, 3]], "f2" : [1, 2]}, + ... {"f1" : [[2, 3], [3, 4]], "f2" : [2, 3]}] True """ if (rdd.__class__ is SchemaRDD): @@ -511,8 +511,8 @@ def _test(): {"f1": array('i', [1, 2]), "f2": {"row1": 1.0}}, {"f1": array('i', [2, 3]), "f2": {"row2": 2.0}}]) globs['nestedRdd2'] = sc.parallelize([ - {"f1": [[1, 2], [2, 3]], "f2": set([1, 2]), "f3": [1, 2]}, - {"f1": [[2, 3], [3, 4]], "f2": set([2, 3]), "f3": [2, 3]}]) + {"f1": [[1, 2], [2, 3]], "f2": [1, 2]}, + {"f1": [[2, 3], [3, 4]], "f2": [2, 3]}]) (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) globs['sc'].stop() if failure_count: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 5b95e8bea6442..c178dad662532 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -352,12 +352,13 @@ class SQLContext(@transient val sparkContext: SparkContext) case c: java.lang.Long => LongType case c: java.lang.Double => DoubleType case c: java.lang.Boolean => BooleanType + case c: java.math.BigDecimal => DecimalType + case c: java.sql.Timestamp => TimestampType + case c: java.util.Calendar => TimestampType case c: java.util.List[_] => ArrayType(typeFor(c.head)) - case c: java.util.Set[_] => ArrayType(typeFor(c.head)) case c: java.util.Map[_, _] => val (key, value) = c.head MapType(typeFor(key), typeFor(value)) - case c: java.util.Calendar => TimestampType case c if c.getClass.isArray => val elem = c.asInstanceOf[Array[_]].head ArrayType(typeFor(elem)) @@ -370,18 +371,16 @@ class SQLContext(@transient val sparkContext: SparkContext) def needTransform(obj: Any): Boolean = obj match { case c: java.util.List[_] => true - case c: java.util.Set[_] => true case c: java.util.Map[_, _] => true case c if c.getClass.isArray => true case c: java.util.Calendar => true case c => false } - // convert JList, JSet into Seq, convert JMap into Map + // convert JList, JArray into Seq, convert JMap into Map // convert Calendar into Timestamp def transform(obj: Any): Any = obj match { case c: java.util.List[_] => c.map(transform).toSeq - case c: java.util.Set[_] => c.map(transform).toSet.toSeq case c: java.util.Map[_, _] => c.map { case (key, value) => (key, transform(value)) }.toMap diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala index e5e9514cf6b52..019ff9d300a18 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.{Inner, JoinType} -import org.apache.spark.sql.catalyst.types.{ArrayType, BooleanType, StructType, MapType} +import org.apache.spark.sql.catalyst.types.{DataType, ArrayType, BooleanType, StructType, MapType} import org.apache.spark.sql.execution.{ExistingRdd, SparkLogicalPlan} import org.apache.spark.api.java.JavaRDD @@ -376,46 +376,27 @@ class SchemaRDD( * Converts a JavaRDD to a PythonRDD. It is used by pyspark. */ private[sql] def javaToPython: JavaRDD[Array[Byte]] = { + def toJava(obj: Any, dataType: DataType): Any = dataType match { + case struct: StructType => rowToMap(obj.asInstanceOf[Row], struct) + case array: ArrayType => obj match { + case seq: Seq[Any] => seq.map(x => toJava(x, array.elementType)).asJava + case list: JList[_] => list.map(x => toJava(x, array.elementType)).asJava + case arr if arr != null && arr.getClass.isArray => + arr.asInstanceOf[Array[Any]].map(x => toJava(x, array.elementType)) + case other => other + } + case mt: MapType => obj.asInstanceOf[Map[_, _]].map { + case (k, v) => (k, toJava(v, mt.valueType)) // key should be primitive type + }.asJava + // Pyrolite can handle Timestamp + case other => obj + } def rowToMap(row: Row, structType: StructType): JMap[String, Any] = { val fields = structType.fields.map(field => (field.name, field.dataType)) val map: JMap[String, Any] = new java.util.HashMap row.zip(fields).foreach { - case (obj, (attrName, dataType)) => - dataType match { - case struct: StructType => map.put(attrName, rowToMap(obj.asInstanceOf[Row], struct)) - case array @ ArrayType(struct: StructType) => - val arrayValues = obj match { - case seq: Seq[Any] => - seq.map(element => rowToMap(element.asInstanceOf[Row], struct)).asJava - case list: JList[_] => - list.map(element => rowToMap(element.asInstanceOf[Row], struct)).asJava - case set: JSet[_] => - set.map(element => rowToMap(element.asInstanceOf[Row], struct)).asJava - case arr if arr != null && arr.getClass.isArray => - arr.asInstanceOf[Array[Any]].map { - element => rowToMap(element.asInstanceOf[Row], struct) - } - case other => other - } - map.put(attrName, arrayValues) - case m @ MapType(_, struct: StructType) => - val nm = obj.asInstanceOf[Map[_,_]].map { - case (k, v) => (k, rowToMap(v.asInstanceOf[Row], struct)) - }.asJava - map.put(attrName, nm) - case array: ArrayType => { - val arrayValues = obj match { - case seq: Seq[Any] => seq.asJava - case other => other - } - map.put(attrName, arrayValues) - } - case m: MapType => map.put(attrName, obj.asInstanceOf[Map[_,_]].asJava) - // Pyrolite can handle Timestamp - case other => map.put(attrName, obj) - } + case (obj, (attrName, dataType)) => map.put(attrName, toJava(obj, dataType)) } - map }