Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution

import org.apache.gluten.execution.{DeltaFilterExecTransformer, DeltaProjectExecTransformer, GlutenClickHouseTPCHAbstractSuite}

import org.apache.spark.SparkConf
import org.apache.spark.sql.Column
import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.delta.metric.IncrementMetric
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.execution.metric.SQLMetrics.createMetric

// Some sqls' line length exceeds 100
// scalastyle:off line.size.limit

class GlutenDeltaExpressionSuite
extends GlutenClickHouseTPCHAbstractSuite
with AdaptiveSparkPlanHelper {

override protected val needCopyParquetToTablePath = true

override protected val tablesPath: String = basePath + "/tpch-data"
override protected val tpchQueries: String = rootPath + "queries/tpch-queries-ch"
override protected val queriesResults: String = rootPath + "mergetree-queries-output"

// import org.apache.gluten.backendsapi.clickhouse.CHConfig._

/** Run Gluten + ClickHouse Backend with SortShuffleManager */
override protected def sparkConf: SparkConf = {
super.sparkConf
.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.ColumnarShuffleManager")
.set("spark.io.compression.codec", "LZ4")
.set("spark.sql.shuffle.partitions", "5")
.set("spark.sql.autoBroadcastJoinThreshold", "10MB")
.set("spark.sql.adaptive.enabled", "true")
.set("spark.sql.files.maxPartitionBytes", "20000000")
.set("spark.sql.storeAssignmentPolicy", "legacy")
.set("spark.databricks.delta.retentionDurationCheck.enabled", "false")
}

override protected def createTPCHNotNullTables(): Unit = {
createNotNullTPCHTablesInParquet(tablesPath)
}

test("test project IncrementMetric not fallback") {
val table_name = "project_increment_metric"
withTable(table_name) {
spark.sql(s"""
|CREATE TABLE IF NOT EXISTS $table_name
|($lineitemNullableSchema)
|USING delta
|TBLPROPERTIES (delta.enableDeletionVectors='true')
|LOCATION '$basePath/$table_name'
|""".stripMargin)

spark.sql(s"""insert into table $table_name select * from lineitem""".stripMargin)
val metric = createMetric(sparkContext, "number of source rows")
val metricFilter = createMetric(sparkContext, "number of source rows (during repeated scan)")
val df = sql(s"select l_orderkey,l_shipdate from $table_name")
.withColumn("im", Column(IncrementMetric(Literal(true), metric)))
.filter("im")
.filter(Column(IncrementMetric(Literal(true), metricFilter)))
.drop("im")
df.collect()

val cnt = df.queryExecution.executedPlan.collect {
case _: DeltaProjectExecTransformer => true
case _: DeltaFilterExecTransformer => true
}

assertResult(2)(cnt.size)
assertResult(600572)(metric.value)
assertResult(600572)(metricFilter.value)
}
}
}
// scalastyle:off line.size.limit
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,17 @@
package org.apache.gluten.component

import org.apache.gluten.backendsapi.clickhouse.CHBackend
import org.apache.gluten.execution.OffloadDeltaNode
import org.apache.gluten.execution.{OffloadDeltaFilter, OffloadDeltaNode, OffloadDeltaProject}
import org.apache.gluten.extension.DeltaPostTransformRules
import org.apache.gluten.extension.columnar.enumerated.RasOffload
import org.apache.gluten.extension.columnar.heuristic.HeuristicTransform
import org.apache.gluten.extension.columnar.validator.Validators
import org.apache.gluten.extension.injector.Injector
import org.apache.gluten.sql.shims.DeltaShimLoader

import org.apache.spark.SparkContext
import org.apache.spark.api.plugin.PluginContext
import org.apache.spark.sql.execution.{FilterExec, ProjectExec}

