Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@
import org.apache.beam.sdk.extensions.sql.zetasql.translation.ZetaSqlScalarFunctionImpl;
import org.apache.beam.sdk.extensions.sql.zetasql.unnest.BeamZetaSqlUncollectRule;
import org.apache.beam.sdk.extensions.sql.zetasql.unnest.BeamZetaSqlUnnestRule;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.adapter.enumerable.CallImplementor;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.adapter.enumerable.RexImpTable;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.config.CalciteConnectionConfig;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.jdbc.CalciteSchema;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.plan.ConventionTraitDef;
Expand All @@ -58,13 +60,17 @@
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.rel.rules.JoinCommuteRule;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.rel.rules.ProjectCalcMergeRule;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.rex.RexCall;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.rex.RexInputRef;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.rex.RexLiteral;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.rex.RexNode;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.schema.SchemaPlus;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.sql.SqlNode;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.sql.SqlOperator;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.sql.SqlOperatorTable;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.sql.parser.SqlParser;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.sql.parser.SqlParserImplFactory;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.sql.type.SqlTypeName;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.sql.util.ChainedSqlOperatorTable;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.sql.validate.SqlUserDefinedFunction;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.tools.FrameworkConfig;
Expand Down Expand Up @@ -128,9 +134,11 @@ public static Collection<RuleSet> getZetaSqlRuleSets(Collection<RelOptRule> calc
}

/**
* Returns true if the arguments only contain user-defined Java functions, otherwise returns
* false. User-defined java functions are in the category whose function group is equal to {@code
* SqlAnalyzer.USER_DEFINED_JAVA_SCALAR_FUNCTIONS}
* Returns true if all the following are true: All RexCalls can be implemented by codegen, All
* RexCalls only contain ZetaSQL user-defined Java functions, All RexLiterals pass ZetaSQL
* compliance tests, All RexInputRefs pass ZetaSQL compliance tests, No other RexNode types
* Otherwise returns false. ZetaSQL user-defined Java functions are in the category whose function
* group is equal to {@code SqlAnalyzer.USER_DEFINED_JAVA_SCALAR_FUNCTIONS}
*/
static boolean hasOnlyJavaUdfInProjects(RelOptRuleCall x) {
List<RelNode> resList = x.getRelList();
Expand All @@ -140,20 +148,61 @@ static boolean hasOnlyJavaUdfInProjects(RelOptRuleCall x) {
for (RexNode rexNode : logicalCalc.getProgram().getExprList()) {
if (rexNode instanceof RexCall) {
RexCall call = (RexCall) rexNode;
if (call.getOperator() instanceof SqlUserDefinedFunction) {
final SqlOperator operator = call.getOperator();

CallImplementor implementor = RexImpTable.INSTANCE.get(operator);
if (implementor == null) {
// Reject methods with no implementation
return false;
}

if (operator instanceof SqlUserDefinedFunction) {
SqlUserDefinedFunction udf = (SqlUserDefinedFunction) call.op;
if (udf.function instanceof ZetaSqlScalarFunctionImpl) {
ZetaSqlScalarFunctionImpl scalarFunction = (ZetaSqlScalarFunctionImpl) udf.function;
if (!scalarFunction.functionGroup.equals(
SqlAnalyzer.USER_DEFINED_JAVA_SCALAR_FUNCTIONS)) {
// Reject ZetaSQL Builtin Scalar Functions
return false;
}
} else {
// Reject other UDFs
return false;
}
} else {
// Reject Calcite implementations
return false;
}
} else if (rexNode instanceof RexLiteral) {
SqlTypeName typeName = ((RexLiteral) rexNode).getTypeName();
switch (typeName) {
case NULL:
case BOOLEAN:
case CHAR:
case BINARY:
case DECIMAL:
break;
default:
// Reject unsupported literals
return false;
}
} else if (rexNode instanceof RexInputRef) {
SqlTypeName typeName = ((RexInputRef) rexNode).getType().getSqlTypeName();
switch (typeName) {
case BIGINT:
case DOUBLE:
case BOOLEAN:
case VARCHAR:
case VARBINARY:
case DECIMAL:
break;
default:
// Reject unsupported input ref
return false;
}
} else {
// Reject everything else
return false;
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ public void testJavaUdfColumnReference() {
String.format(
"CREATE FUNCTION increment(i INT64) RETURNS INT64 LANGUAGE java "
+ "OPTIONS (path='%s'); "
+ "SELECT increment(Key) FROM KeyValue;",
+ "SELECT increment(int64_col) FROM table_all_types;",
jarPath);
ZetaSQLQueryPlanner zetaSQLQueryPlanner = new ZetaSQLQueryPlanner(config);
BeamRelNode beamRelNode = zetaSQLQueryPlanner.convertToBeamRel(sql);
Expand All @@ -111,8 +111,11 @@ public void testJavaUdfColumnReference() {

PAssert.that(stream)
.containsInAnyOrder(
Row.withSchema(singleField).addValues(15L).build(),
Row.withSchema(singleField).addValues(16L).build());
Row.withSchema(singleField).addValues(0L).build(),
Row.withSchema(singleField).addValues(-1L).build(),
Row.withSchema(singleField).addValues(-2L).build(),
Row.withSchema(singleField).addValues(-3L).build(),
Row.withSchema(singleField).addValues(-4L).build());
pipeline.run().waitUntilFinish(Duration.standardMinutes(PIPELINE_EXECUTION_WAITTIME_MINUTES));
}

Expand Down