-
Notifications
You must be signed in to change notification settings - Fork 4.5k
[BEAM-11808][BEAM-9879] Support aggregate functions with two arguments #16200
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
eb0d35d
b20ae71
cd9032d
940977d
f335217
5b44024
5302e52
cefe7fb
f4411d4
618c69f
7ee7a5d
16d9d4c
8760842
1b69b20
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -20,15 +20,18 @@ | |
| import static com.google.zetasql.ZetaSQLResolvedNodeKind.ResolvedNodeKind.RESOLVED_CAST; | ||
| import static com.google.zetasql.ZetaSQLResolvedNodeKind.ResolvedNodeKind.RESOLVED_COLUMN_REF; | ||
| import static com.google.zetasql.ZetaSQLResolvedNodeKind.ResolvedNodeKind.RESOLVED_GET_STRUCT_FIELD; | ||
| import static com.google.zetasql.ZetaSQLResolvedNodeKind.ResolvedNodeKind.RESOLVED_LITERAL; | ||
|
|
||
| import com.google.zetasql.FunctionSignature; | ||
| import com.google.zetasql.ZetaSQLResolvedNodeKind; | ||
| import com.google.zetasql.ZetaSQLType.TypeKind; | ||
| import com.google.zetasql.resolvedast.ResolvedNode; | ||
| import com.google.zetasql.resolvedast.ResolvedNodes.ResolvedAggregateFunctionCall; | ||
| import com.google.zetasql.resolvedast.ResolvedNodes.ResolvedAggregateScan; | ||
| import com.google.zetasql.resolvedast.ResolvedNodes.ResolvedComputedColumn; | ||
| import com.google.zetasql.resolvedast.ResolvedNodes.ResolvedExpr; | ||
| import java.util.ArrayList; | ||
| import java.util.Arrays; | ||
| import java.util.Collections; | ||
| import java.util.List; | ||
| import java.util.stream.Collectors; | ||
|
|
@@ -149,23 +152,27 @@ private LogicalProject convertAggregateScanInputScanToLogicalProject( | |
| ResolvedAggregateFunctionCall aggregateFunctionCall = | ||
| ((ResolvedAggregateFunctionCall) resolvedComputedColumn.getExpr()); | ||
| if (aggregateFunctionCall.getArgumentList() != null | ||
| && aggregateFunctionCall.getArgumentList().size() == 1) { | ||
| && aggregateFunctionCall.getArgumentList().size() >= 1) { | ||
| ResolvedExpr resolvedExpr = aggregateFunctionCall.getArgumentList().get(0); | ||
|
|
||
| // TODO: assume aggregate function's input is either a ColumnRef or a cast(ColumnRef). | ||
| // TODO: user might use multiple CAST so we need to handle this rare case. | ||
| projects.add( | ||
| getExpressionConverter() | ||
| .convertRexNodeFromResolvedExpr( | ||
| resolvedExpr, | ||
| node.getInputScan().getColumnList(), | ||
| input.getRowType().getFieldList(), | ||
| ImmutableMap.of())); | ||
| fieldNames.add(getTrait().resolveAlias(resolvedComputedColumn.getColumn())); | ||
| } else if (aggregateFunctionCall.getArgumentList() != null | ||
| && aggregateFunctionCall.getArgumentList().size() > 1) { | ||
| throw new IllegalArgumentException( | ||
| aggregateFunctionCall.getFunction().getName() + " has more than one argument."); | ||
| for (int i = 0; i < aggregateFunctionCall.getArgumentList().size(); i++) { | ||
| if (i == 0) { | ||
| // TODO: assume aggregate function's input is either a ColumnRef or a cast(ColumnRef). | ||
| // TODO: user might use multiple CAST so we need to handle this rare case. | ||
| projects.add( | ||
| getExpressionConverter() | ||
| .convertRexNodeFromResolvedExpr( | ||
| resolvedExpr, | ||
| node.getInputScan().getColumnList(), | ||
| input.getRowType().getFieldList(), | ||
| ImmutableMap.of())); | ||
| } else { | ||
| projects.add( | ||
| getExpressionConverter() | ||
| .convertRexNodeFromResolvedExpr( | ||
| aggregateFunctionCall.getArgumentList().get(i))); | ||
| } | ||
| fieldNames.add(getTrait().resolveAlias(resolvedComputedColumn.getColumn())); | ||
| } | ||
| } | ||
| } | ||
|
|
||
|
|
@@ -228,10 +235,7 @@ private AggregateCall convertAggCall( | |
| aggregateFunctionCall.getFunction().getName(), typeInference, impl); | ||
| } else { | ||
| // Look up builtin functions in SqlOperatorMappingTable. | ||
| sqlAggFunction = | ||
| (SqlAggFunction) | ||
| SqlOperatorMappingTable.ZETASQL_FUNCTION_TO_CALCITE_SQL_OPERATOR.get( | ||
| aggregateFunctionCall.getFunction().getName()); | ||
| sqlAggFunction = (SqlAggFunction) SqlOperatorMappingTable.create(aggregateFunctionCall); | ||
| if (sqlAggFunction == null) { | ||
| throw new UnsupportedOperationException( | ||
| "Does not support ZetaSQL aggregate function: " | ||
|
|
@@ -240,18 +244,22 @@ private AggregateCall convertAggCall( | |
| } | ||
|
|
||
| List<Integer> argList = new ArrayList<>(); | ||
| for (ResolvedExpr expr : | ||
| ((ResolvedAggregateFunctionCall) computedColumn.getExpr()).getArgumentList()) { | ||
| ResolvedAggregateFunctionCall expr = ((ResolvedAggregateFunctionCall) computedColumn.getExpr()); | ||
| List<ZetaSQLResolvedNodeKind.ResolvedNodeKind> resolvedNodeKinds = | ||
| Arrays.asList(RESOLVED_CAST, RESOLVED_COLUMN_REF, RESOLVED_GET_STRUCT_FIELD); | ||
| for (int i = 0; i < expr.getArgumentList().size(); i++) { | ||
| // Throw an error if aggregate function's input isn't either a ColumnRef or a cast(ColumnRef). | ||
| // TODO: is there a general way to handle aggregation calls conversion? | ||
| if (expr.nodeKind() == RESOLVED_CAST | ||
| || expr.nodeKind() == RESOLVED_COLUMN_REF | ||
| || expr.nodeKind() == RESOLVED_GET_STRUCT_FIELD) { | ||
| ZetaSQLResolvedNodeKind.ResolvedNodeKind resolvedNodeKind = | ||
| expr.getArgumentList().get(i).nodeKind(); | ||
| if (i == 0 && resolvedNodeKinds.contains(resolvedNodeKind)) { | ||
|
||
| argList.add(columnRefOff); | ||
| } else if (i > 0 && resolvedNodeKind == RESOLVED_LITERAL) { | ||
| continue; | ||
| } else { | ||
| throw new UnsupportedOperationException( | ||
| "Aggregate function only accepts Column Reference or CAST(Column Reference) as its" | ||
| + " input."); | ||
| "Aggregate function only accepts Column Reference or CAST(Column Reference) as the first argument and " | ||
| + "Literals as subsequent arguments as its inputs"); | ||
| } | ||
| } | ||
|
|
||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.