Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,28 @@

package org.apache.doris.nereids.rules.rewrite;

import org.apache.doris.nereids.parser.NereidsParser;
import org.apache.doris.nereids.rules.expression.rules.SimplifyComparisonPredicate;
import org.apache.doris.nereids.trees.expressions.Cast;
import org.apache.doris.nereids.trees.expressions.ComparisonPredicate;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter;
import org.apache.doris.nereids.trees.expressions.literal.IntegerLikeLiteral;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.types.DateTimeType;
import org.apache.doris.nereids.types.DateTimeV2Type;
import org.apache.doris.nereids.types.DateType;
import org.apache.doris.nereids.types.DateV2Type;
import org.apache.doris.nereids.types.coercion.CharacterType;
import org.apache.doris.nereids.types.coercion.DateLikeType;
import org.apache.doris.nereids.types.coercion.IntegralType;
import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.nereids.util.TypeCoercionUtils;

import com.google.common.collect.Sets;

import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;

Expand All @@ -40,19 +49,61 @@
*/
public class PredicatePropagation {

private enum InferType {
NONE(null),
INTEGRAL(IntegralType.class),
STRING(CharacterType.class),
DATE(DateLikeType.class),
OTHER(DataType.class)
;

private final Class<? extends DataType> superClazz;

InferType(Class<? extends DataType> superClazz) {
this.superClazz = superClazz;
}
}

private class ComparisonInferInfo {

public final InferType inferType;
public final Optional<Expression> left;
public final Optional<Expression> right;
public final ComparisonPredicate comparisonPredicate;

public ComparisonInferInfo(InferType inferType,
Optional<Expression> left, Optional<Expression> right,
ComparisonPredicate comparisonPredicate) {
this.inferType = inferType;
this.left = left;
this.right = right;
this.comparisonPredicate = comparisonPredicate;
}
}

/**
* infer additional predicates.
*/
public Set<Expression> infer(Set<Expression> predicates) {
Set<Expression> inferred = Sets.newHashSet();
for (Expression predicate : predicates) {
if (canEquivalentInfer(predicate)) {
List<Expression> newInferred = predicates.stream()
.filter(p -> !p.equals(predicate))
.map(p -> doInfer(predicate, p))
.collect(Collectors.toList());
inferred.addAll(newInferred);
if (!(predicate instanceof ComparisonPredicate)) {
continue;
}
ComparisonInferInfo equalInfo = getEquivalentInferInfo((ComparisonPredicate) predicate);
if (equalInfo.inferType == InferType.NONE) {
continue;
}
Set<Expression> newInferred = predicates.stream()
.filter(ComparisonPredicate.class::isInstance)
.filter(p -> !p.equals(predicate))
.map(ComparisonPredicate.class::cast)
.map(this::inferInferInfo)
.filter(predicateInfo -> predicateInfo.inferType != InferType.NONE)
.map(predicateInfo -> doInfer(equalInfo, predicateInfo))
.filter(Objects::nonNull)
.collect(Collectors.toSet());
inferred.addAll(newInferred);
}
inferred.removeAll(predicates);
return inferred;
Expand All @@ -64,64 +115,128 @@ public Set<Expression> infer(Set<Expression> predicates) {
* TODO: We should determine whether `expression` satisfies the condition for replacement
* eg: Satisfy `expression` is non-deterministic
*/
private Expression doInfer(Expression leftSlotEqualToRightSlot, Expression expression) {
return expression.accept(new DefaultExpressionRewriter<Void>() {
private Expression doInfer(ComparisonInferInfo equalInfo, ComparisonInferInfo predicateInfo) {
Expression predicateLeft = predicateInfo.left.get();
Expression predicateRight = predicateInfo.right.get();
Expression equalLeft = equalInfo.left.get();
Expression equalRight = equalInfo.right.get();
Expression newLeft = inferOneSide(predicateLeft, equalLeft, equalRight);
Expression newRight = inferOneSide(predicateRight, equalLeft, equalRight);
if (newLeft == null || newRight == null) {
return null;
}
ComparisonPredicate newPredicate = (ComparisonPredicate) predicateInfo
.comparisonPredicate.withChildren(newLeft, newRight);
return SimplifyComparisonPredicate.INSTANCE
.rewrite(TypeCoercionUtils.processComparisonPredicate(newPredicate), null);
}

@Override
public Expression visit(Expression expr, Void context) {
return expr;
private Expression inferOneSide(Expression predicateOneSide, Expression equalLeft, Expression equalRight) {
if (predicateOneSide instanceof SlotReference) {
if (predicateOneSide.equals(equalLeft)) {
return equalRight;
} else if (predicateOneSide.equals(equalRight)) {
return equalLeft;
}

@Override
public Expression visitComparisonPredicate(ComparisonPredicate cp, Void context) {
// we need to get expression covered by cast, because we want to infer different datatype
if (ExpressionUtils.isExpressionSlotCoveredByCast(cp.left()) && (cp.right().isConstant())) {
return replaceSlot(cp, ExpressionUtils.getDatatypeCoveredByCast(cp.left()));
} else if (ExpressionUtils.isExpressionSlotCoveredByCast(cp.right()) && cp.left().isConstant()) {
return replaceSlot(cp, ExpressionUtils.getDatatypeCoveredByCast(cp.right()));
}
return super.visit(cp, context);
} else if (predicateOneSide.isConstant()) {
if (predicateOneSide instanceof IntegerLikeLiteral) {
return new NereidsParser().parseExpression(((IntegerLikeLiteral) predicateOneSide).toSql());
} else {
return predicateOneSide;
}
}
return null;
}

private boolean isDataTypeValid(DataType originDataType, Expression expr) {
if ((leftSlotEqualToRightSlot.child(0).getDataType() instanceof IntegralType)
&& (leftSlotEqualToRightSlot.child(1).getDataType() instanceof IntegralType)
&& (originDataType instanceof IntegralType)) {
// infer filter can not be lower than original datatype, or dataset would be wrong
if (!((IntegralType) originDataType).widerThan(
(IntegralType) leftSlotEqualToRightSlot.child(0).getDataType())
&& !((IntegralType) originDataType).widerThan(
(IntegralType) leftSlotEqualToRightSlot.child(1).getDataType())) {
return true;
private Optional<Expression> validForInfer(Expression expression, InferType inferType) {
if (!inferType.superClazz.isAssignableFrom(expression.getDataType().getClass())) {
return Optional.empty();
}
if (expression instanceof SlotReference || expression.isConstant()) {
return Optional.of(expression);
}
if (inferType == InferType.INTEGRAL) {
if (expression instanceof Cast) {
// avoid cast from wider type to narrower type, such as cast(int as smallint)
// IntegralType dataType = (IntegralType) expression.getDataType();
// DataType childType = ((Cast) expression).child().getDataType();
// if (childType instanceof IntegralType && dataType.widerThan((IntegralType) childType)) {
// return validForInfer(((Cast) expression).child(), inferType);
// }
return validForInfer(((Cast) expression).child(), inferType);
}
} else if (inferType == InferType.DATE) {
if (expression instanceof Cast) {
DataType dataType = expression.getDataType();
DataType childType = ((Cast) expression).child().getDataType();
// avoid lost precision
if (dataType instanceof DateType) {
if (childType instanceof DateV2Type || childType instanceof DateType) {
return validForInfer(((Cast) expression).child(), inferType);
}
} else if (dataType instanceof DateV2Type) {
if (childType instanceof DateType || childType instanceof DateV2Type) {
return validForInfer(((Cast) expression).child(), inferType);
}
} else if (dataType instanceof DateTimeType) {
if (!(childType instanceof DateTimeV2Type)) {
return validForInfer(((Cast) expression).child(), inferType);
}
} else if (dataType instanceof DateTimeV2Type) {
return validForInfer(((Cast) expression).child(), inferType);
}
return false;
}

private Expression replaceSlot(Expression expr, DataType originDataType) {
return expr.rewriteUp(e -> {
if (isDataTypeValid(originDataType, leftSlotEqualToRightSlot)) {
if (ExpressionUtils.isTwoExpressionEqualWithCast(e, leftSlotEqualToRightSlot.child(0))) {
return leftSlotEqualToRightSlot.child(1);
} else if (ExpressionUtils.isTwoExpressionEqualWithCast(e, leftSlotEqualToRightSlot.child(1))) {
return leftSlotEqualToRightSlot.child(0);
}
}
return e;
});
} else if (inferType == InferType.STRING) {
if (expression instanceof Cast) {
DataType dataType = expression.getDataType();
DataType childType = ((Cast) expression).child().getDataType();
// avoid substring cast such as cast(char(3) as char(2))
if (dataType.width() <= 0 || (dataType.width() >= childType.width() && childType.width() >= 0)) {
return validForInfer(((Cast) expression).child(), inferType);
}
}
}, null);
} else {
return Optional.empty();
}
return Optional.empty();
}

private ComparisonInferInfo inferInferInfo(ComparisonPredicate comparisonPredicate) {
DataType leftType = comparisonPredicate.left().getDataType();
InferType inferType;
if (leftType instanceof CharacterType) {
inferType = InferType.STRING;
} else if (leftType instanceof IntegralType) {
inferType = InferType.INTEGRAL;
} else if (leftType instanceof DateLikeType) {
inferType = InferType.DATE;
} else {
inferType = InferType.OTHER;
}
Optional<Expression> left = validForInfer(comparisonPredicate.left(), inferType);
Optional<Expression> right = validForInfer(comparisonPredicate.right(), inferType);
if (!left.isPresent() || !right.isPresent()) {
inferType = InferType.NONE;
}
return new ComparisonInferInfo(inferType, left, right, comparisonPredicate);
}

/**
* Currently only equivalence derivation is supported
* and requires that the left and right sides of an expression must be slot
*/
private boolean canEquivalentInfer(Expression predicate) {
return predicate instanceof EqualTo
&& predicate.children().stream().allMatch(e ->
(e instanceof SlotReference) || (e instanceof Cast && e.child(0) instanceof SlotReference))
&& predicate.child(0).getDataType().equals(predicate.child(1).getDataType());
private ComparisonInferInfo getEquivalentInferInfo(ComparisonPredicate predicate) {
if (!(predicate instanceof EqualTo)) {
return new ComparisonInferInfo(InferType.NONE,
Optional.of(predicate.left()), Optional.of(predicate.right()), predicate);
}
ComparisonInferInfo info = inferInferInfo(predicate);
if (info.inferType == InferType.NONE) {
return info;
}
if (info.left.get() instanceof SlotReference && info.right.get() instanceof SlotReference) {
return info;
}
return new ComparisonInferInfo(InferType.NONE, info.left, info.right, info.comparisonPredicate);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.expressions.literal.NullLiteral;
import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter;
import org.apache.doris.nereids.types.DataType;

import com.google.common.base.Preconditions;
import com.google.common.base.Predicate;
Expand Down Expand Up @@ -253,34 +252,6 @@ public static Optional<Slot> extractSlotOrCastOnSlot(Expression expr) {
}
}

/**
* get slot covered by cast
* example: input: cast(cast(table.columnA)) output: columnA.datatype
*
*/
public static DataType getDatatypeCoveredByCast(Expression expr) {
if (expr instanceof Cast) {
return getDatatypeCoveredByCast(((Cast) expr).child());
}
return expr.getDataType();
}

/**
* judge if expression is slot covered by cast
* example: cast(cast(table.columnA))
*/
public static boolean isExpressionSlotCoveredByCast(Expression expr) {
if (expr instanceof Cast) {
return isExpressionSlotCoveredByCast(((Cast) expr).child());
}
return expr instanceof SlotReference;
}

public static boolean isTwoExpressionEqualWithCast(Expression left, Expression right) {
return ExpressionUtils.extractSlotOrCastOnSlot(left)
.equals(ExpressionUtils.extractSlotOrCastOnSlot(right));
}

/**
* Replace expression node in the expression tree by `replaceMap` in top-down manner.
* For example.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ suite("test_infer_predicate") {

explain {
sql "select * from infer_tb1 inner join infer_tb2 where cast(infer_tb2.k4 as int) = infer_tb1.k2 and infer_tb2.k4 = 1;"
notContains "PREDICATES: k2"
contains "PREDICATES: k2"
}

explain {
Expand Down