diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Add.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Add.java index bf19192ff5c6fb..0ebd65b769b9cc 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Add.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Add.java @@ -18,7 +18,7 @@ package org.apache.doris.nereids.trees.expressions; import org.apache.doris.analysis.ArithmeticExpr.Operator; -import org.apache.doris.nereids.trees.expressions.functions.CheckOverflowNullable; +import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable; import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; import org.apache.doris.nereids.types.DataType; import org.apache.doris.nereids.types.DecimalV3Type; @@ -31,7 +31,7 @@ /** * Add Expression. */ -public class Add extends BinaryArithmetic implements CheckOverflowNullable { +public class Add extends BinaryArithmetic implements PropagateNullable { public Add(Expression left, Expression right) { super(ImmutableList.of(left, right), Operator.ADD); @@ -49,9 +49,9 @@ public Expression withChildren(List children) { @Override public DecimalV3Type getDataTypeForDecimalV3(DecimalV3Type t1, DecimalV3Type t2) { - DecimalV3Type decimalV3Type = (DecimalV3Type) DecimalV3Type.widerDecimalV3Type(t1, t2, false); - return DecimalV3Type.createDecimalV3Type(decimalV3Type.getPrecision() + 1, - decimalV3Type.getScale()); + int targetScale = Math.max(t1.getScale(), t2.getScale()); + int integralPart = Math.max(t1.getRange(), t2.getRange()); + return processDecimalV3OverFlow(integralPart + 1, targetScale, integralPart); } @Override @@ -59,11 +59,6 @@ public DataType getDataTypeForOthers(DataType t1, DataType t2) { return super.getDataTypeForOthers(t1, t2).promotion(); } - @Override - public boolean nullable() { - return CheckOverflowNullable.super.nullable(); - } - @Override public R accept(ExpressionVisitor visitor, C context) { return visitor.visitAdd(this, context); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/BinaryArithmetic.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/BinaryArithmetic.java index 84eb08eb5810dd..5db61e211b3a0b 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/BinaryArithmetic.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/BinaryArithmetic.java @@ -27,6 +27,7 @@ import org.apache.doris.nereids.types.DecimalV3Type; import org.apache.doris.nereids.types.coercion.NumericType; import org.apache.doris.nereids.util.TypeCoercionUtils; +import org.apache.doris.qe.ConnectContext; import java.util.List; @@ -93,4 +94,23 @@ public DataType getDataTypeForOthers(DataType t1, DataType t2) { public R accept(ExpressionVisitor visitor, C context) { return visitor.visitBinaryArithmetic(this, context); } + + protected DecimalV3Type processDecimalV3OverFlow(int integralPart, int targetScale, int maxIntegralPart) { + int precision = integralPart + targetScale; + boolean enableDecimal256 = false; + ConnectContext connectContext = ConnectContext.get(); + if (connectContext != null) { + enableDecimal256 = connectContext.getSessionVariable().isEnableDecimal256(); + } + if (enableDecimal256) { + if (precision > DecimalV3Type.MAX_DECIMAL256_PRECISION) { + precision = DecimalV3Type.MAX_DECIMAL256_PRECISION; + } + } else { + if (precision > DecimalV3Type.MAX_DECIMAL128_PRECISION) { + precision = DecimalV3Type.MAX_DECIMAL128_PRECISION; + } + } + return DecimalV3Type.createDecimalV3Type(precision, precision - maxIntegralPart); + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Divide.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Divide.java index 002849bb8166ca..ef0d86b579cd9d 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Divide.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Divide.java @@ -25,6 +25,7 @@ import org.apache.doris.nereids.types.DataType; import org.apache.doris.nereids.types.DecimalV3Type; import org.apache.doris.nereids.types.DoubleType; +import org.apache.doris.qe.ConnectContext; import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; @@ -67,21 +68,40 @@ public DataType getDataType() throws UnboundException { @Override public DecimalV3Type getDataTypeForDecimalV3(DecimalV3Type t1, DecimalV3Type t2) { - int retPercision = t1.getPrecision() + t2.getScale() + Config.div_precision_increment; - Preconditions.checkState(retPercision <= DecimalV3Type.MAX_DECIMAL256_PRECISION, - "target precision " + retPercision + " larger than precision " - + DecimalV3Type.MAX_DECIMAL256_PRECISION + " in Divide return type"); - int retScale = t1.getScale() + t2.getScale() - + Config.div_precision_increment; - int targetPercision = retPercision; - int targetScale = t1.getScale() + t2.getScale(); - Preconditions.checkState(targetPercision >= targetScale, - "target scale " + targetScale + " larger than precision " + retPercision - + " in Divide return type"); - Preconditions.checkState(retPercision >= retScale, - "scale " + retScale + " larger than precision " + retPercision - + " in Divide return type"); - return DecimalV3Type.createDecimalV3Type(retPercision, retScale); + int precision = t1.getPrecision() + t2.getScale() + Config.div_precision_increment; + int scale = t1.getScale(); + boolean enableDecimal256 = false; + int defaultScale = 6; + ConnectContext connectContext = ConnectContext.get(); + if (connectContext != null) { + enableDecimal256 = connectContext.getSessionVariable().isEnableDecimal256(); + defaultScale = connectContext.getSessionVariable().decimalOverflowScale; + } + if (enableDecimal256 && precision > DecimalV3Type.MAX_DECIMAL256_PRECISION) { + int integralPartBoundary = DecimalV3Type.MAX_DECIMAL256_PRECISION - defaultScale; + if (precision - scale < integralPartBoundary) { + // retains more int part + scale = DecimalV3Type.MAX_DECIMAL256_PRECISION - (precision - scale); + } else if (precision - scale > integralPartBoundary && scale < defaultScale) { + // scale not change, retains more scale part + } else { + scale = defaultScale; + } + precision = DecimalV3Type.MAX_DECIMAL256_PRECISION; + } else if (!enableDecimal256 && precision > DecimalV3Type.MAX_DECIMAL128_PRECISION) { + int integralPartBoundary = DecimalV3Type.MAX_DECIMAL128_PRECISION - defaultScale; + if (precision - scale < integralPartBoundary) { + // retains more int part + scale = DecimalV3Type.MAX_DECIMAL128_PRECISION - (precision - scale); + } else if (precision - scale > integralPartBoundary && scale < defaultScale) { + // scale not change, retains more scale part + } else { + scale = defaultScale; + } + precision = DecimalV3Type.MAX_DECIMAL128_PRECISION; + } + scale = Math.min(precision, scale + t2.getScale() + Config.div_precision_increment); + return DecimalV3Type.createDecimalV3Type(precision, scale); } @Override diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Mod.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Mod.java index 9a2865509151ab..569e7cabb3b14e 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Mod.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Mod.java @@ -49,11 +49,9 @@ public Expression withChildren(List children) { @Override public DecimalV3Type getDataTypeForDecimalV3(DecimalV3Type t1, DecimalV3Type t2) { - // TODO use max int part + max scale of two operands as result type - // because BE require the result and operands types are the exact the same decimalv3 type - int scale = Math.max(t1.getScale(), t2.getScale()); - int precision = Math.max(t1.getRange(), t2.getRange()) + scale; - return DecimalV3Type.createDecimalV3Type(precision, scale); + int targetScale = Math.max(t1.getScale(), t2.getScale()); + int integralPart = Math.max(t1.getRange(), t2.getRange()); + return processDecimalV3OverFlow(integralPart, targetScale, integralPart); } @Override diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Multiply.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Multiply.java index 52c43fa85612a9..d2af719208f460 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Multiply.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Multiply.java @@ -18,7 +18,7 @@ package org.apache.doris.nereids.trees.expressions; import org.apache.doris.analysis.ArithmeticExpr.Operator; -import org.apache.doris.nereids.trees.expressions.functions.CheckOverflowNullable; +import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable; import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; import org.apache.doris.nereids.types.DataType; import org.apache.doris.nereids.types.DecimalV3Type; @@ -32,7 +32,7 @@ /** * Multiply Expression. */ -public class Multiply extends BinaryArithmetic implements CheckOverflowNullable { +public class Multiply extends BinaryArithmetic implements PropagateNullable { public Multiply(Expression left, Expression right) { super(ImmutableList.of(left, right), Operator.MULTIPLY); @@ -50,26 +50,42 @@ public Expression withChildren(List children) { @Override public DecimalV3Type getDataTypeForDecimalV3(DecimalV3Type t1, DecimalV3Type t2) { - int retPercision = t1.getPrecision() + t2.getPrecision(); + int retPrecision = t1.getPrecision() + t2.getPrecision(); int retScale = t1.getScale() + t2.getScale(); - if (retPercision > DecimalV3Type.MAX_DECIMAL128_PRECISION) { - boolean enableDecimal256 = false; - ConnectContext connectContext = ConnectContext.get(); - if (connectContext != null) { - enableDecimal256 = connectContext.getSessionVariable().isEnableDecimal256(); + boolean enableDecimal256 = false; + int defaultScale = 6; + ConnectContext connectContext = ConnectContext.get(); + if (connectContext != null) { + enableDecimal256 = connectContext.getSessionVariable().isEnableDecimal256(); + defaultScale = connectContext.getSessionVariable().decimalOverflowScale; + } + if (!enableDecimal256 && retPrecision > DecimalV3Type.MAX_DECIMAL128_PRECISION) { + int integralPartBoundary = DecimalV3Type.MAX_DECIMAL128_PRECISION - defaultScale; + if (retPrecision - retScale < integralPartBoundary) { + // retains more int part + retScale = DecimalV3Type.MAX_DECIMAL128_PRECISION - (retPrecision - retScale); + } else if (retPrecision - retScale > integralPartBoundary && retScale < defaultScale) { + // retScale not change, retains more scale part + } else { + retScale = defaultScale; } - if (enableDecimal256) { - if (retPercision > DecimalV3Type.MAX_DECIMAL256_PRECISION) { - retPercision = DecimalV3Type.MAX_DECIMAL256_PRECISION; - } + retPrecision = DecimalV3Type.MAX_DECIMAL128_PRECISION; + } else if (enableDecimal256 && retPrecision > DecimalV3Type.MAX_DECIMAL256_PRECISION) { + int integralPartBoundary = DecimalV3Type.MAX_DECIMAL256_PRECISION - defaultScale; + if (retPrecision - retScale < integralPartBoundary) { + // retains more int part + retScale = DecimalV3Type.MAX_DECIMAL256_PRECISION - (retPrecision - retScale); + } else if (retPrecision - retScale > integralPartBoundary && retScale < defaultScale) { + // retScale not change, retains more scale part } else { - retPercision = DecimalV3Type.MAX_DECIMAL128_PRECISION; + retScale = defaultScale; } + retPrecision = DecimalV3Type.MAX_DECIMAL256_PRECISION; } - Preconditions.checkState(retPercision >= retScale, - "scale " + retScale + " larger than precision " + retPercision + Preconditions.checkState(retPrecision >= retScale, + "scale " + retScale + " larger than precision " + retPrecision + " in Multiply return type"); - return DecimalV3Type.createDecimalV3Type(retPercision, retScale); + return DecimalV3Type.createDecimalV3Type(retPrecision, retScale); } @Override @@ -77,11 +93,6 @@ public DataType getDataTypeForOthers(DataType t1, DataType t2) { return super.getDataTypeForOthers(t1, t2).promotion(); } - @Override - public boolean nullable() { - return CheckOverflowNullable.super.nullable(); - } - @Override public R accept(ExpressionVisitor visitor, C context) { return visitor.visitMultiply(this, context); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Subtract.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Subtract.java index 0b38694769b777..8e80e5e673c42c 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Subtract.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Subtract.java @@ -18,7 +18,7 @@ package org.apache.doris.nereids.trees.expressions; import org.apache.doris.analysis.ArithmeticExpr.Operator; -import org.apache.doris.nereids.trees.expressions.functions.CheckOverflowNullable; +import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable; import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; import org.apache.doris.nereids.types.DataType; import org.apache.doris.nereids.types.DecimalV3Type; @@ -31,7 +31,7 @@ /** * Subtract Expression. BinaryExpression. */ -public class Subtract extends BinaryArithmetic implements CheckOverflowNullable { +public class Subtract extends BinaryArithmetic implements PropagateNullable { public Subtract(Expression left, Expression right) { super(ImmutableList.of(left, right), Operator.SUBTRACT); @@ -49,9 +49,9 @@ public Expression withChildren(List children) { @Override public DecimalV3Type getDataTypeForDecimalV3(DecimalV3Type t1, DecimalV3Type t2) { - DecimalV3Type decimalV3Type = (DecimalV3Type) DecimalV3Type.widerDecimalV3Type(t1, t2, false); - return DecimalV3Type.createDecimalV3Type(decimalV3Type.getPrecision() + 1, - decimalV3Type.getScale()); + int targetScale = Math.max(t1.getScale(), t2.getScale()); + int integralPart = Math.max(t1.getRange(), t2.getRange()); + return processDecimalV3OverFlow(integralPart + 1, targetScale, integralPart); } @Override @@ -59,11 +59,6 @@ public DataType getDataTypeForOthers(DataType t1, DataType t2) { return super.getDataTypeForOthers(t1, t2).promotion(); } - @Override - public boolean nullable() { - return CheckOverflowNullable.super.nullable(); - } - @Override public R accept(ExpressionVisitor visitor, C context) { return visitor.visitSubtract(this, context); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/CheckOverflowNullable.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/CheckOverflowNullable.java deleted file mode 100644 index 13a24faa7a73ea..00000000000000 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/CheckOverflowNullable.java +++ /dev/null @@ -1,34 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -package org.apache.doris.nereids.trees.expressions.functions; - -import org.apache.doris.qe.ConnectContext; - -/** - * if session variable check_overflow_for_decimal set to true, the expression's return always nullable - */ -public interface CheckOverflowNullable extends PropagateNullable { - @Override - default boolean nullable() { - if (ConnectContext.get() != null && ConnectContext.get().getSessionVariable().checkOverflowForDecimal()) { - return true; - } else { - return PropagateNullable.super.nullable(); - } - } -} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/TypeCoercionUtils.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/TypeCoercionUtils.java index 2a6ba2f2ee8895..bf1a7c35cd22ed 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/TypeCoercionUtils.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/TypeCoercionUtils.java @@ -104,7 +104,6 @@ import org.apache.doris.nereids.types.coercion.IntegralType; import org.apache.doris.nereids.types.coercion.NumericType; import org.apache.doris.nereids.types.coercion.PrimitiveType; -import org.apache.doris.qe.ConnectContext; import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableList; @@ -667,20 +666,15 @@ public static Expression processDivide(Divide divide) { DataType commonType = DoubleType.INSTANCE; if (t1.isFloatLikeType() || t2.isFloatLikeType()) { // double type - } else if (t1.isDecimalV3Type() || t2.isDecimalV3Type()) { + } else if (t1.isDecimalV3Type() || t2.isDecimalV3Type() + // decimalv2 vs bigint, largeint treat as decimalv3 + || ((t1.isBigIntType() || t1.isLargeIntType()) && t2.isDecimalV2Type()) + || (t1.isDecimalV2Type() && (t2.isBigIntType() || t2.isLargeIntType()))) { // divide should cast to precision and target scale - DecimalV3Type retType; DecimalV3Type dt1 = DecimalV3Type.forType(t1); DecimalV3Type dt2 = DecimalV3Type.forType(t2); - try { - retType = divide.getDataTypeForDecimalV3(dt1, dt2); - } catch (Exception e) { - // exception means overflow. - return castChildren(divide, left, right, DoubleType.INSTANCE); - } - return divide.withChildren(castIfNotSameType(left, - DecimalV3Type.createDecimalV3Type(retType.getPrecision(), retType.getScale())), - castIfNotSameType(right, dt2)); + DecimalV3Type retType = divide.getDataTypeForDecimalV3(dt1, dt2); + return divide.withChildren(castIfNotSameType(left, retType), castIfNotSameType(right, dt2)); } else if (t1.isDecimalV2Type() || t2.isDecimalV2Type()) { commonType = DecimalV2Type.SYSTEM_DEFAULT; } @@ -773,18 +767,16 @@ public static Expression processBinaryArithmetic(BinaryArithmetic binaryArithmet commonType = DoubleType.INSTANCE; } - if (t1.isDecimalV3Type() && t2.isDecimalV2Type() - || t1.isDecimalV2Type() && t2.isDecimalV3Type()) { + // we treat decimalv2 vs dicimalv3, largeint or bigint as decimalv3 way. + if ((t1.isDecimalV3Type() || t1.isBigIntType() || t1.isLargeIntType()) && t2.isDecimalV2Type() + || t1.isDecimalV2Type() && (t2.isDecimalV3Type() || t2.isBigIntType() || t2.isLargeIntType())) { return processDecimalV3BinaryArithmetic(binaryArithmetic, left, right); } if (t1.isDecimalV2Type() || t2.isDecimalV2Type()) { - // to be consitent with old planner + // to be consistent with old planner // see findCommonType() method in ArithmeticExpr.java - commonType = t1.isDecimalV2Type() && t2.isDecimalV2Type() - || (ConnectContext.get() != null - && ConnectContext.get().getSessionVariable().roundPreciseDecimalV2Value) - ? DecimalV2Type.SYSTEM_DEFAULT : DoubleType.INSTANCE; + commonType = DecimalV2Type.SYSTEM_DEFAULT; } boolean isBitArithmetic = binaryArithmetic instanceof BitAnd @@ -814,7 +806,7 @@ public static Expression processBinaryArithmetic(BinaryArithmetic binaryArithmet return castChildren(binaryArithmetic, left, right, commonType); } - // double and float already process, we only process decimalv2 and fixed point number. + // double and float already process, we only process decimalv3 and fixed point number. if (t1 instanceof DecimalV3Type || t2 instanceof DecimalV3Type) { return processDecimalV3BinaryArithmetic(binaryArithmetic, left, right); } @@ -1582,13 +1574,7 @@ private static Expression processDecimalV3BinaryArithmetic(BinaryArithmetic bina DecimalV3Type.forType(TypeCoercionUtils.getNumResultType(right.getDataType())); // check return type whether overflow, if true, turn to double - DecimalV3Type retType; - try { - retType = binaryArithmetic.getDataTypeForDecimalV3(dt1, dt2); - } catch (Exception e) { - // exception means overflow. - return castChildren(binaryArithmetic, left, right, DoubleType.INSTANCE); - } + DecimalV3Type retType = binaryArithmetic.getDataTypeForDecimalV3(dt1, dt2); // add, subtract and mod should cast children to exactly same type as return type if (binaryArithmetic instanceof Add || binaryArithmetic instanceof Subtract diff --git a/fe/fe-core/src/main/java/org/apache/doris/qe/SessionVariable.java b/fe/fe-core/src/main/java/org/apache/doris/qe/SessionVariable.java index 1fcaa56ddf0dc6..911c876675c0c6 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/qe/SessionVariable.java +++ b/fe/fe-core/src/main/java/org/apache/doris/qe/SessionVariable.java @@ -242,6 +242,8 @@ public class SessionVariable implements Serializable, Writable { public static final String CHECK_OVERFLOW_FOR_DECIMAL = "check_overflow_for_decimal"; + public static final String DECIMAL_OVERFLOW_SCALE = "decimal_overflow_scale"; + public static final String TRIM_TAILING_SPACES_FOR_EXTERNAL_TABLE_QUERY = "trim_tailing_spaces_for_external_table_query"; @@ -904,6 +906,13 @@ public void setMaxJoinNumberOfReorder(int maxJoinNumberOfReorder) { @VariableMgr.VarAttr(name = CHECK_OVERFLOW_FOR_DECIMAL) private boolean checkOverflowForDecimal = false; + @VariableMgr.VarAttr(name = DECIMAL_OVERFLOW_SCALE, needForward = true, description = { + "当decimal数值计算结果精度溢出时,计算结果最多可保留的小数位数", "When the precision of the result of" + + " a decimal numerical calculation overflows," + + "the maximum number of decimal scale that the result can be retained" + }) + public int decimalOverflowScale = 6; + @VariableMgr.VarAttr(name = ENABLE_DPHYP_OPTIMIZER) public boolean enableDPHypOptimizer = false;