Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,55 +22,54 @@ import org.apache.gluten.substrait.expression.ExpressionNode
import org.apache.gluten.substrait.plan.{PlanBuilder, PlanNode}
import org.apache.gluten.substrait.rel.RelBuilder

import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, BoundReference, Expression}
import org.apache.spark.sql.catalyst.expressions.{Attribute, BoundReference, Expression}

import com.google.common.collect.Lists

import java.util

import scala.collection.JavaConverters._

object PlanNodesUtil {

def genProjectionsPlanNode(key: Expression, output: Seq[Attribute]): PlanNode = {
def genProjectionsPlanNode(key: Seq[Expression], output: Seq[Attribute]): PlanNode = {
val context = new SubstraitContext

var operatorId = context.nextOperatorId("ClickHouseBuildSideRelationReadIter")
val typeList = ConverterUtils.collectAttributeTypeNodes(output)
val nameList = ConverterUtils.collectAttributeNamesWithExprId(output)
val readRel = RelBuilder.makeReadRelForInputIterator(typeList, nameList, context, operatorId)

// replace attribute to BoundRefernce according to the output
val newBoundRefKey = key.transformDown {
case expression: AttributeReference =>
val columnInOutput = output.zipWithIndex.filter {
p: (Attribute, Int) => p._1.exprId == expression.exprId || p._1.name == expression.name
}
if (columnInOutput.isEmpty) {
throw new IllegalStateException(
s"Key $expression not found from build side relation output: $output")
}
if (columnInOutput.size != 1) {
throw new IllegalStateException(
s"More than one key $expression found from build side relation output: $output")
}
val boundReference = columnInOutput.head
BoundReference(boundReference._2, boundReference._1.dataType, boundReference._1.nullable)
case other => other
}

// project
operatorId = context.nextOperatorId("ClickHouseBuildSideRelationProjection")
val args = context.registeredFunction

val columnarProjExpr = ExpressionConverter
.replaceWithExpressionTransformer(newBoundRefKey, attributeSeq = output)
.replaceWithExpressionTransformer(key, attributeSeq = output)

val projExprNodeList = new java.util.ArrayList[ExpressionNode]()
projExprNodeList.add(columnarProjExpr.doTransform(args))
columnarProjExpr.foreach(e => projExprNodeList.add(e.doTransform(args)))

PlanBuilder.makePlan(
context,
Lists.newArrayList(
RelBuilder.makeProjectRel(readRel, projExprNodeList, context, operatorId, output.size)),
Lists.newArrayList(
ConverterUtils.genColumnNameWithExprId(ConverterUtils.getAttrFromExpr(key)))
Lists.newArrayList(genColumnNameWithExprId(key, output))
)
}

private def genColumnNameWithExprId(
key: Seq[Expression],
output: Seq[Attribute]): util.List[String] = {
key
.map {
k =>
val reference = k.collectFirst { case BoundReference(ordinal, _, _) => output(ordinal) }
assert(reference.isDefined)
reference.get
}
.map(ConverterUtils.genColumnNameWithExprId)
.toList
.asJava
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ import org.apache.gluten.vectorized._

import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}
import org.apache.spark.sql.catalyst.plans.physical.BroadcastMode
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, UnsafeProjection}
import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, IdentityBroadcastMode}
import org.apache.spark.sql.execution.utils.CHExecUtil
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.storage.CHShuffleReadStreamFactory
Expand Down Expand Up @@ -72,18 +72,26 @@ case class ClickHouseBuildSideRelation(
}

/**
* Transform columnar broadcast value to Array[InternalRow] by key and distinct.
* Transform columnar broadcast value to Array[InternalRow] by key.
*
* @return
*/
override def transform(key: Expression): Array[InternalRow] = {
// native block reader
val blockReader = new CHStreamReader(CHShuffleReadStreamFactory.create(batches, true))
val broadCastIter: Iterator[ColumnarBatch] = IteratorUtil.createBatchIterator(blockReader)

val transformProjections = mode match {
case HashedRelationBroadcastMode(k, _) => k
case IdentityBroadcastMode => output
}

// Expression compute, return block iterator
val expressionEval = new SimpleExpressionEval(
new ColumnarNativeIterator(broadCastIter.asJava),
PlanNodesUtil.genProjectionsPlanNode(key, output))
PlanNodesUtil.genProjectionsPlanNode(transformProjections, output))

