Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.doris.nereids.trees.expressions;

import org.apache.doris.analysis.ArithmeticExpr.Operator;
import org.apache.doris.nereids.trees.expressions.functions.CheckOverflowNullable;
import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.types.DecimalV3Type;
Expand All @@ -31,7 +31,7 @@
/**
* Add Expression.
*/
public class Add extends BinaryArithmetic implements CheckOverflowNullable {
public class Add extends BinaryArithmetic implements PropagateNullable {

public Add(Expression left, Expression right) {
super(ImmutableList.of(left, right), Operator.ADD);
Expand All @@ -49,21 +49,16 @@ public Expression withChildren(List<Expression> children) {

@Override
public DecimalV3Type getDataTypeForDecimalV3(DecimalV3Type t1, DecimalV3Type t2) {
DecimalV3Type decimalV3Type = (DecimalV3Type) DecimalV3Type.widerDecimalV3Type(t1, t2, false);
return DecimalV3Type.createDecimalV3Type(decimalV3Type.getPrecision() + 1,
decimalV3Type.getScale());
int targetScale = Math.max(t1.getScale(), t2.getScale());
int integralPart = Math.max(t1.getRange(), t2.getRange());
return processDecimalV3OverFlow(integralPart + 1, targetScale, integralPart);
}

@Override
public DataType getDataTypeForOthers(DataType t1, DataType t2) {
return super.getDataTypeForOthers(t1, t2).promotion();
}

@Override
public boolean nullable() {
return CheckOverflowNullable.super.nullable();
}

@Override
public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
return visitor.visitAdd(this, context);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import org.apache.doris.nereids.types.DecimalV3Type;
import org.apache.doris.nereids.types.coercion.NumericType;
import org.apache.doris.nereids.util.TypeCoercionUtils;
import org.apache.doris.qe.ConnectContext;

import java.util.List;

Expand Down Expand Up @@ -93,4 +94,23 @@ public DataType getDataTypeForOthers(DataType t1, DataType t2) {
public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
return visitor.visitBinaryArithmetic(this, context);
}

protected DecimalV3Type processDecimalV3OverFlow(int integralPart, int targetScale, int maxIntegralPart) {
int precision = integralPart + targetScale;
boolean enableDecimal256 = false;
ConnectContext connectContext = ConnectContext.get();
if (connectContext != null) {
enableDecimal256 = connectContext.getSessionVariable().isEnableDecimal256();
}
if (enableDecimal256) {
if (precision > DecimalV3Type.MAX_DECIMAL256_PRECISION) {
precision = DecimalV3Type.MAX_DECIMAL256_PRECISION;
}
} else {
if (precision > DecimalV3Type.MAX_DECIMAL128_PRECISION) {
precision = DecimalV3Type.MAX_DECIMAL128_PRECISION;
}
}
return DecimalV3Type.createDecimalV3Type(precision, precision - maxIntegralPart);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.types.DecimalV3Type;
import org.apache.doris.nereids.types.DoubleType;
import org.apache.doris.qe.ConnectContext;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
Expand Down Expand Up @@ -67,21 +68,40 @@ public DataType getDataType() throws UnboundException {

@Override
public DecimalV3Type getDataTypeForDecimalV3(DecimalV3Type t1, DecimalV3Type t2) {
int retPercision = t1.getPrecision() + t2.getScale() + Config.div_precision_increment;
Preconditions.checkState(retPercision <= DecimalV3Type.MAX_DECIMAL256_PRECISION,
"target precision " + retPercision + " larger than precision "
+ DecimalV3Type.MAX_DECIMAL256_PRECISION + " in Divide return type");
int retScale = t1.getScale() + t2.getScale()
+ Config.div_precision_increment;
int targetPercision = retPercision;
int targetScale = t1.getScale() + t2.getScale();
Preconditions.checkState(targetPercision >= targetScale,
"target scale " + targetScale + " larger than precision " + retPercision
+ " in Divide return type");
Preconditions.checkState(retPercision >= retScale,
"scale " + retScale + " larger than precision " + retPercision
+ " in Divide return type");
return DecimalV3Type.createDecimalV3Type(retPercision, retScale);
int precision = t1.getPrecision() + t2.getScale() + Config.div_precision_increment;
int scale = t1.getScale();
boolean enableDecimal256 = false;
int defaultScale = 6;
ConnectContext connectContext = ConnectContext.get();
if (connectContext != null) {
enableDecimal256 = connectContext.getSessionVariable().isEnableDecimal256();
defaultScale = connectContext.getSessionVariable().decimalOverflowScale;
}
if (enableDecimal256 && precision > DecimalV3Type.MAX_DECIMAL256_PRECISION) {
int integralPartBoundary = DecimalV3Type.MAX_DECIMAL256_PRECISION - defaultScale;
if (precision - scale < integralPartBoundary) {
// retains more int part
scale = DecimalV3Type.MAX_DECIMAL256_PRECISION - (precision - scale);
} else if (precision - scale > integralPartBoundary && scale < defaultScale) {
// scale not change, retains more scale part
} else {
scale = defaultScale;
}
precision = DecimalV3Type.MAX_DECIMAL256_PRECISION;
} else if (!enableDecimal256 && precision > DecimalV3Type.MAX_DECIMAL128_PRECISION) {
int integralPartBoundary = DecimalV3Type.MAX_DECIMAL128_PRECISION - defaultScale;
if (precision - scale < integralPartBoundary) {
// retains more int part
scale = DecimalV3Type.MAX_DECIMAL128_PRECISION - (precision - scale);
} else if (precision - scale > integralPartBoundary && scale < defaultScale) {
// scale not change, retains more scale part
} else {
scale = defaultScale;
}
precision = DecimalV3Type.MAX_DECIMAL128_PRECISION;
}
scale = Math.min(precision, scale + t2.getScale() + Config.div_precision_increment);
return DecimalV3Type.createDecimalV3Type(precision, scale);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,9 @@ public Expression withChildren(List<Expression> children) {

@Override
public DecimalV3Type getDataTypeForDecimalV3(DecimalV3Type t1, DecimalV3Type t2) {
// TODO use max int part + max scale of two operands as result type
// because BE require the result and operands types are the exact the same decimalv3 type
int scale = Math.max(t1.getScale(), t2.getScale());
int precision = Math.max(t1.getRange(), t2.getRange()) + scale;
return DecimalV3Type.createDecimalV3Type(precision, scale);
int targetScale = Math.max(t1.getScale(), t2.getScale());
int integralPart = Math.max(t1.getRange(), t2.getRange());
return processDecimalV3OverFlow(integralPart, targetScale, integralPart);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.doris.nereids.trees.expressions;

import org.apache.doris.analysis.ArithmeticExpr.Operator;
import org.apache.doris.nereids.trees.expressions.functions.CheckOverflowNullable;
import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.types.DecimalV3Type;
Expand All @@ -32,7 +32,7 @@
/**
* Multiply Expression.
*/
public class Multiply extends BinaryArithmetic implements CheckOverflowNullable {
public class Multiply extends BinaryArithmetic implements PropagateNullable {

public Multiply(Expression left, Expression right) {
super(ImmutableList.of(left, right), Operator.MULTIPLY);
Expand All @@ -50,38 +50,49 @@ public Expression withChildren(List<Expression> children) {

@Override
public DecimalV3Type getDataTypeForDecimalV3(DecimalV3Type t1, DecimalV3Type t2) {
int retPercision = t1.getPrecision() + t2.getPrecision();
int retPrecision = t1.getPrecision() + t2.getPrecision();
int retScale = t1.getScale() + t2.getScale();
if (retPercision > DecimalV3Type.MAX_DECIMAL128_PRECISION) {
boolean enableDecimal256 = false;
ConnectContext connectContext = ConnectContext.get();
if (connectContext != null) {
enableDecimal256 = connectContext.getSessionVariable().isEnableDecimal256();
boolean enableDecimal256 = false;
int defaultScale = 6;
ConnectContext connectContext = ConnectContext.get();
if (connectContext != null) {
enableDecimal256 = connectContext.getSessionVariable().isEnableDecimal256();
defaultScale = connectContext.getSessionVariable().decimalOverflowScale;
}
if (!enableDecimal256 && retPrecision > DecimalV3Type.MAX_DECIMAL128_PRECISION) {
int integralPartBoundary = DecimalV3Type.MAX_DECIMAL128_PRECISION - defaultScale;
if (retPrecision - retScale < integralPartBoundary) {
// retains more int part
retScale = DecimalV3Type.MAX_DECIMAL128_PRECISION - (retPrecision - retScale);
} else if (retPrecision - retScale > integralPartBoundary && retScale < defaultScale) {
// retScale not change, retains more scale part
} else {
retScale = defaultScale;
}
if (enableDecimal256) {
if (retPercision > DecimalV3Type.MAX_DECIMAL256_PRECISION) {
retPercision = DecimalV3Type.MAX_DECIMAL256_PRECISION;
}
retPrecision = DecimalV3Type.MAX_DECIMAL128_PRECISION;
} else if (enableDecimal256 && retPrecision > DecimalV3Type.MAX_DECIMAL256_PRECISION) {
int integralPartBoundary = DecimalV3Type.MAX_DECIMAL256_PRECISION - defaultScale;
if (retPrecision - retScale < integralPartBoundary) {
// retains more int part
retScale = DecimalV3Type.MAX_DECIMAL256_PRECISION - (retPrecision - retScale);
} else if (retPrecision - retScale > integralPartBoundary && retScale < defaultScale) {
// retScale not change, retains more scale part
} else {
retPercision = DecimalV3Type.MAX_DECIMAL128_PRECISION;
retScale = defaultScale;
}
retPrecision = DecimalV3Type.MAX_DECIMAL256_PRECISION;
}
Preconditions.checkState(retPercision >= retScale,
"scale " + retScale + " larger than precision " + retPercision
Preconditions.checkState(retPrecision >= retScale,
"scale " + retScale + " larger than precision " + retPrecision
+ " in Multiply return type");
return DecimalV3Type.createDecimalV3Type(retPercision, retScale);
return DecimalV3Type.createDecimalV3Type(retPrecision, retScale);
}

@Override
public DataType getDataTypeForOthers(DataType t1, DataType t2) {
return super.getDataTypeForOthers(t1, t2).promotion();
}

@Override
public boolean nullable() {
return CheckOverflowNullable.super.nullable();
}

@Override
public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
return visitor.visitMultiply(this, context);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.doris.nereids.trees.expressions;

import org.apache.doris.analysis.ArithmeticExpr.Operator;
import org.apache.doris.nereids.trees.expressions.functions.CheckOverflowNullable;
import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.types.DecimalV3Type;
Expand All @@ -31,7 +31,7 @@
/**
* Subtract Expression. BinaryExpression.
*/
public class Subtract extends BinaryArithmetic implements CheckOverflowNullable {
public class Subtract extends BinaryArithmetic implements PropagateNullable {

public Subtract(Expression left, Expression right) {
super(ImmutableList.of(left, right), Operator.SUBTRACT);
Expand All @@ -49,21 +49,16 @@ public Expression withChildren(List<Expression> children) {

@Override
public DecimalV3Type getDataTypeForDecimalV3(DecimalV3Type t1, DecimalV3Type t2) {
DecimalV3Type decimalV3Type = (DecimalV3Type) DecimalV3Type.widerDecimalV3Type(t1, t2, false);
return DecimalV3Type.createDecimalV3Type(decimalV3Type.getPrecision() + 1,
decimalV3Type.getScale());
int targetScale = Math.max(t1.getScale(), t2.getScale());
int integralPart = Math.max(t1.getRange(), t2.getRange());
return processDecimalV3OverFlow(integralPart + 1, targetScale, integralPart);
}

@Override
public DataType getDataTypeForOthers(DataType t1, DataType t2) {
return super.getDataTypeForOthers(t1, t2).promotion();
}

@Override
public boolean nullable() {
return CheckOverflowNullable.super.nullable();
}

@Override
public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
return visitor.visitSubtract(this, context);
Expand Down

This file was deleted.

Loading