Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ object ScalaReflection extends ScalaReflection {
def dataTypeFor[T : TypeTag]: DataType = dataTypeFor(localTypeOf[T])

private def dataTypeFor(tpe: `Type`): DataType = {
Copy link
Member Author

@viirya viirya Aug 8, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dataTypeFor can be called like this at many places:

val TypeRef(_, _, Seq(optType)) = t
val unwrapped = UnwrapOption(dataTypeFor(optType), inputObject)

So we need to dealias it too.

tpe match {
tpe.dealias match {
case t if t <:< definitions.IntTpe => IntegerType
case t if t <:< definitions.LongTpe => LongType
case t if t <:< definitions.DoubleTpe => DoubleType
Expand Down Expand Up @@ -94,7 +94,7 @@ object ScalaReflection extends ScalaReflection {
* JVM form instead of the Scala Array that handles auto boxing.
*/
private def arrayClassFor(tpe: `Type`): ObjectType = {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

arrayClassFor is called at many place. The typical calling pattern looks like:

        val TypeRef(_, _, Seq(elementType)) = tpe
        arrayClassFor(elementType)

So instead of dealiasing when calling, we dealiase it here.

val cls = tpe match {
val cls = tpe.dealias match {
case t if t <:< definitions.IntTpe => classOf[Array[Int]]
case t if t <:< definitions.LongTpe => classOf[Array[Long]]
case t if t <:< definitions.DoubleTpe => classOf[Array[Double]]
Expand Down Expand Up @@ -193,7 +193,7 @@ object ScalaReflection extends ScalaReflection {
case _ => UpCast(expr, expected, walkedTypePath)
}

tpe match {
tpe.dealias match {
Copy link
Member Author

@viirya viirya Aug 8, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is for deserializerFor. deserializerFor can call itself. It has many entrance points. So we need to dealias its given type parameter.

case t if !dataTypeFor(t).isInstanceOf[ObjectType] => getPath

case t if t <:< localTypeOf[Option[_]] =>
Expand Down Expand Up @@ -469,7 +469,7 @@ object ScalaReflection extends ScalaReflection {
}
}

tpe match {
tpe.dealias match {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For serializerFor. The same reason as deserializerFor.

case _ if !inputObject.dataType.isInstanceOf[ObjectType] => inputObject

case t if t <:< localTypeOf[Option[_]] =>
Expand Down Expand Up @@ -643,7 +643,7 @@ object ScalaReflection extends ScalaReflection {
* we also treat [[DefinedByConstructorParams]] as product type.
*/
def optionOfProductType(tpe: `Type`): Boolean = {
tpe match {
tpe.dealias match {
case t if t <:< localTypeOf[Option[_]] =>
val TypeRef(_, _, Seq(optType)) = t
definedByConstructorParams(optType)
Expand Down Expand Up @@ -690,7 +690,7 @@ object ScalaReflection extends ScalaReflection {
/*
* Retrieves the runtime class corresponding to the provided type.
*/
def getClassFromType(tpe: Type): Class[_] = mirror.runtimeClass(tpe.typeSymbol.asClass)
def getClassFromType(tpe: Type): Class[_] = mirror.runtimeClass(tpe.dealias.typeSymbol.asClass)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be saved from dealiasing. I'll remove it.


case class Schema(dataType: DataType, nullable: Boolean)

Expand All @@ -705,7 +705,7 @@ object ScalaReflection extends ScalaReflection {

/** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */
def schemaFor(tpe: `Type`): Schema = {
tpe match {
tpe.dealias match {
Copy link
Member Author

@viirya viirya Aug 8, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can't be saved from dealiasing.

case t if t.typeSymbol.annotations.exists(_.tree.tpe =:= typeOf[SQLUserDefinedType]) =>
val udt = getClassFromType(t).getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance()
Schema(udt, nullable = true)
Expand Down Expand Up @@ -775,7 +775,7 @@ object ScalaReflection extends ScalaReflection {
* Whether the fields of the given type is defined entirely by its constructor parameters.
*/
def definedByConstructorParams(tpe: Type): Boolean = {
tpe <:< localTypeOf[Product] || tpe <:< localTypeOf[DefinedByConstructorParams]
tpe.dealias <:< localTypeOf[Product] || tpe.dealias <:< localTypeOf[DefinedByConstructorParams]
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be saved from dealiasing. I'll remove it.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh. no. definedByConstructorParams is called in ExpressionEncoder too. So we should do dealias here.

}

private val javaKeywords = Set("abstract", "assert", "boolean", "break", "byte", "case", "catch",
Expand Down Expand Up @@ -829,7 +829,7 @@ trait ScalaReflection {
* synthetic classes, emulating behaviour in Java bytecode.
*/
def getClassNameFromType(tpe: `Type`): String = {
tpe.erasure.typeSymbol.asClass.fullName
tpe.dealias.erasure.typeSymbol.asClass.fullName
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is needed.

}

/**
Expand All @@ -848,9 +848,10 @@ trait ScalaReflection {
* support inner class.
*/
def getConstructorParameters(tpe: Type): Seq[(String, Type)] = {
val formalTypeArgs = tpe.typeSymbol.asClass.typeParams
val TypeRef(_, _, actualTypeArgs) = tpe
val params = constructParams(tpe)
val dealiasedTpe = tpe.dealias
val formalTypeArgs = dealiasedTpe.typeSymbol.asClass.typeParams
val TypeRef(_, _, actualTypeArgs) = dealiasedTpe
val params = constructParams(dealiasedTpe)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is needed.

// if there are type variables to fill in, do the substitution (SomeClass[T] -> SomeClass[Int])
if (actualTypeArgs.nonEmpty) {
params.map { p =>
Expand All @@ -864,7 +865,7 @@ trait ScalaReflection {
}

protected def constructParams(tpe: Type): Seq[Symbol] = {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is called at different points. So it's needed too.

val constructorSymbol = tpe.member(termNames.CONSTRUCTOR)
val constructorSymbol = tpe.dealias.member(termNames.CONSTRUCTOR)
val params = if (constructorSymbol.isMethod) {
constructorSymbol.asMethod.paramLists
} else {
Expand Down
24 changes: 24 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,16 @@ import org.apache.spark.sql.types._
case class TestDataPoint(x: Int, y: Double, s: String, t: TestDataPoint2)
case class TestDataPoint2(x: Int, s: String)

object TestForTypeAlias {
type TwoInt = (Int, Int)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's also test nested type alias

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry what the nested type alias means?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

like type TwoIntSeq = Seq[TwoInt]?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

like

type TwoInt = (Int, Int)
type ThreeInt = (TowInt, Int)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK. Added another test for this case.

type ThreeInt = (TwoInt, Int)
type SeqOfTwoInt = Seq[TwoInt]

def tupleTypeAlias: TwoInt = (1, 1)
def nestedTupleTypeAlias: ThreeInt = ((1, 1), 2)
def seqOfTupleTypeAlias: SeqOfTwoInt = Seq((1, 1), (2, 2))
}

class DatasetSuite extends QueryTest with SharedSQLContext {
import testImplicits._

Expand Down Expand Up @@ -1317,6 +1327,20 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
checkAnswer(df.orderBy($"id"), expected)
checkAnswer(df.orderBy('id), expected)
}

test("SPARK-21567: Dataset should work with type alias") {
checkDataset(
Seq(1).toDS().map(_ => ("", TestForTypeAlias.tupleTypeAlias)),
("", (1, 1)))

checkDataset(
Seq(1).toDS().map(_ => ("", TestForTypeAlias.nestedTupleTypeAlias)),
("", ((1, 1), 2)))

checkDataset(
Seq(1).toDS().map(_ => ("", TestForTypeAlias.seqOfTupleTypeAlias)),
("", Seq((1, 1), (2, 2))))
}
}

case class WithImmutableMap(id: String, map_test: scala.collection.immutable.Map[Long, String])
Expand Down