Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
250 changes: 200 additions & 50 deletions be/src/vec/aggregate_functions/aggregate_function_regr_union.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,88 +32,238 @@
namespace doris::vectorized {
#include "common/compile_check_begin.h"

template <PrimitiveType T>
template <PrimitiveType T,
// requires Sx and Sy
bool NeedSxy,
// level 1: Sx
// level 2: Sxx
size_t SxLevel = size_t {NeedSxy},
// level 1: Sy
// level 2: Syy
size_t SyLevel = size_t {NeedSxy}>
struct AggregateFunctionRegrData {
static constexpr PrimitiveType Type = T;
UInt64 count = 0;
Float64 sum_x {};
Float64 sum_y {};
Float64 sum_of_x_mul_y {};
Float64 sum_of_x_squared {};

static_assert(!NeedSxy || (SxLevel > 0 && SyLevel > 0),
"NeedSxy requires SxLevel > 0 and SyLevel > 0");
static_assert(SxLevel <= 2 && SyLevel <= 2, "Sx/Sy level must be <= 2");

static constexpr bool need_sx = SxLevel > 0;
static constexpr bool need_sy = SyLevel > 0;
static constexpr bool need_sxx = SxLevel > 1;
static constexpr bool need_syy = SyLevel > 1;
static constexpr bool need_sxy = NeedSxy;

static constexpr size_t kMomentSize = SxLevel + SyLevel + size_t {need_sxy};
static_assert(kMomentSize > 0 && kMomentSize <= 5, "Unexpected size of regr moment array");

/**
* The moments array is:
* Sx = sum(X)
* Sy = sum(Y)
* Sxx = sum((X-Sx/N)^2)
* Syy = sum((Y-Sy/N)^2)
* Sxy = sum((X-Sx/N)*(Y-Sy/N))
*/
std::array<Float64, kMomentSize> moments {};
UInt64 n {};

static constexpr size_t idx_sx() {
static_assert(need_sx, "sx not enabled");
return 0;
}
static constexpr size_t idx_sy() {
static_assert(need_sy, "sy not enabled");
return size_t {need_sx};
}
static constexpr size_t idx_sxx() {
static_assert(need_sxx, "sxx not enabled");
return size_t {need_sx + need_sy};
}
static constexpr size_t idx_syy() {
static_assert(need_syy, "syy not enabled");
return size_t {need_sx + need_sy + need_sxx};
}
static constexpr size_t idx_sxy() {
static_assert(need_sxy, "sxy not enabled");
return size_t {need_sx + need_sy + need_sxx + need_syy};
}

Float64& sx() { return moments[idx_sx()]; }
Float64& sy() { return moments[idx_sy()]; }
Float64& sxx() { return moments[idx_sxx()]; }
Float64& syy() { return moments[idx_syy()]; }
Float64& sxy() { return moments[idx_sxy()]; }

const Float64& sx() const { return moments[idx_sx()]; }
const Float64& sy() const { return moments[idx_sy()]; }
const Float64& sxx() const { return moments[idx_sxx()]; }
const Float64& syy() const { return moments[idx_syy()]; }
const Float64& sxy() const { return moments[idx_sxy()]; }

void write(BufferWritable& buf) const {
buf.write_binary(sum_x);
buf.write_binary(sum_y);
buf.write_binary(sum_of_x_mul_y);
buf.write_binary(sum_of_x_squared);
buf.write_binary(count);
if constexpr (need_sx) {
buf.write_binary(sx());
}
if constexpr (need_sy) {
buf.write_binary(sy());
}
if constexpr (need_sxx) {
buf.write_binary(sxx());
}
if constexpr (need_syy) {
buf.write_binary(syy());
}
if constexpr (need_sxy) {
buf.write_binary(sxy());
}
buf.write_binary(n);
}

void read(BufferReadable& buf) {
buf.read_binary(sum_x);
buf.read_binary(sum_y);
buf.read_binary(sum_of_x_mul_y);
buf.read_binary(sum_of_x_squared);
buf.read_binary(count);
if constexpr (need_sx) {
buf.read_binary(sx());
}
if constexpr (need_sy) {
buf.read_binary(sy());
}
if constexpr (need_sxx) {
buf.read_binary(sxx());
}
if constexpr (need_syy) {
buf.read_binary(syy());
}
if constexpr (need_sxy) {
buf.read_binary(sxy());
}
buf.read_binary(n);
}

void reset() {
sum_x = {};
sum_y = {};
sum_of_x_mul_y = {};
sum_of_x_squared = {};
count = 0;
moments.fill({});
n = {};
}

/**
* The merge function uses the Youngs–Cramer algorithm:
* N = N1 + N2
* Sx = Sx1 + Sx2
* Sy = Sy1 + Sy2
* Sxx = Sxx1 + Sxx2 + N1 * N2 * (Sx1/N1 - Sx2/N2)^2 / N
* Syy = Syy1 + Syy2 + N1 * N2 * (Sy1/N1 - Sy2/N2)^2 / N
* Sxy = Sxy1 + Sxy2 + N1 * N2 * (Sx1/N1 - Sx2/N2) * (Sy1/N1 - Sy2/N2) / N
*/
void merge(const AggregateFunctionRegrData& rhs) {
if (rhs.count == 0) {
if (rhs.n == 0) {
return;
}
if (n == 0) {
*this = rhs;
return;
}
sum_x += rhs.sum_x;
sum_y += rhs.sum_y;
sum_of_x_mul_y += rhs.sum_of_x_mul_y;
sum_of_x_squared += rhs.sum_of_x_squared;
count += rhs.count;
const auto n1 = static_cast<Float64>(n);
const auto n2 = static_cast<Float64>(rhs.n);
const auto nsum = n1 + n2;

Float64 dx {};
Float64 dy {};
if constexpr (need_sxx || need_sxy) {
dx = sx() / n1 - rhs.sx() / n2;
}
if constexpr (need_syy || need_sxy) {
dy = sy() / n1 - rhs.sy() / n2;
}

n += rhs.n;
if constexpr (need_sx) {
sx() += rhs.sx();
}
if constexpr (need_sy) {
sy() += rhs.sy();
}
if constexpr (need_sxx) {
sxx() += rhs.sxx() + n1 * n2 * dx * dx / nsum;
}
if constexpr (need_syy) {
syy() += rhs.syy() + n1 * n2 * dy * dy / nsum;
}
if constexpr (need_sxy) {
sxy() += rhs.sxy() + n1 * n2 * dx * dy / nsum;
}
}

/**
* N
* Sx = sum(X)
* Sy = sum(Y)
* Sxx = sum((X-Sx/N)^2)
* Syy = sum((Y-Sy/N)^2)
* Sxy = sum((X-Sx/N)*(Y-Sy/N))
*/
void add(typename PrimitiveTypeTraits<T>::ColumnItemType value_y,
typename PrimitiveTypeTraits<T>::ColumnItemType value_x) {
sum_x += (double)value_x;
sum_y += (double)value_y;
sum_of_x_mul_y += (double)value_x * (double)value_y;
sum_of_x_squared += (double)value_x * (double)value_x;
count += 1;
}
const auto x = static_cast<Float64>(value_x);
const auto y = static_cast<Float64>(value_y);

Float64 get_slope() const {
Float64 denominator = (double)count * sum_of_x_squared - sum_x * sum_x;
if (count < 2 || denominator == 0.0) {
return std::numeric_limits<Float64>::quiet_NaN();
if constexpr (need_sx) {
sx() += x;
}
if constexpr (need_sy) {
sy() += y;
}

if (n == 0) [[unlikely]] {
n = 1;
return;
}
const auto n_old = static_cast<Float64>(n);
const auto n_new = n_old + 1;
const auto scale = 1.0 / (n_new * n_old);
n += 1;

Float64 tmp_x {};
Float64 tmp_y {};
if constexpr (need_sxx || need_sxy) {
tmp_x = x * n_new - sx();
}
if constexpr (need_syy || need_sxy) {
tmp_y = y * n_new - sy();
}

if constexpr (need_sxx) {
sxx() += tmp_x * tmp_x * scale;
}
if constexpr (need_syy) {
syy() += tmp_y * tmp_y * scale;
}
if constexpr (need_sxy) {
sxy() += tmp_x * tmp_y * scale;
}
Float64 slope = ((double)count * sum_of_x_mul_y - sum_x * sum_y) / denominator;
return slope;
}
};

template <PrimitiveType T>
struct RegrSlopeFunc : AggregateFunctionRegrData<T> {
struct RegrSlopeFunc : AggregateFunctionRegrData<T, true, 2, 1> {
static constexpr const char* name = "regr_slope";

Float64 get_result() const { return this->get_slope(); }
Float64 get_result() const {
if (this->n < 1 || this->sxx() == 0.0) {
return std::numeric_limits<Float64>::quiet_NaN();
}
return this->sxy() / this->sxx();
}
};

template <PrimitiveType T>
struct RegrInterceptFunc : AggregateFunctionRegrData<T> {
struct RegrInterceptFunc : AggregateFunctionRegrData<T, true, 2, 2> {
static constexpr const char* name = "regr_intercept";

Float64 get_result() const {
auto slope = this->get_slope();
if (std::isnan(slope)) {
return slope;
} else {
Float64 intercept = (this->sum_y - slope * this->sum_x) / (double)this->count;
return intercept;
if (this->n < 1 || this->sxx() == 0.0) {
return std::numeric_limits<Float64>::quiet_NaN();
}
return (this->sy() - this->sx() * this->sxy() / this->sxx()) /
static_cast<Float64>(this->n);
}
};

Expand Down Expand Up @@ -147,7 +297,7 @@ class AggregateFunctionRegrSimple
const XInputCol* x_nested_column = nullptr;

if constexpr (y_nullable) {
const ColumnNullable& y_column_nullable =
const auto& y_column_nullable =
assert_cast<const ColumnNullable&, TypeCheckOnRelease::DISABLE>(*columns[0]);
y_null = y_column_nullable.is_null_at(row_num);
y_nested_column = assert_cast<const YInputCol*, TypeCheckOnRelease::DISABLE>(
Expand All @@ -158,7 +308,7 @@ class AggregateFunctionRegrSimple
}

if constexpr (x_nullable) {
const ColumnNullable& x_column_nullable =
const auto& x_column_nullable =
assert_cast<const ColumnNullable&, TypeCheckOnRelease::DISABLE>(*columns[1]);
x_null = x_column_nullable.is_null_at(row_num);
x_nested_column = assert_cast<const XInputCol*, TypeCheckOnRelease::DISABLE>(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,17 @@
-990000.0

-- !regr_intercept_int --
1000001.0
-9.999E9

-- !regr_intercept_bigint --
\N

-- !regr_intercept_largeint --
9.999999999999989E19
1.0E20

-- !regr_intercept_float --
13.241664047161644
13.24166404716167

-- !regr_intercept_double --
58.05515207632899
58.05515207633332

Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,17 @@
1.0

-- !regr_slope_int --
-0.0
1.0

-- !regr_slope_bigint --
\N

-- !regr_slope_largeint --
17725.127617654194
0.0

-- !regr_slope_float --
-2.79289213515492
-2.792892135154929

-- !regr_slope_double --
-0.5501239199999569
-0.5501239199999999

Loading
Loading