Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,17 @@
*/
public class PredicatePropagation {

/**
* equal predicate with literal in one side would be chosen to be source predicates and used to infer all predicates
*/
private Set<Expression> sourcePredicates = Sets.newHashSet();

/**
* infer additional predicates.
*/
public Set<Expression> infer(Set<Expression> predicates) {
Set<Expression> inferred = Sets.newHashSet();
predicates.addAll(sourcePredicates);
for (Expression predicate : predicates) {
if (canEquivalentInfer(predicate)) {
List<Expression> newInferred = predicates.stream()
Expand All @@ -55,6 +61,7 @@ public Set<Expression> infer(Set<Expression> predicates) {
}
}
inferred.removeAll(predicates);
sourcePredicates.addAll(inferred);
return inferred;
}

Expand All @@ -76,8 +83,10 @@ public Expression visit(Expression expr, Void context) {
public Expression visitComparisonPredicate(ComparisonPredicate cp, Void context) {
// we need to get expression covered by cast, because we want to infer different datatype
if (ExpressionUtils.isExpressionSlotCoveredByCast(cp.left()) && (cp.right().isConstant())) {
sourcePredicates.add(cp);
return replaceSlot(cp, ExpressionUtils.getDatatypeCoveredByCast(cp.left()));
} else if (ExpressionUtils.isExpressionSlotCoveredByCast(cp.right()) && cp.left().isConstant()) {
sourcePredicates.add(cp);
return replaceSlot(cp, ExpressionUtils.getDatatypeCoveredByCast(cp.right()));
}
return super.visit(cp, context);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ suite("test_infer_predicate") {

sql 'drop table if exists infer_tb1;'
sql 'drop table if exists infer_tb2;'
sql 'drop table if exists infer_tb3;'

sql '''create table infer_tb1 (k1 int, k2 int) distributed by hash(k1) buckets 3 properties('replication_num' = '1');'''

Expand All @@ -47,4 +48,10 @@ suite("test_infer_predicate") {
sql "select * from infer_tb1 inner join infer_tb3 where infer_tb3.k1 = infer_tb1.k2 and infer_tb3.k1 = '123';"
notContains "PREDICATES: k2[#6] = '123'"
}

explain {
sql "select * from infer_tb1 left join infer_tb2 on infer_tb1.k1 = infer_tb2.k3 left join infer_tb3 on " +
"infer_tb2.k3 = infer_tb3.k2 where infer_tb1.k1 = 1;"
contains "PREDICATES: k3[#4] = 1"
}
}