Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import org.apache.gluten.expression.ExpressionNames.MONOTONICALLY_INCREASING_ID
import org.apache.gluten.extension.ExpressionExtensionTrait
import org.apache.gluten.extension.columnar.heuristic.HeuristicTransform
import org.apache.gluten.sql.shims.SparkShimLoader
import org.apache.gluten.substrait.SubstraitContext
import org.apache.gluten.substrait.expression.{ExpressionBuilder, ExpressionNode, WindowFunctionNode}
import org.apache.gluten.utils.{CHJoinValidateUtil, UnknownJoinStrategy}
import org.apache.gluten.vectorized.{BlockOutputStream, CHColumnarBatchSerializer, CHNativeBlock, CHStreamReader}
Expand Down Expand Up @@ -59,8 +60,7 @@ import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.commons.lang3.ClassUtils

import java.io.{ObjectInputStream, ObjectOutputStream}
import java.lang.{Long => JLong}
import java.util.{ArrayList => JArrayList, List => JList, Map => JMap}
import java.util.{ArrayList => JArrayList, List => JList}

import scala.collection.JavaConverters._
import scala.collection.mutable.ArrayBuffer
Expand Down Expand Up @@ -709,7 +709,7 @@ class CHSparkPlanExecApi extends SparkPlanExecApi with Logging {
windowExpression: Seq[NamedExpression],
windowExpressionNodes: JList[WindowFunctionNode],
originalInputAttributes: Seq[Attribute],
args: JMap[String, JLong]): Unit = {
context: SubstraitContext): Unit = {

windowExpression.map {
windowExpr =>
Expand All @@ -721,7 +721,7 @@ class CHSparkPlanExecApi extends SparkPlanExecApi with Logging {
val aggWindowFunc = wf.asInstanceOf[AggregateWindowFunction]
val frame = aggWindowFunc.frame.asInstanceOf[SpecifiedWindowFrame]
val windowFunctionNode = ExpressionBuilder.makeWindowFunction(
WindowFunctionsBuilder.create(args, aggWindowFunc).toInt,
WindowFunctionsBuilder.create(context, aggWindowFunc).toInt,
new JArrayList[ExpressionNode](),
columnName,
ConverterUtils.getTypeNode(aggWindowFunc.dataType, aggWindowFunc.nullable),
Expand All @@ -745,10 +745,10 @@ class CHSparkPlanExecApi extends SparkPlanExecApi with Logging {
childrenNodeList.add(
ExpressionConverter
.replaceWithExpressionTransformer(expr, originalInputAttributes)
.doTransform(args)))
.doTransform(context)))

val windowFunctionNode = ExpressionBuilder.makeWindowFunction(
CHExpressions.createAggregateFunction(args, aggExpression.aggregateFunction).toInt,
CHExpressions.createAggregateFunction(context, aggExpression.aggregateFunction).toInt,
childrenNodeList,
columnName,
ConverterUtils.getTypeNode(aggExpression.dataType, aggExpression.nullable),
Expand Down Expand Up @@ -784,21 +784,21 @@ class CHSparkPlanExecApi extends SparkPlanExecApi with Logging {
.replaceWithExpressionTransformer(
offsetWf.input,
attributeSeq = originalInputAttributes)
.doTransform(args))
.doTransform(context))
childrenNodeList.add(
ExpressionConverter
.replaceWithExpressionTransformer(
offsetWf.offset,
attributeSeq = originalInputAttributes)
.doTransform(args))
.doTransform(context))
childrenNodeList.add(
ExpressionConverter
.replaceWithExpressionTransformer(
offsetWf.default,
attributeSeq = originalInputAttributes)
.doTransform(args))
.doTransform(context))
val windowFunctionNode = ExpressionBuilder.makeWindowFunction(
WindowFunctionsBuilder.create(args, offsetWf).toInt,
WindowFunctionsBuilder.create(context, offsetWf).toInt,
childrenNodeList,
columnName,
ConverterUtils.getTypeNode(offsetWf.dataType, offsetWf.nullable),
Expand All @@ -812,9 +812,9 @@ class CHSparkPlanExecApi extends SparkPlanExecApi with Logging {
val frame = wExpression.windowSpec.frameSpecification.asInstanceOf[SpecifiedWindowFrame]
val childrenNodeList = new JArrayList[ExpressionNode]()
val literal = buckets.asInstanceOf[Literal]
childrenNodeList.add(LiteralTransformer(literal).doTransform(args))
childrenNodeList.add(LiteralTransformer(literal).doTransform(context))
val windowFunctionNode = ExpressionBuilder.makeWindowFunction(
WindowFunctionsBuilder.create(args, wf).toInt,
WindowFunctionsBuilder.create(context, wf).toInt,
childrenNodeList,
columnName,
ConverterUtils.getTypeNode(wf.dataType, wf.nullable),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.gluten.backendsapi.clickhouse
import org.apache.gluten.backendsapi.TransformerApi
import org.apache.gluten.execution.{CHHashAggregateExecTransformer, WriteFilesExecTransformer}
import org.apache.gluten.expression.ConverterUtils
import org.apache.gluten.substrait.SubstraitContext
import org.apache.gluten.substrait.expression.{BooleanLiteralNode, ExpressionBuilder, ExpressionNode}
import org.apache.gluten.utils.{CHInputPartitionsUtil, ExpressionDocUtil}

Expand Down Expand Up @@ -211,16 +212,14 @@ class CHTransformerApi extends TransformerApi with Logging {
}

override def createCheckOverflowExprNode(
args: java.lang.Object,
context: SubstraitContext,
substraitExprName: String,
childNode: ExpressionNode,
childResultType: DataType,
dataType: DecimalType,
nullable: Boolean,
nullOnOverflow: Boolean): ExpressionNode = {
val functionMap = args.asInstanceOf[java.util.HashMap[String, java.lang.Long]]
val functionId = ExpressionBuilder.newScalarFunction(
functionMap,
val functionId = context.registerFunction(
ConverterUtils.makeFuncName(
substraitExprName,
Seq(dataType, BooleanType),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ class CHValidatorApi extends ValidatorApi with AdaptiveSparkPlanHelper with Logg
expr =>
val node = ExpressionConverter
.replaceWithExpressionTransformer(expr, outputAttributes)
.doTransform(substraitContext.registeredFunction)
.doTransform(substraitContext)
node.isInstanceOf[SelectionNode]
}
if (allSelectionNodes || supportShuffleWithProject(outputPartitioning, child)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,13 +86,12 @@ case class CHAggregateGroupLimitExecTransformer(
operatorId: Long,
input: RelNode,
validation: Boolean): RelNode = {
val args = context.registeredFunction
// Partition By Expressions
val partitionsExpressions = partitionSpec
.map(
ExpressionConverter
.replaceWithExpressionTransformer(_, attributeSeq = child.output)
.doTransform(args))
.doTransform(context))
.asJava

// Sort By Expressions
Expand All @@ -102,7 +101,7 @@ case class CHAggregateGroupLimitExecTransformer(
val builder = SortField.newBuilder()
val exprNode = ExpressionConverter
.replaceWithExpressionTransformer(order.child, attributeSeq = child.output)
.doTransform(args)
.doTransform(context)
builder.setExpr(exprNode.toProtobuf)
builder.setDirectionValue(SortExecTransformer.transformSortDirection(order))
builder.build()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,6 @@ case class CHHashAggregateExecTransformer(
operatorId: Long,
input: RelNode = null,
validation: Boolean): RelNode = {
val args = context.registeredFunction
// Get the grouping nodes.
val groupingList = new util.ArrayList[ExpressionNode]()
groupingExpressions.foreach(
Expand All @@ -247,7 +246,7 @@ case class CHHashAggregateExecTransformer(
// may be different for each backend.
val exprNode = ExpressionConverter
.replaceWithExpressionTransformer(expr, childOutput)
.doTransform(args)
.doTransform(context)
groupingList.add(exprNode)
})
// Get the aggregate function nodes.
Expand All @@ -267,7 +266,7 @@ case class CHHashAggregateExecTransformer(
if (aggExpr.filter.isDefined) {
val exprNode = ExpressionConverter
.replaceWithExpressionTransformer(aggExpr.filter.get, childOutput)
.doTransform(args)
.doTransform(context)
aggFilterList.add(exprNode)
} else {
aggFilterList.add(null)
Expand All @@ -281,7 +280,7 @@ case class CHHashAggregateExecTransformer(
expr => {
ExpressionConverter
.replaceWithExpressionTransformer(expr, childOutput)
.doTransform(args)
.doTransform(context)
})

val extraNodes = aggregateFunc match {
Expand All @@ -290,7 +289,7 @@ case class CHHashAggregateExecTransformer(
Seq(
ExpressionConverter
.replaceWithExpressionTransformer(relativeSDLiteral, child.output)
.doTransform(args))
.doTransform(context))
case _ => Seq.empty
}

Expand All @@ -311,20 +310,20 @@ case class CHHashAggregateExecTransformer(
child.asInstanceOf[BaseAggregateExec].groupingExpressions,
child.asInstanceOf[BaseAggregateExec].aggregateExpressions)
)
Seq(aggTypesExpr.doTransform(args))
Seq(aggTypesExpr.doTransform(context))
case Final | PartialMerge =>
Seq(
ExpressionConverter
.replaceWithExpressionTransformer(aggExpr.resultAttribute, originalInputAttributes)
.doTransform(args))
.doTransform(context))
case other =>
throw new GlutenNotSupportException(s"$other not supported.")
}
for (node <- childrenNodes) {
childrenNodeList.add(node)
}
val aggFunctionNode = ExpressionBuilder.makeAggregateFunction(
CHExpressions.createAggregateFunction(args, aggregateFunc),
CHExpressions.createAggregateFunction(context, aggregateFunc),
childrenNodeList,
modeToKeyWord(aggExpr.mode),
ConverterUtils.getTypeNode(aggregateFunc.dataType, aggregateFunc.nullable)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,13 +96,12 @@ case class CHWindowGroupLimitExecTransformer(
operatorId: Long,
input: RelNode,
validation: Boolean): RelNode = {
val args = context.registeredFunction
// Partition By Expressions
val partitionsExpressions = partitionSpec
.map(
ExpressionConverter
.replaceWithExpressionTransformer(_, attributeSeq = child.output)
.doTransform(args))
.doTransform(context))
.asJava

// Sort By Expressions
Expand All @@ -112,7 +111,7 @@ case class CHWindowGroupLimitExecTransformer(
val builder = SortField.newBuilder()
val exprNode = ExpressionConverter
.replaceWithExpressionTransformer(order.child, attributeSeq = child.output)
.doTransform(args)
.doTransform(context)
builder.setExpr(exprNode.toProtobuf)
builder.setDirectionValue(SortExecTransformer.transformSortDirection(order))
builder.build()
Expand Down
Loading