diff --git a/backends-velox/src/main/scala/org/apache/gluten/execution/HashAggregateExecTransformer.scala b/backends-velox/src/main/scala/org/apache/gluten/execution/HashAggregateExecTransformer.scala index 9a6ba3a220f1..7b0508de87cb 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/execution/HashAggregateExecTransformer.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/execution/HashAggregateExecTransformer.scala @@ -419,25 +419,13 @@ abstract class HashAggregateExecTransformer( } // Create a project rel. - val emitStartIndex = originalInputAttributes.size - val projectRel = if (!validation) { - RelBuilder.makeProjectRel(inputRel, exprNodes, context, operatorId, emitStartIndex) - } else { - // Use a extension node to send the input types through Substrait plan for validation. - val inputTypeNodeList = originalInputAttributes - .map(attr => ConverterUtils.getTypeNode(attr.dataType, attr.nullable)) - .asJava - val extensionNode = ExtensionBuilder.makeAdvancedExtension( - BackendsApiManager.getTransformerApiInstance.packPBMessage( - TypeBuilder.makeStruct(false, inputTypeNodeList).toProtobuf)) - RelBuilder.makeProjectRel( - inputRel, - exprNodes, - extensionNode, - context, - operatorId, - emitStartIndex) - } + val projectRel = RelBuilder.makeProjectRel( + originalInputAttributes.asJava, + inputRel, + exprNodes, + context, + operatorId, + validation) // Create aggregation rel. val groupingList = new JArrayList[ExpressionNode]() diff --git a/backends-velox/src/main/scala/org/apache/gluten/execution/TopNTransformer.scala b/backends-velox/src/main/scala/org/apache/gluten/execution/TopNTransformer.scala index f3adbe351aa4..f3bc929d7eb5 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/execution/TopNTransformer.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/execution/TopNTransformer.scala @@ -16,13 +16,10 @@ */ package org.apache.gluten.execution -import org.apache.gluten.backendsapi.BackendsApiManager -import org.apache.gluten.expression.{ConverterUtils, ExpressionConverter} +import org.apache.gluten.expression.ExpressionConverter import org.apache.gluten.extension.ValidationResult import org.apache.gluten.metrics.MetricsUpdater -import org.apache.gluten.substrait.`type`.TypeBuilder import org.apache.gluten.substrait.SubstraitContext -import org.apache.gluten.substrait.extensions.ExtensionBuilder import org.apache.gluten.substrait.rel.{RelBuilder, RelNode} import org.apache.spark.sql.catalyst.expressions.{Attribute, SortOrder} @@ -105,12 +102,13 @@ case class TopNTransformer( if (!validation) { RelBuilder.makeTopNRel(input, count, sortFieldList.asJava, context, operatorId) } else { - val inputTypeNodes = - inputAttributes.map(attr => ConverterUtils.getTypeNode(attr.dataType, attr.nullable)).asJava - val extensionNode = ExtensionBuilder.makeAdvancedExtension( - BackendsApiManager.getTransformerApiInstance.packPBMessage( - TypeBuilder.makeStruct(false, inputTypeNodes).toProtobuf)) - RelBuilder.makeTopNRel(input, count, sortFieldList.asJava, extensionNode, context, operatorId) + RelBuilder.makeTopNRel( + input, + count, + sortFieldList.asJava, + RelBuilder.createExtensionNode(inputAttributes.asJava), + context, + operatorId) } } diff --git a/gluten-substrait/src/main/java/org/apache/gluten/substrait/rel/RelBuilder.java b/gluten-substrait/src/main/java/org/apache/gluten/substrait/rel/RelBuilder.java index 86b27353183c..c8a028d0be4d 100644 --- a/gluten-substrait/src/main/java/org/apache/gluten/substrait/rel/RelBuilder.java +++ b/gluten-substrait/src/main/java/org/apache/gluten/substrait/rel/RelBuilder.java @@ -16,24 +16,55 @@ */ package org.apache.gluten.substrait.rel; +import org.apache.gluten.backendsapi.BackendsApiManager; import org.apache.gluten.expression.ConverterUtils; import org.apache.gluten.substrait.SubstraitContext; import org.apache.gluten.substrait.expression.AggregateFunctionNode; import org.apache.gluten.substrait.expression.ExpressionNode; import org.apache.gluten.substrait.expression.WindowFunctionNode; import org.apache.gluten.substrait.extensions.AdvancedExtensionNode; +import org.apache.gluten.substrait.extensions.ExtensionBuilder; import org.apache.gluten.substrait.type.ColumnTypeNode; +import org.apache.gluten.substrait.type.TypeBuilder; import org.apache.gluten.substrait.type.TypeNode; import io.substrait.proto.*; import org.apache.spark.sql.catalyst.expressions.Attribute; import java.util.List; +import java.util.stream.Collectors; /** Contains helper functions for constructing substrait relations. */ public class RelBuilder { private RelBuilder() {} + public static AdvancedExtensionNode createExtensionNode(List inputAttributes) { + // Use an extension node to send the input types through Substrait plan for validation. + List inputTypeNodeList = + inputAttributes.stream() + .map(attr -> ConverterUtils.getTypeNode(attr.dataType(), attr.nullable())) + .collect(Collectors.toList()); + + return ExtensionBuilder.makeAdvancedExtension( + BackendsApiManager.getTransformerApiInstance() + .packPBMessage(TypeBuilder.makeStruct(false, inputTypeNodeList).toProtobuf())); + } + + public static RelNode makeFilterRel( + SubstraitContext context, + ExpressionNode condExprNode, + List inputAttributes, + Long operatorId, + RelNode input, + Boolean validation) { + if (!validation) { + return RelBuilder.makeFilterRel(input, condExprNode, context, operatorId); + } else { + return RelBuilder.makeFilterRel( + input, condExprNode, createExtensionNode(inputAttributes), context, operatorId); + } + } + public static RelNode makeFilterRel( RelNode input, ExpressionNode condition, SubstraitContext context, Long operatorId) { context.registerRelToOperator(operatorId); @@ -50,6 +81,28 @@ public static RelNode makeFilterRel( return new FilterRelNode(input, condition, extensionNode); } + public static RelNode makeProjectRel( + List inputAttributes, + RelNode input, + List projExprNodeList, + SubstraitContext context, + Long operatorId, + Boolean validation) { + int emitStartIndex = inputAttributes.size(); + if (!validation) { + return RelBuilder.makeProjectRel( + input, projExprNodeList, context, operatorId, emitStartIndex); + } else { + return RelBuilder.makeProjectRel( + input, + projExprNodeList, + createExtensionNode(inputAttributes), + context, + operatorId, + emitStartIndex); + } + } + public static RelNode makeProjectRel( RelNode input, List expressionNodes, diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/execution/BasicPhysicalOperatorTransformer.scala b/gluten-substrait/src/main/scala/org/apache/gluten/execution/BasicPhysicalOperatorTransformer.scala index ac8e610956dc..fe4898dbaa7b 100644 --- a/gluten-substrait/src/main/scala/org/apache/gluten/execution/BasicPhysicalOperatorTransformer.scala +++ b/gluten-substrait/src/main/scala/org/apache/gluten/execution/BasicPhysicalOperatorTransformer.scala @@ -18,13 +18,11 @@ package org.apache.gluten.execution import org.apache.gluten.backendsapi.BackendsApiManager import org.apache.gluten.exception.GlutenNotSupportException -import org.apache.gluten.expression.{ConverterUtils, ExpressionConverter, ExpressionTransformer} +import org.apache.gluten.expression.{ExpressionConverter, ExpressionTransformer} import org.apache.gluten.extension.ValidationResult import org.apache.gluten.extension.columnar.transition.Convention import org.apache.gluten.metrics.MetricsUpdater -import org.apache.gluten.substrait.`type`.TypeBuilder import org.apache.gluten.substrait.SubstraitContext -import org.apache.gluten.substrait.extensions.ExtensionBuilder import org.apache.gluten.substrait.rel.{RelBuilder, RelNode} import org.apache.spark.internal.Logging @@ -78,23 +76,17 @@ abstract class FilterExecTransformerBase(val cond: Expression, val input: SparkP input: RelNode, validation: Boolean): RelNode = { assert(condExpr != null) - val args = context.registeredFunction val condExprNode = ExpressionConverter - .replaceWithExpressionTransformer(condExpr, attributeSeq = originalInputAttributes) - .doTransform(args) - - if (!validation) { - RelBuilder.makeFilterRel(input, condExprNode, context, operatorId) - } else { - // Use a extension node to send the input types through Substrait plan for validation. - val inputTypeNodeList = originalInputAttributes - .map(attr => ConverterUtils.getTypeNode(attr.dataType, attr.nullable)) - .asJava - val extensionNode = ExtensionBuilder.makeAdvancedExtension( - BackendsApiManager.getTransformerApiInstance.packPBMessage( - TypeBuilder.makeStruct(false, inputTypeNodeList).toProtobuf)) - RelBuilder.makeFilterRel(input, condExprNode, extensionNode, context, operatorId) - } + .replaceWithExpressionTransformer(condExpr, originalInputAttributes) + .doTransform(context.registeredFunction) + RelBuilder.makeFilterRel( + context, + condExprNode, + originalInputAttributes.asJava, + operatorId, + input, + validation + ) } override def output: Seq[Attribute] = { @@ -229,27 +221,15 @@ abstract class ProjectExecTransformerBase(val list: Seq[NamedExpression], val in validation: Boolean): RelNode = { val args = context.registeredFunction val columnarProjExprs: Seq[ExpressionTransformer] = ExpressionConverter - .replaceWithExpressionTransformer(projectList, attributeSeq = originalInputAttributes) + .replaceWithExpressionTransformer(projectList, originalInputAttributes) val projExprNodeList = columnarProjExprs.map(_.doTransform(args)).asJava - val emitStartIndex = originalInputAttributes.size - if (!validation) { - RelBuilder.makeProjectRel(input, projExprNodeList, context, operatorId, emitStartIndex) - } else { - // Use a extension node to send the input types through Substrait plan for validation. - val inputTypeNodeList = originalInputAttributes - .map(attr => ConverterUtils.getTypeNode(attr.dataType, attr.nullable)) - .asJava - val extensionNode = ExtensionBuilder.makeAdvancedExtension( - BackendsApiManager.getTransformerApiInstance.packPBMessage( - TypeBuilder.makeStruct(false, inputTypeNodeList).toProtobuf)) - RelBuilder.makeProjectRel( - input, - projExprNodeList, - extensionNode, - context, - operatorId, - emitStartIndex) - } + RelBuilder.makeProjectRel( + originalInputAttributes.asJava, + input, + projExprNodeList, + context, + operatorId, + validation) } override def verboseStringWithOperatorId(): String = { diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/execution/LimitExecTransformer.scala b/gluten-substrait/src/main/scala/org/apache/gluten/execution/LimitExecTransformer.scala index 0d49acd30f0d..73cc1a15a9a4 100644 --- a/gluten-substrait/src/main/scala/org/apache/gluten/execution/LimitExecTransformer.scala +++ b/gluten-substrait/src/main/scala/org/apache/gluten/execution/LimitExecTransformer.scala @@ -17,12 +17,9 @@ package org.apache.gluten.execution import org.apache.gluten.backendsapi.BackendsApiManager -import org.apache.gluten.expression.ConverterUtils import org.apache.gluten.extension.ValidationResult import org.apache.gluten.metrics.MetricsUpdater -import org.apache.gluten.substrait.`type`.TypeBuilder import org.apache.gluten.substrait.SubstraitContext -import org.apache.gluten.substrait.extensions.ExtensionBuilder import org.apache.gluten.substrait.rel.{RelBuilder, RelNode} import org.apache.spark.sql.catalyst.expressions.Attribute @@ -71,12 +68,13 @@ case class LimitExecTransformer(child: SparkPlan, offset: Long, count: Long) if (!validation) { RelBuilder.makeFetchRel(input, offset, count, context, operatorId) } else { - val inputTypeNodes = - inputAttributes.map(attr => ConverterUtils.getTypeNode(attr.dataType, attr.nullable)).asJava - val extensionNode = ExtensionBuilder.makeAdvancedExtension( - BackendsApiManager.getTransformerApiInstance.packPBMessage( - TypeBuilder.makeStruct(false, inputTypeNodes).toProtobuf)) - RelBuilder.makeFetchRel(input, offset, count, extensionNode, context, operatorId) + RelBuilder.makeFetchRel( + input, + offset, + count, + RelBuilder.createExtensionNode(inputAttributes.asJava), + context, + operatorId) } } } diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/execution/SampleExecTransformer.scala b/gluten-substrait/src/main/scala/org/apache/gluten/execution/SampleExecTransformer.scala index 8e664a3b6ebc..4fed8b36e90d 100644 --- a/gluten-substrait/src/main/scala/org/apache/gluten/execution/SampleExecTransformer.scala +++ b/gluten-substrait/src/main/scala/org/apache/gluten/execution/SampleExecTransformer.scala @@ -17,12 +17,10 @@ package org.apache.gluten.execution import org.apache.gluten.backendsapi.BackendsApiManager -import org.apache.gluten.expression.{ConverterUtils, ExpressionConverter} +import org.apache.gluten.expression.ExpressionConverter import org.apache.gluten.extension.ValidationResult import org.apache.gluten.metrics.MetricsUpdater -import org.apache.gluten.substrait.`type`.TypeBuilder import org.apache.gluten.substrait.SubstraitContext -import org.apache.gluten.substrait.extensions.ExtensionBuilder import org.apache.gluten.substrait.rel.{RelBuilder, RelNode} import org.apache.spark.internal.Logging @@ -78,23 +76,17 @@ case class SampleExecTransformer( input: RelNode, validation: Boolean): RelNode = { assert(condExpr != null) - val args = context.registeredFunction val condExprNode = ExpressionConverter - .replaceWithExpressionTransformer(condExpr, attributeSeq = originalInputAttributes) - .doTransform(args) - - if (!validation) { - RelBuilder.makeFilterRel(input, condExprNode, context, operatorId) - } else { - // Use a extension node to send the input types through Substrait plan for validation. - val inputTypeNodeList = originalInputAttributes - .map(attr => ConverterUtils.getTypeNode(attr.dataType, attr.nullable)) - .asJava - val extensionNode = ExtensionBuilder.makeAdvancedExtension( - BackendsApiManager.getTransformerApiInstance.packPBMessage( - TypeBuilder.makeStruct(false, inputTypeNodeList).toProtobuf)) - RelBuilder.makeFilterRel(input, condExprNode, extensionNode, context, operatorId) - } + .replaceWithExpressionTransformer(condExpr, originalInputAttributes) + .doTransform(context.registeredFunction) + RelBuilder.makeFilterRel( + context, + condExprNode, + originalInputAttributes.asJava, + operatorId, + input, + validation + ) } override protected def doValidateInternal(): ValidationResult = { diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/execution/WindowExecTransformer.scala b/gluten-substrait/src/main/scala/org/apache/gluten/execution/WindowExecTransformer.scala index 7bdb29f0ee95..068cee1742ee 100644 --- a/gluten-substrait/src/main/scala/org/apache/gluten/execution/WindowExecTransformer.scala +++ b/gluten-substrait/src/main/scala/org/apache/gluten/execution/WindowExecTransformer.scala @@ -21,7 +21,6 @@ import org.apache.gluten.config.GlutenConfig import org.apache.gluten.expression._ import org.apache.gluten.extension.ValidationResult import org.apache.gluten.metrics.MetricsUpdater -import org.apache.gluten.substrait.`type`.TypeBuilder import org.apache.gluten.substrait.SubstraitContext import org.apache.gluten.substrait.expression.WindowFunctionNode import org.apache.gluten.substrait.extensions.ExtensionBuilder @@ -148,20 +147,12 @@ case class WindowExecTransformer( context, operatorId) } else { - // Use a extension node to send the input types through Substrait plan for validation. - val inputTypeNodeList = originalInputAttributes - .map(attr => ConverterUtils.getTypeNode(attr.dataType, attr.nullable)) - .asJava - val extensionNode = ExtensionBuilder.makeAdvancedExtension( - BackendsApiManager.getTransformerApiInstance.packPBMessage( - TypeBuilder.makeStruct(false, inputTypeNodeList).toProtobuf)) - RelBuilder.makeWindowRel( input, windowExpressions, partitionsExpressions, sortFieldList, - extensionNode, + RelBuilder.createExtensionNode(originalInputAttributes.asJava), context, operatorId) } diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/execution/WindowGroupLimitExecTransformer.scala b/gluten-substrait/src/main/scala/org/apache/gluten/execution/WindowGroupLimitExecTransformer.scala index 5d8a18b11164..d96d04dfadb7 100644 --- a/gluten-substrait/src/main/scala/org/apache/gluten/execution/WindowGroupLimitExecTransformer.scala +++ b/gluten-substrait/src/main/scala/org/apache/gluten/execution/WindowGroupLimitExecTransformer.scala @@ -17,12 +17,10 @@ package org.apache.gluten.execution import org.apache.gluten.backendsapi.BackendsApiManager -import org.apache.gluten.expression.{ConverterUtils, ExpressionConverter} +import org.apache.gluten.expression.ExpressionConverter import org.apache.gluten.extension.ValidationResult import org.apache.gluten.metrics.MetricsUpdater -import org.apache.gluten.substrait.`type`.TypeBuilder import org.apache.gluten.substrait.SubstraitContext -import org.apache.gluten.substrait.extensions.ExtensionBuilder import org.apache.gluten.substrait.rel.{RelBuilder, RelNode} import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, Expression, SortOrder} @@ -123,20 +121,12 @@ case class WindowGroupLimitExecTransformer( context, operatorId) } else { - // Use a extension node to send the input types through Substrait plan for validation. - val inputTypeNodeList = originalInputAttributes - .map(attr => ConverterUtils.getTypeNode(attr.dataType, attr.nullable)) - .asJava - val extensionNode = ExtensionBuilder.makeAdvancedExtension( - BackendsApiManager.getTransformerApiInstance.packPBMessage( - TypeBuilder.makeStruct(false, inputTypeNodeList).toProtobuf)) - RelBuilder.makeWindowGroupLimitRel( input, partitionsExpressions, sortFieldList, limit, - extensionNode, + RelBuilder.createExtensionNode(originalInputAttributes.asJava), context, operatorId) }