Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions fe/fe-core/src/main/java/org/apache/doris/analysis/Expr.java
Original file line number Diff line number Diff line change
Expand Up @@ -722,6 +722,16 @@ public static <C extends Expr> boolean containsAggregate(List<? extends Expr> in
return false;
}

public static void extractSlots(Expr root, Set<SlotId> slotIdSet) {
if (root instanceof SlotRef) {
slotIdSet.add(((SlotRef) root).getDesc().getId());
return;
}
for (Expr child : root.getChildren()) {
extractSlots(child, slotIdSet);
}
}

/**
* Returns an analyzed clone of 'this' with exprs substituted according to smap.
* Removes implicit casts and analysis state while cloning/substituting exprs within
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -17,33 +17,52 @@

package org.apache.doris.nereids.rules.rewrite;

import org.apache.doris.nereids.properties.OrderKey;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;

import java.util.stream.Collectors;
import java.util.List;
import java.util.stream.Stream;

/**
* the sort node will create new slots for order by keys if the order by keys is not in the output
* so need create a project above sort node to prune the unnecessary order by keys. This means the
* Tuple slots size is difference to PhysicalSort.output.size. If not prune and hide the order key,
* the upper plan node will see the temporary slots and treat as output, and then translate failed.
* This is trick, we should add sort output tuple to ensure the tuple slot size is equals, but it
* has large workload. I think we should refactor the PhysicalPlanTranslator in the future, and
* process PhysicalProject(output)/PhysicalDistribute more general.
* SortNode on BE always output order keys because BE needs them to do merge sort. So we normalize LogicalSort as BE
* expected to materialize order key before sort by bottom project and then prune the useless column after sort by
* top project.
*/
public class NormalizeSort extends OneRewriteRuleFactory {
@Override
public Rule build() {
return logicalSort()
.when(sort -> !sort.isNormalized() && !sort.getOutputSet()
.containsAll(sort.getOrderKeys().stream()
.map(orderKey -> orderKey.getExpr()).collect(Collectors.toSet())))
return logicalSort().whenNot(sort -> sort.getOrderKeys().stream()
.map(OrderKey::getExpr).allMatch(Slot.class::isInstance))
.then(sort -> {
return new LogicalProject(sort.getOutput(), ImmutableList.of(), false,
sort.withNormalize(true));
List<NamedExpression> newProjects = Lists.newArrayList();
List<OrderKey> newOrderKeys = sort.getOrderKeys().stream()
.map(orderKey -> {
Expression expr = orderKey.getExpr();
if (!(expr instanceof Slot)) {
Alias alias = new Alias(expr, expr.toSql());
newProjects.add(alias);
expr = alias.toSlot();
}
return orderKey.withExpression(expr);
}).collect(ImmutableList.toImmutableList());
List<NamedExpression> bottomProjections = Stream.concat(
sort.child().getOutput().stream(),
newProjects.stream()
).collect(ImmutableList.toImmutableList());
List<NamedExpression> topProjections = sort.getOutput().stream()
.map(NamedExpression.class::cast)
.collect(ImmutableList.toImmutableList());
return new LogicalProject<>(topProjections, sort.withOrderKeysAndChild(newOrderKeys,
new LogicalProject<>(bottomProjections, sort.child())));
}).toRule(RuleType.NORMALIZE_SORT);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -46,29 +46,17 @@ public class LogicalSort<CHILD_TYPE extends Plan> extends LogicalUnary<CHILD_TYP

private final List<OrderKey> orderKeys;

private final boolean normalized;

public LogicalSort(List<OrderKey> orderKeys, CHILD_TYPE child) {
this(orderKeys, Optional.empty(), Optional.empty(), child);
}

public LogicalSort(List<OrderKey> orderKeys, CHILD_TYPE child, boolean normalized) {
this(orderKeys, Optional.empty(), Optional.empty(), child, normalized);
}

/**
* Constructor for LogicalSort.
*/
public LogicalSort(List<OrderKey> orderKeys, Optional<GroupExpression> groupExpression,
Optional<LogicalProperties> logicalProperties, CHILD_TYPE child) {
this(orderKeys, groupExpression, logicalProperties, child, false);
}

public LogicalSort(List<OrderKey> orderKeys, Optional<GroupExpression> groupExpression,
Optional<LogicalProperties> logicalProperties, CHILD_TYPE child, boolean normalized) {
super(PlanType.LOGICAL_SORT, groupExpression, logicalProperties, child);
this.orderKeys = ImmutableList.copyOf(Objects.requireNonNull(orderKeys, "orderKeys can not be null"));
this.normalized = normalized;
}

@Override
Expand All @@ -80,10 +68,6 @@ public List<OrderKey> getOrderKeys() {
return orderKeys;
}

public boolean isNormalized() {
return normalized;
}

@Override
public String toString() {
return Utils.toSqlString("LogicalSort[" + id.asInt() + "]",
Expand All @@ -98,7 +82,7 @@ public boolean equals(Object o) {
if (o == null || getClass() != o.getClass()) {
return false;
}
LogicalSort that = (LogicalSort) o;
LogicalSort<?> that = (LogicalSort<?>) o;
return Objects.equals(orderKeys, that.orderKeys);
}

Expand All @@ -122,30 +106,27 @@ public List<? extends Expression> getExpressions() {
@Override
public LogicalSort<Plan> withChildren(List<Plan> children) {
Preconditions.checkArgument(children.size() == 1);
return new LogicalSort<>(orderKeys, children.get(0), normalized);
return new LogicalSort<>(orderKeys, children.get(0));
}

@Override
public LogicalSort<Plan> withGroupExpression(Optional<GroupExpression> groupExpression) {
return new LogicalSort<>(orderKeys, groupExpression, Optional.of(getLogicalProperties()), child(),
normalized);
return new LogicalSort<>(orderKeys, groupExpression, Optional.of(getLogicalProperties()), child());
}

@Override
public Plan withGroupExprLogicalPropChildren(Optional<GroupExpression> groupExpression,
public LogicalSort<Plan> withGroupExprLogicalPropChildren(Optional<GroupExpression> groupExpression,
Optional<LogicalProperties> logicalProperties, List<Plan> children) {
Preconditions.checkArgument(children.size() == 1);
return new LogicalSort<>(orderKeys, groupExpression, logicalProperties, children.get(0),
normalized);
return new LogicalSort<>(orderKeys, groupExpression, logicalProperties, children.get(0));
}

public LogicalSort<Plan> withOrderKeys(List<OrderKey> orderKeys) {
return new LogicalSort<>(orderKeys, Optional.empty(),
Optional.of(getLogicalProperties()), child(), false);
Optional.of(getLogicalProperties()), child());
}

public LogicalSort<Plan> withNormalize(boolean orderKeysPruned) {
return new LogicalSort<>(orderKeys, groupExpression, Optional.of(getLogicalProperties()), child(),
orderKeysPruned);
public LogicalSort<Plan> withOrderKeysAndChild(List<OrderKey> orderKeys, Plan child) {
return new LogicalSort<>(orderKeys, child);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,8 @@

package org.apache.doris.planner;

import org.apache.doris.analysis.Analyzer;
import org.apache.doris.analysis.Expr;
import org.apache.doris.analysis.SlotDescriptor;
import org.apache.doris.analysis.SlotId;
import org.apache.doris.analysis.SortInfo;
import org.apache.doris.common.NotImplementedException;
import org.apache.doris.nereids.trees.plans.WindowFuncType;
import org.apache.doris.statistics.StatisticalType;
import org.apache.doris.thrift.TExplainLevel;
Expand All @@ -34,40 +30,29 @@

import com.google.common.base.Joiner;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

import java.util.ArrayList;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Set;

/**
* PartitionSortNode.
* PartitionSortNode is only used in the Nereids.
*/
public class PartitionSortNode extends PlanNode {
private static final Logger LOG = LogManager.getLogger(PartitionSortNode.class);
private List<Expr> resolvedTupleExprs;
private final WindowFuncType function;
private final List<Expr> partitionExprs;
private final SortInfo info;
private final boolean hasGlobalLimit;
private final long partitionLimit;

private boolean isUnusedExprRemoved = false;
private ArrayList<Boolean> nullabilityChangedFlags = Lists.newArrayList();

/**
* Constructor.
*/
public PartitionSortNode(PlanNodeId id, PlanNode input, WindowFuncType function, List<Expr> partitionExprs,
SortInfo info, boolean hasGlobalLimit, long partitionLimit,
List<Expr> outputList, List<Expr> orderingExpr) {
SortInfo info, boolean hasGlobalLimit, long partitionLimit) {
super(id, "PartitionTopN", StatisticalType.PARTITION_TOPN_NODE);
Preconditions.checkArgument(info.getOrderingExprs().size() == info.getIsAscOrder().size());
this.function = function;
this.partitionExprs = partitionExprs;
this.info = info;
Expand All @@ -77,38 +62,12 @@ public PartitionSortNode(PlanNodeId id, PlanNode input, WindowFuncType function,
this.tblRefIds.addAll(Lists.newArrayList(info.getSortTupleDescriptor().getId()));
this.nullableTupleIds.addAll(input.getNullableTupleIds());
this.children.add(input);

List<Expr> resolvedTupleExprs = new ArrayList<>();
for (Expr order : orderingExpr) {
if (!resolvedTupleExprs.contains(order)) {
resolvedTupleExprs.add(order);
}
}
for (Expr output : outputList) {
if (!resolvedTupleExprs.contains(output)) {
resolvedTupleExprs.add(output);
}
}
this.resolvedTupleExprs = ImmutableList.copyOf(resolvedTupleExprs);
info.setSortTupleSlotExprs(resolvedTupleExprs);

nullabilityChangedFlags.clear();
for (int i = 0; i < resolvedTupleExprs.size(); i++) {
nullabilityChangedFlags.add(false);
}
Preconditions.checkArgument(info.getOrderingExprs().size() == info.getIsAscOrder().size());
}

public SortInfo getSortInfo() {
return info;
}

@Override
public void getMaterializedIds(Analyzer analyzer, List<SlotId> ids) {
super.getMaterializedIds(analyzer, ids);
Expr.getIds(info.getOrderingExprs(), null, ids);
}

@Override
public String getNodeExplainString(String prefix, TExplainLevel detailLevel) {
if (detailLevel == TExplainLevel.BRIEF) {
Expand Down Expand Up @@ -164,34 +123,12 @@ public String getNodeExplainString(String prefix, TExplainLevel detailLevel) {
return output.toString();
}

private void removeUnusedExprs() {
if (!isUnusedExprRemoved) {
if (resolvedTupleExprs != null) {
List<SlotDescriptor> slotDescriptorList = this.info.getSortTupleDescriptor().getSlots();
for (int i = slotDescriptorList.size() - 1; i >= 0; i--) {
if (!slotDescriptorList.get(i).isMaterialized()) {
resolvedTupleExprs.remove(i);
nullabilityChangedFlags.remove(i);
}
}
}
isUnusedExprRemoved = true;
}
}

@Override
protected void toThrift(TPlanNode msg) {
msg.node_type = TPlanNodeType.PARTITION_SORT_NODE;

TSortInfo sortInfo = info.toThrift();
Preconditions.checkState(tupleIds.size() == 1, "Incorrect size for tupleIds in PartitionSortNode");
removeUnusedExprs();
if (resolvedTupleExprs != null) {
sortInfo.setSortTupleSlotExprs(Expr.treesToThrift(resolvedTupleExprs));
// FIXME this is a bottom line solution for wrong nullability of resolvedTupleExprs
// remove the following line after nereids online
sortInfo.setSlotExprsNullabilityChangedFlags(nullabilityChangedFlags);
}

TopNAlgorithm topNAlgorithm;
if (function == WindowFuncType.ROW_NUMBER) {
Expand All @@ -210,13 +147,4 @@ protected void toThrift(TPlanNode msg) {
partitionSortNode.setPartitionInnerLimit(partitionLimit);
msg.partition_sort_node = partitionSortNode;
}

@Override
public Set<SlotId> computeInputSlotIds(Analyzer analyzer) throws NotImplementedException {
removeUnusedExprs();
List<Expr> materializedTupleExprs = new ArrayList<>(resolvedTupleExprs);
List<SlotId> result = Lists.newArrayList();
Expr.getIds(materializedTupleExprs, null, result);
return new HashSet<>(result);
}
}
28 changes: 0 additions & 28 deletions fe/fe-core/src/main/java/org/apache/doris/planner/SortNode.java
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
import org.apache.doris.analysis.SlotId;
import org.apache.doris.analysis.SlotRef;
import org.apache.doris.analysis.SortInfo;
import org.apache.doris.analysis.TupleDescriptor;
import org.apache.doris.common.NotImplementedException;
import org.apache.doris.common.UserException;
import org.apache.doris.statistics.StatisticalType;
Expand Down Expand Up @@ -329,31 +328,4 @@ public Set<SlotId> computeInputSlotIds(Analyzer analyzer) throws NotImplementedE
Expr.getIds(materializedTupleExprs, null, result);
return new HashSet<>(result);
}

/**
* Supplement the information needed by be for the sort node.
* TODO: currently we only process slotref, so when order key is a + 1, we will failed.
*/
public void finalizeForNereids(TupleDescriptor tupleDescriptor,
List<Expr> outputList, List<Expr> orderingExpr) {
resolvedTupleExprs = Lists.newArrayList();
// TODO: should fix the duplicate order by exprs in nereids code later
for (Expr order : orderingExpr) {
if (!resolvedTupleExprs.contains(order)) {
resolvedTupleExprs.add(order);
}
}
for (Expr output : outputList) {
if (!resolvedTupleExprs.contains(output)) {
resolvedTupleExprs.add(output);
}
}
info.setSortTupleDesc(tupleDescriptor);
info.setSortTupleSlotExprs(resolvedTupleExprs);

nullabilityChangedFlags.clear();
for (int i = 0; i < resolvedTupleExprs.size(); i++) {
nullabilityChangedFlags.add(false);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,13 @@ PhysicalCteAnchor ( cteId=CTEId#0 )
------PhysicalDistribute
--------PhysicalTopN
----------PhysicalProject
------------hashJoin[INNER_JOIN](s_store_name = v1_lead.s_store_name)(v1.i_category = v1_lead.i_category)(v1.i_brand = v1_lead.i_brand)(v1.s_company_name = v1_lead.s_company_name)(v1.rn = expr_(rn - 1))
------------hashJoin[INNER_JOIN](v1.i_category = v1_lead.i_category)(v1.i_brand = v1_lead.i_brand)(v1.s_store_name = v1_lead.s_store_name)(v1.s_company_name = v1_lead.s_company_name)(v1.rn = expr_(rn - 1))
--------------PhysicalDistribute
----------------PhysicalProject
------------------PhysicalCteConsumer ( cteId=CTEId#0 )
--------------PhysicalDistribute
----------------PhysicalProject
------------------hashJoin[INNER_JOIN](s_store_name = v1_lag.s_store_name)(v1.i_category = v1_lag.i_category)(v1.i_brand = v1_lag.i_brand)(v1.s_company_name = v1_lag.s_company_name)(v1.rn = expr_(rn + 1))
------------------hashJoin[INNER_JOIN](v1.i_category = v1_lag.i_category)(v1.i_brand = v1_lag.i_brand)(v1.s_store_name = v1_lag.s_store_name)(v1.s_company_name = v1_lag.s_company_name)(v1.rn = expr_(rn + 1))
--------------------PhysicalDistribute
----------------------PhysicalProject
------------------------PhysicalCteConsumer ( cteId=CTEId#0 )
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,13 @@ PhysicalCteAnchor ( cteId=CTEId#0 )
------PhysicalDistribute
--------PhysicalTopN
----------PhysicalProject
------------hashJoin[INNER_JOIN](i_brand = v1_lead.i_brand)(v1.i_category = v1_lead.i_category)(v1.cc_name = v1_lead.cc_name)(v1.rn = expr_(rn - 1))
------------hashJoin[INNER_JOIN](v1.i_category = v1_lead.i_category)(v1.i_brand = v1_lead.i_brand)(v1.cc_name = v1_lead.cc_name)(v1.rn = expr_(rn - 1))
--------------PhysicalDistribute
----------------PhysicalProject
------------------PhysicalCteConsumer ( cteId=CTEId#0 )
--------------PhysicalDistribute
----------------PhysicalProject
------------------hashJoin[INNER_JOIN](i_brand = v1_lag.i_brand)(v1.i_category = v1_lag.i_category)(v1.cc_name = v1_lag.cc_name)(v1.rn = expr_(rn + 1))
------------------hashJoin[INNER_JOIN](v1.i_category = v1_lag.i_category)(v1.i_brand = v1_lag.i_brand)(v1.cc_name = v1_lag.cc_name)(v1.rn = expr_(rn + 1))
--------------------PhysicalDistribute
----------------------PhysicalProject
------------------------PhysicalCteConsumer ( cteId=CTEId#0 )
Expand Down