diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index a184bc94ce..6aa161280a 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -630,7 +630,7 @@ object QueryPlanSerde extends Logging with CometExprShim { } } - expr match { + versionSpecificExprToProtoInternal(expr, inputs, binding).orElse(expr match { case a @ Alias(_, _) => val r = exprToProtoInternal(a.child, inputs, binding) if (r.isEmpty) { @@ -1284,25 +1284,6 @@ object QueryPlanSerde extends Logging with CometExprShim { optExprWithInfo(optExpr, expr, r.child) } - case s: StringDecode => - // Right child is the encoding expression. - s.charset match { - case Literal(str, DataTypes.StringType) - if str.toString.toLowerCase(Locale.ROOT) == "utf-8" => - // decode(col, 'utf-8') can be treated as a cast with "try" eval mode that puts nulls - // for invalid strings. - // Left child is the binary expression. - castToProto( - expr, - None, - DataTypes.StringType, - exprToProtoInternal(s.bin, inputs, binding).get, - CometEvalMode.TRY) - case _ => - withInfo(expr, "Comet only supports decoding with 'utf-8'.") - None - } - case RegExpReplace(subject, pattern, replacement, startPosition) => if (!RegExp.isSupportedPattern(pattern.toString) && !CometConf.COMET_REGEXP_ALLOW_INCOMPATIBLE.get()) { @@ -1679,6 +1660,30 @@ object QueryPlanSerde extends Logging with CometExprShim { withInfo(expr, s"${expr.prettyName} is not supported", expr.children: _*) None } + }) + } + + def stringDecode( + expr: Expression, + charset: Expression, + bin: Expression, + inputs: Seq[Attribute], + binding: Boolean): Option[Expr] = { + charset match { + case Literal(str, DataTypes.StringType) + if str.toString.toLowerCase(Locale.ROOT) == "utf-8" => + // decode(col, 'utf-8') can be treated as a cast with "try" eval mode that puts nulls + // for invalid strings. + // Left child is the binary expression. + castToProto( + expr, + None, + DataTypes.StringType, + exprToProtoInternal(bin, inputs, binding).get, + CometEvalMode.TRY) + case _ => + withInfo(expr, "Comet only supports decoding with 'utf-8'.") + None } } diff --git a/spark/src/main/spark-3.4/org/apache/comet/shims/CometExprShim.scala b/spark/src/main/spark-3.4/org/apache/comet/shims/CometExprShim.scala index 5f4e3fba2b..2a302d8d41 100644 --- a/spark/src/main/spark-3.4/org/apache/comet/shims/CometExprShim.scala +++ b/spark/src/main/spark-3.4/org/apache/comet/shims/CometExprShim.scala @@ -19,6 +19,8 @@ package org.apache.comet.shims import org.apache.comet.expressions.CometEvalMode +import org.apache.comet.serde.ExprOuterClass.Expr +import org.apache.comet.serde.QueryPlanSerde.stringDecode import org.apache.spark.sql.catalyst.expressions._ /** @@ -34,6 +36,19 @@ trait CometExprShim { protected def evalMode(c: Cast): CometEvalMode.Value = CometEvalModeUtil.fromSparkEvalMode(c.evalMode) + + def versionSpecificExprToProtoInternal( + expr: Expression, + inputs: Seq[Attribute], + binding: Boolean): Option[Expr] = { + expr match { + case s: StringDecode => + // Right child is the encoding expression. + stringDecode(expr, s.charset, s.bin, inputs, binding) + + case _ => None + } + } } object CometEvalModeUtil { diff --git a/spark/src/main/spark-3.5/org/apache/comet/shims/CometExprShim.scala b/spark/src/main/spark-3.5/org/apache/comet/shims/CometExprShim.scala index 5f4e3fba2b..2a302d8d41 100644 --- a/spark/src/main/spark-3.5/org/apache/comet/shims/CometExprShim.scala +++ b/spark/src/main/spark-3.5/org/apache/comet/shims/CometExprShim.scala @@ -19,6 +19,8 @@ package org.apache.comet.shims import org.apache.comet.expressions.CometEvalMode +import org.apache.comet.serde.ExprOuterClass.Expr +import org.apache.comet.serde.QueryPlanSerde.stringDecode import org.apache.spark.sql.catalyst.expressions._ /** @@ -34,6 +36,19 @@ trait CometExprShim { protected def evalMode(c: Cast): CometEvalMode.Value = CometEvalModeUtil.fromSparkEvalMode(c.evalMode) + + def versionSpecificExprToProtoInternal( + expr: Expression, + inputs: Seq[Attribute], + binding: Boolean): Option[Expr] = { + expr match { + case s: StringDecode => + // Right child is the encoding expression. + stringDecode(expr, s.charset, s.bin, inputs, binding) + + case _ => None + } + } } object CometEvalModeUtil { diff --git a/spark/src/main/spark-4.0/org/apache/comet/shims/CometExprShim.scala b/spark/src/main/spark-4.0/org/apache/comet/shims/CometExprShim.scala index 5f4e3fba2b..1b8e5aaa04 100644 --- a/spark/src/main/spark-4.0/org/apache/comet/shims/CometExprShim.scala +++ b/spark/src/main/spark-4.0/org/apache/comet/shims/CometExprShim.scala @@ -19,7 +19,12 @@ package org.apache.comet.shims import org.apache.comet.expressions.CometEvalMode +import org.apache.comet.serde.ExprOuterClass.Expr +import org.apache.comet.serde.QueryPlanSerde.stringDecode import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke +import org.apache.spark.sql.internal.types.StringTypeWithCollation +import org.apache.spark.sql.types.{BinaryType, BooleanType, StringType} /** * `CometExprShim` acts as a shim for for parsing expressions from different Spark versions. @@ -34,6 +39,28 @@ trait CometExprShim { protected def evalMode(c: Cast): CometEvalMode.Value = CometEvalModeUtil.fromSparkEvalMode(c.evalMode) + + def versionSpecificExprToProtoInternal( + expr: Expression, + inputs: Seq[Attribute], + binding: Boolean): Option[Expr] = { + expr match { + case s: StaticInvoke + if s.staticObject == classOf[StringDecode] && + s.dataType.isInstanceOf[StringType] && + s.functionName == "decode" && + s.arguments.size == 4 && + s.inputTypes == Seq( + BinaryType, + StringTypeWithCollation(supportsTrimCollation = true), + BooleanType, + BooleanType) => + val Seq(bin, charset, _, _) = s.arguments + stringDecode(expr, charset, bin, inputs, binding) + + case _ => None + } + } } object CometEvalModeUtil { diff --git a/spark/src/test/scala/org/apache/comet/CometFuzzTestSuite.scala b/spark/src/test/scala/org/apache/comet/CometFuzzTestSuite.scala index d1f55cbe1e..a1b1812b31 100644 --- a/spark/src/test/scala/org/apache/comet/CometFuzzTestSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometFuzzTestSuite.scala @@ -38,7 +38,6 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.ParquetOutputTimestampType import org.apache.spark.sql.types._ -import org.apache.comet.CometSparkSessionExtensions.isSpark40Plus import org.apache.comet.testing.{DataGenOptions, ParquetGenerator} class CometFuzzTestSuite extends CometTestBase with AdaptiveSparkPlanHelper { @@ -272,8 +271,6 @@ class CometFuzzTestSuite extends CometTestBase with AdaptiveSparkPlanHelper { } test("decode") { - // https://github.com/apache/datafusion-comet/issues/1942 - assume(!isSpark40Plus) val df = spark.read.parquet(filename) df.createOrReplaceTempView("t1") // We want to make sure that the schema generator wasn't modified to accidentally omit