Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 16 additions & 3 deletions spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1716,6 +1716,12 @@ trait CometHashJoin {
return None
}

val joinKeys = join.leftKeys ++ join.rightKeys
if (joinKeys.exists(key => isStringCollationType(key.dataType))) {
withInfo(join, "unsupported non-default collated string join keys")
return None
}

val condition = join.condition.map { cond =>
val condProto = exprToProto(cond, join.left.output ++ join.right.output)
if (condProto.isEmpty) {
Expand Down Expand Up @@ -1757,7 +1763,7 @@ trait CometHashJoin {
condition.foreach(joinBuilder.setCondition)
Some(builder.setHashJoin(joinBuilder).build())
} else {
val allExprs: Seq[Expression] = join.leftKeys ++ join.rightKeys
val allExprs: Seq[Expression] = joinKeys
withInfo(join, allExprs: _*)
None
}
Expand Down Expand Up @@ -2078,8 +2084,14 @@ object CometSortMergeJoinExec extends CometOperatorSerde[SortMergeJoinExec] {
}
}

val joinKeys = join.leftKeys ++ join.rightKeys
if (joinKeys.exists(key => isStringCollationType(key.dataType))) {
withInfo(join, "unsupported non-default collated string join keys")
return None
}

// Checks if the join keys are supported by DataFusion SortMergeJoin.
val errorMsgs = join.leftKeys.flatMap { key =>
val errorMsgs = joinKeys.flatMap { key =>
if (!supportedSortMergeJoinEqualType(key.dataType)) {
Some(s"Unsupported join key type ${key.dataType} on key: ${key.sql}")
} else {
Expand Down Expand Up @@ -2111,7 +2123,7 @@ object CometSortMergeJoinExec extends CometOperatorSerde[SortMergeJoinExec] {
condition.map(joinBuilder.setCondition)
Some(builder.setSortMergeJoin(joinBuilder).build())
} else {
val allExprs: Seq[Expression] = join.leftKeys ++ join.rightKeys
val allExprs: Seq[Expression] = joinKeys
withInfo(join, allExprs: _*)
None
}
Expand All @@ -2136,6 +2148,7 @@ object CometSortMergeJoinExec extends CometOperatorSerde[SortMergeJoinExec] {
* Returns true if given datatype is supported as a key in DataFusion sort merge join.
*/
private def supportedSortMergeJoinEqualType(dataType: DataType): Boolean = dataType match {
case st: StringType if isStringCollationType(st) => false
case _: ByteType | _: ShortType | _: IntegerType | _: LongType | _: FloatType |
_: DoubleType | _: StringType | _: DateType | _: DecimalType | _: BooleanType =>
true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,18 @@

package org.apache.spark.sql

import org.apache.spark.sql.catalyst.expressions.AttributeReference
import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight}
import org.apache.spark.sql.catalyst.plans.Inner
import org.apache.spark.sql.comet.{CometBroadcastHashJoinExec, CometHashJoinExec, CometSortMergeJoinExec}
import org.apache.spark.sql.execution.LocalTableScanExec
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, ShuffledHashJoinExec, SortMergeJoinExec}
import org.apache.spark.sql.types.StringType

import org.apache.comet.{CometConf, CometExplainInfo}
import org.apache.comet.serde.OperatorOuterClass

class CometCollationSuite extends CometTestBase {

// Queries that group, sort, or shuffle on a non-default collated string must fall back to
Expand All @@ -29,6 +41,8 @@ class CometCollationSuite extends CometTestBase {
"unsupported hash partitioning data type for columnar shuffle"
private val rangeShuffleCollationReason =
"unsupported range partitioning data type for columnar shuffle"
private val joinKeyCollationReason =
"unsupported non-default collated string join keys"

test("listagg DISTINCT with utf8_lcase collation (issue #1947)") {
checkSparkAnswerAndFallbackReason(
Expand Down Expand Up @@ -66,4 +80,171 @@ class CometCollationSuite extends CometTestBase {
checkSparkAnswerAndOperator("SELECT DISTINCT _1 FROM tbl ORDER BY _1")
}
}

// ---- Join collation guards (issue #4051) ----------------------------------------
//
// Comet's native join compares keys byte-by-byte, so 'a' and 'A' would not match
// under utf8_lcase, producing wrong results. The converters must reject any join
// whose keys carry a non-default collation.
//
// End-to-end SQL cannot reach the join converter today: higher-level guards
// (CometScanRule, Collate-expression serialization, #4035 shuffle guard) short-circuit
// first. The tests below bypass those guards by constructing physical-plan operators
// directly and calling convert() — the contract is that convert() returns None for
// collated keys.

private def collatedKey(name: String): AttributeReference =
AttributeReference(name, StringType("UTF8_LCASE"), nullable = false)()

private def placeholderChildOp(): OperatorOuterClass.Operator =
OperatorOuterClass.Operator.newBuilder().build()

// Ensure converters are on so that None from convert() means the collation guard fired,
// not that the join type is disabled.
private def withJoinConvertersEnabled(f: => Unit): Unit =
withSQLConf(
CometConf.COMET_EXEC_HASH_JOIN_ENABLED.key -> "true",
CometConf.COMET_EXEC_BROADCAST_HASH_JOIN_ENABLED.key -> "true",
CometConf.COMET_EXEC_SORT_MERGE_JOIN_ENABLED.key -> "true") {
f
}

private def assertFallbackReason(plan: SparkPlan, expectedReason: String): Unit = {
val reasons = plan.getTagValue(CometExplainInfo.EXTENSION_INFO).getOrElse(Set.empty[String])
assert(
reasons.contains(expectedReason),
s"Expected fallback reason '$expectedReason' on ${plan.nodeName}, got: $reasons")
}

test("CometBroadcastHashJoinExec rejects non-default collated join keys") {
withJoinConvertersEnabled {
val left = collatedKey("l")
val right = collatedKey("r")
val join = BroadcastHashJoinExec(
leftKeys = Seq(left),
rightKeys = Seq(right),
joinType = Inner,
buildSide = BuildRight,
condition = None,
left = LocalTableScanExec(Seq(left), Nil, None),
right = LocalTableScanExec(Seq(right), Nil, None))

val builder = OperatorOuterClass.Operator.newBuilder()
val result =
CometBroadcastHashJoinExec.convert(
join,
builder,
placeholderChildOp(),
placeholderChildOp())

assert(
result.isEmpty,
"CometBroadcastHashJoinExec.convert must reject non-default collated join keys " +
"(issue #4051): native byte equality cannot match values that compare equal " +
"under utf8_lcase. Got a non-empty proto: " + result)
assertFallbackReason(join, joinKeyCollationReason)
}
}

test("CometHashJoinExec rejects non-default collated join keys") {
withJoinConvertersEnabled {
val left = collatedKey("l")
val right = collatedKey("r")
val join = ShuffledHashJoinExec(
leftKeys = Seq(left),
rightKeys = Seq(right),
joinType = Inner,
buildSide = BuildLeft,
condition = None,
left = LocalTableScanExec(Seq(left), Nil, None),
right = LocalTableScanExec(Seq(right), Nil, None))

val builder = OperatorOuterClass.Operator.newBuilder()
val result =
CometHashJoinExec.convert(join, builder, placeholderChildOp(), placeholderChildOp())

assert(
result.isEmpty,
"CometHashJoinExec.convert must reject non-default collated join keys (issue " +
"#4051): native byte equality cannot match values that compare equal under " +
"utf8_lcase. Got a non-empty proto: " + result)
assertFallbackReason(join, joinKeyCollationReason)
}
}

test("CometBroadcastHashJoinExec still accepts default UTF8_BINARY string keys") {
withJoinConvertersEnabled {
val left = AttributeReference("l", StringType, nullable = false)()
val right = AttributeReference("r", StringType, nullable = false)()
val join = BroadcastHashJoinExec(
leftKeys = Seq(left),
rightKeys = Seq(right),
joinType = Inner,
buildSide = BuildRight,
condition = None,
left = LocalTableScanExec(Seq(left), Nil, None),
right = LocalTableScanExec(Seq(right), Nil, None))

val builder = OperatorOuterClass.Operator.newBuilder()
val result =
CometBroadcastHashJoinExec.convert(
join,
builder,
placeholderChildOp(),
placeholderChildOp())

assert(
result.isDefined,
"CometBroadcastHashJoinExec.convert must continue to accept default UTF8_BINARY " +
"string keys; the collation guard for #4051 must not over-block.")
}
}

test("CometSortMergeJoinExec rejects non-default collated join keys") {
withJoinConvertersEnabled {
val left = collatedKey("l")
val right = collatedKey("r")
val join = SortMergeJoinExec(
leftKeys = Seq(left),
rightKeys = Seq(right),
joinType = Inner,
condition = None,
left = LocalTableScanExec(Seq(left), Nil, None),
right = LocalTableScanExec(Seq(right), Nil, None))

val builder = OperatorOuterClass.Operator.newBuilder()
val result =
CometSortMergeJoinExec.convert(join, builder, placeholderChildOp(), placeholderChildOp())

assert(
result.isEmpty,
"CometSortMergeJoinExec.convert must reject non-default collated join keys " +
"(issue #4051): supportedSortMergeJoinEqualType must check collation. Got a " +
"non-empty proto: " + result)
assertFallbackReason(join, joinKeyCollationReason)
}
}

test("CometSortMergeJoinExec still accepts default UTF8_BINARY string keys") {
withJoinConvertersEnabled {
val left = AttributeReference("l", StringType, nullable = false)()
val right = AttributeReference("r", StringType, nullable = false)()
val join = SortMergeJoinExec(
leftKeys = Seq(left),
rightKeys = Seq(right),
joinType = Inner,
condition = None,
left = LocalTableScanExec(Seq(left), Nil, None),
right = LocalTableScanExec(Seq(right), Nil, None))

val builder = OperatorOuterClass.Operator.newBuilder()
val result =
CometSortMergeJoinExec.convert(join, builder, placeholderChildOp(), placeholderChildOp())

assert(
result.isDefined,
"CometSortMergeJoinExec.convert must continue to accept default UTF8_BINARY " +
"string keys; the collation guard for #4051 must not over-block.")
}
}
}
Loading