diff --git a/native/Cargo.lock b/native/Cargo.lock index 340f0fe0cd..2161ff9e41 100644 --- a/native/Cargo.lock +++ b/native/Cargo.lock @@ -1966,6 +1966,8 @@ dependencies = [ "rand 0.10.0", "regex", "serde_json", + "sha2", + "thiserror 2.0.18", "tokio", "twox-hash", ] diff --git a/native/core/src/execution/jni_api.rs b/native/core/src/execution/jni_api.rs index e0a395ebbf..7d547c2a7f 100644 --- a/native/core/src/execution/jni_api.rs +++ b/native/core/src/execution/jni_api.rs @@ -393,7 +393,6 @@ fn prepare_datafusion_session_context( // register UDFs from datafusion-spark crate fn register_datafusion_spark_function(session_ctx: &SessionContext) { session_ctx.register_udf(ScalarUDF::new_from_impl(SparkExpm1::default())); - session_ctx.register_udf(ScalarUDF::new_from_impl(SparkSha2::default())); session_ctx.register_udf(ScalarUDF::new_from_impl(CharFunc::default())); session_ctx.register_udf(ScalarUDF::new_from_impl(SparkBitGet::default())); session_ctx.register_udf(ScalarUDF::new_from_impl(SparkDateAdd::default())); @@ -401,6 +400,7 @@ fn register_datafusion_spark_function(session_ctx: &SessionContext) { session_ctx.register_udf(ScalarUDF::new_from_impl(SparkLastDay::default())); session_ctx.register_udf(ScalarUDF::new_from_impl(SparkNextDay::default())); session_ctx.register_udf(ScalarUDF::new_from_impl(SparkSha1::default())); + session_ctx.register_udf(ScalarUDF::new_from_impl(SparkSha2::default())); session_ctx.register_udf(ScalarUDF::new_from_impl(SparkConcat::default())); session_ctx.register_udf(ScalarUDF::new_from_impl(SparkBitwiseNot::default())); session_ctx.register_udf(ScalarUDF::new_from_impl(SparkHex::default())); diff --git a/spark/src/main/scala/org/apache/comet/serde/hash.scala b/spark/src/main/scala/org/apache/comet/serde/hash.scala index b059199735..a138ea023d 100644 --- a/spark/src/main/scala/org/apache/comet/serde/hash.scala +++ b/spark/src/main/scala/org/apache/comet/serde/hash.scala @@ -19,7 +19,7 @@ package org.apache.comet.serde -import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, Murmur3Hash, Sha1, Sha2, XxHash64} +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, Literal, Murmur3Hash, Sha1, Sha2, XxHash64} import org.apache.spark.sql.types.{ArrayType, DataType, DecimalType, IntegerType, LongType, MapType, StringType, StructType} import org.apache.comet.CometSparkSessionExtensions.withInfo @@ -76,6 +76,12 @@ object CometSha2 extends CometExpressionSerde[Sha2] { return None } + // Fall back to Spark for literal input to avoid native engine crash (#3340) + if (expr.left.isInstanceOf[Literal]) { + withInfo(expr, "Sha2 with literal input falls back to Spark") + return None + } + // It's possible for spark to dynamically compute the number of bits from input // expression, however DataFusion does not support that yet. if (!expr.right.foldable) { diff --git a/spark/src/test/resources/sql-tests/expressions/hash/hash.sql b/spark/src/test/resources/sql-tests/expressions/hash/hash.sql index 35031ea7e4..550f34b13a 100644 --- a/spark/src/test/resources/sql-tests/expressions/hash/hash.sql +++ b/spark/src/test/resources/sql-tests/expressions/hash/hash.sql @@ -25,8 +25,10 @@ statement INSERT INTO test VALUES ('Spark SQL ', 10, 1.2), (NULL, NULL, NULL), ('', 0, 0.0), ('苹果手机', NULL, 3.999999), ('Spark SQL ', 10, 1.2), (NULL, NULL, NULL), ('', 0, 0.0), ('苹果手机', NULL, 3.999999) query -SELECT md5(col), md5(cast(a as string)), md5(cast(b as string)), hash(col), hash(col, 1), hash(col, 0), hash(col, a, b), hash(b, a, col), xxhash64(col), xxhash64(col, 1), xxhash64(col, 0), xxhash64(col, a, b), xxhash64(b, a, col), sha2(col, 0), sha2(col, 256), sha2(col, 224), sha2(col, 384), sha2(col, 512), sha2(col, 128), sha2(col, -1), sha1(col), sha1(cast(a as string)), sha1(cast(b as string)) FROM test +SELECT md5(col), md5(cast(a as string)), md5(cast(b as string)), hash(col), hash(col, 1), hash(col, 0), hash(col, a, b), hash(b, a, col), xxhash64(col), xxhash64(col, 1), xxhash64(col, 0), xxhash64(col, a, b), xxhash64(b, a, col), sha1(col), sha1(cast(a as string)), sha1(cast(b as string)), sha2(col, 0), sha2(col, 256), sha2(col, 224), sha2(col, 384), sha2(col, 512), sha2(col, 128), sha2(col, -1) FROM test -- literal arguments -query ignore(https://github.com/apache/datafusion-comet/issues/3340) +-- sha2 with literal input falls back to Spark to avoid native engine crash (#3340) +query expect_fallback(Sha2 with literal input falls back to Spark) SELECT md5('Spark SQL'), sha1('test'), sha2('test', 256), hash('test'), xxhash64('test') + diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index 68c1a82f14..57fea4ceaf 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -1891,8 +1891,8 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { |hash(col), hash(col, 1), hash(col, 0), hash(col, a, b), hash(b, a, col), |xxhash64(col), xxhash64(col, 1), xxhash64(col, 0), xxhash64(col, a, b), xxhash64(b, a, col), |crc32(col), crc32(cast(a as string)), crc32(cast(b as string)), - |sha2(col, 0), sha2(col, 256), sha2(col, 224), sha2(col, 384), sha2(col, 512), sha2(col, 128), sha2(col, -1), - |sha1(col), sha1(cast(a as string)), sha1(cast(b as string)) + |sha1(col), sha1(cast(a as string)), sha1(cast(b as string)), + |sha2(col, 0), sha2(col, 256), sha2(col, 224), sha2(col, 384), sha2(col, 512), sha2(col, 128), sha2(col, -1) |from test |""".stripMargin) } @@ -2002,8 +2002,8 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { |hash(col), hash(col, 1), hash(col, 0), hash(col, a, b), hash(b, a, col), |xxhash64(col), xxhash64(col, 1), xxhash64(col, 0), xxhash64(col, a, b), xxhash64(b, a, col), |crc32(col), crc32(cast(a as string)), - |sha2(col, 0), sha2(col, 256), sha2(col, 224), sha2(col, 384), sha2(col, 512), sha2(col, 128), sha2(col, -1), - |sha1(col), sha1(cast(a as string)) + |sha1(col), sha1(cast(a as string)), + |sha2(col, 0), sha2(col, 256), sha2(col, 224), sha2(col, 384), sha2(col, 512), sha2(col, 128), sha2(col, -1) |from test |""".stripMargin) }