Skip to content
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -636,6 +636,18 @@ class CHSparkPlanExecApi extends SparkPlanExecApi with Logging {
CHTruncTimestampTransformer(substraitExprName, format, timestamp, timeZoneId, original)
}

override def genToUnixTimestampTransformer(
substraitExprName: String,
timeExp: ExpressionTransformer,
format: ExpressionTransformer,
original: Expression): ExpressionTransformer = {
GenericExpressionTransformer(
substraitExprName,
Seq(timeExp, format),
original
)
}

override def genDateDiffTransformer(
substraitExprName: String,
endDate: ExpressionTransformer,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1048,4 +1048,12 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi {
}
TimestampDiffTransformer(substraitExprName, extract.get, left, right, original)
}

override def genToUnixTimestampTransformer(
substraitExprName: String,
timeExp: ExpressionTransformer,
format: ExpressionTransformer,
original: Expression): ExpressionTransformer = {
ToUnixTimestampTransformer(substraitExprName, timeExp, format, original)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import org.apache.gluten.substrait.SubstraitContext
import org.apache.gluten.substrait.expression._

import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types.{IntegerType, LongType}
import org.apache.spark.sql.types.{IntegerType, LongType, TimestampType}

import java.lang.{Integer => JInteger}
import java.util.{ArrayList => JArrayList}
Expand Down Expand Up @@ -109,3 +109,18 @@ case class VeloxHashExpressionTransformer(
ExpressionBuilder.makeScalarFunction(functionId, nodes, typeNode)
}
}

case class ToUnixTimestampTransformer(
substraitExprName: String,
timeExpTransformer: ExpressionTransformer,
formatTransformer: ExpressionTransformer,
original: Expression)
extends ExpressionTransformer {

override def children: Seq[ExpressionTransformer] = {
timeExpTransformer.dataType match {
case _: TimestampType => Seq(timeExpTransformer)
case _ => Seq(timeExpTransformer, formatTransformer)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -471,4 +471,21 @@ abstract class DateFunctionsValidateSuite extends FunctionsValidateSuite {
}
}
}

test("unix_timestamp with timestamp and format - no fallback") {
withTempPath {
path =>
Seq(
(Timestamp.valueOf("2016-04-08 13:10:15"), "yyyy-MM-dd"),
(Timestamp.valueOf("2017-05-19 18:25:30"), "MM/dd/yyyy")
).toDF("ts", "fmt").write.parquet(path.getCanonicalPath)

spark.read.parquet(path.getCanonicalPath).createOrReplaceTempView("unix_timestamp_test")

// Test unix_timestamp(timestamp, format) - should use native execution without fallback
runQueryAndCompare("SELECT unix_timestamp(ts, fmt) FROM unix_timestamp_test") {
checkGlutenOperatorMatch[ProjectExecTransformer]
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -470,6 +470,12 @@ trait SparkPlanExecApi {
TruncTimestampTransformer(substraitExprName, format, timestamp, original)
}

def genToUnixTimestampTransformer(
substraitExprName: String,
timeExp: ExpressionTransformer,
format: ExpressionTransformer,
original: Expression): ExpressionTransformer

def genDateDiffTransformer(
substraitExprName: String,
endDate: ExpressionTransformer,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -277,24 +277,18 @@ object ExpressionConverter extends SQLConfHelper with Logging {
replaceWithExpressionTransformer0(r.child, attributeSeq, expressionsMap),
r)
case t: ToUnixTimestamp =>
// The failOnError depends on the config for ANSI. ANSI is not supported currently.
// And timeZoneId is passed to backend config.
GenericExpressionTransformer(
BackendsApiManager.getSparkPlanExecApiInstance.genToUnixTimestampTransformer(
substraitExprName,
Seq(
replaceWithExpressionTransformer0(t.timeExp, attributeSeq, expressionsMap),
replaceWithExpressionTransformer0(t.format, attributeSeq, expressionsMap)
),
replaceWithExpressionTransformer0(t.timeExp, attributeSeq, expressionsMap),
replaceWithExpressionTransformer0(t.format, attributeSeq, expressionsMap),
t
)
case u: UnixTimestamp =>
GenericExpressionTransformer(
BackendsApiManager.getSparkPlanExecApiInstance.genToUnixTimestampTransformer(
substraitExprName,
Seq(
replaceWithExpressionTransformer0(u.timeExp, attributeSeq, expressionsMap),
replaceWithExpressionTransformer0(u.format, attributeSeq, expressionsMap)
),
ToUnixTimestamp(u.timeExp, u.format, u.timeZoneId, u.failOnError)
replaceWithExpressionTransformer0(u.timeExp, attributeSeq, expressionsMap),
replaceWithExpressionTransformer0(u.format, attributeSeq, expressionsMap),
u
)
case t: TruncTimestamp =>
BackendsApiManager.getSparkPlanExecApiInstance.genTruncTimestampTransformer(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,14 @@ class GlutenDateFunctionsSuite extends DateFunctionsSuite with GlutenSQLTestsTra
df2.select(unix_timestamp(col("y"), "yyyy-MM-dd")),
Seq(Row(secs(ts5.getTime)), Row(null)))

// Test unix_timestamp(timestamp, format) - format should be ignored
checkAnswer(
df.select(unix_timestamp(col("ts"), "yyyy-MM-dd")),
Seq(Row(secs(ts1.getTime)), Row(secs(ts2.getTime))))
checkAnswer(
df.selectExpr("unix_timestamp(ts, 'invalid-format')"),
Seq(Row(secs(ts1.getTime)), Row(secs(ts2.getTime))))

val now = sql("select unix_timestamp()").collect().head.getLong(0)
checkAnswer(
sql(s"select timestamp_seconds($now)"),
Expand Down Expand Up @@ -187,6 +195,14 @@ class GlutenDateFunctionsSuite extends DateFunctionsSuite with GlutenSQLTestsTra
df2.select(unix_timestamp(col("y"), "yyyy-MM-dd")),
Seq(Row(secs(ts5.getTime)), Row(null)))

// Test to_unix_timestamp(timestamp, format) - format should be ignored
checkAnswer(
df.selectExpr("to_unix_timestamp(ts, 'yyyy-MM-dd')"),
Seq(Row(secs(ts1.getTime)), Row(secs(ts2.getTime))))
checkAnswer(
df.selectExpr("to_unix_timestamp(ts, 'invalid-format')"),
Seq(Row(secs(ts1.getTime)), Row(secs(ts2.getTime))))

val invalid = df1.selectExpr(s"to_unix_timestamp(x, 'yyyy-MM-dd bb:HH:ss')")
checkAnswer(invalid, Seq(Row(null), Row(null), Row(null), Row(null)))
}
Expand Down