From 8a4b63d2928ee93410764d2b8cec33c562f0eeb7 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 2 Aug 2017 06:59:54 +0000 Subject: [PATCH 1/3] Dataset should work with type alias. --- .../spark/sql/catalyst/ScalaReflection.scala | 27 ++++++++++--------- .../org/apache/spark/sql/DatasetSuite.scala | 11 ++++++++ 2 files changed, 25 insertions(+), 13 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 004b4ef8f69fe..17e595f9c5d8d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -62,7 +62,7 @@ object ScalaReflection extends ScalaReflection { def dataTypeFor[T : TypeTag]: DataType = dataTypeFor(localTypeOf[T]) private def dataTypeFor(tpe: `Type`): DataType = { - tpe match { + tpe.dealias match { case t if t <:< definitions.IntTpe => IntegerType case t if t <:< definitions.LongTpe => LongType case t if t <:< definitions.DoubleTpe => DoubleType @@ -94,7 +94,7 @@ object ScalaReflection extends ScalaReflection { * JVM form instead of the Scala Array that handles auto boxing. */ private def arrayClassFor(tpe: `Type`): ObjectType = { - val cls = tpe match { + val cls = tpe.dealias match { case t if t <:< definitions.IntTpe => classOf[Array[Int]] case t if t <:< definitions.LongTpe => classOf[Array[Long]] case t if t <:< definitions.DoubleTpe => classOf[Array[Double]] @@ -193,7 +193,7 @@ object ScalaReflection extends ScalaReflection { case _ => UpCast(expr, expected, walkedTypePath) } - tpe match { + tpe.dealias match { case t if !dataTypeFor(t).isInstanceOf[ObjectType] => getPath case t if t <:< localTypeOf[Option[_]] => @@ -469,7 +469,7 @@ object ScalaReflection extends ScalaReflection { } } - tpe match { + tpe.dealias match { case _ if !inputObject.dataType.isInstanceOf[ObjectType] => inputObject case t if t <:< localTypeOf[Option[_]] => @@ -643,7 +643,7 @@ object ScalaReflection extends ScalaReflection { * we also treat [[DefinedByConstructorParams]] as product type. */ def optionOfProductType(tpe: `Type`): Boolean = { - tpe match { + tpe.dealias match { case t if t <:< localTypeOf[Option[_]] => val TypeRef(_, _, Seq(optType)) = t definedByConstructorParams(optType) @@ -690,7 +690,7 @@ object ScalaReflection extends ScalaReflection { /* * Retrieves the runtime class corresponding to the provided type. */ - def getClassFromType(tpe: Type): Class[_] = mirror.runtimeClass(tpe.typeSymbol.asClass) + def getClassFromType(tpe: Type): Class[_] = mirror.runtimeClass(tpe.dealias.typeSymbol.asClass) case class Schema(dataType: DataType, nullable: Boolean) @@ -705,7 +705,7 @@ object ScalaReflection extends ScalaReflection { /** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */ def schemaFor(tpe: `Type`): Schema = { - tpe match { + tpe.dealias match { case t if t.typeSymbol.annotations.exists(_.tree.tpe =:= typeOf[SQLUserDefinedType]) => val udt = getClassFromType(t).getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance() Schema(udt, nullable = true) @@ -775,7 +775,7 @@ object ScalaReflection extends ScalaReflection { * Whether the fields of the given type is defined entirely by its constructor parameters. */ def definedByConstructorParams(tpe: Type): Boolean = { - tpe <:< localTypeOf[Product] || tpe <:< localTypeOf[DefinedByConstructorParams] + tpe.dealias <:< localTypeOf[Product] || tpe.dealias <:< localTypeOf[DefinedByConstructorParams] } private val javaKeywords = Set("abstract", "assert", "boolean", "break", "byte", "case", "catch", @@ -829,7 +829,7 @@ trait ScalaReflection { * synthetic classes, emulating behaviour in Java bytecode. */ def getClassNameFromType(tpe: `Type`): String = { - tpe.erasure.typeSymbol.asClass.fullName + tpe.dealias.erasure.typeSymbol.asClass.fullName } /** @@ -848,9 +848,10 @@ trait ScalaReflection { * support inner class. */ def getConstructorParameters(tpe: Type): Seq[(String, Type)] = { - val formalTypeArgs = tpe.typeSymbol.asClass.typeParams - val TypeRef(_, _, actualTypeArgs) = tpe - val params = constructParams(tpe) + val dealiasedTpe = tpe.dealias + val formalTypeArgs = dealiasedTpe.typeSymbol.asClass.typeParams + val TypeRef(_, _, actualTypeArgs) = dealiasedTpe + val params = constructParams(dealiasedTpe) // if there are type variables to fill in, do the substitution (SomeClass[T] -> SomeClass[Int]) if (actualTypeArgs.nonEmpty) { params.map { p => @@ -864,7 +865,7 @@ trait ScalaReflection { } protected def constructParams(tpe: Type): Seq[Symbol] = { - val constructorSymbol = tpe.member(termNames.CONSTRUCTOR) + val constructorSymbol = tpe.dealias.member(termNames.CONSTRUCTOR) val params = if (constructorSymbol.isMethod) { constructorSymbol.asMethod.paramLists } else { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 40235e32d35da..144aad6c95c14 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -34,6 +34,11 @@ import org.apache.spark.sql.types._ case class TestDataPoint(x: Int, y: Double, s: String, t: TestDataPoint2) case class TestDataPoint2(x: Int, s: String) +object TestForTypeAlias { + type TwoInt = (Int, Int) + def tupleTypeAlias: TwoInt = (1, 1) +} + class DatasetSuite extends QueryTest with SharedSQLContext { import testImplicits._ @@ -1317,6 +1322,12 @@ class DatasetSuite extends QueryTest with SharedSQLContext { checkAnswer(df.orderBy($"id"), expected) checkAnswer(df.orderBy('id), expected) } + + test("SPARK-21567: Dataset with Tuple of type alias") { + checkDataset( + Seq(1).toDS().map(_ => ("", TestForTypeAlias.tupleTypeAlias)), + ("", (1, 1))) + } } case class WithImmutableMap(id: String, map_test: scala.collection.immutable.Map[Long, String]) From 312c7b0ee3f58c05d7b0116ae4263456f5c93de9 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 8 Aug 2017 02:36:14 +0000 Subject: [PATCH 2/3] Add another test. --- .../test/scala/org/apache/spark/sql/DatasetSuite.scala | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 144aad6c95c14..636205df9def9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -37,6 +37,9 @@ case class TestDataPoint2(x: Int, s: String) object TestForTypeAlias { type TwoInt = (Int, Int) def tupleTypeAlias: TwoInt = (1, 1) + + type SeqOfTwoInt = Seq[TwoInt] + def seqOfTupleTypeAlias: SeqOfTwoInt = Seq((1, 1), (2, 2)) } class DatasetSuite extends QueryTest with SharedSQLContext { @@ -1323,10 +1326,14 @@ class DatasetSuite extends QueryTest with SharedSQLContext { checkAnswer(df.orderBy('id), expected) } - test("SPARK-21567: Dataset with Tuple of type alias") { + test("SPARK-21567: Dataset should work with type alias") { checkDataset( Seq(1).toDS().map(_ => ("", TestForTypeAlias.tupleTypeAlias)), ("", (1, 1))) + + checkDataset( + Seq(1).toDS().map(_ => ("", TestForTypeAlias.seqOfTupleTypeAlias)), + ("", Seq((1, 1), (2, 2)))) } } From 031d1d3018498ceee23ae9f32d619a7600396a41 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 8 Aug 2017 04:22:57 +0000 Subject: [PATCH 3/3] Add another test. --- .../test/scala/org/apache/spark/sql/DatasetSuite.scala | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 636205df9def9..6245b2eff9fa1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -36,9 +36,11 @@ case class TestDataPoint2(x: Int, s: String) object TestForTypeAlias { type TwoInt = (Int, Int) - def tupleTypeAlias: TwoInt = (1, 1) - + type ThreeInt = (TwoInt, Int) type SeqOfTwoInt = Seq[TwoInt] + + def tupleTypeAlias: TwoInt = (1, 1) + def nestedTupleTypeAlias: ThreeInt = ((1, 1), 2) def seqOfTupleTypeAlias: SeqOfTwoInt = Seq((1, 1), (2, 2)) } @@ -1331,6 +1333,10 @@ class DatasetSuite extends QueryTest with SharedSQLContext { Seq(1).toDS().map(_ => ("", TestForTypeAlias.tupleTypeAlias)), ("", (1, 1))) + checkDataset( + Seq(1).toDS().map(_ => ("", TestForTypeAlias.nestedTupleTypeAlias)), + ("", ((1, 1), 2))) + checkDataset( Seq(1).toDS().map(_ => ("", TestForTypeAlias.seqOfTupleTypeAlias)), ("", Seq((1, 1), (2, 2))))