Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,7 @@
<suppress id="ForbidNonVendoredGuava" files=".*sql.*BeamEnumerableConverterTest\.java" />
<suppress id="ForbidNonVendoredGuava" files=".*zetasql.*TableScanConverter\.java" />
<suppress id="ForbidNonVendoredGuava" files=".*zetasql.*ExpressionConverter\.java" />
<suppress id="ForbidNonVendoredGuava" files=".*zetasql.*ZetaSQLPlannerImpl\.java" />
<suppress id="ForbidNonVendoredGuava" files=".*zetasql.*SqlAnalyzer\.java" />
<suppress id="ForbidNonVendoredGuava" files=".*zetasql.*BeamZetaSqlCatalog\.java" />
<suppress id="ForbidNonVendoredGuava" files=".*pubsublite.*BufferingPullSubscriberTest\.java" />

<!-- gRPC/protobuf exceptions -->
Expand Down

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -17,29 +17,14 @@
*/
package org.apache.beam.sdk.extensions.sql.zetasql;

import static com.google.zetasql.ZetaSQLResolvedNodeKind.ResolvedNodeKind.RESOLVED_CREATE_FUNCTION_STMT;
import static com.google.zetasql.ZetaSQLResolvedNodeKind.ResolvedNodeKind.RESOLVED_CREATE_TABLE_FUNCTION_STMT;
import static com.google.zetasql.ZetaSQLResolvedNodeKind.ResolvedNodeKind.RESOLVED_QUERY_STMT;

import com.google.zetasql.AnalyzerOptions;
import com.google.zetasql.LanguageOptions;
import com.google.zetasql.ParseResumeLocation;
import com.google.zetasql.SimpleCatalog;
import com.google.zetasql.ZetaSQLType;
import com.google.zetasql.resolvedast.ResolvedNode;
import com.google.zetasql.resolvedast.ResolvedNodes;
import com.google.zetasql.resolvedast.ResolvedNodes.ResolvedCreateFunctionStmt;
import com.google.zetasql.resolvedast.ResolvedNodes.ResolvedCreateTableFunctionStmt;
import com.google.zetasql.resolvedast.ResolvedNodes.ResolvedQueryStmt;
import com.google.zetasql.resolvedast.ResolvedNodes.ResolvedStatement;
import java.util.List;
import org.apache.beam.sdk.extensions.sql.impl.JavaUdfLoader;
import org.apache.beam.sdk.extensions.sql.impl.QueryPlanner.QueryParameters;
import org.apache.beam.sdk.extensions.sql.udf.ScalarFn;
import org.apache.beam.sdk.extensions.sql.zetasql.translation.ConversionContext;
import org.apache.beam.sdk.extensions.sql.zetasql.translation.ExpressionConverter;
import org.apache.beam.sdk.extensions.sql.zetasql.translation.QueryStatementConverter;
import org.apache.beam.sdk.extensions.sql.zetasql.translation.UserFunctionDefinitions;
import org.apache.beam.vendor.calcite.v1_20_0.com.google.common.collect.ImmutableList;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.adapter.java.JavaTypeFactory;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.plan.RelOptCluster;
Expand All @@ -55,7 +40,6 @@
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.tools.Frameworks;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.tools.Program;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.util.Util;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap;

