diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/udaf/StringAgg.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/udaf/StringAgg.java index f687551f3f3a..357cee4bb3ba 100644 --- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/udaf/StringAgg.java +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/udaf/StringAgg.java @@ -17,6 +17,7 @@ */ package org.apache.beam.sdk.extensions.sql.impl.udaf; +import java.nio.charset.StandardCharsets; import org.apache.beam.sdk.annotations.Experimental; import org.apache.beam.sdk.transforms.Combine.CombineFn; @@ -28,10 +29,13 @@ @Experimental public class StringAgg { - /** A {@link CombineFn} that aggregates strings with comma as delimiter. */ + /** A {@link CombineFn} that aggregates strings with a string as delimiter. */ public static class StringAggString extends CombineFn { + private final String delimiter; - private static final String delimiter = ","; + public StringAggString(String delimiter) { + this.delimiter = delimiter; + } @Override public String createAccumulator() { @@ -43,7 +47,7 @@ public String addInput(String curString, String nextString) { if (!nextString.isEmpty()) { if (!curString.isEmpty()) { - curString += StringAggString.delimiter + nextString; + curString += delimiter + nextString; } else { curString = nextString; } @@ -58,7 +62,7 @@ public String mergeAccumulators(Iterable accumList) { for (String stringAccum : accumList) { if (!stringAccum.isEmpty()) { if (!mergeString.isEmpty()) { - mergeString += StringAggString.delimiter + stringAccum; + mergeString += delimiter + stringAccum; } else { mergeString = stringAccum; } @@ -73,4 +77,51 @@ public String extractOutput(String output) { return output; } } + + /** A {@link CombineFn} that aggregates bytes with a byte array as delimiter. */ + public static class StringAggByte extends CombineFn { + private final String delimiter; + + public StringAggByte(byte[] delimiter) { + this.delimiter = new String(delimiter, StandardCharsets.UTF_8); + } + + @Override + public String createAccumulator() { + return ""; + } + + @Override + public String addInput(String mutableAccumulator, byte[] input) { + if (input != null) { + if (!mutableAccumulator.isEmpty()) { + mutableAccumulator += delimiter + new String(input, StandardCharsets.UTF_8); + } else { + mutableAccumulator = new String(input, StandardCharsets.UTF_8); + } + } + return mutableAccumulator; + } + + @Override + public String mergeAccumulators(Iterable accumList) { + String mergeString = ""; + for (String stringAccum : accumList) { + if (!stringAccum.isEmpty()) { + if (!mergeString.isEmpty()) { + mergeString += delimiter + stringAccum; + } else { + mergeString = stringAccum; + } + } + } + + return mergeString; + } + + @Override + public byte[] extractOutput(String output) { + return output.getBytes(StandardCharsets.UTF_8); + } + } } diff --git a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/SupportedZetaSqlBuiltinFunctions.java b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/SupportedZetaSqlBuiltinFunctions.java index 65dee35bd435..601aee5667c2 100644 --- a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/SupportedZetaSqlBuiltinFunctions.java +++ b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/SupportedZetaSqlBuiltinFunctions.java @@ -405,9 +405,9 @@ class SupportedZetaSqlBuiltinFunctions { FunctionSignatureId.FN_MAX, // max FunctionSignatureId.FN_MIN, // min FunctionSignatureId.FN_STRING_AGG_STRING, // string_agg(s) - // FunctionSignatureId.FN_STRING_AGG_DELIM_STRING, // string_agg(s, delim_s) - // FunctionSignatureId.FN_STRING_AGG_BYTES, // string_agg(b) - // FunctionSignatureId.FN_STRING_AGG_DELIM_BYTES, // string_agg(b, delim_b) + FunctionSignatureId.FN_STRING_AGG_DELIM_STRING, // string_agg(s, delim_s) + FunctionSignatureId.FN_STRING_AGG_BYTES, // string_agg(b) + FunctionSignatureId.FN_STRING_AGG_DELIM_BYTES, // string_agg(b, delim_b) FunctionSignatureId.FN_SUM_INT64, // sum FunctionSignatureId.FN_SUM_DOUBLE, // sum FunctionSignatureId.FN_SUM_NUMERIC, // sum diff --git a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/AggregateScanConverter.java b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/AggregateScanConverter.java index 912a7b0abe9f..c1ab6897bde5 100644 --- a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/AggregateScanConverter.java +++ b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/AggregateScanConverter.java @@ -20,8 +20,10 @@ import static com.google.zetasql.ZetaSQLResolvedNodeKind.ResolvedNodeKind.RESOLVED_CAST; import static com.google.zetasql.ZetaSQLResolvedNodeKind.ResolvedNodeKind.RESOLVED_COLUMN_REF; import static com.google.zetasql.ZetaSQLResolvedNodeKind.ResolvedNodeKind.RESOLVED_GET_STRUCT_FIELD; +import static com.google.zetasql.ZetaSQLResolvedNodeKind.ResolvedNodeKind.RESOLVED_LITERAL; import com.google.zetasql.FunctionSignature; +import com.google.zetasql.ZetaSQLResolvedNodeKind; import com.google.zetasql.ZetaSQLType.TypeKind; import com.google.zetasql.resolvedast.ResolvedNode; import com.google.zetasql.resolvedast.ResolvedNodes.ResolvedAggregateFunctionCall; @@ -29,6 +31,7 @@ import com.google.zetasql.resolvedast.ResolvedNodes.ResolvedComputedColumn; import com.google.zetasql.resolvedast.ResolvedNodes.ResolvedExpr; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collections; import java.util.List; import java.util.stream.Collectors; @@ -149,23 +152,27 @@ private LogicalProject convertAggregateScanInputScanToLogicalProject( ResolvedAggregateFunctionCall aggregateFunctionCall = ((ResolvedAggregateFunctionCall) resolvedComputedColumn.getExpr()); if (aggregateFunctionCall.getArgumentList() != null - && aggregateFunctionCall.getArgumentList().size() == 1) { + && aggregateFunctionCall.getArgumentList().size() >= 1) { ResolvedExpr resolvedExpr = aggregateFunctionCall.getArgumentList().get(0); - - // TODO: assume aggregate function's input is either a ColumnRef or a cast(ColumnRef). - // TODO: user might use multiple CAST so we need to handle this rare case. - projects.add( - getExpressionConverter() - .convertRexNodeFromResolvedExpr( - resolvedExpr, - node.getInputScan().getColumnList(), - input.getRowType().getFieldList(), - ImmutableMap.of())); - fieldNames.add(getTrait().resolveAlias(resolvedComputedColumn.getColumn())); - } else if (aggregateFunctionCall.getArgumentList() != null - && aggregateFunctionCall.getArgumentList().size() > 1) { - throw new IllegalArgumentException( - aggregateFunctionCall.getFunction().getName() + " has more than one argument."); + for (int i = 0; i < aggregateFunctionCall.getArgumentList().size(); i++) { + if (i == 0) { + // TODO: assume aggregate function's input is either a ColumnRef or a cast(ColumnRef). + // TODO: user might use multiple CAST so we need to handle this rare case. + projects.add( + getExpressionConverter() + .convertRexNodeFromResolvedExpr( + resolvedExpr, + node.getInputScan().getColumnList(), + input.getRowType().getFieldList(), + ImmutableMap.of())); + } else { + projects.add( + getExpressionConverter() + .convertRexNodeFromResolvedExpr( + aggregateFunctionCall.getArgumentList().get(i))); + } + fieldNames.add(getTrait().resolveAlias(resolvedComputedColumn.getColumn())); + } } } @@ -228,10 +235,7 @@ private AggregateCall convertAggCall( aggregateFunctionCall.getFunction().getName(), typeInference, impl); } else { // Look up builtin functions in SqlOperatorMappingTable. - sqlAggFunction = - (SqlAggFunction) - SqlOperatorMappingTable.ZETASQL_FUNCTION_TO_CALCITE_SQL_OPERATOR.get( - aggregateFunctionCall.getFunction().getName()); + sqlAggFunction = (SqlAggFunction) SqlOperatorMappingTable.create(aggregateFunctionCall); if (sqlAggFunction == null) { throw new UnsupportedOperationException( "Does not support ZetaSQL aggregate function: " @@ -240,18 +244,22 @@ private AggregateCall convertAggCall( } List argList = new ArrayList<>(); - for (ResolvedExpr expr : - ((ResolvedAggregateFunctionCall) computedColumn.getExpr()).getArgumentList()) { + ResolvedAggregateFunctionCall expr = ((ResolvedAggregateFunctionCall) computedColumn.getExpr()); + List resolvedNodeKinds = + Arrays.asList(RESOLVED_CAST, RESOLVED_COLUMN_REF, RESOLVED_GET_STRUCT_FIELD); + for (int i = 0; i < expr.getArgumentList().size(); i++) { // Throw an error if aggregate function's input isn't either a ColumnRef or a cast(ColumnRef). // TODO: is there a general way to handle aggregation calls conversion? - if (expr.nodeKind() == RESOLVED_CAST - || expr.nodeKind() == RESOLVED_COLUMN_REF - || expr.nodeKind() == RESOLVED_GET_STRUCT_FIELD) { + ZetaSQLResolvedNodeKind.ResolvedNodeKind resolvedNodeKind = + expr.getArgumentList().get(i).nodeKind(); + if (i == 0 && resolvedNodeKinds.contains(resolvedNodeKind)) { argList.add(columnRefOff); + } else if (i > 0 && resolvedNodeKind == RESOLVED_LITERAL) { + continue; } else { throw new UnsupportedOperationException( - "Aggregate function only accepts Column Reference or CAST(Column Reference) as its" - + " input."); + "Aggregate function only accepts Column Reference or CAST(Column Reference) as the first argument and " + + "Literals as subsequent arguments as its inputs"); } } diff --git a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/ExpressionConverter.java b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/ExpressionConverter.java index 17e4766b815f..94d830ff2ad3 100644 --- a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/ExpressionConverter.java +++ b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/ExpressionConverter.java @@ -609,7 +609,7 @@ private RexNode convertResolvedFunctionCall( Map outerFunctionArguments) { final String funGroup = functionCall.getFunction().getGroup(); final String funName = functionCall.getFunction().getName(); - SqlOperator op = SqlOperatorMappingTable.ZETASQL_FUNCTION_TO_CALCITE_SQL_OPERATOR.get(funName); + SqlOperator op = SqlOperatorMappingTable.create(functionCall); List operands = new ArrayList<>(); if (PRE_DEFINED_WINDOW_FUNCTIONS.equals(funGroup)) { diff --git a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/SqlNullIfOperatorRewriter.java b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/SqlNullIfOperatorRewriter.java index 4209bba0aaf5..17bc92d6114e 100644 --- a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/SqlNullIfOperatorRewriter.java +++ b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/SqlNullIfOperatorRewriter.java @@ -43,7 +43,9 @@ public RexNode apply(RexBuilder rexBuilder, List operands) { operands.size() == 2, "NULLIF should have two arguments in function call."); SqlOperator op = - SqlOperatorMappingTable.ZETASQL_FUNCTION_TO_CALCITE_SQL_OPERATOR.get("$case_no_value"); + SqlOperatorMappingTable.ZETASQL_FUNCTION_TO_CALCITE_SQL_OPERATOR + .get("$case_no_value") + .apply(null); List newOperands = ImmutableList.of( rexBuilder.makeCall( diff --git a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/SqlOperatorMappingTable.java b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/SqlOperatorMappingTable.java index 952d16252091..75ac29029100 100644 --- a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/SqlOperatorMappingTable.java +++ b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/SqlOperatorMappingTable.java @@ -17,10 +17,13 @@ */ package org.apache.beam.sdk.extensions.sql.zetasql.translation; +import com.google.zetasql.resolvedast.ResolvedNodes; import java.util.Map; +import java.util.function.Function; import org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.sql.SqlOperator; import org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.sql.fun.SqlStdOperatorTable; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap; +import org.checkerframework.checker.nullness.qual.Nullable; /** SqlOperatorMappingTable. */ class SqlOperatorMappingTable { @@ -28,74 +31,79 @@ class SqlOperatorMappingTable { // todo: Some of operators defined here are later overridden in ZetaSQLPlannerImpl. // We should remove them from this table and add generic way to provide custom // implementation. (Ex.: timestamp_add) - static final Map ZETASQL_FUNCTION_TO_CALCITE_SQL_OPERATOR = - ImmutableMap.builder() - // grouped window function - .put("TUMBLE", SqlStdOperatorTable.TUMBLE_OLD) - .put("HOP", SqlStdOperatorTable.HOP_OLD) - .put("SESSION", SqlStdOperatorTable.SESSION_OLD) + static final Map> + ZETASQL_FUNCTION_TO_CALCITE_SQL_OPERATOR = + ImmutableMap + .>builder() + // grouped window function + .put("TUMBLE", resolvedFunction -> SqlStdOperatorTable.TUMBLE_OLD) + .put("HOP", resolvedFunction -> SqlStdOperatorTable.HOP_OLD) + .put("SESSION", resolvedFunction -> SqlStdOperatorTable.SESSION_OLD) - // ZetaSQL functions - .put("$and", SqlStdOperatorTable.AND) - .put("$or", SqlStdOperatorTable.OR) - .put("$not", SqlStdOperatorTable.NOT) - .put("$equal", SqlStdOperatorTable.EQUALS) - .put("$not_equal", SqlStdOperatorTable.NOT_EQUALS) - .put("$greater", SqlStdOperatorTable.GREATER_THAN) - .put("$greater_or_equal", SqlStdOperatorTable.GREATER_THAN_OR_EQUAL) - .put("$less", SqlStdOperatorTable.LESS_THAN) - .put("$less_or_equal", SqlStdOperatorTable.LESS_THAN_OR_EQUAL) - .put("$like", SqlOperators.LIKE) - .put("$is_null", SqlStdOperatorTable.IS_NULL) - .put("$is_true", SqlStdOperatorTable.IS_TRUE) - .put("$is_false", SqlStdOperatorTable.IS_FALSE) - .put("$add", SqlStdOperatorTable.PLUS) - .put("$subtract", SqlStdOperatorTable.MINUS) - .put("$multiply", SqlStdOperatorTable.MULTIPLY) - .put("$unary_minus", SqlStdOperatorTable.UNARY_MINUS) - .put("$divide", SqlStdOperatorTable.DIVIDE) - .put("concat", SqlOperators.CONCAT) - .put("substr", SqlOperators.SUBSTR) - .put("substring", SqlOperators.SUBSTR) - .put("trim", SqlOperators.TRIM) - .put("replace", SqlOperators.REPLACE) - .put("char_length", SqlOperators.CHAR_LENGTH) - .put("starts_with", SqlOperators.START_WITHS) - .put("ends_with", SqlOperators.ENDS_WITH) - .put("ltrim", SqlOperators.LTRIM) - .put("rtrim", SqlOperators.RTRIM) - .put("reverse", SqlOperators.REVERSE) - .put("$count_star", SqlStdOperatorTable.COUNT) - .put("max", SqlStdOperatorTable.MAX) - .put("min", SqlStdOperatorTable.MIN) - .put("avg", SqlStdOperatorTable.AVG) - .put("sum", SqlStdOperatorTable.SUM) - .put("any_value", SqlStdOperatorTable.ANY_VALUE) - .put("count", SqlStdOperatorTable.COUNT) - .put("bit_and", SqlStdOperatorTable.BIT_AND) - .put("string_agg", SqlOperators.STRING_AGG_STRING_FN) // NULL values not supported - .put("array_agg", SqlOperators.ARRAY_AGG_FN) - .put("bit_or", SqlStdOperatorTable.BIT_OR) - .put("bit_xor", SqlOperators.BIT_XOR) - .put("ceil", SqlStdOperatorTable.CEIL) - .put("floor", SqlStdOperatorTable.FLOOR) - .put("mod", SqlStdOperatorTable.MOD) - .put("timestamp", SqlOperators.TIMESTAMP_OP) - .put("$case_no_value", SqlStdOperatorTable.CASE) + // ZetaSQL functions + .put("$and", resolvedFunction -> SqlStdOperatorTable.AND) + .put("$or", resolvedFunction -> SqlStdOperatorTable.OR) + .put("$not", resolvedFunction -> SqlStdOperatorTable.NOT) + .put("$equal", resolvedFunction -> SqlStdOperatorTable.EQUALS) + .put("$not_equal", resolvedFunction -> SqlStdOperatorTable.NOT_EQUALS) + .put("$greater", resolvedFunction -> SqlStdOperatorTable.GREATER_THAN) + .put( + "$greater_or_equal", + resolvedFunction -> SqlStdOperatorTable.GREATER_THAN_OR_EQUAL) + .put("$less", resolvedFunction -> SqlStdOperatorTable.LESS_THAN) + .put("$less_or_equal", resolvedFunction -> SqlStdOperatorTable.LESS_THAN_OR_EQUAL) + .put("$like", resolvedFunction -> SqlOperators.LIKE) + .put("$is_null", resolvedFunction -> SqlStdOperatorTable.IS_NULL) + .put("$is_true", resolvedFunction -> SqlStdOperatorTable.IS_TRUE) + .put("$is_false", resolvedFunction -> SqlStdOperatorTable.IS_FALSE) + .put("$add", resolvedFunction -> SqlStdOperatorTable.PLUS) + .put("$subtract", resolvedFunction -> SqlStdOperatorTable.MINUS) + .put("$multiply", resolvedFunction -> SqlStdOperatorTable.MULTIPLY) + .put("$unary_minus", resolvedFunction -> SqlStdOperatorTable.UNARY_MINUS) + .put("$divide", resolvedFunction -> SqlStdOperatorTable.DIVIDE) + .put("concat", resolvedFunction -> SqlOperators.CONCAT) + .put("substr", resolvedFunction -> SqlOperators.SUBSTR) + .put("substring", resolvedFunction -> SqlOperators.SUBSTR) + .put("trim", resolvedFunction -> SqlOperators.TRIM) + .put("replace", resolvedFunction -> SqlOperators.REPLACE) + .put("char_length", resolvedFunction -> SqlOperators.CHAR_LENGTH) + .put("starts_with", resolvedFunction -> SqlOperators.START_WITHS) + .put("ends_with", resolvedFunction -> SqlOperators.ENDS_WITH) + .put("ltrim", resolvedFunction -> SqlOperators.LTRIM) + .put("rtrim", resolvedFunction -> SqlOperators.RTRIM) + .put("reverse", resolvedFunction -> SqlOperators.REVERSE) + .put("$count_star", resolvedFunction -> SqlStdOperatorTable.COUNT) + .put("max", resolvedFunction -> SqlStdOperatorTable.MAX) + .put("min", resolvedFunction -> SqlStdOperatorTable.MIN) + .put("avg", resolvedFunction -> SqlStdOperatorTable.AVG) + .put("sum", resolvedFunction -> SqlStdOperatorTable.SUM) + .put("any_value", resolvedFunction -> SqlStdOperatorTable.ANY_VALUE) + .put("count", resolvedFunction -> SqlStdOperatorTable.COUNT) + .put("bit_and", resolvedFunction -> SqlStdOperatorTable.BIT_AND) + .put("string_agg", SqlOperators::createStringAggOperator) // NULL values not supported + .put("array_agg", resolvedFunction -> SqlOperators.ARRAY_AGG_FN) + .put("bit_or", resolvedFunction -> SqlStdOperatorTable.BIT_OR) + .put("bit_xor", resolvedFunction -> SqlOperators.BIT_XOR) + .put("ceil", resolvedFunction -> SqlStdOperatorTable.CEIL) + .put("floor", resolvedFunction -> SqlStdOperatorTable.FLOOR) + .put("mod", resolvedFunction -> SqlStdOperatorTable.MOD) + .put("timestamp", resolvedFunction -> SqlOperators.TIMESTAMP_OP) + .put("$case_no_value", resolvedFunction -> SqlStdOperatorTable.CASE) - // if operator - IF(cond, pos, neg) can actually be mapped directly to `CASE WHEN cond - // THEN pos ELSE neg` - .put("if", SqlStdOperatorTable.CASE) + // if operator - IF(cond, pos, neg) can actually be mapped directly to `CASE WHEN cond + // THEN pos ELSE neg` + .put("if", resolvedFunction -> SqlStdOperatorTable.CASE) - // $case_no_value specializations - // all of these operators can have their operands adjusted to achieve the same thing with - // a call to $case_with_value - .put("$case_with_value", SqlStdOperatorTable.CASE) - .put("coalesce", SqlStdOperatorTable.CASE) - .put("ifnull", SqlStdOperatorTable.CASE) - .put("nullif", SqlStdOperatorTable.CASE) - .put("countif", SqlOperators.COUNTIF) - .build(); + // $case_no_value specializations + // all of these operators can have their operands adjusted to achieve the same thing + // with + // a call to $case_with_value + .put("$case_with_value", resolvedFunction -> SqlStdOperatorTable.CASE) + .put("coalesce", resolvedFunction -> SqlStdOperatorTable.CASE) + .put("ifnull", resolvedFunction -> SqlStdOperatorTable.CASE) + .put("nullif", resolvedFunction -> SqlStdOperatorTable.CASE) + .put("countif", resolvedFunction -> SqlOperators.COUNTIF) + .build(); static final Map ZETASQL_FUNCTION_TO_CALCITE_SQL_OPERATOR_REWRITER = ImmutableMap.builder() @@ -105,4 +113,16 @@ class SqlOperatorMappingTable { .put("nullif", new SqlNullIfOperatorRewriter()) .put("$in", new SqlInOperatorRewriter()) .build(); + + static @Nullable SqlOperator create( + ResolvedNodes.ResolvedFunctionCallBase aggregateFunctionCall) { + + Function sqlOperatorFactory = + ZETASQL_FUNCTION_TO_CALCITE_SQL_OPERATOR.get(aggregateFunctionCall.getFunction().getName()); + + if (sqlOperatorFactory != null) { + return sqlOperatorFactory.apply(aggregateFunctionCall); + } + return null; + } } diff --git a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/SqlOperators.java b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/SqlOperators.java index 4fb24233abc8..778d3594d997 100644 --- a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/SqlOperators.java +++ b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/SqlOperators.java @@ -19,7 +19,12 @@ import static org.apache.beam.sdk.extensions.sql.zetasql.BeamZetaSqlCatalog.ZETASQL_FUNCTION_GROUP_NAME; +import com.google.zetasql.Value; +import com.google.zetasql.io.grpc.Status; +import com.google.zetasql.io.grpc.StatusRuntimeException; +import com.google.zetasql.resolvedast.ResolvedNodes; import java.lang.reflect.Method; +import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.List; import org.apache.beam.sdk.annotations.Internal; @@ -31,6 +36,7 @@ import org.apache.beam.sdk.extensions.sql.impl.udaf.ArrayAgg; import org.apache.beam.sdk.extensions.sql.impl.udaf.StringAgg; import org.apache.beam.sdk.extensions.sql.zetasql.DateTimeUtils; +import org.apache.beam.sdk.extensions.sql.zetasql.ZetaSqlException; import org.apache.beam.sdk.extensions.sql.zetasql.translation.impl.BeamBuiltinMethods; import org.apache.beam.sdk.extensions.sql.zetasql.translation.impl.CastFunctionImpl; import org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.jdbc.JavaTypeFactoryImpl; @@ -81,12 +87,6 @@ public class SqlOperators { private static final RelDataType BIGINT = createSqlType(SqlTypeName.BIGINT, false); private static final RelDataType NULLABLE_BIGINT = createSqlType(SqlTypeName.BIGINT, true); - public static final SqlOperator STRING_AGG_STRING_FN = - createUdafOperator( - "string_agg", - x -> createTypeFactory().createSqlType(SqlTypeName.VARCHAR), - new UdafImpl<>(new StringAgg.StringAggString())); - public static final SqlOperator ARRAY_AGG_FN = createUdafOperator( "array_agg", @@ -180,6 +180,49 @@ public class SqlOperators { null, new CastFunctionImpl()); + public static SqlOperator createStringAggOperator( + ResolvedNodes.ResolvedFunctionCallBase aggregateFunctionCall) { + List args = aggregateFunctionCall.getArgumentList(); + String inputType = args.get(0).getType().typeName(); + Value delimiter = null; + if (args.size() == 2) { + ResolvedNodes.ResolvedExpr resolvedExpr = args.get(1); + if (resolvedExpr instanceof ResolvedNodes.ResolvedLiteral) { + delimiter = ((ResolvedNodes.ResolvedLiteral) resolvedExpr).getValue(); + } else { + // TODO(BEAM-13673) Add support for params + throw new ZetaSqlException( + new StatusRuntimeException( + Status.INVALID_ARGUMENT.withDescription( + String.format( + "STRING_AGG only supports ResolvedLiteral as delimiter, provided %s", + resolvedExpr.getClass().getName())))); + } + } + + switch (inputType) { + case "BYTES": + return SqlOperators.createUdafOperator( + "string_agg", + x -> SqlOperators.createTypeFactory().createSqlType(SqlTypeName.VARBINARY), + new UdafImpl<>( + new StringAgg.StringAggByte( + delimiter == null + ? ",".getBytes(StandardCharsets.UTF_8) + : delimiter.getBytesValue().toByteArray()))); + case "STRING": + return SqlOperators.createUdafOperator( + "string_agg", + x -> SqlOperators.createTypeFactory().createSqlType(SqlTypeName.VARCHAR), + new UdafImpl<>( + new StringAgg.StringAggString( + delimiter == null ? "," : delimiter.getStringValue()))); + default: + throw new UnsupportedOperationException( + String.format("[%s] is not supported in STRING_AGG", inputType)); + } + } + /** * Create a dummy SqlFunction of type OTHER_FUNCTION from given function name and return type. * These functions will be unparsed in either {@link diff --git a/sdks/java/extensions/sql/zetasql/src/test/java/org/apache/beam/sdk/extensions/sql/zetasql/ZetaSqlDialectSpecTest.java b/sdks/java/extensions/sql/zetasql/src/test/java/org/apache/beam/sdk/extensions/sql/zetasql/ZetaSqlDialectSpecTest.java index 7c676d998d9b..eaba01ba7826 100644 --- a/sdks/java/extensions/sql/zetasql/src/test/java/org/apache/beam/sdk/extensions/sql/zetasql/ZetaSqlDialectSpecTest.java +++ b/sdks/java/extensions/sql/zetasql/src/test/java/org/apache/beam/sdk/extensions/sql/zetasql/ZetaSqlDialectSpecTest.java @@ -2559,6 +2559,68 @@ public void testStringAggregation() { pipeline.run().waitUntilFinish(Duration.standardMinutes(PIPELINE_EXECUTION_WAITTIME_MINUTES)); } + @Test + public void testStringAggregationBytes() { + String sql = + "SELECT STRING_AGG(CAST(fruit as bytes)) AS string_agg" + + " FROM UNNEST([\"apple\", \"pear\", \"banana\", \"pear\"]) AS fruit"; + PCollection stream = execute(sql); + + Schema schema = Schema.builder().addByteArrayField("bytearray_field").build(); + PAssert.that(stream) + .containsInAnyOrder( + Row.withSchema(schema) + .addValue("apple,pear,banana,pear".getBytes(StandardCharsets.UTF_8)) + .build()); + + pipeline.run().waitUntilFinish(Duration.standardMinutes(PIPELINE_EXECUTION_WAITTIME_MINUTES)); + } + + @Test + public void testStringAggregationDelimiter() { + String sql = + "SELECT STRING_AGG(fruit, \"&\") AS string_agg" + + " FROM UNNEST([\"apple\", \"pear\", \"banana\", \"pear\"]) AS fruit"; + PCollection stream = execute(sql); + + Schema schema = Schema.builder().addStringField("string_field").build(); + PAssert.that(stream) + .containsInAnyOrder(Row.withSchema(schema).addValue("apple&pear&banana&pear").build()); + + pipeline.run().waitUntilFinish(Duration.standardMinutes(PIPELINE_EXECUTION_WAITTIME_MINUTES)); + } + + @Test + public void testStringAggregationBytesDelimiter() { + String sql = + "SELECT STRING_AGG(CAST(fruit as bytes), b\"&\") AS string_agg" + + " FROM UNNEST([\"apple\", \"pear\", \"banana\", \"pear\"]) AS fruit"; + PCollection stream = execute(sql); + + Schema schema = Schema.builder().addByteArrayField("bytearray_field").build(); + PAssert.that(stream) + .containsInAnyOrder( + Row.withSchema(schema) + .addValue("apple&pear&banana&pear".getBytes(StandardCharsets.UTF_8)) + .build()); + + pipeline.run().waitUntilFinish(Duration.standardMinutes(PIPELINE_EXECUTION_WAITTIME_MINUTES)); + } + + @Test + public void testStringAggregationParamsDelimiter() { + String sql = "SELECT string_agg(\"s\", @separator) FROM (SELECT 1)"; + + ImmutableMap params = + ImmutableMap.builder() + .put("separator", Value.createStringValue(",")) + .build(); + + ZetaSQLQueryPlanner zetaSQLQueryPlanner = new ZetaSQLQueryPlanner(config); + thrown.expect(ZetaSqlException.class); // BEAM-13673 + zetaSQLQueryPlanner.convertToBeamRel(sql, params); + } + @Test @Ignore("Seeing exception in Beam, need further investigation on the cause of this failed query.") public void testNamedUNNESTJoin() {