Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
import com.google.common.collect.ImmutableList;

import java.math.BigDecimal;
import java.math.RoundingMode;
import java.util.List;

/**
Expand Down Expand Up @@ -112,8 +113,14 @@ public static Expression simplifyCast(Cast cast) {
return new DecimalV3Literal(decimalV3Type,
new BigDecimal(((BigIntLiteral) child).getValue()));
} else if (child instanceof DecimalV3Literal) {
return new DecimalV3Literal(decimalV3Type,
((DecimalV3Literal) child).getValue());
DecimalV3Type childType = (DecimalV3Type) child.getDataType();
if (childType.getRange() <= decimalV3Type.getRange()) {
return new DecimalV3Literal(decimalV3Type,
((DecimalV3Literal) child).getValue()
.setScale(decimalV3Type.getScale(), RoundingMode.HALF_UP));
} else {
return cast;
}
}
}
} catch (Throwable t) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -234,9 +234,10 @@ private static Expression processDecimalV3TypeCoercion(ComparisonPredicate compa
int toScale = ((DecimalV3Type) left.getDataType()).getScale();
if (comparisonPredicate instanceof EqualTo) {
try {
return comparisonPredicate.withChildren(left,
new DecimalV3Literal((DecimalV3Type) left.getDataType(),
literal.getValue().setScale(toScale)));
return TypeCoercionUtils.processComparisonPredicate((ComparisonPredicate)
comparisonPredicate.withChildren(left,
new DecimalV3Literal((DecimalV3Type) left.getDataType(),
literal.getValue().setScale(toScale, RoundingMode.UNNECESSARY))));
} catch (ArithmeticException e) {
if (left.nullable()) {
// TODO: the ideal way is to return an If expr like:
Expand All @@ -253,24 +254,25 @@ private static Expression processDecimalV3TypeCoercion(ComparisonPredicate compa
}
} else if (comparisonPredicate instanceof NullSafeEqual) {
try {
return comparisonPredicate.withChildren(left,
new DecimalV3Literal((DecimalV3Type) left.getDataType(),
literal.getValue().setScale(toScale)));
return TypeCoercionUtils.processComparisonPredicate((ComparisonPredicate)
comparisonPredicate.withChildren(left,
new DecimalV3Literal((DecimalV3Type) left.getDataType(),
literal.getValue().setScale(toScale, RoundingMode.UNNECESSARY))));
} catch (ArithmeticException e) {
return BooleanLiteral.of(false);
}
} else if (comparisonPredicate instanceof GreaterThan
|| comparisonPredicate instanceof LessThanEqual) {
return comparisonPredicate.withChildren(left, literal.roundFloor(toScale));
return TypeCoercionUtils.processComparisonPredicate((ComparisonPredicate)
comparisonPredicate.withChildren(left, literal.roundFloor(toScale)));
} else if (comparisonPredicate instanceof LessThan
|| comparisonPredicate instanceof GreaterThanEqual) {
return comparisonPredicate.withChildren(left,
literal.roundCeiling(toScale));
return TypeCoercionUtils.processComparisonPredicate((ComparisonPredicate)
comparisonPredicate.withChildren(left, literal.roundCeiling(toScale)));
}
}
} else if (left.getDataType().isIntegerLikeType()) {
return processIntegerDecimalLiteralComparison(comparisonPredicate, left,
literal.getValue());
return processIntegerDecimalLiteralComparison(comparisonPredicate, left, literal.getValue());
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.doris.nereids.rules.expression.rules;

import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.rules.expression.ExpressionPatternMatcher;
import org.apache.doris.nereids.rules.expression.ExpressionPatternRuleFactory;
import org.apache.doris.nereids.trees.expressions.Cast;
Expand All @@ -25,7 +26,6 @@
import org.apache.doris.nereids.trees.expressions.literal.DecimalV3Literal;
import org.apache.doris.nereids.types.DecimalV3Type;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;

import java.math.BigDecimal;
Expand Down Expand Up @@ -58,15 +58,17 @@ public static Expression simplify(ComparisonPredicate cp) {
if (left.getDataType() instanceof DecimalV3Type
&& left instanceof Cast
&& ((Cast) left).child().getDataType() instanceof DecimalV3Type
&& ((DecimalV3Type) left.getDataType()).getScale()
>= ((DecimalV3Type) ((Cast) left).child().getDataType()).getScale()
&& right instanceof DecimalV3Literal) {
return doProcess(cp, (Cast) left, (DecimalV3Literal) right);
try {
return doProcess(cp, (Cast) left, (DecimalV3Literal) right);
} catch (ArithmeticException e) {
return cp;
}
}

if (left != cp.left() || right != cp.right()) {
return cp.withChildren(left, right);
} else {
return cp;
}
return cp;
}

private static Expression doProcess(ComparisonPredicate cp, Cast left, DecimalV3Literal right) {
Expand All @@ -80,13 +82,16 @@ private static Expression doProcess(ComparisonPredicate cp, Cast left, DecimalV3
}

Expression castChild = left.child();
Preconditions.checkState(castChild.getDataType() instanceof DecimalV3Type);
if (!(castChild.getDataType() instanceof DecimalV3Type)) {
throw new AnalysisException("cast child's type should be DecimalV3Type, but its type is "
+ castChild.getDataType().toSql());
}
DecimalV3Type leftType = (DecimalV3Type) castChild.getDataType();
if (scale <= leftType.getScale() && precision - scale <= leftType.getPrecision() - leftType.getScale()) {
if (scale <= leftType.getScale() && precision - scale <= leftType.getRange()) {
// precision and scale of literal all smaller than left, we don't need the cast
DecimalV3Literal newRight = new DecimalV3Literal(
DecimalV3Type.createDecimalV3TypeLooseCheck(leftType.getPrecision(), leftType.getScale()),
trailingZerosValue);
trailingZerosValue.setScale(leftType.getScale(), RoundingMode.UNNECESSARY));
return cp.withChildren(castChild, newRight);
} else {
return cp;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,15 +73,11 @@ public double getDouble() {
}

public DecimalV3Literal roundCeiling(int newScale) {
return new DecimalV3Literal(DecimalV3Type
.createDecimalV3Type(((DecimalV3Type) dataType).getPrecision(), newScale),
value.setScale(newScale, RoundingMode.CEILING));
return new DecimalV3Literal(value.setScale(newScale, RoundingMode.CEILING));
}

public DecimalV3Literal roundFloor(int newScale) {
return new DecimalV3Literal(DecimalV3Type
.createDecimalV3Type(((DecimalV3Type) dataType).getPrecision(), newScale),
value.setScale(newScale, RoundingMode.FLOOR));
return new DecimalV3Literal(value.setScale(newScale, RoundingMode.FLOOR));
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ public static DataType convertPrimitiveFromStrings(List<String> types) {
case "decimalv3":
switch (types.size()) {
case 1:
dataType = DecimalV3Type.CATALOG_DEFAULT;
dataType = DecimalV3Type.createDecimalV3Type(38, 9);
break;
case 2:
dataType = DecimalV3Type.createDecimalV3Type(Integer.parseInt(types.get(1)));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
import org.apache.doris.nereids.trees.expressions.literal.TinyIntLiteral;
import org.apache.doris.nereids.trees.expressions.literal.VarcharLiteral;
import org.apache.doris.nereids.types.BigIntType;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.types.DecimalV2Type;
import org.apache.doris.nereids.types.DecimalV3Type;
import org.apache.doris.nereids.types.IntegerType;
Expand All @@ -42,7 +41,6 @@
import org.apache.doris.nereids.types.VarcharType;

import com.google.common.collect.ImmutableList;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;

import java.math.BigDecimal;
Expand All @@ -54,17 +52,17 @@ public void testSimplify() {
executor = new ExpressionRuleExecutor(ImmutableList.of(
ExpressionRewrite.bottomUp(SimplifyCastRule.INSTANCE))
);
assertRewriteAfterSimplify("CAST('1' AS STRING)", "'1'", StringType.INSTANCE);
assertRewriteAfterSimplify("CAST('1' AS VARCHAR)", "'1'",
VarcharType.createVarcharType(-1));
assertRewriteAfterSimplify("CAST(1 AS DECIMAL)", "1.000000000",
DecimalV3Type.createDecimalV3Type(38, 9));
assertRewriteAfterSimplify("CAST(1000 AS DECIMAL)", "1000.000000000",
DecimalV3Type.createDecimalV3Type(38, 9));
assertRewriteAfterSimplify("CAST(1 AS DECIMALV3)", "1",
DecimalV3Type.createDecimalV3Type(9, 0));
assertRewriteAfterSimplify("CAST(1000 AS DECIMALV3)", "1000",
DecimalV3Type.createDecimalV3Type(9, 0));

assertRewrite(new Cast(new VarcharLiteral("1"), StringType.INSTANCE),
new StringLiteral("1"));
assertRewrite(new Cast(new VarcharLiteral("1"), VarcharType.SYSTEM_DEFAULT),
new VarcharLiteral("1", -1));
assertRewrite(new Cast(new TinyIntLiteral((byte) 1), DecimalV3Type.SYSTEM_DEFAULT),
new DecimalV3Literal(DecimalV3Type.SYSTEM_DEFAULT, new BigDecimal("1.000000000")));
assertRewrite(new Cast(new SmallIntLiteral((short) 1000), DecimalV3Type.SYSTEM_DEFAULT),
new DecimalV3Literal(DecimalV3Type.SYSTEM_DEFAULT, new BigDecimal("1000.000000000")));
assertRewrite(new Cast(new VarcharLiteral("1"), VarcharType.SYSTEM_DEFAULT), new VarcharLiteral("1", -1));
assertRewrite(new Cast(new VarcharLiteral("1"), VarcharType.SYSTEM_DEFAULT), new VarcharLiteral("1", -1));

Expression tinyIntLiteral = new TinyIntLiteral((byte) 12);
// cast tinyint as tinyint
Expand Down Expand Up @@ -143,17 +141,20 @@ public void testSimplify() {
// cast char(5) as string
assertRewrite(new Cast(charLiteral, StringType.INSTANCE), new StringLiteral("12345"));

Expression decimalV3Literal = new DecimalV3Literal(DecimalV3Type.createDecimalV3Type(3, 1),
new BigDecimal("12.0"));
Expression decimalV3Literal = new DecimalV3Literal(DecimalV3Type.createDecimalV3Type(5, 3),
new BigDecimal("12.000"));
// cast decimalv3(3,1) as decimalv3(5,1)
assertRewrite(new Cast(decimalV3Literal, DecimalV3Type.createDecimalV3Type(5, 1)),
new DecimalV3Literal(DecimalV3Type.createDecimalV3Type(5, 1),
assertRewrite(new Cast(decimalV3Literal, DecimalV3Type.createDecimalV3Type(7, 3)),
new DecimalV3Literal(DecimalV3Type.createDecimalV3Type(7, 3),
new BigDecimal("12.000")));
assertRewrite(new Cast(decimalV3Literal, DecimalV3Type.createDecimalV3Type(3, 1)),
new DecimalV3Literal(DecimalV3Type.createDecimalV3Type(3, 1),
new BigDecimal("12.0")));

assertRewrite(new Cast(decimalV3Literal, DecimalV3Type.createDecimalV3Type(2, 1)),
new Cast(decimalV3Literal, DecimalV3Type.createDecimalV3Type(2, 1)));

// TODO unsupported but should?
// TODO unsupported, supported by org.apache.doris.nereids.trees.expressions.literal.Literal.uncheckedCastTo
// cast tinyint as smallint
assertRewrite(new Cast(tinyIntLiteral, SmallIntType.INSTANCE),
new Cast(tinyIntLiteral, SmallIntType.INSTANCE));
Expand Down Expand Up @@ -186,13 +187,4 @@ public void testSimplify() {
new DecimalV3Literal(DecimalV3Type.createDecimalV3Type(6, 1),
new BigDecimal("12.0")));
}

private void assertRewriteAfterSimplify(String expr, String expected, DataType expectedType) {
Expression needRewriteExpression = PARSER.parseExpression(expr);
Expression rewritten = executor.rewrite(needRewriteExpression, context);
Expression expectedExpression = PARSER.parseExpression(expected);
Assertions.assertEquals(expectedExpression.toSql(), rewritten.toSql());
Assertions.assertEquals(expectedType, rewritten.getDataType());
}

}
Loading