val proj = UnsafeProjection.create(Seq(key))

try {
// convert columnar to row
Expand All @@ -95,6 +103,7 @@ case class ClickHouseBuildSideRelation(
} else {
CHExecUtil
.getRowIterFromSparkRowInfo(block, batch.numColumns(), batch.numRows())
.map(proj)
.map(row => row.copy())
}
}.toArray
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -632,7 +632,7 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi {
}
numOutputRows += serialized.map(_.getNumRows).sum
dataSize += rawSize
ColumnarBuildSideRelation(child.output, serialized.map(_.getSerialized))
ColumnarBuildSideRelation(child.output, serialized.map(_.getSerialized), mode)
}

override def doCanonicalizeForBroadcastMode(mode: BroadcastMode): BroadcastMode = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,8 @@ object BroadcastUtils {
}
ColumnarBuildSideRelation(
SparkShimLoader.getSparkShims.attributesFromStruct(schema),
serialized)
serialized,
mode)
}
// Rebroadcast Velox relation.
context.broadcast(toRelation).asInstanceOf[Broadcast[T]]
Expand All @@ -123,7 +124,8 @@ object BroadcastUtils {
}
ColumnarBuildSideRelation(
SparkShimLoader.getSparkShims.attributesFromStruct(schema),
serialized)
serialized,
mode)
}
// Rebroadcast Velox relation.
context.broadcast(toRelation).asInstanceOf[Broadcast[T]]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,11 @@ import org.apache.gluten.utils.ArrowAbiUtil
import org.apache.gluten.vectorized.{ColumnarBatchSerializerJniWrapper, NativeColumnarToRowJniWrapper}

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression, UnsafeProjection, UnsafeRow}
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, UnsafeProjection, UnsafeRow}
import org.apache.spark.sql.catalyst.plans.physical.BroadcastMode
import org.apache.spark.sql.catalyst.plans.physical.IdentityBroadcastMode
import org.apache.spark.sql.execution.joins.BuildSideRelation
import org.apache.spark.sql.execution.joins.HashedRelationBroadcastMode
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.utils.SparkArrowUtil
import org.apache.spark.sql.vectorized.ColumnarBatch
Expand All @@ -36,9 +39,19 @@ import org.apache.arrow.c.ArrowSchema

import scala.collection.JavaConverters.asScalaIteratorConverter

