diff --git a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/BeamJavaUdfCalcRule.java b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/BeamJavaUdfCalcRule.java index 23d0f76a300a..13c8657f3d05 100644 --- a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/BeamJavaUdfCalcRule.java +++ b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/BeamJavaUdfCalcRule.java @@ -38,7 +38,7 @@ private BeamJavaUdfCalcRule() { @Override public boolean matches(RelOptRuleCall x) { - return ZetaSQLQueryPlanner.hasUdfInProjects(x); + return ZetaSQLQueryPlanner.hasOnlyJavaUdfInProjects(x); } @Override diff --git a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/BeamZetaSqlCalcRule.java b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/BeamZetaSqlCalcRule.java index 2f6c60d60cfd..0437fbb15434 100644 --- a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/BeamZetaSqlCalcRule.java +++ b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/BeamZetaSqlCalcRule.java @@ -37,7 +37,7 @@ private BeamZetaSqlCalcRule() { @Override public boolean matches(RelOptRuleCall x) { - return !ZetaSQLQueryPlanner.hasUdfInProjects(x); + return ZetaSQLQueryPlanner.hasNoJavaUdfInProjects(x); } @Override diff --git a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/ZetaSQLQueryPlanner.java b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/ZetaSQLQueryPlanner.java index 9ca5e8313d68..76b0f528e1f8 100644 --- a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/ZetaSQLQueryPlanner.java +++ b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/ZetaSQLQueryPlanner.java @@ -116,8 +116,44 @@ public static Collection getZetaSqlRuleSets() { return modifyRuleSetsForZetaSql(BeamRuleSets.getRuleSets()); } - /** Returns true if the argument contains any user-defined Java functions. */ - static boolean hasUdfInProjects(RelOptRuleCall x) { + /** + * Returns true if the arguments only contain user-defined Java functions, otherwise returns + * false. User-defined java functions are in the category whose function group is equal to {@code + * SqlAnalyzer.USER_DEFINED_JAVA_SCALAR_FUNCTIONS} + */ + static boolean hasOnlyJavaUdfInProjects(RelOptRuleCall x) { + List resList = x.getRelList(); + for (RelNode relNode : resList) { + if (relNode instanceof LogicalCalc) { + LogicalCalc logicalCalc = (LogicalCalc) relNode; + for (RexNode rexNode : logicalCalc.getProgram().getExprList()) { + if (rexNode instanceof RexCall) { + RexCall call = (RexCall) rexNode; + if (call.getOperator() instanceof SqlUserDefinedFunction) { + SqlUserDefinedFunction udf = (SqlUserDefinedFunction) call.op; + if (udf.function instanceof ZetaSqlScalarFunctionImpl) { + ZetaSqlScalarFunctionImpl scalarFunction = (ZetaSqlScalarFunctionImpl) udf.function; + if (!scalarFunction.functionGroup.equals( + SqlAnalyzer.USER_DEFINED_JAVA_SCALAR_FUNCTIONS)) { + return false; + } + } else { + return false; + } + } else { + return false; + } + } + } + } + } + return true; + } + + /** + * Returns false if the argument contains any user-defined Java functions, otherwise returns true. + */ + static boolean hasNoJavaUdfInProjects(RelOptRuleCall x) { List resList = x.getRelList(); for (RelNode relNode : resList) { if (relNode instanceof LogicalCalc) { @@ -129,15 +165,17 @@ static boolean hasUdfInProjects(RelOptRuleCall x) { SqlUserDefinedFunction udf = (SqlUserDefinedFunction) call.op; if (udf.function instanceof ZetaSqlScalarFunctionImpl) { ZetaSqlScalarFunctionImpl scalarFunction = (ZetaSqlScalarFunctionImpl) udf.function; - return scalarFunction.functionGroup.equals( - SqlAnalyzer.USER_DEFINED_JAVA_SCALAR_FUNCTIONS); + if (scalarFunction.functionGroup.equals( + SqlAnalyzer.USER_DEFINED_JAVA_SCALAR_FUNCTIONS)) { + return false; + } } } } } } } - return false; + return true; } private static Collection modifyRuleSetsForZetaSql(Collection ruleSets) { diff --git a/sdks/java/extensions/sql/zetasql/src/test/java/org/apache/beam/sdk/extensions/sql/zetasql/ZetaSqlJavaUdfTest.java b/sdks/java/extensions/sql/zetasql/src/test/java/org/apache/beam/sdk/extensions/sql/zetasql/ZetaSqlJavaUdfTest.java index 34b1d235ed97..31f61dbdc4ae 100644 --- a/sdks/java/extensions/sql/zetasql/src/test/java/org/apache/beam/sdk/extensions/sql/zetasql/ZetaSqlJavaUdfTest.java +++ b/sdks/java/extensions/sql/zetasql/src/test/java/org/apache/beam/sdk/extensions/sql/zetasql/ZetaSqlJavaUdfTest.java @@ -31,6 +31,7 @@ import org.apache.beam.sdk.testing.TestPipeline; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.Row; +import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.plan.RelOptPlanner.CannotPlanException; import org.checkerframework.checker.nullness.qual.Nullable; import org.codehaus.commons.compiler.CompileException; import org.joda.time.Duration; @@ -203,15 +204,10 @@ public void testBinaryJavaUdf() { + "SELECT matches(\"a\", \"a\"), 'apple'='beta'", jarPath); ZetaSQLQueryPlanner zetaSQLQueryPlanner = new ZetaSQLQueryPlanner(config); - BeamRelNode beamRelNode = zetaSQLQueryPlanner.convertToBeamRel(sql); - PCollection stream = BeamSqlRelUtils.toPCollection(pipeline, beamRelNode); - - Schema singleField = - Schema.builder().addBooleanField("field1").addBooleanField("field2").build(); - - PAssert.that(stream) - .containsInAnyOrder(Row.withSchema(singleField).addValues(true, false).build()); - pipeline.run().waitUntilFinish(Duration.standardMinutes(PIPELINE_EXECUTION_WAITTIME_MINUTES)); + thrown.expect(CannotPlanException.class); + thrown.expectMessage( + "There are not enough rules to produce a node with desired properties: convention=BEAM_LOGICAL."); + zetaSQLQueryPlanner.convertToBeamRel(sql); } // TODO(BEAM-11747) Add tests that mix UDFs and builtin functions that rely on the ZetaSQL C++