Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
import org.apache.doris.nereids.trees.expressions.literal.NumericLiteral;
import org.apache.doris.nereids.trees.expressions.literal.SmallIntLiteral;
import org.apache.doris.nereids.trees.expressions.literal.TinyIntLiteral;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.types.DateTimeType;
import org.apache.doris.nereids.types.DateTimeV2Type;
import org.apache.doris.nereids.types.DateType;
Expand All @@ -71,12 +72,6 @@
public class SimplifyComparisonPredicate extends AbstractExpressionRewriteRule implements ExpressionPatternRuleFactory {
public static SimplifyComparisonPredicate INSTANCE = new SimplifyComparisonPredicate();

enum AdjustType {
LOWER,
UPPER,
NONE
}

@Override
public List<ExpressionPatternMatcher<? extends Expression>> buildRules() {
return ImmutableList.of(
Expand Down Expand Up @@ -119,18 +114,23 @@ public static Expression simplify(ComparisonPredicate cp) {
return result;
}

private static Expression processComparisonPredicateDateTimeV2Literal(
private static Expression processDateTimeLikeComparisonPredicateDateTimeV2Literal(
ComparisonPredicate comparisonPredicate, Expression left, DateTimeV2Literal right) {
DateTimeV2Type leftType = (DateTimeV2Type) left.getDataType();
DataType leftType = left.getDataType();
int toScale = 0;
if (leftType instanceof DateTimeType) {
toScale = 0;
} else if (leftType instanceof DateTimeV2Type) {
toScale = ((DateTimeV2Type) leftType).getScale();
} else {
return comparisonPredicate;
}
DateTimeV2Type rightType = right.getDataType();
if (leftType.getScale() < rightType.getScale()) {
int toScale = leftType.getScale();
if (toScale < rightType.getScale()) {
if (comparisonPredicate instanceof EqualTo) {
long originValue = right.getMicroSecond();
right = right.roundCeiling(toScale);
if (right.getMicroSecond() == originValue) {
return comparisonPredicate.withChildren(left, right);
} else {
if (right.getMicroSecond() != originValue) {
// TODO: the ideal way is to return an If expr like:
// return new If(new IsNull(left), new NullLiteral(BooleanType.INSTANCE),
// BooleanLiteral.of(false));
Expand All @@ -142,50 +142,55 @@ private static Expression processComparisonPredicateDateTimeV2Literal(
} else if (comparisonPredicate instanceof NullSafeEqual) {
long originValue = right.getMicroSecond();
right = right.roundCeiling(toScale);
if (right.getMicroSecond() == originValue) {
return comparisonPredicate.withChildren(left, right);
} else {
if (right.getMicroSecond() != originValue) {
return BooleanLiteral.of(false);
}
} else if (comparisonPredicate instanceof GreaterThan
|| comparisonPredicate instanceof LessThanEqual) {
return comparisonPredicate.withChildren(left, right.roundFloor(toScale));
right = right.roundFloor(toScale);
} else if (comparisonPredicate instanceof LessThan
|| comparisonPredicate instanceof GreaterThanEqual) {
return comparisonPredicate.withChildren(left, right.roundCeiling(toScale));
right = right.roundCeiling(toScale);
} else {
return comparisonPredicate;
}
Expression newRight = leftType instanceof DateTimeType ? migrateToDateTime(right) : right;
return comparisonPredicate.withChildren(left, newRight);
} else {
if (leftType instanceof DateTimeType) {
return comparisonPredicate.withChildren(left, migrateToDateTime(right));
} else {
return comparisonPredicate;
}
}
return comparisonPredicate;
}

private static Expression processDateLikeTypeCoercion(ComparisonPredicate cp, Expression left, Expression right) {
if (left instanceof Cast && right instanceof DateLiteral) {
Cast cast = (Cast) left;
if (cast.child().getDataType() instanceof DateTimeType) {
if (cast.child().getDataType() instanceof DateTimeType
|| cast.child().getDataType() instanceof DateTimeV2Type) {
if (right instanceof DateTimeV2Literal) {
left = cast.child();
right = migrateToDateTime((DateTimeV2Literal) right);
}
}
if (cast.child().getDataType() instanceof DateTimeV2Type) {
if (right instanceof DateTimeV2Literal) {
left = cast.child();
return processComparisonPredicateDateTimeV2Literal(cp, left, (DateTimeV2Literal) right);
return processDateTimeLikeComparisonPredicateDateTimeV2Literal(
cp, cast.child(), (DateTimeV2Literal) right);
}
}

// datetime to datev2
if (cast.child().getDataType() instanceof DateType || cast.child().getDataType() instanceof DateV2Type) {
if (right instanceof DateTimeLiteral) {
if (cannotAdjust((DateTimeLiteral) right, cp)) {
return cp;
}
AdjustType type = AdjustType.NONE;
if (cp instanceof GreaterThanEqual || cp instanceof LessThan) {
type = AdjustType.UPPER;
} else if (cp instanceof GreaterThan || cp instanceof LessThanEqual) {
type = AdjustType.LOWER;
DateTimeLiteral dateTimeLiteral = (DateTimeLiteral) right;
right = migrateToDateV2(dateTimeLiteral);
if (dateTimeLiteral.getHour() != 0 || dateTimeLiteral.getMinute() != 0
|| dateTimeLiteral.getSecond() != 0) {
if (cp instanceof EqualTo) {
return ExpressionUtils.falseOrNull(cast.child());
} else if (cp instanceof NullSafeEqual) {
return BooleanLiteral.FALSE;
} else if (cp instanceof GreaterThanEqual || cp instanceof LessThan) {
right = ((DateV2Literal) right).plusDays(1);
}
}
right = migrateToDateV2((DateTimeLiteral) right, type);
if (cast.child().getDataType() instanceof DateV2Type) {
left = cast.child();
}
Expand Down Expand Up @@ -416,17 +421,8 @@ private static Expression migrateToDateTime(DateTimeV2Literal l) {
return new DateTimeLiteral(l.getYear(), l.getMonth(), l.getDay(), l.getHour(), l.getMinute(), l.getSecond());
}

private static boolean cannotAdjust(DateTimeLiteral l, ComparisonPredicate cp) {
return cp instanceof EqualTo && (l.getHour() != 0 || l.getMinute() != 0 || l.getSecond() != 0);
}

private static Expression migrateToDateV2(DateTimeLiteral l, AdjustType type) {
DateV2Literal d = new DateV2Literal(l.getYear(), l.getMonth(), l.getDay());
if (type == AdjustType.UPPER && (l.getHour() != 0 || l.getMinute() != 0 || l.getSecond() != 0)) {
return d.plusDays(1);
} else {
return d;
}
private static Expression migrateToDateV2(DateTimeLiteral l) {
return new DateV2Literal(l.getYear(), l.getMonth(), l.getDay());
}

private static Expression migrateToDate(DateV2Literal l) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,10 @@
import org.apache.doris.nereids.types.BigIntType;
import org.apache.doris.nereids.types.BooleanType;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.types.DateTimeType;
import org.apache.doris.nereids.types.DateTimeV2Type;
import org.apache.doris.nereids.types.DateType;
import org.apache.doris.nereids.types.DateV2Type;
import org.apache.doris.nereids.types.DecimalV3Type;
import org.apache.doris.nereids.types.DoubleType;
import org.apache.doris.nereids.types.IntegerType;
Expand Down Expand Up @@ -95,11 +98,11 @@ void testSimplifyComparisonPredicateRule() {
new LessThan(dv2, dv2PlusOne));
assertRewrite(
new EqualTo(new Cast(dv2, DateTimeV2Type.SYSTEM_DEFAULT), dtv2),
new EqualTo(new Cast(dv2, DateTimeV2Type.SYSTEM_DEFAULT), dtv2));
BooleanLiteral.FALSE);

assertRewrite(
new EqualTo(new Cast(d, DateTimeV2Type.SYSTEM_DEFAULT), dtv2),
new EqualTo(new Cast(d, DateTimeV2Type.SYSTEM_DEFAULT), dtv2));
BooleanLiteral.FALSE);

// test hour, minute and second all zero
Expression dtv2AtZeroClock = new DateTimeV2Literal(1, 1, 1, 0, 0, 0, 0);
Expand Down Expand Up @@ -140,6 +143,100 @@ void testDateTimeV2CmpDateTimeV2() {
expression = new GreaterThan(left, right);
rewrittenExpression = executor.rewrite(typeCoercion(expression), context);
Assertions.assertEquals(dt.getDataType(), rewrittenExpression.child(0).getDataType());

Expression date = new SlotReference("a", DateV2Type.INSTANCE);
Expression datev1 = new SlotReference("a", DateType.INSTANCE);
Expression datetime0 = new SlotReference("a", DateTimeV2Type.of(0));
Expression datetime2 = new SlotReference("a", DateTimeV2Type.of(2));
Expression datetimev1 = new SlotReference("a", DateTimeType.INSTANCE);

// date
// cast (date as datetimev1) cmp datetimev1
assertRewrite(new EqualTo(new Cast(date, DateTimeType.INSTANCE), new DateTimeLiteral("2020-01-01 00:00:00")),
new EqualTo(date, new DateV2Literal("2020-01-01")));
assertRewrite(new EqualTo(new Cast(date, DateTimeType.INSTANCE), new DateTimeLiteral("2020-01-01 00:00:01")),
ExpressionUtils.falseOrNull(date));
assertRewrite(new NullSafeEqual(new Cast(date, DateTimeType.INSTANCE), new DateTimeLiteral("2020-01-01 00:00:01")),
BooleanLiteral.FALSE);
assertRewrite(new GreaterThan(new Cast(date, DateTimeType.INSTANCE), new DateTimeLiteral("2020-01-01 00:00:01")),
new GreaterThan(date, new DateV2Literal("2020-01-01")));
assertRewrite(new GreaterThanEqual(new Cast(date, DateTimeType.INSTANCE), new DateTimeLiteral("2020-01-01 00:00:01")),
new GreaterThanEqual(date, new DateV2Literal("2020-01-02")));
assertRewrite(new LessThan(new Cast(date, DateTimeType.INSTANCE), new DateTimeLiteral("2020-01-01 00:00:01")),
new LessThan(date, new DateV2Literal("2020-01-02")));
assertRewrite(new LessThanEqual(new Cast(date, DateTimeType.INSTANCE), new DateTimeLiteral("2020-01-01 00:00:01")),
new LessThanEqual(date, new DateV2Literal("2020-01-01")));
// cast (date as datev1) = datev1-literal
// assertRewrite(new EqualTo(new Cast(date, DateType.INSTANCE), new DateLiteral("2020-01-01")),
// new EqualTo(date, new DateV2Literal("2020-01-01")));
// assertRewrite(new GreaterThan(new Cast(date, DateType.INSTANCE), new DateLiteral("2020-01-01")),
// new GreaterThan(date, new DateV2Literal("2020-01-01")));

// cast (datev1 as datetimev1) cmp datetimev1
assertRewrite(new EqualTo(new Cast(datev1, DateTimeType.INSTANCE), new DateTimeLiteral("2020-01-01 00:00:00")),
new EqualTo(datev1, new DateLiteral("2020-01-01")));
assertRewrite(new EqualTo(new Cast(datev1, DateTimeType.INSTANCE), new DateTimeLiteral("2020-01-01 00:00:01")),
ExpressionUtils.falseOrNull(datev1));
assertRewrite(new NullSafeEqual(new Cast(datev1, DateTimeType.INSTANCE), new DateTimeLiteral("2020-01-01 00:00:01")),
BooleanLiteral.FALSE);
assertRewrite(new GreaterThan(new Cast(datev1, DateTimeType.INSTANCE), new DateTimeLiteral("2020-01-01 00:00:01")),
new GreaterThan(datev1, new DateLiteral("2020-01-01")));
assertRewrite(new GreaterThanEqual(new Cast(datev1, DateTimeType.INSTANCE), new DateTimeLiteral("2020-01-01 00:00:01")),
new GreaterThanEqual(datev1, new DateLiteral("2020-01-02")));
assertRewrite(new LessThan(new Cast(datev1, DateTimeType.INSTANCE), new DateTimeLiteral("2020-01-01 00:00:01")),
new LessThan(datev1, new DateLiteral("2020-01-02")));
assertRewrite(new LessThanEqual(new Cast(datev1, DateTimeType.INSTANCE), new DateTimeLiteral("2020-01-01 00:00:01")),
new LessThanEqual(datev1, new DateLiteral("2020-01-01")));
assertRewrite(new EqualTo(new Cast(datev1, DateV2Type.INSTANCE), new DateV2Literal("2020-01-01")),
new EqualTo(datev1, new DateLiteral("2020-01-01")));
assertRewrite(new GreaterThan(new Cast(datev1, DateV2Type.INSTANCE), new DateV2Literal("2020-01-01")),
new GreaterThan(datev1, new DateLiteral("2020-01-01")));

// cast (datetimev1 as datetime) cmp datetime
assertRewrite(new EqualTo(new Cast(datetimev1, DateTimeV2Type.of(0)), new DateTimeV2Literal("2020-01-01 00:00:00")),
new EqualTo(datetimev1, new DateTimeLiteral("2020-01-01 00:00:00")));
assertRewrite(new GreaterThan(new Cast(datetimev1, DateTimeV2Type.of(0)), new DateTimeV2Literal("2020-01-01 00:00:00")),
new GreaterThan(datetimev1, new DateTimeLiteral("2020-01-01 00:00:00")));
assertRewrite(new EqualTo(new Cast(datetimev1, DateTimeV2Type.of(2)), new DateTimeV2Literal("2020-01-01 00:00:00.12")),
ExpressionUtils.falseOrNull(datetimev1));
assertRewrite(new NullSafeEqual(new Cast(datetimev1, DateTimeV2Type.of(2)), new DateTimeV2Literal("2020-01-01 00:00:00.12")),
BooleanLiteral.FALSE);
assertRewrite(new GreaterThan(new Cast(datetimev1, DateTimeV2Type.of(2)), new DateTimeV2Literal("2020-01-01 00:00:00.12")),
new GreaterThan(datetimev1, new DateTimeLiteral("2020-01-01 00:00:00")));
assertRewrite(new GreaterThanEqual(new Cast(datetimev1, DateTimeV2Type.of(2)), new DateTimeV2Literal("2020-01-01 00:00:00.12")),
new GreaterThanEqual(datetimev1, new DateTimeLiteral("2020-01-01 00:00:01")));
assertRewrite(new LessThan(new Cast(datetimev1, DateTimeV2Type.of(2)), new DateTimeV2Literal("2020-01-01 00:00:00.12")),
new LessThan(datetimev1, new DateTimeLiteral("2020-01-01 00:00:01")));
assertRewrite(new LessThanEqual(new Cast(datetimev1, DateTimeV2Type.of(2)), new DateTimeV2Literal("2020-01-01 00:00:00.12")),
new LessThanEqual(datetimev1, new DateTimeLiteral("2020-01-01 00:00:00")));

// cast (datetime0 as datetime) cmp datetime
assertRewrite(new EqualTo(new Cast(datetime0, DateTimeV2Type.of(2)), new DateTimeV2Literal("2020-01-01 00:00:00.12")),
ExpressionUtils.falseOrNull(datetime0));
assertRewrite(new NullSafeEqual(new Cast(datetime0, DateTimeV2Type.of(2)), new DateTimeV2Literal("2020-01-01 00:00:00.12")),
BooleanLiteral.FALSE);
assertRewrite(new GreaterThan(new Cast(datetime0, DateTimeV2Type.of(2)), new DateTimeV2Literal("2020-01-01 00:00:00.12")),
new GreaterThan(datetime0, new DateTimeV2Literal("2020-01-01 00:00:00")));
assertRewrite(new GreaterThanEqual(new Cast(datetime0, DateTimeV2Type.of(2)), new DateTimeV2Literal("2020-01-01 00:00:00.12")),
new GreaterThanEqual(datetime0, new DateTimeV2Literal("2020-01-01 00:00:01")));
assertRewrite(new LessThan(new Cast(datetime0, DateTimeV2Type.of(2)), new DateTimeV2Literal("2020-01-01 00:00:00.12")),
new LessThan(datetime0, new DateTimeV2Literal("2020-01-01 00:00:01")));
assertRewrite(new LessThanEqual(new Cast(datetime0, DateTimeV2Type.of(2)), new DateTimeV2Literal("2020-01-01 00:00:00.12")),
new LessThanEqual(datetime0, new DateTimeV2Literal("2020-01-01 00:00:00")));

// cast (datetime2 as datetime) cmp datetime
assertRewrite(new EqualTo(new Cast(datetime2, DateTimeV2Type.of(3)), new DateTimeV2Literal("2020-01-01 00:00:00.123")),
ExpressionUtils.falseOrNull(datetime2));
assertRewrite(new NullSafeEqual(new Cast(datetime2, DateTimeV2Type.of(3)), new DateTimeV2Literal("2020-01-01 00:00:00.123")),
BooleanLiteral.FALSE);
assertRewrite(new GreaterThan(new Cast(datetime2, DateTimeV2Type.of(3)), new DateTimeV2Literal("2020-01-01 00:00:00.123")),
new GreaterThan(datetime2, new DateTimeV2Literal("2020-01-01 00:00:00.12")));
assertRewrite(new GreaterThanEqual(new Cast(datetime2, DateTimeV2Type.of(3)), new DateTimeV2Literal("2020-01-01 00:00:00.123")),
new GreaterThanEqual(datetime2, new DateTimeV2Literal("2020-01-01 00:00:00.13")));
assertRewrite(new LessThan(new Cast(datetime2, DateTimeV2Type.of(3)), new DateTimeV2Literal("2020-01-01 00:00:00.123")),
new LessThan(datetime2, new DateTimeV2Literal("2020-01-01 00:00:00.13")));
assertRewrite(new LessThanEqual(new Cast(datetime2, DateTimeV2Type.of(3)), new DateTimeV2Literal("2020-01-01 00:00:00.123")),
new LessThanEqual(datetime2, new DateTimeV2Literal("2020-01-01 00:00:00.12")));
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1528,7 +1528,7 @@ PhysicalResultSink
----------PhysicalProject
------------PhysicalStorageLayerAggregate[test_pull_up_predicate_literal]
------PhysicalProject
--------filter((cast(d_date as DATETIMEV2(0)) = '2024-08-02 10:10:00'))
--------filter((t2.d_date = '2024-08-02'))
----------PhysicalOlapScan[test_types]

-- !const_value_and_join_column_type170 --
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -997,7 +997,7 @@ select c1 from (select
qt_const_value_and_join_column_type169 """
explain shape plan
select c1 from (select
'2024-08-02 10:10:00.123332' as c1 from test_pull_up_predicate_literal limit 10) t inner join test_types t2 on d_date=t.c1"""
'2024-08-02 00:00:00.000000' as c1 from test_pull_up_predicate_literal limit 10) t inner join test_types t2 on d_date=t.c1"""

qt_const_value_and_join_column_type170 """
explain shape plan
Expand Down
Loading