Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,12 @@ public Set<Expression> infer(Set<Expression> predicates) {
}

/**
* Use the left or right child of `leftSlotEqualToRightSlot` to replace the left or right child of `expression`
* Use the left or right child of `equalExpr` to replace the left or right child of `expression`
* Now only support infer `ComparisonPredicate`.
* TODO: We should determine whether `expression` satisfies the condition for replacement
* eg: Satisfy `expression` is non-deterministic
*/
private Expression doInfer(Expression leftSlotEqualToRightSlot, Expression expression) {
private Expression doInfer(Expression equalExpr, Expression expression) {
return expression.accept(new DefaultExpressionRewriter<Void>() {

@Override
Expand All @@ -76,36 +76,43 @@ public Expression visit(Expression expr, Void context) {
public Expression visitComparisonPredicate(ComparisonPredicate cp, Void context) {
// we need to get expression covered by cast, because we want to infer different datatype
if (ExpressionUtils.isExpressionSlotCoveredByCast(cp.left()) && (cp.right().isConstant())) {
return replaceSlot(cp, ExpressionUtils.getDatatypeCoveredByCast(cp.left()));
return replaceSlot(cp, ExpressionUtils.getDatatypeCoveredByCast(cp.left()), equalExpr);
} else if (ExpressionUtils.isExpressionSlotCoveredByCast(cp.right()) && cp.left().isConstant()) {
return replaceSlot(cp, ExpressionUtils.getDatatypeCoveredByCast(cp.right()));
return replaceSlot(cp, ExpressionUtils.getDatatypeCoveredByCast(cp.right()), equalExpr);
}
return super.visit(cp, context);
}

private boolean isDataTypeValid(DataType originDataType, Expression expr) {
if ((leftSlotEqualToRightSlot.child(0).getDataType() instanceof IntegralType)
&& (leftSlotEqualToRightSlot.child(1).getDataType() instanceof IntegralType)
if ((expr.child(0).getDataType() instanceof IntegralType)
&& (expr.child(1).getDataType() instanceof IntegralType)
&& (originDataType instanceof IntegralType)) {
// infer filter can not be lower than original datatype, or dataset would be wrong
if (!((IntegralType) originDataType).widerThan(
(IntegralType) leftSlotEqualToRightSlot.child(0).getDataType())
(IntegralType) expr.child(0).getDataType())
&& !((IntegralType) originDataType).widerThan(
(IntegralType) leftSlotEqualToRightSlot.child(1).getDataType())) {
(IntegralType) expr.child(1).getDataType())) {
return true;
}
} else if (expr.child(0).getDataType().equals(expr.child(1).getDataType())) {
return true;
}
return false;
}

private Expression replaceSlot(Expression expr, DataType originDataType) {
return expr.rewriteUp(e -> {
if (isDataTypeValid(originDataType, leftSlotEqualToRightSlot)) {
if (ExpressionUtils.isTwoExpressionEqualWithCast(e, leftSlotEqualToRightSlot.child(0))) {
return leftSlotEqualToRightSlot.child(1);
} else if (ExpressionUtils.isTwoExpressionEqualWithCast(e, leftSlotEqualToRightSlot.child(1))) {
return leftSlotEqualToRightSlot.child(0);
}
private Expression replaceSlot(Expression sourcePredicate, DataType originDataType, Expression equal) {
if (!isDataTypeValid(originDataType, equal)) {
return sourcePredicate;
}
return sourcePredicate.rewriteUp(e -> {
// we can not replace Cast expression to slot because when rewrite up, we have replace child of cast
if (e instanceof Cast) {
return e;
}
if (ExpressionUtils.isTwoExpressionEqualWithCast(e, equal.child(0))) {
return equal.child(1);
} else if (ExpressionUtils.isTwoExpressionEqualWithCast(e, equal.child(1))) {
return equal.child(0);
}
return e;
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,33 @@

package org.apache.doris.nereids.rules.rewrite;

import org.apache.doris.nereids.trees.expressions.Cast;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
import org.apache.doris.nereids.types.BigIntType;
import org.apache.doris.nereids.util.MemoPatternMatchSupported;
import org.apache.doris.nereids.util.PlanChecker;
import org.apache.doris.nereids.util.PlanConstructor;
import org.apache.doris.utframe.TestWithFeService;

import com.google.common.collect.Sets;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;

import java.util.Optional;
import java.util.Set;

public class InferPredicatesTest extends TestWithFeService implements MemoPatternMatchSupported {

private final LogicalOlapScan scan1 = PlanConstructor.newLogicalOlapScan(0, "t1", 0);

private final LogicalOlapScan scan2 = PlanConstructor.newLogicalOlapScan(1, "t2", 0);

private final PredicatePropagation propagation = new PredicatePropagation();

@Override
protected void runBeforeAll() throws Exception {
createDatabase("test");
Expand Down Expand Up @@ -628,4 +646,16 @@ public void innerJoinShouldNotInferUnderLeftJoinOnClausePredicates() {
).when(join -> join.getJoinType() == JoinType.LEFT_OUTER_JOIN)
);
}

@Test
void testInfer() {
EqualTo equalTo = new EqualTo(new Cast(scan1.getOutput().get(0), BigIntType.INSTANCE), Literal.of(1));
EqualTo equalTo2 = new EqualTo(scan2.getOutput().get(0), scan1.getOutput().get(0));
Set<Expression> predicates = Sets.newHashSet();
predicates.add(equalTo2);
predicates.add(equalTo);
Set<Expression> newPredicates = propagation.infer(predicates);
Optional<Expression> newPredicate = newPredicates.stream().findFirst();
Assertions.assertTrue(newPredicate.get().equals(new EqualTo(new Cast(scan2.getOutput().get(0), BigIntType.INSTANCE), Literal.of(1))));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,19 @@ suite("test_infer_predicate") {
sql 'drop table if exists infer_tb1;'
sql 'drop table if exists infer_tb2;'
sql 'drop table if exists infer_tb3;'
sql 'drop table if exists infer_tb4;'
sql 'drop table if exists infer_tb5;'

sql '''create table infer_tb1 (k1 int, k2 int) distributed by hash(k1) buckets 3 properties('replication_num' = '1');'''

sql '''create table infer_tb2 (k1 tinyint, k2 smallint, k3 int, k4 bigint, k5 largeint, k6 date, k7 datetime, k8 float, k9 double) distributed by hash(k1) buckets 3 properties('replication_num' = '1');'''

sql '''create table infer_tb3 (k1 varchar(100), k2 int) distributed by hash(k1) buckets 3 properties('replication_num' = '1');'''

sql '''create table infer_tb4 (k1 varchar(100), k2 date) distributed by hash(k1) buckets 3 properties('replication_num' = '1');'''

sql '''create table infer_tb5 (k1 varchar(100), k3 date) distributed by hash(k1) buckets 3 properties('replication_num' = '1');'''

explain {
sql "select * from infer_tb1 inner join infer_tb2 where infer_tb2.k1 = infer_tb1.k2 and infer_tb2.k1 = 1;"
contains "PREDICATES: k2"
Expand All @@ -55,4 +61,16 @@ suite("test_infer_predicate") {
contains "PREDICATES: k3"
contains "PREDICATES: k2"
}

explain {
sql "select * from infer_tb4 left join infer_tb5 on infer_tb4.k2 = infer_tb5.k3 where infer_tb4.k2 = '20230901';"
contains "PREDICATES: k3"
contains "PREDICATES: k2"
}

sql 'drop table if exists infer_tb1;'
sql 'drop table if exists infer_tb2;'
sql 'drop table if exists infer_tb3;'
sql 'drop table if exists infer_tb4;'
sql 'drop table if exists infer_tb5;'
}