> tables, QueryTrait queryTrait) {
+ tables.forEach(table -> addTableToLeafCatalog(table, queryTrait));
+ }
+
+ void addFunction(ResolvedNodes.ResolvedCreateFunctionStmt createFunctionStmt) {
+ String functionGroup = getFunctionGroup(createFunctionStmt);
+ switch (functionGroup) {
+ case USER_DEFINED_SQL_FUNCTIONS:
+ sqlScalarUdfs.put(createFunctionStmt.getNamePath(), createFunctionStmt);
+ break;
+ case USER_DEFINED_JAVA_SCALAR_FUNCTIONS:
+ String jarPath = getJarPath(createFunctionStmt);
+ ScalarFn scalarFn =
+ javaUdfLoader.loadScalarFunction(createFunctionStmt.getNamePath(), jarPath);
+ javaScalarUdfs.put(
+ createFunctionStmt.getNamePath(),
+ UserFunctionDefinitions.JavaScalarFunction.create(scalarFn, jarPath));
+ break;
+ default:
+ throw new IllegalArgumentException(
+ String.format("Encountered unrecognized function group %s.", functionGroup));
+ }
+ zetaSqlCatalog.addFunction(
+ new Function(
+ createFunctionStmt.getNamePath(),
+ functionGroup,
+ createFunctionStmt.getIsAggregate()
+ ? ZetaSQLFunctions.FunctionEnums.Mode.AGGREGATE
+ : ZetaSQLFunctions.FunctionEnums.Mode.SCALAR,
+ ImmutableList.of(createFunctionStmt.getSignature())));
+ }
+
+ void addTableValuedFunction(
+ ResolvedNodes.ResolvedCreateTableFunctionStmt createTableFunctionStmt) {
+ zetaSqlCatalog.addTableValuedFunction(
+ new TableValuedFunction.FixedOutputSchemaTVF(
+ createTableFunctionStmt.getNamePath(),
+ createTableFunctionStmt.getSignature(),
+ TVFRelation.createColumnBased(
+ createTableFunctionStmt.getQuery().getColumnList().stream()
+ .map(c -> TVFRelation.Column.create(c.getName(), c.getType()))
+ .collect(Collectors.toList()))));
+ sqlUdtvfs.put(createTableFunctionStmt.getNamePath(), createTableFunctionStmt.getQuery());
+ }
+
+ UserFunctionDefinitions getUserFunctionDefinitions() {
+ return UserFunctionDefinitions.newBuilder()
+ .setSqlScalarFunctions(ImmutableMap.copyOf(sqlScalarUdfs))
+ .setSqlTableValuedFunctions(ImmutableMap.copyOf(sqlUdtvfs))
+ .setJavaScalarFunctions(ImmutableMap.copyOf(javaScalarUdfs))
+ .build();
+ }
+
+ private void addBuiltinFunctionsToCatalog(AnalyzerOptions options) {
+ // Enable ZetaSQL builtin functions.
+ ZetaSQLBuiltinFunctionOptions zetasqlBuiltinFunctionOptions =
+ new ZetaSQLBuiltinFunctionOptions(options.getLanguageOptions());
+ SupportedZetaSqlBuiltinFunctions.ALLOWLIST.forEach(
+ zetasqlBuiltinFunctionOptions::includeFunctionSignatureId);
+ zetaSqlCatalog.addZetaSQLFunctions(zetasqlBuiltinFunctionOptions);
+
+ // Enable Beam SQL's builtin windowing functions.
+ addWindowScalarFunctions(options);
+ addWindowTvfs();
+ }
+
+ private void addWindowScalarFunctions(AnalyzerOptions options) {
+ PRE_DEFINED_WINDOW_FUNCTION_DECLARATIONS.stream()
+ .map(
+ func ->
+ (ResolvedNodes.ResolvedCreateFunctionStmt)
+ Analyzer.analyzeStatement(func, options, zetaSqlCatalog))
+ .map(
+ resolvedFunc ->
+ new Function(
+ String.join(".", resolvedFunc.getNamePath()),
+ PRE_DEFINED_WINDOW_FUNCTIONS,
+ ZetaSQLFunctions.FunctionEnums.Mode.SCALAR,
+ ImmutableList.of(resolvedFunc.getSignature())))
+ .forEach(zetaSqlCatalog::addFunction);
+ }
+
+ @SuppressWarnings({
+ "nullness" // customContext and volatility are in fact nullable, but they are missing the
+ // annotation upstream. TODO Unsuppress when this is fixed in ZetaSQL.
+ })
+ private void addWindowTvfs() {
+ FunctionArgumentType retType =
+ new FunctionArgumentType(ZetaSQLFunctions.SignatureArgumentKind.ARG_TYPE_RELATION);
+
+ FunctionArgumentType inputTableType =
+ new FunctionArgumentType(ZetaSQLFunctions.SignatureArgumentKind.ARG_TYPE_RELATION);
+
+ FunctionArgumentType descriptorType =
+ new FunctionArgumentType(
+ ZetaSQLFunctions.SignatureArgumentKind.ARG_TYPE_DESCRIPTOR,
+ FunctionArgumentType.FunctionArgumentTypeOptions.builder()
+ .setDescriptorResolutionTableOffset(0)
+ .build(),
+ 1);
+
+ FunctionArgumentType stringType =
+ new FunctionArgumentType(TypeFactory.createSimpleType(ZetaSQLType.TypeKind.TYPE_STRING));
+
+ // TUMBLE
+ zetaSqlCatalog.addTableValuedFunction(
+ new TableValuedFunction.ForwardInputSchemaToOutputSchemaWithAppendedColumnTVF(
+ ImmutableList.of(TVFStreamingUtils.FIXED_WINDOW_TVF),
+ new FunctionSignature(
+ retType, ImmutableList.of(inputTableType, descriptorType, stringType), -1),
+ ImmutableList.of(
+ TVFRelation.Column.create(
+ TVFStreamingUtils.WINDOW_START,
+ TypeFactory.createSimpleType(ZetaSQLType.TypeKind.TYPE_TIMESTAMP)),
+ TVFRelation.Column.create(
+ TVFStreamingUtils.WINDOW_END,
+ TypeFactory.createSimpleType(ZetaSQLType.TypeKind.TYPE_TIMESTAMP))),
+ null,
+ null));
+
+ // HOP
+ zetaSqlCatalog.addTableValuedFunction(
+ new TableValuedFunction.ForwardInputSchemaToOutputSchemaWithAppendedColumnTVF(
+ ImmutableList.of(TVFStreamingUtils.SLIDING_WINDOW_TVF),
+ new FunctionSignature(
+ retType,
+ ImmutableList.of(inputTableType, descriptorType, stringType, stringType),
+ -1),
+ ImmutableList.of(
+ TVFRelation.Column.create(
+ TVFStreamingUtils.WINDOW_START,
+ TypeFactory.createSimpleType(ZetaSQLType.TypeKind.TYPE_TIMESTAMP)),
+ TVFRelation.Column.create(
+ TVFStreamingUtils.WINDOW_END,
+ TypeFactory.createSimpleType(ZetaSQLType.TypeKind.TYPE_TIMESTAMP))),
+ null,
+ null));
+
+ // SESSION
+ zetaSqlCatalog.addTableValuedFunction(
+ new TableValuedFunction.ForwardInputSchemaToOutputSchemaWithAppendedColumnTVF(
+ ImmutableList.of(TVFStreamingUtils.SESSION_WINDOW_TVF),
+ new FunctionSignature(
+ retType,
+ ImmutableList.of(inputTableType, descriptorType, descriptorType, stringType),
+ -1),
+ ImmutableList.of(
+ TVFRelation.Column.create(
+ TVFStreamingUtils.WINDOW_START,
+ TypeFactory.createSimpleType(ZetaSQLType.TypeKind.TYPE_TIMESTAMP)),
+ TVFRelation.Column.create(
+ TVFStreamingUtils.WINDOW_END,
+ TypeFactory.createSimpleType(ZetaSQLType.TypeKind.TYPE_TIMESTAMP))),
+ null,
+ null));
+ }
+
+ private String getFunctionGroup(ResolvedNodes.ResolvedCreateFunctionStmt createFunctionStmt) {
+ switch (createFunctionStmt.getLanguage().toUpperCase()) {
+ case "JAVA":
+ if (createFunctionStmt.getIsAggregate()) {
+ throw new UnsupportedOperationException(
+ "Java SQL aggregate functions are not supported (BEAM-10925).");
+ }
+ return USER_DEFINED_JAVA_SCALAR_FUNCTIONS;
+ case "SQL":
+ if (createFunctionStmt.getIsAggregate()) {
+ throw new UnsupportedOperationException(
+ "Native SQL aggregate functions are not supported (BEAM-9954).");
+ }
+ return USER_DEFINED_SQL_FUNCTIONS;
+ case "PY":
+ case "PYTHON":
+ case "JS":
+ case "JAVASCRIPT":
+ throw new UnsupportedOperationException(
+ String.format(
+ "Function %s uses unsupported language %s.",
+ String.join(".", createFunctionStmt.getNamePath()),
+ createFunctionStmt.getLanguage()));
+ default:
+ throw new IllegalArgumentException(
+ String.format(
+ "Function %s uses unrecognized language %s.",
+ String.join(".", createFunctionStmt.getNamePath()),
+ createFunctionStmt.getLanguage()));
+ }
+ }
+
+ /**
+ * Assume last element in tablePath is a table name, and everything before is catalogs. So the
+ * logic is to create nested catalogs until the last level, then add a table at the last level.
+ *
+ * Table schema is extracted from Calcite schema based on the table name resolution strategy,
+ * e.g. either by drilling down the schema.getSubschema() path or joining the table name with dots
+ * to construct a single compound identifier (e.g. Data Catalog use case).
+ */
+ private void addTableToLeafCatalog(List tablePath, QueryTrait queryTrait) {
+
+ SimpleCatalog leafCatalog = createNestedCatalogs(zetaSqlCatalog, tablePath);
+
+ org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.schema.Table calciteTable =
+ TableResolution.resolveCalciteTable(calciteSchema, tablePath);
+
+ if (calciteTable == null) {
+ throw new SqlConversionException(
+ "Wasn't able to resolve the path "
+ + tablePath
+ + " in schema: "
+ + calciteSchema.getName());
+ }
+
+ RelDataType rowType = calciteTable.getRowType(typeFactory);
+
+ TableResolution.SimpleTableWithPath tableWithPath =
+ TableResolution.SimpleTableWithPath.of(tablePath);
+ queryTrait.addResolvedTable(tableWithPath);
+
+ addFieldsToTable(tableWithPath, rowType);
+ leafCatalog.addSimpleTable(tableWithPath.getTable());
+ }
+
+ private static void addFieldsToTable(
+ TableResolution.SimpleTableWithPath tableWithPath, RelDataType rowType) {
+ for (RelDataTypeField field : rowType.getFieldList()) {
+ tableWithPath
+ .getTable()
+ .addSimpleColumn(
+ field.getName(), ZetaSqlCalciteTranslationUtils.toZetaSqlType(field.getType()));
+ }
+ }
+
+ /** For table path like a.b.c we assume c is the table and a.b are the nested catalogs/schemas. */
+ private static SimpleCatalog createNestedCatalogs(SimpleCatalog catalog, List tablePath) {
+ SimpleCatalog currentCatalog = catalog;
+ for (int i = 0; i < tablePath.size() - 1; i++) {
+ String nextCatalogName = tablePath.get(i);
+
+ Optional existing = tryGetExisting(currentCatalog, nextCatalogName);
+
+ currentCatalog =
+ existing.isPresent() ? existing.get() : addNewCatalog(currentCatalog, nextCatalogName);
+ }
+ return currentCatalog;
+ }
+
+ private static Optional tryGetExisting(
+ SimpleCatalog currentCatalog, String nextCatalogName) {
+ return currentCatalog.getCatalogList().stream()
+ .filter(c -> nextCatalogName.equals(c.getFullName()))
+ .findFirst();
+ }
+
+ private static SimpleCatalog addNewCatalog(SimpleCatalog currentCatalog, String nextCatalogName) {
+ SimpleCatalog nextCatalog = new SimpleCatalog(nextCatalogName);
+ currentCatalog.addSimpleCatalog(nextCatalog);
+ return nextCatalog;
+ }
+
+ private static String getJarPath(ResolvedNodes.ResolvedCreateFunctionStmt createFunctionStmt) {
+ String jarPath = getOptionStringValue(createFunctionStmt, "path");
+ if (jarPath.isEmpty()) {
+ throw new IllegalArgumentException(
+ String.format(
+ "No jar was provided to define function %s. Add 'OPTIONS (path=)' to the CREATE FUNCTION statement.",
+ String.join(".", createFunctionStmt.getNamePath())));
+ }
+ return jarPath;
+ }
+
+ private static String getOptionStringValue(
+ ResolvedNodes.ResolvedCreateFunctionStmt createFunctionStmt, String optionName) {
+ for (ResolvedNodes.ResolvedOption option : createFunctionStmt.getOptionList()) {
+ if (optionName.equals(option.getName())) {
+ if (option.getValue() == null) {
+ throw new IllegalArgumentException(
+ String.format(
+ "Option '%s' has null value (expected %s).",
+ optionName, ZetaSQLType.TypeKind.TYPE_STRING));
+ }
+ if (option.getValue().getType().getKind() != ZetaSQLType.TypeKind.TYPE_STRING) {
+ throw new IllegalArgumentException(
+ String.format(
+ "Option '%s' has type %s (expected %s).",
+ optionName,
+ option.getValue().getType().getKind(),
+ ZetaSQLType.TypeKind.TYPE_STRING));
+ }
+ return ((ResolvedNodes.ResolvedLiteral) option.getValue()).getValue().getStringValue();
+ }
+ }
+ return "";
+ }
+}
diff --git a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/SqlAnalyzer.java b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/SqlAnalyzer.java
index 0b65757c7b7e..14a65749e411 100644
--- a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/SqlAnalyzer.java
+++ b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/SqlAnalyzer.java
@@ -22,28 +22,16 @@
import static com.google.zetasql.ZetaSQLResolvedNodeKind.ResolvedNodeKind.RESOLVED_QUERY_STMT;
import static java.nio.charset.StandardCharsets.UTF_8;
-import com.google.common.collect.ImmutableList;
-import com.google.common.collect.ImmutableSet;
import com.google.zetasql.Analyzer;
import com.google.zetasql.AnalyzerOptions;
-import com.google.zetasql.Function;
-import com.google.zetasql.FunctionArgumentType;
-import com.google.zetasql.FunctionSignature;
import com.google.zetasql.ParseResumeLocation;
-import com.google.zetasql.SimpleCatalog;
-import com.google.zetasql.TVFRelation;
-import com.google.zetasql.TableValuedFunction;
-import com.google.zetasql.TypeFactory;
import com.google.zetasql.Value;
-import com.google.zetasql.ZetaSQLBuiltinFunctionOptions;
-import com.google.zetasql.ZetaSQLFunctions.FunctionEnums.Mode;
-import com.google.zetasql.ZetaSQLFunctions.SignatureArgumentKind;
import com.google.zetasql.ZetaSQLOptions.ErrorMessageMode;
import com.google.zetasql.ZetaSQLOptions.LanguageFeature;
import com.google.zetasql.ZetaSQLOptions.ParameterMode;
import com.google.zetasql.ZetaSQLOptions.ProductMode;
import com.google.zetasql.ZetaSQLResolvedNodeKind.ResolvedNodeKind;
-import com.google.zetasql.ZetaSQLType.TypeKind;
+import com.google.zetasql.resolvedast.ResolvedNodes;
import com.google.zetasql.resolvedast.ResolvedNodes.ResolvedCreateFunctionStmt;
import com.google.zetasql.resolvedast.ResolvedNodes.ResolvedCreateTableFunctionStmt;
import com.google.zetasql.resolvedast.ResolvedNodes.ResolvedStatement;
@@ -51,69 +39,21 @@
import java.util.HashSet;
import java.util.List;
import java.util.Map;
-import java.util.Optional;
-import java.util.stream.Collectors;
-import org.apache.beam.sdk.extensions.sql.impl.ParseException;
import org.apache.beam.sdk.extensions.sql.impl.QueryPlanner.QueryParameters;
import org.apache.beam.sdk.extensions.sql.impl.QueryPlanner.QueryParameters.Kind;
-import org.apache.beam.sdk.extensions.sql.impl.SqlConversionException;
-import org.apache.beam.sdk.extensions.sql.impl.utils.TVFStreamingUtils;
-import org.apache.beam.sdk.extensions.sql.zetasql.TableResolution.SimpleTableWithPath;
-import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.adapter.java.JavaTypeFactory;
-import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.rel.type.RelDataType;
-import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.rel.type.RelDataTypeField;
-import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.schema.SchemaPlus;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableSet;
/** Adapter for {@link Analyzer} to simplify the API for parsing the query and resolving the AST. */
@SuppressWarnings({
"nullness" // TODO(https://issues.apache.org/jira/browse/BEAM-10402)
})
public class SqlAnalyzer {
- // ZetaSQL function group identifiers. Different function groups may have divergent translation
- // paths.
- public static final String PRE_DEFINED_WINDOW_FUNCTIONS = "pre_defined_window_functions";
- public static final String USER_DEFINED_SQL_FUNCTIONS = "user_defined_functions";
- /**
- * Same as {@link Function}.ZETASQL_FUNCTION_GROUP_NAME. Identifies built-in ZetaSQL functions.
- */
- public static final String ZETASQL_FUNCTION_GROUP_NAME = "ZetaSQL";
-
- public static final String USER_DEFINED_JAVA_SCALAR_FUNCTIONS =
- "user_defined_java_scalar_functions";
-
private static final ImmutableSet SUPPORTED_STATEMENT_KINDS =
ImmutableSet.of(
RESOLVED_QUERY_STMT, RESOLVED_CREATE_FUNCTION_STMT, RESOLVED_CREATE_TABLE_FUNCTION_STMT);
- private static final ImmutableList FUNCTION_LIST =
- ImmutableList.of(
- // TODO: support optional function argument (for window_offset).
- "CREATE FUNCTION TUMBLE(ts TIMESTAMP, window_size STRING) AS (1);",
- "CREATE FUNCTION TUMBLE_START(window_size STRING) RETURNS TIMESTAMP AS (null);",
- "CREATE FUNCTION TUMBLE_END(window_size STRING) RETURNS TIMESTAMP AS (null);",
- "CREATE FUNCTION HOP(ts TIMESTAMP, emit_frequency STRING, window_size STRING) AS (1);",
- "CREATE FUNCTION HOP_START(emit_frequency STRING, window_size STRING) "
- + "RETURNS TIMESTAMP AS (null);",
- "CREATE FUNCTION HOP_END(emit_frequency STRING, window_size STRING) "
- + "RETURNS TIMESTAMP AS (null);",
- "CREATE FUNCTION SESSION(ts TIMESTAMP, session_gap STRING) AS (1);",
- "CREATE FUNCTION SESSION_START(session_gap STRING) RETURNS TIMESTAMP AS (null);",
- "CREATE FUNCTION SESSION_END(session_gap STRING) RETURNS TIMESTAMP AS (null);");
-
- private final QueryTrait queryTrait;
- private final SchemaPlus topLevelSchema;
- private final JavaTypeFactory typeFactory;
-
- SqlAnalyzer(QueryTrait queryTrait, SchemaPlus topLevelSchema, JavaTypeFactory typeFactory) {
- this.queryTrait = queryTrait;
- this.topLevelSchema = topLevelSchema;
- this.typeFactory = typeFactory;
- }
-
- static boolean isEndOfInput(ParseResumeLocation parseResumeLocation) {
- return parseResumeLocation.getBytePosition()
- >= parseResumeLocation.getInput().getBytes(UTF_8).length;
- }
+ SqlAnalyzer() {}
/** Returns table names from all statements in the SQL string. */
List> extractTableNames(String sql, AnalyzerOptions options) {
@@ -127,44 +67,37 @@ List> extractTableNames(String sql, AnalyzerOptions options) {
return tables.build();
}
- static String getFunctionGroup(ResolvedCreateFunctionStmt createFunctionStmt) {
- switch (createFunctionStmt.getLanguage().toUpperCase()) {
- case "JAVA":
- if (createFunctionStmt.getIsAggregate()) {
- throw new UnsupportedOperationException(
- "Java SQL aggregate functions are not supported (BEAM-10925).");
- }
- return USER_DEFINED_JAVA_SCALAR_FUNCTIONS;
- case "SQL":
- if (createFunctionStmt.getIsAggregate()) {
+ /**
+ * Analyzes the entire SQL code block (which may consist of multiple statements) and returns the
+ * resolved query.
+ *
+ * Assumes there is exactly one SELECT statement in the input, and it must be the last
+ * statement in the input.
+ */
+ ResolvedNodes.ResolvedQueryStmt analyzeQuery(
+ String sql, AnalyzerOptions options, BeamZetaSqlCatalog catalog) {
+ ParseResumeLocation parseResumeLocation = new ParseResumeLocation(sql);
+ ResolvedStatement statement;
+ do {
+ statement = analyzeNextStatement(parseResumeLocation, options, catalog);
+ if (statement.nodeKind() == RESOLVED_QUERY_STMT) {
+ if (!SqlAnalyzer.isEndOfInput(parseResumeLocation)) {
throw new UnsupportedOperationException(
- "Native SQL aggregate functions are not supported (BEAM-9954).");
+ "No additional statements are allowed after a SELECT statement.");
}
- return USER_DEFINED_SQL_FUNCTIONS;
- case "PY":
- case "PYTHON":
- case "JS":
- case "JAVASCRIPT":
- throw new UnsupportedOperationException(
- String.format(
- "Function %s uses unsupported language %s.",
- String.join(".", createFunctionStmt.getNamePath()),
- createFunctionStmt.getLanguage()));
- default:
- throw new IllegalArgumentException(
- String.format(
- "Function %s uses unrecognized language %s.",
- String.join(".", createFunctionStmt.getNamePath()),
- createFunctionStmt.getLanguage()));
+ }
+ } while (!SqlAnalyzer.isEndOfInput(parseResumeLocation));
+
+ if (!(statement instanceof ResolvedNodes.ResolvedQueryStmt)) {
+ throw new UnsupportedOperationException(
+ "Statement list must end in a SELECT statement, not " + statement.nodeKindString());
}
+ return (ResolvedNodes.ResolvedQueryStmt) statement;
}
- private Function createFunction(ResolvedCreateFunctionStmt createFunctionStmt) {
- return new Function(
- createFunctionStmt.getNamePath(),
- getFunctionGroup(createFunctionStmt),
- createFunctionStmt.getIsAggregate() ? Mode.AGGREGATE : Mode.SCALAR,
- ImmutableList.of(createFunctionStmt.getSignature()));
+ private static boolean isEndOfInput(ParseResumeLocation parseResumeLocation) {
+ return parseResumeLocation.getBytePosition()
+ >= parseResumeLocation.getInput().getBytes(UTF_8).length;
}
/**
@@ -172,33 +105,28 @@ private Function createFunction(ResolvedCreateFunctionStmt createFunctionStmt) {
* ParseResumeLocation to the start of the next statement. Adds user-defined functions to the
* catalog for use in following statements. Returns the resolved AST.
*/
- ResolvedStatement analyzeNextStatement(
- ParseResumeLocation parseResumeLocation, AnalyzerOptions options, SimpleCatalog catalog) {
+ private ResolvedStatement analyzeNextStatement(
+ ParseResumeLocation parseResumeLocation,
+ AnalyzerOptions options,
+ BeamZetaSqlCatalog catalog) {
ResolvedStatement resolvedStatement =
- Analyzer.analyzeNextStatement(parseResumeLocation, options, catalog);
+ Analyzer.analyzeNextStatement(parseResumeLocation, options, catalog.getZetaSqlCatalog());
if (resolvedStatement.nodeKind() == RESOLVED_CREATE_FUNCTION_STMT) {
ResolvedCreateFunctionStmt createFunctionStmt =
(ResolvedCreateFunctionStmt) resolvedStatement;
- Function userFunction = createFunction(createFunctionStmt);
try {
- catalog.addFunction(userFunction);
+ catalog.addFunction(createFunctionStmt);
} catch (IllegalArgumentException e) {
- throw new ParseException(
+ throw new RuntimeException(
String.format(
- "Failed to define function %s", String.join(".", createFunctionStmt.getNamePath())),
+ "Failed to define function '%s'",
+ String.join(".", createFunctionStmt.getNamePath())),
e);
}
} else if (resolvedStatement.nodeKind() == RESOLVED_CREATE_TABLE_FUNCTION_STMT) {
ResolvedCreateTableFunctionStmt createTableFunctionStmt =
(ResolvedCreateTableFunctionStmt) resolvedStatement;
- catalog.addTableValuedFunction(
- new TableValuedFunction.FixedOutputSchemaTVF(
- createTableFunctionStmt.getNamePath(),
- createTableFunctionStmt.getSignature(),
- TVFRelation.createColumnBased(
- createTableFunctionStmt.getQuery().getColumnList().stream()
- .map(c -> TVFRelation.Column.create(c.getName(), c.getType()))
- .collect(Collectors.toList()))));
+ catalog.addTableValuedFunction(createTableFunctionStmt);
} else if (!SUPPORTED_STATEMENT_KINDS.contains(resolvedStatement.nodeKind())) {
throw new UnsupportedOperationException(
"Unrecognized statement type " + resolvedStatement.nodeKindString());
@@ -248,179 +176,4 @@ static AnalyzerOptions getAnalyzerOptions(QueryParameters queryParams, String de
return options;
}
-
- /**
- * Creates a SimpleCatalog which represents the top-level schema, populates it with tables,
- * built-in functions.
- */
- SimpleCatalog createPopulatedCatalog(
- String catalogName, AnalyzerOptions options, List> tables) {
-
- SimpleCatalog catalog = new SimpleCatalog(catalogName);
- addBuiltinFunctionsToCatalog(catalog, options);
-
- tables.forEach(table -> addTableToLeafCatalog(queryTrait, catalog, table));
-
- return catalog;
- }
-
- private void addBuiltinFunctionsToCatalog(SimpleCatalog catalog, AnalyzerOptions options) {
- // Enable ZetaSQL builtin functions.
- ZetaSQLBuiltinFunctionOptions zetasqlBuiltinFunctionOptions =
- new ZetaSQLBuiltinFunctionOptions(options.getLanguageOptions());
-
- SupportedZetaSqlBuiltinFunctions.ALLOWLIST.forEach(
- zetasqlBuiltinFunctionOptions::includeFunctionSignatureId);
-
- catalog.addZetaSQLFunctions(zetasqlBuiltinFunctionOptions);
-
- FUNCTION_LIST.stream()
- .map(func -> (ResolvedCreateFunctionStmt) Analyzer.analyzeStatement(func, options, catalog))
- .map(
- resolvedFunc ->
- new Function(
- String.join(".", resolvedFunc.getNamePath()),
- PRE_DEFINED_WINDOW_FUNCTIONS,
- Mode.SCALAR,
- ImmutableList.of(resolvedFunc.getSignature())))
- .forEach(catalog::addFunction);
-
- FunctionArgumentType retType =
- new FunctionArgumentType(SignatureArgumentKind.ARG_TYPE_RELATION);
-
- FunctionArgumentType inputTableType =
- new FunctionArgumentType(SignatureArgumentKind.ARG_TYPE_RELATION);
-
- FunctionArgumentType descriptorType =
- new FunctionArgumentType(
- SignatureArgumentKind.ARG_TYPE_DESCRIPTOR,
- FunctionArgumentType.FunctionArgumentTypeOptions.builder()
- .setDescriptorResolutionTableOffset(0)
- .build(),
- 1);
-
- FunctionArgumentType stringType =
- new FunctionArgumentType(TypeFactory.createSimpleType(TypeKind.TYPE_STRING));
-
- // TUMBLE
- catalog.addTableValuedFunction(
- new TableValuedFunction.ForwardInputSchemaToOutputSchemaWithAppendedColumnTVF(
- ImmutableList.of(TVFStreamingUtils.FIXED_WINDOW_TVF),
- new FunctionSignature(
- retType, ImmutableList.of(inputTableType, descriptorType, stringType), -1),
- ImmutableList.of(
- TVFRelation.Column.create(
- TVFStreamingUtils.WINDOW_START,
- TypeFactory.createSimpleType(TypeKind.TYPE_TIMESTAMP)),
- TVFRelation.Column.create(
- TVFStreamingUtils.WINDOW_END,
- TypeFactory.createSimpleType(TypeKind.TYPE_TIMESTAMP))),
- null,
- null));
-
- // HOP
- catalog.addTableValuedFunction(
- new TableValuedFunction.ForwardInputSchemaToOutputSchemaWithAppendedColumnTVF(
- ImmutableList.of(TVFStreamingUtils.SLIDING_WINDOW_TVF),
- new FunctionSignature(
- retType,
- ImmutableList.of(inputTableType, descriptorType, stringType, stringType),
- -1),
- ImmutableList.of(
- TVFRelation.Column.create(
- TVFStreamingUtils.WINDOW_START,
- TypeFactory.createSimpleType(TypeKind.TYPE_TIMESTAMP)),
- TVFRelation.Column.create(
- TVFStreamingUtils.WINDOW_END,
- TypeFactory.createSimpleType(TypeKind.TYPE_TIMESTAMP))),
- null,
- null));
-
- // SESSION
- catalog.addTableValuedFunction(
- new TableValuedFunction.ForwardInputSchemaToOutputSchemaWithAppendedColumnTVF(
- ImmutableList.of(TVFStreamingUtils.SESSION_WINDOW_TVF),
- new FunctionSignature(
- retType,
- ImmutableList.of(inputTableType, descriptorType, descriptorType, stringType),
- -1),
- ImmutableList.of(
- TVFRelation.Column.create(
- TVFStreamingUtils.WINDOW_START,
- TypeFactory.createSimpleType(TypeKind.TYPE_TIMESTAMP)),
- TVFRelation.Column.create(
- TVFStreamingUtils.WINDOW_END,
- TypeFactory.createSimpleType(TypeKind.TYPE_TIMESTAMP))),
- null,
- null));
- }
-
- /**
- * Assume last element in tablePath is a table name, and everything before is catalogs. So the
- * logic is to create nested catalogs until the last level, then add a table at the last level.
- *
- * Table schema is extracted from Calcite schema based on the table name resultion strategy,
- * e.g. either by drilling down the schema.getSubschema() path or joining the table name with dots
- * to construct a single compound identifier (e.g. Data Catalog use case).
- */
- private void addTableToLeafCatalog(
- QueryTrait trait, SimpleCatalog topLevelCatalog, List tablePath) {
-
- SimpleCatalog leafCatalog = createNestedCatalogs(topLevelCatalog, tablePath);
-
- org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.schema.Table calciteTable =
- TableResolution.resolveCalciteTable(topLevelSchema, tablePath);
-
- if (calciteTable == null) {
- throw new SqlConversionException(
- "Wasn't able to resolve the path "
- + tablePath
- + " in schema: "
- + topLevelSchema.getName());
- }
-
- RelDataType rowType = calciteTable.getRowType(typeFactory);
-
- SimpleTableWithPath tableWithPath = SimpleTableWithPath.of(tablePath);
- trait.addResolvedTable(tableWithPath);
-
- addFieldsToTable(tableWithPath, rowType);
- leafCatalog.addSimpleTable(tableWithPath.getTable());
- }
-
- private void addFieldsToTable(SimpleTableWithPath tableWithPath, RelDataType rowType) {
- for (RelDataTypeField field : rowType.getFieldList()) {
- tableWithPath
- .getTable()
- .addSimpleColumn(
- field.getName(), ZetaSqlCalciteTranslationUtils.toZetaSqlType(field.getType()));
- }
- }
-
- /** For table path like a.b.c we assume c is the table and a.b are the nested catalogs/schemas. */
- private SimpleCatalog createNestedCatalogs(SimpleCatalog catalog, List tablePath) {
- SimpleCatalog currentCatalog = catalog;
- for (int i = 0; i < tablePath.size() - 1; i++) {
- String nextCatalogName = tablePath.get(i);
-
- Optional existing = tryGetExisting(currentCatalog, nextCatalogName);
-
- currentCatalog =
- existing.isPresent() ? existing.get() : addNewCatalog(currentCatalog, nextCatalogName);
- }
- return currentCatalog;
- }
-
- private Optional tryGetExisting(
- SimpleCatalog currentCatalog, String nextCatalogName) {
- return currentCatalog.getCatalogList().stream()
- .filter(c -> nextCatalogName.equals(c.getFullName()))
- .findFirst();
- }
-
- private SimpleCatalog addNewCatalog(SimpleCatalog currentCatalog, String nextCatalogName) {
- SimpleCatalog nextCatalog = new SimpleCatalog(nextCatalogName);
- currentCatalog.addSimpleCatalog(nextCatalog);
- return nextCatalog;
- }
}
diff --git a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/ZetaSQLPlannerImpl.java b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/ZetaSQLPlannerImpl.java
index 028a697c144d..7004b0eb62a9 100644
--- a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/ZetaSQLPlannerImpl.java
+++ b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/ZetaSQLPlannerImpl.java
@@ -17,29 +17,14 @@
*/
package org.apache.beam.sdk.extensions.sql.zetasql;
-import static com.google.zetasql.ZetaSQLResolvedNodeKind.ResolvedNodeKind.RESOLVED_CREATE_FUNCTION_STMT;
-import static com.google.zetasql.ZetaSQLResolvedNodeKind.ResolvedNodeKind.RESOLVED_CREATE_TABLE_FUNCTION_STMT;
-import static com.google.zetasql.ZetaSQLResolvedNodeKind.ResolvedNodeKind.RESOLVED_QUERY_STMT;
-
import com.google.zetasql.AnalyzerOptions;
import com.google.zetasql.LanguageOptions;
-import com.google.zetasql.ParseResumeLocation;
-import com.google.zetasql.SimpleCatalog;
-import com.google.zetasql.ZetaSQLType;
-import com.google.zetasql.resolvedast.ResolvedNode;
-import com.google.zetasql.resolvedast.ResolvedNodes;
-import com.google.zetasql.resolvedast.ResolvedNodes.ResolvedCreateFunctionStmt;
-import com.google.zetasql.resolvedast.ResolvedNodes.ResolvedCreateTableFunctionStmt;
import com.google.zetasql.resolvedast.ResolvedNodes.ResolvedQueryStmt;
-import com.google.zetasql.resolvedast.ResolvedNodes.ResolvedStatement;
import java.util.List;
-import org.apache.beam.sdk.extensions.sql.impl.JavaUdfLoader;
import org.apache.beam.sdk.extensions.sql.impl.QueryPlanner.QueryParameters;
-import org.apache.beam.sdk.extensions.sql.udf.ScalarFn;
import org.apache.beam.sdk.extensions.sql.zetasql.translation.ConversionContext;
import org.apache.beam.sdk.extensions.sql.zetasql.translation.ExpressionConverter;
import org.apache.beam.sdk.extensions.sql.zetasql.translation.QueryStatementConverter;
-import org.apache.beam.sdk.extensions.sql.zetasql.translation.UserFunctionDefinitions;
import org.apache.beam.vendor.calcite.v1_20_0.com.google.common.collect.ImmutableList;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.adapter.java.JavaTypeFactory;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.plan.RelOptCluster;
@@ -55,7 +40,6 @@
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.tools.Frameworks;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.tools.Program;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.util.Util;
-import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap;
/** ZetaSQLPlannerImpl. */
@SuppressWarnings({
@@ -93,79 +77,25 @@ class ZetaSQLPlannerImpl {
public RelRoot rel(String sql, QueryParameters params) {
RelOptCluster cluster = RelOptCluster.create(planner, new RexBuilder(typeFactory));
- QueryTrait trait = new QueryTrait();
- SqlAnalyzer analyzer =
- new SqlAnalyzer(trait, defaultSchemaPlus, (JavaTypeFactory) cluster.getTypeFactory());
-
AnalyzerOptions options = SqlAnalyzer.getAnalyzerOptions(params, defaultTimezone);
+ BeamZetaSqlCatalog catalog =
+ BeamZetaSqlCatalog.create(
+ defaultSchemaPlus, (JavaTypeFactory) cluster.getTypeFactory(), options);
// Set up table providers that need to be pre-registered
+ SqlAnalyzer analyzer = new SqlAnalyzer();
List> tables = analyzer.extractTableNames(sql, options);
TableResolution.registerTables(this.defaultSchemaPlus, tables);
- SimpleCatalog catalog =
- analyzer.createPopulatedCatalog(defaultSchemaPlus.getName(), options, tables);
-
- ImmutableMap.Builder, ResolvedCreateFunctionStmt> udfBuilder =
- ImmutableMap.builder();
- ImmutableMap.Builder, ResolvedNode> udtvfBuilder = ImmutableMap.builder();
- ImmutableMap.Builder, UserFunctionDefinitions.JavaScalarFunction>
- javaScalarFunctionBuilder = ImmutableMap.builder();
- JavaUdfLoader javaUdfLoader = new JavaUdfLoader();
-
- ResolvedStatement statement;
- ParseResumeLocation parseResumeLocation = new ParseResumeLocation(sql);
- do {
- statement = analyzer.analyzeNextStatement(parseResumeLocation, options, catalog);
- if (statement.nodeKind() == RESOLVED_CREATE_FUNCTION_STMT) {
- ResolvedCreateFunctionStmt createFunctionStmt = (ResolvedCreateFunctionStmt) statement;
- String functionGroup = SqlAnalyzer.getFunctionGroup(createFunctionStmt);
- switch (functionGroup) {
- case SqlAnalyzer.USER_DEFINED_SQL_FUNCTIONS:
- udfBuilder.put(createFunctionStmt.getNamePath(), createFunctionStmt);
- break;
- case SqlAnalyzer.USER_DEFINED_JAVA_SCALAR_FUNCTIONS:
- String jarPath = getJarPath(createFunctionStmt);
- ScalarFn scalarFn =
- javaUdfLoader.loadScalarFunction(createFunctionStmt.getNamePath(), jarPath);
- javaScalarFunctionBuilder.put(
- createFunctionStmt.getNamePath(),
- UserFunctionDefinitions.JavaScalarFunction.create(scalarFn, jarPath));
- break;
- default:
- throw new IllegalArgumentException(
- String.format("Encountered unrecognized function group %s.", functionGroup));
- }
- } else if (statement.nodeKind() == RESOLVED_CREATE_TABLE_FUNCTION_STMT) {
- ResolvedCreateTableFunctionStmt createTableFunctionStmt =
- (ResolvedCreateTableFunctionStmt) statement;
- udtvfBuilder.put(createTableFunctionStmt.getNamePath(), createTableFunctionStmt.getQuery());
- } else if (statement.nodeKind() == RESOLVED_QUERY_STMT) {
- if (!SqlAnalyzer.isEndOfInput(parseResumeLocation)) {
- throw new UnsupportedOperationException(
- "No additional statements are allowed after a SELECT statement.");
- }
- break;
- }
- } while (!SqlAnalyzer.isEndOfInput(parseResumeLocation));
-
- if (!(statement instanceof ResolvedQueryStmt)) {
- throw new UnsupportedOperationException(
- "Statement list must end in a SELECT statement, not " + statement.nodeKindString());
- }
-
- UserFunctionDefinitions userFunctionDefinitions =
- UserFunctionDefinitions.newBuilder()
- .setSqlScalarFunctions(udfBuilder.build())
- .setSqlTableValuedFunctions(udtvfBuilder.build())
- .setJavaScalarFunctions(javaScalarFunctionBuilder.build())
- .build();
+ QueryTrait trait = new QueryTrait();
+ catalog.addTables(tables, trait);
+
+ ResolvedQueryStmt statement = analyzer.analyzeQuery(sql, options, catalog);
ExpressionConverter expressionConverter =
- new ExpressionConverter(cluster, params, userFunctionDefinitions);
+ new ExpressionConverter(cluster, params, catalog.getUserFunctionDefinitions());
ConversionContext context = ConversionContext.of(config, expressionConverter, cluster, trait);
- RelNode convertedNode =
- QueryStatementConverter.convertRootQuery(context, (ResolvedQueryStmt) statement);
+ RelNode convertedNode = QueryStatementConverter.convertRootQuery(context, statement);
return RelRoot.of(convertedNode, SqlKind.ALL);
}
@@ -185,39 +115,4 @@ void setDefaultTimezone(String timezone) {
static LanguageOptions getLanguageOptions() {
return SqlAnalyzer.baseAnalyzerOptions().getLanguageOptions();
}
-
- private static String getJarPath(ResolvedCreateFunctionStmt createFunctionStmt) {
- String jarPath = getOptionStringValue(createFunctionStmt, "path");
- if (jarPath.isEmpty()) {
- throw new IllegalArgumentException(
- String.format(
- "No jar was provided to define function %s. Add 'OPTIONS (path=)' to the CREATE FUNCTION statement.",
- String.join(".", createFunctionStmt.getNamePath())));
- }
- return jarPath;
- }
-
- private static String getOptionStringValue(
- ResolvedCreateFunctionStmt createFunctionStmt, String optionName) {
- for (ResolvedNodes.ResolvedOption option : createFunctionStmt.getOptionList()) {
- if (optionName.equals(option.getName())) {
- if (option.getValue() == null) {
- throw new IllegalArgumentException(
- String.format(
- "Option '%s' has null value (expected %s).",
- optionName, ZetaSQLType.TypeKind.TYPE_STRING));
- }
- if (option.getValue().getType().getKind() != ZetaSQLType.TypeKind.TYPE_STRING) {
- throw new IllegalArgumentException(
- String.format(
- "Option '%s' has type %s (expected %s).",
- optionName,
- option.getValue().getType().getKind(),
- ZetaSQLType.TypeKind.TYPE_STRING));
- }
- return ((ResolvedNodes.ResolvedLiteral) option.getValue()).getValue().getStringValue();
- }
- }
- return "";
- }
}
diff --git a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/ZetaSQLQueryPlanner.java b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/ZetaSQLQueryPlanner.java
index 07984d7eefb4..33824c0483d7 100644
--- a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/ZetaSQLQueryPlanner.java
+++ b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/ZetaSQLQueryPlanner.java
@@ -159,7 +159,7 @@ static boolean hasOnlyJavaUdfInProjects(RelOptRuleCall x) {
if (udf.function instanceof ZetaSqlScalarFunctionImpl) {
ZetaSqlScalarFunctionImpl scalarFunction = (ZetaSqlScalarFunctionImpl) udf.function;
if (!scalarFunction.functionGroup.equals(
- SqlAnalyzer.USER_DEFINED_JAVA_SCALAR_FUNCTIONS)) {
+ BeamZetaSqlCatalog.USER_DEFINED_JAVA_SCALAR_FUNCTIONS)) {
// Reject ZetaSQL Builtin Scalar Functions
return false;
}
@@ -224,7 +224,7 @@ static boolean hasNoJavaUdfInProjects(RelOptRuleCall x) {
if (udf.function instanceof ZetaSqlScalarFunctionImpl) {
ZetaSqlScalarFunctionImpl scalarFunction = (ZetaSqlScalarFunctionImpl) udf.function;
if (scalarFunction.functionGroup.equals(
- SqlAnalyzer.USER_DEFINED_JAVA_SCALAR_FUNCTIONS)) {
+ BeamZetaSqlCatalog.USER_DEFINED_JAVA_SCALAR_FUNCTIONS)) {
return false;
}
}
diff --git a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/ExpressionConverter.java b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/ExpressionConverter.java
index 2a2c08fa5c64..29a37cf6097e 100644
--- a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/ExpressionConverter.java
+++ b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/ExpressionConverter.java
@@ -24,10 +24,10 @@
import static com.google.zetasql.ZetaSQLType.TypeKind.TYPE_INT64;
import static com.google.zetasql.ZetaSQLType.TypeKind.TYPE_STRING;
import static com.google.zetasql.ZetaSQLType.TypeKind.TYPE_TIMESTAMP;
-import static org.apache.beam.sdk.extensions.sql.zetasql.SqlAnalyzer.PRE_DEFINED_WINDOW_FUNCTIONS;
-import static org.apache.beam.sdk.extensions.sql.zetasql.SqlAnalyzer.USER_DEFINED_JAVA_SCALAR_FUNCTIONS;
-import static org.apache.beam.sdk.extensions.sql.zetasql.SqlAnalyzer.USER_DEFINED_SQL_FUNCTIONS;
-import static org.apache.beam.sdk.extensions.sql.zetasql.SqlAnalyzer.ZETASQL_FUNCTION_GROUP_NAME;
+import static org.apache.beam.sdk.extensions.sql.zetasql.BeamZetaSqlCatalog.PRE_DEFINED_WINDOW_FUNCTIONS;
+import static org.apache.beam.sdk.extensions.sql.zetasql.BeamZetaSqlCatalog.USER_DEFINED_JAVA_SCALAR_FUNCTIONS;
+import static org.apache.beam.sdk.extensions.sql.zetasql.BeamZetaSqlCatalog.USER_DEFINED_SQL_FUNCTIONS;
+import static org.apache.beam.sdk.extensions.sql.zetasql.BeamZetaSqlCatalog.ZETASQL_FUNCTION_GROUP_NAME;
import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkArgument;
import com.google.common.base.Ascii;
diff --git a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/SqlOperators.java b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/SqlOperators.java
index 584c727a3fd2..ea7e24f4c76c 100644
--- a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/SqlOperators.java
+++ b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/SqlOperators.java
@@ -17,6 +17,8 @@
*/
package org.apache.beam.sdk.extensions.sql.zetasql.translation;
+import static org.apache.beam.sdk.extensions.sql.zetasql.BeamZetaSqlCatalog.ZETASQL_FUNCTION_GROUP_NAME;
+
import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.List;
@@ -28,7 +30,6 @@
import org.apache.beam.sdk.extensions.sql.impl.transform.agg.CountIf;
import org.apache.beam.sdk.extensions.sql.impl.udaf.StringAgg;
import org.apache.beam.sdk.extensions.sql.zetasql.DateTimeUtils;
-import org.apache.beam.sdk.extensions.sql.zetasql.SqlAnalyzer;
import org.apache.beam.sdk.extensions.sql.zetasql.translation.impl.BeamBuiltinMethods;
import org.apache.beam.sdk.extensions.sql.zetasql.translation.impl.CastFunctionImpl;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.jdbc.JavaTypeFactoryImpl;
@@ -86,56 +87,43 @@ public class SqlOperators {
public static final SqlOperator START_WITHS =
createUdfOperator(
- "STARTS_WITH",
- BeamBuiltinMethods.STARTS_WITH_METHOD,
- SqlAnalyzer.ZETASQL_FUNCTION_GROUP_NAME);
+ "STARTS_WITH", BeamBuiltinMethods.STARTS_WITH_METHOD, ZETASQL_FUNCTION_GROUP_NAME);
public static final SqlOperator CONCAT =
- createUdfOperator(
- "CONCAT", BeamBuiltinMethods.CONCAT_METHOD, SqlAnalyzer.ZETASQL_FUNCTION_GROUP_NAME);
+ createUdfOperator("CONCAT", BeamBuiltinMethods.CONCAT_METHOD, ZETASQL_FUNCTION_GROUP_NAME);
public static final SqlOperator REPLACE =
- createUdfOperator(
- "REPLACE", BeamBuiltinMethods.REPLACE_METHOD, SqlAnalyzer.ZETASQL_FUNCTION_GROUP_NAME);
+ createUdfOperator("REPLACE", BeamBuiltinMethods.REPLACE_METHOD, ZETASQL_FUNCTION_GROUP_NAME);
public static final SqlOperator TRIM =
- createUdfOperator(
- "TRIM", BeamBuiltinMethods.TRIM_METHOD, SqlAnalyzer.ZETASQL_FUNCTION_GROUP_NAME);
+ createUdfOperator("TRIM", BeamBuiltinMethods.TRIM_METHOD, ZETASQL_FUNCTION_GROUP_NAME);
public static final SqlOperator LTRIM =
- createUdfOperator(
- "LTRIM", BeamBuiltinMethods.LTRIM_METHOD, SqlAnalyzer.ZETASQL_FUNCTION_GROUP_NAME);
+ createUdfOperator("LTRIM", BeamBuiltinMethods.LTRIM_METHOD, ZETASQL_FUNCTION_GROUP_NAME);
public static final SqlOperator RTRIM =
- createUdfOperator(
- "RTRIM", BeamBuiltinMethods.RTRIM_METHOD, SqlAnalyzer.ZETASQL_FUNCTION_GROUP_NAME);
+ createUdfOperator("RTRIM", BeamBuiltinMethods.RTRIM_METHOD, ZETASQL_FUNCTION_GROUP_NAME);
public static final SqlOperator SUBSTR =
- createUdfOperator(
- "SUBSTR", BeamBuiltinMethods.SUBSTR_METHOD, SqlAnalyzer.ZETASQL_FUNCTION_GROUP_NAME);
+ createUdfOperator("SUBSTR", BeamBuiltinMethods.SUBSTR_METHOD, ZETASQL_FUNCTION_GROUP_NAME);
public static final SqlOperator REVERSE =
- createUdfOperator(
- "REVERSE", BeamBuiltinMethods.REVERSE_METHOD, SqlAnalyzer.ZETASQL_FUNCTION_GROUP_NAME);
+ createUdfOperator("REVERSE", BeamBuiltinMethods.REVERSE_METHOD, ZETASQL_FUNCTION_GROUP_NAME);
public static final SqlOperator CHAR_LENGTH =
createUdfOperator(
- "CHAR_LENGTH",
- BeamBuiltinMethods.CHAR_LENGTH_METHOD,
- SqlAnalyzer.ZETASQL_FUNCTION_GROUP_NAME);
+ "CHAR_LENGTH", BeamBuiltinMethods.CHAR_LENGTH_METHOD, ZETASQL_FUNCTION_GROUP_NAME);
public static final SqlOperator ENDS_WITH =
createUdfOperator(
- "ENDS_WITH",
- BeamBuiltinMethods.ENDS_WITH_METHOD,
- SqlAnalyzer.ZETASQL_FUNCTION_GROUP_NAME);
+ "ENDS_WITH", BeamBuiltinMethods.ENDS_WITH_METHOD, ZETASQL_FUNCTION_GROUP_NAME);
public static final SqlOperator LIKE =
createUdfOperator(
"LIKE",
BeamBuiltinMethods.LIKE_METHOD,
SqlSyntax.BINARY,
- SqlAnalyzer.ZETASQL_FUNCTION_GROUP_NAME,
+ ZETASQL_FUNCTION_GROUP_NAME,
"");
public static final SqlOperator VALIDATE_TIMESTAMP =
@@ -145,7 +133,7 @@ public class SqlOperators {
"validateTimestamp",
x -> NULLABLE_TIMESTAMP,
ImmutableList.of(TIMESTAMP),
- SqlAnalyzer.ZETASQL_FUNCTION_GROUP_NAME);
+ ZETASQL_FUNCTION_GROUP_NAME);
public static final SqlOperator VALIDATE_TIME_INTERVAL =
createUdfOperator(
@@ -154,17 +142,14 @@ public class SqlOperators {
"validateTimeInterval",
x -> NULLABLE_BIGINT,
ImmutableList.of(BIGINT, OTHER),
- SqlAnalyzer.ZETASQL_FUNCTION_GROUP_NAME);
+ ZETASQL_FUNCTION_GROUP_NAME);
public static final SqlOperator TIMESTAMP_OP =
createUdfOperator(
- "TIMESTAMP",
- BeamBuiltinMethods.TIMESTAMP_METHOD,
- SqlAnalyzer.ZETASQL_FUNCTION_GROUP_NAME);
+ "TIMESTAMP", BeamBuiltinMethods.TIMESTAMP_METHOD, ZETASQL_FUNCTION_GROUP_NAME);
public static final SqlOperator DATE_OP =
- createUdfOperator(
- "DATE", BeamBuiltinMethods.DATE_METHOD, SqlAnalyzer.ZETASQL_FUNCTION_GROUP_NAME);
+ createUdfOperator("DATE", BeamBuiltinMethods.DATE_METHOD, ZETASQL_FUNCTION_GROUP_NAME);
public static final SqlOperator BIT_XOR =
createUdafOperator(
diff --git a/sdks/java/extensions/sql/zetasql/src/test/java/org/apache/beam/sdk/extensions/sql/zetasql/ZetaSqlJavaUdfTest.java b/sdks/java/extensions/sql/zetasql/src/test/java/org/apache/beam/sdk/extensions/sql/zetasql/ZetaSqlJavaUdfTest.java
index 41a55bbbadfe..2fa564648fe6 100644
--- a/sdks/java/extensions/sql/zetasql/src/test/java/org/apache/beam/sdk/extensions/sql/zetasql/ZetaSqlJavaUdfTest.java
+++ b/sdks/java/extensions/sql/zetasql/src/test/java/org/apache/beam/sdk/extensions/sql/zetasql/ZetaSqlJavaUdfTest.java
@@ -244,8 +244,12 @@ public void testJavaUdfEmptyPath() {
String sql =
"CREATE FUNCTION foo() RETURNS STRING LANGUAGE java OPTIONS (path=''); SELECT foo();";
ZetaSQLQueryPlanner zetaSQLQueryPlanner = new ZetaSQLQueryPlanner(config);
- thrown.expect(IllegalArgumentException.class);
- thrown.expectMessage("No jar was provided to define function foo.");
+ thrown.expect(RuntimeException.class);
+ thrown.expectMessage("Failed to define function 'foo'");
+ thrown.expectCause(
+ allOf(
+ isA(IllegalArgumentException.class),
+ hasProperty("message", containsString("No jar was provided to define function foo."))));
zetaSQLQueryPlanner.convertToBeamRel(sql);
}
@@ -253,8 +257,12 @@ public void testJavaUdfEmptyPath() {
public void testJavaUdfNoJarProvided() {
String sql = "CREATE FUNCTION foo() RETURNS STRING LANGUAGE java; SELECT foo();";
ZetaSQLQueryPlanner zetaSQLQueryPlanner = new ZetaSQLQueryPlanner(config);
- thrown.expect(IllegalArgumentException.class);
- thrown.expectMessage("No jar was provided to define function foo.");
+ thrown.expect(RuntimeException.class);
+ thrown.expectMessage("Failed to define function 'foo'");
+ thrown.expectCause(
+ allOf(
+ isA(IllegalArgumentException.class),
+ hasProperty("message", containsString("No jar was provided to define function foo."))));
zetaSQLQueryPlanner.convertToBeamRel(sql);
}
@@ -263,8 +271,14 @@ public void testPathOptionNotString() {
String sql =
"CREATE FUNCTION foo() RETURNS STRING LANGUAGE java OPTIONS (path=23); SELECT foo();";
ZetaSQLQueryPlanner zetaSQLQueryPlanner = new ZetaSQLQueryPlanner(config);
- thrown.expect(IllegalArgumentException.class);
- thrown.expectMessage("Option 'path' has type TYPE_INT64 (expected TYPE_STRING).");
+ thrown.expect(RuntimeException.class);
+ thrown.expectMessage("Failed to define function 'foo'");
+ thrown.expectCause(
+ allOf(
+ isA(IllegalArgumentException.class),
+ hasProperty(
+ "message",
+ containsString("Option 'path' has type TYPE_INT64 (expected TYPE_STRING)."))));
zetaSQLQueryPlanner.convertToBeamRel(sql);
}
}
diff --git a/sdks/java/extensions/sql/zetasql/src/test/java/org/apache/beam/sdk/extensions/sql/zetasql/ZetaSqlUdfTest.java b/sdks/java/extensions/sql/zetasql/src/test/java/org/apache/beam/sdk/extensions/sql/zetasql/ZetaSqlNativeUdfTest.java
similarity index 96%
rename from sdks/java/extensions/sql/zetasql/src/test/java/org/apache/beam/sdk/extensions/sql/zetasql/ZetaSqlUdfTest.java
rename to sdks/java/extensions/sql/zetasql/src/test/java/org/apache/beam/sdk/extensions/sql/zetasql/ZetaSqlNativeUdfTest.java
index 53b309a10b95..0357242e7481 100644
--- a/sdks/java/extensions/sql/zetasql/src/test/java/org/apache/beam/sdk/extensions/sql/zetasql/ZetaSqlUdfTest.java
+++ b/sdks/java/extensions/sql/zetasql/src/test/java/org/apache/beam/sdk/extensions/sql/zetasql/ZetaSqlNativeUdfTest.java
@@ -17,8 +17,9 @@
*/
package org.apache.beam.sdk.extensions.sql.zetasql;
+import static org.hamcrest.Matchers.isA;
+
import com.google.zetasql.SqlException;
-import org.apache.beam.sdk.extensions.sql.impl.ParseException;
import org.apache.beam.sdk.extensions.sql.impl.SqlConversionException;
import org.apache.beam.sdk.extensions.sql.impl.rel.BeamRelNode;
import org.apache.beam.sdk.extensions.sql.impl.rel.BeamSqlRelUtils;
@@ -36,9 +37,9 @@
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
-/** Tests for user defined functions in the ZetaSQL dialect. */
+/** Tests for SQL-native user defined functions in the ZetaSQL dialect. */
@RunWith(JUnit4.class)
-public class ZetaSqlUdfTest extends ZetaSqlTestBase {
+public class ZetaSqlNativeUdfTest extends ZetaSqlTestBase {
@Rule public transient TestPipeline pipeline = TestPipeline.create();
@Rule public ExpectedException thrown = ExpectedException.none();
@@ -51,8 +52,9 @@ public void setUp() {
public void testAlreadyDefinedUDFThrowsException() {
String sql = "CREATE FUNCTION foo() AS (0); CREATE FUNCTION foo() AS (1); SELECT foo();";
ZetaSQLQueryPlanner zetaSQLQueryPlanner = new ZetaSQLQueryPlanner(config);
- thrown.expect(ParseException.class);
- thrown.expectMessage("Failed to define function foo");
+ thrown.expect(RuntimeException.class);
+ thrown.expectMessage("Failed to define function 'foo'");
+ thrown.expectCause(isA(IllegalArgumentException.class));
zetaSQLQueryPlanner.convertToBeamRel(sql);
}