Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 13 additions & 4 deletions fe/fe-core/src/main/java/org/apache/doris/nereids/memo/Memo.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.doris.nereids.memo;

import org.apache.doris.catalog.MTMV;
import org.apache.doris.common.IdGenerator;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.cost.Cost;
Expand All @@ -33,6 +34,7 @@
import org.apache.doris.nereids.trees.plans.GroupPlan;
import org.apache.doris.nereids.trees.plans.LeafPlan;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.algebra.CatalogRelation;
import org.apache.doris.nereids.trees.plans.algebra.SetOperation;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
Expand All @@ -55,6 +57,7 @@
import java.util.Optional;
import java.util.PriorityQueue;
import java.util.Set;
import java.util.concurrent.atomic.AtomicLong;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import javax.annotation.Nullable;
Expand All @@ -69,6 +72,8 @@ public class Memo {
EventChannel.getDefaultChannel().addConsumers(new LogConsumer(GroupMergeEvent.class, EventChannel.LOG)));
private static long stateId = 0;
private final ConnectContext connectContext;
private final Set<Long> needRefreshTableIdSet = new HashSet<>();
private final AtomicLong refreshVersion = new AtomicLong(1);
private final IdGenerator<GroupId> groupIdGenerator = GroupId.createGenerator();
private final Map<GroupId, Group> groups = Maps.newLinkedHashMap();
// we could not use Set, because Set does not have get method.
Expand Down Expand Up @@ -118,6 +123,10 @@ public int getGroupExpressionsSize() {
return groupExpressions.size();
}

public long getRefreshVersion() {
return refreshVersion.get();
}