case class ColumnarBuildSideRelation(output: Seq[Attribute], batches: Array[Array[Byte]])
case class ColumnarBuildSideRelation(
output: Seq[Attribute],
batches: Array[Array[Byte]],
mode: BroadcastMode)
extends BuildSideRelation {

private def transformProjection: UnsafeProjection = {
mode match {
case HashedRelationBroadcastMode(k, _) => UnsafeProjection.create(k)
case IdentityBroadcastMode => UnsafeProjection.create(output, output)
}
}

override def deserialized: Iterator[ColumnarBatch] = {
val runtime = Runtimes.contextInstance("BuildSideRelation#deserialized")
val jniWrapper = ColumnarBatchSerializerJniWrapper.create(runtime)
Expand Down Expand Up @@ -82,8 +95,11 @@ case class ColumnarBuildSideRelation(output: Seq[Attribute], batches: Array[Arra
override def asReadOnlyCopy(): ColumnarBuildSideRelation = this

/**
* Transform columnar broadcast value to Array[InternalRow] by key and distinct. NOTE: This method
* was called in Spark Driver, should manage resources carefully.
* Transform columnar broadcast value to Array[InternalRow] by key.
*
* NOTE:
* - This method was called in Spark Driver, should manage resources carefully.
* - The "key" must be already been bound reference.
*/
override def transform(key: Expression): Array[InternalRow] = TaskResources.runUnsafe {
val runtime = Runtimes.contextInstance("BuildSideRelation#transform")
Expand All @@ -103,17 +119,7 @@ case class ColumnarBuildSideRelation(output: Seq[Attribute], batches: Array[Arra

var closed = false

val exprIds = output.map(_.exprId)
val projExpr = key.transformDown {
case attr: AttributeReference if !exprIds.contains(attr.exprId) =>
val i = output.count(_.name == attr.name)
if (i != 1) {
throw new IllegalArgumentException(s"Only one attr with the same name is supported: $key")
} else {
output.find(_.name == attr.name).get
}
}
val proj = UnsafeProjection.create(Seq(projExpr), output)
val proj = UnsafeProjection.create(Seq(key))

// Convert columnar to Row.
val jniWrapper = NativeColumnarToRowJniWrapper.create(runtime)
Expand Down Expand Up @@ -175,7 +181,7 @@ case class ColumnarBuildSideRelation(output: Seq[Attribute], batches: Array[Arra
rowId += 1
row
}
}.map(proj).map(_.copy())
}.map(transformProjection).map(proj).map(_.copy())
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.execution.joins.{BuildSideRelation, HashedRelation, HashJoin, LongHashedRelation}
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
import org.apache.spark.sql.types.IntegralType
import org.apache.spark.util.ThreadUtils

import scala.concurrent.Future
Expand Down Expand Up @@ -64,6 +65,14 @@ case class ColumnarSubqueryBroadcastExec(
copy(name = "native-dpp", buildKeys = keys, child = child.canonicalized)
}

// Copy from org.apache.spark.sql.execution.joins.HashJoin#canRewriteAsLongType
// we should keep consistent with it to identify the LongHashRelation.
private def canRewriteAsLongType(keys: Seq[Expression]): Boolean = {
Copy link
Copy Markdown
Member

@philo-he philo-he Dec 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yikf, if spark changes this logic on its side, e.g., for adding support for other types, unexpected failure can happen in Gluten? Seems there is no other way.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There doesn't seem to be any other way at the moment.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yikf, ok, let's merge this pr firstly. Thanks for your efforts!

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks @philo-he

// TODO: support BooleanType, DateType and TimestampType
keys.forall(_.dataType.isInstanceOf[IntegralType]) &&
keys.map(_.dataType.defaultSize).sum <= 8
}

@transient
private lazy val relationFuture: Future[Array[InternalRow]] = {
// relationFuture is used in "doExecute". Therefore we can get the execution id correctly here.
Expand All @@ -78,7 +87,13 @@ case class ColumnarSubqueryBroadcastExec(
relation match {
case b: BuildSideRelation =>
// Transform columnar broadcast value to Array[InternalRow] by key.
b.transform(buildKeys(index)).distinct
if (canRewriteAsLongType(buildKeys)) {
b.transform(HashJoin.extractKeyExprAt(buildKeys, index)).distinct
} else {
b.transform(
BoundReference(index, buildKeys(index).dataType, buildKeys(index).nullable))
.distinct
}
case h: HashedRelation =>
val (iter, expr) = if (h.isInstanceOf[LongHashedRelation]) {
(h.keys(), HashJoin.extractKeyExprAt(buildKeys, index))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.joins

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.plans.physical.BroadcastMode
import org.apache.spark.sql.vectorized.ColumnarBatch

trait BuildSideRelation extends Serializable {
Expand All @@ -26,11 +27,19 @@ trait BuildSideRelation extends Serializable {
def deserialized: Iterator[ColumnarBatch]

/**
* Transform columnar broadcasted value to Array[InternalRow] by key and distinct.
* Transform columnar broadcasted value to Array[InternalRow] by key.
* @return
*/
def transform(key: Expression): Array[InternalRow]

/** Returns a read-only copy of this, to be safely used in current thread. */
def asReadOnlyCopy(): BuildSideRelation

/**
* The broadcast mode that is associated with this relation in Gluten allows for direct
* broadcasting of the original relation, so transforming a relation has a post-processing nature.
*
* Post-processed relation transforms can use this mode to obtain the desired format.
*/
val mode: BroadcastMode
}