diff --git a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/ZetaSQLQueryPlanner.java b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/ZetaSQLQueryPlanner.java index e4a13522d982..908418c465eb 100644 --- a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/ZetaSQLQueryPlanner.java +++ b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/ZetaSQLQueryPlanner.java @@ -39,6 +39,8 @@ import org.apache.beam.sdk.extensions.sql.zetasql.translation.ZetaSqlScalarFunctionImpl; import org.apache.beam.sdk.extensions.sql.zetasql.unnest.BeamZetaSqlUncollectRule; import org.apache.beam.sdk.extensions.sql.zetasql.unnest.BeamZetaSqlUnnestRule; +import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.adapter.enumerable.CallImplementor; +import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.adapter.enumerable.RexImpTable; import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.config.CalciteConnectionConfig; import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.jdbc.CalciteSchema; import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.plan.ConventionTraitDef; @@ -58,13 +60,17 @@ import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.rel.rules.JoinCommuteRule; import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.rel.rules.ProjectCalcMergeRule; import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.rex.RexCall; +import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.rex.RexInputRef; +import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.rex.RexLiteral; import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.rex.RexNode; import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.schema.SchemaPlus; import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.sql.SqlNode; +import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.sql.SqlOperator; import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.sql.SqlOperatorTable; import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.sql.fun.SqlStdOperatorTable; import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.sql.parser.SqlParser; import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.sql.parser.SqlParserImplFactory; +import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.sql.type.SqlTypeName; import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.sql.util.ChainedSqlOperatorTable; import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.sql.validate.SqlUserDefinedFunction; import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.tools.FrameworkConfig; @@ -128,9 +134,11 @@ public static Collection getZetaSqlRuleSets(Collection calc } /** - * Returns true if the arguments only contain user-defined Java functions, otherwise returns - * false. User-defined java functions are in the category whose function group is equal to {@code - * SqlAnalyzer.USER_DEFINED_JAVA_SCALAR_FUNCTIONS} + * Returns true if all the following are true: All RexCalls can be implemented by codegen, All + * RexCalls only contain ZetaSQL user-defined Java functions, All RexLiterals pass ZetaSQL + * compliance tests, All RexInputRefs pass ZetaSQL compliance tests, No other RexNode types + * Otherwise returns false. ZetaSQL user-defined Java functions are in the category whose function + * group is equal to {@code SqlAnalyzer.USER_DEFINED_JAVA_SCALAR_FUNCTIONS} */ static boolean hasOnlyJavaUdfInProjects(RelOptRuleCall x) { List resList = x.getRelList(); @@ -140,20 +148,61 @@ static boolean hasOnlyJavaUdfInProjects(RelOptRuleCall x) { for (RexNode rexNode : logicalCalc.getProgram().getExprList()) { if (rexNode instanceof RexCall) { RexCall call = (RexCall) rexNode; - if (call.getOperator() instanceof SqlUserDefinedFunction) { + final SqlOperator operator = call.getOperator(); + + CallImplementor implementor = RexImpTable.INSTANCE.get(operator); + if (implementor == null) { + // Reject methods with no implementation + return false; + } + + if (operator instanceof SqlUserDefinedFunction) { SqlUserDefinedFunction udf = (SqlUserDefinedFunction) call.op; if (udf.function instanceof ZetaSqlScalarFunctionImpl) { ZetaSqlScalarFunctionImpl scalarFunction = (ZetaSqlScalarFunctionImpl) udf.function; if (!scalarFunction.functionGroup.equals( SqlAnalyzer.USER_DEFINED_JAVA_SCALAR_FUNCTIONS)) { + // Reject ZetaSQL Builtin Scalar Functions return false; } } else { + // Reject other UDFs return false; } } else { + // Reject Calcite implementations return false; } + } else if (rexNode instanceof RexLiteral) { + SqlTypeName typeName = ((RexLiteral) rexNode).getTypeName(); + switch (typeName) { + case NULL: + case BOOLEAN: + case CHAR: + case BINARY: + case DECIMAL: + break; + default: + // Reject unsupported literals + return false; + } + } else if (rexNode instanceof RexInputRef) { + SqlTypeName typeName = ((RexInputRef) rexNode).getType().getSqlTypeName(); + switch (typeName) { + case BIGINT: + case DOUBLE: + case BOOLEAN: + case VARCHAR: + case VARBINARY: + case DECIMAL: + break; + default: + // Reject unsupported input ref + return false; + } + } else { + // Reject everything else + return false; } } } diff --git a/sdks/java/extensions/sql/zetasql/src/test/java/org/apache/beam/sdk/extensions/sql/zetasql/ZetaSqlJavaUdfTest.java b/sdks/java/extensions/sql/zetasql/src/test/java/org/apache/beam/sdk/extensions/sql/zetasql/ZetaSqlJavaUdfTest.java index a5d2a652eae1..a7e8932697c8 100644 --- a/sdks/java/extensions/sql/zetasql/src/test/java/org/apache/beam/sdk/extensions/sql/zetasql/ZetaSqlJavaUdfTest.java +++ b/sdks/java/extensions/sql/zetasql/src/test/java/org/apache/beam/sdk/extensions/sql/zetasql/ZetaSqlJavaUdfTest.java @@ -101,7 +101,7 @@ public void testJavaUdfColumnReference() { String.format( "CREATE FUNCTION increment(i INT64) RETURNS INT64 LANGUAGE java " + "OPTIONS (path='%s'); " - + "SELECT increment(Key) FROM KeyValue;", + + "SELECT increment(int64_col) FROM table_all_types;", jarPath); ZetaSQLQueryPlanner zetaSQLQueryPlanner = new ZetaSQLQueryPlanner(config); BeamRelNode beamRelNode = zetaSQLQueryPlanner.convertToBeamRel(sql); @@ -111,8 +111,11 @@ public void testJavaUdfColumnReference() { PAssert.that(stream) .containsInAnyOrder( - Row.withSchema(singleField).addValues(15L).build(), - Row.withSchema(singleField).addValues(16L).build()); + Row.withSchema(singleField).addValues(0L).build(), + Row.withSchema(singleField).addValues(-1L).build(), + Row.withSchema(singleField).addValues(-2L).build(), + Row.withSchema(singleField).addValues(-3L).build(), + Row.withSchema(singleField).addValues(-4L).build()); pipeline.run().waitUntilFinish(Duration.standardMinutes(PIPELINE_EXECUTION_WAITTIME_MINUTES)); }