Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions be/src/olap/memtable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,13 +114,13 @@ void MemTable::_init_agg_functions(const vectorized::Block* block) {
// the aggregate function manually.
if (_skip_bitmap_col_idx != cid) {
function = vectorized::AggregateFunctionSimpleFactory::instance().get(
"replace_load", {block->get_data_type(cid)},
"replace_load", {block->get_data_type(cid)}, block->get_data_type(cid),
block->get_data_type(cid)->is_nullable(),
BeExecVersionManager::get_newest_version());
} else {
function = vectorized::AggregateFunctionSimpleFactory::instance().get(
"bitmap_intersect", {block->get_data_type(cid)}, false,
BeExecVersionManager::get_newest_version());
"bitmap_intersect", {block->get_data_type(cid)}, block->get_data_type(cid),
false, BeExecVersionManager::get_newest_version());
}
} else {
function = _tablet_schema->column(cid).get_aggregate_function(
Expand Down
3 changes: 2 additions & 1 deletion be/src/olap/tablet_schema.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -776,7 +776,8 @@ vectorized::AggregateFunctionPtr TabletColumn::get_aggregate_function(
std::transform(agg_name.begin(), agg_name.end(), agg_name.begin(),
[](unsigned char c) { return std::tolower(c); });
function = vectorized::AggregateFunctionSimpleFactory::instance().get(
agg_name, {type}, type->is_nullable(), BeExecVersionManager::get_newest_version());
agg_name, {type}, type, type->is_nullable(),
BeExecVersionManager::get_newest_version());
if (!function) {
LOG(WARNING) << "get column aggregate function failed, aggregation_name=" << origin_name
<< ", column_type=" << type->get_name();
Expand Down
8 changes: 7 additions & 1 deletion be/src/pipeline/exec/analytic_sink_operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -186,9 +186,15 @@ class AnalyticSinkOperatorX final : public DataSinkOperatorX<AnalyticSinkLocalSt
#ifdef BE_TEST
AnalyticSinkOperatorX(ObjectPool* pool)
: _pool(pool),
_intermediate_tuple_id(0),
_output_tuple_id(0),
_buffered_tuple_id(0),
_is_colocate(false),
_require_bucket_distribution(false) {}
_require_bucket_distribution(false),
_has_window(false),
_has_range_window(false),
_has_window_start(false),
_has_window_end(false) {}
#endif

Status init(const TDataSink& tsink) override {
Expand Down
129 changes: 17 additions & 112 deletions be/src/runtime/primitive_type.h

Large diffs are not rendered by default.

4 changes: 0 additions & 4 deletions be/src/runtime/runtime_state.h
Original file line number Diff line number Diff line change
Expand Up @@ -208,10 +208,6 @@ class RuntimeState {
return _query_options.__isset.enable_insert_strict && _query_options.enable_insert_strict;
}

bool enable_decimal256() const {
return _query_options.__isset.enable_decimal256 && _query_options.enable_decimal256;
}

bool enable_common_expr_pushdown() const {
return _query_options.__isset.enable_common_expr_pushdown &&
_query_options.enable_common_expr_pushdown;
Expand Down
1 change: 0 additions & 1 deletion be/src/vec/aggregate_functions/aggregate_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ class IColumn;
class IDataType;

struct AggregateFunctionAttr {
bool enable_decimal256 {false};
bool is_window_function {false};
bool is_foreach {false};
std::vector<std::string> column_names;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ QueryContext* AggregateFunctionAIAggData::_ctx = nullptr;
void register_aggregate_function_ai_agg(AggregateFunctionSimpleFactory& factory) {
factory.register_function_both("ai_agg",
[](const std::string& name, const DataTypes& argument_types,
const bool result_is_nullable,
const DataTypePtr& result_type, const bool result_is_nullable,
const AggregateFunctionAttr& attr) -> AggregateFunctionPtr {
return creator_without_type::create<AggregateFunctionAIAgg>(
argument_types, result_is_nullable, attr);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ namespace doris::vectorized {
#include "common/compile_check_begin.h"

AggregateFunctionPtr create_aggregate_function_approx_count_distinct(
const std::string& name, const DataTypes& argument_types, const bool result_is_nullable,
const AggregateFunctionAttr& attr) {
const std::string& name, const DataTypes& argument_types, const DataTypePtr& result_type,
const bool result_is_nullable, const AggregateFunctionAttr& attr) {
return creator_with_type_list<
TYPE_BOOLEAN, TYPE_TINYINT, TYPE_SMALLINT, TYPE_INT, TYPE_BIGINT, TYPE_LARGEINT,
TYPE_FLOAT, TYPE_DOUBLE, TYPE_DECIMAL32, TYPE_DECIMAL64, TYPE_DECIMAL128I,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ AggregateFunctionPtr do_create_agg_function_collect(const DataTypes& argument_ty

AggregateFunctionPtr create_aggregate_function_array_agg(const std::string& name,
const DataTypes& argument_types,
const DataTypePtr& result_type,
const bool result_is_nullable,
const AggregateFunctionAttr& attr) {
switch (argument_types[0]->get_primitive_type()) {
Expand Down
36 changes: 19 additions & 17 deletions be/src/vec/aggregate_functions/aggregate_function_avg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,41 +28,43 @@
namespace doris::vectorized {
#include "common/compile_check_begin.h"

// TODO: use result type got from FE plan
template <PrimitiveType T>
struct Avg {
using FieldType = typename PrimitiveTypeTraits<T>::AvgNearestFieldType;
static constexpr PrimitiveType ResultPType = T == TYPE_DECIMALV2 ? T : TYPE_DOUBLE;
using Function = AggregateFunctionAvg<
T, AggregateFunctionAvgData<PrimitiveTypeTraits<T>::AvgNearestPrimitiveType>>;
T, ResultPType,
AggregateFunctionAvgData<PrimitiveTypeTraits<T>::AvgNearestPrimitiveType>>;
};

template <PrimitiveType T>
using AggregateFuncAvg = typename Avg<T>::Function;

template <PrimitiveType T>
struct AvgDecimal256 {
using FieldType = typename PrimitiveTypeTraits<T>::AvgNearestFieldType256;
using Function = AggregateFunctionAvg<
T, AggregateFunctionAvgData<PrimitiveTypeTraits<T>::AvgNearestPrimitiveType256>>;
// use result type got from FE plan
template <PrimitiveType InputType, PrimitiveType ResultType>
struct AvgDecimalV3 {
using Function =
AggregateFunctionAvg<InputType, ResultType, AggregateFunctionAvgData<ResultType>>;
};

template <PrimitiveType T>
using AggregateFuncAvgDecimal256 = typename AvgDecimal256<T>::Function;
template <PrimitiveType InputType, PrimitiveType ResultType>
using AggregateFuncAvgDecimalV3 = typename AvgDecimalV3<InputType, ResultType>::Function;

void register_aggregate_function_avg(AggregateFunctionSimpleFactory& factory) {
AggregateFunctionCreator creator = [&](const std::string& name, const DataTypes& types,
const DataTypePtr& result_type,
const bool result_is_nullable,
const AggregateFunctionAttr& attr) {
if (attr.enable_decimal256 && is_decimal(types[0]->get_primitive_type())) {
return creator_with_type_list<
TYPE_DECIMAL32, TYPE_DECIMAL64, TYPE_DECIMAL128I,
TYPE_DECIMAL256>::creator<AggregateFuncAvgDecimal256>(name, types,
result_is_nullable, attr);
if (is_decimalv3(types[0]->get_primitive_type())) {
return creator_with_type_list<TYPE_DECIMAL32, TYPE_DECIMAL64, TYPE_DECIMAL128I,
TYPE_DECIMAL256>::
creator_with_result_type<AggregateFuncAvgDecimalV3>(name, types, result_type,
result_is_nullable, attr);
} else {
return creator_with_type_list<
TYPE_TINYINT, TYPE_SMALLINT, TYPE_INT, TYPE_BIGINT, TYPE_LARGEINT, TYPE_DOUBLE,
TYPE_DECIMAL32, TYPE_DECIMAL64, TYPE_DECIMAL128I,
TYPE_DECIMALV2>::creator<AggregateFuncAvg>(name, types, result_is_nullable,
attr);
TYPE_DECIMALV2>::creator<AggregateFuncAvg>(name, types, result_type,
result_is_nullable, attr);
}
};
factory.register_function_both("avg", creator);
Expand Down
33 changes: 19 additions & 14 deletions be/src/vec/aggregate_functions/aggregate_function_avg.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#include <vector>

#include "runtime/decimalv2_value.h"
#include "runtime/primitive_type.h"
#include "vec/aggregate_functions/aggregate_function.h"
#include "vec/columns/column.h"
#include "vec/columns/column_fixed_length_object.h"
Expand Down Expand Up @@ -105,31 +106,35 @@ struct AggregateFunctionAvgData {
}
};

template <PrimitiveType T, PrimitiveType TResult, typename Data>
class AggregateFunctionAvg;

template <PrimitiveType T, PrimitiveType TResult>
constexpr static bool is_valid_avg_types =
(is_same_or_wider_decimalv3(T, TResult) || (is_decimalv2(T) && is_decimalv2(TResult)) ||
(is_float_or_double(T) && is_float_or_double(TResult)) ||
(is_int_or_bool(T) && (is_double(TResult) || is_int(TResult))));
/// Calculates arithmetic mean of numbers.
template <PrimitiveType T, typename Data>
class AggregateFunctionAvg final
: public IAggregateFunctionDataHelper<Data, AggregateFunctionAvg<T, Data>>,
template <PrimitiveType T, PrimitiveType TResult, typename Data>
requires(is_valid_avg_types<T, TResult>)
class AggregateFunctionAvg<T, TResult, Data> final
: public IAggregateFunctionDataHelper<Data, AggregateFunctionAvg<T, TResult, Data>>,
UnaryExpression,
NullableAggregateFunction {
public:
using ResultType = std::conditional_t<
T == TYPE_DECIMALV2, Decimal128V2,
std::conditional_t<is_decimal(T), typename Data::ResultType, Float64>>;
using ResultDataType = std::conditional_t<
T == TYPE_DECIMALV2, DataTypeDecimalV2,
std::conditional_t<is_decimal(T), DataTypeDecimal<Data::ResultPType>, DataTypeFloat64>>;
using ColVecType = typename PrimitiveTypeTraits<T>::ColumnType;
using ColVecResult = std::conditional_t<
T == TYPE_DECIMALV2, ColumnDecimal128V2,
std::conditional_t<is_decimal(T), ColumnDecimal<Data::ResultPType>, ColumnFloat64>>;
using ResultType = PrimitiveTypeTraits<TResult>::ColumnItemType;
using ResultDataType = PrimitiveTypeTraits<TResult>::DataType;
using ColVecType = PrimitiveTypeTraits<T>::ColumnType;
using ColVecResult = PrimitiveTypeTraits<TResult>::ColumnType;
// The result calculated by PercentileApprox is an approximate value,
// so the underlying storage uses float. The following calls will involve
// an implicit cast to float.

using DataType = typename Data::ResultType;
/// ctor for native types
AggregateFunctionAvg(const DataTypes& argument_types_)
: IAggregateFunctionDataHelper<Data, AggregateFunctionAvg<T, Data>>(argument_types_),
: IAggregateFunctionDataHelper<Data, AggregateFunctionAvg<T, TResult, Data>>(
argument_types_),
scale(get_decimal_scale(*argument_types_[0])) {}

String get_name() const override { return "avg"; }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ AggregateFunctionPtr create_with_int_data_type(const DataTypes& argument_type) {
}

AggregateFunctionPtr create_aggregate_function_bitmap_union_count(
const std::string& name, const DataTypes& argument_types, const bool result_is_nullable,
const AggregateFunctionAttr& attr) {
const std::string& name, const DataTypes& argument_types, const DataTypePtr& result_type,
const bool result_is_nullable, const AggregateFunctionAttr& attr) {
const bool arg_is_nullable = argument_types[0]->is_nullable();
if (arg_is_nullable) {
return std::make_shared<AggregateFunctionBitmapCount<true, ColumnBitmap>>(argument_types);
Expand All @@ -56,6 +56,7 @@ AggregateFunctionPtr create_aggregate_function_bitmap_union_count(

AggregateFunctionPtr create_aggregate_function_bitmap_union_int(const std::string& name,
const DataTypes& argument_types,
const DataTypePtr& result_type,
const bool result_is_nullable,
const AggregateFunctionAttr& attr) {
const bool arg_is_nullable = argument_types[0]->is_nullable();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ AggregateFunctionPtr create_with_int_data_type(const DataTypes& argument_types)

AggregateFunctionPtr create_aggregate_function_bitmap_agg(const std::string& name,
const DataTypes& argument_types,
const DataTypePtr& result_type,
const bool result_is_nullable,
const AggregateFunctionAttr& attr) {
const bool arg_is_nullable = argument_types[0]->is_nullable();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ AggregateFunctionPtr create_aggregate_function_collect_impl(const std::string& n

AggregateFunctionPtr create_aggregate_function_collect(const std::string& name,
const DataTypes& argument_types,
const DataTypePtr& result_type,
const bool result_is_nullable,
const AggregateFunctionAttr& attr) {
assert_arity_range(name, argument_types, 1, 2);
Expand Down
2 changes: 2 additions & 0 deletions be/src/vec/aggregate_functions/aggregate_function_corr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ using CorrMomentStat = StatFunc<T, CorrMoment>;

AggregateFunctionPtr create_aggregate_corr_function(const std::string& name,
const DataTypes& argument_types,
const DataTypePtr& result_type,
const bool result_is_nullable,
const AggregateFunctionAttr& attr) {
assert_arity_range(name, argument_types, 2, 2);
Expand All @@ -48,6 +49,7 @@ using CorrWelfordMomentStat = StatFunc<T, CorrMomentWelford>;

AggregateFunctionPtr create_aggregate_corr_welford_function(const std::string& name,
const DataTypes& argument_types,
const DataTypePtr& result_type,
const bool result_is_nullable,
const AggregateFunctionAttr& attr) {
assert_arity_range(name, argument_types, 2, 2);
Expand Down
5 changes: 3 additions & 2 deletions be/src/vec/aggregate_functions/aggregate_function_count.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ namespace doris::vectorized {

AggregateFunctionPtr create_aggregate_function_count(const std::string& name,
const DataTypes& argument_types,
const DataTypePtr& result_type,
const bool result_is_nullable,
const AggregateFunctionAttr& attr) {
assert_arity_range(name, argument_types, 0, 1);
Expand All @@ -38,8 +39,8 @@ AggregateFunctionPtr create_aggregate_function_count(const std::string& name,
}

AggregateFunctionPtr create_aggregate_function_count_not_null_unary(
const std::string& name, const DataTypes& argument_types, const bool result_is_nullable,
const AggregateFunctionAttr& attr) {
const std::string& name, const DataTypes& argument_types, const DataTypePtr& result_type,
const bool result_is_nullable, const AggregateFunctionAttr& attr) {
assert_arity_range(name, argument_types, 0, 1);

return std::make_shared<AggregateFunctionCountNotNullUnary>(argument_types);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ namespace doris::vectorized {

AggregateFunctionPtr create_aggregate_function_count_by_enum(const std::string& name,
const DataTypes& argument_types,
const DataTypePtr& result_type,
const bool result_is_nullable,
const AggregateFunctionAttr& attr) {
assert_arity_range(name, argument_types, 1, 1024);
Expand Down
2 changes: 2 additions & 0 deletions be/src/vec/aggregate_functions/aggregate_function_covar.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ AggregateFunctionPtr create_function_single_value(const String& name,

AggregateFunctionPtr create_aggregate_function_covariance_samp(const std::string& name,
const DataTypes& argument_types,
const DataTypePtr& result_type,
const bool result_is_nullable,
const AggregateFunctionAttr& attr) {
return create_function_single_value<AggregateFunctionSampCovariance, SampData>(
Expand All @@ -43,6 +44,7 @@ AggregateFunctionPtr create_aggregate_function_covariance_samp(const std::string

AggregateFunctionPtr create_aggregate_function_covariance_pop(const std::string& name,
const DataTypes& argument_types,
const DataTypePtr& result_type,
const bool result_is_nullable,
const AggregateFunctionAttr& attr) {
return create_function_single_value<AggregateFunctionSampCovariance, PopData>(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ class AggregateFunctionCombinatorDistinct final : public IAggregateFunctionCombi

void register_aggregate_function_combinator_distinct(AggregateFunctionSimpleFactory& factory) {
AggregateFunctionCreator creator = [&](const std::string& name, const DataTypes& types,
const DataTypePtr& result_type,
const bool result_is_nullable,
const AggregateFunctionAttr& attr) {
// 1. we should get not nullable types;
Expand All @@ -93,8 +94,8 @@ void register_aggregate_function_combinator_distinct(AggregateFunctionSimpleFact
auto function_combinator = std::make_shared<AggregateFunctionCombinatorDistinct>();
auto transform_arguments = function_combinator->transform_arguments(nested_types);
auto nested_function_name = name.substr(DISTINCT_FUNCTION_PREFIX.size());
auto nested_function = factory.get(nested_function_name, transform_arguments, false,
BeExecVersionManager::get_newest_version(), attr);
auto nested_function = factory.get(nested_function_name, transform_arguments, result_type,
false, BeExecVersionManager::get_newest_version(), attr);
return function_combinator->transform_aggregate_function(nested_function, types,
result_is_nullable, attr);
};
Expand Down
Loading
Loading