Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 25 additions & 20 deletions spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
Original file line number Diff line number Diff line change
Expand Up @@ -630,7 +630,7 @@ object QueryPlanSerde extends Logging with CometExprShim {
}
}

expr match {
versionSpecificExprToProtoInternal(expr, inputs, binding).orElse(expr match {
case a @ Alias(_, _) =>
val r = exprToProtoInternal(a.child, inputs, binding)
if (r.isEmpty) {
Expand Down Expand Up @@ -1284,25 +1284,6 @@ object QueryPlanSerde extends Logging with CometExprShim {
optExprWithInfo(optExpr, expr, r.child)
}

case s: StringDecode =>
// Right child is the encoding expression.
s.charset match {
case Literal(str, DataTypes.StringType)
if str.toString.toLowerCase(Locale.ROOT) == "utf-8" =>
// decode(col, 'utf-8') can be treated as a cast with "try" eval mode that puts nulls
// for invalid strings.
// Left child is the binary expression.
castToProto(
expr,
None,
DataTypes.StringType,
exprToProtoInternal(s.bin, inputs, binding).get,
CometEvalMode.TRY)
case _ =>
withInfo(expr, "Comet only supports decoding with 'utf-8'.")
None
}

case RegExpReplace(subject, pattern, replacement, startPosition) =>
if (!RegExp.isSupportedPattern(pattern.toString) &&
!CometConf.COMET_REGEXP_ALLOW_INCOMPATIBLE.get()) {
Expand Down Expand Up @@ -1679,6 +1660,30 @@ object QueryPlanSerde extends Logging with CometExprShim {
withInfo(expr, s"${expr.prettyName} is not supported", expr.children: _*)
None
}
})
}

def stringDecode(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@andygrove Is there somewhere outside of QueryPlanSerde that makes more sense for this, as you've recently been refactoring serde logic?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe strings.scala?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, strings.scala makes sense. I think it would be fine to do this refactor as a separate PR.

expr: Expression,
charset: Expression,
bin: Expression,
inputs: Seq[Attribute],
binding: Boolean): Option[Expr] = {
charset match {
case Literal(str, DataTypes.StringType)
if str.toString.toLowerCase(Locale.ROOT) == "utf-8" =>
// decode(col, 'utf-8') can be treated as a cast with "try" eval mode that puts nulls
// for invalid strings.
// Left child is the binary expression.
castToProto(
expr,
None,
DataTypes.StringType,
exprToProtoInternal(bin, inputs, binding).get,
CometEvalMode.TRY)
case _ =>
withInfo(expr, "Comet only supports decoding with 'utf-8'.")
None
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
package org.apache.comet.shims

import org.apache.comet.expressions.CometEvalMode
import org.apache.comet.serde.ExprOuterClass.Expr
import org.apache.comet.serde.QueryPlanSerde.stringDecode
import org.apache.spark.sql.catalyst.expressions._

/**
Expand All @@ -34,6 +36,19 @@ trait CometExprShim {

protected def evalMode(c: Cast): CometEvalMode.Value =
CometEvalModeUtil.fromSparkEvalMode(c.evalMode)

def versionSpecificExprToProtoInternal(
expr: Expression,
inputs: Seq[Attribute],
binding: Boolean): Option[Expr] = {
expr match {
case s: StringDecode =>
// Right child is the encoding expression.
stringDecode(expr, s.charset, s.bin, inputs, binding)

case _ => None
}
}
}

object CometEvalModeUtil {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
package org.apache.comet.shims

import org.apache.comet.expressions.CometEvalMode
import org.apache.comet.serde.ExprOuterClass.Expr
import org.apache.comet.serde.QueryPlanSerde.stringDecode
import org.apache.spark.sql.catalyst.expressions._

/**
Expand All @@ -34,6 +36,19 @@ trait CometExprShim {

protected def evalMode(c: Cast): CometEvalMode.Value =
CometEvalModeUtil.fromSparkEvalMode(c.evalMode)

def versionSpecificExprToProtoInternal(
expr: Expression,
inputs: Seq[Attribute],
binding: Boolean): Option[Expr] = {
expr match {
case s: StringDecode =>
// Right child is the encoding expression.
stringDecode(expr, s.charset, s.bin, inputs, binding)

case _ => None
}
}
}

object CometEvalModeUtil {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,12 @@
package org.apache.comet.shims

import org.apache.comet.expressions.CometEvalMode
import org.apache.comet.serde.ExprOuterClass.Expr
import org.apache.comet.serde.QueryPlanSerde.stringDecode
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke
import org.apache.spark.sql.internal.types.StringTypeWithCollation
import org.apache.spark.sql.types.{BinaryType, BooleanType, StringType}

/**
* `CometExprShim` acts as a shim for for parsing expressions from different Spark versions.
Expand All @@ -34,6 +39,28 @@ trait CometExprShim {

protected def evalMode(c: Cast): CometEvalMode.Value =
CometEvalModeUtil.fromSparkEvalMode(c.evalMode)

def versionSpecificExprToProtoInternal(
expr: Expression,
inputs: Seq[Attribute],
binding: Boolean): Option[Expr] = {
expr match {
case s: StaticInvoke
if s.staticObject == classOf[StringDecode] &&
s.dataType.isInstanceOf[StringType] &&
s.functionName == "decode" &&
s.arguments.size == 4 &&
s.inputTypes == Seq(
BinaryType,
StringTypeWithCollation(supportsTrimCollation = true),
BooleanType,
BooleanType) =>
val Seq(bin, charset, _, _) = s.arguments
stringDecode(expr, charset, bin, inputs, binding)

case _ => None
}
}
}

object CometEvalModeUtil {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.SQLConf.ParquetOutputTimestampType
import org.apache.spark.sql.types._

import org.apache.comet.CometSparkSessionExtensions.isSpark40Plus
import org.apache.comet.testing.{DataGenOptions, ParquetGenerator}

class CometFuzzTestSuite extends CometTestBase with AdaptiveSparkPlanHelper {
Expand Down Expand Up @@ -272,8 +271,6 @@ class CometFuzzTestSuite extends CometTestBase with AdaptiveSparkPlanHelper {
}

test("decode") {
// https://github.com/apache/datafusion-comet/issues/1942
assume(!isSpark40Plus)
val df = spark.read.parquet(filename)
df.createOrReplaceTempView("t1")
// We want to make sure that the schema generator wasn't modified to accidentally omit
Expand Down
Loading