Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add test case. add description about the root cause of bug and the PR intro it

Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@
public class MergeAggregate implements RewriteRuleFactory {
private static final ImmutableSet<String> ALLOW_MERGE_AGGREGATE_FUNCTIONS =
ImmutableSet.of("min", "max", "sum", "any_value");
private Map<ExprId, AggregateFunction> innerAggExprIdToAggFunc = new HashMap<>();

@Override
public List<Rule> buildRules() {
Expand All @@ -75,7 +74,7 @@ public List<Rule> buildRules() {
*/
private Plan mergeTwoAggregate(LogicalAggregate<LogicalAggregate<Plan>> outerAgg) {
LogicalAggregate<Plan> innerAgg = outerAgg.child();

Map<ExprId, AggregateFunction> innerAggExprIdToAggFunc = getInnerAggExprIdToAggFuncMap(innerAgg);
List<NamedExpression> newOutputExpressions = outerAgg.getOutputExpressions().stream()
.map(e -> rewriteAggregateFunction(e, innerAggExprIdToAggFunc))
.collect(Collectors.toList());
Expand All @@ -97,6 +96,7 @@ private Plan mergeAggProjectAgg(LogicalAggregate<LogicalProject<LogicalAggregate
List<NamedExpression> outputExpressions = outerAgg.getOutputExpressions();
List<NamedExpression> replacedOutputExpressions = PlanUtils.replaceExpressionByProjections(
project.getProjects(), (List) outputExpressions);
Map<ExprId, AggregateFunction> innerAggExprIdToAggFunc = getInnerAggExprIdToAggFuncMap(innerAgg);
// rewrite agg function. e.g. max(max)
List<NamedExpression> replacedAggFunc = replacedOutputExpressions.stream()
.filter(expr -> (expr instanceof Alias) && (expr.child(0) instanceof AggregateFunction))
Expand Down Expand Up @@ -152,10 +152,7 @@ private NamedExpression rewriteAggregateFunction(NamedExpression e,

private boolean commonCheck(LogicalAggregate<? extends Plan> outerAgg, LogicalAggregate<Plan> innerAgg,
boolean sameGroupBy, Optional<LogicalProject> projectOptional) {
innerAggExprIdToAggFunc = innerAgg.getOutputExpressions().stream()
.filter(expr -> (expr instanceof Alias) && (expr.child(0) instanceof AggregateFunction))
.collect(Collectors.toMap(NamedExpression::getExprId, value -> (AggregateFunction) value.child(0),
(existValue, newValue) -> existValue));
Map<ExprId, AggregateFunction> innerAggExprIdToAggFunc = getInnerAggExprIdToAggFuncMap(innerAgg);
Set<AggregateFunction> aggregateFunctions = outerAgg.getAggregateFunctions();
List<AggregateFunction> replacedAggFunctions = projectOptional.map(project ->
(List<AggregateFunction>) PlanUtils.replaceExpressionByProjections(
Expand Down Expand Up @@ -225,4 +222,11 @@ private boolean canMergeAggregateWithProject(LogicalAggregate<LogicalProject<Log
boolean sameGroupBy = (innerAgg.getGroupByExpressions().size() == outerAgg.getGroupByExpressions().size());
return commonCheck(outerAgg, innerAgg, sameGroupBy, Optional.of(project));
}

private Map<ExprId, AggregateFunction> getInnerAggExprIdToAggFuncMap(LogicalAggregate<Plan> innerAgg) {
return innerAgg.getOutputExpressions().stream()
.filter(expr -> (expr instanceof Alias) && (expr.child(0) instanceof AggregateFunction))
.collect(Collectors.toMap(NamedExpression::getExprId, value -> (AggregateFunction) value.child(0),
(existValue, newValue) -> existValue));
}
}