diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/ScalarFunctionImpl.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/ScalarFunctionImpl.java index da8cb269748d..df725a72a808 100644 --- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/ScalarFunctionImpl.java +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/ScalarFunctionImpl.java @@ -61,8 +61,7 @@ public class ScalarFunctionImpl extends UdfImplReflectiveFunctionBase private final CallImplementor implementor; - /** Private constructor. */ - private ScalarFunctionImpl(Method method, CallImplementor implementor) { + protected ScalarFunctionImpl(Method method, CallImplementor implementor) { super(method); this.implementor = implementor; } @@ -86,24 +85,6 @@ public static ImmutableMultimap createAll(Class clazz) { return builder.build(); } - /** - * Creates {@link org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.schema.Function} from - * given class. - * - *

If a method of the given name is not found or it does not suit, returns {@code null}. - * - * @param clazz class that is used to implement the function - * @param methodName Method name (typically "eval") - * @return created {@link ScalarFunction} or null - */ - public static Function create(Class clazz, String methodName) { - final Method method = findMethod(clazz, methodName); - if (method == null) { - return null; - } - return create(method); - } - /** * Creates {@link org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.schema.Function} from * given method. When {@code eval} method does not suit, {@code null} is returned. @@ -112,6 +93,12 @@ public static Function create(Class clazz, String methodName) { * @return created {@link Function} or null */ public static Function create(Method method) { + validateMethod(method); + CallImplementor implementor = createImplementor(method); + return new ScalarFunctionImpl(method, implementor); + } + + protected static void validateMethod(Method method) { if (!Modifier.isStatic(method.getModifiers())) { Class clazz = method.getDeclaringClass(); if (!classHasPublicZeroArgsConstructor(clazz)) { @@ -121,9 +108,6 @@ public static Function create(Method method) { if (method.getExceptionTypes().length != 0) { throw new RuntimeException(method.getName() + " must not throw checked exception"); } - - CallImplementor implementor = createImplementor(method); - return new ScalarFunctionImpl(method, implementor); } @Override @@ -191,7 +175,7 @@ public Expression implement( } } - private static CallImplementor createImplementor(Method method) { + protected static CallImplementor createImplementor(Method method) { final NullPolicy nullPolicy = getNullPolicy(method); return RexImpTable.createImplementor( new ScalarReflectiveCallNotNullImplementor(method), nullPolicy, false); @@ -247,21 +231,6 @@ static boolean classHasPublicZeroArgsConstructor(Class clazz) { } return false; } - - /* - * Finds a method in a given class by name. - * @param clazz class to search method in - * @param name name of the method to find - * @return the first method with matching name or null when no method found - */ - static Method findMethod(Class clazz, String name) { - for (Method method : clazz.getMethods()) { - if (method.getName().equals(name) && !method.isBridge()) { - return method; - } - } - return null; - } } // End ScalarFunctionImpl.java diff --git a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/SqlAnalyzer.java b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/SqlAnalyzer.java index b4666cd11405..f4db1f194a94 100644 --- a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/SqlAnalyzer.java +++ b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/SqlAnalyzer.java @@ -69,8 +69,14 @@ "nullness" // TODO(https://issues.apache.org/jira/browse/BEAM-10402) }) public class SqlAnalyzer { + // ZetaSQL function group identifiers. Different function groups may have divergent translation + // paths. public static final String PRE_DEFINED_WINDOW_FUNCTIONS = "pre_defined_window_functions"; public static final String USER_DEFINED_FUNCTIONS = "user_defined_functions"; + /** + * Same as {@link Function}.ZETASQL_FUNCTION_GROUP_NAME. Identifies built-in ZetaSQL functions. + */ + public static final String ZETASQL_FUNCTION_GROUP_NAME = "ZetaSQL"; private static final ImmutableSet SUPPORTED_STATEMENT_KINDS = ImmutableSet.of( diff --git a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/ExpressionConverter.java b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/ExpressionConverter.java index 24a5e1c74c09..a4d0f0335bed 100644 --- a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/ExpressionConverter.java +++ b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/ExpressionConverter.java @@ -26,6 +26,7 @@ import static com.google.zetasql.ZetaSQLType.TypeKind.TYPE_TIMESTAMP; import static org.apache.beam.sdk.extensions.sql.zetasql.SqlAnalyzer.PRE_DEFINED_WINDOW_FUNCTIONS; import static org.apache.beam.sdk.extensions.sql.zetasql.SqlAnalyzer.USER_DEFINED_FUNCTIONS; +import static org.apache.beam.sdk.extensions.sql.zetasql.SqlAnalyzer.ZETASQL_FUNCTION_GROUP_NAME; import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkArgument; import com.google.common.base.Ascii; @@ -608,7 +609,7 @@ private RexNode convertResolvedFunctionCall( SqlOperator op = SqlOperatorMappingTable.ZETASQL_FUNCTION_TO_CALCITE_SQL_OPERATOR.get(funName); List operands = new ArrayList<>(); - if (funGroup.equals(PRE_DEFINED_WINDOW_FUNCTIONS)) { + if (PRE_DEFINED_WINDOW_FUNCTIONS.equals(funGroup)) { switch (funName) { case FIXED_WINDOW: case SESSION_WINDOW: @@ -646,7 +647,7 @@ private RexNode convertResolvedFunctionCall( throw new UnsupportedOperationException( "Unsupported function: " + funName + ". Only support TUMBLE, HOP, and SESSION now."); } - } else if (funGroup.equals("ZetaSQL")) { + } else if (ZETASQL_FUNCTION_GROUP_NAME.equals(funGroup)) { if (op == null) { Type returnType = functionCall.getSignature().getResultType().getType(); if (returnType != null) { @@ -664,7 +665,7 @@ private RexNode convertResolvedFunctionCall( operands.add( convertRexNodeFromResolvedExpr(expr, columnList, fieldList, outerFunctionArguments)); } - } else if (funGroup.equals(USER_DEFINED_FUNCTIONS)) { + } else if (USER_DEFINED_FUNCTIONS.equals(funGroup)) { ResolvedCreateFunctionStmt createFunctionStmt = userFunctionDefinitions.sqlScalarFunctions.get(functionCall.getFunction().getNamePath()); ResolvedExpr functionExpression = createFunctionStmt.getFunctionExpression(); diff --git a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/SqlOperators.java b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/SqlOperators.java index 1c0835c4d96d..592c8d25fa66 100644 --- a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/SqlOperators.java +++ b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/SqlOperators.java @@ -26,6 +26,7 @@ import org.apache.beam.sdk.extensions.sql.impl.planner.BeamRelDataTypeSystem; import org.apache.beam.sdk.extensions.sql.impl.udaf.StringAgg; import org.apache.beam.sdk.extensions.sql.zetasql.DateTimeUtils; +import org.apache.beam.sdk.extensions.sql.zetasql.SqlAnalyzer; import org.apache.beam.sdk.extensions.sql.zetasql.translation.impl.BeamBuiltinMethods; import org.apache.beam.sdk.extensions.sql.zetasql.translation.impl.CastFunctionImpl; import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.jdbc.JavaTypeFactoryImpl; @@ -82,36 +83,57 @@ public class SqlOperators { new UdafImpl<>(new StringAgg.StringAggString())); public static final SqlOperator START_WITHS = - createUdfOperator("STARTS_WITH", BeamBuiltinMethods.STARTS_WITH_METHOD); + createUdfOperator( + "STARTS_WITH", + BeamBuiltinMethods.STARTS_WITH_METHOD, + SqlAnalyzer.ZETASQL_FUNCTION_GROUP_NAME); public static final SqlOperator CONCAT = - createUdfOperator("CONCAT", BeamBuiltinMethods.CONCAT_METHOD); + createUdfOperator( + "CONCAT", BeamBuiltinMethods.CONCAT_METHOD, SqlAnalyzer.ZETASQL_FUNCTION_GROUP_NAME); public static final SqlOperator REPLACE = - createUdfOperator("REPLACE", BeamBuiltinMethods.REPLACE_METHOD); + createUdfOperator( + "REPLACE", BeamBuiltinMethods.REPLACE_METHOD, SqlAnalyzer.ZETASQL_FUNCTION_GROUP_NAME); - public static final SqlOperator TRIM = createUdfOperator("TRIM", BeamBuiltinMethods.TRIM_METHOD); + public static final SqlOperator TRIM = + createUdfOperator( + "TRIM", BeamBuiltinMethods.TRIM_METHOD, SqlAnalyzer.ZETASQL_FUNCTION_GROUP_NAME); public static final SqlOperator LTRIM = - createUdfOperator("LTRIM", BeamBuiltinMethods.LTRIM_METHOD); + createUdfOperator( + "LTRIM", BeamBuiltinMethods.LTRIM_METHOD, SqlAnalyzer.ZETASQL_FUNCTION_GROUP_NAME); public static final SqlOperator RTRIM = - createUdfOperator("RTRIM", BeamBuiltinMethods.RTRIM_METHOD); + createUdfOperator( + "RTRIM", BeamBuiltinMethods.RTRIM_METHOD, SqlAnalyzer.ZETASQL_FUNCTION_GROUP_NAME); public static final SqlOperator SUBSTR = - createUdfOperator("SUBSTR", BeamBuiltinMethods.SUBSTR_METHOD); + createUdfOperator( + "SUBSTR", BeamBuiltinMethods.SUBSTR_METHOD, SqlAnalyzer.ZETASQL_FUNCTION_GROUP_NAME); public static final SqlOperator REVERSE = - createUdfOperator("REVERSE", BeamBuiltinMethods.REVERSE_METHOD); + createUdfOperator( + "REVERSE", BeamBuiltinMethods.REVERSE_METHOD, SqlAnalyzer.ZETASQL_FUNCTION_GROUP_NAME); public static final SqlOperator CHAR_LENGTH = - createUdfOperator("CHAR_LENGTH", BeamBuiltinMethods.CHAR_LENGTH_METHOD); + createUdfOperator( + "CHAR_LENGTH", + BeamBuiltinMethods.CHAR_LENGTH_METHOD, + SqlAnalyzer.ZETASQL_FUNCTION_GROUP_NAME); public static final SqlOperator ENDS_WITH = - createUdfOperator("ENDS_WITH", BeamBuiltinMethods.ENDS_WITH_METHOD); + createUdfOperator( + "ENDS_WITH", + BeamBuiltinMethods.ENDS_WITH_METHOD, + SqlAnalyzer.ZETASQL_FUNCTION_GROUP_NAME); public static final SqlOperator LIKE = - createUdfOperator("LIKE", BeamBuiltinMethods.LIKE_METHOD, SqlSyntax.BINARY); + createUdfOperator( + "LIKE", + BeamBuiltinMethods.LIKE_METHOD, + SqlSyntax.BINARY, + SqlAnalyzer.ZETASQL_FUNCTION_GROUP_NAME); public static final SqlOperator VALIDATE_TIMESTAMP = createUdfOperator( @@ -119,7 +141,8 @@ public class SqlOperators { DateTimeUtils.class, "validateTimestamp", x -> NULLABLE_TIMESTAMP, - ImmutableList.of(TIMESTAMP)); + ImmutableList.of(TIMESTAMP), + SqlAnalyzer.ZETASQL_FUNCTION_GROUP_NAME); public static final SqlOperator VALIDATE_TIME_INTERVAL = createUdfOperator( @@ -127,13 +150,18 @@ public class SqlOperators { DateTimeUtils.class, "validateTimeInterval", x -> NULLABLE_BIGINT, - ImmutableList.of(BIGINT, OTHER)); + ImmutableList.of(BIGINT, OTHER), + SqlAnalyzer.ZETASQL_FUNCTION_GROUP_NAME); public static final SqlOperator TIMESTAMP_OP = - createUdfOperator("TIMESTAMP", BeamBuiltinMethods.TIMESTAMP_METHOD); + createUdfOperator( + "TIMESTAMP", + BeamBuiltinMethods.TIMESTAMP_METHOD, + SqlAnalyzer.ZETASQL_FUNCTION_GROUP_NAME); public static final SqlOperator DATE_OP = - createUdfOperator("DATE", BeamBuiltinMethods.DATE_METHOD); + createUdfOperator( + "DATE", BeamBuiltinMethods.DATE_METHOD, SqlAnalyzer.ZETASQL_FUNCTION_GROUP_NAME); public static final SqlUserDefinedFunction CAST_OP = new SqlUserDefinedFunction( @@ -158,7 +186,7 @@ public static SqlFunction createZetaSqlFunction(String name, SqlTypeName returnT SqlFunctionCategory.USER_DEFINED_FUNCTION); } - private static SqlUserDefinedAggFunction createUdafOperator( + static SqlUserDefinedAggFunction createUdafOperator( String name, SqlReturnTypeInference returnTypeInference, AggregateFunction function) { return new SqlUserDefinedAggFunction( new SqlIdentifier(name, SqlParserPos.ZERO), @@ -177,26 +205,24 @@ private static SqlUserDefinedFunction createUdfOperator( Class methodClass, String methodName, SqlReturnTypeInference returnTypeInference, - List paramTypes) { + List paramTypes, + String funGroup) { return new SqlUserDefinedFunction( new SqlIdentifier(name, SqlParserPos.ZERO), returnTypeInference, null, null, paramTypes, - ScalarFunctionImpl.create(methodClass, methodName)); + ZetaSqlScalarFunctionImpl.create(methodClass, methodName, funGroup)); } - // Helper function to create SqlUserDefinedFunction based on a function name and a method. - // SqlUserDefinedFunction will be able to pass through Calcite codegen and get proper function - // called. - private static SqlUserDefinedFunction createUdfOperator(String name, Method method) { - return createUdfOperator(name, method, SqlSyntax.FUNCTION); + static SqlUserDefinedFunction createUdfOperator(String name, Method method, String funGroup) { + return createUdfOperator(name, method, SqlSyntax.FUNCTION, funGroup); } private static SqlUserDefinedFunction createUdfOperator( - String name, Method method, final SqlSyntax syntax) { - Function function = ScalarFunctionImpl.create(method); + String name, Method method, final SqlSyntax syntax, String funGroup) { + Function function = ZetaSqlScalarFunctionImpl.create(method, funGroup); final RelDataTypeFactory typeFactory = createTypeFactory(); List argTypes = new ArrayList<>(); diff --git a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/ZetaSqlScalarFunctionImpl.java b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/ZetaSqlScalarFunctionImpl.java new file mode 100644 index 000000000000..5255eec75e59 --- /dev/null +++ b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/ZetaSqlScalarFunctionImpl.java @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.extensions.sql.zetasql.translation; + +import java.lang.reflect.Method; +import org.apache.beam.sdk.extensions.sql.impl.ScalarFunctionImpl; +import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.adapter.enumerable.CallImplementor; +import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.schema.Function; +import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.schema.ScalarFunction; + +/** ZetaSQL-specific extension to {@link ScalarFunctionImpl}. */ +public class ZetaSqlScalarFunctionImpl extends ScalarFunctionImpl { + /** + * ZetaSQL function group identifier. Different function groups may have divergent translation + * paths. + */ + public final String functionGroup; + + private ZetaSqlScalarFunctionImpl( + Method method, CallImplementor implementor, String functionGroup) { + super(method, implementor); + this.functionGroup = functionGroup; + } + + /** + * Creates {@link org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.schema.Function} from + * given class. + * + *

If a method of the given name is not found or it does not suit, returns {@code null}. + * + * @param clazz class that is used to implement the function + * @param methodName Method name (typically "eval") + * @param functionGroup ZetaSQL function group identifier. Different function groups may have + * divergent translation paths. + * @return created {@link ScalarFunction} or null + */ + public static Function create(Class clazz, String methodName, String functionGroup) { + return create(findMethod(clazz, methodName)); + } + + /** + * Creates {@link org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.schema.Function} from + * given method. When {@code eval} method does not suit, {@code null} is returned. + * + * @param method method that is used to implement the function + * @param functionGroup ZetaSQL function group identifier. Different function groups may have + * divergent translation paths. + * @return created {@link Function} or null + */ + public static Function create(Method method, String functionGroup) { + validateMethod(method); + CallImplementor implementor = createImplementor(method); + return new ZetaSqlScalarFunctionImpl(method, implementor, functionGroup); + } + + /* + * Finds a method in a given class by name. + * @param clazz class to search method in + * @param name name of the method to find + * @return the first method with matching name or null when no method found + */ + private static Method findMethod(Class clazz, String name) { + for (Method method : clazz.getMethods()) { + if (method.getName().equals(name) && !method.isBridge()) { + return method; + } + } + throw new NoSuchMethodError( + String.format("Method %s not found in class %s.", name, clazz.getName())); + } +}