From eb0d35d2b9512b4b7d4c50de31fa54291d0d030f Mon Sep 17 00:00:00 2001 From: Benjamin Gonzalez Date: Fri, 10 Dec 2021 15:23:26 -0600 Subject: [PATCH 01/14] [BEAM-11808] Enable two params in aggregate functions, add string_agg with delimiter --- .../extensions/sql/impl/udaf/StringAgg.java | 63 +++++++- .../SupportedZetaSqlBuiltinFunctions.java | 6 +- .../translation/AggregateScanConverter.java | 47 +++--- .../translation/ExpressionConverter.java | 2 +- .../SqlNullIfOperatorRewriter.java | 4 +- .../translation/SqlOperatorMappingTable.java | 148 ++++++++++-------- .../sql/zetasql/translation/SqlOperators.java | 46 +++++- .../sql/zetasql/ZetaSqlDialectSpecTest.java | 48 ++++++ 8 files changed, 262 insertions(+), 102 deletions(-) diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/udaf/StringAgg.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/udaf/StringAgg.java index f687551f3f3a..eb1c7736de44 100644 --- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/udaf/StringAgg.java +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/udaf/StringAgg.java @@ -17,6 +17,7 @@ */ package org.apache.beam.sdk.extensions.sql.impl.udaf; +import java.nio.charset.StandardCharsets; import org.apache.beam.sdk.annotations.Experimental; import org.apache.beam.sdk.transforms.Combine.CombineFn; @@ -28,10 +29,15 @@ @Experimental public class StringAgg { - /** A {@link CombineFn} that aggregates strings with comma as delimiter. */ + /** A {@link CombineFn} that aggregates strings with a string as delimiter. */ public static class StringAggString extends CombineFn { + private String delimiter = ","; - private static final String delimiter = ","; + public StringAggString() {} + + public StringAggString(String delimiter) { + this.delimiter = delimiter; + } @Override public String createAccumulator() { @@ -43,7 +49,7 @@ public String addInput(String curString, String nextString) { if (!nextString.isEmpty()) { if (!curString.isEmpty()) { - curString += StringAggString.delimiter + nextString; + curString += delimiter + nextString; } else { curString = nextString; } @@ -58,7 +64,7 @@ public String mergeAccumulators(Iterable accumList) { for (String stringAccum : accumList) { if (!stringAccum.isEmpty()) { if (!mergeString.isEmpty()) { - mergeString += StringAggString.delimiter + stringAccum; + mergeString += delimiter + stringAccum; } else { mergeString = stringAccum; } @@ -73,4 +79,53 @@ public String extractOutput(String output) { return output; } } + + /** A {@link CombineFn} that aggregates bytes with a byte as delimiter. */ + public static class StringAggByte extends CombineFn { + private String delimiter = ","; + + public StringAggByte() {} + + public StringAggByte(byte[] delimiter) { + this.delimiter = new String(delimiter, StandardCharsets.UTF_8); + } + + @Override + public String createAccumulator() { + return ""; + } + + @Override + public String addInput(String mutableAccumulator, byte[] input) { + if (input != null) { + if (!mutableAccumulator.isEmpty()) { + mutableAccumulator += delimiter + new String(input, StandardCharsets.UTF_8); + } else { + mutableAccumulator = new String(input, StandardCharsets.UTF_8); + } + } + return mutableAccumulator; + } + + @Override + public String mergeAccumulators(Iterable accumList) { + String mergeString = ""; + for (String stringAccum : accumList) { + if (!stringAccum.isEmpty()) { + if (!mergeString.isEmpty()) { + mergeString += delimiter + stringAccum; + } else { + mergeString = stringAccum; + } + } + } + + return mergeString; + } + + @Override + public byte[] extractOutput(String output) { + return output.getBytes(StandardCharsets.UTF_8); + } + } } diff --git a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/SupportedZetaSqlBuiltinFunctions.java b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/SupportedZetaSqlBuiltinFunctions.java index 65dee35bd435..601aee5667c2 100644 --- a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/SupportedZetaSqlBuiltinFunctions.java +++ b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/SupportedZetaSqlBuiltinFunctions.java @@ -405,9 +405,9 @@ class SupportedZetaSqlBuiltinFunctions { FunctionSignatureId.FN_MAX, // max FunctionSignatureId.FN_MIN, // min FunctionSignatureId.FN_STRING_AGG_STRING, // string_agg(s) - // FunctionSignatureId.FN_STRING_AGG_DELIM_STRING, // string_agg(s, delim_s) - // FunctionSignatureId.FN_STRING_AGG_BYTES, // string_agg(b) - // FunctionSignatureId.FN_STRING_AGG_DELIM_BYTES, // string_agg(b, delim_b) + FunctionSignatureId.FN_STRING_AGG_DELIM_STRING, // string_agg(s, delim_s) + FunctionSignatureId.FN_STRING_AGG_BYTES, // string_agg(b) + FunctionSignatureId.FN_STRING_AGG_DELIM_BYTES, // string_agg(b, delim_b) FunctionSignatureId.FN_SUM_INT64, // sum FunctionSignatureId.FN_SUM_DOUBLE, // sum FunctionSignatureId.FN_SUM_NUMERIC, // sum diff --git a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/AggregateScanConverter.java b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/AggregateScanConverter.java index 912a7b0abe9f..3c001cb49427 100644 --- a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/AggregateScanConverter.java +++ b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/AggregateScanConverter.java @@ -20,6 +20,7 @@ import static com.google.zetasql.ZetaSQLResolvedNodeKind.ResolvedNodeKind.RESOLVED_CAST; import static com.google.zetasql.ZetaSQLResolvedNodeKind.ResolvedNodeKind.RESOLVED_COLUMN_REF; import static com.google.zetasql.ZetaSQLResolvedNodeKind.ResolvedNodeKind.RESOLVED_GET_STRUCT_FIELD; +import static com.google.zetasql.ZetaSQLResolvedNodeKind.ResolvedNodeKind.RESOLVED_LITERAL; import com.google.zetasql.FunctionSignature; import com.google.zetasql.ZetaSQLType.TypeKind; @@ -148,24 +149,27 @@ private LogicalProject convertAggregateScanInputScanToLogicalProject( // aggregation? ResolvedAggregateFunctionCall aggregateFunctionCall = ((ResolvedAggregateFunctionCall) resolvedComputedColumn.getExpr()); - if (aggregateFunctionCall.getArgumentList() != null - && aggregateFunctionCall.getArgumentList().size() == 1) { - ResolvedExpr resolvedExpr = aggregateFunctionCall.getArgumentList().get(0); - - // TODO: assume aggregate function's input is either a ColumnRef or a cast(ColumnRef). - // TODO: user might use multiple CAST so we need to handle this rare case. - projects.add( - getExpressionConverter() - .convertRexNodeFromResolvedExpr( - resolvedExpr, - node.getInputScan().getColumnList(), - input.getRowType().getFieldList(), - ImmutableMap.of())); - fieldNames.add(getTrait().resolveAlias(resolvedComputedColumn.getColumn())); - } else if (aggregateFunctionCall.getArgumentList() != null - && aggregateFunctionCall.getArgumentList().size() > 1) { - throw new IllegalArgumentException( - aggregateFunctionCall.getFunction().getName() + " has more than one argument."); + ImmutableList argumentList = + ImmutableList.copyOf(aggregateFunctionCall.getArgumentList()); + if (argumentList != null && argumentList.size() >= 1) { + ResolvedExpr resolvedExpr = argumentList.get(0); + for (int i = 0; i < argumentList.size(); i++) { + if (i == 0) { + // TODO: assume aggregate function's input is either a ColumnRef or a cast(ColumnRef). + // TODO: user might use multiple CAST so we need to handle this rare case. + projects.add( + getExpressionConverter() + .convertRexNodeFromResolvedExpr( + resolvedExpr, + node.getInputScan().getColumnList(), + input.getRowType().getFieldList(), + ImmutableMap.of())); + } else { + projects.add( + getExpressionConverter().convertRexNodeFromResolvedExpr(argumentList.get(i))); + } + fieldNames.add(getTrait().resolveAlias(resolvedComputedColumn.getColumn())); + } } } @@ -228,10 +232,7 @@ private AggregateCall convertAggCall( aggregateFunctionCall.getFunction().getName(), typeInference, impl); } else { // Look up builtin functions in SqlOperatorMappingTable. - sqlAggFunction = - (SqlAggFunction) - SqlOperatorMappingTable.ZETASQL_FUNCTION_TO_CALCITE_SQL_OPERATOR.get( - aggregateFunctionCall.getFunction().getName()); + sqlAggFunction = (SqlAggFunction) SqlOperatorMappingTable.create(aggregateFunctionCall); if (sqlAggFunction == null) { throw new UnsupportedOperationException( "Does not support ZetaSQL aggregate function: " @@ -248,6 +249,8 @@ private AggregateCall convertAggCall( || expr.nodeKind() == RESOLVED_COLUMN_REF || expr.nodeKind() == RESOLVED_GET_STRUCT_FIELD) { argList.add(columnRefOff); + } else if (expr.nodeKind() == RESOLVED_LITERAL) { + continue; } else { throw new UnsupportedOperationException( "Aggregate function only accepts Column Reference or CAST(Column Reference) as its" diff --git a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/ExpressionConverter.java b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/ExpressionConverter.java index 17e4766b815f..94d830ff2ad3 100644 --- a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/ExpressionConverter.java +++ b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/ExpressionConverter.java @@ -609,7 +609,7 @@ private RexNode convertResolvedFunctionCall( Map outerFunctionArguments) { final String funGroup = functionCall.getFunction().getGroup(); final String funName = functionCall.getFunction().getName(); - SqlOperator op = SqlOperatorMappingTable.ZETASQL_FUNCTION_TO_CALCITE_SQL_OPERATOR.get(funName); + SqlOperator op = SqlOperatorMappingTable.create(functionCall); List operands = new ArrayList<>(); if (PRE_DEFINED_WINDOW_FUNCTIONS.equals(funGroup)) { diff --git a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/SqlNullIfOperatorRewriter.java b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/SqlNullIfOperatorRewriter.java index 4209bba0aaf5..17bc92d6114e 100644 --- a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/SqlNullIfOperatorRewriter.java +++ b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/SqlNullIfOperatorRewriter.java @@ -43,7 +43,9 @@ public RexNode apply(RexBuilder rexBuilder, List operands) { operands.size() == 2, "NULLIF should have two arguments in function call."); SqlOperator op = - SqlOperatorMappingTable.ZETASQL_FUNCTION_TO_CALCITE_SQL_OPERATOR.get("$case_no_value"); + SqlOperatorMappingTable.ZETASQL_FUNCTION_TO_CALCITE_SQL_OPERATOR + .get("$case_no_value") + .apply(null); List newOperands = ImmutableList.of( rexBuilder.makeCall( diff --git a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/SqlOperatorMappingTable.java b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/SqlOperatorMappingTable.java index 952d16252091..76f564e3239c 100644 --- a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/SqlOperatorMappingTable.java +++ b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/SqlOperatorMappingTable.java @@ -17,7 +17,9 @@ */ package org.apache.beam.sdk.extensions.sql.zetasql.translation; +import com.google.zetasql.resolvedast.ResolvedNodes; import java.util.Map; +import java.util.function.Function; import org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.sql.SqlOperator; import org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.sql.fun.SqlStdOperatorTable; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap; @@ -28,74 +30,79 @@ class SqlOperatorMappingTable { // todo: Some of operators defined here are later overridden in ZetaSQLPlannerImpl. // We should remove them from this table and add generic way to provide custom // implementation. (Ex.: timestamp_add) - static final Map ZETASQL_FUNCTION_TO_CALCITE_SQL_OPERATOR = - ImmutableMap.builder() - // grouped window function - .put("TUMBLE", SqlStdOperatorTable.TUMBLE_OLD) - .put("HOP", SqlStdOperatorTable.HOP_OLD) - .put("SESSION", SqlStdOperatorTable.SESSION_OLD) + static final Map> + ZETASQL_FUNCTION_TO_CALCITE_SQL_OPERATOR = + ImmutableMap + .>builder() + // grouped window function + .put("TUMBLE", resolvedFunction -> SqlStdOperatorTable.TUMBLE_OLD) + .put("HOP", resolvedFunction -> SqlStdOperatorTable.HOP_OLD) + .put("SESSION", resolvedFunction -> SqlStdOperatorTable.SESSION_OLD) - // ZetaSQL functions - .put("$and", SqlStdOperatorTable.AND) - .put("$or", SqlStdOperatorTable.OR) - .put("$not", SqlStdOperatorTable.NOT) - .put("$equal", SqlStdOperatorTable.EQUALS) - .put("$not_equal", SqlStdOperatorTable.NOT_EQUALS) - .put("$greater", SqlStdOperatorTable.GREATER_THAN) - .put("$greater_or_equal", SqlStdOperatorTable.GREATER_THAN_OR_EQUAL) - .put("$less", SqlStdOperatorTable.LESS_THAN) - .put("$less_or_equal", SqlStdOperatorTable.LESS_THAN_OR_EQUAL) - .put("$like", SqlOperators.LIKE) - .put("$is_null", SqlStdOperatorTable.IS_NULL) - .put("$is_true", SqlStdOperatorTable.IS_TRUE) - .put("$is_false", SqlStdOperatorTable.IS_FALSE) - .put("$add", SqlStdOperatorTable.PLUS) - .put("$subtract", SqlStdOperatorTable.MINUS) - .put("$multiply", SqlStdOperatorTable.MULTIPLY) - .put("$unary_minus", SqlStdOperatorTable.UNARY_MINUS) - .put("$divide", SqlStdOperatorTable.DIVIDE) - .put("concat", SqlOperators.CONCAT) - .put("substr", SqlOperators.SUBSTR) - .put("substring", SqlOperators.SUBSTR) - .put("trim", SqlOperators.TRIM) - .put("replace", SqlOperators.REPLACE) - .put("char_length", SqlOperators.CHAR_LENGTH) - .put("starts_with", SqlOperators.START_WITHS) - .put("ends_with", SqlOperators.ENDS_WITH) - .put("ltrim", SqlOperators.LTRIM) - .put("rtrim", SqlOperators.RTRIM) - .put("reverse", SqlOperators.REVERSE) - .put("$count_star", SqlStdOperatorTable.COUNT) - .put("max", SqlStdOperatorTable.MAX) - .put("min", SqlStdOperatorTable.MIN) - .put("avg", SqlStdOperatorTable.AVG) - .put("sum", SqlStdOperatorTable.SUM) - .put("any_value", SqlStdOperatorTable.ANY_VALUE) - .put("count", SqlStdOperatorTable.COUNT) - .put("bit_and", SqlStdOperatorTable.BIT_AND) - .put("string_agg", SqlOperators.STRING_AGG_STRING_FN) // NULL values not supported - .put("array_agg", SqlOperators.ARRAY_AGG_FN) - .put("bit_or", SqlStdOperatorTable.BIT_OR) - .put("bit_xor", SqlOperators.BIT_XOR) - .put("ceil", SqlStdOperatorTable.CEIL) - .put("floor", SqlStdOperatorTable.FLOOR) - .put("mod", SqlStdOperatorTable.MOD) - .put("timestamp", SqlOperators.TIMESTAMP_OP) - .put("$case_no_value", SqlStdOperatorTable.CASE) + // ZetaSQL functions + .put("$and", resolvedFunction -> SqlStdOperatorTable.AND) + .put("$or", resolvedFunction -> SqlStdOperatorTable.OR) + .put("$not", resolvedFunction -> SqlStdOperatorTable.NOT) + .put("$equal", resolvedFunction -> SqlStdOperatorTable.EQUALS) + .put("$not_equal", resolvedFunction -> SqlStdOperatorTable.NOT_EQUALS) + .put("$greater", resolvedFunction -> SqlStdOperatorTable.GREATER_THAN) + .put( + "$greater_or_equal", + resolvedFunction -> SqlStdOperatorTable.GREATER_THAN_OR_EQUAL) + .put("$less", resolvedFunction -> SqlStdOperatorTable.LESS_THAN) + .put("$less_or_equal", resolvedFunction -> SqlStdOperatorTable.LESS_THAN_OR_EQUAL) + .put("$like", resolvedFunction -> SqlOperators.LIKE) + .put("$is_null", resolvedFunction -> SqlStdOperatorTable.IS_NULL) + .put("$is_true", resolvedFunction -> SqlStdOperatorTable.IS_TRUE) + .put("$is_false", resolvedFunction -> SqlStdOperatorTable.IS_FALSE) + .put("$add", resolvedFunction -> SqlStdOperatorTable.PLUS) + .put("$subtract", resolvedFunction -> SqlStdOperatorTable.MINUS) + .put("$multiply", resolvedFunction -> SqlStdOperatorTable.MULTIPLY) + .put("$unary_minus", resolvedFunction -> SqlStdOperatorTable.UNARY_MINUS) + .put("$divide", resolvedFunction -> SqlStdOperatorTable.DIVIDE) + .put("concat", resolvedFunction -> SqlOperators.CONCAT) + .put("substr", resolvedFunction -> SqlOperators.SUBSTR) + .put("substring", resolvedFunction -> SqlOperators.SUBSTR) + .put("trim", resolvedFunction -> SqlOperators.TRIM) + .put("replace", resolvedFunction -> SqlOperators.REPLACE) + .put("char_length", resolvedFunction -> SqlOperators.CHAR_LENGTH) + .put("starts_with", resolvedFunction -> SqlOperators.START_WITHS) + .put("ends_with", resolvedFunction -> SqlOperators.ENDS_WITH) + .put("ltrim", resolvedFunction -> SqlOperators.LTRIM) + .put("rtrim", resolvedFunction -> SqlOperators.RTRIM) + .put("reverse", resolvedFunction -> SqlOperators.REVERSE) + .put("$count_star", resolvedFunction -> SqlStdOperatorTable.COUNT) + .put("max", resolvedFunction -> SqlStdOperatorTable.MAX) + .put("min", resolvedFunction -> SqlStdOperatorTable.MIN) + .put("avg", resolvedFunction -> SqlStdOperatorTable.AVG) + .put("sum", resolvedFunction -> SqlStdOperatorTable.SUM) + .put("any_value", resolvedFunction -> SqlStdOperatorTable.ANY_VALUE) + .put("count", resolvedFunction -> SqlStdOperatorTable.COUNT) + .put("bit_and", resolvedFunction -> SqlStdOperatorTable.BIT_AND) + .put("string_agg", SqlOperators::createStringAggOperator) // NULL values not supported + .put("array_agg", resolvedFunction -> SqlOperators.ARRAY_AGG_FN) + .put("bit_or", resolvedFunction -> SqlStdOperatorTable.BIT_OR) + .put("bit_xor", resolvedFunction -> SqlOperators.BIT_XOR) + .put("ceil", resolvedFunction -> SqlStdOperatorTable.CEIL) + .put("floor", resolvedFunction -> SqlStdOperatorTable.FLOOR) + .put("mod", resolvedFunction -> SqlStdOperatorTable.MOD) + .put("timestamp", resolvedFunction -> SqlOperators.TIMESTAMP_OP) + .put("$case_no_value", resolvedFunction -> SqlStdOperatorTable.CASE) - // if operator - IF(cond, pos, neg) can actually be mapped directly to `CASE WHEN cond - // THEN pos ELSE neg` - .put("if", SqlStdOperatorTable.CASE) + // if operator - IF(cond, pos, neg) can actually be mapped directly to `CASE WHEN cond + // THEN pos ELSE neg` + .put("if", resolvedFunction -> SqlStdOperatorTable.CASE) - // $case_no_value specializations - // all of these operators can have their operands adjusted to achieve the same thing with - // a call to $case_with_value - .put("$case_with_value", SqlStdOperatorTable.CASE) - .put("coalesce", SqlStdOperatorTable.CASE) - .put("ifnull", SqlStdOperatorTable.CASE) - .put("nullif", SqlStdOperatorTable.CASE) - .put("countif", SqlOperators.COUNTIF) - .build(); + // $case_no_value specializations + // all of these operators can have their operands adjusted to achieve the same thing + // with + // a call to $case_with_value + .put("$case_with_value", resolvedFunction -> SqlStdOperatorTable.CASE) + .put("coalesce", resolvedFunction -> SqlStdOperatorTable.CASE) + .put("ifnull", resolvedFunction -> SqlStdOperatorTable.CASE) + .put("nullif", resolvedFunction -> SqlStdOperatorTable.CASE) + .put("countif", resolvedFunction -> SqlOperators.COUNTIF) + .build(); static final Map ZETASQL_FUNCTION_TO_CALCITE_SQL_OPERATOR_REWRITER = ImmutableMap.builder() @@ -105,4 +112,15 @@ class SqlOperatorMappingTable { .put("nullif", new SqlNullIfOperatorRewriter()) .put("$in", new SqlInOperatorRewriter()) .build(); + + public static SqlOperator create(ResolvedNodes.ResolvedFunctionCallBase aggregateFunctionCall) { + + Function sqlOperatorFactory = + ZETASQL_FUNCTION_TO_CALCITE_SQL_OPERATOR.get(aggregateFunctionCall.getFunction().getName()); + + if (sqlOperatorFactory != null) { + return sqlOperatorFactory.apply(aggregateFunctionCall); + } + return null; + } } diff --git a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/SqlOperators.java b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/SqlOperators.java index 4fb24233abc8..99d4a2837ad9 100644 --- a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/SqlOperators.java +++ b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/SqlOperators.java @@ -19,6 +19,8 @@ import static org.apache.beam.sdk.extensions.sql.zetasql.BeamZetaSqlCatalog.ZETASQL_FUNCTION_GROUP_NAME; +import com.google.zetasql.Value; +import com.google.zetasql.resolvedast.ResolvedNodes; import java.lang.reflect.Method; import java.util.ArrayList; import java.util.List; @@ -81,12 +83,6 @@ public class SqlOperators { private static final RelDataType BIGINT = createSqlType(SqlTypeName.BIGINT, false); private static final RelDataType NULLABLE_BIGINT = createSqlType(SqlTypeName.BIGINT, true); - public static final SqlOperator STRING_AGG_STRING_FN = - createUdafOperator( - "string_agg", - x -> createTypeFactory().createSqlType(SqlTypeName.VARCHAR), - new UdafImpl<>(new StringAgg.StringAggString())); - public static final SqlOperator ARRAY_AGG_FN = createUdafOperator( "array_agg", @@ -180,6 +176,44 @@ public class SqlOperators { null, new CastFunctionImpl()); + public static SqlOperator createStringAggOperator( + ResolvedNodes.ResolvedFunctionCallBase aggregateFunctionCall) { + com.google.common.collect.ImmutableList args = + aggregateFunctionCall.getArgumentList(); + String inputType = args.get(0).getType().typeName(); + Value delimiter = null; + if (args.size() == 2) { + delimiter = ((ResolvedNodes.ResolvedLiteral) args.get(1)).getValue(); + } + switch (inputType) { + case "BYTES": + if (delimiter != null) { + return SqlOperators.createUdafOperator( + "string_agg", + x -> SqlOperators.createTypeFactory().createSqlType(SqlTypeName.VARBINARY), + new UdafImpl<>(new StringAgg.StringAggByte(delimiter.getBytesValue().toByteArray()))); + } + return SqlOperators.createUdafOperator( + "string_agg", + x -> SqlOperators.createTypeFactory().createSqlType(SqlTypeName.VARBINARY), + new UdafImpl<>(new StringAgg.StringAggByte())); + case "STRING": + if (delimiter != null) { + return SqlOperators.createUdafOperator( + "string_agg", + x -> SqlOperators.createTypeFactory().createSqlType(SqlTypeName.VARCHAR), + new UdafImpl<>(new StringAgg.StringAggString(delimiter.getStringValue()))); + } + return SqlOperators.createUdafOperator( + "string_agg", + x -> SqlOperators.createTypeFactory().createSqlType(SqlTypeName.VARCHAR), + new UdafImpl<>(new StringAgg.StringAggString())); + default: + throw new UnsupportedOperationException( + String.format("[%s] is not supported in STRING_AGG", inputType)); + } + } + /** * Create a dummy SqlFunction of type OTHER_FUNCTION from given function name and return type. * These functions will be unparsed in either {@link diff --git a/sdks/java/extensions/sql/zetasql/src/test/java/org/apache/beam/sdk/extensions/sql/zetasql/ZetaSqlDialectSpecTest.java b/sdks/java/extensions/sql/zetasql/src/test/java/org/apache/beam/sdk/extensions/sql/zetasql/ZetaSqlDialectSpecTest.java index 7c676d998d9b..29ef23f4a703 100644 --- a/sdks/java/extensions/sql/zetasql/src/test/java/org/apache/beam/sdk/extensions/sql/zetasql/ZetaSqlDialectSpecTest.java +++ b/sdks/java/extensions/sql/zetasql/src/test/java/org/apache/beam/sdk/extensions/sql/zetasql/ZetaSqlDialectSpecTest.java @@ -2559,6 +2559,54 @@ public void testStringAggregation() { pipeline.run().waitUntilFinish(Duration.standardMinutes(PIPELINE_EXECUTION_WAITTIME_MINUTES)); } + @Test + public void testStringAggregationBytes() { + String sql = + "SELECT STRING_AGG(CAST(fruit as bytes)) AS string_agg" + + " FROM UNNEST([\"apple\", \"pear\", \"banana\", \"pear\"]) AS fruit"; + PCollection stream = execute(sql); + + Schema schema = Schema.builder().addByteArrayField("bytearray_field").build(); + PAssert.that(stream) + .containsInAnyOrder( + Row.withSchema(schema) + .addValue("apple,pear,banana,pear".getBytes(StandardCharsets.UTF_8)) + .build()); + + pipeline.run().waitUntilFinish(Duration.standardMinutes(PIPELINE_EXECUTION_WAITTIME_MINUTES)); + } + + @Test + public void testStringAggregationDelimiter() { + String sql = + "SELECT STRING_AGG(fruit, \"&\") AS string_agg" + + " FROM UNNEST([\"apple\", \"pear\", \"banana\", \"pear\"]) AS fruit"; + PCollection stream = execute(sql); + + Schema schema = Schema.builder().addStringField("string_field").build(); + PAssert.that(stream) + .containsInAnyOrder(Row.withSchema(schema).addValue("apple&pear&banana&pear").build()); + + pipeline.run().waitUntilFinish(Duration.standardMinutes(PIPELINE_EXECUTION_WAITTIME_MINUTES)); + } + + @Test + public void testStringAggregationBytesDelimiter() { + String sql = + "SELECT STRING_AGG(CAST(fruit as bytes), b\"&\") AS string_agg" + + " FROM UNNEST([\"apple\", \"pear\", \"banana\", \"pear\"]) AS fruit"; + PCollection stream = execute(sql); + + Schema schema = Schema.builder().addByteArrayField("bytearray_field").build(); + PAssert.that(stream) + .containsInAnyOrder( + Row.withSchema(schema) + .addValue("apple&pear&banana&pear".getBytes(StandardCharsets.UTF_8)) + .build()); + + pipeline.run().waitUntilFinish(Duration.standardMinutes(PIPELINE_EXECUTION_WAITTIME_MINUTES)); + } + @Test @Ignore("Seeing exception in Beam, need further investigation on the cause of this failed query.") public void testNamedUNNESTJoin() { From b20ae715c34d54ba1277ddc313633e6bed4fef23 Mon Sep 17 00:00:00 2001 From: Benjamin Gonzalez Date: Mon, 13 Dec 2021 12:56:39 -0600 Subject: [PATCH 02/14] [BEAM-11808] Fix checkstyle warning --- .../sdk/extensions/sql/zetasql/translation/SqlOperators.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/SqlOperators.java b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/SqlOperators.java index 99d4a2837ad9..8fd6275044da 100644 --- a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/SqlOperators.java +++ b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/SqlOperators.java @@ -178,7 +178,7 @@ public class SqlOperators { public static SqlOperator createStringAggOperator( ResolvedNodes.ResolvedFunctionCallBase aggregateFunctionCall) { - com.google.common.collect.ImmutableList args = + List args = aggregateFunctionCall.getArgumentList(); String inputType = args.get(0).getType().typeName(); Value delimiter = null; From cd9032d2750eec9f12dc6fad1a1a73011ec6ad93 Mon Sep 17 00:00:00 2001 From: Benjamin Gonzalez Date: Mon, 13 Dec 2021 13:03:35 -0600 Subject: [PATCH 03/14] [BEAM-11808] Fix spotlessApply --- .../sdk/extensions/sql/zetasql/translation/SqlOperators.java | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/SqlOperators.java b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/SqlOperators.java index 8fd6275044da..1c1f548ef327 100644 --- a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/SqlOperators.java +++ b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/SqlOperators.java @@ -178,8 +178,7 @@ public class SqlOperators { public static SqlOperator createStringAggOperator( ResolvedNodes.ResolvedFunctionCallBase aggregateFunctionCall) { - List args = - aggregateFunctionCall.getArgumentList(); + List args = aggregateFunctionCall.getArgumentList(); String inputType = args.get(0).getType().typeName(); Value delimiter = null; if (args.size() == 2) { From 940977d23a7ad54fb8f063dfadd153b1aa2c329e Mon Sep 17 00:00:00 2001 From: Benjamin Gonzalez Date: Mon, 13 Dec 2021 14:15:01 -0600 Subject: [PATCH 04/14] [BEAM-11808] Fix checkstyle warning --- .../sql/zetasql/translation/SqlOperatorMappingTable.java | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/SqlOperatorMappingTable.java b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/SqlOperatorMappingTable.java index 76f564e3239c..33c29b235d5b 100644 --- a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/SqlOperatorMappingTable.java +++ b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/SqlOperatorMappingTable.java @@ -23,6 +23,7 @@ import org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.sql.SqlOperator; import org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.sql.fun.SqlStdOperatorTable; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap; +import org.checkerframework.checker.nullness.qual.Nullable; /** SqlOperatorMappingTable. */ class SqlOperatorMappingTable { @@ -113,7 +114,7 @@ class SqlOperatorMappingTable { .put("$in", new SqlInOperatorRewriter()) .build(); - public static SqlOperator create(ResolvedNodes.ResolvedFunctionCallBase aggregateFunctionCall) { + public static @Nullable SqlOperator create(ResolvedNodes.ResolvedFunctionCallBase aggregateFunctionCall) { Function sqlOperatorFactory = ZETASQL_FUNCTION_TO_CALCITE_SQL_OPERATOR.get(aggregateFunctionCall.getFunction().getName()); From f3352170f33db9e7b67e466c6b61866a0c564d3c Mon Sep 17 00:00:00 2001 From: Benjamin Gonzalez Date: Mon, 13 Dec 2021 14:27:47 -0600 Subject: [PATCH 05/14] [BEAM-11808] Fix spotlessApply --- .../sql/zetasql/translation/SqlOperatorMappingTable.java | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/SqlOperatorMappingTable.java b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/SqlOperatorMappingTable.java index 33c29b235d5b..1e9caae75975 100644 --- a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/SqlOperatorMappingTable.java +++ b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/SqlOperatorMappingTable.java @@ -114,7 +114,8 @@ class SqlOperatorMappingTable { .put("$in", new SqlInOperatorRewriter()) .build(); - public static @Nullable SqlOperator create(ResolvedNodes.ResolvedFunctionCallBase aggregateFunctionCall) { + public static @Nullable SqlOperator create( + ResolvedNodes.ResolvedFunctionCallBase aggregateFunctionCall) { Function sqlOperatorFactory = ZETASQL_FUNCTION_TO_CALCITE_SQL_OPERATOR.get(aggregateFunctionCall.getFunction().getName()); From 5b440242a67cb7f2c78a4e58be2b13ae69463498 Mon Sep 17 00:00:00 2001 From: Benjamin Gonzalez Date: Tue, 28 Dec 2021 13:00:34 -0600 Subject: [PATCH 06/14] [BEAM-11808] Change initialization StringAgg and minor fixes --- .../extensions/sql/impl/udaf/StringAgg.java | 10 +++----- .../translation/AggregateScanConverter.java | 25 +++++++++++-------- .../translation/SqlOperatorMappingTable.java | 2 +- .../sql/zetasql/translation/SqlOperators.java | 23 +++++++---------- 4 files changed, 28 insertions(+), 32 deletions(-) diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/udaf/StringAgg.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/udaf/StringAgg.java index eb1c7736de44..357cee4bb3ba 100644 --- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/udaf/StringAgg.java +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/udaf/StringAgg.java @@ -31,9 +31,7 @@ public class StringAgg { /** A {@link CombineFn} that aggregates strings with a string as delimiter. */ public static class StringAggString extends CombineFn { - private String delimiter = ","; - - public StringAggString() {} + private final String delimiter; public StringAggString(String delimiter) { this.delimiter = delimiter; @@ -80,11 +78,9 @@ public String extractOutput(String output) { } } - /** A {@link CombineFn} that aggregates bytes with a byte as delimiter. */ + /** A {@link CombineFn} that aggregates bytes with a byte array as delimiter. */ public static class StringAggByte extends CombineFn { - private String delimiter = ","; - - public StringAggByte() {} + private final String delimiter; public StringAggByte(byte[] delimiter) { this.delimiter = new String(delimiter, StandardCharsets.UTF_8); diff --git a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/AggregateScanConverter.java b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/AggregateScanConverter.java index 3c001cb49427..c548c43038ea 100644 --- a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/AggregateScanConverter.java +++ b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/AggregateScanConverter.java @@ -23,13 +23,16 @@ import static com.google.zetasql.ZetaSQLResolvedNodeKind.ResolvedNodeKind.RESOLVED_LITERAL; import com.google.zetasql.FunctionSignature; +import com.google.zetasql.ZetaSQLResolvedNodeKind; import com.google.zetasql.ZetaSQLType.TypeKind; import com.google.zetasql.resolvedast.ResolvedNode; +import com.google.zetasql.resolvedast.ResolvedNodes; import com.google.zetasql.resolvedast.ResolvedNodes.ResolvedAggregateFunctionCall; import com.google.zetasql.resolvedast.ResolvedNodes.ResolvedAggregateScan; import com.google.zetasql.resolvedast.ResolvedNodes.ResolvedComputedColumn; import com.google.zetasql.resolvedast.ResolvedNodes.ResolvedExpr; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collections; import java.util.List; import java.util.stream.Collectors; @@ -149,8 +152,8 @@ private LogicalProject convertAggregateScanInputScanToLogicalProject( // aggregation? ResolvedAggregateFunctionCall aggregateFunctionCall = ((ResolvedAggregateFunctionCall) resolvedComputedColumn.getExpr()); - ImmutableList argumentList = - ImmutableList.copyOf(aggregateFunctionCall.getArgumentList()); + com.google.common.collect.ImmutableList argumentList = + aggregateFunctionCall.getArgumentList(); if (argumentList != null && argumentList.size() >= 1) { ResolvedExpr resolvedExpr = argumentList.get(0); for (int i = 0; i < argumentList.size(); i++) { @@ -241,20 +244,22 @@ private AggregateCall convertAggCall( } List argList = new ArrayList<>(); - for (ResolvedExpr expr : - ((ResolvedAggregateFunctionCall) computedColumn.getExpr()).getArgumentList()) { + com.google.common.collect.ImmutableList argumentList = + ((ResolvedAggregateFunctionCall) computedColumn.getExpr()).getArgumentList(); + List resolvedNodeKinds = + Arrays.asList(RESOLVED_CAST, RESOLVED_COLUMN_REF, RESOLVED_GET_STRUCT_FIELD); + for (int i = 0; i < argumentList.size(); i++) { // Throw an error if aggregate function's input isn't either a ColumnRef or a cast(ColumnRef). // TODO: is there a general way to handle aggregation calls conversion? - if (expr.nodeKind() == RESOLVED_CAST - || expr.nodeKind() == RESOLVED_COLUMN_REF - || expr.nodeKind() == RESOLVED_GET_STRUCT_FIELD) { + ZetaSQLResolvedNodeKind.ResolvedNodeKind resolvedNodeKind = argumentList.get(i).nodeKind(); + if (i == 0 && resolvedNodeKinds.contains(resolvedNodeKind)) { argList.add(columnRefOff); - } else if (expr.nodeKind() == RESOLVED_LITERAL) { + } else if (resolvedNodeKind == RESOLVED_LITERAL) { continue; } else { throw new UnsupportedOperationException( - "Aggregate function only accepts Column Reference or CAST(Column Reference) as its" - + " input."); + "Aggregate function only accepts Column Reference or CAST(Column Reference) as the first argument and " + + "Literals as subsequent arguments as its inputs"); } } diff --git a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/SqlOperatorMappingTable.java b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/SqlOperatorMappingTable.java index 1e9caae75975..75ac29029100 100644 --- a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/SqlOperatorMappingTable.java +++ b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/SqlOperatorMappingTable.java @@ -114,7 +114,7 @@ class SqlOperatorMappingTable { .put("$in", new SqlInOperatorRewriter()) .build(); - public static @Nullable SqlOperator create( + static @Nullable SqlOperator create( ResolvedNodes.ResolvedFunctionCallBase aggregateFunctionCall) { Function sqlOperatorFactory = diff --git a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/SqlOperators.java b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/SqlOperators.java index 1c1f548ef327..a65e8ebf4d68 100644 --- a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/SqlOperators.java +++ b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/SqlOperators.java @@ -22,6 +22,7 @@ import com.google.zetasql.Value; import com.google.zetasql.resolvedast.ResolvedNodes; import java.lang.reflect.Method; +import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.List; import org.apache.beam.sdk.annotations.Internal; @@ -186,27 +187,21 @@ public static SqlOperator createStringAggOperator( } switch (inputType) { case "BYTES": - if (delimiter != null) { - return SqlOperators.createUdafOperator( - "string_agg", - x -> SqlOperators.createTypeFactory().createSqlType(SqlTypeName.VARBINARY), - new UdafImpl<>(new StringAgg.StringAggByte(delimiter.getBytesValue().toByteArray()))); - } return SqlOperators.createUdafOperator( "string_agg", x -> SqlOperators.createTypeFactory().createSqlType(SqlTypeName.VARBINARY), - new UdafImpl<>(new StringAgg.StringAggByte())); + new UdafImpl<>( + new StringAgg.StringAggByte( + delimiter == null + ? ",".getBytes(StandardCharsets.UTF_8) + : delimiter.getBytesValue().toByteArray()))); case "STRING": - if (delimiter != null) { - return SqlOperators.createUdafOperator( - "string_agg", - x -> SqlOperators.createTypeFactory().createSqlType(SqlTypeName.VARCHAR), - new UdafImpl<>(new StringAgg.StringAggString(delimiter.getStringValue()))); - } return SqlOperators.createUdafOperator( "string_agg", x -> SqlOperators.createTypeFactory().createSqlType(SqlTypeName.VARCHAR), - new UdafImpl<>(new StringAgg.StringAggString())); + new UdafImpl<>( + new StringAgg.StringAggString( + delimiter == null ? "," : delimiter.getStringValue()))); default: throw new UnsupportedOperationException( String.format("[%s] is not supported in STRING_AGG", inputType)); From 5302e526d1bd286101fb3945dd508c457b07b3a3 Mon Sep 17 00:00:00 2001 From: Benjamin Gonzalez Date: Tue, 28 Dec 2021 13:37:54 -0600 Subject: [PATCH 07/14] [BEAM-11808] Fix checkstyle warnings --- .../translation/AggregateScanConverter.java | 22 +++++++++---------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/AggregateScanConverter.java b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/AggregateScanConverter.java index c548c43038ea..4e4a0a36486b 100644 --- a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/AggregateScanConverter.java +++ b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/AggregateScanConverter.java @@ -26,7 +26,6 @@ import com.google.zetasql.ZetaSQLResolvedNodeKind; import com.google.zetasql.ZetaSQLType.TypeKind; import com.google.zetasql.resolvedast.ResolvedNode; -import com.google.zetasql.resolvedast.ResolvedNodes; import com.google.zetasql.resolvedast.ResolvedNodes.ResolvedAggregateFunctionCall; import com.google.zetasql.resolvedast.ResolvedNodes.ResolvedAggregateScan; import com.google.zetasql.resolvedast.ResolvedNodes.ResolvedComputedColumn; @@ -152,11 +151,10 @@ private LogicalProject convertAggregateScanInputScanToLogicalProject( // aggregation? ResolvedAggregateFunctionCall aggregateFunctionCall = ((ResolvedAggregateFunctionCall) resolvedComputedColumn.getExpr()); - com.google.common.collect.ImmutableList argumentList = - aggregateFunctionCall.getArgumentList(); - if (argumentList != null && argumentList.size() >= 1) { - ResolvedExpr resolvedExpr = argumentList.get(0); - for (int i = 0; i < argumentList.size(); i++) { + if (aggregateFunctionCall.getArgumentList() != null + && aggregateFunctionCall.getArgumentList().size() >= 1) { + ResolvedExpr resolvedExpr = aggregateFunctionCall.getArgumentList().get(0); + for (int i = 0; i < aggregateFunctionCall.getArgumentList().size(); i++) { if (i == 0) { // TODO: assume aggregate function's input is either a ColumnRef or a cast(ColumnRef). // TODO: user might use multiple CAST so we need to handle this rare case. @@ -169,7 +167,9 @@ private LogicalProject convertAggregateScanInputScanToLogicalProject( ImmutableMap.of())); } else { projects.add( - getExpressionConverter().convertRexNodeFromResolvedExpr(argumentList.get(i))); + getExpressionConverter() + .convertRexNodeFromResolvedExpr( + aggregateFunctionCall.getArgumentList().get(i))); } fieldNames.add(getTrait().resolveAlias(resolvedComputedColumn.getColumn())); } @@ -244,14 +244,14 @@ private AggregateCall convertAggCall( } List argList = new ArrayList<>(); - com.google.common.collect.ImmutableList argumentList = - ((ResolvedAggregateFunctionCall) computedColumn.getExpr()).getArgumentList(); + ResolvedAggregateFunctionCall expr = ((ResolvedAggregateFunctionCall) computedColumn.getExpr()); List resolvedNodeKinds = Arrays.asList(RESOLVED_CAST, RESOLVED_COLUMN_REF, RESOLVED_GET_STRUCT_FIELD); - for (int i = 0; i < argumentList.size(); i++) { + for (int i = 0; i < expr.getArgumentList().size(); i++) { // Throw an error if aggregate function's input isn't either a ColumnRef or a cast(ColumnRef). // TODO: is there a general way to handle aggregation calls conversion? - ZetaSQLResolvedNodeKind.ResolvedNodeKind resolvedNodeKind = argumentList.get(i).nodeKind(); + ZetaSQLResolvedNodeKind.ResolvedNodeKind resolvedNodeKind = + expr.getArgumentList().get(i).nodeKind(); if (i == 0 && resolvedNodeKinds.contains(resolvedNodeKind)) { argList.add(columnRefOff); } else if (resolvedNodeKind == RESOLVED_LITERAL) { From cefe7fb7fec9ba3930ae4c9b09de127a48c844ff Mon Sep 17 00:00:00 2001 From: Benjamin Gonzalez Date: Tue, 11 Jan 2022 17:37:39 -0600 Subject: [PATCH 08/14] [BEAM-11808] Add test cases for array_agg and timestamp null max,min --- .../translation/AggregateScanConverter.java | 4 ++- .../sql/zetasql/ZetaSqlDialectSpecTest.java | 33 +++++++++++++++++++ sdks/python/report.html | 0 3 files changed, 36 insertions(+), 1 deletion(-) create mode 100644 sdks/python/report.html diff --git a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/AggregateScanConverter.java b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/AggregateScanConverter.java index 4e4a0a36486b..4438dd4585b2 100644 --- a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/AggregateScanConverter.java +++ b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/AggregateScanConverter.java @@ -254,7 +254,9 @@ private AggregateCall convertAggCall( expr.getArgumentList().get(i).nodeKind(); if (i == 0 && resolvedNodeKinds.contains(resolvedNodeKind)) { argList.add(columnRefOff); - } else if (resolvedNodeKind == RESOLVED_LITERAL) { + } else if (i > 0 + && resolvedNodeKind + == RESOLVED_LITERAL) { // Doesn't support RESOLVED LITERAL as first argument continue; } else { throw new UnsupportedOperationException( diff --git a/sdks/java/extensions/sql/zetasql/src/test/java/org/apache/beam/sdk/extensions/sql/zetasql/ZetaSqlDialectSpecTest.java b/sdks/java/extensions/sql/zetasql/src/test/java/org/apache/beam/sdk/extensions/sql/zetasql/ZetaSqlDialectSpecTest.java index 29ef23f4a703..06abfa6540fb 100644 --- a/sdks/java/extensions/sql/zetasql/src/test/java/org/apache/beam/sdk/extensions/sql/zetasql/ZetaSqlDialectSpecTest.java +++ b/sdks/java/extensions/sql/zetasql/src/test/java/org/apache/beam/sdk/extensions/sql/zetasql/ZetaSqlDialectSpecTest.java @@ -3990,6 +3990,39 @@ public void testArrayAggEmpty() { pipeline.run().waitUntilFinish(Duration.standardMinutes(PIPELINE_EXECUTION_WAITTIME_MINUTES)); } + @Test + public void testArrayAggConstantValue() { + String sql = "SELECT ARRAY_AGG(1) b FROM UNNEST([1, 2, 3]) a"; + + PCollection stream = execute(sql); + + Schema schema = Schema.builder().addArrayField("array_field", FieldType.INT64).build(); + PAssert.that(stream).containsInAnyOrder(Row.withSchema(schema).addArray(1L, 1L, 1L).build()); + + pipeline.run().waitUntilFinish(Duration.standardMinutes(PIPELINE_EXECUTION_WAITTIME_MINUTES)); + } + + @Test + public void testTimestampNullMaxMin() { + String sql = + "SELECT MAX(CAST(NULL AS TIMESTAMP)) AS max_NULL,\n" + + " MIN(CAST(NULL AS TIMESTAMP)) AS min_NULL\n" + + "FROM (SELECT 1)"; + + PCollection stream = execute(sql); + + Schema schema = + Schema.builder() + .addNullableField("field1", FieldType.INT64) + .addNullableField("field2", FieldType.INT64) + .build(); + PAssert.that(stream) + .containsInAnyOrder( + Row.withSchema(schema).addValue((Long) null).addValue((Long) null).build()); + + pipeline.run().waitUntilFinish(Duration.standardMinutes(PIPELINE_EXECUTION_WAITTIME_MINUTES)); + } + @Test public void testInt64SumOverflow() { String sql = diff --git a/sdks/python/report.html b/sdks/python/report.html new file mode 100644 index 000000000000..e69de29bb2d1 From f4411d49349b9a852efbe2aff0862b3e97f3d63f Mon Sep 17 00:00:00 2001 From: Benjamin Gonzalez Date: Tue, 11 Jan 2022 17:41:28 -0600 Subject: [PATCH 09/14] [BEAM-11808] Remove leftover file --- sdks/python/report.html | 0 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 sdks/python/report.html diff --git a/sdks/python/report.html b/sdks/python/report.html deleted file mode 100644 index e69de29bb2d1..000000000000 From 618c69f572862263c7f07668b4cbb2d409ba5a26 Mon Sep 17 00:00:00 2001 From: Benjamin Gonzalez Date: Tue, 11 Jan 2022 18:57:08 -0600 Subject: [PATCH 10/14] [BEAM-11808] Enable resolved_literal as firts arg --- .../sql/zetasql/translation/AggregateScanConverter.java | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/AggregateScanConverter.java b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/AggregateScanConverter.java index 4438dd4585b2..4e4a0a36486b 100644 --- a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/AggregateScanConverter.java +++ b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/AggregateScanConverter.java @@ -254,9 +254,7 @@ private AggregateCall convertAggCall( expr.getArgumentList().get(i).nodeKind(); if (i == 0 && resolvedNodeKinds.contains(resolvedNodeKind)) { argList.add(columnRefOff); - } else if (i > 0 - && resolvedNodeKind - == RESOLVED_LITERAL) { // Doesn't support RESOLVED LITERAL as first argument + } else if (resolvedNodeKind == RESOLVED_LITERAL) { continue; } else { throw new UnsupportedOperationException( From 7ee7a5db840ebd97aa38ee3beff0f8d7ff8056ae Mon Sep 17 00:00:00 2001 From: Benjamin Gonzalez Date: Thu, 13 Jan 2022 18:21:13 -0600 Subject: [PATCH 11/14] [BEAM-11808] Remove tests, validate RESOLVED_LITERAL as second argument --- .../translation/AggregateScanConverter.java | 2 +- .../sql/zetasql/ZetaSqlDialectSpecTest.java | 33 ------------------- 2 files changed, 1 insertion(+), 34 deletions(-) diff --git a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/AggregateScanConverter.java b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/AggregateScanConverter.java index 4e4a0a36486b..c1ab6897bde5 100644 --- a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/AggregateScanConverter.java +++ b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/AggregateScanConverter.java @@ -254,7 +254,7 @@ private AggregateCall convertAggCall( expr.getArgumentList().get(i).nodeKind(); if (i == 0 && resolvedNodeKinds.contains(resolvedNodeKind)) { argList.add(columnRefOff); - } else if (resolvedNodeKind == RESOLVED_LITERAL) { + } else if (i > 0 && resolvedNodeKind == RESOLVED_LITERAL) { continue; } else { throw new UnsupportedOperationException( diff --git a/sdks/java/extensions/sql/zetasql/src/test/java/org/apache/beam/sdk/extensions/sql/zetasql/ZetaSqlDialectSpecTest.java b/sdks/java/extensions/sql/zetasql/src/test/java/org/apache/beam/sdk/extensions/sql/zetasql/ZetaSqlDialectSpecTest.java index 06abfa6540fb..29ef23f4a703 100644 --- a/sdks/java/extensions/sql/zetasql/src/test/java/org/apache/beam/sdk/extensions/sql/zetasql/ZetaSqlDialectSpecTest.java +++ b/sdks/java/extensions/sql/zetasql/src/test/java/org/apache/beam/sdk/extensions/sql/zetasql/ZetaSqlDialectSpecTest.java @@ -3990,39 +3990,6 @@ public void testArrayAggEmpty() { pipeline.run().waitUntilFinish(Duration.standardMinutes(PIPELINE_EXECUTION_WAITTIME_MINUTES)); } - @Test - public void testArrayAggConstantValue() { - String sql = "SELECT ARRAY_AGG(1) b FROM UNNEST([1, 2, 3]) a"; - - PCollection stream = execute(sql); - - Schema schema = Schema.builder().addArrayField("array_field", FieldType.INT64).build(); - PAssert.that(stream).containsInAnyOrder(Row.withSchema(schema).addArray(1L, 1L, 1L).build()); - - pipeline.run().waitUntilFinish(Duration.standardMinutes(PIPELINE_EXECUTION_WAITTIME_MINUTES)); - } - - @Test - public void testTimestampNullMaxMin() { - String sql = - "SELECT MAX(CAST(NULL AS TIMESTAMP)) AS max_NULL,\n" - + " MIN(CAST(NULL AS TIMESTAMP)) AS min_NULL\n" - + "FROM (SELECT 1)"; - - PCollection stream = execute(sql); - - Schema schema = - Schema.builder() - .addNullableField("field1", FieldType.INT64) - .addNullableField("field2", FieldType.INT64) - .build(); - PAssert.that(stream) - .containsInAnyOrder( - Row.withSchema(schema).addValue((Long) null).addValue((Long) null).build()); - - pipeline.run().waitUntilFinish(Duration.standardMinutes(PIPELINE_EXECUTION_WAITTIME_MINUTES)); - } - @Test public void testInt64SumOverflow() { String sql = From 16d9d4cb4a92864d7411b654a46fcde5eb7a6412 Mon Sep 17 00:00:00 2001 From: Benjamin Gonzalez Date: Mon, 17 Jan 2022 10:40:26 -0600 Subject: [PATCH 12/14] [BEAM-11808] Add unsupportedException for delimiter as ResolvedParam --- .../sql/zetasql/translation/SqlOperators.java | 12 +++++++++++- .../sql/zetasql/ZetaSqlDialectSpecTest.java | 14 ++++++++++++++ 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/SqlOperators.java b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/SqlOperators.java index a65e8ebf4d68..30a3b71e71c8 100644 --- a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/SqlOperators.java +++ b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/SqlOperators.java @@ -183,8 +183,18 @@ public static SqlOperator createStringAggOperator( String inputType = args.get(0).getType().typeName(); Value delimiter = null; if (args.size() == 2) { - delimiter = ((ResolvedNodes.ResolvedLiteral) args.get(1)).getValue(); + ResolvedNodes.ResolvedExpr resolvedExpr = args.get(1); + if (resolvedExpr instanceof ResolvedNodes.ResolvedLiteral) { + delimiter = ((ResolvedNodes.ResolvedLiteral) resolvedExpr).getValue(); + } else { + // TODO (BEAM-13673 Add support for params) + throw new UnsupportedOperationException( + String.format( + "STRING_AGG only supports ResolvedLiteral as delimiter, provided %s", + resolvedExpr.getClass().getName())); + } } + switch (inputType) { case "BYTES": return SqlOperators.createUdafOperator( diff --git a/sdks/java/extensions/sql/zetasql/src/test/java/org/apache/beam/sdk/extensions/sql/zetasql/ZetaSqlDialectSpecTest.java b/sdks/java/extensions/sql/zetasql/src/test/java/org/apache/beam/sdk/extensions/sql/zetasql/ZetaSqlDialectSpecTest.java index 29ef23f4a703..862f9af54d9c 100644 --- a/sdks/java/extensions/sql/zetasql/src/test/java/org/apache/beam/sdk/extensions/sql/zetasql/ZetaSqlDialectSpecTest.java +++ b/sdks/java/extensions/sql/zetasql/src/test/java/org/apache/beam/sdk/extensions/sql/zetasql/ZetaSqlDialectSpecTest.java @@ -2607,6 +2607,20 @@ public void testStringAggregationBytesDelimiter() { pipeline.run().waitUntilFinish(Duration.standardMinutes(PIPELINE_EXECUTION_WAITTIME_MINUTES)); } + @Test + public void testStringAggregationParamsDelimiter() { + String sql = "SELECT string_agg(\"s\", @separator) FROM (SELECT 1)"; + + ImmutableMap params = + ImmutableMap.builder() + .put("separator", Value.createStringValue(",")) + .build(); + + ZetaSQLQueryPlanner zetaSQLQueryPlanner = new ZetaSQLQueryPlanner(config); + thrown.expect(UnsupportedOperationException.class); + zetaSQLQueryPlanner.convertToBeamRel(sql, params); + } + @Test @Ignore("Seeing exception in Beam, need further investigation on the cause of this failed query.") public void testNamedUNNESTJoin() { From 87608428dc08d24ac4ea372cd2f5853dc2260613 Mon Sep 17 00:00:00 2001 From: Kyle Weaver Date: Tue, 18 Jan 2022 11:25:11 -0800 Subject: [PATCH 13/14] use zetasql exception --- .../sql/zetasql/translation/SqlOperators.java | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/SqlOperators.java b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/SqlOperators.java index 30a3b71e71c8..778d3594d997 100644 --- a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/SqlOperators.java +++ b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/SqlOperators.java @@ -20,6 +20,8 @@ import static org.apache.beam.sdk.extensions.sql.zetasql.BeamZetaSqlCatalog.ZETASQL_FUNCTION_GROUP_NAME; import com.google.zetasql.Value; +import com.google.zetasql.io.grpc.Status; +import com.google.zetasql.io.grpc.StatusRuntimeException; import com.google.zetasql.resolvedast.ResolvedNodes; import java.lang.reflect.Method; import java.nio.charset.StandardCharsets; @@ -34,6 +36,7 @@ import org.apache.beam.sdk.extensions.sql.impl.udaf.ArrayAgg; import org.apache.beam.sdk.extensions.sql.impl.udaf.StringAgg; import org.apache.beam.sdk.extensions.sql.zetasql.DateTimeUtils; +import org.apache.beam.sdk.extensions.sql.zetasql.ZetaSqlException; import org.apache.beam.sdk.extensions.sql.zetasql.translation.impl.BeamBuiltinMethods; import org.apache.beam.sdk.extensions.sql.zetasql.translation.impl.CastFunctionImpl; import org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.jdbc.JavaTypeFactoryImpl; @@ -187,11 +190,13 @@ public static SqlOperator createStringAggOperator( if (resolvedExpr instanceof ResolvedNodes.ResolvedLiteral) { delimiter = ((ResolvedNodes.ResolvedLiteral) resolvedExpr).getValue(); } else { - // TODO (BEAM-13673 Add support for params) - throw new UnsupportedOperationException( - String.format( - "STRING_AGG only supports ResolvedLiteral as delimiter, provided %s", - resolvedExpr.getClass().getName())); + // TODO(BEAM-13673) Add support for params + throw new ZetaSqlException( + new StatusRuntimeException( + Status.INVALID_ARGUMENT.withDescription( + String.format( + "STRING_AGG only supports ResolvedLiteral as delimiter, provided %s", + resolvedExpr.getClass().getName())))); } } From 1b69b20849c39b086be35eb84783eddff19bb174 Mon Sep 17 00:00:00 2001 From: Kyle Weaver Date: Tue, 18 Jan 2022 12:45:03 -0800 Subject: [PATCH 14/14] update test --- .../beam/sdk/extensions/sql/zetasql/ZetaSqlDialectSpecTest.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdks/java/extensions/sql/zetasql/src/test/java/org/apache/beam/sdk/extensions/sql/zetasql/ZetaSqlDialectSpecTest.java b/sdks/java/extensions/sql/zetasql/src/test/java/org/apache/beam/sdk/extensions/sql/zetasql/ZetaSqlDialectSpecTest.java index 862f9af54d9c..eaba01ba7826 100644 --- a/sdks/java/extensions/sql/zetasql/src/test/java/org/apache/beam/sdk/extensions/sql/zetasql/ZetaSqlDialectSpecTest.java +++ b/sdks/java/extensions/sql/zetasql/src/test/java/org/apache/beam/sdk/extensions/sql/zetasql/ZetaSqlDialectSpecTest.java @@ -2617,7 +2617,7 @@ public void testStringAggregationParamsDelimiter() { .build(); ZetaSQLQueryPlanner zetaSQLQueryPlanner = new ZetaSQLQueryPlanner(config); - thrown.expect(UnsupportedOperationException.class); + thrown.expect(ZetaSqlException.class); // BEAM-13673 zetaSQLQueryPlanner.convertToBeamRel(sql, params); }