Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1212,6 +1212,77 @@ case class ToUTCTimestamp(left: Expression, right: Expression)
}
}

/**
* This modifies a timestamp to show how the display time changes going from one timezone to
* another, for the same instant in time.
*
* We intentionally do not provide an ExpressionDescription as this is not meant to be exposed to
* users, it's only used for internal conversions.
*/
private[spark] case class TimestampTimezoneCorrection(
time: Expression,
from: Expression,
to: Expression)
extends TernaryExpression with ImplicitCastInputTypes {

// convertTz() does the *opposite* conversion we want, which is why from & to appear reversed
// in all the calls to convertTz. It's used for showing how the display time changes when we go
// from one timezone to another. We want to see how the SQLTimestamp value should change to
// ensure the display does *not* change, despite going from one TZ to another.

override def children: Seq[Expression] = Seq(time, from, to)
override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType, StringType, StringType)
override def dataType: DataType = TimestampType
override def prettyName: String = "timestamp_timezone_correction"

override def nullSafeEval(time: Any, from: Any, to: Any): Any = {
DateTimeUtils.convertTz(
time.asInstanceOf[Long],
DateTimeUtils.getTimeZone(to.asInstanceOf[UTF8String].toString()),
DateTimeUtils.getTimeZone(from.asInstanceOf[UTF8String].toString()))
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
if (from.foldable && to.foldable) {
val fromTz = from.eval()
val toTz = to.eval()
if (fromTz == null || toTz == null) {
ev.copy(code = s"""
|boolean ${ev.isNull} = true;
|long ${ev.value} = 0;
""".stripMargin)
} else {
val fromTerm = ctx.freshName("from")
val toTerm = ctx.freshName("to")
val tzClass = classOf[TimeZone].getName
val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
ctx.addMutableState(tzClass, fromTerm, s"""$fromTerm = $dtu.getTimeZone("$fromTz");""")
ctx.addMutableState(tzClass, toTerm, s"""$toTerm = $dtu.getTimeZone("$toTz");""")

val eval = time.genCode(ctx)
ev.copy(code = s"""
|${eval.code}
|boolean ${ev.isNull} = ${eval.isNull};
|long ${ev.value} = 0;
|if (!${ev.isNull}) {
| ${ev.value} = $dtu.convertTz(${eval.value}, $toTerm, $fromTerm);
|}
""".stripMargin)
}
} else {
nullSafeCodeGen(ctx, ev, (time, from, to) =>
s"""
|${ev.value} = $dtu.convertTz(
| $time,
| $dtu.getTimeZone($to.toString()),
| $dtu.getTimeZone($from.toString()));
""".stripMargin
)
}
}
}

/**
* Parses a column to a date based on the given format.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ import javax.xml.bind.DatatypeConverter

import scala.annotation.tailrec

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.unsafe.types.UTF8String

/**
Expand Down Expand Up @@ -65,6 +67,13 @@ object DateTimeUtils {

val TIMEZONE_OPTION = "timeZone"

/**
* Property that holds the time zone used for adjusting "timestamp without time zone"
* columns to the session's time zone. See SPARK-12297 for more details (including the
* specified name of this property).
*/
val TIMEZONE_PROPERTY = "table.timezone-adjustment"

def defaultTimeZone(): TimeZone = TimeZone.getDefault()

// Reuse the Calendar object in each thread as it is expensive to create in each method call.
Expand Down Expand Up @@ -109,6 +118,12 @@ object DateTimeUtils {
computedTimeZones.computeIfAbsent(timeZoneId, computeTimeZone)
}

private lazy val validTimezones = TimeZone.getAvailableIDs().toSet

def isValidTimezone(timezoneId: String): Boolean = {
validTimezones.contains(timezoneId)
}