class CHDeltaComponent extends Component {
override def name(): String = "ch-delta"
Expand All @@ -37,11 +40,25 @@ class CHDeltaComponent extends Component {

override def injectRules(injector: Injector): Unit = {
val legacy = injector.gluten.legacy
val ras = injector.gluten.ras
legacy.injectTransform {
c =>
val offload = Seq(OffloadDeltaNode())
val offload = Seq(OffloadDeltaNode(), OffloadDeltaProject(), OffloadDeltaFilter())
HeuristicTransform.Simple(Validators.newValidator(c.glutenConf, offload), offload)
}
val offloads: Seq[RasOffload] = Seq(
RasOffload.from[ProjectExec](OffloadDeltaProject()),
RasOffload.from[FilterExec](OffloadDeltaFilter())
)
offloads.foreach(
offload =>
ras.injectRasRule(
c => RasOffload.Rule(offload, Validators.newValidator(c.glutenConf), Nil)))
DeltaPostTransformRules.rules.foreach {
r =>
legacy.injectPostTransform(_ => r)
ras.injectPostTransform(_ => r)
}

DeltaShimLoader.getDeltaShims.registerExpressionExtension()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ class CHMetricsApi extends MetricsApi with Logging with LogLevelUtil {
override def genFilterTransformerMetricsUpdater(
metrics: Map[String, SQLMetric],
extraMetrics: Seq[(String, SQLMetric)] = Seq.empty): MetricsUpdater =
new FilterMetricsUpdater(metrics)
new FilterMetricsUpdater(metrics, extraMetrics)

override def genProjectTransformerMetrics(sparkContext: SparkContext): Map[String, SQLMetric] =
Map(
Expand All @@ -206,7 +206,7 @@ class CHMetricsApi extends MetricsApi with Logging with LogLevelUtil {
override def genProjectTransformerMetricsUpdater(
metrics: Map[String, SQLMetric],
extraMetrics: Seq[(String, SQLMetric)] = Seq.empty): MetricsUpdater =
new ProjectMetricsUpdater(metrics)
new ProjectMetricsUpdater(metrics, extraMetrics)

override def genHashAggregateTransformerMetrics(
sparkContext: SparkContext): Map[String, SQLMetric] =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,41 @@
*/
package org.apache.gluten.metrics

import org.apache.gluten.metrics.ProjectMetricsUpdater.{DELTA_INPUT_ROW_METRIC_NAMES, UNSUPPORTED_METRIC_NAMES}

import org.apache.spark.internal.Logging
import org.apache.spark.sql.execution.metric.SQLMetric

class FilterMetricsUpdater(val metrics: Map[String, SQLMetric]) extends MetricsUpdater {
class FilterMetricsUpdater(
val metrics: Map[String, SQLMetric],
val extraMetrics: Seq[(String, SQLMetric)]
) extends MetricsUpdater
with Logging {

override def updateNativeMetrics(opMetrics: IOperatorMetrics): Unit = {
if (opMetrics != null) {
val operatorMetrics = opMetrics.asInstanceOf[OperatorMetrics]
if (!operatorMetrics.metricsList.isEmpty) {
var numInputRows = Seq(metrics("numInputRows"))
extraMetrics.foreach {
case (name, metric) =>
name match {
case "increment_metric" =>
metric.name match {
case Some(input) if DELTA_INPUT_ROW_METRIC_NAMES.contains(input) =>
numInputRows = numInputRows :+ metric
case Some(unSupport) if UNSUPPORTED_METRIC_NAMES.contains(unSupport) =>
logTrace(s"Unsupported metric name: $unSupport")
case Some(other) =>
logTrace(s"Unknown metric name: $other")
case _ => // do nothing
}
case o: String =>
logTrace(s"Unknown metric name: $o")
case _ => // do nothing
}
}

val metricsData = operatorMetrics.metricsList.get(0)
metrics("totalTime") += (metricsData.time / 1000L).toLong
metrics("inputWaitTime") += (metricsData.inputWaitTime / 1000L).toLong
Expand All @@ -35,7 +62,7 @@ class FilterMetricsUpdater(val metrics: Map[String, SQLMetric]) extends MetricsU
metrics("extraTime"),
metrics("numOutputRows"),
metrics("outputBytes"),
metrics("numInputRows"),
numInputRows,
metrics("inputBytes"),
FilterMetricsUpdater.INCLUDING_PROCESSORS,
FilterMetricsUpdater.INCLUDING_PROCESSORS
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ object MetricsUtil extends Logging {
extraTime: SQLMetric,
outputRows: SQLMetric,
outputBytes: SQLMetric,
inputRows: SQLMetric,
inputRows: Seq[SQLMetric],
inputBytes: SQLMetric,
includingMetrics: Array[String],
planNodeNames: Array[String]): Unit = {
Expand All @@ -207,9 +207,29 @@ object MetricsUtil extends Logging {
if (planNodeNames.exists(processor.name.startsWith(_))) {
outputRows += processor.outputRows
outputBytes += processor.outputBytes
inputRows += processor.inputRows
inputRows.foreach(inputRow => inputRow += processor.inputRows)
inputBytes += processor.inputBytes
}
})
}

def updateExtraTimeMetric(
metricData: MetricsData,
extraTime: SQLMetric,
outputRows: SQLMetric,
outputBytes: SQLMetric,
inputRows: SQLMetric,
inputBytes: SQLMetric,
includingMetrics: Array[String],
planNodeNames: Array[String]): Unit = {
updateExtraTimeMetric(
metricData,
extraTime,
outputRows,
outputBytes,
Seq(inputRows),
inputBytes,
includingMetrics,
planNodeNames)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,41 @@
*/
package org.apache.gluten.metrics

import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.gluten.metrics.ProjectMetricsUpdater.{DELTA_INPUT_ROW_METRIC_NAMES, UNSUPPORTED_METRIC_NAMES}

class ProjectMetricsUpdater(val metrics: Map[String, SQLMetric]) extends MetricsUpdater {
import org.apache.spark.internal.Logging
import org.apache.spark.sql.execution.metric.SQLMetric

class ProjectMetricsUpdater(
val metrics: Map[String, SQLMetric],
val extraMetrics: Seq[(String, SQLMetric)])
extends MetricsUpdater
with Logging {
override def updateNativeMetrics(opMetrics: IOperatorMetrics): Unit = {
if (opMetrics != null) {
val operatorMetrics = opMetrics.asInstanceOf[OperatorMetrics]
if (!operatorMetrics.metricsList.isEmpty) {
var numInputRows = Seq(metrics("numInputRows"))

extraMetrics.foreach {
case (name, metric) =>
name match {
case "increment_metric" =>
metric.name match {
case Some(input) if DELTA_INPUT_ROW_METRIC_NAMES.contains(input) =>
numInputRows = numInputRows :+ metric
case Some(unSupport) if UNSUPPORTED_METRIC_NAMES.contains(unSupport) =>
logTrace(s"Unsupported metric name: $unSupport")
case Some(other) =>
logTrace(s"Unknown metric name: $other")
case _ => // do nothing
}
case o: String =>
logTrace(s"Unknown metric name: $o")
case _ => // do nothing
}
}

val metricsData = operatorMetrics.metricsList.get(0)
metrics("totalTime") += (metricsData.time / 1000L).toLong
metrics("inputWaitTime") += (metricsData.inputWaitTime / 1000L).toLong
Expand All @@ -35,7 +62,7 @@ class ProjectMetricsUpdater(val metrics: Map[String, SQLMetric]) extends Metrics
metrics("extraTime"),
metrics("numOutputRows"),
metrics("outputBytes"),
metrics("numInputRows"),
numInputRows,
metrics("inputBytes"),
ProjectMetricsUpdater.INCLUDING_PROCESSORS,
ProjectMetricsUpdater.CH_PLAN_NODE_NAME
Expand All @@ -46,6 +73,21 @@ class ProjectMetricsUpdater(val metrics: Map[String, SQLMetric]) extends Metrics
}

object ProjectMetricsUpdater {
val INCLUDING_PROCESSORS = Array("ExpressionTransform")
val CH_PLAN_NODE_NAME = Array("ExpressionTransform")
val INCLUDING_PROCESSORS: Array[String] = Array("ExpressionTransform")
val CH_PLAN_NODE_NAME: Array[String] = Array("ExpressionTransform")

val UNSUPPORTED_METRIC_NAMES: Set[String] =
Set(
"number of updated rows",
"number of deleted rows",
"number of inserted rows",
"number of rows updated by a matched clause",
"number of rows deleted by a matched clause"
)

val DELTA_INPUT_ROW_METRIC_NAMES: Set[String] = Set(
"number of source rows",
"number of target rows rewritten unmodified",
"number of source rows (during repeated scan)"
)
}
2 changes: 1 addition & 1 deletion cpp-ch/local-engine/Parser/RelParsers/FilterRelParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ FilterRelParser::parse(DB::QueryPlanPtr query_plan, const substrait::Rel & rel,
auto input_header = query_plan->getCurrentHeader();
DB::ActionsDAG actions_dag{input_header.getColumnsWithTypeAndName()};
const auto condition_node = parseExpression(actions_dag, filter_rel.condition());
if (filter_rel.condition().has_scalar_function())
if (filter_rel.condition().has_scalar_function() || filter_rel.condition().has_literal())
{
actions_dag.addOrReplaceInOutputs(*condition_node);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import org.apache.gluten.substrait.SubstraitContext
import org.apache.gluten.substrait.extensions.ExtensionBuilder
import org.apache.gluten.substrait.rel.{RelBuilder, RelNode}

import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, CaseWhen, NamedExpression}
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, NamedExpression}
import org.apache.spark.sql.delta.metric.IncrementMetric
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.metric.SQLMetric
Expand Down Expand Up @@ -81,35 +81,12 @@ case class DeltaProjectExecTransformer(projectList: Seq[NamedExpression], child:
def genNewProjectList(projectList: Seq[NamedExpression]): Seq[NamedExpression] = {
projectList.map {
case alias: Alias =>
alias.child match {
case IncrementMetric(child, metric) =>
extraMetrics :+= (alias.child.prettyName, metric)
Alias(child = child, name = alias.name)()

case CaseWhen(branches, elseValue) =>
val newBranches = branches.map {
case (expr1, expr2: IncrementMetric) =>
extraMetrics :+= (expr2.prettyName, expr2.metric)
(expr1, expr2.child)
case other => other
}

val newElseValue = elseValue match {
case Some(IncrementMetric(child: IncrementMetric, metric)) =>
extraMetrics :+= (child.prettyName, metric)
extraMetrics :+= (child.prettyName, child.metric)
Some(child.child)
case _ => elseValue
}

Alias(
child = CaseWhen(newBranches, newElseValue),
name = alias.name
)(alias.exprId)

case _ =>
alias
val newChild = alias.child.transformUp {
case im @ IncrementMetric(child, metric) =>
extraMetrics :+= (im.prettyName, metric)
child
}
Alias(child = newChild, name = alias.name)(alias.exprId)
case other => other
}
}
Expand Down