Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 135 additions & 0 deletions be/src/exprs/aggregate_functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2061,6 +2061,13 @@ struct KnuthVarianceState {
int64_t count;
};

// Use Decimal to store the intermediate results of the variance algorithm
struct DecimalV2KnuthVarianceState {
DecimalV2Val mean;
DecimalV2Val m2;
int64_t count = 0;
};

// Set pop=true for population variance, false for sample variance
static double compute_knuth_variance(const KnuthVarianceState& state, bool pop) {
// Return zero for 1 tuple specified by
Expand All @@ -2070,6 +2077,16 @@ static double compute_knuth_variance(const KnuthVarianceState& state, bool pop)
return state.m2 / (state.count - 1);
}

// The algorithm is the same as above, using decimal as the intermediate variable
static DecimalV2Value decimalv2_compute_knuth_variance(const DecimalV2KnuthVarianceState& state, bool pop) {
DecimalV2Value new_count = DecimalV2Value();
new_count.assign_from_double(state.count);
if (state.count == 1) return new_count;
DecimalV2Value new_m2 = DecimalV2Value::from_decimal_val(state.m2);
if (pop) return new_m2 / new_count;
else return new_m2 / new_count.assign_from_double(state.count - 1);
}

void AggregateFunctions::knuth_var_init(FunctionContext* ctx, StringVal* dst) {
dst->is_null = false;
// TODO(zc)
Expand All @@ -2079,6 +2096,15 @@ void AggregateFunctions::knuth_var_init(FunctionContext* ctx, StringVal* dst) {
memset(dst->ptr, 0, dst->len);
}

void AggregateFunctions::decimalv2_knuth_var_init(FunctionContext* ctx, StringVal* dst) {
dst->is_null = false;
dst->len = sizeof(DecimalV2KnuthVarianceState);
// The memory for int128 need to be aligned by 16.
// So the constructor has been used instead of allocating memory.
// Also, it will be release in finalize.
dst->ptr = (uint8_t*) new DecimalV2KnuthVarianceState;
}

template <typename T>
void AggregateFunctions::knuth_var_update(FunctionContext* ctx, const T& src, StringVal* dst) {
DCHECK(!dst->is_null);
Expand All @@ -2093,6 +2119,34 @@ void AggregateFunctions::knuth_var_update(FunctionContext* ctx, const T& src, St
state->count = temp;
}

void AggregateFunctions::knuth_var_update(FunctionContext* ctx, const DecimalV2Val& src, StringVal* dst) {
DCHECK(!dst->is_null);
DCHECK_EQ(dst->len, sizeof(DecimalV2KnuthVarianceState));
if (src.is_null) return;
DecimalV2KnuthVarianceState* state = reinterpret_cast<DecimalV2KnuthVarianceState*>(dst->ptr);

DecimalV2Value new_src = DecimalV2Value::from_decimal_val(src);
DecimalV2Value new_mean = DecimalV2Value::from_decimal_val(state->mean);
DecimalV2Value new_m2 = DecimalV2Value::from_decimal_val(state->m2);
DecimalV2Value new_count = DecimalV2Value();
new_count.assign_from_double(state->count);

DecimalV2Value temp = DecimalV2Value();
temp.assign_from_double(1 + state->count);
DecimalV2Value delta = new_src - new_mean;
DecimalV2Value r = delta / temp;
new_mean += r;
// This may cause Decimal to overflow. When it overflows, m2 will be equal to 9223372036854775807999999999,
// which is the maximum value that DecimalV2Value can represent. When using double to store the intermediate result m2,
// it can be expressed by scientific and technical methods and will not overflow.
// Spark's handling of decimal overflow is to return null or report an error, which can be controlled by parameters.
// Spark's handling of decimal reference: https://cloud.tencent.com/developer/news/483615
new_m2 += new_count * delta * r;
++state->count;
new_mean.to_decimal_val(&state->mean);
new_m2.to_decimal_val(&state->m2);
}

void AggregateFunctions::knuth_var_merge(FunctionContext* ctx, const StringVal& src,
StringVal* dst) {
DCHECK(!dst->is_null);
Expand All @@ -2112,6 +2166,33 @@ void AggregateFunctions::knuth_var_merge(FunctionContext* ctx, const StringVal&
dst_state->count = sum_count;
}

void AggregateFunctions::decimalv2_knuth_var_merge(FunctionContext* ctx, const StringVal& src,
StringVal* dst) {
DecimalV2KnuthVarianceState src_state;
memcpy(&src_state, src.ptr, sizeof(DecimalV2KnuthVarianceState));
DCHECK(!dst->is_null);
DCHECK_EQ(dst->len, sizeof(DecimalV2KnuthVarianceState));
DecimalV2KnuthVarianceState* dst_state = reinterpret_cast<DecimalV2KnuthVarianceState*>(dst->ptr);
if (src_state.count == 0) return;

DecimalV2Value new_src_mean = DecimalV2Value::from_decimal_val(src_state.mean);
DecimalV2Value new_dst_mean = DecimalV2Value::from_decimal_val(dst_state->mean);
DecimalV2Value new_src_count = DecimalV2Value();
new_src_count.assign_from_double(src_state.count);
DecimalV2Value new_dst_count = DecimalV2Value();
new_dst_count.assign_from_double(dst_state->count);
DecimalV2Value new_src_m2 = DecimalV2Value::from_decimal_val(src_state.m2);
DecimalV2Value new_dst_m2 = DecimalV2Value::from_decimal_val(dst_state->m2);

DecimalV2Value delta = new_dst_mean - new_src_mean;
DecimalV2Value sum_count = new_dst_count + new_src_count;
new_dst_mean = new_src_mean + delta * (new_dst_count / sum_count);
new_dst_m2 = (new_src_m2) + new_dst_m2 + (delta * delta) * (new_src_count * new_dst_count / sum_count);
dst_state->count += src_state.count;
new_dst_mean.to_decimal_val(&dst_state->mean);
new_dst_m2.to_decimal_val(&dst_state->m2);
}

DoubleVal AggregateFunctions::knuth_var_finalize(FunctionContext* ctx, const StringVal& state_sv) {
DCHECK(!state_sv.is_null);
KnuthVarianceState* state = reinterpret_cast<KnuthVarianceState*>(state_sv.ptr);
Expand All @@ -2121,6 +2202,19 @@ DoubleVal AggregateFunctions::knuth_var_finalize(FunctionContext* ctx, const Str
return DoubleVal(variance);
}

DecimalV2Val AggregateFunctions::decimalv2_knuth_var_finalize(FunctionContext* ctx,
const StringVal& state_sv) {
DCHECK(!state_sv.is_null);
DCHECK_EQ(state_sv.len, sizeof(DecimalV2KnuthVarianceState));
DecimalV2KnuthVarianceState* state = reinterpret_cast<DecimalV2KnuthVarianceState*>(state_sv.ptr);
if (state->count == 0 || state->count == 1) return DecimalV2Val::null();
DecimalV2Value variance = decimalv2_compute_knuth_variance(*state, false);
DecimalV2Val res;
variance.to_decimal_val(&res);
delete (DecimalV2KnuthVarianceState*)state_sv.ptr;
return res;
}

DoubleVal AggregateFunctions::knuth_var_pop_finalize(FunctionContext* ctx,
const StringVal& state_sv) {
DCHECK(!state_sv.is_null);
Expand All @@ -2132,6 +2226,19 @@ DoubleVal AggregateFunctions::knuth_var_pop_finalize(FunctionContext* ctx,
return DoubleVal(variance);
}

DecimalV2Val AggregateFunctions::decimalv2_knuth_var_pop_finalize(FunctionContext* ctx,
const StringVal& state_sv) {
DCHECK(!state_sv.is_null);
DCHECK_EQ(state_sv.len, sizeof(DecimalV2KnuthVarianceState));
DecimalV2KnuthVarianceState* state = reinterpret_cast<DecimalV2KnuthVarianceState*>(state_sv.ptr);
if (state->count == 0) return DecimalV2Val::null();
DecimalV2Value variance = decimalv2_compute_knuth_variance(*state, true);
DecimalV2Val res;
variance.to_decimal_val(&res);
delete (DecimalV2KnuthVarianceState*)state_sv.ptr;
return res;
}

DoubleVal AggregateFunctions::knuth_stddev_finalize(FunctionContext* ctx,
const StringVal& state_sv) {
DCHECK(!state_sv.is_null);
Expand All @@ -2143,6 +2250,20 @@ DoubleVal AggregateFunctions::knuth_stddev_finalize(FunctionContext* ctx,
return DoubleVal(variance);
}

DecimalV2Val AggregateFunctions::decimalv2_knuth_stddev_finalize(FunctionContext* ctx,
const StringVal& state_sv) {
DCHECK(!state_sv.is_null);
DCHECK_EQ(state_sv.len, sizeof(DecimalV2KnuthVarianceState));
DecimalV2KnuthVarianceState* state = reinterpret_cast<DecimalV2KnuthVarianceState*>(state_sv.ptr);
if (state->count == 0 || state->count == 1) return DecimalV2Val::null();
DecimalV2Value variance = decimalv2_compute_knuth_variance(*state, false);
variance = DecimalV2Value::sqrt(variance);
DecimalV2Val res;
variance.to_decimal_val(&res);
delete (DecimalV2KnuthVarianceState*)state_sv.ptr;
return res;
}

DoubleVal AggregateFunctions::knuth_stddev_pop_finalize(FunctionContext* ctx,
const StringVal& state_sv) {
DCHECK(!state_sv.is_null);
Expand All @@ -2154,6 +2275,20 @@ DoubleVal AggregateFunctions::knuth_stddev_pop_finalize(FunctionContext* ctx,
return DoubleVal(variance);
}

DecimalV2Val AggregateFunctions::decimalv2_knuth_stddev_pop_finalize(FunctionContext* ctx,
const StringVal& state_sv) {
DCHECK(!state_sv.is_null);
DCHECK_EQ(state_sv.len, sizeof(DecimalV2KnuthVarianceState));
DecimalV2KnuthVarianceState* state = reinterpret_cast<DecimalV2KnuthVarianceState*>(state_sv.ptr);
if (state->count == 0) return DecimalV2Val::null();
DecimalV2Value variance = decimalv2_compute_knuth_variance(*state, true);
variance = DecimalV2Value::sqrt(variance);
DecimalV2Val res;
variance.to_decimal_val(&res);
delete (DecimalV2KnuthVarianceState*)state_sv.ptr;
return res;
}

struct RankState {
int64_t rank;
int64_t count;
Expand Down
9 changes: 9 additions & 0 deletions be/src/exprs/aggregate_functions.h
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,15 @@ class AggregateFunctions {
/// Calculates the biased STDDEV, uses KnuthVar Init-Update-Merge functions
static DoubleVal knuth_stddev_pop_finalize(FunctionContext* context, const StringVal& val);

// variance/stddev for decimals.
static void decimalv2_knuth_var_init(FunctionContext* context, StringVal* val);
static void knuth_var_update(FunctionContext* context, const DecimalV2Val& src, StringVal* val);
static void decimalv2_knuth_var_merge(FunctionContext* context, const StringVal& src, StringVal* val);
static DecimalV2Val decimalv2_knuth_var_finalize(FunctionContext* context, const StringVal& val);
static DecimalV2Val decimalv2_knuth_var_pop_finalize(FunctionContext* context, const StringVal& val);
static DecimalV2Val decimalv2_knuth_stddev_finalize(FunctionContext* context, const StringVal& val);
static DecimalV2Val decimalv2_knuth_stddev_pop_finalize(FunctionContext* context, const StringVal& val);

/// ----------------------------- Analytic Functions ---------------------------------
/// Analytic functions implement the UDA interface (except Merge(), Serialize()) and are
/// used internally by the AnalyticEvalNode. Some analytic functions store intermediate
Expand Down
107 changes: 107 additions & 0 deletions be/src/runtime/decimalv2_value.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,113 @@ DecimalV2Value& DecimalV2Value::operator+=(const DecimalV2Value& other) {
return *this;
}

// Solve a one-dimensional quadratic equation: ax2 + bx + c =0
// Reference: https://gist.github.com/miloyip/1fcc1859c94d33a01957cf41a7c25fdf
// Reference: https://www.zhihu.com/question/51381686
static std::pair<double, double> quadratic_equation_naive(__uint128_t a, __uint128_t b, __uint128_t c) {
__uint128_t dis = b * b - 4 * a * c;
// assert(dis >= 0);
// not handling complex root
if (dis < 0) return std::make_pair(0, 0);
double sqrtdis = std::sqrt(static_cast<double>(dis));
double a_r = static_cast<double>(a);
double b_r = static_cast<double>(b);
double x1 = (-b_r - sqrtdis) / (a_r + a_r);
double x2 = (-b_r + sqrtdis) / (a_r + a_r);
return std::make_pair(x1, x2);
}

static inline double sgn(double x) {
if (x > 0) return 1;
else if (x < 0) return -1;
else return 0;
}

// In the above quadratic_equation_naive solution process, we found that -b + sqrtdis will
// get the correct answer, and -b-sqrtdis will get the wrong answer. For two close floating-point
// decimals a, b, a-b will cause larger errors than a + b, which is called catastrophic cancellation.
// Both -b and sqrtdis are positive numbers. We can first find the roots brought by -b + sqrtdis,
// and then use the product of the two roots of the quadratic equation in one unknown to find another root
static std::pair<double, double> quadratic_equation_better(int128_t a, int128_t b, int128_t c) {
if (b == 0) return quadratic_equation_naive(a, b, c);
int128_t dis = b * b - 4 * a * c;
// assert(dis >= 0);
// not handling complex root
if (dis < 0) return std::make_pair(0, 0);

// There may be a loss of precision, but here is used to find the mantissa of the square root.
// The current SCALE=9, which is less than the 15 significant digits of the double type,
// so theoretically the loss of precision will not be reflected in the result.
double sqrtdis = std::sqrt(static_cast<double>(dis));
double a_r = static_cast<double>(a);
double b_r = static_cast<double>(b);
double c_r = static_cast<double>(c);
// Here b comes from an unsigned integer, and sgn(b) is always 1,
// which is only used to preserve the complete algorithm
double x1 = (-b_r - sgn(b_r) * sqrtdis) / (a_r + a_r);
double x2 = c_r / (a_r * x1);
return std::make_pair(x1, x2);
}

// Large integer square roots, returns the integer part.
// The time complexity is lower than the traditional dichotomy
// and Newton iteration method, and the number of iterations is fixed.
// in real-time systems, functions that execute an unpredictable number of iterations
// will make the total time per task unpredictable, and introduce jitter
// Reference: https://www.embedded.com/integer-square-roots/
// Reference: https://link.zhihu.com/?target=https%3A//gist.github.com/miloyip/69663b78b26afa0dcc260382a6034b1a
// Reference: https://www.zhihu.com/question/35122102
static std::pair<__uint128_t, __uint128_t> sqrt_integer(__uint128_t n) {
__uint128_t remainder = 0, root = 0;
for (size_t i = 0; i < 64; i++) {
root <<= 1;
++root;
remainder <<= 2;
remainder |= n >> 126; n <<= 2; // Extract 2 MSB from n
if (root <= remainder) {
remainder -= root;
++root;
}
else{
--root;
}
}
return std::make_pair(root >>= 1, remainder);
}

// According to the integer part and the remainder of the square root,
// Use one-dimensional quadratic equation to solve the fractional part of the square root
static double sqrt_fractional(int128_t sqrt_int, int128_t remainder) {
std::pair<double, double> p = quadratic_equation_better(1, 2*sqrt_int, -remainder);
if ((0 < p.first) && (p.first < 1)) return p.first;
if ((0 < p.second) && (p.second < 1)) return p.second;
return 0;
}

const int128_t DecimalV2Value::SQRT_MOLECULAR_MAGNIFICATION = get_scale_base(PRECISION/2);
const int128_t DecimalV2Value::SQRT_DENOMINATOR = std::sqrt(ONE_BILLION) * get_scale_base(PRECISION/2 - SCALE);

DecimalV2Value DecimalV2Value::sqrt(const DecimalV2Value& v) {
int128_t x = v.value();
std::pair<__uint128_t, __uint128_t> sqrt_integer_ret;
bool is_negative = (x < 0);
if (x == 0) {
return DecimalV2Value(0);
}
sqrt_integer_ret = sqrt_integer(abs(x));
int128_t integer_root = static_cast<int128_t>(sqrt_integer_ret.first);
int128_t integer_remainder = static_cast<int128_t>(sqrt_integer_ret.second);
double fractional = sqrt_fractional(integer_root, integer_remainder);

// Multiplying by SQRT_MOLECULAR_MAGNIFICATION here will not overflow,
// because integer_root can be up to 64 bits.
int128_t molecular_integer = integer_root * SQRT_MOLECULAR_MAGNIFICATION;
int128_t molecular_fractional = static_cast<int128_t>(fractional * SQRT_MOLECULAR_MAGNIFICATION);
int128_t ret = (molecular_integer + molecular_fractional)/SQRT_DENOMINATOR;
if (is_negative) ret = -ret;
return DecimalV2Value(ret);
}

int DecimalV2Value::parse_from_str(const char* decimal_str, int32_t length) {
int32_t error = E_DEC_OK;
StringParser::ParseResult result = StringParser::PARSE_SUCCESS;
Expand Down
8 changes: 8 additions & 0 deletions be/src/runtime/decimalv2_value.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,11 @@ class DecimalV2Value {
static const int64_t MAX_INT_VALUE = 999999999999999999;
static const int32_t MAX_FRAC_VALUE = 999999999;
static const int64_t MAX_INT64 = 9223372036854775807ll;
// In sqrt, the integer part and the decimal part of the square root to be solved separately are
// multiplied by the PRECISION/2 power of 10, so that they can be placed in an int128_t variable
static const int128_t SQRT_MOLECULAR_MAGNIFICATION;
// sqrt(ONE_BILLION) * pow(10, PRECISION/2 - SCALE), it is used to calculate SCALE of the sqrt result
static const int128_t SQRT_DENOMINATOR;

static const int128_t MAX_DECIMAL_VALUE =
static_cast<int128_t>(MAX_INT64) * ONE_BILLION + MAX_FRAC_VALUE;
Expand Down Expand Up @@ -204,6 +209,9 @@ class DecimalV2Value {

void to_decimal_val(DecimalV2Val* value) const { value->val = _value; }

// Solve Square root for int128
static DecimalV2Value sqrt(const DecimalV2Value& v);

// set DecimalV2Value to zero
void set_to_zero() { _value = 0; }

Expand Down
Loading