From c27729a95a7d33ab5118fdbc43d0b81cf5b859f4 Mon Sep 17 00:00:00 2001 From: 0lai0 Date: Sun, 26 Apr 2026 22:17:27 +0800 Subject: [PATCH] reject non-default collated string join keys in Comet hash join and sort-merge join --- .../apache/spark/sql/comet/operators.scala | 19 +- .../spark/sql/CometCollationSuite.scala | 181 ++++++++++++++++++ 2 files changed, 197 insertions(+), 3 deletions(-) diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala index 5f7e91529d..85434b8674 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala @@ -1716,6 +1716,12 @@ trait CometHashJoin { return None } + val joinKeys = join.leftKeys ++ join.rightKeys + if (joinKeys.exists(key => isStringCollationType(key.dataType))) { + withInfo(join, "unsupported non-default collated string join keys") + return None + } + val condition = join.condition.map { cond => val condProto = exprToProto(cond, join.left.output ++ join.right.output) if (condProto.isEmpty) { @@ -1757,7 +1763,7 @@ trait CometHashJoin { condition.foreach(joinBuilder.setCondition) Some(builder.setHashJoin(joinBuilder).build()) } else { - val allExprs: Seq[Expression] = join.leftKeys ++ join.rightKeys + val allExprs: Seq[Expression] = joinKeys withInfo(join, allExprs: _*) None } @@ -2078,8 +2084,14 @@ object CometSortMergeJoinExec extends CometOperatorSerde[SortMergeJoinExec] { } } + val joinKeys = join.leftKeys ++ join.rightKeys + if (joinKeys.exists(key => isStringCollationType(key.dataType))) { + withInfo(join, "unsupported non-default collated string join keys") + return None + } + // Checks if the join keys are supported by DataFusion SortMergeJoin. - val errorMsgs = join.leftKeys.flatMap { key => + val errorMsgs = joinKeys.flatMap { key => if (!supportedSortMergeJoinEqualType(key.dataType)) { Some(s"Unsupported join key type ${key.dataType} on key: ${key.sql}") } else { @@ -2111,7 +2123,7 @@ object CometSortMergeJoinExec extends CometOperatorSerde[SortMergeJoinExec] { condition.map(joinBuilder.setCondition) Some(builder.setSortMergeJoin(joinBuilder).build()) } else { - val allExprs: Seq[Expression] = join.leftKeys ++ join.rightKeys + val allExprs: Seq[Expression] = joinKeys withInfo(join, allExprs: _*) None } @@ -2136,6 +2148,7 @@ object CometSortMergeJoinExec extends CometOperatorSerde[SortMergeJoinExec] { * Returns true if given datatype is supported as a key in DataFusion sort merge join. */ private def supportedSortMergeJoinEqualType(dataType: DataType): Boolean = dataType match { + case st: StringType if isStringCollationType(st) => false case _: ByteType | _: ShortType | _: IntegerType | _: LongType | _: FloatType | _: DoubleType | _: StringType | _: DateType | _: DecimalType | _: BooleanType => true diff --git a/spark/src/test/spark-4.0/org/apache/spark/sql/CometCollationSuite.scala b/spark/src/test/spark-4.0/org/apache/spark/sql/CometCollationSuite.scala index 463e169b66..4623f4591e 100644 --- a/spark/src/test/spark-4.0/org/apache/spark/sql/CometCollationSuite.scala +++ b/spark/src/test/spark-4.0/org/apache/spark/sql/CometCollationSuite.scala @@ -19,6 +19,18 @@ package org.apache.spark.sql +import org.apache.spark.sql.catalyst.expressions.AttributeReference +import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight} +import org.apache.spark.sql.catalyst.plans.Inner +import org.apache.spark.sql.comet.{CometBroadcastHashJoinExec, CometHashJoinExec, CometSortMergeJoinExec} +import org.apache.spark.sql.execution.LocalTableScanExec +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, ShuffledHashJoinExec, SortMergeJoinExec} +import org.apache.spark.sql.types.StringType + +import org.apache.comet.{CometConf, CometExplainInfo} +import org.apache.comet.serde.OperatorOuterClass + class CometCollationSuite extends CometTestBase { // Queries that group, sort, or shuffle on a non-default collated string must fall back to @@ -29,6 +41,8 @@ class CometCollationSuite extends CometTestBase { "unsupported hash partitioning data type for columnar shuffle" private val rangeShuffleCollationReason = "unsupported range partitioning data type for columnar shuffle" + private val joinKeyCollationReason = + "unsupported non-default collated string join keys" test("listagg DISTINCT with utf8_lcase collation (issue #1947)") { checkSparkAnswerAndFallbackReason( @@ -66,4 +80,171 @@ class CometCollationSuite extends CometTestBase { checkSparkAnswerAndOperator("SELECT DISTINCT _1 FROM tbl ORDER BY _1") } } + + // ---- Join collation guards (issue #4051) ---------------------------------------- + // + // Comet's native join compares keys byte-by-byte, so 'a' and 'A' would not match + // under utf8_lcase, producing wrong results. The converters must reject any join + // whose keys carry a non-default collation. + // + // End-to-end SQL cannot reach the join converter today: higher-level guards + // (CometScanRule, Collate-expression serialization, #4035 shuffle guard) short-circuit + // first. The tests below bypass those guards by constructing physical-plan operators + // directly and calling convert() — the contract is that convert() returns None for + // collated keys. + + private def collatedKey(name: String): AttributeReference = + AttributeReference(name, StringType("UTF8_LCASE"), nullable = false)() + + private def placeholderChildOp(): OperatorOuterClass.Operator = + OperatorOuterClass.Operator.newBuilder().build() + + // Ensure converters are on so that None from convert() means the collation guard fired, + // not that the join type is disabled. + private def withJoinConvertersEnabled(f: => Unit): Unit = + withSQLConf( + CometConf.COMET_EXEC_HASH_JOIN_ENABLED.key -> "true", + CometConf.COMET_EXEC_BROADCAST_HASH_JOIN_ENABLED.key -> "true", + CometConf.COMET_EXEC_SORT_MERGE_JOIN_ENABLED.key -> "true") { + f + } + + private def assertFallbackReason(plan: SparkPlan, expectedReason: String): Unit = { + val reasons = plan.getTagValue(CometExplainInfo.EXTENSION_INFO).getOrElse(Set.empty[String]) + assert( + reasons.contains(expectedReason), + s"Expected fallback reason '$expectedReason' on ${plan.nodeName}, got: $reasons") + } + + test("CometBroadcastHashJoinExec rejects non-default collated join keys") { + withJoinConvertersEnabled { + val left = collatedKey("l") + val right = collatedKey("r") + val join = BroadcastHashJoinExec( + leftKeys = Seq(left), + rightKeys = Seq(right), + joinType = Inner, + buildSide = BuildRight, + condition = None, + left = LocalTableScanExec(Seq(left), Nil, None), + right = LocalTableScanExec(Seq(right), Nil, None)) + + val builder = OperatorOuterClass.Operator.newBuilder() + val result = + CometBroadcastHashJoinExec.convert( + join, + builder, + placeholderChildOp(), + placeholderChildOp()) + + assert( + result.isEmpty, + "CometBroadcastHashJoinExec.convert must reject non-default collated join keys " + + "(issue #4051): native byte equality cannot match values that compare equal " + + "under utf8_lcase. Got a non-empty proto: " + result) + assertFallbackReason(join, joinKeyCollationReason) + } + } + + test("CometHashJoinExec rejects non-default collated join keys") { + withJoinConvertersEnabled { + val left = collatedKey("l") + val right = collatedKey("r") + val join = ShuffledHashJoinExec( + leftKeys = Seq(left), + rightKeys = Seq(right), + joinType = Inner, + buildSide = BuildLeft, + condition = None, + left = LocalTableScanExec(Seq(left), Nil, None), + right = LocalTableScanExec(Seq(right), Nil, None)) + + val builder = OperatorOuterClass.Operator.newBuilder() + val result = + CometHashJoinExec.convert(join, builder, placeholderChildOp(), placeholderChildOp()) + + assert( + result.isEmpty, + "CometHashJoinExec.convert must reject non-default collated join keys (issue " + + "#4051): native byte equality cannot match values that compare equal under " + + "utf8_lcase. Got a non-empty proto: " + result) + assertFallbackReason(join, joinKeyCollationReason) + } + } + + test("CometBroadcastHashJoinExec still accepts default UTF8_BINARY string keys") { + withJoinConvertersEnabled { + val left = AttributeReference("l", StringType, nullable = false)() + val right = AttributeReference("r", StringType, nullable = false)() + val join = BroadcastHashJoinExec( + leftKeys = Seq(left), + rightKeys = Seq(right), + joinType = Inner, + buildSide = BuildRight, + condition = None, + left = LocalTableScanExec(Seq(left), Nil, None), + right = LocalTableScanExec(Seq(right), Nil, None)) + + val builder = OperatorOuterClass.Operator.newBuilder() + val result = + CometBroadcastHashJoinExec.convert( + join, + builder, + placeholderChildOp(), + placeholderChildOp()) + + assert( + result.isDefined, + "CometBroadcastHashJoinExec.convert must continue to accept default UTF8_BINARY " + + "string keys; the collation guard for #4051 must not over-block.") + } + } + + test("CometSortMergeJoinExec rejects non-default collated join keys") { + withJoinConvertersEnabled { + val left = collatedKey("l") + val right = collatedKey("r") + val join = SortMergeJoinExec( + leftKeys = Seq(left), + rightKeys = Seq(right), + joinType = Inner, + condition = None, + left = LocalTableScanExec(Seq(left), Nil, None), + right = LocalTableScanExec(Seq(right), Nil, None)) + + val builder = OperatorOuterClass.Operator.newBuilder() + val result = + CometSortMergeJoinExec.convert(join, builder, placeholderChildOp(), placeholderChildOp()) + + assert( + result.isEmpty, + "CometSortMergeJoinExec.convert must reject non-default collated join keys " + + "(issue #4051): supportedSortMergeJoinEqualType must check collation. Got a " + + "non-empty proto: " + result) + assertFallbackReason(join, joinKeyCollationReason) + } + } + + test("CometSortMergeJoinExec still accepts default UTF8_BINARY string keys") { + withJoinConvertersEnabled { + val left = AttributeReference("l", StringType, nullable = false)() + val right = AttributeReference("r", StringType, nullable = false)() + val join = SortMergeJoinExec( + leftKeys = Seq(left), + rightKeys = Seq(right), + joinType = Inner, + condition = None, + left = LocalTableScanExec(Seq(left), Nil, None), + right = LocalTableScanExec(Seq(right), Nil, None)) + + val builder = OperatorOuterClass.Operator.newBuilder() + val result = + CometSortMergeJoinExec.convert(join, builder, placeholderChildOp(), placeholderChildOp()) + + assert( + result.isDefined, + "CometSortMergeJoinExec.convert must continue to accept default UTF8_BINARY " + + "string keys; the collation guard for #4051 must not over-block.") + } + } }