diff --git a/sdks/java/build-tools/src/main/resources/beam/suppressions.xml b/sdks/java/build-tools/src/main/resources/beam/suppressions.xml index a7116b788c5b..a7028677ab66 100644 --- a/sdks/java/build-tools/src/main/resources/beam/suppressions.xml +++ b/sdks/java/build-tools/src/main/resources/beam/suppressions.xml @@ -86,8 +86,7 @@ - - + diff --git a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/BeamZetaSqlCatalog.java b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/BeamZetaSqlCatalog.java new file mode 100644 index 000000000000..a02687933aa7 --- /dev/null +++ b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/BeamZetaSqlCatalog.java @@ -0,0 +1,415 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.extensions.sql.zetasql; + +import com.google.common.collect.ImmutableList; +import com.google.zetasql.Analyzer; +import com.google.zetasql.AnalyzerOptions; +import com.google.zetasql.Function; +import com.google.zetasql.FunctionArgumentType; +import com.google.zetasql.FunctionSignature; +import com.google.zetasql.SimpleCatalog; +import com.google.zetasql.TVFRelation; +import com.google.zetasql.TableValuedFunction; +import com.google.zetasql.TypeFactory; +import com.google.zetasql.ZetaSQLBuiltinFunctionOptions; +import com.google.zetasql.ZetaSQLFunctions; +import com.google.zetasql.ZetaSQLType; +import com.google.zetasql.resolvedast.ResolvedNode; +import com.google.zetasql.resolvedast.ResolvedNodes; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.stream.Collectors; +import org.apache.beam.sdk.extensions.sql.impl.JavaUdfLoader; +import org.apache.beam.sdk.extensions.sql.impl.SqlConversionException; +import org.apache.beam.sdk.extensions.sql.impl.utils.TVFStreamingUtils; +import org.apache.beam.sdk.extensions.sql.udf.ScalarFn; +import org.apache.beam.sdk.extensions.sql.zetasql.translation.UserFunctionDefinitions; +import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.adapter.java.JavaTypeFactory; +import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.rel.type.RelDataType; +import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.rel.type.RelDataTypeField; +import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.schema.SchemaPlus; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap; + +/** + * Catalog for registering tables and functions. Populates a {@link SimpleCatalog} based on a {@link + * SchemaPlus}. + */ +public class BeamZetaSqlCatalog { + // ZetaSQL function group identifiers. Different function groups may have divergent translation + // paths. + public static final String PRE_DEFINED_WINDOW_FUNCTIONS = "pre_defined_window_functions"; + public static final String USER_DEFINED_SQL_FUNCTIONS = "user_defined_functions"; + public static final String USER_DEFINED_JAVA_SCALAR_FUNCTIONS = + "user_defined_java_scalar_functions"; + /** + * Same as {@link Function}.ZETASQL_FUNCTION_GROUP_NAME. Identifies built-in ZetaSQL functions. + */ + public static final String ZETASQL_FUNCTION_GROUP_NAME = "ZetaSQL"; + + private static final ImmutableList PRE_DEFINED_WINDOW_FUNCTION_DECLARATIONS = + ImmutableList.of( + // TODO: support optional function argument (for window_offset). + "CREATE FUNCTION TUMBLE(ts TIMESTAMP, window_size STRING) AS (1);", + "CREATE FUNCTION TUMBLE_START(window_size STRING) RETURNS TIMESTAMP AS (null);", + "CREATE FUNCTION TUMBLE_END(window_size STRING) RETURNS TIMESTAMP AS (null);", + "CREATE FUNCTION HOP(ts TIMESTAMP, emit_frequency STRING, window_size STRING) AS (1);", + "CREATE FUNCTION HOP_START(emit_frequency STRING, window_size STRING) " + + "RETURNS TIMESTAMP AS (null);", + "CREATE FUNCTION HOP_END(emit_frequency STRING, window_size STRING) " + + "RETURNS TIMESTAMP AS (null);", + "CREATE FUNCTION SESSION(ts TIMESTAMP, session_gap STRING) AS (1);", + "CREATE FUNCTION SESSION_START(session_gap STRING) RETURNS TIMESTAMP AS (null);", + "CREATE FUNCTION SESSION_END(session_gap STRING) RETURNS TIMESTAMP AS (null);"); + + /** The top-level Calcite schema, which may contain sub-schemas. */ + private final SchemaPlus calciteSchema; + /** + * The top-level ZetaSQL catalog, which may contain nested catalogs for qualified table and + * function references. + */ + private final SimpleCatalog zetaSqlCatalog; + + private final JavaTypeFactory typeFactory; + + private final JavaUdfLoader javaUdfLoader = new JavaUdfLoader(); + private final Map, ResolvedNodes.ResolvedCreateFunctionStmt> sqlScalarUdfs = + new HashMap<>(); + /** User-defined table valued functions. */ + private final Map, ResolvedNode> sqlUdtvfs = new HashMap<>(); + + private final Map, UserFunctionDefinitions.JavaScalarFunction> javaScalarUdfs = + new HashMap<>(); + + private BeamZetaSqlCatalog( + SchemaPlus calciteSchema, SimpleCatalog zetaSqlCatalog, JavaTypeFactory typeFactory) { + this.calciteSchema = calciteSchema; + this.zetaSqlCatalog = zetaSqlCatalog; + this.typeFactory = typeFactory; + } + + /** Return catalog pre-populated with builtin functions. */ + static BeamZetaSqlCatalog create( + SchemaPlus calciteSchema, JavaTypeFactory typeFactory, AnalyzerOptions options) { + BeamZetaSqlCatalog catalog = + new BeamZetaSqlCatalog( + calciteSchema, new SimpleCatalog(calciteSchema.getName()), typeFactory); + catalog.addBuiltinFunctionsToCatalog(options); + return catalog; + } + + SimpleCatalog getZetaSqlCatalog() { + return zetaSqlCatalog; + } + + void addTables(List> tables, QueryTrait queryTrait) { + tables.forEach(table -> addTableToLeafCatalog(table, queryTrait)); + } + + void addFunction(ResolvedNodes.ResolvedCreateFunctionStmt createFunctionStmt) { + String functionGroup = getFunctionGroup(createFunctionStmt); + switch (functionGroup) { + case USER_DEFINED_SQL_FUNCTIONS: + sqlScalarUdfs.put(createFunctionStmt.getNamePath(), createFunctionStmt); + break; + case USER_DEFINED_JAVA_SCALAR_FUNCTIONS: + String jarPath = getJarPath(createFunctionStmt); + ScalarFn scalarFn = + javaUdfLoader.loadScalarFunction(createFunctionStmt.getNamePath(), jarPath); + javaScalarUdfs.put( + createFunctionStmt.getNamePath(), + UserFunctionDefinitions.JavaScalarFunction.create(scalarFn, jarPath)); + break; + default: + throw new IllegalArgumentException( + String.format("Encountered unrecognized function group %s.", functionGroup)); + } + zetaSqlCatalog.addFunction( + new Function( + createFunctionStmt.getNamePath(), + functionGroup, + createFunctionStmt.getIsAggregate() + ? ZetaSQLFunctions.FunctionEnums.Mode.AGGREGATE + : ZetaSQLFunctions.FunctionEnums.Mode.SCALAR, + ImmutableList.of(createFunctionStmt.getSignature()))); + } + + void addTableValuedFunction( + ResolvedNodes.ResolvedCreateTableFunctionStmt createTableFunctionStmt) { + zetaSqlCatalog.addTableValuedFunction( + new TableValuedFunction.FixedOutputSchemaTVF( + createTableFunctionStmt.getNamePath(), + createTableFunctionStmt.getSignature(), + TVFRelation.createColumnBased( + createTableFunctionStmt.getQuery().getColumnList().stream() + .map(c -> TVFRelation.Column.create(c.getName(), c.getType())) + .collect(Collectors.toList())))); + sqlUdtvfs.put(createTableFunctionStmt.getNamePath(), createTableFunctionStmt.getQuery()); + } + + UserFunctionDefinitions getUserFunctionDefinitions() { + return UserFunctionDefinitions.newBuilder() + .setSqlScalarFunctions(ImmutableMap.copyOf(sqlScalarUdfs)) + .setSqlTableValuedFunctions(ImmutableMap.copyOf(sqlUdtvfs)) + .setJavaScalarFunctions(ImmutableMap.copyOf(javaScalarUdfs)) + .build(); + } + + private void addBuiltinFunctionsToCatalog(AnalyzerOptions options) { + // Enable ZetaSQL builtin functions. + ZetaSQLBuiltinFunctionOptions zetasqlBuiltinFunctionOptions = + new ZetaSQLBuiltinFunctionOptions(options.getLanguageOptions()); + SupportedZetaSqlBuiltinFunctions.ALLOWLIST.forEach( + zetasqlBuiltinFunctionOptions::includeFunctionSignatureId); + zetaSqlCatalog.addZetaSQLFunctions(zetasqlBuiltinFunctionOptions); + + // Enable Beam SQL's builtin windowing functions. + addWindowScalarFunctions(options); + addWindowTvfs(); + } + + private void addWindowScalarFunctions(AnalyzerOptions options) { + PRE_DEFINED_WINDOW_FUNCTION_DECLARATIONS.stream() + .map( + func -> + (ResolvedNodes.ResolvedCreateFunctionStmt) + Analyzer.analyzeStatement(func, options, zetaSqlCatalog)) + .map( + resolvedFunc -> + new Function( + String.join(".", resolvedFunc.getNamePath()), + PRE_DEFINED_WINDOW_FUNCTIONS, + ZetaSQLFunctions.FunctionEnums.Mode.SCALAR, + ImmutableList.of(resolvedFunc.getSignature()))) + .forEach(zetaSqlCatalog::addFunction); + } + + @SuppressWarnings({ + "nullness" // customContext and volatility are in fact nullable, but they are missing the + // annotation upstream. TODO Unsuppress when this is fixed in ZetaSQL. + }) + private void addWindowTvfs() { + FunctionArgumentType retType = + new FunctionArgumentType(ZetaSQLFunctions.SignatureArgumentKind.ARG_TYPE_RELATION); + + FunctionArgumentType inputTableType = + new FunctionArgumentType(ZetaSQLFunctions.SignatureArgumentKind.ARG_TYPE_RELATION); + + FunctionArgumentType descriptorType = + new FunctionArgumentType( + ZetaSQLFunctions.SignatureArgumentKind.ARG_TYPE_DESCRIPTOR, + FunctionArgumentType.FunctionArgumentTypeOptions.builder() + .setDescriptorResolutionTableOffset(0) + .build(), + 1); + + FunctionArgumentType stringType = + new FunctionArgumentType(TypeFactory.createSimpleType(ZetaSQLType.TypeKind.TYPE_STRING)); + + // TUMBLE + zetaSqlCatalog.addTableValuedFunction( + new TableValuedFunction.ForwardInputSchemaToOutputSchemaWithAppendedColumnTVF( + ImmutableList.of(TVFStreamingUtils.FIXED_WINDOW_TVF), + new FunctionSignature( + retType, ImmutableList.of(inputTableType, descriptorType, stringType), -1), + ImmutableList.of( + TVFRelation.Column.create( + TVFStreamingUtils.WINDOW_START, + TypeFactory.createSimpleType(ZetaSQLType.TypeKind.TYPE_TIMESTAMP)), + TVFRelation.Column.create( + TVFStreamingUtils.WINDOW_END, + TypeFactory.createSimpleType(ZetaSQLType.TypeKind.TYPE_TIMESTAMP))), + null, + null)); + + // HOP + zetaSqlCatalog.addTableValuedFunction( + new TableValuedFunction.ForwardInputSchemaToOutputSchemaWithAppendedColumnTVF( + ImmutableList.of(TVFStreamingUtils.SLIDING_WINDOW_TVF), + new FunctionSignature( + retType, + ImmutableList.of(inputTableType, descriptorType, stringType, stringType), + -1), + ImmutableList.of( + TVFRelation.Column.create( + TVFStreamingUtils.WINDOW_START, + TypeFactory.createSimpleType(ZetaSQLType.TypeKind.TYPE_TIMESTAMP)), + TVFRelation.Column.create( + TVFStreamingUtils.WINDOW_END, + TypeFactory.createSimpleType(ZetaSQLType.TypeKind.TYPE_TIMESTAMP))), + null, + null)); + + // SESSION + zetaSqlCatalog.addTableValuedFunction( + new TableValuedFunction.ForwardInputSchemaToOutputSchemaWithAppendedColumnTVF( + ImmutableList.of(TVFStreamingUtils.SESSION_WINDOW_TVF), + new FunctionSignature( + retType, + ImmutableList.of(inputTableType, descriptorType, descriptorType, stringType), + -1), + ImmutableList.of( + TVFRelation.Column.create( + TVFStreamingUtils.WINDOW_START, + TypeFactory.createSimpleType(ZetaSQLType.TypeKind.TYPE_TIMESTAMP)), + TVFRelation.Column.create( + TVFStreamingUtils.WINDOW_END, + TypeFactory.createSimpleType(ZetaSQLType.TypeKind.TYPE_TIMESTAMP))), + null, + null)); + } + + private String getFunctionGroup(ResolvedNodes.ResolvedCreateFunctionStmt createFunctionStmt) { + switch (createFunctionStmt.getLanguage().toUpperCase()) { + case "JAVA": + if (createFunctionStmt.getIsAggregate()) { + throw new UnsupportedOperationException( + "Java SQL aggregate functions are not supported (BEAM-10925)."); + } + return USER_DEFINED_JAVA_SCALAR_FUNCTIONS; + case "SQL": + if (createFunctionStmt.getIsAggregate()) { + throw new UnsupportedOperationException( + "Native SQL aggregate functions are not supported (BEAM-9954)."); + } + return USER_DEFINED_SQL_FUNCTIONS; + case "PY": + case "PYTHON": + case "JS": + case "JAVASCRIPT": + throw new UnsupportedOperationException( + String.format( + "Function %s uses unsupported language %s.", + String.join(".", createFunctionStmt.getNamePath()), + createFunctionStmt.getLanguage())); + default: + throw new IllegalArgumentException( + String.format( + "Function %s uses unrecognized language %s.", + String.join(".", createFunctionStmt.getNamePath()), + createFunctionStmt.getLanguage())); + } + } + + /** + * Assume last element in tablePath is a table name, and everything before is catalogs. So the + * logic is to create nested catalogs until the last level, then add a table at the last level. + * + *

Table schema is extracted from Calcite schema based on the table name resolution strategy, + * e.g. either by drilling down the schema.getSubschema() path or joining the table name with dots + * to construct a single compound identifier (e.g. Data Catalog use case). + */ + private void addTableToLeafCatalog(List tablePath, QueryTrait queryTrait) { + + SimpleCatalog leafCatalog = createNestedCatalogs(zetaSqlCatalog, tablePath); + + org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.schema.Table calciteTable = + TableResolution.resolveCalciteTable(calciteSchema, tablePath); + + if (calciteTable == null) { + throw new SqlConversionException( + "Wasn't able to resolve the path " + + tablePath + + " in schema: " + + calciteSchema.getName()); + } + + RelDataType rowType = calciteTable.getRowType(typeFactory); + + TableResolution.SimpleTableWithPath tableWithPath = + TableResolution.SimpleTableWithPath.of(tablePath); + queryTrait.addResolvedTable(tableWithPath); + + addFieldsToTable(tableWithPath, rowType); + leafCatalog.addSimpleTable(tableWithPath.getTable()); + } + + private static void addFieldsToTable( + TableResolution.SimpleTableWithPath tableWithPath, RelDataType rowType) { + for (RelDataTypeField field : rowType.getFieldList()) { + tableWithPath + .getTable() + .addSimpleColumn( + field.getName(), ZetaSqlCalciteTranslationUtils.toZetaSqlType(field.getType())); + } + } + + /** For table path like a.b.c we assume c is the table and a.b are the nested catalogs/schemas. */ + private static SimpleCatalog createNestedCatalogs(SimpleCatalog catalog, List tablePath) { + SimpleCatalog currentCatalog = catalog; + for (int i = 0; i < tablePath.size() - 1; i++) { + String nextCatalogName = tablePath.get(i); + + Optional existing = tryGetExisting(currentCatalog, nextCatalogName); + + currentCatalog = + existing.isPresent() ? existing.get() : addNewCatalog(currentCatalog, nextCatalogName); + } + return currentCatalog; + } + + private static Optional tryGetExisting( + SimpleCatalog currentCatalog, String nextCatalogName) { + return currentCatalog.getCatalogList().stream() + .filter(c -> nextCatalogName.equals(c.getFullName())) + .findFirst(); + } + + private static SimpleCatalog addNewCatalog(SimpleCatalog currentCatalog, String nextCatalogName) { + SimpleCatalog nextCatalog = new SimpleCatalog(nextCatalogName); + currentCatalog.addSimpleCatalog(nextCatalog); + return nextCatalog; + } + + private static String getJarPath(ResolvedNodes.ResolvedCreateFunctionStmt createFunctionStmt) { + String jarPath = getOptionStringValue(createFunctionStmt, "path"); + if (jarPath.isEmpty()) { + throw new IllegalArgumentException( + String.format( + "No jar was provided to define function %s. Add 'OPTIONS (path=)' to the CREATE FUNCTION statement.", + String.join(".", createFunctionStmt.getNamePath()))); + } + return jarPath; + } + + private static String getOptionStringValue( + ResolvedNodes.ResolvedCreateFunctionStmt createFunctionStmt, String optionName) { + for (ResolvedNodes.ResolvedOption option : createFunctionStmt.getOptionList()) { + if (optionName.equals(option.getName())) { + if (option.getValue() == null) { + throw new IllegalArgumentException( + String.format( + "Option '%s' has null value (expected %s).", + optionName, ZetaSQLType.TypeKind.TYPE_STRING)); + } + if (option.getValue().getType().getKind() != ZetaSQLType.TypeKind.TYPE_STRING) { + throw new IllegalArgumentException( + String.format( + "Option '%s' has type %s (expected %s).", + optionName, + option.getValue().getType().getKind(), + ZetaSQLType.TypeKind.TYPE_STRING)); + } + return ((ResolvedNodes.ResolvedLiteral) option.getValue()).getValue().getStringValue(); + } + } + return ""; + } +} diff --git a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/SqlAnalyzer.java b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/SqlAnalyzer.java index 0b65757c7b7e..14a65749e411 100644 --- a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/SqlAnalyzer.java +++ b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/SqlAnalyzer.java @@ -22,28 +22,16 @@ import static com.google.zetasql.ZetaSQLResolvedNodeKind.ResolvedNodeKind.RESOLVED_QUERY_STMT; import static java.nio.charset.StandardCharsets.UTF_8; -import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableSet; import com.google.zetasql.Analyzer; import com.google.zetasql.AnalyzerOptions; -import com.google.zetasql.Function; -import com.google.zetasql.FunctionArgumentType; -import com.google.zetasql.FunctionSignature; import com.google.zetasql.ParseResumeLocation; -import com.google.zetasql.SimpleCatalog; -import com.google.zetasql.TVFRelation; -import com.google.zetasql.TableValuedFunction; -import com.google.zetasql.TypeFactory; import com.google.zetasql.Value; -import com.google.zetasql.ZetaSQLBuiltinFunctionOptions; -import com.google.zetasql.ZetaSQLFunctions.FunctionEnums.Mode; -import com.google.zetasql.ZetaSQLFunctions.SignatureArgumentKind; import com.google.zetasql.ZetaSQLOptions.ErrorMessageMode; import com.google.zetasql.ZetaSQLOptions.LanguageFeature; import com.google.zetasql.ZetaSQLOptions.ParameterMode; import com.google.zetasql.ZetaSQLOptions.ProductMode; import com.google.zetasql.ZetaSQLResolvedNodeKind.ResolvedNodeKind; -import com.google.zetasql.ZetaSQLType.TypeKind; +import com.google.zetasql.resolvedast.ResolvedNodes; import com.google.zetasql.resolvedast.ResolvedNodes.ResolvedCreateFunctionStmt; import com.google.zetasql.resolvedast.ResolvedNodes.ResolvedCreateTableFunctionStmt; import com.google.zetasql.resolvedast.ResolvedNodes.ResolvedStatement; @@ -51,69 +39,21 @@ import java.util.HashSet; import java.util.List; import java.util.Map; -import java.util.Optional; -import java.util.stream.Collectors; -import org.apache.beam.sdk.extensions.sql.impl.ParseException; import org.apache.beam.sdk.extensions.sql.impl.QueryPlanner.QueryParameters; import org.apache.beam.sdk.extensions.sql.impl.QueryPlanner.QueryParameters.Kind; -import org.apache.beam.sdk.extensions.sql.impl.SqlConversionException; -import org.apache.beam.sdk.extensions.sql.impl.utils.TVFStreamingUtils; -import org.apache.beam.sdk.extensions.sql.zetasql.TableResolution.SimpleTableWithPath; -import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.adapter.java.JavaTypeFactory; -import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.rel.type.RelDataType; -import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.rel.type.RelDataTypeField; -import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.schema.SchemaPlus; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableSet; /** Adapter for {@link Analyzer} to simplify the API for parsing the query and resolving the AST. */ @SuppressWarnings({ "nullness" // TODO(https://issues.apache.org/jira/browse/BEAM-10402) }) public class SqlAnalyzer { - // ZetaSQL function group identifiers. Different function groups may have divergent translation - // paths. - public static final String PRE_DEFINED_WINDOW_FUNCTIONS = "pre_defined_window_functions"; - public static final String USER_DEFINED_SQL_FUNCTIONS = "user_defined_functions"; - /** - * Same as {@link Function}.ZETASQL_FUNCTION_GROUP_NAME. Identifies built-in ZetaSQL functions. - */ - public static final String ZETASQL_FUNCTION_GROUP_NAME = "ZetaSQL"; - - public static final String USER_DEFINED_JAVA_SCALAR_FUNCTIONS = - "user_defined_java_scalar_functions"; - private static final ImmutableSet SUPPORTED_STATEMENT_KINDS = ImmutableSet.of( RESOLVED_QUERY_STMT, RESOLVED_CREATE_FUNCTION_STMT, RESOLVED_CREATE_TABLE_FUNCTION_STMT); - private static final ImmutableList FUNCTION_LIST = - ImmutableList.of( - // TODO: support optional function argument (for window_offset). - "CREATE FUNCTION TUMBLE(ts TIMESTAMP, window_size STRING) AS (1);", - "CREATE FUNCTION TUMBLE_START(window_size STRING) RETURNS TIMESTAMP AS (null);", - "CREATE FUNCTION TUMBLE_END(window_size STRING) RETURNS TIMESTAMP AS (null);", - "CREATE FUNCTION HOP(ts TIMESTAMP, emit_frequency STRING, window_size STRING) AS (1);", - "CREATE FUNCTION HOP_START(emit_frequency STRING, window_size STRING) " - + "RETURNS TIMESTAMP AS (null);", - "CREATE FUNCTION HOP_END(emit_frequency STRING, window_size STRING) " - + "RETURNS TIMESTAMP AS (null);", - "CREATE FUNCTION SESSION(ts TIMESTAMP, session_gap STRING) AS (1);", - "CREATE FUNCTION SESSION_START(session_gap STRING) RETURNS TIMESTAMP AS (null);", - "CREATE FUNCTION SESSION_END(session_gap STRING) RETURNS TIMESTAMP AS (null);"); - - private final QueryTrait queryTrait; - private final SchemaPlus topLevelSchema; - private final JavaTypeFactory typeFactory; - - SqlAnalyzer(QueryTrait queryTrait, SchemaPlus topLevelSchema, JavaTypeFactory typeFactory) { - this.queryTrait = queryTrait; - this.topLevelSchema = topLevelSchema; - this.typeFactory = typeFactory; - } - - static boolean isEndOfInput(ParseResumeLocation parseResumeLocation) { - return parseResumeLocation.getBytePosition() - >= parseResumeLocation.getInput().getBytes(UTF_8).length; - } + SqlAnalyzer() {} /** Returns table names from all statements in the SQL string. */ List> extractTableNames(String sql, AnalyzerOptions options) { @@ -127,44 +67,37 @@ List> extractTableNames(String sql, AnalyzerOptions options) { return tables.build(); } - static String getFunctionGroup(ResolvedCreateFunctionStmt createFunctionStmt) { - switch (createFunctionStmt.getLanguage().toUpperCase()) { - case "JAVA": - if (createFunctionStmt.getIsAggregate()) { - throw new UnsupportedOperationException( - "Java SQL aggregate functions are not supported (BEAM-10925)."); - } - return USER_DEFINED_JAVA_SCALAR_FUNCTIONS; - case "SQL": - if (createFunctionStmt.getIsAggregate()) { + /** + * Analyzes the entire SQL code block (which may consist of multiple statements) and returns the + * resolved query. + * + *

Assumes there is exactly one SELECT statement in the input, and it must be the last + * statement in the input. + */ + ResolvedNodes.ResolvedQueryStmt analyzeQuery( + String sql, AnalyzerOptions options, BeamZetaSqlCatalog catalog) { + ParseResumeLocation parseResumeLocation = new ParseResumeLocation(sql); + ResolvedStatement statement; + do { + statement = analyzeNextStatement(parseResumeLocation, options, catalog); + if (statement.nodeKind() == RESOLVED_QUERY_STMT) { + if (!SqlAnalyzer.isEndOfInput(parseResumeLocation)) { throw new UnsupportedOperationException( - "Native SQL aggregate functions are not supported (BEAM-9954)."); + "No additional statements are allowed after a SELECT statement."); } - return USER_DEFINED_SQL_FUNCTIONS; - case "PY": - case "PYTHON": - case "JS": - case "JAVASCRIPT": - throw new UnsupportedOperationException( - String.format( - "Function %s uses unsupported language %s.", - String.join(".", createFunctionStmt.getNamePath()), - createFunctionStmt.getLanguage())); - default: - throw new IllegalArgumentException( - String.format( - "Function %s uses unrecognized language %s.", - String.join(".", createFunctionStmt.getNamePath()), - createFunctionStmt.getLanguage())); + } + } while (!SqlAnalyzer.isEndOfInput(parseResumeLocation)); + + if (!(statement instanceof ResolvedNodes.ResolvedQueryStmt)) { + throw new UnsupportedOperationException( + "Statement list must end in a SELECT statement, not " + statement.nodeKindString()); } + return (ResolvedNodes.ResolvedQueryStmt) statement; } - private Function createFunction(ResolvedCreateFunctionStmt createFunctionStmt) { - return new Function( - createFunctionStmt.getNamePath(), - getFunctionGroup(createFunctionStmt), - createFunctionStmt.getIsAggregate() ? Mode.AGGREGATE : Mode.SCALAR, - ImmutableList.of(createFunctionStmt.getSignature())); + private static boolean isEndOfInput(ParseResumeLocation parseResumeLocation) { + return parseResumeLocation.getBytePosition() + >= parseResumeLocation.getInput().getBytes(UTF_8).length; } /** @@ -172,33 +105,28 @@ private Function createFunction(ResolvedCreateFunctionStmt createFunctionStmt) { * ParseResumeLocation to the start of the next statement. Adds user-defined functions to the * catalog for use in following statements. Returns the resolved AST. */ - ResolvedStatement analyzeNextStatement( - ParseResumeLocation parseResumeLocation, AnalyzerOptions options, SimpleCatalog catalog) { + private ResolvedStatement analyzeNextStatement( + ParseResumeLocation parseResumeLocation, + AnalyzerOptions options, + BeamZetaSqlCatalog catalog) { ResolvedStatement resolvedStatement = - Analyzer.analyzeNextStatement(parseResumeLocation, options, catalog); + Analyzer.analyzeNextStatement(parseResumeLocation, options, catalog.getZetaSqlCatalog()); if (resolvedStatement.nodeKind() == RESOLVED_CREATE_FUNCTION_STMT) { ResolvedCreateFunctionStmt createFunctionStmt = (ResolvedCreateFunctionStmt) resolvedStatement; - Function userFunction = createFunction(createFunctionStmt); try { - catalog.addFunction(userFunction); + catalog.addFunction(createFunctionStmt); } catch (IllegalArgumentException e) { - throw new ParseException( + throw new RuntimeException( String.format( - "Failed to define function %s", String.join(".", createFunctionStmt.getNamePath())), + "Failed to define function '%s'", + String.join(".", createFunctionStmt.getNamePath())), e); } } else if (resolvedStatement.nodeKind() == RESOLVED_CREATE_TABLE_FUNCTION_STMT) { ResolvedCreateTableFunctionStmt createTableFunctionStmt = (ResolvedCreateTableFunctionStmt) resolvedStatement; - catalog.addTableValuedFunction( - new TableValuedFunction.FixedOutputSchemaTVF( - createTableFunctionStmt.getNamePath(), - createTableFunctionStmt.getSignature(), - TVFRelation.createColumnBased( - createTableFunctionStmt.getQuery().getColumnList().stream() - .map(c -> TVFRelation.Column.create(c.getName(), c.getType())) - .collect(Collectors.toList())))); + catalog.addTableValuedFunction(createTableFunctionStmt); } else if (!SUPPORTED_STATEMENT_KINDS.contains(resolvedStatement.nodeKind())) { throw new UnsupportedOperationException( "Unrecognized statement type " + resolvedStatement.nodeKindString()); @@ -248,179 +176,4 @@ static AnalyzerOptions getAnalyzerOptions(QueryParameters queryParams, String de return options; } - - /** - * Creates a SimpleCatalog which represents the top-level schema, populates it with tables, - * built-in functions. - */ - SimpleCatalog createPopulatedCatalog( - String catalogName, AnalyzerOptions options, List> tables) { - - SimpleCatalog catalog = new SimpleCatalog(catalogName); - addBuiltinFunctionsToCatalog(catalog, options); - - tables.forEach(table -> addTableToLeafCatalog(queryTrait, catalog, table)); - - return catalog; - } - - private void addBuiltinFunctionsToCatalog(SimpleCatalog catalog, AnalyzerOptions options) { - // Enable ZetaSQL builtin functions. - ZetaSQLBuiltinFunctionOptions zetasqlBuiltinFunctionOptions = - new ZetaSQLBuiltinFunctionOptions(options.getLanguageOptions()); - - SupportedZetaSqlBuiltinFunctions.ALLOWLIST.forEach( - zetasqlBuiltinFunctionOptions::includeFunctionSignatureId); - - catalog.addZetaSQLFunctions(zetasqlBuiltinFunctionOptions); - - FUNCTION_LIST.stream() - .map(func -> (ResolvedCreateFunctionStmt) Analyzer.analyzeStatement(func, options, catalog)) - .map( - resolvedFunc -> - new Function( - String.join(".", resolvedFunc.getNamePath()), - PRE_DEFINED_WINDOW_FUNCTIONS, - Mode.SCALAR, - ImmutableList.of(resolvedFunc.getSignature()))) - .forEach(catalog::addFunction); - - FunctionArgumentType retType = - new FunctionArgumentType(SignatureArgumentKind.ARG_TYPE_RELATION); - - FunctionArgumentType inputTableType = - new FunctionArgumentType(SignatureArgumentKind.ARG_TYPE_RELATION); - - FunctionArgumentType descriptorType = - new FunctionArgumentType( - SignatureArgumentKind.ARG_TYPE_DESCRIPTOR, - FunctionArgumentType.FunctionArgumentTypeOptions.builder() - .setDescriptorResolutionTableOffset(0) - .build(), - 1); - - FunctionArgumentType stringType = - new FunctionArgumentType(TypeFactory.createSimpleType(TypeKind.TYPE_STRING)); - - // TUMBLE - catalog.addTableValuedFunction( - new TableValuedFunction.ForwardInputSchemaToOutputSchemaWithAppendedColumnTVF( - ImmutableList.of(TVFStreamingUtils.FIXED_WINDOW_TVF), - new FunctionSignature( - retType, ImmutableList.of(inputTableType, descriptorType, stringType), -1), - ImmutableList.of( - TVFRelation.Column.create( - TVFStreamingUtils.WINDOW_START, - TypeFactory.createSimpleType(TypeKind.TYPE_TIMESTAMP)), - TVFRelation.Column.create( - TVFStreamingUtils.WINDOW_END, - TypeFactory.createSimpleType(TypeKind.TYPE_TIMESTAMP))), - null, - null)); - - // HOP - catalog.addTableValuedFunction( - new TableValuedFunction.ForwardInputSchemaToOutputSchemaWithAppendedColumnTVF( - ImmutableList.of(TVFStreamingUtils.SLIDING_WINDOW_TVF), - new FunctionSignature( - retType, - ImmutableList.of(inputTableType, descriptorType, stringType, stringType), - -1), - ImmutableList.of( - TVFRelation.Column.create( - TVFStreamingUtils.WINDOW_START, - TypeFactory.createSimpleType(TypeKind.TYPE_TIMESTAMP)), - TVFRelation.Column.create( - TVFStreamingUtils.WINDOW_END, - TypeFactory.createSimpleType(TypeKind.TYPE_TIMESTAMP))), - null, - null)); - - // SESSION - catalog.addTableValuedFunction( - new TableValuedFunction.ForwardInputSchemaToOutputSchemaWithAppendedColumnTVF( - ImmutableList.of(TVFStreamingUtils.SESSION_WINDOW_TVF), - new FunctionSignature( - retType, - ImmutableList.of(inputTableType, descriptorType, descriptorType, stringType), - -1), - ImmutableList.of( - TVFRelation.Column.create( - TVFStreamingUtils.WINDOW_START, - TypeFactory.createSimpleType(TypeKind.TYPE_TIMESTAMP)), - TVFRelation.Column.create( - TVFStreamingUtils.WINDOW_END, - TypeFactory.createSimpleType(TypeKind.TYPE_TIMESTAMP))), - null, - null)); - } - - /** - * Assume last element in tablePath is a table name, and everything before is catalogs. So the - * logic is to create nested catalogs until the last level, then add a table at the last level. - * - *

Table schema is extracted from Calcite schema based on the table name resultion strategy, - * e.g. either by drilling down the schema.getSubschema() path or joining the table name with dots - * to construct a single compound identifier (e.g. Data Catalog use case). - */ - private void addTableToLeafCatalog( - QueryTrait trait, SimpleCatalog topLevelCatalog, List tablePath) { - - SimpleCatalog leafCatalog = createNestedCatalogs(topLevelCatalog, tablePath); - - org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.schema.Table calciteTable = - TableResolution.resolveCalciteTable(topLevelSchema, tablePath); - - if (calciteTable == null) { - throw new SqlConversionException( - "Wasn't able to resolve the path " - + tablePath - + " in schema: " - + topLevelSchema.getName()); - } - - RelDataType rowType = calciteTable.getRowType(typeFactory); - - SimpleTableWithPath tableWithPath = SimpleTableWithPath.of(tablePath); - trait.addResolvedTable(tableWithPath); - - addFieldsToTable(tableWithPath, rowType); - leafCatalog.addSimpleTable(tableWithPath.getTable()); - } - - private void addFieldsToTable(SimpleTableWithPath tableWithPath, RelDataType rowType) { - for (RelDataTypeField field : rowType.getFieldList()) { - tableWithPath - .getTable() - .addSimpleColumn( - field.getName(), ZetaSqlCalciteTranslationUtils.toZetaSqlType(field.getType())); - } - } - - /** For table path like a.b.c we assume c is the table and a.b are the nested catalogs/schemas. */ - private SimpleCatalog createNestedCatalogs(SimpleCatalog catalog, List tablePath) { - SimpleCatalog currentCatalog = catalog; - for (int i = 0; i < tablePath.size() - 1; i++) { - String nextCatalogName = tablePath.get(i); - - Optional existing = tryGetExisting(currentCatalog, nextCatalogName); - - currentCatalog = - existing.isPresent() ? existing.get() : addNewCatalog(currentCatalog, nextCatalogName); - } - return currentCatalog; - } - - private Optional tryGetExisting( - SimpleCatalog currentCatalog, String nextCatalogName) { - return currentCatalog.getCatalogList().stream() - .filter(c -> nextCatalogName.equals(c.getFullName())) - .findFirst(); - } - - private SimpleCatalog addNewCatalog(SimpleCatalog currentCatalog, String nextCatalogName) { - SimpleCatalog nextCatalog = new SimpleCatalog(nextCatalogName); - currentCatalog.addSimpleCatalog(nextCatalog); - return nextCatalog; - } } diff --git a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/ZetaSQLPlannerImpl.java b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/ZetaSQLPlannerImpl.java index 028a697c144d..7004b0eb62a9 100644 --- a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/ZetaSQLPlannerImpl.java +++ b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/ZetaSQLPlannerImpl.java @@ -17,29 +17,14 @@ */ package org.apache.beam.sdk.extensions.sql.zetasql; -import static com.google.zetasql.ZetaSQLResolvedNodeKind.ResolvedNodeKind.RESOLVED_CREATE_FUNCTION_STMT; -import static com.google.zetasql.ZetaSQLResolvedNodeKind.ResolvedNodeKind.RESOLVED_CREATE_TABLE_FUNCTION_STMT; -import static com.google.zetasql.ZetaSQLResolvedNodeKind.ResolvedNodeKind.RESOLVED_QUERY_STMT; - import com.google.zetasql.AnalyzerOptions; import com.google.zetasql.LanguageOptions; -import com.google.zetasql.ParseResumeLocation; -import com.google.zetasql.SimpleCatalog; -import com.google.zetasql.ZetaSQLType; -import com.google.zetasql.resolvedast.ResolvedNode; -import com.google.zetasql.resolvedast.ResolvedNodes; -import com.google.zetasql.resolvedast.ResolvedNodes.ResolvedCreateFunctionStmt; -import com.google.zetasql.resolvedast.ResolvedNodes.ResolvedCreateTableFunctionStmt; import com.google.zetasql.resolvedast.ResolvedNodes.ResolvedQueryStmt; -import com.google.zetasql.resolvedast.ResolvedNodes.ResolvedStatement; import java.util.List; -import org.apache.beam.sdk.extensions.sql.impl.JavaUdfLoader; import org.apache.beam.sdk.extensions.sql.impl.QueryPlanner.QueryParameters; -import org.apache.beam.sdk.extensions.sql.udf.ScalarFn; import org.apache.beam.sdk.extensions.sql.zetasql.translation.ConversionContext; import org.apache.beam.sdk.extensions.sql.zetasql.translation.ExpressionConverter; import org.apache.beam.sdk.extensions.sql.zetasql.translation.QueryStatementConverter; -import org.apache.beam.sdk.extensions.sql.zetasql.translation.UserFunctionDefinitions; import org.apache.beam.vendor.calcite.v1_20_0.com.google.common.collect.ImmutableList; import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.adapter.java.JavaTypeFactory; import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.plan.RelOptCluster; @@ -55,7 +40,6 @@ import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.tools.Frameworks; import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.tools.Program; import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.util.Util; -import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap; /** ZetaSQLPlannerImpl. */ @SuppressWarnings({ @@ -93,79 +77,25 @@ class ZetaSQLPlannerImpl { public RelRoot rel(String sql, QueryParameters params) { RelOptCluster cluster = RelOptCluster.create(planner, new RexBuilder(typeFactory)); - QueryTrait trait = new QueryTrait(); - SqlAnalyzer analyzer = - new SqlAnalyzer(trait, defaultSchemaPlus, (JavaTypeFactory) cluster.getTypeFactory()); - AnalyzerOptions options = SqlAnalyzer.getAnalyzerOptions(params, defaultTimezone); + BeamZetaSqlCatalog catalog = + BeamZetaSqlCatalog.create( + defaultSchemaPlus, (JavaTypeFactory) cluster.getTypeFactory(), options); // Set up table providers that need to be pre-registered + SqlAnalyzer analyzer = new SqlAnalyzer(); List> tables = analyzer.extractTableNames(sql, options); TableResolution.registerTables(this.defaultSchemaPlus, tables); - SimpleCatalog catalog = - analyzer.createPopulatedCatalog(defaultSchemaPlus.getName(), options, tables); - - ImmutableMap.Builder, ResolvedCreateFunctionStmt> udfBuilder = - ImmutableMap.builder(); - ImmutableMap.Builder, ResolvedNode> udtvfBuilder = ImmutableMap.builder(); - ImmutableMap.Builder, UserFunctionDefinitions.JavaScalarFunction> - javaScalarFunctionBuilder = ImmutableMap.builder(); - JavaUdfLoader javaUdfLoader = new JavaUdfLoader(); - - ResolvedStatement statement; - ParseResumeLocation parseResumeLocation = new ParseResumeLocation(sql); - do { - statement = analyzer.analyzeNextStatement(parseResumeLocation, options, catalog); - if (statement.nodeKind() == RESOLVED_CREATE_FUNCTION_STMT) { - ResolvedCreateFunctionStmt createFunctionStmt = (ResolvedCreateFunctionStmt) statement; - String functionGroup = SqlAnalyzer.getFunctionGroup(createFunctionStmt); - switch (functionGroup) { - case SqlAnalyzer.USER_DEFINED_SQL_FUNCTIONS: - udfBuilder.put(createFunctionStmt.getNamePath(), createFunctionStmt); - break; - case SqlAnalyzer.USER_DEFINED_JAVA_SCALAR_FUNCTIONS: - String jarPath = getJarPath(createFunctionStmt); - ScalarFn scalarFn = - javaUdfLoader.loadScalarFunction(createFunctionStmt.getNamePath(), jarPath); - javaScalarFunctionBuilder.put( - createFunctionStmt.getNamePath(), - UserFunctionDefinitions.JavaScalarFunction.create(scalarFn, jarPath)); - break; - default: - throw new IllegalArgumentException( - String.format("Encountered unrecognized function group %s.", functionGroup)); - } - } else if (statement.nodeKind() == RESOLVED_CREATE_TABLE_FUNCTION_STMT) { - ResolvedCreateTableFunctionStmt createTableFunctionStmt = - (ResolvedCreateTableFunctionStmt) statement; - udtvfBuilder.put(createTableFunctionStmt.getNamePath(), createTableFunctionStmt.getQuery()); - } else if (statement.nodeKind() == RESOLVED_QUERY_STMT) { - if (!SqlAnalyzer.isEndOfInput(parseResumeLocation)) { - throw new UnsupportedOperationException( - "No additional statements are allowed after a SELECT statement."); - } - break; - } - } while (!SqlAnalyzer.isEndOfInput(parseResumeLocation)); - - if (!(statement instanceof ResolvedQueryStmt)) { - throw new UnsupportedOperationException( - "Statement list must end in a SELECT statement, not " + statement.nodeKindString()); - } - - UserFunctionDefinitions userFunctionDefinitions = - UserFunctionDefinitions.newBuilder() - .setSqlScalarFunctions(udfBuilder.build()) - .setSqlTableValuedFunctions(udtvfBuilder.build()) - .setJavaScalarFunctions(javaScalarFunctionBuilder.build()) - .build(); + QueryTrait trait = new QueryTrait(); + catalog.addTables(tables, trait); + + ResolvedQueryStmt statement = analyzer.analyzeQuery(sql, options, catalog); ExpressionConverter expressionConverter = - new ExpressionConverter(cluster, params, userFunctionDefinitions); + new ExpressionConverter(cluster, params, catalog.getUserFunctionDefinitions()); ConversionContext context = ConversionContext.of(config, expressionConverter, cluster, trait); - RelNode convertedNode = - QueryStatementConverter.convertRootQuery(context, (ResolvedQueryStmt) statement); + RelNode convertedNode = QueryStatementConverter.convertRootQuery(context, statement); return RelRoot.of(convertedNode, SqlKind.ALL); } @@ -185,39 +115,4 @@ void setDefaultTimezone(String timezone) { static LanguageOptions getLanguageOptions() { return SqlAnalyzer.baseAnalyzerOptions().getLanguageOptions(); } - - private static String getJarPath(ResolvedCreateFunctionStmt createFunctionStmt) { - String jarPath = getOptionStringValue(createFunctionStmt, "path"); - if (jarPath.isEmpty()) { - throw new IllegalArgumentException( - String.format( - "No jar was provided to define function %s. Add 'OPTIONS (path=)' to the CREATE FUNCTION statement.", - String.join(".", createFunctionStmt.getNamePath()))); - } - return jarPath; - } - - private static String getOptionStringValue( - ResolvedCreateFunctionStmt createFunctionStmt, String optionName) { - for (ResolvedNodes.ResolvedOption option : createFunctionStmt.getOptionList()) { - if (optionName.equals(option.getName())) { - if (option.getValue() == null) { - throw new IllegalArgumentException( - String.format( - "Option '%s' has null value (expected %s).", - optionName, ZetaSQLType.TypeKind.TYPE_STRING)); - } - if (option.getValue().getType().getKind() != ZetaSQLType.TypeKind.TYPE_STRING) { - throw new IllegalArgumentException( - String.format( - "Option '%s' has type %s (expected %s).", - optionName, - option.getValue().getType().getKind(), - ZetaSQLType.TypeKind.TYPE_STRING)); - } - return ((ResolvedNodes.ResolvedLiteral) option.getValue()).getValue().getStringValue(); - } - } - return ""; - } } diff --git a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/ZetaSQLQueryPlanner.java b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/ZetaSQLQueryPlanner.java index 07984d7eefb4..33824c0483d7 100644 --- a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/ZetaSQLQueryPlanner.java +++ b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/ZetaSQLQueryPlanner.java @@ -159,7 +159,7 @@ static boolean hasOnlyJavaUdfInProjects(RelOptRuleCall x) { if (udf.function instanceof ZetaSqlScalarFunctionImpl) { ZetaSqlScalarFunctionImpl scalarFunction = (ZetaSqlScalarFunctionImpl) udf.function; if (!scalarFunction.functionGroup.equals( - SqlAnalyzer.USER_DEFINED_JAVA_SCALAR_FUNCTIONS)) { + BeamZetaSqlCatalog.USER_DEFINED_JAVA_SCALAR_FUNCTIONS)) { // Reject ZetaSQL Builtin Scalar Functions return false; } @@ -224,7 +224,7 @@ static boolean hasNoJavaUdfInProjects(RelOptRuleCall x) { if (udf.function instanceof ZetaSqlScalarFunctionImpl) { ZetaSqlScalarFunctionImpl scalarFunction = (ZetaSqlScalarFunctionImpl) udf.function; if (scalarFunction.functionGroup.equals( - SqlAnalyzer.USER_DEFINED_JAVA_SCALAR_FUNCTIONS)) { + BeamZetaSqlCatalog.USER_DEFINED_JAVA_SCALAR_FUNCTIONS)) { return false; } } diff --git a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/ExpressionConverter.java b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/ExpressionConverter.java index 2a2c08fa5c64..29a37cf6097e 100644 --- a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/ExpressionConverter.java +++ b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/ExpressionConverter.java @@ -24,10 +24,10 @@ import static com.google.zetasql.ZetaSQLType.TypeKind.TYPE_INT64; import static com.google.zetasql.ZetaSQLType.TypeKind.TYPE_STRING; import static com.google.zetasql.ZetaSQLType.TypeKind.TYPE_TIMESTAMP; -import static org.apache.beam.sdk.extensions.sql.zetasql.SqlAnalyzer.PRE_DEFINED_WINDOW_FUNCTIONS; -import static org.apache.beam.sdk.extensions.sql.zetasql.SqlAnalyzer.USER_DEFINED_JAVA_SCALAR_FUNCTIONS; -import static org.apache.beam.sdk.extensions.sql.zetasql.SqlAnalyzer.USER_DEFINED_SQL_FUNCTIONS; -import static org.apache.beam.sdk.extensions.sql.zetasql.SqlAnalyzer.ZETASQL_FUNCTION_GROUP_NAME; +import static org.apache.beam.sdk.extensions.sql.zetasql.BeamZetaSqlCatalog.PRE_DEFINED_WINDOW_FUNCTIONS; +import static org.apache.beam.sdk.extensions.sql.zetasql.BeamZetaSqlCatalog.USER_DEFINED_JAVA_SCALAR_FUNCTIONS; +import static org.apache.beam.sdk.extensions.sql.zetasql.BeamZetaSqlCatalog.USER_DEFINED_SQL_FUNCTIONS; +import static org.apache.beam.sdk.extensions.sql.zetasql.BeamZetaSqlCatalog.ZETASQL_FUNCTION_GROUP_NAME; import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkArgument; import com.google.common.base.Ascii; diff --git a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/SqlOperators.java b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/SqlOperators.java index 584c727a3fd2..ea7e24f4c76c 100644 --- a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/SqlOperators.java +++ b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/SqlOperators.java @@ -17,6 +17,8 @@ */ package org.apache.beam.sdk.extensions.sql.zetasql.translation; +import static org.apache.beam.sdk.extensions.sql.zetasql.BeamZetaSqlCatalog.ZETASQL_FUNCTION_GROUP_NAME; + import java.lang.reflect.Method; import java.util.ArrayList; import java.util.List; @@ -28,7 +30,6 @@ import org.apache.beam.sdk.extensions.sql.impl.transform.agg.CountIf; import org.apache.beam.sdk.extensions.sql.impl.udaf.StringAgg; import org.apache.beam.sdk.extensions.sql.zetasql.DateTimeUtils; -import org.apache.beam.sdk.extensions.sql.zetasql.SqlAnalyzer; import org.apache.beam.sdk.extensions.sql.zetasql.translation.impl.BeamBuiltinMethods; import org.apache.beam.sdk.extensions.sql.zetasql.translation.impl.CastFunctionImpl; import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.jdbc.JavaTypeFactoryImpl; @@ -86,56 +87,43 @@ public class SqlOperators { public static final SqlOperator START_WITHS = createUdfOperator( - "STARTS_WITH", - BeamBuiltinMethods.STARTS_WITH_METHOD, - SqlAnalyzer.ZETASQL_FUNCTION_GROUP_NAME); + "STARTS_WITH", BeamBuiltinMethods.STARTS_WITH_METHOD, ZETASQL_FUNCTION_GROUP_NAME); public static final SqlOperator CONCAT = - createUdfOperator( - "CONCAT", BeamBuiltinMethods.CONCAT_METHOD, SqlAnalyzer.ZETASQL_FUNCTION_GROUP_NAME); + createUdfOperator("CONCAT", BeamBuiltinMethods.CONCAT_METHOD, ZETASQL_FUNCTION_GROUP_NAME); public static final SqlOperator REPLACE = - createUdfOperator( - "REPLACE", BeamBuiltinMethods.REPLACE_METHOD, SqlAnalyzer.ZETASQL_FUNCTION_GROUP_NAME); + createUdfOperator("REPLACE", BeamBuiltinMethods.REPLACE_METHOD, ZETASQL_FUNCTION_GROUP_NAME); public static final SqlOperator TRIM = - createUdfOperator( - "TRIM", BeamBuiltinMethods.TRIM_METHOD, SqlAnalyzer.ZETASQL_FUNCTION_GROUP_NAME); + createUdfOperator("TRIM", BeamBuiltinMethods.TRIM_METHOD, ZETASQL_FUNCTION_GROUP_NAME); public static final SqlOperator LTRIM = - createUdfOperator( - "LTRIM", BeamBuiltinMethods.LTRIM_METHOD, SqlAnalyzer.ZETASQL_FUNCTION_GROUP_NAME); + createUdfOperator("LTRIM", BeamBuiltinMethods.LTRIM_METHOD, ZETASQL_FUNCTION_GROUP_NAME); public static final SqlOperator RTRIM = - createUdfOperator( - "RTRIM", BeamBuiltinMethods.RTRIM_METHOD, SqlAnalyzer.ZETASQL_FUNCTION_GROUP_NAME); + createUdfOperator("RTRIM", BeamBuiltinMethods.RTRIM_METHOD, ZETASQL_FUNCTION_GROUP_NAME); public static final SqlOperator SUBSTR = - createUdfOperator( - "SUBSTR", BeamBuiltinMethods.SUBSTR_METHOD, SqlAnalyzer.ZETASQL_FUNCTION_GROUP_NAME); + createUdfOperator("SUBSTR", BeamBuiltinMethods.SUBSTR_METHOD, ZETASQL_FUNCTION_GROUP_NAME); public static final SqlOperator REVERSE = - createUdfOperator( - "REVERSE", BeamBuiltinMethods.REVERSE_METHOD, SqlAnalyzer.ZETASQL_FUNCTION_GROUP_NAME); + createUdfOperator("REVERSE", BeamBuiltinMethods.REVERSE_METHOD, ZETASQL_FUNCTION_GROUP_NAME); public static final SqlOperator CHAR_LENGTH = createUdfOperator( - "CHAR_LENGTH", - BeamBuiltinMethods.CHAR_LENGTH_METHOD, - SqlAnalyzer.ZETASQL_FUNCTION_GROUP_NAME); + "CHAR_LENGTH", BeamBuiltinMethods.CHAR_LENGTH_METHOD, ZETASQL_FUNCTION_GROUP_NAME); public static final SqlOperator ENDS_WITH = createUdfOperator( - "ENDS_WITH", - BeamBuiltinMethods.ENDS_WITH_METHOD, - SqlAnalyzer.ZETASQL_FUNCTION_GROUP_NAME); + "ENDS_WITH", BeamBuiltinMethods.ENDS_WITH_METHOD, ZETASQL_FUNCTION_GROUP_NAME); public static final SqlOperator LIKE = createUdfOperator( "LIKE", BeamBuiltinMethods.LIKE_METHOD, SqlSyntax.BINARY, - SqlAnalyzer.ZETASQL_FUNCTION_GROUP_NAME, + ZETASQL_FUNCTION_GROUP_NAME, ""); public static final SqlOperator VALIDATE_TIMESTAMP = @@ -145,7 +133,7 @@ public class SqlOperators { "validateTimestamp", x -> NULLABLE_TIMESTAMP, ImmutableList.of(TIMESTAMP), - SqlAnalyzer.ZETASQL_FUNCTION_GROUP_NAME); + ZETASQL_FUNCTION_GROUP_NAME); public static final SqlOperator VALIDATE_TIME_INTERVAL = createUdfOperator( @@ -154,17 +142,14 @@ public class SqlOperators { "validateTimeInterval", x -> NULLABLE_BIGINT, ImmutableList.of(BIGINT, OTHER), - SqlAnalyzer.ZETASQL_FUNCTION_GROUP_NAME); + ZETASQL_FUNCTION_GROUP_NAME); public static final SqlOperator TIMESTAMP_OP = createUdfOperator( - "TIMESTAMP", - BeamBuiltinMethods.TIMESTAMP_METHOD, - SqlAnalyzer.ZETASQL_FUNCTION_GROUP_NAME); + "TIMESTAMP", BeamBuiltinMethods.TIMESTAMP_METHOD, ZETASQL_FUNCTION_GROUP_NAME); public static final SqlOperator DATE_OP = - createUdfOperator( - "DATE", BeamBuiltinMethods.DATE_METHOD, SqlAnalyzer.ZETASQL_FUNCTION_GROUP_NAME); + createUdfOperator("DATE", BeamBuiltinMethods.DATE_METHOD, ZETASQL_FUNCTION_GROUP_NAME); public static final SqlOperator BIT_XOR = createUdafOperator( diff --git a/sdks/java/extensions/sql/zetasql/src/test/java/org/apache/beam/sdk/extensions/sql/zetasql/ZetaSqlJavaUdfTest.java b/sdks/java/extensions/sql/zetasql/src/test/java/org/apache/beam/sdk/extensions/sql/zetasql/ZetaSqlJavaUdfTest.java index 41a55bbbadfe..2fa564648fe6 100644 --- a/sdks/java/extensions/sql/zetasql/src/test/java/org/apache/beam/sdk/extensions/sql/zetasql/ZetaSqlJavaUdfTest.java +++ b/sdks/java/extensions/sql/zetasql/src/test/java/org/apache/beam/sdk/extensions/sql/zetasql/ZetaSqlJavaUdfTest.java @@ -244,8 +244,12 @@ public void testJavaUdfEmptyPath() { String sql = "CREATE FUNCTION foo() RETURNS STRING LANGUAGE java OPTIONS (path=''); SELECT foo();"; ZetaSQLQueryPlanner zetaSQLQueryPlanner = new ZetaSQLQueryPlanner(config); - thrown.expect(IllegalArgumentException.class); - thrown.expectMessage("No jar was provided to define function foo."); + thrown.expect(RuntimeException.class); + thrown.expectMessage("Failed to define function 'foo'"); + thrown.expectCause( + allOf( + isA(IllegalArgumentException.class), + hasProperty("message", containsString("No jar was provided to define function foo.")))); zetaSQLQueryPlanner.convertToBeamRel(sql); } @@ -253,8 +257,12 @@ public void testJavaUdfEmptyPath() { public void testJavaUdfNoJarProvided() { String sql = "CREATE FUNCTION foo() RETURNS STRING LANGUAGE java; SELECT foo();"; ZetaSQLQueryPlanner zetaSQLQueryPlanner = new ZetaSQLQueryPlanner(config); - thrown.expect(IllegalArgumentException.class); - thrown.expectMessage("No jar was provided to define function foo."); + thrown.expect(RuntimeException.class); + thrown.expectMessage("Failed to define function 'foo'"); + thrown.expectCause( + allOf( + isA(IllegalArgumentException.class), + hasProperty("message", containsString("No jar was provided to define function foo.")))); zetaSQLQueryPlanner.convertToBeamRel(sql); } @@ -263,8 +271,14 @@ public void testPathOptionNotString() { String sql = "CREATE FUNCTION foo() RETURNS STRING LANGUAGE java OPTIONS (path=23); SELECT foo();"; ZetaSQLQueryPlanner zetaSQLQueryPlanner = new ZetaSQLQueryPlanner(config); - thrown.expect(IllegalArgumentException.class); - thrown.expectMessage("Option 'path' has type TYPE_INT64 (expected TYPE_STRING)."); + thrown.expect(RuntimeException.class); + thrown.expectMessage("Failed to define function 'foo'"); + thrown.expectCause( + allOf( + isA(IllegalArgumentException.class), + hasProperty( + "message", + containsString("Option 'path' has type TYPE_INT64 (expected TYPE_STRING).")))); zetaSQLQueryPlanner.convertToBeamRel(sql); } } diff --git a/sdks/java/extensions/sql/zetasql/src/test/java/org/apache/beam/sdk/extensions/sql/zetasql/ZetaSqlUdfTest.java b/sdks/java/extensions/sql/zetasql/src/test/java/org/apache/beam/sdk/extensions/sql/zetasql/ZetaSqlNativeUdfTest.java similarity index 96% rename from sdks/java/extensions/sql/zetasql/src/test/java/org/apache/beam/sdk/extensions/sql/zetasql/ZetaSqlUdfTest.java rename to sdks/java/extensions/sql/zetasql/src/test/java/org/apache/beam/sdk/extensions/sql/zetasql/ZetaSqlNativeUdfTest.java index 53b309a10b95..0357242e7481 100644 --- a/sdks/java/extensions/sql/zetasql/src/test/java/org/apache/beam/sdk/extensions/sql/zetasql/ZetaSqlUdfTest.java +++ b/sdks/java/extensions/sql/zetasql/src/test/java/org/apache/beam/sdk/extensions/sql/zetasql/ZetaSqlNativeUdfTest.java @@ -17,8 +17,9 @@ */ package org.apache.beam.sdk.extensions.sql.zetasql; +import static org.hamcrest.Matchers.isA; + import com.google.zetasql.SqlException; -import org.apache.beam.sdk.extensions.sql.impl.ParseException; import org.apache.beam.sdk.extensions.sql.impl.SqlConversionException; import org.apache.beam.sdk.extensions.sql.impl.rel.BeamRelNode; import org.apache.beam.sdk.extensions.sql.impl.rel.BeamSqlRelUtils; @@ -36,9 +37,9 @@ import org.junit.runner.RunWith; import org.junit.runners.JUnit4; -/** Tests for user defined functions in the ZetaSQL dialect. */ +/** Tests for SQL-native user defined functions in the ZetaSQL dialect. */ @RunWith(JUnit4.class) -public class ZetaSqlUdfTest extends ZetaSqlTestBase { +public class ZetaSqlNativeUdfTest extends ZetaSqlTestBase { @Rule public transient TestPipeline pipeline = TestPipeline.create(); @Rule public ExpectedException thrown = ExpectedException.none(); @@ -51,8 +52,9 @@ public void setUp() { public void testAlreadyDefinedUDFThrowsException() { String sql = "CREATE FUNCTION foo() AS (0); CREATE FUNCTION foo() AS (1); SELECT foo();"; ZetaSQLQueryPlanner zetaSQLQueryPlanner = new ZetaSQLQueryPlanner(config); - thrown.expect(ParseException.class); - thrown.expectMessage("Failed to define function foo"); + thrown.expect(RuntimeException.class); + thrown.expectMessage("Failed to define function 'foo'"); + thrown.expectCause(isA(IllegalArgumentException.class)); zetaSQLQueryPlanner.convertToBeamRel(sql); }