diff --git a/common/src/main/scala/org/apache/comet/CometConf.scala b/common/src/main/scala/org/apache/comet/CometConf.scala index 046ccf0b1c..d58dfa62da 100644 --- a/common/src/main/scala/org/apache/comet/CometConf.scala +++ b/common/src/main/scala/org/apache/comet/CometConf.scala @@ -618,6 +618,16 @@ object CometConf extends ShimCometConf { .booleanConf .createWithDefault(false) + val COMET_ASSERT_VALID_PLAN_TRANSITIONS: ConfigEntry[Boolean] = + conf("spark.comet.assertValidPlanTransitions.enabled") + .category(CATEGORY_EXEC_EXPLAIN) + .doc( + "When enabled, Comet asserts that every columnar-to-row transition in the " + + "post-rule plan has a columnar child. Intended for debugging intermittent " + + "bad-plan shapes; off by default.") + .booleanConf + .createWithDefault(true) + val COMET_LOG_FALLBACK_REASONS: ConfigEntry[Boolean] = conf("spark.comet.logFallbackReasons.enabled") .category(CATEGORY_EXEC_EXPLAIN) diff --git a/spark/src/main/scala/org/apache/comet/rules/EliminateRedundantTransitions.scala b/spark/src/main/scala/org/apache/comet/rules/EliminateRedundantTransitions.scala index 7402a83248..8ebe72785a 100644 --- a/spark/src/main/scala/org/apache/comet/rules/EliminateRedundantTransitions.scala +++ b/spark/src/main/scala/org/apache/comet/rules/EliminateRedundantTransitions.scala @@ -54,9 +54,14 @@ import org.apache.comet.CometConf case class EliminateRedundantTransitions(session: SparkSession) extends Rule[SparkPlan] { private lazy val showTransformations = CometConf.COMET_EXPLAIN_TRANSFORMATIONS.get() + private lazy val assertValidPlanTransitions = + CometConf.COMET_ASSERT_VALID_PLAN_TRANSITIONS.get() override def apply(plan: SparkPlan): SparkPlan = { val newPlan = _apply(plan) + if (assertValidPlanTransitions) { + checkTransitionInvariant(newPlan) + } if (showTransformations && !newPlan.fastEquals(plan)) { logInfo(s""" |=== Applying Rule $ruleName === @@ -66,6 +71,28 @@ case class EliminateRedundantTransitions(session: SparkSession) extends Rule[Spa newPlan } + // Gated by spark.comet.assertValidPlanTransitions.enabled. + // Every columnar-to-row transition must have a columnar child; violations + // indicate a bad plan produced by an earlier rule. + private def checkTransitionInvariant(plan: SparkPlan): Unit = { + plan.foreach { + case c: ColumnarToRowExec if !c.child.supportsColumnar => + val cls = c.child.getClass.getName + throw new IllegalStateException( + s"ColumnarToRowExec wraps non-columnar child ($cls):\n" + plan.treeString) + case c: CometColumnarToRowExec if !c.child.supportsColumnar => + val cls = c.child.getClass.getName + throw new IllegalStateException( + s"CometColumnarToRowExec wraps non-columnar child ($cls):\n" + plan.treeString) + case c: CometNativeColumnarToRowExec if !c.child.supportsColumnar => + val cls = c.child.getClass.getName + throw new IllegalStateException( + s"CometNativeColumnarToRowExec wraps non-columnar child ($cls):\n" + + plan.treeString) + case _ => + } + } + private def _apply(plan: SparkPlan): SparkPlan = { val eliminatedPlan = plan transformUp { case ColumnarToRowExec(shuffleExchangeExec: CometShuffleExchangeExec)