From 020e701f13792f31fc6dd95b2598b1d4ca65ee02 Mon Sep 17 00:00:00 2001 From: Fokko Driesprong Date: Mon, 8 Jun 2020 13:48:54 +0200 Subject: [PATCH 1/3] [SPARK-10520][SQL] Allow average out of DateType This allows to make an average out of date types. Under the hood we take an average of the days since epoch, and convert it to a date again. This requires the date object to be casted to a double to perform the average. Error in invokeJava(isStatic = FALSE, objId$id, methodName, ...) : org.apache.spark.sql.AnalysisException: cannot resolve 'avg(date)' due to data type mismatch: function average requires numeric types, not DateType; at org.apache.spark.sql.catalyst.analysis.package$AnalysisErrorAt.failAnalysis(package.scala:42) at org.apache.spark.sql.catalyst.analysis.CheckAnalysis$$anonfun$checkAnalysis$1$$anonfun$apply$2.applyOrElse(CheckAnalysis.scala:61) at org.apache.spark.sql.catalyst.analysis.CheckAnalysis$$anonfun$checkAnalysis$1$$anonfun$apply$2.applyOrElse(CheckAnalysis.scala:53) at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$transformUp$1.apply(TreeNode.scala:293) at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$transformUp$1.apply(TreeNode.scala:293) at org.apache.spark.sql.catalyst.trees.CurrentOrigin$.withOrigin(TreeNode.scala:51) at org.apache.spark.sql.catalyst.trees.TreeNode.transformUp(TreeNode.scala:292) at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$5.apply(TreeNode.scala:290) at org.apache.spark.sql. --- .../spark/sql/catalyst/expressions/Cast.scala | 11 ++++++++--- .../expressions/aggregate/Average.scala | 18 ++++++++++++++---- .../sql/catalyst/util/DateTimeUtils.scala | 19 +++++++++++++++++++ .../sql/catalyst/expressions/CastSuite.scala | 17 ++++++++++++++++- .../spark/sql/DataFrameAggregateSuite.scala | 12 +++++++++++- .../apache/spark/sql/test/SQLTestData.scala | 12 ++++++++++++ 6 files changed, 80 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index ef70915a5c969..b4bae890b27ac 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import java.math.{BigDecimal => JavaBigDecimal} +import java.sql.Date import java.time.ZoneId import java.util.Locale import java.util.concurrent.TimeUnit._ @@ -63,6 +64,7 @@ object Cast { case (StringType, DateType) => true case (TimestampType, DateType) => true + case (DoubleType, DateType) => true case (StringType, CalendarIntervalType) => true @@ -468,6 +470,8 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit private[this] def castToDate(from: DataType): Any => Any = from match { case StringType => buildCast[UTF8String](_, s => DateTimeUtils.stringToDate(s, zoneId).orNull) + case DoubleType => + buildCast[Double](_, daysSinceEpoch => convertTz(daysSinceEpoch.toInt)) case TimestampType => // throw valid precision more than seconds, according to Hive. // Timestamp.nanos is in 0 to 999,999,999, no more than a second. @@ -694,8 +698,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit }) case BooleanType => buildCast[Boolean](_, b => if (b) 1d else 0d) - case DateType => - buildCast[Int](_, d => null) + case DateType => _.asInstanceOf[Int].toDouble case TimestampType => buildCast[Long](_, t => timestampToDouble(t)) case x: NumericType => @@ -1107,6 +1110,8 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit $evNull = true; } """ + case DoubleType => (c, evPrim, evNull) => + code"$evPrim = org.apache.spark.sql.catalyst.util.DateTimeUtils.convertTz((int)$c);" case TimestampType => val zid = getZoneId() (c, evPrim, evNull) => @@ -1569,7 +1574,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit case BooleanType => (c, evPrim, evNull) => code"$evPrim = $c ? 1.0d : 0.0d;" case DateType => - (c, evPrim, evNull) => code"$evNull = true;" + (c, evPrim, evNull) => code"$evPrim = (double) $c;" case TimestampType => (c, evPrim, evNull) => code"$evPrim = ${timestampToDoubleCode(c)};" case DecimalType() => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala index d3ce1f8d331ab..713fe01dcb093 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala @@ -40,10 +40,17 @@ case class Average(child: Expression) extends DeclarativeAggregate with Implicit override def children: Seq[Expression] = child :: Nil - override def inputTypes: Seq[AbstractDataType] = Seq(NumericType) + override def inputTypes: Seq[AbstractDataType] = Seq(NumericType, DateType) - override def checkInputDataTypes(): TypeCheckResult = - TypeUtils.checkForNumericExpr(child.dataType, "function average") + override def checkInputDataTypes(): TypeCheckResult = { + val isNumeric = TypeUtils.checkForNumericExpr(child.dataType, "function average") + + if(isNumeric.isFailure && child.dataType == DateType) { + TypeCheckResult.TypeCheckSuccess + } else { + isNumeric + } + } override def nullable: Boolean = true @@ -53,6 +60,7 @@ case class Average(child: Expression) extends DeclarativeAggregate with Implicit private lazy val resultType = child.dataType match { case DecimalType.Fixed(p, s) => DecimalType.bounded(p + 4, s + 4) + case DateType => DateType case _ => DoubleType } @@ -77,9 +85,11 @@ case class Average(child: Expression) extends DeclarativeAggregate with Implicit ) // If all input are nulls, count will be 0 and we will get null after the division. - override lazy val evaluateExpression = child.dataType match { + override lazy val evaluateExpression: Expression = child.dataType match { case _: DecimalType => DecimalPrecision.decimalAndDecimal(sum / count.cast(DecimalType.LongDecimal)).cast(resultType) + case _: DateType => + (sum / count).cast(resultType) case _ => sum.cast(resultType) / count.cast(resultType) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index 41a271b95e83c..8ddd8ca3282ae 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -422,6 +422,10 @@ object DateTimeUtils { Instant.ofEpochSecond(secs, mos * NANOS_PER_MICROS) } + def daysToInstant(daysSinceEpoch: SQLDate): Instant = { + Instant.ofEpochSecond(daysSinceEpoch * SECONDS_PER_DAY) + } + def instantToDays(instant: Instant): Int = { val seconds = instant.getEpochSecond val days = Math.floorDiv(seconds, SECONDS_PER_DAY) @@ -821,6 +825,21 @@ object DateTimeUtils { instantToMicros(rebasedDateTime.toInstant) } + /** + * Convert the date `ts` from one date to another. + * + * TODO: Because of DST, the conversion between UTC and human time is not exactly one-to-one + * mapping, the conversion here may return wrong result, we should make the timestamp + * timezone-aware. + */ + def convertTz(ts: SQLDate, fromZone: ZoneId, toZone: ZoneId): SQLDate = { + val rebasedDateTime = daysToInstant(ts).atZone(toZone).toLocalDateTime.atZone(fromZone) + instantToDays(rebasedDateTime.toInstant) + } + + // Convenience method for making it easier to only pass the first argument in Java codegen + def convertTz(ts: SQLDate): SQLDate = convertTz(ts, ZoneOffset.UTC, defaultTimeZone().toZoneId) + /** * Returns a timestamp of given timezone from utc timestamp, with the same string * representation in their timezone. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala index 6af995cab64fe..1eaeb3aefdc44 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala @@ -356,7 +356,7 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(cast(d, IntegerType), null) checkEvaluation(cast(d, LongType), null) checkEvaluation(cast(d, FloatType), null) - checkEvaluation(cast(d, DoubleType), null) + checkEvaluation(cast(d, DoubleType), 0.0d) checkEvaluation(cast(d, DecimalType.SYSTEM_DEFAULT), null) checkEvaluation(cast(d, DecimalType(10, 2)), null) checkEvaluation(cast(d, StringType), "1970-01-01") @@ -1299,6 +1299,21 @@ class CastSuite extends CastSuiteBase { } } + private val dateDaysSinceEpoch = 18389.0 // Days since epoch (1970-01-01) + private val date = Date.valueOf("2020-05-07") + + test("SPARK-10520: Cast a Date to Double") { + withDefaultTimeZone(UTC) { + checkEvaluation(cast(Literal(date), DoubleType), dateDaysSinceEpoch) + } + } + + test("SPARK-10520: Cast a Double to Date") { + withDefaultTimeZone(UTC) { + checkEvaluation(cast(Literal(dateDaysSinceEpoch), DateType), date) + } + } + test("cast a timestamp before the epoch 1970-01-01 00:00:00Z") { withDefaultTimeZone(UTC) { val negativeTs = Timestamp.valueOf("1900-05-05 18:34:56.1") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 2293d4ae61aff..2604cf7e6a515 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -17,10 +17,13 @@ package org.apache.spark.sql +import java.sql.Date + import scala.util.Random import org.scalatest.Matchers.the +import org.apache.spark.sql.catalyst.util.DateTimeTestUtils.{withDefaultTimeZone, UTC} import org.apache.spark.sql.execution.WholeStageCodegenExec import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec} @@ -31,7 +34,6 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.test.SQLTestData.DecimalData import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.CalendarInterval case class Fact(date: Int, hour: Int, minute: Int, room_name: String, temp: Double) @@ -319,6 +321,14 @@ class DataFrameAggregateSuite extends QueryTest Row(new java.math.BigDecimal(2), new java.math.BigDecimal(6)) :: Nil) } + test("SPARK-10520: date average") { + withDefaultTimeZone(UTC) { + checkAnswer( + testDataDates.agg(avg($"a")), + Row(new Date(2011, 4, 3))) + } + } + test("null average") { checkAnswer( testData3.agg(avg($"b")), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala index c51faaf10f5dd..9dd5a416265d0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.test import java.nio.charset.StandardCharsets +import java.sql.Date import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, SparkSession, SQLContext, SQLImplicits} @@ -73,6 +74,16 @@ private[sql] trait SQLTestData { self => df } + protected lazy val testDataDates: DataFrame = { + val df = spark.sparkContext.parallelize( + TestDataDate(new Date(2000, 1, 1)) :: + TestDataDate(new Date(2010, 1, 1)) :: + TestDataDate(new Date(2015, 1, 1)) :: + TestDataDate(new Date(2020, 1, 1)) :: Nil, 2).toDF() + df.createOrReplaceTempView("testDates") + df + } + protected lazy val negativeData: DataFrame = { val df = spark.sparkContext.parallelize( (1 to 100).map(i => TestData(-i, (-i).toString))).toDF() @@ -326,6 +337,7 @@ private[sql] object SQLTestData { case class TestData(key: Int, value: String) case class TestData2(a: Int, b: Int) case class TestData3(a: Int, b: Option[Int]) + case class TestDataDate(a: Date) case class LargeAndSmallInts(a: Int, b: Int) case class DecimalData(a: BigDecimal, b: BigDecimal) case class BinaryData(a: Array[Byte], b: Int) From 533dd8de1b8c84c4ab6fe50152d527469e09deba Mon Sep 17 00:00:00 2001 From: Fokko Driesprong Date: Mon, 8 Jun 2020 17:11:36 +0200 Subject: [PATCH 2/3] Remove the timezone conversions --- .../spark/sql/catalyst/expressions/Cast.scala | 5 ++--- .../sql/catalyst/util/DateTimeUtils.scala | 19 ------------------- 2 files changed, 2 insertions(+), 22 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index b4bae890b27ac..c9e774360e344 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -471,7 +471,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit case StringType => buildCast[UTF8String](_, s => DateTimeUtils.stringToDate(s, zoneId).orNull) case DoubleType => - buildCast[Double](_, daysSinceEpoch => convertTz(daysSinceEpoch.toInt)) + buildCast[Double](_, daysSinceEpoch => daysSinceEpoch.toInt) case TimestampType => // throw valid precision more than seconds, according to Hive. // Timestamp.nanos is in 0 to 999,999,999, no more than a second. @@ -1110,8 +1110,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit $evNull = true; } """ - case DoubleType => (c, evPrim, evNull) => - code"$evPrim = org.apache.spark.sql.catalyst.util.DateTimeUtils.convertTz((int)$c);" + case DoubleType => (c, evPrim, evNull) => code"$evPrim = (int) $c;" case TimestampType => val zid = getZoneId() (c, evPrim, evNull) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index 8ddd8ca3282ae..41a271b95e83c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -422,10 +422,6 @@ object DateTimeUtils { Instant.ofEpochSecond(secs, mos * NANOS_PER_MICROS) } - def daysToInstant(daysSinceEpoch: SQLDate): Instant = { - Instant.ofEpochSecond(daysSinceEpoch * SECONDS_PER_DAY) - } - def instantToDays(instant: Instant): Int = { val seconds = instant.getEpochSecond val days = Math.floorDiv(seconds, SECONDS_PER_DAY) @@ -825,21 +821,6 @@ object DateTimeUtils { instantToMicros(rebasedDateTime.toInstant) } - /** - * Convert the date `ts` from one date to another. - * - * TODO: Because of DST, the conversion between UTC and human time is not exactly one-to-one - * mapping, the conversion here may return wrong result, we should make the timestamp - * timezone-aware. - */ - def convertTz(ts: SQLDate, fromZone: ZoneId, toZone: ZoneId): SQLDate = { - val rebasedDateTime = daysToInstant(ts).atZone(toZone).toLocalDateTime.atZone(fromZone) - instantToDays(rebasedDateTime.toInstant) - } - - // Convenience method for making it easier to only pass the first argument in Java codegen - def convertTz(ts: SQLDate): SQLDate = convertTz(ts, ZoneOffset.UTC, defaultTimeZone().toZoneId) - /** * Returns a timestamp of given timezone from utc timestamp, with the same string * representation in their timezone. From bbf72c4a271d56c1a92770aba0a573179b57b765 Mon Sep 17 00:00:00 2001 From: Fokko Driesprong Date: Sun, 28 Jun 2020 08:19:51 +0200 Subject: [PATCH 3/3] We can cast a double to a date This is days since epoch --- .../org/apache/spark/sql/catalyst/expressions/CastSuite.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala index fe42c6b786293..727fb28a67855 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala @@ -652,7 +652,6 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper { assert(cast(1, DateType).checkInputDataTypes().isFailure) assert(cast(1L, DateType).checkInputDataTypes().isFailure) assert(cast(1.0.toFloat, DateType).checkInputDataTypes().isFailure) - assert(cast(1.0, DateType).checkInputDataTypes().isFailure) } test("SPARK-20302 cast with same structure") {