Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import com.google.common.collect.Sets;

import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
Expand Down Expand Up @@ -103,11 +104,15 @@ public Plan visitLogicalFilter(LogicalFilter<? extends Plan> filter, JobContext
}

private Set<Expression> getAllExpressions(Plan left, Plan right, Optional<Expression> condition) {
Set<Expression> baseExpressions = pullUpPredicates(left);
baseExpressions.addAll(pullUpPredicates(right));
condition.ifPresent(on -> baseExpressions.addAll(ExpressionUtils.extractConjunction(on)));
baseExpressions.addAll(propagation.infer(baseExpressions));
return baseExpressions;
Map<Expression, Boolean> baseExpressions = pullUpPredicates(left).stream()
.collect(Collectors.toMap(e -> e, e -> true));
baseExpressions.putAll(pullUpPredicates(right).stream()
.collect(Collectors.toMap(e -> e, e -> true)));
condition.ifPresent(on -> baseExpressions.putAll(ExpressionUtils.extractConjunction(on).stream()
.collect(Collectors.toMap(e -> e, e -> false))));
Set<Expression> allExpressions = Sets.newHashSet(baseExpressions.keySet());
allExpressions.addAll(propagation.infer(baseExpressions));
return allExpressions;
}

private Set<Expression> pullUpPredicates(Plan plan) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,10 @@
import com.google.common.collect.Sets;

import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;

/**
* derive additional predicates.
Expand All @@ -43,25 +45,34 @@ public class PredicatePropagation {
/**
* equal predicate with literal in one side would be chosen to be source predicates and used to infer all predicates
*/
private Set<Expression> sourcePredicates = Sets.newHashSet();
private final Set<Expression> sourcePredicates = Sets.newHashSet();

/**
* infer additional predicates.
*
* @param predicates predicate expression to shouldRecordIntoSourcePredicates flag. If the value is false
* we do not record itself and any expression inferred from it into sourcePredicates set.
* This is because these expression should not be seen when we process upper join node.
* the example could see the UT test
* {@link InferPredicatesTest#shouldNotSaveOnClausePredicates()}
*/
public Set<Expression> infer(Set<Expression> predicates) {
public Set<Expression> infer(Map<Expression, Boolean> predicates) {
Set<Expression> inferred = Sets.newHashSet();
predicates.addAll(sourcePredicates);
for (Expression predicate : predicates) {
if (canEquivalentInfer(predicate)) {
List<Expression> newInferred = predicates.stream()
Set<Expression> newSourcePredicates = Sets.newHashSet();
for (Map.Entry<Expression, Boolean> predicate : predicates.entrySet()) {
if (canEquivalentInfer(predicate.getKey())) {
List<Expression> newInferred = Stream.concat(sourcePredicates.stream(), predicates.keySet().stream())
.filter(p -> !p.equals(predicate))
.map(p -> doInfer(predicate, p))
.map(p -> doInfer(predicate.getKey(), p, predicate.getValue()))
.collect(Collectors.toList());
inferred.addAll(newInferred);
newInferred.removeAll(predicates.keySet());
if (predicate.getValue()) {
newSourcePredicates.addAll(newInferred);
}
}
}
inferred.removeAll(predicates);
sourcePredicates.addAll(inferred);
sourcePredicates.addAll(newSourcePredicates);
return inferred;
}

Expand All @@ -71,25 +82,30 @@ public Set<Expression> infer(Set<Expression> predicates) {
* TODO: We should determine whether `expression` satisfies the condition for replacement
* eg: Satisfy `expression` is non-deterministic
*/
private Expression doInfer(Expression leftSlotEqualToRightSlot, Expression expression) {
return expression.accept(new DefaultExpressionRewriter<Void>() {
private Expression doInfer(Expression leftSlotEqualToRightSlot,
Expression expression, boolean recordSourcePredicates) {
return expression.accept(new DefaultExpressionRewriter<Boolean>() {

@Override
public Expression visit(Expression expr, Void context) {
public Expression visit(Expression expr, Boolean recordSourcePredicates) {
return expr;
}

@Override
public Expression visitComparisonPredicate(ComparisonPredicate cp, Void context) {
public Expression visitComparisonPredicate(ComparisonPredicate cp, Boolean recordSourcePredicates) {
// we need to get expression covered by cast, because we want to infer different datatype
if (ExpressionUtils.isExpressionSlotCoveredByCast(cp.left()) && (cp.right().isConstant())) {
sourcePredicates.add(cp);
if (recordSourcePredicates) {
sourcePredicates.add(cp);
}
return replaceSlot(cp, ExpressionUtils.getDatatypeCoveredByCast(cp.left()));
} else if (ExpressionUtils.isExpressionSlotCoveredByCast(cp.right()) && cp.left().isConstant()) {
sourcePredicates.add(cp);
if (recordSourcePredicates) {
sourcePredicates.add(cp);
}
return replaceSlot(cp, ExpressionUtils.getDatatypeCoveredByCast(cp.right()));
}
return super.visit(cp, context);
return super.visit(cp, recordSourcePredicates);
}

private boolean isDataTypeValid(DataType originDataType, Expression expr) {
Expand Down Expand Up @@ -119,7 +135,7 @@ private Expression replaceSlot(Expression expr, DataType originDataType) {
return e;
});
}
}, null);
}, recordSourcePredicates);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ private ImmutableSet<Expression> cacheOrElse(Plan plan, Supplier<ImmutableSet<Ex

private ImmutableSet<Expression> getAvailableExpressions(Collection<Expression> predicates, Plan plan) {
Set<Expression> expressions = Sets.newHashSet(predicates);
expressions.addAll(propagation.infer(expressions));
expressions.addAll(propagation.infer(expressions.stream().collect(Collectors.toMap(e -> e, e -> false))));
return expressions.stream()
.filter(p -> plan.getOutputSet().containsAll(p.getInputSlots()))
.collect(ImmutableSet.toImmutableSet());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.doris.nereids.rules.rewrite;

import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.util.MemoPatternMatchSupported;
import org.apache.doris.nereids.util.PlanChecker;
import org.apache.doris.utframe.TestWithFeService;
Expand Down Expand Up @@ -604,5 +605,25 @@ public void inferPredicatesTest22() {
)
);
}

@Test
public void shouldNotSaveOnClausePredicates() {
String sql = "select * from student s1"
+ " left join (select sid as id1, sid as id2, grade from score) s2 on s1.id = s2.id1 and s1.id = 1"
+ " join (select sid as id1, sid as id2, grade from score) s3 on s1.id = s3.id1 where s1.id = 2";
PlanChecker.from(connectContext).analyze(sql).rewrite().printlnTree();
PlanChecker.from(connectContext)
.analyze(sql)
.rewrite()
.matches(
logicalJoin(
logicalFilter(
logicalOlapScan()
).when(filter -> filter.getConjuncts().size() == 1
&& filter.getPredicate().toSql().contains("id = 2")),
any()
).when(join -> join.getJoinType() == JoinType.LEFT_OUTER_JOIN)
);
}
}