def newDateFormat(formatString: String, timeZone: TimeZone): DateFormat = {
val sdf = new SimpleDateFormat(formatString, Locale.US)
sdf.setTimeZone(timeZone)
Expand Down Expand Up @@ -1065,4 +1080,24 @@ object DateTimeUtils {
threadLocalTimestampFormat.remove()
threadLocalDateFormat.remove()
}

/**
* Throw an AnalysisException if we're trying to set an invalid timezone for this table.
*/
def checkTableTz(table: TableIdentifier, properties: Map[String, String]): Unit = {
checkTableTz(s"in table ${table.toString}", properties)
}

/**
* Throw an AnalysisException if we're trying to set an invalid timezone for this table.
*/
def checkTableTz(dest: String, properties: Map[String, String]): Unit = {
properties.get(TIMEZONE_PROPERTY).foreach { tz =>
if (!DateTimeUtils.isValidTimezone(tz)) {
throw new AnalysisException(s"Cannot set $TIMEZONE_PROPERTY to invalid " +
s"timezone $tz $dest")
}
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -741,4 +741,32 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
test("2015-07-24 00:00:00", null, null)
test(null, null, null)
}

test("timestamp_timezone_correction") {
def test(t: String, fromTz: String, toTz: String, expected: String): Unit = {
checkEvaluation(
TimestampTimezoneCorrection(
Literal.create(if (t != null) Timestamp.valueOf(t) else null, TimestampType),
Literal.create(fromTz, StringType),
Literal.create(toTz, StringType)),
if (expected != null) Timestamp.valueOf(expected) else null)
checkEvaluation(
TimestampTimezoneCorrection(
Literal.create(if (t != null) Timestamp.valueOf(t) else null, TimestampType),
NonFoldableLiteral.create(fromTz, StringType),
NonFoldableLiteral.create(toTz, StringType)),
if (expected != null) Timestamp.valueOf(expected) else null)
}
// These conversions may look backwards -- but this is *NOT* saying:
// when the clock says 2015-07-24 00:00:00 in PST, what would it say to somebody in UTC?
// Instead, its saying -- suppose somebody stored "2015-07-24 00:00:00" while in PST, but
// as millis-since-epoch. What millis-since-epoch would I need to also see
// "2015-07-24 00:00:00" if my clock were in UTC? Just for testing convenience, we input
// that last value as "what would my clock in PST say for that final millis-since-epoch?"
test("2015-07-24 00:00:00", "PST", "UTC", "2015-07-23 17:00:00")
test("2015-01-24 00:00:00", "PST", "UTC", "2015-01-23 16:00:00")
test(null, "UTC", "UTC", null)
test("2015-07-24 00:00:00", null, null, null)
test(null, null, null, null)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import org.apache.spark.api.java.JavaRDD
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JSONOptions}
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.execution.command.DDLUtils
import org.apache.spark.sql.execution.datasources.{DataSource, FailureSafeParser}
import org.apache.spark.sql.execution.datasources.csv._
Expand Down Expand Up @@ -179,6 +180,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
throw new AnalysisException("Hive data source can only be used with tables, you can not " +
"read files of Hive data source directly.")
}
DateTimeUtils.checkTableTz("", extraOptions.toMap)

sparkSession.baseRelationToDataFrame(
DataSource.apply(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.{EliminateSubqueryAliases, UnresolvedRelation}
import org.apache.spark.sql.catalyst.catalog._
import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoTable, LogicalPlan}
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.execution.SQLExecution
import org.apache.spark.sql.execution.command.DDLUtils
import org.apache.spark.sql.execution.datasources.{CreateTable, DataSource, LogicalRelation}
Expand Down Expand Up @@ -230,6 +231,11 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
}

assertNotBucketed("save")
val dest = extraOptions.get("path") match {
case Some(path) => s"for path $path"
case _ => s"with format $source"
}
DateTimeUtils.checkTableTz(dest, extraOptions.toMap)

runCommand(df.sparkSession, "save") {
DataSource(
Expand Down Expand Up @@ -266,6 +272,10 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
* @since 1.4.0
*/
def insertInto(tableName: String): Unit = {
extraOptions.get(DateTimeUtils.TIMEZONE_PROPERTY).foreach { tz =>
throw new AnalysisException("Cannot provide a table timezone on insert; tried to insert " +
s"$tableName with ${DateTimeUtils.TIMEZONE_PROPERTY}=$tz")
}
insertInto(df.sparkSession.sessionState.sqlParser.parseTableIdentifier(tableName))
}

Expand Down Expand Up @@ -406,6 +416,8 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
} else {
CatalogTableType.MANAGED
}
val props = extraOptions.filterKeys(_ == DateTimeUtils.TIMEZONE_PROPERTY).toMap
DateTimeUtils.checkTableTz(tableIdent, props)

val tableDesc = CatalogTable(
identifier = tableIdent,
Expand All @@ -414,7 +426,8 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
schema = new StructType,
provider = Some(source),
partitionColumnNames = partitioningColumns.getOrElse(Nil),
bucketSpec = getBucketSpec)
bucketSpec = getBucketSpec,
properties = props)

runCommand(df.sparkSession, "saveAsTable")(CreateTable(tableDesc, mode, Some(df.logicalPlan)))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import org.apache.spark.sql.catalyst.catalog._
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation, PartitioningUtils}
import org.apache.spark.sql.execution.datasources.orc.OrcFileFormat
import org.apache.spark.sql.execution.datasources.parquet.ParquetSchemaConverter
Expand Down Expand Up @@ -230,6 +231,13 @@ case class AlterTableSetPropertiesCommand(
isView: Boolean)
extends RunnableCommand {

if (isView) {
properties.get(DateTimeUtils.TIMEZONE_PROPERTY).foreach { _ =>
throw new AnalysisException("Timezone cannot be set for view")
}
}
DateTimeUtils.checkTableTz(tableName, properties)

override def run(sparkSession: SparkSession): Seq[Row] = {
val catalog = sparkSession.sessionState.catalog
val table = catalog.getTableMetadata(tableName)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ import org.apache.spark.sql.catalyst.catalog._
import org.apache.spark.sql.catalyst.catalog.CatalogTableType._
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
import org.apache.spark.sql.catalyst.util.quoteIdentifier
import org.apache.spark.sql.catalyst.util.{quoteIdentifier, DateTimeUtils}
import org.apache.spark.sql.execution.datasources.{DataSource, PartitioningUtils}
import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat
import org.apache.spark.sql.execution.datasources.json.JsonFileFormat
Expand Down Expand Up @@ -76,6 +76,8 @@ case class CreateTableLikeCommand(
// If the location is specified, we create an external table internally.
// Otherwise create a managed table.
val tblType = if (location.isEmpty) CatalogTableType.MANAGED else CatalogTableType.EXTERNAL
val properties =
sourceTableDesc.properties.filterKeys(_ == DateTimeUtils.TIMEZONE_PROPERTY)

val newTableDesc =
CatalogTable(
Expand All @@ -86,7 +88,8 @@ case class CreateTableLikeCommand(
schema = sourceTableDesc.schema,
provider = newProvider,
partitionColumnNames = sourceTableDesc.partitionColumnNames,
bucketSpec = sourceTableDesc.bucketSpec)
bucketSpec = sourceTableDesc.bucketSpec,
properties = properties)

catalog.createTable(newTableDesc, ifNotExists)
Seq.empty[Row]
Expand Down Expand Up @@ -126,6 +129,8 @@ case class CreateTableCommand(
sparkSession.sessionState.catalog.createTable(table, ignoreIfExists)
Seq.empty[Row]
}

DateTimeUtils.checkTableTz(table.identifier, table.properties)
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable
import org.apache.spark.sql.catalyst.expressions.{Alias, SubqueryExpression}
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, View}
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.types.MetadataBuilder
import org.apache.spark.sql.util.SchemaUtils

Expand Down Expand Up @@ -123,6 +124,10 @@ case class CreateViewCommand(
s"It is not allowed to add database prefix `$database` for the TEMPORARY view name.")
}

properties.get(DateTimeUtils.TIMEZONE_PROPERTY).foreach { _ =>
throw new AnalysisException("Timezone cannot be set for view")
}

override def run(sparkSession: SparkSession): Seq[Row] = {
// If the plan cannot be analyzed, throw an exception and don't proceed.
val qe = sparkSession.sessionState.executePlan(child)
Expand Down
Loading