From d7fd8cd6062e96c3545cf960e111967d74a5dc85 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Wed, 20 Dec 2017 01:52:00 +0900 Subject: [PATCH 01/12] Fix a bug --- .../sql/catalyst/encoders/RowEncoder.scala | 5 +- .../spark/sql/catalyst/expressions/Cast.scala | 63 ++++++++++++++----- .../sql/catalyst/expressions/CastSuite.scala | 11 ++++ .../org/apache/spark/sql/SQLQuerySuite.scala | 38 +++++++++++ 4 files changed, 100 insertions(+), 17 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index 789750fd408f2..3c2815e405e24 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.analysis.GetColumnByOrdinal import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.objects._ -import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, DateTimeUtils, GenericArrayData} +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, DateTimeUtils} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -250,7 +250,8 @@ object RowEncoder { deserializerFor(input, input.dataType) } - private def deserializerFor(input: Expression, dataType: DataType): Expression = dataType match { + private[catalyst] def deserializerFor(input: Expression, dataType: DataType) + : Expression = dataType match { case dt if ScalaReflection.isNativeType(dt) => input case p: PythonUserDefinedType => deserializerFor(input, p.sqlType) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 274d8813f16db..8da215f5768e9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -19,10 +19,14 @@ package org.apache.spark.sql.catalyst.expressions import java.math.{BigDecimal => JavaBigDecimal} +import scala.collection.mutable.WrappedArray + import org.apache.spark.SparkException -import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion} +import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} @@ -206,6 +210,11 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String case DateType => buildCast[Int](_, d => UTF8String.fromString(DateTimeUtils.dateToString(d))) case TimestampType => buildCast[Long](_, t => UTF8String.fromString(DateTimeUtils.timestampToString(t, timeZone))) + case ar: ArrayType => + buildCast[ArrayData](_, a => { + val arrayData = CatalystTypeConverters.convertToScala(a, ar).asInstanceOf[WrappedArray[_]] + UTF8String.fromString(arrayData.mkString("[", ", ", "]")) + }) case _ => buildCast[Any](_, o => UTF8String.fromString(o.toString)) } @@ -543,7 +552,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val eval = child.genCode(ctx) - val nullSafeCast = nullSafeCastFunction(child.dataType, dataType, ctx) + val nullSafeCast = nullSafeCastFunction(child.dataType, dataType, ctx, eval) ev.copy(code = eval.code + castCode(ctx, eval.value, eval.isNull, ev.value, ev.isNull, dataType, nullSafeCast)) } @@ -555,11 +564,12 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String private[this] def nullSafeCastFunction( from: DataType, to: DataType, - ctx: CodegenContext): CastFunction = to match { + ctx: CodegenContext, + ev: ExprCode): CastFunction = to match { case _ if from == NullType => (c, evPrim, evNull) => s"$evNull = true;" case _ if to == from => (c, evPrim, evNull) => s"$evPrim = $c;" - case StringType => castToStringCode(from, ctx) + case StringType => castToStringCode(from, ctx, ev) case BinaryType => castToBinaryCode(from) case DateType => castToDateCode(from, ctx) case decimal: DecimalType => castToDecimalCode(from, decimal, ctx) @@ -574,9 +584,9 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String case DoubleType => castToDoubleCode(from) case array: ArrayType => - castArrayCode(from.asInstanceOf[ArrayType].elementType, array.elementType, ctx) - case map: MapType => castMapCode(from.asInstanceOf[MapType], map, ctx) - case struct: StructType => castStructCode(from.asInstanceOf[StructType], struct, ctx) + castArrayCode(from.asInstanceOf[ArrayType].elementType, array.elementType, ctx, ev) + case map: MapType => castMapCode(from.asInstanceOf[MapType], map, ctx, ev) + case struct: StructType => castStructCode(from.asInstanceOf[StructType], struct, ctx, ev) case udt: UserDefinedType[_] if udt.userClass == from.asInstanceOf[UserDefinedType[_]].userClass => (c, evPrim, evNull) => s"$evPrim = $c;" @@ -597,7 +607,8 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String """ } - private[this] def castToStringCode(from: DataType, ctx: CodegenContext): CastFunction = { + private[this] def castToStringCode(from: DataType, ctx: CodegenContext, ev: ExprCode) + : CastFunction = { from match { case BinaryType => (c, evPrim, evNull) => s"$evPrim = UTF8String.fromBytes($c);" @@ -608,6 +619,26 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String val tz = ctx.addReferenceObj("timeZone", timeZone) (c, evPrim, evNull) => s"""$evPrim = UTF8String.fromString( org.apache.spark.sql.catalyst.util.DateTimeUtils.timestampToString($c, $tz));""" + case ar: ArrayType => + // Generate code to recursively convert a catalyst array type `ArrayData` into + // a Scala array type by using an array encoder. + val staticInvoke = RowEncoder.deserializerFor(child, ar).asInstanceOf[StaticInvoke] + val arVal = ctx.freshName("arVal") + val arNull = ctx.freshName("arNull") + val inputExprCode = ExprCode( + code = + s"""${ctx.javaType(ar)} $arVal = ${ev.value} + |boolean $arNull = ${ev.isNull} + """.stripMargin, + isNull = arNull, + value = arVal + ) + val expr = staticInvoke.doGenCode(ctx, inputExprCode) + ev.code = expr.code + ev.isNull = expr.isNull + ev.value = expr.value + (c, evPrim, evNull) => + s"""$evPrim = UTF8String.fromString($c.mkString("[", ", ", "]"));""" case _ => (c, evPrim, evNull) => s"$evPrim = UTF8String.fromString(String.valueOf($c));" } @@ -945,8 +976,8 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String } private[this] def castArrayCode( - fromType: DataType, toType: DataType, ctx: CodegenContext): CastFunction = { - val elementCast = nullSafeCastFunction(fromType, toType, ctx) + fromType: DataType, toType: DataType, ctx: CodegenContext, ev: ExprCode): CastFunction = { + val elementCast = nullSafeCastFunction(fromType, toType, ctx, ev) val arrayClass = classOf[GenericArrayData].getName val fromElementNull = ctx.freshName("feNull") val fromElementPrim = ctx.freshName("fePrim") @@ -980,9 +1011,10 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String """ } - private[this] def castMapCode(from: MapType, to: MapType, ctx: CodegenContext): CastFunction = { - val keysCast = castArrayCode(from.keyType, to.keyType, ctx) - val valuesCast = castArrayCode(from.valueType, to.valueType, ctx) + private[this] def castMapCode(from: MapType, to: MapType, ctx: CodegenContext, ev: ExprCode) + : CastFunction = { + val keysCast = castArrayCode(from.keyType, to.keyType, ctx, ev) + val valuesCast = castArrayCode(from.valueType, to.valueType, ctx, ev) val mapClass = classOf[ArrayBasedMapData].getName @@ -1008,10 +1040,11 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String } private[this] def castStructCode( - from: StructType, to: StructType, ctx: CodegenContext): CastFunction = { + from: StructType, to: StructType, ctx: CodegenContext, ev: ExprCode): CastFunction = { val fieldsCasts = from.fields.zip(to.fields).map { - case (fromField, toField) => nullSafeCastFunction(fromField.dataType, toField.dataType, ctx) + case (fromField, toField) => + nullSafeCastFunction(fromField.dataType, toField.dataType, ctx, ev) } val rowClass = classOf[GenericInternalRow].getName val tmpResult = ctx.freshName("tmpResult") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala index 1dd040e4696a1..85de400733e34 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala @@ -853,4 +853,15 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { cast("2", LongType).genCode(ctx) assert(ctx.inlinedMutableStates.length == 0) } + + test("SPARK-22825 Cast array to string") { + val ret1 = cast(Literal.create(Array(1, 2, 3, 4, 5)), StringType) + checkEvaluation(ret1, "[1, 2, 3, 4, 5]") + val ret2 = cast(Literal.create(Array(Array(1, 2, 3), Array(4, 5))), StringType) + checkEvaluation(ret2, "[WrappedArray(1, 2, 3), WrappedArray(4, 5)]") + val ret3 = cast(Literal.create(Array(Map(1 -> "a"), Map(2 -> "b", 3 -> "c"))), StringType) + checkEvaluation(ret3, "[Map(1 -> a), Map(2 -> b, 3 -> c)]") + val ret4 = cast(Literal.create(Array((1, 3.0, "a"), (3, 1.0, "b"))), StringType) + checkEvaluation(ret4, "[[1,3.0,a], [3,1.0,b]]") + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 5e077285ade55..0fba3b79271ae 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -2775,4 +2775,42 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } } } + + test("SPARK-22825 Cast array to string") { + // Check non-codegen path + val df1 = sql("SELECT CAST(ARRAY(1, 2, 3, 4) AS STRING)") + checkAnswer(df1, Row("[1, 2, 3, 4]")) + val df2 = sql("SELECT CAST(ARRAY(ARRAY(1, 2), ARRAY(3, 4, 5), ARRAY(6, 7)) AS STRING)") + checkAnswer(df2, Row("[WrappedArray(1, 2), WrappedArray(3, 4, 5), WrappedArray(6, 7)]")) + val df3 = sql("SELECT CAST(ARRAY(MAP(1, 'a', 2, 'b'), MAP(3, 'c')) AS STRING)") + checkAnswer(df3, Row("[Map(1 -> a, 2 -> b), Map(3 -> c)]")) + val df4 = sql("SELECT CAST(ARRAY(STRUCT(1, 0.3, 'a'), STRUCT(2, 0.5, 'b')) AS STRING)") + checkAnswer(df4, Row("[[1,0.3,a], [2,0.5,b]]")) + + // Check codegen path + withTable("t") { + Seq(Seq(0, 1, 2, 3, 4)).toDF("a").write.saveAsTable("t") + val df = sql("SELECT CAST(a AS STRING) FROM t") + checkAnswer(df, Row("[0, 1, 2, 3, 4]")) + } + + withTable("t") { + Seq(Seq(Seq(1, 2), Seq(3), Seq(4, 5, 6))).toDF("a").write.saveAsTable("t") + val df = sql("SELECT CAST(a AS STRING) FROM t") + checkAnswer(df, Row("[WrappedArray(1, 2), WrappedArray(3), WrappedArray(4, 5, 6)]")) + } + + withTable("t") { + Seq(Seq(Map(1 -> "a", 2 -> "b"), Map(3 -> "c"), Map(4 -> "d", 5 -> "e"))).toDF("a") + .write.saveAsTable("t") + val df = sql("SELECT CAST(a AS STRING) FROM t") + checkAnswer(df, Row("[Map(1 -> a, 2 -> b), Map(3 -> c), Map(4 -> d, 5 -> e)]")) + } + + withTable("t") { + Seq(Seq((1, "a"), (2, "b")), Seq((3, "c"))).toDF("a").write.saveAsTable("t") + val df = sql("SELECT CAST(a AS STRING) FROM t") + checkAnswer(df, Row("[[1,a], [2,b]]") :: Row("[[3,c]]") :: Nil) + } + } } From 91df078a99b6ec5f2063b2e73170336e3fe812d1 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Thu, 4 Jan 2018 00:18:48 +0900 Subject: [PATCH 02/12] Fix --- .../expressions/codegen/BufferHolder.java | 4 + .../codegen/StringWriterBuffer.java | 61 ++++++++ .../sql/catalyst/encoders/RowEncoder.scala | 5 +- .../spark/sql/catalyst/expressions/Cast.scala | 139 +++++++++++++++--- .../sql/catalyst/expressions/CastSuite.scala | 8 +- .../org/apache/spark/sql/SQLQuerySuite.scala | 8 +- 6 files changed, 193 insertions(+), 32 deletions(-) create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/StringWriterBuffer.java diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java index 259976118c12f..e39205c4a0691 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java @@ -93,4 +93,8 @@ public void reset() { public int totalSize() { return cursor - Platform.BYTE_ARRAY_OFFSET; } + + public int fixedSize() { + return fixedSize; + } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/StringWriterBuffer.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/StringWriterBuffer.java new file mode 100644 index 0000000000000..555961291051a --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/StringWriterBuffer.java @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.codegen; + +import java.nio.charset.StandardCharsets; + +import org.apache.spark.sql.catalyst.expressions.UnsafeRow; +import org.apache.spark.unsafe.Platform; + +/** + * A helper class to write array elements into a string buffer using `BufferHolder`. + */ +public class StringWriterBuffer { + + private BufferHolder buffer; + + public StringWriterBuffer() { + this.buffer = new BufferHolder(new UnsafeRow(1), 256); + } + + public void reset() { + buffer.reset(); + } + + public void append(String value) { + append(value.getBytes(StandardCharsets.UTF_8)); + } + + public void append(byte[] value) { + final int numBytes = value.length; + buffer.grow(numBytes); + Platform.copyMemory( + value, Platform.BYTE_ARRAY_OFFSET, buffer.buffer, buffer.cursor, numBytes); + buffer.cursor += numBytes; + } + + public byte[] getBytes() { + // Compute a length of strings written in this buffer + final int strlen = buffer.totalSize() - buffer.fixedSize(); + final byte[] bytes = new byte[strlen]; + Platform.copyMemory( + buffer.buffer, Platform.BYTE_ARRAY_OFFSET + buffer.fixedSize(), + bytes, Platform.BYTE_ARRAY_OFFSET, strlen); + return bytes; + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index 3c2815e405e24..789750fd408f2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.analysis.GetColumnByOrdinal import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.objects._ -import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, DateTimeUtils} +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, DateTimeUtils, GenericArrayData} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -250,8 +250,7 @@ object RowEncoder { deserializerFor(input, input.dataType) } - private[catalyst] def deserializerFor(input: Expression, dataType: DataType) - : Expression = dataType match { + private def deserializerFor(input: Expression, dataType: DataType): Expression = dataType match { case dt if ScalaReflection.isNativeType(dt) => input case p: PythonUserDefinedType => deserializerFor(input, p.sqlType) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 8da215f5768e9..721d7d4d851fd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -24,9 +24,7 @@ import scala.collection.mutable.WrappedArray import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion} -import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions.codegen._ -import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} @@ -607,6 +605,112 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String """ } + private[this] def writeElemToBufferCode( + dataType: DataType, + buffer: String, + elemTerm: String, + ctx: CodegenContext): String = dataType match { + case BinaryType => s"$buffer.append($elemTerm)" + case StringType => s"$buffer.append($elemTerm.getBytes())" + case DateType => s"""$buffer.append( + org.apache.spark.sql.catalyst.util.DateTimeUtils.dateToString($elemTerm))""" + case TimestampType => s"""$buffer.append( + org.apache.spark.sql.catalyst.util.DateTimeUtils.timestampToString($elemTerm))""" + case map: MapType => s"${codegenWriteMapToBuffer(map, buffer, ctx)}($elemTerm)" + case ar: ArrayType => s"${codegenWriteArrayToBuffer(ar, buffer, ctx)}($elemTerm)" + case st: StructType => s"${codegenWriteStructToBuffer(st, buffer, ctx)}($elemTerm)" + case _ => s"$buffer.append(String.valueOf($elemTerm))" + } + + private[this] def codegenWriteStructToBuffer( + st: StructType, buffer: String, ctx: CodegenContext): String = { + val writeStructToBuffer = ctx.freshName("writeStructToBuffer") + val rowTerm = ctx.freshName("rowTerm") + val writeToBufferCode = st.zipWithIndex.map { case (f, i) => + val fieldTerm = ctx.freshName("fieldTerm") + val writeFieldCode = writeElemToBufferCode(f.dataType, buffer, fieldTerm, ctx) + s""" + |${ctx.javaType(st(i).dataType)} $fieldTerm = ${ctx.getValue(rowTerm, f.dataType, s"$i")}; + |$writeFieldCode; + """.stripMargin + } + ctx.addNewFunction(writeStructToBuffer, + s""" + |private void $writeStructToBuffer(InternalRow $rowTerm) { + | $buffer.append("["); + | ${writeToBufferCode.mkString(s"""$buffer.append(", ");""" + "\n")} + | $buffer.append("]"); + |} + """.stripMargin) + } + + private[this] def codegenWriteMapToBuffer( + map: MapType, buffer: String, ctx: CodegenContext): String = { + val loopIndex = ctx.freshName("loopIndex") + val writeMapToBuffer = ctx.freshName("writeMapToBuffer") + val mapTerm = ctx.freshName("mapTerm") + val keyTerm = ctx.freshName("keyTerm") + val valueTerm = ctx.freshName("valueTerm") + val writeKeyCode = writeElemToBufferCode(map.keyType, buffer, keyTerm, ctx) + val writeValueCode = writeElemToBufferCode(map.valueType, buffer, valueTerm, ctx) + def writeToBufferCode(i: String) = { + s""" + |${ctx.javaType(map.keyType)} $keyTerm = + | ${ctx.getValue(s"$mapTerm.keyArray()", map.keyType, i)}; + |${ctx.javaType(map.valueType)} $valueTerm = + | ${ctx.getValue(s"$mapTerm.valueArray()", map.valueType, i)}; + | + |// Write a key-value pair in the buffer + |$writeKeyCode; + |$buffer.append(" -> "); + |$writeValueCode; + """.stripMargin + } + ctx.addNewFunction(writeMapToBuffer, + s""" + |private void $writeMapToBuffer(MapData $mapTerm) { + | $buffer.append("["); + | if ($mapTerm.numElements() > 0) { + | ${writeToBufferCode("0")} + | } + | for (int $loopIndex = 1; $loopIndex < $mapTerm.numElements(); $loopIndex++) { + | $buffer.append(", "); + | ${writeToBufferCode(loopIndex)} + | } + | $buffer.append("]"); + |} + """.stripMargin) + } + + private[this] def codegenWriteArrayToBuffer( + ar: ArrayType, buffer: String, ctx: CodegenContext): String = { + val loopIndex = ctx.freshName("loopIndex") + val writeArrayToBuffer = ctx.freshName("writeArrayToBuffer") + val arTerm = ctx.freshName("arTerm") + val elemTerm = ctx.freshName("elemTerm") + val writeElemCode = writeElemToBufferCode(ar.elementType, buffer, elemTerm, ctx) + def writeToBufferCode(i: String) = { + s""" + |${ctx.javaType(ar.elementType)} $elemTerm = ${ctx.getValue(arTerm, ar.elementType, i)}; + |$writeElemCode; + """.stripMargin + } + ctx.addNewFunction(writeArrayToBuffer, + s""" + |private void $writeArrayToBuffer(ArrayData $arTerm) { + | $buffer.append("["); + | if ($arTerm.numElements() > 0) { + | ${writeToBufferCode("0")} + | } + | for (int $loopIndex = 1; $loopIndex < $arTerm.numElements(); $loopIndex++) { + | $buffer.append(", "); + | ${writeToBufferCode(loopIndex)} + | } + | $buffer.append("]"); + |} + """.stripMargin) + } + private[this] def castToStringCode(from: DataType, ctx: CodegenContext, ev: ExprCode) : CastFunction = { from match { @@ -620,25 +724,20 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String (c, evPrim, evNull) => s"""$evPrim = UTF8String.fromString( org.apache.spark.sql.catalyst.util.DateTimeUtils.timestampToString($c, $tz));""" case ar: ArrayType => - // Generate code to recursively convert a catalyst array type `ArrayData` into - // a Scala array type by using an array encoder. - val staticInvoke = RowEncoder.deserializerFor(child, ar).asInstanceOf[StaticInvoke] - val arVal = ctx.freshName("arVal") - val arNull = ctx.freshName("arNull") - val inputExprCode = ExprCode( - code = - s"""${ctx.javaType(ar)} $arVal = ${ev.value} - |boolean $arNull = ${ev.isNull} - """.stripMargin, - isNull = arNull, - value = arVal - ) - val expr = staticInvoke.doGenCode(ctx, inputExprCode) - ev.code = expr.code - ev.isNull = expr.isNull - ev.value = expr.value + val bufferClass = classOf[StringWriterBuffer].getName + val buffer = ctx.addMutableState(bufferClass, "buffer", v => s"$v = new $bufferClass();") + val writeArrayToBuffer = codegenWriteArrayToBuffer(ar, buffer, ctx) + val arrayToStringCode = + s""" + |if (!${ev.isNull}) { + | $buffer.reset(); + | $writeArrayToBuffer(${ev.value}); + |} + """.stripMargin + ev.code = ev.code ++ arrayToStringCode + ev.value = buffer (c, evPrim, evNull) => - s"""$evPrim = UTF8String.fromString($c.mkString("[", ", ", "]"));""" + s"""$evPrim = UTF8String.fromBytes($buffer.getBytes());""" case _ => (c, evPrim, evNull) => s"$evPrim = UTF8String.fromString(String.valueOf($c));" } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala index 85de400733e34..c6a829143e879 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala @@ -858,10 +858,10 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { val ret1 = cast(Literal.create(Array(1, 2, 3, 4, 5)), StringType) checkEvaluation(ret1, "[1, 2, 3, 4, 5]") val ret2 = cast(Literal.create(Array(Array(1, 2, 3), Array(4, 5))), StringType) - checkEvaluation(ret2, "[WrappedArray(1, 2, 3), WrappedArray(4, 5)]") + checkEvaluation(ret2, "[[1, 2, 3], [4, 5]]") val ret3 = cast(Literal.create(Array(Map(1 -> "a"), Map(2 -> "b", 3 -> "c"))), StringType) - checkEvaluation(ret3, "[Map(1 -> a), Map(2 -> b, 3 -> c)]") - val ret4 = cast(Literal.create(Array((1, 3.0, "a"), (3, 1.0, "b"))), StringType) - checkEvaluation(ret4, "[[1,3.0,a], [3,1.0,b]]") + checkEvaluation(ret3, "[[1 -> a], [2 -> b, 3 -> c]]") + // val ret4 = cast(Literal.create(Array((1, 3.0, "a"), (3, 1.0, "b"))), StringType) + // checkEvaluation(ret4, "[[1, 3.0, a], [3, 1.0, b]]") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 0fba3b79271ae..9157921ab1518 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -28,8 +28,6 @@ import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart} import org.apache.spark.sql.catalyst.util.StringUtils import org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, SortAggregateExec} -import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation} -import org.apache.spark.sql.execution.datasources.orc.OrcFileFormat import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, CartesianProductExec, SortMergeJoinExec} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf @@ -2797,20 +2795,20 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { withTable("t") { Seq(Seq(Seq(1, 2), Seq(3), Seq(4, 5, 6))).toDF("a").write.saveAsTable("t") val df = sql("SELECT CAST(a AS STRING) FROM t") - checkAnswer(df, Row("[WrappedArray(1, 2), WrappedArray(3), WrappedArray(4, 5, 6)]")) + checkAnswer(df, Row("[[1, 2], [3], [4, 5, 6]]")) } withTable("t") { Seq(Seq(Map(1 -> "a", 2 -> "b"), Map(3 -> "c"), Map(4 -> "d", 5 -> "e"))).toDF("a") .write.saveAsTable("t") val df = sql("SELECT CAST(a AS STRING) FROM t") - checkAnswer(df, Row("[Map(1 -> a, 2 -> b), Map(3 -> c), Map(4 -> d, 5 -> e)]")) + checkAnswer(df, Row("[[1 -> a, 2 -> b], [3 -> c], [4 -> d, 5 -> e]]")) } withTable("t") { Seq(Seq((1, "a"), (2, "b")), Seq((3, "c"))).toDF("a").write.saveAsTable("t") val df = sql("SELECT CAST(a AS STRING) FROM t") - checkAnswer(df, Row("[[1,a], [2,b]]") :: Row("[[3,c]]") :: Nil) + checkAnswer(df, Row("[[1, a], [2, b]]") :: Row("[[3, c]]") :: Nil) } } } From 8705f84982929056cc60cdac4e6b069208eb9f09 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Thu, 4 Jan 2018 11:04:02 +0900 Subject: [PATCH 03/12] Drop StringWriterBuffer --- .../expressions/codegen/BufferHolder.java | 4 -- .../codegen/StringWriterBuffer.java | 61 ------------------- .../spark/sql/catalyst/expressions/Cast.scala | 13 ++-- .../sql/catalyst/expressions/CastSuite.scala | 6 +- 4 files changed, 10 insertions(+), 74 deletions(-) delete mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/StringWriterBuffer.java diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java index e39205c4a0691..259976118c12f 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java @@ -93,8 +93,4 @@ public void reset() { public int totalSize() { return cursor - Platform.BYTE_ARRAY_OFFSET; } - - public int fixedSize() { - return fixedSize; - } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/StringWriterBuffer.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/StringWriterBuffer.java deleted file mode 100644 index 555961291051a..0000000000000 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/StringWriterBuffer.java +++ /dev/null @@ -1,61 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.expressions.codegen; - -import java.nio.charset.StandardCharsets; - -import org.apache.spark.sql.catalyst.expressions.UnsafeRow; -import org.apache.spark.unsafe.Platform; - -/** - * A helper class to write array elements into a string buffer using `BufferHolder`. - */ -public class StringWriterBuffer { - - private BufferHolder buffer; - - public StringWriterBuffer() { - this.buffer = new BufferHolder(new UnsafeRow(1), 256); - } - - public void reset() { - buffer.reset(); - } - - public void append(String value) { - append(value.getBytes(StandardCharsets.UTF_8)); - } - - public void append(byte[] value) { - final int numBytes = value.length; - buffer.grow(numBytes); - Platform.copyMemory( - value, Platform.BYTE_ARRAY_OFFSET, buffer.buffer, buffer.cursor, numBytes); - buffer.cursor += numBytes; - } - - public byte[] getBytes() { - // Compute a length of strings written in this buffer - final int strlen = buffer.totalSize() - buffer.fixedSize(); - final byte[] bytes = new byte[strlen]; - Platform.copyMemory( - buffer.buffer, Platform.BYTE_ARRAY_OFFSET + buffer.fixedSize(), - bytes, Platform.BYTE_ARRAY_OFFSET, strlen); - return bytes; - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 721d7d4d851fd..f2db94e2e47eb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -610,8 +610,8 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String buffer: String, elemTerm: String, ctx: CodegenContext): String = dataType match { - case BinaryType => s"$buffer.append($elemTerm)" - case StringType => s"$buffer.append($elemTerm.getBytes())" + case BinaryType => s"$buffer.append(new String($elemTerm))" + case StringType => s"$buffer.append(new String($elemTerm.getBytes()))" case DateType => s"""$buffer.append( org.apache.spark.sql.catalyst.util.DateTimeUtils.dateToString($elemTerm))""" case TimestampType => s"""$buffer.append( @@ -619,7 +619,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String case map: MapType => s"${codegenWriteMapToBuffer(map, buffer, ctx)}($elemTerm)" case ar: ArrayType => s"${codegenWriteArrayToBuffer(ar, buffer, ctx)}($elemTerm)" case st: StructType => s"${codegenWriteStructToBuffer(st, buffer, ctx)}($elemTerm)" - case _ => s"$buffer.append(String.valueOf($elemTerm))" + case _ => s"$buffer.append($elemTerm)" } private[this] def codegenWriteStructToBuffer( @@ -724,20 +724,21 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String (c, evPrim, evNull) => s"""$evPrim = UTF8String.fromString( org.apache.spark.sql.catalyst.util.DateTimeUtils.timestampToString($c, $tz));""" case ar: ArrayType => - val bufferClass = classOf[StringWriterBuffer].getName + val bufferClass = classOf[StringBuffer].getName val buffer = ctx.addMutableState(bufferClass, "buffer", v => s"$v = new $bufferClass();") val writeArrayToBuffer = codegenWriteArrayToBuffer(ar, buffer, ctx) val arrayToStringCode = s""" |if (!${ev.isNull}) { - | $buffer.reset(); + | // Reset buffer first + | $buffer.delete(0, $buffer.length()); | $writeArrayToBuffer(${ev.value}); |} """.stripMargin ev.code = ev.code ++ arrayToStringCode ev.value = buffer (c, evPrim, evNull) => - s"""$evPrim = UTF8String.fromBytes($buffer.getBytes());""" + s"""$evPrim = UTF8String.fromString($buffer.toString());""" case _ => (c, evPrim, evNull) => s"$evPrim = UTF8String.fromString(String.valueOf($c));" } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala index c6a829143e879..356cb42620d78 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala @@ -854,14 +854,14 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { assert(ctx.inlinedMutableStates.length == 0) } - test("SPARK-22825 Cast array to string") { + ignore("SPARK-22825 Cast array to string") { val ret1 = cast(Literal.create(Array(1, 2, 3, 4, 5)), StringType) checkEvaluation(ret1, "[1, 2, 3, 4, 5]") val ret2 = cast(Literal.create(Array(Array(1, 2, 3), Array(4, 5))), StringType) checkEvaluation(ret2, "[[1, 2, 3], [4, 5]]") val ret3 = cast(Literal.create(Array(Map(1 -> "a"), Map(2 -> "b", 3 -> "c"))), StringType) checkEvaluation(ret3, "[[1 -> a], [2 -> b, 3 -> c]]") - // val ret4 = cast(Literal.create(Array((1, 3.0, "a"), (3, 1.0, "b"))), StringType) - // checkEvaluation(ret4, "[[1, 3.0, a], [3, 1.0, b]]") + val ret4 = cast(Literal.create(Array((1, 3.0, "a"), (3, 1.0, "b"))), StringType) + checkEvaluation(ret4, "[[1, 3.0, a], [3, 1.0, b]]") } } From a46a9a76487151677d011ded7c06bee9d46c35d1 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Thu, 4 Jan 2018 11:17:56 +0900 Subject: [PATCH 04/12] Fix minor issues --- .../org/apache/spark/sql/catalyst/expressions/Cast.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index f2db94e2e47eb..02240928f7cf3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -638,7 +638,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String s""" |private void $writeStructToBuffer(InternalRow $rowTerm) { | $buffer.append("["); - | ${writeToBufferCode.mkString(s"""$buffer.append(", ");""" + "\n")} + | ${writeToBufferCode.mkString(s"""$buffer.append(\", \");\n""")} | $buffer.append("]"); |} """.stripMargin) @@ -660,7 +660,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String |${ctx.javaType(map.valueType)} $valueTerm = | ${ctx.getValue(s"$mapTerm.valueArray()", map.valueType, i)}; | - |// Write a key-value pair in the buffer + |// Write a key-value pair in buffer |$writeKeyCode; |$buffer.append(" -> "); |$writeValueCode; From 9e8390573644351bddae3836a0da254cfc7c32e9 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Thu, 4 Jan 2018 12:53:19 +0900 Subject: [PATCH 05/12] Fix non-codegen path --- .../spark/sql/catalyst/expressions/Cast.scala | 18 +++++++++++++++--- .../sql/catalyst/expressions/CastSuite.scala | 2 +- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 02240928f7cf3..9356081d1befe 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -209,9 +209,21 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String case TimestampType => buildCast[Long](_, t => UTF8String.fromString(DateTimeUtils.timestampToString(t, timeZone))) case ar: ArrayType => - buildCast[ArrayData](_, a => { - val arrayData = CatalystTypeConverters.convertToScala(a, ar).asInstanceOf[WrappedArray[_]] - UTF8String.fromString(arrayData.mkString("[", ", ", "]")) + buildCast[ArrayData](_, array => { + val res = new StringBuilder + res.append("[") + if (array.numElements > 0) { + val toStringFunc = castToString(ar.elementType) + res.append(toStringFunc(array.get(0, ar.elementType))) + var i = 1 + while (i < array.numElements) { + res.append(", ") + res.append(toStringFunc(array.get(i, ar.elementType))) + i += 1 + } + } + res.append("]") + UTF8String.fromString(res.toString()) }) case _ => buildCast[Any](_, o => UTF8String.fromString(o.toString)) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala index 356cb42620d78..8f9600b078774 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala @@ -854,7 +854,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { assert(ctx.inlinedMutableStates.length == 0) } - ignore("SPARK-22825 Cast array to string") { + test("SPARK-22825 Cast array to string") { val ret1 = cast(Literal.create(Array(1, 2, 3, 4, 5)), StringType) checkEvaluation(ret1, "[1, 2, 3, 4, 5]") val ret2 = cast(Literal.create(Array(Array(1, 2, 3), Array(4, 5))), StringType) From 1d623f8b0a53d2152b942ce3adf631a8c3dd1a12 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Thu, 4 Jan 2018 13:41:23 +0900 Subject: [PATCH 06/12] Apply review comments --- .../spark/sql/catalyst/expressions/Cast.scala | 148 +++++------------- .../sql/catalyst/expressions/CastSuite.scala | 24 ++- .../org/apache/spark/sql/SQLQuerySuite.scala | 76 ++++----- 3 files changed, 97 insertions(+), 151 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 9356081d1befe..2f2ede07912f0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -19,10 +19,8 @@ package org.apache.spark.sql.catalyst.expressions import java.math.{BigDecimal => JavaBigDecimal} -import scala.collection.mutable.WrappedArray - import org.apache.spark.SparkException -import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion} import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util._ @@ -205,6 +203,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String // UDFToString private[this] def castToString(from: DataType): Any => Any = from match { case BinaryType => buildCast[Array[Byte]](_, UTF8String.fromBytes) + case StringType => buildCast[UTF8String](_, identity) case DateType => buildCast[Int](_, d => UTF8String.fromString(DateTimeUtils.dateToString(d))) case TimestampType => buildCast[Long](_, t => UTF8String.fromString(DateTimeUtils.timestampToString(t, timeZone))) @@ -214,10 +213,9 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String res.append("[") if (array.numElements > 0) { val toStringFunc = castToString(ar.elementType) - res.append(toStringFunc(array.get(0, ar.elementType))) - var i = 1 + var i = 0 while (i < array.numElements) { - res.append(", ") + if (i != 0) res.append(", ") res.append(toStringFunc(array.get(i, ar.elementType))) i += 1 } @@ -562,7 +560,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val eval = child.genCode(ctx) - val nullSafeCast = nullSafeCastFunction(child.dataType, dataType, ctx, eval) + val nullSafeCast = nullSafeCastFunction(child.dataType, dataType, ctx) ev.copy(code = eval.code + castCode(ctx, eval.value, eval.isNull, ev.value, ev.isNull, dataType, nullSafeCast)) } @@ -574,12 +572,11 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String private[this] def nullSafeCastFunction( from: DataType, to: DataType, - ctx: CodegenContext, - ev: ExprCode): CastFunction = to match { + ctx: CodegenContext): CastFunction = to match { case _ if from == NullType => (c, evPrim, evNull) => s"$evNull = true;" case _ if to == from => (c, evPrim, evNull) => s"$evPrim = $c;" - case StringType => castToStringCode(from, ctx, ev) + case StringType => castToStringCode(from, ctx) case BinaryType => castToBinaryCode(from) case DateType => castToDateCode(from, ctx) case decimal: DecimalType => castToDecimalCode(from, decimal, ctx) @@ -594,9 +591,9 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String case DoubleType => castToDoubleCode(from) case array: ArrayType => - castArrayCode(from.asInstanceOf[ArrayType].elementType, array.elementType, ctx, ev) - case map: MapType => castMapCode(from.asInstanceOf[MapType], map, ctx, ev) - case struct: StructType => castStructCode(from.asInstanceOf[StructType], struct, ctx, ev) + castArrayCode(from.asInstanceOf[ArrayType].elementType, array.elementType, ctx) + case map: MapType => castMapCode(from.asInstanceOf[MapType], map, ctx) + case struct: StructType => castStructCode(from.asInstanceOf[StructType], struct, ctx) case udt: UserDefinedType[_] if udt.userClass == from.asInstanceOf[UserDefinedType[_]].userClass => (c, evPrim, evNull) => s"$evPrim = $c;" @@ -628,79 +625,18 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String org.apache.spark.sql.catalyst.util.DateTimeUtils.dateToString($elemTerm))""" case TimestampType => s"""$buffer.append( org.apache.spark.sql.catalyst.util.DateTimeUtils.timestampToString($elemTerm))""" - case map: MapType => s"${codegenWriteMapToBuffer(map, buffer, ctx)}($elemTerm)" - case ar: ArrayType => s"${codegenWriteArrayToBuffer(ar, buffer, ctx)}($elemTerm)" - case st: StructType => s"${codegenWriteStructToBuffer(st, buffer, ctx)}($elemTerm)" + case ar: ArrayType => s"${codegenWriteArrayToBuffer(ar, ctx)}($elemTerm, $buffer)" case _ => s"$buffer.append($elemTerm)" } - private[this] def codegenWriteStructToBuffer( - st: StructType, buffer: String, ctx: CodegenContext): String = { - val writeStructToBuffer = ctx.freshName("writeStructToBuffer") - val rowTerm = ctx.freshName("rowTerm") - val writeToBufferCode = st.zipWithIndex.map { case (f, i) => - val fieldTerm = ctx.freshName("fieldTerm") - val writeFieldCode = writeElemToBufferCode(f.dataType, buffer, fieldTerm, ctx) - s""" - |${ctx.javaType(st(i).dataType)} $fieldTerm = ${ctx.getValue(rowTerm, f.dataType, s"$i")}; - |$writeFieldCode; - """.stripMargin - } - ctx.addNewFunction(writeStructToBuffer, - s""" - |private void $writeStructToBuffer(InternalRow $rowTerm) { - | $buffer.append("["); - | ${writeToBufferCode.mkString(s"""$buffer.append(\", \");\n""")} - | $buffer.append("]"); - |} - """.stripMargin) - } - - private[this] def codegenWriteMapToBuffer( - map: MapType, buffer: String, ctx: CodegenContext): String = { - val loopIndex = ctx.freshName("loopIndex") - val writeMapToBuffer = ctx.freshName("writeMapToBuffer") - val mapTerm = ctx.freshName("mapTerm") - val keyTerm = ctx.freshName("keyTerm") - val valueTerm = ctx.freshName("valueTerm") - val writeKeyCode = writeElemToBufferCode(map.keyType, buffer, keyTerm, ctx) - val writeValueCode = writeElemToBufferCode(map.valueType, buffer, valueTerm, ctx) - def writeToBufferCode(i: String) = { - s""" - |${ctx.javaType(map.keyType)} $keyTerm = - | ${ctx.getValue(s"$mapTerm.keyArray()", map.keyType, i)}; - |${ctx.javaType(map.valueType)} $valueTerm = - | ${ctx.getValue(s"$mapTerm.valueArray()", map.valueType, i)}; - | - |// Write a key-value pair in buffer - |$writeKeyCode; - |$buffer.append(" -> "); - |$writeValueCode; - """.stripMargin - } - ctx.addNewFunction(writeMapToBuffer, - s""" - |private void $writeMapToBuffer(MapData $mapTerm) { - | $buffer.append("["); - | if ($mapTerm.numElements() > 0) { - | ${writeToBufferCode("0")} - | } - | for (int $loopIndex = 1; $loopIndex < $mapTerm.numElements(); $loopIndex++) { - | $buffer.append(", "); - | ${writeToBufferCode(loopIndex)} - | } - | $buffer.append("]"); - |} - """.stripMargin) - } - - private[this] def codegenWriteArrayToBuffer( - ar: ArrayType, buffer: String, ctx: CodegenContext): String = { + private[this] def codegenWriteArrayToBuffer(ar: ArrayType, ctx: CodegenContext): String = { val loopIndex = ctx.freshName("loopIndex") val writeArrayToBuffer = ctx.freshName("writeArrayToBuffer") val arTerm = ctx.freshName("arTerm") + val bufferClass = classOf[StringBuffer].getName + val bufferTerm = ctx.freshName("bufferTerm") val elemTerm = ctx.freshName("elemTerm") - val writeElemCode = writeElemToBufferCode(ar.elementType, buffer, elemTerm, ctx) + val writeElemCode = writeElemToBufferCode(ar.elementType, bufferTerm, elemTerm, ctx) def writeToBufferCode(i: String) = { s""" |${ctx.javaType(ar.elementType)} $elemTerm = ${ctx.getValue(arTerm, ar.elementType, i)}; @@ -709,22 +645,18 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String } ctx.addNewFunction(writeArrayToBuffer, s""" - |private void $writeArrayToBuffer(ArrayData $arTerm) { - | $buffer.append("["); - | if ($arTerm.numElements() > 0) { - | ${writeToBufferCode("0")} - | } - | for (int $loopIndex = 1; $loopIndex < $arTerm.numElements(); $loopIndex++) { - | $buffer.append(", "); + |private void $writeArrayToBuffer(ArrayData $arTerm, $bufferClass $bufferTerm) { + | $bufferTerm.append("["); + | for (int $loopIndex = 0; $loopIndex < $arTerm.numElements(); $loopIndex++) { + | if ($loopIndex != 0) $bufferTerm.append(", "); | ${writeToBufferCode(loopIndex)} | } - | $buffer.append("]"); + | $bufferTerm.append("]"); |} """.stripMargin) } - private[this] def castToStringCode(from: DataType, ctx: CodegenContext, ev: ExprCode) - : CastFunction = { + private[this] def castToStringCode(from: DataType, ctx: CodegenContext): CastFunction = { from match { case BinaryType => (c, evPrim, evNull) => s"$evPrim = UTF8String.fromBytes($c);" @@ -736,21 +668,19 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String (c, evPrim, evNull) => s"""$evPrim = UTF8String.fromString( org.apache.spark.sql.catalyst.util.DateTimeUtils.timestampToString($c, $tz));""" case ar: ArrayType => - val bufferClass = classOf[StringBuffer].getName - val buffer = ctx.addMutableState(bufferClass, "buffer", v => s"$v = new $bufferClass();") - val writeArrayToBuffer = codegenWriteArrayToBuffer(ar, buffer, ctx) - val arrayToStringCode = + (c, evPrim, evNull) => { + val bufferTerm = ctx.freshName("bufferTerm") + val bufferClass = classOf[StringBuffer].getName + val writeArrayToBuffer = codegenWriteArrayToBuffer(ar, ctx) s""" - |if (!${ev.isNull}) { - | // Reset buffer first - | $buffer.delete(0, $buffer.length()); - | $writeArrayToBuffer(${ev.value}); + |$bufferClass $bufferTerm = new $bufferClass(); + |if (!$evNull) { + | $writeArrayToBuffer($c, $bufferTerm); |} + | + |$evPrim = UTF8String.fromString($bufferTerm.toString()); """.stripMargin - ev.code = ev.code ++ arrayToStringCode - ev.value = buffer - (c, evPrim, evNull) => - s"""$evPrim = UTF8String.fromString($buffer.toString());""" + } case _ => (c, evPrim, evNull) => s"$evPrim = UTF8String.fromString(String.valueOf($c));" } @@ -1088,8 +1018,8 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String } private[this] def castArrayCode( - fromType: DataType, toType: DataType, ctx: CodegenContext, ev: ExprCode): CastFunction = { - val elementCast = nullSafeCastFunction(fromType, toType, ctx, ev) + fromType: DataType, toType: DataType, ctx: CodegenContext): CastFunction = { + val elementCast = nullSafeCastFunction(fromType, toType, ctx) val arrayClass = classOf[GenericArrayData].getName val fromElementNull = ctx.freshName("feNull") val fromElementPrim = ctx.freshName("fePrim") @@ -1123,10 +1053,9 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String """ } - private[this] def castMapCode(from: MapType, to: MapType, ctx: CodegenContext, ev: ExprCode) - : CastFunction = { - val keysCast = castArrayCode(from.keyType, to.keyType, ctx, ev) - val valuesCast = castArrayCode(from.valueType, to.valueType, ctx, ev) + private[this] def castMapCode(from: MapType, to: MapType, ctx: CodegenContext): CastFunction = { + val keysCast = castArrayCode(from.keyType, to.keyType, ctx) + val valuesCast = castArrayCode(from.valueType, to.valueType, ctx) val mapClass = classOf[ArrayBasedMapData].getName @@ -1152,11 +1081,10 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String } private[this] def castStructCode( - from: StructType, to: StructType, ctx: CodegenContext, ev: ExprCode): CastFunction = { + from: StructType, to: StructType, ctx: CodegenContext): CastFunction = { val fieldsCasts = from.fields.zip(to.fields).map { - case (fromField, toField) => - nullSafeCastFunction(fromField.dataType, toField.dataType, ctx, ev) + case (fromField, toField) => nullSafeCastFunction(fromField.dataType, toField.dataType, ctx) } val rowClass = classOf[GenericInternalRow].getName val tmpResult = ctx.freshName("tmpResult") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala index 8f9600b078774..4bda0e2fe2657 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala @@ -857,11 +857,23 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { test("SPARK-22825 Cast array to string") { val ret1 = cast(Literal.create(Array(1, 2, 3, 4, 5)), StringType) checkEvaluation(ret1, "[1, 2, 3, 4, 5]") - val ret2 = cast(Literal.create(Array(Array(1, 2, 3), Array(4, 5))), StringType) - checkEvaluation(ret2, "[[1, 2, 3], [4, 5]]") - val ret3 = cast(Literal.create(Array(Map(1 -> "a"), Map(2 -> "b", 3 -> "c"))), StringType) - checkEvaluation(ret3, "[[1 -> a], [2 -> b, 3 -> c]]") - val ret4 = cast(Literal.create(Array((1, 3.0, "a"), (3, 1.0, "b"))), StringType) - checkEvaluation(ret4, "[[1, 3.0, a], [3, 1.0, b]]") + val ret2 = cast(Literal.create(Array("ab", "cde", "f")), StringType) + checkEvaluation(ret2, "[ab, cde, f]") + val ret3 = cast(Literal.create(Array("ab".getBytes, "cde".getBytes, "f".getBytes)), StringType) + checkEvaluation(ret3, "[ab, cde, f]") + val ret4 = cast( + Literal.create(Array("2014-12-03", "2014-12-04", "2014-12-06").map(Date.valueOf)), + StringType) + checkEvaluation(ret4, "[2014-12-03, 2014-12-04, 2014-12-06]") + val ret5 = cast( + Literal.create(Array("2014-12-03 13:01:00", "2014-12-04 15:05:00").map(Timestamp.valueOf)), + StringType) + checkEvaluation(ret5, "[2014-12-03 13:01:00, 2014-12-04 15:05:00]") + val ret6 = cast(Literal.create(Array(Array(1, 2, 3), Array(4, 5))), StringType) + checkEvaluation(ret6, "[[1, 2, 3], [4, 5]]") + val ret7 = cast( + Literal.create(Array(Array(Array("a"), Array("b", "c")), Array(Array("d")))), + StringType) + checkEvaluation(ret7, "[[[a], [b, c]], [[d]]]") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 9157921ab1518..9da854c87063f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql import java.io.File import java.math.MathContext import java.net.{MalformedURLException, URL} -import java.sql.Timestamp +import java.sql.{Date, Timestamp} import java.util.concurrent.atomic.AtomicBoolean import org.apache.spark.{AccumulatorSuite, SparkException} @@ -2775,40 +2775,46 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } test("SPARK-22825 Cast array to string") { - // Check non-codegen path - val df1 = sql("SELECT CAST(ARRAY(1, 2, 3, 4) AS STRING)") - checkAnswer(df1, Row("[1, 2, 3, 4]")) - val df2 = sql("SELECT CAST(ARRAY(ARRAY(1, 2), ARRAY(3, 4, 5), ARRAY(6, 7)) AS STRING)") - checkAnswer(df2, Row("[WrappedArray(1, 2), WrappedArray(3, 4, 5), WrappedArray(6, 7)]")) - val df3 = sql("SELECT CAST(ARRAY(MAP(1, 'a', 2, 'b'), MAP(3, 'c')) AS STRING)") - checkAnswer(df3, Row("[Map(1 -> a, 2 -> b), Map(3 -> c)]")) - val df4 = sql("SELECT CAST(ARRAY(STRUCT(1, 0.3, 'a'), STRUCT(2, 0.5, 'b')) AS STRING)") - checkAnswer(df4, Row("[[1,0.3,a], [2,0.5,b]]")) - - // Check codegen path - withTable("t") { - Seq(Seq(0, 1, 2, 3, 4)).toDF("a").write.saveAsTable("t") - val df = sql("SELECT CAST(a AS STRING) FROM t") - checkAnswer(df, Row("[0, 1, 2, 3, 4]")) - } - - withTable("t") { - Seq(Seq(Seq(1, 2), Seq(3), Seq(4, 5, 6))).toDF("a").write.saveAsTable("t") - val df = sql("SELECT CAST(a AS STRING) FROM t") - checkAnswer(df, Row("[[1, 2], [3], [4, 5, 6]]")) - } - - withTable("t") { - Seq(Seq(Map(1 -> "a", 2 -> "b"), Map(3 -> "c"), Map(4 -> "d", 5 -> "e"))).toDF("a") - .write.saveAsTable("t") - val df = sql("SELECT CAST(a AS STRING) FROM t") - checkAnswer(df, Row("[[1 -> a, 2 -> b], [3 -> c], [4 -> d, 5 -> e]]")) - } - - withTable("t") { - Seq(Seq((1, "a"), (2, "b")), Seq((3, "c"))).toDF("a").write.saveAsTable("t") - val df = sql("SELECT CAST(a AS STRING) FROM t") - checkAnswer(df, Row("[[1, a], [2, b]]") :: Row("[[3, c]]") :: Nil) + Seq("true", "false").foreach { codegen => + withSQLConf("spark.sql.codegen.wholeStage" -> codegen) { + withTable("t") { + Seq(Seq(0, 1, 2, 3, 4)).toDF("a").write.saveAsTable("t") + val df = sql("SELECT CAST(a AS STRING) FROM t") + checkAnswer(df, Row("[0, 1, 2, 3, 4]")) + } + withTable("t") { + Seq(Seq("ab", "cde", "f")).toDF("a").write.saveAsTable("t") + val df = sql("SELECT CAST(a AS STRING) FROM t") + checkAnswer(df, Row("[ab, cde, f]")) + } + withTable("t") { + Seq(Seq("ab".getBytes, "cde".getBytes, "f".getBytes)).toDF("a").write.saveAsTable("t") + val df = sql("SELECT CAST(a AS STRING) FROM t") + checkAnswer(df, Row("[ab, cde, f]")) + } + withTable("t") { + Seq(Seq("2014-12-03", "2014-12-04", "2014-12-06").map(Date.valueOf)) + .toDF("a").write.saveAsTable("t") + val df = sql("SELECT CAST(a AS STRING) FROM t") + checkAnswer(df, Row("[2014-12-03, 2014-12-04, 2014-12-06]")) + } + withTable("t") { + Seq(Seq("2014-12-03 13:01:00", "2014-12-04 15:05:00").map(Timestamp.valueOf)) + .toDF("a").write.saveAsTable("t") + val df = sql("SELECT CAST(a AS STRING) FROM t") + checkAnswer(df, Row("[2014-12-03 13:01:00, 2014-12-04 15:05:00]")) + } + withTable("t") { + Seq(Seq(Seq(1, 2), Seq(3), Seq(4, 5, 6))).toDF("a").write.saveAsTable("t") + val df = sql("SELECT CAST(a AS STRING) FROM t") + checkAnswer(df, Row("[[1, 2], [3], [4, 5, 6]]")) + } + withTable("t") { + Seq(Seq(Seq(Seq("a"), Seq("b", "c")), Seq(Seq("d")))).toDF("a").write.saveAsTable("t") + val df = sql("SELECT CAST(a AS STRING) FROM t") + checkAnswer(df, Row("[[[a], [b, c]], [[d]]]")) + } + } } } } From 99c3ed061c841ea40cebf564ddfe099a1800254b Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Thu, 4 Jan 2018 16:39:32 +0900 Subject: [PATCH 07/12] Fix --- .../spark/sql/catalyst/expressions/Cast.scala | 31 ++++++++++--------- 1 file changed, 17 insertions(+), 14 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 2f2ede07912f0..662a26e0692b4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -213,9 +213,10 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String res.append("[") if (array.numElements > 0) { val toStringFunc = castToString(ar.elementType) - var i = 0 + res.append(toStringFunc(array.get(0, ar.elementType))) + var i = 1 while (i < array.numElements) { - if (i != 0) res.append(", ") + res.append(", ") res.append(toStringFunc(array.get(i, ar.elementType))) i += 1 } @@ -633,23 +634,28 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String val loopIndex = ctx.freshName("loopIndex") val writeArrayToBuffer = ctx.freshName("writeArrayToBuffer") val arTerm = ctx.freshName("arTerm") - val bufferClass = classOf[StringBuffer].getName + val bufferClass = "java.lang.StringBuilder" val bufferTerm = ctx.freshName("bufferTerm") - val elemTerm = ctx.freshName("elemTerm") - val writeElemCode = writeElemToBufferCode(ar.elementType, bufferTerm, elemTerm, ctx) + def writeElemCode(elemTerm: String) = { + writeElemToBufferCode(ar.elementType, bufferTerm, elemTerm, ctx) + } def writeToBufferCode(i: String) = { + val elemTerm = ctx.freshName("elemTerm") s""" |${ctx.javaType(ar.elementType)} $elemTerm = ${ctx.getValue(arTerm, ar.elementType, i)}; - |$writeElemCode; + |${writeElemCode(elemTerm)}; """.stripMargin } ctx.addNewFunction(writeArrayToBuffer, s""" |private void $writeArrayToBuffer(ArrayData $arTerm, $bufferClass $bufferTerm) { | $bufferTerm.append("["); - | for (int $loopIndex = 0; $loopIndex < $arTerm.numElements(); $loopIndex++) { - | if ($loopIndex != 0) $bufferTerm.append(", "); - | ${writeToBufferCode(loopIndex)} + | if ($arTerm.numElements() > 0) { + | ${writeToBufferCode("0")} + | for (int $loopIndex = 1; $loopIndex < $arTerm.numElements(); $loopIndex++) { + | $bufferTerm.append(", "); + | ${writeToBufferCode(loopIndex)} + | } | } | $bufferTerm.append("]"); |} @@ -670,14 +676,11 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String case ar: ArrayType => (c, evPrim, evNull) => { val bufferTerm = ctx.freshName("bufferTerm") - val bufferClass = classOf[StringBuffer].getName + val bufferClass = "java.lang.StringBuilder" val writeArrayToBuffer = codegenWriteArrayToBuffer(ar, ctx) s""" |$bufferClass $bufferTerm = new $bufferClass(); - |if (!$evNull) { - | $writeArrayToBuffer($c, $bufferTerm); - |} - | + |$writeArrayToBuffer($c, $bufferTerm); |$evPrim = UTF8String.fromString($bufferTerm.toString()); """.stripMargin } From b5b5e35860bcf19c621813073991cf2776d4a388 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Thu, 4 Jan 2018 18:08:05 +0900 Subject: [PATCH 08/12] Add UTF8StringBuilder --- .../codegen/UTF8StringBuilder.java | 52 +++++++++++++++++++ .../spark/sql/catalyst/expressions/Cast.scala | 50 ++++++++++++++---- 2 files changed, 91 insertions(+), 11 deletions(-) create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UTF8StringBuilder.java diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UTF8StringBuilder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UTF8StringBuilder.java new file mode 100644 index 0000000000000..1e75fb87fdcc5 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UTF8StringBuilder.java @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.codegen; + +import java.nio.charset.StandardCharsets; + +import org.apache.spark.unsafe.types.UTF8String; + +/** + * A helper class to write `UTF8String`, `String`, and `byte[]` data into an internal buffer + * and get a final concatenated string. + */ +public class UTF8StringBuilder { + + private StringBuilder buffer; + + public UTF8StringBuilder() { + this.buffer = new StringBuilder(); + } + + public void append(UTF8String value) { + buffer.append(value); + } + + public void append(String value) { + buffer.append(value); + } + + public void append(byte[] value) { + buffer.append(new String(value, StandardCharsets.UTF_8)); + } + + @Override + public String toString() { + return buffer.toString(); + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 662a26e0692b4..37b5b306d26c5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -199,30 +199,59 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String // [[func]] assumes the input is no longer null because eval already does the null check. @inline private[this] def buildCast[T](a: Any, func: T => Any): Any = func(a.asInstanceOf[T]) + @inline private[this] def buildWriter[T]( + a: Any, buffer: UTF8StringBuilder, writer: (T, UTF8StringBuilder) => Unit): Unit = { + writer(a.asInstanceOf[T], buffer) + } + + private[this] def buildElemWriter( + from: DataType): (Any, UTF8StringBuilder) => Unit = from match { + case BinaryType => buildWriter[Array[Byte]](_, _, (b, buf) => buf.append(b)) + case StringType => buildWriter[UTF8String](_, _, (b, buf) => buf.append(b)) + case DateType => buildWriter[Int](_, _, + (d, buf) => buf.append(DateTimeUtils.dateToString(d))) + case TimestampType => buildWriter[Long](_, _, + (t, buf) => buf.append(DateTimeUtils.timestampToString(t))) + case ar: ArrayType => + buildWriter[ArrayData](_, _, (array, buf) => { + buf.append("[") + if (array.numElements > 0) { + val writeElemToBuffer = buildElemWriter(ar.elementType) + writeElemToBuffer(array.get(0, ar.elementType), buf) + var i = 1 + while (i < array.numElements) { + buf.append(", ") + writeElemToBuffer(array.get(i, ar.elementType), buf) + i += 1 + } + } + buf.append("]") + }) + case _ => buildWriter[Any](_, _, (o, buf) => buf.append(String.valueOf(o))) + } // UDFToString private[this] def castToString(from: DataType): Any => Any = from match { case BinaryType => buildCast[Array[Byte]](_, UTF8String.fromBytes) - case StringType => buildCast[UTF8String](_, identity) case DateType => buildCast[Int](_, d => UTF8String.fromString(DateTimeUtils.dateToString(d))) case TimestampType => buildCast[Long](_, t => UTF8String.fromString(DateTimeUtils.timestampToString(t, timeZone))) case ar: ArrayType => buildCast[ArrayData](_, array => { - val res = new StringBuilder + val res = new UTF8StringBuilder res.append("[") if (array.numElements > 0) { - val toStringFunc = castToString(ar.elementType) - res.append(toStringFunc(array.get(0, ar.elementType))) + val writeElemToBuffer = buildElemWriter(ar.elementType) + writeElemToBuffer(array.get(0, ar.elementType), res) var i = 1 while (i < array.numElements) { res.append(", ") - res.append(toStringFunc(array.get(i, ar.elementType))) + writeElemToBuffer(array.get(i, ar.elementType), res) i += 1 } } res.append("]") - UTF8String.fromString(res.toString()) + UTF8String.fromString(res.toString) }) case _ => buildCast[Any](_, o => UTF8String.fromString(o.toString)) } @@ -620,21 +649,20 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String buffer: String, elemTerm: String, ctx: CodegenContext): String = dataType match { - case BinaryType => s"$buffer.append(new String($elemTerm))" - case StringType => s"$buffer.append(new String($elemTerm.getBytes()))" + case BinaryType | StringType => s"$buffer.append($elemTerm)" case DateType => s"""$buffer.append( org.apache.spark.sql.catalyst.util.DateTimeUtils.dateToString($elemTerm))""" case TimestampType => s"""$buffer.append( org.apache.spark.sql.catalyst.util.DateTimeUtils.timestampToString($elemTerm))""" case ar: ArrayType => s"${codegenWriteArrayToBuffer(ar, ctx)}($elemTerm, $buffer)" - case _ => s"$buffer.append($elemTerm)" + case _ => s"$buffer.append(String.valueOf($elemTerm))" } private[this] def codegenWriteArrayToBuffer(ar: ArrayType, ctx: CodegenContext): String = { val loopIndex = ctx.freshName("loopIndex") val writeArrayToBuffer = ctx.freshName("writeArrayToBuffer") val arTerm = ctx.freshName("arTerm") - val bufferClass = "java.lang.StringBuilder" + val bufferClass = classOf[UTF8StringBuilder].getName val bufferTerm = ctx.freshName("bufferTerm") def writeElemCode(elemTerm: String) = { writeElemToBufferCode(ar.elementType, bufferTerm, elemTerm, ctx) @@ -676,7 +704,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String case ar: ArrayType => (c, evPrim, evNull) => { val bufferTerm = ctx.freshName("bufferTerm") - val bufferClass = "java.lang.StringBuilder" + val bufferClass = classOf[UTF8StringBuilder].getName val writeArrayToBuffer = codegenWriteArrayToBuffer(ar, ctx) s""" |$bufferClass $bufferTerm = new $bufferClass(); From b0b3cd6fce4fc6fc0eadced0890e9635f9adaf07 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Thu, 4 Jan 2018 21:08:54 +0900 Subject: [PATCH 09/12] Brush up UTF8StringBuilder --- .../codegen/UTF8StringBuilder.java | 57 +++++++++++++++---- .../spark/sql/catalyst/expressions/Cast.scala | 40 ++----------- 2 files changed, 52 insertions(+), 45 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UTF8StringBuilder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UTF8StringBuilder.java index 1e75fb87fdcc5..37ef2deb2a8b0 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UTF8StringBuilder.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UTF8StringBuilder.java @@ -19,34 +19,71 @@ import java.nio.charset.StandardCharsets; +import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.types.UTF8String; /** - * A helper class to write `UTF8String`, `String`, and `byte[]` data into an internal buffer - * and get a final concatenated string. + * A helper class to write `UTF8String`, `String`, and `byte[]` data into an internal byte buffer + * and get written data as `UTF8String`. */ public class UTF8StringBuilder { - private StringBuilder buffer; + private static final int ARRAY_MAX = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH; + + private byte[] buffer; + private int cursor = Platform.BYTE_ARRAY_OFFSET; public UTF8StringBuilder() { - this.buffer = new StringBuilder(); + // Since initial buffer size is 16 in `StringBuilder`, we set the same size here + this.buffer = new byte[16]; + } + + // Grows the buffer by at least `neededSize` + private void grow(int neededSize) { + if (neededSize > ARRAY_MAX - totalSize()) { + throw new UnsupportedOperationException( + "Cannot grow internal buffer by size " + neededSize + " because the size after growing " + + "exceeds size limitation " + ARRAY_MAX); + } + final int length = totalSize() + neededSize; + if (buffer.length < length) { + int newLength = length < ARRAY_MAX / 2 ? length * 2 : ARRAY_MAX; + final byte[] tmp = new byte[newLength]; + Platform.copyMemory( + buffer, + Platform.BYTE_ARRAY_OFFSET, + tmp, + Platform.BYTE_ARRAY_OFFSET, + totalSize()); + buffer = tmp; + } } public void append(UTF8String value) { - buffer.append(value); + grow(value.numBytes()); + value.writeToMemory(buffer, cursor); + cursor += value.numBytes(); } public void append(String value) { - buffer.append(value); + append(value.getBytes(StandardCharsets.UTF_8)); } public void append(byte[] value) { - buffer.append(new String(value, StandardCharsets.UTF_8)); + grow(value.length); + Platform.copyMemory(value, Platform.BYTE_ARRAY_OFFSET, buffer, cursor, value.length); + cursor += value.length; + } + + public UTF8String toUTF8String() { + final int len = totalSize(); + final byte[] bytes = new byte[len]; + Platform.copyMemory(buffer, Platform.BYTE_ARRAY_OFFSET, bytes, Platform.BYTE_ARRAY_OFFSET, len); + return UTF8String.fromBytes(bytes); } - @Override - public String toString() { - return buffer.toString(); + public int totalSize() { + return cursor - Platform.BYTE_ARRAY_OFFSET; } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 37b5b306d26c5..fb7e40d908d48 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -199,36 +199,6 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String // [[func]] assumes the input is no longer null because eval already does the null check. @inline private[this] def buildCast[T](a: Any, func: T => Any): Any = func(a.asInstanceOf[T]) - @inline private[this] def buildWriter[T]( - a: Any, buffer: UTF8StringBuilder, writer: (T, UTF8StringBuilder) => Unit): Unit = { - writer(a.asInstanceOf[T], buffer) - } - - private[this] def buildElemWriter( - from: DataType): (Any, UTF8StringBuilder) => Unit = from match { - case BinaryType => buildWriter[Array[Byte]](_, _, (b, buf) => buf.append(b)) - case StringType => buildWriter[UTF8String](_, _, (b, buf) => buf.append(b)) - case DateType => buildWriter[Int](_, _, - (d, buf) => buf.append(DateTimeUtils.dateToString(d))) - case TimestampType => buildWriter[Long](_, _, - (t, buf) => buf.append(DateTimeUtils.timestampToString(t))) - case ar: ArrayType => - buildWriter[ArrayData](_, _, (array, buf) => { - buf.append("[") - if (array.numElements > 0) { - val writeElemToBuffer = buildElemWriter(ar.elementType) - writeElemToBuffer(array.get(0, ar.elementType), buf) - var i = 1 - while (i < array.numElements) { - buf.append(", ") - writeElemToBuffer(array.get(i, ar.elementType), buf) - i += 1 - } - } - buf.append("]") - }) - case _ => buildWriter[Any](_, _, (o, buf) => buf.append(String.valueOf(o))) - } // UDFToString private[this] def castToString(from: DataType): Any => Any = from match { @@ -241,17 +211,17 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String val res = new UTF8StringBuilder res.append("[") if (array.numElements > 0) { - val writeElemToBuffer = buildElemWriter(ar.elementType) - writeElemToBuffer(array.get(0, ar.elementType), res) + val toUTF8String = castToString(ar.elementType) + res.append(toUTF8String(array.get(0, ar.elementType)).asInstanceOf[UTF8String]) var i = 1 while (i < array.numElements) { res.append(", ") - writeElemToBuffer(array.get(i, ar.elementType), res) + res.append(toUTF8String(array.get(i, ar.elementType)).asInstanceOf[UTF8String]) i += 1 } } res.append("]") - UTF8String.fromString(res.toString) + res.toUTF8String }) case _ => buildCast[Any](_, o => UTF8String.fromString(o.toString)) } @@ -709,7 +679,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String s""" |$bufferClass $bufferTerm = new $bufferClass(); |$writeArrayToBuffer($c, $bufferTerm); - |$evPrim = UTF8String.fromString($bufferTerm.toString()); + |$evPrim = $bufferTerm.toUTF8String(); """.stripMargin } case _ => From 09fd22ead2ec0267915bef2c169f90b08da7b63e Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Thu, 4 Jan 2018 23:59:35 +0900 Subject: [PATCH 10/12] Fix --- .../codegen/UTF8StringBuilder.java | 23 ++---- .../spark/sql/catalyst/expressions/Cast.scala | 77 +++++++++---------- .../sql/catalyst/expressions/CastSuite.scala | 22 +++--- .../org/apache/spark/sql/SQLQuerySuite.scala | 5 ++ 4 files changed, 61 insertions(+), 66 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UTF8StringBuilder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UTF8StringBuilder.java index 37ef2deb2a8b0..827d45422e295 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UTF8StringBuilder.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UTF8StringBuilder.java @@ -60,6 +60,10 @@ private void grow(int neededSize) { } } + private int totalSize() { + return cursor - Platform.BYTE_ARRAY_OFFSET; + } + public void append(UTF8String value) { grow(value.numBytes()); value.writeToMemory(buffer, cursor); @@ -67,23 +71,10 @@ public void append(UTF8String value) { } public void append(String value) { - append(value.getBytes(StandardCharsets.UTF_8)); - } - - public void append(byte[] value) { - grow(value.length); - Platform.copyMemory(value, Platform.BYTE_ARRAY_OFFSET, buffer, cursor, value.length); - cursor += value.length; - } - - public UTF8String toUTF8String() { - final int len = totalSize(); - final byte[] bytes = new byte[len]; - Platform.copyMemory(buffer, Platform.BYTE_ARRAY_OFFSET, bytes, Platform.BYTE_ARRAY_OFFSET, len); - return UTF8String.fromBytes(bytes); + append(UTF8String.fromString(value)); } - public int totalSize() { - return cursor - Platform.BYTE_ARRAY_OFFSET; + public UTF8String build() { + return UTF8String.fromBytes(buffer, 0, totalSize()); } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index fb7e40d908d48..7e98c10ff6643 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -206,22 +206,27 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String case DateType => buildCast[Int](_, d => UTF8String.fromString(DateTimeUtils.dateToString(d))) case TimestampType => buildCast[Long](_, t => UTF8String.fromString(DateTimeUtils.timestampToString(t, timeZone))) - case ar: ArrayType => + case ArrayType(et, _) => buildCast[ArrayData](_, array => { - val res = new UTF8StringBuilder - res.append("[") + val builder = new UTF8StringBuilder + builder.append("[") if (array.numElements > 0) { - val toUTF8String = castToString(ar.elementType) - res.append(toUTF8String(array.get(0, ar.elementType)).asInstanceOf[UTF8String]) + val toUTF8String = castToString(et) + if (!array.isNullAt(0)) { + builder.append(toUTF8String(array.get(0, et)).asInstanceOf[UTF8String]) + } var i = 1 while (i < array.numElements) { - res.append(", ") - res.append(toUTF8String(array.get(i, ar.elementType)).asInstanceOf[UTF8String]) + builder.append(",") + if (!array.isNullAt(i)) { + builder.append(" ") + builder.append(toUTF8String(array.get(i, et)).asInstanceOf[UTF8String]) + } i += 1 } } - res.append("]") - res.toUTF8String + builder.append("]") + builder.build() }) case _ => buildCast[Any](_, o => UTF8String.fromString(o.toString)) } @@ -614,45 +619,37 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String """ } - private[this] def writeElemToBufferCode( - dataType: DataType, - buffer: String, - elemTerm: String, - ctx: CodegenContext): String = dataType match { - case BinaryType | StringType => s"$buffer.append($elemTerm)" - case DateType => s"""$buffer.append( - org.apache.spark.sql.catalyst.util.DateTimeUtils.dateToString($elemTerm))""" - case TimestampType => s"""$buffer.append( - org.apache.spark.sql.catalyst.util.DateTimeUtils.timestampToString($elemTerm))""" - case ar: ArrayType => s"${codegenWriteArrayToBuffer(ar, ctx)}($elemTerm, $buffer)" - case _ => s"$buffer.append(String.valueOf($elemTerm))" - } + private[this] def codegenWriteArrayElemCode(et: DataType, ctx: CodegenContext): String = { + val elementToStringCode = castToStringCode(et, ctx) + val funcName = ctx.freshName("elementToString") + val elementToStringFunc = ctx.addNewFunction(funcName, + s""" + |private UTF8String $funcName(${ctx.javaType(et)} element) { + | UTF8String elementStr = null; + | ${elementToStringCode("element", "elementStr", null /* resultIsNull won't be used */)} + | return elementStr; + |} + """.stripMargin) - private[this] def codegenWriteArrayToBuffer(ar: ArrayType, ctx: CodegenContext): String = { val loopIndex = ctx.freshName("loopIndex") val writeArrayToBuffer = ctx.freshName("writeArrayToBuffer") val arTerm = ctx.freshName("arTerm") val bufferClass = classOf[UTF8StringBuilder].getName val bufferTerm = ctx.freshName("bufferTerm") - def writeElemCode(elemTerm: String) = { - writeElemToBufferCode(ar.elementType, bufferTerm, elemTerm, ctx) - } - def writeToBufferCode(i: String) = { - val elemTerm = ctx.freshName("elemTerm") - s""" - |${ctx.javaType(ar.elementType)} $elemTerm = ${ctx.getValue(arTerm, ar.elementType, i)}; - |${writeElemCode(elemTerm)}; - """.stripMargin - } ctx.addNewFunction(writeArrayToBuffer, s""" |private void $writeArrayToBuffer(ArrayData $arTerm, $bufferClass $bufferTerm) { | $bufferTerm.append("["); | if ($arTerm.numElements() > 0) { - | ${writeToBufferCode("0")} + | if (!$arTerm.isNullAt(0)) { + | $bufferTerm.append($elementToStringFunc(${ctx.getValue(arTerm, et, "0")})); + | } | for (int $loopIndex = 1; $loopIndex < $arTerm.numElements(); $loopIndex++) { - | $bufferTerm.append(", "); - | ${writeToBufferCode(loopIndex)} + | $bufferTerm.append(","); + | if (!$arTerm.isNullAt($loopIndex)) { + | $bufferTerm.append(" "); + | $bufferTerm.append($elementToStringFunc(${ctx.getValue(arTerm, et, loopIndex)})); + | } | } | } | $bufferTerm.append("]"); @@ -671,15 +668,15 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String val tz = ctx.addReferenceObj("timeZone", timeZone) (c, evPrim, evNull) => s"""$evPrim = UTF8String.fromString( org.apache.spark.sql.catalyst.util.DateTimeUtils.timestampToString($c, $tz));""" - case ar: ArrayType => + case ArrayType(et, _) => (c, evPrim, evNull) => { val bufferTerm = ctx.freshName("bufferTerm") val bufferClass = classOf[UTF8StringBuilder].getName - val writeArrayToBuffer = codegenWriteArrayToBuffer(ar, ctx) + val writeArrayElemCode = codegenWriteArrayElemCode(et, ctx) s""" |$bufferClass $bufferTerm = new $bufferClass(); - |$writeArrayToBuffer($c, $bufferTerm); - |$evPrim = $bufferTerm.toUTF8String(); + |$writeArrayElemCode($c, $bufferTerm); + |$evPrim = $bufferTerm.build(); """.stripMargin } case _ => diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala index 4bda0e2fe2657..e3ed7171defd8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala @@ -859,21 +859,23 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(ret1, "[1, 2, 3, 4, 5]") val ret2 = cast(Literal.create(Array("ab", "cde", "f")), StringType) checkEvaluation(ret2, "[ab, cde, f]") - val ret3 = cast(Literal.create(Array("ab".getBytes, "cde".getBytes, "f".getBytes)), StringType) - checkEvaluation(ret3, "[ab, cde, f]") - val ret4 = cast( + val ret3 = cast(Literal.create(Array("ab", null, "c")), StringType) + checkEvaluation(ret3, "[ab,, c]") + val ret4 = cast(Literal.create(Array("ab".getBytes, "cde".getBytes, "f".getBytes)), StringType) + checkEvaluation(ret4, "[ab, cde, f]") + val ret5 = cast( Literal.create(Array("2014-12-03", "2014-12-04", "2014-12-06").map(Date.valueOf)), StringType) - checkEvaluation(ret4, "[2014-12-03, 2014-12-04, 2014-12-06]") - val ret5 = cast( + checkEvaluation(ret5, "[2014-12-03, 2014-12-04, 2014-12-06]") + val ret6 = cast( Literal.create(Array("2014-12-03 13:01:00", "2014-12-04 15:05:00").map(Timestamp.valueOf)), StringType) - checkEvaluation(ret5, "[2014-12-03 13:01:00, 2014-12-04 15:05:00]") - val ret6 = cast(Literal.create(Array(Array(1, 2, 3), Array(4, 5))), StringType) - checkEvaluation(ret6, "[[1, 2, 3], [4, 5]]") - val ret7 = cast( + checkEvaluation(ret6, "[2014-12-03 13:01:00, 2014-12-04 15:05:00]") + val ret7 = cast(Literal.create(Array(Array(1, 2, 3), Array(4, 5))), StringType) + checkEvaluation(ret7, "[[1, 2, 3], [4, 5]]") + val ret8 = cast( Literal.create(Array(Array(Array("a"), Array("b", "c")), Array(Array("d")))), StringType) - checkEvaluation(ret7, "[[[a], [b, c]], [[d]]]") + checkEvaluation(ret8, "[[[a], [b, c]], [[d]]]") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 9da854c87063f..851618ee03046 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -2787,6 +2787,11 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { val df = sql("SELECT CAST(a AS STRING) FROM t") checkAnswer(df, Row("[ab, cde, f]")) } + withTable("t") { + Seq(Seq("ab", null, "c")).toDF("a").write.saveAsTable("t") + val df = sql("SELECT CAST(a AS STRING) FROM t") + checkAnswer(df, Row("[ab,, c]")) + } withTable("t") { Seq(Seq("ab".getBytes, "cde".getBytes, "f".getBytes)).toDF("a").write.saveAsTable("t") val df = sql("SELECT CAST(a AS STRING) FROM t") From 449e2c9c8c5c48a14a9b2efec728b350463188bf Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Fri, 5 Jan 2018 07:54:21 +0900 Subject: [PATCH 11/12] Fix --- .../codegen/UTF8StringBuilder.java | 6 +-- .../spark/sql/catalyst/expressions/Cast.scala | 49 +++++++++--------- .../org/apache/spark/sql/SQLQuerySuite.scala | 51 +------------------ 3 files changed, 26 insertions(+), 80 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UTF8StringBuilder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UTF8StringBuilder.java index 827d45422e295..f0f66bae245fd 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UTF8StringBuilder.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UTF8StringBuilder.java @@ -17,15 +17,13 @@ package org.apache.spark.sql.catalyst.expressions.codegen; -import java.nio.charset.StandardCharsets; - import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.types.UTF8String; /** - * A helper class to write `UTF8String`, `String`, and `byte[]` data into an internal byte buffer - * and get written data as `UTF8String`. + * A helper class to write {@link UTF8String}s to an internal buffer and build the concatenated + * {@link UTF8String} at the end. */ public class UTF8StringBuilder { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 7e98c10ff6643..02d27d3914bb0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -619,7 +619,11 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String """ } - private[this] def codegenWriteArrayElemCode(et: DataType, ctx: CodegenContext): String = { + private def writeArrayToStringBuilder( + et: DataType, + arTerm: String, + bufferTerm: String, + ctx: CodegenContext): String = { val elementToStringCode = castToStringCode(et, ctx) val funcName = ctx.freshName("elementToString") val elementToStringFunc = ctx.addNewFunction(funcName, @@ -632,29 +636,22 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String """.stripMargin) val loopIndex = ctx.freshName("loopIndex") - val writeArrayToBuffer = ctx.freshName("writeArrayToBuffer") - val arTerm = ctx.freshName("arTerm") - val bufferClass = classOf[UTF8StringBuilder].getName - val bufferTerm = ctx.freshName("bufferTerm") - ctx.addNewFunction(writeArrayToBuffer, - s""" - |private void $writeArrayToBuffer(ArrayData $arTerm, $bufferClass $bufferTerm) { - | $bufferTerm.append("["); - | if ($arTerm.numElements() > 0) { - | if (!$arTerm.isNullAt(0)) { - | $bufferTerm.append($elementToStringFunc(${ctx.getValue(arTerm, et, "0")})); - | } - | for (int $loopIndex = 1; $loopIndex < $arTerm.numElements(); $loopIndex++) { - | $bufferTerm.append(","); - | if (!$arTerm.isNullAt($loopIndex)) { - | $bufferTerm.append(" "); - | $bufferTerm.append($elementToStringFunc(${ctx.getValue(arTerm, et, loopIndex)})); - | } - | } - | } - | $bufferTerm.append("]"); - |} - """.stripMargin) + s""" + |$bufferTerm.append("["); + |if ($arTerm.numElements() > 0) { + | if (!$arTerm.isNullAt(0)) { + | $bufferTerm.append($elementToStringFunc(${ctx.getValue(arTerm, et, "0")})); + | } + | for (int $loopIndex = 1; $loopIndex < $arTerm.numElements(); $loopIndex++) { + | $bufferTerm.append(","); + | if (!$arTerm.isNullAt($loopIndex)) { + | $bufferTerm.append(" "); + | $bufferTerm.append($elementToStringFunc(${ctx.getValue(arTerm, et, loopIndex)})); + | } + | } + |} + |$bufferTerm.append("]"); + """.stripMargin } private[this] def castToStringCode(from: DataType, ctx: CodegenContext): CastFunction = { @@ -672,10 +669,10 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String (c, evPrim, evNull) => { val bufferTerm = ctx.freshName("bufferTerm") val bufferClass = classOf[UTF8StringBuilder].getName - val writeArrayElemCode = codegenWriteArrayElemCode(et, ctx) + val writeArrayElemCode = writeArrayToStringBuilder(et, c, bufferTerm, ctx) s""" |$bufferClass $bufferTerm = new $bufferClass(); - |$writeArrayElemCode($c, $bufferTerm); + |$writeArrayElemCode; |$evPrim = $bufferTerm.build(); """.stripMargin } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 851618ee03046..96bf65fce9c4a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql import java.io.File import java.math.MathContext import java.net.{MalformedURLException, URL} -import java.sql.{Date, Timestamp} +import java.sql.Timestamp import java.util.concurrent.atomic.AtomicBoolean import org.apache.spark.{AccumulatorSuite, SparkException} @@ -2773,53 +2773,4 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } } } - - test("SPARK-22825 Cast array to string") { - Seq("true", "false").foreach { codegen => - withSQLConf("spark.sql.codegen.wholeStage" -> codegen) { - withTable("t") { - Seq(Seq(0, 1, 2, 3, 4)).toDF("a").write.saveAsTable("t") - val df = sql("SELECT CAST(a AS STRING) FROM t") - checkAnswer(df, Row("[0, 1, 2, 3, 4]")) - } - withTable("t") { - Seq(Seq("ab", "cde", "f")).toDF("a").write.saveAsTable("t") - val df = sql("SELECT CAST(a AS STRING) FROM t") - checkAnswer(df, Row("[ab, cde, f]")) - } - withTable("t") { - Seq(Seq("ab", null, "c")).toDF("a").write.saveAsTable("t") - val df = sql("SELECT CAST(a AS STRING) FROM t") - checkAnswer(df, Row("[ab,, c]")) - } - withTable("t") { - Seq(Seq("ab".getBytes, "cde".getBytes, "f".getBytes)).toDF("a").write.saveAsTable("t") - val df = sql("SELECT CAST(a AS STRING) FROM t") - checkAnswer(df, Row("[ab, cde, f]")) - } - withTable("t") { - Seq(Seq("2014-12-03", "2014-12-04", "2014-12-06").map(Date.valueOf)) - .toDF("a").write.saveAsTable("t") - val df = sql("SELECT CAST(a AS STRING) FROM t") - checkAnswer(df, Row("[2014-12-03, 2014-12-04, 2014-12-06]")) - } - withTable("t") { - Seq(Seq("2014-12-03 13:01:00", "2014-12-04 15:05:00").map(Timestamp.valueOf)) - .toDF("a").write.saveAsTable("t") - val df = sql("SELECT CAST(a AS STRING) FROM t") - checkAnswer(df, Row("[2014-12-03 13:01:00, 2014-12-04 15:05:00]")) - } - withTable("t") { - Seq(Seq(Seq(1, 2), Seq(3), Seq(4, 5, 6))).toDF("a").write.saveAsTable("t") - val df = sql("SELECT CAST(a AS STRING) FROM t") - checkAnswer(df, Row("[[1, 2], [3], [4, 5, 6]]")) - } - withTable("t") { - Seq(Seq(Seq(Seq("a"), Seq("b", "c")), Seq(Seq("d")))).toDF("a").write.saveAsTable("t") - val df = sql("SELECT CAST(a AS STRING) FROM t") - checkAnswer(df, Row("[[[a], [b, c]], [[d]]]")) - } - } - } - } } From dc15b93fe76a675136dd1bf08ce25ad3c55959b3 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Fri, 5 Jan 2018 11:33:45 +0900 Subject: [PATCH 12/12] Drop Term --- .../spark/sql/catalyst/expressions/Cast.scala | 32 +++++++++---------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 02d27d3914bb0..d4fc5e0f168a7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -621,8 +621,8 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String private def writeArrayToStringBuilder( et: DataType, - arTerm: String, - bufferTerm: String, + array: String, + buffer: String, ctx: CodegenContext): String = { val elementToStringCode = castToStringCode(et, ctx) val funcName = ctx.freshName("elementToString") @@ -637,20 +637,20 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String val loopIndex = ctx.freshName("loopIndex") s""" - |$bufferTerm.append("["); - |if ($arTerm.numElements() > 0) { - | if (!$arTerm.isNullAt(0)) { - | $bufferTerm.append($elementToStringFunc(${ctx.getValue(arTerm, et, "0")})); + |$buffer.append("["); + |if ($array.numElements() > 0) { + | if (!$array.isNullAt(0)) { + | $buffer.append($elementToStringFunc(${ctx.getValue(array, et, "0")})); | } - | for (int $loopIndex = 1; $loopIndex < $arTerm.numElements(); $loopIndex++) { - | $bufferTerm.append(","); - | if (!$arTerm.isNullAt($loopIndex)) { - | $bufferTerm.append(" "); - | $bufferTerm.append($elementToStringFunc(${ctx.getValue(arTerm, et, loopIndex)})); + | for (int $loopIndex = 1; $loopIndex < $array.numElements(); $loopIndex++) { + | $buffer.append(","); + | if (!$array.isNullAt($loopIndex)) { + | $buffer.append(" "); + | $buffer.append($elementToStringFunc(${ctx.getValue(array, et, loopIndex)})); | } | } |} - |$bufferTerm.append("]"); + |$buffer.append("]"); """.stripMargin } @@ -667,13 +667,13 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String org.apache.spark.sql.catalyst.util.DateTimeUtils.timestampToString($c, $tz));""" case ArrayType(et, _) => (c, evPrim, evNull) => { - val bufferTerm = ctx.freshName("bufferTerm") + val buffer = ctx.freshName("buffer") val bufferClass = classOf[UTF8StringBuilder].getName - val writeArrayElemCode = writeArrayToStringBuilder(et, c, bufferTerm, ctx) + val writeArrayElemCode = writeArrayToStringBuilder(et, c, buffer, ctx) s""" - |$bufferClass $bufferTerm = new $bufferClass(); + |$bufferClass $buffer = new $bufferClass(); |$writeArrayElemCode; - |$evPrim = $bufferTerm.build(); + |$evPrim = $buffer.build(); """.stripMargin } case _ =>