Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -356,10 +356,6 @@ public HyperGraph build() {
return new HyperGraph(finalOutputs, joinEdges, nodes, filterEdges, complexProject);
}

public List<HyperGraph> buildAll() {
return ImmutableList.of(build());
}

public void updateNode(int idx, Group group) {
Preconditions.checkArgument(nodes.get(idx) instanceof DPhyperNode);
nodes.set(idx, ((DPhyperNode) nodes.get(idx)).withGroup(group));
Expand Down
20 changes: 5 additions & 15 deletions fe/fe-core/src/main/java/org/apache/doris/nereids/memo/Group.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import org.apache.doris.nereids.cost.Cost;
import org.apache.doris.nereids.properties.LogicalProperties;
import org.apache.doris.nereids.properties.PhysicalProperties;
import org.apache.doris.nereids.rules.exploration.mv.StructInfo;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.Plan;
Expand Down Expand Up @@ -76,8 +75,6 @@ public class Group {

private int chosenGroupExpressionId = -1;

private List<StructInfo> structInfos = new ArrayList<>();

private StructInfoMap structInfoMap = new StructInfoMap();

/**
Expand Down Expand Up @@ -472,6 +469,7 @@ public String toString() {
}
str.append(" stats").append("\n");
str.append(getStatistics() == null ? "" : getStatistics().detail(" "));

str.append(" lowest Plan(cost, properties, plan, childrenRequires)");
getAllProperties().forEach(
prop -> {
Expand All @@ -485,6 +483,10 @@ public String toString() {
}
}
);

str.append("\n").append(" struct info map").append("\n");
str.append(structInfoMap);

return str.toString();
}

Expand Down Expand Up @@ -557,16 +559,4 @@ public String treeString() {

return TreeStringUtils.treeString(this, toString, getChildren, getExtraPlans, displayExtraPlan);
}

public List<StructInfo> getStructInfos() {
return structInfos;
}

public void addStructInfo(StructInfo structInfo) {
this.structInfos.add(structInfo);
}

public void addStructInfo(List<StructInfo> structInfos) {
this.structInfos.addAll(structInfos);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

import java.util.ArrayList;
import java.util.BitSet;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
Expand All @@ -49,7 +50,7 @@ public class StructInfoMap {
* @param group the group that the mv matched
* @return struct info or null if not found
*/
public @Nullable StructInfo getStructInfo(BitSet mvTableMap, BitSet foldTableMap, Group group) {
public @Nullable StructInfo getStructInfo(BitSet mvTableMap, BitSet foldTableMap, Group group, Plan originPlan) {
if (!infoMap.containsKey(mvTableMap)) {
if ((groupExpressionMap.containsKey(foldTableMap) || groupExpressionMap.isEmpty())
&& !groupExpressionMap.containsKey(mvTableMap)) {
Expand All @@ -59,25 +60,30 @@ public class StructInfoMap {
Pair<GroupExpression, List<BitSet>> groupExpressionBitSetPair = getGroupExpressionWithChildren(
mvTableMap);
StructInfo structInfo = constructStructInfo(groupExpressionBitSetPair.first,
groupExpressionBitSetPair.second, mvTableMap);
groupExpressionBitSetPair.second, mvTableMap, originPlan);
infoMap.put(mvTableMap, structInfo);
}
}

return infoMap.get(mvTableMap);
}

public Set<BitSet> getTableMaps() {
return groupExpressionMap.keySet();
}

public Collection<StructInfo> getStructInfos() {
return infoMap.values();
}

public Pair<GroupExpression, List<BitSet>> getGroupExpressionWithChildren(BitSet tableMap) {
return groupExpressionMap.get(tableMap);
}

private StructInfo constructStructInfo(GroupExpression groupExpression, List<BitSet> children, BitSet tableMap) {
private StructInfo constructStructInfo(GroupExpression groupExpression, List<BitSet> children,
BitSet tableMap, Plan originPlan) {
// this plan is not origin plan, should record origin plan in struct info
Plan plan = constructPlan(groupExpression, children, tableMap);
return StructInfo.of(plan).get(0);
return originPlan == null ? StructInfo.of(plan) : StructInfo.of(plan, originPlan);
}

private Plan constructPlan(GroupExpression groupExpression, List<BitSet> children, BitSet tableMap) {
Expand Down Expand Up @@ -120,12 +126,11 @@ public boolean refresh(Group group) {
refreshedGroup.add(child);
childrenTableMap.add(child.getstructInfoMap().getTableMaps());
}

if (needRefresh) {
Set<Pair<BitSet, List<BitSet>>> bitSetWithChildren = cartesianProduct(childrenTableMap);
for (Pair<BitSet, List<BitSet>> bitSetWithChild : bitSetWithChildren) {
groupExpressionMap.put(bitSetWithChild.first, Pair.of(groupExpression, bitSetWithChild.second));
}
// if cumulative child table map is different from current
// or current group expression map is empty, should update the groupExpressionMap currently
Set<Pair<BitSet, List<BitSet>>> bitSetWithChildren = cartesianProduct(childrenTableMap);
for (Pair<BitSet, List<BitSet>> bitSetWithChild : bitSetWithChildren) {
groupExpressionMap.putIfAbsent(bitSetWithChild.first, Pair.of(groupExpression, bitSetWithChild.second));
}
}
return originSize != groupExpressionMap.size();
Expand All @@ -135,7 +140,7 @@ private BitSet constructLeaf(GroupExpression groupExpression) {
Plan plan = groupExpression.getPlan();
BitSet tableMap = new BitSet();
if (plan instanceof LogicalCatalogRelation) {
// TODO: Bitmap is not compatible with long, use tree map instead
// TODO: Bitset is not compatible with long, use tree map instead
tableMap.set((int) ((LogicalCatalogRelation) plan).getTable().getId());
}
// one row relation / CTE consumer
Expand All @@ -154,4 +159,9 @@ private Set<Pair<BitSet, List<BitSet>>> cartesianProduct(List<Set<BitSet>> child
})
.collect(Collectors.toSet());
}

@Override
public String toString() {
return "StructInfoMap{ groupExpressionMap = " + groupExpressionMap + ", infoMap = " + infoMap + '}';
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
import com.google.common.collect.Sets;

import java.util.ArrayList;
import java.util.BitSet;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
Expand Down Expand Up @@ -180,18 +181,27 @@ protected Plan rewriteQueryByView(MatchMode matchMode,
}
// Firstly,if group by expression between query and view is equals, try to rewrite expression directly
Plan queryTopPlan = queryTopPlanAndAggPair.key();
if (isGroupByEquals(queryTopPlanAndAggPair, viewTopPlanAndAggPair, viewToQuerySlotMapping)) {
if (isGroupByEquals(queryTopPlanAndAggPair, viewTopPlanAndAggPair, viewToQuerySlotMapping, queryStructInfo,
viewStructInfo)) {
List<Expression> rewrittenQueryExpressions = rewriteExpression(queryTopPlan.getOutput(),
queryTopPlan,
materializationContext.getMvExprToMvScanExprMapping(),
viewToQuerySlotMapping,
true);
true,
queryStructInfo.getTableBitSet());
if (!rewrittenQueryExpressions.isEmpty()) {
return new LogicalProject<>(
rewrittenQueryExpressions.stream().map(NamedExpression.class::cast)
.collect(Collectors.toList()),
tempRewritedPlan);

List<NamedExpression> projects = new ArrayList<>();
for (Expression expression : rewrittenQueryExpressions) {
if (expression.containsType(AggregateFunction.class)) {
materializationContext.recordFailReason(queryStructInfo,
"rewritten expression contains aggregate functions when group equals aggregate rewrite",
() -> String.format("aggregate functions = %s\n", rewrittenQueryExpressions));
return null;
}
projects.add(expression instanceof NamedExpression
? (NamedExpression) expression : new Alias(expression));
}
return new LogicalProject<>(projects, tempRewritedPlan);
}
// if fails, record the reason and then try to roll up aggregate function
materializationContext.recordFailReason(queryStructInfo,
Expand Down Expand Up @@ -219,7 +229,7 @@ protected Plan rewriteQueryByView(MatchMode matchMode,
// try to roll up.
// split the query top plan expressions to group expressions and functions, if can not, bail out.
Pair<Set<? extends Expression>, Set<? extends Expression>> queryGroupAndFunctionPair
= topPlanSplitToGroupAndFunction(queryTopPlanAndAggPair);
= topPlanSplitToGroupAndFunction(queryTopPlanAndAggPair, queryStructInfo);
Set<? extends Expression> queryTopPlanFunctionSet = queryGroupAndFunctionPair.value();
// try to rewrite, contains both roll up aggregate functions and aggregate group expression
List<NamedExpression> finalOutputExpressions = new ArrayList<>();
Expand All @@ -234,9 +244,10 @@ protected Plan rewriteQueryByView(MatchMode matchMode,
if (queryTopPlanFunctionSet.contains(topExpression)) {
Expression queryFunctionShuttled = ExpressionUtils.shuttleExpressionWithLineage(
topExpression,
queryTopPlan);
queryTopPlan,
queryStructInfo.getTableBitSet());
AggregateExpressionRewriteContext context = new AggregateExpressionRewriteContext(
false, mvExprToMvScanExprQueryBased, queryTopPlan);
false, mvExprToMvScanExprQueryBased, queryTopPlan, queryStructInfo.getTableBitSet());
// queryFunctionShuttled maybe sum(column) + count(*), so need to use expression rewriter
Expression rollupedExpression = queryFunctionShuttled.accept(AGGREGATE_EXPRESSION_REWRITER,
context);
Expand All @@ -250,10 +261,10 @@ protected Plan rewriteQueryByView(MatchMode matchMode,
finalOutputExpressions.add(new Alias(rollupedExpression));
} else {
// if group by expression, try to rewrite group by expression
Expression queryGroupShuttledExpr =
ExpressionUtils.shuttleExpressionWithLineage(topExpression, queryTopPlan);
AggregateExpressionRewriteContext context = new AggregateExpressionRewriteContext(
true, mvExprToMvScanExprQueryBased, queryTopPlan);
Expression queryGroupShuttledExpr = ExpressionUtils.shuttleExpressionWithLineage(
topExpression, queryTopPlan, queryStructInfo.getTableBitSet());
AggregateExpressionRewriteContext context = new AggregateExpressionRewriteContext(true,
mvExprToMvScanExprQueryBased, queryTopPlan, queryStructInfo.getTableBitSet());
// group by expression maybe group by a + b, so we need expression rewriter
Expression rewrittenGroupByExpression = queryGroupShuttledExpr.accept(AGGREGATE_EXPRESSION_REWRITER,
context);
Expand Down Expand Up @@ -302,16 +313,18 @@ protected Plan rewriteQueryByView(MatchMode matchMode,

private boolean isGroupByEquals(Pair<Plan, LogicalAggregate<Plan>> queryTopPlanAndAggPair,
Pair<Plan, LogicalAggregate<Plan>> viewTopPlanAndAggPair,
SlotMapping viewToQuerySlotMapping) {
SlotMapping viewToQuerySlotMapping,
StructInfo queryStructInfo,
StructInfo viewStructInfo) {
Plan queryTopPlan = queryTopPlanAndAggPair.key();
Plan viewTopPlan = viewTopPlanAndAggPair.key();
LogicalAggregate<Plan> queryAggregate = queryTopPlanAndAggPair.value();
LogicalAggregate<Plan> viewAggregate = viewTopPlanAndAggPair.value();
Set<? extends Expression> queryGroupShuttledExpression = new HashSet<>(
ExpressionUtils.shuttleExpressionWithLineage(
queryAggregate.getGroupByExpressions(), queryTopPlan));
queryAggregate.getGroupByExpressions(), queryTopPlan, queryStructInfo.getTableBitSet()));
Set<? extends Expression> viewGroupShuttledExpressionQueryBased = ExpressionUtils.shuttleExpressionWithLineage(
viewAggregate.getGroupByExpressions(), viewTopPlan)
viewAggregate.getGroupByExpressions(), viewTopPlan, viewStructInfo.getTableBitSet())
.stream()
.map(expr -> ExpressionUtils.replace(expr, viewToQuerySlotMapping.toSlotReferenceMap()))
.collect(Collectors.toSet());
Expand Down Expand Up @@ -384,7 +397,7 @@ private static boolean canRollup(Expression rollupExpression) {
}

private Pair<Set<? extends Expression>, Set<? extends Expression>> topPlanSplitToGroupAndFunction(
Pair<Plan, LogicalAggregate<Plan>> topPlanAndAggPair) {
Pair<Plan, LogicalAggregate<Plan>> topPlanAndAggPair, StructInfo queryStructInfo) {
LogicalAggregate<Plan> bottomQueryAggregate = topPlanAndAggPair.value();
Set<Expression> groupByExpressionSet = new HashSet<>(bottomQueryAggregate.getGroupByExpressions());
// when query is bitmap_count(bitmap_union), the plan is as following:
Expand All @@ -403,7 +416,7 @@ private Pair<Set<? extends Expression>, Set<? extends Expression>> topPlanSplitT
queryTopPlan.getOutput().forEach(expression -> {
ExpressionLineageReplacer.ExpressionReplaceContext replaceContext =
new ExpressionLineageReplacer.ExpressionReplaceContext(ImmutableList.of(expression),
ImmutableSet.of(), ImmutableSet.of());
ImmutableSet.of(), ImmutableSet.of(), queryStructInfo.getTableBitSet());
queryTopPlan.accept(ExpressionLineageReplacer.INSTANCE, replaceContext);
if (!Sets.intersection(bottomAggregateFunctionExprIdSet,
replaceContext.getExprIdExpressionMap().keySet()).isEmpty()) {
Expand Down Expand Up @@ -509,7 +522,8 @@ public Expression visitAggregateFunction(AggregateFunction aggregateFunction,
}
Expression queryFunctionShuttled = ExpressionUtils.shuttleExpressionWithLineage(
aggregateFunction,
rewriteContext.getQueryTopPlan());
rewriteContext.getQueryTopPlan(),
rewriteContext.getQueryTableBitSet());
Function rollupAggregateFunction = rollup(aggregateFunction, queryFunctionShuttled,
rewriteContext.getMvExprToMvScanExprQueryBasedMapping());
if (rollupAggregateFunction == null) {
Expand Down Expand Up @@ -565,12 +579,15 @@ protected static class AggregateExpressionRewriteContext {
private final boolean onlyContainGroupByExpression;
private final Map<Expression, Expression> mvExprToMvScanExprQueryBasedMapping;
private final Plan queryTopPlan;
private final BitSet queryTableBitSet;

public AggregateExpressionRewriteContext(boolean onlyContainGroupByExpression,
Map<Expression, Expression> mvExprToMvScanExprQueryBasedMapping, Plan queryTopPlan) {
Map<Expression, Expression> mvExprToMvScanExprQueryBasedMapping, Plan queryTopPlan,
BitSet queryTableBitSet) {
this.onlyContainGroupByExpression = onlyContainGroupByExpression;
this.mvExprToMvScanExprQueryBasedMapping = mvExprToMvScanExprQueryBasedMapping;
this.queryTopPlan = queryTopPlan;
this.queryTableBitSet = queryTableBitSet;
}

public boolean isValid() {
Expand All @@ -592,5 +609,9 @@ public Map<Expression, Expression> getMvExprToMvScanExprQueryBasedMapping() {
public Plan getQueryTopPlan() {
return queryTopPlan;
}

public BitSet getQueryTableBitSet() {
return queryTableBitSet;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,11 @@ protected Plan rewriteQueryByView(MatchMode matchMode,
// Rewrite top projects, represent the query projects by view
List<Expression> expressionsRewritten = rewriteExpression(
queryStructInfo.getExpressions(),
queryStructInfo.getOriginalPlan(),
queryStructInfo.getTopPlan(),
materializationContext.getMvExprToMvScanExprMapping(),
targetToSourceMapping,
true
true,
queryStructInfo.getTableBitSet()
);
// Can not rewrite, bail out
if (expressionsRewritten.isEmpty()) {
Expand Down
Loading