diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 1b72521270..e68fab7669 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -19,8 +19,6 @@ package org.apache.comet.serde -import java.util.Locale - import scala.collection.JavaConverters._ import scala.collection.mutable.ListBuffer @@ -1487,30 +1485,6 @@ object QueryPlanSerde extends Logging with CometExprShim { }) } - def stringDecode( - expr: Expression, - charset: Expression, - bin: Expression, - inputs: Seq[Attribute], - binding: Boolean): Option[Expr] = { - charset match { - case Literal(str, DataTypes.StringType) - if str.toString.toLowerCase(Locale.ROOT) == "utf-8" => - // decode(col, 'utf-8') can be treated as a cast with "try" eval mode that puts nulls - // for invalid strings. - // Left child is the binary expression. - castToProto( - expr, - None, - DataTypes.StringType, - exprToProtoInternal(bin, inputs, binding).get, - CometEvalMode.TRY) - case _ => - withInfo(expr, "Comet only supports decoding with 'utf-8'.") - None - } - } - /** * Creates a UnaryExpr by calling exprToProtoInternal for the provided child expression and then * invokes the supplied function to wrap this UnaryExpr in a top-level Expr. diff --git a/spark/src/main/scala/org/apache/comet/serde/strings.scala b/spark/src/main/scala/org/apache/comet/serde/strings.scala index 75e7e8bd4c..39547c5666 100644 --- a/spark/src/main/scala/org/apache/comet/serde/strings.scala +++ b/spark/src/main/scala/org/apache/comet/serde/strings.scala @@ -19,13 +19,16 @@ package org.apache.comet.serde +import java.util.Locale + import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Expression, Like, Literal, RLike, StringRPad, Substring} import org.apache.spark.sql.types.{DataTypes, LongType, StringType} import org.apache.comet.CometConf import org.apache.comet.CometSparkSessionExtensions.withInfo +import org.apache.comet.expressions.CometEvalMode import org.apache.comet.serde.ExprOuterClass.Expr -import org.apache.comet.serde.QueryPlanSerde.{createBinaryExpr, exprToProtoInternal, optExprWithInfo, scalarFunctionExprToProto} +import org.apache.comet.serde.QueryPlanSerde.{castToProto, createBinaryExpr, exprToProtoInternal, optExprWithInfo, scalarFunctionExprToProto} object CometStringRepeat extends CometExpressionSerde { @@ -180,3 +183,30 @@ object CometStringRPad extends CometExpressionSerde { } } } + +trait CommonStringExprs { + + def stringDecode( + expr: Expression, + charset: Expression, + bin: Expression, + inputs: Seq[Attribute], + binding: Boolean): Option[Expr] = { + charset match { + case Literal(str, DataTypes.StringType) + if str.toString.toLowerCase(Locale.ROOT) == "utf-8" => + // decode(col, 'utf-8') can be treated as a cast with "try" eval mode that puts nulls + // for invalid strings. + // Left child is the binary expression. + castToProto( + expr, + None, + DataTypes.StringType, + exprToProtoInternal(bin, inputs, binding).get, + CometEvalMode.TRY) + case _ => + withInfo(expr, "Comet only supports decoding with 'utf-8'.") + None + } + } +} diff --git a/spark/src/main/spark-3.4/org/apache/comet/shims/CometExprShim.scala b/spark/src/main/spark-3.4/org/apache/comet/shims/CometExprShim.scala index 2a302d8d41..ca53efc3db 100644 --- a/spark/src/main/spark-3.4/org/apache/comet/shims/CometExprShim.scala +++ b/spark/src/main/spark-3.4/org/apache/comet/shims/CometExprShim.scala @@ -19,14 +19,14 @@ package org.apache.comet.shims import org.apache.comet.expressions.CometEvalMode +import org.apache.comet.serde.CommonStringExprs import org.apache.comet.serde.ExprOuterClass.Expr -import org.apache.comet.serde.QueryPlanSerde.stringDecode import org.apache.spark.sql.catalyst.expressions._ /** * `CometExprShim` acts as a shim for for parsing expressions from different Spark versions. */ -trait CometExprShim { +trait CometExprShim extends CommonStringExprs { /** * Returns a tuple of expressions for the `unhex` function. */ diff --git a/spark/src/main/spark-3.5/org/apache/comet/shims/CometExprShim.scala b/spark/src/main/spark-3.5/org/apache/comet/shims/CometExprShim.scala index 2a302d8d41..ca53efc3db 100644 --- a/spark/src/main/spark-3.5/org/apache/comet/shims/CometExprShim.scala +++ b/spark/src/main/spark-3.5/org/apache/comet/shims/CometExprShim.scala @@ -19,14 +19,14 @@ package org.apache.comet.shims import org.apache.comet.expressions.CometEvalMode +import org.apache.comet.serde.CommonStringExprs import org.apache.comet.serde.ExprOuterClass.Expr -import org.apache.comet.serde.QueryPlanSerde.stringDecode import org.apache.spark.sql.catalyst.expressions._ /** * `CometExprShim` acts as a shim for for parsing expressions from different Spark versions. */ -trait CometExprShim { +trait CometExprShim extends CommonStringExprs { /** * Returns a tuple of expressions for the `unhex` function. */ diff --git a/spark/src/main/spark-4.0/org/apache/comet/shims/CometExprShim.scala b/spark/src/main/spark-4.0/org/apache/comet/shims/CometExprShim.scala index 1b8e5aaa04..ddd53d6d8d 100644 --- a/spark/src/main/spark-4.0/org/apache/comet/shims/CometExprShim.scala +++ b/spark/src/main/spark-4.0/org/apache/comet/shims/CometExprShim.scala @@ -19,8 +19,8 @@ package org.apache.comet.shims import org.apache.comet.expressions.CometEvalMode +import org.apache.comet.serde.CommonStringExprs import org.apache.comet.serde.ExprOuterClass.Expr -import org.apache.comet.serde.QueryPlanSerde.stringDecode import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke import org.apache.spark.sql.internal.types.StringTypeWithCollation @@ -29,7 +29,7 @@ import org.apache.spark.sql.types.{BinaryType, BooleanType, StringType} /** * `CometExprShim` acts as a shim for for parsing expressions from different Spark versions. */ -trait CometExprShim { +trait CometExprShim extends CommonStringExprs { /** * Returns a tuple of expressions for the `unhex` function. */