diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/FuncDeps.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/FuncDeps.java index 6c1b302d7dc11d..c17fd2eee57e1b 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/FuncDeps.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/FuncDeps.java @@ -27,6 +27,7 @@ import java.util.Map; import java.util.Objects; import java.util.Set; +import java.util.stream.Collectors; /** * Function dependence items. @@ -96,11 +97,25 @@ private void dfs(Set parent, Set> visited, Set cir } } - // find item that not in a circle - private Set findValidItems() { + // Find items that are not part of a circular dependency. + // To keep the slots in requireOutputs, we need to always keep the edges that start with output slots. + // Note: We reduce the last edge in a circular dependency, + // so we need to traverse from parents that contain the required output slots. + private Set findValidItems(Set requireOutputs) { Set circleItem = new HashSet<>(); Set> visited = new HashSet<>(); - for (Set parent : edges.keySet()) { + Set> parentInOutput = edges.keySet().stream() + .filter(requireOutputs::containsAll) + .collect(Collectors.toSet()); + for (Set parent : parentInOutput) { + if (!visited.contains(parent)) { + dfs(parent, visited, circleItem); + } + } + Set> otherParent = edges.keySet().stream() + .filter(parent -> !parentInOutput.contains(parent)) + .collect(Collectors.toSet()); + for (Set parent : otherParent) { if (!visited.contains(parent)) { dfs(parent, visited, circleItem); } @@ -126,10 +141,10 @@ private Set findValidItems() { * @param slots the initial set of slot sets to be reduced * @return the minimal set of slot sets after applying all possible reductions */ - public Set> eliminateDeps(Set> slots) { + public Set> eliminateDeps(Set> slots, Set requireOutputs) { Set> minSlotSet = Sets.newHashSet(slots); Set> eliminatedSlots = new HashSet<>(); - Set validItems = findValidItems(); + Set validItems = findValidItems(requireOutputs); for (FuncDepsItem funcDepsItem : validItems) { if (minSlotSet.contains(funcDepsItem.dependencies) && minSlotSet.contains(funcDepsItem.determinants)) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateGroupByKey.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateGroupByKey.java index 9e205f858090bc..fbe0988daff5bc 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateGroupByKey.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateGroupByKey.java @@ -91,7 +91,7 @@ LogicalAggregate eliminateGroupByKey(LogicalAggregate agg, return null; } - Set> minGroupBySlots = funcDeps.eliminateDeps(new HashSet<>(groupBySlots.values())); + Set> minGroupBySlots = funcDeps.eliminateDeps(new HashSet<>(groupBySlots.values()), requireOutput); Set removeExpression = new HashSet<>(); for (Entry> entry : groupBySlots.entrySet()) { if (!minGroupBySlots.contains(entry.getValue()) diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/properties/FuncDepsTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/properties/FuncDepsTest.java index 64df33acd602c8..6b17305ed7a3e8 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/properties/FuncDepsTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/properties/FuncDepsTest.java @@ -21,6 +21,7 @@ import org.apache.doris.nereids.trees.expressions.SlotReference; import org.apache.doris.nereids.types.IntegerType; +import com.google.common.collect.ImmutableSet; import com.google.common.collect.Sets; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; @@ -43,7 +44,7 @@ void testOneEliminate() { Set> slotSet = Sets.newHashSet(set1, set2, set3, set4); FuncDeps funcDeps = new FuncDeps(); funcDeps.addFuncItems(Sets.newHashSet(s1), Sets.newHashSet(s2)); - Set> slots = funcDeps.eliminateDeps(slotSet); + Set> slots = funcDeps.eliminateDeps(slotSet, ImmutableSet.of()); Set> expected = new HashSet<>(); expected.add(set1); expected.add(set3); @@ -58,7 +59,7 @@ void testChainEliminate() { funcDeps.addFuncItems(Sets.newHashSet(s1), Sets.newHashSet(s2)); funcDeps.addFuncItems(Sets.newHashSet(s2), Sets.newHashSet(s3)); funcDeps.addFuncItems(Sets.newHashSet(s3), Sets.newHashSet(s4)); - Set> slots = funcDeps.eliminateDeps(slotSet); + Set> slots = funcDeps.eliminateDeps(slotSet, ImmutableSet.of()); Set> expected = new HashSet<>(); expected.add(set1); Assertions.assertEquals(expected, slots); @@ -71,7 +72,7 @@ void testTreeEliminate() { funcDeps.addFuncItems(Sets.newHashSet(s1), Sets.newHashSet(s2)); funcDeps.addFuncItems(Sets.newHashSet(s1), Sets.newHashSet(s3)); funcDeps.addFuncItems(Sets.newHashSet(s1), Sets.newHashSet(s4)); - Set> slots = funcDeps.eliminateDeps(slotSet); + Set> slots = funcDeps.eliminateDeps(slotSet, ImmutableSet.of()); Set> expected = new HashSet<>(); expected.add(set1); Assertions.assertEquals(expected, slots); @@ -83,7 +84,7 @@ void testCircleEliminate1() { FuncDeps funcDeps = new FuncDeps(); funcDeps.addFuncItems(Sets.newHashSet(s1), Sets.newHashSet(s2)); funcDeps.addFuncItems(Sets.newHashSet(s2), Sets.newHashSet(s1)); - Set> slots = funcDeps.eliminateDeps(slotSet); + Set> slots = funcDeps.eliminateDeps(slotSet, ImmutableSet.of()); Set> expected = new HashSet<>(); expected.add(set1); expected.add(set3); @@ -99,7 +100,7 @@ void testCircleEliminate2() { funcDeps.addFuncItems(Sets.newHashSet(s2), Sets.newHashSet(s3)); funcDeps.addFuncItems(Sets.newHashSet(s3), Sets.newHashSet(s4)); funcDeps.addFuncItems(Sets.newHashSet(s4), Sets.newHashSet(s1)); - Set> slots = funcDeps.eliminateDeps(slotSet); + Set> slots = funcDeps.eliminateDeps(slotSet, ImmutableSet.of()); Set> expected = new HashSet<>(); expected.add(set1); Assertions.assertEquals(expected, slots); @@ -112,7 +113,7 @@ void testGraphEliminate1() { funcDeps.addFuncItems(Sets.newHashSet(s1), Sets.newHashSet(s2)); funcDeps.addFuncItems(Sets.newHashSet(s1), Sets.newHashSet(s3)); funcDeps.addFuncItems(Sets.newHashSet(s3), Sets.newHashSet(s4)); - Set> slots = funcDeps.eliminateDeps(slotSet); + Set> slots = funcDeps.eliminateDeps(slotSet, ImmutableSet.of()); Set> expected = new HashSet<>(); expected.add(set1); Assertions.assertEquals(expected, slots); diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/EliminateGroupByKeyTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/EliminateGroupByKeyTest.java index 203e902b3ebddf..5a9e15cf4774d1 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/EliminateGroupByKeyTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/EliminateGroupByKeyTest.java @@ -66,7 +66,7 @@ void testEliminateChain() { funcDeps.addFuncItems(set1, set2); funcDeps.addFuncItems(set2, set3); funcDeps.addFuncItems(set3, set4); - Set> slots = funcDeps.eliminateDeps(ImmutableSet.of(set1, set2, set3, set4)); + Set> slots = funcDeps.eliminateDeps(ImmutableSet.of(set1, set2, set3, set4), ImmutableSet.of()); Assertions.assertEquals(1, slots.size()); Assertions.assertEquals(set1, slots.iterator().next()); } @@ -78,7 +78,7 @@ void testEliminateCircle() { funcDeps.addFuncItems(set2, set3); funcDeps.addFuncItems(set3, set4); funcDeps.addFuncItems(set4, set1); - Set> slots = funcDeps.eliminateDeps(ImmutableSet.of(set1, set2, set3, set4)); + Set> slots = funcDeps.eliminateDeps(ImmutableSet.of(set1, set2, set3, set4), ImmutableSet.of()); Assertions.assertEquals(1, slots.size()); Assertions.assertEquals(set1, slots.iterator().next()); } @@ -89,7 +89,7 @@ void testEliminateTree() { funcDeps.addFuncItems(set1, set2); funcDeps.addFuncItems(set1, set3); funcDeps.addFuncItems(set1, set4); - Set> slots = funcDeps.eliminateDeps(ImmutableSet.of(set1, set2, set3, set4)); + Set> slots = funcDeps.eliminateDeps(ImmutableSet.of(set1, set2, set3, set4), ImmutableSet.of()); Assertions.assertEquals(1, slots.size()); Assertions.assertEquals(set1, slots.iterator().next()); } @@ -163,11 +163,20 @@ void testEliminateByPk() throws Exception { @Test void testEliminateByEqual() { PlanChecker.from(connectContext) - .analyze("select count(t1.name) from t1 as t1 join t1 as t2 on t1.name = t2.name group by t1.name, t2.name") + .analyze("select t1.name from t1 as t1 join t1 as t2 on t1.name = t2.name group by t1.name, t2.name") .rewrite() .printlnTree() .matches(logicalAggregate().when(agg -> - agg.getGroupByExpressions().size() == 1 && agg.getGroupByExpressions().get(0).toSql().equals("name"))); - } + agg.getGroupByExpressions().size() == 1 + && agg.getGroupByExpressions().get(0).toSql().equals("name"))); + PlanChecker.from(connectContext) + .analyze("select t2.name from t1 as t1 join t1 as t2 " + + "on t1.name = t2.name group by t1.name, t2.name") + .rewrite() + .printlnTree() + .matches(logicalAggregate().when(agg -> + agg.getGroupByExpressions().size() == 1 + && agg.getGroupByExpressions().get(0).toSql().equals("name"))); + } }