diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala index eecf0588d427..8136488b6afa 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala @@ -25,6 +25,7 @@ import org.apache.gluten.expression.ExpressionNames.MONOTONICALLY_INCREASING_ID import org.apache.gluten.extension.ExpressionExtensionTrait import org.apache.gluten.extension.columnar.heuristic.HeuristicTransform import org.apache.gluten.sql.shims.SparkShimLoader +import org.apache.gluten.substrait.SubstraitContext import org.apache.gluten.substrait.expression.{ExpressionBuilder, ExpressionNode, WindowFunctionNode} import org.apache.gluten.utils.{CHJoinValidateUtil, UnknownJoinStrategy} import org.apache.gluten.vectorized.{BlockOutputStream, CHColumnarBatchSerializer, CHNativeBlock, CHStreamReader} @@ -59,8 +60,7 @@ import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.commons.lang3.ClassUtils import java.io.{ObjectInputStream, ObjectOutputStream} -import java.lang.{Long => JLong} -import java.util.{ArrayList => JArrayList, List => JList, Map => JMap} +import java.util.{ArrayList => JArrayList, List => JList} import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer @@ -709,7 +709,7 @@ class CHSparkPlanExecApi extends SparkPlanExecApi with Logging { windowExpression: Seq[NamedExpression], windowExpressionNodes: JList[WindowFunctionNode], originalInputAttributes: Seq[Attribute], - args: JMap[String, JLong]): Unit = { + context: SubstraitContext): Unit = { windowExpression.map { windowExpr => @@ -721,7 +721,7 @@ class CHSparkPlanExecApi extends SparkPlanExecApi with Logging { val aggWindowFunc = wf.asInstanceOf[AggregateWindowFunction] val frame = aggWindowFunc.frame.asInstanceOf[SpecifiedWindowFrame] val windowFunctionNode = ExpressionBuilder.makeWindowFunction( - WindowFunctionsBuilder.create(args, aggWindowFunc).toInt, + WindowFunctionsBuilder.create(context, aggWindowFunc).toInt, new JArrayList[ExpressionNode](), columnName, ConverterUtils.getTypeNode(aggWindowFunc.dataType, aggWindowFunc.nullable), @@ -745,10 +745,10 @@ class CHSparkPlanExecApi extends SparkPlanExecApi with Logging { childrenNodeList.add( ExpressionConverter .replaceWithExpressionTransformer(expr, originalInputAttributes) - .doTransform(args))) + .doTransform(context))) val windowFunctionNode = ExpressionBuilder.makeWindowFunction( - CHExpressions.createAggregateFunction(args, aggExpression.aggregateFunction).toInt, + CHExpressions.createAggregateFunction(context, aggExpression.aggregateFunction).toInt, childrenNodeList, columnName, ConverterUtils.getTypeNode(aggExpression.dataType, aggExpression.nullable), @@ -784,21 +784,21 @@ class CHSparkPlanExecApi extends SparkPlanExecApi with Logging { .replaceWithExpressionTransformer( offsetWf.input, attributeSeq = originalInputAttributes) - .doTransform(args)) + .doTransform(context)) childrenNodeList.add( ExpressionConverter .replaceWithExpressionTransformer( offsetWf.offset, attributeSeq = originalInputAttributes) - .doTransform(args)) + .doTransform(context)) childrenNodeList.add( ExpressionConverter .replaceWithExpressionTransformer( offsetWf.default, attributeSeq = originalInputAttributes) - .doTransform(args)) + .doTransform(context)) val windowFunctionNode = ExpressionBuilder.makeWindowFunction( - WindowFunctionsBuilder.create(args, offsetWf).toInt, + WindowFunctionsBuilder.create(context, offsetWf).toInt, childrenNodeList, columnName, ConverterUtils.getTypeNode(offsetWf.dataType, offsetWf.nullable), @@ -812,9 +812,9 @@ class CHSparkPlanExecApi extends SparkPlanExecApi with Logging { val frame = wExpression.windowSpec.frameSpecification.asInstanceOf[SpecifiedWindowFrame] val childrenNodeList = new JArrayList[ExpressionNode]() val literal = buckets.asInstanceOf[Literal] - childrenNodeList.add(LiteralTransformer(literal).doTransform(args)) + childrenNodeList.add(LiteralTransformer(literal).doTransform(context)) val windowFunctionNode = ExpressionBuilder.makeWindowFunction( - WindowFunctionsBuilder.create(args, wf).toInt, + WindowFunctionsBuilder.create(context, wf).toInt, childrenNodeList, columnName, ConverterUtils.getTypeNode(wf.dataType, wf.nullable), diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHTransformerApi.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHTransformerApi.scala index bec88d13d6f3..906d6d9ef7ff 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHTransformerApi.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHTransformerApi.scala @@ -19,6 +19,7 @@ package org.apache.gluten.backendsapi.clickhouse import org.apache.gluten.backendsapi.TransformerApi import org.apache.gluten.execution.{CHHashAggregateExecTransformer, WriteFilesExecTransformer} import org.apache.gluten.expression.ConverterUtils +import org.apache.gluten.substrait.SubstraitContext import org.apache.gluten.substrait.expression.{BooleanLiteralNode, ExpressionBuilder, ExpressionNode} import org.apache.gluten.utils.{CHInputPartitionsUtil, ExpressionDocUtil} @@ -211,16 +212,14 @@ class CHTransformerApi extends TransformerApi with Logging { } override def createCheckOverflowExprNode( - args: java.lang.Object, + context: SubstraitContext, substraitExprName: String, childNode: ExpressionNode, childResultType: DataType, dataType: DecimalType, nullable: Boolean, nullOnOverflow: Boolean): ExpressionNode = { - val functionMap = args.asInstanceOf[java.util.HashMap[String, java.lang.Long]] - val functionId = ExpressionBuilder.newScalarFunction( - functionMap, + val functionId = context.registerFunction( ConverterUtils.makeFuncName( substraitExprName, Seq(dataType, BooleanType), diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHValidatorApi.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHValidatorApi.scala index c2b52d591978..49efc676c33f 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHValidatorApi.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHValidatorApi.scala @@ -85,7 +85,7 @@ class CHValidatorApi extends ValidatorApi with AdaptiveSparkPlanHelper with Logg expr => val node = ExpressionConverter .replaceWithExpressionTransformer(expr, outputAttributes) - .doTransform(substraitContext.registeredFunction) + .doTransform(substraitContext) node.isInstanceOf[SelectionNode] } if (allSelectionNodes || supportShuffleWithProject(outputPartitioning, child)) { diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHAggregateGroupLimitExecTransformer.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHAggregateGroupLimitExecTransformer.scala index 83bb33bfa225..fab7f41e0170 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHAggregateGroupLimitExecTransformer.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHAggregateGroupLimitExecTransformer.scala @@ -86,13 +86,12 @@ case class CHAggregateGroupLimitExecTransformer( operatorId: Long, input: RelNode, validation: Boolean): RelNode = { - val args = context.registeredFunction // Partition By Expressions val partitionsExpressions = partitionSpec .map( ExpressionConverter .replaceWithExpressionTransformer(_, attributeSeq = child.output) - .doTransform(args)) + .doTransform(context)) .asJava // Sort By Expressions @@ -102,7 +101,7 @@ case class CHAggregateGroupLimitExecTransformer( val builder = SortField.newBuilder() val exprNode = ExpressionConverter .replaceWithExpressionTransformer(order.child, attributeSeq = child.output) - .doTransform(args) + .doTransform(context) builder.setExpr(exprNode.toProtobuf) builder.setDirectionValue(SortExecTransformer.transformSortDirection(order)) builder.build() diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashAggregateExecTransformer.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashAggregateExecTransformer.scala index 1cabb5bc75ba..f941df18ebbe 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashAggregateExecTransformer.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashAggregateExecTransformer.scala @@ -238,7 +238,6 @@ case class CHHashAggregateExecTransformer( operatorId: Long, input: RelNode = null, validation: Boolean): RelNode = { - val args = context.registeredFunction // Get the grouping nodes. val groupingList = new util.ArrayList[ExpressionNode]() groupingExpressions.foreach( @@ -247,7 +246,7 @@ case class CHHashAggregateExecTransformer( // may be different for each backend. val exprNode = ExpressionConverter .replaceWithExpressionTransformer(expr, childOutput) - .doTransform(args) + .doTransform(context) groupingList.add(exprNode) }) // Get the aggregate function nodes. @@ -267,7 +266,7 @@ case class CHHashAggregateExecTransformer( if (aggExpr.filter.isDefined) { val exprNode = ExpressionConverter .replaceWithExpressionTransformer(aggExpr.filter.get, childOutput) - .doTransform(args) + .doTransform(context) aggFilterList.add(exprNode) } else { aggFilterList.add(null) @@ -281,7 +280,7 @@ case class CHHashAggregateExecTransformer( expr => { ExpressionConverter .replaceWithExpressionTransformer(expr, childOutput) - .doTransform(args) + .doTransform(context) }) val extraNodes = aggregateFunc match { @@ -290,7 +289,7 @@ case class CHHashAggregateExecTransformer( Seq( ExpressionConverter .replaceWithExpressionTransformer(relativeSDLiteral, child.output) - .doTransform(args)) + .doTransform(context)) case _ => Seq.empty } @@ -311,12 +310,12 @@ case class CHHashAggregateExecTransformer( child.asInstanceOf[BaseAggregateExec].groupingExpressions, child.asInstanceOf[BaseAggregateExec].aggregateExpressions) ) - Seq(aggTypesExpr.doTransform(args)) + Seq(aggTypesExpr.doTransform(context)) case Final | PartialMerge => Seq( ExpressionConverter .replaceWithExpressionTransformer(aggExpr.resultAttribute, originalInputAttributes) - .doTransform(args)) + .doTransform(context)) case other => throw new GlutenNotSupportException(s"$other not supported.") } @@ -324,7 +323,7 @@ case class CHHashAggregateExecTransformer( childrenNodeList.add(node) } val aggFunctionNode = ExpressionBuilder.makeAggregateFunction( - CHExpressions.createAggregateFunction(args, aggregateFunc), + CHExpressions.createAggregateFunction(context, aggregateFunc), childrenNodeList, modeToKeyWord(aggExpr.mode), ConverterUtils.getTypeNode(aggregateFunc.dataType, aggregateFunc.nullable) diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHWindowGroupLimitExecTransformer.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHWindowGroupLimitExecTransformer.scala index 793d733abf96..1111102e892f 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHWindowGroupLimitExecTransformer.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHWindowGroupLimitExecTransformer.scala @@ -96,13 +96,12 @@ case class CHWindowGroupLimitExecTransformer( operatorId: Long, input: RelNode, validation: Boolean): RelNode = { - val args = context.registeredFunction // Partition By Expressions val partitionsExpressions = partitionSpec .map( ExpressionConverter .replaceWithExpressionTransformer(_, attributeSeq = child.output) - .doTransform(args)) + .doTransform(context)) .asJava // Sort By Expressions @@ -112,7 +111,7 @@ case class CHWindowGroupLimitExecTransformer( val builder = SortField.newBuilder() val exprNode = ExpressionConverter .replaceWithExpressionTransformer(order.child, attributeSeq = child.output) - .doTransform(args) + .doTransform(context) builder.setExpr(exprNode.toProtobuf) builder.setDirectionValue(SortExecTransformer.transformSortDirection(order)) builder.build() diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/expression/CHExpressionTransformer.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/expression/CHExpressionTransformer.scala index 0f6c4e05a0e9..c5111fef83b3 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/expression/CHExpressionTransformer.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/expression/CHExpressionTransformer.scala @@ -19,6 +19,7 @@ package org.apache.gluten.expression import org.apache.gluten.backendsapi.clickhouse.CHConfig import org.apache.gluten.exception.GlutenNotSupportException import org.apache.gluten.expression.ConverterUtils.FunctionConfig +import org.apache.gluten.substrait.SubstraitContext import org.apache.gluten.substrait.expression._ import org.apache.spark.sql.catalyst.expressions._ @@ -39,7 +40,7 @@ case class CHTruncTimestampTransformer( extends ExpressionTransformer { override def children: Seq[ExpressionTransformer] = format :: timestamp :: Nil - override def doTransform(args: java.lang.Object): ExpressionNode = { + override def doTransform(context: SubstraitContext): ExpressionNode = { // The format must be constant string in the function date_trunc of ch. if (!original.format.foldable) { throw new GlutenNotSupportException(s"The format ${original.format} must be constant string.") @@ -78,20 +79,17 @@ case class CHTruncTimestampTransformer( s"${timeZoneId.get}.") } - val timestampNode = timestamp.doTransform(args) + val timestampNode = timestamp.doTransform(context) val lowerFormatNode = ExpressionBuilder.makeStringLiteral(newFormatStr) - val functionMap = args.asInstanceOf[java.util.HashMap[String, java.lang.Long]] - val dataTypes = if (timeZoneId.nonEmpty) { Seq(original.format.dataType, original.timestamp.dataType, StringType) } else { Seq(original.format.dataType, original.timestamp.dataType) } - val functionId = ExpressionBuilder.newScalarFunction( - functionMap, - ConverterUtils.makeFuncName(substraitExprName, dataTypes)) + val functionId = + context.registerFunction(ConverterUtils.makeFuncName(substraitExprName, dataTypes)) val expressionNodes = new java.util.ArrayList[ExpressionNode]() expressionNodes.add(lowerFormatNode) @@ -114,10 +112,10 @@ case class CHStringTranslateTransformer( extends ExpressionTransformer { override def children: Seq[ExpressionTransformer] = srcExpr :: matchingExpr :: replaceExpr :: Nil - override def doTransform(args: java.lang.Object): ExpressionNode = { + override def doTransform(context: SubstraitContext): ExpressionNode = { // In CH, translateUTF8 requires matchingExpr and replaceExpr argument have the same length - val matchingNode = matchingExpr.doTransform(args) - val replaceNode = replaceExpr.doTransform(args) + val matchingNode = matchingExpr.doTransform(context) + val replaceNode = replaceExpr.doTransform(context) if ( !matchingNode.isInstanceOf[StringLiteralNode] || !replaceNode.isInstanceOf[StringLiteralNode] @@ -125,7 +123,7 @@ case class CHStringTranslateTransformer( throw new GlutenNotSupportException(s"$original not supported yet.") } - super.doTransform(args) + super.doTransform(context) } } @@ -136,15 +134,11 @@ case class CHPosExplodeTransformer( attributeSeq: Seq[Attribute]) extends UnaryExpressionTransformer { - override def doTransform(args: java.lang.Object): ExpressionNode = { - val childNode: ExpressionNode = child.doTransform(args) - val funcMap = args.asInstanceOf[java.util.HashMap[String, java.lang.Long]] - val funcId = ExpressionBuilder.newScalarFunction( - funcMap, - ConverterUtils.makeFuncName( - ExpressionNames.POSEXPLODE, - Seq(original.child.dataType), - FunctionConfig.OPT)) + override def doTransform(context: SubstraitContext): ExpressionNode = { + val childNode: ExpressionNode = child.doTransform(context) + val funcId = context.registerFunction( + ConverterUtils + .makeFuncName(ExpressionNames.POSEXPLODE, Seq(original.child.dataType), FunctionConfig.OPT)) val childType = original.child.dataType childType match { case a: ArrayType => @@ -181,10 +175,10 @@ case class CHRegExpReplaceTransformer( extends ExpressionTransformer { override def children: Seq[ExpressionTransformer] = childrenWithPos.dropRight(1) - override def doTransform(args: java.lang.Object): ExpressionNode = { + override def doTransform(context: SubstraitContext): ExpressionNode = { // In CH: replaceRegexpAll(subject, regexp, rep), which is equivalent // In Spark: regexp_replace(subject, regexp, rep, pos=1) - val posNode = childrenWithPos(3).doTransform(args) + val posNode = childrenWithPos(3).doTransform(context) if ( !posNode.isInstanceOf[IntLiteralNode] || posNode.asInstanceOf[IntLiteralNode].getValue != 1 @@ -192,7 +186,7 @@ case class CHRegExpReplaceTransformer( throw new UnsupportedOperationException(s"$original dose not supported position yet.") } // Replace $num in rep with \num used in CH - val repNode = childrenWithPos(2).doTransform(args) + val repNode = childrenWithPos(2).doTransform(context) repNode match { case node: StringLiteralNode => val strValue = node.getValue @@ -204,19 +198,18 @@ case class CHRegExpReplaceTransformer( FunctionConfig.OPT) val replacedRepNode = ExpressionBuilder.makeLiteral(replacedValue, StringType, false) val exprNodes = Lists.newArrayList( - childrenWithPos(0).doTransform(args), - childrenWithPos(1).doTransform(args), + childrenWithPos(0).doTransform(context), + childrenWithPos(1).doTransform(context), replacedRepNode) - val functionMap = args.asInstanceOf[java.util.HashMap[String, java.lang.Long]] return ExpressionBuilder.makeScalarFunction( - ExpressionBuilder.newScalarFunction(functionMap, functionName), + context.registerFunction(functionName), exprNodes, ConverterUtils.getTypeNode(original.dataType, original.nullable)) } case _ => } - super.doTransform(args) + super.doTransform(context) } } @@ -227,11 +220,10 @@ case class GetArrayItemTransformer( original: Expression) extends BinaryExpressionTransformer { - override def doTransform(args: java.lang.Object): ExpressionNode = { + override def doTransform(context: SubstraitContext): ExpressionNode = { // Ignore failOnError for clickhouse backend - val functionMap = args.asInstanceOf[java.util.HashMap[String, java.lang.Long]] - val leftNode = left.doTransform(args) - var rightNode = right.doTransform(args) + val leftNode = left.doTransform(context) + var rightNode = right.doTransform(context) val getArrayItem = original.asInstanceOf[GetArrayItem] @@ -242,7 +234,7 @@ case class GetArrayItemTransformer( ExpressionNames.ADD, Seq(IntegerType, getArrayItem.right.dataType), FunctionConfig.OPT) - val addFunctionId = ExpressionBuilder.newScalarFunction(functionMap, addFunctionName) + val addFunctionId = context.registerFunction(addFunctionName) val literalNode = ExpressionBuilder.makeLiteral(1, IntegerType, false) rightNode = ExpressionBuilder.makeScalarFunction( addFunctionId, @@ -255,7 +247,7 @@ case class GetArrayItemTransformer( FunctionConfig.OPT) val exprNodes = Lists.newArrayList(leftNode, rightNode) ExpressionBuilder.makeScalarFunction( - ExpressionBuilder.newScalarFunction(functionMap, functionName), + context.registerFunction(functionName), exprNodes, ConverterUtils.getTypeNode(getArrayItem.dataType, getArrayItem.nullable)) } diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/expression/CHExpressions.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/expression/CHExpressions.scala index fa8a5763a681..70d45a4e5232 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/expression/CHExpressions.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/expression/CHExpressions.scala @@ -18,26 +18,24 @@ package org.apache.gluten.expression import org.apache.gluten.expression.ConverterUtils.FunctionConfig import org.apache.gluten.extension.ExpressionExtensionTrait -import org.apache.gluten.substrait.expression.ExpressionBuilder +import org.apache.gluten.substrait.SubstraitContext import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateFunction // Static helper object for handling expressions that are specifically used in CH backend. object CHExpressions { // Since https://github.com/apache/incubator-gluten/pull/1937. - def createAggregateFunction(args: java.lang.Object, aggregateFunc: AggregateFunction): Long = { - val functionMap = args.asInstanceOf[java.util.HashMap[String, java.lang.Long]] + def createAggregateFunction(context: SubstraitContext, aggregateFunc: AggregateFunction): Long = { val expressionExtensionTransformer = ExpressionExtensionTrait.findExpressionExtension(aggregateFunc.getClass) if (expressionExtensionTransformer.nonEmpty) { val (substraitAggFuncName, inputTypes) = expressionExtensionTransformer.get.buildCustomAggregateFunction(aggregateFunc) assert(substraitAggFuncName.isDefined) - return ExpressionBuilder.newScalarFunction( - functionMap, + return context.registerFunction( ConverterUtils.makeFuncName(substraitAggFuncName.get, inputTypes, FunctionConfig.REQ)) } - AggregateFunctionsBuilder.create(args, aggregateFunc) + AggregateFunctionsBuilder.create(context, aggregateFunc) } } diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/utils/PlanNodesUtil.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/utils/PlanNodesUtil.scala index d6511f7a4a29..5026cfa4f2aa 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/utils/PlanNodesUtil.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/utils/PlanNodesUtil.scala @@ -42,13 +42,12 @@ object PlanNodesUtil { // project operatorId = context.nextOperatorId("ClickHouseBuildSideRelationProjection") - val args = context.registeredFunction val columnarProjExpr = ExpressionConverter .replaceWithExpressionTransformer(key, attributeSeq = output) val projExprNodeList = new java.util.ArrayList[ExpressionNode]() - columnarProjExpr.foreach(e => projExprNodeList.add(e.doTransform(args))) + columnarProjExpr.foreach(e => projExprNodeList.add(e.doTransform(context))) PlanBuilder.makePlan( context, diff --git a/backends-clickhouse/src/main/scala/org/apache/spark/shuffle/utils/RangePartitionerBoundsGenerator.scala b/backends-clickhouse/src/main/scala/org/apache/spark/shuffle/utils/RangePartitionerBoundsGenerator.scala index 694035b878a5..adba216b4319 100644 --- a/backends-clickhouse/src/main/scala/org/apache/spark/shuffle/utils/RangePartitionerBoundsGenerator.scala +++ b/backends-clickhouse/src/main/scala/org/apache/spark/shuffle/utils/RangePartitionerBoundsGenerator.scala @@ -120,10 +120,9 @@ class RangePartitionerBoundsGenerator[K: Ordering: ClassTag, V]( context: SubstraitContext, ordering: SortOrder, attributes: Seq[Attribute]): Int = { - val funcs = context.registeredFunction val projExprNode = ExpressionConverter .replaceWithExpressionTransformer(ordering.child, attributes) - .doTransform(funcs) + .doTransform(context) val pb = projExprNode.toProtobuf if (!pb.hasSelection) { throw new IllegalArgumentException(s"A sorting field should be an attribute") @@ -135,7 +134,6 @@ class RangePartitionerBoundsGenerator[K: Ordering: ClassTag, V]( private def buildProjectionPlan( context: SubstraitContext, sortExpressions: Seq[NamedExpression]): PlanNode = { - val args = context.registeredFunction val columnarProjExprs = sortExpressions.map( expr => { ExpressionConverter @@ -143,7 +141,7 @@ class RangePartitionerBoundsGenerator[K: Ordering: ClassTag, V]( }) val projExprNodeList = new java.util.ArrayList[ExpressionNode]() for (expr <- columnarProjExprs) { - projExprNodeList.add(expr.doTransform(args)) + projExprNodeList.add(expr.doTransform(context)) } val projectRel = RelBuilder.makeProjectRel(null, projExprNodeList, context, 0) val outNames = new util.ArrayList[String] diff --git a/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/datasources/clickhouse/utils/MergeTreePartsPartitionsUtil.scala b/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/datasources/clickhouse/utils/MergeTreePartsPartitionsUtil.scala index 096212b80ba8..c6511a1d7091 100644 --- a/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/datasources/clickhouse/utils/MergeTreePartsPartitionsUtil.scala +++ b/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/datasources/clickhouse/utils/MergeTreePartsPartitionsUtil.scala @@ -628,7 +628,7 @@ object MergeTreePartsPartitionsUtil extends Logging { typeNodes, nameList, columnTypeNodes, - transformer.map(_.doTransform(substraitContext.registeredFunction)).orNull, + transformer.map(_.doTransform(substraitContext)).orNull, extensionNode, substraitContext, substraitContext.nextOperatorId("readRel") diff --git a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxTransformerApi.scala b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxTransformerApi.scala index 9949f8822a85..bfc4a27b511d 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxTransformerApi.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxTransformerApi.scala @@ -23,6 +23,7 @@ import org.apache.gluten.execution.datasource.GlutenFormatFactory import org.apache.gluten.expression.ConverterUtils import org.apache.gluten.proto.ConfigMap import org.apache.gluten.runtime.Runtimes +import org.apache.gluten.substrait.SubstraitContext import org.apache.gluten.substrait.expression.{ExpressionBuilder, ExpressionNode} import org.apache.gluten.utils.InputPartitionsUtil import org.apache.gluten.vectorized.PlanEvaluatorJniWrapper @@ -73,7 +74,7 @@ class VeloxTransformerApi extends TransformerApi with Logging { } override def createCheckOverflowExprNode( - args: java.lang.Object, + context: SubstraitContext, substraitExprName: String, childNode: ExpressionNode, childResultType: DataType, diff --git a/backends-velox/src/main/scala/org/apache/gluten/execution/HashAggregateExecTransformer.scala b/backends-velox/src/main/scala/org/apache/gluten/execution/HashAggregateExecTransformer.scala index 70a81e1cdb8e..fbf32cc45e8c 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/execution/HashAggregateExecTransformer.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/execution/HashAggregateExecTransformer.scala @@ -37,8 +37,7 @@ import org.apache.spark.sql.types._ import com.google.protobuf.StringValue -import java.lang.{Long => JLong} -import java.util.{ArrayList => JArrayList, HashMap => JHashMap, List => JList} +import java.util.{ArrayList => JArrayList, List => JList} import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer @@ -201,7 +200,7 @@ abstract class HashAggregateExecTransformer( // Create aggregate function node and add to list. private def addFunctionNode( - args: java.lang.Object, + context: SubstraitContext, aggregateFunction: AggregateFunction, childrenNodeList: JList[ExpressionNode], aggregateMode: AggregateMode, @@ -212,7 +211,7 @@ abstract class HashAggregateExecTransformer( aggregateMode match { case Partial | PartialMerge => val aggFunctionNode = ExpressionBuilder.makeAggregateFunction( - VeloxAggregateFunctionsBuilder.create(args, aggregateFunction, aggregateMode), + VeloxAggregateFunctionsBuilder.create(context, aggregateFunction, aggregateMode), childrenNodeList, modeKeyWord, VeloxIntermediateData.getIntermediateTypeNode(aggregateFunction) @@ -220,7 +219,7 @@ abstract class HashAggregateExecTransformer( aggregateNodeList.add(aggFunctionNode) case Final | Complete => val aggFunctionNode = ExpressionBuilder.makeAggregateFunction( - VeloxAggregateFunctionsBuilder.create(args, aggregateFunction, aggregateMode), + VeloxAggregateFunctionsBuilder.create(context, aggregateFunction, aggregateMode), childrenNodeList, modeKeyWord, ConverterUtils.getTypeNode(aggregateFunction.dataType, aggregateFunction.nullable) @@ -238,7 +237,7 @@ abstract class HashAggregateExecTransformer( aggregateMode match { case Partial | PartialMerge => val partialNode = ExpressionBuilder.makeAggregateFunction( - VeloxAggregateFunctionsBuilder.create(args, aggregateFunction, aggregateMode), + VeloxAggregateFunctionsBuilder.create(context, aggregateFunction, aggregateMode), childrenNodeList, modeKeyWord, ConverterUtils.getTypeNode( @@ -248,7 +247,7 @@ abstract class HashAggregateExecTransformer( aggregateNodeList.add(partialNode) case Final | Complete => val aggFunctionNode = ExpressionBuilder.makeAggregateFunction( - VeloxAggregateFunctionsBuilder.create(args, aggregateFunction, aggregateMode), + VeloxAggregateFunctionsBuilder.create(context, aggregateFunction, aggregateMode), childrenNodeList, modeKeyWord, ConverterUtils.getTypeNode(aggregateFunction.dataType, aggregateFunction.nullable) @@ -296,15 +295,14 @@ abstract class HashAggregateExecTransformer( // Return a scalar function node representing row construct function in Velox. private def getRowConstructNode( - args: java.lang.Object, + context: SubstraitContext, childNodes: JList[ExpressionNode], rowConstructAttributes: Seq[Attribute], aggFunc: AggregateFunction): ScalarFunctionNode = { - val functionMap = args.asInstanceOf[JHashMap[String, JLong]] val functionName = ConverterUtils.makeFuncName( VeloxIntermediateData.getRowConstructFuncName(aggFunc), rowConstructAttributes.map(attr => attr.dataType)) - val functionId = ExpressionBuilder.newScalarFunction(functionMap, functionName) + val functionId = context.registerFunction(functionName) // Use struct type to represent Velox RowType. val structTypeNodes = rowConstructAttributes @@ -326,7 +324,6 @@ abstract class HashAggregateExecTransformer( operatorId: Long, inputRel: RelNode, validation: Boolean): RelNode = { - val args = context.registeredFunction // Create a projection for row construct. val exprNodes = new JArrayList[ExpressionNode]() groupingExpressions.foreach( @@ -334,7 +331,7 @@ abstract class HashAggregateExecTransformer( exprNodes.add( ExpressionConverter .replaceWithExpressionTransformer(expr, originalInputAttributes) - .doTransform(args)) + .doTransform(context)) }) for (aggregateExpression <- aggregateExpressions) { @@ -346,7 +343,7 @@ abstract class HashAggregateExecTransformer( .map( ExpressionConverter .replaceWithExpressionTransformer(_, originalInputAttributes) - .doTransform(args) + .doTransform(context) ) .asJava exprNodes.addAll(childNodes) @@ -387,7 +384,7 @@ abstract class HashAggregateExecTransformer( val attr = rewrittenInputAttributes(adjustedIdx) val aggFuncInputAttrNode = ExpressionConverter .replaceWithExpressionTransformer(attr, originalInputAttributes) - .doTransform(args) + .doTransform(context) val expressionNode = if (sparkType != veloxType) { newInputAttributes += attr.copy(dataType = veloxType)(attr.exprId, attr.qualifier) @@ -403,7 +400,7 @@ abstract class HashAggregateExecTransformer( } } exprNodes.add( - getRowConstructNode(args, childNodes, newInputAttributes.toSeq, aggFunc)) + getRowConstructNode(context, childNodes, newInputAttributes.toSeq, aggFunc)) case other => throw new GlutenNotSupportException(s"$other is not supported.") } @@ -415,7 +412,7 @@ abstract class HashAggregateExecTransformer( .map( ExpressionConverter .replaceWithExpressionTransformer(_, originalInputAttributes) - .doTransform(args) + .doTransform(context) ) .asJava exprNodes.addAll(childNodes) @@ -469,7 +466,7 @@ abstract class HashAggregateExecTransformer( throw new GlutenNotSupportException( s"$aggFunc of ${aggExpr.mode.toString} is not supported.") } - addFunctionNode(args, aggFunc, childrenNodes, aggExpr.mode, aggregateFunctionList) + addFunctionNode(context, aggFunc, childrenNodes, aggExpr.mode, aggregateFunctionList) }) val extensionNode = getAdvancedExtension() @@ -566,7 +563,6 @@ abstract class HashAggregateExecTransformer( operatorId: Long, input: RelNode = null, validation: Boolean): RelNode = { - val args = context.registeredFunction // Get the grouping nodes. // Use 'child.output' as based Seq[Attribute], the originalInputAttributes // may be different for each backend. @@ -574,7 +570,7 @@ abstract class HashAggregateExecTransformer( .map( ExpressionConverter .replaceWithExpressionTransformer(_, child.output) - .doTransform(args)) + .doTransform(context)) .asJava // Get the aggregate function nodes. val aggFilterList = new JArrayList[ExpressionNode]() @@ -584,7 +580,7 @@ abstract class HashAggregateExecTransformer( if (aggExpr.filter.isDefined) { val exprNode = ExpressionConverter .replaceWithExpressionTransformer(aggExpr.filter.get, child.output) - .doTransform(args) + .doTransform(context) aggFilterList.add(exprNode) } else { // The number of filters should be aligned with that of aggregate functions. @@ -597,7 +593,7 @@ abstract class HashAggregateExecTransformer( expr => { ExpressionConverter .replaceWithExpressionTransformer(expr, originalInputAttributes) - .doTransform(args) + .doTransform(context) }) case PartialMerge | Final => rewriteAggBufferAttributes( @@ -606,13 +602,13 @@ abstract class HashAggregateExecTransformer( attr => ExpressionConverter .replaceWithExpressionTransformer(attr, originalInputAttributes) - .doTransform(args) + .doTransform(context) } case other => throw new GlutenNotSupportException(s"$other not supported.") } addFunctionNode( - args, + context, aggregateFunc, childrenNodes.asJava, aggExpr.mode, @@ -662,8 +658,8 @@ object VeloxAggregateFunctionsBuilder { /** * Create a scalar function for the input aggregate function. - * @param args: - * the function map. + * @param context: + * the SubstraitContext. * @param aggregateFunc: * the input aggregate function. * @param mode: @@ -671,10 +667,9 @@ object VeloxAggregateFunctionsBuilder { * @return */ def create( - args: java.lang.Object, + context: SubstraitContext, aggregateFunc: AggregateFunction, mode: AggregateMode): Long = { - val functionMap = args.asInstanceOf[JHashMap[String, JLong]] val (sigName, aggFunc) = try { (AggregateFunctionsBuilder.getSubstraitFunctionName(aggregateFunc), aggregateFunc) @@ -688,8 +683,7 @@ object VeloxAggregateFunctionsBuilder { case e: Throwable => throw e } - ExpressionBuilder.newScalarFunction( - functionMap, + context.registerFunction( ConverterUtils.makeFuncName( // Substrait-to-Velox procedure will choose appropriate companion function if needed. sigName, diff --git a/backends-velox/src/main/scala/org/apache/gluten/execution/TopNTransformer.scala b/backends-velox/src/main/scala/org/apache/gluten/execution/TopNTransformer.scala index f3bc929d7eb5..50e7cf9c5192 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/execution/TopNTransformer.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/execution/TopNTransformer.scala @@ -87,13 +87,12 @@ case class TopNTransformer( inputAttributes: Seq[Attribute], input: RelNode, validation: Boolean): RelNode = { - val args = context.registeredFunction val sortFieldList = sortOrder.map { order => val builder = SortField.newBuilder() val exprNode = ExpressionConverter .replaceWithExpressionTransformer(order.child, attributeSeq = child.output) - .doTransform(args) + .doTransform(context) builder.setExpr(exprNode.toProtobuf) builder.setDirectionValue(SortExecTransformer.transformSortDirection(order)) diff --git a/backends-velox/src/main/scala/org/apache/gluten/expression/ExpressionTransformer.scala b/backends-velox/src/main/scala/org/apache/gluten/expression/ExpressionTransformer.scala index 1c47ff2a1d21..a5e77920e485 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/expression/ExpressionTransformer.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/expression/ExpressionTransformer.scala @@ -19,13 +19,14 @@ package org.apache.gluten.expression import org.apache.gluten.expression.ConverterUtils.FunctionConfig import org.apache.gluten.expression.ExpressionConverter.replaceWithExpressionTransformer import org.apache.gluten.substrait.`type`.StructNode +import org.apache.gluten.substrait.SubstraitContext import org.apache.gluten.substrait.expression._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types.{IntegerType, LongType} -import java.lang.{Integer => JInteger, Long => JLong} -import java.util.{ArrayList => JArrayList, HashMap => JHashMap} +import java.lang.{Integer => JInteger} +import java.util.{ArrayList => JArrayList} import scala.language.existentials @@ -35,8 +36,8 @@ case class VeloxAliasTransformer( original: Expression) extends UnaryExpressionTransformer { - override def doTransform(args: java.lang.Object): ExpressionNode = { - child.doTransform(args) + override def doTransform(context: SubstraitContext): ExpressionNode = { + child.doTransform(context) } } @@ -58,8 +59,8 @@ case class VeloxGetStructFieldTransformer( extends BinaryExpressionTransformer { override def left: ExpressionTransformer = child override def right: ExpressionTransformer = LiteralTransformer(ordinal) - override def doTransform(args: Object): ExpressionNode = { - val childNode = child.doTransform(args) + override def doTransform(context: SubstraitContext): ExpressionNode = { + val childNode = child.doTransform(context) childNode match { case node: StructLiteralNode => node.getFieldLiteral(ordinal) @@ -71,7 +72,7 @@ case class VeloxGetStructFieldTransformer( node.getTypeNode.asInstanceOf[StructNode].getFieldTypes.get(ordinal) ExpressionBuilder.makeNullLiteral(nodeType) case _ => - super.doTransform(args) + super.doTransform(context) } } } @@ -82,7 +83,7 @@ case class VeloxHashExpressionTransformer( original: HashExpression[_]) extends ExpressionTransformer { - override def doTransform(args: java.lang.Object): ExpressionNode = { + override def doTransform(context: SubstraitContext): ExpressionNode = { // As of Spark 3.3, there are 3 kinds of HashExpression. // HiveHash is not supported in native backend and will fail native validation. val (seedNode, seedType) = original match { @@ -98,13 +99,12 @@ case class VeloxHashExpressionTransformer( nodes.add(seedNode) children.foreach( expression => { - nodes.add(expression.doTransform(args)) + nodes.add(expression.doTransform(context)) }) val childrenTypes = seedType +: original.children.map(child => child.dataType) - val functionMap = args.asInstanceOf[JHashMap[String, JLong]] val functionName = ConverterUtils.makeFuncName(substraitExprName, childrenTypes, FunctionConfig.OPT) - val functionId = ExpressionBuilder.newScalarFunction(functionMap, functionName) + val functionId = context.registerFunction(functionName) val typeNode = ConverterUtils.getTypeNode(original.dataType, original.nullable) ExpressionBuilder.makeScalarFunction(functionId, nodes, typeNode) } diff --git a/gluten-delta/src-delta-32/main/scala/org/apache/gluten/execution/DeltaFilterExecTransformer.scala b/gluten-delta/src-delta-32/main/scala/org/apache/gluten/execution/DeltaFilterExecTransformer.scala index 0c8cd54902c2..b71ff4ca4bd8 100644 --- a/gluten-delta/src-delta-32/main/scala/org/apache/gluten/execution/DeltaFilterExecTransformer.scala +++ b/gluten-delta/src-delta-32/main/scala/org/apache/gluten/execution/DeltaFilterExecTransformer.scala @@ -49,17 +49,16 @@ case class DeltaFilterExecTransformer(condition: Expression, child: SparkPlan) input: RelNode, validation: Boolean): RelNode = { assert(condExpr != null) - val args = context.registeredFunction val condExprNode = condExpr match { case IncrementMetric(child, metric) => extraMetrics :+= (condExpr.prettyName, metric) ExpressionConverter .replaceWithExpressionTransformer(child, attributeSeq = originalInputAttributes) - .doTransform(args) + .doTransform(context) case _ => ExpressionConverter .replaceWithExpressionTransformer(condExpr, attributeSeq = originalInputAttributes) - .doTransform(args) + .doTransform(context) } if (!validation) { diff --git a/gluten-delta/src-delta-32/main/scala/org/apache/gluten/execution/DeltaProjectExecTransformer.scala b/gluten-delta/src-delta-32/main/scala/org/apache/gluten/execution/DeltaProjectExecTransformer.scala index a2be01a1f024..39e8d5bfa972 100644 --- a/gluten-delta/src-delta-32/main/scala/org/apache/gluten/execution/DeltaProjectExecTransformer.scala +++ b/gluten-delta/src-delta-32/main/scala/org/apache/gluten/execution/DeltaProjectExecTransformer.scala @@ -49,11 +49,10 @@ case class DeltaProjectExecTransformer(projectList: Seq[NamedExpression], child: operatorId: Long, input: RelNode, validation: Boolean): RelNode = { - val args = context.registeredFunction val newProjectList = genNewProjectList(projectList) val columnarProjExprs: Seq[ExpressionTransformer] = ExpressionConverter .replaceWithExpressionTransformer(newProjectList, attributeSeq = originalInputAttributes) - val projExprNodeList = columnarProjExprs.map(_.doTransform(args)).asJava + val projExprNodeList = columnarProjExprs.map(_.doTransform(context)).asJava val emitStartIndex = originalInputAttributes.size if (!validation) { RelBuilder.makeProjectRel(input, projExprNodeList, context, operatorId, emitStartIndex) diff --git a/gluten-substrait/src/main/java/org/apache/gluten/substrait/expression/ExpressionBuilder.java b/gluten-substrait/src/main/java/org/apache/gluten/substrait/expression/ExpressionBuilder.java index 1e6c58f682ca..07c73452bdcf 100644 --- a/gluten-substrait/src/main/java/org/apache/gluten/substrait/expression/ExpressionBuilder.java +++ b/gluten-substrait/src/main/java/org/apache/gluten/substrait/expression/ExpressionBuilder.java @@ -36,16 +36,6 @@ public class ExpressionBuilder { private ExpressionBuilder() {} - public static Long newScalarFunction(Map functionMap, String functionName) { - if (!functionMap.containsKey(functionName)) { - Long functionId = (long) functionMap.size(); - functionMap.put(functionName, functionId); - return functionId; - } else { - return functionMap.get(functionName); - } - } - public static NullLiteralNode makeNullLiteral(TypeNode typeNode) { return new NullLiteralNode(typeNode); } diff --git a/gluten-substrait/src/main/java/org/apache/gluten/substrait/expression/WindowFunctionNode.java b/gluten-substrait/src/main/java/org/apache/gluten/substrait/expression/WindowFunctionNode.java index b9f1fbc126cc..a114c6050a17 100644 --- a/gluten-substrait/src/main/java/org/apache/gluten/substrait/expression/WindowFunctionNode.java +++ b/gluten-substrait/src/main/java/org/apache/gluten/substrait/expression/WindowFunctionNode.java @@ -18,6 +18,7 @@ import org.apache.gluten.exception.GlutenException; import org.apache.gluten.expression.ExpressionConverter; +import org.apache.gluten.substrait.SubstraitContext; import org.apache.gluten.substrait.type.TypeNode; import io.substrait.proto.Expression; @@ -29,7 +30,6 @@ import java.io.Serializable; import java.util.ArrayList; -import java.util.HashMap; import java.util.List; import scala.collection.JavaConverters; @@ -104,7 +104,7 @@ private Expression.WindowFunction.Bound.Builder setBound( JavaConverters.asScalaIteratorConverter(originalInputAttributes.iterator()) .asScala() .toSeq()) - .doTransform(new HashMap()); + .doTransform(new SubstraitContext()); Long offset = Long.valueOf(boundType.eval(null).toString()); if (offset < 0) { Expression.WindowFunction.Bound.Preceding.Builder refPrecedingBuilder = diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala b/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala index 1bb5a255f5e1..a798053f6f36 100644 --- a/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala +++ b/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala @@ -20,6 +20,7 @@ import org.apache.gluten.exception.GlutenNotSupportException import org.apache.gluten.execution._ import org.apache.gluten.expression._ import org.apache.gluten.sql.shims.SparkShimLoader +import org.apache.gluten.substrait.SubstraitContext import org.apache.gluten.substrait.expression.{ExpressionBuilder, ExpressionNode, WindowFunctionNode} import org.apache.spark.ShuffleDependency @@ -46,8 +47,7 @@ import org.apache.spark.sql.types.{DecimalType, LongType, NullType, StructType} import org.apache.spark.sql.vectorized.ColumnarBatch import java.io.{ObjectInputStream, ObjectOutputStream} -import java.lang.{Long => JLong} -import java.util.{ArrayList => JArrayList, List => JList, Map => JMap} +import java.util.{ArrayList => JArrayList, List => JList} import scala.collection.JavaConverters._ @@ -483,7 +483,7 @@ trait SparkPlanExecApi { windowExpression: Seq[NamedExpression], windowExpressionNodes: JList[WindowFunctionNode], originalInputAttributes: Seq[Attribute], - args: JMap[String, JLong]): Unit = { + context: SubstraitContext): Unit = { windowExpression.map { windowExpr => val aliasExpr = windowExpr.asInstanceOf[Alias] @@ -494,7 +494,7 @@ trait SparkPlanExecApi { val aggWindowFunc = wf.asInstanceOf[AggregateWindowFunction] val frame = aggWindowFunc.frame.asInstanceOf[SpecifiedWindowFrame] val windowFunctionNode = ExpressionBuilder.makeWindowFunction( - WindowFunctionsBuilder.create(args, aggWindowFunc).toInt, + WindowFunctionsBuilder.create(context, aggWindowFunc).toInt, new JArrayList[ExpressionNode](), columnName, ConverterUtils.getTypeNode(aggWindowFunc.dataType, aggWindowFunc.nullable), @@ -516,11 +516,11 @@ trait SparkPlanExecApi { .map( ExpressionConverter .replaceWithExpressionTransformer(_, originalInputAttributes) - .doTransform(args)) + .doTransform(context)) .asJava val windowFunctionNode = ExpressionBuilder.makeWindowFunction( - AggregateFunctionsBuilder.create(args, aggExpression.aggregateFunction).toInt, + AggregateFunctionsBuilder.create(context, aggExpression.aggregateFunction).toInt, childrenNodeList, columnName, ConverterUtils.getTypeNode(aggExpression.dataType, aggExpression.nullable), @@ -539,7 +539,7 @@ trait SparkPlanExecApi { .replaceWithExpressionTransformer( offsetWf.input, attributeSeq = originalInputAttributes) - .doTransform(args)) + .doTransform(context)) // Spark only accepts foldable offset. Converts it to LongType literal. val offset = offsetWf.offset.eval(EmptyRow).asInstanceOf[Int] // Velox only allows negative offset. WindowFunctionsBuilder#create converts @@ -554,10 +554,10 @@ trait SparkPlanExecApi { .replaceWithExpressionTransformer( offsetWf.default, attributeSeq = originalInputAttributes) - .doTransform(args)) + .doTransform(context)) } val windowFunctionNode = ExpressionBuilder.makeWindowFunction( - WindowFunctionsBuilder.create(args, offsetWf).toInt, + WindowFunctionsBuilder.create(context, offsetWf).toInt, childrenNodeList, columnName, ConverterUtils.getTypeNode(offsetWf.dataType, offsetWf.nullable), @@ -574,10 +574,10 @@ trait SparkPlanExecApi { childrenNodeList.add( ExpressionConverter .replaceWithExpressionTransformer(input, attributeSeq = originalInputAttributes) - .doTransform(args)) - childrenNodeList.add(LiteralTransformer(offset).doTransform(args)) + .doTransform(context)) + childrenNodeList.add(LiteralTransformer(offset).doTransform(context)) val windowFunctionNode = ExpressionBuilder.makeWindowFunction( - WindowFunctionsBuilder.create(args, wf).toInt, + WindowFunctionsBuilder.create(context, wf).toInt, childrenNodeList, columnName, ConverterUtils.getTypeNode(wf.dataType, wf.nullable), @@ -592,9 +592,9 @@ trait SparkPlanExecApi { val frame = wExpression.windowSpec.frameSpecification.asInstanceOf[SpecifiedWindowFrame] val childrenNodeList = new JArrayList[ExpressionNode]() val literal = buckets.asInstanceOf[Literal] - childrenNodeList.add(LiteralTransformer(literal).doTransform(args)) + childrenNodeList.add(LiteralTransformer(literal).doTransform(context)) val windowFunctionNode = ExpressionBuilder.makeWindowFunction( - WindowFunctionsBuilder.create(args, wf).toInt, + WindowFunctionsBuilder.create(context, wf).toInt, childrenNodeList, columnName, ConverterUtils.getTypeNode(wf.dataType, wf.nullable), diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/TransformerApi.scala b/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/TransformerApi.scala index 984450bf164e..92d6ebd32574 100644 --- a/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/TransformerApi.scala +++ b/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/TransformerApi.scala @@ -17,6 +17,7 @@ package org.apache.gluten.backendsapi import org.apache.gluten.execution.WriteFilesExecTransformer +import org.apache.gluten.substrait.SubstraitContext import org.apache.gluten.substrait.expression.ExpressionNode import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} @@ -61,7 +62,7 @@ trait TransformerApi { } def createCheckOverflowExprNode( - args: java.lang.Object, + context: SubstraitContext, substraitExprName: String, childNode: ExpressionNode, childResultType: DataType, diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/execution/BasicPhysicalOperatorTransformer.scala b/gluten-substrait/src/main/scala/org/apache/gluten/execution/BasicPhysicalOperatorTransformer.scala index d1f3462564f1..b9a7ac2e8323 100644 --- a/gluten-substrait/src/main/scala/org/apache/gluten/execution/BasicPhysicalOperatorTransformer.scala +++ b/gluten-substrait/src/main/scala/org/apache/gluten/execution/BasicPhysicalOperatorTransformer.scala @@ -79,7 +79,7 @@ abstract class FilterExecTransformerBase(val cond: Expression, val input: SparkP assert(condExpr != null) val condExprNode = ExpressionConverter .replaceWithExpressionTransformer(condExpr, originalInputAttributes) - .doTransform(context.registeredFunction) + .doTransform(context) RelBuilder.makeFilterRel( context, condExprNode, @@ -222,10 +222,9 @@ abstract class ProjectExecTransformerBase(val list: Seq[NamedExpression], val in operatorId: Long, input: RelNode, validation: Boolean): RelNode = { - val args = context.registeredFunction val columnarProjExprs: Seq[ExpressionTransformer] = ExpressionConverter .replaceWithExpressionTransformer(projectList, originalInputAttributes) - val projExprNodeList = columnarProjExprs.map(_.doTransform(args)).asJava + val projExprNodeList = columnarProjExprs.map(_.doTransform(context)).asJava RelBuilder.makeProjectRel( originalInputAttributes.asJava, input, diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/execution/BasicScanExecTransformer.scala b/gluten-substrait/src/main/scala/org/apache/gluten/execution/BasicScanExecTransformer.scala index 056c35a527cb..b1ba7820d5dd 100644 --- a/gluten-substrait/src/main/scala/org/apache/gluten/execution/BasicScanExecTransformer.scala +++ b/gluten-substrait/src/main/scala/org/apache/gluten/execution/BasicScanExecTransformer.scala @@ -139,7 +139,7 @@ trait BasicScanExecTransformer extends LeafTransformSupport with BaseDataSource .map(ExpressionConverter.replaceAttributeReference) .reduceLeftOption(And) .map(ExpressionConverter.replaceWithExpressionTransformer(_, output)) - val filterNodes = transformer.map(_.doTransform(context.registeredFunction)) + val filterNodes = transformer.map(_.doTransform(context)) val exprNode = filterNodes.orNull // used by CH backend diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/execution/CartesianProductExecTransformer.scala b/gluten-substrait/src/main/scala/org/apache/gluten/execution/CartesianProductExecTransformer.scala index 9e2f12bcf8ed..ec71afe03c4f 100644 --- a/gluten-substrait/src/main/scala/org/apache/gluten/execution/CartesianProductExecTransformer.scala +++ b/gluten-substrait/src/main/scala/org/apache/gluten/execution/CartesianProductExecTransformer.scala @@ -127,7 +127,7 @@ case class CartesianProductExecTransformer( expr => ExpressionConverter .replaceWithExpressionTransformer(expr, left.output ++ right.output) - .doTransform(substraitContext.registeredFunction) + .doTransform(substraitContext) } val extensionNode = JoinUtils.createExtensionNode(left.output ++ right.output, validation = true) diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/execution/ExpandExecTransformer.scala b/gluten-substrait/src/main/scala/org/apache/gluten/execution/ExpandExecTransformer.scala index c6936daaffe5..8a40fc041220 100644 --- a/gluten-substrait/src/main/scala/org/apache/gluten/execution/ExpandExecTransformer.scala +++ b/gluten-substrait/src/main/scala/org/apache/gluten/execution/ExpandExecTransformer.scala @@ -67,7 +67,6 @@ case class ExpandExecTransformer( operatorId: Long, input: RelNode, validation: Boolean): RelNode = { - val args = context.registeredFunction val projectSetExprNodes = new JArrayList[JList[ExpressionNode]]() projections.foreach { projectSet => @@ -76,7 +75,7 @@ case class ExpandExecTransformer( project => val projectExprNode = ExpressionConverter .replaceWithExpressionTransformer(project, originalInputAttributes) - .doTransform(args) + .doTransform(context) projectExprNodes.add(projectExprNode) } projectSetExprNodes.add(projectExprNodes) diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/execution/GenerateExecTransformerBase.scala b/gluten-substrait/src/main/scala/org/apache/gluten/execution/GenerateExecTransformerBase.scala index 698d1f14c5b9..20c1e088b5d7 100644 --- a/gluten-substrait/src/main/scala/org/apache/gluten/execution/GenerateExecTransformerBase.scala +++ b/gluten-substrait/src/main/scala/org/apache/gluten/execution/GenerateExecTransformerBase.scala @@ -94,5 +94,5 @@ abstract class GenerateExecTransformerBase( private def getGeneratorNode(context: SubstraitContext): ExpressionNode = ExpressionConverter .replaceWithExpressionTransformer(generator, child.output) - .doTransform(context.registeredFunction) + .doTransform(context) } diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/execution/JoinExecTransformer.scala b/gluten-substrait/src/main/scala/org/apache/gluten/execution/JoinExecTransformer.scala index 86e6c1f41265..439b5689e88e 100644 --- a/gluten-substrait/src/main/scala/org/apache/gluten/execution/JoinExecTransformer.scala +++ b/gluten-substrait/src/main/scala/org/apache/gluten/execution/JoinExecTransformer.scala @@ -41,9 +41,6 @@ import com.google.common.collect.Lists import com.google.protobuf.{Any, StringValue} import io.substrait.proto.JoinRel -import java.lang.{Long => JLong} -import java.util.{Map => JMap} - trait ColumnarShuffledJoin extends BaseJoinExec { def isSkewJoin: Boolean @@ -324,9 +321,8 @@ object HashJoinLikeExecTransformer { leftType: DataType, rightNode: ExpressionNode, rightType: DataType, - functionMap: JMap[String, JLong]): ExpressionNode = { - val functionId = ExpressionBuilder.newScalarFunction( - functionMap, + context: SubstraitContext): ExpressionNode = { + val functionId = context.registerFunction( ConverterUtils.makeFuncName(ExpressionNames.EQUAL, Seq(leftType, rightType))) val expressionNodes = Lists.newArrayList(leftNode, rightNode) @@ -338,9 +334,8 @@ object HashJoinLikeExecTransformer { def makeAndExpression( leftNode: ExpressionNode, rightNode: ExpressionNode, - functionMap: JMap[String, JLong]): ExpressionNode = { - val functionId = ExpressionBuilder.newScalarFunction( - functionMap, + context: SubstraitContext): ExpressionNode = { + val functionId = context.registerFunction( ConverterUtils.makeFuncName(ExpressionNames.AND, Seq(BooleanType, BooleanType))) val expressionNodes = Lists.newArrayList(leftNode, rightNode) diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/execution/JoinUtils.scala b/gluten-substrait/src/main/scala/org/apache/gluten/execution/JoinUtils.scala index 303c9e818f56..12b544de90fd 100644 --- a/gluten-substrait/src/main/scala/org/apache/gluten/execution/JoinUtils.scala +++ b/gluten-substrait/src/main/scala/org/apache/gluten/execution/JoinUtils.scala @@ -64,7 +64,7 @@ object JoinUtils { ExpressionConverter .replaceWithExpressionTransformer(expr, partialConstructedJoinOutput) .asInstanceOf[AttributeReferenceTransformer] - .doTransform(substraitContext.registeredFunction), + .doTransform(substraitContext), expr.dataType) } (keys, inputNode, inputNodeOutput) @@ -78,7 +78,7 @@ object JoinUtils { ( ExpressionConverter .replaceWithExpressionTransformer(expr, inputNodeOutput) - .doTransform(substraitContext.registeredFunction), + .doTransform(substraitContext), expr.dataType)) } val preProjectNode = RelBuilder.makeProjectRel( @@ -100,7 +100,7 @@ object JoinUtils { ExpressionConverter .replaceWithExpressionTransformer(a, partialConstructedJoinOutput) .asInstanceOf[AttributeReferenceTransformer] - .doTransform(substraitContext.registeredFunction), + .doTransform(substraitContext), a.dataType) case _ => val (key, idx) = appendedKeysAndIndices.next() @@ -207,11 +207,9 @@ object JoinUtils { leftType, rightKey, rightType, - substraitContext.registeredFunction) + substraitContext) } - .reduce( - (l, r) => - HashJoinLikeExecTransformer.makeAndExpression(l, r, substraitContext.registeredFunction)) + .reduce((l, r) => HashJoinLikeExecTransformer.makeAndExpression(l, r, substraitContext)) // Create post-join filter, which will be computed in hash join. val postJoinFilter = diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/execution/SampleExecTransformer.scala b/gluten-substrait/src/main/scala/org/apache/gluten/execution/SampleExecTransformer.scala index 4fed8b36e90d..c3a70bb81a26 100644 --- a/gluten-substrait/src/main/scala/org/apache/gluten/execution/SampleExecTransformer.scala +++ b/gluten-substrait/src/main/scala/org/apache/gluten/execution/SampleExecTransformer.scala @@ -78,7 +78,7 @@ case class SampleExecTransformer( assert(condExpr != null) val condExprNode = ExpressionConverter .replaceWithExpressionTransformer(condExpr, originalInputAttributes) - .doTransform(context.registeredFunction) + .doTransform(context) RelBuilder.makeFilterRel( context, condExprNode, diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/execution/SortExecTransformer.scala b/gluten-substrait/src/main/scala/org/apache/gluten/execution/SortExecTransformer.scala index 6f9564e6d54f..5ee9e3f3818b 100644 --- a/gluten-substrait/src/main/scala/org/apache/gluten/execution/SortExecTransformer.scala +++ b/gluten-substrait/src/main/scala/org/apache/gluten/execution/SortExecTransformer.scala @@ -68,13 +68,12 @@ case class SortExecTransformer( operatorId: Long, input: RelNode, validation: Boolean): RelNode = { - val args = context.registeredFunction val sortFieldList = sortOrder.map { order => val builder = SortField.newBuilder() val exprNode = ExpressionConverter .replaceWithExpressionTransformer(order.child, attributeSeq = child.output) - .doTransform(args) + .doTransform(context) builder.setExpr(exprNode.toProtobuf) builder.setDirectionValue(SortExecTransformer.transformSortDirection(order)) diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/execution/WindowExecTransformer.scala b/gluten-substrait/src/main/scala/org/apache/gluten/execution/WindowExecTransformer.scala index 792885ef2f54..2c934b1b5c60 100644 --- a/gluten-substrait/src/main/scala/org/apache/gluten/execution/WindowExecTransformer.scala +++ b/gluten-substrait/src/main/scala/org/apache/gluten/execution/WindowExecTransformer.scala @@ -105,14 +105,13 @@ case class WindowExecTransformer( operatorId: Long, input: RelNode, validation: Boolean): RelNode = { - val args = context.registeredFunction // WindowFunction Expressions val windowExpressions = new JArrayList[WindowFunctionNode]() BackendsApiManager.getSparkPlanExecApiInstance.genWindowFunctionsNode( windowExpression, windowExpressions, originalInputAttributes, - args + context ) // Partition By Expressions @@ -120,7 +119,7 @@ case class WindowExecTransformer( .map( ExpressionConverter .replaceWithExpressionTransformer(_, attributeSeq = child.output) - .doTransform(args)) + .doTransform(context)) .asJava // Sort By Expressions @@ -130,7 +129,7 @@ case class WindowExecTransformer( val builder = SortField.newBuilder() val exprNode = ExpressionConverter .replaceWithExpressionTransformer(order.child, attributeSeq = child.output) - .doTransform(args) + .doTransform(context) builder.setExpr(exprNode.toProtobuf) builder.setDirectionValue(SortExecTransformer.transformSortDirection(order)) builder.build() diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/execution/WindowGroupLimitExecTransformer.scala b/gluten-substrait/src/main/scala/org/apache/gluten/execution/WindowGroupLimitExecTransformer.scala index d96d04dfadb7..5ee7bdd684f4 100644 --- a/gluten-substrait/src/main/scala/org/apache/gluten/execution/WindowGroupLimitExecTransformer.scala +++ b/gluten-substrait/src/main/scala/org/apache/gluten/execution/WindowGroupLimitExecTransformer.scala @@ -91,13 +91,12 @@ case class WindowGroupLimitExecTransformer( operatorId: Long, input: RelNode, validation: Boolean): RelNode = { - val args = context.registeredFunction // Partition By Expressions val partitionsExpressions = partitionSpec .map( ExpressionConverter .replaceWithExpressionTransformer(_, attributeSeq = child.output) - .doTransform(args)) + .doTransform(context)) .asJava // Sort By Expressions @@ -107,7 +106,7 @@ case class WindowGroupLimitExecTransformer( val builder = SortField.newBuilder() val exprNode = ExpressionConverter .replaceWithExpressionTransformer(order.child, attributeSeq = child.output) - .doTransform(args) + .doTransform(context) builder.setExpr(exprNode.toProtobuf) builder.setDirectionValue(SortExecTransformer.transformSortDirection(order)) builder.build() diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/expression/AggregateFunctionsBuilder.scala b/gluten-substrait/src/main/scala/org/apache/gluten/expression/AggregateFunctionsBuilder.scala index 15de4a734d53..a567903edadc 100644 --- a/gluten-substrait/src/main/scala/org/apache/gluten/expression/AggregateFunctionsBuilder.scala +++ b/gluten-substrait/src/main/scala/org/apache/gluten/expression/AggregateFunctionsBuilder.scala @@ -19,15 +19,13 @@ package org.apache.gluten.expression import org.apache.gluten.backendsapi.BackendsApiManager import org.apache.gluten.exception.GlutenNotSupportException import org.apache.gluten.expression.ConverterUtils.FunctionConfig -import org.apache.gluten.substrait.expression.ExpressionBuilder +import org.apache.gluten.substrait.SubstraitContext import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.types.DataType object AggregateFunctionsBuilder { - def create(args: java.lang.Object, aggregateFunc: AggregateFunction): Long = { - val functionMap = args.asInstanceOf[java.util.HashMap[String, java.lang.Long]] - + def create(context: SubstraitContext, aggregateFunc: AggregateFunction): Long = { // First handle the custom aggregate functions val substraitAggFuncName = getSubstraitFunctionName(aggregateFunc) @@ -42,8 +40,7 @@ object AggregateFunctionsBuilder { val inputTypes: Seq[DataType] = aggregateFunc.children.map(child => child.dataType) - ExpressionBuilder.newScalarFunction( - functionMap, + context.registerFunction( ConverterUtils.makeFuncName(substraitAggFuncName, inputTypes, FunctionConfig.REQ)) } diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/expression/ArrayExpressionTransformer.scala b/gluten-substrait/src/main/scala/org/apache/gluten/expression/ArrayExpressionTransformer.scala index 2a09e039e52c..765f2dafeacd 100644 --- a/gluten-substrait/src/main/scala/org/apache/gluten/expression/ArrayExpressionTransformer.scala +++ b/gluten-substrait/src/main/scala/org/apache/gluten/expression/ArrayExpressionTransformer.scala @@ -17,6 +17,7 @@ package org.apache.gluten.expression import org.apache.gluten.exception.GlutenNotSupportException +import org.apache.gluten.substrait.SubstraitContext import org.apache.gluten.substrait.expression.ExpressionNode import org.apache.spark.sql.catalyst.expressions._ @@ -27,7 +28,7 @@ case class CreateArrayTransformer( original: CreateArray) extends ExpressionTransformer { - override def doTransform(args: java.lang.Object): ExpressionNode = { + override def doTransform(context: SubstraitContext): ExpressionNode = { // If children is empty, // transformation is only supported when useStringTypeWhenEmpty is false // because ClickHouse and Velox currently doesn't support this config. @@ -35,6 +36,6 @@ case class CreateArrayTransformer( throw new GlutenNotSupportException(s"$original not supported yet.") } - super.doTransform(args) + super.doTransform(context) } } diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/expression/ConditionalTransformer.scala b/gluten-substrait/src/main/scala/org/apache/gluten/expression/ConditionalTransformer.scala index 1dffd390639e..1e3df12901f1 100644 --- a/gluten-substrait/src/main/scala/org/apache/gluten/expression/ConditionalTransformer.scala +++ b/gluten-substrait/src/main/scala/org/apache/gluten/expression/ConditionalTransformer.scala @@ -16,6 +16,7 @@ */ package org.apache.gluten.expression +import org.apache.gluten.substrait.SubstraitContext import org.apache.gluten.substrait.expression.{ExpressionBuilder, ExpressionNode, IfThenNode} import org.apache.spark.sql.catalyst.expressions._ @@ -32,19 +33,19 @@ case class CaseWhenTransformer( override def children: Seq[ExpressionTransformer] = branches.flatMap(b => b._1 :: b._2 :: Nil) ++ elseValue - override def doTransform(args: java.lang.Object): ExpressionNode = { + override def doTransform(context: SubstraitContext): ExpressionNode = { // generate branches nodes val ifNodes = new JArrayList[ExpressionNode] val thenNodes = new JArrayList[ExpressionNode] branches.foreach( branch => { - ifNodes.add(branch._1.doTransform(args)) - thenNodes.add(branch._2.doTransform(args)) + ifNodes.add(branch._1.doTransform(context)) + thenNodes.add(branch._2.doTransform(context)) }) val branchDataType = original.asInstanceOf[CaseWhen].inputTypesForMerging(0) // generate else value node, maybe null val elseValueNode = elseValue - .map(_.doTransform(args)) + .map(_.doTransform(context)) .getOrElse(ExpressionBuilder.makeLiteral(null, branchDataType, true)) new IfThenNode(ifNodes, thenNodes, elseValueNode) } @@ -59,14 +60,14 @@ case class IfTransformer( extends ExpressionTransformer { override def children: Seq[ExpressionTransformer] = predicate :: trueValue :: falseValue :: Nil - override def doTransform(args: java.lang.Object): ExpressionNode = { + override def doTransform(context: SubstraitContext): ExpressionNode = { val ifNodes = new JArrayList[ExpressionNode] - ifNodes.add(predicate.doTransform(args)) + ifNodes.add(predicate.doTransform(context)) val thenNodes = new JArrayList[ExpressionNode] - thenNodes.add(trueValue.doTransform(args)) + thenNodes.add(trueValue.doTransform(context)) - val elseValueNode = falseValue.doTransform(args) + val elseValueNode = falseValue.doTransform(context) new IfThenNode(ifNodes, thenNodes, elseValueNode) } } diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/expression/ExpressionTransformer.scala b/gluten-substrait/src/main/scala/org/apache/gluten/expression/ExpressionTransformer.scala index ebb9db3e824f..cfda2d8782e3 100644 --- a/gluten-substrait/src/main/scala/org/apache/gluten/expression/ExpressionTransformer.scala +++ b/gluten-substrait/src/main/scala/org/apache/gluten/expression/ExpressionTransformer.scala @@ -16,6 +16,7 @@ */ package org.apache.gluten.expression +import org.apache.gluten.substrait.SubstraitContext import org.apache.gluten.substrait.expression.{ExpressionBuilder, ExpressionNode} import org.apache.spark.sql.catalyst.expressions.{Expression, Literal} @@ -32,13 +33,12 @@ trait ExpressionTransformer { def dataType: DataType = original.dataType def nullable: Boolean = original.nullable - def doTransform(args: java.lang.Object): ExpressionNode = { - val functionMap = args.asInstanceOf[java.util.HashMap[String, java.lang.Long]] + def doTransform(context: SubstraitContext): ExpressionNode = { // TODO: the funcName seems can be simplified to `substraitExprName` val funcName: String = ConverterUtils.makeFuncName(substraitExprName, original.children.map(_.dataType)) - val functionId = ExpressionBuilder.newScalarFunction(functionMap, funcName) - val childNodes = children.map(_.doTransform(args)).asJava + val functionId = context.registerFunction(funcName) + val childNodes = children.map(_.doTransform(context)).asJava val typeNode = ConverterUtils.getTypeNode(dataType, nullable) ExpressionBuilder.makeScalarFunction(functionId, childNodes, typeNode) } @@ -78,7 +78,7 @@ object GenericExpressionTransformer { case class LiteralTransformer(original: Literal) extends LeafExpressionTransformer { override def substraitExprName: String = "literal" - override def doTransform(args: java.lang.Object): ExpressionNode = { + override def doTransform(context: SubstraitContext): ExpressionNode = { ExpressionBuilder.makeLiteral(original.value, original.dataType, original.nullable) } } diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/expression/JsonTupleExpressionTransformer.scala b/gluten-substrait/src/main/scala/org/apache/gluten/expression/JsonTupleExpressionTransformer.scala index 25e3e12a53de..c5978f714b7a 100644 --- a/gluten-substrait/src/main/scala/org/apache/gluten/expression/JsonTupleExpressionTransformer.scala +++ b/gluten-substrait/src/main/scala/org/apache/gluten/expression/JsonTupleExpressionTransformer.scala @@ -18,6 +18,7 @@ package org.apache.gluten.expression import org.apache.gluten.expression.ConverterUtils.FunctionConfig import org.apache.gluten.substrait.`type`.ListNode +import org.apache.gluten.substrait.SubstraitContext import org.apache.gluten.substrait.expression.{ExpressionBuilder, ExpressionNode} import org.apache.spark.sql.catalyst.expressions.Expression @@ -30,19 +31,18 @@ case class JsonTupleExpressionTransformer( original: Expression) extends ExpressionTransformer { - override def doTransform(args: Object): ExpressionNode = { + override def doTransform(context: SubstraitContext): ExpressionNode = { val jsonExpr = children.head val fields = children.tail - val jsonExprNode = jsonExpr.doTransform(args) + val jsonExprNode = jsonExpr.doTransform(context) val expressNodes = Lists.newArrayList(jsonExprNode) - fields.foreach(f => expressNodes.add(f.doTransform(args))) - val functionMap = args.asInstanceOf[java.util.HashMap[String, java.lang.Long]] + fields.foreach(f => expressNodes.add(f.doTransform(context))) val functionName = ConverterUtils.makeFuncName( substraitExprName, original.children.map(_.dataType), FunctionConfig.REQ) - val functionId = ExpressionBuilder.newScalarFunction(functionMap, functionName) + val functionId = context.registerFunction(functionName) val typeNode = ConverterUtils.getTypeNode(original.dataType, original.nullable) typeNode match { case node: ListNode => diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/expression/LambdaFunctionTransformer.scala b/gluten-substrait/src/main/scala/org/apache/gluten/expression/LambdaFunctionTransformer.scala index 9e7285ac3a17..ba20e9737dfd 100644 --- a/gluten-substrait/src/main/scala/org/apache/gluten/expression/LambdaFunctionTransformer.scala +++ b/gluten-substrait/src/main/scala/org/apache/gluten/expression/LambdaFunctionTransformer.scala @@ -17,6 +17,7 @@ package org.apache.gluten.expression import org.apache.gluten.exception.GlutenNotSupportException +import org.apache.gluten.substrait.SubstraitContext import org.apache.gluten.substrait.expression.ExpressionNode import org.apache.spark.sql.catalyst.expressions.LambdaFunction @@ -29,11 +30,11 @@ case class LambdaFunctionTransformer( extends ExpressionTransformer { override def children: Seq[ExpressionTransformer] = function +: arguments - override def doTransform(args: Object): ExpressionNode = { + override def doTransform(context: SubstraitContext): ExpressionNode = { // Need to fallback when hidden be true as it's not supported in Velox if (original.hidden) { throw new GlutenNotSupportException(s"Unsupported LambdaFunction with hidden be true.") } - super.doTransform(args) + super.doTransform(context) } } diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/expression/MapExpressionTransformer.scala b/gluten-substrait/src/main/scala/org/apache/gluten/expression/MapExpressionTransformer.scala index fe715979b1a7..c9f0f19c4ed4 100644 --- a/gluten-substrait/src/main/scala/org/apache/gluten/expression/MapExpressionTransformer.scala +++ b/gluten-substrait/src/main/scala/org/apache/gluten/expression/MapExpressionTransformer.scala @@ -18,6 +18,7 @@ package org.apache.gluten.expression import org.apache.gluten.backendsapi.BackendsApiManager import org.apache.gluten.exception.GlutenNotSupportException +import org.apache.gluten.substrait.SubstraitContext import org.apache.gluten.substrait.expression.ExpressionNode import org.apache.spark.sql.catalyst.expressions._ @@ -28,7 +29,7 @@ case class CreateMapTransformer( original: CreateMap) extends ExpressionTransformer { - override def doTransform(args: java.lang.Object): ExpressionNode = { + override def doTransform(context: SubstraitContext): ExpressionNode = { // If children is empty, // transformation is only supported when useStringTypeWhenEmpty is false // because ClickHouse and Velox currently doesn't support this config. @@ -36,7 +37,7 @@ case class CreateMapTransformer( throw new GlutenNotSupportException(s"$original not supported yet.") } - super.doTransform(args) + super.doTransform(context) } } @@ -48,7 +49,7 @@ case class GetMapValueTransformer( original: GetMapValue) extends BinaryExpressionTransformer { - override def doTransform(args: java.lang.Object): ExpressionNode = { + override def doTransform(context: SubstraitContext): ExpressionNode = { if (BackendsApiManager.getSettings.alwaysFailOnMapExpression()) { throw new GlutenNotSupportException(s"$original not supported yet.") } @@ -57,6 +58,6 @@ case class GetMapValueTransformer( throw new GlutenNotSupportException(s"$original not supported yet.") } - super.doTransform(args) + super.doTransform(context) } } diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/expression/NamedExpressionsTransformer.scala b/gluten-substrait/src/main/scala/org/apache/gluten/expression/NamedExpressionsTransformer.scala index f4c703d88ef0..76437d0c3e2d 100644 --- a/gluten-substrait/src/main/scala/org/apache/gluten/expression/NamedExpressionsTransformer.scala +++ b/gluten-substrait/src/main/scala/org/apache/gluten/expression/NamedExpressionsTransformer.scala @@ -16,6 +16,7 @@ */ package org.apache.gluten.expression +import org.apache.gluten.substrait.SubstraitContext import org.apache.gluten.substrait.expression.{ExpressionBuilder, ExpressionNode} import org.apache.spark.sql.catalyst.expressions._ @@ -31,14 +32,14 @@ case class AttributeReferenceTransformer( original: AttributeReference, bound: BoundReference) extends LeafExpressionTransformer { - override def doTransform(args: java.lang.Object): ExpressionNode = { + override def doTransform(context: SubstraitContext): ExpressionNode = { ExpressionBuilder.makeSelection(bound.ordinal.asInstanceOf[java.lang.Integer]) } } case class BoundReferenceTransformer(substraitExprName: String, original: BoundReference) extends LeafExpressionTransformer { - override def doTransform(args: java.lang.Object): ExpressionNode = { + override def doTransform(context: SubstraitContext): ExpressionNode = { ExpressionBuilder.makeSelection(original.ordinal.asInstanceOf[java.lang.Integer]) } } diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/expression/PredicateExpressionTransformer.scala b/gluten-substrait/src/main/scala/org/apache/gluten/expression/PredicateExpressionTransformer.scala index d13c61d64af3..9f443973a9a0 100644 --- a/gluten-substrait/src/main/scala/org/apache/gluten/expression/PredicateExpressionTransformer.scala +++ b/gluten-substrait/src/main/scala/org/apache/gluten/expression/PredicateExpressionTransformer.scala @@ -16,6 +16,7 @@ */ package org.apache.gluten.expression +import org.apache.gluten.substrait.SubstraitContext import org.apache.gluten.substrait.expression.{ExpressionBuilder, ExpressionNode} import org.apache.spark.sql.catalyst.expressions._ @@ -25,11 +26,11 @@ import scala.collection.JavaConverters._ case class InTransformer(substraitExprName: String, child: ExpressionTransformer, original: In) extends UnaryExpressionTransformer { - override def doTransform(args: java.lang.Object): ExpressionNode = { + override def doTransform(context: SubstraitContext): ExpressionNode = { assert(original.list.forall(_.foldable)) // Stores the values in a List Literal. val values: Set[Any] = original.list.map(_.eval()).toSet - InExpressionTransformer.toTransformer(child.doTransform(args), values, child.dataType) + InExpressionTransformer.toTransformer(child.doTransform(context), values, child.dataType) } } @@ -38,9 +39,9 @@ case class InSetTransformer( child: ExpressionTransformer, original: InSet) extends UnaryExpressionTransformer { - override def doTransform(args: java.lang.Object): ExpressionNode = { + override def doTransform(context: SubstraitContext): ExpressionNode = { InExpressionTransformer.toTransformer( - child.doTransform(args), + child.doTransform(context), original.hset, original.child.dataType) } diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/expression/ScalarSubqueryTransformer.scala b/gluten-substrait/src/main/scala/org/apache/gluten/expression/ScalarSubqueryTransformer.scala index 9508d27df73b..a1c6e9b71524 100644 --- a/gluten-substrait/src/main/scala/org/apache/gluten/expression/ScalarSubqueryTransformer.scala +++ b/gluten-substrait/src/main/scala/org/apache/gluten/expression/ScalarSubqueryTransformer.scala @@ -16,6 +16,7 @@ */ package org.apache.gluten.expression +import org.apache.gluten.substrait.SubstraitContext import org.apache.gluten.substrait.expression.{ExpressionBuilder, ExpressionNode} import org.apache.spark.sql.catalyst.InternalRow @@ -26,7 +27,7 @@ case class ScalarSubqueryTransformer(substraitExprName: String, query: ScalarSub extends LeafExpressionTransformer { override def original: Expression = query - override def doTransform(args: java.lang.Object): ExpressionNode = { + override def doTransform(context: SubstraitContext): ExpressionNode = { // don't trigger collect when in validation phase if (TransformerState.underValidationState) { return ExpressionBuilder.makeLiteral(null, query.dataType, true) diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/expression/UnaryExpressionTransformer.scala b/gluten-substrait/src/main/scala/org/apache/gluten/expression/UnaryExpressionTransformer.scala index f9eb1e8eab42..f309621cced9 100644 --- a/gluten-substrait/src/main/scala/org/apache/gluten/expression/UnaryExpressionTransformer.scala +++ b/gluten-substrait/src/main/scala/org/apache/gluten/expression/UnaryExpressionTransformer.scala @@ -21,6 +21,7 @@ import org.apache.gluten.exception.GlutenNotSupportException import org.apache.gluten.sql.shims.SparkShimLoader import org.apache.gluten.substrait.`type`.ListNode import org.apache.gluten.substrait.`type`.MapNode +import org.apache.gluten.substrait.SubstraitContext import org.apache.gluten.substrait.expression.{ExpressionBuilder, ExpressionNode, StructLiteralNode} import org.apache.spark.sql.catalyst.expressions._ @@ -35,18 +36,18 @@ case class ChildTransformer( extends UnaryExpressionTransformer { override def dataType: DataType = child.dataType - override def doTransform(args: java.lang.Object): ExpressionNode = { - child.doTransform(args) + override def doTransform(context: SubstraitContext): ExpressionNode = { + child.doTransform(context) } } case class CastTransformer(substraitExprName: String, child: ExpressionTransformer, original: Cast) extends UnaryExpressionTransformer { - override def doTransform(args: java.lang.Object): ExpressionNode = { + override def doTransform(context: SubstraitContext): ExpressionNode = { val typeNode = ConverterUtils.getTypeNode(dataType, original.nullable) ExpressionBuilder.makeCast( typeNode, - child.doTransform(args), + child.doTransform(context), SparkShimLoader.getSparkShims.withAnsiEvalMode(original)) } } @@ -57,12 +58,10 @@ case class ExplodeTransformer( original: Explode) extends UnaryExpressionTransformer { - override def doTransform(args: java.lang.Object): ExpressionNode = { - val childNode: ExpressionNode = child.doTransform(args) + override def doTransform(context: SubstraitContext): ExpressionNode = { + val childNode: ExpressionNode = child.doTransform(context) - val functionMap = args.asInstanceOf[java.util.HashMap[String, java.lang.Long]] - val functionId = ExpressionBuilder.newScalarFunction( - functionMap, + val functionId = context.registerFunction( ConverterUtils.makeFuncName(substraitExprName, Seq(original.child.dataType))) val expressionNodes = Lists.newArrayList(childNode) @@ -83,11 +82,11 @@ case class CheckOverflowTransformer( child: ExpressionTransformer, original: CheckOverflow) extends UnaryExpressionTransformer { - override def doTransform(args: java.lang.Object): ExpressionNode = { + override def doTransform(context: SubstraitContext): ExpressionNode = { BackendsApiManager.getTransformerApiInstance.createCheckOverflowExprNode( - args, + context, substraitExprName, - child.doTransform(args), + child.doTransform(context), original.child.dataType, original.dataType, original.nullable, @@ -103,13 +102,13 @@ case class GetStructFieldTransformer( override def left: ExpressionTransformer = child override def right: ExpressionTransformer = LiteralTransformer(original.ordinal) - override def doTransform(args: java.lang.Object): ExpressionNode = { - val childNode = child.doTransform(args) + override def doTransform(context: SubstraitContext): ExpressionNode = { + val childNode = child.doTransform(context) childNode match { case node: StructLiteralNode => node.getFieldLiteral(original.ordinal) case _ => - super.doTransform(args) + super.doTransform(context) } } } diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/expression/WindowFunctionsBuilder.scala b/gluten-substrait/src/main/scala/org/apache/gluten/expression/WindowFunctionsBuilder.scala index 831e3199733d..5873c29f2092 100644 --- a/gluten-substrait/src/main/scala/org/apache/gluten/expression/WindowFunctionsBuilder.scala +++ b/gluten-substrait/src/main/scala/org/apache/gluten/expression/WindowFunctionsBuilder.scala @@ -19,15 +19,14 @@ package org.apache.gluten.expression import org.apache.gluten.exception.GlutenNotSupportException import org.apache.gluten.expression.ConverterUtils.FunctionConfig import org.apache.gluten.expression.ExpressionNames.{LAG, LEAD} -import org.apache.gluten.substrait.expression.ExpressionBuilder +import org.apache.gluten.substrait.SubstraitContext import org.apache.spark.sql.catalyst.expressions.{EmptyRow, Expression, Lag, Lead, WindowExpression, WindowFunction} import scala.util.control.Breaks.{break, breakable} object WindowFunctionsBuilder { - def create(args: java.lang.Object, windowFunc: WindowFunction): Long = { - val functionMap = args.asInstanceOf[java.util.HashMap[String, java.lang.Long]] + def create(context: SubstraitContext, windowFunc: WindowFunction): Long = { val substraitFunc = windowFunc match { // Handle lag with negative inputOffset, e.g., converts lag(c1, -1) to lead(c1, 1). // Spark uses `-inputOffset` as `offset` for Lag function. @@ -46,7 +45,7 @@ object WindowFunctionsBuilder { val functionName = ConverterUtils.makeFuncName(substraitFunc.get, Seq(windowFunc.dataType), FunctionConfig.OPT) - ExpressionBuilder.newScalarFunction(functionMap, functionName) + context.registerFunction(functionName) } def extractWindowExpression(expr: Expression): WindowExpression = { diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/utils/SubstraitUtil.scala b/gluten-substrait/src/main/scala/org/apache/gluten/utils/SubstraitUtil.scala index 41d5b18c3809..80b3a365b85a 100644 --- a/gluten-substrait/src/main/scala/org/apache/gluten/utils/SubstraitUtil.scala +++ b/gluten-substrait/src/main/scala/org/apache/gluten/utils/SubstraitUtil.scala @@ -84,7 +84,7 @@ object SubstraitUtil { context: SubstraitContext): ExpressionNode = { ExpressionConverter .replaceWithExpressionTransformer(expr, attributeSeq) - .doTransform(context.registeredFunction) + .doTransform(context) } def createNameStructBuilder( diff --git a/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExecTransformer.scala b/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExecTransformer.scala index 7b4c09d4f9e5..553fc15faca6 100644 --- a/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExecTransformer.scala +++ b/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExecTransformer.scala @@ -67,7 +67,6 @@ case class EvalPythonExecTransformer( } val context = new SubstraitContext - val args = context.registeredFunction val operatorId = context.nextOperatorId(this.nodeName) val expressionNodes = new JArrayList[ExpressionNode] @@ -76,7 +75,9 @@ case class EvalPythonExecTransformer( udfs.foreach( udf => { expressionNodes.add( - ExpressionConverter.replaceWithExpressionTransformer(udf, child.output).doTransform(args)) + ExpressionConverter + .replaceWithExpressionTransformer(udf, child.output) + .doTransform(context)) }) val relNode = RelBuilder.makeProjectRel(null, expressionNodes, context, operatorId) @@ -86,7 +87,6 @@ case class EvalPythonExecTransformer( override protected def doTransform(context: SubstraitContext): TransformContext = { val childCtx = child.asInstanceOf[TransformSupport].transform(context) - val args = context.registeredFunction val operatorId = context.nextOperatorId(this.nodeName) val expressionNodes = new JArrayList[ExpressionNode] child.output.zipWithIndex.foreach( @@ -94,7 +94,9 @@ case class EvalPythonExecTransformer( udfs.foreach( udf => { expressionNodes.add( - ExpressionConverter.replaceWithExpressionTransformer(udf, child.output).doTransform(args)) + ExpressionConverter + .replaceWithExpressionTransformer(udf, child.output) + .doTransform(context)) }) val relNode =