Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 45 additions & 23 deletions be/src/vec/aggregate_functions/aggregate_function_avg.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,43 +56,48 @@ template <PrimitiveType T>
struct AggregateFunctionAvgData {
using ResultType = typename PrimitiveTypeTraits<T>::ColumnItemType;
static constexpr PrimitiveType ResultPType = T;
typename PrimitiveTypeTraits<T>::ColumnItemType sum {};
ResultType sum {};
UInt64 count = 0;

AggregateFunctionAvgData& operator=(const AggregateFunctionAvgData<T>& src) {
sum = src.sum;
count = src.count;
return *this;
}
AggregateFunctionAvgData& operator=(const AggregateFunctionAvgData<T>& src) = default;

template <typename ResultT>
ResultT result() const {
if constexpr (std::is_floating_point_v<ResultT>) {
if constexpr (std::numeric_limits<ResultT>::is_iec559) {
return static_cast<ResultT>(sum) /
static_cast<ResultT>(count); /// allow division by zero
}
}

ResultT result(ResultType multiplier) const {
if (!count) {
// null is handled in AggregationNode::_get_without_key_result
return static_cast<ResultT>(sum);
}
// to keep the same result with row vesion; see AggregateFunctions::decimalv2_avg_get_value
if constexpr (T == TYPE_DECIMALV2 && IsDecimalV2<ResultT>) {
DecimalV2Value decimal_val_count(count, 0);
DecimalV2Value decimal_val_sum(sum);
DecimalV2Value decimal_val_sum(sum * multiplier);
DecimalV2Value cal_ret = decimal_val_sum / decimal_val_count;
Decimal128V2 ret(cal_ret.value());
return ret;
} else {
if constexpr (T == TYPE_DECIMAL256) {
return static_cast<ResultT>(sum /
return static_cast<ResultT>(sum * multiplier /
typename PrimitiveTypeTraits<T>::ColumnItemType(count));
} else {
return static_cast<ResultT>(sum) / static_cast<ResultT>(count);
return static_cast<ResultT>(sum * multiplier) / static_cast<ResultT>(count);
}
}
}

template <typename ResultT>
ResultT result() const {
if constexpr (std::is_floating_point_v<ResultT>) {
if constexpr (std::numeric_limits<ResultT>::is_iec559) {
return static_cast<ResultT>(sum) /
static_cast<ResultT>(count); /// allow division by zero
}
}

if (!count) {
// null is handled in AggregationNode::_get_without_key_result
return static_cast<ResultT>(sum);
}
return static_cast<ResultT>(sum) / static_cast<ResultT>(count);
}

void write(BufferWritable& buf) const {
Expand Down Expand Up @@ -131,17 +136,29 @@ class AggregateFunctionAvg<T, TResult, Data> final
// an implicit cast to float.

using DataType = typename Data::ResultType;

// consistent with fe/fe-common/src/main/java/org/apache/doris/catalog/ScalarType.java
static constexpr uint32_t DEFAULT_MIN_AVG_DECIMAL128_SCALE = 4;

/// ctor for native types
AggregateFunctionAvg(const DataTypes& argument_types_)
: IAggregateFunctionDataHelper<Data, AggregateFunctionAvg<T, TResult, Data>>(
argument_types_),
scale(get_decimal_scale(*argument_types_[0])) {}
output_scale(std::max(DEFAULT_MIN_AVG_DECIMAL128_SCALE,
get_decimal_scale(*argument_types_[0]))) {
if constexpr (is_decimal(T)) {
multiplier = ResultType(ResultDataType::get_scale_multiplier(
output_scale - get_decimal_scale(*argument_types_[0])));
}
}

String get_name() const override { return "avg"; }

DataTypePtr get_return_type() const override {
if constexpr (is_decimal(T)) {
return std::make_shared<ResultDataType>(ResultDataType::max_precision(), scale);
return std::make_shared<ResultDataType>(
ResultDataType::max_precision(),
std::max(DEFAULT_MIN_AVG_DECIMAL128_SCALE, output_scale));
} else {
return std::make_shared<ResultDataType>();
}
Expand All @@ -157,14 +174,14 @@ class AggregateFunctionAvg<T, TResult, Data> final
assert_cast<const ColVecType&, TypeCheckOnRelease::DISABLE>(*columns[0]);
if constexpr (is_add) {
if constexpr (is_decimal(T)) {
this->data(place).sum += (DataType)column.get_data()[row_num].value;
this->data(place).sum += column.get_data()[row_num].value;
} else {
this->data(place).sum += (DataType)column.get_data()[row_num];
}
++this->data(place).count;
} else {
if constexpr (is_decimal(T)) {
this->data(place).sum -= (DataType)column.get_data()[row_num].value;
this->data(place).sum -= column.get_data()[row_num].value;
} else {
this->data(place).sum -= (DataType)column.get_data()[row_num];
}
Expand Down Expand Up @@ -203,7 +220,11 @@ class AggregateFunctionAvg<T, TResult, Data> final

void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override {
auto& column = assert_cast<ColVecResult&>(to);
column.get_data().push_back(this->data(place).template result<ResultType>());
if constexpr (is_decimal(T)) {
column.get_data().push_back(this->data(place).template result<ResultType>(multiplier));
} else {
column.get_data().push_back(this->data(place).template result<ResultType>());
}
}

void deserialize_from_column(AggregateDataPtr places, const IColumn& column, Arena&,
Expand Down Expand Up @@ -351,7 +372,8 @@ class AggregateFunctionAvg<T, TResult, Data> final
}

private:
UInt32 scale;
uint32_t output_scale;
ResultType multiplier;
};

} // namespace doris::vectorized
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,28 +105,16 @@ public FunctionSignature computePrecision(FunctionSignature signature) {
}
DecimalV3Type decimalV3Type = DecimalV3Type.forType(argumentType);
// DecimalV3 scale lower than DEFAULT_MIN_AVG_DECIMAL128_SCALE should do cast
int precision = decimalV3Type.getPrecision();
int scale = decimalV3Type.getScale();
if (decimalV3Type.getScale() < ScalarType.DEFAULT_MIN_AVG_DECIMAL128_SCALE) {
scale = ScalarType.DEFAULT_MIN_AVG_DECIMAL128_SCALE;
precision = precision - decimalV3Type.getScale() + scale;
if (enableDecimal256) {
if (precision > DecimalV3Type.MAX_DECIMAL256_PRECISION) {
precision = DecimalV3Type.MAX_DECIMAL256_PRECISION;
}
} else {
if (precision > DecimalV3Type.MAX_DECIMAL128_PRECISION) {
precision = DecimalV3Type.MAX_DECIMAL128_PRECISION;
}
}
}
decimalV3Type = DecimalV3Type.createDecimalV3Type(precision, scale);
return signature.withArgumentType(0, decimalV3Type)
.withReturnType(DecimalV3Type.createDecimalV3Type(
enableDecimal256 ? DecimalV3Type.MAX_DECIMAL256_PRECISION
: DecimalV3Type.MAX_DECIMAL128_PRECISION,
decimalV3Type.getScale()
));
scale)
);
} else {
return signature;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.doris.nereids.trees.expressions.functions.scalar;

import org.apache.doris.catalog.FunctionSignature;
import org.apache.doris.catalog.ScalarType;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NeedSessionVarGuard;
Expand All @@ -28,13 +29,15 @@
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.ArrayType;
import org.apache.doris.nereids.types.BigIntType;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.types.DecimalV3Type;
import org.apache.doris.nereids.types.DoubleType;
import org.apache.doris.nereids.types.FloatType;
import org.apache.doris.nereids.types.IntegerType;
import org.apache.doris.nereids.types.LargeIntType;
import org.apache.doris.nereids.types.SmallIntType;
import org.apache.doris.nereids.types.TinyIntType;
import org.apache.doris.qe.ConnectContext;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
Expand Down Expand Up @@ -70,35 +73,6 @@ private ArrayAvg(ScalarFunctionParams functionParams) {
super(functionParams);
}

// TODO use this computePrecision if be support dynamic scale
// @Override
// public FunctionSignature computePrecision(FunctionSignature signature) {
// DataType argumentType = getArgumentType(0);
// if (argumentType instanceof ArrayType) {
// DataType argType = ((ArrayType) argumentType).getItemType();
// DataType sigType = ((ArrayType) signature.getArgType(0)).getItemType();
// if (sigType instanceof DecimalV3Type) {
// DecimalV3Type decimalV3Type = DecimalV3Type.forType(argType);
// // DecimalV3 scale lower than DEFAULT_MIN_AVG_DECIMAL128_SCALE should do cast
// int precision = decimalV3Type.getPrecision();
// int scale = decimalV3Type.getScale();
// if (decimalV3Type.getScale() < ScalarType.DEFAULT_MIN_AVG_DECIMAL128_SCALE) {
// scale = ScalarType.DEFAULT_MIN_AVG_DECIMAL128_SCALE;
// precision = precision - decimalV3Type.getScale() + scale;
// if (precision > DecimalV3Type.MAX_DECIMAL128_PRECISION) {
// precision = DecimalV3Type.MAX_DECIMAL128_PRECISION;
// }
// }
// decimalV3Type = DecimalV3Type.createDecimalV3Type(precision, scale);
// return signature.withArgumentType(0, ArrayType.of(decimalV3Type))
// .withReturnType(ArrayType.of(DecimalV3Type.createDecimalV3Type(
// DecimalV3Type.MAX_DECIMAL128_PRECISION, decimalV3Type.getScale()
// )));
// }
// }
// return signature;
// }

/**
* array_avg needs to calculate the average of the elements in the array.
* so the element type must be numeric, boolean or string.
Expand All @@ -112,6 +86,33 @@ public void checkLegalityBeforeTypeCoercion() {
}
}

@Override
public FunctionSignature computePrecision(FunctionSignature signature) {
if (!(getArgumentType(0) instanceof ArrayType)) {
return signature;
}
DataType argumentType = ((ArrayType) getArgumentType(0)).getItemType();
if (!(argumentType instanceof DecimalV3Type)) {
return signature;
}
boolean enableDecimal256 = false;
ConnectContext connectContext = ConnectContext.get();
if (connectContext != null) {
enableDecimal256 = connectContext.getSessionVariable().isEnableDecimal256();
}
DecimalV3Type decimalV3Type = DecimalV3Type.forType(argumentType);
// DecimalV3 scale lower than DEFAULT_MIN_AVG_DECIMAL128_SCALE should do cast
int scale = decimalV3Type.getScale();
if (decimalV3Type.getScale() < ScalarType.DEFAULT_MIN_AVG_DECIMAL128_SCALE) {
scale = ScalarType.DEFAULT_MIN_AVG_DECIMAL128_SCALE;
}
return signature.withReturnType(DecimalV3Type.createDecimalV3Type(
enableDecimal256 ? DecimalV3Type.MAX_DECIMAL256_PRECISION
: DecimalV3Type.MAX_DECIMAL128_PRECISION,
scale)
);
}

/**
* withChildren.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
2.0

-- !sql_5 --
1.6
1.6666

-- !sql_6 --
5.0
Expand Down Expand Up @@ -294,7 +294,7 @@ true
2023-02-05

-- !sql --
166.666
166.6665

-- !sql --
333.333
Expand Down Expand Up @@ -1005,7 +1005,7 @@ _
2.0

-- !sql_5 --
1.6
1.6666

-- !sql_6 --
5.0
Expand Down Expand Up @@ -1278,7 +1278,7 @@ true
2023-02-05

-- !sql --
166.666
166.6665

-- !sql --
333.333
Expand Down Expand Up @@ -1989,7 +1989,7 @@ _
2.0

-- !sql_5 --
1.6
1.6666

-- !sql_6 --
5.0
Expand Down Expand Up @@ -2262,7 +2262,7 @@ true
2023-02-05

-- !sql --
166.666
166.6665

-- !sql --
333.333
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ PhysicalCteAnchor ( cteId=CTEId#0 )
------PhysicalDistribute[DistributionSpecGather]
--------PhysicalTopN[LOCAL_SORT]
----------PhysicalProject
------------hashJoin[INNER_JOIN shuffleBucket] hashCondition=((ctr1.ctr_store_sk = ctr2.ctr_store_sk)) otherCondition=((cast(ctr_total_return as DECIMALV3(38, 5)) > (avg(cast(ctr_total_return as DECIMALV3(38, 4))) * 1.2)))
------------hashJoin[INNER_JOIN shuffleBucket] hashCondition=((ctr1.ctr_store_sk = ctr2.ctr_store_sk)) otherCondition=((cast(ctr_total_return as DECIMALV3(38, 5)) > (avg(ctr_total_return) * 1.2)))
--------------PhysicalProject
----------------hashJoin[INNER_JOIN broadcast] hashCondition=((store.s_store_sk = ctr1.ctr_store_sk)) otherCondition=() build RFs:RF2 s_store_sk->[ctr_store_sk]
------------------PhysicalProject
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ PhysicalCteAnchor ( cteId=CTEId#0 )
----------PhysicalDistribute[DistributionSpecGather]
------------PhysicalTopN[LOCAL_SORT]
--------------PhysicalProject
----------------hashJoin[INNER_JOIN shuffleBucket] hashCondition=((ctr1.ctr_state = ctr2.ctr_state)) otherCondition=((cast(ctr_total_return as DECIMALV3(38, 5)) > (avg(cast(ctr_total_return as DECIMALV3(38, 4))) * 1.2)))
----------------hashJoin[INNER_JOIN shuffleBucket] hashCondition=((ctr1.ctr_state = ctr2.ctr_state)) otherCondition=((cast(ctr_total_return as DECIMALV3(38, 5)) > (avg(ctr_total_return) * 1.2)))
------------------PhysicalProject
--------------------hashJoin[INNER_JOIN broadcast] hashCondition=((customer_address.ca_address_sk = customer.c_current_addr_sk)) otherCondition=() build RFs:RF3 ca_address_sk->[c_current_addr_sk]
----------------------PhysicalProject
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ PhysicalResultSink
--------PhysicalDistribute[DistributionSpecGather]
----------hashAgg[LOCAL]
------------PhysicalProject
--------------filter((cast(cs_ext_discount_amt as DECIMALV3(38, 5)) > (1.3 * avg(cast(cs_ext_discount_amt as DECIMALV3(9, 4))) OVER(PARTITION BY i_item_sk))))
--------------filter((cast(cs_ext_discount_amt as DECIMALV3(38, 5)) > (1.3 * avg(cs_ext_discount_amt) OVER(PARTITION BY i_item_sk))))
----------------PhysicalWindow
------------------PhysicalQuickSort[LOCAL_SORT]
--------------------PhysicalDistribute[DistributionSpecHash]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,25 +7,24 @@ PhysicalCteAnchor ( cteId=CTEId#0 )
--------PhysicalQuickSort[LOCAL_SORT]
----------PhysicalWindow
------------PhysicalQuickSort[LOCAL_SORT]
--------------PhysicalProject
----------------hashAgg[GLOBAL]
------------------PhysicalDistribute[DistributionSpecHash]
--------------------hashAgg[LOCAL]
----------------------PhysicalProject
------------------------hashJoin[INNER_JOIN broadcast] hashCondition=((store_sales.ss_store_sk = store.s_store_sk)) otherCondition=()
--------------------------PhysicalProject
----------------------------hashJoin[INNER_JOIN broadcast] hashCondition=((store_sales.ss_sold_date_sk = date_dim.d_date_sk)) otherCondition=() build RFs:RF1 d_date_sk->[ss_sold_date_sk]
------------------------------PhysicalProject
--------------------------------hashJoin[INNER_JOIN broadcast] hashCondition=((store_sales.ss_item_sk = item.i_item_sk)) otherCondition=()
----------------------------------PhysicalProject
------------------------------------PhysicalOlapScan[store_sales] apply RFs: RF1
----------------------------------PhysicalProject
------------------------------------PhysicalOlapScan[item]
------------------------------PhysicalProject
--------------------------------filter(OR[(date_dim.d_year = 2001),AND[(date_dim.d_year = 2000),(date_dim.d_moy = 12)],AND[(date_dim.d_year = 2002),(date_dim.d_moy = 1)]] and d_year IN (2000, 2001, 2002))
----------------------------------PhysicalOlapScan[date_dim]
--------------------------PhysicalProject
----------------------------PhysicalOlapScan[store]
--------------hashAgg[GLOBAL]
----------------PhysicalDistribute[DistributionSpecHash]
------------------hashAgg[LOCAL]
--------------------PhysicalProject
----------------------hashJoin[INNER_JOIN broadcast] hashCondition=((store_sales.ss_store_sk = store.s_store_sk)) otherCondition=()
------------------------PhysicalProject
--------------------------hashJoin[INNER_JOIN broadcast] hashCondition=((store_sales.ss_sold_date_sk = date_dim.d_date_sk)) otherCondition=() build RFs:RF1 d_date_sk->[ss_sold_date_sk]
----------------------------PhysicalProject
------------------------------hashJoin[INNER_JOIN broadcast] hashCondition=((store_sales.ss_item_sk = item.i_item_sk)) otherCondition=()
--------------------------------PhysicalProject
----------------------------------PhysicalOlapScan[store_sales] apply RFs: RF1
--------------------------------PhysicalProject
----------------------------------PhysicalOlapScan[item]
----------------------------PhysicalProject
------------------------------filter(OR[(date_dim.d_year = 2001),AND[(date_dim.d_year = 2000),(date_dim.d_moy = 12)],AND[(date_dim.d_year = 2002),(date_dim.d_moy = 1)]] and d_year IN (2000, 2001, 2002))
--------------------------------PhysicalOlapScan[date_dim]
------------------------PhysicalProject
--------------------------PhysicalOlapScan[store]
--PhysicalResultSink
----PhysicalProject
------PhysicalTopN[MERGE_SORT]
Expand Down
Loading
Loading