diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala index aa9e3e553c17..962698759320 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala @@ -24,6 +24,7 @@ import org.apache.gluten.config.GlutenConfig import org.apache.gluten.execution.WriteFilesExecTransformer import org.apache.gluten.expression.WindowFunctionsBuilder import org.apache.gluten.extension.ValidationResult +import org.apache.gluten.extension.columnar.cost.{LegacyCoster, LongCoster} import org.apache.gluten.extension.columnar.transition.{Convention, ConventionFunc} import org.apache.gluten.substrait.rel.LocalFilesNode import org.apache.gluten.substrait.rel.LocalFilesNode.ReadFileFormat @@ -54,7 +55,6 @@ class CHBackend extends SubstraitBackend { override def name(): String = CHConf.BACKEND_NAME override def buildInfo(): BuildInfo = BuildInfo("ClickHouse", CH_BRANCH, CH_COMMIT, "UNKNOWN") - override def convFuncOverride(): ConventionFunc.Override = new ConvFunc() override def iteratorApi(): IteratorApi = new CHIteratorApi override def sparkPlanExecApi(): SparkPlanExecApi = new CHSparkPlanExecApi override def transformerApi(): TransformerApi = new CHTransformerApi @@ -63,6 +63,8 @@ class CHBackend extends SubstraitBackend { override def listenerApi(): ListenerApi = new CHListenerApi override def ruleApi(): RuleApi = new CHRuleApi override def settings(): BackendSettingsApi = CHBackendSettings + override def convFuncOverride(): ConventionFunc.Override = new ConvFunc() + override def costers(): Seq[LongCoster] = Seq(LegacyCoster) } object CHBackend { diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala index 426c88c9073f..21ae342a2263 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala @@ -22,7 +22,6 @@ import org.apache.gluten.config.GlutenConfig import org.apache.gluten.extension._ import org.apache.gluten.extension.columnar._ import org.apache.gluten.extension.columnar.MiscColumnarRules.{RemoveGlutenTableCacheColumnarToRow, RemoveTopmostColumnarToRow, RewriteSubqueryBroadcast} -import org.apache.gluten.extension.columnar.enumerated.planner.cost.LegacyCoster import org.apache.gluten.extension.columnar.heuristic.{ExpandFallbackPolicy, HeuristicTransform} import org.apache.gluten.extension.columnar.offload.{OffloadExchange, OffloadJoin, OffloadOthers} import org.apache.gluten.extension.columnar.rewrite._ @@ -143,9 +142,6 @@ object CHRuleApi { } private def injectRas(injector: RasInjector): Unit = { - // Register legacy coster for transition planner. - injector.injectCoster(_ => LegacyCoster) - // CH backend doesn't work with RAS at the moment. Inject a rule that aborts any // execution calls. injector.injectPreTransform( diff --git a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxBackend.scala b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxBackend.scala index 519e98c5d459..677d8792c7ba 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxBackend.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxBackend.scala @@ -25,6 +25,7 @@ import org.apache.gluten.exception.GlutenNotSupportException import org.apache.gluten.execution.WriteFilesExecTransformer import org.apache.gluten.expression.WindowFunctionsBuilder import org.apache.gluten.extension.ValidationResult +import org.apache.gluten.extension.columnar.cost.{LegacyCoster, LongCoster, RoughCoster} import org.apache.gluten.extension.columnar.transition.{Convention, ConventionFunc} import org.apache.gluten.sql.shims.SparkShimLoader import org.apache.gluten.substrait.rel.LocalFilesNode @@ -61,7 +62,6 @@ class VeloxBackend extends SubstraitBackend { override def name(): String = VeloxBackend.BACKEND_NAME override def buildInfo(): BuildInfo = BuildInfo("Velox", VELOX_BRANCH, VELOX_REVISION, VELOX_REVISION_TIME) - override def convFuncOverride(): ConventionFunc.Override = new ConvFunc() override def iteratorApi(): IteratorApi = new VeloxIteratorApi override def sparkPlanExecApi(): SparkPlanExecApi = new VeloxSparkPlanExecApi override def transformerApi(): TransformerApi = new VeloxTransformerApi @@ -70,6 +70,8 @@ class VeloxBackend extends SubstraitBackend { override def listenerApi(): ListenerApi = new VeloxListenerApi override def ruleApi(): RuleApi = new VeloxRuleApi override def settings(): BackendSettingsApi = VeloxBackendSettings + override def convFuncOverride(): ConventionFunc.Override = new ConvFunc() + override def costers(): Seq[LongCoster] = Seq(LegacyCoster, RoughCoster) } object VeloxBackend { diff --git a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala index 6c60ab7d537f..0cf6ac671309 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala @@ -23,7 +23,6 @@ import org.apache.gluten.extension._ import org.apache.gluten.extension.columnar._ import org.apache.gluten.extension.columnar.MiscColumnarRules.{RemoveGlutenTableCacheColumnarToRow, RemoveTopmostColumnarToRow, RewriteSubqueryBroadcast} import org.apache.gluten.extension.columnar.enumerated.{RasOffload, RemoveSort} -import org.apache.gluten.extension.columnar.enumerated.planner.cost.{LegacyCoster, RoughCoster} import org.apache.gluten.extension.columnar.heuristic.{ExpandFallbackPolicy, HeuristicTransform} import org.apache.gluten.extension.columnar.offload.{OffloadExchange, OffloadJoin, OffloadOthers} import org.apache.gluten.extension.columnar.rewrite._ @@ -120,10 +119,6 @@ object VeloxRuleApi { } private def injectRas(injector: RasInjector): Unit = { - // Gluten RAS: Costers. - injector.injectCoster(_ => LegacyCoster) - injector.injectCoster(_ => RoughCoster) - // Gluten RAS: Pre rules. injector.injectPreTransform(_ => RemoveTransitions) injector.injectPreTransform(_ => PushDownInputFileExpression.PreOffload) diff --git a/backends-velox/src/test/scala/org/apache/gluten/execution/MiscOperatorSuite.scala b/backends-velox/src/test/scala/org/apache/gluten/execution/MiscOperatorSuite.scala index 5b6b8a30edfa..f0be15f07db8 100644 --- a/backends-velox/src/test/scala/org/apache/gluten/execution/MiscOperatorSuite.scala +++ b/backends-velox/src/test/scala/org/apache/gluten/execution/MiscOperatorSuite.scala @@ -34,7 +34,6 @@ import java.util.concurrent.TimeUnit import scala.collection.JavaConverters class MiscOperatorSuite extends VeloxWholeStageTransformerSuite with AdaptiveSparkPlanHelper { - protected val rootPath: String = getClass.getResource("/").getPath override protected val resourcePath: String = "/tpch-data-parquet" override protected val fileFormat: String = "parquet" diff --git a/backends-velox/src/test/scala/org/apache/gluten/extension/columnar/enumerated/planner/VeloxRasSuite.scala b/backends-velox/src/test/scala/org/apache/gluten/extension/columnar/enumerated/planner/VeloxRasSuite.scala index e7de629b39e3..050a881394c1 100644 --- a/backends-velox/src/test/scala/org/apache/gluten/extension/columnar/enumerated/planner/VeloxRasSuite.scala +++ b/backends-velox/src/test/scala/org/apache/gluten/extension/columnar/enumerated/planner/VeloxRasSuite.scala @@ -17,11 +17,11 @@ package org.apache.gluten.extension.columnar.enumerated.planner import org.apache.gluten.config.GlutenConfig +import org.apache.gluten.extension.columnar.cost.{GlutenCost, GlutenCostModel, LegacyCoster, LongCostModel} import org.apache.gluten.extension.columnar.enumerated.EnumeratedTransform -import org.apache.gluten.extension.columnar.enumerated.planner.cost.{GlutenCostModel, LegacyCoster, LongCostModel} import org.apache.gluten.extension.columnar.enumerated.planner.property.Conv import org.apache.gluten.extension.columnar.transition.{Convention, ConventionReq} -import org.apache.gluten.ras.{Cost, Ras} +import org.apache.gluten.ras.Ras import org.apache.gluten.ras.RasSuiteBase._ import org.apache.gluten.ras.path.RasPath import org.apache.gluten.ras.property.PropertySet @@ -152,7 +152,7 @@ object VeloxRasSuite { def newRas(rasRules: Seq[RasRule[SparkPlan]]): Ras[SparkPlan] = { GlutenOptimization .builder() - .costModel(sessionCostModel()) + .costModel(EnumeratedTransform.asRasCostModel(sessionCostModel())) .addRules(rasRules) .create() .asInstanceOf[Ras[SparkPlan]] @@ -205,27 +205,27 @@ object VeloxRasSuite { class UserCostModel1 extends GlutenCostModel { private val base = legacyCostModel() - override def costOf(node: SparkPlan): Cost = node match { + override def costOf(node: SparkPlan): GlutenCost = node match { case _: RowUnary => base.makeInfCost() case other => base.costOf(other) } - override def costComparator(): Ordering[Cost] = base.costComparator() - override def makeInfCost(): Cost = base.makeInfCost() - override def sum(one: Cost, other: Cost): Cost = base.sum(one, other) - override def diff(one: Cost, other: Cost): Cost = base.diff(one, other) - override def makeZeroCost(): Cost = base.makeZeroCost() + override def costComparator(): Ordering[GlutenCost] = base.costComparator() + override def makeInfCost(): GlutenCost = base.makeInfCost() + override def sum(one: GlutenCost, other: GlutenCost): GlutenCost = base.sum(one, other) + override def diff(one: GlutenCost, other: GlutenCost): GlutenCost = base.diff(one, other) + override def makeZeroCost(): GlutenCost = base.makeZeroCost() } class UserCostModel2 extends GlutenCostModel { private val base = legacyCostModel() - override def costOf(node: SparkPlan): Cost = node match { + override def costOf(node: SparkPlan): GlutenCost = node match { case _: ColumnarUnary => base.makeInfCost() case other => base.costOf(other) } - override def costComparator(): Ordering[Cost] = base.costComparator() - override def makeInfCost(): Cost = base.makeInfCost() - override def sum(one: Cost, other: Cost): Cost = base.sum(one, other) - override def diff(one: Cost, other: Cost): Cost = base.diff(one, other) - override def makeZeroCost(): Cost = base.makeZeroCost() + override def costComparator(): Ordering[GlutenCost] = base.costComparator() + override def makeInfCost(): GlutenCost = base.makeInfCost() + override def sum(one: GlutenCost, other: GlutenCost): GlutenCost = base.sum(one, other) + override def diff(one: GlutenCost, other: GlutenCost): GlutenCost = base.diff(one, other) + override def makeZeroCost(): GlutenCost = base.makeZeroCost() } } diff --git a/gluten-core/src/main/scala/org/apache/gluten/component/Component.scala b/gluten-core/src/main/scala/org/apache/gluten/component/Component.scala index 4a066e1484c8..bc8640a6bc46 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/component/Component.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/component/Component.scala @@ -16,6 +16,7 @@ */ package org.apache.gluten.component +import org.apache.gluten.extension.columnar.cost.LongCoster import org.apache.gluten.extension.columnar.transition.ConventionFunc import org.apache.gluten.extension.injector.Injector @@ -69,6 +70,12 @@ trait Component { */ def convFuncOverride(): ConventionFunc.Override = ConventionFunc.Override.Empty + /** + * A sequence of [[org.apache.gluten.extension.columnar.cost.LongCoster]] Gluten is using for cost + * evaluation. + */ + def costers(): Seq[LongCoster] = Nil + /** Query planner rules. */ def injectRules(injector: Injector): Unit } diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/planner/cost/GlutenCostModel.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/cost/GlutenCost.scala similarity index 66% rename from gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/planner/cost/GlutenCostModel.scala rename to gluten-core/src/main/scala/org/apache/gluten/extension/columnar/cost/GlutenCost.scala index 41e5529d2eba..08a21549a0fe 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/planner/cost/GlutenCostModel.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/cost/GlutenCost.scala @@ -14,17 +14,6 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.gluten.extension.columnar.enumerated.planner.cost +package org.apache.gluten.extension.columnar.cost -import org.apache.gluten.ras.{Cost, CostModel} - -import org.apache.spark.sql.execution.SparkPlan - -trait GlutenCostModel extends CostModel[SparkPlan] { - // Returns cost value of one + other. - def sum(one: Cost, other: Cost): Cost - // Returns cost value of one - other. - def diff(one: Cost, other: Cost): Cost - - def makeZeroCost(): Cost -} +trait GlutenCost diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/cost/GlutenCostModel.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/cost/GlutenCostModel.scala new file mode 100644 index 000000000000..80edf8919f00 --- /dev/null +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/cost/GlutenCostModel.scala @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.gluten.extension.columnar.cost + +import org.apache.gluten.component.Component + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.util.SparkReflectionUtil + +/** + * The cost model API of Gluten. Used by: + * 1. RAS planner for cost-based optimization; 2. Transition graph for choosing transition paths. + */ +trait GlutenCostModel { + def costOf(node: SparkPlan): GlutenCost + def costComparator(): Ordering[GlutenCost] + def makeZeroCost(): GlutenCost + def makeInfCost(): GlutenCost + // Returns cost value of one + other. + def sum(one: GlutenCost, other: GlutenCost): GlutenCost + // Returns cost value of one - other. + def diff(one: GlutenCost, other: GlutenCost): GlutenCost +} + +object GlutenCostModel extends Logging { + def find(aliasOrClass: String): GlutenCostModel = { + val costModelRegistry = LongCostModel.registry() + // Components should override Backend's costers. Hence, reversed registration order is applied. + Component + .sorted() + .reverse + .flatMap(_.costers()) + .foreach(coster => costModelRegistry.register(coster)) + val costModel = find(costModelRegistry, aliasOrClass) + costModel + } + + private def find(registry: LongCostModel.Registry, aliasOrClass: String): GlutenCostModel = { + if (LongCostModel.Kind.values().contains(aliasOrClass)) { + val kind = LongCostModel.Kind.values()(aliasOrClass) + val model = registry.get(kind) + return model + } + val clazz = SparkReflectionUtil.classForName(aliasOrClass) + logInfo(s"Using user cost model: $aliasOrClass") + val ctor = clazz.getDeclaredConstructor() + ctor.setAccessible(true) + val model: GlutenCostModel = ctor.newInstance().asInstanceOf[GlutenCostModel] + model + } +} diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/planner/cost/LongCost.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/cost/LongCost.scala similarity index 84% rename from gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/planner/cost/LongCost.scala rename to gluten-core/src/main/scala/org/apache/gluten/extension/columnar/cost/LongCost.scala index aa74f7736fbd..7de8407ffe6f 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/planner/cost/LongCost.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/cost/LongCost.scala @@ -14,8 +14,6 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.gluten.extension.columnar.enumerated.planner.cost +package org.apache.gluten.extension.columnar.cost -import org.apache.gluten.ras.Cost - -case class LongCost(value: Long) extends Cost +case class LongCost(value: Long) extends GlutenCost diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/planner/cost/LongCostModel.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/cost/LongCostModel.scala similarity index 88% rename from gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/planner/cost/LongCostModel.scala rename to gluten-core/src/main/scala/org/apache/gluten/extension/columnar/cost/LongCostModel.scala index 0d11541b73dd..2cdf86e6af55 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/planner/cost/LongCostModel.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/cost/LongCostModel.scala @@ -14,11 +14,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.gluten.extension.columnar.enumerated.planner.cost +package org.apache.gluten.extension.columnar.cost import org.apache.gluten.exception.GlutenException import org.apache.gluten.extension.columnar.enumerated.planner.plan.GlutenPlanModel.GroupLeafExec -import org.apache.gluten.ras.Cost import org.apache.spark.internal.Logging import org.apache.spark.sql.execution.SparkPlan @@ -39,15 +38,15 @@ abstract class LongCostModel extends GlutenCostModel { assert(a >= 0) assert(b >= 0) val sum = a + b - if (sum < a || sum < b) Long.MaxValue else sum + if (sum < a || sum < b) infLongCost else sum } - override def sum(one: Cost, other: Cost): LongCost = (one, other) match { + override def sum(one: GlutenCost, other: GlutenCost): LongCost = (one, other) match { case (LongCost(value), LongCost(otherValue)) => LongCost(safeSum(value, otherValue)) } // Returns cost value of one - other. - override def diff(one: Cost, other: Cost): Cost = (one, other) match { + override def diff(one: GlutenCost, other: GlutenCost): GlutenCost = (one, other) match { case (LongCost(value), LongCost(otherValue)) => val d = Math.subtractExact(value, otherValue) require(d >= zeroLongCost, s"Difference between cost $one and $other should not be negative") @@ -62,13 +61,13 @@ abstract class LongCostModel extends GlutenCostModel { def selfLongCostOf(node: SparkPlan): Long - override def costComparator(): Ordering[Cost] = Ordering.Long.on { + override def costComparator(): Ordering[GlutenCost] = Ordering.Long.on { case LongCost(value) => value case _ => throw new IllegalStateException("Unexpected cost type") } - override def makeInfCost(): Cost = LongCost(infLongCost) - override def makeZeroCost(): Cost = LongCost(zeroLongCost) + override def makeInfCost(): GlutenCost = LongCost(infLongCost) + override def makeZeroCost(): GlutenCost = LongCost(zeroLongCost) } object LongCostModel extends Logging { diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/planner/cost/LongCoster.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/cost/LongCoster.scala similarity index 95% rename from gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/planner/cost/LongCoster.scala rename to gluten-core/src/main/scala/org/apache/gluten/extension/columnar/cost/LongCoster.scala index f06d1a4db829..8346f8987c8a 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/planner/cost/LongCoster.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/cost/LongCoster.scala @@ -14,7 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.gluten.extension.columnar.enumerated.planner.cost +package org.apache.gluten.extension.columnar.cost import org.apache.spark.sql.execution.SparkPlan diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/planner/cost/LongCosterChain.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/cost/LongCosterChain.scala similarity index 96% rename from gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/planner/cost/LongCosterChain.scala rename to gluten-core/src/main/scala/org/apache/gluten/extension/columnar/cost/LongCosterChain.scala index 00980e7712a4..c7fe616a4d7b 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/planner/cost/LongCosterChain.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/cost/LongCosterChain.scala @@ -14,7 +14,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.gluten.extension.columnar.enumerated.planner.cost +package org.apache.gluten.extension.columnar.cost + import org.apache.gluten.exception.GlutenException import org.apache.spark.sql.execution.SparkPlan diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/EnumeratedTransform.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/EnumeratedTransform.scala index 59e829e17936..72926407c522 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/EnumeratedTransform.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/EnumeratedTransform.scala @@ -19,12 +19,13 @@ package org.apache.gluten.extension.columnar.enumerated import org.apache.gluten.component.Component import org.apache.gluten.exception.GlutenException import org.apache.gluten.extension.columnar.ColumnarRuleApplier.ColumnarRuleCall +import org.apache.gluten.extension.columnar.cost.{GlutenCost, GlutenCostModel} import org.apache.gluten.extension.columnar.enumerated.planner.GlutenOptimization -import org.apache.gluten.extension.columnar.enumerated.planner.cost.GlutenCostModel import org.apache.gluten.extension.columnar.enumerated.planner.property.Conv import org.apache.gluten.extension.injector.Injector import org.apache.gluten.extension.util.AdaptiveContext import org.apache.gluten.logging.LogLevelUtil +import org.apache.gluten.ras.{Cost, CostModel} import org.apache.gluten.ras.property.PropertySet import org.apache.gluten.ras.rule.RasRule @@ -47,11 +48,12 @@ import org.apache.spark.sql.execution._ case class EnumeratedTransform(costModel: GlutenCostModel, rules: Seq[RasRule[SparkPlan]]) extends Rule[SparkPlan] with LogLevelUtil { + import EnumeratedTransform._ private val optimization = { GlutenOptimization .builder() - .costModel(costModel) + .costModel(asRasCostModel(costModel)) .addRules(rules) .create() } @@ -82,4 +84,18 @@ object EnumeratedTransform { val call = new ColumnarRuleCall(session, AdaptiveContext(session), false) dummyInjector.gluten.ras.createEnumeratedTransform(call) } + + def asRasCostModel(gcm: GlutenCostModel): CostModel[SparkPlan] = { + new CostModelAdapter(gcm) + } + + /** The adapter to make GlutenCostModel comply with RAS cost model. */ + private class CostModelAdapter(gcm: GlutenCostModel) extends CostModel[SparkPlan] { + override def costOf(node: SparkPlan): Cost = CostAdapter(gcm.costOf(node)) + override def costComparator(): Ordering[Cost] = + gcm.costComparator().on[Cost] { case CostAdapter(gc) => gc } + override def makeInfCost(): Cost = CostAdapter(gcm.makeInfCost()) + } + + private case class CostAdapter(gc: GlutenCost) extends Cost } diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/planner/property/Conv.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/planner/property/Conv.scala index 9fa0a839a4f9..ff530d49bced 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/planner/property/Conv.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/planner/property/Conv.scala @@ -46,7 +46,7 @@ sealed trait Conv extends Property[SparkPlan] { return true } val prop = this.asInstanceOf[Prop] - val out = Transition.factory().satisfies(prop.prop, req.req) + val out = Transition.factory.satisfies(prop.prop, req.req) out } } @@ -64,7 +64,7 @@ object Conv { def findTransition(from: Conv, to: Conv): Transition = { val prop = from.asInstanceOf[Prop] val req = to.asInstanceOf[Req] - val out = Transition.factory().findTransition(prop.prop, req.req, new IllegalStateException()) + val out = Transition.factory.findTransition(prop.prop, req.req, new IllegalStateException()) out } diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/Convention.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/Convention.scala index ff0f29585299..cd341410cf31 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/Convention.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/Convention.scala @@ -116,21 +116,21 @@ object Convention { protected[this] def registerTransitions(): Unit final protected[this] def fromRow(transition: Transition): Unit = { - Transition.graph.addEdge(RowType.VanillaRow, this, transition) + Transition.factory.update(graph => graph.addEdge(RowType.VanillaRow, this, transition)) } final protected[this] def toRow(transition: Transition): Unit = { - Transition.graph.addEdge(this, RowType.VanillaRow, transition) + Transition.factory.update(graph => graph.addEdge(this, RowType.VanillaRow, transition)) } final protected[this] def fromBatch(from: BatchType, transition: Transition): Unit = { assert(from != this) - Transition.graph.addEdge(from, this, transition) + Transition.factory.update(graph => graph.addEdge(from, this, transition)) } final protected[this] def toBatch(to: BatchType, transition: Transition): Unit = { assert(to != this) - Transition.graph.addEdge(this, to, transition) + Transition.factory.update(graph => graph.addEdge(this, to, transition)) } } diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/FloydWarshallGraph.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/FloydWarshallGraph.scala index b05e93968711..00c687d3b6e8 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/FloydWarshallGraph.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/FloydWarshallGraph.scala @@ -44,8 +44,8 @@ object FloydWarshallGraph { def cost(costModel: CostModel[E]): Cost } - def builder[V <: AnyRef, E <: AnyRef](costModelFactory: () => CostModel[E]): Builder[V, E] = { - Builder.create(costModelFactory) + def builder[V <: AnyRef, E <: AnyRef](): Builder[V, E] = { + Builder.create() } private object Path { @@ -83,24 +83,22 @@ object FloydWarshallGraph { trait Builder[V <: AnyRef, E <: AnyRef] { def addVertex(v: V): Builder[V, E] def addEdge(from: V, to: V, edge: E): Builder[V, E] - def build(): FloydWarshallGraph[V, E] + def build(costModel: CostModel[E]): FloydWarshallGraph[V, E] } private object Builder { - // Thread safe. - private class Impl[V <: AnyRef, E <: AnyRef](costModelFactory: () => CostModel[E]) - extends Builder[V, E] { + private class Impl[V <: AnyRef, E <: AnyRef]() extends Builder[V, E] { private val pathTable: mutable.Map[V, mutable.Map[V, Path[E]]] = mutable.Map() private var graph: Option[FloydWarshallGraph[V, E]] = None - override def addVertex(v: V): Builder[V, E] = synchronized { + override def addVertex(v: V): Builder[V, E] = { assert(!pathTable.contains(v), s"Vertex $v already exists in graph") pathTable.getOrElseUpdate(v, mutable.Map()).getOrElseUpdate(v, Path(Nil)) graph = None this } - override def addEdge(from: V, to: V, edge: E): Builder[V, E] = synchronized { + override def addEdge(from: V, to: V, edge: E): Builder[V, E] = { assert(from != to, s"Input vertices $from and $to should be different") assert(pathTable.contains(from), s"Vertex $from not exists in graph") assert(pathTable.contains(to), s"Vertex $to not exists in graph") @@ -110,26 +108,7 @@ object FloydWarshallGraph { this } - override def build(): FloydWarshallGraph[V, E] = synchronized { - if (graph.isEmpty) { - graph = Some(compile()) - } - return graph.get - } - - private def hasPath(from: V, to: V): Boolean = { - if (!pathTable.contains(from)) { - return false - } - val vec = pathTable(from) - if (!vec.contains(to)) { - return false - } - true - } - - private def compile(): FloydWarshallGraph[V, E] = { - val costModel = costModelFactory() + override def build(costModel: CostModel[E]): FloydWarshallGraph[V, E] = { val vertices = pathTable.keys for (k <- vertices) { for (i <- vertices) { @@ -156,10 +135,21 @@ object FloydWarshallGraph { } new FloydWarshallGraph.Impl(pathTable.map { case (k, m) => (k, m.toMap) }.toMap) } + + private def hasPath(from: V, to: V): Boolean = { + if (!pathTable.contains(from)) { + return false + } + val vec = pathTable(from) + if (!vec.contains(to)) { + return false + } + true + } } - def create[V <: AnyRef, E <: AnyRef](costModelFactory: () => CostModel[E]): Builder[V, E] = { - new Impl(costModelFactory) + def create[V <: AnyRef, E <: AnyRef](): Builder[V, E] = { + new Impl() } } } diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/Transition.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/Transition.scala index e7a073d9ad16..41951d6da910 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/Transition.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/Transition.scala @@ -16,10 +16,14 @@ */ package org.apache.gluten.extension.columnar.transition +import org.apache.gluten.config.GlutenConfig import org.apache.gluten.exception.GlutenException +import org.apache.gluten.extension.columnar.cost.GlutenCostModel import org.apache.spark.sql.execution.SparkPlan +import scala.collection.mutable + /** * Transition is a simple function to convert a query plan to interested [[ConventionReq]]. * @@ -47,9 +51,7 @@ trait Transition { object Transition { val empty: Transition = (plan: SparkPlan) => plan private val abort: Transition = (_: SparkPlan) => throw new UnsupportedOperationException("Abort") - private[transition] val graph: TransitionGraph.Builder = TransitionGraph.builder() - - def factory(): Factory = Factory.newBuiltin(graph.build()) + val factory = Factory.newBuiltin() def notFound(plan: SparkPlan): GlutenException = { new GlutenException(s"No viable transition found from plan's child to itself: $plan") @@ -74,16 +76,32 @@ object Transition { transition.isEmpty } - protected def findTransition(from: Convention, to: ConventionReq)( + def update(body: TransitionGraph.Builder => Unit): Unit + + protected[Factory] def findTransition(from: Convention, to: ConventionReq)( orElse: => Transition): Transition } private object Factory { - def newBuiltin(graph: TransitionGraph): Factory = { - new BuiltinFactory(graph) + def newBuiltin(): Factory = { + new BuiltinFactory() } - private class BuiltinFactory(graph: TransitionGraph) extends Factory { + private class BuiltinFactory() extends Factory { + private val graphBuilder: TransitionGraph.Builder = TransitionGraph.builder() + // Use of this cache allows user to set a new cost model in the same Spark session, + // then the new cost model will take effect for new transition-finding requests. + private val graphCache = mutable.Map[String, TransitionGraph]() + + private def graph(): TransitionGraph = synchronized { + val aliasOrClass = GlutenConfig.get.rasCostModel + graphCache.getOrElseUpdate( + aliasOrClass, { + val base = GlutenCostModel.find(aliasOrClass) + graphBuilder.build(TransitionGraph.asTransitionCostModel(base)) + }) + } + override def findTransition(from: Convention, to: ConventionReq)( orElse: => Transition): Transition = { assert( @@ -104,7 +122,7 @@ object Transition { case Convention.RowType.None => // Input query plan doesn't have recognizable row-based output, // find columnar-to-row transition. - graph.transitionOfOption(from.batchType, toRowType).getOrElse(orElse) + graph().transitionOfOption(from.batchType, toRowType).getOrElse(orElse) case fromRowType if toRowType == fromRowType => // We have only one single built-in row type. Transition.empty @@ -117,12 +135,12 @@ object Transition { case Convention.BatchType.None => // Input query plan doesn't have recognizable columnar output, // find row-to-columnar transition. - graph.transitionOfOption(from.rowType, toBatchType).getOrElse(orElse) + graph().transitionOfOption(from.rowType, toBatchType).getOrElse(orElse) case fromBatchType if toBatchType == fromBatchType => Transition.empty case fromBatchType => // Find columnar-to-columnar transition. - graph.transitionOfOption(fromBatchType, toBatchType).getOrElse(orElse) + graph().transitionOfOption(fromBatchType, toBatchType).getOrElse(orElse) } case (ConventionReq.RowType.Any, ConventionReq.BatchType.Any) => Transition.empty @@ -132,6 +150,11 @@ object Transition { } out } + + override def update(func: TransitionGraph.Builder => Unit): Unit = synchronized { + func(graphBuilder) + graphCache.clear() + } } } } diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/TransitionGraph.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/TransitionGraph.scala index 8e9744383107..7dece0b3f579 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/TransitionGraph.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/TransitionGraph.scala @@ -16,9 +16,8 @@ */ package org.apache.gluten.extension.columnar.transition -import org.apache.gluten.extension.columnar.enumerated.EnumeratedTransform +import org.apache.gluten.extension.columnar.cost.{GlutenCost, GlutenCostModel} import org.apache.gluten.extension.columnar.transition.Convention.BatchType -import org.apache.gluten.ras.Cost import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.util.SparkReflectionUtil @@ -40,7 +39,7 @@ object TransitionGraph { } final private def register(): Unit = BatchType.synchronized { - Transition.graph.addVertex(this) + Transition.factory.update(graph => graph.addVertex(this)) register0() } @@ -51,8 +50,13 @@ object TransitionGraph { type Builder = FloydWarshallGraph.Builder[TransitionGraph.Vertex, Transition] - def builder(): Builder = { - FloydWarshallGraph.builder(() => new TransitionCostModel()) + private[transition] def builder(): Builder = { + FloydWarshallGraph.builder() + } + + private[transition] def asTransitionCostModel( + base: GlutenCostModel): FloydWarshallGraph.CostModel[Transition] = { + new TransitionCostModel(base) } implicit class TransitionGraphOps(val graph: TransitionGraph) { @@ -93,21 +97,22 @@ object TransitionGraph { } /** Reuse RAS cost to represent transition cost. */ - private case class TransitionCost(value: Cost, nodeNames: Seq[String]) + private case class TransitionCost(value: GlutenCost, nodeNames: Seq[String]) extends FloydWarshallGraph.Cost /** - * The cost model reuses RAS's cost model to evaluate cost of transitions. + * The transition cost model relies on the registered Gluten cost model internally to evaluate + * cost of transitions. * * Note the transition graph is built once for all subsequent Spark sessions created on the same - * driver, so any access to Spark dynamic SQL config in RAS cost model will not take effect for + * driver, so any access to Spark dynamic SQL config in Gluten cost model will not take effect for * the transition cost evaluation. Hence, it's not recommended to access Spark dynamic - * configurations in RAS cost model as well. + * configurations in Gluten cost model as well. */ - private class TransitionCostModel() extends FloydWarshallGraph.CostModel[Transition] { - private val rasCostModel = EnumeratedTransform.static().costModel + private class TransitionCostModel(base: GlutenCostModel) + extends FloydWarshallGraph.CostModel[Transition] { - override def zero(): TransitionCost = TransitionCost(rasCostModel.makeZeroCost(), Nil) + override def zero(): TransitionCost = TransitionCost(base.makeZeroCost(), Nil) override def costOf(transition: Transition): TransitionCost = { costOf0(transition) } @@ -115,13 +120,13 @@ object TransitionGraph { one: FloydWarshallGraph.Cost, other: FloydWarshallGraph.Cost): FloydWarshallGraph.Cost = (one, other) match { case (TransitionCost(c1, p1), TransitionCost(c2, p2)) => - TransitionCost(rasCostModel.sum(c1, c2), p1 ++ p2) + TransitionCost(base.sum(c1, c2), p1 ++ p2) } override def costComparator(): Ordering[FloydWarshallGraph.Cost] = { (x: FloydWarshallGraph.Cost, y: FloydWarshallGraph.Cost) => (x, y) match { case (TransitionCost(v1, nodeNames1), TransitionCost(v2, nodeNames2)) => - val diff = rasCostModel.costComparator().compare(v1, v2) + val diff = base.costComparator().compare(v1, v2) if (diff != 0) { diff } else { @@ -139,14 +144,14 @@ object TransitionGraph { * The calculation considers C2C's cost as half of C2R / R2C's cost. So query planner prefers * C2C than C2R / R2C. */ - def rasCostOfPlan(plan: SparkPlan): Cost = rasCostModel.costOf(plan) + def rasCostOfPlan(plan: SparkPlan): GlutenCost = base.costOf(plan) def nodeNamesOfPlan(plan: SparkPlan): Seq[String] = { plan.map(_.nodeName).reverse } val leafCost = rasCostOfPlan(leaf) val accumulatedCost = rasCostOfPlan(transited) - val costDiff = rasCostModel.diff(accumulatedCost, leafCost) + val costDiff = base.diff(accumulatedCost, leafCost) val leafNodeNames = nodeNamesOfPlan(leaf) val accumulatedNodeNames = nodeNamesOfPlan(transited) diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/Transitions.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/Transitions.scala index 297485d84419..6ac847d19e1f 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/Transitions.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/Transitions.scala @@ -52,7 +52,7 @@ case class InsertTransitions(convReq: ConventionReq) extends Rule[SparkPlan] { child } else { val transition = - Transition.factory().findTransition(from, convReq, Transition.notFound(node)) + Transition.factory.findTransition(from, convReq, Transition.notFound(node)) val newChild = transition.apply(child) newChild } @@ -100,8 +100,7 @@ object Transitions { def enforceReq(plan: SparkPlan, req: ConventionReq): SparkPlan = { val convFunc = ConventionFunc.create() val removed = RemoveTransitions.removeForNode(plan) - val transition = Transition - .factory() + val transition = Transition.factory .findTransition(convFunc.conventionOf(removed), req, Transition.notFound(removed, req)) val out = transition.apply(removed) out diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/injector/GlutenInjector.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/injector/GlutenInjector.scala index 23db1c436da8..a208db2c9631 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/injector/GlutenInjector.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/injector/GlutenInjector.scala @@ -20,8 +20,8 @@ import org.apache.gluten.config.GlutenConfig import org.apache.gluten.extension.GlutenColumnarRule import org.apache.gluten.extension.columnar.ColumnarRuleApplier import org.apache.gluten.extension.columnar.ColumnarRuleApplier.ColumnarRuleCall +import org.apache.gluten.extension.columnar.cost.GlutenCostModel import org.apache.gluten.extension.columnar.enumerated.{EnumeratedApplier, EnumeratedTransform} -import org.apache.gluten.extension.columnar.enumerated.planner.cost.{GlutenCostModel, LongCoster, LongCostModel} import org.apache.gluten.extension.columnar.heuristic.{HeuristicApplier, HeuristicTransform} import org.apache.gluten.ras.rule.RasRule @@ -29,7 +29,6 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.{SparkSession, SparkSessionExtensions} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.SparkPlan -import org.apache.spark.util.SparkReflectionUtil import scala.collection.mutable @@ -106,7 +105,6 @@ object GlutenInjector { class RasInjector extends Logging { private val preTransformBuilders = mutable.Buffer.empty[ColumnarRuleCall => Rule[SparkPlan]] private val rasRuleBuilders = mutable.Buffer.empty[ColumnarRuleCall => RasRule[SparkPlan]] - private val costerBuilders = mutable.Buffer.empty[ColumnarRuleCall => LongCoster] private val postTransformBuilders = mutable.Buffer.empty[ColumnarRuleCall => Rule[SparkPlan]] def injectPreTransform(builder: ColumnarRuleCall => Rule[SparkPlan]): Unit = { @@ -117,10 +115,6 @@ object GlutenInjector { rasRuleBuilders += builder } - def injectCoster(builder: ColumnarRuleCall => LongCoster): Unit = { - costerBuilders += builder - } - def injectPostTransform(builder: ColumnarRuleCall => Rule[SparkPlan]): Unit = { postTransformBuilders += builder } @@ -135,31 +129,9 @@ object GlutenInjector { def createEnumeratedTransform(call: ColumnarRuleCall): EnumeratedTransform = { // Build RAS rules. val rules = rasRuleBuilders.map(_(call)) - - // Build the cost model. - val costModelRegistry = LongCostModel.registry() - costerBuilders.foreach(cb => costModelRegistry.register(cb(call))) - val aliasOrClass = call.glutenConf.rasCostModel - val costModel = findCostModel(costModelRegistry, aliasOrClass) - + val costModel = GlutenCostModel.find(call.glutenConf.rasCostModel) // Create transform. EnumeratedTransform(costModel, rules.toSeq) } - - private def findCostModel( - registry: LongCostModel.Registry, - aliasOrClass: String): GlutenCostModel = { - if (LongCostModel.Kind.values().contains(aliasOrClass)) { - val kind = LongCostModel.Kind.values()(aliasOrClass) - val model = registry.get(kind) - return model - } - val clazz = SparkReflectionUtil.classForName(aliasOrClass) - logInfo(s"Using user cost model: $aliasOrClass") - val ctor = clazz.getDeclaredConstructor() - ctor.setAccessible(true) - val model: GlutenCostModel = ctor.newInstance().asInstanceOf[GlutenCostModel] - model - } } } diff --git a/gluten-core/src/test/scala/org/apache/gluten/extension/columnar/transition/FloydWarshallGraphSuite.scala b/gluten-core/src/test/scala/org/apache/gluten/extension/columnar/transition/FloydWarshallGraphSuite.scala index 7b60940a1ae2..7d78df45c126 100644 --- a/gluten-core/src/test/scala/org/apache/gluten/extension/columnar/transition/FloydWarshallGraphSuite.scala +++ b/gluten-core/src/test/scala/org/apache/gluten/extension/columnar/transition/FloydWarshallGraphSuite.scala @@ -36,7 +36,7 @@ class FloydWarshallGraphSuite extends AnyFunSuite { val e42 = Edge(3) val graph = FloydWarshallGraph - .builder(() => CostModel) + .builder() .addVertex(v0) .addVertex(v1) .addVertex(v2) @@ -47,7 +47,7 @@ class FloydWarshallGraphSuite extends AnyFunSuite { .addEdge(v0, v3, e03) .addEdge(v3, v4, e34) .addEdge(v4, v2, e42) - .build() + .build(CostModel) assert(graph.hasPath(v0, v1)) assert(graph.hasPath(v0, v2)) diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/enumerated/planner/cost/LegacyCoster.scala b/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/cost/LegacyCoster.scala similarity index 96% rename from gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/enumerated/planner/cost/LegacyCoster.scala rename to gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/cost/LegacyCoster.scala index bb89d0035bf8..a8e1524fc997 100644 --- a/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/enumerated/planner/cost/LegacyCoster.scala +++ b/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/cost/LegacyCoster.scala @@ -14,7 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.gluten.extension.columnar.enumerated.planner.cost +package org.apache.gluten.extension.columnar.cost import org.apache.gluten.extension.columnar.transition.{ColumnarToColumnarLike, ColumnarToRowLike, RowToColumnarLike} import org.apache.gluten.utils.PlanUtil diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/enumerated/planner/cost/RoughCoster.scala b/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/cost/RoughCoster.scala similarity index 97% rename from gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/enumerated/planner/cost/RoughCoster.scala rename to gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/cost/RoughCoster.scala index ab893265ec42..caee696df640 100644 --- a/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/enumerated/planner/cost/RoughCoster.scala +++ b/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/cost/RoughCoster.scala @@ -14,7 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.gluten.extension.columnar.enumerated.planner.cost +package org.apache.gluten.extension.columnar.cost import org.apache.gluten.execution.RowToColumnarExecBase import org.apache.gluten.extension.columnar.transition.{ColumnarToColumnarLike, ColumnarToRowLike, RowToColumnarLike} diff --git a/gluten-substrait/src/test/scala/org/apache/gluten/extension/columnar/transition/TransitionSuite.scala b/gluten-substrait/src/test/scala/org/apache/gluten/extension/columnar/transition/TransitionSuite.scala index 2c423783fdcc..03c9c0f55975 100644 --- a/gluten-substrait/src/test/scala/org/apache/gluten/extension/columnar/transition/TransitionSuite.scala +++ b/gluten-substrait/src/test/scala/org/apache/gluten/extension/columnar/transition/TransitionSuite.scala @@ -20,7 +20,7 @@ import org.apache.gluten.backend.Backend import org.apache.gluten.component.Component import org.apache.gluten.exception.GlutenException import org.apache.gluten.execution.{ColumnarToColumnarExec, GlutenPlan} -import org.apache.gluten.extension.columnar.enumerated.planner.cost.LegacyCoster +import org.apache.gluten.extension.columnar.cost.{LegacyCoster, LongCoster} import org.apache.gluten.extension.injector.Injector import org.apache.spark.rdd.RDD @@ -152,8 +152,7 @@ object TransitionSuite extends TransitionSuiteBase { override def name(): String = "dummy-backend" override def buildInfo(): Component.BuildInfo = Component.BuildInfo("DUMMY_BACKEND", "N/A", "N/A", "N/A") - override def injectRules(injector: Injector): Unit = { - injector.gluten.ras.injectCoster(_ => LegacyCoster) - } + override def injectRules(injector: Injector): Unit = {} + override def costers(): Seq[LongCoster] = Seq(LegacyCoster) } } diff --git a/shims/common/src/main/scala/org/apache/gluten/config/GlutenConfig.scala b/shims/common/src/main/scala/org/apache/gluten/config/GlutenConfig.scala index eb9071badb14..635b334e2806 100644 --- a/shims/common/src/main/scala/org/apache/gluten/config/GlutenConfig.scala +++ b/shims/common/src/main/scala/org/apache/gluten/config/GlutenConfig.scala @@ -1450,12 +1450,15 @@ object GlutenConfig { .booleanConf .createWithDefault(false) + // FIXME: This option is no longer only used by RAS. Should change key to + // `spark.gluten.costModel` or something similar. val RAS_COST_MODEL = buildConf("spark.gluten.ras.costModel") .doc( "The class name of user-defined cost model that will be used by Gluten's transition " + - "planner as well as by RAS. If not specified, a legacy built-in cost model that " + - "exhaustively offloads computations will be used.") + "planner as well as by RAS. If not specified, a legacy built-in cost model will be " + + "used. The legacy cost model helps RAS planner exhaustively offload computations, and " + + "helps transition planner choose columnar-to-columnar transition over others.") .stringConf .createWithDefaultString("legacy")