diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index bdd40f340235b..052cc486e8cb2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -568,7 +568,7 @@ object ScalaReflection extends ScalaReflection { udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt(), Nil, dataType = ObjectType(udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt())) - Invoke(obj, "serialize", udt.sqlType, inputObject :: Nil) + Invoke(obj, "serialize", udt, inputObject :: Nil) case t if UDTRegistration.exists(getClassNameFromType(t)) => val udt = UDTRegistration.getUDTFor(getClassNameFromType(t)).get.newInstance() @@ -577,7 +577,7 @@ object ScalaReflection extends ScalaReflection { udt.getClass, Nil, dataType = ObjectType(udt.getClass)) - Invoke(obj, "serialize", udt.sqlType, inputObject :: Nil) + Invoke(obj, "serialize", udt, inputObject :: Nil) case t if definedByConstructorParams(t) => val params = getConstructorParameters(t) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index 2f8ba33f3520f..0de9166aa29aa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -88,7 +88,7 @@ object RowEncoder { udtClass, Nil, dataType = ObjectType(udtClass), false) - Invoke(obj, "serialize", udt.sqlType, inputObject :: Nil) + Invoke(obj, "serialize", udt, inputObject :: Nil) case TimestampType => StaticInvoke( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala index 7d7b486530de9..474f17ff7afbc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala @@ -188,6 +188,10 @@ class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with ParquetT val toCatalystConverter = CatalystTypeConverters.createToCatalystConverter(udt) assert(toCatalystConverter(null) === null) + } + test("SPARK-15658: Analysis exception if Dataset.map returns UDT object") { + // call `collect` to make sure this query can pass analysis. + pointsRDD.as[MyLabeledPoint].map(_.copy(label = 2.0)).collect() } }