diff --git a/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala b/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala index a71eed6c80..091f70fdc2 100644 --- a/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala +++ b/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala @@ -154,7 +154,7 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] { operator2Proto(op).map(fun).getOrElse(op) } - plan.transformUp { + def convertNode(op: SparkPlan): SparkPlan = op match { // Fully native scan for V1 case scan: CometScanExec if scan.scanImpl == CometConf.SCAN_NATIVE_DATAFUSION => val nativeOp = QueryPlanSerde.operator2Proto(scan).get @@ -446,7 +446,7 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] { case other => other } if (!newChildren.exists(_.isInstanceOf[BroadcastExchangeExec])) { - val newPlan = apply(plan.withNewChildren(newChildren)) + val newPlan = convertNode(plan.withNewChildren(newChildren)) if (isCometNative(newPlan) || isCometBroadCastForceEnabled(conf)) { newPlan } else { @@ -554,6 +554,10 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] { } } } + + plan.transformUp { case op => + convertNode(op) + } } private def normalizePlan(plan: SparkPlan): SparkPlan = {