Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import org.apache.doris.nereids.types.DoubleType;
import org.apache.doris.nereids.types.IntegerType;
import org.apache.doris.nereids.types.LargeIntType;
import org.apache.doris.nereids.types.NullType;
import org.apache.doris.nereids.types.SmallIntType;
import org.apache.doris.nereids.types.TinyIntType;
import org.apache.doris.qe.ConnectContext;
Expand All @@ -49,14 +50,14 @@ public class Avg extends NullableAggregateFunction
implements UnaryExpression, ExplicitlyCastableSignature, ComputePrecision, SupportWindowAnalytic {

public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
FunctionSignature.ret(DoubleType.INSTANCE).args(TinyIntType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(SmallIntType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(IntegerType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(BigIntType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(LargeIntType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(DoubleType.INSTANCE),
FunctionSignature.ret(DecimalV3Type.WILDCARD).args(DecimalV3Type.WILDCARD),
FunctionSignature.ret(DecimalV2Type.SYSTEM_DEFAULT).args(DecimalV2Type.SYSTEM_DEFAULT),
FunctionSignature.ret(DecimalV3Type.WILDCARD).args(DecimalV3Type.WILDCARD)
FunctionSignature.ret(DoubleType.INSTANCE).args(LargeIntType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(BigIntType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(IntegerType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(SmallIntType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(TinyIntType.INSTANCE)
);

/**
Expand All @@ -80,8 +81,9 @@ private Avg(boolean distinct, boolean alwaysNullable, Expression arg) {
@Override
public void checkLegalityBeforeTypeCoercion() {
DataType argType = child().getDataType();
if (((!argType.isNumericType() && !argType.isNullType()) || argType.isOnlyMetricType())) {
throw new AnalysisException("avg requires a numeric parameter: " + toSql());
if (!argType.isNumericType() && !argType.isBooleanType()
&& !argType.isNullType() && !argType.isStringLikeType()) {
throw new AnalysisException("avg requires a numeric, boolean or string parameter: " + this.toSql());
}
}

Expand Down Expand Up @@ -153,4 +155,12 @@ public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
public List<FunctionSignature> getSignatures() {
return SIGNATURES;
}

@Override
public FunctionSignature searchSignature(List<FunctionSignature> signatures) {
if (getArgument(0).getDataType() instanceof NullType) {
return FunctionSignature.ret(DoubleType.INSTANCE).args(TinyIntType.INSTANCE);
}
return ExplicitlyCastableSignature.super.searchSignature(signatures);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,13 @@
import org.apache.doris.nereids.types.BigIntType;
import org.apache.doris.nereids.types.BooleanType;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.types.DecimalV2Type;
import org.apache.doris.nereids.types.DecimalV3Type;
import org.apache.doris.nereids.types.DoubleType;
import org.apache.doris.nereids.types.FloatType;
import org.apache.doris.nereids.types.IntegerType;
import org.apache.doris.nereids.types.LargeIntType;
import org.apache.doris.nereids.types.NullType;
import org.apache.doris.nereids.types.SmallIntType;
import org.apache.doris.nereids.types.TinyIntType;

Expand All @@ -50,14 +52,15 @@ public class Sum extends NullableAggregateFunction
RollUpTrait {

public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
FunctionSignature.ret(BigIntType.INSTANCE).args(BooleanType.INSTANCE),
FunctionSignature.ret(BigIntType.INSTANCE).args(TinyIntType.INSTANCE),
FunctionSignature.ret(BigIntType.INSTANCE).args(SmallIntType.INSTANCE),
FunctionSignature.ret(BigIntType.INSTANCE).args(IntegerType.INSTANCE),
FunctionSignature.ret(BigIntType.INSTANCE).args(BigIntType.INSTANCE),
FunctionSignature.ret(LargeIntType.INSTANCE).args(LargeIntType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(DoubleType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(FloatType.INSTANCE),
FunctionSignature.ret(DecimalV3Type.WILDCARD).args(DecimalV3Type.WILDCARD),
FunctionSignature.ret(DoubleType.INSTANCE).args(DoubleType.INSTANCE)
FunctionSignature.ret(LargeIntType.INSTANCE).args(LargeIntType.INSTANCE),
FunctionSignature.ret(BigIntType.INSTANCE).args(BigIntType.INSTANCE),
FunctionSignature.ret(BigIntType.INSTANCE).args(IntegerType.INSTANCE),
FunctionSignature.ret(BigIntType.INSTANCE).args(SmallIntType.INSTANCE),
FunctionSignature.ret(BigIntType.INSTANCE).args(TinyIntType.INSTANCE),
FunctionSignature.ret(BigIntType.INSTANCE).args(BooleanType.INSTANCE)
);

/**
Expand Down Expand Up @@ -87,9 +90,9 @@ public MultiDistinctSum convertToMultiDistinct() {
@Override
public void checkLegalityBeforeTypeCoercion() {
DataType argType = child().getDataType();
if ((!argType.isNumericType() && !argType.isBooleanType() && !argType.isNullType())
|| argType.isOnlyMetricType()) {
throw new AnalysisException("sum requires a numeric or boolean parameter: " + this.toSql());
if (!argType.isNumericType() && !argType.isBooleanType()
&& !argType.isNullType() && !argType.isStringLikeType()) {
throw new AnalysisException("sum requires a numeric, boolean or string parameter: " + this.toSql());
}
}

Expand Down Expand Up @@ -119,8 +122,10 @@ public List<FunctionSignature> getSignatures() {

@Override
public FunctionSignature searchSignature(List<FunctionSignature> signatures) {
if (getArgument(0).getDataType() instanceof FloatType) {
return FunctionSignature.ret(DoubleType.INSTANCE).args(FloatType.INSTANCE);
if (getArgument(0).getDataType() instanceof NullType) {
return FunctionSignature.ret(BigIntType.INSTANCE).args(TinyIntType.INSTANCE);
} else if (getArgument(0).getDataType() instanceof DecimalV2Type) {
return FunctionSignature.ret(DecimalV3Type.WILDCARD).args(DecimalV3Type.WILDCARD);
}
return ExplicitlyCastableSignature.super.searchSignature(signatures);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,13 @@
import org.apache.doris.nereids.types.BigIntType;
import org.apache.doris.nereids.types.BooleanType;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.types.DecimalV2Type;
import org.apache.doris.nereids.types.DecimalV3Type;
import org.apache.doris.nereids.types.DoubleType;
import org.apache.doris.nereids.types.FloatType;
import org.apache.doris.nereids.types.IntegerType;
import org.apache.doris.nereids.types.LargeIntType;
import org.apache.doris.nereids.types.NullType;
import org.apache.doris.nereids.types.SmallIntType;
import org.apache.doris.nereids.types.TinyIntType;

Expand All @@ -57,14 +59,15 @@ public class Sum0 extends NotNullableAggregateFunction
SupportWindowAnalytic, RollUpTrait {

public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
FunctionSignature.ret(BigIntType.INSTANCE).args(BooleanType.INSTANCE),
FunctionSignature.ret(BigIntType.INSTANCE).args(TinyIntType.INSTANCE),
FunctionSignature.ret(BigIntType.INSTANCE).args(SmallIntType.INSTANCE),
FunctionSignature.ret(BigIntType.INSTANCE).args(IntegerType.INSTANCE),
FunctionSignature.ret(BigIntType.INSTANCE).args(BigIntType.INSTANCE),
FunctionSignature.ret(LargeIntType.INSTANCE).args(LargeIntType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(DoubleType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(FloatType.INSTANCE),
FunctionSignature.ret(DecimalV3Type.WILDCARD).args(DecimalV3Type.WILDCARD),
FunctionSignature.ret(DoubleType.INSTANCE).args(DoubleType.INSTANCE)
FunctionSignature.ret(LargeIntType.INSTANCE).args(LargeIntType.INSTANCE),
FunctionSignature.ret(BigIntType.INSTANCE).args(BigIntType.INSTANCE),
FunctionSignature.ret(BigIntType.INSTANCE).args(IntegerType.INSTANCE),
FunctionSignature.ret(BigIntType.INSTANCE).args(SmallIntType.INSTANCE),
FunctionSignature.ret(BigIntType.INSTANCE).args(TinyIntType.INSTANCE),
FunctionSignature.ret(BigIntType.INSTANCE).args(BooleanType.INSTANCE)
);

/**
Expand All @@ -90,9 +93,9 @@ public MultiDistinctSum0 convertToMultiDistinct() {
@Override
public void checkLegalityBeforeTypeCoercion() {
DataType argType = child().getDataType();
if ((!argType.isNumericType() && !argType.isBooleanType() && !argType.isNullType())
|| argType.isOnlyMetricType()) {
throw new AnalysisException("sum0 requires a numeric or boolean parameter: " + this.toSql());
if (!argType.isNumericType() && !argType.isBooleanType()
&& !argType.isNullType() && !argType.isStringLikeType()) {
throw new AnalysisException("sum0 requires a numeric, boolean or string parameter: " + this.toSql());
}
}

Expand All @@ -117,8 +120,10 @@ public List<FunctionSignature> getSignatures() {

@Override
public FunctionSignature searchSignature(List<FunctionSignature> signatures) {
if (getArgument(0).getDataType() instanceof FloatType) {
return FunctionSignature.ret(DoubleType.INSTANCE).args(FloatType.INSTANCE);
if (getArgument(0).getDataType() instanceof NullType) {
return FunctionSignature.ret(BigIntType.INSTANCE).args(TinyIntType.INSTANCE);
} else if (getArgument(0).getDataType() instanceof DecimalV2Type) {
return FunctionSignature.ret(DecimalV3Type.WILDCARD).args(DecimalV3Type.WILDCARD);
}
return ExplicitlyCastableSignature.super.searchSignature(signatures);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ public class CheckExpressionLegalityTest implements MemoPatternMatchSupported {
public void testAvg() {
ConnectContext connectContext = MemoTestUtils.createConnectContext();
ExceptionChecker.expectThrowsWithMsg(
AnalysisException.class, "avg requires a numeric parameter", () -> {
AnalysisException.class, "avg requires a numeric", () -> {
PlanChecker.from(connectContext)
.analyze("select avg(id) from (select to_bitmap(1) id) tbl");
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,16 @@

package org.apache.doris.nereids.trees.expressions;

import org.apache.doris.nereids.trees.expressions.functions.agg.Avg;
import org.apache.doris.nereids.trees.expressions.functions.agg.Sum;
import org.apache.doris.nereids.trees.expressions.functions.agg.Sum0;
import org.apache.doris.nereids.trees.expressions.literal.BigIntLiteral;
import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral;
import org.apache.doris.nereids.trees.expressions.literal.CharLiteral;
import org.apache.doris.nereids.trees.expressions.literal.DateLiteral;
import org.apache.doris.nereids.trees.expressions.literal.DateTimeLiteral;
import org.apache.doris.nereids.trees.expressions.literal.DecimalLiteral;
import org.apache.doris.nereids.trees.expressions.literal.DecimalV3Literal;
import org.apache.doris.nereids.trees.expressions.literal.DoubleLiteral;
import org.apache.doris.nereids.trees.expressions.literal.FloatLiteral;
import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral;
Expand All @@ -35,6 +38,7 @@
import org.apache.doris.nereids.trees.expressions.literal.VarcharLiteral;
import org.apache.doris.nereids.types.BigIntType;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.types.DecimalV2Type;
import org.apache.doris.nereids.types.DecimalV3Type;
import org.apache.doris.nereids.types.DoubleType;
import org.apache.doris.nereids.types.LargeIntType;
Expand All @@ -57,6 +61,7 @@ public class GetDataTypeTest {
FloatLiteral floatLiteral = new FloatLiteral(1.0F);
DoubleLiteral doubleLiteral = new DoubleLiteral(1.0);
DecimalLiteral decimalLiteral = new DecimalLiteral(BigDecimal.ONE);
DecimalV3Literal decimalV3Literal = new DecimalV3Literal(new BigDecimal("123.123456"));
CharLiteral charLiteral = new CharLiteral("hello", 5);
VarcharLiteral varcharLiteral = new VarcharLiteral("hello", 5);
StringLiteral stringLiteral = new StringLiteral("hello");
Expand All @@ -75,14 +80,55 @@ public void testSum() {
Assertions.assertEquals(DoubleType.INSTANCE, checkAndGetDataType(new Sum(floatLiteral)));
Assertions.assertEquals(DoubleType.INSTANCE, checkAndGetDataType(new Sum(doubleLiteral)));
Assertions.assertEquals(DecimalV3Type.createDecimalV3Type(38, 0), checkAndGetDataType(new Sum(decimalLiteral)));
Assertions.assertEquals(BigIntType.INSTANCE, checkAndGetDataType(new Sum(bigIntLiteral)));
Assertions.assertThrows(RuntimeException.class, () -> checkAndGetDataType(new Sum(charLiteral)));
Assertions.assertThrows(RuntimeException.class, () -> checkAndGetDataType(new Sum(varcharLiteral)));
Assertions.assertThrows(RuntimeException.class, () -> checkAndGetDataType(new Sum(stringLiteral)));
Assertions.assertEquals(DecimalV3Type.createDecimalV3Type(38, 6), checkAndGetDataType(new Sum(decimalV3Literal)));
Assertions.assertEquals(DoubleType.INSTANCE, checkAndGetDataType(new Sum(charLiteral)));
Assertions.assertEquals(DoubleType.INSTANCE, checkAndGetDataType(new Sum(varcharLiteral)));
Assertions.assertEquals(DoubleType.INSTANCE, checkAndGetDataType(new Sum(stringLiteral)));
Assertions.assertThrows(RuntimeException.class, () -> checkAndGetDataType(new Sum(dateLiteral)));
Assertions.assertThrows(RuntimeException.class, () -> checkAndGetDataType(new Sum(dateTimeLiteral)));
}

@Test
public void testSum0() {
Assertions.assertEquals(BigIntType.INSTANCE, checkAndGetDataType(new Sum0(nullLiteral)));
Assertions.assertEquals(BigIntType.INSTANCE, checkAndGetDataType(new Sum0(booleanLiteral)));
Assertions.assertEquals(BigIntType.INSTANCE, checkAndGetDataType(new Sum0(tinyIntLiteral)));
Assertions.assertEquals(BigIntType.INSTANCE, checkAndGetDataType(new Sum0(smallIntLiteral)));
Assertions.assertEquals(BigIntType.INSTANCE, checkAndGetDataType(new Sum0(integerLiteral)));
Assertions.assertEquals(BigIntType.INSTANCE, checkAndGetDataType(new Sum0(bigIntLiteral)));
Assertions.assertEquals(LargeIntType.INSTANCE, checkAndGetDataType(new Sum0(largeIntLiteral)));
Assertions.assertEquals(DoubleType.INSTANCE, checkAndGetDataType(new Sum0(floatLiteral)));
Assertions.assertEquals(DoubleType.INSTANCE, checkAndGetDataType(new Sum0(doubleLiteral)));
Assertions.assertEquals(DecimalV3Type.createDecimalV3Type(38, 0), checkAndGetDataType(new Sum0(decimalLiteral)));
Assertions.assertEquals(DecimalV3Type.createDecimalV3Type(38, 6), checkAndGetDataType(new Sum0(decimalV3Literal)));
Assertions.assertEquals(DoubleType.INSTANCE, checkAndGetDataType(new Sum0(charLiteral)));
Assertions.assertEquals(DoubleType.INSTANCE, checkAndGetDataType(new Sum0(varcharLiteral)));
Assertions.assertEquals(DoubleType.INSTANCE, checkAndGetDataType(new Sum0(stringLiteral)));
Assertions.assertThrows(RuntimeException.class, () -> checkAndGetDataType(new Sum0(dateLiteral)));
Assertions.assertThrows(RuntimeException.class, () -> checkAndGetDataType(new Sum0(dateTimeLiteral)));
}

@Test
public void testAvg() {
Assertions.assertEquals(DoubleType.INSTANCE, checkAndGetDataType(new Avg(nullLiteral)));
Assertions.assertEquals(DoubleType.INSTANCE, checkAndGetDataType(new Avg(booleanLiteral)));
Assertions.assertEquals(DoubleType.INSTANCE, checkAndGetDataType(new Avg(tinyIntLiteral)));
Assertions.assertEquals(DoubleType.INSTANCE, checkAndGetDataType(new Avg(smallIntLiteral)));
Assertions.assertEquals(DoubleType.INSTANCE, checkAndGetDataType(new Avg(integerLiteral)));
Assertions.assertEquals(DoubleType.INSTANCE, checkAndGetDataType(new Avg(bigIntLiteral)));
Assertions.assertEquals(DoubleType.INSTANCE, checkAndGetDataType(new Avg(largeIntLiteral)));
Assertions.assertEquals(DoubleType.INSTANCE, checkAndGetDataType(new Avg(floatLiteral)));
Assertions.assertEquals(DoubleType.INSTANCE, checkAndGetDataType(new Avg(doubleLiteral)));
Assertions.assertEquals(DecimalV2Type.createDecimalV2Type(27, 9), checkAndGetDataType(new Avg(decimalLiteral)));
Assertions.assertEquals(DecimalV3Type.createDecimalV3Type(38, 6), checkAndGetDataType(new Avg(decimalV3Literal)));
Assertions.assertEquals(DoubleType.INSTANCE, checkAndGetDataType(new Avg(bigIntLiteral)));
Assertions.assertEquals(DoubleType.INSTANCE, checkAndGetDataType(new Avg(charLiteral)));
Assertions.assertEquals(DoubleType.INSTANCE, checkAndGetDataType(new Avg(varcharLiteral)));
Assertions.assertEquals(DoubleType.INSTANCE, checkAndGetDataType(new Avg(stringLiteral)));
Assertions.assertThrows(RuntimeException.class, () -> checkAndGetDataType(new Avg(dateLiteral)));
Assertions.assertThrows(RuntimeException.class, () -> checkAndGetDataType(new Avg(dateTimeLiteral)));
}

private DataType checkAndGetDataType(Expression expression) {
expression.checkLegalityBeforeTypeCoercion();
expression.checkLegalityAfterRewrite();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2338,6 +2338,20 @@ suite("nereids_agg_fn") {
qt_sql_sum_LargeInt_agg_phase_4_notnull '''
select /*+SET_VAR(disable_nereids_rules='THREE_PHASE_AGGREGATE_WITH_DISTINCT, TWO_PHASE_AGGREGATE_WITH_DISTINCT')*/ count(distinct id), sum(klint) from fn_test'''

// sum on string like
explain {
sql("select sum(kstr) from fn_test;")
contains "partial_sum(cast(kstr as DOUBLE"
}
explain {
sql("select sum(kvchrs3) from fn_test;")
contains "partial_sum(cast(kvchrs3 as DOUBLE"
}
explain {
sql("select sum(kchrs3) from fn_test;")
contains "partial_sum(cast(kchrs3 as DOUBLE"
}

qt_sql_sum0_Boolean '''
select sum0(kbool) from fn_test'''
qt_sql_sum0_Boolean_gb '''
Expand Down Expand Up @@ -2523,6 +2537,20 @@ suite("nereids_agg_fn") {
qt_sql_sum0_LargeInt_agg_phase_4_notnull '''
select /*+SET_VAR(disable_nereids_rules='THREE_PHASE_AGGREGATE_WITH_DISTINCT, TWO_PHASE_AGGREGATE_WITH_DISTINCT')*/ count(distinct id), sum0(klint) from fn_test'''

// sum on string like
explain {
sql("select sum0(kstr) from fn_test;")
contains "partial_sum0(cast(kstr as DOUBLE"
}
explain {
sql("select sum0(kvchrs3) from fn_test;")
contains "partial_sum0(cast(kvchrs3 as DOUBLE"
}
explain {
sql("select sum0(kchrs3) from fn_test;")
contains "partial_sum0(cast(kchrs3 as DOUBLE"
}

qt_sql_topn_Varchar_Integer_gb '''
select topn(kvchrs1, 3) from fn_test group by kbool order by kbool'''
qt_sql_topn_Varchar_Integer '''
Expand Down
Loading