/** ZetaSQLPlannerImpl. */
@SuppressWarnings({
Expand Down Expand Up @@ -93,79 +77,25 @@ class ZetaSQLPlannerImpl {

public RelRoot rel(String sql, QueryParameters params) {
RelOptCluster cluster = RelOptCluster.create(planner, new RexBuilder(typeFactory));
QueryTrait trait = new QueryTrait();
SqlAnalyzer analyzer =
new SqlAnalyzer(trait, defaultSchemaPlus, (JavaTypeFactory) cluster.getTypeFactory());

AnalyzerOptions options = SqlAnalyzer.getAnalyzerOptions(params, defaultTimezone);
BeamZetaSqlCatalog catalog =
BeamZetaSqlCatalog.create(
defaultSchemaPlus, (JavaTypeFactory) cluster.getTypeFactory(), options);

// Set up table providers that need to be pre-registered
SqlAnalyzer analyzer = new SqlAnalyzer();
List<List<String>> tables = analyzer.extractTableNames(sql, options);
TableResolution.registerTables(this.defaultSchemaPlus, tables);
SimpleCatalog catalog =
analyzer.createPopulatedCatalog(defaultSchemaPlus.getName(), options, tables);

ImmutableMap.Builder<List<String>, ResolvedCreateFunctionStmt> udfBuilder =
ImmutableMap.builder();
ImmutableMap.Builder<List<String>, ResolvedNode> udtvfBuilder = ImmutableMap.builder();
ImmutableMap.Builder<List<String>, UserFunctionDefinitions.JavaScalarFunction>
javaScalarFunctionBuilder = ImmutableMap.builder();
JavaUdfLoader javaUdfLoader = new JavaUdfLoader();

ResolvedStatement statement;
ParseResumeLocation parseResumeLocation = new ParseResumeLocation(sql);
do {
statement = analyzer.analyzeNextStatement(parseResumeLocation, options, catalog);
if (statement.nodeKind() == RESOLVED_CREATE_FUNCTION_STMT) {
ResolvedCreateFunctionStmt createFunctionStmt = (ResolvedCreateFunctionStmt) statement;
String functionGroup = SqlAnalyzer.getFunctionGroup(createFunctionStmt);
switch (functionGroup) {
case SqlAnalyzer.USER_DEFINED_SQL_FUNCTIONS:
udfBuilder.put(createFunctionStmt.getNamePath(), createFunctionStmt);
break;
case SqlAnalyzer.USER_DEFINED_JAVA_SCALAR_FUNCTIONS:
String jarPath = getJarPath(createFunctionStmt);
ScalarFn scalarFn =
javaUdfLoader.loadScalarFunction(createFunctionStmt.getNamePath(), jarPath);
javaScalarFunctionBuilder.put(
createFunctionStmt.getNamePath(),
UserFunctionDefinitions.JavaScalarFunction.create(scalarFn, jarPath));
break;
default:
throw new IllegalArgumentException(
String.format("Encountered unrecognized function group %s.", functionGroup));
}
} else if (statement.nodeKind() == RESOLVED_CREATE_TABLE_FUNCTION_STMT) {
ResolvedCreateTableFunctionStmt createTableFunctionStmt =
(ResolvedCreateTableFunctionStmt) statement;
udtvfBuilder.put(createTableFunctionStmt.getNamePath(), createTableFunctionStmt.getQuery());
} else if (statement.nodeKind() == RESOLVED_QUERY_STMT) {
if (!SqlAnalyzer.isEndOfInput(parseResumeLocation)) {
throw new UnsupportedOperationException(
"No additional statements are allowed after a SELECT statement.");
}
break;
}
} while (!SqlAnalyzer.isEndOfInput(parseResumeLocation));

if (!(statement instanceof ResolvedQueryStmt)) {
throw new UnsupportedOperationException(
"Statement list must end in a SELECT statement, not " + statement.nodeKindString());
}

UserFunctionDefinitions userFunctionDefinitions =
UserFunctionDefinitions.newBuilder()
.setSqlScalarFunctions(udfBuilder.build())
.setSqlTableValuedFunctions(udtvfBuilder.build())
.setJavaScalarFunctions(javaScalarFunctionBuilder.build())
.build();
QueryTrait trait = new QueryTrait();
catalog.addTables(tables, trait);

ResolvedQueryStmt statement = analyzer.analyzeQuery(sql, options, catalog);

ExpressionConverter expressionConverter =
new ExpressionConverter(cluster, params, userFunctionDefinitions);
new ExpressionConverter(cluster, params, catalog.getUserFunctionDefinitions());
ConversionContext context = ConversionContext.of(config, expressionConverter, cluster, trait);

RelNode convertedNode =
QueryStatementConverter.convertRootQuery(context, (ResolvedQueryStmt) statement);
RelNode convertedNode = QueryStatementConverter.convertRootQuery(context, statement);
return RelRoot.of(convertedNode, SqlKind.ALL);
}

Expand All @@ -185,39 +115,4 @@ void setDefaultTimezone(String timezone) {
static LanguageOptions getLanguageOptions() {
return SqlAnalyzer.baseAnalyzerOptions().getLanguageOptions();
}

private static String getJarPath(ResolvedCreateFunctionStmt createFunctionStmt) {
String jarPath = getOptionStringValue(createFunctionStmt, "path");
if (jarPath.isEmpty()) {
throw new IllegalArgumentException(
String.format(
"No jar was provided to define function %s. Add 'OPTIONS (path=<jar location>)' to the CREATE FUNCTION statement.",
String.join(".", createFunctionStmt.getNamePath())));
}
return jarPath;
}

private static String getOptionStringValue(
ResolvedCreateFunctionStmt createFunctionStmt, String optionName) {
for (ResolvedNodes.ResolvedOption option : createFunctionStmt.getOptionList()) {
if (optionName.equals(option.getName())) {
if (option.getValue() == null) {
throw new IllegalArgumentException(
String.format(
"Option '%s' has null value (expected %s).",
optionName, ZetaSQLType.TypeKind.TYPE_STRING));
}
if (option.getValue().getType().getKind() != ZetaSQLType.TypeKind.TYPE_STRING) {
throw new IllegalArgumentException(
String.format(
"Option '%s' has type %s (expected %s).",
optionName,
option.getValue().getType().getKind(),
ZetaSQLType.TypeKind.TYPE_STRING));
}
return ((ResolvedNodes.ResolvedLiteral) option.getValue()).getValue().getStringValue();
}
}
return "";
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ static boolean hasOnlyJavaUdfInProjects(RelOptRuleCall x) {
if (udf.function instanceof ZetaSqlScalarFunctionImpl) {
ZetaSqlScalarFunctionImpl scalarFunction = (ZetaSqlScalarFunctionImpl) udf.function;
if (!scalarFunction.functionGroup.equals(
SqlAnalyzer.USER_DEFINED_JAVA_SCALAR_FUNCTIONS)) {
BeamZetaSqlCatalog.USER_DEFINED_JAVA_SCALAR_FUNCTIONS)) {
// Reject ZetaSQL Builtin Scalar Functions
return false;
}
Expand Down Expand Up @@ -224,7 +224,7 @@ static boolean hasNoJavaUdfInProjects(RelOptRuleCall x) {
if (udf.function instanceof ZetaSqlScalarFunctionImpl) {
ZetaSqlScalarFunctionImpl scalarFunction = (ZetaSqlScalarFunctionImpl) udf.function;
if (scalarFunction.functionGroup.equals(
SqlAnalyzer.USER_DEFINED_JAVA_SCALAR_FUNCTIONS)) {
BeamZetaSqlCatalog.USER_DEFINED_JAVA_SCALAR_FUNCTIONS)) {
return false;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@
import static com.google.zetasql.ZetaSQLType.TypeKind.TYPE_INT64;
import static com.google.zetasql.ZetaSQLType.TypeKind.TYPE_STRING;
import static com.google.zetasql.ZetaSQLType.TypeKind.TYPE_TIMESTAMP;
import static org.apache.beam.sdk.extensions.sql.zetasql.SqlAnalyzer.PRE_DEFINED_WINDOW_FUNCTIONS;
import static org.apache.beam.sdk.extensions.sql.zetasql.SqlAnalyzer.USER_DEFINED_JAVA_SCALAR_FUNCTIONS;
import static org.apache.beam.sdk.extensions.sql.zetasql.SqlAnalyzer.USER_DEFINED_SQL_FUNCTIONS;
import static org.apache.beam.sdk.extensions.sql.zetasql.SqlAnalyzer.ZETASQL_FUNCTION_GROUP_NAME;
import static org.apache.beam.sdk.extensions.sql.zetasql.BeamZetaSqlCatalog.PRE_DEFINED_WINDOW_FUNCTIONS;
import static org.apache.beam.sdk.extensions.sql.zetasql.BeamZetaSqlCatalog.USER_DEFINED_JAVA_SCALAR_FUNCTIONS;
import static org.apache.beam.sdk.extensions.sql.zetasql.BeamZetaSqlCatalog.USER_DEFINED_SQL_FUNCTIONS;
import static org.apache.beam.sdk.extensions.sql.zetasql.BeamZetaSqlCatalog.ZETASQL_FUNCTION_GROUP_NAME;
import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkArgument;

import com.google.common.base.Ascii;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
*/
package org.apache.beam.sdk.extensions.sql.zetasql.translation;

import static org.apache.beam.sdk.extensions.sql.zetasql.BeamZetaSqlCatalog.ZETASQL_FUNCTION_GROUP_NAME;

import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.List;
Expand All @@ -28,7 +30,6 @@
import org.apache.beam.sdk.extensions.sql.impl.transform.agg.CountIf;
import org.apache.beam.sdk.extensions.sql.impl.udaf.StringAgg;
import org.apache.beam.sdk.extensions.sql.zetasql.DateTimeUtils;
import org.apache.beam.sdk.extensions.sql.zetasql.SqlAnalyzer;
import org.apache.beam.sdk.extensions.sql.zetasql.translation.impl.BeamBuiltinMethods;
import org.apache.beam.sdk.extensions.sql.zetasql.translation.impl.CastFunctionImpl;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.jdbc.JavaTypeFactoryImpl;
Expand Down Expand Up @@ -86,56 +87,43 @@ public class SqlOperators {

public static final SqlOperator START_WITHS =
createUdfOperator(
"STARTS_WITH",
BeamBuiltinMethods.STARTS_WITH_METHOD,
SqlAnalyzer.ZETASQL_FUNCTION_GROUP_NAME);
"STARTS_WITH", BeamBuiltinMethods.STARTS_WITH_METHOD, ZETASQL_FUNCTION_GROUP_NAME);

public static final SqlOperator CONCAT =
createUdfOperator(
"CONCAT", BeamBuiltinMethods.CONCAT_METHOD, SqlAnalyzer.ZETASQL_FUNCTION_GROUP_NAME);
createUdfOperator("CONCAT", BeamBuiltinMethods.CONCAT_METHOD, ZETASQL_FUNCTION_GROUP_NAME);

public static final SqlOperator REPLACE =
createUdfOperator(
"REPLACE", BeamBuiltinMethods.REPLACE_METHOD, SqlAnalyzer.ZETASQL_FUNCTION_GROUP_NAME);
createUdfOperator("REPLACE", BeamBuiltinMethods.REPLACE_METHOD, ZETASQL_FUNCTION_GROUP_NAME);

public static final SqlOperator TRIM =
createUdfOperator(
"TRIM", BeamBuiltinMethods.TRIM_METHOD, SqlAnalyzer.ZETASQL_FUNCTION_GROUP_NAME);
createUdfOperator("TRIM", BeamBuiltinMethods.TRIM_METHOD, ZETASQL_FUNCTION_GROUP_NAME);

public static final SqlOperator LTRIM =
createUdfOperator(
"LTRIM", BeamBuiltinMethods.LTRIM_METHOD, SqlAnalyzer.ZETASQL_FUNCTION_GROUP_NAME);
createUdfOperator("LTRIM", BeamBuiltinMethods.LTRIM_METHOD, ZETASQL_FUNCTION_GROUP_NAME);

public static final SqlOperator RTRIM =
createUdfOperator(
"RTRIM", BeamBuiltinMethods.RTRIM_METHOD, SqlAnalyzer.ZETASQL_FUNCTION_GROUP_NAME);
createUdfOperator("RTRIM", BeamBuiltinMethods.RTRIM_METHOD, ZETASQL_FUNCTION_GROUP_NAME);

public static final SqlOperator SUBSTR =
createUdfOperator(
"SUBSTR", BeamBuiltinMethods.SUBSTR_METHOD, SqlAnalyzer.ZETASQL_FUNCTION_GROUP_NAME);
createUdfOperator("SUBSTR", BeamBuiltinMethods.SUBSTR_METHOD, ZETASQL_FUNCTION_GROUP_NAME);

public static final SqlOperator REVERSE =
createUdfOperator(
"REVERSE", BeamBuiltinMethods.REVERSE_METHOD, SqlAnalyzer.ZETASQL_FUNCTION_GROUP_NAME);
createUdfOperator("REVERSE", BeamBuiltinMethods.REVERSE_METHOD, ZETASQL_FUNCTION_GROUP_NAME);

public static final SqlOperator CHAR_LENGTH =
createUdfOperator(
"CHAR_LENGTH",
BeamBuiltinMethods.CHAR_LENGTH_METHOD,
SqlAnalyzer.ZETASQL_FUNCTION_GROUP_NAME);
"CHAR_LENGTH", BeamBuiltinMethods.CHAR_LENGTH_METHOD, ZETASQL_FUNCTION_GROUP_NAME);

public static final SqlOperator ENDS_WITH =
createUdfOperator(
"ENDS_WITH",
BeamBuiltinMethods.ENDS_WITH_METHOD,
SqlAnalyzer.ZETASQL_FUNCTION_GROUP_NAME);
"ENDS_WITH", BeamBuiltinMethods.ENDS_WITH_METHOD, ZETASQL_FUNCTION_GROUP_NAME);

public static final SqlOperator LIKE =
createUdfOperator(
"LIKE",
BeamBuiltinMethods.LIKE_METHOD,
SqlSyntax.BINARY,
SqlAnalyzer.ZETASQL_FUNCTION_GROUP_NAME,
ZETASQL_FUNCTION_GROUP_NAME,
"");

public static final SqlOperator VALIDATE_TIMESTAMP =
Expand All @@ -145,7 +133,7 @@ public class SqlOperators {
"validateTimestamp",
x -> NULLABLE_TIMESTAMP,
ImmutableList.of(TIMESTAMP),
SqlAnalyzer.ZETASQL_FUNCTION_GROUP_NAME);
ZETASQL_FUNCTION_GROUP_NAME);

public static final SqlOperator VALIDATE_TIME_INTERVAL =
createUdfOperator(
Expand All @@ -154,17 +142,14 @@ public class SqlOperators {
"validateTimeInterval",
x -> NULLABLE_BIGINT,
ImmutableList.of(BIGINT, OTHER),
SqlAnalyzer.ZETASQL_FUNCTION_GROUP_NAME);
ZETASQL_FUNCTION_GROUP_NAME);

public static final SqlOperator TIMESTAMP_OP =
createUdfOperator(
"TIMESTAMP",
BeamBuiltinMethods.TIMESTAMP_METHOD,
SqlAnalyzer.ZETASQL_FUNCTION_GROUP_NAME);
"TIMESTAMP", BeamBuiltinMethods.TIMESTAMP_METHOD, ZETASQL_FUNCTION_GROUP_NAME);

public static final SqlOperator DATE_OP =
createUdfOperator(
"DATE", BeamBuiltinMethods.DATE_METHOD, SqlAnalyzer.ZETASQL_FUNCTION_GROUP_NAME);
createUdfOperator("DATE", BeamBuiltinMethods.DATE_METHOD, ZETASQL_FUNCTION_GROUP_NAME);

public static final SqlOperator BIT_XOR =
createUdafOperator(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -244,17 +244,25 @@ public void testJavaUdfEmptyPath() {
String sql =
"CREATE FUNCTION foo() RETURNS STRING LANGUAGE java OPTIONS (path=''); SELECT foo();";
ZetaSQLQueryPlanner zetaSQLQueryPlanner = new ZetaSQLQueryPlanner(config);
thrown.expect(IllegalArgumentException.class);
thrown.expectMessage("No jar was provided to define function foo.");
thrown.expect(RuntimeException.class);
thrown.expectMessage("Failed to define function 'foo'");
thrown.expectCause(
allOf(
isA(IllegalArgumentException.class),
hasProperty("message", containsString("No jar was provided to define function foo."))));
zetaSQLQueryPlanner.convertToBeamRel(sql);
}

@Test
public void testJavaUdfNoJarProvided() {
String sql = "CREATE FUNCTION foo() RETURNS STRING LANGUAGE java; SELECT foo();";
ZetaSQLQueryPlanner zetaSQLQueryPlanner = new ZetaSQLQueryPlanner(config);
thrown.expect(IllegalArgumentException.class);
thrown.expectMessage("No jar was provided to define function foo.");
thrown.expect(RuntimeException.class);
thrown.expectMessage("Failed to define function 'foo'");
thrown.expectCause(
allOf(
isA(IllegalArgumentException.class),
hasProperty("message", containsString("No jar was provided to define function foo."))));
zetaSQLQueryPlanner.convertToBeamRel(sql);
}

Expand All @@ -263,8 +271,14 @@ public void testPathOptionNotString() {
String sql =
"CREATE FUNCTION foo() RETURNS STRING LANGUAGE java OPTIONS (path=23); SELECT foo();";
ZetaSQLQueryPlanner zetaSQLQueryPlanner = new ZetaSQLQueryPlanner(config);
thrown.expect(IllegalArgumentException.class);
thrown.expectMessage("Option 'path' has type TYPE_INT64 (expected TYPE_STRING).");
thrown.expect(RuntimeException.class);
thrown.expectMessage("Failed to define function 'foo'");
thrown.expectCause(
allOf(
isA(IllegalArgumentException.class),
hasProperty(
"message",
containsString("Option 'path' has type TYPE_INT64 (expected TYPE_STRING)."))));
zetaSQLQueryPlanner.convertToBeamRel(sql);
}
}
Loading