From 568c7654f277048c649f40895bbd1b91b3820dc2 Mon Sep 17 00:00:00 2001 From: morrySnow Date: Thu, 24 Aug 2023 17:21:38 +0800 Subject: [PATCH] [fix](Nereids) infer predicates generate wrong result we use two facilities to do predicate infer: PredicatePropagation and PullUpPredicates. When we do propagation in PredicatePropagation, we save the source predicates could be used in the upper node. However, we shoud not save any predicates from join on clause. Because these expression is not same with predicate in filter and could not prepagate to other not except the join's Immediate children. For example: ```sql select a.c1 from a left join b on a.c2 = b.c2 and a.c1 = '1' left join c on a.c2 = c.c2 and a.c1 = '2' inner join d on a.c3=d.c3 ``` the predicates `a.c1 = '1'` and `a.c1 = '2'` should not be inferred as filter to relation `a` --- .../rules/rewrite/InferPredicates.java | 15 ++++-- .../rules/rewrite/PredicatePropagation.java | 50 ++++++++++++------- .../rules/rewrite/PullUpPredicates.java | 2 +- .../rules/rewrite/InferPredicatesTest.java | 21 ++++++++ 4 files changed, 65 insertions(+), 23 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/InferPredicates.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/InferPredicates.java index 9736db848265ee..f21d932c67932f 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/InferPredicates.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/InferPredicates.java @@ -31,6 +31,7 @@ import com.google.common.collect.Sets; import java.util.List; +import java.util.Map; import java.util.Optional; import java.util.Set; import java.util.stream.Collectors; @@ -103,11 +104,15 @@ public Plan visitLogicalFilter(LogicalFilter filter, JobContext } private Set getAllExpressions(Plan left, Plan right, Optional condition) { - Set baseExpressions = pullUpPredicates(left); - baseExpressions.addAll(pullUpPredicates(right)); - condition.ifPresent(on -> baseExpressions.addAll(ExpressionUtils.extractConjunction(on))); - baseExpressions.addAll(propagation.infer(baseExpressions)); - return baseExpressions; + Map baseExpressions = pullUpPredicates(left).stream() + .collect(Collectors.toMap(e -> e, e -> true)); + baseExpressions.putAll(pullUpPredicates(right).stream() + .collect(Collectors.toMap(e -> e, e -> true))); + condition.ifPresent(on -> baseExpressions.putAll(ExpressionUtils.extractConjunction(on).stream() + .collect(Collectors.toMap(e -> e, e -> false)))); + Set allExpressions = Sets.newHashSet(baseExpressions.keySet()); + allExpressions.addAll(propagation.infer(baseExpressions)); + return allExpressions; } private Set pullUpPredicates(Plan plan) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PredicatePropagation.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PredicatePropagation.java index f6f04e899bc2f3..b0257dba3445e0 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PredicatePropagation.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PredicatePropagation.java @@ -30,8 +30,10 @@ import com.google.common.collect.Sets; import java.util.List; +import java.util.Map; import java.util.Set; import java.util.stream.Collectors; +import java.util.stream.Stream; /** * derive additional predicates. @@ -43,25 +45,34 @@ public class PredicatePropagation { /** * equal predicate with literal in one side would be chosen to be source predicates and used to infer all predicates */ - private Set sourcePredicates = Sets.newHashSet(); + private final Set sourcePredicates = Sets.newHashSet(); /** * infer additional predicates. + * + * @param predicates predicate expression to shouldRecordIntoSourcePredicates flag. If the value is false + * we do not record itself and any expression inferred from it into sourcePredicates set. + * This is because these expression should not be seen when we process upper join node. + * the example could see the UT test + * {@link InferPredicatesTest#shouldNotSaveOnClausePredicates()} */ - public Set infer(Set predicates) { + public Set infer(Map predicates) { Set inferred = Sets.newHashSet(); - predicates.addAll(sourcePredicates); - for (Expression predicate : predicates) { - if (canEquivalentInfer(predicate)) { - List newInferred = predicates.stream() + Set newSourcePredicates = Sets.newHashSet(); + for (Map.Entry predicate : predicates.entrySet()) { + if (canEquivalentInfer(predicate.getKey())) { + List newInferred = Stream.concat(sourcePredicates.stream(), predicates.keySet().stream()) .filter(p -> !p.equals(predicate)) - .map(p -> doInfer(predicate, p)) + .map(p -> doInfer(predicate.getKey(), p, predicate.getValue())) .collect(Collectors.toList()); inferred.addAll(newInferred); + newInferred.removeAll(predicates.keySet()); + if (predicate.getValue()) { + newSourcePredicates.addAll(newInferred); + } } } - inferred.removeAll(predicates); - sourcePredicates.addAll(inferred); + sourcePredicates.addAll(newSourcePredicates); return inferred; } @@ -71,25 +82,30 @@ public Set infer(Set predicates) { * TODO: We should determine whether `expression` satisfies the condition for replacement * eg: Satisfy `expression` is non-deterministic */ - private Expression doInfer(Expression leftSlotEqualToRightSlot, Expression expression) { - return expression.accept(new DefaultExpressionRewriter() { + private Expression doInfer(Expression leftSlotEqualToRightSlot, + Expression expression, boolean recordSourcePredicates) { + return expression.accept(new DefaultExpressionRewriter() { @Override - public Expression visit(Expression expr, Void context) { + public Expression visit(Expression expr, Boolean recordSourcePredicates) { return expr; } @Override - public Expression visitComparisonPredicate(ComparisonPredicate cp, Void context) { + public Expression visitComparisonPredicate(ComparisonPredicate cp, Boolean recordSourcePredicates) { // we need to get expression covered by cast, because we want to infer different datatype if (ExpressionUtils.isExpressionSlotCoveredByCast(cp.left()) && (cp.right().isConstant())) { - sourcePredicates.add(cp); + if (recordSourcePredicates) { + sourcePredicates.add(cp); + } return replaceSlot(cp, ExpressionUtils.getDatatypeCoveredByCast(cp.left())); } else if (ExpressionUtils.isExpressionSlotCoveredByCast(cp.right()) && cp.left().isConstant()) { - sourcePredicates.add(cp); + if (recordSourcePredicates) { + sourcePredicates.add(cp); + } return replaceSlot(cp, ExpressionUtils.getDatatypeCoveredByCast(cp.right())); } - return super.visit(cp, context); + return super.visit(cp, recordSourcePredicates); } private boolean isDataTypeValid(DataType originDataType, Expression expr) { @@ -119,7 +135,7 @@ private Expression replaceSlot(Expression expr, DataType originDataType) { return e; }); } - }, null); + }, recordSourcePredicates); } /** diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PullUpPredicates.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PullUpPredicates.java index 781d056422d411..56b475b3f3b69c 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PullUpPredicates.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PullUpPredicates.java @@ -154,7 +154,7 @@ private ImmutableSet cacheOrElse(Plan plan, Supplier getAvailableExpressions(Collection predicates, Plan plan) { Set expressions = Sets.newHashSet(predicates); - expressions.addAll(propagation.infer(expressions)); + expressions.addAll(propagation.infer(expressions.stream().collect(Collectors.toMap(e -> e, e -> false)))); return expressions.stream() .filter(p -> plan.getOutputSet().containsAll(p.getInputSlots())) .collect(ImmutableSet.toImmutableSet()); diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/InferPredicatesTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/InferPredicatesTest.java index 04613f7e75e7f0..2f88070d2489a1 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/InferPredicatesTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/InferPredicatesTest.java @@ -17,6 +17,7 @@ package org.apache.doris.nereids.rules.rewrite; +import org.apache.doris.nereids.trees.plans.JoinType; import org.apache.doris.nereids.util.MemoPatternMatchSupported; import org.apache.doris.nereids.util.PlanChecker; import org.apache.doris.utframe.TestWithFeService; @@ -604,5 +605,25 @@ public void inferPredicatesTest22() { ) ); } + + @Test + public void shouldNotSaveOnClausePredicates() { + String sql = "select * from student s1" + + " left join (select sid as id1, sid as id2, grade from score) s2 on s1.id = s2.id1 and s1.id = 1" + + " join (select sid as id1, sid as id2, grade from score) s3 on s1.id = s3.id1 where s1.id = 2"; + PlanChecker.from(connectContext).analyze(sql).rewrite().printlnTree(); + PlanChecker.from(connectContext) + .analyze(sql) + .rewrite() + .matches( + logicalJoin( + logicalFilter( + logicalOlapScan() + ).when(filter -> filter.getConjuncts().size() == 1 + && filter.getPredicate().toSql().contains("id = 2")), + any() + ).when(join -> join.getJoinType() == JoinType.LEFT_OUTER_JOIN) + ); + } }