Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;

/**
* Function dependence items.
Expand Down Expand Up @@ -96,11 +97,25 @@ private void dfs(Set<Slot> parent, Set<Set<Slot>> visited, Set<FuncDepsItem> cir
}
}

// find item that not in a circle
private Set<FuncDepsItem> findValidItems() {
// Find items that are not part of a circular dependency.
// To keep the slots in requireOutputs, we need to always keep the edges that start with output slots.
// Note: We reduce the last edge in a circular dependency,
// so we need to traverse from parents that contain the required output slots.
private Set<FuncDepsItem> findValidItems(Set<Slot> requireOutputs) {
Set<FuncDepsItem> circleItem = new HashSet<>();
Set<Set<Slot>> visited = new HashSet<>();
for (Set<Slot> parent : edges.keySet()) {
Set<Set<Slot>> parentInOutput = edges.keySet().stream()
.filter(requireOutputs::containsAll)
.collect(Collectors.toSet());
for (Set<Slot> parent : parentInOutput) {
if (!visited.contains(parent)) {
dfs(parent, visited, circleItem);
}
}
Set<Set<Slot>> otherParent = edges.keySet().stream()
.filter(parent -> !parentInOutput.contains(parent))
.collect(Collectors.toSet());
for (Set<Slot> parent : otherParent) {
if (!visited.contains(parent)) {
dfs(parent, visited, circleItem);
}
Expand All @@ -126,10 +141,10 @@ private Set<FuncDepsItem> findValidItems() {
* @param slots the initial set of slot sets to be reduced
* @return the minimal set of slot sets after applying all possible reductions
*/
public Set<Set<Slot>> eliminateDeps(Set<Set<Slot>> slots) {
public Set<Set<Slot>> eliminateDeps(Set<Set<Slot>> slots, Set<Slot> requireOutputs) {
Set<Set<Slot>> minSlotSet = Sets.newHashSet(slots);
Set<Set<Slot>> eliminatedSlots = new HashSet<>();
Set<FuncDepsItem> validItems = findValidItems();
Set<FuncDepsItem> validItems = findValidItems(requireOutputs);
for (FuncDepsItem funcDepsItem : validItems) {
if (minSlotSet.contains(funcDepsItem.dependencies)
&& minSlotSet.contains(funcDepsItem.determinants)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ LogicalAggregate<Plan> eliminateGroupByKey(LogicalAggregate<? extends Plan> agg,
return null;
}

Set<Set<Slot>> minGroupBySlots = funcDeps.eliminateDeps(new HashSet<>(groupBySlots.values()));
Set<Set<Slot>> minGroupBySlots = funcDeps.eliminateDeps(new HashSet<>(groupBySlots.values()), requireOutput);
Set<Expression> removeExpression = new HashSet<>();
for (Entry<Expression, Set<Slot>> entry : groupBySlots.entrySet()) {
if (!minGroupBySlots.contains(entry.getValue())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.types.IntegerType;

import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Sets;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
Expand All @@ -43,7 +44,7 @@ void testOneEliminate() {
Set<Set<Slot>> slotSet = Sets.newHashSet(set1, set2, set3, set4);
FuncDeps funcDeps = new FuncDeps();
funcDeps.addFuncItems(Sets.newHashSet(s1), Sets.newHashSet(s2));
Set<Set<Slot>> slots = funcDeps.eliminateDeps(slotSet);
Set<Set<Slot>> slots = funcDeps.eliminateDeps(slotSet, ImmutableSet.of());
Set<Set<Slot>> expected = new HashSet<>();
expected.add(set1);
expected.add(set3);
Expand All @@ -58,7 +59,7 @@ void testChainEliminate() {
funcDeps.addFuncItems(Sets.newHashSet(s1), Sets.newHashSet(s2));
funcDeps.addFuncItems(Sets.newHashSet(s2), Sets.newHashSet(s3));
funcDeps.addFuncItems(Sets.newHashSet(s3), Sets.newHashSet(s4));
Set<Set<Slot>> slots = funcDeps.eliminateDeps(slotSet);
Set<Set<Slot>> slots = funcDeps.eliminateDeps(slotSet, ImmutableSet.of());
Set<Set<Slot>> expected = new HashSet<>();
expected.add(set1);
Assertions.assertEquals(expected, slots);
Expand All @@ -71,7 +72,7 @@ void testTreeEliminate() {
funcDeps.addFuncItems(Sets.newHashSet(s1), Sets.newHashSet(s2));
funcDeps.addFuncItems(Sets.newHashSet(s1), Sets.newHashSet(s3));
funcDeps.addFuncItems(Sets.newHashSet(s1), Sets.newHashSet(s4));
Set<Set<Slot>> slots = funcDeps.eliminateDeps(slotSet);
Set<Set<Slot>> slots = funcDeps.eliminateDeps(slotSet, ImmutableSet.of());
Set<Set<Slot>> expected = new HashSet<>();
expected.add(set1);
Assertions.assertEquals(expected, slots);
Expand All @@ -83,7 +84,7 @@ void testCircleEliminate1() {
FuncDeps funcDeps = new FuncDeps();
funcDeps.addFuncItems(Sets.newHashSet(s1), Sets.newHashSet(s2));
funcDeps.addFuncItems(Sets.newHashSet(s2), Sets.newHashSet(s1));
Set<Set<Slot>> slots = funcDeps.eliminateDeps(slotSet);
Set<Set<Slot>> slots = funcDeps.eliminateDeps(slotSet, ImmutableSet.of());
Set<Set<Slot>> expected = new HashSet<>();
expected.add(set1);
expected.add(set3);
Expand All @@ -99,7 +100,7 @@ void testCircleEliminate2() {
funcDeps.addFuncItems(Sets.newHashSet(s2), Sets.newHashSet(s3));
funcDeps.addFuncItems(Sets.newHashSet(s3), Sets.newHashSet(s4));
funcDeps.addFuncItems(Sets.newHashSet(s4), Sets.newHashSet(s1));
Set<Set<Slot>> slots = funcDeps.eliminateDeps(slotSet);
Set<Set<Slot>> slots = funcDeps.eliminateDeps(slotSet, ImmutableSet.of());
Set<Set<Slot>> expected = new HashSet<>();
expected.add(set1);
Assertions.assertEquals(expected, slots);
Expand All @@ -112,7 +113,7 @@ void testGraphEliminate1() {
funcDeps.addFuncItems(Sets.newHashSet(s1), Sets.newHashSet(s2));
funcDeps.addFuncItems(Sets.newHashSet(s1), Sets.newHashSet(s3));
funcDeps.addFuncItems(Sets.newHashSet(s3), Sets.newHashSet(s4));
Set<Set<Slot>> slots = funcDeps.eliminateDeps(slotSet);
Set<Set<Slot>> slots = funcDeps.eliminateDeps(slotSet, ImmutableSet.of());
Set<Set<Slot>> expected = new HashSet<>();
expected.add(set1);
Assertions.assertEquals(expected, slots);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ void testEliminateChain() {
funcDeps.addFuncItems(set1, set2);
funcDeps.addFuncItems(set2, set3);
funcDeps.addFuncItems(set3, set4);
Set<Set<Slot>> slots = funcDeps.eliminateDeps(ImmutableSet.of(set1, set2, set3, set4));
Set<Set<Slot>> slots = funcDeps.eliminateDeps(ImmutableSet.of(set1, set2, set3, set4), ImmutableSet.of());
Assertions.assertEquals(1, slots.size());
Assertions.assertEquals(set1, slots.iterator().next());
}
Expand All @@ -78,7 +78,7 @@ void testEliminateCircle() {
funcDeps.addFuncItems(set2, set3);
funcDeps.addFuncItems(set3, set4);
funcDeps.addFuncItems(set4, set1);
Set<Set<Slot>> slots = funcDeps.eliminateDeps(ImmutableSet.of(set1, set2, set3, set4));
Set<Set<Slot>> slots = funcDeps.eliminateDeps(ImmutableSet.of(set1, set2, set3, set4), ImmutableSet.of());
Assertions.assertEquals(1, slots.size());
Assertions.assertEquals(set1, slots.iterator().next());
}
Expand All @@ -89,7 +89,7 @@ void testEliminateTree() {
funcDeps.addFuncItems(set1, set2);
funcDeps.addFuncItems(set1, set3);
funcDeps.addFuncItems(set1, set4);
Set<Set<Slot>> slots = funcDeps.eliminateDeps(ImmutableSet.of(set1, set2, set3, set4));
Set<Set<Slot>> slots = funcDeps.eliminateDeps(ImmutableSet.of(set1, set2, set3, set4), ImmutableSet.of());
Assertions.assertEquals(1, slots.size());
Assertions.assertEquals(set1, slots.iterator().next());
}
Expand Down Expand Up @@ -163,11 +163,20 @@ void testEliminateByPk() throws Exception {
@Test
void testEliminateByEqual() {
PlanChecker.from(connectContext)
.analyze("select count(t1.name) from t1 as t1 join t1 as t2 on t1.name = t2.name group by t1.name, t2.name")
.analyze("select t1.name from t1 as t1 join t1 as t2 on t1.name = t2.name group by t1.name, t2.name")
.rewrite()
.printlnTree()
.matches(logicalAggregate().when(agg ->
agg.getGroupByExpressions().size() == 1 && agg.getGroupByExpressions().get(0).toSql().equals("name")));
}
agg.getGroupByExpressions().size() == 1
&& agg.getGroupByExpressions().get(0).toSql().equals("name")));

PlanChecker.from(connectContext)
.analyze("select t2.name from t1 as t1 join t1 as t2 "
+ "on t1.name = t2.name group by t1.name, t2.name")
.rewrite()
.printlnTree()
.matches(logicalAggregate().when(agg ->
agg.getGroupByExpressions().size() == 1
&& agg.getGroupByExpressions().get(0).toSql().equals("name")));
}
}