From e5daf3aa90574ff8a64a72a9daa0fbc20333bccd Mon Sep 17 00:00:00 2001 From: Hongze Zhang Date: Mon, 13 Jan 2025 14:44:28 +0800 Subject: [PATCH 1/7] fixup --- .../clickhouse/CHListenerApi.scala | 4 +- .../backendsapi/clickhouse/CHRuleApi.scala | 6 +- .../backendsapi/velox/VeloxListenerApi.scala | 6 +- .../backendsapi/velox/VeloxRuleApi.scala | 11 ++- .../execution/VeloxRoughCostModel2Suite.scala | 65 ------------ .../transition/VeloxTransitionSuite.scala | 4 +- .../enumerated/EnumeratedTransform.scala | 4 +- .../planner/cost/GlutenCostModel.scala | 30 ++++++ .../planner/cost/LongCostModel.scala | 33 +++---- .../planner/cost/LongCosterChain.scala | 2 +- .../columnar/transition/Convention.scala | 32 +++--- .../transition/FloydWarshallGraph.scala | 10 +- .../columnar/transition/TransitionGraph.scala | 99 ++++++++++++++----- .../extension/injector/GlutenInjector.scala | 7 +- .../transition/FloydWarshallGraphSuite.scala | 11 ++- .../planner/cost/LegacyCoster.scala | 4 +- .../enumerated/planner/cost/RoughCoster.scala | 2 - .../planner/cost/RoughCoster2.scala | 83 ---------------- .../apache/gluten/config/GlutenConfig.scala | 32 +----- 19 files changed, 177 insertions(+), 268 deletions(-) delete mode 100644 backends-velox/src/test/scala/org/apache/gluten/execution/VeloxRoughCostModel2Suite.scala create mode 100644 gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/planner/cost/GlutenCostModel.scala delete mode 100644 gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/enumerated/planner/cost/RoughCoster2.scala diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHListenerApi.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHListenerApi.scala index fbaf9e37c15f..48ef66ca74a8 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHListenerApi.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHListenerApi.scala @@ -23,6 +23,7 @@ import org.apache.gluten.execution.CHBroadcastBuildSideCache import org.apache.gluten.execution.datasource.GlutenFormatFactory import org.apache.gluten.expression.UDFMappings import org.apache.gluten.extension.ExpressionExtensionTrait +import org.apache.gluten.extension.columnar.transition.Convention import org.apache.gluten.jni.JniLibLoader import org.apache.gluten.vectorized.CHNativeExpressionEvaluator @@ -70,7 +71,8 @@ class CHListenerApi extends ListenerApi with Logging { override def onExecutorShutdown(): Unit = shutdown() private def initialize(conf: SparkConf, isDriver: Boolean): Unit = { - // Force batch type initializations. + // Do row / batch type initializations. + Convention.ensureSparkRowAndBatchTypesRegistered() CHBatch.ensureRegistered() SparkDirectoryUtil.init(conf) val libPath = conf.get(GlutenConfig.GLUTEN_LIB_PATH, StringUtils.EMPTY) diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala index 21ae342a2263..45c360978397 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala @@ -22,6 +22,8 @@ import org.apache.gluten.config.GlutenConfig import org.apache.gluten.extension._ import org.apache.gluten.extension.columnar._ import org.apache.gluten.extension.columnar.MiscColumnarRules.{RemoveGlutenTableCacheColumnarToRow, RemoveTopmostColumnarToRow, RewriteSubqueryBroadcast} +import org.apache.gluten.extension.columnar.enumerated.RemoveSort +import org.apache.gluten.extension.columnar.enumerated.planner.cost.{LegacyCoster, RoughCoster} import org.apache.gluten.extension.columnar.heuristic.{ExpandFallbackPolicy, HeuristicTransform} import org.apache.gluten.extension.columnar.offload.{OffloadExchange, OffloadJoin, OffloadOthers} import org.apache.gluten.extension.columnar.rewrite._ @@ -31,7 +33,6 @@ import org.apache.gluten.extension.injector.{Injector, SparkInjector} import org.apache.gluten.extension.injector.GlutenInjector.{LegacyInjector, RasInjector} import org.apache.gluten.parser.{GlutenCacheFilesSqlParser, GlutenClickhouseSqlParser} import org.apache.gluten.sql.shims.SparkShimLoader - import org.apache.spark.sql.catalyst.{CHAggregateFunctionRewriteRule, EqualToRewrite} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.delta.DeltaLogFileIndex @@ -142,6 +143,9 @@ object CHRuleApi { } private def injectRas(injector: RasInjector): Unit = { + // Register legacy coster for transition planner. + injector.injectCoster(_ => LegacyCoster) + // CH backend doesn't work with RAS at the moment. Inject a rule that aborts any // execution calls. injector.injectPreTransform( diff --git a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxListenerApi.scala b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxListenerApi.scala index 5d75521b8473..0453558d1af7 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxListenerApi.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxListenerApi.scala @@ -22,6 +22,7 @@ import org.apache.gluten.columnarbatch.VeloxBatch import org.apache.gluten.config.GlutenConfig import org.apache.gluten.execution.datasource.GlutenFormatFactory import org.apache.gluten.expression.UDFMappings +import org.apache.gluten.extension.columnar.transition.Convention import org.apache.gluten.init.NativeBackendInitializer import org.apache.gluten.jni.{JniLibLoader, JniWorkspace} import org.apache.gluten.udf.UdfJniWrapper @@ -126,10 +127,11 @@ class VeloxListenerApi extends ListenerApi with Logging { override def onExecutorShutdown(): Unit = shutdown() private def initialize(conf: SparkConf, isDriver: Boolean): Unit = { - // Force batch type initializations. - VeloxBatch.ensureRegistered() + // Do row / batch type initializations. + Convention.ensureSparkRowAndBatchTypesRegistered() ArrowJavaBatch.ensureRegistered() ArrowNativeBatch.ensureRegistered() + VeloxBatch.ensureRegistered() // Register columnar shuffle so can be considered when // `org.apache.spark.shuffle.GlutenShuffleManager` is set as Spark shuffle manager. diff --git a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala index f3c75cd98318..6c60ab7d537f 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala @@ -23,7 +23,7 @@ import org.apache.gluten.extension._ import org.apache.gluten.extension.columnar._ import org.apache.gluten.extension.columnar.MiscColumnarRules.{RemoveGlutenTableCacheColumnarToRow, RemoveTopmostColumnarToRow, RewriteSubqueryBroadcast} import org.apache.gluten.extension.columnar.enumerated.{RasOffload, RemoveSort} -import org.apache.gluten.extension.columnar.enumerated.planner.cost.{LegacyCoster, RoughCoster, RoughCoster2} +import org.apache.gluten.extension.columnar.enumerated.planner.cost.{LegacyCoster, RoughCoster} import org.apache.gluten.extension.columnar.heuristic.{ExpandFallbackPolicy, HeuristicTransform} import org.apache.gluten.extension.columnar.offload.{OffloadExchange, OffloadJoin, OffloadOthers} import org.apache.gluten.extension.columnar.rewrite._ @@ -120,6 +120,10 @@ object VeloxRuleApi { } private def injectRas(injector: RasInjector): Unit = { + // Gluten RAS: Costers. + injector.injectCoster(_ => LegacyCoster) + injector.injectCoster(_ => RoughCoster) + // Gluten RAS: Pre rules. injector.injectPreTransform(_ => RemoveTransitions) injector.injectPreTransform(_ => PushDownInputFileExpression.PreOffload) @@ -131,6 +135,7 @@ object VeloxRuleApi { // Gluten RAS: The RAS rule. val validatorBuilder: GlutenConfig => Validator = conf => Validators.newValidator(conf) + injector.injectRasRule(_ => RemoveSort) val rewrites = Seq( RewriteIn, @@ -139,10 +144,6 @@ object VeloxRuleApi { PullOutPreProject, PullOutPostProject, ProjectColumnPruning) - injector.injectCoster(_ => LegacyCoster) - injector.injectCoster(_ => RoughCoster) - injector.injectCoster(_ => RoughCoster2) - injector.injectRasRule(_ => RemoveSort) val offloads: Seq[RasOffload] = Seq( RasOffload.from[Exchange](OffloadExchange()), RasOffload.from[BaseJoinExec](OffloadJoin()), diff --git a/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxRoughCostModel2Suite.scala b/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxRoughCostModel2Suite.scala deleted file mode 100644 index cf61a7323665..000000000000 --- a/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxRoughCostModel2Suite.scala +++ /dev/null @@ -1,65 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.gluten.execution - -import org.apache.gluten.config.GlutenConfig - -import org.apache.spark.SparkConf -import org.apache.spark.sql.execution.ProjectExec - -class VeloxRoughCostModel2Suite extends VeloxWholeStageTransformerSuite { - override protected val resourcePath: String = "/tpch-data-parquet-velox" - override protected val fileFormat: String = "parquet" - - override def beforeAll(): Unit = { - super.beforeAll() - spark - .range(100) - .selectExpr("cast(id % 3 as int) as c1", "id as c2", "array(id, id + 1) as c3") - .write - .format("parquet") - .saveAsTable("tmp1") - } - - override protected def afterAll(): Unit = { - spark.sql("drop table tmp1") - super.afterAll() - } - - override protected def sparkConf: SparkConf = super.sparkConf - .set(GlutenConfig.RAS_ENABLED.key, "true") - .set(GlutenConfig.RAS_COST_MODEL.key, "rough2") - .set(GlutenConfig.VANILLA_VECTORIZED_READERS_ENABLED.key, "false") - - test("fallback trivial project if its neighbor nodes fell back") { - withSQLConf(GlutenConfig.COLUMNAR_FILESCAN_ENABLED.key -> "false") { - runQueryAndCompare("select c1 as c3 from tmp1") { - checkSparkOperatorMatch[ProjectExec] - } - } - } - - test("avoid adding r2c if r2c cost greater than native") { - withSQLConf( - GlutenConfig.COLUMNAR_FILESCAN_ENABLED.key -> "false", - GlutenConfig.RAS_ROUGH2_SIZEBYTES_THRESHOLD.key -> "1") { - runQueryAndCompare("select array_contains(c3, 0) as list from tmp1") { - checkSparkOperatorMatch[ProjectExec] - } - } - } -} diff --git a/backends-velox/src/test/scala/org/apache/gluten/extension/columnar/transition/VeloxTransitionSuite.scala b/backends-velox/src/test/scala/org/apache/gluten/extension/columnar/transition/VeloxTransitionSuite.scala index e14ffd43d82d..335844782a44 100644 --- a/backends-velox/src/test/scala/org/apache/gluten/extension/columnar/transition/VeloxTransitionSuite.scala +++ b/backends-velox/src/test/scala/org/apache/gluten/extension/columnar/transition/VeloxTransitionSuite.scala @@ -200,13 +200,13 @@ class VeloxTransitionSuite extends SharedSparkSession { } override protected def beforeAll(): Unit = { - api.onExecutorStart(MockVeloxBackend.mockPluginContext()) super.beforeAll() + api.onExecutorStart(MockVeloxBackend.mockPluginContext()) } override protected def afterAll(): Unit = { - super.afterAll() api.onExecutorShutdown() + super.afterAll() } } diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/EnumeratedTransform.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/EnumeratedTransform.scala index 34b4005a756d..59e829e17936 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/EnumeratedTransform.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/EnumeratedTransform.scala @@ -20,11 +20,11 @@ import org.apache.gluten.component.Component import org.apache.gluten.exception.GlutenException import org.apache.gluten.extension.columnar.ColumnarRuleApplier.ColumnarRuleCall import org.apache.gluten.extension.columnar.enumerated.planner.GlutenOptimization +import org.apache.gluten.extension.columnar.enumerated.planner.cost.GlutenCostModel import org.apache.gluten.extension.columnar.enumerated.planner.property.Conv import org.apache.gluten.extension.injector.Injector import org.apache.gluten.extension.util.AdaptiveContext import org.apache.gluten.logging.LogLevelUtil -import org.apache.gluten.ras.CostModel import org.apache.gluten.ras.property.PropertySet import org.apache.gluten.ras.rule.RasRule @@ -44,7 +44,7 @@ import org.apache.spark.sql.execution._ * * The feature requires enabling RAS to function. */ -case class EnumeratedTransform(costModel: CostModel[SparkPlan], rules: Seq[RasRule[SparkPlan]]) +case class EnumeratedTransform(costModel: GlutenCostModel, rules: Seq[RasRule[SparkPlan]]) extends Rule[SparkPlan] with LogLevelUtil { diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/planner/cost/GlutenCostModel.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/planner/cost/GlutenCostModel.scala new file mode 100644 index 000000000000..41e5529d2eba --- /dev/null +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/planner/cost/GlutenCostModel.scala @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.gluten.extension.columnar.enumerated.planner.cost + +import org.apache.gluten.ras.{Cost, CostModel} + +import org.apache.spark.sql.execution.SparkPlan + +trait GlutenCostModel extends CostModel[SparkPlan] { + // Returns cost value of one + other. + def sum(one: Cost, other: Cost): Cost + // Returns cost value of one - other. + def diff(one: Cost, other: Cost): Cost + + def makeZeroCost(): Cost +} diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/planner/cost/LongCostModel.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/planner/cost/LongCostModel.scala index 393ac35de42f..9d543ea272ee 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/planner/cost/LongCostModel.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/planner/cost/LongCostModel.scala @@ -18,34 +18,37 @@ package org.apache.gluten.extension.columnar.enumerated.planner.cost import org.apache.gluten.exception.GlutenException import org.apache.gluten.extension.columnar.enumerated.planner.plan.GlutenPlanModel.GroupLeafExec -import org.apache.gluten.ras.{Cost, CostModel} +import org.apache.gluten.ras.Cost import org.apache.spark.internal.Logging import org.apache.spark.sql.execution.SparkPlan import scala.collection.mutable -abstract class LongCostModel extends CostModel[SparkPlan] { +abstract class LongCostModel extends GlutenCostModel { private val infLongCost = Long.MaxValue + private val zeroLongCost = 0 override def costOf(node: SparkPlan): LongCost = node match { case _: GroupLeafExec => throw new IllegalStateException() case _ => LongCost(longCostOf(node)) } + override def sum(one: Cost, other: Cost): LongCost = (one, other) match { + case (LongCost(value), LongCost(otherValue)) => LongCost(Math.addExact(value, otherValue)) + } + // Returns cost value of one - other. + override def diff(one: Cost, other: Cost): Cost = (one, other) match { + case (LongCost(value), LongCost(otherValue)) => + val d = Math.subtractExact(value, otherValue) + require(d >= zeroLongCost, s"Difference between cost $one and $other should not be negative") + LongCost(d) + } + private def longCostOf(node: SparkPlan): Long = node match { case n => val selfCost = selfLongCostOf(n) - - // Sum with ceil to avoid overflow. - def safeSum(a: Long, b: Long): Long = { - assert(a >= 0) - assert(b >= 0) - val sum = a + b - if (sum < a || sum < b) Long.MaxValue else sum - } - - (n.children.map(longCostOf).toList :+ selfCost).reduce(safeSum) + (n.children.map(longCostOf).toSeq :+ selfCost).reduce[Long](Math.addExact) } def selfLongCostOf(node: SparkPlan): Long @@ -56,6 +59,7 @@ abstract class LongCostModel extends CostModel[SparkPlan] { } override def makeInfCost(): Cost = LongCost(infLongCost) + override def makeZeroCost(): Cost = LongCost(zeroLongCost) } object LongCostModel extends Logging { @@ -98,11 +102,6 @@ object LongCostModel extends Logging { override def name(): String = "rough" } - /** Compared with rough, rough2 can be more precise to avoid the costly r2c. */ - case object Rough2 extends Kind { - override def name(): String = "rough2" - } - class Registry private[LongCostModel] { private val lookup: mutable.Map[Kind, LongCosterChain.Builder] = mutable.Map() diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/planner/cost/LongCosterChain.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/planner/cost/LongCosterChain.scala index 8b0c8b9f2d8a..00980e7712a4 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/planner/cost/LongCosterChain.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/planner/cost/LongCosterChain.scala @@ -37,7 +37,7 @@ private class LongCosterChain private (costers: Seq[LongCoster]) extends LongCos case (c @ Some(_), _) => c } - .getOrElse(throw new GlutenException(s"Cost node found for node: $node")) + .getOrElse(throw new GlutenException(s"Cost not found for node: $node")) } } diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/Convention.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/Convention.scala index 0e5387559674..ff0f29585299 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/Convention.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/Convention.scala @@ -19,8 +19,6 @@ package org.apache.gluten.extension.columnar.transition import org.apache.spark.sql.execution.{ColumnarToRowExec, RowToColumnarExec, SparkPlan} import org.apache.spark.util.SparkVersionUtil -import java.util.concurrent.atomic.AtomicBoolean - import scala.collection.mutable /** @@ -33,6 +31,13 @@ sealed trait Convention { } object Convention { + def ensureSparkRowAndBatchTypesRegistered(): Unit = { + RowType.None.ensureRegistered() + RowType.VanillaRow.ensureRegistered() + BatchType.None.ensureRegistered() + BatchType.VanillaBatch.ensureRegistered() + } + implicit class ConventionOps(val conv: Convention) extends AnyVal { def isNone: Boolean = { conv.rowType == RowType.None && conv.batchType == BatchType.None @@ -80,10 +85,17 @@ object Convention { } sealed trait RowType extends TransitionGraph.Vertex with Serializable { - Transition.graph.addVertex(this) + import RowType._ + + final protected[this] def register0(): Unit = BatchType.synchronized { + assert(all.add(this)) + } } object RowType { + private val all: mutable.Set[RowType] = mutable.Set() + def values(): Set[RowType] = all.toSet + // None indicates that the plan doesn't support row-based processing. final case object None extends RowType final case object VanillaRow extends RowType @@ -91,24 +103,12 @@ object Convention { trait BatchType extends TransitionGraph.Vertex with Serializable { import BatchType._ - private val initialized: AtomicBoolean = new AtomicBoolean(false) - final def ensureRegistered(): Unit = { - if (!initialized.compareAndSet(false, true)) { - // Already registered. - return - } - register() - } - - final private def register(): Unit = BatchType.synchronized { + final protected[this] def register0(): Unit = BatchType.synchronized { assert(all.add(this)) - Transition.graph.addVertex(this) registerTransitions() } - ensureRegistered() - /** * User batch type could override this method to define transitions from/to this batch type by * calling the subsequent protected APIs. diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/FloydWarshallGraph.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/FloydWarshallGraph.scala index 2a4e1f422517..f497ebc88531 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/FloydWarshallGraph.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/FloydWarshallGraph.scala @@ -30,12 +30,11 @@ trait FloydWarshallGraph[V <: AnyRef, E <: AnyRef] { } object FloydWarshallGraph { - trait Cost { - def +(other: Cost): Cost - } + trait Cost trait CostModel[E <: AnyRef] { def zero(): Cost + def sum(one: Cost, other: Cost): Cost def costOf(edge: E): Cost def costComparator(): Ordering[Cost] } @@ -54,7 +53,10 @@ object FloydWarshallGraph { private case class Impl[E <: AnyRef](override val edges: Seq[E])(costModel: CostModel[E]) extends Path[E] { override val cost: Cost = { - edges.map(costModel.costOf).reduceOption(_ + _).getOrElse(costModel.zero()) + edges + .map(costModel.costOf) + .reduceOption((c1, c2) => costModel.sum(c1, c2)) + .getOrElse(costModel.zero()) } } } diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/TransitionGraph.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/TransitionGraph.scala index ef08a34d5615..1a1c97850d76 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/TransitionGraph.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/TransitionGraph.scala @@ -16,11 +16,36 @@ */ package org.apache.gluten.extension.columnar.transition +import org.apache.gluten.extension.columnar.enumerated.EnumeratedTransform +import org.apache.gluten.extension.columnar.transition.Convention.BatchType +import org.apache.gluten.ras.Cost + import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.util.SparkReflectionUtil +import java.util.concurrent.atomic.AtomicBoolean + +import scala.collection.mutable + object TransitionGraph { trait Vertex { + private val initialized: AtomicBoolean = new AtomicBoolean(false) + + final def ensureRegistered(): Unit = { + if (!initialized.compareAndSet(false, true)) { + // Already registered. + return + } + register() + } + + final private def register(): Unit = BatchType.synchronized { + Transition.graph.addVertex(this) + register0() + } + + protected[this] def register0(): Unit + override def toString: String = SparkReflectionUtil.getSimpleClassName(this.getClass) } @@ -67,54 +92,76 @@ object TransitionGraph { } } - private case class TransitionCost(count: Int, nodeNames: Seq[String]) - extends FloydWarshallGraph.Cost { - override def +(other: FloydWarshallGraph.Cost): TransitionCost = { - other match { - case TransitionCost(otherCount, otherNodeNames) => - TransitionCost(count + otherCount, nodeNames ++ otherNodeNames) - } - } - } + /** Reuse RAS cost to represent transition cost. */ + private case class TransitionCost(value: Cost, nodeNames: Seq[String]) + extends FloydWarshallGraph.Cost - // TODO: Consolidate transition graph's cost model with RAS cost model. + /** + * The cost model reuses RAS's cost model to evaluate cost of transitions. + * + * Note the transition graph is built once for all subsequent Spark sessions created on the same + * driver, so any access to Spark dynamic SQL config in RAS cost model will not take effect for + * the transition cost evaluation. Hence, it's not recommended to access Spark dynamic + * configurations in RAS cost model as well. + */ private object TransitionCostModel extends FloydWarshallGraph.CostModel[Transition] { - override def zero(): TransitionCost = TransitionCost(0, Nil) + private val rasCostModel = EnumeratedTransform.static().costModel + + override def zero(): TransitionCost = TransitionCost(rasCostModel.makeZeroCost(), Nil) override def costOf(transition: Transition): TransitionCost = { costOf0(transition) } + override def sum( + one: FloydWarshallGraph.Cost, + other: FloydWarshallGraph.Cost): FloydWarshallGraph.Cost = (one, other) match { + case (TransitionCost(c1, p1), TransitionCost(c2, p2)) => + TransitionCost(rasCostModel.sum(c1, c2), p1 ++ p2) + } override def costComparator(): Ordering[FloydWarshallGraph.Cost] = { (x: FloydWarshallGraph.Cost, y: FloydWarshallGraph.Cost) => (x, y) match { - case (TransitionCost(count, nodeNames), TransitionCost(otherCount, otherNodeNames)) => - if (count != otherCount) { - count - otherCount + case (TransitionCost(v1, nodeNames1), TransitionCost(v2, nodeNames2)) => + val diff = rasCostModel.costComparator().compare(v1, v2) + if (diff != 0) { + diff } else { // To make the output order stable. - nodeNames.mkString.hashCode - otherNodeNames.mkString.hashCode + nodeNames1.mkString.hashCode - nodeNames2.mkString.hashCode } } } private def costOf0(transition: Transition): TransitionCost = { val leaf = DummySparkPlan() + val transited = transition.apply(leaf) /** * The calculation considers C2C's cost as half of C2R / R2C's cost. So query planner prefers * C2C than C2R / R2C. */ - def costOfPlan(plan: SparkPlan): TransitionCost = plan - .map { - case p if p == leaf => TransitionCost(0, Nil) - case node @ RowToColumnarLike(_) => TransitionCost(2, Seq(node.nodeName)) - case node @ ColumnarToRowLike(_) => TransitionCost(2, Seq(node.nodeName)) - case node @ ColumnarToColumnarLike(_) => TransitionCost(1, Seq(node.nodeName)) - } - .reduce((l, r) => l + r) + def rasCostOfPlan(plan: SparkPlan): Cost = rasCostModel.costOf(plan) + def nodeNamesOfPlan(plan: SparkPlan): Seq[String] = { + plan.map(_.nodeName).reverse + } + + val leafCost = rasCostOfPlan(leaf) + val accumulatedCost = rasCostOfPlan(transited) + val costDiff = rasCostModel.diff(accumulatedCost, leafCost) + + val leafNodeNames = nodeNamesOfPlan(leaf) + val accumulatedNodeNames = nodeNamesOfPlan(transited) + require( + accumulatedNodeNames.startsWith(leafNodeNames), + s"Transition should only add unary nodes on the input plan or leave it unchanged. Before: $leaf, after: $transited" + ) + val nodeNamesDiff = mutable.ListBuffer[String]() + nodeNamesDiff ++= accumulatedNodeNames + leafNodeNames.foreach(n => assert(nodeNamesDiff.remove(0) == n)) + assert( + nodeNamesDiff.size == accumulatedNodeNames.size - leafNodeNames.size, + s"Dummy leaf node not found in the transited plan: $transited") - val plan = transition.apply(leaf) - val cost = costOfPlan(plan) - cost + TransitionCost(costDiff, nodeNamesDiff.toSeq) } } } diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/injector/GlutenInjector.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/injector/GlutenInjector.scala index 11172a9b3636..23db1c436da8 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/injector/GlutenInjector.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/injector/GlutenInjector.scala @@ -21,9 +21,8 @@ import org.apache.gluten.extension.GlutenColumnarRule import org.apache.gluten.extension.columnar.ColumnarRuleApplier import org.apache.gluten.extension.columnar.ColumnarRuleApplier.ColumnarRuleCall import org.apache.gluten.extension.columnar.enumerated.{EnumeratedApplier, EnumeratedTransform} -import org.apache.gluten.extension.columnar.enumerated.planner.cost.{LongCoster, LongCostModel} +import org.apache.gluten.extension.columnar.enumerated.planner.cost.{GlutenCostModel, LongCoster, LongCostModel} import org.apache.gluten.extension.columnar.heuristic.{HeuristicApplier, HeuristicTransform} -import org.apache.gluten.ras.CostModel import org.apache.gluten.ras.rule.RasRule import org.apache.spark.internal.Logging @@ -149,7 +148,7 @@ object GlutenInjector { private def findCostModel( registry: LongCostModel.Registry, - aliasOrClass: String): CostModel[SparkPlan] = { + aliasOrClass: String): GlutenCostModel = { if (LongCostModel.Kind.values().contains(aliasOrClass)) { val kind = LongCostModel.Kind.values()(aliasOrClass) val model = registry.get(kind) @@ -159,7 +158,7 @@ object GlutenInjector { logInfo(s"Using user cost model: $aliasOrClass") val ctor = clazz.getDeclaredConstructor() ctor.setAccessible(true) - val model: CostModel[SparkPlan] = ctor.newInstance() + val model: GlutenCostModel = ctor.newInstance().asInstanceOf[GlutenCostModel] model } } diff --git a/gluten-core/src/test/scala/org/apache/gluten/extension/columnar/transition/FloydWarshallGraphSuite.scala b/gluten-core/src/test/scala/org/apache/gluten/extension/columnar/transition/FloydWarshallGraphSuite.scala index 6bc4ab804f1d..76ef86d5986b 100644 --- a/gluten-core/src/test/scala/org/apache/gluten/extension/columnar/transition/FloydWarshallGraphSuite.scala +++ b/gluten-core/src/test/scala/org/apache/gluten/extension/columnar/transition/FloydWarshallGraphSuite.scala @@ -87,14 +87,15 @@ private object FloydWarshallGraphSuite { } } - private case class LongCost(c: Long) extends FloydWarshallGraph.Cost { - override def +(other: FloydWarshallGraph.Cost): FloydWarshallGraph.Cost = other match { - case LongCost(o) => LongCost(c + o) - } - } + private case class LongCost(c: Long) extends FloydWarshallGraph.Cost private object CostModel extends FloydWarshallGraph.CostModel[Edge] { override def zero(): FloydWarshallGraph.Cost = LongCost(0) + override def sum( + one: FloydWarshallGraph.Cost, + other: FloydWarshallGraph.Cost): FloydWarshallGraph.Cost = { + LongCost(one.asInstanceOf[LongCost].c + other.asInstanceOf[LongCost].c) + } override def costOf(edge: Edge): FloydWarshallGraph.Cost = LongCost(edge.distance * 10) override def costComparator(): Ordering[FloydWarshallGraph.Cost] = Ordering.Long.on { case LongCost(c) => c diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/enumerated/planner/cost/LegacyCoster.scala b/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/enumerated/planner/cost/LegacyCoster.scala index 5cf9b87f2ac1..bb89d0035bf8 100644 --- a/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/enumerated/planner/cost/LegacyCoster.scala +++ b/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/enumerated/planner/cost/LegacyCoster.scala @@ -19,7 +19,7 @@ package org.apache.gluten.extension.columnar.enumerated.planner.cost import org.apache.gluten.extension.columnar.transition.{ColumnarToColumnarLike, ColumnarToRowLike, RowToColumnarLike} import org.apache.gluten.utils.PlanUtil -import org.apache.spark.sql.execution.{ColumnarToRowExec, ColumnarWriteFilesExec, ProjectExec, RowToColumnarExec, SparkPlan} +import org.apache.spark.sql.execution.{ColumnarWriteFilesExec, ProjectExec, SparkPlan} object LegacyCoster extends LongCoster { override def kind(): LongCostModel.Kind = LongCostModel.Legacy @@ -34,8 +34,6 @@ object LegacyCoster extends LongCoster { private def selfCostOf0(node: SparkPlan): Long = { node match { case ColumnarWriteFilesExec.OnNoopLeafPath(_) => 0 - case ColumnarToRowExec(_) => 10L - case RowToColumnarExec(_) => 10L case ColumnarToRowLike(_) => 10L case RowToColumnarLike(_) => 10L case ColumnarToColumnarLike(_) => 5L diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/enumerated/planner/cost/RoughCoster.scala b/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/enumerated/planner/cost/RoughCoster.scala index d2959d46a13c..ab893265ec42 100644 --- a/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/enumerated/planner/cost/RoughCoster.scala +++ b/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/enumerated/planner/cost/RoughCoster.scala @@ -42,8 +42,6 @@ object RoughCoster extends LongCoster { // Avoid moving computation back to native when transition has complex types in schema. // Such transitions are observed to be extremely expensive as of now. Long.MaxValue - case ColumnarToRowExec(_) => 10L - case RowToColumnarExec(_) => 10L case ColumnarToRowLike(_) => 10L case RowToColumnarLike(_) => 10L case ColumnarToColumnarLike(_) => 5L diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/enumerated/planner/cost/RoughCoster2.scala b/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/enumerated/planner/cost/RoughCoster2.scala deleted file mode 100644 index e46274a79f69..000000000000 --- a/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/enumerated/planner/cost/RoughCoster2.scala +++ /dev/null @@ -1,83 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.gluten.extension.columnar.enumerated.planner.cost - -import org.apache.gluten.config.GlutenConfig -import org.apache.gluten.extension.columnar.transition.{ColumnarToRowLike, RowToColumnarLike} -import org.apache.gluten.utils.PlanUtil - -import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, NamedExpression} -import org.apache.spark.sql.execution._ -import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec -import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanExecBase - -// Since https://github.com/apache/incubator-gluten/pull/7686. -object RoughCoster2 extends LongCoster { - override def kind(): LongCostModel.Kind = LongCostModel.Rough2 - - override def selfCostOf(node: SparkPlan): Option[Long] = { - Some(selfCostOf0(node)) - } - - private def selfCostOf0(node: SparkPlan): Long = { - val sizeFactor = getSizeFactor(node) - val opCost = node match { - case ProjectExec(projectList, _) if projectList.forall(isCheapExpression) => - // Make trivial ProjectExec has the same cost as ProjectExecTransform to reduce unnecessary - // c2r and r2c. - 1L - case ColumnarToRowExec(_) => 1L - case RowToColumnarExec(_) => 1L - case ColumnarToRowLike(_) => 1L - case RowToColumnarLike(_) => - // If sizeBytes is less than the threshold, the cost of RowToColumnarLike is ignored. - if (sizeFactor == 0) 1L else GlutenConfig.get.rasRough2R2cCost - case p if PlanUtil.isGlutenColumnarOp(p) => 1L - case p if PlanUtil.isVanillaColumnarOp(p) => GlutenConfig.get.rasRough2VanillaCost - // Other row ops. Usually a vanilla row op. - case _ => GlutenConfig.get.rasRough2VanillaCost - } - opCost * Math.max(1, sizeFactor) - } - - private def getSizeFactor(plan: SparkPlan): Long = { - // Get the bytes size that the plan needs to consume. - val sizeBytes = plan match { - case _: DataSourceScanExec | _: DataSourceV2ScanExecBase => getStatSizeBytes(plan) - case _: LeafExecNode => 0L - case p => p.children.map(getStatSizeBytes).sum - } - sizeBytes / GlutenConfig.get.rasRough2SizeBytesThreshold - } - - private def getStatSizeBytes(plan: SparkPlan): Long = { - plan match { - case a: AdaptiveSparkPlanExec => getStatSizeBytes(a.inputPlan) - case _ => - plan.logicalLink match { - case Some(logicalPlan) => logicalPlan.stats.sizeInBytes.toLong - case _ => plan.children.map(getStatSizeBytes).sum - } - } - } - - private def isCheapExpression(ne: NamedExpression): Boolean = ne match { - case Alias(_: Attribute, _) => true - case _: Attribute => true - case _ => false - } -} diff --git a/shims/common/src/main/scala/org/apache/gluten/config/GlutenConfig.scala b/shims/common/src/main/scala/org/apache/gluten/config/GlutenConfig.scala index d4083d5896eb..eb9071badb14 100644 --- a/shims/common/src/main/scala/org/apache/gluten/config/GlutenConfig.scala +++ b/shims/common/src/main/scala/org/apache/gluten/config/GlutenConfig.scala @@ -295,12 +295,6 @@ class GlutenConfig(conf: SQLConf) extends Logging { def rasCostModel: String = getConf(RAS_COST_MODEL) - def rasRough2SizeBytesThreshold: Long = getConf(RAS_ROUGH2_SIZEBYTES_THRESHOLD) - - def rasRough2R2cCost: Long = getConf(RAS_ROUGH2_R2C_COST) - - def rasRough2VanillaCost: Long = getConf(RAS_ROUGH2_VANILLA_COST) - def enableVeloxCache: Boolean = getConf(COLUMNAR_VELOX_CACHE_ENABLED) def veloxMemCacheSize: Long = getConf(COLUMNAR_VELOX_MEM_CACHE_SIZE) @@ -1459,32 +1453,12 @@ object GlutenConfig { val RAS_COST_MODEL = buildConf("spark.gluten.ras.costModel") .doc( - "Experimental: The class name of user-defined cost model that will be used by RAS. If " + - "not specified, a legacy built-in cost model that exhaustively offloads computations " + - "will be used.") + "The class name of user-defined cost model that will be used by Gluten's transition " + + "planner as well as by RAS. If not specified, a legacy built-in cost model that " + + "exhaustively offloads computations will be used.") .stringConf .createWithDefaultString("legacy") - val RAS_ROUGH2_SIZEBYTES_THRESHOLD = - buildConf("spark.gluten.ras.rough2.sizeBytesThreshold") - .doc( - "Experimental: Threshold of the byte size consumed by sparkPlan, coefficient used " + - "to calculate cost in RAS rough2 model") - .longConf - .createWithDefault(1073741824L) - - val RAS_ROUGH2_R2C_COST = - buildConf("spark.gluten.ras.rough2.r2c.cost") - .doc("Experimental: Cost of RowToVeloxColumnarExec in RAS rough2 model") - .longConf - .createWithDefault(100L) - - val RAS_ROUGH2_VANILLA_COST = - buildConf("spark.gluten.ras.rough2.vanilla.cost") - .doc("Experimental: Cost of vanilla spark operater in RAS rough model") - .longConf - .createWithDefault(20L) - // velox caching options. val COLUMNAR_VELOX_CACHE_ENABLED = buildStaticConf("spark.gluten.sql.columnar.backend.velox.cacheEnabled") From e900097390fe408addab5f2e7057a31e6f9ac269 Mon Sep 17 00:00:00 2001 From: Hongze Zhang Date: Tue, 14 Jan 2025 12:10:00 +0800 Subject: [PATCH 2/7] fixup --- .../org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala | 1 + .../gluten/extension/columnar/transition/TransitionGraph.scala | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala index 45c360978397..debc89f6114f 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala @@ -33,6 +33,7 @@ import org.apache.gluten.extension.injector.{Injector, SparkInjector} import org.apache.gluten.extension.injector.GlutenInjector.{LegacyInjector, RasInjector} import org.apache.gluten.parser.{GlutenCacheFilesSqlParser, GlutenClickhouseSqlParser} import org.apache.gluten.sql.shims.SparkShimLoader + import org.apache.spark.sql.catalyst.{CHAggregateFunctionRewriteRule, EqualToRewrite} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.delta.DeltaLogFileIndex diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/TransitionGraph.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/TransitionGraph.scala index 1a1c97850d76..7d881ced1c51 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/TransitionGraph.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/TransitionGraph.scala @@ -152,7 +152,8 @@ object TransitionGraph { val accumulatedNodeNames = nodeNamesOfPlan(transited) require( accumulatedNodeNames.startsWith(leafNodeNames), - s"Transition should only add unary nodes on the input plan or leave it unchanged. Before: $leaf, after: $transited" + s"Transition should only add unary nodes on the input plan or leave it unchanged. " + + s"Before: $leaf, after: $transited" ) val nodeNamesDiff = mutable.ListBuffer[String]() nodeNamesDiff ++= accumulatedNodeNames From da3c6ee85c035af04ed7c031959600cc6275dd8c Mon Sep 17 00:00:00 2001 From: Hongze Zhang Date: Tue, 14 Jan 2025 14:29:53 +0800 Subject: [PATCH 3/7] fixup --- .../backendsapi/clickhouse/CHRuleApi.scala | 3 +- .../columnar/transition/ConventionFunc.scala | 22 --------- .../columnar/transition/TransitionSuite.scala | 47 ++++++++++++------- 3 files changed, 30 insertions(+), 42 deletions(-) diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala index debc89f6114f..426c88c9073f 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala @@ -22,8 +22,7 @@ import org.apache.gluten.config.GlutenConfig import org.apache.gluten.extension._ import org.apache.gluten.extension.columnar._ import org.apache.gluten.extension.columnar.MiscColumnarRules.{RemoveGlutenTableCacheColumnarToRow, RemoveTopmostColumnarToRow, RewriteSubqueryBroadcast} -import org.apache.gluten.extension.columnar.enumerated.RemoveSort -import org.apache.gluten.extension.columnar.enumerated.planner.cost.{LegacyCoster, RoughCoster} +import org.apache.gluten.extension.columnar.enumerated.planner.cost.LegacyCoster import org.apache.gluten.extension.columnar.heuristic.{ExpandFallbackPolicy, HeuristicTransform} import org.apache.gluten.extension.columnar.offload.{OffloadExchange, OffloadJoin, OffloadOthers} import org.apache.gluten.extension.columnar.rewrite._ diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/ConventionFunc.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/ConventionFunc.scala index c4405aeb8d0a..3105713d989d 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/ConventionFunc.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/ConventionFunc.scala @@ -24,7 +24,6 @@ import org.apache.spark.sql.execution.{ColumnarToRowExec, SparkPlan, UnionExec} import org.apache.spark.sql.execution.adaptive.QueryStageExec import org.apache.spark.sql.execution.command.DataWritingCommandExec import org.apache.spark.sql.execution.exchange.ReusedExchangeExec -import org.apache.spark.util.SparkTestUtil /** ConventionFunc is a utility to derive [[Convention]] or [[ConventionReq]] from a query plan. */ sealed trait ConventionFunc { @@ -43,33 +42,12 @@ object ConventionFunc { object Empty extends Override } - // For testing, to make things work without a backend loaded. - private var ignoreBackend: Boolean = false - - // Visible for testing. - def ignoreBackend[T](body: => T): T = synchronized { - assert(SparkTestUtil.isTesting) - assert(!ignoreBackend) - ignoreBackend = true - try { - body - } finally { - ignoreBackend = false - } - } - def create(): ConventionFunc = { val batchOverride = newOverride() new BuiltinFunc(batchOverride) } private def newOverride(): Override = { - synchronized { - if (ignoreBackend) { - // For testing - return Override.Empty - } - } // Components should override Backend's convention function. Hence, reversed injection order // is applied. val overrides = Component.sorted().reverse.map(_.convFuncOverride()) diff --git a/gluten-substrait/src/test/scala/org/apache/gluten/extension/columnar/transition/TransitionSuite.scala b/gluten-substrait/src/test/scala/org/apache/gluten/extension/columnar/transition/TransitionSuite.scala index fec36ac1acfa..2c423783fdcc 100644 --- a/gluten-substrait/src/test/scala/org/apache/gluten/extension/columnar/transition/TransitionSuite.scala +++ b/gluten-substrait/src/test/scala/org/apache/gluten/extension/columnar/transition/TransitionSuite.scala @@ -16,8 +16,12 @@ */ package org.apache.gluten.extension.columnar.transition +import org.apache.gluten.backend.Backend +import org.apache.gluten.component.Component import org.apache.gluten.exception.GlutenException import org.apache.gluten.execution.{ColumnarToColumnarExec, GlutenPlan} +import org.apache.gluten.extension.columnar.enumerated.planner.cost.LegacyCoster +import org.apache.gluten.extension.injector.Injector import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow @@ -28,35 +32,38 @@ import org.apache.spark.sql.vectorized.ColumnarBatch class TransitionSuite extends SharedSparkSession { import TransitionSuite._ + + override protected def beforeAll(): Unit = { + super.beforeAll() + new DummyBackend().ensureRegistered() + Convention.ensureSparkRowAndBatchTypesRegistered() + TypeA.ensureRegistered() + TypeB.ensureRegistered() + TypeC.ensureRegistered() + TypeD.ensureRegistered() + } + test("Trivial C2R") { val in = BatchLeaf(TypeA) - val out = ConventionFunc.ignoreBackend { - Transitions.insert(in, outputsColumnar = false) - } + val out = Transitions.insert(in, outputsColumnar = false) assert(out == BatchToRow(TypeA, BatchLeaf(TypeA))) } test("Insert C2R") { val in = RowUnary(BatchLeaf(TypeA)) - val out = ConventionFunc.ignoreBackend { - Transitions.insert(in, outputsColumnar = false) - } + val out = Transitions.insert(in, outputsColumnar = false) assert(out == RowUnary(BatchToRow(TypeA, BatchLeaf(TypeA)))) } test("Insert R2C") { val in = BatchUnary(TypeA, RowLeaf()) - val out = ConventionFunc.ignoreBackend { - Transitions.insert(in, outputsColumnar = false) - } + val out = Transitions.insert(in, outputsColumnar = false) assert(out == BatchToRow(TypeA, BatchUnary(TypeA, RowToBatch(TypeA, RowLeaf())))) } test("Insert C2R2C") { val in = BatchUnary(TypeA, BatchLeaf(TypeB)) - val out = ConventionFunc.ignoreBackend { - Transitions.insert(in, outputsColumnar = false) - } + val out = Transitions.insert(in, outputsColumnar = false) assert( out == BatchToRow( TypeA, @@ -65,9 +72,7 @@ class TransitionSuite extends SharedSparkSession { test("Insert C2C") { val in = BatchUnary(TypeA, BatchLeaf(TypeC)) - val out = ConventionFunc.ignoreBackend { - Transitions.insert(in, outputsColumnar = false) - } + val out = Transitions.insert(in, outputsColumnar = false) assert( out == BatchToRow( TypeA, @@ -77,9 +82,7 @@ class TransitionSuite extends SharedSparkSession { test("No transitions found") { val in = BatchUnary(TypeA, BatchLeaf(TypeD)) assertThrows[GlutenException] { - ConventionFunc.ignoreBackend { - Transitions.insert(in, outputsColumnar = false) - } + Transitions.insert(in, outputsColumnar = false) } } } @@ -145,4 +148,12 @@ object TransitionSuite extends TransitionSuiteBase { throw new UnsupportedOperationException() } + class DummyBackend extends Backend { + override def name(): String = "dummy-backend" + override def buildInfo(): Component.BuildInfo = + Component.BuildInfo("DUMMY_BACKEND", "N/A", "N/A", "N/A") + override def injectRules(injector: Injector): Unit = { + injector.gluten.ras.injectCoster(_ => LegacyCoster) + } + } } From 41e70ec316e02c3f9018faf72d322b8bfa5f67ab Mon Sep 17 00:00:00 2001 From: Hongze Zhang Date: Tue, 14 Jan 2025 15:36:17 +0800 Subject: [PATCH 4/7] fixup --- .../transition/FloydWarshallGraph.scala | 32 +++++++++++-------- .../columnar/transition/TransitionGraph.scala | 2 +- .../transition/FloydWarshallGraphSuite.scala | 2 +- 3 files changed, 20 insertions(+), 16 deletions(-) diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/FloydWarshallGraph.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/FloydWarshallGraph.scala index f497ebc88531..307fdb510507 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/FloydWarshallGraph.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/FloydWarshallGraph.scala @@ -41,18 +41,17 @@ object FloydWarshallGraph { trait Path[E <: AnyRef] { def edges(): Seq[E] - def cost(): Cost + def cost(costModel: CostModel[E]): Cost } - def builder[V <: AnyRef, E <: AnyRef](costModel: CostModel[E]): Builder[V, E] = { - Builder.create(costModel) + def builder[V <: AnyRef, E <: AnyRef](costModelFactory: () => CostModel[E]): Builder[V, E] = { + Builder.create(costModelFactory) } private object Path { - def apply[E <: AnyRef](costModel: CostModel[E], edges: Seq[E]): Path[E] = Impl(edges)(costModel) - private case class Impl[E <: AnyRef](override val edges: Seq[E])(costModel: CostModel[E]) - extends Path[E] { - override val cost: Cost = { + def apply[E <: AnyRef](edges: Seq[E]): Path[E] = Impl(edges) + private case class Impl[E <: AnyRef](override val edges: Seq[E]) extends Path[E] { + override def cost(costModel: CostModel[E]): Cost = { edges .map(costModel.costOf) .reduceOption((c1, c2) => costModel.sum(c1, c2)) @@ -89,13 +88,13 @@ object FloydWarshallGraph { private object Builder { // Thread safe. - private class Impl[V <: AnyRef, E <: AnyRef](costModel: CostModel[E]) extends Builder[V, E] { + private class Impl[V <: AnyRef, E <: AnyRef](costModelFactory: () => CostModel[E]) extends Builder[V, E] { private val pathTable: mutable.Map[V, mutable.Map[V, Path[E]]] = mutable.Map() private var graph: Option[FloydWarshallGraph[V, E]] = None override def addVertex(v: V): Builder[V, E] = synchronized { assert(!pathTable.contains(v), s"Vertex $v already exists in graph") - pathTable.getOrElseUpdate(v, mutable.Map()).getOrElseUpdate(v, Path(costModel, Nil)) + pathTable.getOrElseUpdate(v, mutable.Map()).getOrElseUpdate(v, Path(Nil)) graph = None this } @@ -105,7 +104,7 @@ object FloydWarshallGraph { assert(pathTable.contains(from), s"Vertex $from not exists in graph") assert(pathTable.contains(to), s"Vertex $to not exists in graph") assert(!hasPath(from, to), s"Path from $from to $to already exists in graph") - pathTable(from) += to -> Path(costModel, Seq(edge)) + pathTable(from) += to -> Path(Seq(edge)) graph = None this } @@ -129,6 +128,7 @@ object FloydWarshallGraph { } private def compile(): FloydWarshallGraph[V, E] = { + val costModel = costModelFactory() val vertices = pathTable.keys for (k <- vertices) { for (i <- vertices) { @@ -136,12 +136,16 @@ object FloydWarshallGraph { if (hasPath(i, k) && hasPath(k, j)) { val pathIk = pathTable(i)(k) val pathKj = pathTable(k)(j) - val newPath = Path(costModel, pathIk.edges() ++ pathKj.edges()) + val newPath = Path(pathIk.edges() ++ pathKj.edges()) if (!hasPath(i, j)) { pathTable(i) += j -> newPath } else { val path = pathTable(i)(j) - if (costModel.costComparator().compare(newPath.cost(), path.cost()) < 0) { + if ( + costModel + .costComparator() + .compare(newPath.cost(costModel), path.cost(costModel)) < 0 + ) { pathTable(i) += j -> newPath } } @@ -153,8 +157,8 @@ object FloydWarshallGraph { } } - def create[V <: AnyRef, E <: AnyRef](costModel: CostModel[E]): Builder[V, E] = { - new Impl(costModel) + def create[V <: AnyRef, E <: AnyRef](costModelFactory: () => CostModel[E]): Builder[V, E] = { + new Impl(costModelFactory) } } } diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/TransitionGraph.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/TransitionGraph.scala index 7d881ced1c51..0aa20ce66d42 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/TransitionGraph.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/TransitionGraph.scala @@ -52,7 +52,7 @@ object TransitionGraph { type Builder = FloydWarshallGraph.Builder[TransitionGraph.Vertex, Transition] def builder(): Builder = { - FloydWarshallGraph.builder(TransitionCostModel) + FloydWarshallGraph.builder(() => TransitionCostModel) } implicit class TransitionGraphOps(val graph: TransitionGraph) { diff --git a/gluten-core/src/test/scala/org/apache/gluten/extension/columnar/transition/FloydWarshallGraphSuite.scala b/gluten-core/src/test/scala/org/apache/gluten/extension/columnar/transition/FloydWarshallGraphSuite.scala index 76ef86d5986b..7b60940a1ae2 100644 --- a/gluten-core/src/test/scala/org/apache/gluten/extension/columnar/transition/FloydWarshallGraphSuite.scala +++ b/gluten-core/src/test/scala/org/apache/gluten/extension/columnar/transition/FloydWarshallGraphSuite.scala @@ -36,7 +36,7 @@ class FloydWarshallGraphSuite extends AnyFunSuite { val e42 = Edge(3) val graph = FloydWarshallGraph - .builder(CostModel) + .builder(() => CostModel) .addVertex(v0) .addVertex(v1) .addVertex(v2) From 8ab345eb83dfa0611ec1adbde5f8e9d40a1d5190 Mon Sep 17 00:00:00 2001 From: Hongze Zhang Date: Tue, 14 Jan 2025 15:42:31 +0800 Subject: [PATCH 5/7] fixup --- .../extension/columnar/transition/FloydWarshallGraph.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/FloydWarshallGraph.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/FloydWarshallGraph.scala index 307fdb510507..b05e93968711 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/FloydWarshallGraph.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/FloydWarshallGraph.scala @@ -88,7 +88,8 @@ object FloydWarshallGraph { private object Builder { // Thread safe. - private class Impl[V <: AnyRef, E <: AnyRef](costModelFactory: () => CostModel[E]) extends Builder[V, E] { + private class Impl[V <: AnyRef, E <: AnyRef](costModelFactory: () => CostModel[E]) + extends Builder[V, E] { private val pathTable: mutable.Map[V, mutable.Map[V, Path[E]]] = mutable.Map() private var graph: Option[FloydWarshallGraph[V, E]] = None From fe44748a06166496fa571d77d7935fafb8ca4a46 Mon Sep 17 00:00:00 2001 From: Hongze Zhang Date: Tue, 14 Jan 2025 15:45:31 +0800 Subject: [PATCH 6/7] fixup --- .../extension/columnar/transition/TransitionGraph.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/TransitionGraph.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/TransitionGraph.scala index 0aa20ce66d42..8e9744383107 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/TransitionGraph.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/TransitionGraph.scala @@ -52,7 +52,7 @@ object TransitionGraph { type Builder = FloydWarshallGraph.Builder[TransitionGraph.Vertex, Transition] def builder(): Builder = { - FloydWarshallGraph.builder(() => TransitionCostModel) + FloydWarshallGraph.builder(() => new TransitionCostModel()) } implicit class TransitionGraphOps(val graph: TransitionGraph) { @@ -104,7 +104,7 @@ object TransitionGraph { * the transition cost evaluation. Hence, it's not recommended to access Spark dynamic * configurations in RAS cost model as well. */ - private object TransitionCostModel extends FloydWarshallGraph.CostModel[Transition] { + private class TransitionCostModel() extends FloydWarshallGraph.CostModel[Transition] { private val rasCostModel = EnumeratedTransform.static().costModel override def zero(): TransitionCost = TransitionCost(rasCostModel.makeZeroCost(), Nil) From 94073c80370bb3dc7552811765d0e79e4ca53bac Mon Sep 17 00:00:00 2001 From: Hongze Zhang Date: Tue, 14 Jan 2025 16:59:41 +0800 Subject: [PATCH 7/7] fixup --- .../enumerated/planner/VeloxRasSuite.scala | 25 +++++++++++++------ .../planner/cost/LongCostModel.scala | 13 ++++++++-- 2 files changed, 29 insertions(+), 9 deletions(-) diff --git a/backends-velox/src/test/scala/org/apache/gluten/extension/columnar/enumerated/planner/VeloxRasSuite.scala b/backends-velox/src/test/scala/org/apache/gluten/extension/columnar/enumerated/planner/VeloxRasSuite.scala index 65d32ebf6162..e7de629b39e3 100644 --- a/backends-velox/src/test/scala/org/apache/gluten/extension/columnar/enumerated/planner/VeloxRasSuite.scala +++ b/backends-velox/src/test/scala/org/apache/gluten/extension/columnar/enumerated/planner/VeloxRasSuite.scala @@ -18,10 +18,10 @@ package org.apache.gluten.extension.columnar.enumerated.planner import org.apache.gluten.config.GlutenConfig import org.apache.gluten.extension.columnar.enumerated.EnumeratedTransform -import org.apache.gluten.extension.columnar.enumerated.planner.cost.{LegacyCoster, LongCostModel} +import org.apache.gluten.extension.columnar.enumerated.planner.cost.{GlutenCostModel, LegacyCoster, LongCostModel} import org.apache.gluten.extension.columnar.enumerated.planner.property.Conv -import org.apache.gluten.extension.columnar.transition.ConventionReq -import org.apache.gluten.ras.{Cost, CostModel, Ras} +import org.apache.gluten.extension.columnar.transition.{Convention, ConventionReq} +import org.apache.gluten.ras.{Cost, Ras} import org.apache.gluten.ras.RasSuiteBase._ import org.apache.gluten.ras.path.RasPath import org.apache.gluten.ras.property.PropertySet @@ -37,6 +37,11 @@ import org.apache.spark.sql.types.StringType class VeloxRasSuite extends SharedSparkSession { import VeloxRasSuite._ + override protected def beforeAll(): Unit = { + super.beforeAll() + Convention.ensureSparkRowAndBatchTypesRegistered() + } + test("C2R, R2C - basic") { val in = RowUnary(RowLeaf(TRIVIAL_SCHEMA)) val planner = newRas().newPlanner(in) @@ -153,14 +158,14 @@ object VeloxRasSuite { .asInstanceOf[Ras[SparkPlan]] } - private def legacyCostModel(): CostModel[SparkPlan] = { + private def legacyCostModel(): GlutenCostModel = { val registry = LongCostModel.registry() val coster = LegacyCoster registry.register(coster) registry.get(coster.kind()) } - private def sessionCostModel(): CostModel[SparkPlan] = { + private def sessionCostModel(): GlutenCostModel = { val transform = EnumeratedTransform.static() transform.costModel } @@ -198,7 +203,7 @@ object VeloxRasSuite { override def shape(): Shape[SparkPlan] = Shapes.fixedHeight(1) } - class UserCostModel1 extends CostModel[SparkPlan] { + class UserCostModel1 extends GlutenCostModel { private val base = legacyCostModel() override def costOf(node: SparkPlan): Cost = node match { case _: RowUnary => base.makeInfCost() @@ -206,9 +211,12 @@ object VeloxRasSuite { } override def costComparator(): Ordering[Cost] = base.costComparator() override def makeInfCost(): Cost = base.makeInfCost() + override def sum(one: Cost, other: Cost): Cost = base.sum(one, other) + override def diff(one: Cost, other: Cost): Cost = base.diff(one, other) + override def makeZeroCost(): Cost = base.makeZeroCost() } - class UserCostModel2 extends CostModel[SparkPlan] { + class UserCostModel2 extends GlutenCostModel { private val base = legacyCostModel() override def costOf(node: SparkPlan): Cost = node match { case _: ColumnarUnary => base.makeInfCost() @@ -216,5 +224,8 @@ object VeloxRasSuite { } override def costComparator(): Ordering[Cost] = base.costComparator() override def makeInfCost(): Cost = base.makeInfCost() + override def sum(one: Cost, other: Cost): Cost = base.sum(one, other) + override def diff(one: Cost, other: Cost): Cost = base.diff(one, other) + override def makeZeroCost(): Cost = base.makeZeroCost() } } diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/planner/cost/LongCostModel.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/planner/cost/LongCostModel.scala index 9d543ea272ee..0d11541b73dd 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/planner/cost/LongCostModel.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/planner/cost/LongCostModel.scala @@ -34,9 +34,18 @@ abstract class LongCostModel extends GlutenCostModel { case _ => LongCost(longCostOf(node)) } + // Sum with ceil to avoid overflow. + private def safeSum(a: Long, b: Long): Long = { + assert(a >= 0) + assert(b >= 0) + val sum = a + b + if (sum < a || sum < b) Long.MaxValue else sum + } + override def sum(one: Cost, other: Cost): LongCost = (one, other) match { - case (LongCost(value), LongCost(otherValue)) => LongCost(Math.addExact(value, otherValue)) + case (LongCost(value), LongCost(otherValue)) => LongCost(safeSum(value, otherValue)) } + // Returns cost value of one - other. override def diff(one: Cost, other: Cost): Cost = (one, other) match { case (LongCost(value), LongCost(otherValue)) => @@ -48,7 +57,7 @@ abstract class LongCostModel extends GlutenCostModel { private def longCostOf(node: SparkPlan): Long = node match { case n => val selfCost = selfLongCostOf(n) - (n.children.map(longCostOf).toSeq :+ selfCost).reduce[Long](Math.addExact) + (n.children.map(longCostOf).toSeq :+ selfCost).reduce[Long](safeSum) } def selfLongCostOf(node: SparkPlan): Long