diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/InferPredicates.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/InferPredicates.java index 9736db848265ee..f21d932c67932f 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/InferPredicates.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/InferPredicates.java @@ -31,6 +31,7 @@ import com.google.common.collect.Sets; import java.util.List; +import java.util.Map; import java.util.Optional; import java.util.Set; import java.util.stream.Collectors; @@ -103,11 +104,15 @@ public Plan visitLogicalFilter(LogicalFilter filter, JobContext } private Set getAllExpressions(Plan left, Plan right, Optional condition) { - Set baseExpressions = pullUpPredicates(left); - baseExpressions.addAll(pullUpPredicates(right)); - condition.ifPresent(on -> baseExpressions.addAll(ExpressionUtils.extractConjunction(on))); - baseExpressions.addAll(propagation.infer(baseExpressions)); - return baseExpressions; + Map baseExpressions = pullUpPredicates(left).stream() + .collect(Collectors.toMap(e -> e, e -> true)); + baseExpressions.putAll(pullUpPredicates(right).stream() + .collect(Collectors.toMap(e -> e, e -> true))); + condition.ifPresent(on -> baseExpressions.putAll(ExpressionUtils.extractConjunction(on).stream() + .collect(Collectors.toMap(e -> e, e -> false)))); + Set allExpressions = Sets.newHashSet(baseExpressions.keySet()); + allExpressions.addAll(propagation.infer(baseExpressions)); + return allExpressions; } private Set pullUpPredicates(Plan plan) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PredicatePropagation.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PredicatePropagation.java index f6f04e899bc2f3..b0257dba3445e0 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PredicatePropagation.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PredicatePropagation.java @@ -30,8 +30,10 @@ import com.google.common.collect.Sets; import java.util.List; +import java.util.Map; import java.util.Set; import java.util.stream.Collectors; +import java.util.stream.Stream; /** * derive additional predicates. @@ -43,25 +45,34 @@ public class PredicatePropagation { /** * equal predicate with literal in one side would be chosen to be source predicates and used to infer all predicates */ - private Set sourcePredicates = Sets.newHashSet(); + private final Set sourcePredicates = Sets.newHashSet(); /** * infer additional predicates. + * + * @param predicates predicate expression to shouldRecordIntoSourcePredicates flag. If the value is false + * we do not record itself and any expression inferred from it into sourcePredicates set. + * This is because these expression should not be seen when we process upper join node. + * the example could see the UT test + * {@link InferPredicatesTest#shouldNotSaveOnClausePredicates()} */ - public Set infer(Set predicates) { + public Set infer(Map predicates) { Set inferred = Sets.newHashSet(); - predicates.addAll(sourcePredicates); - for (Expression predicate : predicates) { - if (canEquivalentInfer(predicate)) { - List newInferred = predicates.stream() + Set newSourcePredicates = Sets.newHashSet(); + for (Map.Entry predicate : predicates.entrySet()) { + if (canEquivalentInfer(predicate.getKey())) { + List newInferred = Stream.concat(sourcePredicates.stream(), predicates.keySet().stream()) .filter(p -> !p.equals(predicate)) - .map(p -> doInfer(predicate, p)) + .map(p -> doInfer(predicate.getKey(), p, predicate.getValue())) .collect(Collectors.toList()); inferred.addAll(newInferred); + newInferred.removeAll(predicates.keySet()); + if (predicate.getValue()) { + newSourcePredicates.addAll(newInferred); + } } } - inferred.removeAll(predicates); - sourcePredicates.addAll(inferred); + sourcePredicates.addAll(newSourcePredicates); return inferred; } @@ -71,25 +82,30 @@ public Set infer(Set predicates) { * TODO: We should determine whether `expression` satisfies the condition for replacement * eg: Satisfy `expression` is non-deterministic */ - private Expression doInfer(Expression leftSlotEqualToRightSlot, Expression expression) { - return expression.accept(new DefaultExpressionRewriter() { + private Expression doInfer(Expression leftSlotEqualToRightSlot, + Expression expression, boolean recordSourcePredicates) { + return expression.accept(new DefaultExpressionRewriter() { @Override - public Expression visit(Expression expr, Void context) { + public Expression visit(Expression expr, Boolean recordSourcePredicates) { return expr; } @Override - public Expression visitComparisonPredicate(ComparisonPredicate cp, Void context) { + public Expression visitComparisonPredicate(ComparisonPredicate cp, Boolean recordSourcePredicates) { // we need to get expression covered by cast, because we want to infer different datatype if (ExpressionUtils.isExpressionSlotCoveredByCast(cp.left()) && (cp.right().isConstant())) { - sourcePredicates.add(cp); + if (recordSourcePredicates) { + sourcePredicates.add(cp); + } return replaceSlot(cp, ExpressionUtils.getDatatypeCoveredByCast(cp.left())); } else if (ExpressionUtils.isExpressionSlotCoveredByCast(cp.right()) && cp.left().isConstant()) { - sourcePredicates.add(cp); + if (recordSourcePredicates) { + sourcePredicates.add(cp); + } return replaceSlot(cp, ExpressionUtils.getDatatypeCoveredByCast(cp.right())); } - return super.visit(cp, context); + return super.visit(cp, recordSourcePredicates); } private boolean isDataTypeValid(DataType originDataType, Expression expr) { @@ -119,7 +135,7 @@ private Expression replaceSlot(Expression expr, DataType originDataType) { return e; }); } - }, null); + }, recordSourcePredicates); } /** diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PullUpPredicates.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PullUpPredicates.java index 781d056422d411..56b475b3f3b69c 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PullUpPredicates.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PullUpPredicates.java @@ -154,7 +154,7 @@ private ImmutableSet cacheOrElse(Plan plan, Supplier getAvailableExpressions(Collection predicates, Plan plan) { Set expressions = Sets.newHashSet(predicates); - expressions.addAll(propagation.infer(expressions)); + expressions.addAll(propagation.infer(expressions.stream().collect(Collectors.toMap(e -> e, e -> false)))); return expressions.stream() .filter(p -> plan.getOutputSet().containsAll(p.getInputSlots())) .collect(ImmutableSet.toImmutableSet()); diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/InferPredicatesTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/InferPredicatesTest.java index 04613f7e75e7f0..2f88070d2489a1 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/InferPredicatesTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/InferPredicatesTest.java @@ -17,6 +17,7 @@ package org.apache.doris.nereids.rules.rewrite; +import org.apache.doris.nereids.trees.plans.JoinType; import org.apache.doris.nereids.util.MemoPatternMatchSupported; import org.apache.doris.nereids.util.PlanChecker; import org.apache.doris.utframe.TestWithFeService; @@ -604,5 +605,25 @@ public void inferPredicatesTest22() { ) ); } + + @Test + public void shouldNotSaveOnClausePredicates() { + String sql = "select * from student s1" + + " left join (select sid as id1, sid as id2, grade from score) s2 on s1.id = s2.id1 and s1.id = 1" + + " join (select sid as id1, sid as id2, grade from score) s3 on s1.id = s3.id1 where s1.id = 2"; + PlanChecker.from(connectContext).analyze(sql).rewrite().printlnTree(); + PlanChecker.from(connectContext) + .analyze(sql) + .rewrite() + .matches( + logicalJoin( + logicalFilter( + logicalOlapScan() + ).when(filter -> filter.getConjuncts().size() == 1 + && filter.getPredicate().toSql().contains("id = 2")), + any() + ).when(join -> join.getJoinType() == JoinType.LEFT_OUTER_JOIN) + ); + } }