From 5127734d44e83d82776f49219bb662656e388b5b Mon Sep 17 00:00:00 2001 From: Kyle Weaver Date: Fri, 29 Jan 2021 13:12:23 -0800 Subject: [PATCH] [BEAM-10925] Add rule to replace Calc with BeamCalcRel for ZetaSQL UDFs. --- .../sql/zetasql/BeamJavaUdfCalcRule.java | 55 +++++++++++++++++++ .../sql/zetasql/BeamZetaSqlCalcRule.java | 2 +- .../extensions/sql/zetasql/SqlAnalyzer.java | 3 + .../sql/zetasql/ZetaSQLQueryPlanner.java | 41 +++++++++++++- 4 files changed, 99 insertions(+), 2 deletions(-) create mode 100644 sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/BeamJavaUdfCalcRule.java diff --git a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/BeamJavaUdfCalcRule.java b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/BeamJavaUdfCalcRule.java new file mode 100644 index 000000000000..23d0f76a300a --- /dev/null +++ b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/BeamJavaUdfCalcRule.java @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.extensions.sql.zetasql; + +import org.apache.beam.sdk.extensions.sql.impl.rel.BeamCalcRel; +import org.apache.beam.sdk.extensions.sql.impl.rel.BeamLogicalConvention; +import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.plan.Convention; +import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.plan.RelOptRule; +import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.plan.RelOptRuleCall; +import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.rel.RelNode; +import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.rel.convert.ConverterRule; +import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.rel.core.Calc; +import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.rel.logical.LogicalCalc; + +/** {@link ConverterRule} to replace {@link Calc} with {@link BeamCalcRel}. */ +public class BeamJavaUdfCalcRule extends ConverterRule { + public static final BeamJavaUdfCalcRule INSTANCE = new BeamJavaUdfCalcRule(); + + private BeamJavaUdfCalcRule() { + super( + LogicalCalc.class, Convention.NONE, BeamLogicalConvention.INSTANCE, "BeamJavaUdfCalcRule"); + } + + @Override + public boolean matches(RelOptRuleCall x) { + return ZetaSQLQueryPlanner.hasUdfInProjects(x); + } + + @Override + public RelNode convert(RelNode rel) { + final Calc calc = (Calc) rel; + final RelNode input = calc.getInput(); + + return new BeamCalcRel( + calc.getCluster(), + calc.getTraitSet().replace(BeamLogicalConvention.INSTANCE), + RelOptRule.convert(input, input.getTraitSet().replace(BeamLogicalConvention.INSTANCE)), + calc.getProgram()); + } +} diff --git a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/BeamZetaSqlCalcRule.java b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/BeamZetaSqlCalcRule.java index 2e7ea0f7a2cc..2f6c60d60cfd 100644 --- a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/BeamZetaSqlCalcRule.java +++ b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/BeamZetaSqlCalcRule.java @@ -37,7 +37,7 @@ private BeamZetaSqlCalcRule() { @Override public boolean matches(RelOptRuleCall x) { - return true; + return !ZetaSQLQueryPlanner.hasUdfInProjects(x); } @Override diff --git a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/SqlAnalyzer.java b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/SqlAnalyzer.java index f4db1f194a94..4889183974be 100644 --- a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/SqlAnalyzer.java +++ b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/SqlAnalyzer.java @@ -78,6 +78,9 @@ public class SqlAnalyzer { */ public static final String ZETASQL_FUNCTION_GROUP_NAME = "ZetaSQL"; + public static final String USER_DEFINED_JAVA_SCALAR_FUNCTIONS = + "user_defined_java_scalar_functions"; + private static final ImmutableSet SUPPORTED_STATEMENT_KINDS = ImmutableSet.of( RESOLVED_QUERY_STMT, RESOLVED_CREATE_FUNCTION_STMT, RESOLVED_CREATE_TABLE_FUNCTION_STMT); diff --git a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/ZetaSQLQueryPlanner.java b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/ZetaSQLQueryPlanner.java index b943ab3d2b74..9ca5e8313d68 100644 --- a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/ZetaSQLQueryPlanner.java +++ b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/ZetaSQLQueryPlanner.java @@ -36,22 +36,29 @@ import org.apache.beam.sdk.extensions.sql.impl.rule.BeamCalcRule; import org.apache.beam.sdk.extensions.sql.impl.rule.BeamUncollectRule; import org.apache.beam.sdk.extensions.sql.impl.rule.BeamUnnestRule; +import org.apache.beam.sdk.extensions.sql.zetasql.translation.ZetaSqlScalarFunctionImpl; import org.apache.beam.sdk.extensions.sql.zetasql.unnest.BeamZetaSqlUncollectRule; import org.apache.beam.sdk.extensions.sql.zetasql.unnest.BeamZetaSqlUnnestRule; import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.config.CalciteConnectionConfig; import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.jdbc.CalciteSchema; import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.plan.ConventionTraitDef; import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.plan.RelOptRule; +import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.plan.RelOptRuleCall; +import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.plan.RelOptUtil; import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.plan.RelTraitDef; import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.plan.RelTraitSet; import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.prepare.CalciteCatalogReader; +import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.rel.RelNode; import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.rel.RelRoot; +import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.rel.logical.LogicalCalc; import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.rel.metadata.ChainedRelMetadataProvider; import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.rel.metadata.JaninoRelMetadataProvider; import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.rel.metadata.RelMetadataQuery; import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.rel.rules.FilterCalcMergeRule; import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.rel.rules.JoinCommuteRule; import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.rel.rules.ProjectCalcMergeRule; +import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.rex.RexCall; +import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.rex.RexNode; import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.schema.SchemaPlus; import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.sql.SqlNode; import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.sql.SqlOperatorTable; @@ -59,11 +66,14 @@ import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.sql.parser.SqlParser; import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.sql.parser.SqlParserImplFactory; import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.sql.util.ChainedSqlOperatorTable; +import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.sql.validate.SqlUserDefinedFunction; import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.tools.FrameworkConfig; import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.tools.Frameworks; import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.tools.RuleSet; import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.tools.RuleSets; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; /** ZetaSQLQueryPlanner. */ @SuppressWarnings({ @@ -71,6 +81,8 @@ "nullness" // TODO(https://issues.apache.org/jira/browse/BEAM-10402) }) public class ZetaSQLQueryPlanner implements QueryPlanner { + private static final Logger LOG = LoggerFactory.getLogger(ZetaSQLQueryPlanner.class); + private final ZetaSQLPlannerImpl plannerImpl; public ZetaSQLQueryPlanner(FrameworkConfig config) { @@ -104,6 +116,30 @@ public static Collection getZetaSqlRuleSets() { return modifyRuleSetsForZetaSql(BeamRuleSets.getRuleSets()); } + /** Returns true if the argument contains any user-defined Java functions. */ + static boolean hasUdfInProjects(RelOptRuleCall x) { + List resList = x.getRelList(); + for (RelNode relNode : resList) { + if (relNode instanceof LogicalCalc) { + LogicalCalc logicalCalc = (LogicalCalc) relNode; + for (RexNode rexNode : logicalCalc.getProgram().getExprList()) { + if (rexNode instanceof RexCall) { + RexCall call = (RexCall) rexNode; + if (call.getOperator() instanceof SqlUserDefinedFunction) { + SqlUserDefinedFunction udf = (SqlUserDefinedFunction) call.op; + if (udf.function instanceof ZetaSqlScalarFunctionImpl) { + ZetaSqlScalarFunctionImpl scalarFunction = (ZetaSqlScalarFunctionImpl) udf.function; + return scalarFunction.functionGroup.equals( + SqlAnalyzer.USER_DEFINED_JAVA_SCALAR_FUNCTIONS); + } + } + } + } + } + } + return false; + } + private static Collection modifyRuleSetsForZetaSql(Collection ruleSets) { ImmutableList.Builder ret = ImmutableList.builder(); for (RuleSet ruleSet : ruleSets) { @@ -123,6 +159,7 @@ private static Collection modifyRuleSetsForZetaSql(Collection continue; } else if (rule instanceof BeamCalcRule) { bd.add(BeamZetaSqlCalcRule.INSTANCE); + bd.add(BeamJavaUdfCalcRule.INSTANCE); } else if (rule instanceof BeamUnnestRule) { bd.add(BeamZetaSqlUnnestRule.INSTANCE); } else if (rule instanceof BeamUncollectRule) { @@ -196,7 +233,9 @@ private BeamRelNode convertToBeamRelInternal(String sql, QueryParameters queryPa RelMetadataQuery.THREAD_PROVIDERS.set( JaninoRelMetadataProvider.of(root.rel.getCluster().getMetadataProvider())); root.rel.getCluster().invalidateMetadataQuery(); - return (BeamRelNode) plannerImpl.transform(0, desiredTraits, root.rel); + BeamRelNode beamRelNode = (BeamRelNode) plannerImpl.transform(0, desiredTraits, root.rel); + LOG.info("BEAMPlan>\n" + RelOptUtil.toString(beamRelNode)); + return beamRelNode; } private static FrameworkConfig defaultConfig(