Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions common/src/main/scala/org/apache/comet/CometConf.scala
Original file line number Diff line number Diff line change
Expand Up @@ -618,6 +618,16 @@ object CometConf extends ShimCometConf {
.booleanConf
.createWithDefault(false)

val COMET_ASSERT_VALID_PLAN_TRANSITIONS: ConfigEntry[Boolean] =
conf("spark.comet.assertValidPlanTransitions.enabled")
.category(CATEGORY_EXEC_EXPLAIN)
.doc(
"When enabled, Comet asserts that every columnar-to-row transition in the " +
"post-rule plan has a columnar child. Intended for debugging intermittent " +
"bad-plan shapes; off by default.")
.booleanConf
.createWithDefault(true)

val COMET_LOG_FALLBACK_REASONS: ConfigEntry[Boolean] =
conf("spark.comet.logFallbackReasons.enabled")
.category(CATEGORY_EXEC_EXPLAIN)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,14 @@ import org.apache.comet.CometConf
case class EliminateRedundantTransitions(session: SparkSession) extends Rule[SparkPlan] {

private lazy val showTransformations = CometConf.COMET_EXPLAIN_TRANSFORMATIONS.get()
private lazy val assertValidPlanTransitions =
CometConf.COMET_ASSERT_VALID_PLAN_TRANSITIONS.get()

override def apply(plan: SparkPlan): SparkPlan = {
val newPlan = _apply(plan)
if (assertValidPlanTransitions) {
checkTransitionInvariant(newPlan)
}
if (showTransformations && !newPlan.fastEquals(plan)) {
logInfo(s"""
|=== Applying Rule $ruleName ===
Expand All @@ -66,6 +71,28 @@ case class EliminateRedundantTransitions(session: SparkSession) extends Rule[Spa
newPlan
}

// Gated by spark.comet.assertValidPlanTransitions.enabled.
// Every columnar-to-row transition must have a columnar child; violations
// indicate a bad plan produced by an earlier rule.
private def checkTransitionInvariant(plan: SparkPlan): Unit = {
plan.foreach {
case c: ColumnarToRowExec if !c.child.supportsColumnar =>
val cls = c.child.getClass.getName
throw new IllegalStateException(
s"ColumnarToRowExec wraps non-columnar child ($cls):\n" + plan.treeString)
case c: CometColumnarToRowExec if !c.child.supportsColumnar =>
val cls = c.child.getClass.getName
throw new IllegalStateException(
s"CometColumnarToRowExec wraps non-columnar child ($cls):\n" + plan.treeString)
case c: CometNativeColumnarToRowExec if !c.child.supportsColumnar =>
val cls = c.child.getClass.getName
throw new IllegalStateException(
s"CometNativeColumnarToRowExec wraps non-columnar child ($cls):\n" +
plan.treeString)
case _ =>
}
}

private def _apply(plan: SparkPlan): SparkPlan = {
val eliminatedPlan = plan transformUp {
case ColumnarToRowExec(shuffleExchangeExec: CometShuffleExchangeExec)
Expand Down
Loading