From 42d36b5808acbdeedd7d2c9a0902379813f634ae Mon Sep 17 00:00:00 2001 From: englefly Date: Fri, 8 Mar 2024 17:32:33 +0800 Subject: [PATCH 01/11] cse --- .../translator/PhysicalPlanTranslator.java | 44 ++++-- .../post/CommonSubExpressionCollector.java | 59 +++++++++ .../post/CommonSubExpressionOpt.java | 125 ++++++++++++++++++ .../processor/post/PlanPostProcessors.java | 3 +- .../trees/plans/physical/PhysicalProject.java | 81 +++++++++++- .../org/apache/doris/planner/PlanNode.java | 40 +++++- .../postprocess/CommonSubExpressionTest.java | 114 ++++++++++++++++ gensrc/thrift/PlanNodes.thrift | 2 + .../suites/nereids_tpch_p0/tpch/cse.groovy | 1 + 9 files changed, 454 insertions(+), 15 deletions(-) create mode 100644 fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/CommonSubExpressionCollector.java create mode 100644 fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/CommonSubExpressionOpt.java create mode 100644 fe/fe-core/src/test/java/org/apache/doris/nereids/postprocess/CommonSubExpressionTest.java create mode 100644 regression-test/suites/nereids_tpch_p0/tpch/cse.groovy diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java index f47b6826ebe2ce..b993da29fc376b 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java @@ -1837,15 +1837,38 @@ public PlanFragment visitPhysicalProject(PhysicalProject project registerRewrittenSlot(project, (OlapScanNode) inputFragment.getPlanRoot()); } - List projectionExprs = project.getProjects() - .stream() - .map(e -> ExpressionTranslator.translate(e, context)) - .collect(Collectors.toList()); - List slots = project.getProjects() - .stream() - .map(NamedExpression::toSlot) - .collect(Collectors.toList()); - + PlanNode inputPlanNode = inputFragment.getPlanRoot(); + List projectionExprs = null; + List allProjectionExprs = Lists.newArrayList(); + List slots = null; + if (project.hasMultiLayerProjection()) { + int layerCount = project.getMultiLayerProjects().size(); + for (int i = 0; i < layerCount; i++) { + List layer = project.getMultiLayerProjects().get(i); + projectionExprs = layer.stream() + .map(e -> ExpressionTranslator.translate(e, context)) + .collect(Collectors.toList()); + slots = layer.stream() + .map(NamedExpression::toSlot) + .collect(Collectors.toList()); + if (i < layerCount - 1) { + inputPlanNode.addProjectList(projectionExprs); + TupleDescriptor projectionTuple = generateTupleDesc(slots, null, context); + inputPlanNode.addOutputTupleDescList(projectionTuple); + } + allProjectionExprs.addAll(projectionExprs); + } + } else { + projectionExprs = project.getProjects() + .stream() + .map(e -> ExpressionTranslator.translate(e, context)) + .collect(Collectors.toList()); + slots = project.getProjects() + .stream() + .map(NamedExpression::toSlot) + .collect(Collectors.toList()); + allProjectionExprs.addAll(projectionExprs); + } // process multicast sink if (inputFragment instanceof MultiCastPlanFragment) { MultiCastDataSink multiCastDataSink = (MultiCastDataSink) inputFragment.getSink(); @@ -1857,10 +1880,9 @@ public PlanFragment visitPhysicalProject(PhysicalProject project return inputFragment; } - PlanNode inputPlanNode = inputFragment.getPlanRoot(); List conjuncts = inputPlanNode.getConjuncts(); Set requiredSlotIdSet = Sets.newHashSet(); - for (Expr expr : projectionExprs) { + for (Expr expr : allProjectionExprs) { Expr.extractSlots(expr, requiredSlotIdSet); } Set requiredByProjectSlotIdSet = Sets.newHashSet(requiredSlotIdSet); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/CommonSubExpressionCollector.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/CommonSubExpressionCollector.java new file mode 100644 index 00000000000000..5abc5f6f60ffa2 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/CommonSubExpressionCollector.java @@ -0,0 +1,59 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.processor.post; + +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; + +import java.util.HashMap; +import java.util.LinkedHashSet; +import java.util.Map; +import java.util.Set; + +/** + * collect common expr + */ +public class CommonSubExpressionCollector extends ExpressionVisitor { + public final Map> commonExprByDepth = new HashMap<>(); + private final Map> expressionsByDepth = new HashMap<>(); + + @Override + public Integer visit(Expression expr, Void context) { + if (expr.children().isEmpty()) { + return 0; + } + return collectCommonExpressionByDepth(expr.children().stream().map(child -> + child.accept(this, context)).reduce(Math::max).map(m -> m + 1).orElse(1), expr); + } + + private int collectCommonExpressionByDepth(int depth, Expression expr) { + Set expressions = getExpressionsFromDepthMap(depth, expressionsByDepth); + if (expressions.contains(expr)) { + Set commonExpression = getExpressionsFromDepthMap(depth, commonExprByDepth); + commonExpression.add(expr); + } + expressions.add(expr); + return depth; + } + + public static Set getExpressionsFromDepthMap( + int depth, Map> depthMap) { + depthMap.putIfAbsent(depth, new LinkedHashSet<>()); + return depthMap.get(depth); + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/CommonSubExpressionOpt.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/CommonSubExpressionOpt.java new file mode 100644 index 00000000000000..dfaf2de757e45e --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/CommonSubExpressionOpt.java @@ -0,0 +1,125 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.processor.post; + +import org.apache.doris.nereids.CascadesContext; +import org.apache.doris.nereids.trees.expressions.Alias; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.NamedExpression; +import org.apache.doris.nereids.trees.expressions.Slot; +import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter; +import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.trees.plans.physical.PhysicalProject; + +import com.google.common.collect.Lists; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; + +/** + * Select A+B, (A+B+C)*2, (A+B+C)*3, D from T + * + * before optimize + * projection: + * Proj: A+B, (A+B+C)*2, (A+B+C)*3, D + * + * --- + * after optimize: + * Projection: List < List < Expression > > + * A+B, C, D + * A+B, A+B+C, D + * A+B, (A+B+C)*2, (A+B+C)*3, D + */ +public class CommonSubExpressionOpt extends PlanPostProcessor { + @Override + public PhysicalProject visitPhysicalProject(PhysicalProject project, CascadesContext ctx) { + + List> multiLayers = computeMultiLayerProjections( + project.getInputSlots(), project.getProjects()); + project.setMultiLayerProjects(multiLayers); + return project; + } + + private List> computeMultiLayerProjections( + Set inputSlots, List projects) { + + List> multiLayers = Lists.newArrayList(); + CommonSubExpressionCollector collector = new CommonSubExpressionCollector(); + for (Expression expr : projects) { + expr.accept(collector, null); + } + Map commonExprToAliasMap = new HashMap<>(); + collector.commonExprByDepth.values().stream().flatMap(expressions -> expressions.stream()) + .forEach(expression -> { + if (expression instanceof Alias) { + commonExprToAliasMap.put(expression, (Alias) expression); + } else { + commonExprToAliasMap.put(expression, new Alias(expression)); + } + }); + Map aliasMap = new HashMap<>(); + if (!collector.commonExprByDepth.isEmpty()) { + for (int i = 1; i <= collector.commonExprByDepth.size(); i++) { + List layer = Lists.newArrayList(); + layer.addAll(inputSlots); + Set exprsInDepth = CommonSubExpressionCollector + .getExpressionsFromDepthMap(i, collector.commonExprByDepth); + exprsInDepth.forEach(expr -> { + Expression rewritten = expr.accept(ExpressionReplacer.INSTANCE, aliasMap); + Alias alias = new Alias(rewritten); + aliasMap.put(expr, alias); + }); + layer.addAll(aliasMap.values()); + multiLayers.add(layer); + } + // final layer + List finalLayer = Lists.newArrayList(); + projects.forEach(expr -> { + Expression rewritten = expr.accept(ExpressionReplacer.INSTANCE, aliasMap); + if (rewritten instanceof Slot) { + finalLayer.add((NamedExpression) rewritten); + } else if (rewritten instanceof Alias) { + finalLayer.add(new Alias(expr.getExprId(), ((Alias) rewritten).child(), expr.getName())); + } + }); + multiLayers.add(finalLayer); + } + return multiLayers; + } + + /** + * replace sub expr by aliasMap + */ + public static class ExpressionReplacer + extends DefaultExpressionRewriter> { + public static final ExpressionReplacer INSTANCE = new ExpressionReplacer(); + + private ExpressionReplacer() { + } + + @Override + public Expression visit(Expression expr, Map replaceMap) { + if (replaceMap.containsKey(expr)) { + return replaceMap.get(expr).toSlot(); + } + return super.visit(expr, replaceMap); + } + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/PlanPostProcessors.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/PlanPostProcessors.java index 60c1a74445e1ff..86c8486ef45710 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/PlanPostProcessors.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/PlanPostProcessors.java @@ -63,8 +63,9 @@ public List getProcessors() { builder.add(new MergeProjectPostProcessor()); builder.add(new RecomputeLogicalPropertiesProcessor()); builder.add(new AddOffsetIntoDistribute()); + builder.add(new CommonSubExpressionOpt()); + // DO NOT replace PLAN NODE from here builder.add(new TopNScanOpt()); - // after generate rf, DO NOT replace PLAN NODE builder.add(new FragmentProcessor()); if (!cascadesContext.getConnectContext().getSessionVariable().getRuntimeFilterMode() .toUpperCase().equals(TRuntimeFilterMode.OFF.name())) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalProject.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalProject.java index af7bb950a97d96..e8472b6af23a6e 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalProject.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalProject.java @@ -24,6 +24,7 @@ import org.apache.doris.nereids.processor.post.RuntimeFilterGenerator; import org.apache.doris.nereids.properties.LogicalProperties; import org.apache.doris.nereids.properties.PhysicalProperties; +import org.apache.doris.nereids.trees.expressions.Add; import org.apache.doris.nereids.trees.expressions.Alias; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.NamedExpression; @@ -41,6 +42,7 @@ import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; +import com.google.common.collect.Lists; import java.util.List; import java.util.Objects; @@ -52,6 +54,12 @@ public class PhysicalProject extends PhysicalUnary implements Project { private final List projects; + //multiLayerProjects is used to extract common expressions + // projects: (A+B) * 2, (A+B) * 3 + // multiLayerProjects: + // L1: A+B as x + // L2: x*2, x*3 + private List> multiLayerProjects = Lists.newArrayList(); public PhysicalProject(List projects, LogicalProperties logicalProperties, CHILD_TYPE child) { this(projects, Optional.empty(), logicalProperties, child); @@ -227,7 +235,12 @@ public boolean pushDownRuntimeFilter(CascadesContext context, IdGenerator computeOutput() { - return projects.stream() + List output = projects; + if (! multiLayerProjects.isEmpty()) { + int layers = multiLayerProjects.size(); + output = multiLayerProjects.get(layers - 1); + } + return output.stream() .map(NamedExpression::toSlot) .collect(ImmutableList.toImmutableList()); } @@ -237,4 +250,70 @@ public PhysicalProject resetLogicalProperties() { return new PhysicalProject<>(projects, groupExpression, null, physicalProperties, statistics, child()); } + + /** + * extract common expr, set multi layer projects + */ + public void computeMultiLayerProjectsForCommonExpress() { + // hard code: select (s_suppkey + s_nationkey), 1+(s_suppkey + s_nationkey), s_name from supplier; + if (projects.size() == 3) { + if (projects.get(2) instanceof SlotReference) { + SlotReference sName = (SlotReference) projects.get(2); + if (sName.getName().equals("s_name")) { + Alias a1 = (Alias) projects.get(0); // (s_suppkey + s_nationkey) + Alias a2 = (Alias) projects.get(1); // 1+(s_suppkey + s_nationkey) + // L1: (s_suppkey + s_nationkey) as x, s_name + multiLayerProjects.add(Lists.newArrayList(projects.get(0), projects.get(2))); + List l2 = Lists.newArrayList(); + l2.add(a1.toSlot()); + Alias a3 = new Alias(a2.getExprId(), new Add(a1.toSlot(), a2.child().child(1)), a2.getName()); + l2.add(a3); + l2.add(sName); + // L2: x, (1+x) as y, s_name + multiLayerProjects.add(l2); + } + } + } + // hard code: + // select (s_suppkey + n_regionkey) + 1 as x, (s_suppkey + n_regionkey) + 2 as y + // from supplier join nation on s_nationkey=n_nationkey + // projects: x, y + // multi L1: s_suppkey, n_regionkey, (s_suppkey + n_regionkey) as z + // L2: z +1 as x, z+2 as y + if (projects.size() == 2 && projects.get(0) instanceof Alias && projects.get(1) instanceof Alias + && ((Alias) projects.get(0)).getName().equals("x") + && ((Alias) projects.get(1)).getName().equals("y")) { + Alias a0 = (Alias) projects.get(0); + Alias a1 = (Alias) projects.get(1); + Add common = (Add) a0.child().child(0); // s_suppkey + n_regionkey + List l1 = Lists.newArrayList(); + common.children().stream().forEach(child -> l1.add((SlotReference) child)); + Alias aliasOfCommon = new Alias(common); + l1.add(aliasOfCommon); + multiLayerProjects.add(l1); + Add add1 = new Add(common, a0.child().child(0).child(1)); + Alias aliasOfAdd1 = new Alias(a0.getExprId(), add1, a0.getName()); + Add add2 = new Add(common, a1.child().child(0).child(1)); + Alias aliasOfAdd2 = new Alias(a1.getExprId(), add2, a1.getName()); + List l2 = Lists.newArrayList(aliasOfAdd1, aliasOfAdd2); + multiLayerProjects.add(l2); + } + } + + public boolean hasMultiLayerProjection() { + return !multiLayerProjects.isEmpty(); + } + + public List> getMultiLayerProjects() { + return multiLayerProjects; + } + + public void setMultiLayerProjects(List> multiLayers) { + this.multiLayerProjects = multiLayers; + } + + @Override + public List getOutput() { + return computeOutput(); + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/planner/PlanNode.java b/fe/fe-core/src/main/java/org/apache/doris/planner/PlanNode.java index b404bc4ad3545c..dea8e3263a5579 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/planner/PlanNode.java +++ b/fe/fe-core/src/main/java/org/apache/doris/planner/PlanNode.java @@ -59,6 +59,7 @@ import java.util.List; import java.util.Map; import java.util.Set; +import java.util.stream.Collectors; /** * Each PlanNode represents a single relational operator @@ -155,6 +156,8 @@ public abstract class PlanNode extends TreeNode implements PlanStats { protected int nereidsId = -1; private List> childrenDistributeExprLists = new ArrayList<>(); + private List outputTupleDescList = Lists.newArrayList(); + private List> projectListList = Lists.newArrayList(); protected PlanNode(PlanNodeId id, ArrayList tupleIds, String planNodeName, StatisticalType statisticalType) { @@ -536,10 +539,19 @@ protected final String getExplainString(String rootPrefix, String prefix, TExpla expBuilder.append(detailPrefix + "limit: " + limit + "\n"); } if (!CollectionUtils.isEmpty(projectList)) { - expBuilder.append(detailPrefix).append("projections: ").append(getExplainString(projectList)).append("\n"); - expBuilder.append(detailPrefix).append("project output tuple id: ") + expBuilder.append(detailPrefix).append("final projections: ").append(getExplainString(projectList)).append("\n"); + expBuilder.append(detailPrefix).append("final project output tuple id: ") .append(outputTupleDesc.getId().asInt()).append("\n"); } + if (!projectListList.isEmpty()) { + int layers = projectListList.size(); + for (int i = layers - 1; i >= 0; i--) { + expBuilder.append(detailPrefix).append("intermediate projections: ") + .append(getExplainString(projectListList.get(i))).append("\n"); + expBuilder.append(detailPrefix).append("intermediate tuple id: ") + .append(outputTupleDescList.get(i).getId().asInt()).append("\n"); + } + } if (!CollectionUtils.isEmpty(childrenDistributeExprLists)) { for (List distributeExprList : childrenDistributeExprLists) { expBuilder.append(detailPrefix).append("distribute expr lists: ") @@ -660,6 +672,22 @@ private void treeToThriftHelper(TPlan container) { } } } + if (outputTupleDescList != null && ! outputTupleDescList.isEmpty()) { + outputTupleDescList + .forEach(tupleDescriptor -> msg.addToOutputTupleIdList(tupleDescriptor.getId().asInt())); + // hashJoinNode.outputTupleDesc is null, its counterpart is vOutputTupleDesc + if (outputTupleDesc != null) { + msg.addToOutputTupleIdList(outputTupleDesc.getId().asInt()); + } + if (projectList != null) { + projectListList.forEach( + projectList -> msg.addToProjectionsList( + projectList.stream().map(expr -> expr.treeToThrift()).collect(Collectors.toList()))); + msg.addToProjectionsList(projectList.stream() + .map(expr -> expr.treeToThrift()).collect(Collectors.toList())); + } + } + if (this instanceof ExchangeNode) { msg.num_children = 0; return; @@ -1221,4 +1249,12 @@ public boolean pushDownAggNoGroupingCheckCol(FunctionCallExpr aggExpr, Column co public void setNereidsId(int nereidsId) { this.nereidsId = nereidsId; } + + public void addOutputTupleDescList(TupleDescriptor tupleDescriptor) { + outputTupleDescList.add(tupleDescriptor); + } + + public void addProjectList(List exprs) { + projectListList.add(exprs); + } } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/postprocess/CommonSubExpressionTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/postprocess/CommonSubExpressionTest.java new file mode 100644 index 00000000000000..c666371a46b4e1 --- /dev/null +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/postprocess/CommonSubExpressionTest.java @@ -0,0 +1,114 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.postprocess; + +import org.apache.doris.nereids.processor.post.CommonSubExpressionCollector; +import org.apache.doris.nereids.processor.post.CommonSubExpressionOpt; +import org.apache.doris.nereids.rules.expression.ExpressionRewrite; +import org.apache.doris.nereids.rules.expression.ExpressionRewriteTestHelper; +import org.apache.doris.nereids.trees.expressions.Alias; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.NamedExpression; + +import org.apache.doris.nereids.trees.expressions.Slot; +import org.apache.doris.nereids.trees.expressions.SlotReference; +import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter; +import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionVisitor; +import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; +import org.apache.doris.nereids.types.IntegerType; +import org.apache.doris.nereids.types.StringType; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; + +import java.lang.reflect.Method; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; + +public class CommonSubExpressionTest extends ExpressionRewriteTestHelper { + @Test + public void testExtractCommonExpr() { + List exprs = parseProjections("a+b, a+b+1, abs(a+b+1), a"); + CommonSubExpressionCollector collector = + new CommonSubExpressionCollector(); + exprs.forEach(expr -> collector.visit(expr, null)); + System.out.println(collector.commonExprByDepth); + Assertions.assertEquals(2, collector.commonExprByDepth.size()); + List l1 = collector.commonExprByDepth.get(Integer.valueOf(1)) + .stream().collect(Collectors.toList()); + List l2 = collector.commonExprByDepth.get(Integer.valueOf(2)) + .stream().collect(Collectors.toList()); + Assertions.assertEquals(1, l1.size()); + assertExpression(l1.get(0), "a+b"); + Assertions.assertEquals(1, l2.size()); + assertExpression(l2.get(0), "a+b+1"); + } + + @Test + public void testMultiLayers() throws Exception { + List exprs = parseProjections("a, a+b, a+b+1, abs(a+b+1), a"); + Set inputSlots = exprs.get(0).getInputSlots(); + CommonSubExpressionOpt opt = new CommonSubExpressionOpt(); + Method computeMultLayerProjectionsMethod = CommonSubExpressionOpt.class + .getDeclaredMethod("computeMultiLayerProjections", Set.class, List.class); + computeMultLayerProjectionsMethod.setAccessible(true); + List> multiLayers = (List>) computeMultLayerProjectionsMethod + .invoke(opt, inputSlots, exprs); + System.out.println(multiLayers); + } + + private void assertExpression(Expression expr, String str) { + Assertions.assertEquals(PARSER.parseExpression(str), expr); + } + + private List parseProjections(String exprList) { + List result = new ArrayList<>(); + String[] exprArray = exprList.split(","); + HashMap slotMap = new HashMap<>(); + for (String item : exprArray) { + Expression expr = PARSER.parseExpression(item); + expr = expr.accept(DataTypeAssignor.INSTANCE, slotMap); + if (expr instanceof NamedExpression) { + result.add(expr); + } else { + result.add(new Alias(expr)); + } + } + return result; + } + + public static class DataTypeAssignor extends DefaultExpressionRewriter> { + public static DataTypeAssignor INSTANCE = new DataTypeAssignor(); + + @Override + public Expression visitSlot(Slot slot, Map slotMap) { + SlotReference exitsSlot = slotMap.get(slot.getName()); + if (exitsSlot != null) { + return exitsSlot; + } else { + SlotReference slotReference = new SlotReference(slot.getName(), IntegerType.INSTANCE); + slotMap.put(slot.getName(), slotReference); + return slotReference; + } + } + } + +} diff --git a/gensrc/thrift/PlanNodes.thrift b/gensrc/thrift/PlanNodes.thrift index 2fadcdae538795..b29700b6220d1d 100644 --- a/gensrc/thrift/PlanNodes.thrift +++ b/gensrc/thrift/PlanNodes.thrift @@ -1298,6 +1298,8 @@ struct TPlanNode { 101: optional list projections 102: optional Types.TTupleId output_tuple_id 103: optional TPartitionSortNode partition_sort_node + 104: optional list> projections_list + 105: optional list output_tuple_id_list } // A flattened representation of a tree of PlanNodes, obtained by depth-first diff --git a/regression-test/suites/nereids_tpch_p0/tpch/cse.groovy b/regression-test/suites/nereids_tpch_p0/tpch/cse.groovy new file mode 100644 index 00000000000000..a021e8cd1f7965 --- /dev/null +++ b/regression-test/suites/nereids_tpch_p0/tpch/cse.groovy @@ -0,0 +1 @@ +explain verbose select (s_nationkey + s_suppkey), (s_nationkey + s_suppkey) + 1, abs((s_nationkey + s_suppkey) + 1) from supplier; \ No newline at end of file From 72471acff940a39f6ad8b89709cc9e0210548dc0 Mon Sep 17 00:00:00 2001 From: englefly Date: Fri, 22 Mar 2024 21:43:57 +0800 Subject: [PATCH 02/11] TODO required slot --- .../nereids/glue/translator/PhysicalPlanTranslator.java | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java index b993da29fc376b..65b9760c167161 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java @@ -1917,8 +1917,10 @@ public PlanFragment visitPhysicalProject(PhysicalProject project requiredSlotIdSet.forEach(e -> requiredExprIds.add(context.findExprId(e))); for (ExprId exprId : requiredExprIds) { SlotId slotId = ((HashJoinNode) joinNode).getHashOutputExprSlotIdMap().get(exprId); - Preconditions.checkState(slotId != null); - ((HashJoinNode) joinNode).addSlotIdToHashOutputSlotIds(slotId); + // Preconditions.checkState(slotId != null); + if (slotId != null) { + ((HashJoinNode) joinNode).addSlotIdToHashOutputSlotIds(slotId); + } } } return inputFragment; From 659e0482aa329b72b0104aecfe1e8bf5baf9f4b8 Mon Sep 17 00:00:00 2001 From: englefly Date: Sat, 23 Mar 2024 12:44:21 +0800 Subject: [PATCH 03/11] fmt --- .../src/main/java/org/apache/doris/planner/PlanNode.java | 3 ++- .../doris/nereids/postprocess/CommonSubExpressionTest.java | 6 +----- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/planner/PlanNode.java b/fe/fe-core/src/main/java/org/apache/doris/planner/PlanNode.java index dea8e3263a5579..0853cb84797645 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/planner/PlanNode.java +++ b/fe/fe-core/src/main/java/org/apache/doris/planner/PlanNode.java @@ -539,7 +539,8 @@ protected final String getExplainString(String rootPrefix, String prefix, TExpla expBuilder.append(detailPrefix + "limit: " + limit + "\n"); } if (!CollectionUtils.isEmpty(projectList)) { - expBuilder.append(detailPrefix).append("final projections: ").append(getExplainString(projectList)).append("\n"); + expBuilder.append(detailPrefix).append("final projections: ") + .append(getExplainString(projectList)).append("\n"); expBuilder.append(detailPrefix).append("final project output tuple id: ") .append(outputTupleDesc.getId().asInt()).append("\n"); } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/postprocess/CommonSubExpressionTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/postprocess/CommonSubExpressionTest.java index c666371a46b4e1..d15926d447e2d3 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/postprocess/CommonSubExpressionTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/postprocess/CommonSubExpressionTest.java @@ -19,19 +19,15 @@ import org.apache.doris.nereids.processor.post.CommonSubExpressionCollector; import org.apache.doris.nereids.processor.post.CommonSubExpressionOpt; -import org.apache.doris.nereids.rules.expression.ExpressionRewrite; import org.apache.doris.nereids.rules.expression.ExpressionRewriteTestHelper; import org.apache.doris.nereids.trees.expressions.Alias; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.NamedExpression; - import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.expressions.SlotReference; import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter; -import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionVisitor; -import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; import org.apache.doris.nereids.types.IntegerType; -import org.apache.doris.nereids.types.StringType; + import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; From 0ce465333c1ecf48dab431cc7cd438764597559f Mon Sep 17 00:00:00 2001 From: englefly Date: Mon, 25 Mar 2024 14:30:36 +0800 Subject: [PATCH 04/11] fix-ut --- .../doris/catalog/CreateFunctionTest.java | 41 ++++++++++++------- .../postprocess/CommonSubExpressionTest.java | 39 ++++++++++++++---- 2 files changed, 57 insertions(+), 23 deletions(-) diff --git a/fe/fe-core/src/test/java/org/apache/doris/catalog/CreateFunctionTest.java b/fe/fe-core/src/test/java/org/apache/doris/catalog/CreateFunctionTest.java index 0f464ba2946b7d..c342d858fe1fb5 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/catalog/CreateFunctionTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/catalog/CreateFunctionTest.java @@ -74,6 +74,7 @@ public static void teardown() { public void test() throws Exception { ConnectContext ctx = UtFrameUtils.createDefaultCtx(); ctx.getSessionVariable().setEnableNereidsPlanner(false); + ctx.getSessionVariable().enableFallbackToOriginalPlanner = true; ctx.getSessionVariable().setEnableFoldConstantByBe(false); // create database db1 createDatabase(ctx, "create database db1;"); @@ -113,8 +114,8 @@ public void test() throws Exception { Assert.assertTrue(constExprLists.get(0).get(0) instanceof FunctionCallExpr); queryStr = "select db1.id_masking(k1) from db1.tbl1"; - Assert.assertTrue( - dorisAssert.query(queryStr).explainQuery().contains("concat(left(CAST(CAST(k1 AS BIGINT) AS VARCHAR(65533)), 3), '****', right(CAST(CAST(k1 AS BIGINT) AS VARCHAR(65533)), 4))")); + Assert.assertTrue(containsIgnoreCase(dorisAssert.query(queryStr).explainQuery(), + "concat(left(CAST(CAST(k1 AS BIGINT) AS VARCHAR(65533)), 3), '****', right(CAST(CAST(k1 AS BIGINT) AS VARCHAR(65533)), 4))")); // create alias function with cast // cast any type to decimal with specific precision and scale @@ -142,14 +143,16 @@ public void test() throws Exception { queryStr = "select db1.decimal(k3, 4, 1) from db1.tbl1;"; if (Config.enable_decimal_conversion) { - Assert.assertTrue(dorisAssert.query(queryStr).explainQuery().contains("CAST(`k3` AS DECIMALV3(4, 1))")); + Assert.assertTrue(containsIgnoreCase(dorisAssert.query(queryStr).explainQuery(), + "CAST(`k3` AS DECIMALV3(4, 1))")); } else { - Assert.assertTrue(dorisAssert.query(queryStr).explainQuery().contains("CAST(`k3` AS DECIMAL(4, 1))")); + Assert.assertTrue(containsIgnoreCase(dorisAssert.query(queryStr).explainQuery(), + "CAST(`k3` AS DECIMAL(4, 1))")); } // cast any type to varchar with fixed length - createFuncStr = "create alias function db1.varchar(all) with parameter(text) as " - + "cast(text as varchar(65533));"; + createFuncStr = "create alias function db1.varchar(all, int) with parameter(text, length) as " + + "cast(text as varchar(length));"; createFunctionStmt = (CreateFunctionStmt) UtFrameUtils.parseAndAnalyzeStmt(createFuncStr, ctx); Env.getCurrentEnv().createFunction(createFunctionStmt); @@ -172,7 +175,8 @@ public void test() throws Exception { Assert.assertTrue(constExprLists.get(0).get(0) instanceof StringLiteral); queryStr = "select db1.varchar(k1, 4) from db1.tbl1;"; - Assert.assertTrue(dorisAssert.query(queryStr).explainQuery().contains("CAST(`k1` AS VARCHAR(65533))")); + Assert.assertTrue(containsIgnoreCase(dorisAssert.query(queryStr).explainQuery(), + "CAST(`k1` AS VARCHAR(65533))")); // cast any type to char with fixed length createFuncStr = "create alias function db1.to_char(all, int) with parameter(text, length) as " @@ -199,7 +203,8 @@ public void test() throws Exception { Assert.assertTrue(constExprLists.get(0).get(0) instanceof StringLiteral); queryStr = "select db1.to_char(k1, 4) from db1.tbl1;"; - Assert.assertTrue(dorisAssert.query(queryStr).explainQuery().contains("CAST(`k1` AS CHARACTER")); + Assert.assertTrue(containsIgnoreCase(dorisAssert.query(queryStr).explainQuery(), + "CAST(`k1` AS CHARACTER")); } @Test @@ -235,8 +240,8 @@ public void testCreateGlobalFunction() throws Exception { testFunctionQuery(ctx, queryStr, false); queryStr = "select id_masking(k1) from db2.tbl1"; - Assert.assertTrue( - dorisAssert.query(queryStr).explainQuery().contains("concat(left(CAST(CAST(k1 AS BIGINT) AS VARCHAR(65533)), 3), '****', right(CAST(CAST(k1 AS BIGINT) AS VARCHAR(65533)), 4))")); + Assert.assertTrue(containsIgnoreCase(dorisAssert.query(queryStr).explainQuery(), + "concat(left(CAST(CAST(k1 AS BIGINT) AS VARCHAR(65533)), 3), '****', right(CAST(CAST(k1 AS BIGINT) AS VARCHAR(65533)), 4))")); // 4. create alias function with cast // cast any type to decimal with specific precision and scale @@ -253,9 +258,11 @@ public void testCreateGlobalFunction() throws Exception { queryStr = "select decimal(k3, 4, 1) from db2.tbl1;"; if (Config.enable_decimal_conversion) { - Assert.assertTrue(dorisAssert.query(queryStr).explainQuery().contains("CAST(`k3` AS DECIMALV3(4, 1))")); + Assert.assertTrue(containsIgnoreCase(dorisAssert.query(queryStr).explainQuery(), + "CAST(`k3` AS DECIMALV3(4, 1))")); } else { - Assert.assertTrue(dorisAssert.query(queryStr).explainQuery().contains("CAST(`k3` AS DECIMAL(4, 1))")); + Assert.assertTrue(containsIgnoreCase(dorisAssert.query(queryStr).explainQuery(), + "CAST(`k3` AS DECIMAL(4, 1))")); } // 5. cast any type to varchar with fixed length @@ -271,7 +278,8 @@ public void testCreateGlobalFunction() throws Exception { testFunctionQuery(ctx, queryStr, true); queryStr = "select varchar(k1, 4) from db2.tbl1;"; - Assert.assertTrue(dorisAssert.query(queryStr).explainQuery().contains("CAST(`k1` AS VARCHAR(65533))")); + Assert.assertTrue(containsIgnoreCase(dorisAssert.query(queryStr).explainQuery(), + "CAST(`k1` AS VARCHAR(65533))")); // 6. cast any type to char with fixed length createFuncStr = "create global alias function db2.to_char(all, int) with parameter(text, length) as " @@ -286,7 +294,8 @@ public void testCreateGlobalFunction() throws Exception { testFunctionQuery(ctx, queryStr, true); queryStr = "select to_char(k1, 4) from db2.tbl1;"; - Assert.assertTrue(dorisAssert.query(queryStr).explainQuery().contains("CAST(`k1` AS CHARACTER)")); + Assert.assertTrue(containsIgnoreCase(dorisAssert.query(queryStr).explainQuery(), + "CAST(`k1` AS CHARACTER)")); } private void testFunctionQuery(ConnectContext ctx, String queryStr, Boolean isStringLiteral) throws Exception { @@ -320,4 +329,8 @@ private void createDatabase(ConnectContext ctx, String createDbStmtStr) throws E Env.getCurrentEnv().createDb(createDbStmt); System.out.println(Env.getCurrentInternalCatalog().getDbNames()); } + + private boolean containsIgnoreCase(String str, String sub) { + return str.toLowerCase().contains(sub.toLowerCase()); + } } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/postprocess/CommonSubExpressionTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/postprocess/CommonSubExpressionTest.java index d15926d447e2d3..56b67e087d59ab 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/postprocess/CommonSubExpressionTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/postprocess/CommonSubExpressionTest.java @@ -42,7 +42,7 @@ public class CommonSubExpressionTest extends ExpressionRewriteTestHelper { @Test public void testExtractCommonExpr() { - List exprs = parseProjections("a+b, a+b+1, abs(a+b+1), a"); + List exprs = parseProjections("a+b, a+b+1, abs(a+b+1), a"); CommonSubExpressionCollector collector = new CommonSubExpressionCollector(); exprs.forEach(expr -> collector.visit(expr, null)); @@ -60,7 +60,7 @@ public void testExtractCommonExpr() { @Test public void testMultiLayers() throws Exception { - List exprs = parseProjections("a, a+b, a+b+1, abs(a+b+1), a"); + List exprs = parseProjections("a, a+b, a+b+1, abs(a+b+1), a"); Set inputSlots = exprs.get(0).getInputSlots(); CommonSubExpressionOpt opt = new CommonSubExpressionOpt(); Method computeMultLayerProjectionsMethod = CommonSubExpressionOpt.class @@ -69,21 +69,32 @@ public void testMultiLayers() throws Exception { List> multiLayers = (List>) computeMultLayerProjectionsMethod .invoke(opt, inputSlots, exprs); System.out.println(multiLayers); + Assertions.assertEquals(3, multiLayers.size()); + List l0 = multiLayers.get(0); + Assertions.assertEquals(2, l0.size()); + Assertions.assertTrue(l0.contains(ExprParser.INSTANCE.parseExpression("a"))); + Assertions.assertTrue(l0.get(1) instanceof Alias); + assertExpression(l0.get(1).child(0), "a+b"); + Assertions.assertEquals(multiLayers.get(1).size(), 3); + Assertions.assertEquals(multiLayers.get(2).size(), 5); + List l2 = multiLayers.get(2); + for (int i = 0; i < 5; i++) { + Assertions.assertEquals(exprs.get(i).getExprId().asInt(), l2.get(i).getExprId().asInt()); + } + } private void assertExpression(Expression expr, String str) { - Assertions.assertEquals(PARSER.parseExpression(str), expr); + Assertions.assertEquals(ExprParser.INSTANCE.parseExpression(str), expr); } - private List parseProjections(String exprList) { - List result = new ArrayList<>(); + private List parseProjections(String exprList) { + List result = new ArrayList<>(); String[] exprArray = exprList.split(","); - HashMap slotMap = new HashMap<>(); for (String item : exprArray) { - Expression expr = PARSER.parseExpression(item); - expr = expr.accept(DataTypeAssignor.INSTANCE, slotMap); + Expression expr = ExprParser.INSTANCE.parseExpression(item); if (expr instanceof NamedExpression) { - result.add(expr); + result.add((NamedExpression) expr); } else { result.add(new Alias(expr)); } @@ -91,6 +102,16 @@ private List parseProjections(String exprList) { return result; } + public static class ExprParser { + public static ExprParser INSTANCE = new ExprParser(); + HashMap slotMap = new HashMap<>(); + + public Expression parseExpression(String str) { + Expression expr = PARSER.parseExpression(str); + return expr.accept(DataTypeAssignor.INSTANCE, slotMap); + } + } + public static class DataTypeAssignor extends DefaultExpressionRewriter> { public static DataTypeAssignor INSTANCE = new DataTypeAssignor(); From 089a62cebaa1f89963fcb202027755a810260b99 Mon Sep 17 00:00:00 2001 From: Mryange <2319153948@qq.com> Date: Sun, 24 Mar 2024 08:58:04 +0800 Subject: [PATCH 05/11] remove fake case --- regression-test/suites/nereids_tpch_p0/tpch/cse.groovy | 1 - 1 file changed, 1 deletion(-) delete mode 100644 regression-test/suites/nereids_tpch_p0/tpch/cse.groovy diff --git a/regression-test/suites/nereids_tpch_p0/tpch/cse.groovy b/regression-test/suites/nereids_tpch_p0/tpch/cse.groovy deleted file mode 100644 index a021e8cd1f7965..00000000000000 --- a/regression-test/suites/nereids_tpch_p0/tpch/cse.groovy +++ /dev/null @@ -1 +0,0 @@ -explain verbose select (s_nationkey + s_suppkey), (s_nationkey + s_suppkey) + 1, abs((s_nationkey + s_suppkey) + 1) from supplier; \ No newline at end of file From e04606d38770735ccd0aa85054f004c933335080 Mon Sep 17 00:00:00 2001 From: Mryange <2319153948@qq.com> Date: Mon, 25 Mar 2024 15:39:18 +0800 Subject: [PATCH 06/11] be upd --- be/src/exec/exec_node.cpp | 65 +++++++++++++++++++++---- be/src/exec/exec_node.h | 24 +++++++++ be/src/pipeline/pipeline_x/operator.cpp | 61 +++++++++++++++++++---- be/src/pipeline/pipeline_x/operator.h | 41 +++++++++++++++- be/src/vec/core/block.cpp | 8 +++ be/src/vec/core/block.h | 3 ++ be/src/vec/exec/scan/vscanner.cpp | 39 +++++++++++++-- be/src/vec/exec/scan/vscanner.h | 2 + 8 files changed, 217 insertions(+), 26 deletions(-) diff --git a/be/src/exec/exec_node.cpp b/be/src/exec/exec_node.cpp index ed032d0976700e..fe0c724bfbb72d 100644 --- a/be/src/exec/exec_node.cpp +++ b/be/src/exec/exec_node.cpp @@ -85,9 +85,20 @@ ExecNode::ExecNode(ObjectPool* pool, const TPlanNode& tnode, const DescriptorTbl _row_descriptor(descs, tnode.row_tuples, tnode.nullable_tuples), _resource_profile(tnode.resource_profile), _limit(tnode.limit) { - if (tnode.__isset.output_tuple_id) { - _output_row_descriptor = std::make_unique( - descs, std::vector {tnode.output_tuple_id}, std::vector {true}); + if (!tnode.output_tuple_id_list.empty()) { + // common subexpression elimination + DCHECK_EQ(tnode.output_tuple_id_list.size(), tnode.projections_list.size()); + for (auto output_tuple_id : tnode.output_tuple_id_list) { + _intermediate_output_row_descriptor.push_back(std::make_unique( + descs, std::vector {output_tuple_id}, std::vector {true})); + } + _output_row_descriptor = std::move(_intermediate_output_row_descriptor.back()); + _intermediate_output_row_descriptor.pop_back(); + } else { + if (tnode.__isset.output_tuple_id) { + _output_row_descriptor = std::make_unique( + descs, std::vector {tnode.output_tuple_id}, std::vector {true}); + } } _query_statistics = std::make_shared(); } @@ -110,9 +121,20 @@ Status ExecNode::init(const TPlanNode& tnode, RuntimeState* state) { } // create the projections expr - if (tnode.__isset.projections) { - DCHECK(tnode.__isset.output_tuple_id); - RETURN_IF_ERROR(vectorized::VExpr::create_expr_trees(tnode.projections, _projections)); + if (!tnode.projections_list.empty()) { + for (const auto& tnode_projections : tnode.projections_list) { + vectorized::VExprContextSPtrs projections; + RETURN_IF_ERROR(vectorized::VExpr::create_expr_trees(tnode_projections, projections)); + _intermediate_projections.push_back(projections); + } + _projections = _intermediate_projections.back(); + _intermediate_projections.pop_back(); + + } else { + if (tnode.__isset.projections) { + DCHECK(tnode.__isset.output_tuple_id); + RETURN_IF_ERROR(vectorized::VExpr::create_expr_trees(tnode.projections, _projections)); + } } return Status::OK(); @@ -143,7 +165,12 @@ Status ExecNode::prepare(RuntimeState* state) { RETURN_IF_ERROR(conjunct->prepare(state, intermediate_row_desc())); } - RETURN_IF_ERROR(vectorized::VExpr::prepare(_projections, state, intermediate_row_desc())); + for (int i = 0; i < _intermediate_projections.size(); i++) { + RETURN_IF_ERROR(vectorized::VExpr::prepare(_intermediate_projections[i], state, + intermediate_row_desc(i))); + } + + RETURN_IF_ERROR(vectorized::VExpr::prepare(_projections, state, projections_row_desc())); for (auto& i : _children) { RETURN_IF_ERROR(i->prepare(state)); @@ -155,6 +182,9 @@ Status ExecNode::alloc_resource(RuntimeState* state) { for (auto& conjunct : _conjuncts) { RETURN_IF_ERROR(conjunct->open(state)); } + for (auto& projections : _intermediate_projections) { + RETURN_IF_ERROR(vectorized::VExpr::open(projections, state)); + } RETURN_IF_ERROR(vectorized::VExpr::open(_projections, state)); return Status::OK(); } @@ -514,6 +544,22 @@ std::string ExecNode::get_name() { Status ExecNode::do_projections(vectorized::Block* origin_block, vectorized::Block* output_block) { SCOPED_TIMER(_exec_timer); SCOPED_TIMER(_projection_timer); + vectorized::Block input_block = *origin_block; + + const size_t rows = input_block.rows(); + if (rows == 0) { + return Status::OK(); + } + std::vector result_column_ids; + for (auto& projections : _intermediate_projections) { + result_column_ids.resize(projections.size()); + for (int i = 0; i < projections.size(); i++) { + RETURN_IF_ERROR(projections[i]->execute(&input_block, &result_column_ids[i])); + } + input_block.shuffle_columns(result_column_ids); + } + + DCHECK_EQ(rows, input_block.rows()); auto insert_column_datas = [&](auto& to, vectorized::ColumnPtr& from, size_t rows) { if (to->is_nullable() && !from->is_nullable()) { if (_keep_origin || !from->is_exclusive()) { @@ -535,7 +581,6 @@ Status ExecNode::do_projections(vectorized::Block* origin_block, vectorized::Blo using namespace vectorized; MutableBlock mutable_block = VectorizedUtils::build_mutable_mem_reuse_block(output_block, *_output_row_descriptor); - auto rows = origin_block->rows(); if (rows != 0) { auto& mutable_columns = mutable_block.mutable_columns(); @@ -549,8 +594,8 @@ Status ExecNode::do_projections(vectorized::Block* origin_block, vectorized::Blo for (int i = 0; i < mutable_columns.size(); ++i) { auto result_column_id = -1; - RETURN_IF_ERROR(_projections[i]->execute(origin_block, &result_column_id)); - auto column_ptr = origin_block->get_by_position(result_column_id) + RETURN_IF_ERROR(_projections[i]->execute(&input_block, &result_column_id)); + auto column_ptr = input_block.get_by_position(result_column_id) .column->convert_to_full_column_if_const(); //TODO: this is a quick fix, we need a new function like "change_to_nullable" to do it insert_column_datas(mutable_columns[i], column_ptr, rows); diff --git a/be/src/exec/exec_node.h b/be/src/exec/exec_node.h index f2303068437b2f..4e092b1d04a7b9 100644 --- a/be/src/exec/exec_node.h +++ b/be/src/exec/exec_node.h @@ -220,6 +220,26 @@ class ExecNode { return _output_row_descriptor ? *_output_row_descriptor : _row_descriptor; } virtual const RowDescriptor& intermediate_row_desc() const { return _row_descriptor; } + + // input expr -> intermediate_projections[0] -> intermediate_projections[1] -> intermediate_projections[2] ... -> final projections -> output expr + // prepare _row_descriptor intermediate_row_desc[0] intermediate_row_desc[1] intermediate_row_desc.end() _output_row_descriptor + + [[nodiscard]] const RowDescriptor& intermediate_row_desc(int idx) { + if (idx == 0) { + return intermediate_row_desc(); + } + DCHECK((idx - 1) < _intermediate_output_row_descriptor.size()); + return *_intermediate_output_row_descriptor[idx - 1]; + } + + [[nodiscard]] const RowDescriptor& projections_row_desc() const { + if (_intermediate_output_row_descriptor.empty()) { + return intermediate_row_desc(); + } else { + return *_intermediate_output_row_descriptor.back(); + } + } + int64_t rows_returned() const { return _num_rows_returned; } int64_t limit() const { return _limit; } bool reached_limit() const { return _limit != -1 && _num_rows_returned >= _limit; } @@ -270,6 +290,10 @@ class ExecNode { std::unique_ptr _output_row_descriptor; vectorized::VExprContextSPtrs _projections; + std::vector> _intermediate_output_row_descriptor; + // Used in common subexpression elimination to compute intermediate results. + std::vector _intermediate_projections; + /// Resource information sent from the frontend. const TBackendResourceProfile _resource_profile; diff --git a/be/src/pipeline/pipeline_x/operator.cpp b/be/src/pipeline/pipeline_x/operator.cpp index 989b1ee00a517d..cb6eff5a3f7e49 100644 --- a/be/src/pipeline/pipeline_x/operator.cpp +++ b/be/src/pipeline/pipeline_x/operator.cpp @@ -23,6 +23,8 @@ #include #include "common/logging.h" +#include "common/status.h" +#include "exec/exec_node.h" #include "pipeline/exec/aggregation_sink_operator.h" #include "pipeline/exec/aggregation_source_operator.h" #include "pipeline/exec/analytic_sink_operator.h" @@ -123,10 +125,23 @@ Status OperatorXBase::init(const TPlanNode& tnode, RuntimeState* /*state*/) { } // create the projections expr - if (tnode.__isset.projections) { - DCHECK(tnode.__isset.output_tuple_id); - RETURN_IF_ERROR(vectorized::VExpr::create_expr_trees(tnode.projections, _projections)); + + if (!tnode.projections_list.empty()) { + for (const auto& tnode_projections : tnode.projections_list) { + vectorized::VExprContextSPtrs projections; + RETURN_IF_ERROR(vectorized::VExpr::create_expr_trees(tnode_projections, projections)); + _intermediate_projections.push_back(projections); + } + _projections = _intermediate_projections.back(); + _intermediate_projections.pop_back(); + + } else { + if (tnode.__isset.projections) { + DCHECK(tnode.__isset.output_tuple_id); + RETURN_IF_ERROR(vectorized::VExpr::create_expr_trees(tnode.projections, _projections)); + } } + return Status::OK(); } @@ -134,8 +149,11 @@ Status OperatorXBase::prepare(RuntimeState* state) { for (auto& conjunct : _conjuncts) { RETURN_IF_ERROR(conjunct->prepare(state, intermediate_row_desc())); } - - RETURN_IF_ERROR(vectorized::VExpr::prepare(_projections, state, intermediate_row_desc())); + for (int i = 0; i < _intermediate_projections.size(); i++) { + RETURN_IF_ERROR(vectorized::VExpr::prepare(_intermediate_projections[i], state, + intermediate_row_desc(i))); + } + RETURN_IF_ERROR(vectorized::VExpr::prepare(_projections, state, projections_row_desc())); if (_child_x && !is_source()) { RETURN_IF_ERROR(_child_x->prepare(state)); @@ -149,6 +167,9 @@ Status OperatorXBase::open(RuntimeState* state) { RETURN_IF_ERROR(conjunct->open(state)); } RETURN_IF_ERROR(vectorized::VExpr::open(_projections, state)); + for (auto& projections : _intermediate_projections) { + RETURN_IF_ERROR(vectorized::VExpr::open(projections, state)); + } if (_child_x && !is_source()) { RETURN_IF_ERROR(_child_x->open(state)); } @@ -175,7 +196,22 @@ Status OperatorXBase::do_projections(RuntimeState* state, vectorized::Block* ori auto* local_state = state->get_local_state(operator_id()); SCOPED_TIMER(local_state->exec_time_counter()); SCOPED_TIMER(local_state->_projection_timer); + vectorized::Block input_block = *origin_block; + const size_t rows = input_block.rows(); + if (rows == 0) { + return Status::OK(); + } + std::vector result_column_ids; + for (auto& projections : _intermediate_projections) { + result_column_ids.resize(projections.size()); + for (int i = 0; i < projections.size(); i++) { + RETURN_IF_ERROR(projections[i]->execute(&input_block, &result_column_ids[i])); + } + input_block.shuffle_columns(result_column_ids); + } + + DCHECK_EQ(rows, input_block.rows()); auto insert_column_datas = [&](auto& to, vectorized::ColumnPtr& from, size_t rows) { if (to->is_nullable() && !from->is_nullable()) { if (_keep_origin || !from->is_exclusive()) { @@ -198,15 +234,13 @@ Status OperatorXBase::do_projections(RuntimeState* state, vectorized::Block* ori vectorized::MutableBlock mutable_block = vectorized::VectorizedUtils::build_mutable_mem_reuse_block(output_block, *_output_row_descriptor); - auto rows = origin_block->rows(); - if (rows != 0) { auto& mutable_columns = mutable_block.mutable_columns(); DCHECK(mutable_columns.size() == local_state->_projections.size()); for (int i = 0; i < mutable_columns.size(); ++i) { auto result_column_id = -1; - RETURN_IF_ERROR(local_state->_projections[i]->execute(origin_block, &result_column_id)); - auto column_ptr = origin_block->get_by_position(result_column_id) + RETURN_IF_ERROR(local_state->_projections[i]->execute(&input_block, &result_column_id)); + auto column_ptr = input_block.get_by_position(result_column_id) .column->convert_to_full_column_if_const(); insert_column_datas(mutable_columns[i], column_ptr, rows); } @@ -365,6 +399,15 @@ Status PipelineXLocalState::init(RuntimeState* state, LocalState for (size_t i = 0; i < _projections.size(); i++) { RETURN_IF_ERROR(_parent->_projections[i]->clone(state, _projections[i])); } + _intermediate_projections.resize(_parent->_intermediate_projections.size()); + for (int i = 0; i < _parent->_intermediate_projections.size(); i++) { + _intermediate_projections[i].resize(_parent->_intermediate_projections[i].size()); + for (int j = 0; j < _parent->_intermediate_projections[i].size(); j++) { + RETURN_IF_ERROR(_parent->_intermediate_projections[i][j]->clone( + state, _intermediate_projections[i][j])); + } + } + _rows_returned_counter = ADD_COUNTER_WITH_LEVEL(_runtime_profile, "RowsProduced", TUnit::UNIT, 1); _blocks_returned_counter = diff --git a/be/src/pipeline/pipeline_x/operator.h b/be/src/pipeline/pipeline_x/operator.h index c375efb924dcbc..8f6018138b8610 100644 --- a/be/src/pipeline/pipeline_x/operator.h +++ b/be/src/pipeline/pipeline_x/operator.h @@ -135,6 +135,9 @@ class PipelineXLocalStateBase { RuntimeState* _state = nullptr; vectorized::VExprContextSPtrs _conjuncts; vectorized::VExprContextSPtrs _projections; + // Used in common subexpression elimination to compute intermediate results. + std::vector _intermediate_projections; + bool _closed = false; vectorized::Block _origin_block; }; @@ -152,8 +155,19 @@ class OperatorXBase : public OperatorBase { _row_descriptor(descs, tnode.row_tuples, tnode.nullable_tuples), _resource_profile(tnode.resource_profile), _limit(tnode.limit) { - if (tnode.__isset.output_tuple_id) { - _output_row_descriptor.reset(new RowDescriptor(descs, {tnode.output_tuple_id}, {true})); + if (!tnode.output_tuple_id_list.empty()) { + // common subexpression elimination + for (auto output_tuple_id : tnode.output_tuple_id_list) { + _intermediate_output_row_descriptor.push_back(std::make_unique( + descs, std::vector {output_tuple_id}, std::vector {true})); + } + _output_row_descriptor = std::move(_intermediate_output_row_descriptor.back()); + _intermediate_output_row_descriptor.pop_back(); + } else { + if (tnode.__isset.output_tuple_id) { + _output_row_descriptor.reset( + new RowDescriptor(descs, {tnode.output_tuple_id}, {true})); + } } } @@ -247,6 +261,25 @@ class OperatorXBase : public OperatorBase { return _row_descriptor; } + // input expr -> intermediate_projections[0] -> intermediate_projections[1] -> intermediate_projections[2] ... -> final projections -> output expr + // prepare _row_descriptor intermediate_row_desc[0] intermediate_row_desc[1] intermediate_row_desc.end() _output_row_descriptor + + [[nodiscard]] const RowDescriptor& intermediate_row_desc(int idx) { + if (idx == 0) { + return intermediate_row_desc(); + } + DCHECK((idx - 1) < _intermediate_output_row_descriptor.size()); + return *_intermediate_output_row_descriptor[idx - 1]; + } + + [[nodiscard]] const RowDescriptor& projections_row_desc() const { + if (_intermediate_output_row_descriptor.empty()) { + return intermediate_row_desc(); + } else { + return *_intermediate_output_row_descriptor.back(); + } + } + [[nodiscard]] std::string debug_string() const override { return ""; } virtual std::string debug_string(int indentation_level = 0) const; @@ -318,6 +351,10 @@ class OperatorXBase : public OperatorBase { std::unique_ptr _output_row_descriptor = nullptr; vectorized::VExprContextSPtrs _projections; + std::vector> _intermediate_output_row_descriptor; + // Used in common subexpression elimination to compute intermediate results. + std::vector _intermediate_projections; + /// Resource information sent from the frontend. const TBackendResourceProfile _resource_profile; diff --git a/be/src/vec/core/block.cpp b/be/src/vec/core/block.cpp index c93bfb11f09d6d..ec7a6bf9256625 100644 --- a/be/src/vec/core/block.cpp +++ b/be/src/vec/core/block.cpp @@ -719,6 +719,14 @@ void Block::swap(Block&& other) noexcept { row_same_bit = std::move(other.row_same_bit); } +void Block::shuffle_columns(std::vector& result_column_ids) { + Container tmp_data; + for (const int result_column_id : result_column_ids) { + tmp_data.push_back(data[result_column_id]); + } + swap(Block {tmp_data}); +} + void Block::update_hash(SipHash& hash) const { for (size_t row_no = 0, num_rows = rows(); row_no < num_rows; ++row_no) { for (const auto& col : data) { diff --git a/be/src/vec/core/block.h b/be/src/vec/core/block.h index a9769e7b679287..9f32ac21b6d29a 100644 --- a/be/src/vec/core/block.h +++ b/be/src/vec/core/block.h @@ -234,6 +234,9 @@ class Block { void swap(Block& other) noexcept; void swap(Block&& other) noexcept; + // Shuffle columns in place based on the result_column_ids + void shuffle_columns(std::vector& result_column_ids); + // Default column size = -1 means clear all column in block // Else clear column [0, column_size) delete column [column_size, data.size) void clear_column_data(int column_size = -1) noexcept; diff --git a/be/src/vec/exec/scan/vscanner.cpp b/be/src/vec/exec/scan/vscanner.cpp index 39a9059d1d37c8..fe9f5eb86a979b 100644 --- a/be/src/vec/exec/scan/vscanner.cpp +++ b/be/src/vec/exec/scan/vscanner.cpp @@ -20,6 +20,7 @@ #include #include "common/config.h" +#include "exec/exec_node.h" #include "pipeline/exec/scan_operator.h" #include "runtime/descriptors.h" #include "util/runtime_profile.h" @@ -68,6 +69,19 @@ Status VScanner::prepare(RuntimeState* state, const VExprContextSPtrs& conjuncts } } + const auto& intermediate_projections = + _parent ? _parent->_intermediate_projections : _local_state->_intermediate_projections; + if (!intermediate_projections.empty()) { + _intermediate_projections.resize(intermediate_projections.size()); + for (int i = 0; i < intermediate_projections.size(); i++) { + _intermediate_projections[i].resize(intermediate_projections[i].size()); + for (int j = 0; j < intermediate_projections[i].size(); j++) { + RETURN_IF_ERROR(intermediate_projections[i][j]->clone( + state, _intermediate_projections[i][j])); + } + } + } + return Status::OK(); } @@ -169,14 +183,29 @@ Status VScanner::_filter_output_block(Block* block) { } Status VScanner::_do_projections(vectorized::Block* origin_block, vectorized::Block* output_block) { - auto projection_timer = _parent ? _parent->_projection_timer : _local_state->_projection_timer; - auto exec_timer = _parent ? _parent->_exec_timer : _local_state->_exec_timer; + auto& projection_timer = _parent ? _parent->_projection_timer : _local_state->_projection_timer; + auto& exec_timer = _parent ? _parent->_exec_timer : _local_state->_exec_timer; SCOPED_TIMER(exec_timer); SCOPED_TIMER(projection_timer); + vectorized::Block input_block = *origin_block; + + const size_t rows = input_block.rows(); + if (rows == 0) { + return Status::OK(); + } + std::vector result_column_ids; + for (auto& projections : _intermediate_projections) { + result_column_ids.resize(projections.size()); + for (int i = 0; i < projections.size(); i++) { + RETURN_IF_ERROR(projections[i]->execute(&input_block, &result_column_ids[i])); + } + input_block.shuffle_columns(result_column_ids); + } + + DCHECK_EQ(rows, input_block.rows()); MutableBlock mutable_block = VectorizedUtils::build_mutable_mem_reuse_block(output_block, *_output_row_descriptor); - auto rows = origin_block->rows(); if (rows != 0) { auto& mutable_columns = mutable_block.mutable_columns(); @@ -190,8 +219,8 @@ Status VScanner::_do_projections(vectorized::Block* origin_block, vectorized::Bl for (int i = 0; i < mutable_columns.size(); ++i) { auto result_column_id = -1; - RETURN_IF_ERROR(_projections[i]->execute(origin_block, &result_column_id)); - auto column_ptr = origin_block->get_by_position(result_column_id) + RETURN_IF_ERROR(_projections[i]->execute(&input_block, &result_column_id)); + auto column_ptr = input_block.get_by_position(result_column_id) .column->convert_to_full_column_if_const(); //TODO: this is a quick fix, we need a new function like "change_to_nullable" to do it if (mutable_columns[i]->is_nullable() xor column_ptr->is_nullable()) { diff --git a/be/src/vec/exec/scan/vscanner.h b/be/src/vec/exec/scan/vscanner.h index d264e99fc78306..8c205aaff5d4db 100644 --- a/be/src/vec/exec/scan/vscanner.h +++ b/be/src/vec/exec/scan/vscanner.h @@ -195,6 +195,8 @@ class VScanner { // It includes predicate in SQL and runtime filters. VExprContextSPtrs _conjuncts; VExprContextSPtrs _projections; + // Used in common subexpression elimination to compute intermediate results. + std::vector _intermediate_projections; vectorized::Block _origin_block; VExprContextSPtrs _common_expr_ctxs_push_down; From 39eed1df9ae0779b46c3ac7b7ec6016155c9f837 Mon Sep 17 00:00:00 2001 From: Mryange <2319153948@qq.com> Date: Wed, 27 Mar 2024 13:12:45 +0800 Subject: [PATCH 07/11] refine be code --- be/src/exec/exec_node.cpp | 41 ++++++++++---------- be/src/pipeline/pipeline_x/operator.cpp | 8 ++-- be/src/vec/core/block.cpp | 2 +- be/src/vec/core/block.h | 2 +- be/src/vec/exec/scan/vscanner.cpp | 50 ++++++++++++------------- 5 files changed, 49 insertions(+), 54 deletions(-) diff --git a/be/src/exec/exec_node.cpp b/be/src/exec/exec_node.cpp index fe0c724bfbb72d..f6448d939f9708 100644 --- a/be/src/exec/exec_node.cpp +++ b/be/src/exec/exec_node.cpp @@ -129,7 +129,6 @@ Status ExecNode::init(const TPlanNode& tnode, RuntimeState* state) { } _projections = _intermediate_projections.back(); _intermediate_projections.pop_back(); - } else { if (tnode.__isset.projections) { DCHECK(tnode.__isset.output_tuple_id); @@ -544,12 +543,12 @@ std::string ExecNode::get_name() { Status ExecNode::do_projections(vectorized::Block* origin_block, vectorized::Block* output_block) { SCOPED_TIMER(_exec_timer); SCOPED_TIMER(_projection_timer); - vectorized::Block input_block = *origin_block; - - const size_t rows = input_block.rows(); + const size_t rows = origin_block->rows(); if (rows == 0) { return Status::OK(); } + vectorized::Block input_block = *origin_block; + std::vector result_column_ids; for (auto& projections : _intermediate_projections) { result_column_ids.resize(projections.size()); @@ -582,27 +581,25 @@ Status ExecNode::do_projections(vectorized::Block* origin_block, vectorized::Blo MutableBlock mutable_block = VectorizedUtils::build_mutable_mem_reuse_block(output_block, *_output_row_descriptor); - if (rows != 0) { - auto& mutable_columns = mutable_block.mutable_columns(); + auto& mutable_columns = mutable_block.mutable_columns(); - if (mutable_columns.size() != _projections.size()) { - return Status::InternalError( - "Logical error during processing {}, output of projections {} mismatches with " - "exec node output {}", - this->get_name(), _projections.size(), mutable_columns.size()); - } + if (mutable_columns.size() != _projections.size()) { + return Status::InternalError( + "Logical error during processing {}, output of projections {} mismatches with " + "exec node output {}", + this->get_name(), _projections.size(), mutable_columns.size()); + } - for (int i = 0; i < mutable_columns.size(); ++i) { - auto result_column_id = -1; - RETURN_IF_ERROR(_projections[i]->execute(&input_block, &result_column_id)); - auto column_ptr = input_block.get_by_position(result_column_id) - .column->convert_to_full_column_if_const(); - //TODO: this is a quick fix, we need a new function like "change_to_nullable" to do it - insert_column_datas(mutable_columns[i], column_ptr, rows); - } - DCHECK(mutable_block.rows() == rows); - output_block->set_columns(std::move(mutable_columns)); + for (int i = 0; i < mutable_columns.size(); ++i) { + auto result_column_id = -1; + RETURN_IF_ERROR(_projections[i]->execute(&input_block, &result_column_id)); + auto column_ptr = input_block.get_by_position(result_column_id) + .column->convert_to_full_column_if_const(); + //TODO: this is a quick fix, we need a new function like "change_to_nullable" to do it + insert_column_datas(mutable_columns[i], column_ptr, rows); } + DCHECK(mutable_block.rows() == rows); + output_block->set_columns(std::move(mutable_columns)); return Status::OK(); } diff --git a/be/src/pipeline/pipeline_x/operator.cpp b/be/src/pipeline/pipeline_x/operator.cpp index cb6eff5a3f7e49..68cd947b5c388d 100644 --- a/be/src/pipeline/pipeline_x/operator.cpp +++ b/be/src/pipeline/pipeline_x/operator.cpp @@ -196,14 +196,14 @@ Status OperatorXBase::do_projections(RuntimeState* state, vectorized::Block* ori auto* local_state = state->get_local_state(operator_id()); SCOPED_TIMER(local_state->exec_time_counter()); SCOPED_TIMER(local_state->_projection_timer); - vectorized::Block input_block = *origin_block; - - const size_t rows = input_block.rows(); + const size_t rows = origin_block->rows(); if (rows == 0) { return Status::OK(); } + vectorized::Block input_block = *origin_block; + std::vector result_column_ids; - for (auto& projections : _intermediate_projections) { + for (const auto& projections : _intermediate_projections) { result_column_ids.resize(projections.size()); for (int i = 0; i < projections.size(); i++) { RETURN_IF_ERROR(projections[i]->execute(&input_block, &result_column_ids[i])); diff --git a/be/src/vec/core/block.cpp b/be/src/vec/core/block.cpp index ec7a6bf9256625..03a0d14ab1f697 100644 --- a/be/src/vec/core/block.cpp +++ b/be/src/vec/core/block.cpp @@ -719,7 +719,7 @@ void Block::swap(Block&& other) noexcept { row_same_bit = std::move(other.row_same_bit); } -void Block::shuffle_columns(std::vector& result_column_ids) { +void Block::shuffle_columns(const std::vector& result_column_ids) { Container tmp_data; for (const int result_column_id : result_column_ids) { tmp_data.push_back(data[result_column_id]); diff --git a/be/src/vec/core/block.h b/be/src/vec/core/block.h index 9f32ac21b6d29a..eb4fe43eca2faf 100644 --- a/be/src/vec/core/block.h +++ b/be/src/vec/core/block.h @@ -235,7 +235,7 @@ class Block { void swap(Block&& other) noexcept; // Shuffle columns in place based on the result_column_ids - void shuffle_columns(std::vector& result_column_ids); + void shuffle_columns(const std::vector& result_column_ids); // Default column size = -1 means clear all column in block // Else clear column [0, column_size) delete column [column_size, data.size) diff --git a/be/src/vec/exec/scan/vscanner.cpp b/be/src/vec/exec/scan/vscanner.cpp index fe9f5eb86a979b..bedd6fb9e46352 100644 --- a/be/src/vec/exec/scan/vscanner.cpp +++ b/be/src/vec/exec/scan/vscanner.cpp @@ -188,12 +188,12 @@ Status VScanner::_do_projections(vectorized::Block* origin_block, vectorized::Bl SCOPED_TIMER(exec_timer); SCOPED_TIMER(projection_timer); - vectorized::Block input_block = *origin_block; - - const size_t rows = input_block.rows(); + const size_t rows = origin_block->rows(); if (rows == 0) { return Status::OK(); } + vectorized::Block input_block = *origin_block; + std::vector result_column_ids; for (auto& projections : _intermediate_projections) { result_column_ids.resize(projections.size()); @@ -207,33 +207,31 @@ Status VScanner::_do_projections(vectorized::Block* origin_block, vectorized::Bl MutableBlock mutable_block = VectorizedUtils::build_mutable_mem_reuse_block(output_block, *_output_row_descriptor); - if (rows != 0) { - auto& mutable_columns = mutable_block.mutable_columns(); + auto& mutable_columns = mutable_block.mutable_columns(); - if (mutable_columns.size() != _projections.size()) { - return Status::InternalError( - "Logical error in scanner, output of projections {} mismatches with " - "scanner output {}", - _projections.size(), mutable_columns.size()); - } + if (mutable_columns.size() != _projections.size()) { + return Status::InternalError( + "Logical error in scanner, output of projections {} mismatches with " + "scanner output {}", + _projections.size(), mutable_columns.size()); + } - for (int i = 0; i < mutable_columns.size(); ++i) { - auto result_column_id = -1; - RETURN_IF_ERROR(_projections[i]->execute(&input_block, &result_column_id)); - auto column_ptr = input_block.get_by_position(result_column_id) - .column->convert_to_full_column_if_const(); - //TODO: this is a quick fix, we need a new function like "change_to_nullable" to do it - if (mutable_columns[i]->is_nullable() xor column_ptr->is_nullable()) { - DCHECK(mutable_columns[i]->is_nullable() && !column_ptr->is_nullable()); - reinterpret_cast(mutable_columns[i].get()) - ->insert_range_from_not_nullable(*column_ptr, 0, rows); - } else { - mutable_columns[i]->insert_range_from(*column_ptr, 0, rows); - } + for (int i = 0; i < mutable_columns.size(); ++i) { + auto result_column_id = -1; + RETURN_IF_ERROR(_projections[i]->execute(&input_block, &result_column_id)); + auto column_ptr = input_block.get_by_position(result_column_id) + .column->convert_to_full_column_if_const(); + //TODO: this is a quick fix, we need a new function like "change_to_nullable" to do it + if (mutable_columns[i]->is_nullable() xor column_ptr->is_nullable()) { + DCHECK(mutable_columns[i]->is_nullable() && !column_ptr->is_nullable()); + reinterpret_cast(mutable_columns[i].get()) + ->insert_range_from_not_nullable(*column_ptr, 0, rows); + } else { + mutable_columns[i]->insert_range_from(*column_ptr, 0, rows); } - DCHECK(mutable_block.rows() == rows); - output_block->set_columns(std::move(mutable_columns)); } + DCHECK(mutable_block.rows() == rows); + output_block->set_columns(std::move(mutable_columns)); return Status::OK(); } From 2e35d0b18cb0f3d53dd72342dc51099e6ed5dfbd Mon Sep 17 00:00:00 2001 From: Mryange <2319153948@qq.com> Date: Wed, 27 Mar 2024 15:05:46 +0800 Subject: [PATCH 08/11] upd thrift --- gensrc/thrift/PlanNodes.thrift | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/gensrc/thrift/PlanNodes.thrift b/gensrc/thrift/PlanNodes.thrift index b29700b6220d1d..d88ab993363352 100644 --- a/gensrc/thrift/PlanNodes.thrift +++ b/gensrc/thrift/PlanNodes.thrift @@ -1294,12 +1294,13 @@ struct TPlanNode { 49: optional i64 push_down_count 50: optional list> distribute_expr_lists - + // projections is final projections, which means projecting into results and materializing them into the output block. 101: optional list projections 102: optional Types.TTupleId output_tuple_id 103: optional TPartitionSortNode partition_sort_node - 104: optional list> projections_list - 105: optional list output_tuple_id_list + // Intermediate projections will not materialize into the output block. + 104: optional list> intermediate_projections_list + 105: optional list intermediate_output_tuple_id_list } // A flattened representation of a tree of PlanNodes, obtained by depth-first From c0c9a72bb4b18178ac8968af5c491a018df201cc Mon Sep 17 00:00:00 2001 From: Mryange <2319153948@qq.com> Date: Wed, 27 Mar 2024 15:15:09 +0800 Subject: [PATCH 09/11] fe code refine --- .../translator/PhysicalPlanTranslator.java | 4 +- .../org/apache/doris/planner/PlanNode.java | 45 +++++++++---------- 2 files changed, 23 insertions(+), 26 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java index 65b9760c167161..205cfbd25309d8 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java @@ -1852,9 +1852,9 @@ public PlanFragment visitPhysicalProject(PhysicalProject project .map(NamedExpression::toSlot) .collect(Collectors.toList()); if (i < layerCount - 1) { - inputPlanNode.addProjectList(projectionExprs); + inputPlanNode.addIntermediateProjectList(projectionExprs); TupleDescriptor projectionTuple = generateTupleDesc(slots, null, context); - inputPlanNode.addOutputTupleDescList(projectionTuple); + inputPlanNode.addIntermediateOutputTupleDescList(projectionTuple); } allProjectionExprs.addAll(projectionExprs); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/planner/PlanNode.java b/fe/fe-core/src/main/java/org/apache/doris/planner/PlanNode.java index 0853cb84797645..8cc18a527a86d6 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/planner/PlanNode.java +++ b/fe/fe-core/src/main/java/org/apache/doris/planner/PlanNode.java @@ -156,8 +156,8 @@ public abstract class PlanNode extends TreeNode implements PlanStats { protected int nereidsId = -1; private List> childrenDistributeExprLists = new ArrayList<>(); - private List outputTupleDescList = Lists.newArrayList(); - private List> projectListList = Lists.newArrayList(); + private List intermediateOutputTupleDescList = Lists.newArrayList(); + private List> intermediateProjectListList = Lists.newArrayList(); protected PlanNode(PlanNodeId id, ArrayList tupleIds, String planNodeName, StatisticalType statisticalType) { @@ -544,13 +544,13 @@ protected final String getExplainString(String rootPrefix, String prefix, TExpla expBuilder.append(detailPrefix).append("final project output tuple id: ") .append(outputTupleDesc.getId().asInt()).append("\n"); } - if (!projectListList.isEmpty()) { - int layers = projectListList.size(); + if (!intermediateProjectListList.isEmpty()) { + int layers = intermediateProjectListList.size(); for (int i = layers - 1; i >= 0; i--) { expBuilder.append(detailPrefix).append("intermediate projections: ") - .append(getExplainString(projectListList.get(i))).append("\n"); + .append(getExplainString(intermediateProjectListList.get(i))).append("\n"); expBuilder.append(detailPrefix).append("intermediate tuple id: ") - .append(outputTupleDescList.get(i).getId().asInt()).append("\n"); + .append(intermediateOutputTupleDescList.get(i).getId().asInt()).append("\n"); } } if (!CollectionUtils.isEmpty(childrenDistributeExprLists)) { @@ -673,20 +673,17 @@ private void treeToThriftHelper(TPlan container) { } } } - if (outputTupleDescList != null && ! outputTupleDescList.isEmpty()) { - outputTupleDescList - .forEach(tupleDescriptor -> msg.addToOutputTupleIdList(tupleDescriptor.getId().asInt())); - // hashJoinNode.outputTupleDesc is null, its counterpart is vOutputTupleDesc - if (outputTupleDesc != null) { - msg.addToOutputTupleIdList(outputTupleDesc.getId().asInt()); - } - if (projectList != null) { - projectListList.forEach( - projectList -> msg.addToProjectionsList( - projectList.stream().map(expr -> expr.treeToThrift()).collect(Collectors.toList()))); - msg.addToProjectionsList(projectList.stream() - .map(expr -> expr.treeToThrift()).collect(Collectors.toList())); - } + + if (!intermediateOutputTupleDescList.isEmpty()) { + intermediateOutputTupleDescList + .forEach( + tupleDescriptor -> msg.addToIntermediateOutputTupleIdList(tupleDescriptor.getId().asInt())); + } + + if (!intermediateProjectListList.isEmpty()) { + intermediateProjectListList.forEach( + projectList -> msg.addToIntermediateProjectionsList( + projectList.stream().map(expr -> expr.treeToThrift()).collect(Collectors.toList()))); } if (this instanceof ExchangeNode) { @@ -1251,11 +1248,11 @@ public void setNereidsId(int nereidsId) { this.nereidsId = nereidsId; } - public void addOutputTupleDescList(TupleDescriptor tupleDescriptor) { - outputTupleDescList.add(tupleDescriptor); + public void addIntermediateOutputTupleDescList(TupleDescriptor tupleDescriptor) { + intermediateOutputTupleDescList.add(tupleDescriptor); } - public void addProjectList(List exprs) { - projectListList.add(exprs); + public void addIntermediateProjectList(List exprs) { + intermediateProjectListList.add(exprs); } } From 7e48d68254b1acd8bcff13e6ac3f992f1c6d07bd Mon Sep 17 00:00:00 2001 From: Mryange <2319153948@qq.com> Date: Wed, 27 Mar 2024 15:56:42 +0800 Subject: [PATCH 10/11] be code refine --- be/src/exec/exec_node.cpp | 43 ++++++++++++------------- be/src/exec/exec_node.h | 6 ++-- be/src/pipeline/pipeline_x/operator.cpp | 19 +++++------ be/src/pipeline/pipeline_x/operator.h | 33 +++++++++++-------- be/src/vec/core/block.cpp | 1 + 5 files changed, 52 insertions(+), 50 deletions(-) diff --git a/be/src/exec/exec_node.cpp b/be/src/exec/exec_node.cpp index f6448d939f9708..63b88aa9de2b92 100644 --- a/be/src/exec/exec_node.cpp +++ b/be/src/exec/exec_node.cpp @@ -85,21 +85,22 @@ ExecNode::ExecNode(ObjectPool* pool, const TPlanNode& tnode, const DescriptorTbl _row_descriptor(descs, tnode.row_tuples, tnode.nullable_tuples), _resource_profile(tnode.resource_profile), _limit(tnode.limit) { - if (!tnode.output_tuple_id_list.empty()) { + if (tnode.__isset.output_tuple_id) { + _output_row_descriptor = std::make_unique( + descs, std::vector {tnode.output_tuple_id}, std::vector {true}); + } + if (!tnode.intermediate_output_tuple_id_list.empty()) { + DCHECK(tnode.__isset.output_tuple_id) << " no final output tuple id"; // common subexpression elimination - DCHECK_EQ(tnode.output_tuple_id_list.size(), tnode.projections_list.size()); - for (auto output_tuple_id : tnode.output_tuple_id_list) { - _intermediate_output_row_descriptor.push_back(std::make_unique( - descs, std::vector {output_tuple_id}, std::vector {true})); - } - _output_row_descriptor = std::move(_intermediate_output_row_descriptor.back()); - _intermediate_output_row_descriptor.pop_back(); - } else { - if (tnode.__isset.output_tuple_id) { - _output_row_descriptor = std::make_unique( - descs, std::vector {tnode.output_tuple_id}, std::vector {true}); + DCHECK_EQ(tnode.intermediate_output_tuple_id_list.size(), + tnode.intermediate_projections_list.size()); + _intermediate_output_row_descriptor.reserve(tnode.intermediate_output_tuple_id_list.size()); + for (auto output_tuple_id : tnode.intermediate_output_tuple_id_list) { + _intermediate_output_row_descriptor.push_back( + RowDescriptor(descs, std::vector {output_tuple_id}, std::vector {true})); } } + _query_statistics = std::make_shared(); } @@ -121,21 +122,19 @@ Status ExecNode::init(const TPlanNode& tnode, RuntimeState* state) { } // create the projections expr - if (!tnode.projections_list.empty()) { - for (const auto& tnode_projections : tnode.projections_list) { + if (tnode.__isset.projections) { + DCHECK(tnode.__isset.output_tuple_id); + RETURN_IF_ERROR(vectorized::VExpr::create_expr_trees(tnode.projections, _projections)); + } + if (!tnode.intermediate_projections_list.empty()) { + DCHECK(tnode.__isset.projections) << "no final projections"; + _intermediate_projections.reserve(tnode.intermediate_projections_list.size()); + for (const auto& tnode_projections : tnode.intermediate_projections_list) { vectorized::VExprContextSPtrs projections; RETURN_IF_ERROR(vectorized::VExpr::create_expr_trees(tnode_projections, projections)); _intermediate_projections.push_back(projections); } - _projections = _intermediate_projections.back(); - _intermediate_projections.pop_back(); - } else { - if (tnode.__isset.projections) { - DCHECK(tnode.__isset.output_tuple_id); - RETURN_IF_ERROR(vectorized::VExpr::create_expr_trees(tnode.projections, _projections)); - } } - return Status::OK(); } diff --git a/be/src/exec/exec_node.h b/be/src/exec/exec_node.h index 4e092b1d04a7b9..10b035835d7a7f 100644 --- a/be/src/exec/exec_node.h +++ b/be/src/exec/exec_node.h @@ -229,14 +229,14 @@ class ExecNode { return intermediate_row_desc(); } DCHECK((idx - 1) < _intermediate_output_row_descriptor.size()); - return *_intermediate_output_row_descriptor[idx - 1]; + return _intermediate_output_row_descriptor[idx - 1]; } [[nodiscard]] const RowDescriptor& projections_row_desc() const { if (_intermediate_output_row_descriptor.empty()) { return intermediate_row_desc(); } else { - return *_intermediate_output_row_descriptor.back(); + return _intermediate_output_row_descriptor.back(); } } @@ -290,7 +290,7 @@ class ExecNode { std::unique_ptr _output_row_descriptor; vectorized::VExprContextSPtrs _projections; - std::vector> _intermediate_output_row_descriptor; + std::vector _intermediate_output_row_descriptor; // Used in common subexpression elimination to compute intermediate results. std::vector _intermediate_projections; diff --git a/be/src/pipeline/pipeline_x/operator.cpp b/be/src/pipeline/pipeline_x/operator.cpp index 68cd947b5c388d..4a16cb65a014be 100644 --- a/be/src/pipeline/pipeline_x/operator.cpp +++ b/be/src/pipeline/pipeline_x/operator.cpp @@ -126,22 +126,19 @@ Status OperatorXBase::init(const TPlanNode& tnode, RuntimeState* /*state*/) { // create the projections expr - if (!tnode.projections_list.empty()) { - for (const auto& tnode_projections : tnode.projections_list) { + if (tnode.__isset.projections) { + DCHECK(tnode.__isset.output_tuple_id); + RETURN_IF_ERROR(vectorized::VExpr::create_expr_trees(tnode.projections, _projections)); + } + if (!tnode.intermediate_projections_list.empty()) { + DCHECK(tnode.__isset.projections) << "no final projections"; + _intermediate_projections.reserve(tnode.intermediate_projections_list.size()); + for (const auto& tnode_projections : tnode.intermediate_projections_list) { vectorized::VExprContextSPtrs projections; RETURN_IF_ERROR(vectorized::VExpr::create_expr_trees(tnode_projections, projections)); _intermediate_projections.push_back(projections); } - _projections = _intermediate_projections.back(); - _intermediate_projections.pop_back(); - - } else { - if (tnode.__isset.projections) { - DCHECK(tnode.__isset.output_tuple_id); - RETURN_IF_ERROR(vectorized::VExpr::create_expr_trees(tnode.projections, _projections)); - } } - return Status::OK(); } diff --git a/be/src/pipeline/pipeline_x/operator.h b/be/src/pipeline/pipeline_x/operator.h index 8f6018138b8610..c3eb4d0cb51905 100644 --- a/be/src/pipeline/pipeline_x/operator.h +++ b/be/src/pipeline/pipeline_x/operator.h @@ -155,18 +155,23 @@ class OperatorXBase : public OperatorBase { _row_descriptor(descs, tnode.row_tuples, tnode.nullable_tuples), _resource_profile(tnode.resource_profile), _limit(tnode.limit) { - if (!tnode.output_tuple_id_list.empty()) { + if (tnode.__isset.output_tuple_id) { + _output_row_descriptor.reset(new RowDescriptor(descs, {tnode.output_tuple_id}, {true})); + } + if (tnode.__isset.output_tuple_id) { + _output_row_descriptor = std::make_unique( + descs, std::vector {tnode.output_tuple_id}, std::vector {true}); + } + if (!tnode.intermediate_output_tuple_id_list.empty()) { + DCHECK(tnode.__isset.output_tuple_id) << " no final output tuple id"; // common subexpression elimination - for (auto output_tuple_id : tnode.output_tuple_id_list) { - _intermediate_output_row_descriptor.push_back(std::make_unique( - descs, std::vector {output_tuple_id}, std::vector {true})); - } - _output_row_descriptor = std::move(_intermediate_output_row_descriptor.back()); - _intermediate_output_row_descriptor.pop_back(); - } else { - if (tnode.__isset.output_tuple_id) { - _output_row_descriptor.reset( - new RowDescriptor(descs, {tnode.output_tuple_id}, {true})); + DCHECK_EQ(tnode.intermediate_output_tuple_id_list.size(), + tnode.intermediate_projections_list.size()); + _intermediate_output_row_descriptor.reserve( + tnode.intermediate_output_tuple_id_list.size()); + for (auto output_tuple_id : tnode.intermediate_output_tuple_id_list) { + _intermediate_output_row_descriptor.push_back( + RowDescriptor(descs, std::vector {output_tuple_id}, std::vector {true})); } } } @@ -269,14 +274,14 @@ class OperatorXBase : public OperatorBase { return intermediate_row_desc(); } DCHECK((idx - 1) < _intermediate_output_row_descriptor.size()); - return *_intermediate_output_row_descriptor[idx - 1]; + return _intermediate_output_row_descriptor[idx - 1]; } [[nodiscard]] const RowDescriptor& projections_row_desc() const { if (_intermediate_output_row_descriptor.empty()) { return intermediate_row_desc(); } else { - return *_intermediate_output_row_descriptor.back(); + return _intermediate_output_row_descriptor.back(); } } @@ -351,7 +356,7 @@ class OperatorXBase : public OperatorBase { std::unique_ptr _output_row_descriptor = nullptr; vectorized::VExprContextSPtrs _projections; - std::vector> _intermediate_output_row_descriptor; + std::vector _intermediate_output_row_descriptor; // Used in common subexpression elimination to compute intermediate results. std::vector _intermediate_projections; diff --git a/be/src/vec/core/block.cpp b/be/src/vec/core/block.cpp index 03a0d14ab1f697..1d8d3e838015c9 100644 --- a/be/src/vec/core/block.cpp +++ b/be/src/vec/core/block.cpp @@ -721,6 +721,7 @@ void Block::swap(Block&& other) noexcept { void Block::shuffle_columns(const std::vector& result_column_ids) { Container tmp_data; + tmp_data.reserve(result_column_ids.size()); for (const int result_column_id : result_column_ids) { tmp_data.push_back(data[result_column_id]); } From ea683d4067a26b925ff4ec5f7d48c1802caa443e Mon Sep 17 00:00:00 2001 From: Mryange <2319153948@qq.com> Date: Fri, 29 Mar 2024 20:17:03 +0800 Subject: [PATCH 11/11] add some case --- .../data/tpch_sf0.1_p1/sql/cse.out | 31 +++++++++++++++++++ .../suites/tpch_sf0.1_p1/sql/cse.sql | 6 ++++ 2 files changed, 37 insertions(+) create mode 100644 regression-test/data/tpch_sf0.1_p1/sql/cse.out create mode 100644 regression-test/suites/tpch_sf0.1_p1/sql/cse.sql diff --git a/regression-test/data/tpch_sf0.1_p1/sql/cse.out b/regression-test/data/tpch_sf0.1_p1/sql/cse.out new file mode 100644 index 00000000000000..454fe1083b511e --- /dev/null +++ b/regression-test/data/tpch_sf0.1_p1/sql/cse.out @@ -0,0 +1,31 @@ +-- This file is automatically generated. You should know what you did if you want to edit this +-- !cse -- +1 1 3 4 +2 0 3 4 +3 1 5 6 +4 0 5 6 +5 4 10 11 +6 0 7 8 +7 3 11 12 +8 1 10 11 +9 4 14 15 +10 1 12 13 + +-- !cse_2 -- +17 1 18 19 19 +5 2 7 8 8 +1 3 4 5 5 +15 4 19 20 20 +11 5 16 17 17 +14 6 20 21 21 +23 7 30 31 31 +17 8 25 26 26 +10 9 19 20 20 +24 10 34 35 35 + +-- !cse_3 -- +12093 13093 14093 15093 + +-- !cse_4 -- +12093 13093 14093 15093 + diff --git a/regression-test/suites/tpch_sf0.1_p1/sql/cse.sql b/regression-test/suites/tpch_sf0.1_p1/sql/cse.sql new file mode 100644 index 00000000000000..a7885eb9ce349a --- /dev/null +++ b/regression-test/suites/tpch_sf0.1_p1/sql/cse.sql @@ -0,0 +1,6 @@ +select s_suppkey,n_regionkey,(s_suppkey + n_regionkey) + 1 as x, (s_suppkey + n_regionkey) + 2 as y +from supplier join nation on s_nationkey=n_nationkey order by s_suppkey , n_regionkey limit 10 ; +select s_nationkey,s_suppkey ,(s_nationkey + s_suppkey), (s_nationkey + s_suppkey) + 1, abs((s_nationkey + s_suppkey) + 1) +from supplier order by s_suppkey , s_suppkey limit 10 ; +select sum(s_nationkey),sum(s_nationkey +1 ) ,sum(s_nationkey +2 ) , sum(s_nationkey + 3 ) from supplier ; +select sum(s_nationkey),sum(s_nationkey) + count(1) ,sum(s_nationkey) + 2 * count(1) , sum(s_nationkey) + 3 * count(1) from supplier ; \ No newline at end of file