diff --git a/be/src/common/daemon.cpp b/be/src/common/daemon.cpp index 450521ebc123d5..71d4833ba07024 100644 --- a/be/src/common/daemon.cpp +++ b/be/src/common/daemon.cpp @@ -45,6 +45,7 @@ #include "exprs/es_functions.h" #include "exprs/timestamp_functions.h" #include "exprs/decimal_operators.h" +#include "exprs/decimalv2_operators.h" #include "exprs/utility_functions.h" #include "exprs/json_functions.h" #include "exprs/hll_hash_function.h" @@ -182,6 +183,7 @@ void init_daemon(int argc, char** argv, const std::vector& paths) { EncryptionFunctions::init(); TimestampFunctions::init(); DecimalOperators::init(); + DecimalV2Operators::init(); UtilityFunctions::init(); CompoundPredicate::init(); JsonFunctions::init(); diff --git a/be/src/exec/hash_table.cpp b/be/src/exec/hash_table.cpp index 1625febdeb7eac..a312e25184aa58 100644 --- a/be/src/exec/hash_table.cpp +++ b/be/src/exec/hash_table.cpp @@ -159,7 +159,7 @@ uint32_t HashTable::hash_variable_len_row() { StringValue* str = reinterpret_cast(loc); hash = HashUtil::hash(str->ptr, str->len, hash); } - } else if (_build_expr_ctxs[i]->root()->type().is_decimal_type()) { + } else if (_build_expr_ctxs[i]->root()->type().type == TYPE_DECIMAL) { void* loc = _expr_values_buffer + _expr_values_buffer_offsets[i]; if (_expr_value_null_bits[i]) { // Hash the null random seed values at 'loc' @@ -169,7 +169,7 @@ uint32_t HashTable::hash_variable_len_row() { hash = decimal->hash(hash); } } - + } return hash; @@ -410,7 +410,7 @@ Function* HashTable::codegen_eval_tuple_row(RuntimeState* state, bool build) { for (int i = 0; i < ctxs.size(); ++i) { PrimitiveType type = ctxs[i]->root()->type().type; if (type == TYPE_DATE || type == TYPE_DATETIME - || type == TYPE_DECIMAL || type == TYPE_CHAR) { + || type == TYPE_DECIMAL || type == TYPE_CHAR || type == TYPE_DECIMALV2) { return NULL; } } diff --git a/be/src/exec/olap_common.cpp b/be/src/exec/olap_common.cpp index 35f28c714f579d..a3827fb5f501b7 100644 --- a/be/src/exec/olap_common.cpp +++ b/be/src/exec/olap_common.cpp @@ -45,6 +45,11 @@ void ColumnValueRange::convert_to_fixed_value() { return; } +template<> +void ColumnValueRange::convert_to_fixed_value() { + return; +} + template<> void ColumnValueRange<__int128>::convert_to_fixed_value() { return; @@ -147,6 +152,7 @@ Status DorisScanRange::init() { case TYPE_VARCHAR: case TYPE_CHAR: case TYPE_DECIMAL: + case TYPE_DECIMALV2: case TYPE_DATE: case TYPE_DATETIME: break; diff --git a/be/src/exec/olap_common.h b/be/src/exec/olap_common.h index a0d0bb356af4ca..bb8304f0b83d6f 100644 --- a/be/src/exec/olap_common.h +++ b/be/src/exec/olap_common.h @@ -253,7 +253,8 @@ typedef boost::variant < ColumnValueRange<__int128>, ColumnValueRange, ColumnValueRange, - ColumnValueRange > ColumnValueRangeType; + ColumnValueRange, + ColumnValueRange > ColumnValueRangeType; class DorisScanRange { public: @@ -388,6 +389,9 @@ void ColumnValueRange::convert_to_fixed_value(); template<> void ColumnValueRange::convert_to_fixed_value(); +template<> +void ColumnValueRange::convert_to_fixed_value(); + template<> void ColumnValueRange<__int128>::convert_to_fixed_value(); diff --git a/be/src/exec/olap_rewrite_node.cpp b/be/src/exec/olap_rewrite_node.cpp index 611edc4979c753..645eddaa36ac40 100644 --- a/be/src/exec/olap_rewrite_node.cpp +++ b/be/src/exec/olap_rewrite_node.cpp @@ -58,10 +58,14 @@ Status OlapRewriteNode::prepare(RuntimeState* state) { new RowBatch(child(0)->row_desc(), state->batch_size(), state->fragment_mem_tracker())); _max_decimal_val.resize(_column_types.size()); + _max_decimalv2_val.resize(_column_types.size()); for (int i = 0; i < _column_types.size(); ++i) { if (_column_types[i].type == TPrimitiveType::DECIMAL) { _max_decimal_val[i].to_max_decimal( _column_types[i].precision, _column_types[i].scale); + } else if (_column_types[i].type == TPrimitiveType::DECIMALV2) { + _max_decimalv2_val[i].to_max_decimal( + _column_types[i].precision, _column_types[i].scale); } } return Status::OK; @@ -179,6 +183,24 @@ bool OlapRewriteNode::copy_one_row(TupleRow* src_row, Tuple* tuple, } break; } + case TPrimitiveType::DECIMALV2: { + DecimalV2Value* dec_val = (DecimalV2Value*)src_value; + DecimalV2Value* dst_val = (DecimalV2Value*)tuple->get_slot(slot_desc->tuple_offset()); + if (dec_val->greater_than_scale(column_type.scale)) { + int code = dec_val->round(dst_val, column_type.scale, HALF_UP); + if (code != E_DEC_OK) { + (*ss) << "round one decimal failed.value=" << dec_val->to_string(); + return false; + } + } else { + *reinterpret_cast(dst_val) = + *reinterpret_cast(dec_val); + } + if (*dst_val > _max_decimalv2_val[i]) { + dst_val->to_max_decimal(column_type.precision, column_type.scale); + } + break; + } default: { void* dst_val = (void*)tuple->get_slot(slot_desc->tuple_offset()); RawValue::write(src_value, dst_val, slot_desc->type(), pool); diff --git a/be/src/exec/olap_rewrite_node.h b/be/src/exec/olap_rewrite_node.h index cd9a7722d3f158..d6b2681bcfa14e 100644 --- a/be/src/exec/olap_rewrite_node.h +++ b/be/src/exec/olap_rewrite_node.h @@ -63,6 +63,7 @@ class OlapRewriteNode : public ExecNode { TupleDescriptor* _output_tuple_desc; std::vector _max_decimal_val; + std::vector _max_decimalv2_val; }; } diff --git a/be/src/exec/olap_scan_node.cpp b/be/src/exec/olap_scan_node.cpp index d851ef1e87f524..3547ccfedf4b79 100644 --- a/be/src/exec/olap_scan_node.cpp +++ b/be/src/exec/olap_scan_node.cpp @@ -491,6 +491,17 @@ Status OlapScanNode::normalize_conjuncts() { break; } + case TYPE_DECIMALV2: { + DecimalV2Value min = DecimalV2Value::get_min_decimal(); + DecimalV2Value max = DecimalV2Value::get_max_decimal(); + ColumnValueRange range(slots[slot_idx]->col_name(), + slots[slot_idx]->type().type, + min, + max); + normalize_predicate(range, slots[slot_idx]); + break; + } + default: { VLOG(2) << "Unsupport Normalize Slot [ColName=" << slots[slot_idx]->col_name() << "]"; @@ -739,6 +750,7 @@ Status OlapScanNode::normalize_in_predicate(SlotDescriptor* slot, ColumnValueRan break; } case TYPE_DECIMAL: + case TYPE_DECIMALV2: case TYPE_LARGEINT: case TYPE_CHAR: case TYPE_VARCHAR: @@ -807,6 +819,7 @@ Status OlapScanNode::normalize_in_predicate(SlotDescriptor* slot, ColumnValueRan break; } case TYPE_DECIMAL: + case TYPE_DECIMALV2: case TYPE_CHAR: case TYPE_VARCHAR: case TYPE_HLL: @@ -919,6 +932,7 @@ Status OlapScanNode::normalize_binary_predicate(SlotDescriptor* slot, ColumnValu break; } case TYPE_DECIMAL: + case TYPE_DECIMALV2: case TYPE_CHAR: case TYPE_VARCHAR: case TYPE_HLL: diff --git a/be/src/exec/olap_scanner.cpp b/be/src/exec/olap_scanner.cpp index 2f1ea579ba1560..222708e840c26c 100644 --- a/be/src/exec/olap_scanner.cpp +++ b/be/src/exec/olap_scanner.cpp @@ -385,6 +385,16 @@ void OlapScanner::_convert_row_to_tuple(Tuple* tuple) { *slot = DecimalValue(int_value, frac_value); break; } + case TYPE_DECIMALV2: { + DecimalV2Value *slot = tuple->get_decimalv2_slot(slot_desc->tuple_offset()); + + int64_t int_value = *(int64_t*)(ptr); + int32_t frac_value = *(int32_t*)(ptr + sizeof(int64_t)); + if (!slot->from_olap_decimal(int_value, frac_value)) { + tuple->set_null(slot_desc->null_indicator_offset()); + } + break; + } case TYPE_DATETIME: { DateTimeValue *slot = tuple->get_datetime_slot(slot_desc->tuple_offset()); uint64_t value = *reinterpret_cast(ptr); diff --git a/be/src/exec/olap_table_sink.cpp b/be/src/exec/olap_table_sink.cpp index 5bf3fd93fc54cd..6e4a7b3aacba09 100644 --- a/be/src/exec/olap_table_sink.cpp +++ b/be/src/exec/olap_table_sink.cpp @@ -461,6 +461,9 @@ Status OlapTableSink::prepare(RuntimeState* state) { _max_decimal_val.resize(_output_tuple_desc->slots().size()); _min_decimal_val.resize(_output_tuple_desc->slots().size()); + + _max_decimalv2_val.resize(_output_tuple_desc->slots().size()); + _min_decimalv2_val.resize(_output_tuple_desc->slots().size()); // check if need validate batch for (int i = 0; i < _output_tuple_desc->slots().size(); ++i) { auto slot = _output_tuple_desc->slots()[i]; @@ -470,6 +473,11 @@ Status OlapTableSink::prepare(RuntimeState* state) { _min_decimal_val[i].to_min_decimal(slot->type().precision, slot->type().scale); _need_validate_data = true; break; + case TYPE_DECIMALV2: + _max_decimalv2_val[i].to_max_decimal(slot->type().precision, slot->type().scale); + _min_decimalv2_val[i].to_min_decimal(slot->type().precision, slot->type().scale); + _need_validate_data = true; + break; case TYPE_CHAR: case TYPE_VARCHAR: case TYPE_DATE: @@ -716,6 +724,44 @@ int OlapTableSink::_validate_data(RuntimeState* state, RowBatch* batch, Bitmap* LOG(INFO) << ss.str(); #else state->append_error_msg_to_file("", ss.str()); +#endif + filtered_rows++; + row_valid = false; + filter_bitmap->Set(row_no, true); + continue; + } + break; + } + case TYPE_DECIMALV2: { + DecimalV2Value dec_val(reinterpret_cast(slot)->value); + if (dec_val.greater_than_scale(desc->type().scale)) { + int code = dec_val.round(&dec_val, desc->type().scale, HALF_UP); + reinterpret_cast(slot)->value = dec_val.value(); + if (code != E_DEC_OK) { + std::stringstream ss; + ss << "round one decimal failed.value=" << dec_val.to_string(); +#if BE_TEST + LOG(INFO) << ss.str(); +#else + state->append_error_msg_to_file("", ss.str()); +#endif + + filtered_rows++; + row_valid = false; + filter_bitmap->Set(row_no, true); + continue; + } + } + if (dec_val > _max_decimalv2_val[i] || dec_val < _min_decimalv2_val[i]) { + std::stringstream ss; + ss << "decimal value is not valid for defination, column=" << desc->col_name() + << ", value=" << dec_val.to_string() + << ", precision=" << desc->type().precision + << ", scale=" << desc->type().scale; +#if BE_TEST + LOG(INFO) << ss.str(); +#else + state->append_error_msg_to_file("", ss.str()); #endif filtered_rows++; row_valid = false; diff --git a/be/src/exec/olap_table_sink.h b/be/src/exec/olap_table_sink.h index 44866d2374e984..d4856bc0067496 100644 --- a/be/src/exec/olap_table_sink.h +++ b/be/src/exec/olap_table_sink.h @@ -229,6 +229,9 @@ class OlapTableSink : public DataSink { std::vector _max_decimal_val; std::vector _min_decimal_val; + std::vector _max_decimalv2_val; + std::vector _min_decimalv2_val; + // Stats for this int64_t _convert_batch_ns = 0; int64_t _validate_data_ns = 0; diff --git a/be/src/exec/olap_utils.h b/be/src/exec/olap_utils.h index 596abc0c08a463..4925e895b35aad 100644 --- a/be/src/exec/olap_utils.h +++ b/be/src/exec/olap_utils.h @@ -68,6 +68,9 @@ inline CompareLargeFunc get_compare_func(PrimitiveType type) { case TYPE_DECIMAL: return compare_large; + case TYPE_DECIMALV2: + return compare_large; + case TYPE_CHAR: case TYPE_VARCHAR: return compare_large; @@ -182,6 +185,7 @@ inline int get_olap_size(PrimitiveType type) { return 8; } + case TYPE_DECIMALV2: case TYPE_DECIMAL: { return 12; } diff --git a/be/src/exec/partitioned_aggregation_node.cc b/be/src/exec/partitioned_aggregation_node.cc index dfddb3b85b38bc..d39f88ccafa0d6 100644 --- a/be/src/exec/partitioned_aggregation_node.cc +++ b/be/src/exec/partitioned_aggregation_node.cc @@ -1289,7 +1289,7 @@ llvm::Function* PartitionedAggregationNode::codegen_update_slot( break; } case AggFnEvaluator::SUM: - if (slot_desc->type().type != TYPE_DECIMAL) { + if (slot_desc->type().type != TYPE_DECIMAL && slot_desc->type().type != TYPE_DECIMALV2) { if (slot_desc->type().type == TYPE_FLOAT || slot_desc->type().type == TYPE_DOUBLE) { result = builder.CreateFAdd(dst_value, src.GetVal()); @@ -1298,7 +1298,7 @@ llvm::Function* PartitionedAggregationNode::codegen_update_slot( } break; } - DCHECK_EQ(slot_desc->type().type, TYPE_DECIMAL); + DCHECK(slot_desc->type().type == TYPE_DECIMAL || slot_desc->type().type == TYPE_DECIMALV2); // Fall through to xcompiled case case AggFnEvaluator::AVG: case AggFnEvaluator::NDV: { @@ -1422,6 +1422,11 @@ Function* PartitionedAggregationNode::codegen_update_tuple() { op == AggFnEvaluator::NDV)) { supported = false; } + if (type == TYPE_DECIMALV2 && + !(op == AggFnEvaluator::SUM || op == AggFnEvaluator::AVG || + op == AggFnEvaluator::NDV)) { + supported = false; + } if (!supported) { VLOG_QUERY << "Could not codegen update_tuple because intermediate type " << slot_desc->type() diff --git a/be/src/exec/pre_aggregation_node.cpp b/be/src/exec/pre_aggregation_node.cpp index 32eeb10cab49d9..46305dfe271adc 100644 --- a/be/src/exec/pre_aggregation_node.cpp +++ b/be/src/exec/pre_aggregation_node.cpp @@ -547,6 +547,10 @@ Status PreAggregationNode::update_agg_row(TupleRow* agg_row, TupleRow* probe_row UpdateMinSlot(slot, value); break; + case TYPE_DECIMALV2: + UpdateMinSlot(slot, value); + break; + default: LOG(WARNING) << "invalid type: " << type_to_string(agg_expr->type()); return Status("unknown type"); @@ -593,6 +597,10 @@ Status PreAggregationNode::update_agg_row(TupleRow* agg_row, TupleRow* probe_row UpdateMaxSlot(slot, value); break; + case TYPE_DECIMALV2: + UpdateMaxSlot(slot, value); + break; + default: LOG(WARNING) << "invalid type: " << type_to_string(agg_expr->type()); return Status("unknown type"); @@ -614,6 +622,10 @@ Status PreAggregationNode::update_agg_row(TupleRow* agg_row, TupleRow* probe_row UpdateSumSlot(slot, value); break; + case TYPE_DECIMALV2: + UpdateSumSlot(slot, value); + break; + default: LOG(WARNING) << "invalid type: " << type_to_string(agg_expr->type()); return Status("Aggsum not valid."); diff --git a/be/src/exec/schema_scanner/schema_columns_scanner.cpp b/be/src/exec/schema_scanner/schema_columns_scanner.cpp index c0ffcfa9b0abc7..9a54887ed8d20a 100644 --- a/be/src/exec/schema_scanner/schema_columns_scanner.cpp +++ b/be/src/exec/schema_scanner/schema_columns_scanner.cpp @@ -117,6 +117,7 @@ std::string SchemaColumnsScanner::type_to_string(TColumnDesc &desc) { return "date"; case TPrimitiveType::DATETIME: return "datetime"; + case TPrimitiveType::DECIMALV2: case TPrimitiveType::DECIMAL: { std::stringstream stream; stream << "decimal("; diff --git a/be/src/exec/text_converter.hpp b/be/src/exec/text_converter.hpp index 3ddbc0a081205c..6d3001aa2e6f03 100644 --- a/be/src/exec/text_converter.hpp +++ b/be/src/exec/text_converter.hpp @@ -23,6 +23,7 @@ #include #include "runtime/decimal_value.h" +#include "runtime/decimalv2_value.h" #include "runtime/descriptors.h" #include "runtime/mem_pool.h" #include "runtime/runtime_state.h" @@ -30,6 +31,7 @@ #include "runtime/datetime_value.h" #include "runtime/tuple.h" #include "util/string_parser.hpp" +#include "util/types.h" #include "olap/utils.h" namespace doris { @@ -162,6 +164,19 @@ inline bool TextConverter::write_slot(const SlotDescriptor* slot_desc, break; } + case TYPE_DECIMALV2: { + DecimalV2Value decimal_slot; + + if (decimal_slot.parse_from_str(data, len)) { + parse_result = StringParser::PARSE_FAILURE; + } + + *reinterpret_cast(slot) = + *reinterpret_cast(&decimal_slot); + + break; + } + default: DCHECK(false) << "bad slot type: " << slot_desc->type(); break; diff --git a/be/src/exprs/CMakeLists.txt b/be/src/exprs/CMakeLists.txt index a814d013e35512..658fe8452c58ed 100644 --- a/be/src/exprs/CMakeLists.txt +++ b/be/src/exprs/CMakeLists.txt @@ -36,6 +36,7 @@ add_library(Exprs conditional_functions.cpp conditional_functions_ir.cpp decimal_operators.cpp + decimalv2_operators.cpp es_functions.cpp literal.cpp expr.cpp diff --git a/be/src/exprs/agg_fn_evaluator.cpp b/be/src/exprs/agg_fn_evaluator.cpp index ac71d6e3a31292..f46b0ae6ce7760 100755 --- a/be/src/exprs/agg_fn_evaluator.cpp +++ b/be/src/exprs/agg_fn_evaluator.cpp @@ -43,6 +43,7 @@ using doris_udf::LargeIntVal; using doris_udf::FloatVal; using doris_udf::DoubleVal; using doris_udf::DecimalVal; +using doris_udf::DecimalV2Val; using doris_udf::DateTimeVal; using doris_udf::StringVal; using doris_udf::AnyVal; @@ -344,6 +345,11 @@ inline void AggFnEvaluator::set_any_val( reinterpret_cast(dst)); return; + case TYPE_DECIMALV2: + reinterpret_cast(dst)->val + = reinterpret_cast(slot)->value; + return; + case TYPE_LARGEINT: memcpy(&reinterpret_cast(dst)->val, slot, sizeof(__int128)); return; @@ -413,6 +419,11 @@ inline void AggFnEvaluator::set_output_slot(const AnyVal* src, *reinterpret_cast(src)); return; + case TYPE_DECIMALV2: + *reinterpret_cast(slot) = + reinterpret_cast(src)->val; + return; + case TYPE_LARGEINT: { memcpy(slot, &reinterpret_cast(src)->val, sizeof(__int128)); return; @@ -578,6 +589,13 @@ bool AggFnEvaluator::count_distinct_data_filter(TupleRow* row, Tuple* dst) { break; } + case TYPE_DECIMALV2: { + DecimalV2Val* value = reinterpret_cast(_staging_input_vals[i]); + memcpy(begin, value, sizeof(DecimalV2Val)); + begin += sizeof(DecimalV2Val); + break; + } + case TYPE_CHAR: case TYPE_VARCHAR: case TYPE_HLL: { @@ -656,6 +674,14 @@ bool AggFnEvaluator::sum_distinct_data_filter(TupleRow* row, Tuple* dst) { return is_filter; } + case TYPE_DECIMALV2: { + const DecimalV2Val* value = reinterpret_cast(_staging_input_vals[0]); + DecimalV2Value temp_value = DecimalV2Value::from_decimal_val(*value); + is_filter = is_in_hybirdmap((void*) & (temp_value), dst, &is_add_buckets); + update_mem_trackers(is_filter, is_add_buckets, DECIMALV2_SIZE); + return is_filter; + } + case TYPE_LARGEINT: { const LargeIntVal* value = reinterpret_cast(_staging_input_vals[0]); is_filter = is_in_hybirdmap((void*) & (value->val), dst, &is_add_buckets); @@ -936,6 +962,13 @@ void AggFnEvaluator::serialize_or_finalize(FunctionContext* agg_fn_ctx, Tuple* s break; } + case TYPE_DECIMALV2: { + typedef DecimalV2Val(*Fn)(FunctionContext*, AnyVal*); + DecimalV2Val v = reinterpret_cast(fn)(agg_fn_ctx, _staging_intermediate_val); + set_output_slot(&v, dst_slot_desc, dst); + break; + } + default: DCHECK(false) << "NYI"; } diff --git a/be/src/exprs/agg_fn_evaluator.h b/be/src/exprs/agg_fn_evaluator.h index 3c459409e313cd..09a7fce72d1165 100755 --- a/be/src/exprs/agg_fn_evaluator.h +++ b/be/src/exprs/agg_fn_evaluator.h @@ -160,6 +160,7 @@ class AggFnEvaluator { static const size_t FLOAT_SIZE = sizeof(float); static const size_t DOUBLE_SIZE = sizeof(double); static const size_t DECIMAL_SIZE = sizeof(DecimalValue); + static const size_t DECIMALV2_SIZE = sizeof(DecimalV2Value); static const size_t TIME_DURATION_SIZE = sizeof(boost::posix_time::time_duration); static const size_t DATE_SIZE = sizeof(boost::gregorian::date); static const size_t LARGEINT_SIZE = sizeof(__int128); diff --git a/be/src/exprs/aggregate_functions.cpp b/be/src/exprs/aggregate_functions.cpp index 0a7821cf4e1c7c..8f49189cedc865 100644 --- a/be/src/exprs/aggregate_functions.cpp +++ b/be/src/exprs/aggregate_functions.cpp @@ -42,6 +42,7 @@ using doris_udf::LargeIntVal; using doris_udf::FloatVal; using doris_udf::DoubleVal; using doris_udf::DecimalVal; +using doris_udf::DecimalV2Val; using doris_udf::DateTimeVal; using doris_udf::StringVal; using doris_udf::AnyVal; @@ -69,6 +70,11 @@ void AggregateFunctions::init_zero(FunctionContext*, DecimalVal* dst) { dst->set_to_zero(); } +template<> +void AggregateFunctions::init_zero(FunctionContext*, DecimalV2Val* dst) { + dst->set_to_zero(); +} + template void AggregateFunctions::sum_remove(FunctionContext* ctx, const SRC_VAL& src, DST_VAL* dst) { @@ -109,6 +115,27 @@ void AggregateFunctions::sum_remove(FunctionContext* ctx, const DecimalVal& src, new_dst.to_decimal_val(dst); } +template<> +void AggregateFunctions::sum_remove(FunctionContext* ctx, const DecimalV2Val& src, + DecimalV2Val* dst) { + if (ctx->impl()->num_removes() >= ctx->impl()->num_updates()) { + *dst = DecimalV2Val::null(); + return; + } + if (src.is_null) { + return; + } + if (dst->is_null) { + init_zero(ctx, dst); + } + + DecimalV2Value new_src = DecimalV2Value::from_decimal_val(src); + DecimalV2Value new_dst = DecimalV2Value::from_decimal_val(*dst); + new_dst = new_dst - new_src; + new_dst.to_decimal_val(dst); +} + + StringVal AggregateFunctions::string_val_get_value( FunctionContext* ctx, const StringVal& src) { if (src.is_null) { @@ -163,6 +190,11 @@ struct DecimalAvgState { int64_t count; }; +struct DecimalV2AvgState { + DecimalV2Val sum; + int64_t count; +}; + void AggregateFunctions::avg_init(FunctionContext* ctx, StringVal* dst) { dst->is_null = false; dst->len = sizeof(AvgState); @@ -180,6 +212,17 @@ void AggregateFunctions::decimal_avg_init(FunctionContext* ctx, StringVal* dst) avg->sum.set_to_zero(); } +void AggregateFunctions::decimalv2_avg_init(FunctionContext* ctx, StringVal* dst) { + dst->is_null = false; + dst->len = sizeof(DecimalV2AvgState); + dst->ptr = ctx->allocate(dst->len); + // memset(dst->ptr, 0, sizeof(DecimalAvgState)); + DecimalV2AvgState* avg = reinterpret_cast(dst->ptr); + avg->count = 0; + avg->sum.set_to_zero(); +} + + template void AggregateFunctions::avg_update(FunctionContext* ctx, const T& src, StringVal* dst) { if (src.is_null) { @@ -210,6 +253,24 @@ void AggregateFunctions::decimal_avg_update(FunctionContext* ctx, ++avg->count; } +void AggregateFunctions::decimalv2_avg_update(FunctionContext* ctx, + const DecimalV2Val& src, + StringVal* dst) { + if (src.is_null) { + return; + } + DCHECK(dst->ptr != NULL); + DCHECK_EQ(sizeof(DecimalV2AvgState), dst->len); + DecimalV2AvgState* avg = reinterpret_cast(dst->ptr); + + DecimalV2Value v1 = DecimalV2Value::from_decimal_val(avg->sum); + DecimalV2Value v2 = DecimalV2Value::from_decimal_val(src); + DecimalV2Value v = v1 + v2; + v.to_decimal_val(&avg->sum); + + ++avg->count; +} + template void AggregateFunctions::avg_remove(FunctionContext* ctx, const T& src, StringVal* dst) { // Remove doesn't need to explicitly check the number of calls to Update() or Remove() @@ -246,6 +307,27 @@ void AggregateFunctions::decimal_avg_remove(doris_udf::FunctionContext* ctx, DCHECK_GE(avg->count, 0); } +void AggregateFunctions::decimalv2_avg_remove(doris_udf::FunctionContext* ctx, + const DecimalV2Val& src, + StringVal* dst) { + // Remove doesn't need to explicitly check the number of calls to Update() or Remove() + // because Finalize() returns NULL if count is 0. + if (src.is_null) { + return; + } + DCHECK(dst->ptr != NULL); + DCHECK_EQ(sizeof(DecimalV2AvgState), dst->len); + DecimalV2AvgState* avg = reinterpret_cast(dst->ptr); + + DecimalV2Value v1 = DecimalV2Value::from_decimal_val(avg->sum); + DecimalV2Value v2 = DecimalV2Value::from_decimal_val(src); + DecimalV2Value v = v1 - v2; + v.to_decimal_val(&avg->sum); + + --avg->count; + DCHECK_GE(avg->count, 0); +} + void AggregateFunctions::avg_merge(FunctionContext* ctx, const StringVal& src, StringVal* dst) { const AvgState* src_struct = reinterpret_cast(src.ptr); @@ -270,6 +352,20 @@ void AggregateFunctions::decimal_avg_merge(FunctionContext* ctx, const StringVal dst_struct->count += src_struct->count; } +void AggregateFunctions::decimalv2_avg_merge(FunctionContext* ctx, const StringVal& src, + StringVal* dst) { + const DecimalV2AvgState* src_struct = reinterpret_cast(src.ptr); + DCHECK(dst->ptr != NULL); + DCHECK_EQ(sizeof(DecimalV2AvgState), dst->len); + DecimalV2AvgState* dst_struct = reinterpret_cast(dst->ptr); + + DecimalV2Value v1 = DecimalV2Value::from_decimal_val(dst_struct->sum); + DecimalV2Value v2 = DecimalV2Value::from_decimal_val(src_struct->sum); + DecimalV2Value v = v1 + v2; + v.to_decimal_val(&dst_struct->sum); + dst_struct->count += src_struct->count; +} + DoubleVal AggregateFunctions::avg_get_value(FunctionContext* ctx, const StringVal& src) { AvgState* val_struct = reinterpret_cast(src.ptr); if (val_struct->count == 0) { @@ -291,6 +387,19 @@ DecimalVal AggregateFunctions::decimal_avg_get_value(FunctionContext* ctx, const return res; } +DecimalV2Val AggregateFunctions::decimalv2_avg_get_value(FunctionContext* ctx, const StringVal& src) { + DecimalV2AvgState* val_struct = reinterpret_cast(src.ptr); + if (val_struct->count == 0) { + return DecimalV2Val::null(); + } + DecimalV2Value v1 = DecimalV2Value::from_decimal_val(val_struct->sum); + DecimalV2Value v = v1 / DecimalV2Value(val_struct->count, 0); + DecimalV2Val res; + v.to_decimal_val(&res); + + return res; +} + DoubleVal AggregateFunctions::avg_finalize(FunctionContext* ctx, const StringVal& src) { if (src.is_null) { return DoubleVal::null(); @@ -309,6 +418,15 @@ DecimalVal AggregateFunctions::decimal_avg_finalize(FunctionContext* ctx, const return result; } +DecimalV2Val AggregateFunctions::decimalv2_avg_finalize(FunctionContext* ctx, const StringVal& src) { + if (src.is_null) { + return DecimalV2Val::null(); + } + DecimalV2Val result = decimalv2_avg_get_value(ctx, src); + ctx->free(src.ptr); + return result; +} + void AggregateFunctions::timestamp_avg_update(FunctionContext* ctx, const DateTimeVal& src, StringVal* dst) { if (src.is_null) { @@ -399,6 +517,23 @@ void AggregateFunctions::sum(FunctionContext* ctx, const DecimalVal& src, Decima new_dst.to_decimal_val(dst); } +template<> +void AggregateFunctions::sum(FunctionContext* ctx, const DecimalV2Val& src, DecimalV2Val* dst) { + if (src.is_null) { + return; + } + + if (dst->is_null) { + dst->is_null = false; + dst->set_to_zero(); + } + + DecimalV2Value new_src = DecimalV2Value::from_decimal_val(src); + DecimalV2Value new_dst = DecimalV2Value::from_decimal_val(*dst); + new_dst = new_dst + new_src; + new_dst.to_decimal_val(dst); +} + template<> void AggregateFunctions::sum(FunctionContext* ctx, const LargeIntVal& src, LargeIntVal* dst) { if (src.is_null) { @@ -453,6 +588,25 @@ void AggregateFunctions::min(FunctionContext*, const DecimalVal& src, DecimalVal } } +template<> +void AggregateFunctions::min(FunctionContext*, const DecimalV2Val& src, DecimalV2Val* dst) { + if (src.is_null) { + return; + } + + if (dst->is_null) { + *dst = src; + } else { + DecimalV2Value new_src = DecimalV2Value::from_decimal_val(src); + DecimalV2Value new_dst = DecimalV2Value::from_decimal_val(*dst); + + if (new_src < new_dst) { + *dst = src; + } + } +} + + template<> void AggregateFunctions::min(FunctionContext*, const LargeIntVal& src, LargeIntVal* dst) { if (src.is_null) { @@ -487,6 +641,25 @@ void AggregateFunctions::max(FunctionContext*, const DecimalVal& src, DecimalVal } } +template<> +void AggregateFunctions::max(FunctionContext*, const DecimalV2Val& src, DecimalV2Val* dst) { + if (src.is_null) { + return; + } + + if (dst->is_null) { + *dst = src; + } else { + DecimalV2Value new_src = DecimalV2Value::from_decimal_val(src); + DecimalV2Value new_dst = DecimalV2Value::from_decimal_val(*dst); + + if (new_src > new_dst) { + *dst = src; + } + } +} + + template<> void AggregateFunctions::max(FunctionContext*, const LargeIntVal& src, LargeIntVal* dst) { if (src.is_null) { @@ -1397,6 +1570,90 @@ class MultiDistinctDecimalState { FunctionContext::Type _type; }; +class MultiDistinctDecimalV2State { +public: + + static void create(StringVal* dst) { + dst->is_null = false; + const int state_size = sizeof(MultiDistinctDecimalV2State); + MultiDistinctDecimalV2State* state = new MultiDistinctDecimalV2State(); + state->_type = FunctionContext::TYPE_DECIMALV2; + dst->len = state_size; + dst->ptr = (uint8_t*)state; + } + + static void destory(const StringVal& dst) { + delete (MultiDistinctDecimalV2State*)dst.ptr; + } + + void update(DecimalV2Val& t) { + _set.insert(DecimalV2Value::from_decimal_val(t)); + } + + // type:one byte value:sizeof(T) + StringVal serialize(FunctionContext* ctx) { + const int serialized_set_length = sizeof(uint8_t) + + DECIMAL_BYTE_SIZE * _set.size(); + StringVal result(ctx, serialized_set_length); + uint8_t* writer = result.ptr; + *writer = (uint8_t)_type; + writer++; + // for int_length and frac_length, uint8_t will not overflow. + for (auto& value : _set) { + __int128 v = value.value(); + memcpy(writer, &v, DECIMAL_BYTE_SIZE); + writer += DECIMAL_BYTE_SIZE; + } + return result; + } + + void unserialize(StringVal& src) { + const uint8_t* reader = src.ptr; + // type + _type = (FunctionContext::Type)*reader; + reader++; + const uint8_t* end = src.ptr + src.len; + // value + while (reader < end) { + __int128 v = 0; + memcpy(&v, reader, DECIMAL_BYTE_SIZE); + DecimalV2Value value(v); + reader += DECIMAL_BYTE_SIZE; + _set.insert(value); + } + } + + FunctionContext::Type set_type() { + return _type; + } + + // merge set + void merge(MultiDistinctDecimalV2State& state) { + _set.insert(state._set.begin(), state._set.end()); + } + + // count + BigIntVal count_finalize() { + return BigIntVal(_set.size()); + } + + DecimalV2Val sum_finalize() { + DecimalV2Value sum; + for (auto& value : _set) { + sum += value; + } + DecimalV2Val result; + sum.to_decimal_val(&result); + return result; + } + +private: + const int DECIMAL_BYTE_SIZE = 16; + + std::unordered_set _set; + FunctionContext::Type _type; +}; + // multi distinct state for date // serialize order type:packed_time:type:packed_time:type ... class MultiDistinctCountDateState { @@ -1503,6 +1760,10 @@ void AggregateFunctions::count_distinct_string_init(FunctionContext* ctx, String void AggregateFunctions::count_or_sum_distinct_decimal_init(FunctionContext* ctx, StringVal* dst) { MultiDistinctDecimalState::create(dst); } + +void AggregateFunctions::count_or_sum_distinct_decimalv2_init(FunctionContext* ctx, StringVal* dst) { + MultiDistinctDecimalV2State::create(dst); +} void AggregateFunctions::count_distinct_date_init(FunctionContext* ctx, StringVal* dst) { MultiDistinctCountDateState::create(dst); @@ -1533,7 +1794,15 @@ void AggregateFunctions::count_or_sum_distinct_decimal_update(FunctionContext* c MultiDistinctDecimalState* state = reinterpret_cast(dst->ptr); state->update(src); } - + +void AggregateFunctions::count_or_sum_distinct_decimalv2_update(FunctionContext* ctx, DecimalV2Val& src, + StringVal* dst) { + DCHECK(!dst->is_null); + if (src.is_null) return; + MultiDistinctDecimalV2State* state = reinterpret_cast(dst->ptr); + state->update(src); +} + void AggregateFunctions::count_distinct_date_update(FunctionContext* ctx, DateTimeVal& src, StringVal* dst) { DCHECK(!dst->is_null); @@ -1588,6 +1857,21 @@ void AggregateFunctions::count_or_sum_distinct_decimal_merge(FunctionContext* ct dst_state->merge(*src_state); MultiDistinctDecimalState::destory(src_state_val); } + +void AggregateFunctions::count_or_sum_distinct_decimalv2_merge(FunctionContext* ctx, StringVal& src, + StringVal* dst) { + DCHECK(!dst->is_null); + DCHECK(!src.is_null); + MultiDistinctDecimalV2State* dst_state = reinterpret_cast(dst->ptr); + // unserialize src + StringVal src_state_val; + MultiDistinctDecimalV2State::create(&src_state_val); + MultiDistinctDecimalV2State* src_state = reinterpret_cast(src_state_val.ptr); + src_state->unserialize(src); + DCHECK(dst_state->set_type() == src_state->set_type()); + dst_state->merge(*src_state); + MultiDistinctDecimalV2State::destory(src_state_val); +} void AggregateFunctions::count_distinct_date_merge(FunctionContext* ctx, StringVal& src, StringVal* dst) { @@ -1632,6 +1916,15 @@ StringVal AggregateFunctions::count_or_sum_distinct_decimal_serialize(FunctionCo return result; } +StringVal AggregateFunctions::count_or_sum_distinct_decimalv2_serialize(FunctionContext* ctx, const StringVal& state_sv) { + DCHECK(!state_sv.is_null); + MultiDistinctDecimalV2State* state = reinterpret_cast(state_sv.ptr); + StringVal result = state->serialize(ctx); + // release original object + MultiDistinctDecimalV2State::destory(state_sv); + return result; +} + StringVal AggregateFunctions::count_distinct_date_serialize(FunctionContext* ctx, const StringVal& state_sv) { DCHECK(!state_sv.is_null); MultiDistinctCountDateState* state = reinterpret_cast(state_sv.ptr); @@ -1692,6 +1985,14 @@ BigIntVal AggregateFunctions::count_distinct_decimal_finalize(FunctionContext* c MultiDistinctDecimalState::destory(state_sv); return result; } + +BigIntVal AggregateFunctions::count_distinct_decimalv2_finalize(FunctionContext* ctx, const StringVal& state_sv) { + DCHECK(!state_sv.is_null); + MultiDistinctDecimalV2State* state = reinterpret_cast(state_sv.ptr); + BigIntVal result = state->count_finalize(); + MultiDistinctDecimalV2State::destory(state_sv); + return result; +} DecimalVal AggregateFunctions::sum_distinct_decimal_finalize(FunctionContext* ctx, const StringVal& state_sv) { DCHECK(!state_sv.is_null); @@ -1700,6 +2001,14 @@ DecimalVal AggregateFunctions::sum_distinct_decimal_finalize(FunctionContext* ct MultiDistinctDecimalState::destory(state_sv); return result; } + +DecimalV2Val AggregateFunctions::sum_distinct_decimalv2_finalize(FunctionContext* ctx, const StringVal& state_sv) { + DCHECK(!state_sv.is_null); + MultiDistinctDecimalV2State* state = reinterpret_cast(state_sv.ptr); + DecimalV2Val result = state->sum_finalize(); + MultiDistinctDecimalV2State::destory(state_sv); + return result; +} BigIntVal AggregateFunctions::count_distinct_date_finalize(FunctionContext* ctx, const StringVal& state_sv) { DCHECK(!state_sv.is_null); @@ -2016,6 +2325,8 @@ template void AggregateFunctions::sum_remove( FunctionContext*, const DoubleVal& src, DoubleVal* dst); template void AggregateFunctions::sum_remove( FunctionContext*, const DecimalVal& src, DecimalVal* dst); +template void AggregateFunctions::sum_remove( + FunctionContext*, const DecimalV2Val& src, DecimalV2Val* dst); template void AggregateFunctions::sum_remove( FunctionContext*, const LargeIntVal& src, LargeIntVal* dst); @@ -2162,6 +2473,8 @@ template void AggregateFunctions::hll_update( FunctionContext*, const LargeIntVal&, StringVal*); template void AggregateFunctions::hll_update( FunctionContext*, const DecimalVal&, StringVal*); +template void AggregateFunctions::hll_update( + FunctionContext*, const DecimalV2Val&, StringVal*); template void AggregateFunctions::count_or_sum_distinct_numeric_init( FunctionContext* ctx, StringVal* dst); @@ -2306,13 +2619,17 @@ template void AggregateFunctions::first_val_rewrite_update( FunctionContext*, const DateTimeVal& src, const BigIntVal&, DateTimeVal* dst); template void AggregateFunctions::first_val_rewrite_update( FunctionContext*, const DecimalVal& src, const BigIntVal&, DecimalVal* dst); - +template void AggregateFunctions::first_val_rewrite_update( + FunctionContext*, const DecimalV2Val& src, const BigIntVal&, DecimalV2Val* dst); //template void AggregateFunctions::FirstValUpdate( // doris_udf::FunctionContext*, impala::StringValue const&, impala::StringValue*); template void AggregateFunctions::first_val_update( doris_udf::FunctionContext*, doris_udf::DecimalVal const&, doris_udf::DecimalVal*); +template void AggregateFunctions::first_val_update( + doris_udf::FunctionContext*, doris_udf::DecimalV2Val const&, doris_udf::DecimalV2Val*); + template void AggregateFunctions::last_val_update( FunctionContext*, const BooleanVal& src, BooleanVal* dst); template void AggregateFunctions::last_val_update( @@ -2333,6 +2650,8 @@ template void AggregateFunctions::last_val_update( FunctionContext*, const DateTimeVal& src, DateTimeVal* dst); template void AggregateFunctions::last_val_update( FunctionContext*, const DecimalVal& src, DecimalVal* dst); +template void AggregateFunctions::last_val_update( + FunctionContext*, const DecimalV2Val& src, DecimalV2Val* dst); template void AggregateFunctions::last_val_remove( FunctionContext*, const BooleanVal& src, BooleanVal* dst); @@ -2354,6 +2673,8 @@ template void AggregateFunctions::last_val_remove( FunctionContext*, const DateTimeVal& src, DateTimeVal* dst); template void AggregateFunctions::last_val_remove( FunctionContext*, const DecimalVal& src, DecimalVal* dst); +template void AggregateFunctions::last_val_remove( + FunctionContext*, const DecimalV2Val& src, DecimalV2Val* dst); template void AggregateFunctions::offset_fn_init( FunctionContext*, BooleanVal*); @@ -2375,6 +2696,8 @@ template void AggregateFunctions::offset_fn_init( FunctionContext*, DateTimeVal*); template void AggregateFunctions::offset_fn_init( FunctionContext*, DecimalVal*); +template void AggregateFunctions::offset_fn_init( + FunctionContext*, DecimalV2Val*); template void AggregateFunctions::offset_fn_update( FunctionContext*, const BooleanVal& src, const BigIntVal&, const BooleanVal&, @@ -2405,5 +2728,7 @@ template void AggregateFunctions::offset_fn_update( template void AggregateFunctions::offset_fn_update( FunctionContext*, const DecimalVal& src, const BigIntVal&, const DecimalVal&, DecimalVal* dst); - +template void AggregateFunctions::offset_fn_update( + FunctionContext*, const DecimalV2Val& src, const BigIntVal&, const DecimalV2Val&, + DecimalV2Val* dst); } diff --git a/be/src/exprs/aggregate_functions.h b/be/src/exprs/aggregate_functions.h index 44c5057bb05b6e..1b9b86194096e3 100644 --- a/be/src/exprs/aggregate_functions.h +++ b/be/src/exprs/aggregate_functions.h @@ -94,14 +94,23 @@ dst); // Avg for decimals. static void decimal_avg_init(doris_udf::FunctionContext* ctx, doris_udf::StringVal* dst); + static void decimalv2_avg_init(doris_udf::FunctionContext* ctx, doris_udf::StringVal* dst); static void decimal_avg_update(doris_udf::FunctionContext* ctx, const doris_udf::DecimalVal& src, doris_udf::StringVal* dst); + static void decimalv2_avg_update(doris_udf::FunctionContext* ctx, + const doris_udf::DecimalV2Val& src, + doris_udf::StringVal* dst); static void decimal_avg_merge(FunctionContext* ctx, const doris_udf::StringVal& src, doris_udf::StringVal* dst); + static void decimalv2_avg_merge(FunctionContext* ctx, const doris_udf::StringVal& src, + doris_udf::StringVal* dst); static void decimal_avg_remove(doris_udf::FunctionContext* ctx, const doris_udf::DecimalVal& src, doris_udf::StringVal* dst); + static void decimalv2_avg_remove(doris_udf::FunctionContext* ctx, + const doris_udf::DecimalV2Val& src, + doris_udf::StringVal* dst); // static void decimal_avg_add_or_remove(doris_udf::FunctionContext* ctx, // const doris_udf::DecimalVal& src, @@ -113,9 +122,12 @@ dst); // } static doris_udf::DecimalVal decimal_avg_get_value(doris_udf::FunctionContext* ctx, const doris_udf::StringVal& val); + static doris_udf::DecimalV2Val decimalv2_avg_get_value(doris_udf::FunctionContext* ctx, + const doris_udf::StringVal& val); static doris_udf::DecimalVal decimal_avg_finalize(doris_udf::FunctionContext* ctx, const doris_udf::StringVal& val); - + static doris_udf::DecimalV2Val decimalv2_avg_finalize(doris_udf::FunctionContext* ctx, + const doris_udf::StringVal& val); // SumUpdate, SumMerge template static void sum(doris_udf::FunctionContext*, const SRC_VAL& src, DST_VAL* dst); @@ -206,11 +218,17 @@ dst); // count distinct in multi distinct for decimal static void count_or_sum_distinct_decimal_init(doris_udf::FunctionContext* ctx, doris_udf::StringVal* dst); + static void count_or_sum_distinct_decimalv2_init(doris_udf::FunctionContext* ctx, doris_udf::StringVal* dst); static void count_or_sum_distinct_decimal_update(FunctionContext* ctx, DecimalVal& src, StringVal* dst); + static void count_or_sum_distinct_decimalv2_update(FunctionContext* ctx, DecimalV2Val& src, StringVal* dst); static void count_or_sum_distinct_decimal_merge(FunctionContext* ctx, StringVal& src, StringVal* dst); + static void count_or_sum_distinct_decimalv2_merge(FunctionContext* ctx, StringVal& src, StringVal* dst); static StringVal count_or_sum_distinct_decimal_serialize(FunctionContext* ctx, const StringVal& state_sv); + static StringVal count_or_sum_distinct_decimalv2_serialize(FunctionContext* ctx, const StringVal& state_sv); static BigIntVal count_distinct_decimal_finalize(FunctionContext* ctx, const StringVal& state_sv); + static BigIntVal count_distinct_decimalv2_finalize(FunctionContext* ctx, const StringVal& state_sv); static DecimalVal sum_distinct_decimal_finalize(FunctionContext* ctx, const StringVal& state_sv); + static DecimalV2Val sum_distinct_decimalv2_finalize(FunctionContext* ctx, const StringVal& state_sv); // count distinct in multi disticnt for Date static void count_distinct_date_init(doris_udf::FunctionContext* ctx, doris_udf::StringVal* dst); diff --git a/be/src/exprs/anyval_util.cpp b/be/src/exprs/anyval_util.cpp index 515c7c2f08cbeb..19428fe9c6d812 100755 --- a/be/src/exprs/anyval_util.cpp +++ b/be/src/exprs/anyval_util.cpp @@ -31,6 +31,7 @@ using doris_udf::LargeIntVal; using doris_udf::FloatVal; using doris_udf::DoubleVal; using doris_udf::DecimalVal; +using doris_udf::DecimalV2Val; using doris_udf::DateTimeVal; using doris_udf::StringVal; using doris_udf::AnyVal; @@ -86,6 +87,9 @@ AnyVal* create_any_val(ObjectPool* pool, const TypeDescriptor& type) { case TYPE_DECIMAL: return pool->add(new DecimalVal); + case TYPE_DECIMALV2: + return pool->add(new DecimalV2Val); + case TYPE_DATE: return pool->add(new DateTimeVal); @@ -147,6 +151,11 @@ FunctionContext::TypeDesc AnyValUtil::column_type_to_type_desc(const TypeDescrip // out.precision = type.precision; // out.scale = type.scale; break; + case TYPE_DECIMALV2: + out.type = FunctionContext::TYPE_DECIMALV2; + // out.precision = type.precision; + // out.scale = type.scale; + break; case TYPE_NULL: out.type = FunctionContext::TYPE_NULL; break; diff --git a/be/src/exprs/anyval_util.h b/be/src/exprs/anyval_util.h index 92917b3db5df0d..b47cc65948f72a 100755 --- a/be/src/exprs/anyval_util.h +++ b/be/src/exprs/anyval_util.h @@ -22,6 +22,7 @@ #include "runtime/primitive_type.h" #include "udf/udf.h" #include "util/hash_util.hpp" +#include "util/types.h" #include "common/status.h" namespace doris { @@ -73,6 +74,10 @@ class AnyValUtil { return tv.hash(seed); } + static uint32_t hash(const doris_udf::DecimalV2Val& v, int seed) { + return HashUtil::hash(&v.val, 16, seed); + } + static uint32_t hash(const doris_udf::LargeIntVal& v, int seed) { return HashUtil::hash(&v.val, 8, seed); } @@ -121,6 +126,10 @@ class AnyValUtil { return HashUtil::fnv_hash64(&tv, sizeof(DecimalValue), seed); } + static uint64_t hash64(const doris_udf::DecimalV2Val& v, int64_t seed) { + return HashUtil::fnv_hash64(&v.val, 16, seed); + } + static uint64_t hash64(const doris_udf::LargeIntVal& v, int64_t seed) { return HashUtil::fnv_hash64(&v.val, 8, seed); } @@ -167,6 +176,10 @@ class AnyValUtil { return HashUtil::murmur_hash64A(&tv, sizeof(DecimalValue), seed); } + static uint64_t hash64_murmur(const doris_udf::DecimalV2Val& v, int64_t seed) { + return HashUtil::murmur_hash64A(&v.val, 16, seed); + } + static uint64_t hash64_murmur(const doris_udf::LargeIntVal& v, int64_t seed) { return HashUtil::murmur_hash64A(&v.val, 8, seed); } @@ -201,6 +214,8 @@ class AnyValUtil { return doris_udf::FunctionContext::TYPE_STRING; case TYPE_DECIMAL: return doris_udf::FunctionContext::TYPE_DECIMAL; + case TYPE_DECIMALV2: + return doris_udf::FunctionContext::TYPE_DECIMALV2; break; default: DCHECK(false) << "Unknown type: " << type; @@ -246,6 +261,9 @@ class AnyValUtil { case TYPE_DECIMAL: return sizeof(doris_udf::DecimalVal); + case TYPE_DECIMALV2: + return sizeof(doris_udf::DecimalV2Val); + default: DCHECK(false) << t; return 0; @@ -271,6 +289,7 @@ class AnyValUtil { case TYPE_DATE: return alignof(DateTimeVal); case TYPE_DECIMAL: return alignof(DecimalVal); + case TYPE_DECIMALV2: return alignof(DecimalV2Val); default: DCHECK(false) << t; return 0; @@ -364,6 +383,10 @@ class AnyValUtil { reinterpret_cast(slot)->to_decimal_val( reinterpret_cast(dst)); return; + case TYPE_DECIMALV2: + reinterpret_cast(dst)->val = + reinterpret_cast(slot)->value; + return; case TYPE_DATE: reinterpret_cast(slot)->to_datetime_val( reinterpret_cast(dst)); @@ -437,6 +460,13 @@ inline bool AnyValUtil::equals_intenal(const DecimalVal& x, const DecimalVal& y) return x == y; } +template<> +inline bool AnyValUtil::equals_intenal(const DecimalV2Val& x, const DecimalV2Val& y) { + DCHECK(!x.is_null); + DCHECK(!y.is_null); + return x == y; +} + // Creates the corresponding AnyVal subclass for type. The object is added to the pool. doris_udf::AnyVal* create_any_val(ObjectPool* pool, const TypeDescriptor& type); diff --git a/be/src/exprs/binary_predicate.cpp b/be/src/exprs/binary_predicate.cpp index d371126c301c52..f01c6b64e670cb 100644 --- a/be/src/exprs/binary_predicate.cpp +++ b/be/src/exprs/binary_predicate.cpp @@ -27,6 +27,7 @@ #include "runtime/string_value.h" #include "runtime/datetime_value.h" #include "runtime/decimal_value.h" +#include "runtime/decimalv2_value.h" using llvm::BasicBlock; using llvm::CmpInst; @@ -67,6 +68,8 @@ Expr* BinaryPredicate::from_thrift(const TExprNode& node) { return new EqDateTimeValPred(node); case TPrimitiveType::DECIMAL: return new EqDecimalValPred(node); + case TPrimitiveType::DECIMALV2: + return new EqDecimalV2ValPred(node); default: return NULL; } @@ -97,6 +100,8 @@ Expr* BinaryPredicate::from_thrift(const TExprNode& node) { return new NeDateTimeValPred(node); case TPrimitiveType::DECIMAL: return new NeDecimalValPred(node); + case TPrimitiveType::DECIMALV2: + return new NeDecimalV2ValPred(node); default: return NULL; } @@ -127,6 +132,8 @@ Expr* BinaryPredicate::from_thrift(const TExprNode& node) { return new LtDateTimeValPred(node); case TPrimitiveType::DECIMAL: return new LtDecimalValPred(node); + case TPrimitiveType::DECIMALV2: + return new LtDecimalV2ValPred(node); default: return NULL; } @@ -157,6 +164,8 @@ Expr* BinaryPredicate::from_thrift(const TExprNode& node) { return new LeDateTimeValPred(node); case TPrimitiveType::DECIMAL: return new LeDecimalValPred(node); + case TPrimitiveType::DECIMALV2: + return new LeDecimalV2ValPred(node); default: return NULL; } @@ -187,6 +196,8 @@ Expr* BinaryPredicate::from_thrift(const TExprNode& node) { return new GtDateTimeValPred(node); case TPrimitiveType::DECIMAL: return new GtDecimalValPred(node); + case TPrimitiveType::DECIMALV2: + return new GtDecimalV2ValPred(node); default: return NULL; } @@ -217,6 +228,8 @@ Expr* BinaryPredicate::from_thrift(const TExprNode& node) { return new GeDateTimeValPred(node); case TPrimitiveType::DECIMAL: return new GeDecimalValPred(node); + case TPrimitiveType::DECIMALV2: + return new GeDecimalV2ValPred(node); default: return NULL; } @@ -418,6 +431,7 @@ BINARY_PRED_FLOAT_FNS(DoubleVal, get_double_val); COMPLICATE_BINARY_PRED_FN(Ge##TYPE##Pred, TYPE, FN, DORIS_TYPE, FROM_FUNC, >=) COMPLICATE_BINARY_PRED_FNS(DecimalVal, get_decimal_val, DecimalValue, from_decimal_val) +COMPLICATE_BINARY_PRED_FNS(DecimalV2Val, get_decimalv2_val, DecimalV2Value, from_decimal_val) #define DATETIME_BINARY_PRED_FN(CLASS, OP, LLVM_PRED) \ BooleanVal CLASS::get_boolean_val(ExprContext* ctx, TupleRow* row) { \ diff --git a/be/src/exprs/binary_predicate.h b/be/src/exprs/binary_predicate.h index 463ca9b92d889f..6e18b5f2bbde65 100644 --- a/be/src/exprs/binary_predicate.h +++ b/be/src/exprs/binary_predicate.h @@ -76,5 +76,6 @@ BIN_PRED_CLASSES_DEFINE(DoubleVal) BIN_PRED_CLASSES_DEFINE(StringVal) BIN_PRED_CLASSES_DEFINE(DateTimeVal) BIN_PRED_CLASSES_DEFINE(DecimalVal) +BIN_PRED_CLASSES_DEFINE(DecimalV2Val) } #endif diff --git a/be/src/exprs/case_expr.cpp b/be/src/exprs/case_expr.cpp index dd3d9438069e21..c52c2b2f751305 100644 --- a/be/src/exprs/case_expr.cpp +++ b/be/src/exprs/case_expr.cpp @@ -332,6 +332,9 @@ void CaseExpr::get_child_val(int child_idx, ExprContext* ctx, TupleRow* row, Any case TYPE_DECIMAL: *reinterpret_cast(dst) = _children[child_idx]->get_decimal_val(ctx, row); break; + case TYPE_DECIMALV2: + *reinterpret_cast(dst) = _children[child_idx]->get_decimalv2_val(ctx, row); + break; case TYPE_LARGEINT: *reinterpret_cast(dst) = _children[child_idx]->get_large_int_val(ctx, row); break; @@ -375,6 +378,9 @@ bool CaseExpr::any_val_eq(const TypeDescriptor& type, const AnyVal* v1, const An case TYPE_DECIMAL: return AnyValUtil::equals(type, *reinterpret_cast(v1), *reinterpret_cast(v2)); + case TYPE_DECIMALV2: + return AnyValUtil::equals(type, *reinterpret_cast(v1), + *reinterpret_cast(v2)); case TYPE_LARGEINT: return AnyValUtil::equals(type, *reinterpret_cast(v1), *reinterpret_cast(v2)); @@ -438,6 +444,7 @@ CASE_COMPUTE_FN_WAPPER(DoubleVal, double_val) CASE_COMPUTE_FN_WAPPER(StringVal, string_val) CASE_COMPUTE_FN_WAPPER(DateTimeVal, datetime_val) CASE_COMPUTE_FN_WAPPER(DecimalVal, decimal_val) +CASE_COMPUTE_FN_WAPPER(DecimalV2Val, decimalv2_val) } diff --git a/be/src/exprs/case_expr.h b/be/src/exprs/case_expr.h index 7b848a96f42971..470bb442903e43 100644 --- a/be/src/exprs/case_expr.h +++ b/be/src/exprs/case_expr.h @@ -47,12 +47,14 @@ class CaseExpr: public Expr { virtual StringVal get_string_val(ExprContext* ctx, TupleRow* row); virtual DateTimeVal get_datetime_val(ExprContext* ctx, TupleRow* row); virtual DecimalVal get_decimal_val(ExprContext* ctx, TupleRow* row); + virtual DecimalV2Val get_decimalv2_val(ExprContext* ctx, TupleRow* row); protected: friend class Expr; friend class ComputeFunctions; friend class ConditionalFunctions; friend class DecimalOperators; + friend class DecimalV2Operators; CaseExpr(const TExprNode& node); virtual Status prepare( diff --git a/be/src/exprs/conditional_functions.h b/be/src/exprs/conditional_functions.h index d14ba7b92a2941..3d279001fbf73a 100644 --- a/be/src/exprs/conditional_functions.h +++ b/be/src/exprs/conditional_functions.h @@ -54,6 +54,7 @@ class IfNullExpr : public Expr { virtual StringVal get_string_val(ExprContext* context, TupleRow* row); virtual DateTimeVal get_datetime_val(ExprContext* context, TupleRow* row); virtual DecimalVal get_decimal_val(ExprContext* context, TupleRow* row); + virtual DecimalV2Val get_decimalv2_val(ExprContext* context, TupleRow* row); virtual LargeIntVal get_large_int_val(ExprContext* context, TupleRow* row); virtual Status get_codegend_compute_fn(RuntimeState* state, llvm::Function** fn); @@ -111,6 +112,7 @@ class IfExpr : public Expr { virtual StringVal get_string_val(ExprContext* context, TupleRow* row); virtual DateTimeVal get_datetime_val(ExprContext* context, TupleRow* row); virtual DecimalVal get_decimal_val(ExprContext* context, TupleRow* row); + virtual DecimalV2Val get_decimalv2_val(ExprContext* context, TupleRow* row); virtual LargeIntVal get_large_int_val(ExprContext* context, TupleRow* row); virtual Status get_codegend_compute_fn(RuntimeState* state, llvm::Function** fn); @@ -140,6 +142,7 @@ class CoalesceExpr : public Expr { virtual StringVal get_string_val(ExprContext* context, TupleRow* row); virtual DateTimeVal get_datetime_val(ExprContext* context, TupleRow* row); virtual DecimalVal get_decimal_val(ExprContext* context, TupleRow* row); + virtual DecimalV2Val get_decimalv2_val(ExprContext* context, TupleRow* row); virtual LargeIntVal get_large_int_val(ExprContext* context, TupleRow* row); virtual Status get_codegend_compute_fn(RuntimeState* state, llvm::Function** fn); diff --git a/be/src/exprs/conditional_functions_ir.cpp b/be/src/exprs/conditional_functions_ir.cpp index 912d8248208c63..1dbb718b13aeb5 100644 --- a/be/src/exprs/conditional_functions_ir.cpp +++ b/be/src/exprs/conditional_functions_ir.cpp @@ -39,6 +39,7 @@ IF_NULL_COMPUTE_FUNCTION(DoubleVal, double_val); IF_NULL_COMPUTE_FUNCTION(StringVal, string_val); IF_NULL_COMPUTE_FUNCTION(DateTimeVal, datetime_val); IF_NULL_COMPUTE_FUNCTION(DecimalVal, decimal_val); +IF_NULL_COMPUTE_FUNCTION(DecimalV2Val, decimalv2_val); IF_NULL_COMPUTE_FUNCTION(LargeIntVal, large_int_val); #define NULL_IF_COMPUTE_FUNCTION(TYPE, type_name) \ @@ -91,6 +92,7 @@ IF_COMPUTE_FUNCTION(DoubleVal, double_val); IF_COMPUTE_FUNCTION(StringVal, string_val); IF_COMPUTE_FUNCTION(DateTimeVal, datetime_val); IF_COMPUTE_FUNCTION(DecimalVal, decimal_val); +IF_COMPUTE_FUNCTION(DecimalV2Val, decimalv2_val); IF_COMPUTE_FUNCTION(LargeIntVal, large_int_val); #define COALESCE_COMPUTE_FUNCTION(type, type_name) \ @@ -113,6 +115,7 @@ COALESCE_COMPUTE_FUNCTION(DoubleVal, double_val); COALESCE_COMPUTE_FUNCTION(StringVal, string_val); COALESCE_COMPUTE_FUNCTION(DateTimeVal, datetime_val); COALESCE_COMPUTE_FUNCTION(DecimalVal, decimal_val); +COALESCE_COMPUTE_FUNCTION(DecimalV2Val, decimalv2_val); COALESCE_COMPUTE_FUNCTION(LargeIntVal, large_int_val); } diff --git a/be/src/exprs/decimalv2_operators.cpp b/be/src/exprs/decimalv2_operators.cpp new file mode 100644 index 00000000000000..b15bc40863b1f8 --- /dev/null +++ b/be/src/exprs/decimalv2_operators.cpp @@ -0,0 +1,188 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "exprs/decimalv2_operators.h" + +#include +#include +#include + +#include "exprs/anyval_util.h" +#include "exprs/case_expr.h" +#include "exprs/expr.h" +#include "runtime/tuple_row.h" +// #include "util/decimal_util.h" +#include "util/string_parser.hpp" + +namespace doris { + +void DecimalV2Operators::init() { +} + +#define CAST_INT_TO_DECIMAL(from_type) \ + DecimalV2Val DecimalV2Operators::cast_to_decimalv2_val( \ + FunctionContext* context, const from_type& val) { \ + if (val.is_null) return DecimalV2Val::null(); \ + DecimalV2Value dv(val.val, 0);\ + DecimalV2Val result;\ + dv.to_decimal_val(&result);\ + return result;\ + } + +#define CAST_INT_TO_DECIMALS() \ + CAST_INT_TO_DECIMAL(TinyIntVal);\ + CAST_INT_TO_DECIMAL(SmallIntVal);\ + CAST_INT_TO_DECIMAL(IntVal);\ + CAST_INT_TO_DECIMAL(BigIntVal);\ + CAST_INT_TO_DECIMAL(LargeIntVal);\ + +CAST_INT_TO_DECIMALS(); + +DecimalV2Val DecimalV2Operators::cast_to_decimalv2_val( + FunctionContext* context, const FloatVal& val) { + if (val.is_null) { + return DecimalV2Val::null(); + } + DecimalV2Value dv; + dv.assign_from_float(val.val); + DecimalV2Val result; + dv.to_decimal_val(&result); + return result; +} + +DecimalV2Val DecimalV2Operators::cast_to_decimalv2_val( + FunctionContext* context, const DoubleVal& val) { + if (val.is_null) { + return DecimalV2Val::null(); + } + DecimalV2Value dv; + dv.assign_from_double(val.val); + DecimalV2Val result; + dv.to_decimal_val(&result); + return result; +} + +DecimalV2Val DecimalV2Operators::cast_to_decimalv2_val( + FunctionContext* context, const StringVal& val) { + if (val.is_null) { + return DecimalV2Val::null(); + } + DecimalV2Value dv; + if (dv.parse_from_str((const char*)val.ptr, val.len)) { + return DecimalV2Val::null(); + } + DecimalV2Val result; + dv.to_decimal_val(&result); + return result; +} + +#define CAST_DECIMAL_TO_INT(to_type, type_name) \ + to_type DecimalV2Operators::cast_to_##type_name( \ + FunctionContext* context, const DecimalV2Val& val) { \ + if (val.is_null) return to_type::null(); \ + DecimalV2Value dv = DecimalV2Value::from_decimal_val(val); \ + return to_type(dv);\ + } + +#define CAST_FROM_DECIMAL() \ + CAST_DECIMAL_TO_INT(BooleanVal, boolean_val);\ + CAST_DECIMAL_TO_INT(TinyIntVal, tiny_int_val);\ + CAST_DECIMAL_TO_INT(SmallIntVal, small_int_val);\ + CAST_DECIMAL_TO_INT(IntVal, int_val);\ + CAST_DECIMAL_TO_INT(BigIntVal, big_int_val);\ + CAST_DECIMAL_TO_INT(LargeIntVal, large_int_val);\ + CAST_DECIMAL_TO_INT(FloatVal, float_val);\ + CAST_DECIMAL_TO_INT(DoubleVal, double_val); + +CAST_FROM_DECIMAL(); + +StringVal DecimalV2Operators::cast_to_string_val( + FunctionContext* ctx, const DecimalV2Val& val) { + if (val.is_null) { + return StringVal::null(); + } + const DecimalV2Value& dv = DecimalV2Value::from_decimal_val(val); + return AnyValUtil::from_string_temp(ctx, dv.to_string()); +} + +DateTimeVal DecimalV2Operators::cast_to_datetime_val( + FunctionContext* context, const DecimalV2Val& val) { + if (val.is_null) { + return DateTimeVal::null(); + } + const DecimalV2Value& dv = DecimalV2Value::from_decimal_val(val); + DateTimeValue dt; + if (!dt.from_date_int64(dv)) { + return DateTimeVal::null(); + } + DateTimeVal result; + dt.to_datetime_val(&result); + return result; +} + +DecimalVal DecimalV2Operators::cast_to_decimal_val( + FunctionContext* context, const DecimalV2Val& val) { + if (val.is_null) return DecimalVal::null(); + DecimalV2Value v2(val.val); + DecimalValue dv(v2.int_value(), v2.frac_value()); + DecimalVal result; + dv.to_decimal_val(&result); + return result; +} + +#define DECIMAL_ARITHMETIC_OP(FN_NAME, OP) \ + DecimalV2Val DecimalV2Operators::FN_NAME##_decimalv2_val_decimalv2_val( \ + FunctionContext* context, const DecimalV2Val& v1, const DecimalV2Val& v2) { \ + if (v1.is_null || v2.is_null) return DecimalV2Val::null(); \ + DecimalV2Value iv1 = DecimalV2Value::from_decimal_val(v1); \ + DecimalV2Value iv2 = DecimalV2Value::from_decimal_val(v2); \ + DecimalV2Value ir = iv1 OP iv2; \ + DecimalV2Val result;\ + ir.to_decimal_val(&result); \ + return result; \ + } + +#define DECIMAL_ARITHMETIC_OPS() \ + DECIMAL_ARITHMETIC_OP(add, +);\ + DECIMAL_ARITHMETIC_OP(subtract, -);\ + DECIMAL_ARITHMETIC_OP(multiply, *);\ + DECIMAL_ARITHMETIC_OP(divide, /);\ + DECIMAL_ARITHMETIC_OP(mod, %);\ + +DECIMAL_ARITHMETIC_OPS(); + +#define DECIMAL_BINARY_PREDICATE_NONNUMERIC_FN(NAME, OP) \ + BooleanVal DecimalV2Operators::NAME##_decimalv2_val_decimalv2_val(\ + FunctionContext* c, const DecimalV2Val& v1, const DecimalV2Val& v2) {\ + if (v1.is_null || v2.is_null) return BooleanVal::null();\ + DecimalV2Value iv1 = DecimalV2Value::from_decimal_val(v1);\ + DecimalV2Value iv2 = DecimalV2Value::from_decimal_val(v2);\ + return BooleanVal(iv1 OP iv2);\ + } + +#define BINARY_PREDICATE_NONNUMERIC_FNS() \ + DECIMAL_BINARY_PREDICATE_NONNUMERIC_FN(eq, ==); \ + DECIMAL_BINARY_PREDICATE_NONNUMERIC_FN(ne, !=); \ + DECIMAL_BINARY_PREDICATE_NONNUMERIC_FN(gt, >); \ + DECIMAL_BINARY_PREDICATE_NONNUMERIC_FN(lt, <); \ + DECIMAL_BINARY_PREDICATE_NONNUMERIC_FN(ge, >=); \ + DECIMAL_BINARY_PREDICATE_NONNUMERIC_FN(le, <=); + +BINARY_PREDICATE_NONNUMERIC_FNS(); + +} + diff --git a/be/src/exprs/decimalv2_operators.h b/be/src/exprs/decimalv2_operators.h new file mode 100644 index 00000000000000..5a404b98992b27 --- /dev/null +++ b/be/src/exprs/decimalv2_operators.h @@ -0,0 +1,85 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#ifndef DORIS_BE_SRC_EXPRS_DECIMAL_OPERATORS_H +#define DORIS_BE_SRC_EXPRS_DECIMAL_OPERATORS_H + +#include +#include "runtime/decimalv2_value.h" +#include "udf/udf.h" + +namespace doris { + +class Expr; +struct ExprValue; +class TupleRow; + +/// Implementation of the decimal operators. These include the cast, +/// arithmetic and binary operators. +class DecimalV2Operators { +public: + static void init(); + + static DecimalV2Val cast_to_decimalv2_val(FunctionContext*, const TinyIntVal&); + static DecimalV2Val cast_to_decimalv2_val(FunctionContext*, const SmallIntVal&); + static DecimalV2Val cast_to_decimalv2_val(FunctionContext*, const IntVal&); + static DecimalV2Val cast_to_decimalv2_val(FunctionContext*, const BigIntVal&); + static DecimalV2Val cast_to_decimalv2_val(FunctionContext*, const LargeIntVal&); + static DecimalV2Val cast_to_decimalv2_val(FunctionContext*, const FloatVal&); + static DecimalV2Val cast_to_decimalv2_val(FunctionContext*, const DoubleVal&); + static DecimalV2Val cast_to_decimalv2_val(FunctionContext*, const StringVal&); + + static BooleanVal cast_to_boolean_val(FunctionContext*, const DecimalV2Val&); + static TinyIntVal cast_to_tiny_int_val(FunctionContext*, const DecimalV2Val&); + static SmallIntVal cast_to_small_int_val(FunctionContext*, const DecimalV2Val&); + static IntVal cast_to_int_val(FunctionContext*, const DecimalV2Val&); + static BigIntVal cast_to_big_int_val(FunctionContext*, const DecimalV2Val&); + static LargeIntVal cast_to_large_int_val(FunctionContext*, const DecimalV2Val&); + static FloatVal cast_to_float_val(FunctionContext*, const DecimalV2Val&); + static DoubleVal cast_to_double_val(FunctionContext*, const DecimalV2Val&); + static StringVal cast_to_string_val(FunctionContext*, const DecimalV2Val&); + static DateTimeVal cast_to_datetime_val(FunctionContext*, const DecimalV2Val&); + static DecimalVal cast_to_decimal_val(FunctionContext*, const DecimalV2Val&); + + static DecimalV2Val add_decimalv2_val_decimalv2_val( + FunctionContext*, const DecimalV2Val&, const DecimalV2Val&); + static DecimalV2Val subtract_decimalv2_val_decimalv2_val( + FunctionContext*, const DecimalV2Val&, const DecimalV2Val&); + static DecimalV2Val multiply_decimalv2_val_decimalv2_val( + FunctionContext*, const DecimalV2Val&, const DecimalV2Val&); + static DecimalV2Val divide_decimalv2_val_decimalv2_val( + FunctionContext*, const DecimalV2Val&, const DecimalV2Val&); + static DecimalV2Val mod_decimalv2_val_decimalv2_val( + FunctionContext*, const DecimalV2Val&, const DecimalV2Val&); + + static BooleanVal eq_decimalv2_val_decimalv2_val( + FunctionContext*, const DecimalV2Val&, const DecimalV2Val&); + static BooleanVal ne_decimalv2_val_decimalv2_val( + FunctionContext*, const DecimalV2Val&, const DecimalV2Val&); + static BooleanVal gt_decimalv2_val_decimalv2_val( + FunctionContext*, const DecimalV2Val&, const DecimalV2Val&); + static BooleanVal lt_decimalv2_val_decimalv2_val( + FunctionContext*, const DecimalV2Val&, const DecimalV2Val&); + static BooleanVal ge_decimalv2_val_decimalv2_val( + FunctionContext*, const DecimalV2Val&, const DecimalV2Val&); + static BooleanVal le_decimalv2_val_decimalv2_val( + FunctionContext*, const DecimalV2Val&, const DecimalV2Val&); +}; + +} + +#endif diff --git a/be/src/exprs/expr.cpp b/be/src/exprs/expr.cpp index 7420f314f833dd..439885f66b5ef9 100644 --- a/be/src/exprs/expr.cpp +++ b/be/src/exprs/expr.cpp @@ -141,6 +141,7 @@ Expr::Expr(const TypeDescriptor& type) : break; case TYPE_DECIMAL: + case TYPE_DECIMALV2: _node_type = (TExprNodeType::DECIMAL_LITERAL); break; @@ -198,6 +199,7 @@ Expr::Expr(const TypeDescriptor& type, bool is_slotref) : break; case TYPE_DECIMAL: + case TYPE_DECIMALV2: _node_type = (TExprNodeType::DECIMAL_LITERAL); break; @@ -753,6 +755,10 @@ doris_udf::AnyVal* Expr::get_const_val(ExprContext* context) { _constant_val.reset(new DecimalVal(get_decimal_val(context, NULL))); break; } + case TYPE_DECIMALV2: { + _constant_val.reset(new DecimalV2Val(get_decimalv2_val(context, NULL))); + break; + } case TYPE_NULL: { _constant_val.reset(new AnyVal(true)); break; @@ -836,6 +842,11 @@ DecimalVal Expr::get_decimal_val(ExprContext* context, TupleRow* row) { return val; } +DecimalV2Val Expr::get_decimalv2_val(ExprContext* context, TupleRow* row) { + DecimalV2Val val; + return val; +} + Status Expr::get_fn_context_error(ExprContext* ctx) { if (_fn_context_index != -1) { FunctionContext* fn_ctx = ctx->fn_context(_fn_context_index); diff --git a/be/src/exprs/expr.h b/be/src/exprs/expr.h index 0b3d02fc43fff5..7d1118acdcc2c1 100644 --- a/be/src/exprs/expr.h +++ b/be/src/exprs/expr.h @@ -33,6 +33,7 @@ #include "runtime/string_value.hpp" #include "runtime/datetime_value.h" #include "runtime/decimal_value.h" +#include "runtime/decimalv2_value.h" #include "udf/udf.h" #include "runtime/types.h" //#include @@ -122,6 +123,7 @@ class Expr { // virtual ArrayVal GetArrayVal(ExprContext* context, TupleRow*); virtual DateTimeVal get_datetime_val(ExprContext* context, TupleRow*); virtual DecimalVal get_decimal_val(ExprContext* context, TupleRow*); + virtual DecimalV2Val get_decimalv2_val(ExprContext* context, TupleRow*); // Get the number of digits after the decimal that should be displayed for this // value. Returns -1 if no scale has been specified (currently the scale is only set for @@ -514,6 +516,7 @@ class Expr { static StringVal get_string_val(Expr* expr, ExprContext* context, TupleRow* row); static DateTimeVal get_datetime_val(Expr* expr, ExprContext* context, TupleRow* row); static DecimalVal get_decimal_val(Expr* expr, ExprContext* context, TupleRow* row); + static DecimalV2Val get_decimalv2_val(Expr* expr, ExprContext* context, TupleRow* row); // Helper function for InlineConstants(). Returns the IR version of what GetConstant() // would return. diff --git a/be/src/exprs/expr_context.cpp b/be/src/exprs/expr_context.cpp index c8235fbf732da8..77151f1b611700 100644 --- a/be/src/exprs/expr_context.cpp +++ b/be/src/exprs/expr_context.cpp @@ -374,6 +374,14 @@ void* ExprContext::get_value(Expr* e, TupleRow* row) { _result.decimal_val = DecimalValue::from_decimal_val(v); return &_result.decimal_val; } + case TYPE_DECIMALV2: { + DecimalV2Val v = e->get_decimalv2_val(this, row); + if (v.is_null) { + return NULL; + } + _result.decimalv2_val = DecimalV2Value::from_decimal_val(v); + return &_result.decimalv2_val; + } #if 0 case TYPE_ARRAY: case TYPE_MAP: { @@ -451,6 +459,10 @@ DecimalVal ExprContext::get_decimal_val(TupleRow* row) { return _root->get_decimal_val(this, row); } +DecimalV2Val ExprContext::get_decimalv2_val(TupleRow* row) { + return _root->get_decimalv2_val(this, row); +} + Status ExprContext::get_const_value(RuntimeState* state, Expr& expr, AnyVal** const_val) { DCHECK(_opened); diff --git a/be/src/exprs/expr_context.h b/be/src/exprs/expr_context.h index 9c3b8ddde7f90b..cbf2b6ea991134 100644 --- a/be/src/exprs/expr_context.h +++ b/be/src/exprs/expr_context.h @@ -142,6 +142,7 @@ class ExprContext { // ArrayVal GetArrayVal(TupleRow* row); DateTimeVal get_datetime_val(TupleRow* row); DecimalVal get_decimal_val(TupleRow* row); + DecimalV2Val get_decimalv2_val(TupleRow* row); /// Frees all local allocations made by fn_contexts_. This can be called when result /// data from this context is no longer needed. diff --git a/be/src/exprs/expr_ir.cpp b/be/src/exprs/expr_ir.cpp index f216fb2d29f761..b29b0fc3b2a054 100644 --- a/be/src/exprs/expr_ir.cpp +++ b/be/src/exprs/expr_ir.cpp @@ -74,4 +74,7 @@ DateTimeVal Expr::get_datetime_val(Expr* expr, ExprContext* context, TupleRow* r DecimalVal Expr::get_decimal_val(Expr* expr, ExprContext* context, TupleRow* row) { return expr->get_decimal_val(context, row); } +DecimalV2Val Expr::get_decimalv2_val(Expr* expr, ExprContext* context, TupleRow* row) { + return expr->get_decimalv2_val(context, row); +} } diff --git a/be/src/exprs/expr_value.h b/be/src/exprs/expr_value.h index e8d9d8cb4277e9..428abfd955f6a3 100644 --- a/be/src/exprs/expr_value.h +++ b/be/src/exprs/expr_value.h @@ -22,6 +22,7 @@ #include "runtime/string_value.hpp" #include "runtime/datetime_value.h" #include "runtime/decimal_value.h" +#include "runtime/decimalv2_value.h" #include "runtime/types.h" namespace doris { @@ -44,6 +45,7 @@ struct ExprValue { StringValue string_val; DateTimeValue datetime_val; DecimalValue decimal_val; + DecimalV2Value decimalv2_val; ExprValue() : bool_val(false), @@ -57,7 +59,8 @@ struct ExprValue { string_data(), string_val(NULL, 0), datetime_val(), - decimal_val() { + decimal_val(), + decimalv2_val() { } ExprValue(bool v): bool_val(v) {} @@ -68,7 +71,7 @@ struct ExprValue { ExprValue(__int128 value) : large_int_val(value) {} ExprValue(float v): float_val(v) {} ExprValue(double v): double_val(v) {} - ExprValue(int64_t i, int32_t f) : decimal_val(i, f) {} + ExprValue(int64_t i, int32_t f) : decimal_val(i, f), decimalv2_val(i, f) {} // c'tor for string values ExprValue(const std::string& str) : @@ -137,6 +140,10 @@ struct ExprValue { decimal_val.set_to_zero(); return &decimal_val; + case TYPE_DECIMALV2: + decimalv2_val.set_to_zero(); + return &decimalv2_val; + default: DCHECK(false); return NULL; @@ -185,6 +192,10 @@ struct ExprValue { decimal_val = DecimalValue::get_min_decimal(); return &decimal_val; + case TYPE_DECIMALV2: + decimalv2_val = DecimalV2Value::get_min_decimal(); + return &decimalv2_val; + default: DCHECK(false); return NULL; @@ -233,6 +244,10 @@ struct ExprValue { decimal_val = DecimalValue::get_max_decimal(); return &decimal_val; + case TYPE_DECIMALV2: + decimalv2_val = DecimalV2Value::get_max_decimal(); + return &decimalv2_val; + default: DCHECK(false); return NULL; diff --git a/be/src/exprs/hybird_set.cpp b/be/src/exprs/hybird_set.cpp index 57254ae76a794b..b485b09f7b8085 100644 --- a/be/src/exprs/hybird_set.cpp +++ b/be/src/exprs/hybird_set.cpp @@ -49,6 +49,9 @@ HybirdSetBase* HybirdSetBase::create_set(PrimitiveType type) { case TYPE_DECIMAL: return new(std::nothrow) HybirdSet(); + case TYPE_DECIMALV2: + return new(std::nothrow) HybirdSet(); + case TYPE_LARGEINT: return new(std::nothrow) HybirdSet<__int128>(); diff --git a/be/src/exprs/hybird_set.h b/be/src/exprs/hybird_set.h index 3812ba228f6e2f..54e3d38f6523da 100644 --- a/be/src/exprs/hybird_set.h +++ b/be/src/exprs/hybird_set.h @@ -26,6 +26,7 @@ #include "runtime/string_value.h" #include "runtime/datetime_value.h" #include "runtime/decimal_value.h" +#include "runtime/decimalv2_value.h" namespace doris { diff --git a/be/src/exprs/is_null_predicate.cpp b/be/src/exprs/is_null_predicate.cpp index aca5761cb9120d..efc3a05165709d 100644 --- a/be/src/exprs/is_null_predicate.cpp +++ b/be/src/exprs/is_null_predicate.cpp @@ -45,6 +45,7 @@ template BooleanVal IsNullPredicate::is_null(FunctionContext*, const DoubleVal&) template BooleanVal IsNullPredicate::is_null(FunctionContext*, const StringVal&); template BooleanVal IsNullPredicate::is_null(FunctionContext*, const DateTimeVal&); template BooleanVal IsNullPredicate::is_null(FunctionContext*, const DecimalVal&); +template BooleanVal IsNullPredicate::is_null(FunctionContext*, const DecimalV2Val&); template BooleanVal IsNullPredicate::is_not_null(FunctionContext*, const AnyVal&); template BooleanVal IsNullPredicate::is_not_null(FunctionContext*, const BooleanVal&); @@ -58,5 +59,6 @@ template BooleanVal IsNullPredicate::is_not_null(FunctionContext*, const DoubleV template BooleanVal IsNullPredicate::is_not_null(FunctionContext*, const StringVal&); template BooleanVal IsNullPredicate::is_not_null(FunctionContext*, const DateTimeVal&); template BooleanVal IsNullPredicate::is_not_null(FunctionContext*, const DecimalVal&); +template BooleanVal IsNullPredicate::is_not_null(FunctionContext*, const DecimalV2Val&); } diff --git a/be/src/exprs/literal.cpp b/be/src/exprs/literal.cpp index bb8a91d0e4090a..9a6b626252f0ea 100644 --- a/be/src/exprs/literal.cpp +++ b/be/src/exprs/literal.cpp @@ -99,6 +99,12 @@ Literal::Literal(const TExprNode& node) : _value.decimal_val = DecimalValue(node.decimal_literal.value); break; } + case TYPE_DECIMALV2: { + DCHECK_EQ(node.node_type, TExprNodeType::DECIMAL_LITERAL); + DCHECK(node.__isset.decimal_literal); + _value.decimalv2_val = DecimalV2Value(node.decimal_literal.value); + break; + } default: break; // DCHECK(false) << "Invalid type: " << TypeToString(_type.type); @@ -155,6 +161,13 @@ DecimalVal Literal::get_decimal_val(ExprContext* context, TupleRow* row) { return dec_val; } +DecimalV2Val Literal::get_decimalv2_val(ExprContext* context, TupleRow* row) { + DCHECK_EQ(_type.type, TYPE_DECIMALV2) << _type; + DecimalV2Val dec_val; + _value.decimalv2_val.to_decimal_val(&dec_val); + return dec_val; +} + DateTimeVal Literal::get_datetime_val(ExprContext* context, TupleRow* row) { DateTimeVal dt_val; _value.datetime_val.to_datetime_val(&dt_val); diff --git a/be/src/exprs/literal.h b/be/src/exprs/literal.h index 307775b89f9520..cf48f26689ed20 100644 --- a/be/src/exprs/literal.h +++ b/be/src/exprs/literal.h @@ -44,6 +44,7 @@ class Literal : public Expr { virtual FloatVal get_float_val(ExprContext* context, TupleRow*); virtual DoubleVal get_double_val(ExprContext* context, TupleRow*); virtual DecimalVal get_decimal_val(ExprContext* context, TupleRow*); + virtual DecimalV2Val get_decimalv2_val(ExprContext* context, TupleRow*); virtual DateTimeVal get_datetime_val(ExprContext* context, TupleRow*); virtual StringVal get_string_val(ExprContext* context, TupleRow* row); diff --git a/be/src/exprs/math_functions.cpp b/be/src/exprs/math_functions.cpp index bfef98fb477841..66a23cee0b42ec 100644 --- a/be/src/exprs/math_functions.cpp +++ b/be/src/exprs/math_functions.cpp @@ -26,6 +26,7 @@ #include "exprs/expr.h" #include "runtime/tuple_row.h" #include "runtime/decimal_value.h" +#include "runtime/decimalv2_value.h" #include "util/string_parser.hpp" namespace doris { @@ -530,6 +531,11 @@ DecimalVal MathFunctions::positive_decimal( return val; } +DecimalV2Val MathFunctions::positive_decimal( + FunctionContext* ctx, const DecimalV2Val& val) { + return val; +} + BigIntVal MathFunctions::negative_bigint( FunctionContext* ctx, const BigIntVal& val) { if (val.is_null) { @@ -559,6 +565,17 @@ DecimalVal MathFunctions::negative_decimal( return result; } +DecimalV2Val MathFunctions::negative_decimal( + FunctionContext* ctx, const DecimalV2Val& val) { + if (val.is_null) { + return val; + } + const DecimalV2Value& dv1 = DecimalV2Value::from_decimal_val(val); + DecimalV2Val result; + (-dv1).to_decimal_val(&result); + return result; +} + #define LEAST_FN(TYPE) \ TYPE MathFunctions::least(\ FunctionContext* ctx, int num_args, const TYPE* args) { \ @@ -601,6 +618,7 @@ LEAST_FNS(); LEAST_NONNUMERIC_FN(string_val, StringVal, StringValue); \ LEAST_NONNUMERIC_FN(datetime_val, DateTimeVal, DateTimeValue); \ LEAST_NONNUMERIC_FN(decimal_val, DecimalVal, DecimalValue); \ + LEAST_NONNUMERIC_FN(decimal_val, DecimalV2Val, DecimalV2Value); \ LEAST_NONNUMERIC_FNS(); @@ -646,6 +664,7 @@ GREATEST_FNS(); GREATEST_NONNUMERIC_FN(string_val, StringVal, StringValue); \ GREATEST_NONNUMERIC_FN(datetime_val, DateTimeVal, DateTimeValue); \ GREATEST_NONNUMERIC_FN(decimal_val, DecimalVal, DecimalValue); \ + GREATEST_NONNUMERIC_FN(decimal_val, DecimalV2Val, DecimalV2Value); \ GREATEST_NONNUMERIC_FNS(); @@ -792,6 +811,24 @@ void* MathFunctions::least_decimal(Expr* e, TupleRow* row) { return &e->children()[result_idx]->_result.decimal_val; } +void* MathFunctions::least_decimalv2(Expr* e, TupleRow* row) { + DCHECK_GE(e->get_num_children(), 1); + int32_t num_args = e->get_num_children(); + int result_idx = 0; + // NOTE: loop index starts at 0, so If frist arg is NULL, we can return early.. + for (int i = 0; i < num_args; ++i) { + DecimalV2Value* arg = reinterpret_cast(e->children()[i]->get_value(row)); + if (arg == NULL) { + return NULL; + } + if (*arg < *reinterpret_cast(e->children()[result_idx]->get_value(row))) { + result_idx = i; + } + } + return &e->children()[result_idx]->_result.decimalv2_val; +} + + void* MathFunctions::least_string(Expr* e, TupleRow* row) { DCHECK_GE(e->get_num_children(), 1); int32_t num_args = e->get_num_children(); diff --git a/be/src/exprs/math_functions.h b/be/src/exprs/math_functions.h index 2d1fb64d3656e2..7ee8466e8b5b60 100644 --- a/be/src/exprs/math_functions.h +++ b/be/src/exprs/math_functions.h @@ -117,12 +117,16 @@ class MathFunctions { doris_udf::FunctionContext* ctx, const doris_udf::DoubleVal& val); static doris_udf::DecimalVal positive_decimal( doris_udf::FunctionContext* ctx, const doris_udf::DecimalVal& val); + static doris_udf::DecimalV2Val positive_decimal( + doris_udf::FunctionContext* ctx, const doris_udf::DecimalV2Val& val); static doris_udf::BigIntVal negative_bigint( doris_udf::FunctionContext* ctx, const doris_udf::BigIntVal& val); static doris_udf::DoubleVal negative_double( doris_udf::FunctionContext* ctx, const doris_udf::DoubleVal& val); static doris_udf::DecimalVal negative_decimal( doris_udf::FunctionContext* ctx, const doris_udf::DecimalVal& val); + static doris_udf::DecimalV2Val negative_decimal( + doris_udf::FunctionContext* ctx, const doris_udf::DecimalV2Val& val); static doris_udf::TinyIntVal least( doris_udf::FunctionContext* ctx, int num_args, const doris_udf::TinyIntVal* args); @@ -164,7 +168,10 @@ class MathFunctions { doris_udf::FunctionContext* ctx, int num_args, const doris_udf::DecimalVal* val); static doris_udf::DecimalVal greatest( doris_udf::FunctionContext* ctx, int num_args, const doris_udf::DecimalVal* val); - + static doris_udf::DecimalV2Val least( + doris_udf::FunctionContext* ctx, int num_args, const doris_udf::DecimalV2Val* val); + static doris_udf::DecimalV2Val greatest( + doris_udf::FunctionContext* ctx, int num_args, const doris_udf::DecimalV2Val* val); private: static const int32_t MIN_BASE = 2; static const int32_t MAX_BASE = 36; diff --git a/be/src/exprs/new_agg_fn_evaluator.cc b/be/src/exprs/new_agg_fn_evaluator.cc index 384b2b757cdd83..b37582c329183b 100644 --- a/be/src/exprs/new_agg_fn_evaluator.cc +++ b/be/src/exprs/new_agg_fn_evaluator.cc @@ -261,6 +261,10 @@ void NewAggFnEvaluator::SetDstSlot(const AnyVal* src, const SlotDescriptor& dst_ *reinterpret_cast(slot) = DecimalValue::from_decimal_val( *reinterpret_cast(src)); return; + case TYPE_DECIMALV2: + *reinterpret_cast(slot) = + reinterpret_cast(src)->val; + return; default: DCHECK(false) << "NYI: " << dst_slot_desc.type(); } @@ -362,6 +366,11 @@ inline void NewAggFnEvaluator::set_any_val( reinterpret_cast(dst)); return; + case TYPE_DECIMALV2: + reinterpret_cast(dst)->val = + reinterpret_cast(slot)->value; + return; + case TYPE_LARGEINT: memcpy(&reinterpret_cast(dst)->val, slot, sizeof(__int128)); return; @@ -545,6 +554,13 @@ void NewAggFnEvaluator::SerializeOrFinalize(Tuple* src, SetDstSlot(&v, dst_slot_desc, dst); break; } + case TYPE_DECIMALV2: { + typedef DecimalV2Val(*Fn)(FunctionContext*, AnyVal*); + DecimalV2Val v = reinterpret_cast(fn)( + agg_fn_ctx_.get(), staging_intermediate_val_); + SetDstSlot(&v, dst_slot_desc, dst); + break; + } case TYPE_DATE: case TYPE_DATETIME: { typedef DateTimeVal(*Fn)(FunctionContext*, AnyVal*); diff --git a/be/src/exprs/new_agg_fn_evaluator.h b/be/src/exprs/new_agg_fn_evaluator.h index 7c9bd72f5cad59..529bd240ecf53d 100644 --- a/be/src/exprs/new_agg_fn_evaluator.h +++ b/be/src/exprs/new_agg_fn_evaluator.h @@ -162,6 +162,7 @@ class NewAggFnEvaluator { static const size_t FLOAT_SIZE = sizeof(float); static const size_t DOUBLE_SIZE = sizeof(double); static const size_t DECIMAL_SIZE = sizeof(DecimalValue); + static const size_t DECIMALV2_SIZE = sizeof(DecimalV2Value); static const size_t TIME_DURATION_SIZE = sizeof(boost::posix_time::time_duration); static const size_t DATE_SIZE = sizeof(boost::gregorian::date); static const size_t LARGEINT_SIZE = sizeof(__int128); diff --git a/be/src/exprs/new_in_predicate.cpp b/be/src/exprs/new_in_predicate.cpp index 8bc20f25824dcc..026e52ab17e945 100644 --- a/be/src/exprs/new_in_predicate.cpp +++ b/be/src/exprs/new_in_predicate.cpp @@ -52,6 +52,12 @@ DecimalValue get_val( return DecimalValue::from_decimal_val(x); } +template<> +DecimalV2Value get_val( + const FunctionContext::TypeDesc* type, const DecimalV2Val& x) { + return DecimalV2Value::from_decimal_val(x); +} + template void InPredicate::set_lookup_prepare( FunctionContext* ctx, FunctionContext::FunctionStateScope scope) { @@ -189,6 +195,7 @@ IN_FUNCTIONS(DoubleVal, double, double_val) IN_FUNCTIONS(StringVal, StringValue, string_val) IN_FUNCTIONS(DateTimeVal, DateTimeValue, datetime_val) IN_FUNCTIONS(DecimalVal, DecimalValue, decimal_val) +IN_FUNCTIONS(DecimalV2Val, DecimalV2Value, decimalv2_val) IN_FUNCTIONS(LargeIntVal, __int128, large_int_val) // Needed for in-predicate-benchmark to build diff --git a/be/src/exprs/new_in_predicate.h b/be/src/exprs/new_in_predicate.h index b57400c2d0a170..0ae413079bab96 100644 --- a/be/src/exprs/new_in_predicate.h +++ b/be/src/exprs/new_in_predicate.h @@ -274,24 +274,46 @@ class InPredicate { doris_udf::FunctionContext* context, const doris_udf::DecimalVal& val, int num_args, const doris_udf::DecimalVal* args); + static doris_udf::BooleanVal in_iterate( + doris_udf::FunctionContext* context, const doris_udf::DecimalV2Val& val, + int num_args, const doris_udf::DecimalV2Val* args); + static doris_udf::BooleanVal not_in_iterate( doris_udf::FunctionContext* context, const doris_udf::DecimalVal& val, int num_args, const doris_udf::DecimalVal* args); + static doris_udf::BooleanVal not_in_iterate( + doris_udf::FunctionContext* context, const doris_udf::DecimalV2Val& val, + int num_args, const doris_udf::DecimalV2Val* args); + static void set_lookup_prepare_decimal_val(doris_udf::FunctionContext* ctx, doris_udf::FunctionContext::FunctionStateScope scope); + static void set_lookup_prepare_decimalv2_val(doris_udf::FunctionContext* ctx, + doris_udf::FunctionContext::FunctionStateScope scope); + static void set_lookup_close_decimal_val(doris_udf::FunctionContext* ctx, doris_udf::FunctionContext::FunctionStateScope scope); + static void set_lookup_close_decimalv2_val(doris_udf::FunctionContext* ctx, + doris_udf::FunctionContext::FunctionStateScope scope); + static doris_udf::BooleanVal in_set_lookup( doris_udf::FunctionContext* context, const doris_udf::DecimalVal& val, int num_args, const doris_udf::DecimalVal* args); + static doris_udf::BooleanVal in_set_lookup( + doris_udf::FunctionContext* context, const doris_udf::DecimalV2Val& val, + int num_args, const doris_udf::DecimalV2Val* args); + static doris_udf::BooleanVal not_in_set_lookup( doris_udf::FunctionContext* context, const doris_udf::DecimalVal& val, int num_args, const doris_udf::DecimalVal* args); + static doris_udf::BooleanVal not_in_set_lookup( + doris_udf::FunctionContext* context, const doris_udf::DecimalV2Val& val, + int num_args, const doris_udf::DecimalV2Val* args); + /* added by lide */ IN_FUNCTIONS_STMT(LargeIntVal, __int128, large_int_val) diff --git a/be/src/exprs/null_literal.cpp b/be/src/exprs/null_literal.cpp index 7f5ca92ef98e91..1917dee12b190d 100644 --- a/be/src/exprs/null_literal.cpp +++ b/be/src/exprs/null_literal.cpp @@ -75,6 +75,9 @@ DecimalVal NullLiteral::get_decimal_val(ExprContext*, TupleRow*) { return DecimalVal::null(); } +DecimalV2Val NullLiteral::get_decimalv2_val(ExprContext*, TupleRow*) { + return DecimalV2Val::null(); +} // Generated IR for a bigint NULL literal: // // define { i8, i64 } @NullLiteral(i8* %context, %"class.impala::TupleRow"* %row) { diff --git a/be/src/exprs/null_literal.h b/be/src/exprs/null_literal.h index 857b5b483f41c1..c53aee7ed995d5 100644 --- a/be/src/exprs/null_literal.h +++ b/be/src/exprs/null_literal.h @@ -47,6 +47,7 @@ class NullLiteral : public Expr { virtual doris_udf::StringVal get_string_val(ExprContext*, TupleRow*); virtual doris_udf::DateTimeVal get_datetime_val(ExprContext*, TupleRow*); virtual doris_udf::DecimalVal get_decimal_val(ExprContext*, TupleRow*); + virtual doris_udf::DecimalV2Val get_decimalv2_val(ExprContext*, TupleRow*); protected: friend class Expr; diff --git a/be/src/exprs/scalar_fn_call.cpp b/be/src/exprs/scalar_fn_call.cpp index 5ff95a5325dd6d..b14daf7fd1c4fd 100644 --- a/be/src/exprs/scalar_fn_call.cpp +++ b/be/src/exprs/scalar_fn_call.cpp @@ -441,7 +441,7 @@ Status ScalarFnCall::get_udf(RuntimeState* state, Function** udf) { Type* return_type = CodegenAnyVal::get_lowered_type(codegen, type()); std::vector arg_types; - if (type().type == TYPE_DECIMAL) { + if (type().type == TYPE_DECIMAL || type().type == TYPE_DECIMALV2) { // Per the x64 ABI, DecimalVals are returned via a DecmialVal* output argument return_type = codegen->void_type(); arg_types.push_back( @@ -747,6 +747,7 @@ typedef DoubleVal (*DoubleWrapper)(ExprContext*, TupleRow*); typedef StringVal (*StringWrapper)(ExprContext*, TupleRow*); typedef DateTimeVal (*DatetimeWrapper)(ExprContext*, TupleRow*); typedef DecimalVal (*DecimalWrapper)(ExprContext*, TupleRow*); +typedef DecimalV2Val (*DecimalV2Wrapper)(ExprContext*, TupleRow*); // TODO: macroify this? BooleanVal ScalarFnCall::get_boolean_val(ExprContext* context, TupleRow* row) { @@ -860,6 +861,17 @@ DecimalVal ScalarFnCall::get_decimal_val(ExprContext* context, TupleRow* row) { return fn(context, row); } +DecimalV2Val ScalarFnCall::get_decimalv2_val(ExprContext* context, TupleRow* row) { + DCHECK_EQ(_type.type, TYPE_DECIMALV2); + DCHECK(context != NULL); + if (_scalar_fn_wrapper == NULL) { + return interpret_eval(context, row); + } + DecimalV2Wrapper fn = reinterpret_cast(_scalar_fn_wrapper); + return fn(context, row); +} + + std::string ScalarFnCall::debug_string() const { std::stringstream out; out << "ScalarFnCall(udf_type=" << _fn.binary_type diff --git a/be/src/exprs/scalar_fn_call.h b/be/src/exprs/scalar_fn_call.h index dcd2ba782c83e8..4bf337723dd8b3 100644 --- a/be/src/exprs/scalar_fn_call.h +++ b/be/src/exprs/scalar_fn_call.h @@ -79,6 +79,7 @@ class ScalarFnCall : public Expr { virtual doris_udf::StringVal get_string_val(ExprContext* context, TupleRow*); virtual doris_udf::DateTimeVal get_datetime_val(ExprContext* context, TupleRow*); virtual doris_udf::DecimalVal get_decimal_val(ExprContext* context, TupleRow*); + virtual doris_udf::DecimalV2Val get_decimalv2_val(ExprContext* context, TupleRow*); // virtual doris_udf::ArrayVal GetArrayVal(ExprContext* context, TupleRow*); private: diff --git a/be/src/exprs/slot_ref.cpp b/be/src/exprs/slot_ref.cpp index d8cb4b799c67a5..b3e91fa4a70ba5 100644 --- a/be/src/exprs/slot_ref.cpp +++ b/be/src/exprs/slot_ref.cpp @@ -521,4 +521,14 @@ DecimalVal SlotRef::get_decimal_val(ExprContext* context, TupleRow* row) { return dec_val; } +DecimalV2Val SlotRef::get_decimalv2_val(ExprContext* context, TupleRow* row) { + DCHECK_EQ(_type.type, TYPE_DECIMALV2); + Tuple* t = row->get_tuple(_tuple_idx); + if (t == NULL || t->is_null(_null_indicator_offset)) { + return DecimalV2Val::null(); + } + + return DecimalV2Val(reinterpret_cast(t->get_slot(_slot_offset))->value); +} + } diff --git a/be/src/exprs/slot_ref.h b/be/src/exprs/slot_ref.h index 6d6f7ffb38c5f8..acdecca9476b18 100644 --- a/be/src/exprs/slot_ref.h +++ b/be/src/exprs/slot_ref.h @@ -78,6 +78,7 @@ class SlotRef : public Expr { virtual doris_udf::StringVal get_string_val(ExprContext* context, TupleRow*); virtual doris_udf::DateTimeVal get_datetime_val(ExprContext* context, TupleRow*); virtual doris_udf::DecimalVal get_decimal_val(ExprContext* context, TupleRow*); + virtual doris_udf::DecimalV2Val get_decimalv2_val(ExprContext* context, TupleRow*); // virtual doris_udf::ArrayVal GetArrayVal(ExprContext* context, TupleRow*); private: diff --git a/be/src/exprs/udf_builtins.cpp b/be/src/exprs/udf_builtins.cpp index ad0116b2b6f17f..b04b7d22767e33 100755 --- a/be/src/exprs/udf_builtins.cpp +++ b/be/src/exprs/udf_builtins.cpp @@ -32,6 +32,7 @@ using doris_udf::LargeIntVal; using doris_udf::FloatVal; using doris_udf::DoubleVal; using doris_udf::DecimalVal; +using doris_udf::DecimalV2Val; using doris_udf::StringVal; using doris_udf::AnyVal; @@ -52,6 +53,16 @@ DecimalVal UdfBuiltins::decimal_abs(FunctionContext* context, const DecimalVal& return result; } +DecimalV2Val UdfBuiltins::decimal_abs(FunctionContext* context, const DecimalV2Val& v) { + if (v.is_null) { + return v; + } + DecimalV2Val result = v; + result.set_to_abs_value(); + return result; +} + + //for test BigIntVal UdfBuiltins::add_two_number( FunctionContext* context, diff --git a/be/src/exprs/udf_builtins.h b/be/src/exprs/udf_builtins.h index cb00acf6f74200..7781ae77e40597 100755 --- a/be/src/exprs/udf_builtins.h +++ b/be/src/exprs/udf_builtins.h @@ -32,6 +32,8 @@ class UdfBuiltins { const doris_udf::DoubleVal& v); static doris_udf::DecimalVal decimal_abs(doris_udf::FunctionContext* context, const doris_udf::DecimalVal& v); + static doris_udf::DecimalV2Val decimal_abs(doris_udf::FunctionContext* context, + const doris_udf::DecimalV2Val& v); static doris_udf::BigIntVal add_two_number( doris_udf::FunctionContext* context, const doris_udf::BigIntVal& v1, diff --git a/be/src/olap/field_info.cpp b/be/src/olap/field_info.cpp index e9107678e8065b..59ce1a41bab39b 100644 --- a/be/src/olap/field_info.cpp +++ b/be/src/olap/field_info.cpp @@ -224,6 +224,7 @@ uint32_t FieldInfo::get_field_length_by_type(TPrimitiveType::type type, uint32_t case TPrimitiveType::HLL: return string_length + sizeof(OLAP_STRING_MAX_LENGTH); case TPrimitiveType::DECIMAL: + case TPrimitiveType::DECIMALV2: return 12; // use 12 bytes in olap engine. default: OLAP_LOG_WARNING("unknown field type. [type=%d]", type); diff --git a/be/src/olap/memtable.cpp b/be/src/olap/memtable.cpp index f145ef028d26c3..5771414c48752b 100644 --- a/be/src/olap/memtable.cpp +++ b/be/src/olap/memtable.cpp @@ -116,6 +116,13 @@ void MemTable::insert(Tuple* tuple) { storage_decimal_value->fraction = decimal_value->frac_value(); break; } + case TYPE_DECIMALV2: { + DecimalV2Value* decimal_value = tuple->get_decimalv2_slot(slot->tuple_offset()); + decimal12_t* storage_decimal_value = reinterpret_cast(_tuple_buf + offset); + storage_decimal_value->integer = decimal_value->int_value(); + storage_decimal_value->fraction = decimal_value->frac_value(); + break; + } case TYPE_DATETIME: { DateTimeValue* datetime_value = tuple->get_datetime_slot(slot->tuple_offset()); uint64_t* storage_datetime_value = reinterpret_cast(_tuple_buf + offset); diff --git a/be/src/olap/olap_engine.cpp b/be/src/olap/olap_engine.cpp index 239ddb39fd05f9..f043cc43c7702a 100644 --- a/be/src/olap/olap_engine.cpp +++ b/be/src/olap/olap_engine.cpp @@ -2094,7 +2094,7 @@ OLAPStatus OLAPEngine::_create_new_table_header( string data_type; EnumToString(TPrimitiveType, column.column_type.type, data_type); header->mutable_column(i)->set_type(data_type); - if (column.column_type.type == TPrimitiveType::DECIMAL) { + if (column.column_type.type == TPrimitiveType::DECIMAL || column.column_type.type == TPrimitiveType::DECIMALV2) { if (column.column_type.__isset.precision && column.column_type.__isset.scale) { header->mutable_column(i)->set_precision(column.column_type.precision); header->mutable_column(i)->set_frac(column.column_type.scale); diff --git a/be/src/runtime/CMakeLists.txt b/be/src/runtime/CMakeLists.txt index 13464da33b01a4..6f820b9867c32b 100644 --- a/be/src/runtime/CMakeLists.txt +++ b/be/src/runtime/CMakeLists.txt @@ -53,6 +53,7 @@ add_library(Runtime STATIC thread_resource_mgr.cpp # timestamp_value.cpp decimal_value.cpp + decimalv2_value.cpp large_int_value.cpp tuple.cpp tuple_row.cpp @@ -116,5 +117,6 @@ add_library(Runtime STATIC #ADD_BE_TEST(parallel_executor_test) #ADD_BE_TEST(datetime_value_test) #ADD_BE_TEST(decimal_value_test) +#ADD_BE_TEST(decimalv2_value_test) #ADD_BE_TEST(string_value_test) #ADD_BE_TEST(thread_resource_mgr_test) diff --git a/be/src/runtime/decimalv2_value.cpp b/be/src/runtime/decimalv2_value.cpp new file mode 100644 index 00000000000000..f76387b2671ae8 --- /dev/null +++ b/be/src/runtime/decimalv2_value.cpp @@ -0,0 +1,439 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "runtime/decimalv2_value.h" +#include "util/string_parser.hpp" + +#include +#include +#include + +namespace doris { + +const char* DecimalV2Value::_s_llvm_class_name = "class.doris::DecimalV2Value"; + +static inline int128_t abs(const int128_t& x) { return (x < 0) ? -x : x; } + +// x>=0 && y>=0 +static int do_add(int128_t x, int128_t y, int128_t* result) { + int error = E_DEC_OK; + if (DecimalV2Value::MAX_DECIMAL_VALUE - x >= y) { + *result = x + y; + } else { + *result = DecimalV2Value::MAX_DECIMAL_VALUE; + error = E_DEC_OVERFLOW; + LOG(INFO) << "overflow (x=" << x << ", y=" << y << ")"; + } + return error; +} + +// x>=0 && y>=0 +static int do_sub(int128_t x, int128_t y, int128_t* result) { + int error = E_DEC_OK; + *result = x - y; + return error; +} + +// clear leading zero for __int128 +static int clz128(unsigned __int128 v) { + if (v == 0) return sizeof(__int128); + unsigned __int128 shifted = v >> 64; + if (shifted != 0) { + return __builtin_clzll(shifted); + } else { + return __builtin_clzll(v) + 64; + } +} + +// x>0 && y>0 +static int do_mul(int128_t x, int128_t y, int128_t* result) { + int error = E_DEC_OK; + int128_t max128 = ~(static_cast(1ll) << 127); + + int leading_zero_bits = clz128(x) + clz128(y); + if (leading_zero_bits < sizeof(int128_t) || max128 / x < y) { + *result = DecimalV2Value::MAX_DECIMAL_VALUE; + LOG(INFO) << "overflow (x=" << x << ", y=" << y << ")"; + error = E_DEC_OVERFLOW; + return error; + } + + int128_t product = x * y; + *result = product / DecimalV2Value::ONE_BILLION; + + // overflow + if (*result > DecimalV2Value::MAX_DECIMAL_VALUE) { + *result = DecimalV2Value::MAX_DECIMAL_VALUE; + LOG(INFO) << "overflow (x=" << x << ", y=" << y << ")"; + error = E_DEC_OVERFLOW; + return error; + } + + // truncate with round + int128_t remainder = product % DecimalV2Value::ONE_BILLION; + if (remainder != 0) { + error = E_DEC_TRUNCATED; + if (remainder >= (DecimalV2Value::ONE_BILLION >> 1)) { + *result += 1; + } + LOG(INFO) << "truncate (x=" << x << ", y=" << y << ")" << ", result=" << *result; + } + + return error; +} + +// x>0 && y>0 +static int do_div(int128_t x, int128_t y, int128_t* result) { + int error = E_DEC_OK; + int128_t dividend = x * DecimalV2Value::ONE_BILLION; + *result = dividend / y; + + // overflow + int128_t remainder = dividend % y; + if (remainder != 0) { + error = E_DEC_TRUNCATED; + if (remainder >= (y >> 1)) { + *result += 1; + } + LOG(INFO) << "truncate (x=" << x << ", y=" << y << ")" << ", result=" << *result; + } + + return error; +} + +// x>0 && y>0 +static int do_mod(int128_t x, int128_t y, int128_t* result) { + int error = E_DEC_OK; + *result = x % y; + return error; +} + +DecimalV2Value operator+(const DecimalV2Value& v1, const DecimalV2Value& v2) { + int128_t result; + int128_t x = v1.value(); + int128_t y = v2.value(); + if (x == 0) { + result = y; + } else if (y == 0) { + result = x; + } else if (x > 0) { + if (y > 0) { + do_add(x, y, &result); + } else { + do_sub(x, -y, &result); + } + } else { // x < 0 + if (y > 0) { + do_sub(y, -x, &result); + } else { + do_add(-x, -y, &result); + result = -result; + } + } + + return DecimalV2Value(result); +} + +DecimalV2Value operator-(const DecimalV2Value& v1, const DecimalV2Value& v2) { + int128_t result; + int128_t x = v1.value(); + int128_t y = v2.value(); + if (x == 0) { + result = -y; + } else if (y == 0) { + result = x; + } else if (x > 0) { + if (y > 0) { + do_sub(x, y, &result); + } else { + do_add(x, -y, &result); + } + } else { // x < 0 + if (y > 0) { + do_add(-x, y, &result); + result = -result; + } else { + do_sub(-x, -y, &result); + result = -result; + } + } + + return DecimalV2Value(result); +} + +DecimalV2Value operator*(const DecimalV2Value& v1, const DecimalV2Value& v2){ + int128_t result; + int128_t x = v1.value(); + int128_t y = v2.value(); + + if (x == 0 || y == 0) return DecimalV2Value(0); + + bool is_positive = (x > 0 && y > 0) || (x < 0 && y < 0); + + do_mul(abs(x), abs(y), &result); + + if (!is_positive) result = -result; + + return DecimalV2Value(result); +} + +DecimalV2Value operator/(const DecimalV2Value& v1, const DecimalV2Value& v2){ + int128_t result; + int128_t x = v1.value(); + int128_t y = v2.value(); + + //todo: return 0 for divide zero + if (x == 0 || y == 0) return DecimalV2Value(0); + bool is_positive = (x > 0 && y > 0) || (x < 0 && y < 0); + do_div(abs(x), abs(y), &result); + + if (!is_positive) result = -result; + + return DecimalV2Value(result); +} + +DecimalV2Value operator%(const DecimalV2Value& v1, const DecimalV2Value& v2){ + int128_t result; + int128_t x = v1.value(); + int128_t y = v2.value(); + + //todo: return 0 for divide zero + if (x == 0 || y == 0) return DecimalV2Value(0); + + do_mod(x, y, &result); + + return DecimalV2Value(result); +} + +std::ostream& operator<<(std::ostream& os, DecimalV2Value const& decimal_value) { + return os << decimal_value.to_string(); +} + +std::istream& operator>>(std::istream& ism, DecimalV2Value& decimal_value) { + std::string str_buff; + ism >> str_buff; + decimal_value.parse_from_str(str_buff.c_str(), str_buff.size()); + return ism; +} + +DecimalV2Value operator-(const DecimalV2Value& v) { + return DecimalV2Value(-v.value()); +} + +DecimalV2Value& DecimalV2Value::operator+=(const DecimalV2Value& other) { + *this = *this + other; + return *this; +} + +int DecimalV2Value::parse_from_str(const char* decimal_str, int32_t length) { + int32_t error = E_DEC_OK; + StringParser::ParseResult result = StringParser::PARSE_SUCCESS; + + _value = StringParser::string_to_decimal(decimal_str, length, + PRECISION, SCALE, &result); + + if (result == StringParser::PARSE_FAILURE) { + error = E_DEC_BAD_NUM; + } + return error; +} + +std::string DecimalV2Value::to_string(int round_scale) const { + if (_value == 0) return std::string(1, '0'); + + int last_char_idx = PRECISION + 2 + (_value < 0); + std::string str = std::string(last_char_idx, '0'); + + int128_t remaining_value = _value; + int first_digit_idx = 0; + if (_value < 0) { + remaining_value = -_value; + first_digit_idx = 1; + } + + int remaining_scale = SCALE; + do { + str[--last_char_idx] = (remaining_value % 10) + '0'; + remaining_value /= 10; + } while (--remaining_scale > 0); + str[--last_char_idx] = '.'; + + do { + str[--last_char_idx] = (remaining_value % 10) + '0'; + remaining_value /= 10; + if (remaining_value == 0) { + if (last_char_idx > first_digit_idx) str.erase(0, last_char_idx - first_digit_idx); + break; + } + } while (last_char_idx > first_digit_idx); + + if (_value < 0) str[0] = '-'; + + // right trim and round + int scale = 0; + int len = str.size(); + for(scale = 0; scale < SCALE && scale < len; scale++) { + if (str[len - scale - 1] != '0') break; + } + if (scale == SCALE) scale++; //integer, trim . + if (round_scale >= 0 && round_scale <= SCALE) { + scale = std::max(scale, SCALE - round_scale); + } + if (scale > 1 && scale <= len) str.erase(len - scale, len - 1); + + return str; +} + +std::string DecimalV2Value::to_string() const { + return to_string(-1); +} + +// NOTE: only change abstract value, do not change sign +void DecimalV2Value::to_max_decimal(int32_t precision, int32_t scale) { + bool is_negtive = (_value < 0); + static const int64_t INT_MAX_VALUE[PRECISION] = { + 9ll, + 99ll, + 999ll, + 9999ll, + 99999ll, + 999999ll, + 9999999ll, + 99999999ll, + 999999999ll, + 9999999999ll, + 99999999999ll, + 999999999999ll, + 9999999999999ll, + 99999999999999ll, + 999999999999999ll, + 9999999999999999ll, + 99999999999999999ll, + 999999999999999999ll + }; + static const int32_t FRAC_MAX_VALUE[SCALE] = { + 900000000, + 990000000, + 999000000, + 999900000, + 999990000, + 999999000, + 999999900, + 999999990, + 999999999 + }; + + // precison > 0 && scale >= 0 && scale <= SCALE + if (precision <= 0 || scale < 0) return; + if (scale > SCALE) scale = SCALE; + + // precision: (scale, PRECISION] + if (precision > PRECISION) precision = PRECISION; + if (precision - scale > PRECISION - SCALE) { + precision = PRECISION - SCALE + scale; + } else if (precision <= scale) { + LOG(WARNING) << "Warning: error precision: " << precision << " or scale: " << scale; + precision = scale + 1; // corect error precision + } + + int64_t int_value = INT_MAX_VALUE[precision - scale - 1]; + int64_t frac_value = scale == 0? 0 : FRAC_MAX_VALUE[scale - 1]; + _value = static_cast(int_value) * DecimalV2Value::ONE_BILLION + frac_value; + if (is_negtive) _value = -_value; +} + +std::size_t hash_value(DecimalV2Value const& value) { + return value.hash(0); +} + +int DecimalV2Value::round(DecimalV2Value *to, int rounding_scale, DecimalRoundMode op) { + int32_t error = E_DEC_OK; + int128_t result; + + if (rounding_scale >= SCALE) return error; + if (rounding_scale < -(PRECISION - SCALE)) return 0; + + int128_t base = get_scale_base(SCALE - rounding_scale); + result = _value / base; + + int one = _value > 0 ? 1 : -1; + int128_t remainder = _value % base; + switch (op) { + case HALF_UP: + case HALF_EVEN: + if (abs(remainder) >= (base >> 1)) { + result = (result + one) * base; + } else { + result = result * base; + } + break; + case CEILING: + if (remainder > 0 && _value > 0) { + result = (result + one) * base; + } else { + result = result * base; + } + break; + case FLOOR: + if (remainder < 0 && _value < 0) { + result = (result + one) * base; + } else { + result = result * base; + } + break; + case TRUNCATE: + result = result * base; + break; + default: + break; + } + + to->set_value(result); + return error; +} + +bool DecimalV2Value::greater_than_scale(int scale) { + if (scale >= SCALE || scale < 0) { + return false; + } else if (scale == SCALE) { + return true; + } + + int frac_val = frac_value(); + if (scale == 0) { + bool ret = frac_val == 0 ? false : true; + return ret; + } + + static const int values[SCALE] = { + 1, + 10, + 100, + 1000, + 10000, + 100000, + 1000000, + 10000000, + 100000000 + }; + + int base = values[SCALE - scale]; + if (frac_val % base != 0) return true; + return false; +} + +} // end namespace doris diff --git a/be/src/runtime/decimalv2_value.h b/be/src/runtime/decimalv2_value.h new file mode 100644 index 00000000000000..6a460ca0454d06 --- /dev/null +++ b/be/src/runtime/decimalv2_value.h @@ -0,0 +1,354 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#ifndef DORIS_BE_SRC_RUNTIME_DECIMALV2_VALUE_H +#define DORIS_BE_SRC_RUNTIME_DECIMALV2_VALUE_H + +#include +#include +#include +#include + +#include +#include +#include + +#include "common/logging.h" +#include "udf/udf.h" +#include "util/hash_util.hpp" +#include "runtime/decimal_value.h" + +namespace doris { + +typedef __int128_t int128_t; + +class DecimalV2Value { +public: + friend DecimalV2Value operator+(const DecimalV2Value& v1, const DecimalV2Value& v2); + friend DecimalV2Value operator-(const DecimalV2Value& v1, const DecimalV2Value& v2); + friend DecimalV2Value operator*(const DecimalV2Value& v1, const DecimalV2Value& v2); + friend DecimalV2Value operator/(const DecimalV2Value& v1, const DecimalV2Value& v2); + friend std::istream& operator>>(std::istream& ism, DecimalV2Value& decimal_value); + friend DecimalV2Value operator-(const DecimalV2Value& v); + + static const int32_t PRECISION = 27; + static const int32_t SCALE = 9; + static const uint32_t ONE_BILLION = 1000000000; + static const int64_t MAX_INT_VALUE = 999999999999999999; + static const int32_t MAX_FRAC_VALUE = 999999999; + static const int64_t MAX_INT64 = 9223372036854775807ll; + + static const int128_t MAX_DECIMAL_VALUE = + static_cast(MAX_INT64) * ONE_BILLION + MAX_FRAC_VALUE; + + DecimalV2Value() : _value(0){} + inline const int128_t& value() const { return _value;} + inline int128_t& value() { return _value; } + + DecimalV2Value(const std::string& decimal_str) { + parse_from_str(decimal_str.c_str(), decimal_str.size()); + } + + // Construct from olap engine + DecimalV2Value(int64_t int_value, int64_t frac_value) { + from_olap_decimal(int_value, frac_value); + } + + inline bool from_olap_decimal(int64_t int_value, int64_t frac_value) { + bool success = true; + bool is_negtive = (int_value < 0 || frac_value < 0); + if (is_negtive) { + int_value = std::abs(int_value); + frac_value = std::abs(frac_value); + } + + //if (int_value > MAX_INT_VALUE) { + // int_value = MAX_INT_VALUE; + // success = false; + //} + + if (frac_value > MAX_FRAC_VALUE) { + frac_value = MAX_FRAC_VALUE; + success = false; + } + + _value = static_cast(int_value) * ONE_BILLION + frac_value; + if (is_negtive) _value = -_value; + + return success; + } + + DecimalV2Value(int128_t int_value) { + _value = int_value; + } + + void set_value(int128_t value) { + _value = value; + } + + DecimalV2Value& assign_from_float(const float float_value) { + _value = static_cast(float_value * ONE_BILLION); + return *this; + } + + DecimalV2Value& assign_from_double(const double double_value) { + _value = static_cast(double_value * ONE_BILLION); + return *this; + } + + // These cast functions are needed in "functions.cc", which is generated by python script. + // e.g. "ComputeFunctions::Cast_DecimalV2Value_double()" + // Discard the scale part + // ATTN: invoker must make sure no OVERFLOW + operator int64_t() const { + return static_cast(_value / ONE_BILLION); + } + + // These cast functions are needed in "functions.cc", which is generated by python script. + // e.g. "ComputeFunctions::Cast_DecimalV2Value_double()" + // Discard the scale part + // ATTN: invoker must make sure no OVERFLOW + operator int128_t() const { + return static_cast(_value / ONE_BILLION); + } + + operator bool() const { + return _value != 0; + } + + operator int8_t() const { + return static_cast(operator int64_t()); + } + + operator int16_t() const { + return static_cast(operator int64_t()); + } + + operator int32_t() const { + return static_cast(operator int64_t()); + } + + operator size_t() const { + return static_cast(operator int64_t()); + } + + operator float() const { + return (float)operator double(); + } + + operator double() const { + std::string str_buff = to_string(); + double result = std::strtod(str_buff.c_str(), nullptr); + return result; + } + + DecimalV2Value& operator+=(const DecimalV2Value& other); + + // To be Compatible with OLAP + // ATTN: NO-OVERFLOW should be guaranteed. + int64_t int_value() const { + return operator int64_t(); + } + + // To be Compatible with OLAP + // NOTE: return a negative value if decimal is negative. + // ATTN: the max length of fraction part in OLAP is 9, so the 'big digits' except the first one + // will be truncated. + int32_t frac_value() const { + return static_cast(_value % ONE_BILLION); + } + + bool operator==(const DecimalV2Value& other) const { + return _value == other.value(); + } + + bool operator!=(const DecimalV2Value& other) const { + return _value != other.value(); + } + + bool operator<=(const DecimalV2Value& other) const { + return _value <= other.value(); + } + + bool operator>=(const DecimalV2Value& other) const { + return _value >= other.value(); + } + + bool operator<(const DecimalV2Value& other) const { + return _value < other.value(); + } + + bool operator>(const DecimalV2Value& other) const { + return _value > other.value(); + } + + // change to maximum value for given precision and scale + // precision/scale - see decimal_bin_size() below + // to - decimal where where the result will be stored + void to_max_decimal(int precision, int frac); + void to_min_decimal(int precision, int frac) { + to_max_decimal(precision, frac); + if (_value > 0) _value = -_value; + } + + // The maximum of fraction part is "scale". + // If the length of fraction part is less than "scale", '0' will be filled. + std::string to_string(int scale) const; + // Output actual "scale", remove ending zeroes. + std::string to_string() const; + + // Convert string to decimal + // @param from - value to convert. Doesn't have to be \0 terminated! + // will stop at the fist non-digit char(nor '.' 'e' 'E'), + // or reaches the length + // @param length - maximum lengnth + // @return error number. + // + // E_DEC_OK/E_DEC_TRUNCATED/E_DEC_OVERFLOW/E_DEC_BAD_NUM/E_DEC_OOM + // In case of E_DEC_FATAL_ERROR *to is set to decimal zero + // (to make error handling easier) + // + // e.g. "1.2" ".2" "1.2e-3" "1.2e3" + int parse_from_str(const char* decimal_str, int32_t length); + + std::string get_debug_info() const { + return to_string(); + } + + static DecimalV2Value get_min_decimal() { + return DecimalV2Value(-MAX_INT_VALUE, MAX_FRAC_VALUE); + } + + static DecimalV2Value get_max_decimal() { + return DecimalV2Value(MAX_INT_VALUE, MAX_FRAC_VALUE); + } + + static DecimalV2Value from_decimal_val(const DecimalV2Val& val) { + return DecimalV2Value(val.value()); + } + + void to_decimal_val(DecimalV2Val* value) const { + value->val = _value; + } + + // set DecimalV2Value to zero + void set_to_zero() { + _value = 0; + } + + void to_abs_value() { + if (_value < 0) _value = -_value; + } + + uint32_t hash(uint32_t seed) const { + return HashUtil::hash(&_value, sizeof(_value), seed); + } + + int32_t precision() const { + return PRECISION; + } + + int32_t scale() const { + return SCALE; + } + + bool greater_than_scale(int scale); + + int round(DecimalV2Value *to, int scale, DecimalRoundMode mode); + + inline static int128_t get_scale_base(int scale) { + static const int128_t values[] = { + static_cast(1ll), + static_cast(10ll), + static_cast(100ll), + static_cast(1000ll), + static_cast(10000ll), + static_cast(100000ll), + static_cast(1000000ll), + static_cast(10000000ll), + static_cast(100000000ll), + static_cast(1000000000ll), + static_cast(10000000000ll), + static_cast(100000000000ll), + static_cast(1000000000000ll), + static_cast(10000000000000ll), + static_cast(100000000000000ll), + static_cast(1000000000000000ll), + static_cast(10000000000000000ll), + static_cast(100000000000000000ll), + static_cast(1000000000000000000ll), + static_cast(1000000000000000000ll) * 10ll, + static_cast(1000000000000000000ll) * 100ll, + static_cast(1000000000000000000ll) * 1000ll, + static_cast(1000000000000000000ll) * 10000ll, + static_cast(1000000000000000000ll) * 100000ll, + static_cast(1000000000000000000ll) * 1000000ll, + static_cast(1000000000000000000ll) * 10000000ll, + static_cast(1000000000000000000ll) * 100000000ll, + static_cast(1000000000000000000ll) * 1000000000ll, + static_cast(1000000000000000000ll) * 10000000000ll, + static_cast(1000000000000000000ll) * 100000000000ll, + static_cast(1000000000000000000ll) * 1000000000000ll, + static_cast(1000000000000000000ll) * 10000000000000ll, + static_cast(1000000000000000000ll) * 100000000000000ll, + static_cast(1000000000000000000ll) * 1000000000000000ll, + static_cast(1000000000000000000ll) * 10000000000000000ll, + static_cast(1000000000000000000ll) * 100000000000000000ll, + static_cast(1000000000000000000ll) * 100000000000000000ll * 10ll, + static_cast(1000000000000000000ll) * 100000000000000000ll * 100ll, + static_cast(1000000000000000000ll) * 100000000000000000ll * 1000ll}; + if (scale >= 0 && scale < 38) return values[scale]; + return -1; // Overflow + } + + bool is_zero() const { + return _value == 0; + } + + // For C++/IR interop, we need to be able to look up types by name. + static const char* _s_llvm_class_name; + +private: + + int128_t _value; +}; + +DecimalV2Value operator+(const DecimalV2Value& v1, const DecimalV2Value& v2); +DecimalV2Value operator-(const DecimalV2Value& v1, const DecimalV2Value& v2); +DecimalV2Value operator*(const DecimalV2Value& v1, const DecimalV2Value& v2); +DecimalV2Value operator/(const DecimalV2Value& v1, const DecimalV2Value& v2); +DecimalV2Value operator%(const DecimalV2Value& v1, const DecimalV2Value& v2); + +DecimalV2Value operator-(const DecimalV2Value& v); + +std::ostream& operator<<(std::ostream& os, DecimalV2Value const& decimal_value); +std::istream& operator>>(std::istream& ism, DecimalV2Value& decimal_value); + +std::size_t hash_value(DecimalV2Value const& value); + +} // end namespace doris + +namespace std { + template<> + struct hash { + size_t operator()(const doris::DecimalV2Value& v) const { + return doris::hash_value(v); + } + }; +} + +#endif // DORIS_BE_SRC_RUNTIME_DECIMALV2_VALUE_H diff --git a/be/src/runtime/dpp_sink.cpp b/be/src/runtime/dpp_sink.cpp index 0604d128888d24..ad04adbb8a9249 100644 --- a/be/src/runtime/dpp_sink.cpp +++ b/be/src/runtime/dpp_sink.cpp @@ -468,6 +468,23 @@ Status Translator::create_value_updaters() { } break; } + case TYPE_DECIMALV2: { + switch (_rollup_schema.value_ops()[i]) { + case TAggregationType::MAX: + _value_updaters.push_back(update_max<__int128>); + break; + case TAggregationType::MIN: + _value_updaters.push_back(update_min<__int128>); + break; + case TAggregationType::SUM: + _value_updaters.push_back(update_sum<__int128>); + break; + default: + _value_updaters.push_back(fake_update); + } + break; + } + case TYPE_DATE: case TYPE_DATETIME: { switch (_rollup_schema.value_ops()[i]) { diff --git a/be/src/runtime/dpp_writer.cpp b/be/src/runtime/dpp_writer.cpp index 4501b7d6f05390..6bf9b027ec076a 100644 --- a/be/src/runtime/dpp_writer.cpp +++ b/be/src/runtime/dpp_writer.cpp @@ -24,6 +24,7 @@ #include "olap/utils.h" #include "exprs/expr.h" #include "util/debug_util.h" +#include "util/types.h" #include "runtime/primitive_type.h" #include "runtime/row_batch.h" #include "runtime/tuple_row.h" @@ -215,6 +216,14 @@ Status DppWriter::append_one_row(TupleRow* row) { append_to_buf(&frac_val, sizeof(frac_val)); break; } + case TYPE_DECIMALV2: { + const DecimalV2Value decimal_val(reinterpret_cast(item)->value); + int64_t int_val = decimal_val.int_value(); + int32_t frac_val = decimal_val.frac_value(); + append_to_buf(&int_val, sizeof(int_val)); + append_to_buf(&frac_val, sizeof(frac_val)); + break; + } default: { std::stringstream ss; ss << "Unknown column type " << _output_expr_ctxs[i]->root()->type(); diff --git a/be/src/runtime/export_sink.cpp b/be/src/runtime/export_sink.cpp index c93fb02637a72f..276e0502bb1f80 100644 --- a/be/src/runtime/export_sink.cpp +++ b/be/src/runtime/export_sink.cpp @@ -181,6 +181,19 @@ Status ExportSink::gen_row_buffer(TupleRow* row, std::stringstream* ss) { (*ss) << decimal_str; break; } + case TYPE_DECIMALV2: { + const DecimalV2Value decimal_val(reinterpret_cast(item)->value); + std::string decimal_str; + int output_scale = _output_expr_ctxs[i]->root()->output_scale(); + + if (output_scale > 0 && output_scale <= 30) { + decimal_str = decimal_val.to_string(output_scale); + } else { + decimal_str = decimal_val.to_string(); + } + (*ss) << decimal_str; + break; + } default: { std::stringstream err_ss; err_ss << "can't export this type. type = " << _output_expr_ctxs[i]->root()->type(); diff --git a/be/src/runtime/mysql_table_writer.cpp b/be/src/runtime/mysql_table_writer.cpp index 23e41a0ee2b0ad..467ca0006c4780 100644 --- a/be/src/runtime/mysql_table_writer.cpp +++ b/be/src/runtime/mysql_table_writer.cpp @@ -21,6 +21,7 @@ #include "runtime/row_batch.h" #include "runtime/tuple_row.h" #include "exprs/expr.h" +#include "util/types.h" namespace doris { @@ -149,6 +150,20 @@ Status MysqlTableWriter::insert_row(TupleRow* row) { ss << decimal_str; break; } + case TYPE_DECIMALV2: { + const DecimalV2Value decimal_val(reinterpret_cast(item)->value); + std::string decimal_str; + int output_scale = _output_expr_ctxs[i]->root()->output_scale(); + + if (output_scale > 0 && output_scale <= 30) { + decimal_str = decimal_val.to_string(output_scale); + } else { + decimal_str = decimal_val.to_string(); + } + ss << decimal_str; + break; + } + default: { std::stringstream err_ss; err_ss << "can't convert this type to mysql type. type = " << diff --git a/be/src/runtime/primitive_type.cpp b/be/src/runtime/primitive_type.cpp index 2d670c2ee23c3c..5ff86c23f3fbb1 100644 --- a/be/src/runtime/primitive_type.cpp +++ b/be/src/runtime/primitive_type.cpp @@ -77,6 +77,9 @@ PrimitiveType thrift_to_type(TPrimitiveType::type ttype) { case TPrimitiveType::DECIMAL: return TYPE_DECIMAL; + case TPrimitiveType::DECIMALV2: + return TYPE_DECIMALV2; + case TPrimitiveType::CHAR: return TYPE_CHAR; @@ -135,6 +138,9 @@ TPrimitiveType::type to_thrift(PrimitiveType ptype) { case TYPE_DECIMAL: return TPrimitiveType::DECIMAL; + case TYPE_DECIMALV2: + return TPrimitiveType::DECIMALV2; + case TYPE_CHAR: return TPrimitiveType::CHAR; @@ -193,6 +199,9 @@ std::string type_to_string(PrimitiveType t) { case TYPE_DECIMAL: return "DECIMAL"; + case TYPE_DECIMALV2: + return "DECIMALV2"; + case TYPE_CHAR: return "CHAR"; case TYPE_HLL: @@ -253,6 +262,9 @@ std::string type_to_odbc_string(PrimitiveType t) { case TYPE_DECIMAL: return "decimal"; + case TYPE_DECIMALV2: + return "decimalv2"; + case TYPE_CHAR: return "char"; diff --git a/be/src/runtime/primitive_type.h b/be/src/runtime/primitive_type.h index 89edc67e58e917..3477671ae93472 100644 --- a/be/src/runtime/primitive_type.h +++ b/be/src/runtime/primitive_type.h @@ -24,6 +24,7 @@ #include "gen_cpp/Types_types.h" #include "gen_cpp/Opcodes_types.h" #include "runtime/decimal_value.h" +#include "runtime/decimalv2_value.h" #include "runtime/datetime_value.h" #include "runtime/large_int_value.h" #include "runtime/string_value.h" @@ -51,7 +52,8 @@ enum PrimitiveType { TYPE_STRUCT, /* 16 */ TYPE_ARRAY, /* 17 */ TYPE_MAP, /* 18 */ - TYPE_HLL /* 19 */ + TYPE_HLL, /* 19 */ + TYPE_DECIMALV2 /* 20 */ }; inline bool is_enumeration_type(PrimitiveType type) { @@ -63,6 +65,7 @@ inline bool is_enumeration_type(PrimitiveType type) { case TYPE_VARCHAR: case TYPE_DATETIME: case TYPE_DECIMAL: + case TYPE_DECIMALV2: case TYPE_BOOLEAN: case TYPE_HLL: return false; @@ -117,6 +120,7 @@ inline int get_byte_size(PrimitiveType type) { case TYPE_LARGEINT: case TYPE_DATETIME: case TYPE_DATE: + case TYPE_DECIMALV2: return 16; case TYPE_DECIMAL: @@ -154,6 +158,7 @@ inline int get_real_byte_size(PrimitiveType type) { case TYPE_DATETIME: case TYPE_DATE: + case TYPE_DECIMALV2: return 16; case TYPE_DECIMAL: @@ -204,6 +209,9 @@ inline int get_slot_size(PrimitiveType type) { case TYPE_DECIMAL: return sizeof(DecimalValue); + case TYPE_DECIMALV2: + return 16; + case INVALID_TYPE: default: DCHECK(false); diff --git a/be/src/runtime/raw_value.cpp b/be/src/runtime/raw_value.cpp index 01f6267c95409c..63ca9ea17a4c1d 100644 --- a/be/src/runtime/raw_value.cpp +++ b/be/src/runtime/raw_value.cpp @@ -84,6 +84,10 @@ void RawValue::print_value_as_bytes(const void* value, const TypeDescriptor& typ stream->write(chars, sizeof(DecimalValue)); break; + case TYPE_DECIMALV2: + stream->write(chars, sizeof(DecimalV2Value)); + break; + case TYPE_LARGEINT: stream->write(chars, sizeof(__int128)); break; @@ -161,6 +165,10 @@ void RawValue::print_value(const void* value, const TypeDescriptor& type, int sc *stream << *reinterpret_cast(value); break; + case TYPE_DECIMALV2: + *stream << reinterpret_cast(value)->value; + break; + case TYPE_LARGEINT: *stream << reinterpret_cast(value)->value; break; @@ -270,6 +278,10 @@ void RawValue::write(const void* value, void* dst, const TypeDescriptor& type, M *reinterpret_cast(value); break; + case TYPE_DECIMALV2: + *reinterpret_cast(dst) = *reinterpret_cast(value); + break; + case TYPE_HLL: case TYPE_VARCHAR: case TYPE_CHAR: { @@ -339,6 +351,11 @@ void RawValue::write(const void* value, const TypeDescriptor& type, void* dst, u case TYPE_DECIMAL: *reinterpret_cast(dst) = *reinterpret_cast(value); break; + + case TYPE_DECIMALV2: + *reinterpret_cast(dst) = *reinterpret_cast(value); + break; + default: DCHECK(false) << "RawValue::write(): bad type: " << type.debug_string(); } diff --git a/be/src/runtime/raw_value.h b/be/src/runtime/raw_value.h index 31661b944de258..351aa10606e125 100644 --- a/be/src/runtime/raw_value.h +++ b/be/src/runtime/raw_value.h @@ -167,6 +167,10 @@ inline bool RawValue::lt(const void* v1, const void* v2, const TypeDescriptor& t return *reinterpret_cast(v1) < *reinterpret_cast(v2); + case TYPE_DECIMALV2: + return reinterpret_cast(v1)->value < + reinterpret_cast(v2)->value; + case TYPE_LARGEINT: return reinterpret_cast(v1)->value < reinterpret_cast(v2)->value; @@ -225,6 +229,10 @@ inline bool RawValue::eq(const void* v1, const void* v2, const TypeDescriptor& t return *reinterpret_cast(v1) == *reinterpret_cast(v2); + case TYPE_DECIMALV2: + return reinterpret_cast(v1)->value == + reinterpret_cast(v2)->value; + case TYPE_LARGEINT: return reinterpret_cast(v1)->value == reinterpret_cast(v2)->value; @@ -285,6 +293,9 @@ inline uint32_t RawValue::get_hash_value( case TYPE_DECIMAL: return HashUtil::hash(v, 40, seed); + case TYPE_DECIMALV2: + return HashUtil::hash(v, 16, seed); + case TYPE_LARGEINT: return HashUtil::hash(v, 16, seed); @@ -340,6 +351,9 @@ inline uint32_t RawValue::get_hash_value_fvn( case TYPE_DECIMAL: return ((DecimalValue *) v)->hash(seed); + case TYPE_DECIMALV2: + return HashUtil::fnv_hash(v, 16, seed); + case TYPE_LARGEINT: return HashUtil::fnv_hash(v, 16, seed); @@ -406,6 +420,14 @@ inline uint32_t RawValue::zlib_crc32(const void* v, const TypeDescriptor& type, seed = HashUtil::zlib_crc_hash(&int_val, sizeof(int_val), seed); return HashUtil::zlib_crc_hash(&frac_val, sizeof(frac_val), seed); } + + case TYPE_DECIMALV2: { + const DecimalV2Value* dec_val = (const DecimalV2Value*)v; + int64_t int_val = dec_val->int_value(); + int32_t frac_val = dec_val->frac_value(); + seed = HashUtil::zlib_crc_hash(&int_val, sizeof(int_val), seed); + return HashUtil::zlib_crc_hash(&frac_val, sizeof(frac_val), seed); + } default: DCHECK(false) << "invalid type: " << type; return 0; diff --git a/be/src/runtime/raw_value_ir.cpp b/be/src/runtime/raw_value_ir.cpp index 66b83f0d50f405..04675e3e32eeb6 100644 --- a/be/src/runtime/raw_value_ir.cpp +++ b/be/src/runtime/raw_value_ir.cpp @@ -99,6 +99,13 @@ int RawValue::compare(const void* v1, const void* v2, const TypeDescriptor& type return (*decimal_value1 > *decimal_value2) ? 1 : (*decimal_value1 < *decimal_value2 ? -1 : 0); + case TYPE_DECIMALV2: { + DecimalV2Value decimal_value1(reinterpret_cast(v1)->value); + DecimalV2Value decimal_value2(reinterpret_cast(v2)->value); + return (decimal_value1 > decimal_value2) + ? 1 : (decimal_value1 < decimal_value2 ? -1 : 0); + } + case TYPE_LARGEINT: { __int128 large_int_value1 = reinterpret_cast(v1)->value; __int128 large_int_value2 = reinterpret_cast(v2)->value; diff --git a/be/src/runtime/result_writer.cpp b/be/src/runtime/result_writer.cpp index c9ba7da0ef93b8..a543d2bef7877c 100644 --- a/be/src/runtime/result_writer.cpp +++ b/be/src/runtime/result_writer.cpp @@ -149,6 +149,21 @@ Status ResultWriter::add_one_row(TupleRow* row) { break; } + case TYPE_DECIMALV2: { + DecimalV2Value decimal_val(reinterpret_cast(item)->value); + std::string decimal_str; + int output_scale = _output_expr_ctxs[i]->root()->output_scale(); + + if (output_scale > 0 && output_scale <= 30) { + decimal_str = decimal_val.to_string(output_scale); + } else { + decimal_str = decimal_val.to_string(); + } + + buf_ret = _row_buffer->push_string(decimal_str.c_str(), decimal_str.length()); + break; + } + default: LOG(WARNING) << "can't convert this type to mysql type. type = " << _output_expr_ctxs[i]->root()->type(); diff --git a/be/src/runtime/tuple.h b/be/src/runtime/tuple.h index 7bdda8e8a28cb0..3d7389d0d5a936 100644 --- a/be/src/runtime/tuple.h +++ b/be/src/runtime/tuple.h @@ -169,6 +169,11 @@ class Tuple { return reinterpret_cast(reinterpret_cast(this) + offset); } + DecimalV2Value* get_decimalv2_slot(int offset) { + DCHECK(offset != -1); // -1 offset indicates non-materialized slot + return reinterpret_cast(reinterpret_cast(this) + offset); + } + // For C++/IR interop, we need to be able to look up types by name. static const char* _s_llvm_class_name; diff --git a/be/src/runtime/types.cpp b/be/src/runtime/types.cpp index 24446096d1681d..e21ba42c63fce0 100644 --- a/be/src/runtime/types.cpp +++ b/be/src/runtime/types.cpp @@ -40,7 +40,7 @@ TypeDescriptor::TypeDescriptor(const std::vector& types, int* idx) : if (type == TYPE_CHAR || type == TYPE_VARCHAR || type == TYPE_HLL) { DCHECK(scalar_type.__isset.len); len = scalar_type.len; - } else if (type == TYPE_DECIMAL) { + } else if (type == TYPE_DECIMAL || type == TYPE_DECIMALV2) { DCHECK(scalar_type.__isset.precision); DCHECK(scalar_type.__isset.scale); precision = scalar_type.precision; @@ -107,7 +107,7 @@ void TypeDescriptor::to_thrift(TTypeDesc* thrift_type) const { if (type == TYPE_CHAR || type == TYPE_VARCHAR || type == TYPE_HLL) { // DCHECK_NE(len, -1); scalar_type.__set_len(len); - } else if (type == TYPE_DECIMAL) { + } else if (type == TYPE_DECIMAL || type == TYPE_DECIMALV2) { DCHECK_NE(precision, -1); DCHECK_NE(scale, -1); scalar_type.__set_precision(precision); @@ -124,7 +124,7 @@ void TypeDescriptor::to_protobuf(PTypeDesc* ptype) const { scalar_type->set_type(doris::to_thrift(type)); if (type == TYPE_CHAR || type == TYPE_VARCHAR || type == TYPE_HLL) { scalar_type->set_len(len); - } else if (type == TYPE_DECIMAL) { + } else if (type == TYPE_DECIMAL || type == TYPE_DECIMALV2) { DCHECK_NE(precision, -1); DCHECK_NE(scale, -1); scalar_type->set_precision(precision); @@ -148,7 +148,7 @@ TypeDescriptor::TypeDescriptor( if (type == TYPE_CHAR || type == TYPE_VARCHAR || type == TYPE_HLL) { DCHECK(scalar_type.has_len()); len = scalar_type.len(); - } else if (type == TYPE_DECIMAL) { + } else if (type == TYPE_DECIMAL || type == TYPE_DECIMALV2) { DCHECK(scalar_type.has_precision()); DCHECK(scalar_type.has_scale()); precision = scalar_type.precision(); @@ -170,6 +170,9 @@ std::string TypeDescriptor::debug_string() const { case TYPE_DECIMAL: ss << "DECIMAL(" << precision << ", " << scale << ")"; return ss.str(); + case TYPE_DECIMALV2: + ss << "DECIMALV2(" << precision << ", " << scale << ")"; + return ss.str(); default: return type_to_string(type); } diff --git a/be/src/runtime/types.h b/be/src/runtime/types.h index de12b5ccfd1682..19ed320dabd932 100644 --- a/be/src/runtime/types.h +++ b/be/src/runtime/types.h @@ -120,6 +120,18 @@ struct TypeDescriptor { return ret; } + static TypeDescriptor create_decimalv2_type(int precision, int scale) { + DCHECK_LE(precision, MAX_PRECISION); + DCHECK_LE(scale, MAX_SCALE); + DCHECK_GE(precision, 0); + DCHECK_LE(scale, precision); + TypeDescriptor ret; + ret.type = TYPE_DECIMALV2; + ret.precision = precision; + ret.scale = scale; + return ret; + } + static TypeDescriptor from_thrift(const TTypeDesc& t) { int idx = 0; TypeDescriptor result(t.types, &idx); @@ -144,7 +156,7 @@ struct TypeDescriptor { if (type == TYPE_CHAR) { return len == o.len; } - if (type == TYPE_DECIMAL) { + if (type == TYPE_DECIMAL || type == TYPE_DECIMALV2) { return precision == o.precision && scale == o.scale; } return true; @@ -171,7 +183,7 @@ struct TypeDescriptor { } inline bool is_decimal_type() const { - return type == TYPE_DECIMAL; + return (type == TYPE_DECIMAL || type == TYPE_DECIMALV2); } inline bool is_var_len_string_type() const { @@ -214,6 +226,7 @@ struct TypeDescriptor { case TYPE_LARGEINT: case TYPE_DATETIME: case TYPE_DATE: + case TYPE_DECIMALV2: return 16; case TYPE_DECIMAL: @@ -261,6 +274,9 @@ struct TypeDescriptor { case TYPE_DECIMAL: return sizeof(DecimalValue); + case TYPE_DECIMALV2: + return 16; + case INVALID_TYPE: default: DCHECK(false); diff --git a/be/src/udf/udf.cpp b/be/src/udf/udf.cpp index 6125f7b93041a0..715e59f3f481cb 100755 --- a/be/src/udf/udf.cpp +++ b/be/src/udf/udf.cpp @@ -22,6 +22,7 @@ #include #include "runtime/decimal_value.h" +#include "runtime/decimalv2_value.h" // Be careful what this includes since this needs to be linked into the UDF's // binary. For example, it would be unfortunate if they had a random dependency diff --git a/be/src/udf/udf.h b/be/src/udf/udf.h index 9c042b44ddc75f..2573b262b535ae 100755 --- a/be/src/udf/udf.h +++ b/be/src/udf/udf.h @@ -42,6 +42,7 @@ struct BigIntVal; struct StringVal; struct DateTimeVal; struct DecimalVal; +struct DecimalV2Val; // The FunctionContext is passed to every UDF/UDA and is the interface for the UDF to the // rest of the system. It contains APIs to examine the system state, report errors @@ -71,6 +72,7 @@ class FunctionContext { TYPE_HLL, TYPE_STRING, TYPE_FIXED_BUFFER, + TYPE_DECIMALV2 }; struct TypeDesc { @@ -687,6 +689,50 @@ struct DecimalVal : public AnyVal { }; +struct DecimalV2Val : public AnyVal { + + __int128 val; + + // Default value is zero + DecimalV2Val() : val(0) {} + + const __int128& value() const { return val; } + + DecimalV2Val(__int128 value) : val(value) {} + + static DecimalV2Val null() { + DecimalV2Val result; + result.is_null = true; + return result; + } + + void set_to_zero() { + val = 0; + } + + void set_to_abs_value() { + if (val < 0) val = -val; + } + + bool operator==(const DecimalV2Val& other) const { + if (is_null && other.is_null) { + return true; + } + + if (is_null || other.is_null) { + return false; + } + + return val == other.val; + } + + bool operator!=(const DecimalV2Val& other) const { + return !(*this == other); + } + +}; + + struct LargeIntVal : public AnyVal { __int128 val; @@ -729,6 +775,7 @@ using doris_udf::FloatVal; using doris_udf::DoubleVal; using doris_udf::StringVal; using doris_udf::DecimalVal; +using doris_udf::DecimalV2Val; using doris_udf::DateTimeVal; using doris_udf::FunctionContext; diff --git a/be/src/util/string_parser.hpp b/be/src/util/string_parser.hpp index d584230d4b29f1..f4ee1553b1fd60 100644 --- a/be/src/util/string_parser.hpp +++ b/be/src/util/string_parser.hpp @@ -69,6 +69,8 @@ class StringParser { template static T numeric_limits(bool negative); + static inline __int128 get_scale_multiplier(int scale); + // This is considerably faster than glibc's implementation (25x). // In the case of overflow, the max/min value for the data type will be returned. // Assumes s represents a decimal number. @@ -117,6 +119,9 @@ class StringParser { return string_to_bool_internal(s + i, len - i, result); } + static inline __int128 string_to_decimal(const char* s, int len, + int type_precision, int type_scale, ParseResult* result); + private: // This is considerably faster than glibc's implementation. // In the case of overflow, the max/min value for the data type will be returned. @@ -495,6 +500,198 @@ inline int StringParser::StringParseTraits<__int128>::max_ascii_len() { return 39; } +inline __int128 StringParser::get_scale_multiplier(int scale) { + DCHECK_GE(scale, 0); + static const __int128 values[] = { + static_cast<__int128>(1ll), + static_cast<__int128>(10ll), + static_cast<__int128>(100ll), + static_cast<__int128>(1000ll), + static_cast<__int128>(10000ll), + static_cast<__int128>(100000ll), + static_cast<__int128>(1000000ll), + static_cast<__int128>(10000000ll), + static_cast<__int128>(100000000ll), + static_cast<__int128>(1000000000ll), + static_cast<__int128>(10000000000ll), + static_cast<__int128>(100000000000ll), + static_cast<__int128>(1000000000000ll), + static_cast<__int128>(10000000000000ll), + static_cast<__int128>(100000000000000ll), + static_cast<__int128>(1000000000000000ll), + static_cast<__int128>(10000000000000000ll), + static_cast<__int128>(100000000000000000ll), + static_cast<__int128>(1000000000000000000ll), + static_cast<__int128>(1000000000000000000ll) * 10ll, + static_cast<__int128>(1000000000000000000ll) * 100ll, + static_cast<__int128>(1000000000000000000ll) * 1000ll, + static_cast<__int128>(1000000000000000000ll) * 10000ll, + static_cast<__int128>(1000000000000000000ll) * 100000ll, + static_cast<__int128>(1000000000000000000ll) * 1000000ll, + static_cast<__int128>(1000000000000000000ll) * 10000000ll, + static_cast<__int128>(1000000000000000000ll) * 100000000ll, + static_cast<__int128>(1000000000000000000ll) * 1000000000ll, + static_cast<__int128>(1000000000000000000ll) * 10000000000ll, + static_cast<__int128>(1000000000000000000ll) * 100000000000ll, + static_cast<__int128>(1000000000000000000ll) * 1000000000000ll, + static_cast<__int128>(1000000000000000000ll) * 10000000000000ll, + static_cast<__int128>(1000000000000000000ll) * 100000000000000ll, + static_cast<__int128>(1000000000000000000ll) * 1000000000000000ll, + static_cast<__int128>(1000000000000000000ll) * 10000000000000000ll, + static_cast<__int128>(1000000000000000000ll) * 100000000000000000ll, + static_cast<__int128>(1000000000000000000ll) * 100000000000000000ll * 10ll, + static_cast<__int128>(1000000000000000000ll) * 100000000000000000ll * 100ll, + static_cast<__int128>(1000000000000000000ll) * 100000000000000000ll * 1000ll}; + if (scale >= 0 && scale < 39) return values[scale]; + return -1; // Overflow +} + +inline __int128 StringParser::string_to_decimal(const char* s, int len, + int type_precision, int type_scale, ParseResult* result) { + // Special cases: + // 1) '' == Fail, an empty string fails to parse. + // 2) ' # ' == #, leading and trailing white space is ignored. + // 3) '.' == 0, a single dot parses as zero (for consistency with other types). + // 4) '#.' == '#', a trailing dot is ignored. + + // Ignore leading and trailing spaces. + while (len > 0 && is_whitespace(*s)) { + ++s; + --len; + } + while (len > 0 && is_whitespace(s[len - 1])) { + --len; + } + + bool is_negative = false; + if (len > 0) { + switch (*s) { + case '-': + is_negative = true; + case '+': + ++s; + --len; + } + } + + // Ignore leading zeros. + bool found_value = false; + while (len > 0 && UNLIKELY(*s == '0')) { + found_value = true; + ++s; + --len; + } + + // Ignore leading zeros even after a dot. This allows for differentiating between + // cases like 0.01e2, which would fit in a DECIMAL(1, 0), and 0.10e2, which would + // overflow. + int scale = 0; + int found_dot = 0; + if (len > 0 && *s == '.') { + found_dot = 1; + ++s; + --len; + while (len > 0 && UNLIKELY(*s == '0')) { + found_value = true; + ++scale; + ++s; + --len; + } + } + + int precision = 0; + bool found_exponent = false; + int8_t exponent = 0; + __int128 value = 0; + for (int i = 0; i < len; ++i) { + const char& c = s[i]; + if (LIKELY('0' <= c && c <= '9')) { + found_value = true; + // Ignore digits once the type's precision limit is reached. This avoids + // overflowing the underlying storage while handling a string like + // 10000000000e-10 into a DECIMAL(1, 0). Adjustments for ignored digits and + // an exponent will be made later. + if (LIKELY(type_precision > precision)) { + value = (value * 10) + (c - '0'); // Benchmarks are faster with parenthesis... + } + DCHECK(value >= 0); // For some reason //DCHECK_GE doesn't work with __int128. + ++precision; + scale += found_dot; + } else if (c == '.' && LIKELY(!found_dot)) { + found_dot = 1; + } else if ((c == 'e' || c == 'E') && LIKELY(!found_exponent)) { + found_exponent = true; + exponent = string_to_int_internal(s + i + 1, len - i - 1, result); + if (UNLIKELY(*result != StringParser::PARSE_SUCCESS)) { + if (*result == StringParser::PARSE_OVERFLOW && exponent < 0) { + *result = StringParser::PARSE_UNDERFLOW; + } + return 0; + } + break; + } else { + if (value == 0) { + *result = StringParser::PARSE_FAILURE; + return 0; + } + *result = StringParser::PARSE_SUCCESS; + value *= get_scale_multiplier(type_scale - scale); + return is_negative ? -value : value; + } + } + + // Find the number of truncated digits before adjusting the precision for an exponent. + int truncated_digit_count = precision - type_precision; + if (exponent > scale) { + // Ex: 0.1e3 (which at this point would have precision == 1 and scale == 1), the + // scale must be set to 0 and the value set to 100 which means a precision of 3. + precision += exponent - scale; + value *= get_scale_multiplier(exponent - scale); + scale = 0; + } else { + // Ex: 100e-4, the scale must be set to 4 but no adjustment to the value is needed, + // the precision must also be set to 4 but that will be done below for the + // non-exponent case anyways. + scale -= exponent; + } + // Ex: 0.001, at this point would have precision 1 and scale 3 since leading zeros + // were ignored during previous parsing. + if (scale > precision) precision = scale; + + // Microbenchmarks show that beyond this point, returning on parse failure is slower + // than just letting the function run out. + *result = StringParser::PARSE_SUCCESS; + if (UNLIKELY(precision - scale > type_precision - type_scale)) { + *result = StringParser::PARSE_OVERFLOW; + } else if (UNLIKELY(scale > type_scale)) { + *result = StringParser::PARSE_UNDERFLOW; + int shift = scale - type_scale; + if (UNLIKELY(truncated_digit_count > 0)) shift -= truncated_digit_count; + if (shift > 0) { + __int128 divisor = get_scale_multiplier(shift); + if (LIKELY(divisor >= 0)) { + value /= divisor; + __int128 remainder = value % divisor; + if (abs(remainder) >= (divisor >> 1)) { + value += 1; + } + } else { + DCHECK(divisor == -1); // //DCHECK_EQ doesn't work with __int128. + value = 0; + } + } + DCHECK(value >= 0); // //DCHECK_GE doesn't work with __int128. + } else if (UNLIKELY(!found_value && !found_dot)) { + *result = StringParser::PARSE_FAILURE; + } + + if (type_scale > scale) { + value *= get_scale_multiplier(type_scale - scale); + } + + return is_negative ? -value : value; +} + } // end namespace doris #endif // end of DORIS_BE_SRC_COMMON_UTIL_STRING_PARSER_HPP diff --git a/be/src/util/symbols_util.cpp b/be/src/util/symbols_util.cpp index 1e8d062a5c3100..1b5f0c8ae8fb30 100644 --- a/be/src/util/symbols_util.cpp +++ b/be/src/util/symbols_util.cpp @@ -160,6 +160,9 @@ static void append_any_val_type( case TYPE_DECIMAL: append_mangled_token("DecimalVal", s); break; + case TYPE_DECIMALV2: + append_mangled_token("DecimalV2Val", s); + break; default: DCHECK(false) << "NYI: " << type.debug_string(); } diff --git a/be/test/runtime/CMakeLists.txt b/be/test/runtime/CMakeLists.txt index a42f23cbd08a32..09758663078ec2 100644 --- a/be/test/runtime/CMakeLists.txt +++ b/be/test/runtime/CMakeLists.txt @@ -32,6 +32,7 @@ ADD_BE_TEST(string_buffer_test) #ADD_BE_TEST(parallel_executor_test) ADD_BE_TEST(datetime_value_test) ADD_BE_TEST(decimal_value_test) +ADD_BE_TEST(decimalv2_value_test) ADD_BE_TEST(large_int_value_test) ADD_BE_TEST(string_value_test) #ADD_BE_TEST(thread_resource_mgr_test) diff --git a/be/test/runtime/decimalv2_value_test.cpp b/be/test/runtime/decimalv2_value_test.cpp new file mode 100644 index 00000000000000..ac398562e454d1 --- /dev/null +++ b/be/test/runtime/decimalv2_value_test.cpp @@ -0,0 +1,551 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "runtime/decimalv2_value.h" + +#include +#include + +#include + +#include "util/logging.h" + +namespace doris { + +class DecimalV2ValueTest : public testing::Test { +public: + DecimalV2ValueTest() { + } + +protected: + virtual void SetUp() { + } + virtual void TearDown() { + } +}; + +TEST_F(DecimalV2ValueTest, string_to_decimal) { + DecimalV2Value value(std::string("1.23")); + ASSERT_EQ("1.23", value.to_string(3)); + + DecimalV2Value value1(std::string("0.23")); + ASSERT_EQ("0.23", value1.to_string(3)); + + DecimalV2Value value2(std::string("1234567890123456789.0")); + ASSERT_EQ("1234567890123456789", value2.to_string(3)); +} + +TEST_F(DecimalV2ValueTest, negative_zero) { + DecimalV2Value value(std::string("-0.00")); + std::cout << "value: " << value.get_debug_info() << std::endl; + { + // positive zero VS negative zero + DecimalV2Value value2(std::string("0.00")); + std::cout << "value2: " << value2.get_debug_info() << std::endl; + ASSERT_TRUE(value == value2); + ASSERT_FALSE(value < value2); + ASSERT_FALSE(value < value2); + ASSERT_TRUE(value <= value2); + ASSERT_TRUE(value >= value2); + } + { + // from string, positive + DecimalV2Value value3(std::string("5.0")); + std::cout << "value3: " << value3.get_debug_info() << std::endl; + ASSERT_TRUE(value < value3); + ASSERT_TRUE(value <= value3); + ASSERT_TRUE(value3 > value); + ASSERT_TRUE(value3 >= value); + } + { + // from string, negative + DecimalV2Value value3(std::string("-5.0")); + std::cout << "value3: " << value3.get_debug_info() << std::endl; + ASSERT_TRUE(value > value3); + ASSERT_TRUE(value >= value3); + ASSERT_TRUE(value3 < value); + ASSERT_TRUE(value3 <= value); + } + { + // from int + DecimalV2Value value3(6); + std::cout << "value3: " << value3.get_debug_info() << std::endl; + ASSERT_TRUE(value < value3); + ASSERT_TRUE(value <= value3); + ASSERT_TRUE(value3 > value); + ASSERT_TRUE(value3 >= value); + + ASSERT_FALSE(!(value < value3)); + ASSERT_FALSE(!(value <= value3)); + ASSERT_FALSE(!(value3 > value)); + ASSERT_FALSE(!(value3 >= value)); + + } + { + // from int + DecimalV2Value value3(4, 0); + std::cout << "value3: " << value3.get_debug_info() << std::endl; + ASSERT_TRUE(value < value3); + ASSERT_TRUE(value <= value3); + ASSERT_TRUE(value3 > value); + ASSERT_TRUE(value3 >= value); + } + { + // from int + DecimalV2Value value3(3, -0); + std::cout << "value3: " << value3.get_debug_info() << std::endl; + ASSERT_TRUE(value < value3); + ASSERT_TRUE(value <= value3); + ASSERT_TRUE(value3 > value); + ASSERT_TRUE(value3 >= value); + } +} + +TEST_F(DecimalV2ValueTest, int_to_decimal) { + DecimalV2Value value1; + ASSERT_EQ("0", value1.to_string(3)); + + DecimalV2Value value2(111111111); // 9 digits + std::cout << "value2: " << value2.get_debug_info() << std::endl; + ASSERT_EQ("111111111", value2.to_string(3)); + + DecimalV2Value value3(111111111, 222222222); // 9 digits + std::cout << "value3: " << value3.get_debug_info() << std::endl; + ASSERT_EQ("111111111.222", value3.to_string(3)); + + DecimalV2Value value4(0, 222222222); // 9 digits + std::cout << "value4: " << value4.get_debug_info() << std::endl; + ASSERT_EQ("0.222", value4.to_string(3)); + + DecimalV2Value value5(111111111, 0); // 9 digits + std::cout << "value5: " << value5.get_debug_info() << std::endl; + ASSERT_EQ("111111111", value5.to_string(3)); + + DecimalV2Value value6(0, 0); // 9 digits + std::cout << "value6: " << value6.get_debug_info() << std::endl; + ASSERT_EQ("0", value6.to_string(3)); + + DecimalV2Value value7(0, 12345); // 9 digits + std::cout << "value7: " << value7.get_debug_info() << std::endl; + ASSERT_EQ("0.000012", value7.to_string(6)); + + DecimalV2Value value8(11, 0); + std::cout << "value8: " << value8.get_debug_info() << std::endl; + ASSERT_EQ("11", value8.to_string(3)); + + // more than 9digit, fraction will be trancated to 999999999 + DecimalV2Value value9(1230123456789, 1230123456789); + std::cout << "value9: " << value9.get_debug_info() << std::endl; + ASSERT_EQ("1230123456789.999999999", value9.to_string(10)); + + // negative + { + DecimalV2Value value2(-111111111); // 9 digits + std::cout << "value2: " << value2.get_debug_info() << std::endl; + ASSERT_EQ("-111111111", value2.to_string(3)); + + DecimalV2Value value3(-111111111, 222222222); // 9 digits + std::cout << "value3: " << value3.get_debug_info() << std::endl; + ASSERT_EQ("-111111111.222", value3.to_string(3)); + + DecimalV2Value value4(0, -222222222); // 9 digits + std::cout << "value4: " << value4.get_debug_info() << std::endl; + ASSERT_EQ("-0.222", value4.to_string(3)); + + DecimalV2Value value5(-111111111, 0); // 9 digits + std::cout << "value5: " << value5.get_debug_info() << std::endl; + ASSERT_EQ("-111111111", value5.to_string(3)); + + DecimalV2Value value7(0, -12345); // 9 digits + std::cout << "value7: " << value7.get_debug_info() << std::endl; + ASSERT_EQ("-0.000012", value7.to_string(6)); + + DecimalV2Value value8(-11, 0); + std::cout << "value8: " << value8.get_debug_info() << std::endl; + ASSERT_EQ("-11", value8.to_string(3)); + } +} + +TEST_F(DecimalV2ValueTest, add) { + DecimalV2Value value11(std::string("1111111111.222222222"));// 9 digits + DecimalV2Value value12(std::string("2222222222.111111111")); // 9 digits + DecimalV2Value add_result1 = value11 + value12; + std::cout << "add_result1: " << add_result1.get_debug_info() << std::endl; + ASSERT_EQ("3333333333.333333333", add_result1.to_string(9)); + + DecimalV2Value value21(std::string("-3333333333.222222222"));// 9 digits + DecimalV2Value value22(std::string("2222222222.111111111")); // 9 digits + DecimalV2Value add_result2 = value21 + value22; + std::cout << "add_result2: " << add_result2.get_debug_info() << std::endl; + ASSERT_EQ("-1111111111.111111111", add_result2.to_string(9)); +} + +TEST_F(DecimalV2ValueTest, compound_add) { + { + DecimalV2Value value1(std::string("111111111.222222222")); + DecimalV2Value value2(std::string("111111111.222222222")); + value1 += value2; + std::cout << "value1: " << value1.get_debug_info() << std::endl; + ASSERT_EQ("222222222.444444444", value1.to_string(9)); + } +} + +TEST_F(DecimalV2ValueTest, sub) { + DecimalV2Value value11(std::string("3333333333.222222222"));// 9 digits + DecimalV2Value value12(std::string("2222222222.111111111")); // 9 digits + DecimalV2Value sub_result1 = value11 - value12; + std::cout << "sub_result1: " << sub_result1.get_debug_info() << std::endl; + ASSERT_EQ("1111111111.111111111", sub_result1.to_string(9)); + + DecimalV2Value value21(std::string("-2222222222.111111111")); // 9 digits + DecimalV2Value sub_result2 = value11 - value21; + std::cout << "sub_result2: " << sub_result2.get_debug_info() << std::endl; + ASSERT_EQ("5555555555.333333333", sub_result2.to_string(9)); + + // small - big + { + DecimalV2Value value1(std::string("8.0")); + DecimalV2Value value2(std::string("0")); + DecimalV2Value sub_result = value2 - value1; + std::cout << "sub_result: " << sub_result.get_debug_info() << std::endl; + DecimalV2Value expected_value(std::string("-8.0")); + ASSERT_EQ(expected_value, sub_result); + ASSERT_FALSE(sub_result.is_zero()); + } + // minimum - maximal + { + DecimalV2Value value1(std::string( + "999999999999999999.999999999")); // 27 digits + DecimalV2Value value2(std::string( + "-999999999999999999.999999999")); // 27 digits + DecimalV2Value sub_result = value2 - value1; + std::cout << "sub_result: " << sub_result.get_debug_info() << std::endl; + DecimalV2Value expected_value = value2; + ASSERT_EQ(expected_value, sub_result); + ASSERT_FALSE(sub_result.is_zero()); + ASSERT_TRUE(value1 > value2); + } +} + +TEST_F(DecimalV2ValueTest, mul) { + DecimalV2Value value11(std::string("333333333.2222")); + DecimalV2Value value12(std::string("-222222222.1111")); + DecimalV2Value mul_result1 = value11 * value12; + std::cout << "mul_result1: " << mul_result1.get_debug_info() << std::endl; + ASSERT_EQ(DecimalV2Value( + std::string("-74074074012337037.04938642")), + mul_result1); + + DecimalV2Value value21(std::string("0")); // zero + DecimalV2Value mul_result2 = value11 * value21; + std::cout << "mul_result2: " << mul_result2.get_debug_info() << std::endl; + ASSERT_EQ(DecimalV2Value(std::string("0")), mul_result2); + +} + +TEST_F(DecimalV2ValueTest, div) { + DecimalV2Value value11(std::string("-74074074012337037.04938642")); + DecimalV2Value value12(std::string("-222222222.1111")); + DecimalV2Value div_result1 = value11 / value12; + std::cout << "div_result1: " << div_result1.get_debug_info() << std::endl; + ASSERT_EQ(DecimalV2Value(std::string("333333333.2222")), div_result1); + ASSERT_EQ("333333333.2222", div_result1.to_string()); + { + DecimalV2Value value11(std::string("32766.999943536")); + DecimalV2Value value12(std::string("604587")); + DecimalV2Value div_result1 = value11 / value12; + std::cout << "div_result1: " << div_result1.get_debug_info() << std::endl; + ASSERT_EQ(DecimalV2Value(std::string("0.054197328")), div_result1); + } +} + +TEST_F(DecimalV2ValueTest, unary_minus_operator) { + { + DecimalV2Value value1(std::string("111111111.222222222")); + DecimalV2Value value2 = -value1; + std::cout << "value1: " << value1.get_debug_info() << std::endl; + std::cout << "value2: " << value2.get_debug_info() << std::endl; + ASSERT_EQ("111111111.222222222", value1.to_string(10)); + ASSERT_EQ("-111111111.222222222", value2.to_string(10)); + + } +} + +TEST_F(DecimalV2ValueTest, to_int_frac_value) { + // positive & negative + { + DecimalV2Value value(std::string("123456789123456789.987654321")); + ASSERT_EQ(123456789123456789, value.int_value()); + ASSERT_EQ(987654321, value.frac_value()); + + DecimalV2Value value2(std::string("-123456789123456789.987654321")); + ASSERT_EQ(-123456789123456789, value2.int_value()); + ASSERT_EQ(-987654321, value2.frac_value()); + } + // int or frac part is 0 + { + DecimalV2Value value(std::string("-123456789123456789")); + ASSERT_EQ(-123456789123456789, value.int_value()); + ASSERT_EQ(0, value.frac_value()); + + DecimalV2Value value2(std::string("0.987654321")); + ASSERT_EQ(0, value2.int_value()); + ASSERT_EQ(987654321, value2.frac_value()); + } + // truncate frac part + { + DecimalV2Value value(std::string("-123456789.987654321987654321")); + ASSERT_EQ(-123456789, value.int_value()); + ASSERT_EQ(-987654321, value.frac_value()); + } +} + +// Half up +TEST_F(DecimalV2ValueTest, round_ops) { + // less than 5 + DecimalV2Value value(std::string("1.249")); + { + DecimalV2Value dst; + value.round(&dst, -1, HALF_UP); + ASSERT_EQ("0", dst.to_string()); + + value.round(&dst, -1, CEILING); + ASSERT_EQ("10", dst.to_string()); + + value.round(&dst, -1, FLOOR); + ASSERT_EQ("0", dst.to_string()); + + value.round(&dst, -1, TRUNCATE); + ASSERT_EQ("0", dst.to_string()); + } + { + DecimalV2Value dst; + value.round(&dst, 0, HALF_UP); + ASSERT_EQ("1", dst.to_string()); + + value.round(&dst, 0, CEILING); + ASSERT_EQ("2", dst.to_string()); + + value.round(&dst, 0, FLOOR); + ASSERT_EQ("1", dst.to_string()); + + value.round(&dst, 0, TRUNCATE); + ASSERT_EQ("1", dst.to_string()); + } + + { + DecimalV2Value dst; + value.round(&dst, 1, HALF_UP); + ASSERT_EQ("1.2", dst.to_string()); + + value.round(&dst, 1, CEILING); + ASSERT_EQ("1.3", dst.to_string()); + + value.round(&dst, 1, FLOOR); + ASSERT_EQ("1.2", dst.to_string()); + + value.round(&dst, 1, TRUNCATE); + ASSERT_EQ("1.2", dst.to_string()); + } + + { + DecimalV2Value dst; + value.round(&dst, 2, HALF_UP); + ASSERT_EQ("1.25", dst.to_string()); + + value.round(&dst, 2, CEILING); + ASSERT_EQ("1.25", dst.to_string()); + + value.round(&dst, 2, FLOOR); + ASSERT_EQ("1.24", dst.to_string()); + + value.round(&dst, 2, TRUNCATE); + ASSERT_EQ("1.24", dst.to_string()); + } + + { + DecimalV2Value dst; + value.round(&dst, 3, HALF_UP); + ASSERT_EQ("1.249", dst.to_string()); + + value.round(&dst, 3, CEILING); + ASSERT_EQ("1.249", dst.to_string()); + + value.round(&dst, 3, FLOOR); + ASSERT_EQ("1.249", dst.to_string()); + + value.round(&dst, 3, TRUNCATE); + ASSERT_EQ("1.249", dst.to_string()); + } + + { + DecimalV2Value dst; + value.round(&dst, 4, HALF_UP); + ASSERT_EQ("1.249", dst.to_string()); + + value.round(&dst, 4, CEILING); + ASSERT_EQ("1.249", dst.to_string()); + + value.round(&dst, 4, FLOOR); + ASSERT_EQ("1.249", dst.to_string()); + + value.round(&dst, 4, TRUNCATE); + ASSERT_EQ("1.249", dst.to_string()); + } +} + +// Half up +TEST_F(DecimalV2ValueTest, round_minus) { + // less than 5 + DecimalV2Value value(std::string("-1.249")); + { + DecimalV2Value dst; + value.round(&dst, -1, HALF_UP); + ASSERT_EQ("0", dst.to_string()); + + value.round(&dst, -1, CEILING); + ASSERT_EQ("0", dst.to_string()); + + value.round(&dst, -1, FLOOR); + ASSERT_EQ("-10", dst.to_string()); + + value.round(&dst, -1, TRUNCATE); + ASSERT_EQ("0", dst.to_string()); + } + { + DecimalV2Value dst; + value.round(&dst, 0, HALF_UP); + ASSERT_EQ("-1", dst.to_string()); + + value.round(&dst, 0, CEILING); + ASSERT_EQ("-1", dst.to_string()); + + value.round(&dst, 0, FLOOR); + ASSERT_EQ("-2", dst.to_string()); + + value.round(&dst, 0, TRUNCATE); + ASSERT_EQ("-1", dst.to_string()); + } + + { + DecimalV2Value dst; + value.round(&dst, 1, HALF_UP); + ASSERT_EQ("-1.2", dst.to_string()); + + value.round(&dst, 1, CEILING); + ASSERT_EQ("-1.2", dst.to_string()); + + value.round(&dst, 1, FLOOR); + ASSERT_EQ("-1.3", dst.to_string()); + + value.round(&dst, 1, TRUNCATE); + ASSERT_EQ("-1.2", dst.to_string()); + } + + { + DecimalV2Value dst; + value.round(&dst, 2, HALF_UP); + ASSERT_EQ("-1.25", dst.to_string()); + + value.round(&dst, 2, CEILING); + ASSERT_EQ("-1.24", dst.to_string()); + + value.round(&dst, 2, FLOOR); + ASSERT_EQ("-1.25", dst.to_string()); + + value.round(&dst, 2, TRUNCATE); + ASSERT_EQ("-1.24", dst.to_string()); + } + + { + DecimalV2Value dst; + value.round(&dst, 3, HALF_UP); + ASSERT_EQ("-1.249", dst.to_string()); + + value.round(&dst, 3, CEILING); + ASSERT_EQ("-1.249", dst.to_string()); + + value.round(&dst, 3, FLOOR); + ASSERT_EQ("-1.249", dst.to_string()); + + value.round(&dst, 3, TRUNCATE); + ASSERT_EQ("-1.249", dst.to_string()); + } + + { + DecimalV2Value dst; + value.round(&dst, 4, HALF_UP); + ASSERT_EQ("-1.249", dst.to_string()); + + value.round(&dst, 4, CEILING); + ASSERT_EQ("-1.249", dst.to_string()); + + value.round(&dst, 4, FLOOR); + ASSERT_EQ("-1.249", dst.to_string()); + + value.round(&dst, 4, TRUNCATE); + ASSERT_EQ("-1.249", dst.to_string()); + } +} + +// Half up +TEST_F(DecimalV2ValueTest, round_to_int) { + { + DecimalV2Value value(std::string("99.99")); + { + DecimalV2Value dst; + value.round(&dst, 1, HALF_UP); + ASSERT_EQ("100", dst.to_string()); + } + } + { + DecimalV2Value value(std::string("123.12399")); + { + DecimalV2Value dst; + value.round(&dst, 4, HALF_UP); + ASSERT_EQ("123.124", dst.to_string()); + } + } +} + +TEST_F(DecimalV2ValueTest, double_to_decimal) { + double i = 1.2; + DecimalV2Value *value = new DecimalV2Value(100, 9876); + value->assign_from_double(i); + ASSERT_STREQ("1.2", value->to_string().c_str()); + delete value; +} + +TEST_F(DecimalV2ValueTest, float_to_decimal) { + float i = 1.2; + DecimalV2Value *value = new DecimalV2Value(100, 9876); + value->assign_from_float(i); + ASSERT_STREQ("1.2", value->to_string().c_str()); + delete value; +} +} // end namespace doris + +int main(int argc, char** argv) { + // std::string conffile = std::string(getenv("DORIS_HOME")) + "/conf/be.conf"; + // if (!doris::config::init(conffile.c_str(), false)) { + // fprintf(stderr, "error read config file. \n"); + // return -1; + // } + doris::init_glog("be-test"); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/fe/src/main/cup/sql_parser.cup b/fe/src/main/cup/sql_parser.cup index b54ae1b2e3ad2b..7f0642ae07cc1d 100644 --- a/fe/src/main/cup/sql_parser.cup +++ b/fe/src/main/cup/sql_parser.cup @@ -3106,11 +3106,11 @@ type ::= | KW_CHAR {: RESULT = ScalarType.createCharType(-1); :} | KW_DECIMAL LPAREN INTEGER_LITERAL:precision RPAREN - {: RESULT = ScalarType.createDecimalType(precision.intValue()); :} + {: RESULT = ScalarType.createDecimalV2Type(precision.intValue()); :} | KW_DECIMAL LPAREN INTEGER_LITERAL:precision COMMA INTEGER_LITERAL:scale RPAREN - {: RESULT = ScalarType.createDecimalType(precision.intValue(), scale.intValue()); :} + {: RESULT = ScalarType.createDecimalV2Type(precision.intValue(), scale.intValue()); :} | KW_DECIMAL - {: RESULT = ScalarType.createDecimalType(); :} + {: RESULT = ScalarType.createDecimalV2Type(); :} | KW_HLL {: ScalarType type = ScalarType.createHllType(); type.setAssignedStrLenInColDefinition(); diff --git a/fe/src/main/java/org/apache/doris/analysis/AggregateInfoBase.java b/fe/src/main/java/org/apache/doris/analysis/AggregateInfoBase.java index 116792865cf4e2..286740d0c4aa43 100644 --- a/fe/src/main/java/org/apache/doris/analysis/AggregateInfoBase.java +++ b/fe/src/main/java/org/apache/doris/analysis/AggregateInfoBase.java @@ -162,7 +162,7 @@ private TupleDescriptor createTupleDesc(Analyzer analyzer, boolean isOutputTuple if (!intermediateType.isWildcardDecimal()) { slotDesc.setType(intermediateType); } else { - Preconditions.checkState(expr.getType().isDecimal()); + Preconditions.checkState(expr.getType().isDecimal() || expr.getType().isDecimalV2()); } } } diff --git a/fe/src/main/java/org/apache/doris/analysis/ArithmeticExpr.java b/fe/src/main/java/org/apache/doris/analysis/ArithmeticExpr.java index bc18b9eae59201..e7aaa91ab46177 100644 --- a/fe/src/main/java/org/apache/doris/analysis/ArithmeticExpr.java +++ b/fe/src/main/java/org/apache/doris/analysis/ArithmeticExpr.java @@ -108,6 +108,10 @@ public static void initBuiltins(FunctionSet functionSet) { Operator.DIVIDE.getName(), Lists.newArrayList(Type.DECIMAL, Type.DECIMAL), Type.DECIMAL)); + functionSet.addBuiltin(ScalarFunction.createBuiltinOperator( + Operator.DIVIDE.getName(), + Lists.newArrayList(Type.DECIMALV2, Type.DECIMALV2), + Type.DECIMALV2)); // MOD(), FACTORIAL(), BITAND(), BITOR(), BITXOR(), and BITNOT() are registered as // builtins, see palo_functions.py @@ -161,7 +165,7 @@ public String toSqlImpl() { @Override protected void toThrift(TExprNode msg) { msg.node_type = TExprNodeType.ARITHMETIC_EXPR; - if (!type.isDecimal()) { + if (!type.isDecimal() && !type.isDecimalV2()) { msg.setOpcode(op.getOpcode()); msg.setOutput_column(outputColumn); } @@ -195,6 +199,8 @@ private Type findCommonType(Type t1, Type t2) { if (pt1 == PrimitiveType.DOUBLE || pt2 == PrimitiveType.DOUBLE) { return Type.DOUBLE; + } else if (pt1 == PrimitiveType.DECIMALV2 || pt2 == PrimitiveType.DECIMALV2) { + return Type.DECIMALV2; } else if (pt1 == PrimitiveType.DECIMAL || pt2 == PrimitiveType.DECIMAL) { return Type.DECIMAL; } else if (pt1 == PrimitiveType.LARGEINT || pt2 == PrimitiveType.LARGEINT) { diff --git a/fe/src/main/java/org/apache/doris/analysis/BinaryPredicate.java b/fe/src/main/java/org/apache/doris/analysis/BinaryPredicate.java index 00b9778d528de8..76915a370ca057 100644 --- a/fe/src/main/java/org/apache/doris/analysis/BinaryPredicate.java +++ b/fe/src/main/java/org/apache/doris/analysis/BinaryPredicate.java @@ -284,6 +284,10 @@ private Type getCmpType() { if (t1 == PrimitiveType.BIGINT && t2 == PrimitiveType.BIGINT) { return Type.getAssignmentCompatibleType(getChild(0).getType(), getChild(1).getType(), false); } + if ((t1 == PrimitiveType.BIGINT || t1 == PrimitiveType.DECIMALV2) + && (t2 == PrimitiveType.BIGINT || t2 == PrimitiveType.DECIMALV2)) { + return Type.DECIMALV2; + } if ((t1 == PrimitiveType.BIGINT || t1 == PrimitiveType.DECIMAL) && (t2 == PrimitiveType.BIGINT || t2 == PrimitiveType.DECIMAL)) { return Type.DECIMAL; diff --git a/fe/src/main/java/org/apache/doris/analysis/CastExpr.java b/fe/src/main/java/org/apache/doris/analysis/CastExpr.java index 9bba56426502a9..da35ea8b6e3f50 100644 --- a/fe/src/main/java/org/apache/doris/analysis/CastExpr.java +++ b/fe/src/main/java/org/apache/doris/analysis/CastExpr.java @@ -104,7 +104,8 @@ public static void initBuiltins(FunctionSet functionSet) { continue; } // Disable casting from boolean/timestamp to decimal - if ((fromType.isBoolean() || fromType.isDateType()) && toType == Type.DECIMAL) { + if ((fromType.isBoolean() || fromType.isDateType()) && + (toType == Type.DECIMAL || toType == Type.DECIMALV2)) { continue; } @@ -112,7 +113,8 @@ public static void initBuiltins(FunctionSet functionSet) { if (fromType.equals(toType)) { continue; } - String beClass = toType.isDecimal() || fromType.isDecimal() ? "DecimalOperators" : "CastFunctions"; + String beClass = toType.isDecimalV2() || fromType.isDecimalV2() ? "DecimalV2Operators" : "CastFunctions"; + if (toType.isDecimal() || fromType.isDecimal()) beClass = "DecimalOperators"; String typeName = Function.getUdfTypeName(toType.getPrimitiveType()); if (toType.getPrimitiveType() == PrimitiveType.DATE) { typeName = "date_val"; diff --git a/fe/src/main/java/org/apache/doris/analysis/ColumnDef.java b/fe/src/main/java/org/apache/doris/analysis/ColumnDef.java index 7a6160fcbb7a9d..1ff3b09ba5fdf5 100644 --- a/fe/src/main/java/org/apache/doris/analysis/ColumnDef.java +++ b/fe/src/main/java/org/apache/doris/analysis/ColumnDef.java @@ -158,6 +158,7 @@ public static void validateDefaultValue(Type type, String defaultValue) throws A FloatLiteral doubleLiteral = new FloatLiteral(defaultValue); break; case DECIMAL: + case DECIMALV2: DecimalLiteral decimalLiteral = new DecimalLiteral(defaultValue); decimalLiteral.checkPrecisionAndScale(scalarType.getScalarPrecision(), scalarType.getScalarScale()); break; diff --git a/fe/src/main/java/org/apache/doris/analysis/CreateTableAsSelectStmt.java b/fe/src/main/java/org/apache/doris/analysis/CreateTableAsSelectStmt.java index 4f2bd24905c50a..51da87eddec913 100644 --- a/fe/src/main/java/org/apache/doris/analysis/CreateTableAsSelectStmt.java +++ b/fe/src/main/java/org/apache/doris/analysis/CreateTableAsSelectStmt.java @@ -60,7 +60,7 @@ public void analyze(Analyzer analyzer) throws UserException, AnalysisException { // TODO(zc): support char, varchar and decimal for (Expr expr : tmpStmt.getResultExprs()) { - if (expr.getType().isDecimal() || expr.getType().isStringType()) { + if (expr.getType().isDecimal() || expr.getType().isDecimalV2() || expr.getType().isStringType()) { ErrorReport.reportAnalysisException(ErrorCode.ERR_UNSUPPORTED_TYPE_IN_CTAS, expr.getType()); } } diff --git a/fe/src/main/java/org/apache/doris/analysis/DecimalLiteral.java b/fe/src/main/java/org/apache/doris/analysis/DecimalLiteral.java index a2b52afed97209..5895f3a4566a43 100644 --- a/fe/src/main/java/org/apache/doris/analysis/DecimalLiteral.java +++ b/fe/src/main/java/org/apache/doris/analysis/DecimalLiteral.java @@ -73,7 +73,7 @@ public Expr clone() { private void init(BigDecimal value) { this.value = value; - type = Type.DECIMAL; + type = Type.DECIMALV2; } public BigDecimal getValue() { @@ -130,6 +130,7 @@ public ByteBuffer getHashValue(PrimitiveType type) { buffer.putLong(value.longValue()); break; case DECIMAL: + case DECIMALV2: buffer = ByteBuffer.allocate(12); buffer.order(ByteOrder.LITTLE_ENDIAN); diff --git a/fe/src/main/java/org/apache/doris/analysis/FloatLiteral.java b/fe/src/main/java/org/apache/doris/analysis/FloatLiteral.java index b84e0f42451a06..6f9d35b85817e3 100644 --- a/fe/src/main/java/org/apache/doris/analysis/FloatLiteral.java +++ b/fe/src/main/java/org/apache/doris/analysis/FloatLiteral.java @@ -152,13 +152,13 @@ public double getValue() { @Override protected Expr uncheckedCastTo(Type targetType) throws AnalysisException { - if (!(targetType.isFloatingPointType() || targetType.isDecimal())) { + if (!(targetType.isFloatingPointType() || targetType.isDecimal() || targetType.isDecimalV2())) { return super.uncheckedCastTo(targetType); } if (targetType.isFloatingPointType()) { type = targetType; return this; - } else if (targetType.isDecimal()) { + } else if (targetType.isDecimal() || targetType.isDecimalV2()) { return new DecimalLiteral(new BigDecimal(value)); } return this; diff --git a/fe/src/main/java/org/apache/doris/analysis/IntLiteral.java b/fe/src/main/java/org/apache/doris/analysis/IntLiteral.java index 1f5b1df9983089..7f4b3f2d11bc9b 100644 --- a/fe/src/main/java/org/apache/doris/analysis/IntLiteral.java +++ b/fe/src/main/java/org/apache/doris/analysis/IntLiteral.java @@ -307,7 +307,7 @@ protected Expr uncheckedCastTo(Type targetType) throws AnalysisException { } } else if (targetType.isFloatingPointType()) { return new FloatLiteral(new Double(value), targetType); - } else if (targetType.isDecimal()) { + } else if (targetType.isDecimal() || targetType.isDecimalV2()) { return new DecimalLiteral(new BigDecimal(value)); } return this; diff --git a/fe/src/main/java/org/apache/doris/analysis/LargeIntLiteral.java b/fe/src/main/java/org/apache/doris/analysis/LargeIntLiteral.java index d0162709a54f50..d6dc148b1d2885 100644 --- a/fe/src/main/java/org/apache/doris/analysis/LargeIntLiteral.java +++ b/fe/src/main/java/org/apache/doris/analysis/LargeIntLiteral.java @@ -187,7 +187,7 @@ protected void toThrift(TExprNode msg) { protected Expr uncheckedCastTo(Type targetType) throws AnalysisException { if (targetType.isFloatingPointType()) { return new FloatLiteral(new Double(value.doubleValue()), targetType); - } else if (targetType.isDecimal()) { + } else if (targetType.isDecimal() || targetType.isDecimalV2()) { return new DecimalLiteral(new BigDecimal(value)); } else if (targetType.isNumericType()) { try { diff --git a/fe/src/main/java/org/apache/doris/analysis/LiteralExpr.java b/fe/src/main/java/org/apache/doris/analysis/LiteralExpr.java index faf423d3fa288a..047237857fc5df 100644 --- a/fe/src/main/java/org/apache/doris/analysis/LiteralExpr.java +++ b/fe/src/main/java/org/apache/doris/analysis/LiteralExpr.java @@ -68,6 +68,7 @@ public static LiteralExpr create(String value, Type type) throws AnalysisExcepti literalExpr = new FloatLiteral(value); break; case DECIMAL: + case DECIMALV2: literalExpr = new DecimalLiteral(value); break; case CHAR: diff --git a/fe/src/main/java/org/apache/doris/analysis/StringLiteral.java b/fe/src/main/java/org/apache/doris/analysis/StringLiteral.java index 98029706d944df..942012af9f954c 100644 --- a/fe/src/main/java/org/apache/doris/analysis/StringLiteral.java +++ b/fe/src/main/java/org/apache/doris/analysis/StringLiteral.java @@ -186,6 +186,7 @@ protected Expr uncheckedCastTo(Type targetType) throws AnalysisException { } break; case DECIMAL: + case DECIMALV2: return new DecimalLiteral(value); default: break; diff --git a/fe/src/main/java/org/apache/doris/analysis/TypeDef.java b/fe/src/main/java/org/apache/doris/analysis/TypeDef.java index 13c486d39da870..6fe23c19e97c6c 100644 --- a/fe/src/main/java/org/apache/doris/analysis/TypeDef.java +++ b/fe/src/main/java/org/apache/doris/analysis/TypeDef.java @@ -101,7 +101,8 @@ private void analyzeScalarType(ScalarType scalarType) } break; } - case DECIMAL: { + case DECIMAL: + case DECIMALV2: { int precision = scalarType.decimalPrecision(); int scale = scalarType.decimalScale(); // precision: [1, 27] diff --git a/fe/src/main/java/org/apache/doris/catalog/AggregateType.java b/fe/src/main/java/org/apache/doris/catalog/AggregateType.java index 5561c1219ce7d4..45b8ebb3d27fb7 100644 --- a/fe/src/main/java/org/apache/doris/catalog/AggregateType.java +++ b/fe/src/main/java/org/apache/doris/catalog/AggregateType.java @@ -50,6 +50,7 @@ public enum AggregateType { primitiveTypeList.add(PrimitiveType.FLOAT); primitiveTypeList.add(PrimitiveType.DOUBLE); primitiveTypeList.add(PrimitiveType.DECIMAL); + primitiveTypeList.add(PrimitiveType.DECIMALV2); compatibilityMap.put(SUM, EnumSet.copyOf(primitiveTypeList)); primitiveTypeList.clear(); @@ -61,6 +62,7 @@ public enum AggregateType { primitiveTypeList.add(PrimitiveType.FLOAT); primitiveTypeList.add(PrimitiveType.DOUBLE); primitiveTypeList.add(PrimitiveType.DECIMAL); + primitiveTypeList.add(PrimitiveType.DECIMALV2); primitiveTypeList.add(PrimitiveType.DATE); primitiveTypeList.add(PrimitiveType.DATETIME); compatibilityMap.put(MIN, EnumSet.copyOf(primitiveTypeList)); @@ -74,6 +76,7 @@ public enum AggregateType { primitiveTypeList.add(PrimitiveType.FLOAT); primitiveTypeList.add(PrimitiveType.DOUBLE); primitiveTypeList.add(PrimitiveType.DECIMAL); + primitiveTypeList.add(PrimitiveType.DECIMALV2); primitiveTypeList.add(PrimitiveType.DATE); primitiveTypeList.add(PrimitiveType.DATETIME); compatibilityMap.put(MAX, EnumSet.copyOf(primitiveTypeList)); diff --git a/fe/src/main/java/org/apache/doris/catalog/Column.java b/fe/src/main/java/org/apache/doris/catalog/Column.java index 2c07f5906879ac..af41d8ed6578d0 100644 --- a/fe/src/main/java/org/apache/doris/catalog/Column.java +++ b/fe/src/main/java/org/apache/doris/catalog/Column.java @@ -255,7 +255,8 @@ public void checkSchemaChangeAllowed(Column other) throws DdlException { public String toSql() { StringBuilder sb = new StringBuilder(); sb.append("`").append(name).append("` "); - sb.append(type.toSql()).append(" "); + String typeStr = type.toSql(); + sb.append(typeStr).append(" "); if (aggregationType != null && aggregationType != AggregateType.NONE && !isAggregationTypeImplicit) { sb.append(aggregationType.name()).append(" "); } diff --git a/fe/src/main/java/org/apache/doris/catalog/ColumnType.java b/fe/src/main/java/org/apache/doris/catalog/ColumnType.java index a40a5c92f09ad7..196576020dc9a1 100644 --- a/fe/src/main/java/org/apache/doris/catalog/ColumnType.java +++ b/fe/src/main/java/org/apache/doris/catalog/ColumnType.java @@ -73,7 +73,11 @@ static boolean isSchemaChangeAllowed(Type lhs, Type rhs) { public static void write(DataOutput out, Type type) throws IOException { Preconditions.checkArgument(type.isScalarType(), "only support scalar type serialization"); ScalarType scalarType = (ScalarType) type; - Text.writeString(out, scalarType.getPrimitiveType().name()); + if (scalarType.getPrimitiveType() == PrimitiveType.DECIMALV2) { + Text.writeString(out, PrimitiveType.DECIMAL.name()); + } else { + Text.writeString(out, scalarType.getPrimitiveType().name()); + } out.writeInt(scalarType.getScalarScale()); out.writeInt(scalarType.getScalarPrecision()); out.writeInt(scalarType.getLength()); @@ -83,6 +87,9 @@ public static void write(DataOutput out, Type type) throws IOException { public static Type read(DataInput in) throws IOException { PrimitiveType primitiveType = PrimitiveType.valueOf(Text.readString(in)); + if (primitiveType == PrimitiveType.DECIMAL) { + primitiveType = PrimitiveType.DECIMALV2; + } int scale = in.readInt(); int precision = in.readInt(); int len = in.readInt(); diff --git a/fe/src/main/java/org/apache/doris/catalog/Function.java b/fe/src/main/java/org/apache/doris/catalog/Function.java index 18875245ce3239..0df7bc951dde58 100644 --- a/fe/src/main/java/org/apache/doris/catalog/Function.java +++ b/fe/src/main/java/org/apache/doris/catalog/Function.java @@ -459,6 +459,8 @@ public static String getUdfTypeName(PrimitiveType t) { return "datetime_val"; case DECIMAL: return "decimal_val"; + case DECIMALV2: + return "decimalv2_val"; default: Preconditions.checkState(false, t.toString()); return ""; @@ -494,6 +496,8 @@ public static String getUdfType(PrimitiveType t) { return "DateTimeVal"; case DECIMAL: return "DecimalVal"; + case DECIMALV2: + return "DecimalV2Val"; default: Preconditions.checkState(false, t.toString()); return ""; diff --git a/fe/src/main/java/org/apache/doris/catalog/FunctionSet.java b/fe/src/main/java/org/apache/doris/catalog/FunctionSet.java index 675b17dbf3a5a6..95f7a4c15d3630 100644 --- a/fe/src/main/java/org/apache/doris/catalog/FunctionSet.java +++ b/fe/src/main/java/org/apache/doris/catalog/FunctionSet.java @@ -92,6 +92,8 @@ public void init() { "3minIN9doris_udf11DateTimeValEEEvPNS2_15FunctionContextERKT_PS6_") .put(Type.DECIMAL, "3minIN9doris_udf10DecimalValEEEvPNS2_15FunctionContextERKT_PS6_") + .put(Type.DECIMALV2, + "3minIN9doris_udf12DecimalV2ValEEEvPNS2_15FunctionContextERKT_PS6_") .put(Type.LARGEINT, "3minIN9doris_udf11LargeIntValEEEvPNS2_15FunctionContextERKT_PS6_") .build(); @@ -122,6 +124,8 @@ public void init() { "3maxIN9doris_udf11DateTimeValEEEvPNS2_15FunctionContextERKT_PS6_") .put(Type.DECIMAL, "3maxIN9doris_udf10DecimalValEEEvPNS2_15FunctionContextERKT_PS6_") + .put(Type.DECIMALV2, + "3maxIN9doris_udf12DecimalV2ValEEEvPNS2_15FunctionContextERKT_PS6_") .put(Type.LARGEINT, "3maxIN9doris_udf11LargeIntValEEEvPNS2_15FunctionContextERKT_PS6_") .build(); @@ -136,6 +140,7 @@ public void init() { .put(Type.DOUBLE, Type.DOUBLE) .put(Type.LARGEINT, Type.LARGEINT) .put(Type.DECIMAL, Type.DECIMAL) + .put(Type.DECIMALV2, Type.DECIMALV2) .build(); private static final Map MULTI_DISTINCT_INIT_SYMBOL = @@ -283,6 +288,8 @@ public void init() { "10hll_updateIN9doris_udf11DateTimeValEEEvPNS2_15FunctionContextERKT_PNS2_9StringValE") .put(Type.DECIMAL, "10hll_updateIN9doris_udf10DecimalValEEEvPNS2_15FunctionContextERKT_PNS2_9StringValE") + .put(Type.DECIMALV2, + "10hll_updateIN9doris_udf12DecimalV2ValEEEvPNS2_15FunctionContextERKT_PNS2_9StringValE") .put(Type.LARGEINT, "10hll_updateIN9doris_udf11LargeIntValEEEvPNS2_15FunctionContextERKT_PNS2_9StringValE") .build(); @@ -302,6 +309,8 @@ public void init() { "14offset_fn_initIN9doris_udf10BooleanValEEEvPNS2_15FunctionContextEPT_") .put(Type.DECIMAL, "14offset_fn_initIN9doris_udf10DecimalValEEEvPNS2_15FunctionContextEPT_") + .put(Type.DECIMALV2, + "14offset_fn_initIN9doris_udf12DecimalV2ValEEEvPNS2_15FunctionContextEPT_") .put(Type.TINYINT, "14offset_fn_initIN9doris_udf10TinyIntValEEEvPNS2_15FunctionContextEPT_") .put(Type.SMALLINT, @@ -333,6 +342,8 @@ public void init() { "16offset_fn_updateIN9doris_udf10BooleanValEEEvPNS2_15FunctionContextERKT_RKNS2_9BigIntValES8_PS6_") .put(Type.DECIMAL, "16offset_fn_updateIN9doris_udf10DecimalValEEEvPNS2_15FunctionContextERKT_RKNS2_9BigIntValES8_PS6_") + .put(Type.DECIMALV2, + "16offset_fn_updateIN9doris_udf12DecimalV2ValEEEvPNS2_15FunctionContextERKT_RKNS2_9BigIntValES8_PS6_") .put(Type.TINYINT, "16offset_fn_updateIN9doris_udf10TinyIntValEEEvPNS2_15" + "FunctionContextERKT_RKNS2_9BigIntValES8_PS6_") @@ -368,6 +379,8 @@ public void init() { "15last_val_updateIN9doris_udf10BooleanValEEEvPNS2_15FunctionContextERKT_PS6_") .put(Type.DECIMAL, "15last_val_updateIN9doris_udf10DecimalValEEEvPNS2_15FunctionContextERKT_PS6_") + .put(Type.DECIMALV2, + "15last_val_updateIN9doris_udf12DecimalV2ValEEEvPNS2_15FunctionContextERKT_PS6_") .put(Type.TINYINT, "15last_val_updateIN9doris_udf10TinyIntValEEEvPNS2_15FunctionContextERKT_PS6_") .put(Type.SMALLINT, @@ -400,6 +413,9 @@ public void init() { .put(Type.DECIMAL, "24first_val_rewrite_updateIN9doris_udf10DecimalValEEEvPNS2_15" + "FunctionContextERKT_RKNS2_9BigIntValEPS6_") + .put(Type.DECIMALV2, + "24first_val_rewrite_updateIN9doris_udf12DecimalV2ValEEEvPNS2_15" + + "FunctionContextERKT_RKNS2_9BigIntValEPS6_") .put(Type.TINYINT, "24first_val_rewrite_updateIN9doris_udf10TinyIntValEEEvPNS2_15" + "FunctionContextERKT_RKNS2_9BigIntValEPS6_") @@ -438,6 +454,8 @@ public void init() { "15last_val_removeIN9doris_udf10BooleanValEEEvPNS2_15FunctionContextERKT_PS6_") .put(Type.DECIMAL, "15last_val_removeIN9doris_udf10DecimalValEEEvPNS2_15FunctionContextERKT_PS6_") + .put(Type.DECIMALV2, + "15last_val_removeIN9doris_udf12DecimalV2ValEEEvPNS2_15FunctionContextERKT_PS6_") .put(Type.TINYINT, "15last_val_removeIN9doris_udf10TinyIntValEEEvPNS2_15FunctionContextERKT_PS6_") .put(Type.SMALLINT, @@ -468,6 +486,8 @@ public void init() { "16first_val_updateIN9doris_udf10BooleanValEEEvPNS2_15FunctionContextERKT_PS6_") .put(Type.DECIMAL, "16first_val_updateIN9doris_udf10DecimalValEEEvPNS2_15FunctionContextERKT_PS6_") + .put(Type.DECIMALV2, + "16first_val_updateIN9doris_udf12DecimalV2ValEEEvPNS2_15FunctionContextERKT_PS6_") .put(Type.TINYINT, "16first_val_updateIN9doris_udf10TinyIntValEEEvPNS2_15FunctionContextERKT_PS6_") .put(Type.SMALLINT, @@ -700,6 +720,18 @@ private void initAggregateBuiltins() { null, prefix + "31count_distinct_decimal_finalizeEPN9doris_udf15FunctionContextERKNS1_9StringValE", false, true, true)); + } else if (t == Type.DECIMALV2) { + addBuiltin(AggregateFunction.createBuiltin("multi_distinct_count", Lists.newArrayList(t), + Type.BIGINT, + Type.VARCHAR, + prefix + "36count_or_sum_distinct_decimalv2_initEPN9doris_udf15FunctionContextEPNS1_9StringValE", + prefix + "38count_or_sum_distinct_decimalv2_updateEPN9doris_udf15FunctionContextERNS1_12DecimalV2ValEPNS1_9StringValE", + prefix + "37count_or_sum_distinct_decimalv2_mergeEPN9doris_udf15FunctionContextERNS1_9StringValEPS4_", + prefix + "41count_or_sum_distinct_decimalv2_serializeEPN9doris_udf15FunctionContextERKNS1_9StringValE", + null, + null, + prefix + "33count_distinct_decimalv2_finalizeEPN9doris_udf15FunctionContextERKNS1_9StringValE", + false, true, true)); } // sum in multi distinct @@ -727,6 +759,18 @@ private void initAggregateBuiltins() { null, prefix + "29sum_distinct_decimal_finalizeEPN9doris_udf15FunctionContextERKNS1_9StringValE", false, true, true)); + } else if (t == Type.DECIMALV2) { + addBuiltin(AggregateFunction.createBuiltin("multi_distinct_sum", Lists.newArrayList(t), + MULTI_DISTINCT_SUM_RETURN_TYPE.get(t), + Type.VARCHAR, + prefix + "36count_or_sum_distinct_decimalv2_initEPN9doris_udf15FunctionContextEPNS1_9StringValE", + prefix + "38count_or_sum_distinct_decimalv2_updateEPN9doris_udf15FunctionContextERNS1_12DecimalV2ValEPNS1_9StringValE", + prefix + "37count_or_sum_distinct_decimalv2_mergeEPN9doris_udf15FunctionContextERNS1_9StringValEPS4_", + prefix + "41count_or_sum_distinct_decimalv2_serializeEPN9doris_udf15FunctionContextERKNS1_9StringValE", + null, + null, + prefix + "31sum_distinct_decimalv2_finalizeEPN9doris_udf15FunctionContextERKNS1_9StringValE", + false, true, true)); } // Min String minMaxInit = t.isStringType() ? initNullString : initNull; @@ -861,6 +905,13 @@ private void initAggregateBuiltins() { null, null, prefix + "10sum_removeIN9doris_udf10DecimalValES3_EEvPNS2_15FunctionContextERKT_PT0_", null, false, true, false)); + addBuiltin(AggregateFunction.createBuiltin(name, + Lists.newArrayList(Type.DECIMALV2), Type.DECIMALV2, Type.DECIMALV2, initNull, + prefix + "3sumIN9doris_udf12DecimalV2ValES3_EEvPNS2_15FunctionContextERKT_PT0_", + prefix + "3sumIN9doris_udf12DecimalV2ValES3_EEvPNS2_15FunctionContextERKT_PT0_", + null, null, + prefix + "10sum_removeIN9doris_udf12DecimalV2ValES3_EEvPNS2_15FunctionContextERKT_PT0_", + null, false, true, false)); addBuiltin(AggregateFunction.createBuiltin(name, Lists.newArrayList(Type.LARGEINT), Type.LARGEINT, Type.LARGEINT, initNull, prefix + "3sumIN9doris_udf11LargeIntValES3_EEvPNS2_15FunctionContextERKT_PT0_", @@ -903,6 +954,16 @@ private void initAggregateBuiltins() { prefix + "18decimal_avg_removeEPN9doris_udf15FunctionContextERKNS1_10DecimalValEPNS1_9StringValE", prefix + "20decimal_avg_finalizeEPN9doris_udf15FunctionContextERKNS1_9StringValE", false, true, false)); + addBuiltin(AggregateFunction.createBuiltin("avg", + Lists.newArrayList(Type.DECIMALV2), Type.DECIMALV2, Type.VARCHAR, + prefix + "18decimalv2_avg_initEPN9doris_udf15FunctionContextEPNS1_9StringValE", + prefix + "20decimalv2_avg_updateEPN9doris_udf15FunctionContextERKNS1_12DecimalV2ValEPNS1_9StringValE", + prefix + "19decimalv2_avg_mergeEPN9doris_udf15FunctionContextERKNS1_9StringValEPS4_", + stringValSerializeOrFinalize, + prefix + "23decimalv2_avg_get_valueEPN9doris_udf15FunctionContextERKNS1_9StringValE", + prefix + "20decimalv2_avg_removeEPN9doris_udf15FunctionContextERKNS1_12DecimalV2ValEPNS1_9StringValE", + prefix + "22decimalv2_avg_finalizeEPN9doris_udf15FunctionContextERKNS1_9StringValE", + false, true, false)); // Avg(Timestamp) addBuiltin(AggregateFunction.createBuiltin("avg", Lists.newArrayList(Type.DATE), Type.DATE, Type.VARCHAR, diff --git a/fe/src/main/java/org/apache/doris/catalog/PrimitiveType.java b/fe/src/main/java/org/apache/doris/catalog/PrimitiveType.java index a15c572e9d0d06..1b3f3818a54450 100644 --- a/fe/src/main/java/org/apache/doris/catalog/PrimitiveType.java +++ b/fe/src/main/java/org/apache/doris/catalog/PrimitiveType.java @@ -50,6 +50,7 @@ public enum PrimitiveType { VARCHAR("VARCHAR", 16, TPrimitiveType.VARCHAR), DECIMAL("DECIMAL", 40, TPrimitiveType.DECIMAL), + DECIMALV2("DECIMALV2", 16, TPrimitiveType.DECIMALV2), HLL("HLL", 16, TPrimitiveType.HLL), // Unsupported scalar types. @@ -75,6 +76,7 @@ public enum PrimitiveType { builder.put(NULL_TYPE, DATE); builder.put(NULL_TYPE, DATETIME); builder.put(NULL_TYPE, DECIMAL); + builder.put(NULL_TYPE, DECIMALV2); builder.put(NULL_TYPE, CHAR); builder.put(NULL_TYPE, VARCHAR); // Boolean @@ -89,6 +91,7 @@ public enum PrimitiveType { builder.put(BOOLEAN, DATE); builder.put(BOOLEAN, DATETIME); builder.put(BOOLEAN, DECIMAL); + builder.put(BOOLEAN, DECIMALV2); builder.put(BOOLEAN, VARCHAR); // Tinyint builder.put(TINYINT, BOOLEAN); @@ -102,6 +105,7 @@ public enum PrimitiveType { builder.put(TINYINT, DATE); builder.put(TINYINT, DATETIME); builder.put(TINYINT, DECIMAL); + builder.put(TINYINT, DECIMALV2); builder.put(TINYINT, VARCHAR); // Smallint builder.put(SMALLINT, BOOLEAN); @@ -115,6 +119,7 @@ public enum PrimitiveType { builder.put(SMALLINT, DATE); builder.put(SMALLINT, DATETIME); builder.put(SMALLINT, DECIMAL); + builder.put(SMALLINT, DECIMALV2); builder.put(SMALLINT, VARCHAR); // Int builder.put(INT, BOOLEAN); @@ -128,6 +133,7 @@ public enum PrimitiveType { builder.put(INT, DATE); builder.put(INT, DATETIME); builder.put(INT, DECIMAL); + builder.put(INT, DECIMALV2); builder.put(INT, VARCHAR); // Bigint builder.put(BIGINT, BOOLEAN); @@ -141,6 +147,7 @@ public enum PrimitiveType { builder.put(BIGINT, DATE); builder.put(BIGINT, DATETIME); builder.put(BIGINT, DECIMAL); + builder.put(BIGINT, DECIMALV2); builder.put(BIGINT, VARCHAR); // Largeint builder.put(LARGEINT, BOOLEAN); @@ -154,6 +161,7 @@ public enum PrimitiveType { builder.put(LARGEINT, DATE); builder.put(LARGEINT, DATETIME); builder.put(LARGEINT, DECIMAL); + builder.put(LARGEINT, DECIMALV2); builder.put(LARGEINT, VARCHAR); // Float builder.put(FLOAT, BOOLEAN); @@ -167,6 +175,7 @@ public enum PrimitiveType { builder.put(FLOAT, DATE); builder.put(FLOAT, DATETIME); builder.put(FLOAT, DECIMAL); + builder.put(FLOAT, DECIMALV2); builder.put(FLOAT, VARCHAR); // Double builder.put(DOUBLE, BOOLEAN); @@ -180,6 +189,7 @@ public enum PrimitiveType { builder.put(DOUBLE, DATE); builder.put(DOUBLE, DATETIME); builder.put(DOUBLE, DECIMAL); + builder.put(DOUBLE, DECIMALV2); builder.put(DOUBLE, VARCHAR); // Date builder.put(DATE, BOOLEAN); @@ -193,6 +203,7 @@ public enum PrimitiveType { builder.put(DATE, DATE); builder.put(DATE, DATETIME); builder.put(DATE, DECIMAL); + builder.put(DATE, DECIMALV2); builder.put(DATE, VARCHAR); // Datetime builder.put(DATETIME, BOOLEAN); @@ -206,6 +217,7 @@ public enum PrimitiveType { builder.put(DATETIME, DATE); builder.put(DATETIME, DATETIME); builder.put(DATETIME, DECIMAL); + builder.put(DATETIME, DECIMALV2); builder.put(DATETIME, VARCHAR); // Char builder.put(CHAR, CHAR); @@ -222,6 +234,7 @@ public enum PrimitiveType { builder.put(VARCHAR, DATE); builder.put(VARCHAR, DATETIME); builder.put(VARCHAR, DECIMAL); + builder.put(VARCHAR, DECIMALV2); builder.put(VARCHAR, VARCHAR); builder.put(VARCHAR, HLL); // Decimal @@ -234,8 +247,21 @@ public enum PrimitiveType { builder.put(DECIMAL, FLOAT); builder.put(DECIMAL, DOUBLE); builder.put(DECIMAL, DECIMAL); + builder.put(DECIMAL, DECIMALV2); builder.put(DECIMAL, VARCHAR); - + // DecimalV2 + builder.put(DECIMALV2, BOOLEAN); + builder.put(DECIMALV2, TINYINT); + builder.put(DECIMALV2, SMALLINT); + builder.put(DECIMALV2, INT); + builder.put(DECIMALV2, BIGINT); + builder.put(DECIMALV2, LARGEINT); + builder.put(DECIMALV2, FLOAT); + builder.put(DECIMALV2, DOUBLE); + builder.put(DECIMALV2, DECIMAL); + builder.put(DECIMALV2, DECIMALV2); + builder.put(DECIMALV2, VARCHAR); + // HLL builder.put(HLL, HLL); builder.put(HLL, VARCHAR); @@ -264,6 +290,7 @@ public enum PrimitiveType { numericTypes.add(FLOAT); numericTypes.add(DOUBLE); numericTypes.add(DECIMAL); + numericTypes.add(DECIMALV2); supportedTypes = Lists.newArrayList(); supportedTypes.add(NULL_TYPE); @@ -281,6 +308,7 @@ public enum PrimitiveType { supportedTypes.add(DATE); supportedTypes.add(DATETIME); supportedTypes.add(DECIMAL); + supportedTypes.add(DECIMALV2); } public static ArrayList getIntegerTypes() { @@ -331,6 +359,7 @@ public static boolean isImplicitCast(PrimitiveType type, PrimitiveType target) { compatibilityMatrix[NULL_TYPE.ordinal()][CHAR.ordinal()] = CHAR; compatibilityMatrix[NULL_TYPE.ordinal()][VARCHAR.ordinal()] = VARCHAR; compatibilityMatrix[NULL_TYPE.ordinal()][DECIMAL.ordinal()] = DECIMAL; + compatibilityMatrix[NULL_TYPE.ordinal()][DECIMALV2.ordinal()] = DECIMALV2; compatibilityMatrix[BOOLEAN.ordinal()][BOOLEAN.ordinal()] = BOOLEAN; compatibilityMatrix[BOOLEAN.ordinal()][TINYINT.ordinal()] = TINYINT; @@ -345,6 +374,7 @@ public static boolean isImplicitCast(PrimitiveType type, PrimitiveType target) { compatibilityMatrix[BOOLEAN.ordinal()][CHAR.ordinal()] = INVALID_TYPE; compatibilityMatrix[BOOLEAN.ordinal()][VARCHAR.ordinal()] = INVALID_TYPE; compatibilityMatrix[BOOLEAN.ordinal()][DECIMAL.ordinal()] = DECIMAL; + compatibilityMatrix[BOOLEAN.ordinal()][DECIMALV2.ordinal()] = DECIMALV2; compatibilityMatrix[TINYINT.ordinal()][TINYINT.ordinal()] = TINYINT; compatibilityMatrix[TINYINT.ordinal()][SMALLINT.ordinal()] = SMALLINT; @@ -358,6 +388,7 @@ public static boolean isImplicitCast(PrimitiveType type, PrimitiveType target) { compatibilityMatrix[TINYINT.ordinal()][CHAR.ordinal()] = INVALID_TYPE; compatibilityMatrix[TINYINT.ordinal()][VARCHAR.ordinal()] = INVALID_TYPE; compatibilityMatrix[TINYINT.ordinal()][DECIMAL.ordinal()] = DECIMAL; + compatibilityMatrix[TINYINT.ordinal()][DECIMALV2.ordinal()] = DECIMALV2; compatibilityMatrix[SMALLINT.ordinal()][SMALLINT.ordinal()] = SMALLINT; compatibilityMatrix[SMALLINT.ordinal()][INT.ordinal()] = INT; @@ -370,6 +401,7 @@ public static boolean isImplicitCast(PrimitiveType type, PrimitiveType target) { compatibilityMatrix[SMALLINT.ordinal()][CHAR.ordinal()] = INVALID_TYPE; compatibilityMatrix[SMALLINT.ordinal()][VARCHAR.ordinal()] = INVALID_TYPE; compatibilityMatrix[SMALLINT.ordinal()][DECIMAL.ordinal()] = DECIMAL; + compatibilityMatrix[SMALLINT.ordinal()][DECIMALV2.ordinal()] = DECIMALV2; compatibilityMatrix[INT.ordinal()][INT.ordinal()] = INT; compatibilityMatrix[INT.ordinal()][BIGINT.ordinal()] = BIGINT; @@ -381,6 +413,7 @@ public static boolean isImplicitCast(PrimitiveType type, PrimitiveType target) { compatibilityMatrix[INT.ordinal()][CHAR.ordinal()] = INVALID_TYPE; compatibilityMatrix[INT.ordinal()][VARCHAR.ordinal()] = INVALID_TYPE; compatibilityMatrix[INT.ordinal()][DECIMAL.ordinal()] = DECIMAL; + compatibilityMatrix[INT.ordinal()][DECIMALV2.ordinal()] = DECIMALV2; compatibilityMatrix[BIGINT.ordinal()][BIGINT.ordinal()] = BIGINT; compatibilityMatrix[BIGINT.ordinal()][LARGEINT.ordinal()] = LARGEINT; @@ -391,6 +424,7 @@ public static boolean isImplicitCast(PrimitiveType type, PrimitiveType target) { compatibilityMatrix[BIGINT.ordinal()][CHAR.ordinal()] = INVALID_TYPE; compatibilityMatrix[BIGINT.ordinal()][VARCHAR.ordinal()] = INVALID_TYPE; compatibilityMatrix[BIGINT.ordinal()][DECIMAL.ordinal()] = DECIMAL; + compatibilityMatrix[BIGINT.ordinal()][DECIMALV2.ordinal()] = DECIMALV2; compatibilityMatrix[LARGEINT.ordinal()][LARGEINT.ordinal()] = LARGEINT; compatibilityMatrix[LARGEINT.ordinal()][FLOAT.ordinal()] = DOUBLE; @@ -400,6 +434,7 @@ public static boolean isImplicitCast(PrimitiveType type, PrimitiveType target) { compatibilityMatrix[LARGEINT.ordinal()][CHAR.ordinal()] = INVALID_TYPE; compatibilityMatrix[LARGEINT.ordinal()][VARCHAR.ordinal()] = INVALID_TYPE; compatibilityMatrix[LARGEINT.ordinal()][DECIMAL.ordinal()] = DECIMAL; + compatibilityMatrix[LARGEINT.ordinal()][DECIMALV2.ordinal()] = DECIMALV2; compatibilityMatrix[FLOAT.ordinal()][FLOAT.ordinal()] = FLOAT; compatibilityMatrix[FLOAT.ordinal()][DOUBLE.ordinal()] = DOUBLE; @@ -408,6 +443,7 @@ public static boolean isImplicitCast(PrimitiveType type, PrimitiveType target) { compatibilityMatrix[FLOAT.ordinal()][CHAR.ordinal()] = INVALID_TYPE; compatibilityMatrix[FLOAT.ordinal()][VARCHAR.ordinal()] = INVALID_TYPE; compatibilityMatrix[FLOAT.ordinal()][DECIMAL.ordinal()] = DECIMAL; + compatibilityMatrix[FLOAT.ordinal()][DECIMALV2.ordinal()] = DECIMALV2; compatibilityMatrix[DOUBLE.ordinal()][DOUBLE.ordinal()] = DOUBLE; compatibilityMatrix[DOUBLE.ordinal()][DATE.ordinal()] = INVALID_TYPE; @@ -415,26 +451,33 @@ public static boolean isImplicitCast(PrimitiveType type, PrimitiveType target) { compatibilityMatrix[DOUBLE.ordinal()][CHAR.ordinal()] = INVALID_TYPE; compatibilityMatrix[DOUBLE.ordinal()][VARCHAR.ordinal()] = INVALID_TYPE; compatibilityMatrix[DOUBLE.ordinal()][DECIMAL.ordinal()] = DECIMAL; + compatibilityMatrix[DOUBLE.ordinal()][DECIMALV2.ordinal()] = DECIMALV2; compatibilityMatrix[DATE.ordinal()][DATE.ordinal()] = DATE; compatibilityMatrix[DATE.ordinal()][DATETIME.ordinal()] = DATETIME; compatibilityMatrix[DATE.ordinal()][CHAR.ordinal()] = INVALID_TYPE; compatibilityMatrix[DATE.ordinal()][VARCHAR.ordinal()] = INVALID_TYPE; compatibilityMatrix[DATE.ordinal()][DECIMAL.ordinal()] = INVALID_TYPE; + compatibilityMatrix[DATE.ordinal()][DECIMALV2.ordinal()] = INVALID_TYPE; compatibilityMatrix[DATETIME.ordinal()][DATETIME.ordinal()] = DATETIME; compatibilityMatrix[DATETIME.ordinal()][CHAR.ordinal()] = INVALID_TYPE; compatibilityMatrix[DATETIME.ordinal()][VARCHAR.ordinal()] = INVALID_TYPE; compatibilityMatrix[DATETIME.ordinal()][DECIMAL.ordinal()] = INVALID_TYPE; + compatibilityMatrix[DATETIME.ordinal()][DECIMALV2.ordinal()] = INVALID_TYPE; compatibilityMatrix[CHAR.ordinal()][CHAR.ordinal()] = CHAR; compatibilityMatrix[CHAR.ordinal()][VARCHAR.ordinal()] = VARCHAR; compatibilityMatrix[CHAR.ordinal()][DECIMAL.ordinal()] = INVALID_TYPE; + compatibilityMatrix[CHAR.ordinal()][DECIMALV2.ordinal()] = INVALID_TYPE; compatibilityMatrix[VARCHAR.ordinal()][VARCHAR.ordinal()] = VARCHAR; compatibilityMatrix[VARCHAR.ordinal()][DECIMAL.ordinal()] = INVALID_TYPE; + compatibilityMatrix[VARCHAR.ordinal()][DECIMALV2.ordinal()] = INVALID_TYPE; compatibilityMatrix[DECIMAL.ordinal()][DECIMAL.ordinal()] = DECIMAL; + compatibilityMatrix[DECIMALV2.ordinal()][DECIMALV2.ordinal()] = DECIMALV2; + compatibilityMatrix[DECIMALV2.ordinal()][DECIMAL.ordinal()] = DECIMALV2; compatibilityMatrix[HLL.ordinal()][HLL.ordinal()] = HLL; } @@ -442,7 +485,7 @@ public static boolean isImplicitCast(PrimitiveType type, PrimitiveType target) { private static PrimitiveType[][] schemaChangeCompatibilityMatrix; static { - schemaChangeCompatibilityMatrix = new PrimitiveType[DECIMAL.ordinal() + 1][DECIMAL.ordinal() + 1]; + schemaChangeCompatibilityMatrix = new PrimitiveType[HLL.ordinal() + 1][HLL.ordinal() + 1]; // NULL_TYPE is compatible with any type and results in the non-null type. compatibilityMatrix[NULL_TYPE.ordinal()][NULL_TYPE.ordinal()] = NULL_TYPE; @@ -566,6 +609,10 @@ public boolean isDecimalType() { return this == DECIMAL; } + public boolean isDecimalV2Type() { + return this == DECIMALV2; + } + public PrimitiveType getNumResultType() { switch (this) { case BOOLEAN: @@ -585,6 +632,8 @@ public PrimitiveType getNumResultType() { return DOUBLE; case DECIMAL: return DECIMAL; + case DECIMALV2: + return DECIMALV2; case HLL: return HLL; default: @@ -613,6 +662,8 @@ public PrimitiveType getResultType() { return VARCHAR; case DECIMAL: return DECIMAL; + case DECIMALV2: + return DECIMALV2; case HLL: return HLL; default: @@ -631,6 +682,8 @@ public PrimitiveType getMaxResolutionType() { return BIGINT; } else if (isDecimalType()) { return DECIMAL; + } else if (isDecimalV2Type()) { + return DECIMALV2; } else if (isDateType()) { return DATETIME; // Timestamps get summed as DOUBLE for AVG. @@ -644,7 +697,7 @@ public PrimitiveType getMaxResolutionType() { } public boolean isNumericType() { - return isFixedPointType() || isFloatingPointType() || isDecimalType(); + return isFixedPointType() || isFloatingPointType() || isDecimalType() || isDecimalV2Type(); } public boolean isValid() { @@ -695,6 +748,7 @@ public MysqlColType toMysqlType() { } } case DECIMAL: + case DECIMALV2: return MysqlColType.MYSQL_TYPE_DECIMAL; default: return MysqlColType.MYSQL_TYPE_STRING; @@ -713,6 +767,7 @@ public int getOlapColumnIndexSize() { // char index size is length return -1; case DECIMAL: + case DECIMALV2: return DECIMAL_INDEX_LEN; default: return this.getSlotSize(); @@ -741,6 +796,12 @@ public static PrimitiveType getCmpType(PrimitiveType t1, PrimitiveType t2) { || t2ResultType == PrimitiveType.DECIMAL)) { return PrimitiveType.DECIMAL; } + if ((t1ResultType == PrimitiveType.BIGINT + || t1ResultType == PrimitiveType.DECIMALV2) + && (t2ResultType == PrimitiveType.BIGINT + || t2ResultType == PrimitiveType.DECIMALV2)) { + return PrimitiveType.DECIMALV2; + } if ((t1ResultType == PrimitiveType.BIGINT || t1ResultType == PrimitiveType.LARGEINT) && (t2ResultType == PrimitiveType.BIGINT diff --git a/fe/src/main/java/org/apache/doris/catalog/ScalarFunction.java b/fe/src/main/java/org/apache/doris/catalog/ScalarFunction.java index dc54d50516af2d..9a1e1928ee0e8b 100644 --- a/fe/src/main/java/org/apache/doris/catalog/ScalarFunction.java +++ b/fe/src/main/java/org/apache/doris/catalog/ScalarFunction.java @@ -27,6 +27,8 @@ import org.apache.doris.thrift.TFunction; import org.apache.doris.thrift.TFunctionBinaryType; import org.apache.doris.thrift.TScalarFunction; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; import java.io.DataInput; import java.io.DataOutput; @@ -40,6 +42,7 @@ * Internal representation of a scalar function. */ public class ScalarFunction extends Function { + private static final Logger LOG = LogManager.getLogger(ScalarFunction.class); // The name inside the binary at location_ that contains this particular // function. e.g. org.example.MyUdf.class. private String symbolName; @@ -128,6 +131,7 @@ public static ScalarFunction createBuiltinOperator( // Convert Add(TINYINT, TINYINT) --> Add_TinyIntVal_TinyIntVal String beFn = name; boolean usesDecimal = false; + boolean usesDecimalV2 = false; for (int i = 0; i < argTypes.size(); ++i) { switch (argTypes.get(i).getPrimitiveType()) { case BOOLEAN: @@ -167,11 +171,16 @@ public static ScalarFunction createBuiltinOperator( beFn += "_decimal_val"; usesDecimal = true; break; + case DECIMALV2: + beFn += "_decimalv2_val"; + usesDecimalV2 = true; + break; default: Preconditions.checkState(false, "Argument type not supported: " + argTypes.get(i)); } } String beClass = usesDecimal ? "DecimalOperators" : "Operators"; + if (usesDecimalV2) beClass = "DecimalV2Operators"; String symbol = "doris::" + beClass + "::" + beFn; return createBuiltinOperator(name, symbol, argTypes, retType); } diff --git a/fe/src/main/java/org/apache/doris/catalog/ScalarType.java b/fe/src/main/java/org/apache/doris/catalog/ScalarType.java index eb6880545c54aa..876f185162070f 100644 --- a/fe/src/main/java/org/apache/doris/catalog/ScalarType.java +++ b/fe/src/main/java/org/apache/doris/catalog/ScalarType.java @@ -87,6 +87,8 @@ public static ScalarType createType(PrimitiveType type, int len, int precision, return createVarcharType(len); case DECIMAL: return createDecimalType(precision, scale); + case DECIMALV2: + return createDecimalV2Type(precision, scale); default: return createType(type); } @@ -124,6 +126,8 @@ public static ScalarType createType(PrimitiveType type) { return DATETIME; case DECIMAL: return (ScalarType) createDecimalType(); + case DECIMALV2: + return DEFAULT_DECIMALV2; case LARGEINT: return LARGEINT; default: @@ -165,6 +169,8 @@ public static ScalarType createType(String type) { return DATETIME; case "DECIMAL": return (ScalarType) createDecimalType(); + case "DECIMALV2": + return (ScalarType) createDecimalV2Type(); case "LARGEINT": return LARGEINT; default: @@ -190,10 +196,18 @@ public static ScalarType createDecimalType() { return DEFAULT_DECIMAL; } + public static ScalarType createDecimalV2Type() { + return DEFAULT_DECIMALV2; + } + public static ScalarType createDecimalType(int precision) { return createDecimalType(precision, DEFAULT_SCALE); } + public static ScalarType createDecimalV2Type(int precision) { + return createDecimalV2Type(precision, DEFAULT_SCALE); + } + public static ScalarType createDecimalType(int precision, int scale) { // Preconditions.checkState(precision >= 0); // Enforced by parser // Preconditions.checkState(scale >= 0); // Enforced by parser. @@ -203,6 +217,15 @@ public static ScalarType createDecimalType(int precision, int scale) { return type; } + public static ScalarType createDecimalV2Type(int precision, int scale) { + // Preconditions.checkState(precision >= 0); // Enforced by parser + // Preconditions.checkState(scale >= 0); // Enforced by parser. + ScalarType type = new ScalarType(PrimitiveType.DECIMALV2); + type.precision = precision; + type.scale = scale; + return type; + } + // Identical to createDecimalType except that higher precisions are truncated // to the max storable precision. The BE will report overflow in these cases // (think of this as adding ints to BIGINT but BIGINT can still overflow). @@ -213,6 +236,13 @@ public static ScalarType createDecimalTypeInternal(int precision, int scale) { return type; } + public static ScalarType createDecimalV2TypeInternal(int precision, int scale) { + ScalarType type = new ScalarType(PrimitiveType.DECIMALV2); + type.precision = Math.min(precision, MAX_PRECISION); + type.scale = Math.min(type.precision, scale); + return type; + } + public static ScalarType createVarcharType(int len) { // length checked in analysis ScalarType type = new ScalarType(PrimitiveType.VARCHAR); @@ -249,6 +279,11 @@ public String toString() { return "DECIMAL(*,*)"; } return "DECIMAL(" + precision + "," + scale + ")"; + } else if (type == PrimitiveType.DECIMALV2) { + if (isWildcardDecimal()) { + return "DECIMAL(*,*)"; + } + return "DECIMAL(" + precision + "," + scale + ")"; } else if (type == PrimitiveType.VARCHAR) { if (isWildcardVarchar()) { return "VARCHAR(*)"; @@ -271,6 +306,9 @@ public String toSql(int depth) { case DECIMAL: stringBuilder.append("decimal").append("(").append(precision).append(", ").append(scale).append(")"); break; + case DECIMALV2: + stringBuilder.append("decimal").append("(").append(precision).append(", ").append(scale).append(")"); + break; case BOOLEAN: return "tinyint(1)"; case TINYINT: @@ -317,7 +355,8 @@ public void toThrift(TTypeDesc container) { node.setScalar_type(scalarType); break; } - case DECIMAL: { + case DECIMAL: + case DECIMALV2: { node.setType(TTypeNodeType.SCALAR); TScalarType scalarType = new TScalarType(); scalarType.setType(type.toThrift()); @@ -345,12 +384,12 @@ public static Type[] toColumnType(PrimitiveType[] types) { } public int decimalPrecision() { - Preconditions.checkState(type == PrimitiveType.DECIMAL); + Preconditions.checkState(type == PrimitiveType.DECIMAL || type == PrimitiveType.DECIMALV2); return precision; } public int decimalScale() { - Preconditions.checkState(type == PrimitiveType.DECIMAL); + Preconditions.checkState(type == PrimitiveType.DECIMAL || type == PrimitiveType.DECIMALV2); return scale; } @@ -368,7 +407,8 @@ public int decimalScale() { @Override public boolean isWildcardDecimal() { - return type == PrimitiveType.DECIMAL && precision == -1 && scale == -1; + return (type == PrimitiveType.DECIMAL || type == PrimitiveType.DECIMALV2) + && precision == -1 && scale == -1; } @Override @@ -386,7 +426,7 @@ public boolean isWildcardChar() { */ @Override public boolean isFullySpecifiedDecimal() { - if (!isDecimal()) return false; + if (!isDecimal() && !isDecimalV2()) return false; if (isWildcardDecimal()) return false; if (precision <= 0 || precision > MAX_PRECISION) return false; if (scale < 0 || scale > precision) return false; @@ -399,7 +439,7 @@ public boolean isFixedLengthType() { || type == PrimitiveType.SMALLINT || type == PrimitiveType.INT || type == PrimitiveType.BIGINT || type == PrimitiveType.FLOAT || type == PrimitiveType.DOUBLE || type == PrimitiveType.DATE - || type == PrimitiveType.DATETIME + || type == PrimitiveType.DATETIME || type == PrimitiveType.DECIMALV2 || type == PrimitiveType.CHAR || type == PrimitiveType.DECIMAL; } @@ -457,13 +497,16 @@ public boolean matchesType(Type t) { if (type == PrimitiveType.HLL && scalarType.isStringType()) { return true; } - if (isDecimal() && scalarType.isWildcardDecimal()) { + if ((isDecimal() || isDecimalV2()) && scalarType.isWildcardDecimal()) { Preconditions.checkState(!isWildcardDecimal()); return true; } if (isDecimal() && scalarType.isDecimal()) { return true; } + if (isDecimalV2() && scalarType.isDecimalV2()) { + return true; + } return false; } @@ -482,7 +525,7 @@ public boolean equals(Object o) { if (type == PrimitiveType.VARCHAR) { return len == other.len; } - if (type == PrimitiveType.DECIMAL) { + if (type == PrimitiveType.DECIMAL || type == PrimitiveType.DECIMALV2) { return precision == other.precision && scale == other.scale; } return true; @@ -498,6 +541,8 @@ public Type getMaxResolutionType() { return ScalarType.NULL; } else if (isDecimal()) { return createDecimalTypeInternal(MAX_PRECISION, scale); + } else if (isDecimalV2()) { + return createDecimalV2TypeInternal(MAX_PRECISION, scale); } else if (isLargeIntType()) { return ScalarType.LARGEINT; } else { @@ -511,6 +556,8 @@ public ScalarType getNextResolutionType() { return this; } else if (type == PrimitiveType.DECIMAL) { return createDecimalTypeInternal(MAX_PRECISION, scale); + } else if (type == PrimitiveType.DECIMALV2) { + return createDecimalV2TypeInternal(MAX_PRECISION, scale); } return createType(PrimitiveType.values()[type.ordinal() + 1]); } @@ -524,6 +571,7 @@ public ScalarType getMinResolutionDecimal() { case NULL_TYPE: return Type.NULL; case DECIMAL: + case DECIMALV2: return this; case TINYINT: return createDecimalType(3); @@ -534,9 +582,9 @@ public ScalarType getMinResolutionDecimal() { case BIGINT: return createDecimalType(19); case FLOAT: - return createDecimalTypeInternal(MAX_PRECISION, 9); + return createDecimalV2TypeInternal(MAX_PRECISION, 9); case DOUBLE: - return createDecimalTypeInternal(MAX_PRECISION, 17); + return createDecimalV2TypeInternal(MAX_PRECISION, 17); default: return ScalarType.INVALID; } @@ -549,8 +597,8 @@ public ScalarType getMinResolutionDecimal() { * the decimal point must be greater or equal. */ public boolean isSupertypeOf(ScalarType o) { - Preconditions.checkState(isDecimal()); - Preconditions.checkState(o.isDecimal()); + Preconditions.checkState(isDecimal() || isDecimalV2()); + Preconditions.checkState(o.isDecimal() || o.isDecimalV2()); if (isWildcardDecimal()) { return true; } @@ -601,6 +649,10 @@ public static ScalarType getAssignmentCompatibleType( return INVALID; } + if (t1.isDecimalV2() || t2.isDecimalV2()) { + return DECIMALV2; + } + if (t1.isDecimal() || t2.isDecimal()) { return DECIMAL; // // The case of decimal and float/double must be handled carefully. There are two @@ -688,6 +740,8 @@ public int getStorageLayoutBytes() { return 8; case DECIMAL: return 40; + case DECIMALV2: + return 16; case CHAR: case VARCHAR: return len; @@ -705,7 +759,7 @@ public TColumnType toColumnTypeThrift() { if (type == PrimitiveType.CHAR || type == PrimitiveType.VARCHAR || type == PrimitiveType.HLL) { thrift.setLen(len); } - if (type == PrimitiveType.DECIMAL) { + if (type == PrimitiveType.DECIMAL || type == PrimitiveType.DECIMALV2) { thrift.setPrecision(precision); thrift.setScale(scale); } diff --git a/fe/src/main/java/org/apache/doris/catalog/Type.java b/fe/src/main/java/org/apache/doris/catalog/Type.java index 43f98b3b4f4741..24c70718a09666 100644 --- a/fe/src/main/java/org/apache/doris/catalog/Type.java +++ b/fe/src/main/java/org/apache/doris/catalog/Type.java @@ -66,7 +66,11 @@ public abstract class Type { public static final ScalarType DEFAULT_DECIMAL = (ScalarType) ScalarType.createDecimalType(ScalarType.DEFAULT_PRECISION, ScalarType.DEFAULT_SCALE); + public static final ScalarType DEFAULT_DECIMALV2 = (ScalarType) + ScalarType.createDecimalV2Type(ScalarType.DEFAULT_PRECISION, + ScalarType.DEFAULT_SCALE); public static final ScalarType DECIMAL = DEFAULT_DECIMAL; + public static final ScalarType DECIMALV2 = DEFAULT_DECIMALV2; // (ScalarType) ScalarType.createDecimalTypeInternal(-1, -1); public static final ScalarType DEFAULT_VARCHAR = ScalarType.createVarcharType(-1); public static final ScalarType VARCHAR = ScalarType.createVarcharType(-1); @@ -94,6 +98,7 @@ public abstract class Type { numericTypes.add(FLOAT); numericTypes.add(DOUBLE); numericTypes.add(DECIMAL); + numericTypes.add(DECIMALV2); supportedTypes = Lists.newArrayList(); supportedTypes.add(NULL); @@ -111,6 +116,7 @@ public abstract class Type { supportedTypes.add(DATE); supportedTypes.add(DATETIME); supportedTypes.add(DECIMAL); + supportedTypes.add(DECIMALV2); } public static ArrayList getIntegerTypes() { @@ -166,6 +172,10 @@ public boolean isDecimal() { return isScalarType(PrimitiveType.DECIMAL); } + public boolean isDecimalV2() { + return isScalarType(PrimitiveType.DECIMALV2); + } + public boolean isDecimalOrNull() { return isDecimal() || isNull(); } public boolean isFullySpecifiedDecimal() { return false; } public boolean isWildcardDecimal() { return false; } @@ -213,7 +223,7 @@ public boolean isFixedLengthType() { } public boolean isNumericType() { - return isFixedPointType() || isFloatingPointType() || isDecimal(); + return isFixedPointType() || isFloatingPointType() || isDecimal() || isDecimalV2(); } public boolean isNativeType() { @@ -453,6 +463,8 @@ public static Type fromPrimitiveType(PrimitiveType type) { return Type.DATETIME; case DECIMAL: return Type.DECIMAL; + case DECIMALV2: + return Type.DECIMALV2; case CHAR: return Type.CHAR; case VARCHAR: @@ -508,6 +520,11 @@ protected static Pair fromThrift(TTypeDesc col, int nodeIdx) { && scalarType.isSetPrecision()); type = ScalarType.createDecimalType(scalarType.getPrecision(), scalarType.getScale()); + } else if (scalarType.getType() == TPrimitiveType.DECIMALV2) { + Preconditions.checkState(scalarType.isSetPrecision() + && scalarType.isSetPrecision()); + type = ScalarType.createDecimalV2Type(scalarType.getPrecision(), + scalarType.getScale()); } else { type = ScalarType.createType( PrimitiveType.fromThrift(scalarType.getType())); @@ -608,6 +625,7 @@ public Integer getPrecision() { case DOUBLE: return 15; case DECIMAL: + case DECIMALV2: return t.decimalPrecision(); default: return null; @@ -635,6 +653,7 @@ public Integer getDecimalDigits() { case DOUBLE: return 15; case DECIMAL: + case DECIMALV2: return t.decimalScale(); default: return null; @@ -664,6 +683,7 @@ public Integer getNumPrecRadix() { case FLOAT: case DOUBLE: case DECIMAL: + case DECIMALV2: return 10; default: // everything else (including boolean and string) is null @@ -789,6 +809,7 @@ public Integer getNumPrecRadix() { compatibilityMatrix[LARGEINT.ordinal()][CHAR.ordinal()] = PrimitiveType.INVALID_TYPE; compatibilityMatrix[LARGEINT.ordinal()][VARCHAR.ordinal()] = PrimitiveType.INVALID_TYPE; compatibilityMatrix[LARGEINT.ordinal()][DECIMAL.ordinal()] = PrimitiveType.DECIMAL; + compatibilityMatrix[LARGEINT.ordinal()][DECIMALV2.ordinal()] = PrimitiveType.DECIMALV2; compatibilityMatrix[LARGEINT.ordinal()][HLL.ordinal()] = PrimitiveType.INVALID_TYPE; compatibilityMatrix[FLOAT.ordinal()][DOUBLE.ordinal()] = PrimitiveType.DOUBLE; @@ -823,6 +844,7 @@ public Integer getNumPrecRadix() { compatibilityMatrix[VARCHAR.ordinal()][HLL.ordinal()] = PrimitiveType.INVALID_TYPE; + compatibilityMatrix[DECIMAL.ordinal()][DECIMALV2.ordinal()] = PrimitiveType.DECIMALV2; // Check all of the necessary entries that should be filled. // ignore binary @@ -835,6 +857,7 @@ public Integer getNumPrecRadix() { t2 == PrimitiveType.INVALID_TYPE) continue; if (t1 == PrimitiveType.NULL_TYPE || t2 == PrimitiveType.NULL_TYPE) continue; if (t1 == PrimitiveType.DECIMAL || t2 == PrimitiveType.DECIMAL) continue; + if (t1 == PrimitiveType.DECIMALV2 || t2 == PrimitiveType.DECIMALV2) continue; Preconditions.checkNotNull(compatibilityMatrix[i][j]); } } @@ -862,6 +885,8 @@ public Type getResultType() { return VARCHAR; case DECIMAL: return DECIMAL; + case DECIMALV2: + return DECIMALV2; default: return INVALID; @@ -885,6 +910,12 @@ public static Type getCmpType(Type t1, Type t2) { || t2ResultType == PrimitiveType.DECIMAL)) { return Type.DECIMAL; } + if ((t1ResultType == PrimitiveType.BIGINT + || t1ResultType == PrimitiveType.DECIMALV2) + && (t2ResultType == PrimitiveType.BIGINT + || t2ResultType == PrimitiveType.DECIMALV2)) { + return Type.DECIMALV2; + } if ((t1ResultType == PrimitiveType.BIGINT || t1ResultType == PrimitiveType.LARGEINT) && (t2ResultType == PrimitiveType.BIGINT @@ -919,6 +950,8 @@ public Type getNumResultType() { return Type.DOUBLE; case DECIMAL: return Type.DECIMAL; + case DECIMALV2: + return Type.DECIMALV2; default: return Type.INVALID; diff --git a/fe/src/main/java/org/apache/doris/common/util/Util.java b/fe/src/main/java/org/apache/doris/common/util/Util.java index 9eeb20be68daed..a857b1cbe7249d 100644 --- a/fe/src/main/java/org/apache/doris/common/util/Util.java +++ b/fe/src/main/java/org/apache/doris/common/util/Util.java @@ -61,6 +61,7 @@ public class Util { TYPE_STRING_MAP.put(PrimitiveType.CHAR, "char(%d)"); TYPE_STRING_MAP.put(PrimitiveType.VARCHAR, "varchar(%d)"); TYPE_STRING_MAP.put(PrimitiveType.DECIMAL, "decimal(%d,%d)"); + TYPE_STRING_MAP.put(PrimitiveType.DECIMALV2, "decimal(%d,%d)"); TYPE_STRING_MAP.put(PrimitiveType.HLL, "varchar(%d)"); } @@ -224,6 +225,7 @@ public static int schemaHash(int schemaVersion, List columns, Set outputExprs) for (Expr expr : outputExprs) { List slotList = Lists.newArrayList(); expr.getIds(null, slotList); - if (PrimitiveType.DECIMAL == expr.getType().getPrimitiveType() - && slotList.contains(slotDesc.getId()) - && PrimitiveType.DECIMAL == slotDesc.getType().getPrimitiveType() - && null != slotDesc.getColumn()) { + if (PrimitiveType.DECIMAL != expr.getType().getPrimitiveType() && + PrimitiveType.DECIMALV2 != expr.getType().getPrimitiveType()) { + continue; + } + + if (PrimitiveType.DECIMAL != slotDesc.getType().getPrimitiveType() && + PrimitiveType.DECIMALV2 != slotDesc.getType().getPrimitiveType()) { + continue; + } + + if (slotList.contains(slotDesc.getId()) && null != slotDesc.getColumn()) { // TODO output scale // int outputScale = slotDesc.getColumn().getType().getScale(); int outputScale = 10; diff --git a/fe/src/main/java/org/apache/doris/rewrite/FEFunctions.java b/fe/src/main/java/org/apache/doris/rewrite/FEFunctions.java index c3b16680eb21b5..995dbd0ea3b2f7 100644 --- a/fe/src/main/java/org/apache/doris/rewrite/FEFunctions.java +++ b/fe/src/main/java/org/apache/doris/rewrite/FEFunctions.java @@ -178,6 +178,15 @@ public static DecimalLiteral addDecimal(LiteralExpr first, LiteralExpr second) t return new DecimalLiteral(result); } + @FEFunction(name = "add", argTypes = { "DECIMALV2", "DECIMALV2" }, returnType = "DECIMALV2") + public static DecimalLiteral addDecimalV2(LiteralExpr first, LiteralExpr second) throws AnalysisException { + BigDecimal left = new BigDecimal(first.getStringValue()); + BigDecimal right = new BigDecimal(second.getStringValue()); + + BigDecimal result = left.add(right); + return new DecimalLiteral(result); + } + @FEFunction(name = "add", argTypes = { "LARGEINT", "LARGEINT" }, returnType = "LARGEINT") public static LargeIntLiteral addBigInt(LiteralExpr first, LiteralExpr second) throws AnalysisException { BigInteger left = new BigInteger(first.getStringValue()); @@ -206,6 +215,15 @@ public static DecimalLiteral subtractDecimal(LiteralExpr first, LiteralExpr seco return new DecimalLiteral(result); } + @FEFunction(name = "subtract", argTypes = { "DECIMALV2", "DECIMALV2" }, returnType = "DECIMALV2") + public static DecimalLiteral subtractDecimalV2(LiteralExpr first, LiteralExpr second) throws AnalysisException { + BigDecimal left = new BigDecimal(first.getStringValue()); + BigDecimal right = new BigDecimal(second.getStringValue()); + + BigDecimal result = left.subtract(right); + return new DecimalLiteral(result); + } + @FEFunction(name = "subtract", argTypes = { "LARGEINT", "LARGEINT" }, returnType = "LARGEINT") public static LargeIntLiteral subtractBigInt(LiteralExpr first, LiteralExpr second) throws AnalysisException { BigInteger left = new BigInteger(first.getStringValue()); @@ -236,6 +254,15 @@ public static DecimalLiteral multiplyDecimal(LiteralExpr first, LiteralExpr seco return new DecimalLiteral(result); } + @FEFunction(name = "multiply", argTypes = { "DECIMALV2", "DECIMALV2" }, returnType = "DECIMALV2") + public static DecimalLiteral multiplyDecimalV2(LiteralExpr first, LiteralExpr second) throws AnalysisException { + BigDecimal left = new BigDecimal(first.getStringValue()); + BigDecimal right = new BigDecimal(second.getStringValue()); + + BigDecimal result = left.multiply(right); + return new DecimalLiteral(result); + } + @FEFunction(name = "multiply", argTypes = { "LARGEINT", "LARGEINT" }, returnType = "LARGEINT") public static LargeIntLiteral multiplyBigInt(LiteralExpr first, LiteralExpr second) throws AnalysisException { BigInteger left = new BigInteger(first.getStringValue()); @@ -257,4 +284,13 @@ public static DecimalLiteral divideDecimal(LiteralExpr first, LiteralExpr second BigDecimal result = left.divide(right); return new DecimalLiteral(result); } + + @FEFunction(name = "divide", argTypes = { "DECIMALV2", "DECIMALV2" }, returnType = "DECIMALV2") + public static DecimalLiteral divideDecimalV2(LiteralExpr first, LiteralExpr second) throws AnalysisException { + BigDecimal left = new BigDecimal(first.getStringValue()); + BigDecimal right = new BigDecimal(second.getStringValue()); + + BigDecimal result = left.divide(right); + return new DecimalLiteral(result); + } } diff --git a/fe/src/main/java/org/apache/doris/task/HadoopLoadPendingTask.java b/fe/src/main/java/org/apache/doris/task/HadoopLoadPendingTask.java index 0c204aeeaf3281..4d8334b08020ef 100644 --- a/fe/src/main/java/org/apache/doris/task/HadoopLoadPendingTask.java +++ b/fe/src/main/java/org/apache/doris/task/HadoopLoadPendingTask.java @@ -532,6 +532,9 @@ public Map toDppColumn() { case DECIMAL: columnType = "DECIMAL"; break; + case DECIMALV2: + columnType = "DECIMAL"; + break; default: columnType = type.toString(); break; @@ -558,7 +561,7 @@ public Map toDppColumn() { } // decimal precision scale - if (type == PrimitiveType.DECIMAL) { + if (type == PrimitiveType.DECIMAL || type == PrimitiveType.DECIMALV2) { dppColumn.put("precision", column.getPrecision()); dppColumn.put("scale", column.getScale()); } diff --git a/fe/src/test/java/org/apache/doris/catalog/ColumnTypeTest.java b/fe/src/test/java/org/apache/doris/catalog/ColumnTypeTest.java index 8c05796a0799f1..dc5a1d91dddf31 100644 --- a/fe/src/test/java/org/apache/doris/catalog/ColumnTypeTest.java +++ b/fe/src/test/java/org/apache/doris/catalog/ColumnTypeTest.java @@ -158,6 +158,9 @@ public void testSerialization() throws Exception { ScalarType type3 = ScalarType.createDecimalType(1, 1); ColumnType.write(dos, type3); + ScalarType type4 = ScalarType.createDecimalV2Type(1, 1); + ColumnType.write(dos, type4); + // 2. Read objects from file DataInputStream dis = new DataInputStream(new FileInputStream(file)); Type rType1 = ColumnType.read(dis); @@ -167,7 +170,9 @@ public void testSerialization() throws Exception { Assert.assertTrue(rType2.equals(type2)); Type rType3 = ColumnType.read(dis); - Assert.assertTrue(rType3.equals(type3)); + + // Change it when remove DecimalV2 + Assert.assertTrue(rType3.equals(type3) || rType3.equals(type4)); Assert.assertFalse(type1.equals(this)); diff --git a/gensrc/script/doris_builtins_functions.py b/gensrc/script/doris_builtins_functions.py index 2714520ddb9ce6..c81c647237cbd2 100755 --- a/gensrc/script/doris_builtins_functions.py +++ b/gensrc/script/doris_builtins_functions.py @@ -321,6 +321,9 @@ [['mod'], 'DECIMAL', ['DECIMAL', 'DECIMAL'], '_ZN5doris16DecimalOperators27mod_decimal_val_decimal_valEPN9doris_udf' '15FunctionContextERKNS1_10DecimalValES6_'], + [['mod'], 'DECIMALV2', ['DECIMALV2', 'DECIMALV2'], + '_ZN5doris18DecimalV2Operators31mod_decimalv2_val_decimalv2_valEPN9doris_udf' + '15FunctionContextERKNS1_12DecimalV2ValES6_'], [['mod', 'fmod'], 'FLOAT', ['FLOAT', 'FLOAT'], '_ZN5doris13MathFunctions10fmod_floatEPN9doris_udf15FunctionContextERKNS1_8FloatValES6_'], [['mod', 'fmod'], 'DOUBLE', ['DOUBLE', 'DOUBLE'], @@ -335,6 +338,9 @@ [['positive'], 'DECIMAL', ['DECIMAL'], '_ZN5doris13MathFunctions16positive_decimalEPN9doris_udf' '15FunctionContextERKNS1_10DecimalValE'], + [['positive'], 'DECIMALV2', ['DECIMALV2'], + '_ZN5doris13MathFunctions16positive_decimalEPN9doris_udf' + '15FunctionContextERKNS1_12DecimalV2ValE'], [['negative'], 'BIGINT', ['BIGINT'], '_ZN5doris13MathFunctions15negative_bigintEPN9doris_udf' '15FunctionContextERKNS1_9BigIntValE'], @@ -344,6 +350,9 @@ [['negative'], 'DECIMAL', ['DECIMAL'], '_ZN5doris13MathFunctions16negative_decimalEPN9doris_udf' '15FunctionContextERKNS1_10DecimalValE'], + [['negative'], 'DECIMALV2', ['DECIMALV2'], + '_ZN5doris13MathFunctions16negative_decimalEPN9doris_udf' + '15FunctionContextERKNS1_12DecimalV2ValE'], [['least'], 'TINYINT', ['TINYINT', '...'], '_ZN5doris13MathFunctions5leastEPN9doris_udf15FunctionContextEiPKNS1_10TinyIntValE'], @@ -365,6 +374,8 @@ '_ZN5doris13MathFunctions5leastEPN9doris_udf15FunctionContextEiPKNS1_11DateTimeValE'], [['least'], 'DECIMAL', ['DECIMAL', '...'], '_ZN5doris13MathFunctions5leastEPN9doris_udf15FunctionContextEiPKNS1_10DecimalValE'], + [['least'], 'DECIMALV2', ['DECIMALV2', '...'], + '_ZN5doris13MathFunctions5leastEPN9doris_udf15FunctionContextEiPKNS1_12DecimalV2ValE'], [['greatest'], 'TINYINT', ['TINYINT', '...'], '_ZN5doris13MathFunctions8greatestEPN9doris_udf15FunctionContextEiPKNS1_10TinyIntValE'], @@ -386,6 +397,8 @@ '_ZN5doris13MathFunctions8greatestEPN9doris_udf15FunctionContextEiPKNS1_11DateTimeValE'], [['greatest'], 'DECIMAL', ['DECIMAL', '...'], '_ZN5doris13MathFunctions8greatestEPN9doris_udf15FunctionContextEiPKNS1_10DecimalValE'], + [['greatest'], 'DECIMALV2', ['DECIMALV2', '...'], + '_ZN5doris13MathFunctions8greatestEPN9doris_udf15FunctionContextEiPKNS1_12DecimalV2ValE'], # Conditional Functions # Some of these have empty symbols because the BE special-cases them based on the @@ -401,6 +414,7 @@ [['if'], 'VARCHAR', ['BOOLEAN', 'VARCHAR', 'VARCHAR'], ''], [['if'], 'DATETIME', ['BOOLEAN', 'DATETIME', 'DATETIME'], ''], [['if'], 'DECIMAL', ['BOOLEAN', 'DECIMAL', 'DECIMAL'], ''], + [['if'], 'DECIMALV2', ['BOOLEAN', 'DECIMALV2', 'DECIMALV2'], ''], [['nullif'], 'BOOLEAN', ['BOOLEAN', 'BOOLEAN'], ''], [['nullif'], 'TINYINT', ['TINYINT', 'TINYINT'], ''], @@ -413,6 +427,7 @@ [['nullif'], 'VARCHAR', ['VARCHAR', 'VARCHAR'], ''], [['nullif'], 'DATETIME', ['DATETIME', 'DATETIME'], ''], [['nullif'], 'DECIMAL', ['DECIMAL', 'DECIMAL'], ''], + [['nullif'], 'DECIMALV2', ['DECIMALV2', 'DECIMALV2'], ''], [['ifnull'], 'BOOLEAN', ['BOOLEAN', 'BOOLEAN'], ''], [['ifnull'], 'TINYINT', ['TINYINT', 'TINYINT'], ''], @@ -425,6 +440,7 @@ [['ifnull'], 'VARCHAR', ['VARCHAR', 'VARCHAR'], ''], [['ifnull'], 'DATETIME', ['DATETIME', 'DATETIME'], ''], [['ifnull'], 'DECIMAL', ['DECIMAL', 'DECIMAL'], ''], + [['ifnull'], 'DECIMALV2', ['DECIMALV2', 'DECIMALV2'], ''], [['coalesce'], 'BOOLEAN', ['BOOLEAN', '...'], ''], [['coalesce'], 'TINYINT', ['TINYINT', '...'], ''], @@ -437,6 +453,7 @@ [['coalesce'], 'VARCHAR', ['VARCHAR', '...'], ''], [['coalesce'], 'DATETIME', ['DATETIME', '...'], ''], [['coalesce'], 'DECIMAL', ['DECIMAL', '...'], ''], + [['coalesce'], 'DECIMALV2', ['DECIMALV2', '...'], ''], [['esquery'], 'BOOLEAN', ['VARCHAR', 'VARCHAR'], '_ZN5doris11ESFunctions5matchEPN' diff --git a/gensrc/script/doris_functions.py b/gensrc/script/doris_functions.py index 549e0044c9f7b3..31b87862ab0726 100755 --- a/gensrc/script/doris_functions.py +++ b/gensrc/script/doris_functions.py @@ -89,12 +89,14 @@ ['Math_Greatest', 'BIGINT', ['BIGINT', '...'], 'MathFunctions::greatest_bigint', ['greatest']], ['Math_Greatest', 'DOUBLE', ['DOUBLE', '...'], 'MathFunctions::greatest_double', ['greatest']], ['Math_Greatest', 'DECIMAL', ['DECIMAL', '...'], 'MathFunctions::greatest_decimal', ['greatest']], + ['Math_Greatest', 'DECIMALV2', ['DECIMALV2', '...'], 'MathFunctions::greatest_decimal', ['greatest']], ['Math_Greatest', 'VARCHAR', ['VARCHAR', '...'], 'MathFunctions::greatest_string', ['greatest']], ['Math_Greatest', 'DATETIME', ['DATETIME', '...'], \ 'MathFunctions::greatest_timestamp', ['greatest']], ['Math_Least', 'BIGINT', ['BIGINT', '...'], 'MathFunctions::least_bigint', ['least']], ['Math_Least', 'DOUBLE', ['DOUBLE', '...'], 'MathFunctions::least_double', ['least']], ['Math_Least', 'DECIMAL', ['DECIMAL', '...'], 'MathFunctions::least_decimal', ['least']], + ['Math_Least', 'DECIMALV2', ['DECIMALV2', '...'], 'MathFunctions::least_decimalv2', ['least']], ['Math_Least', 'VARCHAR', ['VARCHAR', '...'], 'MathFunctions::least_string', ['least']], ['Math_Least', 'DATETIME', ['DATETIME', '...'], 'MathFunctions::least_timestamp', ['least']], @@ -305,6 +307,9 @@ udf_functions = [ ['Udf_Math_Abs', 'DECIMAL', ['DECIMAL'], 'UdfBuiltins::decimal_abs', ['udf_abs'], ''], + ['Udf_Math_Abs', 'DECIMALV2', ['DECIMALV2'], 'UdfBuiltins::decimal_abs', ['udf_abs'], + ''], + ['Udf_Sub_String', 'VARCHAR', ['VARCHAR', 'INT', 'INT'], ['Udf_Sub_String', 'VARCHAR', ['VARCHAR', 'INT', 'INT'], 'UdfBuiltins::sub_string', ['udf_substring'], ''], ['Udf_Add_Two_Number', 'BIGINT', ['BIGINT', 'BIGINT'], diff --git a/gensrc/script/gen_functions.py b/gensrc/script/gen_functions.py index b09122359a1b59..fd3027c494db26 100755 --- a/gensrc/script/gen_functions.py +++ b/gensrc/script/gen_functions.py @@ -386,19 +386,20 @@ 'DATE': ['DATE'], 'DATETIME': ['DATETIME'], 'DECIMAL': ['DECIMAL'], + 'DECIMALV2': ['DECIMALV2'], 'NATIVE_INT_TYPES': ['TINYINT', 'SMALLINT', 'INT', 'BIGINT'], 'INT_TYPES': ['TINYINT', 'SMALLINT', 'INT', 'BIGINT', 'LARGEINT'], 'FLOAT_TYPES': ['FLOAT', 'DOUBLE'], 'NUMERIC_TYPES': ['TINYINT', 'SMALLINT', 'INT', 'BIGINT', 'FLOAT', 'DOUBLE', \ - 'LARGEINT', 'DECIMAL'], + 'LARGEINT', 'DECIMAL', 'DECIMALV2'], 'STRING_TYPES': ['VARCHAR'], 'DATETIME_TYPES': ['DATE', 'DATETIME'], 'FIXED_TYPES': ['BOOLEAN', 'TINYINT', 'SMALLINT', 'INT', 'BIGINT', 'LARGEINT'], 'NATIVE_TYPES': ['BOOLEAN', 'TINYINT', 'SMALLINT', 'INT', 'BIGINT', 'FLOAT', 'DOUBLE'], 'STRCAST_FIXED_TYPES': ['BOOLEAN', 'SMALLINT', 'INT', 'BIGINT'], 'ALL_TYPES': ['BOOLEAN', 'TINYINT', 'SMALLINT', 'INT', 'BIGINT', 'LARGEINT', 'FLOAT',\ - 'DOUBLE', 'VARCHAR', 'DATETIME', 'DECIMAL'], - 'MAX_TYPES': ['BIGINT', 'LARGEINT', 'DOUBLE', 'DECIMAL'], + 'DOUBLE', 'VARCHAR', 'DATETIME', 'DECIMAL', 'DECIMALV2'], + 'MAX_TYPES': ['BIGINT', 'LARGEINT', 'DOUBLE', 'DECIMAL', 'DECIMALV2'], } # Operation, [ReturnType], [[Args1], [Args2], ... [ArgsN]] @@ -411,6 +412,7 @@ ['Int_Divide', ['INT_TYPES'], [['INT_TYPES'], ['INT_TYPES']]], ['Mod', ['INT_TYPES'], [['INT_TYPES'], ['INT_TYPES']]], ['Mod', ['DECIMAL'], [['DECIMAL'], ['DECIMAL']]], + ['Mod', ['DECIMALV2'], [['DECIMALV2'], ['DECIMALV2']]], ['Mod', ['DOUBLE'], [['DOUBLE'], ['DOUBLE']], double_mod], ['BitAnd', ['INT_TYPES'], [['INT_TYPES'], ['INT_TYPES']]], ['BitXor', ['INT_TYPES'], [['INT_TYPES'], ['INT_TYPES']]], @@ -448,6 +450,12 @@ ['Lt', ['BOOLEAN'], [['DECIMAL'], ['DECIMAL']],], ['Ge', ['BOOLEAN'], [['DECIMAL'], ['DECIMAL']],], ['Le', ['BOOLEAN'], [['DECIMAL'], ['DECIMAL']],], + ['Eq', ['BOOLEAN'], [['DECIMALV2'], ['DECIMALV2']],], + ['Ne', ['BOOLEAN'], [['DECIMALV2'], ['DECIMALV2']],], + ['Gt', ['BOOLEAN'], [['DECIMALV2'], ['DECIMALV2']],], + ['Lt', ['BOOLEAN'], [['DECIMALV2'], ['DECIMALV2']],], + ['Ge', ['BOOLEAN'], [['DECIMALV2'], ['DECIMALV2']],], + ['Le', ['BOOLEAN'], [['DECIMALV2'], ['DECIMALV2']],], # Casts ['Cast', ['BOOLEAN'], [['NATIVE_TYPES'], ['BOOLEAN']]], @@ -457,13 +465,18 @@ ['Cast', ['BIGINT'], [['NATIVE_TYPES'], ['BIGINT']]], ['Cast', ['LARGEINT'], [['NATIVE_TYPES'], ['LARGEINT']]], ['Cast', ['LARGEINT'], [['DECIMAL'], ['LARGEINT']]], + ['Cast', ['LARGEINT'], [['DECIMALV2'], ['LARGEINT']]], ['Cast', ['NATIVE_TYPES'], [['LARGEINT'], ['NATIVE_TYPES']]], ['Cast', ['FLOAT'], [['NATIVE_TYPES'], ['FLOAT']]], ['Cast', ['DOUBLE'], [['NATIVE_TYPES'], ['DOUBLE']]], ['Cast', ['DECIMAL'], [['FIXED_TYPES'], ['DECIMAL']]], + ['Cast', ['DECIMALV2'], [['FIXED_TYPES'], ['DECIMALV2']]], ['Cast', ['DECIMAL'], [['FLOAT'], ['DECIMAL']], float_to_decimal], + ['Cast', ['DECIMALV2'], [['FLOAT'], ['DECIMALV2']], float_to_decimal], ['Cast', ['DECIMAL'], [['DOUBLE'], ['DECIMAL']], double_to_decimal], + ['Cast', ['DECIMALV2'], [['DOUBLE'], ['DECIMALV2']], double_to_decimal], ['Cast', ['NATIVE_TYPES'], [['DECIMAL'], ['NATIVE_TYPES']]], + ['Cast', ['NATIVE_TYPES'], [['DECIMALV2'], ['NATIVE_TYPES']]], ['Cast', ['NATIVE_INT_TYPES'], [['STRING'], ['NATIVE_INT_TYPES']], string_to_int], ['Cast', ['LARGEINT'], [['STRING'], ['LARGEINT']], string_to_int], ['Cast', ['FLOAT_TYPES'], [['STRING'], ['FLOAT_TYPES']], string_to_float], @@ -473,6 +486,7 @@ ['Cast', ['STRING'], [['DOUBLE'], ['STRING']], double_to_string], ['Cast', ['STRING'], [['TINYINT'], ['STRING']], tinyint_to_string], ['Cast', ['STRING'], [['DECIMAL'], ['STRING']], decimal_to_string], + ['Cast', ['STRING'], [['DECIMALV2'], ['STRING']], decimal_to_string], # Datetime cast ['Cast', ['DATE'], [['NUMERIC_TYPES'], ['DATE']], numeric_to_date], ['Cast', ['DATETIME'], [['NUMERIC_TYPES'], ['DATETIME']], numeric_to_datetime], @@ -507,6 +521,7 @@ 'DATE': 'Date', 'DATETIME': 'DateTime', 'DECIMAL': 'DecimalValue', + 'DECIMALV2': 'DecimalV2Value', } # Portable type used in the function implementation @@ -523,6 +538,7 @@ 'DATE': 'DateTimeValue', 'DATETIME': 'DateTimeValue', 'DECIMAL': 'DecimalValue', + 'DECIMALV2': 'DecimalV2Value', } result_fields = { 'BOOLEAN': 'bool_val', @@ -537,6 +553,7 @@ 'DATE': 'datetime_val', 'DATETIME': 'datetime_val', 'DECIMAL': 'decimal_val', + 'DECIMALV2': 'decimalv2_val', } native_ops = { diff --git a/gensrc/script/gen_opcodes.py b/gensrc/script/gen_opcodes.py index 3b7827f069662e..48bff40d9d28c2 100755 --- a/gensrc/script/gen_opcodes.py +++ b/gensrc/script/gen_opcodes.py @@ -61,6 +61,7 @@ 'DATE': 'Date', 'DATETIME': 'DateTime', 'DECIMAL': 'DecimalValue', + 'DECIMALV2': 'DecimalV2Value', } thrift_preamble = '\ diff --git a/gensrc/script/gen_vector_functions.py b/gensrc/script/gen_vector_functions.py index fab13008736e8f..b1aa3e185beae9 100755 --- a/gensrc/script/gen_vector_functions.py +++ b/gensrc/script/gen_vector_functions.py @@ -285,6 +285,7 @@ 'DATE': ['DATE'], 'DATETIME': ['DATETIME'], 'DECIMAL': ['DECIMAL'], + 'DECIMALV2': ['DECIMALV2'], 'NATIVE_INT_TYPES': ['TINYINT', 'SMALLINT', 'INT', 'BIGINT'], 'INT_TYPES': ['TINYINT', 'SMALLINT', 'INT', 'BIGINT', 'LARGEINT'], 'FLOAT_TYPES': ['FLOAT', 'DOUBLE'], @@ -292,8 +293,8 @@ 'NATIVE_TYPES': ['BOOLEAN', 'TINYINT', 'SMALLINT', 'INT', 'BIGINT', 'FLOAT', 'DOUBLE'], 'STRCAST_TYPES': ['BOOLEAN', 'SMALLINT', 'INT', 'BIGINT', 'FLOAT', 'DOUBLE'], 'ALL_TYPES': ['BOOLEAN', 'TINYINT', 'SMALLINT', 'INT', 'BIGINT', 'LARGEINT', 'FLOAT',\ - 'DOUBLE', 'VARCHAR', 'DATETIME', 'DECIMAL'], - 'MAX_TYPES': ['BIGINT', 'LARGEINT', 'DOUBLE', 'DECIMAL'], + 'DOUBLE', 'VARCHAR', 'DATETIME', 'DECIMAL', 'DECIMALV2'], + 'MAX_TYPES': ['BIGINT', 'LARGEINT', 'DOUBLE', 'DECIMAL', 'DECIMALV2'], } # Operation, [ReturnType], [[Args1], [Args2], ... [ArgsN]] @@ -323,6 +324,7 @@ 'DATE': 'DateTimeValue', 'DATETIME': 'DateTimeValue', 'DECIMAL': 'DecimalValue', + 'DECIMALV2': 'DecimalV2Value', } # Portable type used in the function implementation @@ -339,6 +341,7 @@ 'DATE': 'DateTimeValue', 'DATETIME': 'DateTimeValue', 'DECIMAL': 'DecimalValue', + 'DECIMALV2': 'DecimalV2Value', } native_ops = { diff --git a/gensrc/thrift/Types.thrift b/gensrc/thrift/Types.thrift index c9934ff74ad7ed..81374e685e16ee 100644 --- a/gensrc/thrift/Types.thrift +++ b/gensrc/thrift/Types.thrift @@ -71,6 +71,7 @@ enum TPrimitiveType { LARGEINT, VARCHAR, HLL, + DECIMALV2 } enum TTypeNodeType {