diff --git a/be/src/vec/aggregate_functions/aggregate_function_regr_union.h b/be/src/vec/aggregate_functions/aggregate_function_regr_union.h index 1cc30b8c4307e7..dde9fc5e48fdc3 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_regr_union.h +++ b/be/src/vec/aggregate_functions/aggregate_function_regr_union.h @@ -32,88 +32,238 @@ namespace doris::vectorized { #include "common/compile_check_begin.h" -template +template struct AggregateFunctionRegrData { static constexpr PrimitiveType Type = T; - UInt64 count = 0; - Float64 sum_x {}; - Float64 sum_y {}; - Float64 sum_of_x_mul_y {}; - Float64 sum_of_x_squared {}; + + static_assert(!NeedSxy || (SxLevel > 0 && SyLevel > 0), + "NeedSxy requires SxLevel > 0 and SyLevel > 0"); + static_assert(SxLevel <= 2 && SyLevel <= 2, "Sx/Sy level must be <= 2"); + + static constexpr bool need_sx = SxLevel > 0; + static constexpr bool need_sy = SyLevel > 0; + static constexpr bool need_sxx = SxLevel > 1; + static constexpr bool need_syy = SyLevel > 1; + static constexpr bool need_sxy = NeedSxy; + + static constexpr size_t kMomentSize = SxLevel + SyLevel + size_t {need_sxy}; + static_assert(kMomentSize > 0 && kMomentSize <= 5, "Unexpected size of regr moment array"); + + /** + * The moments array is: + * Sx = sum(X) + * Sy = sum(Y) + * Sxx = sum((X-Sx/N)^2) + * Syy = sum((Y-Sy/N)^2) + * Sxy = sum((X-Sx/N)*(Y-Sy/N)) + */ + std::array moments {}; + UInt64 n {}; + + static constexpr size_t idx_sx() { + static_assert(need_sx, "sx not enabled"); + return 0; + } + static constexpr size_t idx_sy() { + static_assert(need_sy, "sy not enabled"); + return size_t {need_sx}; + } + static constexpr size_t idx_sxx() { + static_assert(need_sxx, "sxx not enabled"); + return size_t {need_sx + need_sy}; + } + static constexpr size_t idx_syy() { + static_assert(need_syy, "syy not enabled"); + return size_t {need_sx + need_sy + need_sxx}; + } + static constexpr size_t idx_sxy() { + static_assert(need_sxy, "sxy not enabled"); + return size_t {need_sx + need_sy + need_sxx + need_syy}; + } + + Float64& sx() { return moments[idx_sx()]; } + Float64& sy() { return moments[idx_sy()]; } + Float64& sxx() { return moments[idx_sxx()]; } + Float64& syy() { return moments[idx_syy()]; } + Float64& sxy() { return moments[idx_sxy()]; } + + const Float64& sx() const { return moments[idx_sx()]; } + const Float64& sy() const { return moments[idx_sy()]; } + const Float64& sxx() const { return moments[idx_sxx()]; } + const Float64& syy() const { return moments[idx_syy()]; } + const Float64& sxy() const { return moments[idx_sxy()]; } void write(BufferWritable& buf) const { - buf.write_binary(sum_x); - buf.write_binary(sum_y); - buf.write_binary(sum_of_x_mul_y); - buf.write_binary(sum_of_x_squared); - buf.write_binary(count); + if constexpr (need_sx) { + buf.write_binary(sx()); + } + if constexpr (need_sy) { + buf.write_binary(sy()); + } + if constexpr (need_sxx) { + buf.write_binary(sxx()); + } + if constexpr (need_syy) { + buf.write_binary(syy()); + } + if constexpr (need_sxy) { + buf.write_binary(sxy()); + } + buf.write_binary(n); } void read(BufferReadable& buf) { - buf.read_binary(sum_x); - buf.read_binary(sum_y); - buf.read_binary(sum_of_x_mul_y); - buf.read_binary(sum_of_x_squared); - buf.read_binary(count); + if constexpr (need_sx) { + buf.read_binary(sx()); + } + if constexpr (need_sy) { + buf.read_binary(sy()); + } + if constexpr (need_sxx) { + buf.read_binary(sxx()); + } + if constexpr (need_syy) { + buf.read_binary(syy()); + } + if constexpr (need_sxy) { + buf.read_binary(sxy()); + } + buf.read_binary(n); } void reset() { - sum_x = {}; - sum_y = {}; - sum_of_x_mul_y = {}; - sum_of_x_squared = {}; - count = 0; + moments.fill({}); + n = {}; } + /** + * The merge function uses the Youngs–Cramer algorithm: + * N = N1 + N2 + * Sx = Sx1 + Sx2 + * Sy = Sy1 + Sy2 + * Sxx = Sxx1 + Sxx2 + N1 * N2 * (Sx1/N1 - Sx2/N2)^2 / N + * Syy = Syy1 + Syy2 + N1 * N2 * (Sy1/N1 - Sy2/N2)^2 / N + * Sxy = Sxy1 + Sxy2 + N1 * N2 * (Sx1/N1 - Sx2/N2) * (Sy1/N1 - Sy2/N2) / N + */ void merge(const AggregateFunctionRegrData& rhs) { - if (rhs.count == 0) { + if (rhs.n == 0) { + return; + } + if (n == 0) { + *this = rhs; return; } - sum_x += rhs.sum_x; - sum_y += rhs.sum_y; - sum_of_x_mul_y += rhs.sum_of_x_mul_y; - sum_of_x_squared += rhs.sum_of_x_squared; - count += rhs.count; + const auto n1 = static_cast(n); + const auto n2 = static_cast(rhs.n); + const auto nsum = n1 + n2; + + Float64 dx {}; + Float64 dy {}; + if constexpr (need_sxx || need_sxy) { + dx = sx() / n1 - rhs.sx() / n2; + } + if constexpr (need_syy || need_sxy) { + dy = sy() / n1 - rhs.sy() / n2; + } + + n += rhs.n; + if constexpr (need_sx) { + sx() += rhs.sx(); + } + if constexpr (need_sy) { + sy() += rhs.sy(); + } + if constexpr (need_sxx) { + sxx() += rhs.sxx() + n1 * n2 * dx * dx / nsum; + } + if constexpr (need_syy) { + syy() += rhs.syy() + n1 * n2 * dy * dy / nsum; + } + if constexpr (need_sxy) { + sxy() += rhs.sxy() + n1 * n2 * dx * dy / nsum; + } } + /** + * N + * Sx = sum(X) + * Sy = sum(Y) + * Sxx = sum((X-Sx/N)^2) + * Syy = sum((Y-Sy/N)^2) + * Sxy = sum((X-Sx/N)*(Y-Sy/N)) + */ void add(typename PrimitiveTypeTraits::ColumnItemType value_y, typename PrimitiveTypeTraits::ColumnItemType value_x) { - sum_x += (double)value_x; - sum_y += (double)value_y; - sum_of_x_mul_y += (double)value_x * (double)value_y; - sum_of_x_squared += (double)value_x * (double)value_x; - count += 1; - } + const auto x = static_cast(value_x); + const auto y = static_cast(value_y); - Float64 get_slope() const { - Float64 denominator = (double)count * sum_of_x_squared - sum_x * sum_x; - if (count < 2 || denominator == 0.0) { - return std::numeric_limits::quiet_NaN(); + if constexpr (need_sx) { + sx() += x; + } + if constexpr (need_sy) { + sy() += y; + } + + if (n == 0) [[unlikely]] { + n = 1; + return; + } + const auto n_old = static_cast(n); + const auto n_new = n_old + 1; + const auto scale = 1.0 / (n_new * n_old); + n += 1; + + Float64 tmp_x {}; + Float64 tmp_y {}; + if constexpr (need_sxx || need_sxy) { + tmp_x = x * n_new - sx(); + } + if constexpr (need_syy || need_sxy) { + tmp_y = y * n_new - sy(); + } + + if constexpr (need_sxx) { + sxx() += tmp_x * tmp_x * scale; + } + if constexpr (need_syy) { + syy() += tmp_y * tmp_y * scale; + } + if constexpr (need_sxy) { + sxy() += tmp_x * tmp_y * scale; } - Float64 slope = ((double)count * sum_of_x_mul_y - sum_x * sum_y) / denominator; - return slope; } }; template -struct RegrSlopeFunc : AggregateFunctionRegrData { +struct RegrSlopeFunc : AggregateFunctionRegrData { static constexpr const char* name = "regr_slope"; - Float64 get_result() const { return this->get_slope(); } + Float64 get_result() const { + if (this->n < 1 || this->sxx() == 0.0) { + return std::numeric_limits::quiet_NaN(); + } + return this->sxy() / this->sxx(); + } }; template -struct RegrInterceptFunc : AggregateFunctionRegrData { +struct RegrInterceptFunc : AggregateFunctionRegrData { static constexpr const char* name = "regr_intercept"; Float64 get_result() const { - auto slope = this->get_slope(); - if (std::isnan(slope)) { - return slope; - } else { - Float64 intercept = (this->sum_y - slope * this->sum_x) / (double)this->count; - return intercept; + if (this->n < 1 || this->sxx() == 0.0) { + return std::numeric_limits::quiet_NaN(); } + return (this->sy() - this->sx() * this->sxy() / this->sxx()) / + static_cast(this->n); } }; @@ -147,7 +297,7 @@ class AggregateFunctionRegrSimple const XInputCol* x_nested_column = nullptr; if constexpr (y_nullable) { - const ColumnNullable& y_column_nullable = + const auto& y_column_nullable = assert_cast(*columns[0]); y_null = y_column_nullable.is_null_at(row_num); y_nested_column = assert_cast( @@ -158,7 +308,7 @@ class AggregateFunctionRegrSimple } if constexpr (x_nullable) { - const ColumnNullable& x_column_nullable = + const auto& x_column_nullable = assert_cast(*columns[1]); x_null = x_column_nullable.is_null_at(row_num); x_nested_column = assert_cast( diff --git a/regression-test/data/query_p0/aggregate/support_type/regr_intercept/regr_intercept.out b/regression-test/data/query_p0/aggregate/support_type/regr_intercept/regr_intercept.out index 88a91371f5f820..f58aaf4a55f698 100644 --- a/regression-test/data/query_p0/aggregate/support_type/regr_intercept/regr_intercept.out +++ b/regression-test/data/query_p0/aggregate/support_type/regr_intercept/regr_intercept.out @@ -6,17 +6,17 @@ -990000.0 -- !regr_intercept_int -- -1000001.0 +-9.999E9 -- !regr_intercept_bigint -- \N -- !regr_intercept_largeint -- -9.999999999999989E19 +1.0E20 -- !regr_intercept_float -- -13.241664047161644 +13.24166404716167 -- !regr_intercept_double -- -58.05515207632899 +58.05515207633332 diff --git a/regression-test/data/query_p0/aggregate/support_type/regr_slope/regr_slope.out b/regression-test/data/query_p0/aggregate/support_type/regr_slope/regr_slope.out index 77140f0d1d3871..0e9d13ae71d07a 100644 --- a/regression-test/data/query_p0/aggregate/support_type/regr_slope/regr_slope.out +++ b/regression-test/data/query_p0/aggregate/support_type/regr_slope/regr_slope.out @@ -6,17 +6,17 @@ 1.0 -- !regr_slope_int -- --0.0 +1.0 -- !regr_slope_bigint -- \N -- !regr_slope_largeint -- -17725.127617654194 +0.0 -- !regr_slope_float -- --2.79289213515492 +-2.792892135154929 -- !regr_slope_double -- --0.5501239199999569 +-0.5501239199999999 diff --git a/regression-test/suites/query_p0/aggregate/test_regr_intercept.groovy b/regression-test/suites/query_p0/aggregate/test_regr_intercept.groovy index f7c44642427c3c..10683585309a33 100644 --- a/regression-test/suites/query_p0/aggregate/test_regr_intercept.groovy +++ b/regression-test/suites/query_p0/aggregate/test_regr_intercept.groovy @@ -51,9 +51,9 @@ suite("test_regr_intercept") { // no value // agg function without group by should return null - qt_sql_empty_1 "select regr_intercept(y,x) from test_regr_intercept_int" + qt_sql_empty_1 "select regr_intercept(y, x) from test_regr_intercept_int" // agg function with group by should return empty set - qt_sql_empty_2 "select regr_intercept(y,x) from test_regr_intercept_int group by id" + qt_sql_empty_2 "select regr_intercept(y, x) from test_regr_intercept_int group by id" sql """ TRUNCATE TABLE test_regr_intercept_int """ @@ -83,7 +83,7 @@ suite("test_regr_intercept") { qt_sql_int_2 "select regr_intercept(x, 4) from test_regr_intercept_int" // int value - qt_sql_int_3 "select regr_intercept(y,x) from test_regr_intercept_int" + qt_sql_int_3 "select regr_intercept(y, x) from test_regr_intercept_int" // qt_sql_int_3 tests Nullable input column, qt_sql_int_4 test non-Nullable input column qt_sql_int_4 "select regr_intercept(non_nullable(y), non_nullable(x)) from test_regr_intercept_int" @@ -122,8 +122,8 @@ suite("test_regr_intercept") { qt_sql_int_7 "select regr_intercept(x, 4) from test_regr_intercept_int" // int value - qt_sql_int_8 "select regr_intercept(y,x) from test_regr_intercept_int" - qt_sql_int_8 "select regr_intercept(y,x) from test_regr_intercept_int group by id order by id" + qt_sql_int_8 "select regr_intercept(y, x) from test_regr_intercept_int" + qt_sql_int_8 "select regr_intercept(y, x) from test_regr_intercept_int group by id order by id" // qt_sql_int_3 tests Nullable input column, qt_sql_int_4 test non-Nullable input column qt_sql_int_9 "select regr_intercept(non_nullable(y), non_nullable(x)) from test_regr_intercept_int where id >= 3" @@ -142,8 +142,8 @@ suite("test_regr_intercept") { qt_sql_double_2 "select regr_intercept(x, 4) from test_regr_intercept_double" // int value - qt_sql_double_3 "select regr_intercept(y,x) from test_regr_intercept_double" - qt_sql_double_3 "select regr_intercept(y,x) from test_regr_intercept_double group by id order by id" + qt_sql_double_3 "select regr_intercept(y, x) from test_regr_intercept_double" + qt_sql_double_3 "select regr_intercept(y, x) from test_regr_intercept_double group by id order by id" // qt_sql_int_3 tests Nullable input column, qt_sql_int_4 test non-Nullable input column qt_sql_double_4 "select regr_intercept(non_nullable(y), non_nullable(x)) from test_regr_intercept_double" @@ -183,8 +183,8 @@ suite("test_regr_intercept") { qt_sql_double_7 "select regr_intercept(x, 4) from test_regr_intercept_double" // int value - qt_sql_double_8 "select regr_intercept(y,x) from test_regr_intercept_double" - qt_sql_double_8 "select regr_intercept(y,x) from test_regr_intercept_double group by id order by id" + qt_sql_double_8 "select regr_intercept(y, x) from test_regr_intercept_double" + qt_sql_double_8 "select regr_intercept(y, x) from test_regr_intercept_double group by id order by id" // qt_sql_int_3 tests Nullable input column, qt_sql_int_4 test non-Nullable input column qt_sql_double_9 "select regr_intercept(non_nullable(y), non_nullable(x)) from test_regr_intercept_double where id >= 3" diff --git a/regression-test/suites/query_p0/aggregate/test_regr_slope.groovy b/regression-test/suites/query_p0/aggregate/test_regr_slope.groovy index 19397036234595..0c6007103678f9 100644 --- a/regression-test/suites/query_p0/aggregate/test_regr_slope.groovy +++ b/regression-test/suites/query_p0/aggregate/test_regr_slope.groovy @@ -51,9 +51,9 @@ suite("test_regr_slope") { // no value // agg function without group by should return null - qt_sql_empty_1 "select regr_slope(y,x) from test_regr_slope_int" + qt_sql_empty_1 "select regr_slope(y, x) from test_regr_slope_int" // agg function with group by should return empty set - qt_sql_empty_2 "select regr_slope(y,x) from test_regr_slope_int group by id" + qt_sql_empty_2 "select regr_slope(y, x) from test_regr_slope_int group by id" sql """ TRUNCATE TABLE test_regr_slope_int """ @@ -83,7 +83,7 @@ suite("test_regr_slope") { qt_sql_int_2 "select regr_slope(x, 4) from test_regr_slope_int" // int value - qt_sql_int_3 "select regr_slope(y,x) from test_regr_slope_int" + qt_sql_int_3 "select regr_slope(y, x) from test_regr_slope_int" // qt_sql_int_3 tests Nullable input column, qt_sql_int_4 test non-Nullable input column qt_sql_int_4 "select regr_slope(non_nullable(y), non_nullable(x)) from test_regr_slope_int" @@ -122,8 +122,8 @@ suite("test_regr_slope") { qt_sql_int_7 "select regr_slope(x, 4) from test_regr_slope_int" // int value - qt_sql_int_8 "select regr_slope(y,x) from test_regr_slope_int" - qt_sql_int_8 "select regr_slope(y,x) from test_regr_slope_int group by id order by id" + qt_sql_int_8 "select regr_slope(y, x) from test_regr_slope_int" + qt_sql_int_8 "select regr_slope(y, x) from test_regr_slope_int group by id order by id" // qt_sql_int_3 tests Nullable input column, qt_sql_int_4 test non-Nullable input column qt_sql_int_9 "select regr_slope(non_nullable(y), non_nullable(x)) from test_regr_slope_int where id >= 3" @@ -142,8 +142,8 @@ suite("test_regr_slope") { qt_sql_double_2 "select regr_slope(x, 4) from test_regr_slope_double" // int value - qt_sql_double_3 "select regr_slope(y,x) from test_regr_slope_double" - qt_sql_double_3 "select regr_slope(y,x) from test_regr_slope_double group by id order by id" + qt_sql_double_3 "select regr_slope(y, x) from test_regr_slope_double" + qt_sql_double_3 "select regr_slope(y, x) from test_regr_slope_double group by id order by id" // qt_sql_int_3 tests Nullable input column, qt_sql_int_4 test non-Nullable input column qt_sql_double_4 "select regr_slope(non_nullable(y), non_nullable(x)) from test_regr_slope_double" @@ -183,8 +183,8 @@ suite("test_regr_slope") { qt_sql_double_7 "select regr_slope(x, 4) from test_regr_slope_double" // int value - qt_sql_double_8 "select regr_slope(y,x) from test_regr_slope_double" - qt_sql_double_8 "select regr_slope(y,x) from test_regr_slope_double group by id order by id" + qt_sql_double_8 "select regr_slope(y, x) from test_regr_slope_double" + qt_sql_double_8 "select regr_slope(y, x) from test_regr_slope_double group by id order by id" // qt_sql_int_3 tests Nullable input column, qt_sql_int_4 test non-Nullable input column qt_sql_double_9 "select regr_slope(non_nullable(y), non_nullable(x)) from test_regr_slope_double where id >= 3"