From 7763f0022d80a05cf443857deb6b5bd726f40e01 Mon Sep 17 00:00:00 2001 From: minghong Date: Wed, 15 Jan 2025 18:01:50 +0800 Subject: [PATCH] [opt](nereids)prune unused column after push down common column from agg (#46627) ### What problem does this PR solve? after extracting common expressions for agg, the underlying projection may project redundant columns. for example: original plan Agg(groupkey=[A+B, A+B+1]) --> project(A, B) after extracting, "A+B as C" is detected as a common expression, and the plan becomes Agg(groupKey=[C, C+1]) -->project(A, B, A+B as C) here A, B should not be projected, since they are not used any more. so the optimal plan is Agg(groupKey=[C, C+1]) -->project(A+B as C) Related PR: #40473 --- .../translator/PhysicalPlanTranslator.java | 8 ++++- .../ProjectAggregateExpressionsForCse.java | 33 ++++++++++++++++++- 2 files changed, 39 insertions(+), 2 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java index 28b14398c86deb..ef65d755795ac3 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java @@ -959,7 +959,13 @@ public PlanFragment visitPhysicalHashAggregate( // 1. generate slot reference for each group expression List groupSlots = collectGroupBySlots(groupByExpressions, outputExpressions); ArrayList execGroupingExpressions = groupByExpressions.stream() - .map(e -> ExpressionTranslator.translate(e, context)) + .map(e -> { + Expr result = ExpressionTranslator.translate(e, context); + if (result == null) { + throw new RuntimeException("translate " + e + " failed"); + } + return result; + }) .collect(Collectors.toCollection(ArrayList::new)); // 2. collect agg expressions and generate agg function to slot reference map List aggFunctionOutput = Lists.newArrayList(); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/ProjectAggregateExpressionsForCse.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/ProjectAggregateExpressionsForCse.java index a8038ab30b04ae..d7a13148c1040d 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/ProjectAggregateExpressionsForCse.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/ProjectAggregateExpressionsForCse.java @@ -46,6 +46,7 @@ import java.util.Map; import java.util.Optional; import java.util.Set; +import java.util.stream.Collectors; /** * create project under aggregate to enable CSE @@ -102,8 +103,38 @@ public Plan visitPhysicalHashAggregate(PhysicalHashAggregate agg } if (aggregate.child() instanceof PhysicalProject) { + List newProjections = Lists.newArrayList(); + // do column prune + // case 1: + // original plan + // agg(groupKey[C+1, abs(C+1)] + // -->project(A+B as C) + // + // "A+B as C" should be reserved + // new plan + // agg(groupKey=[D, abs(D)]) + // -->project(A+B as C, C+1 as D) + // case 2: + // original plan + // agg(groupKey[A+1, abs(A+1)], output[sum(B)]) + // --> project(A, B) + // "A+1" is extracted, we have + // plan1: + // agg(groupKey[X, abs(X)], output[sum(B)]) + // --> project(A, B, A+1 as X) + // then column prune(A should be pruned, because it is not used directly by AGG) + // we have plan2: + // agg(groupKey[X, abs(X)], output[sum(B)]) + // -->project(B, A+1 as X) PhysicalProject project = (PhysicalProject) aggregate.child(); - List newProjections = Lists.newArrayList(project.getProjects()); + Set newInputSlots = aggOutputReplaced.stream() + .flatMap(expr -> expr.getInputSlots().stream()) + .collect(Collectors.toSet()); + for (NamedExpression expr : project.getProjects()) { + if (!(expr instanceof SlotReference) || newInputSlots.contains(expr)) { + newProjections.add(expr); + } + } newProjections.addAll(cseCandidates.values()); project = project.withProjectionsAndChild(newProjections, (Plan) project.child()); aggregate = (PhysicalHashAggregate) aggregate