Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
*/
package org.apache.beam.sdk.extensions.sql.impl.udaf;

import java.nio.charset.StandardCharsets;
import org.apache.beam.sdk.annotations.Experimental;
import org.apache.beam.sdk.transforms.Combine.CombineFn;

Expand All @@ -28,10 +29,13 @@
@Experimental
public class StringAgg {

/** A {@link CombineFn} that aggregates strings with comma as delimiter. */
/** A {@link CombineFn} that aggregates strings with a string as delimiter. */
public static class StringAggString extends CombineFn<String, String, String> {
private final String delimiter;

private static final String delimiter = ",";
public StringAggString(String delimiter) {
this.delimiter = delimiter;
}

@Override
public String createAccumulator() {
Expand All @@ -43,7 +47,7 @@ public String addInput(String curString, String nextString) {

if (!nextString.isEmpty()) {
if (!curString.isEmpty()) {
curString += StringAggString.delimiter + nextString;
curString += delimiter + nextString;
} else {
curString = nextString;
}
Expand All @@ -58,7 +62,7 @@ public String mergeAccumulators(Iterable<String> accumList) {
for (String stringAccum : accumList) {
if (!stringAccum.isEmpty()) {
if (!mergeString.isEmpty()) {
mergeString += StringAggString.delimiter + stringAccum;
mergeString += delimiter + stringAccum;
} else {
mergeString = stringAccum;
}
Expand All @@ -73,4 +77,51 @@ public String extractOutput(String output) {
return output;
}
}

/** A {@link CombineFn} that aggregates bytes with a byte array as delimiter. */
public static class StringAggByte extends CombineFn<byte[], String, byte[]> {
private final String delimiter;

public StringAggByte(byte[] delimiter) {
this.delimiter = new String(delimiter, StandardCharsets.UTF_8);
}

@Override
public String createAccumulator() {
return "";
}

@Override
public String addInput(String mutableAccumulator, byte[] input) {
if (input != null) {
if (!mutableAccumulator.isEmpty()) {
mutableAccumulator += delimiter + new String(input, StandardCharsets.UTF_8);
} else {
mutableAccumulator = new String(input, StandardCharsets.UTF_8);
}
}
return mutableAccumulator;
}

@Override
public String mergeAccumulators(Iterable<String> accumList) {
String mergeString = "";
for (String stringAccum : accumList) {
if (!stringAccum.isEmpty()) {
if (!mergeString.isEmpty()) {
mergeString += delimiter + stringAccum;
} else {
mergeString = stringAccum;
}
}
}

return mergeString;
}

@Override
public byte[] extractOutput(String output) {
return output.getBytes(StandardCharsets.UTF_8);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -405,9 +405,9 @@ class SupportedZetaSqlBuiltinFunctions {
FunctionSignatureId.FN_MAX, // max
FunctionSignatureId.FN_MIN, // min
FunctionSignatureId.FN_STRING_AGG_STRING, // string_agg(s)
// FunctionSignatureId.FN_STRING_AGG_DELIM_STRING, // string_agg(s, delim_s)
// FunctionSignatureId.FN_STRING_AGG_BYTES, // string_agg(b)
// FunctionSignatureId.FN_STRING_AGG_DELIM_BYTES, // string_agg(b, delim_b)
FunctionSignatureId.FN_STRING_AGG_DELIM_STRING, // string_agg(s, delim_s)
FunctionSignatureId.FN_STRING_AGG_BYTES, // string_agg(b)
FunctionSignatureId.FN_STRING_AGG_DELIM_BYTES, // string_agg(b, delim_b)
FunctionSignatureId.FN_SUM_INT64, // sum
FunctionSignatureId.FN_SUM_DOUBLE, // sum
FunctionSignatureId.FN_SUM_NUMERIC, // sum
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,18 @@
import static com.google.zetasql.ZetaSQLResolvedNodeKind.ResolvedNodeKind.RESOLVED_CAST;
import static com.google.zetasql.ZetaSQLResolvedNodeKind.ResolvedNodeKind.RESOLVED_COLUMN_REF;
import static com.google.zetasql.ZetaSQLResolvedNodeKind.ResolvedNodeKind.RESOLVED_GET_STRUCT_FIELD;
import static com.google.zetasql.ZetaSQLResolvedNodeKind.ResolvedNodeKind.RESOLVED_LITERAL;

import com.google.zetasql.FunctionSignature;
import com.google.zetasql.ZetaSQLResolvedNodeKind;
import com.google.zetasql.ZetaSQLType.TypeKind;
import com.google.zetasql.resolvedast.ResolvedNode;
import com.google.zetasql.resolvedast.ResolvedNodes.ResolvedAggregateFunctionCall;
import com.google.zetasql.resolvedast.ResolvedNodes.ResolvedAggregateScan;
import com.google.zetasql.resolvedast.ResolvedNodes.ResolvedComputedColumn;
import com.google.zetasql.resolvedast.ResolvedNodes.ResolvedExpr;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;
Expand Down Expand Up @@ -149,23 +152,27 @@ private LogicalProject convertAggregateScanInputScanToLogicalProject(
ResolvedAggregateFunctionCall aggregateFunctionCall =
((ResolvedAggregateFunctionCall) resolvedComputedColumn.getExpr());
if (aggregateFunctionCall.getArgumentList() != null
&& aggregateFunctionCall.getArgumentList().size() == 1) {
&& aggregateFunctionCall.getArgumentList().size() >= 1) {
ResolvedExpr resolvedExpr = aggregateFunctionCall.getArgumentList().get(0);

// TODO: assume aggregate function's input is either a ColumnRef or a cast(ColumnRef).
// TODO: user might use multiple CAST so we need to handle this rare case.
projects.add(
getExpressionConverter()
.convertRexNodeFromResolvedExpr(
resolvedExpr,
node.getInputScan().getColumnList(),
input.getRowType().getFieldList(),
ImmutableMap.of()));
fieldNames.add(getTrait().resolveAlias(resolvedComputedColumn.getColumn()));
} else if (aggregateFunctionCall.getArgumentList() != null
&& aggregateFunctionCall.getArgumentList().size() > 1) {
throw new IllegalArgumentException(
aggregateFunctionCall.getFunction().getName() + " has more than one argument.");
for (int i = 0; i < aggregateFunctionCall.getArgumentList().size(); i++) {
if (i == 0) {
// TODO: assume aggregate function's input is either a ColumnRef or a cast(ColumnRef).
// TODO: user might use multiple CAST so we need to handle this rare case.
projects.add(
getExpressionConverter()
.convertRexNodeFromResolvedExpr(
resolvedExpr,
node.getInputScan().getColumnList(),
input.getRowType().getFieldList(),
ImmutableMap.of()));
} else {
projects.add(
getExpressionConverter()
.convertRexNodeFromResolvedExpr(
aggregateFunctionCall.getArgumentList().get(i)));
}
fieldNames.add(getTrait().resolveAlias(resolvedComputedColumn.getColumn()));
}
}
}

Expand Down Expand Up @@ -228,10 +235,7 @@ private AggregateCall convertAggCall(
aggregateFunctionCall.getFunction().getName(), typeInference, impl);
} else {
// Look up builtin functions in SqlOperatorMappingTable.
sqlAggFunction =
(SqlAggFunction)
SqlOperatorMappingTable.ZETASQL_FUNCTION_TO_CALCITE_SQL_OPERATOR.get(
aggregateFunctionCall.getFunction().getName());
sqlAggFunction = (SqlAggFunction) SqlOperatorMappingTable.create(aggregateFunctionCall);
if (sqlAggFunction == null) {
throw new UnsupportedOperationException(
"Does not support ZetaSQL aggregate function: "
Expand All @@ -240,18 +244,22 @@ private AggregateCall convertAggCall(
}

List<Integer> argList = new ArrayList<>();
for (ResolvedExpr expr :
((ResolvedAggregateFunctionCall) computedColumn.getExpr()).getArgumentList()) {
ResolvedAggregateFunctionCall expr = ((ResolvedAggregateFunctionCall) computedColumn.getExpr());
List<ZetaSQLResolvedNodeKind.ResolvedNodeKind> resolvedNodeKinds =
Arrays.asList(RESOLVED_CAST, RESOLVED_COLUMN_REF, RESOLVED_GET_STRUCT_FIELD);
for (int i = 0; i < expr.getArgumentList().size(); i++) {
// Throw an error if aggregate function's input isn't either a ColumnRef or a cast(ColumnRef).
// TODO: is there a general way to handle aggregation calls conversion?
if (expr.nodeKind() == RESOLVED_CAST
|| expr.nodeKind() == RESOLVED_COLUMN_REF
|| expr.nodeKind() == RESOLVED_GET_STRUCT_FIELD) {
ZetaSQLResolvedNodeKind.ResolvedNodeKind resolvedNodeKind =
expr.getArgumentList().get(i).nodeKind();
if (i == 0 && resolvedNodeKinds.contains(resolvedNodeKind)) {
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about the case of a literal as the first argument? If I'm understanding the logic right, it doesn't throw an error even though it should (instead it continues).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added both tests cases and also change the condition to throw an exception in the case of a literal as the first argument. Both tests fall in this case https://ci-beam.apache.org/job/beam_PreCommit_SQL_Commit/4616/.
The tests also failed before changing the condition but with a different exception

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What was the other exception? Also, do you think it would be possible to support literals as the first argument? I'm not sure why that limitation exists.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Both are IndexOutOfBounds when inferring types https://ci-beam.apache.org/job/beam_PreCommit_SQL_Commit/4618/
I think it would be possible to support literals as the first argument, but we should verify every function/case.

Copy link

@ibzib ibzib Jan 12, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alright, for now let's do this:

  • Throw an UnsupportedOperationException when literals are passed as the first argument. In other words, keep the current behavior (prior to this PR).
  • Remove the new array_agg and timestamp test cases.
  • We can add support for literals later if we think it's important. I filed a separate JIRA: https://issues.apache.org/jira/browse/BEAM-13648

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

argList.add(columnRefOff);
} else if (i > 0 && resolvedNodeKind == RESOLVED_LITERAL) {
continue;
} else {
throw new UnsupportedOperationException(
"Aggregate function only accepts Column Reference or CAST(Column Reference) as its"
+ " input.");
"Aggregate function only accepts Column Reference or CAST(Column Reference) as the first argument and "
+ "Literals as subsequent arguments as its inputs");
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -609,7 +609,7 @@ private RexNode convertResolvedFunctionCall(
Map<String, RexNode> outerFunctionArguments) {
final String funGroup = functionCall.getFunction().getGroup();
final String funName = functionCall.getFunction().getName();
SqlOperator op = SqlOperatorMappingTable.ZETASQL_FUNCTION_TO_CALCITE_SQL_OPERATOR.get(funName);
SqlOperator op = SqlOperatorMappingTable.create(functionCall);
List<RexNode> operands = new ArrayList<>();

if (PRE_DEFINED_WINDOW_FUNCTIONS.equals(funGroup)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,9 @@ public RexNode apply(RexBuilder rexBuilder, List<RexNode> operands) {
operands.size() == 2, "NULLIF should have two arguments in function call.");

SqlOperator op =
SqlOperatorMappingTable.ZETASQL_FUNCTION_TO_CALCITE_SQL_OPERATOR.get("$case_no_value");
SqlOperatorMappingTable.ZETASQL_FUNCTION_TO_CALCITE_SQL_OPERATOR
.get("$case_no_value")
.apply(null);
List<RexNode> newOperands =
ImmutableList.of(
rexBuilder.makeCall(
Expand Down
Loading