Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -419,25 +419,13 @@ abstract class HashAggregateExecTransformer(
}

// Create a project rel.
val emitStartIndex = originalInputAttributes.size
val projectRel = if (!validation) {
RelBuilder.makeProjectRel(inputRel, exprNodes, context, operatorId, emitStartIndex)
} else {
// Use a extension node to send the input types through Substrait plan for validation.
val inputTypeNodeList = originalInputAttributes
.map(attr => ConverterUtils.getTypeNode(attr.dataType, attr.nullable))
.asJava
val extensionNode = ExtensionBuilder.makeAdvancedExtension(
BackendsApiManager.getTransformerApiInstance.packPBMessage(
TypeBuilder.makeStruct(false, inputTypeNodeList).toProtobuf))
RelBuilder.makeProjectRel(
inputRel,
exprNodes,
extensionNode,
context,
operatorId,
emitStartIndex)
}
val projectRel = RelBuilder.makeProjectRel(
originalInputAttributes.asJava,
inputRel,
exprNodes,
context,
operatorId,
validation)

// Create aggregation rel.
val groupingList = new JArrayList[ExpressionNode]()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,10 @@
*/
package org.apache.gluten.execution

import org.apache.gluten.backendsapi.BackendsApiManager
import org.apache.gluten.expression.{ConverterUtils, ExpressionConverter}
import org.apache.gluten.expression.ExpressionConverter
import org.apache.gluten.extension.ValidationResult
import org.apache.gluten.metrics.MetricsUpdater
import org.apache.gluten.substrait.`type`.TypeBuilder
import org.apache.gluten.substrait.SubstraitContext
import org.apache.gluten.substrait.extensions.ExtensionBuilder
import org.apache.gluten.substrait.rel.{RelBuilder, RelNode}

