From 3354763e6e75039507738d154f03126419f988b6 Mon Sep 17 00:00:00 2001 From: Hongze Zhang Date: Wed, 9 Oct 2024 15:36:50 +0800 Subject: [PATCH 1/4] [VL] Prepare shim API for breaking change in SPARK-48610 --- .../sql/execution/GlutenExplainUtils.scala | 9 ++++---- .../spark/sql/execution/GlutenImplicits.scala | 22 +++++++++++-------- .../apache/gluten/sql/shims/SparkShims.scala | 10 +++++++++ .../sql/shims/spark32/Spark32Shims.scala | 13 +++++++++++ .../sql/shims/spark33/Spark33Shims.scala | 13 +++++++++++ .../sql/shims/spark34/Spark34Shims.scala | 13 +++++++++++ .../sql/shims/spark35/Spark35Shims.scala | 13 +++++++++++ 7 files changed, 80 insertions(+), 13 deletions(-) diff --git a/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/GlutenExplainUtils.scala b/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/GlutenExplainUtils.scala index 43b74c883671..11c12171f8ca 100644 --- a/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/GlutenExplainUtils.scala +++ b/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/GlutenExplainUtils.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution import org.apache.gluten.execution.WholeStageTransformer import org.apache.gluten.extension.GlutenPlan import org.apache.gluten.extension.columnar.FallbackTags +import org.apache.gluten.sql.shims.SparkShimLoader import org.apache.gluten.utils.PlanUtil import org.apache.spark.sql.AnalysisException @@ -49,7 +50,7 @@ object GlutenExplainUtils extends AdaptiveSparkPlanHelper { p: SparkPlan, reason: String, fallbackNodeToReason: mutable.HashMap[String, String]): Unit = { - p.getTagValue(QueryPlan.OP_ID_TAG).foreach { + SparkShimLoader.getSparkShims.getOperatorId(p).foreach { opId => // e.g., 002 project, it is used to help analysis by `substring(4)` val formattedNodeName = f"$opId%03d ${p.nodeName}" @@ -288,7 +289,7 @@ object GlutenExplainUtils extends AdaptiveSparkPlanHelper { } visited.add(plan) currentOperationID += 1 - plan.setTagValue(QueryPlan.OP_ID_TAG, currentOperationID) + SparkShimLoader.getSparkShims.setOperatorId(plan, currentOperationID) } plan.foreachUp { @@ -358,12 +359,12 @@ object GlutenExplainUtils extends AdaptiveSparkPlanHelper { * value. */ private def getOpId(plan: QueryPlan[_]): String = { - plan.getTagValue(QueryPlan.OP_ID_TAG).map(v => s"$v").getOrElse("unknown") + SparkShimLoader.getSparkShims.getOperatorId(plan).map(v => s"$v").getOrElse("unknown") } private def removeTags(plan: QueryPlan[_]): Unit = { def remove(p: QueryPlan[_], children: Seq[QueryPlan[_]]): Unit = { - p.unsetTagValue(QueryPlan.OP_ID_TAG) + SparkShimLoader.getSparkShims.unsetOperatorId(p) children.foreach(removeTags) } diff --git a/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/GlutenImplicits.scala b/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/GlutenImplicits.scala index 7f5f9215015f..4ecc674d4b13 100644 --- a/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/GlutenImplicits.scala +++ b/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/GlutenImplicits.scala @@ -25,7 +25,6 @@ import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical.{CommandResult, LogicalPlan} import org.apache.spark.sql.catalyst.util.StringUtils.PlanStringConcat import org.apache.spark.sql.execution.ColumnarWriteFilesExec.NoopLeaf -import org.apache.spark.sql.execution.GlutenExplainUtils._ import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, AQEShuffleReadExec, QueryStageExec} import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec import org.apache.spark.sql.execution.command.{DataWritingCommandExec, ExecutedCommandExec} @@ -42,8 +41,8 @@ import scala.collection.mutable.ArrayBuffer * A helper class to get the Gluten fallback summary from a Spark [[Dataset]]. * * Note that, if AQE is enabled, but the query is not materialized, then this method will re-plan - * the query execution with disabled AQE. It is a workaround to get the final plan, and it may - * cause the inconsistent results with a materialized query. However, we have no choice. + * the query execution with disabled AQE. It is a workaround to get the final plan, and it may cause + * the inconsistent results with a materialized query. However, we have no choice. * * For example: * @@ -96,7 +95,9 @@ object GlutenImplicits { args.substring(index + "isFinalPlan=".length).trim.toBoolean } - private def collectFallbackNodes(spark: SparkSession, plan: QueryPlan[_]): FallbackInfo = { + private def collectFallbackNodes( + spark: SparkSession, + plan: QueryPlan[_]): GlutenExplainUtils.FallbackInfo = { var numGlutenNodes = 0 val fallbackNodeToReason = new mutable.HashMap[String, String] @@ -131,7 +132,7 @@ object GlutenImplicits { spark, newSparkPlan ) - processPlan( + GlutenExplainUtils.processPlan( newExecutedPlan, new PlanStringConcat().append, Some(plan => collectFallbackNodes(spark, plan))) @@ -146,12 +147,15 @@ object GlutenImplicits { if (PlanUtil.isGlutenTableCache(i)) { numGlutenNodes += 1 } else { - addFallbackNodeWithReason(i, "Columnar table cache is disabled", fallbackNodeToReason) + GlutenExplainUtils.addFallbackNodeWithReason( + i, + "Columnar table cache is disabled", + fallbackNodeToReason) } collect(i.relation.cachedPlan) case _: AQEShuffleReadExec => // Ignore case p: SparkPlan => - handleVanillaSparkPlan(p, fallbackNodeToReason) + GlutenExplainUtils.handleVanillaSparkPlan(p, fallbackNodeToReason) p.innerChildren.foreach(collect) case _ => } @@ -181,10 +185,10 @@ object GlutenImplicits { // AQE is not materialized, so the columnar rules are not applied. // For this case, We apply columnar rules manually with disable AQE. val qe = spark.sessionState.executePlan(logicalPlan, CommandExecutionMode.SKIP) - processPlan(qe.executedPlan, concat.append, collectFallbackFunc) + GlutenExplainUtils.processPlan(qe.executedPlan, concat.append, collectFallbackFunc) } } else { - processPlan(plan, concat.append, collectFallbackFunc) + GlutenExplainUtils.processPlan(plan, concat.append, collectFallbackFunc) } totalNumGlutenNodes += numGlutenNodes totalNumFallbackNodes += fallbackNodeToReason.size diff --git a/shims/common/src/main/scala/org/apache/gluten/sql/shims/SparkShims.scala b/shims/common/src/main/scala/org/apache/gluten/sql/shims/SparkShims.scala index 7671f236c917..fa2991604616 100644 --- a/shims/common/src/main/scala/org/apache/gluten/sql/shims/SparkShims.scala +++ b/shims/common/src/main/scala/org/apache/gluten/sql/shims/SparkShims.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.catalog.BucketSpec import org.apache.spark.sql.catalyst.csv.CSVOptions import org.apache.spark.sql.catalyst.expressions.{Attribute, BinaryExpression, Expression} import org.apache.spark.sql.catalyst.expressions.aggregate.TypedImperativeAggregate +import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.plans.physical.{Distribution, Partitioning} import org.apache.spark.sql.catalyst.rules.Rule @@ -270,4 +271,13 @@ trait SparkShims { def extractExpressionArrayInsert(arrayInsert: Expression): Seq[Expression] = { throw new UnsupportedOperationException("ArrayInsert not supported.") } + + /** Shim method for GlutenExplainUtils.scala. */ + def getOperatorId(plan: QueryPlan[_]): Option[Int] + + /** Shim method for GlutenExplainUtils.scala. */ + def setOperatorId(plan: QueryPlan[_], opId: Int): Unit + + /** Shim method for GlutenExplainUtils.scala. */ + def unsetOperatorId(plan: QueryPlan[_]): Unit } diff --git a/shims/spark32/src/main/scala/org/apache/gluten/sql/shims/spark32/Spark32Shims.scala b/shims/spark32/src/main/scala/org/apache/gluten/sql/shims/spark32/Spark32Shims.scala index f62f9031ce9e..973e675fa98f 100644 --- a/shims/spark32/src/main/scala/org/apache/gluten/sql/shims/spark32/Spark32Shims.scala +++ b/shims/spark32/src/main/scala/org/apache/gluten/sql/shims/spark32/Spark32Shims.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.catalog.BucketSpec import org.apache.spark.sql.catalyst.csv.CSVOptions import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, BinaryExpression, Expression, InputFileBlockLength, InputFileBlockStart, InputFileName} import org.apache.spark.sql.catalyst.expressions.aggregate.TypedImperativeAggregate +import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.plans.physical.{Distribution, HashClusteredDistribution} import org.apache.spark.sql.catalyst.rules.Rule @@ -283,4 +284,16 @@ class Spark32Shims extends SparkShims { val s = decimalType.scale DecimalType(p, if (toScale > s) s else toScale) } + + override def getOperatorId(plan: QueryPlan[_]): Option[Int] = { + plan.getTagValue(QueryPlan.OP_ID_TAG) + } + + override def setOperatorId(plan: QueryPlan[_], opId: Int): Unit = { + plan.setTagValue(QueryPlan.OP_ID_TAG, opId) + } + + override def unsetOperatorId(plan: QueryPlan[_]): Unit = { + plan.unsetTagValue(QueryPlan.OP_ID_TAG) + } } diff --git a/shims/spark33/src/main/scala/org/apache/gluten/sql/shims/spark33/Spark33Shims.scala b/shims/spark33/src/main/scala/org/apache/gluten/sql/shims/spark33/Spark33Shims.scala index 168b88275cc1..5cf7c5505a10 100644 --- a/shims/spark33/src/main/scala/org/apache/gluten/sql/shims/spark33/Spark33Shims.scala +++ b/shims/spark33/src/main/scala/org/apache/gluten/sql/shims/spark33/Spark33Shims.scala @@ -30,6 +30,7 @@ import org.apache.spark.sql.catalyst.catalog.BucketSpec import org.apache.spark.sql.catalyst.csv.CSVOptions import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.{BloomFilterAggregate, RegrR2, TypedImperativeAggregate} +import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Distribution} import org.apache.spark.sql.catalyst.rules.Rule @@ -364,4 +365,16 @@ class Spark33Shims extends SparkShims { RebaseSpec(LegacyBehaviorPolicy.CORRECTED) ) } + + override def getOperatorId(plan: QueryPlan[_]): Option[Int] = { + plan.getTagValue(QueryPlan.OP_ID_TAG) + } + + override def setOperatorId(plan: QueryPlan[_], opId: Int): Unit = { + plan.setTagValue(QueryPlan.OP_ID_TAG, opId) + } + + override def unsetOperatorId(plan: QueryPlan[_]): Unit = { + plan.unsetTagValue(QueryPlan.OP_ID_TAG) + } } diff --git a/shims/spark34/src/main/scala/org/apache/gluten/sql/shims/spark34/Spark34Shims.scala b/shims/spark34/src/main/scala/org/apache/gluten/sql/shims/spark34/Spark34Shims.scala index 558d7f60d5eb..bedad4c01741 100644 --- a/shims/spark34/src/main/scala/org/apache/gluten/sql/shims/spark34/Spark34Shims.scala +++ b/shims/spark34/src/main/scala/org/apache/gluten/sql/shims/spark34/Spark34Shims.scala @@ -32,6 +32,7 @@ import org.apache.spark.sql.catalyst.catalog.BucketSpec import org.apache.spark.sql.catalyst.csv.CSVOptions import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Distribution, KeyGroupedPartitioning, Partitioning} import org.apache.spark.sql.catalyst.rules.Rule @@ -499,4 +500,16 @@ class Spark34Shims extends SparkShims { val expr = arrayInsert.asInstanceOf[ArrayInsert] Seq(expr.srcArrayExpr, expr.posExpr, expr.itemExpr, Literal(expr.legacyNegativeIndex)) } + + override def getOperatorId(plan: QueryPlan[_]): Option[Int] = { + plan.getTagValue(QueryPlan.OP_ID_TAG) + } + + override def setOperatorId(plan: QueryPlan[_], opId: Int): Unit = { + plan.setTagValue(QueryPlan.OP_ID_TAG, opId) + } + + override def unsetOperatorId(plan: QueryPlan[_]): Unit = { + plan.unsetTagValue(QueryPlan.OP_ID_TAG) + } } diff --git a/shims/spark35/src/main/scala/org/apache/gluten/sql/shims/spark35/Spark35Shims.scala b/shims/spark35/src/main/scala/org/apache/gluten/sql/shims/spark35/Spark35Shims.scala index 4a6590161c4a..d130864a9fed 100644 --- a/shims/spark35/src/main/scala/org/apache/gluten/sql/shims/spark35/Spark35Shims.scala +++ b/shims/spark35/src/main/scala/org/apache/gluten/sql/shims/spark35/Spark35Shims.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.catalyst.catalog.BucketSpec import org.apache.spark.sql.catalyst.csv.CSVOptions import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Distribution, KeyGroupedPartitioning, Partitioning} import org.apache.spark.sql.catalyst.rules.Rule @@ -524,4 +525,16 @@ class Spark35Shims extends SparkShims { val expr = arrayInsert.asInstanceOf[ArrayInsert] Seq(expr.srcArrayExpr, expr.posExpr, expr.itemExpr, Literal(expr.legacyNegativeIndex)) } + + override def getOperatorId(plan: QueryPlan[_]): Option[Int] = { + plan.getTagValue(QueryPlan.OP_ID_TAG) + } + + override def setOperatorId(plan: QueryPlan[_], opId: Int): Unit = { + plan.setTagValue(QueryPlan.OP_ID_TAG, opId) + } + + override def unsetOperatorId(plan: QueryPlan[_]): Unit = { + plan.unsetTagValue(QueryPlan.OP_ID_TAG) + } } From 60c46e82d9b17435849842c6c568f50a648d0c9a Mon Sep 17 00:00:00 2001 From: Hongze Zhang Date: Wed, 9 Oct 2024 15:45:29 +0800 Subject: [PATCH 2/4] fixup fixup fixup fixup --- .../sql/execution/GlutenExplainUtils.scala | 141 +++++++++--------- .../apache/gluten/sql/shims/SparkShims.scala | 11 +- .../sql/shims/spark35/Spark35Shims.scala | 18 ++- 3 files changed, 96 insertions(+), 74 deletions(-) diff --git a/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/GlutenExplainUtils.scala b/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/GlutenExplainUtils.scala index 11c12171f8ca..ec529a7b12e2 100644 --- a/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/GlutenExplainUtils.scala +++ b/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/GlutenExplainUtils.scala @@ -154,91 +154,96 @@ object GlutenExplainUtils extends AdaptiveSparkPlanHelper { * 1. Generates the explain output for the input plan excluding the subquery plans. * 2. Generates the explain output for each subquery referenced in the plan. */ + // scalastyle:on + // spotless:on def processPlan[T <: QueryPlan[T]]( plan: T, append: String => Unit, - collectFallbackFunc: Option[QueryPlan[_] => FallbackInfo] = None): FallbackInfo = synchronized { - try { - // Initialize a reference-unique set of Operators to avoid accdiental overwrites and to allow - // intentional overwriting of IDs generated in previous AQE iteration - val operators = newSetFromMap[QueryPlan[_]](new util.IdentityHashMap()) - // Initialize an array of ReusedExchanges to help find Adaptively Optimized Out - // Exchanges as part of SPARK-42753 - val reusedExchanges = ArrayBuffer.empty[ReusedExchangeExec] + collectFallbackFunc: Option[QueryPlan[_] => FallbackInfo] = None): FallbackInfo = + synchronized { + SparkShimLoader.getSparkShims.withOperatorIdMap( + new java.util.IdentityHashMap[QueryPlan[_], Int]()) { + try { + // Initialize a reference-unique set of Operators to avoid accdiental overwrites and to allow + // intentional overwriting of IDs generated in previous AQE iteration + val operators = newSetFromMap[QueryPlan[_]](new util.IdentityHashMap()) + // Initialize an array of ReusedExchanges to help find Adaptively Optimized Out + // Exchanges as part of SPARK-42753 + val reusedExchanges = ArrayBuffer.empty[ReusedExchangeExec] - var currentOperatorID = 0 - currentOperatorID = - generateOperatorIDs(plan, currentOperatorID, operators, reusedExchanges, true) + var currentOperatorID = 0 + currentOperatorID = + generateOperatorIDs(plan, currentOperatorID, operators, reusedExchanges, true) - val subqueries = ArrayBuffer.empty[(SparkPlan, Expression, BaseSubqueryExec)] - getSubqueries(plan, subqueries) + val subqueries = ArrayBuffer.empty[(SparkPlan, Expression, BaseSubqueryExec)] + getSubqueries(plan, subqueries) - currentOperatorID = subqueries.foldLeft(currentOperatorID) { - (curId, plan) => generateOperatorIDs(plan._3.child, curId, operators, reusedExchanges, true) - } + currentOperatorID = subqueries.foldLeft(currentOperatorID) { + (curId, plan) => + generateOperatorIDs(plan._3.child, curId, operators, reusedExchanges, true) + } - // SPARK-42753: Process subtree for a ReusedExchange with unknown child - val optimizedOutExchanges = ArrayBuffer.empty[Exchange] - reusedExchanges.foreach { - reused => - val child = reused.child - if (!operators.contains(child)) { - optimizedOutExchanges.append(child) - currentOperatorID = - generateOperatorIDs(child, currentOperatorID, operators, reusedExchanges, false) + // SPARK-42753: Process subtree for a ReusedExchange with unknown child + val optimizedOutExchanges = ArrayBuffer.empty[Exchange] + reusedExchanges.foreach { + reused => + val child = reused.child + if (!operators.contains(child)) { + optimizedOutExchanges.append(child) + currentOperatorID = + generateOperatorIDs(child, currentOperatorID, operators, reusedExchanges, false) + } } - } - val collectedOperators = BitSet.empty - processPlanSkippingSubqueries(plan, append, collectedOperators) + val collectedOperators = BitSet.empty + processPlanSkippingSubqueries(plan, append, collectedOperators) - var i = 0 - for (sub <- subqueries) { - if (i == 0) { - append("\n===== Subqueries =====\n\n") - } - i = i + 1 - append( - s"Subquery:$i Hosting operator id = " + - s"${getOpId(sub._1)} Hosting Expression = ${sub._2}\n") + var i = 0 + for (sub <- subqueries) { + if (i == 0) { + append("\n===== Subqueries =====\n\n") + } + i = i + 1 + append( + s"Subquery:$i Hosting operator id = " + + s"${getOpId(sub._1)} Hosting Expression = ${sub._2}\n") - // For each subquery expression in the parent plan, process its child plan to compute - // the explain output. In case of subquery reuse, we don't print subquery plan more - // than once. So we skip [[ReusedSubqueryExec]] here. - if (!sub._3.isInstanceOf[ReusedSubqueryExec]) { - processPlanSkippingSubqueries(sub._3.child, append, collectedOperators) - } - append("\n") - } + // For each subquery expression in the parent plan, process its child plan to compute + // the explain output. In case of subquery reuse, we don't print subquery plan more + // than once. So we skip [[ReusedSubqueryExec]] here. + if (!sub._3.isInstanceOf[ReusedSubqueryExec]) { + processPlanSkippingSubqueries(sub._3.child, append, collectedOperators) + } + append("\n") + } - i = 0 - optimizedOutExchanges.foreach { - exchange => - if (i == 0) { - append("\n===== Adaptively Optimized Out Exchanges =====\n\n") + i = 0 + optimizedOutExchanges.foreach { + exchange => + if (i == 0) { + append("\n===== Adaptively Optimized Out Exchanges =====\n\n") + } + i = i + 1 + append(s"Subplan:$i\n") + processPlanSkippingSubqueries[SparkPlan](exchange, append, collectedOperators) + append("\n") } - i = i + 1 - append(s"Subplan:$i\n") - processPlanSkippingSubqueries[SparkPlan](exchange, append, collectedOperators) - append("\n") - } - (subqueries.filter(!_._3.isInstanceOf[ReusedSubqueryExec]).map(_._3.child) :+ plan) - .map { - plan => - if (collectFallbackFunc.isEmpty) { - collectFallbackNodes(plan) - } else { - collectFallbackFunc.get.apply(plan) + (subqueries.filter(!_._3.isInstanceOf[ReusedSubqueryExec]).map(_._3.child) :+ plan) + .map { + plan => + if (collectFallbackFunc.isEmpty) { + collectFallbackNodes(plan) + } else { + collectFallbackFunc.get.apply(plan) + } } + .reduce((a, b) => (a._1 + b._1, a._2 ++ b._2)) + } finally { + removeTags(plan) } - .reduce((a, b) => (a._1 + b._1, a._2 ++ b._2)) - } finally { - removeTags(plan) + } } - } - // scalastyle:on - // spotless:on /** * Traverses the supplied input plan in a bottom-up fashion and records the operator id via diff --git a/shims/common/src/main/scala/org/apache/gluten/sql/shims/SparkShims.scala b/shims/common/src/main/scala/org/apache/gluten/sql/shims/SparkShims.scala index fa2991604616..fba6a4a5a48a 100644 --- a/shims/common/src/main/scala/org/apache/gluten/sql/shims/SparkShims.scala +++ b/shims/common/src/main/scala/org/apache/gluten/sql/shims/SparkShims.scala @@ -272,12 +272,17 @@ trait SparkShims { throw new UnsupportedOperationException("ArrayInsert not supported.") } - /** Shim method for GlutenExplainUtils.scala. */ + /** Shim method for usages from GlutenExplainUtils.scala. */ + def withOperatorIdMap[T](idMap: java.util.Map[QueryPlan[_], Int])(body: => T): T = { + body + } + + /** Shim method for usages from GlutenExplainUtils.scala. */ def getOperatorId(plan: QueryPlan[_]): Option[Int] - /** Shim method for GlutenExplainUtils.scala. */ + /** Shim method for usages from GlutenExplainUtils.scala. */ def setOperatorId(plan: QueryPlan[_], opId: Int): Unit - /** Shim method for GlutenExplainUtils.scala. */ + /** Shim method for usages from GlutenExplainUtils.scala. */ def unsetOperatorId(plan: QueryPlan[_]): Unit } diff --git a/shims/spark35/src/main/scala/org/apache/gluten/sql/shims/spark35/Spark35Shims.scala b/shims/spark35/src/main/scala/org/apache/gluten/sql/shims/spark35/Spark35Shims.scala index d130864a9fed..43ed51579a1b 100644 --- a/shims/spark35/src/main/scala/org/apache/gluten/sql/shims/spark35/Spark35Shims.scala +++ b/shims/spark35/src/main/scala/org/apache/gluten/sql/shims/spark35/Spark35Shims.scala @@ -526,15 +526,27 @@ class Spark35Shims extends SparkShims { Seq(expr.srcArrayExpr, expr.posExpr, expr.itemExpr, Literal(expr.legacyNegativeIndex)) } + override def withOperatorIdMap[T](idMap: java.util.Map[QueryPlan[_], Int])(body: => T): T = { + val prevIdMap = QueryPlan.localIdMap.get() + try { + QueryPlan.localIdMap.set(idMap) + body + } finally { + QueryPlan.localIdMap.set(prevIdMap) + } + } + override def getOperatorId(plan: QueryPlan[_]): Option[Int] = { - plan.getTagValue(QueryPlan.OP_ID_TAG) + Option(QueryPlan.localIdMap.get().get(plan)) } override def setOperatorId(plan: QueryPlan[_], opId: Int): Unit = { - plan.setTagValue(QueryPlan.OP_ID_TAG, opId) + val map = QueryPlan.localIdMap.get() + assert(!map.containsKey(plan)) + map.put(plan, opId) } override def unsetOperatorId(plan: QueryPlan[_]): Unit = { - plan.unsetTagValue(QueryPlan.OP_ID_TAG) + QueryPlan.localIdMap.get().remove(plan) } } From 6edc04ca621ef4d040444887b3da8f04009f191b Mon Sep 17 00:00:00 2001 From: Hongze Zhang Date: Wed, 9 Oct 2024 18:52:35 +0800 Subject: [PATCH 3/4] fixup --- .../sql/shims/spark35/Spark35Shims.scala | 18 +++--------------- 1 file changed, 3 insertions(+), 15 deletions(-) diff --git a/shims/spark35/src/main/scala/org/apache/gluten/sql/shims/spark35/Spark35Shims.scala b/shims/spark35/src/main/scala/org/apache/gluten/sql/shims/spark35/Spark35Shims.scala index 43ed51579a1b..d130864a9fed 100644 --- a/shims/spark35/src/main/scala/org/apache/gluten/sql/shims/spark35/Spark35Shims.scala +++ b/shims/spark35/src/main/scala/org/apache/gluten/sql/shims/spark35/Spark35Shims.scala @@ -526,27 +526,15 @@ class Spark35Shims extends SparkShims { Seq(expr.srcArrayExpr, expr.posExpr, expr.itemExpr, Literal(expr.legacyNegativeIndex)) } - override def withOperatorIdMap[T](idMap: java.util.Map[QueryPlan[_], Int])(body: => T): T = { - val prevIdMap = QueryPlan.localIdMap.get() - try { - QueryPlan.localIdMap.set(idMap) - body - } finally { - QueryPlan.localIdMap.set(prevIdMap) - } - } - override def getOperatorId(plan: QueryPlan[_]): Option[Int] = { - Option(QueryPlan.localIdMap.get().get(plan)) + plan.getTagValue(QueryPlan.OP_ID_TAG) } override def setOperatorId(plan: QueryPlan[_], opId: Int): Unit = { - val map = QueryPlan.localIdMap.get() - assert(!map.containsKey(plan)) - map.put(plan, opId) + plan.setTagValue(QueryPlan.OP_ID_TAG, opId) } override def unsetOperatorId(plan: QueryPlan[_]): Unit = { - QueryPlan.localIdMap.get().remove(plan) + plan.unsetTagValue(QueryPlan.OP_ID_TAG) } } From 46816119c87bb0d05feb8ba79c2be3bfe6abcde5 Mon Sep 17 00:00:00 2001 From: Hongze Zhang Date: Wed, 9 Oct 2024 19:47:30 +0800 Subject: [PATCH 4/4] fixup --- .../apache/spark/sql/execution/GlutenExplainUtils.scala | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/GlutenExplainUtils.scala b/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/GlutenExplainUtils.scala index ec529a7b12e2..fa697789c8cf 100644 --- a/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/GlutenExplainUtils.scala +++ b/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/GlutenExplainUtils.scala @@ -151,8 +151,8 @@ object GlutenExplainUtils extends AdaptiveSparkPlanHelper { // scalastyle:off /** * Given a input physical plan, performs the following tasks. - * 1. Generates the explain output for the input plan excluding the subquery plans. - * 2. Generates the explain output for each subquery referenced in the plan. + * 1. Generates the explain output for the input plan excluding the subquery plans. 2. Generates + * the explain output for each subquery referenced in the plan. */ // scalastyle:on // spotless:on @@ -164,8 +164,8 @@ object GlutenExplainUtils extends AdaptiveSparkPlanHelper { SparkShimLoader.getSparkShims.withOperatorIdMap( new java.util.IdentityHashMap[QueryPlan[_], Int]()) { try { - // Initialize a reference-unique set of Operators to avoid accdiental overwrites and to allow - // intentional overwriting of IDs generated in previous AQE iteration + // Initialize a reference-unique set of Operators to avoid accdiental overwrites and to + // allow intentional overwriting of IDs generated in previous AQE iteration val operators = newSetFromMap[QueryPlan[_]](new util.IdentityHashMap()) // Initialize an array of ReusedExchanges to help find Adaptively Optimized Out // Exchanges as part of SPARK-42753