From 7ccf613459c1e6d13cc20be0241d6711f142babd Mon Sep 17 00:00:00 2001 From: 924060929 Date: Tue, 25 Nov 2025 21:12:56 +0800 Subject: [PATCH] opt --- .../rewrite/AccessPathPlanCollector.java | 25 ++++++++++ .../rules/rewrite/PushDownProject.java | 47 ++++++++++--------- .../rules/rewrite/PruneNestedColumnTest.java | 21 ++++++--- 3 files changed, 65 insertions(+), 28 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AccessPathPlanCollector.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AccessPathPlanCollector.java index 514f7bb1e8cfc1..ed253167b385c0 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AccessPathPlanCollector.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AccessPathPlanCollector.java @@ -19,7 +19,9 @@ import org.apache.doris.nereids.StatementContext; import org.apache.doris.nereids.rules.rewrite.AccessPathExpressionCollector.CollectAccessPathResult; +import org.apache.doris.nereids.trees.expressions.Alias; import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.NamedExpression; import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.logical.LogicalCTEAnchor; @@ -28,6 +30,7 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalFileScan; import org.apache.doris.nereids.trees.plans.logical.LogicalFilter; import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan; +import org.apache.doris.nereids.trees.plans.logical.LogicalProject; import org.apache.doris.nereids.trees.plans.logical.LogicalTVFRelation; import org.apache.doris.nereids.trees.plans.logical.LogicalUnion; import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanVisitor; @@ -53,6 +56,28 @@ public Map> collect(Plan root, StatementCont return scanSlotToAccessPaths; } + @Override + public Void visitLogicalProject(LogicalProject project, StatementContext context) { + AccessPathExpressionCollector exprCollector + = new AccessPathExpressionCollector(context, allSlotToAccessPaths, false); + for (NamedExpression output : project.getProjects()) { + // e.g. select struct_element(s, 'city') from (select s from tbl)a; + // we will not treat the inner `s` access all path + if (output instanceof Slot && allSlotToAccessPaths.containsKey(output.getExprId().asInt())) { + continue; + } else if (output instanceof Alias && output.child(0) instanceof Slot + && allSlotToAccessPaths.containsKey(output.getExprId().asInt())) { + Slot innerSlot = (Slot) output.child(0); + Collection outerSlotAccessPaths = allSlotToAccessPaths.get( + output.getExprId().asInt()); + allSlotToAccessPaths.putAll(innerSlot.getExprId().asInt(), outerSlotAccessPaths); + } else { + exprCollector.collect(output); + } + } + return project.child().accept(this, context); + } + @Override public Void visitLogicalFilter(LogicalFilter filter, StatementContext context) { boolean bottomFilter = filter.child().arity() == 0; diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownProject.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownProject.java index 8b1fbaac8aec9e..9fbc9413b29c9c 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownProject.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownProject.java @@ -48,7 +48,6 @@ import java.util.List; import java.util.Map; import java.util.Optional; -import java.util.Set; import java.util.function.Function; /** push down project if the expression instance of PreferPushDownProject */ @@ -320,13 +319,13 @@ private List replaceSlot( private static class PushdownProjectHelper { private final Plan plan; private final StatementContext statementContext; - private final Map> exprToChildAndSlot; + private final Map oldExprToNewExpr; private final Multimap childToPushDownProjects; public PushdownProjectHelper(StatementContext statementContext, Plan plan) { this.statementContext = statementContext; this.plan = plan; - this.exprToChildAndSlot = new LinkedHashMap<>(); + this.oldExprToNewExpr = new LinkedHashMap<>(); this.childToPushDownProjects = ArrayListMultimap.create(); } @@ -357,32 +356,36 @@ public , E extends Expression> Pair pushDown } public Optional pushDownExpression(E expression) { - if (!(expression instanceof PreferPushDownProject - || (expression instanceof Alias && expression.child(0) instanceof PreferPushDownProject))) { + if (!expression.containsType(PreferPushDownProject.class)) { return Optional.empty(); } - Pair existPushdown = exprToChildAndSlot.get(expression); + Expression existPushdown = oldExprToNewExpr.get(expression); if (existPushdown != null) { - return Optional.of((E) existPushdown.first); + return Optional.of((E) existPushdown); } - Alias pushDownAlias = null; - if (expression instanceof Alias) { - pushDownAlias = (Alias) expression; - } else { - pushDownAlias = new Alias(statementContext.getNextExprId(), expression); - } - - Set inputSlots = expression.getInputSlots(); - for (Plan child : plan.children()) { - if (child.getOutputSet().containsAll(inputSlots)) { - Slot remaimSlot = pushDownAlias.toSlot(); - exprToChildAndSlot.put(expression, Pair.of(remaimSlot, child)); - childToPushDownProjects.put(child, pushDownAlias); - return Optional.of((E) remaimSlot); + Expression newExpression = expression.rewriteDownShortCircuit(e -> { + if (e instanceof PreferPushDownProject) { + List children = plan.children(); + for (int i = 0; i < children.size(); i++) { + Plan child = children.get(i); + if (child.getOutputSet().containsAll(e.getInputSlots())) { + Alias alias = new Alias(statementContext.getNextExprId(), e); + Slot slot = alias.toSlot(); + childToPushDownProjects.put(child, alias); + return slot; + } + } } + return e; + }); + + if (newExpression != expression) { + oldExprToNewExpr.put(expression, newExpression); + return Optional.of((E) newExpression); + } else { + return Optional.empty(); } - return Optional.empty(); } public List buildNewChildren() { diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PruneNestedColumnTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PruneNestedColumnTest.java index 9c6674deb3df0e..96d80793d733c2 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PruneNestedColumnTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PruneNestedColumnTest.java @@ -29,6 +29,7 @@ import org.apache.doris.nereids.trees.expressions.NamedExpression; import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.expressions.SlotReference; +import org.apache.doris.nereids.trees.expressions.functions.scalar.Coalesce; import org.apache.doris.nereids.trees.expressions.functions.scalar.StructElement; import org.apache.doris.nereids.trees.expressions.literal.NullLiteral; import org.apache.doris.nereids.trees.plans.physical.PhysicalCTEConsumer; @@ -67,6 +68,7 @@ public void createTable() throws Exception { createTable("create table tbl(\n" + " id int,\n" + + " value int,\n" + " s struct<\n" + " city: string,\n" + " data: array", ImmutableList.of(path("s", "city")), ImmutableList.of() ); - assertColumn("select * from (select struct_element(s, 'city') from tbl union all select null)a", + assertColumn("select * from (select coalesce(struct_element(s, 'city'), 'abc') from tbl union all select null)a", "struct", ImmutableList.of(path("s", "city")), ImmutableList.of() @@ -402,7 +405,7 @@ public void testCteAndUnion() throws Throwable { @Test public void testPushDownThroughJoin() { PlanChecker.from(connectContext) - .analyze("select struct_element(s, 'city') from (select * from tbl)a join (select 100 id, 'f1' name)b on a.id=b.id") + .analyze("select coalesce(struct_element(s, 'city'), 'abc') from (select * from tbl)a join (select 100 id, 'f1' name)b on a.id=b.id") .rewrite() .matches( logicalResultSink( @@ -421,7 +424,9 @@ public void testPushDownThroughJoin() { logicalOneRowRelation() ) ).when(p -> { - Assertions.assertTrue(p.getProjects().size() == 1 && p.getProjects().get(0) instanceof SlotReference); + Assertions.assertTrue(p.getProjects().size() == 1 && p.getProjects().get(0) instanceof Alias + && p.getProjects().get(0).child(0) instanceof Coalesce + && p.getProjects().get(0).child(0).child(0) instanceof Slot); return true; }) ) @@ -474,7 +479,9 @@ public void testPushDownThroughWindow() { }) ) ).when(p -> { - Assertions.assertTrue(p.getProjects().size() == 2 && p.getProjects().get(0) instanceof SlotReference); + Assertions.assertTrue(p.getProjects().size() == 2 + && (p.getProjects().get(0) instanceof SlotReference + || (p.getProjects().get(0) instanceof Alias && p.getProjects().get(0).child(0) instanceof SlotReference))); return true; }) ) @@ -504,7 +511,9 @@ public void testPushDownThroughPartitionTopN() { ) ) ).when(p -> { - Assertions.assertTrue(p.getProjects().size() == 2 && p.getProjects().get(0) instanceof SlotReference); + Assertions.assertTrue(p.getProjects().size() == 2 + && (p.getProjects().get(0) instanceof SlotReference + || p.getProjects().get(0) instanceof Alias && p.getProjects().get(0).child(0) instanceof SlotReference)); return true; }) )