import org.apache.spark.sql.catalyst.expressions.{Attribute, SortOrder}
Expand Down Expand Up @@ -105,12 +102,13 @@ case class TopNTransformer(
if (!validation) {
RelBuilder.makeTopNRel(input, count, sortFieldList.asJava, context, operatorId)
} else {
val inputTypeNodes =
inputAttributes.map(attr => ConverterUtils.getTypeNode(attr.dataType, attr.nullable)).asJava
val extensionNode = ExtensionBuilder.makeAdvancedExtension(
BackendsApiManager.getTransformerApiInstance.packPBMessage(
TypeBuilder.makeStruct(false, inputTypeNodes).toProtobuf))
RelBuilder.makeTopNRel(input, count, sortFieldList.asJava, extensionNode, context, operatorId)
RelBuilder.makeTopNRel(
input,
count,
sortFieldList.asJava,
RelBuilder.createExtensionNode(inputAttributes.asJava),
context,
operatorId)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,24 +16,55 @@
*/
package org.apache.gluten.substrait.rel;

import org.apache.gluten.backendsapi.BackendsApiManager;
import org.apache.gluten.expression.ConverterUtils;
import org.apache.gluten.substrait.SubstraitContext;
import org.apache.gluten.substrait.expression.AggregateFunctionNode;
import org.apache.gluten.substrait.expression.ExpressionNode;
import org.apache.gluten.substrait.expression.WindowFunctionNode;
import org.apache.gluten.substrait.extensions.AdvancedExtensionNode;
import org.apache.gluten.substrait.extensions.ExtensionBuilder;
import org.apache.gluten.substrait.type.ColumnTypeNode;
import org.apache.gluten.substrait.type.TypeBuilder;
import org.apache.gluten.substrait.type.TypeNode;

import io.substrait.proto.*;
import org.apache.spark.sql.catalyst.expressions.Attribute;

import java.util.List;
import java.util.stream.Collectors;

/** Contains helper functions for constructing substrait relations. */
public class RelBuilder {
private RelBuilder() {}

public static AdvancedExtensionNode createExtensionNode(List<Attribute> inputAttributes) {
// Use an extension node to send the input types through Substrait plan for validation.
List<TypeNode> inputTypeNodeList =
inputAttributes.stream()
.map(attr -> ConverterUtils.getTypeNode(attr.dataType(), attr.nullable()))
.collect(Collectors.toList());

return ExtensionBuilder.makeAdvancedExtension(
BackendsApiManager.getTransformerApiInstance()
.packPBMessage(TypeBuilder.makeStruct(false, inputTypeNodeList).toProtobuf()));
}

public static RelNode makeFilterRel(
SubstraitContext context,
ExpressionNode condExprNode,
List<Attribute> inputAttributes,
Long operatorId,
RelNode input,
Boolean validation) {
if (!validation) {
return RelBuilder.makeFilterRel(input, condExprNode, context, operatorId);
} else {
return RelBuilder.makeFilterRel(
input, condExprNode, createExtensionNode(inputAttributes), context, operatorId);
}
}

public static RelNode makeFilterRel(
RelNode input, ExpressionNode condition, SubstraitContext context, Long operatorId) {
context.registerRelToOperator(operatorId);
Expand All @@ -50,6 +81,28 @@ public static RelNode makeFilterRel(
return new FilterRelNode(input, condition, extensionNode);
}

public static RelNode makeProjectRel(
List<Attribute> inputAttributes,
RelNode input,
List<ExpressionNode> projExprNodeList,
SubstraitContext context,
Long operatorId,
Boolean validation) {
int emitStartIndex = inputAttributes.size();
if (!validation) {
return RelBuilder.makeProjectRel(
input, projExprNodeList, context, operatorId, emitStartIndex);
} else {
return RelBuilder.makeProjectRel(
input,
projExprNodeList,
createExtensionNode(inputAttributes),
context,
operatorId,
emitStartIndex);
}
}

public static RelNode makeProjectRel(
RelNode input,
List<ExpressionNode> expressionNodes,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,11 @@ package org.apache.gluten.execution

import org.apache.gluten.backendsapi.BackendsApiManager
import org.apache.gluten.exception.GlutenNotSupportException
import org.apache.gluten.expression.{ConverterUtils, ExpressionConverter, ExpressionTransformer}
import org.apache.gluten.expression.{ExpressionConverter, ExpressionTransformer}
import org.apache.gluten.extension.ValidationResult
import org.apache.gluten.extension.columnar.transition.Convention
import org.apache.gluten.metrics.MetricsUpdater
import org.apache.gluten.substrait.`type`.TypeBuilder
import org.apache.gluten.substrait.SubstraitContext
import org.apache.gluten.substrait.extensions.ExtensionBuilder
import org.apache.gluten.substrait.rel.{RelBuilder, RelNode}

import org.apache.spark.internal.Logging
Expand Down Expand Up @@ -78,23 +76,17 @@ abstract class FilterExecTransformerBase(val cond: Expression, val input: SparkP
input: RelNode,
validation: Boolean): RelNode = {
assert(condExpr != null)
val args = context.registeredFunction
val condExprNode = ExpressionConverter
.replaceWithExpressionTransformer(condExpr, attributeSeq = originalInputAttributes)
.doTransform(args)

if (!validation) {
RelBuilder.makeFilterRel(input, condExprNode, context, operatorId)
} else {
// Use a extension node to send the input types through Substrait plan for validation.
val inputTypeNodeList = originalInputAttributes
.map(attr => ConverterUtils.getTypeNode(attr.dataType, attr.nullable))
.asJava
val extensionNode = ExtensionBuilder.makeAdvancedExtension(
BackendsApiManager.getTransformerApiInstance.packPBMessage(
TypeBuilder.makeStruct(false, inputTypeNodeList).toProtobuf))
RelBuilder.makeFilterRel(input, condExprNode, extensionNode, context, operatorId)
}
.replaceWithExpressionTransformer(condExpr, originalInputAttributes)
.doTransform(context.registeredFunction)
RelBuilder.makeFilterRel(
context,
condExprNode,
originalInputAttributes.asJava,
operatorId,
input,
validation
)
}

override def output: Seq[Attribute] = {
Expand Down Expand Up @@ -229,27 +221,15 @@ abstract class ProjectExecTransformerBase(val list: Seq[NamedExpression], val in
validation: Boolean): RelNode = {
val args = context.registeredFunction
val columnarProjExprs: Seq[ExpressionTransformer] = ExpressionConverter
.replaceWithExpressionTransformer(projectList, attributeSeq = originalInputAttributes)
.replaceWithExpressionTransformer(projectList, originalInputAttributes)
val projExprNodeList = columnarProjExprs.map(_.doTransform(args)).asJava
val emitStartIndex = originalInputAttributes.size
if (!validation) {
RelBuilder.makeProjectRel(input, projExprNodeList, context, operatorId, emitStartIndex)
} else {
// Use a extension node to send the input types through Substrait plan for validation.
val inputTypeNodeList = originalInputAttributes
.map(attr => ConverterUtils.getTypeNode(attr.dataType, attr.nullable))
.asJava
val extensionNode = ExtensionBuilder.makeAdvancedExtension(
BackendsApiManager.getTransformerApiInstance.packPBMessage(
TypeBuilder.makeStruct(false, inputTypeNodeList).toProtobuf))
RelBuilder.makeProjectRel(
input,
projExprNodeList,
extensionNode,
context,
operatorId,
emitStartIndex)
}
RelBuilder.makeProjectRel(
originalInputAttributes.asJava,
input,
projExprNodeList,
context,
operatorId,
validation)
}

override def verboseStringWithOperatorId(): String = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,9 @@
package org.apache.gluten.execution

import org.apache.gluten.backendsapi.BackendsApiManager
import org.apache.gluten.expression.ConverterUtils
import org.apache.gluten.extension.ValidationResult
import org.apache.gluten.metrics.MetricsUpdater
import org.apache.gluten.substrait.`type`.TypeBuilder
import org.apache.gluten.substrait.SubstraitContext
import org.apache.gluten.substrait.extensions.ExtensionBuilder
import org.apache.gluten.substrait.rel.{RelBuilder, RelNode}

import org.apache.spark.sql.catalyst.expressions.Attribute
Expand Down Expand Up @@ -71,12 +68,13 @@ case class LimitExecTransformer(child: SparkPlan, offset: Long, count: Long)
if (!validation) {
RelBuilder.makeFetchRel(input, offset, count, context, operatorId)
} else {
val inputTypeNodes =
inputAttributes.map(attr => ConverterUtils.getTypeNode(attr.dataType, attr.nullable)).asJava
val extensionNode = ExtensionBuilder.makeAdvancedExtension(
BackendsApiManager.getTransformerApiInstance.packPBMessage(
TypeBuilder.makeStruct(false, inputTypeNodes).toProtobuf))
RelBuilder.makeFetchRel(input, offset, count, extensionNode, context, operatorId)
RelBuilder.makeFetchRel(
input,
offset,
count,
RelBuilder.createExtensionNode(inputAttributes.asJava),
context,
operatorId)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,10 @@
package org.apache.gluten.execution

import org.apache.gluten.backendsapi.BackendsApiManager
import org.apache.gluten.expression.{ConverterUtils, ExpressionConverter}
import org.apache.gluten.expression.ExpressionConverter
import org.apache.gluten.extension.ValidationResult
import org.apache.gluten.metrics.MetricsUpdater
import org.apache.gluten.substrait.`type`.TypeBuilder
import org.apache.gluten.substrait.SubstraitContext
import org.apache.gluten.substrait.extensions.ExtensionBuilder
import org.apache.gluten.substrait.rel.{RelBuilder, RelNode}

import org.apache.spark.internal.Logging
Expand Down Expand Up @@ -78,23 +76,17 @@ case class SampleExecTransformer(
input: RelNode,
validation: Boolean): RelNode = {
assert(condExpr != null)
val args = context.registeredFunction
val condExprNode = ExpressionConverter
.replaceWithExpressionTransformer(condExpr, attributeSeq = originalInputAttributes)
.doTransform(args)

if (!validation) {
RelBuilder.makeFilterRel(input, condExprNode, context, operatorId)
} else {
// Use a extension node to send the input types through Substrait plan for validation.
val inputTypeNodeList = originalInputAttributes
.map(attr => ConverterUtils.getTypeNode(attr.dataType, attr.nullable))
.asJava
val extensionNode = ExtensionBuilder.makeAdvancedExtension(
BackendsApiManager.getTransformerApiInstance.packPBMessage(
TypeBuilder.makeStruct(false, inputTypeNodeList).toProtobuf))
RelBuilder.makeFilterRel(input, condExprNode, extensionNode, context, operatorId)
}
.replaceWithExpressionTransformer(condExpr, originalInputAttributes)
.doTransform(context.registeredFunction)
RelBuilder.makeFilterRel(
context,
condExprNode,
originalInputAttributes.asJava,
operatorId,
input,
validation
)
}

override protected def doValidateInternal(): ValidationResult = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import org.apache.gluten.config.GlutenConfig
import org.apache.gluten.expression._
import org.apache.gluten.extension.ValidationResult
import org.apache.gluten.metrics.MetricsUpdater
import org.apache.gluten.substrait.`type`.TypeBuilder
import org.apache.gluten.substrait.SubstraitContext
import org.apache.gluten.substrait.expression.WindowFunctionNode
import org.apache.gluten.substrait.extensions.ExtensionBuilder
Expand Down Expand Up @@ -148,20 +147,12 @@ case class WindowExecTransformer(
context,
operatorId)
} else {
// Use a extension node to send the input types through Substrait plan for validation.
val inputTypeNodeList = originalInputAttributes
.map(attr => ConverterUtils.getTypeNode(attr.dataType, attr.nullable))
.asJava
val extensionNode = ExtensionBuilder.makeAdvancedExtension(
BackendsApiManager.getTransformerApiInstance.packPBMessage(
TypeBuilder.makeStruct(false, inputTypeNodeList).toProtobuf))

RelBuilder.makeWindowRel(
input,
windowExpressions,
partitionsExpressions,
sortFieldList,
extensionNode,
RelBuilder.createExtensionNode(originalInputAttributes.asJava),
context,
operatorId)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,10 @@
package org.apache.gluten.execution

import org.apache.gluten.backendsapi.BackendsApiManager
import org.apache.gluten.expression.{ConverterUtils, ExpressionConverter}
import org.apache.gluten.expression.ExpressionConverter
import org.apache.gluten.extension.ValidationResult
import org.apache.gluten.metrics.MetricsUpdater
import org.apache.gluten.substrait.`type`.TypeBuilder
import org.apache.gluten.substrait.SubstraitContext
import org.apache.gluten.substrait.extensions.ExtensionBuilder
import org.apache.gluten.substrait.rel.{RelBuilder, RelNode}

import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, Expression, SortOrder}
Expand Down Expand Up @@ -123,20 +121,12 @@ case class WindowGroupLimitExecTransformer(
context,
operatorId)
} else {
// Use a extension node to send the input types through Substrait plan for validation.
val inputTypeNodeList = originalInputAttributes
.map(attr => ConverterUtils.getTypeNode(attr.dataType, attr.nullable))
.asJava
val extensionNode = ExtensionBuilder.makeAdvancedExtension(
BackendsApiManager.getTransformerApiInstance.packPBMessage(
TypeBuilder.makeStruct(false, inputTypeNodeList).toProtobuf))

RelBuilder.makeWindowGroupLimitRel(
input,
partitionsExpressions,
sortFieldList,
limit,
extensionNode,
RelBuilder.createExtensionNode(originalInputAttributes.asJava),
context,
operatorId)
}
Expand Down