From 5c52344e5dbcec3dfeeab797eb9b4a2270398e65 Mon Sep 17 00:00:00 2001 From: zml1206 Date: Thu, 16 Jan 2025 18:07:26 +0800 Subject: [PATCH 1/2] [CORE] Optimize duplicated code for create rel node --- .../HashAggregateExecTransformer.scala | 28 ++---- .../gluten/execution/TopNTransformer.scala | 20 ++--- .../BasicPhysicalOperatorTransformer.scala | 60 ++++--------- .../execution/LimitExecTransformer.scala | 20 ++--- .../execution/SampleExecTransformer.scala | 33 ++----- .../execution/WindowExecTransformer.scala | 13 +-- .../WindowGroupLimitExecTransformer.scala | 16 +--- .../gluten/substrait/rel/RelBuilderUtil.scala | 89 +++++++++++++++++++ 8 files changed, 146 insertions(+), 133 deletions(-) create mode 100644 gluten-substrait/src/main/scala/org/apache/gluten/substrait/rel/RelBuilderUtil.scala diff --git a/backends-velox/src/main/scala/org/apache/gluten/execution/HashAggregateExecTransformer.scala b/backends-velox/src/main/scala/org/apache/gluten/execution/HashAggregateExecTransformer.scala index 9a6ba3a220f1..e596c7d0ea21 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/execution/HashAggregateExecTransformer.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/execution/HashAggregateExecTransformer.scala @@ -24,7 +24,7 @@ import org.apache.gluten.substrait.`type`.{TypeBuilder, TypeNode} import org.apache.gluten.substrait.{AggregationParams, SubstraitContext} import org.apache.gluten.substrait.expression.{AggregateFunctionNode, ExpressionBuilder, ExpressionNode, ScalarFunctionNode} import org.apache.gluten.substrait.extensions.{AdvancedExtensionNode, ExtensionBuilder} -import org.apache.gluten.substrait.rel.{RelBuilder, RelNode} +import org.apache.gluten.substrait.rel.{RelBuilder, RelBuilderUtil, RelNode} import org.apache.gluten.utils.VeloxIntermediateData import org.apache.spark.sql.catalyst.expressions._ @@ -419,25 +419,13 @@ abstract class HashAggregateExecTransformer( } // Create a project rel. - val emitStartIndex = originalInputAttributes.size - val projectRel = if (!validation) { - RelBuilder.makeProjectRel(inputRel, exprNodes, context, operatorId, emitStartIndex) - } else { - // Use a extension node to send the input types through Substrait plan for validation. - val inputTypeNodeList = originalInputAttributes - .map(attr => ConverterUtils.getTypeNode(attr.dataType, attr.nullable)) - .asJava - val extensionNode = ExtensionBuilder.makeAdvancedExtension( - BackendsApiManager.getTransformerApiInstance.packPBMessage( - TypeBuilder.makeStruct(false, inputTypeNodeList).toProtobuf)) - RelBuilder.makeProjectRel( - inputRel, - exprNodes, - extensionNode, - context, - operatorId, - emitStartIndex) - } + val projectRel = RelBuilderUtil.createProjectRel( + originalInputAttributes, + inputRel, + exprNodes, + context, + operatorId, + validation) // Create aggregation rel. val groupingList = new JArrayList[ExpressionNode]() diff --git a/backends-velox/src/main/scala/org/apache/gluten/execution/TopNTransformer.scala b/backends-velox/src/main/scala/org/apache/gluten/execution/TopNTransformer.scala index f3adbe351aa4..20af15f91677 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/execution/TopNTransformer.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/execution/TopNTransformer.scala @@ -16,14 +16,11 @@ */ package org.apache.gluten.execution -import org.apache.gluten.backendsapi.BackendsApiManager -import org.apache.gluten.expression.{ConverterUtils, ExpressionConverter} +import org.apache.gluten.expression.ExpressionConverter import org.apache.gluten.extension.ValidationResult import org.apache.gluten.metrics.MetricsUpdater -import org.apache.gluten.substrait.`type`.TypeBuilder import org.apache.gluten.substrait.SubstraitContext -import org.apache.gluten.substrait.extensions.ExtensionBuilder -import org.apache.gluten.substrait.rel.{RelBuilder, RelNode} +import org.apache.gluten.substrait.rel.{RelBuilder, RelBuilderUtil, RelNode} import org.apache.spark.sql.catalyst.expressions.{Attribute, SortOrder} import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, Distribution, Partitioning, UnspecifiedDistribution} @@ -105,12 +102,13 @@ case class TopNTransformer( if (!validation) { RelBuilder.makeTopNRel(input, count, sortFieldList.asJava, context, operatorId) } else { - val inputTypeNodes = - inputAttributes.map(attr => ConverterUtils.getTypeNode(attr.dataType, attr.nullable)).asJava - val extensionNode = ExtensionBuilder.makeAdvancedExtension( - BackendsApiManager.getTransformerApiInstance.packPBMessage( - TypeBuilder.makeStruct(false, inputTypeNodes).toProtobuf)) - RelBuilder.makeTopNRel(input, count, sortFieldList.asJava, extensionNode, context, operatorId) + RelBuilder.makeTopNRel( + input, + count, + sortFieldList.asJava, + RelBuilderUtil.createExtensionNode(inputAttributes), + context, + operatorId) } } diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/execution/BasicPhysicalOperatorTransformer.scala b/gluten-substrait/src/main/scala/org/apache/gluten/execution/BasicPhysicalOperatorTransformer.scala index ac8e610956dc..9f52782b1935 100644 --- a/gluten-substrait/src/main/scala/org/apache/gluten/execution/BasicPhysicalOperatorTransformer.scala +++ b/gluten-substrait/src/main/scala/org/apache/gluten/execution/BasicPhysicalOperatorTransformer.scala @@ -18,14 +18,12 @@ package org.apache.gluten.execution import org.apache.gluten.backendsapi.BackendsApiManager import org.apache.gluten.exception.GlutenNotSupportException -import org.apache.gluten.expression.{ConverterUtils, ExpressionConverter, ExpressionTransformer} +import org.apache.gluten.expression.{ExpressionConverter, ExpressionTransformer} import org.apache.gluten.extension.ValidationResult import org.apache.gluten.extension.columnar.transition.Convention import org.apache.gluten.metrics.MetricsUpdater -import org.apache.gluten.substrait.`type`.TypeBuilder import org.apache.gluten.substrait.SubstraitContext -import org.apache.gluten.substrait.extensions.ExtensionBuilder -import org.apache.gluten.substrait.rel.{RelBuilder, RelNode} +import org.apache.gluten.substrait.rel.{RelBuilderUtil, RelNode} import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD @@ -77,24 +75,14 @@ abstract class FilterExecTransformerBase(val cond: Expression, val input: SparkP operatorId: Long, input: RelNode, validation: Boolean): RelNode = { - assert(condExpr != null) - val args = context.registeredFunction - val condExprNode = ExpressionConverter - .replaceWithExpressionTransformer(condExpr, attributeSeq = originalInputAttributes) - .doTransform(args) - - if (!validation) { - RelBuilder.makeFilterRel(input, condExprNode, context, operatorId) - } else { - // Use a extension node to send the input types through Substrait plan for validation. - val inputTypeNodeList = originalInputAttributes - .map(attr => ConverterUtils.getTypeNode(attr.dataType, attr.nullable)) - .asJava - val extensionNode = ExtensionBuilder.makeAdvancedExtension( - BackendsApiManager.getTransformerApiInstance.packPBMessage( - TypeBuilder.makeStruct(false, inputTypeNodeList).toProtobuf)) - RelBuilder.makeFilterRel(input, condExprNode, extensionNode, context, operatorId) - } + RelBuilderUtil.createFilterRel( + context, + condExpr, + originalInputAttributes, + operatorId, + input, + validation + ) } override def output: Seq[Attribute] = { @@ -229,27 +217,15 @@ abstract class ProjectExecTransformerBase(val list: Seq[NamedExpression], val in validation: Boolean): RelNode = { val args = context.registeredFunction val columnarProjExprs: Seq[ExpressionTransformer] = ExpressionConverter - .replaceWithExpressionTransformer(projectList, attributeSeq = originalInputAttributes) + .replaceWithExpressionTransformer(projectList, originalInputAttributes) val projExprNodeList = columnarProjExprs.map(_.doTransform(args)).asJava - val emitStartIndex = originalInputAttributes.size - if (!validation) { - RelBuilder.makeProjectRel(input, projExprNodeList, context, operatorId, emitStartIndex) - } else { - // Use a extension node to send the input types through Substrait plan for validation. - val inputTypeNodeList = originalInputAttributes - .map(attr => ConverterUtils.getTypeNode(attr.dataType, attr.nullable)) - .asJava - val extensionNode = ExtensionBuilder.makeAdvancedExtension( - BackendsApiManager.getTransformerApiInstance.packPBMessage( - TypeBuilder.makeStruct(false, inputTypeNodeList).toProtobuf)) - RelBuilder.makeProjectRel( - input, - projExprNodeList, - extensionNode, - context, - operatorId, - emitStartIndex) - } + RelBuilderUtil.createProjectRel( + originalInputAttributes, + input, + projExprNodeList, + context, + operatorId, + validation) } override def verboseStringWithOperatorId(): String = { diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/execution/LimitExecTransformer.scala b/gluten-substrait/src/main/scala/org/apache/gluten/execution/LimitExecTransformer.scala index 0d49acd30f0d..6674ed36045e 100644 --- a/gluten-substrait/src/main/scala/org/apache/gluten/execution/LimitExecTransformer.scala +++ b/gluten-substrait/src/main/scala/org/apache/gluten/execution/LimitExecTransformer.scala @@ -17,19 +17,14 @@ package org.apache.gluten.execution import org.apache.gluten.backendsapi.BackendsApiManager -import org.apache.gluten.expression.ConverterUtils import org.apache.gluten.extension.ValidationResult import org.apache.gluten.metrics.MetricsUpdater -import org.apache.gluten.substrait.`type`.TypeBuilder import org.apache.gluten.substrait.SubstraitContext -import org.apache.gluten.substrait.extensions.ExtensionBuilder -import org.apache.gluten.substrait.rel.{RelBuilder, RelNode} +import org.apache.gluten.substrait.rel.{RelBuilder, RelBuilderUtil, RelNode} import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.execution.SparkPlan -import scala.collection.JavaConverters._ - case class LimitExecTransformer(child: SparkPlan, offset: Long, count: Long) extends UnaryTransformSupport { @@ -71,12 +66,13 @@ case class LimitExecTransformer(child: SparkPlan, offset: Long, count: Long) if (!validation) { RelBuilder.makeFetchRel(input, offset, count, context, operatorId) } else { - val inputTypeNodes = - inputAttributes.map(attr => ConverterUtils.getTypeNode(attr.dataType, attr.nullable)).asJava - val extensionNode = ExtensionBuilder.makeAdvancedExtension( - BackendsApiManager.getTransformerApiInstance.packPBMessage( - TypeBuilder.makeStruct(false, inputTypeNodes).toProtobuf)) - RelBuilder.makeFetchRel(input, offset, count, extensionNode, context, operatorId) + RelBuilder.makeFetchRel( + input, + offset, + count, + RelBuilderUtil.createExtensionNode(inputAttributes), + context, + operatorId) } } } diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/execution/SampleExecTransformer.scala b/gluten-substrait/src/main/scala/org/apache/gluten/execution/SampleExecTransformer.scala index 8e664a3b6ebc..87789a6046a1 100644 --- a/gluten-substrait/src/main/scala/org/apache/gluten/execution/SampleExecTransformer.scala +++ b/gluten-substrait/src/main/scala/org/apache/gluten/execution/SampleExecTransformer.scala @@ -17,21 +17,16 @@ package org.apache.gluten.execution import org.apache.gluten.backendsapi.BackendsApiManager -import org.apache.gluten.expression.{ConverterUtils, ExpressionConverter} import org.apache.gluten.extension.ValidationResult import org.apache.gluten.metrics.MetricsUpdater -import org.apache.gluten.substrait.`type`.TypeBuilder import org.apache.gluten.substrait.SubstraitContext -import org.apache.gluten.substrait.extensions.ExtensionBuilder -import org.apache.gluten.substrait.rel.{RelBuilder, RelNode} +import org.apache.gluten.substrait.rel.{RelBuilderUtil, RelNode} import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, LessThan, Literal, Rand} import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.types.DoubleType -import scala.collection.JavaConverters._ - /** * SampleExec supports two sampling methods: with replacement and without replacement. This * transformer currently supports only sampling without replacement. For sampling without @@ -77,24 +72,14 @@ case class SampleExecTransformer( operatorId: Long, input: RelNode, validation: Boolean): RelNode = { - assert(condExpr != null) - val args = context.registeredFunction - val condExprNode = ExpressionConverter - .replaceWithExpressionTransformer(condExpr, attributeSeq = originalInputAttributes) - .doTransform(args) - - if (!validation) { - RelBuilder.makeFilterRel(input, condExprNode, context, operatorId) - } else { - // Use a extension node to send the input types through Substrait plan for validation. - val inputTypeNodeList = originalInputAttributes - .map(attr => ConverterUtils.getTypeNode(attr.dataType, attr.nullable)) - .asJava - val extensionNode = ExtensionBuilder.makeAdvancedExtension( - BackendsApiManager.getTransformerApiInstance.packPBMessage( - TypeBuilder.makeStruct(false, inputTypeNodeList).toProtobuf)) - RelBuilder.makeFilterRel(input, condExprNode, extensionNode, context, operatorId) - } + RelBuilderUtil.createFilterRel( + context, + condExpr, + originalInputAttributes, + operatorId, + input, + validation + ) } override protected def doValidateInternal(): ValidationResult = { diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/execution/WindowExecTransformer.scala b/gluten-substrait/src/main/scala/org/apache/gluten/execution/WindowExecTransformer.scala index 7bdb29f0ee95..1e1c41743e4d 100644 --- a/gluten-substrait/src/main/scala/org/apache/gluten/execution/WindowExecTransformer.scala +++ b/gluten-substrait/src/main/scala/org/apache/gluten/execution/WindowExecTransformer.scala @@ -21,11 +21,10 @@ import org.apache.gluten.config.GlutenConfig import org.apache.gluten.expression._ import org.apache.gluten.extension.ValidationResult import org.apache.gluten.metrics.MetricsUpdater -import org.apache.gluten.substrait.`type`.TypeBuilder import org.apache.gluten.substrait.SubstraitContext import org.apache.gluten.substrait.expression.WindowFunctionNode import org.apache.gluten.substrait.extensions.ExtensionBuilder -import org.apache.gluten.substrait.rel.{RelBuilder, RelNode} +import org.apache.gluten.substrait.rel.{RelBuilder, RelBuilderUtil, RelNode} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, Partitioning} @@ -148,20 +147,12 @@ case class WindowExecTransformer( context, operatorId) } else { - // Use a extension node to send the input types through Substrait plan for validation. - val inputTypeNodeList = originalInputAttributes - .map(attr => ConverterUtils.getTypeNode(attr.dataType, attr.nullable)) - .asJava - val extensionNode = ExtensionBuilder.makeAdvancedExtension( - BackendsApiManager.getTransformerApiInstance.packPBMessage( - TypeBuilder.makeStruct(false, inputTypeNodeList).toProtobuf)) - RelBuilder.makeWindowRel( input, windowExpressions, partitionsExpressions, sortFieldList, - extensionNode, + RelBuilderUtil.createExtensionNode(originalInputAttributes), context, operatorId) } diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/execution/WindowGroupLimitExecTransformer.scala b/gluten-substrait/src/main/scala/org/apache/gluten/execution/WindowGroupLimitExecTransformer.scala index 5d8a18b11164..b09fbda4735b 100644 --- a/gluten-substrait/src/main/scala/org/apache/gluten/execution/WindowGroupLimitExecTransformer.scala +++ b/gluten-substrait/src/main/scala/org/apache/gluten/execution/WindowGroupLimitExecTransformer.scala @@ -17,13 +17,11 @@ package org.apache.gluten.execution import org.apache.gluten.backendsapi.BackendsApiManager -import org.apache.gluten.expression.{ConverterUtils, ExpressionConverter} +import org.apache.gluten.expression.ExpressionConverter import org.apache.gluten.extension.ValidationResult import org.apache.gluten.metrics.MetricsUpdater -import org.apache.gluten.substrait.`type`.TypeBuilder import org.apache.gluten.substrait.SubstraitContext -import org.apache.gluten.substrait.extensions.ExtensionBuilder -import org.apache.gluten.substrait.rel.{RelBuilder, RelNode} +import org.apache.gluten.substrait.rel.{RelBuilder, RelBuilderUtil, RelNode} import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, Expression, SortOrder} import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, Partitioning} @@ -123,20 +121,12 @@ case class WindowGroupLimitExecTransformer( context, operatorId) } else { - // Use a extension node to send the input types through Substrait plan for validation. - val inputTypeNodeList = originalInputAttributes - .map(attr => ConverterUtils.getTypeNode(attr.dataType, attr.nullable)) - .asJava - val extensionNode = ExtensionBuilder.makeAdvancedExtension( - BackendsApiManager.getTransformerApiInstance.packPBMessage( - TypeBuilder.makeStruct(false, inputTypeNodeList).toProtobuf)) - RelBuilder.makeWindowGroupLimitRel( input, partitionsExpressions, sortFieldList, limit, - extensionNode, + RelBuilderUtil.createExtensionNode(originalInputAttributes), context, operatorId) } diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/substrait/rel/RelBuilderUtil.scala b/gluten-substrait/src/main/scala/org/apache/gluten/substrait/rel/RelBuilderUtil.scala new file mode 100644 index 000000000000..82eb76f13bf7 --- /dev/null +++ b/gluten-substrait/src/main/scala/org/apache/gluten/substrait/rel/RelBuilderUtil.scala @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.gluten.substrait.rel + +import org.apache.gluten.backendsapi.BackendsApiManager +import org.apache.gluten.expression.{ConverterUtils, ExpressionConverter} +import org.apache.gluten.substrait.`type`.TypeBuilder +import org.apache.gluten.substrait.SubstraitContext +import org.apache.gluten.substrait.expression.ExpressionNode +import org.apache.gluten.substrait.extensions.{AdvancedExtensionNode, ExtensionBuilder} + +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} + +import java.util.{List => JList} + +import scala.collection.JavaConverters._ + +object RelBuilderUtil { + + def createProjectRel( + inputAttributes: Seq[Attribute], + input: RelNode, + projExprNodeList: JList[ExpressionNode], + context: SubstraitContext, + operatorId: Long, + validation: Boolean): RelNode = { + val emitStartIndex = inputAttributes.size + if (!validation) { + RelBuilder.makeProjectRel(input, projExprNodeList, context, operatorId, emitStartIndex) + } else { + RelBuilder.makeProjectRel( + input, + projExprNodeList, + createExtensionNode(inputAttributes), + context, + operatorId, + emitStartIndex) + } + } + + def createFilterRel( + context: SubstraitContext, + condExpr: Expression, + inputAttributes: Seq[Attribute], + operatorId: Long, + input: RelNode, + validation: Boolean): RelNode = { + assert(condExpr != null) + val args = context.registeredFunction + val condExprNode = ExpressionConverter + .replaceWithExpressionTransformer(condExpr, inputAttributes) + .doTransform(args) + + if (!validation) { + RelBuilder.makeFilterRel(input, condExprNode, context, operatorId) + } else { + RelBuilder.makeFilterRel( + input, + condExprNode, + createExtensionNode(inputAttributes), + context, + operatorId) + } + } + + def createExtensionNode(inputAttributes: Seq[Attribute]): AdvancedExtensionNode = { + // Use an extension node to send the input types through Substrait plan for validation. + val inputTypeNodeList = inputAttributes + .map(attr => ConverterUtils.getTypeNode(attr.dataType, attr.nullable)) + .asJava + ExtensionBuilder.makeAdvancedExtension( + BackendsApiManager.getTransformerApiInstance.packPBMessage( + TypeBuilder.makeStruct(false, inputTypeNodeList).toProtobuf)) + } +} From 60e59af78301883b5f86dc2311fd2099dcbd6416 Mon Sep 17 00:00:00 2001 From: zml1206 Date: Fri, 17 Jan 2025 13:10:35 +0800 Subject: [PATCH 2/2] update --- .../HashAggregateExecTransformer.scala | 6 +- .../gluten/execution/TopNTransformer.scala | 4 +- .../gluten/substrait/rel/RelBuilder.java | 53 +++++++++++ .../BasicPhysicalOperatorTransformer.scala | 16 ++-- .../execution/LimitExecTransformer.scala | 6 +- .../execution/SampleExecTransformer.scala | 15 +++- .../execution/WindowExecTransformer.scala | 4 +- .../WindowGroupLimitExecTransformer.scala | 4 +- .../gluten/substrait/rel/RelBuilderUtil.scala | 89 ------------------- 9 files changed, 87 insertions(+), 110 deletions(-) delete mode 100644 gluten-substrait/src/main/scala/org/apache/gluten/substrait/rel/RelBuilderUtil.scala diff --git a/backends-velox/src/main/scala/org/apache/gluten/execution/HashAggregateExecTransformer.scala b/backends-velox/src/main/scala/org/apache/gluten/execution/HashAggregateExecTransformer.scala index e596c7d0ea21..7b0508de87cb 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/execution/HashAggregateExecTransformer.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/execution/HashAggregateExecTransformer.scala @@ -24,7 +24,7 @@ import org.apache.gluten.substrait.`type`.{TypeBuilder, TypeNode} import org.apache.gluten.substrait.{AggregationParams, SubstraitContext} import org.apache.gluten.substrait.expression.{AggregateFunctionNode, ExpressionBuilder, ExpressionNode, ScalarFunctionNode} import org.apache.gluten.substrait.extensions.{AdvancedExtensionNode, ExtensionBuilder} -import org.apache.gluten.substrait.rel.{RelBuilder, RelBuilderUtil, RelNode} +import org.apache.gluten.substrait.rel.{RelBuilder, RelNode} import org.apache.gluten.utils.VeloxIntermediateData import org.apache.spark.sql.catalyst.expressions._ @@ -419,8 +419,8 @@ abstract class HashAggregateExecTransformer( } // Create a project rel. - val projectRel = RelBuilderUtil.createProjectRel( - originalInputAttributes, + val projectRel = RelBuilder.makeProjectRel( + originalInputAttributes.asJava, inputRel, exprNodes, context, diff --git a/backends-velox/src/main/scala/org/apache/gluten/execution/TopNTransformer.scala b/backends-velox/src/main/scala/org/apache/gluten/execution/TopNTransformer.scala index 20af15f91677..f3bc929d7eb5 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/execution/TopNTransformer.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/execution/TopNTransformer.scala @@ -20,7 +20,7 @@ import org.apache.gluten.expression.ExpressionConverter import org.apache.gluten.extension.ValidationResult import org.apache.gluten.metrics.MetricsUpdater import org.apache.gluten.substrait.SubstraitContext -import org.apache.gluten.substrait.rel.{RelBuilder, RelBuilderUtil, RelNode} +import org.apache.gluten.substrait.rel.{RelBuilder, RelNode} import org.apache.spark.sql.catalyst.expressions.{Attribute, SortOrder} import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, Distribution, Partitioning, UnspecifiedDistribution} @@ -106,7 +106,7 @@ case class TopNTransformer( input, count, sortFieldList.asJava, - RelBuilderUtil.createExtensionNode(inputAttributes), + RelBuilder.createExtensionNode(inputAttributes.asJava), context, operatorId) } diff --git a/gluten-substrait/src/main/java/org/apache/gluten/substrait/rel/RelBuilder.java b/gluten-substrait/src/main/java/org/apache/gluten/substrait/rel/RelBuilder.java index 86b27353183c..c8a028d0be4d 100644 --- a/gluten-substrait/src/main/java/org/apache/gluten/substrait/rel/RelBuilder.java +++ b/gluten-substrait/src/main/java/org/apache/gluten/substrait/rel/RelBuilder.java @@ -16,24 +16,55 @@ */ package org.apache.gluten.substrait.rel; +import org.apache.gluten.backendsapi.BackendsApiManager; import org.apache.gluten.expression.ConverterUtils; import org.apache.gluten.substrait.SubstraitContext; import org.apache.gluten.substrait.expression.AggregateFunctionNode; import org.apache.gluten.substrait.expression.ExpressionNode; import org.apache.gluten.substrait.expression.WindowFunctionNode; import org.apache.gluten.substrait.extensions.AdvancedExtensionNode; +import org.apache.gluten.substrait.extensions.ExtensionBuilder; import org.apache.gluten.substrait.type.ColumnTypeNode; +import org.apache.gluten.substrait.type.TypeBuilder; import org.apache.gluten.substrait.type.TypeNode; import io.substrait.proto.*; import org.apache.spark.sql.catalyst.expressions.Attribute; import java.util.List; +import java.util.stream.Collectors; /** Contains helper functions for constructing substrait relations. */ public class RelBuilder { private RelBuilder() {} + public static AdvancedExtensionNode createExtensionNode(List inputAttributes) { + // Use an extension node to send the input types through Substrait plan for validation. + List inputTypeNodeList = + inputAttributes.stream() + .map(attr -> ConverterUtils.getTypeNode(attr.dataType(), attr.nullable())) + .collect(Collectors.toList()); + + return ExtensionBuilder.makeAdvancedExtension( + BackendsApiManager.getTransformerApiInstance() + .packPBMessage(TypeBuilder.makeStruct(false, inputTypeNodeList).toProtobuf())); + } + + public static RelNode makeFilterRel( + SubstraitContext context, + ExpressionNode condExprNode, + List inputAttributes, + Long operatorId, + RelNode input, + Boolean validation) { + if (!validation) { + return RelBuilder.makeFilterRel(input, condExprNode, context, operatorId); + } else { + return RelBuilder.makeFilterRel( + input, condExprNode, createExtensionNode(inputAttributes), context, operatorId); + } + } + public static RelNode makeFilterRel( RelNode input, ExpressionNode condition, SubstraitContext context, Long operatorId) { context.registerRelToOperator(operatorId); @@ -50,6 +81,28 @@ public static RelNode makeFilterRel( return new FilterRelNode(input, condition, extensionNode); } + public static RelNode makeProjectRel( + List inputAttributes, + RelNode input, + List projExprNodeList, + SubstraitContext context, + Long operatorId, + Boolean validation) { + int emitStartIndex = inputAttributes.size(); + if (!validation) { + return RelBuilder.makeProjectRel( + input, projExprNodeList, context, operatorId, emitStartIndex); + } else { + return RelBuilder.makeProjectRel( + input, + projExprNodeList, + createExtensionNode(inputAttributes), + context, + operatorId, + emitStartIndex); + } + } + public static RelNode makeProjectRel( RelNode input, List expressionNodes, diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/execution/BasicPhysicalOperatorTransformer.scala b/gluten-substrait/src/main/scala/org/apache/gluten/execution/BasicPhysicalOperatorTransformer.scala index 9f52782b1935..fe4898dbaa7b 100644 --- a/gluten-substrait/src/main/scala/org/apache/gluten/execution/BasicPhysicalOperatorTransformer.scala +++ b/gluten-substrait/src/main/scala/org/apache/gluten/execution/BasicPhysicalOperatorTransformer.scala @@ -23,7 +23,7 @@ import org.apache.gluten.extension.ValidationResult import org.apache.gluten.extension.columnar.transition.Convention import org.apache.gluten.metrics.MetricsUpdater import org.apache.gluten.substrait.SubstraitContext -import org.apache.gluten.substrait.rel.{RelBuilderUtil, RelNode} +import org.apache.gluten.substrait.rel.{RelBuilder, RelNode} import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD @@ -75,10 +75,14 @@ abstract class FilterExecTransformerBase(val cond: Expression, val input: SparkP operatorId: Long, input: RelNode, validation: Boolean): RelNode = { - RelBuilderUtil.createFilterRel( + assert(condExpr != null) + val condExprNode = ExpressionConverter + .replaceWithExpressionTransformer(condExpr, originalInputAttributes) + .doTransform(context.registeredFunction) + RelBuilder.makeFilterRel( context, - condExpr, - originalInputAttributes, + condExprNode, + originalInputAttributes.asJava, operatorId, input, validation @@ -219,8 +223,8 @@ abstract class ProjectExecTransformerBase(val list: Seq[NamedExpression], val in val columnarProjExprs: Seq[ExpressionTransformer] = ExpressionConverter .replaceWithExpressionTransformer(projectList, originalInputAttributes) val projExprNodeList = columnarProjExprs.map(_.doTransform(args)).asJava - RelBuilderUtil.createProjectRel( - originalInputAttributes, + RelBuilder.makeProjectRel( + originalInputAttributes.asJava, input, projExprNodeList, context, diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/execution/LimitExecTransformer.scala b/gluten-substrait/src/main/scala/org/apache/gluten/execution/LimitExecTransformer.scala index 6674ed36045e..73cc1a15a9a4 100644 --- a/gluten-substrait/src/main/scala/org/apache/gluten/execution/LimitExecTransformer.scala +++ b/gluten-substrait/src/main/scala/org/apache/gluten/execution/LimitExecTransformer.scala @@ -20,11 +20,13 @@ import org.apache.gluten.backendsapi.BackendsApiManager import org.apache.gluten.extension.ValidationResult import org.apache.gluten.metrics.MetricsUpdater import org.apache.gluten.substrait.SubstraitContext -import org.apache.gluten.substrait.rel.{RelBuilder, RelBuilderUtil, RelNode} +import org.apache.gluten.substrait.rel.{RelBuilder, RelNode} import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.execution.SparkPlan +import scala.collection.JavaConverters._ + case class LimitExecTransformer(child: SparkPlan, offset: Long, count: Long) extends UnaryTransformSupport { @@ -70,7 +72,7 @@ case class LimitExecTransformer(child: SparkPlan, offset: Long, count: Long) input, offset, count, - RelBuilderUtil.createExtensionNode(inputAttributes), + RelBuilder.createExtensionNode(inputAttributes.asJava), context, operatorId) } diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/execution/SampleExecTransformer.scala b/gluten-substrait/src/main/scala/org/apache/gluten/execution/SampleExecTransformer.scala index 87789a6046a1..4fed8b36e90d 100644 --- a/gluten-substrait/src/main/scala/org/apache/gluten/execution/SampleExecTransformer.scala +++ b/gluten-substrait/src/main/scala/org/apache/gluten/execution/SampleExecTransformer.scala @@ -17,16 +17,19 @@ package org.apache.gluten.execution import org.apache.gluten.backendsapi.BackendsApiManager +import org.apache.gluten.expression.ExpressionConverter import org.apache.gluten.extension.ValidationResult import org.apache.gluten.metrics.MetricsUpdater import org.apache.gluten.substrait.SubstraitContext -import org.apache.gluten.substrait.rel.{RelBuilderUtil, RelNode} +import org.apache.gluten.substrait.rel.{RelBuilder, RelNode} import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, LessThan, Literal, Rand} import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.types.DoubleType +import scala.collection.JavaConverters._ + /** * SampleExec supports two sampling methods: with replacement and without replacement. This * transformer currently supports only sampling without replacement. For sampling without @@ -72,10 +75,14 @@ case class SampleExecTransformer( operatorId: Long, input: RelNode, validation: Boolean): RelNode = { - RelBuilderUtil.createFilterRel( + assert(condExpr != null) + val condExprNode = ExpressionConverter + .replaceWithExpressionTransformer(condExpr, originalInputAttributes) + .doTransform(context.registeredFunction) + RelBuilder.makeFilterRel( context, - condExpr, - originalInputAttributes, + condExprNode, + originalInputAttributes.asJava, operatorId, input, validation diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/execution/WindowExecTransformer.scala b/gluten-substrait/src/main/scala/org/apache/gluten/execution/WindowExecTransformer.scala index 1e1c41743e4d..068cee1742ee 100644 --- a/gluten-substrait/src/main/scala/org/apache/gluten/execution/WindowExecTransformer.scala +++ b/gluten-substrait/src/main/scala/org/apache/gluten/execution/WindowExecTransformer.scala @@ -24,7 +24,7 @@ import org.apache.gluten.metrics.MetricsUpdater import org.apache.gluten.substrait.SubstraitContext import org.apache.gluten.substrait.expression.WindowFunctionNode import org.apache.gluten.substrait.extensions.ExtensionBuilder -import org.apache.gluten.substrait.rel.{RelBuilder, RelBuilderUtil, RelNode} +import org.apache.gluten.substrait.rel.{RelBuilder, RelNode} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, Partitioning} @@ -152,7 +152,7 @@ case class WindowExecTransformer( windowExpressions, partitionsExpressions, sortFieldList, - RelBuilderUtil.createExtensionNode(originalInputAttributes), + RelBuilder.createExtensionNode(originalInputAttributes.asJava), context, operatorId) } diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/execution/WindowGroupLimitExecTransformer.scala b/gluten-substrait/src/main/scala/org/apache/gluten/execution/WindowGroupLimitExecTransformer.scala index b09fbda4735b..d96d04dfadb7 100644 --- a/gluten-substrait/src/main/scala/org/apache/gluten/execution/WindowGroupLimitExecTransformer.scala +++ b/gluten-substrait/src/main/scala/org/apache/gluten/execution/WindowGroupLimitExecTransformer.scala @@ -21,7 +21,7 @@ import org.apache.gluten.expression.ExpressionConverter import org.apache.gluten.extension.ValidationResult import org.apache.gluten.metrics.MetricsUpdater import org.apache.gluten.substrait.SubstraitContext -import org.apache.gluten.substrait.rel.{RelBuilder, RelBuilderUtil, RelNode} +import org.apache.gluten.substrait.rel.{RelBuilder, RelNode} import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, Expression, SortOrder} import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, Partitioning} @@ -126,7 +126,7 @@ case class WindowGroupLimitExecTransformer( partitionsExpressions, sortFieldList, limit, - RelBuilderUtil.createExtensionNode(originalInputAttributes), + RelBuilder.createExtensionNode(originalInputAttributes.asJava), context, operatorId) } diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/substrait/rel/RelBuilderUtil.scala b/gluten-substrait/src/main/scala/org/apache/gluten/substrait/rel/RelBuilderUtil.scala deleted file mode 100644 index 82eb76f13bf7..000000000000 --- a/gluten-substrait/src/main/scala/org/apache/gluten/substrait/rel/RelBuilderUtil.scala +++ /dev/null @@ -1,89 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.gluten.substrait.rel - -import org.apache.gluten.backendsapi.BackendsApiManager -import org.apache.gluten.expression.{ConverterUtils, ExpressionConverter} -import org.apache.gluten.substrait.`type`.TypeBuilder -import org.apache.gluten.substrait.SubstraitContext -import org.apache.gluten.substrait.expression.ExpressionNode -import org.apache.gluten.substrait.extensions.{AdvancedExtensionNode, ExtensionBuilder} - -import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} - -import java.util.{List => JList} - -import scala.collection.JavaConverters._ - -object RelBuilderUtil { - - def createProjectRel( - inputAttributes: Seq[Attribute], - input: RelNode, - projExprNodeList: JList[ExpressionNode], - context: SubstraitContext, - operatorId: Long, - validation: Boolean): RelNode = { - val emitStartIndex = inputAttributes.size - if (!validation) { - RelBuilder.makeProjectRel(input, projExprNodeList, context, operatorId, emitStartIndex) - } else { - RelBuilder.makeProjectRel( - input, - projExprNodeList, - createExtensionNode(inputAttributes), - context, - operatorId, - emitStartIndex) - } - } - - def createFilterRel( - context: SubstraitContext, - condExpr: Expression, - inputAttributes: Seq[Attribute], - operatorId: Long, - input: RelNode, - validation: Boolean): RelNode = { - assert(condExpr != null) - val args = context.registeredFunction - val condExprNode = ExpressionConverter - .replaceWithExpressionTransformer(condExpr, inputAttributes) - .doTransform(args) - - if (!validation) { - RelBuilder.makeFilterRel(input, condExprNode, context, operatorId) - } else { - RelBuilder.makeFilterRel( - input, - condExprNode, - createExtensionNode(inputAttributes), - context, - operatorId) - } - } - - def createExtensionNode(inputAttributes: Seq[Attribute]): AdvancedExtensionNode = { - // Use an extension node to send the input types through Substrait plan for validation. - val inputTypeNodeList = inputAttributes - .map(attr => ConverterUtils.getTypeNode(attr.dataType, attr.nullable)) - .asJava - ExtensionBuilder.makeAdvancedExtension( - BackendsApiManager.getTransformerApiInstance.packPBMessage( - TypeBuilder.makeStruct(false, inputTypeNodeList).toProtobuf)) - } -}