From 2e1375c6bf6668ac7b0bce2e3e8c0898621df5a8 Mon Sep 17 00:00:00 2001 From: Jover Zhang Date: Fri, 12 Sep 2025 11:11:53 +0800 Subject: [PATCH 1/3] [fix](regr) Use Youngs-Cramer for REGR_SLOPE/INTERCEPT align with PG --- .../aggregate_function_regr_union.h | 133 +++++++++++------- .../regr_intercept/regr_intercept.out | 8 +- .../support_type/regr_slope/regr_slope.out | 8 +- .../aggregate/test_regr_intercept.groovy | 18 +-- .../query_p0/aggregate/test_regr_slope.groovy | 18 +-- 5 files changed, 112 insertions(+), 73 deletions(-) diff --git a/be/src/vec/aggregate_functions/aggregate_function_regr_union.h b/be/src/vec/aggregate_functions/aggregate_function_regr_union.h index 1cc30b8c4307e7..685e0e750272f6 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_regr_union.h +++ b/be/src/vec/aggregate_functions/aggregate_function_regr_union.h @@ -35,63 +35,100 @@ namespace doris::vectorized { template struct AggregateFunctionRegrData { static constexpr PrimitiveType Type = T; - UInt64 count = 0; - Float64 sum_x {}; - Float64 sum_y {}; - Float64 sum_of_x_mul_y {}; - Float64 sum_of_x_squared {}; + UInt64 n {}; + Float64 sx {}; + Float64 sy {}; + Float64 sxx {}; + Float64 syy {}; + Float64 sxy {}; void write(BufferWritable& buf) const { - buf.write_binary(sum_x); - buf.write_binary(sum_y); - buf.write_binary(sum_of_x_mul_y); - buf.write_binary(sum_of_x_squared); - buf.write_binary(count); + buf.write_binary(sx); + buf.write_binary(sy); + buf.write_binary(sxx); + buf.write_binary(syy); + buf.write_binary(sxy); + buf.write_binary(n); } void read(BufferReadable& buf) { - buf.read_binary(sum_x); - buf.read_binary(sum_y); - buf.read_binary(sum_of_x_mul_y); - buf.read_binary(sum_of_x_squared); - buf.read_binary(count); + buf.read_binary(sx); + buf.read_binary(sy); + buf.read_binary(sxx); + buf.read_binary(syy); + buf.read_binary(sxy); + buf.read_binary(n); } void reset() { - sum_x = {}; - sum_y = {}; - sum_of_x_mul_y = {}; - sum_of_x_squared = {}; - count = 0; + sx = {}; + sy = {}; + sxx = {}; + syy = {}; + sxy = {}; + n = {}; } + /** + * The merge function uses the Youngs–Cramer algorithm: + * N = N1 + N2 + * Sx = Sx1 + Sx2 + * Sy = Sy1 + Sy2 + * Sxx = Sxx1 + Sxx2 + N1 * N2 * (Sx1/N1 - Sx2/N2)^2 / N + * Syy = Syy1 + Syy2 + N1 * N2 * (Sy1/N1 - Sy2/N2)^2 / N + * Sxy = Sxy1 + Sxy2 + N1 * N2 * (Sx1/N1 - Sx2/N2) * (Sy1/N1 - Sy2/N2) / N + */ void merge(const AggregateFunctionRegrData& rhs) { - if (rhs.count == 0) { + if (rhs.n == 0) { return; } - sum_x += rhs.sum_x; - sum_y += rhs.sum_y; - sum_of_x_mul_y += rhs.sum_of_x_mul_y; - sum_of_x_squared += rhs.sum_of_x_squared; - count += rhs.count; + if (n == 0) { + *this = rhs; + return; + } + const auto n1 = static_cast(n); + const auto n2 = static_cast(rhs.n); + const auto nsum = n1 + n2; + + const auto dx = sx / n1 - rhs.sx / n2; + const auto dy = sy / n1 - rhs.sy / n2; + + n += rhs.n; + sx += rhs.sx; + sy += rhs.sy; + sxx += rhs.sxx + n1 * n2 * dx * dx / nsum; + syy += rhs.syy + n1 * n2 * dy * dy / nsum; + sxy += rhs.sxy + n1 * n2 * dx * dy / nsum; } + /** + * N + * Sx = sum(X) + * Sy = sum(Y) + * Sxx = sum((X-Sx/N)^2) + * Syy = sum((Y-Sy/N)^2) + * Sxy = sum((X-Sx/N)*(Y-Sy/N)) + */ void add(typename PrimitiveTypeTraits::ColumnItemType value_y, typename PrimitiveTypeTraits::ColumnItemType value_x) { - sum_x += (double)value_x; - sum_y += (double)value_y; - sum_of_x_mul_y += (double)value_x * (double)value_y; - sum_of_x_squared += (double)value_x * (double)value_x; - count += 1; - } + const auto x = static_cast(value_x); + const auto y = static_cast(value_y); + sx += x; + sy += y; - Float64 get_slope() const { - Float64 denominator = (double)count * sum_of_x_squared - sum_x * sum_x; - if (count < 2 || denominator == 0.0) { - return std::numeric_limits::quiet_NaN(); + if (n == 0) [[unlikely]] { + n = 1; + return; } - Float64 slope = ((double)count * sum_of_x_mul_y - sum_x * sum_y) / denominator; - return slope; + const auto tmp_n = static_cast(n + 1); + const auto tmp_x = x * tmp_n - sx; + const auto tmp_y = y * tmp_n - sy; + const auto scale = 1.0 / (tmp_n * static_cast(n)); + + n += 1; + sxx += tmp_x * tmp_x * scale; + syy += tmp_y * tmp_y * scale; + sxy += tmp_x * tmp_y * scale; } }; @@ -99,7 +136,12 @@ template struct RegrSlopeFunc : AggregateFunctionRegrData { static constexpr const char* name = "regr_slope"; - Float64 get_result() const { return this->get_slope(); } + Float64 get_result() const { + if (this->n < 1 || this->sxx == 0.0) { + return std::numeric_limits::quiet_NaN(); + } + return this->sxy / this->sxx; + } }; template @@ -107,13 +149,10 @@ struct RegrInterceptFunc : AggregateFunctionRegrData { static constexpr const char* name = "regr_intercept"; Float64 get_result() const { - auto slope = this->get_slope(); - if (std::isnan(slope)) { - return slope; - } else { - Float64 intercept = (this->sum_y - slope * this->sum_x) / (double)this->count; - return intercept; + if (this->n < 1 || this->sxx == 0.0) { + return std::numeric_limits::quiet_NaN(); } + return (this->sy - this->sx * this->sxy / this->sxx) / static_cast(this->n); } }; @@ -147,7 +186,7 @@ class AggregateFunctionRegrSimple const XInputCol* x_nested_column = nullptr; if constexpr (y_nullable) { - const ColumnNullable& y_column_nullable = + const auto& y_column_nullable = assert_cast(*columns[0]); y_null = y_column_nullable.is_null_at(row_num); y_nested_column = assert_cast( @@ -158,7 +197,7 @@ class AggregateFunctionRegrSimple } if constexpr (x_nullable) { - const ColumnNullable& x_column_nullable = + const auto& x_column_nullable = assert_cast(*columns[1]); x_null = x_column_nullable.is_null_at(row_num); x_nested_column = assert_cast( diff --git a/regression-test/data/query_p0/aggregate/support_type/regr_intercept/regr_intercept.out b/regression-test/data/query_p0/aggregate/support_type/regr_intercept/regr_intercept.out index 88a91371f5f820..f58aaf4a55f698 100644 --- a/regression-test/data/query_p0/aggregate/support_type/regr_intercept/regr_intercept.out +++ b/regression-test/data/query_p0/aggregate/support_type/regr_intercept/regr_intercept.out @@ -6,17 +6,17 @@ -990000.0 -- !regr_intercept_int -- -1000001.0 +-9.999E9 -- !regr_intercept_bigint -- \N -- !regr_intercept_largeint -- -9.999999999999989E19 +1.0E20 -- !regr_intercept_float -- -13.241664047161644 +13.24166404716167 -- !regr_intercept_double -- -58.05515207632899 +58.05515207633332 diff --git a/regression-test/data/query_p0/aggregate/support_type/regr_slope/regr_slope.out b/regression-test/data/query_p0/aggregate/support_type/regr_slope/regr_slope.out index 77140f0d1d3871..0e9d13ae71d07a 100644 --- a/regression-test/data/query_p0/aggregate/support_type/regr_slope/regr_slope.out +++ b/regression-test/data/query_p0/aggregate/support_type/regr_slope/regr_slope.out @@ -6,17 +6,17 @@ 1.0 -- !regr_slope_int -- --0.0 +1.0 -- !regr_slope_bigint -- \N -- !regr_slope_largeint -- -17725.127617654194 +0.0 -- !regr_slope_float -- --2.79289213515492 +-2.792892135154929 -- !regr_slope_double -- --0.5501239199999569 +-0.5501239199999999 diff --git a/regression-test/suites/query_p0/aggregate/test_regr_intercept.groovy b/regression-test/suites/query_p0/aggregate/test_regr_intercept.groovy index f7c44642427c3c..10683585309a33 100644 --- a/regression-test/suites/query_p0/aggregate/test_regr_intercept.groovy +++ b/regression-test/suites/query_p0/aggregate/test_regr_intercept.groovy @@ -51,9 +51,9 @@ suite("test_regr_intercept") { // no value // agg function without group by should return null - qt_sql_empty_1 "select regr_intercept(y,x) from test_regr_intercept_int" + qt_sql_empty_1 "select regr_intercept(y, x) from test_regr_intercept_int" // agg function with group by should return empty set - qt_sql_empty_2 "select regr_intercept(y,x) from test_regr_intercept_int group by id" + qt_sql_empty_2 "select regr_intercept(y, x) from test_regr_intercept_int group by id" sql """ TRUNCATE TABLE test_regr_intercept_int """ @@ -83,7 +83,7 @@ suite("test_regr_intercept") { qt_sql_int_2 "select regr_intercept(x, 4) from test_regr_intercept_int" // int value - qt_sql_int_3 "select regr_intercept(y,x) from test_regr_intercept_int" + qt_sql_int_3 "select regr_intercept(y, x) from test_regr_intercept_int" // qt_sql_int_3 tests Nullable input column, qt_sql_int_4 test non-Nullable input column qt_sql_int_4 "select regr_intercept(non_nullable(y), non_nullable(x)) from test_regr_intercept_int" @@ -122,8 +122,8 @@ suite("test_regr_intercept") { qt_sql_int_7 "select regr_intercept(x, 4) from test_regr_intercept_int" // int value - qt_sql_int_8 "select regr_intercept(y,x) from test_regr_intercept_int" - qt_sql_int_8 "select regr_intercept(y,x) from test_regr_intercept_int group by id order by id" + qt_sql_int_8 "select regr_intercept(y, x) from test_regr_intercept_int" + qt_sql_int_8 "select regr_intercept(y, x) from test_regr_intercept_int group by id order by id" // qt_sql_int_3 tests Nullable input column, qt_sql_int_4 test non-Nullable input column qt_sql_int_9 "select regr_intercept(non_nullable(y), non_nullable(x)) from test_regr_intercept_int where id >= 3" @@ -142,8 +142,8 @@ suite("test_regr_intercept") { qt_sql_double_2 "select regr_intercept(x, 4) from test_regr_intercept_double" // int value - qt_sql_double_3 "select regr_intercept(y,x) from test_regr_intercept_double" - qt_sql_double_3 "select regr_intercept(y,x) from test_regr_intercept_double group by id order by id" + qt_sql_double_3 "select regr_intercept(y, x) from test_regr_intercept_double" + qt_sql_double_3 "select regr_intercept(y, x) from test_regr_intercept_double group by id order by id" // qt_sql_int_3 tests Nullable input column, qt_sql_int_4 test non-Nullable input column qt_sql_double_4 "select regr_intercept(non_nullable(y), non_nullable(x)) from test_regr_intercept_double" @@ -183,8 +183,8 @@ suite("test_regr_intercept") { qt_sql_double_7 "select regr_intercept(x, 4) from test_regr_intercept_double" // int value - qt_sql_double_8 "select regr_intercept(y,x) from test_regr_intercept_double" - qt_sql_double_8 "select regr_intercept(y,x) from test_regr_intercept_double group by id order by id" + qt_sql_double_8 "select regr_intercept(y, x) from test_regr_intercept_double" + qt_sql_double_8 "select regr_intercept(y, x) from test_regr_intercept_double group by id order by id" // qt_sql_int_3 tests Nullable input column, qt_sql_int_4 test non-Nullable input column qt_sql_double_9 "select regr_intercept(non_nullable(y), non_nullable(x)) from test_regr_intercept_double where id >= 3" diff --git a/regression-test/suites/query_p0/aggregate/test_regr_slope.groovy b/regression-test/suites/query_p0/aggregate/test_regr_slope.groovy index 19397036234595..0c6007103678f9 100644 --- a/regression-test/suites/query_p0/aggregate/test_regr_slope.groovy +++ b/regression-test/suites/query_p0/aggregate/test_regr_slope.groovy @@ -51,9 +51,9 @@ suite("test_regr_slope") { // no value // agg function without group by should return null - qt_sql_empty_1 "select regr_slope(y,x) from test_regr_slope_int" + qt_sql_empty_1 "select regr_slope(y, x) from test_regr_slope_int" // agg function with group by should return empty set - qt_sql_empty_2 "select regr_slope(y,x) from test_regr_slope_int group by id" + qt_sql_empty_2 "select regr_slope(y, x) from test_regr_slope_int group by id" sql """ TRUNCATE TABLE test_regr_slope_int """ @@ -83,7 +83,7 @@ suite("test_regr_slope") { qt_sql_int_2 "select regr_slope(x, 4) from test_regr_slope_int" // int value - qt_sql_int_3 "select regr_slope(y,x) from test_regr_slope_int" + qt_sql_int_3 "select regr_slope(y, x) from test_regr_slope_int" // qt_sql_int_3 tests Nullable input column, qt_sql_int_4 test non-Nullable input column qt_sql_int_4 "select regr_slope(non_nullable(y), non_nullable(x)) from test_regr_slope_int" @@ -122,8 +122,8 @@ suite("test_regr_slope") { qt_sql_int_7 "select regr_slope(x, 4) from test_regr_slope_int" // int value - qt_sql_int_8 "select regr_slope(y,x) from test_regr_slope_int" - qt_sql_int_8 "select regr_slope(y,x) from test_regr_slope_int group by id order by id" + qt_sql_int_8 "select regr_slope(y, x) from test_regr_slope_int" + qt_sql_int_8 "select regr_slope(y, x) from test_regr_slope_int group by id order by id" // qt_sql_int_3 tests Nullable input column, qt_sql_int_4 test non-Nullable input column qt_sql_int_9 "select regr_slope(non_nullable(y), non_nullable(x)) from test_regr_slope_int where id >= 3" @@ -142,8 +142,8 @@ suite("test_regr_slope") { qt_sql_double_2 "select regr_slope(x, 4) from test_regr_slope_double" // int value - qt_sql_double_3 "select regr_slope(y,x) from test_regr_slope_double" - qt_sql_double_3 "select regr_slope(y,x) from test_regr_slope_double group by id order by id" + qt_sql_double_3 "select regr_slope(y, x) from test_regr_slope_double" + qt_sql_double_3 "select regr_slope(y, x) from test_regr_slope_double group by id order by id" // qt_sql_int_3 tests Nullable input column, qt_sql_int_4 test non-Nullable input column qt_sql_double_4 "select regr_slope(non_nullable(y), non_nullable(x)) from test_regr_slope_double" @@ -183,8 +183,8 @@ suite("test_regr_slope") { qt_sql_double_7 "select regr_slope(x, 4) from test_regr_slope_double" // int value - qt_sql_double_8 "select regr_slope(y,x) from test_regr_slope_double" - qt_sql_double_8 "select regr_slope(y,x) from test_regr_slope_double group by id order by id" + qt_sql_double_8 "select regr_slope(y, x) from test_regr_slope_double" + qt_sql_double_8 "select regr_slope(y, x) from test_regr_slope_double group by id order by id" // qt_sql_int_3 tests Nullable input column, qt_sql_int_4 test non-Nullable input column qt_sql_double_9 "select regr_slope(non_nullable(y), non_nullable(x)) from test_regr_slope_double where id >= 3" From a938543288f567bd5bf1b6b4f7a546c70360469c Mon Sep 17 00:00:00 2001 From: Jover Zhang Date: Tue, 9 Dec 2025 17:36:07 +0800 Subject: [PATCH 2/3] Refactor regr_* aggregate state to level-based central moments --- .../aggregate_function_regr_union.h | 211 +++++++++++++----- 1 file changed, 161 insertions(+), 50 deletions(-) diff --git a/be/src/vec/aggregate_functions/aggregate_function_regr_union.h b/be/src/vec/aggregate_functions/aggregate_function_regr_union.h index 685e0e750272f6..fa7732109f646b 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_regr_union.h +++ b/be/src/vec/aggregate_functions/aggregate_function_regr_union.h @@ -32,51 +32,126 @@ namespace doris::vectorized { #include "common/compile_check_begin.h" -template +template struct AggregateFunctionRegrData { static constexpr PrimitiveType Type = T; + + static_assert(!NeedSxy || (SxLevel >= 1 && SyLevel >= 1), + "NeedSxy requires SxLevel >= 1 and SyLevel >= 1"); + static_assert(SxLevel <= 2 && SyLevel <= 2, "Sx/Sy level must be <= 2"); + + static constexpr bool need_sx = SxLevel >= 1; + static constexpr bool need_sy = SyLevel >= 1; + static constexpr bool need_sxx = SxLevel >= 2; + static constexpr bool need_syy = SyLevel >= 2; + static constexpr bool need_sxy = NeedSxy; + + static constexpr size_t kMomentSize = SxLevel + SyLevel + size_t {need_sxy}; + static_assert(kMomentSize > 0 && kMomentSize <= 5, "Unexpected size of regr moment array"); + + /** + * The moments array is: + * Sx = sum(X) + * Sy = sum(Y) + * Sxx = sum((X-Sx/N)^2) + * Syy = sum((Y-Sy/N)^2) + * Sxy = sum((X-Sx/N)*(Y-Sy/N)) + */ + std::array moments {}; UInt64 n {}; - Float64 sx {}; - Float64 sy {}; - Float64 sxx {}; - Float64 syy {}; - Float64 sxy {}; + + static constexpr size_t idx_sx() { + static_assert(need_sx, "sx not enabled"); + return 0; + } + static constexpr size_t idx_sy() { + static_assert(need_sy, "sy not enabled"); + return size_t {need_sx}; + } + static constexpr size_t idx_sxx() { + static_assert(need_sxx, "sxx not enabled"); + return size_t {need_sx + need_sy}; + } + static constexpr size_t idx_syy() { + static_assert(need_syy, "syy not enabled"); + return size_t {need_sx + need_sy + need_sxx}; + } + static constexpr size_t idx_sxy() { + static_assert(need_sxy, "sxy not enabled"); + return size_t {need_sx + need_sy + need_sxx + need_syy}; + } + + Float64& sx() { return moments[idx_sx()]; } + Float64& sy() { return moments[idx_sy()]; } + Float64& sxx() { return moments[idx_sxx()]; } + Float64& syy() { return moments[idx_syy()]; } + Float64& sxy() { return moments[idx_sxy()]; } + + const Float64& sx() const { return moments[idx_sx()]; } + const Float64& sy() const { return moments[idx_sy()]; } + const Float64& sxx() const { return moments[idx_sxx()]; } + const Float64& syy() const { return moments[idx_syy()]; } + const Float64& sxy() const { return moments[idx_sxy()]; } void write(BufferWritable& buf) const { - buf.write_binary(sx); - buf.write_binary(sy); - buf.write_binary(sxx); - buf.write_binary(syy); - buf.write_binary(sxy); + if constexpr (need_sx) { + buf.write_binary(sx()); + } + if constexpr (need_sy) { + buf.write_binary(sy()); + } + if constexpr (need_sxx) { + buf.write_binary(sxx()); + } + if constexpr (need_syy) { + buf.write_binary(syy()); + } + if constexpr (need_sxy) { + buf.write_binary(sxy()); + } buf.write_binary(n); } void read(BufferReadable& buf) { - buf.read_binary(sx); - buf.read_binary(sy); - buf.read_binary(sxx); - buf.read_binary(syy); - buf.read_binary(sxy); + if constexpr (need_sx) { + buf.read_binary(sx()); + } + if constexpr (need_sy) { + buf.read_binary(sy()); + } + if constexpr (need_sxx) { + buf.read_binary(sxx()); + } + if constexpr (need_syy) { + buf.read_binary(syy()); + } + if constexpr (need_sxy) { + buf.read_binary(sxy()); + } buf.read_binary(n); } void reset() { - sx = {}; - sy = {}; - sxx = {}; - syy = {}; - sxy = {}; + moments.fill({}); n = {}; } /** * The merge function uses the Youngs–Cramer algorithm: - * N = N1 + N2 - * Sx = Sx1 + Sx2 - * Sy = Sy1 + Sy2 - * Sxx = Sxx1 + Sxx2 + N1 * N2 * (Sx1/N1 - Sx2/N2)^2 / N - * Syy = Syy1 + Syy2 + N1 * N2 * (Sy1/N1 - Sy2/N2)^2 / N - * Sxy = Sxy1 + Sxy2 + N1 * N2 * (Sx1/N1 - Sx2/N2) * (Sy1/N1 - Sy2/N2) / N + * N = N1 + N2 + * Sx = Sx1 + Sx2 + * Sy = Sy1 + Sy2 + * Sxx = Sxx1 + Sxx2 + N1 * N2 * (Sx1/N1 - Sx2/N2)^2 / N + * Syy = Syy1 + Syy2 + N1 * N2 * (Sy1/N1 - Sy2/N2)^2 / N + * Sxy = Sxy1 + Sxy2 + N1 * N2 * (Sx1/N1 - Sx2/N2) * (Sy1/N1 - Sy2/N2) / N */ void merge(const AggregateFunctionRegrData& rhs) { if (rhs.n == 0) { @@ -90,15 +165,31 @@ struct AggregateFunctionRegrData { const auto n2 = static_cast(rhs.n); const auto nsum = n1 + n2; - const auto dx = sx / n1 - rhs.sx / n2; - const auto dy = sy / n1 - rhs.sy / n2; + Float64 dx {}; + Float64 dy {}; + if constexpr (need_sxx || need_sxy) { + dx = sx() / n1 - rhs.sx() / n2; + } + if constexpr (need_syy || need_sxy) { + dy = sy() / n1 - rhs.sy() / n2; + } n += rhs.n; - sx += rhs.sx; - sy += rhs.sy; - sxx += rhs.sxx + n1 * n2 * dx * dx / nsum; - syy += rhs.syy + n1 * n2 * dy * dy / nsum; - sxy += rhs.sxy + n1 * n2 * dx * dy / nsum; + if constexpr (need_sx) { + sx() += rhs.sx(); + } + if constexpr (need_sy) { + sy() += rhs.sy(); + } + if constexpr (need_sxx) { + sxx() += rhs.sxx() + n1 * n2 * dx * dx / nsum; + } + if constexpr (need_syy) { + syy() += rhs.syy() + n1 * n2 * dy * dy / nsum; + } + if constexpr (need_sxy) { + sxy() += rhs.sxy() + n1 * n2 * dx * dy / nsum; + } } /** @@ -113,46 +204,66 @@ struct AggregateFunctionRegrData { typename PrimitiveTypeTraits::ColumnItemType value_x) { const auto x = static_cast(value_x); const auto y = static_cast(value_y); - sx += x; - sy += y; + + if constexpr (need_sx) { + sx() += x; + } + if constexpr (need_sy) { + sy() += y; + } if (n == 0) [[unlikely]] { n = 1; return; } - const auto tmp_n = static_cast(n + 1); - const auto tmp_x = x * tmp_n - sx; - const auto tmp_y = y * tmp_n - sy; - const auto scale = 1.0 / (tmp_n * static_cast(n)); - + const auto n_old = static_cast(n); + const auto n_new = n_old + 1; + const auto scale = 1.0 / (n_new * n_old); n += 1; - sxx += tmp_x * tmp_x * scale; - syy += tmp_y * tmp_y * scale; - sxy += tmp_x * tmp_y * scale; + + Float64 tmp_x {}; + Float64 tmp_y {}; + if constexpr (need_sxx || need_sxy) { + tmp_x = x * n_new - sx(); + } + if constexpr (need_syy || need_sxy) { + tmp_y = y * n_new - sy(); + } + + if constexpr (need_sxx) { + sxx() += tmp_x * tmp_x * scale; + } + if constexpr (need_syy) { + syy() += tmp_y * tmp_y * scale; + } + if constexpr (need_sxy) { + sxy() += tmp_x * tmp_y * scale; + } } }; template -struct RegrSlopeFunc : AggregateFunctionRegrData { +struct RegrSlopeFunc : AggregateFunctionRegrData { static constexpr const char* name = "regr_slope"; Float64 get_result() const { - if (this->n < 1 || this->sxx == 0.0) { + if (this->n < 1 || this->sxx() == 0.0) { return std::numeric_limits::quiet_NaN(); } - return this->sxy / this->sxx; + return this->sxy() / this->sxx(); } }; template -struct RegrInterceptFunc : AggregateFunctionRegrData { +struct RegrInterceptFunc : AggregateFunctionRegrData { static constexpr const char* name = "regr_intercept"; Float64 get_result() const { - if (this->n < 1 || this->sxx == 0.0) { + if (this->n < 1 || this->sxx() == 0.0) { return std::numeric_limits::quiet_NaN(); } - return (this->sy - this->sx * this->sxy / this->sxx) / static_cast(this->n); + return (this->sy() - this->sx() * this->sxy() / this->sxx()) / + static_cast(this->n); } }; From 9b7c0e17526c21fcfcb9f4282da2782a2de683c4 Mon Sep 17 00:00:00 2001 From: Jover Zhang Date: Tue, 9 Dec 2025 17:48:52 +0800 Subject: [PATCH 3/3] fix typo --- .../aggregate_function_regr_union.h | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/be/src/vec/aggregate_functions/aggregate_function_regr_union.h b/be/src/vec/aggregate_functions/aggregate_function_regr_union.h index fa7732109f646b..dde9fc5e48fdc3 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_regr_union.h +++ b/be/src/vec/aggregate_functions/aggregate_function_regr_union.h @@ -44,14 +44,14 @@ template = 1 && SyLevel >= 1), - "NeedSxy requires SxLevel >= 1 and SyLevel >= 1"); + static_assert(!NeedSxy || (SxLevel > 0 && SyLevel > 0), + "NeedSxy requires SxLevel > 0 and SyLevel > 0"); static_assert(SxLevel <= 2 && SyLevel <= 2, "Sx/Sy level must be <= 2"); - static constexpr bool need_sx = SxLevel >= 1; - static constexpr bool need_sy = SyLevel >= 1; - static constexpr bool need_sxx = SxLevel >= 2; - static constexpr bool need_syy = SyLevel >= 2; + static constexpr bool need_sx = SxLevel > 0; + static constexpr bool need_sy = SyLevel > 0; + static constexpr bool need_sxx = SxLevel > 1; + static constexpr bool need_syy = SyLevel > 1; static constexpr bool need_sxy = NeedSxy; static constexpr size_t kMomentSize = SxLevel + SyLevel + size_t {need_sxy};