From be53ff6b2fa829ffa3f0f4cc2996aa74fd7056c0 Mon Sep 17 00:00:00 2001 From: younggyu chun Date: Thu, 18 Jun 2020 09:14:50 -0400 Subject: [PATCH] [SPARK-7101][SQL] support java.sql.Time --- .../expressions/SpecializedGettersReader.java | 3 + .../sql/catalyst/expressions/UnsafeRow.java | 3 +- .../org/apache/spark/sql/types/DataTypes.java | 5 + .../main/scala/org/apache/spark/sql/Row.scala | 9 + .../sql/catalyst/CatalystTypeConverters.scala | 13 +- .../catalyst/DeserializerBuildHelper.scala | 9 + .../spark/sql/catalyst/InternalRow.scala | 5 +- .../sql/catalyst/JavaTypeInference.scala | 2 + .../sql/catalyst/SerializerBuildHelper.scala | 9 + .../sql/catalyst/encoders/RowEncoder.scala | 20 ++ .../spark/sql/catalyst/expressions/Cast.scala | 106 ++++++--- .../InterpretedUnsafeProjection.scala | 2 +- .../expressions/codegen/CodeGenerator.scala | 5 +- .../sql/catalyst/expressions/literals.scala | 12 +- .../sql/catalyst/parser/AstBuilder.scala | 1 + .../sql/catalyst/util/DateTimeUtils.scala | 21 +- .../org/apache/spark/sql/types/DataType.scala | 2 +- .../org/apache/spark/sql/types/TimeType.scala | 63 ++++++ .../spark/sql/RandomDataGenerator.scala | 16 ++ .../catalyst/encoders/RowEncoderSuite.scala | 49 +++-- .../sql/catalyst/expressions/CastSuite.scala | 202 +++++++++++++++++- .../spark/sql/execution/HiveResult.scala | 6 +- .../spark/sql/execution/HiveResultSuite.scala | 14 ++ 23 files changed, 516 insertions(+), 61 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/types/TimeType.scala diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGettersReader.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGettersReader.java index ea0648a6cb909..1fed1da47203c 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGettersReader.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGettersReader.java @@ -65,6 +65,9 @@ public static Object read( if (dataType instanceof TimestampType) { return obj.getLong(ordinal); } + if (dataType instanceof TimeType) { + return obj.getLong(ordinal); + } if (dataType instanceof CalendarIntervalType) { return obj.getInterval(ordinal); } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index 23e7d1f07e4a3..305d173cf5198 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -90,7 +90,8 @@ public static int calculateBitSetWidthInBytes(int numFields) { FloatType, DoubleType, DateType, - TimestampType + TimestampType, + TimeType }))); } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/types/DataTypes.java b/sql/catalyst/src/main/java/org/apache/spark/sql/types/DataTypes.java index d786374f69e20..04911b6a2bb81 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/types/DataTypes.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/types/DataTypes.java @@ -54,6 +54,11 @@ public class DataTypes { */ public static final DataType TimestampType = TimestampType$.MODULE$; + /** + * Gets the TimeType object. + */ + public static final DataType TimeType = TimeType$.MODULE$; + /** * Gets the CalendarIntervalType object. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala index 9a7e077b658df..ae9eba52de823 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala @@ -167,6 +167,7 @@ trait Row extends Serializable { * * DateType -> java.sql.Date * TimestampType -> java.sql.Timestamp + * TimeType -> java.sql.Time * * BinaryType -> byte array * ArrayType -> scala.collection.Seq (use getList for java.util.List) @@ -192,6 +193,7 @@ trait Row extends Serializable { * * DateType -> java.sql.Date * TimestampType -> java.sql.Timestamp + * TimeType -> java.sql.Time * * BinaryType -> byte array * ArrayType -> scala.collection.Seq (use getList for java.util.List) @@ -296,6 +298,13 @@ trait Row extends Serializable { */ def getTimestamp(i: Int): java.sql.Timestamp = getAs[java.sql.Timestamp](i) + /** + * Returns the value at position i of date type as java.sql.Time. + * + * @throws ClassCastException when data type does not match. + */ + def getTime(i: Int): java.sql.Time = getAs[java.sql.Time](i) + /** * Returns the value at position i of date type as java.time.Instant. * diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala index 34d2f45e715e9..4f889281f3354 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst import java.lang.{Iterable => JavaIterable} import java.math.{BigDecimal => JavaBigDecimal} import java.math.{BigInteger => JavaBigInteger} -import java.sql.{Date, Timestamp} +import java.sql.{Date, Time, Timestamp} import java.time.{Instant, LocalDate} import java.util.{Map => JavaMap} import javax.annotation.Nullable @@ -331,6 +331,16 @@ object CatalystTypeConverters { DateTimeUtils.toJavaTimestamp(row.getLong(column)) } + private object TimeConverter extends CatalystTypeConverter[Time, Time, Any] { + override def toCatalystImpl(scalaValue: Time): Long = + DateTimeUtils.fromJavaTime(scalaValue) + override def toScala(catalystValue: Any): Time = + if (catalystValue == null) null + else DateTimeUtils.toJavaTime(catalystValue.asInstanceOf[Long]) + override def toScalaImpl(row: InternalRow, column: Int): Time = + DateTimeUtils.toJavaTime(row.getLong(column)) + } + private object InstantConverter extends CatalystTypeConverter[Instant, Instant, Any] { override def toCatalystImpl(scalaValue: Instant): Long = DateTimeUtils.instantToMicros(scalaValue) @@ -451,6 +461,7 @@ object CatalystTypeConverters { case d: Date => DateConverter.toCatalyst(d) case ld: LocalDate => LocalDateConverter.toCatalyst(ld) case t: Timestamp => TimestampConverter.toCatalyst(t) + case ti: Time => TimeConverter.toCatalyst(ti); case i: Instant => InstantConverter.toCatalyst(i) case d: BigDecimal => new DecimalConverter(DecimalType(d.precision, d.scale)).toCatalyst(d) case d: JavaBigDecimal => new DecimalConverter(DecimalType(d.precision, d.scale)).toCatalyst(d) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala index e55c25c4b0c54..f637bbdc49cec 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala @@ -118,6 +118,15 @@ object DeserializerBuildHelper { returnNullable = false) } + def createDeserializerForSqlTime(path: Expression): Expression = { + StaticInvoke( + DateTimeUtils.getClass, + ObjectType(classOf[java.sql.Time]), + "toJavaTime", + path :: Nil, + returnNullable = false) + } + def createDeserializerForJavaBigDecimal( path: Expression, returnNullable: Boolean): Expression = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala index f98b59edd4226..f4b34e64b9217 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala @@ -133,7 +133,7 @@ object InternalRow { case ByteType => (input, ordinal) => input.getByte(ordinal) case ShortType => (input, ordinal) => input.getShort(ordinal) case IntegerType | DateType => (input, ordinal) => input.getInt(ordinal) - case LongType | TimestampType => (input, ordinal) => input.getLong(ordinal) + case LongType | TimestampType | TimeType => (input, ordinal) => input.getLong(ordinal) case FloatType => (input, ordinal) => input.getFloat(ordinal) case DoubleType => (input, ordinal) => input.getDouble(ordinal) case StringType => (input, ordinal) => input.getUTF8String(ordinal) @@ -168,7 +168,8 @@ object InternalRow { case ByteType => (input, v) => input.setByte(ordinal, v.asInstanceOf[Byte]) case ShortType => (input, v) => input.setShort(ordinal, v.asInstanceOf[Short]) case IntegerType | DateType => (input, v) => input.setInt(ordinal, v.asInstanceOf[Int]) - case LongType | TimestampType => (input, v) => input.setLong(ordinal, v.asInstanceOf[Long]) + case LongType | TimestampType | TimeType => (input, v) => + input.setLong(ordinal, v.asInstanceOf[Long]) case FloatType => (input, v) => input.setFloat(ordinal, v.asInstanceOf[Float]) case DoubleType => (input, v) => input.setDouble(ordinal, v.asInstanceOf[Double]) case CalendarIntervalType => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala index 91ddf0f28ad80..21498007f17ae 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala @@ -397,6 +397,8 @@ object JavaTypeInference { case c if c == classOf[java.sql.Timestamp] => createSerializerForSqlTimestamp(inputObject) + case c if c == classOf[java.sql.Time] => createSerializerForSqlTime(inputObject) + case c if c == classOf[java.time.LocalDate] => createSerializerForJavaLocalDate(inputObject) case c if c == classOf[java.sql.Date] => createSerializerForSqlDate(inputObject) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala index 026ff6f2983fb..44516a51b065f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala @@ -83,6 +83,15 @@ object SerializerBuildHelper { returnNullable = false) } + def createSerializerForSqlTime(inputObject: Expression): Expression = { + StaticInvoke( + DateTimeUtils.getClass, + TimeType, + "fromJavaTime", + inputObject :: Nil, + returnNullable = false) + } + def createSerializerForJavaLocalDate(inputObject: Expression): Expression = { StaticInvoke( DateTimeUtils.getClass, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index 765018f07d87a..df6db0000bcf1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -101,6 +101,13 @@ object RowEncoder { createSerializerForSqlTimestamp(inputObject) } + case TimeType => + if (SQLConf.get.datetimeJava8ApiEnabled) { + createSerializerForJavaInstant(inputObject) + } else { + createSerializerForSqlTime(inputObject) + } + case DateType => if (SQLConf.get.datetimeJava8ApiEnabled) { createSerializerForJavaLocalDate(inputObject) @@ -220,6 +227,12 @@ object RowEncoder { } else { ObjectType(classOf[java.sql.Timestamp]) } + case TimeType => + if (SQLConf.get.datetimeJava8ApiEnabled) { + ObjectType(classOf[java.time.Instant]) + } else { + ObjectType(classOf[java.sql.Time]) + } case DateType => if (SQLConf.get.datetimeJava8ApiEnabled) { ObjectType(classOf[java.time.LocalDate]) @@ -274,6 +287,13 @@ object RowEncoder { createDeserializerForSqlTimestamp(input) } + case TimeType => + if (SQLConf.get.datetimeJava8ApiEnabled) { + createDeserializerForInstant(input) + } else { + createDeserializerForSqlTime(input) + } + case DateType => if (SQLConf.get.datetimeJava8ApiEnabled) { createDeserializerForLocalDate(input) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 7c4316fe08433..c68084596d0d0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -54,6 +54,7 @@ object Cast { case (StringType, BooleanType) => true case (DateType, BooleanType) => true case (TimestampType, BooleanType) => true + case (TimeType, BooleanType) => true case (_: NumericType, BooleanType) => true case (StringType, TimestampType) => true @@ -61,8 +62,14 @@ object Cast { case (DateType, TimestampType) => true case (_: NumericType, TimestampType) => true + case (StringType, TimeType) => true + case (BooleanType, TimeType) => true + case (DateType, TimeType) => true + case (_: NumericType, TimeType) => true + case (StringType, DateType) => true case (TimestampType, DateType) => true + case (TimeType, DateType) => true case (StringType, CalendarIntervalType) => true @@ -70,6 +77,7 @@ object Cast { case (BooleanType, _: NumericType) => true case (DateType, _: NumericType) => true case (TimestampType, _: NumericType) => true + case (TimeType, _: NumericType) => true case (_: NumericType, _: NumericType) => true case (ArrayType(fromType, fn), ArrayType(toType, tn)) => @@ -107,10 +115,10 @@ object Cast { * * Cast.castToTimestamp */ def needsTimeZone(from: DataType, to: DataType): Boolean = (from, to) match { - case (StringType, TimestampType | DateType) => true - case (DateType, TimestampType) => true - case (TimestampType, StringType) => true - case (TimestampType, DateType) => true + case (StringType, TimestampType | TimeType | DateType) => true + case (DateType, TimestampType | TimeType) => true + case (TimestampType | TimeType, StringType) => true + case (TimestampType | TimeType, DateType) => true case (ArrayType(fromType, _), ArrayType(toType, _)) => needsTimeZone(fromType, toType) case (MapType(fromKey, fromValue, _), MapType(toKey, toValue, _)) => needsTimeZone(fromKey, toKey) || needsTimeZone(fromValue, toValue) @@ -134,6 +142,7 @@ object Cast { case (from: DecimalType, to: NumericType) if from.isTighterThan(to) => true case (f, t) if legalNumericPrecedence(f, t) => true case (DateType, TimestampType) => true + case (DateType, TimeType) => true case (_: AtomicType, StringType) => true case (_: CalendarIntervalType, StringType) => true case (NullType, _) => true @@ -143,6 +152,9 @@ object Cast { case (TimestampType, LongType) => true case (LongType, TimestampType) => true + case (TimeType, LongType) => true + case (LongType, TimeType) => true + case (ArrayType(fromType, fn), ArrayType(toType, tn)) => resolvableNullability(fn, tn) && canUpCast(fromType, toType) @@ -172,6 +184,8 @@ object Cast { case (_: CalendarIntervalType, StringType) => true case (DateType, TimestampType) => true case (TimestampType, DateType) => true + case (DateType, TimeType) => true + case (TimeType, DateType) => true case (ArrayType(fromType, fn), ArrayType(toType, tn)) => resolvableNullability(fn, tn) && canANSIStoreAssign(fromType, toType) @@ -214,10 +228,10 @@ object Cast { case (StringType, _) => true case (_, StringType) => false - case (FloatType | DoubleType, TimestampType) => true - case (TimestampType, DateType) => false + case (FloatType | DoubleType, TimestampType | TimeType) => true + case (TimestampType | TimeType, DateType) => false case (_, DateType) => true - case (DateType, TimestampType) => false + case (DateType, TimestampType | TimeType) => false case (DateType, _) => true case (_, CalendarIntervalType) => true @@ -245,7 +259,7 @@ object Cast { } } -abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression with NullIntolerant { +abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression with NullIntolerant { def child: Expression @@ -288,7 +302,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit buildCast[CalendarInterval](_, i => UTF8String.fromString(i.toString)) case BinaryType => buildCast[Array[Byte]](_, UTF8String.fromBytes) case DateType => buildCast[Int](_, d => UTF8String.fromString(dateFormatter.format(d))) - case TimestampType => buildCast[Long](_, + case TimestampType | TimeType => buildCast[Long](_, t => UTF8String.fromString(DateTimeUtils.timestampToString(timestampFormatter, t))) case ArrayType(et, _) => buildCast[ArrayData](_, array => { @@ -393,7 +407,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit null } }) - case TimestampType => + case TimestampType | TimeType => buildCast[Long](_, t => t != 0) case DateType => // Hive would return null when cast from date to boolean @@ -441,6 +455,33 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit buildCast[Float](_, f => doubleToTimestamp(f.toDouble)) } + // TimeConverter + private[this] def castToTime(from: DataType): Any => Any = from match { + case StringType => + buildCast[UTF8String](_, utfs => DateTimeUtils.stringToTimestamp(utfs, zoneId).orNull) + case BooleanType => + buildCast[Boolean](_, b => if (b) 1L else 0) + case LongType => + buildCast[Long](_, l => longToTimestamp(l)) + case IntegerType => + buildCast[Int](_, i => longToTimestamp(i.toLong)) + case ShortType => + buildCast[Short](_, s => longToTimestamp(s.toLong)) + case ByteType => + buildCast[Byte](_, b => longToTimestamp(b.toLong)) + case DateType => + buildCast[Int](_, d => epochDaysToMicros(d, zoneId)) + // TimestampWritable.decimalToTimestamp + case DecimalType() => + buildCast[Decimal](_, d => decimalToTimestamp(d)) + // TimestampWritable.doubleToTimestamp + case DoubleType => + buildCast[Double](_, d => doubleToTimestamp(d)) + // TimestampWritable.floatToTimestamp + case FloatType => + buildCast[Float](_, f => doubleToTimestamp(f.toDouble)) + } + private[this] def decimalToTimestamp(d: Decimal): Long = { (d.toBigDecimal * MICROS_PER_SECOND).longValue } @@ -463,7 +504,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit private[this] def castToDate(from: DataType): Any => Any = from match { case StringType => buildCast[UTF8String](_, s => DateTimeUtils.stringToDate(s, zoneId).orNull) - case TimestampType => + case TimestampType | TimeType => // throw valid precision more than seconds, according to Hive. // Timestamp.nanos is in 0 to 999,999,999, no more than a second. buildCast[Long](_, t => microsToEpochDays(t, zoneId)) @@ -486,7 +527,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit buildCast[Boolean](_, b => if (b) 1L else 0L) case DateType => buildCast[Int](_, d => null) - case TimestampType => + case TimestampType | TimeType => buildCast[Long](_, t => timestampToLong(t)) case x: NumericType if ansiEnabled => b => x.exactNumeric.asInstanceOf[Numeric[Any]].toLong(b) @@ -505,9 +546,9 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit buildCast[Boolean](_, b => if (b) 1 else 0) case DateType => buildCast[Int](_, d => null) - case TimestampType if ansiEnabled => + case TimestampType | TimeType if ansiEnabled => buildCast[Long](_, t => LongExactNumeric.toInt(timestampToLong(t))) - case TimestampType => + case TimestampType | TimeType => buildCast[Long](_, t => timestampToLong(t).toInt) case x: NumericType if ansiEnabled => b => x.exactNumeric.asInstanceOf[Numeric[Any]].toInt(b) @@ -530,7 +571,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit buildCast[Boolean](_, b => if (b) 1.toShort else 0.toShort) case DateType => buildCast[Int](_, d => null) - case TimestampType if ansiEnabled => + case TimestampType | TimeType if ansiEnabled => buildCast[Long](_, t => { val longValue = timestampToLong(t) if (longValue == longValue.toShort) { @@ -539,7 +580,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit throw new ArithmeticException(s"Casting $t to short causes overflow") } }) - case TimestampType => + case TimestampType | TimeType => buildCast[Long](_, t => timestampToLong(t).toShort) case x: NumericType if ansiEnabled => b => @@ -573,7 +614,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit buildCast[Boolean](_, b => if (b) 1.toByte else 0.toByte) case DateType => buildCast[Int](_, d => null) - case TimestampType if ansiEnabled => + case TimestampType | TimeType if ansiEnabled => buildCast[Long](_, t => { val longValue = timestampToLong(t) if (longValue == longValue.toByte) { @@ -582,7 +623,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit throw new ArithmeticException(s"Casting $t to byte causes overflow") } }) - case TimestampType => + case TimestampType | TimeType => buildCast[Long](_, t => timestampToLong(t).toByte) case x: NumericType if ansiEnabled => b => @@ -650,7 +691,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit buildCast[Boolean](_, b => toPrecision(if (b) Decimal.ONE else Decimal.ZERO, target)) case DateType => buildCast[Int](_, d => null) // date can't cast to decimal in Hive - case TimestampType => + case TimestampType | TimeType => // Note that we lose precision here. buildCast[Long](_, t => changePrecision(Decimal(timestampToDouble(t)), target)) case dt: DecimalType => @@ -684,7 +725,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit buildCast[Boolean](_, b => if (b) 1d else 0d) case DateType => buildCast[Int](_, d => null) - case TimestampType => + case TimestampType | TimeType => buildCast[Long](_, t => timestampToDouble(t)) case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toDouble(b) @@ -709,7 +750,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit buildCast[Boolean](_, b => if (b) 1f else 0f) case DateType => buildCast[Int](_, d => null) - case TimestampType => + case TimestampType | TimeType => buildCast[Long](_, t => timestampToDouble(t).toFloat) case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toFloat(b) @@ -777,6 +818,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit case DateType => castToDate(from) case decimal: DecimalType => castToDecimal(from, decimal) case TimestampType => castToTimestamp(from) + case TimeType => castToTime(from) case CalendarIntervalType => castToInterval(from) case BooleanType => castToBoolean(from) case ByteType => castToByte(from) @@ -836,6 +878,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit case DateType => castToDateCode(from, ctx) case decimal: DecimalType => castToDecimalCode(from, decimal, ctx) case TimestampType => castToTimestampCode(from, ctx) + case TimeType => castToTimestampCode(from, ctx) case CalendarIntervalType => castToIntervalCode(from) case BooleanType => castToBooleanCode(from) case ByteType => castToByteCode(from, ctx) @@ -1009,7 +1052,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit ctx.addReferenceObj("dateFormatter", dateFormatter), dateFormatter.getClass) (c, evPrim, evNull) => code"""$evPrim = UTF8String.fromString(${df}.format($c));""" - case TimestampType => + case TimestampType | TimeType => val tf = JavaCode.global( ctx.addReferenceObj("timestampFormatter", timestampFormatter), timestampFormatter.getClass) @@ -1095,7 +1138,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit $evNull = true; } """ - case TimestampType => + case TimestampType | TimeType => val zid = getZoneId() (c, evPrim, evNull) => code"""$evPrim = @@ -1162,7 +1205,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit case DateType => // date can't cast to decimal in Hive (c, evPrim, evNull) => code"$evNull = true;" - case TimestampType => + case TimestampType | TimeType => // Note that we lose precision here. (c, evPrim, evNull) => code""" @@ -1249,6 +1292,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit """ } + private[this] def castToIntervalCode(from: DataType): CastFunction = from match { case StringType => val util = IntervalUtils.getClass.getCanonicalName.stripSuffix("$") @@ -1284,7 +1328,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit $evNull = true; } """ - case TimestampType => + case TimestampType | TimeType => (c, evPrim, evNull) => code"$evPrim = $c != 0;" case DateType => // Hive would return null when cast from date to boolean @@ -1390,7 +1434,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit (c, evPrim, evNull) => code"$evPrim = $c ? (byte) 1 : (byte) 0;" case DateType => (c, evPrim, evNull) => code"$evNull = true;" - case TimestampType => castTimestampToIntegralTypeCode(ctx, "byte") + case TimestampType | TimeType => castTimestampToIntegralTypeCode(ctx, "byte") case DecimalType() => castDecimalToIntegralTypeCode(ctx, "byte") case _: ShortType | _: IntegerType | _: LongType if ansiEnabled => castIntegralTypeToIntegralTypeExactCode("byte") @@ -1423,7 +1467,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit (c, evPrim, evNull) => code"$evPrim = $c ? (short) 1 : (short) 0;" case DateType => (c, evPrim, evNull) => code"$evNull = true;" - case TimestampType => castTimestampToIntegralTypeCode(ctx, "short") + case TimestampType | TimeType => castTimestampToIntegralTypeCode(ctx, "short") case DecimalType() => castDecimalToIntegralTypeCode(ctx, "short") case _: IntegerType | _: LongType if ansiEnabled => castIntegralTypeToIntegralTypeExactCode("short") @@ -1454,7 +1498,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit (c, evPrim, evNull) => code"$evPrim = $c ? 1 : 0;" case DateType => (c, evPrim, evNull) => code"$evNull = true;" - case TimestampType => castTimestampToIntegralTypeCode(ctx, "int") + case TimestampType | TimeType => castTimestampToIntegralTypeCode(ctx, "int") case DecimalType() => castDecimalToIntegralTypeCode(ctx, "int") case _: LongType if ansiEnabled => castIntegralTypeToIntegralTypeExactCode("int") case _: FloatType if ansiEnabled => @@ -1484,7 +1528,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit (c, evPrim, evNull) => code"$evPrim = $c ? 1L : 0L;" case DateType => (c, evPrim, evNull) => code"$evNull = true;" - case TimestampType => + case TimestampType | TimeType => (c, evPrim, evNull) => code"$evPrim = (long) ${timestampToLongCode(c)};" case DecimalType() => castDecimalToIntegralTypeCode(ctx, "long") case _: FloatType if ansiEnabled => @@ -1522,7 +1566,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit (c, evPrim, evNull) => code"$evPrim = $c ? 1.0f : 0.0f;" case DateType => (c, evPrim, evNull) => code"$evNull = true;" - case TimestampType => + case TimestampType | TimeType => (c, evPrim, evNull) => code"$evPrim = (float) (${timestampToDoubleCode(c)});" case DecimalType() => (c, evPrim, evNull) => code"$evPrim = $c.toFloat();" @@ -1558,7 +1602,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit (c, evPrim, evNull) => code"$evPrim = $c ? 1.0d : 0.0d;" case DateType => (c, evPrim, evNull) => code"$evNull = true;" - case TimestampType => + case TimestampType | TimeType => (c, evPrim, evNull) => code"$evPrim = ${timestampToDoubleCode(c)};" case DecimalType() => (c, evPrim, evNull) => code"$evPrim = $c.toDouble();" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala index 39a16e917c4a5..3e030af32592c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala @@ -146,7 +146,7 @@ object InterpretedUnsafeProjection { case IntegerType | DateType => (v, i) => writer.write(i, v.getInt(i)) - case LongType | TimestampType => + case LongType | TimestampType | TimeType => (v, i) => writer.write(i, v.getLong(i)) case FloatType => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 58c95c94ba198..fc69de4c5e0de 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -40,7 +40,6 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData} import org.apache.spark.sql.catalyst.util.DateTimeConstants.NANOS_PER_MILLIS -import org.apache.spark.sql.catalyst.util.DateTimeUtils._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.Platform @@ -1767,7 +1766,7 @@ object CodeGenerator extends Logging { case ByteType => JAVA_BYTE case ShortType => JAVA_SHORT case IntegerType | DateType => JAVA_INT - case LongType | TimestampType => JAVA_LONG + case LongType | TimestampType | TimeType => JAVA_LONG case FloatType => JAVA_FLOAT case DoubleType => JAVA_DOUBLE case _: DecimalType => "Decimal" @@ -1788,7 +1787,7 @@ object CodeGenerator extends Logging { case ByteType => java.lang.Byte.TYPE case ShortType => java.lang.Short.TYPE case IntegerType | DateType => java.lang.Integer.TYPE - case LongType | TimestampType => java.lang.Long.TYPE + case LongType | TimestampType | TimeType => java.lang.Long.TYPE case FloatType => java.lang.Float.TYPE case DoubleType => java.lang.Double.TYPE case _: DecimalType => classOf[Decimal] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index 213a58a3244e2..3a46761e4677c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -26,7 +26,7 @@ import java.lang.{Long => JavaLong} import java.lang.{Short => JavaShort} import java.math.{BigDecimal => JavaBigDecimal} import java.nio.charset.StandardCharsets -import java.sql.{Date, Timestamp} +import java.sql.{Date, Time, Timestamp} import java.time.{Instant, LocalDate} import java.util import java.util.Objects @@ -72,6 +72,7 @@ object Literal { case d: Decimal => Literal(d, DecimalType(Math.max(d.precision, d.scale), d.scale)) case i: Instant => Literal(instantToMicros(i), TimestampType) case t: Timestamp => Literal(DateTimeUtils.fromJavaTimestamp(t), TimestampType) + case t: Time => Literal(DateTimeUtils.fromJavaTime(t), TimeType) case ld: LocalDate => Literal(ld.toEpochDay.toInt, DateType) case d: Date => Literal(DateTimeUtils.fromJavaDate(d), DateType) case a: Array[Byte] => Literal(a, BinaryType) @@ -163,6 +164,7 @@ object Literal { case dt: DecimalType => Literal(Decimal(0, dt.precision, dt.scale)) case DateType => create(0, DateType) case TimestampType => create(0L, TimestampType) + case TimeType => create(0L, TimeType) case StringType => Literal("") case BinaryType => Literal("".getBytes(StandardCharsets.UTF_8)) case CalendarIntervalType => Literal(new CalendarInterval(0, 0, 0)) @@ -182,7 +184,7 @@ object Literal { case ByteType => v.isInstanceOf[Byte] case ShortType => v.isInstanceOf[Short] case IntegerType | DateType => v.isInstanceOf[Int] - case LongType | TimestampType => v.isInstanceOf[Long] + case LongType | TimestampType | TimeType => v.isInstanceOf[Long] case FloatType => v.isInstanceOf[Float] case DoubleType => v.isInstanceOf[Double] case _: DecimalType => v.isInstanceOf[Decimal] @@ -369,7 +371,7 @@ case class Literal (value: Any, dataType: DataType) extends LeafExpression { } case ByteType | ShortType => ExprCode.forNonNullValue(JavaCode.expression(s"($javaType)$value", dataType)) - case TimestampType | LongType => + case TimestampType | TimeType | LongType => toExprCode(s"${value}L") case _ => val constRef = ctx.addReferenceObj("literal", value, javaType) @@ -411,6 +413,10 @@ case class Literal (value: Any, dataType: DataType) extends LeafExpression { val formatter = TimestampFormatter.getFractionFormatter( DateTimeUtils.getZoneId(SQLConf.get.sessionLocalTimeZone)) s"TIMESTAMP '${formatter.format(v)}'" + case (v: Long, TimeType) => + val formatter = TimestampFormatter.getFractionFormatter( + DateTimeUtils.getZoneId(SQLConf.get.sessionLocalTimeZone)) + s"TIME '${formatter.format(v)}'" case (i: CalendarInterval, CalendarIntervalType) => s"INTERVAL '${i.toString}'" case (v: Array[Byte], BinaryType) => s"X'${DatatypeConverter.printHexBinary(v)}'" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index b3541a7f7374d..6d695a3cbde46 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -2175,6 +2175,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging case ("double", Nil) => DoubleType case ("date", Nil) => DateType case ("timestamp", Nil) => TimestampType + case ("time", Nil) => TimestampType case ("string", Nil) => StringType case ("character" | "char", length :: Nil) => CharType(length.getText.toInt) case ("varchar", length :: Nil) => VarcharType(length.getText.toInt) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index 593bd18f3de9c..d152a93a12729 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.util import java.nio.charset.StandardCharsets -import java.sql.{Date, Timestamp} +import java.sql.{Date, Time, Timestamp} import java.time._ import java.time.temporal.{ChronoField, ChronoUnit, IsoFields} import java.util.{Locale, TimeZone} @@ -105,6 +105,13 @@ object DateTimeUtils { Timestamp.from(microsToInstant(us)) } + /** + * Returns a java.sql.Time from number of micros since epoch. + */ + def toJavaTime(us: SQLTimestamp): Time = { + new Time(us) + } + /** * Returns the number of micros since epoch from java.sql.Timestamp. */ @@ -112,6 +119,13 @@ object DateTimeUtils { instantToMicros(t.toInstant) } + /** + * Returns the number of micros since epoch from java.sql.Time. + */ + def fromJavaTime(t: Time): SQLTimestamp = { + TimeToMicros(t) + } + /** * Returns the number of microseconds since epoch from Julian day * and nanoseconds in a day @@ -327,6 +341,11 @@ object DateTimeUtils { result } + def TimeToMicros(t: Time): Long = { + val result = Math.multiplyExact(t.getTime, MICROS_PER_MILLIS) + result + } + def microsToInstant(us: Long): Instant = { val secs = Math.floorDiv(us, MICROS_PER_SECOND) // Unfolded Math.floorMod(us, MICROS_PER_SECOND) to reuse the result of diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala index 8a8cea194bf2c..f46a46924ec53 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -138,7 +138,7 @@ object DataType { def fromJson(json: String): DataType = parseDataType(parse(json)) private val nonDecimalNameToType = { - Seq(NullType, DateType, TimestampType, BinaryType, IntegerType, BooleanType, LongType, + Seq(NullType, DateType, TimestampType, TimeType, BinaryType, IntegerType, BooleanType, LongType, DoubleType, FloatType, ShortType, ByteType, StringType, CalendarIntervalType) .map(t => t.typeName -> t).toMap } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/TimeType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/TimeType.scala new file mode 100644 index 0000000000000..ff1d10aa2ac89 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/TimeType.scala @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.types + +import scala.math.Ordering +import scala.reflect.runtime.universe.typeTag + +import org.apache.spark.annotation.Stable + +/** + * The timestamp type represents a time instant in microsecond precision. + * Valid range is [0001-01-01T00:00:00.000000Z, 9999-12-31T23:59:59.999999Z] where + * the left/right-bound is a date and time of the proleptic Gregorian + * calendar in UTC+00:00. + * + * Please use the singleton `DataTypes.TimestampType` to refer the type. + * @since 3.1.0 + */ +@Stable +class TimeType private() extends AtomicType { + /** + * Internally, a timestamp is stored as the number of microseconds from + * the epoch of 1970-01-01T00:00:00.000000Z (UTC+00:00) + */ + private[sql] type InternalType = Long + + @transient private[sql] lazy val tag = typeTag[InternalType] + + private[sql] val ordering = implicitly[Ordering[InternalType]] + + /** + * The default size of a value of the TimeType is 8 bytes. + */ + override def defaultSize: Int = 8 + + private[spark] override def asNullable: TimeType = this +} + +/** + * The companion case object and its class is separated so the companion object also subclasses + * the TimestampType class. Otherwise, the companion object would be of type "TimestampType$" + * in byte code. Defined with a private constructor so the companion object is the only possible + * instantiation. + * + * @since 3.1.0 + */ +@Stable +case object TimeType extends TimeType diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala index a7c20c34d78bc..10e1256fc9baa 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala @@ -191,6 +191,22 @@ object RandomDataGenerator { DateTimeUtils.toJavaTimestamp(milliseconds * 1000) } Some(generator) + case TimeType => + val generator = + () => { + var milliseconds = rand.nextLong() % 253402329599999L + // -62135740800000L is the number of milliseconds before January 1, 1970, 00:00:00 GMT + // for "0001-01-01 00:00:00.000000". We need to find a + // number that is greater or equals to this number as a valid timestamp value. + while (milliseconds < -62135740800000L) { + // 253402329599999L is the number of milliseconds since + // January 1, 1970, 00:00:00 GMT for "9999-12-31 23:59:59.999999". + milliseconds = rand.nextLong() % 253402329599999L + } + // DateTimeUtils.toJavaTime takes microsecond. + DateTimeUtils.toJavaTime(milliseconds) + } + Some(generator) case CalendarIntervalType => Some(() => { val months = rand.nextInt(1000) val days = rand.nextInt(10000) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala index 1a1cab823d4f3..3483521e37fac 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala @@ -83,20 +83,22 @@ class RowEncoderSuite extends CodegenInterpretedPlanTest { encodeDecodeTest( new StructType() - .add("null", NullType) - .add("boolean", BooleanType) - .add("byte", ByteType) - .add("short", ShortType) - .add("int", IntegerType) - .add("long", LongType) - .add("float", FloatType) - .add("double", DoubleType) - .add("decimal", DecimalType.SYSTEM_DEFAULT) - .add("string", StringType) - .add("binary", BinaryType) - .add("date", DateType) - .add("timestamp", TimestampType) - .add("udt", new ExamplePointUDT)) +// .add("null", NullType) +// .add("boolean", BooleanType) +// .add("byte", ByteType) +// .add("short", ShortType) +// .add("int", IntegerType) +// .add("long", LongType) +// .add("float", FloatType) +// .add("double", DoubleType) +// .add("decimal", DecimalType.SYSTEM_DEFAULT) +// .add("string", StringType) +// .add("binary", BinaryType) +// .add("date", DateType) +// .add("timestamp", TimestampType) + .add("time", TimeType) +// .add("udt", new ExamplePointUDT) + ) encodeDecodeTest( new StructType() @@ -298,6 +300,13 @@ class RowEncoderSuite extends CodegenInterpretedPlanTest { encoder.toRow(Row(Array("a"))) } assert(e4.getMessage.contains("java.lang.String is not a valid external type")) + + val e5 = intercept[RuntimeException] { + val schema = new StructType().add("a", ArrayType(TimeType)) + val encoder = RowEncoder(schema) + encoder.toRow(Row(Array("a"))) + } + assert(e5.getMessage.contains("java.lang.String is not a valid external type")) } test("SPARK-25791: Datatype of serializers should be accessible") { @@ -320,6 +329,18 @@ class RowEncoderSuite extends CodegenInterpretedPlanTest { } } + test("encoding/decoding TimeType to/from java.time.Instant") { + withSQLConf(SQLConf.DATETIME_JAVA8API_ENABLED.key -> "true") { + val schema = new StructType().add("t", TimeType) + val encoder = RowEncoder(schema).resolveAndBind() + val instant = java.time.Instant.parse("2019-02-26T16:56:00Z") + val row = encoder.toRow(Row(instant)) + assert(row.getLong(0) === DateTimeUtils.instantToMicros(instant)) + val readback = encoder.fromRow(row) + assert(readback.get(0) === instant) + } + } + test("encoding/decoding DateType to/from java.time.LocalDate") { withSQLConf(SQLConf.DATETIME_JAVA8API_ENABLED.key -> "true") { val schema = new StructType().add("d", DateType) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala index ad66873c02518..6509a5aab040e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.expressions -import java.sql.{Date, Timestamp} +import java.sql.{Date, Time, Timestamp} import java.util.{Calendar, TimeZone} import java.util.concurrent.TimeUnit._ @@ -68,6 +68,7 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper { checkNullCast(StringType, BooleanType) checkNullCast(DateType, BooleanType) checkNullCast(TimestampType, BooleanType) + checkNullCast(TimeType, BooleanType) numericTypes.foreach(dt => checkNullCast(dt, BooleanType)) checkNullCast(StringType, TimestampType) @@ -75,14 +76,21 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper { checkNullCast(DateType, TimestampType) numericTypes.foreach(dt => checkNullCast(dt, TimestampType)) + checkNullCast(StringType, TimeType) + checkNullCast(BooleanType, TimeType) + checkNullCast(DateType, TimeType) + numericTypes.foreach(dt => checkNullCast(dt, TimeType)) + checkNullCast(StringType, DateType) checkNullCast(TimestampType, DateType) + checkNullCast(TimeType, DateType) checkNullCast(StringType, CalendarIntervalType) numericTypes.foreach(dt => checkNullCast(StringType, dt)) numericTypes.foreach(dt => checkNullCast(BooleanType, dt)) numericTypes.foreach(dt => checkNullCast(DateType, dt)) numericTypes.foreach(dt => checkNullCast(TimestampType, dt)) + numericTypes.foreach(dt => checkNullCast(TimeType, dt)) for (from <- numericTypes; to <- numericTypes) checkNullCast(from, to) } @@ -206,6 +214,101 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper { } } + test("cast string to time") { + new ParVector(ALL_TIMEZONES.toVector).foreach { tz => + def checkCastStringToTime(str: String, expected: Time): Unit = { + checkEvaluation(cast(Literal(str), TimeType, Option(tz.getID)), expected) + } + + checkCastStringToTime("123", null) + + var c = Calendar.getInstance(tz) + c.set(2015, 0, 1, 0, 0, 0) + c.set(Calendar.MILLISECOND, 0) + checkCastStringToTime("2015", new Time(c.getTimeInMillis)) + c = Calendar.getInstance(tz) + c.set(2015, 2, 1, 0, 0, 0) + c.set(Calendar.MILLISECOND, 0) + checkCastStringToTime("2015-03", new Time(c.getTimeInMillis)) + c = Calendar.getInstance(tz) + c.set(2015, 2, 18, 0, 0, 0) + c.set(Calendar.MILLISECOND, 0) + checkCastStringToTime("2015-03-18", new Time(c.getTimeInMillis)) + checkCastStringToTime("2015-03-18 ", new Time(c.getTimeInMillis)) + checkCastStringToTime("2015-03-18T", new Time(c.getTimeInMillis)) + + c = Calendar.getInstance(tz) + c.set(2015, 2, 18, 12, 3, 17) + c.set(Calendar.MILLISECOND, 0) + checkCastStringToTime("2015-03-18 12:03:17", new Time(c.getTimeInMillis)) + checkCastStringToTime("2015-03-18T12:03:17", new Time(c.getTimeInMillis)) + + // If the string value includes timezone string, it represents the timestamp string + // in the timezone regardless of the timeZoneId parameter. + c = Calendar.getInstance(TimeZone.getTimeZone("UTC")) + c.set(2015, 2, 18, 12, 3, 17) + c.set(Calendar.MILLISECOND, 0) + checkCastStringToTime("2015-03-18T12:03:17Z", new Time(c.getTimeInMillis)) + checkCastStringToTime("2015-03-18 12:03:17Z", new Time(c.getTimeInMillis)) + + c = Calendar.getInstance(TimeZone.getTimeZone("GMT-01:00")) + c.set(2015, 2, 18, 12, 3, 17) + c.set(Calendar.MILLISECOND, 0) + checkCastStringToTime("2015-03-18T12:03:17-1:0", new Time(c.getTimeInMillis)) + checkCastStringToTime("2015-03-18T12:03:17-01:00", new Time(c.getTimeInMillis)) + + c = Calendar.getInstance(TimeZone.getTimeZone("GMT+07:30")) + c.set(2015, 2, 18, 12, 3, 17) + c.set(Calendar.MILLISECOND, 0) + checkCastStringToTime("2015-03-18T12:03:17+07:30", new Time(c.getTimeInMillis)) + + c = Calendar.getInstance(TimeZone.getTimeZone("GMT+07:03")) + c.set(2015, 2, 18, 12, 3, 17) + c.set(Calendar.MILLISECOND, 0) + checkCastStringToTime("2015-03-18T12:03:17+7:3", new Time(c.getTimeInMillis)) + + // tests for the string including milliseconds. + c = Calendar.getInstance(tz) + c.set(2015, 2, 18, 12, 3, 17) + c.set(Calendar.MILLISECOND, 123) + checkCastStringToTime("2015-03-18 12:03:17.123", new Time(c.getTimeInMillis)) + checkCastStringToTime("2015-03-18T12:03:17.123", new Time(c.getTimeInMillis)) + + // If the string value includes timezone string, it represents the timestamp string + // in the timezone regardless of the timeZoneId parameter. + c = Calendar.getInstance(TimeZone.getTimeZone("UTC")) + c.set(2015, 2, 18, 12, 3, 17) + c.set(Calendar.MILLISECOND, 456) + checkCastStringToTime("2015-03-18T12:03:17.456Z", new Time(c.getTimeInMillis)) + checkCastStringToTime("2015-03-18 12:03:17.456Z", new Time(c.getTimeInMillis)) + + c = Calendar.getInstance(TimeZone.getTimeZone("GMT-01:00")) + c.set(2015, 2, 18, 12, 3, 17) + c.set(Calendar.MILLISECOND, 123) + checkCastStringToTime("2015-03-18T12:03:17.123-1:0", new Time(c.getTimeInMillis)) + checkCastStringToTime("2015-03-18T12:03:17.123-01:00", new Time(c.getTimeInMillis)) + + c = Calendar.getInstance(TimeZone.getTimeZone("GMT+07:30")) + c.set(2015, 2, 18, 12, 3, 17) + c.set(Calendar.MILLISECOND, 123) + checkCastStringToTime("2015-03-18T12:03:17.123+07:30", new Time(c.getTimeInMillis)) + + c = Calendar.getInstance(TimeZone.getTimeZone("GMT+07:03")) + c.set(2015, 2, 18, 12, 3, 17) + c.set(Calendar.MILLISECOND, 123) + checkCastStringToTime("2015-03-18T12:03:17.123+7:3", new Time(c.getTimeInMillis)) + + checkCastStringToTime("2015-03-18 123142", null) + checkCastStringToTime("2015-03-18T123123", null) + checkCastStringToTime("2015-03-18X", null) + checkCastStringToTime("2015/03/18", null) + checkCastStringToTime("2015.03.18", null) + checkCastStringToTime("20150318", null) + checkCastStringToTime("2015-031-8", null) + checkCastStringToTime("2015-03-18T12:03:17-0:70", null) + } + } + test("cast from boolean") { checkEvaluation(cast(true, IntegerType), 1) checkEvaluation(cast(false, IntegerType), 0) @@ -240,6 +343,9 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(cast(cast(1.toDouble, TimestampType), DoubleType), 1.toDouble) checkEvaluation(cast(cast(1.toDouble, TimestampType), DoubleType), 1.toDouble) + + checkEvaluation(cast(cast(1.toDouble, TimeType), DoubleType), 1.toDouble) + checkEvaluation(cast(cast(1.toDouble, TimeType), DoubleType), 1.toDouble) } test("cast from string") { @@ -247,6 +353,7 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper { assert(cast("abcdef", BinaryType).nullable === false) assert(cast("abcdef", BooleanType).nullable) assert(cast("abcdef", TimestampType).nullable) + assert(cast("abcdef", TimeType).nullable) assert(cast("abcdef", LongType).nullable) assert(cast("abcdef", IntegerType).nullable) assert(cast("abcdef", ShortType).nullable) @@ -263,7 +370,10 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper { val zts = sd + " 00:00:00" val sts = sd + " 00:00:02" val nts = sts + ".1" + val zt = "00:00:00" + val t = "00:00:02" val ts = withDefaultTimeZone(TimeZoneGMT)(Timestamp.valueOf(nts)) + val ts2 = withDefaultTimeZone(TimeZoneGMT)(Time.valueOf(t)) for (tz <- ALL_TIMEZONES) { val timeZoneId = Option(tz.getID) @@ -273,32 +383,49 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper { cast(cast(new Timestamp(c.getTimeInMillis), StringType, timeZoneId), TimestampType, timeZoneId), MILLISECONDS.toMicros(c.getTimeInMillis)) + checkEvaluation( + cast(cast(new Time(c.getTimeInMillis), StringType, timeZoneId), + TimeType, timeZoneId), + MILLISECONDS.toMicros(c.getTimeInMillis)) c = Calendar.getInstance(TimeZoneGMT) c.set(2015, 10, 1, 2, 30, 0) checkEvaluation( cast(cast(new Timestamp(c.getTimeInMillis), StringType, timeZoneId), TimestampType, timeZoneId), MILLISECONDS.toMicros(c.getTimeInMillis)) + checkEvaluation( + cast(cast(new Time(c.getTimeInMillis), StringType, timeZoneId), + TimeType, timeZoneId), + MILLISECONDS.toMicros(c.getTimeInMillis)) } val gmtId = Option("GMT") checkEvaluation(cast("abdef", StringType), "abdef") checkEvaluation(cast("abdef", TimestampType, gmtId), null) + checkEvaluation(cast("abdef", TimeType, gmtId), null) checkEvaluation(cast("12.65", DecimalType.SYSTEM_DEFAULT), Decimal(12.65)) checkEvaluation(cast(cast(sd, DateType), StringType), sd) checkEvaluation(cast(cast(d, StringType), DateType), 0) checkEvaluation(cast(cast(nts, TimestampType, gmtId), StringType, gmtId), nts) + checkEvaluation(cast(cast(nts, TimeType, gmtId), StringType, gmtId), nts) checkEvaluation( cast(cast(ts, StringType, gmtId), TimestampType, gmtId), DateTimeUtils.fromJavaTimestamp(ts)) + checkEvaluation( + cast(cast(ts2, StringType, gmtId), TimeType, gmtId), + DateTimeUtils.fromJavaTime(ts2)) // all convert to string type to check checkEvaluation(cast(cast(cast(nts, TimestampType, gmtId), DateType, gmtId), StringType), sd) checkEvaluation( cast(cast(cast(ts, DateType, gmtId), TimestampType, gmtId), StringType, gmtId), zts) + checkEvaluation(cast(cast(cast(nts, TimeType, gmtId), DateType, gmtId), StringType), sd) + checkEvaluation( + cast(cast(cast(ts2, DateType, gmtId), TimeType, gmtId), StringType, gmtId), + zts) checkEvaluation(cast(cast("abdef", BinaryType), StringType), "abdef") @@ -309,14 +436,26 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper { cast(cast(cast(cast(cast(cast("5", ByteType), TimestampType), DecimalType.SYSTEM_DEFAULT), LongType), StringType), ShortType), 5.toShort) + checkEvaluation( + cast(cast(cast(cast(cast(cast("5", ByteType), TimeType), + DecimalType.SYSTEM_DEFAULT), LongType), StringType), ShortType), + 5.toShort) checkEvaluation( cast(cast(cast(cast(cast(cast("5", TimestampType, gmtId), ByteType), DecimalType.SYSTEM_DEFAULT), LongType), StringType), ShortType), null) + checkEvaluation( + cast(cast(cast(cast(cast(cast("5", TimeType, gmtId), ByteType), + DecimalType.SYSTEM_DEFAULT), LongType), StringType), ShortType), + null) checkEvaluation(cast(cast(cast(cast(cast(cast("5", DecimalType.SYSTEM_DEFAULT), ByteType), TimestampType), LongType), StringType), ShortType), 5.toShort) + checkEvaluation(cast(cast(cast(cast(cast(cast("5", DecimalType.SYSTEM_DEFAULT), + ByteType), TimeType), LongType), StringType), ShortType), + 5.toShort) + checkEvaluation(cast("23", DoubleType), 23d) checkEvaluation(cast("23", IntegerType), 23) checkEvaluation(cast("23", FloatType), 23f) @@ -363,6 +502,7 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper { val gmtId = Option("GMT") checkEvaluation(cast(cast(d, TimestampType, gmtId), StringType, gmtId), "1970-01-01 00:00:00") + checkEvaluation(cast(cast(d, TimeType, gmtId), StringType, gmtId), "1970-01-01 00:00:00") } test("cast from timestamp") { @@ -400,6 +540,42 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(cast(1.0f / 0.0f, TimestampType), null) } + test("cast from time") { + val millis = 15 * 1000 + 3 + val seconds = millis * 1000 + 3 + val ts = new Time(millis) + val tss = new Time(seconds) + cast(ts, ShortType) + checkEvaluation(cast(ts, ShortType), 15.toShort) + checkEvaluation(cast(ts, IntegerType), 15) + checkEvaluation(cast(ts, LongType), 15.toLong) + checkEvaluation(cast(ts, FloatType), 15.003f) + checkEvaluation(cast(ts, DoubleType), 15.003) + checkEvaluation(cast(cast(tss, ShortType), TimeType), + DateTimeUtils.fromJavaTime(ts) * MILLIS_PER_SECOND) + checkEvaluation(cast(cast(tss, IntegerType), TimeType), + DateTimeUtils.fromJavaTime(ts) * MILLIS_PER_SECOND) + checkEvaluation(cast(cast(tss, LongType), TimeType), + DateTimeUtils.fromJavaTime(ts) * MILLIS_PER_SECOND) + checkEvaluation( + cast(cast(millis.toFloat / MILLIS_PER_SECOND, TimeType), FloatType), + millis.toFloat / MILLIS_PER_SECOND) + checkEvaluation( + cast(cast(millis.toDouble / MILLIS_PER_SECOND, TimeType), DoubleType), + millis.toDouble / MILLIS_PER_SECOND) + checkEvaluation( + cast(cast(Decimal(1), TimeType), DecimalType.SYSTEM_DEFAULT), + Decimal(1)) + +// // A test for higher precision than millis + checkEvaluation(cast(cast(0.000001, TimeType), DoubleType), 0.000001) +// + checkEvaluation(cast(Double.NaN, TimeType), null) + checkEvaluation(cast(1.0 / 0.0, TimeType), null) + checkEvaluation(cast(Float.NaN, TimeType), null) + checkEvaluation(cast(1.0f / 0.0f, TimeType), null) + } + test("cast from array") { val array = Literal.create(Seq("123", "true", "f", null), ArrayType(StringType, containsNull = true)) @@ -564,6 +740,16 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(cast(inp, targetSchema), expected) } + test("cast struct with a time field") { + val originalSchema = new StructType().add("tsField", TimeType, nullable = false) + // nine out of ten times I'm casting a struct, it's to normalize its fields nullability + val targetSchema = new StructType().add("tsField", TimeType, nullable = true) + + val inp = Literal.create(InternalRow(0L), originalSchema) + val expected = InternalRow(0L) + checkEvaluation(cast(inp, targetSchema), expected) + } + test("complex casting") { val complex = Literal.create( Row( @@ -859,6 +1045,8 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper { checkExceptionInExpression[ArithmeticException](cast(Decimal(value.toString), dt), "overflow") checkExceptionInExpression[ArithmeticException]( cast(Literal(value * MICROS_PER_SECOND, TimestampType), dt), "overflow") + checkExceptionInExpression[ArithmeticException]( + cast(Literal(value * MICROS_PER_SECOND, TimeType), dt), "overflow") checkExceptionInExpression[ArithmeticException]( cast(Literal(value * 1.5f, FloatType), dt), "overflow") checkExceptionInExpression[ArithmeticException]( @@ -885,6 +1073,8 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper { checkExceptionInExpression[ArithmeticException](cast(value, ByteType), "overflow") checkExceptionInExpression[ArithmeticException]( cast(Literal(value * MICROS_PER_SECOND, TimestampType), ByteType), "overflow") + checkExceptionInExpression[ArithmeticException]( + cast(Literal(value * MICROS_PER_SECOND, TimeType), ByteType), "overflow") checkExceptionInExpression[ArithmeticException]( cast(Literal(value.toFloat, FloatType), ByteType), "overflow") checkExceptionInExpression[ArithmeticException]( @@ -896,6 +1086,7 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(cast(value.toString, ByteType), value) checkEvaluation(cast(Decimal(value.toString), ByteType), value) checkEvaluation(cast(Literal(value * MICROS_PER_SECOND, TimestampType), ByteType), value) + checkEvaluation(cast(Literal(value * MICROS_PER_SECOND, TimeType), ByteType), value) checkEvaluation(cast(Literal(value.toInt, DateType), ByteType), null) checkEvaluation(cast(Literal(value.toFloat, FloatType), ByteType), value) checkEvaluation(cast(Literal(value.toDouble, DoubleType), ByteType), value) @@ -910,6 +1101,8 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper { checkExceptionInExpression[ArithmeticException](cast(value, ShortType), "overflow") checkExceptionInExpression[ArithmeticException]( cast(Literal(value * MICROS_PER_SECOND, TimestampType), ShortType), "overflow") + checkExceptionInExpression[ArithmeticException]( + cast(Literal(value * MICROS_PER_SECOND, TimeType), ShortType), "overflow") checkExceptionInExpression[ArithmeticException]( cast(Literal(value.toFloat, FloatType), ShortType), "overflow") checkExceptionInExpression[ArithmeticException]( @@ -921,6 +1114,7 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(cast(value.toString, ShortType), value) checkEvaluation(cast(Decimal(value.toString), ShortType), value) checkEvaluation(cast(Literal(value * MICROS_PER_SECOND, TimestampType), ShortType), value) + checkEvaluation(cast(Literal(value * MICROS_PER_SECOND, TimeType), ShortType), value) checkEvaluation(cast(Literal(value.toInt, DateType), ShortType), null) checkEvaluation(cast(Literal(value.toFloat, FloatType), ShortType), value) checkEvaluation(cast(Literal(value.toDouble, DoubleType), ShortType), value) @@ -938,6 +1132,7 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(cast(value.toString, IntegerType), value) checkEvaluation(cast(Decimal(value.toString), IntegerType), value) checkEvaluation(cast(Literal(value * MICROS_PER_SECOND, TimestampType), IntegerType), value) + checkEvaluation(cast(Literal(value * MICROS_PER_SECOND, TimeType), IntegerType), value) checkEvaluation(cast(Literal(value * 1.0, DoubleType), IntegerType), value) } checkEvaluation(cast(Int.MaxValue + 0.9D, IntegerType), Int.MaxValue) @@ -955,6 +1150,8 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(cast(Decimal(value.toString), LongType), value) checkEvaluation(cast(Literal(value, TimestampType), LongType), Math.floorDiv(value, MICROS_PER_SECOND)) + checkEvaluation(cast(Literal(value, TimeType), LongType), + Math.floorDiv(value, MICROS_PER_SECOND)) } checkEvaluation(cast(Long.MaxValue + 0.9F, LongType), Long.MaxValue) checkEvaluation(cast(Long.MinValue - 0.9F, LongType), Long.MinValue) @@ -1020,6 +1217,9 @@ class CastSuite extends CastSuiteBase { checkEvaluation(cast(cast(1000, TimestampType), LongType), 1000.toLong) checkEvaluation(cast(cast(-1200, TimestampType), LongType), -1200.toLong) + checkEvaluation(cast(cast(1000, TimeType), LongType), 1000.toLong) + checkEvaluation(cast(cast(-1200, TimeType), LongType), -1200.toLong) + checkEvaluation(cast(123, DecimalType.USER_DEFAULT), Decimal(123)) checkEvaluation(cast(123, DecimalType(3, 0)), Decimal(123)) checkEvaluation(cast(123, DecimalType(3, 1)), null) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/HiveResult.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/HiveResult.scala index b19184055268a..ca42bfa13f57a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/HiveResult.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/HiveResult.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution import java.nio.charset.StandardCharsets -import java.sql.{Date, Timestamp} +import java.sql.{Date, Time, Timestamp} import java.time.{Instant, LocalDate} import org.apache.spark.sql.{Dataset, Row} @@ -89,7 +89,9 @@ object HiveResult { dateFormatter.format(DateTimeUtils.localDateToDays(ld)) case (t: Timestamp, TimestampType) => timestampFormatter.format(DateTimeUtils.fromJavaTimestamp(t)) - case (i: Instant, TimestampType) => + case (t: Time, TimeType) => + timestampFormatter.format(DateTimeUtils.fromJavaTime(t)) + case (i: Instant, TimestampType | TimeType) => timestampFormatter.format(DateTimeUtils.instantToMicros(i)) case (bin: Array[Byte], BinaryType) => new String(bin, StandardCharsets.UTF_8) case (decimal: java.math.BigDecimal, DecimalType()) => decimal.toPlainString diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/HiveResultSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/HiveResultSuite.scala index bddd15c6e25d6..54e90a1f04d4e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/HiveResultSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/HiveResultSuite.scala @@ -46,6 +46,20 @@ class HiveResultSuite extends SharedSparkSession { assert(result2 == timestamps.map(x => s"[$x]")) } + test("time formatting in hive result") { + val time = Seq( + "2018-12-28 01:02:03", + "1582-10-13 01:02:03", + "1582-10-14 01:02:03", + "1582-10-15 01:02:03") + val df = time.toDF("a").selectExpr("cast(a as time) as b") + val result = HiveResult.hiveResultString(df) + assert(result == time) + val df2 = df.selectExpr("array(b)") + val result2 = HiveResult.hiveResultString(df2) + assert(result2 == time.map(x => s"[$x]")) + } + test("toHiveString correctly handles UDTs") { val point = new ExamplePoint(50.0, 50.0) val tpe = new ExamplePointUDT()