Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import org.apache.gluten.execution.CHBroadcastBuildSideCache
import org.apache.gluten.execution.datasource.GlutenFormatFactory
import org.apache.gluten.expression.UDFMappings
import org.apache.gluten.extension.ExpressionExtensionTrait
import org.apache.gluten.extension.columnar.transition.Convention
import org.apache.gluten.jni.JniLibLoader
import org.apache.gluten.vectorized.CHNativeExpressionEvaluator

Expand Down Expand Up @@ -70,7 +71,8 @@ class CHListenerApi extends ListenerApi with Logging {
override def onExecutorShutdown(): Unit = shutdown()

private def initialize(conf: SparkConf, isDriver: Boolean): Unit = {
// Force batch type initializations.
// Do row / batch type initializations.
Convention.ensureSparkRowAndBatchTypesRegistered()
CHBatch.ensureRegistered()
SparkDirectoryUtil.init(conf)
val libPath = conf.get(GlutenConfig.GLUTEN_LIB_PATH, StringUtils.EMPTY)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import org.apache.gluten.config.GlutenConfig
import org.apache.gluten.extension._
import org.apache.gluten.extension.columnar._
import org.apache.gluten.extension.columnar.MiscColumnarRules.{RemoveGlutenTableCacheColumnarToRow, RemoveTopmostColumnarToRow, RewriteSubqueryBroadcast}
import org.apache.gluten.extension.columnar.enumerated.planner.cost.LegacyCoster
import org.apache.gluten.extension.columnar.heuristic.{ExpandFallbackPolicy, HeuristicTransform}
import org.apache.gluten.extension.columnar.offload.{OffloadExchange, OffloadJoin, OffloadOthers}
import org.apache.gluten.extension.columnar.rewrite._
Expand Down Expand Up @@ -142,6 +143,9 @@ object CHRuleApi {
}

private def injectRas(injector: RasInjector): Unit = {
// Register legacy coster for transition planner.
injector.injectCoster(_ => LegacyCoster)

// CH backend doesn't work with RAS at the moment. Inject a rule that aborts any
// execution calls.
injector.injectPreTransform(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import org.apache.gluten.columnarbatch.VeloxBatch
import org.apache.gluten.config.GlutenConfig
import org.apache.gluten.execution.datasource.GlutenFormatFactory
import org.apache.gluten.expression.UDFMappings
import org.apache.gluten.extension.columnar.transition.Convention
import org.apache.gluten.init.NativeBackendInitializer
import org.apache.gluten.jni.{JniLibLoader, JniWorkspace}
import org.apache.gluten.udf.UdfJniWrapper
Expand Down Expand Up @@ -126,10 +127,11 @@ class VeloxListenerApi extends ListenerApi with Logging {
override def onExecutorShutdown(): Unit = shutdown()

private def initialize(conf: SparkConf, isDriver: Boolean): Unit = {
// Force batch type initializations.
VeloxBatch.ensureRegistered()
// Do row / batch type initializations.
Convention.ensureSparkRowAndBatchTypesRegistered()
ArrowJavaBatch.ensureRegistered()
ArrowNativeBatch.ensureRegistered()
VeloxBatch.ensureRegistered()

// Register columnar shuffle so can be considered when
// `org.apache.spark.shuffle.GlutenShuffleManager` is set as Spark shuffle manager.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import org.apache.gluten.extension._
import org.apache.gluten.extension.columnar._
import org.apache.gluten.extension.columnar.MiscColumnarRules.{RemoveGlutenTableCacheColumnarToRow, RemoveTopmostColumnarToRow, RewriteSubqueryBroadcast}
import org.apache.gluten.extension.columnar.enumerated.{RasOffload, RemoveSort}
import org.apache.gluten.extension.columnar.enumerated.planner.cost.{LegacyCoster, RoughCoster, RoughCoster2}
import org.apache.gluten.extension.columnar.enumerated.planner.cost.{LegacyCoster, RoughCoster}
import org.apache.gluten.extension.columnar.heuristic.{ExpandFallbackPolicy, HeuristicTransform}
import org.apache.gluten.extension.columnar.offload.{OffloadExchange, OffloadJoin, OffloadOthers}
import org.apache.gluten.extension.columnar.rewrite._
Expand Down Expand Up @@ -120,6 +120,10 @@ object VeloxRuleApi {
}

private def injectRas(injector: RasInjector): Unit = {
// Gluten RAS: Costers.
injector.injectCoster(_ => LegacyCoster)
injector.injectCoster(_ => RoughCoster)

// Gluten RAS: Pre rules.
injector.injectPreTransform(_ => RemoveTransitions)
injector.injectPreTransform(_ => PushDownInputFileExpression.PreOffload)
Expand All @@ -131,6 +135,7 @@ object VeloxRuleApi {

// Gluten RAS: The RAS rule.
val validatorBuilder: GlutenConfig => Validator = conf => Validators.newValidator(conf)
injector.injectRasRule(_ => RemoveSort)
val rewrites =
Seq(
RewriteIn,
Expand All @@ -139,10 +144,6 @@ object VeloxRuleApi {
PullOutPreProject,
PullOutPostProject,
ProjectColumnPruning)
injector.injectCoster(_ => LegacyCoster)
injector.injectCoster(_ => RoughCoster)
injector.injectCoster(_ => RoughCoster2)
injector.injectRasRule(_ => RemoveSort)
val offloads: Seq[RasOffload] = Seq(
RasOffload.from[Exchange](OffloadExchange()),
RasOffload.from[BaseJoinExec](OffloadJoin()),
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@ package org.apache.gluten.extension.columnar.enumerated.planner

import org.apache.gluten.config.GlutenConfig
import org.apache.gluten.extension.columnar.enumerated.EnumeratedTransform
import org.apache.gluten.extension.columnar.enumerated.planner.cost.{LegacyCoster, LongCostModel}
import org.apache.gluten.extension.columnar.enumerated.planner.cost.{GlutenCostModel, LegacyCoster, LongCostModel}
import org.apache.gluten.extension.columnar.enumerated.planner.property.Conv
import org.apache.gluten.extension.columnar.transition.ConventionReq
import org.apache.gluten.ras.{Cost, CostModel, Ras}
import org.apache.gluten.extension.columnar.transition.{Convention, ConventionReq}
import org.apache.gluten.ras.{Cost, Ras}
import org.apache.gluten.ras.RasSuiteBase._
import org.apache.gluten.ras.path.RasPath
import org.apache.gluten.ras.property.PropertySet
Expand All @@ -37,6 +37,11 @@ import org.apache.spark.sql.types.StringType
class VeloxRasSuite extends SharedSparkSession {
import VeloxRasSuite._

override protected def beforeAll(): Unit = {
super.beforeAll()
Convention.ensureSparkRowAndBatchTypesRegistered()
}

test("C2R, R2C - basic") {
val in = RowUnary(RowLeaf(TRIVIAL_SCHEMA))
val planner = newRas().newPlanner(in)
Expand Down Expand Up @@ -153,14 +158,14 @@ object VeloxRasSuite {
.asInstanceOf[Ras[SparkPlan]]
}

private def legacyCostModel(): CostModel[SparkPlan] = {
private def legacyCostModel(): GlutenCostModel = {
val registry = LongCostModel.registry()
val coster = LegacyCoster
registry.register(coster)
registry.get(coster.kind())
}

private def sessionCostModel(): CostModel[SparkPlan] = {
private def sessionCostModel(): GlutenCostModel = {
val transform = EnumeratedTransform.static()
transform.costModel
}
Expand Down Expand Up @@ -198,23 +203,29 @@ object VeloxRasSuite {
override def shape(): Shape[SparkPlan] = Shapes.fixedHeight(1)
}

class UserCostModel1 extends CostModel[SparkPlan] {
class UserCostModel1 extends GlutenCostModel {
private val base = legacyCostModel()
override def costOf(node: SparkPlan): Cost = node match {
case _: RowUnary => base.makeInfCost()
case other => base.costOf(other)
}
override def costComparator(): Ordering[Cost] = base.costComparator()
override def makeInfCost(): Cost = base.makeInfCost()
override def sum(one: Cost, other: Cost): Cost = base.sum(one, other)
override def diff(one: Cost, other: Cost): Cost = base.diff(one, other)
override def makeZeroCost(): Cost = base.makeZeroCost()
}

class UserCostModel2 extends CostModel[SparkPlan] {
class UserCostModel2 extends GlutenCostModel {
private val base = legacyCostModel()
override def costOf(node: SparkPlan): Cost = node match {
case _: ColumnarUnary => base.makeInfCost()
case other => base.costOf(other)
}
override def costComparator(): Ordering[Cost] = base.costComparator()
override def makeInfCost(): Cost = base.makeInfCost()
override def sum(one: Cost, other: Cost): Cost = base.sum(one, other)
override def diff(one: Cost, other: Cost): Cost = base.diff(one, other)
override def makeZeroCost(): Cost = base.makeZeroCost()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -200,13 +200,13 @@ class VeloxTransitionSuite extends SharedSparkSession {
}

override protected def beforeAll(): Unit = {
api.onExecutorStart(MockVeloxBackend.mockPluginContext())
super.beforeAll()
api.onExecutorStart(MockVeloxBackend.mockPluginContext())
}

override protected def afterAll(): Unit = {
super.afterAll()
api.onExecutorShutdown()
super.afterAll()
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@ import org.apache.gluten.component.Component
import org.apache.gluten.exception.GlutenException
import org.apache.gluten.extension.columnar.ColumnarRuleApplier.ColumnarRuleCall
import org.apache.gluten.extension.columnar.enumerated.planner.GlutenOptimization
import org.apache.gluten.extension.columnar.enumerated.planner.cost.GlutenCostModel
import org.apache.gluten.extension.columnar.enumerated.planner.property.Conv
import org.apache.gluten.extension.injector.Injector
import org.apache.gluten.extension.util.AdaptiveContext
import org.apache.gluten.logging.LogLevelUtil
import org.apache.gluten.ras.CostModel
import org.apache.gluten.ras.property.PropertySet
import org.apache.gluten.ras.rule.RasRule

Expand All @@ -44,7 +44,7 @@ import org.apache.spark.sql.execution._
*
* The feature requires enabling RAS to function.
*/
case class EnumeratedTransform(costModel: CostModel[SparkPlan], rules: Seq[RasRule[SparkPlan]])
case class EnumeratedTransform(costModel: GlutenCostModel, rules: Seq[RasRule[SparkPlan]])
extends Rule[SparkPlan]
with LogLevelUtil {

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.gluten.extension.columnar.enumerated.planner.cost

import org.apache.gluten.ras.{Cost, CostModel}

import org.apache.spark.sql.execution.SparkPlan

trait GlutenCostModel extends CostModel[SparkPlan] {
// Returns cost value of one + other.
def sum(one: Cost, other: Cost): Cost
// Returns cost value of one - other.
def diff(one: Cost, other: Cost): Cost

def makeZeroCost(): Cost
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,34 +18,46 @@ package org.apache.gluten.extension.columnar.enumerated.planner.cost

import org.apache.gluten.exception.GlutenException
import org.apache.gluten.extension.columnar.enumerated.planner.plan.GlutenPlanModel.GroupLeafExec
import org.apache.gluten.ras.{Cost, CostModel}
import org.apache.gluten.ras.Cost

import org.apache.spark.internal.Logging
import org.apache.spark.sql.execution.SparkPlan

import scala.collection.mutable

abstract class LongCostModel extends CostModel[SparkPlan] {
abstract class LongCostModel extends GlutenCostModel {
private val infLongCost = Long.MaxValue
private val zeroLongCost = 0

override def costOf(node: SparkPlan): LongCost = node match {
case _: GroupLeafExec => throw new IllegalStateException()
case _ => LongCost(longCostOf(node))
}

// Sum with ceil to avoid overflow.
private def safeSum(a: Long, b: Long): Long = {
assert(a >= 0)
assert(b >= 0)
val sum = a + b
if (sum < a || sum < b) Long.MaxValue else sum
}

override def sum(one: Cost, other: Cost): LongCost = (one, other) match {
case (LongCost(value), LongCost(otherValue)) => LongCost(safeSum(value, otherValue))
}

// Returns cost value of one - other.
override def diff(one: Cost, other: Cost): Cost = (one, other) match {
case (LongCost(value), LongCost(otherValue)) =>
val d = Math.subtractExact(value, otherValue)
require(d >= zeroLongCost, s"Difference between cost $one and $other should not be negative")
LongCost(d)
}

private def longCostOf(node: SparkPlan): Long = node match {
case n =>
val selfCost = selfLongCostOf(n)

// Sum with ceil to avoid overflow.
def safeSum(a: Long, b: Long): Long = {
assert(a >= 0)
assert(b >= 0)
val sum = a + b
if (sum < a || sum < b) Long.MaxValue else sum
}

(n.children.map(longCostOf).toList :+ selfCost).reduce(safeSum)
(n.children.map(longCostOf).toSeq :+ selfCost).reduce[Long](safeSum)
}

def selfLongCostOf(node: SparkPlan): Long
Expand All @@ -56,6 +68,7 @@ abstract class LongCostModel extends CostModel[SparkPlan] {
}

override def makeInfCost(): Cost = LongCost(infLongCost)
override def makeZeroCost(): Cost = LongCost(zeroLongCost)
}

object LongCostModel extends Logging {
Expand Down Expand Up @@ -98,11 +111,6 @@ object LongCostModel extends Logging {
override def name(): String = "rough"
}

/** Compared with rough, rough2 can be more precise to avoid the costly r2c. */
case object Rough2 extends Kind {
override def name(): String = "rough2"
}

class Registry private[LongCostModel] {
private val lookup: mutable.Map[Kind, LongCosterChain.Builder] = mutable.Map()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ private class LongCosterChain private (costers: Seq[LongCoster]) extends LongCos
case (c @ Some(_), _) =>
c
}
.getOrElse(throw new GlutenException(s"Cost node found for node: $node"))
.getOrElse(throw new GlutenException(s"Cost not found for node: $node"))
}
}

Expand Down
Loading