private Plan skipProject(Plan plan, Group targetGroup) {
// Some top project can't be eliminated
if (plan instanceof LogicalProject && ((LogicalProject<?>) plan).canEliminate()) {
Expand Down Expand Up @@ -406,14 +415,15 @@ private CopyInResult doCopyIn(Plan plan, @Nullable Group targetGroup, @Nullable
plan.getLogicalProperties(), targetGroup.getLogicalProperties());
throw new IllegalStateException("Insert a plan into targetGroup but differ in logicalproperties");
}
// TODO Support sync materialized view in the future
if (plan instanceof CatalogRelation && ((CatalogRelation) plan).getTable() instanceof MTMV) {
refreshVersion.incrementAndGet();
}
Optional<GroupExpression> groupExpr = plan.getGroupExpression();
if (groupExpr.isPresent()) {
Preconditions.checkState(groupExpressions.containsKey(groupExpr.get()));
return CopyInResult.of(false, groupExpr.get());
}
if (targetGroup != null) {
targetGroup.getstructInfoMap().setRefreshed(false);
}
List<Group> childrenGroups = Lists.newArrayList();
for (int i = 0; i < plan.children().size(); i++) {
// skip useless project.
Expand Down Expand Up @@ -562,7 +572,6 @@ public void mergeGroup(Group source, Group destination, HashMap<Long, Group> pla
if (source == root) {
root = destination;
}
destination.getstructInfoMap().setRefreshed(false);
groups.remove(source.getGroupId());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,31 +42,34 @@
public class StructInfoMap {
private final Map<BitSet, Pair<GroupExpression, List<BitSet>>> groupExpressionMap = new HashMap<>();
private final Map<BitSet, StructInfo> infoMap = new HashMap<>();
private boolean refreshed;
private long refreshVersion = 0;

/**
* get struct info according to table map
*
* @param mvTableMap the original table map
* @param tableMap the original table map
* @param foldTableMap the fold table map
* @param group the group that the mv matched
* @return struct info or null if not found
*/
public @Nullable StructInfo getStructInfo(BitSet mvTableMap, BitSet foldTableMap, Group group, Plan originPlan) {
if (!infoMap.containsKey(mvTableMap)) {
if ((groupExpressionMap.containsKey(foldTableMap) || groupExpressionMap.isEmpty())
&& !groupExpressionMap.containsKey(mvTableMap)) {
refresh(group);
}
if (groupExpressionMap.containsKey(mvTableMap)) {
Pair<GroupExpression, List<BitSet>> groupExpressionBitSetPair = getGroupExpressionWithChildren(
mvTableMap);
StructInfo structInfo = constructStructInfo(groupExpressionBitSetPair.first,
groupExpressionBitSetPair.second, mvTableMap, originPlan);
infoMap.put(mvTableMap, structInfo);
}
public @Nullable StructInfo getStructInfo(Memo memo, BitSet tableMap, BitSet foldTableMap,
Group group, Plan originPlan) {
StructInfo structInfo = infoMap.get(tableMap);
if (structInfo != null) {
return structInfo;
}
if (groupExpressionMap.isEmpty() || !groupExpressionMap.containsKey(tableMap)) {
refresh(group, memo.getRefreshVersion(), foldTableMap);
group.getstructInfoMap().setRefreshVersion(memo.getRefreshVersion());
}
return infoMap.get(mvTableMap);
if (groupExpressionMap.containsKey(tableMap)) {
Pair<GroupExpression, List<BitSet>> groupExpressionBitSetPair = getGroupExpressionWithChildren(
tableMap);
structInfo = constructStructInfo(groupExpressionBitSetPair.first,
groupExpressionBitSetPair.second, tableMap, originPlan);
infoMap.put(tableMap, structInfo);
}
return structInfo;
}

public Set<BitSet> getTableMaps() {
Expand All @@ -81,12 +84,12 @@ public Pair<GroupExpression, List<BitSet>> getGroupExpressionWithChildren(BitSet
return groupExpressionMap.get(tableMap);
}

public boolean isRefreshed() {
return refreshed;
public long getRefreshVersion() {
return refreshVersion;
}

public void setRefreshed(boolean refreshed) {
this.refreshed = refreshed;
public void setRefreshVersion(long refreshVersion) {
this.refreshVersion = refreshVersion;
}

private StructInfo constructStructInfo(GroupExpression groupExpression, List<BitSet> children,
Expand Down Expand Up @@ -114,27 +117,24 @@ private Plan constructPlan(GroupExpression groupExpression, List<BitSet> childre
*
* @param group the root group
*
* @return whether groupExpressionMap is updated
*/
public boolean refresh(Group group) {
Set<Group> refreshedGroup = new HashSet<>();
int originSize = groupExpressionMap.size();
public void refresh(Group group, long refreshVersion, BitSet targetBitSet) {
Set<Integer> refreshedGroup = new HashSet<>();
for (GroupExpression groupExpression : group.getLogicalExpressions()) {
List<Set<BitSet>> childrenTableMap = new ArrayList<>();
boolean needRefresh = groupExpressionMap.isEmpty();
List<Set<BitSet>> childrenTableMap = new LinkedList<>();
if (groupExpression.children().isEmpty()) {
BitSet leaf = constructLeaf(groupExpression);
groupExpressionMap.put(leaf, Pair.of(groupExpression, new ArrayList<>()));
groupExpressionMap.put(leaf, Pair.of(groupExpression, new LinkedList<>()));
continue;
}

for (Group child : groupExpression.children()) {
if (!refreshedGroup.contains(child) && !child.getstructInfoMap().isRefreshed()) {
StructInfoMap childStructInfoMap = child.getstructInfoMap();
needRefresh |= childStructInfoMap.refresh(child);
childStructInfoMap.setRefreshed(true);
StructInfoMap childStructInfoMap = child.getstructInfoMap();
if (!refreshedGroup.contains(child.getGroupId().asInt())
&& refreshVersion != childStructInfoMap.getRefreshVersion()) {
childStructInfoMap.refresh(child, refreshVersion, targetBitSet);
childStructInfoMap.setRefreshVersion(refreshVersion);
}
refreshedGroup.add(child);
refreshedGroup.add(child.getGroupId().asInt());
childrenTableMap.add(child.getstructInfoMap().getTableMaps());
}
// if one same groupExpression have refreshed, continue
Expand All @@ -150,15 +150,14 @@ public boolean refresh(Group group) {
}
// if cumulative child table map is different from current
// or current group expression map is empty, should update the groupExpressionMap currently
Collection<Pair<BitSet, List<BitSet>>> bitSetWithChildren = cartesianProduct(childrenTableMap);
if (needRefresh) {
for (Pair<BitSet, List<BitSet>> bitSetWithChild : bitSetWithChildren) {
groupExpressionMap.putIfAbsent(bitSetWithChild.first,
Pair.of(groupExpression, bitSetWithChild.second));
}
Collection<Pair<BitSet, List<BitSet>>> bitSetWithChildren = cartesianProduct(childrenTableMap,
new BitSet());
for (Pair<BitSet, List<BitSet>> bitSetWithChild : bitSetWithChildren) {
groupExpressionMap.putIfAbsent(bitSetWithChild.first,
Pair.of(groupExpression, bitSetWithChild.second));
}

}
return originSize != groupExpressionMap.size();
}

private BitSet constructLeaf(GroupExpression groupExpression) {
Expand All @@ -172,14 +171,19 @@ private BitSet constructLeaf(GroupExpression groupExpression) {
return tableMap;
}

private Collection<Pair<BitSet, List<BitSet>>> cartesianProduct(List<Set<BitSet>> childrenTableMap) {
private Collection<Pair<BitSet, List<BitSet>>> cartesianProduct(List<Set<BitSet>> childrenTableMap,
BitSet targetBitSet) {
Set<List<BitSet>> cartesianLists = Sets.cartesianProduct(childrenTableMap);
List<Pair<BitSet, List<BitSet>>> resultPairSet = new LinkedList<>();
for (List<BitSet> bitSetList : cartesianLists) {
BitSet bitSet = new BitSet();
for (BitSet b : bitSetList) {
bitSet.or(b);
}
// filter the useless bitset which targetBitSet not contains, avoid exponential expansion
if (!targetBitSet.isEmpty() && !StructInfo.containsAll(targetBitSet, bitSet)) {
continue;
}
resultPairSet.add(Pair.of(bitSet, bitSetList));
}
return resultPairSet;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ public List<Plan> rewrite(Plan queryPlan, CascadesContext cascadesContext) {
protected List<StructInfo> getValidQueryStructInfos(Plan queryPlan, CascadesContext cascadesContext,
BitSet materializedViewTableSet) {
List<StructInfo> validStructInfos = new ArrayList<>();
// For every materialized view we should trigger refreshing struct info map
List<StructInfo> uncheckedStructInfos = MaterializedViewUtils.extractStructInfo(queryPlan, cascadesContext,
materializedViewTableSet);
uncheckedStructInfos.forEach(queryStructInfo -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,16 +148,23 @@ public static List<StructInfo> extractStructInfo(Plan plan, CascadesContext casc
if (plan.getGroupExpression().isPresent()) {
Group ownerGroup = plan.getGroupExpression().get().getOwnerGroup();
StructInfoMap structInfoMap = ownerGroup.getstructInfoMap();
structInfoMap.refresh(ownerGroup);
if (cascadesContext.getMemo().getRefreshVersion() != structInfoMap.getRefreshVersion()
|| structInfoMap.getTableMaps().isEmpty()) {
structInfoMap.refresh(ownerGroup, cascadesContext.getMemo().getRefreshVersion(),
materializedViewTableSet);
structInfoMap.setRefreshVersion(cascadesContext.getMemo().getRefreshVersion());
}
Set<BitSet> queryTableSets = structInfoMap.getTableMaps();
ImmutableList.Builder<StructInfo> structInfosBuilder = ImmutableList.builder();
if (!queryTableSets.isEmpty()) {
for (BitSet queryTableSet : queryTableSets) {
// TODO As only support MatchMode.COMPLETE, so only get equaled query table struct info
if (!materializedViewTableSet.isEmpty()
&& !StructInfo.containsAll(materializedViewTableSet, queryTableSet)) {
&& !materializedViewTableSet.equals(queryTableSet)) {
continue;
}
StructInfo structInfo = structInfoMap.getStructInfo(queryTableSet, queryTableSet, ownerGroup, plan);
StructInfo structInfo = structInfoMap.getStructInfo(cascadesContext.getMemo(),
queryTableSet, queryTableSet, ownerGroup, plan);
if (structInfo != null) {
structInfosBuilder.add(structInfo);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ void testTableMap() throws Exception {
Group root = c1.getMemo().getRoot();
Set<BitSet> tableMaps = root.getstructInfoMap().getTableMaps();
Assertions.assertTrue(tableMaps.isEmpty());
root.getstructInfoMap().refresh(root);
root.getstructInfoMap().refresh(root, 1, new BitSet());
Assertions.assertEquals(1, tableMaps.size());
new MockUp<MTMVRelationManager>() {
@Mock
Expand All @@ -76,7 +76,7 @@ public boolean isMVPartitionValid(MTMV mtmv, ConnectContext ctx) {
.optimize()
.printlnBestPlanTree();
root = c1.getMemo().getRoot();
root.getstructInfoMap().refresh(root);
root.getstructInfoMap().refresh(root, 1, new BitSet());
tableMaps = root.getstructInfoMap().getTableMaps();
Assertions.assertEquals(2, tableMaps.size());
dropMvByNereids("drop materialized view mv1");
Expand All @@ -97,10 +97,8 @@ void testLazyRefresh() throws Exception {
Group root = c1.getMemo().getRoot();
Set<BitSet> tableMaps = root.getstructInfoMap().getTableMaps();
Assertions.assertTrue(tableMaps.isEmpty());
boolean refreshed = root.getstructInfoMap().refresh(root);
Assertions.assertTrue(refreshed);
refreshed = root.getstructInfoMap().refresh(root);
Assertions.assertFalse(refreshed);
root.getstructInfoMap().refresh(root, 1, new BitSet());
root.getstructInfoMap().refresh(root, 1, new BitSet());
Assertions.assertEquals(1, tableMaps.size());
new MockUp<MTMVRelationManager>() {
@Mock
Expand All @@ -126,10 +124,8 @@ public boolean isMVPartitionValid(MTMV mtmv, ConnectContext ctx) {
.optimize()
.printlnBestPlanTree();
root = c1.getMemo().getRoot();
refreshed = root.getstructInfoMap().refresh(root);
Assertions.assertTrue(refreshed);
refreshed = root.getstructInfoMap().refresh(root);
Assertions.assertFalse(refreshed);
root.getstructInfoMap().refresh(root, 1, new BitSet());
root.getstructInfoMap().refresh(root, 1, new BitSet());
tableMaps = root.getstructInfoMap().getTableMaps();
Assertions.assertEquals(2, tableMaps.size());
dropMvByNereids("drop materialized view mv1");
Expand Down Expand Up @@ -166,13 +162,13 @@ public boolean isMVPartitionValid(MTMV mtmv, ConnectContext ctx) {
.rewrite()
.optimize();
Group root = c1.getMemo().getRoot();
root.getstructInfoMap().refresh(root);
root.getstructInfoMap().refresh(root, 1, new BitSet());
StructInfoMap structInfoMap = root.getstructInfoMap();
Assertions.assertEquals(2, structInfoMap.getTableMaps().size());
BitSet mvMap = structInfoMap.getTableMaps().stream()
.filter(b -> b.cardinality() == 2)
.collect(Collectors.toList()).get(0);
StructInfo structInfo = structInfoMap.getStructInfo(mvMap, mvMap, root, null);
StructInfo structInfo = structInfoMap.getStructInfo(c1.getMemo(), mvMap, mvMap, root, null);
System.out.println(structInfo.getOriginalPlan().treeString());
BitSet bitSet = new BitSet();
structInfo.getRelations().forEach(r -> bitSet.set((int) r.getTable().getId()));
Expand Down