From 44fa8764884c9899bf2c08013b93cff68d5028a2 Mon Sep 17 00:00:00 2001 From: "Jungtaek Lim (HeartSaVioR)" Date: Wed, 27 Feb 2019 22:02:56 +0900 Subject: [PATCH 01/14] [SPARK-27001][SQL] Refactor "serializerFor" method between ScalaReflection and JavaTypeInference --- .../sql/catalyst/JavaTypeInference.scala | 115 ++++------- .../spark/sql/catalyst/ScalaReflection.scala | 141 ++++--------- .../sql/catalyst/SerializerBuildHelper.scala | 190 ++++++++++++++++++ 3 files changed, 267 insertions(+), 179 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala index 87b2ae8cdf7e1..ba22f86dae27f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala @@ -27,6 +27,7 @@ import scala.language.existentials import com.google.common.reflect.TypeToken import org.apache.spark.sql.catalyst.DeserializerBuildHelper._ +import org.apache.spark.sql.catalyst.SerializerBuildHelper._ import org.apache.spark.sql.catalyst.analysis.GetColumnByOrdinal import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.objects._ @@ -367,12 +368,10 @@ object JavaTypeInference { def toCatalystArray(input: Expression, elementType: TypeToken[_]): Expression = { val (dataType, nullable) = inferDataType(elementType) if (ScalaReflection.isNativeType(dataType)) { - NewInstance( - classOf[GenericArrayData], - input :: Nil, - dataType = ArrayType(dataType, nullable)) + createSerializerForGenericArray(input, dataType, nullable = nullable) } else { - MapObjects(serializerFor(_, elementType), input, ObjectType(elementType.getRawType)) + createSerializerForMapObjects(input, ObjectType(elementType.getRawType), + serializerFor(_, elementType)) } } @@ -380,60 +379,26 @@ object JavaTypeInference { inputObject } else { typeToken.getRawType match { - case c if c == classOf[String] => - StaticInvoke( - classOf[UTF8String], - StringType, - "fromString", - inputObject :: Nil, - returnNullable = false) - - case c if c == classOf[java.sql.Timestamp] => - StaticInvoke( - DateTimeUtils.getClass, - TimestampType, - "fromJavaTimestamp", - inputObject :: Nil, - returnNullable = false) - - case c if c == classOf[java.time.LocalDate] => - StaticInvoke( - DateTimeUtils.getClass, - DateType, - "localDateToDays", - inputObject :: Nil, - returnNullable = false) - - case c if c == classOf[java.sql.Date] => - StaticInvoke( - DateTimeUtils.getClass, - DateType, - "fromJavaDate", - inputObject :: Nil, - returnNullable = false) + case c if c == classOf[String] => createSerializerForString(inputObject) + + case c if c == classOf[java.time.Instant] => createSerializerForJavaInstant(inputObject) + + case c if c == classOf[java.sql.Timestamp] => createSerializerForSqlTimestamp(inputObject) + + case c if c == classOf[java.time.LocalDate] => createSerializerForJavaLocalDate(inputObject) + + case c if c == classOf[java.sql.Date] => createSerializerForSqlDate(inputObject) case c if c == classOf[java.math.BigDecimal] => - StaticInvoke( - Decimal.getClass, - DecimalType.SYSTEM_DEFAULT, - "apply", - inputObject :: Nil, - returnNullable = false) - - case c if c == classOf[java.lang.Boolean] => - Invoke(inputObject, "booleanValue", BooleanType) - case c if c == classOf[java.lang.Byte] => - Invoke(inputObject, "byteValue", ByteType) - case c if c == classOf[java.lang.Short] => - Invoke(inputObject, "shortValue", ShortType) - case c if c == classOf[java.lang.Integer] => - Invoke(inputObject, "intValue", IntegerType) - case c if c == classOf[java.lang.Long] => - Invoke(inputObject, "longValue", LongType) - case c if c == classOf[java.lang.Float] => - Invoke(inputObject, "floatValue", FloatType) - case c if c == classOf[java.lang.Double] => - Invoke(inputObject, "doubleValue", DoubleType) + createSerializerForJavaBigDecimal(inputObject) + + case c if c == classOf[java.lang.Boolean] => createSerializerForBoolean(inputObject) + case c if c == classOf[java.lang.Byte] => createSerializerForByte(inputObject) + case c if c == classOf[java.lang.Short] => createSerializerForShort(inputObject) + case c if c == classOf[java.lang.Integer] => createSerializerForInteger(inputObject) + case c if c == classOf[java.lang.Long] => createSerializerForLong(inputObject) + case c if c == classOf[java.lang.Float] => createSerializerForFloat(inputObject) + case c if c == classOf[java.lang.Double] => createSerializerForDouble(inputObject) case _ if typeToken.isArray => toCatalystArray(inputObject, typeToken.getComponentType) @@ -444,38 +409,36 @@ object JavaTypeInference { case _ if mapType.isAssignableFrom(typeToken) => val (keyType, valueType) = mapKeyValueType(typeToken) - ExternalMapToCatalyst( + createSerializerForMap( inputObject, - ObjectType(keyType.getRawType), - serializerFor(_, keyType), - keyNullable = true, - ObjectType(valueType.getRawType), - serializerFor(_, valueType), - valueNullable = true + MapElementInformation( + ObjectType(keyType.getRawType), + nullable = true, + serializerFor(_, keyType) + ), + MapElementInformation( + ObjectType(valueType.getRawType), + nullable = true, + serializerFor(_, valueType) + ), ) case other if other.isEnum => - StaticInvoke( - classOf[UTF8String], - StringType, - "fromString", - Invoke(inputObject, "name", ObjectType(classOf[String]), returnNullable = false) :: Nil, - returnNullable = false) + createSerializerForString( + Invoke(inputObject, "name", ObjectType(classOf[String]), returnNullable = false)) case other => val properties = getJavaBeanReadableAndWritableProperties(other) - val nonNullOutput = CreateNamedStruct(properties.flatMap { p => + val fields = properties.map { p => val fieldName = p.getName val fieldType = typeToken.method(p.getReadMethod).getReturnType val fieldValue = Invoke( inputObject, p.getReadMethod.getName, inferExternalType(fieldType.getRawType)) - expressions.Literal(fieldName) :: serializerFor(fieldValue, fieldType) :: Nil - }) - - val nullOutput = expressions.Literal.create(null, nonNullOutput.dataType) - expressions.If(IsNull(inputObject), nullOutput, nonNullOutput) + (fieldName, fieldValue) + } + createSerializerForObject(inputObject, fields) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index bbddd3312a581..9f496baa43810 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -21,10 +21,11 @@ import org.apache.commons.lang3.reflect.ConstructorUtils import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.DeserializerBuildHelper._ +import org.apache.spark.sql.catalyst.SerializerBuildHelper._ import org.apache.spark.sql.catalyst.analysis.GetColumnByOrdinal import org.apache.spark.sql.catalyst.expressions.{Expression, _} import org.apache.spark.sql.catalyst.expressions.objects._ -import org.apache.spark.sql.catalyst.util.{ArrayData, DateTimeUtils, GenericArrayData, MapData} +import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} @@ -395,30 +396,20 @@ object ScalaReflection extends ScalaReflection { case dt: ObjectType => val clsName = getClassNameFromType(elementType) val newPath = s"""- array element class: "$clsName"""" +: walkedTypePath - MapObjects(serializerFor(_, elementType, newPath, seenTypeSet), input, dt) + createSerializerForMapObjects(input, dt, + serializerFor(_, elementType, newPath, seenTypeSet)) case dt @ (BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType) => val cls = input.dataType.asInstanceOf[ObjectType].cls if (cls.isArray && cls.getComponentType.isPrimitive) { - StaticInvoke( - classOf[UnsafeArrayData], - ArrayType(dt, false), - "fromPrimitiveArray", - input :: Nil, - returnNullable = false) + createSerializerForPrimitiveArray(input, dt) } else { - NewInstance( - classOf[GenericArrayData], - input :: Nil, - dataType = ArrayType(dt, schemaFor(elementType).nullable)) + createSerializerForGenericArray(input, dt, nullable = schemaFor(elementType).nullable) } case dt => - NewInstance( - classOf[GenericArrayData], - input :: Nil, - dataType = ArrayType(dt, schemaFor(elementType).nullable)) + createSerializerForGenericArray(input, dt, nullable = schemaFor(elementType).nullable) } } @@ -450,14 +441,17 @@ object ScalaReflection extends ScalaReflection { val keyPath = s"""- map key class: "$keyClsName"""" +: walkedTypePath val valuePath = s"""- map value class: "$valueClsName"""" +: walkedTypePath - ExternalMapToCatalyst( + createSerializerForMap( inputObject, - dataTypeFor(keyType), - serializerFor(_, keyType, keyPath, seenTypeSet), - keyNullable = !keyType.typeSymbol.asClass.isPrimitive, - dataTypeFor(valueType), - serializerFor(_, valueType, valuePath, seenTypeSet), - valueNullable = !valueType.typeSymbol.asClass.isPrimitive) + MapElementInformation( + dataTypeFor(keyType), + nullable = !keyType.typeSymbol.asClass.isPrimitive, + serializerFor(_, keyType, keyPath, seenTypeSet)), + MapElementInformation( + dataTypeFor(valueType), + nullable = !valueType.typeSymbol.asClass.isPrimitive, + serializerFor(_, valueType, valuePath, seenTypeSet)) + ) case t if t <:< localTypeOf[scala.collection.Set[_]] => val TypeRef(_, _, Seq(elementType)) = t @@ -472,92 +466,35 @@ object ScalaReflection extends ScalaReflection { toCatalystArray(newInput, elementType) - case t if t <:< localTypeOf[String] => - StaticInvoke( - classOf[UTF8String], - StringType, - "fromString", - inputObject :: Nil, - returnNullable = false) + case t if t <:< localTypeOf[String] => createSerializerForString(inputObject) - case t if t <:< localTypeOf[java.time.Instant] => - StaticInvoke( - DateTimeUtils.getClass, - TimestampType, - "instantToMicros", - inputObject :: Nil, - returnNullable = false) + case t if t <:< localTypeOf[java.time.Instant] => createSerializerForJavaInstant(inputObject) case t if t <:< localTypeOf[java.sql.Timestamp] => - StaticInvoke( - DateTimeUtils.getClass, - TimestampType, - "fromJavaTimestamp", - inputObject :: Nil, - returnNullable = false) + createSerializerForSqlTimestamp(inputObject) case t if t <:< localTypeOf[java.time.LocalDate] => - StaticInvoke( - DateTimeUtils.getClass, - DateType, - "localDateToDays", - inputObject :: Nil, - returnNullable = false) + createSerializerForJavaLocalDate(inputObject) - case t if t <:< localTypeOf[java.sql.Date] => - StaticInvoke( - DateTimeUtils.getClass, - DateType, - "fromJavaDate", - inputObject :: Nil, - returnNullable = false) + case t if t <:< localTypeOf[java.sql.Date] => createSerializerForSqlDate(inputObject) - case t if t <:< localTypeOf[BigDecimal] => - StaticInvoke( - Decimal.getClass, - DecimalType.SYSTEM_DEFAULT, - "apply", - inputObject :: Nil, - returnNullable = false) + case t if t <:< localTypeOf[BigDecimal] => createSerializerForScalaBigDecimal(inputObject) case t if t <:< localTypeOf[java.math.BigDecimal] => - StaticInvoke( - Decimal.getClass, - DecimalType.SYSTEM_DEFAULT, - "apply", - inputObject :: Nil, - returnNullable = false) + createSerializerForJavaBigDecimal(inputObject) case t if t <:< localTypeOf[java.math.BigInteger] => - StaticInvoke( - Decimal.getClass, - DecimalType.BigIntDecimal, - "apply", - inputObject :: Nil, - returnNullable = false) + createSerializerForJavaBigInteger(inputObject) - case t if t <:< localTypeOf[scala.math.BigInt] => - StaticInvoke( - Decimal.getClass, - DecimalType.BigIntDecimal, - "apply", - inputObject :: Nil, - returnNullable = false) + case t if t <:< localTypeOf[scala.math.BigInt] => createSerializerForScalaBigInt(inputObject) - case t if t <:< localTypeOf[java.lang.Integer] => - Invoke(inputObject, "intValue", IntegerType) - case t if t <:< localTypeOf[java.lang.Long] => - Invoke(inputObject, "longValue", LongType) - case t if t <:< localTypeOf[java.lang.Double] => - Invoke(inputObject, "doubleValue", DoubleType) - case t if t <:< localTypeOf[java.lang.Float] => - Invoke(inputObject, "floatValue", FloatType) - case t if t <:< localTypeOf[java.lang.Short] => - Invoke(inputObject, "shortValue", ShortType) - case t if t <:< localTypeOf[java.lang.Byte] => - Invoke(inputObject, "byteValue", ByteType) - case t if t <:< localTypeOf[java.lang.Boolean] => - Invoke(inputObject, "booleanValue", BooleanType) + case t if t <:< localTypeOf[java.lang.Integer] => createSerializerForInteger(inputObject) + case t if t <:< localTypeOf[java.lang.Long] => createSerializerForLong(inputObject) + case t if t <:< localTypeOf[java.lang.Double] => createSerializerForDouble(inputObject) + case t if t <:< localTypeOf[java.lang.Float] => createSerializerForFloat(inputObject) + case t if t <:< localTypeOf[java.lang.Short] => createSerializerForShort(inputObject) + case t if t <:< localTypeOf[java.lang.Byte] => createSerializerForByte(inputObject) + case t if t <:< localTypeOf[java.lang.Boolean] => createSerializerForBoolean(inputObject) case t if t.typeSymbol.annotations.exists(_.tree.tpe =:= typeOf[SQLUserDefinedType]) => val udt = getClassFromType(t) @@ -584,7 +521,7 @@ object ScalaReflection extends ScalaReflection { } val params = getConstructorParameters(t) - val nonNullOutput = CreateNamedStruct(params.flatMap { case (fieldName, fieldType) => + val fields = params.map { case (fieldName, fieldType) => if (javaKeywords.contains(fieldName)) { throw new UnsupportedOperationException(s"`$fieldName` is a reserved keyword and " + "cannot be used as field name\n" + walkedTypePath.mkString("\n")) @@ -598,13 +535,11 @@ object ScalaReflection extends ScalaReflection { returnNullable = !fieldType.typeSymbol.asClass.isPrimitive) val clsName = getClassNameFromType(fieldType) val newPath = s"""- field (class: "$clsName", name: "$fieldName")""" +: walkedTypePath - expressions.Literal(fieldName) :: - serializerFor(fieldValue, fieldType, newPath, seenTypeSet + t) :: Nil - }) - val nullOutput = expressions.Literal.create(null, nonNullOutput.dataType) - expressions.If(IsNull(inputObject), nullOutput, nonNullOutput) + (fieldName, serializerFor(fieldValue, fieldType, newPath, seenTypeSet + t)) + } + createSerializerForObject(inputObject, fields) - case other => + case _ => throw new UnsupportedOperationException( s"No Encoder found for $tpe\n" + walkedTypePath.mkString("\n")) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala new file mode 100644 index 0000000000000..79e56c2999a1b --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala @@ -0,0 +1,190 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst + +import org.apache.spark.sql.catalyst.expressions.{CreateNamedStruct, Expression, IsNull, UnsafeArrayData} +import org.apache.spark.sql.catalyst.expressions.objects._ +import org.apache.spark.sql.catalyst.util.{DateTimeUtils, GenericArrayData} +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + +object SerializerBuildHelper { + + def createSerializerForBoolean(inputObject: Expression): Expression = { + Invoke(inputObject, "booleanValue", BooleanType) + } + + def createSerializerForByte(inputObject: Expression): Expression = { + Invoke(inputObject, "byteValue", ByteType) + } + + def createSerializerForShort(inputObject: Expression): Expression = { + Invoke(inputObject, "shortValue", ShortType) + } + + def createSerializerForInteger(inputObject: Expression): Expression = { + Invoke(inputObject, "intValue", IntegerType) + } + + def createSerializerForLong(inputObject: Expression): Expression = { + Invoke(inputObject, "longValue", LongType) + } + + def createSerializerForFloat(inputObject: Expression): Expression = { + Invoke(inputObject, "floatValue", FloatType) + } + + def createSerializerForDouble(inputObject: Expression): Expression = { + Invoke(inputObject, "doubleValue", DoubleType) + } + + def createSerializerForString(inputObject: Expression): Expression = { + StaticInvoke( + classOf[UTF8String], + StringType, + "fromString", + inputObject :: Nil, + returnNullable = false) + } + + def createSerializerForJavaInstant(inputObject: Expression): Expression = { + StaticInvoke( + DateTimeUtils.getClass, + TimestampType, + "instantToMicros", + inputObject :: Nil, + returnNullable = false) + } + + def createSerializerForSqlTimestamp(inputObject: Expression): Expression = { + StaticInvoke( + DateTimeUtils.getClass, + TimestampType, + "fromJavaTimestamp", + inputObject :: Nil, + returnNullable = false) + } + + def createSerializerForJavaLocalDate(inputObject: Expression): Expression = { + StaticInvoke( + DateTimeUtils.getClass, + DateType, + "localDateToDays", + inputObject :: Nil, + returnNullable = false) + } + + def createSerializerForSqlDate(inputObject: Expression): Expression = { + StaticInvoke( + DateTimeUtils.getClass, + DateType, + "fromJavaDate", + inputObject :: Nil, + returnNullable = false) + } + + def createSerializerForJavaBigDecimal(inputObject: Expression): Expression = { + StaticInvoke( + Decimal.getClass, + DecimalType.SYSTEM_DEFAULT, + "apply", + inputObject :: Nil, + returnNullable = false) + } + + def createSerializerForScalaBigDecimal(inputObject: Expression): Expression = { + createSerializerForJavaBigDecimal(inputObject) + } + + def createSerializerForJavaBigInteger(inputObject: Expression): Expression = { + StaticInvoke( + Decimal.getClass, + DecimalType.BigIntDecimal, + "apply", + inputObject :: Nil, + returnNullable = false) + } + + def createSerializerForScalaBigInt(inputObject: Expression): Expression = { + createSerializerForJavaBigInteger(inputObject) + } + + def createSerializerForPrimitiveArray( + inputObject: Expression, + dataType: DataType): Expression = { + StaticInvoke( + classOf[UnsafeArrayData], + ArrayType(dataType, false), + "fromPrimitiveArray", + inputObject :: Nil, + returnNullable = false) + } + + def createSerializerForGenericArray( + inputObject: Expression, + dataType: DataType, + nullable: Boolean): Expression = { + NewInstance( + classOf[GenericArrayData], + inputObject :: Nil, + dataType = ArrayType(dataType, nullable)) + } + + def createSerializerForMapObjects( + inputObject: Expression, + dataType: ObjectType, + funcForNewExpr: Expression => Expression): Expression = { + MapObjects(funcForNewExpr, inputObject, dataType) + } + + case class MapElementInformation( + dataType: DataType, + nullable: Boolean, + funcForNewExpr: Expression => Expression) + + def createSerializerForMap( + inputObject: Expression, + keyInformation: MapElementInformation, + valueInformation: MapElementInformation): Expression = { + ExternalMapToCatalyst( + inputObject, + keyInformation.dataType, + keyInformation.funcForNewExpr, + keyNullable = keyInformation.nullable, + valueInformation.dataType, + valueInformation.funcForNewExpr, + valueNullable = valueInformation.nullable + ) + } + + private def argumentsForFieldSerializer( + fieldName: String, + serializerForFieldValue: Expression): Seq[Expression] = { + expressions.Literal(fieldName) :: serializerForFieldValue :: Nil + } + + def createSerializerForObject( + inputObject: Expression, + fields: Seq[(String, Expression)]): Expression = { + val nonNullOutput = CreateNamedStruct(fields.flatMap { case(fieldName, fieldExpr) => + argumentsForFieldSerializer(fieldName, fieldExpr) + }) + val nullOutput = expressions.Literal.create(null, nonNullOutput.dataType) + expressions.If(IsNull(inputObject), nullOutput, nonNullOutput) + } +} From d683d8022ab99a373dc77d0a0fa4801f959924c3 Mon Sep 17 00:00:00 2001 From: "Jungtaek Lim (HeartSaVioR)" Date: Wed, 27 Feb 2019 22:45:03 +0900 Subject: [PATCH 02/14] Also extract the logic on recording walked type path --- .../sql/catalyst/JavaTypeInference.scala | 25 ++++------ .../spark/sql/catalyst/ScalaReflection.scala | 27 +++++----- .../sql/catalyst/WalkedTypePathRecorder.scala | 49 +++++++++++++++++++ 3 files changed, 72 insertions(+), 29 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/WalkedTypePathRecorder.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala index ba22f86dae27f..12fead9deebc9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala @@ -28,12 +28,12 @@ import com.google.common.reflect.TypeToken import org.apache.spark.sql.catalyst.DeserializerBuildHelper._ import org.apache.spark.sql.catalyst.SerializerBuildHelper._ +import org.apache.spark.sql.catalyst.WalkedTypePathRecorder._ import org.apache.spark.sql.catalyst.analysis.GetColumnByOrdinal import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.objects._ -import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils, GenericArrayData} +import org.apache.spark.sql.catalyst.util.ArrayBasedMapData import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String /** * Type-inference utilities for POJOs and Java collections. @@ -196,7 +196,7 @@ object JavaTypeInference { */ def deserializerFor(beanClass: Class[_]): Expression = { val typeToken = TypeToken.of(beanClass) - val walkedTypePath = s"""- root class: "${beanClass.getCanonicalName}"""" :: Nil + val walkedTypePath = recordRoot(beanClass.getCanonicalName) val (dataType, nullable) = inferDataType(typeToken) // Assumes we are deserializing the first column of a row. @@ -245,8 +245,7 @@ object JavaTypeInference { case c if c.isArray => val elementType = c.getComponentType - val newTypePath = s"""- array element class: "${elementType.getCanonicalName}"""" +: - walkedTypePath + val newTypePath = recordArray(walkedTypePath, elementType.getCanonicalName) val (dataType, elementNullable) = inferDataType(elementType) val mapFunction: Expression => Expression = element => { // upcast the array element to the data type the encoder expected. @@ -275,8 +274,7 @@ object JavaTypeInference { case c if listType.isAssignableFrom(typeToken) => val et = elementType(typeToken) - val newTypePath = s"""- array element class: "${et.getType.getTypeName}"""" +: - walkedTypePath + val newTypePath = recordArray(walkedTypePath, et.getType.getTypeName) val (dataType, elementNullable) = inferDataType(et) val mapFunction: Expression => Expression = element => { // upcast the array element to the data type the encoder expected. @@ -292,8 +290,8 @@ object JavaTypeInference { case _ if mapType.isAssignableFrom(typeToken) => val (keyType, valueType) = mapKeyValueType(typeToken) - val newTypePath = (s"""- map key class: "${keyType.getType.getTypeName}"""" + - s""", value class: "${valueType.getType.getTypeName}"""") +: walkedTypePath + val newTypePath = recordMap(walkedTypePath, keyType.getType.getTypeName, + valueType.getType.getTypeName) val keyData = Invoke( @@ -329,8 +327,7 @@ object JavaTypeInference { val fieldName = p.getName val fieldType = typeToken.method(p.getReadMethod).getReturnType val (dataType, nullable) = inferDataType(fieldType) - val newTypePath = (s"""- field (class: "${fieldType.getType.getTypeName}"""" + - s""", name: "$fieldName")""") +: walkedTypePath + val newTypePath = recordField(walkedTypePath, fieldType.getType.getTypeName, fieldName) val setter = deserializerForWithNullSafety( path, dataType, @@ -414,13 +411,11 @@ object JavaTypeInference { MapElementInformation( ObjectType(keyType.getRawType), nullable = true, - serializerFor(_, keyType) - ), + serializerFor(_, keyType)), MapElementInformation( ObjectType(valueType.getRawType), nullable = true, - serializerFor(_, valueType) - ), + serializerFor(_, valueType)) ) case other if other.isEnum => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 9f496baa43810..2f3c15e604bfb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -22,6 +22,7 @@ import org.apache.commons.lang3.reflect.ConstructorUtils import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.DeserializerBuildHelper._ import org.apache.spark.sql.catalyst.SerializerBuildHelper._ +import org.apache.spark.sql.catalyst.WalkedTypePathRecorder._ import org.apache.spark.sql.catalyst.analysis.GetColumnByOrdinal import org.apache.spark.sql.catalyst.expressions.{Expression, _} import org.apache.spark.sql.catalyst.expressions.objects._ @@ -137,7 +138,7 @@ object ScalaReflection extends ScalaReflection { */ def deserializerForType(tpe: `Type`): Expression = { val clsName = getClassNameFromType(tpe) - val walkedTypePath = s"""- root class: "$clsName"""" :: Nil + val walkedTypePath = recordRoot(clsName) val Schema(dataType, nullable) = schemaFor(tpe) // Assumes we are deserializing the first column of a row. @@ -164,7 +165,7 @@ object ScalaReflection extends ScalaReflection { case t if t <:< localTypeOf[Option[_]] => val TypeRef(_, _, Seq(optType)) = t val className = getClassNameFromType(optType) - val newTypePath = s"""- option value class: "$className"""" +: walkedTypePath + val newTypePath = recordOption(walkedTypePath, className) WrapOption(deserializerFor(optType, path, newTypePath), dataTypeFor(optType)) case t if t <:< localTypeOf[java.lang.Integer] => @@ -226,7 +227,7 @@ object ScalaReflection extends ScalaReflection { val TypeRef(_, _, Seq(elementType)) = t val Schema(dataType, elementNullable) = schemaFor(elementType) val className = getClassNameFromType(elementType) - val newTypePath = s"""- array element class: "$className"""" +: walkedTypePath + val newTypePath = recordArray(walkedTypePath, className) val mapFunction: Expression => Expression = element => { // upcast the array element to the data type the encoder expected. @@ -261,7 +262,7 @@ object ScalaReflection extends ScalaReflection { val TypeRef(_, _, Seq(elementType)) = t val Schema(dataType, elementNullable) = schemaFor(elementType) val className = getClassNameFromType(elementType) - val newTypePath = s"""- array element class: "$className"""" +: walkedTypePath + val newTypePath = recordArray(walkedTypePath, className) val mapFunction: Expression => Expression = element => { deserializerForWithNullSafetyAndUpcast( @@ -287,8 +288,7 @@ object ScalaReflection extends ScalaReflection { val classNameForKey = getClassNameFromType(keyType) val classNameForValue = getClassNameFromType(valueType) - val newTypePath = (s"""- map key class: "${classNameForKey}"""" + - s""", value class: "${classNameForValue}"""") +: walkedTypePath + val newTypePath = recordMap(walkedTypePath, classNameForKey, classNameForValue) UnresolvedCatalystToExternalMap( path, @@ -323,8 +323,7 @@ object ScalaReflection extends ScalaReflection { val arguments = params.zipWithIndex.map { case ((fieldName, fieldType), i) => val Schema(dataType, nullable) = schemaFor(fieldType) val clsName = getClassNameFromType(fieldType) - val newTypePath = (s"""- field (class: "$clsName", """ + - s"""name: "$fieldName")""") +: walkedTypePath + val newTypePath = recordField(walkedTypePath, clsName, fieldName) // For tuples, we based grab the inner fields by ordinal instead of name. deserializerForWithNullSafety( @@ -372,7 +371,7 @@ object ScalaReflection extends ScalaReflection { */ def serializerForType(tpe: `Type`): Expression = ScalaReflection.cleanUpReflectionObjects { val clsName = getClassNameFromType(tpe) - val walkedTypePath = s"""- root class: "$clsName"""" :: Nil + val walkedTypePath = recordRoot(clsName) // The input object to `ExpressionEncoder` is located at first column of an row. val isPrimitive = tpe.typeSymbol.asClass.isPrimitive @@ -395,7 +394,7 @@ object ScalaReflection extends ScalaReflection { dataTypeFor(elementType) match { case dt: ObjectType => val clsName = getClassNameFromType(elementType) - val newPath = s"""- array element class: "$clsName"""" +: walkedTypePath + val newPath = recordArray(walkedTypePath, clsName) createSerializerForMapObjects(input, dt, serializerFor(_, elementType, newPath, seenTypeSet)) @@ -419,7 +418,7 @@ object ScalaReflection extends ScalaReflection { case t if t <:< localTypeOf[Option[_]] => val TypeRef(_, _, Seq(optType)) = t val className = getClassNameFromType(optType) - val newPath = s"""- option value class: "$className"""" +: walkedTypePath + val newPath = recordOption(walkedTypePath, className) val unwrapped = UnwrapOption(dataTypeFor(optType), inputObject) serializerFor(unwrapped, optType, newPath, seenTypeSet) @@ -438,8 +437,8 @@ object ScalaReflection extends ScalaReflection { val TypeRef(_, _, Seq(keyType, valueType)) = t val keyClsName = getClassNameFromType(keyType) val valueClsName = getClassNameFromType(valueType) - val keyPath = s"""- map key class: "$keyClsName"""" +: walkedTypePath - val valuePath = s"""- map value class: "$valueClsName"""" +: walkedTypePath + val keyPath = recordKeyForMap(walkedTypePath, keyClsName) + val valuePath = recordValueForMap(walkedTypePath, valueClsName) createSerializerForMap( inputObject, @@ -534,7 +533,7 @@ object ScalaReflection extends ScalaReflection { val fieldValue = Invoke(KnownNotNull(inputObject), fieldName, dataTypeFor(fieldType), returnNullable = !fieldType.typeSymbol.asClass.isPrimitive) val clsName = getClassNameFromType(fieldType) - val newPath = s"""- field (class: "$clsName", name: "$fieldName")""" +: walkedTypePath + val newPath = recordField(walkedTypePath, clsName, fieldName) (fieldName, serializerFor(fieldValue, fieldType, newPath, seenTypeSet + t)) } createSerializerForObject(inputObject, fields) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/WalkedTypePathRecorder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/WalkedTypePathRecorder.scala new file mode 100644 index 0000000000000..ca45168ffa7ee --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/WalkedTypePathRecorder.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst + +object WalkedTypePathRecorder { + def recordRoot(className: String): Seq[String] = s"""- root class: "$className"""" :: Nil + + def recordOption(walkedTypePath: Seq[String], className: String): Seq[String] = + s"""- option value class: "$className"""" +: walkedTypePath + + def recordArray(walkedTypePath: Seq[String], elementClassName: String): Seq[String] = + s"""- array element class: "$elementClassName"""" +: walkedTypePath + + def recordMap( + walkedTypePath: Seq[String], + keyClassName: String, + valueClassName: String): Seq[String] = { + (s"""- map key class: "$keyClassName"""" + + s""", value class: "$valueClassName"""") +: walkedTypePath + } + + def recordKeyForMap(walkedTypePath: Seq[String], keyClassName: String): Seq[String] = + s"""- map key class: "$keyClassName"""" +: walkedTypePath + + def recordValueForMap(walkedTypePath: Seq[String], valueClassName: String): Seq[String] = + s"""- map value class: "$valueClassName"""" +: walkedTypePath + + def recordField( + walkedTypePath: Seq[String], + className: String, + fieldName: String): Seq[String] = { + s"""- field (class: "$className", name: "$fieldName")""" +: walkedTypePath + } +} From 43a69f0e63d497d3cff9c5eb9d023572c6ca5cfa Mon Sep 17 00:00:00 2001 From: "Jungtaek Lim (HeartSaVioR)" Date: Wed, 27 Feb 2019 23:20:02 +0900 Subject: [PATCH 03/14] Address UDT as well --- .../spark/sql/catalyst/ScalaReflection.scala | 14 ++++---------- .../spark/sql/catalyst/SerializerBuildHelper.scala | 8 ++++++++ 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 2f3c15e604bfb..0b22d14a1ba07 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -498,20 +498,14 @@ object ScalaReflection extends ScalaReflection { case t if t.typeSymbol.annotations.exists(_.tree.tpe =:= typeOf[SQLUserDefinedType]) => val udt = getClassFromType(t) .getAnnotation(classOf[SQLUserDefinedType]).udt().getConstructor().newInstance() - val obj = NewInstance( - udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt(), - Nil, - dataType = ObjectType(udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt())) - Invoke(obj, "serialize", udt, inputObject :: Nil) + val udtClass = udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt() + createSerializerForUserDefinedType(inputObject, udt, udtClass) case t if UDTRegistration.exists(getClassNameFromType(t)) => val udt = UDTRegistration.getUDTFor(getClassNameFromType(t)).get.getConstructor(). newInstance().asInstanceOf[UserDefinedType[_]] - val obj = NewInstance( - udt.getClass, - Nil, - dataType = ObjectType(udt.getClass)) - Invoke(obj, "serialize", udt, inputObject :: Nil) + val udtClass = udt.getClass + createSerializerForUserDefinedType(inputObject, udt, udtClass) case t if definedByConstructorParams(t) => if (seenTypeSet.contains(t)) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala index 79e56c2999a1b..e035c4be97240 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala @@ -187,4 +187,12 @@ object SerializerBuildHelper { val nullOutput = expressions.Literal.create(null, nonNullOutput.dataType) expressions.If(IsNull(inputObject), nullOutput, nonNullOutput) } + + def createSerializerForUserDefinedType( + inputObject: Expression, + udt: UserDefinedType[_], + udtClass: Class[_]): Expression = { + val obj = NewInstance(udtClass, Nil, dataType = ObjectType(udtClass)) + Invoke(obj, "serialize", udt, inputObject :: Nil) + } } From 371a2d16a6089a79fcb938da38e90ec663403810 Mon Sep 17 00:00:00 2001 From: "Jungtaek Lim (HeartSaVioR)" Date: Thu, 28 Feb 2019 04:51:03 +0900 Subject: [PATCH 04/14] Fix a silly mistake --- .../scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala index 12fead9deebc9..56043dd76ed17 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala @@ -431,7 +431,7 @@ object JavaTypeInference { inputObject, p.getReadMethod.getName, inferExternalType(fieldType.getRawType)) - (fieldName, fieldValue) + (fieldName, serializerFor(fieldValue, fieldType)) } createSerializerForObject(inputObject, fields) } From 4dfe3c76f9a928e960b7f490ef8022ceca012ba9 Mon Sep 17 00:00:00 2001 From: "Jungtaek Lim (HeartSaVioR)" Date: Thu, 28 Feb 2019 05:39:48 +0900 Subject: [PATCH 05/14] Introduce WalkedTypePath class to replace recording walked path in `Seq[String]` --- .../catalyst/DeserializerBuildHelper.scala | 16 +++--- .../sql/catalyst/JavaTypeInference.scala | 15 +++--- .../spark/sql/catalyst/ScalaReflection.scala | 33 ++++++------ .../spark/sql/catalyst/WalkedTypePath.scala | 50 +++++++++++++++++++ .../sql/catalyst/WalkedTypePathRecorder.scala | 49 ------------------ .../sql/catalyst/analysis/Analyzer.scala | 6 +-- .../catalyst/encoders/ExpressionEncoder.scala | 4 +- .../sql/catalyst/encoders/RowEncoder.scala | 2 +- .../spark/sql/catalyst/expressions/Cast.scala | 7 ++- .../expressions/objects/objects.scala | 6 +-- .../expressions/CodeGenerationSuite.scala | 2 +- .../expressions/NullExpressionsSuite.scala | 2 +- 12 files changed, 97 insertions(+), 95 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/WalkedTypePath.scala delete mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/WalkedTypePathRecorder.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala index d75d3ca918c49..5245caed6b516 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala @@ -29,7 +29,7 @@ object DeserializerBuildHelper { path: Expression, part: String, dataType: DataType, - walkedTypePath: Seq[String]): Expression = { + walkedTypePath: WalkedTypePath): Expression = { val newPath = UnresolvedExtractValue(path, expressions.Literal(part)) upCastToExpectedType(newPath, dataType, walkedTypePath) } @@ -39,7 +39,7 @@ object DeserializerBuildHelper { path: Expression, ordinal: Int, dataType: DataType, - walkedTypePath: Seq[String]): Expression = { + walkedTypePath: WalkedTypePath): Expression = { val newPath = GetStructField(path, ordinal) upCastToExpectedType(newPath, dataType, walkedTypePath) } @@ -48,8 +48,8 @@ object DeserializerBuildHelper { expr: Expression, dataType: DataType, nullable: Boolean, - walkedTypePath: Seq[String], - funcForCreatingNewExpr: (Expression, Seq[String]) => Expression): Expression = { + walkedTypePath: WalkedTypePath, + funcForCreatingNewExpr: (Expression, WalkedTypePath) => Expression): Expression = { val newExpr = funcForCreatingNewExpr(expr, walkedTypePath) expressionWithNullSafety(newExpr, nullable, walkedTypePath) } @@ -58,8 +58,8 @@ object DeserializerBuildHelper { expr: Expression, dataType: DataType, nullable: Boolean, - walkedTypePath: Seq[String], - funcForCreatingNewExpr: (Expression, Seq[String]) => Expression): Expression = { + walkedTypePath: WalkedTypePath, + funcForCreatingNewExpr: (Expression, WalkedTypePath) => Expression): Expression = { val casted = upCastToExpectedType(expr, dataType, walkedTypePath) deserializerForWithNullSafety(casted, dataType, nullable, walkedTypePath, funcForCreatingNewExpr) @@ -68,7 +68,7 @@ object DeserializerBuildHelper { private def expressionWithNullSafety( expr: Expression, nullable: Boolean, - walkedTypePath: Seq[String]): Expression = { + walkedTypePath: WalkedTypePath): Expression = { if (nullable) { expr } else { @@ -167,7 +167,7 @@ object DeserializerBuildHelper { private def upCastToExpectedType( expr: Expression, expected: DataType, - walkedTypePath: Seq[String]): Expression = expected match { + walkedTypePath: WalkedTypePath): Expression = expected match { case _: StructType => expr case _: ArrayType => expr case _: MapType => expr diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala index 56043dd76ed17..4865429bc23be 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala @@ -28,7 +28,6 @@ import com.google.common.reflect.TypeToken import org.apache.spark.sql.catalyst.DeserializerBuildHelper._ import org.apache.spark.sql.catalyst.SerializerBuildHelper._ -import org.apache.spark.sql.catalyst.WalkedTypePathRecorder._ import org.apache.spark.sql.catalyst.analysis.GetColumnByOrdinal import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.objects._ @@ -196,7 +195,7 @@ object JavaTypeInference { */ def deserializerFor(beanClass: Class[_]): Expression = { val typeToken = TypeToken.of(beanClass) - val walkedTypePath = recordRoot(beanClass.getCanonicalName) + val walkedTypePath = new WalkedTypePath().recordRoot(beanClass.getCanonicalName) val (dataType, nullable) = inferDataType(typeToken) // Assumes we are deserializing the first column of a row. @@ -209,7 +208,7 @@ object JavaTypeInference { private def deserializerFor( typeToken: TypeToken[_], path: Expression, - walkedTypePath: Seq[String]): Expression = { + walkedTypePath: WalkedTypePath): Expression = { typeToken.getRawType match { case c if !inferExternalType(c).isInstanceOf[ObjectType] => path @@ -245,7 +244,7 @@ object JavaTypeInference { case c if c.isArray => val elementType = c.getComponentType - val newTypePath = recordArray(walkedTypePath, elementType.getCanonicalName) + val newTypePath = walkedTypePath.recordArray(elementType.getCanonicalName) val (dataType, elementNullable) = inferDataType(elementType) val mapFunction: Expression => Expression = element => { // upcast the array element to the data type the encoder expected. @@ -274,7 +273,7 @@ object JavaTypeInference { case c if listType.isAssignableFrom(typeToken) => val et = elementType(typeToken) - val newTypePath = recordArray(walkedTypePath, et.getType.getTypeName) + val newTypePath = walkedTypePath.recordArray(et.getType.getTypeName) val (dataType, elementNullable) = inferDataType(et) val mapFunction: Expression => Expression = element => { // upcast the array element to the data type the encoder expected. @@ -290,7 +289,7 @@ object JavaTypeInference { case _ if mapType.isAssignableFrom(typeToken) => val (keyType, valueType) = mapKeyValueType(typeToken) - val newTypePath = recordMap(walkedTypePath, keyType.getType.getTypeName, + val newTypePath = walkedTypePath.recordMap(keyType.getType.getTypeName, valueType.getType.getTypeName) val keyData = @@ -327,7 +326,7 @@ object JavaTypeInference { val fieldName = p.getName val fieldType = typeToken.method(p.getReadMethod).getReturnType val (dataType, nullable) = inferDataType(fieldType) - val newTypePath = recordField(walkedTypePath, fieldType.getType.getTypeName, fieldName) + val newTypePath = walkedTypePath.recordField(fieldType.getType.getTypeName, fieldName) val setter = deserializerForWithNullSafety( path, dataType, @@ -356,7 +355,7 @@ object JavaTypeInference { */ def serializerFor(beanClass: Class[_]): Expression = { val inputObject = BoundReference(0, ObjectType(beanClass), nullable = true) - val nullSafeInput = AssertNotNull(inputObject, Seq("top level input bean")) + val nullSafeInput = AssertNotNull(inputObject, new WalkedTypePath(Seq("top level input bean"))) serializerFor(nullSafeInput, TypeToken.of(beanClass)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 0b22d14a1ba07..32c54af771585 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -22,7 +22,6 @@ import org.apache.commons.lang3.reflect.ConstructorUtils import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.DeserializerBuildHelper._ import org.apache.spark.sql.catalyst.SerializerBuildHelper._ -import org.apache.spark.sql.catalyst.WalkedTypePathRecorder._ import org.apache.spark.sql.catalyst.analysis.GetColumnByOrdinal import org.apache.spark.sql.catalyst.expressions.{Expression, _} import org.apache.spark.sql.catalyst.expressions.objects._ @@ -138,7 +137,7 @@ object ScalaReflection extends ScalaReflection { */ def deserializerForType(tpe: `Type`): Expression = { val clsName = getClassNameFromType(tpe) - val walkedTypePath = recordRoot(clsName) + val walkedTypePath = new WalkedTypePath().recordRoot(clsName) val Schema(dataType, nullable) = schemaFor(tpe) // Assumes we are deserializing the first column of a row. @@ -158,14 +157,14 @@ object ScalaReflection extends ScalaReflection { private def deserializerFor( tpe: `Type`, path: Expression, - walkedTypePath: Seq[String]): Expression = cleanUpReflectionObjects { + walkedTypePath: WalkedTypePath): Expression = cleanUpReflectionObjects { tpe.dealias match { case t if !dataTypeFor(t).isInstanceOf[ObjectType] => path case t if t <:< localTypeOf[Option[_]] => val TypeRef(_, _, Seq(optType)) = t val className = getClassNameFromType(optType) - val newTypePath = recordOption(walkedTypePath, className) + val newTypePath = walkedTypePath.recordOption(className) WrapOption(deserializerFor(optType, path, newTypePath), dataTypeFor(optType)) case t if t <:< localTypeOf[java.lang.Integer] => @@ -227,7 +226,7 @@ object ScalaReflection extends ScalaReflection { val TypeRef(_, _, Seq(elementType)) = t val Schema(dataType, elementNullable) = schemaFor(elementType) val className = getClassNameFromType(elementType) - val newTypePath = recordArray(walkedTypePath, className) + val newTypePath = walkedTypePath.recordArray(className) val mapFunction: Expression => Expression = element => { // upcast the array element to the data type the encoder expected. @@ -262,7 +261,7 @@ object ScalaReflection extends ScalaReflection { val TypeRef(_, _, Seq(elementType)) = t val Schema(dataType, elementNullable) = schemaFor(elementType) val className = getClassNameFromType(elementType) - val newTypePath = recordArray(walkedTypePath, className) + val newTypePath = walkedTypePath.recordArray(className) val mapFunction: Expression => Expression = element => { deserializerForWithNullSafetyAndUpcast( @@ -288,7 +287,7 @@ object ScalaReflection extends ScalaReflection { val classNameForKey = getClassNameFromType(keyType) val classNameForValue = getClassNameFromType(valueType) - val newTypePath = recordMap(walkedTypePath, classNameForKey, classNameForValue) + val newTypePath = walkedTypePath.recordMap(classNameForKey, classNameForValue) UnresolvedCatalystToExternalMap( path, @@ -323,7 +322,7 @@ object ScalaReflection extends ScalaReflection { val arguments = params.zipWithIndex.map { case ((fieldName, fieldType), i) => val Schema(dataType, nullable) = schemaFor(fieldType) val clsName = getClassNameFromType(fieldType) - val newTypePath = recordField(walkedTypePath, clsName, fieldName) + val newTypePath = walkedTypePath.recordField(clsName, fieldName) // For tuples, we based grab the inner fields by ordinal instead of name. deserializerForWithNullSafety( @@ -371,7 +370,7 @@ object ScalaReflection extends ScalaReflection { */ def serializerForType(tpe: `Type`): Expression = ScalaReflection.cleanUpReflectionObjects { val clsName = getClassNameFromType(tpe) - val walkedTypePath = recordRoot(clsName) + val walkedTypePath = new WalkedTypePath().recordRoot(clsName) // The input object to `ExpressionEncoder` is located at first column of an row. val isPrimitive = tpe.typeSymbol.asClass.isPrimitive @@ -387,14 +386,14 @@ object ScalaReflection extends ScalaReflection { private def serializerFor( inputObject: Expression, tpe: `Type`, - walkedTypePath: Seq[String], + walkedTypePath: WalkedTypePath, seenTypeSet: Set[`Type`] = Set.empty): Expression = cleanUpReflectionObjects { def toCatalystArray(input: Expression, elementType: `Type`): Expression = { dataTypeFor(elementType) match { case dt: ObjectType => val clsName = getClassNameFromType(elementType) - val newPath = recordArray(walkedTypePath, clsName) + val newPath = walkedTypePath.recordArray(clsName) createSerializerForMapObjects(input, dt, serializerFor(_, elementType, newPath, seenTypeSet)) @@ -418,7 +417,7 @@ object ScalaReflection extends ScalaReflection { case t if t <:< localTypeOf[Option[_]] => val TypeRef(_, _, Seq(optType)) = t val className = getClassNameFromType(optType) - val newPath = recordOption(walkedTypePath, className) + val newPath = walkedTypePath.recordOption(className) val unwrapped = UnwrapOption(dataTypeFor(optType), inputObject) serializerFor(unwrapped, optType, newPath, seenTypeSet) @@ -437,8 +436,8 @@ object ScalaReflection extends ScalaReflection { val TypeRef(_, _, Seq(keyType, valueType)) = t val keyClsName = getClassNameFromType(keyType) val valueClsName = getClassNameFromType(valueType) - val keyPath = recordKeyForMap(walkedTypePath, keyClsName) - val valuePath = recordValueForMap(walkedTypePath, valueClsName) + val keyPath = walkedTypePath.recordKeyForMap(keyClsName) + val valuePath = walkedTypePath.recordValueForMap(valueClsName) createSerializerForMap( inputObject, @@ -517,7 +516,7 @@ object ScalaReflection extends ScalaReflection { val fields = params.map { case (fieldName, fieldType) => if (javaKeywords.contains(fieldName)) { throw new UnsupportedOperationException(s"`$fieldName` is a reserved keyword and " + - "cannot be used as field name\n" + walkedTypePath.mkString("\n")) + "cannot be used as field name\n" + walkedTypePath) } // SPARK-26730 inputObject won't be null with If's guard below. And KnownNotNul @@ -527,14 +526,14 @@ object ScalaReflection extends ScalaReflection { val fieldValue = Invoke(KnownNotNull(inputObject), fieldName, dataTypeFor(fieldType), returnNullable = !fieldType.typeSymbol.asClass.isPrimitive) val clsName = getClassNameFromType(fieldType) - val newPath = recordField(walkedTypePath, clsName, fieldName) + val newPath = walkedTypePath.recordField(clsName, fieldName) (fieldName, serializerFor(fieldValue, fieldType, newPath, seenTypeSet + t)) } createSerializerForObject(inputObject, fields) case _ => throw new UnsupportedOperationException( - s"No Encoder found for $tpe\n" + walkedTypePath.mkString("\n")) + s"No Encoder found for $tpe\n" + walkedTypePath) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/WalkedTypePath.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/WalkedTypePath.scala new file mode 100644 index 0000000000000..c7cf05d434ca6 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/WalkedTypePath.scala @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst + +class WalkedTypePath(val walkedPaths: Seq[String] = Nil) extends Serializable { + def recordRoot(className: String): WalkedTypePath = + newInstance(s"""- root class: "$className"""") + + def recordOption(className: String): WalkedTypePath = + newInstance(s"""- option value class: "$className"""") + + def recordArray(elementClassName: String): WalkedTypePath = + newInstance(s"""- array element class: "$elementClassName"""") + + def recordMap(keyClassName: String, valueClassName: String): WalkedTypePath = { + newInstance(s"""- map key class: "$keyClassName"""" + + s""", value class: "$valueClassName"""") + } + + def recordKeyForMap(keyClassName: String): WalkedTypePath = + newInstance(s"""- map key class: "$keyClassName"""") + + def recordValueForMap(valueClassName: String): WalkedTypePath = + newInstance(s"""- map value class: "$valueClassName"""") + + def recordField(className: String, fieldName: String): WalkedTypePath = + newInstance(s"""- field (class: "$className", name: "$fieldName")""") + + override def toString: String = { + walkedPaths.mkString("\n") + } + + private def newInstance(newRecord: String): WalkedTypePath = + new WalkedTypePath(newRecord +: walkedPaths) +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/WalkedTypePathRecorder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/WalkedTypePathRecorder.scala deleted file mode 100644 index ca45168ffa7ee..0000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/WalkedTypePathRecorder.scala +++ /dev/null @@ -1,49 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst - -object WalkedTypePathRecorder { - def recordRoot(className: String): Seq[String] = s"""- root class: "$className"""" :: Nil - - def recordOption(walkedTypePath: Seq[String], className: String): Seq[String] = - s"""- option value class: "$className"""" +: walkedTypePath - - def recordArray(walkedTypePath: Seq[String], elementClassName: String): Seq[String] = - s"""- array element class: "$elementClassName"""" +: walkedTypePath - - def recordMap( - walkedTypePath: Seq[String], - keyClassName: String, - valueClassName: String): Seq[String] = { - (s"""- map key class: "$keyClassName"""" + - s""", value class: "$valueClassName"""") +: walkedTypePath - } - - def recordKeyForMap(walkedTypePath: Seq[String], keyClassName: String): Seq[String] = - s"""- map key class: "$keyClassName"""" +: walkedTypePath - - def recordValueForMap(walkedTypePath: Seq[String], valueClassName: String): Seq[String] = - s"""- map value class: "$valueClassName"""" +: walkedTypePath - - def recordField( - walkedTypePath: Seq[String], - className: String, - fieldName: String): Seq[String] = { - s"""- field (class: "$className", name: "$fieldName")""" +: walkedTypePath - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 42904c5c04c3c..68901e4ba8b6a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -2348,7 +2348,7 @@ class Analyzer( } else { // always add an UpCast. it will be removed in the optimizer if it is unnecessary. Some(Alias( - UpCast(queryExpr, tableAttr.dataType, Seq()), tableAttr.name + UpCast(queryExpr, tableAttr.dataType), tableAttr.name )( explicitMetadata = Option(tableAttr.metadata) )) @@ -2528,14 +2528,14 @@ class Analyzer( * Replace the [[UpCast]] expression by [[Cast]], and throw exceptions if the cast may truncate. */ object ResolveUpCast extends Rule[LogicalPlan] { - private def fail(from: Expression, to: DataType, walkedTypePath: Seq[String]) = { + private def fail(from: Expression, to: DataType, walkedTypePath: WalkedTypePath) = { val fromStr = from match { case l: LambdaVariable => "array element" case e => e.sql } throw new AnalysisException(s"Cannot up cast $fromStr from " + s"${from.dataType.catalogString} to ${to.catalogString} as it may truncate\n" + - "The type path of the target object is:\n" + walkedTypePath.mkString("", "\n", "\n") + + "The type path of the target object is:\n\n" + walkedTypePath + "\n" + "You can either add an explicit cast to the input data or choose a higher precision " + "type of the field in the target object") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index da5c1fd0feb01..8701633711341 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -21,7 +21,7 @@ import scala.reflect.ClassTag import scala.reflect.runtime.universe.{typeTag, TypeTag} import org.apache.spark.sql.Encoder -import org.apache.spark.sql.catalyst.{InternalRow, JavaTypeInference, ScalaReflection} +import org.apache.spark.sql.catalyst.{InternalRow, JavaTypeInference, ScalaReflection, WalkedTypePath} import org.apache.spark.sql.catalyst.analysis.{Analyzer, GetColumnByOrdinal, SimpleAnalyzer, UnresolvedAttribute, UnresolvedExtractValue} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection} @@ -195,7 +195,7 @@ case class ExpressionEncoder[T]( case r: BoundReference => // For input object of Product type, we can't encode it to row if it's null, as Spark SQL // doesn't allow top-level row to be null, only its columns can be null. - AssertNotNull(r, Seq("top level Product or row object")) + AssertNotNull(r, new WalkedTypePath(Seq("top level Product or row object"))) } nullSafeSerializer match { case If(_: IsNull, _, s: CreateNamedStruct) => s diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index 68a603b95ad50..97709bd6d1882 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -155,7 +155,7 @@ object RowEncoder { element => { val value = serializerFor(ValidateExternalType(element, et), et) if (!containsNull) { - AssertNotNull(value, Seq.empty) + AssertNotNull(value) } else { value } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index d591c588a95e3..f2e0319e3700e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -21,7 +21,7 @@ import java.math.{BigDecimal => JavaBigDecimal} import java.util.concurrent.TimeUnit._ import org.apache.spark.SparkException -import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.{InternalRow, WalkedTypePath} import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion} import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ @@ -1378,7 +1378,10 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String * Cast the child expression to the target data type, but will throw error if the cast might * truncate, e.g. long -> int, timestamp -> data. */ -case class UpCast(child: Expression, dataType: DataType, walkedTypePath: Seq[String]) +case class UpCast( + child: Expression, + dataType: DataType, + walkedTypePath: WalkedTypePath = new WalkedTypePath()) extends UnaryExpression with Unevaluable { override lazy val resolved = false } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 8182730feb4b4..ceebbd22fe3f2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -28,7 +28,7 @@ import scala.util.Try import org.apache.spark.{SparkConf, SparkEnv} import org.apache.spark.serializer._ import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, ScalaReflection} +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, ScalaReflection, WalkedTypePath} import org.apache.spark.sql.catalyst.ScalaReflection.universe.TermName import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, UnresolvedException} import org.apache.spark.sql.catalyst.encoders.RowEncoder @@ -1627,7 +1627,7 @@ case class InitializeJavaBean(beanInstance: Expression, setters: Map[String, Exp * `Int` field named `i`. Expression `s.i` is nullable because `s` can be null. However, for all * non-null `s`, `s.i` can't be null. */ -case class AssertNotNull(child: Expression, walkedTypePath: Seq[String] = Nil) +case class AssertNotNull(child: Expression, walkedTypePath: WalkedTypePath = new WalkedTypePath()) extends UnaryExpression with NonSQLExpression { override def dataType: DataType = child.dataType @@ -1637,7 +1637,7 @@ case class AssertNotNull(child: Expression, walkedTypePath: Seq[String] = Nil) override def flatArguments: Iterator[Any] = Iterator(child) private val errMsg = "Null value appeared in non-nullable field:" + - walkedTypePath.mkString("\n", "\n", "\n") + + s"\n$walkedTypePath\n" + "If the schema is inferred from a Scala tuple/case class, or a Java bean, " + "please try to use scala.Option[_] or other nullable types " + "(e.g. java.lang.Integer instead of int/scala.Int)." diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index 7e6fe5b4e2069..cf08d351801c3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -338,7 +338,7 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { test("should not apply common subexpression elimination on conditional expressions") { val row = InternalRow(null) val bound = BoundReference(0, IntegerType, true) - val assertNotNull = AssertNotNull(bound, Nil) + val assertNotNull = AssertNotNull(bound) val expr = If(IsNull(bound), Literal(1), Add(assertNotNull, assertNotNull)) val projection = GenerateUnsafeProjection.generate( Seq(expr), subexpressionEliminationEnabled = true) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala index b7ce367230810..49fd59c8694f1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala @@ -53,7 +53,7 @@ class NullExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { test("AssertNotNUll") { val ex = intercept[RuntimeException] { - evaluateWithoutCodegen(AssertNotNull(Literal(null), Seq.empty[String])) + evaluateWithoutCodegen(AssertNotNull(Literal(null))) }.getMessage assert(ex.contains("Null value appeared in non-nullable field")) } From 1970e500b9698838617539e6a8b27f9e324d0f95 Mon Sep 17 00:00:00 2001 From: "Jungtaek Lim (HeartSaVioR)" Date: Thu, 28 Feb 2019 08:22:55 +0900 Subject: [PATCH 06/14] Fixed string issue --- .../scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 68901e4ba8b6a..c88464d27cb37 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -2535,7 +2535,7 @@ class Analyzer( } throw new AnalysisException(s"Cannot up cast $fromStr from " + s"${from.dataType.catalogString} to ${to.catalogString} as it may truncate\n" + - "The type path of the target object is:\n\n" + walkedTypePath + "\n" + + "The type path of the target object is:\n" + walkedTypePath + "\n" + "You can either add an explicit cast to the input data or choose a higher precision " + "type of the field in the target object") } From 6b26513487b6620dcc773cbc2704a77049359f3e Mon Sep 17 00:00:00 2001 From: "Jungtaek Lim (HeartSaVioR)" Date: Thu, 28 Feb 2019 08:32:01 +0900 Subject: [PATCH 07/14] Change WalkedTypePath to case class so that `equals` are properly implemented --- .../scala/org/apache/spark/sql/catalyst/WalkedTypePath.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/WalkedTypePath.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/WalkedTypePath.scala index c7cf05d434ca6..c632bc4eebfb4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/WalkedTypePath.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/WalkedTypePath.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst -class WalkedTypePath(val walkedPaths: Seq[String] = Nil) extends Serializable { +case class WalkedTypePath(walkedPaths: Seq[String] = Nil) extends Serializable { def recordRoot(className: String): WalkedTypePath = newInstance(s"""- root class: "$className"""") @@ -46,5 +46,5 @@ class WalkedTypePath(val walkedPaths: Seq[String] = Nil) extends Serializable { } private def newInstance(newRecord: String): WalkedTypePath = - new WalkedTypePath(newRecord +: walkedPaths) + WalkedTypePath(newRecord +: walkedPaths) } From c67826afb738bfe5b74447ce91e3e6ade9f9a484 Mon Sep 17 00:00:00 2001 From: "Jungtaek Lim (HeartSaVioR)" Date: Thu, 28 Feb 2019 13:42:57 +0900 Subject: [PATCH 08/14] address review comments from cloud-fan - replace function parameter from function which creates instance to actual instance if it doesn't really necessary - use recordRoot for initializing WalkedTypePath - use mutable list for WalkedTypePath --- .../catalyst/DeserializerBuildHelper.scala | 14 ++-- .../sql/catalyst/JavaTypeInference.scala | 35 +++++---- .../spark/sql/catalyst/ScalaReflection.scala | 78 ++++++++++--------- .../spark/sql/catalyst/WalkedTypePath.scala | 49 +++++++----- .../catalyst/encoders/ExpressionEncoder.scala | 5 +- .../spark/sql/catalyst/expressions/Cast.scala | 2 +- .../expressions/objects/objects.scala | 2 +- 7 files changed, 107 insertions(+), 78 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala index 5245caed6b516..287a1c3e35d8b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala @@ -45,12 +45,10 @@ object DeserializerBuildHelper { } def deserializerForWithNullSafety( - expr: Expression, + newExpr: Expression, dataType: DataType, nullable: Boolean, - walkedTypePath: WalkedTypePath, - funcForCreatingNewExpr: (Expression, WalkedTypePath) => Expression): Expression = { - val newExpr = funcForCreatingNewExpr(expr, walkedTypePath) + walkedTypePath: WalkedTypePath): Expression = { expressionWithNullSafety(newExpr, nullable, walkedTypePath) } @@ -61,8 +59,8 @@ object DeserializerBuildHelper { walkedTypePath: WalkedTypePath, funcForCreatingNewExpr: (Expression, WalkedTypePath) => Expression): Expression = { val casted = upCastToExpectedType(expr, dataType, walkedTypePath) - deserializerForWithNullSafety(casted, dataType, nullable, walkedTypePath, - funcForCreatingNewExpr) + deserializerForWithNullSafety(funcForCreatingNewExpr(casted, walkedTypePath), dataType, + nullable, walkedTypePath) } private def expressionWithNullSafety( @@ -72,7 +70,7 @@ object DeserializerBuildHelper { if (nullable) { expr } else { - AssertNotNull(expr, walkedTypePath) + AssertNotNull(expr, walkedTypePath.copy()) } } @@ -171,6 +169,6 @@ object DeserializerBuildHelper { case _: StructType => expr case _: ArrayType => expr case _: MapType => expr - case _ => UpCast(expr, expected, walkedTypePath) + case _ => UpCast(expr, expected, walkedTypePath.copy()) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala index 4865429bc23be..970a8b23f6b92 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala @@ -195,7 +195,8 @@ object JavaTypeInference { */ def deserializerFor(beanClass: Class[_]): Expression = { val typeToken = TypeToken.of(beanClass) - val walkedTypePath = new WalkedTypePath().recordRoot(beanClass.getCanonicalName) + val walkedTypePath = WalkedTypePath() + walkedTypePath.recordRoot(beanClass.getCanonicalName) val (dataType, nullable) = inferDataType(typeToken) // Assumes we are deserializing the first column of a row. @@ -244,7 +245,7 @@ object JavaTypeInference { case c if c.isArray => val elementType = c.getComponentType - val newTypePath = walkedTypePath.recordArray(elementType.getCanonicalName) + walkedTypePath.recordArray(elementType.getCanonicalName) val (dataType, elementNullable) = inferDataType(elementType) val mapFunction: Expression => Expression = element => { // upcast the array element to the data type the encoder expected. @@ -252,7 +253,7 @@ object JavaTypeInference { element, dataType, nullable = elementNullable, - newTypePath, + walkedTypePath, (casted, typePath) => deserializerFor(typeToken.getComponentType, casted, typePath)) } @@ -273,7 +274,7 @@ object JavaTypeInference { case c if listType.isAssignableFrom(typeToken) => val et = elementType(typeToken) - val newTypePath = walkedTypePath.recordArray(et.getType.getTypeName) + walkedTypePath.recordArray(et.getType.getTypeName) val (dataType, elementNullable) = inferDataType(et) val mapFunction: Expression => Expression = element => { // upcast the array element to the data type the encoder expected. @@ -281,7 +282,7 @@ object JavaTypeInference { element, dataType, nullable = elementNullable, - newTypePath, + walkedTypePath, (casted, typePath) => deserializerFor(et, casted, typePath)) } @@ -289,13 +290,16 @@ object JavaTypeInference { case _ if mapType.isAssignableFrom(typeToken) => val (keyType, valueType) = mapKeyValueType(typeToken) - val newTypePath = walkedTypePath.recordMap(keyType.getType.getTypeName, + walkedTypePath.recordMap(keyType.getType.getTypeName, valueType.getType.getTypeName) + val newTypePathForKey = walkedTypePath.copy() + val newTypePathForValue = walkedTypePath.copy() + val keyData = Invoke( UnresolvedMapObjects( - p => deserializerFor(keyType, p, newTypePath), + p => deserializerFor(keyType, p, newTypePathForKey), MapKeys(path)), "array", ObjectType(classOf[Array[Any]])) @@ -303,7 +307,7 @@ object JavaTypeInference { val valueData = Invoke( UnresolvedMapObjects( - p => deserializerFor(valueType, p, newTypePath), + p => deserializerFor(valueType, p, newTypePathForValue), MapValues(path)), "array", ObjectType(classOf[Array[Any]])) @@ -326,14 +330,14 @@ object JavaTypeInference { val fieldName = p.getName val fieldType = typeToken.method(p.getReadMethod).getReturnType val (dataType, nullable) = inferDataType(fieldType) - val newTypePath = walkedTypePath.recordField(fieldType.getType.getTypeName, fieldName) + val newTypePathForField = walkedTypePath.copy() + newTypePathForField.recordField(fieldType.getType.getTypeName, fieldName) val setter = deserializerForWithNullSafety( - path, + deserializerFor(fieldType, addToPath(path, fieldName, dataType, newTypePathForField), + newTypePathForField), dataType, nullable = nullable, - newTypePath, - (expr, typePath) => deserializerFor(fieldType, - addToPath(expr, fieldName, dataType, typePath), typePath)) + newTypePathForField) p.getWriteMethod.getName -> setter }.toMap @@ -355,7 +359,10 @@ object JavaTypeInference { */ def serializerFor(beanClass: Class[_]): Expression = { val inputObject = BoundReference(0, ObjectType(beanClass), nullable = true) - val nullSafeInput = AssertNotNull(inputObject, new WalkedTypePath(Seq("top level input bean"))) + val walkedTypePath = WalkedTypePath() + walkedTypePath.recordRoot("top level input bean") + // not copying walkedTypePath since the instance will be only used here + val nullSafeInput = AssertNotNull(inputObject, walkedTypePath) serializerFor(nullSafeInput, TypeToken.of(beanClass)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 32c54af771585..dc2586effb980 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -137,7 +137,8 @@ object ScalaReflection extends ScalaReflection { */ def deserializerForType(tpe: `Type`): Expression = { val clsName = getClassNameFromType(tpe) - val walkedTypePath = new WalkedTypePath().recordRoot(clsName) + val walkedTypePath = WalkedTypePath() + walkedTypePath.recordRoot(clsName) val Schema(dataType, nullable) = schemaFor(tpe) // Assumes we are deserializing the first column of a row. @@ -164,8 +165,8 @@ object ScalaReflection extends ScalaReflection { case t if t <:< localTypeOf[Option[_]] => val TypeRef(_, _, Seq(optType)) = t val className = getClassNameFromType(optType) - val newTypePath = walkedTypePath.recordOption(className) - WrapOption(deserializerFor(optType, path, newTypePath), dataTypeFor(optType)) + walkedTypePath.recordOption(className) + WrapOption(deserializerFor(optType, path, walkedTypePath), dataTypeFor(optType)) case t if t <:< localTypeOf[java.lang.Integer] => createDeserializerForTypesSupportValueOf(path, @@ -226,7 +227,7 @@ object ScalaReflection extends ScalaReflection { val TypeRef(_, _, Seq(elementType)) = t val Schema(dataType, elementNullable) = schemaFor(elementType) val className = getClassNameFromType(elementType) - val newTypePath = walkedTypePath.recordArray(className) + walkedTypePath.recordArray(className) val mapFunction: Expression => Expression = element => { // upcast the array element to the data type the encoder expected. @@ -234,7 +235,7 @@ object ScalaReflection extends ScalaReflection { element, dataType, nullable = elementNullable, - newTypePath, + walkedTypePath, (casted, typePath) => deserializerFor(elementType, casted, typePath)) } @@ -261,14 +262,14 @@ object ScalaReflection extends ScalaReflection { val TypeRef(_, _, Seq(elementType)) = t val Schema(dataType, elementNullable) = schemaFor(elementType) val className = getClassNameFromType(elementType) - val newTypePath = walkedTypePath.recordArray(className) + walkedTypePath.recordArray(className) val mapFunction: Expression => Expression = element => { deserializerForWithNullSafetyAndUpcast( element, dataType, nullable = elementNullable, - newTypePath, + walkedTypePath, (casted, typePath) => deserializerFor(elementType, casted, typePath)) } @@ -287,12 +288,15 @@ object ScalaReflection extends ScalaReflection { val classNameForKey = getClassNameFromType(keyType) val classNameForValue = getClassNameFromType(valueType) - val newTypePath = walkedTypePath.recordMap(classNameForKey, classNameForValue) + val newPathForKey = walkedTypePath.copy() + newPathForKey.recordMap(classNameForKey, classNameForValue) + val newPathForValue = walkedTypePath.copy() + newPathForValue.recordMap(classNameForKey, classNameForValue) UnresolvedCatalystToExternalMap( path, - p => deserializerFor(keyType, p, newTypePath), - p => deserializerFor(valueType, p, newTypePath), + p => deserializerFor(keyType, p, newPathForKey), + p => deserializerFor(valueType, p, newPathForValue), mirror.runtimeClass(t.typeSymbol.asClass) ) @@ -322,27 +326,26 @@ object ScalaReflection extends ScalaReflection { val arguments = params.zipWithIndex.map { case ((fieldName, fieldType), i) => val Schema(dataType, nullable) = schemaFor(fieldType) val clsName = getClassNameFromType(fieldType) - val newTypePath = walkedTypePath.recordField(clsName, fieldName) + val newPathForField = walkedTypePath.copy() + newPathForField.recordField(clsName, fieldName) // For tuples, we based grab the inner fields by ordinal instead of name. + val newPath = if (cls.getName startsWith "scala.Tuple") { + deserializerFor( + fieldType, + addToPathOrdinal(path, i, dataType, newPathForField), + newPathForField) + } else { + deserializerFor( + fieldType, + addToPath(path, fieldName, dataType, newPathForField), + newPathForField) + } deserializerForWithNullSafety( - path, + newPath, dataType, nullable = nullable, - newTypePath, - (expr, typePath) => { - if (cls.getName startsWith "scala.Tuple") { - deserializerFor( - fieldType, - addToPathOrdinal(expr, i, dataType, typePath), - newTypePath) - } else { - deserializerFor( - fieldType, - addToPath(expr, fieldName, dataType, typePath), - newTypePath) - } - }) + newPathForField) } val newInstance = NewInstance(cls, arguments, ObjectType(cls), propagateNull = false) @@ -370,7 +373,8 @@ object ScalaReflection extends ScalaReflection { */ def serializerForType(tpe: `Type`): Expression = ScalaReflection.cleanUpReflectionObjects { val clsName = getClassNameFromType(tpe) - val walkedTypePath = new WalkedTypePath().recordRoot(clsName) + val walkedTypePath = WalkedTypePath() + walkedTypePath.recordRoot(clsName) // The input object to `ExpressionEncoder` is located at first column of an row. val isPrimitive = tpe.typeSymbol.asClass.isPrimitive @@ -393,11 +397,11 @@ object ScalaReflection extends ScalaReflection { dataTypeFor(elementType) match { case dt: ObjectType => val clsName = getClassNameFromType(elementType) - val newPath = walkedTypePath.recordArray(clsName) + walkedTypePath.recordArray(clsName) createSerializerForMapObjects(input, dt, - serializerFor(_, elementType, newPath, seenTypeSet)) + serializerFor(_, elementType, walkedTypePath, seenTypeSet)) - case dt @ (BooleanType | ByteType | ShortType | IntegerType | LongType | + case dt @ (BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType) => val cls = input.dataType.asInstanceOf[ObjectType].cls if (cls.isArray && cls.getComponentType.isPrimitive) { @@ -417,9 +421,9 @@ object ScalaReflection extends ScalaReflection { case t if t <:< localTypeOf[Option[_]] => val TypeRef(_, _, Seq(optType)) = t val className = getClassNameFromType(optType) - val newPath = walkedTypePath.recordOption(className) + walkedTypePath.recordOption(className) val unwrapped = UnwrapOption(dataTypeFor(optType), inputObject) - serializerFor(unwrapped, optType, newPath, seenTypeSet) + serializerFor(unwrapped, optType, walkedTypePath, seenTypeSet) // Since List[_] also belongs to localTypeOf[Product], we put this case before // "case t if definedByConstructorParams(t)" to make sure it will match to the @@ -436,8 +440,11 @@ object ScalaReflection extends ScalaReflection { val TypeRef(_, _, Seq(keyType, valueType)) = t val keyClsName = getClassNameFromType(keyType) val valueClsName = getClassNameFromType(valueType) - val keyPath = walkedTypePath.recordKeyForMap(keyClsName) - val valuePath = walkedTypePath.recordValueForMap(valueClsName) + + val keyPath = walkedTypePath.copy() + keyPath.recordKeyForMap(keyClsName) + val valuePath = walkedTypePath.copy() + valuePath.recordValueForMap(valueClsName) createSerializerForMap( inputObject, @@ -526,7 +533,8 @@ object ScalaReflection extends ScalaReflection { val fieldValue = Invoke(KnownNotNull(inputObject), fieldName, dataTypeFor(fieldType), returnNullable = !fieldType.typeSymbol.asClass.isPrimitive) val clsName = getClassNameFromType(fieldType) - val newPath = walkedTypePath.recordField(clsName, fieldName) + val newPath = walkedTypePath.copy() + newPath.recordField(clsName, fieldName) (fieldName, serializerFor(fieldValue, fieldType, newPath, seenTypeSet + t)) } createSerializerForObject(inputObject, fields) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/WalkedTypePath.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/WalkedTypePath.scala index c632bc4eebfb4..1a0b3a498fbb6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/WalkedTypePath.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/WalkedTypePath.scala @@ -17,34 +17,47 @@ package org.apache.spark.sql.catalyst -case class WalkedTypePath(walkedPaths: Seq[String] = Nil) extends Serializable { - def recordRoot(className: String): WalkedTypePath = - newInstance(s"""- root class: "$className"""") +import scala.collection.mutable - def recordOption(className: String): WalkedTypePath = - newInstance(s"""- option value class: "$className"""") +case class WalkedTypePath() extends Serializable { + val walkedPaths: mutable.ArrayBuffer[String] = new mutable.ArrayBuffer[String]() - def recordArray(elementClassName: String): WalkedTypePath = - newInstance(s"""- array element class: "$elementClassName"""") + def recordRoot(className: String): Unit = + record(s"""- root class: "$className"""") - def recordMap(keyClassName: String, valueClassName: String): WalkedTypePath = { - newInstance(s"""- map key class: "$keyClassName"""" + + def recordOption(className: String): Unit = + record(s"""- option value class: "$className"""") + + def recordArray(elementClassName: String): Unit = + record(s"""- array element class: "$elementClassName"""") + + def recordMap(keyClassName: String, valueClassName: String): Unit = { + record(s"""- map key class: "$keyClassName"""" + s""", value class: "$valueClassName"""") } - def recordKeyForMap(keyClassName: String): WalkedTypePath = - newInstance(s"""- map key class: "$keyClassName"""") + def recordKeyForMap(keyClassName: String): Unit = + record(s"""- map key class: "$keyClassName"""") - def recordValueForMap(valueClassName: String): WalkedTypePath = - newInstance(s"""- map value class: "$valueClassName"""") + def recordValueForMap(valueClassName: String): Unit = + record(s"""- map value class: "$valueClassName"""") - def recordField(className: String, fieldName: String): WalkedTypePath = - newInstance(s"""- field (class: "$className", name: "$fieldName")""") + def recordField(className: String, fieldName: String): Unit = + record(s"""- field (class: "$className", name: "$fieldName")""") + + def copy(): WalkedTypePath = { + val copied = WalkedTypePath() + copied.walkedPaths ++= walkedPaths + copied + } override def toString: String = { - walkedPaths.mkString("\n") + // to speed up appending element we are adding element at last and apply reverse + // just before printing it out + walkedPaths.reverse.mkString("\n") } - private def newInstance(newRecord: String): WalkedTypePath = - WalkedTypePath(newRecord +: walkedPaths) + private def record(newRecord: String): Unit = { + walkedPaths += newRecord + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index 8701633711341..48e8e18ba63cc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -195,7 +195,10 @@ case class ExpressionEncoder[T]( case r: BoundReference => // For input object of Product type, we can't encode it to row if it's null, as Spark SQL // doesn't allow top-level row to be null, only its columns can be null. - AssertNotNull(r, new WalkedTypePath(Seq("top level Product or row object"))) + val walkedTypePath = WalkedTypePath() + walkedTypePath.recordRoot("top level Product or row object") + // not copying walkedTypePath since the instance will be only used here + AssertNotNull(r, walkedTypePath) } nullSafeSerializer match { case If(_: IsNull, _, s: CreateNamedStruct) => s diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index f2e0319e3700e..9d2726c80046a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -1381,7 +1381,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String case class UpCast( child: Expression, dataType: DataType, - walkedTypePath: WalkedTypePath = new WalkedTypePath()) + walkedTypePath: WalkedTypePath = WalkedTypePath()) extends UnaryExpression with Unevaluable { override lazy val resolved = false } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index ceebbd22fe3f2..32ab2fac58737 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -1627,7 +1627,7 @@ case class InitializeJavaBean(beanInstance: Expression, setters: Map[String, Exp * `Int` field named `i`. Expression `s.i` is nullable because `s` can be null. However, for all * non-null `s`, `s.i` can't be null. */ -case class AssertNotNull(child: Expression, walkedTypePath: WalkedTypePath = new WalkedTypePath()) +case class AssertNotNull(child: Expression, walkedTypePath: WalkedTypePath = WalkedTypePath()) extends UnaryExpression with NonSQLExpression { override def dataType: DataType = child.dataType From 852debda5c634f12ddacda42bab4d0f69596b753 Mon Sep 17 00:00:00 2001 From: "Jungtaek Lim (HeartSaVioR)" Date: Fri, 1 Mar 2019 08:37:07 +0900 Subject: [PATCH 09/14] Remove unnecessary method --- .../sql/catalyst/DeserializerBuildHelper.scala | 14 +++----------- .../spark/sql/catalyst/JavaTypeInference.scala | 3 +-- .../spark/sql/catalyst/ScalaReflection.scala | 3 +-- 3 files changed, 5 insertions(+), 15 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala index 287a1c3e35d8b..ed5533a60f523 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala @@ -44,26 +44,18 @@ object DeserializerBuildHelper { upCastToExpectedType(newPath, dataType, walkedTypePath) } - def deserializerForWithNullSafety( - newExpr: Expression, - dataType: DataType, - nullable: Boolean, - walkedTypePath: WalkedTypePath): Expression = { - expressionWithNullSafety(newExpr, nullable, walkedTypePath) - } - def deserializerForWithNullSafetyAndUpcast( expr: Expression, dataType: DataType, nullable: Boolean, walkedTypePath: WalkedTypePath, - funcForCreatingNewExpr: (Expression, WalkedTypePath) => Expression): Expression = { + funcForCreatingDeserializer: (Expression, WalkedTypePath) => Expression): Expression = { val casted = upCastToExpectedType(expr, dataType, walkedTypePath) - deserializerForWithNullSafety(funcForCreatingNewExpr(casted, walkedTypePath), dataType, + expressionWithNullSafety(funcForCreatingDeserializer(casted, walkedTypePath), nullable, walkedTypePath) } - private def expressionWithNullSafety( + def expressionWithNullSafety( expr: Expression, nullable: Boolean, walkedTypePath: WalkedTypePath): Expression = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala index 970a8b23f6b92..959369129caac 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala @@ -332,10 +332,9 @@ object JavaTypeInference { val (dataType, nullable) = inferDataType(fieldType) val newTypePathForField = walkedTypePath.copy() newTypePathForField.recordField(fieldType.getType.getTypeName, fieldName) - val setter = deserializerForWithNullSafety( + val setter = expressionWithNullSafety( deserializerFor(fieldType, addToPath(path, fieldName, dataType, newTypePathForField), newTypePathForField), - dataType, nullable = nullable, newTypePathForField) p.getWriteMethod.getName -> setter diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index dc2586effb980..9fa9e589fb99b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -341,9 +341,8 @@ object ScalaReflection extends ScalaReflection { addToPath(path, fieldName, dataType, newPathForField), newPathForField) } - deserializerForWithNullSafety( + expressionWithNullSafety( newPath, - dataType, nullable = nullable, newPathForField) } From e0d749543c131247a9bebcbc5d1c3de573b0f9f8 Mon Sep 17 00:00:00 2001 From: "Jungtaek Lim (HeartSaVioR)" Date: Fri, 1 Mar 2019 16:55:49 +0900 Subject: [PATCH 10/14] Revert "address review comments from cloud-fan" This reverts commit c67826afb738bfe5b74447ce91e3e6ade9f9a484. NOTE: there's conflict which makes revert commit not clearly reverting as before, but WalkedTypePath is clearly reverted --- .../catalyst/DeserializerBuildHelper.scala | 4 +- .../sql/catalyst/JavaTypeInference.scala | 34 ++++------- .../spark/sql/catalyst/ScalaReflection.scala | 60 ++++++++----------- .../spark/sql/catalyst/WalkedTypePath.scala | 49 ++++++--------- .../catalyst/encoders/ExpressionEncoder.scala | 5 +- .../spark/sql/catalyst/expressions/Cast.scala | 2 +- .../expressions/objects/objects.scala | 2 +- 7 files changed, 61 insertions(+), 95 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala index ed5533a60f523..11db0dfb42188 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala @@ -62,7 +62,7 @@ object DeserializerBuildHelper { if (nullable) { expr } else { - AssertNotNull(expr, walkedTypePath.copy()) + AssertNotNull(expr, walkedTypePath) } } @@ -161,6 +161,6 @@ object DeserializerBuildHelper { case _: StructType => expr case _: ArrayType => expr case _: MapType => expr - case _ => UpCast(expr, expected, walkedTypePath.copy()) + case _ => UpCast(expr, expected, walkedTypePath) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala index 959369129caac..a2cdfa2f74e8f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala @@ -195,8 +195,7 @@ object JavaTypeInference { */ def deserializerFor(beanClass: Class[_]): Expression = { val typeToken = TypeToken.of(beanClass) - val walkedTypePath = WalkedTypePath() - walkedTypePath.recordRoot(beanClass.getCanonicalName) + val walkedTypePath = new WalkedTypePath().recordRoot(beanClass.getCanonicalName) val (dataType, nullable) = inferDataType(typeToken) // Assumes we are deserializing the first column of a row. @@ -245,7 +244,7 @@ object JavaTypeInference { case c if c.isArray => val elementType = c.getComponentType - walkedTypePath.recordArray(elementType.getCanonicalName) + val newTypePath = walkedTypePath.recordArray(elementType.getCanonicalName) val (dataType, elementNullable) = inferDataType(elementType) val mapFunction: Expression => Expression = element => { // upcast the array element to the data type the encoder expected. @@ -253,7 +252,7 @@ object JavaTypeInference { element, dataType, nullable = elementNullable, - walkedTypePath, + newTypePath, (casted, typePath) => deserializerFor(typeToken.getComponentType, casted, typePath)) } @@ -274,7 +273,7 @@ object JavaTypeInference { case c if listType.isAssignableFrom(typeToken) => val et = elementType(typeToken) - walkedTypePath.recordArray(et.getType.getTypeName) + val newTypePath = walkedTypePath.recordArray(et.getType.getTypeName) val (dataType, elementNullable) = inferDataType(et) val mapFunction: Expression => Expression = element => { // upcast the array element to the data type the encoder expected. @@ -282,7 +281,7 @@ object JavaTypeInference { element, dataType, nullable = elementNullable, - walkedTypePath, + newTypePath, (casted, typePath) => deserializerFor(et, casted, typePath)) } @@ -290,16 +289,13 @@ object JavaTypeInference { case _ if mapType.isAssignableFrom(typeToken) => val (keyType, valueType) = mapKeyValueType(typeToken) - walkedTypePath.recordMap(keyType.getType.getTypeName, + val newTypePath = walkedTypePath.recordMap(keyType.getType.getTypeName, valueType.getType.getTypeName) - val newTypePathForKey = walkedTypePath.copy() - val newTypePathForValue = walkedTypePath.copy() - val keyData = Invoke( UnresolvedMapObjects( - p => deserializerFor(keyType, p, newTypePathForKey), + p => deserializerFor(keyType, p, newTypePath), MapKeys(path)), "array", ObjectType(classOf[Array[Any]])) @@ -307,7 +303,7 @@ object JavaTypeInference { val valueData = Invoke( UnresolvedMapObjects( - p => deserializerFor(valueType, p, newTypePathForValue), + p => deserializerFor(valueType, p, newTypePath), MapValues(path)), "array", ObjectType(classOf[Array[Any]])) @@ -330,13 +326,12 @@ object JavaTypeInference { val fieldName = p.getName val fieldType = typeToken.method(p.getReadMethod).getReturnType val (dataType, nullable) = inferDataType(fieldType) - val newTypePathForField = walkedTypePath.copy() - newTypePathForField.recordField(fieldType.getType.getTypeName, fieldName) + val newTypePath = walkedTypePath.recordField(fieldType.getType.getTypeName, fieldName) val setter = expressionWithNullSafety( - deserializerFor(fieldType, addToPath(path, fieldName, dataType, newTypePathForField), - newTypePathForField), + deserializerFor(fieldType, addToPath(path, fieldName, dataType, newTypePath), + newTypePath), nullable = nullable, - newTypePathForField) + newTypePath) p.getWriteMethod.getName -> setter }.toMap @@ -358,10 +353,7 @@ object JavaTypeInference { */ def serializerFor(beanClass: Class[_]): Expression = { val inputObject = BoundReference(0, ObjectType(beanClass), nullable = true) - val walkedTypePath = WalkedTypePath() - walkedTypePath.recordRoot("top level input bean") - // not copying walkedTypePath since the instance will be only used here - val nullSafeInput = AssertNotNull(inputObject, walkedTypePath) + val nullSafeInput = AssertNotNull(inputObject, new WalkedTypePath(Seq("top level input bean"))) serializerFor(nullSafeInput, TypeToken.of(beanClass)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 9fa9e589fb99b..5b3109af6a53a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -137,8 +137,7 @@ object ScalaReflection extends ScalaReflection { */ def deserializerForType(tpe: `Type`): Expression = { val clsName = getClassNameFromType(tpe) - val walkedTypePath = WalkedTypePath() - walkedTypePath.recordRoot(clsName) + val walkedTypePath = new WalkedTypePath().recordRoot(clsName) val Schema(dataType, nullable) = schemaFor(tpe) // Assumes we are deserializing the first column of a row. @@ -165,8 +164,8 @@ object ScalaReflection extends ScalaReflection { case t if t <:< localTypeOf[Option[_]] => val TypeRef(_, _, Seq(optType)) = t val className = getClassNameFromType(optType) - walkedTypePath.recordOption(className) - WrapOption(deserializerFor(optType, path, walkedTypePath), dataTypeFor(optType)) + val newTypePath = walkedTypePath.recordOption(className) + WrapOption(deserializerFor(optType, path, newTypePath), dataTypeFor(optType)) case t if t <:< localTypeOf[java.lang.Integer] => createDeserializerForTypesSupportValueOf(path, @@ -227,7 +226,7 @@ object ScalaReflection extends ScalaReflection { val TypeRef(_, _, Seq(elementType)) = t val Schema(dataType, elementNullable) = schemaFor(elementType) val className = getClassNameFromType(elementType) - walkedTypePath.recordArray(className) + val newTypePath = walkedTypePath.recordArray(className) val mapFunction: Expression => Expression = element => { // upcast the array element to the data type the encoder expected. @@ -235,7 +234,7 @@ object ScalaReflection extends ScalaReflection { element, dataType, nullable = elementNullable, - walkedTypePath, + newTypePath, (casted, typePath) => deserializerFor(elementType, casted, typePath)) } @@ -262,14 +261,14 @@ object ScalaReflection extends ScalaReflection { val TypeRef(_, _, Seq(elementType)) = t val Schema(dataType, elementNullable) = schemaFor(elementType) val className = getClassNameFromType(elementType) - walkedTypePath.recordArray(className) + val newTypePath = walkedTypePath.recordArray(className) val mapFunction: Expression => Expression = element => { deserializerForWithNullSafetyAndUpcast( element, dataType, nullable = elementNullable, - walkedTypePath, + newTypePath, (casted, typePath) => deserializerFor(elementType, casted, typePath)) } @@ -288,15 +287,12 @@ object ScalaReflection extends ScalaReflection { val classNameForKey = getClassNameFromType(keyType) val classNameForValue = getClassNameFromType(valueType) - val newPathForKey = walkedTypePath.copy() - newPathForKey.recordMap(classNameForKey, classNameForValue) - val newPathForValue = walkedTypePath.copy() - newPathForValue.recordMap(classNameForKey, classNameForValue) + val newTypePath = walkedTypePath.recordMap(classNameForKey, classNameForValue) UnresolvedCatalystToExternalMap( path, - p => deserializerFor(keyType, p, newPathForKey), - p => deserializerFor(valueType, p, newPathForValue), + p => deserializerFor(keyType, p, newTypePath), + p => deserializerFor(valueType, p, newTypePath), mirror.runtimeClass(t.typeSymbol.asClass) ) @@ -326,25 +322,24 @@ object ScalaReflection extends ScalaReflection { val arguments = params.zipWithIndex.map { case ((fieldName, fieldType), i) => val Schema(dataType, nullable) = schemaFor(fieldType) val clsName = getClassNameFromType(fieldType) - val newPathForField = walkedTypePath.copy() - newPathForField.recordField(clsName, fieldName) + val newTypePath = walkedTypePath.recordField(clsName, fieldName) // For tuples, we based grab the inner fields by ordinal instead of name. val newPath = if (cls.getName startsWith "scala.Tuple") { deserializerFor( fieldType, - addToPathOrdinal(path, i, dataType, newPathForField), - newPathForField) + addToPathOrdinal(path, i, dataType, newTypePath), + newTypePath) } else { deserializerFor( fieldType, - addToPath(path, fieldName, dataType, newPathForField), - newPathForField) + addToPath(path, fieldName, dataType, newTypePath), + newTypePath) } expressionWithNullSafety( newPath, nullable = nullable, - newPathForField) + newTypePath) } val newInstance = NewInstance(cls, arguments, ObjectType(cls), propagateNull = false) @@ -372,8 +367,7 @@ object ScalaReflection extends ScalaReflection { */ def serializerForType(tpe: `Type`): Expression = ScalaReflection.cleanUpReflectionObjects { val clsName = getClassNameFromType(tpe) - val walkedTypePath = WalkedTypePath() - walkedTypePath.recordRoot(clsName) + val walkedTypePath = new WalkedTypePath().recordRoot(clsName) // The input object to `ExpressionEncoder` is located at first column of an row. val isPrimitive = tpe.typeSymbol.asClass.isPrimitive @@ -396,11 +390,11 @@ object ScalaReflection extends ScalaReflection { dataTypeFor(elementType) match { case dt: ObjectType => val clsName = getClassNameFromType(elementType) - walkedTypePath.recordArray(clsName) + val newPath = walkedTypePath.recordArray(clsName) createSerializerForMapObjects(input, dt, - serializerFor(_, elementType, walkedTypePath, seenTypeSet)) + serializerFor(_, elementType, newPath, seenTypeSet)) - case dt @ (BooleanType | ByteType | ShortType | IntegerType | LongType | + case dt @ (BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType) => val cls = input.dataType.asInstanceOf[ObjectType].cls if (cls.isArray && cls.getComponentType.isPrimitive) { @@ -420,9 +414,9 @@ object ScalaReflection extends ScalaReflection { case t if t <:< localTypeOf[Option[_]] => val TypeRef(_, _, Seq(optType)) = t val className = getClassNameFromType(optType) - walkedTypePath.recordOption(className) + val newPath = walkedTypePath.recordOption(className) val unwrapped = UnwrapOption(dataTypeFor(optType), inputObject) - serializerFor(unwrapped, optType, walkedTypePath, seenTypeSet) + serializerFor(unwrapped, optType, newPath, seenTypeSet) // Since List[_] also belongs to localTypeOf[Product], we put this case before // "case t if definedByConstructorParams(t)" to make sure it will match to the @@ -439,11 +433,8 @@ object ScalaReflection extends ScalaReflection { val TypeRef(_, _, Seq(keyType, valueType)) = t val keyClsName = getClassNameFromType(keyType) val valueClsName = getClassNameFromType(valueType) - - val keyPath = walkedTypePath.copy() - keyPath.recordKeyForMap(keyClsName) - val valuePath = walkedTypePath.copy() - valuePath.recordValueForMap(valueClsName) + val keyPath = walkedTypePath.recordKeyForMap(keyClsName) + val valuePath = walkedTypePath.recordValueForMap(valueClsName) createSerializerForMap( inputObject, @@ -532,8 +523,7 @@ object ScalaReflection extends ScalaReflection { val fieldValue = Invoke(KnownNotNull(inputObject), fieldName, dataTypeFor(fieldType), returnNullable = !fieldType.typeSymbol.asClass.isPrimitive) val clsName = getClassNameFromType(fieldType) - val newPath = walkedTypePath.copy() - newPath.recordField(clsName, fieldName) + val newPath = walkedTypePath.recordField(clsName, fieldName) (fieldName, serializerFor(fieldValue, fieldType, newPath, seenTypeSet + t)) } createSerializerForObject(inputObject, fields) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/WalkedTypePath.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/WalkedTypePath.scala index 1a0b3a498fbb6..c632bc4eebfb4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/WalkedTypePath.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/WalkedTypePath.scala @@ -17,47 +17,34 @@ package org.apache.spark.sql.catalyst -import scala.collection.mutable +case class WalkedTypePath(walkedPaths: Seq[String] = Nil) extends Serializable { + def recordRoot(className: String): WalkedTypePath = + newInstance(s"""- root class: "$className"""") -case class WalkedTypePath() extends Serializable { - val walkedPaths: mutable.ArrayBuffer[String] = new mutable.ArrayBuffer[String]() + def recordOption(className: String): WalkedTypePath = + newInstance(s"""- option value class: "$className"""") - def recordRoot(className: String): Unit = - record(s"""- root class: "$className"""") + def recordArray(elementClassName: String): WalkedTypePath = + newInstance(s"""- array element class: "$elementClassName"""") - def recordOption(className: String): Unit = - record(s"""- option value class: "$className"""") - - def recordArray(elementClassName: String): Unit = - record(s"""- array element class: "$elementClassName"""") - - def recordMap(keyClassName: String, valueClassName: String): Unit = { - record(s"""- map key class: "$keyClassName"""" + + def recordMap(keyClassName: String, valueClassName: String): WalkedTypePath = { + newInstance(s"""- map key class: "$keyClassName"""" + s""", value class: "$valueClassName"""") } - def recordKeyForMap(keyClassName: String): Unit = - record(s"""- map key class: "$keyClassName"""") + def recordKeyForMap(keyClassName: String): WalkedTypePath = + newInstance(s"""- map key class: "$keyClassName"""") - def recordValueForMap(valueClassName: String): Unit = - record(s"""- map value class: "$valueClassName"""") + def recordValueForMap(valueClassName: String): WalkedTypePath = + newInstance(s"""- map value class: "$valueClassName"""") - def recordField(className: String, fieldName: String): Unit = - record(s"""- field (class: "$className", name: "$fieldName")""") - - def copy(): WalkedTypePath = { - val copied = WalkedTypePath() - copied.walkedPaths ++= walkedPaths - copied - } + def recordField(className: String, fieldName: String): WalkedTypePath = + newInstance(s"""- field (class: "$className", name: "$fieldName")""") override def toString: String = { - // to speed up appending element we are adding element at last and apply reverse - // just before printing it out - walkedPaths.reverse.mkString("\n") + walkedPaths.mkString("\n") } - private def record(newRecord: String): Unit = { - walkedPaths += newRecord - } + private def newInstance(newRecord: String): WalkedTypePath = + WalkedTypePath(newRecord +: walkedPaths) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index 48e8e18ba63cc..8701633711341 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -195,10 +195,7 @@ case class ExpressionEncoder[T]( case r: BoundReference => // For input object of Product type, we can't encode it to row if it's null, as Spark SQL // doesn't allow top-level row to be null, only its columns can be null. - val walkedTypePath = WalkedTypePath() - walkedTypePath.recordRoot("top level Product or row object") - // not copying walkedTypePath since the instance will be only used here - AssertNotNull(r, walkedTypePath) + AssertNotNull(r, new WalkedTypePath(Seq("top level Product or row object"))) } nullSafeSerializer match { case If(_: IsNull, _, s: CreateNamedStruct) => s diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 9d2726c80046a..f2e0319e3700e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -1381,7 +1381,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String case class UpCast( child: Expression, dataType: DataType, - walkedTypePath: WalkedTypePath = WalkedTypePath()) + walkedTypePath: WalkedTypePath = new WalkedTypePath()) extends UnaryExpression with Unevaluable { override lazy val resolved = false } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 32ab2fac58737..ceebbd22fe3f2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -1627,7 +1627,7 @@ case class InitializeJavaBean(beanInstance: Expression, setters: Map[String, Exp * `Int` field named `i`. Expression `s.i` is nullable because `s` can be null. However, for all * non-null `s`, `s.i` can't be null. */ -case class AssertNotNull(child: Expression, walkedTypePath: WalkedTypePath = WalkedTypePath()) +case class AssertNotNull(child: Expression, walkedTypePath: WalkedTypePath = new WalkedTypePath()) extends UnaryExpression with NonSQLExpression { override def dataType: DataType = child.dataType From 65d2079638e1fa16888c73bc30e05ce28e353075 Mon Sep 17 00:00:00 2001 From: "Jungtaek Lim (HeartSaVioR)" Date: Fri, 1 Mar 2019 17:00:02 +0900 Subject: [PATCH 11/14] Address review comment from cloud-fan (partially reapplying 90df8a3ebf3a228e4ff6d47bfd6f0ed98ad2b964) --- .../org/apache/spark/sql/catalyst/JavaTypeInference.scala | 3 ++- .../apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala index a2cdfa2f74e8f..1ea9ed3f36fd8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala @@ -353,7 +353,8 @@ object JavaTypeInference { */ def serializerFor(beanClass: Class[_]): Expression = { val inputObject = BoundReference(0, ObjectType(beanClass), nullable = true) - val nullSafeInput = AssertNotNull(inputObject, new WalkedTypePath(Seq("top level input bean"))) + val nullSafeInput = AssertNotNull(inputObject, + WalkedTypePath().recordRoot("top level input bean")) serializerFor(nullSafeInput, TypeToken.of(beanClass)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index 8701633711341..8acdf38e2fc34 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -195,7 +195,7 @@ case class ExpressionEncoder[T]( case r: BoundReference => // For input object of Product type, we can't encode it to row if it's null, as Spark SQL // doesn't allow top-level row to be null, only its columns can be null. - AssertNotNull(r, new WalkedTypePath(Seq("top level Product or row object"))) + AssertNotNull(r, WalkedTypePath().recordRoot("top level Product or row object")) } nullSafeSerializer match { case If(_: IsNull, _, s: CreateNamedStruct) => s From 578d8feff26565f7ee6a0be5ff85a244ff94d873 Mon Sep 17 00:00:00 2001 From: "Jungtaek Lim (HeartSaVioR)" Date: Fri, 1 Mar 2019 17:07:14 +0900 Subject: [PATCH 12/14] Address review comments from maropu --- .../org/apache/spark/sql/catalyst/WalkedTypePath.scala | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/WalkedTypePath.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/WalkedTypePath.scala index c632bc4eebfb4..c653236bfd63c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/WalkedTypePath.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/WalkedTypePath.scala @@ -17,7 +17,12 @@ package org.apache.spark.sql.catalyst -case class WalkedTypePath(walkedPaths: Seq[String] = Nil) extends Serializable { +/** + * This class records the paths the serializer and deserializer walk through to reach current path. + * Note that this class adds new path in prior to recorded paths so it maintains + * the paths as reverse order. + */ +case class WalkedTypePath(private val walkedPaths: Seq[String] = Nil) extends Serializable { def recordRoot(className: String): WalkedTypePath = newInstance(s"""- root class: "$className"""") From 20e8d5a0b4608d3b0ccd4fe00553065d20b2b6fe Mon Sep 17 00:00:00 2001 From: "Jungtaek Lim (HeartSaVioR)" Date: Fri, 1 Mar 2019 18:04:37 +0900 Subject: [PATCH 13/14] Address review comments --- .../apache/spark/sql/catalyst/DeserializerBuildHelper.scala | 4 ++-- .../org/apache/spark/sql/catalyst/JavaTypeInference.scala | 2 +- .../scala/org/apache/spark/sql/catalyst/WalkedTypePath.scala | 2 ++ .../org/apache/spark/sql/catalyst/analysis/Analyzer.scala | 4 ++-- .../spark/sql/catalyst/encoders/ExpressionEncoder.scala | 2 +- .../org/apache/spark/sql/catalyst/expressions/Cast.scala | 2 +- .../spark/sql/catalyst/expressions/objects/objects.scala | 4 ++-- 7 files changed, 11 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala index 11db0dfb42188..e55c25c4b0c54 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala @@ -62,7 +62,7 @@ object DeserializerBuildHelper { if (nullable) { expr } else { - AssertNotNull(expr, walkedTypePath) + AssertNotNull(expr, walkedTypePath.getPaths) } } @@ -161,6 +161,6 @@ object DeserializerBuildHelper { case _: StructType => expr case _: ArrayType => expr case _: MapType => expr - case _ => UpCast(expr, expected, walkedTypePath) + case _ => UpCast(expr, expected, walkedTypePath.getPaths) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala index 1ea9ed3f36fd8..786cf10d3c127 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala @@ -354,7 +354,7 @@ object JavaTypeInference { def serializerFor(beanClass: Class[_]): Expression = { val inputObject = BoundReference(0, ObjectType(beanClass), nullable = true) val nullSafeInput = AssertNotNull(inputObject, - WalkedTypePath().recordRoot("top level input bean")) + WalkedTypePath().recordRoot("top level input bean").getPaths) serializerFor(nullSafeInput, TypeToken.of(beanClass)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/WalkedTypePath.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/WalkedTypePath.scala index c653236bfd63c..cdb55b8f6fc49 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/WalkedTypePath.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/WalkedTypePath.scala @@ -50,6 +50,8 @@ case class WalkedTypePath(private val walkedPaths: Seq[String] = Nil) extends Se walkedPaths.mkString("\n") } + def getPaths: Seq[String] = walkedPaths + private def newInstance(newRecord: String): WalkedTypePath = WalkedTypePath(newRecord +: walkedPaths) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index c88464d27cb37..ab9cedc1306c8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -2528,14 +2528,14 @@ class Analyzer( * Replace the [[UpCast]] expression by [[Cast]], and throw exceptions if the cast may truncate. */ object ResolveUpCast extends Rule[LogicalPlan] { - private def fail(from: Expression, to: DataType, walkedTypePath: WalkedTypePath) = { + private def fail(from: Expression, to: DataType, walkedTypePath: Seq[String]) = { val fromStr = from match { case l: LambdaVariable => "array element" case e => e.sql } throw new AnalysisException(s"Cannot up cast $fromStr from " + s"${from.dataType.catalogString} to ${to.catalogString} as it may truncate\n" + - "The type path of the target object is:\n" + walkedTypePath + "\n" + + "The type path of the target object is:\n" + walkedTypePath.mkString("", "\n", "\n") + "You can either add an explicit cast to the input data or choose a higher precision " + "type of the field in the target object") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index 8acdf38e2fc34..ae296a20e06d8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -195,7 +195,7 @@ case class ExpressionEncoder[T]( case r: BoundReference => // For input object of Product type, we can't encode it to row if it's null, as Spark SQL // doesn't allow top-level row to be null, only its columns can be null. - AssertNotNull(r, WalkedTypePath().recordRoot("top level Product or row object")) + AssertNotNull(r, WalkedTypePath().recordRoot("top level Product or row object").getPaths) } nullSafeSerializer match { case If(_: IsNull, _, s: CreateNamedStruct) => s diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index f2e0319e3700e..5c8e0a876e761 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -1381,7 +1381,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String case class UpCast( child: Expression, dataType: DataType, - walkedTypePath: WalkedTypePath = new WalkedTypePath()) + walkedTypePath: Seq[String] = Nil) extends UnaryExpression with Unevaluable { override lazy val resolved = false } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index ceebbd22fe3f2..f946ebf3cfd3f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -1627,7 +1627,7 @@ case class InitializeJavaBean(beanInstance: Expression, setters: Map[String, Exp * `Int` field named `i`. Expression `s.i` is nullable because `s` can be null. However, for all * non-null `s`, `s.i` can't be null. */ -case class AssertNotNull(child: Expression, walkedTypePath: WalkedTypePath = new WalkedTypePath()) +case class AssertNotNull(child: Expression, walkedTypePath: Seq[String] = Nil) extends UnaryExpression with NonSQLExpression { override def dataType: DataType = child.dataType @@ -1637,7 +1637,7 @@ case class AssertNotNull(child: Expression, walkedTypePath: WalkedTypePath = new override def flatArguments: Iterator[Any] = Iterator(child) private val errMsg = "Null value appeared in non-nullable field:" + - s"\n$walkedTypePath\n" + + walkedTypePath.mkString("\n", "\n", "\n") + "If the schema is inferred from a Scala tuple/case class, or a Java bean, " + "please try to use scala.Option[_] or other nullable types " + "(e.g. java.lang.Integer instead of int/scala.Int)." From 50c2ddce1e527b0d46d5ac1fbc2de28575d4059e Mon Sep 17 00:00:00 2001 From: "Jungtaek Lim (HeartSaVioR)" Date: Sat, 2 Mar 2019 16:57:42 +0900 Subject: [PATCH 14/14] Address review comments --- .../org/apache/spark/sql/catalyst/JavaTypeInference.scala | 3 +-- .../spark/sql/catalyst/encoders/ExpressionEncoder.scala | 2 +- .../org/apache/spark/sql/catalyst/expressions/Cast.scala | 5 +---- .../spark/sql/catalyst/expressions/objects/objects.scala | 2 +- 4 files changed, 4 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala index 786cf10d3c127..933a6dbeb7059 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala @@ -353,8 +353,7 @@ object JavaTypeInference { */ def serializerFor(beanClass: Class[_]): Expression = { val inputObject = BoundReference(0, ObjectType(beanClass), nullable = true) - val nullSafeInput = AssertNotNull(inputObject, - WalkedTypePath().recordRoot("top level input bean").getPaths) + val nullSafeInput = AssertNotNull(inputObject, Seq("top level input bean")) serializerFor(nullSafeInput, TypeToken.of(beanClass)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index ae296a20e06d8..abffda7127b07 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -195,7 +195,7 @@ case class ExpressionEncoder[T]( case r: BoundReference => // For input object of Product type, we can't encode it to row if it's null, as Spark SQL // doesn't allow top-level row to be null, only its columns can be null. - AssertNotNull(r, WalkedTypePath().recordRoot("top level Product or row object").getPaths) + AssertNotNull(r, Seq("top level Product or row object")) } nullSafeSerializer match { case If(_: IsNull, _, s: CreateNamedStruct) => s diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 5c8e0a876e761..84087ae510327 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -1378,10 +1378,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String * Cast the child expression to the target data type, but will throw error if the cast might * truncate, e.g. long -> int, timestamp -> data. */ -case class UpCast( - child: Expression, - dataType: DataType, - walkedTypePath: Seq[String] = Nil) +case class UpCast(child: Expression, dataType: DataType, walkedTypePath: Seq[String] = Nil) extends UnaryExpression with Unevaluable { override lazy val resolved = false } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index f946ebf3cfd3f..8182730feb4b4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -28,7 +28,7 @@ import scala.util.Try import org.apache.spark.{SparkConf, SparkEnv} import org.apache.spark.serializer._ import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, ScalaReflection, WalkedTypePath} +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, ScalaReflection} import org.apache.spark.sql.catalyst.ScalaReflection.universe.TermName import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, UnresolvedException} import org.apache.spark.sql.catalyst.encoders.RowEncoder