Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -465,6 +465,20 @@ class VeloxParquetDataTypeValidationSuite extends VeloxWholeStageTransformerSuit
}
}

testWithMinSparkVersion("Fallback for TimestampNTZ type scan", "3.4") {
withTempDir {
dir =>
val path = new File(dir, "ntz_data").toURI.getPath
val inputDf =
spark.sql("SELECT CAST('2024-01-01 00:00:00' AS TIMESTAMP_NTZ) AS ts_ntz")
inputDf.write.format("parquet").save(path)
val df = spark.read.format("parquet").load(path)
val executedPlan = getExecutedPlan(df)
assert(!executedPlan.exists(plan => plan.isInstanceOf[BatchScanExecTransformer]))
checkAnswer(df, inputDf)
}
}

test("Velox Parquet Write") {
withSQLConf((GlutenConfig.NATIVE_WRITER_ENABLED.key, "true")) {
withTempDir {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ object SparkArrowUtil {
} else {
new ArrowType.Timestamp(TimeUnit.MICROSECOND, "UTC")
}
case dt if dt.catalogString == "timestamp_ntz" =>
new ArrowType.Timestamp(TimeUnit.MICROSECOND, null)
case YearMonthIntervalType.DEFAULT =>
new ArrowType.Interval(IntervalUnit.YEAR_MONTH)
case _: ArrayType => ArrowType.List.INSTANCE
Expand All @@ -72,7 +74,17 @@ object SparkArrowUtil {
case ArrowType.Binary.INSTANCE => BinaryType
case d: ArrowType.Decimal => DecimalType(d.getPrecision, d.getScale)
case date: ArrowType.Date if date.getUnit == DateUnit.DAY => DateType
// TODO: Time unit is not handled.
case ts: ArrowType.Timestamp if ts.getUnit == TimeUnit.MICROSECOND && ts.getTimezone == null =>
// TimestampNTZType is only available in Spark 3.4+
try {
Class
.forName("org.apache.spark.sql.types.TimestampNTZType$")
.getField("MODULE$")
.get(null)
.asInstanceOf[DataType]
} catch {
case _: ClassNotFoundException => TimestampType
}
case _: ArrowType.Timestamp => TimestampType
case interval: ArrowType.Interval if interval.getUnit == IntervalUnit.YEAR_MONTH =>
YearMonthIntervalType.DEFAULT
Expand Down Expand Up @@ -156,7 +168,8 @@ object SparkArrowUtil {
}.asJava)
}

// TimestampNTZ does not support
// TimestampNTZ is not supported for native computation, but the Arrow type mapping is needed
// for row-to-columnar transitions when the fallback validator tags NTZ operators.
def checkSchema(schema: StructType): Boolean = {
try {
SparkSchemaUtil.toArrowSchema(schema)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -340,4 +340,66 @@ abstract class DeltaSuite extends WholeStageTransformerSuite {
}
}
}

// TIMESTAMP_NTZ was introduced in Spark 3.4 / Delta 2.4
testWithMinSparkVersion(
"delta: create table with TIMESTAMP_NTZ should fallback and return correct results",
"3.4") {
withTable("delta_ntz") {
spark.sql("CREATE TABLE delta_ntz(c1 STRING, c2 TIMESTAMP, c3 TIMESTAMP_NTZ) USING DELTA")
spark.sql("""INSERT INTO delta_ntz VALUES
|('foo','2022-01-02 03:04:05.123456','2022-01-02 03:04:05.123456')""".stripMargin)
val df = runQueryAndCompare("select * from delta_ntz", noFallBack = false) { _ => }
checkAnswer(
df,
Row(
"foo",
java.sql.Timestamp.valueOf("2022-01-02 03:04:05.123456"),
java.time.LocalDateTime.of(2022, 1, 2, 3, 4, 5, 123456000)))
}
}

testWithMinSparkVersion(
"delta: TIMESTAMP_NTZ as partition column should fallback and return correct results",
"3.4") {
withTable("delta_ntz_part") {
spark.sql("""CREATE TABLE delta_ntz_part(c1 STRING, c2 TIMESTAMP, c3 TIMESTAMP_NTZ)
|USING DELTA PARTITIONED BY (c3)""".stripMargin)
spark.sql("""INSERT INTO delta_ntz_part VALUES
|('foo','2022-01-02 03:04:05.123456','2022-01-02 03:04:05.123456'),
|('bar','2023-06-15 10:30:00.000000','2023-06-15 10:30:00.000000')""".stripMargin)
val df = runQueryAndCompare("select * from delta_ntz_part order by c1", noFallBack = false) {
_ =>
}
checkAnswer(
df,
Seq(
Row(
"bar",
java.sql.Timestamp.valueOf("2023-06-15 10:30:00"),
java.time.LocalDateTime.of(2023, 6, 15, 10, 30, 0, 0)),
Row(
"foo",
java.sql.Timestamp.valueOf("2022-01-02 03:04:05.123456"),
java.time.LocalDateTime.of(2022, 1, 2, 3, 4, 5, 123456000))
)
)
}
}

testWithMinSparkVersion(
"delta: filter on TIMESTAMP_NTZ column should fallback and return correct results",
"3.4") {
withTable("delta_ntz_filter") {
spark.sql("CREATE TABLE delta_ntz_filter(id INT, ts TIMESTAMP_NTZ) USING DELTA")
spark.sql("""INSERT INTO delta_ntz_filter VALUES
|(1, '2022-01-01 00:00:00'),
|(2, '2023-01-01 00:00:00'),
|(3, '2024-01-01 00:00:00')""".stripMargin)
val df = runQueryAndCompare(
"select id from delta_ntz_filter where ts > '2022-06-01 00:00:00'",
noFallBack = false) { _ => }
checkAnswer(df, Seq(Row(2), Row(3)))
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleEx
import org.apache.spark.sql.execution.joins._
import org.apache.spark.sql.execution.window.WindowExec
import org.apache.spark.sql.hive.HiveTableScanExecTransformer
import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructType}

object Validators {
implicit class ValidatorBuilderImplicits(builder: Validator.Builder) {
Expand Down Expand Up @@ -78,6 +79,11 @@ object Validators {
builder.add(new FallbackByTestInjects())
}

/** Fails validation if a plan node's input or output schema contains TimestampNTZType. */
def fallbackByTimestampNTZ(): Validator.Builder = {
builder.add(new FallbackByTimestampNTZ())
}

/**
* Fails validation on non-scan plan nodes if Gluten is running as scan-only mode. Also, passes
* validation on filter for the exception that filter + scan is detected. Because filters can be
Expand Down Expand Up @@ -212,6 +218,25 @@ object Validators {
}
}

private class FallbackByTimestampNTZ() extends Validator {
override def validate(plan: SparkPlan): Validator.OutCome = {
def containsNTZ(dataType: DataType): Boolean = dataType match {
case dt if dt.catalogString == "timestamp_ntz" => true
case st: StructType => st.exists(f => containsNTZ(f.dataType))
case at: ArrayType => containsNTZ(at.elementType)
case mt: MapType => containsNTZ(mt.keyType) || containsNTZ(mt.valueType)
case _ => false
}
val hasNTZ = plan.output.exists(a => containsNTZ(a.dataType)) ||
plan.children.exists(_.output.exists(a => containsNTZ(a.dataType)))
if (hasNTZ) {
fail(s"${plan.nodeName} has TimestampNTZType in input/output schema")
} else {
pass()
}
}
}

private class FallbackIfScanOnlyWithFilterPushed(scanOnly: Boolean) extends Validator {
override def validate(plan: SparkPlan): Validator.OutCome = {
if (!scanOnly) {
Expand Down Expand Up @@ -292,6 +317,7 @@ object Validators {
.fallbackComplexExpressions()
.fallbackByBackendSettings()
.fallbackByUserOptions()
.fallbackByTimestampNTZ()
.fallbackByTestInjects()
.build()
}
Expand Down