Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -959,7 +959,13 @@ public PlanFragment visitPhysicalHashAggregate(
// 1. generate slot reference for each group expression
List<SlotReference> groupSlots = collectGroupBySlots(groupByExpressions, outputExpressions);
ArrayList<Expr> execGroupingExpressions = groupByExpressions.stream()
.map(e -> ExpressionTranslator.translate(e, context))
.map(e -> {
Expr result = ExpressionTranslator.translate(e, context);
if (result == null) {
throw new RuntimeException("translate " + e + " failed");
}
return result;
})
.collect(Collectors.toCollection(ArrayList::new));
// 2. collect agg expressions and generate agg function to slot reference map
List<Slot> aggFunctionOutput = Lists.newArrayList();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;

/**
* create project under aggregate to enable CSE
Expand Down Expand Up @@ -102,8 +103,38 @@ public Plan visitPhysicalHashAggregate(PhysicalHashAggregate<? extends Plan> agg
}

if (aggregate.child() instanceof PhysicalProject) {
List<NamedExpression> newProjections = Lists.newArrayList();
// do column prune
// case 1:
// original plan
// agg(groupKey[C+1, abs(C+1)]
// -->project(A+B as C)
//
// "A+B as C" should be reserved
// new plan
// agg(groupKey=[D, abs(D)])
// -->project(A+B as C, C+1 as D)
// case 2:
// original plan
// agg(groupKey[A+1, abs(A+1)], output[sum(B)])
// --> project(A, B)
// "A+1" is extracted, we have
// plan1:
// agg(groupKey[X, abs(X)], output[sum(B)])
// --> project(A, B, A+1 as X)
// then column prune(A should be pruned, because it is not used directly by AGG)
// we have plan2:
// agg(groupKey[X, abs(X)], output[sum(B)])
// -->project(B, A+1 as X)
PhysicalProject<? extends Plan> project = (PhysicalProject<? extends Plan>) aggregate.child();
List<NamedExpression> newProjections = Lists.newArrayList(project.getProjects());
Set<Slot> newInputSlots = aggOutputReplaced.stream()
.flatMap(expr -> expr.getInputSlots().stream())
.collect(Collectors.toSet());
for (NamedExpression expr : project.getProjects()) {
if (!(expr instanceof SlotReference) || newInputSlots.contains(expr)) {
newProjections.add(expr);
}
}
newProjections.addAll(cseCandidates.values());
project = project.withProjectionsAndChild(newProjections, (Plan) project.child());
aggregate = (PhysicalHashAggregate<? extends Plan>) aggregate
Expand Down
Loading