From 87ff5b42f1c0c8d62e66ccbf9f285d1162afe492 Mon Sep 17 00:00:00 2001 From: feiniaofeiafei Date: Fri, 5 Dec 2025 11:39:31 +0800 Subject: [PATCH 1/3] [feature](variable) enables views, materialized views, generated columns, and alias functions to persist session variables (#58031) Problem Summary: When creating views, generated columns, materialized views, or alias functions, session variables that affect query results (e.g., `enable_decimal256`) are not persisted. This causes inconsistent query results when session variables differ between creation and query time. This PR persists session variables marked with `affectQueryResult()` annotation for: - Views - Generated Columns - Materialized Views - Alias Functions When querying these objects, the system automatically uses the persisted session variables from creation time, ensuring consistent results. - Added `sessionVariables` field to `View`, `Column`, `MTMV`, and `AliasFunction` classes - Created `AutoCloseSessionVariable` utility to temporarily apply persisted variables during query processing - Modified view/column/MV/function parsing logic to use persisted session variables - Added `SessionVarGuardExpr` to protect expressions that depend on session variables --------- Co-authored-by: jacktengg --- be/src/olap/memtable.cpp | 6 +- be/src/olap/tablet_schema.cpp | 3 +- be/src/pipeline/exec/analytic_sink_operator.h | 8 +- be/src/runtime/primitive_type.h | 129 +---- be/src/runtime/runtime_state.h | 4 - .../aggregate_functions/aggregate_function.h | 1 - .../aggregate_function_ai_agg.cpp | 2 +- ...gregate_function_approx_count_distinct.cpp | 4 +- .../aggregate_function_array_agg.cpp | 1 + .../aggregate_function_avg.cpp | 36 +- .../aggregate_function_avg.h | 33 +- .../aggregate_function_bitmap.cpp | 5 +- .../aggregate_function_bitmap_agg.cpp | 1 + .../aggregate_function_collect.cpp | 1 + .../aggregate_function_corr.cpp | 2 + .../aggregate_function_count.cpp | 5 +- .../aggregate_function_count_by_enum.cpp | 1 + .../aggregate_function_covar.cpp | 2 + .../aggregate_function_distinct.cpp | 5 +- .../aggregate_function_foreach.cpp | 10 +- .../aggregate_function_foreachv2.cpp | 11 +- .../aggregate_function_group_array_set_op.cpp | 8 +- .../aggregate_function_group_concat.cpp | 1 + .../aggregate_function_histogram.cpp | 1 + .../aggregate_function_kurtosis.cpp | 1 + .../aggregate_function_linear_histogram.cpp | 1 + .../aggregate_function_map.cpp | 1 + .../aggregate_function_map_v2.cpp | 1 + .../aggregate_function_min_max.cpp | 1 + .../aggregate_function_min_max.h | 5 +- .../aggregate_function_min_max_by.h | 1 + .../aggregate_function_orthogonal_bitmap.cpp | 1 + .../aggregate_function_percentile.cpp | 8 +- ...ggregate_function_percentile_reservoir.cpp | 1 + .../aggregate_function_product.h | 11 +- .../aggregate_function_quantile_state.cpp | 4 +- .../aggregate_function_quantile_state.h | 4 +- .../aggregate_function_reader_first_last.h | 41 +- .../aggregate_function_regr_union.cpp | 1 + .../aggregate_function_sem.cpp | 12 +- .../aggregate_function_sequence_match.cpp | 1 + .../aggregate_function_simple_factory.h | 14 +- .../aggregate_function_skew.cpp | 1 + .../aggregate_function_stddev.cpp | 16 +- .../aggregate_function_sum.cpp | 18 +- .../aggregate_function_sum.h | 51 +- .../aggregate_function_topn.cpp | 3 + .../aggregate_function_uniq.cpp | 1 + ...aggregate_function_uniq_distribute_key.cpp | 1 + .../aggregate_function_window.cpp | 4 +- .../aggregate_function_window_funnel.cpp | 1 + be/src/vec/aggregate_functions/helpers.h | 72 +++ be/src/vec/core/call_on_type_index.h | 162 ++++++ be/src/vec/data_types/data_type_agg_state.h | 20 +- be/src/vec/exec/scan/file_scanner.cpp | 5 +- be/src/vec/exprs/vcast_expr.cpp | 5 +- be/src/vec/exprs/vectorized_agg_fn.cpp | 10 +- be/src/vec/exprs/vectorized_fn_call.cpp | 3 +- be/src/vec/exprs/vin_predicate.cpp | 5 +- be/src/vec/exprs/vmatch_predicate.cpp | 5 +- be/src/vec/exprs/vtopn_pred.h | 4 +- .../array/function_array_aggregation.cpp | 491 ++++++++++------ .../array/function_array_cum_sum.cpp | 131 +++-- .../functions/comparison_equal_for_null.cpp | 10 +- be/src/vec/functions/function.h | 73 ++- be/src/vec/functions/function_ifnull.h | 3 +- be/src/vec/functions/nullif.cpp | 10 +- .../vec/functions/simple_function_factory.h | 32 +- be/test/ai/aggregate_function_ai_agg_test.cpp | 2 +- .../operator/analytic_sink_operator_test.cpp | 35 +- .../operator/streaming_agg_operator_test.cpp | 5 +- .../testutil/mock/mock_agg_fn_evaluator.cpp | 21 +- be/test/testutil/mock/mock_agg_fn_evaluator.h | 4 +- .../agg_array_agg_test.cpp | 11 +- .../vec/aggregate_functions/agg_avg_test.cpp | 3 +- .../vec/aggregate_functions/agg_bit_test.cpp | 9 +- .../aggregate_functions/agg_bitmap_test.cpp | 10 +- .../agg_bool_union_test.cpp | 18 +- .../aggregate_functions/agg_collect_test.cpp | 14 +- .../vec/aggregate_functions/agg_corr_test.cpp | 6 +- .../aggregate_functions/agg_count_test.cpp | 3 +- .../aggregate_functions/agg_function_test.h | 5 +- .../agg_group_array_intersect_test.cpp | 8 +- .../agg_histogram_test.cpp | 2 +- .../agg_linear_histogram_test.cpp | 4 +- .../agg_min_max_by_test.cpp | 2 +- .../aggregate_functions/agg_min_max_test.cpp | 8 +- .../aggregate_functions/agg_replace_test.cpp | 4 +- be/test/vec/aggregate_functions/agg_test.cpp | 45 +- .../vec_count_by_enum_test.cpp | 2 +- .../vec_retention_test.cpp | 2 +- .../vec_sequence_match_test.cpp | 10 +- .../vec_window_funnel_test.cpp | 2 +- .../function/function_dict_get_many_test.cpp | 4 +- .../vec/function/function_dict_get_test.cpp | 4 +- .../function/simple_function_factory_test.cpp | 4 +- .../java/org/apache/doris/alter/Alter.java | 12 +- .../apache/doris/alter/AlterJobV2Factory.java | 9 +- .../apache/doris/alter/CloudRollupJobV2.java | 4 +- .../doris/alter/MaterializedViewHandler.java | 20 +- .../org/apache/doris/alter/RollupJobV2.java | 15 +- .../apache/doris/alter/SchemaChangeJobV2.java | 6 +- .../apache/doris/analysis/MVColumnItem.java | 3 +- .../apache/doris/catalog/AliasFunction.java | 27 + .../java/org/apache/doris/catalog/Column.java | 34 +- .../org/apache/doris/catalog/Database.java | 7 +- .../java/org/apache/doris/catalog/Env.java | 5 +- .../apache/doris/catalog/FunctionUtil.java | 37 +- .../doris/catalog/GlobalFunctionMgr.java | 4 +- .../java/org/apache/doris/catalog/MTMV.java | 70 ++- .../doris/catalog/MaterializedIndexMeta.java | 108 ++-- .../org/apache/doris/catalog/OlapTable.java | 10 +- .../doris/catalog/OlapTableFactory.java | 10 + .../java/org/apache/doris/catalog/View.java | 18 +- .../java/org/apache/doris/mtmv/MTMVCache.java | 26 +- .../org/apache/doris/mtmv/MTMVPlanUtil.java | 6 +- .../apache/doris/nereids/SqlCacheContext.java | 2 +- .../doris/nereids/StatementContext.java | 10 +- .../glue/translator/ExpressionTranslator.java | 6 + .../translator/PhysicalPlanTranslator.java | 46 +- .../doris/nereids/jobs/executor/Rewriter.java | 2 + .../doris/nereids/memo/StructInfoMap.java | 13 +- .../doris/nereids/minidump/MinidumpUtils.java | 4 +- .../nereids/parser/LogicalPlanBuilder.java | 6 +- .../post/CommonSubExpressionCollector.java | 18 +- .../nereids/rules/analysis/BindRelation.java | 22 +- .../nereids/rules/analysis/BindSink.java | 66 ++- .../rules/analysis/CollectRelation.java | 11 +- .../rules/analysis/NormalizeAggregate.java | 12 +- .../analysis/SessionVarGuardRewriter.java | 149 +++++ .../mv/InitMaterializationContextHook.java | 30 +- .../ExpressionBottomUpRewriter.java | 12 +- .../rules/expression/ExpressionRuleType.java | 4 +- .../rules/expression/MergeGuardExpr.java | 52 ++ .../rules/SimplifyComparisonPredicate.java | 6 +- .../SplitAggWithoutDistinct.java | 53 +- .../nereids/rules/rewrite/ColumnPruning.java | 2 +- .../rewrite/DistinctAggregateRewriter.java | 22 +- .../nereids/rules/rewrite/MergeAggregate.java | 46 +- .../doris/nereids/trees/expressions/Add.java | 2 +- .../nereids/trees/expressions/CaseWhen.java | 2 +- .../nereids/trees/expressions/Divide.java | 2 +- .../doris/nereids/trees/expressions/Mod.java | 2 +- .../nereids/trees/expressions/Multiply.java | 2 +- .../expressions/NeedSessionVarGuard.java | 23 +- .../expressions/SessionVarGuardExpr.java | 147 +++++ .../nereids/trees/expressions/Subtract.java | 2 +- .../trees/expressions/WindowExpression.java | 2 +- .../functions/ComputePrecisionForSum.java | 6 + .../functions/agg/AggregateFunction.java | 2 +- .../trees/expressions/functions/agg/Avg.java | 4 +- .../functions/agg/MultiDistinctSum.java | 3 +- .../functions/agg/MultiDistinctSum0.java | 3 +- .../trees/expressions/functions/agg/Sum.java | 3 +- .../trees/expressions/functions/agg/Sum0.java | 3 +- .../executable/NumericArithmetic.java | 18 +- .../functions/scalar/ArrayAvg.java | 3 +- .../functions/scalar/ArrayCumSum.java | 4 +- .../functions/scalar/ArrayMax.java | 3 +- .../functions/scalar/ArrayMin.java | 3 +- .../functions/scalar/ArrayProduct.java | 3 +- .../functions/scalar/ArraySum.java | 3 +- .../functions/scalar/Coalesce.java | 3 +- .../functions/scalar/GreatestLeast.java | 3 +- .../expressions/functions/scalar/If.java | 3 +- .../expressions/functions/scalar/NullIf.java | 3 +- .../expressions/functions/scalar/Nvl.java | 3 +- .../expressions/functions/udf/AliasUdf.java | 30 +- .../functions/udf/AliasUdfBuilder.java | 64 ++- .../expressions/literal/DecimalV3Literal.java | 19 +- .../trees/expressions/literal/Literal.java | 2 +- .../visitor/ExpressionVisitor.java | 5 + .../trees/plans/algebra/Aggregate.java | 29 + .../plans/commands/CreateFunctionCommand.java | 5 +- .../CreateMaterializedViewCommand.java | 5 +- .../plans/commands/info/ColumnDefinition.java | 13 +- .../plans/commands/info/CreateMTMVInfo.java | 9 +- .../commands/info/ShowCreateMTMVInfo.java | 3 +- .../doris/nereids/types/DecimalV2Type.java | 2 +- .../doris/nereids/types/DecimalV3Type.java | 24 +- .../apache/doris/nereids/util/PlanUtils.java | 42 ++ .../apache/doris/persist/AlterViewInfo.java | 20 +- .../doris/qe/AutoCloseSessionVariable.java | 63 +++ .../apache/doris/qe/ConnectContextUtil.java | 8 + .../org/apache/doris/qe/SessionVariable.java | 159 +++--- .../java/org/apache/doris/qe/VariableMgr.java | 87 +-- .../alter/MaterializedViewHandlerTest.java | 7 +- .../apache/doris/alter/RollupJobV2Test.java | 2 +- .../apache/doris/catalog/CreateViewTest.java | 5 +- .../apache/doris/catalog/DatabaseTest.java | 1 + .../nereids/rules/rewrite/PrepareTest.java | 83 +++ .../expressions/VariablePersistTest.java | 215 +++++++ .../commands/merge/MergeIntoCommandTest.java | 42 +- .../doris/persist/AlterViewInfoTest.java | 3 +- .../apache/doris/qe/OlapQueryCacheTest.java | 9 +- .../apache/doris/qe/SessionVariablesTest.java | 3 +- .../decimal_sum/test_decimal_sum.out | 5 - .../agg_state/decimalv3/test_decimalv3.out | 17 + .../variables_persist/test_alias_function.out | 37 ++ .../variables_persist/test_array_agg_view.out | Bin 0 -> 6518 bytes .../test_generated_column.out | 25 + .../variables_persist/test_mtmv.out | 21 + .../variables_persist/test_sync_mv.out | 16 + .../test_view_var_persist.out | 216 +++++++ .../variables_persist/use_view_create_mv.out | 11 + .../agg_state/decimalv3/test_decimalv3.groovy | 66 +++ .../nereids_p0/variables_persist/load.groovy | 44 ++ .../test_alias_function.groovy | 80 +++ .../test_array_agg_view.groovy | 77 +++ .../test_generated_column.groovy | 147 +++++ .../variables_persist/test_mtmv.groovy | 249 +++++++++ .../variables_persist/test_sync_mv.groovy | 117 ++++ .../test_view_var_persist.groovy | 525 ++++++++++++++++++ .../use_view_create_mv.groovy | 68 +++ 214 files changed, 4631 insertions(+), 1107 deletions(-) create mode 100644 fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/SessionVarGuardRewriter.java create mode 100644 fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/MergeGuardExpr.java rename regression-test/suites/datatype_p0/agg_state/decimal_sum/test_decimal_sum.groovy => fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/NeedSessionVarGuard.java (54%) create mode 100644 fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/SessionVarGuardExpr.java create mode 100644 fe/fe-core/src/main/java/org/apache/doris/qe/AutoCloseSessionVariable.java create mode 100644 fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PrepareTest.java create mode 100644 fe/fe-core/src/test/java/org/apache/doris/nereids/trees/expressions/VariablePersistTest.java delete mode 100644 regression-test/data/datatype_p0/agg_state/decimal_sum/test_decimal_sum.out create mode 100644 regression-test/data/datatype_p0/agg_state/decimalv3/test_decimalv3.out create mode 100644 regression-test/data/nereids_p0/variables_persist/test_alias_function.out create mode 100644 regression-test/data/nereids_p0/variables_persist/test_array_agg_view.out create mode 100644 regression-test/data/nereids_p0/variables_persist/test_generated_column.out create mode 100644 regression-test/data/nereids_p0/variables_persist/test_mtmv.out create mode 100644 regression-test/data/nereids_p0/variables_persist/test_sync_mv.out create mode 100644 regression-test/data/nereids_p0/variables_persist/test_view_var_persist.out create mode 100644 regression-test/data/nereids_p0/variables_persist/use_view_create_mv.out create mode 100644 regression-test/suites/datatype_p0/agg_state/decimalv3/test_decimalv3.groovy create mode 100644 regression-test/suites/nereids_p0/variables_persist/load.groovy create mode 100644 regression-test/suites/nereids_p0/variables_persist/test_alias_function.groovy create mode 100644 regression-test/suites/nereids_p0/variables_persist/test_array_agg_view.groovy create mode 100644 regression-test/suites/nereids_p0/variables_persist/test_generated_column.groovy create mode 100644 regression-test/suites/nereids_p0/variables_persist/test_mtmv.groovy create mode 100644 regression-test/suites/nereids_p0/variables_persist/test_sync_mv.groovy create mode 100644 regression-test/suites/nereids_p0/variables_persist/test_view_var_persist.groovy create mode 100644 regression-test/suites/nereids_p0/variables_persist/use_view_create_mv.groovy diff --git a/be/src/olap/memtable.cpp b/be/src/olap/memtable.cpp index f7d674eee43b87..c8d2bc86c842e9 100644 --- a/be/src/olap/memtable.cpp +++ b/be/src/olap/memtable.cpp @@ -114,13 +114,13 @@ void MemTable::_init_agg_functions(const vectorized::Block* block) { // the aggregate function manually. if (_skip_bitmap_col_idx != cid) { function = vectorized::AggregateFunctionSimpleFactory::instance().get( - "replace_load", {block->get_data_type(cid)}, + "replace_load", {block->get_data_type(cid)}, block->get_data_type(cid), block->get_data_type(cid)->is_nullable(), BeExecVersionManager::get_newest_version()); } else { function = vectorized::AggregateFunctionSimpleFactory::instance().get( - "bitmap_intersect", {block->get_data_type(cid)}, false, - BeExecVersionManager::get_newest_version()); + "bitmap_intersect", {block->get_data_type(cid)}, block->get_data_type(cid), + false, BeExecVersionManager::get_newest_version()); } } else { function = _tablet_schema->column(cid).get_aggregate_function( diff --git a/be/src/olap/tablet_schema.cpp b/be/src/olap/tablet_schema.cpp index d8a6c8375131c1..7d9336cf1081f2 100644 --- a/be/src/olap/tablet_schema.cpp +++ b/be/src/olap/tablet_schema.cpp @@ -776,7 +776,8 @@ vectorized::AggregateFunctionPtr TabletColumn::get_aggregate_function( std::transform(agg_name.begin(), agg_name.end(), agg_name.begin(), [](unsigned char c) { return std::tolower(c); }); function = vectorized::AggregateFunctionSimpleFactory::instance().get( - agg_name, {type}, type->is_nullable(), BeExecVersionManager::get_newest_version()); + agg_name, {type}, type, type->is_nullable(), + BeExecVersionManager::get_newest_version()); if (!function) { LOG(WARNING) << "get column aggregate function failed, aggregation_name=" << origin_name << ", column_type=" << type->get_name(); diff --git a/be/src/pipeline/exec/analytic_sink_operator.h b/be/src/pipeline/exec/analytic_sink_operator.h index c7c979223f1584..4e7d82679f3028 100644 --- a/be/src/pipeline/exec/analytic_sink_operator.h +++ b/be/src/pipeline/exec/analytic_sink_operator.h @@ -186,9 +186,15 @@ class AnalyticSinkOperatorX final : public DataSinkOperatorX= type1); } constexpr bool is_number(PrimitiveType type) { @@ -251,8 +266,6 @@ struct PrimitiveTypeTraits; * DataType: DataType which is mapping to this PrimitiveType * ColumnType: ColumnType which is mapping to this PrimitiveType * NearestFieldType: Nearest Doris type in execution engine - * AvgNearestFieldType: Nearest Doris type in execution engine for Avg - * AvgNearestFieldType256: Nearest Doris type in execution engine for Avg * NearestPrimitiveType: Nearest primitive type */ template <> @@ -264,11 +277,8 @@ struct PrimitiveTypeTraits { using DataType = vectorized::DataTypeBool; using ColumnType = vectorized::ColumnUInt8; using NearestFieldType = vectorized::Int64; - using AvgNearestFieldType = vectorized::Int64; - using AvgNearestFieldType256 = vectorized::Int64; static constexpr PrimitiveType NearestPrimitiveType = TYPE_BIGINT; static constexpr PrimitiveType AvgNearestPrimitiveType = TYPE_BIGINT; - static constexpr PrimitiveType AvgNearestPrimitiveType256 = TYPE_BIGINT; }; template <> struct PrimitiveTypeTraits { @@ -279,11 +289,8 @@ struct PrimitiveTypeTraits { using DataType = vectorized::DataTypeInt8; using ColumnType = vectorized::ColumnInt8; using NearestFieldType = vectorized::Int64; - using AvgNearestFieldType = vectorized::Int64; - using AvgNearestFieldType256 = vectorized::Int64; static constexpr PrimitiveType NearestPrimitiveType = TYPE_BIGINT; static constexpr PrimitiveType AvgNearestPrimitiveType = TYPE_BIGINT; - static constexpr PrimitiveType AvgNearestPrimitiveType256 = TYPE_BIGINT; }; template <> struct PrimitiveTypeTraits { @@ -294,11 +301,8 @@ struct PrimitiveTypeTraits { using DataType = vectorized::DataTypeInt16; using ColumnType = vectorized::ColumnInt16; using NearestFieldType = vectorized::Int64; - using AvgNearestFieldType = vectorized::Int64; - using AvgNearestFieldType256 = vectorized::Int64; static constexpr PrimitiveType NearestPrimitiveType = TYPE_BIGINT; static constexpr PrimitiveType AvgNearestPrimitiveType = TYPE_BIGINT; - static constexpr PrimitiveType AvgNearestPrimitiveType256 = TYPE_BIGINT; }; template <> struct PrimitiveTypeTraits { @@ -309,11 +313,8 @@ struct PrimitiveTypeTraits { using DataType = vectorized::DataTypeInt32; using ColumnType = vectorized::ColumnInt32; using NearestFieldType = vectorized::Int64; - using AvgNearestFieldType = vectorized::Int64; - using AvgNearestFieldType256 = vectorized::Int64; static constexpr PrimitiveType NearestPrimitiveType = TYPE_BIGINT; static constexpr PrimitiveType AvgNearestPrimitiveType = TYPE_BIGINT; - static constexpr PrimitiveType AvgNearestPrimitiveType256 = TYPE_BIGINT; }; template <> struct PrimitiveTypeTraits { @@ -324,11 +325,8 @@ struct PrimitiveTypeTraits { using DataType = vectorized::DataTypeInt64; using ColumnType = vectorized::ColumnInt64; using NearestFieldType = vectorized::Int64; - using AvgNearestFieldType = vectorized::Int128; - using AvgNearestFieldType256 = vectorized::Int128; static constexpr PrimitiveType NearestPrimitiveType = TYPE_BIGINT; static constexpr PrimitiveType AvgNearestPrimitiveType = TYPE_LARGEINT; - static constexpr PrimitiveType AvgNearestPrimitiveType256 = TYPE_LARGEINT; }; template <> struct PrimitiveTypeTraits { @@ -339,11 +337,8 @@ struct PrimitiveTypeTraits { using DataType = vectorized::DataTypeInt128; using ColumnType = vectorized::ColumnInt128; using NearestFieldType = vectorized::Int128; - using AvgNearestFieldType = vectorized::Int128; - using AvgNearestFieldType256 = vectorized::Int128; static constexpr PrimitiveType NearestPrimitiveType = TYPE_LARGEINT; static constexpr PrimitiveType AvgNearestPrimitiveType = TYPE_LARGEINT; - static constexpr PrimitiveType AvgNearestPrimitiveType256 = TYPE_LARGEINT; }; template <> struct PrimitiveTypeTraits { @@ -354,11 +349,8 @@ struct PrimitiveTypeTraits { using DataType = vectorized::DataTypeNothing; using ColumnType = vectorized::IColumnDummy; using NearestFieldType = vectorized::Null; - using AvgNearestFieldType = vectorized::Null; - using AvgNearestFieldType256 = vectorized::Null; static constexpr PrimitiveType NearestPrimitiveType = TYPE_NULL; static constexpr PrimitiveType AvgNearestPrimitiveType = TYPE_NULL; - static constexpr PrimitiveType AvgNearestPrimitiveType256 = TYPE_NULL; }; template <> struct PrimitiveTypeTraits { @@ -369,11 +361,8 @@ struct PrimitiveTypeTraits { using DataType = vectorized::DataTypeFloat32; using ColumnType = vectorized::ColumnFloat32; using NearestFieldType = vectorized::Float64; - using AvgNearestFieldType = vectorized::Float64; - using AvgNearestFieldType256 = vectorized::Float64; static constexpr PrimitiveType NearestPrimitiveType = TYPE_DOUBLE; static constexpr PrimitiveType AvgNearestPrimitiveType = TYPE_DOUBLE; - static constexpr PrimitiveType AvgNearestPrimitiveType256 = TYPE_DOUBLE; }; template <> struct PrimitiveTypeTraits { @@ -384,11 +373,8 @@ struct PrimitiveTypeTraits { using DataType = vectorized::DataTypeFloat64; using ColumnType = vectorized::ColumnFloat64; using NearestFieldType = vectorized::Float64; - using AvgNearestFieldType = vectorized::Float64; - using AvgNearestFieldType256 = vectorized::Float64; static constexpr PrimitiveType NearestPrimitiveType = TYPE_DOUBLE; static constexpr PrimitiveType AvgNearestPrimitiveType = TYPE_DOUBLE; - static constexpr PrimitiveType AvgNearestPrimitiveType256 = TYPE_DOUBLE; }; template <> struct PrimitiveTypeTraits { @@ -399,11 +385,8 @@ struct PrimitiveTypeTraits { using DataType = vectorized::DataTypeTimeV2; using ColumnType = vectorized::ColumnTimeV2; using NearestFieldType = vectorized::Float64; - using AvgNearestFieldType = vectorized::Float64; - using AvgNearestFieldType256 = vectorized::Float64; static constexpr PrimitiveType NearestPrimitiveType = TYPE_DOUBLE; static constexpr PrimitiveType AvgNearestPrimitiveType = TYPE_DOUBLE; - static constexpr PrimitiveType AvgNearestPrimitiveType256 = TYPE_DOUBLE; }; template <> struct PrimitiveTypeTraits { @@ -414,11 +397,8 @@ struct PrimitiveTypeTraits { using DataType = vectorized::DataTypeTimeV2; using ColumnType = vectorized::ColumnTime; using NearestFieldType = vectorized::Float64; - using AvgNearestFieldType = vectorized::Float64; - using AvgNearestFieldType256 = vectorized::Float64; static constexpr PrimitiveType NearestPrimitiveType = TYPE_DOUBLE; static constexpr PrimitiveType AvgNearestPrimitiveType = TYPE_DOUBLE; - static constexpr PrimitiveType AvgNearestPrimitiveType256 = TYPE_DOUBLE; }; template <> struct PrimitiveTypeTraits { @@ -430,11 +410,8 @@ struct PrimitiveTypeTraits { using DataType = vectorized::DataTypeDate; using ColumnType = vectorized::ColumnDate; using NearestFieldType = vectorized::Int64; - using AvgNearestFieldType = vectorized::Int64; - using AvgNearestFieldType256 = vectorized::Int64; static constexpr PrimitiveType NearestPrimitiveType = TYPE_DATE; static constexpr PrimitiveType AvgNearestPrimitiveType = TYPE_DATE; - static constexpr PrimitiveType AvgNearestPrimitiveType256 = TYPE_DATE; }; template <> struct PrimitiveTypeTraits { @@ -445,11 +422,8 @@ struct PrimitiveTypeTraits { using DataType = vectorized::DataTypeDateTime; using ColumnType = vectorized::ColumnDateTime; using NearestFieldType = vectorized::Int64; - using AvgNearestFieldType = vectorized::Int64; - using AvgNearestFieldType256 = vectorized::Int64; static constexpr PrimitiveType NearestPrimitiveType = TYPE_DATETIME; static constexpr PrimitiveType AvgNearestPrimitiveType = TYPE_DATETIME; - static constexpr PrimitiveType AvgNearestPrimitiveType256 = TYPE_DATETIME; }; template <> struct PrimitiveTypeTraits { @@ -460,11 +434,8 @@ struct PrimitiveTypeTraits { using DataType = vectorized::DataTypeDateTimeV2; using ColumnType = vectorized::ColumnDateTimeV2; using NearestFieldType = vectorized::UInt64; - using AvgNearestFieldType = vectorized::UInt64; - using AvgNearestFieldType256 = vectorized::UInt64; static constexpr PrimitiveType NearestPrimitiveType = TYPE_DATETIMEV2; static constexpr PrimitiveType AvgNearestPrimitiveType = TYPE_DATETIMEV2; - static constexpr PrimitiveType AvgNearestPrimitiveType256 = TYPE_DATETIMEV2; }; template <> struct PrimitiveTypeTraits { @@ -475,11 +446,8 @@ struct PrimitiveTypeTraits { using DataType = vectorized::DataTypeDateV2; using ColumnType = vectorized::ColumnDateV2; using NearestFieldType = vectorized::UInt64; - using AvgNearestFieldType = vectorized::UInt32; - using AvgNearestFieldType256 = vectorized::UInt32; static constexpr PrimitiveType NearestPrimitiveType = TYPE_DATEV2; static constexpr PrimitiveType AvgNearestPrimitiveType = TYPE_DATEV2; - static constexpr PrimitiveType AvgNearestPrimitiveType256 = TYPE_DATEV2; }; template <> struct PrimitiveTypeTraits { @@ -491,11 +459,8 @@ struct PrimitiveTypeTraits { using DataType = vectorized::DataTypeDecimalV2; using ColumnType = vectorized::ColumnDecimal128V2; using NearestFieldType = vectorized::DecimalField; - using AvgNearestFieldType = vectorized::Decimal128V2; - using AvgNearestFieldType256 = vectorized::Decimal256; static constexpr PrimitiveType NearestPrimitiveType = TYPE_DECIMALV2; static constexpr PrimitiveType AvgNearestPrimitiveType = TYPE_DECIMALV2; - static constexpr PrimitiveType AvgNearestPrimitiveType256 = TYPE_DECIMALV2; }; template <> struct PrimitiveTypeTraits { @@ -506,11 +471,8 @@ struct PrimitiveTypeTraits { using DataType = vectorized::DataTypeDecimal32; using ColumnType = vectorized::ColumnDecimal32; using NearestFieldType = vectorized::DecimalField; - using AvgNearestFieldType = vectorized::Decimal128V3; - using AvgNearestFieldType256 = vectorized::Decimal256; static constexpr PrimitiveType NearestPrimitiveType = TYPE_DECIMAL32; static constexpr PrimitiveType AvgNearestPrimitiveType = TYPE_DECIMAL128I; - static constexpr PrimitiveType AvgNearestPrimitiveType256 = TYPE_DECIMAL256; }; template <> struct PrimitiveTypeTraits { @@ -521,11 +483,8 @@ struct PrimitiveTypeTraits { using DataType = vectorized::DataTypeDecimal64; using ColumnType = vectorized::ColumnDecimal64; using NearestFieldType = vectorized::DecimalField; - using AvgNearestFieldType = vectorized::Decimal128V3; - using AvgNearestFieldType256 = vectorized::Decimal256; static constexpr PrimitiveType NearestPrimitiveType = TYPE_DECIMAL64; static constexpr PrimitiveType AvgNearestPrimitiveType = TYPE_DECIMAL128I; - static constexpr PrimitiveType AvgNearestPrimitiveType256 = TYPE_DECIMAL256; }; template <> struct PrimitiveTypeTraits { @@ -536,11 +495,8 @@ struct PrimitiveTypeTraits { using DataType = vectorized::DataTypeDecimal128; using ColumnType = vectorized::ColumnDecimal128V3; using NearestFieldType = vectorized::DecimalField; - using AvgNearestFieldType = vectorized::Decimal128V3; - using AvgNearestFieldType256 = vectorized::Decimal256; static constexpr PrimitiveType NearestPrimitiveType = TYPE_DECIMAL128I; static constexpr PrimitiveType AvgNearestPrimitiveType = TYPE_DECIMAL128I; - static constexpr PrimitiveType AvgNearestPrimitiveType256 = TYPE_DECIMAL256; }; template <> struct PrimitiveTypeTraits { @@ -551,11 +507,8 @@ struct PrimitiveTypeTraits { using DataType = vectorized::DataTypeDecimal256; using ColumnType = vectorized::ColumnDecimal256; using NearestFieldType = vectorized::DecimalField; - using AvgNearestFieldType = vectorized::Decimal256; - using AvgNearestFieldType256 = vectorized::Decimal256; static constexpr PrimitiveType NearestPrimitiveType = TYPE_DECIMAL256; static constexpr PrimitiveType AvgNearestPrimitiveType = TYPE_DECIMAL256; - static constexpr PrimitiveType AvgNearestPrimitiveType256 = TYPE_DECIMAL256; }; template <> struct PrimitiveTypeTraits { @@ -566,11 +519,8 @@ struct PrimitiveTypeTraits { using DataType = vectorized::DataTypeIPv4; using ColumnType = vectorized::ColumnIPv4; using NearestFieldType = IPv4; - using AvgNearestFieldType = IPv4; - using AvgNearestFieldType256 = IPv4; static constexpr PrimitiveType NearestPrimitiveType = TYPE_IPV4; static constexpr PrimitiveType AvgNearestPrimitiveType = TYPE_IPV4; - static constexpr PrimitiveType AvgNearestPrimitiveType256 = TYPE_IPV4; }; template <> struct PrimitiveTypeTraits { @@ -581,11 +531,8 @@ struct PrimitiveTypeTraits { using DataType = vectorized::DataTypeIPv6; using ColumnType = vectorized::ColumnIPv6; using NearestFieldType = IPv6; - using AvgNearestFieldType = IPv6; - using AvgNearestFieldType256 = IPv6; static constexpr PrimitiveType NearestPrimitiveType = TYPE_IPV6; static constexpr PrimitiveType AvgNearestPrimitiveType = TYPE_IPV6; - static constexpr PrimitiveType AvgNearestPrimitiveType256 = TYPE_IPV6; }; template <> struct PrimitiveTypeTraits { @@ -596,11 +543,8 @@ struct PrimitiveTypeTraits { using DataType = vectorized::DataTypeString; using ColumnType = vectorized::ColumnString; using NearestFieldType = vectorized::String; - using AvgNearestFieldType = vectorized::String; - using AvgNearestFieldType256 = vectorized::String; static constexpr PrimitiveType NearestPrimitiveType = TYPE_CHAR; static constexpr PrimitiveType AvgNearestPrimitiveType = TYPE_CHAR; - static constexpr PrimitiveType AvgNearestPrimitiveType256 = TYPE_CHAR; }; template <> struct PrimitiveTypeTraits { @@ -611,11 +555,8 @@ struct PrimitiveTypeTraits { using DataType = vectorized::DataTypeString; using ColumnType = vectorized::ColumnString; using NearestFieldType = vectorized::String; - using AvgNearestFieldType = vectorized::String; - using AvgNearestFieldType256 = vectorized::String; static constexpr PrimitiveType NearestPrimitiveType = TYPE_VARCHAR; static constexpr PrimitiveType AvgNearestPrimitiveType = TYPE_VARCHAR; - static constexpr PrimitiveType AvgNearestPrimitiveType256 = TYPE_VARCHAR; }; template <> struct PrimitiveTypeTraits { @@ -626,11 +567,8 @@ struct PrimitiveTypeTraits { using DataType = vectorized::DataTypeString; using ColumnType = vectorized::ColumnString; using NearestFieldType = vectorized::String; - using AvgNearestFieldType = vectorized::String; - using AvgNearestFieldType256 = vectorized::String; static constexpr PrimitiveType NearestPrimitiveType = TYPE_STRING; static constexpr PrimitiveType AvgNearestPrimitiveType = TYPE_STRING; - static constexpr PrimitiveType AvgNearestPrimitiveType256 = TYPE_STRING; }; template <> struct PrimitiveTypeTraits { @@ -641,11 +579,8 @@ struct PrimitiveTypeTraits { using DataType = vectorized::DataTypeVarbinary; using ColumnType = vectorized::ColumnVarbinary; using NearestFieldType = doris::StringView; - using AvgNearestFieldType = doris::StringView; - using AvgNearestFieldType256 = doris::StringView; static constexpr PrimitiveType NearestPrimitiveType = TYPE_VARBINARY; static constexpr PrimitiveType AvgNearestPrimitiveType = TYPE_VARBINARY; - static constexpr PrimitiveType AvgNearestPrimitiveType256 = TYPE_VARBINARY; }; template <> struct PrimitiveTypeTraits { @@ -656,11 +591,8 @@ struct PrimitiveTypeTraits { using DataType = vectorized::DataTypeHLL; using ColumnType = vectorized::ColumnHLL; using NearestFieldType = HyperLogLog; - using AvgNearestFieldType = HyperLogLog; - using AvgNearestFieldType256 = HyperLogLog; static constexpr PrimitiveType NearestPrimitiveType = TYPE_HLL; static constexpr PrimitiveType AvgNearestPrimitiveType = TYPE_HLL; - static constexpr PrimitiveType AvgNearestPrimitiveType256 = TYPE_HLL; }; template <> struct PrimitiveTypeTraits { @@ -671,11 +603,8 @@ struct PrimitiveTypeTraits { using DataType = vectorized::DataTypeJsonb; using ColumnType = vectorized::ColumnString; using NearestFieldType = vectorized::JsonbField; - using AvgNearestFieldType = vectorized::JsonbField; - using AvgNearestFieldType256 = vectorized::JsonbField; static constexpr PrimitiveType NearestPrimitiveType = TYPE_JSONB; static constexpr PrimitiveType AvgNearestPrimitiveType = TYPE_JSONB; - static constexpr PrimitiveType AvgNearestPrimitiveType256 = TYPE_JSONB; }; template <> struct PrimitiveTypeTraits { @@ -686,11 +615,8 @@ struct PrimitiveTypeTraits { using DataType = vectorized::DataTypeArray; using ColumnType = vectorized::ColumnArray; using NearestFieldType = vectorized::Array; - using AvgNearestFieldType = vectorized::Array; - using AvgNearestFieldType256 = vectorized::Array; static constexpr PrimitiveType NearestPrimitiveType = TYPE_ARRAY; static constexpr PrimitiveType AvgNearestPrimitiveType = TYPE_ARRAY; - static constexpr PrimitiveType AvgNearestPrimitiveType256 = TYPE_ARRAY; }; template <> struct PrimitiveTypeTraits { @@ -701,11 +627,8 @@ struct PrimitiveTypeTraits { using DataType = vectorized::DataTypeMap; using ColumnType = vectorized::ColumnMap; using NearestFieldType = vectorized::Map; - using AvgNearestFieldType = vectorized::Map; - using AvgNearestFieldType256 = vectorized::Map; static constexpr PrimitiveType NearestPrimitiveType = TYPE_MAP; static constexpr PrimitiveType AvgNearestPrimitiveType = TYPE_MAP; - static constexpr PrimitiveType AvgNearestPrimitiveType256 = TYPE_MAP; }; template <> struct PrimitiveTypeTraits { @@ -716,11 +639,8 @@ struct PrimitiveTypeTraits { using DataType = vectorized::DataTypeStruct; using ColumnType = vectorized::ColumnStruct; using NearestFieldType = vectorized::Tuple; - using AvgNearestFieldType = vectorized::Tuple; - using AvgNearestFieldType256 = vectorized::Tuple; static constexpr PrimitiveType NearestPrimitiveType = TYPE_STRUCT; static constexpr PrimitiveType AvgNearestPrimitiveType = TYPE_STRUCT; - static constexpr PrimitiveType AvgNearestPrimitiveType256 = TYPE_STRUCT; }; template <> struct PrimitiveTypeTraits { @@ -731,11 +651,8 @@ struct PrimitiveTypeTraits { using DataType = vectorized::DataTypeVariant; using ColumnType = vectorized::ColumnVariant; using NearestFieldType = vectorized::VariantMap; - using AvgNearestFieldType = vectorized::VariantMap; - using AvgNearestFieldType256 = vectorized::VariantMap; static constexpr PrimitiveType NearestPrimitiveType = TYPE_VARIANT; static constexpr PrimitiveType AvgNearestPrimitiveType = TYPE_VARIANT; - static constexpr PrimitiveType AvgNearestPrimitiveType256 = TYPE_VARIANT; }; template <> struct PrimitiveTypeTraits { @@ -746,11 +663,8 @@ struct PrimitiveTypeTraits { using DataType = vectorized::DataTypeBitMap; using ColumnType = vectorized::ColumnBitmap; using NearestFieldType = BitmapValue; - using AvgNearestFieldType = BitmapValue; - using AvgNearestFieldType256 = BitmapValue; static constexpr PrimitiveType NearestPrimitiveType = TYPE_BITMAP; static constexpr PrimitiveType AvgNearestPrimitiveType = TYPE_BITMAP; - static constexpr PrimitiveType AvgNearestPrimitiveType256 = TYPE_BITMAP; }; template <> struct PrimitiveTypeTraits { @@ -761,11 +675,8 @@ struct PrimitiveTypeTraits { using DataType = vectorized::DataTypeQuantileState; using ColumnType = vectorized::ColumnQuantileState; using NearestFieldType = QuantileState; - using AvgNearestFieldType = QuantileState; - using AvgNearestFieldType256 = QuantileState; static constexpr PrimitiveType NearestPrimitiveType = TYPE_QUANTILE_STATE; static constexpr PrimitiveType AvgNearestPrimitiveType = TYPE_QUANTILE_STATE; - static constexpr PrimitiveType AvgNearestPrimitiveType256 = TYPE_QUANTILE_STATE; }; template <> struct PrimitiveTypeTraits { @@ -776,11 +687,8 @@ struct PrimitiveTypeTraits { using DataType = vectorized::DataTypeNothing; using ColumnType = vectorized::ColumnOffset32; using NearestFieldType = vectorized::UInt64; - using AvgNearestFieldType = vectorized::UInt64; - using AvgNearestFieldType256 = vectorized::UInt64; static constexpr PrimitiveType NearestPrimitiveType = TYPE_UINT32; static constexpr PrimitiveType AvgNearestPrimitiveType = TYPE_UINT32; - static constexpr PrimitiveType AvgNearestPrimitiveType256 = TYPE_UINT32; }; template <> struct PrimitiveTypeTraits { @@ -791,11 +699,8 @@ struct PrimitiveTypeTraits { using DataType = vectorized::DataTypeNothing; using ColumnType = vectorized::ColumnOffset64; using NearestFieldType = vectorized::UInt64; - using AvgNearestFieldType = vectorized::Float64; - using AvgNearestFieldType256 = vectorized::Float64; static constexpr PrimitiveType NearestPrimitiveType = TYPE_UINT64; static constexpr PrimitiveType AvgNearestPrimitiveType = TYPE_DOUBLE; - static constexpr PrimitiveType AvgNearestPrimitiveType256 = TYPE_DOUBLE; }; template diff --git a/be/src/runtime/runtime_state.h b/be/src/runtime/runtime_state.h index b67b9004961ee3..5e869680961392 100644 --- a/be/src/runtime/runtime_state.h +++ b/be/src/runtime/runtime_state.h @@ -208,10 +208,6 @@ class RuntimeState { return _query_options.__isset.enable_insert_strict && _query_options.enable_insert_strict; } - bool enable_decimal256() const { - return _query_options.__isset.enable_decimal256 && _query_options.enable_decimal256; - } - bool enable_common_expr_pushdown() const { return _query_options.__isset.enable_common_expr_pushdown && _query_options.enable_common_expr_pushdown; diff --git a/be/src/vec/aggregate_functions/aggregate_function.h b/be/src/vec/aggregate_functions/aggregate_function.h index 95c96c8a038841..b9c2de9c148f0f 100644 --- a/be/src/vec/aggregate_functions/aggregate_function.h +++ b/be/src/vec/aggregate_functions/aggregate_function.h @@ -46,7 +46,6 @@ class IColumn; class IDataType; struct AggregateFunctionAttr { - bool enable_decimal256 {false}; bool is_window_function {false}; bool is_foreach {false}; std::vector column_names; diff --git a/be/src/vec/aggregate_functions/aggregate_function_ai_agg.cpp b/be/src/vec/aggregate_functions/aggregate_function_ai_agg.cpp index 2508fa3805f9f1..b9e6e93ccf232f 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_ai_agg.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_ai_agg.cpp @@ -26,7 +26,7 @@ QueryContext* AggregateFunctionAIAggData::_ctx = nullptr; void register_aggregate_function_ai_agg(AggregateFunctionSimpleFactory& factory) { factory.register_function_both("ai_agg", [](const std::string& name, const DataTypes& argument_types, - const bool result_is_nullable, + const DataTypePtr& result_type, const bool result_is_nullable, const AggregateFunctionAttr& attr) -> AggregateFunctionPtr { return creator_without_type::create( argument_types, result_is_nullable, attr); diff --git a/be/src/vec/aggregate_functions/aggregate_function_approx_count_distinct.cpp b/be/src/vec/aggregate_functions/aggregate_function_approx_count_distinct.cpp index 4bc5da2ac723ca..4b61d4ca6075c3 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_approx_count_distinct.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_approx_count_distinct.cpp @@ -25,8 +25,8 @@ namespace doris::vectorized { #include "common/compile_check_begin.h" AggregateFunctionPtr create_aggregate_function_approx_count_distinct( - const std::string& name, const DataTypes& argument_types, const bool result_is_nullable, - const AggregateFunctionAttr& attr) { + const std::string& name, const DataTypes& argument_types, const DataTypePtr& result_type, + const bool result_is_nullable, const AggregateFunctionAttr& attr) { return creator_with_type_list< TYPE_BOOLEAN, TYPE_TINYINT, TYPE_SMALLINT, TYPE_INT, TYPE_BIGINT, TYPE_LARGEINT, TYPE_FLOAT, TYPE_DOUBLE, TYPE_DECIMAL32, TYPE_DECIMAL64, TYPE_DECIMAL128I, diff --git a/be/src/vec/aggregate_functions/aggregate_function_array_agg.cpp b/be/src/vec/aggregate_functions/aggregate_function_array_agg.cpp index 4d974229ce7f7c..a1b79dad5dd828 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_array_agg.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_array_agg.cpp @@ -41,6 +41,7 @@ AggregateFunctionPtr do_create_agg_function_collect(const DataTypes& argument_ty AggregateFunctionPtr create_aggregate_function_array_agg(const std::string& name, const DataTypes& argument_types, + const DataTypePtr& result_type, const bool result_is_nullable, const AggregateFunctionAttr& attr) { switch (argument_types[0]->get_primitive_type()) { diff --git a/be/src/vec/aggregate_functions/aggregate_function_avg.cpp b/be/src/vec/aggregate_functions/aggregate_function_avg.cpp index 7ffbc57a01b94a..bf73f3fc2e996f 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_avg.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_avg.cpp @@ -28,41 +28,43 @@ namespace doris::vectorized { #include "common/compile_check_begin.h" +// TODO: use result type got from FE plan template struct Avg { - using FieldType = typename PrimitiveTypeTraits::AvgNearestFieldType; + static constexpr PrimitiveType ResultPType = T == TYPE_DECIMALV2 ? T : TYPE_DOUBLE; using Function = AggregateFunctionAvg< - T, AggregateFunctionAvgData::AvgNearestPrimitiveType>>; + T, ResultPType, + AggregateFunctionAvgData::AvgNearestPrimitiveType>>; }; template using AggregateFuncAvg = typename Avg::Function; -template -struct AvgDecimal256 { - using FieldType = typename PrimitiveTypeTraits::AvgNearestFieldType256; - using Function = AggregateFunctionAvg< - T, AggregateFunctionAvgData::AvgNearestPrimitiveType256>>; +// use result type got from FE plan +template +struct AvgDecimalV3 { + using Function = + AggregateFunctionAvg>; }; -template -using AggregateFuncAvgDecimal256 = typename AvgDecimal256::Function; +template +using AggregateFuncAvgDecimalV3 = typename AvgDecimalV3::Function; void register_aggregate_function_avg(AggregateFunctionSimpleFactory& factory) { AggregateFunctionCreator creator = [&](const std::string& name, const DataTypes& types, + const DataTypePtr& result_type, const bool result_is_nullable, const AggregateFunctionAttr& attr) { - if (attr.enable_decimal256 && is_decimal(types[0]->get_primitive_type())) { - return creator_with_type_list< - TYPE_DECIMAL32, TYPE_DECIMAL64, TYPE_DECIMAL128I, - TYPE_DECIMAL256>::creator(name, types, - result_is_nullable, attr); + if (is_decimalv3(types[0]->get_primitive_type())) { + return creator_with_type_list:: + creator_with_result_type(name, types, result_type, + result_is_nullable, attr); } else { return creator_with_type_list< TYPE_TINYINT, TYPE_SMALLINT, TYPE_INT, TYPE_BIGINT, TYPE_LARGEINT, TYPE_DOUBLE, - TYPE_DECIMAL32, TYPE_DECIMAL64, TYPE_DECIMAL128I, - TYPE_DECIMALV2>::creator(name, types, result_is_nullable, - attr); + TYPE_DECIMALV2>::creator(name, types, result_type, + result_is_nullable, attr); } }; factory.register_function_both("avg", creator); diff --git a/be/src/vec/aggregate_functions/aggregate_function_avg.h b/be/src/vec/aggregate_functions/aggregate_function_avg.h index cf65b41d90f643..0e263c0ec795c1 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_avg.h +++ b/be/src/vec/aggregate_functions/aggregate_function_avg.h @@ -30,6 +30,7 @@ #include #include "runtime/decimalv2_value.h" +#include "runtime/primitive_type.h" #include "vec/aggregate_functions/aggregate_function.h" #include "vec/columns/column.h" #include "vec/columns/column_fixed_length_object.h" @@ -105,23 +106,26 @@ struct AggregateFunctionAvgData { } }; +template +class AggregateFunctionAvg; + +template +constexpr static bool is_valid_avg_types = + (is_same_or_wider_decimalv3(T, TResult) || (is_decimalv2(T) && is_decimalv2(TResult)) || + (is_float_or_double(T) && is_float_or_double(TResult)) || + (is_int_or_bool(T) && (is_double(TResult) || is_int(TResult)))); /// Calculates arithmetic mean of numbers. -template -class AggregateFunctionAvg final - : public IAggregateFunctionDataHelper>, +template + requires(is_valid_avg_types) +class AggregateFunctionAvg final + : public IAggregateFunctionDataHelper>, UnaryExpression, NullableAggregateFunction { public: - using ResultType = std::conditional_t< - T == TYPE_DECIMALV2, Decimal128V2, - std::conditional_t>; - using ResultDataType = std::conditional_t< - T == TYPE_DECIMALV2, DataTypeDecimalV2, - std::conditional_t, DataTypeFloat64>>; - using ColVecType = typename PrimitiveTypeTraits::ColumnType; - using ColVecResult = std::conditional_t< - T == TYPE_DECIMALV2, ColumnDecimal128V2, - std::conditional_t, ColumnFloat64>>; + using ResultType = PrimitiveTypeTraits::ColumnItemType; + using ResultDataType = PrimitiveTypeTraits::DataType; + using ColVecType = PrimitiveTypeTraits::ColumnType; + using ColVecResult = PrimitiveTypeTraits::ColumnType; // The result calculated by PercentileApprox is an approximate value, // so the underlying storage uses float. The following calls will involve // an implicit cast to float. @@ -129,7 +133,8 @@ class AggregateFunctionAvg final using DataType = typename Data::ResultType; /// ctor for native types AggregateFunctionAvg(const DataTypes& argument_types_) - : IAggregateFunctionDataHelper>(argument_types_), + : IAggregateFunctionDataHelper>( + argument_types_), scale(get_decimal_scale(*argument_types_[0])) {} String get_name() const override { return "avg"; } diff --git a/be/src/vec/aggregate_functions/aggregate_function_bitmap.cpp b/be/src/vec/aggregate_functions/aggregate_function_bitmap.cpp index 0a37d0f1ae7db2..0932216b908009 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_bitmap.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_bitmap.cpp @@ -44,8 +44,8 @@ AggregateFunctionPtr create_with_int_data_type(const DataTypes& argument_type) { } AggregateFunctionPtr create_aggregate_function_bitmap_union_count( - const std::string& name, const DataTypes& argument_types, const bool result_is_nullable, - const AggregateFunctionAttr& attr) { + const std::string& name, const DataTypes& argument_types, const DataTypePtr& result_type, + const bool result_is_nullable, const AggregateFunctionAttr& attr) { const bool arg_is_nullable = argument_types[0]->is_nullable(); if (arg_is_nullable) { return std::make_shared>(argument_types); @@ -56,6 +56,7 @@ AggregateFunctionPtr create_aggregate_function_bitmap_union_count( AggregateFunctionPtr create_aggregate_function_bitmap_union_int(const std::string& name, const DataTypes& argument_types, + const DataTypePtr& result_type, const bool result_is_nullable, const AggregateFunctionAttr& attr) { const bool arg_is_nullable = argument_types[0]->is_nullable(); diff --git a/be/src/vec/aggregate_functions/aggregate_function_bitmap_agg.cpp b/be/src/vec/aggregate_functions/aggregate_function_bitmap_agg.cpp index 7d43a59987052c..7771b5bf4d798f 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_bitmap_agg.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_bitmap_agg.cpp @@ -46,6 +46,7 @@ AggregateFunctionPtr create_with_int_data_type(const DataTypes& argument_types) AggregateFunctionPtr create_aggregate_function_bitmap_agg(const std::string& name, const DataTypes& argument_types, + const DataTypePtr& result_type, const bool result_is_nullable, const AggregateFunctionAttr& attr) { const bool arg_is_nullable = argument_types[0]->is_nullable(); diff --git a/be/src/vec/aggregate_functions/aggregate_function_collect.cpp b/be/src/vec/aggregate_functions/aggregate_function_collect.cpp index 91531a76747e2c..4df9c470c8db3f 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_collect.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_collect.cpp @@ -127,6 +127,7 @@ AggregateFunctionPtr create_aggregate_function_collect_impl(const std::string& n AggregateFunctionPtr create_aggregate_function_collect(const std::string& name, const DataTypes& argument_types, + const DataTypePtr& result_type, const bool result_is_nullable, const AggregateFunctionAttr& attr) { assert_arity_range(name, argument_types, 1, 2); diff --git a/be/src/vec/aggregate_functions/aggregate_function_corr.cpp b/be/src/vec/aggregate_functions/aggregate_function_corr.cpp index ab993fc4d10d73..7ce9bb5e5de5f7 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_corr.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_corr.cpp @@ -30,6 +30,7 @@ using CorrMomentStat = StatFunc; AggregateFunctionPtr create_aggregate_corr_function(const std::string& name, const DataTypes& argument_types, + const DataTypePtr& result_type, const bool result_is_nullable, const AggregateFunctionAttr& attr) { assert_arity_range(name, argument_types, 2, 2); @@ -48,6 +49,7 @@ using CorrWelfordMomentStat = StatFunc; AggregateFunctionPtr create_aggregate_corr_welford_function(const std::string& name, const DataTypes& argument_types, + const DataTypePtr& result_type, const bool result_is_nullable, const AggregateFunctionAttr& attr) { assert_arity_range(name, argument_types, 2, 2); diff --git a/be/src/vec/aggregate_functions/aggregate_function_count.cpp b/be/src/vec/aggregate_functions/aggregate_function_count.cpp index 30d830e587272b..2fcfecbde9a21f 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_count.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_count.cpp @@ -30,6 +30,7 @@ namespace doris::vectorized { AggregateFunctionPtr create_aggregate_function_count(const std::string& name, const DataTypes& argument_types, + const DataTypePtr& result_type, const bool result_is_nullable, const AggregateFunctionAttr& attr) { assert_arity_range(name, argument_types, 0, 1); @@ -38,8 +39,8 @@ AggregateFunctionPtr create_aggregate_function_count(const std::string& name, } AggregateFunctionPtr create_aggregate_function_count_not_null_unary( - const std::string& name, const DataTypes& argument_types, const bool result_is_nullable, - const AggregateFunctionAttr& attr) { + const std::string& name, const DataTypes& argument_types, const DataTypePtr& result_type, + const bool result_is_nullable, const AggregateFunctionAttr& attr) { assert_arity_range(name, argument_types, 0, 1); return std::make_shared(argument_types); diff --git a/be/src/vec/aggregate_functions/aggregate_function_count_by_enum.cpp b/be/src/vec/aggregate_functions/aggregate_function_count_by_enum.cpp index c6fba1e76bb4ba..d05cf751d0db64 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_count_by_enum.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_count_by_enum.cpp @@ -30,6 +30,7 @@ namespace doris::vectorized { AggregateFunctionPtr create_aggregate_function_count_by_enum(const std::string& name, const DataTypes& argument_types, + const DataTypePtr& result_type, const bool result_is_nullable, const AggregateFunctionAttr& attr) { assert_arity_range(name, argument_types, 1, 1024); diff --git a/be/src/vec/aggregate_functions/aggregate_function_covar.cpp b/be/src/vec/aggregate_functions/aggregate_function_covar.cpp index f28efe228b97e1..a108589cf70b96 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_covar.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_covar.cpp @@ -35,6 +35,7 @@ AggregateFunctionPtr create_function_single_value(const String& name, AggregateFunctionPtr create_aggregate_function_covariance_samp(const std::string& name, const DataTypes& argument_types, + const DataTypePtr& result_type, const bool result_is_nullable, const AggregateFunctionAttr& attr) { return create_function_single_value( @@ -43,6 +44,7 @@ AggregateFunctionPtr create_aggregate_function_covariance_samp(const std::string AggregateFunctionPtr create_aggregate_function_covariance_pop(const std::string& name, const DataTypes& argument_types, + const DataTypePtr& result_type, const bool result_is_nullable, const AggregateFunctionAttr& attr) { return create_function_single_value( diff --git a/be/src/vec/aggregate_functions/aggregate_function_distinct.cpp b/be/src/vec/aggregate_functions/aggregate_function_distinct.cpp index 47bc3309699c3e..233222ea416730 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_distinct.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_distinct.cpp @@ -84,6 +84,7 @@ class AggregateFunctionCombinatorDistinct final : public IAggregateFunctionCombi void register_aggregate_function_combinator_distinct(AggregateFunctionSimpleFactory& factory) { AggregateFunctionCreator creator = [&](const std::string& name, const DataTypes& types, + const DataTypePtr& result_type, const bool result_is_nullable, const AggregateFunctionAttr& attr) { // 1. we should get not nullable types; @@ -93,8 +94,8 @@ void register_aggregate_function_combinator_distinct(AggregateFunctionSimpleFact auto function_combinator = std::make_shared(); auto transform_arguments = function_combinator->transform_arguments(nested_types); auto nested_function_name = name.substr(DISTINCT_FUNCTION_PREFIX.size()); - auto nested_function = factory.get(nested_function_name, transform_arguments, false, - BeExecVersionManager::get_newest_version(), attr); + auto nested_function = factory.get(nested_function_name, transform_arguments, result_type, + false, BeExecVersionManager::get_newest_version(), attr); return function_combinator->transform_aggregate_function(nested_function, types, result_is_nullable, attr); }; diff --git a/be/src/vec/aggregate_functions/aggregate_function_foreach.cpp b/be/src/vec/aggregate_functions/aggregate_function_foreach.cpp index 293621cc329383..845fb13dceb288 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_foreach.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_foreach.cpp @@ -33,7 +33,8 @@ namespace doris::vectorized { void register_aggregate_function_combinator_foreach(AggregateFunctionSimpleFactory& factory) { AggregateFunctionCreator creator = - [&](const std::string& name, const DataTypes& types, const bool result_is_nullable, + [&](const std::string& name, const DataTypes& types, const DataTypePtr& result_type, + const bool result_is_nullable, const AggregateFunctionAttr& attr) -> AggregateFunctionPtr { const std::string& suffix = AggregateFunctionForEach::AGG_FOREACH_SUFFIX; DataTypes transform_arguments; @@ -42,10 +43,13 @@ void register_aggregate_function_combinator_foreach(AggregateFunctionSimpleFacto assert_cast(remove_nullable(t).get())->get_nested_type(); transform_arguments.push_back((item_type)); } + auto result_item_type = + assert_cast(remove_nullable(result_type).get()) + ->get_nested_type(); auto nested_function_name = name.substr(0, name.size() - suffix.size()); auto nested_function = - factory.get(nested_function_name, transform_arguments, result_is_nullable, - BeExecVersionManager::get_newest_version(), attr); + factory.get(nested_function_name, transform_arguments, result_item_type, + result_is_nullable, BeExecVersionManager::get_newest_version(), attr); if (!nested_function) { throw Exception( ErrorCode::INTERNAL_ERROR, diff --git a/be/src/vec/aggregate_functions/aggregate_function_foreachv2.cpp b/be/src/vec/aggregate_functions/aggregate_function_foreachv2.cpp index 2ff8adc208f1f8..882aa2306a8b82 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_foreachv2.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_foreachv2.cpp @@ -78,7 +78,8 @@ class AggregateFunctionForEachV2 : public AggregateFunctionForEach { void register_aggregate_function_combinator_foreachv2(AggregateFunctionSimpleFactory& factory) { AggregateFunctionCreator creator = - [&](const std::string& name, const DataTypes& types, const bool result_is_nullable, + [&](const std::string& name, const DataTypes& types, const DataTypePtr& result_type, + const bool result_is_nullable, const AggregateFunctionAttr& attr) -> AggregateFunctionPtr { const std::string& suffix = AggregateFunctionForEachV2::AGG_FOREACH_SUFFIX; DataTypes transform_arguments; @@ -87,9 +88,13 @@ void register_aggregate_function_combinator_foreachv2(AggregateFunctionSimpleFac assert_cast(remove_nullable(t).get())->get_nested_type(); transform_arguments.push_back(item_type); } + auto result_item_type = + assert_cast(remove_nullable(result_type).get()) + ->get_nested_type(); auto nested_function_name = name.substr(0, name.size() - suffix.size()); - auto nested_function = factory.get(nested_function_name, transform_arguments, true, - BeExecVersionManager::get_newest_version(), attr); + auto nested_function = + factory.get(nested_function_name, transform_arguments, result_item_type, true, + BeExecVersionManager::get_newest_version(), attr); if (!nested_function) { throw Exception( ErrorCode::INTERNAL_ERROR, diff --git a/be/src/vec/aggregate_functions/aggregate_function_group_array_set_op.cpp b/be/src/vec/aggregate_functions/aggregate_function_group_array_set_op.cpp index b0eaf53179f222..5c7805cbc23c4d 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_group_array_set_op.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_group_array_set_op.cpp @@ -115,8 +115,8 @@ inline AggregateFunctionPtr create_aggregate_function_group_array_impl( } AggregateFunctionPtr create_aggregate_function_group_array_intersect( - const std::string& name, const DataTypes& argument_types, const bool result_is_nullable, - const AggregateFunctionAttr& attr) { + const std::string& name, const DataTypes& argument_types, const DataTypePtr& result_type, + const bool result_is_nullable, const AggregateFunctionAttr& attr) { assert_arity_range(name, argument_types, 1, 1); const DataTypePtr& argument_type = remove_nullable(argument_types[0]); @@ -133,8 +133,8 @@ AggregateFunctionPtr create_aggregate_function_group_array_intersect( } AggregateFunctionPtr create_aggregate_function_group_array_union( - const std::string& name, const DataTypes& argument_types, const bool result_is_nullable, - const AggregateFunctionAttr& attr) { + const std::string& name, const DataTypes& argument_types, const DataTypePtr& result_type, + const bool result_is_nullable, const AggregateFunctionAttr& attr) { assert_arity_range(name, argument_types, 1, 1); const DataTypePtr& argument_type = remove_nullable(argument_types[0]); diff --git a/be/src/vec/aggregate_functions/aggregate_function_group_concat.cpp b/be/src/vec/aggregate_functions/aggregate_function_group_concat.cpp index 75c3852b512cc6..58f0642b3b6557 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_group_concat.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_group_concat.cpp @@ -29,6 +29,7 @@ const std::string AggregateFunctionGroupConcatImplStr::separator = ","; AggregateFunctionPtr create_aggregate_function_group_concat(const std::string& name, const DataTypes& argument_types, + const DataTypePtr& result_type, const bool result_is_nullable, const AggregateFunctionAttr& attr) { if (argument_types.size() == 1) { diff --git a/be/src/vec/aggregate_functions/aggregate_function_histogram.cpp b/be/src/vec/aggregate_functions/aggregate_function_histogram.cpp index 4322792c3b9945..cda8d4626b361c 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_histogram.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_histogram.cpp @@ -36,6 +36,7 @@ using HistogramNormal = AggregateFunctionHistogram; AggregateFunctionPtr create_aggregate_function_histogram(const std::string& name, const DataTypes& argument_types, + const DataTypePtr& result_type, const bool result_is_nullable, const AggregateFunctionAttr& attr) { assert_arity_range(name, argument_types, 1, 2); diff --git a/be/src/vec/aggregate_functions/aggregate_function_kurtosis.cpp b/be/src/vec/aggregate_functions/aggregate_function_kurtosis.cpp index 6c6c0f0ca15ced..bf31382a36a1c3 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_kurtosis.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_kurtosis.cpp @@ -27,6 +27,7 @@ namespace doris::vectorized { AggregateFunctionPtr create_aggregate_function_kurt(const std::string& name, const DataTypes& argument_types, + const DataTypePtr& result_type, const bool result_is_nullable, const AggregateFunctionAttr& attr) { assert_arity_range(name, argument_types, 1, 1); diff --git a/be/src/vec/aggregate_functions/aggregate_function_linear_histogram.cpp b/be/src/vec/aggregate_functions/aggregate_function_linear_histogram.cpp index a140159c5541e5..07a999d04a5af8 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_linear_histogram.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_linear_histogram.cpp @@ -32,6 +32,7 @@ using HistogramNormal = AggregateFunctionLinearHistogram; AggregateFunctionPtr create_aggregate_function_linear_histogram(const std::string& name, const DataTypes& argument_types, + const DataTypePtr& result_type, const bool result_is_nullable, const AggregateFunctionAttr& attr) { using creator = creator_with_type_listget_primitive_type()) { diff --git a/be/src/vec/aggregate_functions/aggregate_function_map_v2.cpp b/be/src/vec/aggregate_functions/aggregate_function_map_v2.cpp index 964fc7aa22e009..65367e5b9d76d5 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_map_v2.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_map_v2.cpp @@ -32,6 +32,7 @@ AggregateFunctionPtr create_agg_function_map_agg_v2(const DataTypes& argument_ty AggregateFunctionPtr create_aggregate_function_map_agg_v2(const std::string& name, const DataTypes& argument_types, + const DataTypePtr& result_type, const bool result_is_nullable, const AggregateFunctionAttr& attr) { switch (argument_types[0]->get_primitive_type()) { diff --git a/be/src/vec/aggregate_functions/aggregate_function_min_max.cpp b/be/src/vec/aggregate_functions/aggregate_function_min_max.cpp index ccdbd9dfe45e13..844015cf081ed0 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_min_max.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_min_max.cpp @@ -34,6 +34,7 @@ namespace doris::vectorized { template