From ac12debe304720495392a6b53d03814ac650785f Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Fri, 17 Apr 2026 10:19:15 -0600 Subject: [PATCH 1/2] test: add fuzz fallback + canonicalization suites for #3949 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds a diagnostic harness that randomly vetoes Comet shuffle and operator conversions so the rule pipeline produces irregular Spark/Comet boundaries, plus tests that probe canonicalization of Comet plans. None of the suites reproduce #3949 yet — the bug requires post-stage-materialization plan shapes that don't appear in the initial physical plan — but the infrastructure is useful as a general regression guard and documents the hypotheses we've ruled out. Changes: - FuzzFallback: seeded, node-identity-keyed decision cache. - Inject points: CometShuffleExchangeExec.{nativeShuffleSupported, columnarShuffleSupported} and the generic operator branch of CometExecRule.convertNode. Four new testing-category configs, all default-off. - CometFuzzFallbackSuite: 129 TPC-DS queries x N seeds, plan-only (no data required). - CometFuzzDppSuite: small partitioned fact/dim data, actually executes DPP-flavored queries across SMJ/BHJ/coalesce variants with AQE on. - CometCanonicalizationSuite / CometCanonicalizationTpcdsSuite: walk each Comet plan node and assert p.canonicalized.supportsColumnar == true, and that ColumnarToRowExec(p).canonicalized does not throw (mirrors the #3949 stack trace). --- .../scala/org/apache/comet/CometConf.scala | 41 +++ .../scala/org/apache/comet/FuzzFallback.scala | 89 +++++++ .../apache/comet/rules/CometExecRule.scala | 6 +- .../shuffle/CometShuffleExchangeExec.scala | 12 +- .../comet/CometCanonicalizationSuite.scala | 242 ++++++++++++++++++ .../CometCanonicalizationTpcdsSuite.scala | 161 ++++++++++++ .../spark/sql/comet/CometFuzzDppSuite.scala | 159 ++++++++++++ .../sql/comet/CometFuzzFallbackSuite.scala | 168 ++++++++++++ 8 files changed, 876 insertions(+), 2 deletions(-) create mode 100644 spark/src/main/scala/org/apache/comet/FuzzFallback.scala create mode 100644 spark/src/test/scala/org/apache/spark/sql/comet/CometCanonicalizationSuite.scala create mode 100644 spark/src/test/scala/org/apache/spark/sql/comet/CometCanonicalizationTpcdsSuite.scala create mode 100644 spark/src/test/scala/org/apache/spark/sql/comet/CometFuzzDppSuite.scala create mode 100644 spark/src/test/scala/org/apache/spark/sql/comet/CometFuzzFallbackSuite.scala diff --git a/common/src/main/scala/org/apache/comet/CometConf.scala b/common/src/main/scala/org/apache/comet/CometConf.scala index 32522bf51f..7eb14d6b23 100644 --- a/common/src/main/scala/org/apache/comet/CometConf.scala +++ b/common/src/main/scala/org/apache/comet/CometConf.scala @@ -561,6 +561,47 @@ object CometConf extends ShimCometConf { .booleanConf .createWithDefault(true) + val COMET_FUZZ_FALLBACK_ENABLED: ConfigEntry[Boolean] = + conf("spark.comet.fuzz.fallback.enabled") + .category(CATEGORY_TESTING) + .doc( + "Diagnostic: when enabled, Comet randomly vetoes converting shuffles/operators to " + + "Comet equivalents so the rule pipeline produces irregular Spark/Comet boundaries. " + + "Used to surface plan-shape bugs that are hard to trigger via normal queries. " + + "Decisions are deterministic given `spark.comet.fuzz.fallback.seed`.") + .booleanConf + .createWithDefault(false) + + val COMET_FUZZ_FALLBACK_SEED: ConfigEntry[Long] = + conf("spark.comet.fuzz.fallback.seed") + .category(CATEGORY_TESTING) + .doc("Seed for the fuzz fallback RNG. Same seed + same query reproduces the same pattern " + + "of forced fallbacks. Only used when `spark.comet.fuzz.fallback.enabled=true`.") + .longConf + .createWithDefault(0L) + + val COMET_FUZZ_FALLBACK_SHUFFLE_VETO_PROBABILITY: ConfigEntry[Double] = + conf("spark.comet.fuzz.fallback.shuffleVetoProbability") + .category(CATEGORY_TESTING) + .doc( + "Probability in [0.0, 1.0] that the fuzz fallback vetoes converting a given " + + "ShuffleExchangeExec to a CometShuffleExchangeExec. Only used when " + + "`spark.comet.fuzz.fallback.enabled=true`.") + .doubleConf + .checkValue(v => v >= 0.0 && v <= 1.0, "Probability must be in [0.0, 1.0]") + .createWithDefault(0.5) + + val COMET_FUZZ_FALLBACK_EXEC_VETO_PROBABILITY: ConfigEntry[Double] = + conf("spark.comet.fuzz.fallback.execVetoProbability") + .category(CATEGORY_TESTING) + .doc( + "Probability in [0.0, 1.0] that the fuzz fallback vetoes converting a given " + + "Spark operator (aggregate, join, project, etc.) to its Comet equivalent. " + + "Only used when `spark.comet.fuzz.fallback.enabled=true`.") + .doubleConf + .checkValue(v => v >= 0.0 && v <= 1.0, "Probability must be in [0.0, 1.0]") + .createWithDefault(0.0) + val COMET_DEBUG_ENABLED: ConfigEntry[Boolean] = conf("spark.comet.debug.enabled") .category(CATEGORY_EXEC) diff --git a/spark/src/main/scala/org/apache/comet/FuzzFallback.scala b/spark/src/main/scala/org/apache/comet/FuzzFallback.scala new file mode 100644 index 0000000000..68cf4b8d64 --- /dev/null +++ b/spark/src/main/scala/org/apache/comet/FuzzFallback.scala @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet + +import java.util.concurrent.ConcurrentHashMap + +import org.apache.spark.sql.execution.SparkPlan + +/** + * Diagnostic utility that randomly vetoes Comet conversions so the rule pipeline produces + * irregular Spark/Comet boundaries. Used by fuzz tests to surface plan-shape bugs that arise when + * adjacent operators belong to different execution modes (e.g. the assertion failure described in + * issue #3949). + * + * Determinism: each decision is a pure function of the seed, the node identity hash, and the + * decision kind. This means repeated calls for the same node return the same answer (important + * because `getSupportLevel` and `createExec` are called at different times during rule + * application), and a failing seed can be reproduced by rerunning the test with the same + * configuration. + */ +object FuzzFallback { + + // Cache decisions per (kind, identityHashCode(plan)). The cache is cleared between queries via + // reset(); identity hash collisions within one query are astronomically unlikely. + private val decisions = new ConcurrentHashMap[(Int, Int), Boolean]() + + /** Reset cached decisions. Call this between queries so every query starts clean. */ + def reset(): Unit = decisions.clear() + + private def decide(kind: Int, plan: SparkPlan, probability: Double): Boolean = { + if (probability <= 0.0) return false + val key = (kind, System.identityHashCode(plan)) + val cached = decisions.get(key) + if (cached != null) return cached + val seed = CometConf.COMET_FUZZ_FALLBACK_SEED.get() + // Mix seed, kind, and node identity into a deterministic hash, then compare against the + // probability. Using SplitMix64-style avalanche gives a reasonable uniform distribution. + var h: Long = seed + h ^= kind.toLong * 0x9e3779b97f4a7c15L + h ^= System.identityHashCode(plan).toLong * 0xbf58476d1ce4e5b9L + h ^= h >>> 30 + h *= 0xbf58476d1ce4e5b9L + h ^= h >>> 27 + h *= 0x94d049bb133111ebL + h ^= h >>> 31 + // Map to [0.0, 1.0) + val u = (h >>> 11) * (1.0 / (1L << 53)) + val result = u < probability + decisions.put(key, result) + result + } + + /** + * Decide whether to veto converting this shuffle exchange to a Comet shuffle. Returns false + * unless fuzz fallback is enabled. When enabled, returns true with probability + * `spark.comet.fuzz.fallback.shuffleVetoProbability`. + */ + def shouldVetoShuffle(plan: SparkPlan): Boolean = { + if (!CometConf.COMET_FUZZ_FALLBACK_ENABLED.get()) false + else decide(1, plan, CometConf.COMET_FUZZ_FALLBACK_SHUFFLE_VETO_PROBABILITY.get()) + } + + /** + * Decide whether to veto converting this operator to a Comet equivalent. Returns false unless + * fuzz fallback is enabled. When enabled, returns true with probability + * `spark.comet.fuzz.fallback.execVetoProbability`. + */ + def shouldVetoExec(plan: SparkPlan): Boolean = { + if (!CometConf.COMET_FUZZ_FALLBACK_ENABLED.get()) false + else decide(2, plan, CometConf.COMET_FUZZ_FALLBACK_EXEC_VETO_PROBABILITY.get()) + } +} diff --git a/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala b/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala index 70983b0599..0fc3f1dd4c 100644 --- a/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala +++ b/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala @@ -47,7 +47,7 @@ import org.apache.spark.sql.execution.window.WindowExec import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ -import org.apache.comet.{CometConf, CometExplainInfo, ExtendedExplainInfo} +import org.apache.comet.{CometConf, CometExplainInfo, ExtendedExplainInfo, FuzzFallback} import org.apache.comet.CometConf.{COMET_SPARK_TO_ARROW_ENABLED, COMET_SPARK_TO_ARROW_SUPPORTED_OPERATOR_LIST} import org.apache.comet.CometSparkSessionExtensions._ import org.apache.comet.rules.CometExecRule.allExecs @@ -269,6 +269,10 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] { .map(_.asInstanceOf[CometOperatorSerde[SparkPlan]]) handler match { case Some(handler) => + if (FuzzFallback.shouldVetoExec(op)) { + withInfo(op, "Fuzz fallback vetoed operator conversion") + return op + } return convertToComet(op, handler).getOrElse(op) case _ => } diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala index df2dca0331..4779a5048c 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala @@ -48,7 +48,7 @@ import org.apache.spark.util.random.XORShiftRandom import com.google.common.base.Objects -import org.apache.comet.CometConf +import org.apache.comet.{CometConf, FuzzFallback} import org.apache.comet.CometConf.{COMET_EXEC_SHUFFLE_ENABLED, COMET_SHUFFLE_MODE} import org.apache.comet.CometSparkSessionExtensions.{isCometShuffleManagerEnabled, withInfo} import org.apache.comet.serde.{Compatible, OperatorOuterClass, QueryPlanSerde, SupportLevel, Unsupported} @@ -342,6 +342,11 @@ object CometShuffleExchangeExec return false } + if (FuzzFallback.shouldVetoShuffle(s)) { + withInfo(s, "Fuzz fallback vetoed native shuffle") + return false + } + val inputs = s.child.output for (input <- inputs) { @@ -459,6 +464,11 @@ object CometShuffleExchangeExec return false } + if (FuzzFallback.shouldVetoShuffle(s)) { + withInfo(s, "Fuzz fallback vetoed columnar shuffle") + return false + } + if (!isCometJVMShuffleMode(s.conf)) { withInfo(s, "Comet columnar shuffle not enabled") return false diff --git a/spark/src/test/scala/org/apache/spark/sql/comet/CometCanonicalizationSuite.scala b/spark/src/test/scala/org/apache/spark/sql/comet/CometCanonicalizationSuite.scala new file mode 100644 index 0000000000..8113cd0934 --- /dev/null +++ b/spark/src/test/scala/org/apache/spark/sql/comet/CometCanonicalizationSuite.scala @@ -0,0 +1,242 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.comet + +import org.apache.spark.sql.CometTestBase +import org.apache.spark.sql.execution.{ColumnarToRowExec, SparkPlan} +import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec + +/** + * Targets issue #3949. The full stack trace shows the assertion fires from + * `QueryPlan.doCanonicalize` → `withNewChildren(canonicalizedChildren)` → + * `ColumnarToRowExec.copy`, meaning the canonical form of some Comet plan has `supportsColumnar + * \== false` even though the original has it `true`. That violates the implicit contract that + * canonicalization preserves `supportsColumnar`. + * + * This suite constructs every kind of CometPlan reachable from common queries, canonicalizes it, + * and asserts that `supportsColumnar` is preserved. Any failing case is the minimal repro. + */ +class CometCanonicalizationSuite extends CometTestBase { + + private def collectCometPlans(root: SparkPlan): Seq[SparkPlan] = { + // AdaptiveSparkPlanExec hides its children from collect; reach inside explicitly. + def walk(p: SparkPlan): Seq[SparkPlan] = p match { + case a: AdaptiveSparkPlanExec => + Seq(a.initialPlan, a.executedPlan).flatMap(walk) + case other => + val self = + if (other.isInstanceOf[CometPlan] && other.supportsColumnar) Seq(other) else Nil + self ++ other.children.flatMap(walk) + } + walk(root) + } + + private def planOf(query: String): SparkPlan = { + spark.sql(query).queryExecution.executedPlan + } + + private def checkCanonicalized(tag: String, plans: Seq[SparkPlan]): Unit = { + assert(plans.nonEmpty, s"[$tag] produced no Comet plans with supportsColumnar=true") + val broken = plans.flatMap { p => + try { + val c = p.canonicalized + if (!c.supportsColumnar) Some((p, c, "supportsColumnar=false after canonicalization")) + else None + } catch { + case t: Throwable => + Some((p, null, s"canonicalization threw: ${t.getClass.getName}: ${t.getMessage}")) + } + } + if (broken.nonEmpty) { + val details = broken + .map { case (p, c, reason) => + s"""node: ${p.getClass.getName} + | reason: $reason + | original: + |${p.treeString} + | canonical: + |${Option(c).map(_.treeString).getOrElse("")} + |""".stripMargin + } + .mkString("\n") + fail( + s"[$tag] ${broken.size} node(s) lose supportsColumnar under canonicalization:\n$details") + } + } + + private def plansOf(query: String): Seq[SparkPlan] = { + collectCometPlans(planOf(query)) + } + + test("CometScanExec canonicalization preserves supportsColumnar") { + withTempDir { dir => + val path = s"${dir.getAbsolutePath}/t.parquet" + val sess = spark + import sess.implicits._ + (0 until 10).map(i => (i, i.toString)).toDF("a", "b").write.parquet(path) + spark.read.parquet(path).createOrReplaceTempView("t3949_scan") + checkCanonicalized("scan", plansOf("select a, b from t3949_scan")) + } + } + + test("CometScanExec + filter + project canonicalization") { + withTempDir { dir => + val path = s"${dir.getAbsolutePath}/t.parquet" + val sess = spark + import sess.implicits._ + (0 until 100).map(i => (i, i.toString, i * 2L)).toDF("a", "b", "c").write.parquet(path) + spark.read.parquet(path).createOrReplaceTempView("t3949_fp") + checkCanonicalized("scan+filter+project", plansOf("select a, c from t3949_fp where a > 10")) + } + } + + test("Aggregate plan canonicalization") { + withTempDir { dir => + val path = s"${dir.getAbsolutePath}/t.parquet" + val sess = spark + import sess.implicits._ + (0 until 100).map(i => (i % 10, i)).toDF("k", "v").write.parquet(path) + spark.read.parquet(path).createOrReplaceTempView("t3949_agg") + checkCanonicalized( + "aggregate", + plansOf("select k, sum(v), count(*) from t3949_agg group by k")) + } + } + + test("Broadcast hash join canonicalization") { + withTempDir { dir => + val fact = s"${dir.getAbsolutePath}/fact.parquet" + val dim = s"${dir.getAbsolutePath}/dim.parquet" + val sess = spark + import sess.implicits._ + (0 until 200).map(i => (i, i % 20, i.toString)).toDF("id", "k", "v").write.parquet(fact) + (0 until 20).map(i => (i, i.toString)).toDF("k", "d").write.parquet(dim) + spark.read.parquet(fact).createOrReplaceTempView("t3949_f") + spark.read.parquet(dim).createOrReplaceTempView("t3949_d") + checkCanonicalized( + "bhj", + plansOf("select f.id, f.v, d.d from t3949_f f join t3949_d d on f.k = d.k")) + } + } + + /** + * Directly mirrors the stack path in #3949: wrap each Comet plan in `ColumnarToRowExec` and + * canonicalize the wrapper. `QueryPlan.doCanonicalize` does `withNewChildren( + * canonicalizedChildren)`, which triggers `ColumnarToRowExec.copy(child = cometPlan.canonical)` + * — the exact constructor call whose assertion fires in the issue. + */ + private def checkWrappedCanonicalization(tag: String, plans: Seq[SparkPlan]): Unit = { + assert(plans.nonEmpty, s"[$tag] produced no Comet plans with supportsColumnar=true") + val broken = plans.flatMap { p => + try { + // Constructing the wrapper on a non-canonicalized plan succeeds (supportsColumnar=true). + val wrapper = ColumnarToRowExec(p) + // Canonicalizing the wrapper reproduces the AQE path: it calls withNewChildren( + // Seq(p.canonicalized)) which reinvokes the ColumnarToRowExec constructor. + wrapper.canonicalized + None + } catch { + case t: Throwable => + Some( + ( + p, + s"ColumnarToRowExec($tag).canonicalized threw: ${t.getClass.getName}: " + + s"${t.getMessage}")) + } + } + if (broken.nonEmpty) { + val details = broken + .map { case (p, reason) => + s"""node: ${p.getClass.getName} + | reason: $reason + | original: + |${p.treeString} + |""".stripMargin + } + .mkString("\n") + fail( + s"[$tag] ${broken.size} node(s) blow the ColumnarToRow assertion when canonicalized:" + + s"\n$details") + } + } + + test("ColumnarToRowExec(cometPlan).canonicalized reproduces the #3949 path") { + withTempDir { dir => + val fact = s"${dir.getAbsolutePath}/fact.parquet" + val dim = s"${dir.getAbsolutePath}/dim.parquet" + val sess = spark + import sess.implicits._ + val oneDay = 24L * 60L * 60000L + val now = System.currentTimeMillis() + (0 until 400) + .map(i => (i, new java.sql.Date(now + (i % 40) * oneDay), i.toString)) + .toDF("fact_id", "fact_date", "fact_str") + .write + .partitionBy("fact_date") + .parquet(fact) + (0 until 40) + .map(i => (i, new java.sql.Date(now + i * oneDay), i.toString)) + .toDF("dim_id", "dim_date", "dim_str") + .write + .parquet(dim) + spark.read.parquet(fact).createOrReplaceTempView("t3949_fact2") + spark.read.parquet(dim).createOrReplaceTempView("t3949_dim2") + + val queries = Seq( + "select a, b from t3949_fact2".replace("a, b", "fact_id, fact_str"), + "select fact_id, count(*) from t3949_fact2 group by fact_id", + "select * from t3949_fact2 f join t3949_dim2 d on f.fact_date = d.dim_date " + + "where d.dim_id > 35") + queries.zipWithIndex.foreach { case (q, i) => + checkWrappedCanonicalization(s"q$i", plansOf(q)) + } + } + } + + test("DPP-shaped plan canonicalization (mirrors #3949 setup)") { + withTempDir { dir => + val fact = s"${dir.getAbsolutePath}/fact.parquet" + val dim = s"${dir.getAbsolutePath}/dim.parquet" + val sess = spark + import sess.implicits._ + val oneDay = 24L * 60L * 60000L + val now = System.currentTimeMillis() + (0 until 400) + .map(i => (i, new java.sql.Date(now + (i % 40) * oneDay), i.toString)) + .toDF("fact_id", "fact_date", "fact_str") + .write + .partitionBy("fact_date") + .parquet(fact) + (0 until 40) + .map(i => (i, new java.sql.Date(now + i * oneDay), i.toString)) + .toDF("dim_id", "dim_date", "dim_str") + .toDF("dim_id", "dim_date", "dim_str") + .write + .parquet(dim) + spark.read.parquet(fact).createOrReplaceTempView("t3949_fact") + spark.read.parquet(dim).createOrReplaceTempView("t3949_dim") + checkCanonicalized( + "dpp", + plansOf( + "select * from t3949_fact f join t3949_dim d on f.fact_date = d.dim_date " + + "where d.dim_id > 35")) + } + } +} diff --git a/spark/src/test/scala/org/apache/spark/sql/comet/CometCanonicalizationTpcdsSuite.scala b/spark/src/test/scala/org/apache/spark/sql/comet/CometCanonicalizationTpcdsSuite.scala new file mode 100644 index 0000000000..494ce7ce74 --- /dev/null +++ b/spark/src/test/scala/org/apache/spark/sql/comet/CometCanonicalizationTpcdsSuite.scala @@ -0,0 +1,161 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.comet + +import scala.collection.mutable + +import org.apache.spark.SparkContext +import org.apache.spark.internal.config.{MEMORY_OFFHEAP_ENABLED, MEMORY_OFFHEAP_SIZE} +import org.apache.spark.sql.TPCDSBase +import org.apache.spark.sql.catalyst.util.resourceToString +import org.apache.spark.sql.execution.{ColumnarToRowExec, SparkPlan} +import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.TestSparkSession + +import org.apache.comet.CometConf + +/** + * Scans every TPC-DS query plan for Comet nodes whose canonicalization either throws or produces + * a non-columnar result. Targets issue #3949: the stack trace shows Spark's `ColumnarToRowExec` + * constructor assertion firing during `QueryPlan.doCanonicalize` → `withNewChildren( + * canonicalizedChildren)`. That can only happen if `cometPlan.canonicalized.supportsColumnar` is + * false. + * + * For each Comet node produced by planning a TPC-DS query, this suite: + * 1. Calls `p.canonicalized` and asserts `supportsColumnar == true` (the direct invariant). 2. + * Wraps the original `p` in `ColumnarToRowExec(p)` and canonicalizes the wrapper — exactly + * the call path the issue's stack trace takes. + * + * Plans are only compiled (no execution), so no TPC-DS data is required. + */ +class CometCanonicalizationTpcdsSuite extends TPCDSBase { + + override protected val injectStats: Boolean = false + + override protected def sparkConf = { + val conf = super.sparkConf + conf.set("spark.sql.extensions", "org.apache.comet.CometSparkSessionExtensions") + conf.set( + "spark.shuffle.manager", + "org.apache.spark.sql.comet.execution.shuffle.CometShuffleManager") + conf.set(MEMORY_OFFHEAP_ENABLED.key, "true") + conf.set(MEMORY_OFFHEAP_SIZE.key, "2g") + conf.set(CometConf.COMET_ENABLED.key, "true") + conf.set(CometConf.COMET_EXEC_ENABLED.key, "true") + conf.set(CometConf.COMET_NATIVE_SCAN_ENABLED.key, "true") + conf.set(CometConf.COMET_EXEC_SHUFFLE_ENABLED.key, "true") + conf.set(CometConf.COMET_ONHEAP_MEMORY_OVERHEAD.key, "1g") + conf.set(SQLConf.SHUFFLE_PARTITIONS.key, "4") + conf + } + + override protected def createSparkSession: TestSparkSession = { + new TestSparkSession(new SparkContext("local[2]", this.getClass.getCanonicalName, sparkConf)) + } + + private def collectCometPlans(root: SparkPlan): Seq[SparkPlan] = { + def walk(p: SparkPlan): Seq[SparkPlan] = p match { + case a: AdaptiveSparkPlanExec => + Seq(a.initialPlan, a.executedPlan).flatMap(walk) + case other => + val self = + if (other.isInstanceOf[CometPlan] && other.supportsColumnar) Seq(other) else Nil + self ++ other.children.flatMap(walk) + } + walk(root) + } + + case class Failure(query: String, node: SparkPlan, mode: String, cause: Throwable) { + def pretty: String = + s"""[$query] ${node.getClass.getSimpleName} — $mode — ${Option(cause.getMessage).getOrElse( + cause.toString)} + | original node: + |${node.treeString} + |""".stripMargin + } + + private def checkNode(query: String, p: SparkPlan, failures: mutable.Buffer[Failure]): Unit = { + // Direct canonicalization. + try { + val c = p.canonicalized + if (!c.supportsColumnar) { + failures += Failure( + query, + p, + "canonicalized.supportsColumnar=false", + new IllegalStateException( + s"canonical form: ${c.getClass.getSimpleName}\n${c.treeString}")) + } + } catch { + case t: Throwable => + failures += Failure(query, p, "canonicalized threw", t) + } + + // Wrapped-in-ColumnarToRow canonicalization — mirrors the #3949 stack exactly. + try { + val wrapper = ColumnarToRowExec(p) + wrapper.canonicalized + } catch { + case t: Throwable => + failures += Failure(query, p, "ColumnarToRowExec(p).canonicalized threw", t) + } + } + + private def runQueryScan( + group: String, + query: String, + failures: mutable.Buffer[Failure]): Unit = { + val sql = resourceToString( + s"$group/$query.sql", + classLoader = Thread.currentThread().getContextClassLoader) + val plan = + try spark.sql(sql).queryExecution.executedPlan + catch { + case t: Throwable => + failures += Failure(s"$group/$query", null, "executedPlan threw", t) + return + } + collectCometPlans(plan).foreach(p => checkNode(s"$group/$query", p, failures)) + } + + private val perTestConf: Map[String, String] = Map( + CometConf.COMET_DPP_FALLBACK_ENABLED.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "10MB", + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") + + test("canonicalization holds across every TPC-DS query") { + val failures = mutable.Buffer.empty[Failure] + withSQLConf(perTestConf.toSeq: _*) { + for (q <- tpcdsQueries) runQueryScan("tpcds", q, failures) + for (q <- tpcdsQueriesV2_7_0) runQueryScan("tpcds-v2.7.0", q, failures) + } + if (failures.nonEmpty) { + val distinctQueries = failures.map(_.query).distinct + val header = + s"Canonicalization broke in ${failures.size} Comet node(s) across " + + s"${distinctQueries.size} query/ies:\n${distinctQueries.mkString(", ")}\n" + // Limit output size on huge failure sets. + val body = failures.take(10).map(_.pretty).mkString("\n") + fail( + header + body + (if (failures.size > 10) s"\n... (${failures.size - 10} more)" else "")) + } + } +} diff --git a/spark/src/test/scala/org/apache/spark/sql/comet/CometFuzzDppSuite.scala b/spark/src/test/scala/org/apache/spark/sql/comet/CometFuzzDppSuite.scala new file mode 100644 index 0000000000..f5271272b5 --- /dev/null +++ b/spark/src/test/scala/org/apache/spark/sql/comet/CometFuzzDppSuite.scala @@ -0,0 +1,159 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.comet + +import scala.collection.mutable + +import org.apache.hadoop.fs.Path +import org.apache.spark.SparkConf +import org.apache.spark.internal.config.{MEMORY_OFFHEAP_ENABLED, MEMORY_OFFHEAP_SIZE} +import org.apache.spark.sql.CometTestBase +import org.apache.spark.sql.Row +import org.apache.spark.sql.internal.SQLConf + +import org.apache.comet.{CometConf, FuzzFallback} + +/** + * Execution-based fuzz suite that exercises the DPP fallback code path added in #3879. Builds + * small partitioned fact/dim tables, then for each seed: + * - randomly vetoes Comet shuffle/exec conversions, + * - actually executes a DPP-flavored query so AQE re-plans stages at runtime. + * + * If the rule pipeline ever produces an invalid plan, Spark's own driver-side assertions will + * throw at plan construction and the failure is surfaced with the seed, query, and variant for + * reproduction. + */ +class CometFuzzDppSuite extends CometTestBase { + + override protected def sparkConf: SparkConf = { + val conf = super.sparkConf + conf.set(MEMORY_OFFHEAP_ENABLED.key, "true") + conf.set(MEMORY_OFFHEAP_SIZE.key, "2g") + conf.set(CometConf.COMET_ONHEAP_MEMORY_OVERHEAD.key, "1g") + conf + } + + private val seedsPerVariant: Int = + sys.env.getOrElse("COMET_FUZZ_SEEDS", "40").toInt + + private def buildDppData(base: Path): (String, String) = { + val factPath = s"${base.toString}/fact.parquet" + val dimPath = s"${base.toString}/dim.parquet" + withSQLConf(CometConf.COMET_EXEC_ENABLED.key -> "false") { + val sess = spark + import sess.implicits._ + val oneDay = 24L * 60L * 60000L + val now = System.currentTimeMillis() + val fact = (0 until 400) + .map(i => (i, new java.sql.Date(now + (i % 40) * oneDay), i.toString)) + .toDF("fact_id", "fact_date", "fact_str") + fact.write.partitionBy("fact_date").parquet(factPath) + val dim = (0 until 40) + .map(i => (i, new java.sql.Date(now + i * oneDay), i.toString)) + .toDF("dim_id", "dim_date", "dim_str") + dim.write.parquet(dimPath) + } + (factPath, dimPath) + } + + private val variants: Seq[(String, Map[String, String])] = Seq( + "smj+aqe" -> Map( + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", + SQLConf.PREFER_SORTMERGEJOIN.key -> "true", + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true"), + "bhj+aqe" -> Map( + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "10MB", + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true"), + "smj+aqe+coalesce" -> Map( + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", + SQLConf.PREFER_SORTMERGEJOIN.key -> "true", + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + "spark.sql.adaptive.coalescePartitions.enabled" -> "true", + "spark.sql.adaptive.coalescePartitions.minPartitionSize" -> "1b", + "spark.sql.adaptive.coalescePartitions.initialPartitionNum" -> "16")) + + private val queries = Seq( + // Classic DPP: join on partitioning column, filter on dim side. + "select * from dpp_fact f join dpp_dim d on f.fact_date = d.dim_date where d.dim_id > 35", + // Aggregated result with DPP. + """select f.fact_date, count(*) c + |from dpp_fact f join dpp_dim d on f.fact_date = d.dim_date + |where d.dim_id > 30 + |group by f.fact_date""".stripMargin, + // Two-stage query mixing row/columnar operators above a DPP scan. + """select cnt, count(*) from ( + | select f.fact_id, count(*) cnt + | from dpp_fact f join dpp_dim d on f.fact_date = d.dim_date + | where d.dim_id > 20 + | group by f.fact_id + |) group by cnt""".stripMargin) + + test("fuzz fallback on DPP execution") { + withTempDir { dir => + val (factPath, dimPath) = buildDppData(new Path(dir.getAbsolutePath)) + spark.read.parquet(factPath).createOrReplaceTempView("dpp_fact") + spark.read.parquet(dimPath).createOrReplaceTempView("dpp_dim") + + val failures = mutable.Buffer.empty[(String, String, Long, Throwable)] + for ((variantName, variantConf) <- variants; (q, idx) <- queries.zipWithIndex) { + for (i <- 0 until seedsPerVariant) { + val seed = (variantName.hashCode.toLong * 31 + idx) * 1000003L + i + val conf = variantConf ++ Map( + CometConf.COMET_FUZZ_FALLBACK_ENABLED.key -> "true", + CometConf.COMET_FUZZ_FALLBACK_SEED.key -> seed.toString, + CometConf.COMET_FUZZ_FALLBACK_SHUFFLE_VETO_PROBABILITY.key -> "0.5", + CometConf.COMET_FUZZ_FALLBACK_EXEC_VETO_PROBABILITY.key -> "0.3", + CometConf.COMET_DPP_FALLBACK_ENABLED.key -> "true", + SQLConf.USE_V1_SOURCE_LIST.key -> "parquet") + try { + withSQLConf(conf.toSeq: _*) { + FuzzFallback.reset() + try { + val rows: Array[Row] = spark.sql(q).collect() + // Touch rows to force materialization. + rows.length + } finally FuzzFallback.reset() + } + } catch { + case t: Throwable => failures += ((variantName, s"q$idx", seed, t)) + } + } + } + + if (failures.nonEmpty) { + val grouped = failures.groupBy { case (v, q, _, _) => (v, q) } + val summary = grouped.toSeq + .map { case ((v, q), fs) => + val seeds = fs.map(_._3).mkString(", ") + s" $v/$q: ${fs.size} seed(s) failed: $seeds" + } + .mkString("\n") + val (fv, fq, fseed, ft) = failures.head + var cause: Throwable = ft + while (cause.getCause != null && cause.getCause != cause) cause = cause.getCause + val msg = Option(cause.getMessage).getOrElse(cause.toString) + throw new AssertionError( + s"Fuzz fallback produced ${failures.size} failure(s) across variants:\n$summary\n" + + s"First failure: $fv/$fq seed=$fseed\n$msg", + ft) + } + } + } +} diff --git a/spark/src/test/scala/org/apache/spark/sql/comet/CometFuzzFallbackSuite.scala b/spark/src/test/scala/org/apache/spark/sql/comet/CometFuzzFallbackSuite.scala new file mode 100644 index 0000000000..b2f197011d --- /dev/null +++ b/spark/src/test/scala/org/apache/spark/sql/comet/CometFuzzFallbackSuite.scala @@ -0,0 +1,168 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.comet + +import scala.collection.mutable + +import org.apache.spark.SparkContext +import org.apache.spark.internal.config.{MEMORY_OFFHEAP_ENABLED, MEMORY_OFFHEAP_SIZE} +import org.apache.spark.sql.TPCDSBase +import org.apache.spark.sql.catalyst.util.resourceToString +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.TestSparkSession + +import org.apache.comet.{CometConf, FuzzFallback} + +/** + * Fuzz suite that randomly vetoes Comet conversions while planning every TPC-DS query. The goal + * is to stress the Spark/Comet boundary: if the rule pipeline ever constructs an invalid plan + * (e.g. a ColumnarToRow over a non-columnar child), Spark's own driver-side assertions will throw + * during plan construction and the failure is reported with the seed and query for reproduction. + * + * This suite only triggers Comet rule application via `queryExecution.executedPlan`; queries are + * not executed, so no TPC-DS data is required. Managed (location-less) tables are created from + * the TPC-DS schema so that `sql(...)` resolves. + * + * To reproduce a specific failure, rerun with `spark.comet.fuzz.fallback.seed=`. + * + * Run: + * {{{ + * ./mvnw test -Dsuites="org.apache.spark.sql.comet.CometFuzzFallbackSuite" + * }}} + */ +class CometFuzzFallbackSuite extends TPCDSBase { + + /** Number of random seeds to exercise per query. */ + private val seedsPerQuery: Int = + sys.env.getOrElse("COMET_FUZZ_SEEDS", "8").toInt + + /** Probability the fuzz layer vetoes converting a shuffle to Comet. */ + private val shuffleVetoProbability: Double = + sys.env.getOrElse("COMET_FUZZ_SHUFFLE_P", "0.5").toDouble + + /** Probability the fuzz layer vetoes converting an operator to Comet. */ + private val execVetoProbability: Double = + sys.env.getOrElse("COMET_FUZZ_EXEC_P", "0.2").toDouble + + // SF=1 synthetic stats are fine; we never execute. + override protected val injectStats: Boolean = false + + override protected def sparkConf = { + val conf = super.sparkConf + conf.set("spark.sql.extensions", "org.apache.comet.CometSparkSessionExtensions") + conf.set( + "spark.shuffle.manager", + "org.apache.spark.sql.comet.execution.shuffle.CometShuffleManager") + conf.set(MEMORY_OFFHEAP_ENABLED.key, "true") + conf.set(MEMORY_OFFHEAP_SIZE.key, "2g") + conf.set(CometConf.COMET_ENABLED.key, "true") + conf.set(CometConf.COMET_EXEC_ENABLED.key, "true") + conf.set(CometConf.COMET_NATIVE_SCAN_ENABLED.key, "true") + conf.set(CometConf.COMET_EXEC_SHUFFLE_ENABLED.key, "true") + conf.set(CometConf.COMET_ONHEAP_MEMORY_OVERHEAD.key, "1g") + conf.set(SQLConf.SHUFFLE_PARTITIONS.key, "4") + conf + } + + override protected def createSparkSession: TestSparkSession = { + new TestSparkSession(new SparkContext("local[2]", this.getClass.getCanonicalName, sparkConf)) + } + + private val queryConf: Map[String, String] = Map( + CometConf.COMET_FUZZ_FALLBACK_ENABLED.key -> "true", + CometConf.COMET_FUZZ_FALLBACK_SHUFFLE_VETO_PROBABILITY.key -> shuffleVetoProbability.toString, + CometConf.COMET_FUZZ_FALLBACK_EXEC_VETO_PROBABILITY.key -> execVetoProbability.toString, + // Keep the DPP fallback on so we exercise the #3879 code path alongside random vetoes. + CometConf.COMET_DPP_FALLBACK_ENABLED.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "10MB", + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") + + private def runFuzzedQuery(group: String, query: String, seed: Long): Unit = { + val sql = resourceToString( + s"$group/$query.sql", + classLoader = Thread.currentThread().getContextClassLoader) + + val perQueryConf = queryConf + (CometConf.COMET_FUZZ_FALLBACK_SEED.key -> seed.toString) + + withSQLConf(perQueryConf.toSeq: _*) { + FuzzFallback.reset() + try { + // Touch executedPlan to drive the Comet rule pipeline and the debug assertion. + spark.sql(sql).queryExecution.executedPlan + } catch { + case t: Throwable => + // Walk the cause chain for the most specific message. + var cause: Throwable = t + while (cause.getCause != null && cause.getCause != cause) cause = cause.getCause + val msg = Option(cause.getMessage).getOrElse(cause.toString) + throw new AssertionError( + s"Fuzz fallback produced a bad plan for $group/$query with seed=$seed " + + s"(shuffleVetoP=$shuffleVetoProbability, execVetoP=$execVetoProbability):\n$msg", + t) + } finally { + FuzzFallback.reset() + } + } + } + + // Queries known to be heavy or flaky for planning at SF=1 without execution; skip to keep + // the suite responsive. Extend as needed. + private val skip: Set[String] = Set.empty + + private def seedsFor(query: String): Seq[Long] = { + // Derive a stable per-query seed base so queries don't all share the same RNG stream. + val base = query.hashCode.toLong + (0 until seedsPerQuery).map(i => base + i * 0x9e3779b97f4a7c15L) + } + + for (q <- tpcdsQueries if !skip.contains(q)) { + test(s"fuzz fallback on planning: tpcds/$q") { + val failures = mutable.Buffer.empty[(Long, Throwable)] + for (seed <- seedsFor(q)) { + try runFuzzedQuery("tpcds", q, seed) + catch { case t: Throwable => failures += ((seed, t)) } + } + if (failures.nonEmpty) { + val first = failures.head + val summary = failures.map { case (s, _) => s"seed=$s" }.mkString(", ") + val msg = s"Fuzz fallback produced bad plans for $q across ${failures.size} seed(s): " + + s"$summary\nFirst failure:\n${first._2.getMessage}" + throw new AssertionError(msg, first._2) + } + } + } + + for (q <- tpcdsQueriesV2_7_0 if !skip.contains(q)) { + test(s"fuzz fallback on planning: tpcds-v2.7.0/$q") { + val failures = mutable.Buffer.empty[(Long, Throwable)] + for (seed <- seedsFor(q)) { + try runFuzzedQuery("tpcds-v2.7.0", q, seed) + catch { case t: Throwable => failures += ((seed, t)) } + } + if (failures.nonEmpty) { + val first = failures.head + val summary = failures.map { case (s, _) => s"seed=$s" }.mkString(", ") + val msg = s"Fuzz fallback produced bad plans for $q across ${failures.size} seed(s): " + + s"$summary\nFirst failure:\n${first._2.getMessage}" + throw new AssertionError(msg, first._2) + } + } + } +} From 65be97e9011976bad065d98463765c3f33c2be7c Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Fri, 17 Apr 2026 11:09:34 -0600 Subject: [PATCH 2/2] test: reproduce inconsistent DPP fallback across initial-plan vs AQE stage-prep MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit stageContainsDPPScan walks s.child.exists for a FileSourceScanExec with a PlanExpression partition filter. In AQE, once the inner stage materializes, its subtree is replaced by a ShuffleQueryStageExec whose children is Seq.empty — .exists no longer descends into it and the DPP scan becomes invisible. The same shuffle therefore falls back at initial planning (columnarShuffleSupported=false) but is converted to Comet at stage-prep (columnarShuffleSupported=true). That plan-shape flip is the suspected trigger for #3949. The test builds a DPP-shaped query, observes the initial decision, then swaps the shuffle's child for an opaque LeafExecNode stub that mimics how ShuffleQueryStageExec presents to tree walks, and observes the decision flip. A fix requires stageContainsDPPScan to descend into QueryStageExec.plan. --- .../CometDppFallbackConsistencySuite.scala | 182 ++++++++++++++++++ 1 file changed, 182 insertions(+) create mode 100644 spark/src/test/scala/org/apache/spark/sql/comet/CometDppFallbackConsistencySuite.scala diff --git a/spark/src/test/scala/org/apache/spark/sql/comet/CometDppFallbackConsistencySuite.scala b/spark/src/test/scala/org/apache/spark/sql/comet/CometDppFallbackConsistencySuite.scala new file mode 100644 index 0000000000..e045bcfab0 --- /dev/null +++ b/spark/src/test/scala/org/apache/spark/sql/comet/CometDppFallbackConsistencySuite.scala @@ -0,0 +1,182 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.comet + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.CometTestBase +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec +import org.apache.spark.sql.execution.{FileSourceScanExec, LeafExecNode, SparkPlan} +import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec +import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec +import org.apache.spark.sql.internal.SQLConf + +import org.apache.comet.CometConf + +/** + * Demonstrates the suspected root cause of issue #3949. Comet's DPP fallback decision + * (`columnarShuffleSupported` → `stageContainsDPPScan`) walks `s.child.exists(...)` to look for a + * `FileSourceScanExec` with a `PlanExpression` partition filter. That walk is not stable across + * the two planning passes: + * + * - initial planning: the shuffle's child subtree includes the DPP scan, so `.exists` finds it + * and Comet falls back (keeps the shuffle as Spark). + * - AQE stage-prep (after the inner stage materializes): the DPP subtree is replaced by an + * opaque wrapper (`ShuffleQueryStageExec`, whose `children == Seq.empty`). `.exists` can no + * longer see the scan, `stageContainsDPPScan` returns false, and the same shuffle is + * converted to Comet. + * + * That decision flip changes the plan shape between passes — the suspected trigger for the + * canonicalization assertion in #3949. A fix needs `stageContainsDPPScan` to descend into + * `QueryStageExec.plan` when walking for DPP filters. + */ +class CometDppFallbackConsistencySuite extends CometTestBase { + + private def buildDppTables(dir: java.io.File): Unit = { + val factPath = s"${dir.getAbsolutePath}/fact.parquet" + val dimPath = s"${dir.getAbsolutePath}/dim.parquet" + withSQLConf(CometConf.COMET_EXEC_ENABLED.key -> "false") { + val sess = spark + import sess.implicits._ + val oneDay = 24L * 60L * 60000L + val now = System.currentTimeMillis() + (0 until 400) + .map(i => (i, new java.sql.Date(now + (i % 40) * oneDay), i.toString)) + .toDF("fact_id", "fact_date", "fact_str") + .write + .partitionBy("fact_date") + .parquet(factPath) + (0 until 40) + .map(i => (i, new java.sql.Date(now + i * oneDay), i.toString)) + .toDF("dim_id", "dim_date", "dim_str") + .write + .parquet(dimPath) + } + spark.read.parquet(factPath).createOrReplaceTempView("dpp_consistency_fact") + spark.read.parquet(dimPath).createOrReplaceTempView("dpp_consistency_dim") + } + + private def unwrapAqe(plan: SparkPlan): SparkPlan = plan match { + case a: AdaptiveSparkPlanExec => a.initialPlan + case other => other + } + + private def findFirstShuffle(plan: SparkPlan): Option[ShuffleExchangeExec] = { + var found: Option[ShuffleExchangeExec] = None + plan.foreach { + case s: ShuffleExchangeExec if found.isEmpty => found = Some(s) + case _ => + } + found + } + + test("columnarShuffleSupported decision flips when child is wrapped in ShuffleQueryStageExec") { + withTempDir { dir => + buildDppTables(dir) + withSQLConf( + CometConf.COMET_DPP_FALLBACK_ENABLED.key -> "true", + // Force SMJ so we get a shuffle above the DPP scan. + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", + SQLConf.PREFER_SORTMERGEJOIN.key -> "true", + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + SQLConf.USE_V1_SOURCE_LIST.key -> "parquet") { + + // Aggregation guarantees a ShuffleExchangeExec; the join's DPP filter attaches below. + val df = spark.sql( + "select f.fact_date, count(*) c " + + "from dpp_consistency_fact f " + + "join dpp_consistency_dim d on f.fact_date = d.dim_date " + + "where d.dim_id > 35 " + + "group by f.fact_date") + val initialPlan = unwrapAqe(df.queryExecution.executedPlan) + + val shuffle = findFirstShuffle(initialPlan).getOrElse { + fail(s"No ShuffleExchangeExec found in initial plan:\n${initialPlan.treeString}") + } + + // (1) initial-plan decision + val initialDecision = CometShuffleExchangeExec.columnarShuffleSupported(shuffle) + + // Prove the DPP scan is visible in the initial child subtree. + val initialDppVisible = shuffle.child.exists { + case scan: FileSourceScanExec => + scan.partitionFilters.exists(e => + e.exists( + _.isInstanceOf[org.apache.spark.sql.catalyst.expressions.PlanExpression[_]])) + case _ => false + } + + // (2) simulate AQE stage-prep: swap the child for a LeafExecNode, matching the + // tree-walking behavior of ShuffleQueryStageExec (whose children is Seq.empty from the + // perspective of `.exists`). This models what the shuffle "sees" after its child stage + // has materialized and been replaced by an opaque stage wrapper. + val hiddenChild = OpaqueStageStub(shuffle.child.output) + val postAqeShuffle = + shuffle.withNewChildren(Seq(hiddenChild)).asInstanceOf[ShuffleExchangeExec] + val postAqeDecision = CometShuffleExchangeExec.columnarShuffleSupported(postAqeShuffle) + + val postAqeDppVisible = postAqeShuffle.child.exists { + case scan: FileSourceScanExec => + scan.partitionFilters.exists(e => + e.exists( + _.isInstanceOf[org.apache.spark.sql.catalyst.expressions.PlanExpression[_]])) + case _ => false + } + + // scalastyle:off println + println(s"=== DPP consistency check ===") + println(s"initial shuffle.child:\n${shuffle.child.treeString}") + println(s"initialDppVisible=$initialDppVisible, initialDecision=$initialDecision") + println(s"postAqeDppVisible=$postAqeDppVisible, postAqeDecision=$postAqeDecision") + // scalastyle:on println + + // The DPP scan is only visible while the child subtree is walkable; hiding it behind + // an opaque stage wrapper removes it from `.exists`. This is the mechanism. + assert(initialDppVisible, "sanity: initial child tree should expose DPP scan") + assert(!postAqeDppVisible, "sanity: stage-wrapped child should hide DPP scan") + + // The bug: Comet decides to fall back at initial planning but to convert to Comet + // once the stage has materialized. That is the same shuffle getting two different + // treatments across the two passes. + // + // TODO(#3949): once `stageContainsDPPScan` descends into `QueryStageExec.plan`, flip + // this assertion to `initialDecision == postAqeDecision`. + assert( + initialDecision == false, + s"expected Spark fallback initially, got $initialDecision") + assert( + postAqeDecision == true, + s"expected Comet conversion after stage wrap, got $postAqeDecision " + + "(if this now returns false, the bug has been fixed — invert these assertions)") + } + } + } +} + +/** + * LeafExecNode stub that mimics how `ShuffleQueryStageExec` presents itself to tree walks: + * `children == Seq.empty`, so `.exists(...)` cannot descend. Used to model what the parent + * shuffle's `.child` looks like after the inner stage has materialized. + */ +private case class OpaqueStageStub(output: Seq[Attribute]) extends LeafExecNode { + override protected def doExecute(): RDD[InternalRow] = + throw new UnsupportedOperationException("stub") +}